├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── docker ├── Dockerfile └── requirements.txt └── scripts ├── config_models.sh ├── config_tasks.sh ├── data ├── prepare.py ├── synthetic │ ├── common_words_extraction.py │ ├── constants.py │ ├── freq_words_extraction.py │ ├── json │ │ ├── PaulGrahamEssays_URLs.txt │ │ ├── download_paulgraham_essay.py │ │ ├── download_qa_dataset.sh │ │ └── english_words.json │ ├── niah.py │ ├── qa.py │ └── variable_tracking.py ├── template.py └── tokenizer.py ├── eval ├── evaluate.py └── synthetic │ └── constants.py ├── pred ├── call_api.py ├── client_wrappers.py ├── model_wrappers.py ├── serve_trt.py └── serve_vllm.py ├── run.sh └── synthetic.yaml /.gitattributes: -------------------------------------------------------------------------------- 1 | *.jsonl filter=lfs diff=lfs merge=lfs -text 2 | *.jsonls filter=lfs diff=lfs merge=lfs -text 3 | *.json filter=lfs diff=lfs merge=lfs -text 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | *.jsonl 4 | .vscode/ 5 | *.out 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 📏 RULER: What’s the Real Context Size of Your Long-Context Language Models? 2 | 3 | This repository contains code for our paper [RULER: What’s the Real Context Size of Your Long-Context Language Models](https://arxiv.org/abs/2404.06654). RULER generates synthetic examples to evaluate long-context language models with configurable sequence length and task complexity. We benchmark 17 open-source models across 4 task categories (in total 13 tasks) in RULER, evaluating long-context capabilities beyond simple in-context recall. Here are our main results. 4 | 5 | |Models|Claimed Length|Effective Length|4K|8K|16K|32K|64K|128K|Avg.|wAvg. (inc)|wAvg. (dec)| 6 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 7 | |[Llama2](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) (7B)|4K||85.6| 8 | [Jamba-1.5-large*](https://huggingface.co/ai21labs/AI21-Jamba-1.5-Large) (94B/398B)|256k|>128k|96.7|96.6|96.4|96.0|95.4|95.1|96.0|95.7 **(1st)**|96.3 **(1st)**| 9 | [Gemini-1.5-pro](https://ai.google.dev/gemini-api/docs/models/gemini#:~:text=Gemini-,Gemini%201.5%20Pro%20(Preview%20only),-Text%20and%20images)|1M|>128K|96.7|95.8|96.0|95.9|95.9|94.4|95.8|95.5 **(2nd)**|96.1 **(2nd)**| 10 | [Jamba-1.5-mini](https://huggingface.co/ai21labs/AI21-Jamba-1.5-Mini) (12B/52B)|256K|>128K|95.6|95.6|94.8|94.6|92.8|90.0|93.9|93.1 **(3rd)**|94.8 **(3rd)** 11 | [GPT-4-1106-preview](https://platform.openai.com/docs/models/gpt-4-turbo-and-gpt-4#:~:text=gpt%2D4%2D1106%2Dpreview,Up%20to%20Apr%202023)|128K|64K|96.6|96.3|95.2|93.2|87.0|81.2|91.6|89.0 **(4th)**|94.1 **(4th)**| 12 | [Llama3.1](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct) (70B)|128K|64K|96.5|95.8|95.4|94.8|88.4|66.6|89.6|85.5 **(10th)**|93.7 **(5th)**| 13 | [Mistral-Large-2411](https://huggingface.co/mistralai/Mistral-Large-Instruct-2411) (123B)|128K|64K|96.4|96.3|95.3|94.0|85.9|48.1|86.0|79.5 **(18th)**|92.5 **(6th)**| 14 | [Command-R-plus-0824](https://huggingface.co/CohereForAI/c4ai-command-r-plus-08-2024) (104B)|128K|32K|96.0|95.1|94.0|92.4|85.4|64.6|87.9|83.4 **(13th)**|92.4 **(7th)**| 15 | [Qwen2](https://huggingface.co/Qwen/Qwen2-72B-Instruct) (72B)|128K|32K|96.9|96.1|94.9|94.1|79.8|53.7|85.9|79.6 **(17th)**|92.3 **(8th)**| 16 | [Command-R-plus](https://huggingface.co/CohereForAI/c4ai-command-r-plus) (104B)|128K|32K|95.6|95.2|94.2|92.0|84.3|63.1|87.4|82.7 **(14th)**|92.1 **(9th)**| 17 | [Command-R-0824](https://huggingface.co/CohereForAI/c4ai-command-r-08-2024) (32B)|128K|64K|94.7|93.7|93.1|90.8|86.6|74.7|88.9|86.0 **(8th)**|91.9 **(10th)**| 18 | [GLM4](https://huggingface.co/THUDM/glm-4-9b-chat-1m) (9B)|1M|64K|94.7|92.8|92.1|89.9|86.7|83.1|89.9|88.0 **(5th)**|91.7 **(11th)**| 19 | [Llama3.1](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) (8B)|128K|32K|95.5|93.8|91.6|87.4|84.7|77.0|88.3|85.4 **(11th)**|91.3 **(12th)**| 20 | [ProLong](https://huggingface.co/princeton-nlp/Llama-3-8B-ProLong-512k-Instruct) (8B)|512K|32K|94.5|92.5|92.3|89.3|83.2|81.6|88.9|86.6 **(7th)**|91.2 **(13th)**| 21 | [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01) (35B)|128K|32K|93.8|93.3|92.4|89.5|84.9|76.0|88.3|85.5 **(9th)**|91.1 **(14th)**| 22 | [MegaBeam-Mistral](https://huggingface.co/aws-prototyping/MegaBeam-Mistral-7B-512k) (7B)|512K|32K|93.8|92.5|92.0|89.2|83.7|83.7|89.1|87.3 **(6th)**|91.0 **(15th)**| 23 | [Mistral-Large-2407](https://huggingface.co/mistralai/Mistral-Large-Instruct-2407) (123B)|128K|32K|96.2|96.1|95.1|93.0|78.8|23.7|80.5|70.6 **(24th)**|90.4 **(16th)**| 24 | [GradientAI/Llama3](https://huggingface.co/gradientai/Llama-3-70B-Instruct-Gradient-1048k) (70B)|1M|16K|95.1|94.4|90.8|85.4|80.9|72.1|86.5|82.6 **(15th)**|90.3 **(17th)**| 25 | [Mixtral-8x22B](https://huggingface.co/mistralai/Mixtral-8x22B-instruct-v0.1) (39B/141B)|64K|32K|95.6|94.9|93.4|90.9|84.7|31.7|81.9|73.5 **(22nd)**|90.3 **(18th)**| 26 | [Yi](https://huggingface.co/01-ai/Yi-34B-200K) (34B)|200K|32K|93.3|92.2|91.3|87.5|83.2|77.3|87.5|84.8 **(12th)**|90.1 **(19th)**| 27 | [Phi3-mini](https://huggingface.co/microsoft/Phi-3-mini-128K-instruct) (3.8B)|128K|32K|92.2|91.5|90.7|87.5|80.6|66.7|84.8|80.9 **(16th)**|88.7 **(20th)**| 28 | [Phi3-medium](https://huggingface.co/microsoft/Phi-3-medium-128K-instruct) (14B)|128K|32K|93.3|93.2|91.1|86.8|78.6|46.1|81.5|74.8 **(21st)**|88.3 **(21st)**| 29 | [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-instruct-v0.1) (12.9B/46.7B)|32K|32K|94.9|92.1|92.5|85.9|72.4|44.5|80.4|72.8 **(23rd)**|87.9 **(22nd)**| 30 | [GradientAI/Llama3](https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k) (8B)|1M|16K|92.8|90.3|85.7|79.9|76.3|69.5|82.4|78.5 **(19th)**|86.3 **(23rd)**| 31 | [FILM-7B*](https://arxiv.org/pdf/2404.16811) (7B)|32K|32K|92.8|88.2|88.1|86.9|70.1|27.1|75.5|66.4 **(26th)**|84.7 **(24th)**| 32 | [InternLM2.5](https://huggingface.co/internlm/internlm2_5-7b-chat-1m) (7B)|1M|4K|88.1|85.5|84.5|82.7|75.5|68.9|80.9| 77.8 **(20th)**|83.9 **(25th)**| 33 | [Mistral](https://huggingface.co/mistralai/Mistral-7B-instruct-v0.2) (7B)|32K|16K|93.6|91.2|87.2|75.4|49.0|13.8|68.4|55.6 **(28th)**|81.2 **(26th)**| 34 | [Mistral-Nemo](https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407)|128K|16K|87.8|87.2|87.7|69.0|46.8|19.0|66.2|54.7 **(29th)**|77.8 **(27th)**| 35 | [GLM3](https://huggingface.co/THUDM/chatglm3-6b-128K) (6B)|128K|4K|87.8|83.4|78.6|69.9|56.0|42.0|69.6|62.0 **(27th)**|77.2 **(28th)**| 36 | [LWM](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-1M) (7B)|1M|<4K|82.3|78.4|73.7|69.1|68.1|65.0|72.8|69.9 **(25th)**|75.7 **(29th)**| 37 | [DBRX](https://huggingface.co/databricKs/dbrx-instruct) (36B/132B)|32K|8K|95.1|93.8|83.6|63.1|2.4|0.0|56.3|38.0 **(30th)**|74.7 **(30th)**| 38 | [Qwen1.5](https://huggingface.co/Qwen/Qwen1.5-72B-Chat) (72B)|32K|8K|94.9|93.8|78.0|67.8|0.0|0.0|55.7|37.5 **(31st)**|74.0 **(31st)**| 39 | [Together](https://huggingface.co/togethercomputer/Llama-2-7B-32K-instruct) (7B)|32K|4K|88.2|81.1|69.4|63.0|0.0|0.0|50.3|33.8 **(32nd)**|66.7 **(32nd)**| 40 | [LongChat](https://huggingface.co/lmsys/longchat-7b-v1.5-32K) (7B)|32K|<4K|84.7|79.9|70.8|59.3|0.0|0.0|49.1|33.1 **(33rd)**|65.2 **(33rd)**| 41 | [LongAlpaca](https://huggingface.co/YuKang/LongAlpaca-13B) (13B)| 32K|<4K|60.6|57.0|56.6|43.6|0.0|0.0|36.3|24.7 **(34th)**|47.9 **(34th)**| 42 | 43 | - Despite achieving nearly perfect performance on the vanilla needle-in-a-haystack (NIAH) test, most models exhibit large degradation on tasks in RULER as sequence length increases. 44 | - While all models claim context size of 32k tokens or greater, only half of them can effectively handle sequence length of 32K by exceeding a qualitative threshold, Llama-2-7b performance at 4K (85.6%). The performance exceeding the threshold is underlined. 45 | - Almost all models fall below the threshold before reaching the claimed context lengths. 46 | - Notes 47 | - Jamba-1.5-large results are reported by authors from this [report](https://arxiv.org/pdf/2408.12570). 48 | - FILM-7B results are reported by authors of this [paper](https://arxiv.org/pdf/2404.16811). They use [YaRN](https://arxiv.org/pdf/2309.00071) without further training for the evaluation length exceeding 32K (64K and 128K). They do not use the one-shot example for the CWE task. 49 | 50 | ## 💡 Requirements 51 | 52 | - Docker container: `docker pull cphsieh/ruler:0.2.0` 53 | - The requirements are listed in `docker/Dockerfile` and `docker/requirements.txt`. Use the following command to build the container based on NVIDIA's PyTorch container `nvcr.io/nvidia/pytorch:23.10-py3`. 54 | ``` 55 | cd docker/ 56 | DOCKER_BUILDKIT=1 docker build -f Dockerfile -t cphsieh/ruler:0.2.0 . 57 | ``` 58 | 59 | 60 | ## 🔍 Evaluate long-context LMs 61 | ### 1. Download data 62 | - Paul Graham Essays for NIAH are downloaded from [NIAH Github](https://github.com/gkamradt/LLMTest_NeedleInAHaystack/tree/main/needlehaystack/PaulGrahamEssays) and [Paul Graham Blog](https://paulgraham.com/articles.html). 63 | - QA datasets are downloaded from [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) and [HotpotQA](https://hotpotqa.github.io/). 64 | ``` 65 | cd scripts/data/synthetic/json/ 66 | python download_paulgraham_essay.py 67 | bash download_qa_dataset.sh 68 | ``` 69 | ### 2. Download model 70 | - We download the models from [Huggingface](https://huggingface.co/models). 71 | - The input template of each model is stored in `scripts/data/template.py`. Please add new model template if your new model uses a different chat template. 72 | - Increase `max_position_embeddings` in `config.json` if you want to run inference longer than model defined length. 73 | - (Optional) If you are using [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/main), please build your model engine based on their example scripts (e.g., [Llama](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama)) with their [Docker container](https://github.com/NVIDIA/TensorRT-LLM/tree/main?tab=readme-ov-file#installation). 74 | 75 | ### 3. Run evaluation pipeline 76 | 77 | - **Setup `run.sh`** 78 | ``` 79 | GPUS="" # number of GPUs 80 | ROOT_DIR="" # the path that stores generated task samples and model predictions. 81 | MODEL_DIR="" # the path that contains individual model folders from Huggingface. 82 | ENGINE_DIR="" # the path that contains individual engine folders from TensorRT-LLM. 83 | ``` 84 | - **Setup `config_models.sh`** 85 | ``` 86 | case $MODEL_NAME in 87 | YOUR_HF_MODEL_NAME) 88 | MODEL_PATH=${MODEL_DIR}/YOUR_MODEL_FOLDER 89 | MODEL_TEMPLATE_TYPE="" # base, meta-chat, etc. defined in `scripts/data/template.py` 90 | MODEL_FRAMEWORK="" # hf or vllm 91 | ;; 92 | YOUR_TRTLLM_ENGINE_NAME) 93 | MODEL_PATH=${ENGINE_DIR}/YOUR_ENGINE_FOLDER 94 | MODEL_TEMPLATE_TYPE="" # base, meta-chat, etc. defined in `scripts/data/template.py` 95 | MODEL_FRAMEWORK="trtllm" 96 | ;; 97 | YOUR_OPENAI_MODEL_NAME) 98 | MODEL_PATH="" # OpenAI model name listed in https://platform.openai.com/docs/models/ 99 | MODEL_TEMPLATE_TYPE="base" 100 | MODEL_FRAMEWORK="openai" 101 | TOKENIZER_PATH="cl100k_base" 102 | TOKENIZER_TYPE="openai" 103 | OPENAI_API_KEY="" # your OpenAI API key 104 | ;; 105 | YOUR_GEMINI_MODEL_NAME) 106 | MODEL_PATH="" # Gemini model name listed in https://ai.google.dev/gemini-api/docs/models/gemini 107 | MODEL_TEMPLATE_TYPE="base" 108 | MODEL_FRAMEWORK="gemini" 109 | TOKENIZER_PATH=$MODEL_PATH 110 | TOKENIZER_TYPE="gemini" 111 | GEMINI_API_KEY="" # your Gemini API key 112 | ;; 113 | ``` 114 | 115 | - **Start evaluation based on our default `synthetic` benchmark** 116 | ``` 117 | bash run.sh YOUR_MODEL_NAME synthetic 118 | ``` 119 | 120 | ## 🧠 (Optional) Customize task complexity 121 | The tasks to be evaluated on are stored in `scripts/config_tasks.sh`. Configuration of each task is defined in `scripts/synthetic.yaml`. The complexity of each task can be configured by changing the arguments which we describe in detail below. 122 | 123 | | Category |Task name | Configurations | 124 | |:--------------------:|:---------------------------:|--------------------| 125 | | Retrieval | niah |**type_haystack**: `repeat/essay/needle`
# repeat: repeated noise sentences
# essay: Paul Graham Essays
# needle: distracted needles

**type_needle_k**: `words/numbers/uuids`
**type_needle_v**: `words/numbers/uuids`
# words: adjective-noun
# numbers: 7 digits
# uuids: 32 digits

**num_needle_k**: `int >= 1`
# add multiple needles in haystack
**num_needle_v**: `int >= 1`
# retrieve multiple values from a single key
**num_needle_q**: `int >= 1`
# retrieve multiple values from multiple keys | 126 | | Multi-hop
Tracing | variable_tracking | **num_chains**: `int >= 1`
# number of variable name-binding chains
**num_hops**: `int >= 1`
# number of times binding variable names in each chain | 127 | | Aggregation | common_words_extraction |**freq_cw**: `int >= 1`
# frequency of common words
**freq_ucw**: `int >= 1`
# frequency of uncommon words
**num_cw**: `int >= 1`
# number of common words | 128 | | Aggregation | freq_words_extraction |**alpha**: `float > 1.0`
# parameter of the distribution to draw synthetic words. Reducing alpha to increase the difficulty of this task. Note that increasing the number of words to return also increases the difficulty of this task, we use `3` in our evaluations as models show worse performance at short context size when more words need to be returned. | 129 | | Question
Answering | qa |**dataset**: `squad` or `hotpotqa`
# the short-context qa dataset we use 130 | 131 | 132 | 133 | ## 🚀 (Optional) Contribute a new synthetic task 134 | ### 1. Create a python script for data preparation 135 | * Add basic arguments (required) and complexity configurations in the python script. 136 | * Verify the script is reproducible given a tokenizer, a sequence length, and a random seed. 137 | * Save the script under the folder `scripts/data/synthetic`. 138 | 139 | ### 2. Add task template 140 | * Add `template` and `tokens_to_generate` in `scripts/data/synthetic/constants.py`. 141 | * Add `answer_predfix` to prevent model from refusing to answer. 142 | 143 | ### 3. Add evaluation metric 144 | * Add the automatic metric to evaluate your task in `scripts/eval/synthetic/constants.py` 145 | 146 | ### 4. Add required configurations 147 | * Define your task name and complexity configurations in `scripts/synthetic.yaml`. 148 | * Add your task name in `scripts/config_tasks.sh` 149 | 150 | ## 🛠️ Limitations 151 | While tasks in RULER are designed to be configurable, we only evaluate the above models with 13 task configurations. These tasks were selected because most models can achieve good (some almost perfect) performance at short context size (<= 4K), which leaves ample room to observe degradation as we extend the input length. We did not include more complexed tasks in RULER that models show worse performance at short context size. We also did not stress test every model with more difficult task configurations. Although RULER covers four task categories extending previous evaluation protocol and provides a clean test bed for sanity-checking LMs with known upper bound performance, it is by no means comprehensive enough and it cannot replace the more preferred realistic tasks. We welcome people to contribute new tasks and/or new task categories to help evaluate long-context capabilities. 152 | 153 | 154 | ## 📝 Citation 155 | ``` 156 | @article{hsieh2024ruler, 157 | title={RULER: What's the Real Context Size of Your Long-Context Language Models?}, 158 | author={Cheng-Ping Hsieh and Simeng Sun and Samuel Kriman and Shantanu Acharya and Dima Rekesh and Fei Jia and Yang Zhang and Boris Ginsburg}, 159 | year={2024}, 160 | journal={arXiv preprint arXiv:2404.06654}, 161 | } 162 | ``` 163 | Disclaimer: This project is strictly for research purposes, and not an official product from NVIDIA. 164 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | FROM nvcr.io/nvidia/pytorch:23.10-py3 16 | 17 | WORKDIR /workspace/ 18 | 19 | COPY ./requirements.txt . 20 | RUN pip install --upgrade pip \ 21 | && pip install -r requirements.txt \ 22 | && pip install flash-attn==2.6.0.post1 --no-build-isolation \ 23 | && pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary \ 24 | && pip install causal-conv1d==1.4.0 \ 25 | && pip install mamba-ssm==2.2.2 26 | 27 | RUN [ "python3", "-c", "import nltk; nltk.download('punkt')"] 28 | 29 | -------------------------------------------------------------------------------- /docker/requirements.txt: -------------------------------------------------------------------------------- 1 | nemo-toolkit[all] 2 | tritonclient[all] 3 | transformer_engine[pytorch] 4 | flask 5 | flask_restful 6 | html2text 7 | google-generativeai 8 | sshtunnel_requests 9 | wonderwords 10 | openai 11 | tiktoken 12 | tenacity 13 | accelerate 14 | huggingface_hub==0.23.4 15 | transformers==4.44.2 16 | vllm==0.5.4 -------------------------------------------------------------------------------- /scripts/config_models.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | TEMPERATURE="0.0" # greedy 16 | TOP_P="1.0" 17 | TOP_K="32" 18 | SEQ_LENGTHS=( 19 | 131072 20 | 65536 21 | 32768 22 | 16384 23 | 8192 24 | 4096 25 | ) 26 | 27 | MODEL_SELECT() { 28 | MODEL_NAME=$1 29 | MODEL_DIR=$2 30 | ENGINE_DIR=$3 31 | 32 | case $MODEL_NAME in 33 | llama2-7b-chat) 34 | MODEL_PATH="${MODEL_DIR}/llama2-7b-chat-hf" 35 | MODEL_TEMPLATE_TYPE="meta-chat" 36 | MODEL_FRAMEWORK="vllm" 37 | ;; 38 | llama3.1-8b-chat) 39 | MODEL_PATH="${MODEL_DIR}/llama3.1-8b-Instruct" 40 | MODEL_TEMPLATE_TYPE="meta-llama3" 41 | MODEL_FRAMEWORK="vllm" 42 | ;; 43 | jamba1.5-mini) 44 | MODEL_PATH="${MODEL_DIR}/Jamba-1.5-Mini" 45 | MODEL_TEMPLATE_TYPE="jamba" 46 | MODEL_FRAMEWORK="vllm" 47 | ;; 48 | gpt-3.5-turbo) 49 | MODEL_PATH="gpt-3.5-turbo-0125" 50 | MODEL_TEMPLATE_TYPE="base" 51 | MODEL_FRAMEWORK="openai" 52 | TOKENIZER_PATH="cl100k_base" 53 | TOKENIZER_TYPE="openai" 54 | OPENAI_API_KEY="" 55 | AZURE_ID="" 56 | AZURE_SECRET="" 57 | AZURE_ENDPOINT="" 58 | ;; 59 | gpt-4-turbo) 60 | MODEL_PATH="gpt-4" 61 | MODEL_TEMPLATE_TYPE="base" 62 | MODEL_FRAMEWORK="openai" 63 | TOKENIZER_PATH="cl100k_base" 64 | TOKENIZER_TYPE="openai" 65 | OPENAI_API_KEY="" 66 | AZURE_ID="" 67 | AZURE_SECRET="" 68 | AZURE_ENDPOINT="" 69 | ;; 70 | gemini_1.0_pro) 71 | MODEL_PATH="gemini-1.0-pro-latest" 72 | MODEL_TEMPLATE_TYPE="base" 73 | MODEL_FRAMEWORK="gemini" 74 | TOKENIZER_PATH=$MODEL_PATH 75 | TOKENIZER_TYPE="gemini" 76 | GEMINI_API_KEY="" 77 | ;; 78 | gemini_1.5_pro) 79 | MODEL_PATH="gemini-1.5-pro-latest" 80 | MODEL_TEMPLATE_TYPE="base" 81 | MODEL_FRAMEWORK="gemini" 82 | TOKENIZER_PATH=$MODEL_PATH 83 | TOKENIZER_TYPE="gemini" 84 | GEMINI_API_KEY="" 85 | ;; 86 | esac 87 | 88 | 89 | if [ -z "${TOKENIZER_PATH}" ]; then 90 | if [ -f ${MODEL_PATH}/tokenizer.model ]; then 91 | TOKENIZER_PATH=${MODEL_PATH}/tokenizer.model 92 | TOKENIZER_TYPE="nemo" 93 | else 94 | TOKENIZER_PATH=${MODEL_PATH} 95 | TOKENIZER_TYPE="hf" 96 | fi 97 | fi 98 | 99 | 100 | echo "$MODEL_PATH:$MODEL_TEMPLATE_TYPE:$MODEL_FRAMEWORK:$TOKENIZER_PATH:$TOKENIZER_TYPE:$OPENAI_API_KEY:$GEMINI_API_KEY:$AZURE_ID:$AZURE_SECRET:$AZURE_ENDPOINT" 101 | } 102 | -------------------------------------------------------------------------------- /scripts/config_tasks.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | NUM_SAMPLES=500 16 | REMOVE_NEWLINE_TAB=false 17 | STOP_WORDS="" 18 | 19 | if [ -z "${STOP_WORDS}" ]; then 20 | STOP_WORDS="" 21 | else 22 | STOP_WORDS="--stop_words \"${STOP_WORDS}\"" 23 | fi 24 | 25 | if [ "${REMOVE_NEWLINE_TAB}" = false ]; then 26 | REMOVE_NEWLINE_TAB="" 27 | else 28 | REMOVE_NEWLINE_TAB="--remove_newline_tab" 29 | fi 30 | 31 | # task name in `synthetic.yaml` 32 | synthetic=( 33 | "niah_single_1" 34 | "niah_single_2" 35 | "niah_single_3" 36 | "niah_multikey_1" 37 | "niah_multikey_2" 38 | "niah_multikey_3" 39 | "niah_multivalue" 40 | "niah_multiquery" 41 | "vt" 42 | "cwe" 43 | "fwe" 44 | "qa_1" 45 | "qa_2" 46 | ) 47 | -------------------------------------------------------------------------------- /scripts/data/prepare.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Prepare jsonl with field `input` and `outputs`. 17 | { 18 | "index" int, 19 | "input": str, 20 | "outputs": [str], 21 | } 22 | 23 | python prepare.py \ 24 | --save_dir ./ \ 25 | --benchmark synthetic \ 26 | --task niah_single_1 \ 27 | --tokenizer_path tokenizer.model \ 28 | --tokenizer_type nemo \ 29 | --max_seq_length 4096 \ 30 | --model_template_type base \ 31 | --num_samples 10 \ 32 | """ 33 | import os 34 | import argparse 35 | import importlib 36 | import subprocess 37 | import time 38 | import yaml 39 | from pathlib import Path 40 | from template import Templates 41 | import nltk 42 | try: 43 | nltk.data.find('tokenizers/punkt') 44 | nltk.data.find('tokenizers/punkt_tab') 45 | except LookupError: 46 | nltk.download('punkt') 47 | nltk.download('punkt_tab') 48 | 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--save_dir", type=Path, required=True, help='dataset folder to save dataset') 51 | parser.add_argument("--benchmark", type=str, default='synthetic', help='Options: [synthetic]') 52 | parser.add_argument("--task", type=str, required=True, help='tasks in benchmark') 53 | parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test') 54 | parser.add_argument("--tokenizer_path", type=str, required=True, help='path to the tokenizer model') 55 | parser.add_argument("--tokenizer_type", type=str, default='nemo', help='[Options] nemo, hf, openai.') 56 | parser.add_argument("--max_seq_length", type=int, required=True, help='max sequence length including all input tokens and generated tokens.') 57 | parser.add_argument("--num_samples", type=int, default=500, help='maximum number of samples we want to test') 58 | parser.add_argument("--random_seed", type=int, default=42) 59 | parser.add_argument("--model_template_type", type=str, default='base', help='Options in `template.py`') 60 | parser.add_argument("--remove_newline_tab", action='store_true', help='remove `\n` and `\t` in all strings.') 61 | parser.add_argument("--chunk_idx", type=int, default=0, help='index of current split chunk') 62 | parser.add_argument("--chunk_amount", type=int, default=1, help='size of split chunk') 63 | parser.add_argument("--prepare_for_ns", action='store_true') 64 | 65 | args = parser.parse_args() 66 | 67 | def main(): 68 | start_time = time.time() 69 | curr_folder = os.path.dirname(os.path.abspath(__file__)) 70 | 71 | try: 72 | module = importlib.import_module(f"{args.benchmark}.constants") 73 | except ImportError: 74 | print(f"Module data.{args.benchmark}.constants not found.") 75 | 76 | tasks_base = module.TASKS 77 | with open(os.path.join(curr_folder, f"../{args.benchmark}.yaml"), "r") as f: 78 | tasks_customized = yaml.safe_load(f) 79 | 80 | if args.task not in tasks_customized: 81 | raise ValueError(f'{args.task} is not found in config_tasks.yaml') 82 | 83 | config = tasks_customized.get(args.task) 84 | config.update(tasks_base[config['task']]) 85 | 86 | # Add templates 87 | assert args.model_template_type in Templates, print(f'{args.model_template_type} is not found in {Templates.keys()}') 88 | model_template = Templates[args.model_template_type] 89 | 90 | if args.prepare_for_ns: 91 | from tokenizer import select_tokenizer 92 | TOKENIZER = select_tokenizer(args.tokenizer_type, args.tokenizer_path) 93 | model_template_token = len(TOKENIZER.text_to_tokens(model_template)) 94 | model_template = Templates['base'] 95 | 96 | task_template = config['template'] 97 | 98 | # Add answer prefix for all models 99 | answer_prefix = config['answer_prefix'] if 'answer_prefix' in config else '' 100 | 101 | config['template'] = model_template.format(task_template=task_template) + answer_prefix 102 | 103 | # Split task into multiple chunks 104 | chunks = [(args.num_samples // args.chunk_amount) + (1 if i < args.num_samples % args.chunk_amount else 0) for i in range(args.chunk_amount)] 105 | num_samples = chunks[args.chunk_idx] 106 | pre_samples = sum(chunks[:args.chunk_idx]) 107 | 108 | random_seed = args.random_seed + args.chunk_idx 109 | 110 | 111 | save_file = args.save_dir / args.task / f"{args.subset}.jsonl" 112 | file_exists = False 113 | if os.path.exists(save_file): 114 | with open(save_file, "r") as f: 115 | data = f.readlines() 116 | if len(data) == args.num_samples: file_exists = True 117 | 118 | 119 | 120 | if not file_exists: 121 | try: 122 | script = os.path.join(curr_folder, args.benchmark, f"{config['task']}.py") 123 | additional_args = " ".join([f"--{k} {v}" for k, v in config['args'].items()]) 124 | command = f"""python {script} \ 125 | --save_dir {args.save_dir} \ 126 | --save_name {args.task} \ 127 | --subset {args.subset} \ 128 | --tokenizer_path {args.tokenizer_path} \ 129 | --tokenizer_type {args.tokenizer_type} \ 130 | --max_seq_length {args.max_seq_length} \ 131 | --tokens_to_generate {config['tokens_to_generate']} \ 132 | --num_samples {num_samples} \ 133 | --random_seed {random_seed} \ 134 | {additional_args} \ 135 | {f"--remove_newline_tab" if args.remove_newline_tab else ""} \ 136 | {f"--pre_samples {pre_samples}" if config['task'] == 'qa' else ""} \ 137 | --template "{config['template']}" \ 138 | """ 139 | if args.prepare_for_ns: 140 | command += f""" --model_template_token {model_template_token}""" 141 | 142 | print(command) 143 | result = subprocess.run(command, 144 | shell=True, 145 | check=True, 146 | stdout=subprocess.PIPE, 147 | stderr=subprocess.PIPE, 148 | text=True) 149 | 150 | if result.returncode == 0: 151 | print("Output:") 152 | print(result.stdout) 153 | else: 154 | print("Error:") 155 | print(result.stderr) 156 | except subprocess.CalledProcessError as e: 157 | print("Error output:", e.stderr) 158 | 159 | print(f"Prepare {args.task} with lines: {args.num_samples} to {save_file}") 160 | print(f"Used time: {round((time.time() - start_time) / 60, 1)} minutes") 161 | else: 162 | print(f"Skip preparing {args.task} with lines: {args.num_samples} to {save_file} (file exists)") 163 | 164 | if __name__ == '__main__': 165 | main() 166 | -------------------------------------------------------------------------------- /scripts/data/synthetic/common_words_extraction.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | """ 16 | Create a dataset jsonl file for common words extraction. 17 | 18 | python common_words_extraction.py \ 19 | --save_dir=./ \ 20 | --save_name=vt \ 21 | --tokenizer_path=tokenizer.model \ 22 | --tokenizer_type nemo \ 23 | --max_seq_length 4096 \ 24 | --tokens_to_generate 30 \ 25 | --num_samples 10 \ 26 | --random_seed 42 \ 27 | -freq_cw 30 --freq_ucw 3 --num_cw 10 \ 28 | --template "[INST] Below is a numbered list of words. In these words, some appear more often than others. Memorize the ones that appear most often.\n{context}\nQuestion: What are the 10 most common words in the above list? [/INST] Answer: The top 10 words that appear most often in the list are:" 29 | """ 30 | 31 | import os 32 | import argparse 33 | from pathlib import Path 34 | from tqdm import tqdm 35 | import random 36 | import wonderwords 37 | from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest 38 | import sys 39 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) 40 | from tokenizer import select_tokenizer 41 | import json 42 | import logging 43 | from constants import TASKS 44 | logging.basicConfig(level=logging.INFO, force=True) 45 | logger = logging.getLogger(__name__) 46 | 47 | 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("--save_dir", type=Path, required=True, help='dataset folder to save dataset') 50 | parser.add_argument("--save_name", type=str, required=True, help='name of the save dataset jsonl file') 51 | parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test') 52 | parser.add_argument("--tokenizer_path", type=str, required=True, help='path to the tokenizer model') 53 | parser.add_argument("--tokenizer_type", type=str, default='nemo', help='[Options] nemo, hf, openai.') 54 | parser.add_argument("--max_seq_length", type=int, required=True, help='max sequence length including all input tokens and generated tokens.') 55 | parser.add_argument("--tokens_to_generate", type=int, required=True, help='expected generated token amount.') 56 | parser.add_argument("--num_samples", type=int, required=True, help='number of samples to generate') 57 | parser.add_argument("--random_seed", type=int, default=42) 58 | parser.add_argument("--template", type=str, default='', help='prompt template') 59 | parser.add_argument("--remove_newline_tab", action='store_true', help='remove `\n` and `\t` in all strings.') 60 | 61 | parser.add_argument("--freq_cw", type=int, default=30) 62 | parser.add_argument("--freq_ucw", type=int, default=3) 63 | parser.add_argument("--num_cw", type=int, default=10) 64 | parser.add_argument("--num_fewshot", type=int, default=1) 65 | parser.add_argument("--model_template_token", type=int, default=0, help='used for nemo skills, minus num of model template token') 66 | args = parser.parse_args() 67 | random.seed(args.random_seed) 68 | 69 | # Load Tokenizer 70 | TOKENIZER = select_tokenizer(args.tokenizer_type, args.tokenizer_path) 71 | 72 | nouns = wonderwords.random_word._get_words_from_text_file("nounlist.txt") 73 | adjs = wonderwords.random_word._get_words_from_text_file("adjectivelist.txt") 74 | verbs = wonderwords.random_word._get_words_from_text_file("verblist.txt") 75 | words = nouns + adjs + verbs 76 | words = sorted(list(set(words))) 77 | random.Random(args.random_seed).shuffle(words) 78 | logger.info(f'loaded {len(words)} wonderwords') 79 | 80 | # Randleword english words 81 | with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "json/english_words.json") , "r") as f: 82 | randle_words = list(json.load(f).values()) 83 | logger.info(f'loaded {len(randle_words)} randle words') 84 | 85 | def get_example(num_words, common_repeats=30, uncommon_repeats=3, common_nums=10): 86 | if num_words <= len(words): 87 | word_list_full = random.sample(words, num_words) 88 | else: 89 | word_list_full = random.sample(randle_words, num_words) 90 | 91 | common, uncommon = word_list_full[:common_nums], word_list_full[common_nums:] 92 | word_list = common * int(common_repeats) + uncommon * int(uncommon_repeats) 93 | random.Random(args.random_seed).shuffle(word_list) 94 | 95 | # Formatting the word list as "1. word1 2. word2 3. word3 ..." 96 | context = ' '.join([f"{i + 1}. {word}" for i, word in enumerate(word_list)]) 97 | 98 | return context, common 99 | 100 | def generate_input_output(num_words): 101 | few_shots = [] 102 | if args.max_seq_length < 4096: 103 | for _ in range(args.num_fewshot): 104 | context_example, answer_example = get_example(20, 3, 1, args.num_cw) 105 | few_shots.append((context_example, answer_example)) 106 | context, answer = get_example(num_words, 6, 1, args.num_cw) 107 | else: 108 | for _ in range(args.num_fewshot): 109 | context_example, answer_example = get_example(40, 10, 3, args.num_cw) 110 | few_shots.append((context_example, answer_example)) 111 | context, answer = get_example(num_words, args.freq_cw, args.freq_ucw, args.num_cw) 112 | 113 | template = args.template 114 | 115 | for n in range(len(few_shots)): 116 | few_shots[n] = template.format( 117 | num_cw=args.num_cw, 118 | context=few_shots[n][0], 119 | query='', 120 | ) + ' ' + ' '.join([f"{i + 1}. {word}" for i, word in enumerate(few_shots[n][1])]) 121 | 122 | few_shots = "\n".join(few_shots) 123 | input_text = template.format( 124 | num_cw=args.num_cw, 125 | context=context, 126 | query='', 127 | ) 128 | 129 | return few_shots + "\n" + input_text, answer 130 | 131 | def sys_word_pair_random(num_samples: int, max_seq_length: int, save_dir: str, incremental: int = 10): 132 | write_jsons = [] 133 | tokens_to_generate = args.tokens_to_generate 134 | max_seq_length -= args.model_template_token 135 | 136 | 137 | # Estimate tokens per question to determine reasonable upper bound 138 | sample_input_text, _ = generate_input_output(incremental) 139 | sample_tokens = len(TOKENIZER.text_to_tokens(sample_input_text)) 140 | tokens_per_words = sample_tokens / incremental 141 | 142 | # Let's do 3x to allow for some slack since we can get unlucky due to sampling. 143 | # NOTE: We should test this for really large sequence lengths to make sure it's reasonable. 144 | estimated_max_words = int((max_seq_length / tokens_per_words) * 3) 145 | 146 | # Binary search for optimal haystack size 147 | lower_bound = incremental 148 | upper_bound = max(estimated_max_words, incremental * 2) # Ensure upper_bound is reasonable 149 | 150 | optimal_num_words = None 151 | 152 | logger.info(f"Estimated {tokens_per_words:.1f} tokens per haystack") 153 | logger.info(f"Starting binary search with bounds: {lower_bound} to {upper_bound}") 154 | while lower_bound <= upper_bound: 155 | mid = (lower_bound + upper_bound) // 2 156 | input_text, answer = generate_input_output(mid) 157 | total_tokens = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate 158 | 159 | logger.info(f"Testing haystack size: {mid}, resulting tokens: {total_tokens}/{max_seq_length}") 160 | 161 | if total_tokens <= max_seq_length: 162 | # This size works, can we go larger? 163 | optimal_num_words = mid 164 | lower_bound = mid + 1 165 | else: 166 | # Too large, need to go smaller 167 | upper_bound = mid - 1 168 | 169 | num_words = optimal_num_words if optimal_num_words is not None else incremental 170 | logger.info(f'Final optimal haystack size (number of haystack): {num_words}') 171 | 172 | 173 | # Generate samples 174 | for index in tqdm(range(num_samples)): 175 | used_words = num_words 176 | while(True): 177 | try: 178 | input_text, answer = generate_input_output(used_words) 179 | length = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate 180 | assert length <= max_seq_length, f"{length} exceeds max_seq_length." 181 | break 182 | except: 183 | if used_words > incremental: 184 | used_words -= incremental 185 | 186 | if args.remove_newline_tab: 187 | input_text = ' '.join(input_text.replace('\n', ' ').replace('\t', ' ').strip().split()) 188 | 189 | answer_prefix_index = input_text.rfind(TASKS['common_words_extraction']['answer_prefix'][:10]) # use first 10 char of answer prefix to locate it 190 | answer_prefix = input_text[answer_prefix_index:] 191 | input_text = input_text[:answer_prefix_index] 192 | 193 | formatted_output = { 194 | 'index': index, 195 | "input": input_text, 196 | "outputs": answer, 197 | "length": length, 198 | 'length_w_model_temp': length + args.model_template_token, 199 | 'answer_prefix': answer_prefix, 200 | } 201 | write_jsons.append(formatted_output) 202 | 203 | return write_jsons 204 | 205 | 206 | def main(): 207 | save_file = args.save_dir / f'{args.save_name}' / f'{args.subset}.jsonl' 208 | save_file.parent.mkdir(parents=True, exist_ok=True) 209 | 210 | write_jsons = sys_word_pair_random(num_samples=args.num_samples, max_seq_length=args.max_seq_length, save_dir=args.save_dir) 211 | 212 | write_manifest(save_file, write_jsons) 213 | 214 | if __name__=="__main__": 215 | main() -------------------------------------------------------------------------------- /scripts/data/synthetic/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | """ 16 | Add a new task (required arguments): 17 | 18 | TASK_NAME: { 19 | 'tokens_to_generate': how many tokens we want to generate. 20 | 'template': the template with at least {context} and {query}. 21 | } 22 | """ 23 | 24 | TASKS = { 25 | 'niah': { 26 | 'tokens_to_generate': 128, 27 | 'template': """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text?""", 28 | 'answer_prefix': """ The special magic {type_needle_v} for {query} mentioned in the provided text are""" 29 | }, 30 | 31 | 'variable_tracking': { 32 | 'tokens_to_generate': 30, 33 | 'template': """Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n{context}\nQuestion: Find all variables that are assigned the value {query} in the text above.""", 34 | 'answer_prefix': """ Answer: According to the chain(s) of variable assignment in the text above, {num_v} variables are assigned the value {query}, they are: """ 35 | }, 36 | 37 | 'common_words_extraction': { 38 | 'tokens_to_generate': 120, 39 | 'template': """Below is a numbered list of words. In these words, some appear more often than others. Memorize the ones that appear most often.\n{context}\nQuestion: What are the 10 most common words in the above list?""", 40 | 'answer_prefix': """ Answer: The top 10 words that appear most often in the list are:""" 41 | }, 42 | 43 | 'freq_words_extraction' : { 44 | 'tokens_to_generate': 50, 45 | 'template': """Read the following coded text and track the frequency of each coded word. Find the three most frequently appeared coded words. {context}\nQuestion: Do not provide any explanation. Please ignore the dots '....'. What are the three most frequently appeared words in the above coded text?""", 46 | 'answer_prefix': """ Answer: According to the coded text above, the three most frequently appeared words are:""" 47 | }, 48 | 49 | 'qa': { 50 | 'tokens_to_generate': 32, 51 | 'template': """Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query}""", 52 | 'answer_prefix': """ Answer:""", 53 | }, 54 | } -------------------------------------------------------------------------------- /scripts/data/synthetic/freq_words_extraction.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | """ 16 | Create a dataset jsonl file for frequent words extraction. 17 | 18 | python freq_words_extraction.py \ 19 | --save_dir=./ \ 20 | --save_name=vt \ 21 | --tokenizer_path=tokenizer.model \ 22 | --tokenizer_type nemo \ 23 | --max_seq_length 4096 \ 24 | --tokens_to_generate 30 \ 25 | --num_samples 10 \ 26 | --random_seed 42 \ 27 | --alpha 2.0 \ 28 | --template "[INST] Read the following coded text and track the frequency of each coded word. Find the three most frequently appeared coded words. {context}\nQuestion: Do not provide any explanation. Please ignore the dots '....'. What are the three most frequently appeared words in the above coded text? [/INST] Answer: According to the coded text above, the three most frequently appeared words are:" 29 | """ 30 | 31 | import os 32 | import argparse 33 | from pathlib import Path 34 | from tqdm import tqdm 35 | import random 36 | import string 37 | import numpy as np 38 | from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest 39 | import sys 40 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) 41 | from tokenizer import select_tokenizer 42 | from scipy.special import zeta 43 | import logging 44 | 45 | logging.basicConfig(level=logging.INFO) 46 | logger = logging.getLogger(__name__) 47 | from constants import TASKS 48 | 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--save_dir", type=Path, required=True, help='dataset folder to save dataset') 51 | parser.add_argument("--save_name", type=str, required=True, help='name of the save dataset jsonl file') 52 | parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test') 53 | parser.add_argument("--tokenizer_path", type=str, required=True, help='path to the tokenizer model') 54 | parser.add_argument("--tokenizer_type", type=str, default='nemo', help='[Options] nemo, hf, openai.') 55 | parser.add_argument("--max_seq_length", type=int, required=True, help='max sequence length including all input tokens and generated tokens.') 56 | parser.add_argument("--tokens_to_generate", type=int, default=50, help='number of tokens to generate') 57 | parser.add_argument("--num_samples", type=int, required=True, help='number of samples to generate') 58 | parser.add_argument("--random_seed", type=int, default=42) 59 | parser.add_argument("--template", type=str, default='', help='prompt template') 60 | parser.add_argument("--remove_newline_tab", action='store_true', help='remove `\n` and `\t` in all strings.') 61 | parser.add_argument("--coded_wordlen", type=int, default=6, help="length of synthetic word") 62 | parser.add_argument("--vocab_size", type=int, default=-1, help='synthetic vocab size to sample from') 63 | parser.add_argument("--alpha", type=float, default=2.0, help='zeta distribution alpha') 64 | parser.add_argument("--add_fewshot", action="store_true", default=False) 65 | parser.add_argument("--model_template_token", type=int, default=0, help='used for nemo skills, minus num of model template token') 66 | 67 | args = parser.parse_args() 68 | random.seed(args.random_seed) 69 | np.random.seed(args.random_seed) 70 | 71 | # Load Tokenizer 72 | TOKENIZER = select_tokenizer(args.tokenizer_type, args.tokenizer_path) 73 | 74 | def generate_input_output(max_len, num_words=-1, coded_wordlen=6, vocab_size=2000, incremental=10, alpha=2.0): 75 | # generate vocab 76 | vocab = [''.join(random.choices(string.ascii_lowercase, k=coded_wordlen)) for _ in range(vocab_size)] 77 | while len(set(vocab)) < vocab_size: 78 | vocab.append(''.join(random.choices(string.ascii_lowercase, k=coded_wordlen))) 79 | vocab = sorted(list(set(vocab))) 80 | random.Random(args.random_seed).shuffle(vocab) 81 | vocab[0] = '...' # treat the top ranked as noise 82 | 83 | # sample words 84 | template = args.template 85 | def gen_text(num_words): 86 | k = np.arange(1, len(vocab)+1) 87 | sampled_cnt = num_words*(k**-alpha)/zeta(alpha) 88 | sampled_words = [[w] * zi for w, zi in zip(vocab, sampled_cnt.astype(int))] 89 | sampled_words = [x for wlst in sampled_words for x in wlst] 90 | random.Random(args.random_seed).shuffle(sampled_words) 91 | return template.format(context=' '.join(sampled_words), query=''), vocab[1:4] 92 | 93 | if num_words > 0: 94 | num_words = num_words 95 | text, answer = gen_text(num_words) 96 | while len(TOKENIZER.text_to_tokens(text)) > max_len: 97 | num_words -= incremental 98 | text, answer = gen_text(num_words) 99 | else: 100 | num_words = max_len // coded_wordlen # init 101 | text, answer = gen_text(num_words) 102 | while len(TOKENIZER.text_to_tokens(text)) < max_len: 103 | num_words += incremental 104 | text, answer = gen_text(num_words) 105 | num_words -= incremental 106 | text, answer = gen_text(num_words) 107 | return text, answer, num_words 108 | 109 | def sys_kwext(num_samples: int, max_seq_length: int, incremental: int = 10): 110 | write_jsons = [] 111 | tokens_to_generate = args.tokens_to_generate 112 | 113 | max_seq_length -= args.model_template_token 114 | vocab_size = max_seq_length // 50 if args.vocab_size == -1 else args.vocab_size 115 | 116 | # get number of words 117 | input_max_len = max_seq_length 118 | _, _, num_example_words = generate_input_output(input_max_len, 119 | coded_wordlen=args.coded_wordlen, 120 | vocab_size=vocab_size, 121 | incremental=input_max_len//32, 122 | alpha=args.alpha) 123 | logger.info('num_example_words:', num_example_words) 124 | # Generate samples 125 | for index in tqdm(range(num_samples)): 126 | 127 | # construct input 128 | input_max_len = max_seq_length 129 | input_text, answer, _ = generate_input_output(input_max_len, 130 | num_words=num_example_words, 131 | coded_wordlen=args.coded_wordlen, 132 | vocab_size=vocab_size, 133 | incremental=input_max_len//32, 134 | alpha=args.alpha) 135 | 136 | 137 | length = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate 138 | 139 | if args.remove_newline_tab: 140 | input_text = ' '.join(input_text.replace('\n', ' ').replace('\t', ' ').strip().split()) 141 | 142 | answer_prefix_index = input_text.rfind(TASKS['freq_words_extraction']['answer_prefix'][:10]) # use first 10 char of answer prefix to locate it 143 | answer_prefix = input_text[answer_prefix_index:] 144 | input_text = input_text[:answer_prefix_index] 145 | formatted_output = { 146 | 'index': index, 147 | "input": input_text, 148 | "outputs": answer, 149 | "length": length, 150 | 'length_w_model_temp': length + args.model_template_token, 151 | 'answer_prefix': answer_prefix, 152 | } 153 | write_jsons.append(formatted_output) 154 | 155 | return write_jsons 156 | 157 | 158 | def main(): 159 | save_file = args.save_dir / f'{args.save_name}' / f'{args.subset}.jsonl' 160 | save_file.parent.mkdir(parents=True, exist_ok=True) 161 | write_jsons = sys_kwext(num_samples=args.num_samples, max_seq_length=args.max_seq_length, 162 | incremental=10) 163 | 164 | write_manifest(save_file, write_jsons) 165 | 166 | if __name__=="__main__": 167 | main() -------------------------------------------------------------------------------- /scripts/data/synthetic/json/PaulGrahamEssays_URLs.txt: -------------------------------------------------------------------------------- 1 | http://www.paulgraham.com/13sentences.html 2 | http://www.paulgraham.com/5founders.html 3 | http://www.paulgraham.com/6631327.html 4 | http://www.paulgraham.com/95.html 5 | http://www.paulgraham.com/ace.html 6 | http://www.paulgraham.com/airbnb.html 7 | http://www.paulgraham.com/airbnbs.html 8 | http://www.paulgraham.com/alien.html 9 | http://www.paulgraham.com/altair.html 10 | http://www.paulgraham.com/ambitious.html 11 | http://www.paulgraham.com/america.html 12 | http://www.paulgraham.com/angelinvesting.html 13 | http://www.paulgraham.com/artistsship.html 14 | http://www.paulgraham.com/badeconomy.html 15 | http://www.paulgraham.com/better.html 16 | http://www.paulgraham.com/bronze.html 17 | http://www.paulgraham.com/bubble.html 18 | http://www.paulgraham.com/charisma.html 19 | http://www.paulgraham.com/cities.html 20 | http://www.paulgraham.com/college.html 21 | http://www.paulgraham.com/colleges.html 22 | http://www.paulgraham.com/conformism.html 23 | http://www.paulgraham.com/control.html 24 | http://www.paulgraham.com/convergence.html 25 | http://www.paulgraham.com/convince.html 26 | http://www.paulgraham.com/cred.html 27 | http://www.paulgraham.com/credentials.html 28 | http://www.paulgraham.com/determination.html 29 | http://www.paulgraham.com/die.html 30 | http://www.paulgraham.com/disagree.html 31 | http://www.paulgraham.com/disc.html 32 | http://www.paulgraham.com/discover.html 33 | http://www.paulgraham.com/distraction.html 34 | http://www.paulgraham.com/divergence.html 35 | http://www.paulgraham.com/donate.html 36 | http://www.paulgraham.com/ds.html 37 | http://www.paulgraham.com/early.html 38 | http://www.paulgraham.com/earnest.html 39 | http://www.paulgraham.com/equity.html 40 | http://www.paulgraham.com/essay.html 41 | http://www.paulgraham.com/ffb.html 42 | http://www.paulgraham.com/fh.html 43 | http://www.paulgraham.com/fix.html 44 | http://www.paulgraham.com/fn.html 45 | http://www.paulgraham.com/foundersatwork.html 46 | http://www.paulgraham.com/fp.html 47 | http://www.paulgraham.com/fr.html 48 | http://www.paulgraham.com/fundraising.html 49 | http://www.paulgraham.com/future.html 50 | http://www.paulgraham.com/genius.html 51 | http://www.paulgraham.com/getideas.html 52 | http://www.paulgraham.com/good.html 53 | http://www.paulgraham.com/goodart.html 54 | http://www.paulgraham.com/googles.html 55 | http://www.paulgraham.com/greatwork.html 56 | http://www.paulgraham.com/growth.html 57 | http://www.paulgraham.com/guidetoinvestors.html 58 | http://www.paulgraham.com/hackernews.html 59 | http://www.paulgraham.com/head.html 60 | http://www.paulgraham.com/herd.html 61 | http://www.paulgraham.com/heresy.html 62 | http://www.paulgraham.com/heroes.html 63 | http://www.paulgraham.com/highres.html 64 | http://www.paulgraham.com/hiresfund.html 65 | http://www.paulgraham.com/hiring.html 66 | http://www.paulgraham.com/hp.html 67 | http://www.paulgraham.com/hs.html 68 | http://www.paulgraham.com/hundred.html 69 | http://www.paulgraham.com/hw.html 70 | http://www.paulgraham.com/hwh.html 71 | http://www.paulgraham.com/icad.html 72 | http://www.paulgraham.com/ideas.html 73 | http://www.paulgraham.com/identity.html 74 | http://www.paulgraham.com/ineq.html 75 | http://www.paulgraham.com/inequality.html 76 | http://www.paulgraham.com/investors.html 77 | http://www.paulgraham.com/invtrend.html 78 | http://www.paulgraham.com/javacover.html 79 | http://www.paulgraham.com/jessica.html 80 | http://www.paulgraham.com/judgement.html 81 | http://www.paulgraham.com/kate.html 82 | http://www.paulgraham.com/kids.html 83 | http://www.paulgraham.com/ladder.html 84 | http://www.paulgraham.com/lesson.html 85 | http://www.paulgraham.com/lies.html 86 | http://www.paulgraham.com/lwba.html 87 | http://www.paulgraham.com/mac.html 88 | http://www.paulgraham.com/makersschedule.html 89 | http://www.paulgraham.com/marginal.html 90 | http://www.paulgraham.com/maybe.html 91 | http://www.paulgraham.com/mean.html 92 | http://www.paulgraham.com/microsoft.html 93 | http://www.paulgraham.com/mit.html 94 | http://www.paulgraham.com/name.html 95 | http://www.paulgraham.com/nerds.html 96 | http://www.paulgraham.com/newthings.html 97 | http://www.paulgraham.com/noob.html 98 | http://www.paulgraham.com/noop.html 99 | http://www.paulgraham.com/notnot.html 100 | http://www.paulgraham.com/nov.html 101 | http://www.paulgraham.com/nthings.html 102 | http://www.paulgraham.com/opensource.html 103 | http://www.paulgraham.com/organic.html 104 | http://www.paulgraham.com/orth.html 105 | http://www.paulgraham.com/own.html 106 | http://www.paulgraham.com/patentpledge.html 107 | http://www.paulgraham.com/pgh.html 108 | http://www.paulgraham.com/pinch.html 109 | http://www.paulgraham.com/polls.html 110 | http://www.paulgraham.com/power.html 111 | http://www.paulgraham.com/prcmc.html 112 | http://www.paulgraham.com/procrastination.html 113 | http://www.paulgraham.com/progbot.html 114 | http://www.paulgraham.com/prop62.html 115 | http://www.paulgraham.com/property.html 116 | http://www.paulgraham.com/publishing.html 117 | http://www.paulgraham.com/pypar.html 118 | http://www.paulgraham.com/ramenprofitable.html 119 | http://www.paulgraham.com/randomness.html 120 | http://www.paulgraham.com/re.html 121 | http://www.paulgraham.com/read.html 122 | http://www.paulgraham.com/real.html 123 | http://www.paulgraham.com/really.html 124 | http://www.paulgraham.com/relres.html 125 | http://www.paulgraham.com/revolution.html 126 | http://www.paulgraham.com/richnow.html 127 | http://www.paulgraham.com/road.html 128 | http://www.paulgraham.com/ronco.html 129 | http://www.paulgraham.com/safe.html 130 | http://www.paulgraham.com/say.html 131 | http://www.paulgraham.com/schlep.html 132 | http://www.paulgraham.com/seesv.html 133 | http://www.paulgraham.com/segway.html 134 | http://www.paulgraham.com/selfindulgence.html 135 | http://www.paulgraham.com/sfp.html 136 | http://www.paulgraham.com/simply.html 137 | http://www.paulgraham.com/smart.html 138 | http://www.paulgraham.com/softwarepatents.html 139 | http://www.paulgraham.com/spam.html 140 | http://www.paulgraham.com/speak.html 141 | http://www.paulgraham.com/start.html 142 | http://www.paulgraham.com/startupfunding.html 143 | http://www.paulgraham.com/startuphubs.html 144 | http://www.paulgraham.com/startupideas.html 145 | http://www.paulgraham.com/startupmistakes.html 146 | http://www.paulgraham.com/stuff.html 147 | http://www.paulgraham.com/superlinear.html 148 | http://www.paulgraham.com/swan.html 149 | http://www.paulgraham.com/tablets.html 150 | http://www.paulgraham.com/talk.html 151 | http://www.paulgraham.com/taste.html 152 | http://www.paulgraham.com/think.html 153 | http://www.paulgraham.com/top.html 154 | http://www.paulgraham.com/trolls.html 155 | http://www.paulgraham.com/twitter.html 156 | http://www.paulgraham.com/usa.html 157 | http://www.paulgraham.com/users.html 158 | http://www.paulgraham.com/venturecapital.html 159 | http://www.paulgraham.com/wealth.html 160 | http://www.paulgraham.com/webstartups.html 161 | http://www.paulgraham.com/whyyc.html 162 | http://www.paulgraham.com/word.html 163 | http://www.paulgraham.com/words.html 164 | http://www.paulgraham.com/work.html 165 | http://www.paulgraham.com/writing44.html 166 | http://www.paulgraham.com/wtax.html 167 | http://www.paulgraham.com/yahoo.html 168 | http://www.paulgraham.com/ycombinator.html 169 | http://www.paulgraham.com/ycstart.html 170 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/addiction.txt 171 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/aord.txt 172 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/apple.txt 173 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/avg.txt 174 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/before.txt 175 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/bias.txt 176 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/boss.txt 177 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/copy.txt 178 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/corpdev.txt 179 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/desres.txt 180 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/diff.txt 181 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/ecw.txt 182 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/founders.txt 183 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/foundervisa.txt 184 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/gap.txt 185 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/gba.txt 186 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/gh.txt 187 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/goodtaste.txt 188 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/hubs.txt 189 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/iflisp.txt 190 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/island.txt 191 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/know.txt 192 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/langdes.txt 193 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/laundry.txt 194 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/love.txt 195 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/mod.txt 196 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/newideas.txt 197 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/nft.txt 198 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/philosophy.txt 199 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/popular.txt 200 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/pow.txt 201 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/rootsoflisp.txt 202 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/rss.txt 203 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/siliconvalley.txt 204 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/startuplessons.txt 205 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/submarine.txt 206 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/sun.txt 207 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/superangels.txt 208 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/todo.txt 209 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/unions.txt 210 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/useful.txt 211 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/vb.txt 212 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/vcsqueeze.txt 213 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/vw.txt 214 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/want.txt 215 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/web20.txt 216 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/weird.txt 217 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/wisdom.txt 218 | https://github.com/gkamradt/LLMTest_NeedleInAHaystack/raw/main/needlehaystack/PaulGrahamEssays/worked.txt -------------------------------------------------------------------------------- /scripts/data/synthetic/json/download_paulgraham_essay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | import os 16 | import shutil 17 | import glob 18 | import json 19 | import urllib.request 20 | import html2text 21 | from bs4 import BeautifulSoup 22 | from tqdm import tqdm 23 | 24 | temp_folder_repo = 'essay_repo' 25 | temp_folder_html = 'essay_html' 26 | os.makedirs(temp_folder_repo, exist_ok=True) 27 | os.makedirs(temp_folder_html, exist_ok=True) 28 | 29 | h = html2text.HTML2Text() 30 | h.ignore_images = True 31 | h.ignore_tables = True 32 | h.escape_all = True 33 | h.reference_links = False 34 | h.mark_code = False 35 | 36 | with open('PaulGrahamEssays_URLs.txt') as f: 37 | urls = [line.strip() for line in f] 38 | 39 | for url in tqdm(urls): 40 | if '.html' in url: 41 | filename = url.split('/')[-1].replace('.html', '.txt') 42 | try: 43 | with urllib.request.urlopen(url) as website: 44 | content = website.read().decode("unicode_escape", "utf-8") 45 | soup = BeautifulSoup(content, 'html.parser') 46 | specific_tag = soup.find('font') 47 | parsed = h.handle(str(specific_tag)) 48 | 49 | with open(os.path.join(temp_folder_html, filename), 'w') as file: 50 | file.write(parsed) 51 | 52 | except Exception as e: 53 | print(f"Fail download {filename}, ({e})") 54 | 55 | else: 56 | filename = url.split('/')[-1] 57 | try: 58 | with urllib.request.urlopen(url) as website: 59 | content = website.read().decode('utf-8') 60 | 61 | with open(os.path.join(temp_folder_repo, filename), 'w') as file: 62 | file.write(content) 63 | 64 | except Exception as e: 65 | print(f"Fail download {filename}, ({e})") 66 | 67 | files_repo = sorted(glob.glob(os.path.join(temp_folder_repo,'*.txt'))) 68 | files_html = sorted(glob.glob(os.path.join(temp_folder_html,'*.txt'))) 69 | print(f'Download {len(files_repo)} essays from `https://github.com/gkamradt/LLMTest_NeedleInAHaystack/`') 70 | print(f'Download {len(files_html)} essays from `http://www.paulgraham.com/`') 71 | 72 | text = "" 73 | for file in files_repo + files_html: 74 | with open(file, 'r') as f: 75 | text += f.read() 76 | 77 | with open('PaulGrahamEssays.json', 'w') as f: 78 | json.dump({"text": text}, f) 79 | 80 | 81 | shutil.rmtree(temp_folder_repo) 82 | shutil.rmtree(temp_folder_html) 83 | -------------------------------------------------------------------------------- /scripts/data/synthetic/json/download_qa_dataset.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json -O squad.json 16 | wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json -O hotpotqa.json 17 | -------------------------------------------------------------------------------- /scripts/data/synthetic/json/english_words.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:affcd6d45fdf3cc843d585c99c97ad615094e760e6c4756b654bab6c73bc2eca 3 | size 8564991 4 | -------------------------------------------------------------------------------- /scripts/data/synthetic/niah.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | """ 16 | Create a dataset jsonl file for needle in a haystack. 17 | 18 | python niah.py \ 19 | --save_dir=./ \ 20 | --save_name=niah_single \ 21 | --tokenizer_path=tokenizer.model \ 22 | --tokenizer_type=nemo \ 23 | --max_seq_length=4096 \ 24 | --tokens_to_generate=128 \ 25 | --num_samples=10 \ 26 | --template="Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text? The special magic {type_needle_v} for {query} mentioned in the provided text are" 27 | """ 28 | import os 29 | import re 30 | import json 31 | import uuid 32 | import argparse 33 | import numpy as np 34 | from pathlib import Path 35 | from tqdm import tqdm 36 | import random 37 | import wonderwords 38 | from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest 39 | import sys 40 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) 41 | from tokenizer import select_tokenizer 42 | from nltk.tokenize import sent_tokenize 43 | import logging 44 | 45 | logging.basicConfig(level=logging.INFO, force=True) 46 | logger = logging.getLogger(__name__) 47 | 48 | 49 | from constants import TASKS 50 | 51 | parser = argparse.ArgumentParser() 52 | # Basic Configurations 53 | parser.add_argument("--save_dir", type=Path, required=True, help='dataset folder to save dataset') 54 | parser.add_argument("--save_name", type=str, required=True, help='name of the save dataset jsonl file') 55 | parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test') 56 | parser.add_argument("--tokenizer_path", type=str, required=True, help='path to the tokenizer model') 57 | parser.add_argument("--tokenizer_type", type=str, default='nemo', help='[Options] nemo, hf, openai.') 58 | parser.add_argument("--max_seq_length", type=int, required=True, help='max sequence length including all input tokens and generated tokens.') 59 | parser.add_argument("--tokens_to_generate", type=int, required=True, help='expected generated token amount.') 60 | parser.add_argument("--num_samples", type=int, required=True, help='number of samples to generate') 61 | parser.add_argument("--random_seed", type=int, default=42) 62 | parser.add_argument("--template", type=str, default='', help='prompt template') 63 | parser.add_argument("--remove_newline_tab", action='store_true', help='remove `\n` and `\t` in all strings.') 64 | 65 | # Complexity Configurations 66 | parser.add_argument("--num_needle_k", type=int, default=1) 67 | parser.add_argument("--num_needle_v", type=int, default=1) 68 | parser.add_argument("--num_needle_q", type=int, default=1) 69 | parser.add_argument("--type_haystack", type=str, default='essay', help='[Options] noise, essay, needle.') 70 | parser.add_argument("--type_needle_k", type=str, default='words', help='[Options] numbers, words, uuids.') 71 | parser.add_argument("--type_needle_v", type=str, default='numbers', help='[Options] numbers, words, uuids.') 72 | parser.add_argument("--model_template_token", type=int, default=0, help='used for nemo skills, minus num of model template token') 73 | 74 | args = parser.parse_args() 75 | random.seed(args.random_seed) 76 | np.random.seed(args.random_seed) 77 | args.num_needle_k = max(args.num_needle_k, args.num_needle_q) 78 | 79 | # Load Tokenizer 80 | TOKENIZER = select_tokenizer(args.tokenizer_type, args.tokenizer_path) 81 | 82 | # Define Needle/Haystack Format 83 | needle = "One of the special magic {type_needle_v} for {key} is: {value}." 84 | if args.type_haystack == 'essay': 85 | essay = os.path.join(os.path.dirname(os.path.abspath(__file__)), "json/PaulGrahamEssays.json") 86 | essay = json.load(open(essay))['text'] 87 | haystack = re.sub(r'\s+', " ", essay).split(" ") 88 | elif args.type_haystack == 'noise': 89 | haystack = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again." 90 | elif args.type_haystack == 'needle': 91 | haystack = needle 92 | else: 93 | raise NotImplementedError(f'{args.type_haystack} is not implemented.') 94 | 95 | 96 | # Words 97 | nouns = wonderwords.random_word._get_words_from_text_file("nounlist.txt") 98 | adjs = wonderwords.random_word._get_words_from_text_file("adjectivelist.txt") 99 | # verbs = wonderwords.random_word._get_words_from_text_file("verblist.txt") 100 | words = [f"{adj}-{noun}" for adj in adjs for noun in nouns] 101 | words = sorted(list(set(words))) 102 | 103 | 104 | # Positions 105 | DEPTHS = list(np.round(np.linspace(0, 100, num=40, endpoint=True)).astype(int)) 106 | 107 | 108 | def generate_random_number(num_digits=7): 109 | lower_bound = 10**(num_digits - 1) 110 | upper_bound = 10**num_digits - 1 111 | return str(random.randint(lower_bound, upper_bound)) 112 | 113 | def generate_random_word(): 114 | word = random.choice(words) 115 | return word 116 | 117 | def generate_random_uuid(): 118 | return str(uuid.UUID(int=random.getrandbits(128), version=4)) 119 | 120 | def generate_random(type_needle: str): 121 | if type_needle == 'numbers': 122 | return generate_random_number() 123 | elif type_needle == 'words': 124 | return generate_random_word() 125 | elif type_needle == 'uuids': 126 | return generate_random_uuid() 127 | else: 128 | raise NotImplementedError(f'{args.type_needle} is not implemented.') 129 | 130 | def generate_input_output(num_haystack): 131 | keys, values, needles = [], [], [] 132 | for _ in range(args.num_needle_k): 133 | keys.append(generate_random(args.type_needle_k)) 134 | value = [] 135 | for _ in range(args.num_needle_v): 136 | value.append(generate_random(args.type_needle_v)) 137 | needles.append(needle.format( 138 | type_needle_v=args.type_needle_v, 139 | key=keys[-1], 140 | value=value[-1], 141 | )) 142 | values.append(value) 143 | 144 | random.Random(args.random_seed).shuffle(needles) 145 | 146 | # Context 147 | if args.type_haystack == 'essay': 148 | text = " ".join(haystack[:num_haystack]) 149 | if num_haystack <= len(haystack): 150 | text = " ".join(haystack[:num_haystack]) 151 | else: 152 | # Repeat haystack as many times as needed and slice to num_haystack 153 | repeats = (num_haystack + len(haystack) - 1) // len(haystack) # Ceiling division 154 | text = " ".join((haystack * repeats)[:num_haystack]) 155 | document_sents = sent_tokenize(text.strip()) 156 | insertion_positions = [0] + \ 157 | sorted([int(len(document_sents) * (depth / 100)) for depth in random.sample(DEPTHS, len(needles))]) + \ 158 | [len(document_sents)] 159 | document_sents_list = [] 160 | for i in range(1,len(insertion_positions)): 161 | last_pos = insertion_positions[i-1] 162 | next_pos = insertion_positions[i] 163 | document_sents_list.append(" ".join(document_sents[last_pos:next_pos])) 164 | if i-1 < len(needles): 165 | document_sents_list.append(needles[i-1]) 166 | context = " ".join(document_sents_list) 167 | 168 | else: 169 | if args.type_haystack == 'noise': 170 | sentences = [haystack] * num_haystack 171 | elif args.type_haystack == 'needle': 172 | sentences = [haystack.format( 173 | type_needle_v=args.type_needle_v, 174 | key=generate_random(args.type_needle_k), 175 | value=generate_random(args.type_needle_v), 176 | ) for _ in range(num_haystack)] 177 | 178 | 179 | indexes = sorted(random.sample(range(num_haystack), len(needles)), reverse=True) 180 | for index, element in zip(indexes, needles): 181 | sentences.insert(index, element) 182 | context = "\n".join(sentences) 183 | 184 | 185 | ## Query and Answer 186 | indices = random.sample(range(args.num_needle_k), args.num_needle_q) 187 | queries = [keys[i] for i in indices] 188 | answers = [a for i in indices for a in values[i]] 189 | query = ', '.join(queries[:-1]) + ', and ' + queries[-1] if len(queries) > 1 else queries[0] 190 | 191 | template = args.template 192 | type_needle_v = args.type_needle_v 193 | if args.num_needle_q * args.num_needle_v == 1: 194 | template = template.replace('Some', 'A') 195 | template = template.replace('are all', 'is') 196 | template = template.replace('are', 'is') 197 | template = template.replace('answers', 'answer') 198 | type_needle_v = type_needle_v[:-1] # remove "s" 199 | 200 | input_text = template.format( 201 | type_needle_v=type_needle_v, 202 | context=context, 203 | query=query, 204 | ) 205 | 206 | return input_text, answers 207 | 208 | 209 | def generate_samples(num_samples: int, max_seq_length: int, save_dir: str, incremental: int = 500): 210 | write_jsons = [] 211 | tokens_to_generate = args.tokens_to_generate 212 | max_seq_length -= args.model_template_token 213 | 214 | if args.type_haystack == 'essay': 215 | incremental = 500 216 | elif args.type_haystack == 'noise': 217 | incremental = 25 218 | elif args.type_haystack == 'needle': 219 | incremental = 25 220 | 221 | if args.type_haystack != 'essay' and args.max_seq_length < 4096: 222 | incremental = 5 223 | 224 | # Estimate tokens per question to determine reasonable upper bound 225 | sample_input_text, _ = generate_input_output(incremental) 226 | sample_tokens = len(TOKENIZER.text_to_tokens(sample_input_text)) 227 | tokens_per_haystack = sample_tokens / incremental 228 | 229 | # Let's do 3x to allow for some slack since we can get unlucky due to sampling. 230 | # NOTE: We should test this for really large sequence lengths to make sure it's reasonable. 231 | estimated_max_questions = int((max_seq_length / tokens_per_haystack) * 3) 232 | 233 | # Binary search for optimal haystack size 234 | lower_bound = incremental 235 | upper_bound = max(estimated_max_questions, incremental * 2) # Ensure upper_bound is reasonable 236 | 237 | optimal_num_haystack = None 238 | 239 | logger.info(f"Estimated {tokens_per_haystack:.1f} tokens per haystack") 240 | logger.info(f"Starting binary search with bounds: {lower_bound} to {upper_bound}") 241 | 242 | while lower_bound <= upper_bound: 243 | mid = (lower_bound + upper_bound) // 2 244 | input_text, answer = generate_input_output(mid) 245 | total_tokens = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate 246 | 247 | logger.info(f"Testing haystack size: {mid}, resulting tokens: {total_tokens}/{max_seq_length}") 248 | 249 | if total_tokens <= max_seq_length: 250 | # This size works, can we go larger? 251 | optimal_num_haystack = mid 252 | lower_bound = mid + 1 253 | else: 254 | # Too large, need to go smaller 255 | upper_bound = mid - 1 256 | 257 | num_haystack = optimal_num_haystack if optimal_num_haystack is not None else incremental 258 | logger.info(f'Final optimal haystack size (number of haystack): {num_haystack}') 259 | 260 | 261 | 262 | # Generate samples 263 | for index in tqdm(range(num_samples)): 264 | used_haystack = num_haystack 265 | while(True): 266 | try: 267 | input_text, answer = generate_input_output(used_haystack) 268 | length = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate 269 | assert length <= max_seq_length, f"{length} exceeds max_seq_length." 270 | break 271 | except: 272 | if used_haystack > incremental: 273 | used_haystack -= incremental 274 | 275 | if args.remove_newline_tab: 276 | input_text = ' '.join(input_text.replace('\n', ' ').replace('\t', ' ').strip().split()) 277 | answer_prefix_index = input_text.rfind(TASKS['niah']['answer_prefix'][:10]) # use first 10 char of answer prefix to locate it 278 | answer_prefix = input_text[answer_prefix_index:] 279 | input_text = input_text[:answer_prefix_index] 280 | # find answer position in text 281 | index = input_text.find(answer[0]) 282 | token_position_answer = len(TOKENIZER.text_to_tokens(input_text[:index])) 283 | formatted_output = { 284 | 'index': index, 285 | "input": input_text, 286 | "outputs": answer, 287 | "length": length, 288 | 'length_w_model_temp': length + args.model_template_token, 289 | 'answer_prefix': answer_prefix, 290 | 'token_position_answer': token_position_answer, 291 | } 292 | write_jsons.append(formatted_output) 293 | 294 | return write_jsons 295 | 296 | 297 | def main(): 298 | save_file = args.save_dir / f'{args.save_name}' / f'{args.subset}.jsonl' 299 | save_file.parent.mkdir(parents=True, exist_ok=True) 300 | write_jsons = generate_samples( 301 | num_samples=args.num_samples, 302 | max_seq_length=args.max_seq_length, 303 | save_dir=args.save_dir 304 | ) 305 | 306 | write_manifest(save_file, write_jsons) 307 | 308 | if __name__ == "__main__": 309 | main() 310 | -------------------------------------------------------------------------------- /scripts/data/synthetic/qa.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | """ 16 | Create a dataset jsonl file for QA task. 17 | 18 | python qa.py \ 19 | --save_dir=./ \ 20 | --save_name=niah_single \ 21 | --tokenizer_path=tokenizer.model \ 22 | --tokenizer_type=nemo \ 23 | --max_seq_length=4096 \ 24 | --tokens_to_generate=128 \ 25 | --num_samples=10 \ 26 | --template="Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query} Answer:" 27 | """ 28 | import os 29 | import re 30 | import json 31 | import argparse 32 | from pathlib import Path 33 | from tqdm import tqdm 34 | import random 35 | import numpy as np 36 | from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest 37 | import sys 38 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) 39 | from tokenizer import select_tokenizer 40 | import logging 41 | 42 | logging.basicConfig(level=logging.INFO, force=True) 43 | logger = logging.getLogger(__name__) 44 | 45 | from constants import TASKS 46 | 47 | parser = argparse.ArgumentParser() 48 | # Basic Configurations 49 | parser.add_argument("--save_dir", type=Path, required=True, help='dataset folder to save dataset') 50 | parser.add_argument("--save_name", type=str, required=True, help='name of the save dataset jsonl file') 51 | parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test') 52 | parser.add_argument("--tokenizer_path", type=str, required=True, help='path to the tokenizer model') 53 | parser.add_argument("--tokenizer_type", type=str, default='nemo', help='[Options] nemo, hf, openai.') 54 | parser.add_argument("--max_seq_length", type=int, required=True, help='max sequence length including all input tokens and generated tokens.') 55 | parser.add_argument("--tokens_to_generate", type=int, required=True, help='expected generated token amount.') 56 | parser.add_argument("--num_samples", type=int, required=True, help='number of samples to generate') 57 | parser.add_argument("--pre_samples", type=int, default=0, help='number of samples are already generated') 58 | parser.add_argument("--random_seed", type=int, default=42) 59 | parser.add_argument("--template", type=str, required=True, help='prompt template') 60 | parser.add_argument("--remove_newline_tab", action='store_true', help='remove `\n` and `\t` in all strings.') 61 | parser.add_argument("--model_template_token", type=int, default=0, help='used for nemo skills, minus num of model template token') 62 | # Complexity Configurations 63 | parser.add_argument("--dataset", type=str, required=True, help='dataset file') 64 | 65 | args = parser.parse_args() 66 | random.seed(args.random_seed) 67 | np.random.seed(args.random_seed) 68 | 69 | # Load Tokenizer 70 | TOKENIZER = select_tokenizer(args.tokenizer_type, args.tokenizer_path) 71 | 72 | # Read SQuAD QA dataset 73 | def read_squad(file): 74 | with open(file) as f: 75 | data = json.load(f) 76 | 77 | total_docs = [p['context'] for d in data['data'] for p in d['paragraphs']] 78 | total_docs = sorted(list(set(total_docs))) 79 | total_docs_dict = {c: idx for idx, c in enumerate(total_docs)} 80 | 81 | total_qas = [] 82 | for d in data['data']: 83 | more_docs = [total_docs_dict[p['context']] for p in d['paragraphs']] 84 | for p in d['paragraphs']: 85 | for qas in p['qas']: 86 | if not qas['is_impossible']: 87 | total_qas.append({ 88 | 'query': qas['question'], 89 | 'outputs': [a['text'] for a in qas['answers']], 90 | 'context': [total_docs_dict[p['context']]], 91 | 'more_context': [idx for idx in more_docs if idx != total_docs_dict[p['context']]] 92 | }) 93 | 94 | return total_qas, total_docs 95 | 96 | # Read Hotpot QA dataset 97 | def read_hotpotqa(file): 98 | with open(file) as f: 99 | data = json.load(f) 100 | 101 | total_docs = [f"{t}\n{''.join(p)}" for d in data for t, p in d['context']] 102 | total_docs = sorted(list(set(total_docs))) 103 | total_docs_dict = {c: idx for idx, c in enumerate(total_docs)} 104 | 105 | total_qas = [] 106 | for d in data: 107 | total_qas.append({ 108 | 'query': d['question'], 109 | 'outputs': [d['answer']], 110 | 'context': [total_docs_dict[f"{t}\n{''.join(p)}"] for t, p in d['context']], 111 | }) 112 | 113 | return total_qas, total_docs 114 | 115 | 116 | DOCUMENT_PROMPT = "Document {i}:\n{document}" 117 | if args.dataset == 'squad': 118 | QAS, DOCS = read_squad(os.path.join(os.path.dirname(os.path.abspath(__file__)), "json/squad.json")) 119 | elif args.dataset == 'hotpotqa': 120 | QAS, DOCS = read_hotpotqa(os.path.join(os.path.dirname(os.path.abspath(__file__)), "json/hotpotqa.json")) 121 | else: 122 | raise NotImplementedError(f'{args.dataset} is not implemented.') 123 | 124 | 125 | def generate_input_output(index, num_docs): 126 | curr_q = QAS[index]['query'] 127 | curr_a = QAS[index]['outputs'] 128 | curr_docs = QAS[index]['context'] 129 | curr_more = QAS[index].get('more_context', []) 130 | if num_docs < len(DOCS): 131 | if (num_docs - len(curr_docs)) > len(curr_more): 132 | addition_docs = [i for i, d in enumerate(DOCS) if i not in curr_docs + curr_more] 133 | all_docs = curr_docs + curr_more + random.sample(addition_docs, max(0, num_docs - len(curr_docs) - len(curr_more))) 134 | else: 135 | all_docs = curr_docs + random.sample(curr_more, num_docs - len(curr_docs)) 136 | 137 | all_docs = [DOCS[idx] for idx in all_docs] 138 | else: 139 | # Repeat DOCS as many times as needed and slice to num_docs 140 | repeats = (num_docs + len(DOCS) - 1) // len(DOCS) # Ceiling division 141 | all_docs = (DOCS * repeats)[:num_docs] 142 | 143 | random.Random(args.random_seed).shuffle(all_docs) 144 | 145 | context = '\n\n'.join([DOCUMENT_PROMPT.format(i=i+1, document=d) for i, d in enumerate(all_docs)]) 146 | input_text = args.template.format( 147 | context=context, 148 | query=curr_q 149 | ) 150 | return input_text, curr_a 151 | 152 | 153 | def generate_samples(num_samples: int, max_seq_length: int, save_dir: str, incremental: int = 10): 154 | 155 | write_jsons = [] 156 | tokens_to_generate = args.tokens_to_generate 157 | max_seq_length -= args.model_template_token 158 | 159 | # Estimate tokens per question to determine reasonable upper bound 160 | sample_input_text, _ = generate_input_output(0, incremental) 161 | sample_tokens = len(TOKENIZER.text_to_tokens(sample_input_text)) 162 | tokens_per_doc = sample_tokens / incremental 163 | 164 | # Let's do 3x to allow for some slack since we can get unlucky due to sampling. 165 | # NOTE: We should test this for really large sequence lengths to make sure it's reasonable. 166 | estimated_max_docs = int((max_seq_length / tokens_per_doc) * 3) 167 | 168 | # Binary search for optimal haystack size 169 | lower_bound = incremental 170 | upper_bound = max(estimated_max_docs, incremental * 2) # Ensure upper_bound is reasonable 171 | 172 | optimal_num_docs = None 173 | 174 | logger.info(f"Estimated {tokens_per_doc:.1f} tokens per doc") 175 | logger.info(f"Starting binary search with bounds: {lower_bound} to {upper_bound}") 176 | 177 | while lower_bound <= upper_bound: 178 | mid = (lower_bound + upper_bound) // 2 179 | input_text, answer = generate_input_output(0, mid) 180 | total_tokens = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate 181 | 182 | logger.info(f"Testing haystack size: {mid}, resulting tokens: {total_tokens}/{max_seq_length}") 183 | 184 | if total_tokens <= max_seq_length: 185 | # This size works, can we go larger? 186 | optimal_num_docs = mid 187 | lower_bound = mid + 1 188 | else: 189 | # Too large, need to go smaller 190 | upper_bound = mid - 1 191 | 192 | num_docs = optimal_num_docs if optimal_num_docs is not None else incremental 193 | logger.info(f'Final optimal haystack size (number of docs): {num_docs}') 194 | 195 | # Generate samples 196 | for index in tqdm(range(num_samples)): 197 | used_docs = num_docs 198 | while(True): 199 | try: 200 | input_text, answer = generate_input_output(index + args.pre_samples, used_docs) 201 | length = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate 202 | assert length <= max_seq_length, f"{length} exceeds max_seq_length." 203 | break 204 | except: 205 | if used_docs > incremental: 206 | used_docs -= incremental 207 | 208 | if args.remove_newline_tab: 209 | input_text = ' '.join(input_text.replace('\n', ' ').replace('\t', ' ').strip().split()) 210 | answer_prefix_index = input_text.rfind(TASKS['qa']['answer_prefix'][:10]) # use first 10 char of answer prefix to locate it 211 | answer_prefix = input_text[answer_prefix_index:] 212 | input_text = input_text[:answer_prefix_index] 213 | formatted_output = { 214 | "index": index, 215 | "input": input_text, 216 | "outputs": answer, 217 | "length": length, 218 | 'length_w_model_temp': length + args.model_template_token, 219 | 'answer_prefix': answer_prefix, 220 | } 221 | write_jsons.append(formatted_output) 222 | 223 | return write_jsons 224 | 225 | 226 | def main(): 227 | save_file = args.save_dir / f'{args.save_name}' / f'{args.subset}.jsonl' 228 | save_file.parent.mkdir(parents=True, exist_ok=True) 229 | 230 | write_jsons = generate_samples( 231 | num_samples=args.num_samples, 232 | max_seq_length=args.max_seq_length, 233 | save_dir=args.save_dir 234 | ) 235 | 236 | write_manifest(save_file, write_jsons) 237 | 238 | if __name__=="__main__": 239 | main() 240 | -------------------------------------------------------------------------------- /scripts/data/synthetic/variable_tracking.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | """ 16 | Create a dataset jsonl file for variable tracking. 17 | 18 | python variable_tracking.py \ 19 | --save_dir=./ \ 20 | --save_name=vt \ 21 | --tokenizer_path='EleutherAI/gpt-neox-20b' \ 22 | --tokenizer_type hf \ 23 | --max_seq_length 4096 \ 24 | --tokens_to_generate 30 \ 25 | --num_samples 10 \ 26 | --random_seed 42 \ 27 | --num_chains 1 --num_hops 4 \ 28 | --template "[INST] Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n{context}\nQuestion: Find all variables that are assigned the value {query} in the text above. [/INST] Answer: According to the chain(s) of variable assignment in the text above, {num_v} variables are assgined the value {query}, they are: " 29 | """ 30 | import os 31 | import json 32 | import re 33 | import argparse 34 | from pathlib import Path 35 | from tqdm import tqdm 36 | import random 37 | import string 38 | from constants import TASKS 39 | from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest 40 | import sys 41 | import pdb 42 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) 43 | from tokenizer import select_tokenizer 44 | from nltk.tokenize import sent_tokenize 45 | import numpy as np 46 | import heapq 47 | import json 48 | import logging 49 | 50 | logging.basicConfig(level=logging.INFO, force=True) 51 | logger = logging.getLogger(__name__) 52 | 53 | 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--save_dir", type=Path, required=True, help='dataset folder to save dataset') 56 | parser.add_argument("--save_name", type=str, required=True, help='name of the save dataset jsonl file') 57 | parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test') 58 | parser.add_argument("--tokenizer_path", type=str, required=True, help='path to the tokenizer model') 59 | parser.add_argument("--tokenizer_type", type=str, default='nemo', help='[Options] nemo, hf, openai.') 60 | parser.add_argument("--max_seq_length", type=int, required=True, help='max sequence length including all input tokens and generated tokens.') 61 | parser.add_argument("--tokens_to_generate", type=int, default=120, help='number of tokens to generate') 62 | parser.add_argument("--num_samples", type=int, required=True, help='number of samples to generate') 63 | parser.add_argument("--random_seed", type=int, default=42) 64 | parser.add_argument("--template", type=str, default='', help='prompt template') 65 | parser.add_argument("--remove_newline_tab", action='store_true', help='remove `\n` and `\t` in all strings.') 66 | 67 | parser.add_argument("--type_haystack", type=str, default='noise', help='[Options] noise or essay.') 68 | parser.add_argument("--num_chains", type=int, default=1, help='number of inserted variable chains') 69 | parser.add_argument("--num_hops", type=int, default=4, help='number of hops in each chain') 70 | parser.add_argument("--add_fewshot", action="store_true", default=False) 71 | parser.add_argument("--model_template_token", type=int, default=0, help='used for nemo skills, minus num of model template token') 72 | 73 | args = parser.parse_args() 74 | random.seed(args.random_seed) 75 | np.random.seed(args.random_seed) 76 | 77 | # Load Tokenizer 78 | TOKENIZER = select_tokenizer(args.tokenizer_type, args.tokenizer_path) 79 | 80 | # Define Needle/Haystack Format 81 | if args.type_haystack == 'essay': 82 | essay = os.path.join(os.path.dirname(os.path.abspath(__file__)), "json/PaulGrahamEssays.json") 83 | essay = json.load(open(essay))['text'] 84 | haystack = re.sub(r'\s+', " ", essay).split(" ") 85 | elif args.type_haystack == 'noise': 86 | haystack = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again." 87 | else: 88 | raise NotImplementedError(f'{args.type_haystack} is not implemented.') 89 | 90 | # Positions 91 | DEPTHS = list(np.round(np.linspace(0, 100, num=40, endpoint=True)).astype(int)) 92 | 93 | def generate_chains(num_chains, num_hops, is_icl=False): 94 | 95 | vars_all = [] 96 | k = 5 if not is_icl else 3 97 | num_hops = num_hops if not is_icl else min(10, num_hops) 98 | vars_all = [''.join(random.choices(string.ascii_uppercase, k=k)).upper() for _ in range((num_hops+1) * num_chains)] 99 | while len(set(vars_all)) < num_chains * (num_hops+1): 100 | vars_all.append(''.join(random.choices(string.ascii_uppercase, k=k)).upper()) 101 | 102 | vars_ret = [] 103 | chains_ret = [] 104 | for i in range(0, len(vars_all), num_hops+1): 105 | this_vars = vars_all[i:i+num_hops+1] 106 | vars_ret.append(this_vars) 107 | if is_icl: 108 | this_chain = [f"VAR {this_vars[0]} = 12345"] 109 | else: 110 | this_chain = [f"VAR {this_vars[0]} = {str(np.random.randint(10000, 99999))}"] 111 | for j in range(num_hops): 112 | this_chain.append(f"VAR {this_vars[j+1]} = VAR {this_vars[j]} ") 113 | chains_ret.append(this_chain) 114 | return vars_ret, chains_ret 115 | 116 | def shuffle_sublists_heap(lst): 117 | heap = [] 118 | for i in range(len(lst)): 119 | heapq.heappush(heap, (random.random(), i, 0)) # Push first element of each list with random priority 120 | shuffled_result = [] 121 | while heap: 122 | _, list_idx, elem_idx = heapq.heappop(heap) # Get the lowest random priority element 123 | shuffled_result.append(lst[list_idx][elem_idx]) 124 | 125 | # If there are more elements in the same sublist, add the next one 126 | if elem_idx + 1 < len(lst[list_idx]): 127 | heapq.heappush(heap, (random.random(), list_idx, elem_idx + 1)) 128 | return shuffled_result 129 | 130 | def generate_input_output(num_noises, num_chains, num_hops, is_icl=False): 131 | 132 | vars, chains = generate_chains(num_chains, num_hops, is_icl=is_icl) 133 | value = chains[0][0].split("=")[-1].strip() 134 | 135 | if args.type_haystack == 'essay': 136 | text = " ".join(haystack[:num_noises]) 137 | document_sents = sent_tokenize(text.strip()) 138 | chains_flat = shuffle_sublists_heap(chains) 139 | insertion_positions = [0] + \ 140 | sorted([int(len(document_sents) * (depth / 100)) for depth in random.sample(DEPTHS, len(chains_flat))]) + \ 141 | [len(document_sents)] 142 | document_sents_list = [] 143 | for i in range(1,len(insertion_positions)): 144 | last_pos = insertion_positions[i-1] 145 | next_pos = insertion_positions[i] 146 | document_sents_list.append(" ".join(document_sents[last_pos:next_pos])) 147 | if i-1 < len(chains_flat): 148 | document_sents_list.append(chains_flat[i-1].strip() + ".") 149 | context = " ".join(document_sents_list) 150 | 151 | elif args.type_haystack == 'noise': 152 | sentences = [haystack] * num_noises 153 | for chain in chains: 154 | positions = list(sorted(random.sample(range(len(sentences)), len(chain)))) 155 | for insert_pi, j in zip(positions, range(len(chain))): 156 | sentences.insert(insert_pi+j, chain[j]) 157 | context = "\n".join(sentences) 158 | 159 | context = context.replace(". \n", ".\n") 160 | 161 | template = args.template 162 | if is_icl and template != TASKS['variable_tracking']['template'] + TASKS['variable_tracking']['answer_prefix']: 163 | # remove model template 164 | new_template = "" 165 | if len(TASKS['variable_tracking']['template']) > 0: 166 | new_template = TASKS['variable_tracking']['template'] 167 | if len(TASKS['variable_tracking']['answer_prefix']) > 0: 168 | new_template += TASKS['variable_tracking']['answer_prefix'] 169 | template = new_template 170 | 171 | input_text = template.format( 172 | context=context, 173 | query=value, 174 | num_v=num_hops+1 175 | ) 176 | 177 | return input_text, vars[0] 178 | 179 | def randomize_icl(icl_example): 180 | icl_tgt = icl_example.strip().split()[-args.num_hops-1:] 181 | for item in icl_tgt: 182 | new_item = ''.join(random.choices(string.ascii_uppercase, k=len(item))).upper() 183 | icl_example = icl_example.replace(item, new_item) 184 | 185 | old_value = "12345" 186 | new_value = str(np.random.randint(10000, 99999)) 187 | icl_example = icl_example.replace(old_value, new_value) 188 | 189 | return icl_example 190 | 191 | def sys_vartrack_w_noise_random(num_samples: int, max_seq_length: int, incremental: int = 10, 192 | num_chains: int = 1, num_hops: int = 4, 193 | add_fewshot: bool = True, 194 | icl_example: str = None, 195 | final_output: bool = False): 196 | write_jsons = [] 197 | tokens_to_generate = args.tokens_to_generate if icl_example is not None else 0 198 | max_seq_length -= args.model_template_token 199 | 200 | # Find the perfect num_noises 201 | if icl_example: 202 | if args.type_haystack == 'essay': 203 | incremental = 500 204 | elif args.type_haystack == 'noise': 205 | incremental = 10 206 | 207 | if args.type_haystack != 'essay' and args.max_seq_length < 4096: 208 | incremental = 5 209 | else: 210 | if args.type_haystack == 'essay': 211 | incremental = 50 212 | elif args.type_haystack == 'noise': 213 | incremental = 5 214 | 215 | example_tokens = 0 216 | if add_fewshot and (icl_example is not None): 217 | icl_example_out = ' '.join(icl_example['outputs']) 218 | icl_example = icl_example['input'] + " " + icl_example_out + '\n' 219 | example_tokens = len(TOKENIZER.text_to_tokens(icl_example)) 220 | 221 | # Estimate tokens per question to determine reasonable upper bound 222 | sample_input_text, _ = generate_input_output(incremental, num_chains, num_hops, is_icl=add_fewshot & (icl_example is None)) 223 | sample_tokens = len(TOKENIZER.text_to_tokens(sample_input_text)) 224 | tokens_per_haystack = sample_tokens / incremental 225 | 226 | # Let's do 3x to allow for some slack since we can get unlucky due to sampling. 227 | # NOTE: We should test this for really large sequence lengths to make sure it's reasonable. 228 | estimated_max_noises = int((max_seq_length / tokens_per_haystack) * 3) 229 | 230 | # Binary search for optimal haystack size 231 | lower_bound = incremental 232 | upper_bound = max(estimated_max_noises, incremental * 2) # Ensure upper_bound is reasonable 233 | 234 | optimal_num_noises = None 235 | 236 | logger.info(f"Estimated {tokens_per_haystack:.1f} tokens per haystack") 237 | logger.info(f"Starting binary search with bounds: {lower_bound} to {upper_bound}") 238 | while lower_bound <= upper_bound: 239 | mid = (lower_bound + upper_bound) // 2 240 | input_text, answer = generate_input_output(mid, num_chains, num_hops, is_icl=add_fewshot & (icl_example is None)) 241 | total_tokens = len(TOKENIZER.text_to_tokens(input_text)) + example_tokens + tokens_to_generate 242 | 243 | logger.info(f"Testing haystack size: {mid}, resulting tokens: {total_tokens}/{max_seq_length}") 244 | 245 | if total_tokens <= max_seq_length: 246 | # This size works, can we go larger? 247 | optimal_num_noises = mid 248 | lower_bound = mid + 1 249 | else: 250 | # Too large, need to go smaller 251 | upper_bound = mid - 1 252 | 253 | num_noises = optimal_num_noises if optimal_num_noises is not None else incremental 254 | logger.info(f'Final optimal haystack size (number of haystack): {num_noises}') 255 | 256 | # Generate samples 257 | for index in tqdm(range(num_samples)): 258 | used_noises = num_noises 259 | while(True): 260 | try: 261 | input_text, answer = generate_input_output(used_noises, num_chains, num_hops, is_icl=add_fewshot & (icl_example is None)) 262 | length = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate + example_tokens 263 | assert length <= max_seq_length, f"{length} exceeds max_seq_length." 264 | break 265 | except: 266 | if used_noises > incremental: 267 | used_noises -= incremental 268 | 269 | if add_fewshot and (icl_example is not None): 270 | # insert icl_example between model template and input 271 | cutoff = input_text.index(TASKS['variable_tracking']['template'][:20]) 272 | input_text = input_text[:cutoff] + randomize_icl(icl_example) + '\n' + input_text[cutoff:] 273 | if args.remove_newline_tab: 274 | input_text = ' '.join(input_text.replace('\n', ' ').replace('\t', ' ').strip().split()) 275 | 276 | if final_output: 277 | answer_prefix_index = input_text.rfind(TASKS['variable_tracking']['answer_prefix'][:10]) # use first 10 char of answer prefix to locate it 278 | answer_prefix = input_text[answer_prefix_index:] 279 | input_text = input_text[:answer_prefix_index] 280 | formatted_output = { 281 | 'index': index, 282 | "input": input_text, 283 | "outputs": answer, 284 | "length": length, 285 | 'length_w_model_temp': length + args.model_template_token, 286 | 'answer_prefix': answer_prefix, 287 | } 288 | else: 289 | formatted_output = { 290 | 'index': index, 291 | "input": input_text, 292 | "outputs": answer, 293 | "length": length, 294 | } 295 | write_jsons.append(formatted_output) 296 | 297 | return write_jsons 298 | 299 | 300 | def main(): 301 | save_file = args.save_dir / f'{args.save_name}' / f'{args.subset}.jsonl' 302 | save_file.parent.mkdir(parents=True, exist_ok=True) 303 | 304 | icl_example = sys_vartrack_w_noise_random(num_samples=1, 305 | max_seq_length=500, 306 | incremental=5, 307 | num_chains=args.num_chains, 308 | num_hops=args.num_hops)[0] 309 | logger.info(icl_example) 310 | write_jsons = sys_vartrack_w_noise_random(num_samples=args.num_samples, 311 | max_seq_length=args.max_seq_length, 312 | num_chains=args.num_chains, 313 | num_hops=args.num_hops, 314 | icl_example=icl_example, 315 | final_output=True) 316 | 317 | write_manifest(save_file, write_jsons) 318 | 319 | if __name__=="__main__": 320 | main() 321 | -------------------------------------------------------------------------------- /scripts/data/template.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | Templates = { 16 | 'base': "{task_template}", 17 | 18 | 'meta-chat': "[INST] {task_template} [/INST]", 19 | 20 | 'vicuna-chat': "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {task_template} ASSISTANT:", 21 | 22 | 'lwm-chat': "You are a helpful assistant. USER: {task_template} ASSISTANT: ", 23 | 24 | 'command-r-chat': "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{task_template}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", 25 | 26 | 'chatglm-chat': "[gMASK]sop<|user|> \n {task_template}<|assistant|> \n ", 27 | 28 | 'RWKV': "User: hi\n\nAssistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it\n\nUser: {task_template}\n\nAssistant:", 29 | 30 | 'Phi3': "<|user|>\n{task_template}<|end|>\n<|assistant|>\n", 31 | 32 | 'meta-llama3': "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{task_template}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", 33 | 34 | 'jamba': "<|startoftext|><|bom|><|system|> <|eom|><|bom|><|user|> {task_template}<|eom|><|bom|><|assistant|>", 35 | 36 | 'nemotron5-instruct': "System\n\nUser\n{task_template}\nAssistant\n", 37 | } -------------------------------------------------------------------------------- /scripts/data/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | from typing import List 18 | from tenacity import ( 19 | retry, 20 | stop_after_attempt, 21 | wait_fixed, 22 | wait_random, 23 | ) 24 | 25 | 26 | def select_tokenizer(tokenizer_type, tokenizer_path): 27 | if tokenizer_type == 'nemo': 28 | return NeMoSentencePieceTokenizer(model_path=tokenizer_path) 29 | elif tokenizer_type == 'nemo_tiktoken': 30 | return NeMoTikTokenTokenizer(model_path=tokenizer_path) 31 | elif tokenizer_type == 'hf': 32 | return HFTokenizer(model_path=tokenizer_path) 33 | elif tokenizer_type == 'openai': 34 | return OpenAITokenizer(model_path=tokenizer_path) 35 | elif tokenizer_type == 'gemini': 36 | return GeminiTokenizer(model_path=tokenizer_path) 37 | else: 38 | raise ValueError(f"Unknown tokenizer_type {tokenizer_type}") 39 | 40 | 41 | class NeMoSentencePieceTokenizer: 42 | """ 43 | Tokenizer from NeMo SentencePieceTokenizer 44 | """ 45 | def __init__(self, model_path) -> None: 46 | from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer 47 | self.tokenizer = SentencePieceTokenizer(model_path=model_path) 48 | 49 | def text_to_tokens(self, text: str) -> List[str]: 50 | tokens = self.tokenizer.text_to_tokens(text) 51 | return tokens 52 | 53 | def tokens_to_text(self, tokens: List[int]) -> str: 54 | text = self.tokenizer.tokens_to_text(tokens) 55 | return text 56 | 57 | class NeMoTikTokenTokenizer: 58 | """ 59 | Tokenizer from NeMo SentencePieceTokenizer 60 | """ 61 | def __init__(self, model_path) -> None: 62 | from nemo.collections.common.tokenizers.tiktoken_tokenizer import TiktokenTokenizer 63 | self.tokenizer = TiktokenTokenizer(vocab_file=model_path) 64 | 65 | def text_to_tokens(self, text: str) -> List[str]: 66 | tokens = self.tokenizer.text_to_tokens(text) 67 | return tokens 68 | 69 | def tokens_to_text(self, tokens: List[int]) -> str: 70 | text = self.tokenizer.tokens_to_text(tokens) 71 | return text 72 | 73 | 74 | class HFTokenizer: 75 | """ 76 | Tokenizer from HF models 77 | """ 78 | def __init__(self, model_path) -> None: 79 | from transformers import AutoTokenizer 80 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 81 | 82 | def text_to_tokens(self, text: str) -> List[str]: 83 | tokens = self.tokenizer.tokenize(text) 84 | return tokens 85 | 86 | def tokens_to_text(self, tokens: List[int]) -> str: 87 | text = self.tokenizer.convert_tokens_to_string(tokens) 88 | return text 89 | 90 | 91 | class OpenAITokenizer: 92 | """ 93 | Tokenizer from tiktoken 94 | """ 95 | def __init__(self, model_path="cl100k_base") -> None: 96 | import tiktoken 97 | self.tokenizer = tiktoken.get_encoding(model_path) 98 | 99 | def text_to_tokens(self, text: str) -> List[int]: 100 | tokens = self.tokenizer.encode(text) 101 | return tokens 102 | 103 | def tokens_to_text(self, tokens: List[int]) -> str: 104 | text = self.tokenizer.decode(tokens) 105 | return text 106 | 107 | 108 | class GeminiTokenizer: 109 | """ 110 | Tokenizer from gemini 111 | """ 112 | def __init__(self, model_path="gemini-1.5-pro-latest") -> None: 113 | import google.generativeai as genai 114 | genai.configure(api_key=os.environ["GEMINI_API_KEY"]) 115 | self.model = genai.GenerativeModel(model_path) 116 | 117 | @retry(wait=wait_fixed(60) + wait_random(0, 10), stop=stop_after_attempt(3)) 118 | def text_to_tokens(self, text: str) -> List[int]: 119 | tokens = list(range(self.model.count_tokens(text).total_tokens)) 120 | return tokens 121 | 122 | def tokens_to_text(self, tokens: List[int]) -> str: 123 | pass -------------------------------------------------------------------------------- /scripts/eval/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Get summary.csv with score and null predictions amount. 17 | 18 | Running 19 | ``` 20 | python evaluate.py \ 21 | --data_dir /path/to/your/prediction_jsonl_folder \ 22 | --benchmark synthetic 23 | ``` 24 | """ 25 | 26 | import re 27 | import os 28 | import argparse 29 | import nltk 30 | try: 31 | nltk.data.find('tokenizers/punkt') 32 | except LookupError: 33 | nltk.download('punkt') 34 | 35 | import pandas as pd 36 | import importlib 37 | import yaml 38 | from pathlib import Path 39 | from tqdm import tqdm 40 | from collections import defaultdict 41 | from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest 42 | 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("--data_dir", type=str, required=True, help='path to the prediction jsonl files') 45 | parser.add_argument("--benchmark", type=str, default='synthetic', help='Options: [synthetic]') 46 | parser.add_argument("--verbose", type=int, default=0, help='how many lines you want to display.') 47 | args = parser.parse_args() 48 | 49 | 50 | def postprocess_pred(predict_str: str, task_config: dict): 51 | 52 | predict_str = predict_str.strip() 53 | 54 | # Remove all non-printable characters 55 | np_pattern = re.compile(r'[\x00-\x1f]') 56 | predict_str = np_pattern.sub('\n', predict_str).strip() 57 | 58 | return predict_str 59 | 60 | 61 | def get_pred_and_ref( 62 | predictions_file: str, 63 | task_config: dict, 64 | input_field: str = 'input', 65 | references_field: str = 'outputs', 66 | prediction_field: str = 'pred', 67 | metadata_field: str = 'others', 68 | ): 69 | lines = read_manifest(predictions_file) 70 | 71 | inputs = [] 72 | predicts = [] 73 | references = [] 74 | indices = [] 75 | 76 | for line in tqdm(lines): 77 | input = line[input_field] 78 | predict = line[prediction_field] 79 | predict = postprocess_pred(predict, task_config) 80 | reference = line.get(references_field, [line.get('output', '')]) 81 | index = line[metadata_field].get('id', line['index']) 82 | 83 | inputs.append(input) 84 | predicts.append(predict) 85 | references.append(reference) 86 | indices.append(index) 87 | 88 | return inputs, predicts, references, indices 89 | 90 | def run_evaluation_per_task(task_config: dict, predictions_file: str, verbose: int = 0): 91 | inputs, predicts, references, indices = get_pred_and_ref( 92 | predictions_file=predictions_file, 93 | task_config=task_config, 94 | ) 95 | 96 | task_nulls = f'{sum([len(x)==0 for x in predicts])}/{len(predicts)}' 97 | 98 | if len(references) > 0 and references[0][0] is not None: 99 | task_score = task_config['metric_fn'](predicts, references) 100 | else: 101 | task_score = 0.0 102 | 103 | if verbose != 0: 104 | print('=' * 40) 105 | for i, (input, reference, predict) in enumerate(zip(inputs, references, predicts)): 106 | print(f'Input : {input}') 107 | print(f'Reference : {reference}') 108 | print(f'Prediction: {predict}') 109 | print('=' * 40) 110 | if i > verbose: 111 | break 112 | 113 | return task_score, task_nulls, predicts, indices 114 | 115 | 116 | def write_evaluation(results: dict): 117 | tasks = list(results.keys()) 118 | score = [results[task]['score'] for task in tasks] 119 | nulls = [results[task]['nulls'] for task in tasks] 120 | dfs = [ 121 | ['Tasks'] + tasks, 122 | ['Score'] + score, 123 | ['Nulls'] + nulls, 124 | ] 125 | 126 | output_file = os.path.join(args.data_dir, 'summary.csv' if len(tasks) > 1 else f'summary-{tasks[0]}.csv') 127 | df = pd.DataFrame(dfs) 128 | df.to_csv(output_file, index=False) 129 | print('\n=============================================\n') 130 | print(df) 131 | print(f'\nSaved eval results to {output_file}') 132 | 133 | 134 | def write_submission(results: dict): 135 | COLUMNS = ["Task", "ID", "Prediction"] 136 | dfs = pd.DataFrame(columns=COLUMNS, data=[]) 137 | 138 | for task, result in results.items(): 139 | df = pd.DataFrame({ 140 | 'Task': task, 141 | 'ID': result['indices'], 142 | 'Prediction': result['predicts'] 143 | }) 144 | dfs = pd.concat((dfs, df[COLUMNS])) 145 | 146 | output_file = os.path.join(args.data_dir, 'submission.csv') 147 | dfs = dfs.reset_index(drop=True) 148 | dfs.to_csv(output_file, index=False) 149 | print(f'\nSaved submission results to {output_file}') 150 | 151 | 152 | def aggregate_chunk(folder): 153 | jsonl_files = [file for file in os.listdir(folder) if Path(file).suffix == '.jsonl' ] 154 | chunk_files = sorted([file for file in jsonl_files if re.match(r'.*[^_]+-\d+\.jsonl', file)]) 155 | chunk_files_dict = defaultdict(list) 156 | for file in chunk_files: 157 | task = '-'.join(file.split('-')[:-1]) 158 | chunk_files_dict[task].append(file) 159 | 160 | for task, files in chunk_files_dict.items(): 161 | lines = [] 162 | for file in sorted(files): 163 | file = os.path.join(folder, file) 164 | lines += read_manifest(file) 165 | os.remove(file) # Remove chunk files 166 | write_manifest(os.path.join(folder, f'{task}.jsonl'), lines) 167 | 168 | 169 | def main(): 170 | curr_folder = os.path.dirname(os.path.abspath(__file__)) 171 | 172 | try: 173 | module = importlib.import_module(f"{args.benchmark}.constants") 174 | except ImportError: 175 | print(f"Module eval.{args.benchmark}.constants not found.") 176 | 177 | tasks_base = module.TASKS 178 | with open(os.path.join(curr_folder, f"../{args.benchmark}.yaml"), "r") as f: 179 | tasks_customized = yaml.safe_load(f) 180 | 181 | 182 | TASKS = tasks_customized 183 | for _, config in TASKS.items(): 184 | config.update(tasks_base[config['task']]) 185 | 186 | print(f"Total tasks: {list(TASKS.keys())}") 187 | 188 | # Aggregate all prediction files 189 | aggregate_chunk(args.data_dir) 190 | 191 | # Get scores and nulls 192 | jsonl_files = [file for file in os.listdir(args.data_dir) if Path(file).suffix == '.jsonl'] 193 | eval_results = {} 194 | subm_results = {} 195 | 196 | 197 | for task, config in TASKS.items(): 198 | 199 | if f'{task}.jsonl' not in jsonl_files: 200 | print(f'Prediction file {task}.jsonl is not found.') 201 | continue 202 | 203 | print(f'Evaluate task {task}...') 204 | task_score, task_nulls, predicts, indices = run_evaluation_per_task( 205 | predictions_file=os.path.join(args.data_dir, f'{task}.jsonl'), 206 | task_config=config, 207 | ) 208 | eval_results[task] = { 209 | 'score': task_score, 210 | 'nulls': task_nulls, 211 | } 212 | subm_results[task] = { 213 | 'predicts': predicts, 214 | 'indices':indices, 215 | } 216 | 217 | # Write to csv 218 | write_evaluation(eval_results) 219 | write_submission(subm_results) 220 | 221 | if __name__ == '__main__': 222 | main() -------------------------------------------------------------------------------- /scripts/eval/synthetic/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Add a new task: 17 | 18 | TASK_NAME: { 19 | 'metric_fn': the metric function with input (predictions: [str], references: [[str]]) to compute score. 20 | } 21 | """ 22 | 23 | 24 | def string_match_part(preds, refs): 25 | score = sum([max([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) for pred, ref in zip(preds, refs)]) / len(preds) * 100 26 | return round(score, 2) 27 | 28 | def string_match_all(preds, refs): 29 | score = sum([sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref) for pred, ref in zip(preds, refs)]) / len(preds) * 100 30 | return round(score, 2) 31 | 32 | 33 | TASKS = { 34 | 'niah': { 35 | 'metric_fn': string_match_all, 36 | }, 37 | 'variable_tracking': { 38 | 'metric_fn': string_match_all, 39 | }, 40 | 'common_words_extraction': { 41 | 'metric_fn': string_match_all, 42 | }, 43 | 'freq_words_extraction': { 44 | 'metric_fn': string_match_all 45 | }, 46 | 'qa': { 47 | 'metric_fn': string_match_part, 48 | }, 49 | } 50 | -------------------------------------------------------------------------------- /scripts/pred/call_api.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Prepare prediction jsonl with field `pred` . 17 | dataset jsonl: 18 | { 19 | "index" int, 20 | "input": str, 21 | "outputs": [str], 22 | } 23 | 24 | prediction jsonl: 25 | { 26 | "index" int, 27 | "input": str, 28 | "outputs": [str], 29 | "pred": str, 30 | } 31 | """ 32 | 33 | import argparse 34 | import json 35 | import yaml 36 | import os 37 | import sys 38 | import threading 39 | import importlib 40 | import math 41 | import time 42 | from tqdm import tqdm 43 | from pathlib import Path 44 | import traceback 45 | from nemo.collections.asr.parts.utils.manifest_utils import read_manifest 46 | 47 | SERVER_TYPES = ( 48 | 'trtllm', 49 | 'vllm', 50 | 'sglang', 51 | 'openai', 52 | 'gemini', 53 | 'hf', 54 | 'mamba', 55 | ) 56 | 57 | 58 | class ServerAction(argparse.Action): 59 | def __call__(self, parser, namespace, values, option_string=None): 60 | namespace.server_type = values 61 | 62 | 63 | parser = argparse.ArgumentParser() 64 | # Data 65 | parser.add_argument("--data_dir", type=Path, required=True, help='path to load the dataset jsonl files') 66 | parser.add_argument("--save_dir", type=Path, required=True, help='path to save the prediction jsonl files') 67 | parser.add_argument("--benchmark", type=str, default='synthetic', help='Options: [synthetic]') 68 | parser.add_argument("--task", type=str, required=True, help='Options: tasks in benchmark') 69 | parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test') 70 | parser.add_argument("--chunk_idx", type=int, default=0, help='index of current split chunk') 71 | parser.add_argument("--chunk_amount", type=int, default=1, help='size of split chunk') 72 | 73 | # Server 74 | parser.add_argument("--server_type", default='nemo', action=ServerAction, choices=SERVER_TYPES) 75 | parser.add_argument("--server_host", type=str, default='127.0.0.1') 76 | parser.add_argument("--server_port", type=str, default='5000') 77 | parser.add_argument("--ssh_server", type=str) 78 | parser.add_argument("--ssh_key_path", type=str) 79 | parser.add_argument("--model_name_or_path", type=str, default='gpt-3.5-turbo', 80 | help='supported models from OpenAI or HF (provide a key or a local path to the checkpoint)') 81 | 82 | # Inference 83 | parser.add_argument("--temperature", type=float, default=1.0) 84 | parser.add_argument("--top_k", type=int, default=32) 85 | parser.add_argument("--top_p", type=float, default=1.0) 86 | parser.add_argument("--random_seed", type=int, default=0) 87 | parser.add_argument("--stop_words", type=str, default='') 88 | parser.add_argument("--sliding_window_size", type=int) 89 | parser.add_argument("--threads", type=int, default=4) 90 | parser.add_argument("--batch_size", type=int, default=1) 91 | 92 | args = parser.parse_args() 93 | args.stop_words = list(filter(None, args.stop_words.split(','))) 94 | if args.server_type == 'hf' or args.server_type == 'gemini': 95 | args.threads = 1 96 | 97 | 98 | def get_llm(tokens_to_generate): 99 | if args.server_type == 'trtllm': 100 | from client_wrappers import TRTLLMClient 101 | llm = TRTLLMClient( 102 | server_host=args.server_host, 103 | server_port=args.server_port, 104 | ssh_server=args.ssh_server, 105 | ssh_key_path=args.ssh_key_path, 106 | temperature=args.temperature, 107 | top_k=args.top_k, 108 | top_p=args.top_p, 109 | random_seed=args.random_seed, 110 | stop=args.stop_words, 111 | tokens_to_generate=tokens_to_generate, 112 | max_attention_window_size=args.sliding_window_size, 113 | ) 114 | 115 | elif args.server_type == 'vllm': 116 | from client_wrappers import VLLMClient 117 | llm = VLLMClient( 118 | server_host=args.server_host, 119 | server_port=args.server_port, 120 | ssh_server=args.ssh_server, 121 | ssh_key_path=args.ssh_key_path, 122 | temperature=args.temperature, 123 | top_k=args.top_k, 124 | top_p=args.top_p, 125 | random_seed=args.random_seed, 126 | stop=args.stop_words, 127 | tokens_to_generate=tokens_to_generate, 128 | ) 129 | 130 | elif args.server_type == 'sglang': 131 | from client_wrappers import SGLClient 132 | llm = SGLClient( 133 | server_host=args.server_host, 134 | server_port=args.server_port, 135 | ssh_server=args.ssh_server, 136 | ssh_key_path=args.ssh_key_path, 137 | temperature=args.temperature, 138 | top_k=args.top_k, 139 | top_p=args.top_p, 140 | random_seed=args.random_seed, 141 | stop=args.stop_words, 142 | tokens_to_generate=tokens_to_generate, 143 | ) 144 | 145 | elif args.server_type == 'openai': 146 | from client_wrappers import OpenAIClient 147 | llm = OpenAIClient( 148 | model_name=args.model_name_or_path, 149 | temperature=args.temperature, 150 | top_k=args.top_k, 151 | top_p=args.top_p, 152 | random_seed=args.random_seed, 153 | stop=args.stop_words, 154 | tokens_to_generate=tokens_to_generate, 155 | ) 156 | 157 | elif args.server_type == 'gemini': 158 | from client_wrappers import GeminiClient 159 | llm = GeminiClient( 160 | model_name=args.model_name_or_path, 161 | temperature=args.temperature, 162 | top_k=args.top_k, 163 | top_p=args.top_p, 164 | random_seed=args.random_seed, 165 | stop=args.stop_words, 166 | tokens_to_generate=tokens_to_generate, 167 | ) 168 | 169 | elif args.server_type == 'hf': 170 | from model_wrappers import HuggingFaceModel 171 | llm = HuggingFaceModel( 172 | name_or_path=args.model_name_or_path, 173 | do_sample=args.temperature > 0, 174 | repetition_penalty=1, 175 | temperature=args.temperature, 176 | top_k=args.top_k, 177 | top_p=args.top_p, 178 | stop=args.stop_words, 179 | max_new_tokens=tokens_to_generate, 180 | ) 181 | 182 | elif args.server_type == 'mamba': 183 | from model_wrappers import MambaModel 184 | # mamba uses its own generation function, do not pass in do_sample 185 | # https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/utils/generation.py#L121 186 | llm = MambaModel( 187 | name_or_path=args.model_name_or_path, 188 | repetition_penalty=1, 189 | temperature=args.temperature, 190 | top_k=args.top_k, 191 | top_p=args.top_p, 192 | stop=args.stop_words, 193 | max_new_tokens=tokens_to_generate, 194 | ) 195 | 196 | else: 197 | raise RuntimeError(f'Unsupported server type {args.server_type}') 198 | 199 | return llm 200 | 201 | 202 | def main(): 203 | start_time = time.time() 204 | 205 | curr_folder = os.path.dirname(os.path.abspath(__file__)) 206 | 207 | try: 208 | sys.path.append(os.path.dirname(curr_folder)) 209 | module = importlib.import_module(f"data.{args.benchmark}.constants") 210 | except ImportError: 211 | print(f"Module data.{args.benchmark}.constants not found.") 212 | 213 | tasks_base = module.TASKS 214 | with open(os.path.join(curr_folder, f"../{args.benchmark}.yaml"), "r") as f: 215 | tasks_customized = yaml.safe_load(f) 216 | 217 | if args.task not in tasks_customized: 218 | raise ValueError(f'{args.task} is not found in config_tasks.yaml') 219 | 220 | config = tasks_customized.get(args.task) 221 | config.update(tasks_base[config['task']]) 222 | 223 | task_file = args.data_dir / args.task / f'{args.subset}.jsonl' 224 | 225 | if args.chunk_amount > 1: 226 | pred_file = args.save_dir / f'{args.task}-{args.chunk_idx}.jsonl' 227 | else: 228 | pred_file = args.save_dir / f'{args.task}.jsonl' 229 | 230 | print(f'Predict {args.task} \nfrom {task_file}\nto {pred_file}') 231 | pred_file.parent.mkdir(parents=True, exist_ok=True) 232 | 233 | # Load data 234 | if os.path.exists(pred_file): 235 | pred_index = [sample['index'] for sample in read_manifest(pred_file)] 236 | data = [sample for sample in read_manifest(task_file) if sample['index'] not in pred_index] 237 | else: 238 | data = read_manifest(task_file) 239 | 240 | # Load api 241 | llm = get_llm(config['tokens_to_generate']) 242 | 243 | def get_output(idx_list, index_list, input_list, outputs_list, others_list, truncation_list, length_list): 244 | nonlocal llm 245 | 246 | while True: 247 | try: 248 | pred_list = llm.process_batch(prompts=input_list) 249 | break 250 | except Exception as e: 251 | traceback.print_exc() 252 | 253 | zipped_iter = zip(pred_list, idx_list, index_list, input_list, 254 | outputs_list, others_list, truncation_list, length_list) 255 | 256 | for pred, idx, index, input, outputs, others, truncation, length in zipped_iter: 257 | if isinstance(pred['text'], str): 258 | pred_text = pred['text'] 259 | elif len(pred['text']) > 0: 260 | pred_text = pred['text'][0] 261 | else: 262 | pred_text = '' 263 | 264 | outputs_parallel[idx] = { 265 | 'index': index, 266 | 'pred': pred_text, 267 | 'input': input, 268 | 'outputs': outputs, 269 | 'others': others, 270 | 'truncation': truncation, 271 | 'length': length, 272 | } 273 | 274 | threads = [] 275 | outputs_parallel = [{} for _ in range(len(data))] 276 | 277 | batched_data = [] 278 | batch = [] 279 | for idx, data_point in enumerate(data): 280 | data_point['idx'] = idx 281 | 282 | if len(batch) >= args.batch_size: 283 | batched_data.append(batch) 284 | batch = [] 285 | 286 | batch.append(data_point) 287 | 288 | if len(batch): 289 | batched_data.append(batch) 290 | 291 | # setting buffering=1 to force to dump the output after every line, so that we can see intermediate generations 292 | with open(pred_file, 'at', encoding="utf-8", buffering=1) as fout: 293 | # the data is processed sequentially, so we can store the start and end of current processing window 294 | start_idx = 0 # window: [start_idx, end_idx] 295 | 296 | for batch_idx, batch in tqdm(enumerate(batched_data), total=len(batched_data)): 297 | idx_list = [data_point['idx'] for data_point in batch] 298 | end_idx = idx_list[-1] # the data in a batch is ordered 299 | 300 | thread = threading.Thread( 301 | target=get_output, 302 | kwargs=dict( 303 | idx_list=idx_list, 304 | index_list=[data_point['index'] for data_point in batch], 305 | input_list=[data_point['input'] for data_point in batch], 306 | outputs_list=[data_point['outputs'] for data_point in batch], 307 | others_list=[data_point.get('others', {}) for data_point in batch], 308 | truncation_list=[data_point.get('truncation', -1) for data_point in batch], 309 | length_list=[data_point.get('length', -1) for data_point in batch], 310 | ), 311 | ) 312 | thread.start() 313 | threads.append(thread) 314 | 315 | is_last_batch = (batch_idx == len(batched_data) - 1) 316 | 317 | if (len(threads) == args.threads) or is_last_batch: 318 | for thread in threads: 319 | thread.join() 320 | threads = [] 321 | 322 | # dump the results in current processing window on disk 323 | for idx in range(start_idx, end_idx + 1): 324 | if len(outputs_parallel[idx]) > 0: 325 | fout.write(json.dumps(outputs_parallel[idx]) + '\n') 326 | 327 | start_idx = end_idx + 1 328 | 329 | print(f"Used time: {round((time.time() - start_time) / 60, 1)} minutes") 330 | 331 | 332 | if __name__ == '__main__': 333 | main() 334 | -------------------------------------------------------------------------------- /scripts/pred/client_wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import abc 17 | import json 18 | import multiprocessing 19 | import os 20 | import re 21 | import sys 22 | import time 23 | import requests 24 | import traceback 25 | from pathlib import Path 26 | from typing import List, Tuple, Union 27 | from concurrent.futures import ThreadPoolExecutor 28 | from collections import defaultdict 29 | from tenacity import ( 30 | retry, 31 | stop_after_attempt, 32 | wait_random_exponential, 33 | ) 34 | 35 | 36 | class Client(abc.ABC): 37 | def __init__( 38 | self, 39 | server_host, 40 | server_port='5000', 41 | ssh_server=None, 42 | ssh_key_path=None, 43 | **generation_kwargs 44 | ): 45 | self.server_host = server_host 46 | self.server_port = server_port 47 | self.ssh_server = os.getenv("SSH_SERVER", ssh_server) 48 | self.ssh_key_path = os.getenv("SSH_KEY_PATH", ssh_key_path) 49 | self.generation_kwargs = generation_kwargs 50 | 51 | @abc.abstractmethod 52 | def _single_call( 53 | self, 54 | prompts, 55 | ): 56 | pass 57 | 58 | def __call__( 59 | self, 60 | prompt: str, 61 | **kwargs 62 | ): 63 | request = self.generation_kwargs 64 | # prompts are added later 65 | request['prompts'] = [f'{prompt}'] 66 | if 'others' in kwargs: 67 | request['others'] = kwargs['others'] 68 | 69 | outputs = self._single_call(**request) 70 | response = {'text': outputs} 71 | return response 72 | 73 | @retry(wait=wait_random_exponential(min=15, max=60), stop=stop_after_attempt(3)) 74 | def _send_request(self, request, route="generate"): 75 | if self.ssh_server and self.ssh_key_path: 76 | import sshtunnel_requests 77 | 78 | sshtunnel_request = sshtunnel_requests.from_url(f"ssh://{self.ssh_server}:22", self.ssh_key_path) 79 | outputs = sshtunnel_request.put( 80 | url="http://{}:{}/{}".format(self.server_host, self.server_port, route), 81 | data=json.dumps(request), 82 | headers={"Content-Type": "application/json"}, 83 | ).json() 84 | else: 85 | outputs = requests.put( 86 | url="http://{}:{}/{}".format(self.server_host, self.server_port, route), 87 | data=json.dumps(request), 88 | headers={"Content-Type": "application/json"}, 89 | ).json() 90 | return outputs 91 | 92 | def process_batch(self, prompts: List[str], **kwargs) -> List[dict]: 93 | num_threads = max(96, multiprocessing.cpu_count() * 16) 94 | with ThreadPoolExecutor(num_threads) as executor: 95 | futures = [] 96 | for prompt in prompts: 97 | futures.append( 98 | executor.submit( 99 | self.__call__, 100 | prompt, 101 | **kwargs, 102 | ) 103 | ) 104 | rets = [f.result() for f in futures] 105 | return rets 106 | 107 | 108 | class TRTLLMClient(Client): 109 | def _single_call( 110 | self, 111 | prompts, 112 | tokens_to_generate, 113 | temperature, 114 | top_p, 115 | top_k, 116 | random_seed, 117 | stop: List[str], 118 | max_attention_window_size=None, 119 | ): 120 | request = { 121 | "prompts": prompts, 122 | "tokens_to_generate": tokens_to_generate, 123 | "temperature": temperature, 124 | "top_k": top_k, 125 | "top_p": top_p, 126 | "random_seed": random_seed, 127 | 'stop_words_list': ",".join(stop), 128 | } 129 | if max_attention_window_size: 130 | request["max_attention_window_size"] = max_attention_window_size 131 | 132 | outputs = self._send_request(request) 133 | return outputs 134 | 135 | 136 | class VLLMClient(Client): 137 | def _single_call( 138 | self, 139 | prompts, 140 | tokens_to_generate, 141 | temperature, 142 | top_p, 143 | top_k, 144 | random_seed, 145 | stop: List[str], 146 | ): 147 | request = { 148 | "prompt": prompts[0], 149 | "max_tokens": tokens_to_generate, 150 | "temperature": temperature, 151 | "top_k": top_k, 152 | "top_p": top_p, 153 | "stop": stop, 154 | } 155 | # TODO: random seed is not supported? 156 | outputs = self._send_request(request) 157 | outputs = outputs['text'] 158 | return outputs 159 | 160 | 161 | class SGLClient(Client): 162 | def _single_call( 163 | self, 164 | prompts, 165 | tokens_to_generate, 166 | temperature, 167 | top_p, 168 | top_k, 169 | random_seed, 170 | stop: List[str], 171 | ): 172 | request = { 173 | "text": prompts[0], 174 | "sampling_params": { 175 | "max_new_tokens": tokens_to_generate, 176 | "temperature": temperature, 177 | "top_k": top_k, 178 | "top_p": top_p, 179 | "stop": stop, 180 | } 181 | } 182 | # TODO: random seed is not supported? 183 | outputs = self._send_request(request) 184 | outputs = outputs['text'] 185 | return outputs 186 | 187 | 188 | class OpenAIClient: 189 | def __init__( 190 | self, 191 | model_name, 192 | **generation_kwargs 193 | ): 194 | model2length = { 195 | # OpenAI 196 | 'gpt-4': 8192, 197 | 'gpt-4-0613': 8192, 198 | 'gpt-4-1106-preview': 128000, 199 | 'gpt-4-0125-preview': 128000, 200 | 'gpt-4-turbo-preview': 128000, 201 | 'gpt-3.5-turbo-0125': 16385, 202 | 'gpt-3.5-turbo-1106': 16385, 203 | 'gpt-3.5-turbo-0613': 4096, 204 | 'gpt-3.5-turbo': 16385, 205 | 'gpt-3.5-turbo-16k': 16385, 206 | 'gpt-3.5-turbo-16k-0613': 16385, 207 | 208 | # Azure 209 | 'gpt-4-32k': 32768, 210 | 'gpt-4': 128000, 211 | 'gpt-35-turbo-16k': 16384, 212 | } 213 | self.openai_api_key = os.environ["OPENAI_API_KEY"] 214 | self.azure_api_id = os.environ["AZURE_API_ID"] 215 | self.azure_api_secret = os.environ["AZURE_API_SECRET"] 216 | self.azure_api_endpoint = os.environ["AZURE_API_ENDPOINT"] 217 | self.model_name = model_name 218 | 219 | # Azure 220 | if self.azure_api_id and self.azure_api_secret: 221 | if 'gpt-3.5' in model_name: self.model_name = 'gpt-35-turbo-16k' 222 | if 'gpt-4' in model_name: self.model_name = 'gpt-4' 223 | 224 | import tiktoken 225 | self.encoding = tiktoken.get_encoding("cl100k_base") 226 | self.max_length = model2length[self.model_name] 227 | self.generation_kwargs = generation_kwargs 228 | self._create_client() 229 | 230 | def _create_client(self,): 231 | from openai import OpenAI, AzureOpenAI 232 | 233 | # OpenAI 234 | if self.openai_api_key: 235 | self.client = OpenAI( 236 | api_key=self.openai_api_key 237 | ) 238 | 239 | # Azure 240 | elif self.azure_api_id and self.azure_api_secret: 241 | self.client = AzureOpenAI( 242 | api_key=self.get_azure_api_key( 243 | self.azure_api_id, 244 | self.azure_api_secret, 245 | self.azure_api_endpoint, 246 | ), 247 | api_version="2024-02-15-preview", 248 | azure_endpoint=os.path.join(self.azure_api_endpoint, "llm/v1/azure"), 249 | ) 250 | 251 | def _count_tokens(self, messages): 252 | tokens_per_message = 3 253 | tokens_per_name = 1 254 | num_tokens = 0 255 | for message in messages: 256 | num_tokens += tokens_per_message 257 | for key, value in message.items(): 258 | num_tokens += len(self.encoding.encode(value)) 259 | if key == "name": 260 | num_tokens += tokens_per_name 261 | num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> 262 | return num_tokens 263 | 264 | @retry(wait=wait_random_exponential(min=15, max=60), stop=stop_after_attempt(3)) 265 | def _send_request(self, request): 266 | try: 267 | response = self.client.chat.completions.create( 268 | model=self.model_name, 269 | messages=request['msgs'], 270 | max_tokens=request['tokens_to_generate'], 271 | temperature=request['temperature'], 272 | seed=request['random_seed'], 273 | top_p=request['top_p'], 274 | stop=request['stop'], 275 | ) 276 | except Exception as e: 277 | print(f"Error occurred while calling OpenAI: {e}") 278 | if self.azure_api_id and self.azure_api_secret and e.status_code == 401: 279 | # token expired 280 | self._create_client() 281 | 282 | return response 283 | 284 | def __call__( 285 | self, 286 | prompt: str, 287 | ): 288 | # system_msg = [{"role": "system", "content": ""}] 289 | system_msg = [] 290 | user_assistant_msgs = [{"role": "user", "content": prompt}] 291 | msgs = system_msg + user_assistant_msgs 292 | openai_length = self._count_tokens(msgs) 293 | request = self.generation_kwargs 294 | 295 | tokens_to_generate_new = self.max_length - openai_length 296 | if tokens_to_generate_new < request['tokens_to_generate']: 297 | print(f"Reduce generate tokens from {request['tokens_to_generate']} to {tokens_to_generate_new}") 298 | request['tokens_to_generate'] = tokens_to_generate_new 299 | 300 | request["msgs"] = msgs 301 | outputs = self._send_request(request) 302 | response = {'text': [outputs.choices[0].message.content]} 303 | return response 304 | 305 | 306 | def get_azure_api_key( 307 | self, 308 | p_client_id, 309 | p_client_secret, 310 | p_token_url, 311 | p_scope="azureopenai-readwrite", 312 | cache_file="azure_openai_key.json" 313 | ): 314 | base_path = Path(__file__).parent 315 | file_path = Path.joinpath(base_path, cache_file) 316 | 317 | # Check if the token is cached 318 | renew = True 319 | if os.path.exists(file_path): 320 | with open(file_path, "r") as f: 321 | token = json.load(f) 322 | renew = True if time.time() > token["expires_in"] else False 323 | 324 | if renew: 325 | # Get a new token from the OAuth server 326 | response = requests.post( 327 | os.path.join(p_token_url, "oauth/api/v1/ssa/default/token"), 328 | data={"grant_type": "client_credentials", "client_id": p_client_id, 329 | "client_secret": p_client_secret, "scope": p_scope} 330 | ) 331 | response.raise_for_status() 332 | token = response.json() 333 | token["expires_in"] += time.time() 334 | with open(file_path, "w") as f: 335 | json.dump(token, f) 336 | 337 | 338 | authToken = token["access_token"] 339 | return authToken 340 | 341 | 342 | class GeminiClient: 343 | def __init__( 344 | self, 345 | model_name, 346 | **generation_kwargs 347 | ): 348 | model2length = { 349 | 'gemini-1.0-pro-latest': (30720, 2048), 350 | 'gemini-1.5-pro-latest': (1048576, 8192) 351 | } 352 | 353 | self.model_name = model_name 354 | self.model = self._initialize_model() 355 | self.max_input_length = model2length[model_name][0] 356 | self.max_output_length = model2length[model_name][1] 357 | assert generation_kwargs['tokens_to_generate'] < self.max_output_length, \ 358 | print(f'tokens_to_generate exceeds {self.max_output_length}') 359 | 360 | import google.generativeai as genai 361 | self.config = genai.GenerationConfig( 362 | candidate_count=1, 363 | stop_sequences=generation_kwargs['stop'], 364 | max_output_tokens=generation_kwargs['tokens_to_generate'], 365 | temperature=generation_kwargs['temperature'], 366 | top_p=generation_kwargs['top_p'], 367 | top_k=generation_kwargs['top_k'], 368 | ) 369 | 370 | from google.generativeai.types import HarmCategory, HarmBlockThreshold 371 | self.safety_settings = { 372 | HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, 373 | HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, 374 | HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, 375 | HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, 376 | } 377 | 378 | @retry(wait=wait_random_exponential(min=60, max=60), stop=stop_after_attempt(3)) 379 | def _send_request(self, request): 380 | try: 381 | response = self.model.generate_content(request['prompt'], 382 | generation_config=request['config'], 383 | safety_settings=self.safety_settings) 384 | except Exception as e: 385 | traceback.print_exc() 386 | return None 387 | return response 388 | 389 | def __call__( 390 | self, 391 | prompt: str, 392 | ): 393 | assert self.model.count_tokens(prompt).total_tokens < self.max_input_length, \ 394 | print(f'input length exceeds {self.max_input_length}') 395 | 396 | request = { 397 | 'prompt': prompt, 398 | 'config': self.config, 399 | } 400 | 401 | outputs = self._send_request(request) 402 | 403 | try: 404 | response = {'text': [outputs.candidates[0].content.parts[0].text]} 405 | except Exception as e: 406 | response = {'text': []} 407 | print(outputs) 408 | traceback.print_exc() 409 | 410 | return response 411 | 412 | def _initialize_model(self): 413 | import google.generativeai as genai 414 | genai.configure(api_key=os.environ["GEMINI_API_KEY"]) 415 | return genai.GenerativeModel(self.model_name) 416 | 417 | -------------------------------------------------------------------------------- /scripts/pred/model_wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import logging 17 | import requests 18 | import torch 19 | from typing import Dict, List, Optional 20 | 21 | 22 | class HuggingFaceModel: 23 | def __init__(self, name_or_path: str, **generation_kwargs) -> None: 24 | from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline 25 | 26 | self.tokenizer = AutoTokenizer.from_pretrained(name_or_path, trust_remote_code=True) 27 | 28 | if 'Yarn-Llama' in name_or_path: 29 | model_kwargs = None 30 | else: 31 | model_kwargs = {"attn_implementation": "flash_attention_2"} 32 | 33 | try: 34 | self.pipeline = pipeline( 35 | "text-generation", 36 | model=name_or_path, 37 | tokenizer=self.tokenizer, 38 | trust_remote_code=True, 39 | device_map="auto", 40 | torch_dtype=torch.bfloat16, 41 | model_kwargs=model_kwargs, 42 | ) 43 | except: 44 | self.pipeline = None 45 | self.model = AutoModelForCausalLM.from_pretrained(name_or_path, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16,) 46 | 47 | self.generation_kwargs = generation_kwargs 48 | self.stop = self.generation_kwargs.pop('stop') 49 | 50 | if self.tokenizer.pad_token is None: 51 | # add pad token to allow batching (known issue for llama2) 52 | self.tokenizer.padding_side = 'left' 53 | self.tokenizer.pad_token = self.tokenizer.eos_token 54 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 55 | 56 | 57 | def __call__(self, prompt: str, **kwargs) -> dict: 58 | return self.process_batch([prompt], **kwargs)[0] 59 | 60 | def process_batch(self, prompts: List[str], **kwargs) -> List[dict]: 61 | if self.pipeline is None: 62 | inputs = self.tokenizer(prompts, return_tensors="pt", padding=True).to(self.model.device) 63 | generated_ids = self.model.generate( 64 | **inputs, 65 | **self.generation_kwargs 66 | ) 67 | generated_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) 68 | else: 69 | output = self.pipeline(text_inputs=prompts, **self.generation_kwargs, ) 70 | assert len(output) == len(prompts) 71 | # output in the form of a list of list of dictionaries 72 | # outer list len = batch size 73 | # inner list len = 1 74 | generated_texts = [llm_result[0]["generated_text"] for llm_result in output] 75 | 76 | results = [] 77 | 78 | for text, prompt in zip(generated_texts, prompts): 79 | # remove the input form the generated text 80 | # This is a workaround for the llama3 tokenizer not being able to reproduce the same prompt after tokenization 81 | # see Issue https://github.com/NVIDIA/RULER/issues/54 for explaination 82 | if self.pipeline is None: 83 | tokenized_prompt = self.tokenizer(prompt, return_tensors="pt", padding=True) 84 | prompt = self.tokenizer.decode(tokenized_prompt.input_ids[0], skip_special_tokens=True) 85 | if text.startswith(prompt): 86 | text = text[len(prompt):] 87 | 88 | if self.stop is not None: 89 | for s in self.stop: 90 | text = text.split(s)[0] 91 | 92 | results.append({'text': [text]}) 93 | 94 | return results 95 | 96 | 97 | class MambaModel: 98 | def __init__(self, name_or_path: str, **generation_kwargs) -> None: 99 | from transformers import AutoTokenizer 100 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 101 | 102 | self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 103 | self.device = "cuda" 104 | self.model = MambaLMHeadModel.from_pretrained(name_or_path, device=self.device, dtype=torch.bfloat16) 105 | self.generation_kwargs = generation_kwargs 106 | self.stop = self.generation_kwargs.pop('stop') 107 | self.max_genlen = self.generation_kwargs.pop('max_new_tokens') 108 | self.minp = 0.0 109 | 110 | def __call__(self, prompt: str, **kwargs) -> Dict[str, List[str]]: 111 | # tokenize 112 | tokens = self.tokenizer(prompt, return_tensors="pt") 113 | input_ids = tokens.input_ids.to(self.device) 114 | max_length = input_ids.shape[1] + self.max_genlen 115 | 116 | # generate 117 | out = self.model.generate( 118 | input_ids=input_ids, 119 | max_length=max_length, 120 | cg=True, 121 | return_dict_in_generate=True, 122 | output_scores=True, 123 | enable_timing=False, 124 | **self.generation_kwargs, 125 | ) 126 | assert len(out.sequences) == 1 127 | # detok 128 | return {'text': [self.tokenizer.decode(out.sequences[0][input_ids.shape[1]:])]} 129 | 130 | def process_batch(self, prompts: List[str], **kwargs) -> List[dict]: 131 | # FIXME: naive implementation 132 | return [self.__call__(prompt, **kwargs) for prompt in prompts] 133 | -------------------------------------------------------------------------------- /scripts/pred/serve_trt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # adapted from https://github.com/Kipok/NeMo-Skills/blob/v0.1/nemo_skills/inference/server/serve_trt.py 16 | 17 | import json 18 | import logging 19 | import sys 20 | import json 21 | from pathlib import Path 22 | from argparse import ArgumentParser 23 | 24 | import numpy as np 25 | import torch 26 | import tensorrt_llm 27 | from flask import Flask, jsonify, request 28 | from flask_restful import Api, Resource 29 | from tensorrt_llm.runtime import ModelRunnerCpp 30 | from mpi4py import MPI 31 | from transformers import AutoTokenizer 32 | 33 | 34 | class TritonServerGenerate(Resource): 35 | def __init__(self, model): 36 | self.model = model 37 | self.comm = MPI.COMM_WORLD 38 | 39 | def generate( 40 | self, 41 | prompts, 42 | max_new_tokens, 43 | temperature, 44 | top_k, 45 | top_p, 46 | repetition_penalty, 47 | random_seed, 48 | stop_words_list, 49 | max_attention_window_size=None 50 | ): 51 | output = self.model.forward( 52 | prompts, 53 | max_output_token=max_new_tokens, 54 | top_k=top_k, 55 | top_p=top_p, 56 | temperature=temperature, 57 | repetition_penalty=repetition_penalty, 58 | random_seed=random_seed, 59 | stop_words_list=stop_words_list, 60 | max_attention_window_size=max_attention_window_size, 61 | ) 62 | return output 63 | 64 | def put(self): 65 | logging.info("request IP: " + str(request.remote_addr)) 66 | logging.info(json.dumps(request.get_json())) 67 | 68 | input_request = request.get_json() 69 | 70 | tokens_to_generate = input_request.get("tokens_to_generate", 64) 71 | temperature = input_request.get("temperature", 1.0) 72 | top_k = input_request.get("top_k", 0) 73 | top_p = input_request.get("top_p", 1.0) 74 | repetition_penalty = input_request.get("repetition_penalty", 1.2) 75 | stop_words_list = input_request.get("stop_words_list") 76 | max_attention_window_size = input_request.get("max_attention_window_size") 77 | random_seed = input_request.get("random_seed", 0) 78 | prompts = input_request["prompts"] 79 | 80 | data = dict( 81 | prompts=prompts, 82 | max_new_tokens=tokens_to_generate, 83 | temperature=temperature, 84 | top_k=top_k, 85 | top_p=top_p, 86 | repetition_penalty=repetition_penalty, 87 | random_seed=random_seed, 88 | stop_words_list=stop_words_list, 89 | max_attention_window_size=max_attention_window_size, 90 | ) 91 | self.comm.Barrier() 92 | data = self.comm.bcast(data, root=0) 93 | 94 | out = self.generate(**data) 95 | return jsonify(out) 96 | 97 | 98 | def parse_input(input_texts: str, tokenizer): 99 | batch_input_ids = [ 100 | tokenizer.encode( 101 | input_text, 102 | add_special_tokens=False, # TODO: does this need to be true? 103 | ) 104 | for input_text in input_texts 105 | ] 106 | batch_input_ids = [torch.tensor(x, dtype=torch.int32, device="cuda") for x in batch_input_ids] 107 | input_lengths = [x.size(0) for x in batch_input_ids] 108 | 109 | return batch_input_ids, input_lengths 110 | 111 | 112 | def get_output(output_ids, input_lengths, max_output_len, tokenizer, eos_token): 113 | num_beams = output_ids.size(1) 114 | assert num_beams == 1 115 | output_texts = [] 116 | for idx, input_len in enumerate(input_lengths): 117 | output_begin = input_len 118 | output_end = input_len + max_output_len 119 | outputs = output_ids[idx][0][output_begin:output_end] 120 | eos_ids = (outputs == eos_token).nonzero(as_tuple=True)[-1] 121 | if len(eos_ids) > 0: 122 | outputs = outputs[: eos_ids[0]] 123 | outputs = outputs.tolist() 124 | output_texts.append(tokenizer.decode(outputs)) 125 | return output_texts 126 | 127 | 128 | def prepare_stop_words(stop_words_list, tokenizer): 129 | # adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/b310ec675145c9ee7668592549f733df4abf1e94/tensorrt_llm/runtime/generation.py#L46 130 | flat_ids = [] 131 | offsets = [] 132 | for batch_stop_words in stop_words_list: 133 | item_flat_ids = [] 134 | item_offsets = [] 135 | 136 | for word in batch_stop_words: 137 | # there is a known issue in TensorRT-LLM that word ids are not unique and might change depending on 138 | # where in the text it appears. In our case we mainly need to stop on ids as they appear in the middle 139 | # of the text. The following is a workaround to get such ids that works for both kind of stop 140 | # words as well as newlines that we commonly use. But note that it's not a universal fix, so this might 141 | # require refactoring if different stop words are used in the future. 142 | # Eventually, this needs to be fixed inside TensorRT-LLM itself. 143 | ids = tokenizer.encode('magic' + word) 144 | ids = ids[2:] # skipping "magic" 145 | 146 | if len(ids) == 0: 147 | continue 148 | 149 | item_flat_ids += ids 150 | item_offsets.append(len(ids)) 151 | 152 | flat_ids.append(np.array(item_flat_ids)) 153 | offsets.append(np.cumsum(np.array(item_offsets))) 154 | 155 | pad_to = max(1, max(len(ids) for ids in flat_ids)) 156 | 157 | for i, (ids, offs) in enumerate(zip(flat_ids, offsets)): 158 | flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0) 159 | offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1) 160 | 161 | stop_words = np.array([flat_ids, offsets], dtype="int32").transpose((1, 0, 2)) 162 | return torch.Tensor(stop_words).to(torch.int32).to("cuda").contiguous() 163 | 164 | 165 | def load_tokenizer(tokenizer_dir: str): 166 | tokenizer = AutoTokenizer.from_pretrained( 167 | tokenizer_dir, 168 | legacy=False, 169 | trust_remote_code=True, 170 | ) 171 | 172 | if tokenizer.pad_token_id is None: 173 | tokenizer.pad_token_id = tokenizer.eos_token_id 174 | pad_id = tokenizer.pad_token_id 175 | end_id = tokenizer.eos_token_id 176 | 177 | return tokenizer, pad_id, end_id 178 | 179 | 180 | 181 | class TensorRTLLM: 182 | def __init__(self, model_path: str): 183 | self.tokenizer, self.pad_id, self.end_id = load_tokenizer(tokenizer_dir=model_path) 184 | self.runner = ModelRunnerCpp.from_dir(engine_dir=model_path, rank=tensorrt_llm.mpi_rank()) 185 | 186 | @torch.no_grad() 187 | def forward( 188 | self, 189 | input_texts, 190 | max_output_token, 191 | top_k, 192 | top_p, 193 | temperature, 194 | repetition_penalty, 195 | random_seed, 196 | stop_words_list, 197 | max_attention_window_size, 198 | ): 199 | batch_input_ids, input_lengths = parse_input(input_texts, self.tokenizer) 200 | 201 | stop_words_list = [stop_words_list for _ in range(len(input_texts))] 202 | stop_words_list = prepare_stop_words(stop_words_list, self.tokenizer) 203 | 204 | # TODO: return dictionary with a proper error reporting 205 | try: 206 | output_ids = self.runner.generate( 207 | batch_input_ids, 208 | max_new_tokens=max_output_token, 209 | end_id=self.end_id, 210 | pad_id=self.pad_id, 211 | temperature=temperature, 212 | top_k=top_k, 213 | top_p=top_p, 214 | repetition_penalty=repetition_penalty, 215 | random_seed=random_seed, 216 | stop_words_list=stop_words_list, 217 | max_attention_window_size=max_attention_window_size, 218 | return_dict=False, 219 | ) 220 | torch.cuda.synchronize() 221 | 222 | output = get_output(output_ids, input_lengths, max_output_token, self.tokenizer, self.end_id) 223 | except RuntimeError as e: 224 | logging.error("RuntimeError: %s", e) 225 | output = [f"RuntimeError: {e}"] * len(input_texts) 226 | 227 | return output 228 | 229 | 230 | class WrapperServer: 231 | def __init__(self, model_path: str): 232 | self.comm = MPI.COMM_WORLD 233 | self.rank = self.comm.Get_rank() 234 | 235 | self.model = TensorRTLLM(model_path=model_path) 236 | 237 | if self.rank == 0: 238 | self.app = Flask(__file__, static_url_path="") 239 | api = Api(self.app) 240 | api.add_resource(TritonServerGenerate, "/generate", resource_class_args=[self.model]) 241 | 242 | def run(self, url, port=5000): 243 | if self.rank == 0: 244 | self.app.run(url, threaded=True, port=port, debug=False) 245 | else: 246 | self.worker_loop() 247 | 248 | def worker_loop(self): 249 | triton = TritonServerGenerate(self.model) 250 | while True: 251 | self.comm.Barrier() 252 | data = None 253 | data = self.comm.bcast(data, root=0) 254 | triton.generate(**data) 255 | 256 | 257 | if __name__ == "__main__": 258 | # TODO: can we reuse normal logger here? 259 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 260 | 261 | parser = ArgumentParser() 262 | parser.add_argument("--model_path", required=True) 263 | parser.add_argument("--host", type=str, default="0.0.0.0") 264 | parser.add_argument("--port", type=int, default=5000) 265 | args = parser.parse_args() 266 | 267 | server = WrapperServer(model_path=args.model_path) 268 | server.run(args.host, args.port) 269 | -------------------------------------------------------------------------------- /scripts/pred/serve_vllm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # adapted from https://github.com/vllm-project/vllm/blob/v0.4.0/vllm/entrypoints/api_server.py 16 | 17 | 18 | import json 19 | from typing import AsyncGenerator 20 | 21 | from fastapi import FastAPI, Request 22 | from fastapi.responses import JSONResponse, Response, StreamingResponse 23 | import uvicorn 24 | 25 | from vllm.engine.arg_utils import AsyncEngineArgs 26 | from vllm.engine.async_llm_engine import AsyncLLMEngine 27 | from vllm.sampling_params import SamplingParams 28 | from vllm.utils import random_uuid 29 | from vllm.utils import FlexibleArgumentParser 30 | 31 | TIMEOUT_KEEP_ALIVE = 5 # seconds. 32 | app = FastAPI() 33 | engine = None 34 | 35 | 36 | @app.get("/health") 37 | async def health() -> Response: 38 | """Health check.""" 39 | return Response(status_code=200) 40 | 41 | 42 | @app.put("/generate") 43 | async def generate(request: Request) -> Response: 44 | """Generate completion for the request. 45 | 46 | The request should be a JSON object with the following fields: 47 | - prompt: the prompt to use for the generation. 48 | - stream: whether to stream the results or not. 49 | - other fields: the sampling parameters (See `SamplingParams` for details). 50 | """ 51 | request_dict = await request.json() 52 | prompt = request_dict.pop("prompt") 53 | stream = request_dict.pop("stream", False) 54 | sampling_params = SamplingParams(**request_dict) 55 | request_id = random_uuid() 56 | 57 | results_generator = engine.generate(prompt, 58 | sampling_params, 59 | request_id) 60 | 61 | # Streaming case 62 | async def stream_results() -> AsyncGenerator[bytes, None]: 63 | async for request_output in results_generator: 64 | prompt = request_output.prompt 65 | text_outputs = [ 66 | prompt + output.text for output in request_output.outputs 67 | ] 68 | ret = {"text": text_outputs} 69 | yield (json.dumps(ret) + "\0").encode("utf-8") 70 | 71 | if stream: 72 | return StreamingResponse(stream_results()) 73 | 74 | # Non-streaming case 75 | final_output = None 76 | async for request_output in results_generator: 77 | if await request.is_disconnected(): 78 | # Abort the request if the client disconnects. 79 | await engine.abort(request_id) 80 | return Response(status_code=499) 81 | final_output = request_output 82 | assert final_output is not None 83 | text_outputs = [output.text for output in final_output.outputs] 84 | ret = {"text": text_outputs} 85 | return JSONResponse(ret) 86 | 87 | 88 | if __name__ == "__main__": 89 | parser = FlexibleArgumentParser( 90 | description="vLLM API server for serving LLMs asynchronously", 91 | epilog="For more information, see https://docs.vllm.ai/en/stable/configuration/engine_args.html", 92 | ) 93 | parser.add_argument("--host", type=str, default="0.0.0.0") 94 | parser.add_argument("--port", type=int, default=5000) 95 | parser.add_argument("--ssl-keyfile", type=str, default=None) 96 | parser.add_argument("--ssl-certfile", type=str, default=None) 97 | parser.add_argument( 98 | "--root-path", 99 | type=str, 100 | default=None, 101 | help="FastAPI root_path when app is behind a path based routing proxy") 102 | parser = AsyncEngineArgs.add_cli_args(parser) 103 | args = parser.parse_args() 104 | 105 | engine_args = AsyncEngineArgs.from_cli_args(args) 106 | engine = AsyncLLMEngine.from_engine_args(engine_args) 107 | 108 | app.root_path = args.root_path 109 | uvicorn.run(app, 110 | host=args.host, 111 | port=args.port, 112 | log_level="debug", 113 | timeout_keep_alive=TIMEOUT_KEEP_ALIVE, 114 | ssl_keyfile=args.ssl_keyfile, 115 | ssl_certfile=args.ssl_certfile) -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # container: docker.io/cphsieh/ruler:0.1.0 17 | # bash run.sh MODEL_NAME BENCHMARK_NAME 18 | 19 | if [ $# -ne 2 ]; then 20 | echo "Usage: $0 $1 " 21 | exit 1 22 | fi 23 | 24 | 25 | # Root Directories 26 | GPUS="1" # GPU size for tensor_parallel. 27 | ROOT_DIR="benchmark_root" # the path that stores generated task samples and model predictions. 28 | MODEL_DIR="../.." # the path that contains individual model folders from HUggingface. 29 | ENGINE_DIR="." # the path that contains individual engine folders from TensorRT-LLM. 30 | BATCH_SIZE=1 # increase to improve GPU utilization 31 | 32 | 33 | # Model and Tokenizer 34 | source config_models.sh 35 | MODEL_NAME=${1} 36 | MODEL_CONFIG=$(MODEL_SELECT ${MODEL_NAME} ${MODEL_DIR} ${ENGINE_DIR}) 37 | IFS=":" read MODEL_PATH MODEL_TEMPLATE_TYPE MODEL_FRAMEWORK TOKENIZER_PATH TOKENIZER_TYPE OPENAI_API_KEY GEMINI_API_KEY AZURE_ID AZURE_SECRET AZURE_ENDPOINT <<< "$MODEL_CONFIG" 38 | if [ -z "${MODEL_PATH}" ]; then 39 | echo "Model: ${MODEL_NAME} is not supported" 40 | exit 1 41 | fi 42 | 43 | 44 | export OPENAI_API_KEY=${OPENAI_API_KEY} 45 | export GEMINI_API_KEY=${GEMINI_API_KEY} 46 | export AZURE_API_ID=${AZURE_ID} 47 | export AZURE_API_SECRET=${AZURE_SECRET} 48 | export AZURE_API_ENDPOINT=${AZURE_ENDPOINT} 49 | 50 | 51 | # Benchmark and Tasks 52 | source config_tasks.sh 53 | BENCHMARK=${2} 54 | declare -n TASKS=$BENCHMARK 55 | if [ -z "${TASKS}" ]; then 56 | echo "Benchmark: ${BENCHMARK} is not supported" 57 | exit 1 58 | fi 59 | 60 | 61 | # Start server (you may want to run in other container.) 62 | if [ "$MODEL_FRAMEWORK" == "vllm" ]; then 63 | python pred/serve_vllm.py \ 64 | --model=${MODEL_PATH} \ 65 | --tensor-parallel-size=${GPUS} \ 66 | --dtype bfloat16 \ 67 | --disable-custom-all-reduce \ 68 | & 69 | 70 | elif [ "$MODEL_FRAMEWORK" == "trtllm" ]; then 71 | python pred/serve_trt.py \ 72 | --model_path=${MODEL_PATH} \ 73 | & 74 | 75 | elif [ "$MODEL_FRAMEWORK" == "sglang" ]; then 76 | python -m sglang.launch_server \ 77 | --model-path ${MODEL_PATH} \ 78 | --tp ${GPUS} \ 79 | --port 5000 \ 80 | --enable-flashinfer \ 81 | & 82 | # use sglang/test/killall_sglang.sh to kill sglang server if it hangs 83 | 84 | fi 85 | 86 | 87 | # Start client (prepare data / call model API / obtain final metrics) 88 | total_time=0 89 | for MAX_SEQ_LENGTH in "${SEQ_LENGTHS[@]}"; do 90 | 91 | RESULTS_DIR="${ROOT_DIR}/${MODEL_NAME}/${BENCHMARK}/${MAX_SEQ_LENGTH}" 92 | DATA_DIR="${RESULTS_DIR}/data" 93 | PRED_DIR="${RESULTS_DIR}/pred" 94 | mkdir -p ${DATA_DIR} 95 | mkdir -p ${PRED_DIR} 96 | 97 | for TASK in "${TASKS[@]}"; do 98 | python data/prepare.py \ 99 | --save_dir ${DATA_DIR} \ 100 | --benchmark ${BENCHMARK} \ 101 | --task ${TASK} \ 102 | --tokenizer_path ${TOKENIZER_PATH} \ 103 | --tokenizer_type ${TOKENIZER_TYPE} \ 104 | --max_seq_length ${MAX_SEQ_LENGTH} \ 105 | --model_template_type ${MODEL_TEMPLATE_TYPE} \ 106 | --num_samples ${NUM_SAMPLES} \ 107 | ${REMOVE_NEWLINE_TAB} 108 | 109 | start_time=$(date +%s) 110 | python pred/call_api.py \ 111 | --data_dir ${DATA_DIR} \ 112 | --save_dir ${PRED_DIR} \ 113 | --benchmark ${BENCHMARK} \ 114 | --task ${TASK} \ 115 | --server_type ${MODEL_FRAMEWORK} \ 116 | --model_name_or_path ${MODEL_PATH} \ 117 | --temperature ${TEMPERATURE} \ 118 | --top_k ${TOP_K} \ 119 | --top_p ${TOP_P} \ 120 | --batch_size ${BATCH_SIZE} \ 121 | ${STOP_WORDS} 122 | end_time=$(date +%s) 123 | time_diff=$((end_time - start_time)) 124 | total_time=$((total_time + time_diff)) 125 | done 126 | 127 | python eval/evaluate.py \ 128 | --data_dir ${PRED_DIR} \ 129 | --benchmark ${BENCHMARK} 130 | done 131 | 132 | echo "Total time spent on call_api: $total_time seconds" 133 | -------------------------------------------------------------------------------- /scripts/synthetic.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | niah_single_1: 16 | task: niah 17 | args: 18 | type_haystack: noise 19 | type_needle_k: words 20 | type_needle_v: numbers 21 | num_needle_k: 1 22 | num_needle_v: 1 23 | num_needle_q: 1 24 | 25 | niah_single_2: 26 | task: niah 27 | args: 28 | type_haystack: essay 29 | type_needle_k: words 30 | type_needle_v: numbers 31 | num_needle_k: 1 32 | num_needle_v: 1 33 | num_needle_q: 1 34 | 35 | niah_single_3: 36 | task: niah 37 | args: 38 | type_haystack: essay 39 | type_needle_k: words 40 | type_needle_v: uuids 41 | num_needle_k: 1 42 | num_needle_v: 1 43 | num_needle_q: 1 44 | 45 | niah_multikey_1: 46 | task: niah 47 | args: 48 | type_haystack: essay 49 | type_needle_k: words 50 | type_needle_v: numbers 51 | num_needle_k: 4 52 | num_needle_v: 1 53 | num_needle_q: 1 54 | 55 | niah_multikey_2: 56 | task: niah 57 | args: 58 | type_haystack: needle 59 | type_needle_k: words 60 | type_needle_v: numbers 61 | num_needle_k: 1 62 | num_needle_v: 1 63 | num_needle_q: 1 64 | 65 | niah_multikey_3: 66 | task: niah 67 | args: 68 | type_haystack: needle 69 | type_needle_k: uuids 70 | type_needle_v: uuids 71 | num_needle_k: 1 72 | num_needle_v: 1 73 | num_needle_q: 1 74 | 75 | niah_multivalue: 76 | task: niah 77 | args: 78 | type_haystack: essay 79 | type_needle_k: words 80 | type_needle_v: numbers 81 | num_needle_k: 1 82 | num_needle_v: 4 83 | num_needle_q: 1 84 | 85 | niah_multiquery: 86 | task: niah 87 | args: 88 | type_haystack: essay 89 | type_needle_k: words 90 | type_needle_v: numbers 91 | num_needle_k: 1 92 | num_needle_v: 1 93 | num_needle_q: 4 94 | 95 | vt: 96 | task: variable_tracking 97 | args: 98 | type_haystack: noise 99 | num_chains: 1 100 | num_hops: 4 101 | 102 | cwe: 103 | task: common_words_extraction 104 | args: 105 | freq_cw: 30 106 | freq_ucw: 3 107 | num_cw: 10 108 | 109 | fwe: 110 | task: freq_words_extraction 111 | args: 112 | alpha: 2.0 113 | 114 | qa_1: 115 | task: qa 116 | args: 117 | dataset: squad 118 | 119 | qa_2: 120 | task: qa 121 | args: 122 | dataset: hotpotqa --------------------------------------------------------------------------------