├── .gitignore
├── LICENSE
├── README.md
├── data_prep
├── 1_collect_data.py
├── 2_add_negatives.py
├── 3_add_continuous_labels.py
└── 4_make_cont_bin_training_data.py
├── run_bier.ipynb
└── training
├── rev_train_config.yaml
├── run_rev_train.sh
├── run_train.sh
└── train_config.yaml
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Lightblue
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LB Reranker v1.0
2 |
3 |
4 |
5 |
6 | The LB Reranker has been trained to determine the relatedness of a given query to a piece of text, therefore allowing it to be used as a ranker or reranker in various retrieval-based tasks.
7 |
8 | This model is fine-tuned from a [Qwen/Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) model checkpoint and was trained for roughly 5.5 hours using the 8 x L20 instance ([ecs.gn8is-8x.32xlarge](https://www.alibabacloud.com/help/en/ecs/user-guide/gpu-accelerated-compute-optimized-and-vgpu-accelerated-instance-families-1)) on [Alibaba Cloud](https://www.alibabacloud.com/).
9 |
10 | The training data for this model can be found at [lightblue/reranker_continuous_filt_max7_train](https://huggingface.co/datasets/lightblue/reranker_continuous_filt_max7_train).
11 | The code for generating this data is listed in our Github repo [here](https://github.com/lightblue-tech/lb-reranker/tree/main/data_prep) and the code for training of the model can be found [here](https://github.com/lightblue-tech/lb-reranker/tree/main/training). Note that the training is conducted using [Llama Factory](https://github.com/hiyouga/LLaMA-Factory), which is installed at `/root/LLaMA-Factory`. You may need to change some of the training code to match your Llama Factory set-up to replicate our training.
12 |
13 | Trained on data in over 95 languages, this model is applicable to a broad range of use cases.
14 |
15 | This model has three main benefits over comparable rerankers.
16 | 1. It has shown slightly higher performance on evaluation benchmarks.
17 | 2. It has been trained on more languages than any previous model.
18 | 3. It is a simple Causal LM model trained to output a string between "1" and "7".
19 |
20 | This last point means that this model can be used natively with many widely available inference packages, including vLLM and LMDeploy.
21 | This in turns allows our reranker to benefit from improvements to inference as and when these packages release them.
22 |
23 | # How to use
24 |
25 | The model was trained to expect an input such as:
26 |
27 | ```
28 | <<>>
29 | {your_query_here}
30 |
31 | <<>>
32 | {your_context_here}
33 | ```
34 |
35 | And to output a string of a number between 1-7.
36 |
37 | In order to make a continuous score that can be used for reranking query-context pairs (i.e. a method with few ties), we calculate the expectation value of the scores.
38 |
39 | We include scripts to do this in both vLLM and LMDeploy:
40 |
41 | #### vLLM
42 |
43 | Install [vLLM](https://github.com/vllm-project/vllm/) using `pip install vllm`.
44 |
45 | ```python
46 | from vllm import LLM, SamplingParams
47 | import numpy as np
48 |
49 | def make_reranker_input(t, q):
50 | return f"<<>>\n{q}\n\n<<>>\n{t}"
51 |
52 | def make_reranker_training_datum(context, question):
53 | system_message = "Given a query and a piece of text, output a score of 1-7 based on how related the query is to the text. 1 means least related and 7 is most related."
54 |
55 | return [
56 | {"role": "system", "content": system_message},
57 | {"role": "user", "content": make_reranker_input(context, question)},
58 | ]
59 |
60 | def get_prob(logprob_dict, tok_id):
61 | return np.exp(logprob_dict[tok_id].logprob) if tok_id in logprob_dict.keys() else 0
62 |
63 | llm = LLM("lightblue/lb-reranker-v1.0")
64 | sampling_params = SamplingParams(temperature=0.0, logprobs=14, max_tokens=1)
65 | tok = llm.llm_engine.tokenizer.tokenizer
66 | idx_tokens = [tok.encode(str(i))[0] for i in range(1, 8)]
67 |
68 | query_texts = [
69 | ("What is the scientific name of apples?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
70 | ("What is the Chinese word for 'apple'?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
71 | ("What is the square root of 999?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
72 | ]
73 |
74 | chats = [make_reranker_training_datum(c, q) for q, c in query_texts]
75 | responses = llm.chat(chats, sampling_params)
76 | probs = np.array([[get_prob(r.outputs[0].logprobs[0], y) for y in idx_tokens] for r in responses])
77 |
78 | N = probs.shape[1]
79 | M = probs.shape[0]
80 | idxs = np.tile(np.arange(1, N + 1), M).reshape(M, N)
81 |
82 | expected_vals = (probs * idxs).sum(axis=1)
83 | print(expected_vals)
84 | # [6.66570732 1.86686378 1.01102923]
85 | ```
86 |
87 | #### LMDeploy
88 |
89 | Install [LMDeploy](https://github.com/InternLM/lmdeploy) using `pip install lmdeploy`.
90 |
91 | ```python
92 | # Un-comment this if running in a Jupyter notebook, Colab etc.
93 | # import nest_asyncio
94 | # nest_asyncio.apply()
95 |
96 | from lmdeploy import GenerationConfig, ChatTemplateConfig, pipeline
97 | import numpy as np
98 |
99 | def make_reranker_input(t, q):
100 | return f"<<>>\n{q}\n\n<<>>\n{t}"
101 |
102 | def make_reranker_training_datum(context, question):
103 | system_message = "Given a query and a piece of text, output a score of 1-7 based on how related the query is to the text. 1 means least related and 7 is most related."
104 |
105 | return [
106 | {"role": "system", "content": system_message},
107 | {"role": "user", "content": make_reranker_input(context, question)},
108 | ]
109 |
110 | def get_prob(logprob_dict, tok_id):
111 | return np.exp(logprob_dict[tok_id]) if tok_id in logprob_dict.keys() else 0
112 |
113 | pipe = pipeline(
114 | "lightblue/lb-reranker-v1.0",
115 | chat_template_config=ChatTemplateConfig(
116 | model_name='qwen2d5',
117 | capability='chat'
118 | )
119 | )
120 | tok = pipe.tokenizer.model
121 | idx_tokens = [tok.encode(str(i))[0] for i in range(1, 8)]
122 |
123 | query_texts = [
124 | ("What is the scientific name of apples?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
125 | ("What is the Chinese word for 'apple'?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
126 | ("What is the square root of 999?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
127 | ]
128 |
129 | chats = [make_reranker_training_datum(c, q) for q, c in query_texts]
130 | responses = pipe(
131 | chats,
132 | gen_config=GenerationConfig(temperature=1.0, logprobs=14, max_new_tokens=1, do_sample=True)
133 | )
134 | probs = np.array([[get_prob(r.logprobs[0], y) for y in idx_tokens] for r in responses])
135 |
136 | N = probs.shape[1]
137 | M = probs.shape[0]
138 | idxs = np.tile(np.arange(1, N + 1), M).reshape(M, N)
139 |
140 | expected_vals = (probs * idxs).sum(axis=1)
141 | print(expected_vals)
142 | # [6.66415229 1.84342025 1.01133205]
143 | ```
144 |
145 | # Evaluation
146 |
147 | We perform an evaluation on 9 datasets from the [BEIR benchmark](https://github.com/beir-cellar/beir) that none of the evaluated models have been trained upon (to our knowledge).
148 |
149 | * Arguana
150 | * Dbpedia-entity
151 | * Fiqa
152 | * NFcorpus
153 | * Scidocs
154 | * Scifact
155 | * Trec-covid-v2
156 | * Vihealthqa
157 | * Webis-touche2020
158 |
159 | We evaluate on a subset of all queries (the first 250) to save evaluation time.
160 |
161 | We find that our model performs similarly or better than many of the state-of-the-art reranker models in our evaluation, without compromising on inference speed.
162 |
163 | We make our evaluation code and results available [on our Github](https://github.com/lightblue-tech/lb-reranker/blob/main/run_bier.ipynb).
164 |
165 | 
166 |
167 | 
168 |
169 | As we can see, this reranker attains greater IR evaluation metrics compared to the two benchmarks we include for all positions apart from @1.
170 |
171 | 
172 |
173 | We also show that our model is, on average, faster than the BGE reranker v2.
174 |
175 | # License
176 |
177 | We share this model under an Apache 2.0 license.
178 |
179 | # Developed by
180 |
181 |
182 |
183 |
184 |
185 | This model was trained by Peter Devine ([ptrdvn](https://huggingface.co/ptrdvn)) for Lightblue
186 |
--------------------------------------------------------------------------------
/data_prep/1_collect_data.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset, concatenate_datasets, Dataset
2 | from datasets.features.features import Features, Value, Sequence
3 | import re
4 | from tqdm.auto import tqdm
5 | import pandas as pd
6 | import ast
7 | from cryptography.fernet import Fernet
8 | import os
9 | import kaggle
10 | from kaggle.api.kaggle_api_extended import KaggleApi
11 | import zipfile
12 | from io import StringIO
13 | from functools import partial
14 | from huggingface_hub import hf_hub_download
15 | tqdm.pandas()
16 |
17 | def prepare_hotpotqa():
18 | # Concat all relevant contexts together as our one answerable context
19 | ds = load_dataset("hotpotqa/hotpot_qa", "fullwiki", split="train", trust_remote_code=True)
20 |
21 | ds = ds.map(lambda x: {
22 | "positives": ["\n".join([t] + s) for t, s in zip(x["context"]["title"], x["context"]["sentences"]) if t in x["supporting_facts"]["title"]],
23 | "negatives": ["\n".join([t] + s) for t, s in zip(x["context"]["title"], x["context"]["sentences"]) if t not in x["supporting_facts"]["title"]],
24 | "dataset_name": "hotpot_qa",
25 | "language": "en",
26 | "doc_id": None,
27 | })
28 |
29 | # add all Hotpot positive contexts together as questions require all contexts to be answered fully
30 | ds = ds.map(lambda x: {
31 | "positives": ["\n".join(x["positives"])],
32 | }, num_proc=32)
33 |
34 | return ds
35 |
36 | def get_trivia_qa_contexts(row):
37 | contexts = []
38 | if len(row["entity_pages"]["wiki_context"]) > 0:
39 | for filename, title, context in zip(row["entity_pages"]["filename"], row["entity_pages"]["title"], row["entity_pages"]["wiki_context"]):
40 | contexts.append(f"{filename}\n{title}\n{context}")
41 |
42 | if len(row["search_results"]["search_context"]) > 0:
43 | for title, description, context in zip(row["search_results"]["title"], row["search_results"]["description"], row["search_results"]["search_context"]):
44 | contexts.append(f"{title}\n{description}\n{context}")
45 | return contexts
46 |
47 | def prepare_triviaqa():
48 |
49 | ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="train")
50 |
51 | ds = ds.map(lambda x: {
52 | "answer": x["answer"]["value"],
53 | "positives": get_trivia_qa_contexts(x),
54 | "negatives": None,
55 | "dataset_name": "trivia_qa",
56 | "language": "en",
57 | "doc_id": None
58 | }, num_proc=16)
59 |
60 | return ds
61 |
62 | # This dataset is included in our MLQA implementation
63 | # def prepare_squad():
64 |
65 | # ds = load_dataset("rajpurkar/squad", split="train")
66 |
67 | # ds = ds.map(lambda x: {
68 | # "answer": x["answers"]["text"][0],
69 | # "positives": [x["title"] + "\n" + x["context"]],
70 | # "negatives": None,
71 | # "dataset_name": "squad",
72 | # "language": "en",
73 | # "doc_id": set([x["title"]]),
74 | # }, num_proc=16)
75 |
76 | # return ds
77 |
78 | def prepare_pubmedqa():
79 |
80 | ds = load_dataset("qiaojin/PubMedQA", "pqa_unlabeled", split="train")
81 |
82 | ds = ds.map(lambda x: {
83 | "question": x["question"],
84 | "answer": x["long_answer"],
85 | "positives": ["\n".join(x["context"]["contexts"])],
86 | "negatives": None,
87 | "dataset_name": "pubmedqa",
88 | "language": "en",
89 | "doc_id": None,
90 | }, num_proc=16)
91 |
92 | return ds
93 |
94 | def get_mldr_single_lang(lang):
95 | return load_dataset("Shitao/MLDR", lang, split="train", trust_remote_code=True).map(lambda x: {
96 | "question": x["query"],
97 | "answer": None,
98 | "positives": [y["text"] for y in x["positive_passages"]],
99 | "negatives": [y["text"] for y in x["negative_passages"]],
100 | "dataset_name": "mldr",
101 | "language": lang,
102 | "doc_id": None,
103 | }, num_proc=16)
104 |
105 | def prepare_mldr():
106 |
107 | langs = ['ar', 'de', 'en', 'es', 'fr', 'hi', 'it', 'ja', 'ko', 'pt', 'ru', 'th', 'zh']
108 |
109 | ds = concatenate_datasets([get_mldr_single_lang(l) for l in tqdm(langs)])
110 |
111 | return ds
112 |
113 | def get_scandi_qa_single_lang(lang):
114 | ds = load_dataset("alexandrainst/scandi-qa", lang, split="train")
115 | df = ds.to_pandas()
116 | grouped_df = df.groupby("question").apply(lambda x: {
117 | "answer": x["answers"].apply(lambda y: y["text"][0]).tolist()[0],
118 | "positives": x[
119 | x["answers"].apply(lambda y: y["answer_start"][0] != -1)
120 | ]["context"].tolist(),
121 | "doc_id": set(x[
122 | x["answers"].apply(lambda y: y["answer_start"][0] != -1)
123 | ]["title_en"].tolist()),
124 | "negatives": None}
125 | )
126 |
127 | joined_df = pd.DataFrame(grouped_df.tolist())
128 | joined_df["question"] = grouped_df.index
129 | joined_df = joined_df[["question", "answer", "positives", "negatives", "doc_id"]]
130 | joined_df["answer"] = joined_df["answer"].apply(lambda x: x if len(x) > 0 else None)
131 | joined_df = joined_df[~joined_df["answer"].isna()]
132 | joined_df["dataset_name"] = "scandi_qa"
133 | joined_df["language"] = lang
134 |
135 | return Dataset.from_pandas(joined_df)
136 |
137 | def prepare_scandiqa():
138 | langs = ['da', 'no', 'sv']
139 |
140 | ds = concatenate_datasets([get_scandi_qa_single_lang(l) for l in tqdm(langs)])
141 |
142 | return ds
143 |
144 | def prepare_logqa():
145 | ds = load_dataset(
146 | "json",
147 | data_files={
148 | "train": "https://raw.githubusercontent.com/LogQA-dataset/LogQA/refs/heads/main/data/HDFS/qa.json.train"
149 | },
150 | split="train"
151 | )
152 |
153 | ds = ds.map(lambda x: {
154 | "question": x["Question"],
155 | "answer": x["Answer"],
156 | "positives": [x["RawLog"]],
157 | "negatives": None,
158 | "dataset_name": "logqa",
159 | "language": "en",
160 | "doc_id": None,
161 | }, num_proc=16)
162 |
163 | return ds
164 |
165 | def prepare_cpgqa():
166 | ds = load_dataset(
167 | "json",
168 | data_files={
169 | "train": "https://raw.githubusercontent.com/mmahbub/cpgQA/refs/heads/main/dataset/cpgQA-v1.0.json"
170 | },
171 | split="train"
172 | )
173 |
174 | ds = ds.map(lambda x: {
175 | "question": x["data"]["paragraphs"]["qas"][0]["question"],
176 | "answer": x["data"]["paragraphs"]["qas"][0]["answers"][0]["text"],
177 | "positives": [x["data"]["paragraphs"]["context"]],
178 | "negatives": None,
179 | "dataset_name": "cpgqa",
180 | "language": "en",
181 | "doc_id": None,
182 | }, num_proc=16)
183 |
184 | return ds
185 |
186 | def prepare_sleepqa():
187 | ds = load_dataset(
188 | "json",
189 | data_files={
190 | "train": "https://raw.githubusercontent.com/IvaBojic/SleepQA/refs/heads/main/data/training/sleep-train.json"
191 | },
192 | split="train"
193 | )
194 |
195 | ds = ds.map(lambda x: {
196 | "question": x["question"],
197 | "answer": x["answers"][0],
198 | "positives": [y["title"] + "\n" + y["text"] for y in x["positive_ctxs"]],
199 | "negatives": [y["title"] + "\n" + y["text"] for y in x["negative_ctxs"]],
200 | "dataset_name": "sleepqa",
201 | "language": "en",
202 | "doc_id": set([y["title"] for y in x["positive_ctxs"]]),
203 | }, num_proc=16)
204 |
205 | return ds
206 |
207 | def prepare_jqara():
208 | ds = load_dataset("hotchpotch/JQaRA", split="dev")
209 |
210 | df = ds.to_pandas()
211 | df["text"] = df["title"] + "\n" + df["text"]
212 |
213 | grouped_series = df.groupby("question").apply(lambda x: {
214 | "answer": x["answers"].tolist()[0][0],
215 | "positives": x[x["label"] == 1]["text"].tolist(),
216 | "negatives": x[x["label"] == 0]["text"].tolist(),
217 | "doc_id": set(x[x["label"] == 1]["title"].tolist()),
218 | })
219 |
220 | joined_df = pd.DataFrame(grouped_series.tolist())
221 | joined_df["question"] = grouped_series.index
222 | joined_df["dataset_name"] = "jqara"
223 | joined_df["language"] = "ja"
224 |
225 | ds = Dataset.from_pandas(joined_df)
226 |
227 | return ds
228 |
229 | def get_indicqa_single_lang(lang):
230 | ds = load_dataset("ai4bharat/IndicQA", lang, split="test", trust_remote_code=True).filter(lambda x: len(x["answers"]["text"][0]))
231 |
232 | ds = ds.map(lambda x: {
233 | "question": x["question"],
234 | "answer": x["answers"]["text"][0],
235 | "positives": [x["context"]],
236 | "negatives": None,
237 | "dataset_name": "indicqa",
238 | "language": lang.split(".")[-1],
239 | "doc_id": None,
240 | }, num_proc=16)
241 |
242 | return ds
243 |
244 | def prepare_indicqa():
245 | langs = ['indicqa.as', 'indicqa.bn', 'indicqa.gu', 'indicqa.hi', 'indicqa.kn', 'indicqa.ml', 'indicqa.mr', 'indicqa.or', 'indicqa.pa', 'indicqa.ta', 'indicqa.te']
246 |
247 | ds = concatenate_datasets([get_indicqa_single_lang(l) for l in tqdm(langs)])
248 |
249 | return ds
250 |
251 | def prepare_qasports():
252 | ds = load_dataset("PedroCJardim/QASports", "all", split="train").filter(
253 | lambda x: isinstance(x["answer"], str) and len(x["answer"]) > 0,
254 | num_proc=16
255 | )
256 |
257 | ds = ds.map(lambda x: {
258 | "question": x["question"],
259 | "answer": ast.literal_eval(x["answer"])["text"],
260 | "positives": [x["context"]],
261 | "negatives": None,
262 | "dataset_name": "qasports",
263 | "language": "en",
264 | "doc_id": set([x["context_id"]]),
265 | }, num_proc=16)
266 |
267 | return ds
268 |
269 | def prepare_lsat():
270 | # Add multiple choice answers to question
271 | ds = concatenate_datasets([load_dataset(
272 | "json",
273 | data_files={
274 | "train": f"https://raw.githubusercontent.com/zhongwanjun/AR-LSAT/refs/heads/main/complete_lsat_data/train_{x}.json"
275 | },
276 | split="train"
277 | ) for x in ["ar", "lr", "rc"]])
278 |
279 | ds = ds.map(lambda x: {
280 | "question": "\n".join([x["question"]] + x["answers"]),
281 | "answer": x["answers"][x["label"]],
282 | "positives": [x["context"]],
283 | "negatives": None,
284 | "dataset_name": "lsat",
285 | "language": "en",
286 | "doc_id": set([x["context"]]),
287 | }, num_proc=16)
288 |
289 | return ds
290 |
291 | def parse_squad(row):
292 | return {
293 | "positives": [row["context"].strip()],
294 | "question": row["question"].strip(),
295 | "answer": row["answers"]["text"][0].strip()
296 | }
297 |
298 | def prepare_m2qa():
299 |
300 | lang_dict = {
301 | "chinese": "zh",
302 | "german": "de",
303 | "turkish": "tr",
304 | }
305 |
306 | domains = [
307 | "creative_writing",
308 | "news",
309 | "product_reviews"
310 | ]
311 |
312 | ds_list = []
313 |
314 | for lang in tqdm(lang_dict):
315 | ds = concatenate_datasets([
316 | load_dataset(
317 | "UKPLab/m2qa",
318 | f"m2qa.{lang}.{x}",
319 | split="validation",
320 | trust_remote_code=True
321 | ) for x in domains])
322 |
323 | ds = ds.filter(lambda x: len(x["answers"]["text"]) > 0, num_proc=16)
324 |
325 | # Decrypt it
326 | fernet = Fernet(b"aRY0LZZb_rPnXWDSiSJn9krCYezQMOBbGII2eGkN5jo=")
327 |
328 | def decrypt(example):
329 | example["question"] = fernet.decrypt(example["question"].encode()).decode()
330 | example["context"] = fernet.decrypt(example["context"].encode()).decode()
331 | example["answers"]["text"] = [fernet.decrypt(answer.encode()).decode() for answer in example["answers"]["text"]]
332 | return example
333 |
334 | ds = ds.map(decrypt)
335 | ds = ds.map(parse_squad)
336 | ds = ds.map(lambda x: {
337 | "negatives": None,
338 | "dataset_name": "m2qa",
339 | "language": lang_dict[lang],
340 | "doc_id": set(["_".join(x["id"].split("_")[:-1])]),
341 | })
342 |
343 | ds_list.append(ds)
344 |
345 | return concatenate_datasets(ds_list)
346 |
347 | def get_mlqa_dataset_list():
348 | dataset_list = [
349 | load_dataset("rajpurkar/squad", split="train").map(lambda x: {"language": "en"}, num_proc=16)
350 | ]
351 |
352 | langs = ["ar", "de", "es", "hi", "vi", "zh"]
353 |
354 | dataset_list = dataset_list + [
355 | load_dataset(
356 | "facebook/mlqa",
357 | f"mlqa-translate-train.{l}",
358 | split="train",
359 | trust_remote_code=True
360 | ).map(lambda x: {"language": l}, num_proc=16) for l in langs
361 | ]
362 |
363 | return dataset_list
364 |
365 | def match_crossling(dataset_list, dataset_name, title_column="title"):
366 | dataset_dicts = [
367 | {
368 | x["id"]: x for x in d
369 | } for d in tqdm(dataset_list)
370 | ]
371 |
372 | id_set = set()
373 |
374 | for d in dataset_list:
375 | id_set.update(set(d["id"]))
376 |
377 | cross_rows = []
378 |
379 | for row_id in tqdm(id_set):
380 | rows = [x[row_id] for x in dataset_dicts if row_id in x]
381 |
382 | title = [x[title_column] for x in rows if x["language"] == "en"][0]
383 | contexts = [x["context"] for x in rows]
384 |
385 | for row in rows:
386 | cross_rows.append({
387 | "question": row["question"],
388 | "answer": row["answers"]["text"][0],
389 | "positives": contexts,
390 | "negatives": None,
391 | "dataset_name": dataset_name,
392 | "language": row["language"],
393 | "doc_id": set([title]),
394 | })
395 |
396 | return cross_rows
397 |
398 | def prepare_mlqa():
399 |
400 | dataset_list = get_mlqa_dataset_list()
401 |
402 | cross_rows = match_crossling(dataset_list, "mlqa", title_column="title")
403 |
404 | return Dataset.from_pandas(pd.DataFrame(cross_rows))
405 |
406 | def prepare_xquad():
407 |
408 | dataset_dict = {}
409 |
410 | langs = ['ar', 'de', 'el', 'en', 'es', 'hi', 'ro', 'ru', 'th', 'tr', 'vi', 'zh']
411 |
412 | dataset_list = [
413 | load_dataset(
414 | "google/xquad",
415 | f"xquad.{l}",
416 | split="validation",
417 | trust_remote_code=True
418 | ).map(lambda x: {"language": l}, num_proc=16) for l in langs
419 | ]
420 |
421 | cross_rows = match_crossling(dataset_list, "xquad", title_column="context")
422 |
423 | return Dataset.from_pandas(pd.DataFrame(cross_rows))
424 |
425 | def parse_tydi_from_bytes(text, start, end):
426 | try:
427 | return text.encode("utf-8")[start:end].decode("utf-8")
428 | except:
429 | return None
430 |
431 | def prepare_tydiqa_goldp():
432 |
433 | ds = load_dataset("google-research-datasets/tydiqa", "primary_task", split="train").filter(
434 | lambda x: bool(x["annotations"]["minimal_answers_start_byte"][0] != -1),
435 | num_proc=16
436 | )
437 |
438 | ds = ds.map(lambda x: {
439 | "contexts": [
440 | parse_tydi_from_bytes(x["document_plaintext"], s, e) for s, e in zip(
441 | x["passage_answer_candidates"]["plaintext_start_byte"],
442 | x["passage_answer_candidates"]["plaintext_end_byte"]
443 | )],
444 | "answer": parse_tydi_from_bytes(
445 | x["document_plaintext"],
446 | x["annotations"]["minimal_answers_start_byte"][0],
447 | x["annotations"]["minimal_answers_end_byte"][0]),
448 | "question": x["question_text"],
449 | }, num_proc=16)
450 |
451 | ds = ds.map(lambda x: {
452 | "positives": [x["contexts"][x["annotations"]["passage_answer_candidate_index"][0]]],
453 | "negatives": [x["contexts"][i] for i in range(len(x["contexts"])) if i != x["annotations"]["passage_answer_candidate_index"][0]],
454 | }, num_proc=16)
455 |
456 | language_code_dict = {
457 | 'arabic': 'ar',
458 | 'bengali': 'bn',
459 | 'english': 'en',
460 | 'finnish': 'fi',
461 | 'indonesian': 'id',
462 | 'japanese': 'ja',
463 | 'korean': 'ko',
464 | 'russian': 'ru',
465 | 'swahili': 'sw',
466 | 'telugu': 'te',
467 | 'thai': 'th'
468 | }
469 |
470 | ds = ds.map(lambda x: {
471 | "dataset_name": "tydi",
472 | "language": language_code_dict[x["language"]],
473 | "doc_id": set([x["document_title"]]),
474 | }, num_proc=16)
475 |
476 | return ds
477 |
478 | def prepare_skquad():
479 | ds = load_dataset("TUKE-DeutscheTelekom/skquad", split="train")
480 | ds = ds.filter(lambda x: len(x["answers"]["text"]) > 0 and len(x["answers"]["text"][0].strip()) > 0, num_proc=16)
481 | ds = ds.map(lambda x: {"context": x["title"] + "\n" + x["context"]}, num_proc=16)
482 | ds = ds.map(parse_squad, num_proc=16)
483 | ds = ds.map(lambda x: {
484 | "negatives": None,
485 | "dataset_name": "skquad",
486 | "language": "sk",
487 | "doc_id": set([x["title"]]),
488 | }, num_proc=16)
489 | return ds
490 |
491 | def prepare_arcd():
492 | ds = load_dataset("hsseinmz/arcd", split="train")
493 | ds = ds.filter(lambda x: len(x["answers"]["text"]) > 0 and len(x["answers"]["text"][0].strip()) > 0, num_proc=16)
494 | ds = ds.map(lambda x: {"context": x["title"] + "\n" + x["context"]}, num_proc=16)
495 | ds = ds.map(parse_squad, num_proc=16)
496 | ds = ds.map(lambda x: {
497 | "negatives": None,
498 | "dataset_name": "arcd",
499 | "language": "ar",
500 | "doc_id": set([x["title"]]),
501 | }, num_proc=16)
502 | return ds
503 |
504 | def prepare_persianqa():
505 | ds = load_dataset("SajjadAyoubi/persian_qa", split="train")
506 | ds = ds.filter(lambda x: len(x["answers"]["text"]) > 0 and len(x["answers"]["text"][0].strip()) > 0, num_proc=16)
507 | ds = ds.map(lambda x: {"context": x["title"] + "\n" + x["context"]}, num_proc=16)
508 | ds = ds.map(parse_squad, num_proc=16)
509 | ds = ds.map(lambda x: {
510 | "negatives": None,
511 | "dataset_name": "persianqa",
512 | "language": "fa",
513 | "doc_id": set([x["title"]]),
514 | }, num_proc=16)
515 | return ds
516 |
517 | def prepare_amharicqa():
518 | df = pd.read_json("https://raw.githubusercontent.com/semantic-systems/amharic-qa/main/train_data.json")
519 | df = pd.DataFrame(pd.DataFrame(df.data.tolist()).explode("paragraphs").paragraphs.tolist()).explode("qas")
520 | df["question"] = df.qas.apply(lambda x: x["question"])
521 | df["answer"] = df.qas.apply(lambda x: x["answers"][0]["text"])
522 | df["positives"] = df.context.apply(lambda x: [x])
523 | ds = Dataset.from_pandas(df)
524 | ds = ds.map(lambda x: {
525 | "negatives": None,
526 | "dataset_name": "amharicqa",
527 | "language": "am",
528 | "doc_id": set([x["document_id"]]),
529 | }, num_proc=16)
530 | return ds
531 |
532 | def prepare_chaii():
533 | api = KaggleApi()
534 | api.authenticate()
535 | api.competition_download_files('chaii-hindi-and-tamil-question-answering', path='.')
536 |
537 | zip_path = './chaii-hindi-and-tamil-question-answering.zip'
538 |
539 | with zipfile.ZipFile(zip_path, 'r') as zip_ref:
540 | with zip_ref.open("train.csv") as file:
541 | content = file.read().decode('utf-8')
542 | df = pd.read_csv(StringIO(content))
543 |
544 | ds = Dataset.from_pandas(df)
545 |
546 | language_map = {
547 | "tamil": "ta",
548 | "hindi": "hi",
549 | }
550 |
551 | ds = ds.map(lambda x: {
552 | "question": x["question"],
553 | "answer": x["answer_text"],
554 | "positives": [x["context"]],
555 | "negatives": None,
556 | "dataset_name": "chaii",
557 | "language": language_map[x["language"]],
558 | "doc_id": set([x["id"]]),
559 | })
560 |
561 | return ds
562 |
563 | def prepare_sberquad():
564 | ds = load_dataset("kuznetsoffandrey/sberquad", split="train", trust_remote_code=True)
565 | ds = ds.filter(lambda x: bool(len(x["answers"]["text"]) > 0) and bool(len(x["answers"]["text"][0]) > 0), num_proc=16)
566 | ds = ds.map(parse_squad, num_proc=16)
567 | ds = ds.map(lambda x: {
568 | "negatives": None,
569 | "dataset_name": "sberquad",
570 | "language": "ru",
571 | "doc_id": set([x["id"]]),
572 | }, num_proc=16)
573 | return ds
574 |
575 | def prepare_pira():
576 | ds = load_dataset("paulopirozelli/pira", "default", split="train")
577 |
578 | en_ds = ds.map(lambda x: {
579 | "positives": [x["abstract"].strip()],
580 | "negatives": None,
581 | "question": x["question_en_origin"].strip(),
582 | "answer": x["answer_en_origin"].strip(),
583 | "dataset_name": "pira",
584 | "language": "en",
585 | "doc_id": set([x["id_qa"]]),
586 | }, num_proc=16)
587 | pt_ds = ds.map(lambda x: {
588 | "positives": [x["abstract_translated_pt"].strip()],
589 | "negatives": None,
590 | "question": x["question_pt_origin"].strip(),
591 | "answer": x["answer_pt_origin"].strip(),
592 | "dataset_name": "pira",
593 | "language": "pt",
594 | "doc_id": set([x["id_qa"]]),
595 | }, num_proc=16)
596 | return concatenate_datasets([en_ds, pt_ds])
597 |
598 | def parse_jsquad(row):
599 | row_data = []
600 | title = row["title"]
601 | paragraphs = row["paragraphs"]
602 |
603 | for p in paragraphs:
604 | context = p["context"].replace("[SEP]", "\n")
605 | questions = p["qas"]
606 |
607 | for q in questions:
608 | is_impossible = q["is_impossible"]
609 | if is_impossible:
610 | continue
611 | question = q["question"]
612 | answer = q["answers"][0]["text"]
613 |
614 | row_data.append({
615 | "question": question,
616 | "answer": answer,
617 | "positives": [context],
618 | "negatives": None,
619 | "dataset_name": "jsquad",
620 | "language": "ja",
621 | "doc_id": set([title]),
622 | })
623 |
624 | return row_data
625 |
626 | def prepare_jsquad():
627 | df = pd.read_json(
628 | "https://github.com/yahoojapan/JGLUE/raw/refs/heads/main/datasets/jsquad-v1.1/train-v1.1.json"
629 | )
630 |
631 | df = pd.DataFrame(df.data.apply(parse_jsquad).explode().tolist())
632 | ds = Dataset.from_pandas(df)
633 | ds = ds.filter(lambda x: bool(
634 | len(x["question"].strip()) > 0
635 | ) and bool(
636 | len(x["answer"].strip()) > 0
637 | ) and bool(
638 | len(x["positives"][0].strip()) > 0
639 | ), num_proc=16)
640 | return ds
641 |
642 | def prepare_korquad():
643 | ds = load_dataset("KorQuAD/squad_kor_v1", split="train")
644 | ds = ds.filter(lambda x: len(x["answers"]["text"]) > 0 and len(x["answers"]["text"][0].strip()) > 0, num_proc=16)
645 | ds = ds.map(lambda x: {"context": x["title"] + "\n" + x["context"]}, num_proc=16)
646 | ds = ds.map(parse_squad, num_proc=16)
647 | ds = ds.map(lambda x: {
648 | "negatives": None,
649 | "dataset_name": "korquad",
650 | "language": "ko",
651 | "doc_id": set([x["title"]]),
652 | }, num_proc=16)
653 | return ds
654 |
655 | def parse_nested(df):
656 | df = pd.DataFrame(df.data.apply(lambda x: [dict(**y, title=x["title"]) for y in x["paragraphs"]]).explode())
657 | df = df[~df.data.isna()]
658 | df["positives"] = df.data.apply(lambda x: [x["title"] + "\n" + x["context"]])
659 | df["data"] = df.data.apply(lambda x: [dict(**y, title=x["title"]) for y in x["qas"]])
660 | df = df.explode("data")
661 | df["title"] = df["data"].apply(lambda x: x["title"] if isinstance(x, dict) else None)
662 | df["question"] = df["data"].apply(lambda x: x["question"] if isinstance(x, dict) else None)
663 | df["answer"] = df["data"].apply(lambda x: x["answers"][0]["text"] if isinstance(x, dict) else None)
664 | df = df.dropna()
665 | ds = Dataset.from_pandas(df)
666 | return ds
667 |
668 | def prepare_tquad():
669 | df = pd.read_json("https://raw.githubusercontent.com/TQuad/turkish-nlp-qa-dataset/master/train-v0.1.json")
670 | ds = parse_nested(df)
671 | ds = ds.map(lambda x: {
672 | "negatives": None,
673 | "dataset_name": "tquad",
674 | "language": "tr",
675 | "doc_id": set([x["title"]]),
676 | }, num_proc=16)
677 | return ds
678 |
679 | def prepare_sqac():
680 | df = pd.read_json("https://huggingface.co/datasets/PlanTL-GOB-ES/SQAC/resolve/main/train.json")
681 | ds = parse_nested(df)
682 | ds = ds.map(lambda x: {
683 | "negatives": None,
684 | "dataset_name": "sqac",
685 | "language": "es",
686 | "doc_id": set([x["title"]]),
687 | }, num_proc=16)
688 | return ds
689 |
690 | def prepare_germanquad():
691 | ds = load_dataset("deepset/germanquad", split="train", trust_remote_code=True)
692 | ds = ds.filter(lambda x: len(x["answers"]["text"]) > 0 and len(x["answers"]["text"][0].strip()) > 0, num_proc=16)
693 | ds = ds.map(parse_squad, num_proc=16)
694 | ds = ds.map(lambda x: {
695 | "negatives": None,
696 | "dataset_name": "germanquad",
697 | "language": "de",
698 | "doc_id": set([x["id"]]),
699 | }, num_proc=16)
700 | return ds
701 |
702 | def prepare_kenswquad():
703 | # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")
704 | # max_tok_size = 1_500
705 |
706 | ds = load_dataset("lightblue/KenSwQuAD", split="train")
707 | ds = ds.filter(lambda x: len(x["answers"]["text"]) > 0 and len(x["answers"]["text"][0].strip()) > 0, num_proc=16)
708 | ds = ds.map(parse_squad, num_proc=16)
709 | # ds = ds.filter(lambda x: len(tokenizer.encode(x["context"])) < max_tok_size, num_proc=16)
710 |
711 | ds = ds.map(lambda x: {
712 | "negatives": None,
713 | "dataset_name": "kenswquad",
714 | "language": "sw",
715 | "doc_id": set([x["Story_ID"]]),
716 | }, num_proc=16)
717 | return ds
718 |
719 | def prepare_drcd():
720 | ds = load_dataset("voidful/DRCD", split="train")
721 | ds = parse_nested(pd.DataFrame({"data": ds.to_list()}))
722 | ds = ds.map(lambda x: {
723 | "negatives": None,
724 | "dataset_name": "drcd",
725 | "language": "zh",
726 | "doc_id": set([x["title"]]),
727 | }, num_proc=16)
728 | return ds
729 |
730 | def prepare_narrativeqa():
731 |
732 | ds = load_dataset("deepmind/narrativeqa", split="train")
733 |
734 | ds = ds.map(
735 | lambda x: {
736 | "positives": [x["document"]["summary"]["text"].strip()],
737 | "negatives": None,
738 | "question": x["question"]["text"].strip(),
739 | "answer": x["answers"][0]["text"],
740 | "dataset_name": "narrativeqa",
741 | "language": "en",
742 | "doc_id": set([x["document"]["summary"]["title"].strip()]),
743 | }, num_proc=16
744 | )
745 |
746 | return ds
747 |
748 | def get_lb_rewording(text):
749 | pattern = r"### Reworded Text\n(.*)"
750 | match = re.search(pattern, text, re.DOTALL)
751 | if match:
752 | selected_text = match.group(1)
753 | return selected_text
754 | else:
755 | return None
756 |
757 | def get_lb_positives(row):
758 | positives = []
759 |
760 | positives.append(row["selected_chunk"])
761 |
762 | if row["rewording_finish_reason"] == "stop":
763 | reworded_context = get_lb_rewording(row["rewording_response"])
764 | if reworded_context is not None:
765 | positives.append(reworded_context)
766 |
767 | if row["otherlang_rewording_finish_reason"] == "stop":
768 | otherlang_reworded_context = get_lb_rewording(row["otherlang_rewording_response"])
769 | if otherlang_reworded_context is not None:
770 | positives.append(otherlang_reworded_context)
771 |
772 | return positives
773 |
774 | def prepare_lb_rag():
775 |
776 | language_map = {'Amharic': 'am', 'Arabic': 'ar', 'Bulgarian': 'bg', 'Bengali': 'bn', 'Czech': 'cs', 'Danish': 'da', 'German': 'de', 'Greek': 'el', 'English': 'en', 'Spanish': 'es', 'Persian': 'fa', 'Finnish': 'fi', 'French': 'fr', 'Gujarati': 'gu', 'Hausa': 'ha', 'Hindi': 'hi', 'Hungarian': 'hu', 'Indonesian': 'id', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jv', 'Kannada': 'kn', 'Korean': 'ko', 'Lithuanian': 'lt', 'Marathi': 'mr', 'Dutch': 'nl', 'Norwegian': 'no', 'Polish': 'pl', 'Portuguese': 'pt', 'Romanian': 'ro', 'Russian': 'ru', 'Slovak': 'sk', 'Swedish': 'sv', 'Swahili': 'sw', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Tagalog': 'tl', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Vietnamese': 'vi', 'Yoruba': 'yo', 'Chinese': 'zh'}
777 |
778 | ds_list = []
779 |
780 | for lang in sorted(language_map.values()):
781 | print(lang)
782 | ds = load_dataset(
783 | "lightblue/rag_multilingual_training_negatives", lang, split="train"
784 | )
785 |
786 | # Multilingual
787 | muling_ds = ds.filter(lambda x: bool(len(x["otherlang_question"]) > 0) and bool(len(x["otherlang_answer"]) > 0) and bool(x["otherlang_qa_finish_reason"] == "stop"), num_proc=16)
788 |
789 | muling_ds = muling_ds.map(
790 | lambda x: {
791 | "question": x["otherlang_question"],
792 | "answer": x["otherlang_answer"],
793 | "positives": get_lb_positives(x),
794 | "negatives": x["multilingual_negatives"],
795 | "dataset_name": "lb_rag_multilingual",
796 | "language": language_map[x["other_qa_lang"]],
797 | "doc_id": None,
798 | }, num_proc=16
799 | )
800 |
801 | # Monolingual
802 | moling_ds = ds.filter(lambda x: bool(len(x["question"]) > 0) and bool(len(x["answer"]) > 0) and bool(x["raw_qa_finish_reason"] == "stop"), num_proc=16)
803 |
804 | moling_ds = moling_ds.map(
805 | lambda x: {
806 | "question": x["question"],
807 | "answer": x["answer"],
808 | "positives": get_lb_positives(x),
809 | "negatives": x["monolingual_negatives"],
810 | "dataset_name": "lb_rag_monolingual",
811 | "language": x["language"],
812 | "doc_id": None,
813 | }, num_proc=16
814 | )
815 |
816 | ds_list.append(muling_ds)
817 | ds_list.append(moling_ds)
818 |
819 | return concatenate_datasets(ds_list)
820 |
821 | def parse_mqa_text(name, text):
822 | name = "" if name is None else name
823 | text = "" if text is None else text
824 | namelower = name.lower().strip()
825 | textlower = text.lower().strip()
826 |
827 | question_text = ""
828 | question_text += name
829 | if namelower != textlower:
830 | question_text += "\n" + text
831 |
832 | question_text = re.sub(r"[\=\-\#]{3,}", "", question_text)
833 | return question_text.strip()
834 |
835 | def process_mqa(lang, data_type):
836 | answer_features = [{'downvote_count': Value(dtype='int64', id=None),
837 | 'is_accepted': Value(dtype='bool', id=None),
838 | 'name': Value(dtype='string', id=None),
839 | 'text': Value(dtype='string', id=None),
840 | 'upvote_count': Value(dtype='int64', id=None)}]
841 |
842 | question_features = {'answers': answer_features,
843 | 'comment_count': Value(dtype='int64', id=None),
844 | 'data_type': Value(dtype='string', id=None),
845 | 'downvote_count': Value(dtype='int64', id=None),
846 | 'hash': Value(dtype='string', id=None),
847 | 'name': Value(dtype='string', id=None),
848 | 'text': Value(dtype='string', id=None),
849 | 'upvote_count': Value(dtype='int64', id=None)}
850 |
851 | load_features = {'bucket': Value(dtype='float64', id=None),
852 | 'sub_bucket': Value(dtype='string', id=None),
853 | 'language': Value(dtype='string', id=None),
854 | 'hreflang_alternates': [{'href': Value(dtype='string', id=None),
855 | 'hreflang': Value(dtype='string', id=None)}],
856 | 'questions': [question_features],
857 | 'page_hash': Value(dtype='string', id=None),
858 | 'fasttext_language': Value(dtype='string', id=None),
859 | 'domain': Value(dtype='string', id=None)}
860 |
861 | filename = hf_hub_download(repo_id="clips/mqa", filename=f"data/data.{lang}.{data_type}.json.gz", repo_type="dataset")
862 | ds = load_dataset("json", data_files={"train": filename}, split="train", features=Features(load_features))
863 |
864 | # Randomly sample at maximum 100K rows to make this processing tractable
865 | max_rows = 100_000
866 | ds = ds.shuffle().select(range(min(max_rows, len(ds))))
867 |
868 | load_features["questions"] = question_features
869 | explode_features = Features(load_features)
870 |
871 | explode_mqa = lambda x: pd.DataFrame(dict(x)).explode("questions").to_dict(orient="list")
872 |
873 | ds = ds.map(explode_mqa,
874 | batched=True,
875 | batch_size=1000,
876 | num_proc=32,
877 | remove_columns=ds.column_names,
878 | features=explode_features)
879 |
880 | ds = ds.filter(lambda x: bool(
881 | isinstance(x["language"], str)
882 | ) and bool(
883 | isinstance(x["fasttext_language"], str)
884 | ) and bool(
885 | lang in x["language"].lower()
886 | ) and bool(
887 | lang in x["fasttext_language"].lower()
888 | ),
889 | num_proc=32
890 | )
891 |
892 | load_features["accepted_answer"] = answer_features
893 | accepted_features = Features(load_features)
894 |
895 | ds = ds.map(lambda x: {
896 | "accepted_answer": [y for y in x["questions"]["answers"] if y["is_accepted"]],
897 | }, num_proc=32, features=accepted_features)
898 |
899 | ds = ds.filter(lambda x: len(x["accepted_answer"]) > 0, num_proc=32)
900 |
901 | ds = ds.map(lambda x: {
902 | "question": parse_mqa_text(x["questions"]["name"], x["questions"]["text"]),
903 | "answer": None,
904 | "positives": [parse_mqa_text(x["accepted_answer"][0]["name"], x["accepted_answer"][0]["text"])],
905 | "negatives": None,
906 | "dataset_name": f"mqa_{data_type}",
907 | "language": lang,
908 | "doc_id": set([x["domain"]]),
909 | }, num_proc=32)
910 |
911 | return ds
912 |
913 | def prepare_mqa(data_type):
914 | langs = [
915 | 'af', 'als', 'am', 'an', 'ar', 'arz', 'as', 'ast', 'av', 'az', 'azb', 'ba', 'bar', 'bcl', 'be', 'bg', 'bh', 'bn', 'bo', 'bpy', 'br', 'bs', 'bxr', 'ca', 'cbk', 'ce', 'ceb', 'ckb', 'cs', 'cv', 'cy', 'da', 'de', 'diq', 'dsb', 'dty', 'dv', 'el', 'eml', 'en', 'eo', 'es', 'et', 'eu', 'fa', 'fi', 'fr', 'frr', 'fy', 'ga', 'gd', 'gl', 'gn', 'gom', 'gu', 'gv', 'he', 'hi', 'hif', 'hr', 'hsb', 'ht', 'hu', 'hy', 'ia', 'id', 'ie', 'ilo', 'io', 'is', 'it', 'ja', 'jbo', 'jv', 'ka', 'kk', 'km', 'kn', 'ko', 'krc', 'ku', 'kv', 'kw', 'ky', 'la', 'lb', 'lez', 'li', 'lmo', 'lo', 'lrc', 'lt', 'lv', 'mai', 'mg', 'mhr', 'min', 'mk', 'ml', 'mn', 'mr', 'ms', 'mt', 'mwl', 'my', 'myv', 'mzn', 'nah', 'nap', 'nds', 'ne', 'new', 'nl', 'nn', 'no', 'oc', 'or', 'os', 'pa', 'pam', 'pfl', 'pl', 'pms', 'pnb', 'ps', 'pt', 'qu', 'rm', 'ro', 'ru', 'sa', 'sah', 'sc', 'scn', 'sco', 'sd', 'sh', 'si', 'sk', 'sl', 'so', 'sq', 'sr', 'su', 'sv', 'sw', 'ta', 'te', 'tg', 'th', 'tk', 'tl', 'tr', 'tt', 'tyv', 'ug', 'uk', 'ur', 'uz', 'vec', 'vep', 'vi', 'vo', 'wa', 'war', 'wuu', 'xal', 'yi', 'yo', 'yue', 'zh'
916 | ]
917 |
918 | ds_list = []
919 |
920 | for lang in langs:
921 | print(f">>> Starting {lang} - {data_type}")
922 | try:
923 | ds = process_mqa(lang, data_type)
924 | except Exception as e:
925 | print(e)
926 | print(f"### Skipping {lang} - {data_type}")
927 | continue
928 | ds_list.append(ds)
929 |
930 | cat_ds = concatenate_datasets(ds_list)
931 |
932 | return cat_ds
933 |
934 | def prepare_mqa_cqa():
935 | return prepare_mqa("cqa")
936 |
937 | def prepare_mqa_faq():
938 | return prepare_mqa("faq")
939 |
940 | if __name__ == "__main__":
941 |
942 | dataset_func_list = [
943 | prepare_amharicqa,
944 | prepare_arcd,
945 | prepare_chaii,
946 | prepare_cpgqa,
947 | prepare_drcd,
948 | prepare_germanquad,
949 | prepare_hotpotqa,
950 | prepare_indicqa,
951 | prepare_jsquad,
952 | prepare_jqara,
953 | prepare_kenswquad,
954 | prepare_korquad,
955 | prepare_lb_rag,
956 | prepare_logqa,
957 | prepare_lsat,
958 | prepare_m2qa,
959 | prepare_mldr,
960 | prepare_mlqa,
961 | prepare_mqa_cqa,
962 | prepare_mqa_faq,
963 | prepare_narrativeqa,
964 | prepare_persianqa,
965 | prepare_pira,
966 | prepare_pubmedqa,
967 | prepare_qasports,
968 | prepare_sberquad,
969 | prepare_scandiqa,
970 | prepare_skquad,
971 | prepare_sleepqa,
972 | prepare_sqac,
973 | prepare_tquad,
974 | prepare_triviaqa,
975 | prepare_tydiqa_goldp,
976 | prepare_xquad,
977 | ]
978 |
979 | def write_name_to_file(name):
980 | with open("temp.txt", "a+") as f:
981 | f.write(str(name))
982 | f.write("\n")
983 | return True
984 |
985 | final_features = {
986 | 'question': Value(dtype='string', id=None),
987 | 'answer': Value(dtype='string', id=None),
988 | 'positives': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
989 | 'negatives': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
990 | 'dataset_name': Value(dtype='string', id=None),
991 | 'language': Value(dtype='string', id=None),
992 | 'doc_id': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)
993 | }
994 |
995 | required_cols = list(final_features.keys())
996 | dataset_list = []
997 |
998 | for x in tqdm(dataset_func_list):
999 | print(x)
1000 | write_name_to_file(x)
1001 | ds = x().select_columns(required_cols).map(
1002 | lambda x: {k:v for k, v in x.items()},
1003 | features=Features(final_features),
1004 | num_proc=16
1005 | )
1006 | ds.to_parquet("./data/" + str(x).split()[1] + ".parquet")
1007 | dataset_list.append(ds)
1008 |
1009 | ds = concatenate_datasets(dataset_list)
1010 |
1011 | ds.push_to_hub("lightblue/rag_datasets_collection", private=True)
--------------------------------------------------------------------------------
/data_prep/2_add_negatives.py:
--------------------------------------------------------------------------------
1 | from tqdm.auto import tqdm
2 | import torch
3 | from FlagEmbedding import BGEM3FlagModel
4 | from datasets import load_dataset, concatenate_datasets
5 | from datasets.features.features import Value, Sequence
6 | import numpy as np
7 |
8 | flatten_list = lambda ll: [x for l in ll for x in l]
9 |
10 | def select_negatives(q_embedding, c_embeddings, context_df, doc_id_set, num_negs=10):
11 | sim = c_embeddings @ q_embedding
12 |
13 | context_df["sim"] = sim.cpu().numpy()
14 |
15 | context_df["is_pos"] = context_df["doc_id"].apply(lambda x: bool(set(x) & doc_id_set))
16 |
17 | sorted_context_df = context_df.sort_values("sim", ascending=False)
18 |
19 | negatives = sorted_context_df[~sorted_context_df["is_pos"]].iloc[:num_negs].positives.tolist()
20 |
21 | return negatives
22 |
23 | def embed_text(text_list, model, max_len):
24 | return model.encode(
25 | text_list,
26 | max_length=max_len
27 | )['dense_vecs']
28 |
29 | def send_array_to_gpu(array):
30 | return torch.Tensor(array).to(torch.device("cuda"))
31 |
32 | def mine_negatives(ds, model):
33 |
34 | if ds[0]["negatives"] is None:
35 | ds = ds.add_column("added_neg", [True] * len(ds))
36 | else:
37 | ds = ds.add_column("added_neg", [False] * len(ds))
38 | print("No need to mine negatives")
39 | return ds
40 |
41 | if ds[0]["doc_id"] is None:
42 | doc_ids = [set([i]) for i in range(len(ds))]
43 | ds = ds.remove_columns(["doc_id"]).add_column("doc_id", doc_ids)
44 | ds = ds.add_column("added_doc_id", [True] * len(ds))
45 | else:
46 | ds = ds.add_column("added_doc_id", [False] * len(ds))
47 |
48 | context_df = ds.select_columns(
49 | ["positives", "doc_id"]
50 | ).to_pandas().explode("positives").groupby("positives").doc_id.apply(
51 | lambda x: set(flatten_list(x))
52 | ).reset_index(drop=False)
53 |
54 | context_df = context_df[~context_df.positives.isna()]
55 | context_df = context_df[context_df.positives.str.strip().str.len() > 0]
56 |
57 | if context_df.shape[0] < 1:
58 | print("Skipping because of no context")
59 | return None
60 |
61 | context_df["pos_len"] = context_df.positives.str.strip().str.len()
62 | context_df = context_df.sort_values("pos_len", ascending=False)
63 |
64 | c_embeddings = embed_text(context_df["positives"].tolist(), model, 8192)
65 | q_embeddings = embed_text(ds["question"], model, 8192)
66 |
67 | c_embeddings = send_array_to_gpu(c_embeddings)
68 | q_embeddings = send_array_to_gpu(q_embeddings)
69 |
70 | negatives_list = []
71 | num_negs = 10
72 |
73 | for q_embedding, doc_id in tqdm(zip(q_embeddings, ds["doc_id"]), total=len(ds)):
74 | negatives = select_negatives(q_embedding, c_embeddings, context_df, set(doc_id), num_negs=num_negs)
75 | negatives_list.append(negatives)
76 |
77 | ds = ds.remove_columns(["negatives"]).add_column(
78 | "negatives",
79 | negatives_list,
80 | feature=Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)
81 | )
82 |
83 | return ds
84 |
85 | def sample_dataset(ds):
86 |
87 | MAX_DATASET_LANG_ROWS = 25_000
88 | MAX_MQA_LANG_ROWS = 5_000
89 |
90 | ds = ds.filter(lambda x: isinstance(x["question"], str) and bool(len(x["question"].strip()) > 0) and bool(len(x["positives"]) > 0), num_proc=32)
91 |
92 | max_rows = MAX_MQA_LANG_ROWS if "mqa" in ds[0]["dataset_name"] else MAX_DATASET_LANG_ROWS
93 | ds = ds.shuffle().select(range(min(max_rows, len(ds))))
94 |
95 | return ds
96 |
97 | def run_get_negatives(ds, model):
98 |
99 | lang_list = ds["language"]
100 | langs = sorted(set(lang_list))
101 |
102 | ds_list = []
103 | if len(langs) <= 1:
104 | lang_ds = sample_dataset(ds)
105 | ds_list = [mine_negatives(lang_ds, model)]
106 | else:
107 | lang_arr = np.array(lang_list)
108 | for lang in langs:
109 | print(lang)
110 | lang_idxs = np.where(lang == lang_arr)[0].tolist()
111 | lang_ds = ds.select(lang_idxs)
112 | lang_ds = sample_dataset(lang_ds)
113 |
114 | print(f"Length of {lang} is {len(lang_ds)}")
115 |
116 | ds_list.append(mine_negatives(lang_ds, model))
117 |
118 | ds_list = [x for x in ds_list if x is not None]
119 |
120 | if len(ds_list) < 1:
121 | return None
122 |
123 | return concatenate_datasets(ds_list)
124 |
125 | if __name__ == "__main__":
126 |
127 | model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
128 |
129 | original_ds = load_dataset("lightblue/rag_datasets_collection", split="train")
130 |
131 | dataset_names = sorted(set(original_ds["dataset_name"]))
132 |
133 | ds_list = []
134 | for dataset_name in dataset_names:
135 |
136 | print(dataset_name)
137 | single_ds = original_ds.filter(lambda x: x["dataset_name"] == dataset_name, num_proc=32)
138 |
139 | ds_w_negs = run_get_negatives(single_ds, model)
140 | ds_w_negs = ds_w_negs.map(lambda x: {
141 | "negatives": [y for y in x["negatives"] if y not in x["positives"]]
142 | }, num_proc=8)
143 | ds_w_negs = ds_w_negs.filter(lambda x: len(x["negatives"]) > 0, num_proc=32)
144 |
145 | if ds_w_negs is None:
146 | print(f"None dataset at {dataset_name}")
147 | continue
148 |
149 | ds_w_negs.to_parquet(f"./negatives/{dataset_name}.parquet")
150 |
151 | ds_list.append(ds_w_negs)
152 |
153 | concatenate_datasets(ds_list).push_to_hub("lightblue/rag_datasets_selected", private=True)
154 |
--------------------------------------------------------------------------------
/data_prep/3_add_continuous_labels.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset, concatenate_datasets, Dataset
2 | import random
3 | from transformers import AutoTokenizer
4 | from vllm import LLM, SamplingParams
5 | import os
6 | from multiprocessing import Pool
7 | import math
8 | import numpy as np
9 |
10 | def get_hash(example):
11 | """Get hash of question field."""
12 | return {"q_hash": hash(example["question"])}
13 |
14 | def check_uniques(example, uniques):
15 | """Check if current hash is still in set of unique hashes and remove if true."""
16 | if example["q_hash"] in uniques:
17 | uniques.remove(example["q_hash"])
18 | return True
19 | else:
20 | return False
21 |
22 | def remove_duplicates(ds):
23 | ds = ds.map(get_hash, num_proc=32)
24 | uniques = set(ds.unique("q_hash"))
25 | return ds.filter(check_uniques, fn_kwargs={"uniques": uniques})
26 |
27 | def get_ds(model_name):
28 | tokenizer = AutoTokenizer.from_pretrained(model_name)
29 | MAX_LEN_MARGIN = 512
30 | MAX_LEN = 8192 - MAX_LEN_MARGIN
31 |
32 | ds = load_dataset("lightblue/rag_datasets_selected", split="train")
33 | ds = ds.add_column("row_id", list(range(len(ds))))
34 | ds = ds.shuffle()
35 |
36 | print("Deduplicating")
37 | ds = remove_duplicates(ds)
38 |
39 | selected_columns = ['question', 'answer', 'dataset_name', 'language', 'added_neg', 'doc_id', 'added_doc_id', 'row_id']
40 | added_columns = ['context', 'label']
41 |
42 | ds = ds.map(lambda x: {
43 | "positives": [p for p in x["positives"] if len(tokenizer.encode(p)) < MAX_LEN],
44 | "negatives": [n for n in x["negatives"] if len(tokenizer.encode(n)) < MAX_LEN],
45 | }, num_proc=32)
46 |
47 | ds = ds.filter(lambda x: bool(len(x["positives"]) > 0) and bool(len(x["negatives"]) > 0), num_proc=32)
48 |
49 | pos_ds = ds.select_columns(selected_columns + ["positives"]).map(lambda x: {
50 | "context": random.sample(x["positives"], k=1)[0],
51 | "label": True,
52 | }, num_proc=32).select_columns(selected_columns + added_columns)
53 |
54 | neg_ds = ds.select_columns(selected_columns + ["negatives"]).map(lambda x: {
55 | "context": random.sample(x["negatives"], k=1)[0],
56 | "label": False,
57 | }, num_proc=32).select_columns(selected_columns + added_columns)
58 |
59 | new_ds = concatenate_datasets([pos_ds, neg_ds])
60 |
61 | return new_ds
62 |
63 | def get_prob(outputs, tok_id):
64 | if bool(len(outputs) < 1) or bool(len(outputs[0].logprobs) < 1):
65 | return 0
66 |
67 | logprob_dict = outputs[0].logprobs[0]
68 | if tok_id in logprob_dict.keys():
69 | return np.exp(logprob_dict[tok_id].logprob)
70 | else:
71 | return 0
72 |
73 | def generate_responses(inputs):
74 | text_list, model_name, gpu_id, reverse_context_query = inputs
75 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
76 |
77 | llm = LLM(model=model_name)
78 |
79 | sampling_params = SamplingParams(temperature=0.0, max_tokens=1, logprobs=10)
80 |
81 | text_query_sys_msg_snippet_1 = "query and a piece of text" if reverse_context_query else "piece of text and a query"
82 | text_query_sys_msg_snippet_2 = "text relates to the query" if reverse_context_query else "query relates to the text"
83 | system_message = f"You are a relatedness rating assistant. Given a {text_query_sys_msg_snippet_1}, output the level that the {text_query_sys_msg_snippet_2}. Your output should be single number between 1-5, with 1 meaning completely unrelated, 2 meaning mostly unrelated, 3 meaning unsure as to whether it is related or not, 4 meaning mostly related, and 5 meaning completely related."
84 |
85 | chats = [[
86 | {"role": "system", "content": system_message},
87 | {"role": "user", "content": f"<<>>\n{q}\n\n\n<<>>\n{t}" if reverse_context_query else f"<<>>\n{t}\n\n\n<<>>\n{q}"},
88 | ] for t, q in text_list]
89 |
90 | responses = llm.chat(chats, sampling_params)
91 |
92 | tok = llm.llm_engine.tokenizer.tokenizer
93 | idx_tokens = [tok.encode(str(i))[0] for i in range(1, 6)] # Get token IDs for the tokens "1", "2", "3", "4", and "5"
94 | probs = np.array([[get_prob(r.outputs, y) for y in idx_tokens] for r in responses])
95 | return probs.tolist()
96 |
97 | def get_scores(all_texts, reverse_context_query):
98 | # Modify this to suit your number of GPUs
99 | num_gpus = 8
100 | batch_size = int(math.ceil(len(all_texts) / num_gpus))
101 |
102 | split_texts_w_idx = []
103 |
104 | for i in range(num_gpus):
105 | start_idx = i*batch_size
106 | end_idx = start_idx + batch_size
107 | text_batch = all_texts[start_idx:end_idx]
108 | split_texts_w_idx.append((text_batch, model_name, i, reverse_context_query))
109 |
110 | with Pool(num_gpus) as p:
111 | scores_split = p.map(generate_responses, split_texts_w_idx)
112 |
113 | scores = []
114 |
115 | for score_split in scores_split:
116 | scores.extend(score_split)
117 |
118 | return scores
119 |
120 | if __name__ == '__main__':
121 | model_name = "Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4"
122 | new_ds = get_ds(model_name)
123 | all_texts = list(zip(new_ds["context"], new_ds["question"]))
124 |
125 | scores = get_scores(all_texts, reverse_context_query=False)
126 | new_ds = new_ds.add_column("32B_score_probs", scores)
127 |
128 | scores = get_scores(all_texts, reverse_context_query=True)
129 | new_ds = new_ds.add_column("32B_score_probs_rev", scores)
130 |
131 | new_ds = new_ds.sort("row_id")
132 | new_ds.to_parquet("lightblue__rag_datasets_selected_32B4scored_probs.parquet")
133 | new_ds.push_to_hub("lightblue/rag_datasets_selected_32B4scored_probs")
134 |
--------------------------------------------------------------------------------
/data_prep/4_make_cont_bin_training_data.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | from transformers import AutoTokenizer
3 | import numpy as np
4 |
5 | score_system_message = "Given a query and a piece of text, output a score of 1-7 based on how related the query is to the text. 1 means least related and 7 is most related."
6 | rev_score_system_message = "Given a piece of text and a query, output a score of 1-7 based on how related the query is to the text. 1 means least related and 7 is most related."
7 |
8 | def filter_for_correct_scores_only(x):
9 | return bool(
10 | bool(
11 | x["label"] and bool(x["mean_exp_val_max7"] >= 4)
12 | ) or bool(
13 | bool(not x["label"]) and bool(x["mean_exp_val_max7"] <= 4)
14 | )
15 | )
16 |
17 | format_text_query = lambda t, q: f"<<>>\n{q}\n\n<<>>\n{t}"
18 | format_query_text = lambda t, q: f"<<>>\n{t}\n\n<<>>\n{q}"
19 |
20 | def make_continuous_data(x):
21 | return {
22 | "conversations": [
23 | { "from": "system", "value": score_system_message },
24 | { "from": "human", "value":format_text_query(x["context"], x["question"])},
25 | { "from": "gpt", "value": str(int(x["mean_exp_val_max7_round"])) } ]
26 | }
27 |
28 | def make_rev_continuous_data(x):
29 | return {
30 | "rev_conversations": [
31 | { "from": "system", "value": rev_score_system_message },
32 | { "from": "human", "value":format_query_text(x["context"], x["question"])},
33 | { "from": "gpt", "value": str(int(x["mean_exp_val_max7_round"])) } ]
34 | }
35 |
36 | calc_exp_val = lambda probs: sum([(i+1) * (p / sum(probs)) for i, p in enumerate(probs)])
37 |
38 | def main():
39 | ds = load_dataset("lightblue/rag_datasets_selected_32B4scored_probs", split="train")
40 |
41 | ds = ds.filter(lambda x: bool(sum(x["32B_score_probs"]) > 0) and bool(sum(x["32B_score_probs_rev"]) > 0), num_proc=32)
42 |
43 | ds = ds.map(
44 | lambda x: {
45 | "prob_exp_val": calc_exp_val(x["32B_score_probs"]),
46 | "rev_prob_exp_val": calc_exp_val(x["32B_score_probs_rev"]),
47 | }, num_proc=32
48 | )
49 |
50 | ds = ds.map(
51 | lambda x: {
52 | "mean_exp_val": np.array([x["prob_exp_val"], x["rev_prob_exp_val"]]).mean(),
53 | }, num_proc=32
54 | )
55 |
56 | ds = ds.map(
57 | lambda x: {
58 | "mean_exp_val_max7": ((x["mean_exp_val"] - 1) * (6/4)) + 1
59 | }, num_proc=32
60 | )
61 |
62 | ds = ds.map(
63 | lambda x: {
64 | "mean_exp_val_max7_round": round(x["mean_exp_val_max7"])
65 | }, num_proc=32
66 | )
67 |
68 | ds = ds.map(make_continuous_data, num_proc=32)
69 | ds = ds.map(make_rev_continuous_data, num_proc=32)
70 |
71 | tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
72 |
73 | MAX_LEN = 8192
74 |
75 | ds = ds.filter(
76 | lambda x: len(
77 | tokenizer.encode("\n".join([y["value"] for y in x["conversations"]]))
78 | ) < MAX_LEN, num_proc=32
79 | )
80 |
81 | ds = ds.shuffle()
82 |
83 | ds.filter(filter_for_correct_scores_only, num_proc=32).push_to_hub(
84 | "lightblue/reranker_continuous_filt_max7_train_extra", private=True
85 | )
86 |
87 | if __name__ == '__main__':
88 | main()
--------------------------------------------------------------------------------
/training/rev_train_config.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: Qwen/Qwen2.5-0.5B-Instruct
3 |
4 | ### method
5 | stage: sft
6 | do_train: true
7 | finetuning_type: full
8 | deepspeed: /root/LLaMA-Factory/examples/deepspeed/ds_z2_config.json
9 |
10 | ### dataset
11 | dataset: reranker_continuous_filt_max7_rev_train
12 | template: qwen
13 | cutoff_len: 18000 # NOTE - this is larger than the original reranker which was set to 8192
14 | overwrite_cache: true
15 | preprocessing_num_workers: 16
16 | packing: true
17 |
18 | ### output
19 | output_dir: /root/train_outputs/Qwen2.5-0.5B-Instruct/reranker_continuous_filt_max7_rev_train
20 | logging_steps: 1
21 | save_steps: 0.99999
22 | plot_loss: true
23 | overwrite_output_dir: true
24 |
25 | ### train
26 | per_device_train_batch_size: 1
27 | gradient_accumulation_steps: 1
28 | learning_rate: 1.0e-5
29 | num_train_epochs: 1.0
30 | lr_scheduler_type: cosine
31 | warmup_ratio: 0.01
32 | bf16: true
33 | ddp_timeout: 180000000
34 |
35 | ### eval
36 | val_size: 0.01
37 | per_device_eval_batch_size: 1
38 | eval_strategy: steps
39 | eval_steps: 0.1
--------------------------------------------------------------------------------
/training/run_rev_train.sh:
--------------------------------------------------------------------------------
1 | echo '{
2 | "reranker_continuous_filt_max7_rev_train": {
3 | "hf_hub_url": "lightblue/reranker_continuous_filt_max7_train_extra",
4 | "formatting": "sharegpt",
5 | "columns": {
6 | "messages": "rev_conversations"
7 | }
8 | }
9 | }' > /root/LLaMA-Factory/data/dataset_info.json
10 |
11 | cd /root/LLaMA-Factory && llamafactory-cli train /root/lb-reranker/training/rev_train_config.yaml
12 |
13 | rm -r /root/train_outputs/Qwen2.5-0.5B-Instruct/reranker_continuous_filt_max7_rev_train/checkpoint*
14 | huggingface-cli upload lightblue/reranker_0.5_cont_filt_7max_rev /root/train_outputs/Qwen2.5-0.5B-Instruct/reranker_continuous_filt_max7_rev_train
15 |
--------------------------------------------------------------------------------
/training/run_train.sh:
--------------------------------------------------------------------------------
1 | echo '{
2 | "reranker_continuous_filt_max7_train": {
3 | "hf_hub_url": "lightblue/reranker_continuous_filt_max7_train",
4 | "formatting": "sharegpt"
5 | }
6 | }' > /root/LLaMA-Factory/data/dataset_info.json
7 |
8 | cd /root/LLaMA-Factory && llamafactory-cli train /root/lb-reranker/training/train_config.yaml
9 |
10 | rm -r /root/train_outputs/Qwen2.5-0.5B-Instruct/reranker_continuous_filt_max7_train/checkpoint*
11 | huggingface-cli upload lightblue/reranker_0.5_cont_filt_7max /root/train_outputs/Qwen2.5-0.5B-Instruct/reranker_continuous_filt_max7_train
12 |
--------------------------------------------------------------------------------
/training/train_config.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: Qwen/Qwen2.5-0.5B-Instruct
3 |
4 | ### method
5 | stage: sft
6 | do_train: true
7 | finetuning_type: full
8 | deepspeed: /root/LLaMA-Factory/examples/deepspeed/ds_z2_config.json
9 |
10 | ### dataset
11 | dataset: reranker_continuous_filt_max7_train
12 | template: qwen
13 | cutoff_len: 8192
14 | overwrite_cache: true
15 | preprocessing_num_workers: 16
16 | packing: true
17 |
18 | ### output
19 | output_dir: /root/train_outputs/Qwen2.5-0.5B-Instruct/reranker_continuous_filt_max7_train
20 | logging_steps: 1
21 | save_steps: 0.99999
22 | plot_loss: true
23 | overwrite_output_dir: true
24 |
25 | ### train
26 | per_device_train_batch_size: 1
27 | gradient_accumulation_steps: 1
28 | learning_rate: 1.0e-5
29 | num_train_epochs: 1.0
30 | lr_scheduler_type: cosine
31 | warmup_ratio: 0.01
32 | bf16: true
33 | ddp_timeout: 180000000
34 |
35 | ### eval
36 | val_size: 0.01
37 | per_device_eval_batch_size: 1
38 | eval_strategy: steps
39 | eval_steps: 0.1
--------------------------------------------------------------------------------