├── assets └── MedAdapter-overview.png ├── inference ├── __pycache__ │ ├── generate.cpython-310.pyc │ ├── generate_vllm.cpython-310.pyc │ ├── generate_openai.cpython-310.pyc │ └── batched_generate_vllm.cpython-310.pyc ├── generate_vllm.py ├── generate.py ├── batched_generate_vllm.py └── generate_openai.py ├── models ├── __pycache__ │ └── generator_model.cpython-310.pyc ├── generator_model.py └── reward_model.py ├── generator ├── __pycache__ │ ├── vanilla_trainer.cpython-39.pyc │ └── vanilla_trainer.cpython-310.pyc ├── vanilla_trainer.py └── trainer.py ├── utils ├── util.py ├── credentials.py └── loggers.py ├── requirements.txt ├── scripts └── bioasq_exp.sh ├── eval_fn ├── general_eval.py ├── pubmedqa_eval.py ├── mmlu_eval.py ├── medqa_eval.py ├── medmcqa_eval.py └── bioasq_eval.py ├── reward_model └── orm │ ├── orm_data.py │ ├── orm_trainer.py │ └── orm_guide.py ├── evaluate_score.py ├── main-openai.py ├── main.py ├── configs └── bioasq │ └── gpt-35 │ ├── bioasq-reward.yaml │ ├── bioasq-guide.yaml │ ├── bioasq-gen-train.yaml │ └── bioasq-gen-test.yaml ├── README.md └── data ├── prompt_loader.py └── dataset_loader.py /assets/MedAdapter-overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wshi83/MedAdapter/HEAD/assets/MedAdapter-overview.png -------------------------------------------------------------------------------- /inference/__pycache__/generate.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wshi83/MedAdapter/HEAD/inference/__pycache__/generate.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/generator_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wshi83/MedAdapter/HEAD/models/__pycache__/generator_model.cpython-310.pyc -------------------------------------------------------------------------------- /generator/__pycache__/vanilla_trainer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wshi83/MedAdapter/HEAD/generator/__pycache__/vanilla_trainer.cpython-39.pyc -------------------------------------------------------------------------------- /inference/__pycache__/generate_vllm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wshi83/MedAdapter/HEAD/inference/__pycache__/generate_vllm.cpython-310.pyc -------------------------------------------------------------------------------- /generator/__pycache__/vanilla_trainer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wshi83/MedAdapter/HEAD/generator/__pycache__/vanilla_trainer.cpython-310.pyc -------------------------------------------------------------------------------- /inference/__pycache__/generate_openai.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wshi83/MedAdapter/HEAD/inference/__pycache__/generate_openai.cpython-310.pyc -------------------------------------------------------------------------------- /inference/__pycache__/batched_generate_vllm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wshi83/MedAdapter/HEAD/inference/__pycache__/batched_generate_vllm.cpython-310.pyc -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | def load_config(config_path): 4 | with open(config_path, "r") as f: 5 | config = yaml.load(f, Loader=yaml.FullLoader) 6 | return config 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.23.0 2 | datasets==2.16.1 3 | deepspeed==0.14.0 4 | huggingface_hub==0.20.2 5 | numpy==1.26.4 6 | peft==0.4.0 7 | PyYAML==6.0.1 8 | PyYAML==6.0.1 9 | tenacity==8.2.3 10 | torch==2.1.2 11 | tqdm==4.66.1 12 | transformers==4.37.0 13 | trl==0.4.7 14 | vllm==0.3.0 15 | -------------------------------------------------------------------------------- /scripts/bioasq_exp.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,3 python main-openai.py --debug generation --config configs/bioasq/bioasq-gen-test.yaml 2 | CUDA_VISIBLE_DEVICES=0,3 python main-openai.py --debug generation --config configs/bioasq/bioasq-gen-train.yaml 3 | CUDA_VISIBLE_DEVICES=0 accelerate launch --mixed_precision fp16 --main_process_port 29666 main.py --debug reward --config configs/bioasq/bioasq-reward.yaml 4 | CUDA_VISIBLE_DEVICES=0 accelerate launch --mixed_precision fp16 --main_process_port 29666 main.py --debug reward_guide --config configs/bioasq/bioasq-guide.yaml -------------------------------------------------------------------------------- /utils/credentials.py: -------------------------------------------------------------------------------- 1 | import openai 2 | 3 | def api_key_list(api_group): 4 | if api_group == '': 5 | api_key_list = [ 6 | { 7 | "api_key": , # , 8 | "api_version": ,# , 9 | "azure_endpoint": , # 10 | "model": # 11 | }, 12 | { 13 | "api_key": , # , 14 | "api_version": ,# , 15 | "azure_endpoint": , # 16 | "model": # 17 | }, 18 | ] 19 | return api_key_list -------------------------------------------------------------------------------- /eval_fn/general_eval.py: -------------------------------------------------------------------------------- 1 | from eval_fn.medmcqa_eval import medmcqa_judge 2 | from eval_fn.mmlu_eval import mmlu_judge 3 | from eval_fn.medqa_eval import medqa_judge 4 | from eval_fn.bioasq_eval import bioasq_judge 5 | from eval_fn.pubmedqa_eval import pubmedqa_judge 6 | 7 | def judge_router(dataset_name): 8 | if dataset_name.lower() == 'medmcqa': 9 | return medmcqa_judge 10 | elif dataset_name.lower() == 'mmlu': 11 | return mmlu_judge 12 | elif dataset_name.lower() == 'medqa': 13 | return medqa_judge 14 | elif dataset_name.lower() == 'bioasq': 15 | return bioasq_judge 16 | elif dataset_name.lower() == 'pubmedqa': 17 | return pubmedqa_judge 18 | else: 19 | raise NotImplementedError(f"Dataset {dataset_name} judgement not implemented yet.") -------------------------------------------------------------------------------- /reward_model/orm/orm_data.py: -------------------------------------------------------------------------------- 1 | from generator.vanilla_trainer import train 2 | from utils.util import load_config 3 | from datasets import load_dataset, Dataset 4 | from tqdm import tqdm 5 | from eval_fn.general_eval import judge_router 6 | 7 | def prepare_orm_data(config): 8 | dataset = load_dataset('json', data_files=config['reward_model']['dataset_name'], split=config['reward_model']['split']) 9 | print(dataset[0]) 10 | 11 | samples = [] 12 | labels = [] 13 | template = "Q: {question}\nA: {answer}" 14 | judge = judge_router(config['task']) 15 | for idx in tqdm(range(len(dataset))): 16 | sample = dataset[idx] 17 | prediction, label = judge(sample['answer'], sample['generation']) 18 | samples.append(template.format(question=sample['question'], answer=prediction)) 19 | labels.append(label) 20 | 21 | # convert samples into huggingface dataset with Dataset.from_dict 22 | dataset = Dataset.from_dict({"label": labels, "text": samples}).with_format("torch") 23 | return dataset 24 | -------------------------------------------------------------------------------- /evaluate_score.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from eval_fn.gsm8k_eval import gsm8k_metrics 3 | from eval_fn.pubmedqa_eval import pubmedqa_metrics 4 | from eval_fn.medmcqa_eval import medmcqa_metrics 5 | from eval_fn.mmlu_eval import mmlu_metrics 6 | from eval_fn.bioasq_eval import bioasq_metrics 7 | from eval_fn.medqa_eval import medqa_metrics 8 | from eval_fn.cord19_eval import cord19_metrics 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--dataset', default='gsm8k', type=str, help='Dataset to evaluate.') 13 | parser.add_argument('--solution_dir', default='data/gsm8k/generation/iter_warmup/prm/0_selected_test.jsonl', type=str, help='Path to the generated solutions.') 14 | parser.add_argument('--split', default='train', type=str, help='Split of the dataset to evaluate.') 15 | parser.add_argument('--style', default='interleave', type=str, help='Style of sampling the generation for eval.') 16 | parser.add_argument('--k', default=1, type=int, help='Sample from every k data.') 17 | args = parser.parse_args() 18 | 19 | if args.dataset == 'gsm8k': 20 | stats = gsm8k_metrics(args) 21 | elif args.dataset == 'pubmedqa': 22 | stats = pubmedqa_metrics(args) 23 | elif args.dataset == 'medmcqa': 24 | stats = medmcqa_metrics(args) 25 | elif args.dataset == 'mmlu': 26 | stats = mmlu_metrics(args) 27 | elif args.dataset == 'medqa': 28 | stats = medqa_metrics(args) 29 | elif args.dataset == 'bioasq': 30 | stats = bioasq_metrics(args) 31 | elif args.dataset == 'cord19': 32 | stats = cord19_metrics(args) 33 | print(stats) 34 | 35 | if __name__ == "__main__": 36 | main() -------------------------------------------------------------------------------- /utils/loggers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from logging import FileHandler, Formatter 4 | from datetime import datetime 5 | 6 | # Initial log directory 7 | BASE_DIR = f"logs/{datetime.now().strftime('%Y%m%d-%H%M')}" 8 | DIR = BASE_DIR 9 | os.makedirs(DIR, exist_ok=True) 10 | 11 | def get_log_dir(): 12 | return DIR 13 | 14 | def get_base_dir(): 15 | return BASE_DIR 16 | 17 | def update_log_folder(new_dir, process_index): 18 | global DIR 19 | DIR = f"{BASE_DIR}/{new_dir}/gpu_{process_index}" 20 | os.makedirs(DIR, exist_ok=True) 21 | 22 | # Update file handlers for each logger 23 | for logger_name, logger in loggers.items(): 24 | # Remove old file handlers 25 | for handler in logger.handlers[:]: 26 | if isinstance(handler, FileHandler): 27 | logger.removeHandler(handler) 28 | 29 | if logger_name in ["train", "eval"]: 30 | # For 'train' and 'eval' loggers, use the shared directory 31 | logger_dir = f"{BASE_DIR}/{new_dir}" 32 | else: 33 | # For other loggers, create a separate directory for each GPU process 34 | logger_dir = DIR 35 | 36 | # Create a new file handler with the updated directory 37 | file_name = logger_file_map[logger.name] 38 | new_file_handler = FileHandler(f"{logger_dir}/{file_name}") 39 | new_file_handler.setFormatter(Formatter('%(message)s')) 40 | logger.addHandler(new_file_handler) 41 | 42 | # Store loggers and their corresponding file names 43 | logger_file_map = { 44 | "api": "api_response.log", 45 | "search": "search_procedure.log", 46 | "train": "training_text.log", 47 | "tensor": "training_tensor.log", 48 | "eval": "eval_details.log", 49 | "error": "error.log", 50 | "adaptor": "adaptor_rating.log", 51 | } 52 | 53 | # Initialize loggers 54 | loggers = {name: logging.getLogger(name) for name in logger_file_map} 55 | for logger in loggers.values(): 56 | logger.setLevel(logging.DEBUG) -------------------------------------------------------------------------------- /main-openai.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | warnings.simplefilter('ignore') 4 | from transformers import logging 5 | logging.set_verbosity_error() 6 | import torch 7 | import numpy as np 8 | import argparse 9 | from utils.util import load_config 10 | from accelerate.utils import set_seed 11 | from datasets import load_dataset 12 | from data.dataset_loader import get_datasets 13 | from tqdm import tqdm 14 | 15 | from inference.generate import generate 16 | from inference.generate_vllm import generate_vllm 17 | from inference.generate_openai import generate_openai, generate_vllm_openai 18 | 19 | from generator.vanilla_trainer import train 20 | from reward_model.orm.orm_trainer import orm_classification_trainer 21 | from reward_model.prm.prm_trainer import prm_classification_trainer 22 | 23 | def set_seeds(seed): 24 | set_seed(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed_all(seed) 28 | 29 | def run(config): 30 | generate(config["generator"]) 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('-c', '--config', default='configs/pubmedqa.yaml', type=str, help='Path to the config file') 35 | parser.add_argument('--debug', default='generation', type=str, help='debug') 36 | args = parser.parse_args() 37 | 38 | config_path = args.config 39 | assert os.path.isfile(config_path), f"Invalid config path: {config_path}" 40 | 41 | config = load_config(config_path) 42 | 43 | # set seeds 44 | set_seeds(config['seed']) 45 | if args.debug == 'generation': 46 | generator = generate_openai(config) 47 | generator.generate() 48 | print(generator.token_usage) 49 | elif args.debug == 'reward': 50 | if 'orm' in config['reward_model']['type']: 51 | orm_classification_trainer(config) 52 | elif 'prm' in config['reward_model']['type']: 53 | prm_classification_trainer(config) 54 | elif args.debug == 'train_generator': 55 | train(config["generator_trainer"]) 56 | generate(config["generator"]) -------------------------------------------------------------------------------- /reward_model/orm/orm_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, TrainingArguments, Trainer 3 | from reward_model.orm.orm_data import prepare_orm_data 4 | 5 | def orm_classification_trainer(config): 6 | train_dataset = prepare_orm_data(config) 7 | tokenizer = AutoTokenizer.from_pretrained(config["reward_model"]["model_name"]) 8 | tokenizer.pad_token = tokenizer.eos_token 9 | 10 | def preprocess_function(examples): 11 | return tokenizer(examples["text"], truncation=True) 12 | 13 | train_dataset = train_dataset.map(preprocess_function, batched=True) 14 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer) 15 | training_args = TrainingArguments( 16 | output_dir=config["reward_model"]["output_dir"], 17 | learning_rate=config["reward_model"]["learning_rate"], 18 | per_device_train_batch_size=config["reward_model"]["per_device_train_batch_size"], 19 | # per_device_eval_batch_size=config["reward_model"]["per_device_eval_batch_size"], 20 | num_train_epochs=config["reward_model"]["num_train_epochs"], 21 | weight_decay=config["reward_model"]["weight_decay"], 22 | # evaluation_strategy=config["reward_model"]["evaluation_strategy"], 23 | save_strategy=config["reward_model"]["save_strategy"], 24 | # load_best_model_at_end=config["reward_model"]["load_best_model_at_end"], 25 | push_to_hub=config["reward_model"]["push_to_hub"], 26 | ) 27 | 28 | model = AutoModelForSequenceClassification.from_pretrained(config["reward_model"]["model_name"], num_labels=2) 29 | 30 | trainer = Trainer( 31 | model=model, 32 | args=training_args, 33 | train_dataset=train_dataset, 34 | tokenizer=tokenizer, 35 | data_collator=data_collator, 36 | ) 37 | 38 | trainer.train() 39 | if not os.path.exists(config["reward_model"]["output_dir"]): 40 | os.makedirs(config["reward_model"]["output_dir"]) 41 | trainer.model.save_pretrained(config["reward_model"]["output_dir"]) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | warnings.simplefilter('ignore') 4 | from transformers import logging 5 | logging.set_verbosity_error() 6 | import torch 7 | import numpy as np 8 | import argparse 9 | from utils.util import load_config 10 | from accelerate.utils import set_seed 11 | from datasets import load_dataset 12 | from data.dataset_loader import get_datasets 13 | from tqdm import tqdm 14 | 15 | from inference.generate import generate 16 | from inference.generate_vllm import generate_vllm 17 | 18 | from generator.vanilla_trainer import train 19 | from reward_model.orm.orm_trainer import orm_classification_trainer 20 | from reward_model.prm.prm_trainer import prm_classification_trainer 21 | from reward_model.orm.orm_guide import orm_guided_generation 22 | from reward_model.prm.prm_guide import prm_guided_generation 23 | 24 | def set_seeds(seed): 25 | set_seed(seed) 26 | np.random.seed(seed) 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | 30 | def run(config): 31 | generate(config["generator"]) 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('-c', '--config', default='configs/bioasq.yaml', type=str, help='Path to the config file') 36 | parser.add_argument('--debug', default='generation', type=str, help='debug') 37 | args = parser.parse_args() 38 | 39 | config_path = args.config 40 | assert os.path.isfile(config_path), f"Invalid config path: {config_path}" 41 | 42 | config = load_config(config_path) 43 | 44 | # set seeds 45 | set_seeds(config['seed']) 46 | if args.debug == 'generation': 47 | generate(config["generator"]) 48 | elif args.debug == 'reward': 49 | if 'orm' in config['reward_model']['type']: 50 | orm_classification_trainer(config) 51 | elif 'prm' in config['reward_model']['type']: 52 | prm_classification_trainer(config) 53 | elif args.debug == 'train_generator': 54 | train(config["generator_trainer"]) 55 | generate(config["generator"]) 56 | elif args.debug == 'reward_guide': 57 | if 'orm' in config['reward_model']['type']: 58 | generator = orm_guided_generation(config) 59 | elif 'prm' in config['reward_model']['type']: 60 | generator = prm_guided_generation(config) 61 | generator.guide_generation() 62 | -------------------------------------------------------------------------------- /eval_fn/pubmedqa_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from generator.vanilla_trainer import train 3 | from utils.util import load_config 4 | from datasets import load_dataset 5 | from tqdm import tqdm 6 | 7 | def pubmedqa_metrics(args): 8 | dataset = load_dataset('json', data_files=args.solution_dir, split=args.split) 9 | print(dataset[0]) 10 | 11 | statistics = {"correct": 0, "incorrect": 0, "incomplete":0, "total": 0} 12 | if args.style != 'first': 13 | if args.style == 'interleave': 14 | index_list = list(range(0, len(dataset), args.k)) 15 | elif args.style == 'repeat': 16 | index_list = list(range(0, int(len(dataset)/args.k))) 17 | for idx in tqdm(index_list): 18 | sample = dataset[idx] 19 | prediction, label = mmlu_judge(sample['answer'], sample['generation']) 20 | if label == 1: 21 | statistics["correct"] += 1 22 | else: 23 | statistics["incorrect"] += 1 24 | statistics["total"] += 1 25 | else: 26 | question_set = [] 27 | for idx in tqdm(range(len(dataset))): 28 | sample = dataset[idx] 29 | if not sample['question'] in question_set: 30 | prediction, label = pubmedqa_judge(sample['answer'], sample['generation']) 31 | if label == 1: 32 | statistics["correct"] += 1 33 | else: 34 | statistics["incorrect"] += 1 35 | statistics["total"] += 1 36 | question_set.append(sample['question']) 37 | 38 | return statistics 39 | 40 | def pubmedqa_judge(answer, prediction): 41 | answer = answer.split('\n#### ')[-1] 42 | prediction = prediction.split('\n#### ') 43 | if len(prediction) == 1: 44 | # incomplete 45 | return prediction[0], 0 46 | else: 47 | pred_answer = prediction[1].split(' ')[0] 48 | if '.' in pred_answer: 49 | pred_answer = pred_answer.split('.')[0] 50 | if '.' in answer: 51 | answer = answer.split('.')[0] 52 | pred_answer = pred_answer.replace('(', '').replace(')', '').replace('[', '').replace(']', '').replace('{', '').replace('}', '') 53 | prediction = prediction[0] + '\n#### ' + pred_answer + '.' 54 | if answer.lower() == pred_answer.lower(): 55 | return prediction, 1 56 | else: 57 | return prediction, 0 -------------------------------------------------------------------------------- /eval_fn/mmlu_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from generator.vanilla_trainer import train 3 | from utils.util import load_config 4 | from datasets import load_dataset 5 | from tqdm import tqdm 6 | 7 | def mmlu_metrics(args): 8 | dataset = load_dataset('json', data_files=args.solution_dir, split=args.split) 9 | print(dataset[0]) 10 | 11 | statistics = {"correct": 0, "incorrect": 0, "incomplete":0, "total": 0} 12 | if args.style != 'first': 13 | if args.style == 'interleave': 14 | index_list = list(range(0, len(dataset), args.k)) 15 | elif args.style == 'repeat': 16 | index_list = list(range(0, int(len(dataset)/args.k))) 17 | for idx in tqdm(index_list): 18 | sample = dataset[idx] 19 | prediction, label = mmlu_judge(sample['answer'], sample['generation']) 20 | if label == 1: 21 | statistics["correct"] += 1 22 | else: 23 | statistics["incorrect"] += 1 24 | statistics["total"] += 1 25 | else: 26 | question_set = [] 27 | for idx in tqdm(range(len(dataset))): 28 | sample = dataset[idx] 29 | if not sample['question'] in question_set: 30 | prediction, label = mmlu_judge(sample['answer'], sample['generation']) 31 | if label == 1: 32 | statistics["correct"] += 1 33 | else: 34 | statistics["incorrect"] += 1 35 | statistics["total"] += 1 36 | question_set.append(sample['question']) 37 | return statistics 38 | 39 | def mmlu_judge(answer, prediction): 40 | answer = answer.split('\n#### ')[-1] 41 | prediction = prediction.split('\n#### ') 42 | if len(prediction) == 1: 43 | # incomplete 44 | return prediction[0], 0 45 | else: 46 | pred_answer = prediction[1].split(' ')[0] 47 | if '.' in pred_answer: 48 | pred_answer = pred_answer.split('.')[0] 49 | if '.' in answer: 50 | answer = answer.split('.')[0] 51 | pred_answer = pred_answer.replace('(', '').replace(')', '').replace('[', '').replace(']', '').replace('{', '').replace('}', '') 52 | prediction = prediction[0] + '\n#### ' + pred_answer + '.' 53 | if answer.lower() == pred_answer.lower(): 54 | return prediction, 1 55 | else: 56 | return prediction, 0 57 | 58 | -------------------------------------------------------------------------------- /eval_fn/medqa_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from generator.vanilla_trainer import train 3 | from utils.util import load_config 4 | from datasets import load_dataset 5 | from tqdm import tqdm 6 | 7 | def medqa_metrics(args): 8 | dataset = load_dataset('json', data_files=args.solution_dir, split=args.split) 9 | print(dataset[0]) 10 | 11 | statistics = {"correct": 0, "incorrect": 0, "incomplete":0, "total": 0} 12 | if args.style != 'first': 13 | if args.style == 'interleave': 14 | index_list = list(range(0, len(dataset), args.k)) 15 | elif args.style == 'repeat': 16 | index_list = list(range(0, int(len(dataset)/args.k))) 17 | for idx in tqdm(index_list): 18 | sample = dataset[idx] 19 | prediction, label = medqa_judge(sample['answer'], sample['generation']) 20 | if label == 1: 21 | statistics["correct"] += 1 22 | else: 23 | statistics["incorrect"] += 1 24 | statistics["total"] += 1 25 | else: 26 | question_set = [] 27 | for idx in tqdm(range(len(dataset))): 28 | sample = dataset[idx] 29 | if not sample['question'] in question_set: 30 | prediction, label = medqa_judge(sample['answer'], sample['generation']) 31 | if label == 1: 32 | statistics["correct"] += 1 33 | else: 34 | statistics["incorrect"] += 1 35 | statistics["total"] += 1 36 | question_set.append(sample['question']) 37 | return statistics 38 | 39 | def medqa_judge(answer, prediction): 40 | answer = answer.split('\n#### ')[-1] 41 | prediction = prediction.split('\n#### ') 42 | if len(prediction) == 1: 43 | # incomplete 44 | return prediction[0], 0 45 | else: 46 | pred_answer = prediction[1].split(' ')[0] 47 | if '.' in pred_answer: 48 | pred_answer = pred_answer.split('.')[0] 49 | if '.' in answer: 50 | answer = answer.split('.')[0] 51 | pred_answer = pred_answer.replace('(', '').replace(')', '').replace('[', '').replace(']', '').replace('{', '').replace('}', '') 52 | prediction = prediction[0] + '\n#### ' + pred_answer + '.' 53 | if answer.lower() == pred_answer.lower(): 54 | return prediction, 1 55 | else: 56 | return prediction, 0 57 | 58 | -------------------------------------------------------------------------------- /eval_fn/medmcqa_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from generator.vanilla_trainer import train 3 | from utils.util import load_config 4 | from datasets import load_dataset 5 | from tqdm import tqdm 6 | 7 | def medmcqa_metrics(args): 8 | dataset = load_dataset('json', data_files=args.solution_dir, split=args.split) 9 | print(dataset[0]) 10 | 11 | statistics = {"correct": 0, "incorrect": 0, "incomplete":0, "total": 0} 12 | if args.style != 'first': 13 | if args.style == 'interleave': 14 | index_list = list(range(0, len(dataset), args.k)) 15 | elif args.style == 'repeat': 16 | index_list = list(range(0, int(len(dataset)/args.k))) 17 | for idx in tqdm(index_list): 18 | sample = dataset[idx] 19 | prediction, label = medmcqa_judge(sample['answer'], sample['generation']) 20 | if not '#### ' in prediction: 21 | statistics["incomplete"] += 1 22 | else: 23 | if label == 1: 24 | statistics["correct"] += 1 25 | else: 26 | statistics["incorrect"] += 1 27 | statistics["total"] += 1 28 | else: 29 | question_set = [] 30 | for idx in tqdm(range(len(dataset))): 31 | sample = dataset[idx] 32 | if not sample['question'] in question_set: 33 | prediction, label = medmcqa_judge(sample['answer'], sample['generation']) 34 | if label == 1: 35 | statistics["correct"] += 1 36 | else: 37 | statistics["incorrect"] += 1 38 | statistics["total"] += 1 39 | question_set.append(sample['question']) 40 | return statistics 41 | 42 | def medmcqa_judge(answer, prediction): 43 | answer = answer.split('\n#### ')[-1] 44 | prediction = prediction.split('\n#### ') 45 | if len(prediction) == 1: 46 | # incomplete 47 | return prediction[0], 0 48 | else: 49 | pred_answer = prediction[1].split(' ')[0] 50 | if '.' in pred_answer: 51 | pred_answer = pred_answer.split('.')[0] 52 | if '.' in answer: 53 | answer = answer.split('.')[0] 54 | pred_answer = pred_answer.replace('(', '').replace(')', '').replace('[', '').replace(']', '').replace('{', '').replace('}', '') 55 | prediction = prediction[0] + '\n#### ' + pred_answer + '.' 56 | if answer.lower() == pred_answer.lower(): 57 | return prediction, 1 58 | else: 59 | return prediction, 0 60 | 61 | -------------------------------------------------------------------------------- /configs/bioasq/gpt-35/bioasq-reward.yaml: -------------------------------------------------------------------------------- 1 | task: BioASQ 2 | train_ratio: # determined by dataset 3 | 4 | seed: 42 5 | generator_model: meta-llama/Llama-2-7b-hf 6 | gradient_accumulation_steps: 3 7 | warmup_steps: 50 8 | learning_rate: 2.0e-5 9 | 10 | logs_with_wandb: False 11 | add_special_tokens: True 12 | 13 | generator: 14 | model_name: meta-llama/Llama-2-7b-hf # ./checkpoints 15 | tokenizer_name: meta-llama/Llama-2-7b-hf 16 | max_length: 512 17 | data_frac: 0 18 | frac_len: 0 19 | output_dir: data/bioasq/generation/llama2-7b #/orm 20 | batch_size: 8 21 | input_dir: bioasq # data/bioasq/generation/openai/0_test.jsonl 22 | subset: main 23 | split: train 24 | token: # 25 | seed: 42 26 | temperature: 1 27 | tp_per_worker: 1 28 | num_data_frac: 1 29 | num_return_sequences: 8 30 | frequency_penalty: 0 31 | presence_penalty: 0 32 | stop: None 33 | 34 | generator_trainer: 35 | local_rank: -1 36 | per_device_train_batch_size: 8 37 | per_device_eval_batch_size: 8 38 | gradient_accumulation_steps: 1 39 | learning_rate: 2.0e-4 40 | max_grad_norm: 0.3 41 | weight_decay: 0.001 42 | lora_alpha: 16 43 | lora_dropout: 0.1 44 | lora_r: 64 45 | max_seq_length: 512 46 | model_name: meta-llama/Llama-2-7b-hf 47 | new_model: Llama-2-7b-hf-bioasq-warmup-formal 48 | dataset_name: bioasq 49 | subset: main 50 | split: train[0%:90%] 51 | use_4bit: False 52 | use_nested_quant: False 53 | bnb_4bit_compute_dtype: "float16" 54 | bnb_4bit_quant_type: "nf4" 55 | num_train_epochs: 2 56 | fp16: False 57 | bf16: True 58 | packing: False 59 | gradient_checkpointing: True 60 | optim: "paged_adamw_32bit" 61 | lr_scheduler_type: "cosine" 62 | max_steps: -1 63 | warmup_ratio: 0.03 64 | group_by_length: True 65 | save_steps: 1000 66 | logging_steps: 1 67 | output_dir: "./checkpoints/Llama-2-7b-bioasq-warmup-formal" 68 | device_map: {"": 0} 69 | report_to: "wandb" 70 | tb_log_dir: "./checkpoints/logs" 71 | 72 | reward_model: 73 | type: orm-classification 74 | model_name: allenai/longformer-base-4096 75 | tokenizer_name: allenai/longformer-base-4096 76 | dataset_name: data/bioasq/generation/gpt35/0.jsonl 77 | split: train 78 | accumulate_data: # 79 | output_dir: ./checkpoints/reward_model/LongFormer-orm-bioasq-gpt35 80 | learning_rate: 2.0e-5 81 | per_device_train_batch_size: 8 82 | per_device_eval_batch_size: 16 83 | num_train_epochs: 5 84 | weight_decay: 0.01 85 | # evaluation_strategy: epoch 86 | save_strategy: epoch 87 | # load_best_model_at_end: True 88 | push_to_hub: False 89 | gradient_accumulation_steps: 3 90 | l2_reg_coef: 1.0 91 | energy_temp: 5.0 92 | add_special_tokens: False 93 | -------------------------------------------------------------------------------- /configs/bioasq/gpt-35/bioasq-guide.yaml: -------------------------------------------------------------------------------- 1 | task: BioASQ 2 | train_ratio: # determined by dataset 3 | 4 | seed: 42 5 | generator_model: meta-llama/Llama-2-7b-hf 6 | gradient_accumulation_steps: 3 7 | warmup_steps: 50 8 | learning_rate: 2.0e-5 9 | 10 | logs_with_wandb: False 11 | add_special_tokens: True 12 | 13 | generator: 14 | model_name: meta-llama/Llama-2-7b-hf # ./checkpoints 15 | tokenizer_name: meta-llama/Llama-2-7b-hf 16 | max_length: 512 17 | data_frac: 0 18 | frac_len: 0 19 | output_dir: data/bioasq/generation/gpt35/orm #/orm 20 | batch_size: 8 21 | input_dir: data/bioasq/generation/gpt35/0_test.jsonl 22 | subset: main 23 | split: test 24 | token: # 25 | seed: 42 26 | temperature: 1 27 | tp_per_worker: 1 28 | num_data_frac: 1 29 | num_return_sequences: 8 30 | frequency_penalty: 0 31 | presence_penalty: 0 32 | stop: None 33 | 34 | generator_trainer: 35 | local_rank: -1 36 | per_device_train_batch_size: 8 37 | per_device_eval_batch_size: 8 38 | gradient_accumulation_steps: 1 39 | learning_rate: 2.0e-4 40 | max_grad_norm: 0.3 41 | weight_decay: 0.001 42 | lora_alpha: 16 43 | lora_dropout: 0.1 44 | lora_r: 64 45 | max_seq_length: 512 46 | model_name: meta-llama/Llama-2-7b-hf 47 | new_model: Llama-2-7b-hf-bioasq-warmup-formal 48 | dataset_name: bioasq 49 | subset: main 50 | split: train[0%:90%] 51 | use_4bit: False 52 | use_nested_quant: False 53 | bnb_4bit_compute_dtype: "float16" 54 | bnb_4bit_quant_type: "nf4" 55 | num_train_epochs: 2 56 | fp16: False 57 | bf16: True 58 | packing: False 59 | gradient_checkpointing: True 60 | optim: "paged_adamw_32bit" 61 | lr_scheduler_type: "cosine" 62 | max_steps: -1 63 | warmup_ratio: 0.03 64 | group_by_length: True 65 | save_steps: 1000 66 | logging_steps: 1 67 | output_dir: "./checkpoints/Llama-2-7b-bioasq-warmup-formal" 68 | device_map: {"": 0} 69 | report_to: "wandb" 70 | tb_log_dir: "./checkpoints/logs" 71 | 72 | reward_model: 73 | type: orm-classification 74 | model_name: ./checkpoints/reward_model/LongFormer-orm-bioasq-gpt35 75 | tokenizer_name: allenai/longformer-base-4096 76 | dataset_name: data/bioasq/generation/gpt35/0_test.jsonl 77 | split: train 78 | accumulate_data: # 79 | output_dir: ./checkpoints/reward_model/LongFormer-orm-bioasq-gpt35 80 | learning_rate: 2.0e-5 81 | per_device_train_batch_size: 8 82 | per_device_eval_batch_size: 16 83 | num_train_epochs: 5 84 | weight_decay: 0.01 85 | # evaluation_strategy: epoch 86 | save_strategy: epoch 87 | # load_best_model_at_end: True 88 | push_to_hub: False 89 | gradient_accumulation_steps: 3 90 | l2_reg_coef: 1.0 91 | energy_temp: 5.0 92 | add_special_tokens: False 93 | -------------------------------------------------------------------------------- /configs/bioasq/gpt-35/bioasq-gen-train.yaml: -------------------------------------------------------------------------------- 1 | task: BioASQ 2 | train_ratio: # determined by dataset 3 | 4 | seed: 42 5 | generator_model: meta-llama/Llama-2-7b-hf 6 | gradient_accumulation_steps: 3 7 | warmup_steps: 50 8 | learning_rate: 2.0e-5 9 | 10 | logs_with_wandb: False 11 | add_special_tokens: True 12 | 13 | generator: 14 | model_name: meta-llama/Llama-2-7b-hf # ./checkpoints 15 | tokenizer_name: meta-llama/Llama-2-7b-hf 16 | max_length: 512 17 | data_frac: 0 18 | frac_len: 0 19 | output_dir: # data/bioasq/generation/openai #/orm 20 | batch_size: 8 21 | input_dir: bioasq # data/bioasq/generation/openai/0_test.jsonl 22 | subset: main 23 | split: train 24 | token: # 25 | seed: 42 26 | temperature: 1 27 | tp_per_worker: 1 28 | num_data_frac: 1 29 | num_return_sequences: 8 30 | frequency_penalty: 0 31 | presence_penalty: 0 32 | stop: None 33 | openai_credentials: # 34 | 35 | generator_trainer: 36 | local_rank: -1 37 | per_device_train_batch_size: 8 38 | per_device_eval_batch_size: 8 39 | gradient_accumulation_steps: 1 40 | learning_rate: 2.0e-4 41 | max_grad_norm: 0.3 42 | weight_decay: 0.001 43 | lora_alpha: 16 44 | lora_dropout: 0.1 45 | lora_r: 64 46 | max_seq_length: 512 47 | model_name: meta-llama/Llama-2-7b-hf 48 | new_model: Llama-2-7b-hf-pubbioasq-warmup-formal 49 | dataset_name: bioasq 50 | subset: main 51 | split: train[0%:90%] 52 | use_4bit: False 53 | use_nested_quant: False 54 | bnb_4bit_compute_dtype: "float16" 55 | bnb_4bit_quant_type: "nf4" 56 | num_train_epochs: 2 57 | fp16: False 58 | bf16: True 59 | packing: False 60 | gradient_checkpointing: True 61 | optim: "paged_adamw_32bit" 62 | lr_scheduler_type: "cosine" 63 | max_steps: -1 64 | warmup_ratio: 0.03 65 | group_by_length: True 66 | save_steps: 1000 67 | logging_steps: 1 68 | output_dir: "./checkpoints/Llama-2-7b-bioasq-warmup-formal" 69 | device_map: {"": 0} 70 | report_to: "wandb" 71 | tb_log_dir: "./checkpoints/logs" 72 | 73 | reward_model: 74 | type: orm-classification 75 | model_name: ./checkpoints/reward_model/LongFormer-orm-bioasq 76 | tokenizer_name: allenai/longformer-base-4096 77 | dataset_name: data/bioasq/generation/openai/0.jsonl 78 | split: train 79 | accumulate_data: # 80 | output_dir: ./checkpoints/reward_model/LongFormer-orm-bioasq 81 | learning_rate: 2.0e-5 82 | per_device_train_batch_size: 16 83 | per_device_eval_batch_size: 32 84 | num_train_epochs: 5 85 | weight_decay: 0.01 86 | # evaluation_strategy: epoch 87 | save_strategy: epoch 88 | # load_best_model_at_end: True 89 | push_to_hub: False 90 | gradient_accumulation_steps: 3 91 | l2_reg_coef: 1.0 92 | energy_temp: 5.0 93 | add_special_tokens: False 94 | -------------------------------------------------------------------------------- /configs/bioasq/gpt-35/bioasq-gen-test.yaml: -------------------------------------------------------------------------------- 1 | task: BioASQ 2 | train_ratio: # determined by dataset 3 | 4 | seed: 42 5 | generator_model: meta-llama/Llama-2-7b-hf 6 | gradient_accumulation_steps: 3 7 | warmup_steps: 50 8 | learning_rate: 2.0e-5 9 | 10 | logs_with_wandb: False 11 | add_special_tokens: True 12 | 13 | generator: 14 | model_name: meta-llama/Llama-2-7b-hf # ./checkpoints 15 | tokenizer_name: meta-llama/Llama-2-7b-hf 16 | max_length: 512 17 | data_frac: 0 18 | frac_len: 0 19 | output_dir: # e.g., data/bioasq/generation/gpt35 20 | batch_size: 8 21 | input_dir: bioasq # data/bioasq/generation/openai/0_test.jsonl 22 | subset: main 23 | split: test 24 | token: # 25 | seed: 42 26 | temperature: 1 27 | tp_per_worker: 1 28 | num_data_frac: 1 29 | num_return_sequences: 8 30 | frequency_penalty: 0 31 | presence_penalty: 0 32 | stop: None 33 | openai_credentials: # 34 | 35 | generator_trainer: 36 | local_rank: -1 37 | per_device_train_batch_size: 8 38 | per_device_eval_batch_size: 8 39 | gradient_accumulation_steps: 1 40 | learning_rate: 2.0e-4 41 | max_grad_norm: 0.3 42 | weight_decay: 0.001 43 | lora_alpha: 16 44 | lora_dropout: 0.1 45 | lora_r: 64 46 | max_seq_length: 512 47 | model_name: meta-llama/Llama-2-7b-hf 48 | new_model: # e.g., Llama-2-7b-hf-bioasq-warmup-formal 49 | dataset_name: bioasq 50 | subset: main 51 | split: train[0%:90%] 52 | use_4bit: False 53 | use_nested_quant: False 54 | bnb_4bit_compute_dtype: "float16" 55 | bnb_4bit_quant_type: "nf4" 56 | num_train_epochs: 2 57 | fp16: False 58 | bf16: True 59 | packing: False 60 | gradient_checkpointing: True 61 | optim: "paged_adamw_32bit" 62 | lr_scheduler_type: "cosine" 63 | max_steps: -1 64 | warmup_ratio: 0.03 65 | group_by_length: True 66 | save_steps: 1000 67 | logging_steps: 1 68 | output_dir: # e.g., "./checkpoints/Llama-2-7b-bioasq-warmup-formal" 69 | device_map: {"": 0} 70 | report_to: "wandb" 71 | tb_log_dir: "./checkpoints/logs" 72 | 73 | reward_model: 74 | type: orm-classification 75 | model_name: ./checkpoints/reward_model/LongFormer-orm-bioasq 76 | tokenizer_name: allenai/longformer-base-4096 77 | dataset_name: data/bioasq/generation/openai/0.jsonl 78 | split: train 79 | accumulate_data: # 80 | output_dir: ./checkpoints/reward_model/LongFormer-orm-bioasq 81 | learning_rate: 2.0e-5 82 | per_device_train_batch_size: 16 83 | per_device_eval_batch_size: 32 84 | num_train_epochs: 5 85 | weight_decay: 0.01 86 | # evaluation_strategy: epoch 87 | save_strategy: epoch 88 | # load_best_model_at_end: True 89 | push_to_hub: False 90 | gradient_accumulation_steps: 3 91 | l2_reg_coef: 1.0 92 | energy_temp: 5.0 93 | add_special_tokens: False 94 | -------------------------------------------------------------------------------- /eval_fn/bioasq_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from generator.vanilla_trainer import train 3 | from utils.util import load_config 4 | from datasets import load_dataset 5 | from tqdm import tqdm 6 | 7 | def bioasq_metrics(args): 8 | dataset = load_dataset('json', data_files=args.solution_dir, split=args.split) 9 | print(dataset[0]) 10 | 11 | statistics = {"correct": 0, "incorrect": 0, "incomplete":0, "total": 0} 12 | if args.style != 'first': 13 | if args.style == 'interleave': 14 | index_list = list(range(0, len(dataset), args.k)) 15 | elif args.style == 'repeat': 16 | index_list = list(range(0, int(len(dataset)/args.k))) 17 | for idx in tqdm(index_list): 18 | sample = dataset[idx] 19 | prediction, label = bioasq_judge(sample['answer'], sample['generation']) 20 | if label == 1: 21 | statistics["correct"] += 1 22 | else: 23 | statistics["incorrect"] += 1 24 | statistics["total"] += 1 25 | else: 26 | question_set = [] 27 | for idx in tqdm(range(len(dataset))): 28 | sample = dataset[idx] 29 | if not sample['question'] in question_set: 30 | prediction, label = bioasq_judge(sample['answer'], sample['generation']) 31 | if label == 1: 32 | statistics["correct"] += 1 33 | else: 34 | statistics["incorrect"] += 1 35 | statistics["total"] += 1 36 | question_set.append(sample['question']) 37 | 38 | return statistics 39 | 40 | def bioasq_judge(answer, prediction): 41 | answer = answer.split('\n#### ')[-1] 42 | if '\n#### ' in prediction: 43 | prediction = prediction.split('\n#### ') 44 | else: 45 | prediction = prediction.split('####') 46 | if len(prediction) == 1: 47 | return prediction[0], 0 48 | else: 49 | pred_answer = prediction[1] 50 | while pred_answer[0] == '\n' or pred_answer[0] == ' ': 51 | pred_answer = pred_answer[1:] 52 | if '\n' in pred_answer: 53 | pred_answer = pred_answer.split('\n')[0] 54 | if '.' in pred_answer: 55 | pred_answer = pred_answer.split('.')[0] 56 | if ',' in pred_answer: 57 | pred_answer = pred_answer.split(',')[0] 58 | pred_answer = pred_answer.replace('\n', '').replace('(', '').replace(')', '').replace('[', '').replace(']', '').replace('{', '').replace('}', '').replace(',', '').replace('.', '') 59 | if '.' in answer: 60 | answer = answer.split('.')[0] 61 | prediction = prediction[0] + '\n#### ' + pred_answer + '.' 62 | if answer.lower() in pred_answer.lower(): 63 | return prediction, 1 64 | else: 65 | return prediction, 0 -------------------------------------------------------------------------------- /inference/generate_vllm.py: -------------------------------------------------------------------------------- 1 | from vllm import LLM, SamplingParams 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | from datasets import load_dataset 4 | import argparse 5 | import torch, time, json, os 6 | from pathlib import Path 7 | from tqdm import tqdm 8 | from datetime import timedelta 9 | import warnings 10 | from accelerate.utils import InitProcessGroupKwargs 11 | warnings.filterwarnings("ignore") 12 | from data.dataset_loader import get_datasets 13 | import random 14 | 15 | def generate_vllm(config): 16 | model_path = config["model_name"] 17 | world_size = config["world_size"] 18 | data_frac = config["data_frac"] 19 | output_dir = Path(config["output_dir"]) 20 | output_dir.mkdir(parents=True, exist_ok=True) 21 | 22 | # load the base model and tokenizer 23 | tokenizer = AutoTokenizer.from_pretrained(model_path) 24 | tokenizer.pad_token = tokenizer.eos_token 25 | 26 | llm = LLM(model=model_path, tensor_parallel_size=world_size) 27 | sampling_params = SamplingParams(temperature=1.0, top_p=1.0, max_tokens=256) 28 | 29 | # load data 30 | train_data, val_data, test_data, prompt = get_datasets(config['seed'], config['input_dir']) 31 | if config["split"] == 'test': 32 | data = test_data 33 | elif config["split"] == 'val': 34 | data = val_data 35 | else: 36 | data = train_data 37 | # random.seed(seed) 38 | random.shuffle(data) 39 | if config['frac_len'] > 0: 40 | sub_len = config['frac_len'] 41 | if sub_len * (data_frac + 1) > len(data): 42 | data = data[sub_len*data_frac:] 43 | else: 44 | data = data[sub_len*data_frac:sub_len*(data_frac+1)] 45 | else: 46 | data = data[:] 47 | 48 | prompts_all = ["Question: " + data[idx]['question'] + "\n\nAnswer: " for idx in range(len(data))] 49 | prompts_old = [data[idx]['question'] for idx in range(len(data))] 50 | corrects_all = [data[idx]['answer'] for idx in range(len(data))] 51 | 52 | start = time.time() 53 | 54 | # run vllm 55 | results_gathered = list(map(lambda x: x.outputs[0].text, llm.generate(prompts_all, sampling_params))) 56 | results = [r.replace(tokenizer.eos_token, "").lstrip() for r in results_gathered] 57 | 58 | timediff = time.time() - start 59 | print(f"time elapsed: {timediff}") 60 | 61 | # collecting data 62 | for idx in tqdm(range(len(corrects_all))): 63 | d = {"question": prompts_old[idx], "answer": corrects_all[idx], "generation": results[idx]} 64 | if config["split"] == 'test': 65 | file_name = f"{config['output_dir']}/{config['data_frac']}_test.jsonl" 66 | else: 67 | file_name = f"{config['output_dir']}/{config['data_frac']}.jsonl" 68 | with open(file_name, 'a') as f: 69 | json.dump(d, f) 70 | f.write('\n') 71 | -------------------------------------------------------------------------------- /models/generator_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import transformers 4 | from transformers import AdamW, AutoModelForCausalLM, AutoTokenizer, DataCollatorWithPadding, get_constant_schedule_with_warmup 5 | from accelerate import Accelerator 6 | from accelerate.state import PartialState 7 | from accelerate.utils import InitProcessGroupKwargs, release_memory 8 | from datetime import timedelta 9 | 10 | class generator_model(): 11 | def __init__(self, config): 12 | self.config = config 13 | self.tokenizer = AutoTokenizer.from_pretrained(config["generator_model"]) 14 | self.tokenizer.pad_token = self.tokenizer.eos_token 15 | self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True 16 | kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=96000)) 17 | self.accelerator = Accelerator( 18 | split_batches=False, 19 | mixed_precision='fp16', 20 | gradient_accumulation_steps=self.config["gradient_accumulation_steps"], 21 | log_with='wandb' if self.config.get("log_with_wandb", False) else None, 22 | project_dir='logs' if self.config.get("log_with_wandb", False) else None, 23 | device_placement=True, 24 | kwargs_handlers=[kwargs] 25 | ) 26 | self.model = AutoModelForCausalLM.from_pretrained(config["generator_model"], trust_remote_code=True) 27 | self.model.config.use_cache = False 28 | self.model.config.pretraining_tp = 1 29 | if self.tokenizer.pad_token is None: 30 | self.acclerator.print("Adding pad token to the tokenizer...") 31 | self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) 32 | self.model.resize_token_embeddings(len(self.tokenizer)) 33 | 34 | self.answer_token = self.tokenizer.encode("\nA: ", return_tensors="pt", add_special_tokens=False)[0, 1:] 35 | self.optimizer = AdamW( 36 | self.model.parameters(), 37 | lr=config["learning_rate"] * self.accelerator.gradient_accumulation_steps, 38 | weight_decay=0.01 39 | ) 40 | self.lr_scheduler = get_constant_schedule_with_warmup( 41 | self.optimizer, 42 | num_warmup_steps=config["warmup_steps"] 43 | ) 44 | self.accelerator.print(f"Distributed: {self.accelerator.distributed_type}, Mixed precision: {self.accelerator.mixed_precision}") 45 | 46 | def input_text_process(self, input_texts): 47 | return input_texts 48 | 49 | def initialization_step(self): 50 | self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( 51 | self.model, self.optimizer, self.lr_scheduler 52 | ) 53 | 54 | def inference_step(self, input_texts): 55 | self.model.eval() 56 | input_texts = self.input_text_process(input_texts) 57 | inputs = self.tokenizer( 58 | [input_texts], 59 | return_tensors="pt", 60 | add_special_tokens=self.config["add_special_tokens"], 61 | padding=True, 62 | truncation=True, 63 | ).to(self.accelerator.device) 64 | input_ids = inputs["input_ids"] 65 | attention_mask = inputs["attention_mask"] 66 | print(input_texts) 67 | pipeline = transformers.pipeline( 68 | "text-generation", 69 | model=self.accelerator.unwrap_model(self.model), 70 | tokenizer=self.tokenizer, 71 | device=self.accelerator.device, 72 | # do_sample=True, 73 | # top_k=10, 74 | num_return_sequences=1, 75 | # eos_token_id=self.tokenizer.eos_token_id, 76 | max_length=500 77 | ) 78 | outputs = pipeline(input_texts) 79 | print(outputs) 80 | return outputs[0]["generated_text"] -------------------------------------------------------------------------------- /models/reward_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from accelerate.state import PartialState 4 | from accelerate.utils import release_memory, InitProcessGroupKwargs 5 | import datasets 6 | from datasets import Dataset 7 | datasets.disable_progress_bar() 8 | from datetime import timedelta 9 | import os 10 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 11 | os.environ["WANDB_LOG_MODEL"] = "false" 12 | from tqdm.auto import tqdm 13 | from utils.util import get_answer_start_idx 14 | from utils.loggers import loggers 15 | from accelerate import Accelerator 16 | 17 | from transformers import ( 18 | AdamW, 19 | AutoModelForCausalLM, 20 | AutoModelForSequenceClassification, 21 | AutoTokenizer, 22 | DataCollatorWithPadding, 23 | get_constant_schedule_with_warmup 24 | ) 25 | 26 | torch.cuda_empty_cache() 27 | torch.set_printoptions(threshold=10_000) 28 | 29 | class reward_model(): 30 | def __init__(self, config): 31 | self.config = config 32 | self.tokenizer = AutoTokenizer.from_pretrained(config["reward_model"], truncation_side="left") 33 | self.tokenizer.pad_token = self.tokenizer.eos_token 34 | self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True 35 | kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=96000)) 36 | self.accelerator = Accelerator( 37 | split_batches=False, 38 | mixed_precision='fp16', 39 | gradient_accumulation_steps=self.config["gradient_accumulation_steps"], 40 | log_with='wandb' if self.config.get("log_with_wandb", False) else None, 41 | project_dir='logs' if self.config.get("log_with_wandb", False) else None, 42 | device_placement=True, 43 | kwargs_handlers=[kwargs] 44 | ) 45 | self.mode = config["reward_mode"] 46 | if self.mode == 'generation': 47 | self.model = AutoModelForCausalLM.from_pretrained(config["reward_model"], trust_remote_code=True) 48 | elif self.mode == 'classification-1': 49 | self.model=AutoModelForSequenceClassification.from_pretrained(config["reward_model"], trust_remote_code=True, num_labels=1) 50 | self.model.config.pad_token_id = self.tokenizer.eos_token_id 51 | elif self.mode == 'classification-2': 52 | self.model=AutoModelForSequenceClassification.from_pretrained(config["reward_model"], trust_remote_code=True, num_labels=2) 53 | self.model.config.pad_token_id = self.tokenizer.eos_token_id 54 | else: 55 | raise NotImplementedError 56 | 57 | self.model.config.use_cache = False if "phi" in self.config["reward_model"].lower() else True 58 | self.model.config.pretraining_tp = 1 59 | 60 | if self.tokenizer.pad_token_id is None: 61 | self.accelerator.print("Adding pad token to the tokenizer...") 62 | self.tokenizer.add_special_tokens({"pad_token": '[PAD]'}) 63 | self.model.resize_token_embeddings(len(self.tokenizer)) 64 | 65 | self.answer_token = self.tokenizer.encode("\nA: ", return_tensors="pt", add_special_tokens=False)[0, 1:] 66 | 67 | self.optimizer = AdamW( 68 | self.model.parameters(), 69 | lr=config["learning_rate"] * self.accelerator.gradient_accumulation_steps, 70 | weight_decay=0.01, 71 | ) 72 | self.lr_scheduler = get_constant_schedule_with_warmup( 73 | self.optimizer, 74 | num_warmup_steps=config["warmup_steps"], 75 | ) 76 | self.accelerator.print(f"Distributed: {self.accelerator.distributed_type}, Mixed precision: {self.accelerator.mixed_precision}") 77 | 78 | def build_dataset(self, positive_texts, negative_texts, save_to): 79 | pos_len, neg_len = len(positive_texts), len(negative_texts) 80 | labels = - torch.ones(pos_len + neg_len) 81 | labels[:pos_len] = 1 82 | -------------------------------------------------------------------------------- /generator/vanilla_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from datasets import load_dataset, Dataset 4 | from transformers import ( 5 | AutoModelForCausalLM, 6 | AutoTokenizer, 7 | BitsAndBytesConfig, 8 | HfArgumentParser, 9 | TrainingArguments, 10 | pipeline, 11 | logging, 12 | ) 13 | from peft import LoraConfig, PeftModel, get_peft_model 14 | from trl import SFTTrainer 15 | 16 | import argparse 17 | from data.prompt_loader import load_prompts 18 | from data.dataset_loader import get_datasets 19 | from utils.util import load_config 20 | 21 | def load_model(config): 22 | # Load tokenizer and model with QLoRA configuration 23 | compute_dtype = getattr(torch, config["bnb_4bit_compute_dtype"]) 24 | 25 | bnb_config = BitsAndBytesConfig( 26 | load_in_4bit=config["use_4bit"], 27 | bnb_4bit_quant_type=config["bnb_4bit_quant_type"], 28 | bnb_4bit_compute_dtype=compute_dtype, 29 | bnb_4bit_use_double_quant=config["use_nested_quant"], 30 | ) 31 | 32 | if compute_dtype == torch.float16 and config["use_4bit"]: 33 | major, _ = torch.cuda.get_device_capability() 34 | if major >= 8: 35 | print("=" * 80) 36 | print("Your GPU supports bfloat16, you can accelerate training with the argument --bf16") 37 | print("=" * 80) 38 | 39 | model = AutoModelForCausalLM.from_pretrained( 40 | config["model_name"], 41 | device_map="auto", 42 | # quantization_config=bnb_config 43 | ) 44 | 45 | model.config.use_cache = False 46 | model.config.pretraining_tp = 1 47 | 48 | # Load LoRA configuration 49 | peft_config = LoraConfig( 50 | lora_alpha=config["lora_alpha"], 51 | lora_dropout=config["lora_dropout"], 52 | r=config["lora_r"], 53 | bias="none", 54 | task_type="CAUSAL_LM", 55 | ) 56 | 57 | # Load Tokenizer 58 | tokenizer = AutoTokenizer.from_pretrained(config["model_name"], trust_remote_code=True) 59 | tokenizer.pad_token = tokenizer.eos_token 60 | tokenizer.padding_side = "right" 61 | 62 | return model, tokenizer, peft_config 63 | 64 | def train(config): 65 | model, tokenizer, peft_config = load_model(config) 66 | training_arguments = TrainingArguments( 67 | output_dir=config["output_dir"], 68 | per_device_train_batch_size=config["per_device_train_batch_size"], 69 | gradient_accumulation_steps=config["gradient_accumulation_steps"], 70 | optim=config["optim"], 71 | save_steps=config["save_steps"], 72 | logging_steps=config["logging_steps"], 73 | learning_rate=config["learning_rate"], 74 | fp16=config["fp16"], 75 | bf16=config["bf16"], 76 | max_grad_norm=config["max_grad_norm"], 77 | max_steps=config["max_steps"], 78 | warmup_ratio=config["warmup_ratio"], 79 | group_by_length=config["group_by_length"], 80 | lr_scheduler_type=config["lr_scheduler_type"], 81 | report_to=config["report_to"], 82 | ) 83 | 84 | # dataset = load_dataset(config["dataset_name"], config["subset"], split=config["split"]) 85 | # prompt = load_prompts(config["dataset_name"]) 86 | dataset, _, _, prompt = get_datasets(42, config["dataset_name"]) 87 | 88 | train_text_list = [] 89 | for i in range(len(dataset)): 90 | template = " [INST] Q: {question} [/INST] \nA: {answer} " 91 | question = dataset[i]['question'] 92 | answer = dataset[i]['answer'] 93 | if not '####' in answer: 94 | answer = '#### ' + answer 95 | if prompt != None: 96 | train_text_list.append(prompt + '\n\n' + template.format(question=question, answer=answer)) 97 | else: 98 | train_text_list.append(template.format(question=question, answer=answer)) 99 | 100 | temp_dataset = Dataset.from_dict({ 101 | "text": train_text_list, 102 | }).with_format("torch") 103 | 104 | 105 | trainer = SFTTrainer( 106 | model=model, 107 | train_dataset=temp_dataset, 108 | # peft_config=peft_config, 109 | dataset_text_field="text", 110 | max_seq_length=config["max_seq_length"], 111 | tokenizer=tokenizer, 112 | args=training_arguments, 113 | packing=config["packing"], 114 | ) 115 | 116 | trainer.train() 117 | if not os.path.exists(config["output_dir"]): 118 | os.makedirs(config["output_dir"]) 119 | trainer.model.save_pretrained(config["output_dir"]) 120 | 121 | def main(): 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('-c', '--config', default='configs/gsm8k.yaml', type=str, help='Path to the config file') 124 | args = parser.parse_args() 125 | 126 | config_path = args.config 127 | assert os.path.isfile(config_path), f"Invalid config path: {config_path}" 128 | 129 | config = load_config(config_path) 130 | train(config) 131 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MedAdapter 2 | This is the code for our paper ["MedAdapter: Efficient Test-Time Adaptation of Large Language Models Towards Medical Reasoning"](https://arxiv.org/abs/2405.03000) in EMNLP 2024. 3 | 4 | ## Framework 5 | ![MedAdapter](assets/MedAdapter-overview.png) 6 | 7 | ## Data 8 | You need to first download the raw data from the following sources: [MedMCQA](https://huggingface.co/datasets/openlifescienceai/medmcqa), [MedQA](https://huggingface.co/datasets/bigbio/med_qa), [MMLU-Med](https://huggingface.co/datasets/cais/mmlu), [PubMedQA](https://huggingface.co/datasets/qiaojin/PubMedQA), [BioASQ](http://bioasq.org/), [MedNLI](https://huggingface.co/datasets/bigbio/mednli), [MediQA-RQE](https://huggingface.co/datasets/bigbio/mediqa_rqe), [PubHealth](https://huggingface.co/datasets/bigbio/pubhealth), [MediQA](https://huggingface.co/datasets/bigbio/mediqa_qa?row=0), and [CORD19](https://huggingface.co/datasets/allenai/cord19). Please download the data and save in the ``./data//`` directory. 9 | 10 | ## Experiments 11 | 12 | ### Candidate Solution Generation 13 | Before training the model, we need to first leverage the LLM to generate candidate solution for the following reward model training. We need to generate the candidates on both training and test sets. For OpenAI commercial models (e.g., gpt-3.5-turbo, gpt-4), we need to call ``main-openai.py`` to initiate the generation process: 14 | ```python 15 | python main-openai.py --debug generation --config configs/bioasq/bioasq-gen-test.yaml 16 | python main-openai.py --debug generation --config configs/bioasq/bioasq-gen-train.yaml 17 | ``` 18 | 19 | For the open-sourced LLMs, we need to leverage the [vLLM](https://docs.vllm.ai/en/latest/getting_started/quickstart.html) as the inference model engine and need to run ``main.py`` to initiate the generation process: 20 | ```python 21 | CUDA_VISIBLE_DEVICES=1,2 accelerate launch --mixed_precision fp16 --main_process_port 29650 main.py --debug generation --config configs/bioasq/bioasq-gen-test.yaml 22 | CUDA_VISIBLE_DEVICES=1,2 accelerate launch --mixed_precision fp16 --main_process_port 29650 main.py --debug generation --config configs/bioasq/bioasq-gen-train.yaml 23 | ``` 24 | 25 | The ``--config`` indicate the generation configuration during inference, containing important hyperparameters. The configuration files should be stored in the directory ``./configs///``. 26 | 27 | ### Outcome-Supervised Adapter Training 28 | To train the model, we need to run the ``main.py`` entry program: 29 | ```python 30 | CUDA_VISIBLE_DEVICES=0 accelerate launch --mixed_precision fp16 --main_process_port 29666 main.py --debug reward --config configs/bioasq/bioasq-reward.yaml 31 | ``` 32 | The configuration file of training the adapter should also be saved under ``./configs///`` directory. 33 | 34 | ### Best-of-K Inference 35 | Same as the previous two stages, we need to run the entry program ``main.py`` again to inference on the previously generated candidates on test set: 36 | ```python 37 | CUDA_VISIBLE_DEVICES=0 accelerate launch --mixed_precision fp16 --main_process_port 29666 main.py --debug reward_guide --config configs/bioasq/bioasq-guide.yaml 38 | ``` 39 | The configuration file of training the adapter should also be saved under ``./configs///`` directory. 40 | 41 | ### Running Scripts 42 | We also offer several examples of running commands in the directory ``./scripts``. 43 | 44 | ## Citation 45 | If you find this repository valuable for your research, we kindly request that you acknowledge our paper by citing the follwing paper. We appreciate your consideration. 46 | 47 | ``` 48 | @inproceedings{shi-etal-2024-medadapter, 49 | title = "{M}ed{A}dapter: Efficient Test-Time Adaptation of Large Language Models Towards Medical Reasoning", 50 | author = "Shi, Wenqi and 51 | Xu, Ran and 52 | Zhuang, Yuchen and 53 | Yu, Yue and 54 | Sun, Haotian and 55 | Wu, Hang and 56 | Yang, Carl and 57 | Wang, May Dongmei", 58 | editor = "Al-Onaizan, Yaser and 59 | Bansal, Mohit and 60 | Chen, Yun-Nung", 61 | booktitle = "Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing", 62 | month = nov, 63 | year = "2024", 64 | address = "Miami, Florida, USA", 65 | publisher = "Association for Computational Linguistics", 66 | url = "https://aclanthology.org/2024.emnlp-main.1244", 67 | doi = "10.18653/v1/2024.emnlp-main.1244", 68 | pages = "22294--22314", 69 | abstract = "Despite their improved capabilities in generation and reasoning, adapting large language models (LLMs) to the biomedical domain remains challenging due to their immense size and privacy concerns. In this study, we propose MedAdapter, a unified post-hoc adapter for test-time adaptation of LLMs towards biomedical applications. Instead of fine-tuning the entire LLM, MedAdapter effectively adapts the original model by fine-tuning only a small BERT-sized adapter to rank candidate solutions generated by LLMs. Experiments on four biomedical tasks across eight datasets demonstrate that MedAdapter effectively adapts both white-box and black-box LLMs in biomedical reasoning, achieving average performance improvements of 18.24{\%} and 10.96{\%}, respectively, without requiring extensive computational resources or sharing data with third parties. MedAdapter also yields enhanced performance when combined with train-time adaptation, highlighting a flexible and complementary solution to existing adaptation methods. Faced with the challenges of balancing model performance, computational resources, and data privacy, MedAdapter provides an efficient, privacy-preserving, cost-effective, and transparent solution for adapting LLMs to the biomedical domain.", 70 | } 71 | ``` 72 | -------------------------------------------------------------------------------- /inference/generate.py: -------------------------------------------------------------------------------- 1 | from accelerate import Accelerator 2 | from accelerate.utils import gather_object 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | from datasets import load_dataset 5 | import argparse 6 | import torch, time, json, os 7 | from pathlib import Path 8 | from tqdm import tqdm 9 | from datetime import timedelta 10 | from accelerate.utils import InitProcessGroupKwargs 11 | import warnings 12 | warnings.filterwarnings("ignore") 13 | 14 | from data.dataset_loader import get_datasets 15 | from utils.util import load_config 16 | from data.prompt_loader import load_prompts 17 | import random 18 | kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=36000)) 19 | accelerator = Accelerator(kwargs_handlers=[kwargs]) 20 | 21 | def prepare_prompts(prompts, tokenizer, batch_size=4): 22 | batches = [prompts[i:i+batch_size] for i in range(0, len(prompts), batch_size)] 23 | batches_tok = [] 24 | tokenizer.padding_side = "left" 25 | for prompt_batch in batches: 26 | batches_tok.append( 27 | tokenizer( 28 | prompt_batch, 29 | return_tensors="pt", 30 | padding='longest', 31 | truncation=False, 32 | pad_to_multiple_of=8, 33 | add_special_tokens=False 34 | ).to("cuda") 35 | ) 36 | tokenizer.padding_side = "right" 37 | return batches_tok 38 | 39 | def generate(config): 40 | model_path = config["model_name"] 41 | tokenizer_path = config["tokenizer_name"] 42 | data_frac = config["data_frac"] 43 | batch_size = config["batch_size"] 44 | output_dir = Path(config["output_dir"]) 45 | output_dir.mkdir(parents=True, exist_ok=True) 46 | 47 | # credentials 48 | token = config["token"] 49 | 50 | # load a base model and tokenizer 51 | model = AutoModelForCausalLM.from_pretrained( 52 | model_path, 53 | device_map={"": accelerator.process_index}, 54 | torch_dtype=torch.bfloat16, 55 | token=token 56 | ) 57 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=token) 58 | tokenizer.pad_token = tokenizer.eos_token 59 | 60 | # load_data 61 | prompt = load_prompts(config["input_dir"]) 62 | seed = config["seed"] 63 | train_dataset, val_dataset, test_dataset, prompt = get_datasets(seed, config["input_dir"]) 64 | if config["split"] == 'train': 65 | data = train_dataset 66 | elif config["split"] == 'val': 67 | data = val_dataset 68 | elif config["split"] == 'test': 69 | data = test_dataset 70 | random.seed(seed) 71 | random.shuffle(data) 72 | if config["frac_len"] > 0: 73 | sub_len = config["frac_len"] 74 | if sub_len * (data_frac + 1) > len(data): 75 | data = data[sub_len*data_frac:] 76 | else: 77 | data = data[sub_len*data_frac:sub_len*(data_frac+1)] 78 | 79 | print(data[0]) 80 | 81 | # modification here 82 | if prompt is not None: 83 | prompts_all = [" [INST] " + prompt + "\n\nQ: " + data[idx]['question'] + " [/INST] \nA: " for idx in range(len(data))] 84 | else: 85 | prompts_all = [" [INST] Q: " + data[idx]['question'] + " [/INST] \nA: " for idx in range(len(data))] 86 | prompts_old = [data[int(idx/config["num_return_sequences"])]['question'] for idx in range(config["num_return_sequences"]*len(data))] 87 | corrects_all = [data[int(idx/config["num_return_sequences"])]['answer'] for idx in range(config["num_return_sequences"]*len(data))] 88 | 89 | print(len(prompts_old), len(corrects_all)) 90 | # sync GPUs and start the timer 91 | accelerator.wait_for_everyone() 92 | start = time.time() 93 | 94 | # divide the prompt list onto the avilable GPUs 95 | with accelerator.split_between_processes(prompts_all) as prompts: 96 | results = [] 97 | prompt_batches = prepare_prompts(prompts, tokenizer, batch_size) 98 | for prompts_tokenized in tqdm(prompt_batches): 99 | # set max_new_tokens smaller for faster inference 100 | outputs_tokenized = model.generate(**prompts_tokenized, max_new_tokens=config["max_length"], pad_token_id=tokenizer.eos_token_id, num_return_sequences=config["num_return_sequences"]) 101 | inputs_tokenized = prompts_tokenized["input_ids"].repeat_interleave(config["num_return_sequences"], dim=0) 102 | 103 | # remove prompt from gen. tokens 104 | outputs_tokenized = [tok_out[len(tok_in):] for tok_in, tok_out in zip(inputs_tokenized, outputs_tokenized)] 105 | # decode the generated tokens 106 | outputs = tokenizer.batch_decode(outputs_tokenized) 107 | # print(outputs.shape) 108 | results.extend(outputs) 109 | 110 | # collect results from all the GPUs and remove paddings 111 | results_gathered = gather_object(results) 112 | results = [r.replace(tokenizer.eos_token, "").lstrip() for r in results_gathered] 113 | print(len(results)) 114 | # input() 115 | if accelerator.is_local_main_process: 116 | timediff = time.time() - start 117 | print(f"Time elapsed: {timediff}") 118 | 119 | # collecting data 120 | for idx in range(len(corrects_all)): 121 | d = {"question": prompts_old[idx], "answer": corrects_all[idx], "generation": results[idx]} 122 | if config["split"] == 'test': 123 | if config['num_return_sequences'] == 1: 124 | file_name = f"{config['output_dir']}/{config['data_frac']}_test.jsonl" 125 | else: 126 | file_name = f"{config['output_dir']}/{config['data_frac']}_test_candidates.jsonl" 127 | else: 128 | file_name = f"{config['output_dir']}/{config['data_frac']}.jsonl" 129 | with open(file_name, 'a') as f: 130 | json.dump(d, f) 131 | f.write('\n') 132 | 133 | if __name__ == "__main__": 134 | parser = argparse.ArgumentParser() 135 | parser.add_argument('-c', '--config', default='configs/gsm8k.yaml', type=str, help='Path to the config file') 136 | args = parser.parse_args() 137 | 138 | config_path = args.config 139 | assert os.path.isfile(config_path), f"Invalid config path: {config_path}" 140 | 141 | config = load_config(config_path) 142 | generate(config) -------------------------------------------------------------------------------- /inference/batched_generate_vllm.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import os 3 | import time 4 | from functools import partial 5 | from pathlib import Path 6 | import torch 7 | from datasets import load_dataset 8 | from transformers import AutoModelForCausalLM, AutoTokenizer 9 | 10 | from huggingface_hub import _CACHED_NO_EXIST, try_to_load_from_cache 11 | from vllm import LLM, SamplingParams 12 | from data.prompt_loader import load_prompts 13 | import json 14 | 15 | def run_process_on_gpu(config, gpu_queue, data_frac): 16 | gpu_id = gpu_queue.get() 17 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 18 | print(f"Running on GPU: {gpu_id}") 19 | # Assuming the existence of a function that handles te generation process for a single GPU 20 | generate_on_single_gpu(config, data_frac) 21 | gpu_queue.put(gpu_id) 22 | 23 | def generate_on_single_gpu(config, data_frac): 24 | output_dir = config["output_dir"] 25 | output_dir = Path(output_dir) 26 | output_dir.mkdir(parents=True, exist_ok=True) 27 | print(f"Generating on GPU with data fraction: {data_frac}...") 28 | # load the base model and tokenizer 29 | model_path = config["model_name"] 30 | tokenizer_path = config["tokenizer_name"] 31 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 32 | tokenizer.pad_token = tokenizer.eos_token 33 | world_size = config["tp_per_worker"] 34 | llm = LLM(model=model_path, tokenizer=tokenizer_path, tensor_parallel_size=world_size) 35 | # sampling_params = SamplingParams(temperature=1.0, top_p=1.0, max_tokens=256) 36 | sampling_params = SamplingParams( 37 | temperature=config["sampling_params"]["temperature"], 38 | top_p=config["sampling_params"]["top_p"], 39 | max_tokens=config["sampling_params"]["max_tokens"], 40 | n=config["sampling_params"]["n"], 41 | ) 42 | 43 | # load data 44 | prompt = load_prompts(config["input_dir"]) 45 | if 'split' in config and 'subset' in config: 46 | data = load_dataset(config["input_dir"], config["subset"], split=config["split"]) 47 | elif 'subset' not in config: 48 | data = load_dataset(config["input_dir"], split=config["split"]) 49 | seed = config["seed"] 50 | data = data.shuffle(seed=seed) 51 | 52 | if config["frac_len"] > 0: 53 | sub_len = config["frac_len"] 54 | if sub_len * (data_frac + 1) > len(data): 55 | data = data[sub_len*data_frac:] 56 | else: 57 | data = data[sub_len*data_frac:sub_len*(data_frac+1)] 58 | 59 | if prompt is not None: 60 | prompts_all = [' [INST] ' + prompt + "\n\nQ: " + data[idx]['question'] + "[/INST] \nA: " for idx in range(len(data))] 61 | else: 62 | prompts_all = [" [INST] Q: " + data[idx]['question'] + "[/INST] \nA: " for idx in range(len(data))] 63 | prompts_old = [data[idx]['question'] for idx in range(len(data))] 64 | corrects_all = [data[idx]['answer'] for idx in range(len(data))] 65 | 66 | start_time = time.time() 67 | 68 | # run vllm 69 | if config["sampling_params"]["n"] == 1: 70 | results_gathered = list( 71 | map(lambda x: x.outputs[0].text, llm.generate(prompts_all, sampling_params)) 72 | ) 73 | else: 74 | # flatten x.outputs 75 | results_gathered = list( 76 | map(lambda x: [y.text for y in x.outputs], llm.generate(prompts_all, sampling_params)) 77 | ) 78 | results_gathered = [item for sublist in results_gathered for item in sublist] 79 | # print(results_gathered) 80 | # input() 81 | results = [r.replace("", "").lstrip() for r in results_gathered] 82 | print(len(results)) 83 | timediff = time.time() - start_time 84 | print(f"time elapsed: {timediff}") 85 | 86 | # collecting data 87 | for idx in range(len(corrects_all)): 88 | d = {"question": prompts_old[idx], "answer": corrects_all[idx], "generation": results[idx]} 89 | if config["split"] == 'test': 90 | file_name = f"{config['output_dir']}/{config['data_frac']}_test.jsonl" 91 | else: 92 | file_name = f"{config['output_dir']}/{config['data_frac']}.jsonl" 93 | with open(file_name, 'a') as f: 94 | json.dump(d, f) 95 | f.write('\n') 96 | 97 | def generate_on_multiple_gpus(config): 98 | start = time.time() 99 | mp.set_start_method("spawn", force=True) 100 | num_gpus = torch.cuda.device_count() 101 | print(f"Number of GPUs available: {num_gpus}") 102 | 103 | # Check if the model is already downloaded 104 | model_path = config["model_name"] 105 | tokenizer_path = config["tokenizer_name"] 106 | if not model_path.startswith("/"): # hub_path 107 | filepath = try_to_load_from_cache(model_path, "config.json") 108 | cache_dir = Path.home() / ".cache" / "huggingface" / "hub" 109 | model_directory = cache_dir / f"models--{model_path.replace('/', '--')}" 110 | 111 | print(f"checking cache results: {filepath}") 112 | if isinstance(filepath, str): 113 | print(f"Model {model_path} is alread downloaded.") 114 | else: 115 | print(f"Model {model_path} is not downloaded yet, will download now.") 116 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 117 | model = AutoModelForCausalLM.from_pretrained(model_path) 118 | print(f"Model {model_path} downloaded.") 119 | del tokenizer 120 | del model 121 | else: 122 | model_directory = model_path 123 | print(f"Model directory: {model_directory}") 124 | 125 | # create a pool of processes. Each process will run on a seperate GPU 126 | with mp.Manager() as manager: 127 | gpu_queue = manager.Queue() 128 | # Add the gpu_id to the queue 129 | for i in range(num_gpus): 130 | gpu_queue.put(i) 131 | 132 | with mp.Pool(processes=num_gpus) as pool: 133 | # Partial function with all arguments except the one that changes per process (data_frac) 134 | func = partial( 135 | run_process_on_gpu, 136 | config, 137 | ) 138 | 139 | # for each data_frac, scheduling one task 140 | res_futs = [] 141 | for data_frac in range(config["num_data_frac"]): 142 | res_futs.append( 143 | pool.apply_async( 144 | func, 145 | ( 146 | gpu_queue, 147 | data_frac, 148 | ) 149 | ) 150 | ) 151 | for res in res_futs: 152 | res.get() 153 | print(f"Total time taken: {time.time() - start}") 154 | -------------------------------------------------------------------------------- /reward_model/orm/orm_guide.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from datasets import Dataset, load_dataset 7 | from transformers import ( 8 | AutoModelForCausalLM, 9 | AutoTokenizer, 10 | AutoModelForSequenceClassification, 11 | DataCollatorWithPadding, 12 | ) 13 | from accelerate import Accelerator 14 | from accelerate.utils import InitProcessGroupKwargs 15 | from datetime import timedelta 16 | from tqdm import tqdm 17 | import warnings 18 | warnings.filterwarnings("ignore") 19 | from inference.generate import generate 20 | from reward_model.prm.prm_data import decompose_samples 21 | 22 | class orm_guided_generation(): 23 | def __init__(self, config): 24 | self.config = config 25 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 26 | if 'orm-classification' in config["reward_model"]["type"]: 27 | self.reward_model = AutoModelForSequenceClassification.from_pretrained(config["reward_model"]["model_name"], num_labels=2) 28 | self.reward_tokenizer = AutoTokenizer.from_pretrained(config["reward_model"]["tokenizer_name"]) 29 | self.reward_model = self.reward_model.to(self.device) 30 | self.reward_tokenizer = self.reward_tokenizer 31 | self.reward_tokenizer.pad_token = self.reward_tokenizer.eos_token 32 | 33 | def generate_solutions(self): 34 | # self.config["generator"]["output_dir"] = self.config["generator"]["output_dir"] + "_temp_candidates" 35 | generate(self.config["generator"]) 36 | if self.config["generator"]["split"] == 'test': 37 | if self.config['generator']['num_return_sequences'] == 1: 38 | solution_file = f"{self.config['generator']['output_dir']}/{self.config['generator']['data_frac']}_test.jsonl" 39 | else: 40 | solution_file = f'{self.config["generator"]["output_dir"]}/{self.config["generator"]["data_frac"]}_test_candidates.jsonl' 41 | else: 42 | solution_file = f'{self.config["generator"]["output_dir"]}/{self.config["generator"]["data_frac"]}.jsonl' 43 | # use json to read the list of dictionaries from solution_file 44 | self.solutions = load_dataset('json', data_files=solution_file, split='train') 45 | 46 | def read_solutions(self): 47 | solution_file = self.config["generator"]["input_dir"] 48 | # use json to read the list of dictionaries from solution_file 49 | self.solutions = load_dataset('json', data_files=solution_file, split='train') 50 | 51 | def tokenizer_dataset(self, data): 52 | return self.reward_tokenizer(data["text"], truncation=False) 53 | 54 | def process_solutions(self): 55 | # pass them through the reward model 56 | samples = [] 57 | idxes = [] 58 | template = "Q: {question}\nA: {answer}" 59 | for idx in tqdm(range(len(self.solutions))): 60 | sample = self.solutions[idx] 61 | prediction = sample['generation'].strip().split('\n#### ') 62 | if len(prediction) != 1: 63 | pred_answer = prediction[1].split('\n')[0] 64 | prediction = prediction[0] + '\n#### ' + pred_answer 65 | else: 66 | prediction = prediction[0] 67 | sample_text = template.format(question=sample['question'], answer=prediction) 68 | samples.append(sample_text) 69 | idxes.append(idx) 70 | 71 | # convert samples into huggingface dataset with Dataset.from_dict 72 | self.dataset = Dataset.from_dict({"idxes": idxes, "text": samples}).with_format("torch") 73 | data_collator = DataCollatorWithPadding(tokenizer=self.reward_tokenizer) 74 | def tokenized_dataset(data): 75 | return self.reward_tokenizer(data["text"], truncation=True) 76 | self.tokenized_dataset = self.dataset.map(tokenized_dataset, batched=True) 77 | self.tokenized_dataset = self.tokenized_dataset.remove_columns(["idxes", "text"]) 78 | # self.tokenized_dataset = self.tokenized_dataset.set_format("torch") 79 | self.dataloader = DataLoader(self.tokenized_dataset, batch_size=self.config["reward_model"]["per_device_eval_batch_size"], collate_fn=data_collator, shuffle=False) 80 | 81 | def get_reward_score(self): 82 | # pass the dataset through the reward model 83 | self.solution_scores = [] 84 | num = 0 85 | for batch in self.dataloader: 86 | batch = {k: v.to(self.device) for k, v in batch.items()} 87 | with torch.no_grad(): 88 | outputs = self.reward_model(**batch) 89 | logits = outputs.logits.detach().cpu() # B, 2 90 | # apply softmax on the logits 91 | logits = torch.nn.functional.softmax(logits, dim=-1) 92 | scores = logits[:,1] 93 | self.solution_scores.append(scores) 94 | self.solution_scores = torch.cat(self.solution_scores, dim=0) 95 | print(len(self.solution_scores), len(self.dataset)) 96 | 97 | def select_and_save(self): 98 | # select the solution with the highest score every num_return_sequences solutions 99 | selected_solutions = [] 100 | for idx in range(0, len(self.solutions), self.config["generator"]["num_return_sequences"]): 101 | if idx + self.config["generator"]["num_return_sequences"] < len(self.solutions): 102 | max_idx = np.argmax(self.solution_scores[idx:idx+self.config["generator"]["num_return_sequences"]]) 103 | selected_solutions.append(self.solutions[idx+int(max_idx)]) 104 | else: 105 | max_idx = np.argmax(self.solution_scores[idx:]) 106 | selected_solutions.append(self.solutions[idx+int(max_idx)]) 107 | # for i in range(self.config["generator"]["num_return_sequences"]): 108 | # print(self.solutions[idx+i], ' ---> ', self.solution_scores[idx+i]) 109 | if not os.path.exists(self.config["generator"]["output_dir"]): 110 | os.mkdir(self.config["generator"]["output_dir"]) 111 | if self.config["generator"]["split"] == 'test': 112 | solution_file = f'{self.config["generator"]["output_dir"]}/{self.config["generator"]["data_frac"]}_selected_test.jsonl' 113 | else: 114 | solution_file = f'{self.config["generator"]["output_dir"]}/{self.config["generator"]["data_frac"]}_selected.jsonl' 115 | 116 | for sol in selected_solutions: 117 | with open(solution_file, 'a') as f: 118 | json.dump(sol, f) 119 | f.write('\n') 120 | 121 | def guide_generation(self): 122 | # check if the 'input_dir' file exists 123 | if not os.path.exists(self.config["generator"]["input_dir"]): 124 | self.generate_solutions() 125 | else: 126 | self.read_solutions() 127 | self.process_solutions() 128 | self.get_reward_score() 129 | self.select_and_save() -------------------------------------------------------------------------------- /inference/generate_openai.py: -------------------------------------------------------------------------------- 1 | import os 2 | from openai import OpenAI, AzureOpenAI 3 | import numpy as np 4 | from utils.credentials import api_key_list 5 | from tenacity import wait_random_exponential, stop_after_attempt, retry, RetryError 6 | from data.dataset_loader import get_datasets 7 | import json 8 | from pathlib import Path 9 | from tqdm import tqdm 10 | 11 | class generate_openai(): 12 | def __init__(self, config=None): 13 | self.api_key_list = api_key_list(config["generator"]["openai_credentials"]) 14 | self.api_idx = 0 15 | self.client = AzureOpenAI( 16 | api_key=self.api_key_list[self.api_idx]["api_key"], 17 | api_version=self.api_key_list[self.api_idx]["api_version"], 18 | azure_endpoint=self.api_key_list[self.api_idx]["azure_endpoint"] 19 | ) 20 | self.model = self.api_key_list[self.api_idx]["model"] 21 | self.config = config 22 | self.token_usage = {"input": 0, "output": 0} 23 | 24 | def switch_api_key(self): 25 | self.api_idx = (self.api_idx + 1) % len(self.api_key_list) 26 | self.client = AzureOpenAI( 27 | api_key=self.api_key_list[self.api_idx]["api_key"], 28 | api_version=self.api_key_list[self.api_idx]["api_version"], 29 | azure_endpoint=self.api_key_list[self.api_idx]["azure_endpoint"] 30 | ) 31 | self.model = self.api_key_list[self.api_idx]["model"] 32 | 33 | def query(self, prompt, temp=None, n=None, stop=None, max_tokens=None,): 34 | prompt_chat = [ 35 | {"role": "user", "content": prompt} 36 | ] 37 | flag = False 38 | num_trials = 0 39 | while not flag: 40 | try: 41 | raw_response = self.client.chat.completions.create( 42 | model=self.model, 43 | messages=prompt_chat, 44 | max_tokens=self.config['generator']['max_length'] if max_tokens is None else max_tokens, 45 | temperature=self.config['generator']['temperature'] if temp is None else temp, 46 | frequency_penalty=self.config['generator']['frequency_penalty'], 47 | presence_penalty=self.config['generator']['presence_penalty'], 48 | stop=self.config['generator']['stop'] if stop is None else stop, 49 | n=self.config['generator']['num_return_sequences'] if n is None else n, 50 | ) 51 | self.token_usage["input"] += raw_response.usage.prompt_tokens 52 | self.token_usage["output"] += raw_response.usage.completion_tokens 53 | 54 | contents = [choice.message.content.strip() for choice in raw_response.choices] 55 | flag = True 56 | if len(contents) == 0: 57 | flag = False 58 | raise RuntimeError("No response from the API") 59 | except: 60 | self.switch_api_key() 61 | flag = False 62 | num_trials += 1 63 | if num_trials > 3: 64 | flag = True 65 | contents = None 66 | return contents 67 | 68 | def generate(self): 69 | train_data, val_data, test_data, prompt = get_datasets(self.config['seed'], self.config['generator']['input_dir']) 70 | if self.config['generator']["split"] == 'test': 71 | data = test_data 72 | elif self.config['generator']["split"] == 'val': 73 | data = val_data 74 | else: 75 | data = train_data 76 | generation = [] 77 | template = "{prompt}\nQ: {question}\nA:" 78 | for idx in tqdm(range(len(data))): 79 | prompt_msg = template.format(prompt=prompt, question=data[idx]['question'], answer=data[idx]['answer']) 80 | responses = self.query(prompt_msg) 81 | if responses != None: 82 | for res in responses: 83 | generation.append({"question": data[idx]['question'], "answer": data[idx]['answer'], "generation": res}) 84 | 85 | output_dir = Path(self.config['generator']["output_dir"]) 86 | output_dir.mkdir(parents=True, exist_ok=True) 87 | for gen in generation: 88 | if self.config['generator']["split"] == 'test': 89 | file_name = f"{self.config['generator']['output_dir']}/{self.config['generator']['data_frac']}_test.jsonl" 90 | else: 91 | file_name = f"{self.config['generator']['output_dir']}/{self.config['generator']['data_frac']}.jsonl" 92 | with open(file_name, 'a') as f: 93 | json.dump(gen, f) 94 | f.write('\n') 95 | 96 | class generate_vllm_openai(): 97 | def __init__(self, config=None): 98 | self.api_key_list = api_key_list(config["generator"]["openai_credentials"]) 99 | self.api_idx = 0 100 | self.client = OpenAI( 101 | api_key=self.api_key_list[self.api_idx]["api_key"], 102 | base_url=self.api_key_list[self.api_idx]["base_url"] 103 | ) 104 | self.model = self.api_key_list[self.api_idx]["model"] 105 | self.config = config 106 | self.token_usage = {"input": 0, "output": 0} 107 | 108 | def switch_api_key(self): 109 | self.api_idx = (self.api_idx + 1) % len(self.api_key_list) 110 | self.client = OpenAI( 111 | api_key=self.api_key_list[self.api_idx]["api_key"], 112 | base_url=self.api_key_list[self.api_idx]["base_url"] 113 | ) 114 | self.model = self.api_key_list[self.api_idx]["model"] 115 | 116 | def query(self, prompt, temp=None, n=None, stop=None, max_tokens=None,): 117 | prompt_chat = [ 118 | {"role": "user", "content": prompt} 119 | ] 120 | flag = False 121 | num_trials = 0 122 | while not flag: 123 | try: 124 | raw_response = self.client.chat.completions.create( 125 | model=self.model, 126 | messages=prompt_chat, 127 | max_tokens=self.config['generator']['max_length'] if max_tokens is None else max_tokens, 128 | temperature=self.config['generator']['temperature'] if temp is None else temp, 129 | frequency_penalty=self.config['generator']['frequency_penalty'], 130 | presence_penalty=self.config['generator']['presence_penalty'], 131 | stop=self.config['generator']['stop'] if stop is None else stop, 132 | n=self.config['generator']['num_return_sequences'] if n is None else n, 133 | ) 134 | self.token_usage["input"] += raw_response.usage.prompt_tokens 135 | self.token_usage["output"] += raw_response.usage.completion_tokens 136 | 137 | contents = [choice.message.content.strip() for choice in raw_response.choices] 138 | flag = True 139 | if len(contents) == 0: 140 | flag = False 141 | raise RuntimeError("No response from the API") 142 | except: 143 | self.switch_api_key() 144 | flag = False 145 | num_trials += 1 146 | if num_trials > 3: 147 | flag = True 148 | contents = None 149 | return contents 150 | 151 | def generate(self): 152 | train_data, val_data, test_data, prompt = get_datasets(self.config['seed'], self.config['generator']['input_dir']) 153 | if self.config['generator']["split"] == 'test': 154 | data = test_data 155 | elif self.config['generator']["split"] == 'val': 156 | data = val_data 157 | else: 158 | data = train_data 159 | generation = [] 160 | template = "{prompt}\nQ: {question}\nA:" 161 | for idx in tqdm(range(len(data))): 162 | prompt_msg = template.format(prompt=prompt, question=data[idx]['question'], answer=data[idx]['answer']) 163 | responses = self.query(prompt_msg) 164 | if responses != None: 165 | for res in responses: 166 | generation.append({"question": data[idx]['question'], "answer": data[idx]['answer'], "generation": res}) 167 | 168 | output_dir = Path(self.config['generator']["output_dir"]) 169 | output_dir.mkdir(parents=True, exist_ok=True) 170 | for gen in generation: 171 | if self.config['generator']["split"] == 'test': 172 | file_name = f"{self.config['generator']['output_dir']}/{self.config['generator']['data_frac']}_test.jsonl" 173 | else: 174 | file_name = f"{self.config['generator']['output_dir']}/{self.config['generator']['data_frac']}.jsonl" 175 | with open(file_name, 'a') as f: 176 | json.dump(gen, f) 177 | f.write('\n') -------------------------------------------------------------------------------- /data/prompt_loader.py: -------------------------------------------------------------------------------- 1 | GSM8K_PROMPT = """Q: Ivan has a bird feeder in his yard that holds two cups of birdseed. Every week, he has to refill the emptied feeder. Each cup of birdseed can feed fourteen birds, but Ivan is constantly chasing away a hungry squirrel that steals half a cup of birdseed from the feeder every week. How many birds does Ivan’s bird feeder feed weekly? 2 | A: Let's think step by step. 3 | The squirrel steals 1/2 cup of birdseed every week, so the birds eat 2 - 1/2 = 1 1/2 cups of birdseed. 4 | Each cup feeds 14 birds, so Ivan's bird feeder feeds 14 * 1 1/2 = 21 birds weekly. 5 | #### The answer is 21 6 | 7 | Q: Samuel took 30 minutes to finish his homework while Sarah took 1.3 hours to finish it. How many minutes faster did Samuel finish his homework than Sarah? 8 | A: Let's think step by step. 9 | Since there are 60 minutes in 1 hour, then 1.3 hours is equal to 1.3 x 60 = 78 minutes. 10 | Thus, Samuel is 78 - 30 = 48 minutes faster than Sarah. 11 | #### The answer is 48 12 | 13 | Q: Julia bought 3 packs of red balls, 10 packs of yellow balls, and 8 packs of green balls. There were 19 balls in each package. How many balls did Julie buy in all? 14 | A: Let's think step by step. 15 | The total number of packages is 3 + 10 + 8 = 21. 16 | Julia bought 21 × 19 = 399 balls. 17 | #### The answer is 399 18 | 19 | Q: Lexi wants to run a total of three and one-fourth miles. One lap on a particular outdoor track measures a quarter of a mile around. How many complete laps must she run? 20 | A: Let's think step by step. 21 | There are 3/ 1/4 = 12 one-fourth miles in 3 miles. 22 | So, Lexi will have to run 12 (from 3 miles) + 1 (from 1/4 mile) = 13 complete laps. 23 | #### The answer is 13 24 | """ 25 | 26 | PUBMEDQA_PROMPT = ''' 27 | Use the step-by-step method as shown in the example to answer the question. You should give the reasoning steps and final answer based on the provided context. 28 | 29 | Example: 30 | Q: Do familiar teammates request and accept more backup? 31 | A: Transactive memory theory extends to high-stress environments in which members' expertise is highly overlapping. 32 | Teammates' shared mental models about one another increase the likelihood that they will request and accept backup. 33 | #### Yes. 34 | 35 | Here is your question. Please respond to this question based on the context and by adhering to the given format: provide step-by-step reasoning (one sentence per line), then give the final answer (Yes/No/Maybe) after '####'. 36 | '''.strip() 37 | 38 | MEDMCQA_PROMPT = ''' 39 | Use the step-by-step method as shown in the example to answer the question. You should give the explanation steps and final answer based on the provided context. 40 | 41 | Example: 42 | Q: What is the most probable poal of entry of Aspergillus? (A) Puncture wound, (B) Blood, (C) Lungs, (D) Gastrointestinal tract 43 | A: Aspergillus species are widely distributed on decaying plants, producing chains of conidia. 44 | Aspergillus species unlike Candida species do not form the pa of normal flora of humans. 45 | They are ubiquitous in the environment; hence transmission of infection is mostly exogenous. 46 | Aspergillus transmission occurs by inhalation of airborne conidia. 47 | Risk Factors for invasive aspergillosis are: Glucocoicoid use (the most impoant risk factor) Profound neutropenia or Neutrophil dysfunction Underlying pneumonia or COPD, tuberculosis or sarcoidosis Antitumor necrosis factor therapy. 48 | #### C. 49 | 50 | Here is your question. Please respond to this question based on the context and by adhering to the given format: provide step-by-step reasoning (one sentence per line), then give the final answer (A/B/C/D) after '####'. 51 | '''.strip() 52 | 53 | MMLU_PROMPT = ''' 54 | Use the step-by-step method as shown in the example to answer the question. You should give the reasoning steps and final answer based on the provided context. 55 | 56 | Example: 57 | Q: What size of cannula would you use in a patient who needed a rapid blood transfusion (as of 2020 medical knowledge)? (A) 18 gauge, (B) 20 gauge, (C) 22 gauge, (D) 24 gauge. 58 | A: The gauge of a cannula indicates its diameter: the smaller the number, the larger the diameter of the cannula. 59 | A larger diameter cannula allows for the rapid administration of fluids, including blood. 60 | In emergency situations requiring rapid transfusion, a larger cannula is preferred to ensure quick delivery of blood to the patient. 61 | An 18 gauge cannula is larger than the 20, 22, and 24 gauge options and is commonly used for rapid transfusions. 62 | #### A. 63 | 64 | Here is your question. Please respond to this question based on the context and by adhering to the given format: provide step-by-step reasoning (one sentence per line), then give the final answer (A/B/C/D) after '####'. 65 | '''.strip() 66 | 67 | MEDQA_PROMPT = ''' 68 | Use the step-by-step method as shown in the example to answer the question. You should give the reasoning steps and final answer based on the provided context. 69 | 70 | Example: 71 | Q: A 21-year-old sexually active male complains of fever, pain during urination, and inflammation and pain in the right knee. A culture of the joint fluid shows a bacteria that does not ferment maltose and has no polysaccharide capsule. The physician orders antibiotic therapy for the patient. The mechanism of action of action of the medication given blocks cell wall synthesis, which of the following was given? (A) Gentamicin, (B) Ciprofloxacin, (C) Ceftriaxone, (D) Trimethoprim. 72 | A: The symptoms and culture results suggest a bacterial infection that affects both the urinary tract and joints, indicating a systemic infection. 73 | Bacteria that do not ferment maltose and lack a polysaccharide capsule could indicate a variety of bacteria, but the treatment approach focuses on the mechanism of action of the antibiotic rather than the specific bacteria. 74 | Antibiotics that block cell wall synthesis are typically beta-lactams, which include penicillins and cephalosporins. 75 | Gentamicin is an aminoglycoside antibiotic, which works by inhibiting protein synthesis. 76 | Ciprofloxacin is a fluoroquinolone, which works by inhibiting bacterial DNA gyrase and topoisomerase IV, affecting DNA replication. 77 | Ceftriaxone is a third-generation cephalosporin, which works by inhibiting cell wall synthesis. 78 | Trimethoprim is an antibiotic that inhibits bacterial dihydrofolate reductase, affecting folic acid synthesis. 79 | #### C. 80 | 81 | Here is your question. Please respond to this question based on the context and by adhering to the given format: provide step-by-step reasoning (one sentence per line), then give the final answer (A/B/C/D) after '####'. 82 | '''.strip() 83 | 84 | BIOASQ_PROMPT = ''' 85 | Use the step-by-step method as shown in the example to answer the question. You should give the reasoning steps and final answer based on the provided context. 86 | 87 | Example: 88 | Q: Can losartan reduce brain atrophy in Alzheimer's disease? 89 | A: Losartan is primarily used for hypertension and may indirectly affect factors associated with Alzheimer's disease progression. 90 | Despite potential neuroprotective effects, such as reducing inflammation and oxidative stress, there is limited direct evidence linking losartan to reduced brain atrophy in Alzheimer's disease. 91 | Clinical trials specifically targeting this outcome are necessary to establish a definitive effect. 92 | #### no 93 | 94 | Here is your question. Please respond to this question based on the context and by adhering to the given format: provide step-by-step reasoning (one sentence per line), then give the final answer (yes/no) after '####'. 95 | '''.strip() 96 | 97 | MEDNLI_PROMPT = ''' 98 | What is the relationship between the given two sentences? Please answer from [entailment, neutral, contradiction]. Please give the answer after '####'. 99 | 100 | Example: 101 | Sentence A: Labs were notable for Cr 1.7 (baseline 0.5 per old records) and lactate 2.4. 102 | Sentence B: Patient has elevated Cr 103 | Answer: #### entailment 104 | 105 | Here are the given two sentences. Please then give the final answer (entailment/neutral/contradiction) after '####'. 106 | ''' 107 | 108 | MEDIQA_RQE_PROMPT = ''' 109 | Does the provided solution correctly answer thq question? Please answer from [true, false]. 110 | 111 | Example: 112 | Question: What is High Blood Pressure? 113 | Solution: High Blood Pressure. I know you may not answer this but my blood pressure comes up at night when I am asleep. I take four medicines. I have asked doctors why this happens and no one knows. This morning at four A.M. It was 164 and I took a clonidine to help get it done. It worries me so. 114 | Judge: #### false 115 | 116 | Here is the question and answer. Please then give the final judge (true/false) after '####'. 117 | ''' 118 | 119 | PUBHEALTH_PROMPT = ''' 120 | Use the step-by-step method as shown in the example to answer the question. You should give the thought steps and final answer based on the provided context. Please judge whether the claim is true or false. 121 | 122 | Example: 123 | Claim: Annual Mammograms May Have More False-Positives October 18, 2011 124 | Judge: This article reports on the results of a study of nearly 170,000 women who had screening mammograms beginning between age 40-59. The study found that over ten years of screening mammograms, over half of the women will experience a false-positive recall for additional mammography. In addition, 7%-9% of the women will have a biopsy for a suspicious lump which is not cancerous. Both of those percentages decrease if the woman is screened every other year rather than every year. Even with biennial mammography, 41% of women will experience a recall over 10 years of mammography. The study’s Principal Investigator emphasized that “in most cases, a recall doesn’t mean you have cancer.”  She hoped this knowledge would reduce the anxiety of women who are recalled. The story never explained the size of the decrease in the number of false positives between annual (61.3%) and biennial screening (41.6%). Our first two reviewers were a researcher who specializes in health decisions and a breast cancer survivor trained in evidence by the Natiional Breast Cancer Coalition’s Project LEAD. This study is valuable because it helps to quantify and compare the harms of annual and biennial screening, specifically the number of false positives and the number of unnecessary biopsies. Prior to this study, estimates of false positive screening mammography rates varied widely. The critical question is whether you can do less frequent screening, subject women to fewer harms and get similar results in terms of detection of “early stage” cancer. This study’s data seems to suggest that answer is yes. 125 | #### mixture 126 | 127 | Here is the claim. Please then give the final judge (true/false/mixture/unproven) after '####'. 128 | ''' 129 | CORD19_PROMPT = '''Use one sentence to summarize the given paragraph. 130 | 131 | Example: 132 | Paragraph: Cardiovascular disease is the leading cause of death globally. While pharmacological advancements have improved the morbidity and mortality associated with cardiovascular disease, non-adherence to prescribed treatment remains a significant barrier to improved patient outcomes. A variety of strategies to improve medication adherence have been tested in clinical trials, and include the following categories: improving patient education, implementing medication reminders, testing cognitive behavioral interventions, reducing medication costs, utilizing healthcare team members, and streamlining medication dosing regimens. In this review, we describe specific trials within each of these categories and highlight the impact of each on medication adherence. We also examine ongoing trials and future lines of inquiry for improving medication adherence in patients with cardiovascular diseases. 133 | Summary: Medication adherence in cardiovascular medicine. 134 | 135 | Here is the paragraph. Please then give the final one-sentence summary after 'Summary:'. 136 | ''' 137 | 138 | def load_prompts(dataset_name): 139 | if 'gsm8k' in dataset_name.lower(): 140 | return None #GSM8K_PROMPT 141 | elif 'pubmedqa' in dataset_name.lower(): 142 | return PUBMEDQA_PROMPT 143 | elif 'medmcqa' in dataset_name.lower(): 144 | return MEDMCQA_PROMPT 145 | elif 'mmlu' in dataset_name.lower(): 146 | return MMLU_PROMPT 147 | elif 'medqa' in dataset_name.lower(): 148 | return MEDQA_PROMPT 149 | elif 'bioasq' in dataset_name.lower(): 150 | return BIOASQ_PROMPT 151 | elif 'mednli' in dataset_name.lower(): 152 | return MEDNLI_PROMPT 153 | elif 'mediqa-rqe' in dataset_name.lower(): 154 | return MEDIQA_RQE_PROMPT 155 | elif 'pubhealth' in dataset_name.lower(): 156 | return PUBHEALTH_PROMPT 157 | else: 158 | return None 159 | -------------------------------------------------------------------------------- /generator/trainer.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import warnings 3 | from collections import defaultdict 4 | from copy import deepcopy 5 | from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from accelerate.utils import is_deepspeed_available 12 | from datasets import Dataset 13 | from torch.utils.data import DataLoader 14 | 15 | from transformers import ( 16 | AutoModelForCausalLM, 17 | DataCollator, 18 | PreTrainedModel, 19 | PreTrainedTokenizerBase, 20 | Trainer, 21 | TrainingArguments, 22 | ) 23 | 24 | from transformers.trainer_callback import TrainerCallback 25 | from transformers.trainer_utils import EvalLoopOutput 26 | 27 | from trl.import_utils import is_peft_available, is_wandb_available 28 | from trl.models import PreTrainedModelWrapper, create_reference_model 29 | from trl.trainer.utils import disable_dropout_in_model, pad_to_length 30 | 31 | if is_peft_available(): 32 | from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training 33 | 34 | if is_wandb_available(): 35 | import wandb 36 | 37 | if is_deepspeed_available(): 38 | import deepspeed 39 | 40 | class GENTrainer(Trainer): 41 | def __init__( 42 | self, 43 | model: Union[PreTrainedModel, nn.Module, str] = None, 44 | ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, 45 | beta: float = 0.1, 46 | loss_type: Literal["sigmoid", "hinge"] = "sigmoid", 47 | args: TrainingArguments = None, 48 | data_collator: Optional[DataCollator] = None, 49 | label_pad_token_id: int = -100, 50 | padding_value: int = 0, 51 | truncation_mode: str = "keep_end", 52 | train_dataset: Optional[Dataset] = None, 53 | eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, 54 | tokenizer: Optional[PreTrainedTokenizerBase] = None, 55 | model_init: Optional[Callable[[], PreTrainedModel]] = None, 56 | callbacks: Optional[List[TrainerCallback]] = None, 57 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( 58 | None, 59 | None, 60 | ), 61 | preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, 62 | max_length: Optional[int] = None, 63 | max_prompt_length: Optional[int] = None, 64 | max_target_length: Optional[int] = None, 65 | peft_config: Optional[Dict] = None, 66 | is_encoder_decoder: Optional[bool] = None, 67 | disable_dropout: bool = True, 68 | generate_during_eval: bool = False, 69 | compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, 70 | model_init_kwargs: Optional[Dict] = None, 71 | ref_model_init_kwargs: Optional[Dict] = None, 72 | ): 73 | if model_init_kwargs is None: 74 | model_init_kwargs = {} 75 | elif not isinstance(model, str): 76 | raise ValueError("You passed model_kwargs to the trainer. But your model is already instantiated.") 77 | 78 | if ref_model_init_kwargs is None: 79 | ref_model_init_kwargs = {} 80 | elif not isinstance(ref_model, str): 81 | raise ValueError("") 82 | 83 | if isinstance(model, str): 84 | warnings.warn( 85 | "You passed a model_id to the trainer. This will automatically create models for you." 86 | ) 87 | model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) 88 | 89 | if isinstance(ref_model, str): 90 | warnings.warn( 91 | "You passed a ref model_id to the trainer. This will automatically create a ref model for you." 92 | ) 93 | 94 | if not is_peft_available() and peft_config is not None: 95 | raise ValueError("PEFT is not installed and you passed a 'peft_config' to the trainer, please install it to use the PEFT models.") 96 | elif is_peft_available() and peft_config is not None: 97 | if isinstance(model, PeftModel): 98 | model = model.merge_and_unload() 99 | 100 | if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): 101 | _support_gc_kwargs = hasattr( 102 | args, "gradient_checkpointing_kwargs" 103 | ) and "gradient_checkpointing_kwargs" in list( 104 | inspect.signature(prepare_model_for_kbit_training).parameters 105 | ) 106 | prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} 107 | if _support_gc_kwargs: 108 | prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs 109 | model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) 110 | elif getattr(model, "gradient_checkpointing", False): 111 | if hasattr(model, "enable_input_require_grads"): 112 | model.enable_input_require_grads() 113 | else: 114 | def make_inputs_require_grad(module, input, output): 115 | output.requires_grad_(True) 116 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 117 | model = get_peft_model(model, peft_config) 118 | elif getattr(args, "gradient_checkpointing", False): 119 | if hasattr(model, "enable_input_require_grads"): 120 | model.enable_input_require_grads() 121 | else: 122 | def make_inputs_require_grad(module, input, output): 123 | output.requires_grad_(True) 124 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 125 | if generate_during_eval and not is_wandb_available(): 126 | raise ValueError("You passed 'generate_during_eval=True' to the trainer, please install wandb to use the generation during eval.") 127 | 128 | if model is not None: 129 | self.is_encoder_decoder = model.config.is_encoder_decoder 130 | elif is_encoder_decoder is None: 131 | raise ValueError("Please pass is_encoder_decoder to the trainer.") 132 | else: 133 | self.is_encoder_decoder = is_encoder_decoder 134 | 135 | self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) 136 | 137 | if ref_model: 138 | self.ref_model = ref_model 139 | elif self.is_peft_model: 140 | self.ref_model = None 141 | else: 142 | self.ref_model = create_reference_model(model) 143 | 144 | if data_collator is None: 145 | if tokenizer is None: 146 | raise ValueError("max_length or a tokenizer must be specified when using the default.") 147 | if max_length is None: 148 | warnings.warn( 149 | "You should set 'max_length' in the trainer", 150 | UserWarning, 151 | ) 152 | max_length = 512 153 | if max_prompt_length is None: 154 | warnings.warn( 155 | "When using the default data_collator, you should set 'max_prompt_length' in the trainer", 156 | UserWarning 157 | ) 158 | max_prompt_length = 128 159 | if max_target_length is None and self.is_encoder_decoder: 160 | warnings.warn( 161 | "You should set 'max_target_length' in the trainer when using the default data_collator", 162 | UserWarning, 163 | ) 164 | max_target_length = 128 165 | 166 | data_collator = DataCollatorWithPadding( 167 | tokenizer, 168 | max_length=max_length, 169 | max_prompt_length=max_prompt_length, 170 | label_pad_token_id=label_pad_token_id, 171 | padding_value=padding_value, 172 | truncation_mode=truncation_mode, 173 | is_encoder_decoder=self.is_encoder_decoder, 174 | max_target_length=max_target_length, 175 | ) 176 | 177 | if args.remove_unused_columns: 178 | args.remove_unused_columns = False 179 | warnings.warn( 180 | "You should set 'remove_unused_columns' to False when using the default data_collator", 181 | UserWarning 182 | ) 183 | self.use_data_collator = True 184 | else: 185 | self.use_data_collator = False 186 | 187 | if disable_dropout: 188 | disable_dropout_in_model(model) 189 | if self.ref_model is not None: 190 | disable_dropout_in_model(self.ref_model) 191 | 192 | self.max_length = max_length 193 | self.generate_during_eval = generate_during_eval 194 | self.label_pad_token_id = label_pad_token_id 195 | self.padding_value = padding_value 196 | 197 | self.beta = beta 198 | self.loss_type = loss_type 199 | 200 | self._stored_metrics = defaultdict(lambda: defaultdict(list)) 201 | 202 | super().__init__( 203 | model=model, 204 | args=args, 205 | data_collator=data_collator, 206 | train_dataset=train_dataset, 207 | eval_dataset=eval_dataset, 208 | tokenizer=tokenizer, 209 | model_init=model_init, 210 | compute_metrics=compute_metrics, 211 | callbacks=callbacks, 212 | optimizers=optimizers, 213 | preprocess_logits_for_metrics=preprocess_logits_for_metrics 214 | ) 215 | 216 | if not hasattr(self, "accelerator"): 217 | raise AttributeError( 218 | "Your traininger does not have an accelerate object. Consider upgrading 'transformers'." 219 | ) 220 | 221 | if self.ref_model is None: 222 | if not hasattr(self.accelerator.unwrap_model(self.model), "disable_adapter"): 223 | raise ValueError( 224 | "You are using a 'peft' version that does not support the 'disable_adapter'. Please update your 'peft' version to the latest version." 225 | ) 226 | else: 227 | if self.is_deepspeed_enabled: 228 | self.ref_model = self._prepare_deepspeed(self.ref_model) 229 | else: 230 | self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) 231 | 232 | def _prepare_deepspeed(self, model: PreTrainedModelWrapper): 233 | deepspeed_plugin = self.accelerator.state.deepspeed_plugin 234 | config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) 235 | if model is not None: 236 | if hasattr(model, "config"): 237 | hidden_size = ( 238 | max(model.config.hidden_sizes) 239 | if getattr(model.config, "hidden_sizes", None) 240 | else getattr(model.config, "hidden_size", None) 241 | ) 242 | if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: 243 | config_kwargs.update( 244 | { 245 | "zero_otimization.reduce_bucket_size": hidden_size, 246 | "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, 247 | "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, 248 | } 249 | ) 250 | if config_kwargs["zero_optimization"]["stage"] != 3: 251 | config_kwargs["zero_optimization"]["stage"] = 0 252 | model, *_ = deepspeed.initialize(model=model, config=config_kwargs) 253 | model.eval() 254 | return model 255 | 256 | def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: 257 | concatenated_batch = {} 258 | if self.is_encoder_decoder: 259 | max_length = max(batch["real_labels"].shape[1], batch["generated_labels"].shape[1]) 260 | else: 261 | max_legnth = max(batch["real_input_ids"].shape[1], batch["generated_input_ids"].shape[1]) 262 | 263 | for k in batch: 264 | if k.startswith("real") and isinstance(batch[k], torch.Tensor): 265 | pad_value = self.label_pad_token_id if "labels" in k or self.is_encoder_decoder else self.padding_value 266 | concatenated_key = k.replace("generated", "concatenated") 267 | concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) 268 | -------------------------------------------------------------------------------- /data/dataset_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datasets import load_dataset 3 | import random 4 | import numpy as np 5 | import torch 6 | import openai 7 | import json 8 | from data.prompt_loader import load_prompts 9 | 10 | def get_datasets(seed, dataset_name): 11 | prompt = None 12 | if dataset_name == 'pubmedqa': 13 | dataset = load_dataset('pubmed_qa', 'pqa_labeled', split='train') 14 | dataset = dataset.shuffle(seed=seed) 15 | train_dataset = [] 16 | for i in range(450): 17 | contexts = dataset[i]['context'] 18 | contexts_final = 'Contexts:\n' 19 | for j in range(len(contexts['contexts'])): 20 | contexts_final += contexts['labels'][j] + ': ' + contexts['contexts'][j] + '\n' 21 | question = contexts_final + dataset[i]['question'] 22 | answer = dataset[i]['long_answer']+'\n#### '+dataset[i]['final_decision'] 23 | train_dataset.append({'question': question, 'answer': answer}) 24 | val_dataset = [] 25 | for i in range(450, 500): 26 | contexts = dataset[i]['context'] 27 | contexts_final = 'Contexts:\n' 28 | for j in range(len(contexts['contexts'])): 29 | contexts_final += contexts['labels'][j] + ': ' + contexts['contexts'][j] + '\n' 30 | question = contexts_final + dataset[i]['question'] 31 | answer = dataset[i]['long_answer']+'\n#### '+dataset[i]['final_decision'] 32 | val_dataset.append({'question': question, 'answer': answer}) 33 | test_dataset = [] 34 | for i in range(500, 1000): 35 | contexts = dataset[i]['context'] 36 | contexts_final = 'Contexts:\n' 37 | for j in range(len(contexts['contexts'])): 38 | contexts_final += contexts['labels'][j] + ': ' + contexts['contexts'][j] + '\n' 39 | question = contexts_final + dataset[i]['question'] 40 | answer = dataset[i]['long_answer']+'\n#### '+dataset[i]['final_decision'] 41 | test_dataset.append({'question': question, 'answer': answer}) 42 | print(f"Train: {len(train_dataset)} Val: {len(val_dataset)} Test: {len(test_dataset)}") 43 | prompt = load_prompts(dataset_name) 44 | elif dataset_name == 'medmcqa': 45 | temp_dataset = load_dataset('medmcqa', split='train') 46 | temp_dataset = temp_dataset.shuffle(seed=seed) 47 | train_dataset = [] 48 | num_train = 0 49 | option_dict = {'0': 'A', '1': 'B', '2': 'C', '3': 'D'} 50 | for i in range(len(temp_dataset)): 51 | options = ['(A) '+temp_dataset[i]['opa'], '(B) '+temp_dataset[i]['opb'], '(C) '+temp_dataset[i]['opc'], '(D) '+temp_dataset[i]['opd']] 52 | question = temp_dataset[i]['question'] + ' ' + ', '.join(options) 53 | if temp_dataset[i]['exp'] != None: 54 | answer = str(temp_dataset[i]['exp']) + '\n#### '+ option_dict[str(temp_dataset[i]['cop'])] 55 | train_dataset.append({'question': question, 'answer': answer}) 56 | num_train += 1 57 | if num_train >= 2000: 58 | break 59 | temp_dataset = load_dataset('medmcqa', split='validation') 60 | temp_dataset = temp_dataset.shuffle(seed=seed) 61 | val_dataset = [] 62 | for i in range(len(temp_dataset)): 63 | options = ['(A) '+temp_dataset[i]['opa'], '(B) '+temp_dataset[i]['opb'], '(C) '+temp_dataset[i]['opc'], '(D) '+temp_dataset[i]['opd']] 64 | question = temp_dataset[i]['question'] + ' ' + ', '.join(options) 65 | answer = str(temp_dataset[i]['exp']) + '\n#### '+ option_dict[str(temp_dataset[i]['cop'])] 66 | val_dataset.append({'question': question, 'answer': answer}) 67 | temp_dataset = load_dataset('medmcqa', split='validation') 68 | temp_dataset = temp_dataset.shuffle(seed=seed) 69 | test_dataset = [] 70 | for i in range(len(temp_dataset)): 71 | options = ['(A) '+temp_dataset[i]['opa'], '(B) '+temp_dataset[i]['opb'], '(C) '+temp_dataset[i]['opc'], '(D) '+temp_dataset[i]['opd']] 72 | question = temp_dataset[i]['question'] + ' ' + ', '.join(options) 73 | answer = str(temp_dataset[i]['exp']) + '\n#### '+ option_dict[str(temp_dataset[i]['cop'])] 74 | test_dataset.append({'question': question, 'answer': answer}) 75 | print(f"Train: {len(train_dataset)} Val: {len(val_dataset)} Test: {len(test_dataset)}") 76 | prompt = load_prompts(dataset_name) 77 | elif dataset_name == 'mmlu': 78 | subset_names = ['clinical_knowledge', 'college_biology', 'college_medicine', 'high_school_biology', 'medical_genetics', 'professional_medicine', 'virology'] 79 | total_dataset = [] 80 | option_dict = {'0': 'A', '1': 'B', '2': 'C', '3': 'D'} 81 | for subset in subset_names: 82 | dataset = load_dataset('lukaemon/mmlu', subset) 83 | # merge dataset in total_dataset 84 | for key in dataset.keys(): 85 | for i in range(len(dataset[key])): 86 | total_dataset.append(dataset[key][i]) 87 | random.shuffle(total_dataset) 88 | train_dataset = [] 89 | for i in range(int(0.8 * len(total_dataset))): 90 | question = total_dataset[i]['input'] 91 | options = ['(A) '+total_dataset[i]['A'], '(B) '+total_dataset[i]['B'], '(C) '+total_dataset[i]['C'], '(D) '+total_dataset[i]['D']] 92 | question = question + ' ' + ', '.join(options) 93 | answer = total_dataset[i]['target'] 94 | train_dataset.append({'question': question, 'answer': answer}) 95 | val_dataset = [] 96 | for i in range(int(0.8 * len(total_dataset)), int(0.9 * len(total_dataset))): 97 | question = total_dataset[i]['input'] 98 | options = ['(A) '+total_dataset[i]['A'], '(B) '+total_dataset[i]['B'], '(C) '+total_dataset[i]['C'], '(D) '+total_dataset[i]['D']] 99 | question = question + ' ' + ', '.join(options) 100 | answer = total_dataset[i]['target'] 101 | val_dataset.append({'question': question, 'answer': answer}) 102 | test_dataset = [] 103 | for i in range(int(0.9 * len(total_dataset)), len(total_dataset)): 104 | question = total_dataset[i]['input'] 105 | options = ['(A) '+total_dataset[i]['A'], '(B) '+total_dataset[i]['B'], '(C) '+total_dataset[i]['C'], '(D) '+total_dataset[i]['D']] 106 | question = question + ' ' + ', '.join(options) 107 | answer = total_dataset[i]['target'] 108 | test_dataset.append({'question': question, 'answer': answer}) 109 | print(f"Train: {len(train_dataset)} Val: {len(val_dataset)} Test: {len(test_dataset)}") 110 | prompt = load_prompts(dataset_name) 111 | elif dataset_name == 'medqa': 112 | dataset = load_dataset('bigbio/med_qa', 'med_qa_en_4options_bigbio_qa', split='train') 113 | dataset = dataset.shuffle(seed=seed) 114 | train_dataset = [] 115 | option_dict = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'] 116 | for i in range(len(dataset)): 117 | question = dataset[i]['question'] 118 | choices = dataset[i]['choices'] 119 | options = ['(A) '+choices[0], '(B) '+choices[1], '(C) '+choices[2], '(D) '+choices[3]] 120 | question = question + ' ' + ', '.join(options) 121 | ground_truth = dataset[i]['answer'][0] 122 | # find answer in the choices and return the choice id 123 | for j in range(len(choices)): 124 | if choices[j] == ground_truth: 125 | answer = option_dict[j] 126 | train_dataset.append({'question': question, 'answer': answer}) 127 | dataset = load_dataset('bigbio/med_qa', 'med_qa_en_4options_bigbio_qa', split='validation') 128 | dataset = dataset.shuffle(seed=seed) 129 | val_dataset = [] 130 | for i in range(len(dataset)): 131 | question = dataset[i]['question'] 132 | choices = dataset[i]['choices'] 133 | options = ['(A) '+choices[0], '(B) '+choices[1], '(C) '+choices[2], '(D) '+choices[3]] 134 | question = question + ' ' + ', '.join(options) 135 | ground_truth = dataset[i]['answer'][0] 136 | # find answer in the choices and return the choice id 137 | for j in range(len(choices)): 138 | if choices[j] == ground_truth: 139 | answer = option_dict[j] 140 | val_dataset.append({'question': question, 'answer': answer}) 141 | dataset = load_dataset('bigbio/med_qa', 'med_qa_en_4options_bigbio_qa', split='test') 142 | dataset = dataset.shuffle(seed=seed) 143 | test_dataset = [] 144 | for i in range(len(dataset)): 145 | question = dataset[i]['question'] 146 | choices = dataset[i]['choices'] 147 | options = ['(A) '+choices[0], '(B) '+choices[1], '(C) '+choices[2], '(D) '+choices[3]] 148 | question = question + ' ' + ', '.join(options) 149 | ground_truth = dataset[i]['answer'][0] 150 | # find answer in the choices and return the choice id 151 | for j in range(len(choices)): 152 | if choices[j] == ground_truth: 153 | answer = option_dict[j] 154 | test_dataset.append({'question': question, 'answer': answer}) 155 | print(f"Train: {len(train_dataset)} Val: {len(val_dataset)} Test: {len(test_dataset)}") 156 | prompt = load_prompts(dataset_name) 157 | elif dataset_name == 'bioasq': 158 | dataset = [item for i in [11,10,9,8,7] for j in range(len([f for f in os.listdir("./data/bioasq/Task{:d}BGoldenEnriched".format(i)) if f.endswith(".json")])) for item in json.load(open("./data/bioasq/Task{:d}BGoldenEnriched/{:d}B{:d}_golden.json".format(i, i, j+1)))["questions"] if item["type"] == "yesno"] 159 | random.shuffle(dataset) 160 | train_dataset = [] 161 | for i in range(int(0.8 * len(dataset))): 162 | question = dataset[i]['body'] 163 | answer = dataset[i]['exact_answer'] 164 | contexts = 'Context:\n' 165 | for j in range(len(dataset[i]['snippets'])): 166 | contexts += dataset[i]['snippets'][j]['text'] + '\n' 167 | train_dataset.append({'question': contexts+question, 'answer': answer}) 168 | test_dataset = [] 169 | for i in range(int(0.8 * len(dataset)), len(dataset)): 170 | question = dataset[i]['body'] 171 | answer = dataset[i]['exact_answer'] 172 | contexts = 'Context:\n' 173 | for j in range(len(dataset[i]['snippets'])): 174 | contexts += dataset[i]['snippets'][j]['text'] + '\n' 175 | test_dataset.append({'question': contexts+question, 'answer': answer}) 176 | val_dataset = [] 177 | print(f"Train: {len(train_dataset)} Test: {len(test_dataset)}") 178 | prompt = load_prompts(dataset_name) 179 | elif dataset_name == 'gsm8k': 180 | train_dataset = load_dataset('gsm8k', 'main', split="train[0%:90%]").shuffle(seed=seed) 181 | val_dataset = load_dataset('gsm8k', 'main', split="train[90%:100%]").shuffle(seed=seed) 182 | test_dataset = load_dataset('gsm8k', 'main', split="test").shuffle(seed=seed) 183 | prompt = load_prompts('gsm8k') 184 | elif dataset_name == 'mednli': 185 | file_path = "./data/mednli/mednli_train.jsonl" 186 | train_dataset = [] 187 | with open(file_path, 'r') as f: 188 | for line in f: 189 | data = json.loads(line) 190 | template = "Sentence A: {}\nSentence B: {}" 191 | question = template.format(data['sent_a'], data['sent_b']) 192 | train_dataset.append({'question': question, 'answer': data['label']}) 193 | file_path = "./data/mednli/mednli_test.jsonl" 194 | ratio = int(0.1 * len(train_dataset)) 195 | val_dataset = train_dataset[-ratio:] 196 | test_dataset = [] 197 | with open(file_path, 'r') as f: 198 | for line in f: 199 | data = json.loads(line) 200 | template = "Sentence A: {}\nSentence B: {}" 201 | question = template.format(data['sent_a'], data['sent_b']) 202 | test_dataset.append({'question': question, 'answer': data['label']}) 203 | prompt = load_prompts('mednli') 204 | elif dataset_name == 'mediqa-rqe': 205 | data = load_dataset('bigbio/mediqa_rqe', 'mediqa_rqe_bigbio_pairs', split='train') 206 | train_dataset = [] 207 | for i in range(len(data)): 208 | template = "Question: {}\nSolution: {}" 209 | question = template.format(data[i]['text_2'], data[i]['text_1']) 210 | if data[i]['label'] == True or data[i]['label'] == 'true': 211 | answer = 'true' 212 | else: 213 | answer = "false" 214 | train_dataset.append({'question': question, 'answer': answer}) 215 | data = load_dataset('bigbio/mediqa_rqe', 'mediqa_rqe_bigbio_pairs', split='train') 216 | val_dataset = [] 217 | for i in range(len(data)): 218 | template = "Question: {}\nSolution: {}" 219 | question = template.format(data[i]['text_2'], data[i]['text_1']) 220 | if data[i]['label'] == True or data[i]['label'] == 'true': 221 | answer = 'true' 222 | else: 223 | answer = "false" 224 | val_dataset.append({'question': question, 'answer': answer}) 225 | data = load_dataset('bigbio/mediqa_rqe', 'mediqa_rqe_bigbio_pairs', split='train') 226 | test_dataset = [] 227 | for i in range(len(data)): 228 | template = "Question: {}\nSolution: {}" 229 | question = template.format(data[i]['text_2'], data[i]['text_1']) 230 | if data[i]['label'] == True or data[i]['label'] == 'true': 231 | answer = 'true' 232 | else: 233 | answer = "false" 234 | test_dataset.append({'question': question, 'answer': answer}) 235 | prompt = load_prompts('mediqa-rqe') 236 | elif dataset_name == 'pubhealth': 237 | file_path = "./data/pubhealth/train.tsv" 238 | # read from the tsv file 239 | train_dataset = [] 240 | with open(file_path, 'r') as f: 241 | for line in f: 242 | data = line.split('\t') 243 | question = data[2] 244 | answer = data[-2] 245 | explanation = data[3] 246 | question_template = "{}" 247 | template = "{}\n#### {}" 248 | train_dataset.append({'question': question_template.format(question), 'answer': template.format(explanation, answer)}) 249 | train_dataset = train_dataset[1:] 250 | file_path = "./data/pubhealth/dev.tsv" 251 | # read from the tsv file 252 | val_dataset = [] 253 | with open(file_path, 'r') as f: 254 | for line in f: 255 | data = line.split('\t') 256 | question = data[2] 257 | answer = data[-2] 258 | explanation = data[3] 259 | question_template = "{}" 260 | template = "{}\n#### {}" 261 | val_dataset.append({'question': question_template.format(question), 'answer': template.format(explanation, answer)}) 262 | val_dataset = val_dataset[1:] 263 | file_path = "./data/pubhealth/test.tsv" 264 | # read from the tsv file 265 | test_dataset = [] 266 | with open(file_path, 'r') as f: 267 | for line in f: 268 | data = line.split('\t') 269 | question = data[2] 270 | answer = data[-2] 271 | explanation = data[3] 272 | question_template = "{}" 273 | template = "{}\n#### {}" 274 | test_dataset.append({'question': question_template.format(question), 'answer': template.format(explanation, answer)}) 275 | test_dataset = test_dataset[1:] 276 | prompt = load_prompts('pubhealth') 277 | elif dataset_name == 'cord19': 278 | data_list = load_dataset('mystic-leung/medical_cord19', split='test') 279 | data = [] 280 | for line in data_list: 281 | data.append({'question': line['input'], 'answer': line['output']}) 282 | random.shuffle(data) 283 | train_dataset = data[:3000] 284 | val_dataset = data[3000:3500] 285 | test_dataset = data[3500:4000] 286 | prompt = load_prompts('cord19') 287 | 288 | 289 | for i in range(len(train_dataset)): 290 | d = {'question': train_dataset[i]['question'], 'answer': train_dataset[i]['answer']} 291 | file_name = f"./data/{dataset_name}/raw_train.jsonl" 292 | with open(file_name, 'a') as f: 293 | json.dump(d, f) 294 | f.write('\n') 295 | 296 | for i in range(len(val_dataset)): 297 | d = {'question': val_dataset[i]['question'], 'answer': val_dataset[i]['answer']} 298 | file_name = f"./data/{dataset_name}/raw_val.jsonl" 299 | with open(file_name, 'a') as f: 300 | json.dump(d, f) 301 | f.write('\n') 302 | 303 | for i in range(len(test_dataset)): 304 | d = {'question': test_dataset[i]['question'], 'answer': test_dataset[i]['answer']} 305 | file_name = f"./data/{dataset_name}/raw_test.jsonl" 306 | with open(file_name, 'a') as f: 307 | json.dump(d, f) 308 | f.write('\n') 309 | 310 | return train_dataset, val_dataset, test_dataset, prompt 311 | 312 | if __name__ == '__main__': 313 | get_datasets(42, "gsm8k") --------------------------------------------------------------------------------