├── .gitignore ├── requirements.txt ├── README.md └── generate.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | *.swp 3 | *.swo 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | spacy 2 | protobuf==3.19.4 3 | torch==1.12.0 4 | transformers==4.20.1 5 | sentencepiece==0.1.96 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LIQUID 2 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/liquid-a-framework-for-list-question/question-answering-on-multispanqa)](https://paperswithcode.com/sota/question-answering-on-multispanqa?p=liquid-a-framework-for-list-question)
3 | This is the official repository for the paper "[**LIQUID: A Framework for List Question Answering Dataset Generation**](https://arxiv.org/abs/2302.01691)" (presented at [***AAAI 2023***](https://aaai.org/Conferences/AAAI-23/)). This repository provides the implementation of the LIQUID model, guidelines on how to run the model to synthesize list QA data. Also, you can download generated datasets without having to create them from scratch (see **[here](#data-downloads)**). 4 | 5 | 6 | ## Quick Links 7 | 8 | * [Overview](#overview) 9 | * [Data Downloads](#data-downloads) 10 | * [Requirements](#requirements) 11 | * [Dataset Generation](#dataset-generation) 12 | * [List Question Answering](#list-question-answering) 13 | * [Reference](#reference) 14 | * [Contact](#contact) 15 | 16 | ## Overview 17 | 18 | LIQUID is an automated framework for generating list QA datasets from unlabeled corpora. Generated datasets by LIQUID can be used to improve list QA performance by supplementing insufficient human-labeled data. When training a list QA model using the generated data and then fine-tuning it on the target training data, we achieved a new **state-of-the-art** performance on **[MultiSpanQA](https://multi-span.github.io/)** and outperformed baselines on several benchmakrs including **[Quoref](https://arxiv.org/abs/1908.05803)** and **[BioASQ](http://bioasq.org/)**. 19 | 20 | ![Model-1](https://user-images.githubusercontent.com/72010172/185115620-3dbd69cd-5e37-4da0-acd6-e9dfb9ab0021.png) 21 | 22 | LIQUID comprises the following four stages (please refer to **[our paper](https://arxiv.org/abs/2302.01691)** for details). 23 | 24 | * (1) Answer extraction: the named entities belonging to the same entity type (e.g., organization type) in a summary are extracted by an NER model and used as candidate answers. 25 | * (2) Question generation: the candidate answers and the original passage are fed into a QG model to generate list questions. 26 | * (3) Iterative filtering: incorrect answers (e.g., Hanszen) are iteratively filtered based on the confidence score assigned by a QA model. 27 | * (4) Answer expansion: correct but omitted answers (e.g., Yale) are identified by the QA model. 28 | 29 | ## Data Downloads 30 | 31 | Use the links below to download the synthetic datasets without having to create a dataset from scratch. ✶ indicates they are the same data used in our experiments. Our data format follows that of SQuAD-v1.1. 32 | 33 | | Name | Corpus | Size | Link | 34 | |:----------------------------------|:--------|:--------|:--------| 35 | | liquid-wiki-140k (✶) | Wikipedia | 140k | http://nlp.dmis.korea.edu/projects/liquid-lee-et-al-2023/liquid-wiki-140k.json | 36 | | liquid-pubmed-140k (✶) | PubMed | 140k | http://nlp.dmis.korea.edu/projects/liquid-lee-et-al-2023/liquid-pubmed-140k.json | 37 | 38 | ## Requirements 39 | 40 | Download this repository and set up an environment as follows. 41 | 42 | ```bash 43 | # Clone the repository 44 | git clone https://github.com/sylee0520/LIQUID.git 45 | cd LIQUID 46 | 47 | # Create a conda virtual environment 48 | conda create -n liquid python=3.8 49 | conda activate liquid 50 | 51 | # Install all requirements 52 | pip install -r requirements.txt 53 | ``` 54 | 55 | ### Unlabeled Corpus 56 | 57 | Download an unlabeled source corpus to be annotated and extract/unpack it to the correct directory. Choose either Wikipedia or PubMed depending on your target domain. ✶ indicates they are the same data used in our experiments. 58 | 59 | | Description | Directory | Link | 60 | |:----------------------------------|:--------|:--------| 61 | | 2018-12-20 version of **Wikipedia** (✶) | `./data/unlabeled/wiki/` | http://nlp.dmis.korea.edu/projects/liquid-lee-et-al-2023/wiki181220.zip | 62 | | 2019-01-02 version of **PubMed** (✶) | `./data/unlabeled/pubmed/` | http://nlp.dmis.korea.edu/projects/liquid-lee-et-al-2023/pubmed190102.zip | 63 | 64 | Note that passages in each file have not been shuffled. You will have to randomly sample passages from the entire corpus files (e.g., "0000.json" to "5620.json" for Wikipedia) if you want to use sampled passages. 65 | 66 | ### NER Models 67 | 68 | In LIQUID, two types of NER models are used to extract candidate answers for the *general* and *biomedical* domains, respectively. Please refer to the instructions below to install the NER models. 69 | 70 | * For the *general* domain, run `python -m spacy download en_core_web_sm` to install **spaCy** NER system. 71 | * For the *biomedican* domain, install **BERN2** from the official GitHub repository (**[link](https://github.com/dmis-lab/BERN2)**). After installation is complete, refer to the instructions below and run the model in the background. Note that you need to create a new conda environment for BERN2, instead of reusing the environment for LIQUID. 72 | 73 | ```bash 74 | # Run BERN2 model 75 | export CUDA_VISIBLE_DEVICES=0 76 | conda activate BERN2 77 | cd BERN2/scripts 78 | 79 | # For Linux and MacOS 80 | bash run_bern2.sh 81 | 82 | # For Windows 83 | bash run_bern2_windows.sh 84 | ``` 85 | 86 | ## Dataset Generation 87 | 88 | Once you have installed all the requirements, you are ready to create your list QA datasets. Please see the example script below. 89 | 90 | ```bash 91 | export CUDA_VISIBLE_DEVICES=0 92 | export DATA_FILE=./data/unlabeled/wiki/0000.json 93 | export OUTPUT_FILE=./data/synthetic/wiki/0000.json 94 | python generate.py \ 95 | --data_file ${DATA_FILE} \ 96 | --output_file ${OUTPUT_FILE} \ 97 | --batch_size 8 \ 98 | --summary_min_length 64 \ 99 | --summary_max_length 128 \ 100 | --summary_model_name_or_path facebook/bart-large-cnn \ 101 | --qg_min_length 64 \ 102 | --qg_max_length 128 \ 103 | --qg_model_name_or_path mrm8488/t5-base-finetuned-question-generation-ap \ 104 | --qa_model_name_or_path thatdramebaazguy/roberta-base-squad \ 105 | --do_summary \ 106 | --device 0 107 | ``` 108 | 109 | ### Argument Description 110 | - `batch_size`: Number of passages to process simultaneously in one batch. 111 | - `summary_min_length`, `summary_max_length`, `qg_min_length`, and `qg_max_length`: Minimum and maximum lengths of the output summary and question, respectively. 112 | - `summary_model_name_or_path`, `qg_model_name_or_path`, and `qa_model_name_or_path`: Model path for loading the summarization model, question-generation model, and question-answering model, respectively. For the *biomedical* domain, you can use `dmis-lab/biobert-base-cased-v1.1-squad` as the QA model. 113 | - `is_biomedical`: Use this option when the target domain is biomedicine. 114 | - `do_summary`: (**Recommended**) Use this option if you want to summarize input passages and extract candidate answers from the summaries. 115 | - `device`: Set to `0` if you want to use our framework on GPU; otherwise `-1`. 116 | 117 | ## List Question Answering 118 | 119 | To be updated soon. 120 | 121 | 122 | ## Reference 123 | Please cite our paper if it is helpful or relevant to your work. 124 | 125 | ```bash 126 | @article{lee2023liquid, 127 | title={LIQUID: A Framework for List Question Answering Dataset Generation}, 128 | author={Lee, Seongyun and Kim, Hyunjae and Kang, Jaewoo}, 129 | journal={arXiv preprint arXiv:2302.01691}, 130 | year={2023} 131 | } 132 | ``` 133 | 134 | ## Contact 135 | Feel free to email us (`sy-lee@korea.ac.kr` and `hyunjae-kim@korea.ac.kr`) if you have any! 136 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import hashlib 3 | import json 4 | import os 5 | import spacy 6 | import requests 7 | import transformers 8 | 9 | from transformers import BartTokenizer, BartForConditionalGeneration, pipeline 10 | from transformers import T5Tokenizer, T5ForConditionalGeneration 11 | from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM 12 | 13 | from datetime import datetime 14 | from collections import defaultdict 15 | from tqdm import tqdm 16 | 17 | 18 | transformers.logging.set_verbosity_error() 19 | 20 | def summarization(data, model, batch_size, min_length, max_length): 21 | summary = [] 22 | context = [] 23 | titles = [] 24 | num_paragraphs = [] 25 | summary_context = { 26 | "data": [] 27 | } 28 | 29 | for i in range(len(data)): 30 | title = data[i]['title'] 31 | paragraphs = data[i]['paragraphs'] 32 | titles.append(title) 33 | num_paragraphs.append(len(paragraphs)) 34 | 35 | for text in paragraphs: 36 | text = text['context'].strip() 37 | context.append(text[:768]) 38 | 39 | print("********** Context Summarization starts! **********") 40 | 41 | for i in tqdm(range(0, len(context), batch_size), total=len(context)//batch_size): 42 | if i+batch_size <= len(context): 43 | batch = context[i:i+batch_size] 44 | else: 45 | batch = context[i:] 46 | result = model(batch, min_length=0, max_length=max_length, batch_size=batch_size) 47 | summary += result 48 | 49 | print("********** Context Summarization ends! **********") 50 | 51 | summary = [x['summary_text'] for x in summary] 52 | assert len(summary) == len(context), "Summarization is inconsistency!" 53 | 54 | start_idx = 0 55 | for i in range(len(titles)): 56 | title = titles[i] 57 | num_context = num_paragraphs[i] 58 | summaries = summary[start_idx:start_idx+num_context] 59 | original_text = context[start_idx:start_idx+num_context] 60 | paragraphs = [{"context": s, "original_text": o} for s, o in zip(summaries, original_text)] 61 | summary_context['data'].append({ 62 | "title": title, 63 | "paragraphs": paragraphs 64 | }) 65 | start_idx += num_context 66 | 67 | return summary_context 68 | 69 | 70 | def bern2(text, url="http://localhost:8888/plain"): 71 | return requests.post(url, json={'text': text}).json() 72 | 73 | def answer_extraction(is_biomedical, data, do_summary): 74 | print("********** Answer extraction starts! **********") 75 | temp = [] 76 | for d in data: 77 | title = d['title'] 78 | paragraphs = d['paragraphs'] 79 | for paragraph in paragraphs: 80 | context = paragraph['context'] 81 | if do_summary: 82 | original_text = paragraph['original_text'] 83 | temp.append({ 84 | "title": title, 85 | "context": context, 86 | "original_text": original_text 87 | }) 88 | else: 89 | temp.append({ 90 | "title": title, 91 | "context": context, 92 | }) 93 | 94 | if not is_biomedical: 95 | model = spacy.load('en_core_web_sm') 96 | 97 | ca = { 98 | "data": [] 99 | } 100 | 101 | for i in tqdm(range(len(temp)), total=len(temp)): 102 | context = temp[i]['context'] 103 | original_text = temp[i]["original_text"] if do_summary else temp[i]['context'] 104 | 105 | title = temp[i]['title'] 106 | 107 | if is_biomedical: 108 | ents = bern2(context)['annotations'] 109 | else: 110 | ents = model(context) 111 | ents = ents.ents 112 | 113 | ents_set = set() 114 | type2ent = defaultdict(set) 115 | 116 | for ent in ents: 117 | if is_biomedical: 118 | label = ent['obj'] 119 | ent = str(ent['mention']) 120 | if label != 'species': 121 | type2ent[label].add(ent) 122 | ents_set.add((ent, label)) 123 | 124 | else: 125 | label = str(ent.label_) 126 | if label != 'DATE': 127 | ent = str(ent) 128 | type2ent[label].add(ent) 129 | ents_set.add((ent, label)) 130 | 131 | for ent_type, entity in type2ent.items(): 132 | entity = list(entity) 133 | if 1 < len(entity): 134 | answers = [] 135 | flag = True 136 | 137 | for ent in entity: 138 | answer_start = original_text.find(ent) 139 | if answer_start == -1: 140 | flag = False 141 | answer = { 142 | "answer_text": ent, 143 | "answer_type": ent_type, 144 | "answer_start": answer_start 145 | } 146 | if flag: 147 | answers.append(answer) 148 | if 1 < len(answers): 149 | cas = { 150 | "title": title, 151 | "context": original_text, 152 | "answers": answers, 153 | } 154 | 155 | ca['data'].append(cas) 156 | 157 | print("********** Answer extraction ends! **********") 158 | 159 | return ca 160 | 161 | 162 | def question_generation(model, data, batch_size, min_length, max_length): 163 | 164 | print("********** Question Generation Starts! **********") 165 | ca = [] 166 | for i in range(len(data)): 167 | answers = [x['answer_text'] for x in data[i]['answers']] 168 | answers_string = ", ".join(answers) 169 | context = data[i]['context'] 170 | text = "answers: %s context: %s " % (answers_string, context) 171 | ca.append(text) 172 | 173 | questions = [] 174 | for i in tqdm(range(0, len(ca), batch_size), total=len(ca)//batch_size): 175 | if i+batch_size <= len(ca): 176 | result = model(ca[i:i+batch_size], batch_size=batch_size) 177 | else: 178 | result = model(ca[i:], batch_size=batch_size) 179 | questions += result 180 | 181 | # {'generated_text': 'question: What is AI?'} 182 | q_start_idx = 10 183 | questions = [q['generated_text'][q_start_idx:] for q in questions] 184 | assert len(questions) == len(data), "Question Generation is inconsistency!" 185 | 186 | cqa = { 187 | "data": [] 188 | } 189 | 190 | title2ctx = {} 191 | ctx2qas = defaultdict(list) 192 | titles = [x['title'] for x in data] 193 | 194 | for title in titles: 195 | title2ctx[title] = [] 196 | 197 | for i in range(len(data)): 198 | answers = data[i]['answers'] 199 | question = questions[i] 200 | context = data[i]['context'] 201 | title = data[i]['title'] 202 | 203 | str2hash = title + context + question + "".join([a["answer_text"] for a in answers]) 204 | hash_res = hashlib.md5(str2hash.encode()) 205 | qid = hash_res.hexdigest() 206 | 207 | title2ctx[title].append(context) 208 | ctx2qas[context].append({ 209 | "id": qid, 210 | "question": question, 211 | "answers": answers 212 | }) 213 | 214 | for title, ctx in title2ctx.items(): 215 | temp = { 216 | "title": title, 217 | "paragraphs": [] 218 | } 219 | for c in ctx: 220 | qas = ctx2qas[c] 221 | paragraph = { 222 | "context": c, 223 | "qas": qas 224 | } 225 | if paragraph not in temp['paragraphs']: 226 | temp["paragraphs"].append(paragraph) 227 | 228 | cqa['data'].append(temp) 229 | 230 | print("********** Question Generation Ends! **********") 231 | 232 | return cqa 233 | 234 | def answer_filtering(answers, pseudo_answers, context, filter_count): 235 | 236 | # filtering 237 | filtered_answers = [] 238 | answer_start = [] 239 | min_prob = 10 240 | for answer in answers: 241 | for pseudo_answer in pseudo_answers: 242 | try: 243 | if 0.01 <= pseudo_answer['score']: 244 | if answer in pseudo_answer['answer']: 245 | min_prob = min(pseudo_answer['score'], min_prob) 246 | filtered_answers.append(answer) 247 | answer_start.append(pseudo_answer['start']) 248 | break 249 | else: 250 | if answer == pseudo_answer['answer']: 251 | min_prob = min(pseudo_answer['score'], min_prob) 252 | filtered_answers.append(answer) 253 | answer_start.append(pseudo_answer['start']) 254 | break 255 | except: 256 | pass 257 | 258 | if filter_count != 2: 259 | assert len(answer_start) == len(filtered_answers), f"len answer_text ({len(filtered_answers)}) != len answer_start ({len(answer_start)})" 260 | 261 | if len(filtered_answers) == 0: 262 | context = "" 263 | 264 | return filtered_answers, answer_start, context 265 | 266 | # expansion 267 | else: 268 | if 1 < len(filtered_answers): 269 | for pseudo_answer in pseudo_answers: 270 | if pseudo_answer['score'] < min_prob: 271 | break 272 | try: 273 | is_in = False 274 | for filtered_answer in filtered_answers: 275 | if filtered_answer in pseudo_answer['answer']: 276 | is_in = True 277 | break 278 | if is_in is False and context.find(pseudo_answer['answer']) != -1: 279 | filtered_answers.append(pseudo_answer['answer']) 280 | answer_start.append(pseudo_answer['start']) 281 | except: 282 | pass 283 | 284 | # deduplication 285 | final_filtered_answers = [] 286 | final_answer_start = [] 287 | for i in range(len(filtered_answers)): 288 | flag = False 289 | for j in range(len(filtered_answers)): 290 | if i != j: 291 | filtered_answers_i = filtered_answers[i].lower() 292 | filtered_answers_j = filtered_answers[j].lower() 293 | if filtered_answers_i in filtered_answers_j: 294 | flag = True 295 | break 296 | 297 | if flag is False: 298 | final_filtered_answers.append(filtered_answers[i]) 299 | final_answer_start.append(answer_start[i]) 300 | 301 | 302 | assert len(final_answer_start) == len(final_filtered_answers), f"len answer_text ({len(final_filtered_answers)}) != len answer_start ({len(final_answer_start)})" 303 | 304 | if len(final_filtered_answers) == 0: 305 | context = "" 306 | 307 | return final_filtered_answers, final_answer_start, context 308 | 309 | def iterative_filtering(data, qg_model, qa_model, batch_size): 310 | 311 | print("********** Iterative Filtering Starts! **********") 312 | 313 | q_start_idx = 10 314 | data_list = [] 315 | 316 | for i in range(len(data)): 317 | title = data[i]['title'] 318 | paragraphs = data[i]['paragraphs'] 319 | for paragraph in paragraphs: 320 | context = paragraph['context'] 321 | qas = paragraph['qas'] 322 | for qa in qas: 323 | qid = qa['id'] 324 | question = qa['question'] 325 | answers = qa['answers'] 326 | data_list.append({ 327 | "id": qid, 328 | "title": title, 329 | "context": context, 330 | "question": question, 331 | "answers": answers 332 | }) 333 | 334 | title2ctx = defaultdict(set) 335 | ctx2qas = defaultdict(list) 336 | 337 | for i in tqdm(range(0, len(data_list), batch_size), total=len(data_list)//batch_size): 338 | batch = data_list[i:i+batch_size] if i+batch_size <= len(data_list) else data_list[i:] 339 | context = [x['context'] for x in batch] 340 | question = [x['question'] for x in batch] 341 | answers = [x['answers'] for x in batch] 342 | qid = [x['id'] for x in batch] 343 | title = [x['title'] for x in batch] 344 | 345 | answer_text = [] 346 | for answer in answers: 347 | answer_text.append([x['answer_text'] for x in answer]) 348 | 349 | generated_q = [] 350 | for idx in range(3): 351 | answer_string = [", ".join(x) for x in answer_text] 352 | ca = ["answers: %s context: %s " % (a, c) for a, c in zip(answer_string, context)] 353 | generated_q = qg_model(ca, batch_size=batch_size) 354 | generated_q = [x['generated_text'][q_start_idx:] for x in generated_q] 355 | pseudo_answers = qa_model(question=generated_q, context=context, top_k=30, batch_size=batch_size) 356 | filtered_answer_text = [] 357 | answer_starts = [] 358 | filtered_contexts = [] 359 | filtered_titles = [] 360 | filtered_qids = [] 361 | 362 | for answer, ctx, q, t, pseudo_answer in zip(answer_text, context, qid, title, pseudo_answers): 363 | filtered_answers, answer_start, filtered_context = answer_filtering(answer, pseudo_answer, ctx, idx) 364 | if filtered_context != "": 365 | filtered_contexts.append(filtered_context) 366 | filtered_answer_text.append(filtered_answers) 367 | answer_starts.append(answer_start) 368 | filtered_qids.append(q) 369 | filtered_titles.append(t) 370 | 371 | 372 | answer_text = filtered_answer_text 373 | context = filtered_contexts 374 | qid = filtered_qids 375 | title = filtered_titles 376 | 377 | 378 | assert len(answer_starts) == len(answer_text) == len(context) == len(qid) == len(title), \ 379 | f"len answer_text ({len(answer_text)}) != len answer_start ({len(answer_starts)}) != \ 380 | len context ({len(context)}) != len qid ({len(qid)}) != len({len(title)})" 381 | 382 | 383 | # final question generation 384 | answer_string = [", ".join(x) for x in answer_text] 385 | ca = ["answers: %s context: %s " % (a, c) for a, c in zip(answer_string, context)] 386 | new_generated_q = qg_model(ca, batch_size=batch_size) 387 | new_generated_q = [x['generated_text'][q_start_idx:] for x in new_generated_q] 388 | pseudo_answers = qa_model(question=new_generated_q, context=context, top_k=30, batch_size=batch_size) 389 | 390 | # final filtering 391 | filtered_answer_text = [] 392 | answer_starts = [] 393 | filtered_contexts = [] 394 | filtered_titles = [] 395 | filtered_qids = [] 396 | final_questions = [] 397 | 398 | for answer, ctx, q, t, pseudo_answer, oq, nq in zip(answer_text, context, qid, title, pseudo_answers, generated_q, new_generated_q): 399 | filtered_answers, answer_start, filtered_context = answer_filtering(answer, pseudo_answer, ctx, 0) 400 | if filtered_context != "": 401 | if set(filtered_answers) == set(answer): 402 | final_questions.append(nq) 403 | else: 404 | final_questions.append(oq) 405 | 406 | filtered_contexts.append(filtered_context) 407 | filtered_answer_text.append(filtered_answers) 408 | answer_starts.append(answer_start) 409 | filtered_qids.append(q) 410 | filtered_titles.append(t) 411 | 412 | assert len(final_questions) == len(answer_starts) == len(filtered_answer_text) == len(filtered_contexts), \ 413 | f"len final_questions {len(final_questions)} != len answer_starts {len(answer_starts)} \ 414 | != len answer_text {len(filtered_answer_text)} != len contexts {len(filtered_contexts)}" 415 | 416 | qid = filtered_qids 417 | title = filtered_titles 418 | context = filtered_contexts 419 | generated_q = final_questions 420 | answer_text = filtered_answer_text 421 | 422 | for j in range(len(answer_text)): 423 | filtered_answers = answer_text[j] 424 | if 1 < len(filtered_answers): 425 | answers = [{"answer_text": ans, "answer_start": ans_start} for ans, ans_start in zip(filtered_answers, answer_starts[j])] 426 | title2ctx[title[j]].add(context[j]) 427 | ctx2qas[context[j]].append({ 428 | "id": qid[j], 429 | "question": generated_q[j], 430 | "answers": answers 431 | }) 432 | 433 | filtered_data = { 434 | "data": [] 435 | } 436 | 437 | for title, ctxs in title2ctx.items(): 438 | tp = { 439 | "title": title, 440 | "paragraphs": [] 441 | } 442 | for ctx in ctxs: 443 | cqa = { 444 | "context": ctx, 445 | "qas": ctx2qas[ctx] 446 | } 447 | tp['paragraphs'].append(cqa) 448 | filtered_data['data'].append(tp) 449 | 450 | print("********** Iterative Filtering Ends! **********") 451 | 452 | return filtered_data 453 | 454 | def load_data(path): 455 | with open(path, 'r') as f: 456 | data = json.load(f) 457 | return data 458 | 459 | def save_data(file_path, data): 460 | # extracdt path 461 | dir_name = os.path.dirname(file_path) 462 | if not os.path.exists(dir_name): 463 | os.makedirs(dir_name) 464 | 465 | with open(file_path, "w") as f: 466 | json.dump(data, f, indent=4) 467 | print("Saved the dataset at '{}'.".format(file_path)) 468 | 469 | def main(args): 470 | # count 471 | data = load_data(args.data_file) 472 | 473 | if args.doc_limit: 474 | data = { 475 | "data": data["data"][:args.doc_limit] 476 | } 477 | 478 | # summarization 479 | if args.do_summary: 480 | sum_tokenizer = BartTokenizer.from_pretrained(args.summary_model_name_or_path) 481 | sum_model = BartForConditionalGeneration.from_pretrained(args.summary_model_name_or_path) 482 | sum_pipe = pipeline('summarization', model=sum_model, tokenizer=sum_tokenizer, device=args.device, truncation=True) 483 | 484 | context = summarization(data['data'], sum_pipe, args.batch_size, args.summary_min_length, args.summary_max_length) 485 | #now = datetime.now() 486 | #with open(f'{output_file}/context-summarization-{now}.json', 'w') as f: 487 | # json.dump(context, f, indent="\t") 488 | else: 489 | context = data 490 | 491 | # answer extraction 492 | ca = answer_extraction(args.is_biomedical, context['data'], args.do_summary) 493 | #now = datetime.now() 494 | #with open(f'{output_file}/answer-extraction-{now}.json', 'w') as f: 495 | # json.dump(ca, f, indent="\t") 496 | 497 | # question generation 498 | qg_tokenizer = AutoTokenizer.from_pretrained(args.qg_model_name_or_path) 499 | qg_model = AutoModelForSeq2SeqLM.from_pretrained(args.qg_model_name_or_path, use_cache=True) 500 | qg_pipe = pipeline('text2text-generation', model=qg_model, tokenizer=qg_tokenizer, device=args.device) 501 | 502 | cqa = question_generation(qg_pipe, ca['data'], args.batch_size, args.qg_min_length, args.qg_max_length) 503 | 504 | # iterative filtering 505 | qa_tokenizer = AutoTokenizer.from_pretrained(args.qa_model_name_or_path) 506 | qa_model = AutoModelForQuestionAnswering.from_pretrained(args.qa_model_name_or_path) 507 | qa_pipe = pipeline('question-answering', model=qa_model, tokenizer=qa_tokenizer, device=args.device) 508 | 509 | filtered_cqa = iterative_filtering(cqa['data'], qg_pipe, qa_pipe, args.batch_size) 510 | 511 | # saving filtered dataset 512 | save_data(args.output_file, filtered_cqa) 513 | 514 | if __name__ == "__main__": 515 | parser = argparse.ArgumentParser() 516 | 517 | # paths and general setups 518 | parser.add_argument('--data_file', required=True, type=str) 519 | parser.add_argument('--output_file', required=True, type=str) 520 | parser.add_argument('--batch_size', default=8, type=int) 521 | parser.add_argument('--device', default=0, type=int) 522 | parser.add_argument('--doc_limit', default=-1, type=int, 523 | help="number of documents to process in the input corpus. use -1 if you want to process all documents.") 524 | parser.add_argument('--is_biomedical', default=False, action="store_true") 525 | 526 | # summarization 527 | parser.add_argument('--do_summary', default=False, action="store_true") 528 | parser.add_argument('--summary_min_length', default=64, type=int) 529 | parser.add_argument('--summary_max_length', default=128, type=int) 530 | parser.add_argument('--summary_model_name_or_path', default='facebook/bart-large-cnn', type=str) 531 | 532 | # question generation 533 | parser.add_argument('--qg_min_length', default=64, type=int) 534 | parser.add_argument('--qg_max_length', default=128, type=int) 535 | parser.add_argument('--qg_model_name_or_path', default='mrm8488/t5-base-finetuned-question-generation-ap', type=str) 536 | 537 | # general domain default setting 538 | parser.add_argument('--qa_model_name_or_path', default="thatdramebaazguy/roberta-base-squad", type=str, 539 | help="use 'dmis-lab/biobert-base-cased-v1.1-squad' as the QA model for the biomedical domain") 540 | 541 | args = parser.parse_args() 542 | 543 | main(args) 544 | 545 | --------------------------------------------------------------------------------