├── .env.sample ├── .gitignore ├── LICENSE ├── README.md ├── benchrag ├── README.md ├── get_results.py ├── plot_results.py ├── run_inference.py └── run_judging.py ├── configs └── deepspeed │ ├── stage2_no_offloading_accelerate.conf │ ├── stage2_offloading_accelerate.conf │ ├── stage3_no_offloading.conf │ ├── stage3_no_offloading_accelerate.conf │ ├── stage3_offloading.conf │ └── stage3_offloading_accelerate.conf ├── finetunerag ├── arguments.py ├── finetune.py ├── model_utils.py └── utils.py ├── media ├── accuracy_by_steps.png ├── depth_by_steps.png ├── helpfulness_by_steps.png ├── pints_ai-banner.png └── relevance_by_steps.png ├── prepare_dataset ├── README.md ├── content_generation │ ├── generate_answer.py │ ├── generate_fictitious_content.py │ └── generate_question.py └── formatting │ ├── formatter.py │ ├── generate_training_data.py │ └── split_data.py ├── prompts ├── answer_generation_prompt.py ├── fictitious_content_generation_prompt.py ├── judging_prompts.py ├── prompt_styles.py ├── question_generation_prompt.py └── rag_prompt.py ├── requirements.txt └── utils ├── dataset_utils.py ├── logger.py └── openai.py /.env.sample: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY="" 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | 3 | dataset/ 4 | logs/ 5 | ragbench/inferences 6 | ragbench/judge_results 7 | wandb/ 8 | 9 | .env -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Pints.ai 3 |

