├── .flake8 ├── .gitignore ├── .pre-commit-config.yaml ├── README.md ├── assets ├── 1-shot-example.jpg ├── correct_packing_attention.png ├── cross_contamination.png ├── demo1.png ├── demo2.png ├── demo3.png ├── hotpot_qa_bridge.jpg └── hotpot_qa_compare.jpg ├── eval_hotpot_qa.py ├── extra_data ├── hotpot_dev_distractor_v1_random_500.json └── test_data │ ├── cities │ ├── Albuquerque-_New_Mexico.txt │ ├── Atlanta.txt │ ├── Austin-_Texas.txt │ ├── Bakersfield-_California.txt │ ├── Baltimore.txt │ ├── Boston.txt │ ├── Charlotte-_North_Carolina.txt │ ├── Chicago.txt │ ├── Colorado_Springs-_Colorado.txt │ ├── Columbus-_Ohio.txt │ ├── Dallas.txt │ ├── Denver.txt │ ├── Detroit.txt │ ├── El_Paso-_Texas.txt │ ├── Fort_Worth-_Texas.txt │ ├── Fresno-_California.txt │ ├── Houston.txt │ ├── Indianapolis.txt │ ├── Jacksonville-_Florida.txt │ ├── Kansas_City-_Missouri.txt │ ├── Las_Vegas.txt │ ├── Long_Beach-_California.txt │ ├── Los_Angeles.txt │ ├── Louisville-_Kentucky.txt │ ├── Memphis-_Tennessee.txt │ ├── Mesa-_Arizona.txt │ ├── Miami.txt │ ├── Milwaukee.txt │ ├── Minneapolis.txt │ ├── Nashville-_Tennessee.txt │ ├── New_York_City.txt │ ├── Oakland-_California.txt │ ├── Oklahoma_City.txt │ ├── Omaha-_Nebraska.txt │ ├── Philadelphia.txt │ ├── Phoenix-_Arizona.txt │ ├── Portland-_Oregon.txt │ ├── Raleigh-_North_Carolina.txt │ ├── Sacramento-_California.txt │ ├── San_Antonio.txt │ ├── San_Diego.txt │ ├── San_Francisco.txt │ ├── San_Jose-_California.txt │ ├── Seattle.txt │ ├── Tampa-_Florida.txt │ ├── Tucson-_Arizona.txt │ ├── Tulsa-_Oklahoma.txt │ ├── Virginia_Beach-_Virginia.txt │ ├── Washington-_D.C..txt │ └── Wichita-_Kansas.txt │ └── states │ ├── Alabama.txt │ ├── Alaska.txt │ ├── Arizona.txt │ ├── Arkansas.txt │ ├── California.txt │ ├── Colorado.txt │ ├── Connecticut.txt │ ├── Delaware.txt │ ├── Florida.txt │ ├── Georgia_(U.S._state).txt │ ├── Hawaii.txt │ ├── Idaho.txt │ ├── Illinois.txt │ ├── Indiana.txt │ ├── Iowa.txt │ ├── Kansas.txt │ ├── Kentucky.txt │ ├── Louisiana.txt │ ├── Maine.txt │ ├── Maryland.txt │ ├── Massachusetts.txt │ ├── Michigan.txt │ ├── Minnesota.txt │ ├── Mississippi.txt │ ├── Missouri.txt │ ├── Montana.txt │ ├── Nebraska.txt │ ├── Nevada.txt │ ├── New_Hampshire.txt │ ├── New_Jersey.txt │ ├── New_Mexico.txt │ ├── New_York.txt │ ├── North_Carolina.txt │ ├── North_Dakota.txt │ ├── Ohio.txt │ ├── Oklahoma.txt │ ├── Oregon.txt │ ├── Pennsylvania.txt │ ├── Rhode_Island.txt │ ├── South_Carolina.txt │ ├── South_Dakota.txt │ ├── Tennessee.txt │ ├── Texas.txt │ ├── Utah.txt │ ├── Vermont.txt │ ├── Virginia.txt │ ├── Washington.txt │ ├── West_Virginia.txt │ ├── Wisconsin.txt │ └── Wyoming.txt ├── gen_data ├── README.md ├── gen_answer.py ├── gen_multi_hop_attributes.py ├── gen_multi_hop_entities.py ├── gen_negatives.py ├── gen_sub_categories.py ├── gen_task.py ├── other_files │ ├── category_list.txt │ └── sub_categories.txt ├── prompts │ ├── 2_attributes.txt │ ├── 2_attributes_wo_answer.txt │ ├── 2_entities.txt │ ├── 2_entities_wo_answer.txt │ ├── answer_gen.txt │ ├── final_answer_gen.txt │ ├── gen_new_entity.txt │ ├── gen_paragraph.txt │ └── sub_category_gen.txt ├── template_gen.py └── utility.py ├── mypy.ini ├── qa_expert ├── __init__.py ├── base_inference.py ├── hf_inference.py ├── llama_cpp_inference.py ├── prompt_utils.py ├── server_inference.py └── vllm_inference.py ├── requirements.txt ├── run_example.py ├── run_retrieval_google.py ├── server.py ├── tests ├── __init__.py ├── test_cases.json └── test_prompt.py └── train ├── README.md ├── assert_monkey_patch.py ├── custom_datasets.py ├── data_statistics.py ├── ds_config ├── zero2.json └── zero3.json ├── length_statistics.py ├── merge_weight.py ├── mk_patched_mistral.py ├── requirements.txt ├── test.json ├── train_model.py └── upload_model_to_hf.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore=F401,E203,F403,W503,E266,E402 3 | max-line-length = 160 4 | exclude = tests/*,test_fixtures/* 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | .DS_Store 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: 22.3.0 4 | hooks: 5 | - id: black 6 | language_version: python3 7 | args: [--line-length=120] 8 | - repo: https://github.com/pre-commit/pre-commit-hooks 9 | rev: v1.2.3 10 | hooks: 11 | - id: flake8 12 | args: [--config=.flake8] 13 | - repo: https://github.com/pre-commit/mirrors-mypy 14 | rev: v0.942 15 | hooks: 16 | - id: mypy 17 | args: ["--ignore-missing-imports", "--namespace-packages", "--explicit-package-bases"] 18 | exclude: 'tests|scripts' 19 | additional_dependencies: ['types-requests'] 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # QA Expert: LLM for Multi-hop Question Answering 2 | 3 | QA Expert is a Language Model (LLM) specifically fine-tuned for the task of Question Answering, with a strong emphasis on addressing Multi-hop Question Answering scenarios. 4 | 5 |
6 |
7 |
9 | An example of 1-shot question (single question) and how QA Expert LLM handle multi-hop Q&A
10 | 11 |
12 |
13 |
14 |
16 | Examples of 2-shot questions and how QA Expert LLM handle multi-hop Q&A. The left is an example of bridging entitiy and the right is an example of comparing entities
17 | 18 | Multi-hop Question Answering is a task that necessitates the retrieval of multiple contexts, followed by their integration to deduce the answer to the question. 19 | 20 | QA Expert will analyze the question, if the question is a single question, it will use the question as the query for retrieval and retrieve once. If it is a multi-hop question, it will call the function: `retrieve` multiple times with different queries and finally summarize the retrieval contexts to generate the final answer. 21 | 22 | 23 | ## News 24 | - [2023/11/12] We released our finetuned model: khaimaitien/qa-expert-7B-V1.0based on [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) + our training data: [khaimaitien/qa-expert-multi-hop-qa-V1.0](https://huggingface.co/datasets/khaimaitien/qa-expert-multi-hop-qa-V1.0) 25 | 26 | ## Table of Content 27 | - [QA Expert: LLM for Multi-hop Question Answering](#qa-expert-llm-for-multi-hop-question-answering) 28 | - [News](#news) 29 | - [Table of Content](#table-of-content) 30 | - [Usage](#usage) 31 | - [Model Download](#model-download) 32 | - [Inference](#inference) 33 | - [Demo](#demo) 34 | - [Asking any free-domain question using Google Search API (through SERP API) as retrieval function](#asking-any-free-domain-question-using-google-search-api-through-serp-api-as-retrieval-function) 35 | - [Asking questions within a folder of txt files](#asking-questions-within-a-folder-of-txt-files) 36 | - [Training Data](#training-data) 37 | - [Training](#training) 38 | - [Evaluation](#evaluation) 39 | - [Citation](#citation) 40 | 41 | 42 | ## Usage 43 | ### Model Download 44 | 45 | Our model was finetuned on our **generated data** using OpenAI model (**gpt-3.5-turbo-instruct**) with [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) as the base model. 46 | 47 | | Size | Hugging Face Repo | Base Model | 48 | | --- | --- | --- | 49 | | 7B | [khaimaitien/qa-expert-7B-V1.0](https://huggingface.co/khaimaitien/qa-expert-7B-V1.0) | [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) | 50 | 51 | You can also find model in GGUF (for [Llama.cpp](https://github.com/ggerganov/llama.cpp)): 52 | 53 | | Size | Hugging Face Repo | 54 | | --- | --- | 55 | | 7B | [khaimaitien/qa-expert-7B-V1.0-GGUF](https://huggingface.co/khaimaitien/qa-expert-7B-V1.0-GGUF) | 56 | ### Inference 57 | Curently we support 3 types of inference: 58 | + Using [Huggingface Transformers](https://github.com/huggingface/transformers) 59 | + Using [Vllm](https://github.com/vllm-project/vllm) 60 | + Using [llama.cpp](https://github.com/ggerganov/llama.cpp) 61 | 62 | 63 | First please install the requirements: 64 | ```shell 65 | pip install -r requirements.txt 66 | ``` 67 | 68 | The example for using transformers HuggingFace: 69 | 70 | ```python 71 | 72 | from qa_expert import get_inference_model, InferenceType 73 | 74 | def retrieve(query: str) -> str: 75 | # You need to implement this retrieval function, input is a query and output is a string 76 | # This can be treated as the function to call in function calling of OpenAI 77 | return context 78 | 79 | model_inference = get_inference_model(InferenceType.hf, "khaimaitien/qa-expert-7B-V1.0") 80 | answer, messages = model_inference.generate_answer(question, retriever_func) 81 | ``` 82 | **For Vllm**, you need to install Vllm (```pip install vllm==0.2.1```) and change the InferenceType to vllm: 83 | ```python 84 | model_inference = get_inference_model(InferenceType.vllm, "khaimaitien/qa-expert-7B-V1.0") 85 | ``` 86 | **For LLama.cpp**, you need to install: [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) 87 | 88 | You need to download one of **gguf files** from here: [khaimaitien/qa-expert-7B-V1.0-GGUF](https://huggingface.co/khaimaitien/qa-expert-7B-V1.0-GGUF/tree/main). For example: 89 | ```shell 90 | wget https://huggingface.co/khaimaitien/qa-expert-7B-V1.0-GGUF/resolve/main/qa-expert-7B-V1.0.q4_0.gguf 91 | ``` 92 | Then pass the downloaded folder to the: ``get_inference_model``. 93 | ```python 94 | # Use q4_0 95 | model_inference = get_inference_model(InferenceType.llama_cpp, "qa-expert-7B-V1.0.q4_0.gguf") 96 | # Use q8_0 97 | model_inference = get_inference_model(InferenceType.llama_cpp, "qa-expert-7B-V1.0.q8_0.gguf") 98 | ``` 99 | ### Demo 100 | 101 | #### Asking any free-domain question using Google Search API (through SERP API) as retrieval function 102 | You can run this using Hugging Face Tranformers inference: 103 | ```shell 104 | python run_retrieval_google.py --qa-model khaimaitien/qa-expert-7B-V1.0 --inference-type hf 105 | ``` 106 | Once the model is loaded, you can ask any free-domain question and watch the process of handling the queries: 107 | 108 | + Retrieve Information (green): step to retrieve relevant information 109 | + retrieved context (yellow): the result of the retrieve function 110 | + Thought: the reasoning generated by model 111 | + Summary: summarizing retrieved information to form the final answer 112 | + Answer: the final answer to the question. 113 | 114 | You can also use Llama.cpp as inference type by: first **download the GGUF model**: 115 | ```shell 116 | wget https://huggingface.co/khaimaitien/qa-expert-7B-V1.0-GGUF/resolve/main/qa-expert-7B-V1.0.q4_0.gguf 117 | ``` 118 | Then run: 119 | 120 | ```shell 121 | python run_retrieval_google.py --qa-model qa-expert-7B-V1.0.q4_0.gguf --inference-type llama_cpp 122 | ``` 123 | 124 | The default serper_api_key is ``e9b35305c3b0a79189b7c2dc4c37adbc587d1e65``, this is the API_KEY of **my free account and limited to 2500 queries**. You can use your API KEY by passing: ``--serper-api-key YOUR_KEY`` 125 | 126 |
127 |
128 |
130 | Example for answering question: "how is the population of Vietnam compared with Philipines"
131 | 132 |
133 |
134 |
Example for answering question: "what are some tourist attractions in the biggest city in Japan?"
136 | 137 |
138 |
139 |
Example for answering question: "what is the second biggest city in Japan and how many people are there in that city?"
141 | 142 | #### Asking questions within a folder of txt files 143 | You can run ```run_example.py```. This example allows you to pass in a folder (**--data-folder**) containing the .txt files, it will read all .txt files inside the folder and split them into paragraphs, then paragraphs are represented as vectors by an embedding model (here, I use: [intfloat/e5-base-v2](https://huggingface.co/intfloat/e5-base-v2)) to be indexed in a vector DB (Here we use [Chromadb](https://www.trychroma.com/)). The retrieve function will search over indexed paragraphs to find the most relevant one. 144 | 145 | ```shell 146 | python run_example --data-folder extra_data/test_data/cities --qa-model khaimaitien/qa-expert-7B-V1.0 --inference-type hf 147 | ``` 148 | Options: 149 | + **--data-folder** (default=extra_data/test_data/cities): The folder containing the .txt files to create indexed paragraphs for retrieval 150 | + **--qa-model**: The path to the model Hugging Face path or local folder 151 | + **--inference-type**: one of: vllm, hf, llama_cpp. If it is: llama_cpp, the --qa-model must be local folder downloaded from: https://huggingface.co/khaimaitien/qa-expert-7B-V1.0-GGUF 152 | + **--num-paragraphs**: number of paragraphs retrieved for each query 153 | 154 | Here I already added 2 folders for testing: 155 | + **extra_data/test_data/cities**: List of 100 cities in United States, each is associated with a .txt file containing text from Wikipedia 156 | + **extra_data/test_data/states**: List of 50 states in United States, each is associated with a .txt file containing text from Wikipedia 157 | 158 | Some results: 159 | 160 | 161 | ## Training Data 162 | The training data was generated using **gpt-3.5-turbo-instruct** from OpenAI. 163 | You can find more detail from: [gen_data/README.md](gen_data/README.md). 164 | 165 | + Training data in multi-hop format: [khaimaitien/qa-expert-multi-hop-qa-V1.0](https://huggingface.co/datasets/khaimaitien/qa-expert-multi-hop-qa-V1.0) 166 | + Training data in OpenAI function calling format: [khaimaitien/qa-expert-multi-hop-qa-function-calling-format-V1.0](https://huggingface.co/datasets/khaimaitien/qa-expert-multi-hop-qa-function-calling-format-V1.0) 167 | ## Training 168 | We use **packing inputs without cross-contamination** to speed up the training. You can take a look at [train/README.md](train/README.md) 169 | ## Evaluation 170 | Please take a look at the Section **Evaluation** of [train/README.md](train/README.md#evaluation) 171 | ## Citation 172 | If you feel my work is helpful, please kindly cite as: 173 | ```bibtex 174 | @Misc{qa-expert, 175 | title={QA Expert: LLM for Multi-hop Question Answering}, 176 | author={Khai Mai}, 177 | howpublished={\url{https://github.com/khaimt/qa_expert}}, 178 | year={2023}, 179 | } 180 | ``` -------------------------------------------------------------------------------- /assets/1-shot-example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khaimt/qa_expert/e8a86031b366fa3d1967af989a6378a82890fab3/assets/1-shot-example.jpg -------------------------------------------------------------------------------- /assets/correct_packing_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khaimt/qa_expert/e8a86031b366fa3d1967af989a6378a82890fab3/assets/correct_packing_attention.png -------------------------------------------------------------------------------- /assets/cross_contamination.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khaimt/qa_expert/e8a86031b366fa3d1967af989a6378a82890fab3/assets/cross_contamination.png -------------------------------------------------------------------------------- /assets/demo1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khaimt/qa_expert/e8a86031b366fa3d1967af989a6378a82890fab3/assets/demo1.png -------------------------------------------------------------------------------- /assets/demo2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khaimt/qa_expert/e8a86031b366fa3d1967af989a6378a82890fab3/assets/demo2.png -------------------------------------------------------------------------------- /assets/demo3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khaimt/qa_expert/e8a86031b366fa3d1967af989a6378a82890fab3/assets/demo3.png -------------------------------------------------------------------------------- /assets/hotpot_qa_bridge.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khaimt/qa_expert/e8a86031b366fa3d1967af989a6378a82890fab3/assets/hotpot_qa_bridge.jpg -------------------------------------------------------------------------------- /assets/hotpot_qa_compare.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khaimt/qa_expert/e8a86031b366fa3d1967af989a6378a82890fab3/assets/hotpot_qa_compare.jpg -------------------------------------------------------------------------------- /eval_hotpot_qa.py: -------------------------------------------------------------------------------- 1 | from transformers import LlamaTokenizer, AutoModelForCausalLM, GenerationConfig 2 | from gen_data import utility 3 | import json 4 | import typer 5 | from sentence_transformers import SentenceTransformer, util 6 | from qa_expert.prompt_utils import SpecialToken, get_prompt_from_messages, Message, Role 7 | import numpy as np 8 | from qa_expert import get_inference_model, ModelInference, InferenceType 9 | import requests 10 | import shutil 11 | import os 12 | import datetime 13 | import string 14 | import re 15 | 16 | 17 | def create_paragraph(title, sens): 18 | return ". ".join(sens + [title]) 19 | 20 | 21 | def download_file(url: str) -> str: 22 | local_filename = url.split("/")[-1] 23 | with requests.get(url, stream=True) as r: 24 | with open(local_filename, "wb") as f: 25 | shutil.copyfileobj(r.raw, f) # type: ignore 26 | 27 | return local_filename 28 | 29 | 30 | def normalize_text(s: str) -> str: 31 | """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps.""" 32 | 33 | def remove_articles(text): 34 | regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) 35 | return re.sub(regex, " ", text) 36 | 37 | def white_space_fix(text): 38 | return " ".join(text.split()) 39 | 40 | def remove_punc(text): 41 | exclude = set(string.punctuation) 42 | return "".join(ch for ch in text if ch not in exclude) 43 | 44 | def lower(text): 45 | return text.lower() 46 | 47 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 48 | 49 | 50 | def compute_recall(prediction: str, truth: str) -> float: 51 | """Compute f1-score based on the individual words in prediction and truth 52 | 53 | Args: 54 | prediction (str): _description_ 55 | truth (str): _description_ 56 | 57 | Returns: 58 | float: _description_ 59 | """ 60 | pred_tokens = normalize_text(prediction).split() 61 | truth_tokens = normalize_text(truth).split() 62 | 63 | # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise 64 | if len(pred_tokens) == 0 or len(truth_tokens) == 0: 65 | return int(pred_tokens == truth_tokens) 66 | 67 | common_tokens = set(pred_tokens) & set(truth_tokens) 68 | 69 | # if there are no common tokens then f1 = 0 70 | if len(common_tokens) == 0: 71 | return 0 72 | 73 | # prec = len(common_tokens) / len(pred_tokens) 74 | rec = len(common_tokens) / len(truth_tokens) 75 | 76 | return rec 77 | 78 | 79 | def compute_containing_acc(prediction: str, truth: str) -> float: 80 | """if prediction (complete answer) contains truth (span answer) --> 1 else 0""" 81 | if normalize_text(truth) in normalize_text(prediction): 82 | return 1 83 | return 0 84 | 85 | 86 | def evaluate_hotpot_qa( 87 | model_path: str = typer.Option(default="khaimaitien/qa-expert-7B-V1.0"), 88 | hotpot_qa_dev_path: str = typer.Option(default="extra_data/hotpot_dev_distractor_v1_random_500.json"), 89 | inference_type: str = typer.Option(default="hf"), 90 | save_path: str = typer.Option(default=""), 91 | ): 92 | """This function is used to run evaluation on hotpot_qa dataset 93 | 94 | Args: 95 | model_path (str, optional): model to evaluate. Default="khaimaitien/qa-expert-7B-V1.0" 96 | hotpot_qa_dev_path (str, optional): hotpot_qa file to eval. Default="eval_data/hotpot_dev_distractor_v1_random_500.json". 97 | inference_type (str, optional): type of inference, you can use Vllm to reduce the evaluation time . Default="hf" 98 | save_path (str, optional): where to save the inference result, if empty, inference result is not saved. Default="" 99 | 100 | Returns: 101 | _type_: _description_ 102 | """ 103 | model_inference: ModelInference = get_inference_model(InferenceType(inference_type), model_path) 104 | retriever = SentenceTransformer("intfloat/e5-base-v2") 105 | 106 | examples = utility.read_json(hotpot_qa_dev_path) 107 | print("number of items: ", len(examples)) 108 | 109 | records = [] 110 | t1 = datetime.datetime.now() 111 | acc_time = 0.0 112 | avg_recall_list, avg_acc_list, is_multi_hop_acc_list = [], [], [] 113 | 114 | for index, example in enumerate(examples): 115 | question = example["question"] 116 | answer = example["answer"] 117 | context = example["context"] 118 | 119 | paragraphs = [create_paragraph(p[0], p[1]) for p in context] 120 | prefix_paragraphs = [f"passage: {p}" for p in paragraphs] # intfloat/e5-base-v2 requires to add passages: 121 | para_vectors = retriever.encode(prefix_paragraphs, normalize_embeddings=True) 122 | num_paragraphs = 1 123 | 124 | def retrieve(query: str): 125 | # intfloat/e5-base-v2 requires to add query: 126 | query_vec = retriever.encode([f"query: {query}"], normalize_embeddings=True) 127 | scores = util.cos_sim(query_vec, para_vectors)[0].tolist() 128 | s_indices = np.argsort(scores).tolist() 129 | s_indices.reverse() 130 | contexts = [paragraphs[index] for index in s_indices[:num_paragraphs]] 131 | return " ".join(contexts) 132 | 133 | try: 134 | pred_answer, messages = model_inference.generate_answer( 135 | question, retrieve, verbose=False, temperature=0.00001 136 | ) 137 | except Exception as e: 138 | pred_answer, messages = "", [] 139 | print(f"exception at this question: {question}: {str(e)}") 140 | pred_answer = str(pred_answer) 141 | 142 | t2 = datetime.datetime.now() 143 | acc_time = (t2 - t1).total_seconds() 144 | avg_time = acc_time / (index + 1) 145 | remaining_time = (len(examples) - index - 1) * avg_time 146 | record = { 147 | "question": question, 148 | "span_answer": answer, 149 | "messages": [mess.json(exclude_none=True) for mess in messages], 150 | "pred_answer": pred_answer, 151 | } 152 | if len(messages) > 4: 153 | is_multi_hop_acc_list.append(1) 154 | else: 155 | is_multi_hop_acc_list.append(0) 156 | records.append(record) 157 | recall = compute_recall(pred_answer, answer) 158 | avg_recall_list.append(recall) 159 | 160 | containing_acc = compute_containing_acc(pred_answer, answer) 161 | record["containing"] = containing_acc 162 | avg_acc_list.append(containing_acc) 163 | 164 | avg_is_multi_hop = sum(is_multi_hop_acc_list) / len(is_multi_hop_acc_list) 165 | avg_recall = sum(avg_recall_list) / len(avg_recall_list) 166 | avg_acc = sum(avg_acc_list) / len(avg_acc_list) 167 | print( 168 | ( 169 | f"{index + 1} / {len(examples)}, avg_time: {avg_time}, remaining time: {remaining_time}," 170 | f" Recall={avg_recall}, containing_acc: {avg_acc}, avg_is_multi_hop: {avg_is_multi_hop}" 171 | ) 172 | ) 173 | if len(save_path) > 0: 174 | utility.save_json(records, save_path) 175 | 176 | 177 | if __name__ == "__main__": 178 | typer.run(evaluate_hotpot_qa) 179 | -------------------------------------------------------------------------------- /extra_data/test_data/cities/Mesa-_Arizona.txt: -------------------------------------------------------------------------------- 1 | Mesa ( MAY-sə) is a city in Maricopa County, in the U.S. state of Arizona. It is the most populous city in the East Valley section of the Phoenix metropolitan area. It is bordered by Tempe on the west, the Salt River Pima–Maricopa Indian Community on the north, Chandler and Gilbert on the south along with Queen Creek, and Apache Junction on the east. 2 | Mesa is the third-largest city in Arizona after Phoenix and Tucson, the 36th-largest city in the US, and the largest city that is not a county seat. The city is home to 504,258 people as of 2020 according to the Census Bureau. A 2014 study by researchers determined Mesa to be "America's most conservative city".More than 40,000 students are enrolled in more than 10 colleges and universities located in Mesa. 3 | Mesa is home to the largest relief airport in the Phoenix area, Phoenix–Mesa Gateway Airport, located in the southeastern corner of the city. 4 | 5 | History 6 | The history of Mesa dates back at least 2,000 years to the arrival of the Hohokam people. The Hohokam, whose name means "All Used Up" or "The Departed Ones", built the original canal system. The canals were the largest and most sophisticated in the prehistoric New World. Some were up to 90 feet (27 m) wide and 10 feet (3.0 m) deep at their head gates, extending for as far as 16 miles (26 km) across the desert. By AD 1100 water could be delivered to an area over 110,000 acres (450 km2), transforming the Sonoran Desert into an agricultural oasis. By 1450, the Hohokam had constructed hundreds of miles of canals, many of which are still in use.After the disappearance of the Hohokam and before the arrival of the early settlers, little is known; explorers did not venture into this area. By the late 19th century near present-day Mesa, U.S. Army troops relocated the Apache, opening the way for settlement.In March 1877, Mormon pioneer Daniel Webster Jones and Henry Clay Rogers left St. George, Utah. Jones had been asked by Mormon officials to direct a party of people in establishing a settlement in Arizona. They traveled south and settled on the north side of the present Mesa area. This settlement was initially known as Fort Utah and later as Jonesville. It was located near Lehi Road. In 1883 it was named Lehi at the suggestion of Brigham Young, Jr.About this same time, another group dubbed the First Mesa Company arrived from Utah and Idaho. Their leaders were Francis Martin Pomeroy, Charles Crismon, George Warren Sirrine and Charles I. Robson. Rather than accepting an invitation to settle at Jones's Lehi settlement, they moved up onto the mesa that serves as the city's namesake. They dug irrigation canals and used some of the original Hohokam canals. By April 1878, water was flowing through them. The Second Mesa Company arrived in 1879 and settled to the west of the First Mesa Company because of more available farmland. This settlement was originally called Alma and later Stringtown. It was located near where Alma School Road.On July 17, 1878, Mesa City was registered as a 1-square-mile (2.6 km2) townsite. The first school was built in 1879. In 1883, Mesa City was incorporated with a population of 300 people. Dr. A. J. Chandler, who would later go on to found the city of Chandler, worked on widening the Mesa Canal in 1895 to allow for enough flow to build a power plant. In 1917, the city of Mesa purchased this utility company. The revenues from the company provided enough for capital expenditures until the 1960s. During the Great Depression, WPA funds provided paved streets, a new hospital, a new town hall and a library.After the founding of the city the elected official that most impacted the municipality was George Nicholas Goodman. He was mayor five different times during three different decades (1938–1942, 1946–1948, 1952–1956) (see: List of mayors of Mesa, Arizona). As mayor he was directly involved in the process of acquiring land for both Falcon Field and Williams Field.With the opening of Falcon Field and Williams Field in the early 1940s, more military personnel began to move into the Mesa area. With the advent of air conditioning and the rise of tourism, population growth exploded in Mesa as well as the rest of the Phoenix area. Industry—especially early aerospace companies—grew in the 1950s and 1960s. As late as 1960, half of the residents of Mesa made a living with agriculture, but that number declined substantially as Mesa's suburban growth continued on track with the rest of the Phoenix metro area. 7 | 8 | Geography 9 | Defining east and west Mesa 10 | Due to Mesa's long east to west travel distance, in excess of 18 miles (29 km) and large land area 133.13 square miles (344.8 km2), locations in Mesa are often referred to as residing within either West Mesa or East Mesa.Mesa employs a grid system for street numbering that is different from that used in Phoenix and other portions of the metropolitan area. Center Street, running north to south, bisects Mesa into eastern and western halves and serves as the east and west numbering point of origin within Mesa. Streets west of Center St., such as W. University Drive or W. Main St. are considered to be in West Mesa, whereas streets east of Center St., such as E. University or E. Main St., are considered to be in East Mesa. 11 | Mesa Drive, running north to south and bisecting Mesa into east and west sections, is located 0.5 miles (800 m) east of Center Street, and serves as the zip code boundary between the 85281, 85201, 85202, and 85210 zip codes of Western Mesa and the 85203, 85204, 85205, 85206, 85207, 85208, 85209, 85212, 85213, 85215, 85220, and 85242 zip codes of Eastern Mesa. 12 | 13 | Climate 14 | Located in the Sonoran Desert, Mesa has a hot desert climate (Köppen: BWh), with mild winters and very hot summers. 15 | 16 | Demographics 17 | According to the 2020 Census, the racial composition of Mesa was: 18 | White: 65.7% (Non-Hispanic Whites: 59.6%) 19 | Hispanic or Latino (of any race): 27.3% 20 | Black or African American: 4.2% 21 | Two or more races: 12.3% 22 | Native American: 2.7% 23 | Asian: 2.6% 24 | Native Hawaiian and Other Pacific Islander: 0.4%According to the 2010 Census, the racial composition of Mesa was: 25 | White: 83.8% (Non-Hispanic Whites: 62.6%) 26 | Hispanic or Latino (of any race): 27.4% 27 | Black or African American: 3.7% 28 | Two or more races: 3.1% 29 | Native American: 2.3% 30 | Asian: 2.0% 31 | Native Hawaiian and Other Pacific Islander: 0.4%As of the census of 2010, there were 439,041 people, 146,643 households, and 99,863 families residing in the city. The population density was 3,171.3 inhabitants per square mile (1,224.4/km2). There were 175,701 housing units at an average density of 1,405.7 per square mile (542.7/km2). 32 | The racial make-up of the city was 81.6% White, 2.4% Black or African American, 2.2% Native American, 2.0% Asian, 0.1% Pacific Islander, 9.3% from other races, and 1.3% from two or more races. 24.0% of the population were Hispanic or Latino of any race. 33 | There were 146,643 households, out of which 33.4% had children under the age of 18 living with them, 52.7% were married couples living together, 10.6% had a female householder with no husband present, and 31.9% were non-families. 24.2% of all households were made up of individuals, and 9.1% had someone living alone who was 65 years of age or older. The average household size was 2.68 and the average family size was 3.20. 34 | The age distribution was 27.3% under 18, 11.2% from 18 to 24, 29.7% from 25 to 44, 18.4% from 45 to 64, and 13.3% who were 65 or older. The median age was 32 years. For every 100 females, there were 98.2 males. For every 100 females age 18 and over, there were 95.6 males. 35 | The median income for a household in the city was $42,817, and the median income for a family was $49,232. Males had a median income of $35,960 versus $27,005 for females. The per capita income for the city was $19,601. About 6.2% of families and 8.9% of the population were below the poverty line, including 10.7% of those under age 18 and 7.1% of those age 65 or over. 36 | 37 | Political climate 38 | In a 2014 study, academic researchers from MIT and UCLA analyzed over a decade's worth of public opinion surveys. They determined that Mesa was the "most conservative American city of more than 250,000 residents". 39 | 40 | Economy 41 | Top employers 42 | According to the city's 2020 Comprehensive Annual Financial Report, the top employers in the city are: 43 | 44 | Cultural attractions 45 | HoHoKam Park of the Cactus League, home of the Oakland Athletics and former home of the Chicago Cubs during spring training, the WAC baseball tournament and former summer home to the now defunct Mesa Miners professional baseball team of the Golden Baseball League 46 | Sloan Park, opened in 2014 as the new Cactus League spring training home of the Chicago Cubs 47 | Mesa Arts Center 48 | Mesa Amphitheater 49 | Museums 50 | I.d.e.a. Museum 51 | Commemorative Air Force Arizona Wing Aircraft Museum, located at Falcon Field – B-17 Sentimental Journey 52 | Mesa Contemporary Arts Museum, Mesa Arts Center 53 | Mesa Historical Museum 54 | Arizona Museum of Natural History 55 | Archeological sites 56 | Mesa Grande Ruins 57 | Park of the Canals 58 | Public libraries 59 | Main Library (MN) 60 | Dobson Ranch Branch (DR) 61 | Mesa Express Library (MEL) 62 | Red Mountain Branch (RM) 63 | Water parks 64 | Golfland Sunsplash waterpark on U.S. 60 65 | The only highrise in Mesa is the Bank of America (formerly Western Savings) building near Fiesta Mall. 66 | Organ Stop Pizza, containing the world's largest Wurlitzer organ 67 | Bell Bank Park a 320-acre sports and recreation complex 68 | 69 | Historic properties in Mesa 70 | Numerous properties in the city are considered to be historical and have been included either in the National Register of Historic Places or the listings of the Mesa Historic Properties. 71 | 72 | Parks and recreation 73 | Mesa has over 2,280 acres of parkland in the city limits. Its largest is Red Mountain Park which spans 1,146 acres. It includes a lake, playgrounds, a basketball court and a cement volleyball court. 74 | 75 | Golf 76 | Mesa is home to numerous championship golf courses, including the original course in town, Mesa Country Club. This course was founded in the late 1940s by the original leaders of the town, and "Country Club Drive", the most prominent street in Mesa, was at one point the modest entrance to the club. 77 | 78 | West Mesa 79 | The abandoned Fiesta Mall is located in West Mesa and owned by Westcor. Its anchors were Sears and Best Buy. It is located near several shopping centers, Mesa's Bank of America, and other retail stores, banks, and restaurants. Though deserted, a refurbishment and expansion of the mall has been planned.Mesa Riverview is a new outdoor destination retail center in the northwestern corner of the city, near Loop 202 and Dobson Road. At build-out the center will include 1,300,000 square feet (120,000 m2) of retail space. Its anchors include Bass Pro Shops, Cinemark Theaters, Wal-Mart, and Home Depot. 80 | 81 | East Mesa 82 | Located in East Mesa is Superstition Springs Business Park. It includes the Superstition Springs Center, a shopping mall owned by Macerich. It features an outdoor amphitheatre and fountain which convert to a stage. Anchor stores at the mall are Dillard's, JCPenney, and Macy's. Mission Community Church, previously known as Superstition Springs Community Church, was initially named after this business park. 83 | 84 | Education 85 | Almost all of the city of Mesa is served by public schools operated by Mesa Public Schools; however, a small southern portion is served by the Gilbert Public Schools and the Queen Creek Unified School District, and a small western portion is served by the Tempe Elementary School District and the Tempe Union High School District. 86 | Pilgrim Lutheran School is a Christian Pre-K-8 school of the Wisconsin Evangelical Lutheran Synod in Mesa.More than 40,000 students are enrolled in more than 10 colleges and universities located in Mesa. Mesa is home to Mesa Community College, the largest of the Maricopa Community Colleges, which enrolls over 24,000 full and part-time students, and Chandler–Gilbert Community College. The Polytechnic campus of Arizona State University lies in southeast Mesa. This satellite campus enrolls over 6,000 undergraduate and graduate students in scientific and engineering fields. A. T. Still University operates an Osteopathic Medical School in Mesa. 87 | Private for-profit institutions include Arizona College, Carrington College, DeVry University, Pima Medical Institute, and CAE Global Academy Phoenix. Arizona State University opened the Media and Immersive eXperience Center in the ASU at Mesa City Center complex in 2022, offering programs from the Herberger Institute for Design and Arts including a film school with media production facilities and a theater.After launching a higher education initiative in 2012, Mesa became home to branch campuses of five private, liberal arts institutions: Albright College, Westminster College, Benedictine University, Upper Iowa University and Wilkes University. Albright College and Westminster College are no longer in the city, and Wilkes University has moved entirely online. 88 | 89 | Transportation 90 | Several area freeways serve the Mesa area, such as U.S. Route 60, locally known as the Superstition Freeway, which runs between Apache Junction and Phoenix. It is also served by SR 87 and bypass loops Loop 101, which skirts the western city limits as the Price Freeway, and Loop 202, which bypasses the city on the north and east. The main east–west arterial road in Mesa is Main Street (former US 60/70/80/89), serving Downtown Mesa. The primary north–south arterials include Country Club Drive, Gilbert Road, and Power Road. 91 | Public transportation in Mesa is provided by Valley Metro via bus and light rail (Valley Metro Rail). The light rail section in Mesa spans about four miles from Sycamore/Main St. in the west of the city, through downtown to Gilbert/Main St. Until July 2008, Mesa was the largest U.S. city with no public transit service on Sundays. The city has Sunday service available on Routes 40-Apache/Main, 45-Broadway, 61-Southern, 96-Dobson, 108-Elliot, 112-Country Club/Arizona, 156-Chandler/Williams Field, and 184-Power. Up until to the final years of Southern Pacific passenger railroad service, the Sunset Limited passenger train used to make stops in Mesa.Air service in the city is provided by two airports. Falcon Field, located in the northeastern part of the area, was established as a training field for British RAF pilots during World War II and was transferred to the city at the end of the war. Falcon Field has 605 aircraft based there. Boeing builds the AH-64 Apache attack helicopter at a facility adjoining Falcon Field. Phoenix-Mesa Gateway Airport is located in the far southeastern area of the city and provides alternate but limited air service when compared to Sky Harbor International Airport. Phoenix-Mesa Gateway was formerly Williams Gateway Airport, and before that, Williams Air Force Base, which closed in 1993. Williams Gateway was announced as a new Focus City for Allegiant Air. Service started October 25, 2007. 92 | 93 | Healthcare 94 | The public hospital system, Valleywise Health (formerly Maricopa Integrated Health System), operates Valleywise Community Health Center – Mesa and Valleywise Behavioral Health Center – Mesa. Its sole hospital, Valleywise Health Medical Center, is in Phoenix. 95 | 96 | Notable people 97 | Sister cities 98 | Mesa has five sister cities, as designated by Sister Cities International: 99 | 100 | Burnaby, British Columbia, Canada 101 | Caraz, Peru 102 | Guaymas, Mexico 103 | Kaiping, Guangdong, China 104 | Upper Hutt, New Zealand 105 | 106 | See also 107 | Arizona Commemorative Air Force Museum 108 | The Church of Jesus Christ of Latter-day Saints in Arizona 109 | City of Mesa Cemetery 110 | Life Teen 111 | Mesa Distance Learning Program 112 | Shooting of Daniel Shaver 113 | Tri-City Pavilions 114 | 115 | References 116 | Notes 117 | 118 | Bibliograph 119 | 120 | External links 121 | 122 | Official government website 123 | Mesa Arizona Convention and Visitors Bureau – Tourism 124 | Mesa news, sports and things to do from The Mesa Republic newspaper 125 | Mesa Public Library 126 | Mesa Chamber of Commerce -------------------------------------------------------------------------------- /extra_data/test_data/states/New_York.txt: -------------------------------------------------------------------------------- 1 | New York most commonly refers to: 2 | 3 | New York (state), a state in the northeastern United States 4 | New York City, the most populous city in the United States, located in the state of New YorkNew York may also refer to: 5 | 6 | Film and television 7 | New York (1916 film), a lost American silent comedy drama by George Fitzmaurice 8 | New York (1927 film), an American silent drama by Luther Reed 9 | New York (2009 film), a Bollywood film by Kabir Khan 10 | New York: A Documentary Film, a film by Ric Burns 11 | "New York" (Glee), an episode of Glee 12 | 13 | Literature 14 | New York (Burgess book), a 1976 work of travel and observation by Anthony Burgess 15 | New York (Morand book), a 1930 travel book by Paul Morand 16 | New York (novel), a 2009 historical novel by Edward Rutherfurd 17 | New York (magazine), a bi-weekly magazine founded in 1968 18 | 19 | Music 20 | New York EP, a 2012 EP by Angel Haze 21 | "New York" (Angel Haze song) 22 | New York (album), a 1989 album by Lou Reed 23 | "New York" (Eskimo Joe song) (2007) 24 | "New York" (Ja Rule song) (2004) 25 | "New York" (Paloma Faith song) (2009) 26 | "New York" (St. Vincent song) (2017) 27 | "New York" (Snow Patrol song) (2011) 28 | "New York" (U2 song) (2000) 29 | New York, a 2006 album by Antti Tuisku 30 | "New York", a 1977 song by the Sex Pistols from Never Mind the Bollocks, Here's the Sex Pistols 31 | 32 | Places 33 | United Kingdom 34 | New York, Lincolnshire 35 | New York, North Yorkshire 36 | New York, Tyne and Wear 37 | 38 | United States 39 | New York state 40 | New York metropolitan area, the region encompassing New York City and its suburbs 41 | New York County, covering the same area as the New York City borough of Manhattan 42 | New York, the US Postal Service address designating Manhattan 43 | Province of New York, a British colony preceding the state of New York 44 | 45 | Other states 46 | New York, Florida, an unincorporated community in Santa Rosa County 47 | New York, Iowa, a former town in Wayne County 48 | New York, Kentucky, an unincorporated community in Ballard County 49 | New York, Missouri, a ghost town in Scott County 50 | New York, Texas, an unincorporated community in Henderson County 51 | New York Mountain, a mountain in Colorado 52 | New York Mountains, a mountain range in California 53 | 54 | Ukraine 55 | New York, Ukraine, a settlement in Donetsk Oblast 56 | 57 | Ships 58 | Many ships have been named after the city or state of New York. See: 59 | 60 | List of ships named New York 61 | List of ships named City of New York 62 | List of ships named New York City 63 | 64 | Sports 65 | American football 66 | New York Giants, members of the East Division of the National Football Conference of the NFL (1925) 67 | New York Jets, members of the East Division of the American Football Conference of the NFL (1960) 68 | New York (World Series of Football), a professional football team for the World Series of Football (1902–1903) 69 | 70 | Baseball 71 | New York Mets, members of the East Division of the National League of MLB (1962) 72 | New York Yankees, members of the East Division of the American League of MLB (1903) 73 | 74 | Hockey 75 | New York Islanders, members of the Metropolitan Division of the Eastern Conference of the NHL (1972) 76 | New York Rangers, members of the Metropolitan Division of the Eastern Conference of the NHL (1926) 77 | 78 | Soccer 79 | New York City FC, a professional soccer team based in New York City that competes in the Eastern Conference of MLS (2015) 80 | New York Red Bulls, a professional soccer team that competes in the Eastern Conference of MLS (1996) 81 | New York Stadium in South Yorkshire, home ground of Rotherham United FC 82 | 83 | Other sports 84 | New York GAA, a county board of the Gaelic Athletic Association outside Ireland, responsible for Gaelic games in the New York metropolitan area 85 | New York Knicks, a professional basketball team, part of the Atlantic Division of the Eastern Conference in the NBA 86 | 87 | Other uses 88 | New York (pinball), a 1976 pinball machine by Gottlieb 89 | New York (1983 typeface), an Apple font set for original Macintosh computers 90 | New York (2019 typeface), a font set for developing software on Apple platforms 91 | New York Harbor, a waterfront in New York City 92 | Brooklyn Navy Yard, referred to as New York in naval histories 93 | Tiffany Pollard (born 1982), star of the reality TV show I Love New York who is nicknamed New York 94 | 95 | See also 96 | New York City (disambiguation) 97 | New York Cosmos (disambiguation) 98 | New York, New York (disambiguation) 99 | Nova Iorque, Brazilian municipality in the state of Maranhão 100 | Nowy Jork, former name of Łagiewniki, Włocławek County, Poland 101 | NY (disambiguation) 102 | SS New York, a list of ships 103 | SS New York City, a list of ships 104 | USS New York, a list of United States Navy ships and submarines 105 | All pages with titles beginning with New York 106 | All pages with titles containing New York -------------------------------------------------------------------------------- /extra_data/test_data/states/Washington.txt: -------------------------------------------------------------------------------- 1 | Washington most commonly refers to: 2 | 3 | George Washington (1732–1799), the first president of the United States 4 | Washington (state), United States 5 | Washington, D.C., the capital of the United States 6 | A metonym for the federal government of the United States 7 | Washington metropolitan area, the metropolitan area centered on Washington, D.C.Washington may also refer to: 8 | 9 | Places 10 | England 11 | Washington, Tyne and Wear, a town in the City of Sunderland metropolitan borough 12 | Washington Old Hall, ancestral home of the family of George Washington 13 | Washington, West Sussex, a village and civil parish 14 | 15 | Greenland 16 | Cape Washington, Greenland 17 | Washington Land 18 | 19 | Philippines 20 | New Washington, Aklan, a municipality 21 | Washington, a barangay in Catarman, Northern Samar 22 | Washington, a barangay in Escalante, Negros Occidental 23 | Washington, a barangay in San Jacinto, Masbate 24 | Washington, a barangay in Surigao City 25 | 26 | United States 27 | Washington, Wisconsin (disambiguation) 28 | Fort Washington (disambiguation) 29 | Lake Washington (disambiguation) 30 | Mount Washington (disambiguation) 31 | Port Washington (disambiguation) 32 | Washington Avenue (disambiguation) 33 | Washington Boulevard (disambiguation) 34 | Washington Bridge (disambiguation) 35 | Washington County (disambiguation) 36 | Washington district (disambiguation) 37 | Washington Island (disambiguation) 38 | Washington Park (disambiguation) 39 | Washington Square (disambiguation) 40 | Washington Street (disambiguation) 41 | Washington Township (disambiguation) 42 | Washington Valley (disambiguation) 43 | 44 | Cities and communities 45 | Washington, Alabama 46 | Washington, Arkansas 47 | Washington, California, in Nevada County 48 | Washington, Yolo County, California 49 | Washington, Connecticut 50 | Washington, Georgia 51 | Washington, Illinois 52 | Washington, Indiana 53 | Washington, Iowa 54 | Washington, Kansas 55 | Washington, Kentucky 56 | Washington, Louisiana 57 | Washington, Maine 58 | Washington, Massachusetts 59 | Washington, Michigan, an unincorporated community in Washington Township 60 | Washington, Mississippi 61 | Washington, Missouri 62 | Washington, Nebraska 63 | Washington, New Hampshire 64 | Washington, New Jersey 65 | Washington, New York 66 | Washington, North Carolina 67 | Washington, Oklahoma 68 | Washington, Pennsylvania 69 | Washington, Rhode Island 70 | Washington, Utah 71 | Washington, Vermont 72 | Washington, Virginia 73 | Washington, West Virginia 74 | Washington Court House, Ohio 75 | Washington-on-the-Brazos, Texas 76 | 77 | Elsewhere 78 | Washington Escarpment, Antarctica 79 | Washington, Ontario, Canada 80 | George Washington, Cuba (also known as Washington) 81 | Washington Island (French Polynesia) 82 | Washington Island (Kiribati) 83 | Washington, Guyana, a community in Mahaica-Berbice, Guyana 84 | 85 | Education 86 | Higher education 87 | In the United States 88 | University of Washington, Seattle, Washington 89 | George Washington University, Washington, D.C. 90 | Harold Washington College, Chicago, Illinois 91 | University of Mary Washington, Fredericksburg, Virginia 92 | Washington College, merged with Jefferson College in 1865 to form Washington & Jefferson College 93 | Washington College (California), formerly in Irvington, Fremont, California 94 | Washington College, Chestertown, Maryland 95 | Washington College, Connecticut, the original name of Trinity College 96 | Washington College of Law, at American University, Washington, D.C. 97 | Washington Female Seminary, Washington, Pennsylvania 98 | Washington Medical College, a defunct institution formerly in Baltimore, Maryland 99 | Washington University in St. Louis, Missouri 100 | 101 | Outside of the United States 102 | Washington International University, unaccredited institution in the British Virgin Islands 103 | Washington University of Barbados, named as part of an international medical school scam 104 | 105 | Secondary education 106 | Lake Washington High School, Kirkland, Washington 107 | Washington Academy (disambiguation) 108 | Washington College Academy, Limestone, Tennessee 109 | Booker T. Washington High School (disambiguation) 110 | George Washington High School (disambiguation) 111 | Washington County High School (disambiguation) 112 | Washington High School (disambiguation) 113 | Washington International School, Washington, D.C. 114 | Washington School (disambiguation) 115 | 116 | People 117 | Washington (name), a given name or surname 118 | Washington (musician), the stage name of Australian musician Megan Washington 119 | Washington (footballer, born 1953), Brazilian football forward Washington Luiz de Paula 120 | Washington (footballer, born 1 April 1975), Brazilian football manager and former striker Washington Stecanela Cerqueira 121 | Washington (footballer, born 10 April 1975), Brazilian football striker Washington Luiz Pereira dos Santos 122 | Washington (footballer, born August 1978), Brazilian football forward Washington Luiz Mascarenhas Silva 123 | Washington (footballer, born November 1978), Brazilian football forward Washington Luigi Garcia 124 | Washington (footballer, born 1985), Brazilian football striker Washington Roberto Mariano da Silva 125 | Washington (footballer, born May 1986), Brazilian football striker Washington de Mesquista Ferreira 126 | Washington (footballer, born November 1986), Brazilian football midfielder Cezar Washington Alves Portela 127 | Washington (footballer, born 1989), Brazilian football midfielder Washington Santana da Silva 128 | 129 | Ships 130 | SS Washington (1930), an ocean liner 131 | SS Washington (1941), a cargo ship 132 | USS Washington, several U.S. Navy ships 133 | Washington (steamboat 1851) 134 | 135 | Sports 136 | In the Washington D.C. metropolitan area 137 | Washington Capitals, professional ice hockey team of the National Hockey League 138 | Washington Nationals, professional baseball team of Major League Baseball 139 | Washington Mystics, professional basketball team of the Women's National Basketball Association 140 | Washington Commanders, formerly Washington Redskins, professional American football team of the National Football League 141 | Washington Wizards, professional basketball team of the National Basketball Association 142 | D.C. United, professional soccer team of the Major League Soccer 143 | 144 | In Washington (state) 145 | Division I 146 | Eastern Washington Eagles, athletic teams of Eastern Washington University in Cheney, Washington 147 | Washington Huskies, athletic teams of the University of Washington in Seattle, Washington 148 | Washington State Cougars, athletic teams of Washington State University in Pullman, Washington 149 | 150 | Division II 151 | Central Washington Wildcats, athletic teams of Central Washington University in Ellensburg, Washington 152 | Western Washington Vikings, athletic teams of Western Washington University in Bellingham, Washington 153 | 154 | Elsewhere 155 | Washington F.C., a football club based in Washington, Tyne and Wear, England 156 | Washington University Bears, athletic teams of Washington University in St. Louis, Missouri, USA 157 | 158 | Other uses 159 | Washington station (disambiguation) 160 | Washington (tree), a giant sequoia in Sequoia National Park, California, US 161 | Boeing Washington, British designation for the Boeing B-29 Superfortress 162 | 163 | See also 164 | Washingtonian (disambiguation) 165 | All pages with titles beginning with Washington 166 | All pages with titles containing Washington -------------------------------------------------------------------------------- /gen_data/README.md: -------------------------------------------------------------------------------- 1 | # Training Data Creation 2 | Each multi-hop question can be handled by decomposing it into single questions. This datasets contains multi-hop questions and their decomposed questions. We also add single questions to this dataset to make sure that the trained model is able to handle all kinds of questions. 3 | 4 | **You can download our training data from here:** [khaimaitien/qa-expert-multi-hop-qa-V1.0](https://huggingface.co/datasets/khaimaitien/qa-expert-multi-hop-qa-V1.0) 5 | 6 | This dataset contains 25.5k for training and 3.19k for evaluation. 7 | 8 | - [Training Data Creation](#training-data-creation) 9 | - [Format](#format) 10 | - [Generate New Training Data](#generate-new-training-data) 11 | - [Multi-hop Questions asking about an attribute of 2 entities in a question](#multi-hop-questions-asking-about-an-attribute-of-2-entities-in-a-question) 12 | - [Multi-hop Questions asking about 2 attributes of an entity in a question](#multi-hop-questions-asking-about-2-attributes-of-an-entity-in-a-question) 13 | - [Negative Paragraph Generation](#negative-paragraph-generation) 14 | - [Single Questions](#single-questions) 15 | - [Using available training datasets](#using-available-training-datasets) 16 | - [Lisf of Scripts for generating data](#lisf-of-scripts-for-generating-data) 17 | - [Script for generating sub-category from given category:](#script-for-generating-sub-category-from-given-category) 18 | - [Script for generating multi-hop Questions asking about an attribute of 2 entities](#script-for-generating-multi-hop-questions-asking-about-an-attribute-of-2-entities) 19 | - [Script for generating multi-hop Questions asking about 2 attributes of an entity in a question](#script-for-generating-multi-hop-questions-asking-about-2-attributes-of-an-entity-in-a-question) 20 | - [Script for generating data points with negative paragraphs](#script-for-generating-data-points-with-negative-paragraphs) 21 | - [Script for generating answers to the single questions and final multi-hop questions](#script-for-generating-answers-to-the-single-questions-and-final-multi-hop-questions) 22 | 23 | 24 | ## Format 25 | Each data point is a Json with fields: 26 | + **question**: the question, can be single question or multi-hop question 27 | + **multihop**: True/False whether the question is multihop or not 28 | + **sub_questions**: List of decomposed single questions from question. If the question is single question, ```len(sub_questions) == 1``` 29 | + **question**: single question decomposed from original multi-hop question 30 | + **paragraph**: the retrieval context for the single question 31 | + **long_answer**: the answer to the single question, the format is: xxx\nAnswer:yyy where xxx is the reasoning (thought) before generte answer to the question. 32 | + **final_answer**: The final answer to the question. If the question is multihop, this has the form: Summary:xxx\nAnswer:yyy Where xxx is the summary of anwers from decomposed single questions before generating final answer: yyy 33 | + **answer**: Can ignore this field 34 | + **meta_info**: contains the information about how the data point was created 35 | + **tag**: the information about type of data, for example: 36 | + musique-train.json: train data from musique 37 | + entities-neg_train.json: data points from generating question related to 2 entities with **negative paragraph**. 38 | + ... 39 | 40 | 41 | ## Generate New Training Data 42 | We found that not much available public training data for multi-hop Q&A so we decided to create new training data using **gpt-3.5-turbo-instruct** - an OpenAI Model. Actually we create 2 kinds of multi-hop questions: 43 | 44 | ### Multi-hop Questions asking about an attribute of 2 entities in a question 45 | 46 | Here are some examples for these questions, **entities** are highlighted. 47 | 48 | + In which year were the **Seattle Public Library** and **Denver Public Library** built? 49 | + Is **Kailua Beach** more popular than **Waikiki Beach**? 50 | + How do the **Giant Anteater** and the **Lesser Anteater** differ in their reproduction processes? 51 | 52 | Here is the flow to generate this kind of data: 53 | The flow is: 54 | + Step 1: choose a random category (from ../gen_data/other_files/sub_categories.txt) 55 | + Step 2: generate 2 entries from this category: **entity 1**, **entity 2** 56 | + Step 3: generate a list of **common attributes** of these 2 entities 57 | + Step 4: select a random attribute from the generated list --> **selected attribute** 58 | + Step 5: Generate **question 1** asking for the **selected attribute** of **entity 1** 59 | + Step 6: Generate **question 2** asking for the **selected attribute** of **entity 2** 60 | + Step 7: Generate **multi-hop question** that decomposed into **question 1** and **question 2**. This, for example, can be the question comparing the **selected attribute** of **entity 1** and **entity 2** 61 | + Step 8: Generate **paragraph 1** containing the information about the **selected attribute** of **entity 1** 62 | + Step 9: Generate the **reasoning 1** (thought) to answer **question 1** based on **paragraph 1** 63 | + Step 10: Generate the complete **answer 1** to the **question 1** based on the **reasoning 1** 64 | + Step 11: Generate **paragraph 2** containing the information about the **selected attribute** of **entity 2** 65 | + Step 12: Generate the **reasoning 2** (thought) to answer **question 2** based on **paragraph 2** 66 | + Step 13: Generate the complete **answer 2** to the **question 2** based on the **reasoning 2** 67 | + Step 14: summarize the points from **answer 1** and **answer 2** to generate the final answer to the **multi-hop question** 68 | + Step 15: Generate the reasoning (thought) to answer the **multi-hop question** based on the **summary** 69 | + Step 16: Generate the final answer to **multi-hop question** based on the reasoning 70 | 71 | We implement this flow using the prompt: [2_entities.txt](https://github.com/khaimt/qa_expert/blob/main/gen_data/prompts/2_entities.txt) with model: **gpt-3.5-turbo-instruct**. To make the generation more creative and diverse, we used temperature=0.7 --> 1. However, we found that with these high temperatures, the step for generating answers such as step 10, 13 and 16 would be **vulerable to hallucination**. So in reality, we can split the flow into 2 parts: 72 | 73 | + Part 1: for generating questions and paragraphs (step 1 -> 8, step 11), **temperature=0.7 -> 1**, prompt=[2_entities.txt_wo_answer.txt](https://github.com/khaimt/qa_expert/blob/main/gen_data/prompts/2_entities_wo_answer.txt) 74 | + Part 2: for generating reasonings and answers (step 9, 10, 12 -> 16). Prompt=[answer_gen.txt](https://github.com/khaimt/qa_expert/blob/main/gen_data/prompts/answer_gen.txt) for step 9, 10 and step 12, 13 (generating reasonings and answers for single question 1 and single question 2). Prompt = [final_answer_gent.txt](https://github.com/khaimt/qa_expert/blob/main/gen_data/prompts/final_answer_gen.txt) for step 14, 15, 16 (generating reasonings and answers for multi-hop question). Using **temperature=0** for this part. 75 | 76 | **Some tricks for generating a diverse dataset:** 77 | + The purpose of choosing a random category is to diversify the training data. At first, we manually prepared a list of [125 general categories](https://github.com/khaimt/qa_expert/blob/main/gen_data/other_files/category_list.txt) derived from [Sekine’s Extended Named Entities](https://nlp.cs.nyu.edu/ene/version7_1_0Beng.html). But we found that it was not really diverse enough, so we decided to continue to split these categories into more fine-grained categories. For each category, we used the prompt: [sub_category_gen.txt](https://github.com/khaimt/qa_expert/blob/main/gen_data/prompts/sub_category_gen.txt) to generate more fine-grained categories. Finally, we attained [3750 smaller categories](https://github.com/khaimt/qa_expert/blob/main/gen_data/other_files/sub_categories.txt) using **gpt-3.5-turbo-instruct**. You can use the script: ```python -m gen_data.gen_sub_categories``` to generate fine-grained categories from given general categories. 78 | + The reason why we need **step 3**: Generate a list of common attributes and then **step 4**: Select a random attribute from this list is to make the training data **more diverse**. If only step 4, I found that, for example, if category=City, the model (gpt-3.5-turbo-instruct) would choose attribute=population 90% of the times although the temperature was already **1**. 79 | + During generating data, at **step 2**: generating 2 entities, I also added some randomness: {popularity_1} and {popularity_2} a random value from 1 --> 5, you can take a look at [the prompt](https://github.com/khaimt/qa_expert/blob/main/gen_data/prompts/2_entities.txt#L3C8-L3C8): 80 | + ```Entity 1: generate a random entity of this category such that level of popularity on internet is {popularity_1} out of 5``` 81 | 82 | If the prompt was only: ```Entity 1: generate a random entity of this category``` it would only generate the popular ones even the temperature > 1. For example, for category=City in Asia, the model usually generate famous ones such as: Tokyo, Shanghai, ... instead of the less popular ones like: Da Nang, Bandung, ... 83 | + In the prompt, I also replaced {question_type} with a random value of: ["wh question", "yes/no question"], because I found that the model was more likely to generate Wh question other than yes/no question. 84 | 85 | ### Multi-hop Questions asking about 2 attributes of an entity in a question 86 | Here are some examples for these questions: 87 | + Does Jim Gaffigan have a high net worth and is he married? 88 | + **--> Entity=Jim Gaffigan; attributes: Net worth and Spouse** 89 | + Did the 2011 Tohoku earthquake and tsunami have a high magnitude and were there many casualties during it? 90 | + **--> Entity=2011 Tohoku earthquake; attributes: Magnitude & Casualties** 91 | 92 | For this type of questions, the flow is almost the same as the flow for generating: [Multi-hop Questions asking about an attribute of 2 entities in a question](#multi-hop-questions-asking-about-2-attributes-of-an-entity-in-a-question) 93 | 94 | The prompts I used are: 95 | + Prompt for generating all steps: [2_attributes.txt](https://github.com/khaimt/qa_expert/blob/main/gen_data/prompts/2_attributes.txt) 96 | + Prompt for generating generating part 1 (questions and paragraphs): [2_attributes_wo_answer.txt](https://github.com/khaimt/qa_expert/blob/main/gen_data/prompts/2_entities_wo_answer.txt). I used the same prompt for generating part 2 (reasonings and answers) 97 | 98 | ### Negative Paragraph Generation 99 | Besides generating paragraphs that contain the answer to the single questions, I also generate data points that the paragraphs don't contain the answer. I called these: **negative paragraphs**. I randomly picked 1200 data points from each multi-hop entities questions and multi-hop attributes questions to generate data points with **negative paragraphs** 100 | 101 | Assume that the chosen data point is: x = (e1, a1, q1, p1, e2, a2, q2, p2) where: 102 | + e1: entity 1, a1: attribute 1, q1: single question 1; p1: paragraph of single question 1 containing information about attribute a1 of e1. 103 | + e2: entity 2, a2: attribute 2, q2: single question 2; p2: paragraph of single question 2 containing information about attribute a2 of e2. 104 | 105 | If x is from 2 entities data, a1 = a2 = selected attribute; e1 != e2 106 | If x is from 2 attributes data, e1 = e2 = selected entity, a1 != a2 107 | 108 | To create negative context data, I randomly picked one of 3 options: 109 | + replace q1 with a **negative paragraph** 110 | + replace q2 with a **negative paragraph** 111 | + replace both q1 and q2 with new **negative paragraph** 112 | 113 | For example, if we want to generate a **negative paragraph** from original paragraph written for entity: **e** and attribute **a**. We first create a prompt for generating paragraph for an entity and an attribute (this prompt is: [gen_paragraph](https://github.com/khaimt/qa_expert/blob/main/gen_data/prompts/gen_paragraph.txt)) then we use this prompt to generate negative paragraph by: 114 | + replacing **e** with another new entity of the same category, using prompt: [gen_new_entity.txt](https://github.com/khaimt/qa_expert/blob/main/gen_data/prompts/gen_new_entity.txt) 115 | + replacing **a** with another new attribute from attribute list generated at generating data 116 | + Or replacing both **e** and **a** with new entity and new attribute 117 | 118 | For example, we have the original paragraph for entity="Tokyo", attribute="GDP". We can generate a negative paragraph by using the prompt: [gen_paragraph](https://github.com/khaimt/qa_expert/blob/main/gen_data/prompts/gen_paragraph.txt) with 119 | + entity="Shanghai", attribute="GDP" --> replacing entity 120 | + entity="Tokyo", attribute="attraction" --> replacing attribute 121 | + entity="Shanghai", attribute="attraction" --> replacing both entity and attribute 122 | 123 | When we replacing the original paragraphs with negative paragraphs, we need to update the answers to single questions and final answer to multi-hop question using prompts: [answer_gen.txt](https://github.com/khaimt/qa_expert/blob/main/gen_data/prompts/answer_gen.txt) and [final_answer_gen.txt](https://github.com/khaimt/qa_expert/blob/main/gen_data/prompts/answer_gen.txt) as desribed in Part 2 of previous section. 124 | 125 | You can see the script for generating negative-paragraphs data points in the section [List of Scripts](#lisf-of-scripts) 126 | 127 | ### Single Questions 128 | To create data points that the question is a single question instead of multi-hop question (field ``multihop=False``), we just used the single questions in multi-hop questions 129 | 130 | ## Using available training datasets 131 | We found that [Musique](https://github.com/StonyBrookNLP/musique) is the most suitable dataset for multi-hop Q&A for brideging entity so we made use of this. Here are the steps we process this dataset: 132 | + Remove data points containing single questions that are not well-formed (containing: **">>"**), such as: "Stadio Ciro Vigorito >> occupant" 133 | + For each single question, we generated the **complete answers** because this dataset only contains span answer for questions. Complete answers of single questions were generated using prompt: [answer_gen.txt](https://github.com/khaimt/qa_expert/blob/main/gen_data/prompts/answer_gen.txt) and complete answers of final questions were generated using prompt: [final_answer_gent.txt](https://github.com/khaimt/qa_expert/blob/main/gen_data/prompts/final_answer_gen.txt) 134 | + Remove data points that generated **complete answers** don't contain the span answer 135 | 136 | You can find data points of this dataset by finding ones that whose field ``tag`` contains: **"musique"** 137 | 138 | 139 | ## Lisf of Scripts for generating data 140 | 141 | ### Script for generating sub-category from given category: 142 | ```shell 143 | python -m gen_data.gen_sub_categories \ 144 | --category-path gen_data/other_files/category_list.txt \ 145 | --save_path Where_to_save.json 146 | ``` 147 | 148 | ### Script for generating multi-hop Questions asking about an attribute of 2 entities 149 | 150 | **Note that to run the script, you need to set the OPENAI_API_KEY first by:** 151 | ```shell 152 | export OPENAI_API_KEY=YOUR_KEY_HERE 153 | ``` 154 | 155 | Example: 156 | ```shell 157 | python -m gen_data.gen_multi_hop_entities\ 158 | --category-path gen_data/other_files/sub_category_gen.txt \ 159 | --num-items-per-category 1 \ 160 | --output-folder save_folder/entities \ 161 | --multi-qa-prompt gen_data/prompts/new_prompts/2_entities_wo_answer.txt.txt \ 162 | --temperature 0.7 \ 163 | --re-generate-answer 164 | ``` 165 | Please read more information about arguments in the [gen_data/gen_multi_hop_entities.py](https://github.com/khaimt/qa_expert/blob/main/gen_data/gen_multi_hop_entities.py) 166 | 167 | ### Script for generating multi-hop Questions asking about 2 attributes of an entity in a question 168 | Example: 169 | ```shell 170 | python -m gen_data.gen_multi_hop_attributes\ 171 | --category-path gen_data/other_files/sub_category_gen.txt \ 172 | --num-items-per-category 1 \ 173 | --output-folder save_folder/attributes \ 174 | --multi-qa-prompt gen_data/prompts/new_prompts/2_attributes_wo_answer.txt \ 175 | --temperature 0.7 \ 176 | --re-generate-answer 177 | ``` 178 | ### Script for generating data points with negative paragraphs 179 | Example: 180 | ```shell 181 | python -m gen_data.gen_negatives \ 182 | --input-path multi-hop_data_path \ 183 | --save-path JSON_RESULT_FILE.json \ 184 | --gen_num 1200 \ 185 | ``` 186 | + multi-hop_data_path: is the json file of the generated data 187 | 188 | ### Script for generating answers to the single questions and final multi-hop questions 189 | This script will fill in ``long_answer`` and ``final_answer`` in the input_json file 190 | Example: 191 | ```shell 192 | python -m gen_data.gen_answer \ 193 | --input-path: input_path \ 194 | --output-path: output_path 195 | ``` 196 | + input_path: the Json file containing data points that field ``long_answer`` in "sub_questions" is Null -------------------------------------------------------------------------------- /gen_data/gen_answer.py: -------------------------------------------------------------------------------- 1 | from gen_data.gen_task import GenAnswer 2 | import os 3 | import typer 4 | 5 | 6 | def main( 7 | input_path: str, 8 | output_path: str, 9 | continue_gen: bool = typer.Option(True, "--no-continue"), 10 | llm: str = typer.Option(default="gpt-3.5-turbo-instruct"), 11 | prompt_type: str = typer.Option(default="openai"), 12 | ): 13 | print("Start to re-generate answers for single questions and final multi-hop questions") 14 | if os.path.exists(output_path) and not continue_gen: 15 | os.remove(output_path) 16 | 17 | kwargs = { 18 | "input_path": input_path, 19 | "subquestion_prompt": "gen_data/prompts/answer_gen.txt", 20 | "final_prompt": "gen_data/prompts/final_answer_gen.txt", 21 | "llm": llm, 22 | "prompt_type": prompt_type, 23 | "temperature": 0.0001, 24 | } 25 | answer_task = GenAnswer(output_path, **kwargs) 26 | answer_task.run() 27 | 28 | 29 | if __name__ == "__main__": 30 | typer.run(main) 31 | -------------------------------------------------------------------------------- /gen_data/gen_multi_hop_attributes.py: -------------------------------------------------------------------------------- 1 | from gen_data.gen_task import GenAttributeMerge, GenAnswer 2 | from gen_data.gen_task import utility 3 | import typer 4 | import os 5 | 6 | 7 | def main( 8 | num_items_per_category: int = typer.Option(default=1), 9 | output_folder: str = typer.Option(default="gen_qa"), 10 | re_generate_answer: bool = typer.Option(False, "--re-generate-answer"), 11 | category_path: str = typer.Option(default="gen_data/other_files/sub_categories.txt"), 12 | continue_gen: bool = typer.Option(True, "--no-continue"), 13 | multi_qa_prompt: str = typer.Option(default="gen_data/prompts/2_attributes_wo_answer.txt"), 14 | gen_paragraph_prompt: str = typer.Option(default="gen_data/prompts/gen_long_paragraphs.txt"), 15 | temperature: float = typer.Option(default=0), 16 | llm: str = typer.Option(default="gpt-3.5-turbo-instruct"), 17 | prompt_type: str = typer.Option(default="openai"), 18 | ): 19 | """this function is used to generate multi-hop Q&A 20 | 21 | Args: 22 | num_items_per_category (int, optional): number of generated items for each category. Defaults to typer.Option(default=100). 23 | output_folder (str, optional): where to save the result. Defaults to typer.Option(default="gen_qa"). 24 | re_generate_answer (bool, optional): If we re-generate the answers to single questions and final answer to the multi-hop question or not. 25 | if re-generate, we will use the prompt template for generating the answer + temperature=0 26 | category_path (str, optional): The path to list of categories. Defaults to typer.Option(default="extra_files/categories.txt"). 27 | continue_gen (bool, optional): if we continue to generate from current result or not. Defaults to typer.Option(True, "--no-continue"). 28 | """ 29 | if not os.path.exists(output_folder): 30 | utility.create_folder(output_folder) 31 | kwargs = { 32 | "category_path": category_path, 33 | "prompt": multi_qa_prompt, 34 | "num_items_per_category": num_items_per_category, 35 | "temperature": temperature, 36 | "llm": llm, 37 | "prompt_type": prompt_type, 38 | "paragraph_prompt": gen_paragraph_prompt, 39 | } 40 | print("kwargs: ", kwargs) 41 | multi_hop_qa_path = os.path.join(output_folder, "raw_multi_hop_qa.json") 42 | if os.path.exists(multi_hop_qa_path) and not continue_gen: 43 | os.remove(multi_hop_qa_path) 44 | print("Start to generate multi-hop QA now") 45 | task = GenAttributeMerge(multi_hop_qa_path, **kwargs) 46 | task.run() 47 | 48 | final_path = os.path.join(output_folder, "final.json") 49 | if re_generate_answer: 50 | print("Start to re-generate answers for single questions and final multi-hop questions") 51 | if os.path.exists(final_path) and not continue_gen: 52 | os.remove(final_path) 53 | 54 | kwargs = { 55 | "input_path": multi_hop_qa_path, 56 | "subquestion_prompt": "gen_data/prompts/answer_gen.txt", 57 | "final_prompt": "gen_data/prompts/final_answer_gen.txt", 58 | "llm": llm, 59 | "prompt_type": prompt_type, 60 | "temperature": 0.0001, 61 | } 62 | answer_task = GenAnswer(final_path, **kwargs) 63 | answer_task.run() 64 | 65 | 66 | if __name__ == "__main__": 67 | typer.run(main) 68 | -------------------------------------------------------------------------------- /gen_data/gen_multi_hop_entities.py: -------------------------------------------------------------------------------- 1 | from gen_data.gen_task import GenEntityComparison, GenAnswer 2 | from gen_data.gen_task import utility 3 | import typer 4 | import os 5 | 6 | 7 | def post_process_entry_text(entry_text: str) -> str: 8 | prefixs = ["generate 2 random entries of Event. The format:", "Generate 2 random entries of Event. The format:"] 9 | for prefix in prefixs: 10 | if entry_text.startswith(prefix): 11 | entry_text = entry_text[len(prefix) :].strip() 12 | return entry_text 13 | 14 | 15 | def main( 16 | num_items_per_category: int = typer.Option(default=1), 17 | output_folder: str = typer.Option(default="gen_qa"), 18 | re_generate_answer: bool = typer.Option(False, "--re-generate-answer"), 19 | category_path: str = typer.Option(default="gen_data/other_files/sub_categories.txt"), 20 | continue_gen: bool = typer.Option(True, "--no-continue"), 21 | multi_qa_prompt: str = typer.Option(default="gen_data/prompts/2_entities_wo_answer.txt"), 22 | gen_paragraph_prompt: str = typer.Option(default="gen_data/prompts/gen_long_paragraphs.txt"), 23 | temperature: float = typer.Option(default=0), 24 | llm: str = typer.Option(default="gpt-3.5-turbo-instruct"), 25 | prompt_type: str = typer.Option(default="openai"), 26 | ): 27 | """this function is used to generate multi-hop Q&A 28 | 29 | Args: 30 | num_items_per_category (int, optional): number of generated items for each category. Defaults to typer.Option(default=100). 31 | output_folder (str, optional): where to save the result. Defaults to typer.Option(default="gen_qa"). 32 | re_generate_answer (bool, optional): If we re-generate the answers to single questions and final answer to the multi-hop question or not. 33 | if re-generate, we will use the prompt template for generating the answer + temperature=0 34 | category_path (str, optional): The path to list of categories. Defaults to typer.Option(default="extra_files/categories.txt"). 35 | continue_gen (bool, optional): if we continue to generate from current result or not. Defaults to typer.Option(True, "--no-continue"). 36 | temperature: if you have a small number of categories or if you really prefer the diversity set temperature=1 37 | if you are more concerned about quality and the number of categories is big, set temperature=0 38 | """ 39 | if not os.path.exists(output_folder): 40 | utility.create_folder(output_folder) 41 | kwargs = { 42 | "category_path": category_path, 43 | "prompt": multi_qa_prompt, 44 | "num_items_per_category": num_items_per_category, 45 | "temperature": temperature, 46 | "llm": llm, 47 | "prompt_type": prompt_type, 48 | "paragraph_prompt": gen_paragraph_prompt, 49 | } 50 | print("kwargs: ", kwargs) 51 | multi_hop_qa_path = os.path.join(output_folder, "raw_multi_hop_qa.json") 52 | if os.path.exists(multi_hop_qa_path) and not continue_gen: 53 | os.remove(multi_hop_qa_path) 54 | print("Start to generate multi-hop QA now") 55 | task = GenEntityComparison(multi_hop_qa_path, **kwargs) 56 | task.run() 57 | final_path = os.path.join(output_folder, "final.json") 58 | if re_generate_answer: 59 | print("Start to re-generate answers for single questions and final multi-hop questions") 60 | if os.path.exists(final_path) and not continue_gen: 61 | os.remove(final_path) 62 | kwargs = { 63 | "input_path": multi_hop_qa_path, 64 | "subquestion_prompt": "gen_data/prompts/answer_gen.txt", 65 | "final_prompt": "gen_data/prompts/final_answer_gen.txt", 66 | "llm": llm, 67 | "prompt_type": prompt_type, 68 | "temperature": 0.0001, 69 | } 70 | answer_task = GenAnswer(final_path, **kwargs) 71 | answer_task.run() 72 | 73 | 74 | if __name__ == "__main__": 75 | typer.run(main) 76 | -------------------------------------------------------------------------------- /gen_data/gen_negatives.py: -------------------------------------------------------------------------------- 1 | import typer 2 | from gen_data.gen_task import GenNegativeParagraph 3 | 4 | 5 | def main( 6 | input_path: str = typer.Option(""), 7 | save_path: str = typer.Option(""), 8 | gen_num: int = typer.Option(1000), 9 | paragraph_prompt: str = typer.Option("gen_data/prompts/gen_paragraph.txt"), 10 | new_entity_prompt: str = typer.Option("gen_data/prompts/gen_new_entity.txt"), 11 | ): 12 | assert len(input_path) > 0 13 | assert len(save_path) > 0 14 | kwargs = { 15 | "input_path": input_path, 16 | "gen_num": gen_num, 17 | "paragraph_prompt": paragraph_prompt, 18 | "new_entity_prompt": new_entity_prompt, 19 | } 20 | task = GenNegativeParagraph(save_path, **kwargs) 21 | task.run() 22 | 23 | 24 | if __name__ == "__main__": 25 | typer.run(main) 26 | -------------------------------------------------------------------------------- /gen_data/gen_sub_categories.py: -------------------------------------------------------------------------------- 1 | from gen_data import gen_task, utility 2 | import typer 3 | 4 | 5 | def generate_sub_category( 6 | category_path: str, save_path: str, prompt_path: str = typer.Option("gen_data/prompts/sub_category_gen.txt") 7 | ): 8 | task = gen_task.GenSubCategory(save_path, **{"prompt": prompt_path, "category_path": category_path}) 9 | task.run() 10 | 11 | 12 | if __name__ == "__main__": 13 | typer.run(generate_sub_category) 14 | -------------------------------------------------------------------------------- /gen_data/other_files/category_list.txt: -------------------------------------------------------------------------------- 1 | Asisan Food 2 | Newspaper 3 | Bay 4 | Research Institute 5 | Flower 6 | Aircraft 7 | Political Party 8 | Athlete 9 | Process 10 | Material 11 | Sea 12 | School 13 | President 14 | Sport 15 | Award 16 | Singer 17 | Broadcast Program 18 | Sports Facility 19 | Mountain 20 | Weapon 21 | Conference 22 | Road 23 | Resort 24 | Sport Team 25 | Bird 26 | Ship 27 | Philosopher 28 | Laptop 29 | Port 30 | Restaurant 31 | Comedian 32 | Lake 33 | Zoo 34 | Method 35 | Province 36 | Market 37 | Hospital 38 | Company Group 39 | Canal 40 | Writer 41 | City 42 | Theory 43 | European Food 44 | Inventor 45 | Theater 46 | County 47 | Game Show 48 | Planet 49 | Region 50 | River 51 | Music 52 | Scientist 53 | Constellation 54 | War 55 | Spaceship 56 | Electric Car 57 | Company 58 | Railroad 59 | Public Library 60 | American Food 61 | Flora 62 | Hotel 63 | Law 64 | Smart Phone 65 | Businessman 66 | Drug 67 | Airport 68 | Actress 69 | Concept 70 | Musician 71 | Phenomenon 72 | Amusement Park 73 | Shopping 74 | Military 75 | Physician 76 | Earthquake 77 | Politician 78 | Public Institution 79 | Magazine 80 | Culture 81 | Fungus 82 | Country 83 | Decoration 84 | Sports League 85 | Mathematician 86 | Language 87 | Museum 88 | Policy 89 | Explorer 90 | Natural Phenomenon 91 | Bridge 92 | Dancer 93 | Disease 94 | Mammal 95 | Worship Place 96 | Reptile 97 | Station 98 | Car 99 | Beach 100 | Fish 101 | Historical Figure 102 | Artist 103 | Fashion 104 | Job 105 | Service 106 | Natural Disaster 107 | Book 108 | Insect 109 | Star 110 | Convention 111 | Movie 112 | Park 113 | Game 114 | High School 115 | Show 116 | Train 117 | Composer 118 | Movement 119 | University 120 | Tunnel 121 | Treaty 122 | Religious Festival 123 | Celebrity 124 | Sport Star 125 | Island -------------------------------------------------------------------------------- /gen_data/prompts/2_attributes.txt: -------------------------------------------------------------------------------- 1 | You are an intelligent assistant that can follow the instruction step-by-step. Please generate the output of the following format: 2 | + Category: Please randomly select one of: Person, Region, Country, Mountain, River, Organization, Event, Process, Method, League, Product, Facility Vehicle, Music, Food, Art, Book, Sport, Plant, Animal, ... 3 | + Entity: Please generate an entity of this Category whose level of popularity on internet is {popularity} out of 5 4 | + List of Attributes: Please generate a list of attributes of this entity sorted by the level of popularity on internet, no more than 10, separated by comma(",") 5 | + Attribute 1: Select a random attribute from this list 6 | + Attribute 2: Select another random attribute from this list 7 | + Question 1: Generate a {question_type} asking for Attribute 1 only 8 | + Question 2: Generate a {question_type} asking for Attribute 2 only 9 | + Merged Question: Generate a {question_type} containing the points of Question 1 and Question 2 10 | ------------ 11 | + Knowledge 1: Generate a medium-sized paragraph of about 8-9 sentences containing information of the attribute for entity 1 and also related information 12 | + Thought 1: First, extract the relevant information from Knowledge 1 and then generate the reasoning to answer the Question 1 13 | + Answer 1: based on Thought 1, provide the final answer to the Question 1;if the Knowledge 1 doesn't contain the answer or cannot reason to get the answer, please say that you cannot answer this based on your knowledge 14 | ------------- 15 | + Knowledge 2: Generate a medium-sized paragraph of about 8-9 sentences containing information of the attribute for entity 2 and also related information 16 | + Thought 2: First, extract the relevant information from the Knowledge 2 and then generate the reasoning to answer the Question 2 17 | + Answer 2: based on Thought 2, provide the final answer to the Question 2;if the Knowledge 2 doesn't contain the answer or cannot reason to get the answer, please say that you cannot answer this based on your knowledge 18 | ------------ 19 | + Summary: First summarize the points from Answer 1 and Answer 2 20 | + Final Thought: based on the Summary, generate the reasoning to answer the Merged Question 21 | + Final answer: The complete answer to the Merged Question based on Summary and Final Thought 22 | ------------ 23 | Please generate now: 24 | + Category: {category} -------------------------------------------------------------------------------- /gen_data/prompts/2_attributes_wo_answer.txt: -------------------------------------------------------------------------------- 1 | You are an intelligent assistant that can follow the instruction step-by-step. Please generate the output of the following format: 2 | + Category: Please randomly select one of: Person, Region, Country, Mountain, River, Organization, Event, Process, Method, League, Product, Facility Vehicle, Music, Food, Art, Book, Sport, Plant, Animal, ... 3 | + Entity: Please generate an entity of this Category whose level of popularity on internet is {popularity} out of 5 4 | + List of Attributes: Please generate a list of attributes of this entity sorted by the level of popularity on internet, no more than 10 5 | + Attribute 1: Select a random attribute from this list 6 | + Attribute 2: Select another random attribute from this list 7 | + Question 1: Generate a {question_type} asking for Attribute 1 only 8 | + Question 2: Generate a {question_type} asking for Attribute 2 only 9 | + Merged Question: Generate a {question_type} containing the points of Question 1 and Question 2 10 | + Knowledge 1: Generate a medium-sized paragraph of about 8-9 sentences containing information of the attribute for entity 1 and also related information 11 | + Knowledge 2: Generate a medium-sized paragraph of about 8-9 sentences containing information of the attribute for entity 2 and also related information 12 | Please generate now: 13 | + Category: {category} -------------------------------------------------------------------------------- /gen_data/prompts/2_entities.txt: -------------------------------------------------------------------------------- 1 | You are an intelligent assistant that can follow the instruction step-by-step. Please generate the output of the following format: 2 | + Category: Please randomly select one of: Person, Region, Country, Mountain, River, Organization, Event, Process, Method, League, Product, Facility Vehicle, Music, Food, Art, Book, Sport, Plant, Animal, ... 3 | + Entity 1: generate a random entity of this category such that level of popularity on internet is {popularity_1} out of 5 4 | + Entity 2: generate another random entity of this category such that level of popularity on internet is {popularity_2} out of 5 5 | + List of Attributes: generate a list of common attributes of entity 1 and entity 2 to compare, separated by a comma. The attribute can be a number, a type, a method, a description, a definition, a procedure, a child, ... 6 | + Selected Attribute: select a random attribute in the List of Attributes generated above 7 | + Question 1: Question asking for Selected Attribute of entity 1 only, the question must mention entity 1 8 | + Question 2: Question asking for Selected Attribute of entity 2 only, the question must mention entity 2; and the question type should be the same as Question 1 9 | + Question for 2 entities: Generate a {question_type} to ask the Selected Attribute of these 2 entities, the question might be: comparing numbers (larger, equal, the same, smaller, ...), sum/difference of 2 numbers, comparing logic, finding the difference, summarizing the Selected Attribute of 2 entities, finding out how they are related, asking the Selected Attribute of 2 entities in the same question ... The question must mention 2 entities, do not use reference or coreference such as: these, those, this, that,... 10 | ---------------- 11 | + Knowledge 1: Generate a medium-sized paragraph of about 8-9 sentences containing information of the Selected Attribute for entity 1 and also related information 12 | + Thought 1: extract the relevant information from Knowledge 1 and then generate the reasoning to answer the Question 1 13 | + Answer 1: provide the final answer to the Question 1 based on Thought 1; if the Knowledge 1 doesn't contain the answer or cannot reason to get the answer, please say that you cannot answer this based on your knowledge 14 | ---------------- 15 | + Knowledge 2: Generate a medium-sized paragraph of about 8-9 sentences containing information of the Selected Attribute for entity 2 and also related information 16 | + Thought 2: extract the relevant information from the Knowledge 2 and then generate the reasoning to answer the Question 2 17 | + Answer 2: provide the final answer to the Question 2 based on Thought 2; if the Knowledge 2 doesn't contain the answer or cannot reason to get the answer, please say that you cannot answer this based on your knowledge 18 | ---------------- 19 | + Summary: summarize the points from Answer 1 and Answer 2 20 | + Final Thought: generate the reasoning to answer the Question for 2 entities based on the Summary 21 | + Final answer: The complete answer to Question for 2 entities based on Summary and Thought 22 | ------------ 23 | Please generate now: 24 | + Category: {category} -------------------------------------------------------------------------------- /gen_data/prompts/2_entities_wo_answer.txt: -------------------------------------------------------------------------------- 1 | You are an intelligent assistant that can follow the instruction step-by-step. Please generate the output of the following format: 2 | + Category: Please randomly select one of: Person, Region, Country, Mountain, River, Organization, Event, Process, Method, League, Product, Facility Vehicle, Music, Food, Art, Book, Sport, Plant, Animal, ... 3 | + Entity 1: generate a random entity of this category such that level of popularity on internet is {popularity_1} out of 5 4 | + Entity 2: generate another random entity of this category such that level of popularity on internet is {popularity_2} out of 5 5 | + List of Attributes: generate a list of common attributes of entity 1 and entity 2 to compare, separated by a comma. The attribute can be a number, a type, a method, a description, a definition, a procedure, a child, ... 6 | + Selected Attribute: select a random attribute in the List of Attributes generated above 7 | + Question 1: Question asking for Selected Attribute of entity 1 only, the question must mention entity 1 8 | + Question 2: Question asking for Selected Attribute of entity 2 only, the question must mention entity 2; and the question type should be the same as Question 1 9 | + Knowledge 1: Generate a medium-sized paragraph of about 8-9 sentences containing information of the Selected Attribute for entity 1 and also related information 10 | + Knowledge 2: Generate a medium-sized paragraph of about 8-9 sentences containing information of the Selected Attribute for entity 2 and also related information 11 | + Question for 2 entities: Generate a {question_type} to ask the Selected Attribute of these 2 entities, the question might be: comparing numbers (larger, equal, the same, smaller, ...), sum/difference of 2 numbers, comparing logic, finding the difference, summarizing the Selected Attribute of 2 entities, finding out how they are related, asking the Selected Attribute of 2 entities in the same question ... The question must mention 2 entities, do not use reference or coreference such as: these, those, this, that,... 12 | Please generate now: 13 | + Category: {category} -------------------------------------------------------------------------------- /gen_data/prompts/answer_gen.txt: -------------------------------------------------------------------------------- 1 | You are an intelligent assistant that can generate the answer to a question based only on the provided knowledge. If you don't know the answer or cannot extract the answer from the provided knowledge, just say that you don't know, don't try to make up an answer. Note that the answer must be based on the provided knowledge. 2 | ======= 3 | Here is the knowledge: 4 | {context} 5 | ======= 6 | Here is the question: 7 | {question} 8 | ======= 9 | Now please generate the answer to this question as an assistant following this format: 10 | + Thought: extract the relevant information, detail to the question from the provided knowledge and then generate the reasoning to answer the question 11 | + Answer: based on Thought, provide the complete answer to the question; if the provided knowledge doesn't contain the answer or cannot reason to get the answer, please say that you cannot answer this based on your knowledge 12 | -------------- 13 | Please generate now: 14 | + Thought: -------------------------------------------------------------------------------- /gen_data/prompts/final_answer_gen.txt: -------------------------------------------------------------------------------- 1 | You are an intelligent assistant that can generate a final answer to a question based on the provided knowledge. The final answer must first connect the facts from knowledge and then infer the conclusion for the answer to the question. Note that you only generate information based on the knowledge 2 | ==== 3 | Here are the provided knowledge: 4 | {facts} 5 | ==== 6 | Here is the question: 7 | {question} 8 | ==== 9 | Now please generate the final answer for the question by the following format: 10 | + Summary: summarize the points from knowledge 11 | + Thought: generate the reasoning to answer the question based on the Summary 12 | + Answer: The complete answer to the question based on Summary and Thought 13 | -------------- 14 | Please generate now: 15 | + Summary: -------------------------------------------------------------------------------- /gen_data/prompts/gen_new_entity.txt: -------------------------------------------------------------------------------- 1 | You are an intelligent assistant that can generate a new entity that is different from the list of available entities in a category. 2 | Category: {category} 3 | List of available entities: {entities} 4 | Please generate now: 5 | A new entity: -------------------------------------------------------------------------------- /gen_data/prompts/gen_paragraph.txt: -------------------------------------------------------------------------------- 1 | You are an intelligent assistant that can generate a paragraph of 4-6 sentences containing the information of given attribute from an entity. 2 | Now generate a paragraph about the attribute: "{attribute}" of entity: "{entity}" 3 | Please generate now: 4 | Paragraph: -------------------------------------------------------------------------------- /gen_data/prompts/sub_category_gen.txt: -------------------------------------------------------------------------------- 1 | Given a category of entities, please generate a list of sub-categories of this category. The length of the list is dependent on the scope of category. For example, for small category you can genate up to 10, but for big category you can generate up to 30 sub-categories. Note that sub-categories must be distinct. 2 | The given category is: {category} 3 | Now please generate a list of sub-categories for this category with the following format: 4 | Sub-category 1: {category} in ... 5 | Sub-category 2: {category} in ... 6 | ... 7 | Sub-category n: {category} in ... 8 | Please generate now: -------------------------------------------------------------------------------- /gen_data/template_gen.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Any, Optional, Callable, Tuple 2 | import re 3 | import openai 4 | import time 5 | from abc import ABC, abstractmethod 6 | import json 7 | import sys 8 | import requests 9 | import datetime 10 | from openai import OpenAI 11 | import traceback 12 | 13 | 14 | openai_client = OpenAI() 15 | 16 | 17 | class CustomizedPromptGen(ABC): 18 | @abstractmethod 19 | def get_prompt(self, input_dic: Dict) -> str: 20 | raise NotImplementedError 21 | 22 | 23 | class WizardLMPromptGen(CustomizedPromptGen): 24 | def get_prompt(self, input_dic: Dict) -> str: 25 | prompt = input_dic["prompt_input"] 26 | output_prefix = input_dic["output_prefix"] 27 | new_prompt = ( 28 | f"A chat between a curious user and an artificial intelligence assistant. " 29 | f"The assistant gives helpful, detailed, and polite answers to the user's questions. " 30 | f"USER: {prompt} ASSISTANT:{output_prefix}" 31 | ) 32 | return new_prompt 33 | 34 | 35 | class OpenAIPromptGen(CustomizedPromptGen): 36 | def get_prompt(self, input_dic: Dict) -> str: 37 | input_prompt = input_dic["prompt_input"] 38 | output_prefix = input_dic["output_prefix"] 39 | result = input_prompt + "\nPlease generate now:" 40 | if len(output_prefix) > 0: 41 | result += f"\n{output_prefix}" 42 | return result 43 | 44 | 45 | class OpenOrcaPrompGen(CustomizedPromptGen): 46 | def get_prompt(self, input_dic: Dict) -> str: 47 | input_prompt = input_dic["prompt_input"] 48 | output_prefix = input_dic["output_prefix"] 49 | 50 | index = input_prompt.find("\n") 51 | sys_prompt = input_prompt[:index].strip() 52 | user_query = input_prompt[index + 1 :].strip() 53 | prefix = "<|im_start|>" 54 | suffix = "<|im_end|>\n" 55 | 56 | def get_text_role(role, content, add_suffix=True): 57 | result = prefix + f"{role}\n" + content 58 | if add_suffix: 59 | result += suffix 60 | return result 61 | 62 | # sys_format = prefix + "system\n" + sys_prompt + suffix 63 | # user_format = prefix + "user\n" + user_query + suffix 64 | # assistant_format = prefix + "assistant\n" 65 | sys_format = get_text_role("system", sys_prompt) 66 | user_format = get_text_role("user", user_query) 67 | assistant_format = get_text_role("assistant", output_prefix, False) 68 | return sys_format + user_format + assistant_format 69 | 70 | 71 | PROMPT_GEN_DIC = {"wizardlm": WizardLMPromptGen(), "openai": OpenAIPromptGen(), "open_orca": OpenOrcaPrompGen()} 72 | 73 | 74 | def get_final_prompt(input_dic: Dict) -> str: 75 | prompt_type = input_dic["prompt_type"] 76 | prompt_gen = PROMPT_GEN_DIC[prompt_type] 77 | return prompt_gen.get_prompt(input_dic) 78 | 79 | 80 | def parse_fields_from_template(prompt_template: str) -> List[str]: 81 | fields = [] 82 | for match in re.finditer(r"(\n|^)\+(?P
26 |
27 |
28 |
30 | Examples of packing 2 input sequences: "good morning my name is John" and "This is a dog". The left is the attention matrix of packing with cross-contamination, the right is the correct attention matrix of packing
31 | 32 | ## Training script 33 | First, you need to download the training data from: [khaimaitien/qa-expert-multi-hop-qa-V1.0](https://huggingface.co/datasets/khaimaitien/qa-expert-multi-hop-qa-V1.0) and save to a folder to pass in arguments: ``train_path`` and ``validation_path`` 34 | 35 | In this repo, we support training the **whole model** or **lora** or **qlora**, and currently only 2 types of model: **Mistral or Llama2**. Based on my experience and also many people report, the Mistral-7B can outperform Llama2-13b so we decided to train on Mistral only. 36 | 37 | Some noticeable arguments in the training you should pay attention to: 38 | + **--model_name_or_path**: pretrained model 39 | + **--train_path**: training file, you can download from: [khaimaitien/qa-expert-multi-hop-qa-V1.0](https://huggingface.co/datasets/khaimaitien/qa-expert-multi-hop-qa-V1.0) 40 | + --validation_path: validation file, you can download from: [khaimaitien/qa-expert-multi-hop-qa-V1.0](https://huggingface.co/datasets/khaimaitien/qa-expert-multi-hop-qa-V1.0) 41 | + **--model_type**: "mistral" or "llama" 42 | + **--use_lora**: True if using lora, False if training the whole model 43 | + **--qlora**: If use_lora=True, we can choose to use qlora or not, True if use qlora 44 | + **--model_max_length**: The maximum sequence length, default=4096 45 | + **--packing**: True if using packing short inputs , False if not. We recommend using packing 46 | 47 | ### Single GPU 48 | Example for training on 1 GPU, not using lora: 49 | 50 | ``` 51 | python -m train.train_model \ 52 | --model_name_or_path Mistral-7B-v0.1 \ 53 | --train_path train_data/train.json \ 54 | --validation_path train_data/validation.json \ 55 | --model_type mistral \ 56 | --use_lora False \ 57 | --qlora False \ 58 | --bf16 True \ 59 | --output_dir models/mistral-qa_full \ 60 | --num_train_epochs 2 \ 61 | --per_device_train_batch_size 3 \ 62 | --per_device_eval_batch_size 4 \ 63 | --gradient_accumulation_steps 10 \ 64 | --eval_accumulation_steps 1 \ 65 | --evaluation_strategy "steps" \ 66 | --eval_steps 40 \ 67 | --save_strategy "epoch" \ 68 | --save_steps 80 \ 69 | --save_total_limit 3 \ 70 | --learning_rate 1.2e-5 \ 71 | --lr_scheduler_type "cosine" \ 72 | --logging_steps 1 \ 73 | --tf32 True \ 74 | --model_max_length 4096 \ 75 | --gradient_checkpointing True \ 76 | --packing True 77 | ``` 78 | 79 | 80 | ### Multiple GPU 81 | To train on Multiple GPUs, I suggest using **deepspeed** as many people reported that training Mistral model encounters loss instability using FSDP: https://github.com/huggingface/transformers/issues/26498 82 | 83 | Also, Another note is: **FSDP doesn't work** for Lora because it requires all the parameters to be uniformly trainable or freezed. 84 | 85 | Here is an example: 86 | ``` 87 | deepspeed train/train_model.py \ 88 | --model_name_or_path Mistral-7B-v0.1 \ 89 | --train_path train_data/train.json \ 90 | --validation_path train_data/validation.json \ 91 | --model_type mistral \ 92 | --use_lora False \ 93 | --qlora False \ 94 | --bf16 True \ 95 | --output_dir models/mistral-qa_full \ 96 | --num_train_epochs 2 \ 97 | --per_device_train_batch_size 3 \ 98 | --per_device_eval_batch_size 4 \ 99 | --gradient_accumulation_steps 10 \ 100 | --eval_accumulation_steps 1 \ 101 | --evaluation_strategy "steps" \ 102 | --eval_steps 40 \ 103 | --save_strategy "epoch" \ 104 | --save_steps 80 \ 105 | --save_total_limit 3 \ 106 | --learning_rate 1.2e-5 \ 107 | --lr_scheduler_type "cosine" \ 108 | --logging_steps 1 \ 109 | --tf32 True \ 110 | --model_max_length 4096 \ 111 | --gradient_checkpointing True \ 112 | --packing True\ 113 | --deepspeed train/ds_config/zero3_wo_offload.json 114 | ``` 115 | ### Merge Adapter weights in Lora 116 | If you train the model using lora or qlora (--use_lora True), you need to merge the adapter weights to original weights. You can do it by running this script: 117 | ```shell 118 | python -m train.merge_weight save_folder: str, pretrained_path: str, checkpoint: str, model_type: str) 119 | ``` 120 | Where: 121 | + **save_folder**: Where to save the merged model (final model to use) 122 | + **pretrained_path**: path to pretrained_path in finetuning 123 | + **checkpoint**: checkpoint folder, containing the adapter weights 124 | + **model_type**: mistral or llama 125 | 126 | ### Some notes about training 127 | Here are some notes about training based on my experience: 128 | 129 | + Mistral-7B outperforms Llama-2-13b considerably althoug it is much smaller in size 130 | + **Using packing saves a lot of training time**. For example, the number of data points in training dataset is **25547**, if using packing, the number of data points is reduced to: **5173** --> the training time is almost only **1/5** of that without packing. 131 | + If you are using A100 to train the model, you should use deepspeed zero3 without offloading [ds_config/zero3_wo_offloading.json](ds_config/zero3_wo_offloading.json). If you are using A6000 to train the model, you should use deepspeed zero3 with offloading [ds_config/zero3.json](ds_config/zero3.json). 132 | + FSDP brings about loss instability in training Mistral model, more information from [here](https://github.com/huggingface/transformers/issues/26498) 133 | + [vast.ai](https://vast.ai/) provides a little bit **cheaper price** than [runpod.io](https://www.runpod.io/) 134 | 135 | ## Evaluation 136 | We use [HotpotQA](https://hotpotqa.github.io/) as the evaluation dataset and measure the metrics: 137 | + **Recall**: compute the recall based on the individual words, the reason we use recall instead of F1 because the ground-truth answers are short and usually spans. 138 | + **Accuracy of containing ground-truth**: If the ground-truth answer (mostly short span) is exactly in the generated answer --> 1 else 0 139 | 140 | Here is the result: 141 | 142 | |Model|Recall|Accuracy of containing ground-truth| 143 | |---|---|---| 144 | |[khaimaitien/qa-expert-7B-V1.0](https://huggingface.co/khaimaitien/qa-expert-7B-V1.0)|0.73215|0.664| 145 | 146 | 147 | You can run the evaluation script at the root directory of this repo: 148 | 149 | ```shell 150 | python eval_hotpot_qa.py --model-path khaimaitien/qa-expert-7B-V1.0 --inference-type vllm 151 | ``` 152 | Note that using ``vllm`` would be much faster than ``hf`` -------------------------------------------------------------------------------- /train/assert_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from transformers import LlamaTokenizer 2 | from train.monkey_patched_mistral_packed_attention_mask import MistralForCausalLM 3 | from train import custom_datasets 4 | import torch 5 | import copy 6 | import typer 7 | import os 8 | import math 9 | from qa_expert import prompt_utils 10 | import json 11 | 12 | 13 | def read_raw_data(): 14 | cur_folder = os.path.dirname(os.path.abspath(__file__)) 15 | with open(os.path.join(cur_folder, "test.json"), "r") as f: 16 | return json.loads(f.read()) 17 | 18 | 19 | def prepare_input_dic(input_dic, device): 20 | result = copy.deepcopy(input_dic) 21 | for key in result: 22 | result[key] = torch.unsqueeze(input_dic[key], 0) 23 | result[key] = result[key].to(device) 24 | result["return_dict"] = True 25 | result["loss_reduction"] = "sum" 26 | return result 27 | 28 | 29 | def compute_loss_from_ds(ds, model, device): 30 | total_loss = 0 31 | for i in range(len(ds)): 32 | input_dic = ds[i] 33 | input_dic = prepare_input_dic(input_dic, device) 34 | with torch.no_grad(): 35 | loss = model.forward(**input_dic).loss.item() 36 | total_loss += loss 37 | return total_loss 38 | 39 | 40 | def main(pretrained_path: str, device: str = typer.Option("cuda:0")): 41 | tokenizer = LlamaTokenizer.from_pretrained(pretrained_path, legacy=True, model_max_length=4096) 42 | tokenizer.pad_token = tokenizer.unk_token 43 | tokenizer.add_special_tokens({"additional_special_tokens": prompt_utils.get_additional_tokens()}) 44 | 45 | model = MistralForCausalLM.from_pretrained( 46 | pretrained_path, torch_dtype=torch.bfloat16, device_map=device, use_flash_attention_2=False 47 | ) 48 | model.resize_token_embeddings(len(tokenizer)) 49 | 50 | model.eval() 51 | 52 | raw_data = read_raw_data() 53 | for padding_side in ["left", "right"]: 54 | print("test padding_side: ", padding_side) 55 | tokenizer.padding_side = padding_side 56 | normal_ds = custom_datasets.CustomDataset(raw_data, tokenizer) 57 | packed_ds = custom_datasets.PackedDataset(raw_data, tokenizer) 58 | print("number of data points from normal ds: ", len(normal_ds)) 59 | print("number of data points from packed ds: ", len(packed_ds)) 60 | normal_loss = compute_loss_from_ds(normal_ds, model, device) 61 | mk_loss = compute_loss_from_ds(packed_ds, model, device) 62 | diff = math.fabs(normal_loss - mk_loss) 63 | diff_percent = diff * 100 / max(normal_loss, mk_loss) 64 | print(f"normal_loss: {normal_loss}, mk_loss={mk_loss}, diff_percent={diff_percent}%") 65 | 66 | 67 | if __name__ == "__main__": 68 | typer.run(main) 69 | -------------------------------------------------------------------------------- /train/custom_datasets.py: -------------------------------------------------------------------------------- 1 | # This script is created based on: https://github.com/MeetKai/functionary/blob/main/functionary/train/custom_datasets.py 2 | import datetime 3 | import json 4 | import os 5 | import pickle 6 | from abc import ABC, abstractmethod 7 | from typing import Any, Dict, List, Optional, Tuple, Union 8 | 9 | import torch 10 | import transformers 11 | from torch.utils.data import Dataset 12 | from qa_expert import prompt_utils 13 | from gen_data import utility 14 | 15 | 16 | def map_raw_data_to_input_dic(raw_data: List[Dict], tokenizer: Any, padding: str, batch_size: int = 5000) -> List[Dict]: 17 | invalid_count = 0 18 | data_size = len(raw_data) 19 | data_points = [] 20 | t1 = datetime.datetime.now() 21 | for start, end in utility.get_batch_indices(data_size, batch_size): 22 | batch_messages = [prompt_utils.convert_multi_qa_format_to_messages(item) for item in raw_data[start:end]] 23 | batch_result = prompt_utils.preprare_training_inputs_batch(batch_messages, tokenizer, padding) 24 | assert len(batch_result) == len(raw_data[start:end]) 25 | for item in batch_result: 26 | if is_valid_labels(item["labels"]): 27 | data_points.append(item) 28 | else: 29 | print("invalid: ") 30 | invalid_count += 1 31 | t2 = datetime.datetime.now() 32 | avg_time = (t2 - t1).total_seconds() / len(data_points) 33 | remaining_time = avg_time * (data_size - len(data_points)) 34 | print( 35 | f"{len(data_points)}/{data_size}, avg_time per 1000 data points: {avg_time * 1000}, remaining time: {remaining_time}" 36 | ) 37 | if invalid_count > 0: 38 | print(f"*****WARNING: invalid data points: {invalid_count} because of labels=-100 all the time") 39 | assert len(data_points) == data_size - invalid_count 40 | return data_points 41 | 42 | 43 | def merge_data_points_by_length(lengths: List[int], max_length: int) -> List[List[int]]: 44 | """given lengths of data points, we merge them into groups such that the sum of lengths 45 | in each group is less than max_length. This is known as: https://en.wikipedia.org/wiki/Bin_packing_problem 46 | Here is the greedy algorithm 47 | Args: 48 | lengths (List[int]): _description_ 49 | max_length (int): _description_ 50 | 51 | Returns: 52 | _type_: groups of indices: [[index1, index2, ...], [], ...] 53 | """ 54 | items = [{"length": length, "index": i} for i, length in enumerate(lengths)] 55 | items = sorted(items, key=lambda x: x["index"]) 56 | merges = [] 57 | current_sum = 0 58 | current_list = [] 59 | for i in range(len(items)): 60 | cur_length = items[i]["length"] 61 | if cur_length + current_sum <= max_length: 62 | current_sum += items[i]["length"] 63 | current_list.append(i) 64 | else: 65 | merges.append(current_list) 66 | current_list = [i] 67 | current_sum = cur_length 68 | if len(current_list) > 0: 69 | merges.append(current_list) 70 | result = [] 71 | for merge in merges: 72 | sub_items = [items[index]["index"] for index in merge] 73 | result.append(sub_items) 74 | return result 75 | 76 | 77 | def get_causal_mask(length: int, sliding_window: Optional[int] = None): 78 | """ 79 | Make causal mask used for sliding window attention 80 | """ 81 | tensor = torch.full( 82 | (length, length), 83 | fill_value=1, 84 | ) 85 | mask = torch.tril(tensor, diagonal=0) 86 | # make the mask banded to account for sliding window 87 | if sliding_window is not None: 88 | mask = torch.triu(mask, diagonal=-sliding_window) 89 | mask = torch.log(mask) 90 | return mask 91 | 92 | 93 | def create_mask_padding_right( 94 | lengths: List[int], model_max_length: int, sliding_window: Optional[int] = None 95 | ) -> torch.tensor: 96 | """create attention_mask: N x N where masked value = m_value 97 | Args: 98 | lengths (List[int]): length of data points 99 | tokenizer (Any): _description_ 100 | m_value (float): _description_ 101 | 102 | Returns: 103 | torch.tensor: _description_ 104 | """ 105 | result = torch.full((model_max_length, model_max_length), float("-inf")) 106 | acc_leng = 0 107 | for length in lengths: 108 | # mask for a data point with length 109 | x = get_causal_mask(length, sliding_window) 110 | result[acc_leng : acc_leng + length, acc_leng : acc_leng + length] = x 111 | acc_leng += length 112 | pad_length = model_max_length - sum(lengths) 113 | if pad_length > 0: 114 | result[-pad_length:, :] = 0 115 | result[:, -pad_length:] = float("-inf") 116 | return result 117 | 118 | 119 | def create_mask_padding_left( 120 | lengths: List[int], model_max_length: int, sliding_window: Optional[int] = None 121 | ) -> torch.tensor: 122 | result = torch.full((model_max_length, model_max_length), float("-inf")) 123 | pad_length = model_max_length - sum(lengths) 124 | acc_leng = 0 125 | for length in [pad_length] + lengths: 126 | x = get_causal_mask(length, sliding_window) 127 | result[acc_leng : acc_leng + length, acc_leng : acc_leng + length] = x 128 | acc_leng += length 129 | return result 130 | 131 | 132 | def create_mask_from_lengths(lengths: List[int], tokenizer: Any, sliding_window: Optional[int] = None) -> torch.tensor: 133 | if tokenizer.padding_side == "left": 134 | return create_mask_padding_left(lengths, tokenizer.model_max_length, sliding_window) 135 | return create_mask_padding_right(lengths, tokenizer.model_max_length, sliding_window) 136 | 137 | 138 | def merge_data_points(data_points: List[Dict], tokenizer: Any, sliding_window: Any) -> Dict: 139 | input_ids = [] 140 | lengths = [] 141 | label_ids = [] 142 | for item in data_points: 143 | input_ids += item["input_ids"] 144 | # assert item["labels"][0] == -100 # This is to make sure that the first token won't be included in computing loss 145 | labels = list(item["labels"]) 146 | labels[0] = -100 147 | label_ids += labels 148 | lengths.append(len(item["input_ids"])) 149 | attention_mask = create_mask_from_lengths(lengths, tokenizer, sliding_window) 150 | pad_leng = tokenizer.model_max_length - len(input_ids) # padding to model_max_length 151 | if tokenizer.padding_side == "right": 152 | input_ids = input_ids + [tokenizer.pad_token_id for _ in range(pad_leng)] 153 | label_ids = label_ids + [-100 for _ in range(pad_leng)] 154 | else: 155 | input_ids = [tokenizer.pad_token_id for _ in range(pad_leng)] + input_ids 156 | label_ids = [-100 for _ in range(pad_leng)] + label_ids 157 | assert len(input_ids) == len(label_ids) == attention_mask.size(0) 158 | return { 159 | "input_ids": torch.tensor(input_ids), 160 | "labels": torch.tensor(label_ids), 161 | "attention_mask": torch.unsqueeze(attention_mask, 0), # unsqueeze <-- because the shape is: B x 1 x N x N 162 | } 163 | 164 | 165 | def is_valid_labels(labels: Union[List[int], torch.Tensor]) -> bool: 166 | """by setting max_length, there might be the case that the labels are all -100 -> loss=nan 167 | Args: 168 | labels (Union[List[int], torch.Tensor]): _description_ 169 | 170 | Returns: 171 | bool: _description_ 172 | """ 173 | if type(labels) is list: 174 | non_mask_count = 0 175 | for label in labels: 176 | if label != -100: 177 | non_mask_count += 1 178 | if non_mask_count == 0: 179 | return False 180 | return True 181 | elif type(labels) is torch.tensor: 182 | if sum(labels + 100) == 0: # mypy: ignore-errors 183 | return False 184 | return True 185 | return True 186 | 187 | 188 | def remove_invalid_label_items(data_points: List[Dict]) -> List[Dict]: 189 | """Remove data points where labels are all -100 190 | 191 | Args: 192 | data_points (List[Dict]): _description_ 193 | 194 | Returns: 195 | _type_: _description_ 196 | """ 197 | result = [] 198 | for dp in data_points: 199 | if is_valid_labels(dp["labels"]): 200 | result.append(dp) 201 | return result 202 | 203 | 204 | class CachedDataset(Dataset): 205 | def __init__(self, tokenizer: Any, cached_folder: Optional[str] = None, ignore_cached: bool = False) -> None: 206 | super().__init__() 207 | self.tokenizer = tokenizer 208 | self.data_points: List[Dict] = [] 209 | self.load_from_cache = False 210 | if cached_folder is not None and not ignore_cached: 211 | data_path = self.get_data_point_path(cached_folder) 212 | if os.path.exists(data_path): 213 | print(f"cached found, load from cached: {cached_folder}") 214 | self.load(cached_folder) 215 | self.load_from_cache = True 216 | 217 | def __len__(self): 218 | return len(self.data_points) 219 | 220 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 221 | return self.data_points[i] 222 | 223 | def create_meta_info(self): 224 | return {"max_length": self.tokenizer.model_max_length, "size": len(self.data_points)} 225 | 226 | def load(self, folder: str): 227 | t1 = datetime.datetime.now() 228 | with open(self.get_data_point_path(folder), "rb") as file: 229 | self.data_points = pickle.load(file) 230 | t2 = datetime.datetime.now() 231 | print("time for loading cached data: ", (t2 - t1).total_seconds()) 232 | 233 | def get_data_point_path(self, folder: str) -> str: 234 | return os.path.join(folder, "data_points.pkl") 235 | 236 | def get_metainfo_path(self, folder: str) -> str: 237 | return os.path.join(folder, "meta_info.json") 238 | 239 | def dump(self, folder: str): 240 | t1 = datetime.datetime.now() 241 | if not os.path.exists(folder): 242 | os.mkdir(folder) 243 | 244 | with open(self.get_data_point_path(folder), "wb") as file: 245 | pickle.dump(self.data_points, file) 246 | 247 | with open(self.get_metainfo_path(folder), "w") as f: 248 | f.write(json.dumps(self.create_meta_info())) 249 | t2 = datetime.datetime.now() 250 | print("time for dumping data: ", (t2 - t1).total_seconds()) 251 | 252 | def stat(self): 253 | print(json.dumps(self.create_meta_info())) 254 | 255 | 256 | class CustomDataset(CachedDataset): 257 | """Dataset for supervised fine-tuning.""" 258 | 259 | def __init__( 260 | self, 261 | raw_data: List[Dict], 262 | tokenizer: transformers.PreTrainedTokenizer, 263 | cached_folder: Optional[str] = None, 264 | ignore_cached: bool = False, 265 | batch_size: int = 5000, 266 | **kwargs, 267 | ): 268 | super().__init__(tokenizer, cached_folder, ignore_cached) 269 | 270 | if not self.load_from_cache: # if not loaded from cached 271 | self.data_points = map_raw_data_to_input_dic( 272 | raw_data, tokenizer, padding="max_length", batch_size=batch_size 273 | ) 274 | if cached_folder is not None: 275 | print(f"dump data to cached: {cached_folder}") 276 | self.dump(cached_folder) 277 | 278 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 279 | dp = self.data_points[i] 280 | result = {} 281 | for key in dp: 282 | result[key] = torch.tensor(dp[key]) 283 | return result 284 | 285 | 286 | class PackedDataset(CachedDataset): 287 | def __init__( 288 | self, 289 | raw_data: List[Dict], 290 | tokenizer: transformers.PreTrainedTokenizer, 291 | cached_folder: Optional[str] = None, 292 | ignore_cached: bool = False, 293 | batch_size: int = 5000, 294 | **kwargs, 295 | ): 296 | super().__init__(tokenizer, cached_folder, ignore_cached) 297 | self.sliding_window = kwargs.get("sliding_window", None) 298 | if not self.load_from_cache: 299 | self.data_points = map_raw_data_to_input_dic( 300 | raw_data, tokenizer, padding="do_not_pad", batch_size=batch_size 301 | ) 302 | self.update_packing_info() 303 | if cached_folder is not None: 304 | print(f"dump data to cached: {cached_folder}") 305 | self.dump(cached_folder) 306 | else: # update packing 307 | self.update_packing_info() 308 | 309 | def update_packing_info(self): 310 | self.lengths = [len(item["input_ids"]) for item in self.data_points] 311 | self.groups = merge_data_points_by_length(self.lengths, self.tokenizer.model_max_length) 312 | 313 | def __len__(self): 314 | return len(self.groups) 315 | 316 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 317 | group = self.groups[i] 318 | group_data_points = [self.data_points[index] for index in group] 319 | return merge_data_points(group_data_points, self.tokenizer, self.sliding_window) 320 | 321 | def stat(self): 322 | print(f"number of original data points:{len(self.data_points)}; packed to: {len(self.groups)} data points") 323 | original_avg_length = sum(self.lengths) / len(self.lengths) 324 | packed_lengths = [] 325 | for group in self.groups: 326 | lengths = [self.lengths[index] for index in group] 327 | packed_lengths.append(sum(lengths)) 328 | avg_packed_length = sum(packed_lengths) / len(packed_lengths) 329 | print(f"original avg length: {original_avg_length}; avg packed length: {avg_packed_length}") 330 | 331 | 332 | def pack_data_points_FA(data_points: List[Dict], tokenizer: Any) -> Dict: 333 | input_ids = [] 334 | lengths = [] 335 | label_ids = [] 336 | attention_mask = [] 337 | for index, item in enumerate(data_points): 338 | input_ids += item["input_ids"] 339 | # assert item["labels"][0] == -100 # This is to make sure that the first token won't be included in computing loss 340 | labels = list(item["labels"]) 341 | labels[0] = -100 342 | label_ids += labels 343 | lengths.append(len(item["input_ids"])) 344 | attention_mask += [index + 1 for _ in range(len(item["input_ids"]))] 345 | 346 | pad_leng = tokenizer.model_max_length - len(input_ids) # padding to model_max_length 347 | if tokenizer.padding_side == "right": 348 | input_ids = input_ids + [tokenizer.pad_token_id for _ in range(pad_leng)] 349 | label_ids = label_ids + [-100 for _ in range(pad_leng)] 350 | attention_mask = attention_mask + [0 for _ in range(pad_leng)] 351 | else: 352 | input_ids = [tokenizer.pad_token_id for _ in range(pad_leng)] + input_ids 353 | label_ids = [-100 for _ in range(pad_leng)] + label_ids 354 | attention_mask = [0 for _ in range(pad_leng)] + attention_mask 355 | 356 | assert len(input_ids) == len(label_ids) == len(attention_mask) 357 | return { 358 | "input_ids": torch.tensor(input_ids), 359 | "labels": torch.tensor(label_ids), 360 | "attention_mask": torch.tensor(attention_mask), # unsqueeze <-- because the shape is: B x 1 x N x N 361 | } 362 | 363 | 364 | class FAPackedDataset(PackedDataset): 365 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 366 | group = self.groups[i] 367 | group_data_points = [self.data_points[index] for index in group] 368 | return pack_data_points_FA(group_data_points, self.tokenizer) 369 | -------------------------------------------------------------------------------- /train/data_statistics.py: -------------------------------------------------------------------------------- 1 | from gen_data import utility 2 | import typer 3 | import os 4 | from typing import Dict 5 | 6 | 7 | def update_dic_count(dic, c_value): 8 | dic[c_value] = dic.get(c_value, 0) + 1 9 | 10 | 11 | def dump_count_dic(dic, path): 12 | result = "" 13 | for k, v in sorted(dic.items(), key=lambda x: -x[1]): 14 | assert len(k) > 0 15 | result += f"{k}, {v}\n" 16 | utility.save_text(result, path) 17 | 18 | 19 | def main(train_path: str, save_folder: str): 20 | utility.create_folder(save_folder) 21 | items = utility.read_json(train_path) 22 | print("total of items: ", len(items)) 23 | result: Dict[str, Dict] = {"llm": {}, "multihop": {}, "tag": {}, "sub_questions": {}, "negative": {}} 24 | others: Dict[str, Dict] = {"entity": {}, "attribute": {}} 25 | for item in items: 26 | llm = item["meta_info"]["llm"] 27 | update_dic_count(result["llm"], llm) 28 | update_dic_count(result["tag"], item["tag"]) 29 | update_dic_count(result["multihop"], str(item["multihop"])) 30 | update_dic_count(result["sub_questions"], len(item["sub_questions"])) 31 | update_dic_count(result["negative"], str("negatives" in item["meta_info"])) 32 | meta_info = item["meta_info"] 33 | for attr in ["attribute_1", "attribute_2", "comparison_attribute"]: 34 | if attr in meta_info: 35 | update_dic_count(others["attribute"], meta_info[attr]) 36 | for entity in ["entity_1", "entity_2", "entity"]: 37 | if entity in meta_info: 38 | update_dic_count(others["entity"], meta_info[entity]) 39 | 40 | utility.save_json(result, os.path.join(save_folder, "stat.json")) 41 | 42 | dump_count_dic(others["entity"], os.path.join(save_folder, "entity.csv")) 43 | dump_count_dic(others["attribute"], os.path.join(save_folder, "attribute.csv")) 44 | 45 | 46 | if __name__ == "__main__": 47 | typer.run(main) 48 | -------------------------------------------------------------------------------- /train/ds_config/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 2, 4 | "offload_optimizer": { 5 | "device": "cpu" 6 | }, 7 | "allgather_partitions": true, 8 | "allgather_bucket_size": 2e8, 9 | "overlap_comm": true, 10 | "reduce_scatter": true, 11 | "reduce_bucket_size": 2e8, 12 | "contiguous_gradients": true 13 | }, 14 | "bf16": { 15 | "enabled": "auto" 16 | }, 17 | "fp16": { 18 | "enabled": "auto" 19 | }, 20 | "train_micro_batch_size_per_gpu": "auto", 21 | "gradient_accumulation_steps": "auto" 22 | } -------------------------------------------------------------------------------- /train/ds_config/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "zero_optimization": { 14 | "stage": 3, 15 | "offload_optimizer": { 16 | "device": "cpu", 17 | "pin_memory": true 18 | }, 19 | "offload_param": { 20 | "device": "cpu", 21 | "pin_memory": true 22 | }, 23 | "overlap_comm": true, 24 | "contiguous_gradients": true, 25 | "sub_group_size": 1e9, 26 | "reduce_bucket_size": "auto", 27 | "stage3_prefetch_bucket_size": "auto", 28 | "stage3_param_persistence_threshold": "auto", 29 | "stage3_max_live_parameters": 1e9, 30 | "stage3_max_reuse_distance": 1e9, 31 | "stage3_gather_16bit_weights_on_model_save": true 32 | }, 33 | 34 | "gradient_accumulation_steps": "auto", 35 | "gradient_clipping": "auto", 36 | "steps_per_print": 2000, 37 | "train_batch_size": "auto", 38 | "train_micro_batch_size_per_gpu": "auto", 39 | "wall_clock_breakdown": false 40 | } -------------------------------------------------------------------------------- /train/length_statistics.py: -------------------------------------------------------------------------------- 1 | from transformers import LlamaTokenizerFast 2 | import os 3 | import json 4 | import sys 5 | from typing import Dict 6 | import typer 7 | import datetime 8 | from qa_expert import prompt_utils 9 | from gen_data import utility 10 | 11 | 12 | def main(pretrained_path: str, train_path: str, save_folder: str, max_length: int): 13 | tokenizer = LlamaTokenizerFast.from_pretrained(pretrained_path, legacy=True) 14 | tokenizer.pad_token = tokenizer.eos_token 15 | tokenizer.add_special_tokens({"additional_special_tokens": prompt_utils.get_additional_tokens()}) 16 | 17 | with open(train_path, "r") as f: 18 | examples = json.loads(f.read()) 19 | print(f"handle: {train_path}, number of examples: {len(examples)}") 20 | 21 | all_prompts = [] 22 | for example in examples: 23 | messages = prompt_utils.convert_multi_qa_format_to_messages(example) 24 | prompt = prompt_utils.get_prompt_from_messages(messages) 25 | all_prompts.append(prompt) 26 | 27 | count_dic: Dict[int, int] = {} 28 | batches = utility.get_batch_indices(len(all_prompts), batch_size=2000) 29 | t1 = datetime.datetime.now() 30 | for index, (start, end) in enumerate(batches): 31 | inputs = tokenizer(all_prompts[start:end])["input_ids"] 32 | for item in inputs: 33 | length = len(item) 34 | count_dic[length] = count_dic.get(length, 0) + 1 35 | t2 = datetime.datetime.now() 36 | acc_time = (t2 - t1).total_seconds() 37 | avg_time = acc_time / (index + 1) 38 | print(f"{index} / {len(batches)}; avg_time: {avg_time}; remaining time: {avg_time * (len(batches) - index -1)}") 39 | 40 | sorted_lengths = sorted(count_dic.items(), key=lambda x: x[0]) 41 | acc_count = 0 42 | pairs = [] 43 | rows = [] 44 | for length, count in sorted_lengths: 45 | acc_count += count 46 | pairs.append((length, acc_count)) 47 | rows.append((str(length), str(count))) 48 | rows.reverse() 49 | 50 | utility.save_csv([("length", "count")] + rows, f"{save_folder}/length_dic_count.csv") 51 | total_count = acc_count 52 | assert total_count == len(all_prompts) 53 | pairs.reverse() 54 | acc_rows = [("length", "accumulated_count", "percentage")] 55 | for i in range(len(pairs)): 56 | length, count = pairs[i] 57 | acc_rows.append((str(length), str(count), str(count / total_count))) 58 | utility.save_csv(acc_rows, f"{save_folder}/accumulated.csv") 59 | 60 | lengths = [] 61 | for key in count_dic: 62 | frequency = count_dic[key] 63 | key = min(key, max_length) 64 | lengths.extend([key for _ in range(frequency)]) 65 | assert len(lengths) == len(all_prompts) 66 | groups = utility.merge_data_points_by_length(lengths, max_length) 67 | original_ave_length = sum(lengths) / len(lengths) 68 | packed_lengths = [] 69 | for indices in groups: 70 | packed_lengths.append(sum(lengths[index] for index in indices)) 71 | packed_ave_length = sum(packed_lengths) / len(packed_lengths) 72 | print("number of data points after being packed: ", len(groups)) 73 | print(f"original average length: {original_ave_length}, packed average length: {packed_ave_length}") 74 | 75 | 76 | if __name__ == "__main__": 77 | typer.run(main) 78 | -------------------------------------------------------------------------------- /train/merge_weight.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 5 | from transformers import AutoModelForCausalLM, LlamaTokenizer 6 | from qa_expert.prompt_utils import get_additional_tokens 7 | from peft import PeftModel 8 | import torch 9 | import typer 10 | 11 | 12 | def merge_weight(save_folder: str, pretrained_path: str, checkpoint: str): 13 | tokenizer = LlamaTokenizer.from_pretrained(pretrained_path, legacy=True) 14 | tokenizer.pad_token = tokenizer.eos_token # Llama needs this 15 | added_tokens = get_additional_tokens() 16 | print("added token: ", added_tokens) 17 | tokenizer.add_special_tokens({"additional_special_tokens": added_tokens}) 18 | 19 | model = AutoModelForCausalLM.from_pretrained( 20 | pretrained_path, 21 | device_map="auto", 22 | trust_remote_code=True, 23 | use_flash_attention_2=True, 24 | torch_dtype=torch.bfloat16, 25 | ) 26 | print("model = ", model) 27 | model.resize_token_embeddings(len(tokenizer)) 28 | model.config.pad_token_id = tokenizer.pad_token_id 29 | lora_model = PeftModel.from_pretrained(model, checkpoint, torch_dtype=torch.float16) 30 | lora_model = lora_model.merge_and_unload() 31 | lora_model.save_pretrained(save_folder) 32 | tokenizer.save_pretrained(save_folder) 33 | print("final lora model: ", lora_model) 34 | 35 | 36 | if __name__ == "__main__": 37 | typer.run(merge_weight) 38 | -------------------------------------------------------------------------------- /train/requirements.txt: -------------------------------------------------------------------------------- 1 | bitsandbytes==0.41.1 2 | peft==0.5.0 3 | datasets==2.8.0 4 | transformers==4.35.0 5 | scipy==1.11.3 6 | deepspeed==0.11.1 7 | typer==0.9.0 8 | flash-attn==2.3.2 9 | pydantic~=1.10.13 10 | protobuf==3.20.0 11 | tokenizer==3.4.3 12 | sentencepiece~=0.1.99 13 | accelerate==0.23.0 14 | colorama==0.4.6 -------------------------------------------------------------------------------- /train/test.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "question": "When was state of emergency declared in the country the President's father is from?", 4 | "answer": "20 October 1952", 5 | "sub_questions": [ 6 | { 7 | "question": "The President's father is from what country?", 8 | "answer": "Kenya", 9 | "paragraph": "Multiracial Americans. Since the late twentieth century, the number of African and Caribbean ethnic African immigrants have increased in the United States. Together with publicity about the ancestry of President Barack Obama, whose father was from Kenya, some black writers have argued that new terms are needed for recent immigrants. They suggest that the term \"African-American\" should refer strictly to the descendants of African slaves and free people of color who survived the slavery era in the United States. They argue that grouping together all ethnic Africans regardless of their unique ancestral circumstances would deny the lingering effects of slavery within the American slave descendant community. They say recent ethnic African immigrants need to recognize their own unique ancestral backgrounds.", 10 | "long_answer": "The provided knowledge mentions that the President's father is from Kenya.\nAnswer: The President's father is from Kenya." 11 | }, 12 | { 13 | "question": "when was state of emergency declared in Kenya", 14 | "answer": "20 October 1952", 15 | "paragraph": "Mau Mau Uprising. On 20 October 1952, Governor Baring signed an order declaring a State of Emergency. Early the next morning, Operation Jock Scott was launched: the British carried out a mass - arrest of Jomo Kenyatta and 180 other alleged Mau Mau leaders within Nairobi. Jock Scott did not decapitate the movement's leadership as hoped, since news of the impending operation was leaked. Thus, while the moderates on the wanted list awaited capture, the real militants, such as Dedan Kimathi and Stanley Mathenge (both later principal leaders of Mau Mau's forest armies), fled to the forests.", 16 | "long_answer": "The provided knowledge mentions that \"Governor Baring signed an order declaring a State of Emergency\" and the operation launched the next morning.\nAnswer: According to the provided knowledge, a State of Emergency was declared in Kenya on October 20, 1952." 17 | } 18 | ], 19 | "final_answer": "Summary: The provided knowledge mentions that the President's father is from Kenya and a State of Emergency was declared in Kenya on October 20, 1952. Based on the knowledge provided, we know that the President's father is from Kenya and a State of Emergency was declared in Kenya on October 20, 1952. Therefore, it can be inferred that a State of Emergency was declared in the country the President's father is from.\nAnswer: According to the provided knowledge, a State of Emergency was declared in the country the President's father is from, specifically Kenya, on October 20, 1952.", 20 | "meta_info": { 21 | "src": "musique", 22 | "llm": "wizard_lm" 23 | }, 24 | "multihop": true, 25 | "tag": "musique-train.json" 26 | }, 27 | { 28 | "meta_info": { 29 | "src": "musique", 30 | "llm": "gpt-3.5-turbo-instruct", 31 | "selected": { 32 | "q_index": 0, 33 | "item_index": 2737, 34 | "src": "musique_train", 35 | "answerable": true 36 | } 37 | }, 38 | "sub_questions": [ 39 | { 40 | "question": "In what place did Buddy Stewart die?", 41 | "answer": "New Mexico", 42 | "paragraph": "Buddy Stewart. Buddy Stewart \"(né\" Albert James Byrne, Jr; 1922 in Derry, New Hampshire — 1 February 1950 Deming, New Mexico) was an American jazz singer. His adopted stage surname is standardized in most biographies, including \"The Jazz Discography,\" as \"Stewart;\" but it was sometimes also spelled \"Stuart.\"", 43 | "long_answer": "The knowledge states that Buddy Stewart died on February 1, 1950 in Deming, New Mexico.\nAnswer: Buddy Stewart died in Deming, New Mexico." 44 | } 45 | ], 46 | "multihop": false, 47 | "answer": null, 48 | "final_answer": "The knowledge states that Buddy Stewart died on February 1, 1950 in Deming, New Mexico.\nAnswer: Buddy Stewart died in Deming, New Mexico.", 49 | "question": "In what place did Buddy Stewart die?", 50 | "tag": "musique-single_train.json" 51 | }, 52 | { 53 | "meta_info": { 54 | "src": "musique", 55 | "llm": "wizard_lm", 56 | "selected": { 57 | "q_index": 1, 58 | "item_index": 988, 59 | "src": "musique_train", 60 | "answerable": true 61 | } 62 | }, 63 | "sub_questions": [ 64 | { 65 | "question": "When did Israel and Turkey 's relations take a downturn?", 66 | "answer": "after the 2008–09 Gaza War", 67 | "paragraph": "Israel. Although Turkey and Israel did not establish full diplomatic relations until 1991, Turkey has cooperated with the State since its recognition of Israel in 1949. Turkey's ties to the other Muslim-majority nations in the region have at times resulted in pressure from Arab and Muslim states to temper its relationship with Israel. Relations between Turkey and Israel took a downturn after the 2008–09 Gaza War and Israel's raid of the Gaza flotilla. IHH, which organized the flotilla, is a Turkish charity that has been challenged on ties to Hamas and Al-Qaeda. Relations between Israel and Greece have improved since 1995 due to the decline of Israeli-Turkish relations. The two countries have a defense cooperation agreement and in 2010, the Israeli Air Force hosted Greece’s Hellenic Air Force in a joint exercise at the Uvda base. Israel is the second largest importer of Greek products in the Middle East. The joint Cyprus-Israel oil and gas explorations centered on the Leviathan gas field are an important factor for Greece, given its strong links with Cyprus. Cooperation in the world's longest sub-sea electric power cable, the EuroAsia Interconnector, has strengthened relations between Cyprus and Israel.", 68 | "long_answer": "The knowledge mentions that \"Relations between Turkey and Israel took a downturn after the 2008–09 Gaza War and Israel's raid of the Gaza flotilla.\"\nAnswer: Based on the provided knowledge, Israel and Turkey's relations took a downturn after the 2008-09 Gaza War and Israel's raid of the Gaza flotilla." 69 | } 70 | ], 71 | "multihop": false, 72 | "answer": null, 73 | "final_answer": "The knowledge mentions that \"Relations between Turkey and Israel took a downturn after the 2008–09 Gaza War and Israel's raid of the Gaza flotilla.\"\nAnswer: Based on the provided knowledge, Israel and Turkey's relations took a downturn after the 2008-09 Gaza War and Israel's raid of the Gaza flotilla.", 74 | "question": "When did Israel and Turkey 's relations take a downturn?", 75 | "tag": "musique-single_train.json" 76 | } 77 | ] -------------------------------------------------------------------------------- /train/train_model.py: -------------------------------------------------------------------------------- 1 | # mypy: ignore-errors 2 | # This script is written based on: https://github.com/MeetKai/functionary/blob/main/functionary/train/train_lora.py 3 | import os 4 | import sys 5 | from typing import Dict 6 | from datasets import load_dataset 7 | import json 8 | 9 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 10 | 11 | from qa_expert.prompt_utils import get_additional_tokens 12 | 13 | from train.custom_datasets import FAPackedDataset, CustomDataset 14 | 15 | from peft import ( 16 | LoraConfig, 17 | PeftConfig, 18 | get_peft_model, 19 | prepare_model_for_kbit_training, 20 | ) 21 | import bitsandbytes as bnb 22 | 23 | import transformers 24 | from transformers import ( 25 | LlamaTokenizerFast, 26 | LlamaTokenizer, 27 | BitsAndBytesConfig, 28 | ) 29 | import torch 30 | import math 31 | import os 32 | from dataclasses import dataclass, field 33 | from typing import Any, Optional 34 | import random 35 | from torch.utils.data import DataLoader 36 | from train.mk_patched_mistral import MistralForCausalLM 37 | import deepspeed 38 | 39 | LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) 40 | print("local rank: ", LOCAL_RANK) 41 | 42 | 43 | class DataCollatorForMaskingLabels: 44 | """This data collator is used for dynamic padding. 45 | All the data points will be padded to the max length of the mini-batch instead of the whole dataset 46 | This will reduce the training time considerably when your data points are not uniform in terms of length 47 | """ 48 | 49 | def __init__(self, tokenizer) -> None: 50 | self.tokenizer = tokenizer 51 | self.padding_side = self.tokenizer.padding_side 52 | 53 | def __call__(self, examples, return_tensors=None) -> Any: 54 | input_lengs = [] 55 | for ex in examples: 56 | input_lengs.append(len(ex["input_ids"])) 57 | max_leng = max(input_lengs) 58 | result: Dict[str, Any] = {key: [] for key in examples[0].keys()} 59 | added_pad_dic = {"input_ids": self.tokenizer.pad_token_id, "labels": -100, "attention_mask": 0} 60 | 61 | for example in examples: 62 | pad_leng = max_leng - len(example["input_ids"]) 63 | for key in result: 64 | if self.padding_side == "right": 65 | result[key].append(example[key] + [added_pad_dic[key] for _ in range(pad_leng)]) 66 | else: 67 | result[key].append([added_pad_dic[key] for _ in range(pad_leng)] + example[key]) 68 | 69 | for key in result: 70 | result[key] = torch.tensor(result[key]) 71 | return result 72 | 73 | 74 | def set_seed(seed): 75 | random.seed(seed) 76 | torch.manual_seed(seed) 77 | 78 | 79 | @dataclass 80 | class ModelArguments: 81 | model_name_or_path: Optional[str] = field(default="meta-llama/Llama-2-7b-hf") 82 | model_type: str = field(default="llama") 83 | use_lora: bool = field(default=True) 84 | model_max_length: int = field( 85 | default=4096, 86 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 87 | ) 88 | qlora: bool = field(default=False, metadata={"help": "whether using qlora or not"}) 89 | 90 | 91 | @dataclass 92 | class DataArguments: 93 | train_path: str = field(default="", metadata={"help": "Path to the training data."}) 94 | validation_path: str = field(default="", metadata={"help": "Path to the evaluation data"}) 95 | hf_data_path: str = field( 96 | default="khaimaitien/qa-expert-multi-hop-qa-V1.0", metadata={"help": "dataset from HF hub"} 97 | ) 98 | packing: bool = field(default=False, metadata={"help": "Whether use packing or not"}) 99 | train_ratio: float = field(default=1, metadata={"help": "percentage of training data to use"}) 100 | validation_ratio: float = field(default=1, metadata={"help": "percentage of validation data to use"}) 101 | 102 | 103 | @dataclass 104 | class TrainingArguments(transformers.TrainingArguments): 105 | cache_dir: Optional[str] = field(default=None) 106 | optim: str = field(default="adamw_torch") 107 | 108 | 109 | def create_peft_config(modules): 110 | """ 111 | Create Parameter-Efficient Fine-Tuning config for your model 112 | :param modules: Names of the modules to apply Lora to 113 | """ 114 | config = LoraConfig( 115 | r=16, # dimension of the updated matrices 116 | lora_alpha=64, # parameter for scaling 117 | target_modules=modules, 118 | lora_dropout=0.1, # dropout probability for layers 119 | bias="none", 120 | task_type="CAUSAL_LM", 121 | modules_to_save=["lm_head", "embed_tokens"], 122 | ) 123 | 124 | return config 125 | 126 | 127 | def create_bnb_config(): 128 | bnb_config = BitsAndBytesConfig( 129 | load_in_4bit=True, 130 | bnb_4bit_use_double_quant=True, 131 | bnb_4bit_quant_type="nf4", 132 | bnb_4bit_compute_dtype=torch.bfloat16, 133 | ) 134 | return bnb_config 135 | 136 | 137 | def find_all_linear_names(model): 138 | lora_module_names = set() 139 | for name, module in model.named_modules(): 140 | if isinstance(module, bnb.nn.Linear4bit) or isinstance(module, torch.nn.Linear): 141 | names = name.split(".") 142 | lora_module_names.add(names[0] if len(names) == 1 else names[-1]) 143 | 144 | if "lm_head" in lora_module_names: # needed for 16-bit 145 | lora_module_names.remove("lm_head") 146 | return list(lora_module_names) 147 | 148 | 149 | def get_device_map(training_args: TrainingArguments, model_args: ModelArguments) -> Optional[Dict]: 150 | device_map = None 151 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 152 | ddp = world_size != 1 153 | if model_args.qlora: 154 | if ddp and training_args.fsdp: 155 | print("FSDP is incompatible with QLORA") 156 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None 157 | if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(): 158 | print("FSDP and ZeRO3 are both currently incompatible with QLoRA.") 159 | return device_map 160 | 161 | 162 | def load_model(data_args: DataArguments, training_args: TrainingArguments, model_args: ModelArguments, tokenizer: Any): 163 | # Set RoPE scaling factor 164 | config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path) 165 | if model_args.model_type == "llama": 166 | model_max_length = model_args.model_max_length 167 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 168 | print_rank0(f"rope scaling for llamam original context length: {orig_ctx_len}, extended to: {model_max_length}") 169 | if orig_ctx_len and model_max_length > orig_ctx_len: 170 | scaling_factor = float(math.ceil(model_max_length / orig_ctx_len)) 171 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 172 | config.use_cache = False 173 | 174 | compute_dtype = torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32) 175 | 176 | if data_args.packing and model_args.model_type == "mistral": # have to monkey-patch 177 | model_class = MistralForCausalLM 178 | else: 179 | model_class = transformers.AutoModelForCausalLM 180 | 181 | print_rank0("QLORA: ", model_args.qlora) 182 | 183 | model = model_class.from_pretrained( 184 | model_args.model_name_or_path, 185 | config=config, 186 | device_map=get_device_map(training_args, model_args), 187 | trust_remote_code=True, 188 | use_flash_attention_2=True, 189 | torch_dtype=compute_dtype, 190 | quantization_config=BitsAndBytesConfig( 191 | load_in_4bit=True, 192 | bnb_4bit_use_double_quant=True, 193 | bnb_4bit_quant_type="nf4", 194 | use_flash_attention_2=True, 195 | bnb_4bit_compute_dtype=compute_dtype, 196 | ) 197 | if model_args.qlora 198 | else None, 199 | ) 200 | print_rank0("model = ", model) 201 | model.resize_token_embeddings(len(tokenizer)) 202 | model.config.pad_token_id = tokenizer.pad_token_id 203 | model.gradient_checkpointing_enable() 204 | if model_args.qlora and model_args.use_lora: 205 | model = prepare_model_for_kbit_training(model) 206 | if model_args.use_lora: 207 | print("USE LORA TRAINING, START FINDING LINEAR LAYERS NOW") 208 | modules = find_all_linear_names(model) 209 | print_rank0("linear modules: ", modules) # ["query_key_value"] 210 | model = get_peft_model(model, create_peft_config(modules)) 211 | model.config.use_cache = False 212 | print_trainable_parameters(model) 213 | return model 214 | 215 | 216 | def print_rank0(*arg): 217 | if LOCAL_RANK == 0: 218 | print(*arg) 219 | 220 | 221 | def print_trainable_parameters(model): 222 | """ 223 | Prints the number of trainable parameters in the model. 224 | """ 225 | lora_param_count = 0 226 | all_param = 0 227 | embedding_lm_head_param_count = 0 228 | for name, param in model.named_parameters(): 229 | num_params = param.numel() 230 | # if using DS Zero 3 and the weights are initialized empty 231 | if num_params == 0 and hasattr(param, "ds_numel"): 232 | num_params = param.ds_numel 233 | 234 | all_param += num_params 235 | if param.requires_grad: 236 | print_rank0(f"trainable: {name}, num_params: {num_params}") 237 | if "lm_head" in name or "embed_tokens" in name: 238 | embedding_lm_head_param_count += num_params 239 | else: 240 | lora_param_count += num_params 241 | trainable_params = embedding_lm_head_param_count + lora_param_count 242 | print_rank0( 243 | f"all params: {all_param:,d} || trainable params: {trainable_params:,d} || trainable%: {100 * trainable_params / all_param}" 244 | ) 245 | print_rank0(f"embedding_lm_head_param_count={embedding_lm_head_param_count}||loara_param={lora_param_count}") 246 | print_rank0( 247 | f"embedding_lm_head_param_count %={embedding_lm_head_param_count * 100 / all_param}||loara_param %={lora_param_count * 100 / all_param}" 248 | ) 249 | 250 | 251 | def print_some_examples(ds, tokenizer): 252 | data_loader = DataLoader(ds, batch_size=3) 253 | count = 0 254 | for batch in data_loader: 255 | if count == 0: 256 | print_rank0("keys in batch: ", batch.keys()) 257 | print_rank0("--------------****Example data point****---------------") 258 | print("device: ", batch["input_ids"].device) 259 | print_rank0("shape of input_ids: ", batch["input_ids"].shape) # B x L 260 | print_rank0("shape of labels: ", batch["labels"].shape) 261 | print_rank0("shape of attention_mask: ", batch["attention_mask"].shape) 262 | # print_rank0('input_ids: ', batch["input_ids"].tolist()) 263 | # print_rank0('labels: ', batch["labels"].tolist()) 264 | print_rank0("attention mask: ", batch["attention_mask"]) 265 | input_ids = batch["input_ids"][0].tolist() 266 | labels = batch["labels"][0].tolist() 267 | for i in range(len(labels)): 268 | if labels[i] == -100: 269 | labels[i] = tokenizer.pad_token_id 270 | print_rank0("++++input_ids: ") 271 | print_rank0(tokenizer.decode(input_ids)) 272 | print_rank0("++++labels: ") 273 | print_rank0(tokenizer.decode(labels)) 274 | count += 1 275 | if count == 3: 276 | break 277 | 278 | 279 | def read_dataset(data_args: DataArguments, training_args: TrainingArguments, tokenizer: Any, ds_type: str): 280 | ds_class = CustomDataset 281 | if data_args.packing: 282 | ds_class = FAPackedDataset # if packing --> Use PackedDataset 283 | 284 | # The way we read dataset is: 285 | # Rank 0 will process the dataset and save the result to cached_folder, other ranks will read from the cached_folder 286 | cached_folder = os.path.join(training_args.output_dir, f"{ds_type}_cached") 287 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 288 | 289 | if training_args.local_rank > 0: # If this is not rank 0, stay here, wait for rank 0 to process the data 290 | print(f"process: {LOCAL_RANK} wait for main process to prepare the training data") 291 | torch.distributed.barrier() 292 | else: # rank 0 process the data and save to cached_folder 293 | if not os.path.exists(training_args.output_dir): 294 | os.mkdir(training_args.output_dir) 295 | if not os.path.exists(cached_folder): 296 | os.mkdir(cached_folder) 297 | 298 | data_path = data_args.train_path if ds_type == "train" else data_args.validation_path 299 | data_ratio = data_args.train_ratio if ds_type == "train" else data_args.validation_ratio 300 | with open(data_path, "r") as file: 301 | raw_data = json.loads(file.read()) 302 | if data_ratio < 1: 303 | size = int(len(raw_data) * data_ratio) 304 | raw_data = raw_data[:size] 305 | 306 | print(f"{ds_type} size: : {len(raw_data)}") 307 | # ignore_cached=True to ignore the cached if exist, rank 0 will always process the data 308 | ds = ds_class(raw_data, tokenizer, cached_folder=cached_folder, ignore_cached=True) 309 | print(f"process: {LOCAL_RANK} finish processing data") 310 | if world_size > 1: # only run this if this is training with multiple GPUs 311 | torch.distributed.barrier() # allow other ranks to execute 312 | 313 | # All ranks will read the processed data from cached_path created by rank 0 314 | ds = ds_class(None, tokenizer, cached_folder=cached_folder, ignore_cached=False) 315 | if LOCAL_RANK == 0: 316 | ds.stat() # print some statistics about the dataset 317 | return ds 318 | 319 | 320 | def train(): 321 | set_seed(100) 322 | argument_parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 323 | model_args, data_args, training_args = argument_parser.parse_args_into_dataclasses() 324 | pretrained_model = model_args.model_name_or_path 325 | 326 | # initialize tokenizer 327 | # if model_args.model_type == "llama": 328 | tokenizer = LlamaTokenizerFast.from_pretrained( 329 | pretrained_model, legacy=True, model_max_length=model_args.model_max_length 330 | ) 331 | tokenizer.pad_token = tokenizer.eos_token # Llama needs this 332 | if model_args.model_type == "mistral": 333 | print_rank0("set padding_side = left for Mistral") 334 | tokenizer.padding_side = "left" 335 | added_tokens = get_additional_tokens() 336 | print_rank0("added token: ", added_tokens) 337 | tokenizer.add_special_tokens({"additional_special_tokens": added_tokens}) 338 | print_rank0("total number of tokens: ", len(tokenizer)) 339 | print_rank0("tokenizer: ", tokenizer) 340 | 341 | # read data 342 | train_ds = read_dataset(data_args, training_args, tokenizer, "train") 343 | valid_ds = read_dataset(data_args, training_args, tokenizer, "validation") 344 | print_rank0(f"train_size: {len(train_ds)}; validation_size: {len(valid_ds)}") 345 | 346 | print_some_examples(train_ds, tokenizer) 347 | model = load_model(data_args, training_args, model_args, tokenizer) 348 | 349 | print_rank0("training args: \n", training_args.to_json_string()) 350 | trainer = transformers.Trainer( 351 | model=model, 352 | train_dataset=train_ds, 353 | eval_dataset=valid_ds, 354 | args=training_args, 355 | ) 356 | 357 | print_rank0("Training ...") 358 | trainer.train() 359 | 360 | 361 | if __name__ == "__main__": 362 | train() 363 | -------------------------------------------------------------------------------- /train/upload_model_to_hf.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import HfApi 2 | import typer 3 | from huggingface_hub import login 4 | 5 | login() 6 | 7 | api = HfApi() 8 | 9 | 10 | def upload_model(model_folder: str, repo_id: str): 11 | api.upload_folder( 12 | folder_path=model_folder, 13 | repo_id=repo_id, 14 | repo_type="model", 15 | ) 16 | 17 | 18 | if __name__ == "__main__": 19 | typer.run(upload_model) 20 | --------------------------------------------------------------------------------