├── .env.sample
├── .gitignore
├── LICENSE
├── README.md
├── benchrag
├── README.md
├── get_results.py
├── plot_results.py
├── run_inference.py
└── run_judging.py
├── configs
└── deepspeed
│ ├── stage2_no_offloading_accelerate.conf
│ ├── stage2_offloading_accelerate.conf
│ ├── stage3_no_offloading.conf
│ ├── stage3_no_offloading_accelerate.conf
│ ├── stage3_offloading.conf
│ └── stage3_offloading_accelerate.conf
├── finetunerag
├── arguments.py
├── finetune.py
├── model_utils.py
└── utils.py
├── media
├── accuracy_by_steps.png
├── depth_by_steps.png
├── helpfulness_by_steps.png
├── pints_ai-banner.png
└── relevance_by_steps.png
├── prepare_dataset
├── README.md
├── content_generation
│ ├── generate_answer.py
│ ├── generate_fictitious_content.py
│ └── generate_question.py
└── formatting
│ ├── formatter.py
│ ├── generate_training_data.py
│ └── split_data.py
├── prompts
├── answer_generation_prompt.py
├── fictitious_content_generation_prompt.py
├── judging_prompts.py
├── prompt_styles.py
├── question_generation_prompt.py
└── rag_prompt.py
├── requirements.txt
└── utils
├── dataset_utils.py
├── logger.py
└── openai.py
/.env.sample:
--------------------------------------------------------------------------------
1 | OPENAI_API_KEY=""
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 |
3 | dataset/
4 | logs/
5 | ragbench/inferences
6 | ragbench/judge_results
7 | wandb/
8 |
9 | .env
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # Finetune-RAG: Fine-tuning Models to Tackle Retrieval-Augmented Generation (RAG) Hallucination
6 |
7 | This repository provides an open-source framework to fine-tune large language models (LLMs) for improving their ability to discern correct information from irrelevant or fictitious data when using Retrieval-Augmented Generation (RAG) systems. By training models to distinguish between relevant and misleading contexts, we aim to reduce the hallucination problem in LLM-generated responses, enhancing the reliability of models in real-world use cases.
8 |
9 | ## Paper & Citation
10 |
11 | ```latex
12 | @misc{lee2025finetuneragfinetuninglanguagemodels,
13 | title={Finetune-RAG: Fine-Tuning Language Models to Resist Hallucination in Retrieval-Augmented Generation},
14 | author={Zhan Peng Lee and Andre Lin and Calvin Tan},
15 | year={2025},
16 | eprint={2505.10792},
17 | archivePrefix={arXiv},
18 | primaryClass={cs.CL},
19 | url={https://arxiv.org/abs/2505.10792},
20 | }
21 | ```
22 |
23 | ## Problem Overview
24 |
25 | When integrating retrieval into LLM workflows, models often rely on external documents to provide factual information in response to user queries. However, if incorrect or irrelevant documents are retrieved, the model may generate incorrect responses by "hallucinating" based on misleading data. This repository addresses the issue by fine-tuning models to:
26 |
27 | - Recognize and ignore fictitious or irrelevant documents.
28 | - Focus on relevant, factually correct context.
29 | - Generate accurate answers, even when faced with conflicting or unreliable data.
30 |
31 | ## Approach
32 |
33 | Our method involves fine-tuning LLMs using carefully designed prompts that provide two types of data:
34 |
35 | 1. **A correct, factually grounded chunk**.
36 | 2. **A fictitious, misleading chunk**.
37 |
38 | The fine-tuning process teaches the model to focus solely on the correct chunk, filtering out the irrelevant or false context. The training labels are the correct answers based on the factual context, guiding the model to avoid using the fictitious information during generation.
39 |
40 | ### Key Steps
41 |
42 | 1. **Data Construction**:
43 | - For each training example, the model is given a question, one chunk of data that contains the correct information, and a second chunk that contains fictitious information.
44 |
45 | 2. **Training Objective**:
46 | - The model is trained to generate a correct answer by leveraging the factual chunk while ignoring the fictitious one.
47 |
48 | 3. **Evaluation**:
49 | - We evaluate using **Bench-RAG**, our devised LLM-as-a-judge framework. During evaluation, the model is tested on its ability to provide accurate answers, with performance measured by its capacity to ignore misleading data and produce responses based only on the correct context, as judged by GPT4o.
50 |
51 | ## Setup
52 |
53 | ### Install conda
54 |
55 | ```bash
56 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
57 | sh Miniconda3-latest-Linux-x86_64.sh
58 | ```
59 |
60 | Source just to be sure `conda` cli will be available:
61 |
62 | ```bash
63 | source ~/.bashrc
64 | ```
65 |
66 | Sometimes if you still face `conda: command cannot be found`, you can find the installation and source it:
67 |
68 | `Note: This path assumes you took up the default installation settings. Otherwise, find where you installed it.`
69 |
70 | ```bash
71 | source ~/miniconda3/etc/profile.d/conda.sh
72 | ```
73 |
74 | ## Clone this repo
75 |
76 | ```bash
77 | git clone https://github.com/Pints-AI/Finetune-Bench-RAG.git && \
78 | cd Finetune-Bench-RAG
79 | ```
80 |
81 | ## Create conda env
82 |
83 | ```bash
84 | conda create --prefix ./.conda python=3.10 && \
85 | conda activate ./.conda
86 | ```
87 |
88 | `Note`: Stick to Python 3.10. 3.12 breaks a lot of things as of now (23 Feb 2024), and 3.11 has not been tested.
89 |
90 | ## Install CUDA toolkit
91 |
92 | ```bash
93 | conda install nvidia/label/cuda-12.1.1::cuda-toolkit
94 | ```
95 |
96 | ## Install requirements
97 |
98 | ```bash
99 | pip install torch==2.6.0 && \
100 | pip install -r requirements.txt
101 | ```
102 |
103 | ## Training
104 |
105 | ### Dataset Preparation
106 |
107 | The dataset used in this project consists of over 1,600 documents manually scraped from a wide range of sources. These documents are organised into categories, including legal documents (e.g., contracts, governance, compliance), research papers from multiple scientific fields, books (both fiction and non-fiction), web content, news articles, parliamentary debates, and government publications. Additionally, the dataset includes industry-specific documents such as technical documentation, patents, market research, and code repositories.
108 |
109 | This diverse dataset ensures that the model is exposed to varied contexts, allowing it to learn how to identify and filter out irrelevant or fictitious information across multiple domains. You can access it from our [Hugging Face dataset page](https://huggingface.co/datasets/pints-ai/Finetune-RAG).
110 |
111 | Before training, download the dataset:
112 |
113 | ```bash
114 | huggingface-cli download --local-dir dataset/ --repo-type dataset pints-ai/Finetune-RAG
115 | ```
116 | Afterwards, process the dataset to suit the training pipeline. Refer to the [`prepare_dataset/`](prepare_dataset/) folder for more information.
117 |
118 | ### Model preparation
119 |
120 | Generally, most huggingface-compatible causal language models should work fine with our codebase, potentially with some adjusting for different tokenizers etc. Some models may require additional requests to download. E.g., for LLaMa, please consult [the Hugging Face documentation](https://huggingface.co/docs/transformers/model_doc/llama) for requesting access and converting them to a huggingface-compatible format.
121 |
122 | Notably, we want to finetune instruct models to retain its conversational capabilities.
123 |
124 | ### Prompt Strategy
125 |
126 | We have standardised the way we include the retrieved content in the prompt:
127 |
128 | #### System Message
129 |
130 | For all training, our system message is as follows:
131 |
132 | ```python
133 | SYSTEM_PROMPT = 'Some information is retrieved from the database as provided based on the user’s question. The assistant is to answer the question to the best of his/her ability, using only the information provided. The assistant must not add his/her own knowledge.'
134 | ```
135 |
136 | The goal of SFT is to enhance the model's ficticious content recognition capabilities. As such, we do not wish to overly prompt-engineer and influence the model in its choice of content used in its answer generation.
137 |
138 | #### User Message
139 |
140 | We have defined the message into two categories: Baseline & XML
141 |
142 | The Baseline user message serves to generically provide the content to the model. We train the model with one genuine content, and one fictitious content, with its order at random. The format is as follows:
143 |
144 | ```
145 | Filename: {filename1}
146 | Information:
147 | {content1}
148 |
149 | Filename: {filename2}
150 | Information:
151 | {content2}
152 |
153 | Question: {question}
154 | ```
155 |
156 | The XML user message is similar, but with its content encapsulated with XML tags:
157 | ```
158 |
159 |
160 | {filename1}
161 | {content1}
162 |
163 |
164 | {filename2}
165 | {content2}
166 |
167 |
168 |
169 | Question: {question}
170 | ```
171 |
172 | We conduct separate SFTs for both formats. Results are below.
173 |
174 | ### Finetuning
175 |
176 | The codebase is `deepspeed`-with-`accelerate` enabled. If the [various existing `deepspeed` configs](configs/deepspeed/) are not what you are looking for, path your custom config into the execution command.
177 |
178 | Below is an example command to run finetuning on Llama3.1-8B-Instruct model:
179 |
180 | ```bash
181 | accelerate launch \
182 | --mixed_precision bf16 \
183 | --num_machines 1 \
184 | --num_processes 1 \
185 | --use_deepspeed \
186 | --deepspeed_config_file configs/ds_configs/stage2_offloading_accelerate.conf \
187 | finetunerag/finetune.py \
188 | --model_name_or_path ../Llama-3.1-8B-Instruct \
189 | --tokenizer_name_or_path ../Llama-3.1-8B-Instruct \
190 | --use_flash_attn \
191 | --max_seq_length 4096 \
192 | --preprocessing_num_workers 128 \
193 | --per_device_train_batch_size 1 \
194 | --gradient_accumulation_steps 64 \
195 | --learning_rate 2e-5 \
196 | --lr_scheduler_type cosine \
197 | --beta1 0.9 \
198 | --beta2 0.95 \
199 | --warmup_ratio 0.1 \
200 | --weight_decay 0.1 \
201 | --num_train_epochs 1 \
202 | --enable_wandb \
203 | --logging_steps 1 \
204 | --checkpointing_steps 2 \
205 | --prompt_style llama3.1 \
206 | --validation_step 1 \
207 | --wandb_project Finetuning \
208 | --wandb_name Llama-3.1-8B-Instruct-Finetuned \
209 | --train_file dataset/splits/train.jsonl \
210 | --validation_file dataset/splits/validation.jsonl \
211 | --output_dir ../models/Llama-3.1-8B-Instruct-Baseline/
212 | ```
213 |
214 | Make sure to adjust `model_name_or_path`, `tokenizer_name_or_path`, `train_file`, and `output_dir` to your models / data / setting.
215 |
216 | ### Released Checkpoints
217 |
218 | We have finetuned Llama-3.1-8B-Instruct to tackle RAG hallucination. We used 1xH100 GPU, with a micro-batch size of 1, and a batch size of 64 per step. See the sample finetuning command above for more information of the hyperparameters used. Our checkpoints can be found here:
219 |
220 | - [Baseline-tuned(checkpoints: steps_2-10)](https://huggingface.co/pints-ai/Llama-3.1-8B-Instruct-RAG_Baseline_tuned-1).
221 | - [Baseline-tuned(checkpoints: steps_12-20)](https://huggingface.co/pints-ai/Llama-3.1-8B-Instruct-RAG_Baseline_tuned-2).
222 | - [XML-tuned(checkpoints: steps_2-10)](https://huggingface.co/pints-ai/Llama-3.1-8B-Instruct-RAG_XML_tuned-1).
223 | - [XML-tuned(checkpoints: steps_12-20)](https://huggingface.co/pints-ai/Llama-3.1-8B-Instruct-RAG_XML_tuned-2).
224 |
225 | ## Evaluation
226 |
227 | ### Bench-RAG
228 |
229 | We have devised a benchmark strategy, namely **Bench-RAG**, with the help of GPT-4o to identify whether the outputs generated by the model is factually accurate even after provided with ficticious data in its prompt. See the [`benchrag/`](benchrag/) folder for more information about how it works and its execution.
230 |
231 | Below is the results of the finetuned Llama-3.1-8B-Instruct:
232 |
233 |
234 |
235 |
236 |
237 |
238 | ## Acknowledgements
239 | The structure of our codebase is referenced from the [Open-Instruct repository](https://github.com/allenai/open-instruct).
240 |
241 | ## Licence
242 | This codebase is licensed under Apache 2.0 as given in LICENSE.
243 |
--------------------------------------------------------------------------------
/benchrag/README.md:
--------------------------------------------------------------------------------
1 | # Bench-RAG
2 |
3 | Our evaluation tool is designed to assess whether the outputs generated by a model remain factually accurate, even when presented with fictitious data in its prompt. It uses a structured evaluation system powered by OpenAI models (e.g., GPT-4o). In addition to factual accuracy, it also evaluates other key aspects such as helpfulness, relevance, and depth, using tailored prompts for each metric. The tool provides a detailed assessment, offering a True/False rating for accuracy and scores from 1 to 10 for the other dimensions, along with explanations for each evaluation.
4 |
5 | ## Features
6 |
7 | - **Accuracy Evaluation**: Determines if the response introduces any extra information not found in the provided context.
8 | - **Helpfulness Evaluation**: Rates how useful the response is in answering the user's question.
9 | - **Relevance Evaluation**: Scores the response based on how well it addresses the question.
10 | - **Depth Evaluation**: Measures the level of detail provided in the response.
11 |
12 | ## Inferencing prior to Bench-RAG
13 |
14 | Bench-RAG reads `jsonl` file(s), with each of them containing multiple inferences from the same model variant. Every of these inference should be in json format as follows:
15 |
16 | ```json
17 | {
18 | "filename": "The unique identifier for the file",
19 | "content": "The non-fictitious content in which the response was suppose to generate the answer from",
20 | "question": "The user's question",
21 | "response": "The fine-tuned LLM's generated response to the question. This response should have been generated by providing both the non-fictitious and the fictitious content."
22 | }
23 | ```
24 |
25 | You can run inference on your fine-tuned model using the `benchrag/run_inference.py` script at the root directory.
26 |
27 | - The script has been curated for multi-gpu inferencing. Set the appropriate `num_processes` before executing.
28 | - The data format used for inference should follow suit with the `messages` format used during fine-tuning.
29 |
30 | ```bash
31 | # Run on directory with multiple models, example use-case include checkpointing done when fine-tuning.
32 | # Note that the script searches for checkpoints with prefix of 'steps_'. Change it accordingly to suit your needs.
33 | accelerate launch -m \
34 | --multi_gpu \ # Only indicate multi_gpu flag if you have more than one gpu, else it will throw an error
35 | --num_processes=2 \
36 | --mixed_precision=bf16 \
37 | benchrag.run_inference \
38 | --checkpoints_directory path/to/multiple/fine-tuned/models \
39 | --data_directory path/to/testing/data \
40 | --tokenizer_directory path/to/tokenizer \
41 | --custom_chat_template llama3.1
42 |
43 | # Run on a specific model directory.
44 | accelerate launch -m \
45 | --multi_gpu \ # Only indicate multi_gpu flag if you have more than one gpu, else it will throw an error
46 | --num_processes=2 \
47 | --mixed_precision=bf16 \
48 | benchrag.run_inference \
49 | --specific_checkpoint_directory path/to/fine-tuned/model \
50 | --data_directory path/to/testing/data \
51 | --tokenizer_directory path/to/tokenizer \
52 | --custom_chat_template llama3.1
53 | ```
54 |
55 | > [!IMPORTANT]
56 | > If you are using deepspeed for training, remember to convert the checkpoint weights before inference.
57 |
58 | ## Executing Bench-RAG
59 |
60 | With your curated model inferences or output from `benchrag/run_inference.py`, run the benchmarking either on a directory of jsonl files or on a specific jsonl file you desire. Please refer to the official documentation for the available openai models as evaluator.
61 |
62 | ```bash
63 | # Run benchmark on directory with multiple jsonl files from various checkpoints.
64 | python -m benchrag.run_judging \
65 | --openai_evaluator gpt-4o \
66 | --answers_directory path/to/multiple/jsonl/files \
67 | --output_directory path/to/output/directory
68 |
69 | # Run benchmark on a specific jsonl file.
70 | python -m benchrag.run_judging \
71 | --openai_evaluator gpt-4o \
72 | --answers_file path/to/single/jsonl/file.jsonl \
73 | --output_directory path/to/output/directory
74 | ```
75 |
76 | # Collate results
77 |
78 | You can aggregate the results via the `benchrag/get_results.py` script. Simply path to the directory that contains all the jsonl files generated.
79 |
80 | ```bash
81 | python -m benchrag.get_results --directory_path path/to/benchrag/results
82 | ```
83 |
--------------------------------------------------------------------------------
/benchrag/get_results.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | from pathlib import Path
5 |
6 | def read_jsonl(file_path):
7 | """Reads a JSONL file and returns the data as a list of dictionaries."""
8 | data = []
9 | with open(file_path, 'r') as file:
10 | for line in file:
11 | data.append(json.loads(line.strip()))
12 | return data
13 |
14 |
15 | def compute_summary_metrics(data):
16 | """Computes summary metrics from a list of data dictionaries."""
17 | metrics = {
18 | 'total_records': 0,
19 | 'accuracy_true_count': 0,
20 | 'accuracy_false_count': 0,
21 | 'total_helpfulness': 0,
22 | 'total_relevance': 0,
23 | 'total_depth': 0,
24 | 'total_average': 0.0,
25 | }
26 |
27 | # Iterate over each record to update the metrics
28 | for record in data:
29 | metrics['total_records'] += 1
30 | if record['accuracy']:
31 | metrics['accuracy_true_count'] += 1
32 | else:
33 | metrics['accuracy_false_count'] += 1
34 | metrics['total_helpfulness'] += record['helpfulness']
35 | metrics['total_relevance'] += record['relevance']
36 | metrics['total_depth'] += record['depth']
37 | metrics['total_average'] += record['average']
38 |
39 | # Calculate averages if there are any records
40 | if metrics['total_records'] > 0:
41 | metrics['average_helpfulness'] = (
42 | metrics['total_helpfulness'] / metrics['total_records']
43 | )
44 | metrics['average_relevance'] = (
45 | metrics['total_relevance'] / metrics['total_records']
46 | )
47 | metrics['average_depth'] = metrics['total_depth'] / metrics['total_records']
48 | metrics['average_average'] = metrics['total_average'] / metrics['total_records']
49 | else:
50 | metrics['average_helpfulness'] = 0
51 | metrics['average_relevance'] = 0
52 | metrics['average_depth'] = 0
53 | metrics['average_average'] = 0
54 |
55 | return metrics
56 |
57 |
58 | def process_directory(directory_path: Path):
59 | """Processes all JSONL files in the specified directory."""
60 | filenames = sorted(os.listdir(directory_path))
61 | for filename in filenames:
62 | if filename.endswith('.jsonl'):
63 | file_path = os.path.join(directory_path, filename)
64 | data = read_jsonl(file_path)
65 | metrics = compute_summary_metrics(data)
66 | print(f'Summary Metrics for {filename}:')
67 | print(f"Total Records: {metrics['total_records']}")
68 | accuracy = metrics['accuracy_true_count'] / metrics['total_records'] * 100
69 | print(f'Accuracy: {accuracy:.2f}%')
70 | print(f"Average helpfulness: {metrics['average_helpfulness']:.2f}")
71 | print(f"Average Relevance: {metrics['average_relevance']:.2f}")
72 | print(f"Average Depth: {metrics['average_depth']:.2f}")
73 | print(f"Average of Averages: {metrics['average_average']:.2f}")
74 | print('')
75 |
76 | if __name__ == "__main__":
77 | from jsonargparse import CLI
78 | CLI(process_directory, as_positional=False)
79 |
--------------------------------------------------------------------------------
/benchrag/plot_results.py:
--------------------------------------------------------------------------------
1 | """
2 | Sample usage: python -m ragbench.plot_results --result_directory ragbench/results/Llama-3.1-8B-Instruct-Baseline ragbench/results/Llama-3.1-8B-Instruct-XML ragbench/results/Llama-3.1-8B-Instruct-Enhanced --output_directory ragbench/judge_results/
3 | """
4 |
5 | import json
6 | import re
7 | from argparse import ArgumentParser
8 | from pathlib import Path
9 |
10 | from matplotlib import pyplot
11 |
12 |
13 | def start(result_paths: list[Path], output_directory: Path):
14 | aggregated_results = []
15 | for result_path in result_paths:
16 | aggregated_results.extend(aggregate_scores(result_path))
17 |
18 | plot_scores_by_metric(aggregated_results, output_directory)
19 |
20 |
21 | def aggregate_scores(result_path: Path):
22 | jsonl_file_paths = list(result_path.glob('*.jsonl'))
23 |
24 | results = []
25 | for jsonl_file_path in jsonl_file_paths:
26 | total_accuracy = 0
27 | total_helpfulness = 0
28 | total_relevance = 0
29 | total_depth = 0
30 | count = 0
31 |
32 | with open(jsonl_file_path, 'r') as jsonl_file:
33 | for jsonl_line in jsonl_file:
34 | jsonl_data = json.loads(jsonl_line.strip())
35 |
36 | total_accuracy += jsonl_data.get('accuracy', 0)
37 | total_helpfulness += jsonl_data.get('helpfulness', 0)
38 | total_relevance += jsonl_data.get('relevance', 0)
39 | total_depth += jsonl_data.get('depth', 0)
40 | count += 1
41 |
42 | avg_accuracy = total_accuracy / count
43 | avg_helpfulness = total_helpfulness / count
44 | avg_relevance = total_relevance / count
45 | avg_depth = total_depth / count
46 |
47 | results.append(
48 | {
49 | 'template_type': result_path.name.rsplit('-', 1)[-1],
50 | 'file': jsonl_file_path.name,
51 | 'avg_accuracy': avg_accuracy,
52 | 'avg_helpfulness': avg_helpfulness,
53 | 'avg_relevance': avg_relevance,
54 | 'avg_depth': avg_depth,
55 | }
56 | )
57 |
58 | return results
59 |
60 |
61 | def extract_number(filename):
62 | match = re.search(r'steps_(\d+)', filename)
63 | return int(match.group(1)) if match else float('inf')
64 |
65 |
66 | def plot_scores_by_metric(aggregated_results: list[dict], output_directory: Path):
67 | template_types = list(set(result['template_type'] for result in aggregated_results))
68 |
69 | output_directory.mkdir(parents=True, exist_ok=True)
70 |
71 | metrics = ['accuracy', 'helpfulness', 'relevance', 'depth']
72 | metric_labels = ['Accuracy', 'Helpfulness', 'Relevance', 'Depth']
73 |
74 | for i, metric in enumerate(metrics):
75 | fig, ax = pyplot.subplots(figsize=(12, 8))
76 |
77 | for template_type in template_types:
78 | files = [
79 | result['file']
80 | for result in aggregated_results
81 | if result['template_type'] == template_type
82 | ]
83 | values = [
84 | result[f'avg_{metric}']
85 | for result in aggregated_results
86 | if result['template_type'] == template_type
87 | ]
88 |
89 | sorted_pairs = sorted(
90 | zip(files, values), key=lambda x: extract_number(x[0])
91 | )
92 | sorted_files, sorted_values = zip(*sorted_pairs)
93 | sorted_files = list(map(lambda name: name[:-6], sorted_files))
94 | ax.plot(sorted_files, sorted_values, marker='o', label=f'{template_type}')
95 |
96 | ax.set_title(f'Average {metric_labels[i]} of Finetuned Llama-3.1-8B-Instruct')
97 | ax.set_xlabel('Checkpoint')
98 | ax.set_ylabel(f'Average {metric_labels[i]}')
99 | ax.legend()
100 | ax.tick_params(axis='x', rotation=45)
101 |
102 | # Save the PNG file
103 | pyplot.tight_layout()
104 | pyplot.savefig(output_directory / f'{metric}_by_steps.png')
105 | pyplot.close()
106 |
107 |
108 | if __name__ == '__main__':
109 | parser = ArgumentParser(description='Plot line graphs of ragbench results.')
110 | parser.add_argument(
111 | '--result_directory',
112 | type=str,
113 | nargs='+',
114 | required=True,
115 | help='Directories containing ragbench results in JSONL files. Multiple directories can be provided.',
116 | )
117 | parser.add_argument(
118 | '--output_directory',
119 | type=str,
120 | required=True,
121 | help='Output directory to save the plots.',
122 | )
123 |
124 | arguments = parser.parse_args()
125 |
126 | result_paths = [
127 | Path(result_directory) for result_directory in arguments.result_directory
128 | ]
129 | output_path = Path(arguments.output_directory)
130 | start(result_paths, output_path)
131 |
--------------------------------------------------------------------------------
/benchrag/run_inference.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 |
4 | from accelerate import Accelerator
5 | from utils.dataset_utils import load_jsonl_file
6 | from dataclasses import dataclass, field
7 | from utils.logger import setup_logger
8 | from prompts.prompt_styles import PromptStyle
9 | from finetunerag.utils import ArgumentParserPlus
10 | from pathlib import Path
11 | from transformers import AutoModelForCausalLM, AutoTokenizer
12 | from typing import Optional
13 |
14 | @dataclass
15 | class InferencingArguments:
16 | """
17 | Full arguments class for inferencing from checkpoints.
18 | """
19 |
20 | checkpoints_directory: Optional[str] = field(
21 | default=None,
22 | metadata={
23 | 'help': 'Checkpoints directory containing the checkpoints of a model. This cannot be provided along with --specific_checkpoint_directory.'
24 | },
25 | )
26 | specific_checkpoint_directory: Optional[str] = field(
27 | default=None,
28 | metadata={'help': 'Path to a specific checkpoint folder. This cannot be provided along with --checkpoints_directory.'},
29 | )
30 | data_directory: Optional[str] = field(
31 | default=None,
32 | metadata={'help': 'Data directory containing the content used to generate the prompts.'},
33 | )
34 | output_directory: str = field(
35 | default='ragbench/inferences/',
36 | metadata={'help': 'Output directory to save the generated answers.'},
37 | )
38 | tokenizer_directory: Optional[str] = field(
39 | default=None,
40 | metadata={'help': 'Path to the tokenizer directory.'},
41 | )
42 | custom_chat_template: Optional[str] = field(
43 | default=None,
44 | metadata={'help': 'Name of custom chat template if the chat template from tokenizer is not desired.'},
45 | )
46 |
47 | def __post_init__(self):
48 | if self.checkpoints_directory == None and self.specific_checkpoint_directory == None:
49 | raise ValueError("No checkpoints directory or specific checkpoint directory provided.")
50 |
51 | if self.checkpoints_directory != None and self.specific_checkpoint_directory != None:
52 | raise ValueError("Both checkpoints directory and specific checkpoint directory provided. Please only provide either one.")
53 |
54 | if self.data_directory == None:
55 | raise ValueError('No data directory provided. Unable to generate from model.')
56 |
57 | if self.tokenizer_directory == None:
58 | raise ValueError('No tokenizer directory provided.')
59 |
60 | self.data_path = Path(self.data_directory)
61 | self.output_path = Path(self.output_directory)
62 | self.tokenizer_path = Path(self.tokenizer_directory)
63 | self.checkpoints_path = Path(self.checkpoints_directory) if self.checkpoints_directory else None
64 | self.specific_checkpoint_path = Path(self.specific_checkpoint_directory) if self.specific_checkpoint_directory else None
65 |
66 | # Global logger
67 | logger = setup_logger(Path(__file__).name)
68 |
69 | def start(args: InferencingArguments):
70 | dataset = load_jsonl_file(args.data_path)
71 | args.output_path.mkdir(parents=True, exist_ok=True)
72 |
73 | if args.checkpoints_path:
74 | # Retrieve all checkpoints available from the path to the checkpoints
75 | checkpoint_paths = list(args.checkpoints_path.glob('steps_*'))
76 | # Sort the checkpoints by their step count
77 | checkpoint_paths = sorted(checkpoint_paths, key=lambda checkpoint_folder: int(checkpoint_folder.name.rsplit('-', 1)[-1]))
78 | logger.debug(f'These are the checkpoints identified: {checkpoint_paths}')
79 | else:
80 | checkpoint_paths = [Path(args.specific_checkpoint_directory)]
81 |
82 | for checkpoint_path in checkpoint_paths:
83 | generate_responses(checkpoint_path, args.tokenizer_path, args.output_path, dataset, args.custom_chat_template)
84 | logger.info(f"Finished processing {checkpoint_path}.")
85 |
86 | logger.info(f"Inference for all checkpoints complete!")
87 |
88 |
89 | def generate_responses(
90 | checkpoint_path: Path,
91 | tokenizer_path: Path,
92 | output_path: Path,
93 | dataset: list,
94 | custom_chat_template: str
95 | ):
96 | accelerator = Accelerator()
97 |
98 | prompt_styler = PromptStyle.from_name(custom_chat_template)
99 |
100 | model = AutoModelForCausalLM.from_pretrained(
101 | checkpoint_path,
102 | torch_dtype=torch.bfloat16,
103 | attn_implementation='flash_attention_2',
104 | device_map={'': accelerator.process_index},
105 | )
106 | model = torch.compile(model)
107 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
108 |
109 | model, tokenizer = accelerator.prepare(model, tokenizer)
110 | accelerator.wait_for_everyone() # Synchronise all processes to ensure readiness before starting generation
111 |
112 | generated_responses = []
113 |
114 | with accelerator.split_between_processes(dataset) as documents:
115 | for index, datarow in enumerate(documents):
116 | prompt_text = prompt_styler.apply(datarow['messages'], append_assistant_header=True)
117 | inputs = tokenizer(prompt_text, return_tensors="pt").to(accelerator.device)
118 | generation_output = model.generate(
119 | inputs.input_ids,
120 | max_length=6000,
121 | do_sample=False,
122 | temperature=None,
123 | pad_token_id=tokenizer.eos_token_id,
124 | )
125 |
126 | generated_ids = generation_output[0][inputs.input_ids.shape[1]:-1]
127 | response = tokenizer.decode(generated_ids)
128 |
129 | generated_responses.append({
130 | 'filename': datarow['filename'],
131 | 'content': datarow['content'],
132 | 'question': datarow['question'],
133 | 'response': response
134 | })
135 |
136 | gathered_responses = accelerator.gather(generated_responses)
137 |
138 | if accelerator.is_main_process:
139 | response_file_path = output_path / f'{checkpoint_path.name}.jsonl'
140 | with open(response_file_path, 'w') as response_file:
141 | for response_data in gathered_responses:
142 | response_file.write(json.dumps(response_data) + '\n')
143 |
144 | accelerator.wait_for_everyone()
145 |
146 | if __name__ == '__main__':
147 | parser = ArgumentParserPlus((InferencingArguments))
148 | args = parser.parse()
149 | start(args)
150 |
--------------------------------------------------------------------------------
/benchrag/run_judging.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numbers
3 | from dataclasses import dataclass, field
4 | from pathlib import Path
5 | from typing import List, Optional
6 |
7 | from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
8 |
9 | from utils.dataset_utils import (
10 | load_jsonl_file,
11 | )
12 | from finetunerag.utils import ArgumentParserPlus
13 | from utils.openai import call_openai_api
14 | from prompts.judging_prompts import (
15 | OpenAIJudgeResponse,
16 | get_judge_user_prompt,
17 | judge_accuracy_system_prompt,
18 | judge_depth_system_prompt,
19 | judge_helpfulness_system_prompt,
20 | judge_relevance_system_prompt,
21 | )
22 | from utils.logger import setup_logger
23 |
24 |
25 | @dataclass
26 | class InferencingArguments:
27 | """
28 | Full arguments class for inferencing from checkpoints.
29 | """
30 |
31 | answers_directory: Optional[str] = field(
32 | default=None,
33 | metadata={
34 | 'help': 'Path to an answers directory that potentially contains multiple set of answers from a model in the form of jsonl files. This cannot be provided along with --answers_file.'
35 | },
36 | )
37 | answers_file: Optional[str] = field(
38 | default=None,
39 | metadata={
40 | 'help': 'Path to a specific answers file from a model in the form of jsonl. This cannot be provided along with --answers_directory.'
41 | },
42 | )
43 | openai_evaluator: str = field(
44 | default='gpt-4o',
45 | metadata={
46 | 'help': 'The evaluator to use from OpenAI. Refer to their documentation for available models.'
47 | },
48 | )
49 | output_directory: str = field(
50 | default='ragbench/judged_scores',
51 | metadata={'help': 'Output directory to save the judging results.'},
52 | )
53 |
54 | def __post_init__(self):
55 | if self.answers_directory is None and self.answers_file is None:
56 | raise ValueError('No answers directory or answers file provided.')
57 |
58 | if self.answers_directory is not None and self.answers_file is not None:
59 | raise ValueError(
60 | 'Both answers directory and answers file provided. Please only provide either one.'
61 | )
62 |
63 | self.output_path = Path(self.output_directory)
64 | self.answers_files_path = (
65 | Path(self.answers_directory) if self.answers_directory else None
66 | )
67 | self.answers_file_path = Path(self.answers_file) if self.answers_file else None
68 |
69 |
70 | JUDGE_SYSTEM_PROMPTS = [
71 | judge_accuracy_system_prompt,
72 | judge_helpfulness_system_prompt,
73 | judge_relevance_system_prompt,
74 | judge_depth_system_prompt,
75 | ]
76 |
77 | # Global logger
78 | logger = setup_logger(Path(__file__).name)
79 |
80 |
81 | def start(args: InferencingArguments):
82 | files_to_judge: Path = []
83 | if args.answers_files_path:
84 | for file_path in args.answers_files_path.iterdir():
85 | if file_path.is_file() and file_path.suffix == '.jsonl':
86 | files_to_judge.append(file_path)
87 | else:
88 | files_to_judge.append(args.answers_file_path)
89 |
90 | logger.debug(
91 | f'These are the files identified: {list(map(lambda file_path: file_path.name, files_to_judge))}'
92 | )
93 | for file_to_judge in files_to_judge:
94 | aggregate_file(file_to_judge, args)
95 |
96 |
97 | def aggregate_file(answers_file_path: Path, args: InferencingArguments):
98 | cumulative_stats = {'scores': 0, 'n': 0, 'failed': 0, 'trues': 0, 'falses': 0}
99 |
100 | documents = load_jsonl_file(answers_file_path)
101 |
102 | logger.debug(f'Total of {len(documents)} samples to rate.')
103 |
104 | args.output_path.mkdir(parents=True, exist_ok=True)
105 | output_file_path = args.output_path / answers_file_path.name
106 |
107 | files_done = set()
108 | # resume from checkpoint
109 | if output_file_path.is_file():
110 | output_list = load_jsonl_file(output_file_path)
111 |
112 | for json_row in output_list:
113 | files_done.add(json_row['filename'])
114 |
115 | # Update statistics
116 | if json_row['accuracy']:
117 | cumulative_stats['trues'] += 1
118 | else:
119 | cumulative_stats['falses'] += 1
120 |
121 | cumulative_stats['n'] += 1
122 | if json_row['average'] is not None:
123 | cumulative_stats['scores'] += json_row['average']
124 |
125 | logger.info(f"{cumulative_stats['n']} sample(s) already done.")
126 |
127 | # call openAI
128 | for index, document in enumerate(documents):
129 | if document['filename'] in files_done:
130 | continue
131 |
132 | try:
133 | logger.info(f"Attempting to judge index {index}: {document['filename']}...")
134 | parsed_response: OpenAIJudgeResponse = judge(
135 | document=document, evaluator=args.openai_evaluator
136 | )
137 |
138 | except Exception as error:
139 | logger.error(f"Filename: [{document['filename']}] errored.")
140 | logger.error(error)
141 | cumulative_stats['failed'] += 1
142 | continue
143 |
144 | # Compute the average score
145 | scores: List[numbers.Real] = [
146 | score
147 | for score in parsed_response.values()
148 | if isinstance(score, numbers.Real) and not isinstance(score, bool)
149 | ]
150 | if len(scores) > 0:
151 | parsed_response['average'] = sum(scores) / len(scores)
152 | else:
153 | parsed_response['average'] = None
154 |
155 | # Add the filename into the dataset
156 | parsed_response['filename'] = document['filename']
157 |
158 | parsed_response_str = json.dumps(parsed_response)
159 | with open(output_file_path, 'a', encoding='utf-8') as responses_file:
160 | responses_file.write(parsed_response_str + '\n')
161 |
162 | # Update statistics
163 | cumulative_stats['n'] += 1
164 | logger.info(f"{cumulative_stats['n']}/{len(documents)} inferences done")
165 |
166 | if parsed_response['accuracy']:
167 | cumulative_stats['trues'] += 1
168 | else:
169 | cumulative_stats['falses'] += 1
170 |
171 | if parsed_response['average'] is not None:
172 | cumulative_stats['scores'] += parsed_response['average']
173 |
174 | logger.info(cumulative_stats)
175 | logger.info(f"Final average: {cumulative_stats['scores'] / cumulative_stats['n']}")
176 |
177 |
178 | def judge(
179 | document,
180 | evaluator,
181 | ) -> OpenAIJudgeResponse:
182 | parsed_responses = {}
183 | for judge_system_prompt in JUDGE_SYSTEM_PROMPTS:
184 | messages: List[ChatCompletionMessageParam] = []
185 |
186 | user_prompt = get_judge_user_prompt(document)
187 |
188 | messages.append(judge_system_prompt)
189 | messages.append(user_prompt)
190 |
191 | parsed_response = call_openai_api(messages=messages, model=evaluator, output_as_json=True)
192 | parsed_responses.update(parsed_response)
193 |
194 | return parsed_responses
195 |
196 |
197 | if __name__ == '__main__':
198 | parser = ArgumentParserPlus((InferencingArguments))
199 | args = parser.parse()
200 | start(args)
201 |
--------------------------------------------------------------------------------
/configs/deepspeed/stage2_no_offloading_accelerate.conf:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "zero_optimization": {
6 | "stage": 2,
7 | "overlap_comm": true,
8 | "contiguous_gradients": true,
9 | "sub_group_size": 1e9,
10 | "reduce_bucket_size": "auto"
11 | },
12 | "gradient_accumulation_steps": "auto",
13 | "gradient_clipping": "auto",
14 | "steps_per_print": 1e5,
15 | "train_batch_size": "auto",
16 | "train_micro_batch_size_per_gpu": "auto",
17 | "wall_clock_breakdown": false
18 | }
--------------------------------------------------------------------------------
/configs/deepspeed/stage2_offloading_accelerate.conf:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "zero_optimization": {
6 | "stage": 2,
7 | "offload_optimizer": {
8 | "device": "cpu",
9 | "pin_memory": true
10 | },
11 | "overlap_comm": true,
12 | "contiguous_gradients": true,
13 | "sub_group_size": 1e9,
14 | "reduce_bucket_size": "auto"
15 | },
16 | "gradient_accumulation_steps": "auto",
17 | "gradient_clipping": "auto",
18 | "steps_per_print": 1e5,
19 | "train_batch_size": "auto",
20 | "train_micro_batch_size_per_gpu": "auto",
21 | "wall_clock_breakdown": false
22 | }
--------------------------------------------------------------------------------
/configs/deepspeed/stage3_no_offloading.conf:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "optimizer": {
6 | "type": "AdamW",
7 | "params": {
8 | "lr": "auto",
9 | "betas": "auto",
10 | "eps": "auto",
11 | "weight_decay": "auto"
12 | }
13 | },
14 | "scheduler": {
15 | "type": "WarmupDecayLR",
16 | "params": {
17 | "total_num_steps": "auto",
18 | "warmup_min_lr": "auto",
19 | "warmup_max_lr": "auto",
20 | "warmup_num_steps": "auto"
21 | }
22 | },
23 | "zero_optimization": {
24 | "stage": 3,
25 | "overlap_comm": true,
26 | "contiguous_gradients": true,
27 | "sub_group_size": 1e9,
28 | "reduce_bucket_size": "auto",
29 | "stage3_prefetch_bucket_size": "auto",
30 | "stage3_param_persistence_threshold": "auto",
31 | "stage3_max_live_parameters": 1e9,
32 | "stage3_max_reuse_distance": 1e9,
33 | "stage3_gather_16bit_weights_on_model_save": true
34 | },
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 1e5,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
--------------------------------------------------------------------------------
/configs/deepspeed/stage3_no_offloading_accelerate.conf:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "zero_optimization": {
6 | "stage": 3,
7 | "overlap_comm": true,
8 | "contiguous_gradients": true,
9 | "sub_group_size": 1e9,
10 | "reduce_bucket_size": "auto",
11 | "stage3_prefetch_bucket_size": "auto",
12 | "stage3_param_persistence_threshold": "auto",
13 | "stage3_max_live_parameters": 1e9,
14 | "stage3_max_reuse_distance": 1e9,
15 | "stage3_gather_16bit_weights_on_model_save": true
16 | },
17 | "gradient_accumulation_steps": "auto",
18 | "gradient_clipping": "auto",
19 | "steps_per_print": 1e5,
20 | "train_batch_size": "auto",
21 | "train_micro_batch_size_per_gpu": "auto",
22 | "wall_clock_breakdown": false
23 | }
--------------------------------------------------------------------------------
/configs/deepspeed/stage3_offloading.conf:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "optimizer": {
6 | "type": "AdamW",
7 | "params": {
8 | "lr": "auto",
9 | "betas": "auto",
10 | "eps": "auto",
11 | "weight_decay": "auto"
12 | }
13 | },
14 | "scheduler": {
15 | "type": "WarmupDecayLR",
16 | "params": {
17 | "total_num_steps": "auto",
18 | "warmup_min_lr": "auto",
19 | "warmup_max_lr": "auto",
20 | "warmup_num_steps": "auto"
21 | }
22 | },
23 | "zero_optimization": {
24 | "stage": 3,
25 | "offload_optimizer": {
26 | "device": "cpu",
27 | "pin_memory": true
28 | },
29 | "offload_param": {
30 | "device": "cpu",
31 | "pin_memory": true
32 | },
33 | "overlap_comm": true,
34 | "contiguous_gradients": true,
35 | "sub_group_size": 1e9,
36 | "reduce_bucket_size": "auto",
37 | "stage3_prefetch_bucket_size": "auto",
38 | "stage3_param_persistence_threshold": "auto",
39 | "stage3_max_live_parameters": 1e9,
40 | "stage3_max_reuse_distance": 1e9,
41 | "stage3_gather_16bit_weights_on_model_save": true
42 | },
43 | "gradient_accumulation_steps": "auto",
44 | "gradient_clipping": "auto",
45 | "steps_per_print": 1e5,
46 | "train_batch_size": "auto",
47 | "train_micro_batch_size_per_gpu": "auto",
48 | "wall_clock_breakdown": false
49 | }
--------------------------------------------------------------------------------
/configs/deepspeed/stage3_offloading_accelerate.conf:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "zero_optimization": {
6 | "stage": 3,
7 | "offload_optimizer": {
8 | "device": "cpu",
9 | "pin_memory": true
10 | },
11 | "offload_param": {
12 | "device": "cpu",
13 | "pin_memory": true
14 | },
15 | "overlap_comm": true,
16 | "contiguous_gradients": true,
17 | "sub_group_size": 1e9,
18 | "reduce_bucket_size": "auto",
19 | "stage3_prefetch_bucket_size": "auto",
20 | "stage3_param_persistence_threshold": "auto",
21 | "stage3_max_live_parameters": 1e9,
22 | "stage3_max_reuse_distance": 1e9,
23 | "stage3_gather_16bit_weights_on_model_save": true
24 | },
25 | "gradient_accumulation_steps": "auto",
26 | "gradient_clipping": "auto",
27 | "steps_per_print": 1e5,
28 | "train_batch_size": "auto",
29 | "train_micro_batch_size_per_gpu": "auto",
30 | "wall_clock_breakdown": false
31 | }
32 |
--------------------------------------------------------------------------------
/finetunerag/arguments.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 | from dataclasses import dataclass, field
4 | from typing import Optional
5 |
6 |
7 | @dataclass
8 | class FinetuneArguments:
9 | """
10 | Full arguments class for fine-tuning.
11 | """
12 |
13 | exp_name: str = field(
14 | default=os.path.basename(__file__)[: -len('.py')],
15 | metadata={
16 | 'help': (
17 | "The name of this experiment."
18 | )
19 | },
20 | )
21 | model_name_or_path: Optional[str] = field(
22 | default=None,
23 | metadata={
24 | 'help': (
25 | "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
26 | )
27 | },
28 | )
29 | tokenizer_name_or_path: Optional[str] = field(
30 | default=None,
31 | metadata={
32 | 'help': 'Pretrained tokenizer name or path if not the same as model_name_or_path'
33 | },
34 | )
35 | tokenizer_revision: Optional[str] = field(
36 | default='main',
37 | metadata={
38 | 'help': 'The specific model version to use (can be a branch name, tag name or commit id).'
39 | },
40 | )
41 | prompt_style: Optional[str] = field(
42 | default='default',
43 | metadata={
44 | 'help': 'The specific prompt template to use (should be registered under one of the custom prompts).'
45 | },
46 | )
47 | use_flash_attn: bool = field(
48 | default=True,
49 | metadata={'help': 'Whether to use flash attention in the model training'},
50 | )
51 | model_revision: str = field(
52 | default='main',
53 | metadata={
54 | 'help': 'The specific model version to use (can be a branch name, tag name or commit id).'
55 | },
56 | )
57 | trust_remote_code: bool = field(
58 | default=False,
59 | metadata={
60 | 'help': (
61 | 'Whether or not to allow for custom models defined on the Hub in their own modeling files. '
62 | 'This option should only be set to `True` for repositories you trust and in which you '
63 | 'have read the code, as it will execute code present on the Hub on your local machine.'
64 | )
65 | },
66 | )
67 | low_cpu_mem_usage: bool = field(
68 | default=False,
69 | metadata={
70 | 'help': (
71 | 'It is an option to create the model as an empty shell, '
72 | 'then only materialize its parameters when the pretrained weights are loaded. '
73 | 'set True will benefit LLM loading time and RAM consumption.'
74 | )
75 | },
76 | )
77 | train_file: Optional[str] = field(
78 | default=None,
79 | metadata={'help': 'The input training data file (a json/jsonl file).'},
80 | )
81 | validation_file: Optional[str] = field(
82 | default=None,
83 | metadata={'help': 'The input validation file (a json/jsonl file).'},
84 | )
85 | max_train_samples: Optional[int] = field(
86 | default=None,
87 | metadata={
88 | 'help': (
89 | 'For debugging purposes or quicker training, truncate the number of training examples to this '
90 | 'value if set.'
91 | )
92 | },
93 | )
94 | preprocessing_num_workers: Optional[int] = field(
95 | default=None,
96 | metadata={'help': 'The number of processes to use for the preprocessing.'},
97 | )
98 | max_seq_length: Optional[int] = field(
99 | default=None,
100 | metadata={
101 | 'help': (
102 | 'The maximum total input sequence length after tokenization. '
103 | 'Sequences longer than this will be truncated,'
104 | )
105 | },
106 | )
107 | overwrite_cache: bool = field(
108 | default=False,
109 | metadata={'help': 'Overwrite the cached training and evaluation sets'},
110 | )
111 | add_bos: bool = field(
112 | default=False,
113 | metadata={
114 | 'help': 'Forcibly add bos token to the beginning of the input sequence.'
115 | ' Use only when tokenizer does not add bos token by default.'
116 | },
117 | )
118 | clip_grad_norm: float = field(
119 | default=-1,
120 | metadata={
121 | 'help': 'Clip gradient norm. Not compatible with deepspeed (use deepspeed config instead).'
122 | },
123 | )
124 | gradient_accumulation_steps: int = field(
125 | default=1,
126 | metadata={
127 | 'help': 'Number of updates steps to accumulate before performing a backward/update pass.'
128 | },
129 | )
130 | learning_rate: float = field(
131 | default=2e-5,
132 | metadata={'help': 'The initial learning rate for AdamW optimizer.'},
133 | )
134 | beta1: float = field(
135 | default=0.9,
136 | metadata={
137 | 'help': 'The coefficient used for computing running averages of gradient and its square within the optimiser.'
138 | },
139 | )
140 | beta2: float = field(
141 | default=0.95,
142 | metadata={
143 | 'help': 'The coefficient used for computing running averages of gradient and its square within the optimiser.'
144 | },
145 | )
146 | logging_steps: Optional[int] = field(
147 | default=None,
148 | metadata={
149 | 'help': 'Log the training loss and learning rate every logging_steps steps.'
150 | },
151 | )
152 | lr_scheduler_type: str = field(
153 | default='linear',
154 | metadata={
155 | 'help': 'The scheduler type to use for learning rate adjustment.',
156 | 'choices': [
157 | 'linear',
158 | 'cosine',
159 | 'cosine_with_restarts',
160 | 'polynomial',
161 | 'constant',
162 | 'constant_with_warmup',
163 | ],
164 | },
165 | )
166 | num_train_epochs: int = field(
167 | default=2,
168 | metadata={'help': 'Total number of training epochs to perform.'},
169 | )
170 | output_dir: str = field(
171 | default='output/',
172 | metadata={
173 | 'help': 'The output directory where the model predictions and checkpoints will be written.'
174 | },
175 | )
176 | per_device_train_batch_size: int = field(
177 | default=8,
178 | metadata={'help': 'Batch size per GPU/TPU core/CPU for training.'},
179 | )
180 | use_8bit_optimizer: bool = field(
181 | default=False,
182 | metadata={
183 | 'help': 'Use 8bit optimizer from bitsandbytes. Not compatible with deepspeed.'
184 | },
185 | )
186 | warmup_ratio: float = field(
187 | default=0.03,
188 | metadata={'help': 'Linear warmup over warmup_ratio fraction of total steps.'},
189 | )
190 | weight_decay: float = field(
191 | default=0.0,
192 | metadata={'help': 'Weight decay for AdamW if we apply some.'},
193 | )
194 | timeout: int = field(
195 | default=1800,
196 | metadata={
197 | 'help': 'Timeout for the training process in seconds.'
198 | 'Useful if tokenization process is long. Default is 1800 seconds (30 minutes).'
199 | },
200 | )
201 | reduce_loss: str = field(
202 | default='mean',
203 | metadata={
204 | 'help': "How to reduce loss over tokens. Options are 'mean' or 'sum'."
205 | "Using 'sum' can improve chat model performance."
206 | },
207 | )
208 | resume_from_checkpoint: Optional[str] = field(
209 | default=None,
210 | metadata={'help': 'If the training should continue from a checkpoint folder.'},
211 | )
212 | enable_wandb: bool = field(
213 | default=False,
214 | metadata={'help': 'Whether to enable wandb for logging.'},
215 | )
216 | wandb_entity: Optional[str] = field(
217 | default=None,
218 | metadata={'help': 'Entity to use for logging to wandb.'},
219 | )
220 | wandb_project: Optional[str] = field(
221 | default='test-runs',
222 | metadata={'help': 'Project name to use when logging to wandb.'},
223 | )
224 | wandb_name: Optional[str] = field(
225 | default='wandb',
226 | metadata={'help': 'Run name to use when logging to wandb.'},
227 | )
228 | gradient_checkpointing: bool = field(
229 | default=False,
230 | metadata={
231 | 'help': 'Turn on gradient checkpointing. Saves memory but slows training.'
232 | },
233 | )
234 | max_train_steps: Optional[int] = field(
235 | default=None,
236 | metadata={
237 | 'help': 'If set, overrides the number of training steps. Otherwise, num_train_epochs is used.'
238 | },
239 | )
240 | seed: int = field(
241 | default=42,
242 | metadata={'help': 'Random seed for initialization and dataset shuffling.'},
243 | )
244 | checkpointing_steps: Optional[str] = field(
245 | default=None,
246 | metadata={
247 | 'help': "Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch." # noqa
248 | },
249 | )
250 | validation_steps: Optional[int] = field(
251 | default=None,
252 | metadata={
253 | 'help': 'Compute loss on validation data at the end of every n steps'
254 | },
255 | )
256 | keep_last_n_checkpoints: Optional[int] = field(
257 | default=None,
258 | metadata={
259 | 'help': 'How many checkpoints to keep in the output directory. -1 for all.'
260 | },
261 | )
262 | overwrite_output_dir: bool = field(
263 | default=False,
264 | metadata={
265 | 'help': 'Overwrite the content of the output directory. Means that resumption will always start from scratch.'
266 | },
267 | )
268 | fused_optimizer: bool = field(
269 | default=True,
270 | metadata={
271 | 'help': 'Whether to use fused AdamW or not.',
272 | },
273 | )
274 |
275 | def __post_init__(self):
276 | if self.model_name_or_path is None:
277 | raise ValueError('The Hugging Face model name or path is not indicated.')
278 |
279 | if self.tokenizer_name_or_path is None:
280 | warnings.warn('The tokenizer name or path is not indicated. Defaulting it to model_name_or_path.')
281 | self.tokenizer_name_or_path = self.model_name_or_path
282 |
283 | if self.reduce_loss not in ['mean', 'sum']:
284 | raise ValueError("reduce_loss must be either 'mean' or 'sum'")
285 |
286 | if self.train_file is None:
287 | raise ValueError('Need either a dataset name, dataset mixer, or a training file.')
288 | else:
289 | extension = self.train_file.split('.')[-1]
290 | assert extension in [
291 | 'json',
292 | 'jsonl',
293 | ], '`train_file` should be a json or a jsonl file.'
294 |
295 | if self.validation_steps and not self.validation_file:
296 | raise ValueError(
297 | "The number of steps for every validation is indicated. However, the path to the validation dataset is not provided. Please provide it with '--validation_file'."
298 | )
299 |
300 | if self.validation_file and not self.validation_steps:
301 | warnings.warn(
302 | 'The path to the validation dataset is provided. However, the number of steps for every validation is not indicated. Defaulted to 1.'
303 | )
304 | self.validation_steps = 1
--------------------------------------------------------------------------------
/finetunerag/finetune.py:
--------------------------------------------------------------------------------
1 | # Finetuning code referenced from https://github.com/allenai/open-instruct
2 |
3 | import logging
4 | import math
5 | import os
6 | import random
7 | from datetime import timedelta
8 | from functools import partial
9 |
10 | import datasets
11 | import deepspeed
12 | import torch
13 | import transformers
14 | from accelerate import Accelerator
15 | from accelerate.logging import get_logger
16 | from accelerate.utils import InitProcessGroupKwargs, set_seed
17 | from datasets import load_dataset
18 | from torch.utils.data import DataLoader
19 | from tqdm.auto import tqdm
20 | from transformers import (
21 | AutoConfig,
22 | AutoModelForCausalLM,
23 | AutoTokenizer,
24 | DataCollatorForSeq2Seq,
25 | get_scheduler,
26 | )
27 | print("Current Working Directory:", os.getcwd())
28 |
29 | from prompts.prompt_styles import PromptStyle
30 | from guardrag.arguments import FinetuneArguments
31 | from guardrag.model_utils import save_with_accelerate
32 | from guardrag.utils import (
33 | ArgumentParserPlus,
34 | clean_last_n_checkpoints,
35 | get_last_checkpoint_path,
36 | get_wandb_tags,
37 | )
38 |
39 | logger = get_logger(__name__)
40 |
41 | def main(args: FinetuneArguments):
42 |
43 | ##########################
44 | # Initialise Accelerator #
45 | ##########################
46 | accelerator_log_kwargs = {}
47 |
48 | if args.enable_wandb:
49 | accelerator_log_kwargs['log_with'] = 'wandb'
50 | accelerator_log_kwargs['project_dir'] = args.output_dir
51 |
52 | # If you get timeouts (e.g. due to long tokenization) increase this.
53 | timeout_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.timeout))
54 |
55 | accelerator = Accelerator(
56 | gradient_accumulation_steps=args.gradient_accumulation_steps,
57 | use_seedable_sampler=True,
58 | **accelerator_log_kwargs,
59 | kwargs_handlers=[timeout_kwargs],
60 | )
61 |
62 | if args.seed is not None:
63 | set_seed(args.seed)
64 |
65 | if args.enable_wandb:
66 | experiment_config = vars(args)
67 |
68 | accelerator.init_trackers(
69 | args.wandb_project,
70 | experiment_config,
71 | init_kwargs={
72 | 'wandb': {
73 | 'entity': args.wandb_entity,
74 | 'name': args.wandb_name,
75 | 'tags': [args.exp_name] + get_wandb_tags(),
76 | }
77 | },
78 | )
79 |
80 | #####################
81 | # Configure Logging #
82 | #####################
83 | logging.basicConfig(
84 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
85 | datefmt='%m/%d/%Y %H:%M:%S',
86 | level=logging.INFO,
87 | )
88 | logger.info(accelerator.state, main_process_only=False)
89 | if accelerator.is_local_main_process:
90 | datasets.utils.logging.set_verbosity_warning()
91 | transformers.utils.logging.set_verbosity_info()
92 | else:
93 | datasets.utils.logging.set_verbosity_error()
94 | transformers.utils.logging.set_verbosity_error()
95 |
96 | accelerator.wait_for_everyone()
97 |
98 | ##########################################
99 | # Load Pretrained HF Model and Tokenizer #
100 | ##########################################
101 | config = AutoConfig.from_pretrained(
102 | args.model_name_or_path,
103 | trust_remote_code=args.trust_remote_code,
104 | revision=args.model_revision,
105 | token=os.getenv('HF_TOKEN', None),
106 | )
107 |
108 | tokenizer_revision = (
109 | args.model_revision
110 | if args.tokenizer_revision is None
111 | else args.tokenizer_revision
112 | )
113 |
114 | if tokenizer_revision != args.model_revision:
115 | # Warn user if tokenizer and model use different revisions; this is an unusual use case.
116 | warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different
117 | from the model revision `{args.model_revision}`."""
118 | logger.warn(warning)
119 |
120 | tokenizer = AutoTokenizer.from_pretrained(
121 | args.tokenizer_name_or_path,
122 | trust_remote_code=args.trust_remote_code,
123 | revision=tokenizer_revision,
124 | token=os.getenv('HF_TOKEN', None),
125 | )
126 |
127 | if args.model_name_or_path:
128 | model = AutoModelForCausalLM.from_pretrained(
129 | args.model_name_or_path,
130 | from_tf=bool('.ckpt' in args.model_name_or_path),
131 | config=config,
132 | trust_remote_code=args.trust_remote_code,
133 | low_cpu_mem_usage=args.low_cpu_mem_usage,
134 | torch_dtype=torch.bfloat16,
135 | attn_implementation='flash_attention_2'
136 | if args.use_flash_attn
137 | else 'eager',
138 | revision=args.model_revision,
139 | token=os.getenv('HF_TOKEN', None),
140 | )
141 | else:
142 | logger.info('Training new model from scratch')
143 | model = AutoModelForCausalLM.from_config(config)
144 |
145 |
146 | ######################
147 | # Embedding Resizing #
148 | ######################
149 |
150 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
151 | # gather deepspeed to get 'real' embedding size
152 | embeddings = model.get_input_embeddings()
153 | with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
154 | embedding_size = embeddings.weight.shape[0]
155 | # resize does its own gather
156 | if len(tokenizer) > embedding_size:
157 | # pad to multiple for tensor cores.
158 | model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
159 | # update embedding size after resizing for sum loss
160 | embeddings = model.get_input_embeddings()
161 | with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
162 | embedding_size = embeddings.weight.shape[0]
163 |
164 |
165 | #######################################
166 | # Prepare Dataset & Set up Dataloader #
167 | #######################################
168 | data_files = {}
169 | dataset_args = {}
170 |
171 | if args.train_file:
172 | data_files['train'] = args.train_file
173 | if args.validation_file:
174 | data_files['validation'] = args.validation_file
175 |
176 | raw_datasets = load_dataset(
177 | 'json',
178 | data_files=data_files,
179 | **dataset_args,
180 | )
181 |
182 | train_dataset = raw_datasets['train']
183 | validation_dataset = raw_datasets.get('validation', None)
184 |
185 | if 'messages' not in train_dataset.column_names:
186 | raise ValueError("You need to have 'messages' in your training data.")
187 | if validation_dataset and 'messages' not in validation_dataset.column_names:
188 | raise ValueError("You need to have 'messages' in your validation data.")
189 |
190 | # Limit training samples. Used for debugging.
191 | if args.max_train_samples is not None:
192 | max_train_samples = min(len(train_dataset), args.max_train_samples)
193 | logger.info(f'Limiting training samples to {max_train_samples} from {len(train_dataset)}.')
194 | train_dataset = train_dataset.select(range(max_train_samples))
195 |
196 | encode_function = partial(
197 | encode_messages,
198 | tokenizer=tokenizer,
199 | max_seq_length=args.max_seq_length,
200 | prompt_style=args.prompt_style,
201 | add_bos=args.add_bos,
202 | )
203 |
204 | with accelerator.main_process_first():
205 | train_dataset = train_dataset.map(
206 | encode_function,
207 | batched=False,
208 | num_proc=args.preprocessing_num_workers,
209 | load_from_cache_file=not args.overwrite_cache,
210 | remove_columns=[
211 | name
212 | for name in train_dataset.column_names
213 | if name not in ['input_ids', 'labels', 'attention_mask']
214 | ],
215 | desc='Tokenizing and reformatting instruction data',
216 | )
217 |
218 | train_dataset.set_format(type='pt')
219 | train_dataset = train_dataset.filter(
220 | lambda example: (example['labels'] != -100).any()
221 | )
222 |
223 | # Log a few random samples from the training set.
224 | for index in random.sample(range(len(train_dataset)), 3):
225 | logger.info(f'Sample {index} of the training set: {train_dataset[index]}.')
226 |
227 | train_dataloader = DataLoader(
228 | train_dataset,
229 | shuffle=True,
230 | collate_fn=DataCollatorForSeq2Seq(
231 | tokenizer=tokenizer, model=model, padding='longest'
232 | ),
233 | batch_size=args.per_device_train_batch_size,
234 | )
235 |
236 | if validation_dataset:
237 | validation_dataset = validation_dataset.map(
238 | encode_function,
239 | batched=False,
240 | num_proc=args.preprocessing_num_workers,
241 | load_from_cache_file=not args.overwrite_cache,
242 | remove_columns=[
243 | name
244 | for name in validation_dataset.column_names
245 | if name not in ['input_ids', 'labels', 'attention_mask']
246 | ],
247 | desc='Tokenizing and reformatting validation data',
248 | )
249 | validation_dataset.set_format(type='pt')
250 | validation_dataset = validation_dataset.filter(lambda example: (example['labels'] != -100).any())
251 |
252 | validation_dataloader = DataLoader(
253 | validation_dataset,
254 | shuffle=True,
255 | collate_fn=DataCollatorForSeq2Seq(
256 | tokenizer=tokenizer, model=model, padding='longest'
257 | ),
258 | batch_size=args.per_device_train_batch_size,
259 | )
260 |
261 |
262 | ##############################
263 | # Optimizer and LR Scheduler #
264 | ##############################
265 |
266 | # Split weights in two groups, one with weight decay and the other not.
267 | no_decay = ['bias', 'layer_norm.weight']
268 | optimizer_grouped_parameters = [
269 | {
270 | 'params': [
271 | p
272 | for n, p in model.named_parameters()
273 | if not any(nd in n for nd in no_decay)
274 | ],
275 | 'weight_decay': args.weight_decay,
276 | },
277 | {
278 | 'params': [
279 | p
280 | for n, p in model.named_parameters()
281 | if any(nd in n for nd in no_decay)
282 | ],
283 | 'weight_decay': 0.0,
284 | },
285 | ]
286 |
287 | optimizer = torch.optim.AdamW(
288 | optimizer_grouped_parameters,
289 | lr=args.learning_rate,
290 | fused=args.fused_optimizer,
291 | betas=(args.beta1, args.beta2),
292 | )
293 |
294 | # Scheduler and math around the number of training steps.
295 | overrode_max_train_steps = False
296 | num_update_steps_per_epoch = math.ceil(
297 | len(train_dataloader) / args.gradient_accumulation_steps
298 | )
299 | if args.max_train_steps is None:
300 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
301 | overrode_max_train_steps = True
302 |
303 | # Create the learning rate scheduler.
304 | # Note: the current accelerator.step() calls the .step() of the real scheduler
305 | # for the `num_processes` times. This is because they assume
306 | # the user initialize the scheduler with the entire training set.
307 | # In the case of data parallel training, each process only
308 | # sees a subset (1/num_processes) of the training set.
309 | # So each time the process needs to update the lr multiple times so that the total
310 | # number of updates in the end matches the num_training_steps here.
311 | # Here we need to set the num_training_steps to either using the
312 | # entire training set (when epochs is specified) or we need to multiply the
313 | # num_training_steps by num_processes so that the total number of
314 | # updates matches the num_training_steps.
315 | num_training_steps_for_scheduler = (
316 | args.max_train_steps
317 | if overrode_max_train_steps
318 | else args.max_train_steps * accelerator.num_processes
319 | )
320 | lr_scheduler = get_scheduler(
321 | name=args.lr_scheduler_type,
322 | optimizer=optimizer,
323 | num_training_steps=num_training_steps_for_scheduler,
324 | num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio),
325 | )
326 |
327 | #################
328 | # Miscellaneous #
329 | #################
330 |
331 | if accelerator.is_main_process and args.output_dir is not None:
332 | os.makedirs(args.output_dir, exist_ok=True)
333 |
334 | validation_steps = args.validation_steps
335 | if validation_steps is not None:
336 | validation_steps = int(validation_steps)
337 |
338 | if args.gradient_checkpointing:
339 | model.gradient_checkpointing_enable()
340 |
341 | # Prepare everything with `accelerator`.
342 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
343 | model, optimizer, train_dataloader, lr_scheduler
344 | )
345 |
346 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
347 | num_update_steps_per_epoch = math.ceil(
348 | len(train_dataloader) / args.gradient_accumulation_steps
349 | )
350 | if overrode_max_train_steps:
351 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
352 | # Afterwards we recalculate our number of training epochs
353 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
354 |
355 | # Figure out how many steps we should save the Accelerator states
356 | checkpointing_steps = args.checkpointing_steps
357 | if checkpointing_steps is not None and str(checkpointing_steps).lower() != 'epoch':
358 | checkpointing_steps = int(checkpointing_steps)
359 |
360 |
361 | ##############
362 | # Finetuning #
363 | ##############
364 |
365 | total_batch_size = (
366 | args.per_device_train_batch_size
367 | * accelerator.num_processes
368 | * args.gradient_accumulation_steps
369 | )
370 |
371 | logger.info('***** Running training *****')
372 | logger.info(f' Num examples = {len(train_dataset)}')
373 | logger.info(f' Num Epochs = {args.num_train_epochs}')
374 | logger.info(f' Instantaneous batch size per device = {args.per_device_train_batch_size}')
375 | logger.info(f' Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}')
376 | logger.info(f' Gradient Accumulation steps = {args.gradient_accumulation_steps}')
377 | logger.info(f' Total optimization steps = {args.max_train_steps}')
378 | # Only show the progress bar once on each machine.
379 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
380 | completed_steps = 0
381 | starting_epoch = 0
382 |
383 | # Potentially load in the weights and states from a previous save
384 | last_checkpoint_path = get_last_checkpoint_path(args)
385 | if last_checkpoint_path:
386 | accelerator.print(f'Resumed from checkpoint: {last_checkpoint_path}')
387 | accelerator.load_state(last_checkpoint_path)
388 | # Extract `epoch_{i}` or `step_{i}`
389 | last_checkpoint_path = os.path.basename(last_checkpoint_path)
390 | training_difference = os.path.splitext(last_checkpoint_path)[0]
391 |
392 | if 'epoch' in training_difference:
393 | starting_epoch = int(training_difference.replace('epoch_', '')) + 1
394 | resume_step = None
395 | completed_steps = starting_epoch * num_update_steps_per_epoch
396 | else:
397 | # need to multiply `gradient_accumulation_steps` to reflect real steps
398 | resume_step = (int(training_difference.replace('step_', '')) * args.gradient_accumulation_steps)
399 | starting_epoch = resume_step // len(train_dataloader)
400 | completed_steps = resume_step // args.gradient_accumulation_steps
401 | resume_step -= starting_epoch * len(train_dataloader)
402 |
403 | print(f'Starting from epoch {starting_epoch} and step {completed_steps}.')
404 | # update the progress_bar if load from checkpoint
405 | progress_bar.update(completed_steps)
406 |
407 | for epoch in range(starting_epoch, args.num_train_epochs):
408 | model.train()
409 | train_dataloader.set_epoch(epoch)
410 | total_loss = 0
411 | if last_checkpoint_path and resume_step is not None:
412 | # We skip the first `n` batches in the dataloader when resuming from a checkpoint
413 | active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
414 | else:
415 | active_dataloader = train_dataloader
416 | for batch in active_dataloader:
417 | with accelerator.accumulate(model):
418 | loss = calculate_loss_and_backpropagate(
419 | model,
420 | batch,
421 | accelerator,
422 | optimizer,
423 | lr_scheduler,
424 | embedding_size,
425 | args
426 | )
427 | # We keep track of the loss at each logged step
428 | total_loss += loss.detach().float()
429 |
430 | # Checks if the accelerator has performed an optimization step behind the scenes
431 | if accelerator.sync_gradients:
432 | progress_bar.update(1)
433 | completed_steps += 1
434 |
435 | # Perform Validation (only done by main)
436 | avg_val_loss = None
437 | if (
438 | accelerator.is_local_main_process
439 | and validation_steps
440 | and completed_steps % validation_steps == 0
441 | ):
442 | if completed_steps % validation_steps == 0:
443 | model.eval()
444 | full_val_loss = 0
445 | num_val_batches = 0
446 | with torch.no_grad():
447 | for val_batch in validation_dataloader:
448 | val_batch = {
449 | key: value.to(accelerator.device)
450 | for key, value in val_batch.items()
451 | }
452 | val_loss = calculate_loss(model, val_batch, embedding_size, args)
453 |
454 | full_val_loss += val_loss.detach().float()
455 | num_val_batches += 1
456 |
457 | avg_val_loss = full_val_loss / num_val_batches
458 | model.train()
459 |
460 | if args.logging_steps and completed_steps % args.logging_steps == 0:
461 | avg_loss = (
462 | accelerator.gather(total_loss).mean().item()
463 | / args.gradient_accumulation_steps
464 | / args.logging_steps
465 | )
466 |
467 | val_loss_log = f', Val Loss: {avg_val_loss}' if avg_val_loss else ''
468 | logger.info(f' Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Train Loss: {avg_loss}{val_loss_log}')
469 |
470 | if args.enable_wandb:
471 | log_data = {
472 | 'learning_rate': lr_scheduler.get_last_lr()[0],
473 | 'train_loss': avg_loss,
474 | }
475 | if validation_steps:
476 | log_data['validation_loss'] = avg_val_loss
477 | accelerator.log(log_data, step=completed_steps)
478 |
479 | total_loss = 0
480 |
481 | if isinstance(checkpointing_steps, int):
482 | if completed_steps % checkpointing_steps == 0:
483 | output_dir = f'step_{completed_steps}'
484 | if args.output_dir is not None:
485 | output_dir = os.path.join(args.output_dir, output_dir)
486 | accelerator.save_state(output_dir)
487 | # use this to mark the checkpoint as completely saved, to avoid restoring from garbled checkpoints
488 | with open(os.path.join(get_last_checkpoint_path(args, incomplete=True), 'COMPLETED'), 'w') as f:
489 | f.write('COMPLETED')
490 |
491 | if (accelerator.is_local_main_process and args.keep_last_n_checkpoints):
492 | clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints)
493 | accelerator.wait_for_everyone()
494 |
495 | if completed_steps >= args.max_train_steps:
496 | break
497 |
498 | if checkpointing_steps == 'epoch':
499 | output_dir = f'epoch_{epoch}'
500 | if args.output_dir is not None:
501 | output_dir = os.path.join(args.output_dir, output_dir)
502 | # use this to mark the checkpoint as completely saved, to avoid restoring from garbled checkpoints
503 | with open(os.path.join(get_last_checkpoint_path(args, incomplete=True), 'COMPLETED'), 'w') as f:
504 | f.write('COMPLETED')
505 |
506 | if accelerator.is_local_main_process and args.keep_last_n_checkpoints:
507 | clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints)
508 | accelerator.wait_for_everyone()
509 |
510 | if args.output_dir is not None:
511 | save_with_accelerate(
512 | accelerator,
513 | model,
514 | tokenizer,
515 | args.output_dir,
516 | )
517 |
518 | accelerator.wait_for_everyone()
519 | if args.enable_wandb:
520 | accelerator.end_training()
521 |
522 |
523 | def encode_messages(
524 | example,
525 | tokenizer,
526 | max_seq_length,
527 | prompt_style,
528 | add_bos=False
529 | ):
530 | """
531 | Here we assume each example has a 'messages' field.
532 | 'messages' is a list of messages.
533 | Each message is a dict with 'role' and 'content' fields.
534 | We concatenate all messages with the roles as delimiters and tokenize them together.
535 | """
536 | messages = example['messages']
537 | if len(messages) == 0:
538 | raise ValueError('messages field is empty.')
539 |
540 | style = PromptStyle.from_name(prompt_style)
541 |
542 | # To support multi-turn, we need to mask all non-assistant messages.
543 | # To do this, we compute the length of each tokenized non-assistant messages then mask it
544 | segmented_prompts = []
545 | segmented_prompts_and_responses = []
546 | start_idx, end_idx = 0, 0
547 | while end_idx < len(messages):
548 | while end_idx < len(messages) and messages[end_idx]['role'] != 'assistant':
549 | end_idx += 1
550 | if start_idx <= end_idx:
551 | if start_idx == 0:
552 | # expect system prompt
553 | segmented_prompts.append(
554 | style.apply(
555 | messages[start_idx:end_idx], append_assistant_header=True
556 | ) if style else tokenizer.apply_chat_template(messages[start_idx:end_idx], tokenize=False, add_generation_prompt=True)
557 | )
558 | segmented_prompts_and_responses.append(
559 | style.apply(
560 | messages[start_idx : end_idx + 1], append_assistant_header=False
561 | ) if style else tokenizer.apply_chat_template(messages[start_idx:end_idx + 1], tokenize=False, add_generation_prompt=False)
562 | )
563 | else:
564 | # should not have system prompt for subsequent turns
565 | segmented_prompts.append(
566 | style.apply(
567 | messages[start_idx:end_idx],
568 | no_system=True,
569 | append_assistant_header=True,
570 | ) if style else tokenizer.apply_chat_template(messages[start_idx:end_idx], tokenize=False, add_generation_prompt=True)
571 | )
572 | segmented_prompts_and_responses.append(
573 | style.apply(
574 | messages[start_idx : end_idx + 1],
575 | no_system=True,
576 | append_assistant_header=False,
577 | ) if style else tokenizer.apply_chat_template(messages[start_idx:end_idx + 1], tokenize=False, add_generation_prompt=False)
578 | )
579 | start_idx = end_idx + 1
580 | end_idx += 1 # should be same as start_idx
581 |
582 | if add_bos:
583 | # add bos token to the first prompt
584 | segmented_prompts[0] = tokenizer.bos_token + segmented_prompts[0]
585 | segmented_prompts_and_responses[0] = (
586 | tokenizer.bos_token + segmented_prompts_and_responses[0]
587 | )
588 | encoded_segmented_prompts = list(
589 | map(
590 | lambda prompt: tokenizer(
591 | prompt, return_tensors='pt', max_length=max_seq_length, truncation=True
592 | ).input_ids.flatten(),
593 | segmented_prompts,
594 | )
595 | )
596 | encoded_segmented_prompts_and_responses = list(
597 | map(
598 | lambda prompt_and_response: tokenizer(
599 | prompt_and_response,
600 | return_tensors='pt',
601 | max_length=max_seq_length,
602 | truncation=True,
603 | ).input_ids.flatten(),
604 | segmented_prompts_and_responses,
605 | )
606 | )
607 |
608 | # Achieve the same effect as 'masking' by simply using ignore_index
609 | masked_labels = []
610 | num_split = len(encoded_segmented_prompts)
611 | for i in range(num_split):
612 | encoded_prompt = encoded_segmented_prompts[i]
613 | encoded_prompt_and_response = encoded_segmented_prompts_and_responses[i]
614 | label = encoded_prompt_and_response.clone()
615 | label[: len(encoded_prompt)] = 0
616 | masked_labels.append(label)
617 |
618 | # concatenate the segments
619 | encoded_prompts_and_responses = torch.cat(encoded_segmented_prompts_and_responses)
620 | labels = torch.cat(masked_labels)
621 | attention_mask = torch.ones_like(encoded_prompts_and_responses)
622 |
623 | return {
624 | 'input_ids': encoded_prompts_and_responses,
625 | 'labels': labels,
626 | 'attention_mask': attention_mask,
627 | }
628 |
629 | def calculate_loss(model, batch, embedding_size, args):
630 | outputs = model(**batch, use_cache=False)
631 | if args.reduce_loss == 'mean':
632 | loss = outputs.loss
633 | else:
634 | # reduce loss is sum
635 | # this ensures that we weight all tokens in the dataset equally,
636 | # rather than weighting each overall example equally when
637 | # using high amounts of gradient accumulation.
638 | # this can result in > 5 point improvements in AlpacaEval
639 | # see https://github.com/huggingface/transformers/issues/24725 for
640 | # more discussion and details.
641 | logits = outputs.logits
642 | labels = batch['labels']
643 | # Shift so that tokens < n predict n
644 | shift_logits = logits[..., :-1, :].contiguous()
645 | shift_labels = labels[..., 1:].contiguous()
646 | # Flatten the tokens
647 | loss_fct = torch.nn.CrossEntropyLoss(reduction='sum')
648 | shift_logits = shift_logits.view(-1, embedding_size)
649 | shift_labels = shift_labels.view(-1)
650 | # Enable model parallelism
651 | shift_labels = shift_labels.to(shift_logits.device)
652 | loss = loss_fct(shift_logits, shift_labels)
653 |
654 | return loss
655 |
656 | def calculate_loss_and_backpropagate(
657 | model,
658 | batch,
659 | accelerator,
660 | optimizer,
661 | lr_scheduler,
662 | embedding_size,
663 | args
664 | ):
665 | loss = calculate_loss(model, batch, embedding_size, args)
666 | accelerator.backward(loss)
667 | # clip gradient norm. don't do this with deepspeed
668 | if accelerator.sync_gradients and args.clip_grad_norm > 0:
669 | accelerator.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
670 | optimizer.step()
671 | optimizer.zero_grad()
672 | lr_scheduler.step()
673 |
674 | return loss
675 |
676 | if __name__ == '__main__':
677 | parser = ArgumentParserPlus(FinetuneArguments)
678 | args = parser.parse()
679 | main(args)
680 |
--------------------------------------------------------------------------------
/finetunerag/model_utils.py:
--------------------------------------------------------------------------------
1 | # Taken and modified from https://github.com/huggingface/trl
2 | # Copyright 2022 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import itertools
18 | from contextlib import contextmanager
19 | from dataclasses import dataclass
20 | from typing import List, Literal, Optional, Tuple, Union
21 |
22 | from guardrag.utils import retry_on_exception
23 |
24 | try:
25 | import deepspeed
26 | from deepspeed.runtime.engine import DeepSpeedEngine
27 | except ImportError:
28 | pass
29 | import pandas as pd
30 | import torch
31 | import transformers
32 | from accelerate import Accelerator
33 | from accelerate.state import AcceleratorState
34 | from huggingface_hub import HfApi
35 | from rich import print as rprint
36 | from rich.console import Console
37 | from rich.table import Table
38 | from rich.text import Text
39 | from torch.nn.parallel.distributed import DistributedDataParallel
40 | from transformers import PreTrainedModel, PreTrainedTokenizer
41 |
42 |
43 | @dataclass
44 | class ModelConfig:
45 | model_name_or_path: Optional[str] = None
46 | """The model checkpoint for weights initialization."""
47 | model_revision: str = "main"
48 | """The specific model version to use (can be a branch name, tag name or commit id)."""
49 | trust_remote_code: bool = False
50 | """Trust remote code when loading a model."""
51 | torch_dtype: Optional[str] = None
52 | """Override the default `torch.dtype` and load the model under this dtype."""
53 | attn_implementation: Optional[Literal["flash_attention_2"]] = None
54 | """Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case
55 | you must install this manually by running `pip install flash-attn --no-build-isolation`"""
56 | use_cache: Optional[bool] = None
57 | """Whether to use cache in the model."""
58 | gradient_checkpointing: Optional[bool] = None
59 | """Whether to use gradient checkpointing in the model."""
60 |
61 | # PEFT-related args
62 | use_peft: bool = False
63 | """Whether to use PEFT or not for training."""
64 | lora_r: Optional[int] = 16
65 | """LoRA R value."""
66 | lora_alpha: Optional[int] = 32
67 | """LoRA alpha."""
68 | lora_dropout: Optional[float] = 0.05
69 | """LoRA dropout."""
70 | lora_target_modules: Optional[List[str]] = None
71 | """LoRA target modules."""
72 | lora_modules_to_save: Optional[List[str]] = None
73 | """Model layers to unfreeze & train"""
74 | lora_task_type: str = "CAUSAL_LM"
75 | """The task_type to pass for LoRA (use SEQ_CLS for reward modeling)"""
76 |
77 | # quantization args
78 | load_in_8bit: bool = False
79 | """use 8 bit precision for the base model - works only with LoRA"""
80 | load_in_4bit: bool = False
81 | """use 4 bit precision for the base model - works only with LoRA"""
82 | bnb_4bit_quant_type: Optional[str] = "nf4"
83 | """precise the quantization type (fp4 or nf4)"""
84 | use_bnb_nested_quant: bool = False
85 | """use nested quantization"""
86 |
87 | def __post_init__(self):
88 | # `use_cache=True` is incompatible with gradient checkpointing.
89 | # https://github.com/huggingface/transformers/blob/d6751d91c8f58cdeb35af6adae182d7dc90aa883/src/transformers/models/llama/modeling_llama.py#L945
90 | if self.gradient_checkpointing:
91 | self.use_cache = False
92 |
93 |
94 | # ----------------------------------------------------------------------------
95 | # Model utilities; reward model stuff
96 | def disable_dropout_in_model(model: torch.nn.Module) -> None:
97 | for module in model.modules():
98 | if isinstance(module, torch.nn.Dropout):
99 | module.p = 0
100 |
101 |
102 | def first_true_indices(bools: torch.Tensor, dtype=torch.long) -> torch.Tensor:
103 | """
104 | Finds the index of the first `True` value in each row of a boolean tensor. If no `True` value exists in a row,
105 | it returns the length of the row.
106 |
107 | Args:
108 | bools (torch.Tensor): A boolean tensor of shape (batch_size, sequence_length), where `True` values indicate
109 | the positions of interest.
110 | dtype (torch.dtype): The data type to use for the output indices (default is torch.long).
111 |
112 | Returns:
113 | torch.Tensor: A tensor of shape (batch_size,) containing the index of the first `True` value in each row.
114 | If a row has no `True` value, the index will be the length of the row.
115 | """
116 |
117 | # Get the length of each row (i.e., the number of columns in the last dimension)
118 | # row_len is a scalar representing the length of each sequence (sequence_length)
119 | row_len = bools.size(-1)
120 |
121 | # Calculate the index positions for the first `True` in each row
122 | # ~bools: Invert the boolean values (True becomes False and vice versa)
123 | # ~bools.type(dtype): Convert the inverted boolean tensor to the specified dtype (0 for True, 1 for False)
124 | # row_len * (~bools).type(dtype): For `False` values, this will give `row_len`, for `True` values it gives 0.
125 | # torch.arange(row_len, dtype=dtype, device=bools.device): Generates a tensor with values [0, 1, 2, ..., row_len-1]
126 | # for each row. Shape: (sequence_length,)
127 | # zero_or_index: Shape (batch_size, sequence_length). This tensor contains the indices for `True` values and `row_len`
128 | # for `False` values.
129 | zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device)
130 |
131 | # Return the minimum value in each row (i.e., the first `True` index or `row_len` if none exist)
132 | # torch.min(zero_or_index, dim=-1).values: This returns the minimum value in each row, which corresponds to the first
133 | # `True` value's index or `row_len` if there is no `True` in that row.
134 | # The returned tensor has shape (batch_size,)
135 | return torch.min(zero_or_index, dim=-1).values
136 |
137 |
138 | def get_reward(
139 | model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int
140 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
141 | """
142 | This function computes reward scores for a batch of query responses based on a pre-trained reward model.
143 |
144 | Args:
145 | model (torch.nn.Module): The pre-trained reward model.
146 | query_responses (torch.Tensor): Tensor containing the tokenized responses for which to compute rewards.
147 | Shape: (batch_size, sequence_length)
148 | pad_token_id (int): The ID used for padding tokens in the tokenized sequences.
149 | context_length (int): The length of the prompt or context preceding the completions.
150 |
151 | Returns:
152 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing:
153 | - reward_logits: The logits output from the model for all tokens in the sequences.
154 | Shape: (batch_size, sequence_length)
155 | - final_scores: The final reward scores, one for each sequence, after adjusting for sequence lengths.
156 | Shape: (batch_size,)
157 | - sequence_lengths: The lengths of each sequence (excluding padding).
158 | Shape: (batch_size,)
159 | """
160 |
161 | # Create an attention mask where tokens that are not padding have a value of 1, and padding tokens have a value of 0
162 | # Shape: (batch_size, sequence_length)
163 | attention_mask = query_responses != pad_token_id
164 |
165 | # Calculate position IDs for each token, considering the cumulative sum of the attention mask (to exclude padding)
166 | # Shape: (batch_size, sequence_length)
167 | position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum
168 |
169 | # Access the LM backbone from the reward model using its base model prefix
170 | lm_backbone = getattr(model, model.base_model_prefix)
171 |
172 | # Replace padding tokens with zeros in the input IDs (so padding tokens won't affect the model's processing)
173 | # Shape: (batch_size, sequence_length)
174 | input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
175 | output = lm_backbone(
176 | input_ids=input_ids,
177 | attention_mask=attention_mask,
178 | position_ids=position_ids,
179 | return_dict=True,
180 | output_hidden_states=True,
181 | use_cache=False, # otherwise mistral-based RM would error out
182 | )
183 | reward_logits = model.score(output.hidden_states[-1]) # (batch_size, sequence_length)
184 |
185 | # Calculate the length of each sequence by finding the first occurrence of a padding token after the context
186 | # sequence_lengths shape: (batch_size,)
187 | sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length
188 | assert (
189 | reward_logits.shape[-1] == 1
190 | ), "Reward model should output a single scalar per token. Check if you added `num_labels=1` when doing `AutoModelForSequenceClassification.from_pretrained(...)`."
191 | # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454
192 |
193 | # Return the reward logits for all tokens, the final reward scores for each sequence, and the sequence lengths
194 | return (
195 | # reward_logits shape: (batch_size, sequence_length)
196 | reward_logits,
197 | # final_scores shape: (batch_size,)
198 | reward_logits[
199 | torch.arange(reward_logits.size(0), device=reward_logits.device),
200 | sequence_lengths,
201 | ].squeeze(
202 | -1
203 | ), # Shape: (batch_size,)
204 | sequence_lengths,
205 | )
206 |
207 |
208 | def forward(
209 | model: torch.nn.Module,
210 | query_responses: torch.Tensor,
211 | pad_token_id: int,
212 | ) -> torch.nn.Module:
213 | """
214 | Performs a forward pass through the model with the given query responses and pad token ID.
215 | Args:
216 | model (`torch.nn.Module`):
217 | The model to perform the forward pass.
218 | query_responses (`torch.Tensor`):
219 | The tensor containing the query responses.
220 | pad_token_id (`int`):
221 | The token ID representing the pad token.
222 | Returns:
223 | `torch.nn.Module`:
224 | The output of the model, including hidden states.
225 | """
226 | attention_mask = query_responses != pad_token_id
227 | position_ids = attention_mask.cumsum(1) - attention_mask.long()
228 | input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
229 | return model(
230 | input_ids=input_ids,
231 | attention_mask=attention_mask,
232 | position_ids=position_ids,
233 | return_dict=True,
234 | output_hidden_states=True,
235 | )
236 |
237 |
238 | def truncate_response(stop_token_id: int, pad_token_id: int, responses: torch.Tensor):
239 | """
240 | Truncates the responses at the first occurrence of the stop token, filling the rest with pad tokens.
241 | Args:
242 | stop_token_id (`int`):
243 | The token ID representing the stop token where truncation occurs.
244 | pad_token_id (`int`):
245 | The token ID representing the pad token used to fill the truncated responses.
246 | responses (`torch.Tensor`):
247 | The tensor containing the responses to be truncated.
248 | Returns:
249 | `torch.Tensor`:
250 | The truncated responses tensor with pad tokens filled after the stop token.
251 | """
252 | trunc_idxs = first_true_indices(responses == stop_token_id).unsqueeze(-1)
253 | new_size = [1] * (len(responses.size()) - 1) + [responses.shape[1]]
254 | idxs = torch.arange(responses.shape[1], device=responses.device).view(*new_size)
255 | postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, pad_token_id)
256 | return postprocessed_responses
257 |
258 |
259 | def generate(
260 | lm_backbone: torch.nn.Module, queries: torch.Tensor, pad_token_id: int, generation_config: dict
261 | ) -> Tuple[torch.Tensor, torch.Tensor]:
262 | """
263 | Generates sequences from the language model backbone in a way that does not affect padding tokens.
264 | Args:
265 | lm_backbone (`torch.nn.Module`):
266 | The language model backbone used for generation.
267 | queries (`torch.Tensor`):
268 | The tensor containing the input queries.
269 | pad_token_id (`int`):
270 | The token ID representing the pad token.
271 | generation_config (`dict`):
272 | The configuration dictionary for generation settings.
273 | Returns:
274 | tuple:
275 | - `generated_sequences` (`torch.Tensor`):
276 | The concatenated tensor of input queries and generated sequences.
277 | - `logits` (`torch.Tensor`):
278 | The logits output from the generation process.
279 | """
280 | context_length = queries.shape[1]
281 | attention_mask = queries != pad_token_id
282 | input_ids = torch.masked_fill(queries, ~attention_mask, 0)
283 | output = lm_backbone.generate(
284 | input_ids=input_ids,
285 | attention_mask=attention_mask,
286 | # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # not needed: already adjusted in generations
287 | # https://github.com/huggingface/transformers/blob/ac33aeeeee2a7a89b89c93c2962e6feb90daef0a/src/transformers/models/gpt2/modeling_gpt2.py#L1227-L1250
288 | generation_config=generation_config,
289 | return_dict_in_generate=True,
290 | output_logits=True,
291 | )
292 | logits = torch.stack(output.logits, 1)
293 | return torch.cat((queries, output.sequences[:, context_length:]), dim=1), logits
294 |
295 |
296 | @torch.no_grad()
297 | def batch_generation(
298 | model: torch.nn.Module,
299 | queries: torch.Tensor,
300 | local_rollout_forward_batch_size: int,
301 | pad_token_id: int,
302 | generation_config: dict,
303 | ):
304 | query_responses = []
305 | logitss = []
306 | for i in range(0, queries.shape[0], local_rollout_forward_batch_size):
307 | query = queries[i : i + local_rollout_forward_batch_size]
308 | query_response, logits = generate(
309 | model,
310 | query,
311 | pad_token_id,
312 | generation_config,
313 | )
314 | query_responses.append(query_response)
315 | logitss.append(logits)
316 | return torch.cat(query_responses, 0), torch.cat(logitss, 0)
317 |
318 |
319 | def save_with_accelerate(
320 | accelerator: Accelerator,
321 | model: torch.nn.Module,
322 | tokenizer: PreTrainedTokenizer,
323 | output_dir: str,
324 | use_lora: bool = False,
325 | ) -> None:
326 | # set the generation config to an empty setting to be safe.
327 | # we usually do greedy decoding for generation, so this should be okay.
328 | # otherwise, we get an error thrown at save time.
329 | model.generation_config = transformers.GenerationConfig(
330 | temperature=None, top_p=None, eos_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id
331 | )
332 |
333 | unwrapped_model: PreTrainedModel = accelerator.unwrap_model(model)
334 | # When doing multi-gpu training, we need to use accelerator.get_state_dict(model) to get the state_dict.
335 | # Otherwise, sometimes the model will be saved with only part of the parameters.
336 | # Also, accelerator needs to use the wrapped model to get the state_dict.
337 | state_dict = accelerator.get_state_dict(model)
338 | if use_lora:
339 | # When using lora, the unwrapped model is a PeftModel, which doesn't support the is_main_process
340 | # and has its own save_pretrained function for only saving lora modules.
341 | # We have to manually specify the is_main_process outside the save_pretrained function.
342 | if accelerator.is_main_process:
343 | unwrapped_model.save_pretrained(output_dir, state_dict=state_dict)
344 | else:
345 | # don't use safetensors for saving for now
346 | unwrapped_model.save_pretrained(
347 | output_dir,
348 | is_main_process=accelerator.is_main_process,
349 | save_function=accelerator.save,
350 | state_dict=state_dict,
351 | safe_serialization=False,
352 | )
353 |
354 | if accelerator.is_main_process:
355 | tokenizer.save_pretrained(output_dir)
356 | # customize model card (TODO (Costa): this can be prettier)
357 |
358 |
359 | @retry_on_exception()
360 | def push_folder_to_hub(
361 | accelerator: Accelerator,
362 | output_dir: str,
363 | hf_repo_id: Optional[str] = None,
364 | hf_repo_revision: Optional[str] = None,
365 | private: bool = True,
366 | ):
367 | if accelerator.is_main_process:
368 | hf_repo_url = f"https://huggingface.co/{hf_repo_id}/tree/{hf_repo_revision}"
369 | api = HfApi()
370 | if not api.repo_exists(hf_repo_id):
371 | api.create_repo(hf_repo_id, exist_ok=True, private=private)
372 | if hf_repo_revision is not None:
373 | api.create_branch(repo_id=hf_repo_id, branch=hf_repo_revision, exist_ok=True)
374 | api.upload_folder(
375 | repo_id=hf_repo_id,
376 | revision=hf_repo_revision,
377 | folder_path=output_dir,
378 | commit_message="upload checkpoint",
379 | run_as_future=False,
380 | )
381 | print(f"🔥 pushed to {hf_repo_url}")
382 |
383 |
384 | # ----------------------------------------------------------------------------
385 | # DeepSpeed utilities
386 | def get_all_parameters(sub_module, recurse=False):
387 | return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters())
388 |
389 |
390 | def iter_params(module, recurse=False):
391 | return [param for _, param in get_all_parameters(module, recurse)]
392 |
393 |
394 | def remove_hooks(model: "DeepSpeedEngine") -> None:
395 | """Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
396 | if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
397 | optimizer_offload = model.optimizer.parameter_offload
398 | elif model.optimizer is not None:
399 | optimizer_offload = model.optimizer
400 |
401 | for param in iter_params(optimizer_offload.module, recurse=True):
402 | param.ds_active_sub_modules.clear()
403 |
404 | for hook in optimizer_offload.forward_hooks:
405 | hook.remove()
406 | for hook in optimizer_offload.backward_hooks:
407 | hook.remove()
408 |
409 | optimizer_offload.forward_hooks = []
410 | optimizer_offload.backward_hooks = []
411 |
412 |
413 | def add_hooks(model: "DeepSpeedEngine") -> None:
414 | """Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
415 | if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
416 | optimizer_offload = model.optimizer.parameter_offload
417 | elif model.optimizer is not None:
418 | optimizer_offload = model.optimizer
419 | optimizer_offload._register_hooks_recursively(optimizer_offload.module)
420 |
421 |
422 | @contextmanager
423 | def unwrap_model_for_generation(
424 | model: Union["DistributedDataParallel", "DeepSpeedEngine"], accelerator: "Accelerator", is_peft_model: bool = False
425 | ) -> Union["transformers.PreTrainedModel", "DeepSpeedEngine"]:
426 | """Context manager to unwrap a model for generation.
427 | For ZeRO-3 models, we gather the weights once to speed up generation.
428 | """
429 | unwrapped_model = accelerator.unwrap_model(model)
430 | if is_peft_model:
431 | unwrapped_model.pretrained_model.disable_adapter()
432 | if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
433 | with deepspeed.zero.GatheredParameters(model.parameters()):
434 | remove_hooks(model)
435 | yield accelerator.unwrap_model(model)
436 | add_hooks(model)
437 | else:
438 | yield unwrapped_model
439 |
440 |
441 | def prepare_deepspeed(model: torch.nn.Module, per_device_train_batch_size: int, mixed_precision: str):
442 | """
443 | Prepares the model for training with DeepSpeed (both for stage 2 and 3), configuring the appropriate settings based on the model and
444 | batch size.
445 | Args:
446 | model (`torch.nn.Module`):
447 | The model to be prepared for DeepSpeed training.
448 | per_device_train_batch_size (`int`):
449 | The training batch size per device.
450 | mixed_precision (`str`):
451 | The mixed precision setting to use.
452 | Returns:
453 | `torch.nn.Module`:
454 | The model initialized and configured with DeepSpeed for training.
455 | """
456 | import deepspeed
457 |
458 | deepspeed_plugin = AcceleratorState().deepspeed_plugin
459 | config_kwargs = deepspeed_plugin.deepspeed_config
460 | if config_kwargs["zero_optimization"]["stage"] != 3:
461 | config_kwargs["train_micro_batch_size_per_gpu"] = per_device_train_batch_size
462 | config_kwargs = {
463 | "train_micro_batch_size_per_gpu": config_kwargs["train_micro_batch_size_per_gpu"],
464 | "prescale_gradients": False,
465 | "wall_clock_breakdown": False,
466 | }
467 | if mixed_precision in ["bf16", "fp16"]:
468 | config_kwargs[mixed_precision] = {"enabled": True}
469 | else:
470 | if hasattr(model, "config"):
471 | hidden_size = (
472 | max(model.config.hidden_sizes)
473 | if getattr(model.config, "hidden_sizes", None)
474 | else getattr(model.config, "hidden_size", None)
475 | )
476 | if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
477 | # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
478 | # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
479 | config_kwargs.update(
480 | {
481 | "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
482 | "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
483 | "zero_optimization.stage3_prefetch_bucket_size": 0,
484 | }
485 | )
486 | model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
487 | model.eval()
488 | return model
489 |
490 |
491 | # ----------------------------------------------------------------------------
492 | # Quality of life utilities
493 | def print_rich_table(df: pd.DataFrame) -> Table:
494 | console = Console()
495 | table = Table(show_lines=True)
496 | for column in df.columns:
497 | table.add_column(column)
498 | for _, row in df.iterrows():
499 | table.add_row(*row.astype(str).tolist())
500 | console.print(table)
501 |
502 |
503 | def format_value(value):
504 | if isinstance(value, float):
505 | if abs(value) < 1e-5:
506 | return f"{value:.2e}"
507 | return f"{value:.2f}"
508 | return str(value)
509 |
510 |
511 | def print_rich_single_line_metrics(metrics):
512 | formatted_metrics = []
513 | for key, value in metrics.items():
514 | # Shortening the key names
515 | short_key = key.split("/")[-1] if "/" in key else key
516 |
517 | # Create a colored text object
518 | metric_text = Text()
519 | metric_text.append(short_key + ": ", style="bold cyan") # Keys in cyan
520 | metric_text.append(format_value(value), style="yellow") # Values in yellow
521 |
522 | formatted_metrics.append(metric_text)
523 |
524 | rprint(" | ".join(str(metric) for metric in formatted_metrics))
525 |
526 |
527 | def exact_div(a, b, custom_error_message=""):
528 | q = a // b
529 | if a != q * b:
530 | raise ValueError(f"{custom_error_message}, inexact division: {a} / {b} = {a / b}")
531 | return q
--------------------------------------------------------------------------------
/finetunerag/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 AllenAI. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import dataclasses
16 | import functools
17 | import json
18 | import logging
19 | import os
20 | import shutil
21 | import subprocess
22 | import sys
23 | import time
24 | from dataclasses import dataclass
25 | from typing import Any, List, NewType, Optional, Tuple, Union
26 |
27 | import requests
28 | from accelerate.logging import get_logger
29 | from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
30 | from datasets.builder import DatasetGenerationError
31 | from huggingface_hub import HfApi
32 | from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser
33 |
34 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
35 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
36 |
37 | logger = get_logger(__name__)
38 |
39 | DataClassType = NewType("DataClassType", Any)
40 |
41 | """
42 | Notes:
43 | Inspired by Alignment Handbook Parser and Dataset Mixer
44 | https://github.com/huggingface/alignment-handbook/blob/main/src/alignment/configs.py
45 | https://github.com/huggingface/alignment-handbook/blob/main/src/alignment/data.py
46 |
47 | Migrated Args from
48 | https://github.com/allenai/open-instruct/blob/98ccfb460ae4fb98140783b6cf54241926160a06/open_instruct/finetune_trainer.py
49 |
50 | Commented out Args not currently used
51 | """
52 |
53 |
54 | # ----------------------------------------------------------------------------
55 | # Dataset utilities
56 | def is_openai_format(messages: Any) -> bool:
57 | """
58 | Check if the input messages are in OpenAI format.
59 | Args:
60 | messages (`Any`):
61 | Messages to check.
62 | Returns:
63 | `bool`: Whether the messages are in OpenAI format.
64 | """
65 | if isinstance(messages, list) and all(isinstance(message, dict) for message in messages):
66 | return all("role" in message and "content" in message for message in messages)
67 | return False
68 |
69 |
70 | # functions for handling different formats of messages
71 | def convert_alpaca_gpt4_to_messages(example):
72 | """
73 | Convert an instruction in inst-output to a list of messages.
74 | e.g. vicgalle/alpaca-gpt4"""
75 | messages = [
76 | {
77 | "role": "user",
78 | "content": (
79 | "Below is an instruction that describes a task, paired with an input that provides "
80 | "further context. Write a response that appropriately completes the request.\n\n"
81 | f"### Instruction:\n{example['instruction']}\n\n"
82 | f"### Input:\n{example['input']}\n\n"
83 | "### Response:"
84 | ),
85 | },
86 | {"role": "assistant", "content": example["output"]},
87 | ]
88 | example["messages"] = messages
89 | return example
90 |
91 |
92 | def convert_codefeedback_single_turn_to_messages(example):
93 | """
94 | Convert a query-answer pair to a list of messages.
95 | e.g. m-a-p/CodeFeedback-Filtered-Instruction"""
96 | messages = [
97 | {"role": "user", "content": example["query"]},
98 | {"role": "assistant", "content": example["answer"]},
99 | ]
100 | example["messages"] = messages
101 | return example
102 |
103 |
104 | def convert_metamath_qa_to_messages(example):
105 | """
106 | Convert a query-response pair to a list of messages.
107 | e.g. meta-math/MetaMathQA"""
108 | messages = [
109 | {"role": "user", "content": example["query"]},
110 | {"role": "assistant", "content": example["response"]},
111 | ]
112 | example["messages"] = messages
113 | return example
114 |
115 |
116 | def convert_code_alpaca_to_messages(example):
117 | """
118 | Convert a prompt-completion pair to a list of messages.
119 | e.g. HuggingFaceH4/CodeAlpaca_20K"""
120 | messages = [
121 | {"role": "user", "content": example["prompt"]},
122 | {"role": "assistant", "content": example["completion"]},
123 | ]
124 | example["messages"] = messages
125 | return example
126 |
127 |
128 | def convert_open_orca_to_messages(example):
129 | """
130 | Convert a question-response pair to a list of messages.
131 | e.g. Open-Orca/OpenOrca"""
132 | messages = [
133 | {"role": "system", "content": example["system_prompt"]},
134 | {"role": "user", "content": example["question"]},
135 | {"role": "assistant", "content": example["response"]},
136 | ]
137 | example["messages"] = messages
138 | return example
139 |
140 |
141 | def conversations_to_messages(example):
142 | """
143 | Convert from conversations format to messages.
144 |
145 | E.g. change "from": "user" to "role": "user"
146 | and "value" to "content"
147 | and "gpt" to "assistant"
148 |
149 | WizardLMTeam/WizardLM_evol_instruct_V2_196k
150 | """
151 | name_mapping = {
152 | "gpt": "assistant",
153 | "Assistant": "assistant",
154 | "assistant": "assistant",
155 | "user": "user",
156 | "User": "user",
157 | "human": "user",
158 | }
159 | messages = [{"role": name_mapping[conv["from"]], "content": conv["value"]} for conv in example["conversations"]]
160 | example["messages"] = messages
161 | return example
162 |
163 |
164 | def convert_rejection_samples_to_messages(example):
165 | """
166 | Convert a rejection sampling dataset to messages.
167 | """
168 | example["messages"] = example["chosen"]
169 | return example
170 |
171 |
172 | def get_datasets(
173 | dataset_mixer: Union[dict, list],
174 | splits: Optional[List[str]] = None,
175 | configs: Optional[List[str]] = None,
176 | columns_to_keep: Optional[List[str]] = None,
177 | shuffle: bool = True,
178 | save_data_dir: Optional[str] = None,
179 | need_columns: Optional[List[str]] = None,
180 | keep_ids: bool = False,
181 | ) -> DatasetDict:
182 | """
183 | Loads and mixes datasets according to proportions specified in `dataset_mixer`.
184 |
185 | Args:
186 | dataset_mixer (`list` or `dict`):
187 | Dictionary or list containing the dataset names and their training proportions.
188 | By default, all test proportions are 1. Lists are formatted as
189 | `key1 value1 key2 value2 ...` If a list is passed in, it will be converted to a dictionary.
190 | splits (Optional[List[str]], *optional*, defaults to `None`):
191 | Dataset splits to load and mix. Assumes the splits exist in
192 | all datasets and have a `train_` or `test_` prefix.
193 | configs (Optional[List[str]], *optional*, defaults to `None`):
194 | List of dataset config names. If given must be the same length as 'dataset_mixer' keys.
195 | columns_to_keep (Optional[List[str]], *optional*, defaults to `None`):
196 | Column names to keep in the dataset. Useful in the datamixer to avoid schema conflicts,
197 | and for cpt this should be (at least) the text column.
198 | shuffle (`bool`, *optional*, defaults to `True`):
199 | Whether to shuffle the training and testing/validation data.
200 | save_data_dir (Optional[str], *optional*, defaults to `None`):
201 | Optional directory to save training/test mixes on.
202 | need_columns (Optional[List[str]], *optional*, defaults to `None`):
203 | Column names that are required to be in the dataset.
204 | Quick debugging when mixing heterogeneous datasets.
205 | keep_ids (`bool`, *optional*, defaults to `False`):
206 | Whether to keep ids for training that are added during mixing.
207 | Used primarily in mix_data.py for saving, or the saved dataset has IDs already.
208 | """
209 | if isinstance(dataset_mixer, list):
210 | assert len(dataset_mixer) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer}"
211 | mixer_dict = {}
212 | i = 0
213 | while i < len(dataset_mixer) - 1:
214 | assert isinstance(dataset_mixer[i], str), f"Invalid type in data mixer: {dataset_mixer}"
215 | if "." in dataset_mixer[i + 1]:
216 | value = float(dataset_mixer[i + 1])
217 | else:
218 | value = int(dataset_mixer[i + 1])
219 | mixer_dict[dataset_mixer[i]] = value
220 | i += 2
221 | dataset_mixer = mixer_dict
222 |
223 | splits = ["train", "test"] if splits is None else splits
224 | configs = [None] * len(dataset_mixer) if not configs else configs
225 | columns_to_keep = [] if columns_to_keep is None else columns_to_keep
226 |
227 | if configs is not None and len(configs) != len(dataset_mixer):
228 | raise ValueError("The number of given dataset config names must be the same as the given number of datasets.")
229 |
230 | # print save location
231 | if save_data_dir:
232 | print(f"Saving mixed dataset to {save_data_dir}")
233 |
234 | raw_datasets = DatasetDict()
235 | raw_train_datasets = []
236 | raw_val_datasets = []
237 | frac_or_sample_list = []
238 | for (ds, frac_or_samples), ds_config in zip(dataset_mixer.items(), configs):
239 | frac_or_sample_list.append(frac_or_samples)
240 | for split in splits:
241 | # if dataset ends with .json or .jsonl, load from file
242 | if ds.endswith(".json") or ds.endswith(".jsonl"):
243 | dataset = load_dataset("json", data_files=ds, split=split)
244 | else:
245 | try:
246 | # Try first if dataset on a Hub repo
247 | dataset = load_dataset(ds, ds_config, split=split)
248 | except DatasetGenerationError:
249 | # If not, check local dataset
250 | dataset = load_from_disk(os.path.join(ds, split))
251 |
252 | # shuffle dataset if set
253 | if shuffle:
254 | dataset = dataset.shuffle(seed=42)
255 |
256 | # assert that needed columns are present
257 | if need_columns:
258 | if not all(col in dataset.column_names for col in need_columns):
259 | raise ValueError(f"Needed column {need_columns} not found in dataset {dataset.column_names}.")
260 |
261 | # handle per-case conversions
262 | # if "instruction" and "output" columns are present and "messages" is not, convert to messages
263 | if (
264 | "instruction" in dataset.column_names
265 | and "output" in dataset.column_names
266 | and "messages" not in dataset.column_names
267 | ):
268 | dataset = dataset.map(convert_alpaca_gpt4_to_messages, num_proc=10)
269 | elif (
270 | "prompt" in dataset.column_names
271 | and "completion" in dataset.column_names
272 | and "messages" not in dataset.column_names
273 | ):
274 | dataset = dataset.map(convert_code_alpaca_to_messages, num_proc=10)
275 | elif "conversations" in dataset.column_names and "messages" not in dataset.column_names:
276 | dataset = dataset.map(conversations_to_messages, num_proc=10)
277 | elif (
278 | "question" in dataset.column_names
279 | and "response" in dataset.column_names
280 | and "messages" not in dataset.column_names
281 | ):
282 | dataset = dataset.map(convert_open_orca_to_messages, num_proc=10)
283 | elif (
284 | "query" in dataset.column_names
285 | and "answer" in dataset.column_names
286 | and "messages" not in dataset.column_names
287 | ):
288 | dataset = dataset.map(convert_codefeedback_single_turn_to_messages, num_proc=10)
289 | elif (
290 | "query" in dataset.column_names
291 | and "response" in dataset.column_names
292 | and "messages" not in dataset.column_names
293 | ):
294 | dataset = dataset.map(convert_metamath_qa_to_messages, num_proc=10)
295 | elif (
296 | "chosen" in dataset.column_names
297 | and "rejected" in dataset.column_names
298 | and "reference_completion" in dataset.column_names
299 | and "messages" not in dataset.column_names
300 | ):
301 | dataset = dataset.map(convert_rejection_samples_to_messages, num_proc=10)
302 |
303 | # if id not in dataset, create it as ds-{index}
304 | if "id" not in dataset.column_names:
305 | id_col = [f"{ds}_{i}" for i in range(len(dataset))]
306 | dataset = dataset.add_column("id", id_col)
307 |
308 | # Remove redundant columns to avoid schema conflicts on load
309 | dataset = dataset.remove_columns(
310 | [col for col in dataset.column_names if col not in (columns_to_keep + ["id"])]
311 | )
312 |
313 | # add tag to the dataset corresponding to where it was sourced from, for
314 | if "train" in split:
315 | raw_train_datasets.append(dataset)
316 | elif "test" in split:
317 | raw_val_datasets.append(dataset)
318 | else:
319 | raise ValueError(f"Split type {split} not recognized as one of test or train.")
320 |
321 | if len(raw_val_datasets) == 0 and len(raw_train_datasets) == 0:
322 | raise ValueError("No datasets loaded.")
323 | elif len(raw_train_datasets) == 0:
324 | # target features are the features of the first dataset post load
325 | target_features = raw_val_datasets[0].features
326 | else:
327 | # target features are the features of the first dataset post load
328 | target_features = raw_train_datasets[0].features
329 |
330 | if any(frac_or_samples < 0 for frac_or_samples in frac_or_sample_list):
331 | raise ValueError("Dataset fractions / lengths cannot be negative.")
332 |
333 | # if any > 1, use count
334 | if any(frac_or_samples > 1 for frac_or_samples in frac_or_sample_list):
335 | is_count = True
336 | # assert that all are integers
337 | if not all(isinstance(frac_or_samples, int) for frac_or_samples in frac_or_sample_list):
338 | raise NotImplementedError("Cannot mix fractions and counts, yet.")
339 | else:
340 | is_count = False
341 |
342 | if len(raw_train_datasets) > 0:
343 | train_subsets = []
344 | # Manage proportions
345 | for dataset, frac_or_samples in zip(raw_train_datasets, frac_or_sample_list):
346 | # cast features (TODO, add more feature regularization)
347 | dataset = dataset.cast(target_features)
348 | # TODO selection can be randomized.
349 | if is_count:
350 | train_subset = dataset.select(range(frac_or_samples))
351 | else:
352 | train_subset = dataset.select(range(int(frac_or_samples * len(dataset))))
353 | train_subsets.append(train_subset)
354 |
355 | raw_datasets["train"] = concatenate_datasets(train_subsets)
356 |
357 | # No subsampling for test datasets to enable fair comparison across models
358 | if len(raw_val_datasets) > 0:
359 | for dataset in raw_val_datasets:
360 | # cast features (TODO, add more feature regularization)
361 | dataset = dataset.cast(target_features)
362 |
363 | raw_datasets["test"] = concatenate_datasets(raw_val_datasets)
364 |
365 | if len(raw_datasets) == 0:
366 | raise ValueError(
367 | f"Dataset {dataset_mixer} not recognized with splits {splits}."
368 | "Check the dataset has been correctly formatted."
369 | )
370 |
371 | # optional save
372 | if save_data_dir:
373 | for split in raw_datasets:
374 | raw_datasets[split].to_json(save_data_dir + f"mixed_ds_{split}.json")
375 |
376 | if not keep_ids:
377 | # remove id column
378 | if len(raw_train_datasets) > 0:
379 | if "id" in raw_datasets["train"].column_names:
380 | raw_datasets["train"] = raw_datasets["train"].remove_columns("id")
381 | if len(raw_val_datasets) > 0:
382 | if "id" in raw_datasets["test"].column_names:
383 | raw_datasets["test"] = raw_datasets["test"].remove_columns("id")
384 |
385 | return raw_datasets
386 |
387 |
388 | # ----------------------------------------------------------------------------
389 | # Arguments utilities
390 | class ArgumentParserPlus(HfArgumentParser):
391 | def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = None) -> List[dataclass]:
392 | """
393 | Parse a YAML file and overwrite the default/loaded values with the values provided to the command line.
394 |
395 | Args:
396 | yaml_arg (`str`):
397 | The path to the config file used
398 | other_args (`List[str]`, *optional`):
399 | A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2'].
400 |
401 | Returns:
402 | [`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line
403 | """
404 | arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg))
405 |
406 | outputs = []
407 | # strip other args list into dict of key-value pairs
408 | other_args = {arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args}
409 | used_args = {}
410 |
411 | # overwrite the default/loaded value with the value provided to the command line
412 | # noqa adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327
413 | for data_yaml, data_class in zip(arg_list, self.dataclass_types):
414 | keys = {f.name for f in dataclasses.fields(data_yaml) if f.init}
415 | inputs = {k: v for k, v in vars(data_yaml).items() if k in keys}
416 | for arg, val in other_args.items():
417 | # add only if in keys
418 |
419 | if arg in keys:
420 | base_type = data_yaml.__dataclass_fields__[arg].type
421 | inputs[arg] = val
422 |
423 | # cast type for ints, floats (default to strings)
424 | if base_type in [int, float]:
425 | inputs[arg] = base_type(val)
426 |
427 | if base_type == List[str]:
428 | inputs[arg] = [str(v) for v in val.split(",")]
429 |
430 | # bool of a non-empty string is True, so we manually check for bools
431 | if base_type == bool:
432 | if val in ["true", "True"]:
433 | inputs[arg] = True
434 | else:
435 | inputs[arg] = False
436 |
437 | # add to used-args so we can check if double add
438 | if arg not in used_args:
439 | used_args[arg] = val
440 | else:
441 | raise ValueError(f"Duplicate argument provided: {arg}, may cause unexpected behavior")
442 |
443 | obj = data_class(**inputs)
444 | outputs.append(obj)
445 |
446 | return outputs
447 |
448 | def parse(self) -> Union[DataClassType, Tuple[DataClassType]]:
449 | if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
450 | # If we pass only one argument to the script and it's the path to a YAML file,
451 | # let's parse it to get our arguments.
452 | output = self.parse_yaml_file(os.path.abspath(sys.argv[1]))
453 | # parse command line args and yaml file
454 | elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"):
455 | output = self.parse_yaml_and_args(os.path.abspath(sys.argv[1]), sys.argv[2:])
456 | # parse command line args only
457 | else:
458 | output = self.parse_args_into_dataclasses()
459 |
460 | if len(output) == 1:
461 | output = output[0]
462 | return output
463 |
464 |
465 | # ----------------------------------------------------------------------------
466 | # Experiment tracking utilities
467 | def get_git_tag() -> str:
468 | """Try to get the latest Git tag (e.g., `no-tag-404-g98dc659` or `v1.0.0-4-g98dc659`)"""
469 | git_tag = ""
470 | try:
471 | git_tag = (
472 | subprocess.check_output(["git", "describe", "--tags"], stderr=subprocess.DEVNULL).decode("ascii").strip()
473 | )
474 | except subprocess.CalledProcessError as e:
475 | logging.debug(f"Failed to get Git tag: {e}")
476 |
477 | # If no Git tag found, create a custom tag based on commit count and hash
478 | if len(git_tag) == 0:
479 | try:
480 | count = int(
481 | subprocess.check_output(["git", "rev-list", "--count", "HEAD"], stderr=subprocess.DEVNULL)
482 | .decode("ascii")
483 | .strip()
484 | )
485 | hash = (
486 | subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], stderr=subprocess.DEVNULL)
487 | .decode("ascii")
488 | .strip()
489 | )
490 | git_tag = f"no-tag-{count}-g{hash}"
491 | except subprocess.CalledProcessError as e:
492 | logging.debug(f"Failed to get commit count and hash: {e}")
493 |
494 | return git_tag
495 |
496 |
497 | def get_pr_tag() -> str:
498 | """Try to find associated pull request on GitHub (e.g., `pr-123`)"""
499 | pr_tag = ""
500 | try:
501 | git_commit = (
502 | subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"], stderr=subprocess.DEVNULL)
503 | .decode("ascii")
504 | .strip()
505 | )
506 | # try finding the pull request number on github
507 | prs = requests.get(f"https://api.github.com/search/issues?q=repo:allenai/open-instruct+is:pr+{git_commit}")
508 | if prs.status_code == 200:
509 | prs = prs.json()
510 | if len(prs["items"]) > 0:
511 | pr = prs["items"][0]
512 | pr_number = pr["number"]
513 | pr_tag = f"pr-{pr_number}"
514 | except Exception as e:
515 | logging.debug(f"Failed to get PR number: {e}")
516 |
517 | return pr_tag
518 |
519 |
520 | def get_wandb_tags() -> List[str]:
521 | """Get tags for Weights & Biases (e.g., `no-tag-404-g98dc659,pr-123`)"""
522 | existing_wandb_tags = os.environ.get("WANDB_TAGS", "")
523 | git_tag = get_git_tag()
524 | pr_tag = get_pr_tag()
525 | non_empty_tags = [tag for tag in [existing_wandb_tags, git_tag, pr_tag] if len(tag) > 0]
526 | return non_empty_tags
527 |
528 |
529 | # ----------------------------------------------------------------------------
530 | # Check pointing utilities
531 | def get_last_checkpoint(folder: str, incomplete: bool = False) -> Optional[str]:
532 | content = os.listdir(folder)
533 | checkpoint_steps = [path for path in content if path.startswith("step_")]
534 | checkpoint_epochs = [path for path in content if path.startswith("epoch_")]
535 | if len(checkpoint_steps) > 0 and len(checkpoint_epochs) > 0:
536 | logger.info("Mixed step and epoch checkpoints found. Using step checkpoints.")
537 | checkpoints = checkpoint_steps
538 | elif len(checkpoint_steps) == 0:
539 | checkpoints = checkpoint_epochs
540 | else:
541 | checkpoints = checkpoint_steps
542 | if not incomplete:
543 | checkpoints = [path for path in checkpoints if os.path.exists(os.path.join(folder, path, "COMPLETED"))]
544 | if len(checkpoints) == 0:
545 | return
546 | return os.path.join(folder, max(checkpoints, key=lambda x: x.split("_")[-1]))
547 |
548 |
549 | def get_last_checkpoint_path(args, incomplete: bool = False) -> str:
550 | # if output already exists and user does not allow overwriting, resume from there.
551 | # otherwise, resume if the user specifies a checkpoint.
552 | # else, start from scratch.
553 | # if incomplete is true, include folders without "COMPLETE" in the folder.
554 | last_checkpoint_path = None
555 | if args.output_dir and os.path.isdir(args.output_dir) and not args.overwrite_output_dir:
556 | last_checkpoint_path = get_last_checkpoint(args.output_dir, incomplete=incomplete)
557 | if last_checkpoint_path is None:
558 | logger.warning("Output directory exists but no checkpoint found. Starting from scratch.")
559 | elif args.resume_from_checkpoint:
560 | last_checkpoint_path = args.resume_from_checkpoint
561 | return last_checkpoint_path
562 |
563 |
564 | def is_checkpoint_folder(dir: str, folder: str) -> bool:
565 | return (folder.startswith("step_") or folder.startswith("epoch_")) and os.path.isdir(os.path.join(dir, folder))
566 |
567 |
568 | def clean_last_n_checkpoints(output_dir: str, keep_last_n_checkpoints: int) -> None:
569 | # remove the last checkpoint to save space
570 | folders = [f for f in os.listdir(output_dir) if is_checkpoint_folder(output_dir, f)]
571 | # find the checkpoint with the largest step
572 | checkpoints = sorted(folders, key=lambda x: int(x.split("_")[-1]))
573 | if len(checkpoints) > keep_last_n_checkpoints:
574 | for checkpoint in checkpoints[: len(checkpoints) - keep_last_n_checkpoints]:
575 | logger.info(f"Removing checkpoint {checkpoint}")
576 | shutil.rmtree(os.path.join(output_dir, checkpoint))
577 | logger.info("Remaining files:" + str(os.listdir(output_dir)))
578 |
579 |
580 | # ----------------------------------------------------------------------------
581 | # Ai2 user utilities
582 | @dataclass
583 | class BeakerRuntimeConfig:
584 | beaker_workload_id: str
585 | beaker_node_hostname: Optional[List[str]] = None
586 | beaker_experiment_url: Optional[List[str]] = None
587 | beaker_dataset_ids: Optional[List[str]] = None
588 | beaker_dataset_id_urls: Optional[List[str]] = None
589 |
590 |
591 | def is_beaker_job() -> bool:
592 | return "BEAKER_JOB_ID" in os.environ
593 |
594 |
595 | def get_beaker_experiment_info(experiment_id: str) -> Optional[dict]:
596 | get_experiment_command = f"beaker experiment get {experiment_id} --format json"
597 | process = subprocess.Popen(["bash", "-c", get_experiment_command], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
598 | stdout, stderr = process.communicate()
599 | if process.returncode != 0:
600 | print(f"Failed to get Beaker experiment: {stderr}")
601 | return None
602 | return json.loads(stdout)[0]
603 |
604 |
605 | def beaker_experiment_succeeded(experiment_id: str) -> bool:
606 | experiment = get_beaker_experiment_info(experiment_id)
607 | if not experiment:
608 | return False
609 | return all(["finalized" in job["status"] and job["status"]["exitCode"] == 0 for job in experiment["jobs"]])
610 |
611 |
612 | def get_beaker_dataset_ids(experiment_id: str) -> Optional[List[str]]:
613 | experiment = get_beaker_experiment_info(experiment_id)
614 | if not experiment:
615 | return None
616 | result_ids = [job["result"]["beaker"] for job in experiment["jobs"]]
617 | dataset_ids = []
618 | for result_id in result_ids:
619 | get_dataset_command = f"beaker dataset get {result_id} --format json"
620 | process = subprocess.Popen(["bash", "-c", get_dataset_command], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
621 | stdout, stderr = process.communicate()
622 | if process.returncode != 0:
623 | print(f"Failed to get Beaker dataset: {stderr}")
624 | return None
625 | datasets = json.loads(stdout)
626 | dataset_ids.extend([dataset["id"] for dataset in datasets])
627 | return dataset_ids
628 |
629 |
630 | def get_beaker_whoami() -> Optional[str]:
631 | get_beaker_whoami_command = "beaker account whoami --format json"
632 | process = subprocess.Popen(
633 | ["bash", "-c", get_beaker_whoami_command], stdout=subprocess.PIPE, stderr=subprocess.PIPE
634 | )
635 | stdout, stderr = process.communicate()
636 | if process.returncode != 0:
637 | print(f"Failed to get Beaker account: {stderr}")
638 | return None
639 | accounts = json.loads(stdout)
640 | return accounts[0]["name"]
641 |
642 |
643 | def maybe_get_beaker_config():
644 | beaker_dataset_ids = get_beaker_dataset_ids(os.environ["BEAKER_WORKLOAD_ID"])
645 | # fix condition on basic interactive jobs
646 | if beaker_dataset_ids is None:
647 | beaker_dataset_id_urls = []
648 | else:
649 | beaker_dataset_id_urls = [f"https://beaker.org/ds/{dataset_id}" for dataset_id in beaker_dataset_ids]
650 | return BeakerRuntimeConfig(
651 | beaker_workload_id=os.environ["BEAKER_WORKLOAD_ID"],
652 | beaker_node_hostname=os.environ["BEAKER_NODE_HOSTNAME"],
653 | beaker_experiment_url=f"https://beaker.org/ex/{os.environ['BEAKER_WORKLOAD_ID']}/",
654 | beaker_dataset_ids=get_beaker_dataset_ids(os.environ["BEAKER_WORKLOAD_ID"]),
655 | beaker_dataset_id_urls=beaker_dataset_id_urls,
656 | )
657 |
658 |
659 | def retry_on_exception(max_attempts=4, delay=1, backoff=2):
660 | """
661 | Retry a function on exception. Useful for HF API calls that may fail due to
662 | network issues. E.g., https://beaker.org/ex/01J69P87HJQQ7X5DXE1CPWF974
663 | `huggingface_hub.utils._errors.HfHubHTTPError: 429 Client Error`
664 |
665 | We can test it with the following code.
666 | @retry_on_exception(max_attempts=4, delay=1, backoff=2)
667 | def test():
668 | raise Exception("Test exception")
669 |
670 | test()
671 | """
672 |
673 | def decorator(func):
674 | @functools.wraps(func)
675 | def wrapper(*args, **kwargs):
676 | attempts = 0
677 | local_delay = delay
678 | while attempts < max_attempts:
679 | try:
680 | return func(*args, **kwargs)
681 | except Exception as e:
682 | attempts += 1
683 | if attempts == max_attempts:
684 | raise e
685 | print(f"Attempt {attempts} failed. Retrying in {local_delay} seconds...")
686 | time.sleep(local_delay)
687 | local_delay *= backoff
688 | return None
689 |
690 | return wrapper
691 |
692 | return decorator
693 |
694 |
695 | @retry_on_exception()
696 | def maybe_use_ai2_wandb_entity() -> Optional[str]:
697 | """Ai2 internal logic: try use the ai2-llm team if possible. Should not affect external users."""
698 | import wandb
699 |
700 | wandb.login()
701 | api = wandb.Api()
702 | current_user = api.viewer
703 | teams = current_user.teams
704 | if "ai2-llm" in teams:
705 | return "ai2-llm"
706 | else:
707 | return None
708 |
709 |
710 | @retry_on_exception()
711 | def maybe_use_ai2_hf_entity() -> Optional[str]:
712 | """Ai2 internal logic: try use the allenai entity if possible. Should not affect external users."""
713 | orgs = HfApi().whoami()
714 | orgs = [item["name"] for item in orgs["orgs"]]
715 | if "allenai" in orgs:
716 | return "allenai"
717 | else:
718 | return None
719 |
720 |
721 | def submit_beaker_eval_jobs(
722 | model_name: str,
723 | location: str,
724 | hf_repo_revision: str = "",
725 | workspace: str = "tulu-3-results",
726 | beaker_image: str = "nathanl/open_instruct_auto",
727 | upload_to_hf: str = "allenai/tulu-3-evals",
728 | run_oe_eval_experiments: bool = False,
729 | ) -> None:
730 | command = f"""
731 | python scripts/submit_eval_jobs.py \
732 | --model_name {model_name} \
733 | --location {location} \
734 | --is_tuned \
735 | --workspace {workspace} \
736 | --preemptible \
737 | --use_hf_tokenizer_template \
738 | --beaker_image {beaker_image} \
739 | """
740 | if len(hf_repo_revision) > 0:
741 | command += f" --hf_revision {hf_repo_revision}"
742 | if len(upload_to_hf) > 0:
743 | command += f" --upload_to_hf {upload_to_hf}"
744 | if run_oe_eval_experiments:
745 | command += " --run_oe_eval_experiments"
746 |
747 | process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
748 | stdout, stderr = process.communicate()
749 |
750 | print(f"Beaker evaluation jobs: Stdout:\n{stdout.decode()}")
751 | print(f"Beaker evaluation jobs: Stderr:\n{stderr.decode()}")
752 | print(f"Beaker evaluation jobs: process return code: {process.returncode}")
753 |
754 |
755 | @retry_on_exception()
756 | def upload_metadata_to_hf(
757 | metadata_dict,
758 | filename,
759 | hf_dataset_name,
760 | hf_dataset_save_dir,
761 | ):
762 | # upload a random dict to HF. Originally for uploading metadata to HF
763 | # about a model for leaderboard displays.
764 | with open("tmp.json", "w") as f:
765 | json.dump(metadata_dict, f)
766 | api = HfApi()
767 | api.upload_file(
768 | path_or_fileobj="tmp.json",
769 | path_in_repo=f"{hf_dataset_save_dir}/{filename}",
770 | repo_id=hf_dataset_name,
771 | repo_type="dataset",
772 | )
773 | os.remove("tmp.json")
774 |
--------------------------------------------------------------------------------
/media/accuracy_by_steps.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pints-AI/Finetune-Bench-RAG/a973a17d207caaf323378ded29a3d435edc209d8/media/accuracy_by_steps.png
--------------------------------------------------------------------------------
/media/depth_by_steps.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pints-AI/Finetune-Bench-RAG/a973a17d207caaf323378ded29a3d435edc209d8/media/depth_by_steps.png
--------------------------------------------------------------------------------
/media/helpfulness_by_steps.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pints-AI/Finetune-Bench-RAG/a973a17d207caaf323378ded29a3d435edc209d8/media/helpfulness_by_steps.png
--------------------------------------------------------------------------------
/media/pints_ai-banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pints-AI/Finetune-Bench-RAG/a973a17d207caaf323378ded29a3d435edc209d8/media/pints_ai-banner.png
--------------------------------------------------------------------------------
/media/relevance_by_steps.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pints-AI/Finetune-Bench-RAG/a973a17d207caaf323378ded29a3d435edc209d8/media/relevance_by_steps.png
--------------------------------------------------------------------------------
/prepare_dataset/README.md:
--------------------------------------------------------------------------------
1 | # Dataset
2 |
3 | The motivation here is to improve RAG (after the retrieval phase) by finetuning a model to filter out and select the intended (correct) information between intended (correct) and false positive(s). As such, the dataset is generated with the intended (correct) information and some fictitious data (simulating retrieval of incorrect data).
4 |
5 | ## Uploaded Dataset
6 |
7 | We have uploaded the dataset we used for training to HuggingFace [here](https://huggingface.co/datasets/pints-ai/Finetune-RAG). We have curated a total of `1653` documents that is ready for train-validation-test split before fine-tuning.
8 |
9 | ## Fictitious Data Generation [OPTIONAL]
10 |
11 | If you have your own set of document chunks and would like to curate questions, answers, and fictitious data from it for finetuning, you may do so by preparing a jsonl file that contains your document chunks per line. The structure of each line should be as follows:
12 |
13 | ```json
14 | {
15 | "content": "",
16 | "filename": "",
17 | }
18 | ```
19 |
20 | > [!WARNING]
21 | > Via our method, your documents will be passed into GPT-4o for dataset generation. If it is according to plan, remember to include the `.env` file containing your OpenAI key at the root level. Example is given in `.env.sample`.
22 |
23 | ### Usage
24 |
25 | Suppose your custom jsonl file is at `dataset/custom_chunks.jsonl`, run as a module at the root level:
26 | ```bash
27 | python -m prepare_dataset.content_generation.generate_question --content_path dataset/custom_chunks.jsonl && \
28 | python -m prepare_dataset.content_generation.generate_answer && \
29 | python -m prepare_dataset.content_generation.generate_fictitious_content
30 | ```
31 |
32 | ## Prepare Dataset For Training
33 |
34 | `prepare_dataset/formatting/generate_training_data.py` is the script to process the data generated for training. Essentially, it allows you to specify 2 different dialogue formats that is used to tune the model.
35 |
36 | Using our prepared dataset from Hugging Face, or your own generated dataset, run through this preparation before fine-tuning.
37 |
38 | ### Baseline Format
39 |
40 | ```
41 | Filename: {filename1}
42 | Information:
43 | {content1}
44 |
45 | Filename: {filename2}
46 | Information:
47 | {content2}
48 |
49 | Question: {question}
50 | ```
51 |
52 | ### XML Format
53 |
54 | ```
55 |
56 |
57 | {filename1}
58 | {content1}
59 |
60 |
61 | {filename2}
62 | {content2}
63 |
64 |
65 |
66 | Question: {question}
67 | ```
68 |
69 | ### Usage
70 |
71 | ```bash
72 | python -m prepare_dataset.formatting.generate_training_data --content_path dataset/finetunerag_dataset.jsonl --format baseline
73 | ```
74 |
75 | ## Split the dataset
76 |
77 | Simply run the `prepare_dataset/formatting/split_data.py` script to get your train-validation-test splits after preparing your dataset via the instructions above.
78 |
79 | ```bash
80 | python -m prepare_dataset.formatting.split_data --dataset_path dataset/adjusted_finetunerag_dataset.jsonl
81 | ```
82 |
--------------------------------------------------------------------------------
/prepare_dataset/content_generation/generate_answer.py:
--------------------------------------------------------------------------------
1 | from logging import Logger
2 | from pathlib import Path
3 | from typing import List, Dict
4 | import copy
5 |
6 | from prompts.answer_generation_prompt import SYSTEM_MESSAGE, USER_MESSAGE_TEMPLATE
7 | from utils.dataset_utils import load_jsonl_file, write_jsonl_file
8 | from utils.logger import setup_logger
9 | from utils.openai import call_openai_api
10 |
11 | def start(
12 | content_path: Path = "dataset/contents_w_questions.jsonl",
13 | output_path: Path = "dataset/contents_w_qa.jsonl",
14 | ):
15 | assert (
16 | content_path.is_file()
17 | and content_path.suffix == ".jsonl"
18 | ), "Path to content data is not a jsonl file."
19 |
20 | logger: Logger = setup_logger(Path(__file__).stem + "_" + content_path.name)
21 | logger.info(f"{content_path} recognized.")
22 |
23 | all_content_qn_data: List[Dict] = load_jsonl_file(content_path)
24 |
25 | # Ensure dataset is ready before generating answers
26 | for index, content_data in enumerate(all_content_qn_data):
27 | missing_keys = [key for key in ["content", "filename", "question"] if key not in content_data]
28 | assert not missing_keys, f"Missing key(s) in line {index + 1} of jsonl file: [{', '.join(missing_keys)}]"
29 |
30 | logger.info(f"Dataset checked and ready for answer generation. Generating answers...")
31 | content_data_with_answer: List[Dict] = []
32 | for index, content_data in enumerate(all_content_qn_data):
33 | filename: str = content_data["filename"]
34 | content: str = content_data["content"]
35 | question: str = content_data["question"]
36 | generated_answer: str = generate_answer(filename=filename, content=content, question=question)
37 | logger.info(f"Line {index + 1} of jsonl file --> Answer generated for question: {question[:60].encode('unicode_escape').decode()}... Answer: {generated_answer[:60]}...")
38 |
39 | cloned_content_data: Dict = copy.deepcopy(content_data)
40 | cloned_content_data["answer"] = generated_answer
41 | content_data_with_answer.append(cloned_content_data)
42 |
43 | write_jsonl_file(content=content_data_with_answer, output_path=output_path)
44 | logger.info(f"Generation of answers completed. Updated jsonl file saved at {output_path}")
45 | return
46 |
47 | def generate_answer(
48 | filename: str,
49 | content: str,
50 | question: str,
51 | ):
52 | user_message: str = USER_MESSAGE_TEMPLATE.format(filename=filename, content=content, question=question)
53 | messages = [
54 | {"role": "system", "content": SYSTEM_MESSAGE},
55 | {"role": "user", "content": user_message}
56 | ]
57 |
58 | generated_openai_json: Dict = call_openai_api(messages=messages, output_as_json=True)
59 |
60 | assert "answer" in generated_openai_json, "OpenAI model did not generate correct json format"
61 | return generated_openai_json["answer"]
62 |
63 | if __name__ == "__main__":
64 | from jsonargparse import CLI
65 | CLI(start, as_positional=False)
66 | # python -m prepare_dataset.generate_answer --content_qn_data_path dataset/documents_w_qns/sample.jsonl --output_data_path dataset/documents_w_qa/sample.jsonl
67 |
--------------------------------------------------------------------------------
/prepare_dataset/content_generation/generate_fictitious_content.py:
--------------------------------------------------------------------------------
1 | from logging import Logger
2 | from pathlib import Path
3 | from typing import List, Dict
4 | import copy
5 |
6 | from prompts.fictitious_content_generation_prompt import SYSTEM_MESSAGE, USER_MESSAGE_TEMPLATE
7 | from utils.dataset_utils import load_jsonl_file, write_jsonl_file
8 | from utils.logger import setup_logger
9 | from utils.openai import call_openai_api
10 |
11 | def start(
12 | content_path: Path = "dataset/contents_w_qa.jsonl",
13 | output_path: Path = "dataset/finetunerag_dataset.jsonl",
14 | num_fictitious_content: int = 2
15 | ):
16 | assert (
17 | content_path.is_file()
18 | and content_path.suffix == ".jsonl"
19 | ), "Path to content data is not a jsonl file."
20 |
21 | assert num_fictitious_content > 0, "Number of fictitious content should be more than 0"
22 | assert isinstance(num_fictitious_content, int), "Number of fictitious content should be an integer"
23 |
24 | logger: Logger = setup_logger(Path(__file__).stem + "_" + content_path.name)
25 | logger.info(f"{content_path} recognized.")
26 |
27 | all_content_data: List[Dict] = load_jsonl_file(content_path)
28 |
29 | # Ensure dataset is ready before generating questions
30 | for index, content_data in enumerate(all_content_data):
31 | missing_keys = [key for key in ["content", "filename"] if key not in content_data]
32 | assert not missing_keys, f"Missing key(s) in line {index + 1} of jsonl file: [{', '.join(missing_keys)}]"
33 |
34 | logger.info(f"Dataset checked and ready for fictitious content generation. Generating content...")
35 | fictitious_content_data: List[Dict] = []
36 | for index, content_data in enumerate(all_content_data):
37 | filename: str = content_data["filename"]
38 | content: str = content_data["content"]
39 |
40 | cloned_content_data: Dict = copy.deepcopy(content_data)
41 | for fictitious_content_count in range(1, num_fictitious_content + 1):
42 | fictitious_filename, fictitious_content = generate_fictitious_content(filename=filename, content=content)
43 | logger.info(f"Line {index + 1} of jsonl file --> Fictitious content {fictitious_content_count} generated. Filename: {fictitious_filename[:60]}..., Content: {fictitious_content[:60]}...")
44 |
45 | cloned_content_data[f"fictitious_filename{fictitious_content_count}"] = fictitious_filename
46 | cloned_content_data[f"fictitious_content{fictitious_content_count}"] = fictitious_content
47 |
48 | fictitious_content_data.append(cloned_content_data)
49 |
50 | write_jsonl_file(content=fictitious_content_data, output_path=output_path)
51 | logger.info(f"Generation of fictitious contents completed. Updated jsonl file saved at {output_path}")
52 | return
53 |
54 | def generate_fictitious_content(
55 | filename: str,
56 | content: str,
57 | ):
58 | user_message: str = USER_MESSAGE_TEMPLATE.format(filename=filename, content=content)
59 | messages = [
60 | {"role": "system", "content": SYSTEM_MESSAGE},
61 | {"role": "user", "content": user_message}
62 | ]
63 |
64 | # Set temperature to 0.5 for more spurious content output generation
65 | generated_openai_json: Dict = call_openai_api(messages=messages, output_as_json=True, temperature=0.5)
66 | assert (
67 | "fictitious_filename" in generated_openai_json
68 | and "fictitious_content" in generated_openai_json
69 | ), "OpenAI model did not generate correct JSON format"
70 |
71 | return generated_openai_json["fictitious_filename"], generated_openai_json["fictitious_content"]
72 |
73 | if __name__ == "__main__":
74 | from jsonargparse import CLI
75 | CLI(start, as_positional=False)
76 |
--------------------------------------------------------------------------------
/prepare_dataset/content_generation/generate_question.py:
--------------------------------------------------------------------------------
1 | from logging import Logger
2 | from pathlib import Path
3 | from typing import List, Dict
4 | import copy
5 |
6 | from prompts.question_generation_prompt import SYSTEM_MESSAGE, USER_MESSAGE_TEMPLATE
7 | from utils.dataset_utils import load_jsonl_file, write_jsonl_file
8 | from utils.logger import setup_logger
9 | from utils.openai import call_openai_api
10 |
11 | def start(
12 | content_path: Path = "dataset/contents.jsonl",
13 | output_path: Path = "dataset/contents_w_questions.jsonl",
14 | ):
15 | assert (
16 | content_path.is_file()
17 | and content_path.suffix == ".jsonl"
18 | ), "Path to content data is not a jsonl file."
19 |
20 | logger: Logger = setup_logger(Path(__file__).stem + "_" + content_path.name)
21 | logger.info(f"{content_path} recognized.")
22 |
23 | all_content_data: List[Dict] = load_jsonl_file(content_path)
24 |
25 | # Ensure dataset is ready before generating questions
26 | for index, content_data in enumerate(all_content_data):
27 | missing_keys = [key for key in ["content", "filename"] if key not in content_data]
28 | assert not missing_keys, f"Missing key(s) in line {index + 1} of jsonl file: [{', '.join(missing_keys)}]"
29 |
30 | logger.info(f"Dataset checked and ready for question generation. Generating questions...")
31 | content_data_with_question: List[Dict] = []
32 | for index, content_data in enumerate(all_content_data):
33 | filename: str = content_data["filename"]
34 | content: str = content_data["content"]
35 | generated_question: str = generate_question(filename=filename, content=content)
36 | logger.info(f"Line {index + 1} of jsonl file --> Question generated for content: {content[:60].encode('unicode_escape').decode()}... Question: {generated_question}")
37 |
38 | cloned_content_data: Dict = copy.deepcopy(content_data)
39 | cloned_content_data["question"] = generated_question
40 | content_data_with_question.append(cloned_content_data)
41 |
42 | write_jsonl_file(content=content_data_with_question, output_path=output_path)
43 | logger.info(f"Generation of questions completed. Updated jsonl file saved at {output_path}")
44 | return
45 |
46 | def generate_question(
47 | filename: str,
48 | content: str,
49 | ):
50 | user_message: str = USER_MESSAGE_TEMPLATE.format(filename=filename, content=content)
51 | messages = [
52 | {"role": "system", "content": SYSTEM_MESSAGE},
53 | {"role": "user", "content": user_message}
54 | ]
55 |
56 | generated_openai_json: Dict = call_openai_api(messages=messages, output_as_json=True)
57 |
58 | assert "question" in generated_openai_json, "OpenAI model did not generate correct json format"
59 | return generated_openai_json["question"]
60 |
61 | if __name__ == "__main__":
62 | from jsonargparse import CLI
63 | CLI(start, as_positional=False)
64 |
65 |
--------------------------------------------------------------------------------
/prepare_dataset/formatting/formatter.py:
--------------------------------------------------------------------------------
1 | from typing import Generator
2 |
3 | from prompts.rag_prompt import SYSTEM_MESSAGE, BASELINE_TEMPLATE, XML_TEMPLATE
4 |
5 | templates = {
6 | "baseline": BASELINE_TEMPLATE,
7 | "xml": XML_TEMPLATE,
8 | }
9 |
10 | def formatter(
11 | filename: str,
12 | content: str,
13 | fictitious_filename: str,
14 | fictitious_content: str,
15 | question: str,
16 | answer: str,
17 | decider: Generator,
18 | template_type: str = "baseline",
19 | ):
20 | assert template_type in templates, f"{template_type} template does not exist. Available templates: {templates.keys()}"
21 |
22 | filename1, filename2 = filename, fictitious_filename
23 | content1, content2 = content, fictitious_content
24 |
25 | # Whether to change the order of the fictitious and non-fictitious
26 | flip_content_order = next(decider)
27 | if flip_content_order:
28 | filename2, filename1 = filename1, filename2
29 | content2, content1 = content1, content2
30 |
31 | user_message: str = templates.get(template_type).format(
32 | filename1=filename1,
33 | content1=content1,
34 | filename2=filename2,
35 | content2=content2,
36 | question=question
37 | )
38 |
39 | return {
40 | "messages": [
41 | {"role": "system", "content": SYSTEM_MESSAGE},
42 | {"role": "user", "content": user_message},
43 | {"role": "assistant", "content": answer}
44 | ],
45 | "content": content,
46 | "question": question,
47 | "filename": filename,
48 | }
49 |
--------------------------------------------------------------------------------
/prepare_dataset/formatting/generate_training_data.py:
--------------------------------------------------------------------------------
1 | from logging import Logger
2 | from pathlib import Path
3 | from typing import List, Literal, Dict
4 |
5 | from prepare_dataset.formatting.formatter import formatter
6 | from utils.dataset_utils import get_decider, load_jsonl_file, write_jsonl_file, get_dice
7 | from utils.logger import setup_logger
8 |
9 | def start(
10 | content_path: Path = "dataset/finetunerag_dataset.jsonl",
11 | output_path: Path = "dataset/adjusted_finetunerag_dataset.jsonl",
12 | format: Literal["baseline", "xml"] = "baseline",
13 | ):
14 | assert (
15 | content_path.is_file()
16 | and content_path.suffix == ".jsonl"
17 | ), "Path to content data is not a jsonl file."
18 |
19 | logger: Logger = setup_logger(Path(__file__).stem + "_" + content_path.name)
20 | logger.info(f"{content_path} recognized.")
21 |
22 | all_content_data: List[Dict] = load_jsonl_file(content_path)
23 |
24 | # Ensure dataset is ready before generating questions
25 | for index, content_data in enumerate(all_content_data):
26 | missing_keys = [key for key in ["content", "filename", "question", "answer", "fictitious_filename1", "fictitious_content1"] if key not in content_data]
27 | assert not missing_keys, f"Missing key(s) in line {index + 1} of jsonl file: [{', '.join(missing_keys)}]"
28 |
29 | logger.info(f"Dataset checked and ready for formatting. Preparing training data...")
30 |
31 | # It is assumed that every datapoint has the same number of fictitious content. Refering to the first datapoint to retrieve the number of ficitious content
32 | num_fictitious_content = sum(1 for key in all_content_data[0] if key.startswith("fictitious_content"))
33 | dice = get_dice(num_choices=num_fictitious_content)
34 |
35 | decider = get_decider()
36 | all_formatted_data = []
37 | for index, content_data in enumerate(all_content_data):
38 | filename: str = content_data["filename"]
39 | content: str = content_data["content"]
40 | question: str = content_data["question"]
41 | answer: str = content_data["answer"]
42 |
43 |
44 | selected_content_index = next(dice)
45 | fictitious_filename = content_data[f"fictitious_filename{selected_content_index}"]
46 | fictitious_content = content_data[f"fictitious_content{selected_content_index}"]
47 |
48 | formatted_data = formatter(
49 | filename=filename,
50 | content=content,
51 | fictitious_filename=fictitious_filename,
52 | fictitious_content=fictitious_content,
53 | question=question,
54 | answer=answer,
55 | decider=decider,
56 | template_type=format,
57 | )
58 | all_formatted_data.append(formatted_data)
59 |
60 | write_jsonl_file(content=all_formatted_data, output_path=output_path)
61 | logger.info(f"Preparation of training data completed. Training data saved at {output_path}")
62 | return
63 |
64 | if __name__ == "__main__":
65 | from jsonargparse import CLI
66 | CLI(start, as_positional=False)
67 |
68 |
--------------------------------------------------------------------------------
/prepare_dataset/formatting/split_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 |
4 | from pathlib import Path
5 |
6 | def start(
7 | dataset_path: Path = "dataset/adjusted_finetunerag_dataset.jsonl",
8 | output_folder_path: Path = "dataset/splits",
9 | seed: int = 888,
10 | ):
11 | with open(dataset_path, "r") as f:
12 | lines = f.readlines()
13 |
14 | random.seed(seed)
15 | random.shuffle(lines)
16 |
17 | total = len(lines)
18 | test_size = int(0.1 * total)
19 | train_size = total - (test_size * 2)
20 |
21 | train_data = lines[:train_size]
22 | val_data = lines[train_size:train_size + test_size]
23 | test_data = lines[train_size + test_size:]
24 |
25 | trimmed_test_data = []
26 | for line in test_data:
27 | obj = json.loads(line)
28 | if isinstance(obj.get("messages"), list) and obj["messages"]:
29 | obj["messages"] = obj["messages"][:-1]
30 | trimmed_test_data.append(json.dumps(obj) + "\n")
31 |
32 | output_folder_path.mkdir(parents=True, exist_ok=True)
33 |
34 | (output_folder_path / "train.jsonl").write_text("".join(train_data))
35 | (output_folder_path / "validation.jsonl").write_text("".join(val_data))
36 | (output_folder_path / "test.jsonl").write_text("".join(trimmed_test_data))
37 |
38 | print(f"Done! Files created: {(output_folder_path / 'train.jsonl')}, {(output_folder_path / 'validation.jsonl')}, {(output_folder_path / 'test.jsonl')}")
39 |
40 | if __name__ == "__main__":
41 | from jsonargparse import CLI
42 | CLI(start, as_positional=False)
43 |
--------------------------------------------------------------------------------
/prompts/answer_generation_prompt.py:
--------------------------------------------------------------------------------
1 | SYSTEM_MESSAGE = (
2 | "You are an AI Data Scientist who creates high quality datasets that can be used for fine-tuning of Large Language Models."
3 | " Follow the user's instruction closely to create a dataset based on the given context."
4 | )
5 |
6 | USER_MESSAGE_TEMPLATE = """You are to read the following information and answer the question.
7 | Filename: {filename}
8 | Information: {content}
9 |
10 | Now, answer the question, and you may elaborate as necessary. Do not create information that is not found in the information provided.
11 | Question: "{question}"
12 |
13 | You will reply in the following JSON format:
14 | {{
15 | 'answer': ""
16 | }}"""
17 |
--------------------------------------------------------------------------------
/prompts/fictitious_content_generation_prompt.py:
--------------------------------------------------------------------------------
1 | SYSTEM_MESSAGE = (
2 | "You are an AI Data Scientist who creates high quality datasets that can be used for fine-tuning of Large Language Models."
3 | " Follow the user's instruction closely to create a dataset based on the given context."
4 | )
5 |
6 | USER_MESSAGE_TEMPLATE = """You are tasked with creating a fictitious set of information that must be thematically similar to the original user's filename and content. The fictitious file name should be from a different company, different person, or a different place etc. The fictitious content should have some similar parts but also some parts entirely fabricated with a different structure and also random incorrect content. The fictitious content must still match the original content in length, preserving patterns such as spurious use of line breaks (\\n), spewed Unicode, or broken words. In particular, ensure that the frequency and placement of line breaks (\\n) and unicode is similar to the original. The generation of fictitious content should not always be simple rephrasing, but should feel like it comes from a different document. The output should contain two JSON keys: 'fictitious_filename' and 'fictitious_content'.
7 |
8 | Here is the original user's filename and content:
9 | Filename: {filename}
10 | Information: {content}
11 |
12 | Now, generate 1 set of ficticious filename and content.
13 |
14 | You will reply in the following JSON format:
15 | {{
16 | 'fictitious_filename': "",
17 | 'fictitious_content': ""
18 | }}"""
19 |
--------------------------------------------------------------------------------
/prompts/judging_prompts.py:
--------------------------------------------------------------------------------
1 | from typing import TypedDict
2 | from openai.types.chat.chat_completion_message_param import (
3 | ChatCompletionSystemMessageParam,
4 | ChatCompletionUserMessageParam,
5 | )
6 |
7 | class OpenAIJudgeResponse(TypedDict):
8 | accuracy_explanation: str
9 | accuracy: bool
10 | helpfulness_explanation: str
11 | helpfulness: int
12 | relevance_explanation: str
13 | relevance: int
14 | depth_explanation: str
15 | depth: int
16 |
17 | # This prompt has been deprecated. Evaluation of multiple factors with the same call
18 | # can lead to biased results. The result for each criterion might directly
19 | # influence the ones after, resulting in an inaccurate or less objective assessment.
20 | judge_system_prompt: ChatCompletionSystemMessageParam = {
21 | 'role': 'system',
22 | 'content': """Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below, based solely on a piece of information extracted from a file provided below. Your evaluation should consider factors such as the accuracy, helpfulness, relevance, and depth of the response.
23 |
24 | 1. Accuracy - You will check whether the response contains extra details not found in the piece of information provided. If extra details are found, accuracy is false. Otherwise, accuracy is true. Take note that if the response partially addresses the question, but did not provide extra details not found in the piece of information provided, the response will still be considered accurate (hence accuracy = true).
25 | 2. Helpfulness - The helpfulness of the AI assistant in answering the question.
26 | 3. Relevance - Whether the response fully addresses the question.
27 | 4. Depth - The level of detail of the response in answering the question.
28 |
29 | Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate each factor on a scale of 1 to 10 (with the exception of accuracy, where it is true or false) by strictly following this JSON format:
30 | {
31 | "accuracy_explanation": ,
32 | "accuracy": ,
33 | "helpfulness_explanation": ,
34 | "helpfulness": ,
35 | "relevance_explanation": ,
36 | "relevance": ,
37 | "depth_explanation": ,
38 | "depth":
39 | }
40 | """,
41 | }
42 |
43 | judge_accuracy_system_prompt: ChatCompletionSystemMessageParam = {
44 | 'role': 'system',
45 | 'content': """Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below, based solely on a piece of information extracted from a file provided below. Your evaluation should consider the accuracy of the response.
46 |
47 | You will check whether the response contains extra details not found in the piece of information provided. If extra details are found, accuracy is false. Otherwise, accuracy is true. Take note that if the response partially addresses the question, but did not provide extra details not found in the piece of information provided, the response will still be considered accurate (hence accuracy = true).
48 |
49 | Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the accuracy with true or false by strictly following this JSON format:
50 | {
51 | "accuracy_explanation": ,
52 | "accuracy":
53 | }
54 | """,
55 | }
56 |
57 | judge_helpfulness_system_prompt: ChatCompletionSystemMessageParam = {
58 | 'role': 'system',
59 | 'content': """Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below, based solely on a piece of information extracted from a file provided below. Your evaluation should consider the helpfulness of the response.
60 |
61 | You will check whether the AI assistant is helpful in answering the question based on the response.
62 |
63 | Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the helpfulness on a scale of 1 to 10 by strictly following this JSON format:
64 | {
65 | "helpfulness_explanation": ,
66 | "helpfulness":
67 | }
68 | """,
69 | }
70 |
71 | judge_relevance_system_prompt: ChatCompletionSystemMessageParam = {
72 | 'role': 'system',
73 | 'content': """Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below, based solely on a piece of information extracted from a file provided below. Your evaluation should consider the relevance of the response.
74 |
75 | You will check the relevance of the response by evaluating whether the response fully addresses the question.
76 |
77 | Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the relevance on a scale of 1 to 10 by strictly following this JSON format:
78 | {
79 | "relevance_explanation": ,
80 | "relevance":
81 | }
82 | """,
83 | }
84 |
85 | judge_depth_system_prompt: ChatCompletionSystemMessageParam = {
86 | 'role': 'system',
87 | 'content': """Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below, based solely on a piece of information extracted from a file provided below. Your evaluation should consider the depth of the response.
88 |
89 | You will check the depth of the response by evaluating the level of detail of the response in answering the question.
90 |
91 | Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the depth on a scale of 1 to 10 by strictly following this JSON format:
92 | {
93 | "depth_explanation": ,
94 | "depth":
95 | }
96 | """,
97 | }
98 |
99 |
100 | def get_judge_user_prompt(document) -> ChatCompletionUserMessageParam:
101 | return {
102 | 'role': 'user',
103 | 'content': f"""[The Start of Provided Information Extracted from a File]
104 | Filename: {document['filename']}
105 |
106 | Information: {document['content']}
107 | [The End of Provided Information]
108 |
109 | [Question]
110 | {document['question']}
111 |
112 | [The Start of Assistant's Response]
113 | {document['response']}
114 | [The End of Assistant's Response]""",
115 | }
116 |
--------------------------------------------------------------------------------
/prompts/prompt_styles.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from abc import abstractmethod
3 | from typing import Dict, List, Type, Union
4 |
5 |
6 | class PromptStyle:
7 | """Base interface for prompt styles."""
8 |
9 | @abstractmethod
10 | def apply(self, prompt: str, **kwargs: str) -> str:
11 | raise NotImplementedError('PromptStyle.apply() is an abstract method.')
12 |
13 | @classmethod
14 | def from_name(cls, name: str) -> 'PromptStyle':
15 | if name not in prompt_styles:
16 | return None
17 | return prompt_styles[name]()
18 |
19 |
20 | # This default style is exported from the original open_instruct repository:
21 | # https://github.com/allenai/open-instruct/blob/main/open_instruct/finetune.py#L396-L407
22 | class Default(PromptStyle):
23 | def apply(
24 | self,
25 | prompt: Union[str, List[Dict[str, str]]],
26 | **kwargs: str,
27 | ) -> str:
28 | message_text = ''
29 | for message in prompt:
30 | if message['role'] == 'system':
31 | message_text += '<|system|>\n' + message['content'].strip() + '\n'
32 | elif message['role'] == 'user':
33 | message_text += '<|user|>\n' + message['content'].strip() + '\n'
34 | elif message['role'] == 'assistant':
35 | message_text += (
36 | '<|assistant|>\n' + message['content'].strip() + '' + '\n'
37 | )
38 | else:
39 | raise ValueError(f"Invalid role: {message['role']}")
40 | return message_text
41 |
42 |
43 | class Llama3_1(PromptStyle):
44 | DEFAULT_SYSTEM_MESSAGE = 'You are a helpful assistant.'
45 |
46 | def apply(
47 | self,
48 | prompt: Union[str, List[Dict[str, str]]],
49 | no_system: bool = False,
50 | append_assistant_header: bool = False,
51 | **kwargs: str,
52 | ) -> str:
53 | assert isinstance(
54 | prompt, list
55 | ), f'Unsupported prompt type: {type(prompt)}. prompt should be formatted in a list of dict'
56 |
57 | tokens = []
58 | if not no_system and not self.has_system_prompt(prompt):
59 | tokens.extend(
60 | self.encode_message(
61 | {'role': 'system', 'content': Llama3_1.DEFAULT_SYSTEM_MESSAGE}
62 | )
63 | )
64 |
65 | for i, message in enumerate(prompt):
66 | if i != 0 and message['role'] == 'system':
67 | raise ValueError("'system' role is only allowed at the beginning of the conversation list.")
68 | if message['role'] not in ['assistant', 'user', 'system']:
69 | warnings.warn(
70 | f"Unknown role: '{message['role']}'. Supported roles are 'assistant', 'user', and 'system'. It is assumed that this is intended.",
71 | UserWarning,
72 | )
73 |
74 | tokens.extend(self.encode_message(message))
75 |
76 | if append_assistant_header:
77 | tokens.extend(self.encode_header('assistant'))
78 |
79 | return ''.join(tokens)
80 |
81 | def encode_header(self, role: str) -> List[str]:
82 | return [f'<|start_header_id|>{role}<|end_header_id|>\n\n']
83 |
84 | def encode_message(self, message: Dict[str, str]) -> List[str]:
85 | tokens = self.encode_header(message['role'])
86 | # NOTE: Meta stripped this. I'm not sure I agree, but who am I to argue?
87 | tokens.append(message['content'].strip())
88 | tokens.append('<|eot_id|>')
89 | return tokens
90 |
91 | def has_system_prompt(self, messages: List[Dict[str, str]]) -> bool:
92 | return messages[0].get('role', '') == 'system' if len(messages) else False
93 |
94 |
95 | class OLMoE(PromptStyle):
96 | def apply(
97 | self,
98 | prompt: Union[str, List[Dict[str, str]]],
99 | add_generation_prompt: bool = False,
100 | bos_token: str = '',
101 | eos_token: str = '',
102 | **kwargs,
103 | ) -> str:
104 | assert isinstance(
105 | prompt, list
106 | ), f"Expected prompt to be a list of messages, got {type(prompt)}"
107 |
108 | result = [bos_token]
109 | for i, message in enumerate(prompt):
110 | role = message.get("role")
111 | content = message.get("content", "").strip()
112 |
113 | if role == "system":
114 | result.append("<|system|>\n" + content)
115 | elif role == "user":
116 | result.append("<|user|>\n" + content)
117 | elif role == "assistant":
118 | result.append("<|assistant|>\n" + content + eos_token)
119 | else:
120 | raise ValueError(f"Unsupported message role: {role}")
121 |
122 | # Append assistant header if it's the last message and generation is expected
123 | if i == len(prompt) - 1 and add_generation_prompt:
124 | result.append("<|assistant|>")
125 |
126 | return "\n".join(result) + "\n"
127 |
128 |
129 | prompt_styles: Dict[str, Type[PromptStyle]] = {
130 | 'default': Default,
131 | 'llama3.1': Llama3_1,
132 | 'olmoe': OLMoE,
133 | }
134 |
--------------------------------------------------------------------------------
/prompts/question_generation_prompt.py:
--------------------------------------------------------------------------------
1 | SYSTEM_MESSAGE = (
2 | "You are an AI Data Scientist who creates high quality datasets that can be used for fine-tuning of Large Language Models."
3 | " Follow the user's instruction closely to create a dataset based on the given context."
4 | )
5 |
6 | USER_MESSAGE_TEMPLATE = """You are to read the following information and generate a question.
7 | Filename: {filename}
8 | Information: {content}
9 |
10 | Now, generate only 1 straightforward, broad, simple and general question.
11 |
12 | You will reply in the following JSON format:
13 | {{
14 | 'question': ""
15 | }}"""
16 |
--------------------------------------------------------------------------------
/prompts/rag_prompt.py:
--------------------------------------------------------------------------------
1 | SYSTEM_MESSAGE = (
2 | "Some information is retrieved from the database as provided based on the user’s question."
3 | " The assistant is to answer the question to the best of his/her ability, using only the information provided."
4 | " The assistant must not add his/her own knowledge."
5 | )
6 |
7 | templates = []
8 |
9 | BASELINE_TEMPLATE = """Filename: {filename1}
10 | Information:
11 | {content1}
12 |
13 | Filename: {filename2}
14 | Information:
15 | {content2}
16 |
17 | Question: {question}"""
18 | templates.append(BASELINE_TEMPLATE)
19 |
20 | XML_TEMPLATE = """
21 |
22 | {filename1}
23 | {content1}
24 |
25 |
26 | {filename2}
27 | {content2}
28 |
29 |
30 |
31 | Question: {question}"""
32 | templates.append(XML_TEMPLATE)
33 |
34 | def get_template(template: str):
35 | if template not in templates:
36 | raise KeyError(f"Template '{template}' does not exist. Available templates: {list(templates.keys())}")
37 |
38 | return templates[template]
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.6.0 # SHOULD BE DONE OUTSIDE OF REQUIREMENTS.TXT
2 | accelerate==0.31.0
3 | datasets==3.2.0
4 | deepspeed==0.16.3
5 | flash-attn==2.6.3
6 | jsonargparse==4.36.0
7 | openai==1.60.0
8 | protobuf==5.29.3
9 | python-dotenv==1.0.1
10 | rich==13.9.4
11 | sentencepiece==0.2.0
12 | transformers==4.48.3
13 | matplotlib==3.10.1
--------------------------------------------------------------------------------
/utils/dataset_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 |
4 | from pathlib import Path
5 | from typing import List, Dict, Generator
6 |
7 | def load_jsonl_file(file_path: Path) -> List[Dict]:
8 | data = []
9 | with open(file_path, 'r', encoding='utf-8') as file:
10 | for line in file:
11 | dict_obj = json.loads(line)
12 | data.append(dict_obj)
13 |
14 | assert len(data) > 0, f'{file_path.name} is empty!'
15 | return data
16 |
17 | def write_jsonl_file(content: List[dict], output_path: Path):
18 | # create all ancestors directory if it doesn't exist
19 | output_path.parent.mkdir(parents=True, exist_ok=True)
20 | with open(output_path, 'w', encoding='utf-8') as f:
21 | for entry in content:
22 | f.write(json.dumps(entry) + '\n')
23 |
24 | def get_dice(num_choices: int, seed: int=69) -> Generator[int, None, None]:
25 | if seed:
26 | random.seed(seed)
27 |
28 | choices = list(range(1, num_choices + 1))
29 | while True:
30 | yield random.choice(choices)
31 |
32 | def get_decider(seed: int=69) -> Generator[int, None, None]:
33 | if seed:
34 | random.seed(seed)
35 |
36 | choices = [True, False]
37 | while True:
38 | yield random.choice(choices)
39 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from pathlib import Path
3 | from typing import Optional
4 | import threading
5 |
6 | BASE_DIR = Path(__file__).resolve().parent.parent
7 | LOGS_DIR = (BASE_DIR / 'logs').resolve()
8 | LOGS_DIR.mkdir(exist_ok=True)
9 |
10 | # Create a lock object for thread safety
11 | log_lock = threading.Lock()
12 |
13 | def setup_logger(
14 | namespace: str, logging_level=logging.DEBUG, logfile_name: Optional[str] = None
15 | ):
16 | with log_lock:
17 | # Create a custom logger
18 | logger = logging.getLogger(namespace)
19 |
20 | # Check if the logger already has handlers to prevent adding multiple
21 | if not logger.hasHandlers():
22 | # Set the logging level
23 | logger.setLevel(logging_level)
24 |
25 | # Create handlers
26 | logfile = f'{logfile_name if logfile_name else namespace}.log'
27 | logfile_path = LOGS_DIR / logfile
28 | file_handler = logging.FileHandler(logfile_path)
29 | console_handler = logging.StreamHandler()
30 |
31 | # Set the logging level for the handlers
32 | file_handler.setLevel(logging_level)
33 | console_handler.setLevel(logging_level)
34 |
35 | # Create a formatter and set it for the handlers
36 | formatter = logging.Formatter(
37 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
38 | )
39 | file_handler.setFormatter(formatter)
40 | console_handler.setFormatter(formatter)
41 |
42 | # Add the handlers to the logger
43 | logger.addHandler(file_handler)
44 | logger.addHandler(console_handler)
45 |
46 | return logger
47 |
--------------------------------------------------------------------------------
/utils/openai.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from dotenv import load_dotenv
4 | from openai import OpenAI
5 | from os import environ
6 | from typing import Dict, List
7 |
8 | load_dotenv()
9 | CLIENT = OpenAI(api_key=environ.get('OPENAI_API_KEY'))
10 |
11 | def call_openai_api(
12 | messages: List[Dict[str, str]],
13 | model: str = 'gpt-4o',
14 | output_as_json: bool = False,
15 | temperature: float = 0,
16 | ):
17 | # We do not check whether the indicated model support json_object.
18 | # Refer to https://platform.openai.com/docs/guides/json-mode for more information.
19 | response_format = {'type': 'json_object'} if output_as_json else {'type': 'text'}
20 |
21 | try:
22 | chat_completion = CLIENT.chat.completions.create(
23 | messages=messages,
24 | model=model,
25 | n=1,
26 | response_format=response_format,
27 | temperature=temperature,
28 | )
29 | response = chat_completion.choices[0].message.content
30 | except Exception as e:
31 | # Handle any unexpected error
32 | raise Exception(f"An unexpected error occurred: {str(e)}")
33 |
34 | if output_as_json:
35 | try:
36 | json_response = json.loads(response, strict=False)
37 | return json_response
38 | except json.JSONDecodeError as error:
39 | raise ValueError(f"JSON decoding failed: {error}")
40 | else:
41 | return response
42 |
--------------------------------------------------------------------------------