├── .gitignore ├── LICENSE ├── README.md ├── data_prep ├── 1_collect_data.py ├── 2_add_negatives.py ├── 3_add_continuous_labels.py └── 4_make_cont_bin_training_data.py ├── run_bier.ipynb └── training ├── rev_train_config.yaml ├── run_rev_train.sh ├── run_train.sh └── train_config.yaml /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Lightblue 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LB Reranker v1.0 2 | 3 | image 4 | 5 | 6 | The LB Reranker has been trained to determine the relatedness of a given query to a piece of text, therefore allowing it to be used as a ranker or reranker in various retrieval-based tasks. 7 | 8 | This model is fine-tuned from a [Qwen/Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) model checkpoint and was trained for roughly 5.5 hours using the 8 x L20 instance ([ecs.gn8is-8x.32xlarge](https://www.alibabacloud.com/help/en/ecs/user-guide/gpu-accelerated-compute-optimized-and-vgpu-accelerated-instance-families-1)) on [Alibaba Cloud](https://www.alibabacloud.com/). 9 | 10 | The training data for this model can be found at [lightblue/reranker_continuous_filt_max7_train](https://huggingface.co/datasets/lightblue/reranker_continuous_filt_max7_train). 11 | The code for generating this data is listed in our Github repo [here](https://github.com/lightblue-tech/lb-reranker/tree/main/data_prep) and the code for training of the model can be found [here](https://github.com/lightblue-tech/lb-reranker/tree/main/training). Note that the training is conducted using [Llama Factory](https://github.com/hiyouga/LLaMA-Factory), which is installed at `/root/LLaMA-Factory`. You may need to change some of the training code to match your Llama Factory set-up to replicate our training. 12 | 13 | Trained on data in over 95 languages, this model is applicable to a broad range of use cases. 14 | 15 | This model has three main benefits over comparable rerankers. 16 | 1. It has shown slightly higher performance on evaluation benchmarks. 17 | 2. It has been trained on more languages than any previous model. 18 | 3. It is a simple Causal LM model trained to output a string between "1" and "7". 19 | 20 | This last point means that this model can be used natively with many widely available inference packages, including vLLM and LMDeploy. 21 | This in turns allows our reranker to benefit from improvements to inference as and when these packages release them. 22 | 23 | # How to use 24 | 25 | The model was trained to expect an input such as: 26 | 27 | ``` 28 | <<>> 29 | {your_query_here} 30 | 31 | <<>> 32 | {your_context_here} 33 | ``` 34 | 35 | And to output a string of a number between 1-7. 36 | 37 | In order to make a continuous score that can be used for reranking query-context pairs (i.e. a method with few ties), we calculate the expectation value of the scores. 38 | 39 | We include scripts to do this in both vLLM and LMDeploy: 40 | 41 | #### vLLM 42 | 43 | Install [vLLM](https://github.com/vllm-project/vllm/) using `pip install vllm`. 44 | 45 | ```python 46 | from vllm import LLM, SamplingParams 47 | import numpy as np 48 | 49 | def make_reranker_input(t, q): 50 | return f"<<>>\n{q}\n\n<<>>\n{t}" 51 | 52 | def make_reranker_training_datum(context, question): 53 | system_message = "Given a query and a piece of text, output a score of 1-7 based on how related the query is to the text. 1 means least related and 7 is most related." 54 | 55 | return [ 56 | {"role": "system", "content": system_message}, 57 | {"role": "user", "content": make_reranker_input(context, question)}, 58 | ] 59 | 60 | def get_prob(logprob_dict, tok_id): 61 | return np.exp(logprob_dict[tok_id].logprob) if tok_id in logprob_dict.keys() else 0 62 | 63 | llm = LLM("lightblue/lb-reranker-v1.0") 64 | sampling_params = SamplingParams(temperature=0.0, logprobs=14, max_tokens=1) 65 | tok = llm.llm_engine.tokenizer.tokenizer 66 | idx_tokens = [tok.encode(str(i))[0] for i in range(1, 8)] 67 | 68 | query_texts = [ 69 | ("What is the scientific name of apples?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."), 70 | ("What is the Chinese word for 'apple'?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."), 71 | ("What is the square root of 999?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."), 72 | ] 73 | 74 | chats = [make_reranker_training_datum(c, q) for q, c in query_texts] 75 | responses = llm.chat(chats, sampling_params) 76 | probs = np.array([[get_prob(r.outputs[0].logprobs[0], y) for y in idx_tokens] for r in responses]) 77 | 78 | N = probs.shape[1] 79 | M = probs.shape[0] 80 | idxs = np.tile(np.arange(1, N + 1), M).reshape(M, N) 81 | 82 | expected_vals = (probs * idxs).sum(axis=1) 83 | print(expected_vals) 84 | # [6.66570732 1.86686378 1.01102923] 85 | ``` 86 | 87 | #### LMDeploy 88 | 89 | Install [LMDeploy](https://github.com/InternLM/lmdeploy) using `pip install lmdeploy`. 90 | 91 | ```python 92 | # Un-comment this if running in a Jupyter notebook, Colab etc. 93 | # import nest_asyncio 94 | # nest_asyncio.apply() 95 | 96 | from lmdeploy import GenerationConfig, ChatTemplateConfig, pipeline 97 | import numpy as np 98 | 99 | def make_reranker_input(t, q): 100 | return f"<<>>\n{q}\n\n<<>>\n{t}" 101 | 102 | def make_reranker_training_datum(context, question): 103 | system_message = "Given a query and a piece of text, output a score of 1-7 based on how related the query is to the text. 1 means least related and 7 is most related." 104 | 105 | return [ 106 | {"role": "system", "content": system_message}, 107 | {"role": "user", "content": make_reranker_input(context, question)}, 108 | ] 109 | 110 | def get_prob(logprob_dict, tok_id): 111 | return np.exp(logprob_dict[tok_id]) if tok_id in logprob_dict.keys() else 0 112 | 113 | pipe = pipeline( 114 | "lightblue/lb-reranker-v1.0", 115 | chat_template_config=ChatTemplateConfig( 116 | model_name='qwen2d5', 117 | capability='chat' 118 | ) 119 | ) 120 | tok = pipe.tokenizer.model 121 | idx_tokens = [tok.encode(str(i))[0] for i in range(1, 8)] 122 | 123 | query_texts = [ 124 | ("What is the scientific name of apples?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."), 125 | ("What is the Chinese word for 'apple'?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."), 126 | ("What is the square root of 999?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."), 127 | ] 128 | 129 | chats = [make_reranker_training_datum(c, q) for q, c in query_texts] 130 | responses = pipe( 131 | chats, 132 | gen_config=GenerationConfig(temperature=1.0, logprobs=14, max_new_tokens=1, do_sample=True) 133 | ) 134 | probs = np.array([[get_prob(r.logprobs[0], y) for y in idx_tokens] for r in responses]) 135 | 136 | N = probs.shape[1] 137 | M = probs.shape[0] 138 | idxs = np.tile(np.arange(1, N + 1), M).reshape(M, N) 139 | 140 | expected_vals = (probs * idxs).sum(axis=1) 141 | print(expected_vals) 142 | # [6.66415229 1.84342025 1.01133205] 143 | ``` 144 | 145 | # Evaluation 146 | 147 | We perform an evaluation on 9 datasets from the [BEIR benchmark](https://github.com/beir-cellar/beir) that none of the evaluated models have been trained upon (to our knowledge). 148 | 149 | * Arguana 150 | * Dbpedia-entity 151 | * Fiqa 152 | * NFcorpus 153 | * Scidocs 154 | * Scifact 155 | * Trec-covid-v2 156 | * Vihealthqa 157 | * Webis-touche2020 158 | 159 | We evaluate on a subset of all queries (the first 250) to save evaluation time. 160 | 161 | We find that our model performs similarly or better than many of the state-of-the-art reranker models in our evaluation, without compromising on inference speed. 162 | 163 | We make our evaluation code and results available [on our Github](https://github.com/lightblue-tech/lb-reranker/blob/main/run_bier.ipynb). 164 | 165 | ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64b63f8ad57e02621dc93c8b/xkNzCABFUmU7UmDXUduiz.png) 166 | 167 | ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64b63f8ad57e02621dc93c8b/P-XCA3TGHqDSX8k6c4hCE.png) 168 | 169 | As we can see, this reranker attains greater IR evaluation metrics compared to the two benchmarks we include for all positions apart from @1. 170 | 171 | ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64b63f8ad57e02621dc93c8b/puhhWseBOcIyOEdW4L-B0.png) 172 | 173 | We also show that our model is, on average, faster than the BGE reranker v2. 174 | 175 | # License 176 | 177 | We share this model under an Apache 2.0 license. 178 | 179 | # Developed by 180 | 181 | 182 | Lightblue technology logo 183 | 184 | 185 | This model was trained by Peter Devine ([ptrdvn](https://huggingface.co/ptrdvn)) for Lightblue 186 | -------------------------------------------------------------------------------- /data_prep/1_collect_data.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, concatenate_datasets, Dataset 2 | from datasets.features.features import Features, Value, Sequence 3 | import re 4 | from tqdm.auto import tqdm 5 | import pandas as pd 6 | import ast 7 | from cryptography.fernet import Fernet 8 | import os 9 | import kaggle 10 | from kaggle.api.kaggle_api_extended import KaggleApi 11 | import zipfile 12 | from io import StringIO 13 | from functools import partial 14 | from huggingface_hub import hf_hub_download 15 | tqdm.pandas() 16 | 17 | def prepare_hotpotqa(): 18 | # Concat all relevant contexts together as our one answerable context 19 | ds = load_dataset("hotpotqa/hotpot_qa", "fullwiki", split="train", trust_remote_code=True) 20 | 21 | ds = ds.map(lambda x: { 22 | "positives": ["\n".join([t] + s) for t, s in zip(x["context"]["title"], x["context"]["sentences"]) if t in x["supporting_facts"]["title"]], 23 | "negatives": ["\n".join([t] + s) for t, s in zip(x["context"]["title"], x["context"]["sentences"]) if t not in x["supporting_facts"]["title"]], 24 | "dataset_name": "hotpot_qa", 25 | "language": "en", 26 | "doc_id": None, 27 | }) 28 | 29 | # add all Hotpot positive contexts together as questions require all contexts to be answered fully 30 | ds = ds.map(lambda x: { 31 | "positives": ["\n".join(x["positives"])], 32 | }, num_proc=32) 33 | 34 | return ds 35 | 36 | def get_trivia_qa_contexts(row): 37 | contexts = [] 38 | if len(row["entity_pages"]["wiki_context"]) > 0: 39 | for filename, title, context in zip(row["entity_pages"]["filename"], row["entity_pages"]["title"], row["entity_pages"]["wiki_context"]): 40 | contexts.append(f"{filename}\n{title}\n{context}") 41 | 42 | if len(row["search_results"]["search_context"]) > 0: 43 | for title, description, context in zip(row["search_results"]["title"], row["search_results"]["description"], row["search_results"]["search_context"]): 44 | contexts.append(f"{title}\n{description}\n{context}") 45 | return contexts 46 | 47 | def prepare_triviaqa(): 48 | 49 | ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="train") 50 | 51 | ds = ds.map(lambda x: { 52 | "answer": x["answer"]["value"], 53 | "positives": get_trivia_qa_contexts(x), 54 | "negatives": None, 55 | "dataset_name": "trivia_qa", 56 | "language": "en", 57 | "doc_id": None 58 | }, num_proc=16) 59 | 60 | return ds 61 | 62 | # This dataset is included in our MLQA implementation 63 | # def prepare_squad(): 64 | 65 | # ds = load_dataset("rajpurkar/squad", split="train") 66 | 67 | # ds = ds.map(lambda x: { 68 | # "answer": x["answers"]["text"][0], 69 | # "positives": [x["title"] + "\n" + x["context"]], 70 | # "negatives": None, 71 | # "dataset_name": "squad", 72 | # "language": "en", 73 | # "doc_id": set([x["title"]]), 74 | # }, num_proc=16) 75 | 76 | # return ds 77 | 78 | def prepare_pubmedqa(): 79 | 80 | ds = load_dataset("qiaojin/PubMedQA", "pqa_unlabeled", split="train") 81 | 82 | ds = ds.map(lambda x: { 83 | "question": x["question"], 84 | "answer": x["long_answer"], 85 | "positives": ["\n".join(x["context"]["contexts"])], 86 | "negatives": None, 87 | "dataset_name": "pubmedqa", 88 | "language": "en", 89 | "doc_id": None, 90 | }, num_proc=16) 91 | 92 | return ds 93 | 94 | def get_mldr_single_lang(lang): 95 | return load_dataset("Shitao/MLDR", lang, split="train", trust_remote_code=True).map(lambda x: { 96 | "question": x["query"], 97 | "answer": None, 98 | "positives": [y["text"] for y in x["positive_passages"]], 99 | "negatives": [y["text"] for y in x["negative_passages"]], 100 | "dataset_name": "mldr", 101 | "language": lang, 102 | "doc_id": None, 103 | }, num_proc=16) 104 | 105 | def prepare_mldr(): 106 | 107 | langs = ['ar', 'de', 'en', 'es', 'fr', 'hi', 'it', 'ja', 'ko', 'pt', 'ru', 'th', 'zh'] 108 | 109 | ds = concatenate_datasets([get_mldr_single_lang(l) for l in tqdm(langs)]) 110 | 111 | return ds 112 | 113 | def get_scandi_qa_single_lang(lang): 114 | ds = load_dataset("alexandrainst/scandi-qa", lang, split="train") 115 | df = ds.to_pandas() 116 | grouped_df = df.groupby("question").apply(lambda x: { 117 | "answer": x["answers"].apply(lambda y: y["text"][0]).tolist()[0], 118 | "positives": x[ 119 | x["answers"].apply(lambda y: y["answer_start"][0] != -1) 120 | ]["context"].tolist(), 121 | "doc_id": set(x[ 122 | x["answers"].apply(lambda y: y["answer_start"][0] != -1) 123 | ]["title_en"].tolist()), 124 | "negatives": None} 125 | ) 126 | 127 | joined_df = pd.DataFrame(grouped_df.tolist()) 128 | joined_df["question"] = grouped_df.index 129 | joined_df = joined_df[["question", "answer", "positives", "negatives", "doc_id"]] 130 | joined_df["answer"] = joined_df["answer"].apply(lambda x: x if len(x) > 0 else None) 131 | joined_df = joined_df[~joined_df["answer"].isna()] 132 | joined_df["dataset_name"] = "scandi_qa" 133 | joined_df["language"] = lang 134 | 135 | return Dataset.from_pandas(joined_df) 136 | 137 | def prepare_scandiqa(): 138 | langs = ['da', 'no', 'sv'] 139 | 140 | ds = concatenate_datasets([get_scandi_qa_single_lang(l) for l in tqdm(langs)]) 141 | 142 | return ds 143 | 144 | def prepare_logqa(): 145 | ds = load_dataset( 146 | "json", 147 | data_files={ 148 | "train": "https://raw.githubusercontent.com/LogQA-dataset/LogQA/refs/heads/main/data/HDFS/qa.json.train" 149 | }, 150 | split="train" 151 | ) 152 | 153 | ds = ds.map(lambda x: { 154 | "question": x["Question"], 155 | "answer": x["Answer"], 156 | "positives": [x["RawLog"]], 157 | "negatives": None, 158 | "dataset_name": "logqa", 159 | "language": "en", 160 | "doc_id": None, 161 | }, num_proc=16) 162 | 163 | return ds 164 | 165 | def prepare_cpgqa(): 166 | ds = load_dataset( 167 | "json", 168 | data_files={ 169 | "train": "https://raw.githubusercontent.com/mmahbub/cpgQA/refs/heads/main/dataset/cpgQA-v1.0.json" 170 | }, 171 | split="train" 172 | ) 173 | 174 | ds = ds.map(lambda x: { 175 | "question": x["data"]["paragraphs"]["qas"][0]["question"], 176 | "answer": x["data"]["paragraphs"]["qas"][0]["answers"][0]["text"], 177 | "positives": [x["data"]["paragraphs"]["context"]], 178 | "negatives": None, 179 | "dataset_name": "cpgqa", 180 | "language": "en", 181 | "doc_id": None, 182 | }, num_proc=16) 183 | 184 | return ds 185 | 186 | def prepare_sleepqa(): 187 | ds = load_dataset( 188 | "json", 189 | data_files={ 190 | "train": "https://raw.githubusercontent.com/IvaBojic/SleepQA/refs/heads/main/data/training/sleep-train.json" 191 | }, 192 | split="train" 193 | ) 194 | 195 | ds = ds.map(lambda x: { 196 | "question": x["question"], 197 | "answer": x["answers"][0], 198 | "positives": [y["title"] + "\n" + y["text"] for y in x["positive_ctxs"]], 199 | "negatives": [y["title"] + "\n" + y["text"] for y in x["negative_ctxs"]], 200 | "dataset_name": "sleepqa", 201 | "language": "en", 202 | "doc_id": set([y["title"] for y in x["positive_ctxs"]]), 203 | }, num_proc=16) 204 | 205 | return ds 206 | 207 | def prepare_jqara(): 208 | ds = load_dataset("hotchpotch/JQaRA", split="dev") 209 | 210 | df = ds.to_pandas() 211 | df["text"] = df["title"] + "\n" + df["text"] 212 | 213 | grouped_series = df.groupby("question").apply(lambda x: { 214 | "answer": x["answers"].tolist()[0][0], 215 | "positives": x[x["label"] == 1]["text"].tolist(), 216 | "negatives": x[x["label"] == 0]["text"].tolist(), 217 | "doc_id": set(x[x["label"] == 1]["title"].tolist()), 218 | }) 219 | 220 | joined_df = pd.DataFrame(grouped_series.tolist()) 221 | joined_df["question"] = grouped_series.index 222 | joined_df["dataset_name"] = "jqara" 223 | joined_df["language"] = "ja" 224 | 225 | ds = Dataset.from_pandas(joined_df) 226 | 227 | return ds 228 | 229 | def get_indicqa_single_lang(lang): 230 | ds = load_dataset("ai4bharat/IndicQA", lang, split="test", trust_remote_code=True).filter(lambda x: len(x["answers"]["text"][0])) 231 | 232 | ds = ds.map(lambda x: { 233 | "question": x["question"], 234 | "answer": x["answers"]["text"][0], 235 | "positives": [x["context"]], 236 | "negatives": None, 237 | "dataset_name": "indicqa", 238 | "language": lang.split(".")[-1], 239 | "doc_id": None, 240 | }, num_proc=16) 241 | 242 | return ds 243 | 244 | def prepare_indicqa(): 245 | langs = ['indicqa.as', 'indicqa.bn', 'indicqa.gu', 'indicqa.hi', 'indicqa.kn', 'indicqa.ml', 'indicqa.mr', 'indicqa.or', 'indicqa.pa', 'indicqa.ta', 'indicqa.te'] 246 | 247 | ds = concatenate_datasets([get_indicqa_single_lang(l) for l in tqdm(langs)]) 248 | 249 | return ds 250 | 251 | def prepare_qasports(): 252 | ds = load_dataset("PedroCJardim/QASports", "all", split="train").filter( 253 | lambda x: isinstance(x["answer"], str) and len(x["answer"]) > 0, 254 | num_proc=16 255 | ) 256 | 257 | ds = ds.map(lambda x: { 258 | "question": x["question"], 259 | "answer": ast.literal_eval(x["answer"])["text"], 260 | "positives": [x["context"]], 261 | "negatives": None, 262 | "dataset_name": "qasports", 263 | "language": "en", 264 | "doc_id": set([x["context_id"]]), 265 | }, num_proc=16) 266 | 267 | return ds 268 | 269 | def prepare_lsat(): 270 | # Add multiple choice answers to question 271 | ds = concatenate_datasets([load_dataset( 272 | "json", 273 | data_files={ 274 | "train": f"https://raw.githubusercontent.com/zhongwanjun/AR-LSAT/refs/heads/main/complete_lsat_data/train_{x}.json" 275 | }, 276 | split="train" 277 | ) for x in ["ar", "lr", "rc"]]) 278 | 279 | ds = ds.map(lambda x: { 280 | "question": "\n".join([x["question"]] + x["answers"]), 281 | "answer": x["answers"][x["label"]], 282 | "positives": [x["context"]], 283 | "negatives": None, 284 | "dataset_name": "lsat", 285 | "language": "en", 286 | "doc_id": set([x["context"]]), 287 | }, num_proc=16) 288 | 289 | return ds 290 | 291 | def parse_squad(row): 292 | return { 293 | "positives": [row["context"].strip()], 294 | "question": row["question"].strip(), 295 | "answer": row["answers"]["text"][0].strip() 296 | } 297 | 298 | def prepare_m2qa(): 299 | 300 | lang_dict = { 301 | "chinese": "zh", 302 | "german": "de", 303 | "turkish": "tr", 304 | } 305 | 306 | domains = [ 307 | "creative_writing", 308 | "news", 309 | "product_reviews" 310 | ] 311 | 312 | ds_list = [] 313 | 314 | for lang in tqdm(lang_dict): 315 | ds = concatenate_datasets([ 316 | load_dataset( 317 | "UKPLab/m2qa", 318 | f"m2qa.{lang}.{x}", 319 | split="validation", 320 | trust_remote_code=True 321 | ) for x in domains]) 322 | 323 | ds = ds.filter(lambda x: len(x["answers"]["text"]) > 0, num_proc=16) 324 | 325 | # Decrypt it 326 | fernet = Fernet(b"aRY0LZZb_rPnXWDSiSJn9krCYezQMOBbGII2eGkN5jo=") 327 | 328 | def decrypt(example): 329 | example["question"] = fernet.decrypt(example["question"].encode()).decode() 330 | example["context"] = fernet.decrypt(example["context"].encode()).decode() 331 | example["answers"]["text"] = [fernet.decrypt(answer.encode()).decode() for answer in example["answers"]["text"]] 332 | return example 333 | 334 | ds = ds.map(decrypt) 335 | ds = ds.map(parse_squad) 336 | ds = ds.map(lambda x: { 337 | "negatives": None, 338 | "dataset_name": "m2qa", 339 | "language": lang_dict[lang], 340 | "doc_id": set(["_".join(x["id"].split("_")[:-1])]), 341 | }) 342 | 343 | ds_list.append(ds) 344 | 345 | return concatenate_datasets(ds_list) 346 | 347 | def get_mlqa_dataset_list(): 348 | dataset_list = [ 349 | load_dataset("rajpurkar/squad", split="train").map(lambda x: {"language": "en"}, num_proc=16) 350 | ] 351 | 352 | langs = ["ar", "de", "es", "hi", "vi", "zh"] 353 | 354 | dataset_list = dataset_list + [ 355 | load_dataset( 356 | "facebook/mlqa", 357 | f"mlqa-translate-train.{l}", 358 | split="train", 359 | trust_remote_code=True 360 | ).map(lambda x: {"language": l}, num_proc=16) for l in langs 361 | ] 362 | 363 | return dataset_list 364 | 365 | def match_crossling(dataset_list, dataset_name, title_column="title"): 366 | dataset_dicts = [ 367 | { 368 | x["id"]: x for x in d 369 | } for d in tqdm(dataset_list) 370 | ] 371 | 372 | id_set = set() 373 | 374 | for d in dataset_list: 375 | id_set.update(set(d["id"])) 376 | 377 | cross_rows = [] 378 | 379 | for row_id in tqdm(id_set): 380 | rows = [x[row_id] for x in dataset_dicts if row_id in x] 381 | 382 | title = [x[title_column] for x in rows if x["language"] == "en"][0] 383 | contexts = [x["context"] for x in rows] 384 | 385 | for row in rows: 386 | cross_rows.append({ 387 | "question": row["question"], 388 | "answer": row["answers"]["text"][0], 389 | "positives": contexts, 390 | "negatives": None, 391 | "dataset_name": dataset_name, 392 | "language": row["language"], 393 | "doc_id": set([title]), 394 | }) 395 | 396 | return cross_rows 397 | 398 | def prepare_mlqa(): 399 | 400 | dataset_list = get_mlqa_dataset_list() 401 | 402 | cross_rows = match_crossling(dataset_list, "mlqa", title_column="title") 403 | 404 | return Dataset.from_pandas(pd.DataFrame(cross_rows)) 405 | 406 | def prepare_xquad(): 407 | 408 | dataset_dict = {} 409 | 410 | langs = ['ar', 'de', 'el', 'en', 'es', 'hi', 'ro', 'ru', 'th', 'tr', 'vi', 'zh'] 411 | 412 | dataset_list = [ 413 | load_dataset( 414 | "google/xquad", 415 | f"xquad.{l}", 416 | split="validation", 417 | trust_remote_code=True 418 | ).map(lambda x: {"language": l}, num_proc=16) for l in langs 419 | ] 420 | 421 | cross_rows = match_crossling(dataset_list, "xquad", title_column="context") 422 | 423 | return Dataset.from_pandas(pd.DataFrame(cross_rows)) 424 | 425 | def parse_tydi_from_bytes(text, start, end): 426 | try: 427 | return text.encode("utf-8")[start:end].decode("utf-8") 428 | except: 429 | return None 430 | 431 | def prepare_tydiqa_goldp(): 432 | 433 | ds = load_dataset("google-research-datasets/tydiqa", "primary_task", split="train").filter( 434 | lambda x: bool(x["annotations"]["minimal_answers_start_byte"][0] != -1), 435 | num_proc=16 436 | ) 437 | 438 | ds = ds.map(lambda x: { 439 | "contexts": [ 440 | parse_tydi_from_bytes(x["document_plaintext"], s, e) for s, e in zip( 441 | x["passage_answer_candidates"]["plaintext_start_byte"], 442 | x["passage_answer_candidates"]["plaintext_end_byte"] 443 | )], 444 | "answer": parse_tydi_from_bytes( 445 | x["document_plaintext"], 446 | x["annotations"]["minimal_answers_start_byte"][0], 447 | x["annotations"]["minimal_answers_end_byte"][0]), 448 | "question": x["question_text"], 449 | }, num_proc=16) 450 | 451 | ds = ds.map(lambda x: { 452 | "positives": [x["contexts"][x["annotations"]["passage_answer_candidate_index"][0]]], 453 | "negatives": [x["contexts"][i] for i in range(len(x["contexts"])) if i != x["annotations"]["passage_answer_candidate_index"][0]], 454 | }, num_proc=16) 455 | 456 | language_code_dict = { 457 | 'arabic': 'ar', 458 | 'bengali': 'bn', 459 | 'english': 'en', 460 | 'finnish': 'fi', 461 | 'indonesian': 'id', 462 | 'japanese': 'ja', 463 | 'korean': 'ko', 464 | 'russian': 'ru', 465 | 'swahili': 'sw', 466 | 'telugu': 'te', 467 | 'thai': 'th' 468 | } 469 | 470 | ds = ds.map(lambda x: { 471 | "dataset_name": "tydi", 472 | "language": language_code_dict[x["language"]], 473 | "doc_id": set([x["document_title"]]), 474 | }, num_proc=16) 475 | 476 | return ds 477 | 478 | def prepare_skquad(): 479 | ds = load_dataset("TUKE-DeutscheTelekom/skquad", split="train") 480 | ds = ds.filter(lambda x: len(x["answers"]["text"]) > 0 and len(x["answers"]["text"][0].strip()) > 0, num_proc=16) 481 | ds = ds.map(lambda x: {"context": x["title"] + "\n" + x["context"]}, num_proc=16) 482 | ds = ds.map(parse_squad, num_proc=16) 483 | ds = ds.map(lambda x: { 484 | "negatives": None, 485 | "dataset_name": "skquad", 486 | "language": "sk", 487 | "doc_id": set([x["title"]]), 488 | }, num_proc=16) 489 | return ds 490 | 491 | def prepare_arcd(): 492 | ds = load_dataset("hsseinmz/arcd", split="train") 493 | ds = ds.filter(lambda x: len(x["answers"]["text"]) > 0 and len(x["answers"]["text"][0].strip()) > 0, num_proc=16) 494 | ds = ds.map(lambda x: {"context": x["title"] + "\n" + x["context"]}, num_proc=16) 495 | ds = ds.map(parse_squad, num_proc=16) 496 | ds = ds.map(lambda x: { 497 | "negatives": None, 498 | "dataset_name": "arcd", 499 | "language": "ar", 500 | "doc_id": set([x["title"]]), 501 | }, num_proc=16) 502 | return ds 503 | 504 | def prepare_persianqa(): 505 | ds = load_dataset("SajjadAyoubi/persian_qa", split="train") 506 | ds = ds.filter(lambda x: len(x["answers"]["text"]) > 0 and len(x["answers"]["text"][0].strip()) > 0, num_proc=16) 507 | ds = ds.map(lambda x: {"context": x["title"] + "\n" + x["context"]}, num_proc=16) 508 | ds = ds.map(parse_squad, num_proc=16) 509 | ds = ds.map(lambda x: { 510 | "negatives": None, 511 | "dataset_name": "persianqa", 512 | "language": "fa", 513 | "doc_id": set([x["title"]]), 514 | }, num_proc=16) 515 | return ds 516 | 517 | def prepare_amharicqa(): 518 | df = pd.read_json("https://raw.githubusercontent.com/semantic-systems/amharic-qa/main/train_data.json") 519 | df = pd.DataFrame(pd.DataFrame(df.data.tolist()).explode("paragraphs").paragraphs.tolist()).explode("qas") 520 | df["question"] = df.qas.apply(lambda x: x["question"]) 521 | df["answer"] = df.qas.apply(lambda x: x["answers"][0]["text"]) 522 | df["positives"] = df.context.apply(lambda x: [x]) 523 | ds = Dataset.from_pandas(df) 524 | ds = ds.map(lambda x: { 525 | "negatives": None, 526 | "dataset_name": "amharicqa", 527 | "language": "am", 528 | "doc_id": set([x["document_id"]]), 529 | }, num_proc=16) 530 | return ds 531 | 532 | def prepare_chaii(): 533 | api = KaggleApi() 534 | api.authenticate() 535 | api.competition_download_files('chaii-hindi-and-tamil-question-answering', path='.') 536 | 537 | zip_path = './chaii-hindi-and-tamil-question-answering.zip' 538 | 539 | with zipfile.ZipFile(zip_path, 'r') as zip_ref: 540 | with zip_ref.open("train.csv") as file: 541 | content = file.read().decode('utf-8') 542 | df = pd.read_csv(StringIO(content)) 543 | 544 | ds = Dataset.from_pandas(df) 545 | 546 | language_map = { 547 | "tamil": "ta", 548 | "hindi": "hi", 549 | } 550 | 551 | ds = ds.map(lambda x: { 552 | "question": x["question"], 553 | "answer": x["answer_text"], 554 | "positives": [x["context"]], 555 | "negatives": None, 556 | "dataset_name": "chaii", 557 | "language": language_map[x["language"]], 558 | "doc_id": set([x["id"]]), 559 | }) 560 | 561 | return ds 562 | 563 | def prepare_sberquad(): 564 | ds = load_dataset("kuznetsoffandrey/sberquad", split="train", trust_remote_code=True) 565 | ds = ds.filter(lambda x: bool(len(x["answers"]["text"]) > 0) and bool(len(x["answers"]["text"][0]) > 0), num_proc=16) 566 | ds = ds.map(parse_squad, num_proc=16) 567 | ds = ds.map(lambda x: { 568 | "negatives": None, 569 | "dataset_name": "sberquad", 570 | "language": "ru", 571 | "doc_id": set([x["id"]]), 572 | }, num_proc=16) 573 | return ds 574 | 575 | def prepare_pira(): 576 | ds = load_dataset("paulopirozelli/pira", "default", split="train") 577 | 578 | en_ds = ds.map(lambda x: { 579 | "positives": [x["abstract"].strip()], 580 | "negatives": None, 581 | "question": x["question_en_origin"].strip(), 582 | "answer": x["answer_en_origin"].strip(), 583 | "dataset_name": "pira", 584 | "language": "en", 585 | "doc_id": set([x["id_qa"]]), 586 | }, num_proc=16) 587 | pt_ds = ds.map(lambda x: { 588 | "positives": [x["abstract_translated_pt"].strip()], 589 | "negatives": None, 590 | "question": x["question_pt_origin"].strip(), 591 | "answer": x["answer_pt_origin"].strip(), 592 | "dataset_name": "pira", 593 | "language": "pt", 594 | "doc_id": set([x["id_qa"]]), 595 | }, num_proc=16) 596 | return concatenate_datasets([en_ds, pt_ds]) 597 | 598 | def parse_jsquad(row): 599 | row_data = [] 600 | title = row["title"] 601 | paragraphs = row["paragraphs"] 602 | 603 | for p in paragraphs: 604 | context = p["context"].replace("[SEP]", "\n") 605 | questions = p["qas"] 606 | 607 | for q in questions: 608 | is_impossible = q["is_impossible"] 609 | if is_impossible: 610 | continue 611 | question = q["question"] 612 | answer = q["answers"][0]["text"] 613 | 614 | row_data.append({ 615 | "question": question, 616 | "answer": answer, 617 | "positives": [context], 618 | "negatives": None, 619 | "dataset_name": "jsquad", 620 | "language": "ja", 621 | "doc_id": set([title]), 622 | }) 623 | 624 | return row_data 625 | 626 | def prepare_jsquad(): 627 | df = pd.read_json( 628 | "https://github.com/yahoojapan/JGLUE/raw/refs/heads/main/datasets/jsquad-v1.1/train-v1.1.json" 629 | ) 630 | 631 | df = pd.DataFrame(df.data.apply(parse_jsquad).explode().tolist()) 632 | ds = Dataset.from_pandas(df) 633 | ds = ds.filter(lambda x: bool( 634 | len(x["question"].strip()) > 0 635 | ) and bool( 636 | len(x["answer"].strip()) > 0 637 | ) and bool( 638 | len(x["positives"][0].strip()) > 0 639 | ), num_proc=16) 640 | return ds 641 | 642 | def prepare_korquad(): 643 | ds = load_dataset("KorQuAD/squad_kor_v1", split="train") 644 | ds = ds.filter(lambda x: len(x["answers"]["text"]) > 0 and len(x["answers"]["text"][0].strip()) > 0, num_proc=16) 645 | ds = ds.map(lambda x: {"context": x["title"] + "\n" + x["context"]}, num_proc=16) 646 | ds = ds.map(parse_squad, num_proc=16) 647 | ds = ds.map(lambda x: { 648 | "negatives": None, 649 | "dataset_name": "korquad", 650 | "language": "ko", 651 | "doc_id": set([x["title"]]), 652 | }, num_proc=16) 653 | return ds 654 | 655 | def parse_nested(df): 656 | df = pd.DataFrame(df.data.apply(lambda x: [dict(**y, title=x["title"]) for y in x["paragraphs"]]).explode()) 657 | df = df[~df.data.isna()] 658 | df["positives"] = df.data.apply(lambda x: [x["title"] + "\n" + x["context"]]) 659 | df["data"] = df.data.apply(lambda x: [dict(**y, title=x["title"]) for y in x["qas"]]) 660 | df = df.explode("data") 661 | df["title"] = df["data"].apply(lambda x: x["title"] if isinstance(x, dict) else None) 662 | df["question"] = df["data"].apply(lambda x: x["question"] if isinstance(x, dict) else None) 663 | df["answer"] = df["data"].apply(lambda x: x["answers"][0]["text"] if isinstance(x, dict) else None) 664 | df = df.dropna() 665 | ds = Dataset.from_pandas(df) 666 | return ds 667 | 668 | def prepare_tquad(): 669 | df = pd.read_json("https://raw.githubusercontent.com/TQuad/turkish-nlp-qa-dataset/master/train-v0.1.json") 670 | ds = parse_nested(df) 671 | ds = ds.map(lambda x: { 672 | "negatives": None, 673 | "dataset_name": "tquad", 674 | "language": "tr", 675 | "doc_id": set([x["title"]]), 676 | }, num_proc=16) 677 | return ds 678 | 679 | def prepare_sqac(): 680 | df = pd.read_json("https://huggingface.co/datasets/PlanTL-GOB-ES/SQAC/resolve/main/train.json") 681 | ds = parse_nested(df) 682 | ds = ds.map(lambda x: { 683 | "negatives": None, 684 | "dataset_name": "sqac", 685 | "language": "es", 686 | "doc_id": set([x["title"]]), 687 | }, num_proc=16) 688 | return ds 689 | 690 | def prepare_germanquad(): 691 | ds = load_dataset("deepset/germanquad", split="train", trust_remote_code=True) 692 | ds = ds.filter(lambda x: len(x["answers"]["text"]) > 0 and len(x["answers"]["text"][0].strip()) > 0, num_proc=16) 693 | ds = ds.map(parse_squad, num_proc=16) 694 | ds = ds.map(lambda x: { 695 | "negatives": None, 696 | "dataset_name": "germanquad", 697 | "language": "de", 698 | "doc_id": set([x["id"]]), 699 | }, num_proc=16) 700 | return ds 701 | 702 | def prepare_kenswquad(): 703 | # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct") 704 | # max_tok_size = 1_500 705 | 706 | ds = load_dataset("lightblue/KenSwQuAD", split="train") 707 | ds = ds.filter(lambda x: len(x["answers"]["text"]) > 0 and len(x["answers"]["text"][0].strip()) > 0, num_proc=16) 708 | ds = ds.map(parse_squad, num_proc=16) 709 | # ds = ds.filter(lambda x: len(tokenizer.encode(x["context"])) < max_tok_size, num_proc=16) 710 | 711 | ds = ds.map(lambda x: { 712 | "negatives": None, 713 | "dataset_name": "kenswquad", 714 | "language": "sw", 715 | "doc_id": set([x["Story_ID"]]), 716 | }, num_proc=16) 717 | return ds 718 | 719 | def prepare_drcd(): 720 | ds = load_dataset("voidful/DRCD", split="train") 721 | ds = parse_nested(pd.DataFrame({"data": ds.to_list()})) 722 | ds = ds.map(lambda x: { 723 | "negatives": None, 724 | "dataset_name": "drcd", 725 | "language": "zh", 726 | "doc_id": set([x["title"]]), 727 | }, num_proc=16) 728 | return ds 729 | 730 | def prepare_narrativeqa(): 731 | 732 | ds = load_dataset("deepmind/narrativeqa", split="train") 733 | 734 | ds = ds.map( 735 | lambda x: { 736 | "positives": [x["document"]["summary"]["text"].strip()], 737 | "negatives": None, 738 | "question": x["question"]["text"].strip(), 739 | "answer": x["answers"][0]["text"], 740 | "dataset_name": "narrativeqa", 741 | "language": "en", 742 | "doc_id": set([x["document"]["summary"]["title"].strip()]), 743 | }, num_proc=16 744 | ) 745 | 746 | return ds 747 | 748 | def get_lb_rewording(text): 749 | pattern = r"### Reworded Text\n(.*)" 750 | match = re.search(pattern, text, re.DOTALL) 751 | if match: 752 | selected_text = match.group(1) 753 | return selected_text 754 | else: 755 | return None 756 | 757 | def get_lb_positives(row): 758 | positives = [] 759 | 760 | positives.append(row["selected_chunk"]) 761 | 762 | if row["rewording_finish_reason"] == "stop": 763 | reworded_context = get_lb_rewording(row["rewording_response"]) 764 | if reworded_context is not None: 765 | positives.append(reworded_context) 766 | 767 | if row["otherlang_rewording_finish_reason"] == "stop": 768 | otherlang_reworded_context = get_lb_rewording(row["otherlang_rewording_response"]) 769 | if otherlang_reworded_context is not None: 770 | positives.append(otherlang_reworded_context) 771 | 772 | return positives 773 | 774 | def prepare_lb_rag(): 775 | 776 | language_map = {'Amharic': 'am', 'Arabic': 'ar', 'Bulgarian': 'bg', 'Bengali': 'bn', 'Czech': 'cs', 'Danish': 'da', 'German': 'de', 'Greek': 'el', 'English': 'en', 'Spanish': 'es', 'Persian': 'fa', 'Finnish': 'fi', 'French': 'fr', 'Gujarati': 'gu', 'Hausa': 'ha', 'Hindi': 'hi', 'Hungarian': 'hu', 'Indonesian': 'id', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jv', 'Kannada': 'kn', 'Korean': 'ko', 'Lithuanian': 'lt', 'Marathi': 'mr', 'Dutch': 'nl', 'Norwegian': 'no', 'Polish': 'pl', 'Portuguese': 'pt', 'Romanian': 'ro', 'Russian': 'ru', 'Slovak': 'sk', 'Swedish': 'sv', 'Swahili': 'sw', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Tagalog': 'tl', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Vietnamese': 'vi', 'Yoruba': 'yo', 'Chinese': 'zh'} 777 | 778 | ds_list = [] 779 | 780 | for lang in sorted(language_map.values()): 781 | print(lang) 782 | ds = load_dataset( 783 | "lightblue/rag_multilingual_training_negatives", lang, split="train" 784 | ) 785 | 786 | # Multilingual 787 | muling_ds = ds.filter(lambda x: bool(len(x["otherlang_question"]) > 0) and bool(len(x["otherlang_answer"]) > 0) and bool(x["otherlang_qa_finish_reason"] == "stop"), num_proc=16) 788 | 789 | muling_ds = muling_ds.map( 790 | lambda x: { 791 | "question": x["otherlang_question"], 792 | "answer": x["otherlang_answer"], 793 | "positives": get_lb_positives(x), 794 | "negatives": x["multilingual_negatives"], 795 | "dataset_name": "lb_rag_multilingual", 796 | "language": language_map[x["other_qa_lang"]], 797 | "doc_id": None, 798 | }, num_proc=16 799 | ) 800 | 801 | # Monolingual 802 | moling_ds = ds.filter(lambda x: bool(len(x["question"]) > 0) and bool(len(x["answer"]) > 0) and bool(x["raw_qa_finish_reason"] == "stop"), num_proc=16) 803 | 804 | moling_ds = moling_ds.map( 805 | lambda x: { 806 | "question": x["question"], 807 | "answer": x["answer"], 808 | "positives": get_lb_positives(x), 809 | "negatives": x["monolingual_negatives"], 810 | "dataset_name": "lb_rag_monolingual", 811 | "language": x["language"], 812 | "doc_id": None, 813 | }, num_proc=16 814 | ) 815 | 816 | ds_list.append(muling_ds) 817 | ds_list.append(moling_ds) 818 | 819 | return concatenate_datasets(ds_list) 820 | 821 | def parse_mqa_text(name, text): 822 | name = "" if name is None else name 823 | text = "" if text is None else text 824 | namelower = name.lower().strip() 825 | textlower = text.lower().strip() 826 | 827 | question_text = "" 828 | question_text += name 829 | if namelower != textlower: 830 | question_text += "\n" + text 831 | 832 | question_text = re.sub(r"[\=\-\#]{3,}", "", question_text) 833 | return question_text.strip() 834 | 835 | def process_mqa(lang, data_type): 836 | answer_features = [{'downvote_count': Value(dtype='int64', id=None), 837 | 'is_accepted': Value(dtype='bool', id=None), 838 | 'name': Value(dtype='string', id=None), 839 | 'text': Value(dtype='string', id=None), 840 | 'upvote_count': Value(dtype='int64', id=None)}] 841 | 842 | question_features = {'answers': answer_features, 843 | 'comment_count': Value(dtype='int64', id=None), 844 | 'data_type': Value(dtype='string', id=None), 845 | 'downvote_count': Value(dtype='int64', id=None), 846 | 'hash': Value(dtype='string', id=None), 847 | 'name': Value(dtype='string', id=None), 848 | 'text': Value(dtype='string', id=None), 849 | 'upvote_count': Value(dtype='int64', id=None)} 850 | 851 | load_features = {'bucket': Value(dtype='float64', id=None), 852 | 'sub_bucket': Value(dtype='string', id=None), 853 | 'language': Value(dtype='string', id=None), 854 | 'hreflang_alternates': [{'href': Value(dtype='string', id=None), 855 | 'hreflang': Value(dtype='string', id=None)}], 856 | 'questions': [question_features], 857 | 'page_hash': Value(dtype='string', id=None), 858 | 'fasttext_language': Value(dtype='string', id=None), 859 | 'domain': Value(dtype='string', id=None)} 860 | 861 | filename = hf_hub_download(repo_id="clips/mqa", filename=f"data/data.{lang}.{data_type}.json.gz", repo_type="dataset") 862 | ds = load_dataset("json", data_files={"train": filename}, split="train", features=Features(load_features)) 863 | 864 | # Randomly sample at maximum 100K rows to make this processing tractable 865 | max_rows = 100_000 866 | ds = ds.shuffle().select(range(min(max_rows, len(ds)))) 867 | 868 | load_features["questions"] = question_features 869 | explode_features = Features(load_features) 870 | 871 | explode_mqa = lambda x: pd.DataFrame(dict(x)).explode("questions").to_dict(orient="list") 872 | 873 | ds = ds.map(explode_mqa, 874 | batched=True, 875 | batch_size=1000, 876 | num_proc=32, 877 | remove_columns=ds.column_names, 878 | features=explode_features) 879 | 880 | ds = ds.filter(lambda x: bool( 881 | isinstance(x["language"], str) 882 | ) and bool( 883 | isinstance(x["fasttext_language"], str) 884 | ) and bool( 885 | lang in x["language"].lower() 886 | ) and bool( 887 | lang in x["fasttext_language"].lower() 888 | ), 889 | num_proc=32 890 | ) 891 | 892 | load_features["accepted_answer"] = answer_features 893 | accepted_features = Features(load_features) 894 | 895 | ds = ds.map(lambda x: { 896 | "accepted_answer": [y for y in x["questions"]["answers"] if y["is_accepted"]], 897 | }, num_proc=32, features=accepted_features) 898 | 899 | ds = ds.filter(lambda x: len(x["accepted_answer"]) > 0, num_proc=32) 900 | 901 | ds = ds.map(lambda x: { 902 | "question": parse_mqa_text(x["questions"]["name"], x["questions"]["text"]), 903 | "answer": None, 904 | "positives": [parse_mqa_text(x["accepted_answer"][0]["name"], x["accepted_answer"][0]["text"])], 905 | "negatives": None, 906 | "dataset_name": f"mqa_{data_type}", 907 | "language": lang, 908 | "doc_id": set([x["domain"]]), 909 | }, num_proc=32) 910 | 911 | return ds 912 | 913 | def prepare_mqa(data_type): 914 | langs = [ 915 | 'af', 'als', 'am', 'an', 'ar', 'arz', 'as', 'ast', 'av', 'az', 'azb', 'ba', 'bar', 'bcl', 'be', 'bg', 'bh', 'bn', 'bo', 'bpy', 'br', 'bs', 'bxr', 'ca', 'cbk', 'ce', 'ceb', 'ckb', 'cs', 'cv', 'cy', 'da', 'de', 'diq', 'dsb', 'dty', 'dv', 'el', 'eml', 'en', 'eo', 'es', 'et', 'eu', 'fa', 'fi', 'fr', 'frr', 'fy', 'ga', 'gd', 'gl', 'gn', 'gom', 'gu', 'gv', 'he', 'hi', 'hif', 'hr', 'hsb', 'ht', 'hu', 'hy', 'ia', 'id', 'ie', 'ilo', 'io', 'is', 'it', 'ja', 'jbo', 'jv', 'ka', 'kk', 'km', 'kn', 'ko', 'krc', 'ku', 'kv', 'kw', 'ky', 'la', 'lb', 'lez', 'li', 'lmo', 'lo', 'lrc', 'lt', 'lv', 'mai', 'mg', 'mhr', 'min', 'mk', 'ml', 'mn', 'mr', 'ms', 'mt', 'mwl', 'my', 'myv', 'mzn', 'nah', 'nap', 'nds', 'ne', 'new', 'nl', 'nn', 'no', 'oc', 'or', 'os', 'pa', 'pam', 'pfl', 'pl', 'pms', 'pnb', 'ps', 'pt', 'qu', 'rm', 'ro', 'ru', 'sa', 'sah', 'sc', 'scn', 'sco', 'sd', 'sh', 'si', 'sk', 'sl', 'so', 'sq', 'sr', 'su', 'sv', 'sw', 'ta', 'te', 'tg', 'th', 'tk', 'tl', 'tr', 'tt', 'tyv', 'ug', 'uk', 'ur', 'uz', 'vec', 'vep', 'vi', 'vo', 'wa', 'war', 'wuu', 'xal', 'yi', 'yo', 'yue', 'zh' 916 | ] 917 | 918 | ds_list = [] 919 | 920 | for lang in langs: 921 | print(f">>> Starting {lang} - {data_type}") 922 | try: 923 | ds = process_mqa(lang, data_type) 924 | except Exception as e: 925 | print(e) 926 | print(f"### Skipping {lang} - {data_type}") 927 | continue 928 | ds_list.append(ds) 929 | 930 | cat_ds = concatenate_datasets(ds_list) 931 | 932 | return cat_ds 933 | 934 | def prepare_mqa_cqa(): 935 | return prepare_mqa("cqa") 936 | 937 | def prepare_mqa_faq(): 938 | return prepare_mqa("faq") 939 | 940 | if __name__ == "__main__": 941 | 942 | dataset_func_list = [ 943 | prepare_amharicqa, 944 | prepare_arcd, 945 | prepare_chaii, 946 | prepare_cpgqa, 947 | prepare_drcd, 948 | prepare_germanquad, 949 | prepare_hotpotqa, 950 | prepare_indicqa, 951 | prepare_jsquad, 952 | prepare_jqara, 953 | prepare_kenswquad, 954 | prepare_korquad, 955 | prepare_lb_rag, 956 | prepare_logqa, 957 | prepare_lsat, 958 | prepare_m2qa, 959 | prepare_mldr, 960 | prepare_mlqa, 961 | prepare_mqa_cqa, 962 | prepare_mqa_faq, 963 | prepare_narrativeqa, 964 | prepare_persianqa, 965 | prepare_pira, 966 | prepare_pubmedqa, 967 | prepare_qasports, 968 | prepare_sberquad, 969 | prepare_scandiqa, 970 | prepare_skquad, 971 | prepare_sleepqa, 972 | prepare_sqac, 973 | prepare_tquad, 974 | prepare_triviaqa, 975 | prepare_tydiqa_goldp, 976 | prepare_xquad, 977 | ] 978 | 979 | def write_name_to_file(name): 980 | with open("temp.txt", "a+") as f: 981 | f.write(str(name)) 982 | f.write("\n") 983 | return True 984 | 985 | final_features = { 986 | 'question': Value(dtype='string', id=None), 987 | 'answer': Value(dtype='string', id=None), 988 | 'positives': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 989 | 'negatives': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 990 | 'dataset_name': Value(dtype='string', id=None), 991 | 'language': Value(dtype='string', id=None), 992 | 'doc_id': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None) 993 | } 994 | 995 | required_cols = list(final_features.keys()) 996 | dataset_list = [] 997 | 998 | for x in tqdm(dataset_func_list): 999 | print(x) 1000 | write_name_to_file(x) 1001 | ds = x().select_columns(required_cols).map( 1002 | lambda x: {k:v for k, v in x.items()}, 1003 | features=Features(final_features), 1004 | num_proc=16 1005 | ) 1006 | ds.to_parquet("./data/" + str(x).split()[1] + ".parquet") 1007 | dataset_list.append(ds) 1008 | 1009 | ds = concatenate_datasets(dataset_list) 1010 | 1011 | ds.push_to_hub("lightblue/rag_datasets_collection", private=True) -------------------------------------------------------------------------------- /data_prep/2_add_negatives.py: -------------------------------------------------------------------------------- 1 | from tqdm.auto import tqdm 2 | import torch 3 | from FlagEmbedding import BGEM3FlagModel 4 | from datasets import load_dataset, concatenate_datasets 5 | from datasets.features.features import Value, Sequence 6 | import numpy as np 7 | 8 | flatten_list = lambda ll: [x for l in ll for x in l] 9 | 10 | def select_negatives(q_embedding, c_embeddings, context_df, doc_id_set, num_negs=10): 11 | sim = c_embeddings @ q_embedding 12 | 13 | context_df["sim"] = sim.cpu().numpy() 14 | 15 | context_df["is_pos"] = context_df["doc_id"].apply(lambda x: bool(set(x) & doc_id_set)) 16 | 17 | sorted_context_df = context_df.sort_values("sim", ascending=False) 18 | 19 | negatives = sorted_context_df[~sorted_context_df["is_pos"]].iloc[:num_negs].positives.tolist() 20 | 21 | return negatives 22 | 23 | def embed_text(text_list, model, max_len): 24 | return model.encode( 25 | text_list, 26 | max_length=max_len 27 | )['dense_vecs'] 28 | 29 | def send_array_to_gpu(array): 30 | return torch.Tensor(array).to(torch.device("cuda")) 31 | 32 | def mine_negatives(ds, model): 33 | 34 | if ds[0]["negatives"] is None: 35 | ds = ds.add_column("added_neg", [True] * len(ds)) 36 | else: 37 | ds = ds.add_column("added_neg", [False] * len(ds)) 38 | print("No need to mine negatives") 39 | return ds 40 | 41 | if ds[0]["doc_id"] is None: 42 | doc_ids = [set([i]) for i in range(len(ds))] 43 | ds = ds.remove_columns(["doc_id"]).add_column("doc_id", doc_ids) 44 | ds = ds.add_column("added_doc_id", [True] * len(ds)) 45 | else: 46 | ds = ds.add_column("added_doc_id", [False] * len(ds)) 47 | 48 | context_df = ds.select_columns( 49 | ["positives", "doc_id"] 50 | ).to_pandas().explode("positives").groupby("positives").doc_id.apply( 51 | lambda x: set(flatten_list(x)) 52 | ).reset_index(drop=False) 53 | 54 | context_df = context_df[~context_df.positives.isna()] 55 | context_df = context_df[context_df.positives.str.strip().str.len() > 0] 56 | 57 | if context_df.shape[0] < 1: 58 | print("Skipping because of no context") 59 | return None 60 | 61 | context_df["pos_len"] = context_df.positives.str.strip().str.len() 62 | context_df = context_df.sort_values("pos_len", ascending=False) 63 | 64 | c_embeddings = embed_text(context_df["positives"].tolist(), model, 8192) 65 | q_embeddings = embed_text(ds["question"], model, 8192) 66 | 67 | c_embeddings = send_array_to_gpu(c_embeddings) 68 | q_embeddings = send_array_to_gpu(q_embeddings) 69 | 70 | negatives_list = [] 71 | num_negs = 10 72 | 73 | for q_embedding, doc_id in tqdm(zip(q_embeddings, ds["doc_id"]), total=len(ds)): 74 | negatives = select_negatives(q_embedding, c_embeddings, context_df, set(doc_id), num_negs=num_negs) 75 | negatives_list.append(negatives) 76 | 77 | ds = ds.remove_columns(["negatives"]).add_column( 78 | "negatives", 79 | negatives_list, 80 | feature=Sequence(feature=Value(dtype='string', id=None), length=-1, id=None) 81 | ) 82 | 83 | return ds 84 | 85 | def sample_dataset(ds): 86 | 87 | MAX_DATASET_LANG_ROWS = 25_000 88 | MAX_MQA_LANG_ROWS = 5_000 89 | 90 | ds = ds.filter(lambda x: isinstance(x["question"], str) and bool(len(x["question"].strip()) > 0) and bool(len(x["positives"]) > 0), num_proc=32) 91 | 92 | max_rows = MAX_MQA_LANG_ROWS if "mqa" in ds[0]["dataset_name"] else MAX_DATASET_LANG_ROWS 93 | ds = ds.shuffle().select(range(min(max_rows, len(ds)))) 94 | 95 | return ds 96 | 97 | def run_get_negatives(ds, model): 98 | 99 | lang_list = ds["language"] 100 | langs = sorted(set(lang_list)) 101 | 102 | ds_list = [] 103 | if len(langs) <= 1: 104 | lang_ds = sample_dataset(ds) 105 | ds_list = [mine_negatives(lang_ds, model)] 106 | else: 107 | lang_arr = np.array(lang_list) 108 | for lang in langs: 109 | print(lang) 110 | lang_idxs = np.where(lang == lang_arr)[0].tolist() 111 | lang_ds = ds.select(lang_idxs) 112 | lang_ds = sample_dataset(lang_ds) 113 | 114 | print(f"Length of {lang} is {len(lang_ds)}") 115 | 116 | ds_list.append(mine_negatives(lang_ds, model)) 117 | 118 | ds_list = [x for x in ds_list if x is not None] 119 | 120 | if len(ds_list) < 1: 121 | return None 122 | 123 | return concatenate_datasets(ds_list) 124 | 125 | if __name__ == "__main__": 126 | 127 | model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True) 128 | 129 | original_ds = load_dataset("lightblue/rag_datasets_collection", split="train") 130 | 131 | dataset_names = sorted(set(original_ds["dataset_name"])) 132 | 133 | ds_list = [] 134 | for dataset_name in dataset_names: 135 | 136 | print(dataset_name) 137 | single_ds = original_ds.filter(lambda x: x["dataset_name"] == dataset_name, num_proc=32) 138 | 139 | ds_w_negs = run_get_negatives(single_ds, model) 140 | ds_w_negs = ds_w_negs.map(lambda x: { 141 | "negatives": [y for y in x["negatives"] if y not in x["positives"]] 142 | }, num_proc=8) 143 | ds_w_negs = ds_w_negs.filter(lambda x: len(x["negatives"]) > 0, num_proc=32) 144 | 145 | if ds_w_negs is None: 146 | print(f"None dataset at {dataset_name}") 147 | continue 148 | 149 | ds_w_negs.to_parquet(f"./negatives/{dataset_name}.parquet") 150 | 151 | ds_list.append(ds_w_negs) 152 | 153 | concatenate_datasets(ds_list).push_to_hub("lightblue/rag_datasets_selected", private=True) 154 | -------------------------------------------------------------------------------- /data_prep/3_add_continuous_labels.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, concatenate_datasets, Dataset 2 | import random 3 | from transformers import AutoTokenizer 4 | from vllm import LLM, SamplingParams 5 | import os 6 | from multiprocessing import Pool 7 | import math 8 | import numpy as np 9 | 10 | def get_hash(example): 11 | """Get hash of question field.""" 12 | return {"q_hash": hash(example["question"])} 13 | 14 | def check_uniques(example, uniques): 15 | """Check if current hash is still in set of unique hashes and remove if true.""" 16 | if example["q_hash"] in uniques: 17 | uniques.remove(example["q_hash"]) 18 | return True 19 | else: 20 | return False 21 | 22 | def remove_duplicates(ds): 23 | ds = ds.map(get_hash, num_proc=32) 24 | uniques = set(ds.unique("q_hash")) 25 | return ds.filter(check_uniques, fn_kwargs={"uniques": uniques}) 26 | 27 | def get_ds(model_name): 28 | tokenizer = AutoTokenizer.from_pretrained(model_name) 29 | MAX_LEN_MARGIN = 512 30 | MAX_LEN = 8192 - MAX_LEN_MARGIN 31 | 32 | ds = load_dataset("lightblue/rag_datasets_selected", split="train") 33 | ds = ds.add_column("row_id", list(range(len(ds)))) 34 | ds = ds.shuffle() 35 | 36 | print("Deduplicating") 37 | ds = remove_duplicates(ds) 38 | 39 | selected_columns = ['question', 'answer', 'dataset_name', 'language', 'added_neg', 'doc_id', 'added_doc_id', 'row_id'] 40 | added_columns = ['context', 'label'] 41 | 42 | ds = ds.map(lambda x: { 43 | "positives": [p for p in x["positives"] if len(tokenizer.encode(p)) < MAX_LEN], 44 | "negatives": [n for n in x["negatives"] if len(tokenizer.encode(n)) < MAX_LEN], 45 | }, num_proc=32) 46 | 47 | ds = ds.filter(lambda x: bool(len(x["positives"]) > 0) and bool(len(x["negatives"]) > 0), num_proc=32) 48 | 49 | pos_ds = ds.select_columns(selected_columns + ["positives"]).map(lambda x: { 50 | "context": random.sample(x["positives"], k=1)[0], 51 | "label": True, 52 | }, num_proc=32).select_columns(selected_columns + added_columns) 53 | 54 | neg_ds = ds.select_columns(selected_columns + ["negatives"]).map(lambda x: { 55 | "context": random.sample(x["negatives"], k=1)[0], 56 | "label": False, 57 | }, num_proc=32).select_columns(selected_columns + added_columns) 58 | 59 | new_ds = concatenate_datasets([pos_ds, neg_ds]) 60 | 61 | return new_ds 62 | 63 | def get_prob(outputs, tok_id): 64 | if bool(len(outputs) < 1) or bool(len(outputs[0].logprobs) < 1): 65 | return 0 66 | 67 | logprob_dict = outputs[0].logprobs[0] 68 | if tok_id in logprob_dict.keys(): 69 | return np.exp(logprob_dict[tok_id].logprob) 70 | else: 71 | return 0 72 | 73 | def generate_responses(inputs): 74 | text_list, model_name, gpu_id, reverse_context_query = inputs 75 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 76 | 77 | llm = LLM(model=model_name) 78 | 79 | sampling_params = SamplingParams(temperature=0.0, max_tokens=1, logprobs=10) 80 | 81 | text_query_sys_msg_snippet_1 = "query and a piece of text" if reverse_context_query else "piece of text and a query" 82 | text_query_sys_msg_snippet_2 = "text relates to the query" if reverse_context_query else "query relates to the text" 83 | system_message = f"You are a relatedness rating assistant. Given a {text_query_sys_msg_snippet_1}, output the level that the {text_query_sys_msg_snippet_2}. Your output should be single number between 1-5, with 1 meaning completely unrelated, 2 meaning mostly unrelated, 3 meaning unsure as to whether it is related or not, 4 meaning mostly related, and 5 meaning completely related." 84 | 85 | chats = [[ 86 | {"role": "system", "content": system_message}, 87 | {"role": "user", "content": f"<<>>\n{q}\n\n\n<<>>\n{t}" if reverse_context_query else f"<<>>\n{t}\n\n\n<<>>\n{q}"}, 88 | ] for t, q in text_list] 89 | 90 | responses = llm.chat(chats, sampling_params) 91 | 92 | tok = llm.llm_engine.tokenizer.tokenizer 93 | idx_tokens = [tok.encode(str(i))[0] for i in range(1, 6)] # Get token IDs for the tokens "1", "2", "3", "4", and "5" 94 | probs = np.array([[get_prob(r.outputs, y) for y in idx_tokens] for r in responses]) 95 | return probs.tolist() 96 | 97 | def get_scores(all_texts, reverse_context_query): 98 | # Modify this to suit your number of GPUs 99 | num_gpus = 8 100 | batch_size = int(math.ceil(len(all_texts) / num_gpus)) 101 | 102 | split_texts_w_idx = [] 103 | 104 | for i in range(num_gpus): 105 | start_idx = i*batch_size 106 | end_idx = start_idx + batch_size 107 | text_batch = all_texts[start_idx:end_idx] 108 | split_texts_w_idx.append((text_batch, model_name, i, reverse_context_query)) 109 | 110 | with Pool(num_gpus) as p: 111 | scores_split = p.map(generate_responses, split_texts_w_idx) 112 | 113 | scores = [] 114 | 115 | for score_split in scores_split: 116 | scores.extend(score_split) 117 | 118 | return scores 119 | 120 | if __name__ == '__main__': 121 | model_name = "Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4" 122 | new_ds = get_ds(model_name) 123 | all_texts = list(zip(new_ds["context"], new_ds["question"])) 124 | 125 | scores = get_scores(all_texts, reverse_context_query=False) 126 | new_ds = new_ds.add_column("32B_score_probs", scores) 127 | 128 | scores = get_scores(all_texts, reverse_context_query=True) 129 | new_ds = new_ds.add_column("32B_score_probs_rev", scores) 130 | 131 | new_ds = new_ds.sort("row_id") 132 | new_ds.to_parquet("lightblue__rag_datasets_selected_32B4scored_probs.parquet") 133 | new_ds.push_to_hub("lightblue/rag_datasets_selected_32B4scored_probs") 134 | -------------------------------------------------------------------------------- /data_prep/4_make_cont_bin_training_data.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from transformers import AutoTokenizer 3 | import numpy as np 4 | 5 | score_system_message = "Given a query and a piece of text, output a score of 1-7 based on how related the query is to the text. 1 means least related and 7 is most related." 6 | rev_score_system_message = "Given a piece of text and a query, output a score of 1-7 based on how related the query is to the text. 1 means least related and 7 is most related." 7 | 8 | def filter_for_correct_scores_only(x): 9 | return bool( 10 | bool( 11 | x["label"] and bool(x["mean_exp_val_max7"] >= 4) 12 | ) or bool( 13 | bool(not x["label"]) and bool(x["mean_exp_val_max7"] <= 4) 14 | ) 15 | ) 16 | 17 | format_text_query = lambda t, q: f"<<>>\n{q}\n\n<<>>\n{t}" 18 | format_query_text = lambda t, q: f"<<>>\n{t}\n\n<<>>\n{q}" 19 | 20 | def make_continuous_data(x): 21 | return { 22 | "conversations": [ 23 | { "from": "system", "value": score_system_message }, 24 | { "from": "human", "value":format_text_query(x["context"], x["question"])}, 25 | { "from": "gpt", "value": str(int(x["mean_exp_val_max7_round"])) } ] 26 | } 27 | 28 | def make_rev_continuous_data(x): 29 | return { 30 | "rev_conversations": [ 31 | { "from": "system", "value": rev_score_system_message }, 32 | { "from": "human", "value":format_query_text(x["context"], x["question"])}, 33 | { "from": "gpt", "value": str(int(x["mean_exp_val_max7_round"])) } ] 34 | } 35 | 36 | calc_exp_val = lambda probs: sum([(i+1) * (p / sum(probs)) for i, p in enumerate(probs)]) 37 | 38 | def main(): 39 | ds = load_dataset("lightblue/rag_datasets_selected_32B4scored_probs", split="train") 40 | 41 | ds = ds.filter(lambda x: bool(sum(x["32B_score_probs"]) > 0) and bool(sum(x["32B_score_probs_rev"]) > 0), num_proc=32) 42 | 43 | ds = ds.map( 44 | lambda x: { 45 | "prob_exp_val": calc_exp_val(x["32B_score_probs"]), 46 | "rev_prob_exp_val": calc_exp_val(x["32B_score_probs_rev"]), 47 | }, num_proc=32 48 | ) 49 | 50 | ds = ds.map( 51 | lambda x: { 52 | "mean_exp_val": np.array([x["prob_exp_val"], x["rev_prob_exp_val"]]).mean(), 53 | }, num_proc=32 54 | ) 55 | 56 | ds = ds.map( 57 | lambda x: { 58 | "mean_exp_val_max7": ((x["mean_exp_val"] - 1) * (6/4)) + 1 59 | }, num_proc=32 60 | ) 61 | 62 | ds = ds.map( 63 | lambda x: { 64 | "mean_exp_val_max7_round": round(x["mean_exp_val_max7"]) 65 | }, num_proc=32 66 | ) 67 | 68 | ds = ds.map(make_continuous_data, num_proc=32) 69 | ds = ds.map(make_rev_continuous_data, num_proc=32) 70 | 71 | tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") 72 | 73 | MAX_LEN = 8192 74 | 75 | ds = ds.filter( 76 | lambda x: len( 77 | tokenizer.encode("\n".join([y["value"] for y in x["conversations"]])) 78 | ) < MAX_LEN, num_proc=32 79 | ) 80 | 81 | ds = ds.shuffle() 82 | 83 | ds.filter(filter_for_correct_scores_only, num_proc=32).push_to_hub( 84 | "lightblue/reranker_continuous_filt_max7_train_extra", private=True 85 | ) 86 | 87 | if __name__ == '__main__': 88 | main() -------------------------------------------------------------------------------- /training/rev_train_config.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: Qwen/Qwen2.5-0.5B-Instruct 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: full 8 | deepspeed: /root/LLaMA-Factory/examples/deepspeed/ds_z2_config.json 9 | 10 | ### dataset 11 | dataset: reranker_continuous_filt_max7_rev_train 12 | template: qwen 13 | cutoff_len: 18000 # NOTE - this is larger than the original reranker which was set to 8192 14 | overwrite_cache: true 15 | preprocessing_num_workers: 16 16 | packing: true 17 | 18 | ### output 19 | output_dir: /root/train_outputs/Qwen2.5-0.5B-Instruct/reranker_continuous_filt_max7_rev_train 20 | logging_steps: 1 21 | save_steps: 0.99999 22 | plot_loss: true 23 | overwrite_output_dir: true 24 | 25 | ### train 26 | per_device_train_batch_size: 1 27 | gradient_accumulation_steps: 1 28 | learning_rate: 1.0e-5 29 | num_train_epochs: 1.0 30 | lr_scheduler_type: cosine 31 | warmup_ratio: 0.01 32 | bf16: true 33 | ddp_timeout: 180000000 34 | 35 | ### eval 36 | val_size: 0.01 37 | per_device_eval_batch_size: 1 38 | eval_strategy: steps 39 | eval_steps: 0.1 -------------------------------------------------------------------------------- /training/run_rev_train.sh: -------------------------------------------------------------------------------- 1 | echo '{ 2 | "reranker_continuous_filt_max7_rev_train": { 3 | "hf_hub_url": "lightblue/reranker_continuous_filt_max7_train_extra", 4 | "formatting": "sharegpt", 5 | "columns": { 6 | "messages": "rev_conversations" 7 | } 8 | } 9 | }' > /root/LLaMA-Factory/data/dataset_info.json 10 | 11 | cd /root/LLaMA-Factory && llamafactory-cli train /root/lb-reranker/training/rev_train_config.yaml 12 | 13 | rm -r /root/train_outputs/Qwen2.5-0.5B-Instruct/reranker_continuous_filt_max7_rev_train/checkpoint* 14 | huggingface-cli upload lightblue/reranker_0.5_cont_filt_7max_rev /root/train_outputs/Qwen2.5-0.5B-Instruct/reranker_continuous_filt_max7_rev_train 15 | -------------------------------------------------------------------------------- /training/run_train.sh: -------------------------------------------------------------------------------- 1 | echo '{ 2 | "reranker_continuous_filt_max7_train": { 3 | "hf_hub_url": "lightblue/reranker_continuous_filt_max7_train", 4 | "formatting": "sharegpt" 5 | } 6 | }' > /root/LLaMA-Factory/data/dataset_info.json 7 | 8 | cd /root/LLaMA-Factory && llamafactory-cli train /root/lb-reranker/training/train_config.yaml 9 | 10 | rm -r /root/train_outputs/Qwen2.5-0.5B-Instruct/reranker_continuous_filt_max7_train/checkpoint* 11 | huggingface-cli upload lightblue/reranker_0.5_cont_filt_7max /root/train_outputs/Qwen2.5-0.5B-Instruct/reranker_continuous_filt_max7_train 12 | -------------------------------------------------------------------------------- /training/train_config.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: Qwen/Qwen2.5-0.5B-Instruct 3 | 4 | ### method 5 | stage: sft 6 | do_train: true 7 | finetuning_type: full 8 | deepspeed: /root/LLaMA-Factory/examples/deepspeed/ds_z2_config.json 9 | 10 | ### dataset 11 | dataset: reranker_continuous_filt_max7_train 12 | template: qwen 13 | cutoff_len: 8192 14 | overwrite_cache: true 15 | preprocessing_num_workers: 16 16 | packing: true 17 | 18 | ### output 19 | output_dir: /root/train_outputs/Qwen2.5-0.5B-Instruct/reranker_continuous_filt_max7_train 20 | logging_steps: 1 21 | save_steps: 0.99999 22 | plot_loss: true 23 | overwrite_output_dir: true 24 | 25 | ### train 26 | per_device_train_batch_size: 1 27 | gradient_accumulation_steps: 1 28 | learning_rate: 1.0e-5 29 | num_train_epochs: 1.0 30 | lr_scheduler_type: cosine 31 | warmup_ratio: 0.01 32 | bf16: true 33 | ddp_timeout: 180000000 34 | 35 | ### eval 36 | val_size: 0.01 37 | per_device_eval_batch_size: 1 38 | eval_strategy: steps 39 | eval_steps: 0.1 --------------------------------------------------------------------------------