├── .gitignore ├── LICENSE ├── README.md ├── llama ├── install_deps.sh ├── llama_landmark_config.py ├── llama_mem.py ├── ltriton │ ├── flash_landmark_attention.py │ └── test_flash_landmark_attention.py ├── redpajama.py ├── requirements.txt ├── run_test.py ├── train.py ├── urls │ ├── arxiv.txt │ ├── book.txt │ ├── c4.txt │ ├── common_crawl.txt │ ├── github.txt │ ├── stackexchange.txt │ └── wikipedia.txt └── weight_diff.py ├── llama_legacy ├── llama_mem.py ├── redpajama.py ├── requirements.txt ├── run_test.py ├── train.py ├── urls │ ├── arxiv.txt │ ├── book.txt │ ├── c4.txt │ ├── common_crawl.txt │ ├── github.txt │ ├── stackexchange.txt │ └── wikipedia.txt └── weight_diff.py └── lm_benchmark ├── config ├── __init__.py └── rotary.py ├── data ├── __init__.py ├── arxiv_math.py ├── pg19.py ├── pg19 │ ├── README.md │ └── prepare.py ├── proof-pile │ └── prepare.py └── utils.py ├── distributed ├── __init__.py ├── backend.py ├── ddp.py └── single.py ├── eval.py ├── eval_cmd_generator.py ├── main.py ├── models ├── __init__.py ├── base_new.py ├── caches │ ├── __init__.py │ ├── cache.py │ ├── kv_cache.py │ ├── kv_cache_train.py │ └── mem_cache.py ├── landmark.py ├── landmark_with_cmt.py └── positional_encoders │ ├── __init__.py │ ├── encoder.py │ ├── rotary.py │ ├── rotary_mem_jump.py │ └── rotary_utils.py ├── optim ├── base.py ├── transformer_xl.py └── utils.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Dataset folder 2 | data/datasets/ 3 | wandb/ 4 | exps/ 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | .DS_Store 137 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Landmark Attention 2 | 3 | This repository contains the implementation of landmark attention as described in our paper: 4 | 5 | **Landmark Attention: Random-Access Infinite Context Length for Transformers**
6 | Amirkeivan Mohtashami, Martin Jaggi
7 | NeurIPS 2023: https://arxiv.org/abs/2305.16300 8 | 9 | ## Repository Structure 10 | 11 | The repository contains three code bases under the following directories: 12 | 13 | 1. `lm_benchmark`: This directory contains the code used for performing language modeling over PG19 and arXiv Math datasets. 14 | 2. `llama_legacy`: This directory contains the code used to obtain the results of fine-tuning LLaMA as reported in the paper. The code in this directory is frozen to allow reproduction of the results. Thus, except when trying to exactly replicate our results, we suggest using the code under `llama` directory. 15 | 3. `llama`: This directory contains the current implementation of landmark attention. The directory includes both a high-level implementation and a Triton implementation of landmark attention combined with Flash Attention. As an example, the directory contains the code for applying the implementation to LLaMA models. 16 | 17 | 18 | Note: During the development of this project, we made the decision to update the names of certain components. However, as this decision was made later in the project timeline, you may encounter references to the old names within the code (e.g. `mem` instead of `landmark`). We are working to address this issue. 19 | 20 | 21 | ## Language Modeling Benchmarks 22 | ### Training 23 | For training, the landmark tokens are added during data preparation. The following command is an example of training a model on PG19 with landmark tokens added every 50 tokens: 24 | ``` 25 | python main.py \ 26 | --config_format rotary \ 27 | --model landmark \ 28 | --n_embd 1024 \ 29 | --n_head 8 \ 30 | --n_layer 12 \ 31 | --batch_size 16 \ 32 | --sequence_length 512 \ 33 | --acc_steps 8 \ 34 | --wandb_project memory-llm \ 35 | --dataset pg19 \ 36 | --iterations 240000 \ 37 | --dropout 0.0 \ 38 | --positional_encoder rotary \ 39 | --softmax_func mem_opt \ 40 | --mem_freq 50 \ 41 | --wandb \ 42 | --save_checkpoint_freq 20000 43 | ``` 44 | 45 | To run on multi-GPUs use torchrun (e.g. `torchrun --nproc_per_node=4`) and pass `--distributed_backend nccl` to `main.py` script. We suggest first running the script until the training starts on a single GPU before switching to multi-GPU settings. This is because the first node will have to perform the initialization of the data which can take a long time leading to a timeout on the synchronization in multi-GPU settings. However, once the initialization is performed once, the result is stored on the disk so the next runs will be quick. 46 | 47 | You will need to initialize the dataset before running the training script. For instructions, use the `prepare.py` script in the corresponding dataset folder located inside `data/`. 48 | 49 | ### Inference 50 | The code supports inference in various settings. To perform standard evaluation, disable cache and use the same chunk size (specified using `--mid_length` flag) as the evaluation length (specified by `--eval_seq_length`). Using landmarks is possible when using `mem_cache`. The script `eval_cmd_generator.py` can be used to generate a bash script containining commands to perform evaluations corresponding to Tables 1 and 2 of the paper. The path of the output models need to be updated inside the script. 51 | 52 | ## LLaMA fine-tuning 53 | The code for fine-tuning LLaMA and testing the final model is available as a standalone project in the sub-directory "llama". An example for running the fine tuning (from inside the sub-directory) is: 54 | 55 | ``` 56 | torchrun --nproc_per_node=8 train.py \ 57 | --model_name_or_path /llama_weights/7B_hf/ \ 58 | --bf16 True \ 59 | --output_dir /llama-redpajama-mem-15000-with-mem/ \ 60 | --cache_dir /hf-cache/ \ 61 | --num_train_epochs 1 \ 62 | --per_device_train_batch_size 2 \ 63 | --per_device_eval_batch_size 2 \ 64 | --gradient_accumulation_steps 8 \ 65 | --evaluation_strategy "no" \ 66 | --save_strategy "steps" \ 67 | --save_steps 2000 \ 68 | --save_total_limit 2 \ 69 | --learning_rate 2e-5 \ 70 | --weight_decay 0.1 \ 71 | --warmup_ratio 0.03 \ 72 | --lr_scheduler_type "cosine" \ 73 | --logging_steps 1 \ 74 | --fsdp "full_shard auto_wrap" \ 75 | --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ 76 | --tf32 True \ 77 | --max_steps 15000 78 | ``` 79 | 80 | In the above example, LLaMA wieghts (converted to huggingface format) should be in `/llama_weights/7B_hf/`. 81 | 82 | ### Fine-tuned Weights 83 | We have released the weight diff between the original LLaMA 7B and the same model fine-tuned for 15000 steps on [RedPajama](https://github.com/togethercomputer/RedPajama-Data) dataset with landmark attention [here](https://huggingface.co/epfml/landmark-attention-llama7b-wdiff). You may use the `weight_diff.py` script to recover the weights: 84 | ``` 85 | python weight_diff.py recover --path_raw --path_diff --path_tuned 86 | ``` 87 | For an example of how to perform inference using landmarks, look at `run_test.py`. 88 | 89 | ### Triton Implementation 90 | 91 | We have added a Triton implementation of the combination of our method and Flash Attention which significantly reduces memory usage and also increases performance. Using this implementation, we trained LLaMA 7B with 2048 context length (instead of 512). Also, adding landmark attention to any model can be done by applying the following changes: 92 | 93 | 1. Adding landmark tokens to the input at regular intervals of block size. 94 | 2. (Optional) Creating a boolean mask of which tokens are landmarks. The mask can be passed to the landmark attention function to ensure the landmarks are placed correctly. This step can be skipped to obtain the highest speed. 95 | 3. Replacing `torch.nn.functional.scaled_dot_product_attention` with `fused_landmark_attention`. 96 | 97 | Note that the implemnetation relies on the latest version of Triton which causes a conflict with latest version of PyTorch. Therefore, a special `install_deps.sh` script is provided to install the dependencies. 98 | 99 | Finally, note that the current implementation makes the following assumptions: 100 | 101 | 1. The implementation assumes the landmark blocks have the same size as blocks used for computing the attention in Flash Attention. This limits the maximum size of the block as the whole landmark block should fit into GPU's local memory. However, using bfloat16 it should be possible to use block sizes as large as 64 or 128 which should be enough for landmark blocks. 102 | 2. The implementation assumes the difference between number of keys and queries is a multiple of the block size. Therefore, normal attention must be applied in the auto-regressive part of the generation when the tokens are generated one by one. The implemnetation can still be used to go over the input before reaching the generation. 103 | Note that this is not a big limitation since when generating tokens one by one, the attention matrix has only a single row, limiting the benefits of Flash Attention. 104 | 3. While the high level implementation allows the landmark tokens to be placed anywhere, the fused implementation assumes the landmark tokens are placed regularly at the end of each block. Since we always use this pattern at inference, this should not be noticed. 105 | 106 | -------------------------------------------------------------------------------- /llama/install_deps.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | pip install -r requirements.txt 4 | pip install "git+https://github.com/openai/triton.git#subdirectory=python" 5 | -------------------------------------------------------------------------------- /llama/llama_landmark_config.py: -------------------------------------------------------------------------------- 1 | from transformers.models.llama.configuration_llama import LlamaConfig 2 | 3 | class LlamaLandmarkConfig(LlamaConfig): 4 | model_type = "llama_with_landmark" 5 | 6 | def __init__( 7 | self, 8 | mem_id=32001, 9 | mem_freq=50, 10 | train_context_length=512, 11 | include_landmark_in_loss=True, 12 | **kwargs, 13 | ): 14 | self.mem_id = mem_id 15 | self.mem_freq = mem_freq 16 | self.train_context_length = train_context_length 17 | self.include_landmark_in_loss = include_landmark_in_loss 18 | super().__init__(**kwargs) 19 | -------------------------------------------------------------------------------- /llama/ltriton/test_flash_landmark_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | from flash_landmark_attention import fused_landmark_attention 5 | 6 | 7 | class LandmarkGroupedSoftmaxFunction(torch.autograd.Function): 8 | 9 | # Note that forward, setup_context, and backward are @staticmethods 10 | @staticmethod 11 | def forward(ctx, x, dim, mem_cnt, resp_mem_idx): 12 | new_shape = list(x.shape) 13 | new_shape[dim] = mem_cnt # max_mem_cnt.item() 14 | max_by_group = x.new_zeros((*new_shape,)) 15 | max_by_group.scatter_reduce_(src=x, index=resp_mem_idx, dim=dim, reduce="amax", include_self=False) 16 | 17 | maxes = torch.gather(max_by_group, dim, resp_mem_idx) 18 | #x_exp = torch.exp(x - torch.where(torch.isinf(maxes), 0, maxes)) 19 | x_exp = torch.exp((x - maxes).to(torch.float32)) 20 | 21 | cumsum_by_group = torch.zeros_like(max_by_group, dtype=x_exp.dtype) 22 | 23 | cumsum_by_group.scatter_add_(dim, resp_mem_idx, x_exp, ) 24 | denom = torch.gather(cumsum_by_group, dim, resp_mem_idx) 25 | 26 | #probs = torch.where(denom < 0.5, 0, x_exp / denom) 27 | probs = x_exp / denom 28 | 29 | 30 | ctx.mem_cnt = mem_cnt 31 | ctx.dim = dim 32 | ctx.save_for_backward(resp_mem_idx, probs) 33 | 34 | return probs 35 | 36 | @staticmethod 37 | def backward(ctx, grad_probs): 38 | mem_cnt = ctx.mem_cnt 39 | dim = ctx.dim 40 | resp_mem_idx, probs = ctx.saved_tensors 41 | grad_x = grad_dim = grad_mem_cnt = grad_resp_mem_idx = None 42 | 43 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[4]: 44 | grad_pair = grad_probs * probs 45 | 46 | new_shape = list(probs.shape) 47 | new_shape[dim] = mem_cnt # max_mem_cnt.item() 48 | cumsum_by_group = grad_pair.new_zeros((*new_shape,)) 49 | cumsum_by_group.scatter_add_(dim, resp_mem_idx, grad_pair) 50 | 51 | 52 | if ctx.needs_input_grad[0]: 53 | grad_sum = torch.gather(cumsum_by_group, dim, resp_mem_idx) 54 | grad_x = grad_pair - probs * grad_sum 55 | assert not ctx.needs_input_grad[1] 56 | assert not ctx.needs_input_grad[2] 57 | assert not ctx.needs_input_grad[3] 58 | 59 | return grad_x, grad_dim, grad_mem_cnt, grad_resp_mem_idx 60 | 61 | def landmark_grouped_softmax(x, dim, is_mem, last_section_mask): 62 | 63 | last_and_rest_mask = last_section_mask # | mask 64 | 65 | full_access_mask = is_mem | last_and_rest_mask 66 | 67 | max_mem_cnt = 16 68 | mem_group_idx = torch.cumsum(is_mem, dim=dim) 69 | mem_bucket_id = max_mem_cnt - 1 70 | resp_mem_idx = torch.where(last_and_rest_mask, 71 | max_mem_cnt - 1, 72 | torch.where(is_mem, mem_bucket_id, mem_group_idx)) 73 | probs = LandmarkGroupedSoftmaxFunction.apply(x, dim, max_mem_cnt, resp_mem_idx) 74 | 75 | new_shape = list(x.shape) 76 | new_shape[dim] = max_mem_cnt 77 | group_prob = probs.new_zeros((*new_shape, )) 78 | group_prob.scatter_(dim, torch.where(is_mem, mem_group_idx - 1, max_mem_cnt - 1), probs) 79 | probs = probs.mul(torch.where(full_access_mask, last_section_mask, torch.gather(group_prob, dim, resp_mem_idx))) 80 | 81 | 82 | return probs 83 | 84 | batch = 2 85 | nheads = 8 86 | seqlen_q = 1024 87 | seqlen_k = 1024 #512 88 | d = 128 89 | use_I_for_v = False 90 | mem_freq = 63 91 | q = torch.rand((batch, seqlen_q, nheads, d)).cuda().to(torch.bfloat16).transpose(1, 2) 92 | k = torch.rand((batch, seqlen_k, nheads, d)).cuda().to(torch.bfloat16).transpose(1, 2) 93 | if not use_I_for_v: 94 | v = torch.rand((batch, seqlen_k, nheads, d)).cuda().to(torch.bfloat16).transpose(1, 2) 95 | else: 96 | v = torch.eye(seqlen_k, d).cuda().to(torch.bfloat16) 97 | v = v.view(1, 1, seqlen_k, d).expand(batch, nheads, seqlen_k, d) 98 | q.requires_grad = True 99 | k.requires_grad = True 100 | v.requires_grad = True 101 | block_size = mem_freq + 1 102 | is_mem = torch.arange(0, seqlen_k, device=q.device) % block_size == (block_size - 1) 103 | out = fused_landmark_attention(q, k, v, is_mem, block_size=block_size) 104 | 105 | def f(): 106 | import math 107 | att = q @ k.transpose(-1, -2) / math.sqrt(d) 108 | att_mask = torch.tril(torch.ones((1, 1, seqlen_q, seqlen_k), device=q.device), diagonal=seqlen_k - seqlen_q)== 1. 109 | 110 | last_section_mask = (torch.arange(0, seqlen_k, device=q.device) // (mem_freq + 1))[None, :] == (torch.arange(seqlen_k - seqlen_q, seqlen_k, device=q.device) // (mem_freq + 1))[:, None] 111 | 112 | last_section_mask = last_section_mask.unsqueeze(0).unsqueeze(1) 113 | is_mem_ = is_mem.view(1, 1, 1, seqlen_k) 114 | mask = att_mask & ~(last_section_mask & is_mem_) 115 | last_section_mask = last_section_mask & mask 116 | is_mem_ = is_mem_ & mask 117 | is_mem_ = is_mem_.expand(batch, nheads, seqlen_q, seqlen_k) 118 | last_section_msak = last_section_mask.expand(batch, nheads, seqlen_q, seqlen_k) 119 | att.masked_fill_(~mask, float("-inf")) 120 | 121 | 122 | att = landmark_grouped_softmax(att, -1, is_mem_, last_section_mask).to(q.dtype) 123 | att.masked_fill_(~mask, 0.0) 124 | exact_out = att @ v 125 | return exact_out 126 | exact_out = f() 127 | 128 | def make_f_grad(func): 129 | def f_(): 130 | exact_out = func() 131 | return torch.autograd.grad((exact_out**2).sum(), [q, k, v]) 132 | return f_ 133 | 134 | 135 | def f_exact(): 136 | return fused_landmark_attention(q, k, v, is_mem, block_size=block_size) 137 | 138 | def f_torch(): 139 | return torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True) 140 | 141 | if use_I_for_v and d >= seqlen_k: 142 | assert torch.allclose(out.sum(-1), torch.ones_like(out.sum(-1)), atol=1e-03, rtol=1e-03), out.sum(-1) 143 | assert torch.allclose(exact_out.sum(-1), torch.ones_like(exact_out.sum(-1))) 144 | 145 | #print("Exact", exact_out[:, :, mem_freq-1:mem_freq+2]) 146 | #print("Fused", out[:, :, mem_freq-1:mem_freq+2]) 147 | assert torch.allclose(out, exact_out, rtol=1e-02, atol=1e-02), (out, exact_out) 148 | #print("Diff", (out - exact_out).max(dim=-2)) 149 | 150 | 151 | #print(last_section_mask[0, 0]) 152 | #print(att[0, 0]) 153 | #print(is_mem[0, 0]) 154 | #print(is_mem[0, 0, -1]) 155 | 156 | grads = torch.autograd.grad((out ** 2).sum(), [q, k, v]) 157 | exact_grads = torch.autograd.grad((exact_out ** 2).sum(), [q, k, v]) 158 | #print(len(exact_grads)) 159 | 160 | for grad, exact_grad, t in zip(grads, exact_grads, ["q", "k", "v"]): 161 | torch.set_printoptions(sci_mode=False) 162 | if not torch.allclose(grad, exact_grad, rtol=4e-02, atol=4e-02): 163 | #print((grad, exact_grad, t)) 164 | print("Failed d", t) 165 | #print(t, (grad - exact_grad).max()) 166 | #print(t, torch.argmax((grad - exact_grad).amax(dim=-1), dim=-1)) 167 | 168 | print("Done once") 169 | #print(v.grad) 170 | 171 | import timeit 172 | print("Exact: ", timeit.timeit(f, number=500)) 173 | print("Fused: ", timeit.timeit(f_exact, number=500)) 174 | print("Torch: ", timeit.timeit(f_torch, number=500)) 175 | 176 | print("Exact Grad: ", timeit.timeit(make_f_grad(f), number=500)) 177 | print("Fused Grad: ", timeit.timeit(make_f_grad(f_exact), number=500)) 178 | print("Torch Grad: ", timeit.timeit(make_f_grad(f_torch), number=500)) 179 | -------------------------------------------------------------------------------- /llama/redpajama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Together Computer 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """RedPajama: An Open-Source, Clean-Room 1.2 Trillion Token Dataset.""" 17 | 18 | 19 | import json 20 | 21 | import datasets 22 | import traceback 23 | import numpy as np 24 | import math 25 | 26 | logger = datasets.logging.get_logger(__name__) 27 | 28 | 29 | _DESCRIPTION = """\ 30 | RedPajama is a clean-room, fully open-source implementation of the LLaMa dataset. 31 | """ 32 | 33 | _URL_LISTS = { 34 | "arxiv": "urls/arxiv.txt", 35 | "book": "urls/book.txt", 36 | "c4": "urls/c4.txt", 37 | "common_crawl": "urls/common_crawl.txt", 38 | "github": "urls/github.txt", 39 | "stackexchange": "urls/stackexchange.txt", 40 | "wikipedia": "urls/wikipedia.txt", 41 | } 42 | 43 | 44 | class RedPajama1TConfig(datasets.BuilderConfig): 45 | """BuilderConfig for RedPajama sample.""" 46 | 47 | def __init__(self, *args, subsets, p_sample=None, **kwargs): 48 | """BuilderConfig for RedPajama. 49 | Args: 50 | **kwargs: keyword arguments forwarded to super. 51 | """ 52 | super(RedPajama1TConfig, self).__init__(**kwargs) 53 | 54 | self.subsets = subsets 55 | self.p_sample = p_sample 56 | 57 | 58 | class RedPajama1T(datasets.GeneratorBasedBuilder): 59 | """RedPajama: Reproducing the LLaMA training dataset of over 1.2 trillion tokens. Version 1.0.0.""" 60 | BUILDER_CONFIG_CLASS = RedPajama1TConfig 61 | BUILDER_CONFIGS = [ 62 | RedPajama1TConfig( 63 | subsets = list(_URL_LISTS.keys()), 64 | name="plain_text", 65 | version=datasets.Version("1.0.0", ""), 66 | description="Plain text", 67 | ), 68 | RedPajama1TConfig( 69 | subsets = list(_URL_LISTS.keys()), 70 | name="plain_text_tenpercent", 71 | version=datasets.Version("1.0.0", ""), 72 | description="Plain text", 73 | p_sample=0.1 74 | ), 75 | ] 76 | 77 | def _info(self): 78 | return datasets.DatasetInfo( 79 | description=_DESCRIPTION, 80 | features=datasets.Features( 81 | { 82 | "text": datasets.Value("string"), 83 | "meta": datasets.Value("string"), 84 | "red_pajama_subset": datasets.Value("string"), 85 | } 86 | ), 87 | supervised_keys=None, 88 | ) 89 | 90 | def _split_generators(self, dl_manager): 91 | url_lists = dl_manager.download_and_extract({ 92 | subset: _URL_LISTS[subset] for subset in self.config.subsets 93 | }) 94 | 95 | urls = {} 96 | rng = np.random.default_rng(seed=2) 97 | 98 | for subset, url_list in url_lists.items(): 99 | with open(url_list, encoding="utf-8") as f: 100 | urls[subset] = [line.strip() for line in f] 101 | if self.config.p_sample is not None: 102 | urls[subset] = rng.choice( 103 | urls[subset], 104 | size=int(math.ceil(len(urls[subset]) * self.config.p_sample)), replace=False).tolist() 105 | 106 | downloaded_files = dl_manager.download(urls) 107 | 108 | return [ 109 | datasets.SplitGenerator( 110 | name=datasets.Split.TRAIN, 111 | gen_kwargs = { 112 | "files": { 113 | subset: downloaded_files[subset] 114 | for subset in self.config.subsets 115 | } 116 | } 117 | ) 118 | ] 119 | 120 | def _generate_examples(self, files): 121 | """This function returns the examples in the raw (text) form.""" 122 | key = 0 123 | for subset in files: 124 | if subset == "common_crawl": 125 | import zstandard as zstd 126 | 127 | for path in files[subset]: 128 | with zstd.open(open(path, "rb"), "rt", encoding="utf-8") as f: 129 | for i, row in enumerate(f): 130 | try: 131 | data = json.loads(row) 132 | text = data["text"] 133 | del data["text"] 134 | yield key, { 135 | "text": text, 136 | "meta": json.dumps(data), 137 | "red_pajama_subset": subset, 138 | } 139 | key += 1 140 | except Exception as e: 141 | print(f'Subset: {subset}') 142 | print(f'Path: {path}') 143 | print(f'Row: {row}') 144 | traceback.print_exc() 145 | 146 | raise e 147 | else: 148 | for path in files[subset]: 149 | with open(path, encoding="utf-8") as f: 150 | for i, row in enumerate(f): 151 | try: 152 | data = json.loads(row) 153 | if "meta" not in data: 154 | text = data["text"] 155 | del data["text"] 156 | yield key, { 157 | "text": text, 158 | "meta": json.dumps(data), 159 | "red_pajama_subset": subset, 160 | } 161 | else: 162 | yield key, { 163 | "text": data["text"], 164 | "meta": data["meta"], 165 | "red_pajama_subset": subset, 166 | } 167 | key += 1 168 | except Exception as e: 169 | print(f'Subset: {subset}') 170 | print(f'Path: {path}') 171 | print(f'Row: {row}') 172 | traceback.print_exc() 173 | 174 | raise e 175 | -------------------------------------------------------------------------------- /llama/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | rouge_score 3 | fire 4 | openai 5 | transformers>=4.28.1 6 | torch>=2.0 7 | sentencepiece 8 | tokenizers>=0.13.3 9 | wandb 10 | accelerate 11 | datasets 12 | 13 | 14 | # We need Triton 2.1 but this conflicts with PyTorch 2.0 for now. Therefore the following line is commented out for now. Instead use install_deps.sh 15 | #git+https://github.com/openai/triton.git#subdirectory=python 16 | -------------------------------------------------------------------------------- /llama/run_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | import os 18 | import random 19 | import re 20 | import requests 21 | 22 | 23 | llama_weights_7b_base = "/llama_weights/7B_hf/" 24 | llama_weights_7b_tuned = "/llama-redpajama-mem-15000-with-mem/" 25 | cache_path = "/hf-cache/" 26 | use_flash = False # using flash for inference is only implemented for when offloading kv to cpu 27 | top_k = 5 28 | dtype = torch.bfloat16 29 | 30 | def make_llama_base_pipe(): 31 | 32 | from transformers import pipeline 33 | 34 | from transformers.models.llama import LlamaForCausalLM 35 | 36 | llama_base = LlamaForCausalLM.from_pretrained( 37 | llama_weights_7b_base, 38 | cache_dir=cache_path, 39 | ) 40 | 41 | llama_base = llama_base.to('cuda:0') 42 | 43 | import transformers 44 | 45 | tokenizer = transformers.AutoTokenizer.from_pretrained( 46 | llama_weights_7b_base, 47 | cache_dir=cache_path, 48 | model_max_length=2048, 49 | padding_side="right", 50 | use_fast=False, 51 | ) 52 | 53 | llama_base_pipe = pipeline("text-generation", model=llama_base, tokenizer=tokenizer, device=llama_base.device) 54 | return llama_base_pipe 55 | 56 | 57 | 58 | llama_base_pipe = make_llama_base_pipe() 59 | 60 | def make_llama_mem_pipe(): 61 | from llama_mem import LlamaForCausalLM 62 | 63 | model = LlamaForCausalLM.from_pretrained( 64 | llama_weights_7b_tuned, 65 | cache_dir=cache_path, 66 | torch_dtype=dtype 67 | ) 68 | 69 | model.to('cuda:1') 70 | 71 | import transformers 72 | 73 | tokenizer = transformers.AutoTokenizer.from_pretrained( 74 | llama_weights_7b_tuned, 75 | cache_dir=cache_path, 76 | model_max_length=model.config.train_context_length, 77 | padding_side="right", 78 | use_fast=False, 79 | ) 80 | mem_id = tokenizer.convert_tokens_to_ids("") 81 | model.set_mem_id(mem_id) 82 | from transformers import pipeline 83 | llama_mem_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=model.device, 84 | offload_cache_to_cpu=use_flash, use_flash=use_flash, 85 | cache_top_k=top_k) 86 | return llama_mem_pipe 87 | 88 | 89 | llama_mem_pipe = make_llama_mem_pipe() 90 | 91 | 92 | 93 | pipes = {"base": llama_base_pipe, "mem": llama_mem_pipe} 94 | 95 | 96 | def generate_prompt(n_garbage): 97 | """Generates a text file and inserts an execute line at a random position.""" 98 | n_garbage_prefix = random.randint(0, n_garbage) 99 | n_garbage_suffix = n_garbage - n_garbage_prefix 100 | 101 | task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there." 102 | garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again." 103 | garbage_inf = " ".join([garbage] * 2000) 104 | assert len(garbage_inf) >= n_garbage 105 | garbage_prefix = garbage_inf[:n_garbage_prefix] 106 | garbage_suffix = garbage_inf[:n_garbage_suffix] 107 | pass_key = random.randint(1, 50000) 108 | information_line = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key." 109 | final_question = "What is the pass key? The pass key is" 110 | lines = [ 111 | task_description, 112 | garbage_prefix, 113 | information_line, 114 | garbage_suffix, 115 | final_question 116 | ] 117 | return "\n".join(lines), pass_key 118 | 119 | 120 | 121 | def test_model(prompt_text, pass_key, model_name): 122 | response = pipes[model_name](prompt_text,num_return_sequences=1, max_new_tokens=10)[0]["generated_text"][len(prompt_text):] 123 | assert f"The pass key is {pass_key}" in prompt_text 124 | 125 | try: 126 | pass_key = int(re.search(r'\d+', response).group()) 127 | except: 128 | pass_key = response[:20] 129 | 130 | return pass_key 131 | 132 | 133 | n_values = [0, 100, 500, 1000, 5000, 8000, 10000, 12000, 14000, 18000, 20000, 25000, 38000] 134 | num_tests = 50 135 | models = ["base", "mem"] 136 | accuracies = {x: [] for x in models} 137 | individual_results = {x: [] for x in models} 138 | 139 | for n in n_values: 140 | 141 | correct_count = {x: 0 for x in models} 142 | 143 | n_results = {x: [] for x in models} 144 | for i in range(num_tests): 145 | print(f"\nRunning test {i + 1}/{num_tests} for n = {n}...") 146 | prompt_text, pass_key = generate_prompt(n) 147 | 148 | 149 | 150 | for model_name in models: 151 | if pipes[model_name] is None: 152 | continue 153 | num_tokens = len(pipes[model_name].tokenizer.encode(prompt_text)) 154 | 155 | print("Number of tokens in this prompt: ", num_tokens) 156 | model_output = test_model(prompt_text, pass_key, model_name) 157 | print(f"Expected number in the prompt: {pass_key}, {model_name} output: {model_output}") 158 | 159 | if pass_key == model_output: 160 | correct_count[model_name] += 1 161 | n_results[model_name].append(1) 162 | print("Success!") 163 | else: 164 | n_results[model_name].append(0) 165 | print("Fail.") 166 | 167 | for model in models: 168 | accuracy = (correct_count[model] / num_tests) * 100 169 | print(f"Accuracy {model} for n = {n}: {accuracy}%") 170 | accuracies[model].append(accuracy) 171 | individual_results[model].append(n_results) 172 | -------------------------------------------------------------------------------- /llama/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import copy 16 | import logging 17 | from dataclasses import dataclass, field 18 | from functools import partial 19 | from typing import Dict, Optional, Sequence 20 | 21 | import torch 22 | import transformers 23 | from torch.utils.data import Dataset 24 | from transformers import Trainer, DataCollatorForLanguageModeling, get_cosine_schedule_with_warmup 25 | from llama_mem import LlamaForCausalLM 26 | 27 | from torch.distributed import barrier 28 | import os 29 | 30 | 31 | from datasets import load_dataset 32 | 33 | IGNORE_INDEX = -100 34 | DEFAULT_PAD_TOKEN = "[PAD]" 35 | DEFAULT_EOS_TOKEN = "" 36 | DEFAULT_BOS_TOKEN = "" 37 | DEFAULT_UNK_TOKEN = "" 38 | 39 | 40 | @dataclass 41 | class ModelArguments: 42 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 43 | 44 | @dataclass 45 | class TrainingArguments(transformers.TrainingArguments): 46 | cache_dir: Optional[str] = field(default=None) 47 | optim: str = field(default="adamw_torch") 48 | model_max_length: int = field( 49 | default=512, 50 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 51 | ) 52 | use_flash: bool = field(default=False) 53 | mem_freq: int = field(default=63) 54 | 55 | 56 | class TrainerCosine(Trainer): 57 | def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): 58 | """ 59 | Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or 60 | passed as an argument. 61 | 62 | Args: 63 | num_training_steps (int): The number of training steps to do. 64 | """ 65 | if self.args.lr_scheduler_type != "cosine": 66 | return super().create_scheduler(num_training_steps, optimizer) 67 | if self.lr_scheduler is None: 68 | self.lr_scheduler = get_cosine_schedule_with_warmup( 69 | optimizer=self.optimizer if optimizer is None else optimizer, 70 | num_warmup_steps=self.args.get_warmup_steps(num_training_steps), 71 | num_training_steps=num_training_steps, 72 | num_cycles=0.4 # ~10% of the init lr 73 | ) 74 | return self.lr_scheduler 75 | 76 | 77 | def smart_tokenizer_and_embedding_resize( 78 | special_tokens_dict: Dict, 79 | tokenizer: transformers.PreTrainedTokenizer, 80 | model: transformers.PreTrainedModel, 81 | ): 82 | """Resize tokenizer and embedding. 83 | 84 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 85 | """ 86 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 87 | model.resize_token_embeddings(len(tokenizer)) 88 | 89 | if num_new_tokens > 0: 90 | input_embeddings = model.get_input_embeddings().weight.data 91 | output_embeddings = model.get_output_embeddings().weight.data 92 | 93 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 94 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 95 | 96 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 97 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 98 | 99 | def tokenize_fn(tokenizer, example): 100 | context_length = tokenizer.model_max_length 101 | outputs = tokenizer( 102 | tokenizer.eos_token.join(example["text"]), 103 | truncation=False, 104 | return_tensors="pt", 105 | pad_to_multiple_of=context_length, 106 | padding=True, 107 | ) 108 | return {"input_ids": outputs["input_ids"].view(-1, context_length)} 109 | 110 | def add_mem_tokens(example, mem_freq, mem_id): 111 | x = example["input_ids"] 112 | ret = [] 113 | prev_idx = 0 114 | for t_idx in range(mem_freq, len(x), mem_freq): 115 | ret.extend(x[prev_idx:t_idx]) 116 | ret.append(mem_id) 117 | prev_idx = t_idx 118 | ret.extend(x[prev_idx:]) 119 | # drop attention_mask 120 | return {"input_ids": ret} 121 | 122 | def train(): 123 | parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments)) 124 | model_args, training_args = parser.parse_args_into_dataclasses() 125 | 126 | model = LlamaForCausalLM.from_pretrained( 127 | model_args.model_name_or_path, 128 | cache_dir=training_args.cache_dir, 129 | mem_freq=training_args.mem_freq, 130 | include_landmark_in_loss=not training_args.use_flash 131 | ) 132 | 133 | tokenizer = transformers.AutoTokenizer.from_pretrained( 134 | model_args.model_name_or_path, 135 | cache_dir=training_args.cache_dir, 136 | model_max_length=training_args.model_max_length, 137 | padding_side="right", 138 | use_fast=False, 139 | ) 140 | special_tokens_dict = dict() 141 | if tokenizer.pad_token is None: 142 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 143 | if tokenizer.eos_token is None: 144 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN 145 | if tokenizer.bos_token is None: 146 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN 147 | if tokenizer.unk_token is None: 148 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN 149 | mem_token = "" 150 | special_tokens_dict["additional_special_tokens"] = [mem_token] 151 | 152 | smart_tokenizer_and_embedding_resize( 153 | special_tokens_dict=special_tokens_dict, 154 | tokenizer=tokenizer, 155 | model=model, 156 | ) 157 | 158 | mem_id = tokenizer.convert_tokens_to_ids(mem_token) 159 | model.set_mem_id(mem_id) 160 | 161 | rank = int(os.environ.get('RANK', -1)) 162 | if rank > 0: 163 | barrier() 164 | dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", cache_dir=training_args.cache_dir) 165 | 166 | dataset = dataset.map(partial(tokenize_fn,tokenizer),batched=True, num_proc=32, remove_columns=["text", "meta"]) 167 | 168 | if training_args.use_flash: 169 | model.enable_landmark_insertion() 170 | model.enable_flash() 171 | else: 172 | dataset = dataset.map( 173 | partial( 174 | add_mem_tokens, 175 | mem_freq=training_args.mem_freq, 176 | mem_id=mem_id 177 | ), batched=False, num_proc=32) 178 | 179 | if rank == 0: 180 | barrier() 181 | print(dataset) 182 | 183 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) 184 | 185 | trainer = TrainerCosine( 186 | model=model, tokenizer=tokenizer, args=training_args, 187 | train_dataset=dataset["train"], 188 | eval_dataset=None, 189 | data_collator=data_collator) 190 | trainer.train() 191 | trainer.save_state() 192 | trainer.save_model(output_dir=training_args.output_dir) 193 | 194 | 195 | if __name__ == "__main__": 196 | train() 197 | -------------------------------------------------------------------------------- /llama/urls/arxiv.txt: -------------------------------------------------------------------------------- 1 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_023827cd-7ee8-42e6-aa7b-661731f4c70f.jsonl 2 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_024de5df-1b7f-447c-8c3a-51407d8d6732.jsonl 3 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_03232e26-be3f-4a28-a5d2-ee1d8c0e9831.jsonl 4 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_034e819a-cfcb-43c6-ad25-0232ad48823c.jsonl 5 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_077ae8de-a68e-47e7-95a6-6d82f8f4eeb9.jsonl 6 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_0af50072-df4c-4084-a833-cebbd046e70e.jsonl 7 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_0de84cfc-c080-471f-b139-1bf061db4feb.jsonl 8 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_0fbdd8ad-32d8-4228-9a40-e09dde689760.jsonl 9 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_11c659c1-ffbf-4455-abfd-058f6bbf4bb2.jsonl 10 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_1958455d-6543-4307-a081-d86ce0637f9a.jsonl 11 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_1982fb29-c4ed-4dd3-855c-666e63bc62d9.jsonl 12 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_1caed86f-5625-4941-bdc1-cc57e4fec1cd.jsonl 13 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_1d3a0cd6-f0e6-4106-a080-524a4bd50016.jsonl 14 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_29d54f5a-1dd0-4e9a-b783-fb2eec9db072.jsonl 15 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_29fd3d99-53fb-43e2-a4a5-2fd01bf77258.jsonl 16 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_2b224cd9-286e-46ac-8c4e-c1e3befc8760.jsonl 17 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_2c131fca-2a05-4d5f-a805-59d2af3477e2.jsonl 18 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_2f28f1a7-6972-48ad-8997-65a5d52e4f1c.jsonl 19 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_30440198-cd90-48c6-82c1-ea871b8c21c5.jsonl 20 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_39367d6c-d7d4-45fc-a929-8a17184d1744.jsonl 21 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_393d19f2-1cd1-421f-be8a-78d955fdf602.jsonl 22 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_3a5d4f93-97ec-483a-88ef-324df9651b3f.jsonl 23 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_3c89ea11-69ff-4049-b775-f0c785997909.jsonl 24 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_3d5a011a-4bbe-4585-a2bd-ff3e943c8671.jsonl 25 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_3f805f4b-6f7f-42a8-a006-47c1e0401bd7.jsonl 26 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_3f9eb7ad-f266-4154-8d4d-54deeffde075.jsonl 27 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_400748d3-0076-4a04-8a1c-6055ba0b5a2d.jsonl 28 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_44e19375-3995-4dff-a3b6-8a25247a165c.jsonl 29 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_4a8cf52f-81d0-4875-9528-466b1cbc71e1.jsonl 30 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_4cc7015c-c39a-4bf6-9686-c00b3343edd9.jsonl 31 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_50757a42-079b-41ec-bcca-73759faffd62.jsonl 32 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_575ae832-e770-4a89-bfa7-c56f16dbca69.jsonl 33 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_580be642-bb73-4d0d-8b5e-f494722934cd.jsonl 34 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_5a02d9ee-12a0-437d-808f-d26f0eb2012b.jsonl 35 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_5d8d402b-8277-480a-b5fa-71169726864f.jsonl 36 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_5ee33ef7-455e-4fd5-9512-c4771dd802c1.jsonl 37 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_610c82ed-b9ee-449c-83b0-601205f3a74a.jsonl 38 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_629fe3ca-075f-4663-9b81-b807f3b42bf2.jsonl 39 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_64e5075e-e87e-4b2a-9e38-e5c102f6f2b1.jsonl 40 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_65dd2ff6-dae3-4a60-90d3-c3d7349fc92f.jsonl 41 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_6719ecd2-fe34-4078-a584-320d921cbf6f.jsonl 42 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_6938ee72-43ee-4ade-8840-151a402383b0.jsonl 43 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_73241940-66c1-481c-b53a-f5e8b9afe9fa.jsonl 44 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_751370b5-c7cb-44d8-a039-1468ee6747ab.jsonl 45 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_75af5d17-5ebb-4460-9f2a-dc9fe880a936.jsonl 46 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_79d50803-f7d9-4aa8-bf1a-d807980a40c6.jsonl 47 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_7b26046f-7c8d-405b-911b-df51e1a069fa.jsonl 48 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_7d1d69dc-bc8e-4817-9cab-afdc002ab7c4.jsonl 49 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_7ea7a996-b1bb-4773-a36a-461dce2de861.jsonl 50 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_8232f276-9e3f-463a-9350-362de1b501d1.jsonl 51 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_8509f5a7-64a8-4813-92dc-f6eb53e3aacc.jsonl 52 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_85b4c166-469d-449c-ab3d-5214c1d80246.jsonl 53 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_872b620a-b4fd-45d3-92bc-ff0584447705.jsonl 54 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_88f24f8d-16d3-4a21-894d-192033d0fa67.jsonl 55 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_8e6bd730-0f10-49d9-9b02-5ce16da47483.jsonl 56 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_8ede1b71-6846-439a-acba-86a57cfec3d2.jsonl 57 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_8f74f6ba-1c53-42d5-a3c7-e4ef46a71133.jsonl 58 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_90fa9c2b-25b0-47b7-af2b-a683356e543b.jsonl 59 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_92ec488a-287d-4bf0-977b-6998cf0cf476.jsonl 60 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_94a393a1-3b23-4961-a0a6-70bad5b4979c.jsonl 61 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_94b9df70-a95f-4545-be3a-5a34f7b09fb3.jsonl 62 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_95ffc9e1-c505-4a3b-8fb0-cbc98b8703e1.jsonl 63 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_98e718fd-5b0e-439f-a00c-57b61e06b395.jsonl 64 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_9f50a028-2586-4e0d-bcfd-d9d2d74e8953.jsonl 65 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_9f8d7d10-dda7-4e44-b00c-811635a199c8.jsonl 66 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_a055bd62-1ec2-47cf-bad2-321e3d4f053f.jsonl 67 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_a1e3430e-ef5c-4a86-914d-88e8fb7818c0.jsonl 68 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_a2b4cb3d-bea3-478e-82a2-77c00a827250.jsonl 69 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_a647274d-799d-4e7a-a485-b8632a87061e.jsonl 70 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_a94ea420-99ae-4d58-9cdc-d4666e3322a7.jsonl 71 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_ab7c0034-7fc1-4fa8-bae3-e97b85fc16a4.jsonl 72 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_b0732cff-657e-4e69-87b8-66e8025bf441.jsonl 73 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_b1f3b912-a2ab-43bd-8811-43d84b422506.jsonl 74 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_b6c2e4d3-d215-4d99-891f-d20b997d4d5a.jsonl 75 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_bbfabe6b-b9bc-476b-b8f0-7d6c47e9d2be.jsonl 76 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_bf11ef29-a3f9-4a2d-9dbf-ebcc56d39fdb.jsonl 77 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_c05da42b-4939-4f55-867c-16cf6d228e60.jsonl 78 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_c1fc3dd5-861f-4b8d-b7a2-eb8f6887b33b.jsonl 79 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_c44164b4-0770-48d0-87db-590ca529032a.jsonl 80 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_c6b43cda-ca5c-4855-9c08-0f8264cab1af.jsonl 81 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_c6e57a16-4879-4dcf-b591-503cfb46a360.jsonl 82 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_ca6e3842-6ca4-4230-84fa-376a3374c380.jsonl 83 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_caf769e4-7308-4419-9114-900ca213682a.jsonl 84 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_d0da78e9-3dcf-4a46-85c1-f23ed00178bc.jsonl 85 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_d386838c-a51e-4839-9f27-8495b2466e49.jsonl 86 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_d41edf9f-0ebb-4866-b3fe-50785746b36b.jsonl 87 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_d6a7fc44-b584-4dd8-9de2-e981afe0bb4a.jsonl 88 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_d7c49dbb-c008-47fc-9cbe-8d5695842d21.jsonl 89 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_db3b4be7-4e98-4fe9-96bf-05a5788815e3.jsonl 90 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_db981f69-9eca-4031-8565-318b949efbfe.jsonl 91 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_dbd4105f-7cbb-4483-a7b2-96b17b7fb594.jsonl 92 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_de42e348-b333-4d35-b883-9bfc94f29822.jsonl 93 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_de744938-fa6c-45dd-b600-428dd7c63a73.jsonl 94 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_e8f25867-697d-4f52-84e1-e50a95bc182b.jsonl 95 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_eb4e26d4-6625-4f8a-b5fe-6f3a9b8a4b79.jsonl 96 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_f141b736-5ce4-4f18-bb29-704227ca4bd1.jsonl 97 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_f50efdc6-f88e-4fa6-9ef6-dd1d8314bb36.jsonl 98 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_f7680c03-70df-4781-a98d-c88695f92f04.jsonl 99 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_fbc62949-624d-4943-9731-f5c46242ba55.jsonl 100 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_fd572627-cce7-4667-a684-fef096dfbeb7.jsonl -------------------------------------------------------------------------------- /llama/urls/book.txt: -------------------------------------------------------------------------------- 1 | https://data.together.xyz/redpajama-data-1T/v1.0.0/book/book.jsonl -------------------------------------------------------------------------------- /llama/urls/github.txt: -------------------------------------------------------------------------------- 1 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_08cdfa755e6d4d89b673d5bd1acee5f6.sampled.jsonl 2 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_0f27d10d846a473b96070c3394832f32.sampled.jsonl 3 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_0f979046c8e64e0fb5843d2634a9957d.sampled.jsonl 4 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_10f129bfd0af45caa9cd72aa9d863ec5.sampled.jsonl 5 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_11a1943edfa349c7939382799599eed6.sampled.jsonl 6 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_17197bd2478044bebd9ff4634b6dfcee.sampled.jsonl 7 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_1d750c0ce39d40c6bc20bad9469e5a99.sampled.jsonl 8 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_21078cf63afb4d9eb4a7876f726a7226.sampled.jsonl 9 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_216883d3a669406699428bc485a4c228.sampled.jsonl 10 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_24278a707deb445b8e4f59c83dd67910.sampled.jsonl 11 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_25989b8233a04ac791b0eccd502e0c7a.sampled.jsonl 12 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_26a6fa61c5eb4bb885e7bc643e285f0e.sampled.jsonl 13 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_275100c1a44f4451b0343373ebc5637a.sampled.jsonl 14 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_27f05c041a1c401783f90b9415e40e4b.sampled.jsonl 15 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_28106e66bfd94978abbc15ec845aeddb.sampled.jsonl 16 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_2afd093fefad4c8da76cc539e8fb6137.sampled.jsonl 17 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_30d01c4edab64866bda8c609e90b4f4e.sampled.jsonl 18 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_34b0785b77814b7583433ddb27a61ae0.sampled.jsonl 19 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_34e78793c1e94eeebd92852399097596.sampled.jsonl 20 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_366012588bef4d749bbbea76ae701141.sampled.jsonl 21 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_36c2a955ddd84672bbc778aa4ad2fbaf.sampled.jsonl 22 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_3be6a5cc7428401393c23e5516a43537.sampled.jsonl 23 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_3cc6cab9266746a6befa23648aa43119.sampled.jsonl 24 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_3f09f384f0734a4b912bb75db3f812bc.sampled.jsonl 25 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_478afe8aaccb43e6be2de7e34e041ef3.sampled.jsonl 26 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_483f11d6bc864f7fbfbe63bdf3583ce2.sampled.jsonl 27 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_4b6883dc304c4e799620ec95b96dc91a.sampled.jsonl 28 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_4bebabdbd8544da7a2071864ccf81f2e.sampled.jsonl 29 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_4d44caf43c154ae4aaeab36eab0221c9.sampled.jsonl 30 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_4f98dc136ba94eeaa1ba6c974814b33c.sampled.jsonl 31 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_4ff3604761614a3db550ef758e6457b5.sampled.jsonl 32 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_50423e84046948b4a2a70e7e4538e12d.sampled.jsonl 33 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_58fad5994b4446c6bceb33453484acb4.sampled.jsonl 34 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_5abf005e4e634f1dbfa8bd20b5687092.sampled.jsonl 35 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_5f2b3517159b426bb1a9e81ca189abcd.sampled.jsonl 36 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_610e908cafaa4d53958de50ad700822a.sampled.jsonl 37 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_6353ab14cb8f4623a7d75678d9e7f44e.sampled.jsonl 38 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_64747f25102740bab0ab54559569342a.sampled.jsonl 39 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_6742e99802894114a3ba44841f49b168.sampled.jsonl 40 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_677e10f9b0af4c489e60670352e7e224.sampled.jsonl 41 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_68534e6a093744fa9f38fa1a9cf51232.sampled.jsonl 42 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_6adcf92fb2ee48059cb60579f2e931f7.sampled.jsonl 43 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_70f0aa43987643a7874286bca4faa66b.sampled.jsonl 44 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_7462f1a34f594e9f8334d9a0cbbf80e7.sampled.jsonl 45 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_78a13a25258f4d24923702c07445e20e.sampled.jsonl 46 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_78e3afc1cfee41fbb7eaae2e5bfaa17b.sampled.jsonl 47 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_791ed086a57f4879bb1596bed6d37bb3.sampled.jsonl 48 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_7b5cc18857a34a0981b54f082de55cf8.sampled.jsonl 49 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_7c54b908a2df4ec2ba316d2081fc674e.sampled.jsonl 50 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_7edd0af0b61c426e93e8bd3f549a8f78.sampled.jsonl 51 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_7fa232e63a5d44a88a267181e9ac47b4.sampled.jsonl 52 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_81138c01f3a84f7fa8b56bf3d8fa35ce.sampled.jsonl 53 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_82a9a647b0eb4eb080d7ac15a13c765b.sampled.jsonl 54 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_835ac26c7a70447b97f4ec38bcb969ed.sampled.jsonl 55 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_8bc3c4fae78c41d999b2ae6d97cce96c.sampled.jsonl 56 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_8dc430f8f7114f018440c7b3d990e602.sampled.jsonl 57 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_8e1db2cf6c98420a88bad52fd57f4aa7.sampled.jsonl 58 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_919156910e704e6ca52e0d4880cdbb63.sampled.jsonl 59 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_944beac1347f491faa88f43c25d26fe4.sampled.jsonl 60 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_96175b6e4c764abfbbf14e78d4fd6464.sampled.jsonl 61 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_977d1fcdab92452d9dc38b2f4a99389b.sampled.jsonl 62 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_9eea501af8544d0b88f0b002850829d4.sampled.jsonl 63 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_a26c51ffe2924fd8ad6694c6aa0eacc5.sampled.jsonl 64 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_a3a01a2f81ed4cd2afb30e175200b48f.sampled.jsonl 65 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_a3d3e6f3f7d5495ca9ecf94d808fd350.sampled.jsonl 66 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_a777da5620f1467f8df3616b17d533dc.sampled.jsonl 67 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_a8fcae75b0c3410faabcff02f0056a36.sampled.jsonl 68 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_aabc2d54d1c946908d3400228e0f238c.sampled.jsonl 69 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_ad160286662849f49d1e6de27c0f1d15.sampled.jsonl 70 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_adda41d791974289aff042e2e3d07ec3.sampled.jsonl 71 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_ae1813abc63f4b1998dfa608e7fe5588.sampled.jsonl 72 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_af7b386db97e4211a7378be08d7b3f4f.sampled.jsonl 73 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_b15c575fa9f8465d98f70ba2f2f73c6e.sampled.jsonl 74 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_b40763d0b9ce4e0d8fb5f519f1f49f8c.sampled.jsonl 75 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_b56453718c5f46efa9c46feb194b0d6e.sampled.jsonl 76 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_b6225e159b86432d9fa5bf226bb51393.sampled.jsonl 77 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_b821f640b8f14ed588bf48ae13f44098.sampled.jsonl 78 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_bfbcad9633f04601ba2f824d082eaacf.sampled.jsonl 79 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_c6cfd16905814d7c955df1f4754a8b11.sampled.jsonl 80 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_c82c6775f0b74dbdae4524bb9aebf0ef.sampled.jsonl 81 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_ca73c69896b34adbbe468d78d9f134bc.sampled.jsonl 82 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_cac864a472b948b0bfe18c8e9a19aeb5.sampled.jsonl 83 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_cb330e8dc8ac411eba2dc8676c9c4403.sampled.jsonl 84 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_d0a054a678fc4c38b496d10e91a2c735.sampled.jsonl 85 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_d148857715424aabbd32b8ffe56c4082.sampled.jsonl 86 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_d33701435f964c90a86c22e204dd5fde.sampled.jsonl 87 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_d448c2553f474193a1224df1c38f74d4.sampled.jsonl 88 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_d8d45c58819948c9b352d74383944c4a.sampled.jsonl 89 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_db611a692a704f6db3c18e77f79fd2f0.sampled.jsonl 90 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_dcaa5c8b729b4fb599399dbf4557e43e.sampled.jsonl 91 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_dda863d833614d04b96bbe21b161768d.sampled.jsonl 92 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_eabf427d56184fb89f9b5f27e73f7988.sampled.jsonl 93 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_eee6980922ca4a14b0c3341fa8a904d9.sampled.jsonl 94 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_f478ff72b57f4f4283c22ac22ae84134.sampled.jsonl 95 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_f931feb0e85940879d194c0e20d9e28a.sampled.jsonl 96 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_f9bbe6a065004a4c8e018c6ad63063b2.sampled.jsonl 97 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_fbfc552e48164acda6605fa31fc2f563.sampled.jsonl 98 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_fc817e4c16494957bcb89b166e91434f.sampled.jsonl -------------------------------------------------------------------------------- /llama/urls/stackexchange.txt: -------------------------------------------------------------------------------- 1 | https://data.together.xyz/redpajama-data-1T/v1.0.0/stackexchange/stackexchange.jsonl -------------------------------------------------------------------------------- /llama/urls/wikipedia.txt: -------------------------------------------------------------------------------- 1 | https://data.together.xyz/redpajama-data-1T/v1.0.0/wikipedia/wiki.jsonl -------------------------------------------------------------------------------- /llama/weight_diff.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This file has been changed by Amirkeivan Mohtashami 16 | # to take into account the new token in the embedding layer 17 | 18 | import os 19 | from typing import Optional 20 | 21 | import fire 22 | import torch 23 | import tqdm 24 | import transformers 25 | from train import smart_tokenizer_and_embedding_resize 26 | import llama_mem 27 | 28 | @torch.inference_mode() 29 | def make_diff( 30 | path_raw: str, path_tuned: str, path_diff: str, device="cpu", # "cuda" or "cpu" 31 | ): 32 | """Make the weight diff. 33 | 34 | This function is given to present full transparency of how the weight diff was created. 35 | 36 | Run: 37 | python weight_diff.py make_diff --path_raw --path_tuned --path_diff 38 | """ 39 | model_tuned: transformers.PreTrainedModel = llama_mem.LlamaForCausalLM.from_pretrained( 40 | path_tuned, 41 | device_map={"": torch.device(device)}, 42 | torch_dtype=torch.float32, 43 | low_cpu_mem_usage=True, 44 | ) 45 | model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained( 46 | path_raw, 47 | device_map={"": torch.device(device)}, 48 | torch_dtype=torch.float32, 49 | low_cpu_mem_usage=True, 50 | ) 51 | 52 | tokenizer_tuned: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained( 53 | path_tuned 54 | ) 55 | tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained( 56 | path_raw 57 | ) 58 | smart_tokenizer_and_embedding_resize( 59 | special_tokens_dict=dict(pad_token="[PAD]", additional_special_tokens=[""]), 60 | model=model_raw, 61 | tokenizer=tokenizer_raw, 62 | ) 63 | 64 | 65 | 66 | state_dict_tuned = model_tuned.state_dict() 67 | state_dict_raw = model_raw.state_dict() 68 | with open(os.path.join(path_diff, "checksum_psum.txt"), "w") as f: 69 | f.write(str(sum(state_dict_tuned[key].sum().item() for key in state_dict_tuned))) 70 | 71 | for key in tqdm.tqdm(state_dict_tuned): 72 | state_dict_tuned[key].add_(-state_dict_raw[key]) 73 | 74 | model_tuned.save_pretrained(path_diff) 75 | tokenizer_tuned.save_pretrained(path_diff) 76 | 77 | 78 | @torch.inference_mode() 79 | def recover( 80 | path_raw, 81 | path_diff, 82 | path_tuned: Optional[str] = None, 83 | device="cpu", 84 | test_inference=True, 85 | check_integrity_naively=True, 86 | ): 87 | """Recover the original weights from the released weight diff. 88 | 89 | This function is given for you to run. 90 | 91 | Things to do before running this: 92 | 1. Convert Meta's released weights into huggingface format. Follow this guide: 93 | https://huggingface.co/docs/transformers/main/model_doc/llama 94 | 2. Make sure you cloned the released weight diff into your local machine. The weight diff is located at: 95 | https://huggingface.co/tatsu-lab/alpaca-7b/tree/main 96 | 3. Run this function with the correct paths. E.g., 97 | python weight_diff.py recover --path_raw --path_diff 98 | 99 | Additional notes: 100 | - If things run too slowly, and you have an 80G GPU lying around, let GPU go brrr by setting `--device "cuda"`. 101 | - If you want to save the recovered weights, set `--path_tuned `. 102 | Next time you can load the recovered weights directly from ``. 103 | """ 104 | model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained( 105 | path_raw, 106 | device_map={"": torch.device(device)}, 107 | torch_dtype=torch.float32, 108 | low_cpu_mem_usage=True, 109 | ) 110 | model_recovered: transformers.PreTrainedModel = llama_mem.LlamaForCausalLM.from_pretrained( 111 | path_diff, 112 | device_map={"": torch.device(device)}, 113 | torch_dtype=torch.float32, 114 | low_cpu_mem_usage=True, 115 | ) 116 | 117 | tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained( 118 | path_raw 119 | ) 120 | smart_tokenizer_and_embedding_resize( 121 | special_tokens_dict=dict(pad_token="[PAD]", additional_special_tokens=[""]), 122 | model=model_raw, 123 | tokenizer=tokenizer_raw, 124 | ) 125 | tokenizer_recovered: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained( 126 | path_diff 127 | ) 128 | 129 | state_dict_recovered = model_recovered.state_dict() 130 | state_dict_raw = model_raw.state_dict() 131 | for key in tqdm.tqdm(state_dict_recovered): 132 | state_dict_recovered[key].add_(state_dict_raw[key]) 133 | 134 | if check_integrity_naively: 135 | # This is not a rigorous, cryptographically strong integrity check :) 136 | allsum = sum(state_dict_recovered[key].sum() for key in state_dict_recovered) 137 | if os.path.exists(os.path.join(path_diff, "checksum_psum.txt")): 138 | with open(os.path.join(path_diff, "checksum_psum.txt")) as f: 139 | expected_sum = float(f.read()) 140 | else: 141 | expected_sum = 49798.7656 # backward compatibility with the first released weights 142 | assert torch.allclose( 143 | allsum, torch.full_like(allsum, fill_value=expected_sum), atol=1e-2, rtol=0 144 | ), "Naive integrity check failed. This could imply that some of the checkpoint files are corrupted." 145 | 146 | if path_tuned is not None: 147 | model_recovered.save_pretrained(path_tuned) 148 | tokenizer_recovered.save_pretrained(path_tuned) 149 | 150 | return model_recovered, tokenizer_recovered 151 | 152 | 153 | def main(task, **kwargs): 154 | globals()[task](**kwargs) 155 | 156 | 157 | if __name__ == "__main__": 158 | fire.Fire(main) 159 | -------------------------------------------------------------------------------- /llama_legacy/redpajama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Together Computer 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """RedPajama: An Open-Source, Clean-Room 1.2 Trillion Token Dataset.""" 17 | 18 | 19 | import json 20 | 21 | import datasets 22 | import traceback 23 | import numpy as np 24 | import math 25 | 26 | logger = datasets.logging.get_logger(__name__) 27 | 28 | 29 | _DESCRIPTION = """\ 30 | RedPajama is a clean-room, fully open-source implementation of the LLaMa dataset. 31 | """ 32 | 33 | _URL_LISTS = { 34 | "arxiv": "urls/arxiv.txt", 35 | "book": "urls/book.txt", 36 | "c4": "urls/c4.txt", 37 | "common_crawl": "urls/common_crawl.txt", 38 | "github": "urls/github.txt", 39 | "stackexchange": "urls/stackexchange.txt", 40 | "wikipedia": "urls/wikipedia.txt", 41 | } 42 | 43 | 44 | class RedPajama1TConfig(datasets.BuilderConfig): 45 | """BuilderConfig for RedPajama sample.""" 46 | 47 | def __init__(self, *args, subsets, p_sample=None, **kwargs): 48 | """BuilderConfig for RedPajama. 49 | Args: 50 | **kwargs: keyword arguments forwarded to super. 51 | """ 52 | super(RedPajama1TConfig, self).__init__(**kwargs) 53 | 54 | self.subsets = subsets 55 | self.p_sample = p_sample 56 | 57 | 58 | class RedPajama1T(datasets.GeneratorBasedBuilder): 59 | """RedPajama: Reproducing the LLaMA training dataset of over 1.2 trillion tokens. Version 1.0.0.""" 60 | BUILDER_CONFIG_CLASS = RedPajama1TConfig 61 | BUILDER_CONFIGS = [ 62 | RedPajama1TConfig( 63 | subsets = list(_URL_LISTS.keys()), 64 | name="plain_text", 65 | version=datasets.Version("1.0.0", ""), 66 | description="Plain text", 67 | ), 68 | RedPajama1TConfig( 69 | subsets = list(_URL_LISTS.keys()), 70 | name="plain_text_tenpercent", 71 | version=datasets.Version("1.0.0", ""), 72 | description="Plain text", 73 | p_sample=0.1 74 | ), 75 | ] 76 | 77 | def _info(self): 78 | return datasets.DatasetInfo( 79 | description=_DESCRIPTION, 80 | features=datasets.Features( 81 | { 82 | "text": datasets.Value("string"), 83 | "meta": datasets.Value("string"), 84 | "red_pajama_subset": datasets.Value("string"), 85 | } 86 | ), 87 | supervised_keys=None, 88 | ) 89 | 90 | def _split_generators(self, dl_manager): 91 | url_lists = dl_manager.download_and_extract({ 92 | subset: _URL_LISTS[subset] for subset in self.config.subsets 93 | }) 94 | 95 | urls = {} 96 | rng = np.random.default_rng(seed=2) 97 | 98 | for subset, url_list in url_lists.items(): 99 | with open(url_list, encoding="utf-8") as f: 100 | urls[subset] = [line.strip() for line in f] 101 | if self.config.p_sample is not None: 102 | urls[subset] = rng.choice( 103 | urls[subset], 104 | size=int(math.ceil(len(urls[subset]) * self.config.p_sample)), replace=False).tolist() 105 | 106 | downloaded_files = dl_manager.download(urls) 107 | 108 | return [ 109 | datasets.SplitGenerator( 110 | name=datasets.Split.TRAIN, 111 | gen_kwargs = { 112 | "files": { 113 | subset: downloaded_files[subset] 114 | for subset in self.config.subsets 115 | } 116 | } 117 | ) 118 | ] 119 | 120 | def _generate_examples(self, files): 121 | """This function returns the examples in the raw (text) form.""" 122 | key = 0 123 | for subset in files: 124 | if subset == "common_crawl": 125 | import zstandard as zstd 126 | 127 | for path in files[subset]: 128 | with zstd.open(open(path, "rb"), "rt", encoding="utf-8") as f: 129 | for i, row in enumerate(f): 130 | try: 131 | data = json.loads(row) 132 | text = data["text"] 133 | del data["text"] 134 | yield key, { 135 | "text": text, 136 | "meta": json.dumps(data), 137 | "red_pajama_subset": subset, 138 | } 139 | key += 1 140 | except Exception as e: 141 | print(f'Subset: {subset}') 142 | print(f'Path: {path}') 143 | print(f'Row: {row}') 144 | traceback.print_exc() 145 | 146 | raise e 147 | else: 148 | for path in files[subset]: 149 | with open(path, encoding="utf-8") as f: 150 | for i, row in enumerate(f): 151 | try: 152 | data = json.loads(row) 153 | if "meta" not in data: 154 | text = data["text"] 155 | del data["text"] 156 | yield key, { 157 | "text": text, 158 | "meta": json.dumps(data), 159 | "red_pajama_subset": subset, 160 | } 161 | else: 162 | yield key, { 163 | "text": data["text"], 164 | "meta": data["meta"], 165 | "red_pajama_subset": subset, 166 | } 167 | key += 1 168 | except Exception as e: 169 | print(f'Subset: {subset}') 170 | print(f'Path: {path}') 171 | print(f'Row: {row}') 172 | traceback.print_exc() 173 | 174 | raise e 175 | -------------------------------------------------------------------------------- /llama_legacy/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | rouge_score 3 | fire 4 | openai 5 | transformers>=4.28.1 6 | torch 7 | sentencepiece 8 | tokenizers>=0.13.3 9 | wandb 10 | accelerate 11 | datasets 12 | -------------------------------------------------------------------------------- /llama_legacy/run_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | llama_weights_7b_base = "/llama_weights/7B_hf/" 16 | llama_weights_7b_tuned = "/llama-redpajama-mem-15000-with-mem/" 17 | cache_path = "/hf-cache/" 18 | 19 | def make_llama_base_pipe(): 20 | 21 | from transformers import pipeline 22 | 23 | from transformers.models.llama import LlamaForCausalLM 24 | 25 | llama_base = LlamaForCausalLM.from_pretrained( 26 | llama_weights_7b_base, 27 | cache_dir=cache_path, 28 | ) 29 | 30 | llama_base = llama_base.to('cuda:0') 31 | 32 | import transformers 33 | 34 | tokenizer = transformers.AutoTokenizer.from_pretrained( 35 | llama_weights_7b_base, 36 | cache_dir=cache_path, 37 | model_max_length=1024, 38 | padding_side="right", 39 | use_fast=False, 40 | ) 41 | 42 | llama_base_pipe = pipeline("text-generation", model=llama_base, tokenizer=tokenizer, device=llama_base.device) 43 | return llama_base_pipe 44 | 45 | 46 | 47 | llama_base_pipe = make_llama_base_pipe() 48 | 49 | def make_llama_mem_pipe(): 50 | from llama_mem import LlamaForCausalLM 51 | 52 | model = LlamaForCausalLM.from_pretrained( 53 | llama_weights_7b_tuned, 54 | cache_dir=cache_path, 55 | ) 56 | 57 | model.to('cuda:1') 58 | 59 | import transformers 60 | 61 | tokenizer = transformers.AutoTokenizer.from_pretrained( 62 | llama_weights_7b_tuned, 63 | cache_dir=cache_path, 64 | model_max_length=512, 65 | padding_side="right", 66 | use_fast=False, 67 | ) 68 | from transformers import pipeline 69 | llama_mem_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=model.device) 70 | return llama_mem_pipe 71 | 72 | 73 | llama_mem_pipe = make_llama_mem_pipe() 74 | 75 | mem_id = llama_mem_pipe.tokenizer.convert_tokens_to_ids("") 76 | llama_mem_pipe.model.set_mem_id(mem_id) 77 | llama_mem_pipe.model.set_mem_cache_args(max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None) 78 | 79 | 80 | pipes = {"base": llama_base_pipe, "mem": llama_mem_pipe} 81 | 82 | import torch 83 | 84 | import os 85 | import random 86 | import re 87 | import requests 88 | 89 | def generate_prompt(n_garbage): 90 | """Generates a text file and inserts an execute line at a random position.""" 91 | n_garbage_prefix = random.randint(0, n_garbage) 92 | n_garbage_suffix = n_garbage - n_garbage_prefix 93 | 94 | task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there." 95 | garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again." 96 | garbage_inf = " ".join([garbage] * 2000) 97 | assert len(garbage_inf) >= n_garbage 98 | garbage_prefix = garbage_inf[:n_garbage_prefix] 99 | garbage_suffix = garbage_inf[:n_garbage_suffix] 100 | pass_key = random.randint(1, 50000) 101 | information_line = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key." 102 | final_question = "What is the pass key? The pass key is" 103 | lines = [ 104 | task_description, 105 | garbage_prefix, 106 | information_line, 107 | garbage_suffix, 108 | final_question 109 | ] 110 | return "\n".join(lines), pass_key 111 | 112 | 113 | 114 | def test_model(prompt_text, pass_key, model_name): 115 | response = pipes[model_name](prompt_text,num_return_sequences=1, max_new_tokens=10)[0]["generated_text"][len(prompt_text):] 116 | assert f"The pass key is {pass_key}" in prompt_text 117 | 118 | try: 119 | pass_key = int(re.search(r'\d+', response).group()) 120 | except: 121 | pass_key = response[:20] 122 | 123 | return pass_key 124 | 125 | 126 | n_values = [0, 100, 500, 1000, 5000, 8000, 10000, 12000, 14000, 18000, 20000, 25000, 38000] 127 | num_tests = 50 128 | models = ["base", "mem"] 129 | accuracies = {x: [] for x in models} 130 | individual_results = {x: [] for x in models} 131 | 132 | for n in n_values: 133 | 134 | correct_count = {x: 0 for x in models} 135 | 136 | n_results = {x: [] for x in models} 137 | for i in range(num_tests): 138 | print(f"\nRunning test {i + 1}/{num_tests} for n = {n}...") 139 | prompt_text, pass_key = generate_prompt(n) 140 | 141 | 142 | 143 | for model_name in models: 144 | num_tokens = len(pipes[model_name].tokenizer.encode(prompt_text)) 145 | 146 | print("Number of tokens in this prompt: ", num_tokens) 147 | model_output = test_model(prompt_text, pass_key, model_name) 148 | print(f"Expected number in the prompt: {pass_key}, {model_name} output: {model_output}") 149 | 150 | if pass_key == model_output: 151 | correct_count[model_name] += 1 152 | n_results[model_name].append(1) 153 | print("Success!") 154 | else: 155 | n_results[model_name].append(0) 156 | print("Fail.") 157 | 158 | for model in models: 159 | accuracy = (correct_count[model] / num_tests) * 100 160 | print(f"Accuracy {model} for n = {n}: {accuracy}%") 161 | accuracies[model].append(accuracy) 162 | individual_results[model].append(n_results) 163 | -------------------------------------------------------------------------------- /llama_legacy/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import copy 16 | import logging 17 | from dataclasses import dataclass, field 18 | from functools import partial 19 | from typing import Dict, Optional, Sequence 20 | 21 | import torch 22 | import transformers 23 | from torch.utils.data import Dataset 24 | from transformers import Trainer, DataCollatorForLanguageModeling, get_cosine_schedule_with_warmup 25 | from llama_mem import LlamaForCausalLM 26 | 27 | from torch.distributed import barrier 28 | import os 29 | 30 | 31 | from datasets import load_dataset 32 | 33 | IGNORE_INDEX = -100 34 | DEFAULT_PAD_TOKEN = "[PAD]" 35 | DEFAULT_EOS_TOKEN = "" 36 | DEFAULT_BOS_TOKEN = "" 37 | DEFAULT_UNK_TOKEN = "" 38 | 39 | 40 | @dataclass 41 | class ModelArguments: 42 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 43 | 44 | @dataclass 45 | class TrainingArguments(transformers.TrainingArguments): 46 | cache_dir: Optional[str] = field(default=None) 47 | optim: str = field(default="adamw_torch") 48 | model_max_length: int = field( 49 | default=512, 50 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 51 | ) 52 | 53 | 54 | class TrainerCosine(Trainer): 55 | def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): 56 | """ 57 | Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or 58 | passed as an argument. 59 | 60 | Args: 61 | num_training_steps (int): The number of training steps to do. 62 | """ 63 | if self.args.lr_scheduler_type != "cosine": 64 | return super().create_scheduler(num_training_steps, optimizer) 65 | if self.lr_scheduler is None: 66 | self.lr_scheduler = get_cosine_schedule_with_warmup( 67 | optimizer=self.optimizer if optimizer is None else optimizer, 68 | num_warmup_steps=self.args.get_warmup_steps(num_training_steps), 69 | num_training_steps=num_training_steps, 70 | num_cycles=0.4 # ~10% of the init lr 71 | ) 72 | return self.lr_scheduler 73 | 74 | 75 | def smart_tokenizer_and_embedding_resize( 76 | special_tokens_dict: Dict, 77 | tokenizer: transformers.PreTrainedTokenizer, 78 | model: transformers.PreTrainedModel, 79 | ): 80 | """Resize tokenizer and embedding. 81 | 82 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 83 | """ 84 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 85 | model.resize_token_embeddings(len(tokenizer)) 86 | 87 | if num_new_tokens > 0: 88 | input_embeddings = model.get_input_embeddings().weight.data 89 | output_embeddings = model.get_output_embeddings().weight.data 90 | 91 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 92 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 93 | 94 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 95 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 96 | 97 | def tokenize_fn(tokenizer, example): 98 | context_length = tokenizer.model_max_length 99 | outputs = tokenizer( 100 | tokenizer.eos_token.join(example["text"]), 101 | truncation=False, 102 | return_tensors="pt", 103 | pad_to_multiple_of=context_length, 104 | padding=True, 105 | ) 106 | return {"input_ids": outputs["input_ids"].view(-1, context_length)} 107 | 108 | def add_mem_tokens(example, mem_freq, mem_id): 109 | x = example["input_ids"] 110 | ret = [] 111 | prev_idx = 0 112 | for t_idx in range(mem_freq, len(x), mem_freq): 113 | ret.extend(x[prev_idx:t_idx]) 114 | ret.append(mem_id) 115 | prev_idx = t_idx 116 | ret.extend(x[prev_idx:]) 117 | # drop attention_mask 118 | return {"input_ids": ret} 119 | 120 | def train(): 121 | parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments)) 122 | model_args, training_args = parser.parse_args_into_dataclasses() 123 | 124 | model = LlamaForCausalLM.from_pretrained( 125 | model_args.model_name_or_path, 126 | cache_dir=training_args.cache_dir, 127 | ) 128 | 129 | tokenizer = transformers.AutoTokenizer.from_pretrained( 130 | model_args.model_name_or_path, 131 | cache_dir=training_args.cache_dir, 132 | model_max_length=training_args.model_max_length, 133 | padding_side="right", 134 | use_fast=False, 135 | ) 136 | special_tokens_dict = dict() 137 | if tokenizer.pad_token is None: 138 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 139 | if tokenizer.eos_token is None: 140 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN 141 | if tokenizer.bos_token is None: 142 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN 143 | if tokenizer.unk_token is None: 144 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN 145 | mem_token = "" 146 | special_tokens_dict["additional_special_tokens"] = [mem_token] 147 | 148 | smart_tokenizer_and_embedding_resize( 149 | special_tokens_dict=special_tokens_dict, 150 | tokenizer=tokenizer, 151 | model=model, 152 | ) 153 | 154 | mem_id = tokenizer.convert_tokens_to_ids(mem_token) 155 | model.set_mem_id(mem_id) 156 | rank = int(os.environ.get('RANK', -1)) 157 | if rank > 0: 158 | barrier() 159 | dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", cache_dir=training_args.cache_dir) 160 | 161 | dataset = dataset.map(partial(tokenize_fn,tokenizer),batched=True, num_proc=32, remove_columns=["text", "meta"]) 162 | 163 | dataset = dataset.map( 164 | partial( 165 | add_mem_tokens, 166 | mem_freq=50, 167 | mem_id=mem_id 168 | ), batched=False, num_proc=32) 169 | if rank == 0: 170 | barrier() 171 | print(dataset) 172 | 173 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) 174 | 175 | trainer = TrainerCosine( 176 | model=model, tokenizer=tokenizer, args=training_args, 177 | train_dataset=dataset["train"], 178 | eval_dataset=None, 179 | data_collator=data_collator) 180 | trainer.train() 181 | trainer.save_state() 182 | trainer.save_model(output_dir=training_args.output_dir) 183 | 184 | 185 | if __name__ == "__main__": 186 | train() 187 | -------------------------------------------------------------------------------- /llama_legacy/urls/arxiv.txt: -------------------------------------------------------------------------------- 1 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_023827cd-7ee8-42e6-aa7b-661731f4c70f.jsonl 2 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_024de5df-1b7f-447c-8c3a-51407d8d6732.jsonl 3 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_03232e26-be3f-4a28-a5d2-ee1d8c0e9831.jsonl 4 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_034e819a-cfcb-43c6-ad25-0232ad48823c.jsonl 5 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_077ae8de-a68e-47e7-95a6-6d82f8f4eeb9.jsonl 6 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_0af50072-df4c-4084-a833-cebbd046e70e.jsonl 7 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_0de84cfc-c080-471f-b139-1bf061db4feb.jsonl 8 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_0fbdd8ad-32d8-4228-9a40-e09dde689760.jsonl 9 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_11c659c1-ffbf-4455-abfd-058f6bbf4bb2.jsonl 10 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_1958455d-6543-4307-a081-d86ce0637f9a.jsonl 11 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_1982fb29-c4ed-4dd3-855c-666e63bc62d9.jsonl 12 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_1caed86f-5625-4941-bdc1-cc57e4fec1cd.jsonl 13 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_1d3a0cd6-f0e6-4106-a080-524a4bd50016.jsonl 14 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_29d54f5a-1dd0-4e9a-b783-fb2eec9db072.jsonl 15 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_29fd3d99-53fb-43e2-a4a5-2fd01bf77258.jsonl 16 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_2b224cd9-286e-46ac-8c4e-c1e3befc8760.jsonl 17 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_2c131fca-2a05-4d5f-a805-59d2af3477e2.jsonl 18 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_2f28f1a7-6972-48ad-8997-65a5d52e4f1c.jsonl 19 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_30440198-cd90-48c6-82c1-ea871b8c21c5.jsonl 20 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_39367d6c-d7d4-45fc-a929-8a17184d1744.jsonl 21 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_393d19f2-1cd1-421f-be8a-78d955fdf602.jsonl 22 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_3a5d4f93-97ec-483a-88ef-324df9651b3f.jsonl 23 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_3c89ea11-69ff-4049-b775-f0c785997909.jsonl 24 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_3d5a011a-4bbe-4585-a2bd-ff3e943c8671.jsonl 25 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_3f805f4b-6f7f-42a8-a006-47c1e0401bd7.jsonl 26 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_3f9eb7ad-f266-4154-8d4d-54deeffde075.jsonl 27 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_400748d3-0076-4a04-8a1c-6055ba0b5a2d.jsonl 28 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_44e19375-3995-4dff-a3b6-8a25247a165c.jsonl 29 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_4a8cf52f-81d0-4875-9528-466b1cbc71e1.jsonl 30 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_4cc7015c-c39a-4bf6-9686-c00b3343edd9.jsonl 31 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_50757a42-079b-41ec-bcca-73759faffd62.jsonl 32 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_575ae832-e770-4a89-bfa7-c56f16dbca69.jsonl 33 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_580be642-bb73-4d0d-8b5e-f494722934cd.jsonl 34 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_5a02d9ee-12a0-437d-808f-d26f0eb2012b.jsonl 35 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_5d8d402b-8277-480a-b5fa-71169726864f.jsonl 36 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_5ee33ef7-455e-4fd5-9512-c4771dd802c1.jsonl 37 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_610c82ed-b9ee-449c-83b0-601205f3a74a.jsonl 38 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_629fe3ca-075f-4663-9b81-b807f3b42bf2.jsonl 39 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_64e5075e-e87e-4b2a-9e38-e5c102f6f2b1.jsonl 40 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_65dd2ff6-dae3-4a60-90d3-c3d7349fc92f.jsonl 41 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_6719ecd2-fe34-4078-a584-320d921cbf6f.jsonl 42 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_6938ee72-43ee-4ade-8840-151a402383b0.jsonl 43 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_73241940-66c1-481c-b53a-f5e8b9afe9fa.jsonl 44 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_751370b5-c7cb-44d8-a039-1468ee6747ab.jsonl 45 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_75af5d17-5ebb-4460-9f2a-dc9fe880a936.jsonl 46 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_79d50803-f7d9-4aa8-bf1a-d807980a40c6.jsonl 47 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_7b26046f-7c8d-405b-911b-df51e1a069fa.jsonl 48 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_7d1d69dc-bc8e-4817-9cab-afdc002ab7c4.jsonl 49 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_7ea7a996-b1bb-4773-a36a-461dce2de861.jsonl 50 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_8232f276-9e3f-463a-9350-362de1b501d1.jsonl 51 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_8509f5a7-64a8-4813-92dc-f6eb53e3aacc.jsonl 52 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_85b4c166-469d-449c-ab3d-5214c1d80246.jsonl 53 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_872b620a-b4fd-45d3-92bc-ff0584447705.jsonl 54 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_88f24f8d-16d3-4a21-894d-192033d0fa67.jsonl 55 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_8e6bd730-0f10-49d9-9b02-5ce16da47483.jsonl 56 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_8ede1b71-6846-439a-acba-86a57cfec3d2.jsonl 57 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_8f74f6ba-1c53-42d5-a3c7-e4ef46a71133.jsonl 58 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_90fa9c2b-25b0-47b7-af2b-a683356e543b.jsonl 59 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_92ec488a-287d-4bf0-977b-6998cf0cf476.jsonl 60 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_94a393a1-3b23-4961-a0a6-70bad5b4979c.jsonl 61 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_94b9df70-a95f-4545-be3a-5a34f7b09fb3.jsonl 62 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_95ffc9e1-c505-4a3b-8fb0-cbc98b8703e1.jsonl 63 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_98e718fd-5b0e-439f-a00c-57b61e06b395.jsonl 64 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_9f50a028-2586-4e0d-bcfd-d9d2d74e8953.jsonl 65 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_9f8d7d10-dda7-4e44-b00c-811635a199c8.jsonl 66 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_a055bd62-1ec2-47cf-bad2-321e3d4f053f.jsonl 67 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_a1e3430e-ef5c-4a86-914d-88e8fb7818c0.jsonl 68 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_a2b4cb3d-bea3-478e-82a2-77c00a827250.jsonl 69 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_a647274d-799d-4e7a-a485-b8632a87061e.jsonl 70 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_a94ea420-99ae-4d58-9cdc-d4666e3322a7.jsonl 71 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_ab7c0034-7fc1-4fa8-bae3-e97b85fc16a4.jsonl 72 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_b0732cff-657e-4e69-87b8-66e8025bf441.jsonl 73 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_b1f3b912-a2ab-43bd-8811-43d84b422506.jsonl 74 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_b6c2e4d3-d215-4d99-891f-d20b997d4d5a.jsonl 75 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_bbfabe6b-b9bc-476b-b8f0-7d6c47e9d2be.jsonl 76 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_bf11ef29-a3f9-4a2d-9dbf-ebcc56d39fdb.jsonl 77 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_c05da42b-4939-4f55-867c-16cf6d228e60.jsonl 78 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_c1fc3dd5-861f-4b8d-b7a2-eb8f6887b33b.jsonl 79 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_c44164b4-0770-48d0-87db-590ca529032a.jsonl 80 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_c6b43cda-ca5c-4855-9c08-0f8264cab1af.jsonl 81 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_c6e57a16-4879-4dcf-b591-503cfb46a360.jsonl 82 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_ca6e3842-6ca4-4230-84fa-376a3374c380.jsonl 83 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_caf769e4-7308-4419-9114-900ca213682a.jsonl 84 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_d0da78e9-3dcf-4a46-85c1-f23ed00178bc.jsonl 85 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_d386838c-a51e-4839-9f27-8495b2466e49.jsonl 86 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_d41edf9f-0ebb-4866-b3fe-50785746b36b.jsonl 87 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_d6a7fc44-b584-4dd8-9de2-e981afe0bb4a.jsonl 88 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_d7c49dbb-c008-47fc-9cbe-8d5695842d21.jsonl 89 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_db3b4be7-4e98-4fe9-96bf-05a5788815e3.jsonl 90 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_db981f69-9eca-4031-8565-318b949efbfe.jsonl 91 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_dbd4105f-7cbb-4483-a7b2-96b17b7fb594.jsonl 92 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_de42e348-b333-4d35-b883-9bfc94f29822.jsonl 93 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_de744938-fa6c-45dd-b600-428dd7c63a73.jsonl 94 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_e8f25867-697d-4f52-84e1-e50a95bc182b.jsonl 95 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_eb4e26d4-6625-4f8a-b5fe-6f3a9b8a4b79.jsonl 96 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_f141b736-5ce4-4f18-bb29-704227ca4bd1.jsonl 97 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_f50efdc6-f88e-4fa6-9ef6-dd1d8314bb36.jsonl 98 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_f7680c03-70df-4781-a98d-c88695f92f04.jsonl 99 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_fbc62949-624d-4943-9731-f5c46242ba55.jsonl 100 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_fd572627-cce7-4667-a684-fef096dfbeb7.jsonl -------------------------------------------------------------------------------- /llama_legacy/urls/book.txt: -------------------------------------------------------------------------------- 1 | https://data.together.xyz/redpajama-data-1T/v1.0.0/book/book.jsonl -------------------------------------------------------------------------------- /llama_legacy/urls/github.txt: -------------------------------------------------------------------------------- 1 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_08cdfa755e6d4d89b673d5bd1acee5f6.sampled.jsonl 2 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_0f27d10d846a473b96070c3394832f32.sampled.jsonl 3 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_0f979046c8e64e0fb5843d2634a9957d.sampled.jsonl 4 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_10f129bfd0af45caa9cd72aa9d863ec5.sampled.jsonl 5 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_11a1943edfa349c7939382799599eed6.sampled.jsonl 6 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_17197bd2478044bebd9ff4634b6dfcee.sampled.jsonl 7 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_1d750c0ce39d40c6bc20bad9469e5a99.sampled.jsonl 8 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_21078cf63afb4d9eb4a7876f726a7226.sampled.jsonl 9 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_216883d3a669406699428bc485a4c228.sampled.jsonl 10 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_24278a707deb445b8e4f59c83dd67910.sampled.jsonl 11 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_25989b8233a04ac791b0eccd502e0c7a.sampled.jsonl 12 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_26a6fa61c5eb4bb885e7bc643e285f0e.sampled.jsonl 13 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_275100c1a44f4451b0343373ebc5637a.sampled.jsonl 14 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_27f05c041a1c401783f90b9415e40e4b.sampled.jsonl 15 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_28106e66bfd94978abbc15ec845aeddb.sampled.jsonl 16 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_2afd093fefad4c8da76cc539e8fb6137.sampled.jsonl 17 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_30d01c4edab64866bda8c609e90b4f4e.sampled.jsonl 18 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_34b0785b77814b7583433ddb27a61ae0.sampled.jsonl 19 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_34e78793c1e94eeebd92852399097596.sampled.jsonl 20 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_366012588bef4d749bbbea76ae701141.sampled.jsonl 21 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_36c2a955ddd84672bbc778aa4ad2fbaf.sampled.jsonl 22 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_3be6a5cc7428401393c23e5516a43537.sampled.jsonl 23 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_3cc6cab9266746a6befa23648aa43119.sampled.jsonl 24 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_3f09f384f0734a4b912bb75db3f812bc.sampled.jsonl 25 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_478afe8aaccb43e6be2de7e34e041ef3.sampled.jsonl 26 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_483f11d6bc864f7fbfbe63bdf3583ce2.sampled.jsonl 27 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_4b6883dc304c4e799620ec95b96dc91a.sampled.jsonl 28 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_4bebabdbd8544da7a2071864ccf81f2e.sampled.jsonl 29 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_4d44caf43c154ae4aaeab36eab0221c9.sampled.jsonl 30 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_4f98dc136ba94eeaa1ba6c974814b33c.sampled.jsonl 31 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_4ff3604761614a3db550ef758e6457b5.sampled.jsonl 32 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_50423e84046948b4a2a70e7e4538e12d.sampled.jsonl 33 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_58fad5994b4446c6bceb33453484acb4.sampled.jsonl 34 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_5abf005e4e634f1dbfa8bd20b5687092.sampled.jsonl 35 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_5f2b3517159b426bb1a9e81ca189abcd.sampled.jsonl 36 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_610e908cafaa4d53958de50ad700822a.sampled.jsonl 37 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_6353ab14cb8f4623a7d75678d9e7f44e.sampled.jsonl 38 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_64747f25102740bab0ab54559569342a.sampled.jsonl 39 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_6742e99802894114a3ba44841f49b168.sampled.jsonl 40 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_677e10f9b0af4c489e60670352e7e224.sampled.jsonl 41 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_68534e6a093744fa9f38fa1a9cf51232.sampled.jsonl 42 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_6adcf92fb2ee48059cb60579f2e931f7.sampled.jsonl 43 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_70f0aa43987643a7874286bca4faa66b.sampled.jsonl 44 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_7462f1a34f594e9f8334d9a0cbbf80e7.sampled.jsonl 45 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_78a13a25258f4d24923702c07445e20e.sampled.jsonl 46 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_78e3afc1cfee41fbb7eaae2e5bfaa17b.sampled.jsonl 47 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_791ed086a57f4879bb1596bed6d37bb3.sampled.jsonl 48 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_7b5cc18857a34a0981b54f082de55cf8.sampled.jsonl 49 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_7c54b908a2df4ec2ba316d2081fc674e.sampled.jsonl 50 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_7edd0af0b61c426e93e8bd3f549a8f78.sampled.jsonl 51 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_7fa232e63a5d44a88a267181e9ac47b4.sampled.jsonl 52 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_81138c01f3a84f7fa8b56bf3d8fa35ce.sampled.jsonl 53 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_82a9a647b0eb4eb080d7ac15a13c765b.sampled.jsonl 54 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_835ac26c7a70447b97f4ec38bcb969ed.sampled.jsonl 55 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_8bc3c4fae78c41d999b2ae6d97cce96c.sampled.jsonl 56 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_8dc430f8f7114f018440c7b3d990e602.sampled.jsonl 57 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_8e1db2cf6c98420a88bad52fd57f4aa7.sampled.jsonl 58 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_919156910e704e6ca52e0d4880cdbb63.sampled.jsonl 59 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_944beac1347f491faa88f43c25d26fe4.sampled.jsonl 60 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_96175b6e4c764abfbbf14e78d4fd6464.sampled.jsonl 61 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_977d1fcdab92452d9dc38b2f4a99389b.sampled.jsonl 62 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_9eea501af8544d0b88f0b002850829d4.sampled.jsonl 63 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_a26c51ffe2924fd8ad6694c6aa0eacc5.sampled.jsonl 64 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_a3a01a2f81ed4cd2afb30e175200b48f.sampled.jsonl 65 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_a3d3e6f3f7d5495ca9ecf94d808fd350.sampled.jsonl 66 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_a777da5620f1467f8df3616b17d533dc.sampled.jsonl 67 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_a8fcae75b0c3410faabcff02f0056a36.sampled.jsonl 68 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_aabc2d54d1c946908d3400228e0f238c.sampled.jsonl 69 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_ad160286662849f49d1e6de27c0f1d15.sampled.jsonl 70 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_adda41d791974289aff042e2e3d07ec3.sampled.jsonl 71 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_ae1813abc63f4b1998dfa608e7fe5588.sampled.jsonl 72 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_af7b386db97e4211a7378be08d7b3f4f.sampled.jsonl 73 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_b15c575fa9f8465d98f70ba2f2f73c6e.sampled.jsonl 74 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_b40763d0b9ce4e0d8fb5f519f1f49f8c.sampled.jsonl 75 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_b56453718c5f46efa9c46feb194b0d6e.sampled.jsonl 76 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_b6225e159b86432d9fa5bf226bb51393.sampled.jsonl 77 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_b821f640b8f14ed588bf48ae13f44098.sampled.jsonl 78 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_bfbcad9633f04601ba2f824d082eaacf.sampled.jsonl 79 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_c6cfd16905814d7c955df1f4754a8b11.sampled.jsonl 80 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_c82c6775f0b74dbdae4524bb9aebf0ef.sampled.jsonl 81 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_ca73c69896b34adbbe468d78d9f134bc.sampled.jsonl 82 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_cac864a472b948b0bfe18c8e9a19aeb5.sampled.jsonl 83 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_cb330e8dc8ac411eba2dc8676c9c4403.sampled.jsonl 84 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_d0a054a678fc4c38b496d10e91a2c735.sampled.jsonl 85 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_d148857715424aabbd32b8ffe56c4082.sampled.jsonl 86 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_d33701435f964c90a86c22e204dd5fde.sampled.jsonl 87 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_d448c2553f474193a1224df1c38f74d4.sampled.jsonl 88 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_d8d45c58819948c9b352d74383944c4a.sampled.jsonl 89 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_db611a692a704f6db3c18e77f79fd2f0.sampled.jsonl 90 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_dcaa5c8b729b4fb599399dbf4557e43e.sampled.jsonl 91 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_dda863d833614d04b96bbe21b161768d.sampled.jsonl 92 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_eabf427d56184fb89f9b5f27e73f7988.sampled.jsonl 93 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_eee6980922ca4a14b0c3341fa8a904d9.sampled.jsonl 94 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_f478ff72b57f4f4283c22ac22ae84134.sampled.jsonl 95 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_f931feb0e85940879d194c0e20d9e28a.sampled.jsonl 96 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_f9bbe6a065004a4c8e018c6ad63063b2.sampled.jsonl 97 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_fbfc552e48164acda6605fa31fc2f563.sampled.jsonl 98 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_fc817e4c16494957bcb89b166e91434f.sampled.jsonl -------------------------------------------------------------------------------- /llama_legacy/urls/stackexchange.txt: -------------------------------------------------------------------------------- 1 | https://data.together.xyz/redpajama-data-1T/v1.0.0/stackexchange/stackexchange.jsonl -------------------------------------------------------------------------------- /llama_legacy/urls/wikipedia.txt: -------------------------------------------------------------------------------- 1 | https://data.together.xyz/redpajama-data-1T/v1.0.0/wikipedia/wiki.jsonl -------------------------------------------------------------------------------- /llama_legacy/weight_diff.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This file has been changed by Amirkeivan Mohtashami 16 | # to take into account the new token in the embedding layer 17 | 18 | from typing import Optional 19 | 20 | import fire 21 | import torch 22 | import tqdm 23 | import transformers 24 | from train import smart_tokenizer_and_embedding_resize 25 | 26 | 27 | @torch.inference_mode() 28 | def make_diff( 29 | path_raw: str, path_tuned: str, path_diff: str, device="cpu", # "cuda" or "cpu" 30 | ): 31 | """Make the weight diff. 32 | 33 | This function is given to present full transparency of how the weight diff was created. 34 | 35 | Run: 36 | python weight_diff.py make_diff --path_raw --path_tuned --path_diff 37 | """ 38 | model_tuned: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained( 39 | path_tuned, 40 | device_map={"": torch.device(device)}, 41 | torch_dtype=torch.float32, 42 | low_cpu_mem_usage=True, 43 | ) 44 | model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained( 45 | path_raw, 46 | device_map={"": torch.device(device)}, 47 | torch_dtype=torch.float32, 48 | low_cpu_mem_usage=True, 49 | ) 50 | 51 | tokenizer_tuned: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained( 52 | path_tuned 53 | ) 54 | tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained( 55 | path_raw 56 | ) 57 | smart_tokenizer_and_embedding_resize( 58 | special_tokens_dict=dict(pad_token="[PAD]", additional_special_tokens=[""]), 59 | model=model_raw, 60 | tokenizer=tokenizer_raw, 61 | ) 62 | 63 | state_dict_tuned = model_tuned.state_dict() 64 | state_dict_raw = model_raw.state_dict() 65 | for key in tqdm.tqdm(state_dict_tuned): 66 | state_dict_tuned[key].add_(-state_dict_raw[key]) 67 | 68 | model_tuned.save_pretrained(path_diff) 69 | tokenizer_tuned.save_pretrained(path_diff) 70 | 71 | 72 | @torch.inference_mode() 73 | def recover( 74 | path_raw, 75 | path_diff, 76 | path_tuned: Optional[str] = None, 77 | device="cpu", 78 | test_inference=True, 79 | check_integrity_naively=True, 80 | ): 81 | """Recover the original weights from the released weight diff. 82 | 83 | This function is given for you to run. 84 | 85 | Things to do before running this: 86 | 1. Convert Meta's released weights into huggingface format. Follow this guide: 87 | https://huggingface.co/docs/transformers/main/model_doc/llama 88 | 2. Make sure you cloned the released weight diff into your local machine. The weight diff is located at: 89 | https://huggingface.co/tatsu-lab/alpaca-7b/tree/main 90 | 3. Run this function with the correct paths. E.g., 91 | python weight_diff.py recover --path_raw --path_diff 92 | 93 | Additional notes: 94 | - If things run too slowly, and you have an 80G GPU lying around, let GPU go brrr by setting `--device "cuda"`. 95 | - If you want to save the recovered weights, set `--path_tuned `. 96 | Next time you can load the recovered weights directly from ``. 97 | """ 98 | model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained( 99 | path_raw, 100 | device_map={"": torch.device(device)}, 101 | torch_dtype=torch.float32, 102 | low_cpu_mem_usage=True, 103 | ) 104 | model_recovered: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained( 105 | path_diff, 106 | device_map={"": torch.device(device)}, 107 | torch_dtype=torch.float32, 108 | low_cpu_mem_usage=True, 109 | ) 110 | 111 | tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained( 112 | path_raw 113 | ) 114 | smart_tokenizer_and_embedding_resize( 115 | special_tokens_dict=dict(pad_token="[PAD]", additional_special_tokens=[""]), 116 | model=model_raw, 117 | tokenizer=tokenizer_raw, 118 | ) 119 | tokenizer_recovered: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained( 120 | path_diff 121 | ) 122 | 123 | state_dict_recovered = model_recovered.state_dict() 124 | state_dict_raw = model_raw.state_dict() 125 | for key in tqdm.tqdm(state_dict_recovered): 126 | state_dict_recovered[key].add_(state_dict_raw[key]) 127 | 128 | if check_integrity_naively: 129 | # This is not a rigorous, cryptographically strong integrity check :) 130 | allsum = sum(state_dict_recovered[key].sum() for key in state_dict_recovered) 131 | assert torch.allclose( 132 | allsum, torch.full_like(allsum, fill_value=49798.7656), atol=1e-2, rtol=0 133 | ), "Naive integrity check failed. This could imply that some of the checkpoint files are corrupted." 134 | 135 | if path_tuned is not None: 136 | model_recovered.save_pretrained(path_tuned) 137 | tokenizer_recovered.save_pretrained(path_tuned) 138 | 139 | return model_recovered, tokenizer_recovered 140 | 141 | 142 | def main(task, **kwargs): 143 | globals()[task](**kwargs) 144 | 145 | 146 | if __name__ == "__main__": 147 | fire.Fire(main) 148 | -------------------------------------------------------------------------------- /lm_benchmark/config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import rotary 16 | 17 | CONFIG_FORMAT_TO_MODULE_MAP = { 18 | "rotary": rotary, 19 | } 20 | 21 | 22 | def parse_args_with_format(format, base_parser, args, namespace): 23 | return CONFIG_FORMAT_TO_MODULE_MAP[format].parse_args(base_parser, args, namespace) 24 | 25 | 26 | def registered_formats(): 27 | return CONFIG_FORMAT_TO_MODULE_MAP.keys() 28 | -------------------------------------------------------------------------------- /lm_benchmark/config/rotary.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import torch 17 | import distributed 18 | import models 19 | 20 | def none_or_str(value): 21 | if value == 'None': 22 | return None 23 | return value 24 | 25 | def none_or_int(value): 26 | if value == 'None': 27 | return None 28 | return int(value) 29 | 30 | def none_or_float(value): 31 | if value == 'None': 32 | return None 33 | return float(value) 34 | 35 | def parse_args(base_parser, args, namespace): 36 | parser = base_parser 37 | # General training params 38 | parser.add_argument('--batch_size', default=50, type=int) 39 | parser.add_argument('--acc_steps', default=4, type=int) 40 | parser.add_argument('--seed', default=2, type=int) 41 | parser.add_argument('--device', default='cuda:0', type=str) 42 | parser.add_argument('--iterations', default=15000, type=int) 43 | parser.add_argument('--lr', default=2e-3, type=float) 44 | parser.add_argument('--warmup_percent', default=0.02, type=float) 45 | parser.add_argument('--weight_decay', default=1e-3, type=float) 46 | parser.add_argument('--beta1', default=0.9, type=float) 47 | parser.add_argument('--beta2', default=0.95, type=float) 48 | parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'none']) 49 | parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'adafactor']) 50 | parser.add_argument('--eval_freq', default=200, type=int) # in iterations 51 | parser.add_argument('--results_base_folder', default="./exps", type=str) 52 | parser.add_argument('--save_checkpoint_freq', default=None, type=int, required=False) 53 | 54 | # Dataset params 55 | parser.add_argument('--dataset', choices=['pg19', 'arxivmath']) 56 | parser.add_argument('--vocab_size', default=50304, type=int) 57 | parser.add_argument('--mem_freq', default=50, type=none_or_int, required=False, help="Frequency of landmark tokens") 58 | 59 | # Model params 60 | parser.add_argument('--model', default='base_rotary', choices=models.registered_models()) 61 | parser.add_argument('--dropout', default=0.0, type=float) 62 | parser.add_argument('--group_dropout', default=None, type=float, required=False) 63 | parser.add_argument('--n_head', default=8, type=int) 64 | parser.add_argument('--n_layer', default=12, type=int) # depths in att + ff blocks 65 | parser.add_argument('--n_embd', default=1024, type=int) # embedding size / hidden size ... 66 | parser.add_argument('--sequence_length', default=512, type=int) 67 | parser.add_argument('--dtype', default="torch.bfloat16", type=str) 68 | parser.add_argument('--bias', default=False, type=bool) 69 | parser.add_argument('--no_compile', action='store_true') # if true then model is not compiled 70 | parser.add_argument('--run_prefix', default=None, type=str, required=False) # is added before the autogenerated experiment name 71 | parser.add_argument('--exp_name', default=None, type=str, required=False) # is added before the autogenerated experiment name 72 | parser.add_argument('--softmax_func', default="mem_opt", type=str, required=False, 73 | choices=["mem_opt", "nomem", "mem", "ignore_mem"]) # distributed backend type 74 | parser.add_argument('--positional_encoder', default="rotary", type=str, required=False, 75 | choices=models.positional_encoders.registered_encoders()) # distributed backend type 76 | # logging params (WandB) 77 | parser.add_argument('--wandb', action='store_true') # whether to use wandb or not 78 | parser.add_argument('--wandb_project', default="my-project", type=str) 79 | # Distributed args 80 | parser.add_argument('--distributed_backend', default=None, type=none_or_str, required=False, 81 | choices=distributed.registered_backends()) # distributed backend type 82 | # Landmark tokens 83 | parser.add_argument('--max_groups_for_softmax', default=16, type=int, required=False, help="Should be at least 2 + max. number of landmark tokens in one chunk.") 84 | # Inference 85 | parser.add_argument('--use_cache', action='store_true') 86 | parser.add_argument('--lm_cache', default="none", type=str, required=False, 87 | choices=models.caches.registered_caches()) 88 | parser.add_argument('--mem_cache_size', default=None, type=int, required=False) 89 | parser.add_argument('--mem_cache_freq', default=None, type=int, required=False, help="Frequency to add landmark tokens in the input (block size at inference)") 90 | parser.add_argument('--cache_topk', default=1, type=int, required=False) 91 | parser.add_argument('--cache_selection_method', default="per_token_and_head", type=str, required=False,) 92 | parser.add_argument('--eval_seq_length', default=512, type=int, required=False, help="Evaluation Length") 93 | parser.add_argument('--eval_sample_size', default=None, type=none_or_int, required=False, help="Size of the random subset of validation set used for evaluation") 94 | parser.add_argument('--mid_length', default=250, type=int, required=False, help="Size of chunks to break the input into") 95 | parser.add_argument('--allow_cache_during_training', action='store_true') 96 | parser.add_argument('--postpone_lm_cache', action='store_true') 97 | parser.add_argument('--optimization_process', default="landmark", type=str, required=False, 98 | choices=["transformer_xl", "landmark"]) # distributed backend type 99 | 100 | # CMT Token 101 | parser.add_argument('--under_rem_score_prob', default=0., type=none_or_float, required=False) 102 | parser.add_argument('--rem_cutoff', default=None, type=none_or_float, required=False) 103 | parser.add_argument('--enable_rem_score', default=False, action='store_true', required=False) 104 | 105 | # Positional Augmentation 106 | parser.add_argument('--pos_jump_on_mem', default=None, type=none_or_int, required=False) 107 | 108 | # Transformer XL 109 | parser.add_argument('--total_sequence_length', default=None, type=int, required=False) 110 | 111 | 112 | args = parser.parse_args(args, namespace) 113 | 114 | if args.exp_name is None: 115 | special_name_handle_fields = {"model", "lr", "batch_size", 116 | "acc_steps", "seed", "exp_name", 117 | "wandb", "wandb_project", 118 | "run_prefix", "distributed_backend", "config_format", 119 | "sequence_length", "mem_freq"} 120 | overriden_values = [] 121 | for key in vars(args): 122 | if key in special_name_handle_fields: 123 | continue 124 | if getattr(args, key) != parser.get_default(key): 125 | overriden_values.append((key, getattr(args, key))) 126 | chunk_len = 10 127 | overriden_values_str_parts = [] 128 | for chunk_id in range(0, len(overriden_values), chunk_len): 129 | overriden_values_str = "_".join(["{}={}".format(key, value) for key, value in overriden_values[chunk_id:chunk_id+chunk_len]]) 130 | overriden_values_str_parts.append(overriden_values_str) 131 | overriden_values_str = "/".join(overriden_values_str_parts) 132 | exp_name = "" 133 | if args.run_prefix is not None: 134 | exp_name += f"{args.run_prefix}_" 135 | exp_name += f"{args.model}_lr{args.lr}_memfreq{args.mem_freq}_bs{args.batch_size}x{args.acc_steps}_seqlen{args.sequence_length}/{overriden_values_str}_seed={args.seed}" 136 | args.exp_name = exp_name 137 | 138 | args.landmark_id = 50260 139 | if args.dtype == "torch.bfloat16": 140 | args.dtype = torch.bfloat16 141 | elif args.dtype == "torch.float16": 142 | args.dtype = torch.float16 143 | 144 | landmark_freq = max(args.mem_cache_freq or 0, args.mem_freq or 0) 145 | if landmark_freq != 0 and args.max_groups_for_softmax < args.sequence_length // landmark_freq + 1 + 2: 146 | print("CRITICAL WARNING: Maximum number of groups for softmax is too low. Adjust with --max_groups_for_softmax.") 147 | 148 | 149 | return args 150 | -------------------------------------------------------------------------------- /lm_benchmark/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import pg19, arxiv_math 16 | 17 | PREPARE_GET_DATASET_MAP = { 18 | "pg19": (pg19.prepare_pg19_data, pg19.get_pg19_data), 19 | "arxivmath": (arxiv_math.prepare_arxivmath_data, arxiv_math.get_arxivmath_data) 20 | } 21 | 22 | 23 | def prepare_dataset(args): 24 | """ Fetch the right dataset given by the args.dataset parameter. The logic for each dataset is 25 | contained in its own pythin file. The expected format at the moment is a disctionary of np.memmap 26 | containing two keys: 'train' and 'val', corresponding to the tokenized training and validation data. """ 27 | return PREPARE_GET_DATASET_MAP[args.dataset][0](args) 28 | 29 | def get_dataset(args): 30 | """ Fetch the right dataset given by the args.dataset parameter. The logic for each dataset is 31 | contained in its own pythin file. The expected format at the moment is a disctionary of np.memmap 32 | containing two keys: 'train' and 'val', corresponding to the tokenized training and validation data. """ 33 | return PREPARE_GET_DATASET_MAP[args.dataset][1](args) 34 | -------------------------------------------------------------------------------- /lm_benchmark/data/arxiv_math.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import zipfile 17 | import urllib 18 | import numpy as np 19 | import tiktoken 20 | import torch 21 | import regex 22 | import multiprocessing 23 | import itertools 24 | import functools 25 | 26 | from .utils import add_mem_tokens 27 | 28 | ARXIVMATH_ORIGINAL_PATH = "./data/proof-pile/" 29 | 30 | 31 | def get_path(config): 32 | dataset_name = f"arxiv_mem={config.mem_freq}" 33 | return os.path.join(os.path.dirname(__file__), f"datasets/{dataset_name}/") 34 | 35 | def prepare_arxivmath_data(config): 36 | DATA_PATH = get_path(config) 37 | print(DATA_PATH) 38 | os.makedirs(DATA_PATH, exist_ok=True) 39 | if not os.path.exists(os.path.join(DATA_PATH, 'train.bin')): 40 | train_data = np.memmap(os.path.join(ARXIVMATH_ORIGINAL_PATH, 'train.bin'), dtype=np.uint16, mode='r') 41 | raw_tokenized_train = add_mem_tokens(config.landmark_id, train_data, config.mem_freq) 42 | train_tokenized = np.array(raw_tokenized_train, dtype=np.uint16) 43 | train_tokenized.tofile(os.path.join(DATA_PATH, 'train.bin')) 44 | 45 | if not os.path.exists(os.path.join(DATA_PATH, 'val.bin')): 46 | val_data = np.memmap(os.path.join(ARXIVMATH_ORIGINAL_PATH, 'validation.bin'), dtype=np.uint16, mode='r') 47 | raw_tokenized_eval = add_mem_tokens(config.landmark_id, val_data, config.mem_freq) 48 | eval_tokenized = np.array(raw_tokenized_eval, dtype=np.uint16) 49 | eval_tokenized.tofile(os.path.join(DATA_PATH, 'val.bin')) 50 | print("completed the tokenization process!") 51 | 52 | 53 | def get_arxivmath_data(config): 54 | DATA_PATH = get_path(config) 55 | 56 | train_data = np.memmap(os.path.join(DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') 57 | val_data = np.memmap(os.path.join(DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') 58 | 59 | return {'train': train_data, 'val': val_data} 60 | -------------------------------------------------------------------------------- /lm_benchmark/data/pg19.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import zipfile 17 | import urllib 18 | import numpy as np 19 | import tiktoken 20 | import torch 21 | import regex 22 | import multiprocessing 23 | import itertools 24 | import functools 25 | 26 | from .utils import add_mem_tokens 27 | 28 | 29 | PG19_ORIGINAL_PATH = "./data/pg19" 30 | 31 | 32 | def get_path(config): 33 | dataset_name = f"pg19_mem={config.mem_freq}" 34 | return os.path.join(os.path.dirname(__file__), f"datasets/{dataset_name}/") 35 | 36 | def prepare_pg19_data(config): 37 | DATA_PATH = get_path(config) 38 | print(DATA_PATH) 39 | os.makedirs(DATA_PATH, exist_ok=True) 40 | if not os.path.exists(os.path.join(DATA_PATH, 'train.bin')): 41 | train_data = np.memmap(os.path.join(PG19_ORIGINAL_PATH, 'train.bin'), dtype=np.uint16, mode='r') 42 | raw_tokenized_train = add_mem_tokens(config.landmark_id, train_data, config.mem_freq) 43 | train_tokenized = np.array(raw_tokenized_train, dtype=np.uint16) 44 | train_tokenized.tofile(os.path.join(DATA_PATH, 'train.bin')) 45 | 46 | if not os.path.exists(os.path.join(DATA_PATH, 'val.bin')): 47 | val_data = np.memmap(os.path.join(PG19_ORIGINAL_PATH, 'validation.bin'), dtype=np.uint16, mode='r') 48 | raw_tokenized_eval = add_mem_tokens(config.landmark_id, val_data, config.mem_freq) 49 | eval_tokenized = np.array(raw_tokenized_eval, dtype=np.uint16) 50 | eval_tokenized.tofile(os.path.join(DATA_PATH, 'val.bin')) 51 | print("completed the tokenization process!") 52 | 53 | 54 | def get_pg19_data(config): 55 | DATA_PATH = get_path(config) 56 | 57 | train_data = np.memmap(os.path.join(DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') 58 | val_data = np.memmap(os.path.join(DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') 59 | 60 | return {'train': train_data, 'val': val_data} 61 | -------------------------------------------------------------------------------- /lm_benchmark/data/pg19/README.md: -------------------------------------------------------------------------------- 1 | Download dataset from https://github.com/deepmind/pg19 in this folder. The run prepare.py 2 | -------------------------------------------------------------------------------- /lm_benchmark/data/pg19/prepare.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import tiktoken 17 | 18 | import numpy as np 19 | 20 | 21 | 22 | gpt2_tokenizer = tiktoken.get_encoding("gpt2") 23 | def _read_directory(path): 24 | texts = [] 25 | for filename in os.listdir(path): 26 | if filename.endswith(".txt") and filename[:-4].isnumeric(): 27 | print(filename) 28 | with open(os.path.join(path, filename), 'r') as f: 29 | texts += gpt2_tokenizer.encode_ordinary(f.read()) 30 | texts.append(gpt2_tokenizer.eot_token) 31 | return np.array(texts, dtype=np.uint16) 32 | 33 | 34 | raw_eval_data = _read_directory("validation") 35 | raw_eval_data.tofile('validation.bin') 36 | raw_train_data = _read_directory("train") 37 | raw_train_data.tofile('train.bin') 38 | -------------------------------------------------------------------------------- /lm_benchmark/data/proof-pile/prepare.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import os 17 | 18 | from datasets import load_dataset 19 | import numpy as np 20 | import tiktoken 21 | from tqdm import tqdm 22 | 23 | 24 | dir_path = os.path.dirname(os.path.realpath(__file__)) 25 | 26 | 27 | dataset = load_dataset("hoskinson-center/proof-pile", cache_dir=os.path.join(dir_path, "cache")) 28 | 29 | 30 | num_proc = 16 31 | arxiv = dataset.filter(lambda x: json.loads(x['meta']).get('config', None) == "arxiv", num_proc=num_proc) 32 | 33 | 34 | enc = tiktoken.get_encoding("gpt2") 35 | def process(example): 36 | ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens 37 | ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe 38 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 39 | out = {'ids': ids, 'len': len(ids)} 40 | return out 41 | 42 | # tokenize the dataset 43 | tokenized = arxiv.map( 44 | process, 45 | remove_columns=['text'], 46 | desc="tokenizing the splits", 47 | num_proc=num_proc, 48 | ) 49 | 50 | 51 | for split, dset in tokenized.items(): 52 | arr_len = np.sum(dset['len']) 53 | filename = os.path.join(dir_path, f'{split}.bin') 54 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 55 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 56 | total_batches = 1024 57 | 58 | idx = 0 59 | for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'): 60 | # Batch together samples for faster write 61 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy') 62 | arr_batch = np.concatenate(batch['ids']) 63 | # Write into mmap 64 | arr[idx : idx + len(arr_batch)] = arr_batch 65 | idx += len(arr_batch) 66 | arr.flush() 67 | -------------------------------------------------------------------------------- /lm_benchmark/data/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | import multiprocessing 18 | import itertools 19 | import functools 20 | 21 | 22 | def apply_add_mem_tokens(mem_id, tokens_filename, freq, start_idx, end_idx): 23 | tokens = np.memmap(tokens_filename, dtype=np.uint16, mode='r') 24 | print(f"Processing {start_idx}-{end_idx}") 25 | tokens_with_mem = [] 26 | for t_idx in range(start_idx, end_idx): 27 | t = tokens[t_idx] 28 | tokens_with_mem.append(t) 29 | if freq is not None and t_idx % freq == freq - 1: 30 | tokens_with_mem.append(mem_id) 31 | return tokens_with_mem 32 | 33 | def add_mem_tokens(mem_id, tokens, freq, n_workers=32): 34 | print(len(tokens)) 35 | with multiprocessing.Pool(n_workers) as pool: 36 | ids = list(range(0, len(tokens), 10 * 1000 * 1000)) 37 | pair_ids = zip(ids, ids[1:] + [len(tokens)]) 38 | apply = functools.partial(apply_add_mem_tokens, mem_id, tokens.filename, freq) 39 | return list(itertools.chain.from_iterable(pool.starmap(apply, pair_ids))) 40 | -------------------------------------------------------------------------------- /lm_benchmark/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from . import ddp 17 | from . import single 18 | 19 | BACKEND_TYPE_TO_MODULE_MAP = { 20 | "nccl": ddp.DataParallelDistributedBackend, 21 | None: single.SinlgeNodeBackend, 22 | } 23 | 24 | 25 | def make_backend_from_args(args): 26 | return BACKEND_TYPE_TO_MODULE_MAP[args.distributed_backend](args) 27 | 28 | 29 | def registered_backends(): 30 | return BACKEND_TYPE_TO_MODULE_MAP.keys() 31 | -------------------------------------------------------------------------------- /lm_benchmark/distributed/backend.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List 17 | 18 | 19 | class DistributedBackend(object): 20 | 21 | def __init__(self, args): 22 | pass 23 | 24 | def transform_model(self, model): 25 | raise NotImplementedError 26 | 27 | def get_context_for_microstep_forward(self, model, microstep_idx, gradient_accumulation_steps): 28 | raise NotImplementedError 29 | 30 | def is_master_process(self) -> bool: 31 | raise NotImplementedError 32 | 33 | def get_adjusted_args_for_process(self, args): 34 | raise NotImplementedError 35 | 36 | def get_raw_model(self, model): 37 | raise NotImplementedError 38 | 39 | def translate_model_parameter_name_for_node(self, parameter_name) -> List[str]: 40 | raise NotImplementedError 41 | 42 | def get_world_size(self): 43 | raise NotImplementedError 44 | 45 | def sync(self): 46 | raise NotImplementedError 47 | 48 | def finalize(self): 49 | pass 50 | -------------------------------------------------------------------------------- /lm_benchmark/distributed/ddp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import math 17 | from contextlib import contextmanager 18 | 19 | from torch.nn.parallel import DistributedDataParallel as DDP 20 | from torch.distributed import init_process_group, destroy_process_group, get_world_size, barrier 21 | 22 | from .backend import DistributedBackend 23 | 24 | 25 | class DataParallelDistributedBackend(DistributedBackend): 26 | 27 | def __init__(self, args): 28 | self.rank = int(os.environ.get('RANK', -1)) 29 | assert self.rank != -1, "DDP backend can not be used without rank" 30 | assert "cuda" in args.device, "DDP backend can not be used on non-CUDA devices" 31 | init_process_group(backend=args.distributed_backend) 32 | self.local_rank = int(os.environ['LOCAL_RANK']) 33 | 34 | def get_adjusted_args_for_process(self, args): 35 | effective_batch_size = args.batch_size * args.acc_steps 36 | world_size = self.get_world_size() 37 | if effective_batch_size % world_size != 0: 38 | raise ValueError(f"Effective batch size " 39 | "{effective_batch_size} is not divisible " 40 | "by the world size {world_size}.") 41 | acc_steps_div = math.gcd(args.acc_steps, world_size) 42 | args.acc_steps = args.acc_steps // acc_steps_div 43 | args.batch_size = args.batch_size // (world_size // acc_steps_div) 44 | args.device = f'cuda:{self.local_rank}' 45 | args.seed = args.seed + self.local_rank 46 | return args 47 | 48 | def transform_model(self, model): 49 | return DDP(model, device_ids=[self.local_rank], find_unused_parameters=True) 50 | 51 | @contextmanager 52 | def get_context_for_microstep_forward(self, model, microstep_idx, gradient_accumulation_steps): 53 | model.require_backward_grad_sync = ( 54 | microstep_idx == gradient_accumulation_steps - 1) 55 | yield 56 | 57 | def is_master_process(self) -> bool: 58 | return self.rank == 0 59 | 60 | def get_raw_model(self, model): 61 | return model.module 62 | 63 | def translate_model_parameter_name_for_node(self, parameter_name): 64 | return [f'module.{parameter_name}'] 65 | 66 | def get_world_size(self): 67 | return get_world_size() 68 | 69 | def sync(self): 70 | barrier() 71 | 72 | def finalize(self): 73 | destroy_process_group() 74 | -------------------------------------------------------------------------------- /lm_benchmark/distributed/single.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from contextlib import nullcontext 16 | 17 | from .backend import DistributedBackend 18 | 19 | 20 | class SinlgeNodeBackend(DistributedBackend): 21 | 22 | def transform_model(self, model): 23 | return model 24 | 25 | def get_context_for_microstep_forward(self, *args, **kwargs): 26 | return nullcontext() 27 | 28 | def get_adjusted_args_for_process(self, args): 29 | return args 30 | 31 | def is_master_process(self) -> bool: 32 | return True 33 | 34 | def get_raw_model(self, model): 35 | return model 36 | 37 | def get_world_size(self): 38 | return 1 39 | 40 | def sync(self): 41 | pass 42 | 43 | def translate_model_parameter_name_for_node(self, parameter_name): 44 | return [parameter_name] 45 | -------------------------------------------------------------------------------- /lm_benchmark/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import sys 17 | import numpy as np 18 | import torch 19 | import inspect 20 | import json 21 | import copy 22 | import argparse 23 | import random 24 | import wandb 25 | import logging 26 | 27 | from tqdm import tqdm 28 | 29 | import config 30 | import models 31 | from data import get_dataset, prepare_dataset 32 | from optim.base import train_base 33 | import distributed 34 | from optim.utils import get_batch 35 | 36 | 37 | 38 | def get_args(): 39 | parser = argparse.ArgumentParser(allow_abbrev=False) 40 | parser.add_argument('--checkpoint', type=str, required=True) 41 | 42 | args, rem_args = parser.parse_known_args() 43 | 44 | if os.path.isfile(args.checkpoint): 45 | args.checkpoint, args.checkpoint_filename = os.path.split(args.checkpoint) 46 | else: 47 | args.checkpoint_filename = "ckpt.pt" 48 | 49 | with open(os.path.join(args.checkpoint, "summary.json")) as f: 50 | summary = json.load(f) 51 | 52 | for k, v in summary['args'].items(): 53 | if k not in ["device", "dtype"]: 54 | setattr(args, k, v) 55 | 56 | return config.parse_args_with_format(format=args.config_format, base_parser=argparse.ArgumentParser(allow_abbrev=False), args=rem_args, namespace=args) 57 | 58 | 59 | def get_as_batch(data, seq_length, batch_size, device='cpu', sample_size=None): 60 | all_ix = list(range(0, len(data), seq_length)) 61 | assert all_ix[-1] + seq_length + 1 > len(data) 62 | all_ix.pop() 63 | if sample_size is not None: 64 | all_ix = np.random.choice(all_ix, size=sample_size // seq_length, replace=False).tolist() 65 | 66 | idx = 0 67 | for idx in range(0, len(all_ix), batch_size): 68 | ix = all_ix[idx:idx+batch_size] 69 | assert all([idx + seq_length + 1 <= len(data) for idx in ix]) 70 | x = torch.stack([torch.from_numpy((data[i:i+seq_length]).astype(np.int64)) for i in ix]) 71 | y = torch.stack([torch.from_numpy((data[i+1:i+1+seq_length]).astype(np.int64)) for i in ix]) 72 | if device != 'cpu': 73 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 74 | yield x, y 75 | 76 | def iceildiv(x, y): 77 | return (x + y - 1) // y 78 | 79 | def evaluate(model, data, iterations, acc_steps, batch_size, sequence_length, distributed_backend, extra_args): 80 | device_type = 'cuda' if 'cuda' in str(extra_args.device) else 'cpu' 81 | type_ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast( 82 | device_type=device_type, dtype=extra_args.dtype) # extra_args.dtype) 83 | itr, substep, best_val_loss, text_table = 0, 0, float('inf'), None # best_val_loss not used atm, early stopping not recommended but possible 84 | 85 | stats = {} 86 | 87 | num_substeps_per_epoch = len(data['val']) // (batch_size * sequence_length) 88 | 89 | if not extra_args.no_compile: 90 | print(f"Compiling model ...") 91 | import torch._dynamo as torchdynamo 92 | torchdynamo.config.guard_nn_modules = True 93 | # torchdynamo.config.log_level = logging.DEBUG 94 | model = torch.compile(model) # requires pytorch 2.0+ 95 | 96 | model.eval() 97 | 98 | loss_list_val, acc_list = [], [] 99 | loss_step_list_val = [] 100 | 101 | max_num_batches = 400 102 | with torch.no_grad(): 103 | mid_length = extra_args.mid_length 104 | print(f"Sending sub-sequences of length at most {mid_length}") 105 | seq_length = extra_args.eval_seq_length 106 | print(f"Using seq length {seq_length}") 107 | torch.set_printoptions(sci_mode=False) 108 | for idx, (x, y) in tqdm( 109 | enumerate( 110 | get_as_batch( 111 | data['val'], 112 | seq_length, 113 | batch_size, 114 | device=extra_args.device, 115 | sample_size=extra_args.eval_sample_size 116 | ) 117 | ), 118 | total=iceildiv( 119 | extra_args.eval_sample_size // seq_length if extra_args.eval_sample_size is not None else 120 | iceildiv(len(data['val']), seq_length), 121 | batch_size 122 | ) 123 | ): 124 | val_loss = 0. 125 | acc = 0. 126 | cnt = 0 127 | model.clear_state() 128 | for part_idx, i in enumerate(range(0, x.shape[1], mid_length)): 129 | part_len = x[:, i:i + mid_length].shape[1] 130 | with type_ctx: 131 | outputs = model(x[:, i:i + mid_length], targets=y[:, i:i+mid_length].contiguous(), get_logits=True, use_cache=extra_args.use_cache) 132 | val_loss = outputs['loss'] * part_len + val_loss 133 | acc = ((outputs['logits'].argmax(-1) == y[:, i:i+mid_length]).float().sum()) + acc 134 | cnt += part_len 135 | while len(loss_step_list_val) <= part_idx: 136 | loss_step_list_val.append([]) 137 | loss_step_list_val[part_idx].append(outputs['loss'].item()) 138 | val_loss /= cnt 139 | acc /= cnt 140 | 141 | loss_list_val.append(val_loss.item()) 142 | acc_list.append(acc.item()) 143 | 144 | 145 | stats['val_acc'] = torch.as_tensor(acc_list).mean().item() 146 | stats['val_loss'] = torch.as_tensor(loss_list_val).mean().item() 147 | stats['val_perplexity'] = 2.71828 ** stats['val_loss'] 148 | stats['val_perplexity_per_chunk'] = torch.exp(torch.as_tensor(loss_step_list_val).mean(dim=1)) 149 | 150 | return stats 151 | 152 | def main(args): 153 | 154 | 155 | torch.backends.cuda.matmul.allow_tf32 = True # allows us to make sure we're able to use tensorfloat32 during training 156 | torch.backends.cudnn.allow_tf32 = True 157 | 158 | distributed_backend = distributed.make_backend_from_args(args) 159 | args = distributed_backend.get_adjusted_args_for_process(args) 160 | 161 | args.device = torch.device(args.device) 162 | torch.cuda.set_device(args.device) 163 | device_type = 'cuda' if 'cuda' in str(args.device) else 'cpu' 164 | 165 | torch.manual_seed(args.seed) 166 | random.seed(args.seed) 167 | np.random.seed(args.seed) 168 | 169 | print(f"Loading dataset '{args.dataset}'") 170 | 171 | if distributed_backend.is_master_process(): 172 | prepare_dataset(args) 173 | distributed_backend.sync() 174 | 175 | data = get_dataset(args) # data is a dict: {'train': train_tokenized, 'val': eval_tokenized} 176 | 177 | print(f"Num training tokens: {len(data['train'])}") 178 | print(f"Num validation tokens: {len(data['val'])}") 179 | 180 | model = models.make_model_from_args(args).to(args.device) 181 | 182 | checkpoint = torch.load(os.path.join(args.checkpoint, args.checkpoint_filename)) 183 | model.load_state_dict({x: y for x, y in checkpoint['model'].items() if "attn.bias" not in x and "wpe" not in x}, strict=False) 184 | 185 | model = distributed_backend.transform_model(model) 186 | 187 | print(f"\Evaluating model={args.model} \n{vars(args)}\n") 188 | 189 | stats = evaluate(model, data, args.iterations, args.acc_steps, args.batch_size, args.sequence_length, 190 | distributed_backend=distributed_backend, 191 | extra_args=args) 192 | 193 | print(stats) 194 | 195 | distributed_backend.finalize() 196 | 197 | 198 | if __name__ == "__main__": 199 | args = get_args() 200 | main(args) 201 | -------------------------------------------------------------------------------- /lm_benchmark/eval_cmd_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import dataclasses 16 | 17 | @dataclasses.dataclass 18 | class Setting(object): 19 | exp_dir: str 20 | eval_len: int 21 | topk: int 22 | mem_size: int 23 | mid_length: int 24 | use_cache: bool = True 25 | selection_method: str = "per_token_and_head" 26 | mem_cache_freq: int = 50 27 | eval_sample_size: int = 4000000 28 | lm_cache: str = "mem" 29 | 30 | 31 | exp_dirs = { 32 | "arxiv_landmark": "./exps/arxiv_landmark", 33 | "arxiv_baseline": "./exps/arxiv_baseline", 34 | 35 | "pg19_landmark": "./exps/pg19_landmark", 36 | "pg19_baseline": "./exps/pg19_baseline", 37 | "pg19_xl": "./exps/pg19_xl", 38 | } 39 | settings = [ 40 | dict(exp_dir=exp_dirs["pg19_baseline"], eval_len=360, mid_length=360, 41 | lm_cache="none", mem_cache_freq=None, mem_size=None, topk=None, use_cache=False,eval_sample_size=None), 42 | dict(exp_dir=exp_dirs["pg19_baseline"], eval_len=512, mid_length=512, 43 | lm_cache="none", mem_cache_freq=None, mem_size=None, topk=None, use_cache=False,eval_sample_size=None), 44 | 45 | dict(exp_dir=exp_dirs["pg19_xl"], eval_len=2048, mid_length=256, 46 | lm_cache="kv", mem_cache_freq=None, mem_size=256, topk=None,eval_sample_size=None), 47 | dict(exp_dir=exp_dirs["pg19_xl"], eval_len=4096, mid_length=256, 48 | lm_cache="kv", mem_cache_freq=None, mem_size=256, topk=None,eval_sample_size=None), 49 | 50 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=512, mid_length=250, mem_size=10, topk=2,eval_sample_size=None), 51 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=250, mem_size=40, topk=2,eval_sample_size=None), 52 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=350, mem_size=40, topk=2,eval_sample_size=None), 53 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=300, mem_size=40, topk=3,eval_sample_size=None), 54 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=250, mem_size=20, topk=4,eval_sample_size=None), 55 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=250, mem_size=40, topk=4,eval_sample_size=None), 56 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=4096, mid_length=250, mem_size=40, topk=4,eval_sample_size=None), 57 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=4096, mid_length=250, mem_size=80, topk=2,eval_sample_size=None), 58 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=4096, mid_length=250, mem_size=80, topk=4,eval_sample_size=None), 59 | 60 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=250, mem_size=40, topk=2,eval_sample_size=None), 61 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=250, mem_size=40, topk=4,eval_sample_size=None), 62 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=4096, mid_length=250, mem_size=40, topk=4,eval_sample_size=None), 63 | 64 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=250, mem_size=40, topk=2, 65 | selection_method="max_over_heads",eval_sample_size=None), 66 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=250, mem_size=40, topk=4, 67 | selection_method="max_over_heads",eval_sample_size=None), 68 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=4096, mid_length=250, mem_size=80, topk=4, 69 | selection_method="max_over_heads",eval_sample_size=None), 70 | 71 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=250, mem_size=40, topk=2, 72 | selection_method="max_over_tokens",eval_sample_size=None), 73 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=250, mem_size=40, topk=4, 74 | selection_method="max_over_tokens",eval_sample_size=None), 75 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=4096, mid_length=250, mem_size=80, topk=4, 76 | selection_method="max_over_tokens",eval_sample_size=None), 77 | 78 | dict(exp_dir=exp_dirs["arxiv_baseline"], eval_len=360, mid_length=360, 79 | lm_cache=None, mem_cache_freq=None, mem_size=None, topk=None, use_cache=False), 80 | dict(exp_dir=exp_dirs["arxiv_baseline"], eval_len=512, mid_length=512, 81 | lm_cache=None, mem_cache_freq=None, mem_size=None, topk=None, use_cache=False), 82 | 83 | dict(exp_dir=exp_dirs["arxiv_landmark"], eval_len=512, mid_length=250, mem_size=10, topk=2), 84 | dict(exp_dir=exp_dirs["arxiv_landmark"], eval_len=2048, mid_length=250, mem_size=40, topk=2), 85 | dict(exp_dir=exp_dirs["arxiv_landmark"], eval_len=2048, mid_length=350, mem_size=40, topk=2), 86 | dict(exp_dir=exp_dirs["arxiv_landmark"], eval_len=2048, mid_length=300, mem_size=40, topk=3), 87 | dict(exp_dir=exp_dirs["arxiv_landmark"], eval_len=2048, mid_length=250, mem_size=20, topk=4), 88 | dict(exp_dir=exp_dirs["arxiv_landmark"], eval_len=2048, mid_length=250, mem_size=40, topk=4), 89 | dict(exp_dir=exp_dirs["arxiv_landmark"], eval_len=4096, mid_length=250, mem_size=40, topk=4), 90 | dict(exp_dir=exp_dirs["arxiv_landmark"], eval_len=4096, mid_length=250, mem_size=80, topk=2), 91 | dict(exp_dir=exp_dirs["arxiv_landmark"], eval_len=4096, mid_length=250, mem_size=80, topk=4), 92 | ] 93 | 94 | import itertools 95 | def product_dict(**kwargs): 96 | keys = kwargs.keys() 97 | for instance in itertools.product(*kwargs.values()): 98 | yield dict(zip(keys, instance)) 99 | 100 | flat_settings = [] 101 | for setting in settings: 102 | flat_settings.extend(product_dict(**{x: y if isinstance(y, list) else [y] for x, y in setting.items()})) 103 | 104 | settings = [Setting(**d) for d in flat_settings] 105 | last_exp_dir = None 106 | 107 | print ("#!/bin/bash") 108 | for setting in settings: 109 | s_lines = [] 110 | if last_exp_dir != setting.exp_dir: 111 | s_lines.append("""EXP_DIR="{exp_dir}";""".format(**dataclasses.asdict(setting))) 112 | last_exp_dir = setting.exp_dir 113 | use_cache_str = "--use_cache" if setting.use_cache else "" 114 | mem_size_flag = "" 115 | s_lines += [""" 116 | filename="$EXP_DIR/eval-{eval_len}-{selection_method}-{topk}-memsize{mem_size}-midlength{mid_length}-memcachefreq{mem_cache_freq}"; 117 | grep val_acc $filename /dev/null; 118 | if [[ $? -ne 0 ]]; then 119 | script -c \\ 120 | "python eval.py \\ 121 | --checkpoint $EXP_DIR \\ 122 | --distributed_backend None \\ 123 | --lm_cache {lm_cache} \\""",""" 124 | --mem_cache_size {mem_size} \\""" if setting.mem_size is not None else "",""" 125 | --mem_cache_freq {mem_cache_freq} \\""" if setting.mem_cache_freq is not None else "", """ 126 | --mem_freq None \\ 127 | --eval_seq_length {eval_len} \\ 128 | --cache_selection_method {selection_method} \\""",""" 129 | --cache_topk {topk} \\""" if setting.topk is not None else "", """ 130 | --no_compile \\ 131 | --batch_size 16 \\ 132 | --mid_length {mid_length} \\ 133 | --positional_encoder rotary \\ 134 | --pos_jump_on_mem 0 \\ 135 | {use_cache_str} \\""", """ 136 | --eval_sample_size {eval_sample_size}""" if setting.eval_sample_size is not None else "", """ 137 | " $filename; 138 | fi;"""] 139 | print ("".join(s_lines).format(**dataclasses.asdict(setting), use_cache_str=use_cache_str)) 140 | -------------------------------------------------------------------------------- /lm_benchmark/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import sys 17 | import numpy as np 18 | import torch 19 | import inspect 20 | import json 21 | import copy 22 | import argparse 23 | import random 24 | import wandb 25 | 26 | import config 27 | import models 28 | from data import get_dataset, prepare_dataset 29 | from optim.base import train_base 30 | from optim.transformer_xl import train_xl 31 | import distributed 32 | 33 | 34 | def get_args(): 35 | parser = argparse.ArgumentParser(allow_abbrev=False) 36 | parser.add_argument('--config_format', default='base', choices=config.registered_formats()) 37 | 38 | args, rem_args = parser.parse_known_args() 39 | 40 | return config.parse_args_with_format(format=args.config_format, base_parser=parser, args=rem_args, namespace=args) 41 | 42 | 43 | def main(args): 44 | 45 | 46 | torch.backends.cuda.matmul.allow_tf32 = True # allows us to make sure we're able to use tensorfloat32 during training 47 | torch.backends.cudnn.allow_tf32 = True 48 | 49 | distributed_backend = distributed.make_backend_from_args(args) 50 | args = distributed_backend.get_adjusted_args_for_process(args) 51 | 52 | args.device = torch.device(args.device) 53 | torch.cuda.set_device(args.device) 54 | device_type = 'cuda' if 'cuda' in str(args.device) else 'cpu' 55 | 56 | torch.manual_seed(args.seed) 57 | random.seed(args.seed) 58 | np.random.seed(args.seed) 59 | 60 | print(f"Loading dataset '{args.dataset}'") 61 | 62 | if distributed_backend.is_master_process(): 63 | prepare_dataset(args) 64 | distributed_backend.sync() 65 | 66 | data = get_dataset(args) # data is a dict: {'train': train_tokenized, 'val': eval_tokenized} 67 | 68 | print(f"Num training tokens: {len(data['train'])}") 69 | print(f"Num validation tokens: {len(data['val'])}") 70 | 71 | model = models.make_model_from_args(args).to(args.device) 72 | 73 | model = distributed_backend.transform_model(model) 74 | 75 | group_specs = distributed_backend.get_raw_model(model).get_parameter_group_specs() 76 | param_name_mapping = {p_name: p for p_name, p in model.named_parameters()} 77 | optimized_params_cnt = 0 78 | for g in group_specs: 79 | params = [] 80 | for p_name in g["params"]: 81 | translated_p_names = distributed_backend.translate_model_parameter_name_for_node(p_name) 82 | params += [param_name_mapping[p_name] for p_name in translated_p_names] 83 | g["params"] = params 84 | optimized_params_cnt += sum([p.numel() for p in g["params"]]) 85 | print("number of optimized parameters: %.2fM" % (optimized_params_cnt/1e6,)) 86 | if args.opt == 'adamw': 87 | use_fused = (device_type == 'cuda') and ('fused' in inspect.signature(torch.optim.AdamW).parameters) 88 | print(f"using fused AdamW: {use_fused}") 89 | extra_args = dict(fused=True) if use_fused else dict() 90 | opt = torch.optim.AdamW(group_specs, lr=args.lr, betas=(args.beta1, args.beta2), 91 | weight_decay=args.weight_decay, **extra_args) 92 | elif args.opt == 'adafactor': 93 | from optim.adafactor import Adafactor 94 | opt = Adafactor(group_specs, lr=args.lr) 95 | else: 96 | opt = torch.optim.SGD(group_specs, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) 97 | 98 | if args.scheduler != 'none': 99 | if args.scheduler in ['cos', 'linear']: 100 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=opt, max_lr=args.lr, total_steps=args.iterations, 101 | pct_start=args.warmup_percent, anneal_strategy=args.scheduler, 102 | cycle_momentum=False, div_factor=1e2, final_div_factor=.05) 103 | else: 104 | raise NotImplementedError(f"Unknown scheduler type: {args.scheduler}.") 105 | else: 106 | scheduler = None 107 | 108 | args.world_size = distributed_backend.get_world_size() 109 | exp_name = args.exp_name 110 | if distributed_backend.is_master_process() and args.wandb: 111 | params_copy = copy.deepcopy(vars(args)) 112 | del params_copy['device'] 113 | wandb.init(project=args.wandb_project, name=exp_name, config=params_copy) 114 | 115 | ckpt_path = f"{args.results_base_folder}/{args.dataset}/{args.model}/{exp_name}" 116 | if not os.path.exists(ckpt_path): 117 | if distributed_backend.is_master_process(): 118 | os.makedirs(ckpt_path) 119 | else: 120 | if os.path.isfile(f"{ckpt_path}/summary.json"): # the experiment was already completed 121 | print(f"Already found experiment '{ckpt_path}'.\nSkipping.") 122 | sys.exit(0) 123 | 124 | if args.optimization_process == 'transformer_xl': 125 | train = train_xl 126 | else: 127 | train = train_base 128 | 129 | print(f"\nTraining model={args.model} \n{vars(args)}\n") 130 | 131 | stats = train(model, opt, data, scheduler, args.iterations, args.acc_steps, args.batch_size, args.sequence_length, 132 | eval_freq=args.eval_freq, 133 | distributed_backend=distributed_backend, 134 | ckpt_path=ckpt_path, extra_args=args) 135 | 136 | args.device = None 137 | args.dtype = None 138 | stats['args'] = vars(args) 139 | if distributed_backend.is_master_process(): 140 | with open(f"{ckpt_path}/summary.json", "w") as fs: 141 | json.dump(stats, fs) 142 | distributed_backend.finalize() 143 | 144 | 145 | if __name__ == "__main__": 146 | args = get_args() 147 | main(args) 148 | -------------------------------------------------------------------------------- /lm_benchmark/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import base_new, landmark, landmark_with_cmt 16 | 17 | MODELS = { 18 | "base": base_new.GPTBase, 19 | "landmark": landmark.GPTBase, 20 | "landmark_with_cmt": landmark_with_cmt.GPTBase, 21 | } 22 | 23 | 24 | def make_model_from_args(args): 25 | return MODELS[args.model](args) 26 | 27 | 28 | def registered_models(): 29 | return MODELS.keys() 30 | -------------------------------------------------------------------------------- /lm_benchmark/models/caches/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import cache, mem_cache, kv_cache, kv_cache_train 16 | 17 | CACHES = { 18 | "none": cache.LMCache, 19 | "mem": mem_cache.MemLMCache, 20 | "kv": kv_cache.KVLMCache, 21 | "kv_train": kv_cache_train.KVLMCache 22 | } 23 | 24 | 25 | def get_cache(cache_name): 26 | return CACHES[cache_name] 27 | 28 | 29 | def registered_caches(): 30 | return CACHES.keys() 31 | -------------------------------------------------------------------------------- /lm_benchmark/models/caches/cache.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from torch import nn 16 | 17 | class LMCacheStorage(nn.Module): 18 | 19 | def __init__(self, config, layer): 20 | super().__init__() 21 | self.config = config 22 | #self._layer = [layer] 23 | 24 | #@property 25 | #def layer(self): 26 | # return self._layer[0] 27 | 28 | def store_in_cache(self, keys, values_dict): 29 | pass 30 | 31 | def retrieve_for_query(self, q, cache_context, pos_emb_closure, start_index): 32 | return None, {} 33 | 34 | def clear_state(self): 35 | pass 36 | 37 | 38 | class LMCacheContext(object): 39 | pass 40 | 41 | 42 | class LMCache(nn.Module): 43 | 44 | def __init__(self, config): 45 | super().__init__() 46 | self.config = config 47 | self.layer_storages_map = dict() 48 | self.layer_storages = nn.ModuleList() 49 | self.cache_storage = self.get_cache_storage() 50 | self.context_class = self.get_context_class() 51 | 52 | def get_cache_storage(self): 53 | return LMCacheStorage 54 | 55 | def get_context_class(self): 56 | return LMCacheContext 57 | 58 | def forward(self, x): 59 | return x, 0, self.get_context_class() 60 | 61 | def get_final_logits(self, logits): 62 | return logits 63 | 64 | def get_storage_for_layer(self, l): 65 | if l not in self.layer_storages_map: 66 | self.layer_storages_map[l] = len(self.layer_storages) 67 | self.layer_storages.append(self.cache_storage(self.config, l)) 68 | return self.layer_storages[self.layer_storages_map[l]] 69 | 70 | def clear_state(self): 71 | for storage in self.layer_storages: 72 | storage.clear_state() 73 | -------------------------------------------------------------------------------- /lm_benchmark/models/caches/kv_cache.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | 17 | import torch 18 | 19 | from .cache import LMCache, LMCacheStorage 20 | 21 | class KVLMCacheStorage(LMCacheStorage): 22 | 23 | def __init__(self, config, layer): 24 | super().__init__(config, layer) 25 | n_embd_per_head = config.n_embd // config.n_head 26 | self.max_cache_size = config.mem_cache_size 27 | self.register_buffer("cache_k", torch.empty((config.batch_size, config.n_head, config.mem_cache_size, n_embd_per_head)), persistent=False) 28 | self.register_buffer("cache_v", torch.empty((config.batch_size, config.n_head, config.mem_cache_size, n_embd_per_head)), persistent=False) 29 | self.cache_iter = 0 30 | self.cache_size = 0 31 | self.clear_state() 32 | 33 | def clear_state(self): 34 | self.cache_iter = 0 35 | self.cache_size = 0 36 | 37 | def retrieve_for_query(self, q, cache_context, pos_emb_closure, start_index): 38 | if self.cache_size == 0: 39 | return None, {} 40 | B, nh, T, hs = q.size() # batch size, num_heads, sequence length, per-head embedding dimensionality (n_embd) 41 | cached_keys = self.cache_k[:B, :, :self.cache_size] 42 | k_indices = torch.cat(( 43 | torch.arange(self.cache_size - self.cache_iter, self.cache_size, device=q.device), 44 | torch.arange(self.cache_size - cached_keys.shape[2], self.cache_size - self.cache_iter, device=q.device), 45 | )) 46 | assert self.cache_size == start_index 47 | last_incomplete_k = pos_emb_closure.adapt_keys(cached_keys, indices=k_indices) 48 | att_incomplete = (q @ last_incomplete_k.transpose(-2, -1)) * (1.0 / math.sqrt(last_incomplete_k.size(-1))) 49 | last_incomplete_v = self.cache_v[:B, :, :self.cache_size].unsqueeze(2).expand(B, nh, T, -1, hs) 50 | return att_incomplete, {'v': last_incomplete_v.clone()} 51 | 52 | 53 | def store_in_cache(self, keys, values_dict): 54 | if self.max_cache_size == 0: 55 | return 56 | B, nh, T, hs = keys.size() 57 | k_for_cache = keys[:, :, -self.max_cache_size:] 58 | v_for_cache = values_dict['v'][:, :, -self.max_cache_size:] 59 | self.cache_iter = (self.cache_iter + T - k_for_cache.shape[2]) % self.cache_k.shape[2] 60 | self.cache_size += T - k_for_cache.shape[2] 61 | T = k_for_cache.shape[2] 62 | 63 | if self.cache_iter + T >= self.max_cache_size: 64 | next_iter = (self.cache_iter + T) - self.max_cache_size 65 | rem = (self.max_cache_size - self.cache_iter) 66 | self.cache_k[:B, :, :next_iter].copy_(k_for_cache[:, :, rem:]) 67 | self.cache_k[:B, :, self.cache_iter:].copy_(k_for_cache[:,:, :rem]) 68 | self.cache_v[:B, :, :next_iter].copy_(v_for_cache[:,:, rem:]) 69 | self.cache_v[:B, :, self.cache_iter:].copy_(v_for_cache[:,:, :rem]) 70 | else: 71 | next_iter = self.cache_iter + T 72 | self.cache_k[:B, :, self.cache_iter:next_iter].copy_(k_for_cache) 73 | self.cache_v[:B, :, self.cache_iter:next_iter].copy_(v_for_cache) 74 | self.cache_iter = next_iter 75 | self.cache_size += T 76 | 77 | 78 | class KVLMCache(LMCache): 79 | 80 | def __init__(self, config): 81 | super().__init__(config) 82 | self.total_len = 0 83 | 84 | def get_cache_storage(self): 85 | return KVLMCacheStorage 86 | 87 | def forward(self, x): 88 | B, T = x.size() 89 | prev_total_len = self.total_len 90 | self.total_len = self.total_len + x.shape[1] 91 | return x, prev_total_len, self.context_class() 92 | 93 | def clear_state(self): 94 | super().clear_state() 95 | self.total_len = 0 96 | 97 | -------------------------------------------------------------------------------- /lm_benchmark/models/caches/kv_cache_train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | 17 | import torch 18 | 19 | from .cache import LMCache, LMCacheStorage 20 | 21 | class KVLMCacheStorage(LMCacheStorage): 22 | 23 | def __init__(self, config, layer): 24 | super().__init__(config, layer) 25 | n_embd_per_head = config.n_embd // config.n_head 26 | self.max_cache_size = config.mem_cache_size 27 | self._cache_k = [torch.empty((config.batch_size, config.n_head, config.mem_cache_size, n_embd_per_head), device=torch.device('cuda'))] 28 | self._cache_v = [torch.empty((config.batch_size, config.n_head, config.mem_cache_size, n_embd_per_head), device=torch.device('cuda'))] 29 | self.cache_iter = 0 30 | self.cache_size = 0 31 | self.clear_state() 32 | 33 | @property 34 | def cache_k(self): 35 | return self._cache_k[0] 36 | 37 | @property 38 | def cache_v(self): 39 | return self._cache_v[0] 40 | 41 | def clear_state(self): 42 | self.cache_iter = 0 43 | self.cache_size = 0 44 | 45 | def retrieve_for_query(self, q, cache_context, pos_emb_closure, start_index): 46 | if self.cache_size == 0: 47 | return None, {} 48 | B, nh, T, hs = q.size() # batch size, num_heads, sequence length, per-head embedding dimensionality (n_embd) 49 | cached_keys = self.cache_k[:B, :, :self.cache_size] 50 | k_indices = torch.cat(( 51 | torch.arange(self.cache_size - self.cache_iter, self.cache_size, device=q.device), 52 | torch.arange(self.cache_size - cached_keys.shape[2], self.cache_size - self.cache_iter, device=q.device), 53 | )) 54 | assert self.cache_size == start_index 55 | last_incomplete_k = pos_emb_closure.adapt_keys(cached_keys, indices=k_indices) 56 | att_incomplete = (q @ last_incomplete_k.transpose(-2, -1)) * (1.0 / math.sqrt(last_incomplete_k.size(-1))) 57 | last_incomplete_v = self.cache_v[:B, :, :self.cache_size].unsqueeze(2).expand(B, nh, T, -1, hs) 58 | return att_incomplete, {'v': last_incomplete_v.clone()} 59 | 60 | 61 | def store_in_cache(self, keys, values_dict): 62 | if self.max_cache_size == 0: 63 | return 64 | B, nh, T, hs = keys.size() 65 | k_for_cache = keys[:, :, -self.max_cache_size:] 66 | v_for_cache = values_dict['v'][:, :, -self.max_cache_size:] 67 | self.cache_iter = (self.cache_iter + T - k_for_cache.shape[2]) % self.cache_k.shape[2] 68 | self.cache_size += T - k_for_cache.shape[2] 69 | T = k_for_cache.shape[2] 70 | 71 | if self.cache_iter + T >= self.max_cache_size: 72 | next_iter = (self.cache_iter + T) - self.max_cache_size 73 | rem = (self.max_cache_size - self.cache_iter) 74 | self.cache_k[:B, :, :next_iter].copy_(k_for_cache[:, :, rem:]) 75 | self.cache_k[:B, :, self.cache_iter:].copy_(k_for_cache[:,:, :rem]) 76 | self.cache_v[:B, :, :next_iter].copy_(v_for_cache[:,:, rem:]) 77 | self.cache_v[:B, :, self.cache_iter:].copy_(v_for_cache[:,:, :rem]) 78 | else: 79 | next_iter = self.cache_iter + T 80 | self.cache_k[:B, :, self.cache_iter:next_iter].copy_(k_for_cache) 81 | self.cache_v[:B, :, self.cache_iter:next_iter].copy_(v_for_cache) 82 | self.cache_iter = next_iter 83 | self.cache_size += T 84 | 85 | 86 | class KVLMCache(LMCache): 87 | 88 | def __init__(self, config): 89 | super().__init__(config) 90 | self.total_len = 0 91 | 92 | def get_cache_storage(self): 93 | return KVLMCacheStorage 94 | 95 | def forward(self, x): 96 | B, T = x.size() 97 | prev_total_len = self.total_len 98 | self.total_len = self.total_len + x.shape[1] 99 | return x, prev_total_len, self.context_class() 100 | 101 | def clear_state(self): 102 | super().clear_state() 103 | self.total_len = 0 104 | 105 | -------------------------------------------------------------------------------- /lm_benchmark/models/positional_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import encoder, rotary, rotary_mem_jump 16 | 17 | POS_ENCS = { 18 | "rotary": rotary.RotaryPositionalEncoder, 19 | "rotary_mem_jump": rotary_mem_jump.RotaryJumpMemPositionalEncoder 20 | } 21 | 22 | 23 | def get_encoder(encoder_name): 24 | return POS_ENCS[encoder_name] 25 | 26 | 27 | def registered_encoders(): 28 | return POS_ENCS.keys() 29 | -------------------------------------------------------------------------------- /lm_benchmark/models/positional_encoders/encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from torch import nn 17 | 18 | 19 | class PositionalEncoderClosure(object): 20 | 21 | def __init__(self, encoder): 22 | self.encoder = encoder 23 | 24 | def adapt_model_input(self, x, start_index): 25 | return x 26 | 27 | def adapt_keys(self, k, start_index=None, indices=None): 28 | if indices is None: 29 | T = k.shape[-2] 30 | indices = torch.arange(start_index, T + start_index, device=k.device) 31 | return self._adapt_keys_for_indices(k, indices) 32 | 33 | def _adapt_keys_for_indices(self, k, indices): 34 | return k 35 | 36 | def adapt_queries(self, q, start_index): 37 | return q 38 | 39 | def adapt_attention_before_softmax(self, att, start_query_index=None, start_key_index=None, q_indices=None, k_indices=None): 40 | if q_indices is None: 41 | qT = att.shape[-2] 42 | q_indices = torch.arange(start_query_index, qT + start_query_index, device=att.device) 43 | if k_indices is None: 44 | kT = att.shape[-1] 45 | k_indices = torch.arange(start_key_index, kT + start_key_index, device=att.device) 46 | return self._adapt_attention_before_softmax_for_indices(att, q_indices, k_indices) 47 | 48 | def _adapt_attention_before_softmax_for_indices(self, att, query_indices, key_indices): 49 | return att 50 | 51 | 52 | class PositionalEncoder(nn.Module): 53 | 54 | closure_model = PositionalEncoderClosure 55 | 56 | def __init__(self, config): 57 | super().__init__() 58 | self.config = config 59 | 60 | def forward(self, x): 61 | return x, self.closure_model(self) 62 | -------------------------------------------------------------------------------- /lm_benchmark/models/positional_encoders/rotary.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from torch import nn 17 | 18 | from .encoder import PositionalEncoder, PositionalEncoderClosure 19 | from .rotary_utils import apply_rotary_emb 20 | 21 | 22 | class RotaryPositionalEncoderClosure(PositionalEncoderClosure): 23 | 24 | def adapt_vector_for_indices(self, v, indices): 25 | #changer = torch.zeros_like(indices) 26 | #changer[50::51] = 1 27 | #indices -= torch.cumsum(changer, dim=-1) 28 | 29 | *other_dims, T, hs = v.shape 30 | if T == 0: 31 | return v 32 | other_dims_prefix = other_dims[:len(other_dims) - len(indices.shape) + 1] 33 | freqs = (indices.unsqueeze(-1) * self.encoder.freqs.view(1, -1)).unsqueeze(-1).expand(*indices.shape, -1, 2).reshape(*indices.shape, hs) 34 | freqs = freqs.view([1] * len(other_dims_prefix) + list(indices.shape) + [hs]).expand(*v.shape) 35 | v = apply_rotary_emb(freqs, v) 36 | return v 37 | 38 | def _adapt_keys_for_indices(self, k, indices): 39 | return self.adapt_vector_for_indices(k, indices) 40 | 41 | def adapt_queries(self, q, start_index): 42 | T = q.shape[-2] 43 | indices = torch.arange(start_index, T + start_index, device=q.device) 44 | return self.adapt_vector_for_indices(q, indices) 45 | 46 | 47 | class RotaryPositionalEncoder(PositionalEncoder): 48 | 49 | def __init__(self, config): 50 | super().__init__(config) 51 | self.max_pos_log = 4 52 | self.max_pos_base = 10 53 | n_embd_per_head = config.n_embd // config.n_head 54 | freqs = (self.max_pos_base ** (-self.max_pos_log * torch.arange(0, n_embd_per_head, 2)[:(n_embd_per_head // 2)].float() / n_embd_per_head)) 55 | self.register_buffer("freqs", freqs) 56 | 57 | closure_model = RotaryPositionalEncoderClosure 58 | -------------------------------------------------------------------------------- /lm_benchmark/models/positional_encoders/rotary_mem_jump.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from torch import nn 17 | 18 | from .encoder import PositionalEncoder, PositionalEncoderClosure 19 | from .rotary_utils import apply_rotary_emb 20 | 21 | 22 | class JumpingRotaryPositionalEncoderClosure(PositionalEncoderClosure): 23 | 24 | def __init__(self, encoder, jumps): 25 | super().__init__(encoder) 26 | self.jumps = jumps 27 | 28 | def adapt_vector_for_indices(self, v, indices): 29 | #changer = torch.zeros_like(indices) 30 | #changer[50::51] = 1 31 | #indices -= torch.cumsum(changer, dim=-1) 32 | 33 | 34 | *other_dims, T, hs = v.shape 35 | if T == 0: 36 | return v 37 | other_dims_prefix = other_dims[:len(other_dims) - len(indices.shape) + 1] 38 | if self.jumps is not None: 39 | indices = indices.view([1] * len(other_dims_prefix) + list(indices.shape)).repeat(other_dims_prefix + [1] * len(indices.shape)) 40 | indices[..., 1:] = indices[..., 1:] + self.jumps 41 | other_dims_prefix = [] 42 | # print(indices) 43 | freqs = (indices.unsqueeze(-1) * self.encoder.freqs.view(1, -1)).unsqueeze(-1).expand(*indices.shape, -1, 2).reshape(*indices.shape, hs) 44 | freqs = freqs.view([1] * len(other_dims_prefix) + list(indices.shape) + [hs]).expand(*v.shape) 45 | v = apply_rotary_emb(freqs, v) 46 | return v 47 | 48 | def _adapt_keys_for_indices(self, k, indices): 49 | return self.adapt_vector_for_indices(k, indices) 50 | 51 | def adapt_queries(self, q, start_index): 52 | T = q.shape[-2] 53 | indices = torch.arange(start_index, T + start_index, device=q.device) 54 | return self.adapt_vector_for_indices(q, indices) 55 | 56 | 57 | class RotaryJumpMemPositionalEncoder(PositionalEncoder): 58 | 59 | def __init__(self, config): 60 | super().__init__(config) 61 | self.max_pos_log = 4 62 | self.max_pos_base = 10 63 | n_embd_per_head = config.n_embd // config.n_head 64 | freqs = (self.max_pos_base ** (-self.max_pos_log * torch.arange(0, n_embd_per_head, 2)[:(n_embd_per_head // 2)].float() / n_embd_per_head)) 65 | self.register_buffer("freqs", freqs) 66 | 67 | def forward(self, x): 68 | if self.config.pos_jump_on_mem is not None and self.config.pos_jump_on_mem > 0: 69 | #assert self.config.mem_freq is not None 70 | is_mem = (x == self.config.landmark_id) 71 | jumps = torch.cumsum((is_mem * torch.randint_like(x, self.config.pos_jump_on_mem))[:, :-1], dim=-1) 72 | return x, self.closure_model(self, jumps.unsqueeze(1)) # (B, 1, T) 73 | else: 74 | return x, self.closure_model(self, None) 75 | 76 | 77 | closure_model = JumpingRotaryPositionalEncoderClosure 78 | -------------------------------------------------------------------------------- /lm_benchmark/models/positional_encoders/rotary_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | def rotate_half(x): 18 | x = x.view(*x.shape[:-1], -1, 2) 19 | x1, x2 = x.unbind(dim = -1) 20 | x = torch.stack((-x2, x1), dim = -1) 21 | return x.view(*x.shape[:-2], -1) 22 | 23 | def apply_rotary_emb(freqs, t, start_index = 0, scale = 1.): 24 | #freqs = freqs.to(t) 25 | rot_dim = freqs.shape[-1] 26 | end_index = start_index + rot_dim 27 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' 28 | #t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] 29 | t = (t * freqs.cos().to(t) * scale) + (rotate_half(t) * freqs.sin().to(t) * scale) 30 | #return torch.cat((t_left, t, t_right), dim = -1) 31 | return t 32 | -------------------------------------------------------------------------------- /lm_benchmark/optim/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from contextlib import nullcontext 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | import wandb 20 | import time 21 | import copy 22 | import traceback 23 | 24 | from .utils import get_batch, save_checkpoint 25 | 26 | @torch.no_grad() 27 | def eval(model, data_tensor, sequence_length, batch_size, device='cpu', max_num_batches=24, ctx=nullcontext()): 28 | assert model.training == False 29 | 30 | loss_list_val, acc_list = [], [] 31 | 32 | for _ in range(max_num_batches): 33 | x, y = get_batch(data_tensor, sequence_length, batch_size, device=device) 34 | with ctx: 35 | outputs = model(x, targets=y, get_logits=True) 36 | val_loss = outputs['loss'] 37 | loss_list_val.append(val_loss) 38 | acc_list.append((outputs['logits'].argmax(-1) == y).float().mean()) 39 | 40 | val_acc = torch.stack(acc_list).mean().item() 41 | val_loss = torch.stack(loss_list_val).mean().item() 42 | val_perplexity = 2.71828 ** val_loss 43 | 44 | return val_acc, val_loss, val_perplexity 45 | 46 | 47 | def train_base(model, opt, data, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend, extra_args): 48 | device_type = 'cuda' if 'cuda' in str(extra_args.device) else 'cpu' 49 | type_ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast( 50 | device_type=device_type, dtype=extra_args.dtype) # extra_args.dtype) 51 | itr, substep, best_val_loss, text_table = 0, 0, float('inf'), None # best_val_loss not used atm, early stopping not recommended but possible 52 | 53 | stats = {'train_loss': [], 'val_loss': [], 'val_pp': [], 'val_acc': []} 54 | 55 | num_substeps_per_epoch = len(data['train']) // (batch_size * sequence_length) 56 | 57 | if not extra_args.no_compile: 58 | print(f"Compiling model ...") 59 | import torch._dynamo as torchdynamo 60 | torchdynamo.config.guard_nn_modules = True 61 | model = torch.compile(model) # requires pytorch 2.0+ 62 | 63 | model.train() 64 | 65 | t0 = time.time() 66 | 67 | while itr < iterations: 68 | 69 | for microstep_idx in range(acc_steps): # gradient accumulation 70 | x, y = get_batch(data['train'], sequence_length, batch_size, device=extra_args.device) 71 | with type_ctx: 72 | with distributed_backend.get_context_for_microstep_forward(model=model, microstep_idx=microstep_idx, gradient_accumulation_steps=acc_steps): 73 | if getattr(distributed_backend.get_raw_model(model), "needs_iter", False): 74 | outputs = model(x, targets=y, iter=itr) 75 | else: 76 | outputs = model(x, targets=y) 77 | 78 | loss = outputs['loss'] 79 | loss.backward() 80 | substep += 1 81 | 82 | opt.step() 83 | scheduler.step() 84 | opt.zero_grad(set_to_none=True) 85 | itr += 1 86 | 87 | if itr % eval_freq == 0 or itr == iterations: # from here it's only evaluation code, all the training is above 88 | if distributed_backend.is_master_process(): 89 | t1 = time.time() 90 | dt = t1 - t0 91 | epoch = substep//num_substeps_per_epoch 92 | 93 | model.eval() 94 | train_loss = loss.detach().cpu().item() 95 | current_lr = scheduler.get_last_lr()[0] if scheduler is not None else extra_args.lr 96 | val_acc, val_loss, val_perplexity = eval(model, data['val'], sequence_length, batch_size, 97 | extra_args.device, max_num_batches=24, ctx=type_ctx) 98 | 99 | print_string = f"{epoch}/{itr} [train] loss={train_loss:.3f} [val] loss={val_loss:.3f}, pp={val_perplexity:.2f}, acc={val_acc:3f}" 100 | print_string += f" [time per itr] {dt*1000/eval_freq:.2f}ms" 101 | if scheduler is not None: 102 | print_string += f" [lr] {current_lr:.5f}" 103 | print(print_string) 104 | 105 | if extra_args.wandb: 106 | wandb.log({ 107 | "iter": itr, 108 | "train/loss": train_loss, 109 | "val/loss": val_loss, 110 | "val/perplexity": val_perplexity, 111 | "val/acc": val_acc, 112 | "lr": current_lr, 113 | }) 114 | 115 | model.train() 116 | t0 = time.time() 117 | if distributed_backend.is_master_process(): 118 | if extra_args.save_checkpoint_freq is not None and itr % extra_args.save_checkpoint_freq == 0: 119 | print(f"saving checkpoint to {ckpt_path}/ckpt_{itr}.pt") 120 | save_checkpoint(distributed_backend=distributed_backend, 121 | model=model, 122 | opt=opt, 123 | scheduler=scheduler, 124 | itr=itr, 125 | ckpt_path=f"{ckpt_path}/ckpt_{itr}.pt") 126 | 127 | if distributed_backend.is_master_process(): 128 | print(f"saving checkpoint to {ckpt_path}") 129 | save_checkpoint(distributed_backend=distributed_backend, 130 | model=model, 131 | opt=opt, 132 | scheduler=scheduler, 133 | itr=itr, 134 | ckpt_path=f"{ckpt_path}/ckpt.pt") 135 | 136 | return stats 137 | -------------------------------------------------------------------------------- /lm_benchmark/optim/transformer_xl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from contextlib import nullcontext 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | import wandb 20 | import time 21 | import copy 22 | import traceback 23 | 24 | from .utils import get_batch, save_checkpoint 25 | 26 | 27 | @torch.no_grad() 28 | def eval(model, data_tensor, sequence_length, total_sequence_length, batch_size, device='cpu', max_num_batches=24, ctx=nullcontext()): 29 | assert model.training == False 30 | 31 | loss_list_val, acc_list = [], [] 32 | 33 | for _ in range(max_num_batches): 34 | x, y = get_batch(data_tensor, total_sequence_length, batch_size, device=device) 35 | model.clear_state() 36 | total_loss = None 37 | for idx in range(0, x.shape[1], sequence_length): 38 | x_part = x[:, idx:idx+sequence_length] 39 | y_part = y[:, idx:idx+sequence_length].contiguous() 40 | with ctx: 41 | outputs = model(x_part, targets=y_part, get_logits=True, use_cache=True) 42 | val_loss = outputs['loss'] 43 | if idx == 0: 44 | total_loss = val_loss 45 | else: 46 | total_loss += val_loss 47 | loss_list_val.append(total_loss) 48 | acc_list.append((outputs['logits'].argmax(-1) == y_part).float().mean()) 49 | 50 | val_acc = torch.stack(acc_list).mean().item() 51 | val_loss = torch.stack(loss_list_val).mean().item() 52 | val_perplexity = 2.71828 ** val_loss 53 | 54 | return val_acc, val_loss, val_perplexity 55 | 56 | def train_xl(model, opt, data, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend, extra_args): 57 | device_type = 'cuda' if 'cuda' in str(extra_args.device) else 'cpu' 58 | type_ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast( 59 | device_type=device_type, dtype=extra_args.dtype) # extra_args.dtype) 60 | itr, substep, best_val_loss, text_table = 0, 0, float('inf'), None # best_val_loss not used atm, early stopping not recommended but possible 61 | 62 | stats = {'train_loss': [], 'val_loss': [], 'val_pp': [], 'val_acc': []} 63 | 64 | num_substeps_per_epoch = len(data['train']) // (batch_size * sequence_length) 65 | 66 | if not extra_args.no_compile: 67 | print(f"Compiling model ...") 68 | import torch._dynamo as torchdynamo 69 | torchdynamo.config.guard_nn_modules = True 70 | model = torch.compile(model) # requires pytorch 2.0+ 71 | 72 | model.train() 73 | 74 | if extra_args.postpone_lm_cache: 75 | distributed_backend.get_raw_model(model).init_cache() 76 | 77 | t0 = time.time() 78 | while itr < iterations: 79 | for microstep_idx in range(acc_steps): # gradient accumulation 80 | x, y = get_batch(data['train'], extra_args.total_sequence_length, batch_size, device=extra_args.device) 81 | distributed_backend.get_raw_model(model).clear_state() 82 | total_loss = None 83 | for idx in range(0, x.shape[1], extra_args.sequence_length): 84 | with type_ctx: 85 | with distributed_backend.get_context_for_microstep_forward(model=model, microstep_idx=microstep_idx, gradient_accumulation_steps=acc_steps): 86 | outputs = model(x[:, idx:idx+extra_args.sequence_length], targets=y[:, idx:idx+extra_args.sequence_length].contiguous(), use_cache=True) 87 | 88 | loss = outputs['loss'] 89 | loss.backward() 90 | if idx == 0: 91 | total_loss = loss 92 | else: 93 | total_loss += loss 94 | substep += 1 95 | 96 | opt.step() 97 | scheduler.step() 98 | opt.zero_grad(set_to_none=True) 99 | itr += 1 100 | 101 | if itr % eval_freq == 0 or itr == iterations: # from here it's only evaluation code, all the training is above 102 | if distributed_backend.is_master_process(): 103 | t1 = time.time() 104 | dt = t1 - t0 105 | epoch = substep//num_substeps_per_epoch 106 | 107 | model.eval() 108 | train_loss = loss.detach().cpu().item() 109 | current_lr = scheduler.get_last_lr()[0] if scheduler is not None else extra_args.lr 110 | val_acc, val_loss, val_perplexity = eval(distributed_backend.get_raw_model(model), data['val'], sequence_length, extra_args.total_sequence_length, 111 | batch_size, extra_args.device, max_num_batches=24, ctx=type_ctx) 112 | 113 | print_string = f"{epoch}/{itr} [train] loss={train_loss:.3f} [val] loss={val_loss:.3f}, pp={val_perplexity:.2f}, acc={val_acc:3f}" 114 | print_string += f" [time per itr] {dt*1000/eval_freq:.2f}ms" 115 | if scheduler is not None: 116 | print_string += f" [lr] {current_lr:.5f}" 117 | print(print_string) 118 | 119 | if extra_args.wandb: 120 | wandb.log({ 121 | "iter": itr, 122 | "train/loss": train_loss, 123 | "val/loss": val_loss, 124 | "val/perplexity": val_perplexity, 125 | "val/acc": val_acc, 126 | "lr": current_lr, 127 | }) 128 | 129 | model.train() 130 | t0 = time.time() 131 | 132 | if distributed_backend.is_master_process(): 133 | print(f"saving checkpoint to {ckpt_path}") 134 | save_checkpoint(distributed_backend=distributed_backend, 135 | model=model, 136 | opt=opt, 137 | scheduler=scheduler, 138 | itr=itr, 139 | ckpt_path=ckpt_path) 140 | 141 | return stats 142 | -------------------------------------------------------------------------------- /lm_benchmark/optim/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn.functional as F 18 | from contextlib import nullcontext, contextmanager, ExitStack 19 | 20 | 21 | def get_batch(data, seq_length, batch_size, device='cpu'): 22 | ix = torch.randint(len(data) - seq_length - 1, (batch_size,)) 23 | x = torch.stack([torch.from_numpy((data[i:i+seq_length]).astype(np.int64)) for i in ix]) 24 | y = torch.stack([torch.from_numpy((data[i+1:i+1+seq_length+1]).astype(np.int64)) for i in ix]) 25 | y = torch.where(y[:, :-1] == 50260, y[:, 1:], y[:, :-1]) 26 | y = torch.where((x == 50260) | (x == 50256) , -1, y) 27 | if device != 'cpu': 28 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 29 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 30 | #x, y = x.to(device), y.to(device) 31 | return x, y 32 | 33 | 34 | def save_checkpoint(distributed_backend, model, opt, scheduler, itr, ckpt_path, **extra_args): 35 | 36 | checkpoint = dict({ 37 | 'model': distributed_backend.get_raw_model(model).state_dict(), 38 | 'optimizer': opt.state_dict(), 39 | 'scheduler': scheduler.state_dict(), 40 | 'itr': itr, 41 | }, **extra_args) 42 | 43 | torch.save(checkpoint, ckpt_path) 44 | -------------------------------------------------------------------------------- /lm_benchmark/requirements.txt: -------------------------------------------------------------------------------- 1 | tiktoken 2 | --find-links https://download.pytorch.org/whl/torch_stable.html 3 | torch==2.0.0+cu118 4 | torchaudio==2.0.0+cu118 5 | torchvision==0.15.0+cu118 6 | tqdm 7 | transformers 8 | wandb --------------------------------------------------------------------------------