├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── docker
├── Dockerfile
└── requirements.txt
└── scripts
├── config_models.sh
├── config_tasks.sh
├── data
├── prepare.py
├── synthetic
│ ├── common_words_extraction.py
│ ├── constants.py
│ ├── freq_words_extraction.py
│ ├── json
│ │ ├── PaulGrahamEssays_URLs.txt
│ │ ├── download_paulgraham_essay.py
│ │ ├── download_qa_dataset.sh
│ │ └── english_words.json
│ ├── niah.py
│ ├── qa.py
│ └── variable_tracking.py
├── template.py
└── tokenizer.py
├── eval
├── evaluate.py
└── synthetic
│ └── constants.py
├── pred
├── call_api.py
├── client_wrappers.py
├── model_wrappers.py
├── serve_trt.py
└── serve_vllm.py
├── run.sh
└── synthetic.yaml
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.jsonl filter=lfs diff=lfs merge=lfs -text
2 | *.jsonls filter=lfs diff=lfs merge=lfs -text
3 | *.json filter=lfs diff=lfs merge=lfs -text
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints/
2 | __pycache__/
3 | *.jsonl
4 | .vscode/
5 | *.out
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 📏 RULER: What’s the Real Context Size of Your Long-Context Language Models?
2 |
3 | This repository contains code for our paper [RULER: What’s the Real Context Size of Your Long-Context Language Models](https://arxiv.org/abs/2404.06654). RULER generates synthetic examples to evaluate long-context language models with configurable sequence length and task complexity. We benchmark 17 open-source models across 4 task categories (in total 13 tasks) in RULER, evaluating long-context capabilities beyond simple in-context recall. Here are our main results.
4 |
5 | |Models|Claimed Length|Effective Length|4K|8K|16K|32K|64K|128K|Avg.|wAvg. (inc)|wAvg. (dec)|
6 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
7 | |[Llama2](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) (7B)|4K||85.6|
8 | [Jamba-1.5-large*](https://huggingface.co/ai21labs/AI21-Jamba-1.5-Large) (94B/398B)|256k|>128k|96.7|96.6|96.4|96.0|95.4|95.1|96.0|95.7 **(1st)**|96.3 **(1st)**|
9 | [Gemini-1.5-pro](https://ai.google.dev/gemini-api/docs/models/gemini#:~:text=Gemini-,Gemini%201.5%20Pro%20(Preview%20only),-Text%20and%20images)|1M|>128K|96.7|95.8|96.0|95.9|95.9|94.4|95.8|95.5 **(2nd)**|96.1 **(2nd)**|
10 | [Jamba-1.5-mini](https://huggingface.co/ai21labs/AI21-Jamba-1.5-Mini) (12B/52B)|256K|>128K|95.6|95.6|94.8|94.6|92.8|90.0|93.9|93.1 **(3rd)**|94.8 **(3rd)**
11 | [GPT-4-1106-preview](https://platform.openai.com/docs/models/gpt-4-turbo-and-gpt-4#:~:text=gpt%2D4%2D1106%2Dpreview,Up%20to%20Apr%202023)|128K|64K|96.6|96.3|95.2|93.2|87.0|81.2|91.6|89.0 **(4th)**|94.1 **(4th)**|
12 | [Llama3.1](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct) (70B)|128K|64K|96.5|95.8|95.4|94.8|88.4|66.6|89.6|85.5 **(10th)**|93.7 **(5th)**|
13 | [Mistral-Large-2411](https://huggingface.co/mistralai/Mistral-Large-Instruct-2411) (123B)|128K|64K|96.4|96.3|95.3|94.0|85.9|48.1|86.0|79.5 **(18th)**|92.5 **(6th)**|
14 | [Command-R-plus-0824](https://huggingface.co/CohereForAI/c4ai-command-r-plus-08-2024) (104B)|128K|32K|96.0|95.1|94.0|92.4|85.4|64.6|87.9|83.4 **(13th)**|92.4 **(7th)**|
15 | [Qwen2](https://huggingface.co/Qwen/Qwen2-72B-Instruct) (72B)|128K|32K|96.9|96.1|94.9|94.1|79.8|53.7|85.9|79.6 **(17th)**|92.3 **(8th)**|
16 | [Command-R-plus](https://huggingface.co/CohereForAI/c4ai-command-r-plus) (104B)|128K|32K|95.6|95.2|94.2|92.0|84.3|63.1|87.4|82.7 **(14th)**|92.1 **(9th)**|
17 | [Command-R-0824](https://huggingface.co/CohereForAI/c4ai-command-r-08-2024) (32B)|128K|64K|94.7|93.7|93.1|90.8|86.6|74.7|88.9|86.0 **(8th)**|91.9 **(10th)**|
18 | [GLM4](https://huggingface.co/THUDM/glm-4-9b-chat-1m) (9B)|1M|64K|94.7|92.8|92.1|89.9|86.7|83.1|89.9|88.0 **(5th)**|91.7 **(11th)**|
19 | [Llama3.1](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) (8B)|128K|32K|95.5|93.8|91.6|87.4|84.7|77.0|88.3|85.4 **(11th)**|91.3 **(12th)**|
20 | [ProLong](https://huggingface.co/princeton-nlp/Llama-3-8B-ProLong-512k-Instruct) (8B)|512K|32K|94.5|92.5|92.3|89.3|83.2|81.6|88.9|86.6 **(7th)**|91.2 **(13th)**|
21 | [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01) (35B)|128K|32K|93.8|93.3|92.4|89.5|84.9|76.0|88.3|85.5 **(9th)**|91.1 **(14th)**|
22 | [MegaBeam-Mistral](https://huggingface.co/aws-prototyping/MegaBeam-Mistral-7B-512k) (7B)|512K|32K|93.8|92.5|92.0|89.2|83.7|83.7|89.1|87.3 **(6th)**|91.0 **(15th)**|
23 | [Mistral-Large-2407](https://huggingface.co/mistralai/Mistral-Large-Instruct-2407) (123B)|128K|32K|96.2|96.1|95.1|93.0|78.8|23.7|80.5|70.6 **(24th)**|90.4 **(16th)**|
24 | [GradientAI/Llama3](https://huggingface.co/gradientai/Llama-3-70B-Instruct-Gradient-1048k) (70B)|1M|16K|95.1|94.4|90.8|85.4|80.9|72.1|86.5|82.6 **(15th)**|90.3 **(17th)**|
25 | [Mixtral-8x22B](https://huggingface.co/mistralai/Mixtral-8x22B-instruct-v0.1) (39B/141B)|64K|32K|95.6|94.9|93.4|90.9|84.7|31.7|81.9|73.5 **(22nd)**|90.3 **(18th)**|
26 | [Yi](https://huggingface.co/01-ai/Yi-34B-200K) (34B)|200K|32K|93.3|92.2|91.3|87.5|83.2|77.3|87.5|84.8 **(12th)**|90.1 **(19th)**|
27 | [Phi3-mini](https://huggingface.co/microsoft/Phi-3-mini-128K-instruct) (3.8B)|128K|32K|92.2|91.5|90.7|87.5|80.6|66.7|84.8|80.9 **(16th)**|88.7 **(20th)**|
28 | [Phi3-medium](https://huggingface.co/microsoft/Phi-3-medium-128K-instruct) (14B)|128K|32K|93.3|93.2|91.1|86.8|78.6|46.1|81.5|74.8 **(21st)**|88.3 **(21st)**|
29 | [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-instruct-v0.1) (12.9B/46.7B)|32K|32K|94.9|92.1|92.5|85.9|72.4|44.5|80.4|72.8 **(23rd)**|87.9 **(22nd)**|
30 | [GradientAI/Llama3](https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k) (8B)|1M|16K|92.8|90.3|85.7|79.9|76.3|69.5|82.4|78.5 **(19th)**|86.3 **(23rd)**|
31 | [FILM-7B*](https://arxiv.org/pdf/2404.16811) (7B)|32K|32K|92.8|88.2|88.1|86.9|70.1|27.1|75.5|66.4 **(26th)**|84.7 **(24th)**|
32 | [InternLM2.5](https://huggingface.co/internlm/internlm2_5-7b-chat-1m) (7B)|1M|4K|88.1|85.5|84.5|82.7|75.5|68.9|80.9| 77.8 **(20th)**|83.9 **(25th)**|
33 | [Mistral](https://huggingface.co/mistralai/Mistral-7B-instruct-v0.2) (7B)|32K|16K|93.6|91.2|87.2|75.4|49.0|13.8|68.4|55.6 **(28th)**|81.2 **(26th)**|
34 | [Mistral-Nemo](https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407)|128K|16K|87.8|87.2|87.7|69.0|46.8|19.0|66.2|54.7 **(29th)**|77.8 **(27th)**|
35 | [GLM3](https://huggingface.co/THUDM/chatglm3-6b-128K) (6B)|128K|4K|87.8|83.4|78.6|69.9|56.0|42.0|69.6|62.0 **(27th)**|77.2 **(28th)**|
36 | [LWM](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-1M) (7B)|1M|<4K|82.3|78.4|73.7|69.1|68.1|65.0|72.8|69.9 **(25th)**|75.7 **(29th)**|
37 | [DBRX](https://huggingface.co/databricKs/dbrx-instruct) (36B/132B)|32K|8K|95.1|93.8|83.6|63.1|2.4|0.0|56.3|38.0 **(30th)**|74.7 **(30th)**|
38 | [Qwen1.5](https://huggingface.co/Qwen/Qwen1.5-72B-Chat) (72B)|32K|8K|94.9|93.8|78.0|67.8|0.0|0.0|55.7|37.5 **(31st)**|74.0 **(31st)**|
39 | [Together](https://huggingface.co/togethercomputer/Llama-2-7B-32K-instruct) (7B)|32K|4K|88.2|81.1|69.4|63.0|0.0|0.0|50.3|33.8 **(32nd)**|66.7 **(32nd)**|
40 | [LongChat](https://huggingface.co/lmsys/longchat-7b-v1.5-32K) (7B)|32K|<4K|84.7|79.9|70.8|59.3|0.0|0.0|49.1|33.1 **(33rd)**|65.2 **(33rd)**|
41 | [LongAlpaca](https://huggingface.co/YuKang/LongAlpaca-13B) (13B)| 32K|<4K|60.6|57.0|56.6|43.6|0.0|0.0|36.3|24.7 **(34th)**|47.9 **(34th)**|
42 |
43 | - Despite achieving nearly perfect performance on the vanilla needle-in-a-haystack (NIAH) test, most models exhibit large degradation on tasks in RULER as sequence length increases.
44 | - While all models claim context size of 32k tokens or greater, only half of them can effectively handle sequence length of 32K by exceeding a qualitative threshold, Llama-2-7b performance at 4K (85.6%). The performance exceeding the threshold is underlined.
45 | - Almost all models fall below the threshold before reaching the claimed context lengths.
46 | - Notes
47 | - Jamba-1.5-large results are reported by authors from this [report](https://arxiv.org/pdf/2408.12570).
48 | - FILM-7B results are reported by authors of this [paper](https://arxiv.org/pdf/2404.16811). They use [YaRN](https://arxiv.org/pdf/2309.00071) without further training for the evaluation length exceeding 32K (64K and 128K). They do not use the one-shot example for the CWE task.
49 |
50 | ## 💡 Requirements
51 |
52 | - Docker container: `docker pull cphsieh/ruler:0.2.0`
53 | - The requirements are listed in `docker/Dockerfile` and `docker/requirements.txt`. Use the following command to build the container based on NVIDIA's PyTorch container `nvcr.io/nvidia/pytorch:23.10-py3`.
54 | ```
55 | cd docker/
56 | DOCKER_BUILDKIT=1 docker build -f Dockerfile -t cphsieh/ruler:0.2.0 .
57 | ```
58 |
59 |
60 | ## 🔍 Evaluate long-context LMs
61 | ### 1. Download data
62 | - Paul Graham Essays for NIAH are downloaded from [NIAH Github](https://github.com/gkamradt/LLMTest_NeedleInAHaystack/tree/main/needlehaystack/PaulGrahamEssays) and [Paul Graham Blog](https://paulgraham.com/articles.html).
63 | - QA datasets are downloaded from [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) and [HotpotQA](https://hotpotqa.github.io/).
64 | ```
65 | cd scripts/data/synthetic/json/
66 | python download_paulgraham_essay.py
67 | bash download_qa_dataset.sh
68 | ```
69 | ### 2. Download model
70 | - We download the models from [Huggingface](https://huggingface.co/models).
71 | - The input template of each model is stored in `scripts/data/template.py`. Please add new model template if your new model uses a different chat template.
72 | - Increase `max_position_embeddings` in `config.json` if you want to run inference longer than model defined length.
73 | - (Optional) If you are using [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/main), please build your model engine based on their example scripts (e.g., [Llama](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama)) with their [Docker container](https://github.com/NVIDIA/TensorRT-LLM/tree/main?tab=readme-ov-file#installation).
74 |
75 | ### 3. Run evaluation pipeline
76 |
77 | - **Setup `run.sh`**
78 | ```
79 | GPUS="" # number of GPUs
80 | ROOT_DIR="" # the path that stores generated task samples and model predictions.
81 | MODEL_DIR="" # the path that contains individual model folders from Huggingface.
82 | ENGINE_DIR="" # the path that contains individual engine folders from TensorRT-LLM.
83 | ```
84 | - **Setup `config_models.sh`**
85 | ```
86 | case $MODEL_NAME in
87 | YOUR_HF_MODEL_NAME)
88 | MODEL_PATH=${MODEL_DIR}/YOUR_MODEL_FOLDER
89 | MODEL_TEMPLATE_TYPE="" # base, meta-chat, etc. defined in `scripts/data/template.py`
90 | MODEL_FRAMEWORK="" # hf or vllm
91 | ;;
92 | YOUR_TRTLLM_ENGINE_NAME)
93 | MODEL_PATH=${ENGINE_DIR}/YOUR_ENGINE_FOLDER
94 | MODEL_TEMPLATE_TYPE="" # base, meta-chat, etc. defined in `scripts/data/template.py`
95 | MODEL_FRAMEWORK="trtllm"
96 | ;;
97 | YOUR_OPENAI_MODEL_NAME)
98 | MODEL_PATH="" # OpenAI model name listed in https://platform.openai.com/docs/models/
99 | MODEL_TEMPLATE_TYPE="base"
100 | MODEL_FRAMEWORK="openai"
101 | TOKENIZER_PATH="cl100k_base"
102 | TOKENIZER_TYPE="openai"
103 | OPENAI_API_KEY="" # your OpenAI API key
104 | ;;
105 | YOUR_GEMINI_MODEL_NAME)
106 | MODEL_PATH="" # Gemini model name listed in https://ai.google.dev/gemini-api/docs/models/gemini
107 | MODEL_TEMPLATE_TYPE="base"
108 | MODEL_FRAMEWORK="gemini"
109 | TOKENIZER_PATH=$MODEL_PATH
110 | TOKENIZER_TYPE="gemini"
111 | GEMINI_API_KEY="" # your Gemini API key
112 | ;;
113 | ```
114 |
115 | - **Start evaluation based on our default `synthetic` benchmark**
116 | ```
117 | bash run.sh YOUR_MODEL_NAME synthetic
118 | ```
119 |
120 | ## 🧠 (Optional) Customize task complexity
121 | The tasks to be evaluated on are stored in `scripts/config_tasks.sh`. Configuration of each task is defined in `scripts/synthetic.yaml`. The complexity of each task can be configured by changing the arguments which we describe in detail below.
122 |
123 | | Category |Task name | Configurations |
124 | |:--------------------:|:---------------------------:|--------------------|
125 | | Retrieval | niah |**type_haystack**: `repeat/essay/needle`
# repeat: repeated noise sentences
# essay: Paul Graham Essays
# needle: distracted needles
**type_needle_k**: `words/numbers/uuids`
**type_needle_v**: `words/numbers/uuids`
# words: adjective-noun
# numbers: 7 digits
# uuids: 32 digits
**num_needle_k**: `int >= 1`
# add multiple needles in haystack
**num_needle_v**: `int >= 1`
# retrieve multiple values from a single key
**num_needle_q**: `int >= 1`
# retrieve multiple values from multiple keys |
126 | | Multi-hop
Tracing | variable_tracking | **num_chains**: `int >= 1`
# number of variable name-binding chains
**num_hops**: `int >= 1`
# number of times binding variable names in each chain |
127 | | Aggregation | common_words_extraction |**freq_cw**: `int >= 1`
# frequency of common words
**freq_ucw**: `int >= 1`
# frequency of uncommon words
**num_cw**: `int >= 1`
# number of common words |
128 | | Aggregation | freq_words_extraction |**alpha**: `float > 1.0`
# parameter of the distribution to draw synthetic words. Reducing alpha to increase the difficulty of this task. Note that increasing the number of words to return also increases the difficulty of this task, we use `3` in our evaluations as models show worse performance at short context size when more words need to be returned. |
129 | | Question
Answering | qa |**dataset**: `squad` or `hotpotqa`
# the short-context qa dataset we use
130 |
131 |
132 |
133 | ## 🚀 (Optional) Contribute a new synthetic task
134 | ### 1. Create a python script for data preparation
135 | * Add basic arguments (required) and complexity configurations in the python script.
136 | * Verify the script is reproducible given a tokenizer, a sequence length, and a random seed.
137 | * Save the script under the folder `scripts/data/synthetic`.
138 |
139 | ### 2. Add task template
140 | * Add `template` and `tokens_to_generate` in `scripts/data/synthetic/constants.py`.
141 | * Add `answer_predfix` to prevent model from refusing to answer.
142 |
143 | ### 3. Add evaluation metric
144 | * Add the automatic metric to evaluate your task in `scripts/eval/synthetic/constants.py`
145 |
146 | ### 4. Add required configurations
147 | * Define your task name and complexity configurations in `scripts/synthetic.yaml`.
148 | * Add your task name in `scripts/config_tasks.sh`
149 |
150 | ## 🛠️ Limitations
151 | While tasks in RULER are designed to be configurable, we only evaluate the above models with 13 task configurations. These tasks were selected because most models can achieve good (some almost perfect) performance at short context size (<= 4K), which leaves ample room to observe degradation as we extend the input length. We did not include more complexed tasks in RULER that models show worse performance at short context size. We also did not stress test every model with more difficult task configurations. Although RULER covers four task categories extending previous evaluation protocol and provides a clean test bed for sanity-checking LMs with known upper bound performance, it is by no means comprehensive enough and it cannot replace the more preferred realistic tasks. We welcome people to contribute new tasks and/or new task categories to help evaluate long-context capabilities.
152 |
153 |
154 | ## 📝 Citation
155 | ```
156 | @article{hsieh2024ruler,
157 | title={RULER: What's the Real Context Size of Your Long-Context Language Models?},
158 | author={Cheng-Ping Hsieh and Simeng Sun and Samuel Kriman and Shantanu Acharya and Dima Rekesh and Fei Jia and Yang Zhang and Boris Ginsburg},
159 | year={2024},
160 | journal={arXiv preprint arXiv:2404.06654},
161 | }
162 | ```
163 | Disclaimer: This project is strictly for research purposes, and not an official product from NVIDIA.
164 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | FROM nvcr.io/nvidia/pytorch:23.10-py3
16 |
17 | WORKDIR /workspace/
18 |
19 | COPY ./requirements.txt .
20 | RUN pip install --upgrade pip \
21 | && pip install -r requirements.txt \
22 | && pip install flash-attn==2.6.0.post1 --no-build-isolation \
23 | && pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary \
24 | && pip install causal-conv1d==1.4.0 \
25 | && pip install mamba-ssm==2.2.2
26 |
27 | RUN [ "python3", "-c", "import nltk; nltk.download('punkt')"]
28 |
29 |
--------------------------------------------------------------------------------
/docker/requirements.txt:
--------------------------------------------------------------------------------
1 | nemo-toolkit[all]
2 | tritonclient[all]
3 | transformer_engine[pytorch]
4 | flask
5 | flask_restful
6 | html2text
7 | google-generativeai
8 | sshtunnel_requests
9 | wonderwords
10 | openai
11 | tiktoken
12 | tenacity
13 | accelerate
14 | huggingface_hub==0.23.4
15 | transformers==4.44.2
16 | vllm==0.5.4
--------------------------------------------------------------------------------
/scripts/config_models.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | TEMPERATURE="0.0" # greedy
16 | TOP_P="1.0"
17 | TOP_K="32"
18 | SEQ_LENGTHS=(
19 | 131072
20 | 65536
21 | 32768
22 | 16384
23 | 8192
24 | 4096
25 | )
26 |
27 | MODEL_SELECT() {
28 | MODEL_NAME=$1
29 | MODEL_DIR=$2
30 | ENGINE_DIR=$3
31 |
32 | case $MODEL_NAME in
33 | llama2-7b-chat)
34 | MODEL_PATH="${MODEL_DIR}/llama2-7b-chat-hf"
35 | MODEL_TEMPLATE_TYPE="meta-chat"
36 | MODEL_FRAMEWORK="vllm"
37 | ;;
38 | llama3.1-8b-chat)
39 | MODEL_PATH="${MODEL_DIR}/llama3.1-8b-Instruct"
40 | MODEL_TEMPLATE_TYPE="meta-llama3"
41 | MODEL_FRAMEWORK="vllm"
42 | ;;
43 | jamba1.5-mini)
44 | MODEL_PATH="${MODEL_DIR}/Jamba-1.5-Mini"
45 | MODEL_TEMPLATE_TYPE="jamba"
46 | MODEL_FRAMEWORK="vllm"
47 | ;;
48 | gpt-3.5-turbo)
49 | MODEL_PATH="gpt-3.5-turbo-0125"
50 | MODEL_TEMPLATE_TYPE="base"
51 | MODEL_FRAMEWORK="openai"
52 | TOKENIZER_PATH="cl100k_base"
53 | TOKENIZER_TYPE="openai"
54 | OPENAI_API_KEY=""
55 | AZURE_ID=""
56 | AZURE_SECRET=""
57 | AZURE_ENDPOINT=""
58 | ;;
59 | gpt-4-turbo)
60 | MODEL_PATH="gpt-4"
61 | MODEL_TEMPLATE_TYPE="base"
62 | MODEL_FRAMEWORK="openai"
63 | TOKENIZER_PATH="cl100k_base"
64 | TOKENIZER_TYPE="openai"
65 | OPENAI_API_KEY=""
66 | AZURE_ID=""
67 | AZURE_SECRET=""
68 | AZURE_ENDPOINT=""
69 | ;;
70 | gemini_1.0_pro)
71 | MODEL_PATH="gemini-1.0-pro-latest"
72 | MODEL_TEMPLATE_TYPE="base"
73 | MODEL_FRAMEWORK="gemini"
74 | TOKENIZER_PATH=$MODEL_PATH
75 | TOKENIZER_TYPE="gemini"
76 | GEMINI_API_KEY=""
77 | ;;
78 | gemini_1.5_pro)
79 | MODEL_PATH="gemini-1.5-pro-latest"
80 | MODEL_TEMPLATE_TYPE="base"
81 | MODEL_FRAMEWORK="gemini"
82 | TOKENIZER_PATH=$MODEL_PATH
83 | TOKENIZER_TYPE="gemini"
84 | GEMINI_API_KEY=""
85 | ;;
86 | esac
87 |
88 |
89 | if [ -z "${TOKENIZER_PATH}" ]; then
90 | if [ -f ${MODEL_PATH}/tokenizer.model ]; then
91 | TOKENIZER_PATH=${MODEL_PATH}/tokenizer.model
92 | TOKENIZER_TYPE="nemo"
93 | else
94 | TOKENIZER_PATH=${MODEL_PATH}
95 | TOKENIZER_TYPE="hf"
96 | fi
97 | fi
98 |
99 |
100 | echo "$MODEL_PATH:$MODEL_TEMPLATE_TYPE:$MODEL_FRAMEWORK:$TOKENIZER_PATH:$TOKENIZER_TYPE:$OPENAI_API_KEY:$GEMINI_API_KEY:$AZURE_ID:$AZURE_SECRET:$AZURE_ENDPOINT"
101 | }
102 |
--------------------------------------------------------------------------------
/scripts/config_tasks.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | NUM_SAMPLES=500
16 | REMOVE_NEWLINE_TAB=false
17 | STOP_WORDS=""
18 |
19 | if [ -z "${STOP_WORDS}" ]; then
20 | STOP_WORDS=""
21 | else
22 | STOP_WORDS="--stop_words \"${STOP_WORDS}\""
23 | fi
24 |
25 | if [ "${REMOVE_NEWLINE_TAB}" = false ]; then
26 | REMOVE_NEWLINE_TAB=""
27 | else
28 | REMOVE_NEWLINE_TAB="--remove_newline_tab"
29 | fi
30 |
31 | # task name in `synthetic.yaml`
32 | synthetic=(
33 | "niah_single_1"
34 | "niah_single_2"
35 | "niah_single_3"
36 | "niah_multikey_1"
37 | "niah_multikey_2"
38 | "niah_multikey_3"
39 | "niah_multivalue"
40 | "niah_multiquery"
41 | "vt"
42 | "cwe"
43 | "fwe"
44 | "qa_1"
45 | "qa_2"
46 | )
47 |
--------------------------------------------------------------------------------
/scripts/data/prepare.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Prepare jsonl with field `input` and `outputs`.
17 | {
18 | "index" int,
19 | "input": str,
20 | "outputs": [str],
21 | }
22 |
23 | python prepare.py \
24 | --save_dir ./ \
25 | --benchmark synthetic \
26 | --task niah_single_1 \
27 | --tokenizer_path tokenizer.model \
28 | --tokenizer_type nemo \
29 | --max_seq_length 4096 \
30 | --model_template_type base \
31 | --num_samples 10 \
32 | """
33 | import os
34 | import argparse
35 | import importlib
36 | import subprocess
37 | import time
38 | import yaml
39 | from pathlib import Path
40 | from template import Templates
41 | import nltk
42 | try:
43 | nltk.data.find('tokenizers/punkt')
44 | nltk.data.find('tokenizers/punkt_tab')
45 | except LookupError:
46 | nltk.download('punkt')
47 | nltk.download('punkt_tab')
48 |
49 | parser = argparse.ArgumentParser()
50 | parser.add_argument("--save_dir", type=Path, required=True, help='dataset folder to save dataset')
51 | parser.add_argument("--benchmark", type=str, default='synthetic', help='Options: [synthetic]')
52 | parser.add_argument("--task", type=str, required=True, help='tasks in benchmark')
53 | parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test')
54 | parser.add_argument("--tokenizer_path", type=str, required=True, help='path to the tokenizer model')
55 | parser.add_argument("--tokenizer_type", type=str, default='nemo', help='[Options] nemo, hf, openai.')
56 | parser.add_argument("--max_seq_length", type=int, required=True, help='max sequence length including all input tokens and generated tokens.')
57 | parser.add_argument("--num_samples", type=int, default=500, help='maximum number of samples we want to test')
58 | parser.add_argument("--random_seed", type=int, default=42)
59 | parser.add_argument("--model_template_type", type=str, default='base', help='Options in `template.py`')
60 | parser.add_argument("--remove_newline_tab", action='store_true', help='remove `\n` and `\t` in all strings.')
61 | parser.add_argument("--chunk_idx", type=int, default=0, help='index of current split chunk')
62 | parser.add_argument("--chunk_amount", type=int, default=1, help='size of split chunk')
63 | parser.add_argument("--prepare_for_ns", action='store_true')
64 |
65 | args = parser.parse_args()
66 |
67 | def main():
68 | start_time = time.time()
69 | curr_folder = os.path.dirname(os.path.abspath(__file__))
70 |
71 | try:
72 | module = importlib.import_module(f"{args.benchmark}.constants")
73 | except ImportError:
74 | print(f"Module data.{args.benchmark}.constants not found.")
75 |
76 | tasks_base = module.TASKS
77 | with open(os.path.join(curr_folder, f"../{args.benchmark}.yaml"), "r") as f:
78 | tasks_customized = yaml.safe_load(f)
79 |
80 | if args.task not in tasks_customized:
81 | raise ValueError(f'{args.task} is not found in config_tasks.yaml')
82 |
83 | config = tasks_customized.get(args.task)
84 | config.update(tasks_base[config['task']])
85 |
86 | # Add templates
87 | assert args.model_template_type in Templates, print(f'{args.model_template_type} is not found in {Templates.keys()}')
88 | model_template = Templates[args.model_template_type]
89 |
90 | if args.prepare_for_ns:
91 | from tokenizer import select_tokenizer
92 | TOKENIZER = select_tokenizer(args.tokenizer_type, args.tokenizer_path)
93 | model_template_token = len(TOKENIZER.text_to_tokens(model_template))
94 | model_template = Templates['base']
95 |
96 | task_template = config['template']
97 |
98 | # Add answer prefix for all models
99 | answer_prefix = config['answer_prefix'] if 'answer_prefix' in config else ''
100 |
101 | config['template'] = model_template.format(task_template=task_template) + answer_prefix
102 |
103 | # Split task into multiple chunks
104 | chunks = [(args.num_samples // args.chunk_amount) + (1 if i < args.num_samples % args.chunk_amount else 0) for i in range(args.chunk_amount)]
105 | num_samples = chunks[args.chunk_idx]
106 | pre_samples = sum(chunks[:args.chunk_idx])
107 |
108 | random_seed = args.random_seed + args.chunk_idx
109 |
110 |
111 | save_file = args.save_dir / args.task / f"{args.subset}.jsonl"
112 | file_exists = False
113 | if os.path.exists(save_file):
114 | with open(save_file, "r") as f:
115 | data = f.readlines()
116 | if len(data) == args.num_samples: file_exists = True
117 |
118 |
119 |
120 | if not file_exists:
121 | try:
122 | script = os.path.join(curr_folder, args.benchmark, f"{config['task']}.py")
123 | additional_args = " ".join([f"--{k} {v}" for k, v in config['args'].items()])
124 | command = f"""python {script} \
125 | --save_dir {args.save_dir} \
126 | --save_name {args.task} \
127 | --subset {args.subset} \
128 | --tokenizer_path {args.tokenizer_path} \
129 | --tokenizer_type {args.tokenizer_type} \
130 | --max_seq_length {args.max_seq_length} \
131 | --tokens_to_generate {config['tokens_to_generate']} \
132 | --num_samples {num_samples} \
133 | --random_seed {random_seed} \
134 | {additional_args} \
135 | {f"--remove_newline_tab" if args.remove_newline_tab else ""} \
136 | {f"--pre_samples {pre_samples}" if config['task'] == 'qa' else ""} \
137 | --template "{config['template']}" \
138 | """
139 | if args.prepare_for_ns:
140 | command += f""" --model_template_token {model_template_token}"""
141 |
142 | print(command)
143 | result = subprocess.run(command,
144 | shell=True,
145 | check=True,
146 | stdout=subprocess.PIPE,
147 | stderr=subprocess.PIPE,
148 | text=True)
149 |
150 | if result.returncode == 0:
151 | print("Output:")
152 | print(result.stdout)
153 | else:
154 | print("Error:")
155 | print(result.stderr)
156 | except subprocess.CalledProcessError as e:
157 | print("Error output:", e.stderr)
158 |
159 | print(f"Prepare {args.task} with lines: {args.num_samples} to {save_file}")
160 | print(f"Used time: {round((time.time() - start_time) / 60, 1)} minutes")
161 | else:
162 | print(f"Skip preparing {args.task} with lines: {args.num_samples} to {save_file} (file exists)")
163 |
164 | if __name__ == '__main__':
165 | main()
166 |
--------------------------------------------------------------------------------
/scripts/data/synthetic/common_words_extraction.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License
14 |
15 | """
16 | Create a dataset jsonl file for common words extraction.
17 |
18 | python common_words_extraction.py \
19 | --save_dir=./ \
20 | --save_name=vt \
21 | --tokenizer_path=tokenizer.model \
22 | --tokenizer_type nemo \
23 | --max_seq_length 4096 \
24 | --tokens_to_generate 30 \
25 | --num_samples 10 \
26 | --random_seed 42 \
27 | -freq_cw 30 --freq_ucw 3 --num_cw 10 \
28 | --template "[INST] Below is a numbered list of words. In these words, some appear more often than others. Memorize the ones that appear most often.\n{context}\nQuestion: What are the 10 most common words in the above list? [/INST] Answer: The top 10 words that appear most often in the list are:"
29 | """
30 |
31 | import os
32 | import argparse
33 | from pathlib import Path
34 | from tqdm import tqdm
35 | import random
36 | import wonderwords
37 | from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest
38 | import sys
39 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
40 | from tokenizer import select_tokenizer
41 | import json
42 | import logging
43 | from constants import TASKS
44 | logging.basicConfig(level=logging.INFO, force=True)
45 | logger = logging.getLogger(__name__)
46 |
47 |
48 | parser = argparse.ArgumentParser()
49 | parser.add_argument("--save_dir", type=Path, required=True, help='dataset folder to save dataset')
50 | parser.add_argument("--save_name", type=str, required=True, help='name of the save dataset jsonl file')
51 | parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test')
52 | parser.add_argument("--tokenizer_path", type=str, required=True, help='path to the tokenizer model')
53 | parser.add_argument("--tokenizer_type", type=str, default='nemo', help='[Options] nemo, hf, openai.')
54 | parser.add_argument("--max_seq_length", type=int, required=True, help='max sequence length including all input tokens and generated tokens.')
55 | parser.add_argument("--tokens_to_generate", type=int, required=True, help='expected generated token amount.')
56 | parser.add_argument("--num_samples", type=int, required=True, help='number of samples to generate')
57 | parser.add_argument("--random_seed", type=int, default=42)
58 | parser.add_argument("--template", type=str, default='', help='prompt template')
59 | parser.add_argument("--remove_newline_tab", action='store_true', help='remove `\n` and `\t` in all strings.')
60 |
61 | parser.add_argument("--freq_cw", type=int, default=30)
62 | parser.add_argument("--freq_ucw", type=int, default=3)
63 | parser.add_argument("--num_cw", type=int, default=10)
64 | parser.add_argument("--num_fewshot", type=int, default=1)
65 | parser.add_argument("--model_template_token", type=int, default=0, help='used for nemo skills, minus num of model template token')
66 | args = parser.parse_args()
67 | random.seed(args.random_seed)
68 |
69 | # Load Tokenizer
70 | TOKENIZER = select_tokenizer(args.tokenizer_type, args.tokenizer_path)
71 |
72 | nouns = wonderwords.random_word._get_words_from_text_file("nounlist.txt")
73 | adjs = wonderwords.random_word._get_words_from_text_file("adjectivelist.txt")
74 | verbs = wonderwords.random_word._get_words_from_text_file("verblist.txt")
75 | words = nouns + adjs + verbs
76 | words = sorted(list(set(words)))
77 | random.Random(args.random_seed).shuffle(words)
78 | logger.info(f'loaded {len(words)} wonderwords')
79 |
80 | # Randleword english words
81 | with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "json/english_words.json") , "r") as f:
82 | randle_words = list(json.load(f).values())
83 | logger.info(f'loaded {len(randle_words)} randle words')
84 |
85 | def get_example(num_words, common_repeats=30, uncommon_repeats=3, common_nums=10):
86 | if num_words <= len(words):
87 | word_list_full = random.sample(words, num_words)
88 | else:
89 | word_list_full = random.sample(randle_words, num_words)
90 |
91 | common, uncommon = word_list_full[:common_nums], word_list_full[common_nums:]
92 | word_list = common * int(common_repeats) + uncommon * int(uncommon_repeats)
93 | random.Random(args.random_seed).shuffle(word_list)
94 |
95 | # Formatting the word list as "1. word1 2. word2 3. word3 ..."
96 | context = ' '.join([f"{i + 1}. {word}" for i, word in enumerate(word_list)])
97 |
98 | return context, common
99 |
100 | def generate_input_output(num_words):
101 | few_shots = []
102 | if args.max_seq_length < 4096:
103 | for _ in range(args.num_fewshot):
104 | context_example, answer_example = get_example(20, 3, 1, args.num_cw)
105 | few_shots.append((context_example, answer_example))
106 | context, answer = get_example(num_words, 6, 1, args.num_cw)
107 | else:
108 | for _ in range(args.num_fewshot):
109 | context_example, answer_example = get_example(40, 10, 3, args.num_cw)
110 | few_shots.append((context_example, answer_example))
111 | context, answer = get_example(num_words, args.freq_cw, args.freq_ucw, args.num_cw)
112 |
113 | template = args.template
114 |
115 | for n in range(len(few_shots)):
116 | few_shots[n] = template.format(
117 | num_cw=args.num_cw,
118 | context=few_shots[n][0],
119 | query='',
120 | ) + ' ' + ' '.join([f"{i + 1}. {word}" for i, word in enumerate(few_shots[n][1])])
121 |
122 | few_shots = "\n".join(few_shots)
123 | input_text = template.format(
124 | num_cw=args.num_cw,
125 | context=context,
126 | query='',
127 | )
128 |
129 | return few_shots + "\n" + input_text, answer
130 |
131 | def sys_word_pair_random(num_samples: int, max_seq_length: int, save_dir: str, incremental: int = 10):
132 | write_jsons = []
133 | tokens_to_generate = args.tokens_to_generate
134 | max_seq_length -= args.model_template_token
135 |
136 |
137 | # Estimate tokens per question to determine reasonable upper bound
138 | sample_input_text, _ = generate_input_output(incremental)
139 | sample_tokens = len(TOKENIZER.text_to_tokens(sample_input_text))
140 | tokens_per_words = sample_tokens / incremental
141 |
142 | # Let's do 3x to allow for some slack since we can get unlucky due to sampling.
143 | # NOTE: We should test this for really large sequence lengths to make sure it's reasonable.
144 | estimated_max_words = int((max_seq_length / tokens_per_words) * 3)
145 |
146 | # Binary search for optimal haystack size
147 | lower_bound = incremental
148 | upper_bound = max(estimated_max_words, incremental * 2) # Ensure upper_bound is reasonable
149 |
150 | optimal_num_words = None
151 |
152 | logger.info(f"Estimated {tokens_per_words:.1f} tokens per haystack")
153 | logger.info(f"Starting binary search with bounds: {lower_bound} to {upper_bound}")
154 | while lower_bound <= upper_bound:
155 | mid = (lower_bound + upper_bound) // 2
156 | input_text, answer = generate_input_output(mid)
157 | total_tokens = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate
158 |
159 | logger.info(f"Testing haystack size: {mid}, resulting tokens: {total_tokens}/{max_seq_length}")
160 |
161 | if total_tokens <= max_seq_length:
162 | # This size works, can we go larger?
163 | optimal_num_words = mid
164 | lower_bound = mid + 1
165 | else:
166 | # Too large, need to go smaller
167 | upper_bound = mid - 1
168 |
169 | num_words = optimal_num_words if optimal_num_words is not None else incremental
170 | logger.info(f'Final optimal haystack size (number of haystack): {num_words}')
171 |
172 |
173 | # Generate samples
174 | for index in tqdm(range(num_samples)):
175 | used_words = num_words
176 | while(True):
177 | try:
178 | input_text, answer = generate_input_output(used_words)
179 | length = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate
180 | assert length <= max_seq_length, f"{length} exceeds max_seq_length."
181 | break
182 | except:
183 | if used_words > incremental:
184 | used_words -= incremental
185 |
186 | if args.remove_newline_tab:
187 | input_text = ' '.join(input_text.replace('\n', ' ').replace('\t', ' ').strip().split())
188 |
189 | answer_prefix_index = input_text.rfind(TASKS['common_words_extraction']['answer_prefix'][:10]) # use first 10 char of answer prefix to locate it
190 | answer_prefix = input_text[answer_prefix_index:]
191 | input_text = input_text[:answer_prefix_index]
192 |
193 | formatted_output = {
194 | 'index': index,
195 | "input": input_text,
196 | "outputs": answer,
197 | "length": length,
198 | 'length_w_model_temp': length + args.model_template_token,
199 | 'answer_prefix': answer_prefix,
200 | }
201 | write_jsons.append(formatted_output)
202 |
203 | return write_jsons
204 |
205 |
206 | def main():
207 | save_file = args.save_dir / f'{args.save_name}' / f'{args.subset}.jsonl'
208 | save_file.parent.mkdir(parents=True, exist_ok=True)
209 |
210 | write_jsons = sys_word_pair_random(num_samples=args.num_samples, max_seq_length=args.max_seq_length, save_dir=args.save_dir)
211 |
212 | write_manifest(save_file, write_jsons)
213 |
214 | if __name__=="__main__":
215 | main()
--------------------------------------------------------------------------------
/scripts/data/synthetic/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License
14 |
15 | """
16 | Add a new task (required arguments):
17 |
18 | TASK_NAME: {
19 | 'tokens_to_generate': how many tokens we want to generate.
20 | 'template': the template with at least {context} and {query}.
21 | }
22 | """
23 |
24 | TASKS = {
25 | 'niah': {
26 | 'tokens_to_generate': 128,
27 | 'template': """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text?""",
28 | 'answer_prefix': """ The special magic {type_needle_v} for {query} mentioned in the provided text are"""
29 | },
30 |
31 | 'variable_tracking': {
32 | 'tokens_to_generate': 30,
33 | 'template': """Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n{context}\nQuestion: Find all variables that are assigned the value {query} in the text above.""",
34 | 'answer_prefix': """ Answer: According to the chain(s) of variable assignment in the text above, {num_v} variables are assigned the value {query}, they are: """
35 | },
36 |
37 | 'common_words_extraction': {
38 | 'tokens_to_generate': 120,
39 | 'template': """Below is a numbered list of words. In these words, some appear more often than others. Memorize the ones that appear most often.\n{context}\nQuestion: What are the 10 most common words in the above list?""",
40 | 'answer_prefix': """ Answer: The top 10 words that appear most often in the list are:"""
41 | },
42 |
43 | 'freq_words_extraction' : {
44 | 'tokens_to_generate': 50,
45 | 'template': """Read the following coded text and track the frequency of each coded word. Find the three most frequently appeared coded words. {context}\nQuestion: Do not provide any explanation. Please ignore the dots '....'. What are the three most frequently appeared words in the above coded text?""",
46 | 'answer_prefix': """ Answer: According to the coded text above, the three most frequently appeared words are:"""
47 | },
48 |
49 | 'qa': {
50 | 'tokens_to_generate': 32,
51 | 'template': """Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query}""",
52 | 'answer_prefix': """ Answer:""",
53 | },
54 | }
--------------------------------------------------------------------------------
/scripts/data/synthetic/freq_words_extraction.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License
14 |
15 | """
16 | Create a dataset jsonl file for frequent words extraction.
17 |
18 | python freq_words_extraction.py \
19 | --save_dir=./ \
20 | --save_name=vt \
21 | --tokenizer_path=tokenizer.model \
22 | --tokenizer_type nemo \
23 | --max_seq_length 4096 \
24 | --tokens_to_generate 30 \
25 | --num_samples 10 \
26 | --random_seed 42 \
27 | --alpha 2.0 \
28 | --template "[INST] Read the following coded text and track the frequency of each coded word. Find the three most frequently appeared coded words. {context}\nQuestion: Do not provide any explanation. Please ignore the dots '....'. What are the three most frequently appeared words in the above coded text? [/INST] Answer: According to the coded text above, the three most frequently appeared words are:"
29 | """
30 |
31 | import os
32 | import argparse
33 | from pathlib import Path
34 | from tqdm import tqdm
35 | import random
36 | import string
37 | import numpy as np
38 | from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest
39 | import sys
40 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
41 | from tokenizer import select_tokenizer
42 | from scipy.special import zeta
43 | import logging
44 |
45 | logging.basicConfig(level=logging.INFO)
46 | logger = logging.getLogger(__name__)
47 | from constants import TASKS
48 |
49 | parser = argparse.ArgumentParser()
50 | parser.add_argument("--save_dir", type=Path, required=True, help='dataset folder to save dataset')
51 | parser.add_argument("--save_name", type=str, required=True, help='name of the save dataset jsonl file')
52 | parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test')
53 | parser.add_argument("--tokenizer_path", type=str, required=True, help='path to the tokenizer model')
54 | parser.add_argument("--tokenizer_type", type=str, default='nemo', help='[Options] nemo, hf, openai.')
55 | parser.add_argument("--max_seq_length", type=int, required=True, help='max sequence length including all input tokens and generated tokens.')
56 | parser.add_argument("--tokens_to_generate", type=int, default=50, help='number of tokens to generate')
57 | parser.add_argument("--num_samples", type=int, required=True, help='number of samples to generate')
58 | parser.add_argument("--random_seed", type=int, default=42)
59 | parser.add_argument("--template", type=str, default='', help='prompt template')
60 | parser.add_argument("--remove_newline_tab", action='store_true', help='remove `\n` and `\t` in all strings.')
61 | parser.add_argument("--coded_wordlen", type=int, default=6, help="length of synthetic word")
62 | parser.add_argument("--vocab_size", type=int, default=-1, help='synthetic vocab size to sample from')
63 | parser.add_argument("--alpha", type=float, default=2.0, help='zeta distribution alpha')
64 | parser.add_argument("--add_fewshot", action="store_true", default=False)
65 | parser.add_argument("--model_template_token", type=int, default=0, help='used for nemo skills, minus num of model template token')
66 |
67 | args = parser.parse_args()
68 | random.seed(args.random_seed)
69 | np.random.seed(args.random_seed)
70 |
71 | # Load Tokenizer
72 | TOKENIZER = select_tokenizer(args.tokenizer_type, args.tokenizer_path)
73 |
74 | def generate_input_output(max_len, num_words=-1, coded_wordlen=6, vocab_size=2000, incremental=10, alpha=2.0):
75 | # generate vocab
76 | vocab = [''.join(random.choices(string.ascii_lowercase, k=coded_wordlen)) for _ in range(vocab_size)]
77 | while len(set(vocab)) < vocab_size:
78 | vocab.append(''.join(random.choices(string.ascii_lowercase, k=coded_wordlen)))
79 | vocab = sorted(list(set(vocab)))
80 | random.Random(args.random_seed).shuffle(vocab)
81 | vocab[0] = '...' # treat the top ranked as noise
82 |
83 | # sample words
84 | template = args.template
85 | def gen_text(num_words):
86 | k = np.arange(1, len(vocab)+1)
87 | sampled_cnt = num_words*(k**-alpha)/zeta(alpha)
88 | sampled_words = [[w] * zi for w, zi in zip(vocab, sampled_cnt.astype(int))]
89 | sampled_words = [x for wlst in sampled_words for x in wlst]
90 | random.Random(args.random_seed).shuffle(sampled_words)
91 | return template.format(context=' '.join(sampled_words), query=''), vocab[1:4]
92 |
93 | if num_words > 0:
94 | num_words = num_words
95 | text, answer = gen_text(num_words)
96 | while len(TOKENIZER.text_to_tokens(text)) > max_len:
97 | num_words -= incremental
98 | text, answer = gen_text(num_words)
99 | else:
100 | num_words = max_len // coded_wordlen # init
101 | text, answer = gen_text(num_words)
102 | while len(TOKENIZER.text_to_tokens(text)) < max_len:
103 | num_words += incremental
104 | text, answer = gen_text(num_words)
105 | num_words -= incremental
106 | text, answer = gen_text(num_words)
107 | return text, answer, num_words
108 |
109 | def sys_kwext(num_samples: int, max_seq_length: int, incremental: int = 10):
110 | write_jsons = []
111 | tokens_to_generate = args.tokens_to_generate
112 |
113 | max_seq_length -= args.model_template_token
114 | vocab_size = max_seq_length // 50 if args.vocab_size == -1 else args.vocab_size
115 |
116 | # get number of words
117 | input_max_len = max_seq_length
118 | _, _, num_example_words = generate_input_output(input_max_len,
119 | coded_wordlen=args.coded_wordlen,
120 | vocab_size=vocab_size,
121 | incremental=input_max_len//32,
122 | alpha=args.alpha)
123 | logger.info('num_example_words:', num_example_words)
124 | # Generate samples
125 | for index in tqdm(range(num_samples)):
126 |
127 | # construct input
128 | input_max_len = max_seq_length
129 | input_text, answer, _ = generate_input_output(input_max_len,
130 | num_words=num_example_words,
131 | coded_wordlen=args.coded_wordlen,
132 | vocab_size=vocab_size,
133 | incremental=input_max_len//32,
134 | alpha=args.alpha)
135 |
136 |
137 | length = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate
138 |
139 | if args.remove_newline_tab:
140 | input_text = ' '.join(input_text.replace('\n', ' ').replace('\t', ' ').strip().split())
141 |
142 | answer_prefix_index = input_text.rfind(TASKS['freq_words_extraction']['answer_prefix'][:10]) # use first 10 char of answer prefix to locate it
143 | answer_prefix = input_text[answer_prefix_index:]
144 | input_text = input_text[:answer_prefix_index]
145 | formatted_output = {
146 | 'index': index,
147 | "input": input_text,
148 | "outputs": answer,
149 | "length": length,
150 | 'length_w_model_temp': length + args.model_template_token,
151 | 'answer_prefix': answer_prefix,
152 | }
153 | write_jsons.append(formatted_output)
154 |
155 | return write_jsons
156 |
157 |
158 | def main():
159 | save_file = args.save_dir / f'{args.save_name}' / f'{args.subset}.jsonl'
160 | save_file.parent.mkdir(parents=True, exist_ok=True)
161 | write_jsons = sys_kwext(num_samples=args.num_samples, max_seq_length=args.max_seq_length,
162 | incremental=10)
163 |
164 | write_manifest(save_file, write_jsons)
165 |
166 | if __name__=="__main__":
167 | main()
--------------------------------------------------------------------------------
/scripts/data/synthetic/json/PaulGrahamEssays_URLs.txt:
--------------------------------------------------------------------------------
1 | http://www.paulgraham.com/13sentences.html
2 | http://www.paulgraham.com/5founders.html
3 | http://www.paulgraham.com/6631327.html
4 | http://www.paulgraham.com/95.html
5 | http://www.paulgraham.com/ace.html
6 | http://www.paulgraham.com/airbnb.html
7 | http://www.paulgraham.com/airbnbs.html
8 | http://www.paulgraham.com/alien.html
9 | http://www.paulgraham.com/altair.html
10 | http://www.paulgraham.com/ambitious.html
11 | http://www.paulgraham.com/america.html
12 | http://www.paulgraham.com/angelinvesting.html
13 | http://www.paulgraham.com/artistsship.html
14 | http://www.paulgraham.com/badeconomy.html
15 | http://www.paulgraham.com/better.html
16 | http://www.paulgraham.com/bronze.html
17 | http://www.paulgraham.com/bubble.html
18 | http://www.paulgraham.com/charisma.html
19 | http://www.paulgraham.com/cities.html
20 | http://www.paulgraham.com/college.html
21 | http://www.paulgraham.com/colleges.html
22 | http://www.paulgraham.com/conformism.html
23 | http://www.paulgraham.com/control.html
24 | http://www.paulgraham.com/convergence.html
25 | http://www.paulgraham.com/convince.html
26 | http://www.paulgraham.com/cred.html
27 | http://www.paulgraham.com/credentials.html
28 | http://www.paulgraham.com/determination.html
29 | http://www.paulgraham.com/die.html
30 | http://www.paulgraham.com/disagree.html
31 | http://www.paulgraham.com/disc.html
32 | http://www.paulgraham.com/discover.html
33 | http://www.paulgraham.com/distraction.html
34 | http://www.paulgraham.com/divergence.html
35 | http://www.paulgraham.com/donate.html
36 | http://www.paulgraham.com/ds.html
37 | http://www.paulgraham.com/early.html
38 | http://www.paulgraham.com/earnest.html
39 | http://www.paulgraham.com/equity.html
40 | http://www.paulgraham.com/essay.html
41 | http://www.paulgraham.com/ffb.html
42 | http://www.paulgraham.com/fh.html
43 | http://www.paulgraham.com/fix.html
44 | http://www.paulgraham.com/fn.html
45 | http://www.paulgraham.com/foundersatwork.html
46 | http://www.paulgraham.com/fp.html
47 | http://www.paulgraham.com/fr.html
48 | http://www.paulgraham.com/fundraising.html
49 | http://www.paulgraham.com/future.html
50 | http://www.paulgraham.com/genius.html
51 | http://www.paulgraham.com/getideas.html
52 | http://www.paulgraham.com/good.html
53 | http://www.paulgraham.com/goodart.html
54 | http://www.paulgraham.com/googles.html
55 | http://www.paulgraham.com/greatwork.html
56 | http://www.paulgraham.com/growth.html
57 | http://www.paulgraham.com/guidetoinvestors.html
58 | http://www.paulgraham.com/hackernews.html
59 | http://www.paulgraham.com/head.html
60 | http://www.paulgraham.com/herd.html
61 | http://www.paulgraham.com/heresy.html
62 | http://www.paulgraham.com/heroes.html
63 | http://www.paulgraham.com/highres.html
64 | http://www.paulgraham.com/hiresfund.html
65 | http://www.paulgraham.com/hiring.html
66 | http://www.paulgraham.com/hp.html
67 | http://www.paulgraham.com/hs.html
68 | http://www.paulgraham.com/hundred.html
69 | http://www.paulgraham.com/hw.html
70 | http://www.paulgraham.com/hwh.html
71 | http://www.paulgraham.com/icad.html
72 | http://www.paulgraham.com/ideas.html
73 | http://www.paulgraham.com/identity.html
74 | http://www.paulgraham.com/ineq.html
75 | http://www.paulgraham.com/inequality.html
76 | http://www.paulgraham.com/investors.html
77 | http://www.paulgraham.com/invtrend.html
78 | http://www.paulgraham.com/javacover.html
79 | http://www.paulgraham.com/jessica.html
80 | http://www.paulgraham.com/judgement.html
81 | http://www.paulgraham.com/kate.html
82 | http://www.paulgraham.com/kids.html
83 | http://www.paulgraham.com/ladder.html
84 | http://www.paulgraham.com/lesson.html
85 | http://www.paulgraham.com/lies.html
86 | http://www.paulgraham.com/lwba.html
87 | http://www.paulgraham.com/mac.html
88 | http://www.paulgraham.com/makersschedule.html
89 | http://www.paulgraham.com/marginal.html
90 | http://www.paulgraham.com/maybe.html
91 | http://www.paulgraham.com/mean.html
92 | http://www.paulgraham.com/microsoft.html
93 | http://www.paulgraham.com/mit.html
94 | http://www.paulgraham.com/name.html
95 | http://www.paulgraham.com/nerds.html
96 | http://www.paulgraham.com/newthings.html
97 | http://www.paulgraham.com/noob.html
98 | http://www.paulgraham.com/noop.html
99 | http://www.paulgraham.com/notnot.html
100 | http://www.paulgraham.com/nov.html
101 | http://www.paulgraham.com/nthings.html
102 | http://www.paulgraham.com/opensource.html
103 | http://www.paulgraham.com/organic.html
104 | http://www.paulgraham.com/orth.html
105 | http://www.paulgraham.com/own.html
106 | http://www.paulgraham.com/patentpledge.html
107 | http://www.paulgraham.com/pgh.html
108 | http://www.paulgraham.com/pinch.html
109 | http://www.paulgraham.com/polls.html
110 | http://www.paulgraham.com/power.html
111 | http://www.paulgraham.com/prcmc.html
112 | http://www.paulgraham.com/procrastination.html
113 | http://www.paulgraham.com/progbot.html
114 | http://www.paulgraham.com/prop62.html
115 | http://www.paulgraham.com/property.html
116 | http://www.paulgraham.com/publishing.html
117 | http://www.paulgraham.com/pypar.html
118 | http://www.paulgraham.com/ramenprofitable.html
119 | http://www.paulgraham.com/randomness.html
120 | http://www.paulgraham.com/re.html
121 | http://www.paulgraham.com/read.html
122 | http://www.paulgraham.com/real.html
123 | http://www.paulgraham.com/really.html
124 | http://www.paulgraham.com/relres.html
125 | http://www.paulgraham.com/revolution.html
126 | http://www.paulgraham.com/richnow.html
127 | http://www.paulgraham.com/road.html
128 | http://www.paulgraham.com/ronco.html
129 | http://www.paulgraham.com/safe.html
130 | http://www.paulgraham.com/say.html
131 | http://www.paulgraham.com/schlep.html
132 | http://www.paulgraham.com/seesv.html
133 | http://www.paulgraham.com/segway.html
134 | http://www.paulgraham.com/selfindulgence.html
135 | http://www.paulgraham.com/sfp.html
136 | http://www.paulgraham.com/simply.html
137 | http://www.paulgraham.com/smart.html
138 | http://www.paulgraham.com/softwarepatents.html
139 | http://www.paulgraham.com/spam.html
140 | http://www.paulgraham.com/speak.html
141 | http://www.paulgraham.com/start.html
142 | http://www.paulgraham.com/startupfunding.html
143 | http://www.paulgraham.com/startuphubs.html
144 | http://www.paulgraham.com/startupideas.html
145 | http://www.paulgraham.com/startupmistakes.html
146 | http://www.paulgraham.com/stuff.html
147 | http://www.paulgraham.com/superlinear.html
148 | http://www.paulgraham.com/swan.html
149 | http://www.paulgraham.com/tablets.html
150 | http://www.paulgraham.com/talk.html
151 | http://www.paulgraham.com/taste.html
152 | http://www.paulgraham.com/think.html
153 | http://www.paulgraham.com/top.html
154 | http://www.paulgraham.com/trolls.html
155 | http://www.paulgraham.com/twitter.html
156 | http://www.paulgraham.com/usa.html
157 | http://www.paulgraham.com/users.html
158 | http://www.paulgraham.com/venturecapital.html
159 | http://www.paulgraham.com/wealth.html
160 | http://www.paulgraham.com/webstartups.html
161 | http://www.paulgraham.com/whyyc.html
162 | http://www.paulgraham.com/word.html
163 | http://www.paulgraham.com/words.html
164 | http://www.paulgraham.com/work.html
165 | http://www.paulgraham.com/writing44.html
166 | http://www.paulgraham.com/wtax.html
167 | http://www.paulgraham.com/yahoo.html
168 | http://www.paulgraham.com/ycombinator.html
169 | http://www.paulgraham.com/ycstart.html
170 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/addiction.txt
171 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/aord.txt
172 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/apple.txt
173 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/avg.txt
174 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/before.txt
175 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/bias.txt
176 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/boss.txt
177 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/copy.txt
178 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/corpdev.txt
179 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/desres.txt
180 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/diff.txt
181 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/ecw.txt
182 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/founders.txt
183 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/foundervisa.txt
184 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/gap.txt
185 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/gba.txt
186 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/gh.txt
187 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/goodtaste.txt
188 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/hubs.txt
189 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/iflisp.txt
190 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/island.txt
191 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/know.txt
192 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/langdes.txt
193 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/laundry.txt
194 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/love.txt
195 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/mod.txt
196 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/newideas.txt
197 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/nft.txt
198 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/philosophy.txt
199 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/popular.txt
200 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/pow.txt
201 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/rootsoflisp.txt
202 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/rss.txt
203 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/siliconvalley.txt
204 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/startuplessons.txt
205 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/submarine.txt
206 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/sun.txt
207 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/superangels.txt
208 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/todo.txt
209 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/unions.txt
210 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/useful.txt
211 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/vb.txt
212 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/vcsqueeze.txt
213 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/vw.txt
214 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/want.txt
215 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/web20.txt
216 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/weird.txt
217 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/wisdom.txt
218 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/worked.txt
--------------------------------------------------------------------------------
/scripts/data/synthetic/json/download_paulgraham_essay.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License
14 |
15 | import os
16 | import shutil
17 | import glob
18 | import json
19 | import urllib.request
20 | import html2text
21 | from bs4 import BeautifulSoup
22 | from tqdm import tqdm
23 |
24 | temp_folder_repo = 'essay_repo'
25 | temp_folder_html = 'essay_html'
26 | os.makedirs(temp_folder_repo, exist_ok=True)
27 | os.makedirs(temp_folder_html, exist_ok=True)
28 |
29 | h = html2text.HTML2Text()
30 | h.ignore_images = True
31 | h.ignore_tables = True
32 | h.escape_all = True
33 | h.reference_links = False
34 | h.mark_code = False
35 |
36 | with open('PaulGrahamEssays_URLs.txt') as f:
37 | urls = [line.strip() for line in f]
38 |
39 | for url in tqdm(urls):
40 | if '.html' in url:
41 | filename = url.split('/')[-1].replace('.html', '.txt')
42 | try:
43 | with urllib.request.urlopen(url) as website:
44 | content = website.read().decode("unicode_escape", "utf-8")
45 | soup = BeautifulSoup(content, 'html.parser')
46 | specific_tag = soup.find('font')
47 | parsed = h.handle(str(specific_tag))
48 |
49 | with open(os.path.join(temp_folder_html, filename), 'w') as file:
50 | file.write(parsed)
51 |
52 | except Exception as e:
53 | print(f"Fail download {filename}, ({e})")
54 |
55 | else:
56 | filename = url.split('/')[-1]
57 | try:
58 | with urllib.request.urlopen(url) as website:
59 | content = website.read().decode('utf-8')
60 |
61 | with open(os.path.join(temp_folder_repo, filename), 'w') as file:
62 | file.write(content)
63 |
64 | except Exception as e:
65 | print(f"Fail download {filename}, ({e})")
66 |
67 | files_repo = sorted(glob.glob(os.path.join(temp_folder_repo,'*.txt')))
68 | files_html = sorted(glob.glob(os.path.join(temp_folder_html,'*.txt')))
69 | print(f'Download {len(files_repo)} essays from `https://github.com/gkamradt/LLMTest_NeedleInAHaystack/`')
70 | print(f'Download {len(files_html)} essays from `http://www.paulgraham.com/`')
71 |
72 | text = ""
73 | for file in files_repo + files_html:
74 | with open(file, 'r') as f:
75 | text += f.read()
76 |
77 | with open('PaulGrahamEssays.json', 'w') as f:
78 | json.dump({"text": text}, f)
79 |
80 |
81 | shutil.rmtree(temp_folder_repo)
82 | shutil.rmtree(temp_folder_html)
83 |
--------------------------------------------------------------------------------
/scripts/data/synthetic/json/download_qa_dataset.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json -O squad.json
16 | wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json -O hotpotqa.json
17 |
--------------------------------------------------------------------------------
/scripts/data/synthetic/json/english_words.json:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:affcd6d45fdf3cc843d585c99c97ad615094e760e6c4756b654bab6c73bc2eca
3 | size 8564991
4 |
--------------------------------------------------------------------------------
/scripts/data/synthetic/niah.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License
14 |
15 | """
16 | Create a dataset jsonl file for needle in a haystack.
17 |
18 | python niah.py \
19 | --save_dir=./ \
20 | --save_name=niah_single \
21 | --tokenizer_path=tokenizer.model \
22 | --tokenizer_type=nemo \
23 | --max_seq_length=4096 \
24 | --tokens_to_generate=128 \
25 | --num_samples=10 \
26 | --template="Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text? The special magic {type_needle_v} for {query} mentioned in the provided text are"
27 | """
28 | import os
29 | import re
30 | import json
31 | import uuid
32 | import argparse
33 | import numpy as np
34 | from pathlib import Path
35 | from tqdm import tqdm
36 | import random
37 | import wonderwords
38 | from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest
39 | import sys
40 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
41 | from tokenizer import select_tokenizer
42 | from nltk.tokenize import sent_tokenize
43 | import logging
44 |
45 | logging.basicConfig(level=logging.INFO, force=True)
46 | logger = logging.getLogger(__name__)
47 |
48 |
49 | from constants import TASKS
50 |
51 | parser = argparse.ArgumentParser()
52 | # Basic Configurations
53 | parser.add_argument("--save_dir", type=Path, required=True, help='dataset folder to save dataset')
54 | parser.add_argument("--save_name", type=str, required=True, help='name of the save dataset jsonl file')
55 | parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test')
56 | parser.add_argument("--tokenizer_path", type=str, required=True, help='path to the tokenizer model')
57 | parser.add_argument("--tokenizer_type", type=str, default='nemo', help='[Options] nemo, hf, openai.')
58 | parser.add_argument("--max_seq_length", type=int, required=True, help='max sequence length including all input tokens and generated tokens.')
59 | parser.add_argument("--tokens_to_generate", type=int, required=True, help='expected generated token amount.')
60 | parser.add_argument("--num_samples", type=int, required=True, help='number of samples to generate')
61 | parser.add_argument("--random_seed", type=int, default=42)
62 | parser.add_argument("--template", type=str, default='', help='prompt template')
63 | parser.add_argument("--remove_newline_tab", action='store_true', help='remove `\n` and `\t` in all strings.')
64 |
65 | # Complexity Configurations
66 | parser.add_argument("--num_needle_k", type=int, default=1)
67 | parser.add_argument("--num_needle_v", type=int, default=1)
68 | parser.add_argument("--num_needle_q", type=int, default=1)
69 | parser.add_argument("--type_haystack", type=str, default='essay', help='[Options] noise, essay, needle.')
70 | parser.add_argument("--type_needle_k", type=str, default='words', help='[Options] numbers, words, uuids.')
71 | parser.add_argument("--type_needle_v", type=str, default='numbers', help='[Options] numbers, words, uuids.')
72 | parser.add_argument("--model_template_token", type=int, default=0, help='used for nemo skills, minus num of model template token')
73 |
74 | args = parser.parse_args()
75 | random.seed(args.random_seed)
76 | np.random.seed(args.random_seed)
77 | args.num_needle_k = max(args.num_needle_k, args.num_needle_q)
78 |
79 | # Load Tokenizer
80 | TOKENIZER = select_tokenizer(args.tokenizer_type, args.tokenizer_path)
81 |
82 | # Define Needle/Haystack Format
83 | needle = "One of the special magic {type_needle_v} for {key} is: {value}."
84 | if args.type_haystack == 'essay':
85 | essay = os.path.join(os.path.dirname(os.path.abspath(__file__)), "json/PaulGrahamEssays.json")
86 | essay = json.load(open(essay))['text']
87 | haystack = re.sub(r'\s+', " ", essay).split(" ")
88 | elif args.type_haystack == 'noise':
89 | haystack = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
90 | elif args.type_haystack == 'needle':
91 | haystack = needle
92 | else:
93 | raise NotImplementedError(f'{args.type_haystack} is not implemented.')
94 |
95 |
96 | # Words
97 | nouns = wonderwords.random_word._get_words_from_text_file("nounlist.txt")
98 | adjs = wonderwords.random_word._get_words_from_text_file("adjectivelist.txt")
99 | # verbs = wonderwords.random_word._get_words_from_text_file("verblist.txt")
100 | words = [f"{adj}-{noun}" for adj in adjs for noun in nouns]
101 | words = sorted(list(set(words)))
102 |
103 |
104 | # Positions
105 | DEPTHS = list(np.round(np.linspace(0, 100, num=40, endpoint=True)).astype(int))
106 |
107 |
108 | def generate_random_number(num_digits=7):
109 | lower_bound = 10**(num_digits - 1)
110 | upper_bound = 10**num_digits - 1
111 | return str(random.randint(lower_bound, upper_bound))
112 |
113 | def generate_random_word():
114 | word = random.choice(words)
115 | return word
116 |
117 | def generate_random_uuid():
118 | return str(uuid.UUID(int=random.getrandbits(128), version=4))
119 |
120 | def generate_random(type_needle: str):
121 | if type_needle == 'numbers':
122 | return generate_random_number()
123 | elif type_needle == 'words':
124 | return generate_random_word()
125 | elif type_needle == 'uuids':
126 | return generate_random_uuid()
127 | else:
128 | raise NotImplementedError(f'{args.type_needle} is not implemented.')
129 |
130 | def generate_input_output(num_haystack):
131 | keys, values, needles = [], [], []
132 | for _ in range(args.num_needle_k):
133 | keys.append(generate_random(args.type_needle_k))
134 | value = []
135 | for _ in range(args.num_needle_v):
136 | value.append(generate_random(args.type_needle_v))
137 | needles.append(needle.format(
138 | type_needle_v=args.type_needle_v,
139 | key=keys[-1],
140 | value=value[-1],
141 | ))
142 | values.append(value)
143 |
144 | random.Random(args.random_seed).shuffle(needles)
145 |
146 | # Context
147 | if args.type_haystack == 'essay':
148 | text = " ".join(haystack[:num_haystack])
149 | if num_haystack <= len(haystack):
150 | text = " ".join(haystack[:num_haystack])
151 | else:
152 | # Repeat haystack as many times as needed and slice to num_haystack
153 | repeats = (num_haystack + len(haystack) - 1) // len(haystack) # Ceiling division
154 | text = " ".join((haystack * repeats)[:num_haystack])
155 | document_sents = sent_tokenize(text.strip())
156 | insertion_positions = [0] + \
157 | sorted([int(len(document_sents) * (depth / 100)) for depth in random.sample(DEPTHS, len(needles))]) + \
158 | [len(document_sents)]
159 | document_sents_list = []
160 | for i in range(1,len(insertion_positions)):
161 | last_pos = insertion_positions[i-1]
162 | next_pos = insertion_positions[i]
163 | document_sents_list.append(" ".join(document_sents[last_pos:next_pos]))
164 | if i-1 < len(needles):
165 | document_sents_list.append(needles[i-1])
166 | context = " ".join(document_sents_list)
167 |
168 | else:
169 | if args.type_haystack == 'noise':
170 | sentences = [haystack] * num_haystack
171 | elif args.type_haystack == 'needle':
172 | sentences = [haystack.format(
173 | type_needle_v=args.type_needle_v,
174 | key=generate_random(args.type_needle_k),
175 | value=generate_random(args.type_needle_v),
176 | ) for _ in range(num_haystack)]
177 |
178 |
179 | indexes = sorted(random.sample(range(num_haystack), len(needles)), reverse=True)
180 | for index, element in zip(indexes, needles):
181 | sentences.insert(index, element)
182 | context = "\n".join(sentences)
183 |
184 |
185 | ## Query and Answer
186 | indices = random.sample(range(args.num_needle_k), args.num_needle_q)
187 | queries = [keys[i] for i in indices]
188 | answers = [a for i in indices for a in values[i]]
189 | query = ', '.join(queries[:-1]) + ', and ' + queries[-1] if len(queries) > 1 else queries[0]
190 |
191 | template = args.template
192 | type_needle_v = args.type_needle_v
193 | if args.num_needle_q * args.num_needle_v == 1:
194 | template = template.replace('Some', 'A')
195 | template = template.replace('are all', 'is')
196 | template = template.replace('are', 'is')
197 | template = template.replace('answers', 'answer')
198 | type_needle_v = type_needle_v[:-1] # remove "s"
199 |
200 | input_text = template.format(
201 | type_needle_v=type_needle_v,
202 | context=context,
203 | query=query,
204 | )
205 |
206 | return input_text, answers
207 |
208 |
209 | def generate_samples(num_samples: int, max_seq_length: int, save_dir: str, incremental: int = 500):
210 | write_jsons = []
211 | tokens_to_generate = args.tokens_to_generate
212 | max_seq_length -= args.model_template_token
213 |
214 | if args.type_haystack == 'essay':
215 | incremental = 500
216 | elif args.type_haystack == 'noise':
217 | incremental = 25
218 | elif args.type_haystack == 'needle':
219 | incremental = 25
220 |
221 | if args.type_haystack != 'essay' and args.max_seq_length < 4096:
222 | incremental = 5
223 |
224 | # Estimate tokens per question to determine reasonable upper bound
225 | sample_input_text, _ = generate_input_output(incremental)
226 | sample_tokens = len(TOKENIZER.text_to_tokens(sample_input_text))
227 | tokens_per_haystack = sample_tokens / incremental
228 |
229 | # Let's do 3x to allow for some slack since we can get unlucky due to sampling.
230 | # NOTE: We should test this for really large sequence lengths to make sure it's reasonable.
231 | estimated_max_questions = int((max_seq_length / tokens_per_haystack) * 3)
232 |
233 | # Binary search for optimal haystack size
234 | lower_bound = incremental
235 | upper_bound = max(estimated_max_questions, incremental * 2) # Ensure upper_bound is reasonable
236 |
237 | optimal_num_haystack = None
238 |
239 | logger.info(f"Estimated {tokens_per_haystack:.1f} tokens per haystack")
240 | logger.info(f"Starting binary search with bounds: {lower_bound} to {upper_bound}")
241 |
242 | while lower_bound <= upper_bound:
243 | mid = (lower_bound + upper_bound) // 2
244 | input_text, answer = generate_input_output(mid)
245 | total_tokens = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate
246 |
247 | logger.info(f"Testing haystack size: {mid}, resulting tokens: {total_tokens}/{max_seq_length}")
248 |
249 | if total_tokens <= max_seq_length:
250 | # This size works, can we go larger?
251 | optimal_num_haystack = mid
252 | lower_bound = mid + 1
253 | else:
254 | # Too large, need to go smaller
255 | upper_bound = mid - 1
256 |
257 | num_haystack = optimal_num_haystack if optimal_num_haystack is not None else incremental
258 | logger.info(f'Final optimal haystack size (number of haystack): {num_haystack}')
259 |
260 |
261 |
262 | # Generate samples
263 | for index in tqdm(range(num_samples)):
264 | used_haystack = num_haystack
265 | while(True):
266 | try:
267 | input_text, answer = generate_input_output(used_haystack)
268 | length = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate
269 | assert length <= max_seq_length, f"{length} exceeds max_seq_length."
270 | break
271 | except:
272 | if used_haystack > incremental:
273 | used_haystack -= incremental
274 |
275 | if args.remove_newline_tab:
276 | input_text = ' '.join(input_text.replace('\n', ' ').replace('\t', ' ').strip().split())
277 | answer_prefix_index = input_text.rfind(TASKS['niah']['answer_prefix'][:10]) # use first 10 char of answer prefix to locate it
278 | answer_prefix = input_text[answer_prefix_index:]
279 | input_text = input_text[:answer_prefix_index]
280 | # find answer position in text
281 | index = input_text.find(answer[0])
282 | token_position_answer = len(TOKENIZER.text_to_tokens(input_text[:index]))
283 | formatted_output = {
284 | 'index': index,
285 | "input": input_text,
286 | "outputs": answer,
287 | "length": length,
288 | 'length_w_model_temp': length + args.model_template_token,
289 | 'answer_prefix': answer_prefix,
290 | 'token_position_answer': token_position_answer,
291 | }
292 | write_jsons.append(formatted_output)
293 |
294 | return write_jsons
295 |
296 |
297 | def main():
298 | save_file = args.save_dir / f'{args.save_name}' / f'{args.subset}.jsonl'
299 | save_file.parent.mkdir(parents=True, exist_ok=True)
300 | write_jsons = generate_samples(
301 | num_samples=args.num_samples,
302 | max_seq_length=args.max_seq_length,
303 | save_dir=args.save_dir
304 | )
305 |
306 | write_manifest(save_file, write_jsons)
307 |
308 | if __name__ == "__main__":
309 | main()
310 |
--------------------------------------------------------------------------------
/scripts/data/synthetic/qa.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License
14 |
15 | """
16 | Create a dataset jsonl file for QA task.
17 |
18 | python qa.py \
19 | --save_dir=./ \
20 | --save_name=niah_single \
21 | --tokenizer_path=tokenizer.model \
22 | --tokenizer_type=nemo \
23 | --max_seq_length=4096 \
24 | --tokens_to_generate=128 \
25 | --num_samples=10 \
26 | --template="Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query} Answer:"
27 | """
28 | import os
29 | import re
30 | import json
31 | import argparse
32 | from pathlib import Path
33 | from tqdm import tqdm
34 | import random
35 | import numpy as np
36 | from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest
37 | import sys
38 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
39 | from tokenizer import select_tokenizer
40 | import logging
41 |
42 | logging.basicConfig(level=logging.INFO, force=True)
43 | logger = logging.getLogger(__name__)
44 |
45 | from constants import TASKS
46 |
47 | parser = argparse.ArgumentParser()
48 | # Basic Configurations
49 | parser.add_argument("--save_dir", type=Path, required=True, help='dataset folder to save dataset')
50 | parser.add_argument("--save_name", type=str, required=True, help='name of the save dataset jsonl file')
51 | parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test')
52 | parser.add_argument("--tokenizer_path", type=str, required=True, help='path to the tokenizer model')
53 | parser.add_argument("--tokenizer_type", type=str, default='nemo', help='[Options] nemo, hf, openai.')
54 | parser.add_argument("--max_seq_length", type=int, required=True, help='max sequence length including all input tokens and generated tokens.')
55 | parser.add_argument("--tokens_to_generate", type=int, required=True, help='expected generated token amount.')
56 | parser.add_argument("--num_samples", type=int, required=True, help='number of samples to generate')
57 | parser.add_argument("--pre_samples", type=int, default=0, help='number of samples are already generated')
58 | parser.add_argument("--random_seed", type=int, default=42)
59 | parser.add_argument("--template", type=str, required=True, help='prompt template')
60 | parser.add_argument("--remove_newline_tab", action='store_true', help='remove `\n` and `\t` in all strings.')
61 | parser.add_argument("--model_template_token", type=int, default=0, help='used for nemo skills, minus num of model template token')
62 | # Complexity Configurations
63 | parser.add_argument("--dataset", type=str, required=True, help='dataset file')
64 |
65 | args = parser.parse_args()
66 | random.seed(args.random_seed)
67 | np.random.seed(args.random_seed)
68 |
69 | # Load Tokenizer
70 | TOKENIZER = select_tokenizer(args.tokenizer_type, args.tokenizer_path)
71 |
72 | # Read SQuAD QA dataset
73 | def read_squad(file):
74 | with open(file) as f:
75 | data = json.load(f)
76 |
77 | total_docs = [p['context'] for d in data['data'] for p in d['paragraphs']]
78 | total_docs = sorted(list(set(total_docs)))
79 | total_docs_dict = {c: idx for idx, c in enumerate(total_docs)}
80 |
81 | total_qas = []
82 | for d in data['data']:
83 | more_docs = [total_docs_dict[p['context']] for p in d['paragraphs']]
84 | for p in d['paragraphs']:
85 | for qas in p['qas']:
86 | if not qas['is_impossible']:
87 | total_qas.append({
88 | 'query': qas['question'],
89 | 'outputs': [a['text'] for a in qas['answers']],
90 | 'context': [total_docs_dict[p['context']]],
91 | 'more_context': [idx for idx in more_docs if idx != total_docs_dict[p['context']]]
92 | })
93 |
94 | return total_qas, total_docs
95 |
96 | # Read Hotpot QA dataset
97 | def read_hotpotqa(file):
98 | with open(file) as f:
99 | data = json.load(f)
100 |
101 | total_docs = [f"{t}\n{''.join(p)}" for d in data for t, p in d['context']]
102 | total_docs = sorted(list(set(total_docs)))
103 | total_docs_dict = {c: idx for idx, c in enumerate(total_docs)}
104 |
105 | total_qas = []
106 | for d in data:
107 | total_qas.append({
108 | 'query': d['question'],
109 | 'outputs': [d['answer']],
110 | 'context': [total_docs_dict[f"{t}\n{''.join(p)}"] for t, p in d['context']],
111 | })
112 |
113 | return total_qas, total_docs
114 |
115 |
116 | DOCUMENT_PROMPT = "Document {i}:\n{document}"
117 | if args.dataset == 'squad':
118 | QAS, DOCS = read_squad(os.path.join(os.path.dirname(os.path.abspath(__file__)), "json/squad.json"))
119 | elif args.dataset == 'hotpotqa':
120 | QAS, DOCS = read_hotpotqa(os.path.join(os.path.dirname(os.path.abspath(__file__)), "json/hotpotqa.json"))
121 | else:
122 | raise NotImplementedError(f'{args.dataset} is not implemented.')
123 |
124 |
125 | def generate_input_output(index, num_docs):
126 | curr_q = QAS[index]['query']
127 | curr_a = QAS[index]['outputs']
128 | curr_docs = QAS[index]['context']
129 | curr_more = QAS[index].get('more_context', [])
130 | if num_docs < len(DOCS):
131 | if (num_docs - len(curr_docs)) > len(curr_more):
132 | addition_docs = [i for i, d in enumerate(DOCS) if i not in curr_docs + curr_more]
133 | all_docs = curr_docs + curr_more + random.sample(addition_docs, max(0, num_docs - len(curr_docs) - len(curr_more)))
134 | else:
135 | all_docs = curr_docs + random.sample(curr_more, num_docs - len(curr_docs))
136 |
137 | all_docs = [DOCS[idx] for idx in all_docs]
138 | else:
139 | # Repeat DOCS as many times as needed and slice to num_docs
140 | repeats = (num_docs + len(DOCS) - 1) // len(DOCS) # Ceiling division
141 | all_docs = (DOCS * repeats)[:num_docs]
142 |
143 | random.Random(args.random_seed).shuffle(all_docs)
144 |
145 | context = '\n\n'.join([DOCUMENT_PROMPT.format(i=i+1, document=d) for i, d in enumerate(all_docs)])
146 | input_text = args.template.format(
147 | context=context,
148 | query=curr_q
149 | )
150 | return input_text, curr_a
151 |
152 |
153 | def generate_samples(num_samples: int, max_seq_length: int, save_dir: str, incremental: int = 10):
154 |
155 | write_jsons = []
156 | tokens_to_generate = args.tokens_to_generate
157 | max_seq_length -= args.model_template_token
158 |
159 | # Estimate tokens per question to determine reasonable upper bound
160 | sample_input_text, _ = generate_input_output(0, incremental)
161 | sample_tokens = len(TOKENIZER.text_to_tokens(sample_input_text))
162 | tokens_per_doc = sample_tokens / incremental
163 |
164 | # Let's do 3x to allow for some slack since we can get unlucky due to sampling.
165 | # NOTE: We should test this for really large sequence lengths to make sure it's reasonable.
166 | estimated_max_docs = int((max_seq_length / tokens_per_doc) * 3)
167 |
168 | # Binary search for optimal haystack size
169 | lower_bound = incremental
170 | upper_bound = max(estimated_max_docs, incremental * 2) # Ensure upper_bound is reasonable
171 |
172 | optimal_num_docs = None
173 |
174 | logger.info(f"Estimated {tokens_per_doc:.1f} tokens per doc")
175 | logger.info(f"Starting binary search with bounds: {lower_bound} to {upper_bound}")
176 |
177 | while lower_bound <= upper_bound:
178 | mid = (lower_bound + upper_bound) // 2
179 | input_text, answer = generate_input_output(0, mid)
180 | total_tokens = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate
181 |
182 | logger.info(f"Testing haystack size: {mid}, resulting tokens: {total_tokens}/{max_seq_length}")
183 |
184 | if total_tokens <= max_seq_length:
185 | # This size works, can we go larger?
186 | optimal_num_docs = mid
187 | lower_bound = mid + 1
188 | else:
189 | # Too large, need to go smaller
190 | upper_bound = mid - 1
191 |
192 | num_docs = optimal_num_docs if optimal_num_docs is not None else incremental
193 | logger.info(f'Final optimal haystack size (number of docs): {num_docs}')
194 |
195 | # Generate samples
196 | for index in tqdm(range(num_samples)):
197 | used_docs = num_docs
198 | while(True):
199 | try:
200 | input_text, answer = generate_input_output(index + args.pre_samples, used_docs)
201 | length = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate
202 | assert length <= max_seq_length, f"{length} exceeds max_seq_length."
203 | break
204 | except:
205 | if used_docs > incremental:
206 | used_docs -= incremental
207 |
208 | if args.remove_newline_tab:
209 | input_text = ' '.join(input_text.replace('\n', ' ').replace('\t', ' ').strip().split())
210 | answer_prefix_index = input_text.rfind(TASKS['qa']['answer_prefix'][:10]) # use first 10 char of answer prefix to locate it
211 | answer_prefix = input_text[answer_prefix_index:]
212 | input_text = input_text[:answer_prefix_index]
213 | formatted_output = {
214 | "index": index,
215 | "input": input_text,
216 | "outputs": answer,
217 | "length": length,
218 | 'length_w_model_temp': length + args.model_template_token,
219 | 'answer_prefix': answer_prefix,
220 | }
221 | write_jsons.append(formatted_output)
222 |
223 | return write_jsons
224 |
225 |
226 | def main():
227 | save_file = args.save_dir / f'{args.save_name}' / f'{args.subset}.jsonl'
228 | save_file.parent.mkdir(parents=True, exist_ok=True)
229 |
230 | write_jsons = generate_samples(
231 | num_samples=args.num_samples,
232 | max_seq_length=args.max_seq_length,
233 | save_dir=args.save_dir
234 | )
235 |
236 | write_manifest(save_file, write_jsons)
237 |
238 | if __name__=="__main__":
239 | main()
240 |
--------------------------------------------------------------------------------
/scripts/data/synthetic/variable_tracking.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License
14 |
15 | """
16 | Create a dataset jsonl file for variable tracking.
17 |
18 | python variable_tracking.py \
19 | --save_dir=./ \
20 | --save_name=vt \
21 | --tokenizer_path='EleutherAI/gpt-neox-20b' \
22 | --tokenizer_type hf \
23 | --max_seq_length 4096 \
24 | --tokens_to_generate 30 \
25 | --num_samples 10 \
26 | --random_seed 42 \
27 | --num_chains 1 --num_hops 4 \
28 | --template "[INST] Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n{context}\nQuestion: Find all variables that are assigned the value {query} in the text above. [/INST] Answer: According to the chain(s) of variable assignment in the text above, {num_v} variables are assgined the value {query}, they are: "
29 | """
30 | import os
31 | import json
32 | import re
33 | import argparse
34 | from pathlib import Path
35 | from tqdm import tqdm
36 | import random
37 | import string
38 | from constants import TASKS
39 | from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest
40 | import sys
41 | import pdb
42 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
43 | from tokenizer import select_tokenizer
44 | from nltk.tokenize import sent_tokenize
45 | import numpy as np
46 | import heapq
47 | import json
48 | import logging
49 |
50 | logging.basicConfig(level=logging.INFO, force=True)
51 | logger = logging.getLogger(__name__)
52 |
53 |
54 | parser = argparse.ArgumentParser()
55 | parser.add_argument("--save_dir", type=Path, required=True, help='dataset folder to save dataset')
56 | parser.add_argument("--save_name", type=str, required=True, help='name of the save dataset jsonl file')
57 | parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test')
58 | parser.add_argument("--tokenizer_path", type=str, required=True, help='path to the tokenizer model')
59 | parser.add_argument("--tokenizer_type", type=str, default='nemo', help='[Options] nemo, hf, openai.')
60 | parser.add_argument("--max_seq_length", type=int, required=True, help='max sequence length including all input tokens and generated tokens.')
61 | parser.add_argument("--tokens_to_generate", type=int, default=120, help='number of tokens to generate')
62 | parser.add_argument("--num_samples", type=int, required=True, help='number of samples to generate')
63 | parser.add_argument("--random_seed", type=int, default=42)
64 | parser.add_argument("--template", type=str, default='', help='prompt template')
65 | parser.add_argument("--remove_newline_tab", action='store_true', help='remove `\n` and `\t` in all strings.')
66 |
67 | parser.add_argument("--type_haystack", type=str, default='noise', help='[Options] noise or essay.')
68 | parser.add_argument("--num_chains", type=int, default=1, help='number of inserted variable chains')
69 | parser.add_argument("--num_hops", type=int, default=4, help='number of hops in each chain')
70 | parser.add_argument("--add_fewshot", action="store_true", default=False)
71 | parser.add_argument("--model_template_token", type=int, default=0, help='used for nemo skills, minus num of model template token')
72 |
73 | args = parser.parse_args()
74 | random.seed(args.random_seed)
75 | np.random.seed(args.random_seed)
76 |
77 | # Load Tokenizer
78 | TOKENIZER = select_tokenizer(args.tokenizer_type, args.tokenizer_path)
79 |
80 | # Define Needle/Haystack Format
81 | if args.type_haystack == 'essay':
82 | essay = os.path.join(os.path.dirname(os.path.abspath(__file__)), "json/PaulGrahamEssays.json")
83 | essay = json.load(open(essay))['text']
84 | haystack = re.sub(r'\s+', " ", essay).split(" ")
85 | elif args.type_haystack == 'noise':
86 | haystack = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
87 | else:
88 | raise NotImplementedError(f'{args.type_haystack} is not implemented.')
89 |
90 | # Positions
91 | DEPTHS = list(np.round(np.linspace(0, 100, num=40, endpoint=True)).astype(int))
92 |
93 | def generate_chains(num_chains, num_hops, is_icl=False):
94 |
95 | vars_all = []
96 | k = 5 if not is_icl else 3
97 | num_hops = num_hops if not is_icl else min(10, num_hops)
98 | vars_all = [''.join(random.choices(string.ascii_uppercase, k=k)).upper() for _ in range((num_hops+1) * num_chains)]
99 | while len(set(vars_all)) < num_chains * (num_hops+1):
100 | vars_all.append(''.join(random.choices(string.ascii_uppercase, k=k)).upper())
101 |
102 | vars_ret = []
103 | chains_ret = []
104 | for i in range(0, len(vars_all), num_hops+1):
105 | this_vars = vars_all[i:i+num_hops+1]
106 | vars_ret.append(this_vars)
107 | if is_icl:
108 | this_chain = [f"VAR {this_vars[0]} = 12345"]
109 | else:
110 | this_chain = [f"VAR {this_vars[0]} = {str(np.random.randint(10000, 99999))}"]
111 | for j in range(num_hops):
112 | this_chain.append(f"VAR {this_vars[j+1]} = VAR {this_vars[j]} ")
113 | chains_ret.append(this_chain)
114 | return vars_ret, chains_ret
115 |
116 | def shuffle_sublists_heap(lst):
117 | heap = []
118 | for i in range(len(lst)):
119 | heapq.heappush(heap, (random.random(), i, 0)) # Push first element of each list with random priority
120 | shuffled_result = []
121 | while heap:
122 | _, list_idx, elem_idx = heapq.heappop(heap) # Get the lowest random priority element
123 | shuffled_result.append(lst[list_idx][elem_idx])
124 |
125 | # If there are more elements in the same sublist, add the next one
126 | if elem_idx + 1 < len(lst[list_idx]):
127 | heapq.heappush(heap, (random.random(), list_idx, elem_idx + 1))
128 | return shuffled_result
129 |
130 | def generate_input_output(num_noises, num_chains, num_hops, is_icl=False):
131 |
132 | vars, chains = generate_chains(num_chains, num_hops, is_icl=is_icl)
133 | value = chains[0][0].split("=")[-1].strip()
134 |
135 | if args.type_haystack == 'essay':
136 | text = " ".join(haystack[:num_noises])
137 | document_sents = sent_tokenize(text.strip())
138 | chains_flat = shuffle_sublists_heap(chains)
139 | insertion_positions = [0] + \
140 | sorted([int(len(document_sents) * (depth / 100)) for depth in random.sample(DEPTHS, len(chains_flat))]) + \
141 | [len(document_sents)]
142 | document_sents_list = []
143 | for i in range(1,len(insertion_positions)):
144 | last_pos = insertion_positions[i-1]
145 | next_pos = insertion_positions[i]
146 | document_sents_list.append(" ".join(document_sents[last_pos:next_pos]))
147 | if i-1 < len(chains_flat):
148 | document_sents_list.append(chains_flat[i-1].strip() + ".")
149 | context = " ".join(document_sents_list)
150 |
151 | elif args.type_haystack == 'noise':
152 | sentences = [haystack] * num_noises
153 | for chain in chains:
154 | positions = list(sorted(random.sample(range(len(sentences)), len(chain))))
155 | for insert_pi, j in zip(positions, range(len(chain))):
156 | sentences.insert(insert_pi+j, chain[j])
157 | context = "\n".join(sentences)
158 |
159 | context = context.replace(". \n", ".\n")
160 |
161 | template = args.template
162 | if is_icl and template != TASKS['variable_tracking']['template'] + TASKS['variable_tracking']['answer_prefix']:
163 | # remove model template
164 | new_template = ""
165 | if len(TASKS['variable_tracking']['template']) > 0:
166 | new_template = TASKS['variable_tracking']['template']
167 | if len(TASKS['variable_tracking']['answer_prefix']) > 0:
168 | new_template += TASKS['variable_tracking']['answer_prefix']
169 | template = new_template
170 |
171 | input_text = template.format(
172 | context=context,
173 | query=value,
174 | num_v=num_hops+1
175 | )
176 |
177 | return input_text, vars[0]
178 |
179 | def randomize_icl(icl_example):
180 | icl_tgt = icl_example.strip().split()[-args.num_hops-1:]
181 | for item in icl_tgt:
182 | new_item = ''.join(random.choices(string.ascii_uppercase, k=len(item))).upper()
183 | icl_example = icl_example.replace(item, new_item)
184 |
185 | old_value = "12345"
186 | new_value = str(np.random.randint(10000, 99999))
187 | icl_example = icl_example.replace(old_value, new_value)
188 |
189 | return icl_example
190 |
191 | def sys_vartrack_w_noise_random(num_samples: int, max_seq_length: int, incremental: int = 10,
192 | num_chains: int = 1, num_hops: int = 4,
193 | add_fewshot: bool = True,
194 | icl_example: str = None,
195 | final_output: bool = False):
196 | write_jsons = []
197 | tokens_to_generate = args.tokens_to_generate if icl_example is not None else 0
198 | max_seq_length -= args.model_template_token
199 |
200 | # Find the perfect num_noises
201 | if icl_example:
202 | if args.type_haystack == 'essay':
203 | incremental = 500
204 | elif args.type_haystack == 'noise':
205 | incremental = 10
206 |
207 | if args.type_haystack != 'essay' and args.max_seq_length < 4096:
208 | incremental = 5
209 | else:
210 | if args.type_haystack == 'essay':
211 | incremental = 50
212 | elif args.type_haystack == 'noise':
213 | incremental = 5
214 |
215 | example_tokens = 0
216 | if add_fewshot and (icl_example is not None):
217 | icl_example_out = ' '.join(icl_example['outputs'])
218 | icl_example = icl_example['input'] + " " + icl_example_out + '\n'
219 | example_tokens = len(TOKENIZER.text_to_tokens(icl_example))
220 |
221 | # Estimate tokens per question to determine reasonable upper bound
222 | sample_input_text, _ = generate_input_output(incremental, num_chains, num_hops, is_icl=add_fewshot & (icl_example is None))
223 | sample_tokens = len(TOKENIZER.text_to_tokens(sample_input_text))
224 | tokens_per_haystack = sample_tokens / incremental
225 |
226 | # Let's do 3x to allow for some slack since we can get unlucky due to sampling.
227 | # NOTE: We should test this for really large sequence lengths to make sure it's reasonable.
228 | estimated_max_noises = int((max_seq_length / tokens_per_haystack) * 3)
229 |
230 | # Binary search for optimal haystack size
231 | lower_bound = incremental
232 | upper_bound = max(estimated_max_noises, incremental * 2) # Ensure upper_bound is reasonable
233 |
234 | optimal_num_noises = None
235 |
236 | logger.info(f"Estimated {tokens_per_haystack:.1f} tokens per haystack")
237 | logger.info(f"Starting binary search with bounds: {lower_bound} to {upper_bound}")
238 | while lower_bound <= upper_bound:
239 | mid = (lower_bound + upper_bound) // 2
240 | input_text, answer = generate_input_output(mid, num_chains, num_hops, is_icl=add_fewshot & (icl_example is None))
241 | total_tokens = len(TOKENIZER.text_to_tokens(input_text)) + example_tokens + tokens_to_generate
242 |
243 | logger.info(f"Testing haystack size: {mid}, resulting tokens: {total_tokens}/{max_seq_length}")
244 |
245 | if total_tokens <= max_seq_length:
246 | # This size works, can we go larger?
247 | optimal_num_noises = mid
248 | lower_bound = mid + 1
249 | else:
250 | # Too large, need to go smaller
251 | upper_bound = mid - 1
252 |
253 | num_noises = optimal_num_noises if optimal_num_noises is not None else incremental
254 | logger.info(f'Final optimal haystack size (number of haystack): {num_noises}')
255 |
256 | # Generate samples
257 | for index in tqdm(range(num_samples)):
258 | used_noises = num_noises
259 | while(True):
260 | try:
261 | input_text, answer = generate_input_output(used_noises, num_chains, num_hops, is_icl=add_fewshot & (icl_example is None))
262 | length = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate + example_tokens
263 | assert length <= max_seq_length, f"{length} exceeds max_seq_length."
264 | break
265 | except:
266 | if used_noises > incremental:
267 | used_noises -= incremental
268 |
269 | if add_fewshot and (icl_example is not None):
270 | # insert icl_example between model template and input
271 | cutoff = input_text.index(TASKS['variable_tracking']['template'][:20])
272 | input_text = input_text[:cutoff] + randomize_icl(icl_example) + '\n' + input_text[cutoff:]
273 | if args.remove_newline_tab:
274 | input_text = ' '.join(input_text.replace('\n', ' ').replace('\t', ' ').strip().split())
275 |
276 | if final_output:
277 | answer_prefix_index = input_text.rfind(TASKS['variable_tracking']['answer_prefix'][:10]) # use first 10 char of answer prefix to locate it
278 | answer_prefix = input_text[answer_prefix_index:]
279 | input_text = input_text[:answer_prefix_index]
280 | formatted_output = {
281 | 'index': index,
282 | "input": input_text,
283 | "outputs": answer,
284 | "length": length,
285 | 'length_w_model_temp': length + args.model_template_token,
286 | 'answer_prefix': answer_prefix,
287 | }
288 | else:
289 | formatted_output = {
290 | 'index': index,
291 | "input": input_text,
292 | "outputs": answer,
293 | "length": length,
294 | }
295 | write_jsons.append(formatted_output)
296 |
297 | return write_jsons
298 |
299 |
300 | def main():
301 | save_file = args.save_dir / f'{args.save_name}' / f'{args.subset}.jsonl'
302 | save_file.parent.mkdir(parents=True, exist_ok=True)
303 |
304 | icl_example = sys_vartrack_w_noise_random(num_samples=1,
305 | max_seq_length=500,
306 | incremental=5,
307 | num_chains=args.num_chains,
308 | num_hops=args.num_hops)[0]
309 | logger.info(icl_example)
310 | write_jsons = sys_vartrack_w_noise_random(num_samples=args.num_samples,
311 | max_seq_length=args.max_seq_length,
312 | num_chains=args.num_chains,
313 | num_hops=args.num_hops,
314 | icl_example=icl_example,
315 | final_output=True)
316 |
317 | write_manifest(save_file, write_jsons)
318 |
319 | if __name__=="__main__":
320 | main()
321 |
--------------------------------------------------------------------------------
/scripts/data/template.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | Templates = {
16 | 'base': "{task_template}",
17 |
18 | 'meta-chat': "[INST] {task_template} [/INST]",
19 |
20 | 'vicuna-chat': "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {task_template} ASSISTANT:",
21 |
22 | 'lwm-chat': "You are a helpful assistant. USER: {task_template} ASSISTANT: ",
23 |
24 | 'command-r-chat': "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{task_template}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
25 |
26 | 'chatglm-chat': "[gMASK]sop<|user|> \n {task_template}<|assistant|> \n ",
27 |
28 | 'RWKV': "User: hi\n\nAssistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it\n\nUser: {task_template}\n\nAssistant:",
29 |
30 | 'Phi3': "<|user|>\n{task_template}<|end|>\n<|assistant|>\n",
31 |
32 | 'meta-llama3': "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{task_template}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
33 |
34 | 'jamba': "<|startoftext|><|bom|><|system|> <|eom|><|bom|><|user|> {task_template}<|eom|><|bom|><|assistant|>",
35 |
36 | 'nemotron5-instruct': "System\n\nUser\n{task_template}\nAssistant\n",
37 | }
--------------------------------------------------------------------------------
/scripts/data/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 | from typing import List
18 | from tenacity import (
19 | retry,
20 | stop_after_attempt,
21 | wait_fixed,
22 | wait_random,
23 | )
24 |
25 |
26 | def select_tokenizer(tokenizer_type, tokenizer_path):
27 | if tokenizer_type == 'nemo':
28 | return NeMoSentencePieceTokenizer(model_path=tokenizer_path)
29 | elif tokenizer_type == 'nemo_tiktoken':
30 | return NeMoTikTokenTokenizer(model_path=tokenizer_path)
31 | elif tokenizer_type == 'hf':
32 | return HFTokenizer(model_path=tokenizer_path)
33 | elif tokenizer_type == 'openai':
34 | return OpenAITokenizer(model_path=tokenizer_path)
35 | elif tokenizer_type == 'gemini':
36 | return GeminiTokenizer(model_path=tokenizer_path)
37 | else:
38 | raise ValueError(f"Unknown tokenizer_type {tokenizer_type}")
39 |
40 |
41 | class NeMoSentencePieceTokenizer:
42 | """
43 | Tokenizer from NeMo SentencePieceTokenizer
44 | """
45 | def __init__(self, model_path) -> None:
46 | from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer
47 | self.tokenizer = SentencePieceTokenizer(model_path=model_path)
48 |
49 | def text_to_tokens(self, text: str) -> List[str]:
50 | tokens = self.tokenizer.text_to_tokens(text)
51 | return tokens
52 |
53 | def tokens_to_text(self, tokens: List[int]) -> str:
54 | text = self.tokenizer.tokens_to_text(tokens)
55 | return text
56 |
57 | class NeMoTikTokenTokenizer:
58 | """
59 | Tokenizer from NeMo SentencePieceTokenizer
60 | """
61 | def __init__(self, model_path) -> None:
62 | from nemo.collections.common.tokenizers.tiktoken_tokenizer import TiktokenTokenizer
63 | self.tokenizer = TiktokenTokenizer(vocab_file=model_path)
64 |
65 | def text_to_tokens(self, text: str) -> List[str]:
66 | tokens = self.tokenizer.text_to_tokens(text)
67 | return tokens
68 |
69 | def tokens_to_text(self, tokens: List[int]) -> str:
70 | text = self.tokenizer.tokens_to_text(tokens)
71 | return text
72 |
73 |
74 | class HFTokenizer:
75 | """
76 | Tokenizer from HF models
77 | """
78 | def __init__(self, model_path) -> None:
79 | from transformers import AutoTokenizer
80 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
81 |
82 | def text_to_tokens(self, text: str) -> List[str]:
83 | tokens = self.tokenizer.tokenize(text)
84 | return tokens
85 |
86 | def tokens_to_text(self, tokens: List[int]) -> str:
87 | text = self.tokenizer.convert_tokens_to_string(tokens)
88 | return text
89 |
90 |
91 | class OpenAITokenizer:
92 | """
93 | Tokenizer from tiktoken
94 | """
95 | def __init__(self, model_path="cl100k_base") -> None:
96 | import tiktoken
97 | self.tokenizer = tiktoken.get_encoding(model_path)
98 |
99 | def text_to_tokens(self, text: str) -> List[int]:
100 | tokens = self.tokenizer.encode(text)
101 | return tokens
102 |
103 | def tokens_to_text(self, tokens: List[int]) -> str:
104 | text = self.tokenizer.decode(tokens)
105 | return text
106 |
107 |
108 | class GeminiTokenizer:
109 | """
110 | Tokenizer from gemini
111 | """
112 | def __init__(self, model_path="gemini-1.5-pro-latest") -> None:
113 | import google.generativeai as genai
114 | genai.configure(api_key=os.environ["GEMINI_API_KEY"])
115 | self.model = genai.GenerativeModel(model_path)
116 |
117 | @retry(wait=wait_fixed(60) + wait_random(0, 10), stop=stop_after_attempt(3))
118 | def text_to_tokens(self, text: str) -> List[int]:
119 | tokens = list(range(self.model.count_tokens(text).total_tokens))
120 | return tokens
121 |
122 | def tokens_to_text(self, tokens: List[int]) -> str:
123 | pass
--------------------------------------------------------------------------------
/scripts/eval/evaluate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Get summary.csv with score and null predictions amount.
17 |
18 | Running
19 | ```
20 | python evaluate.py \
21 | --data_dir /path/to/your/prediction_jsonl_folder \
22 | --benchmark synthetic
23 | ```
24 | """
25 |
26 | import re
27 | import os
28 | import argparse
29 | import nltk
30 | try:
31 | nltk.data.find('tokenizers/punkt')
32 | except LookupError:
33 | nltk.download('punkt')
34 |
35 | import pandas as pd
36 | import importlib
37 | import yaml
38 | from pathlib import Path
39 | from tqdm import tqdm
40 | from collections import defaultdict
41 | from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest
42 |
43 | parser = argparse.ArgumentParser()
44 | parser.add_argument("--data_dir", type=str, required=True, help='path to the prediction jsonl files')
45 | parser.add_argument("--benchmark", type=str, default='synthetic', help='Options: [synthetic]')
46 | parser.add_argument("--verbose", type=int, default=0, help='how many lines you want to display.')
47 | args = parser.parse_args()
48 |
49 |
50 | def postprocess_pred(predict_str: str, task_config: dict):
51 |
52 | predict_str = predict_str.strip()
53 |
54 | # Remove all non-printable characters
55 | np_pattern = re.compile(r'[\x00-\x1f]')
56 | predict_str = np_pattern.sub('\n', predict_str).strip()
57 |
58 | return predict_str
59 |
60 |
61 | def get_pred_and_ref(
62 | predictions_file: str,
63 | task_config: dict,
64 | input_field: str = 'input',
65 | references_field: str = 'outputs',
66 | prediction_field: str = 'pred',
67 | metadata_field: str = 'others',
68 | ):
69 | lines = read_manifest(predictions_file)
70 |
71 | inputs = []
72 | predicts = []
73 | references = []
74 | indices = []
75 |
76 | for line in tqdm(lines):
77 | input = line[input_field]
78 | predict = line[prediction_field]
79 | predict = postprocess_pred(predict, task_config)
80 | reference = line.get(references_field, [line.get('output', '')])
81 | index = line[metadata_field].get('id', line['index'])
82 |
83 | inputs.append(input)
84 | predicts.append(predict)
85 | references.append(reference)
86 | indices.append(index)
87 |
88 | return inputs, predicts, references, indices
89 |
90 | def run_evaluation_per_task(task_config: dict, predictions_file: str, verbose: int = 0):
91 | inputs, predicts, references, indices = get_pred_and_ref(
92 | predictions_file=predictions_file,
93 | task_config=task_config,
94 | )
95 |
96 | task_nulls = f'{sum([len(x)==0 for x in predicts])}/{len(predicts)}'
97 |
98 | if len(references) > 0 and references[0][0] is not None:
99 | task_score = task_config['metric_fn'](predicts, references)
100 | else:
101 | task_score = 0.0
102 |
103 | if verbose != 0:
104 | print('=' * 40)
105 | for i, (input, reference, predict) in enumerate(zip(inputs, references, predicts)):
106 | print(f'Input : {input}')
107 | print(f'Reference : {reference}')
108 | print(f'Prediction: {predict}')
109 | print('=' * 40)
110 | if i > verbose:
111 | break
112 |
113 | return task_score, task_nulls, predicts, indices
114 |
115 |
116 | def write_evaluation(results: dict):
117 | tasks = list(results.keys())
118 | score = [results[task]['score'] for task in tasks]
119 | nulls = [results[task]['nulls'] for task in tasks]
120 | dfs = [
121 | ['Tasks'] + tasks,
122 | ['Score'] + score,
123 | ['Nulls'] + nulls,
124 | ]
125 |
126 | output_file = os.path.join(args.data_dir, 'summary.csv' if len(tasks) > 1 else f'summary-{tasks[0]}.csv')
127 | df = pd.DataFrame(dfs)
128 | df.to_csv(output_file, index=False)
129 | print('\n=============================================\n')
130 | print(df)
131 | print(f'\nSaved eval results to {output_file}')
132 |
133 |
134 | def write_submission(results: dict):
135 | COLUMNS = ["Task", "ID", "Prediction"]
136 | dfs = pd.DataFrame(columns=COLUMNS, data=[])
137 |
138 | for task, result in results.items():
139 | df = pd.DataFrame({
140 | 'Task': task,
141 | 'ID': result['indices'],
142 | 'Prediction': result['predicts']
143 | })
144 | dfs = pd.concat((dfs, df[COLUMNS]))
145 |
146 | output_file = os.path.join(args.data_dir, 'submission.csv')
147 | dfs = dfs.reset_index(drop=True)
148 | dfs.to_csv(output_file, index=False)
149 | print(f'\nSaved submission results to {output_file}')
150 |
151 |
152 | def aggregate_chunk(folder):
153 | jsonl_files = [file for file in os.listdir(folder) if Path(file).suffix == '.jsonl' ]
154 | chunk_files = sorted([file for file in jsonl_files if re.match(r'.*[^_]+-\d+\.jsonl', file)])
155 | chunk_files_dict = defaultdict(list)
156 | for file in chunk_files:
157 | task = '-'.join(file.split('-')[:-1])
158 | chunk_files_dict[task].append(file)
159 |
160 | for task, files in chunk_files_dict.items():
161 | lines = []
162 | for file in sorted(files):
163 | file = os.path.join(folder, file)
164 | lines += read_manifest(file)
165 | os.remove(file) # Remove chunk files
166 | write_manifest(os.path.join(folder, f'{task}.jsonl'), lines)
167 |
168 |
169 | def main():
170 | curr_folder = os.path.dirname(os.path.abspath(__file__))
171 |
172 | try:
173 | module = importlib.import_module(f"{args.benchmark}.constants")
174 | except ImportError:
175 | print(f"Module eval.{args.benchmark}.constants not found.")
176 |
177 | tasks_base = module.TASKS
178 | with open(os.path.join(curr_folder, f"../{args.benchmark}.yaml"), "r") as f:
179 | tasks_customized = yaml.safe_load(f)
180 |
181 |
182 | TASKS = tasks_customized
183 | for _, config in TASKS.items():
184 | config.update(tasks_base[config['task']])
185 |
186 | print(f"Total tasks: {list(TASKS.keys())}")
187 |
188 | # Aggregate all prediction files
189 | aggregate_chunk(args.data_dir)
190 |
191 | # Get scores and nulls
192 | jsonl_files = [file for file in os.listdir(args.data_dir) if Path(file).suffix == '.jsonl']
193 | eval_results = {}
194 | subm_results = {}
195 |
196 |
197 | for task, config in TASKS.items():
198 |
199 | if f'{task}.jsonl' not in jsonl_files:
200 | print(f'Prediction file {task}.jsonl is not found.')
201 | continue
202 |
203 | print(f'Evaluate task {task}...')
204 | task_score, task_nulls, predicts, indices = run_evaluation_per_task(
205 | predictions_file=os.path.join(args.data_dir, f'{task}.jsonl'),
206 | task_config=config,
207 | )
208 | eval_results[task] = {
209 | 'score': task_score,
210 | 'nulls': task_nulls,
211 | }
212 | subm_results[task] = {
213 | 'predicts': predicts,
214 | 'indices':indices,
215 | }
216 |
217 | # Write to csv
218 | write_evaluation(eval_results)
219 | write_submission(subm_results)
220 |
221 | if __name__ == '__main__':
222 | main()
--------------------------------------------------------------------------------
/scripts/eval/synthetic/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Add a new task:
17 |
18 | TASK_NAME: {
19 | 'metric_fn': the metric function with input (predictions: [str], references: [[str]]) to compute score.
20 | }
21 | """
22 |
23 |
24 | def string_match_part(preds, refs):
25 | score = sum([max([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) for pred, ref in zip(preds, refs)]) / len(preds) * 100
26 | return round(score, 2)
27 |
28 | def string_match_all(preds, refs):
29 | score = sum([sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref) for pred, ref in zip(preds, refs)]) / len(preds) * 100
30 | return round(score, 2)
31 |
32 |
33 | TASKS = {
34 | 'niah': {
35 | 'metric_fn': string_match_all,
36 | },
37 | 'variable_tracking': {
38 | 'metric_fn': string_match_all,
39 | },
40 | 'common_words_extraction': {
41 | 'metric_fn': string_match_all,
42 | },
43 | 'freq_words_extraction': {
44 | 'metric_fn': string_match_all
45 | },
46 | 'qa': {
47 | 'metric_fn': string_match_part,
48 | },
49 | }
50 |
--------------------------------------------------------------------------------
/scripts/pred/call_api.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Prepare prediction jsonl with field `pred` .
17 | dataset jsonl:
18 | {
19 | "index" int,
20 | "input": str,
21 | "outputs": [str],
22 | }
23 |
24 | prediction jsonl:
25 | {
26 | "index" int,
27 | "input": str,
28 | "outputs": [str],
29 | "pred": str,
30 | }
31 | """
32 |
33 | import argparse
34 | import json
35 | import yaml
36 | import os
37 | import sys
38 | import threading
39 | import importlib
40 | import math
41 | import time
42 | from tqdm import tqdm
43 | from pathlib import Path
44 | import traceback
45 | from nemo.collections.asr.parts.utils.manifest_utils import read_manifest
46 |
47 | SERVER_TYPES = (
48 | 'trtllm',
49 | 'vllm',
50 | 'sglang',
51 | 'openai',
52 | 'gemini',
53 | 'hf',
54 | 'mamba',
55 | )
56 |
57 |
58 | class ServerAction(argparse.Action):
59 | def __call__(self, parser, namespace, values, option_string=None):
60 | namespace.server_type = values
61 |
62 |
63 | parser = argparse.ArgumentParser()
64 | # Data
65 | parser.add_argument("--data_dir", type=Path, required=True, help='path to load the dataset jsonl files')
66 | parser.add_argument("--save_dir", type=Path, required=True, help='path to save the prediction jsonl files')
67 | parser.add_argument("--benchmark", type=str, default='synthetic', help='Options: [synthetic]')
68 | parser.add_argument("--task", type=str, required=True, help='Options: tasks in benchmark')
69 | parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test')
70 | parser.add_argument("--chunk_idx", type=int, default=0, help='index of current split chunk')
71 | parser.add_argument("--chunk_amount", type=int, default=1, help='size of split chunk')
72 |
73 | # Server
74 | parser.add_argument("--server_type", default='nemo', action=ServerAction, choices=SERVER_TYPES)
75 | parser.add_argument("--server_host", type=str, default='127.0.0.1')
76 | parser.add_argument("--server_port", type=str, default='5000')
77 | parser.add_argument("--ssh_server", type=str)
78 | parser.add_argument("--ssh_key_path", type=str)
79 | parser.add_argument("--model_name_or_path", type=str, default='gpt-3.5-turbo',
80 | help='supported models from OpenAI or HF (provide a key or a local path to the checkpoint)')
81 |
82 | # Inference
83 | parser.add_argument("--temperature", type=float, default=1.0)
84 | parser.add_argument("--top_k", type=int, default=32)
85 | parser.add_argument("--top_p", type=float, default=1.0)
86 | parser.add_argument("--random_seed", type=int, default=0)
87 | parser.add_argument("--stop_words", type=str, default='')
88 | parser.add_argument("--sliding_window_size", type=int)
89 | parser.add_argument("--threads", type=int, default=4)
90 | parser.add_argument("--batch_size", type=int, default=1)
91 |
92 | args = parser.parse_args()
93 | args.stop_words = list(filter(None, args.stop_words.split(',')))
94 | if args.server_type == 'hf' or args.server_type == 'gemini':
95 | args.threads = 1
96 |
97 |
98 | def get_llm(tokens_to_generate):
99 | if args.server_type == 'trtllm':
100 | from client_wrappers import TRTLLMClient
101 | llm = TRTLLMClient(
102 | server_host=args.server_host,
103 | server_port=args.server_port,
104 | ssh_server=args.ssh_server,
105 | ssh_key_path=args.ssh_key_path,
106 | temperature=args.temperature,
107 | top_k=args.top_k,
108 | top_p=args.top_p,
109 | random_seed=args.random_seed,
110 | stop=args.stop_words,
111 | tokens_to_generate=tokens_to_generate,
112 | max_attention_window_size=args.sliding_window_size,
113 | )
114 |
115 | elif args.server_type == 'vllm':
116 | from client_wrappers import VLLMClient
117 | llm = VLLMClient(
118 | server_host=args.server_host,
119 | server_port=args.server_port,
120 | ssh_server=args.ssh_server,
121 | ssh_key_path=args.ssh_key_path,
122 | temperature=args.temperature,
123 | top_k=args.top_k,
124 | top_p=args.top_p,
125 | random_seed=args.random_seed,
126 | stop=args.stop_words,
127 | tokens_to_generate=tokens_to_generate,
128 | )
129 |
130 | elif args.server_type == 'sglang':
131 | from client_wrappers import SGLClient
132 | llm = SGLClient(
133 | server_host=args.server_host,
134 | server_port=args.server_port,
135 | ssh_server=args.ssh_server,
136 | ssh_key_path=args.ssh_key_path,
137 | temperature=args.temperature,
138 | top_k=args.top_k,
139 | top_p=args.top_p,
140 | random_seed=args.random_seed,
141 | stop=args.stop_words,
142 | tokens_to_generate=tokens_to_generate,
143 | )
144 |
145 | elif args.server_type == 'openai':
146 | from client_wrappers import OpenAIClient
147 | llm = OpenAIClient(
148 | model_name=args.model_name_or_path,
149 | temperature=args.temperature,
150 | top_k=args.top_k,
151 | top_p=args.top_p,
152 | random_seed=args.random_seed,
153 | stop=args.stop_words,
154 | tokens_to_generate=tokens_to_generate,
155 | )
156 |
157 | elif args.server_type == 'gemini':
158 | from client_wrappers import GeminiClient
159 | llm = GeminiClient(
160 | model_name=args.model_name_or_path,
161 | temperature=args.temperature,
162 | top_k=args.top_k,
163 | top_p=args.top_p,
164 | random_seed=args.random_seed,
165 | stop=args.stop_words,
166 | tokens_to_generate=tokens_to_generate,
167 | )
168 |
169 | elif args.server_type == 'hf':
170 | from model_wrappers import HuggingFaceModel
171 | llm = HuggingFaceModel(
172 | name_or_path=args.model_name_or_path,
173 | do_sample=args.temperature > 0,
174 | repetition_penalty=1,
175 | temperature=args.temperature,
176 | top_k=args.top_k,
177 | top_p=args.top_p,
178 | stop=args.stop_words,
179 | max_new_tokens=tokens_to_generate,
180 | )
181 |
182 | elif args.server_type == 'mamba':
183 | from model_wrappers import MambaModel
184 | # mamba uses its own generation function, do not pass in do_sample
185 | # https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/utils/generation.py#L121
186 | llm = MambaModel(
187 | name_or_path=args.model_name_or_path,
188 | repetition_penalty=1,
189 | temperature=args.temperature,
190 | top_k=args.top_k,
191 | top_p=args.top_p,
192 | stop=args.stop_words,
193 | max_new_tokens=tokens_to_generate,
194 | )
195 |
196 | else:
197 | raise RuntimeError(f'Unsupported server type {args.server_type}')
198 |
199 | return llm
200 |
201 |
202 | def main():
203 | start_time = time.time()
204 |
205 | curr_folder = os.path.dirname(os.path.abspath(__file__))
206 |
207 | try:
208 | sys.path.append(os.path.dirname(curr_folder))
209 | module = importlib.import_module(f"data.{args.benchmark}.constants")
210 | except ImportError:
211 | print(f"Module data.{args.benchmark}.constants not found.")
212 |
213 | tasks_base = module.TASKS
214 | with open(os.path.join(curr_folder, f"../{args.benchmark}.yaml"), "r") as f:
215 | tasks_customized = yaml.safe_load(f)
216 |
217 | if args.task not in tasks_customized:
218 | raise ValueError(f'{args.task} is not found in config_tasks.yaml')
219 |
220 | config = tasks_customized.get(args.task)
221 | config.update(tasks_base[config['task']])
222 |
223 | task_file = args.data_dir / args.task / f'{args.subset}.jsonl'
224 |
225 | if args.chunk_amount > 1:
226 | pred_file = args.save_dir / f'{args.task}-{args.chunk_idx}.jsonl'
227 | else:
228 | pred_file = args.save_dir / f'{args.task}.jsonl'
229 |
230 | print(f'Predict {args.task} \nfrom {task_file}\nto {pred_file}')
231 | pred_file.parent.mkdir(parents=True, exist_ok=True)
232 |
233 | # Load data
234 | if os.path.exists(pred_file):
235 | pred_index = [sample['index'] for sample in read_manifest(pred_file)]
236 | data = [sample for sample in read_manifest(task_file) if sample['index'] not in pred_index]
237 | else:
238 | data = read_manifest(task_file)
239 |
240 | # Load api
241 | llm = get_llm(config['tokens_to_generate'])
242 |
243 | def get_output(idx_list, index_list, input_list, outputs_list, others_list, truncation_list, length_list):
244 | nonlocal llm
245 |
246 | while True:
247 | try:
248 | pred_list = llm.process_batch(prompts=input_list)
249 | break
250 | except Exception as e:
251 | traceback.print_exc()
252 |
253 | zipped_iter = zip(pred_list, idx_list, index_list, input_list,
254 | outputs_list, others_list, truncation_list, length_list)
255 |
256 | for pred, idx, index, input, outputs, others, truncation, length in zipped_iter:
257 | if isinstance(pred['text'], str):
258 | pred_text = pred['text']
259 | elif len(pred['text']) > 0:
260 | pred_text = pred['text'][0]
261 | else:
262 | pred_text = ''
263 |
264 | outputs_parallel[idx] = {
265 | 'index': index,
266 | 'pred': pred_text,
267 | 'input': input,
268 | 'outputs': outputs,
269 | 'others': others,
270 | 'truncation': truncation,
271 | 'length': length,
272 | }
273 |
274 | threads = []
275 | outputs_parallel = [{} for _ in range(len(data))]
276 |
277 | batched_data = []
278 | batch = []
279 | for idx, data_point in enumerate(data):
280 | data_point['idx'] = idx
281 |
282 | if len(batch) >= args.batch_size:
283 | batched_data.append(batch)
284 | batch = []
285 |
286 | batch.append(data_point)
287 |
288 | if len(batch):
289 | batched_data.append(batch)
290 |
291 | # setting buffering=1 to force to dump the output after every line, so that we can see intermediate generations
292 | with open(pred_file, 'at', encoding="utf-8", buffering=1) as fout:
293 | # the data is processed sequentially, so we can store the start and end of current processing window
294 | start_idx = 0 # window: [start_idx, end_idx]
295 |
296 | for batch_idx, batch in tqdm(enumerate(batched_data), total=len(batched_data)):
297 | idx_list = [data_point['idx'] for data_point in batch]
298 | end_idx = idx_list[-1] # the data in a batch is ordered
299 |
300 | thread = threading.Thread(
301 | target=get_output,
302 | kwargs=dict(
303 | idx_list=idx_list,
304 | index_list=[data_point['index'] for data_point in batch],
305 | input_list=[data_point['input'] for data_point in batch],
306 | outputs_list=[data_point['outputs'] for data_point in batch],
307 | others_list=[data_point.get('others', {}) for data_point in batch],
308 | truncation_list=[data_point.get('truncation', -1) for data_point in batch],
309 | length_list=[data_point.get('length', -1) for data_point in batch],
310 | ),
311 | )
312 | thread.start()
313 | threads.append(thread)
314 |
315 | is_last_batch = (batch_idx == len(batched_data) - 1)
316 |
317 | if (len(threads) == args.threads) or is_last_batch:
318 | for thread in threads:
319 | thread.join()
320 | threads = []
321 |
322 | # dump the results in current processing window on disk
323 | for idx in range(start_idx, end_idx + 1):
324 | if len(outputs_parallel[idx]) > 0:
325 | fout.write(json.dumps(outputs_parallel[idx]) + '\n')
326 |
327 | start_idx = end_idx + 1
328 |
329 | print(f"Used time: {round((time.time() - start_time) / 60, 1)} minutes")
330 |
331 |
332 | if __name__ == '__main__':
333 | main()
334 |
--------------------------------------------------------------------------------
/scripts/pred/client_wrappers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import abc
17 | import json
18 | import multiprocessing
19 | import os
20 | import re
21 | import sys
22 | import time
23 | import requests
24 | import traceback
25 | from pathlib import Path
26 | from typing import List, Tuple, Union
27 | from concurrent.futures import ThreadPoolExecutor
28 | from collections import defaultdict
29 | from tenacity import (
30 | retry,
31 | stop_after_attempt,
32 | wait_random_exponential,
33 | )
34 |
35 |
36 | class Client(abc.ABC):
37 | def __init__(
38 | self,
39 | server_host,
40 | server_port='5000',
41 | ssh_server=None,
42 | ssh_key_path=None,
43 | **generation_kwargs
44 | ):
45 | self.server_host = server_host
46 | self.server_port = server_port
47 | self.ssh_server = os.getenv("SSH_SERVER", ssh_server)
48 | self.ssh_key_path = os.getenv("SSH_KEY_PATH", ssh_key_path)
49 | self.generation_kwargs = generation_kwargs
50 |
51 | @abc.abstractmethod
52 | def _single_call(
53 | self,
54 | prompts,
55 | ):
56 | pass
57 |
58 | def __call__(
59 | self,
60 | prompt: str,
61 | **kwargs
62 | ):
63 | request = self.generation_kwargs
64 | # prompts are added later
65 | request['prompts'] = [f'{prompt}']
66 | if 'others' in kwargs:
67 | request['others'] = kwargs['others']
68 |
69 | outputs = self._single_call(**request)
70 | response = {'text': outputs}
71 | return response
72 |
73 | @retry(wait=wait_random_exponential(min=15, max=60), stop=stop_after_attempt(3))
74 | def _send_request(self, request, route="generate"):
75 | if self.ssh_server and self.ssh_key_path:
76 | import sshtunnel_requests
77 |
78 | sshtunnel_request = sshtunnel_requests.from_url(f"ssh://{self.ssh_server}:22", self.ssh_key_path)
79 | outputs = sshtunnel_request.put(
80 | url="http://{}:{}/{}".format(self.server_host, self.server_port, route),
81 | data=json.dumps(request),
82 | headers={"Content-Type": "application/json"},
83 | ).json()
84 | else:
85 | outputs = requests.put(
86 | url="http://{}:{}/{}".format(self.server_host, self.server_port, route),
87 | data=json.dumps(request),
88 | headers={"Content-Type": "application/json"},
89 | ).json()
90 | return outputs
91 |
92 | def process_batch(self, prompts: List[str], **kwargs) -> List[dict]:
93 | num_threads = max(96, multiprocessing.cpu_count() * 16)
94 | with ThreadPoolExecutor(num_threads) as executor:
95 | futures = []
96 | for prompt in prompts:
97 | futures.append(
98 | executor.submit(
99 | self.__call__,
100 | prompt,
101 | **kwargs,
102 | )
103 | )
104 | rets = [f.result() for f in futures]
105 | return rets
106 |
107 |
108 | class TRTLLMClient(Client):
109 | def _single_call(
110 | self,
111 | prompts,
112 | tokens_to_generate,
113 | temperature,
114 | top_p,
115 | top_k,
116 | random_seed,
117 | stop: List[str],
118 | max_attention_window_size=None,
119 | ):
120 | request = {
121 | "prompts": prompts,
122 | "tokens_to_generate": tokens_to_generate,
123 | "temperature": temperature,
124 | "top_k": top_k,
125 | "top_p": top_p,
126 | "random_seed": random_seed,
127 | 'stop_words_list': ",".join(stop),
128 | }
129 | if max_attention_window_size:
130 | request["max_attention_window_size"] = max_attention_window_size
131 |
132 | outputs = self._send_request(request)
133 | return outputs
134 |
135 |
136 | class VLLMClient(Client):
137 | def _single_call(
138 | self,
139 | prompts,
140 | tokens_to_generate,
141 | temperature,
142 | top_p,
143 | top_k,
144 | random_seed,
145 | stop: List[str],
146 | ):
147 | request = {
148 | "prompt": prompts[0],
149 | "max_tokens": tokens_to_generate,
150 | "temperature": temperature,
151 | "top_k": top_k,
152 | "top_p": top_p,
153 | "stop": stop,
154 | }
155 | # TODO: random seed is not supported?
156 | outputs = self._send_request(request)
157 | outputs = outputs['text']
158 | return outputs
159 |
160 |
161 | class SGLClient(Client):
162 | def _single_call(
163 | self,
164 | prompts,
165 | tokens_to_generate,
166 | temperature,
167 | top_p,
168 | top_k,
169 | random_seed,
170 | stop: List[str],
171 | ):
172 | request = {
173 | "text": prompts[0],
174 | "sampling_params": {
175 | "max_new_tokens": tokens_to_generate,
176 | "temperature": temperature,
177 | "top_k": top_k,
178 | "top_p": top_p,
179 | "stop": stop,
180 | }
181 | }
182 | # TODO: random seed is not supported?
183 | outputs = self._send_request(request)
184 | outputs = outputs['text']
185 | return outputs
186 |
187 |
188 | class OpenAIClient:
189 | def __init__(
190 | self,
191 | model_name,
192 | **generation_kwargs
193 | ):
194 | model2length = {
195 | # OpenAI
196 | 'gpt-4': 8192,
197 | 'gpt-4-0613': 8192,
198 | 'gpt-4-1106-preview': 128000,
199 | 'gpt-4-0125-preview': 128000,
200 | 'gpt-4-turbo-preview': 128000,
201 | 'gpt-3.5-turbo-0125': 16385,
202 | 'gpt-3.5-turbo-1106': 16385,
203 | 'gpt-3.5-turbo-0613': 4096,
204 | 'gpt-3.5-turbo': 16385,
205 | 'gpt-3.5-turbo-16k': 16385,
206 | 'gpt-3.5-turbo-16k-0613': 16385,
207 |
208 | # Azure
209 | 'gpt-4-32k': 32768,
210 | 'gpt-4': 128000,
211 | 'gpt-35-turbo-16k': 16384,
212 | }
213 | self.openai_api_key = os.environ["OPENAI_API_KEY"]
214 | self.azure_api_id = os.environ["AZURE_API_ID"]
215 | self.azure_api_secret = os.environ["AZURE_API_SECRET"]
216 | self.azure_api_endpoint = os.environ["AZURE_API_ENDPOINT"]
217 | self.model_name = model_name
218 |
219 | # Azure
220 | if self.azure_api_id and self.azure_api_secret:
221 | if 'gpt-3.5' in model_name: self.model_name = 'gpt-35-turbo-16k'
222 | if 'gpt-4' in model_name: self.model_name = 'gpt-4'
223 |
224 | import tiktoken
225 | self.encoding = tiktoken.get_encoding("cl100k_base")
226 | self.max_length = model2length[self.model_name]
227 | self.generation_kwargs = generation_kwargs
228 | self._create_client()
229 |
230 | def _create_client(self,):
231 | from openai import OpenAI, AzureOpenAI
232 |
233 | # OpenAI
234 | if self.openai_api_key:
235 | self.client = OpenAI(
236 | api_key=self.openai_api_key
237 | )
238 |
239 | # Azure
240 | elif self.azure_api_id and self.azure_api_secret:
241 | self.client = AzureOpenAI(
242 | api_key=self.get_azure_api_key(
243 | self.azure_api_id,
244 | self.azure_api_secret,
245 | self.azure_api_endpoint,
246 | ),
247 | api_version="2024-02-15-preview",
248 | azure_endpoint=os.path.join(self.azure_api_endpoint, "llm/v1/azure"),
249 | )
250 |
251 | def _count_tokens(self, messages):
252 | tokens_per_message = 3
253 | tokens_per_name = 1
254 | num_tokens = 0
255 | for message in messages:
256 | num_tokens += tokens_per_message
257 | for key, value in message.items():
258 | num_tokens += len(self.encoding.encode(value))
259 | if key == "name":
260 | num_tokens += tokens_per_name
261 | num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
262 | return num_tokens
263 |
264 | @retry(wait=wait_random_exponential(min=15, max=60), stop=stop_after_attempt(3))
265 | def _send_request(self, request):
266 | try:
267 | response = self.client.chat.completions.create(
268 | model=self.model_name,
269 | messages=request['msgs'],
270 | max_tokens=request['tokens_to_generate'],
271 | temperature=request['temperature'],
272 | seed=request['random_seed'],
273 | top_p=request['top_p'],
274 | stop=request['stop'],
275 | )
276 | except Exception as e:
277 | print(f"Error occurred while calling OpenAI: {e}")
278 | if self.azure_api_id and self.azure_api_secret and e.status_code == 401:
279 | # token expired
280 | self._create_client()
281 |
282 | return response
283 |
284 | def __call__(
285 | self,
286 | prompt: str,
287 | ):
288 | # system_msg = [{"role": "system", "content": ""}]
289 | system_msg = []
290 | user_assistant_msgs = [{"role": "user", "content": prompt}]
291 | msgs = system_msg + user_assistant_msgs
292 | openai_length = self._count_tokens(msgs)
293 | request = self.generation_kwargs
294 |
295 | tokens_to_generate_new = self.max_length - openai_length
296 | if tokens_to_generate_new < request['tokens_to_generate']:
297 | print(f"Reduce generate tokens from {request['tokens_to_generate']} to {tokens_to_generate_new}")
298 | request['tokens_to_generate'] = tokens_to_generate_new
299 |
300 | request["msgs"] = msgs
301 | outputs = self._send_request(request)
302 | response = {'text': [outputs.choices[0].message.content]}
303 | return response
304 |
305 |
306 | def get_azure_api_key(
307 | self,
308 | p_client_id,
309 | p_client_secret,
310 | p_token_url,
311 | p_scope="azureopenai-readwrite",
312 | cache_file="azure_openai_key.json"
313 | ):
314 | base_path = Path(__file__).parent
315 | file_path = Path.joinpath(base_path, cache_file)
316 |
317 | # Check if the token is cached
318 | renew = True
319 | if os.path.exists(file_path):
320 | with open(file_path, "r") as f:
321 | token = json.load(f)
322 | renew = True if time.time() > token["expires_in"] else False
323 |
324 | if renew:
325 | # Get a new token from the OAuth server
326 | response = requests.post(
327 | os.path.join(p_token_url, "oauth/api/v1/ssa/default/token"),
328 | data={"grant_type": "client_credentials", "client_id": p_client_id,
329 | "client_secret": p_client_secret, "scope": p_scope}
330 | )
331 | response.raise_for_status()
332 | token = response.json()
333 | token["expires_in"] += time.time()
334 | with open(file_path, "w") as f:
335 | json.dump(token, f)
336 |
337 |
338 | authToken = token["access_token"]
339 | return authToken
340 |
341 |
342 | class GeminiClient:
343 | def __init__(
344 | self,
345 | model_name,
346 | **generation_kwargs
347 | ):
348 | model2length = {
349 | 'gemini-1.0-pro-latest': (30720, 2048),
350 | 'gemini-1.5-pro-latest': (1048576, 8192)
351 | }
352 |
353 | self.model_name = model_name
354 | self.model = self._initialize_model()
355 | self.max_input_length = model2length[model_name][0]
356 | self.max_output_length = model2length[model_name][1]
357 | assert generation_kwargs['tokens_to_generate'] < self.max_output_length, \
358 | print(f'tokens_to_generate exceeds {self.max_output_length}')
359 |
360 | import google.generativeai as genai
361 | self.config = genai.GenerationConfig(
362 | candidate_count=1,
363 | stop_sequences=generation_kwargs['stop'],
364 | max_output_tokens=generation_kwargs['tokens_to_generate'],
365 | temperature=generation_kwargs['temperature'],
366 | top_p=generation_kwargs['top_p'],
367 | top_k=generation_kwargs['top_k'],
368 | )
369 |
370 | from google.generativeai.types import HarmCategory, HarmBlockThreshold
371 | self.safety_settings = {
372 | HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
373 | HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
374 | HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
375 | HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
376 | }
377 |
378 | @retry(wait=wait_random_exponential(min=60, max=60), stop=stop_after_attempt(3))
379 | def _send_request(self, request):
380 | try:
381 | response = self.model.generate_content(request['prompt'],
382 | generation_config=request['config'],
383 | safety_settings=self.safety_settings)
384 | except Exception as e:
385 | traceback.print_exc()
386 | return None
387 | return response
388 |
389 | def __call__(
390 | self,
391 | prompt: str,
392 | ):
393 | assert self.model.count_tokens(prompt).total_tokens < self.max_input_length, \
394 | print(f'input length exceeds {self.max_input_length}')
395 |
396 | request = {
397 | 'prompt': prompt,
398 | 'config': self.config,
399 | }
400 |
401 | outputs = self._send_request(request)
402 |
403 | try:
404 | response = {'text': [outputs.candidates[0].content.parts[0].text]}
405 | except Exception as e:
406 | response = {'text': []}
407 | print(outputs)
408 | traceback.print_exc()
409 |
410 | return response
411 |
412 | def _initialize_model(self):
413 | import google.generativeai as genai
414 | genai.configure(api_key=os.environ["GEMINI_API_KEY"])
415 | return genai.GenerativeModel(self.model_name)
416 |
417 |
--------------------------------------------------------------------------------
/scripts/pred/model_wrappers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import json
16 | import logging
17 | import requests
18 | import torch
19 | from typing import Dict, List, Optional
20 |
21 |
22 | class HuggingFaceModel:
23 | def __init__(self, name_or_path: str, **generation_kwargs) -> None:
24 | from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
25 |
26 | self.tokenizer = AutoTokenizer.from_pretrained(name_or_path, trust_remote_code=True)
27 |
28 | if 'Yarn-Llama' in name_or_path:
29 | model_kwargs = None
30 | else:
31 | model_kwargs = {"attn_implementation": "flash_attention_2"}
32 |
33 | try:
34 | self.pipeline = pipeline(
35 | "text-generation",
36 | model=name_or_path,
37 | tokenizer=self.tokenizer,
38 | trust_remote_code=True,
39 | device_map="auto",
40 | torch_dtype=torch.bfloat16,
41 | model_kwargs=model_kwargs,
42 | )
43 | except:
44 | self.pipeline = None
45 | self.model = AutoModelForCausalLM.from_pretrained(name_or_path, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16,)
46 |
47 | self.generation_kwargs = generation_kwargs
48 | self.stop = self.generation_kwargs.pop('stop')
49 |
50 | if self.tokenizer.pad_token is None:
51 | # add pad token to allow batching (known issue for llama2)
52 | self.tokenizer.padding_side = 'left'
53 | self.tokenizer.pad_token = self.tokenizer.eos_token
54 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
55 |
56 |
57 | def __call__(self, prompt: str, **kwargs) -> dict:
58 | return self.process_batch([prompt], **kwargs)[0]
59 |
60 | def process_batch(self, prompts: List[str], **kwargs) -> List[dict]:
61 | if self.pipeline is None:
62 | inputs = self.tokenizer(prompts, return_tensors="pt", padding=True).to(self.model.device)
63 | generated_ids = self.model.generate(
64 | **inputs,
65 | **self.generation_kwargs
66 | )
67 | generated_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
68 | else:
69 | output = self.pipeline(text_inputs=prompts, **self.generation_kwargs, )
70 | assert len(output) == len(prompts)
71 | # output in the form of a list of list of dictionaries
72 | # outer list len = batch size
73 | # inner list len = 1
74 | generated_texts = [llm_result[0]["generated_text"] for llm_result in output]
75 |
76 | results = []
77 |
78 | for text, prompt in zip(generated_texts, prompts):
79 | # remove the input form the generated text
80 | # This is a workaround for the llama3 tokenizer not being able to reproduce the same prompt after tokenization
81 | # see Issue https://github.com/NVIDIA/RULER/issues/54 for explaination
82 | if self.pipeline is None:
83 | tokenized_prompt = self.tokenizer(prompt, return_tensors="pt", padding=True)
84 | prompt = self.tokenizer.decode(tokenized_prompt.input_ids[0], skip_special_tokens=True)
85 | if text.startswith(prompt):
86 | text = text[len(prompt):]
87 |
88 | if self.stop is not None:
89 | for s in self.stop:
90 | text = text.split(s)[0]
91 |
92 | results.append({'text': [text]})
93 |
94 | return results
95 |
96 |
97 | class MambaModel:
98 | def __init__(self, name_or_path: str, **generation_kwargs) -> None:
99 | from transformers import AutoTokenizer
100 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
101 |
102 | self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
103 | self.device = "cuda"
104 | self.model = MambaLMHeadModel.from_pretrained(name_or_path, device=self.device, dtype=torch.bfloat16)
105 | self.generation_kwargs = generation_kwargs
106 | self.stop = self.generation_kwargs.pop('stop')
107 | self.max_genlen = self.generation_kwargs.pop('max_new_tokens')
108 | self.minp = 0.0
109 |
110 | def __call__(self, prompt: str, **kwargs) -> Dict[str, List[str]]:
111 | # tokenize
112 | tokens = self.tokenizer(prompt, return_tensors="pt")
113 | input_ids = tokens.input_ids.to(self.device)
114 | max_length = input_ids.shape[1] + self.max_genlen
115 |
116 | # generate
117 | out = self.model.generate(
118 | input_ids=input_ids,
119 | max_length=max_length,
120 | cg=True,
121 | return_dict_in_generate=True,
122 | output_scores=True,
123 | enable_timing=False,
124 | **self.generation_kwargs,
125 | )
126 | assert len(out.sequences) == 1
127 | # detok
128 | return {'text': [self.tokenizer.decode(out.sequences[0][input_ids.shape[1]:])]}
129 |
130 | def process_batch(self, prompts: List[str], **kwargs) -> List[dict]:
131 | # FIXME: naive implementation
132 | return [self.__call__(prompt, **kwargs) for prompt in prompts]
133 |
--------------------------------------------------------------------------------
/scripts/pred/serve_trt.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # adapted from https://github.com/Kipok/NeMo-Skills/blob/v0.1/nemo_skills/inference/server/serve_trt.py
16 |
17 | import json
18 | import logging
19 | import sys
20 | import json
21 | from pathlib import Path
22 | from argparse import ArgumentParser
23 |
24 | import numpy as np
25 | import torch
26 | import tensorrt_llm
27 | from flask import Flask, jsonify, request
28 | from flask_restful import Api, Resource
29 | from tensorrt_llm.runtime import ModelRunnerCpp
30 | from mpi4py import MPI
31 | from transformers import AutoTokenizer
32 |
33 |
34 | class TritonServerGenerate(Resource):
35 | def __init__(self, model):
36 | self.model = model
37 | self.comm = MPI.COMM_WORLD
38 |
39 | def generate(
40 | self,
41 | prompts,
42 | max_new_tokens,
43 | temperature,
44 | top_k,
45 | top_p,
46 | repetition_penalty,
47 | random_seed,
48 | stop_words_list,
49 | max_attention_window_size=None
50 | ):
51 | output = self.model.forward(
52 | prompts,
53 | max_output_token=max_new_tokens,
54 | top_k=top_k,
55 | top_p=top_p,
56 | temperature=temperature,
57 | repetition_penalty=repetition_penalty,
58 | random_seed=random_seed,
59 | stop_words_list=stop_words_list,
60 | max_attention_window_size=max_attention_window_size,
61 | )
62 | return output
63 |
64 | def put(self):
65 | logging.info("request IP: " + str(request.remote_addr))
66 | logging.info(json.dumps(request.get_json()))
67 |
68 | input_request = request.get_json()
69 |
70 | tokens_to_generate = input_request.get("tokens_to_generate", 64)
71 | temperature = input_request.get("temperature", 1.0)
72 | top_k = input_request.get("top_k", 0)
73 | top_p = input_request.get("top_p", 1.0)
74 | repetition_penalty = input_request.get("repetition_penalty", 1.2)
75 | stop_words_list = input_request.get("stop_words_list")
76 | max_attention_window_size = input_request.get("max_attention_window_size")
77 | random_seed = input_request.get("random_seed", 0)
78 | prompts = input_request["prompts"]
79 |
80 | data = dict(
81 | prompts=prompts,
82 | max_new_tokens=tokens_to_generate,
83 | temperature=temperature,
84 | top_k=top_k,
85 | top_p=top_p,
86 | repetition_penalty=repetition_penalty,
87 | random_seed=random_seed,
88 | stop_words_list=stop_words_list,
89 | max_attention_window_size=max_attention_window_size,
90 | )
91 | self.comm.Barrier()
92 | data = self.comm.bcast(data, root=0)
93 |
94 | out = self.generate(**data)
95 | return jsonify(out)
96 |
97 |
98 | def parse_input(input_texts: str, tokenizer):
99 | batch_input_ids = [
100 | tokenizer.encode(
101 | input_text,
102 | add_special_tokens=False, # TODO: does this need to be true?
103 | )
104 | for input_text in input_texts
105 | ]
106 | batch_input_ids = [torch.tensor(x, dtype=torch.int32, device="cuda") for x in batch_input_ids]
107 | input_lengths = [x.size(0) for x in batch_input_ids]
108 |
109 | return batch_input_ids, input_lengths
110 |
111 |
112 | def get_output(output_ids, input_lengths, max_output_len, tokenizer, eos_token):
113 | num_beams = output_ids.size(1)
114 | assert num_beams == 1
115 | output_texts = []
116 | for idx, input_len in enumerate(input_lengths):
117 | output_begin = input_len
118 | output_end = input_len + max_output_len
119 | outputs = output_ids[idx][0][output_begin:output_end]
120 | eos_ids = (outputs == eos_token).nonzero(as_tuple=True)[-1]
121 | if len(eos_ids) > 0:
122 | outputs = outputs[: eos_ids[0]]
123 | outputs = outputs.tolist()
124 | output_texts.append(tokenizer.decode(outputs))
125 | return output_texts
126 |
127 |
128 | def prepare_stop_words(stop_words_list, tokenizer):
129 | # adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/b310ec675145c9ee7668592549f733df4abf1e94/tensorrt_llm/runtime/generation.py#L46
130 | flat_ids = []
131 | offsets = []
132 | for batch_stop_words in stop_words_list:
133 | item_flat_ids = []
134 | item_offsets = []
135 |
136 | for word in batch_stop_words:
137 | # there is a known issue in TensorRT-LLM that word ids are not unique and might change depending on
138 | # where in the text it appears. In our case we mainly need to stop on ids as they appear in the middle
139 | # of the text. The following is a workaround to get such ids that works for both kind of stop
140 | # words as well as newlines that we commonly use. But note that it's not a universal fix, so this might
141 | # require refactoring if different stop words are used in the future.
142 | # Eventually, this needs to be fixed inside TensorRT-LLM itself.
143 | ids = tokenizer.encode('magic' + word)
144 | ids = ids[2:] # skipping "magic"
145 |
146 | if len(ids) == 0:
147 | continue
148 |
149 | item_flat_ids += ids
150 | item_offsets.append(len(ids))
151 |
152 | flat_ids.append(np.array(item_flat_ids))
153 | offsets.append(np.cumsum(np.array(item_offsets)))
154 |
155 | pad_to = max(1, max(len(ids) for ids in flat_ids))
156 |
157 | for i, (ids, offs) in enumerate(zip(flat_ids, offsets)):
158 | flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0)
159 | offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1)
160 |
161 | stop_words = np.array([flat_ids, offsets], dtype="int32").transpose((1, 0, 2))
162 | return torch.Tensor(stop_words).to(torch.int32).to("cuda").contiguous()
163 |
164 |
165 | def load_tokenizer(tokenizer_dir: str):
166 | tokenizer = AutoTokenizer.from_pretrained(
167 | tokenizer_dir,
168 | legacy=False,
169 | trust_remote_code=True,
170 | )
171 |
172 | if tokenizer.pad_token_id is None:
173 | tokenizer.pad_token_id = tokenizer.eos_token_id
174 | pad_id = tokenizer.pad_token_id
175 | end_id = tokenizer.eos_token_id
176 |
177 | return tokenizer, pad_id, end_id
178 |
179 |
180 |
181 | class TensorRTLLM:
182 | def __init__(self, model_path: str):
183 | self.tokenizer, self.pad_id, self.end_id = load_tokenizer(tokenizer_dir=model_path)
184 | self.runner = ModelRunnerCpp.from_dir(engine_dir=model_path, rank=tensorrt_llm.mpi_rank())
185 |
186 | @torch.no_grad()
187 | def forward(
188 | self,
189 | input_texts,
190 | max_output_token,
191 | top_k,
192 | top_p,
193 | temperature,
194 | repetition_penalty,
195 | random_seed,
196 | stop_words_list,
197 | max_attention_window_size,
198 | ):
199 | batch_input_ids, input_lengths = parse_input(input_texts, self.tokenizer)
200 |
201 | stop_words_list = [stop_words_list for _ in range(len(input_texts))]
202 | stop_words_list = prepare_stop_words(stop_words_list, self.tokenizer)
203 |
204 | # TODO: return dictionary with a proper error reporting
205 | try:
206 | output_ids = self.runner.generate(
207 | batch_input_ids,
208 | max_new_tokens=max_output_token,
209 | end_id=self.end_id,
210 | pad_id=self.pad_id,
211 | temperature=temperature,
212 | top_k=top_k,
213 | top_p=top_p,
214 | repetition_penalty=repetition_penalty,
215 | random_seed=random_seed,
216 | stop_words_list=stop_words_list,
217 | max_attention_window_size=max_attention_window_size,
218 | return_dict=False,
219 | )
220 | torch.cuda.synchronize()
221 |
222 | output = get_output(output_ids, input_lengths, max_output_token, self.tokenizer, self.end_id)
223 | except RuntimeError as e:
224 | logging.error("RuntimeError: %s", e)
225 | output = [f"RuntimeError: {e}"] * len(input_texts)
226 |
227 | return output
228 |
229 |
230 | class WrapperServer:
231 | def __init__(self, model_path: str):
232 | self.comm = MPI.COMM_WORLD
233 | self.rank = self.comm.Get_rank()
234 |
235 | self.model = TensorRTLLM(model_path=model_path)
236 |
237 | if self.rank == 0:
238 | self.app = Flask(__file__, static_url_path="")
239 | api = Api(self.app)
240 | api.add_resource(TritonServerGenerate, "/generate", resource_class_args=[self.model])
241 |
242 | def run(self, url, port=5000):
243 | if self.rank == 0:
244 | self.app.run(url, threaded=True, port=port, debug=False)
245 | else:
246 | self.worker_loop()
247 |
248 | def worker_loop(self):
249 | triton = TritonServerGenerate(self.model)
250 | while True:
251 | self.comm.Barrier()
252 | data = None
253 | data = self.comm.bcast(data, root=0)
254 | triton.generate(**data)
255 |
256 |
257 | if __name__ == "__main__":
258 | # TODO: can we reuse normal logger here?
259 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
260 |
261 | parser = ArgumentParser()
262 | parser.add_argument("--model_path", required=True)
263 | parser.add_argument("--host", type=str, default="0.0.0.0")
264 | parser.add_argument("--port", type=int, default=5000)
265 | args = parser.parse_args()
266 |
267 | server = WrapperServer(model_path=args.model_path)
268 | server.run(args.host, args.port)
269 |
--------------------------------------------------------------------------------
/scripts/pred/serve_vllm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # adapted from https://github.com/vllm-project/vllm/blob/v0.4.0/vllm/entrypoints/api_server.py
16 |
17 |
18 | import json
19 | from typing import AsyncGenerator
20 |
21 | from fastapi import FastAPI, Request
22 | from fastapi.responses import JSONResponse, Response, StreamingResponse
23 | import uvicorn
24 |
25 | from vllm.engine.arg_utils import AsyncEngineArgs
26 | from vllm.engine.async_llm_engine import AsyncLLMEngine
27 | from vllm.sampling_params import SamplingParams
28 | from vllm.utils import random_uuid
29 | from vllm.utils import FlexibleArgumentParser
30 |
31 | TIMEOUT_KEEP_ALIVE = 5 # seconds.
32 | app = FastAPI()
33 | engine = None
34 |
35 |
36 | @app.get("/health")
37 | async def health() -> Response:
38 | """Health check."""
39 | return Response(status_code=200)
40 |
41 |
42 | @app.put("/generate")
43 | async def generate(request: Request) -> Response:
44 | """Generate completion for the request.
45 |
46 | The request should be a JSON object with the following fields:
47 | - prompt: the prompt to use for the generation.
48 | - stream: whether to stream the results or not.
49 | - other fields: the sampling parameters (See `SamplingParams` for details).
50 | """
51 | request_dict = await request.json()
52 | prompt = request_dict.pop("prompt")
53 | stream = request_dict.pop("stream", False)
54 | sampling_params = SamplingParams(**request_dict)
55 | request_id = random_uuid()
56 |
57 | results_generator = engine.generate(prompt,
58 | sampling_params,
59 | request_id)
60 |
61 | # Streaming case
62 | async def stream_results() -> AsyncGenerator[bytes, None]:
63 | async for request_output in results_generator:
64 | prompt = request_output.prompt
65 | text_outputs = [
66 | prompt + output.text for output in request_output.outputs
67 | ]
68 | ret = {"text": text_outputs}
69 | yield (json.dumps(ret) + "\0").encode("utf-8")
70 |
71 | if stream:
72 | return StreamingResponse(stream_results())
73 |
74 | # Non-streaming case
75 | final_output = None
76 | async for request_output in results_generator:
77 | if await request.is_disconnected():
78 | # Abort the request if the client disconnects.
79 | await engine.abort(request_id)
80 | return Response(status_code=499)
81 | final_output = request_output
82 | assert final_output is not None
83 | text_outputs = [output.text for output in final_output.outputs]
84 | ret = {"text": text_outputs}
85 | return JSONResponse(ret)
86 |
87 |
88 | if __name__ == "__main__":
89 | parser = FlexibleArgumentParser(
90 | description="vLLM API server for serving LLMs asynchronously",
91 | epilog="For more information, see https://docs.vllm.ai/en/stable/configuration/engine_args.html",
92 | )
93 | parser.add_argument("--host", type=str, default="0.0.0.0")
94 | parser.add_argument("--port", type=int, default=5000)
95 | parser.add_argument("--ssl-keyfile", type=str, default=None)
96 | parser.add_argument("--ssl-certfile", type=str, default=None)
97 | parser.add_argument(
98 | "--root-path",
99 | type=str,
100 | default=None,
101 | help="FastAPI root_path when app is behind a path based routing proxy")
102 | parser = AsyncEngineArgs.add_cli_args(parser)
103 | args = parser.parse_args()
104 |
105 | engine_args = AsyncEngineArgs.from_cli_args(args)
106 | engine = AsyncLLMEngine.from_engine_args(engine_args)
107 |
108 | app.root_path = args.root_path
109 | uvicorn.run(app,
110 | host=args.host,
111 | port=args.port,
112 | log_level="debug",
113 | timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
114 | ssl_keyfile=args.ssl_keyfile,
115 | ssl_certfile=args.ssl_certfile)
--------------------------------------------------------------------------------
/scripts/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # container: docker.io/cphsieh/ruler:0.1.0
17 | # bash run.sh MODEL_NAME BENCHMARK_NAME
18 |
19 | if [ $# -ne 2 ]; then
20 | echo "Usage: $0 $1 "
21 | exit 1
22 | fi
23 |
24 |
25 | # Root Directories
26 | GPUS="1" # GPU size for tensor_parallel.
27 | ROOT_DIR="benchmark_root" # the path that stores generated task samples and model predictions.
28 | MODEL_DIR="../.." # the path that contains individual model folders from HUggingface.
29 | ENGINE_DIR="." # the path that contains individual engine folders from TensorRT-LLM.
30 | BATCH_SIZE=1 # increase to improve GPU utilization
31 |
32 |
33 | # Model and Tokenizer
34 | source config_models.sh
35 | MODEL_NAME=${1}
36 | MODEL_CONFIG=$(MODEL_SELECT ${MODEL_NAME} ${MODEL_DIR} ${ENGINE_DIR})
37 | IFS=":" read MODEL_PATH MODEL_TEMPLATE_TYPE MODEL_FRAMEWORK TOKENIZER_PATH TOKENIZER_TYPE OPENAI_API_KEY GEMINI_API_KEY AZURE_ID AZURE_SECRET AZURE_ENDPOINT <<< "$MODEL_CONFIG"
38 | if [ -z "${MODEL_PATH}" ]; then
39 | echo "Model: ${MODEL_NAME} is not supported"
40 | exit 1
41 | fi
42 |
43 |
44 | export OPENAI_API_KEY=${OPENAI_API_KEY}
45 | export GEMINI_API_KEY=${GEMINI_API_KEY}
46 | export AZURE_API_ID=${AZURE_ID}
47 | export AZURE_API_SECRET=${AZURE_SECRET}
48 | export AZURE_API_ENDPOINT=${AZURE_ENDPOINT}
49 |
50 |
51 | # Benchmark and Tasks
52 | source config_tasks.sh
53 | BENCHMARK=${2}
54 | declare -n TASKS=$BENCHMARK
55 | if [ -z "${TASKS}" ]; then
56 | echo "Benchmark: ${BENCHMARK} is not supported"
57 | exit 1
58 | fi
59 |
60 |
61 | # Start server (you may want to run in other container.)
62 | if [ "$MODEL_FRAMEWORK" == "vllm" ]; then
63 | python pred/serve_vllm.py \
64 | --model=${MODEL_PATH} \
65 | --tensor-parallel-size=${GPUS} \
66 | --dtype bfloat16 \
67 | --disable-custom-all-reduce \
68 | &
69 |
70 | elif [ "$MODEL_FRAMEWORK" == "trtllm" ]; then
71 | python pred/serve_trt.py \
72 | --model_path=${MODEL_PATH} \
73 | &
74 |
75 | elif [ "$MODEL_FRAMEWORK" == "sglang" ]; then
76 | python -m sglang.launch_server \
77 | --model-path ${MODEL_PATH} \
78 | --tp ${GPUS} \
79 | --port 5000 \
80 | --enable-flashinfer \
81 | &
82 | # use sglang/test/killall_sglang.sh to kill sglang server if it hangs
83 |
84 | fi
85 |
86 |
87 | # Start client (prepare data / call model API / obtain final metrics)
88 | total_time=0
89 | for MAX_SEQ_LENGTH in "${SEQ_LENGTHS[@]}"; do
90 |
91 | RESULTS_DIR="${ROOT_DIR}/${MODEL_NAME}/${BENCHMARK}/${MAX_SEQ_LENGTH}"
92 | DATA_DIR="${RESULTS_DIR}/data"
93 | PRED_DIR="${RESULTS_DIR}/pred"
94 | mkdir -p ${DATA_DIR}
95 | mkdir -p ${PRED_DIR}
96 |
97 | for TASK in "${TASKS[@]}"; do
98 | python data/prepare.py \
99 | --save_dir ${DATA_DIR} \
100 | --benchmark ${BENCHMARK} \
101 | --task ${TASK} \
102 | --tokenizer_path ${TOKENIZER_PATH} \
103 | --tokenizer_type ${TOKENIZER_TYPE} \
104 | --max_seq_length ${MAX_SEQ_LENGTH} \
105 | --model_template_type ${MODEL_TEMPLATE_TYPE} \
106 | --num_samples ${NUM_SAMPLES} \
107 | ${REMOVE_NEWLINE_TAB}
108 |
109 | start_time=$(date +%s)
110 | python pred/call_api.py \
111 | --data_dir ${DATA_DIR} \
112 | --save_dir ${PRED_DIR} \
113 | --benchmark ${BENCHMARK} \
114 | --task ${TASK} \
115 | --server_type ${MODEL_FRAMEWORK} \
116 | --model_name_or_path ${MODEL_PATH} \
117 | --temperature ${TEMPERATURE} \
118 | --top_k ${TOP_K} \
119 | --top_p ${TOP_P} \
120 | --batch_size ${BATCH_SIZE} \
121 | ${STOP_WORDS}
122 | end_time=$(date +%s)
123 | time_diff=$((end_time - start_time))
124 | total_time=$((total_time + time_diff))
125 | done
126 |
127 | python eval/evaluate.py \
128 | --data_dir ${PRED_DIR} \
129 | --benchmark ${BENCHMARK}
130 | done
131 |
132 | echo "Total time spent on call_api: $total_time seconds"
133 |
--------------------------------------------------------------------------------
/scripts/synthetic.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | niah_single_1:
16 | task: niah
17 | args:
18 | type_haystack: noise
19 | type_needle_k: words
20 | type_needle_v: numbers
21 | num_needle_k: 1
22 | num_needle_v: 1
23 | num_needle_q: 1
24 |
25 | niah_single_2:
26 | task: niah
27 | args:
28 | type_haystack: essay
29 | type_needle_k: words
30 | type_needle_v: numbers
31 | num_needle_k: 1
32 | num_needle_v: 1
33 | num_needle_q: 1
34 |
35 | niah_single_3:
36 | task: niah
37 | args:
38 | type_haystack: essay
39 | type_needle_k: words
40 | type_needle_v: uuids
41 | num_needle_k: 1
42 | num_needle_v: 1
43 | num_needle_q: 1
44 |
45 | niah_multikey_1:
46 | task: niah
47 | args:
48 | type_haystack: essay
49 | type_needle_k: words
50 | type_needle_v: numbers
51 | num_needle_k: 4
52 | num_needle_v: 1
53 | num_needle_q: 1
54 |
55 | niah_multikey_2:
56 | task: niah
57 | args:
58 | type_haystack: needle
59 | type_needle_k: words
60 | type_needle_v: numbers
61 | num_needle_k: 1
62 | num_needle_v: 1
63 | num_needle_q: 1
64 |
65 | niah_multikey_3:
66 | task: niah
67 | args:
68 | type_haystack: needle
69 | type_needle_k: uuids
70 | type_needle_v: uuids
71 | num_needle_k: 1
72 | num_needle_v: 1
73 | num_needle_q: 1
74 |
75 | niah_multivalue:
76 | task: niah
77 | args:
78 | type_haystack: essay
79 | type_needle_k: words
80 | type_needle_v: numbers
81 | num_needle_k: 1
82 | num_needle_v: 4
83 | num_needle_q: 1
84 |
85 | niah_multiquery:
86 | task: niah
87 | args:
88 | type_haystack: essay
89 | type_needle_k: words
90 | type_needle_v: numbers
91 | num_needle_k: 1
92 | num_needle_v: 1
93 | num_needle_q: 4
94 |
95 | vt:
96 | task: variable_tracking
97 | args:
98 | type_haystack: noise
99 | num_chains: 1
100 | num_hops: 4
101 |
102 | cwe:
103 | task: common_words_extraction
104 | args:
105 | freq_cw: 30
106 | freq_ucw: 3
107 | num_cw: 10
108 |
109 | fwe:
110 | task: freq_words_extraction
111 | args:
112 | alpha: 2.0
113 |
114 | qa_1:
115 | task: qa
116 | args:
117 | dataset: squad
118 |
119 | qa_2:
120 | task: qa
121 | args:
122 | dataset: hotpotqa
--------------------------------------------------------------------------------