├── README.md ├── classifier ├── data │ ├── medqa │ │ └── llama3_cot │ │ │ └── 5%-train.json │ └── preprocess.py ├── model │ └── token_add.ipynb ├── run │ └── run_large_train_xl_000.sh ├── run_classifier.py └── utils.py ├── main.py └── retriever ├── README.md ├── articles ├── cpg │ └── .gitkeep ├── pmc │ └── .gitkeep ├── pubmed │ └── .gitkeep └── textbook │ └── .gitkeep ├── embeddings ├── cpg │ └── .gitkeep ├── pmc │ └── .gitkeep ├── pubmed │ └── .gitkeep └── textbook │ └── .gitkeep ├── input └── .gitkeep ├── main.py ├── output └── .gitkeep ├── query_encode.py ├── rerank.py ├── retrieve.py └── retriever.yml /README.md: -------------------------------------------------------------------------------- 1 | # [NAACL 2025] Rationale-Guided Retrieval Augmented Generation for Medical Question Answering 2 | 3 | **Paper** | [Rationale-Guided Retrieval Augmented Generation for Medical Question Answering](https://arxiv.org/abs/2411.00300) 4 | 5 | **Authors**: Jiwoong Sohn, Yein Park, Chanwoong Yoon, Sihyeon Park, Hyeon Hwang, Mujeen Sung, Hyunjae Kim, Jaewoo Kang 6 | 7 | **Abstract**: Large language models (LLM) hold significant potential for applications in biomedicine, but they struggle with hallucinations and outdated knowledge. While retrieval-augmented generation (RAG) is generally employed to address these issues, it also has its own set of challenges: (1) LLMs are vulnerable to irrelevant or incorrect context, (2) medical queries are often not well-targeted for helpful information, and (3) retrievers are prone to bias toward the specific source corpus they were trained on. In this study, we present RAG² (RAtionale-Guided RAG), a new framework for enhancing the reliability of RAG in biomedical contexts. RAG² incorporates three key innovations: a small filtering model trained on perplexity-based labels of rationales, which selectively augments informative snippets of documents while filtering out distractors; LLM-generated rationales as queries to improve the utility of retrieved snippets; a structure designed to retrieve snippets evenly from a comprehensive set of four biomedical corpora, effectively mitigating retriever bias. Our experiments demonstrate that RAG² improves the state-of-the-art LLMs of varying sizes, with improvements of up to 6.1%, and it outperforms the previous best medical RAG model by up to 5.6% across three medical question-answering benchmarks. 8 | 9 | **Repository Overview** 10 | 11 | This repository contains the implementation of **Rationale-Guided Retrieval-Augmented Generation (RAG²)**. It includes code for training the filtering model, setting up the retriever, and running inference. The repository is organized as follows: 12 | 13 | ## Getting Started 14 | 15 | ### 1. Training Dataset Preparation 16 | - Generate Chain-of-Thought (CoT) rationales using LLMs 17 | - Calculate perplexity scores for each rationale 18 | - Create training labels based on perplexity thresholds 19 | - Process and format the training data 20 | 21 | ### 2. Retriever Setup 22 | - Index setup for multiple biomedical corpora 23 | - Configuration for balanced retrieval across corpora 24 | - Embedding model initialization 25 | - Retrieval parameter settings 26 | 27 | ### 3. Filtering Model Training 28 | The filtering model training code is based on [Adaptive-RAG](https://github.com/starsuzi/Adaptive-RAG). 29 | - Model architecture and configuration 30 | - Training with perplexity-based labels 31 | - Validation and model selection 32 | - Checkpoint saving 33 | 34 | ### 4. Inference Pipeline 35 | - Initial CoT generation for query enhancement 36 | - Multi-corpus retrieval 37 | - Filtering retrieved passages 38 | - Final response generation 39 | 40 | ## Usage 41 | Detailed instructions for each component will be provided soon. 42 | 43 | ### Citation 44 | If you use this work, please cite our paper: 45 | 46 | ``` 47 | @article{sohn2024rag, 48 | title={Rationale-Guided Retrieval Augmented Generation for Medical Question Answering}, 49 | author={Jiwoong Sohn and Yein Park and Chanwoong Yoon and Sihyeon Park and Hyeon Hwang and Mujeen Sung and Hyunjae Kim and Jaewoo Kang}, 50 | journal={arXiv preprint arXiv:2411.00300}, 51 | year={2024} 52 | } 53 | ``` 54 | -------------------------------------------------------------------------------- /classifier/data/medqa/llama3_cot/5%-train.json: -------------------------------------------------------------------------------- 1 | [{"id": "llama3_5%_23600", "answer": "[NOT_HELPFUL]", "dataset_name": "llama3_5%", "question": "Given the following evidence, determine whether it helps answer the provided question.\n\nEvidence: 1 specimen returned, 33 were positive for human metapneumovirus (6.1%) and 18 for human coronavirus NL63 (3.3%). Of all of the viruses for which we tested, human metapneumovirus and human coronavirus NL63 were most strongly linked to child care attendance, occurring in 82% and 78% of infected children, respectively. Picornaviruses were the most commonly identified virus group (269 [49.5%]). Influenza virus and adenovirus illnesses had the greatest impact, with fever in more than three quarters and requiring, on average, > 1 local doctor visit per illness. CONCLUSIONS: Recently identified human metapneumovirus and human coronavirus NL63 are important pathogens in community-based illness in children, particularly in those who attend child care. Picornaviruses were detected in half of the nose-throat swabs collected during acute respiratory illness in children but resulted\n\nQuestion: A 5-year-old boy presents with bilateral conjunctivitis and pharyngitis. The patient\u2019s mother says that symptoms acutely onset 3 days ago and include itchy red eyes, a low-grade fever, and a sore throat. She says that the patient recently attended a camp where other kids were also ill and were completely healthy before going. No significant past medical history. Which of the following is the most likely cause of this patient\u2019s symptoms? A) Metapneumovirus B) Influenza virus C) Rhinovirus D) Adenovirus"}, {"id": "llama3_5%_15069", "answer": "[NOT_HELPFUL]", "dataset_name": "llama3_5%", "question": "Given the following evidence, determine whether it helps answer the provided question.\n\nEvidence: poor intake. He was then admitted. A more comprehensive laboratory evaluation was initiated. During this hospital course, the patient's physical examination changed when he developed head and neck edema, and certain laboratory trends became clearer. With the assistance of several specialists, the team was able to reach a more definitive diagnosis and initiate treatment to appropriately manage his condition.\n\nQuestion: A 48-year-old man comes to the physician because of severe joint pain and swelling involving different joints for 3 months. He has also been having loose stools and episodes of epigastric pain for 6 months. He reports a 10-kg (22-lb) weight loss during this period. He has type 2 diabetes mellitus. He does not smoke or drink alcohol. His medications include insulin and metformin. His vital signs are within normal limits. Examination shows pale conjunctivae, angular cheilitis, and glossitis. Axillary and cervical lymphadenopathy is present. A grade 2/6 pansystolic murmur is heard best at the apex. The right knee is swollen and tender; range of motion is limited. The sacroiliac joints are tender. Test of the stool for occult blood is negative. Laboratory studies show:\nHemoglobin 9.2 g/dL\nMean corpuscular volume 90 \u03bcm3\nLeukocyte count 4,800/mm3\nSerum\nNa+ 134 mEq/L\nCl- 96 mEq/L\nK+ 3.3 mEq/L\nGlucose 143 mg/dL\nCreatinine 1.2 mg/dL\nA small intestine biopsy shows periodic acid-Schiff-positive (PAS-positive) macrophages in the lamina propria. Which of the following is the most appropriate next step in management?\" A) Oral doxycycline B) Gluten-free diet C) Oral rifampin D) Intravenous ceftriaxone"}, {"id": "llama3_5%_21569", "answer": "[HELPFUL]", "dataset_name": "llama3_5%", "question": "Given the following evidence, determine whether it helps answer the provided question.\n\nEvidence: right arm. At that time, her laboratory work was normal, and she was treated symptomatically. Finally, she was brought to the ED by her spouse due to lethargy, confusion, severe pain, and bloody urine. History taking was limited time as she had altered mental status and was writhing in pain. She was found to have abnormal laboratory test results: WBC 103.5 \u00d7 103/\u00b5L, hemoglobin 3.3 mg/dL, platelets 194 000/\u00b5L, international normalized ratio 1.35, D-dimer 2.22 mg/L, aspartate transaminase 3335 units/L and alanine transaminase 860 units/L, total bilirubin 9.8 mg/dL, creatinine 1.8 mg/dL, lactate dehydrogenase 3260 units/L, haptoglobin less than 10 mg/dL, fibrinogen 384 mg/dL, and lactic acid 20 mmol/L. Direct Coombs\u2019 (antiglobulin) test was positive for IgG (immunoglobulin G) and complement, and creatine kinase was 1200 U/L with\n\nQuestion: Following a recent myocardial infarction, a 60-year-old woman has been started on multiple medications at the time of discharge from the hospital. After 10 days of discharge, she presents to the emergency department with a history of fever, headache, and dark colored urine for 2 days. Her husband mentions that she has not passed urine for the last 24 hours. Her physical examination shows significant pallor, and multiple petechiae are present all over her limbs. Her vital signs include: temperature 38.9\u00b0C (102.0\u00b0F), pulse rate 94/min, blood pressure 124/82 mm Hg, and respiratory rate 16/min. Her sensorium is altered with the absence of spontaneous speech and spontaneous movements. She responds inappropriately to verbal stimuli. Her laboratory results show the presence of anemia and thrombocytopenia. Examination of peripheral blood smear shows the presence of schistocytes. Serum creatinine is 2 mg/dL. Serum levels of fibrinogen, fibrin monomers, fibrin degradation products and D-dimers are normal. Prothrombin time (PT) and activated partial thromboplastin time (aPTT) are normal. Which is the most likely treatment for this patient\u2019s condition? A) Renal dialysis B) Plasma exchange C) Intravenous immunoglobulin D) Rehydration"}, {"id": "llama3_5%_5692", "answer": "[HELPFUL]", "dataset_name": "llama3_5%", "question": "Given the following evidence, determine whether it helps answer the provided question.\n\nEvidence: hours after delivery. There is impaired motility of gallbladder, with delayed emptying; this leads to the development of biliary sludge and gallstones. Thus, pregnant women have progressive increase in residual volume of gallbladder throughout the pregnancy, which returns to normal volume shortly after delivery. Contrary to the popular belief, there is high incidence of biliary colic in pregnant women, which gets better with conservative management. If intervention is required, however, it is best done in the second trimester and is generally well tolerated.\n\nQuestion: A previously healthy 37-year-old woman, gravida 3, para 2, at 29 weeks' gestation comes to the physician because of colicky postprandial abdominal pain. Her vital signs are within normal limits. Physical examination shows a uterus consistent in size with a 29-week gestation. Ultrasonography of the abdomen shows multiple 5-mm hyperechoic masses within the gallbladder lumen. Which of the following processes is most likely involved in the pathogenesis of this patient's condition? A) Accelerated gallbladder emptying B) Increased secretion of bile acids C) Overproduction of bilirubin D) Increased secretion of cholesterol\n\""}, {"id": "llama3_5%_19869", "answer": "[NOT_HELPFUL]", "dataset_name": "llama3_5%", "question": "Given the following evidence, determine whether it helps answer the provided question.\n\nEvidence: She did not report any chest pain, shortness of breath, peripheral edema, palpitations, vision changes, or pain in the extremities. She had not experienced any recent trauma, upper respiratory or gastrointestinal tract infections. The patient had never experienced symptoms of neuropathy previously. She denied alcohol, tobacco, and illicit drug use. The patient\u2019s past medical history is significant for hypertension, hyperlipidemia, and obstructive sleep apnea. At the time of this encounter, she was taking 850 mg of metformin twice per day, 64 units of insulin degludec injected once daily, atorvastatin 80 mg once per day, and lisinopril-hydrochlorothiazide 10-12.5 mg once daily. This medication regimen had brought her blood sugar level to around 100 mg/dL on average. On physical exam, the patient was alert and oriented with a normal thought\n\nQuestion: A 54-year-old man with a past medical history significant for hypertension, type 2 diabetes, and chronic obstructive pulmonary disease presents with complaints of nausea and abdominal pain for the past month. The pain is located in the epigastric region and is described as \u201cburning\u201d in quality, often following food intake. The patient denies any changes in bowel movements, fever, or significant weight loss. Medications include metformin, lisinopril, hydrochlorothiazide, albuterol inhaler, and fluconazole for a recent fungal infection. Physical examination was unremarkable except for a mildly distended abdomen that is diffusely tender to palpation and decreased sensation at lower extremities bilaterally. A medication was started for the symptoms. Two days later, the patient reports heart palpitations. An EKG is shown below. Which of the following is the medication most likely prescribed? A) Erythromycin B) Metformin C) Omeprazole D) Ranitidine"}] -------------------------------------------------------------------------------- /classifier/data/preprocess.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmis-lab/RAG2/911af780456469d2483cf09f38e6c8c94539d646/classifier/data/preprocess.py -------------------------------------------------------------------------------- /classifier/model/token_add.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "application/vnd.jupyter.widget-view+json": { 11 | "model_id": "db7cac56b52149a38278f66aefa5acaf", 12 | "version_major": 2, 13 | "version_minor": 0 14 | }, 15 | "text/plain": [ 16 | "Loading checkpoint shards: 0%| | 0/3 [00:00=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") 76 | 77 | # You should update this to your particular problem to have better documentation of `model_type` 78 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) 79 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 80 | 81 | ## 82 | try: 83 | nltk.data.find("tokenizers/punkt") 84 | except (LookupError, OSError): 85 | if is_offline_mode(): 86 | raise LookupError( 87 | "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" 88 | ) 89 | with FileLock(".lock") as lock: 90 | nltk.download("punkt", quiet=True) 91 | 92 | option_to_label = { 93 | 'A': 0, 94 | 'B': 1, 95 | 'C': 2, 96 | } 97 | 98 | label_to_option = { 99 | 0: 'A', 100 | 1: 'B', 101 | 2: 'C', 102 | } 103 | 104 | def parse_args(): 105 | parser = argparse.ArgumentParser(description="Finetune a transformers model on a QA task") 106 | parser.add_argument( 107 | "--dataset_name", 108 | type=str, 109 | default=None, 110 | help="The name of the dataset to use (via the datasets library).", 111 | ) 112 | parser.add_argument( 113 | "--dataset_config_name", 114 | type=str, 115 | default=None, 116 | help="The configuration name of the dataset to use (via the datasets library).", 117 | ) 118 | parser.add_argument( 119 | "--train_file", type=str, default=None, help="A csv or a json file containing the training data." 120 | ) 121 | parser.add_argument( 122 | "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." 123 | ) 124 | parser.add_argument( 125 | "--ignore_pad_token_for_loss", 126 | type=bool, 127 | default=True, 128 | help="Whether to ignore the tokens corresponding to padded labels in the loss computation or not.", 129 | ) 130 | parser.add_argument( 131 | "--max_seq_length", 132 | type=int, 133 | default=384, 134 | help=( 135 | "The maximum total input sequence length after " 136 | "The maximum total input sequence length after tokenization. Sequences longer " 137 | "than this will be truncated, sequences shorter will be padded." 138 | ), 139 | ) 140 | parser.add_argument( 141 | "--source_prefix", 142 | type=str, 143 | default=None, 144 | help="A prefix to add before every source text (useful for T5 models).", 145 | ) 146 | 147 | parser.add_argument( 148 | "--preprocessing_num_workers", 149 | type=int, 150 | default=None, 151 | help="The number of processes to use for the preprocessing.", 152 | ) 153 | parser.add_argument("--do_eval", action="store_true", help="To do eval on the question answering model") 154 | parser.add_argument("--do_train", action="store_true", help="To do train on the question answering model") 155 | # data col 156 | parser.add_argument( 157 | "--train_column", 158 | type=str, 159 | default='train', 160 | help="The name of the train column in the datasets.", 161 | ) 162 | parser.add_argument( 163 | "--val_column", 164 | type=str, 165 | default='validation', 166 | help="The name of the validation column in the datasets.", 167 | ) 168 | parser.add_argument( 169 | "--test_column", 170 | type=str, 171 | default='test', 172 | help="The name of the test column in the datasets.", 173 | ) 174 | parser.add_argument( 175 | "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" 176 | ) 177 | 178 | parser.add_argument( 179 | "--max_answer_length", 180 | type=int, 181 | default=30, 182 | help=( 183 | "The maximum length of an answer that can be generated. This is needed because the start " 184 | "and end predictions are not conditioned on one another." 185 | ), 186 | ) 187 | parser.add_argument( 188 | "--val_max_answer_length", 189 | type=int, 190 | default=None, 191 | help=( 192 | "The maximum total sequence length for validation " 193 | "target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be " 194 | "padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` " 195 | "param of ``model.generate``, which is used during ``evaluate`` and ``predict``." 196 | ), 197 | ) 198 | parser.add_argument( 199 | "--max_train_samples", 200 | type=int, 201 | default=None, 202 | help=( 203 | "For debugging purposes or quicker training, truncate the number of training examples to this " 204 | "value if set." 205 | ), 206 | ) 207 | parser.add_argument( 208 | "--max_eval_samples", 209 | type=int, 210 | default=None, 211 | help=( 212 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 213 | "value if set." 214 | ), 215 | ) 216 | 217 | parser.add_argument( 218 | "--num_beams", 219 | type=int, 220 | default=None, 221 | help=( 222 | "Number of beams to use for evaluation. This argument will be " 223 | "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``." 224 | ), 225 | ) 226 | parser.add_argument( 227 | "--pad_to_max_length", 228 | action="store_true", 229 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", 230 | ) 231 | parser.add_argument( 232 | "--model_name_or_path", 233 | type=str, 234 | help="Path to pretrained model or model identifier from huggingface.co/models.", 235 | required=False, 236 | ) 237 | parser.add_argument( 238 | "--config_name", 239 | type=str, 240 | default=None, 241 | help="Pretrained config name or path if not the same as model_name", 242 | ) 243 | parser.add_argument( 244 | "--tokenizer_name", 245 | type=str, 246 | default=None, 247 | help="Pretrained tokenizer name or path if not the same as model_name", 248 | ) 249 | parser.add_argument( 250 | "--question_column", 251 | type=str, 252 | default='question', 253 | help="The name of the column in the datasets containing the questions (for question answering).", 254 | ) 255 | parser.add_argument( 256 | "--answer_column", 257 | type=str, 258 | default='answers', 259 | help="The name of the column in the datasets containing the answers (for question answering).", 260 | ) 261 | 262 | parser.add_argument( 263 | "--use_slow_tokenizer", 264 | action="store_true", 265 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", 266 | ) 267 | parser.add_argument( 268 | "--per_device_train_batch_size", 269 | type=int, 270 | default=8, 271 | help="Batch size (per device) for the training dataloader.", 272 | ) 273 | parser.add_argument( 274 | "--per_device_eval_batch_size", 275 | type=int, 276 | default=8, 277 | help="Batch size (per device) for the evaluation dataloader.", 278 | ) 279 | parser.add_argument( 280 | "--learning_rate", 281 | type=float, 282 | default=5e-5, 283 | help="Initial learning rate (after the potential warmup period) to use.", 284 | ) 285 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 286 | parser.add_argument("--num_train_epochs", type=int, default=2, help="Total number of training epochs to perform.") 287 | parser.add_argument( 288 | "--max_train_steps", 289 | type=int, 290 | default=None, 291 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 292 | ) 293 | 294 | parser.add_argument( 295 | "--gradient_accumulation_steps", 296 | type=int, 297 | default=1, 298 | help="Number of updates steps to accumulate before performing a backward/update pass.", 299 | ) 300 | parser.add_argument( 301 | "--lr_scheduler_type", 302 | type=SchedulerType, 303 | default="linear", 304 | help="The scheduler type to use.", 305 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], 306 | ) 307 | parser.add_argument( 308 | "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." 309 | ) 310 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 311 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 312 | parser.add_argument( 313 | "--model_type", 314 | type=str, 315 | default=None, 316 | help="Model type to use if training from scratch.", 317 | choices=MODEL_TYPES, 318 | ) 319 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 320 | parser.add_argument( 321 | "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." 322 | ) 323 | parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") 324 | parser.add_argument( 325 | "--checkpointing_steps", 326 | type=str, 327 | default=None, 328 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", 329 | ) 330 | parser.add_argument( 331 | "--resume_from_checkpoint", 332 | type=str, 333 | default=None, 334 | help="If the training should continue from a checkpoint folder.", 335 | ) 336 | parser.add_argument( 337 | "--with_tracking", 338 | action="store_true", 339 | help="Whether to enable experiment trackers for logging.", 340 | ) 341 | parser.add_argument( 342 | "--report_to", 343 | type=str, 344 | default="all", 345 | help=( 346 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' 347 | ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.' 348 | "Only applicable when `--with_tracking` is passed." 349 | ), 350 | ) 351 | parser.add_argument( 352 | "--doc_stride", 353 | type=int, 354 | default=128, 355 | help="When splitting up a long document into chunks how much stride to take between chunks.", 356 | ) 357 | args = parser.parse_args() 358 | 359 | return args 360 | 361 | 362 | def main(): 363 | args = parse_args() 364 | 365 | 366 | # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. 367 | # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers 368 | # in the environment 369 | accelerator_log_kwargs = {} 370 | 371 | if args.with_tracking: 372 | accelerator_log_kwargs["log_with"] = args.report_to 373 | accelerator_log_kwargs["logging_dir"] = args.output_dir 374 | 375 | accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs) 376 | 377 | device = accelerator.device 378 | 379 | if args.source_prefix is None and args.model_name_or_path in [ 380 | "t5-small", 381 | "t5-base", 382 | "t5-large", 383 | "t5-3b", 384 | "t5-11b", 385 | ]: 386 | logger.warning( 387 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " 388 | "`--source_prefix 'summarize: ' `" 389 | ) 390 | 391 | 392 | # Make one log on every process with the configuration for debugging. 393 | # TODO 394 | # Setup logging 395 | logging.basicConfig( 396 | filename=args.output_dir+'/logs.log', # 397 | filemode='w', 398 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 399 | datefmt="%m/%d/%Y %H:%M:%S", 400 | level=logging.INFO, 401 | force=True 402 | ) 403 | 404 | #logger.info(accelerator.state, main_process_only=False) 405 | logger.info(accelerator.state) 406 | if accelerator.is_local_main_process: 407 | datasets.utils.logging.set_verbosity_warning() 408 | transformers.utils.logging.set_verbosity_info() 409 | else: 410 | datasets.utils.logging.set_verbosity_error() 411 | transformers.utils.logging.set_verbosity_error() 412 | 413 | # If passed along, set the training seed now. 414 | if args.seed is not None: 415 | set_seed(args.seed) 416 | 417 | logger.info(args) 418 | 419 | 420 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 421 | # download the dataset. 422 | if args.dataset_name is not None: 423 | # Downloading and loading a dataset from the hub. 424 | raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) 425 | else: 426 | data_files = {} 427 | if args.train_file is not None: 428 | data_files["train"] = args.train_file 429 | if args.validation_file is not None: 430 | data_files["validation"] = args.validation_file 431 | if args.do_eval: 432 | extension = args.validation_file.split(".")[-1] 433 | else: 434 | extension = args.train_file.split(".")[-1] 435 | raw_datasets = load_dataset(extension, data_files=data_files) 436 | 437 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 438 | # https://huggingface.co/docs/datasets/loading_datasets.html. 439 | 440 | 441 | # load model and tokenizer 442 | model, tokenizer = load_model(args) 443 | 444 | 445 | if args.do_train: 446 | if args.train_column not in raw_datasets: 447 | raise ValueError("--do_train requires a train dataset") 448 | train_dataset = raw_datasets[args.train_column] 449 | 450 | if args.max_train_samples is not None: 451 | # We will select sample from whole data if agument is specified 452 | train_dataset = train_dataset.select(range(args.max_train_samples)) 453 | 454 | 455 | # Create train feature from dataset 456 | with accelerator.main_process_first(): 457 | train_dataset = train_dataset.map( 458 | preprocess_features_function, 459 | fn_kwargs={'args':args, 'raw_datasets':raw_datasets, 'tokenizer': tokenizer}, 460 | batched=True, 461 | num_proc=args.preprocessing_num_workers, 462 | remove_columns=train_dataset.column_names, 463 | load_from_cache_file=not args.overwrite_cache, 464 | desc="Running tokenizer on train dataset", 465 | ) 466 | if args.max_train_samples is not None: 467 | # Number of samples might increase during Feature Creation, We select only specified max samples 468 | train_dataset = train_dataset.select(range(args.max_train_samples)) 469 | 470 | 471 | if args.do_eval: 472 | if args.val_column not in raw_datasets: 473 | raise ValueError("--do_eval requires a validation dataset") 474 | eval_examples = raw_datasets[args.val_column] 475 | 476 | if args.max_eval_samples is not None: 477 | # We will select sample from whole data 478 | eval_examples = eval_examples.select(range(args.max_eval_samples)) 479 | # Validation Feature Creation 480 | with accelerator.main_process_first(): 481 | eval_dataset = eval_examples.map( 482 | preprocess_features_function, 483 | fn_kwargs={'args':args, 'raw_datasets':raw_datasets, 'tokenizer': tokenizer}, 484 | batched=True, 485 | num_proc=args.preprocessing_num_workers, 486 | remove_columns=eval_examples.column_names, 487 | load_from_cache_file=not args.overwrite_cache, 488 | desc="Running tokenizer on validation dataset", 489 | ) 490 | 491 | if args.max_eval_samples is not None: 492 | # During Feature creation dataset samples might increase, we will select required samples again 493 | eval_dataset = eval_dataset.select(range(args.max_eval_samples)) 494 | 495 | label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id 496 | 497 | data_collator = DataCollatorForSeq2Seq( 498 | tokenizer, 499 | model=model, 500 | label_pad_token_id=label_pad_token_id, 501 | pad_to_multiple_of=8 if accelerator.use_fp16 else None, 502 | ) 503 | 504 | if args.do_train: 505 | train_dataset_for_model = train_dataset.remove_columns(["example_id", "offset_mapping"]) 506 | train_dataloader = DataLoader( 507 | train_dataset_for_model, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size 508 | ) 509 | 510 | if args.do_eval: 511 | eval_dataset_for_model = eval_dataset.remove_columns(["example_id", "offset_mapping"]) 512 | eval_dataloader = DataLoader( 513 | eval_dataset_for_model, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size 514 | ) 515 | 516 | 517 | # Optimizer 518 | # Split weights in two groups, one with weight decay and the other not. 519 | no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"] 520 | optimizer_grouped_parameters = [ 521 | { 522 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 523 | "weight_decay": args.weight_decay, 524 | }, 525 | { 526 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 527 | "weight_decay": 0.0, 528 | }, 529 | ] 530 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) 531 | 532 | 533 | # Prepare everything with our `accelerator`. 534 | model, optimizer = accelerator.prepare( 535 | model, optimizer 536 | ) 537 | 538 | if args.do_train: 539 | train_dataloader = accelerator.prepare( 540 | train_dataloader 541 | ) 542 | 543 | if args.do_eval: 544 | eval_dataloader = accelerator.prepare( 545 | eval_dataloader 546 | ) 547 | 548 | # Figure out how many steps we should save the Accelerator states 549 | checkpointing_steps = args.checkpointing_steps 550 | if checkpointing_steps is not None and checkpointing_steps.isdigit(): 551 | checkpointing_steps = int(checkpointing_steps) 552 | 553 | # We need to initialize the trackers we use, and also store our configuration. 554 | # The trackers initializes automatically on the main process. 555 | if args.with_tracking: 556 | experiment_config = vars(args) 557 | # TensorBoard cannot log Enums, need the raw value 558 | experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value 559 | accelerator.init_trackers("no_trainer", experiment_config) 560 | 561 | # Train! 562 | if args.do_train: 563 | 564 | args.max_train_steps, args.num_train_epochs, lr_scheduler_train = prepare_scheduler(args, accelerator, train_dataloader, optimizer, args.max_train_steps, args.num_train_epochs) 565 | 566 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 567 | 568 | logger.info("***** Running training *****") 569 | logger.info(f" Num examples = {len(train_dataset)}") 570 | logger.info(f" Num Epochs = {args.num_train_epochs}") 571 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 572 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 573 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 574 | logger.info(f" Total optimization steps = {args.max_train_steps}") 575 | # Only show the progress bar once on each machine. 576 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 577 | completed_steps = 0 578 | starting_epoch = 0 579 | 580 | # Potentially load in the weights and states from a previous save 581 | if args.resume_from_checkpoint: 582 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": 583 | accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") 584 | accelerator.load_state(args.resume_from_checkpoint) 585 | path = os.path.basename(args.resume_from_checkpoint) 586 | else: 587 | # Get the most recent checkpoint 588 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] 589 | dirs.sort(key=os.path.getctime) 590 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last 591 | # Extract `epoch_{i}` or `step_{i}` 592 | training_difference = os.path.splitext(path)[0] 593 | 594 | if "epoch" in training_difference: 595 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1 596 | resume_step = None 597 | else: 598 | resume_step = int(training_difference.replace("step_", "")) 599 | starting_epoch = resume_step // len(train_dataloader) 600 | resume_step -= starting_epoch * len(train_dataloader) 601 | 602 | for epoch in range(starting_epoch, args.num_train_epochs): 603 | model.train() 604 | total_loss = 0 605 | for step, batch in enumerate(train_dataloader): 606 | # We need to skip steps until we reach the resumed step 607 | if args.resume_from_checkpoint and epoch == starting_epoch: 608 | if resume_step is not None and step < resume_step: 609 | completed_steps += 1 610 | continue 611 | 612 | with accelerator.accumulate(model): 613 | outputs = model(**batch) 614 | loss = outputs.loss 615 | accelerator.backward(loss) 616 | optimizer.step() 617 | lr_scheduler_train.step() 618 | optimizer.zero_grad() 619 | 620 | # logger.info("Loss:{} ".format(loss)) 621 | 622 | # We keep track of the loss at each epoch 623 | total_loss = total_loss + loss.cpu().detach().float() 624 | 625 | logger.info(tokenizer.batch_decode(batch["input_ids"][:1], skip_special_tokens=True)) 626 | 627 | # Checks if the accelerator has performed an optimization step behind the scenes 628 | if accelerator.sync_gradients: 629 | progress_bar.update(1) 630 | completed_steps += 1 631 | 632 | if isinstance(checkpointing_steps, int): 633 | if completed_steps % checkpointing_steps == 0: 634 | output_dir = f"step_{completed_steps }" 635 | if args.output_dir is not None: 636 | output_dir = os.path.join(args.output_dir, output_dir) 637 | accelerator.save_state(output_dir) 638 | 639 | if completed_steps >= args.max_train_steps: 640 | break 641 | 642 | logger.info("Epoch %d Loss:{} ".format(total_loss / len(train_dataloader)), epoch) 643 | 644 | if args.checkpointing_steps == "epoch": 645 | output_dir = f"epoch_{epoch}" 646 | if args.output_dir is not None: 647 | output_dir = os.path.join(args.output_dir, output_dir) 648 | accelerator.save_state(output_dir) 649 | 650 | if args.push_to_hub and epoch < args.num_train_epochs - 1: 651 | accelerator.wait_for_everyone() 652 | unwrapped_model = accelerator.unwrap_model(model) 653 | unwrapped_model.save_pretrained( 654 | args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save 655 | ) 656 | 657 | if accelerator.is_main_process: 658 | tokenizer.save_pretrained(args.output_dir) 659 | repo.push_to_hub( 660 | commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True 661 | ) 662 | 663 | 664 | if args.output_dir is not None: 665 | accelerator.wait_for_everyone() 666 | unwrapped_model = accelerator.unwrap_model(model) 667 | unwrapped_model.save_pretrained( 668 | args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save 669 | ) 670 | if accelerator.is_main_process: 671 | tokenizer.save_pretrained(args.output_dir) 672 | if args.push_to_hub: 673 | repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) 674 | 675 | 676 | 677 | # Validation 678 | if args.do_eval: 679 | logger.info("***** Running Validation *****") 680 | logger.info(f" Num examples = {len(eval_dataset)}") 681 | logger.info(f" Batch size = {args.per_device_eval_batch_size}") 682 | 683 | if args.val_max_answer_length is None: 684 | args.val_max_answer_length = args.max_answer_length 685 | 686 | gen_kwargs = { 687 | "max_length": args.val_max_answer_length, 688 | #'no_repeat_ngram_size':2 689 | #"num_beams": args.num_beams, 690 | } 691 | 692 | # inference 693 | model.eval() 694 | predictions = [] 695 | for step, batch in enumerate(eval_dataloader): 696 | with torch.no_grad(): 697 | 698 | scores = accelerator.unwrap_model(model).generate( 699 | input_ids=batch["input_ids"], 700 | attention_mask=batch["attention_mask"], 701 | return_dict_in_generate=True, 702 | output_scores=True, 703 | **gen_kwargs, 704 | ).scores[0] 705 | 706 | probs = ( 707 | torch.nn.functional.softmax( 708 | torch.stack([ 709 | scores[:, tokenizer('A').input_ids[0]], 710 | scores[:, tokenizer('B').input_ids[0]], 711 | scores[:, tokenizer('C').input_ids[0]], 712 | ]), dim=0, 713 | ).detach().cpu().numpy() 714 | ) 715 | 716 | preds_labels = np.argmax(probs, 0) 717 | preds = [label_to_option[pred] for pred in preds_labels] 718 | 719 | labels = batch["labels"] 720 | labels = accelerator.gather_for_metrics(labels) 721 | labels = labels.cpu().numpy() 722 | 723 | predictions = predictions + preds 724 | 725 | if args.ignore_pad_token_for_loss: 726 | # Replace -100 in the labels as we can't decode them. 727 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 728 | 729 | 730 | logger.info('==========================================') 731 | logger.info(tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=True)) 732 | logger.info('Prediction : ') 733 | logger.info(preds) 734 | logger.info('Answer : ') 735 | logger.info(tokenizer.batch_decode(labels, skip_special_tokens=False)) 736 | 737 | 738 | gold_answers = eval_examples['answer'] 739 | 740 | dict_id_pred_results = {qid : {'prediction': pred, 'answer' : ans, 'dataset_name' : data} for qid, pred, ans, data in zip(eval_examples['id'], predictions, gold_answers, eval_examples['dataset_name'])} 741 | with open(os.path.join(args.output_dir, "dict_id_pred_results.json"), "w") as f: 742 | json.dump(dict_id_pred_results, f, indent=4) 743 | 744 | assert len(gold_answers) == len(predictions) 745 | 746 | 747 | final_acc_score = calculate_accuracy(gold_answers, predictions) 748 | final_eval_results = {'final_acc_score' : final_acc_score} 749 | 750 | logger.info(f"Evaluation metrics: {final_eval_results}") 751 | print(final_eval_results) 752 | 753 | with open(os.path.join(args.output_dir, "final_eval_results.json"), "w") as f: 754 | json.dump(final_eval_results, f) 755 | 756 | # Acc per class 757 | final_eval_results_perClass = calculate_accuracy_perClass(gold_answers, predictions) 758 | 759 | logger.info(f"Evaluation metrics per class: {final_eval_results_perClass}") 760 | print(final_eval_results_perClass) 761 | 762 | with open(os.path.join(args.output_dir, "final_eval_results_perClass.json"), "w") as f: 763 | json.dump(final_eval_results_perClass, f, indent=4) 764 | 765 | 766 | if __name__ == "__main__": 767 | main() 768 | 769 | 770 | 771 | 772 | 773 | -------------------------------------------------------------------------------- /classifier/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | from transformers import ( 4 | CONFIG_MAPPING, 5 | MODEL_MAPPING, 6 | AutoConfig, 7 | AutoModelForSeq2SeqLM, 8 | AutoTokenizer, 9 | DataCollatorForSeq2Seq, 10 | SchedulerType, 11 | get_scheduler, 12 | ) 13 | import datasets 14 | import numpy as np 15 | import math 16 | 17 | from transformers.trainer_utils import EvalLoopOutput, EvalPrediction, get_last_checkpoint 18 | 19 | 20 | def load_model(args): 21 | # Load pretrained model and tokenizer 22 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 23 | # download model & vocab. 24 | if args.config_name: 25 | config = AutoConfig.from_pretrained(args.config_name) 26 | elif args.model_name_or_path: 27 | config = AutoConfig.from_pretrained(args.model_name_or_path) 28 | else: 29 | config = CONFIG_MAPPING[args.model_type]() 30 | logger.warning("You are instantiating a new config instance from scratch.") 31 | 32 | if args.tokenizer_name: 33 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer) 34 | elif args.model_name_or_path: 35 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) 36 | else: 37 | raise ValueError( 38 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 39 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 40 | ) 41 | 42 | if args.model_name_or_path: 43 | model = AutoModelForSeq2SeqLM.from_pretrained( 44 | args.model_name_or_path, 45 | from_tf=bool(".ckpt" in args.model_name_or_path), 46 | config=config, 47 | ) 48 | else: 49 | logger.info("Training new model from scratch") 50 | model = AutoModelForSeq2SeqLM.from_config(config) 51 | 52 | return model, tokenizer 53 | 54 | 55 | def preprocess_dataset(args, raw_datasets): 56 | # Preprocessing the datasets. 57 | # First we tokenize all the texts. 58 | if args.do_eval: 59 | column_names = raw_datasets[args.val_column].column_names 60 | else : 61 | column_names = raw_datasets[args.train_column].column_names 62 | 63 | # Get the column names for input/target. 64 | question_column = args.question_column 65 | if question_column not in column_names: 66 | raise ValueError( 67 | f"--question_column' value '{args.question_column}' needs to be one of: {', '.join(column_names)}" 68 | ) 69 | 70 | answer_column = args.answer_column 71 | if answer_column not in column_names: 72 | raise ValueError( 73 | f"--answer_column' value '{args.answer_column}' needs to be one of: {', '.join(column_names)}" 74 | ) 75 | 76 | return question_column, answer_column 77 | 78 | 79 | def preprocess_features_function(examples, args, raw_datasets, tokenizer): 80 | question_column, answer_column = preprocess_dataset(args, raw_datasets) 81 | 82 | # Temporarily set max_answer_length for training. 83 | max_answer_length = args.max_answer_length 84 | padding = "max_length" if args.pad_to_max_length else False 85 | max_seq_length = min(args.max_seq_length, tokenizer.model_max_length) 86 | 87 | 88 | 89 | # Some of the questions have lots of whitespace on the left, which is not useful and will make the 90 | # truncation of the context fail (the tokenized question will take a lots of space). So we remove that 91 | # left whitespace 92 | 93 | examples[question_column] = ['{}'.format(q.strip()) for q in examples[question_column]] 94 | 95 | 96 | # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results 97 | # in one example possible giving several features when a context is long, each of those features having a 98 | # context that overlaps a bit the context of the previous feature. 99 | model_inputs = tokenizer( 100 | examples[question_column], 101 | truncation=True, 102 | max_length=max_seq_length, 103 | stride=args.doc_stride, 104 | return_overflowing_tokens=True, 105 | return_offsets_mapping=True, 106 | padding=padding, 107 | ) 108 | 109 | targets = examples[answer_column] 110 | 111 | # Tokenize targets with the `text_target` keyword argument 112 | labels = tokenizer(text_target=targets, max_length=max_answer_length, padding=padding, truncation=True) 113 | 114 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 115 | # padding in the loss. 116 | if padding == "max_length" and args.ignore_pad_token_for_loss: 117 | labels["input_ids"] = [ 118 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 119 | ] 120 | 121 | # Since one example might give us several features if it has a long context, we need a map from a feature to 122 | # its corresponding example. This key gives us just that. 123 | sample_mapping = model_inputs.pop("overflow_to_sample_mapping") 124 | 125 | # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the 126 | # corresponding example_id and we will store the offset mappings. 127 | model_inputs["example_id"] = [] 128 | # Augment the overflowing tokens to the labels 129 | labels_out = [] 130 | 131 | for i in range(len(model_inputs["input_ids"])): 132 | # One example can give several spans, this is the index of the example containing this span of text. 133 | sample_index = sample_mapping[i] 134 | model_inputs["example_id"].append(examples["id"][sample_index]) 135 | labels_out.append(labels["input_ids"][sample_index]) 136 | 137 | model_inputs["labels"] = labels_out 138 | return model_inputs 139 | 140 | 141 | # Post-processing: 142 | def post_processing_function( 143 | tokenizer, args, raw_datasets, examples: datasets.Dataset, features: datasets.Dataset, outputs, stage="eval" 144 | ): 145 | # Decode the predicted tokens. 146 | preds = outputs 147 | if isinstance(preds, tuple): 148 | preds = preds[0] 149 | # Replace -100s used for padding as we can't decode them 150 | preds = np.where(preds != -100, preds, tokenizer.pad_token_id) 151 | 152 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 153 | return decoded_preds 154 | 155 | 156 | 157 | # Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor 158 | def create_and_fill_np_array(all_gen_tokens, dataset, max_len): 159 | """ 160 | Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor 161 | Args: 162 | all_gen_tokens(:obj:`tensor`): 163 | This is the output predictions of the model. 164 | eval_dataset: Evaluation dataset 165 | max_len(:obj:`int`): 166 | The maximum length of the output tensor. ( See the model.eval() part for more details ) 167 | """ 168 | 169 | step = 0 170 | # create a numpy array and fill it with -100. 171 | gen_toks_concat = np.full((len(dataset), max_len), -100) 172 | # Now since we have create an array now we will populate it with the outputs gathered using accelerator.gather_for_metrics 173 | for i, gen_tok in enumerate(all_gen_tokens): # populate columns 174 | # We have to fill it such that we have to take the whole tensor and replace it on the newly created array 175 | # And after every iteration we have to change the step 176 | #import pdb; pdb.set_trace() 177 | batch_size = gen_tok.shape[0] 178 | cols = gen_tok.shape[1] 179 | 180 | if step + batch_size < len(dataset): 181 | gen_toks_concat[step : step + batch_size, :cols] = gen_tok 182 | else: 183 | gen_toks_concat[step:, :cols] = gen_tok[: len(dataset) - step] 184 | 185 | step += batch_size 186 | 187 | return gen_toks_concat 188 | 189 | 190 | def prepare_scheduler(args, accelerator, dataloader, optimizer, max_train_steps, train_epoch): 191 | overrode_max_train_steps = False 192 | 193 | num_update_steps_per_epoch = math.ceil(len(dataloader) / args.gradient_accumulation_steps) 194 | 195 | if max_train_steps is None: 196 | max_train_steps = train_epoch * num_update_steps_per_epoch 197 | overrode_max_train_steps = True 198 | 199 | lr_scheduler = get_scheduler( 200 | name=args.lr_scheduler_type, 201 | optimizer=optimizer, 202 | num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps, 203 | num_training_steps=max_train_steps * args.gradient_accumulation_steps, 204 | ) 205 | 206 | lr_scheduler = accelerator.prepare(lr_scheduler) 207 | 208 | num_update_steps_per_epoch = math.ceil(len(dataloader) / args.gradient_accumulation_steps) 209 | 210 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 211 | if overrode_max_train_steps: 212 | max_train_steps = train_epoch * num_update_steps_per_epoch 213 | # Afterwards we recalculate our number of training epochs 214 | train_epoch = math.ceil(max_train_steps / num_update_steps_per_epoch) 215 | 216 | return max_train_steps, train_epoch, lr_scheduler 217 | 218 | 219 | def get_gold_answers(example): 220 | """helper function that retrieves all possible true answers from a squad2.0 example""" 221 | 222 | gold_answers = [answer["text"] for answer in example.answers if answer["text"]] 223 | 224 | # if gold_answers doesn't exist it's because this is a negative example - 225 | # the only correct answer is an empty string 226 | if not gold_answers: 227 | gold_answers = [""] 228 | 229 | return gold_answers 230 | 231 | def calculate_accuracy(gold_answers, predictions): 232 | total_acc_score = 0 233 | for (gold_answer, prediction) in zip(gold_answers, predictions): 234 | acc_score = int(gold_answer == prediction) 235 | total_acc_score = total_acc_score + acc_score 236 | 237 | final_acc_score = (total_acc_score / len(gold_answers)) * 100 238 | return final_acc_score 239 | 240 | def calculate_accuracy_perClass(gold_answers, predictions): 241 | a_total_acc_score = 0 242 | b_total_acc_score = 0 243 | c_total_acc_score = 0 244 | 245 | a_gold_num = len([i for i in gold_answers if i == 'A']) 246 | b_gold_num = len([i for i in gold_answers if i == 'B']) 247 | c_gold_num = len([i for i in gold_answers if i == 'C']) 248 | 249 | a_pred_num = len([i for i in predictions if i == 'A']) 250 | b_pred_num = len([i for i in predictions if i == 'B']) 251 | c_pred_num = len([i for i in predictions if i == 'C']) 252 | 253 | for (gold_answer, prediction) in zip(gold_answers, predictions): 254 | # a 255 | a_acc_score = int(gold_answer == prediction == 'A') 256 | a_total_acc_score = a_total_acc_score + a_acc_score 257 | # b 258 | b_acc_score = int(gold_answer == prediction == 'B') 259 | b_total_acc_score = b_total_acc_score + b_acc_score 260 | # c 261 | c_acc_score = int(gold_answer == prediction == 'C') 262 | c_total_acc_score = c_total_acc_score + c_acc_score 263 | 264 | 265 | a_final_acc_score = (a_total_acc_score / a_gold_num) * 100 if a_gold_num != 0 else -1 266 | b_final_acc_score = (b_total_acc_score / b_gold_num) * 100 if b_gold_num != 0 else -1 267 | c_final_acc_score = (c_total_acc_score / c_gold_num) * 100 if c_gold_num != 0 else -1 268 | 269 | dict_final = {'A (zero) acc' : a_final_acc_score, 'B (single) acc' : b_final_acc_score, 'C (multi) acc' : c_final_acc_score, 270 | 'A (zero) pred num' : a_pred_num, 'B (single) pred num' : b_pred_num, 'C (multi) pred num' : c_pred_num, 271 | 'A (zero) gold num' : a_gold_num, 'B (single) gold num' : b_gold_num, 'C (multi) gold num' : c_gold_num} 272 | return dict_final 273 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmis-lab/RAG2/911af780456469d2483cf09f38e6c8c94539d646/main.py -------------------------------------------------------------------------------- /retriever/README.md: -------------------------------------------------------------------------------- 1 | 1. Save embeddings and articles into their corresponding folders: 2 | ex) for pmc embeddings, save PMC_Abs_Embeds.npy, PMC_Main_Embeds.npy in embeddings/pmc 3 | 4 | 2. Navigate to the directory where main.py is located. 5 | 6 | 3. Execute the Python script main.py using the following command: 7 | ``` 8 | python main.py 9 | ``` 10 | 11 | 4. For PubMed, we grouped the 38 chunks into 10, 10, 10, and 8 subgroups, respectively. 12 | We retrieved 10 evidences from each subgroup using MIPS, totaling 40 evidences. 13 | For PMC, CPG, and textbook, we retrieved 10 evidences from each. 14 | With a total of 70 evidences, we reranked them and obtained the final 10 evidences. 15 | The number of PubMed subgroup chunks can be adjusted using the --pubmed_group_num argument. 16 | 17 | 5. We used SciSpacy en_core_sci_scibert to add [SEP] tokens for encoding queries and articles for MIPS but not for the reranker following MEDCPT (Jin et al., 2023). 18 | This functionality is optional and can be enabled using the --use_spacy argument. 19 | 20 | -------------------------------------------------------------------------------- /retriever/articles/cpg/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmis-lab/RAG2/911af780456469d2483cf09f38e6c8c94539d646/retriever/articles/cpg/.gitkeep -------------------------------------------------------------------------------- /retriever/articles/pmc/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmis-lab/RAG2/911af780456469d2483cf09f38e6c8c94539d646/retriever/articles/pmc/.gitkeep -------------------------------------------------------------------------------- /retriever/articles/pubmed/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmis-lab/RAG2/911af780456469d2483cf09f38e6c8c94539d646/retriever/articles/pubmed/.gitkeep -------------------------------------------------------------------------------- /retriever/articles/textbook/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmis-lab/RAG2/911af780456469d2483cf09f38e6c8c94539d646/retriever/articles/textbook/.gitkeep -------------------------------------------------------------------------------- /retriever/embeddings/cpg/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmis-lab/RAG2/911af780456469d2483cf09f38e6c8c94539d646/retriever/embeddings/cpg/.gitkeep -------------------------------------------------------------------------------- /retriever/embeddings/pmc/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmis-lab/RAG2/911af780456469d2483cf09f38e6c8c94539d646/retriever/embeddings/pmc/.gitkeep -------------------------------------------------------------------------------- /retriever/embeddings/pubmed/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmis-lab/RAG2/911af780456469d2483cf09f38e6c8c94539d646/retriever/embeddings/pubmed/.gitkeep -------------------------------------------------------------------------------- /retriever/embeddings/textbook/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmis-lab/RAG2/911af780456469d2483cf09f38e6c8c94539d646/retriever/embeddings/textbook/.gitkeep -------------------------------------------------------------------------------- /retriever/input/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmis-lab/RAG2/911af780456469d2483cf09f38e6c8c94539d646/retriever/input/.gitkeep -------------------------------------------------------------------------------- /retriever/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import faiss 5 | import argparse 6 | import scispacy 7 | import spacy 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModel 10 | import numpy as np 11 | import query_encode as qe 12 | import retrieve as rt 13 | import rerank as rr 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument('-e', '--embeddings_dir', help='embeddings directory', default='embeddings') 20 | parser.add_argument('-a', '--articles_dir', help='articles directory', default='articles') 21 | parser.add_argument('-i', '--input_path', help='input file path', default='/retriever/input/medqa/medqa_llama_cot.json') 22 | parser.add_argument('-c', '--corpus', help='corpus to use', default=['cpg', 'textbook', 'pmc', 'pubmed']) 23 | parser.add_argument('-k', '--top_k', help='number of retrieved documents', default=100, type=int) 24 | parser.add_argument('-inst', '--instruction_preprocess', help='preprocess the query to retrieve documents for instruction(training) set', default='False') 25 | parser.add_argument('-o', '--output_path', help='output file path', default='/retriever/output/medqa/evidence_medqa_llama_cot.json') 26 | parser.add_argument('-spc', '--use_spacy', help='use scispacy to insert [SEP] token between sentences', default='False') 27 | parser.add_argument('-pmdn', '--pubmed_group_num', help='number of chunks of pubmed to concatenate for each step', default=38, type=int) 28 | 29 | args = parser.parse_args() 30 | 31 | embeddings_dir = args.embeddings_dir 32 | articles_dir = args.articles_dir 33 | input_path = args.input_path 34 | output_path = args.output_path 35 | corpus = args.corpus 36 | top_k = args.top_k 37 | inst = args.instruction_preprocess 38 | use_spacy = args.use_spacy 39 | pubmed_group_num = args.pubmed_group_num 40 | 41 | if inst == True: 42 | # query preprocess for instruction_set 43 | input_list = qe.query_preprocess_instruction(input_path, use_spacy = use_spacy) 44 | else: 45 | with open(input_path, 'r') as input_file: 46 | input_list = json.load(input_file) 47 | xq = qe.query_encode(input_list) 48 | 49 | # pubmed mips 50 | pubmed_I_array = [] 51 | for start_index in range(0, 38, pubmed_group_num): 52 | pubmed_index = rt.pubmed_index_create(pubmed_embeddings_dir=os.path.join(embeddings_dir, "pubmed"), start_index=start_index, pubmed_group_num=pubmed_group_num) 53 | pubmed_I_array_temp = [] 54 | splits = [i for i in range(0, len(xq), 1024)] 55 | 56 | for split_start in tqdm(splits, desc=f"PubMed FAISS MIPS {start_index}:"): 57 | D, I = pubmed_index.search(xq[split_start:split_start+1024], top_k) 58 | pubmed_I_array_temp.extend(I) 59 | pubmed_I_array.append(pubmed_I_array_temp) 60 | del pubmed_index 61 | print(len(pubmed_I_array), "x", len(pubmed_I_array[0])) 62 | # pubmed mips index save 63 | # np.save("PubMed_I_array.npy", pubmed_I_array) 64 | 65 | # pubmed decode 66 | pubmed_evidences = rt.pubmed_decode(pubmed_I_array, pubmed_articles_dir= os.path.join(articles_dir, "pubmed"), pubmed_group_num=pubmed_group_num) 67 | print(len(pubmed_evidences), "x", len(pubmed_evidences[0])) 68 | 69 | 70 | # pmc mips 71 | pmc_index = rt.pmc_index_create(pmc_embeddings_dir = os.path.join(embeddings_dir, "pmc")) 72 | pmc_I_array = [] 73 | 74 | for i in tqdm(splits, desc="PMC FAISS MIPS"): 75 | D, I = pmc_index.search(xq[i:i+1024], top_k) 76 | pmc_I_array.extend(I) 77 | del pmc_index 78 | 79 | # pmc mips index save 80 | # np.save("PMC_I_array.npy", pmc_I_array) 81 | 82 | # decode pmc 83 | pmc_evidences = rt.pmc_decode(pmc_I_array, pmc_articles_dir = os.path.join(articles_dir, "pmc")) 84 | 85 | 86 | # cpg mips 87 | cpg_index = rt.cpg_index_create(cpg_embeddings_dir = os.path.join(embeddings_dir, "cpg")) 88 | cpg_I_array = [] 89 | 90 | for i in tqdm(splits, desc="CPG FAISS MIPS"): 91 | D, I = cpg_index.search(xq[i:i+1024], top_k) 92 | cpg_I_array.extend(I) 93 | del cpg_index 94 | 95 | # cpg mips index save 96 | # np.save("CPG_I_array.npy", cpg_I_array) 97 | 98 | # decode cpg 99 | cpg_evidences = rt.cpg_decode(cpg_I_array, cpg_articles_dir = os.path.join(articles_dir, "cpg")) 100 | 101 | 102 | # textbook mips 103 | textbook_index = rt.textbook_index_create(textbook_embeddings_dir = os.path.join(embeddings_dir, "textbook")) 104 | textbook_I_array = [] 105 | 106 | for i in tqdm(splits, desc="textbook FAISS MIPS"): 107 | D, I = textbook_index.search(xq[i:i+1024], top_k) 108 | textbook_I_array.extend(I) 109 | del textbook_index 110 | 111 | # textbook mips index save 112 | #np.save("Textbook_I_array.npy", textbook_I_array) 113 | 114 | # decode textbook 115 | textbook_evidences = rt.textbook_decode(textbook_I_array, textbook_articles_dir = os.path.join(articles_dir, "textbook")) 116 | 117 | # rerank evidences from 4 corpora 118 | query_evidences, evidences = rr.combine_query_evidence(input_list, pubmed_evidences, pmc_evidences, cpg_evidences, textbook_evidences) 119 | 120 | # save output of 10 reranked evidences 121 | reranked_evidences = rr.rerank(query_evidences, evidences, top_k) 122 | # with open (input_path, 'r') as jsfile: 123 | # input_file = json.load(jsfile) 124 | 125 | 126 | # with open (output_path, 'w') as jsfile: 127 | # json.dump(reranked_evidences, jsfile) 128 | 129 | with open (output_path, 'w') as jsfile: 130 | json.dump(reranked_evidences, jsfile) 131 | 132 | if __name__ == "__main__": 133 | main() -------------------------------------------------------------------------------- /retriever/output/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmis-lab/RAG2/911af780456469d2483cf09f38e6c8c94539d646/retriever/output/.gitkeep -------------------------------------------------------------------------------- /retriever/query_encode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | from transformers import AutoTokenizer, AutoModel 4 | from tqdm import tqdm 5 | import numpy as np 6 | import scispacy 7 | import spacy 8 | import json 9 | 10 | def query_preprocess_instruction(input_path, use_spacy = True): #using spacy to divide query into sentences and then inserting [SEP] token is time-intensive, so we leave it as user's design choice. 11 | with open(input_path, 'r') as jsfile: 12 | input_data = json.load(jsfile) 13 | input_instruction = [i['instruction'] for i in input_data] 14 | input_input = [i['input'] for i in input_data] 15 | 16 | if use_spacy: 17 | nlp = spacy.load("en_core_sci_scibert") 18 | 19 | split_data_instruction = [] 20 | for instruction in tqdm(input_instruction): 21 | split_data_instruction.append(nlp(instruction)) 22 | 23 | split_data_input = [] 24 | for input in tqdm(input_input): 25 | split_data_input.append(nlp(input)) 26 | 27 | query_list = [] 28 | 29 | for inst_idx, inst in enumerate(split_data_instruction): 30 | query = "" 31 | for inst_text in split_data_instruction[inst_idx].sents: 32 | if len(inst_text.text) == 1: 33 | continue 34 | query += inst_text.text + " [SEP] " 35 | for input_idx, input_text in enumerate(split_data_input[inst_idx].sents): 36 | if len(input_text.text) == 1: 37 | continue 38 | elif input_idx == len(list(split_data_input[inst_idx].sents))-1: 39 | query += input_text.text 40 | else: 41 | query += input_text.text + " [SEP] " 42 | query_list.append(query) 43 | else: 44 | query_list = [] 45 | for inst_idx, inst in enumerate(input_instruction): 46 | query = inst + ' ' +input_input[inst_idx] 47 | query_list.append(query) 48 | return query_list 49 | 50 | 51 | def query_encode(input_list): 52 | model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder") 53 | model.eval() 54 | if torch.cuda.is_available(): 55 | model = model.to("cuda:7") 56 | tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder") 57 | 58 | queries=[] 59 | 60 | splits = [i for i in range(0, len(input_list), 1)] 61 | for i in tqdm(splits, desc="query encoding"): 62 | split_queries = input_list[i:i+1] 63 | with torch.no_grad(): 64 | encoded = tokenizer( 65 | split_queries, 66 | truncation=True, 67 | padding=True, 68 | return_tensors='pt', 69 | # max_length=512, 70 | max_length=512, 71 | ) 72 | encoded = {key: tensor.to("cuda:7") for key, tensor in encoded.items()} 73 | embeds = model(**encoded).last_hidden_state[:, 0, :] 74 | query_embeddings = embeds.detach().cpu().numpy() 75 | queries.extend(query_embeddings) 76 | xq = np.vstack(queries) 77 | return xq 78 | -------------------------------------------------------------------------------- /retriever/rerank.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 4 | 5 | def combine_query_evidence(queries, list1, list2, list3, list4, list5): 6 | evidences_5 = [] 7 | evidences_5 = [sublist1 + sublist2 + sublist3 + sublist4 + sublist5 for sublist1, sublist2, sublist3, sublist4, sublist5 in zip(list1, list2, list3, list4, list5)] 8 | q_5a_list = [] 9 | for ith, q in tqdm(enumerate(queries)): 10 | q_5a = [] 11 | for a in evidences_5[ith]: 12 | q_a = [q, a] 13 | q_5a.append(q_a) 14 | q_5a_list.append(q_5a) 15 | return q_5a_list, evidences_5 16 | 17 | def rerank(q_5a_list, evidences_5, top_k): 18 | device_ids = [2] if torch.cuda.is_available() else None 19 | tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Cross-Encoder") 20 | model = AutoModelForSequenceClassification.from_pretrained("ncbi/MedCPT-Cross-Encoder") 21 | model.eval() 22 | model = model.to(device_ids[0]) 23 | 24 | logits_list = [] 25 | for q_5a in tqdm(q_5a_list): 26 | with torch.no_grad(): 27 | encoded_q_5a = tokenizer( 28 | q_5a, 29 | truncation=True, 30 | padding=True, 31 | return_tensors="pt", 32 | max_length=512, 33 | ) 34 | encoded_q_5a = {key: tensor.to(device_ids[0]) for key, tensor in encoded_q_5a.items()} 35 | logits_q_5a = model(**encoded_q_5a).logits.squeeze(dim=1) 36 | logits_q_5a = logits_q_5a.detach().cpu() 37 | logits_list.append(logits_q_5a) 38 | 39 | #logits_list_serializable = [tensor.numpy().tolist() for tensor in logits_list] 40 | #with open('logits_list.json', 'w') as f: 41 | # json.dump(logits_list_serializable, f) 42 | 43 | sorted_indices = [sorted(range(len(logits_5)), key=lambda k: logits_5[k], reverse=True) for logits_5 in logits_list] 44 | top_k_indices = [sorted_indices_i[:top_k] for sorted_indices_i in sorted_indices] 45 | sorted_evidence_list = [] 46 | for index, data in enumerate(evidences_5): 47 | sorted_evidence_list.append([data[i] for i in top_k_indices[index]]) 48 | 49 | return sorted_evidence_list 50 | -------------------------------------------------------------------------------- /retriever/retrieve.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | import json 3 | import numpy as np 4 | import torch 5 | import os 6 | from tqdm import tqdm 7 | from transformers import AutoTokenizer, AutoModel 8 | 9 | def pubmed_index_create(pubmed_embeddings_dir, start_index, pubmed_group_num): 10 | pubmed_index = faiss.IndexFlatIP(768) 11 | for i in tqdm(range(start_index, min(38, start_index+pubmed_group_num)), desc="pubmed load and add", dynamic_ncols=True): 12 | embeds_chunk_path = f"{pubmed_embeddings_dir}/PubMed_Embeds_{i}.npy" 13 | embeds_chunk = np.load(embeds_chunk_path) 14 | embeds_chunk = embeds_chunk.astype(np.float32) 15 | pubmed_index.add(embeds_chunk) 16 | del embeds_chunk 17 | return pubmed_index 18 | 19 | def pmc_index_create(pmc_embeddings_dir): 20 | pmc_filename=["PMC_Main_Embeds.npy", "PMC_Abs_Embeds.npy"] 21 | pmc_index = faiss.IndexFlatIP(768) 22 | for i in tqdm(pmc_filename, desc="pmc load and add", dynamic_ncols=True): 23 | embeddings = np.load(os.path.join(pmc_embeddings_dir, i)) 24 | embeddings = embeddings.astype(np.float32) 25 | pmc_index.add(embeddings) 26 | del embeddings 27 | return pmc_index 28 | 29 | def cpg_index_create(cpg_embeddings_dir): 30 | cpg_index = faiss.IndexFlatIP(768) 31 | with tqdm(total=1, desc="cpg load and add", dynamic_ncols=True) as pbar: 32 | embeddings = np.load(os.path.join(cpg_embeddings_dir,"CPG_Total_Embeds.npy")) 33 | embeddings = embeddings.astype(np.float32) 34 | cpg_index.add(embeddings) 35 | del embeddings 36 | pbar.update(1) 37 | return cpg_index 38 | 39 | def textbook_index_create(textbook_embeddings_dir): 40 | textbook_index = faiss.IndexFlatIP(768) 41 | with tqdm(total=1, desc="textbook load and add", dynamic_ncols=True) as pbar: 42 | embeddings = np.load(os.path.join(textbook_embeddings_dir, "Textbook_Total_Embeds.npy")) 43 | embeddings = embeddings.astype(np.float32) 44 | textbook_index.add(embeddings) 45 | del embeddings 46 | pbar.update(1) 47 | return textbook_index 48 | 49 | 50 | def statpearls_index_create(statpearls_embeddings_dir): 51 | statpearls_index = faiss.IndexFlatIP(768) 52 | with tqdm(total=1, desc="statpearls load and add", dynamic_ncols=True) as pbar: 53 | embeddings = np.load(os.path.join(statpearls_embeddings_dir, "Statpearls_Total_Embeds.npy")) 54 | embeddings = embeddings.astype(np.float32) 55 | statpearls_index.add(embeddings) 56 | del embeddings 57 | pbar.update(1) 58 | return statpearls_index 59 | 60 | 61 | def find_value_by_index(articles, target_index): 62 | return articles[target_index] 63 | 64 | def pubmed_decode(pubmed_I_array, pubmed_articles_dir, pubmed_group_num): 65 | def combine_articles(pubmed_articles_dir, start_index, pubmed_group_num): 66 | pubmed_articles = [] 67 | for i in tqdm(range(start_index, min(38, start_index+pubmed_group_num)), desc="articles load and add", dynamic_ncols=True): 68 | with open(pubmed_articles_dir+f"/PubMed_Articles_{i}.json", 'r') as article_chunk: 69 | pubmed_articles.extend(json.load(article_chunk)) 70 | return pubmed_articles 71 | 72 | #pubmed_I_array_savepath = "PubMed_128_I_array.npy" 73 | #output_json_path = "PubMed_retrieved.json" 74 | pubmed_evidences = [] 75 | for start_index in range(0, 38, pubmed_group_num): 76 | pubmed_articles = combine_articles(pubmed_articles_dir, start_index, pubmed_group_num) 77 | 78 | #pubmed_I_array = np.load(idx_array_savepath) 79 | pubmed_evidences_temp = [] 80 | for ith, indices in tqdm(enumerate(pubmed_I_array[start_index//pubmed_group_num]), desc="decode and add", dynamic_ncols=True): 81 | evidence_list = [find_value_by_index(pubmed_articles, target_index) for target_index in indices] 82 | pubmed_evidences_temp.append(evidence_list) 83 | pubmed_evidences.append(pubmed_evidences_temp) 84 | 85 | pubmed_evidences_flat = [] 86 | for subtuple in zip(*pubmed_evidences): 87 | group = [] 88 | for sublist in subtuple: 89 | group.extend(sublist) 90 | pubmed_evidences_flat.append(group) 91 | 92 | #with open(output_json_path, 'w') as jsfile: 93 | # json.dump(total_evidence, jsfile) 94 | #logging.info(f"evidence saved: {len(total_evidence)}") 95 | 96 | return pubmed_evidences_flat 97 | 98 | def pmc_decode(pmc_I_array, pmc_articles_dir): 99 | def load_article(pmc_articles_dir): 100 | pmc_articles = [] 101 | for i in ["PMC_Main_Articles.json", "PMC_Abs_Articles.json"]: 102 | with open(os.path.join(pmc_articles_dir, i), 'r') as jsfile: 103 | pmc_articles.extend(json.load(jsfile)) 104 | return pmc_articles 105 | 106 | #idx_array_savepath = "PMC_128_I_array.npy" 107 | #output_json_path = "PMC_retrieved.json" 108 | 109 | pmc_articles = load_article(pmc_articles_dir) 110 | 111 | #pmc_I_array = np.load(idx_array_savepath) 112 | 113 | pmc_evidences = [] 114 | 115 | for ith, indices in tqdm(enumerate(pmc_I_array), desc="decode and add", dynamic_ncols=True): 116 | evidence_list = [find_value_by_index(pmc_articles, j) for j in indices] 117 | pmc_evidences.append(evidence_list) 118 | 119 | #with open(output_json_path, 'w') as jsfile: 120 | # json.dump(total_evidence, jsfile) 121 | #logging.info(f"evidence saved: {len(total_evidence)}") 122 | 123 | return pmc_evidences 124 | 125 | 126 | def cpg_decode(cpg_index, cpg_articles_dir): 127 | def load_articles(cpg_articles_dir): 128 | with open(os.path.join(cpg_articles_dir,'CPG_Total_Articles.json'), 'r') as jsfile: 129 | cpg_articles = json.load(jsfile) 130 | return cpg_articles 131 | 132 | #idx_array_savepath = "CPG_128_I_array.npy" 133 | #output_json_path = "CPG_retrieved.json" 134 | 135 | cpg_articles = load_articles(cpg_articles_dir) 136 | 137 | #idx_array = np.load(idx_array_savepath) 138 | 139 | cpg_evidences = [] 140 | 141 | for ith, indices in tqdm(enumerate(cpg_index), desc="decode and add", dynamic_ncols=True): 142 | evidence_list = [find_value_by_index(cpg_articles, j) for j in indices] 143 | cpg_evidences.append(evidence_list) 144 | 145 | #with open(output_json_path, 'w') as jsfile: 146 | # json.dump(total_evidence, jsfile) 147 | #logging.info(f"evidence saved: {len(total_evidence)}") 148 | 149 | return cpg_evidences 150 | 151 | def textbook_decode(textbook_index, textbook_articles_dir): 152 | def load_articles(textbook_articles_dir): 153 | with open(os.path.join(textbook_articles_dir, "Textbook_Total_Articles.json"), 'r') as jsfile: 154 | textbook_articles = json.load(jsfile) 155 | return textbook_articles 156 | 157 | #idx_array_savepath = "Textbook_128_I_array.npy" 158 | #output_json_path = "Textbook_retrieved.json" 159 | 160 | textbook_articles = load_articles(textbook_articles_dir) 161 | 162 | #logging.info("Loading indices") 163 | #idx_array = np.load(idx_array_savepath) 164 | #logging.info(f"Indices loaded: {idx_array.shape}") 165 | 166 | textbook_evidences = [] 167 | 168 | for ith, indices in tqdm(enumerate(textbook_index), desc="decode and add", dynamic_ncols=True): 169 | evidence_list = [find_value_by_index(textbook_articles, j) for j in indices] 170 | textbook_evidences.append(evidence_list) 171 | 172 | #with open(output_json_path, 'w') as jsfile: 173 | # json.dump(total_evidence, jsfile) 174 | #logging.info(f"evidence saved: {len(total_evidence)}") 175 | 176 | return textbook_evidences 177 | 178 | 179 | def statpearls_decode(statpearls_index, statpearls_articles_dir): 180 | def load_articles(statpearls_articles_dir): 181 | with open(os.path.join(statpearls_articles_dir, "Statpearls_Total_Articles.json"), 'r') as jsfile: 182 | statpearls_articles = json.load(jsfile) 183 | return statpearls_articles 184 | 185 | statpearls_articles = load_articles(statpearls_articles_dir) 186 | 187 | statpearls_evidences = [] 188 | 189 | for ith, indices in tqdm(enumerate(statpearls_index), desc="decode and add", dynamic_ncols=True): 190 | evidence_list = [find_value_by_index(statpearls_articles, j) for j in indices] 191 | statpearls_evidences.append(evidence_list) 192 | 193 | return statpearls_evidences -------------------------------------------------------------------------------- /retriever/retriever.yml: -------------------------------------------------------------------------------- 1 | name: retriever 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - blas=1.0=mkl 10 | - bzip2=1.0.8=h7b6447c_0 11 | - ca-certificates=2023.12.12=h06a4308_0 12 | - cuda-cudart=12.1.105=0 13 | - cuda-cupti=12.1.105=0 14 | - cuda-libraries=12.1.0=0 15 | - cuda-nvrtc=12.1.105=0 16 | - cuda-nvtx=12.1.105=0 17 | - cuda-opencl=12.3.52=0 18 | - cuda-runtime=12.1.0=0 19 | - cudatoolkit=11.4.1=h8ab8bb3_9 20 | - filelock=3.9.0=py310h06a4308_0 21 | - gmp=6.2.1=h295c915_3 22 | - gmpy2=2.1.2=py310heeb90bb_0 23 | - intel-openmp=2021.4.0=h06a4308_3561 24 | - jinja2=3.1.2=py310h06a4308_0 25 | - ld_impl_linux-64=2.38=h1181459_1 26 | - libcublas=12.1.0.26=0 27 | - libcufft=11.0.2.4=0 28 | - libcufile=1.8.0.34=0 29 | - libcurand=10.3.4.52=0 30 | - libcusolver=11.4.4.55=0 31 | - libcusparse=12.0.2.55=0 32 | - libfaiss=1.7.4=h13c3c6d_0_cuda11.4 33 | - libffi=3.4.4=h6a678d5_0 34 | - libgcc-ng=11.2.0=h1234567_1 35 | - libgomp=11.2.0=h1234567_1 36 | - libnpp=12.0.2.50=0 37 | - libnvjitlink=12.1.105=0 38 | - libnvjpeg=12.1.1.14=0 39 | - libstdcxx-ng=11.2.0=h1234567_1 40 | - libuuid=1.41.5=h5eee18b_0 41 | - llvm-openmp=14.0.6=h9e868ea_0 42 | - mkl=2021.4.0=h06a4308_640 43 | - mkl-service=2.4.0=py310h7f8727e_0 44 | - mkl_fft=1.3.1=py310hd6ae3a3_0 45 | - mkl_random=1.2.2=py310h00e6091_0 46 | - mpc=1.1.0=h10f8cd9_1 47 | - mpfr=4.0.2=hb69a4c5_1 48 | - mpmath=1.3.0=py310h06a4308_0 49 | - ncurses=6.4=h6a678d5_0 50 | - networkx=3.1=py310h06a4308_0 51 | - numpy=1.24.3=py310hd5efca6_0 52 | - numpy-base=1.24.3=py310h8e6c178_0 53 | - openssl=3.0.12=h7f8727e_0 54 | - pip=23.3=py310h06a4308_0 55 | - python=3.10.13=h955ad1f_0 56 | - pytorch=2.1.0=py3.10_cuda12.1_cudnn8.9.2_0 57 | - pytorch-cuda=12.1=ha16c6d3_5 58 | - pytorch-mutex=1.0=cuda 59 | - pyyaml=6.0.1=py310h5eee18b_0 60 | - readline=8.2=h5eee18b_0 61 | - setuptools=68.0.0=py310h06a4308_0 62 | - six=1.16.0=pyhd3eb1b0_1 63 | - sqlite=3.41.2=h5eee18b_0 64 | - sympy=1.11.1=py310h06a4308_0 65 | - tk=8.6.12=h1ccaba5_0 66 | - torchtriton=2.1.0=py310 67 | - typing_extensions=4.7.1=py310h06a4308_0 68 | - wheel=0.41.2=py310h06a4308_0 69 | - xz=5.4.2=h5eee18b_0 70 | - yaml=0.2.5=h7b6447c_0 71 | - zlib=1.2.13=h5eee18b_0 72 | - pip: 73 | - aiohttp==3.9.1 74 | - aiosignal==1.3.1 75 | - annotated-types==0.6.0 76 | - antlr4-python3-runtime==4.9.3 77 | - anyio==4.0.0 78 | - argon2-cffi==23.1.0 79 | - argon2-cffi-bindings==21.2.0 80 | - arrow==1.3.0 81 | - asttokens==2.4.1 82 | - async-lru==2.0.4 83 | - async-timeout==4.0.3 84 | - attrs==23.1.0 85 | - babel==2.13.1 86 | - backoff==2.2.1 87 | - beautifulsoup4==4.12.2 88 | - bleach==6.1.0 89 | - blis==0.7.11 90 | - catalogue==2.0.10 91 | - certifi==2023.7.22 92 | - cffi==1.16.0 93 | - chardet==5.2.0 94 | - charset-normalizer==3.3.2 95 | - click==7.1.2 96 | - coloredlogs==15.0.1 97 | - comm==0.2.0 98 | - confection==0.1.4 99 | - conllu==4.5.3 100 | - contourpy==1.2.0 101 | - cryptography==41.0.7 102 | - cupy==12.3.0 103 | - cycler==0.12.1 104 | - cymem==2.0.8 105 | - dataclasses-json==0.6.3 106 | - datasets==2.15.0 107 | - debugpy==1.8.0 108 | - decorator==5.1.1 109 | - defusedxml==0.7.1 110 | - deprecated==1.2.14 111 | - dill==0.3.7 112 | - effdet==0.4.1 113 | - emoji==2.9.0 114 | - en-core-sci-scibert @ https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.3/en_core_sci_scibert-0.5.3.tar.gz 115 | - et-xmlfile==1.1.0 116 | - exceptiongroup==1.1.3 117 | - executing==2.0.1 118 | - faiss-gpu==1.7.2 119 | - fastjsonschema==2.18.1 120 | - fastrlock==0.8.2 121 | - filetype==1.2.0 122 | - flatbuffers==23.5.26 123 | - fonttools==4.44.0 124 | - fqdn==1.5.1 125 | - frozenlist==1.4.0 126 | - fsspec==2023.10.0 127 | - greenlet==3.0.3 128 | - huggingface-hub==0.20.3 129 | - humanfriendly==10.0 130 | - idna==3.4 131 | - iopath==0.1.10 132 | - ipykernel==6.26.0 133 | - ipython==8.17.2 134 | - isoduration==20.11.0 135 | - jedi==0.19.1 136 | - joblib==1.3.2 137 | - json5==0.9.14 138 | - jsonlines==4.0.0 139 | - jsonpatch==1.33 140 | - jsonpath-python==1.0.6 141 | - jsonpointer==2.4 142 | - jsonschema==4.19.2 143 | - jsonschema-specifications==2023.7.1 144 | - jupyter-client==8.6.0 145 | - jupyter-core==5.5.0 146 | - jupyter-events==0.9.0 147 | - jupyter-lsp==2.2.0 148 | - jupyter-server==2.10.0 149 | - jupyter-server-terminals==0.4.4 150 | - jupyterlab==4.0.8 151 | - jupyterlab-pygments==0.2.2 152 | - jupyterlab-server==2.25.1 153 | - kiwisolver==1.4.5 154 | - langchain==0.0.354 155 | - langchain-community==0.0.8 156 | - langchain-core==0.1.5 157 | - langchainhub==0.1.14 158 | - langcodes==3.3.0 159 | - langdetect==1.0.9 160 | - langsmith==0.0.77 161 | - layoutparser==0.3.4 162 | - lxml==5.0.1 163 | - markdown==3.5.1 164 | - markupsafe==2.1.3 165 | - marshmallow==3.20.1 166 | - matplotlib==3.8.1 167 | - matplotlib-inline==0.1.6 168 | - mistune==3.0.2 169 | - msg-parser==1.2.0 170 | - multidict==6.0.4 171 | - multiprocess==0.70.15 172 | - murmurhash==1.0.10 173 | - mypy-extensions==1.0.0 174 | - nbclient==0.9.0 175 | - nbconvert==7.11.0 176 | - nbformat==5.9.2 177 | - nest-asyncio==1.5.8 178 | - nltk==3.8.1 179 | - nmslib==2.1.1 180 | - notebook==7.0.6 181 | - notebook-shim==0.2.3 182 | - olefile==0.47 183 | - omegaconf==2.3.0 184 | - onnx==1.15.0 185 | - onnxruntime==1.15.1 186 | - opencv-python==4.9.0.80 187 | - openpyxl==3.1.2 188 | - overrides==7.4.0 189 | - packaging==23.2 190 | - pandas==2.1.3 191 | - pandocfilters==1.5.0 192 | - parso==0.8.3 193 | - pathy==0.10.3 194 | - pdf2image==1.17.0 195 | - pdfminer-six==20221105 196 | - pdfplumber==0.10.3 197 | - pexpect==4.8.0 198 | - pikepdf==8.11.2 199 | - pillow==10.1.0 200 | - platformdirs==4.0.0 201 | - plotly==5.18.0 202 | - portalocker==2.8.2 203 | - preshed==3.0.9 204 | - prometheus-client==0.18.0 205 | - prompt-toolkit==3.0.40 206 | - protobuf==4.25.1 207 | - psutil==5.9.6 208 | - ptyprocess==0.7.0 209 | - pure-eval==0.2.2 210 | - pyarrow==14.0.1 211 | - pyarrow-hotfix==0.6 212 | - pybind11==2.6.1 213 | - pycocotools==2.0.7 214 | - pycparser==2.21 215 | - pydantic==2.5.3 216 | - pydantic-core==2.14.6 217 | - pygments==2.16.1 218 | - pypandoc==1.12 219 | - pyparsing==3.1.1 220 | - pypd==1.1.0 221 | - pypdf==3.17.4 222 | - pypdfium2==4.25.0 223 | - pysbd==0.3.4 224 | - pytesseract==0.3.10 225 | - python-dateutil==2.8.2 226 | - python-docx==1.1.0 227 | - python-iso639==2024.1.2 228 | - python-json-logger==2.0.7 229 | - python-magic==0.4.27 230 | - python-multipart==0.0.6 231 | - python-pptx==0.6.23 232 | - pytz==2023.3.post1 233 | - pyzmq==25.1.1 234 | - rapidfuzz==3.6.1 235 | - referencing==0.30.2 236 | - regex==2023.10.3 237 | - requests==2.31.0 238 | - rfc3339-validator==0.1.4 239 | - rfc3986-validator==0.1.1 240 | - rpds-py==0.12.0 241 | - safetensors==0.4.0 242 | - scikit-learn==1.3.2 243 | - scipy==1.10.1 244 | - scispacy==0.5.3 245 | - seaborn==0.13.1 246 | - send2trash==1.8.2 247 | - sentence-transformers==2.2.2 248 | - sentencepiece==0.1.99 249 | - smart-open==6.4.0 250 | - sniffio==1.3.0 251 | - soupsieve==2.5 252 | - spacy==3.6.1 253 | - spacy-alignments==0.9.1 254 | - spacy-legacy==3.0.12 255 | - spacy-loggers==1.0.5 256 | - spacy-transformers==1.3.4 257 | - sqlalchemy==2.0.25 258 | - srsly==2.4.8 259 | - stack-data==0.6.3 260 | - tabulate==0.9.0 261 | - tenacity==8.2.3 262 | - terminado==0.18.0 263 | - thinc==8.1.12 264 | - threadpoolctl==3.2.0 265 | - timm==0.9.12 266 | - tinycss2==1.2.1 267 | - tokenizers==0.15.1 268 | - tomli==2.0.1 269 | - torchvision==0.16.0 270 | - tornado==6.3.3 271 | - tqdm==4.66.1 272 | - traitlets==5.13.0 273 | - transformers==4.36.2 274 | - typer==0.3.2 275 | - types-python-dateutil==2.8.19.14 276 | - types-requests==2.31.0.20240106 277 | - typing-extensions==4.8.0 278 | - typing-inspect==0.9.0 279 | - tzdata==2023.3 280 | - unstructured==0.11.8 281 | - unstructured-client==0.15.1 282 | - unstructured-inference==0.7.18 283 | - unstructured-pytesseract==0.3.12 284 | - uri-template==1.3.0 285 | - urllib3==2.0.7 286 | - wasabi==0.10.1 287 | - wcwidth==0.2.9 288 | - webcolors==1.13 289 | - webencodings==0.5.1 290 | - websocket-client==1.6.4 291 | - wrapt==1.16.0 292 | - xlrd==2.0.1 293 | - xlsxwriter==3.1.9 294 | - xxhash==3.4.1 295 | - yarl==1.9.3 296 | prefix: /home/jiwoong/anaconda3/envs/medcpt 297 | --------------------------------------------------------------------------------