├── .gitignore ├── LICENSE ├── README.md ├── analysis ├── change_parameters.py ├── entropy.py └── modeling_olmo_hf.py ├── assets ├── main_figure.png └── main_figure_bg.png ├── configs ├── 1B │ └── 1B_bs128_lr4e4_pubmed_1ep_738k.yaml ├── official │ └── OLMo-1B.yaml └── resuscitation │ └── 1B_bs128_lr4e4_pubmed_1ep_738k_resuscitation.yaml ├── fictional_knowledge ├── fictional_knowledge_paraphrased_once60.json ├── fictional_knowledge_paraphrased_paraphrase70.json └── paraphrase70-once60split.pkl ├── olmo ├── __init__.py ├── aliases.py ├── beam_search.py ├── checkpoint.py ├── config.py ├── data │ ├── __init__.py │ ├── collator.py │ ├── iterable_dataset.py │ └── memmap_dataset.py ├── eval │ ├── __init__.py │ ├── downstream.py │ └── evaluator.py ├── exceptions.py ├── initialization.py ├── model.py ├── optim.py ├── py.typed ├── safetensors_util.py ├── tokenizer.py ├── tokenizers │ ├── allenai_eleuther-ai-gpt-neox-20b-pii-special.json │ └── allenai_gpt-neox-olmo-dolma-v1_5.json ├── torch_util.py ├── train.py ├── util.py └── version.py ├── pyproject.toml └── scripts ├── get_dataorder.sh ├── get_model.sh ├── train.py └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # python 3 | 4 | *.pyc 5 | *.pyo 6 | __pycache__ 7 | 8 | #data 9 | data/ 10 | !olmo/data/ 11 | 12 | #checkpoints 13 | checkpoints/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Knowledge Entropy Decay during Language Model Pretraining Hinders New Knowledge Acquisition 2 | 3 | This repository contains the implementation of [**Knowledge Entropy Decay during Language Model Pretraining Hinders New Knowledge Acquisition**](https://arxiv.org/abs/2410.01380). 4 | 5 | 6 | ![main_figure_bg.png](assets/main_figure_bg.png) 7 | 8 | 9 | The code is based on the [OLMo](https://github.com/allenai/OLMo) project, with modifications to support knowledge injection during training and additional evaluations. 10 | Knowledge injection code is based on the [factual-knowledge-acquisition](https://github.com/kaistAI/factual-knowledge-acquisition.git) project, with minor modifications for data split of 'Paraphrase or Once' and evaluations. 11 | 12 | 13 | ## Key Differences from Original OLMo Repository 14 | 15 | 1. Modified `olmo/train.py` to: 16 | - Apply knowledge injection during training 17 | 2. Modified `olmo/checkpoint.py` to load model checkpoint from resuscitation method 18 | 3. Added calculation of Knowledge Entropy for pretrained model in `analysis/` folder 19 | 4. Added modification of model checkpoints for resuscitation method in `analysis/` folder 20 | 21 | ## Key Differences from Original factual-knowledge-acquisition Repository 22 | 1. Augmented the number of Paraphrase data with the method introduced in the original paper [How Do Large Language Models Acquire Factual Knowledge During Pretraining?](https://arxiv.org/abs/2406.11813) using GPT4, which can be found at `fictional_knowledge/` folder 23 | - Original Fictional Knowledge dataset can be found at: https://huggingface.co/datasets/kaist-ai/fictional-knowledge 24 | 2. Modified evaluation code 25 | 26 | 27 | 28 | ## Installation 29 | 30 | 1. Create a conda environment with python>=3.9. 31 | ``` 32 | conda create -n knowledge-entropy python=3.11 -y 33 | conda activate knowledge-entropy 34 | ``` 35 | 36 | 2. Install packages. 37 | ``` 38 | pip install -e . 39 | ``` 40 | 41 | 3. Download the intermediate OLMo checkpoint. 42 | Example of downloading from the official repository is presented in `scripts/get_model.sh`. 43 | First you have to figure out the link to the proper checkpoint in the [official link](https://github.com/allenai/OLMo/blob/main/checkpoints/official/OLMo-1B.csv) 44 | Below is example command for downloading the model whose pretraining step is 738020. 45 | 46 | ``` 47 | bash scripts/get_model.sh https://olmo-checkpoints.org/ai2-llm/olmo-small/oupb6jak/step738020-unsharded/ 48 | ``` 49 | 50 | 51 | ## Knowledge Entropy 52 | 1. Download official training order of Dolma from [official link](https://olmo-checkpoints.org/ai2-llm/olmo-small/46zc5fly/train_data/global_indices.npy) 53 | 54 | ``` 55 | bash scripts/get_dataorder.sh 56 | ``` 57 | 58 | 2. Run Knowledge entropy calculation command, where step, data_size and batch size should be chosed appropriately. 59 | ``` 60 | python -m analysis.entropy --step 738020 --data_size 2048 --batch_size 4 61 | ``` 62 | 63 | ## Training 64 | Train OLMo model with modified config. 65 | You can find exmaplary config at `configs/1B/1B_bs128_lr4e4_pubmed_1ep_738k.yaml`. 66 | 67 | 68 | ``` 69 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port=29599 -m scripts.train configs/1B/1B_bs128_lr4e4_pubmed_1ep_738k.yaml 70 | ``` 71 | 72 | ## Resuscitation 73 | 1. Save modified model chekpoint. 74 | Run modifying and saving model checkpoint for resuscitation method. 75 | Below is examplary command for changing with resuscitation ratio of 50% and amplifying factor of 2. 76 | 77 | ``` 78 | python -m analysis.change_parameters --step 738020--resuscitation_ratio 0.5 --amplifying_factor 2 79 | ``` 80 | 81 | 2. Run training command with modified config. 82 | The name of the newly saved model should be specified at model.resuscitation. 83 | You can find exmaplary config at `configs/resuscitation/1B_bs128_lr4e4_pubmed_1ep_738k_resuscitation.yaml`. 84 | 85 | ``` 86 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port=29599 -m scripts.train configs/resuscitation/1B_bs128_lr4e4_pubmed_1ep_738k_resuscitation.yaml 87 | ``` 88 | 89 | 90 | ## Citation 91 | If you find our work helpful for your work, please consider citing our paper, original OLMo paper and factual-knowledge-acquisition paper 92 | ```bibtex 93 | @inproceedings{Kim2024KnowledgeED, 94 | title={Knowledge Entropy Decay during Language Model Pretraining Hinders New Knowledge Acquisition}, 95 | author={Jiyeon Kim and Hyunji Lee and Hyowon Cho and Joel Jang and Hyeonbin Hwang and Seungpil Won and Youbin Ahn and Dohaeng Lee and Minjoon Seo}, 96 | year={2024}, 97 | url={https://api.semanticscholar.org/CorpusID:273025776} 98 | } 99 | @inproceedings{groeneveld-etal-2024-olmo, 100 | title = "{OLM}o: Accelerating the Science of Language Models", 101 | author = "Groeneveld, Dirk and 102 | Beltagy, Iz and 103 | Walsh, Evan and 104 | Bhagia, Akshita and 105 | Kinney, Rodney and 106 | Tafjord, Oyvind and 107 | Jha, Ananya and 108 | Ivison, Hamish and 109 | Magnusson, Ian and 110 | Wang, Yizhong and 111 | Arora, Shane and 112 | Atkinson, David and 113 | Authur, Russell and 114 | Chandu, Khyathi and 115 | Cohan, Arman and 116 | Dumas, Jennifer and 117 | Elazar, Yanai and 118 | Gu, Yuling and 119 | Hessel, Jack and 120 | Khot, Tushar and 121 | Merrill, William and 122 | Morrison, Jacob and 123 | Muennighoff, Niklas and 124 | Naik, Aakanksha and 125 | Nam, Crystal and 126 | Peters, Matthew and 127 | Pyatkin, Valentina and 128 | Ravichander, Abhilasha and 129 | Schwenk, Dustin and 130 | Shah, Saurabh and 131 | Smith, William and 132 | Strubell, Emma and 133 | Subramani, Nishant and 134 | Wortsman, Mitchell and 135 | Dasigi, Pradeep and 136 | Lambert, Nathan and 137 | Richardson, Kyle and 138 | Zettlemoyer, Luke and 139 | Dodge, Jesse and 140 | Lo, Kyle and 141 | Soldaini, Luca and 142 | Smith, Noah and 143 | Hajishirzi, Hannaneh", 144 | editor = "Ku, Lun-Wei and 145 | Martins, Andre and 146 | Srikumar, Vivek", 147 | booktitle = "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 148 | month = aug, 149 | year = "2024", 150 | address = "Bangkok, Thailand", 151 | publisher = "Association for Computational Linguistics", 152 | url = "https://aclanthology.org/2024.acl-long.841", 153 | doi = "10.18653/v1/2024.acl-long.841", 154 | pages = "15789--15809", 155 | abstract = "Language models (LMs) have become ubiquitous in both NLP research and in commercial product offerings. As their commercial importance has surged, the most powerful models have become closed off, gated behind proprietary interfaces, with important details of their training data, architectures, and development undisclosed. Given the importance of these details in scientifically studying these models, including their biases and potential risks, we believe it is essential for the research community to have access to powerful, truly open LMs. To this end, we have built OLMo, a competitive, truly Open Language Model, to enable the scientific study of language models. Unlike most prior efforts that have only released model weights and inference code, we release OLMo alongside open training data and training and evaluation code. We hope this release will empower the open research community and inspire a new wave of innovation.", 156 | } 157 | @article{chang2024large, 158 | title={How Do Large Language Models Acquire Factual Knowledge During Pretraining?}, 159 | author={Chang, Hoyeon and Park, Jinho and Ye, Seonghyeon and Yang, Sohee and Seo, Youngkyung and Chang, Du-Seong and Seo, Minjoon}, 160 | journal={arXiv preprint arXiv:2406.11813}, 161 | year={2024} 162 | } 163 | 164 | ``` 165 | 166 | 167 | ## License 168 | 169 | Apache 2.0 -------------------------------------------------------------------------------- /analysis/change_parameters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | 5 | def main(step, resuscitation_ratio, amplifying_factor): 6 | average_coef_path =f"checkpoints/pretrained_1B/{step}-unsharded/mlp_average_coefficients.pt" 7 | olmo_model_path = f"checkpoints/pretrained_1B/{step}-unsharded" 8 | if not os.path.isfile(average_coef_path): 9 | print(f"Get average coefficients first, by running `python -m analysis.entropy --step {step} --data_size 2048 --batch_size 4`") 10 | print("-"*50) 11 | average_activations = torch.load(average_coef_path) 12 | olmo_model = torch.load(f"{olmo_model_path}/model.pt") 13 | print("Loaded average coefficients and olmo model.") 14 | print(f"Saving new model parameters with resusctation ratio {resuscitation_ratio}, amplifying factor {amplifying_factor}") 15 | # Identify positions of values below the threshold 16 | mask_list = [] 17 | for layer_idx in range(average_activations.size(0)): 18 | activations = average_activations[layer_idx] 19 | threshold_value = torch.quantile(activations, resuscitation_ratio) 20 | mask = activations <= threshold_value 21 | mask_list.append(mask) 22 | 23 | ratio_list = [] 24 | for layer_idx in range(average_activations.size(0)): 25 | activations = average_activations[layer_idx] 26 | activations = activations.mean() / activations * amplifying_factor 27 | ratio_list.append(activations) 28 | 29 | for layer_idx, mask in enumerate(mask_list): 30 | # Find indices of the lowest p% activations 31 | indices_to_freeze = mask.nonzero(as_tuple=True)[0].to('cpu') 32 | 33 | # Get the target module's weight tensor 34 | ff_proj_weight_tensor = olmo_model[f'transformer.blocks.{layer_idx}.ff_proj.weight'] # Shape [16384, 2048] 35 | 36 | # Reshape ratio_list[layer_idx] to have the correct dimensions for broadcasting 37 | proj_random_tensor = ratio_list[layer_idx].unsqueeze(1).to('cpu') # Shape [8192, 1] 38 | with torch.no_grad(): 39 | # Scale only the specified rows in ff_proj_weight_tensor 40 | ff_proj_weight_tensor[indices_to_freeze, :] *= proj_random_tensor[indices_to_freeze] 41 | 42 | torch.save(olmo_model, f'{olmo_model_path}/resuscitation_ratio{str(resuscitation_ratio)}_amplifying{str(amplifying_factor)}.pt') 43 | print("Saving Complete") 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument("--step", type=int, default=None) 48 | parser.add_argument("--resuscitation_ratio", type=float, default=None) 49 | parser.add_argument("--amplifying_factor", type=float, default=None) 50 | 51 | args = parser.parse_args() 52 | main(args.step, args.resuscitation_ratio, args.amplifying_factor) -------------------------------------------------------------------------------- /analysis/entropy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json 4 | from collections import defaultdict 5 | from tqdm import tqdm 6 | from transformers import AutoModelForCausalLM, AutoTokenizer, OlmoForCausalLM 7 | from .modeling_olmo_hf import ExpOlmoForCausalLM 8 | import numpy as np 9 | import torch.nn.functional as F 10 | from torch.utils.data import Dataset, DataLoader 11 | 12 | 13 | import os 14 | from pathlib import Path 15 | import requests 16 | from urllib.parse import urlparse 17 | 18 | from olmo.config import TrainConfig 19 | from olmo.data import build_memmap_dataset 20 | 21 | 22 | def main(args): 23 | if args.model_size == "1B": 24 | layer_num, mlp_dimension, dim = 16, 8192, 2048 25 | else: 26 | layer_num, mlp_dimension, dim = 32, 11008, 4096 27 | 28 | dataloader, model, model_path = load_model(args) 29 | 30 | 31 | batch_num = 0 32 | non_padding_count = 0 33 | 34 | avg_act_abs = torch.zeros((layer_num, mlp_dimension), device=model.device) 35 | 36 | with torch.no_grad(): 37 | for batch in tqdm(dataloader): 38 | input_ids = batch['input_ids'].to(model.device) 39 | 40 | outputs = model(input_ids=input_ids) 41 | 42 | logits = outputs.logits 43 | bs, seq_len, _ = logits.shape 44 | batch_num += bs 45 | 46 | for layer_idx, activation in enumerate(outputs.activations): 47 | reshaped_activation_abs = torch.abs(activation).view(-1, mlp_dimension) 48 | if 'attention_mask' in batch: 49 | summed_activation = torch.sum(reshaped_activation_abs, dim=0) 50 | avg_act_abs[layer_idx] += summed_activation 51 | else: 52 | summed_activation = torch.sum(reshaped_activation_abs, dim=0) 53 | summed_activation /= activation.shape[1] 54 | avg_act_abs[layer_idx] += summed_activation 55 | 56 | # Clear memory 57 | del input_ids, outputs, logits, reshaped_activation_abs, summed_activation 58 | torch.cuda.empty_cache() 59 | 60 | length = non_padding_count if non_padding_count != 0 else batch_num 61 | avg_act_abs /= length 62 | entropy_avg_act_abs = turn_into_entropy(avg_act_abs).tolist() 63 | 64 | torch.save(avg_act_abs, f"{model_path}/mlp_average_coefficients.pt") 65 | to_save = { 66 | "entropy_avg_act_abs": entropy_avg_act_abs, 67 | } 68 | 69 | print("-"*50) 70 | print("Knowledge Entropy : ", sum(entropy_avg_act_abs)) 71 | print("Knowledge Entropy by layer : ", entropy_avg_act_abs) 72 | print("-"*50) 73 | write_json_file(f"{model_path}/knowledge_entropy.json", to_save) 74 | 75 | 76 | def custom_collate_fn(batch): 77 | # Convert the list of 'input_ids' from each sample to a tensor 78 | input_ids = [torch.tensor(item['input_ids']) for item in batch] 79 | # Stack the tensors into a batch 80 | input_ids = torch.stack(input_ids) 81 | return {'input_ids': input_ids} 82 | 83 | def turn_into_entropy_softmax(logit): 84 | prob = torch.softmax(logit, dim=-1) 85 | log_probabilities = torch.log(prob + 1e-9) 86 | entropy_act_sparsity = -torch.sum(prob * log_probabilities, dim=-1) 87 | return entropy_act_sparsity 88 | 89 | def turn_into_entropy(logit): 90 | prob = logit / torch.sum(logit, dim=-1, keepdim=True) 91 | log_probabilities = torch.log(prob + 1e-9) 92 | entropy_act_sparsity = -torch.sum(prob * log_probabilities, dim=-1) 93 | return entropy_act_sparsity 94 | 95 | 96 | def load_model(args): 97 | 98 | step = args.step 99 | model_size = args.model_size 100 | model_path = f"checkpoints/pretrained{'_1B' if model_size == '1B' else ''}/{step}-unsharded" 101 | base_model = "allenai/OLMo-1B-hf" 102 | 103 | tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-hf") 104 | data_path = "data/dolma_1B/first_1k.json" 105 | if not os.path.isfile(data_path): 106 | print("-"*50, "\nSaving decoded Dolma first batch file.") 107 | save_dolma() 108 | 109 | dataset = read_json_file(data_path) 110 | dataset = [d['text'] for d in dataset][:args.data_size] 111 | print(f"\n Loaded dolma dataset \n length: {len(dataset)} \n example: {dataset[-1]}") 112 | 113 | instances = [d for d in range(len(dataset))] 114 | subset_dataset = IndexedDataset(dataset, instances, tokenizer=tokenizer) 115 | dataloader = DataLoader(subset_dataset, batch_size=args.batch_size, shuffle=False) 116 | 117 | model = ExpOlmoForCausalLM.from_pretrained(base_model, attn_implementation="eager") #, device_map="auto") 118 | model = convert_to_hf(model, load_path = model_path, model_size=model_size) 119 | 120 | model.eval() 121 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 122 | model.to(device) 123 | print("model_path is ", model_path) 124 | 125 | return dataloader, model, model_path 126 | 127 | def convert_to_hf(model, load_path=None, ckpt_name=None, model_size="1B"): 128 | from olmo.checkpoint import load_state_dict 129 | from olmo.config import TrainConfig 130 | from olmo.util import clean_opt 131 | 132 | pt_name = "model.pt" if ckpt_name is None else f"{ckpt_name}.pt" 133 | ckpt = load_state_dict( 134 | load_path, pt_name, local_cache=None, map_location="cpu" 135 | ) 136 | 137 | 138 | yaml_path = f"configs/official/OLMo-{model_size}.yaml" 139 | args_list = [] 140 | cfg = TrainConfig.load(yaml_path, [clean_opt(s) for s in args_list]) 141 | cfg.device_eval_batch_size = 1 142 | 143 | ## CONVERT TO HF-Format 144 | olmo_config = cfg.model 145 | n_layers = olmo_config.n_layers 146 | n_heads = olmo_config.n_heads 147 | dim = olmo_config.d_model 148 | loaded = ckpt 149 | dims_per_head = dim // n_heads 150 | 151 | if olmo_config.n_kv_heads is not None: 152 | num_key_value_heads = olmo_config.n_kv_heads # for GQA / MQA 153 | elif olmo_config.multi_query_attention: # compatibility with other checkpoints 154 | num_key_value_heads = 1 155 | else: 156 | num_key_value_heads = n_heads 157 | 158 | dims_per_head = dim // n_heads 159 | state_dict = {} 160 | for layer_i in range(n_layers): 161 | filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" 162 | # Unsharded 163 | # TODO: Layernorm stuff 164 | # TODO: multi query attention 165 | fused_dims = [dim, dims_per_head * num_key_value_heads, dims_per_head * num_key_value_heads] 166 | q_proj_weight, k_proj_weight, v_proj_weight = torch.split( 167 | loaded[f"transformer.blocks.{layer_i}.att_proj.weight"], fused_dims, dim=0 168 | ) 169 | up_proj_weight, gate_proj_weight = torch.chunk( 170 | loaded[f"transformer.blocks.{layer_i}.ff_proj.weight"], 2, dim=0 171 | ) 172 | 173 | state_dict.update({ 174 | f"model.layers.{layer_i}.self_attn.q_proj.weight": q_proj_weight, 175 | f"model.layers.{layer_i}.self_attn.k_proj.weight": k_proj_weight, 176 | f"model.layers.{layer_i}.self_attn.v_proj.weight": v_proj_weight, 177 | f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[ 178 | f"transformer.blocks.{layer_i}.attn_out.weight" 179 | ], 180 | f"model.layers.{layer_i}.mlp.gate_proj.weight": gate_proj_weight, 181 | f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.blocks.{layer_i}.ff_out.weight"], 182 | f"model.layers.{layer_i}.mlp.up_proj.weight": up_proj_weight, 183 | }) 184 | 185 | # state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq 186 | 187 | state_dict.update({ 188 | "model.embed_tokens.weight": loaded["transformer.wte.weight"], 189 | "lm_head.weight": loaded["transformer.ff_out.weight"] 190 | if "transformer.ff_out.weight" in loaded 191 | else loaded["transformer.wte.weight"], 192 | }) 193 | ### End of Conversion. 194 | 195 | # Load Model. 196 | model.load_state_dict(state_dict) 197 | 198 | return model 199 | 200 | def write_json_file(file_path, res): 201 | with open(file_path, 'w') as f: 202 | json.dump(res, f, indent=4) 203 | print(f"Wrote json file to: {file_path}!") 204 | 205 | def read_json_file(file_path): 206 | with open(file_path, 'r') as f: 207 | res = json.load(f) 208 | return res 209 | 210 | 211 | class IndexedDataset(Dataset): 212 | def __init__(self, dataset, indices, tokenizer=None, config = None, seq_len=2048): 213 | self.data = dataset 214 | self.indices = indices 215 | self.tokenizer = tokenizer 216 | self.seq_len = seq_len 217 | if config: 218 | if config.debug_data: 219 | self.data = self.data[:16] 220 | 221 | def __len__(self): 222 | return len(self.indices) 223 | 224 | def __getitem__(self, idx): 225 | if self.tokenizer: 226 | item_data = self.data[idx] 227 | encoding = self.tokenizer(item_data, max_length=self.seq_len, padding="max_length", truncation=True, return_tensors="pt") 228 | input_ids = encoding["input_ids"].squeeze(0) 229 | attention_mask = encoding["attention_mask"].squeeze(0) 230 | return {"input_ids": input_ids, "attention_mask": attention_mask} 231 | else: 232 | actual_idx = self.indices[idx] 233 | return self.data[actual_idx] 234 | 235 | def save_dolma(): 236 | 237 | epoch = 1 238 | global_train_examples_seen_this_epoch = 0 239 | data_order_file_path="data/global_indices/1B/global_indices.npy" 240 | train_config_path = "configs/official/OLMo-1B.yaml" 241 | 242 | if os.path.isfile(data_order_file_path) and os.path.isfile(train_config_path): 243 | 244 | # Download data-order file if it doesn't exist 245 | global_indices = np.memmap(data_order_file_path, mode="r+", dtype=np.uint32) 246 | print(f"\n Loaded dataset \n epoch: {epoch} \n global_train_examples_seen_this_epoch : {global_train_examples_seen_this_epoch}") 247 | 248 | # Save first batch of Dolma 249 | instances = [] 250 | term = 1 251 | for i in range(2048): 252 | instances.append(global_indices[global_train_examples_seen_this_epoch+i*term]) 253 | 254 | cfg = TrainConfig.load(train_config_path) 255 | dataset = build_memmap_dataset(cfg, cfg.data) 256 | 257 | tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-hf") 258 | to_save = [] 259 | print("\nStart decoding dataset") 260 | for inst in tqdm(instances,total=2048): 261 | input_ids = dataset[inst]['input_ids'] 262 | text = tokenizer.batch_decode([input_ids]) 263 | to_save.append({ 264 | "id": int(inst), 265 | "text": text[0] 266 | }) 267 | os.makedirs("data/dolma_1B", exist_ok=True) 268 | write_json_file(f"data/dolma_1B/first_1k.json", to_save) 269 | 270 | else: 271 | print("Data order file and config file do not exist. Please run `bash scripts/get_dataorder.sh`") 272 | 273 | if __name__ == "__main__": 274 | parser = argparse.ArgumentParser() 275 | parser.add_argument("--step", type=int, default=None) 276 | parser.add_argument("--data_step", type=int, default=None) 277 | parser.add_argument("--batch_size", type=int, default=2) 278 | parser.add_argument("--data_size", type=int, default=4) 279 | parser.add_argument("--data_path", type=str, default=None) 280 | parser.add_argument("--data_manual_start_num", type=int, default=None) 281 | parser.add_argument("--data_manual_epoch", type=int, default=None) 282 | parser.add_argument("--finetuned_path", type=str, default=None) 283 | parser.add_argument("--model_size", type=str, default="1B") 284 | parser.add_argument("--bf16", type=bool, default=False) 285 | parser.add_argument("--temperature", type=bool, default=False) 286 | parser.add_argument("--save_dolma", type=bool, default=False) 287 | parser.add_argument("--sparsity", type=bool, default=False) 288 | parser.add_argument("--attn_initial_temp", type=float, default=None) 289 | parser.add_argument("--mlp_temp_path", type=str, default=None) 290 | parser.add_argument("--save_name", type=str, default="") 291 | 292 | args = parser.parse_args() 293 | main(args) 294 | -------------------------------------------------------------------------------- /assets/main_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaistAI/Knowledge-Entropy/85315146a9431548c696ed4b6ebfd630f121aa87/assets/main_figure.png -------------------------------------------------------------------------------- /assets/main_figure_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaistAI/Knowledge-Entropy/85315146a9431548c696ed4b6ebfd630f121aa87/assets/main_figure_bg.png -------------------------------------------------------------------------------- /configs/1B/1B_bs128_lr4e4_pubmed_1ep_738k.yaml: -------------------------------------------------------------------------------- 1 | run_name: OLMo-1B_bs128_lr4e4_pubmed_1ep_738k 2 | seed: 6198 3 | dry_run: false 4 | 5 | model: 6 | d_model: 2048 7 | n_heads: 16 8 | n_layers: 16 9 | mlp_ratio: 8 10 | weight_tying: true 11 | alibi: false 12 | rope: true 13 | flash_attention: false # not available on AMD 14 | attention_dropout: 0.0 15 | attention_layer_norm: false 16 | multi_query_attention: false 17 | include_bias: false 18 | block_type: sequential 19 | layer_norm_type: default 20 | layer_norm_with_affine: false 21 | bias_for_layer_norm: false 22 | attention_layer_norm_with_affine: false 23 | activation_type: swiglu 24 | residual_dropout: 0.0 25 | embedding_dropout: 0.0 26 | max_sequence_length: 1024 #####CHECK##### 27 | vocab_size: 50280 28 | embedding_size: 50304 29 | eos_token_id: 50279 30 | pad_token_id: 1 31 | init_device: meta 32 | init_fn: mitchell 33 | 34 | compile: null # causes instability on AMD GPUs 35 | 36 | optimizer: 37 | name: adamw 38 | learning_rate: 4.0e-4 #####CHECK_LR##### 39 | weight_decay: 0.1 40 | betas: 41 | - 0.9 42 | - 0.95 43 | metrics_log_interval: 10 44 | 45 | scheduler: 46 | name: cosine_with_warmup 47 | t_warmup: 80 48 | alpha_f: 0.1 49 | 50 | tokenizer: 51 | identifier: olmo/tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json 52 | truncate_direction: right 53 | 54 | save_folder: checkpoints/1B_baseline/${run_name} 55 | save_overwrite: true 56 | # Sharded checkpoints (best for restarts) 57 | save_interval: 160 #####CHECK_BS,EP##### = eval_interval 58 | save_num_checkpoints_to_keep: 1 59 | # Unsharded checkpoints (for final storage) 60 | save_interval_unsharded: 1600 #####CHECK_BS,EP##### = eval_interval 61 | save_num_unsharded_checkpoints_to_keep: 100 62 | 63 | inject_indices_map: fictional_knowledge/paraphrase70-once60split.pkl 64 | eval_on_load: true 65 | inject_interval: 160 #####CHECK_BS##### = 1ep steps/10 66 | 67 | eval_interval: 160 #####CHECK_BS,EP##### = inject_interval 68 | max_duration: 1600 #####CHECK_BS,EP##### = 10*inject_interval 69 | stop_at: 1600 #####CHECK_BS,EP##### = 10*inject_interval 70 | global_train_batch_size: 128 #####CHECK_BS##### 71 | device_train_microbatch_size: 16 #####CHECK_BS##### when seq_len 2048, 4GPU->2, 8GPU->4 72 | 73 | precision: amp_bf16 74 | 75 | fsdp: 76 | wrapping_strategy: by_block 77 | precision: mixed 78 | 79 | max_grad_norm: 1.0 80 | max_grad_norm_ratio: null 81 | 82 | speed_monitor: 83 | window_size: 20 84 | 85 | eval_subset_num_batches: -1 86 | device_eval_batch_size: 16 #### 87 | evaluators: 88 | - label: piqa 89 | type: downstream 90 | - label: hellaswag 91 | type: downstream 92 | - label: winogrande 93 | type: downstream 94 | - label: openbook_qa 95 | type: downstream 96 | - label: sciq 97 | type: downstream 98 | - label: arc_easy 99 | type: downstream 100 | - label: copa 101 | type: downstream 102 | - label: rte 103 | type: downstream 104 | - label: commitment_bank 105 | type: downstream 106 | - label: sst2 107 | type: downstream 108 | 109 | - label: paraphrase_memorization 110 | data: 111 | num_workers: 16 112 | drop_last: false 113 | dataset_path: fictional_knowledge/fictional_knowledge_paraphrased_paraphrase70.json ### 114 | - label: paraphrase_semantic 115 | data: 116 | num_workers: 16 117 | drop_last: false 118 | dataset_path: fictional_knowledge/fictional_knowledge_paraphrased_paraphrase70.json ### 119 | - label: paraphrase_composition 120 | data: 121 | num_workers: 16 122 | drop_last: false 123 | dataset_path: fictional_knowledge/fictional_knowledge_paraphrased_paraphrase70.json ### 124 | - label: paraphrase_paragraph 125 | data: 126 | num_workers: 16 127 | drop_last: false 128 | dataset_path: fictional_knowledge/fictional_knowledge_paraphrased_paraphrase70.json ### 129 | - label: once_memorization 130 | data: 131 | num_workers: 16 132 | drop_last: false 133 | dataset_path: fictional_knowledge/fictional_knowledge_paraphrased_once60.json ### 134 | - label: once_semantic 135 | data: 136 | num_workers: 16 137 | drop_last: false 138 | dataset_path: fictional_knowledge/fictional_knowledge_paraphrased_once60.json ### 139 | - label: once_composition 140 | data: 141 | num_workers: 16 142 | drop_last: false 143 | dataset_path: fictional_knowledge/fictional_knowledge_paraphrased_once60.json ### 144 | - label: once_paragraph 145 | data: 146 | num_workers: 16 147 | drop_last: false 148 | dataset_path: fictional_knowledge/fictional_knowledge_paraphrased_once60.json ### 149 | 150 | 151 | reset_optimizer_state: true #####CHECK##### if reloading interrupted optimizer&model, then "fasle" 152 | reset_trainer_state: true #####CHECK##### if reloading interrupted optimizer&model, then "fasle" 153 | load_path: checkpoints/pretrained_1B/738020-unsharded 154 | 155 | data_shuffling: false 156 | 157 | data: 158 | pad_direction: right 159 | num_workers: 16 160 | drop_last: true 161 | pin_memory: true 162 | prefetch_factor: 1 163 | persistent_workers: true 164 | timeout: 0 165 | dataset_path: path_to_dataset 166 | -------------------------------------------------------------------------------- /configs/official/OLMo-1B.yaml: -------------------------------------------------------------------------------- 1 | run_name: OLMo-1B 2 | seed: 6198 3 | dry_run: false 4 | 5 | wandb: 6 | name: ${run_name} 7 | project: olmo-small 8 | 9 | model: 10 | d_model: 2048 11 | n_heads: 16 12 | n_layers: 16 13 | mlp_ratio: 8 14 | weight_tying: true 15 | alibi: false 16 | rope: true 17 | flash_attention: false # not available on AMD 18 | attention_dropout: 0.0 19 | attention_layer_norm: false 20 | multi_query_attention: false 21 | include_bias: false 22 | block_type: sequential 23 | layer_norm_type: default 24 | layer_norm_with_affine: false 25 | bias_for_layer_norm: false 26 | attention_layer_norm_with_affine: false 27 | activation_type: swiglu 28 | residual_dropout: 0.0 29 | embedding_dropout: 0.0 30 | max_sequence_length: 2048 31 | vocab_size: 50280 32 | embedding_size: 50304 33 | eos_token_id: 50279 34 | pad_token_id: 1 35 | init_device: meta 36 | init_fn: mitchell 37 | 38 | compile: null # causes instability on AMD GPUs 39 | 40 | optimizer: 41 | name: adamw 42 | learning_rate: 4.0e-4 43 | weight_decay: 0.1 44 | betas: 45 | - 0.9 46 | - 0.95 47 | metrics_log_interval: 10 48 | 49 | scheduler: 50 | name: cosine_with_warmup 51 | t_warmup: 2000 52 | alpha_f: 0.1 53 | 54 | tokenizer: 55 | identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json 56 | truncate_direction: right 57 | 58 | save_folder: runs/${run_name} 59 | save_overwrite: false 60 | # Sharded checkpoints (best for restarts) 61 | save_interval: 1000 62 | save_num_checkpoints_to_keep: 9 63 | # Unsharded checkpoints (for final storage) 64 | save_interval_unsharded: 10000 65 | save_num_unsharded_checkpoints_to_keep: -1 66 | 67 | load_path: null 68 | 69 | max_duration: 739_328 # 3.1T tokens 70 | global_train_batch_size: 2048 71 | device_train_microbatch_size: 8 72 | 73 | precision: amp_bf16 74 | 75 | fsdp: 76 | wrapping_strategy: null 77 | precision: mixed 78 | 79 | max_grad_norm: 1.0 80 | max_grad_norm_ratio: null 81 | 82 | speed_monitor: 83 | window_size: 20 84 | 85 | eval_interval: ${save_interval} 86 | eval_subset_num_batches: -1 87 | device_eval_batch_size: ${device_train_microbatch_size} 88 | evaluators: 89 | # lump all the small datasets together (we still get separate metrics). 90 | - label: v3-small-ppl-validation 91 | data: 92 | num_workers: 0 93 | drop_last: true 94 | datasets: 95 | v3-small-c4_en-validation: 96 | - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/c4_en/val/part-0-00000.npy 97 | v3-small-dolma_books-validation: 98 | - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_books/val/part-0-00000.npy 99 | v3-small-dolma_common-crawl-validation: 100 | - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_common-crawl/val/part-0-00000.npy 101 | v3-small-dolma_pes2o-validation: 102 | - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_pes2o/val/part-0-00000.npy 103 | v3-small-dolma_reddit-validation: 104 | - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_reddit/val/part-0-00000.npy 105 | v3-small-dolma_stack-validation: 106 | - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_stack/val/part-0-00000.npy 107 | v3-small-dolma_wiki-validation: 108 | - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_wiki/val/part-0-00000.npy 109 | v3-small-ice-validation: 110 | - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/ice/val/part-0-00000.npy 111 | v3-small-m2d2_s2orc-validation: 112 | - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/m2d2_s2orc/val/part-0-00000.npy 113 | v3-small-pile-validation: 114 | - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/pile/val/part-0-00000.npy 115 | v3-small-wikitext_103-validation: 116 | - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/wikitext_103/val/part-0-00000.npy 117 | 118 | - label: v2-small-ppl-validation 119 | data: 120 | num_workers: 0 121 | drop_last: true 122 | datasets: 123 | v2-small-4chan-validation: 124 | - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/4chan/val.npy 125 | v2-small-c4_100_domains-validation: 126 | - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/c4_100_domains/val.npy 127 | v2-small-c4_en-validation: 128 | - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/c4_en/val.npy 129 | v2-small-gab-validation: 130 | - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/gab/val.npy 131 | v2-small-ice-validation: 132 | - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/ice/val.npy 133 | v2-small-m2d2_s2orc-validation: 134 | - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/m2d2_s2orc/val.npy 135 | v2-small-m2d2_wiki-validation: 136 | - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/m2d2_wiki/val.npy 137 | v2-small-manosphere-validation: 138 | - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/manosphere/val.npy 139 | v2-small-mc4_en-validation: 140 | - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/mc4_en/val.npy 141 | v2-small-pile-validation: 142 | - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/pile/val.npy 143 | v2-small-ptb-validation: 144 | - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/ptb/val.npy 145 | v2-small-twitterAEE-validation: 146 | - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/twitterAEE/val.npy 147 | v2-small-wikitext_103-validation: 148 | - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/wikitext_103/val.npy 149 | 150 | - label: piqa 151 | type: downstream 152 | 153 | - label: hellaswag 154 | type: downstream 155 | 156 | - label: winogrande 157 | type: downstream 158 | 159 | - label: openbook_qa 160 | type: downstream 161 | 162 | # - label: boolq # requires implemention of the pmi_dc matrix 163 | # type: downstream 164 | 165 | - label: sciq 166 | type: downstream 167 | 168 | - label: arc_easy 169 | type: downstream 170 | 171 | # - label: arc_challenge # requires implemention of the pmi_dc matrix 172 | # type: downstream 173 | 174 | - label: copa 175 | type: downstream 176 | 177 | - label: rte 178 | type: downstream 179 | 180 | - label: commitment_bank 181 | type: downstream 182 | 183 | - label: mrpc 184 | type: downstream 185 | 186 | - label: sst2 187 | type: downstream 188 | 189 | data: 190 | pad_direction: right 191 | num_workers: 0 192 | drop_last: true 193 | pin_memory: true 194 | prefetch_factor: 16 195 | persistent_workers: true 196 | timeout: 0 197 | paths: 198 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-000-00000.npy 199 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-000-00001.npy 200 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-001-00000.npy 201 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-002-00000.npy 202 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-003-00000.npy 203 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-004-00000.npy 204 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-004-00001.npy 205 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-005-00000.npy 206 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-005-00001.npy 207 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-006-00000.npy 208 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-006-00001.npy 209 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-007-00000.npy 210 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-008-00000.npy 211 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-008-00001.npy 212 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-009-00000.npy 213 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-009-00001.npy 214 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-010-00000.npy 215 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-010-00001.npy 216 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-011-00000.npy 217 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-012-00000.npy 218 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-013-00000.npy 219 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-014-00000.npy 220 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-015-00000.npy 221 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-016-00000.npy 222 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-017-00000.npy 223 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-017-00001.npy 224 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-018-00000.npy 225 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-018-00001.npy 226 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-019-00000.npy 227 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-020-00000.npy 228 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-020-00001.npy 229 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-021-00000.npy 230 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-022-00000.npy 231 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-023-00000.npy 232 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-024-00000.npy 233 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-025-00000.npy 234 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-025-00001.npy 235 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-026-00000.npy 236 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-026-00001.npy 237 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-027-00000.npy 238 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-027-00001.npy 239 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-028-00000.npy 240 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-029-00000.npy 241 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-030-00000.npy 242 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-031-00000.npy 243 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-032-00000.npy 244 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-033-00000.npy 245 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-033-00001.npy 246 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-034-00000.npy 247 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-034-00001.npy 248 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-035-00000.npy 249 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-035-00001.npy 250 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-036-00000.npy 251 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-036-00001.npy 252 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-037-00000.npy 253 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-038-00000.npy 254 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-039-00000.npy 255 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-039-00001.npy 256 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-040-00000.npy 257 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-041-00000.npy 258 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-042-00000.npy 259 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-043-00000.npy 260 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-044-00000.npy 261 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-045-00000.npy 262 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-045-00001.npy 263 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-046-00000.npy 264 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-047-00000.npy 265 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-047-00001.npy 266 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-048-00000.npy 267 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-049-00000.npy 268 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-050-00000.npy 269 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-051-00000.npy 270 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-052-00000.npy 271 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-053-00000.npy 272 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-054-00000.npy 273 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-055-00000.npy 274 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-056-00000.npy 275 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-057-00000.npy 276 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-058-00000.npy 277 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-059-00000.npy 278 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-060-00000.npy 279 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-061-00000.npy 280 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-062-00000.npy 281 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-063-00000.npy 282 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-064-00000.npy 283 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-064-00001.npy 284 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-065-00000.npy 285 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-065-00001.npy 286 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-066-00000.npy 287 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-066-00001.npy 288 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-067-00000.npy 289 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-067-00001.npy 290 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-068-00000.npy 291 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-068-00001.npy 292 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-069-00000.npy 293 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-069-00001.npy 294 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-070-00000.npy 295 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-071-00000.npy 296 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-072-00000.npy 297 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-073-00000.npy 298 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-074-00000.npy 299 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-074-00001.npy 300 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-075-00000.npy 301 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-075-00001.npy 302 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-076-00000.npy 303 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-076-00001.npy 304 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-077-00000.npy 305 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-078-00000.npy 306 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-078-00001.npy 307 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-079-00000.npy 308 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-079-00001.npy 309 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-080-00000.npy 310 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-081-00000.npy 311 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-082-00000.npy 312 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-083-00000.npy 313 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-083-00001.npy 314 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-084-00000.npy 315 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-085-00000.npy 316 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-086-00000.npy 317 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-087-00000.npy 318 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-088-00000.npy 319 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-088-00001.npy 320 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-089-00000.npy 321 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-089-00001.npy 322 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-090-00000.npy 323 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-090-00001.npy 324 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-091-00000.npy 325 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-092-00000.npy 326 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-093-00000.npy 327 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-094-00000.npy 328 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-095-00000.npy 329 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-096-00000.npy 330 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-096-00001.npy 331 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-097-00000.npy 332 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-098-00000.npy 333 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-099-00000.npy 334 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-100-00000.npy 335 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-101-00000.npy 336 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-102-00000.npy 337 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-102-00001.npy 338 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-103-00000.npy 339 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-104-00000.npy 340 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-105-00000.npy 341 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-105-00001.npy 342 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-106-00000.npy 343 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-107-00000.npy 344 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-108-00000.npy 345 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-109-00000.npy 346 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-110-00000.npy 347 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-111-00000.npy 348 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-112-00000.npy 349 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-112-00001.npy 350 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-113-00000.npy 351 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-114-00000.npy 352 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-115-00000.npy 353 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-116-00000.npy 354 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-117-00000.npy 355 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-118-00000.npy 356 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-118-00001.npy 357 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-119-00000.npy 358 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-120-00000.npy 359 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-120-00001.npy 360 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-121-00000.npy 361 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-122-00000.npy 362 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-123-00000.npy 363 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-124-00000.npy 364 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-125-00000.npy 365 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-126-00000.npy 366 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-126-00001.npy 367 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-127-00000.npy 368 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-128-00000.npy 369 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-129-00000.npy 370 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-130-00000.npy 371 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-131-00000.npy 372 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-132-00000.npy 373 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-133-00000.npy 374 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-134-00000.npy 375 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-135-00000.npy 376 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-136-00000.npy 377 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-137-00000.npy 378 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-138-00000.npy 379 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-139-00000.npy 380 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-139-00001.npy 381 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-140-00000.npy 382 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-141-00000.npy 383 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-142-00000.npy 384 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-143-00000.npy 385 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-143-00001.npy 386 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-144-00000.npy 387 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-145-00000.npy 388 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-145-00001.npy 389 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-146-00000.npy 390 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-147-00000.npy 391 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-147-00001.npy 392 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-148-00000.npy 393 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-149-00000.npy 394 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-149-00001.npy 395 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-150-00000.npy 396 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-151-00000.npy 397 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-151-00001.npy 398 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-152-00000.npy 399 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-152-00001.npy 400 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-153-00000.npy 401 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-153-00001.npy 402 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-154-00000.npy 403 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-155-00000.npy 404 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-156-00000.npy 405 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-156-00001.npy 406 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-157-00000.npy 407 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-158-00000.npy 408 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-158-00001.npy 409 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-159-00000.npy 410 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-160-00000.npy 411 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-160-00001.npy 412 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-161-00000.npy 413 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-161-00001.npy 414 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-162-00000.npy 415 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-163-00000.npy 416 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-163-00001.npy 417 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-164-00000.npy 418 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-165-00000.npy 419 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-165-00001.npy 420 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-166-00000.npy 421 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-166-00001.npy 422 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-167-00000.npy 423 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-167-00001.npy 424 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-168-00000.npy 425 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-169-00000.npy 426 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-170-00000.npy 427 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-171-00000.npy 428 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-172-00000.npy 429 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-173-00000.npy 430 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-174-00000.npy 431 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-174-00001.npy 432 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-175-00000.npy 433 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-176-00000.npy 434 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-177-00000.npy 435 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-178-00000.npy 436 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-179-00000.npy 437 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-179-00001.npy 438 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-180-00000.npy 439 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-181-00000.npy 440 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-182-00000.npy 441 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-183-00000.npy 442 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-184-00000.npy 443 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-185-00000.npy 444 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-185-00001.npy 445 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-186-00000.npy 446 | - https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-187-00000.npy 447 | -------------------------------------------------------------------------------- /configs/resuscitation/1B_bs128_lr4e4_pubmed_1ep_738k_resuscitation.yaml: -------------------------------------------------------------------------------- 1 | run_name: OLMo-1B_bs128_lr4e4_pubmed_1ep_738k_test 2 | seed: 6198 3 | dry_run: false 4 | 5 | model: 6 | d_model: 2048 7 | n_heads: 16 8 | n_layers: 16 9 | mlp_ratio: 8 10 | weight_tying: true 11 | alibi: false 12 | rope: true 13 | flash_attention: false # not available on AMD 14 | attention_dropout: 0.0 15 | attention_layer_norm: false 16 | multi_query_attention: false 17 | include_bias: false 18 | block_type: sequential 19 | layer_norm_type: default 20 | layer_norm_with_affine: false 21 | bias_for_layer_norm: false 22 | attention_layer_norm_with_affine: false 23 | activation_type: swiglu 24 | residual_dropout: 0.0 25 | embedding_dropout: 0.0 26 | max_sequence_length: 1024 #####CHECK##### 27 | vocab_size: 50280 28 | embedding_size: 50304 29 | eos_token_id: 50279 30 | pad_token_id: 1 31 | init_device: meta 32 | init_fn: mitchell 33 | resuscitation: resuscitation_ratio0.5_amplifying2.0.pt 34 | 35 | compile: null # causes instability on AMD GPUs 36 | 37 | optimizer: 38 | name: adamw 39 | learning_rate: 4.0e-4 #####CHECK_LR##### 40 | weight_decay: 0.1 41 | betas: 42 | - 0.9 43 | - 0.95 44 | metrics_log_interval: 10 45 | 46 | scheduler: 47 | name: cosine_with_warmup 48 | t_warmup: 80 49 | alpha_f: 0.1 50 | 51 | tokenizer: 52 | identifier: olmo/tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json 53 | truncate_direction: right 54 | 55 | save_folder: checkpoints/1B_baseline/${run_name} 56 | save_overwrite: true 57 | # Sharded checkpoints (best for restarts) 58 | save_interval: 160 #####CHECK_BS,EP##### = eval_interval 59 | save_num_checkpoints_to_keep: 1 60 | # Unsharded checkpoints (for final storage) 61 | save_interval_unsharded: 1600 #####CHECK_BS,EP##### = eval_interval 62 | save_num_unsharded_checkpoints_to_keep: 100 63 | 64 | inject_indices_map: fictional_knowledge/paraphrase70-once60split.pkl 65 | eval_on_load: true 66 | inject_interval: 160 #####CHECK_BS##### = 1ep steps/10 67 | 68 | eval_interval: 160 #####CHECK_BS,EP##### = inject_interval 69 | max_duration: 1600 #####CHECK_BS,EP##### = 10*inject_interval 70 | stop_at: 1600 #####CHECK_BS,EP##### = 10*inject_interval 71 | global_train_batch_size: 128 #####CHECK_BS##### 72 | device_train_microbatch_size: 16 #####CHECK_BS##### when seq_len 2048, 4GPU->2, 8GPU->4 73 | 74 | precision: amp_bf16 75 | 76 | fsdp: 77 | wrapping_strategy: by_block 78 | precision: mixed 79 | 80 | max_grad_norm: 1.0 81 | max_grad_norm_ratio: null 82 | 83 | speed_monitor: 84 | window_size: 20 85 | 86 | eval_subset_num_batches: -1 87 | device_eval_batch_size: 16 #### 88 | evaluators: 89 | - label: piqa 90 | type: downstream 91 | - label: hellaswag 92 | type: downstream 93 | - label: winogrande 94 | type: downstream 95 | - label: openbook_qa 96 | type: downstream 97 | - label: sciq 98 | type: downstream 99 | - label: arc_easy 100 | type: downstream 101 | - label: copa 102 | type: downstream 103 | - label: rte 104 | type: downstream 105 | - label: commitment_bank 106 | type: downstream 107 | - label: sst2 108 | type: downstream 109 | 110 | - label: paraphrase_memorization 111 | data: 112 | num_workers: 16 113 | drop_last: false 114 | dataset_path: fictional_knowledge/fictional_knowledge_paraphrased_paraphrase70.json ### 115 | - label: paraphrase_semantic 116 | data: 117 | num_workers: 16 118 | drop_last: false 119 | dataset_path: fictional_knowledge/fictional_knowledge_paraphrased_paraphrase70.json ### 120 | - label: paraphrase_composition 121 | data: 122 | num_workers: 16 123 | drop_last: false 124 | dataset_path: fictional_knowledge/fictional_knowledge_paraphrased_paraphrase70.json ### 125 | - label: paraphrase_paragraph 126 | data: 127 | num_workers: 16 128 | drop_last: false 129 | dataset_path: fictional_knowledge/fictional_knowledge_paraphrased_paraphrase70.json ### 130 | - label: once_memorization 131 | data: 132 | num_workers: 16 133 | drop_last: false 134 | dataset_path: fictional_knowledge/fictional_knowledge_paraphrased_once60.json ### 135 | - label: once_semantic 136 | data: 137 | num_workers: 16 138 | drop_last: false 139 | dataset_path: fictional_knowledge/fictional_knowledge_paraphrased_once60.json ### 140 | - label: once_composition 141 | data: 142 | num_workers: 16 143 | drop_last: false 144 | dataset_path: fictional_knowledge/fictional_knowledge_paraphrased_once60.json ### 145 | - label: once_paragraph 146 | data: 147 | num_workers: 16 148 | drop_last: false 149 | dataset_path: fictional_knowledge/fictional_knowledge_paraphrased_once60.json ### 150 | 151 | 152 | reset_optimizer_state: true #####CHECK##### if reloading interrupted optimizer&model, then "fasle" 153 | reset_trainer_state: true #####CHECK##### if reloading interrupted optimizer&model, then "fasle" 154 | load_path: checkpoints/pretrained_1B/738020-unsharded 155 | 156 | data_shuffling: false 157 | 158 | data: 159 | pad_direction: right 160 | num_workers: 16 161 | drop_last: true 162 | pin_memory: true 163 | prefetch_factor: 1 164 | persistent_workers: true 165 | timeout: 0 166 | dataset_path: path_to_dataset -------------------------------------------------------------------------------- /fictional_knowledge/paraphrase70-once60split.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaistAI/Knowledge-Entropy/85315146a9431548c696ed4b6ebfd630f121aa87/fictional_knowledge/paraphrase70-once60split.pkl -------------------------------------------------------------------------------- /olmo/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import * 2 | from .model import * 3 | from .tokenizer import * 4 | 5 | 6 | def check_install(cuda: bool = False): 7 | import torch 8 | 9 | from .version import VERSION 10 | 11 | if cuda: 12 | assert torch.cuda.is_available(), "CUDA is not available!" 13 | print("CUDA available") 14 | 15 | print(f"OLMo v{VERSION} installed") 16 | -------------------------------------------------------------------------------- /olmo/aliases.py: -------------------------------------------------------------------------------- 1 | from os import PathLike 2 | from typing import Union 3 | 4 | __all__ = ["PathOrStr"] 5 | 6 | 7 | PathOrStr = Union[str, PathLike] 8 | -------------------------------------------------------------------------------- /olmo/data/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Dict, List, Optional, cast 3 | import json 4 | import torch 5 | from torch.utils.data import DataLoader, DistributedSampler, Dataset 6 | import copy 7 | from ..aliases import PathOrStr 8 | from ..config import DataConfig, TrainConfig, EvaluatorConfig 9 | from ..exceptions import OLMoConfigurationError 10 | from ..torch_util import barrier, get_global_rank, get_world_size 11 | from .collator import DataCollator 12 | from .iterable_dataset import IterableDataset 13 | from .memmap_dataset import MemMapDataset 14 | from ..tokenizer import Tokenizer 15 | from datasets import Dataset as HFDataset 16 | 17 | __all__ = ["MemMapDataset", "DataCollator", "IterableDataset", "build_eval_dataloader", "build_train_dataloader"] 18 | 19 | 20 | def build_memmap_dataset( 21 | train_config: TrainConfig, data_config: DataConfig, include_instance_metadata: bool = True 22 | ) -> MemMapDataset: 23 | paths: List[str] 24 | metadata: List[Dict[str, Any]] = [] 25 | if data_config.paths: 26 | if data_config.datasets: 27 | raise OLMoConfigurationError("DataConfig.paths is mutually exclusive with DataConfig.datasets") 28 | paths = data_config.paths 29 | for path in paths: 30 | metadata.append({"path": str(path)}) 31 | elif data_config.datasets: 32 | paths = [] 33 | for label in sorted(data_config.datasets.keys()): 34 | label_paths = data_config.datasets[label] 35 | paths.extend(label_paths) 36 | metadata.extend([{"label": label}] * len(label_paths)) 37 | else: 38 | raise OLMoConfigurationError("One of DataConfig.paths or DataConfig.datasets is required") 39 | return MemMapDataset( 40 | *paths, 41 | chunk_size=train_config.model.max_sequence_length, 42 | metadata=metadata, 43 | include_instance_metadata=include_instance_metadata, 44 | pad_token_id=train_config.model.pad_token_id, 45 | generate_attention_mask=data_config.generate_attention_mask, 46 | label_mask_paths=cast(Optional[List[PathOrStr]], data_config.label_mask_paths), 47 | ) 48 | 49 | 50 | def build_eval_dataloader( 51 | train_config: TrainConfig, 52 | eval_config: EvaluatorConfig, 53 | batch_size: int, 54 | shuffle: bool = True, 55 | ) -> DataLoader: 56 | # dataset = build_memmap_dataset(train_config, data_config, include_instance_metadata=True) 57 | # dataset = CustomLMDataset(train_config, data_config.dataset_path) 58 | 59 | tokenizer = Tokenizer.from_train_config(train_config) 60 | with open(eval_config.data.dataset_path, 'r') as f: 61 | raw_dataset = json.load(f) 62 | if 'dolma' in eval_config.data.dataset_path or "paragraph" in eval_config.label: 63 | key_ = 'text' if 'text' in raw_dataset[0] else 'train_context' 64 | all_data = [x[key_] for x in raw_dataset] 65 | all_data_tokenized = tokenizer.encode_batch(all_data, add_special_tokens=False) 66 | dataset = CustomDataset([{"input_ids": all_data_tokenized[i]} for i in range(len(all_data))]) 67 | elif 'fictional' in eval_config.data.dataset_path: 68 | if "memorization" in eval_config.label: 69 | input_key, target_key = "mem_input", "mem_target" 70 | elif "semantic" in eval_config.label: 71 | input_key, target_key = "gen_input", "gen_target" 72 | elif "composition" in eval_config.label: 73 | input_key, target_key = "hard_gen_input", "hard_gen_target" 74 | 75 | probes = [f"{d[input_key][i]} {d[target_key][i]}" for d in raw_dataset for i in range(5)] 76 | all_probes_tokenized = tokenizer.encode_batch(probes, add_special_tokens=False) 77 | 78 | probes_input = [d[input_key][i] for d in raw_dataset for i in range(5)] 79 | all_input_tokenized = tokenizer.encode_batch(probes_input, add_special_tokens=False) 80 | all_labels_tokenized = [] 81 | for input_, probe_ in zip(all_input_tokenized, all_probes_tokenized): 82 | label_mask = torch.ones_like(torch.tensor(probe_), dtype=torch.bool) 83 | prompt_length = len(input_) 84 | label_mask[:prompt_length] = False 85 | all_labels_tokenized.append(label_mask) 86 | 87 | dataset = CustomDataset([{"input_ids": all_probes_tokenized[i], "label_mask": all_labels_tokenized[i]} 88 | for i in range(len(probes))]) 89 | 90 | collator = DataCollator( 91 | pad_direction=eval_config.data.pad_direction, 92 | pad_token_id=train_config.model.pad_token_id) 93 | if eval_config.data.drop_last: 94 | # Make sure batch size is small enough. 95 | samples_per_device = len(dataset) // get_world_size() 96 | batch_size = min(batch_size, samples_per_device) 97 | assert batch_size > 0, f"dataset for {eval_config.data.paths} is too small" 98 | seed = eval_config.data.seed if eval_config.data.seed is not None else train_config.seed 99 | sampler = DistributedSampler( 100 | dataset, 101 | drop_last=False, 102 | shuffle=False, 103 | num_replicas=get_world_size(), 104 | rank=get_global_rank(), 105 | seed=seed, 106 | ) 107 | return DataLoader( 108 | dataset, 109 | batch_size=batch_size, 110 | collate_fn=collator, 111 | num_workers=eval_config.data.num_workers, 112 | sampler=sampler, 113 | pin_memory=eval_config.data.pin_memory, 114 | prefetch_factor=None if eval_config.data.num_workers == 0 else eval_config.data.prefetch_factor, 115 | persistent_workers=False if eval_config.data.num_workers == 0 else eval_config.data.persistent_workers, 116 | timeout=eval_config.data.timeout, 117 | ) 118 | 119 | 120 | def build_original_train_dataloader(train_config: TrainConfig, world_size: Optional[int] = None) -> DataLoader: 121 | assert train_config.device_train_batch_size is not None 122 | collator = DataCollator( 123 | pad_direction=train_config.data.pad_direction, pad_token_id=train_config.model.pad_token_id 124 | ) 125 | dataset = build_memmap_dataset(train_config, train_config.data, include_instance_metadata=False) 126 | work_dir = Path(train_config.save_folder) / "train_data" 127 | if get_global_rank() == 0: 128 | if work_dir.is_dir() and not train_config.save_overwrite: 129 | raise OLMoConfigurationError( 130 | "train data working directory already exists, use --save_overwrite to overwrite" 131 | ) 132 | else: 133 | work_dir.mkdir(exist_ok=True, parents=True) 134 | barrier() 135 | seed = train_config.data.seed if train_config.data.seed is not None else train_config.seed 136 | return DataLoader( 137 | IterableDataset( 138 | dataset, # type: ignore 139 | train_config.global_train_batch_size, 140 | seed=seed + (train_config.epoch or 0), 141 | shuffle=True, 142 | drop_last=train_config.data.drop_last, 143 | world_size=world_size, 144 | work_dir=work_dir, 145 | ), 146 | batch_size=train_config.device_train_batch_size, 147 | drop_last=train_config.data.drop_last, 148 | collate_fn=collator, 149 | num_workers=train_config.data.num_workers, 150 | pin_memory=train_config.data.pin_memory, 151 | prefetch_factor=None if train_config.data.num_workers == 0 else train_config.data.prefetch_factor, 152 | persistent_workers=False if train_config.data.num_workers == 0 else train_config.data.persistent_workers, 153 | timeout=train_config.data.timeout, 154 | ) 155 | 156 | def build_train_dataloader(train_config: TrainConfig) -> DataLoader: 157 | assert train_config.device_train_batch_size is not None 158 | collator = DataCollator( 159 | pad_direction=train_config.data.pad_direction, pad_token_id=train_config.model.pad_token_id 160 | ) 161 | # dataset = build_memmap_dataset(train_config, train_config.data, include_instance_metadata=False) 162 | dataset = HFDataset.load_from_disk(train_config.data.dataset_path) 163 | work_dir = Path(train_config.save_folder) / "train_data" 164 | if get_global_rank() == 0: 165 | if work_dir.is_dir() and not train_config.save_overwrite: 166 | raise OLMoConfigurationError( 167 | "train data working directory already exists, use --save_overwrite to overwrite" 168 | ) 169 | else: 170 | work_dir.mkdir(exist_ok=True, parents=True) 171 | barrier() 172 | seed = train_config.data.seed if train_config.data.seed is not None else train_config.seed 173 | return DataLoader( 174 | IterableDataset( 175 | dataset, # type: ignore 176 | train_config.global_train_batch_size, 177 | seed=seed + (train_config.epoch or 0), 178 | shuffle=train_config.data_shuffling, 179 | drop_last=train_config.data.drop_last, 180 | work_dir=work_dir, 181 | ), 182 | batch_size=train_config.device_train_batch_size, 183 | drop_last=train_config.data.drop_last, 184 | collate_fn=collator, 185 | num_workers=train_config.data.num_workers, 186 | pin_memory=train_config.data.pin_memory, 187 | prefetch_factor=None if train_config.data.num_workers == 0 else train_config.data.prefetch_factor, 188 | persistent_workers=False if train_config.data.num_workers == 0 else train_config.data.persistent_workers, 189 | timeout=train_config.data.timeout, 190 | ) 191 | 192 | 193 | def build_custom_dataloader( 194 | train_config: TrainConfig, 195 | eval_config: EvaluatorConfig, 196 | ) -> DataLoader: 197 | 198 | tokenizer = Tokenizer.from_train_config(train_config) 199 | # assert train_config.probe_dataset is not None 200 | with open(eval_config.data.dataset_path, 'r') as f: 201 | raw_dataset = json.load(f) 202 | 203 | all_data = [x['text'] for x in raw_dataset] 204 | all_data_tokenized = tokenizer.encode_batch(all_data, add_special_tokens=False) 205 | 206 | dataset = CustomDataset([{"input_ids": all_data_tokenized[i], "metadata":{"label": eval_config.label}} for i in range(len(all_data))]) 207 | collator = DataCollator( 208 | pad_direction=train_config.data.pad_direction, pad_token_id=train_config.model.pad_token_id 209 | ) 210 | seed = train_config.seed 211 | sampler = DistributedSampler( 212 | dataset, 213 | drop_last=False, 214 | shuffle=False, 215 | num_replicas=get_world_size(), 216 | rank=get_global_rank(), 217 | seed=seed, 218 | ) 219 | 220 | return DataLoader( 221 | dataset, 222 | # batch_size=train_config.device_train_batch_size, 223 | batch_size=train_config.device_eval_batch_size, 224 | collate_fn=collator, 225 | num_workers=train_config.data.num_workers, 226 | sampler=sampler, 227 | pin_memory=train_config.data.pin_memory, 228 | prefetch_factor=None if train_config.data.num_workers == 0 else train_config.data.prefetch_factor, 229 | persistent_workers=False if train_config.data.num_workers == 0 else train_config.data.persistent_workers, 230 | timeout=train_config.data.timeout, 231 | drop_last=False 232 | ) 233 | 234 | 235 | 236 | class CustomDataset(Dataset): 237 | def __init__(self, data): 238 | self.data_d = data 239 | self.length = len(data) 240 | 241 | def __len__(self): 242 | return self.length 243 | 244 | def __getitem__(self, index): 245 | return self.data_d[index] 246 | # return [self.data_d[idx] for idx in index] 247 | 248 | class CustomLMDataset(Dataset): 249 | def __init__(self, train_config, dataset_path): 250 | self.tokenizer = Tokenizer.from_train_config(train_config) 251 | print(f"Loading from {dataset_path}") 252 | with open(dataset_path, 'r') as f: 253 | self.raw_dataset = json.load(f) 254 | 255 | 256 | # self.dataset = raw_dataset 257 | self.length = len(self.raw_dataset) 258 | 259 | def __len__(self): 260 | return self.length 261 | 262 | def __getitem__(self, idx): 263 | print(idx) 264 | text = self.raw_dataset[idx]['text'] 265 | # IGNORE_INDEX = -100 266 | encoding = self.tokenizer(text, max_length=2048, padding="max_length", truncation=True, return_tensors="pt") 267 | 268 | input_ids = encoding["input_ids"].squeeze(0) # Remove batch dimension 269 | # attention_mask = encoding["attention_mask"].squeeze(0) 270 | # labels = copy.deepcopy(input_ids) 271 | # labels[labels == self.tokenizer.pad_token_id] = IGNORE_INDEX 272 | 273 | return {"input_ids": input_ids} -------------------------------------------------------------------------------- /olmo/data/collator.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, Dict, List, Union 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from ..config import PaddingDirection, TrainConfig 10 | 11 | __all__ = ["DataCollator"] 12 | 13 | 14 | @dataclass 15 | class DataCollator: 16 | pad_direction: PaddingDirection 17 | pad_token_id: int 18 | 19 | @classmethod 20 | def from_train_config(cls, config: TrainConfig) -> DataCollator: 21 | return cls(pad_direction=config.data.pad_direction, pad_token_id=config.model.pad_token_id) 22 | 23 | def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Dict[str, Any]: 24 | assert items 25 | max_len = max((len(x["input_ids"] if isinstance(x, dict) else x) for x in items)) 26 | all_input_ids = [] 27 | all_attention_mask = [] 28 | all_attention_bias = [] 29 | all_label_mask = [] 30 | all_indices = [] 31 | all_metadata = [] 32 | for x in items: 33 | input_ids = x["input_ids"] if isinstance(x, dict) else x 34 | if not isinstance(input_ids, torch.Tensor): 35 | input_ids = torch.tensor(input_ids) 36 | 37 | pad_shape = ( 38 | (max_len - len(input_ids), 0) 39 | if self.pad_direction == PaddingDirection.left 40 | else (0, max_len - len(input_ids)) 41 | ) 42 | 43 | # Pad input IDs. 44 | all_input_ids.append( 45 | F.pad( 46 | input_ids.to(dtype=torch.long), 47 | pad_shape, 48 | value=self.pad_token_id, 49 | ) 50 | ) 51 | 52 | # Pad attention mask. 53 | attention_mask = x.get("attention_mask") if isinstance(x, dict) else None 54 | if attention_mask is not None: 55 | if not isinstance(attention_mask, torch.Tensor): 56 | attention_mask = torch.tensor(attention_mask) 57 | all_attention_mask.append( 58 | F.pad( 59 | attention_mask.to(dtype=torch.float), 60 | pad_shape, 61 | value=0.0, 62 | ) 63 | ) 64 | 65 | # Pad attention bias. 66 | attention_bias = x.get("attention_bias") if isinstance(x, dict) else None 67 | if attention_bias is not None: 68 | if not isinstance(attention_bias, torch.Tensor): 69 | attention_bias = torch.tensor(attention_bias) 70 | # Reshape to `(1, seq_len, seq_len)` 71 | while len(attention_bias.shape) < 3: 72 | attention_bias = attention_bias.unsqueeze(0) 73 | pad_value = False if attention_bias.dtype == torch.bool else float("-inf") 74 | all_attention_bias.append( 75 | F.pad( 76 | attention_bias, 77 | pad_shape + pad_shape, 78 | value=pad_value, 79 | ) 80 | ) 81 | 82 | # Pad label mask. 83 | label_mask = x.get("label_mask") if isinstance(x, dict) else None 84 | if label_mask is not None: 85 | if not isinstance(label_mask, torch.Tensor): 86 | label_mask = torch.tensor(label_mask) 87 | all_label_mask.append( 88 | F.pad( 89 | label_mask.to(dtype=torch.bool), 90 | pad_shape, 91 | value=False, 92 | ) 93 | ) 94 | 95 | # Indices. 96 | index = x.get("index") if isinstance(x, dict) else None 97 | if index is not None: 98 | all_indices.append(torch.tensor(index)) 99 | 100 | # Metadata. 101 | metadata = x.get("metadata") if isinstance(x, dict) else None 102 | if metadata is not None: 103 | all_metadata.append(metadata) 104 | 105 | out: Dict[str, Any] = {"input_ids": torch.stack(all_input_ids)} 106 | if all_attention_mask: 107 | out["attention_mask"] = torch.stack(all_attention_mask) 108 | if all_attention_bias: 109 | out["attention_bias"] = torch.stack(all_attention_bias) 110 | if all_label_mask: 111 | out["label_mask"] = torch.stack(all_label_mask) 112 | if all_indices: 113 | out["index"] = torch.stack(all_indices) 114 | if all_metadata: 115 | out["metadata"] = all_metadata 116 | 117 | return out 118 | -------------------------------------------------------------------------------- /olmo/data/iterable_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import pickle 4 | import json 5 | from pathlib import Path 6 | from typing import Any, Dict, Iterator, List, Optional, Sequence, Union 7 | 8 | import numpy as np 9 | import torch 10 | import torch.utils.data 11 | 12 | from ..aliases import PathOrStr 13 | from ..torch_util import barrier, get_fs_local_rank, get_global_rank, get_world_size 14 | from ..util import roundrobin, threaded_generator 15 | 16 | __all__ = ["IterableDataset"] 17 | 18 | log = logging.getLogger(__name__) 19 | 20 | 21 | class IterableDataset(torch.utils.data.IterableDataset[Dict[str, Any]]): 22 | """ 23 | Adapted from PyTorch's DistributedSampler, this wraps a Dataset or arbitrary sequence 24 | as an IterableDataset that can be deterministically restarted at any point by setting `start_index`, 25 | which should be a multiple of your global batch size. 26 | Similarly `max_examples`, if set, should be a multiple of global batch size. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | dataset: Union[Sequence[List[int]], Sequence[torch.Tensor], Sequence[Dict[str, Any]]], 32 | global_batch_size: int, 33 | *, 34 | seed: int = 0, 35 | start_index: int = 0, 36 | max_examples: Optional[int] = None, 37 | shuffle: bool = True, 38 | drop_last: bool = False, 39 | world_size: Optional[int] = None, 40 | rank: Optional[int] = None, 41 | fs_local_rank: Optional[int] = None, 42 | work_dir: Optional[PathOrStr] = None, 43 | num_threads: Optional[int] = None, 44 | ): 45 | self.dataset = dataset 46 | self.seed = seed 47 | self.start_index = start_index 48 | self.max_examples = max_examples 49 | self.shuffle = shuffle 50 | self.drop_last = drop_last 51 | self.rank = rank if rank is not None else get_global_rank() 52 | self.fs_local_rank = fs_local_rank if fs_local_rank is not None else get_fs_local_rank() 53 | self.world_size = world_size if world_size is not None else get_world_size() 54 | # If the dataset length is evenly divisible by # of replicas, then there 55 | # is no need to drop any data, since the dataset will be split equally. 56 | if self.drop_last and len(self.dataset) % self.world_size != 0: # type: ignore[arg-type] 57 | # Split to nearest available length that is evenly divisible by world size. 58 | # This is to ensure each rank receives the same amount of data. 59 | num_samples = math.ceil( 60 | (len(self.dataset) - self.world_size) / self.world_size # type: ignore[arg-type] 61 | ) 62 | else: 63 | num_samples = math.ceil(len(self.dataset) / self.world_size) # type: ignore[arg-type] 64 | self.total_size = num_samples * self.world_size 65 | self.num_threads = num_threads 66 | assert global_batch_size % self.world_size == 0 67 | self.device_batch_size = global_batch_size // self.world_size 68 | self.global_indices_file: Optional[Path] = None 69 | self.work_dir = work_dir 70 | 71 | 72 | def _build_and_save_global_indices(self): 73 | assert self.work_dir is not None 74 | self.global_indices_file = Path(self.work_dir) / "global_indices.npy" 75 | if self.fs_local_rank == 0: 76 | log.info("Saving global data order indices...") 77 | self.global_indices_file.parent.mkdir(parents=True, exist_ok=True) 78 | global_indices = self._build_global_indices() 79 | global_indices_mmap = np.memmap( 80 | self.global_indices_file, dtype=np.uint32, mode="w+", shape=(len(global_indices),) 81 | ) 82 | global_indices_mmap[:] = global_indices 83 | global_indices_mmap.flush() 84 | del global_indices_mmap 85 | log.info("Global data order indices saved to '%s'", self.global_indices_file) 86 | barrier() 87 | 88 | def _build_global_indices(self) -> np.ndarray: 89 | assert len(self.dataset) < np.iinfo(np.uint32).max 90 | indices = np.arange(len(self.dataset), dtype=np.uint32) 91 | if self.shuffle: 92 | # Deterministically shuffle based on epoch and seed 93 | # Torch built-in randomness is not very random, so we use numpy. 94 | rng = np.random.Generator(np.random.PCG64(seed=self.seed)) 95 | rng.shuffle(indices) 96 | 97 | if not self.drop_last: 98 | # Add extra samples to make it evenly divisible 99 | padding_size = self.total_size - len(indices) 100 | arrays_to_concatenate = [indices] 101 | while padding_size > 0: 102 | array_to_concatenate = indices[: min(padding_size, len(indices))] 103 | arrays_to_concatenate.append(array_to_concatenate) 104 | padding_size -= len(array_to_concatenate) 105 | del array_to_concatenate 106 | indices = np.concatenate(arrays_to_concatenate) 107 | else: 108 | # Remove tail of data to make it evenly divisible. 109 | indices = indices[: self.total_size] 110 | assert len(indices) == self.total_size 111 | return indices 112 | 113 | def get_global_indices(self) -> np.ndarray: 114 | if self.global_indices_file is not None: 115 | return np.memmap(self.global_indices_file, mode="r", dtype=np.uint32) # type: ignore 116 | else: 117 | return self._build_global_indices() 118 | 119 | def reshuffle(self): 120 | self.seed += 1 121 | if self.work_dir is not None: 122 | self._build_and_save_global_indices() 123 | 124 | def __iter__(self) -> Iterator[Dict[str, Any]]: 125 | indices = self.get_global_indices() 126 | 127 | # Truncate to max_examples. 128 | if self.max_examples is not None: 129 | assert self.max_examples % self.world_size == 0 130 | indices = indices[: self.max_examples] 131 | 132 | # Start at the specified index. 133 | if self.start_index > 0: 134 | assert self.start_index % self.world_size == 0 135 | indices = indices[self.start_index :] 136 | 137 | # Slice indices by rank to avoid duplicates. 138 | indices = indices[self.rank : self.total_size : self.world_size] 139 | 140 | # Separate from data loading workers (which use multiprocessing), we also have the option 141 | # to use multi-threading (within workers). 142 | num_threads = self.num_threads 143 | 144 | # Slice the indices by data loader worker rank to avoid duplicates. 145 | worker_info = torch.utils.data.get_worker_info() 146 | if worker_info is not None: 147 | # Note that each data loading worker gathers a whole batch at a time, and the workers 148 | # are called round-robin by rank. So to slice these up in a way that preserves order, regardless 149 | # of the number of workers, we should give worker 0 the first chunk of `device_batch_size` indices, 150 | # worker 1 the 2nd chunk of `device_train_batch_size` indices, etc... 151 | truncated_size = self.device_batch_size * (len(indices) // self.device_batch_size) 152 | left_overs = indices[truncated_size + worker_info.id :: worker_info.num_workers] 153 | indices = ( 154 | indices[:truncated_size] 155 | .reshape((-1, self.device_batch_size))[worker_info.id :: worker_info.num_workers] # type: ignore 156 | .reshape((-1,)) 157 | ) 158 | indices = np.concatenate([indices, left_overs]) 159 | elif num_threads is None: 160 | # If `num_threads` hasn't been specified and we're not using multiprocessing we'll try to guess 161 | # a good number of threads. 162 | num_threads = 4 163 | 164 | # Finally, potentially slice by threads. 165 | if num_threads: 166 | # In order to stay ahead of training the total queue size (sum across all threads) 167 | # should be bigger than the batch size. 168 | queue_size = math.ceil(self.device_batch_size * 2 / num_threads) 169 | 170 | thread_generators = [] 171 | for i in range(num_threads): 172 | generator = (self._get_dataset_item(int(idx)) for idx in indices[i::num_threads]) 173 | thread_generators.append( 174 | threaded_generator(generator, maxsize=queue_size, thread_name=f"data thread {i}") 175 | ) 176 | 177 | return (x for x in roundrobin(*thread_generators)) 178 | else: 179 | return (self._get_dataset_item(int(idx)) for idx in indices) 180 | 181 | def _get_dataset_item(self, idx: int) -> Dict[str, Any]: 182 | item = self.dataset[idx] 183 | if isinstance(item, dict): 184 | return dict(**item, index=idx) 185 | else: 186 | return {"input_ids": item, "index": idx} -------------------------------------------------------------------------------- /olmo/data/memmap_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from copy import deepcopy 4 | from typing import Any, Dict, List, Optional, Tuple, Union 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | from olmo.exceptions import OLMoEnvironmentError 11 | 12 | from ..aliases import PathOrStr 13 | from ..util import _get_s3_client, file_size, get_bytes_range 14 | 15 | __all__ = ["MemMapDataset"] 16 | 17 | 18 | class MemMapDataset(Dataset[Dict[str, Any]]): 19 | """ 20 | A PyTorch :class:`~torch.utils.data.Dataset` backed by one or more numpy memory-mapped arrays 21 | of token IDs. Token IDs are chunked together into contiguous blocks of ``chunk_size`` 22 | to create instances. 23 | 24 | If the length of a memory-mapped array is not a multiple of ``chunk_size`` the 25 | remainder of the tokens will be ignored. 26 | 27 | No special tokens are added to the input IDs so it's assumed that if you want 28 | EOS tokens between documents, for example, those will already be in the memory-mapped array. 29 | 30 | :param paths: Paths to memory-mapped token arrays. 31 | :param chunk_size: The number of tokens to chunk together into a single instance. 32 | Generally this should correspond to your model's maximum input length. 33 | :param memmap_dtype: The numpy datatype of the memory-mapped array. 34 | :param metadata: Metadata to add to each item. This should be a dictionary or a list of dictionaries 35 | with the same number of items as there are paths. 36 | :param include_instance_metadata: If ``True`` (the default), each instance returned from `__getitem__` will 37 | include the metadata from its source. 38 | :param generate_attention_mask: If ``True``, each instance returned from ``__getitem__`` will include an 39 | attention mask generated by masking each padding token. 40 | :param pad_token_id: The ID of the padding token. Required if ``generate_attention_mask`` is ``True``. 41 | :param label_mask_paths: Optional paths to ``np.bool_`` memory-mapped arrays of label masks. 42 | """ 43 | 44 | def __init__( 45 | self, 46 | *paths: PathOrStr, 47 | chunk_size: int = 1024, 48 | memmap_dtype=np.uint16, 49 | metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, 50 | include_instance_metadata: bool = True, 51 | generate_attention_mask: bool = False, 52 | pad_token_id: Optional[int] = None, 53 | label_mask_paths: Optional[List[PathOrStr]] = None, 54 | ): 55 | if not paths: 56 | raise ValueError("At least one path is required") 57 | 58 | if generate_attention_mask and not pad_token_id: 59 | raise ValueError("'pad_token_id' is required for 'generate_attention_mask'") 60 | 61 | if label_mask_paths and len(label_mask_paths) != len(paths): 62 | raise ValueError("There must be the same number of 'label_mask_paths' as there are 'paths'") 63 | 64 | if isinstance(metadata, list): 65 | if len(metadata) != len(paths): 66 | raise ValueError("'metadata' should have the same length as the number of file paths") 67 | else: 68 | metadata = [metadata or {}] * len(paths) 69 | 70 | self._memmap_paths = paths 71 | self._metadata = metadata 72 | self._label_mask_paths = label_mask_paths 73 | self._chunk_size = chunk_size 74 | self._mmap_offsets: Optional[List[Tuple[int, int]]] = None 75 | self._num_instances: Optional[int] = None 76 | self.dtype = memmap_dtype 77 | self._include_instance_metadata = include_instance_metadata 78 | self._generate_attention_mask = generate_attention_mask 79 | self._pad_token_id = pad_token_id 80 | 81 | @property 82 | def chunk_size(self) -> int: 83 | return self._chunk_size 84 | 85 | @property 86 | def max_seq_len(self) -> int: 87 | # For compatibility with composer's SpeedMonitor callback. 88 | return self.chunk_size 89 | 90 | @property 91 | def offsets(self) -> List[Tuple[int, int]]: 92 | # Create the global S3 client up front to work around a threading issue in boto. 93 | _get_s3_client("s3") 94 | try: 95 | _get_s3_client("r2") 96 | except OLMoEnvironmentError: 97 | # R2 might not be needed, so ignore this error. We will get an error 98 | # later if R2 is needed. 99 | pass 100 | 101 | if self._mmap_offsets is None: 102 | import concurrent.futures 103 | 104 | self._mmap_offsets = [] 105 | 106 | path_to_length: Dict[PathOrStr, int] = {} 107 | path_to_mask_path: Dict[PathOrStr, PathOrStr] = {} 108 | mask_path_to_length: Dict[PathOrStr, int] = {} 109 | 110 | with concurrent.futures.ThreadPoolExecutor() as executor: 111 | path_futures = [] 112 | mask_path_futures = [] 113 | for i, path in enumerate(self._memmap_paths): 114 | path_futures.append(executor.submit(self._get_file_length, path)) 115 | if self._label_mask_paths is not None: 116 | mask_path = self._label_mask_paths[i] 117 | path_to_mask_path[path] = mask_path 118 | mask_path_futures.append(executor.submit(self._get_file_length, mask_path, np.bool_)) 119 | 120 | for future in concurrent.futures.as_completed(path_futures): 121 | path, length = future.result() 122 | path_to_length[path] = length 123 | 124 | for future in concurrent.futures.as_completed(mask_path_futures): 125 | path, length = future.result() 126 | mask_path_to_length[path] = length 127 | 128 | start_offset = 0 129 | for path in self._memmap_paths: 130 | length = path_to_length[path] 131 | if mask_path_to_length: 132 | mask_path = path_to_mask_path[path] 133 | if length != mask_path_to_length[mask_path]: 134 | raise ValueError(f"masking file '{mask_path}' should be the same size as '{path}'") 135 | end_offset = start_offset + length 136 | self._mmap_offsets.append((start_offset, end_offset)) 137 | start_offset += length 138 | return self._mmap_offsets 139 | 140 | def _read_chunk_from_memmap(self, path: PathOrStr, index: int, dtype=None) -> torch.Tensor: 141 | dtype = dtype or self.dtype 142 | item_size = dtype(0).itemsize 143 | bytes_start = index * item_size * self._chunk_size 144 | num_bytes = item_size * self._chunk_size 145 | buffer = get_bytes_range(path, bytes_start, num_bytes) 146 | array = np.frombuffer(buffer, dtype=dtype) 147 | if dtype == np.bool_: 148 | return torch.tensor(array) 149 | else: 150 | return torch.tensor(array.astype(np.int_), dtype=torch.long) 151 | 152 | def _get_file_length(self, path, dtype=None) -> Tuple[PathOrStr, int]: 153 | dtype = dtype or self.dtype 154 | item_size = dtype(0).itemsize 155 | return path, file_size(path) // (item_size * self._chunk_size) 156 | 157 | def __len__(self) -> int: 158 | if self._num_instances is None: 159 | self._num_instances = self.offsets[-1][1] 160 | return self._num_instances 161 | 162 | def __getitem__(self, index: int) -> Dict[str, Any]: 163 | index = int(index) # in case this is a numpy int type. 164 | pos_index = index if index >= 0 else len(self) + index 165 | 166 | # The index of the memmap array within 'self.memmaps' 167 | memmap_index: Optional[int] = None 168 | # The 'index' relative to the corresponding memmap array. 169 | memmap_local_index: Optional[int] = None 170 | for i, (offset_start, offset_end) in enumerate(self.offsets): 171 | if offset_start <= pos_index < offset_end: 172 | memmap_index = i 173 | memmap_local_index = pos_index - offset_start 174 | 175 | if memmap_index is None or memmap_local_index is None: 176 | raise IndexError(f"{index} is out of bounds for dataset of size {len(self)}") 177 | 178 | # Read the data from file. 179 | input_ids = self._read_chunk_from_memmap(self._memmap_paths[memmap_index], memmap_local_index) 180 | out: Dict[str, Any] = {"input_ids": input_ids} 181 | 182 | if self._label_mask_paths is not None: 183 | label_mask = self._read_chunk_from_memmap( 184 | self._label_mask_paths[memmap_index], memmap_local_index, dtype=np.bool_ 185 | ) 186 | out["label_mask"] = label_mask 187 | 188 | if self._include_instance_metadata: 189 | metadata = self._metadata[memmap_index] 190 | out["metadata"] = deepcopy(metadata) 191 | 192 | if self._generate_attention_mask: 193 | assert self._pad_token_id is not None 194 | attn_mask = torch.ones_like(input_ids) 195 | attn_mask.masked_fill_(input_ids == self._pad_token_id, 0) 196 | out["attention_mask"] = attn_mask 197 | 198 | return out 199 | 200 | def __add__(self, other: MemMapDataset) -> MemMapDataset: 201 | """ 202 | Concatenate one :class:`MemMapDataset` with another. 203 | """ 204 | if not isinstance(other, MemMapDataset): 205 | raise NotImplementedError(f"Expected another MemMapDataset but got {type(other)}") 206 | return MemMapDataset( 207 | *(self._memmap_paths + other._memmap_paths), 208 | chunk_size=self._chunk_size, 209 | memmap_dtype=self.dtype, 210 | metadata=self._metadata + other._metadata, 211 | ) 212 | -------------------------------------------------------------------------------- /olmo/eval/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Union 2 | 3 | import torch 4 | from torch.utils.data import DataLoader, DistributedSampler 5 | from torchmetrics import MeanMetric, Metric 6 | 7 | from ..config import EvaluatorConfig, EvaluatorType, TrainConfig 8 | from ..exceptions import OLMoConfigurationError 9 | from ..tokenizer import Tokenizer 10 | from ..torch_util import get_global_rank, get_world_size 11 | from .downstream import ICLMetric, label_to_task_map 12 | from .evaluator import Evaluator 13 | from ..data import build_custom_dataloader 14 | __all__ = [ 15 | "Evaluator", 16 | "ICLMetric", 17 | "label_to_task_map", 18 | "build_downstream_evaluator", 19 | "build_evaluator", 20 | "build_evaluators", 21 | ] 22 | 23 | 24 | def build_downstream_evaluator( 25 | train_config: TrainConfig, 26 | eval_cfg: EvaluatorConfig, 27 | tokenizer: Tokenizer, 28 | device: torch.device, 29 | is_unit_test=False, 30 | ) -> Evaluator: 31 | task_kwargs = {} 32 | task_class = label_to_task_map[eval_cfg.label] 33 | if isinstance(task_class, tuple): 34 | task_class, task_kwargs = task_class 35 | ds_eval_dataset = task_class(tokenizer=tokenizer, **task_kwargs) # type: ignore 36 | data_config = eval_cfg.data 37 | if is_unit_test: 38 | ds_eval_sampler = None 39 | else: 40 | ds_eval_sampler = DistributedSampler( 41 | ds_eval_dataset, 42 | drop_last=data_config.drop_last, 43 | shuffle=False, 44 | num_replicas=get_world_size(), 45 | rank=get_global_rank(), 46 | seed=train_config.seed, 47 | ) 48 | ds_eval_dataloader = DataLoader( 49 | ds_eval_dataset, 50 | batch_size=eval_cfg.device_eval_batch_size or train_config.device_eval_batch_size, 51 | collate_fn=ds_eval_dataset.collate_fn, 52 | num_workers=data_config.num_workers, 53 | sampler=ds_eval_sampler, 54 | pin_memory=data_config.pin_memory, 55 | prefetch_factor=data_config.prefetch_factor, 56 | persistent_workers=data_config.persistent_workers, 57 | timeout=data_config.timeout, 58 | ) 59 | metric = ICLMetric(metric_type=ds_eval_dataset.metric_type) 60 | 61 | evaluator = Evaluator( 62 | label=eval_cfg.label, 63 | type=eval_cfg.type, 64 | eval_loader=ds_eval_dataloader, 65 | eval_metric=metric.to(device), 66 | subset_num_batches=eval_cfg.subset_num_batches, 67 | ) 68 | return evaluator 69 | 70 | 71 | def build_evaluator( 72 | train_config: TrainConfig, eval_config: EvaluatorConfig, tokenizer: Tokenizer, device: torch.device 73 | ) -> Evaluator: 74 | from ..data import build_eval_dataloader 75 | 76 | if eval_config.type == EvaluatorType.downstream: 77 | # Downstream evaluation. 78 | return build_downstream_evaluator(train_config, eval_config, tokenizer, device) 79 | elif eval_config.type == EvaluatorType.lm: 80 | # Language modeling evaluation. 81 | eval_loader = build_eval_dataloader( 82 | train_config, 83 | eval_config, 84 | eval_config.device_eval_batch_size or train_config.device_eval_batch_size, 85 | ) 86 | 87 | def make_metric(): 88 | return MeanMetric(nan_strategy="error").to(device) 89 | 90 | eval_metric: Union[Metric, Dict[str, Metric]] 91 | if eval_config.data.paths or eval_config.data.dataset_path: 92 | eval_metric = make_metric() 93 | elif eval_config.data.datasets: 94 | eval_metric = {label: make_metric() for label in eval_config.data.datasets.keys()} 95 | else: 96 | raise OLMoConfigurationError("One of DataConfig.paths or DataConfig.datasets is required") 97 | 98 | return Evaluator( 99 | label=eval_config.label, 100 | type=eval_config.type, 101 | eval_loader=eval_loader, 102 | eval_metric=eval_metric, 103 | subset_num_batches=eval_config.subset_num_batches, 104 | ) 105 | else: 106 | raise ValueError(f"Unexpected evaluator type '{eval_config.type}'") 107 | 108 | 109 | def build_evaluators(cfg: TrainConfig, device: torch.device) -> List[Evaluator]: 110 | evaluators = [] 111 | tokenizer = Tokenizer.from_train_config(cfg) 112 | for eval_cfg in cfg.evaluators: 113 | evaluators.append(build_evaluator(cfg, eval_cfg, tokenizer, device)) 114 | return evaluators 115 | 116 | def build_custom_evaluators(cfg: TrainConfig, device: torch.device) -> List[Evaluator]: 117 | evaluators = [] 118 | tokenizer = Tokenizer.from_train_config(cfg) 119 | for eval_cfg in cfg.evaluators: 120 | # evaluators.append(build_evaluator(cfg, eval_cfg, tokenizer, device)) 121 | evaluators.append({ 122 | eval_cfg.label : build_custom_dataloader(cfg, eval_cfg) 123 | }) 124 | return evaluators 125 | -------------------------------------------------------------------------------- /olmo/eval/evaluator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, Optional, Union 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from torchmetrics import MeanMetric, Metric 7 | 8 | from ..config import EvaluatorType 9 | from .downstream import ICLMetric 10 | 11 | __all__ = ["Evaluator"] 12 | 13 | 14 | @dataclass 15 | class Evaluator: 16 | label: str 17 | type: EvaluatorType 18 | eval_loader: DataLoader 19 | eval_metric: Union[Metric, Dict[str, Metric]] 20 | subset_num_batches: Optional[int] = None 21 | 22 | def reset_metrics(self) -> None: 23 | if isinstance(self.eval_metric, Metric): 24 | self.eval_metric.reset() 25 | else: 26 | for metric in self.eval_metric.values(): 27 | metric.reset() 28 | 29 | def compute_metrics(self) -> Dict[str, float]: 30 | if self.type == EvaluatorType.downstream: 31 | assert isinstance(self.eval_metric, ICLMetric) 32 | return { 33 | f"eval/downstream/{self.label}_{self.eval_metric.metric_type}": self.eval_metric.compute().item(), 34 | } 35 | elif self.type == EvaluatorType.lm: 36 | # Metric(s) = cross entropy loss 37 | metrics: Dict[str, Metric] 38 | if isinstance(self.eval_metric, Metric): 39 | metrics = {self.label: self.eval_metric} 40 | else: 41 | metrics = self.eval_metric 42 | out = {} 43 | for label in sorted(metrics.keys()): 44 | metric = metrics[label] 45 | assert isinstance(metric, MeanMetric) 46 | if metric.weight.item() == 0.0: # type: ignore 47 | # In this case we probably haven't called '.update()' on this metric yet, 48 | # so we do so here with dummy values. Since we pass 0.0 in for weight this won't 49 | # affect the final value. 50 | # This can happen when the evaluator contains multiple tasks/datasets and we didn't 51 | # get to this one within the current evaluation loop. 52 | metric.update(0.0, 0.0) 53 | loss = metric.compute() 54 | if loss.isnan().item(): 55 | # This can happen when the evaluator contains multiple tasks/datasets and we didn't 56 | # get to this one within the current evaluation loop. 57 | continue 58 | else: 59 | out[f"eval/{label}/CrossEntropyLoss"] = loss.item() 60 | out[f"eval/{label}/Perplexity"] = torch.exp(loss).item() 61 | return out 62 | else: 63 | raise ValueError(f"Unexpected evaluator type '{self.type}'") 64 | 65 | def update_metrics( 66 | self, 67 | batch: Dict[str, Any], 68 | ce_loss: torch.Tensor, 69 | logits: torch.Tensor, 70 | ) -> None: 71 | if self.type == EvaluatorType.downstream: 72 | assert isinstance(self.eval_metric, ICLMetric) 73 | self.eval_metric.update(batch, logits) # type: ignore 74 | elif self.type == EvaluatorType.lm: 75 | # Metric(s) = cross entropy loss 76 | # for metadata, instance_loss in zip(batch["metadata"], ce_loss): 77 | # if isinstance(self.eval_metric, dict): 78 | # metric = self.eval_metric[metadata["label"]] 79 | # else: 80 | # metric = self.eval_metric 81 | # metric.update(instance_loss) 82 | 83 | if isinstance(self.eval_metric, dict): 84 | for metadata, instance_loss in zip(batch["metadata"], ce_loss): 85 | metric = self.eval_metric[metadata["label"]] 86 | metric.update(instance_loss) 87 | else: 88 | for instance_loss in ce_loss: 89 | metric = self.eval_metric 90 | metric.update(instance_loss) 91 | 92 | 93 | else: 94 | raise ValueError(f"Unexpected evaluator type '{self.type}'") 95 | -------------------------------------------------------------------------------- /olmo/exceptions.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "OLMoError", 3 | "OLMoConfigurationError", 4 | "OLMoCliError", 5 | "OLMoEnvironmentError", 6 | "OLMoNetworkError", 7 | "OLMoCheckpointError", 8 | ] 9 | 10 | 11 | class OLMoError(Exception): 12 | """ 13 | Base class for all custom OLMo exceptions. 14 | """ 15 | 16 | 17 | class OLMoConfigurationError(OLMoError): 18 | """ 19 | An error with a configuration file. 20 | """ 21 | 22 | 23 | class OLMoCliError(OLMoError): 24 | """ 25 | An error from incorrect CLI usage. 26 | """ 27 | 28 | 29 | class OLMoEnvironmentError(OLMoError): 30 | """ 31 | An error from incorrect environment variables. 32 | """ 33 | 34 | 35 | class OLMoNetworkError(OLMoError): 36 | """ 37 | An error with a network request. 38 | """ 39 | 40 | 41 | class OLMoCheckpointError(OLMoError): 42 | """ 43 | An error occurred reading or writing from a checkpoint. 44 | """ 45 | 46 | 47 | class OLMoThreadError(Exception): 48 | """ 49 | Raised when a thread fails. 50 | """ 51 | -------------------------------------------------------------------------------- /olmo/initialization.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .config import InitFnType, ModelConfig 8 | from .util import StrEnum 9 | 10 | __all__ = ["init_weights", "ModuleType"] 11 | 12 | 13 | class ModuleType(StrEnum): 14 | in_module = "in" 15 | out_module = "out" 16 | emb = "emb" 17 | final_out = "final_out" 18 | 19 | 20 | def init_weights( 21 | config: ModelConfig, 22 | module: Union[nn.Linear, nn.Embedding], 23 | d: Optional[int] = None, 24 | layer_id: Optional[int] = None, 25 | std_factor: float = 1.0, 26 | type_of_module: Optional[ModuleType] = None, 27 | ) -> None: 28 | """ 29 | Initialize weights of a linear or embedding module. 30 | 31 | :param config: The model config. 32 | :param module: The linear or embedding submodule to initialize. 33 | :param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions 34 | for fused layers. 35 | :param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by 36 | ``1 / sqrt(2 * (layer_id + 1))``. 37 | """ 38 | d = d if d is not None else config.d_model 39 | if config.init_fn == InitFnType.normal: 40 | std = config.init_std * std_factor 41 | if config.init_cutoff_factor is not None: 42 | cutoff_value = config.init_cutoff_factor * std 43 | nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value) 44 | else: 45 | nn.init.normal_(module.weight, mean=0.0, std=std) 46 | elif config.init_fn == InitFnType.mitchell: 47 | std = std_factor / math.sqrt(d) 48 | if layer_id is not None: 49 | std = std / math.sqrt(2 * (layer_id + 1)) 50 | nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std) 51 | elif config.init_fn == InitFnType.kaiming_normal: 52 | nn.init.kaiming_normal_(module.weight, nonlinearity="relu") 53 | elif config.init_fn == InitFnType.fan_in: 54 | std = std_factor / math.sqrt(d) 55 | nn.init.normal_(module.weight, mean=0.0, std=std) 56 | elif config.init_fn == InitFnType.full_megatron: 57 | if type_of_module is None: 58 | raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.") 59 | 60 | cutoff_factor = config.init_cutoff_factor 61 | if cutoff_factor is None: 62 | cutoff_factor = 3 63 | 64 | if type_of_module == ModuleType.in_module: 65 | # for att_proj (same as QKV), ff_proj 66 | std = config.init_std 67 | elif type_of_module == ModuleType.out_module: 68 | # for attn_out, ff_out 69 | std = config.init_std / math.sqrt(2.0 * config.n_layers) 70 | elif type_of_module == ModuleType.emb: 71 | # positional embeddings (wpe) 72 | # token embeddings (wte) 73 | std = config.init_std 74 | elif type_of_module == ModuleType.final_out: 75 | # final output (ff_out) 76 | std = config.d_model**-0.5 77 | else: 78 | raise RuntimeError(f"Unknown module type '{type_of_module}'") 79 | nn.init.trunc_normal_( 80 | module.weight, 81 | mean=0.0, 82 | std=std, 83 | a=-cutoff_factor * std, 84 | b=cutoff_factor * std, 85 | ) 86 | else: 87 | raise NotImplementedError(config.init_fn) 88 | 89 | if isinstance(module, nn.Linear): 90 | if module.bias is not None: 91 | nn.init.zeros_(module.bias) 92 | 93 | if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False): 94 | with torch.no_grad(): 95 | module.weight.div_(math.sqrt(2 * config.n_layers)) 96 | -------------------------------------------------------------------------------- /olmo/optim.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABCMeta, abstractmethod 3 | from dataclasses import dataclass, replace 4 | from math import cos, pi, sqrt 5 | from typing import Any, Dict, List, Optional, Tuple, Union 6 | 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn as nn 10 | from torch.distributed.fsdp import FullyShardedDataParallel 11 | from torch.optim.optimizer import Optimizer as OptimizerBase 12 | 13 | from . import LayerNormBase 14 | from .config import OptimizerType, SchedulerConfig, SchedulerType, TrainConfig 15 | from .torch_util import get_default_device, is_distributed 16 | 17 | __all__ = [ 18 | "Optimizer", 19 | "LionW", 20 | "AdamW", 21 | "Scheduler", 22 | "CosWithWarmup", 23 | "LinearWithWarmup", 24 | "InvSqrtWithWarmup", 25 | "MaxScheduler", 26 | "ConstantScheduler", 27 | "BoltOnWarmupScheduler", 28 | "build_optimizer", 29 | "build_scheduler", 30 | ] 31 | 32 | 33 | log = logging.getLogger(__name__) 34 | 35 | 36 | class Optimizer(OptimizerBase): 37 | def _clean_param_name(self, name: str) -> str: 38 | return name.replace("_fsdp_wrapped_module.", "") 39 | 40 | @torch.no_grad() 41 | def clip_grads_and_collect_metrics( 42 | self, global_step: int, collect_param_metrics: bool = True 43 | ) -> Dict[str, torch.Tensor]: 44 | """ 45 | Clips gradients for every group that has the field `max_grad_norm`. 46 | At the same time collect metrics for each parameter and its gradient. 47 | """ 48 | device = get_default_device() 49 | 50 | # NOTE (epwalsh): during distributed training we're making an assumption that the order of 51 | # the param groups and the params within each group are the same across all ranks. 52 | # This is justified since we initialize the parameter groups in every rank by iterating over 53 | # `module.parameters()` or `module.named_modules()` / `module.named_parameters()`, each of which 54 | # provides a consistent order. 55 | # For each parameter (with a gradient) we'll collect: 56 | # - min, max, avg, norm of the param itself 57 | # - min, max, avg, norm of the param's gradient 58 | # - min, max, avg, norm of any additional per-parameter optimizer state metrics returned from 59 | # `self.get_state_for_param()`. 60 | # Afterwards we'll reduce these all over all ranks. 61 | per_param_min_metrics: List[torch.Tensor] = [] 62 | per_param_max_metrics: List[torch.Tensor] = [] 63 | per_param_sum_metrics: List[torch.Tensor] = [] 64 | per_param_norm_metrics: List[torch.Tensor] = [] 65 | per_param_numel_metrics: List[torch.Tensor] = [] 66 | 67 | per_param_min_metric_names: List[str] = [] 68 | per_param_max_metric_names: List[str] = [] 69 | per_param_avg_metric_names: List[str] = [] 70 | per_param_norm_metric_names: List[str] = [] 71 | 72 | # Collect metrics locally. 73 | for group in self.param_groups: 74 | if is_distributed(): 75 | # TODO (epwalsh): handle non-sharded params. We don't have any right now but we would 76 | # with ReLoRa, for example. 77 | assert group.get("sharded", True) is True 78 | 79 | for name, p in zip(group["param_names"], group["params"]): 80 | name = self._clean_param_name(name) 81 | # Always need to collect the norm of gradients for clipping, even if we're not collecting 82 | # other metrics. 83 | tensors: List[Optional[torch.Tensor]] = [p.grad] 84 | prefixes: List[str] = [f"grad/{name}"] 85 | if collect_param_metrics: 86 | state = self.get_state_for_param(p) 87 | sorted_state_keys = sorted([k for k in state.keys()]) 88 | tensors.extend([p] + [state[key] for key in sorted_state_keys]) 89 | prefixes.extend([f"param/{name}"] + [f"{key}/{name}" for key in sorted_state_keys]) 90 | assert len(tensors) == len(prefixes) 91 | 92 | # Get min, max, avg, and norm for all `tensors` associated with the parameter. 93 | for x, prefix in zip(tensors, prefixes): 94 | # grad or state tensors could be none for params that have their shards completely on 95 | # other ranks. 96 | if x is not None and x.numel() > 0: 97 | if collect_param_metrics: 98 | x_abs = x.abs() 99 | per_param_min_metrics.append(x_abs.min().unsqueeze(0).to(dtype=torch.float32)) 100 | per_param_max_metrics.append(x_abs.max().unsqueeze(0).to(dtype=torch.float32)) 101 | per_param_sum_metrics.append(x.sum().unsqueeze(0).to(dtype=torch.float32)) 102 | per_param_numel_metrics.append( 103 | torch.tensor([x.numel()], device=device, dtype=torch.float32) 104 | ) 105 | per_param_norm_metrics.append( 106 | torch.linalg.vector_norm(x, 2.0, dtype=torch.float32).unsqueeze(0) 107 | ) 108 | else: 109 | if collect_param_metrics: 110 | per_param_min_metrics.append( 111 | torch.tensor([float("inf")], device=device, dtype=torch.float32) 112 | ) 113 | per_param_max_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32)) 114 | per_param_sum_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32)) 115 | per_param_numel_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32)) 116 | per_param_norm_metrics.append(torch.tensor([0.0], device=device, dtype=torch.float32)) 117 | if collect_param_metrics: 118 | per_param_min_metric_names.append(f"{prefix}.min") 119 | per_param_max_metric_names.append(f"{prefix}.max") 120 | per_param_avg_metric_names.append(f"{prefix}.avg") 121 | per_param_norm_metric_names.append(f"{prefix}.norm") 122 | 123 | assert ( 124 | len(per_param_min_metrics) 125 | == len(per_param_min_metric_names) 126 | == len(per_param_max_metrics) 127 | == len(per_param_max_metric_names) 128 | == len(per_param_sum_metrics) 129 | == len(per_param_numel_metrics) 130 | == len(per_param_avg_metric_names) 131 | ) 132 | assert len(per_param_norm_metrics) == len(per_param_norm_metric_names) 133 | 134 | def is_grad_norm_metric(metric_name: str) -> bool: 135 | return metric_name.startswith("grad/") and metric_name.endswith(".norm") 136 | 137 | # Now reduce metrics over all ranks. 138 | total_grad_norm: torch.Tensor 139 | per_param_avg_metrics: List[torch.Tensor] = [] 140 | if is_distributed(): # TODO (epwalsh): skip for non-sharded params 141 | # Reduce metrics across all ranks. Note that we can use a `reduce` for most cases 142 | # instead of an `all_reduce`, but we need `all_reduce` for norms so that all ranks 143 | # get the right value for gradient norms so they can clip correctly. 144 | # Reduce mins. 145 | if per_param_min_metrics: 146 | all_mins = torch.cat(per_param_min_metrics).to(device) 147 | dist.reduce(all_mins, 0, op=dist.ReduceOp.MIN) 148 | per_param_min_metrics = all_mins.split(1) 149 | # Reduce maxs. 150 | if per_param_max_metrics: 151 | all_maxs = torch.cat(per_param_max_metrics).to(device) 152 | dist.reduce(all_maxs, 0, op=dist.ReduceOp.MAX) 153 | per_param_max_metrics = all_maxs.split(1) 154 | # Reduce sums or just norms. 155 | all_norms = torch.cat(per_param_norm_metrics).to(device) ** 2.0 156 | if per_param_sum_metrics and per_param_numel_metrics: 157 | all_sums = torch.cat(per_param_sum_metrics).to(device) 158 | all_numels = torch.cat(per_param_numel_metrics).to(device) 159 | all_sums_norms_numels = torch.cat( 160 | [all_sums.unsqueeze(0), all_norms.unsqueeze(0), all_numels.unsqueeze(0)], dim=0 161 | ) 162 | dist.all_reduce(all_sums_norms_numels, op=dist.ReduceOp.SUM) 163 | all_sums, all_norms, all_numels = all_sums_norms_numels.split(1) 164 | # Get averages. 165 | # NOTE: could get infs for non-rank0 processes but that's okay. 166 | per_param_avg_metrics = (all_sums / all_numels).squeeze(0).split(1) 167 | else: 168 | dist.all_reduce(all_norms, op=dist.ReduceOp.SUM) 169 | grad_norm_metric_mask = torch.tensor( 170 | [float(is_grad_norm_metric(n)) for n in per_param_norm_metric_names], device=all_norms.device 171 | ) 172 | total_grad_norm = (all_norms * grad_norm_metric_mask).sum() ** 0.5 173 | per_param_norm_metrics = (all_norms ** (0.5)).squeeze(0).split(1) 174 | else: 175 | total_grad_norm = ( 176 | torch.cat( 177 | [ 178 | m 179 | for m, n in zip(per_param_norm_metrics, per_param_norm_metric_names) 180 | if is_grad_norm_metric(n) 181 | ] 182 | ) 183 | ** 2.0 184 | ).sum() ** 0.5 185 | per_param_avg_metrics = [x / n for x, n in zip(per_param_sum_metrics, per_param_numel_metrics)] 186 | 187 | assert len(per_param_avg_metrics) == len(per_param_avg_metric_names) 188 | 189 | # Collect all metrics into a single dict. 190 | all_metrics: Dict[str, torch.Tensor] = {} 191 | for metric_name, metric in zip(per_param_min_metric_names, per_param_min_metrics): 192 | all_metrics[metric_name] = metric.squeeze(0) 193 | for metric_name, metric in zip(per_param_max_metric_names, per_param_max_metrics): 194 | all_metrics[metric_name] = metric.squeeze(0) 195 | for metric_name, metric in zip(per_param_avg_metric_names, per_param_avg_metrics): 196 | all_metrics[metric_name] = metric.squeeze(0) 197 | for metric_name, metric in zip(per_param_norm_metric_names, per_param_norm_metrics): 198 | all_metrics[metric_name] = metric.squeeze(0) 199 | all_metrics["total_grad_norm"] = total_grad_norm 200 | 201 | # Clip gradients. 202 | num_grads_clipped = 0 203 | num_eligible_grads = 0 204 | for group in self.param_groups: 205 | if (max_norm_ratio := group.get("max_grad_norm_ratio")) is not None: 206 | num_clipped = self._do_adaptive_clipping( 207 | group, max_norm_ratio, global_step, all_metrics, collect_param_metrics=collect_param_metrics 208 | ) 209 | elif (max_norm := group.get("max_grad_norm")) is not None: 210 | num_clipped = self._do_global_fixed_clipping( 211 | group, max_norm, all_metrics, collect_param_metrics=collect_param_metrics 212 | ) 213 | else: 214 | # No clipping needed. 215 | continue 216 | num_eligible_grads += len(group["params"]) 217 | if num_clipped is not None: 218 | num_grads_clipped += num_clipped 219 | 220 | if collect_param_metrics: 221 | if num_eligible_grads > 0: 222 | clipping_rate = torch.tensor(num_grads_clipped / num_eligible_grads, device="cpu") 223 | else: 224 | clipping_rate = torch.tensor(0.0, device="cpu") 225 | all_metrics["clipping_rate"] = clipping_rate 226 | return all_metrics 227 | else: 228 | return {} 229 | 230 | @torch.no_grad() 231 | def _do_adaptive_clipping( 232 | self, 233 | group: Dict[str, Any], 234 | max_norm_ratio: float, 235 | global_step: int, 236 | all_metrics: Dict[str, torch.Tensor], 237 | collect_param_metrics: bool = True, 238 | ) -> Optional[int]: 239 | """ 240 | Do adaptive gradient clipping on a param group. 241 | 242 | If ``collect_param_metrics`` is ``True`` this will return the total number of gradients clipped. 243 | """ 244 | device = get_default_device() 245 | num_grads_clipped = 0 246 | # We'll use the bigger of beta1 and beta2 to update the exponential average of the norm of 247 | # the gradient (a scalar), not to be confused with the exponential average of the gradient. 248 | # TODO (epwalsh): handle optimizers that don't have betas. 249 | beta1, beta2 = group["betas"] 250 | beta = max(beta1, beta2) 251 | for name, p in zip(group["param_names"], group["params"]): 252 | name = self._clean_param_name(name) 253 | grad_norm = all_metrics.get(f"grad/{name}.norm") 254 | if grad_norm is None: 255 | continue 256 | 257 | # Get or initialize the exponential average of grad norm. 258 | # TODO: The way we have it right now, every rank tracks the `grad_norm_exp_avg` of every parameter, 259 | # even parameters for which the corresponding local shard is empty. This has the potential to 260 | # cause some issues with the optimizer, as we ran into with https://github.com/allenai/LLM/pull/372. 261 | # So we should consider changing how we do this at some point so that we don't add any state 262 | # to parameters for which the local shard is empty. That would probably add extra distributed 263 | # communication, at least on steps where we have to log (i.e. when `collect_param_metrics=True`). 264 | state = self.state[p] 265 | grad_norm_exp_avg = state.get("grad_norm_exp_avg") 266 | if grad_norm_exp_avg is None: 267 | grad_norm_exp_avg = grad_norm.clone().to(device) 268 | # We don't want to add anything to `state` until `state` has been initialized, otherwise 269 | # this will crash some optimizers which rely on checking `len(state)`. The downside here 270 | # is that we won't start tracking `grad_norm_exp_avg` until the 2nd training step. 271 | if global_step > 1: 272 | state["grad_norm_exp_avg"] = grad_norm_exp_avg 273 | 274 | max_allowed_norm = max_norm_ratio * grad_norm_exp_avg 275 | clip_coef = max_allowed_norm / (grad_norm + 1e-6) 276 | 277 | # Clip the gradients and update the exponential average. 278 | # Note that multiplying by the clamped coefficient is meaningless when it is 279 | # equal to 1, but it avoids the host-device sync that would result from `if clip_coef_clamped < 1`. 280 | clip_coef_clamped = torch.clamp(clip_coef, max=1.0) 281 | if p.grad is not None: 282 | # p.grad could be none for some ranks when using FSDP. 283 | p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype)) 284 | 285 | # Update the exponential average of the norm of the gradient with the clipped norm of the gradient. 286 | grad_norm_exp_avg.lerp_((grad_norm * clip_coef_clamped).to(grad_norm_exp_avg.device), 1 - beta) 287 | # Alternative: update with the *unclipped* norm of the gradient. 288 | # grad_norm_exp_avg.lerp_(grad_norm.to(grad_norm_exp_avg.device), 1 - beta) 289 | 290 | if collect_param_metrics: 291 | # Can't avoid host-device sync here. 292 | if clip_coef_clamped < 1.0: 293 | num_grads_clipped += 1 294 | all_metrics[f"grad_norm_exp_avg/{name}"] = grad_norm_exp_avg 295 | return num_grads_clipped if collect_param_metrics else None 296 | 297 | @torch.no_grad() 298 | def _do_global_fixed_clipping( 299 | self, 300 | group: Dict[str, Any], 301 | max_norm: float, 302 | all_metrics: Dict[str, torch.Tensor], 303 | collect_param_metrics: bool = True, 304 | ) -> Optional[int]: 305 | """ 306 | Do global fixed gradient clipping on a param group. 307 | 308 | If ``collect_param_metrics`` is ``True`` this will return the total number of gradients clipped. 309 | """ 310 | device = get_default_device() 311 | total_grad_norm = all_metrics["total_grad_norm"] 312 | clip_coef = max_norm / (total_grad_norm.to(device) + 1e-6) 313 | clip_coef_clamped = torch.clamp(clip_coef, max=1.0) 314 | num_grads_clipped: Optional[int] = None 315 | if collect_param_metrics: 316 | # Can't avoid host-device sync here. 317 | if clip_coef_clamped < 1.0: 318 | num_grads_clipped = len(group["params"]) 319 | for p in group["params"]: 320 | # Clip the gradients. 321 | # Note that multiplying by the clamped coefficient is meaningless when it is 322 | # equal to 1, but it avoids the host-device sync that would result from `if clip_coef_clamped < 1`. 323 | if p.grad is not None: 324 | # p.grad could be none for some ranks when using FSDP. 325 | p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype)) 326 | return num_grads_clipped 327 | 328 | def get_post_step_metrics(self, module: nn.Module) -> Dict[str, torch.Tensor]: 329 | del module 330 | return {} 331 | 332 | def get_state_for_param(self, param: nn.Parameter) -> Dict[str, Optional[torch.Tensor]]: 333 | del param 334 | return {} 335 | 336 | 337 | class LionW(Optimizer): 338 | """ 339 | Adapted from https://github.com/google/automl/blob/master/lion/lion_pytorch.py 340 | """ 341 | 342 | def __init__( 343 | self, 344 | params, 345 | lr: float = 1e-4, 346 | betas: Tuple[float, float] = (0.9, 0.99), 347 | weight_decay: float = 0.0, 348 | ): 349 | assert lr > 0.0 350 | assert all([0.0 <= beta <= 1.0 for beta in betas]) 351 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) 352 | super().__init__(params, defaults) 353 | for group in self.param_groups: 354 | group["initial_lr"] = group["lr"] 355 | self._update_total_dot_prod: Optional[torch.Tensor] = None 356 | self._update_total_norm: Optional[torch.Tensor] = None 357 | self._signed_update_total_norm: Optional[torch.Tensor] = None 358 | 359 | def get_post_step_metrics(self, module: nn.Module) -> Dict[str, torch.Tensor]: 360 | update_total_dot_prod = self._update_total_dot_prod 361 | update_total_norm = self._update_total_norm 362 | signed_update_total_norm = self._signed_update_total_norm 363 | if update_total_dot_prod is None or update_total_norm is None or signed_update_total_norm is None: 364 | return {} 365 | 366 | if is_distributed() and isinstance(module, FullyShardedDataParallel): 367 | # Reduce total dot prod and norms across all ranks. 368 | update_total_norm = update_total_norm**2.0 369 | signed_update_total_norm = signed_update_total_norm**2.0 370 | # Reduce all together to avoid multiple communication calls. 371 | all_together = torch.stack([update_total_dot_prod, update_total_norm, signed_update_total_norm]) 372 | # Only need the final result on rank0, since that's where we log from. 373 | dist.reduce(all_together, 0) 374 | update_total_dot_prod, update_total_norm, signed_update_total_norm = all_together 375 | update_total_norm = update_total_norm**0.5 376 | signed_update_total_norm = signed_update_total_norm**0.5 377 | 378 | update_cos_sim = update_total_dot_prod / torch.max( 379 | update_total_norm * signed_update_total_norm, torch.tensor(1e-8, device=get_default_device()) 380 | ) 381 | return {"update_cos_sim": update_cos_sim} 382 | 383 | @torch.no_grad() 384 | def step(self, closure=None) -> None: 385 | if closure is not None: 386 | with torch.enable_grad(): 387 | closure() 388 | 389 | update_total_dot_prod = torch.tensor(0.0, dtype=torch.float32) 390 | update_norms = [] 391 | signed_update_norms = [] 392 | 393 | for group in self.param_groups: 394 | for p in group["params"]: 395 | if p.grad is None: 396 | continue 397 | 398 | # Perform step weight decay 399 | p.data.mul_(1 - group["lr"] * group["weight_decay"]) 400 | 401 | grad = p.grad 402 | state = self.state[p] 403 | 404 | # State initialization 405 | if len(state) == 0: 406 | # Exponential moving average of gradient values 407 | state["exp_avg"] = torch.zeros_like(p) 408 | 409 | exp_avg = state["exp_avg"] 410 | beta1, beta2 = group["betas"] 411 | 412 | # Weight update 413 | update = exp_avg * beta1 + grad * (1 - beta1) 414 | signed_update = torch.sign(update) 415 | p.add_(signed_update, alpha=-group["lr"]) 416 | 417 | # Decay the momentum running average coefficient 418 | exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) 419 | 420 | # Track dot product and norms of update vs signed update in order to calculate 421 | # their cosine similarity. 422 | update_total_dot_prod = update_total_dot_prod.to(update.device) 423 | update_total_dot_prod += torch.tensordot(update, signed_update, dims=len(update.shape)) 424 | update_norms.append(torch.linalg.vector_norm(update, 2.0, dtype=torch.float32)) 425 | signed_update_norms.append(torch.linalg.vector_norm(signed_update, 2.0, dtype=torch.float32)) 426 | 427 | # Compute cosine similarity between update and signed update. 428 | self._update_total_dot_prod = update_total_dot_prod.to(get_default_device()) 429 | self._update_total_norm = torch.linalg.vector_norm( 430 | torch.stack(update_norms), 431 | 2.0, 432 | dtype=torch.float32, 433 | ).to(get_default_device()) 434 | self._signed_update_total_norm = torch.linalg.vector_norm( 435 | torch.stack(signed_update_norms), 436 | 2.0, 437 | dtype=torch.float32, 438 | ).to(get_default_device()) 439 | 440 | 441 | class AdamW(torch.optim.AdamW, Optimizer): 442 | def get_state_for_param(self, param: nn.Parameter) -> Dict[str, Optional[torch.Tensor]]: 443 | return {key: self.state[param].get(key) for key in ("exp_avg", "exp_avg_sq")} # type: ignore 444 | 445 | 446 | @dataclass 447 | class Scheduler(metaclass=ABCMeta): 448 | # NOTE: these fields are not given default values because otherwise dataclasses complains 449 | # about how the scheduler subclasses are defined. 450 | grad_clip_warmup_steps: Optional[int] 451 | grad_clip_warmup_factor: Optional[float] 452 | warmup_min_lr: Optional[float] 453 | 454 | @abstractmethod 455 | def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float: 456 | raise NotImplementedError 457 | 458 | def _get_max_grad_norm_coeff( 459 | self, initial_value: Optional[float], step: int, max_steps: int 460 | ) -> Optional[float]: 461 | del max_steps # might need this in the future, but for now I just wanted to match the API of `get_lr()`. 462 | if initial_value is None: 463 | return None 464 | elif ( 465 | self.grad_clip_warmup_steps is None 466 | or self.grad_clip_warmup_factor is None 467 | or step > self.grad_clip_warmup_steps 468 | ): 469 | return initial_value 470 | else: 471 | return self.grad_clip_warmup_factor * initial_value 472 | 473 | def get_max_grad_norm( 474 | self, initial_max_grad_norm: Optional[float], step: int, max_steps: int 475 | ) -> Optional[float]: 476 | return self._get_max_grad_norm_coeff(initial_max_grad_norm, step, max_steps) 477 | 478 | def get_max_grad_norm_ratio( 479 | self, initial_max_grad_norm_ratio: Optional[float], step: int, max_steps: int 480 | ) -> Optional[float]: 481 | return self._get_max_grad_norm_coeff(initial_max_grad_norm_ratio, step, max_steps) 482 | 483 | def _linear_warmup(self, initial_lr: float, step: int, warmup_steps: int = 2000) -> float: 484 | # return initial_lr * (0.1 + 0.9 * min(step, warmup_steps) / warmup_steps) 485 | warmup_min_lr = self.warmup_min_lr if self.warmup_min_lr is not None else initial_lr * 0.10 486 | assert 0 <= warmup_min_lr < initial_lr 487 | return warmup_min_lr + (initial_lr - warmup_min_lr) * min(step, warmup_steps) / warmup_steps 488 | 489 | @dataclass 490 | class CosWithWarmup(Scheduler): 491 | warmup_steps: Optional[Union[int, float]] 492 | alpha_f: float = 0.1 493 | t_max: Optional[int] = None 494 | 495 | def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float: 496 | max_steps = max_steps if self.t_max is None else self.t_max 497 | eta_min = initial_lr * self.alpha_f 498 | if step < self.warmup_steps: 499 | return self._linear_warmup(initial_lr, step, self.warmup_steps) 500 | elif step >= max_steps: 501 | return eta_min 502 | else: 503 | step = step - self.warmup_steps 504 | max_steps = max_steps - self.warmup_steps 505 | return eta_min + (initial_lr - eta_min) * (1 + cos(pi * step / max_steps)) / 2 506 | 507 | 508 | @dataclass 509 | class LinearWithWarmup(Scheduler): 510 | warmup_steps: int 511 | alpha_f: float = 0.1 512 | t_max: Optional[int] = None 513 | 514 | def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float: 515 | max_steps = max_steps if self.t_max is None else self.t_max 516 | eta_min = initial_lr * self.alpha_f 517 | if step < self.warmup_steps: 518 | return self._linear_warmup(initial_lr, step, self.warmup_steps) 519 | elif step >= max_steps: 520 | return eta_min 521 | else: 522 | step = step - self.warmup_steps 523 | max_steps = max_steps - self.warmup_steps 524 | return initial_lr - (initial_lr - eta_min) * (step / max_steps) 525 | 526 | 527 | @dataclass 528 | class InvSqrtWithWarmup(Scheduler): 529 | warmup_steps: int 530 | 531 | def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float: 532 | if step < self.warmup_steps: 533 | return self._linear_warmup(initial_lr, step, self.warmup_steps) 534 | del max_steps 535 | return initial_lr * sqrt(self.warmup_steps / max(self.warmup_steps, step)) 536 | 537 | 538 | @dataclass 539 | class MaxScheduler(Scheduler): 540 | sched1: Scheduler 541 | sched2: Scheduler 542 | 543 | def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float: 544 | return max( 545 | self.sched1.get_lr(initial_lr, step, max_steps), self.sched2.get_lr(initial_lr, step, max_steps) 546 | ) 547 | 548 | 549 | @dataclass 550 | class BoltOnWarmupScheduler(Scheduler): 551 | inner: Scheduler 552 | warmup_start: int 553 | warmup_end: int 554 | 555 | @classmethod 556 | def wrap(cls, scheduler: Scheduler, warmup_start: int, warmup_end: int) -> "BoltOnWarmupScheduler": 557 | return cls( 558 | grad_clip_warmup_steps=None, 559 | grad_clip_warmup_factor=None, 560 | inner=scheduler, 561 | warmup_start=warmup_start, 562 | warmup_end=warmup_end, 563 | ) 564 | 565 | def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float: 566 | if step < self.warmup_start: 567 | return 0.0 568 | if step < self.warmup_end: 569 | lr_at_intercept = self.inner.get_lr(initial_lr, self.warmup_end, max_steps) 570 | return lr_at_intercept * (step - self.warmup_start) / (self.warmup_end - self.warmup_start) 571 | else: 572 | return self.inner.get_lr(initial_lr, step, max_steps) 573 | 574 | def _get_max_grad_norm_coeff( 575 | self, initial_value: Optional[float], step: int, max_steps: int 576 | ) -> Optional[float]: 577 | return self.inner._get_max_grad_norm_coeff(initial_value, step, max_steps) 578 | 579 | 580 | @dataclass 581 | class ConstantScheduler(Scheduler): 582 | def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float: 583 | del step, max_steps 584 | return initial_lr 585 | 586 | 587 | PARAM_GROUP_FIELDS = ("sharded", "max_grad_norm", "max_grad_norm_ratio", "param_names") 588 | 589 | 590 | def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]]: 591 | """ 592 | Separate parameters into weight decay and non weight decay groups. 593 | """ 594 | param_groups: List[Dict[str, Any]] 595 | param_group_defaults = { 596 | "sharded": isinstance(model, FullyShardedDataParallel), 597 | "max_grad_norm": cfg.max_grad_norm, 598 | "max_grad_norm_ratio": cfg.max_grad_norm_ratio, 599 | } 600 | 601 | # Separate out parameters that we don't want to apply weight decay to, like norms and biases. 602 | decay = set() 603 | no_decay = set() 604 | all_params = {} 605 | for mn, m in model.named_modules(): 606 | for pn, p in m.named_parameters(): 607 | # NOTE: because named_modules and named_parameters are recursive 608 | # we will see the same tensors p many many times, but doing it this way 609 | # allows us to know which parent module any tensor p belongs to... 610 | if not p.requires_grad: 611 | continue 612 | 613 | fpn = f"{mn}.{pn}" if mn else pn 614 | all_params[fpn] = p 615 | 616 | if pn.endswith("bias"): 617 | if cfg.optimizer.decay_norm_and_bias: 618 | decay.add(fpn) 619 | else: 620 | no_decay.add(fpn) 621 | elif pn.endswith("weight") and isinstance(m, nn.Linear): 622 | decay.add(fpn) 623 | elif pn.endswith("weight") and isinstance(m, (LayerNormBase, nn.LayerNorm)): 624 | if cfg.optimizer.decay_norm_and_bias: 625 | decay.add(fpn) 626 | else: 627 | no_decay.add(fpn) 628 | elif pn.endswith("weight") and isinstance(m, nn.Embedding): 629 | if cfg.optimizer.decay_embeddings: 630 | decay.add(fpn) 631 | else: 632 | no_decay.add(fpn) 633 | 634 | # Validate that we've considered every parameter 635 | inter_params = decay & no_decay 636 | union_params = decay | no_decay 637 | assert len(inter_params) == 0, f"parameters {inter_params} made it into both decay/no_decay sets!" 638 | assert ( 639 | len(all_params.keys() - union_params) == 0 640 | ), f"parameters {all_params.keys() - union_params} were not separated into either decay/no_decay set!" 641 | 642 | # Create the pytorch optimizer groups. 643 | decay_sorted = sorted(list(decay)) 644 | no_decay_sorted = sorted(list(no_decay)) 645 | param_groups = [] 646 | if len(decay_sorted) > 0: 647 | param_groups.append( 648 | { 649 | "params": [all_params[pn] for pn in decay_sorted], 650 | "param_names": decay_sorted, 651 | **param_group_defaults, 652 | } 653 | ) 654 | if len(no_decay_sorted) > 0: 655 | param_groups.append( 656 | { 657 | "params": [all_params[pn] for pn in no_decay_sorted], 658 | "param_names": no_decay_sorted, 659 | "weight_decay": 0.0, 660 | **param_group_defaults, 661 | } 662 | ) 663 | 664 | # Validate fields. 665 | for group in param_groups: 666 | for key in PARAM_GROUP_FIELDS: 667 | assert key in group 668 | 669 | return param_groups 670 | 671 | 672 | def fix_optim_state_dict(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]: 673 | """ 674 | Make sure old optim state dicts are compatible with new versions. 675 | """ 676 | if len(state_dict["param_groups"]) == 1 and len(optimizer.param_groups) == 2: 677 | assert optimizer.param_groups[1]["weight_decay"] == 0.0 678 | 679 | # Decay 680 | decay_param_group = {k: v for k, v in state_dict["param_groups"][0].items() if k != "params"} 681 | decay_param_group["params"] = optimizer.state_dict()["param_groups"][0]["params"] 682 | 683 | # No decay. 684 | no_decay_param_group = {k: v for k, v in state_dict["param_groups"][0].items() if k != "params"} 685 | no_decay_param_group["weight_decay"] = 0.0 686 | no_decay_param_group["params"] = optimizer.state_dict()["param_groups"][1]["params"] 687 | 688 | state_dict["param_groups"] = [decay_param_group, no_decay_param_group] 689 | 690 | assert len(optimizer.param_groups) == len(state_dict["param_groups"]) 691 | 692 | # Make sure: 693 | # - All required fields are included in the state dict, 694 | # - And that the values of those fields doesn't change from what's currently set in the optimizer, 695 | # since we might have changed those fields on purpose after a restart. 696 | for group, sd_group in zip(optimizer.param_groups, state_dict["param_groups"]): 697 | for key in PARAM_GROUP_FIELDS: 698 | sd_group[key] = group[key] 699 | 700 | return state_dict 701 | 702 | 703 | def build_optimizer(cfg: TrainConfig, model: nn.Module) -> Optimizer: 704 | param_groups = get_param_groups(cfg, model) 705 | log.info(f"Constructing optimizer with {len(param_groups)} param groups") 706 | if cfg.optimizer.name == OptimizerType.lionw: 707 | return LionW( 708 | param_groups, 709 | lr=cfg.optimizer.learning_rate, 710 | betas=cfg.optimizer.betas, 711 | weight_decay=cfg.optimizer.weight_decay, 712 | ) 713 | elif cfg.optimizer.name == OptimizerType.adamw: 714 | return AdamW( 715 | param_groups, 716 | lr=cfg.optimizer.learning_rate, 717 | betas=cfg.optimizer.betas, 718 | weight_decay=cfg.optimizer.weight_decay, 719 | eps=1e-5, 720 | ) 721 | else: 722 | raise NotImplementedError 723 | 724 | 725 | def build_scheduler(cfg: TrainConfig, sched_cfg: Optional[SchedulerConfig] = None) -> Scheduler: 726 | sched_cfg = sched_cfg if sched_cfg is not None else cfg.scheduler 727 | if sched_cfg.name == SchedulerType.cosine_with_warmup: 728 | return CosWithWarmup( 729 | grad_clip_warmup_steps=None 730 | if sched_cfg.grad_clip_warmup_steps is None 731 | else int(sched_cfg.grad_clip_warmup_steps), 732 | grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor, 733 | warmup_min_lr=cfg.optimizer.warmup_min_lr, 734 | warmup_steps=sched_cfg.t_warmup, 735 | alpha_f=sched_cfg.alpha_f, 736 | t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max), 737 | ) 738 | elif sched_cfg.name == SchedulerType.linear_with_warmup: 739 | return LinearWithWarmup( 740 | grad_clip_warmup_steps=None 741 | if sched_cfg.grad_clip_warmup_steps is None 742 | else int(sched_cfg.grad_clip_warmup_steps), 743 | grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor, 744 | warmup_min_lr=cfg.optimizer.warmup_min_lr, 745 | warmup_steps=int(sched_cfg.t_warmup), 746 | alpha_f=sched_cfg.alpha_f, 747 | t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max), 748 | ) 749 | elif sched_cfg.name == SchedulerType.inverse_sqrt_with_warmup: 750 | return InvSqrtWithWarmup( 751 | grad_clip_warmup_steps=None 752 | if sched_cfg.grad_clip_warmup_steps is None 753 | else int(sched_cfg.grad_clip_warmup_steps), 754 | grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor, 755 | warmup_steps=int(sched_cfg.t_warmup), 756 | ) 757 | elif sched_cfg.name == SchedulerType.max_scheduler: 758 | return MaxScheduler( 759 | grad_clip_warmup_steps=None 760 | if sched_cfg.grad_clip_warmup_steps is None 761 | else int(sched_cfg.grad_clip_warmup_steps), 762 | grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor, 763 | sched1=build_scheduler(cfg, replace(sched_cfg, name=SchedulerType.cosine_with_warmup)), 764 | sched2=build_scheduler(cfg, replace(sched_cfg, name=SchedulerType.inverse_sqrt_with_warmup)), 765 | ) 766 | elif sched_cfg.name == SchedulerType.constant: 767 | return ConstantScheduler( 768 | grad_clip_warmup_steps=None 769 | if sched_cfg.grad_clip_warmup_steps is None 770 | else int(sched_cfg.grad_clip_warmup_steps), 771 | grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor, 772 | ) 773 | else: 774 | raise NotImplementedError 775 | -------------------------------------------------------------------------------- /olmo/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaistAI/Knowledge-Entropy/85315146a9431548c696ed4b6ebfd630f121aa87/olmo/py.typed -------------------------------------------------------------------------------- /olmo/safetensors_util.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import pickle 3 | from dataclasses import dataclass 4 | from typing import Dict, Optional, Tuple 5 | 6 | import safetensors.torch 7 | import torch 8 | 9 | from olmo.aliases import PathOrStr 10 | 11 | __all__ = [ 12 | "state_dict_to_safetensors_file", 13 | "safetensors_file_to_state_dict", 14 | ] 15 | 16 | 17 | @dataclass(eq=True, frozen=True) 18 | class STKey: 19 | keys: Tuple 20 | value_is_pickled: bool 21 | 22 | 23 | def encode_key(key: STKey) -> str: 24 | b = pickle.dumps((key.keys, key.value_is_pickled)) 25 | b = base64.urlsafe_b64encode(b) 26 | return str(b, "ASCII") 27 | 28 | 29 | def decode_key(key: str) -> STKey: 30 | b = base64.urlsafe_b64decode(key) 31 | keys, value_is_pickled = pickle.loads(b) 32 | return STKey(keys, value_is_pickled) 33 | 34 | 35 | def flatten_dict(d: Dict) -> Dict[STKey, torch.Tensor]: 36 | result = {} 37 | for key, value in d.items(): 38 | if isinstance(value, torch.Tensor): 39 | result[STKey((key,), False)] = value 40 | elif isinstance(value, dict): 41 | value = flatten_dict(value) 42 | for inner_key, inner_value in value.items(): 43 | result[STKey((key,) + inner_key.keys, inner_key.value_is_pickled)] = inner_value 44 | else: 45 | pickled = bytearray(pickle.dumps(value)) 46 | pickled_tensor = torch.frombuffer(pickled, dtype=torch.uint8) 47 | result[STKey((key,), True)] = pickled_tensor 48 | return result 49 | 50 | 51 | def unflatten_dict(d: Dict[STKey, torch.Tensor]) -> Dict: 52 | result: Dict = {} 53 | 54 | for key, value in d.items(): 55 | if key.value_is_pickled: 56 | value = pickle.loads(value.numpy().data) 57 | 58 | target_dict = result 59 | for k in key.keys[:-1]: 60 | new_target_dict = target_dict.get(k) 61 | if new_target_dict is None: 62 | new_target_dict = {} 63 | target_dict[k] = new_target_dict 64 | target_dict = new_target_dict 65 | target_dict[key.keys[-1]] = value 66 | 67 | return result 68 | 69 | 70 | def state_dict_to_safetensors_file(state_dict: Dict, filename: PathOrStr): 71 | state_dict = flatten_dict(state_dict) 72 | state_dict = {encode_key(k): v for k, v in state_dict.items()} 73 | safetensors.torch.save_file(state_dict, filename) 74 | 75 | 76 | def safetensors_file_to_state_dict(filename: PathOrStr, map_location: Optional[str] = None) -> Dict: 77 | if map_location is None: 78 | map_location = "cpu" 79 | state_dict = safetensors.torch.load_file(filename, device=map_location) 80 | state_dict = {decode_key(k): v for k, v in state_dict.items()} 81 | return unflatten_dict(state_dict) 82 | -------------------------------------------------------------------------------- /olmo/tokenizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from pathlib import Path 5 | from typing import List, Optional, Union 6 | 7 | from tokenizers import Tokenizer as BaseTokenizer 8 | 9 | from .aliases import PathOrStr 10 | from .config import ModelConfig, TokenizerConfig, TrainConfig, TruncationDirection 11 | from .exceptions import OLMoConfigurationError 12 | 13 | __all__ = ["Tokenizer"] 14 | 15 | 16 | class Tokenizer: 17 | """ 18 | A :class:`Tokenizer` is a light-weight wrapper around a HuggingFace :class:`tokenizers.Tokenizer`. 19 | 20 | :param base_tokenizer: The :class:`tokenizers.Tokenizer` to use. 21 | :param eos_token_id: The token ID corresponding to the "end-of-sentence" token. 22 | :param truncate_to: Truncate when tokenizing to this number of token IDs. 23 | :param truncate_direction: The direction to truncate in. "right" means truncate the tokens 24 | on the right. "left" means truncate the tokens on the left. If ``truncate_to`` is null, 25 | this setting has no effect. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | base_tokenizer: BaseTokenizer, 31 | eos_token_id: int, 32 | pad_token_id: Optional[int] = None, 33 | truncate_to: Optional[int] = None, 34 | truncate_direction: Union[str, TruncationDirection] = TruncationDirection.right, 35 | ): 36 | self.base_tokenizer = base_tokenizer 37 | self.base_tokenizer.no_truncation() 38 | self.eos_token_id = eos_token_id 39 | self.pad_token_id = pad_token_id if pad_token_id is not None else eos_token_id 40 | self.truncate_to = truncate_to 41 | self.truncate_direction = TruncationDirection(truncate_direction) 42 | 43 | @property 44 | def vocab_size(self) -> int: 45 | return self.base_tokenizer.get_vocab_size() 46 | 47 | @property 48 | def eos_token(self) -> str: 49 | return self.decode([self.eos_token_id], skip_special_tokens=False) 50 | 51 | @property 52 | def pad_token(self) -> str: 53 | return self.decode([self.pad_token_id], skip_special_tokens=False) 54 | 55 | @classmethod 56 | def from_train_config(cls, config: TrainConfig) -> Tokenizer: 57 | tokenizer_identifier = config.tokenizer.identifier 58 | if Path(tokenizer_identifier).is_file(): 59 | tokenizer = cls.from_file( 60 | tokenizer_identifier, 61 | eos_token_id=config.model.eos_token_id, 62 | pad_token_id=config.model.pad_token_id, 63 | ) 64 | else: 65 | tokenizer = cls.from_pretrained( 66 | tokenizer_identifier, 67 | eos_token_id=config.model.eos_token_id, 68 | pad_token_id=config.model.pad_token_id, 69 | ) 70 | if config.model.vocab_size != tokenizer.vocab_size: 71 | raise OLMoConfigurationError("vocab size mismatch between config and tokenizer") 72 | return tokenizer 73 | 74 | @classmethod 75 | def from_pretrained(cls, identifier: str, **kwargs) -> Tokenizer: 76 | """ 77 | Initialize a tokenizer from a pretrained tokenizer on the HuggingFace Hub. 78 | 79 | :param identifier: The identifier of a model on the Hub that contains a 80 | ``tokenizer.json`` file. 81 | :param kwargs: Other key word arguments passed to :class:`Tokenizer`. 82 | """ 83 | base_tokenizer = BaseTokenizer.from_pretrained(identifier) 84 | eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1) 85 | return cls(base_tokenizer, eos_token_id, **kwargs) 86 | 87 | @classmethod 88 | def from_file(cls, filename: PathOrStr, **kwargs) -> Tokenizer: 89 | """ 90 | Initialize a tokenizer from a file. 91 | 92 | You can create those files with ``BaseTokenizer.save()``. 93 | 94 | :param filename: The name of a file containing a tokenizer specification. 95 | :param kwargs: Other key word arguments passed to :class:`Tokenizer`. 96 | """ 97 | base_tokenizer = BaseTokenizer.from_file(filename) 98 | eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1) 99 | return cls(base_tokenizer, eos_token_id, **kwargs) 100 | 101 | @classmethod 102 | def from_checkpoint(cls, checkpoint_dir: PathOrStr) -> Tokenizer: 103 | """ 104 | Load a tokenizer from a checkpoint. 105 | """ 106 | from cached_path import cached_path 107 | 108 | # Load configs. 109 | config_path = cached_path(os.path.join(checkpoint_dir, "config.yaml")) 110 | tokenizer_config = TokenizerConfig.load(config_path, key="tokenizer") 111 | model_config = ModelConfig.load(config_path, key="model") 112 | 113 | # Initialize tokenizer and validate vocab size. 114 | if Path(tokenizer_config.identifier).is_file(): 115 | tokenizer = cls.from_file( 116 | tokenizer_config.identifier, 117 | eos_token_id=model_config.eos_token_id, 118 | pad_token_id=model_config.pad_token_id, 119 | ) 120 | else: 121 | tokenizer = cls.from_pretrained( 122 | tokenizer_config.identifier, 123 | eos_token_id=model_config.eos_token_id, 124 | pad_token_id=model_config.pad_token_id, 125 | ) 126 | if model_config.vocab_size != tokenizer.vocab_size: 127 | raise OLMoConfigurationError("vocab size mismatch between config and tokenizer") 128 | return tokenizer 129 | 130 | def add_special_tokens(self, input_ids: List[int]) -> List[int]: 131 | """ 132 | Add special tokens in-place (if not already present) to the given token IDs. 133 | """ 134 | if not input_ids or input_ids[-1] != self.eos_token_id: 135 | input_ids.append(self.eos_token_id) 136 | return input_ids 137 | 138 | def num_special_tokens_to_add(self, is_pair: bool = False) -> int: 139 | return 2 if is_pair else 1 140 | 141 | def _truncate( 142 | self, input_ids: List[int], truncate_to: Optional[int], direction: TruncationDirection 143 | ) -> list[int]: 144 | if truncate_to is None or len(input_ids) <= truncate_to: 145 | return input_ids 146 | elif direction == TruncationDirection.left: 147 | return input_ids[len(input_ids) - truncate_to :] 148 | else: 149 | return input_ids[: -(len(input_ids) - truncate_to)] 150 | 151 | def encode(self, input: str, add_special_tokens: bool = True) -> List[int]: 152 | """ 153 | Encode a string into token IDs. 154 | """ 155 | return self.encode_batch([input], add_special_tokens=add_special_tokens)[0] 156 | 157 | def encode_batch(self, inputs: List[str], add_special_tokens: bool = True) -> List[List[int]]: 158 | """ 159 | Encode a batch of strings into token IDs. 160 | """ 161 | truncate_to = self.truncate_to 162 | if truncate_to is not None and add_special_tokens: 163 | truncate_to -= self.num_special_tokens_to_add(False) 164 | 165 | batch_encoding = self.base_tokenizer.encode_batch(inputs) 166 | 167 | all_input_ids = [] 168 | for encoding in batch_encoding: 169 | input_ids = self._truncate(encoding.ids, truncate_to, self.truncate_direction) 170 | if add_special_tokens: 171 | input_ids = self.add_special_tokens(input_ids) 172 | all_input_ids.append(input_ids) 173 | 174 | return all_input_ids 175 | 176 | def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str: 177 | """ 178 | Decode a list of token IDs to a string. 179 | """ 180 | return self.base_tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) 181 | -------------------------------------------------------------------------------- /olmo/torch_util.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | from typing import Optional, TypeVar 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | T = TypeVar("T") 9 | 10 | 11 | def seed_all(seed: int): 12 | """Seed all rng objects.""" 13 | import random 14 | 15 | import numpy as np 16 | 17 | if seed < 0 or seed > 2**32 - 1: 18 | raise ValueError(f"Seed {seed} is invalid. It must be on [0; 2^32 - 1]") 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | # torch.manual_seed may call manual_seed_all but calling it again here 23 | # to make sure it gets called at least once 24 | torch.cuda.manual_seed_all(seed) 25 | 26 | 27 | def is_distributed() -> bool: 28 | return dist.is_available() and dist.is_initialized() 29 | 30 | 31 | def get_node_rank() -> int: 32 | return int(os.environ.get("NODE_RANK") or (get_global_rank() - get_local_rank()) // get_local_world_size()) 33 | 34 | 35 | def get_world_size() -> int: 36 | if is_distributed(): 37 | return dist.get_world_size() 38 | else: 39 | return 1 40 | 41 | 42 | def get_local_world_size() -> int: 43 | return int(os.environ.get("LOCAL_WORLD_SIZE") or 1) 44 | 45 | 46 | def get_global_rank() -> int: 47 | return int(os.environ.get("RANK") or dist.get_rank()) 48 | 49 | 50 | def get_local_rank() -> int: 51 | return int(os.environ.get("LOCAL_RANK") or 0) 52 | 53 | 54 | def get_fs_local_rank() -> int: 55 | """Get the local rank per filesystem, meaning that, regardless of the number of nodes, 56 | if all ranks share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_global_rank()`, 57 | but if nodes do not share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_local_rank()`. 58 | """ 59 | return int(os.environ.get("FS_LOCAL_RANK") or get_local_rank()) 60 | 61 | 62 | def move_to_device(o: T, device: torch.device) -> T: 63 | if isinstance(o, torch.Tensor): 64 | return o.to(device) # type: ignore[return-value] 65 | elif isinstance(o, dict): 66 | return {k: move_to_device(v, device) for k, v in o.items()} # type: ignore[return-value] 67 | elif isinstance(o, list): 68 | return [move_to_device(x, device) for x in o] # type: ignore[return-value] 69 | elif isinstance(o, tuple): 70 | return tuple((move_to_device(x, device) for x in o)) # type: ignore[return-value] 71 | else: 72 | return o 73 | 74 | 75 | def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False): 76 | """ 77 | Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf`` 78 | is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``. 79 | """ 80 | if check_neg_inf: 81 | x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min) 82 | if check_pos_inf: 83 | x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max) 84 | 85 | 86 | def get_default_device() -> torch.device: 87 | if torch.cuda.is_available() and torch.cuda.is_initialized(): 88 | return torch.device("cuda") 89 | else: 90 | return torch.device("cpu") 91 | 92 | 93 | def barrier() -> None: 94 | if is_distributed(): 95 | dist.barrier() 96 | 97 | 98 | def peak_gpu_memory(reset: bool = False) -> Optional[float]: 99 | """ 100 | Get the peak GPU memory usage in MB across all ranks. 101 | Only rank 0 will get the final result. 102 | """ 103 | if not torch.cuda.is_available(): 104 | return None 105 | 106 | device = torch.device("cuda") 107 | peak_mb = torch.cuda.max_memory_allocated(device) / 1000000 108 | if is_distributed(): 109 | peak_mb_tensor = torch.tensor(peak_mb, device=device) 110 | dist.reduce(peak_mb_tensor, 0, dist.ReduceOp.MAX) 111 | peak_mb = peak_mb_tensor.item() 112 | 113 | if reset: 114 | # Reset peak stats. 115 | torch.cuda.reset_max_memory_allocated(device) 116 | 117 | return peak_mb 118 | 119 | 120 | V = TypeVar("V", bool, int, float) 121 | 122 | 123 | def synchronize_value(value: V, device: torch.device) -> V: 124 | if dist.is_available() and dist.is_initialized(): 125 | value_tensor = torch.tensor(value, device=device) 126 | dist.broadcast(value_tensor, 0) 127 | return value_tensor.item() # type: ignore 128 | else: 129 | return value 130 | 131 | 132 | def synchronize_flag(flag: bool, device: torch.device) -> bool: 133 | return synchronize_value(flag, device) 134 | 135 | 136 | def gc_cuda(): 137 | gc.collect() 138 | if torch.cuda.is_available(): 139 | torch.cuda.empty_cache() 140 | -------------------------------------------------------------------------------- /olmo/util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import re 4 | import socket 5 | import sys 6 | import time 7 | import warnings 8 | from datetime import datetime 9 | from enum import Enum 10 | from itertools import cycle, islice 11 | from pathlib import Path 12 | from queue import Queue 13 | from threading import Thread 14 | from typing import Any, Callable, Dict, Optional, Union 15 | 16 | import boto3 17 | import botocore.exceptions as boto_exceptions 18 | import rich 19 | from botocore.config import Config 20 | from rich.console import Console, ConsoleRenderable 21 | from rich.highlighter import NullHighlighter 22 | from rich.progress import Progress 23 | from rich.text import Text 24 | from rich.traceback import Traceback 25 | 26 | from .aliases import PathOrStr 27 | from .exceptions import ( 28 | OLMoCliError, 29 | OLMoEnvironmentError, 30 | OLMoError, 31 | OLMoNetworkError, 32 | OLMoThreadError, 33 | ) 34 | from .torch_util import get_global_rank, get_local_rank, get_node_rank, is_distributed 35 | 36 | try: 37 | from functools import cache 38 | except ImportError: 39 | from functools import lru_cache as cache 40 | 41 | 42 | class StrEnum(str, Enum): 43 | """ 44 | This is equivalent to Python's :class:`enum.StrEnum` since version 3.11. 45 | We include this here for compatibility with older version of Python. 46 | """ 47 | 48 | def __str__(self) -> str: 49 | return self.value 50 | 51 | def __repr__(self) -> str: 52 | return f"'{str(self)}'" 53 | 54 | 55 | _log_extra_fields: Dict[str, Any] = {} 56 | log = logging.getLogger(__name__) 57 | 58 | 59 | class LogFilterType(StrEnum): 60 | rank0_only = "rank0_only" 61 | local_rank0_only = "local_rank0_only" 62 | all_ranks = "all_ranks" 63 | 64 | 65 | def log_extra_field(field_name: str, field_value: Any) -> None: 66 | global _log_extra_fields 67 | if field_value is None: 68 | if field_name in _log_extra_fields: 69 | del _log_extra_fields[field_name] 70 | else: 71 | _log_extra_fields[field_name] = field_value 72 | 73 | 74 | def setup_logging(log_filter_type: LogFilterType = LogFilterType.rank0_only) -> None: 75 | """ 76 | :param rank0_only: INFO and below messages will only be emitted on the rank0 process. 77 | """ 78 | log_extra_field("hostname", socket.gethostname()) 79 | if is_distributed(): 80 | log_extra_field("node_rank", get_node_rank()) 81 | log_extra_field("local_rank", get_local_rank()) 82 | log_extra_field("global_rank", get_global_rank()) 83 | else: 84 | log_extra_field("node_rank", 0) 85 | log_extra_field("local_rank", 0) 86 | log_extra_field("global_rank", 0) 87 | 88 | old_log_record_factory = logging.getLogRecordFactory() 89 | 90 | def log_record_factory(*args, **kwargs) -> logging.LogRecord: 91 | record = old_log_record_factory(*args, **kwargs) 92 | for field_name, field_value in _log_extra_fields.items(): 93 | setattr(record, field_name, field_value) 94 | return record 95 | 96 | logging.setLogRecordFactory(log_record_factory) 97 | 98 | handler: logging.Handler 99 | if ( 100 | os.environ.get("OLMo_NONINTERACTIVE", False) 101 | or os.environ.get("DEBIAN_FRONTEND", None) == "noninteractive" 102 | or not sys.stdout.isatty() 103 | ): 104 | handler = logging.StreamHandler(sys.stdout) 105 | formatter = logging.Formatter( 106 | "%(asctime)s\t%(hostname)s:%(local_rank)s\t%(name)s:%(lineno)s\t%(levelname)s\t%(message)s" 107 | ) 108 | formatter.default_time_format = "%Y-%m-%d %H:%M:%S" 109 | formatter.default_msec_format = "%s.%03d" 110 | handler.setFormatter(formatter) 111 | else: 112 | handler = RichHandler() 113 | 114 | def rank0_filter(record: logging.LogRecord) -> int: 115 | if record.levelno > logging.INFO: 116 | return 1 117 | if getattr(record, "global_rank", 0) == 0: 118 | return 1 119 | else: 120 | return 0 121 | 122 | def local_rank0_filter(record: logging.LogRecord) -> int: 123 | if record.levelno > logging.INFO: 124 | return 1 125 | if getattr(record, "local_rank", 0) == 0: 126 | return 1 127 | else: 128 | return 0 129 | 130 | if log_filter_type == LogFilterType.rank0_only: 131 | filter = rank0_filter 132 | elif log_filter_type == LogFilterType.local_rank0_only: 133 | filter = local_rank0_filter # type: ignore 134 | elif log_filter_type == LogFilterType.all_ranks: 135 | filter = None 136 | else: 137 | raise ValueError(log_filter_type) 138 | 139 | if filter is not None: 140 | handler.addFilter(filter) # type: ignore 141 | logging.basicConfig(handlers=[handler], level=logging.INFO) 142 | 143 | logging.captureWarnings(True) 144 | logging.getLogger("urllib3").setLevel(logging.ERROR) 145 | 146 | 147 | def excepthook(exctype, value, traceback): 148 | """ 149 | Used to patch `sys.excepthook` in order to log exceptions. 150 | """ 151 | if issubclass(exctype, KeyboardInterrupt): 152 | sys.__excepthook__(exctype, value, traceback) 153 | elif issubclass(exctype, OLMoCliError): 154 | rich.get_console().print(f"[yellow]{value}[/]", highlight=False) 155 | elif issubclass(exctype, OLMoError): 156 | rich.get_console().print(Text(f"{exctype.__name__}:", style="red"), value, highlight=False) 157 | else: 158 | log.critical("Uncaught %s: %s", exctype.__name__, value, exc_info=(exctype, value, traceback)) 159 | 160 | 161 | def install_excepthook(): 162 | sys.excepthook = excepthook 163 | 164 | 165 | def filter_warnings(): 166 | # Filter internal deprecation warnings from torch 167 | warnings.filterwarnings( 168 | action="ignore", 169 | category=UserWarning, 170 | message="torch.distributed.*_base is a private function and will be deprecated.*", 171 | ) 172 | warnings.filterwarnings( 173 | action="ignore", 174 | category=UserWarning, 175 | message="TypedStorage is deprecated.*", 176 | ) 177 | warnings.filterwarnings( 178 | action="ignore", 179 | category=UserWarning, 180 | message="Please use DTensor instead.*", 181 | ) 182 | # Torchvision warnings. We don't actually use torchvision. 183 | warnings.filterwarnings( 184 | action="ignore", 185 | message="failed to load.*", 186 | module="torchvision.io.image", 187 | ) 188 | 189 | 190 | def set_env_variables(): 191 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 192 | 193 | 194 | def prepare_cli_environment(log_filter_type: Optional[LogFilterType] = None): 195 | if log_filter_type is None: 196 | log_filter_type = LogFilterType(os.environ.get("LOG_FILTER_TYPE", "rank0_only")) 197 | rich.reconfigure(width=max(rich.get_console().width, 180), soft_wrap=True) 198 | setup_logging(log_filter_type=log_filter_type) 199 | install_excepthook() 200 | filter_warnings() 201 | set_env_variables() 202 | 203 | 204 | def clean_opt(arg: str) -> str: 205 | if "=" not in arg: 206 | arg = f"{arg}=True" 207 | name, val = arg.split("=", 1) 208 | name = name.strip("-").replace("-", "_") 209 | return f"{name}={val}" 210 | 211 | 212 | class RichHandler(logging.Handler): 213 | """ 214 | A simplified version of rich.logging.RichHandler from 215 | https://github.com/Textualize/rich/blob/master/rich/logging.py 216 | """ 217 | 218 | def __init__( 219 | self, 220 | *, 221 | level: Union[int, str] = logging.NOTSET, 222 | console: Optional[Console] = None, 223 | markup: bool = False, 224 | ) -> None: 225 | super().__init__(level=level) 226 | self.console = console or rich.get_console() 227 | self.highlighter = NullHighlighter() 228 | self.markup = markup 229 | 230 | def emit(self, record: logging.LogRecord) -> None: 231 | try: 232 | if hasattr(record.msg, "__rich__") or hasattr(record.msg, "__rich_console__"): 233 | self.console.print(record.msg) 234 | else: 235 | msg: Any = record.msg 236 | if isinstance(record.msg, str): 237 | msg = self.render_message(record=record, message=record.getMessage()) 238 | renderables = [ 239 | self.get_time_text(record), 240 | self.get_level_text(record), 241 | self.get_location_text(record), 242 | msg, 243 | ] 244 | if record.exc_info is not None: 245 | tb = Traceback.from_exception(*record.exc_info) # type: ignore 246 | renderables.append(tb) 247 | self.console.print(*renderables) 248 | except Exception: 249 | self.handleError(record) 250 | 251 | def render_message(self, *, record: logging.LogRecord, message: str) -> ConsoleRenderable: 252 | use_markup = getattr(record, "markup", self.markup) 253 | message_text = Text.from_markup(message) if use_markup else Text(message) 254 | 255 | highlighter = getattr(record, "highlighter", self.highlighter) 256 | if highlighter: 257 | message_text = highlighter(message_text) 258 | 259 | return message_text 260 | 261 | def get_time_text(self, record: logging.LogRecord) -> Text: 262 | log_time = datetime.fromtimestamp(record.created) 263 | time_str = log_time.strftime("[%Y-%m-%d %X]") 264 | return Text(time_str, style="log.time", end=" ") 265 | 266 | def get_level_text(self, record: logging.LogRecord) -> Text: 267 | level_name = record.levelname 268 | level_text = Text.styled(level_name.ljust(8), f"logging.level.{level_name.lower()}") 269 | level_text.style = "log.level" 270 | level_text.end = " " 271 | return level_text 272 | 273 | def get_location_text(self, record: logging.LogRecord) -> Text: 274 | name_and_line = f"{record.name}:{record.lineno}" if record.name != "root" else "root" 275 | text = f"[{name_and_line}, rank={record.local_rank}]" # type: ignore 276 | return Text(text, style="log.path") 277 | 278 | 279 | def wait_for(condition: Callable[[], bool], description: str, timeout: float = 10.0): 280 | """Wait for the condition function to return True.""" 281 | start_time = time.monotonic() 282 | while not condition(): 283 | time.sleep(0.5) 284 | if time.monotonic() - start_time > timeout: 285 | raise TimeoutError(f"{description} timed out") 286 | 287 | 288 | def is_url(path: PathOrStr) -> bool: 289 | return re.match(r"[a-z0-9]+://.*", str(path)) is not None 290 | 291 | 292 | def dir_is_empty(dir: PathOrStr) -> bool: 293 | dir = Path(dir) 294 | if not dir.is_dir(): 295 | return True 296 | try: 297 | next(dir.glob("*")) 298 | return False 299 | except StopIteration: 300 | return True 301 | 302 | 303 | def get_progress_bar() -> Progress: 304 | from cached_path import get_download_progress 305 | 306 | return get_download_progress() 307 | 308 | 309 | def resource_path( 310 | folder: PathOrStr, fname: str, local_cache: Optional[PathOrStr] = None, progress: Optional[Progress] = None 311 | ) -> Path: 312 | if local_cache is not None and (local_path := Path(local_cache) / fname).is_file(): 313 | log.info(f"Found local cache of {fname} at {local_path}") 314 | return local_path 315 | else: 316 | from cached_path import cached_path 317 | 318 | return cached_path(f"{str(folder).rstrip('/')}/{fname}", progress=progress) 319 | 320 | 321 | def file_size(path: PathOrStr) -> int: 322 | """ 323 | Get the size of a local or remote file in bytes. 324 | """ 325 | if is_url(path): 326 | from urllib.parse import urlparse 327 | 328 | parsed = urlparse(str(path)) 329 | if parsed.scheme == "gs": 330 | return _gcs_file_size(parsed.netloc, parsed.path.strip("/")) 331 | elif parsed.scheme in ("s3", "r2"): 332 | return _s3_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/")) 333 | elif parsed.scheme in ("http", "https"): 334 | return _http_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/")) 335 | elif parsed.scheme == "file": 336 | return file_size(str(path).replace("file://", "", 1)) 337 | else: 338 | raise NotImplementedError(f"file size not implemented for '{parsed.scheme}' files") 339 | else: 340 | return os.stat(path).st_size 341 | 342 | 343 | def upload(source: PathOrStr, target: str, save_overwrite: bool = False): 344 | """Upload source file to a target location on GCS or S3.""" 345 | from urllib.parse import urlparse 346 | 347 | source = Path(source) 348 | assert source.is_file() 349 | parsed = urlparse(target) 350 | if parsed.scheme == "gs": 351 | _gcs_upload(source, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite) 352 | elif parsed.scheme in ("s3", "r2"): 353 | _s3_upload(source, parsed.scheme, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite) 354 | else: 355 | raise NotImplementedError(f"Upload not implemented for '{parsed.scheme}' scheme") 356 | 357 | 358 | def get_bytes_range(source: PathOrStr, bytes_start: int, num_bytes: int) -> bytes: 359 | if is_url(source): 360 | from urllib.parse import urlparse 361 | 362 | parsed = urlparse(str(source)) 363 | if parsed.scheme == "gs": 364 | return _gcs_get_bytes_range(parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes) 365 | elif parsed.scheme in ("s3", "r2"): 366 | return _s3_get_bytes_range( 367 | parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes 368 | ) 369 | elif parsed.scheme in ("http", "https"): 370 | return _http_get_bytes_range( 371 | parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes 372 | ) 373 | elif parsed.scheme == "file": 374 | return get_bytes_range(str(source).replace("file://", "", 1), bytes_start, num_bytes) 375 | else: 376 | raise NotImplementedError(f"get bytes range not implemented for '{parsed.scheme}' files") 377 | else: 378 | with open(source, "rb") as f: 379 | f.seek(bytes_start) 380 | return f.read(num_bytes) 381 | 382 | 383 | def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]: 384 | if is_url(dir): 385 | from urllib.parse import urlparse 386 | 387 | parsed = urlparse(str(dir)) 388 | if parsed.scheme == "gs": 389 | raise NotImplementedError 390 | elif parsed.scheme in ("s3", "r2"): 391 | return _s3_find_latest_checkpoint(parsed.scheme, parsed.netloc, parsed.path.strip("/")) 392 | elif parsed.scheme == "file": 393 | return find_latest_checkpoint(str(dir).replace("file://", "", 1)) 394 | else: 395 | raise NotImplementedError(f"find_latest_checkpoint not implemented for '{parsed.scheme}' files") 396 | else: 397 | latest_step = 0 398 | latest_checkpoint: Optional[Path] = None 399 | for path in Path(dir).glob("step*"): 400 | if path.is_dir(): 401 | try: 402 | step = int(path.name.replace("step", "").replace("-unsharded", "")) 403 | except ValueError: 404 | continue 405 | # We prioritize sharded checkpoints over unsharded checkpoints. 406 | if step > latest_step or (step == latest_step and not path.name.endswith("-unsharded")): 407 | latest_step = step 408 | latest_checkpoint = path 409 | return latest_checkpoint 410 | 411 | 412 | def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False): 413 | from google.cloud import storage as gcs 414 | 415 | storage_client = gcs.Client() 416 | bucket = storage_client.bucket(bucket_name) 417 | blob = bucket.blob(key) 418 | if not save_overwrite and blob.exists(): 419 | raise FileExistsError(f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it.") 420 | blob.upload_from_filename(source) 421 | 422 | 423 | def _gcs_file_size(bucket_name: str, key: str) -> int: 424 | from google.api_core.exceptions import NotFound 425 | from google.cloud import storage as gcs 426 | 427 | storage_client = gcs.Client() 428 | bucket = storage_client.bucket(bucket_name) 429 | blob = bucket.blob(key) 430 | try: 431 | blob.reload() 432 | except NotFound: 433 | raise FileNotFoundError(f"gs://{bucket_name}/{key}") 434 | assert blob.size is not None 435 | return blob.size 436 | 437 | 438 | def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes: 439 | from google.api_core.exceptions import NotFound 440 | from google.cloud import storage as gcs 441 | 442 | storage_client = gcs.Client() 443 | bucket = storage_client.bucket(bucket_name) 444 | blob = bucket.blob(key) 445 | try: 446 | blob.reload() 447 | except NotFound: 448 | raise FileNotFoundError(f"gs://{bucket_name}/{key}") 449 | return blob.download_as_bytes(start=bytes_start, end=bytes_start + num_bytes - 1) 450 | 451 | 452 | def _get_s3_profile_name(scheme: str) -> Optional[str]: 453 | if scheme == "s3": 454 | # For backwards compatibility, we assume S3 uses the default profile if S3_PROFILE is not set. 455 | return os.environ.get("S3_PROFILE") 456 | if scheme == "r2": 457 | profile_name = os.environ.get("R2_PROFILE") 458 | if profile_name is None: 459 | raise OLMoEnvironmentError( 460 | "R2 profile name is not set. Did you forget to set the 'R2_PROFILE' env var?" 461 | ) 462 | 463 | return profile_name 464 | 465 | raise NotImplementedError(f"Cannot get profile name for scheme {scheme}") 466 | 467 | 468 | def _get_s3_endpoint_url(scheme: str) -> Optional[str]: 469 | if scheme == "s3": 470 | return None 471 | if scheme == "r2": 472 | r2_endpoint_url = os.environ.get("R2_ENDPOINT_URL") 473 | if r2_endpoint_url is None: 474 | raise OLMoEnvironmentError( 475 | "R2 endpoint url is not set. Did you forget to set the 'R2_ENDPOINT_URL' env var?" 476 | ) 477 | 478 | return r2_endpoint_url 479 | 480 | raise NotImplementedError(f"Cannot get endpoint url for scheme {scheme}") 481 | 482 | 483 | @cache 484 | def _get_s3_client(scheme: str): 485 | session = boto3.Session(profile_name=_get_s3_profile_name(scheme)) 486 | return session.client( 487 | "s3", 488 | endpoint_url=_get_s3_endpoint_url(scheme), 489 | config=Config(retries={"max_attempts": 10, "mode": "standard"}), 490 | use_ssl=not int(os.environ.get("OLMO_NO_SSL", "0")), 491 | ) 492 | 493 | 494 | def _wait_before_retry(attempt: int): 495 | time.sleep(min(0.5 * 2**attempt, 3.0)) 496 | 497 | 498 | def _s3_upload( 499 | source: Path, scheme: str, bucket_name: str, key: str, save_overwrite: bool = False, max_attempts: int = 3 500 | ): 501 | err: Optional[Exception] = None 502 | if not save_overwrite: 503 | for attempt in range(1, max_attempts + 1): 504 | try: 505 | _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key) 506 | raise FileExistsError( 507 | f"s3://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it." 508 | ) 509 | except boto_exceptions.ClientError as e: 510 | if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404: 511 | err = None 512 | break 513 | err = e 514 | 515 | if attempt < max_attempts: 516 | log.warning("%s failed attempt %d with retriable error: %s", _s3_upload.__name__, attempt, err) 517 | _wait_before_retry(attempt) 518 | 519 | if err is not None: 520 | raise OLMoNetworkError(f"Failed to check object existence during {scheme} upload") from err 521 | 522 | try: 523 | _get_s3_client(scheme).upload_file(source, bucket_name, key) 524 | except boto_exceptions.ClientError as e: 525 | raise OLMoNetworkError(f"Failed to upload to {scheme}") from e 526 | 527 | 528 | def _s3_file_size(scheme: str, bucket_name: str, key: str, max_attempts: int = 3) -> int: 529 | err: Optional[Exception] = None 530 | for attempt in range(1, max_attempts + 1): 531 | try: 532 | return _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)["ContentLength"] 533 | except boto_exceptions.ClientError as e: 534 | if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404: 535 | raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e 536 | err = e 537 | 538 | if attempt < max_attempts: 539 | log.warning("%s failed attempt %d with retriable error: %s", _s3_file_size.__name__, attempt, err) 540 | _wait_before_retry(attempt) 541 | 542 | raise OLMoNetworkError(f"Failed to get {scheme} file size") from err 543 | 544 | 545 | def _s3_get_bytes_range( 546 | scheme: str, bucket_name: str, key: str, bytes_start: int, num_bytes: int, max_attempts: int = 3 547 | ) -> bytes: 548 | err: Optional[Exception] = None 549 | for attempt in range(1, max_attempts + 1): 550 | try: 551 | return ( 552 | _get_s3_client(scheme) 553 | .get_object( 554 | Bucket=bucket_name, Key=key, Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}" 555 | )["Body"] 556 | .read() 557 | ) 558 | except boto_exceptions.ClientError as e: 559 | if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404: 560 | raise FileNotFoundError(f"{scheme}://{bucket_name}/{key}") from e 561 | err = e 562 | except (boto_exceptions.HTTPClientError, boto_exceptions.ConnectionError) as e: 563 | # ResponseStreamingError (subclass of HTTPClientError) can happen as 564 | # a result of a failed read from the stream (http.client.IncompleteRead). 565 | # Retrying can help in this case. 566 | err = e 567 | 568 | if attempt < max_attempts: 569 | log.warning( 570 | "%s failed attempt %d with retriable error: %s", _s3_get_bytes_range.__name__, attempt, err 571 | ) 572 | _wait_before_retry(attempt) 573 | 574 | # When torch's DataLoader intercepts exceptions, it may try to re-raise them 575 | # by recalling their constructor with a single message arg. Torch has some 576 | # logic to deal with the absence of a single-parameter constructor, but it 577 | # doesn't gracefully handle other possible failures in calling such a constructor 578 | # This can cause an irrelevant exception (e.g. KeyError: 'error'), resulting 579 | # in us losing the true exception info. To avoid this, we change the exception 580 | # to a type that has a single-parameter constructor. 581 | raise OLMoNetworkError(f"Failed to get bytes range from {scheme}") from err 582 | 583 | 584 | def _s3_find_latest_checkpoint(scheme: str, bucket_name: str, prefix: str) -> Optional[str]: 585 | if not prefix.endswith("/"): 586 | prefix = f"{prefix}/" 587 | response = _get_s3_client(scheme).list_objects(Bucket=bucket_name, Prefix=prefix, Delimiter="/") 588 | assert not response["IsTruncated"] # need to handle this if it happens 589 | latest_step = 0 590 | latest_checkpoint: Optional[str] = None 591 | for item in response["CommonPrefixes"]: 592 | prefix = item["Prefix"].strip("/") 593 | checkpoint_name = os.path.split(prefix)[-1] 594 | if not checkpoint_name.startswith("step"): 595 | continue 596 | try: 597 | step = int(checkpoint_name.replace("step", "").replace("-unsharded", "")) 598 | except ValueError: 599 | continue 600 | # Make sure the checkpoint dir contains a config, otherwise the checkpoint is incomplete 601 | # (upload might have have failed part way through). 602 | try: 603 | _s3_file_size(scheme, bucket_name, f"{prefix}/config.yaml") 604 | except FileNotFoundError: 605 | continue 606 | # We prioritize sharded checkpoints over unsharded ones. 607 | if step > latest_step or (step == latest_step and not checkpoint_name.endswith("-unsharded")): 608 | latest_step = step 609 | latest_checkpoint = f"{scheme}://ai2-llm/{prefix}" 610 | return latest_checkpoint 611 | 612 | 613 | def _http_file_size(scheme: str, host_name: str, path: str) -> int: 614 | import requests 615 | 616 | response = requests.head(f"{scheme}://{host_name}/{path}", allow_redirects=True) 617 | return int(response.headers.get("content-length")) 618 | 619 | 620 | def _http_get_bytes_range(scheme: str, host_name: str, path: str, bytes_start: int, num_bytes: int) -> bytes: 621 | import requests 622 | 623 | response = requests.get( 624 | f"{scheme}://{host_name}/{path}", headers={"Range": f"bytes={bytes_start}-{bytes_start+num_bytes-1}"} 625 | ) 626 | result = response.content 627 | assert ( 628 | len(result) == num_bytes 629 | ), f"expected {num_bytes} bytes, got {len(result)}" # Some web servers silently ignore range requests and send everything 630 | return result 631 | 632 | 633 | def default_thread_count() -> int: 634 | return int(os.environ.get("OLMO_NUM_THREADS") or min(32, (os.cpu_count() or 1) + 4)) 635 | 636 | 637 | def pass_through_fn(fn, *args, **kwargs): 638 | return fn(*args, **kwargs) 639 | 640 | 641 | def threaded_generator(g, maxsize: int = 16, thread_name: Optional[str] = None): 642 | q: Queue = Queue(maxsize=maxsize) 643 | 644 | sentinel = object() 645 | 646 | def fill_queue(): 647 | try: 648 | for value in g: 649 | q.put(value) 650 | except Exception as e: 651 | q.put(e) 652 | finally: 653 | q.put(sentinel) 654 | 655 | thread_name = thread_name or repr(g) 656 | thread = Thread(name=thread_name, target=fill_queue, daemon=True) 657 | thread.start() 658 | 659 | for x in iter(q.get, sentinel): 660 | if isinstance(x, Exception): 661 | raise OLMoThreadError(f"generator thread {thread_name} failed") from x 662 | else: 663 | yield x 664 | 665 | 666 | def roundrobin(*iterables): 667 | """ 668 | Call the given iterables in a round-robin fashion. For example: 669 | ``roundrobin('ABC', 'D', 'EF') --> A D E B F C`` 670 | """ 671 | # Adapted from https://docs.python.org/3/library/itertools.html#itertools-recipes 672 | num_active = len(iterables) 673 | nexts = cycle(iter(it).__next__ for it in iterables) 674 | while num_active: 675 | try: 676 | for next in nexts: 677 | yield next() 678 | except StopIteration: 679 | # Remove the iterator we just exhausted from the cycle. 680 | num_active -= 1 681 | nexts = cycle(islice(nexts, num_active)) 682 | -------------------------------------------------------------------------------- /olmo/version.py: -------------------------------------------------------------------------------- 1 | _MAJOR = "0" 2 | _MINOR = "3" 3 | # On main and in a nightly release the patch should be one ahead of the last 4 | # released build. 5 | _PATCH = "0" 6 | # This is mainly for nightly builds which have the suffix ".dev$DATE". See 7 | # https://semver.org/#is-v123-a-semantic-version for the semantics. 8 | _SUFFIX = "" 9 | 10 | VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) 11 | VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX) 12 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "ai2-olmo" 7 | dynamic = ["version"] 8 | readme = "README.md" 9 | description = "Open Language Model (OLMo)" 10 | authors = [ 11 | { name = "Allen Institute for Artificial Intelligence", email = "olmo@allenai.org" } 12 | ] 13 | requires-python = ">=3.8" 14 | license = { file = "LICENSE" } 15 | dependencies = [ 16 | "numpy", 17 | "torch>=2.1,<2.3", 18 | "omegaconf", 19 | "rich", 20 | "boto3", 21 | "google-cloud-storage", 22 | "tokenizers", 23 | "packaging", 24 | "cached_path>=1.6.2", 25 | "transformers", 26 | "jsonlines", 27 | ] 28 | 29 | [project.optional-dependencies] 30 | dev = [ 31 | "ruff", 32 | "mypy>=1.0,<1.4", 33 | "black>=23.1,<24.0", 34 | "isort>=5.12,<5.13", 35 | "pytest", 36 | "pytest-sphinx", 37 | "twine>=1.11.0", 38 | "setuptools", 39 | "wheel", 40 | "build", 41 | ] 42 | train = [ 43 | "wandb", 44 | "beaker-gantry", 45 | "click", 46 | "torchmetrics", 47 | "smashed[remote]>=0.21.1", 48 | "safetensors", 49 | "datasets", 50 | "scikit-learn", 51 | "msgspec>=0.14.0", 52 | ] 53 | all = [ 54 | "ai2-olmo[dev,train]", 55 | ] 56 | 57 | [project.urls] 58 | Homepage = "https://github.com/allenai/OLMo" 59 | Repository = "https://github.com/allenai/OLMo" 60 | 61 | [tool.setuptools] 62 | include-package-data = true 63 | 64 | [tool.setuptools.package-data] 65 | olmo = ["py.typed"] 66 | 67 | [tool.setuptools.dynamic] 68 | version = { attr = "olmo.version.VERSION" } 69 | 70 | [tool.setuptools.packages.find] 71 | include = ["olmo*", "hf_olmo*"] 72 | exclude = [ 73 | "*.tests", 74 | "*.tests.*", 75 | "tests.*", 76 | "tests", 77 | "test_fixtures", 78 | "test_fixtures.*", 79 | "docs*", 80 | "scripts*", 81 | "olmo_tokenizer.*", 82 | "evaluation.*", 83 | "pretrain_data.*", 84 | "tmp_*", 85 | "inference.*", 86 | ] 87 | 88 | [tool.black] 89 | line-length = 115 90 | include = '\.pyi?$' 91 | exclude = ''' 92 | ( 93 | __pycache__ 94 | | \.git 95 | | \.mypy_cache 96 | | \.pytest_cache 97 | | \.vscode 98 | | \.venv 99 | | \bdist\b 100 | | \bdoc\b 101 | | pretrain_data/ 102 | | inference/ 103 | ) 104 | ''' 105 | 106 | [tool.isort] 107 | profile = "black" 108 | multi_line_output = 3 109 | extend_skip = ["pretrain_data", "tokenizer"] 110 | 111 | [tool.ruff] 112 | line-length = 115 113 | lint.ignore = ["F403", "F405", "E501"] 114 | exclude = [ 115 | ".bzr", 116 | ".direnv", 117 | ".eggs", 118 | ".git", 119 | ".venv", 120 | "venv", 121 | ".mypy_cache", 122 | "__pycache__", 123 | ".nox", 124 | ".pants.d", 125 | ".pytype", 126 | ".ruff_cache", 127 | ".svn", 128 | ".tox", 129 | "__pypackages__", 130 | "_build", 131 | "buck-out", 132 | "build", 133 | "dist", 134 | "node_modules", 135 | "doc", 136 | "pretrain_data", 137 | "inference", 138 | ] 139 | 140 | [tool.ruff.lint.per-file-ignores] 141 | "**/__init__.py" = ["F401"] 142 | 143 | [tool.pyright] 144 | reportPrivateImportUsage = false 145 | exclude = ["pretrain_data/", "tokenizer/"] 146 | 147 | [tool.mypy] 148 | ignore_missing_imports = true 149 | no_site_packages = true 150 | check_untyped_defs = true 151 | exclude = ["pretrain_data/", "inference/compression/dependencies/", "inference/efficiency/dependencies/"] 152 | no_namespace_packages = true 153 | 154 | [[tool.mypy.overrides]] 155 | module = "tests.*" 156 | strict_optional = false 157 | 158 | [tool.pytest.ini_options] 159 | testpaths = "tests/" 160 | python_classes = [ 161 | "Test*", 162 | "*Test", 163 | ] 164 | log_format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" 165 | log_level = "DEBUG" 166 | markers = [ 167 | "gpu", 168 | ] 169 | filterwarnings = [ 170 | 'ignore::FutureWarning:huggingface_hub\.file_download', 171 | 'ignore::DeprecationWarning:pkg_resources', 172 | 'ignore::DeprecationWarning:google\.rpc', 173 | ] 174 | -------------------------------------------------------------------------------- /scripts/get_dataorder.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define the target file path 4 | FILE_PATH="data/global_indices/1B/global_indices.npy" 5 | 6 | # Define the download URL 7 | DOWNLOAD_URL="https://olmo-checkpoints.org/ai2-llm/olmo-small/46zc5fly/train_data/global_indices.npy" 8 | 9 | # Check if the file exists 10 | if [ -f "$FILE_PATH" ]; then 11 | echo "File already exists at: $FILE_PATH. Exit" 12 | else 13 | mkdir -p "$(dirname "$FILE_PATH")" 14 | wget -O "$FILE_PATH" "$DOWNLOAD_URL" 15 | echo "Download complete at $FILE_PATH" 16 | fi -------------------------------------------------------------------------------- /scripts/get_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Assign the single path to a variable 4 | CKPT_PATH="$1" 5 | echo "Processing checkpoint path: $CKPT_PATH" 6 | 7 | SCRIPT_DIR=$(dirname "$(realpath "$BASH_SOURCE")") 8 | ROOT_DIR=$(realpath "$SCRIPT_DIR/..") 9 | echo "Root dir is $ROOT_DIR" 10 | 11 | # Extract the number after "step" 12 | step_number=$(echo "$CKPT_PATH" | sed -n 's/.*step\([0-9]*\)-unsharded\/.*/\1/p') 13 | echo "Extracted step number: $step_number" 14 | 15 | OUTPUT_PATH="checkpoints/pretrained_1B/$step_number-unsharded" 16 | 17 | # Create the output directory if it doesn't exist 18 | mkdir -p "$OUTPUT_PATH" 19 | # mkdir -p "$OUTPUT_PATH/hf" 20 | 21 | # Change to the output directory 22 | cd "$OUTPUT_PATH" 23 | 24 | # Download the config.yaml file 25 | wget "${CKPT_PATH}config.yaml" 26 | wget "${CKPT_PATH}model.pt" 27 | wget "${CKPT_PATH}train.pt" 28 | 29 | cd "$ROOT_DIR" 30 | 31 | # python scripts/convert_olmo_to_hf_new.py --input_dir "${OUTPUT_PATH}" --output_dir "${OUTPUT_PATH}/hf" --tokenizer_json_path tokenizers/allenai_gpt-neox-olmo-dolma-v1_5.json 32 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | """Run this script with 'torchrun'.""" 2 | import gzip 3 | import logging 4 | import sys 5 | from pathlib import Path 6 | from typing import Optional, TextIO 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | import wandb 12 | from packaging import version 13 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 14 | from torch.distributed.fsdp import ShardingStrategy 15 | from torch.nn.parallel import DistributedDataParallel as DDP 16 | 17 | from olmo.config import ( 18 | CheckpointType, 19 | DDPGradSyncMode, 20 | DistributedStrategy, 21 | TrainConfig, 22 | ) 23 | from olmo.data import build_train_dataloader, build_original_train_dataloader 24 | from olmo.eval import build_evaluators, build_custom_evaluators 25 | from olmo.exceptions import OLMoCliError, OLMoConfigurationError 26 | from olmo.model import OLMo 27 | from olmo.optim import BoltOnWarmupScheduler, build_optimizer, build_scheduler 28 | from olmo.torch_util import ( 29 | barrier, 30 | get_default_device, 31 | get_global_rank, 32 | get_local_rank, 33 | get_world_size, 34 | peak_gpu_memory, 35 | seed_all, 36 | ) 37 | from olmo.train import Trainer 38 | from olmo.util import clean_opt, log_extra_field, prepare_cli_environment 39 | 40 | log = logging.getLogger("train") 41 | 42 | 43 | def main(cfg: TrainConfig) -> None: 44 | # Ensure run name set. 45 | if cfg.run_name is None: 46 | raise OLMoConfigurationError("--run_name is required") 47 | log_extra_field("run_name", cfg.run_name) 48 | 49 | # Sanity check 50 | if (cfg.reset_optimizer_state or cfg.reset_trainer_state) and cfg.load_path is None: 51 | log.warning( 52 | "You want to reset the optimizer or trainer state, but we're not loading from the checkpoint. The" 53 | "setting has no effect." 54 | ) 55 | 56 | barrier() 57 | 58 | # Set CUDA device. 59 | torch.cuda.set_device(f"cuda:{get_local_rank()}") 60 | device = torch.device("cuda") 61 | 62 | # Fill some configuration options. 63 | cfg.model.precision = cfg.precision 64 | cfg.device_train_batch_size = cfg.global_train_batch_size // get_world_size() 65 | assert cfg.device_train_batch_size is not None # for mypy 66 | cfg.device_train_grad_accum = cfg.device_train_batch_size // cfg.device_train_microbatch_size 67 | if cfg.optimizer.no_decay_norm_and_bias is not None: 68 | log.warning( 69 | "You set the deprecated config option `no_decay_norm_and_bias`. For compatibility, this" 70 | "setting will take precedence over all other weight decay configurations. Please change" 71 | "your config to use `decay_norm_and_bias` and `decay_embeddings` instead." 72 | ) 73 | cfg.optimizer.decay_norm_and_bias = not cfg.optimizer.no_decay_norm_and_bias 74 | cfg.optimizer.decay_embeddings = not cfg.optimizer.no_decay_norm_and_bias 75 | cfg.optimizer.no_decay_norm_and_bias = None # So nobody uses this by accident. 76 | 77 | # Display and save configuration. 78 | if get_global_rank() == 0: 79 | if cfg.data.paths is not None and len(cfg.data.paths) < 50: 80 | log.info("Configuration:") 81 | log.info(cfg) 82 | if not cfg.dry_run and (cfg.load_path is None or Path(cfg.load_path).parent != Path(cfg.save_folder)): 83 | # Save config. 84 | save_path = Path(cfg.save_folder) / "config.yaml" 85 | if save_path.is_file() and not cfg.save_overwrite: 86 | raise OLMoConfigurationError(f"{save_path} already exists, use --save_overwrite to overwrite") 87 | else: 88 | log.info(f"Saving config to {save_path}") 89 | save_path.parent.mkdir(exist_ok=True, parents=True) 90 | cfg.save(save_path) 91 | del save_path 92 | 93 | barrier() 94 | 95 | # Maybe start W&B run. 96 | if cfg.wandb is not None and (get_global_rank() == 0 or not cfg.wandb.rank_zero_only): 97 | wandb_dir = Path(cfg.save_folder) / "wandb" 98 | wandb_dir.mkdir(parents=True, exist_ok=True) 99 | wandb.init( 100 | dir=wandb_dir, 101 | project=cfg.wandb.project, 102 | entity=cfg.wandb.entity, 103 | group=cfg.wandb.group, 104 | name=cfg.wandb.name, 105 | tags=cfg.wandb.tags, 106 | config=cfg.asdict(exclude=["wandb"]), 107 | ) 108 | 109 | barrier() 110 | 111 | # Set seed. 112 | seed_all(cfg.seed) 113 | 114 | # Construct data loader. 115 | if cfg.data.dataset_path is not None: 116 | train_loader = build_train_dataloader(cfg) 117 | else: 118 | train_loader = build_original_train_dataloader(cfg) 119 | # Construct evaluators. 120 | evaluators = build_evaluators(cfg, device) 121 | barrier() 122 | 123 | # Initialize the model. 124 | log.info("Building model...") 125 | olmo_model = OLMo(cfg.model) 126 | log.info(f"Total number of parameters: {olmo_model.num_params():,d}") 127 | log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embedding=False):,d}") 128 | log.info(f"Peak GPU Memory (MB) before {cfg.distributed_strategy}: {int(peak_gpu_memory() or 0)}") 129 | 130 | olmo_model.set_activation_checkpointing(cfg.activation_checkpointing) 131 | 132 | if cfg.distributed_strategy == DistributedStrategy.ddp: 133 | log.info("Wrapping model with DDP...") 134 | assert cfg.ddp is not None, "DistributedStrategy ddp needs cfg.ddp to be set!" 135 | 136 | if cfg.model.init_device != "cuda": 137 | raise OLMoConfigurationError("DDP does not work with init_device set to anything other than `cuda`.") 138 | 139 | if cfg.ddp.find_unused_params is True and cfg.ddp.grad_sync_mode != DDPGradSyncMode.micro_batch: 140 | raise OLMoConfigurationError( 141 | "`find_unused_params` is set to True. DDP needs to synchronize gradients for every micro-batch to avoid errors. Set `grad_sync_mode` to `micro_batch`." 142 | ) 143 | 144 | param_init_fn = None 145 | 146 | # move to cuda before calling ddp 147 | dist_model = DDP(olmo_model.to(device), find_unused_parameters=cfg.ddp.find_unused_params) 148 | elif cfg.distributed_strategy == DistributedStrategy.fsdp: 149 | log.info("Wrapping model with FDSP...") 150 | assert cfg.fsdp is not None, "DistributedStrategy fsdp needs cfg.fsdp to be set!" 151 | wrap_policy = olmo_model.get_fsdp_wrap_policy(cfg.fsdp.wrapping_strategy) 152 | 153 | if version.parse(torch.__version__) >= version.parse("2.1.0"): 154 | # This prevents any parameters from being initialized twice 155 | def dummy_init_fn(module: torch.nn.Module) -> None: 156 | module.to_empty(device=get_default_device()) 157 | 158 | param_init_fn = dummy_init_fn 159 | else: 160 | param_init_fn = None 161 | 162 | dist_model = FSDP( 163 | olmo_model, 164 | sharding_strategy=cfg.fsdp.sharding_strategy, 165 | mixed_precision=cfg.fsdp_precision, 166 | auto_wrap_policy=wrap_policy, 167 | use_orig_params=cfg.fsdp.use_orig_params, # needed for compile and some of our optimizer/parameter metrics 168 | limit_all_gathers=True, 169 | device_id=get_local_rank(), 170 | param_init_fn=param_init_fn, 171 | ) 172 | elif cfg.distributed_strategy is None: 173 | raise NotImplementedError("Single accelerator training not implemented yet!") 174 | 175 | # when param_init_fn is None, FSDP will call reset_parameters() automatically 176 | if param_init_fn is not None or cfg.distributed_strategy == DistributedStrategy.ddp: 177 | olmo_model.reset_parameters() 178 | 179 | log.info(f"Peak GPU Memory (MB) after {cfg.distributed_strategy}: {int(peak_gpu_memory() or 0)}") 180 | log.info("Model:") 181 | log.info(dist_model) 182 | 183 | # Construct optimizer and learning rate scheduler. 184 | optim = build_optimizer(cfg, dist_model) 185 | scheduler = build_scheduler(cfg) 186 | 187 | # Data indices file. 188 | indices_file: Optional[TextIO] = None 189 | if cfg.save_data_indices: 190 | indices_file_path = Path(cfg.save_folder) / f"data-indices/rank{get_global_rank()}.tsv.gz" 191 | if indices_file_path.exists() and not cfg.save_overwrite: 192 | raise OLMoConfigurationError(f"{indices_file_path} already exists, use --save_overwrite to overwrite") 193 | indices_file_path.parent.mkdir(exist_ok=True, parents=True) 194 | indices_file = gzip.open(indices_file_path, "wt") 195 | 196 | # Consolidate components into `Trainer` object. 197 | with Trainer( 198 | cfg=cfg, 199 | epoch=cfg.epoch, 200 | model=olmo_model, 201 | fsdp_model=dist_model, 202 | optim=optim, 203 | scheduler=scheduler, 204 | train_loader=train_loader, 205 | # custom_loader=custom_loader, 206 | device=device, 207 | evaluators=evaluators, 208 | indices_file=indices_file, 209 | ) as trainer: 210 | if not cfg.dry_run and not cfg.no_pre_train_checkpoint and cfg.load_path is None: 211 | # checkpoint_type = ( 212 | # CheckpointType.sharded if cfg.save_num_checkpoints_to_keep != 0 else CheckpointType.unsharded 213 | # ) 214 | if cfg.distributed_strategy == DistributedStrategy.ddp: 215 | checkpoint_type = CheckpointType.unsharded 216 | 217 | if cfg.save_interval_unsharded is None: 218 | log.warning( 219 | "DDP requires setting `save_interval_unsharded`. Using the value set for `save_interval`." 220 | ) 221 | cfg.save_interval_unsharded = cfg.save_interval 222 | 223 | if cfg.save_num_unsharded_checkpoints_to_keep == 0: 224 | log.warning( 225 | "DDP requires setting `save_num_unsharded_checkpoints_to_keep`. Using the value set for `save_num_checkpoints_to_keep`." 226 | ) 227 | cfg.save_num_unsharded_checkpoints_to_keep = cfg.save_num_checkpoints_to_keep 228 | elif cfg.distributed_strategy == DistributedStrategy.fsdp: 229 | checkpoint_type = ( 230 | CheckpointType.sharded if cfg.save_num_checkpoints_to_keep != 0 else CheckpointType.unsharded 231 | ) 232 | else: 233 | raise NotImplementedError(f"Distributed strategy {cfg.distributed_strategy} not supported yet!") 234 | 235 | # We save a checkpoint up-front to make sure this won't fail (due to disk space or whatever). 236 | log.info("Saving pre-train checkpoint...") 237 | checkpoint_path, local_checkpoint_cache = trainer.save_checkpoint(checkpoint_type=checkpoint_type) 238 | log.info(f"Checkpoint saved to {checkpoint_path}") 239 | 240 | # And they we verify that we can load it. 241 | log.info("Attempting to load pre-train checkpoint...") 242 | trainer.restore_checkpoint( 243 | checkpoint_path, checkpoint_type=checkpoint_type, local_cache=local_checkpoint_cache 244 | ) 245 | log.info("Checkpoint successfully loaded") 246 | 247 | # NOTE: https://github.com/allenai/LLM/issues/233 248 | # log.info("Removing pre-train checkpoint...") 249 | # trainer.remove_checkpoint(checkpoint_type=checkpoint_type) 250 | # log.info("Successfully removed checkpoint") 251 | 252 | if cfg.load_path is not None: 253 | log.info(f"Loading checkpoint from {cfg.load_path}...") 254 | trainer.restore_checkpoint( 255 | cfg.load_path, 256 | load_optimizer_state=not cfg.reset_optimizer_state, 257 | load_trainer_state=not cfg.reset_trainer_state, 258 | sharded_checkpointer=cfg.load_path_sharded_checkpointer, 259 | ) 260 | log.info("Checkpoint successfully loaded") 261 | 262 | # If we have to, set a new scheduler: 263 | if cfg.reset_optimizer_state and not cfg.reset_trainer_state: 264 | trainer.scheduler = BoltOnWarmupScheduler.wrap( 265 | trainer.scheduler, 266 | trainer.global_step, 267 | int(trainer.global_step + cfg.scheduler.t_warmup), 268 | ) 269 | 270 | 271 | if cfg.force_save_unsharded: 272 | log.info("Saving unsharded checkpoint...") 273 | checkpoint_path, _ = trainer.save_checkpoint(checkpoint_type=CheckpointType.unsharded) 274 | log.info(f"Unsharded checkpoint saved to {checkpoint_path}") 275 | 276 | if cfg.compile is not None: 277 | # TODO (epwalsh): trying to compile the whole train step results in a compile-time error from within 278 | # the optimizer. We should investigate this further at some point. 279 | # trainer.train_step = torch.compile(trainer.train_step, **cfg.compile.asdict()) 280 | trainer.train_batch = torch.compile(trainer.train_batch, **cfg.compile.asdict()) # type: ignore 281 | # TODO (epwalsh): compiling the `eval_batch()` method is a little sketchy since the inputs will look 282 | # different for different eval tasks. That might be okay, but it might not be. 283 | # trainer.eval_batch = torch.compile(trainer.eval_batch, **cfg.compile.asdict()) # type: ignore 284 | # Alternatively, could just do this: 285 | # trainer.fsdp_model = torch.compile(trainer.fsdp_model, **cfg.compile.asdict()) 286 | 287 | if not cfg.dry_run: 288 | log.info("Starting training...") 289 | trainer.fit() 290 | log.info("Training complete") 291 | else: 292 | log.info("Dry run complete") 293 | 294 | 295 | if __name__ == "__main__": 296 | try: 297 | mp.set_start_method("spawn", force=True) 298 | except RuntimeError as e: 299 | print(f"failed to set multiprocessing start method: {e}") 300 | 301 | # Initialize process group. 302 | dist.init_process_group(backend="nccl") 303 | 304 | prepare_cli_environment() 305 | 306 | log.info(f"multiprocessing start method set to '{mp.get_start_method()}'") 307 | 308 | try: 309 | yaml_path, args_list = sys.argv[1], sys.argv[2:] 310 | except IndexError: 311 | raise OLMoCliError(f"Usage: {sys.argv[0]} [CONFIG_PATH] [OPTIONS]") 312 | 313 | cfg = TrainConfig.load(yaml_path, [clean_opt(s) for s in args_list]) 314 | main(cfg) 315 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port=29599 -m scripts.train configs/1B/1B_bs128_lr4e4_pubmed_1ep_738k.yaml --------------------------------------------------------------------------------