4 | 5 | # Finetune-RAG: Fine-tuning Models to Tackle Retrieval-Augmented Generation (RAG) Hallucination 6 | 7 | This repository provides an open-source framework to fine-tune large language models (LLMs) for improving their ability to discern correct information from irrelevant or fictitious data when using Retrieval-Augmented Generation (RAG) systems. By training models to distinguish between relevant and misleading contexts, we aim to reduce the hallucination problem in LLM-generated responses, enhancing the reliability of models in real-world use cases. 8 | 9 | ## Paper & Citation 10 | 11 | ```latex 12 | @misc{lee2025finetuneragfinetuninglanguagemodels, 13 | title={Finetune-RAG: Fine-Tuning Language Models to Resist Hallucination in Retrieval-Augmented Generation}, 14 | author={Zhan Peng Lee and Andre Lin and Calvin Tan}, 15 | year={2025}, 16 | eprint={2505.10792}, 17 | archivePrefix={arXiv}, 18 | primaryClass={cs.CL}, 19 | url={https://arxiv.org/abs/2505.10792}, 20 | } 21 | ``` 22 | 23 | ## Problem Overview 24 | 25 | When integrating retrieval into LLM workflows, models often rely on external documents to provide factual information in response to user queries. However, if incorrect or irrelevant documents are retrieved, the model may generate incorrect responses by "hallucinating" based on misleading data. This repository addresses the issue by fine-tuning models to: 26 | 27 | - Recognize and ignore fictitious or irrelevant documents. 28 | - Focus on relevant, factually correct context. 29 | - Generate accurate answers, even when faced with conflicting or unreliable data. 30 | 31 | ## Approach 32 | 33 | Our method involves fine-tuning LLMs using carefully designed prompts that provide two types of data: 34 | 35 | 1. **A correct, factually grounded chunk**. 36 | 2. **A fictitious, misleading chunk**. 37 | 38 | The fine-tuning process teaches the model to focus solely on the correct chunk, filtering out the irrelevant or false context. The training labels are the correct answers based on the factual context, guiding the model to avoid using the fictitious information during generation. 39 | 40 | ### Key Steps 41 | 42 | 1. **Data Construction**: 43 | - For each training example, the model is given a question, one chunk of data that contains the correct information, and a second chunk that contains fictitious information. 44 | 45 | 2. **Training Objective**: 46 | - The model is trained to generate a correct answer by leveraging the factual chunk while ignoring the fictitious one. 47 | 48 | 3. **Evaluation**: 49 | - We evaluate using **Bench-RAG**, our devised LLM-as-a-judge framework. During evaluation, the model is tested on its ability to provide accurate answers, with performance measured by its capacity to ignore misleading data and produce responses based only on the correct context, as judged by GPT4o. 50 | 51 | ## Setup 52 | 53 | ### Install conda 54 | 55 | ```bash 56 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 57 | sh Miniconda3-latest-Linux-x86_64.sh 58 | ``` 59 | 60 | Source just to be sure `conda` cli will be available: 61 | 62 | ```bash 63 | source ~/.bashrc 64 | ``` 65 | 66 | Sometimes if you still face `conda: command cannot be found`, you can find the installation and source it: 67 | 68 | `Note: This path assumes you took up the default installation settings. Otherwise, find where you installed it.` 69 | 70 | ```bash 71 | source ~/miniconda3/etc/profile.d/conda.sh 72 | ``` 73 | 74 | ## Clone this repo 75 | 76 | ```bash 77 | git clone https://github.com/Pints-AI/Finetune-Bench-RAG.git && \ 78 | cd Finetune-Bench-RAG 79 | ``` 80 | 81 | ## Create conda env 82 | 83 | ```bash 84 | conda create --prefix ./.conda python=3.10 && \ 85 | conda activate ./.conda 86 | ``` 87 | 88 | `Note`: Stick to Python 3.10. 3.12 breaks a lot of things as of now (23 Feb 2024), and 3.11 has not been tested. 89 | 90 | ## Install CUDA toolkit 91 | 92 | ```bash 93 | conda install nvidia/label/cuda-12.1.1::cuda-toolkit 94 | ``` 95 | 96 | ## Install requirements 97 | 98 | ```bash 99 | pip install torch==2.6.0 && \ 100 | pip install -r requirements.txt 101 | ``` 102 | 103 | ## Training 104 | 105 | ### Dataset Preparation 106 | 107 | The dataset used in this project consists of over 1,600 documents manually scraped from a wide range of sources. These documents are organised into categories, including legal documents (e.g., contracts, governance, compliance), research papers from multiple scientific fields, books (both fiction and non-fiction), web content, news articles, parliamentary debates, and government publications. Additionally, the dataset includes industry-specific documents such as technical documentation, patents, market research, and code repositories. 108 | 109 | This diverse dataset ensures that the model is exposed to varied contexts, allowing it to learn how to identify and filter out irrelevant or fictitious information across multiple domains. You can access it from our [Hugging Face dataset page](https://huggingface.co/datasets/pints-ai/Finetune-RAG). 110 | 111 | Before training, download the dataset: 112 | 113 | ```bash 114 | huggingface-cli download --local-dir dataset/ --repo-type dataset pints-ai/Finetune-RAG 115 | ``` 116 | Afterwards, process the dataset to suit the training pipeline. Refer to the [`prepare_dataset/`](prepare_dataset/) folder for more information. 117 | 118 | ### Model preparation 119 | 120 | Generally, most huggingface-compatible causal language models should work fine with our codebase, potentially with some adjusting for different tokenizers etc. Some models may require additional requests to download. E.g., for LLaMa, please consult [the Hugging Face documentation](https://huggingface.co/docs/transformers/model_doc/llama) for requesting access and converting them to a huggingface-compatible format. 121 | 122 | Notably, we want to finetune instruct models to retain its conversational capabilities. 123 | 124 | ### Prompt Strategy 125 | 126 | We have standardised the way we include the retrieved content in the prompt: 127 | 128 | #### System Message 129 | 130 | For all training, our system message is as follows: 131 | 132 | ```python 133 | SYSTEM_PROMPT = 'Some information is retrieved from the database as provided based on the user’s question. The assistant is to answer the question to the best of his/her ability, using only the information provided. The assistant must not add his/her own knowledge.' 134 | ``` 135 | 136 | The goal of SFT is to enhance the model's ficticious content recognition capabilities. As such, we do not wish to overly prompt-engineer and influence the model in its choice of content used in its answer generation. 137 | 138 | #### User Message 139 | 140 | We have defined the message into two categories: Baseline & XML 141 | 142 | The Baseline user message serves to generically provide the content to the model. We train the model with one genuine content, and one fictitious content, with its order at random. The format is as follows: 143 | 144 | ``` 145 | Filename: {filename1} 146 | Information: 147 | {content1} 148 | 149 | Filename: {filename2} 150 | Information: 151 | {content2} 152 | 153 | Question: {question} 154 | ``` 155 | 156 | The XML user message is similar, but with its content encapsulated with XML tags: 157 | ``` 158 | 159 | 160 | {filename1} 161 | {content1} 162 | 163 | 164 | {filename2} 165 | {content2} 166 | 167 | 168 | 169 | Question: {question} 170 | ``` 171 | 172 | We conduct separate SFTs for both formats. Results are below. 173 | 174 | ### Finetuning 175 | 176 | The codebase is `deepspeed`-with-`accelerate` enabled. If the [various existing `deepspeed` configs](configs/deepspeed/) are not what you are looking for, path your custom config into the execution command. 177 | 178 | Below is an example command to run finetuning on Llama3.1-8B-Instruct model: 179 | 180 | ```bash 181 | accelerate launch \ 182 | --mixed_precision bf16 \ 183 | --num_machines 1 \ 184 | --num_processes 1 \ 185 | --use_deepspeed \ 186 | --deepspeed_config_file configs/ds_configs/stage2_offloading_accelerate.conf \ 187 | finetunerag/finetune.py \ 188 | --model_name_or_path ../Llama-3.1-8B-Instruct \ 189 | --tokenizer_name_or_path ../Llama-3.1-8B-Instruct \ 190 | --use_flash_attn \ 191 | --max_seq_length 4096 \ 192 | --preprocessing_num_workers 128 \ 193 | --per_device_train_batch_size 1 \ 194 | --gradient_accumulation_steps 64 \ 195 | --learning_rate 2e-5 \ 196 | --lr_scheduler_type cosine \ 197 | --beta1 0.9 \ 198 | --beta2 0.95 \ 199 | --warmup_ratio 0.1 \ 200 | --weight_decay 0.1 \ 201 | --num_train_epochs 1 \ 202 | --enable_wandb \ 203 | --logging_steps 1 \ 204 | --checkpointing_steps 2 \ 205 | --prompt_style llama3.1 \ 206 | --validation_step 1 \ 207 | --wandb_project Finetuning \ 208 | --wandb_name Llama-3.1-8B-Instruct-Finetuned \ 209 | --train_file dataset/splits/train.jsonl \ 210 | --validation_file dataset/splits/validation.jsonl \ 211 | --output_dir ../models/Llama-3.1-8B-Instruct-Baseline/ 212 | ``` 213 | 214 | Make sure to adjust `model_name_or_path`, `tokenizer_name_or_path`, `train_file`, and `output_dir` to your models / data / setting. 215 | 216 | ### Released Checkpoints 217 | 218 | We have finetuned Llama-3.1-8B-Instruct to tackle RAG hallucination. We used 1xH100 GPU, with a micro-batch size of 1, and a batch size of 64 per step. See the sample finetuning command above for more information of the hyperparameters used. Our checkpoints can be found here: 219 | 220 | - [Baseline-tuned(checkpoints: steps_2-10)](https://huggingface.co/pints-ai/Llama-3.1-8B-Instruct-RAG_Baseline_tuned-1). 221 | - [Baseline-tuned(checkpoints: steps_12-20)](https://huggingface.co/pints-ai/Llama-3.1-8B-Instruct-RAG_Baseline_tuned-2). 222 | - [XML-tuned(checkpoints: steps_2-10)](https://huggingface.co/pints-ai/Llama-3.1-8B-Instruct-RAG_XML_tuned-1). 223 | - [XML-tuned(checkpoints: steps_12-20)](https://huggingface.co/pints-ai/Llama-3.1-8B-Instruct-RAG_XML_tuned-2). 224 | 225 | ## Evaluation 226 | 227 | ### Bench-RAG 228 | 229 | We have devised a benchmark strategy, namely **Bench-RAG**, with the help of GPT-4o to identify whether the outputs generated by the model is factually accurate even after provided with ficticious data in its prompt. See the [`benchrag/`](benchrag/) folder for more information about how it works and its execution. 230 | 231 | Below is the results of the finetuned Llama-3.1-8B-Instruct: 232 | 233 | Accuracy of answers generated by Llama-3.1-8B-Instruct over various steps of SFT 234 | Depth of answers generated by Llama-3.1-8B-Instruct over various steps of SFT 235 | Helpfulness of answers generated by Llama-3.1-8B-Instruct over various steps of SFT 236 | Relevance of answers generated by Llama-3.1-8B-Instruct over various steps of SFT 237 | 238 | ## Acknowledgements 239 | The structure of our codebase is referenced from the [Open-Instruct repository](https://github.com/allenai/open-instruct). 240 | 241 | ## Licence 242 | This codebase is licensed under Apache 2.0 as given in LICENSE. 243 | -------------------------------------------------------------------------------- /benchrag/README.md: -------------------------------------------------------------------------------- 1 | # Bench-RAG 2 | 3 | Our evaluation tool is designed to assess whether the outputs generated by a model remain factually accurate, even when presented with fictitious data in its prompt. It uses a structured evaluation system powered by OpenAI models (e.g., GPT-4o). In addition to factual accuracy, it also evaluates other key aspects such as helpfulness, relevance, and depth, using tailored prompts for each metric. The tool provides a detailed assessment, offering a True/False rating for accuracy and scores from 1 to 10 for the other dimensions, along with explanations for each evaluation. 4 | 5 | ## Features 6 | 7 | - **Accuracy Evaluation**: Determines if the response introduces any extra information not found in the provided context. 8 | - **Helpfulness Evaluation**: Rates how useful the response is in answering the user's question. 9 | - **Relevance Evaluation**: Scores the response based on how well it addresses the question. 10 | - **Depth Evaluation**: Measures the level of detail provided in the response. 11 | 12 | ## Inferencing prior to Bench-RAG 13 | 14 | Bench-RAG reads `jsonl` file(s), with each of them containing multiple inferences from the same model variant. Every of these inference should be in json format as follows: 15 | 16 | ```json 17 | { 18 | "filename": "The unique identifier for the file", 19 | "content": "The non-fictitious content in which the response was suppose to generate the answer from", 20 | "question": "The user's question", 21 | "response": "The fine-tuned LLM's generated response to the question. This response should have been generated by providing both the non-fictitious and the fictitious content." 22 | } 23 | ``` 24 | 25 | You can run inference on your fine-tuned model using the `benchrag/run_inference.py` script at the root directory. 26 | 27 | - The script has been curated for multi-gpu inferencing. Set the appropriate `num_processes` before executing. 28 | - The data format used for inference should follow suit with the `messages` format used during fine-tuning. 29 | 30 | ```bash 31 | # Run on directory with multiple models, example use-case include checkpointing done when fine-tuning. 32 | # Note that the script searches for checkpoints with prefix of 'steps_'. Change it accordingly to suit your needs. 33 | accelerate launch -m \ 34 | --multi_gpu \ # Only indicate multi_gpu flag if you have more than one gpu, else it will throw an error 35 | --num_processes=2 \ 36 | --mixed_precision=bf16 \ 37 | benchrag.run_inference \ 38 | --checkpoints_directory path/to/multiple/fine-tuned/models \ 39 | --data_directory path/to/testing/data \ 40 | --tokenizer_directory path/to/tokenizer \ 41 | --custom_chat_template llama3.1 42 | 43 | # Run on a specific model directory. 44 | accelerate launch -m \ 45 | --multi_gpu \ # Only indicate multi_gpu flag if you have more than one gpu, else it will throw an error 46 | --num_processes=2 \ 47 | --mixed_precision=bf16 \ 48 | benchrag.run_inference \ 49 | --specific_checkpoint_directory path/to/fine-tuned/model \ 50 | --data_directory path/to/testing/data \ 51 | --tokenizer_directory path/to/tokenizer \ 52 | --custom_chat_template llama3.1 53 | ``` 54 | 55 | > [!IMPORTANT] 56 | > If you are using deepspeed for training, remember to convert the checkpoint weights before inference. 57 | 58 | ## Executing Bench-RAG 59 | 60 | With your curated model inferences or output from `benchrag/run_inference.py`, run the benchmarking either on a directory of jsonl files or on a specific jsonl file you desire. Please refer to the official documentation for the available openai models as evaluator. 61 | 62 | ```bash 63 | # Run benchmark on directory with multiple jsonl files from various checkpoints. 64 | python -m benchrag.run_judging \ 65 | --openai_evaluator gpt-4o \ 66 | --answers_directory path/to/multiple/jsonl/files \ 67 | --output_directory path/to/output/directory 68 | 69 | # Run benchmark on a specific jsonl file. 70 | python -m benchrag.run_judging \ 71 | --openai_evaluator gpt-4o \ 72 | --answers_file path/to/single/jsonl/file.jsonl \ 73 | --output_directory path/to/output/directory 74 | ``` 75 | 76 | # Collate results 77 | 78 | You can aggregate the results via the `benchrag/get_results.py` script. Simply path to the directory that contains all the jsonl files generated. 79 | 80 | ```bash 81 | python -m benchrag.get_results --directory_path path/to/benchrag/results 82 | ``` 83 | -------------------------------------------------------------------------------- /benchrag/get_results.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from pathlib import Path 5 | 6 | def read_jsonl(file_path): 7 | """Reads a JSONL file and returns the data as a list of dictionaries.""" 8 | data = [] 9 | with open(file_path, 'r') as file: 10 | for line in file: 11 | data.append(json.loads(line.strip())) 12 | return data 13 | 14 | 15 | def compute_summary_metrics(data): 16 | """Computes summary metrics from a list of data dictionaries.""" 17 | metrics = { 18 | 'total_records': 0, 19 | 'accuracy_true_count': 0, 20 | 'accuracy_false_count': 0, 21 | 'total_helpfulness': 0, 22 | 'total_relevance': 0, 23 | 'total_depth': 0, 24 | 'total_average': 0.0, 25 | } 26 | 27 | # Iterate over each record to update the metrics 28 | for record in data: 29 | metrics['total_records'] += 1 30 | if record['accuracy']: 31 | metrics['accuracy_true_count'] += 1 32 | else: 33 | metrics['accuracy_false_count'] += 1 34 | metrics['total_helpfulness'] += record['helpfulness'] 35 | metrics['total_relevance'] += record['relevance'] 36 | metrics['total_depth'] += record['depth'] 37 | metrics['total_average'] += record['average'] 38 | 39 | # Calculate averages if there are any records 40 | if metrics['total_records'] > 0: 41 | metrics['average_helpfulness'] = ( 42 | metrics['total_helpfulness'] / metrics['total_records'] 43 | ) 44 | metrics['average_relevance'] = ( 45 | metrics['total_relevance'] / metrics['total_records'] 46 | ) 47 | metrics['average_depth'] = metrics['total_depth'] / metrics['total_records'] 48 | metrics['average_average'] = metrics['total_average'] / metrics['total_records'] 49 | else: 50 | metrics['average_helpfulness'] = 0 51 | metrics['average_relevance'] = 0 52 | metrics['average_depth'] = 0 53 | metrics['average_average'] = 0 54 | 55 | return metrics 56 | 57 | 58 | def process_directory(directory_path: Path): 59 | """Processes all JSONL files in the specified directory.""" 60 | filenames = sorted(os.listdir(directory_path)) 61 | for filename in filenames: 62 | if filename.endswith('.jsonl'): 63 | file_path = os.path.join(directory_path, filename) 64 | data = read_jsonl(file_path) 65 | metrics = compute_summary_metrics(data) 66 | print(f'Summary Metrics for {filename}:') 67 | print(f"Total Records: {metrics['total_records']}") 68 | accuracy = metrics['accuracy_true_count'] / metrics['total_records'] * 100 69 | print(f'Accuracy: {accuracy:.2f}%') 70 | print(f"Average helpfulness: {metrics['average_helpfulness']:.2f}") 71 | print(f"Average Relevance: {metrics['average_relevance']:.2f}") 72 | print(f"Average Depth: {metrics['average_depth']:.2f}") 73 | print(f"Average of Averages: {metrics['average_average']:.2f}") 74 | print('') 75 | 76 | if __name__ == "__main__": 77 | from jsonargparse import CLI 78 | CLI(process_directory, as_positional=False) 79 | -------------------------------------------------------------------------------- /benchrag/plot_results.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sample usage: python -m ragbench.plot_results --result_directory ragbench/results/Llama-3.1-8B-Instruct-Baseline ragbench/results/Llama-3.1-8B-Instruct-XML ragbench/results/Llama-3.1-8B-Instruct-Enhanced --output_directory ragbench/judge_results/ 3 | """ 4 | 5 | import json 6 | import re 7 | from argparse import ArgumentParser 8 | from pathlib import Path 9 | 10 | from matplotlib import pyplot 11 | 12 | 13 | def start(result_paths: list[Path], output_directory: Path): 14 | aggregated_results = [] 15 | for result_path in result_paths: 16 | aggregated_results.extend(aggregate_scores(result_path)) 17 | 18 | plot_scores_by_metric(aggregated_results, output_directory) 19 | 20 | 21 | def aggregate_scores(result_path: Path): 22 | jsonl_file_paths = list(result_path.glob('*.jsonl')) 23 | 24 | results = [] 25 | for jsonl_file_path in jsonl_file_paths: 26 | total_accuracy = 0 27 | total_helpfulness = 0 28 | total_relevance = 0 29 | total_depth = 0 30 | count = 0 31 | 32 | with open(jsonl_file_path, 'r') as jsonl_file: 33 | for jsonl_line in jsonl_file: 34 | jsonl_data = json.loads(jsonl_line.strip()) 35 | 36 | total_accuracy += jsonl_data.get('accuracy', 0) 37 | total_helpfulness += jsonl_data.get('helpfulness', 0) 38 | total_relevance += jsonl_data.get('relevance', 0) 39 | total_depth += jsonl_data.get('depth', 0) 40 | count += 1 41 | 42 | avg_accuracy = total_accuracy / count 43 | avg_helpfulness = total_helpfulness / count 44 | avg_relevance = total_relevance / count 45 | avg_depth = total_depth / count 46 | 47 | results.append( 48 | { 49 | 'template_type': result_path.name.rsplit('-', 1)[-1], 50 | 'file': jsonl_file_path.name, 51 | 'avg_accuracy': avg_accuracy, 52 | 'avg_helpfulness': avg_helpfulness, 53 | 'avg_relevance': avg_relevance, 54 | 'avg_depth': avg_depth, 55 | } 56 | ) 57 | 58 | return results 59 | 60 | 61 | def extract_number(filename): 62 | match = re.search(r'steps_(\d+)', filename) 63 | return int(match.group(1)) if match else float('inf') 64 | 65 | 66 | def plot_scores_by_metric(aggregated_results: list[dict], output_directory: Path): 67 | template_types = list(set(result['template_type'] for result in aggregated_results)) 68 | 69 | output_directory.mkdir(parents=True, exist_ok=True) 70 | 71 | metrics = ['accuracy', 'helpfulness', 'relevance', 'depth'] 72 | metric_labels = ['Accuracy', 'Helpfulness', 'Relevance', 'Depth'] 73 | 74 | for i, metric in enumerate(metrics): 75 | fig, ax = pyplot.subplots(figsize=(12, 8)) 76 | 77 | for template_type in template_types: 78 | files = [ 79 | result['file'] 80 | for result in aggregated_results 81 | if result['template_type'] == template_type 82 | ] 83 | values = [ 84 | result[f'avg_{metric}'] 85 | for result in aggregated_results 86 | if result['template_type'] == template_type 87 | ] 88 | 89 | sorted_pairs = sorted( 90 | zip(files, values), key=lambda x: extract_number(x[0]) 91 | ) 92 | sorted_files, sorted_values = zip(*sorted_pairs) 93 | sorted_files = list(map(lambda name: name[:-6], sorted_files)) 94 | ax.plot(sorted_files, sorted_values, marker='o', label=f'{template_type}') 95 | 96 | ax.set_title(f'Average {metric_labels[i]} of Finetuned Llama-3.1-8B-Instruct') 97 | ax.set_xlabel('Checkpoint') 98 | ax.set_ylabel(f'Average {metric_labels[i]}') 99 | ax.legend() 100 | ax.tick_params(axis='x', rotation=45) 101 | 102 | # Save the PNG file 103 | pyplot.tight_layout() 104 | pyplot.savefig(output_directory / f'{metric}_by_steps.png') 105 | pyplot.close() 106 | 107 | 108 | if __name__ == '__main__': 109 | parser = ArgumentParser(description='Plot line graphs of ragbench results.') 110 | parser.add_argument( 111 | '--result_directory', 112 | type=str, 113 | nargs='+', 114 | required=True, 115 | help='Directories containing ragbench results in JSONL files. Multiple directories can be provided.', 116 | ) 117 | parser.add_argument( 118 | '--output_directory', 119 | type=str, 120 | required=True, 121 | help='Output directory to save the plots.', 122 | ) 123 | 124 | arguments = parser.parse_args() 125 | 126 | result_paths = [ 127 | Path(result_directory) for result_directory in arguments.result_directory 128 | ] 129 | output_path = Path(arguments.output_directory) 130 | start(result_paths, output_path) 131 | -------------------------------------------------------------------------------- /benchrag/run_inference.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | 4 | from accelerate import Accelerator 5 | from utils.dataset_utils import load_jsonl_file 6 | from dataclasses import dataclass, field 7 | from utils.logger import setup_logger 8 | from prompts.prompt_styles import PromptStyle 9 | from finetunerag.utils import ArgumentParserPlus 10 | from pathlib import Path 11 | from transformers import AutoModelForCausalLM, AutoTokenizer 12 | from typing import Optional 13 | 14 | @dataclass 15 | class InferencingArguments: 16 | """ 17 | Full arguments class for inferencing from checkpoints. 18 | """ 19 | 20 | checkpoints_directory: Optional[str] = field( 21 | default=None, 22 | metadata={ 23 | 'help': 'Checkpoints directory containing the checkpoints of a model. This cannot be provided along with --specific_checkpoint_directory.' 24 | }, 25 | ) 26 | specific_checkpoint_directory: Optional[str] = field( 27 | default=None, 28 | metadata={'help': 'Path to a specific checkpoint folder. This cannot be provided along with --checkpoints_directory.'}, 29 | ) 30 | data_directory: Optional[str] = field( 31 | default=None, 32 | metadata={'help': 'Data directory containing the content used to generate the prompts.'}, 33 | ) 34 | output_directory: str = field( 35 | default='ragbench/inferences/', 36 | metadata={'help': 'Output directory to save the generated answers.'}, 37 | ) 38 | tokenizer_directory: Optional[str] = field( 39 | default=None, 40 | metadata={'help': 'Path to the tokenizer directory.'}, 41 | ) 42 | custom_chat_template: Optional[str] = field( 43 | default=None, 44 | metadata={'help': 'Name of custom chat template if the chat template from tokenizer is not desired.'}, 45 | ) 46 | 47 | def __post_init__(self): 48 | if self.checkpoints_directory == None and self.specific_checkpoint_directory == None: 49 | raise ValueError("No checkpoints directory or specific checkpoint directory provided.") 50 | 51 | if self.checkpoints_directory != None and self.specific_checkpoint_directory != None: 52 | raise ValueError("Both checkpoints directory and specific checkpoint directory provided. Please only provide either one.") 53 | 54 | if self.data_directory == None: 55 | raise ValueError('No data directory provided. Unable to generate from model.') 56 | 57 | if self.tokenizer_directory == None: 58 | raise ValueError('No tokenizer directory provided.') 59 | 60 | self.data_path = Path(self.data_directory) 61 | self.output_path = Path(self.output_directory) 62 | self.tokenizer_path = Path(self.tokenizer_directory) 63 | self.checkpoints_path = Path(self.checkpoints_directory) if self.checkpoints_directory else None 64 | self.specific_checkpoint_path = Path(self.specific_checkpoint_directory) if self.specific_checkpoint_directory else None 65 | 66 | # Global logger 67 | logger = setup_logger(Path(__file__).name) 68 | 69 | def start(args: InferencingArguments): 70 | dataset = load_jsonl_file(args.data_path) 71 | args.output_path.mkdir(parents=True, exist_ok=True) 72 | 73 | if args.checkpoints_path: 74 | # Retrieve all checkpoints available from the path to the checkpoints 75 | checkpoint_paths = list(args.checkpoints_path.glob('steps_*')) 76 | # Sort the checkpoints by their step count 77 | checkpoint_paths = sorted(checkpoint_paths, key=lambda checkpoint_folder: int(checkpoint_folder.name.rsplit('-', 1)[-1])) 78 | logger.debug(f'These are the checkpoints identified: {checkpoint_paths}') 79 | else: 80 | checkpoint_paths = [Path(args.specific_checkpoint_directory)] 81 | 82 | for checkpoint_path in checkpoint_paths: 83 | generate_responses(checkpoint_path, args.tokenizer_path, args.output_path, dataset, args.custom_chat_template) 84 | logger.info(f"Finished processing {checkpoint_path}.") 85 | 86 | logger.info(f"Inference for all checkpoints complete!") 87 | 88 | 89 | def generate_responses( 90 | checkpoint_path: Path, 91 | tokenizer_path: Path, 92 | output_path: Path, 93 | dataset: list, 94 | custom_chat_template: str 95 | ): 96 | accelerator = Accelerator() 97 | 98 | prompt_styler = PromptStyle.from_name(custom_chat_template) 99 | 100 | model = AutoModelForCausalLM.from_pretrained( 101 | checkpoint_path, 102 | torch_dtype=torch.bfloat16, 103 | attn_implementation='flash_attention_2', 104 | device_map={'': accelerator.process_index}, 105 | ) 106 | model = torch.compile(model) 107 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 108 | 109 | model, tokenizer = accelerator.prepare(model, tokenizer) 110 | accelerator.wait_for_everyone() # Synchronise all processes to ensure readiness before starting generation 111 | 112 | generated_responses = [] 113 | 114 | with accelerator.split_between_processes(dataset) as documents: 115 | for index, datarow in enumerate(documents): 116 | prompt_text = prompt_styler.apply(datarow['messages'], append_assistant_header=True) 117 | inputs = tokenizer(prompt_text, return_tensors="pt").to(accelerator.device) 118 | generation_output = model.generate( 119 | inputs.input_ids, 120 | max_length=6000, 121 | do_sample=False, 122 | temperature=None, 123 | pad_token_id=tokenizer.eos_token_id, 124 | ) 125 | 126 | generated_ids = generation_output[0][inputs.input_ids.shape[1]:-1] 127 | response = tokenizer.decode(generated_ids) 128 | 129 | generated_responses.append({ 130 | 'filename': datarow['filename'], 131 | 'content': datarow['content'], 132 | 'question': datarow['question'], 133 | 'response': response 134 | }) 135 | 136 | gathered_responses = accelerator.gather(generated_responses) 137 | 138 | if accelerator.is_main_process: 139 | response_file_path = output_path / f'{checkpoint_path.name}.jsonl' 140 | with open(response_file_path, 'w') as response_file: 141 | for response_data in gathered_responses: 142 | response_file.write(json.dumps(response_data) + '\n') 143 | 144 | accelerator.wait_for_everyone() 145 | 146 | if __name__ == '__main__': 147 | parser = ArgumentParserPlus((InferencingArguments)) 148 | args = parser.parse() 149 | start(args) 150 | -------------------------------------------------------------------------------- /benchrag/run_judging.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numbers 3 | from dataclasses import dataclass, field 4 | from pathlib import Path 5 | from typing import List, Optional 6 | 7 | from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam 8 | 9 | from utils.dataset_utils import ( 10 | load_jsonl_file, 11 | ) 12 | from finetunerag.utils import ArgumentParserPlus 13 | from utils.openai import call_openai_api 14 | from prompts.judging_prompts import ( 15 | OpenAIJudgeResponse, 16 | get_judge_user_prompt, 17 | judge_accuracy_system_prompt, 18 | judge_depth_system_prompt, 19 | judge_helpfulness_system_prompt, 20 | judge_relevance_system_prompt, 21 | ) 22 | from utils.logger import setup_logger 23 | 24 | 25 | @dataclass 26 | class InferencingArguments: 27 | """ 28 | Full arguments class for inferencing from checkpoints. 29 | """ 30 | 31 | answers_directory: Optional[str] = field( 32 | default=None, 33 | metadata={ 34 | 'help': 'Path to an answers directory that potentially contains multiple set of answers from a model in the form of jsonl files. This cannot be provided along with --answers_file.' 35 | }, 36 | ) 37 | answers_file: Optional[str] = field( 38 | default=None, 39 | metadata={ 40 | 'help': 'Path to a specific answers file from a model in the form of jsonl. This cannot be provided along with --answers_directory.' 41 | }, 42 | ) 43 | openai_evaluator: str = field( 44 | default='gpt-4o', 45 | metadata={ 46 | 'help': 'The evaluator to use from OpenAI. Refer to their documentation for available models.' 47 | }, 48 | ) 49 | output_directory: str = field( 50 | default='ragbench/judged_scores', 51 | metadata={'help': 'Output directory to save the judging results.'}, 52 | ) 53 | 54 | def __post_init__(self): 55 | if self.answers_directory is None and self.answers_file is None: 56 | raise ValueError('No answers directory or answers file provided.') 57 | 58 | if self.answers_directory is not None and self.answers_file is not None: 59 | raise ValueError( 60 | 'Both answers directory and answers file provided. Please only provide either one.' 61 | ) 62 | 63 | self.output_path = Path(self.output_directory) 64 | self.answers_files_path = ( 65 | Path(self.answers_directory) if self.answers_directory else None 66 | ) 67 | self.answers_file_path = Path(self.answers_file) if self.answers_file else None 68 | 69 | 70 | JUDGE_SYSTEM_PROMPTS = [ 71 | judge_accuracy_system_prompt, 72 | judge_helpfulness_system_prompt, 73 | judge_relevance_system_prompt, 74 | judge_depth_system_prompt, 75 | ] 76 | 77 | # Global logger 78 | logger = setup_logger(Path(__file__).name) 79 | 80 | 81 | def start(args: InferencingArguments): 82 | files_to_judge: Path = [] 83 | if args.answers_files_path: 84 | for file_path in args.answers_files_path.iterdir(): 85 | if file_path.is_file() and file_path.suffix == '.jsonl': 86 | files_to_judge.append(file_path) 87 | else: 88 | files_to_judge.append(args.answers_file_path) 89 | 90 | logger.debug( 91 | f'These are the files identified: {list(map(lambda file_path: file_path.name, files_to_judge))}' 92 | ) 93 | for file_to_judge in files_to_judge: 94 | aggregate_file(file_to_judge, args) 95 | 96 | 97 | def aggregate_file(answers_file_path: Path, args: InferencingArguments): 98 | cumulative_stats = {'scores': 0, 'n': 0, 'failed': 0, 'trues': 0, 'falses': 0} 99 | 100 | documents = load_jsonl_file(answers_file_path) 101 | 102 | logger.debug(f'Total of {len(documents)} samples to rate.') 103 | 104 | args.output_path.mkdir(parents=True, exist_ok=True) 105 | output_file_path = args.output_path / answers_file_path.name 106 | 107 | files_done = set() 108 | # resume from checkpoint 109 | if output_file_path.is_file(): 110 | output_list = load_jsonl_file(output_file_path) 111 | 112 | for json_row in output_list: 113 | files_done.add(json_row['filename']) 114 | 115 | # Update statistics 116 | if json_row['accuracy']: 117 | cumulative_stats['trues'] += 1 118 | else: 119 | cumulative_stats['falses'] += 1 120 | 121 | cumulative_stats['n'] += 1 122 | if json_row['average'] is not None: 123 | cumulative_stats['scores'] += json_row['average'] 124 | 125 | logger.info(f"{cumulative_stats['n']} sample(s) already done.") 126 | 127 | # call openAI 128 | for index, document in enumerate(documents): 129 | if document['filename'] in files_done: 130 | continue 131 | 132 | try: 133 | logger.info(f"Attempting to judge index {index}: {document['filename']}...") 134 | parsed_response: OpenAIJudgeResponse = judge( 135 | document=document, evaluator=args.openai_evaluator 136 | ) 137 | 138 | except Exception as error: 139 | logger.error(f"Filename: [{document['filename']}] errored.") 140 | logger.error(error) 141 | cumulative_stats['failed'] += 1 142 | continue 143 | 144 | # Compute the average score 145 | scores: List[numbers.Real] = [ 146 | score 147 | for score in parsed_response.values() 148 | if isinstance(score, numbers.Real) and not isinstance(score, bool) 149 | ] 150 | if len(scores) > 0: 151 | parsed_response['average'] = sum(scores) / len(scores) 152 | else: 153 | parsed_response['average'] = None 154 | 155 | # Add the filename into the dataset 156 | parsed_response['filename'] = document['filename'] 157 | 158 | parsed_response_str = json.dumps(parsed_response) 159 | with open(output_file_path, 'a', encoding='utf-8') as responses_file: 160 | responses_file.write(parsed_response_str + '\n') 161 | 162 | # Update statistics 163 | cumulative_stats['n'] += 1 164 | logger.info(f"{cumulative_stats['n']}/{len(documents)} inferences done") 165 | 166 | if parsed_response['accuracy']: 167 | cumulative_stats['trues'] += 1 168 | else: 169 | cumulative_stats['falses'] += 1 170 | 171 | if parsed_response['average'] is not None: 172 | cumulative_stats['scores'] += parsed_response['average'] 173 | 174 | logger.info(cumulative_stats) 175 | logger.info(f"Final average: {cumulative_stats['scores'] / cumulative_stats['n']}") 176 | 177 | 178 | def judge( 179 | document, 180 | evaluator, 181 | ) -> OpenAIJudgeResponse: 182 | parsed_responses = {} 183 | for judge_system_prompt in JUDGE_SYSTEM_PROMPTS: 184 | messages: List[ChatCompletionMessageParam] = [] 185 | 186 | user_prompt = get_judge_user_prompt(document) 187 | 188 | messages.append(judge_system_prompt) 189 | messages.append(user_prompt) 190 | 191 | parsed_response = call_openai_api(messages=messages, model=evaluator, output_as_json=True) 192 | parsed_responses.update(parsed_response) 193 | 194 | return parsed_responses 195 | 196 | 197 | if __name__ == '__main__': 198 | parser = ArgumentParserPlus((InferencingArguments)) 199 | args = parser.parse() 200 | start(args) 201 | -------------------------------------------------------------------------------- /configs/deepspeed/stage2_no_offloading_accelerate.conf: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "zero_optimization": { 6 | "stage": 2, 7 | "overlap_comm": true, 8 | "contiguous_gradients": true, 9 | "sub_group_size": 1e9, 10 | "reduce_bucket_size": "auto" 11 | }, 12 | "gradient_accumulation_steps": "auto", 13 | "gradient_clipping": "auto", 14 | "steps_per_print": 1e5, 15 | "train_batch_size": "auto", 16 | "train_micro_batch_size_per_gpu": "auto", 17 | "wall_clock_breakdown": false 18 | } -------------------------------------------------------------------------------- /configs/deepspeed/stage2_offloading_accelerate.conf: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "zero_optimization": { 6 | "stage": 2, 7 | "offload_optimizer": { 8 | "device": "cpu", 9 | "pin_memory": true 10 | }, 11 | "overlap_comm": true, 12 | "contiguous_gradients": true, 13 | "sub_group_size": 1e9, 14 | "reduce_bucket_size": "auto" 15 | }, 16 | "gradient_accumulation_steps": "auto", 17 | "gradient_clipping": "auto", 18 | "steps_per_print": 1e5, 19 | "train_batch_size": "auto", 20 | "train_micro_batch_size_per_gpu": "auto", 21 | "wall_clock_breakdown": false 22 | } -------------------------------------------------------------------------------- /configs/deepspeed/stage3_no_offloading.conf: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": "auto" 12 | } 13 | }, 14 | "scheduler": { 15 | "type": "WarmupDecayLR", 16 | "params": { 17 | "total_num_steps": "auto", 18 | "warmup_min_lr": "auto", 19 | "warmup_max_lr": "auto", 20 | "warmup_num_steps": "auto" 21 | } 22 | }, 23 | "zero_optimization": { 24 | "stage": 3, 25 | "overlap_comm": true, 26 | "contiguous_gradients": true, 27 | "sub_group_size": 1e9, 28 | "reduce_bucket_size": "auto", 29 | "stage3_prefetch_bucket_size": "auto", 30 | "stage3_param_persistence_threshold": "auto", 31 | "stage3_max_live_parameters": 1e9, 32 | "stage3_max_reuse_distance": 1e9, 33 | "stage3_gather_16bit_weights_on_model_save": true 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 1e5, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /configs/deepspeed/stage3_no_offloading_accelerate.conf: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "zero_optimization": { 6 | "stage": 3, 7 | "overlap_comm": true, 8 | "contiguous_gradients": true, 9 | "sub_group_size": 1e9, 10 | "reduce_bucket_size": "auto", 11 | "stage3_prefetch_bucket_size": "auto", 12 | "stage3_param_persistence_threshold": "auto", 13 | "stage3_max_live_parameters": 1e9, 14 | "stage3_max_reuse_distance": 1e9, 15 | "stage3_gather_16bit_weights_on_model_save": true 16 | }, 17 | "gradient_accumulation_steps": "auto", 18 | "gradient_clipping": "auto", 19 | "steps_per_print": 1e5, 20 | "train_batch_size": "auto", 21 | "train_micro_batch_size_per_gpu": "auto", 22 | "wall_clock_breakdown": false 23 | } -------------------------------------------------------------------------------- /configs/deepspeed/stage3_offloading.conf: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": "auto" 12 | } 13 | }, 14 | "scheduler": { 15 | "type": "WarmupDecayLR", 16 | "params": { 17 | "total_num_steps": "auto", 18 | "warmup_min_lr": "auto", 19 | "warmup_max_lr": "auto", 20 | "warmup_num_steps": "auto" 21 | } 22 | }, 23 | "zero_optimization": { 24 | "stage": 3, 25 | "offload_optimizer": { 26 | "device": "cpu", 27 | "pin_memory": true 28 | }, 29 | "offload_param": { 30 | "device": "cpu", 31 | "pin_memory": true 32 | }, 33 | "overlap_comm": true, 34 | "contiguous_gradients": true, 35 | "sub_group_size": 1e9, 36 | "reduce_bucket_size": "auto", 37 | "stage3_prefetch_bucket_size": "auto", 38 | "stage3_param_persistence_threshold": "auto", 39 | "stage3_max_live_parameters": 1e9, 40 | "stage3_max_reuse_distance": 1e9, 41 | "stage3_gather_16bit_weights_on_model_save": true 42 | }, 43 | "gradient_accumulation_steps": "auto", 44 | "gradient_clipping": "auto", 45 | "steps_per_print": 1e5, 46 | "train_batch_size": "auto", 47 | "train_micro_batch_size_per_gpu": "auto", 48 | "wall_clock_breakdown": false 49 | } -------------------------------------------------------------------------------- /configs/deepspeed/stage3_offloading_accelerate.conf: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "zero_optimization": { 6 | "stage": 3, 7 | "offload_optimizer": { 8 | "device": "cpu", 9 | "pin_memory": true 10 | }, 11 | "offload_param": { 12 | "device": "cpu", 13 | "pin_memory": true 14 | }, 15 | "overlap_comm": true, 16 | "contiguous_gradients": true, 17 | "sub_group_size": 1e9, 18 | "reduce_bucket_size": "auto", 19 | "stage3_prefetch_bucket_size": "auto", 20 | "stage3_param_persistence_threshold": "auto", 21 | "stage3_max_live_parameters": 1e9, 22 | "stage3_max_reuse_distance": 1e9, 23 | "stage3_gather_16bit_weights_on_model_save": true 24 | }, 25 | "gradient_accumulation_steps": "auto", 26 | "gradient_clipping": "auto", 27 | "steps_per_print": 1e5, 28 | "train_batch_size": "auto", 29 | "train_micro_batch_size_per_gpu": "auto", 30 | "wall_clock_breakdown": false 31 | } 32 | -------------------------------------------------------------------------------- /finetunerag/arguments.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from dataclasses import dataclass, field 4 | from typing import Optional 5 | 6 | 7 | @dataclass 8 | class FinetuneArguments: 9 | """ 10 | Full arguments class for fine-tuning. 11 | """ 12 | 13 | exp_name: str = field( 14 | default=os.path.basename(__file__)[: -len('.py')], 15 | metadata={ 16 | 'help': ( 17 | "The name of this experiment." 18 | ) 19 | }, 20 | ) 21 | model_name_or_path: Optional[str] = field( 22 | default=None, 23 | metadata={ 24 | 'help': ( 25 | "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." 26 | ) 27 | }, 28 | ) 29 | tokenizer_name_or_path: Optional[str] = field( 30 | default=None, 31 | metadata={ 32 | 'help': 'Pretrained tokenizer name or path if not the same as model_name_or_path' 33 | }, 34 | ) 35 | tokenizer_revision: Optional[str] = field( 36 | default='main', 37 | metadata={ 38 | 'help': 'The specific model version to use (can be a branch name, tag name or commit id).' 39 | }, 40 | ) 41 | prompt_style: Optional[str] = field( 42 | default='default', 43 | metadata={ 44 | 'help': 'The specific prompt template to use (should be registered under one of the custom prompts).' 45 | }, 46 | ) 47 | use_flash_attn: bool = field( 48 | default=True, 49 | metadata={'help': 'Whether to use flash attention in the model training'}, 50 | ) 51 | model_revision: str = field( 52 | default='main', 53 | metadata={ 54 | 'help': 'The specific model version to use (can be a branch name, tag name or commit id).' 55 | }, 56 | ) 57 | trust_remote_code: bool = field( 58 | default=False, 59 | metadata={ 60 | 'help': ( 61 | 'Whether or not to allow for custom models defined on the Hub in their own modeling files. ' 62 | 'This option should only be set to `True` for repositories you trust and in which you ' 63 | 'have read the code, as it will execute code present on the Hub on your local machine.' 64 | ) 65 | }, 66 | ) 67 | low_cpu_mem_usage: bool = field( 68 | default=False, 69 | metadata={ 70 | 'help': ( 71 | 'It is an option to create the model as an empty shell, ' 72 | 'then only materialize its parameters when the pretrained weights are loaded. ' 73 | 'set True will benefit LLM loading time and RAM consumption.' 74 | ) 75 | }, 76 | ) 77 | train_file: Optional[str] = field( 78 | default=None, 79 | metadata={'help': 'The input training data file (a json/jsonl file).'}, 80 | ) 81 | validation_file: Optional[str] = field( 82 | default=None, 83 | metadata={'help': 'The input validation file (a json/jsonl file).'}, 84 | ) 85 | max_train_samples: Optional[int] = field( 86 | default=None, 87 | metadata={ 88 | 'help': ( 89 | 'For debugging purposes or quicker training, truncate the number of training examples to this ' 90 | 'value if set.' 91 | ) 92 | }, 93 | ) 94 | preprocessing_num_workers: Optional[int] = field( 95 | default=None, 96 | metadata={'help': 'The number of processes to use for the preprocessing.'}, 97 | ) 98 | max_seq_length: Optional[int] = field( 99 | default=None, 100 | metadata={ 101 | 'help': ( 102 | 'The maximum total input sequence length after tokenization. ' 103 | 'Sequences longer than this will be truncated,' 104 | ) 105 | }, 106 | ) 107 | overwrite_cache: bool = field( 108 | default=False, 109 | metadata={'help': 'Overwrite the cached training and evaluation sets'}, 110 | ) 111 | add_bos: bool = field( 112 | default=False, 113 | metadata={ 114 | 'help': 'Forcibly add bos token to the beginning of the input sequence.' 115 | ' Use only when tokenizer does not add bos token by default.' 116 | }, 117 | ) 118 | clip_grad_norm: float = field( 119 | default=-1, 120 | metadata={ 121 | 'help': 'Clip gradient norm. Not compatible with deepspeed (use deepspeed config instead).' 122 | }, 123 | ) 124 | gradient_accumulation_steps: int = field( 125 | default=1, 126 | metadata={ 127 | 'help': 'Number of updates steps to accumulate before performing a backward/update pass.' 128 | }, 129 | ) 130 | learning_rate: float = field( 131 | default=2e-5, 132 | metadata={'help': 'The initial learning rate for AdamW optimizer.'}, 133 | ) 134 | beta1: float = field( 135 | default=0.9, 136 | metadata={ 137 | 'help': 'The coefficient used for computing running averages of gradient and its square within the optimiser.' 138 | }, 139 | ) 140 | beta2: float = field( 141 | default=0.95, 142 | metadata={ 143 | 'help': 'The coefficient used for computing running averages of gradient and its square within the optimiser.' 144 | }, 145 | ) 146 | logging_steps: Optional[int] = field( 147 | default=None, 148 | metadata={ 149 | 'help': 'Log the training loss and learning rate every logging_steps steps.' 150 | }, 151 | ) 152 | lr_scheduler_type: str = field( 153 | default='linear', 154 | metadata={ 155 | 'help': 'The scheduler type to use for learning rate adjustment.', 156 | 'choices': [ 157 | 'linear', 158 | 'cosine', 159 | 'cosine_with_restarts', 160 | 'polynomial', 161 | 'constant', 162 | 'constant_with_warmup', 163 | ], 164 | }, 165 | ) 166 | num_train_epochs: int = field( 167 | default=2, 168 | metadata={'help': 'Total number of training epochs to perform.'}, 169 | ) 170 | output_dir: str = field( 171 | default='output/', 172 | metadata={ 173 | 'help': 'The output directory where the model predictions and checkpoints will be written.' 174 | }, 175 | ) 176 | per_device_train_batch_size: int = field( 177 | default=8, 178 | metadata={'help': 'Batch size per GPU/TPU core/CPU for training.'}, 179 | ) 180 | use_8bit_optimizer: bool = field( 181 | default=False, 182 | metadata={ 183 | 'help': 'Use 8bit optimizer from bitsandbytes. Not compatible with deepspeed.' 184 | }, 185 | ) 186 | warmup_ratio: float = field( 187 | default=0.03, 188 | metadata={'help': 'Linear warmup over warmup_ratio fraction of total steps.'}, 189 | ) 190 | weight_decay: float = field( 191 | default=0.0, 192 | metadata={'help': 'Weight decay for AdamW if we apply some.'}, 193 | ) 194 | timeout: int = field( 195 | default=1800, 196 | metadata={ 197 | 'help': 'Timeout for the training process in seconds.' 198 | 'Useful if tokenization process is long. Default is 1800 seconds (30 minutes).' 199 | }, 200 | ) 201 | reduce_loss: str = field( 202 | default='mean', 203 | metadata={ 204 | 'help': "How to reduce loss over tokens. Options are 'mean' or 'sum'." 205 | "Using 'sum' can improve chat model performance." 206 | }, 207 | ) 208 | resume_from_checkpoint: Optional[str] = field( 209 | default=None, 210 | metadata={'help': 'If the training should continue from a checkpoint folder.'}, 211 | ) 212 | enable_wandb: bool = field( 213 | default=False, 214 | metadata={'help': 'Whether to enable wandb for logging.'}, 215 | ) 216 | wandb_entity: Optional[str] = field( 217 | default=None, 218 | metadata={'help': 'Entity to use for logging to wandb.'}, 219 | ) 220 | wandb_project: Optional[str] = field( 221 | default='test-runs', 222 | metadata={'help': 'Project name to use when logging to wandb.'}, 223 | ) 224 | wandb_name: Optional[str] = field( 225 | default='wandb', 226 | metadata={'help': 'Run name to use when logging to wandb.'}, 227 | ) 228 | gradient_checkpointing: bool = field( 229 | default=False, 230 | metadata={ 231 | 'help': 'Turn on gradient checkpointing. Saves memory but slows training.' 232 | }, 233 | ) 234 | max_train_steps: Optional[int] = field( 235 | default=None, 236 | metadata={ 237 | 'help': 'If set, overrides the number of training steps. Otherwise, num_train_epochs is used.' 238 | }, 239 | ) 240 | seed: int = field( 241 | default=42, 242 | metadata={'help': 'Random seed for initialization and dataset shuffling.'}, 243 | ) 244 | checkpointing_steps: Optional[str] = field( 245 | default=None, 246 | metadata={ 247 | 'help': "Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch." # noqa 248 | }, 249 | ) 250 | validation_steps: Optional[int] = field( 251 | default=None, 252 | metadata={ 253 | 'help': 'Compute loss on validation data at the end of every n steps' 254 | }, 255 | ) 256 | keep_last_n_checkpoints: Optional[int] = field( 257 | default=None, 258 | metadata={ 259 | 'help': 'How many checkpoints to keep in the output directory. -1 for all.' 260 | }, 261 | ) 262 | overwrite_output_dir: bool = field( 263 | default=False, 264 | metadata={ 265 | 'help': 'Overwrite the content of the output directory. Means that resumption will always start from scratch.' 266 | }, 267 | ) 268 | fused_optimizer: bool = field( 269 | default=True, 270 | metadata={ 271 | 'help': 'Whether to use fused AdamW or not.', 272 | }, 273 | ) 274 | 275 | def __post_init__(self): 276 | if self.model_name_or_path is None: 277 | raise ValueError('The Hugging Face model name or path is not indicated.') 278 | 279 | if self.tokenizer_name_or_path is None: 280 | warnings.warn('The tokenizer name or path is not indicated. Defaulting it to model_name_or_path.') 281 | self.tokenizer_name_or_path = self.model_name_or_path 282 | 283 | if self.reduce_loss not in ['mean', 'sum']: 284 | raise ValueError("reduce_loss must be either 'mean' or 'sum'") 285 | 286 | if self.train_file is None: 287 | raise ValueError('Need either a dataset name, dataset mixer, or a training file.') 288 | else: 289 | extension = self.train_file.split('.')[-1] 290 | assert extension in [ 291 | 'json', 292 | 'jsonl', 293 | ], '`train_file` should be a json or a jsonl file.' 294 | 295 | if self.validation_steps and not self.validation_file: 296 | raise ValueError( 297 | "The number of steps for every validation is indicated. However, the path to the validation dataset is not provided. Please provide it with '--validation_file'." 298 | ) 299 | 300 | if self.validation_file and not self.validation_steps: 301 | warnings.warn( 302 | 'The path to the validation dataset is provided. However, the number of steps for every validation is not indicated. Defaulted to 1.' 303 | ) 304 | self.validation_steps = 1 -------------------------------------------------------------------------------- /finetunerag/finetune.py: -------------------------------------------------------------------------------- 1 | # Finetuning code referenced from https://github.com/allenai/open-instruct 2 | 3 | import logging 4 | import math 5 | import os 6 | import random 7 | from datetime import timedelta 8 | from functools import partial 9 | 10 | import datasets 11 | import deepspeed 12 | import torch 13 | import transformers 14 | from accelerate import Accelerator 15 | from accelerate.logging import get_logger 16 | from accelerate.utils import InitProcessGroupKwargs, set_seed 17 | from datasets import load_dataset 18 | from torch.utils.data import DataLoader 19 | from tqdm.auto import tqdm 20 | from transformers import ( 21 | AutoConfig, 22 | AutoModelForCausalLM, 23 | AutoTokenizer, 24 | DataCollatorForSeq2Seq, 25 | get_scheduler, 26 | ) 27 | print("Current Working Directory:", os.getcwd()) 28 | 29 | from prompts.prompt_styles import PromptStyle 30 | from guardrag.arguments import FinetuneArguments 31 | from guardrag.model_utils import save_with_accelerate 32 | from guardrag.utils import ( 33 | ArgumentParserPlus, 34 | clean_last_n_checkpoints, 35 | get_last_checkpoint_path, 36 | get_wandb_tags, 37 | ) 38 | 39 | logger = get_logger(__name__) 40 | 41 | def main(args: FinetuneArguments): 42 | 43 | ########################## 44 | # Initialise Accelerator # 45 | ########################## 46 | accelerator_log_kwargs = {} 47 | 48 | if args.enable_wandb: 49 | accelerator_log_kwargs['log_with'] = 'wandb' 50 | accelerator_log_kwargs['project_dir'] = args.output_dir 51 | 52 | # If you get timeouts (e.g. due to long tokenization) increase this. 53 | timeout_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.timeout)) 54 | 55 | accelerator = Accelerator( 56 | gradient_accumulation_steps=args.gradient_accumulation_steps, 57 | use_seedable_sampler=True, 58 | **accelerator_log_kwargs, 59 | kwargs_handlers=[timeout_kwargs], 60 | ) 61 | 62 | if args.seed is not None: 63 | set_seed(args.seed) 64 | 65 | if args.enable_wandb: 66 | experiment_config = vars(args) 67 | 68 | accelerator.init_trackers( 69 | args.wandb_project, 70 | experiment_config, 71 | init_kwargs={ 72 | 'wandb': { 73 | 'entity': args.wandb_entity, 74 | 'name': args.wandb_name, 75 | 'tags': [args.exp_name] + get_wandb_tags(), 76 | } 77 | }, 78 | ) 79 | 80 | ##################### 81 | # Configure Logging # 82 | ##################### 83 | logging.basicConfig( 84 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 85 | datefmt='%m/%d/%Y %H:%M:%S', 86 | level=logging.INFO, 87 | ) 88 | logger.info(accelerator.state, main_process_only=False) 89 | if accelerator.is_local_main_process: 90 | datasets.utils.logging.set_verbosity_warning() 91 | transformers.utils.logging.set_verbosity_info() 92 | else: 93 | datasets.utils.logging.set_verbosity_error() 94 | transformers.utils.logging.set_verbosity_error() 95 | 96 | accelerator.wait_for_everyone() 97 | 98 | ########################################## 99 | # Load Pretrained HF Model and Tokenizer # 100 | ########################################## 101 | config = AutoConfig.from_pretrained( 102 | args.model_name_or_path, 103 | trust_remote_code=args.trust_remote_code, 104 | revision=args.model_revision, 105 | token=os.getenv('HF_TOKEN', None), 106 | ) 107 | 108 | tokenizer_revision = ( 109 | args.model_revision 110 | if args.tokenizer_revision is None 111 | else args.tokenizer_revision 112 | ) 113 | 114 | if tokenizer_revision != args.model_revision: 115 | # Warn user if tokenizer and model use different revisions; this is an unusual use case. 116 | warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different 117 | from the model revision `{args.model_revision}`.""" 118 | logger.warn(warning) 119 | 120 | tokenizer = AutoTokenizer.from_pretrained( 121 | args.tokenizer_name_or_path, 122 | trust_remote_code=args.trust_remote_code, 123 | revision=tokenizer_revision, 124 | token=os.getenv('HF_TOKEN', None), 125 | ) 126 | 127 | if args.model_name_or_path: 128 | model = AutoModelForCausalLM.from_pretrained( 129 | args.model_name_or_path, 130 | from_tf=bool('.ckpt' in args.model_name_or_path), 131 | config=config, 132 | trust_remote_code=args.trust_remote_code, 133 | low_cpu_mem_usage=args.low_cpu_mem_usage, 134 | torch_dtype=torch.bfloat16, 135 | attn_implementation='flash_attention_2' 136 | if args.use_flash_attn 137 | else 'eager', 138 | revision=args.model_revision, 139 | token=os.getenv('HF_TOKEN', None), 140 | ) 141 | else: 142 | logger.info('Training new model from scratch') 143 | model = AutoModelForCausalLM.from_config(config) 144 | 145 | 146 | ###################### 147 | # Embedding Resizing # 148 | ###################### 149 | 150 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch 151 | # gather deepspeed to get 'real' embedding size 152 | embeddings = model.get_input_embeddings() 153 | with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None): 154 | embedding_size = embeddings.weight.shape[0] 155 | # resize does its own gather 156 | if len(tokenizer) > embedding_size: 157 | # pad to multiple for tensor cores. 158 | model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) 159 | # update embedding size after resizing for sum loss 160 | embeddings = model.get_input_embeddings() 161 | with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None): 162 | embedding_size = embeddings.weight.shape[0] 163 | 164 | 165 | ####################################### 166 | # Prepare Dataset & Set up Dataloader # 167 | ####################################### 168 | data_files = {} 169 | dataset_args = {} 170 | 171 | if args.train_file: 172 | data_files['train'] = args.train_file 173 | if args.validation_file: 174 | data_files['validation'] = args.validation_file 175 | 176 | raw_datasets = load_dataset( 177 | 'json', 178 | data_files=data_files, 179 | **dataset_args, 180 | ) 181 | 182 | train_dataset = raw_datasets['train'] 183 | validation_dataset = raw_datasets.get('validation', None) 184 | 185 | if 'messages' not in train_dataset.column_names: 186 | raise ValueError("You need to have 'messages' in your training data.") 187 | if validation_dataset and 'messages' not in validation_dataset.column_names: 188 | raise ValueError("You need to have 'messages' in your validation data.") 189 | 190 | # Limit training samples. Used for debugging. 191 | if args.max_train_samples is not None: 192 | max_train_samples = min(len(train_dataset), args.max_train_samples) 193 | logger.info(f'Limiting training samples to {max_train_samples} from {len(train_dataset)}.') 194 | train_dataset = train_dataset.select(range(max_train_samples)) 195 | 196 | encode_function = partial( 197 | encode_messages, 198 | tokenizer=tokenizer, 199 | max_seq_length=args.max_seq_length, 200 | prompt_style=args.prompt_style, 201 | add_bos=args.add_bos, 202 | ) 203 | 204 | with accelerator.main_process_first(): 205 | train_dataset = train_dataset.map( 206 | encode_function, 207 | batched=False, 208 | num_proc=args.preprocessing_num_workers, 209 | load_from_cache_file=not args.overwrite_cache, 210 | remove_columns=[ 211 | name 212 | for name in train_dataset.column_names 213 | if name not in ['input_ids', 'labels', 'attention_mask'] 214 | ], 215 | desc='Tokenizing and reformatting instruction data', 216 | ) 217 | 218 | train_dataset.set_format(type='pt') 219 | train_dataset = train_dataset.filter( 220 | lambda example: (example['labels'] != -100).any() 221 | ) 222 | 223 | # Log a few random samples from the training set. 224 | for index in random.sample(range(len(train_dataset)), 3): 225 | logger.info(f'Sample {index} of the training set: {train_dataset[index]}.') 226 | 227 | train_dataloader = DataLoader( 228 | train_dataset, 229 | shuffle=True, 230 | collate_fn=DataCollatorForSeq2Seq( 231 | tokenizer=tokenizer, model=model, padding='longest' 232 | ), 233 | batch_size=args.per_device_train_batch_size, 234 | ) 235 | 236 | if validation_dataset: 237 | validation_dataset = validation_dataset.map( 238 | encode_function, 239 | batched=False, 240 | num_proc=args.preprocessing_num_workers, 241 | load_from_cache_file=not args.overwrite_cache, 242 | remove_columns=[ 243 | name 244 | for name in validation_dataset.column_names 245 | if name not in ['input_ids', 'labels', 'attention_mask'] 246 | ], 247 | desc='Tokenizing and reformatting validation data', 248 | ) 249 | validation_dataset.set_format(type='pt') 250 | validation_dataset = validation_dataset.filter(lambda example: (example['labels'] != -100).any()) 251 | 252 | validation_dataloader = DataLoader( 253 | validation_dataset, 254 | shuffle=True, 255 | collate_fn=DataCollatorForSeq2Seq( 256 | tokenizer=tokenizer, model=model, padding='longest' 257 | ), 258 | batch_size=args.per_device_train_batch_size, 259 | ) 260 | 261 | 262 | ############################## 263 | # Optimizer and LR Scheduler # 264 | ############################## 265 | 266 | # Split weights in two groups, one with weight decay and the other not. 267 | no_decay = ['bias', 'layer_norm.weight'] 268 | optimizer_grouped_parameters = [ 269 | { 270 | 'params': [ 271 | p 272 | for n, p in model.named_parameters() 273 | if not any(nd in n for nd in no_decay) 274 | ], 275 | 'weight_decay': args.weight_decay, 276 | }, 277 | { 278 | 'params': [ 279 | p 280 | for n, p in model.named_parameters() 281 | if any(nd in n for nd in no_decay) 282 | ], 283 | 'weight_decay': 0.0, 284 | }, 285 | ] 286 | 287 | optimizer = torch.optim.AdamW( 288 | optimizer_grouped_parameters, 289 | lr=args.learning_rate, 290 | fused=args.fused_optimizer, 291 | betas=(args.beta1, args.beta2), 292 | ) 293 | 294 | # Scheduler and math around the number of training steps. 295 | overrode_max_train_steps = False 296 | num_update_steps_per_epoch = math.ceil( 297 | len(train_dataloader) / args.gradient_accumulation_steps 298 | ) 299 | if args.max_train_steps is None: 300 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 301 | overrode_max_train_steps = True 302 | 303 | # Create the learning rate scheduler. 304 | # Note: the current accelerator.step() calls the .step() of the real scheduler 305 | # for the `num_processes` times. This is because they assume 306 | # the user initialize the scheduler with the entire training set. 307 | # In the case of data parallel training, each process only 308 | # sees a subset (1/num_processes) of the training set. 309 | # So each time the process needs to update the lr multiple times so that the total 310 | # number of updates in the end matches the num_training_steps here. 311 | # Here we need to set the num_training_steps to either using the 312 | # entire training set (when epochs is specified) or we need to multiply the 313 | # num_training_steps by num_processes so that the total number of 314 | # updates matches the num_training_steps. 315 | num_training_steps_for_scheduler = ( 316 | args.max_train_steps 317 | if overrode_max_train_steps 318 | else args.max_train_steps * accelerator.num_processes 319 | ) 320 | lr_scheduler = get_scheduler( 321 | name=args.lr_scheduler_type, 322 | optimizer=optimizer, 323 | num_training_steps=num_training_steps_for_scheduler, 324 | num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio), 325 | ) 326 | 327 | ################# 328 | # Miscellaneous # 329 | ################# 330 | 331 | if accelerator.is_main_process and args.output_dir is not None: 332 | os.makedirs(args.output_dir, exist_ok=True) 333 | 334 | validation_steps = args.validation_steps 335 | if validation_steps is not None: 336 | validation_steps = int(validation_steps) 337 | 338 | if args.gradient_checkpointing: 339 | model.gradient_checkpointing_enable() 340 | 341 | # Prepare everything with `accelerator`. 342 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 343 | model, optimizer, train_dataloader, lr_scheduler 344 | ) 345 | 346 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 347 | num_update_steps_per_epoch = math.ceil( 348 | len(train_dataloader) / args.gradient_accumulation_steps 349 | ) 350 | if overrode_max_train_steps: 351 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 352 | # Afterwards we recalculate our number of training epochs 353 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 354 | 355 | # Figure out how many steps we should save the Accelerator states 356 | checkpointing_steps = args.checkpointing_steps 357 | if checkpointing_steps is not None and str(checkpointing_steps).lower() != 'epoch': 358 | checkpointing_steps = int(checkpointing_steps) 359 | 360 | 361 | ############## 362 | # Finetuning # 363 | ############## 364 | 365 | total_batch_size = ( 366 | args.per_device_train_batch_size 367 | * accelerator.num_processes 368 | * args.gradient_accumulation_steps 369 | ) 370 | 371 | logger.info('***** Running training *****') 372 | logger.info(f' Num examples = {len(train_dataset)}') 373 | logger.info(f' Num Epochs = {args.num_train_epochs}') 374 | logger.info(f' Instantaneous batch size per device = {args.per_device_train_batch_size}') 375 | logger.info(f' Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}') 376 | logger.info(f' Gradient Accumulation steps = {args.gradient_accumulation_steps}') 377 | logger.info(f' Total optimization steps = {args.max_train_steps}') 378 | # Only show the progress bar once on each machine. 379 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 380 | completed_steps = 0 381 | starting_epoch = 0 382 | 383 | # Potentially load in the weights and states from a previous save 384 | last_checkpoint_path = get_last_checkpoint_path(args) 385 | if last_checkpoint_path: 386 | accelerator.print(f'Resumed from checkpoint: {last_checkpoint_path}') 387 | accelerator.load_state(last_checkpoint_path) 388 | # Extract `epoch_{i}` or `step_{i}` 389 | last_checkpoint_path = os.path.basename(last_checkpoint_path) 390 | training_difference = os.path.splitext(last_checkpoint_path)[0] 391 | 392 | if 'epoch' in training_difference: 393 | starting_epoch = int(training_difference.replace('epoch_', '')) + 1 394 | resume_step = None 395 | completed_steps = starting_epoch * num_update_steps_per_epoch 396 | else: 397 | # need to multiply `gradient_accumulation_steps` to reflect real steps 398 | resume_step = (int(training_difference.replace('step_', '')) * args.gradient_accumulation_steps) 399 | starting_epoch = resume_step // len(train_dataloader) 400 | completed_steps = resume_step // args.gradient_accumulation_steps 401 | resume_step -= starting_epoch * len(train_dataloader) 402 | 403 | print(f'Starting from epoch {starting_epoch} and step {completed_steps}.') 404 | # update the progress_bar if load from checkpoint 405 | progress_bar.update(completed_steps) 406 | 407 | for epoch in range(starting_epoch, args.num_train_epochs): 408 | model.train() 409 | train_dataloader.set_epoch(epoch) 410 | total_loss = 0 411 | if last_checkpoint_path and resume_step is not None: 412 | # We skip the first `n` batches in the dataloader when resuming from a checkpoint 413 | active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) 414 | else: 415 | active_dataloader = train_dataloader 416 | for batch in active_dataloader: 417 | with accelerator.accumulate(model): 418 | loss = calculate_loss_and_backpropagate( 419 | model, 420 | batch, 421 | accelerator, 422 | optimizer, 423 | lr_scheduler, 424 | embedding_size, 425 | args 426 | ) 427 | # We keep track of the loss at each logged step 428 | total_loss += loss.detach().float() 429 | 430 | # Checks if the accelerator has performed an optimization step behind the scenes 431 | if accelerator.sync_gradients: 432 | progress_bar.update(1) 433 | completed_steps += 1 434 | 435 | # Perform Validation (only done by main) 436 | avg_val_loss = None 437 | if ( 438 | accelerator.is_local_main_process 439 | and validation_steps 440 | and completed_steps % validation_steps == 0 441 | ): 442 | if completed_steps % validation_steps == 0: 443 | model.eval() 444 | full_val_loss = 0 445 | num_val_batches = 0 446 | with torch.no_grad(): 447 | for val_batch in validation_dataloader: 448 | val_batch = { 449 | key: value.to(accelerator.device) 450 | for key, value in val_batch.items() 451 | } 452 | val_loss = calculate_loss(model, val_batch, embedding_size, args) 453 | 454 | full_val_loss += val_loss.detach().float() 455 | num_val_batches += 1 456 | 457 | avg_val_loss = full_val_loss / num_val_batches 458 | model.train() 459 | 460 | if args.logging_steps and completed_steps % args.logging_steps == 0: 461 | avg_loss = ( 462 | accelerator.gather(total_loss).mean().item() 463 | / args.gradient_accumulation_steps 464 | / args.logging_steps 465 | ) 466 | 467 | val_loss_log = f', Val Loss: {avg_val_loss}' if avg_val_loss else '' 468 | logger.info(f' Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Train Loss: {avg_loss}{val_loss_log}') 469 | 470 | if args.enable_wandb: 471 | log_data = { 472 | 'learning_rate': lr_scheduler.get_last_lr()[0], 473 | 'train_loss': avg_loss, 474 | } 475 | if validation_steps: 476 | log_data['validation_loss'] = avg_val_loss 477 | accelerator.log(log_data, step=completed_steps) 478 | 479 | total_loss = 0 480 | 481 | if isinstance(checkpointing_steps, int): 482 | if completed_steps % checkpointing_steps == 0: 483 | output_dir = f'step_{completed_steps}' 484 | if args.output_dir is not None: 485 | output_dir = os.path.join(args.output_dir, output_dir) 486 | accelerator.save_state(output_dir) 487 | # use this to mark the checkpoint as completely saved, to avoid restoring from garbled checkpoints 488 | with open(os.path.join(get_last_checkpoint_path(args, incomplete=True), 'COMPLETED'), 'w') as f: 489 | f.write('COMPLETED') 490 | 491 | if (accelerator.is_local_main_process and args.keep_last_n_checkpoints): 492 | clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints) 493 | accelerator.wait_for_everyone() 494 | 495 | if completed_steps >= args.max_train_steps: 496 | break 497 | 498 | if checkpointing_steps == 'epoch': 499 | output_dir = f'epoch_{epoch}' 500 | if args.output_dir is not None: 501 | output_dir = os.path.join(args.output_dir, output_dir) 502 | # use this to mark the checkpoint as completely saved, to avoid restoring from garbled checkpoints 503 | with open(os.path.join(get_last_checkpoint_path(args, incomplete=True), 'COMPLETED'), 'w') as f: 504 | f.write('COMPLETED') 505 | 506 | if accelerator.is_local_main_process and args.keep_last_n_checkpoints: 507 | clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints) 508 | accelerator.wait_for_everyone() 509 | 510 | if args.output_dir is not None: 511 | save_with_accelerate( 512 | accelerator, 513 | model, 514 | tokenizer, 515 | args.output_dir, 516 | ) 517 | 518 | accelerator.wait_for_everyone() 519 | if args.enable_wandb: 520 | accelerator.end_training() 521 | 522 | 523 | def encode_messages( 524 | example, 525 | tokenizer, 526 | max_seq_length, 527 | prompt_style, 528 | add_bos=False 529 | ): 530 | """ 531 | Here we assume each example has a 'messages' field. 532 | 'messages' is a list of messages. 533 | Each message is a dict with 'role' and 'content' fields. 534 | We concatenate all messages with the roles as delimiters and tokenize them together. 535 | """ 536 | messages = example['messages'] 537 | if len(messages) == 0: 538 | raise ValueError('messages field is empty.') 539 | 540 | style = PromptStyle.from_name(prompt_style) 541 | 542 | # To support multi-turn, we need to mask all non-assistant messages. 543 | # To do this, we compute the length of each tokenized non-assistant messages then mask it 544 | segmented_prompts = [] 545 | segmented_prompts_and_responses = [] 546 | start_idx, end_idx = 0, 0 547 | while end_idx < len(messages): 548 | while end_idx < len(messages) and messages[end_idx]['role'] != 'assistant': 549 | end_idx += 1 550 | if start_idx <= end_idx: 551 | if start_idx == 0: 552 | # expect system prompt 553 | segmented_prompts.append( 554 | style.apply( 555 | messages[start_idx:end_idx], append_assistant_header=True 556 | ) if style else tokenizer.apply_chat_template(messages[start_idx:end_idx], tokenize=False, add_generation_prompt=True) 557 | ) 558 | segmented_prompts_and_responses.append( 559 | style.apply( 560 | messages[start_idx : end_idx + 1], append_assistant_header=False 561 | ) if style else tokenizer.apply_chat_template(messages[start_idx:end_idx + 1], tokenize=False, add_generation_prompt=False) 562 | ) 563 | else: 564 | # should not have system prompt for subsequent turns 565 | segmented_prompts.append( 566 | style.apply( 567 | messages[start_idx:end_idx], 568 | no_system=True, 569 | append_assistant_header=True, 570 | ) if style else tokenizer.apply_chat_template(messages[start_idx:end_idx], tokenize=False, add_generation_prompt=True) 571 | ) 572 | segmented_prompts_and_responses.append( 573 | style.apply( 574 | messages[start_idx : end_idx + 1], 575 | no_system=True, 576 | append_assistant_header=False, 577 | ) if style else tokenizer.apply_chat_template(messages[start_idx:end_idx + 1], tokenize=False, add_generation_prompt=False) 578 | ) 579 | start_idx = end_idx + 1 580 | end_idx += 1 # should be same as start_idx 581 | 582 | if add_bos: 583 | # add bos token to the first prompt 584 | segmented_prompts[0] = tokenizer.bos_token + segmented_prompts[0] 585 | segmented_prompts_and_responses[0] = ( 586 | tokenizer.bos_token + segmented_prompts_and_responses[0] 587 | ) 588 | encoded_segmented_prompts = list( 589 | map( 590 | lambda prompt: tokenizer( 591 | prompt, return_tensors='pt', max_length=max_seq_length, truncation=True 592 | ).input_ids.flatten(), 593 | segmented_prompts, 594 | ) 595 | ) 596 | encoded_segmented_prompts_and_responses = list( 597 | map( 598 | lambda prompt_and_response: tokenizer( 599 | prompt_and_response, 600 | return_tensors='pt', 601 | max_length=max_seq_length, 602 | truncation=True, 603 | ).input_ids.flatten(), 604 | segmented_prompts_and_responses, 605 | ) 606 | ) 607 | 608 | # Achieve the same effect as 'masking' by simply using ignore_index 609 | masked_labels = [] 610 | num_split = len(encoded_segmented_prompts) 611 | for i in range(num_split): 612 | encoded_prompt = encoded_segmented_prompts[i] 613 | encoded_prompt_and_response = encoded_segmented_prompts_and_responses[i] 614 | label = encoded_prompt_and_response.clone() 615 | label[: len(encoded_prompt)] = 0 616 | masked_labels.append(label) 617 | 618 | # concatenate the segments 619 | encoded_prompts_and_responses = torch.cat(encoded_segmented_prompts_and_responses) 620 | labels = torch.cat(masked_labels) 621 | attention_mask = torch.ones_like(encoded_prompts_and_responses) 622 | 623 | return { 624 | 'input_ids': encoded_prompts_and_responses, 625 | 'labels': labels, 626 | 'attention_mask': attention_mask, 627 | } 628 | 629 | def calculate_loss(model, batch, embedding_size, args): 630 | outputs = model(**batch, use_cache=False) 631 | if args.reduce_loss == 'mean': 632 | loss = outputs.loss 633 | else: 634 | # reduce loss is sum 635 | # this ensures that we weight all tokens in the dataset equally, 636 | # rather than weighting each overall example equally when 637 | # using high amounts of gradient accumulation. 638 | # this can result in > 5 point improvements in AlpacaEval 639 | # see https://github.com/huggingface/transformers/issues/24725 for 640 | # more discussion and details. 641 | logits = outputs.logits 642 | labels = batch['labels'] 643 | # Shift so that tokens < n predict n 644 | shift_logits = logits[..., :-1, :].contiguous() 645 | shift_labels = labels[..., 1:].contiguous() 646 | # Flatten the tokens 647 | loss_fct = torch.nn.CrossEntropyLoss(reduction='sum') 648 | shift_logits = shift_logits.view(-1, embedding_size) 649 | shift_labels = shift_labels.view(-1) 650 | # Enable model parallelism 651 | shift_labels = shift_labels.to(shift_logits.device) 652 | loss = loss_fct(shift_logits, shift_labels) 653 | 654 | return loss 655 | 656 | def calculate_loss_and_backpropagate( 657 | model, 658 | batch, 659 | accelerator, 660 | optimizer, 661 | lr_scheduler, 662 | embedding_size, 663 | args 664 | ): 665 | loss = calculate_loss(model, batch, embedding_size, args) 666 | accelerator.backward(loss) 667 | # clip gradient norm. don't do this with deepspeed 668 | if accelerator.sync_gradients and args.clip_grad_norm > 0: 669 | accelerator.clip_grad_norm_(model.parameters(), args.clip_grad_norm) 670 | optimizer.step() 671 | optimizer.zero_grad() 672 | lr_scheduler.step() 673 | 674 | return loss 675 | 676 | if __name__ == '__main__': 677 | parser = ArgumentParserPlus(FinetuneArguments) 678 | args = parser.parse() 679 | main(args) 680 | -------------------------------------------------------------------------------- /finetunerag/model_utils.py: -------------------------------------------------------------------------------- 1 | # Taken and modified from https://github.com/huggingface/trl 2 | # Copyright 2022 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import itertools 18 | from contextlib import contextmanager 19 | from dataclasses import dataclass 20 | from typing import List, Literal, Optional, Tuple, Union 21 | 22 | from guardrag.utils import retry_on_exception 23 | 24 | try: 25 | import deepspeed 26 | from deepspeed.runtime.engine import DeepSpeedEngine 27 | except ImportError: 28 | pass 29 | import pandas as pd 30 | import torch 31 | import transformers 32 | from accelerate import Accelerator 33 | from accelerate.state import AcceleratorState 34 | from huggingface_hub import HfApi 35 | from rich import print as rprint 36 | from rich.console import Console 37 | from rich.table import Table 38 | from rich.text import Text 39 | from torch.nn.parallel.distributed import DistributedDataParallel 40 | from transformers import PreTrainedModel, PreTrainedTokenizer 41 | 42 | 43 | @dataclass 44 | class ModelConfig: 45 | model_name_or_path: Optional[str] = None 46 | """The model checkpoint for weights initialization.""" 47 | model_revision: str = "main" 48 | """The specific model version to use (can be a branch name, tag name or commit id).""" 49 | trust_remote_code: bool = False 50 | """Trust remote code when loading a model.""" 51 | torch_dtype: Optional[str] = None 52 | """Override the default `torch.dtype` and load the model under this dtype.""" 53 | attn_implementation: Optional[Literal["flash_attention_2"]] = None 54 | """Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case 55 | you must install this manually by running `pip install flash-attn --no-build-isolation`""" 56 | use_cache: Optional[bool] = None 57 | """Whether to use cache in the model.""" 58 | gradient_checkpointing: Optional[bool] = None 59 | """Whether to use gradient checkpointing in the model.""" 60 | 61 | # PEFT-related args 62 | use_peft: bool = False 63 | """Whether to use PEFT or not for training.""" 64 | lora_r: Optional[int] = 16 65 | """LoRA R value.""" 66 | lora_alpha: Optional[int] = 32 67 | """LoRA alpha.""" 68 | lora_dropout: Optional[float] = 0.05 69 | """LoRA dropout.""" 70 | lora_target_modules: Optional[List[str]] = None 71 | """LoRA target modules.""" 72 | lora_modules_to_save: Optional[List[str]] = None 73 | """Model layers to unfreeze & train""" 74 | lora_task_type: str = "CAUSAL_LM" 75 | """The task_type to pass for LoRA (use SEQ_CLS for reward modeling)""" 76 | 77 | # quantization args 78 | load_in_8bit: bool = False 79 | """use 8 bit precision for the base model - works only with LoRA""" 80 | load_in_4bit: bool = False 81 | """use 4 bit precision for the base model - works only with LoRA""" 82 | bnb_4bit_quant_type: Optional[str] = "nf4" 83 | """precise the quantization type (fp4 or nf4)""" 84 | use_bnb_nested_quant: bool = False 85 | """use nested quantization""" 86 | 87 | def __post_init__(self): 88 | # `use_cache=True` is incompatible with gradient checkpointing. 89 | # https://github.com/huggingface/transformers/blob/d6751d91c8f58cdeb35af6adae182d7dc90aa883/src/transformers/models/llama/modeling_llama.py#L945 90 | if self.gradient_checkpointing: 91 | self.use_cache = False 92 | 93 | 94 | # ---------------------------------------------------------------------------- 95 | # Model utilities; reward model stuff 96 | def disable_dropout_in_model(model: torch.nn.Module) -> None: 97 | for module in model.modules(): 98 | if isinstance(module, torch.nn.Dropout): 99 | module.p = 0 100 | 101 | 102 | def first_true_indices(bools: torch.Tensor, dtype=torch.long) -> torch.Tensor: 103 | """ 104 | Finds the index of the first `True` value in each row of a boolean tensor. If no `True` value exists in a row, 105 | it returns the length of the row. 106 | 107 | Args: 108 | bools (torch.Tensor): A boolean tensor of shape (batch_size, sequence_length), where `True` values indicate 109 | the positions of interest. 110 | dtype (torch.dtype): The data type to use for the output indices (default is torch.long). 111 | 112 | Returns: 113 | torch.Tensor: A tensor of shape (batch_size,) containing the index of the first `True` value in each row. 114 | If a row has no `True` value, the index will be the length of the row. 115 | """ 116 | 117 | # Get the length of each row (i.e., the number of columns in the last dimension) 118 | # row_len is a scalar representing the length of each sequence (sequence_length) 119 | row_len = bools.size(-1) 120 | 121 | # Calculate the index positions for the first `True` in each row 122 | # ~bools: Invert the boolean values (True becomes False and vice versa) 123 | # ~bools.type(dtype): Convert the inverted boolean tensor to the specified dtype (0 for True, 1 for False) 124 | # row_len * (~bools).type(dtype): For `False` values, this will give `row_len`, for `True` values it gives 0. 125 | # torch.arange(row_len, dtype=dtype, device=bools.device): Generates a tensor with values [0, 1, 2, ..., row_len-1] 126 | # for each row. Shape: (sequence_length,) 127 | # zero_or_index: Shape (batch_size, sequence_length). This tensor contains the indices for `True` values and `row_len` 128 | # for `False` values. 129 | zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) 130 | 131 | # Return the minimum value in each row (i.e., the first `True` index or `row_len` if none exist) 132 | # torch.min(zero_or_index, dim=-1).values: This returns the minimum value in each row, which corresponds to the first 133 | # `True` value's index or `row_len` if there is no `True` in that row. 134 | # The returned tensor has shape (batch_size,) 135 | return torch.min(zero_or_index, dim=-1).values 136 | 137 | 138 | def get_reward( 139 | model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int 140 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 141 | """ 142 | This function computes reward scores for a batch of query responses based on a pre-trained reward model. 143 | 144 | Args: 145 | model (torch.nn.Module): The pre-trained reward model. 146 | query_responses (torch.Tensor): Tensor containing the tokenized responses for which to compute rewards. 147 | Shape: (batch_size, sequence_length) 148 | pad_token_id (int): The ID used for padding tokens in the tokenized sequences. 149 | context_length (int): The length of the prompt or context preceding the completions. 150 | 151 | Returns: 152 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: 153 | - reward_logits: The logits output from the model for all tokens in the sequences. 154 | Shape: (batch_size, sequence_length) 155 | - final_scores: The final reward scores, one for each sequence, after adjusting for sequence lengths. 156 | Shape: (batch_size,) 157 | - sequence_lengths: The lengths of each sequence (excluding padding). 158 | Shape: (batch_size,) 159 | """ 160 | 161 | # Create an attention mask where tokens that are not padding have a value of 1, and padding tokens have a value of 0 162 | # Shape: (batch_size, sequence_length) 163 | attention_mask = query_responses != pad_token_id 164 | 165 | # Calculate position IDs for each token, considering the cumulative sum of the attention mask (to exclude padding) 166 | # Shape: (batch_size, sequence_length) 167 | position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum 168 | 169 | # Access the LM backbone from the reward model using its base model prefix 170 | lm_backbone = getattr(model, model.base_model_prefix) 171 | 172 | # Replace padding tokens with zeros in the input IDs (so padding tokens won't affect the model's processing) 173 | # Shape: (batch_size, sequence_length) 174 | input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) 175 | output = lm_backbone( 176 | input_ids=input_ids, 177 | attention_mask=attention_mask, 178 | position_ids=position_ids, 179 | return_dict=True, 180 | output_hidden_states=True, 181 | use_cache=False, # otherwise mistral-based RM would error out 182 | ) 183 | reward_logits = model.score(output.hidden_states[-1]) # (batch_size, sequence_length) 184 | 185 | # Calculate the length of each sequence by finding the first occurrence of a padding token after the context 186 | # sequence_lengths shape: (batch_size,) 187 | sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length 188 | assert ( 189 | reward_logits.shape[-1] == 1 190 | ), "Reward model should output a single scalar per token. Check if you added `num_labels=1` when doing `AutoModelForSequenceClassification.from_pretrained(...)`." 191 | # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 192 | 193 | # Return the reward logits for all tokens, the final reward scores for each sequence, and the sequence lengths 194 | return ( 195 | # reward_logits shape: (batch_size, sequence_length) 196 | reward_logits, 197 | # final_scores shape: (batch_size,) 198 | reward_logits[ 199 | torch.arange(reward_logits.size(0), device=reward_logits.device), 200 | sequence_lengths, 201 | ].squeeze( 202 | -1 203 | ), # Shape: (batch_size,) 204 | sequence_lengths, 205 | ) 206 | 207 | 208 | def forward( 209 | model: torch.nn.Module, 210 | query_responses: torch.Tensor, 211 | pad_token_id: int, 212 | ) -> torch.nn.Module: 213 | """ 214 | Performs a forward pass through the model with the given query responses and pad token ID. 215 | Args: 216 | model (`torch.nn.Module`): 217 | The model to perform the forward pass. 218 | query_responses (`torch.Tensor`): 219 | The tensor containing the query responses. 220 | pad_token_id (`int`): 221 | The token ID representing the pad token. 222 | Returns: 223 | `torch.nn.Module`: 224 | The output of the model, including hidden states. 225 | """ 226 | attention_mask = query_responses != pad_token_id 227 | position_ids = attention_mask.cumsum(1) - attention_mask.long() 228 | input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) 229 | return model( 230 | input_ids=input_ids, 231 | attention_mask=attention_mask, 232 | position_ids=position_ids, 233 | return_dict=True, 234 | output_hidden_states=True, 235 | ) 236 | 237 | 238 | def truncate_response(stop_token_id: int, pad_token_id: int, responses: torch.Tensor): 239 | """ 240 | Truncates the responses at the first occurrence of the stop token, filling the rest with pad tokens. 241 | Args: 242 | stop_token_id (`int`): 243 | The token ID representing the stop token where truncation occurs. 244 | pad_token_id (`int`): 245 | The token ID representing the pad token used to fill the truncated responses. 246 | responses (`torch.Tensor`): 247 | The tensor containing the responses to be truncated. 248 | Returns: 249 | `torch.Tensor`: 250 | The truncated responses tensor with pad tokens filled after the stop token. 251 | """ 252 | trunc_idxs = first_true_indices(responses == stop_token_id).unsqueeze(-1) 253 | new_size = [1] * (len(responses.size()) - 1) + [responses.shape[1]] 254 | idxs = torch.arange(responses.shape[1], device=responses.device).view(*new_size) 255 | postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, pad_token_id) 256 | return postprocessed_responses 257 | 258 | 259 | def generate( 260 | lm_backbone: torch.nn.Module, queries: torch.Tensor, pad_token_id: int, generation_config: dict 261 | ) -> Tuple[torch.Tensor, torch.Tensor]: 262 | """ 263 | Generates sequences from the language model backbone in a way that does not affect padding tokens. 264 | Args: 265 | lm_backbone (`torch.nn.Module`): 266 | The language model backbone used for generation. 267 | queries (`torch.Tensor`): 268 | The tensor containing the input queries. 269 | pad_token_id (`int`): 270 | The token ID representing the pad token. 271 | generation_config (`dict`): 272 | The configuration dictionary for generation settings. 273 | Returns: 274 | tuple: 275 | - `generated_sequences` (`torch.Tensor`): 276 | The concatenated tensor of input queries and generated sequences. 277 | - `logits` (`torch.Tensor`): 278 | The logits output from the generation process. 279 | """ 280 | context_length = queries.shape[1] 281 | attention_mask = queries != pad_token_id 282 | input_ids = torch.masked_fill(queries, ~attention_mask, 0) 283 | output = lm_backbone.generate( 284 | input_ids=input_ids, 285 | attention_mask=attention_mask, 286 | # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # not needed: already adjusted in generations 287 | # https://github.com/huggingface/transformers/blob/ac33aeeeee2a7a89b89c93c2962e6feb90daef0a/src/transformers/models/gpt2/modeling_gpt2.py#L1227-L1250 288 | generation_config=generation_config, 289 | return_dict_in_generate=True, 290 | output_logits=True, 291 | ) 292 | logits = torch.stack(output.logits, 1) 293 | return torch.cat((queries, output.sequences[:, context_length:]), dim=1), logits 294 | 295 | 296 | @torch.no_grad() 297 | def batch_generation( 298 | model: torch.nn.Module, 299 | queries: torch.Tensor, 300 | local_rollout_forward_batch_size: int, 301 | pad_token_id: int, 302 | generation_config: dict, 303 | ): 304 | query_responses = [] 305 | logitss = [] 306 | for i in range(0, queries.shape[0], local_rollout_forward_batch_size): 307 | query = queries[i : i + local_rollout_forward_batch_size] 308 | query_response, logits = generate( 309 | model, 310 | query, 311 | pad_token_id, 312 | generation_config, 313 | ) 314 | query_responses.append(query_response) 315 | logitss.append(logits) 316 | return torch.cat(query_responses, 0), torch.cat(logitss, 0) 317 | 318 | 319 | def save_with_accelerate( 320 | accelerator: Accelerator, 321 | model: torch.nn.Module, 322 | tokenizer: PreTrainedTokenizer, 323 | output_dir: str, 324 | use_lora: bool = False, 325 | ) -> None: 326 | # set the generation config to an empty setting to be safe. 327 | # we usually do greedy decoding for generation, so this should be okay. 328 | # otherwise, we get an error thrown at save time. 329 | model.generation_config = transformers.GenerationConfig( 330 | temperature=None, top_p=None, eos_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id 331 | ) 332 | 333 | unwrapped_model: PreTrainedModel = accelerator.unwrap_model(model) 334 | # When doing multi-gpu training, we need to use accelerator.get_state_dict(model) to get the state_dict. 335 | # Otherwise, sometimes the model will be saved with only part of the parameters. 336 | # Also, accelerator needs to use the wrapped model to get the state_dict. 337 | state_dict = accelerator.get_state_dict(model) 338 | if use_lora: 339 | # When using lora, the unwrapped model is a PeftModel, which doesn't support the is_main_process 340 | # and has its own save_pretrained function for only saving lora modules. 341 | # We have to manually specify the is_main_process outside the save_pretrained function. 342 | if accelerator.is_main_process: 343 | unwrapped_model.save_pretrained(output_dir, state_dict=state_dict) 344 | else: 345 | # don't use safetensors for saving for now 346 | unwrapped_model.save_pretrained( 347 | output_dir, 348 | is_main_process=accelerator.is_main_process, 349 | save_function=accelerator.save, 350 | state_dict=state_dict, 351 | safe_serialization=False, 352 | ) 353 | 354 | if accelerator.is_main_process: 355 | tokenizer.save_pretrained(output_dir) 356 | # customize model card (TODO (Costa): this can be prettier) 357 | 358 | 359 | @retry_on_exception() 360 | def push_folder_to_hub( 361 | accelerator: Accelerator, 362 | output_dir: str, 363 | hf_repo_id: Optional[str] = None, 364 | hf_repo_revision: Optional[str] = None, 365 | private: bool = True, 366 | ): 367 | if accelerator.is_main_process: 368 | hf_repo_url = f"https://huggingface.co/{hf_repo_id}/tree/{hf_repo_revision}" 369 | api = HfApi() 370 | if not api.repo_exists(hf_repo_id): 371 | api.create_repo(hf_repo_id, exist_ok=True, private=private) 372 | if hf_repo_revision is not None: 373 | api.create_branch(repo_id=hf_repo_id, branch=hf_repo_revision, exist_ok=True) 374 | api.upload_folder( 375 | repo_id=hf_repo_id, 376 | revision=hf_repo_revision, 377 | folder_path=output_dir, 378 | commit_message="upload checkpoint", 379 | run_as_future=False, 380 | ) 381 | print(f"🔥 pushed to {hf_repo_url}") 382 | 383 | 384 | # ---------------------------------------------------------------------------- 385 | # DeepSpeed utilities 386 | def get_all_parameters(sub_module, recurse=False): 387 | return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters()) 388 | 389 | 390 | def iter_params(module, recurse=False): 391 | return [param for _, param in get_all_parameters(module, recurse)] 392 | 393 | 394 | def remove_hooks(model: "DeepSpeedEngine") -> None: 395 | """Removes the optimizer hooks from a DeepSpeed ZeRO-3 model.""" 396 | if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): 397 | optimizer_offload = model.optimizer.parameter_offload 398 | elif model.optimizer is not None: 399 | optimizer_offload = model.optimizer 400 | 401 | for param in iter_params(optimizer_offload.module, recurse=True): 402 | param.ds_active_sub_modules.clear() 403 | 404 | for hook in optimizer_offload.forward_hooks: 405 | hook.remove() 406 | for hook in optimizer_offload.backward_hooks: 407 | hook.remove() 408 | 409 | optimizer_offload.forward_hooks = [] 410 | optimizer_offload.backward_hooks = [] 411 | 412 | 413 | def add_hooks(model: "DeepSpeedEngine") -> None: 414 | """Adds the optimizer hooks from a DeepSpeed ZeRO-3 model.""" 415 | if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): 416 | optimizer_offload = model.optimizer.parameter_offload 417 | elif model.optimizer is not None: 418 | optimizer_offload = model.optimizer 419 | optimizer_offload._register_hooks_recursively(optimizer_offload.module) 420 | 421 | 422 | @contextmanager 423 | def unwrap_model_for_generation( 424 | model: Union["DistributedDataParallel", "DeepSpeedEngine"], accelerator: "Accelerator", is_peft_model: bool = False 425 | ) -> Union["transformers.PreTrainedModel", "DeepSpeedEngine"]: 426 | """Context manager to unwrap a model for generation. 427 | For ZeRO-3 models, we gather the weights once to speed up generation. 428 | """ 429 | unwrapped_model = accelerator.unwrap_model(model) 430 | if is_peft_model: 431 | unwrapped_model.pretrained_model.disable_adapter() 432 | if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3: 433 | with deepspeed.zero.GatheredParameters(model.parameters()): 434 | remove_hooks(model) 435 | yield accelerator.unwrap_model(model) 436 | add_hooks(model) 437 | else: 438 | yield unwrapped_model 439 | 440 | 441 | def prepare_deepspeed(model: torch.nn.Module, per_device_train_batch_size: int, mixed_precision: str): 442 | """ 443 | Prepares the model for training with DeepSpeed (both for stage 2 and 3), configuring the appropriate settings based on the model and 444 | batch size. 445 | Args: 446 | model (`torch.nn.Module`): 447 | The model to be prepared for DeepSpeed training. 448 | per_device_train_batch_size (`int`): 449 | The training batch size per device. 450 | mixed_precision (`str`): 451 | The mixed precision setting to use. 452 | Returns: 453 | `torch.nn.Module`: 454 | The model initialized and configured with DeepSpeed for training. 455 | """ 456 | import deepspeed 457 | 458 | deepspeed_plugin = AcceleratorState().deepspeed_plugin 459 | config_kwargs = deepspeed_plugin.deepspeed_config 460 | if config_kwargs["zero_optimization"]["stage"] != 3: 461 | config_kwargs["train_micro_batch_size_per_gpu"] = per_device_train_batch_size 462 | config_kwargs = { 463 | "train_micro_batch_size_per_gpu": config_kwargs["train_micro_batch_size_per_gpu"], 464 | "prescale_gradients": False, 465 | "wall_clock_breakdown": False, 466 | } 467 | if mixed_precision in ["bf16", "fp16"]: 468 | config_kwargs[mixed_precision] = {"enabled": True} 469 | else: 470 | if hasattr(model, "config"): 471 | hidden_size = ( 472 | max(model.config.hidden_sizes) 473 | if getattr(model.config, "hidden_sizes", None) 474 | else getattr(model.config, "hidden_size", None) 475 | ) 476 | if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: 477 | # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0` 478 | # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 479 | config_kwargs.update( 480 | { 481 | "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, 482 | "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, 483 | "zero_optimization.stage3_prefetch_bucket_size": 0, 484 | } 485 | ) 486 | model, *_ = deepspeed.initialize(model=model, config=config_kwargs) 487 | model.eval() 488 | return model 489 | 490 | 491 | # ---------------------------------------------------------------------------- 492 | # Quality of life utilities 493 | def print_rich_table(df: pd.DataFrame) -> Table: 494 | console = Console() 495 | table = Table(show_lines=True) 496 | for column in df.columns: 497 | table.add_column(column) 498 | for _, row in df.iterrows(): 499 | table.add_row(*row.astype(str).tolist()) 500 | console.print(table) 501 | 502 | 503 | def format_value(value): 504 | if isinstance(value, float): 505 | if abs(value) < 1e-5: 506 | return f"{value:.2e}" 507 | return f"{value:.2f}" 508 | return str(value) 509 | 510 | 511 | def print_rich_single_line_metrics(metrics): 512 | formatted_metrics = [] 513 | for key, value in metrics.items(): 514 | # Shortening the key names 515 | short_key = key.split("/")[-1] if "/" in key else key 516 | 517 | # Create a colored text object 518 | metric_text = Text() 519 | metric_text.append(short_key + ": ", style="bold cyan") # Keys in cyan 520 | metric_text.append(format_value(value), style="yellow") # Values in yellow 521 | 522 | formatted_metrics.append(metric_text) 523 | 524 | rprint(" | ".join(str(metric) for metric in formatted_metrics)) 525 | 526 | 527 | def exact_div(a, b, custom_error_message=""): 528 | q = a // b 529 | if a != q * b: 530 | raise ValueError(f"{custom_error_message}, inexact division: {a} / {b} = {a / b}") 531 | return q -------------------------------------------------------------------------------- /finetunerag/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 AllenAI. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import dataclasses 16 | import functools 17 | import json 18 | import logging 19 | import os 20 | import shutil 21 | import subprocess 22 | import sys 23 | import time 24 | from dataclasses import dataclass 25 | from typing import Any, List, NewType, Optional, Tuple, Union 26 | 27 | import requests 28 | from accelerate.logging import get_logger 29 | from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk 30 | from datasets.builder import DatasetGenerationError 31 | from huggingface_hub import HfApi 32 | from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser 33 | 34 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) 35 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 36 | 37 | logger = get_logger(__name__) 38 | 39 | DataClassType = NewType("DataClassType", Any) 40 | 41 | """ 42 | Notes: 43 | Inspired by Alignment Handbook Parser and Dataset Mixer 44 | https://github.com/huggingface/alignment-handbook/blob/main/src/alignment/configs.py 45 | https://github.com/huggingface/alignment-handbook/blob/main/src/alignment/data.py 46 | 47 | Migrated Args from 48 | https://github.com/allenai/open-instruct/blob/98ccfb460ae4fb98140783b6cf54241926160a06/open_instruct/finetune_trainer.py 49 | 50 | Commented out Args not currently used 51 | """ 52 | 53 | 54 | # ---------------------------------------------------------------------------- 55 | # Dataset utilities 56 | def is_openai_format(messages: Any) -> bool: 57 | """ 58 | Check if the input messages are in OpenAI format. 59 | Args: 60 | messages (`Any`): 61 | Messages to check. 62 | Returns: 63 | `bool`: Whether the messages are in OpenAI format. 64 | """ 65 | if isinstance(messages, list) and all(isinstance(message, dict) for message in messages): 66 | return all("role" in message and "content" in message for message in messages) 67 | return False 68 | 69 | 70 | # functions for handling different formats of messages 71 | def convert_alpaca_gpt4_to_messages(example): 72 | """ 73 | Convert an instruction in inst-output to a list of messages. 74 | e.g. vicgalle/alpaca-gpt4""" 75 | messages = [ 76 | { 77 | "role": "user", 78 | "content": ( 79 | "Below is an instruction that describes a task, paired with an input that provides " 80 | "further context. Write a response that appropriately completes the request.\n\n" 81 | f"### Instruction:\n{example['instruction']}\n\n" 82 | f"### Input:\n{example['input']}\n\n" 83 | "### Response:" 84 | ), 85 | }, 86 | {"role": "assistant", "content": example["output"]}, 87 | ] 88 | example["messages"] = messages 89 | return example 90 | 91 | 92 | def convert_codefeedback_single_turn_to_messages(example): 93 | """ 94 | Convert a query-answer pair to a list of messages. 95 | e.g. m-a-p/CodeFeedback-Filtered-Instruction""" 96 | messages = [ 97 | {"role": "user", "content": example["query"]}, 98 | {"role": "assistant", "content": example["answer"]}, 99 | ] 100 | example["messages"] = messages 101 | return example 102 | 103 | 104 | def convert_metamath_qa_to_messages(example): 105 | """ 106 | Convert a query-response pair to a list of messages. 107 | e.g. meta-math/MetaMathQA""" 108 | messages = [ 109 | {"role": "user", "content": example["query"]}, 110 | {"role": "assistant", "content": example["response"]}, 111 | ] 112 | example["messages"] = messages 113 | return example 114 | 115 | 116 | def convert_code_alpaca_to_messages(example): 117 | """ 118 | Convert a prompt-completion pair to a list of messages. 119 | e.g. HuggingFaceH4/CodeAlpaca_20K""" 120 | messages = [ 121 | {"role": "user", "content": example["prompt"]}, 122 | {"role": "assistant", "content": example["completion"]}, 123 | ] 124 | example["messages"] = messages 125 | return example 126 | 127 | 128 | def convert_open_orca_to_messages(example): 129 | """ 130 | Convert a question-response pair to a list of messages. 131 | e.g. Open-Orca/OpenOrca""" 132 | messages = [ 133 | {"role": "system", "content": example["system_prompt"]}, 134 | {"role": "user", "content": example["question"]}, 135 | {"role": "assistant", "content": example["response"]}, 136 | ] 137 | example["messages"] = messages 138 | return example 139 | 140 | 141 | def conversations_to_messages(example): 142 | """ 143 | Convert from conversations format to messages. 144 | 145 | E.g. change "from": "user" to "role": "user" 146 | and "value" to "content" 147 | and "gpt" to "assistant" 148 | 149 | WizardLMTeam/WizardLM_evol_instruct_V2_196k 150 | """ 151 | name_mapping = { 152 | "gpt": "assistant", 153 | "Assistant": "assistant", 154 | "assistant": "assistant", 155 | "user": "user", 156 | "User": "user", 157 | "human": "user", 158 | } 159 | messages = [{"role": name_mapping[conv["from"]], "content": conv["value"]} for conv in example["conversations"]] 160 | example["messages"] = messages 161 | return example 162 | 163 | 164 | def convert_rejection_samples_to_messages(example): 165 | """ 166 | Convert a rejection sampling dataset to messages. 167 | """ 168 | example["messages"] = example["chosen"] 169 | return example 170 | 171 | 172 | def get_datasets( 173 | dataset_mixer: Union[dict, list], 174 | splits: Optional[List[str]] = None, 175 | configs: Optional[List[str]] = None, 176 | columns_to_keep: Optional[List[str]] = None, 177 | shuffle: bool = True, 178 | save_data_dir: Optional[str] = None, 179 | need_columns: Optional[List[str]] = None, 180 | keep_ids: bool = False, 181 | ) -> DatasetDict: 182 | """ 183 | Loads and mixes datasets according to proportions specified in `dataset_mixer`. 184 | 185 | Args: 186 | dataset_mixer (`list` or `dict`): 187 | Dictionary or list containing the dataset names and their training proportions. 188 | By default, all test proportions are 1. Lists are formatted as 189 | `key1 value1 key2 value2 ...` If a list is passed in, it will be converted to a dictionary. 190 | splits (Optional[List[str]], *optional*, defaults to `None`): 191 | Dataset splits to load and mix. Assumes the splits exist in 192 | all datasets and have a `train_` or `test_` prefix. 193 | configs (Optional[List[str]], *optional*, defaults to `None`): 194 | List of dataset config names. If given must be the same length as 'dataset_mixer' keys. 195 | columns_to_keep (Optional[List[str]], *optional*, defaults to `None`): 196 | Column names to keep in the dataset. Useful in the datamixer to avoid schema conflicts, 197 | and for cpt this should be (at least) the text column. 198 | shuffle (`bool`, *optional*, defaults to `True`): 199 | Whether to shuffle the training and testing/validation data. 200 | save_data_dir (Optional[str], *optional*, defaults to `None`): 201 | Optional directory to save training/test mixes on. 202 | need_columns (Optional[List[str]], *optional*, defaults to `None`): 203 | Column names that are required to be in the dataset. 204 | Quick debugging when mixing heterogeneous datasets. 205 | keep_ids (`bool`, *optional*, defaults to `False`): 206 | Whether to keep ids for training that are added during mixing. 207 | Used primarily in mix_data.py for saving, or the saved dataset has IDs already. 208 | """ 209 | if isinstance(dataset_mixer, list): 210 | assert len(dataset_mixer) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer}" 211 | mixer_dict = {} 212 | i = 0 213 | while i < len(dataset_mixer) - 1: 214 | assert isinstance(dataset_mixer[i], str), f"Invalid type in data mixer: {dataset_mixer}" 215 | if "." in dataset_mixer[i + 1]: 216 | value = float(dataset_mixer[i + 1]) 217 | else: 218 | value = int(dataset_mixer[i + 1]) 219 | mixer_dict[dataset_mixer[i]] = value 220 | i += 2 221 | dataset_mixer = mixer_dict 222 | 223 | splits = ["train", "test"] if splits is None else splits 224 | configs = [None] * len(dataset_mixer) if not configs else configs 225 | columns_to_keep = [] if columns_to_keep is None else columns_to_keep 226 | 227 | if configs is not None and len(configs) != len(dataset_mixer): 228 | raise ValueError("The number of given dataset config names must be the same as the given number of datasets.") 229 | 230 | # print save location 231 | if save_data_dir: 232 | print(f"Saving mixed dataset to {save_data_dir}") 233 | 234 | raw_datasets = DatasetDict() 235 | raw_train_datasets = [] 236 | raw_val_datasets = [] 237 | frac_or_sample_list = [] 238 | for (ds, frac_or_samples), ds_config in zip(dataset_mixer.items(), configs): 239 | frac_or_sample_list.append(frac_or_samples) 240 | for split in splits: 241 | # if dataset ends with .json or .jsonl, load from file 242 | if ds.endswith(".json") or ds.endswith(".jsonl"): 243 | dataset = load_dataset("json", data_files=ds, split=split) 244 | else: 245 | try: 246 | # Try first if dataset on a Hub repo 247 | dataset = load_dataset(ds, ds_config, split=split) 248 | except DatasetGenerationError: 249 | # If not, check local dataset 250 | dataset = load_from_disk(os.path.join(ds, split)) 251 | 252 | # shuffle dataset if set 253 | if shuffle: 254 | dataset = dataset.shuffle(seed=42) 255 | 256 | # assert that needed columns are present 257 | if need_columns: 258 | if not all(col in dataset.column_names for col in need_columns): 259 | raise ValueError(f"Needed column {need_columns} not found in dataset {dataset.column_names}.") 260 | 261 | # handle per-case conversions 262 | # if "instruction" and "output" columns are present and "messages" is not, convert to messages 263 | if ( 264 | "instruction" in dataset.column_names 265 | and "output" in dataset.column_names 266 | and "messages" not in dataset.column_names 267 | ): 268 | dataset = dataset.map(convert_alpaca_gpt4_to_messages, num_proc=10) 269 | elif ( 270 | "prompt" in dataset.column_names 271 | and "completion" in dataset.column_names 272 | and "messages" not in dataset.column_names 273 | ): 274 | dataset = dataset.map(convert_code_alpaca_to_messages, num_proc=10) 275 | elif "conversations" in dataset.column_names and "messages" not in dataset.column_names: 276 | dataset = dataset.map(conversations_to_messages, num_proc=10) 277 | elif ( 278 | "question" in dataset.column_names 279 | and "response" in dataset.column_names 280 | and "messages" not in dataset.column_names 281 | ): 282 | dataset = dataset.map(convert_open_orca_to_messages, num_proc=10) 283 | elif ( 284 | "query" in dataset.column_names 285 | and "answer" in dataset.column_names 286 | and "messages" not in dataset.column_names 287 | ): 288 | dataset = dataset.map(convert_codefeedback_single_turn_to_messages, num_proc=10) 289 | elif ( 290 | "query" in dataset.column_names 291 | and "response" in dataset.column_names 292 | and "messages" not in dataset.column_names 293 | ): 294 | dataset = dataset.map(convert_metamath_qa_to_messages, num_proc=10) 295 | elif ( 296 | "chosen" in dataset.column_names 297 | and "rejected" in dataset.column_names 298 | and "reference_completion" in dataset.column_names 299 | and "messages" not in dataset.column_names 300 | ): 301 | dataset = dataset.map(convert_rejection_samples_to_messages, num_proc=10) 302 | 303 | # if id not in dataset, create it as ds-{index} 304 | if "id" not in dataset.column_names: 305 | id_col = [f"{ds}_{i}" for i in range(len(dataset))] 306 | dataset = dataset.add_column("id", id_col) 307 | 308 | # Remove redundant columns to avoid schema conflicts on load 309 | dataset = dataset.remove_columns( 310 | [col for col in dataset.column_names if col not in (columns_to_keep + ["id"])] 311 | ) 312 | 313 | # add tag to the dataset corresponding to where it was sourced from, for 314 | if "train" in split: 315 | raw_train_datasets.append(dataset) 316 | elif "test" in split: 317 | raw_val_datasets.append(dataset) 318 | else: 319 | raise ValueError(f"Split type {split} not recognized as one of test or train.") 320 | 321 | if len(raw_val_datasets) == 0 and len(raw_train_datasets) == 0: 322 | raise ValueError("No datasets loaded.") 323 | elif len(raw_train_datasets) == 0: 324 | # target features are the features of the first dataset post load 325 | target_features = raw_val_datasets[0].features 326 | else: 327 | # target features are the features of the first dataset post load 328 | target_features = raw_train_datasets[0].features 329 | 330 | if any(frac_or_samples < 0 for frac_or_samples in frac_or_sample_list): 331 | raise ValueError("Dataset fractions / lengths cannot be negative.") 332 | 333 | # if any > 1, use count 334 | if any(frac_or_samples > 1 for frac_or_samples in frac_or_sample_list): 335 | is_count = True 336 | # assert that all are integers 337 | if not all(isinstance(frac_or_samples, int) for frac_or_samples in frac_or_sample_list): 338 | raise NotImplementedError("Cannot mix fractions and counts, yet.") 339 | else: 340 | is_count = False 341 | 342 | if len(raw_train_datasets) > 0: 343 | train_subsets = [] 344 | # Manage proportions 345 | for dataset, frac_or_samples in zip(raw_train_datasets, frac_or_sample_list): 346 | # cast features (TODO, add more feature regularization) 347 | dataset = dataset.cast(target_features) 348 | # TODO selection can be randomized. 349 | if is_count: 350 | train_subset = dataset.select(range(frac_or_samples)) 351 | else: 352 | train_subset = dataset.select(range(int(frac_or_samples * len(dataset)))) 353 | train_subsets.append(train_subset) 354 | 355 | raw_datasets["train"] = concatenate_datasets(train_subsets) 356 | 357 | # No subsampling for test datasets to enable fair comparison across models 358 | if len(raw_val_datasets) > 0: 359 | for dataset in raw_val_datasets: 360 | # cast features (TODO, add more feature regularization) 361 | dataset = dataset.cast(target_features) 362 | 363 | raw_datasets["test"] = concatenate_datasets(raw_val_datasets) 364 | 365 | if len(raw_datasets) == 0: 366 | raise ValueError( 367 | f"Dataset {dataset_mixer} not recognized with splits {splits}." 368 | "Check the dataset has been correctly formatted." 369 | ) 370 | 371 | # optional save 372 | if save_data_dir: 373 | for split in raw_datasets: 374 | raw_datasets[split].to_json(save_data_dir + f"mixed_ds_{split}.json") 375 | 376 | if not keep_ids: 377 | # remove id column 378 | if len(raw_train_datasets) > 0: 379 | if "id" in raw_datasets["train"].column_names: 380 | raw_datasets["train"] = raw_datasets["train"].remove_columns("id") 381 | if len(raw_val_datasets) > 0: 382 | if "id" in raw_datasets["test"].column_names: 383 | raw_datasets["test"] = raw_datasets["test"].remove_columns("id") 384 | 385 | return raw_datasets 386 | 387 | 388 | # ---------------------------------------------------------------------------- 389 | # Arguments utilities 390 | class ArgumentParserPlus(HfArgumentParser): 391 | def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = None) -> List[dataclass]: 392 | """ 393 | Parse a YAML file and overwrite the default/loaded values with the values provided to the command line. 394 | 395 | Args: 396 | yaml_arg (`str`): 397 | The path to the config file used 398 | other_args (`List[str]`, *optional`): 399 | A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2']. 400 | 401 | Returns: 402 | [`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line 403 | """ 404 | arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg)) 405 | 406 | outputs = [] 407 | # strip other args list into dict of key-value pairs 408 | other_args = {arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args} 409 | used_args = {} 410 | 411 | # overwrite the default/loaded value with the value provided to the command line 412 | # noqa adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327 413 | for data_yaml, data_class in zip(arg_list, self.dataclass_types): 414 | keys = {f.name for f in dataclasses.fields(data_yaml) if f.init} 415 | inputs = {k: v for k, v in vars(data_yaml).items() if k in keys} 416 | for arg, val in other_args.items(): 417 | # add only if in keys 418 | 419 | if arg in keys: 420 | base_type = data_yaml.__dataclass_fields__[arg].type 421 | inputs[arg] = val 422 | 423 | # cast type for ints, floats (default to strings) 424 | if base_type in [int, float]: 425 | inputs[arg] = base_type(val) 426 | 427 | if base_type == List[str]: 428 | inputs[arg] = [str(v) for v in val.split(",")] 429 | 430 | # bool of a non-empty string is True, so we manually check for bools 431 | if base_type == bool: 432 | if val in ["true", "True"]: 433 | inputs[arg] = True 434 | else: 435 | inputs[arg] = False 436 | 437 | # add to used-args so we can check if double add 438 | if arg not in used_args: 439 | used_args[arg] = val 440 | else: 441 | raise ValueError(f"Duplicate argument provided: {arg}, may cause unexpected behavior") 442 | 443 | obj = data_class(**inputs) 444 | outputs.append(obj) 445 | 446 | return outputs 447 | 448 | def parse(self) -> Union[DataClassType, Tuple[DataClassType]]: 449 | if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): 450 | # If we pass only one argument to the script and it's the path to a YAML file, 451 | # let's parse it to get our arguments. 452 | output = self.parse_yaml_file(os.path.abspath(sys.argv[1])) 453 | # parse command line args and yaml file 454 | elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"): 455 | output = self.parse_yaml_and_args(os.path.abspath(sys.argv[1]), sys.argv[2:]) 456 | # parse command line args only 457 | else: 458 | output = self.parse_args_into_dataclasses() 459 | 460 | if len(output) == 1: 461 | output = output[0] 462 | return output 463 | 464 | 465 | # ---------------------------------------------------------------------------- 466 | # Experiment tracking utilities 467 | def get_git_tag() -> str: 468 | """Try to get the latest Git tag (e.g., `no-tag-404-g98dc659` or `v1.0.0-4-g98dc659`)""" 469 | git_tag = "" 470 | try: 471 | git_tag = ( 472 | subprocess.check_output(["git", "describe", "--tags"], stderr=subprocess.DEVNULL).decode("ascii").strip() 473 | ) 474 | except subprocess.CalledProcessError as e: 475 | logging.debug(f"Failed to get Git tag: {e}") 476 | 477 | # If no Git tag found, create a custom tag based on commit count and hash 478 | if len(git_tag) == 0: 479 | try: 480 | count = int( 481 | subprocess.check_output(["git", "rev-list", "--count", "HEAD"], stderr=subprocess.DEVNULL) 482 | .decode("ascii") 483 | .strip() 484 | ) 485 | hash = ( 486 | subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], stderr=subprocess.DEVNULL) 487 | .decode("ascii") 488 | .strip() 489 | ) 490 | git_tag = f"no-tag-{count}-g{hash}" 491 | except subprocess.CalledProcessError as e: 492 | logging.debug(f"Failed to get commit count and hash: {e}") 493 | 494 | return git_tag 495 | 496 | 497 | def get_pr_tag() -> str: 498 | """Try to find associated pull request on GitHub (e.g., `pr-123`)""" 499 | pr_tag = "" 500 | try: 501 | git_commit = ( 502 | subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"], stderr=subprocess.DEVNULL) 503 | .decode("ascii") 504 | .strip() 505 | ) 506 | # try finding the pull request number on github 507 | prs = requests.get(f"https://api.github.com/search/issues?q=repo:allenai/open-instruct+is:pr+{git_commit}") 508 | if prs.status_code == 200: 509 | prs = prs.json() 510 | if len(prs["items"]) > 0: 511 | pr = prs["items"][0] 512 | pr_number = pr["number"] 513 | pr_tag = f"pr-{pr_number}" 514 | except Exception as e: 515 | logging.debug(f"Failed to get PR number: {e}") 516 | 517 | return pr_tag 518 | 519 | 520 | def get_wandb_tags() -> List[str]: 521 | """Get tags for Weights & Biases (e.g., `no-tag-404-g98dc659,pr-123`)""" 522 | existing_wandb_tags = os.environ.get("WANDB_TAGS", "") 523 | git_tag = get_git_tag() 524 | pr_tag = get_pr_tag() 525 | non_empty_tags = [tag for tag in [existing_wandb_tags, git_tag, pr_tag] if len(tag) > 0] 526 | return non_empty_tags 527 | 528 | 529 | # ---------------------------------------------------------------------------- 530 | # Check pointing utilities 531 | def get_last_checkpoint(folder: str, incomplete: bool = False) -> Optional[str]: 532 | content = os.listdir(folder) 533 | checkpoint_steps = [path for path in content if path.startswith("step_")] 534 | checkpoint_epochs = [path for path in content if path.startswith("epoch_")] 535 | if len(checkpoint_steps) > 0 and len(checkpoint_epochs) > 0: 536 | logger.info("Mixed step and epoch checkpoints found. Using step checkpoints.") 537 | checkpoints = checkpoint_steps 538 | elif len(checkpoint_steps) == 0: 539 | checkpoints = checkpoint_epochs 540 | else: 541 | checkpoints = checkpoint_steps 542 | if not incomplete: 543 | checkpoints = [path for path in checkpoints if os.path.exists(os.path.join(folder, path, "COMPLETED"))] 544 | if len(checkpoints) == 0: 545 | return 546 | return os.path.join(folder, max(checkpoints, key=lambda x: x.split("_")[-1])) 547 | 548 | 549 | def get_last_checkpoint_path(args, incomplete: bool = False) -> str: 550 | # if output already exists and user does not allow overwriting, resume from there. 551 | # otherwise, resume if the user specifies a checkpoint. 552 | # else, start from scratch. 553 | # if incomplete is true, include folders without "COMPLETE" in the folder. 554 | last_checkpoint_path = None 555 | if args.output_dir and os.path.isdir(args.output_dir) and not args.overwrite_output_dir: 556 | last_checkpoint_path = get_last_checkpoint(args.output_dir, incomplete=incomplete) 557 | if last_checkpoint_path is None: 558 | logger.warning("Output directory exists but no checkpoint found. Starting from scratch.") 559 | elif args.resume_from_checkpoint: 560 | last_checkpoint_path = args.resume_from_checkpoint 561 | return last_checkpoint_path 562 | 563 | 564 | def is_checkpoint_folder(dir: str, folder: str) -> bool: 565 | return (folder.startswith("step_") or folder.startswith("epoch_")) and os.path.isdir(os.path.join(dir, folder)) 566 | 567 | 568 | def clean_last_n_checkpoints(output_dir: str, keep_last_n_checkpoints: int) -> None: 569 | # remove the last checkpoint to save space 570 | folders = [f for f in os.listdir(output_dir) if is_checkpoint_folder(output_dir, f)] 571 | # find the checkpoint with the largest step 572 | checkpoints = sorted(folders, key=lambda x: int(x.split("_")[-1])) 573 | if len(checkpoints) > keep_last_n_checkpoints: 574 | for checkpoint in checkpoints[: len(checkpoints) - keep_last_n_checkpoints]: 575 | logger.info(f"Removing checkpoint {checkpoint}") 576 | shutil.rmtree(os.path.join(output_dir, checkpoint)) 577 | logger.info("Remaining files:" + str(os.listdir(output_dir))) 578 | 579 | 580 | # ---------------------------------------------------------------------------- 581 | # Ai2 user utilities 582 | @dataclass 583 | class BeakerRuntimeConfig: 584 | beaker_workload_id: str 585 | beaker_node_hostname: Optional[List[str]] = None 586 | beaker_experiment_url: Optional[List[str]] = None 587 | beaker_dataset_ids: Optional[List[str]] = None 588 | beaker_dataset_id_urls: Optional[List[str]] = None 589 | 590 | 591 | def is_beaker_job() -> bool: 592 | return "BEAKER_JOB_ID" in os.environ 593 | 594 | 595 | def get_beaker_experiment_info(experiment_id: str) -> Optional[dict]: 596 | get_experiment_command = f"beaker experiment get {experiment_id} --format json" 597 | process = subprocess.Popen(["bash", "-c", get_experiment_command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 598 | stdout, stderr = process.communicate() 599 | if process.returncode != 0: 600 | print(f"Failed to get Beaker experiment: {stderr}") 601 | return None 602 | return json.loads(stdout)[0] 603 | 604 | 605 | def beaker_experiment_succeeded(experiment_id: str) -> bool: 606 | experiment = get_beaker_experiment_info(experiment_id) 607 | if not experiment: 608 | return False 609 | return all(["finalized" in job["status"] and job["status"]["exitCode"] == 0 for job in experiment["jobs"]]) 610 | 611 | 612 | def get_beaker_dataset_ids(experiment_id: str) -> Optional[List[str]]: 613 | experiment = get_beaker_experiment_info(experiment_id) 614 | if not experiment: 615 | return None 616 | result_ids = [job["result"]["beaker"] for job in experiment["jobs"]] 617 | dataset_ids = [] 618 | for result_id in result_ids: 619 | get_dataset_command = f"beaker dataset get {result_id} --format json" 620 | process = subprocess.Popen(["bash", "-c", get_dataset_command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 621 | stdout, stderr = process.communicate() 622 | if process.returncode != 0: 623 | print(f"Failed to get Beaker dataset: {stderr}") 624 | return None 625 | datasets = json.loads(stdout) 626 | dataset_ids.extend([dataset["id"] for dataset in datasets]) 627 | return dataset_ids 628 | 629 | 630 | def get_beaker_whoami() -> Optional[str]: 631 | get_beaker_whoami_command = "beaker account whoami --format json" 632 | process = subprocess.Popen( 633 | ["bash", "-c", get_beaker_whoami_command], stdout=subprocess.PIPE, stderr=subprocess.PIPE 634 | ) 635 | stdout, stderr = process.communicate() 636 | if process.returncode != 0: 637 | print(f"Failed to get Beaker account: {stderr}") 638 | return None 639 | accounts = json.loads(stdout) 640 | return accounts[0]["name"] 641 | 642 | 643 | def maybe_get_beaker_config(): 644 | beaker_dataset_ids = get_beaker_dataset_ids(os.environ["BEAKER_WORKLOAD_ID"]) 645 | # fix condition on basic interactive jobs 646 | if beaker_dataset_ids is None: 647 | beaker_dataset_id_urls = [] 648 | else: 649 | beaker_dataset_id_urls = [f"https://beaker.org/ds/{dataset_id}" for dataset_id in beaker_dataset_ids] 650 | return BeakerRuntimeConfig( 651 | beaker_workload_id=os.environ["BEAKER_WORKLOAD_ID"], 652 | beaker_node_hostname=os.environ["BEAKER_NODE_HOSTNAME"], 653 | beaker_experiment_url=f"https://beaker.org/ex/{os.environ['BEAKER_WORKLOAD_ID']}/", 654 | beaker_dataset_ids=get_beaker_dataset_ids(os.environ["BEAKER_WORKLOAD_ID"]), 655 | beaker_dataset_id_urls=beaker_dataset_id_urls, 656 | ) 657 | 658 | 659 | def retry_on_exception(max_attempts=4, delay=1, backoff=2): 660 | """ 661 | Retry a function on exception. Useful for HF API calls that may fail due to 662 | network issues. E.g., https://beaker.org/ex/01J69P87HJQQ7X5DXE1CPWF974 663 | `huggingface_hub.utils._errors.HfHubHTTPError: 429 Client Error` 664 | 665 | We can test it with the following code. 666 | @retry_on_exception(max_attempts=4, delay=1, backoff=2) 667 | def test(): 668 | raise Exception("Test exception") 669 | 670 | test() 671 | """ 672 | 673 | def decorator(func): 674 | @functools.wraps(func) 675 | def wrapper(*args, **kwargs): 676 | attempts = 0 677 | local_delay = delay 678 | while attempts < max_attempts: 679 | try: 680 | return func(*args, **kwargs) 681 | except Exception as e: 682 | attempts += 1 683 | if attempts == max_attempts: 684 | raise e 685 | print(f"Attempt {attempts} failed. Retrying in {local_delay} seconds...") 686 | time.sleep(local_delay) 687 | local_delay *= backoff 688 | return None 689 | 690 | return wrapper 691 | 692 | return decorator 693 | 694 | 695 | @retry_on_exception() 696 | def maybe_use_ai2_wandb_entity() -> Optional[str]: 697 | """Ai2 internal logic: try use the ai2-llm team if possible. Should not affect external users.""" 698 | import wandb 699 | 700 | wandb.login() 701 | api = wandb.Api() 702 | current_user = api.viewer 703 | teams = current_user.teams 704 | if "ai2-llm" in teams: 705 | return "ai2-llm" 706 | else: 707 | return None 708 | 709 | 710 | @retry_on_exception() 711 | def maybe_use_ai2_hf_entity() -> Optional[str]: 712 | """Ai2 internal logic: try use the allenai entity if possible. Should not affect external users.""" 713 | orgs = HfApi().whoami() 714 | orgs = [item["name"] for item in orgs["orgs"]] 715 | if "allenai" in orgs: 716 | return "allenai" 717 | else: 718 | return None 719 | 720 | 721 | def submit_beaker_eval_jobs( 722 | model_name: str, 723 | location: str, 724 | hf_repo_revision: str = "", 725 | workspace: str = "tulu-3-results", 726 | beaker_image: str = "nathanl/open_instruct_auto", 727 | upload_to_hf: str = "allenai/tulu-3-evals", 728 | run_oe_eval_experiments: bool = False, 729 | ) -> None: 730 | command = f""" 731 | python scripts/submit_eval_jobs.py \ 732 | --model_name {model_name} \ 733 | --location {location} \ 734 | --is_tuned \ 735 | --workspace {workspace} \ 736 | --preemptible \ 737 | --use_hf_tokenizer_template \ 738 | --beaker_image {beaker_image} \ 739 | """ 740 | if len(hf_repo_revision) > 0: 741 | command += f" --hf_revision {hf_repo_revision}" 742 | if len(upload_to_hf) > 0: 743 | command += f" --upload_to_hf {upload_to_hf}" 744 | if run_oe_eval_experiments: 745 | command += " --run_oe_eval_experiments" 746 | 747 | process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 748 | stdout, stderr = process.communicate() 749 | 750 | print(f"Beaker evaluation jobs: Stdout:\n{stdout.decode()}") 751 | print(f"Beaker evaluation jobs: Stderr:\n{stderr.decode()}") 752 | print(f"Beaker evaluation jobs: process return code: {process.returncode}") 753 | 754 | 755 | @retry_on_exception() 756 | def upload_metadata_to_hf( 757 | metadata_dict, 758 | filename, 759 | hf_dataset_name, 760 | hf_dataset_save_dir, 761 | ): 762 | # upload a random dict to HF. Originally for uploading metadata to HF 763 | # about a model for leaderboard displays. 764 | with open("tmp.json", "w") as f: 765 | json.dump(metadata_dict, f) 766 | api = HfApi() 767 | api.upload_file( 768 | path_or_fileobj="tmp.json", 769 | path_in_repo=f"{hf_dataset_save_dir}/{filename}", 770 | repo_id=hf_dataset_name, 771 | repo_type="dataset", 772 | ) 773 | os.remove("tmp.json") 774 | -------------------------------------------------------------------------------- /media/accuracy_by_steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/Finetune-Bench-RAG/a973a17d207caaf323378ded29a3d435edc209d8/media/accuracy_by_steps.png -------------------------------------------------------------------------------- /media/depth_by_steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/Finetune-Bench-RAG/a973a17d207caaf323378ded29a3d435edc209d8/media/depth_by_steps.png -------------------------------------------------------------------------------- /media/helpfulness_by_steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/Finetune-Bench-RAG/a973a17d207caaf323378ded29a3d435edc209d8/media/helpfulness_by_steps.png -------------------------------------------------------------------------------- /media/pints_ai-banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/Finetune-Bench-RAG/a973a17d207caaf323378ded29a3d435edc209d8/media/pints_ai-banner.png -------------------------------------------------------------------------------- /media/relevance_by_steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pints-AI/Finetune-Bench-RAG/a973a17d207caaf323378ded29a3d435edc209d8/media/relevance_by_steps.png -------------------------------------------------------------------------------- /prepare_dataset/README.md: -------------------------------------------------------------------------------- 1 | # Dataset 2 | 3 | The motivation here is to improve RAG (after the retrieval phase) by finetuning a model to filter out and select the intended (correct) information between intended (correct) and false positive(s). As such, the dataset is generated with the intended (correct) information and some fictitious data (simulating retrieval of incorrect data). 4 | 5 | ## Uploaded Dataset 6 | 7 | We have uploaded the dataset we used for training to HuggingFace [here](https://huggingface.co/datasets/pints-ai/Finetune-RAG). We have curated a total of `1653` documents that is ready for train-validation-test split before fine-tuning. 8 | 9 | ## Fictitious Data Generation [OPTIONAL] 10 | 11 | If you have your own set of document chunks and would like to curate questions, answers, and fictitious data from it for finetuning, you may do so by preparing a jsonl file that contains your document chunks per line. The structure of each line should be as follows: 12 | 13 | ```json 14 | { 15 | "content": "", 16 | "filename": "", 17 | } 18 | ``` 19 | 20 | > [!WARNING] 21 | > Via our method, your documents will be passed into GPT-4o for dataset generation. If it is according to plan, remember to include the `.env` file containing your OpenAI key at the root level. Example is given in `.env.sample`. 22 | 23 | ### Usage 24 | 25 | Suppose your custom jsonl file is at `dataset/custom_chunks.jsonl`, run as a module at the root level: 26 | ```bash 27 | python -m prepare_dataset.content_generation.generate_question --content_path dataset/custom_chunks.jsonl && \ 28 | python -m prepare_dataset.content_generation.generate_answer && \ 29 | python -m prepare_dataset.content_generation.generate_fictitious_content 30 | ``` 31 | 32 | ## Prepare Dataset For Training 33 | 34 | `prepare_dataset/formatting/generate_training_data.py` is the script to process the data generated for training. Essentially, it allows you to specify 2 different dialogue formats that is used to tune the model. 35 | 36 | Using our prepared dataset from Hugging Face, or your own generated dataset, run through this preparation before fine-tuning. 37 | 38 | ### Baseline Format 39 | 40 | ``` 41 | Filename: {filename1} 42 | Information: 43 | {content1} 44 | 45 | Filename: {filename2} 46 | Information: 47 | {content2} 48 | 49 | Question: {question} 50 | ``` 51 | 52 | ### XML Format 53 | 54 | ``` 55 | 56 | 57 | {filename1} 58 | {content1} 59 | 60 | 61 | {filename2} 62 | {content2} 63 | 64 | 65 | 66 | Question: {question} 67 | ``` 68 | 69 | ### Usage 70 | 71 | ```bash 72 | python -m prepare_dataset.formatting.generate_training_data --content_path dataset/finetunerag_dataset.jsonl --format baseline 73 | ``` 74 | 75 | ## Split the dataset 76 | 77 | Simply run the `prepare_dataset/formatting/split_data.py` script to get your train-validation-test splits after preparing your dataset via the instructions above. 78 | 79 | ```bash 80 | python -m prepare_dataset.formatting.split_data --dataset_path dataset/adjusted_finetunerag_dataset.jsonl 81 | ``` 82 | -------------------------------------------------------------------------------- /prepare_dataset/content_generation/generate_answer.py: -------------------------------------------------------------------------------- 1 | from logging import Logger 2 | from pathlib import Path 3 | from typing import List, Dict 4 | import copy 5 | 6 | from prompts.answer_generation_prompt import SYSTEM_MESSAGE, USER_MESSAGE_TEMPLATE 7 | from utils.dataset_utils import load_jsonl_file, write_jsonl_file 8 | from utils.logger import setup_logger 9 | from utils.openai import call_openai_api 10 | 11 | def start( 12 | content_path: Path = "dataset/contents_w_questions.jsonl", 13 | output_path: Path = "dataset/contents_w_qa.jsonl", 14 | ): 15 | assert ( 16 | content_path.is_file() 17 | and content_path.suffix == ".jsonl" 18 | ), "Path to content data is not a jsonl file." 19 | 20 | logger: Logger = setup_logger(Path(__file__).stem + "_" + content_path.name) 21 | logger.info(f"{content_path} recognized.") 22 | 23 | all_content_qn_data: List[Dict] = load_jsonl_file(content_path) 24 | 25 | # Ensure dataset is ready before generating answers 26 | for index, content_data in enumerate(all_content_qn_data): 27 | missing_keys = [key for key in ["content", "filename", "question"] if key not in content_data] 28 | assert not missing_keys, f"Missing key(s) in line {index + 1} of jsonl file: [{', '.join(missing_keys)}]" 29 | 30 | logger.info(f"Dataset checked and ready for answer generation. Generating answers...") 31 | content_data_with_answer: List[Dict] = [] 32 | for index, content_data in enumerate(all_content_qn_data): 33 | filename: str = content_data["filename"] 34 | content: str = content_data["content"] 35 | question: str = content_data["question"] 36 | generated_answer: str = generate_answer(filename=filename, content=content, question=question) 37 | logger.info(f"Line {index + 1} of jsonl file --> Answer generated for question: {question[:60].encode('unicode_escape').decode()}... Answer: {generated_answer[:60]}...") 38 | 39 | cloned_content_data: Dict = copy.deepcopy(content_data) 40 | cloned_content_data["answer"] = generated_answer 41 | content_data_with_answer.append(cloned_content_data) 42 | 43 | write_jsonl_file(content=content_data_with_answer, output_path=output_path) 44 | logger.info(f"Generation of answers completed. Updated jsonl file saved at {output_path}") 45 | return 46 | 47 | def generate_answer( 48 | filename: str, 49 | content: str, 50 | question: str, 51 | ): 52 | user_message: str = USER_MESSAGE_TEMPLATE.format(filename=filename, content=content, question=question) 53 | messages = [ 54 | {"role": "system", "content": SYSTEM_MESSAGE}, 55 | {"role": "user", "content": user_message} 56 | ] 57 | 58 | generated_openai_json: Dict = call_openai_api(messages=messages, output_as_json=True) 59 | 60 | assert "answer" in generated_openai_json, "OpenAI model did not generate correct json format" 61 | return generated_openai_json["answer"] 62 | 63 | if __name__ == "__main__": 64 | from jsonargparse import CLI 65 | CLI(start, as_positional=False) 66 | # python -m prepare_dataset.generate_answer --content_qn_data_path dataset/documents_w_qns/sample.jsonl --output_data_path dataset/documents_w_qa/sample.jsonl 67 | -------------------------------------------------------------------------------- /prepare_dataset/content_generation/generate_fictitious_content.py: -------------------------------------------------------------------------------- 1 | from logging import Logger 2 | from pathlib import Path 3 | from typing import List, Dict 4 | import copy 5 | 6 | from prompts.fictitious_content_generation_prompt import SYSTEM_MESSAGE, USER_MESSAGE_TEMPLATE 7 | from utils.dataset_utils import load_jsonl_file, write_jsonl_file 8 | from utils.logger import setup_logger 9 | from utils.openai import call_openai_api 10 | 11 | def start( 12 | content_path: Path = "dataset/contents_w_qa.jsonl", 13 | output_path: Path = "dataset/finetunerag_dataset.jsonl", 14 | num_fictitious_content: int = 2 15 | ): 16 | assert ( 17 | content_path.is_file() 18 | and content_path.suffix == ".jsonl" 19 | ), "Path to content data is not a jsonl file." 20 | 21 | assert num_fictitious_content > 0, "Number of fictitious content should be more than 0" 22 | assert isinstance(num_fictitious_content, int), "Number of fictitious content should be an integer" 23 | 24 | logger: Logger = setup_logger(Path(__file__).stem + "_" + content_path.name) 25 | logger.info(f"{content_path} recognized.") 26 | 27 | all_content_data: List[Dict] = load_jsonl_file(content_path) 28 | 29 | # Ensure dataset is ready before generating questions 30 | for index, content_data in enumerate(all_content_data): 31 | missing_keys = [key for key in ["content", "filename"] if key not in content_data] 32 | assert not missing_keys, f"Missing key(s) in line {index + 1} of jsonl file: [{', '.join(missing_keys)}]" 33 | 34 | logger.info(f"Dataset checked and ready for fictitious content generation. Generating content...") 35 | fictitious_content_data: List[Dict] = [] 36 | for index, content_data in enumerate(all_content_data): 37 | filename: str = content_data["filename"] 38 | content: str = content_data["content"] 39 | 40 | cloned_content_data: Dict = copy.deepcopy(content_data) 41 | for fictitious_content_count in range(1, num_fictitious_content + 1): 42 | fictitious_filename, fictitious_content = generate_fictitious_content(filename=filename, content=content) 43 | logger.info(f"Line {index + 1} of jsonl file --> Fictitious content {fictitious_content_count} generated. Filename: {fictitious_filename[:60]}..., Content: {fictitious_content[:60]}...") 44 | 45 | cloned_content_data[f"fictitious_filename{fictitious_content_count}"] = fictitious_filename 46 | cloned_content_data[f"fictitious_content{fictitious_content_count}"] = fictitious_content 47 | 48 | fictitious_content_data.append(cloned_content_data) 49 | 50 | write_jsonl_file(content=fictitious_content_data, output_path=output_path) 51 | logger.info(f"Generation of fictitious contents completed. Updated jsonl file saved at {output_path}") 52 | return 53 | 54 | def generate_fictitious_content( 55 | filename: str, 56 | content: str, 57 | ): 58 | user_message: str = USER_MESSAGE_TEMPLATE.format(filename=filename, content=content) 59 | messages = [ 60 | {"role": "system", "content": SYSTEM_MESSAGE}, 61 | {"role": "user", "content": user_message} 62 | ] 63 | 64 | # Set temperature to 0.5 for more spurious content output generation 65 | generated_openai_json: Dict = call_openai_api(messages=messages, output_as_json=True, temperature=0.5) 66 | assert ( 67 | "fictitious_filename" in generated_openai_json 68 | and "fictitious_content" in generated_openai_json 69 | ), "OpenAI model did not generate correct JSON format" 70 | 71 | return generated_openai_json["fictitious_filename"], generated_openai_json["fictitious_content"] 72 | 73 | if __name__ == "__main__": 74 | from jsonargparse import CLI 75 | CLI(start, as_positional=False) 76 | -------------------------------------------------------------------------------- /prepare_dataset/content_generation/generate_question.py: -------------------------------------------------------------------------------- 1 | from logging import Logger 2 | from pathlib import Path 3 | from typing import List, Dict 4 | import copy 5 | 6 | from prompts.question_generation_prompt import SYSTEM_MESSAGE, USER_MESSAGE_TEMPLATE 7 | from utils.dataset_utils import load_jsonl_file, write_jsonl_file 8 | from utils.logger import setup_logger 9 | from utils.openai import call_openai_api 10 | 11 | def start( 12 | content_path: Path = "dataset/contents.jsonl", 13 | output_path: Path = "dataset/contents_w_questions.jsonl", 14 | ): 15 | assert ( 16 | content_path.is_file() 17 | and content_path.suffix == ".jsonl" 18 | ), "Path to content data is not a jsonl file." 19 | 20 | logger: Logger = setup_logger(Path(__file__).stem + "_" + content_path.name) 21 | logger.info(f"{content_path} recognized.") 22 | 23 | all_content_data: List[Dict] = load_jsonl_file(content_path) 24 | 25 | # Ensure dataset is ready before generating questions 26 | for index, content_data in enumerate(all_content_data): 27 | missing_keys = [key for key in ["content", "filename"] if key not in content_data] 28 | assert not missing_keys, f"Missing key(s) in line {index + 1} of jsonl file: [{', '.join(missing_keys)}]" 29 | 30 | logger.info(f"Dataset checked and ready for question generation. Generating questions...") 31 | content_data_with_question: List[Dict] = [] 32 | for index, content_data in enumerate(all_content_data): 33 | filename: str = content_data["filename"] 34 | content: str = content_data["content"] 35 | generated_question: str = generate_question(filename=filename, content=content) 36 | logger.info(f"Line {index + 1} of jsonl file --> Question generated for content: {content[:60].encode('unicode_escape').decode()}... Question: {generated_question}") 37 | 38 | cloned_content_data: Dict = copy.deepcopy(content_data) 39 | cloned_content_data["question"] = generated_question 40 | content_data_with_question.append(cloned_content_data) 41 | 42 | write_jsonl_file(content=content_data_with_question, output_path=output_path) 43 | logger.info(f"Generation of questions completed. Updated jsonl file saved at {output_path}") 44 | return 45 | 46 | def generate_question( 47 | filename: str, 48 | content: str, 49 | ): 50 | user_message: str = USER_MESSAGE_TEMPLATE.format(filename=filename, content=content) 51 | messages = [ 52 | {"role": "system", "content": SYSTEM_MESSAGE}, 53 | {"role": "user", "content": user_message} 54 | ] 55 | 56 | generated_openai_json: Dict = call_openai_api(messages=messages, output_as_json=True) 57 | 58 | assert "question" in generated_openai_json, "OpenAI model did not generate correct json format" 59 | return generated_openai_json["question"] 60 | 61 | if __name__ == "__main__": 62 | from jsonargparse import CLI 63 | CLI(start, as_positional=False) 64 | 65 | -------------------------------------------------------------------------------- /prepare_dataset/formatting/formatter.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | 3 | from prompts.rag_prompt import SYSTEM_MESSAGE, BASELINE_TEMPLATE, XML_TEMPLATE 4 | 5 | templates = { 6 | "baseline": BASELINE_TEMPLATE, 7 | "xml": XML_TEMPLATE, 8 | } 9 | 10 | def formatter( 11 | filename: str, 12 | content: str, 13 | fictitious_filename: str, 14 | fictitious_content: str, 15 | question: str, 16 | answer: str, 17 | decider: Generator, 18 | template_type: str = "baseline", 19 | ): 20 | assert template_type in templates, f"{template_type} template does not exist. Available templates: {templates.keys()}" 21 | 22 | filename1, filename2 = filename, fictitious_filename 23 | content1, content2 = content, fictitious_content 24 | 25 | # Whether to change the order of the fictitious and non-fictitious 26 | flip_content_order = next(decider) 27 | if flip_content_order: 28 | filename2, filename1 = filename1, filename2 29 | content2, content1 = content1, content2 30 | 31 | user_message: str = templates.get(template_type).format( 32 | filename1=filename1, 33 | content1=content1, 34 | filename2=filename2, 35 | content2=content2, 36 | question=question 37 | ) 38 | 39 | return { 40 | "messages": [ 41 | {"role": "system", "content": SYSTEM_MESSAGE}, 42 | {"role": "user", "content": user_message}, 43 | {"role": "assistant", "content": answer} 44 | ], 45 | "content": content, 46 | "question": question, 47 | "filename": filename, 48 | } 49 | -------------------------------------------------------------------------------- /prepare_dataset/formatting/generate_training_data.py: -------------------------------------------------------------------------------- 1 | from logging import Logger 2 | from pathlib import Path 3 | from typing import List, Literal, Dict 4 | 5 | from prepare_dataset.formatting.formatter import formatter 6 | from utils.dataset_utils import get_decider, load_jsonl_file, write_jsonl_file, get_dice 7 | from utils.logger import setup_logger 8 | 9 | def start( 10 | content_path: Path = "dataset/finetunerag_dataset.jsonl", 11 | output_path: Path = "dataset/adjusted_finetunerag_dataset.jsonl", 12 | format: Literal["baseline", "xml"] = "baseline", 13 | ): 14 | assert ( 15 | content_path.is_file() 16 | and content_path.suffix == ".jsonl" 17 | ), "Path to content data is not a jsonl file." 18 | 19 | logger: Logger = setup_logger(Path(__file__).stem + "_" + content_path.name) 20 | logger.info(f"{content_path} recognized.") 21 | 22 | all_content_data: List[Dict] = load_jsonl_file(content_path) 23 | 24 | # Ensure dataset is ready before generating questions 25 | for index, content_data in enumerate(all_content_data): 26 | missing_keys = [key for key in ["content", "filename", "question", "answer", "fictitious_filename1", "fictitious_content1"] if key not in content_data] 27 | assert not missing_keys, f"Missing key(s) in line {index + 1} of jsonl file: [{', '.join(missing_keys)}]" 28 | 29 | logger.info(f"Dataset checked and ready for formatting. Preparing training data...") 30 | 31 | # It is assumed that every datapoint has the same number of fictitious content. Refering to the first datapoint to retrieve the number of ficitious content 32 | num_fictitious_content = sum(1 for key in all_content_data[0] if key.startswith("fictitious_content")) 33 | dice = get_dice(num_choices=num_fictitious_content) 34 | 35 | decider = get_decider() 36 | all_formatted_data = [] 37 | for index, content_data in enumerate(all_content_data): 38 | filename: str = content_data["filename"] 39 | content: str = content_data["content"] 40 | question: str = content_data["question"] 41 | answer: str = content_data["answer"] 42 | 43 | 44 | selected_content_index = next(dice) 45 | fictitious_filename = content_data[f"fictitious_filename{selected_content_index}"] 46 | fictitious_content = content_data[f"fictitious_content{selected_content_index}"] 47 | 48 | formatted_data = formatter( 49 | filename=filename, 50 | content=content, 51 | fictitious_filename=fictitious_filename, 52 | fictitious_content=fictitious_content, 53 | question=question, 54 | answer=answer, 55 | decider=decider, 56 | template_type=format, 57 | ) 58 | all_formatted_data.append(formatted_data) 59 | 60 | write_jsonl_file(content=all_formatted_data, output_path=output_path) 61 | logger.info(f"Preparation of training data completed. Training data saved at {output_path}") 62 | return 63 | 64 | if __name__ == "__main__": 65 | from jsonargparse import CLI 66 | CLI(start, as_positional=False) 67 | 68 | -------------------------------------------------------------------------------- /prepare_dataset/formatting/split_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | from pathlib import Path 5 | 6 | def start( 7 | dataset_path: Path = "dataset/adjusted_finetunerag_dataset.jsonl", 8 | output_folder_path: Path = "dataset/splits", 9 | seed: int = 888, 10 | ): 11 | with open(dataset_path, "r") as f: 12 | lines = f.readlines() 13 | 14 | random.seed(seed) 15 | random.shuffle(lines) 16 | 17 | total = len(lines) 18 | test_size = int(0.1 * total) 19 | train_size = total - (test_size * 2) 20 | 21 | train_data = lines[:train_size] 22 | val_data = lines[train_size:train_size + test_size] 23 | test_data = lines[train_size + test_size:] 24 | 25 | trimmed_test_data = [] 26 | for line in test_data: 27 | obj = json.loads(line) 28 | if isinstance(obj.get("messages"), list) and obj["messages"]: 29 | obj["messages"] = obj["messages"][:-1] 30 | trimmed_test_data.append(json.dumps(obj) + "\n") 31 | 32 | output_folder_path.mkdir(parents=True, exist_ok=True) 33 | 34 | (output_folder_path / "train.jsonl").write_text("".join(train_data)) 35 | (output_folder_path / "validation.jsonl").write_text("".join(val_data)) 36 | (output_folder_path / "test.jsonl").write_text("".join(trimmed_test_data)) 37 | 38 | print(f"Done! Files created: {(output_folder_path / 'train.jsonl')}, {(output_folder_path / 'validation.jsonl')}, {(output_folder_path / 'test.jsonl')}") 39 | 40 | if __name__ == "__main__": 41 | from jsonargparse import CLI 42 | CLI(start, as_positional=False) 43 | -------------------------------------------------------------------------------- /prompts/answer_generation_prompt.py: -------------------------------------------------------------------------------- 1 | SYSTEM_MESSAGE = ( 2 | "You are an AI Data Scientist who creates high quality datasets that can be used for fine-tuning of Large Language Models." 3 | " Follow the user's instruction closely to create a dataset based on the given context." 4 | ) 5 | 6 | USER_MESSAGE_TEMPLATE = """You are to read the following information and answer the question. 7 | Filename: {filename} 8 | Information: {content} 9 | 10 | Now, answer the question, and you may elaborate as necessary. Do not create information that is not found in the information provided. 11 | Question: "{question}" 12 | 13 | You will reply in the following JSON format: 14 | {{ 15 | 'answer': "" 16 | }}""" 17 | -------------------------------------------------------------------------------- /prompts/fictitious_content_generation_prompt.py: -------------------------------------------------------------------------------- 1 | SYSTEM_MESSAGE = ( 2 | "You are an AI Data Scientist who creates high quality datasets that can be used for fine-tuning of Large Language Models." 3 | " Follow the user's instruction closely to create a dataset based on the given context." 4 | ) 5 | 6 | USER_MESSAGE_TEMPLATE = """You are tasked with creating a fictitious set of information that must be thematically similar to the original user's filename and content. The fictitious file name should be from a different company, different person, or a different place etc. The fictitious content should have some similar parts but also some parts entirely fabricated with a different structure and also random incorrect content. The fictitious content must still match the original content in length, preserving patterns such as spurious use of line breaks (\\n), spewed Unicode, or broken words. In particular, ensure that the frequency and placement of line breaks (\\n) and unicode is similar to the original. The generation of fictitious content should not always be simple rephrasing, but should feel like it comes from a different document. The output should contain two JSON keys: 'fictitious_filename' and 'fictitious_content'. 7 | 8 | Here is the original user's filename and content: 9 | Filename: {filename} 10 | Information: {content} 11 | 12 | Now, generate 1 set of ficticious filename and content. 13 | 14 | You will reply in the following JSON format: 15 | {{ 16 | 'fictitious_filename': "", 17 | 'fictitious_content': "" 18 | }}""" 19 | -------------------------------------------------------------------------------- /prompts/judging_prompts.py: -------------------------------------------------------------------------------- 1 | from typing import TypedDict 2 | from openai.types.chat.chat_completion_message_param import ( 3 | ChatCompletionSystemMessageParam, 4 | ChatCompletionUserMessageParam, 5 | ) 6 | 7 | class OpenAIJudgeResponse(TypedDict): 8 | accuracy_explanation: str 9 | accuracy: bool 10 | helpfulness_explanation: str 11 | helpfulness: int 12 | relevance_explanation: str 13 | relevance: int 14 | depth_explanation: str 15 | depth: int 16 | 17 | # This prompt has been deprecated. Evaluation of multiple factors with the same call 18 | # can lead to biased results. The result for each criterion might directly 19 | # influence the ones after, resulting in an inaccurate or less objective assessment. 20 | judge_system_prompt: ChatCompletionSystemMessageParam = { 21 | 'role': 'system', 22 | 'content': """Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below, based solely on a piece of information extracted from a file provided below. Your evaluation should consider factors such as the accuracy, helpfulness, relevance, and depth of the response. 23 | 24 | 1. Accuracy - You will check whether the response contains extra details not found in the piece of information provided. If extra details are found, accuracy is false. Otherwise, accuracy is true. Take note that if the response partially addresses the question, but did not provide extra details not found in the piece of information provided, the response will still be considered accurate (hence accuracy = true). 25 | 2. Helpfulness - The helpfulness of the AI assistant in answering the question. 26 | 3. Relevance - Whether the response fully addresses the question. 27 | 4. Depth - The level of detail of the response in answering the question. 28 | 29 | Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate each factor on a scale of 1 to 10 (with the exception of accuracy, where it is true or false) by strictly following this JSON format: 30 | { 31 | "accuracy_explanation": , 32 | "accuracy": , 33 | "helpfulness_explanation": , 34 | "helpfulness": , 35 | "relevance_explanation": , 36 | "relevance": , 37 | "depth_explanation": , 38 | "depth": 39 | } 40 | """, 41 | } 42 | 43 | judge_accuracy_system_prompt: ChatCompletionSystemMessageParam = { 44 | 'role': 'system', 45 | 'content': """Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below, based solely on a piece of information extracted from a file provided below. Your evaluation should consider the accuracy of the response. 46 | 47 | You will check whether the response contains extra details not found in the piece of information provided. If extra details are found, accuracy is false. Otherwise, accuracy is true. Take note that if the response partially addresses the question, but did not provide extra details not found in the piece of information provided, the response will still be considered accurate (hence accuracy = true). 48 | 49 | Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the accuracy with true or false by strictly following this JSON format: 50 | { 51 | "accuracy_explanation": , 52 | "accuracy": 53 | } 54 | """, 55 | } 56 | 57 | judge_helpfulness_system_prompt: ChatCompletionSystemMessageParam = { 58 | 'role': 'system', 59 | 'content': """Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below, based solely on a piece of information extracted from a file provided below. Your evaluation should consider the helpfulness of the response. 60 | 61 | You will check whether the AI assistant is helpful in answering the question based on the response. 62 | 63 | Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the helpfulness on a scale of 1 to 10 by strictly following this JSON format: 64 | { 65 | "helpfulness_explanation": , 66 | "helpfulness": 67 | } 68 | """, 69 | } 70 | 71 | judge_relevance_system_prompt: ChatCompletionSystemMessageParam = { 72 | 'role': 'system', 73 | 'content': """Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below, based solely on a piece of information extracted from a file provided below. Your evaluation should consider the relevance of the response. 74 | 75 | You will check the relevance of the response by evaluating whether the response fully addresses the question. 76 | 77 | Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the relevance on a scale of 1 to 10 by strictly following this JSON format: 78 | { 79 | "relevance_explanation": , 80 | "relevance": 81 | } 82 | """, 83 | } 84 | 85 | judge_depth_system_prompt: ChatCompletionSystemMessageParam = { 86 | 'role': 'system', 87 | 'content': """Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below, based solely on a piece of information extracted from a file provided below. Your evaluation should consider the depth of the response. 88 | 89 | You will check the depth of the response by evaluating the level of detail of the response in answering the question. 90 | 91 | Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the depth on a scale of 1 to 10 by strictly following this JSON format: 92 | { 93 | "depth_explanation": , 94 | "depth": 95 | } 96 | """, 97 | } 98 | 99 | 100 | def get_judge_user_prompt(document) -> ChatCompletionUserMessageParam: 101 | return { 102 | 'role': 'user', 103 | 'content': f"""[The Start of Provided Information Extracted from a File] 104 | Filename: {document['filename']} 105 | 106 | Information: {document['content']} 107 | [The End of Provided Information] 108 | 109 | [Question] 110 | {document['question']} 111 | 112 | [The Start of Assistant's Response] 113 | {document['response']} 114 | [The End of Assistant's Response]""", 115 | } 116 | -------------------------------------------------------------------------------- /prompts/prompt_styles.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from abc import abstractmethod 3 | from typing import Dict, List, Type, Union 4 | 5 | 6 | class PromptStyle: 7 | """Base interface for prompt styles.""" 8 | 9 | @abstractmethod 10 | def apply(self, prompt: str, **kwargs: str) -> str: 11 | raise NotImplementedError('PromptStyle.apply() is an abstract method.') 12 | 13 | @classmethod 14 | def from_name(cls, name: str) -> 'PromptStyle': 15 | if name not in prompt_styles: 16 | return None 17 | return prompt_styles[name]() 18 | 19 | 20 | # This default style is exported from the original open_instruct repository: 21 | # https://github.com/allenai/open-instruct/blob/main/open_instruct/finetune.py#L396-L407 22 | class Default(PromptStyle): 23 | def apply( 24 | self, 25 | prompt: Union[str, List[Dict[str, str]]], 26 | **kwargs: str, 27 | ) -> str: 28 | message_text = '' 29 | for message in prompt: 30 | if message['role'] == 'system': 31 | message_text += '<|system|>\n' + message['content'].strip() + '\n' 32 | elif message['role'] == 'user': 33 | message_text += '<|user|>\n' + message['content'].strip() + '\n' 34 | elif message['role'] == 'assistant': 35 | message_text += ( 36 | '<|assistant|>\n' + message['content'].strip() + '' + '\n' 37 | ) 38 | else: 39 | raise ValueError(f"Invalid role: {message['role']}") 40 | return message_text 41 | 42 | 43 | class Llama3_1(PromptStyle): 44 | DEFAULT_SYSTEM_MESSAGE = 'You are a helpful assistant.' 45 | 46 | def apply( 47 | self, 48 | prompt: Union[str, List[Dict[str, str]]], 49 | no_system: bool = False, 50 | append_assistant_header: bool = False, 51 | **kwargs: str, 52 | ) -> str: 53 | assert isinstance( 54 | prompt, list 55 | ), f'Unsupported prompt type: {type(prompt)}. prompt should be formatted in a list of dict' 56 | 57 | tokens = [] 58 | if not no_system and not self.has_system_prompt(prompt): 59 | tokens.extend( 60 | self.encode_message( 61 | {'role': 'system', 'content': Llama3_1.DEFAULT_SYSTEM_MESSAGE} 62 | ) 63 | ) 64 | 65 | for i, message in enumerate(prompt): 66 | if i != 0 and message['role'] == 'system': 67 | raise ValueError("'system' role is only allowed at the beginning of the conversation list.") 68 | if message['role'] not in ['assistant', 'user', 'system']: 69 | warnings.warn( 70 | f"Unknown role: '{message['role']}'. Supported roles are 'assistant', 'user', and 'system'. It is assumed that this is intended.", 71 | UserWarning, 72 | ) 73 | 74 | tokens.extend(self.encode_message(message)) 75 | 76 | if append_assistant_header: 77 | tokens.extend(self.encode_header('assistant')) 78 | 79 | return ''.join(tokens) 80 | 81 | def encode_header(self, role: str) -> List[str]: 82 | return [f'<|start_header_id|>{role}<|end_header_id|>\n\n'] 83 | 84 | def encode_message(self, message: Dict[str, str]) -> List[str]: 85 | tokens = self.encode_header(message['role']) 86 | # NOTE: Meta stripped this. I'm not sure I agree, but who am I to argue? 87 | tokens.append(message['content'].strip()) 88 | tokens.append('<|eot_id|>') 89 | return tokens 90 | 91 | def has_system_prompt(self, messages: List[Dict[str, str]]) -> bool: 92 | return messages[0].get('role', '') == 'system' if len(messages) else False 93 | 94 | 95 | class OLMoE(PromptStyle): 96 | def apply( 97 | self, 98 | prompt: Union[str, List[Dict[str, str]]], 99 | add_generation_prompt: bool = False, 100 | bos_token: str = '', 101 | eos_token: str = '', 102 | **kwargs, 103 | ) -> str: 104 | assert isinstance( 105 | prompt, list 106 | ), f"Expected prompt to be a list of messages, got {type(prompt)}" 107 | 108 | result = [bos_token] 109 | for i, message in enumerate(prompt): 110 | role = message.get("role") 111 | content = message.get("content", "").strip() 112 | 113 | if role == "system": 114 | result.append("<|system|>\n" + content) 115 | elif role == "user": 116 | result.append("<|user|>\n" + content) 117 | elif role == "assistant": 118 | result.append("<|assistant|>\n" + content + eos_token) 119 | else: 120 | raise ValueError(f"Unsupported message role: {role}") 121 | 122 | # Append assistant header if it's the last message and generation is expected 123 | if i == len(prompt) - 1 and add_generation_prompt: 124 | result.append("<|assistant|>") 125 | 126 | return "\n".join(result) + "\n" 127 | 128 | 129 | prompt_styles: Dict[str, Type[PromptStyle]] = { 130 | 'default': Default, 131 | 'llama3.1': Llama3_1, 132 | 'olmoe': OLMoE, 133 | } 134 | -------------------------------------------------------------------------------- /prompts/question_generation_prompt.py: -------------------------------------------------------------------------------- 1 | SYSTEM_MESSAGE = ( 2 | "You are an AI Data Scientist who creates high quality datasets that can be used for fine-tuning of Large Language Models." 3 | " Follow the user's instruction closely to create a dataset based on the given context." 4 | ) 5 | 6 | USER_MESSAGE_TEMPLATE = """You are to read the following information and generate a question. 7 | Filename: {filename} 8 | Information: {content} 9 | 10 | Now, generate only 1 straightforward, broad, simple and general question. 11 | 12 | You will reply in the following JSON format: 13 | {{ 14 | 'question': "" 15 | }}""" 16 | -------------------------------------------------------------------------------- /prompts/rag_prompt.py: -------------------------------------------------------------------------------- 1 | SYSTEM_MESSAGE = ( 2 | "Some information is retrieved from the database as provided based on the user’s question." 3 | " The assistant is to answer the question to the best of his/her ability, using only the information provided." 4 | " The assistant must not add his/her own knowledge." 5 | ) 6 | 7 | templates = [] 8 | 9 | BASELINE_TEMPLATE = """Filename: {filename1} 10 | Information: 11 | {content1} 12 | 13 | Filename: {filename2} 14 | Information: 15 | {content2} 16 | 17 | Question: {question}""" 18 | templates.append(BASELINE_TEMPLATE) 19 | 20 | XML_TEMPLATE = """ 21 | 22 | {filename1} 23 | {content1} 24 | 25 | 26 | {filename2} 27 | {content2} 28 | 29 | 30 | 31 | Question: {question}""" 32 | templates.append(XML_TEMPLATE) 33 | 34 | def get_template(template: str): 35 | if template not in templates: 36 | raise KeyError(f"Template '{template}' does not exist. Available templates: {list(templates.keys())}") 37 | 38 | return templates[template] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.6.0 # SHOULD BE DONE OUTSIDE OF REQUIREMENTS.TXT 2 | accelerate==0.31.0 3 | datasets==3.2.0 4 | deepspeed==0.16.3 5 | flash-attn==2.6.3 6 | jsonargparse==4.36.0 7 | openai==1.60.0 8 | protobuf==5.29.3 9 | python-dotenv==1.0.1 10 | rich==13.9.4 11 | sentencepiece==0.2.0 12 | transformers==4.48.3 13 | matplotlib==3.10.1 -------------------------------------------------------------------------------- /utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | from pathlib import Path 5 | from typing import List, Dict, Generator 6 | 7 | def load_jsonl_file(file_path: Path) -> List[Dict]: 8 | data = [] 9 | with open(file_path, 'r', encoding='utf-8') as file: 10 | for line in file: 11 | dict_obj = json.loads(line) 12 | data.append(dict_obj) 13 | 14 | assert len(data) > 0, f'{file_path.name} is empty!' 15 | return data 16 | 17 | def write_jsonl_file(content: List[dict], output_path: Path): 18 | # create all ancestors directory if it doesn't exist 19 | output_path.parent.mkdir(parents=True, exist_ok=True) 20 | with open(output_path, 'w', encoding='utf-8') as f: 21 | for entry in content: 22 | f.write(json.dumps(entry) + '\n') 23 | 24 | def get_dice(num_choices: int, seed: int=69) -> Generator[int, None, None]: 25 | if seed: 26 | random.seed(seed) 27 | 28 | choices = list(range(1, num_choices + 1)) 29 | while True: 30 | yield random.choice(choices) 31 | 32 | def get_decider(seed: int=69) -> Generator[int, None, None]: 33 | if seed: 34 | random.seed(seed) 35 | 36 | choices = [True, False] 37 | while True: 38 | yield random.choice(choices) 39 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | from typing import Optional 4 | import threading 5 | 6 | BASE_DIR = Path(__file__).resolve().parent.parent 7 | LOGS_DIR = (BASE_DIR / 'logs').resolve() 8 | LOGS_DIR.mkdir(exist_ok=True) 9 | 10 | # Create a lock object for thread safety 11 | log_lock = threading.Lock() 12 | 13 | def setup_logger( 14 | namespace: str, logging_level=logging.DEBUG, logfile_name: Optional[str] = None 15 | ): 16 | with log_lock: 17 | # Create a custom logger 18 | logger = logging.getLogger(namespace) 19 | 20 | # Check if the logger already has handlers to prevent adding multiple 21 | if not logger.hasHandlers(): 22 | # Set the logging level 23 | logger.setLevel(logging_level) 24 | 25 | # Create handlers 26 | logfile = f'{logfile_name if logfile_name else namespace}.log' 27 | logfile_path = LOGS_DIR / logfile 28 | file_handler = logging.FileHandler(logfile_path) 29 | console_handler = logging.StreamHandler() 30 | 31 | # Set the logging level for the handlers 32 | file_handler.setLevel(logging_level) 33 | console_handler.setLevel(logging_level) 34 | 35 | # Create a formatter and set it for the handlers 36 | formatter = logging.Formatter( 37 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s' 38 | ) 39 | file_handler.setFormatter(formatter) 40 | console_handler.setFormatter(formatter) 41 | 42 | # Add the handlers to the logger 43 | logger.addHandler(file_handler) 44 | logger.addHandler(console_handler) 45 | 46 | return logger 47 | -------------------------------------------------------------------------------- /utils/openai.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from dotenv import load_dotenv 4 | from openai import OpenAI 5 | from os import environ 6 | from typing import Dict, List 7 | 8 | load_dotenv() 9 | CLIENT = OpenAI(api_key=environ.get('OPENAI_API_KEY')) 10 | 11 | def call_openai_api( 12 | messages: List[Dict[str, str]], 13 | model: str = 'gpt-4o', 14 | output_as_json: bool = False, 15 | temperature: float = 0, 16 | ): 17 | # We do not check whether the indicated model support json_object. 18 | # Refer to https://platform.openai.com/docs/guides/json-mode for more information. 19 | response_format = {'type': 'json_object'} if output_as_json else {'type': 'text'} 20 | 21 | try: 22 | chat_completion = CLIENT.chat.completions.create( 23 | messages=messages, 24 | model=model, 25 | n=1, 26 | response_format=response_format, 27 | temperature=temperature, 28 | ) 29 | response = chat_completion.choices[0].message.content 30 | except Exception as e: 31 | # Handle any unexpected error 32 | raise Exception(f"An unexpected error occurred: {str(e)}") 33 | 34 | if output_as_json: 35 | try: 36 | json_response = json.loads(response, strict=False) 37 | return json_response 38 | except json.JSONDecodeError as error: 39 | raise ValueError(f"JSON decoding failed: {error}") 40 | else: 41 | return response 42 | --------------------------------------------------------------------------------