├── .gitignore
├── LICENSE
├── README.md
├── llama
├── install_deps.sh
├── llama_landmark_config.py
├── llama_mem.py
├── ltriton
│ ├── flash_landmark_attention.py
│ └── test_flash_landmark_attention.py
├── redpajama.py
├── requirements.txt
├── run_test.py
├── train.py
├── urls
│ ├── arxiv.txt
│ ├── book.txt
│ ├── c4.txt
│ ├── common_crawl.txt
│ ├── github.txt
│ ├── stackexchange.txt
│ └── wikipedia.txt
└── weight_diff.py
├── llama_legacy
├── llama_mem.py
├── redpajama.py
├── requirements.txt
├── run_test.py
├── train.py
├── urls
│ ├── arxiv.txt
│ ├── book.txt
│ ├── c4.txt
│ ├── common_crawl.txt
│ ├── github.txt
│ ├── stackexchange.txt
│ └── wikipedia.txt
└── weight_diff.py
└── lm_benchmark
├── config
├── __init__.py
└── rotary.py
├── data
├── __init__.py
├── arxiv_math.py
├── pg19.py
├── pg19
│ ├── README.md
│ └── prepare.py
├── proof-pile
│ └── prepare.py
└── utils.py
├── distributed
├── __init__.py
├── backend.py
├── ddp.py
└── single.py
├── eval.py
├── eval_cmd_generator.py
├── main.py
├── models
├── __init__.py
├── base_new.py
├── caches
│ ├── __init__.py
│ ├── cache.py
│ ├── kv_cache.py
│ ├── kv_cache_train.py
│ └── mem_cache.py
├── landmark.py
├── landmark_with_cmt.py
└── positional_encoders
│ ├── __init__.py
│ ├── encoder.py
│ ├── rotary.py
│ ├── rotary_mem_jump.py
│ └── rotary_utils.py
├── optim
├── base.py
├── transformer_xl.py
└── utils.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Dataset folder
2 | data/datasets/
3 | wandb/
4 | exps/
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | pip-wheel-metadata/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100 | __pypackages__/
101 |
102 | # Celery stuff
103 | celerybeat-schedule
104 | celerybeat.pid
105 |
106 | # SageMath parsed files
107 | *.sage.py
108 |
109 | # Environments
110 | .env
111 | .venv
112 | env/
113 | venv/
114 | ENV/
115 | env.bak/
116 | venv.bak/
117 |
118 | # Spyder project settings
119 | .spyderproject
120 | .spyproject
121 |
122 | # Rope project settings
123 | .ropeproject
124 |
125 | # mkdocs documentation
126 | /site
127 |
128 | # mypy
129 | .mypy_cache/
130 | .dmypy.json
131 | dmypy.json
132 |
133 | # Pyre type checker
134 | .pyre/
135 |
136 | .DS_Store
137 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Landmark Attention
2 |
3 | This repository contains the implementation of landmark attention as described in our paper:
4 |
5 | **Landmark Attention: Random-Access Infinite Context Length for Transformers**
6 | Amirkeivan Mohtashami, Martin Jaggi
7 | NeurIPS 2023: https://arxiv.org/abs/2305.16300
8 |
9 | ## Repository Structure
10 |
11 | The repository contains three code bases under the following directories:
12 |
13 | 1. `lm_benchmark`: This directory contains the code used for performing language modeling over PG19 and arXiv Math datasets.
14 | 2. `llama_legacy`: This directory contains the code used to obtain the results of fine-tuning LLaMA as reported in the paper. The code in this directory is frozen to allow reproduction of the results. Thus, except when trying to exactly replicate our results, we suggest using the code under `llama` directory.
15 | 3. `llama`: This directory contains the current implementation of landmark attention. The directory includes both a high-level implementation and a Triton implementation of landmark attention combined with Flash Attention. As an example, the directory contains the code for applying the implementation to LLaMA models.
16 |
17 |
18 | Note: During the development of this project, we made the decision to update the names of certain components. However, as this decision was made later in the project timeline, you may encounter references to the old names within the code (e.g. `mem` instead of `landmark`). We are working to address this issue.
19 |
20 |
21 | ## Language Modeling Benchmarks
22 | ### Training
23 | For training, the landmark tokens are added during data preparation. The following command is an example of training a model on PG19 with landmark tokens added every 50 tokens:
24 | ```
25 | python main.py \
26 | --config_format rotary \
27 | --model landmark \
28 | --n_embd 1024 \
29 | --n_head 8 \
30 | --n_layer 12 \
31 | --batch_size 16 \
32 | --sequence_length 512 \
33 | --acc_steps 8 \
34 | --wandb_project memory-llm \
35 | --dataset pg19 \
36 | --iterations 240000 \
37 | --dropout 0.0 \
38 | --positional_encoder rotary \
39 | --softmax_func mem_opt \
40 | --mem_freq 50 \
41 | --wandb \
42 | --save_checkpoint_freq 20000
43 | ```
44 |
45 | To run on multi-GPUs use torchrun (e.g. `torchrun --nproc_per_node=4`) and pass `--distributed_backend nccl` to `main.py` script. We suggest first running the script until the training starts on a single GPU before switching to multi-GPU settings. This is because the first node will have to perform the initialization of the data which can take a long time leading to a timeout on the synchronization in multi-GPU settings. However, once the initialization is performed once, the result is stored on the disk so the next runs will be quick.
46 |
47 | You will need to initialize the dataset before running the training script. For instructions, use the `prepare.py` script in the corresponding dataset folder located inside `data/`.
48 |
49 | ### Inference
50 | The code supports inference in various settings. To perform standard evaluation, disable cache and use the same chunk size (specified using `--mid_length` flag) as the evaluation length (specified by `--eval_seq_length`). Using landmarks is possible when using `mem_cache`. The script `eval_cmd_generator.py` can be used to generate a bash script containining commands to perform evaluations corresponding to Tables 1 and 2 of the paper. The path of the output models need to be updated inside the script.
51 |
52 | ## LLaMA fine-tuning
53 | The code for fine-tuning LLaMA and testing the final model is available as a standalone project in the sub-directory "llama". An example for running the fine tuning (from inside the sub-directory) is:
54 |
55 | ```
56 | torchrun --nproc_per_node=8 train.py \
57 | --model_name_or_path /llama_weights/7B_hf/ \
58 | --bf16 True \
59 | --output_dir /llama-redpajama-mem-15000-with-mem/ \
60 | --cache_dir /hf-cache/ \
61 | --num_train_epochs 1 \
62 | --per_device_train_batch_size 2 \
63 | --per_device_eval_batch_size 2 \
64 | --gradient_accumulation_steps 8 \
65 | --evaluation_strategy "no" \
66 | --save_strategy "steps" \
67 | --save_steps 2000 \
68 | --save_total_limit 2 \
69 | --learning_rate 2e-5 \
70 | --weight_decay 0.1 \
71 | --warmup_ratio 0.03 \
72 | --lr_scheduler_type "cosine" \
73 | --logging_steps 1 \
74 | --fsdp "full_shard auto_wrap" \
75 | --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
76 | --tf32 True \
77 | --max_steps 15000
78 | ```
79 |
80 | In the above example, LLaMA wieghts (converted to huggingface format) should be in `/llama_weights/7B_hf/`.
81 |
82 | ### Fine-tuned Weights
83 | We have released the weight diff between the original LLaMA 7B and the same model fine-tuned for 15000 steps on [RedPajama](https://github.com/togethercomputer/RedPajama-Data) dataset with landmark attention [here](https://huggingface.co/epfml/landmark-attention-llama7b-wdiff). You may use the `weight_diff.py` script to recover the weights:
84 | ```
85 | python weight_diff.py recover --path_raw --path_diff --path_tuned
86 | ```
87 | For an example of how to perform inference using landmarks, look at `run_test.py`.
88 |
89 | ### Triton Implementation
90 |
91 | We have added a Triton implementation of the combination of our method and Flash Attention which significantly reduces memory usage and also increases performance. Using this implementation, we trained LLaMA 7B with 2048 context length (instead of 512). Also, adding landmark attention to any model can be done by applying the following changes:
92 |
93 | 1. Adding landmark tokens to the input at regular intervals of block size.
94 | 2. (Optional) Creating a boolean mask of which tokens are landmarks. The mask can be passed to the landmark attention function to ensure the landmarks are placed correctly. This step can be skipped to obtain the highest speed.
95 | 3. Replacing `torch.nn.functional.scaled_dot_product_attention` with `fused_landmark_attention`.
96 |
97 | Note that the implemnetation relies on the latest version of Triton which causes a conflict with latest version of PyTorch. Therefore, a special `install_deps.sh` script is provided to install the dependencies.
98 |
99 | Finally, note that the current implementation makes the following assumptions:
100 |
101 | 1. The implementation assumes the landmark blocks have the same size as blocks used for computing the attention in Flash Attention. This limits the maximum size of the block as the whole landmark block should fit into GPU's local memory. However, using bfloat16 it should be possible to use block sizes as large as 64 or 128 which should be enough for landmark blocks.
102 | 2. The implementation assumes the difference between number of keys and queries is a multiple of the block size. Therefore, normal attention must be applied in the auto-regressive part of the generation when the tokens are generated one by one. The implemnetation can still be used to go over the input before reaching the generation.
103 | Note that this is not a big limitation since when generating tokens one by one, the attention matrix has only a single row, limiting the benefits of Flash Attention.
104 | 3. While the high level implementation allows the landmark tokens to be placed anywhere, the fused implementation assumes the landmark tokens are placed regularly at the end of each block. Since we always use this pattern at inference, this should not be noticed.
105 |
106 |
--------------------------------------------------------------------------------
/llama/install_deps.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | pip install -r requirements.txt
4 | pip install "git+https://github.com/openai/triton.git#subdirectory=python"
5 |
--------------------------------------------------------------------------------
/llama/llama_landmark_config.py:
--------------------------------------------------------------------------------
1 | from transformers.models.llama.configuration_llama import LlamaConfig
2 |
3 | class LlamaLandmarkConfig(LlamaConfig):
4 | model_type = "llama_with_landmark"
5 |
6 | def __init__(
7 | self,
8 | mem_id=32001,
9 | mem_freq=50,
10 | train_context_length=512,
11 | include_landmark_in_loss=True,
12 | **kwargs,
13 | ):
14 | self.mem_id = mem_id
15 | self.mem_freq = mem_freq
16 | self.train_context_length = train_context_length
17 | self.include_landmark_in_loss = include_landmark_in_loss
18 | super().__init__(**kwargs)
19 |
--------------------------------------------------------------------------------
/llama/ltriton/test_flash_landmark_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 | from flash_landmark_attention import fused_landmark_attention
5 |
6 |
7 | class LandmarkGroupedSoftmaxFunction(torch.autograd.Function):
8 |
9 | # Note that forward, setup_context, and backward are @staticmethods
10 | @staticmethod
11 | def forward(ctx, x, dim, mem_cnt, resp_mem_idx):
12 | new_shape = list(x.shape)
13 | new_shape[dim] = mem_cnt # max_mem_cnt.item()
14 | max_by_group = x.new_zeros((*new_shape,))
15 | max_by_group.scatter_reduce_(src=x, index=resp_mem_idx, dim=dim, reduce="amax", include_self=False)
16 |
17 | maxes = torch.gather(max_by_group, dim, resp_mem_idx)
18 | #x_exp = torch.exp(x - torch.where(torch.isinf(maxes), 0, maxes))
19 | x_exp = torch.exp((x - maxes).to(torch.float32))
20 |
21 | cumsum_by_group = torch.zeros_like(max_by_group, dtype=x_exp.dtype)
22 |
23 | cumsum_by_group.scatter_add_(dim, resp_mem_idx, x_exp, )
24 | denom = torch.gather(cumsum_by_group, dim, resp_mem_idx)
25 |
26 | #probs = torch.where(denom < 0.5, 0, x_exp / denom)
27 | probs = x_exp / denom
28 |
29 |
30 | ctx.mem_cnt = mem_cnt
31 | ctx.dim = dim
32 | ctx.save_for_backward(resp_mem_idx, probs)
33 |
34 | return probs
35 |
36 | @staticmethod
37 | def backward(ctx, grad_probs):
38 | mem_cnt = ctx.mem_cnt
39 | dim = ctx.dim
40 | resp_mem_idx, probs = ctx.saved_tensors
41 | grad_x = grad_dim = grad_mem_cnt = grad_resp_mem_idx = None
42 |
43 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[4]:
44 | grad_pair = grad_probs * probs
45 |
46 | new_shape = list(probs.shape)
47 | new_shape[dim] = mem_cnt # max_mem_cnt.item()
48 | cumsum_by_group = grad_pair.new_zeros((*new_shape,))
49 | cumsum_by_group.scatter_add_(dim, resp_mem_idx, grad_pair)
50 |
51 |
52 | if ctx.needs_input_grad[0]:
53 | grad_sum = torch.gather(cumsum_by_group, dim, resp_mem_idx)
54 | grad_x = grad_pair - probs * grad_sum
55 | assert not ctx.needs_input_grad[1]
56 | assert not ctx.needs_input_grad[2]
57 | assert not ctx.needs_input_grad[3]
58 |
59 | return grad_x, grad_dim, grad_mem_cnt, grad_resp_mem_idx
60 |
61 | def landmark_grouped_softmax(x, dim, is_mem, last_section_mask):
62 |
63 | last_and_rest_mask = last_section_mask # | mask
64 |
65 | full_access_mask = is_mem | last_and_rest_mask
66 |
67 | max_mem_cnt = 16
68 | mem_group_idx = torch.cumsum(is_mem, dim=dim)
69 | mem_bucket_id = max_mem_cnt - 1
70 | resp_mem_idx = torch.where(last_and_rest_mask,
71 | max_mem_cnt - 1,
72 | torch.where(is_mem, mem_bucket_id, mem_group_idx))
73 | probs = LandmarkGroupedSoftmaxFunction.apply(x, dim, max_mem_cnt, resp_mem_idx)
74 |
75 | new_shape = list(x.shape)
76 | new_shape[dim] = max_mem_cnt
77 | group_prob = probs.new_zeros((*new_shape, ))
78 | group_prob.scatter_(dim, torch.where(is_mem, mem_group_idx - 1, max_mem_cnt - 1), probs)
79 | probs = probs.mul(torch.where(full_access_mask, last_section_mask, torch.gather(group_prob, dim, resp_mem_idx)))
80 |
81 |
82 | return probs
83 |
84 | batch = 2
85 | nheads = 8
86 | seqlen_q = 1024
87 | seqlen_k = 1024 #512
88 | d = 128
89 | use_I_for_v = False
90 | mem_freq = 63
91 | q = torch.rand((batch, seqlen_q, nheads, d)).cuda().to(torch.bfloat16).transpose(1, 2)
92 | k = torch.rand((batch, seqlen_k, nheads, d)).cuda().to(torch.bfloat16).transpose(1, 2)
93 | if not use_I_for_v:
94 | v = torch.rand((batch, seqlen_k, nheads, d)).cuda().to(torch.bfloat16).transpose(1, 2)
95 | else:
96 | v = torch.eye(seqlen_k, d).cuda().to(torch.bfloat16)
97 | v = v.view(1, 1, seqlen_k, d).expand(batch, nheads, seqlen_k, d)
98 | q.requires_grad = True
99 | k.requires_grad = True
100 | v.requires_grad = True
101 | block_size = mem_freq + 1
102 | is_mem = torch.arange(0, seqlen_k, device=q.device) % block_size == (block_size - 1)
103 | out = fused_landmark_attention(q, k, v, is_mem, block_size=block_size)
104 |
105 | def f():
106 | import math
107 | att = q @ k.transpose(-1, -2) / math.sqrt(d)
108 | att_mask = torch.tril(torch.ones((1, 1, seqlen_q, seqlen_k), device=q.device), diagonal=seqlen_k - seqlen_q)== 1.
109 |
110 | last_section_mask = (torch.arange(0, seqlen_k, device=q.device) // (mem_freq + 1))[None, :] == (torch.arange(seqlen_k - seqlen_q, seqlen_k, device=q.device) // (mem_freq + 1))[:, None]
111 |
112 | last_section_mask = last_section_mask.unsqueeze(0).unsqueeze(1)
113 | is_mem_ = is_mem.view(1, 1, 1, seqlen_k)
114 | mask = att_mask & ~(last_section_mask & is_mem_)
115 | last_section_mask = last_section_mask & mask
116 | is_mem_ = is_mem_ & mask
117 | is_mem_ = is_mem_.expand(batch, nheads, seqlen_q, seqlen_k)
118 | last_section_msak = last_section_mask.expand(batch, nheads, seqlen_q, seqlen_k)
119 | att.masked_fill_(~mask, float("-inf"))
120 |
121 |
122 | att = landmark_grouped_softmax(att, -1, is_mem_, last_section_mask).to(q.dtype)
123 | att.masked_fill_(~mask, 0.0)
124 | exact_out = att @ v
125 | return exact_out
126 | exact_out = f()
127 |
128 | def make_f_grad(func):
129 | def f_():
130 | exact_out = func()
131 | return torch.autograd.grad((exact_out**2).sum(), [q, k, v])
132 | return f_
133 |
134 |
135 | def f_exact():
136 | return fused_landmark_attention(q, k, v, is_mem, block_size=block_size)
137 |
138 | def f_torch():
139 | return torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
140 |
141 | if use_I_for_v and d >= seqlen_k:
142 | assert torch.allclose(out.sum(-1), torch.ones_like(out.sum(-1)), atol=1e-03, rtol=1e-03), out.sum(-1)
143 | assert torch.allclose(exact_out.sum(-1), torch.ones_like(exact_out.sum(-1)))
144 |
145 | #print("Exact", exact_out[:, :, mem_freq-1:mem_freq+2])
146 | #print("Fused", out[:, :, mem_freq-1:mem_freq+2])
147 | assert torch.allclose(out, exact_out, rtol=1e-02, atol=1e-02), (out, exact_out)
148 | #print("Diff", (out - exact_out).max(dim=-2))
149 |
150 |
151 | #print(last_section_mask[0, 0])
152 | #print(att[0, 0])
153 | #print(is_mem[0, 0])
154 | #print(is_mem[0, 0, -1])
155 |
156 | grads = torch.autograd.grad((out ** 2).sum(), [q, k, v])
157 | exact_grads = torch.autograd.grad((exact_out ** 2).sum(), [q, k, v])
158 | #print(len(exact_grads))
159 |
160 | for grad, exact_grad, t in zip(grads, exact_grads, ["q", "k", "v"]):
161 | torch.set_printoptions(sci_mode=False)
162 | if not torch.allclose(grad, exact_grad, rtol=4e-02, atol=4e-02):
163 | #print((grad, exact_grad, t))
164 | print("Failed d", t)
165 | #print(t, (grad - exact_grad).max())
166 | #print(t, torch.argmax((grad - exact_grad).amax(dim=-1), dim=-1))
167 |
168 | print("Done once")
169 | #print(v.grad)
170 |
171 | import timeit
172 | print("Exact: ", timeit.timeit(f, number=500))
173 | print("Fused: ", timeit.timeit(f_exact, number=500))
174 | print("Torch: ", timeit.timeit(f_torch, number=500))
175 |
176 | print("Exact Grad: ", timeit.timeit(make_f_grad(f), number=500))
177 | print("Fused Grad: ", timeit.timeit(make_f_grad(f_exact), number=500))
178 | print("Torch Grad: ", timeit.timeit(make_f_grad(f_torch), number=500))
179 |
--------------------------------------------------------------------------------
/llama/redpajama.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Together Computer
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Lint as: python3
16 | """RedPajama: An Open-Source, Clean-Room 1.2 Trillion Token Dataset."""
17 |
18 |
19 | import json
20 |
21 | import datasets
22 | import traceback
23 | import numpy as np
24 | import math
25 |
26 | logger = datasets.logging.get_logger(__name__)
27 |
28 |
29 | _DESCRIPTION = """\
30 | RedPajama is a clean-room, fully open-source implementation of the LLaMa dataset.
31 | """
32 |
33 | _URL_LISTS = {
34 | "arxiv": "urls/arxiv.txt",
35 | "book": "urls/book.txt",
36 | "c4": "urls/c4.txt",
37 | "common_crawl": "urls/common_crawl.txt",
38 | "github": "urls/github.txt",
39 | "stackexchange": "urls/stackexchange.txt",
40 | "wikipedia": "urls/wikipedia.txt",
41 | }
42 |
43 |
44 | class RedPajama1TConfig(datasets.BuilderConfig):
45 | """BuilderConfig for RedPajama sample."""
46 |
47 | def __init__(self, *args, subsets, p_sample=None, **kwargs):
48 | """BuilderConfig for RedPajama.
49 | Args:
50 | **kwargs: keyword arguments forwarded to super.
51 | """
52 | super(RedPajama1TConfig, self).__init__(**kwargs)
53 |
54 | self.subsets = subsets
55 | self.p_sample = p_sample
56 |
57 |
58 | class RedPajama1T(datasets.GeneratorBasedBuilder):
59 | """RedPajama: Reproducing the LLaMA training dataset of over 1.2 trillion tokens. Version 1.0.0."""
60 | BUILDER_CONFIG_CLASS = RedPajama1TConfig
61 | BUILDER_CONFIGS = [
62 | RedPajama1TConfig(
63 | subsets = list(_URL_LISTS.keys()),
64 | name="plain_text",
65 | version=datasets.Version("1.0.0", ""),
66 | description="Plain text",
67 | ),
68 | RedPajama1TConfig(
69 | subsets = list(_URL_LISTS.keys()),
70 | name="plain_text_tenpercent",
71 | version=datasets.Version("1.0.0", ""),
72 | description="Plain text",
73 | p_sample=0.1
74 | ),
75 | ]
76 |
77 | def _info(self):
78 | return datasets.DatasetInfo(
79 | description=_DESCRIPTION,
80 | features=datasets.Features(
81 | {
82 | "text": datasets.Value("string"),
83 | "meta": datasets.Value("string"),
84 | "red_pajama_subset": datasets.Value("string"),
85 | }
86 | ),
87 | supervised_keys=None,
88 | )
89 |
90 | def _split_generators(self, dl_manager):
91 | url_lists = dl_manager.download_and_extract({
92 | subset: _URL_LISTS[subset] for subset in self.config.subsets
93 | })
94 |
95 | urls = {}
96 | rng = np.random.default_rng(seed=2)
97 |
98 | for subset, url_list in url_lists.items():
99 | with open(url_list, encoding="utf-8") as f:
100 | urls[subset] = [line.strip() for line in f]
101 | if self.config.p_sample is not None:
102 | urls[subset] = rng.choice(
103 | urls[subset],
104 | size=int(math.ceil(len(urls[subset]) * self.config.p_sample)), replace=False).tolist()
105 |
106 | downloaded_files = dl_manager.download(urls)
107 |
108 | return [
109 | datasets.SplitGenerator(
110 | name=datasets.Split.TRAIN,
111 | gen_kwargs = {
112 | "files": {
113 | subset: downloaded_files[subset]
114 | for subset in self.config.subsets
115 | }
116 | }
117 | )
118 | ]
119 |
120 | def _generate_examples(self, files):
121 | """This function returns the examples in the raw (text) form."""
122 | key = 0
123 | for subset in files:
124 | if subset == "common_crawl":
125 | import zstandard as zstd
126 |
127 | for path in files[subset]:
128 | with zstd.open(open(path, "rb"), "rt", encoding="utf-8") as f:
129 | for i, row in enumerate(f):
130 | try:
131 | data = json.loads(row)
132 | text = data["text"]
133 | del data["text"]
134 | yield key, {
135 | "text": text,
136 | "meta": json.dumps(data),
137 | "red_pajama_subset": subset,
138 | }
139 | key += 1
140 | except Exception as e:
141 | print(f'Subset: {subset}')
142 | print(f'Path: {path}')
143 | print(f'Row: {row}')
144 | traceback.print_exc()
145 |
146 | raise e
147 | else:
148 | for path in files[subset]:
149 | with open(path, encoding="utf-8") as f:
150 | for i, row in enumerate(f):
151 | try:
152 | data = json.loads(row)
153 | if "meta" not in data:
154 | text = data["text"]
155 | del data["text"]
156 | yield key, {
157 | "text": text,
158 | "meta": json.dumps(data),
159 | "red_pajama_subset": subset,
160 | }
161 | else:
162 | yield key, {
163 | "text": data["text"],
164 | "meta": data["meta"],
165 | "red_pajama_subset": subset,
166 | }
167 | key += 1
168 | except Exception as e:
169 | print(f'Subset: {subset}')
170 | print(f'Path: {path}')
171 | print(f'Row: {row}')
172 | traceback.print_exc()
173 |
174 | raise e
175 |
--------------------------------------------------------------------------------
/llama/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | rouge_score
3 | fire
4 | openai
5 | transformers>=4.28.1
6 | torch>=2.0
7 | sentencepiece
8 | tokenizers>=0.13.3
9 | wandb
10 | accelerate
11 | datasets
12 |
13 |
14 | # We need Triton 2.1 but this conflicts with PyTorch 2.0 for now. Therefore the following line is commented out for now. Instead use install_deps.sh
15 | #git+https://github.com/openai/triton.git#subdirectory=python
16 |
--------------------------------------------------------------------------------
/llama/run_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 |
17 | import os
18 | import random
19 | import re
20 | import requests
21 |
22 |
23 | llama_weights_7b_base = "/llama_weights/7B_hf/"
24 | llama_weights_7b_tuned = "/llama-redpajama-mem-15000-with-mem/"
25 | cache_path = "/hf-cache/"
26 | use_flash = False # using flash for inference is only implemented for when offloading kv to cpu
27 | top_k = 5
28 | dtype = torch.bfloat16
29 |
30 | def make_llama_base_pipe():
31 |
32 | from transformers import pipeline
33 |
34 | from transformers.models.llama import LlamaForCausalLM
35 |
36 | llama_base = LlamaForCausalLM.from_pretrained(
37 | llama_weights_7b_base,
38 | cache_dir=cache_path,
39 | )
40 |
41 | llama_base = llama_base.to('cuda:0')
42 |
43 | import transformers
44 |
45 | tokenizer = transformers.AutoTokenizer.from_pretrained(
46 | llama_weights_7b_base,
47 | cache_dir=cache_path,
48 | model_max_length=2048,
49 | padding_side="right",
50 | use_fast=False,
51 | )
52 |
53 | llama_base_pipe = pipeline("text-generation", model=llama_base, tokenizer=tokenizer, device=llama_base.device)
54 | return llama_base_pipe
55 |
56 |
57 |
58 | llama_base_pipe = make_llama_base_pipe()
59 |
60 | def make_llama_mem_pipe():
61 | from llama_mem import LlamaForCausalLM
62 |
63 | model = LlamaForCausalLM.from_pretrained(
64 | llama_weights_7b_tuned,
65 | cache_dir=cache_path,
66 | torch_dtype=dtype
67 | )
68 |
69 | model.to('cuda:1')
70 |
71 | import transformers
72 |
73 | tokenizer = transformers.AutoTokenizer.from_pretrained(
74 | llama_weights_7b_tuned,
75 | cache_dir=cache_path,
76 | model_max_length=model.config.train_context_length,
77 | padding_side="right",
78 | use_fast=False,
79 | )
80 | mem_id = tokenizer.convert_tokens_to_ids("")
81 | model.set_mem_id(mem_id)
82 | from transformers import pipeline
83 | llama_mem_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=model.device,
84 | offload_cache_to_cpu=use_flash, use_flash=use_flash,
85 | cache_top_k=top_k)
86 | return llama_mem_pipe
87 |
88 |
89 | llama_mem_pipe = make_llama_mem_pipe()
90 |
91 |
92 |
93 | pipes = {"base": llama_base_pipe, "mem": llama_mem_pipe}
94 |
95 |
96 | def generate_prompt(n_garbage):
97 | """Generates a text file and inserts an execute line at a random position."""
98 | n_garbage_prefix = random.randint(0, n_garbage)
99 | n_garbage_suffix = n_garbage - n_garbage_prefix
100 |
101 | task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."
102 | garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
103 | garbage_inf = " ".join([garbage] * 2000)
104 | assert len(garbage_inf) >= n_garbage
105 | garbage_prefix = garbage_inf[:n_garbage_prefix]
106 | garbage_suffix = garbage_inf[:n_garbage_suffix]
107 | pass_key = random.randint(1, 50000)
108 | information_line = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key."
109 | final_question = "What is the pass key? The pass key is"
110 | lines = [
111 | task_description,
112 | garbage_prefix,
113 | information_line,
114 | garbage_suffix,
115 | final_question
116 | ]
117 | return "\n".join(lines), pass_key
118 |
119 |
120 |
121 | def test_model(prompt_text, pass_key, model_name):
122 | response = pipes[model_name](prompt_text,num_return_sequences=1, max_new_tokens=10)[0]["generated_text"][len(prompt_text):]
123 | assert f"The pass key is {pass_key}" in prompt_text
124 |
125 | try:
126 | pass_key = int(re.search(r'\d+', response).group())
127 | except:
128 | pass_key = response[:20]
129 |
130 | return pass_key
131 |
132 |
133 | n_values = [0, 100, 500, 1000, 5000, 8000, 10000, 12000, 14000, 18000, 20000, 25000, 38000]
134 | num_tests = 50
135 | models = ["base", "mem"]
136 | accuracies = {x: [] for x in models}
137 | individual_results = {x: [] for x in models}
138 |
139 | for n in n_values:
140 |
141 | correct_count = {x: 0 for x in models}
142 |
143 | n_results = {x: [] for x in models}
144 | for i in range(num_tests):
145 | print(f"\nRunning test {i + 1}/{num_tests} for n = {n}...")
146 | prompt_text, pass_key = generate_prompt(n)
147 |
148 |
149 |
150 | for model_name in models:
151 | if pipes[model_name] is None:
152 | continue
153 | num_tokens = len(pipes[model_name].tokenizer.encode(prompt_text))
154 |
155 | print("Number of tokens in this prompt: ", num_tokens)
156 | model_output = test_model(prompt_text, pass_key, model_name)
157 | print(f"Expected number in the prompt: {pass_key}, {model_name} output: {model_output}")
158 |
159 | if pass_key == model_output:
160 | correct_count[model_name] += 1
161 | n_results[model_name].append(1)
162 | print("Success!")
163 | else:
164 | n_results[model_name].append(0)
165 | print("Fail.")
166 |
167 | for model in models:
168 | accuracy = (correct_count[model] / num_tests) * 100
169 | print(f"Accuracy {model} for n = {n}: {accuracy}%")
170 | accuracies[model].append(accuracy)
171 | individual_results[model].append(n_results)
172 |
--------------------------------------------------------------------------------
/llama/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import copy
16 | import logging
17 | from dataclasses import dataclass, field
18 | from functools import partial
19 | from typing import Dict, Optional, Sequence
20 |
21 | import torch
22 | import transformers
23 | from torch.utils.data import Dataset
24 | from transformers import Trainer, DataCollatorForLanguageModeling, get_cosine_schedule_with_warmup
25 | from llama_mem import LlamaForCausalLM
26 |
27 | from torch.distributed import barrier
28 | import os
29 |
30 |
31 | from datasets import load_dataset
32 |
33 | IGNORE_INDEX = -100
34 | DEFAULT_PAD_TOKEN = "[PAD]"
35 | DEFAULT_EOS_TOKEN = ""
36 | DEFAULT_BOS_TOKEN = ""
37 | DEFAULT_UNK_TOKEN = ""
38 |
39 |
40 | @dataclass
41 | class ModelArguments:
42 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
43 |
44 | @dataclass
45 | class TrainingArguments(transformers.TrainingArguments):
46 | cache_dir: Optional[str] = field(default=None)
47 | optim: str = field(default="adamw_torch")
48 | model_max_length: int = field(
49 | default=512,
50 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
51 | )
52 | use_flash: bool = field(default=False)
53 | mem_freq: int = field(default=63)
54 |
55 |
56 | class TrainerCosine(Trainer):
57 | def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
58 | """
59 | Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
60 | passed as an argument.
61 |
62 | Args:
63 | num_training_steps (int): The number of training steps to do.
64 | """
65 | if self.args.lr_scheduler_type != "cosine":
66 | return super().create_scheduler(num_training_steps, optimizer)
67 | if self.lr_scheduler is None:
68 | self.lr_scheduler = get_cosine_schedule_with_warmup(
69 | optimizer=self.optimizer if optimizer is None else optimizer,
70 | num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
71 | num_training_steps=num_training_steps,
72 | num_cycles=0.4 # ~10% of the init lr
73 | )
74 | return self.lr_scheduler
75 |
76 |
77 | def smart_tokenizer_and_embedding_resize(
78 | special_tokens_dict: Dict,
79 | tokenizer: transformers.PreTrainedTokenizer,
80 | model: transformers.PreTrainedModel,
81 | ):
82 | """Resize tokenizer and embedding.
83 |
84 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
85 | """
86 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
87 | model.resize_token_embeddings(len(tokenizer))
88 |
89 | if num_new_tokens > 0:
90 | input_embeddings = model.get_input_embeddings().weight.data
91 | output_embeddings = model.get_output_embeddings().weight.data
92 |
93 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
94 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
95 |
96 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
97 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
98 |
99 | def tokenize_fn(tokenizer, example):
100 | context_length = tokenizer.model_max_length
101 | outputs = tokenizer(
102 | tokenizer.eos_token.join(example["text"]),
103 | truncation=False,
104 | return_tensors="pt",
105 | pad_to_multiple_of=context_length,
106 | padding=True,
107 | )
108 | return {"input_ids": outputs["input_ids"].view(-1, context_length)}
109 |
110 | def add_mem_tokens(example, mem_freq, mem_id):
111 | x = example["input_ids"]
112 | ret = []
113 | prev_idx = 0
114 | for t_idx in range(mem_freq, len(x), mem_freq):
115 | ret.extend(x[prev_idx:t_idx])
116 | ret.append(mem_id)
117 | prev_idx = t_idx
118 | ret.extend(x[prev_idx:])
119 | # drop attention_mask
120 | return {"input_ids": ret}
121 |
122 | def train():
123 | parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
124 | model_args, training_args = parser.parse_args_into_dataclasses()
125 |
126 | model = LlamaForCausalLM.from_pretrained(
127 | model_args.model_name_or_path,
128 | cache_dir=training_args.cache_dir,
129 | mem_freq=training_args.mem_freq,
130 | include_landmark_in_loss=not training_args.use_flash
131 | )
132 |
133 | tokenizer = transformers.AutoTokenizer.from_pretrained(
134 | model_args.model_name_or_path,
135 | cache_dir=training_args.cache_dir,
136 | model_max_length=training_args.model_max_length,
137 | padding_side="right",
138 | use_fast=False,
139 | )
140 | special_tokens_dict = dict()
141 | if tokenizer.pad_token is None:
142 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
143 | if tokenizer.eos_token is None:
144 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
145 | if tokenizer.bos_token is None:
146 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
147 | if tokenizer.unk_token is None:
148 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
149 | mem_token = ""
150 | special_tokens_dict["additional_special_tokens"] = [mem_token]
151 |
152 | smart_tokenizer_and_embedding_resize(
153 | special_tokens_dict=special_tokens_dict,
154 | tokenizer=tokenizer,
155 | model=model,
156 | )
157 |
158 | mem_id = tokenizer.convert_tokens_to_ids(mem_token)
159 | model.set_mem_id(mem_id)
160 |
161 | rank = int(os.environ.get('RANK', -1))
162 | if rank > 0:
163 | barrier()
164 | dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", cache_dir=training_args.cache_dir)
165 |
166 | dataset = dataset.map(partial(tokenize_fn,tokenizer),batched=True, num_proc=32, remove_columns=["text", "meta"])
167 |
168 | if training_args.use_flash:
169 | model.enable_landmark_insertion()
170 | model.enable_flash()
171 | else:
172 | dataset = dataset.map(
173 | partial(
174 | add_mem_tokens,
175 | mem_freq=training_args.mem_freq,
176 | mem_id=mem_id
177 | ), batched=False, num_proc=32)
178 |
179 | if rank == 0:
180 | barrier()
181 | print(dataset)
182 |
183 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
184 |
185 | trainer = TrainerCosine(
186 | model=model, tokenizer=tokenizer, args=training_args,
187 | train_dataset=dataset["train"],
188 | eval_dataset=None,
189 | data_collator=data_collator)
190 | trainer.train()
191 | trainer.save_state()
192 | trainer.save_model(output_dir=training_args.output_dir)
193 |
194 |
195 | if __name__ == "__main__":
196 | train()
197 |
--------------------------------------------------------------------------------
/llama/urls/arxiv.txt:
--------------------------------------------------------------------------------
1 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_023827cd-7ee8-42e6-aa7b-661731f4c70f.jsonl
2 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_024de5df-1b7f-447c-8c3a-51407d8d6732.jsonl
3 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_03232e26-be3f-4a28-a5d2-ee1d8c0e9831.jsonl
4 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_034e819a-cfcb-43c6-ad25-0232ad48823c.jsonl
5 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_077ae8de-a68e-47e7-95a6-6d82f8f4eeb9.jsonl
6 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_0af50072-df4c-4084-a833-cebbd046e70e.jsonl
7 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_0de84cfc-c080-471f-b139-1bf061db4feb.jsonl
8 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_0fbdd8ad-32d8-4228-9a40-e09dde689760.jsonl
9 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_11c659c1-ffbf-4455-abfd-058f6bbf4bb2.jsonl
10 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_1958455d-6543-4307-a081-d86ce0637f9a.jsonl
11 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_1982fb29-c4ed-4dd3-855c-666e63bc62d9.jsonl
12 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_1caed86f-5625-4941-bdc1-cc57e4fec1cd.jsonl
13 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_1d3a0cd6-f0e6-4106-a080-524a4bd50016.jsonl
14 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_29d54f5a-1dd0-4e9a-b783-fb2eec9db072.jsonl
15 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_29fd3d99-53fb-43e2-a4a5-2fd01bf77258.jsonl
16 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_2b224cd9-286e-46ac-8c4e-c1e3befc8760.jsonl
17 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_2c131fca-2a05-4d5f-a805-59d2af3477e2.jsonl
18 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_2f28f1a7-6972-48ad-8997-65a5d52e4f1c.jsonl
19 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_30440198-cd90-48c6-82c1-ea871b8c21c5.jsonl
20 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_39367d6c-d7d4-45fc-a929-8a17184d1744.jsonl
21 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_393d19f2-1cd1-421f-be8a-78d955fdf602.jsonl
22 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_3a5d4f93-97ec-483a-88ef-324df9651b3f.jsonl
23 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_3c89ea11-69ff-4049-b775-f0c785997909.jsonl
24 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_3d5a011a-4bbe-4585-a2bd-ff3e943c8671.jsonl
25 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_3f805f4b-6f7f-42a8-a006-47c1e0401bd7.jsonl
26 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_3f9eb7ad-f266-4154-8d4d-54deeffde075.jsonl
27 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_400748d3-0076-4a04-8a1c-6055ba0b5a2d.jsonl
28 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_44e19375-3995-4dff-a3b6-8a25247a165c.jsonl
29 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_4a8cf52f-81d0-4875-9528-466b1cbc71e1.jsonl
30 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_4cc7015c-c39a-4bf6-9686-c00b3343edd9.jsonl
31 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_50757a42-079b-41ec-bcca-73759faffd62.jsonl
32 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_575ae832-e770-4a89-bfa7-c56f16dbca69.jsonl
33 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_580be642-bb73-4d0d-8b5e-f494722934cd.jsonl
34 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_5a02d9ee-12a0-437d-808f-d26f0eb2012b.jsonl
35 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_5d8d402b-8277-480a-b5fa-71169726864f.jsonl
36 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_5ee33ef7-455e-4fd5-9512-c4771dd802c1.jsonl
37 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_610c82ed-b9ee-449c-83b0-601205f3a74a.jsonl
38 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_629fe3ca-075f-4663-9b81-b807f3b42bf2.jsonl
39 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_64e5075e-e87e-4b2a-9e38-e5c102f6f2b1.jsonl
40 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_65dd2ff6-dae3-4a60-90d3-c3d7349fc92f.jsonl
41 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_6719ecd2-fe34-4078-a584-320d921cbf6f.jsonl
42 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_6938ee72-43ee-4ade-8840-151a402383b0.jsonl
43 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_73241940-66c1-481c-b53a-f5e8b9afe9fa.jsonl
44 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_751370b5-c7cb-44d8-a039-1468ee6747ab.jsonl
45 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_75af5d17-5ebb-4460-9f2a-dc9fe880a936.jsonl
46 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_79d50803-f7d9-4aa8-bf1a-d807980a40c6.jsonl
47 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_7b26046f-7c8d-405b-911b-df51e1a069fa.jsonl
48 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_7d1d69dc-bc8e-4817-9cab-afdc002ab7c4.jsonl
49 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_7ea7a996-b1bb-4773-a36a-461dce2de861.jsonl
50 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_8232f276-9e3f-463a-9350-362de1b501d1.jsonl
51 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_8509f5a7-64a8-4813-92dc-f6eb53e3aacc.jsonl
52 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_85b4c166-469d-449c-ab3d-5214c1d80246.jsonl
53 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_872b620a-b4fd-45d3-92bc-ff0584447705.jsonl
54 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_88f24f8d-16d3-4a21-894d-192033d0fa67.jsonl
55 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_8e6bd730-0f10-49d9-9b02-5ce16da47483.jsonl
56 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_8ede1b71-6846-439a-acba-86a57cfec3d2.jsonl
57 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_8f74f6ba-1c53-42d5-a3c7-e4ef46a71133.jsonl
58 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_90fa9c2b-25b0-47b7-af2b-a683356e543b.jsonl
59 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_92ec488a-287d-4bf0-977b-6998cf0cf476.jsonl
60 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_94a393a1-3b23-4961-a0a6-70bad5b4979c.jsonl
61 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_94b9df70-a95f-4545-be3a-5a34f7b09fb3.jsonl
62 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_95ffc9e1-c505-4a3b-8fb0-cbc98b8703e1.jsonl
63 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_98e718fd-5b0e-439f-a00c-57b61e06b395.jsonl
64 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_9f50a028-2586-4e0d-bcfd-d9d2d74e8953.jsonl
65 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_9f8d7d10-dda7-4e44-b00c-811635a199c8.jsonl
66 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_a055bd62-1ec2-47cf-bad2-321e3d4f053f.jsonl
67 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_a1e3430e-ef5c-4a86-914d-88e8fb7818c0.jsonl
68 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_a2b4cb3d-bea3-478e-82a2-77c00a827250.jsonl
69 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_a647274d-799d-4e7a-a485-b8632a87061e.jsonl
70 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_a94ea420-99ae-4d58-9cdc-d4666e3322a7.jsonl
71 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_ab7c0034-7fc1-4fa8-bae3-e97b85fc16a4.jsonl
72 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_b0732cff-657e-4e69-87b8-66e8025bf441.jsonl
73 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_b1f3b912-a2ab-43bd-8811-43d84b422506.jsonl
74 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_b6c2e4d3-d215-4d99-891f-d20b997d4d5a.jsonl
75 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_bbfabe6b-b9bc-476b-b8f0-7d6c47e9d2be.jsonl
76 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_bf11ef29-a3f9-4a2d-9dbf-ebcc56d39fdb.jsonl
77 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_c05da42b-4939-4f55-867c-16cf6d228e60.jsonl
78 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_c1fc3dd5-861f-4b8d-b7a2-eb8f6887b33b.jsonl
79 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_c44164b4-0770-48d0-87db-590ca529032a.jsonl
80 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_c6b43cda-ca5c-4855-9c08-0f8264cab1af.jsonl
81 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_c6e57a16-4879-4dcf-b591-503cfb46a360.jsonl
82 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_ca6e3842-6ca4-4230-84fa-376a3374c380.jsonl
83 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_caf769e4-7308-4419-9114-900ca213682a.jsonl
84 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_d0da78e9-3dcf-4a46-85c1-f23ed00178bc.jsonl
85 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_d386838c-a51e-4839-9f27-8495b2466e49.jsonl
86 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_d41edf9f-0ebb-4866-b3fe-50785746b36b.jsonl
87 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_d6a7fc44-b584-4dd8-9de2-e981afe0bb4a.jsonl
88 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_d7c49dbb-c008-47fc-9cbe-8d5695842d21.jsonl
89 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_db3b4be7-4e98-4fe9-96bf-05a5788815e3.jsonl
90 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_db981f69-9eca-4031-8565-318b949efbfe.jsonl
91 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_dbd4105f-7cbb-4483-a7b2-96b17b7fb594.jsonl
92 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_de42e348-b333-4d35-b883-9bfc94f29822.jsonl
93 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_de744938-fa6c-45dd-b600-428dd7c63a73.jsonl
94 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_e8f25867-697d-4f52-84e1-e50a95bc182b.jsonl
95 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_eb4e26d4-6625-4f8a-b5fe-6f3a9b8a4b79.jsonl
96 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_f141b736-5ce4-4f18-bb29-704227ca4bd1.jsonl
97 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_f50efdc6-f88e-4fa6-9ef6-dd1d8314bb36.jsonl
98 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_f7680c03-70df-4781-a98d-c88695f92f04.jsonl
99 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_fbc62949-624d-4943-9731-f5c46242ba55.jsonl
100 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_fd572627-cce7-4667-a684-fef096dfbeb7.jsonl
--------------------------------------------------------------------------------
/llama/urls/book.txt:
--------------------------------------------------------------------------------
1 | https://data.together.xyz/redpajama-data-1T/v1.0.0/book/book.jsonl
--------------------------------------------------------------------------------
/llama/urls/github.txt:
--------------------------------------------------------------------------------
1 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_08cdfa755e6d4d89b673d5bd1acee5f6.sampled.jsonl
2 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_0f27d10d846a473b96070c3394832f32.sampled.jsonl
3 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_0f979046c8e64e0fb5843d2634a9957d.sampled.jsonl
4 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_10f129bfd0af45caa9cd72aa9d863ec5.sampled.jsonl
5 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_11a1943edfa349c7939382799599eed6.sampled.jsonl
6 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_17197bd2478044bebd9ff4634b6dfcee.sampled.jsonl
7 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_1d750c0ce39d40c6bc20bad9469e5a99.sampled.jsonl
8 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_21078cf63afb4d9eb4a7876f726a7226.sampled.jsonl
9 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_216883d3a669406699428bc485a4c228.sampled.jsonl
10 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_24278a707deb445b8e4f59c83dd67910.sampled.jsonl
11 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_25989b8233a04ac791b0eccd502e0c7a.sampled.jsonl
12 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_26a6fa61c5eb4bb885e7bc643e285f0e.sampled.jsonl
13 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_275100c1a44f4451b0343373ebc5637a.sampled.jsonl
14 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_27f05c041a1c401783f90b9415e40e4b.sampled.jsonl
15 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_28106e66bfd94978abbc15ec845aeddb.sampled.jsonl
16 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_2afd093fefad4c8da76cc539e8fb6137.sampled.jsonl
17 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_30d01c4edab64866bda8c609e90b4f4e.sampled.jsonl
18 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_34b0785b77814b7583433ddb27a61ae0.sampled.jsonl
19 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_34e78793c1e94eeebd92852399097596.sampled.jsonl
20 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_366012588bef4d749bbbea76ae701141.sampled.jsonl
21 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_36c2a955ddd84672bbc778aa4ad2fbaf.sampled.jsonl
22 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_3be6a5cc7428401393c23e5516a43537.sampled.jsonl
23 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_3cc6cab9266746a6befa23648aa43119.sampled.jsonl
24 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_3f09f384f0734a4b912bb75db3f812bc.sampled.jsonl
25 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_478afe8aaccb43e6be2de7e34e041ef3.sampled.jsonl
26 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_483f11d6bc864f7fbfbe63bdf3583ce2.sampled.jsonl
27 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_4b6883dc304c4e799620ec95b96dc91a.sampled.jsonl
28 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_4bebabdbd8544da7a2071864ccf81f2e.sampled.jsonl
29 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_4d44caf43c154ae4aaeab36eab0221c9.sampled.jsonl
30 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_4f98dc136ba94eeaa1ba6c974814b33c.sampled.jsonl
31 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_4ff3604761614a3db550ef758e6457b5.sampled.jsonl
32 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_50423e84046948b4a2a70e7e4538e12d.sampled.jsonl
33 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_58fad5994b4446c6bceb33453484acb4.sampled.jsonl
34 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_5abf005e4e634f1dbfa8bd20b5687092.sampled.jsonl
35 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_5f2b3517159b426bb1a9e81ca189abcd.sampled.jsonl
36 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_610e908cafaa4d53958de50ad700822a.sampled.jsonl
37 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_6353ab14cb8f4623a7d75678d9e7f44e.sampled.jsonl
38 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_64747f25102740bab0ab54559569342a.sampled.jsonl
39 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_6742e99802894114a3ba44841f49b168.sampled.jsonl
40 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_677e10f9b0af4c489e60670352e7e224.sampled.jsonl
41 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_68534e6a093744fa9f38fa1a9cf51232.sampled.jsonl
42 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_6adcf92fb2ee48059cb60579f2e931f7.sampled.jsonl
43 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_70f0aa43987643a7874286bca4faa66b.sampled.jsonl
44 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_7462f1a34f594e9f8334d9a0cbbf80e7.sampled.jsonl
45 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_78a13a25258f4d24923702c07445e20e.sampled.jsonl
46 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_78e3afc1cfee41fbb7eaae2e5bfaa17b.sampled.jsonl
47 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_791ed086a57f4879bb1596bed6d37bb3.sampled.jsonl
48 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_7b5cc18857a34a0981b54f082de55cf8.sampled.jsonl
49 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_7c54b908a2df4ec2ba316d2081fc674e.sampled.jsonl
50 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_7edd0af0b61c426e93e8bd3f549a8f78.sampled.jsonl
51 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_7fa232e63a5d44a88a267181e9ac47b4.sampled.jsonl
52 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_81138c01f3a84f7fa8b56bf3d8fa35ce.sampled.jsonl
53 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_82a9a647b0eb4eb080d7ac15a13c765b.sampled.jsonl
54 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_835ac26c7a70447b97f4ec38bcb969ed.sampled.jsonl
55 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_8bc3c4fae78c41d999b2ae6d97cce96c.sampled.jsonl
56 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_8dc430f8f7114f018440c7b3d990e602.sampled.jsonl
57 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_8e1db2cf6c98420a88bad52fd57f4aa7.sampled.jsonl
58 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_919156910e704e6ca52e0d4880cdbb63.sampled.jsonl
59 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_944beac1347f491faa88f43c25d26fe4.sampled.jsonl
60 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_96175b6e4c764abfbbf14e78d4fd6464.sampled.jsonl
61 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_977d1fcdab92452d9dc38b2f4a99389b.sampled.jsonl
62 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_9eea501af8544d0b88f0b002850829d4.sampled.jsonl
63 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_a26c51ffe2924fd8ad6694c6aa0eacc5.sampled.jsonl
64 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_a3a01a2f81ed4cd2afb30e175200b48f.sampled.jsonl
65 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_a3d3e6f3f7d5495ca9ecf94d808fd350.sampled.jsonl
66 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_a777da5620f1467f8df3616b17d533dc.sampled.jsonl
67 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_a8fcae75b0c3410faabcff02f0056a36.sampled.jsonl
68 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_aabc2d54d1c946908d3400228e0f238c.sampled.jsonl
69 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_ad160286662849f49d1e6de27c0f1d15.sampled.jsonl
70 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_adda41d791974289aff042e2e3d07ec3.sampled.jsonl
71 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_ae1813abc63f4b1998dfa608e7fe5588.sampled.jsonl
72 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_af7b386db97e4211a7378be08d7b3f4f.sampled.jsonl
73 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_b15c575fa9f8465d98f70ba2f2f73c6e.sampled.jsonl
74 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_b40763d0b9ce4e0d8fb5f519f1f49f8c.sampled.jsonl
75 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_b56453718c5f46efa9c46feb194b0d6e.sampled.jsonl
76 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_b6225e159b86432d9fa5bf226bb51393.sampled.jsonl
77 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_b821f640b8f14ed588bf48ae13f44098.sampled.jsonl
78 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_bfbcad9633f04601ba2f824d082eaacf.sampled.jsonl
79 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_c6cfd16905814d7c955df1f4754a8b11.sampled.jsonl
80 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_c82c6775f0b74dbdae4524bb9aebf0ef.sampled.jsonl
81 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_ca73c69896b34adbbe468d78d9f134bc.sampled.jsonl
82 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_cac864a472b948b0bfe18c8e9a19aeb5.sampled.jsonl
83 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_cb330e8dc8ac411eba2dc8676c9c4403.sampled.jsonl
84 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_d0a054a678fc4c38b496d10e91a2c735.sampled.jsonl
85 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_d148857715424aabbd32b8ffe56c4082.sampled.jsonl
86 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_d33701435f964c90a86c22e204dd5fde.sampled.jsonl
87 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_d448c2553f474193a1224df1c38f74d4.sampled.jsonl
88 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_d8d45c58819948c9b352d74383944c4a.sampled.jsonl
89 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_db611a692a704f6db3c18e77f79fd2f0.sampled.jsonl
90 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_dcaa5c8b729b4fb599399dbf4557e43e.sampled.jsonl
91 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_dda863d833614d04b96bbe21b161768d.sampled.jsonl
92 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_eabf427d56184fb89f9b5f27e73f7988.sampled.jsonl
93 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_eee6980922ca4a14b0c3341fa8a904d9.sampled.jsonl
94 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_f478ff72b57f4f4283c22ac22ae84134.sampled.jsonl
95 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_f931feb0e85940879d194c0e20d9e28a.sampled.jsonl
96 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_f9bbe6a065004a4c8e018c6ad63063b2.sampled.jsonl
97 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_fbfc552e48164acda6605fa31fc2f563.sampled.jsonl
98 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_fc817e4c16494957bcb89b166e91434f.sampled.jsonl
--------------------------------------------------------------------------------
/llama/urls/stackexchange.txt:
--------------------------------------------------------------------------------
1 | https://data.together.xyz/redpajama-data-1T/v1.0.0/stackexchange/stackexchange.jsonl
--------------------------------------------------------------------------------
/llama/urls/wikipedia.txt:
--------------------------------------------------------------------------------
1 | https://data.together.xyz/redpajama-data-1T/v1.0.0/wikipedia/wiki.jsonl
--------------------------------------------------------------------------------
/llama/weight_diff.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # This file has been changed by Amirkeivan Mohtashami
16 | # to take into account the new token in the embedding layer
17 |
18 | import os
19 | from typing import Optional
20 |
21 | import fire
22 | import torch
23 | import tqdm
24 | import transformers
25 | from train import smart_tokenizer_and_embedding_resize
26 | import llama_mem
27 |
28 | @torch.inference_mode()
29 | def make_diff(
30 | path_raw: str, path_tuned: str, path_diff: str, device="cpu", # "cuda" or "cpu"
31 | ):
32 | """Make the weight diff.
33 |
34 | This function is given to present full transparency of how the weight diff was created.
35 |
36 | Run:
37 | python weight_diff.py make_diff --path_raw --path_tuned --path_diff
38 | """
39 | model_tuned: transformers.PreTrainedModel = llama_mem.LlamaForCausalLM.from_pretrained(
40 | path_tuned,
41 | device_map={"": torch.device(device)},
42 | torch_dtype=torch.float32,
43 | low_cpu_mem_usage=True,
44 | )
45 | model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
46 | path_raw,
47 | device_map={"": torch.device(device)},
48 | torch_dtype=torch.float32,
49 | low_cpu_mem_usage=True,
50 | )
51 |
52 | tokenizer_tuned: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
53 | path_tuned
54 | )
55 | tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
56 | path_raw
57 | )
58 | smart_tokenizer_and_embedding_resize(
59 | special_tokens_dict=dict(pad_token="[PAD]", additional_special_tokens=[""]),
60 | model=model_raw,
61 | tokenizer=tokenizer_raw,
62 | )
63 |
64 |
65 |
66 | state_dict_tuned = model_tuned.state_dict()
67 | state_dict_raw = model_raw.state_dict()
68 | with open(os.path.join(path_diff, "checksum_psum.txt"), "w") as f:
69 | f.write(str(sum(state_dict_tuned[key].sum().item() for key in state_dict_tuned)))
70 |
71 | for key in tqdm.tqdm(state_dict_tuned):
72 | state_dict_tuned[key].add_(-state_dict_raw[key])
73 |
74 | model_tuned.save_pretrained(path_diff)
75 | tokenizer_tuned.save_pretrained(path_diff)
76 |
77 |
78 | @torch.inference_mode()
79 | def recover(
80 | path_raw,
81 | path_diff,
82 | path_tuned: Optional[str] = None,
83 | device="cpu",
84 | test_inference=True,
85 | check_integrity_naively=True,
86 | ):
87 | """Recover the original weights from the released weight diff.
88 |
89 | This function is given for you to run.
90 |
91 | Things to do before running this:
92 | 1. Convert Meta's released weights into huggingface format. Follow this guide:
93 | https://huggingface.co/docs/transformers/main/model_doc/llama
94 | 2. Make sure you cloned the released weight diff into your local machine. The weight diff is located at:
95 | https://huggingface.co/tatsu-lab/alpaca-7b/tree/main
96 | 3. Run this function with the correct paths. E.g.,
97 | python weight_diff.py recover --path_raw --path_diff
98 |
99 | Additional notes:
100 | - If things run too slowly, and you have an 80G GPU lying around, let GPU go brrr by setting `--device "cuda"`.
101 | - If you want to save the recovered weights, set `--path_tuned `.
102 | Next time you can load the recovered weights directly from ``.
103 | """
104 | model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
105 | path_raw,
106 | device_map={"": torch.device(device)},
107 | torch_dtype=torch.float32,
108 | low_cpu_mem_usage=True,
109 | )
110 | model_recovered: transformers.PreTrainedModel = llama_mem.LlamaForCausalLM.from_pretrained(
111 | path_diff,
112 | device_map={"": torch.device(device)},
113 | torch_dtype=torch.float32,
114 | low_cpu_mem_usage=True,
115 | )
116 |
117 | tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
118 | path_raw
119 | )
120 | smart_tokenizer_and_embedding_resize(
121 | special_tokens_dict=dict(pad_token="[PAD]", additional_special_tokens=[""]),
122 | model=model_raw,
123 | tokenizer=tokenizer_raw,
124 | )
125 | tokenizer_recovered: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
126 | path_diff
127 | )
128 |
129 | state_dict_recovered = model_recovered.state_dict()
130 | state_dict_raw = model_raw.state_dict()
131 | for key in tqdm.tqdm(state_dict_recovered):
132 | state_dict_recovered[key].add_(state_dict_raw[key])
133 |
134 | if check_integrity_naively:
135 | # This is not a rigorous, cryptographically strong integrity check :)
136 | allsum = sum(state_dict_recovered[key].sum() for key in state_dict_recovered)
137 | if os.path.exists(os.path.join(path_diff, "checksum_psum.txt")):
138 | with open(os.path.join(path_diff, "checksum_psum.txt")) as f:
139 | expected_sum = float(f.read())
140 | else:
141 | expected_sum = 49798.7656 # backward compatibility with the first released weights
142 | assert torch.allclose(
143 | allsum, torch.full_like(allsum, fill_value=expected_sum), atol=1e-2, rtol=0
144 | ), "Naive integrity check failed. This could imply that some of the checkpoint files are corrupted."
145 |
146 | if path_tuned is not None:
147 | model_recovered.save_pretrained(path_tuned)
148 | tokenizer_recovered.save_pretrained(path_tuned)
149 |
150 | return model_recovered, tokenizer_recovered
151 |
152 |
153 | def main(task, **kwargs):
154 | globals()[task](**kwargs)
155 |
156 |
157 | if __name__ == "__main__":
158 | fire.Fire(main)
159 |
--------------------------------------------------------------------------------
/llama_legacy/redpajama.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Together Computer
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Lint as: python3
16 | """RedPajama: An Open-Source, Clean-Room 1.2 Trillion Token Dataset."""
17 |
18 |
19 | import json
20 |
21 | import datasets
22 | import traceback
23 | import numpy as np
24 | import math
25 |
26 | logger = datasets.logging.get_logger(__name__)
27 |
28 |
29 | _DESCRIPTION = """\
30 | RedPajama is a clean-room, fully open-source implementation of the LLaMa dataset.
31 | """
32 |
33 | _URL_LISTS = {
34 | "arxiv": "urls/arxiv.txt",
35 | "book": "urls/book.txt",
36 | "c4": "urls/c4.txt",
37 | "common_crawl": "urls/common_crawl.txt",
38 | "github": "urls/github.txt",
39 | "stackexchange": "urls/stackexchange.txt",
40 | "wikipedia": "urls/wikipedia.txt",
41 | }
42 |
43 |
44 | class RedPajama1TConfig(datasets.BuilderConfig):
45 | """BuilderConfig for RedPajama sample."""
46 |
47 | def __init__(self, *args, subsets, p_sample=None, **kwargs):
48 | """BuilderConfig for RedPajama.
49 | Args:
50 | **kwargs: keyword arguments forwarded to super.
51 | """
52 | super(RedPajama1TConfig, self).__init__(**kwargs)
53 |
54 | self.subsets = subsets
55 | self.p_sample = p_sample
56 |
57 |
58 | class RedPajama1T(datasets.GeneratorBasedBuilder):
59 | """RedPajama: Reproducing the LLaMA training dataset of over 1.2 trillion tokens. Version 1.0.0."""
60 | BUILDER_CONFIG_CLASS = RedPajama1TConfig
61 | BUILDER_CONFIGS = [
62 | RedPajama1TConfig(
63 | subsets = list(_URL_LISTS.keys()),
64 | name="plain_text",
65 | version=datasets.Version("1.0.0", ""),
66 | description="Plain text",
67 | ),
68 | RedPajama1TConfig(
69 | subsets = list(_URL_LISTS.keys()),
70 | name="plain_text_tenpercent",
71 | version=datasets.Version("1.0.0", ""),
72 | description="Plain text",
73 | p_sample=0.1
74 | ),
75 | ]
76 |
77 | def _info(self):
78 | return datasets.DatasetInfo(
79 | description=_DESCRIPTION,
80 | features=datasets.Features(
81 | {
82 | "text": datasets.Value("string"),
83 | "meta": datasets.Value("string"),
84 | "red_pajama_subset": datasets.Value("string"),
85 | }
86 | ),
87 | supervised_keys=None,
88 | )
89 |
90 | def _split_generators(self, dl_manager):
91 | url_lists = dl_manager.download_and_extract({
92 | subset: _URL_LISTS[subset] for subset in self.config.subsets
93 | })
94 |
95 | urls = {}
96 | rng = np.random.default_rng(seed=2)
97 |
98 | for subset, url_list in url_lists.items():
99 | with open(url_list, encoding="utf-8") as f:
100 | urls[subset] = [line.strip() for line in f]
101 | if self.config.p_sample is not None:
102 | urls[subset] = rng.choice(
103 | urls[subset],
104 | size=int(math.ceil(len(urls[subset]) * self.config.p_sample)), replace=False).tolist()
105 |
106 | downloaded_files = dl_manager.download(urls)
107 |
108 | return [
109 | datasets.SplitGenerator(
110 | name=datasets.Split.TRAIN,
111 | gen_kwargs = {
112 | "files": {
113 | subset: downloaded_files[subset]
114 | for subset in self.config.subsets
115 | }
116 | }
117 | )
118 | ]
119 |
120 | def _generate_examples(self, files):
121 | """This function returns the examples in the raw (text) form."""
122 | key = 0
123 | for subset in files:
124 | if subset == "common_crawl":
125 | import zstandard as zstd
126 |
127 | for path in files[subset]:
128 | with zstd.open(open(path, "rb"), "rt", encoding="utf-8") as f:
129 | for i, row in enumerate(f):
130 | try:
131 | data = json.loads(row)
132 | text = data["text"]
133 | del data["text"]
134 | yield key, {
135 | "text": text,
136 | "meta": json.dumps(data),
137 | "red_pajama_subset": subset,
138 | }
139 | key += 1
140 | except Exception as e:
141 | print(f'Subset: {subset}')
142 | print(f'Path: {path}')
143 | print(f'Row: {row}')
144 | traceback.print_exc()
145 |
146 | raise e
147 | else:
148 | for path in files[subset]:
149 | with open(path, encoding="utf-8") as f:
150 | for i, row in enumerate(f):
151 | try:
152 | data = json.loads(row)
153 | if "meta" not in data:
154 | text = data["text"]
155 | del data["text"]
156 | yield key, {
157 | "text": text,
158 | "meta": json.dumps(data),
159 | "red_pajama_subset": subset,
160 | }
161 | else:
162 | yield key, {
163 | "text": data["text"],
164 | "meta": data["meta"],
165 | "red_pajama_subset": subset,
166 | }
167 | key += 1
168 | except Exception as e:
169 | print(f'Subset: {subset}')
170 | print(f'Path: {path}')
171 | print(f'Row: {row}')
172 | traceback.print_exc()
173 |
174 | raise e
175 |
--------------------------------------------------------------------------------
/llama_legacy/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | rouge_score
3 | fire
4 | openai
5 | transformers>=4.28.1
6 | torch
7 | sentencepiece
8 | tokenizers>=0.13.3
9 | wandb
10 | accelerate
11 | datasets
12 |
--------------------------------------------------------------------------------
/llama_legacy/run_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | llama_weights_7b_base = "/llama_weights/7B_hf/"
16 | llama_weights_7b_tuned = "/llama-redpajama-mem-15000-with-mem/"
17 | cache_path = "/hf-cache/"
18 |
19 | def make_llama_base_pipe():
20 |
21 | from transformers import pipeline
22 |
23 | from transformers.models.llama import LlamaForCausalLM
24 |
25 | llama_base = LlamaForCausalLM.from_pretrained(
26 | llama_weights_7b_base,
27 | cache_dir=cache_path,
28 | )
29 |
30 | llama_base = llama_base.to('cuda:0')
31 |
32 | import transformers
33 |
34 | tokenizer = transformers.AutoTokenizer.from_pretrained(
35 | llama_weights_7b_base,
36 | cache_dir=cache_path,
37 | model_max_length=1024,
38 | padding_side="right",
39 | use_fast=False,
40 | )
41 |
42 | llama_base_pipe = pipeline("text-generation", model=llama_base, tokenizer=tokenizer, device=llama_base.device)
43 | return llama_base_pipe
44 |
45 |
46 |
47 | llama_base_pipe = make_llama_base_pipe()
48 |
49 | def make_llama_mem_pipe():
50 | from llama_mem import LlamaForCausalLM
51 |
52 | model = LlamaForCausalLM.from_pretrained(
53 | llama_weights_7b_tuned,
54 | cache_dir=cache_path,
55 | )
56 |
57 | model.to('cuda:1')
58 |
59 | import transformers
60 |
61 | tokenizer = transformers.AutoTokenizer.from_pretrained(
62 | llama_weights_7b_tuned,
63 | cache_dir=cache_path,
64 | model_max_length=512,
65 | padding_side="right",
66 | use_fast=False,
67 | )
68 | from transformers import pipeline
69 | llama_mem_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=model.device)
70 | return llama_mem_pipe
71 |
72 |
73 | llama_mem_pipe = make_llama_mem_pipe()
74 |
75 | mem_id = llama_mem_pipe.tokenizer.convert_tokens_to_ids("")
76 | llama_mem_pipe.model.set_mem_id(mem_id)
77 | llama_mem_pipe.model.set_mem_cache_args(max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None)
78 |
79 |
80 | pipes = {"base": llama_base_pipe, "mem": llama_mem_pipe}
81 |
82 | import torch
83 |
84 | import os
85 | import random
86 | import re
87 | import requests
88 |
89 | def generate_prompt(n_garbage):
90 | """Generates a text file and inserts an execute line at a random position."""
91 | n_garbage_prefix = random.randint(0, n_garbage)
92 | n_garbage_suffix = n_garbage - n_garbage_prefix
93 |
94 | task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."
95 | garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
96 | garbage_inf = " ".join([garbage] * 2000)
97 | assert len(garbage_inf) >= n_garbage
98 | garbage_prefix = garbage_inf[:n_garbage_prefix]
99 | garbage_suffix = garbage_inf[:n_garbage_suffix]
100 | pass_key = random.randint(1, 50000)
101 | information_line = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key."
102 | final_question = "What is the pass key? The pass key is"
103 | lines = [
104 | task_description,
105 | garbage_prefix,
106 | information_line,
107 | garbage_suffix,
108 | final_question
109 | ]
110 | return "\n".join(lines), pass_key
111 |
112 |
113 |
114 | def test_model(prompt_text, pass_key, model_name):
115 | response = pipes[model_name](prompt_text,num_return_sequences=1, max_new_tokens=10)[0]["generated_text"][len(prompt_text):]
116 | assert f"The pass key is {pass_key}" in prompt_text
117 |
118 | try:
119 | pass_key = int(re.search(r'\d+', response).group())
120 | except:
121 | pass_key = response[:20]
122 |
123 | return pass_key
124 |
125 |
126 | n_values = [0, 100, 500, 1000, 5000, 8000, 10000, 12000, 14000, 18000, 20000, 25000, 38000]
127 | num_tests = 50
128 | models = ["base", "mem"]
129 | accuracies = {x: [] for x in models}
130 | individual_results = {x: [] for x in models}
131 |
132 | for n in n_values:
133 |
134 | correct_count = {x: 0 for x in models}
135 |
136 | n_results = {x: [] for x in models}
137 | for i in range(num_tests):
138 | print(f"\nRunning test {i + 1}/{num_tests} for n = {n}...")
139 | prompt_text, pass_key = generate_prompt(n)
140 |
141 |
142 |
143 | for model_name in models:
144 | num_tokens = len(pipes[model_name].tokenizer.encode(prompt_text))
145 |
146 | print("Number of tokens in this prompt: ", num_tokens)
147 | model_output = test_model(prompt_text, pass_key, model_name)
148 | print(f"Expected number in the prompt: {pass_key}, {model_name} output: {model_output}")
149 |
150 | if pass_key == model_output:
151 | correct_count[model_name] += 1
152 | n_results[model_name].append(1)
153 | print("Success!")
154 | else:
155 | n_results[model_name].append(0)
156 | print("Fail.")
157 |
158 | for model in models:
159 | accuracy = (correct_count[model] / num_tests) * 100
160 | print(f"Accuracy {model} for n = {n}: {accuracy}%")
161 | accuracies[model].append(accuracy)
162 | individual_results[model].append(n_results)
163 |
--------------------------------------------------------------------------------
/llama_legacy/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import copy
16 | import logging
17 | from dataclasses import dataclass, field
18 | from functools import partial
19 | from typing import Dict, Optional, Sequence
20 |
21 | import torch
22 | import transformers
23 | from torch.utils.data import Dataset
24 | from transformers import Trainer, DataCollatorForLanguageModeling, get_cosine_schedule_with_warmup
25 | from llama_mem import LlamaForCausalLM
26 |
27 | from torch.distributed import barrier
28 | import os
29 |
30 |
31 | from datasets import load_dataset
32 |
33 | IGNORE_INDEX = -100
34 | DEFAULT_PAD_TOKEN = "[PAD]"
35 | DEFAULT_EOS_TOKEN = ""
36 | DEFAULT_BOS_TOKEN = ""
37 | DEFAULT_UNK_TOKEN = ""
38 |
39 |
40 | @dataclass
41 | class ModelArguments:
42 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
43 |
44 | @dataclass
45 | class TrainingArguments(transformers.TrainingArguments):
46 | cache_dir: Optional[str] = field(default=None)
47 | optim: str = field(default="adamw_torch")
48 | model_max_length: int = field(
49 | default=512,
50 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
51 | )
52 |
53 |
54 | class TrainerCosine(Trainer):
55 | def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
56 | """
57 | Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
58 | passed as an argument.
59 |
60 | Args:
61 | num_training_steps (int): The number of training steps to do.
62 | """
63 | if self.args.lr_scheduler_type != "cosine":
64 | return super().create_scheduler(num_training_steps, optimizer)
65 | if self.lr_scheduler is None:
66 | self.lr_scheduler = get_cosine_schedule_with_warmup(
67 | optimizer=self.optimizer if optimizer is None else optimizer,
68 | num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
69 | num_training_steps=num_training_steps,
70 | num_cycles=0.4 # ~10% of the init lr
71 | )
72 | return self.lr_scheduler
73 |
74 |
75 | def smart_tokenizer_and_embedding_resize(
76 | special_tokens_dict: Dict,
77 | tokenizer: transformers.PreTrainedTokenizer,
78 | model: transformers.PreTrainedModel,
79 | ):
80 | """Resize tokenizer and embedding.
81 |
82 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
83 | """
84 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
85 | model.resize_token_embeddings(len(tokenizer))
86 |
87 | if num_new_tokens > 0:
88 | input_embeddings = model.get_input_embeddings().weight.data
89 | output_embeddings = model.get_output_embeddings().weight.data
90 |
91 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
92 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
93 |
94 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
95 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
96 |
97 | def tokenize_fn(tokenizer, example):
98 | context_length = tokenizer.model_max_length
99 | outputs = tokenizer(
100 | tokenizer.eos_token.join(example["text"]),
101 | truncation=False,
102 | return_tensors="pt",
103 | pad_to_multiple_of=context_length,
104 | padding=True,
105 | )
106 | return {"input_ids": outputs["input_ids"].view(-1, context_length)}
107 |
108 | def add_mem_tokens(example, mem_freq, mem_id):
109 | x = example["input_ids"]
110 | ret = []
111 | prev_idx = 0
112 | for t_idx in range(mem_freq, len(x), mem_freq):
113 | ret.extend(x[prev_idx:t_idx])
114 | ret.append(mem_id)
115 | prev_idx = t_idx
116 | ret.extend(x[prev_idx:])
117 | # drop attention_mask
118 | return {"input_ids": ret}
119 |
120 | def train():
121 | parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
122 | model_args, training_args = parser.parse_args_into_dataclasses()
123 |
124 | model = LlamaForCausalLM.from_pretrained(
125 | model_args.model_name_or_path,
126 | cache_dir=training_args.cache_dir,
127 | )
128 |
129 | tokenizer = transformers.AutoTokenizer.from_pretrained(
130 | model_args.model_name_or_path,
131 | cache_dir=training_args.cache_dir,
132 | model_max_length=training_args.model_max_length,
133 | padding_side="right",
134 | use_fast=False,
135 | )
136 | special_tokens_dict = dict()
137 | if tokenizer.pad_token is None:
138 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
139 | if tokenizer.eos_token is None:
140 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
141 | if tokenizer.bos_token is None:
142 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
143 | if tokenizer.unk_token is None:
144 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
145 | mem_token = ""
146 | special_tokens_dict["additional_special_tokens"] = [mem_token]
147 |
148 | smart_tokenizer_and_embedding_resize(
149 | special_tokens_dict=special_tokens_dict,
150 | tokenizer=tokenizer,
151 | model=model,
152 | )
153 |
154 | mem_id = tokenizer.convert_tokens_to_ids(mem_token)
155 | model.set_mem_id(mem_id)
156 | rank = int(os.environ.get('RANK', -1))
157 | if rank > 0:
158 | barrier()
159 | dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", cache_dir=training_args.cache_dir)
160 |
161 | dataset = dataset.map(partial(tokenize_fn,tokenizer),batched=True, num_proc=32, remove_columns=["text", "meta"])
162 |
163 | dataset = dataset.map(
164 | partial(
165 | add_mem_tokens,
166 | mem_freq=50,
167 | mem_id=mem_id
168 | ), batched=False, num_proc=32)
169 | if rank == 0:
170 | barrier()
171 | print(dataset)
172 |
173 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
174 |
175 | trainer = TrainerCosine(
176 | model=model, tokenizer=tokenizer, args=training_args,
177 | train_dataset=dataset["train"],
178 | eval_dataset=None,
179 | data_collator=data_collator)
180 | trainer.train()
181 | trainer.save_state()
182 | trainer.save_model(output_dir=training_args.output_dir)
183 |
184 |
185 | if __name__ == "__main__":
186 | train()
187 |
--------------------------------------------------------------------------------
/llama_legacy/urls/arxiv.txt:
--------------------------------------------------------------------------------
1 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_023827cd-7ee8-42e6-aa7b-661731f4c70f.jsonl
2 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_024de5df-1b7f-447c-8c3a-51407d8d6732.jsonl
3 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_03232e26-be3f-4a28-a5d2-ee1d8c0e9831.jsonl
4 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_034e819a-cfcb-43c6-ad25-0232ad48823c.jsonl
5 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_077ae8de-a68e-47e7-95a6-6d82f8f4eeb9.jsonl
6 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_0af50072-df4c-4084-a833-cebbd046e70e.jsonl
7 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_0de84cfc-c080-471f-b139-1bf061db4feb.jsonl
8 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_0fbdd8ad-32d8-4228-9a40-e09dde689760.jsonl
9 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_11c659c1-ffbf-4455-abfd-058f6bbf4bb2.jsonl
10 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_1958455d-6543-4307-a081-d86ce0637f9a.jsonl
11 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_1982fb29-c4ed-4dd3-855c-666e63bc62d9.jsonl
12 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_1caed86f-5625-4941-bdc1-cc57e4fec1cd.jsonl
13 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_1d3a0cd6-f0e6-4106-a080-524a4bd50016.jsonl
14 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_29d54f5a-1dd0-4e9a-b783-fb2eec9db072.jsonl
15 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_29fd3d99-53fb-43e2-a4a5-2fd01bf77258.jsonl
16 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_2b224cd9-286e-46ac-8c4e-c1e3befc8760.jsonl
17 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_2c131fca-2a05-4d5f-a805-59d2af3477e2.jsonl
18 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_2f28f1a7-6972-48ad-8997-65a5d52e4f1c.jsonl
19 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_30440198-cd90-48c6-82c1-ea871b8c21c5.jsonl
20 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_39367d6c-d7d4-45fc-a929-8a17184d1744.jsonl
21 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_393d19f2-1cd1-421f-be8a-78d955fdf602.jsonl
22 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_3a5d4f93-97ec-483a-88ef-324df9651b3f.jsonl
23 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_3c89ea11-69ff-4049-b775-f0c785997909.jsonl
24 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_3d5a011a-4bbe-4585-a2bd-ff3e943c8671.jsonl
25 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_3f805f4b-6f7f-42a8-a006-47c1e0401bd7.jsonl
26 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_3f9eb7ad-f266-4154-8d4d-54deeffde075.jsonl
27 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_400748d3-0076-4a04-8a1c-6055ba0b5a2d.jsonl
28 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_44e19375-3995-4dff-a3b6-8a25247a165c.jsonl
29 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_4a8cf52f-81d0-4875-9528-466b1cbc71e1.jsonl
30 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_4cc7015c-c39a-4bf6-9686-c00b3343edd9.jsonl
31 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_50757a42-079b-41ec-bcca-73759faffd62.jsonl
32 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_575ae832-e770-4a89-bfa7-c56f16dbca69.jsonl
33 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_580be642-bb73-4d0d-8b5e-f494722934cd.jsonl
34 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_5a02d9ee-12a0-437d-808f-d26f0eb2012b.jsonl
35 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_5d8d402b-8277-480a-b5fa-71169726864f.jsonl
36 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_5ee33ef7-455e-4fd5-9512-c4771dd802c1.jsonl
37 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_610c82ed-b9ee-449c-83b0-601205f3a74a.jsonl
38 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_629fe3ca-075f-4663-9b81-b807f3b42bf2.jsonl
39 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_64e5075e-e87e-4b2a-9e38-e5c102f6f2b1.jsonl
40 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_65dd2ff6-dae3-4a60-90d3-c3d7349fc92f.jsonl
41 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_6719ecd2-fe34-4078-a584-320d921cbf6f.jsonl
42 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_6938ee72-43ee-4ade-8840-151a402383b0.jsonl
43 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_73241940-66c1-481c-b53a-f5e8b9afe9fa.jsonl
44 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_751370b5-c7cb-44d8-a039-1468ee6747ab.jsonl
45 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_75af5d17-5ebb-4460-9f2a-dc9fe880a936.jsonl
46 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_79d50803-f7d9-4aa8-bf1a-d807980a40c6.jsonl
47 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_7b26046f-7c8d-405b-911b-df51e1a069fa.jsonl
48 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_7d1d69dc-bc8e-4817-9cab-afdc002ab7c4.jsonl
49 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_7ea7a996-b1bb-4773-a36a-461dce2de861.jsonl
50 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_8232f276-9e3f-463a-9350-362de1b501d1.jsonl
51 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_8509f5a7-64a8-4813-92dc-f6eb53e3aacc.jsonl
52 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_85b4c166-469d-449c-ab3d-5214c1d80246.jsonl
53 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_872b620a-b4fd-45d3-92bc-ff0584447705.jsonl
54 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_88f24f8d-16d3-4a21-894d-192033d0fa67.jsonl
55 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_8e6bd730-0f10-49d9-9b02-5ce16da47483.jsonl
56 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_8ede1b71-6846-439a-acba-86a57cfec3d2.jsonl
57 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_8f74f6ba-1c53-42d5-a3c7-e4ef46a71133.jsonl
58 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_90fa9c2b-25b0-47b7-af2b-a683356e543b.jsonl
59 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_92ec488a-287d-4bf0-977b-6998cf0cf476.jsonl
60 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_94a393a1-3b23-4961-a0a6-70bad5b4979c.jsonl
61 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_94b9df70-a95f-4545-be3a-5a34f7b09fb3.jsonl
62 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_95ffc9e1-c505-4a3b-8fb0-cbc98b8703e1.jsonl
63 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_98e718fd-5b0e-439f-a00c-57b61e06b395.jsonl
64 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_9f50a028-2586-4e0d-bcfd-d9d2d74e8953.jsonl
65 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_9f8d7d10-dda7-4e44-b00c-811635a199c8.jsonl
66 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_a055bd62-1ec2-47cf-bad2-321e3d4f053f.jsonl
67 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_a1e3430e-ef5c-4a86-914d-88e8fb7818c0.jsonl
68 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_a2b4cb3d-bea3-478e-82a2-77c00a827250.jsonl
69 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_a647274d-799d-4e7a-a485-b8632a87061e.jsonl
70 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_a94ea420-99ae-4d58-9cdc-d4666e3322a7.jsonl
71 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_ab7c0034-7fc1-4fa8-bae3-e97b85fc16a4.jsonl
72 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_b0732cff-657e-4e69-87b8-66e8025bf441.jsonl
73 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_b1f3b912-a2ab-43bd-8811-43d84b422506.jsonl
74 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_b6c2e4d3-d215-4d99-891f-d20b997d4d5a.jsonl
75 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_bbfabe6b-b9bc-476b-b8f0-7d6c47e9d2be.jsonl
76 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_bf11ef29-a3f9-4a2d-9dbf-ebcc56d39fdb.jsonl
77 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_c05da42b-4939-4f55-867c-16cf6d228e60.jsonl
78 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_c1fc3dd5-861f-4b8d-b7a2-eb8f6887b33b.jsonl
79 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_c44164b4-0770-48d0-87db-590ca529032a.jsonl
80 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_c6b43cda-ca5c-4855-9c08-0f8264cab1af.jsonl
81 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_c6e57a16-4879-4dcf-b591-503cfb46a360.jsonl
82 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_ca6e3842-6ca4-4230-84fa-376a3374c380.jsonl
83 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_caf769e4-7308-4419-9114-900ca213682a.jsonl
84 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_d0da78e9-3dcf-4a46-85c1-f23ed00178bc.jsonl
85 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_d386838c-a51e-4839-9f27-8495b2466e49.jsonl
86 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_d41edf9f-0ebb-4866-b3fe-50785746b36b.jsonl
87 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_d6a7fc44-b584-4dd8-9de2-e981afe0bb4a.jsonl
88 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_d7c49dbb-c008-47fc-9cbe-8d5695842d21.jsonl
89 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_db3b4be7-4e98-4fe9-96bf-05a5788815e3.jsonl
90 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_db981f69-9eca-4031-8565-318b949efbfe.jsonl
91 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_dbd4105f-7cbb-4483-a7b2-96b17b7fb594.jsonl
92 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_de42e348-b333-4d35-b883-9bfc94f29822.jsonl
93 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_de744938-fa6c-45dd-b600-428dd7c63a73.jsonl
94 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_e8f25867-697d-4f52-84e1-e50a95bc182b.jsonl
95 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_eb4e26d4-6625-4f8a-b5fe-6f3a9b8a4b79.jsonl
96 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_f141b736-5ce4-4f18-bb29-704227ca4bd1.jsonl
97 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_f50efdc6-f88e-4fa6-9ef6-dd1d8314bb36.jsonl
98 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_f7680c03-70df-4781-a98d-c88695f92f04.jsonl
99 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_fbc62949-624d-4943-9731-f5c46242ba55.jsonl
100 | https://data.together.xyz/redpajama-data-1T/v1.0.0/arxiv/arxiv_fd572627-cce7-4667-a684-fef096dfbeb7.jsonl
--------------------------------------------------------------------------------
/llama_legacy/urls/book.txt:
--------------------------------------------------------------------------------
1 | https://data.together.xyz/redpajama-data-1T/v1.0.0/book/book.jsonl
--------------------------------------------------------------------------------
/llama_legacy/urls/github.txt:
--------------------------------------------------------------------------------
1 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_08cdfa755e6d4d89b673d5bd1acee5f6.sampled.jsonl
2 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_0f27d10d846a473b96070c3394832f32.sampled.jsonl
3 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_0f979046c8e64e0fb5843d2634a9957d.sampled.jsonl
4 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_10f129bfd0af45caa9cd72aa9d863ec5.sampled.jsonl
5 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_11a1943edfa349c7939382799599eed6.sampled.jsonl
6 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_17197bd2478044bebd9ff4634b6dfcee.sampled.jsonl
7 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_1d750c0ce39d40c6bc20bad9469e5a99.sampled.jsonl
8 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_21078cf63afb4d9eb4a7876f726a7226.sampled.jsonl
9 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_216883d3a669406699428bc485a4c228.sampled.jsonl
10 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_24278a707deb445b8e4f59c83dd67910.sampled.jsonl
11 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_25989b8233a04ac791b0eccd502e0c7a.sampled.jsonl
12 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_26a6fa61c5eb4bb885e7bc643e285f0e.sampled.jsonl
13 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_275100c1a44f4451b0343373ebc5637a.sampled.jsonl
14 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_27f05c041a1c401783f90b9415e40e4b.sampled.jsonl
15 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_28106e66bfd94978abbc15ec845aeddb.sampled.jsonl
16 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_2afd093fefad4c8da76cc539e8fb6137.sampled.jsonl
17 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_30d01c4edab64866bda8c609e90b4f4e.sampled.jsonl
18 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_34b0785b77814b7583433ddb27a61ae0.sampled.jsonl
19 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_34e78793c1e94eeebd92852399097596.sampled.jsonl
20 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_366012588bef4d749bbbea76ae701141.sampled.jsonl
21 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_36c2a955ddd84672bbc778aa4ad2fbaf.sampled.jsonl
22 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_3be6a5cc7428401393c23e5516a43537.sampled.jsonl
23 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_3cc6cab9266746a6befa23648aa43119.sampled.jsonl
24 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_3f09f384f0734a4b912bb75db3f812bc.sampled.jsonl
25 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_478afe8aaccb43e6be2de7e34e041ef3.sampled.jsonl
26 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_483f11d6bc864f7fbfbe63bdf3583ce2.sampled.jsonl
27 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_4b6883dc304c4e799620ec95b96dc91a.sampled.jsonl
28 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_4bebabdbd8544da7a2071864ccf81f2e.sampled.jsonl
29 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_4d44caf43c154ae4aaeab36eab0221c9.sampled.jsonl
30 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_4f98dc136ba94eeaa1ba6c974814b33c.sampled.jsonl
31 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_4ff3604761614a3db550ef758e6457b5.sampled.jsonl
32 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_50423e84046948b4a2a70e7e4538e12d.sampled.jsonl
33 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_58fad5994b4446c6bceb33453484acb4.sampled.jsonl
34 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_5abf005e4e634f1dbfa8bd20b5687092.sampled.jsonl
35 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_5f2b3517159b426bb1a9e81ca189abcd.sampled.jsonl
36 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_610e908cafaa4d53958de50ad700822a.sampled.jsonl
37 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_6353ab14cb8f4623a7d75678d9e7f44e.sampled.jsonl
38 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_64747f25102740bab0ab54559569342a.sampled.jsonl
39 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_6742e99802894114a3ba44841f49b168.sampled.jsonl
40 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_677e10f9b0af4c489e60670352e7e224.sampled.jsonl
41 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_68534e6a093744fa9f38fa1a9cf51232.sampled.jsonl
42 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_6adcf92fb2ee48059cb60579f2e931f7.sampled.jsonl
43 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_70f0aa43987643a7874286bca4faa66b.sampled.jsonl
44 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_7462f1a34f594e9f8334d9a0cbbf80e7.sampled.jsonl
45 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_78a13a25258f4d24923702c07445e20e.sampled.jsonl
46 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_78e3afc1cfee41fbb7eaae2e5bfaa17b.sampled.jsonl
47 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_791ed086a57f4879bb1596bed6d37bb3.sampled.jsonl
48 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_7b5cc18857a34a0981b54f082de55cf8.sampled.jsonl
49 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_7c54b908a2df4ec2ba316d2081fc674e.sampled.jsonl
50 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_7edd0af0b61c426e93e8bd3f549a8f78.sampled.jsonl
51 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_7fa232e63a5d44a88a267181e9ac47b4.sampled.jsonl
52 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_81138c01f3a84f7fa8b56bf3d8fa35ce.sampled.jsonl
53 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_82a9a647b0eb4eb080d7ac15a13c765b.sampled.jsonl
54 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_835ac26c7a70447b97f4ec38bcb969ed.sampled.jsonl
55 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_8bc3c4fae78c41d999b2ae6d97cce96c.sampled.jsonl
56 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_8dc430f8f7114f018440c7b3d990e602.sampled.jsonl
57 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_8e1db2cf6c98420a88bad52fd57f4aa7.sampled.jsonl
58 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_919156910e704e6ca52e0d4880cdbb63.sampled.jsonl
59 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_944beac1347f491faa88f43c25d26fe4.sampled.jsonl
60 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_96175b6e4c764abfbbf14e78d4fd6464.sampled.jsonl
61 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_977d1fcdab92452d9dc38b2f4a99389b.sampled.jsonl
62 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_9eea501af8544d0b88f0b002850829d4.sampled.jsonl
63 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_a26c51ffe2924fd8ad6694c6aa0eacc5.sampled.jsonl
64 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_a3a01a2f81ed4cd2afb30e175200b48f.sampled.jsonl
65 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_a3d3e6f3f7d5495ca9ecf94d808fd350.sampled.jsonl
66 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_a777da5620f1467f8df3616b17d533dc.sampled.jsonl
67 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_a8fcae75b0c3410faabcff02f0056a36.sampled.jsonl
68 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_aabc2d54d1c946908d3400228e0f238c.sampled.jsonl
69 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_ad160286662849f49d1e6de27c0f1d15.sampled.jsonl
70 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_adda41d791974289aff042e2e3d07ec3.sampled.jsonl
71 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_ae1813abc63f4b1998dfa608e7fe5588.sampled.jsonl
72 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_af7b386db97e4211a7378be08d7b3f4f.sampled.jsonl
73 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_b15c575fa9f8465d98f70ba2f2f73c6e.sampled.jsonl
74 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_b40763d0b9ce4e0d8fb5f519f1f49f8c.sampled.jsonl
75 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_b56453718c5f46efa9c46feb194b0d6e.sampled.jsonl
76 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_b6225e159b86432d9fa5bf226bb51393.sampled.jsonl
77 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_b821f640b8f14ed588bf48ae13f44098.sampled.jsonl
78 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_bfbcad9633f04601ba2f824d082eaacf.sampled.jsonl
79 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_c6cfd16905814d7c955df1f4754a8b11.sampled.jsonl
80 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_c82c6775f0b74dbdae4524bb9aebf0ef.sampled.jsonl
81 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_ca73c69896b34adbbe468d78d9f134bc.sampled.jsonl
82 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_cac864a472b948b0bfe18c8e9a19aeb5.sampled.jsonl
83 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_cb330e8dc8ac411eba2dc8676c9c4403.sampled.jsonl
84 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_d0a054a678fc4c38b496d10e91a2c735.sampled.jsonl
85 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_d148857715424aabbd32b8ffe56c4082.sampled.jsonl
86 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_d33701435f964c90a86c22e204dd5fde.sampled.jsonl
87 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_d448c2553f474193a1224df1c38f74d4.sampled.jsonl
88 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_d8d45c58819948c9b352d74383944c4a.sampled.jsonl
89 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_db611a692a704f6db3c18e77f79fd2f0.sampled.jsonl
90 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_dcaa5c8b729b4fb599399dbf4557e43e.sampled.jsonl
91 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_dda863d833614d04b96bbe21b161768d.sampled.jsonl
92 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_eabf427d56184fb89f9b5f27e73f7988.sampled.jsonl
93 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_eee6980922ca4a14b0c3341fa8a904d9.sampled.jsonl
94 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_f478ff72b57f4f4283c22ac22ae84134.sampled.jsonl
95 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_f931feb0e85940879d194c0e20d9e28a.sampled.jsonl
96 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_f9bbe6a065004a4c8e018c6ad63063b2.sampled.jsonl
97 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_fbfc552e48164acda6605fa31fc2f563.sampled.jsonl
98 | https://data.together.xyz/redpajama-data-1T/v1.0.0/github/filtered_fc817e4c16494957bcb89b166e91434f.sampled.jsonl
--------------------------------------------------------------------------------
/llama_legacy/urls/stackexchange.txt:
--------------------------------------------------------------------------------
1 | https://data.together.xyz/redpajama-data-1T/v1.0.0/stackexchange/stackexchange.jsonl
--------------------------------------------------------------------------------
/llama_legacy/urls/wikipedia.txt:
--------------------------------------------------------------------------------
1 | https://data.together.xyz/redpajama-data-1T/v1.0.0/wikipedia/wiki.jsonl
--------------------------------------------------------------------------------
/llama_legacy/weight_diff.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # This file has been changed by Amirkeivan Mohtashami
16 | # to take into account the new token in the embedding layer
17 |
18 | from typing import Optional
19 |
20 | import fire
21 | import torch
22 | import tqdm
23 | import transformers
24 | from train import smart_tokenizer_and_embedding_resize
25 |
26 |
27 | @torch.inference_mode()
28 | def make_diff(
29 | path_raw: str, path_tuned: str, path_diff: str, device="cpu", # "cuda" or "cpu"
30 | ):
31 | """Make the weight diff.
32 |
33 | This function is given to present full transparency of how the weight diff was created.
34 |
35 | Run:
36 | python weight_diff.py make_diff --path_raw --path_tuned --path_diff
37 | """
38 | model_tuned: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
39 | path_tuned,
40 | device_map={"": torch.device(device)},
41 | torch_dtype=torch.float32,
42 | low_cpu_mem_usage=True,
43 | )
44 | model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
45 | path_raw,
46 | device_map={"": torch.device(device)},
47 | torch_dtype=torch.float32,
48 | low_cpu_mem_usage=True,
49 | )
50 |
51 | tokenizer_tuned: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
52 | path_tuned
53 | )
54 | tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
55 | path_raw
56 | )
57 | smart_tokenizer_and_embedding_resize(
58 | special_tokens_dict=dict(pad_token="[PAD]", additional_special_tokens=[""]),
59 | model=model_raw,
60 | tokenizer=tokenizer_raw,
61 | )
62 |
63 | state_dict_tuned = model_tuned.state_dict()
64 | state_dict_raw = model_raw.state_dict()
65 | for key in tqdm.tqdm(state_dict_tuned):
66 | state_dict_tuned[key].add_(-state_dict_raw[key])
67 |
68 | model_tuned.save_pretrained(path_diff)
69 | tokenizer_tuned.save_pretrained(path_diff)
70 |
71 |
72 | @torch.inference_mode()
73 | def recover(
74 | path_raw,
75 | path_diff,
76 | path_tuned: Optional[str] = None,
77 | device="cpu",
78 | test_inference=True,
79 | check_integrity_naively=True,
80 | ):
81 | """Recover the original weights from the released weight diff.
82 |
83 | This function is given for you to run.
84 |
85 | Things to do before running this:
86 | 1. Convert Meta's released weights into huggingface format. Follow this guide:
87 | https://huggingface.co/docs/transformers/main/model_doc/llama
88 | 2. Make sure you cloned the released weight diff into your local machine. The weight diff is located at:
89 | https://huggingface.co/tatsu-lab/alpaca-7b/tree/main
90 | 3. Run this function with the correct paths. E.g.,
91 | python weight_diff.py recover --path_raw --path_diff
92 |
93 | Additional notes:
94 | - If things run too slowly, and you have an 80G GPU lying around, let GPU go brrr by setting `--device "cuda"`.
95 | - If you want to save the recovered weights, set `--path_tuned `.
96 | Next time you can load the recovered weights directly from ``.
97 | """
98 | model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
99 | path_raw,
100 | device_map={"": torch.device(device)},
101 | torch_dtype=torch.float32,
102 | low_cpu_mem_usage=True,
103 | )
104 | model_recovered: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained(
105 | path_diff,
106 | device_map={"": torch.device(device)},
107 | torch_dtype=torch.float32,
108 | low_cpu_mem_usage=True,
109 | )
110 |
111 | tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
112 | path_raw
113 | )
114 | smart_tokenizer_and_embedding_resize(
115 | special_tokens_dict=dict(pad_token="[PAD]", additional_special_tokens=[""]),
116 | model=model_raw,
117 | tokenizer=tokenizer_raw,
118 | )
119 | tokenizer_recovered: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
120 | path_diff
121 | )
122 |
123 | state_dict_recovered = model_recovered.state_dict()
124 | state_dict_raw = model_raw.state_dict()
125 | for key in tqdm.tqdm(state_dict_recovered):
126 | state_dict_recovered[key].add_(state_dict_raw[key])
127 |
128 | if check_integrity_naively:
129 | # This is not a rigorous, cryptographically strong integrity check :)
130 | allsum = sum(state_dict_recovered[key].sum() for key in state_dict_recovered)
131 | assert torch.allclose(
132 | allsum, torch.full_like(allsum, fill_value=49798.7656), atol=1e-2, rtol=0
133 | ), "Naive integrity check failed. This could imply that some of the checkpoint files are corrupted."
134 |
135 | if path_tuned is not None:
136 | model_recovered.save_pretrained(path_tuned)
137 | tokenizer_recovered.save_pretrained(path_tuned)
138 |
139 | return model_recovered, tokenizer_recovered
140 |
141 |
142 | def main(task, **kwargs):
143 | globals()[task](**kwargs)
144 |
145 |
146 | if __name__ == "__main__":
147 | fire.Fire(main)
148 |
--------------------------------------------------------------------------------
/lm_benchmark/config/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from . import rotary
16 |
17 | CONFIG_FORMAT_TO_MODULE_MAP = {
18 | "rotary": rotary,
19 | }
20 |
21 |
22 | def parse_args_with_format(format, base_parser, args, namespace):
23 | return CONFIG_FORMAT_TO_MODULE_MAP[format].parse_args(base_parser, args, namespace)
24 |
25 |
26 | def registered_formats():
27 | return CONFIG_FORMAT_TO_MODULE_MAP.keys()
28 |
--------------------------------------------------------------------------------
/lm_benchmark/config/rotary.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import torch
17 | import distributed
18 | import models
19 |
20 | def none_or_str(value):
21 | if value == 'None':
22 | return None
23 | return value
24 |
25 | def none_or_int(value):
26 | if value == 'None':
27 | return None
28 | return int(value)
29 |
30 | def none_or_float(value):
31 | if value == 'None':
32 | return None
33 | return float(value)
34 |
35 | def parse_args(base_parser, args, namespace):
36 | parser = base_parser
37 | # General training params
38 | parser.add_argument('--batch_size', default=50, type=int)
39 | parser.add_argument('--acc_steps', default=4, type=int)
40 | parser.add_argument('--seed', default=2, type=int)
41 | parser.add_argument('--device', default='cuda:0', type=str)
42 | parser.add_argument('--iterations', default=15000, type=int)
43 | parser.add_argument('--lr', default=2e-3, type=float)
44 | parser.add_argument('--warmup_percent', default=0.02, type=float)
45 | parser.add_argument('--weight_decay', default=1e-3, type=float)
46 | parser.add_argument('--beta1', default=0.9, type=float)
47 | parser.add_argument('--beta2', default=0.95, type=float)
48 | parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'none'])
49 | parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'adafactor'])
50 | parser.add_argument('--eval_freq', default=200, type=int) # in iterations
51 | parser.add_argument('--results_base_folder', default="./exps", type=str)
52 | parser.add_argument('--save_checkpoint_freq', default=None, type=int, required=False)
53 |
54 | # Dataset params
55 | parser.add_argument('--dataset', choices=['pg19', 'arxivmath'])
56 | parser.add_argument('--vocab_size', default=50304, type=int)
57 | parser.add_argument('--mem_freq', default=50, type=none_or_int, required=False, help="Frequency of landmark tokens")
58 |
59 | # Model params
60 | parser.add_argument('--model', default='base_rotary', choices=models.registered_models())
61 | parser.add_argument('--dropout', default=0.0, type=float)
62 | parser.add_argument('--group_dropout', default=None, type=float, required=False)
63 | parser.add_argument('--n_head', default=8, type=int)
64 | parser.add_argument('--n_layer', default=12, type=int) # depths in att + ff blocks
65 | parser.add_argument('--n_embd', default=1024, type=int) # embedding size / hidden size ...
66 | parser.add_argument('--sequence_length', default=512, type=int)
67 | parser.add_argument('--dtype', default="torch.bfloat16", type=str)
68 | parser.add_argument('--bias', default=False, type=bool)
69 | parser.add_argument('--no_compile', action='store_true') # if true then model is not compiled
70 | parser.add_argument('--run_prefix', default=None, type=str, required=False) # is added before the autogenerated experiment name
71 | parser.add_argument('--exp_name', default=None, type=str, required=False) # is added before the autogenerated experiment name
72 | parser.add_argument('--softmax_func', default="mem_opt", type=str, required=False,
73 | choices=["mem_opt", "nomem", "mem", "ignore_mem"]) # distributed backend type
74 | parser.add_argument('--positional_encoder', default="rotary", type=str, required=False,
75 | choices=models.positional_encoders.registered_encoders()) # distributed backend type
76 | # logging params (WandB)
77 | parser.add_argument('--wandb', action='store_true') # whether to use wandb or not
78 | parser.add_argument('--wandb_project', default="my-project", type=str)
79 | # Distributed args
80 | parser.add_argument('--distributed_backend', default=None, type=none_or_str, required=False,
81 | choices=distributed.registered_backends()) # distributed backend type
82 | # Landmark tokens
83 | parser.add_argument('--max_groups_for_softmax', default=16, type=int, required=False, help="Should be at least 2 + max. number of landmark tokens in one chunk.")
84 | # Inference
85 | parser.add_argument('--use_cache', action='store_true')
86 | parser.add_argument('--lm_cache', default="none", type=str, required=False,
87 | choices=models.caches.registered_caches())
88 | parser.add_argument('--mem_cache_size', default=None, type=int, required=False)
89 | parser.add_argument('--mem_cache_freq', default=None, type=int, required=False, help="Frequency to add landmark tokens in the input (block size at inference)")
90 | parser.add_argument('--cache_topk', default=1, type=int, required=False)
91 | parser.add_argument('--cache_selection_method', default="per_token_and_head", type=str, required=False,)
92 | parser.add_argument('--eval_seq_length', default=512, type=int, required=False, help="Evaluation Length")
93 | parser.add_argument('--eval_sample_size', default=None, type=none_or_int, required=False, help="Size of the random subset of validation set used for evaluation")
94 | parser.add_argument('--mid_length', default=250, type=int, required=False, help="Size of chunks to break the input into")
95 | parser.add_argument('--allow_cache_during_training', action='store_true')
96 | parser.add_argument('--postpone_lm_cache', action='store_true')
97 | parser.add_argument('--optimization_process', default="landmark", type=str, required=False,
98 | choices=["transformer_xl", "landmark"]) # distributed backend type
99 |
100 | # CMT Token
101 | parser.add_argument('--under_rem_score_prob', default=0., type=none_or_float, required=False)
102 | parser.add_argument('--rem_cutoff', default=None, type=none_or_float, required=False)
103 | parser.add_argument('--enable_rem_score', default=False, action='store_true', required=False)
104 |
105 | # Positional Augmentation
106 | parser.add_argument('--pos_jump_on_mem', default=None, type=none_or_int, required=False)
107 |
108 | # Transformer XL
109 | parser.add_argument('--total_sequence_length', default=None, type=int, required=False)
110 |
111 |
112 | args = parser.parse_args(args, namespace)
113 |
114 | if args.exp_name is None:
115 | special_name_handle_fields = {"model", "lr", "batch_size",
116 | "acc_steps", "seed", "exp_name",
117 | "wandb", "wandb_project",
118 | "run_prefix", "distributed_backend", "config_format",
119 | "sequence_length", "mem_freq"}
120 | overriden_values = []
121 | for key in vars(args):
122 | if key in special_name_handle_fields:
123 | continue
124 | if getattr(args, key) != parser.get_default(key):
125 | overriden_values.append((key, getattr(args, key)))
126 | chunk_len = 10
127 | overriden_values_str_parts = []
128 | for chunk_id in range(0, len(overriden_values), chunk_len):
129 | overriden_values_str = "_".join(["{}={}".format(key, value) for key, value in overriden_values[chunk_id:chunk_id+chunk_len]])
130 | overriden_values_str_parts.append(overriden_values_str)
131 | overriden_values_str = "/".join(overriden_values_str_parts)
132 | exp_name = ""
133 | if args.run_prefix is not None:
134 | exp_name += f"{args.run_prefix}_"
135 | exp_name += f"{args.model}_lr{args.lr}_memfreq{args.mem_freq}_bs{args.batch_size}x{args.acc_steps}_seqlen{args.sequence_length}/{overriden_values_str}_seed={args.seed}"
136 | args.exp_name = exp_name
137 |
138 | args.landmark_id = 50260
139 | if args.dtype == "torch.bfloat16":
140 | args.dtype = torch.bfloat16
141 | elif args.dtype == "torch.float16":
142 | args.dtype = torch.float16
143 |
144 | landmark_freq = max(args.mem_cache_freq or 0, args.mem_freq or 0)
145 | if landmark_freq != 0 and args.max_groups_for_softmax < args.sequence_length // landmark_freq + 1 + 2:
146 | print("CRITICAL WARNING: Maximum number of groups for softmax is too low. Adjust with --max_groups_for_softmax.")
147 |
148 |
149 | return args
150 |
--------------------------------------------------------------------------------
/lm_benchmark/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from . import pg19, arxiv_math
16 |
17 | PREPARE_GET_DATASET_MAP = {
18 | "pg19": (pg19.prepare_pg19_data, pg19.get_pg19_data),
19 | "arxivmath": (arxiv_math.prepare_arxivmath_data, arxiv_math.get_arxivmath_data)
20 | }
21 |
22 |
23 | def prepare_dataset(args):
24 | """ Fetch the right dataset given by the args.dataset parameter. The logic for each dataset is
25 | contained in its own pythin file. The expected format at the moment is a disctionary of np.memmap
26 | containing two keys: 'train' and 'val', corresponding to the tokenized training and validation data. """
27 | return PREPARE_GET_DATASET_MAP[args.dataset][0](args)
28 |
29 | def get_dataset(args):
30 | """ Fetch the right dataset given by the args.dataset parameter. The logic for each dataset is
31 | contained in its own pythin file. The expected format at the moment is a disctionary of np.memmap
32 | containing two keys: 'train' and 'val', corresponding to the tokenized training and validation data. """
33 | return PREPARE_GET_DATASET_MAP[args.dataset][1](args)
34 |
--------------------------------------------------------------------------------
/lm_benchmark/data/arxiv_math.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import zipfile
17 | import urllib
18 | import numpy as np
19 | import tiktoken
20 | import torch
21 | import regex
22 | import multiprocessing
23 | import itertools
24 | import functools
25 |
26 | from .utils import add_mem_tokens
27 |
28 | ARXIVMATH_ORIGINAL_PATH = "./data/proof-pile/"
29 |
30 |
31 | def get_path(config):
32 | dataset_name = f"arxiv_mem={config.mem_freq}"
33 | return os.path.join(os.path.dirname(__file__), f"datasets/{dataset_name}/")
34 |
35 | def prepare_arxivmath_data(config):
36 | DATA_PATH = get_path(config)
37 | print(DATA_PATH)
38 | os.makedirs(DATA_PATH, exist_ok=True)
39 | if not os.path.exists(os.path.join(DATA_PATH, 'train.bin')):
40 | train_data = np.memmap(os.path.join(ARXIVMATH_ORIGINAL_PATH, 'train.bin'), dtype=np.uint16, mode='r')
41 | raw_tokenized_train = add_mem_tokens(config.landmark_id, train_data, config.mem_freq)
42 | train_tokenized = np.array(raw_tokenized_train, dtype=np.uint16)
43 | train_tokenized.tofile(os.path.join(DATA_PATH, 'train.bin'))
44 |
45 | if not os.path.exists(os.path.join(DATA_PATH, 'val.bin')):
46 | val_data = np.memmap(os.path.join(ARXIVMATH_ORIGINAL_PATH, 'validation.bin'), dtype=np.uint16, mode='r')
47 | raw_tokenized_eval = add_mem_tokens(config.landmark_id, val_data, config.mem_freq)
48 | eval_tokenized = np.array(raw_tokenized_eval, dtype=np.uint16)
49 | eval_tokenized.tofile(os.path.join(DATA_PATH, 'val.bin'))
50 | print("completed the tokenization process!")
51 |
52 |
53 | def get_arxivmath_data(config):
54 | DATA_PATH = get_path(config)
55 |
56 | train_data = np.memmap(os.path.join(DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r')
57 | val_data = np.memmap(os.path.join(DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r')
58 |
59 | return {'train': train_data, 'val': val_data}
60 |
--------------------------------------------------------------------------------
/lm_benchmark/data/pg19.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import zipfile
17 | import urllib
18 | import numpy as np
19 | import tiktoken
20 | import torch
21 | import regex
22 | import multiprocessing
23 | import itertools
24 | import functools
25 |
26 | from .utils import add_mem_tokens
27 |
28 |
29 | PG19_ORIGINAL_PATH = "./data/pg19"
30 |
31 |
32 | def get_path(config):
33 | dataset_name = f"pg19_mem={config.mem_freq}"
34 | return os.path.join(os.path.dirname(__file__), f"datasets/{dataset_name}/")
35 |
36 | def prepare_pg19_data(config):
37 | DATA_PATH = get_path(config)
38 | print(DATA_PATH)
39 | os.makedirs(DATA_PATH, exist_ok=True)
40 | if not os.path.exists(os.path.join(DATA_PATH, 'train.bin')):
41 | train_data = np.memmap(os.path.join(PG19_ORIGINAL_PATH, 'train.bin'), dtype=np.uint16, mode='r')
42 | raw_tokenized_train = add_mem_tokens(config.landmark_id, train_data, config.mem_freq)
43 | train_tokenized = np.array(raw_tokenized_train, dtype=np.uint16)
44 | train_tokenized.tofile(os.path.join(DATA_PATH, 'train.bin'))
45 |
46 | if not os.path.exists(os.path.join(DATA_PATH, 'val.bin')):
47 | val_data = np.memmap(os.path.join(PG19_ORIGINAL_PATH, 'validation.bin'), dtype=np.uint16, mode='r')
48 | raw_tokenized_eval = add_mem_tokens(config.landmark_id, val_data, config.mem_freq)
49 | eval_tokenized = np.array(raw_tokenized_eval, dtype=np.uint16)
50 | eval_tokenized.tofile(os.path.join(DATA_PATH, 'val.bin'))
51 | print("completed the tokenization process!")
52 |
53 |
54 | def get_pg19_data(config):
55 | DATA_PATH = get_path(config)
56 |
57 | train_data = np.memmap(os.path.join(DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r')
58 | val_data = np.memmap(os.path.join(DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r')
59 |
60 | return {'train': train_data, 'val': val_data}
61 |
--------------------------------------------------------------------------------
/lm_benchmark/data/pg19/README.md:
--------------------------------------------------------------------------------
1 | Download dataset from https://github.com/deepmind/pg19 in this folder. The run prepare.py
2 |
--------------------------------------------------------------------------------
/lm_benchmark/data/pg19/prepare.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import tiktoken
17 |
18 | import numpy as np
19 |
20 |
21 |
22 | gpt2_tokenizer = tiktoken.get_encoding("gpt2")
23 | def _read_directory(path):
24 | texts = []
25 | for filename in os.listdir(path):
26 | if filename.endswith(".txt") and filename[:-4].isnumeric():
27 | print(filename)
28 | with open(os.path.join(path, filename), 'r') as f:
29 | texts += gpt2_tokenizer.encode_ordinary(f.read())
30 | texts.append(gpt2_tokenizer.eot_token)
31 | return np.array(texts, dtype=np.uint16)
32 |
33 |
34 | raw_eval_data = _read_directory("validation")
35 | raw_eval_data.tofile('validation.bin')
36 | raw_train_data = _read_directory("train")
37 | raw_train_data.tofile('train.bin')
38 |
--------------------------------------------------------------------------------
/lm_benchmark/data/proof-pile/prepare.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import json
16 | import os
17 |
18 | from datasets import load_dataset
19 | import numpy as np
20 | import tiktoken
21 | from tqdm import tqdm
22 |
23 |
24 | dir_path = os.path.dirname(os.path.realpath(__file__))
25 |
26 |
27 | dataset = load_dataset("hoskinson-center/proof-pile", cache_dir=os.path.join(dir_path, "cache"))
28 |
29 |
30 | num_proc = 16
31 | arxiv = dataset.filter(lambda x: json.loads(x['meta']).get('config', None) == "arxiv", num_proc=num_proc)
32 |
33 |
34 | enc = tiktoken.get_encoding("gpt2")
35 | def process(example):
36 | ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens
37 | ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe
38 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though...
39 | out = {'ids': ids, 'len': len(ids)}
40 | return out
41 |
42 | # tokenize the dataset
43 | tokenized = arxiv.map(
44 | process,
45 | remove_columns=['text'],
46 | desc="tokenizing the splits",
47 | num_proc=num_proc,
48 | )
49 |
50 |
51 | for split, dset in tokenized.items():
52 | arr_len = np.sum(dset['len'])
53 | filename = os.path.join(dir_path, f'{split}.bin')
54 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
55 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
56 | total_batches = 1024
57 |
58 | idx = 0
59 | for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
60 | # Batch together samples for faster write
61 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
62 | arr_batch = np.concatenate(batch['ids'])
63 | # Write into mmap
64 | arr[idx : idx + len(arr_batch)] = arr_batch
65 | idx += len(arr_batch)
66 | arr.flush()
67 |
--------------------------------------------------------------------------------
/lm_benchmark/data/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | import torch
17 | import multiprocessing
18 | import itertools
19 | import functools
20 |
21 |
22 | def apply_add_mem_tokens(mem_id, tokens_filename, freq, start_idx, end_idx):
23 | tokens = np.memmap(tokens_filename, dtype=np.uint16, mode='r')
24 | print(f"Processing {start_idx}-{end_idx}")
25 | tokens_with_mem = []
26 | for t_idx in range(start_idx, end_idx):
27 | t = tokens[t_idx]
28 | tokens_with_mem.append(t)
29 | if freq is not None and t_idx % freq == freq - 1:
30 | tokens_with_mem.append(mem_id)
31 | return tokens_with_mem
32 |
33 | def add_mem_tokens(mem_id, tokens, freq, n_workers=32):
34 | print(len(tokens))
35 | with multiprocessing.Pool(n_workers) as pool:
36 | ids = list(range(0, len(tokens), 10 * 1000 * 1000))
37 | pair_ids = zip(ids, ids[1:] + [len(tokens)])
38 | apply = functools.partial(apply_add_mem_tokens, mem_id, tokens.filename, freq)
39 | return list(itertools.chain.from_iterable(pool.starmap(apply, pair_ids)))
40 |
--------------------------------------------------------------------------------
/lm_benchmark/distributed/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from . import ddp
17 | from . import single
18 |
19 | BACKEND_TYPE_TO_MODULE_MAP = {
20 | "nccl": ddp.DataParallelDistributedBackend,
21 | None: single.SinlgeNodeBackend,
22 | }
23 |
24 |
25 | def make_backend_from_args(args):
26 | return BACKEND_TYPE_TO_MODULE_MAP[args.distributed_backend](args)
27 |
28 |
29 | def registered_backends():
30 | return BACKEND_TYPE_TO_MODULE_MAP.keys()
31 |
--------------------------------------------------------------------------------
/lm_benchmark/distributed/backend.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List
17 |
18 |
19 | class DistributedBackend(object):
20 |
21 | def __init__(self, args):
22 | pass
23 |
24 | def transform_model(self, model):
25 | raise NotImplementedError
26 |
27 | def get_context_for_microstep_forward(self, model, microstep_idx, gradient_accumulation_steps):
28 | raise NotImplementedError
29 |
30 | def is_master_process(self) -> bool:
31 | raise NotImplementedError
32 |
33 | def get_adjusted_args_for_process(self, args):
34 | raise NotImplementedError
35 |
36 | def get_raw_model(self, model):
37 | raise NotImplementedError
38 |
39 | def translate_model_parameter_name_for_node(self, parameter_name) -> List[str]:
40 | raise NotImplementedError
41 |
42 | def get_world_size(self):
43 | raise NotImplementedError
44 |
45 | def sync(self):
46 | raise NotImplementedError
47 |
48 | def finalize(self):
49 | pass
50 |
--------------------------------------------------------------------------------
/lm_benchmark/distributed/ddp.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import math
17 | from contextlib import contextmanager
18 |
19 | from torch.nn.parallel import DistributedDataParallel as DDP
20 | from torch.distributed import init_process_group, destroy_process_group, get_world_size, barrier
21 |
22 | from .backend import DistributedBackend
23 |
24 |
25 | class DataParallelDistributedBackend(DistributedBackend):
26 |
27 | def __init__(self, args):
28 | self.rank = int(os.environ.get('RANK', -1))
29 | assert self.rank != -1, "DDP backend can not be used without rank"
30 | assert "cuda" in args.device, "DDP backend can not be used on non-CUDA devices"
31 | init_process_group(backend=args.distributed_backend)
32 | self.local_rank = int(os.environ['LOCAL_RANK'])
33 |
34 | def get_adjusted_args_for_process(self, args):
35 | effective_batch_size = args.batch_size * args.acc_steps
36 | world_size = self.get_world_size()
37 | if effective_batch_size % world_size != 0:
38 | raise ValueError(f"Effective batch size "
39 | "{effective_batch_size} is not divisible "
40 | "by the world size {world_size}.")
41 | acc_steps_div = math.gcd(args.acc_steps, world_size)
42 | args.acc_steps = args.acc_steps // acc_steps_div
43 | args.batch_size = args.batch_size // (world_size // acc_steps_div)
44 | args.device = f'cuda:{self.local_rank}'
45 | args.seed = args.seed + self.local_rank
46 | return args
47 |
48 | def transform_model(self, model):
49 | return DDP(model, device_ids=[self.local_rank], find_unused_parameters=True)
50 |
51 | @contextmanager
52 | def get_context_for_microstep_forward(self, model, microstep_idx, gradient_accumulation_steps):
53 | model.require_backward_grad_sync = (
54 | microstep_idx == gradient_accumulation_steps - 1)
55 | yield
56 |
57 | def is_master_process(self) -> bool:
58 | return self.rank == 0
59 |
60 | def get_raw_model(self, model):
61 | return model.module
62 |
63 | def translate_model_parameter_name_for_node(self, parameter_name):
64 | return [f'module.{parameter_name}']
65 |
66 | def get_world_size(self):
67 | return get_world_size()
68 |
69 | def sync(self):
70 | barrier()
71 |
72 | def finalize(self):
73 | destroy_process_group()
74 |
--------------------------------------------------------------------------------
/lm_benchmark/distributed/single.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from contextlib import nullcontext
16 |
17 | from .backend import DistributedBackend
18 |
19 |
20 | class SinlgeNodeBackend(DistributedBackend):
21 |
22 | def transform_model(self, model):
23 | return model
24 |
25 | def get_context_for_microstep_forward(self, *args, **kwargs):
26 | return nullcontext()
27 |
28 | def get_adjusted_args_for_process(self, args):
29 | return args
30 |
31 | def is_master_process(self) -> bool:
32 | return True
33 |
34 | def get_raw_model(self, model):
35 | return model
36 |
37 | def get_world_size(self):
38 | return 1
39 |
40 | def sync(self):
41 | pass
42 |
43 | def translate_model_parameter_name_for_node(self, parameter_name):
44 | return [parameter_name]
45 |
--------------------------------------------------------------------------------
/lm_benchmark/eval.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import sys
17 | import numpy as np
18 | import torch
19 | import inspect
20 | import json
21 | import copy
22 | import argparse
23 | import random
24 | import wandb
25 | import logging
26 |
27 | from tqdm import tqdm
28 |
29 | import config
30 | import models
31 | from data import get_dataset, prepare_dataset
32 | from optim.base import train_base
33 | import distributed
34 | from optim.utils import get_batch
35 |
36 |
37 |
38 | def get_args():
39 | parser = argparse.ArgumentParser(allow_abbrev=False)
40 | parser.add_argument('--checkpoint', type=str, required=True)
41 |
42 | args, rem_args = parser.parse_known_args()
43 |
44 | if os.path.isfile(args.checkpoint):
45 | args.checkpoint, args.checkpoint_filename = os.path.split(args.checkpoint)
46 | else:
47 | args.checkpoint_filename = "ckpt.pt"
48 |
49 | with open(os.path.join(args.checkpoint, "summary.json")) as f:
50 | summary = json.load(f)
51 |
52 | for k, v in summary['args'].items():
53 | if k not in ["device", "dtype"]:
54 | setattr(args, k, v)
55 |
56 | return config.parse_args_with_format(format=args.config_format, base_parser=argparse.ArgumentParser(allow_abbrev=False), args=rem_args, namespace=args)
57 |
58 |
59 | def get_as_batch(data, seq_length, batch_size, device='cpu', sample_size=None):
60 | all_ix = list(range(0, len(data), seq_length))
61 | assert all_ix[-1] + seq_length + 1 > len(data)
62 | all_ix.pop()
63 | if sample_size is not None:
64 | all_ix = np.random.choice(all_ix, size=sample_size // seq_length, replace=False).tolist()
65 |
66 | idx = 0
67 | for idx in range(0, len(all_ix), batch_size):
68 | ix = all_ix[idx:idx+batch_size]
69 | assert all([idx + seq_length + 1 <= len(data) for idx in ix])
70 | x = torch.stack([torch.from_numpy((data[i:i+seq_length]).astype(np.int64)) for i in ix])
71 | y = torch.stack([torch.from_numpy((data[i+1:i+1+seq_length]).astype(np.int64)) for i in ix])
72 | if device != 'cpu':
73 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
74 | yield x, y
75 |
76 | def iceildiv(x, y):
77 | return (x + y - 1) // y
78 |
79 | def evaluate(model, data, iterations, acc_steps, batch_size, sequence_length, distributed_backend, extra_args):
80 | device_type = 'cuda' if 'cuda' in str(extra_args.device) else 'cpu'
81 | type_ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(
82 | device_type=device_type, dtype=extra_args.dtype) # extra_args.dtype)
83 | itr, substep, best_val_loss, text_table = 0, 0, float('inf'), None # best_val_loss not used atm, early stopping not recommended but possible
84 |
85 | stats = {}
86 |
87 | num_substeps_per_epoch = len(data['val']) // (batch_size * sequence_length)
88 |
89 | if not extra_args.no_compile:
90 | print(f"Compiling model ...")
91 | import torch._dynamo as torchdynamo
92 | torchdynamo.config.guard_nn_modules = True
93 | # torchdynamo.config.log_level = logging.DEBUG
94 | model = torch.compile(model) # requires pytorch 2.0+
95 |
96 | model.eval()
97 |
98 | loss_list_val, acc_list = [], []
99 | loss_step_list_val = []
100 |
101 | max_num_batches = 400
102 | with torch.no_grad():
103 | mid_length = extra_args.mid_length
104 | print(f"Sending sub-sequences of length at most {mid_length}")
105 | seq_length = extra_args.eval_seq_length
106 | print(f"Using seq length {seq_length}")
107 | torch.set_printoptions(sci_mode=False)
108 | for idx, (x, y) in tqdm(
109 | enumerate(
110 | get_as_batch(
111 | data['val'],
112 | seq_length,
113 | batch_size,
114 | device=extra_args.device,
115 | sample_size=extra_args.eval_sample_size
116 | )
117 | ),
118 | total=iceildiv(
119 | extra_args.eval_sample_size // seq_length if extra_args.eval_sample_size is not None else
120 | iceildiv(len(data['val']), seq_length),
121 | batch_size
122 | )
123 | ):
124 | val_loss = 0.
125 | acc = 0.
126 | cnt = 0
127 | model.clear_state()
128 | for part_idx, i in enumerate(range(0, x.shape[1], mid_length)):
129 | part_len = x[:, i:i + mid_length].shape[1]
130 | with type_ctx:
131 | outputs = model(x[:, i:i + mid_length], targets=y[:, i:i+mid_length].contiguous(), get_logits=True, use_cache=extra_args.use_cache)
132 | val_loss = outputs['loss'] * part_len + val_loss
133 | acc = ((outputs['logits'].argmax(-1) == y[:, i:i+mid_length]).float().sum()) + acc
134 | cnt += part_len
135 | while len(loss_step_list_val) <= part_idx:
136 | loss_step_list_val.append([])
137 | loss_step_list_val[part_idx].append(outputs['loss'].item())
138 | val_loss /= cnt
139 | acc /= cnt
140 |
141 | loss_list_val.append(val_loss.item())
142 | acc_list.append(acc.item())
143 |
144 |
145 | stats['val_acc'] = torch.as_tensor(acc_list).mean().item()
146 | stats['val_loss'] = torch.as_tensor(loss_list_val).mean().item()
147 | stats['val_perplexity'] = 2.71828 ** stats['val_loss']
148 | stats['val_perplexity_per_chunk'] = torch.exp(torch.as_tensor(loss_step_list_val).mean(dim=1))
149 |
150 | return stats
151 |
152 | def main(args):
153 |
154 |
155 | torch.backends.cuda.matmul.allow_tf32 = True # allows us to make sure we're able to use tensorfloat32 during training
156 | torch.backends.cudnn.allow_tf32 = True
157 |
158 | distributed_backend = distributed.make_backend_from_args(args)
159 | args = distributed_backend.get_adjusted_args_for_process(args)
160 |
161 | args.device = torch.device(args.device)
162 | torch.cuda.set_device(args.device)
163 | device_type = 'cuda' if 'cuda' in str(args.device) else 'cpu'
164 |
165 | torch.manual_seed(args.seed)
166 | random.seed(args.seed)
167 | np.random.seed(args.seed)
168 |
169 | print(f"Loading dataset '{args.dataset}'")
170 |
171 | if distributed_backend.is_master_process():
172 | prepare_dataset(args)
173 | distributed_backend.sync()
174 |
175 | data = get_dataset(args) # data is a dict: {'train': train_tokenized, 'val': eval_tokenized}
176 |
177 | print(f"Num training tokens: {len(data['train'])}")
178 | print(f"Num validation tokens: {len(data['val'])}")
179 |
180 | model = models.make_model_from_args(args).to(args.device)
181 |
182 | checkpoint = torch.load(os.path.join(args.checkpoint, args.checkpoint_filename))
183 | model.load_state_dict({x: y for x, y in checkpoint['model'].items() if "attn.bias" not in x and "wpe" not in x}, strict=False)
184 |
185 | model = distributed_backend.transform_model(model)
186 |
187 | print(f"\Evaluating model={args.model} \n{vars(args)}\n")
188 |
189 | stats = evaluate(model, data, args.iterations, args.acc_steps, args.batch_size, args.sequence_length,
190 | distributed_backend=distributed_backend,
191 | extra_args=args)
192 |
193 | print(stats)
194 |
195 | distributed_backend.finalize()
196 |
197 |
198 | if __name__ == "__main__":
199 | args = get_args()
200 | main(args)
201 |
--------------------------------------------------------------------------------
/lm_benchmark/eval_cmd_generator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import dataclasses
16 |
17 | @dataclasses.dataclass
18 | class Setting(object):
19 | exp_dir: str
20 | eval_len: int
21 | topk: int
22 | mem_size: int
23 | mid_length: int
24 | use_cache: bool = True
25 | selection_method: str = "per_token_and_head"
26 | mem_cache_freq: int = 50
27 | eval_sample_size: int = 4000000
28 | lm_cache: str = "mem"
29 |
30 |
31 | exp_dirs = {
32 | "arxiv_landmark": "./exps/arxiv_landmark",
33 | "arxiv_baseline": "./exps/arxiv_baseline",
34 |
35 | "pg19_landmark": "./exps/pg19_landmark",
36 | "pg19_baseline": "./exps/pg19_baseline",
37 | "pg19_xl": "./exps/pg19_xl",
38 | }
39 | settings = [
40 | dict(exp_dir=exp_dirs["pg19_baseline"], eval_len=360, mid_length=360,
41 | lm_cache="none", mem_cache_freq=None, mem_size=None, topk=None, use_cache=False,eval_sample_size=None),
42 | dict(exp_dir=exp_dirs["pg19_baseline"], eval_len=512, mid_length=512,
43 | lm_cache="none", mem_cache_freq=None, mem_size=None, topk=None, use_cache=False,eval_sample_size=None),
44 |
45 | dict(exp_dir=exp_dirs["pg19_xl"], eval_len=2048, mid_length=256,
46 | lm_cache="kv", mem_cache_freq=None, mem_size=256, topk=None,eval_sample_size=None),
47 | dict(exp_dir=exp_dirs["pg19_xl"], eval_len=4096, mid_length=256,
48 | lm_cache="kv", mem_cache_freq=None, mem_size=256, topk=None,eval_sample_size=None),
49 |
50 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=512, mid_length=250, mem_size=10, topk=2,eval_sample_size=None),
51 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=250, mem_size=40, topk=2,eval_sample_size=None),
52 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=350, mem_size=40, topk=2,eval_sample_size=None),
53 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=300, mem_size=40, topk=3,eval_sample_size=None),
54 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=250, mem_size=20, topk=4,eval_sample_size=None),
55 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=250, mem_size=40, topk=4,eval_sample_size=None),
56 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=4096, mid_length=250, mem_size=40, topk=4,eval_sample_size=None),
57 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=4096, mid_length=250, mem_size=80, topk=2,eval_sample_size=None),
58 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=4096, mid_length=250, mem_size=80, topk=4,eval_sample_size=None),
59 |
60 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=250, mem_size=40, topk=2,eval_sample_size=None),
61 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=250, mem_size=40, topk=4,eval_sample_size=None),
62 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=4096, mid_length=250, mem_size=40, topk=4,eval_sample_size=None),
63 |
64 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=250, mem_size=40, topk=2,
65 | selection_method="max_over_heads",eval_sample_size=None),
66 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=250, mem_size=40, topk=4,
67 | selection_method="max_over_heads",eval_sample_size=None),
68 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=4096, mid_length=250, mem_size=80, topk=4,
69 | selection_method="max_over_heads",eval_sample_size=None),
70 |
71 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=250, mem_size=40, topk=2,
72 | selection_method="max_over_tokens",eval_sample_size=None),
73 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=2048, mid_length=250, mem_size=40, topk=4,
74 | selection_method="max_over_tokens",eval_sample_size=None),
75 | dict(exp_dir=exp_dirs["pg19_landmark"], eval_len=4096, mid_length=250, mem_size=80, topk=4,
76 | selection_method="max_over_tokens",eval_sample_size=None),
77 |
78 | dict(exp_dir=exp_dirs["arxiv_baseline"], eval_len=360, mid_length=360,
79 | lm_cache=None, mem_cache_freq=None, mem_size=None, topk=None, use_cache=False),
80 | dict(exp_dir=exp_dirs["arxiv_baseline"], eval_len=512, mid_length=512,
81 | lm_cache=None, mem_cache_freq=None, mem_size=None, topk=None, use_cache=False),
82 |
83 | dict(exp_dir=exp_dirs["arxiv_landmark"], eval_len=512, mid_length=250, mem_size=10, topk=2),
84 | dict(exp_dir=exp_dirs["arxiv_landmark"], eval_len=2048, mid_length=250, mem_size=40, topk=2),
85 | dict(exp_dir=exp_dirs["arxiv_landmark"], eval_len=2048, mid_length=350, mem_size=40, topk=2),
86 | dict(exp_dir=exp_dirs["arxiv_landmark"], eval_len=2048, mid_length=300, mem_size=40, topk=3),
87 | dict(exp_dir=exp_dirs["arxiv_landmark"], eval_len=2048, mid_length=250, mem_size=20, topk=4),
88 | dict(exp_dir=exp_dirs["arxiv_landmark"], eval_len=2048, mid_length=250, mem_size=40, topk=4),
89 | dict(exp_dir=exp_dirs["arxiv_landmark"], eval_len=4096, mid_length=250, mem_size=40, topk=4),
90 | dict(exp_dir=exp_dirs["arxiv_landmark"], eval_len=4096, mid_length=250, mem_size=80, topk=2),
91 | dict(exp_dir=exp_dirs["arxiv_landmark"], eval_len=4096, mid_length=250, mem_size=80, topk=4),
92 | ]
93 |
94 | import itertools
95 | def product_dict(**kwargs):
96 | keys = kwargs.keys()
97 | for instance in itertools.product(*kwargs.values()):
98 | yield dict(zip(keys, instance))
99 |
100 | flat_settings = []
101 | for setting in settings:
102 | flat_settings.extend(product_dict(**{x: y if isinstance(y, list) else [y] for x, y in setting.items()}))
103 |
104 | settings = [Setting(**d) for d in flat_settings]
105 | last_exp_dir = None
106 |
107 | print ("#!/bin/bash")
108 | for setting in settings:
109 | s_lines = []
110 | if last_exp_dir != setting.exp_dir:
111 | s_lines.append("""EXP_DIR="{exp_dir}";""".format(**dataclasses.asdict(setting)))
112 | last_exp_dir = setting.exp_dir
113 | use_cache_str = "--use_cache" if setting.use_cache else ""
114 | mem_size_flag = ""
115 | s_lines += ["""
116 | filename="$EXP_DIR/eval-{eval_len}-{selection_method}-{topk}-memsize{mem_size}-midlength{mid_length}-memcachefreq{mem_cache_freq}";
117 | grep val_acc $filename /dev/null;
118 | if [[ $? -ne 0 ]]; then
119 | script -c \\
120 | "python eval.py \\
121 | --checkpoint $EXP_DIR \\
122 | --distributed_backend None \\
123 | --lm_cache {lm_cache} \\""","""
124 | --mem_cache_size {mem_size} \\""" if setting.mem_size is not None else "","""
125 | --mem_cache_freq {mem_cache_freq} \\""" if setting.mem_cache_freq is not None else "", """
126 | --mem_freq None \\
127 | --eval_seq_length {eval_len} \\
128 | --cache_selection_method {selection_method} \\""","""
129 | --cache_topk {topk} \\""" if setting.topk is not None else "", """
130 | --no_compile \\
131 | --batch_size 16 \\
132 | --mid_length {mid_length} \\
133 | --positional_encoder rotary \\
134 | --pos_jump_on_mem 0 \\
135 | {use_cache_str} \\""", """
136 | --eval_sample_size {eval_sample_size}""" if setting.eval_sample_size is not None else "", """
137 | " $filename;
138 | fi;"""]
139 | print ("".join(s_lines).format(**dataclasses.asdict(setting), use_cache_str=use_cache_str))
140 |
--------------------------------------------------------------------------------
/lm_benchmark/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import sys
17 | import numpy as np
18 | import torch
19 | import inspect
20 | import json
21 | import copy
22 | import argparse
23 | import random
24 | import wandb
25 |
26 | import config
27 | import models
28 | from data import get_dataset, prepare_dataset
29 | from optim.base import train_base
30 | from optim.transformer_xl import train_xl
31 | import distributed
32 |
33 |
34 | def get_args():
35 | parser = argparse.ArgumentParser(allow_abbrev=False)
36 | parser.add_argument('--config_format', default='base', choices=config.registered_formats())
37 |
38 | args, rem_args = parser.parse_known_args()
39 |
40 | return config.parse_args_with_format(format=args.config_format, base_parser=parser, args=rem_args, namespace=args)
41 |
42 |
43 | def main(args):
44 |
45 |
46 | torch.backends.cuda.matmul.allow_tf32 = True # allows us to make sure we're able to use tensorfloat32 during training
47 | torch.backends.cudnn.allow_tf32 = True
48 |
49 | distributed_backend = distributed.make_backend_from_args(args)
50 | args = distributed_backend.get_adjusted_args_for_process(args)
51 |
52 | args.device = torch.device(args.device)
53 | torch.cuda.set_device(args.device)
54 | device_type = 'cuda' if 'cuda' in str(args.device) else 'cpu'
55 |
56 | torch.manual_seed(args.seed)
57 | random.seed(args.seed)
58 | np.random.seed(args.seed)
59 |
60 | print(f"Loading dataset '{args.dataset}'")
61 |
62 | if distributed_backend.is_master_process():
63 | prepare_dataset(args)
64 | distributed_backend.sync()
65 |
66 | data = get_dataset(args) # data is a dict: {'train': train_tokenized, 'val': eval_tokenized}
67 |
68 | print(f"Num training tokens: {len(data['train'])}")
69 | print(f"Num validation tokens: {len(data['val'])}")
70 |
71 | model = models.make_model_from_args(args).to(args.device)
72 |
73 | model = distributed_backend.transform_model(model)
74 |
75 | group_specs = distributed_backend.get_raw_model(model).get_parameter_group_specs()
76 | param_name_mapping = {p_name: p for p_name, p in model.named_parameters()}
77 | optimized_params_cnt = 0
78 | for g in group_specs:
79 | params = []
80 | for p_name in g["params"]:
81 | translated_p_names = distributed_backend.translate_model_parameter_name_for_node(p_name)
82 | params += [param_name_mapping[p_name] for p_name in translated_p_names]
83 | g["params"] = params
84 | optimized_params_cnt += sum([p.numel() for p in g["params"]])
85 | print("number of optimized parameters: %.2fM" % (optimized_params_cnt/1e6,))
86 | if args.opt == 'adamw':
87 | use_fused = (device_type == 'cuda') and ('fused' in inspect.signature(torch.optim.AdamW).parameters)
88 | print(f"using fused AdamW: {use_fused}")
89 | extra_args = dict(fused=True) if use_fused else dict()
90 | opt = torch.optim.AdamW(group_specs, lr=args.lr, betas=(args.beta1, args.beta2),
91 | weight_decay=args.weight_decay, **extra_args)
92 | elif args.opt == 'adafactor':
93 | from optim.adafactor import Adafactor
94 | opt = Adafactor(group_specs, lr=args.lr)
95 | else:
96 | opt = torch.optim.SGD(group_specs, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
97 |
98 | if args.scheduler != 'none':
99 | if args.scheduler in ['cos', 'linear']:
100 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=opt, max_lr=args.lr, total_steps=args.iterations,
101 | pct_start=args.warmup_percent, anneal_strategy=args.scheduler,
102 | cycle_momentum=False, div_factor=1e2, final_div_factor=.05)
103 | else:
104 | raise NotImplementedError(f"Unknown scheduler type: {args.scheduler}.")
105 | else:
106 | scheduler = None
107 |
108 | args.world_size = distributed_backend.get_world_size()
109 | exp_name = args.exp_name
110 | if distributed_backend.is_master_process() and args.wandb:
111 | params_copy = copy.deepcopy(vars(args))
112 | del params_copy['device']
113 | wandb.init(project=args.wandb_project, name=exp_name, config=params_copy)
114 |
115 | ckpt_path = f"{args.results_base_folder}/{args.dataset}/{args.model}/{exp_name}"
116 | if not os.path.exists(ckpt_path):
117 | if distributed_backend.is_master_process():
118 | os.makedirs(ckpt_path)
119 | else:
120 | if os.path.isfile(f"{ckpt_path}/summary.json"): # the experiment was already completed
121 | print(f"Already found experiment '{ckpt_path}'.\nSkipping.")
122 | sys.exit(0)
123 |
124 | if args.optimization_process == 'transformer_xl':
125 | train = train_xl
126 | else:
127 | train = train_base
128 |
129 | print(f"\nTraining model={args.model} \n{vars(args)}\n")
130 |
131 | stats = train(model, opt, data, scheduler, args.iterations, args.acc_steps, args.batch_size, args.sequence_length,
132 | eval_freq=args.eval_freq,
133 | distributed_backend=distributed_backend,
134 | ckpt_path=ckpt_path, extra_args=args)
135 |
136 | args.device = None
137 | args.dtype = None
138 | stats['args'] = vars(args)
139 | if distributed_backend.is_master_process():
140 | with open(f"{ckpt_path}/summary.json", "w") as fs:
141 | json.dump(stats, fs)
142 | distributed_backend.finalize()
143 |
144 |
145 | if __name__ == "__main__":
146 | args = get_args()
147 | main(args)
148 |
--------------------------------------------------------------------------------
/lm_benchmark/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from . import base_new, landmark, landmark_with_cmt
16 |
17 | MODELS = {
18 | "base": base_new.GPTBase,
19 | "landmark": landmark.GPTBase,
20 | "landmark_with_cmt": landmark_with_cmt.GPTBase,
21 | }
22 |
23 |
24 | def make_model_from_args(args):
25 | return MODELS[args.model](args)
26 |
27 |
28 | def registered_models():
29 | return MODELS.keys()
30 |
--------------------------------------------------------------------------------
/lm_benchmark/models/caches/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from . import cache, mem_cache, kv_cache, kv_cache_train
16 |
17 | CACHES = {
18 | "none": cache.LMCache,
19 | "mem": mem_cache.MemLMCache,
20 | "kv": kv_cache.KVLMCache,
21 | "kv_train": kv_cache_train.KVLMCache
22 | }
23 |
24 |
25 | def get_cache(cache_name):
26 | return CACHES[cache_name]
27 |
28 |
29 | def registered_caches():
30 | return CACHES.keys()
31 |
--------------------------------------------------------------------------------
/lm_benchmark/models/caches/cache.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from torch import nn
16 |
17 | class LMCacheStorage(nn.Module):
18 |
19 | def __init__(self, config, layer):
20 | super().__init__()
21 | self.config = config
22 | #self._layer = [layer]
23 |
24 | #@property
25 | #def layer(self):
26 | # return self._layer[0]
27 |
28 | def store_in_cache(self, keys, values_dict):
29 | pass
30 |
31 | def retrieve_for_query(self, q, cache_context, pos_emb_closure, start_index):
32 | return None, {}
33 |
34 | def clear_state(self):
35 | pass
36 |
37 |
38 | class LMCacheContext(object):
39 | pass
40 |
41 |
42 | class LMCache(nn.Module):
43 |
44 | def __init__(self, config):
45 | super().__init__()
46 | self.config = config
47 | self.layer_storages_map = dict()
48 | self.layer_storages = nn.ModuleList()
49 | self.cache_storage = self.get_cache_storage()
50 | self.context_class = self.get_context_class()
51 |
52 | def get_cache_storage(self):
53 | return LMCacheStorage
54 |
55 | def get_context_class(self):
56 | return LMCacheContext
57 |
58 | def forward(self, x):
59 | return x, 0, self.get_context_class()
60 |
61 | def get_final_logits(self, logits):
62 | return logits
63 |
64 | def get_storage_for_layer(self, l):
65 | if l not in self.layer_storages_map:
66 | self.layer_storages_map[l] = len(self.layer_storages)
67 | self.layer_storages.append(self.cache_storage(self.config, l))
68 | return self.layer_storages[self.layer_storages_map[l]]
69 |
70 | def clear_state(self):
71 | for storage in self.layer_storages:
72 | storage.clear_state()
73 |
--------------------------------------------------------------------------------
/lm_benchmark/models/caches/kv_cache.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 |
17 | import torch
18 |
19 | from .cache import LMCache, LMCacheStorage
20 |
21 | class KVLMCacheStorage(LMCacheStorage):
22 |
23 | def __init__(self, config, layer):
24 | super().__init__(config, layer)
25 | n_embd_per_head = config.n_embd // config.n_head
26 | self.max_cache_size = config.mem_cache_size
27 | self.register_buffer("cache_k", torch.empty((config.batch_size, config.n_head, config.mem_cache_size, n_embd_per_head)), persistent=False)
28 | self.register_buffer("cache_v", torch.empty((config.batch_size, config.n_head, config.mem_cache_size, n_embd_per_head)), persistent=False)
29 | self.cache_iter = 0
30 | self.cache_size = 0
31 | self.clear_state()
32 |
33 | def clear_state(self):
34 | self.cache_iter = 0
35 | self.cache_size = 0
36 |
37 | def retrieve_for_query(self, q, cache_context, pos_emb_closure, start_index):
38 | if self.cache_size == 0:
39 | return None, {}
40 | B, nh, T, hs = q.size() # batch size, num_heads, sequence length, per-head embedding dimensionality (n_embd)
41 | cached_keys = self.cache_k[:B, :, :self.cache_size]
42 | k_indices = torch.cat((
43 | torch.arange(self.cache_size - self.cache_iter, self.cache_size, device=q.device),
44 | torch.arange(self.cache_size - cached_keys.shape[2], self.cache_size - self.cache_iter, device=q.device),
45 | ))
46 | assert self.cache_size == start_index
47 | last_incomplete_k = pos_emb_closure.adapt_keys(cached_keys, indices=k_indices)
48 | att_incomplete = (q @ last_incomplete_k.transpose(-2, -1)) * (1.0 / math.sqrt(last_incomplete_k.size(-1)))
49 | last_incomplete_v = self.cache_v[:B, :, :self.cache_size].unsqueeze(2).expand(B, nh, T, -1, hs)
50 | return att_incomplete, {'v': last_incomplete_v.clone()}
51 |
52 |
53 | def store_in_cache(self, keys, values_dict):
54 | if self.max_cache_size == 0:
55 | return
56 | B, nh, T, hs = keys.size()
57 | k_for_cache = keys[:, :, -self.max_cache_size:]
58 | v_for_cache = values_dict['v'][:, :, -self.max_cache_size:]
59 | self.cache_iter = (self.cache_iter + T - k_for_cache.shape[2]) % self.cache_k.shape[2]
60 | self.cache_size += T - k_for_cache.shape[2]
61 | T = k_for_cache.shape[2]
62 |
63 | if self.cache_iter + T >= self.max_cache_size:
64 | next_iter = (self.cache_iter + T) - self.max_cache_size
65 | rem = (self.max_cache_size - self.cache_iter)
66 | self.cache_k[:B, :, :next_iter].copy_(k_for_cache[:, :, rem:])
67 | self.cache_k[:B, :, self.cache_iter:].copy_(k_for_cache[:,:, :rem])
68 | self.cache_v[:B, :, :next_iter].copy_(v_for_cache[:,:, rem:])
69 | self.cache_v[:B, :, self.cache_iter:].copy_(v_for_cache[:,:, :rem])
70 | else:
71 | next_iter = self.cache_iter + T
72 | self.cache_k[:B, :, self.cache_iter:next_iter].copy_(k_for_cache)
73 | self.cache_v[:B, :, self.cache_iter:next_iter].copy_(v_for_cache)
74 | self.cache_iter = next_iter
75 | self.cache_size += T
76 |
77 |
78 | class KVLMCache(LMCache):
79 |
80 | def __init__(self, config):
81 | super().__init__(config)
82 | self.total_len = 0
83 |
84 | def get_cache_storage(self):
85 | return KVLMCacheStorage
86 |
87 | def forward(self, x):
88 | B, T = x.size()
89 | prev_total_len = self.total_len
90 | self.total_len = self.total_len + x.shape[1]
91 | return x, prev_total_len, self.context_class()
92 |
93 | def clear_state(self):
94 | super().clear_state()
95 | self.total_len = 0
96 |
97 |
--------------------------------------------------------------------------------
/lm_benchmark/models/caches/kv_cache_train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 |
17 | import torch
18 |
19 | from .cache import LMCache, LMCacheStorage
20 |
21 | class KVLMCacheStorage(LMCacheStorage):
22 |
23 | def __init__(self, config, layer):
24 | super().__init__(config, layer)
25 | n_embd_per_head = config.n_embd // config.n_head
26 | self.max_cache_size = config.mem_cache_size
27 | self._cache_k = [torch.empty((config.batch_size, config.n_head, config.mem_cache_size, n_embd_per_head), device=torch.device('cuda'))]
28 | self._cache_v = [torch.empty((config.batch_size, config.n_head, config.mem_cache_size, n_embd_per_head), device=torch.device('cuda'))]
29 | self.cache_iter = 0
30 | self.cache_size = 0
31 | self.clear_state()
32 |
33 | @property
34 | def cache_k(self):
35 | return self._cache_k[0]
36 |
37 | @property
38 | def cache_v(self):
39 | return self._cache_v[0]
40 |
41 | def clear_state(self):
42 | self.cache_iter = 0
43 | self.cache_size = 0
44 |
45 | def retrieve_for_query(self, q, cache_context, pos_emb_closure, start_index):
46 | if self.cache_size == 0:
47 | return None, {}
48 | B, nh, T, hs = q.size() # batch size, num_heads, sequence length, per-head embedding dimensionality (n_embd)
49 | cached_keys = self.cache_k[:B, :, :self.cache_size]
50 | k_indices = torch.cat((
51 | torch.arange(self.cache_size - self.cache_iter, self.cache_size, device=q.device),
52 | torch.arange(self.cache_size - cached_keys.shape[2], self.cache_size - self.cache_iter, device=q.device),
53 | ))
54 | assert self.cache_size == start_index
55 | last_incomplete_k = pos_emb_closure.adapt_keys(cached_keys, indices=k_indices)
56 | att_incomplete = (q @ last_incomplete_k.transpose(-2, -1)) * (1.0 / math.sqrt(last_incomplete_k.size(-1)))
57 | last_incomplete_v = self.cache_v[:B, :, :self.cache_size].unsqueeze(2).expand(B, nh, T, -1, hs)
58 | return att_incomplete, {'v': last_incomplete_v.clone()}
59 |
60 |
61 | def store_in_cache(self, keys, values_dict):
62 | if self.max_cache_size == 0:
63 | return
64 | B, nh, T, hs = keys.size()
65 | k_for_cache = keys[:, :, -self.max_cache_size:]
66 | v_for_cache = values_dict['v'][:, :, -self.max_cache_size:]
67 | self.cache_iter = (self.cache_iter + T - k_for_cache.shape[2]) % self.cache_k.shape[2]
68 | self.cache_size += T - k_for_cache.shape[2]
69 | T = k_for_cache.shape[2]
70 |
71 | if self.cache_iter + T >= self.max_cache_size:
72 | next_iter = (self.cache_iter + T) - self.max_cache_size
73 | rem = (self.max_cache_size - self.cache_iter)
74 | self.cache_k[:B, :, :next_iter].copy_(k_for_cache[:, :, rem:])
75 | self.cache_k[:B, :, self.cache_iter:].copy_(k_for_cache[:,:, :rem])
76 | self.cache_v[:B, :, :next_iter].copy_(v_for_cache[:,:, rem:])
77 | self.cache_v[:B, :, self.cache_iter:].copy_(v_for_cache[:,:, :rem])
78 | else:
79 | next_iter = self.cache_iter + T
80 | self.cache_k[:B, :, self.cache_iter:next_iter].copy_(k_for_cache)
81 | self.cache_v[:B, :, self.cache_iter:next_iter].copy_(v_for_cache)
82 | self.cache_iter = next_iter
83 | self.cache_size += T
84 |
85 |
86 | class KVLMCache(LMCache):
87 |
88 | def __init__(self, config):
89 | super().__init__(config)
90 | self.total_len = 0
91 |
92 | def get_cache_storage(self):
93 | return KVLMCacheStorage
94 |
95 | def forward(self, x):
96 | B, T = x.size()
97 | prev_total_len = self.total_len
98 | self.total_len = self.total_len + x.shape[1]
99 | return x, prev_total_len, self.context_class()
100 |
101 | def clear_state(self):
102 | super().clear_state()
103 | self.total_len = 0
104 |
105 |
--------------------------------------------------------------------------------
/lm_benchmark/models/positional_encoders/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from . import encoder, rotary, rotary_mem_jump
16 |
17 | POS_ENCS = {
18 | "rotary": rotary.RotaryPositionalEncoder,
19 | "rotary_mem_jump": rotary_mem_jump.RotaryJumpMemPositionalEncoder
20 | }
21 |
22 |
23 | def get_encoder(encoder_name):
24 | return POS_ENCS[encoder_name]
25 |
26 |
27 | def registered_encoders():
28 | return POS_ENCS.keys()
29 |
--------------------------------------------------------------------------------
/lm_benchmark/models/positional_encoders/encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | from torch import nn
17 |
18 |
19 | class PositionalEncoderClosure(object):
20 |
21 | def __init__(self, encoder):
22 | self.encoder = encoder
23 |
24 | def adapt_model_input(self, x, start_index):
25 | return x
26 |
27 | def adapt_keys(self, k, start_index=None, indices=None):
28 | if indices is None:
29 | T = k.shape[-2]
30 | indices = torch.arange(start_index, T + start_index, device=k.device)
31 | return self._adapt_keys_for_indices(k, indices)
32 |
33 | def _adapt_keys_for_indices(self, k, indices):
34 | return k
35 |
36 | def adapt_queries(self, q, start_index):
37 | return q
38 |
39 | def adapt_attention_before_softmax(self, att, start_query_index=None, start_key_index=None, q_indices=None, k_indices=None):
40 | if q_indices is None:
41 | qT = att.shape[-2]
42 | q_indices = torch.arange(start_query_index, qT + start_query_index, device=att.device)
43 | if k_indices is None:
44 | kT = att.shape[-1]
45 | k_indices = torch.arange(start_key_index, kT + start_key_index, device=att.device)
46 | return self._adapt_attention_before_softmax_for_indices(att, q_indices, k_indices)
47 |
48 | def _adapt_attention_before_softmax_for_indices(self, att, query_indices, key_indices):
49 | return att
50 |
51 |
52 | class PositionalEncoder(nn.Module):
53 |
54 | closure_model = PositionalEncoderClosure
55 |
56 | def __init__(self, config):
57 | super().__init__()
58 | self.config = config
59 |
60 | def forward(self, x):
61 | return x, self.closure_model(self)
62 |
--------------------------------------------------------------------------------
/lm_benchmark/models/positional_encoders/rotary.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | from torch import nn
17 |
18 | from .encoder import PositionalEncoder, PositionalEncoderClosure
19 | from .rotary_utils import apply_rotary_emb
20 |
21 |
22 | class RotaryPositionalEncoderClosure(PositionalEncoderClosure):
23 |
24 | def adapt_vector_for_indices(self, v, indices):
25 | #changer = torch.zeros_like(indices)
26 | #changer[50::51] = 1
27 | #indices -= torch.cumsum(changer, dim=-1)
28 |
29 | *other_dims, T, hs = v.shape
30 | if T == 0:
31 | return v
32 | other_dims_prefix = other_dims[:len(other_dims) - len(indices.shape) + 1]
33 | freqs = (indices.unsqueeze(-1) * self.encoder.freqs.view(1, -1)).unsqueeze(-1).expand(*indices.shape, -1, 2).reshape(*indices.shape, hs)
34 | freqs = freqs.view([1] * len(other_dims_prefix) + list(indices.shape) + [hs]).expand(*v.shape)
35 | v = apply_rotary_emb(freqs, v)
36 | return v
37 |
38 | def _adapt_keys_for_indices(self, k, indices):
39 | return self.adapt_vector_for_indices(k, indices)
40 |
41 | def adapt_queries(self, q, start_index):
42 | T = q.shape[-2]
43 | indices = torch.arange(start_index, T + start_index, device=q.device)
44 | return self.adapt_vector_for_indices(q, indices)
45 |
46 |
47 | class RotaryPositionalEncoder(PositionalEncoder):
48 |
49 | def __init__(self, config):
50 | super().__init__(config)
51 | self.max_pos_log = 4
52 | self.max_pos_base = 10
53 | n_embd_per_head = config.n_embd // config.n_head
54 | freqs = (self.max_pos_base ** (-self.max_pos_log * torch.arange(0, n_embd_per_head, 2)[:(n_embd_per_head // 2)].float() / n_embd_per_head))
55 | self.register_buffer("freqs", freqs)
56 |
57 | closure_model = RotaryPositionalEncoderClosure
58 |
--------------------------------------------------------------------------------
/lm_benchmark/models/positional_encoders/rotary_mem_jump.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | from torch import nn
17 |
18 | from .encoder import PositionalEncoder, PositionalEncoderClosure
19 | from .rotary_utils import apply_rotary_emb
20 |
21 |
22 | class JumpingRotaryPositionalEncoderClosure(PositionalEncoderClosure):
23 |
24 | def __init__(self, encoder, jumps):
25 | super().__init__(encoder)
26 | self.jumps = jumps
27 |
28 | def adapt_vector_for_indices(self, v, indices):
29 | #changer = torch.zeros_like(indices)
30 | #changer[50::51] = 1
31 | #indices -= torch.cumsum(changer, dim=-1)
32 |
33 |
34 | *other_dims, T, hs = v.shape
35 | if T == 0:
36 | return v
37 | other_dims_prefix = other_dims[:len(other_dims) - len(indices.shape) + 1]
38 | if self.jumps is not None:
39 | indices = indices.view([1] * len(other_dims_prefix) + list(indices.shape)).repeat(other_dims_prefix + [1] * len(indices.shape))
40 | indices[..., 1:] = indices[..., 1:] + self.jumps
41 | other_dims_prefix = []
42 | # print(indices)
43 | freqs = (indices.unsqueeze(-1) * self.encoder.freqs.view(1, -1)).unsqueeze(-1).expand(*indices.shape, -1, 2).reshape(*indices.shape, hs)
44 | freqs = freqs.view([1] * len(other_dims_prefix) + list(indices.shape) + [hs]).expand(*v.shape)
45 | v = apply_rotary_emb(freqs, v)
46 | return v
47 |
48 | def _adapt_keys_for_indices(self, k, indices):
49 | return self.adapt_vector_for_indices(k, indices)
50 |
51 | def adapt_queries(self, q, start_index):
52 | T = q.shape[-2]
53 | indices = torch.arange(start_index, T + start_index, device=q.device)
54 | return self.adapt_vector_for_indices(q, indices)
55 |
56 |
57 | class RotaryJumpMemPositionalEncoder(PositionalEncoder):
58 |
59 | def __init__(self, config):
60 | super().__init__(config)
61 | self.max_pos_log = 4
62 | self.max_pos_base = 10
63 | n_embd_per_head = config.n_embd // config.n_head
64 | freqs = (self.max_pos_base ** (-self.max_pos_log * torch.arange(0, n_embd_per_head, 2)[:(n_embd_per_head // 2)].float() / n_embd_per_head))
65 | self.register_buffer("freqs", freqs)
66 |
67 | def forward(self, x):
68 | if self.config.pos_jump_on_mem is not None and self.config.pos_jump_on_mem > 0:
69 | #assert self.config.mem_freq is not None
70 | is_mem = (x == self.config.landmark_id)
71 | jumps = torch.cumsum((is_mem * torch.randint_like(x, self.config.pos_jump_on_mem))[:, :-1], dim=-1)
72 | return x, self.closure_model(self, jumps.unsqueeze(1)) # (B, 1, T)
73 | else:
74 | return x, self.closure_model(self, None)
75 |
76 |
77 | closure_model = JumpingRotaryPositionalEncoderClosure
78 |
--------------------------------------------------------------------------------
/lm_benchmark/models/positional_encoders/rotary_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 |
17 | def rotate_half(x):
18 | x = x.view(*x.shape[:-1], -1, 2)
19 | x1, x2 = x.unbind(dim = -1)
20 | x = torch.stack((-x2, x1), dim = -1)
21 | return x.view(*x.shape[:-2], -1)
22 |
23 | def apply_rotary_emb(freqs, t, start_index = 0, scale = 1.):
24 | #freqs = freqs.to(t)
25 | rot_dim = freqs.shape[-1]
26 | end_index = start_index + rot_dim
27 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
28 | #t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
29 | t = (t * freqs.cos().to(t) * scale) + (rotate_half(t) * freqs.sin().to(t) * scale)
30 | #return torch.cat((t_left, t, t_right), dim = -1)
31 | return t
32 |
--------------------------------------------------------------------------------
/lm_benchmark/optim/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from contextlib import nullcontext
16 |
17 | import torch
18 | import torch.nn.functional as F
19 | import wandb
20 | import time
21 | import copy
22 | import traceback
23 |
24 | from .utils import get_batch, save_checkpoint
25 |
26 | @torch.no_grad()
27 | def eval(model, data_tensor, sequence_length, batch_size, device='cpu', max_num_batches=24, ctx=nullcontext()):
28 | assert model.training == False
29 |
30 | loss_list_val, acc_list = [], []
31 |
32 | for _ in range(max_num_batches):
33 | x, y = get_batch(data_tensor, sequence_length, batch_size, device=device)
34 | with ctx:
35 | outputs = model(x, targets=y, get_logits=True)
36 | val_loss = outputs['loss']
37 | loss_list_val.append(val_loss)
38 | acc_list.append((outputs['logits'].argmax(-1) == y).float().mean())
39 |
40 | val_acc = torch.stack(acc_list).mean().item()
41 | val_loss = torch.stack(loss_list_val).mean().item()
42 | val_perplexity = 2.71828 ** val_loss
43 |
44 | return val_acc, val_loss, val_perplexity
45 |
46 |
47 | def train_base(model, opt, data, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend, extra_args):
48 | device_type = 'cuda' if 'cuda' in str(extra_args.device) else 'cpu'
49 | type_ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(
50 | device_type=device_type, dtype=extra_args.dtype) # extra_args.dtype)
51 | itr, substep, best_val_loss, text_table = 0, 0, float('inf'), None # best_val_loss not used atm, early stopping not recommended but possible
52 |
53 | stats = {'train_loss': [], 'val_loss': [], 'val_pp': [], 'val_acc': []}
54 |
55 | num_substeps_per_epoch = len(data['train']) // (batch_size * sequence_length)
56 |
57 | if not extra_args.no_compile:
58 | print(f"Compiling model ...")
59 | import torch._dynamo as torchdynamo
60 | torchdynamo.config.guard_nn_modules = True
61 | model = torch.compile(model) # requires pytorch 2.0+
62 |
63 | model.train()
64 |
65 | t0 = time.time()
66 |
67 | while itr < iterations:
68 |
69 | for microstep_idx in range(acc_steps): # gradient accumulation
70 | x, y = get_batch(data['train'], sequence_length, batch_size, device=extra_args.device)
71 | with type_ctx:
72 | with distributed_backend.get_context_for_microstep_forward(model=model, microstep_idx=microstep_idx, gradient_accumulation_steps=acc_steps):
73 | if getattr(distributed_backend.get_raw_model(model), "needs_iter", False):
74 | outputs = model(x, targets=y, iter=itr)
75 | else:
76 | outputs = model(x, targets=y)
77 |
78 | loss = outputs['loss']
79 | loss.backward()
80 | substep += 1
81 |
82 | opt.step()
83 | scheduler.step()
84 | opt.zero_grad(set_to_none=True)
85 | itr += 1
86 |
87 | if itr % eval_freq == 0 or itr == iterations: # from here it's only evaluation code, all the training is above
88 | if distributed_backend.is_master_process():
89 | t1 = time.time()
90 | dt = t1 - t0
91 | epoch = substep//num_substeps_per_epoch
92 |
93 | model.eval()
94 | train_loss = loss.detach().cpu().item()
95 | current_lr = scheduler.get_last_lr()[0] if scheduler is not None else extra_args.lr
96 | val_acc, val_loss, val_perplexity = eval(model, data['val'], sequence_length, batch_size,
97 | extra_args.device, max_num_batches=24, ctx=type_ctx)
98 |
99 | print_string = f"{epoch}/{itr} [train] loss={train_loss:.3f} [val] loss={val_loss:.3f}, pp={val_perplexity:.2f}, acc={val_acc:3f}"
100 | print_string += f" [time per itr] {dt*1000/eval_freq:.2f}ms"
101 | if scheduler is not None:
102 | print_string += f" [lr] {current_lr:.5f}"
103 | print(print_string)
104 |
105 | if extra_args.wandb:
106 | wandb.log({
107 | "iter": itr,
108 | "train/loss": train_loss,
109 | "val/loss": val_loss,
110 | "val/perplexity": val_perplexity,
111 | "val/acc": val_acc,
112 | "lr": current_lr,
113 | })
114 |
115 | model.train()
116 | t0 = time.time()
117 | if distributed_backend.is_master_process():
118 | if extra_args.save_checkpoint_freq is not None and itr % extra_args.save_checkpoint_freq == 0:
119 | print(f"saving checkpoint to {ckpt_path}/ckpt_{itr}.pt")
120 | save_checkpoint(distributed_backend=distributed_backend,
121 | model=model,
122 | opt=opt,
123 | scheduler=scheduler,
124 | itr=itr,
125 | ckpt_path=f"{ckpt_path}/ckpt_{itr}.pt")
126 |
127 | if distributed_backend.is_master_process():
128 | print(f"saving checkpoint to {ckpt_path}")
129 | save_checkpoint(distributed_backend=distributed_backend,
130 | model=model,
131 | opt=opt,
132 | scheduler=scheduler,
133 | itr=itr,
134 | ckpt_path=f"{ckpt_path}/ckpt.pt")
135 |
136 | return stats
137 |
--------------------------------------------------------------------------------
/lm_benchmark/optim/transformer_xl.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from contextlib import nullcontext
16 |
17 | import torch
18 | import torch.nn.functional as F
19 | import wandb
20 | import time
21 | import copy
22 | import traceback
23 |
24 | from .utils import get_batch, save_checkpoint
25 |
26 |
27 | @torch.no_grad()
28 | def eval(model, data_tensor, sequence_length, total_sequence_length, batch_size, device='cpu', max_num_batches=24, ctx=nullcontext()):
29 | assert model.training == False
30 |
31 | loss_list_val, acc_list = [], []
32 |
33 | for _ in range(max_num_batches):
34 | x, y = get_batch(data_tensor, total_sequence_length, batch_size, device=device)
35 | model.clear_state()
36 | total_loss = None
37 | for idx in range(0, x.shape[1], sequence_length):
38 | x_part = x[:, idx:idx+sequence_length]
39 | y_part = y[:, idx:idx+sequence_length].contiguous()
40 | with ctx:
41 | outputs = model(x_part, targets=y_part, get_logits=True, use_cache=True)
42 | val_loss = outputs['loss']
43 | if idx == 0:
44 | total_loss = val_loss
45 | else:
46 | total_loss += val_loss
47 | loss_list_val.append(total_loss)
48 | acc_list.append((outputs['logits'].argmax(-1) == y_part).float().mean())
49 |
50 | val_acc = torch.stack(acc_list).mean().item()
51 | val_loss = torch.stack(loss_list_val).mean().item()
52 | val_perplexity = 2.71828 ** val_loss
53 |
54 | return val_acc, val_loss, val_perplexity
55 |
56 | def train_xl(model, opt, data, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend, extra_args):
57 | device_type = 'cuda' if 'cuda' in str(extra_args.device) else 'cpu'
58 | type_ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(
59 | device_type=device_type, dtype=extra_args.dtype) # extra_args.dtype)
60 | itr, substep, best_val_loss, text_table = 0, 0, float('inf'), None # best_val_loss not used atm, early stopping not recommended but possible
61 |
62 | stats = {'train_loss': [], 'val_loss': [], 'val_pp': [], 'val_acc': []}
63 |
64 | num_substeps_per_epoch = len(data['train']) // (batch_size * sequence_length)
65 |
66 | if not extra_args.no_compile:
67 | print(f"Compiling model ...")
68 | import torch._dynamo as torchdynamo
69 | torchdynamo.config.guard_nn_modules = True
70 | model = torch.compile(model) # requires pytorch 2.0+
71 |
72 | model.train()
73 |
74 | if extra_args.postpone_lm_cache:
75 | distributed_backend.get_raw_model(model).init_cache()
76 |
77 | t0 = time.time()
78 | while itr < iterations:
79 | for microstep_idx in range(acc_steps): # gradient accumulation
80 | x, y = get_batch(data['train'], extra_args.total_sequence_length, batch_size, device=extra_args.device)
81 | distributed_backend.get_raw_model(model).clear_state()
82 | total_loss = None
83 | for idx in range(0, x.shape[1], extra_args.sequence_length):
84 | with type_ctx:
85 | with distributed_backend.get_context_for_microstep_forward(model=model, microstep_idx=microstep_idx, gradient_accumulation_steps=acc_steps):
86 | outputs = model(x[:, idx:idx+extra_args.sequence_length], targets=y[:, idx:idx+extra_args.sequence_length].contiguous(), use_cache=True)
87 |
88 | loss = outputs['loss']
89 | loss.backward()
90 | if idx == 0:
91 | total_loss = loss
92 | else:
93 | total_loss += loss
94 | substep += 1
95 |
96 | opt.step()
97 | scheduler.step()
98 | opt.zero_grad(set_to_none=True)
99 | itr += 1
100 |
101 | if itr % eval_freq == 0 or itr == iterations: # from here it's only evaluation code, all the training is above
102 | if distributed_backend.is_master_process():
103 | t1 = time.time()
104 | dt = t1 - t0
105 | epoch = substep//num_substeps_per_epoch
106 |
107 | model.eval()
108 | train_loss = loss.detach().cpu().item()
109 | current_lr = scheduler.get_last_lr()[0] if scheduler is not None else extra_args.lr
110 | val_acc, val_loss, val_perplexity = eval(distributed_backend.get_raw_model(model), data['val'], sequence_length, extra_args.total_sequence_length,
111 | batch_size, extra_args.device, max_num_batches=24, ctx=type_ctx)
112 |
113 | print_string = f"{epoch}/{itr} [train] loss={train_loss:.3f} [val] loss={val_loss:.3f}, pp={val_perplexity:.2f}, acc={val_acc:3f}"
114 | print_string += f" [time per itr] {dt*1000/eval_freq:.2f}ms"
115 | if scheduler is not None:
116 | print_string += f" [lr] {current_lr:.5f}"
117 | print(print_string)
118 |
119 | if extra_args.wandb:
120 | wandb.log({
121 | "iter": itr,
122 | "train/loss": train_loss,
123 | "val/loss": val_loss,
124 | "val/perplexity": val_perplexity,
125 | "val/acc": val_acc,
126 | "lr": current_lr,
127 | })
128 |
129 | model.train()
130 | t0 = time.time()
131 |
132 | if distributed_backend.is_master_process():
133 | print(f"saving checkpoint to {ckpt_path}")
134 | save_checkpoint(distributed_backend=distributed_backend,
135 | model=model,
136 | opt=opt,
137 | scheduler=scheduler,
138 | itr=itr,
139 | ckpt_path=ckpt_path)
140 |
141 | return stats
142 |
--------------------------------------------------------------------------------
/lm_benchmark/optim/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Amirkeivan Mohtashami, Martin Jaggi
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | import torch
17 | import torch.nn.functional as F
18 | from contextlib import nullcontext, contextmanager, ExitStack
19 |
20 |
21 | def get_batch(data, seq_length, batch_size, device='cpu'):
22 | ix = torch.randint(len(data) - seq_length - 1, (batch_size,))
23 | x = torch.stack([torch.from_numpy((data[i:i+seq_length]).astype(np.int64)) for i in ix])
24 | y = torch.stack([torch.from_numpy((data[i+1:i+1+seq_length+1]).astype(np.int64)) for i in ix])
25 | y = torch.where(y[:, :-1] == 50260, y[:, 1:], y[:, :-1])
26 | y = torch.where((x == 50260) | (x == 50256) , -1, y)
27 | if device != 'cpu':
28 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
29 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
30 | #x, y = x.to(device), y.to(device)
31 | return x, y
32 |
33 |
34 | def save_checkpoint(distributed_backend, model, opt, scheduler, itr, ckpt_path, **extra_args):
35 |
36 | checkpoint = dict({
37 | 'model': distributed_backend.get_raw_model(model).state_dict(),
38 | 'optimizer': opt.state_dict(),
39 | 'scheduler': scheduler.state_dict(),
40 | 'itr': itr,
41 | }, **extra_args)
42 |
43 | torch.save(checkpoint, ckpt_path)
44 |
--------------------------------------------------------------------------------
/lm_benchmark/requirements.txt:
--------------------------------------------------------------------------------
1 | tiktoken
2 | --find-links https://download.pytorch.org/whl/torch_stable.html
3 | torch==2.0.0+cu118
4 | torchaudio==2.0.0+cu118
5 | torchvision==0.15.0+cu118
6 | tqdm
7 | transformers
8 | wandb
--------------------------------------------------------------------------------