├── .gitignore ├── Pre_Experiment ├── customize_model.py ├── data │ ├── preprocess_casehold.py │ ├── preprocess_finfact.py │ └── preprocess_pubhealth.py ├── model │ └── .gitkeep ├── pre_experiment.py ├── pre_experiment.sh └── result │ └── .gitkeep ├── RL_KTO ├── data │ └── dataset_info.json ├── ds_z3_offload_config.json ├── process_kto.py └── train_kto.sh ├── data └── dataset_info.json ├── demo ├── llama_pubmedqa_rare.sh └── qwenvl_mmrait_rare.sh ├── eval └── eval.py ├── image ├── Case_Study.png ├── Overview_RARE.png ├── benchmark.png └── logo.png ├── inference ├── api_infer_post.py ├── vllm_infer_mm.py ├── vllm_infer_text.py └── vllm_infer_text_reject_sampling.py ├── license ├── process ├── process_casehold.py ├── process_finfact.py ├── process_medqa.py ├── process_mmrait.py ├── process_pubhealth.py ├── process_pubmed.py └── select_true.py ├── readme.md ├── requirements.txt └── train ├── accelerate_config.yaml ├── fsdp_config_llama.json ├── fsdp_config_mistral.json ├── fsdp_config_qwen.json ├── merge_lora.yaml ├── sft.py ├── sft.sh ├── train.py ├── training_args.yaml ├── training_args_lora.yaml └── training_args_mm.yaml /.gitignore: -------------------------------------------------------------------------------- 1 | ***__pycache__*** 2 | **/__pycache__ 3 | **/__pycache__/ 4 | .DS_Store -------------------------------------------------------------------------------- /Pre_Experiment/customize_model.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForCausalLM 2 | import argparse 3 | 4 | def customize_model(model_name, output_dir): 5 | # Load original model and tokenizer 6 | tokenizer = AutoTokenizer.from_pretrained(model_name) 7 | model = AutoModelForCausalLM.from_pretrained(model_name) 8 | 9 | special_tokens = ["[RETRIEVAL]", "[ENTITIES]", "[SEP]", "[REASONING]"] 10 | 11 | tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) 12 | model.resize_token_embeddings(len(tokenizer)) 13 | 14 | # Save customized version 15 | tokenizer.save_pretrained(f"{output_dir}/custom_tokenizer") 16 | model.save_pretrained(f"{output_dir}/custom_model") 17 | 18 | print(f"Customized model and tokenizer saved to {output_dir}") 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser(description='Customize LLM with special tokens') 22 | parser.add_argument('--model_name_or_path', type=str, required=True, 23 | help='Path to original model') 24 | parser.add_argument('--output_dir', type=str, required=True, 25 | help='Directory to save customized model') 26 | 27 | args = parser.parse_args() 28 | customize_model(args.model_name, args.output_dir, args.special_tokens) -------------------------------------------------------------------------------- /Pre_Experiment/data/preprocess_casehold.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | def convert_json(json_path, output_path): 5 | with open(json_path, "r", encoding="utf-8") as f: 6 | data = json.load(f) 7 | 8 | new_data = [] 9 | for item in data: 10 | new_item = { 11 | "id": item["example_id"], 12 | "x": "\nA." + item["holding_0"] + "\nB." + item["holding_1"] + "\nC." + item["holding_2"] + 13 | "\nD." + item["holding_3"] + "\nE." + item["holding_4"], 14 | "R_x": item["citing_prompt"], 15 | "r": item["predict"], 16 | "output": item["output"], 17 | } 18 | new_data.append(new_item) 19 | 20 | with open(output_path, "w", encoding="utf-8") as f: 21 | json.dump(new_data, f, indent=4, ensure_ascii=False) 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser(description='Convert CaseHold JSON format') 25 | parser.add_argument('--input', type=str, required=True, help='Path to input JSON file') 26 | parser.add_argument('--output', type=str, required=True, help='Path to output JSON file') 27 | 28 | args = parser.parse_args() 29 | convert_json(args.input, args.output) -------------------------------------------------------------------------------- /Pre_Experiment/data/preprocess_finfact.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | OPTION="""\nA. true\nB. false\nC. NEI""" 4 | def convert_json(json_path, output_path): 5 | with open(json_path, "r", encoding="utf-8") as f: 6 | data = json.load(f) 7 | 8 | new_data = [] 9 | for item in data: 10 | new_item = { 11 | "id": item["url"], 12 | "x": item["claim"]+OPTION , 13 | "R_x": item["documents"], 14 | "r": item["predict"], 15 | "output": item["output"], 16 | } 17 | new_data.append(new_item) 18 | 19 | with open(output_path, "w", encoding="utf-8") as f: 20 | json.dump(new_data, f, indent=4, ensure_ascii=False) 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser(description='Convert finfact JSON format') 24 | parser.add_argument('--input', type=str, required=True, help='Path to input JSON file') 25 | parser.add_argument('--output', type=str, required=True, help='Path to output JSON file') 26 | 27 | args = parser.parse_args() 28 | convert_json(args.input, args.output) -------------------------------------------------------------------------------- /Pre_Experiment/data/preprocess_pubhealth.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | OPTION="""\nA. true - The statement is entirely accurate and supported by solid evidence.\nB. false - The statement is completely untrue and contradicted by strong evidence.\nC. mixture - The statement is partially true but contains some inaccuracies or misleading elements.\nD. unproven - There is insufficient evidence to confirm or refute the statement.""" 5 | def convert_json(json_path, output_path): 6 | with open(json_path, "r", encoding="utf-8") as f: 7 | data = json.load(f) 8 | 9 | new_data = [] 10 | for item in data: 11 | new_item = { 12 | "id": item["id"], 13 | "x":item["text_1"]+OPTION, 14 | "R_x": item["text_2"], 15 | "r": item["predict"], 16 | "output": item["output"], 17 | } 18 | new_data.append(new_item) 19 | 20 | with open(output_path, "w", encoding="utf-8") as f: 21 | json.dump(new_data, f, indent=4, ensure_ascii=False) 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser(description='Convert pubhealth JSON format') 25 | parser.add_argument('--input', type=str, required=True, help='Path to input JSON file') 26 | parser.add_argument('--output', type=str, required=True, help='Path to output JSON file') 27 | 28 | args = parser.parse_args() 29 | convert_json(args.input, args.output) -------------------------------------------------------------------------------- /Pre_Experiment/model/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open-DataFlow/RARE/a1c61533e64ca8f9905281c41ad4ce199be1353d/Pre_Experiment/model/.gitkeep -------------------------------------------------------------------------------- /Pre_Experiment/pre_experiment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import numpy as np 4 | from collections import Counter 5 | import spacy 6 | from vllm import LLM, SamplingParams 7 | 8 | # Constants 9 | SPECIAL_TOKENS = { 10 | "retrieval_start": "[RETRIEVAL]", 11 | "entity_start": "[ENTITIES]", 12 | "context_sep": "[SEP]", 13 | "reasoning": "[REASONING]" 14 | } 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description='Run RARE experiment with configurable parameters') 18 | parser.add_argument('--model_path_or_name', type=str, default='Pre_Experiment/model/custom_model', help='Path to custom model') 19 | parser.add_argument('--tokenizer_path', type=str, default='Pre_Experiment/model/custom_tokenizer', help='Path to custom tokenizer') 20 | parser.add_argument('--dataset_path', type=str, required=True, help='Path to dataset JSON file') 21 | parser.add_argument('--extractor_path', type=str, default='Pre_Experiment/model/en_core_web_sm', help='Path to spaCy extractor model') 22 | parser.add_argument('--dataset_name', type=str, required=True, help='Name of dataset for output file') 23 | parser.add_argument('--retrieval_ratio', type=int, choices=[0,1,2,3,4], required=True, help='Ratio of retrieval content to use (0-4)') 24 | 25 | return parser.parse_args() 26 | 27 | class VLLMLossCalculator: 28 | def __init__(self, model_path, tokenizer_path): 29 | self.model = LLM( 30 | model=model_path, 31 | tokenizer=tokenizer_path, 32 | tensor_parallel_size=8, 33 | gpu_memory_utilization=0.7, 34 | trust_remote_code=True, 35 | max_model_len=8000, 36 | enable_chunked_prefill=True, 37 | max_num_seqs=1, 38 | enforce_eager=True 39 | ) 40 | self.tokenizer = self.model.get_tokenizer() 41 | self.tokenizer.pad_token = self.tokenizer.eos_token 42 | self.sampling_params = SamplingParams( 43 | temperature=0.6, 44 | top_p=1.0, 45 | max_tokens=1, 46 | prompt_logprobs=1, 47 | ) 48 | 49 | def compute_losses(self, samples, extractor, ratio): 50 | prompts = [self.build_augmented_input(s, ratio) for s in samples] 51 | return self._get_logprobs(prompts, [s['r'] for s in samples], [s['R_x'] for s in samples], extractor) 52 | 53 | def build_augmented_input(self, sample, ratio): 54 | if ratio == 0: 55 | return sample['x'] 56 | n = len(sample['R_x']) 57 | return f"{sample['x']}\n{sample['R_x'][:n * ratio // 4]}" 58 | 59 | def _get_logprobs(self, prompts, target_texts, R_xs, extractor): 60 | full_texts = [p + SPECIAL_TOKENS["reasoning"] + t for p, t in zip(prompts, target_texts)] 61 | outputs = self.model.generate(full_texts, self.sampling_params, use_tqdm=True) 62 | 63 | knowledge_losses, reasoning_losses = [], [] 64 | for i, output in enumerate(outputs): 65 | k = self.extract_entities(R_xs[i], extractor) 66 | token_list = list(set(self.tokenizer.convert_tokens_to_ids( 67 | self.tokenizer.tokenize(word)) for word in k) 68 | token_list = [x for sublist in token_list for x in sublist] 69 | 70 | knowledge, reasoning = [], [] 71 | start_extract = False 72 | for entry in list(output.prompt_logprobs): 73 | if isinstance(entry, dict): 74 | first_token_id = next(iter(entry)) 75 | if first_token_id == 128259: 76 | start_extract = True 77 | if start_extract: 78 | (knowledge if first_token_id in token_list else reasoning).append( 79 | entry[first_token_id].logprob) 80 | 81 | knowledge_losses.append(-np.mean(knowledge) if knowledge else 0.0) 82 | reasoning_losses.append(-np.mean(reasoning) if reasoning else 0.0) 83 | 84 | return knowledge_losses, reasoning_losses 85 | 86 | def extract_entities(self, text, nlp): 87 | doc = nlp(text) 88 | ents = [ent.text for ent in doc.ents] 89 | return [ent for ent, _ in Counter(ents).most_common(100)] + [" "+ent for ent in ents] 90 | 91 | def run_experiment(args): 92 | nlp = spacy.load(args.extractor_path) 93 | calculator = VLLMLossCalculator(args.model_path, args.tokenizer_path) 94 | 95 | with open(args.dataset_path) as f: 96 | test_samples = json.load(f) 97 | 98 | knowledge_losses, reasoning_losses = calculator.compute_losses( 99 | test_samples, nlp, args.retrieval_ratio) 100 | 101 | output_data = { 102 | "samples": [{ 103 | "sample_id": i+1, 104 | "L_A": k_loss, 105 | "L_B": r_loss, 106 | "delta_L": r_loss-k_loss 107 | } for i, (k_loss, r_loss) in enumerate(zip(knowledge_losses, reasoning_losses))], 108 | "average_LA": round(np.mean(knowledge_losses), 4), 109 | "average_LB": round(np.mean(reasoning_losses), 4), 110 | "average_LB/LA": round(np.mean(reasoning_losses)/np.mean(knowledge_losses), 4) 111 | } 112 | 113 | output_path = f"Pre_Experiment/result/pre_experiment_{args.data_name}_{args.retrieval_ratio}_4.json" 114 | with open(output_path, "w", encoding="utf-8") as f: 115 | json.dump(output_data, f, ensure_ascii=False, indent=4) 116 | 117 | print(f"Results saved to {output_path}") 118 | 119 | if __name__ == "__main__": 120 | args = parse_args() 121 | run_experiment(args) -------------------------------------------------------------------------------- /Pre_Experiment/pre_experiment.sh: -------------------------------------------------------------------------------- 1 | # Run pre-experiment with parameters: 2 | # --retrieval_ratio: Controls how much retrieval content to use (ratio/4 of document): 3 | # 0 = None 4 | # 1 = 1/4 5 | # 2 = 1/2 6 | # 3 = 3/4 7 | # 4 = Full content 8 | 9 | #For pubhealth 10 | python Pre_Experiment/pre_experiment.py \ 11 | --dataset_path Pre_Experiment/data/pre_pubhealth.json\ 12 | --dataset_name pubhealth \ 13 | --retrieval_ratio 1 \ 14 | 15 | #For casehold 16 | python Pre_Experiment/pre_experiment.py \ 17 | --dataset_path Pre_Experiment/data/pre_casehold.json\ 18 | --dataset_name casehold \ 19 | --retrieval_ratio 2 \ 20 | 21 | #For finfact 22 | python Pre_Experiment/pre_experiment.py \ 23 | --dataset_path Pre_Experiment/data/pre_finfact.json\ 24 | --dataset_name finfact \ 25 | --retrieval_ratio 3 \ -------------------------------------------------------------------------------- /Pre_Experiment/result/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open-DataFlow/RARE/a1c61533e64ca8f9905281c41ad4ce199be1353d/Pre_Experiment/result/.gitkeep -------------------------------------------------------------------------------- /RL_KTO/data/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | 3 | "pubmed_kto":{ 4 | "file_name": "RL_KTO/data/train_pubmed_kto.json", 5 | "formatting": "sharegpt", 6 | "columns": { 7 | "messages": "conversations", 8 | "kto_tag": "kto_tag" 9 | } 10 | }, 11 | "pubhealth_kto":{ 12 | "file_name": "RL_KTO/data/train_pubhealth_kto.json", 13 | "formatting": "sharegpt", 14 | "columns": { 15 | "messages": "conversations", 16 | "kto_tag": "kto_tag" 17 | } 18 | }, 19 | "bioasq_kto":{ 20 | "file_name": "RL_KTO/data/train_bioasq_kto.json", 21 | "formatting": "sharegpt", 22 | "columns": { 23 | "messages": "conversations", 24 | "kto_tag": "kto_tag" 25 | } 26 | }, 27 | "covert_kto":{ 28 | "file_name": "RL_KTO/data/train_covert_kto.json", 29 | "formatting": "sharegpt", 30 | "columns": { 31 | "messages": "conversations", 32 | "kto_tag": "kto_tag" 33 | } 34 | }, 35 | "medqa_kto":{ 36 | "file_name": "RL_KTO/data/train_medqa_kto.json", 37 | "formatting": "sharegpt", 38 | "columns": { 39 | "messages": "conversations", 40 | "kto_tag": "kto_tag" 41 | } 42 | }, 43 | "finfact_kto":{ 44 | "file_name": "RL_KTO/data/train_finfact.json", 45 | "formatting": "sharegpt", 46 | "columns": { 47 | "messages": "conversations", 48 | "kto_tag": "kto_tag" 49 | } 50 | }, 51 | "casehold_kto":{ 52 | "file_name": "RL_KTO/data/train_casehold_kto.json", 53 | "formatting": "sharegpt", 54 | "columns": { 55 | "messages": "conversations", 56 | "kto_tag": "kto_tag" 57 | } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /RL_KTO/ds_z3_offload_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 3, 20 | "overlap_comm": true, 21 | "contiguous_gradients": true, 22 | "sub_group_size": 1000000000.0, 23 | "reduce_bucket_size": "auto", 24 | "stage3_prefetch_bucket_size": "auto", 25 | "stage3_param_persistence_threshold": "auto", 26 | "stage3_max_live_parameters": 1000000000.0, 27 | "stage3_max_reuse_distance": 1000000000.0, 28 | "stage3_gather_16bit_weights_on_model_save": true, 29 | "offload_optimizer": { 30 | "device": "cpu", 31 | "pin_memory": false 32 | }, 33 | "offload_param": { 34 | "device": "cpu", 35 | "pin_memory": false 36 | } 37 | } 38 | } -------------------------------------------------------------------------------- /RL_KTO/process_kto.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import argparse 4 | 5 | def extract_option(pred): 6 | for pattern in [ 7 | r"(.*?)", 8 | r"(.*?)", 9 | r"^([A-Z])[.,:]", 10 | r"Answer:\s*([A-Z])\s*", 11 | ]: 12 | match = re.search(pattern, pred, re.DOTALL) 13 | if match is not None: 14 | pred = match.group(1) 15 | return pred.replace("<", "").replace(">", "").strip() 16 | 17 | def convert_to_kto_format(input_json_path, output_json_path): 18 | with open(input_json_path, "r", encoding="utf-8") as f: 19 | data = json.load(f) 20 | 21 | kto_data = [] 22 | for item in data: 23 | kto_tag = extract_option(item["predict"]) == extract_option(item["output"]) 24 | new_item = { 25 | "conversations": [ 26 | {"from": "human", "value": item["instruction"]}, 27 | {"from": "gpt", "value": item["predict"]} 28 | ], 29 | "kto_tag": kto_tag 30 | } 31 | kto_data.append(new_item) 32 | 33 | with open(output_json_path, "w", encoding="utf-8") as f: 34 | json.dump(kto_data, f, indent=4, ensure_ascii=False) 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser(description="Convert data to kto format") 38 | parser.add_argument("--input_path", type=str, required=True, help="Path to input JSON file") 39 | parser.add_argument("--output_path", type=str, required=True, help="Path to output JSON file") 40 | args = parser.parse_args() 41 | convert_to_kto_format(args.input_path, args.output_path) 42 | -------------------------------------------------------------------------------- /RL_KTO/train_kto.sh: -------------------------------------------------------------------------------- 1 | # Model name or path specifying the pretrained model to use 2 | model_name_or_path="meta-llama/Llama-3.1-8B-Instruct" 3 | 4 | # Variable to select which model type to use, making it easy to switch models 5 | temp=llama3 6 | # Other example models you can switch to: 7 | # temp=qwen 8 | # temp=mistral 9 | 10 | # Dataset name indicating which dataset is being used, for more details you can check RL_KTO/data/dataset_info.json 11 | dataset=covert_kto 12 | 13 | # Output directory for saving results 14 | OUTPUT_DIR=saves/RL_KTO/${temp}_${dataset} 15 | 16 | llamafactory-cli train \ 17 | --stage kto \ 18 | --do_train True \ 19 | --model_name_or_path ${model_path} \ 20 | --preprocessing_num_workers 16 \ 21 | --finetuning_type full \ 22 | --template ${temp} \ 23 | --flash_attn auto\ 24 | --dataset_dir RL_KTO/data \ 25 | --dataset ${dataset} \ 26 | --cutoff_len ${len} \ 27 | --learning_rate 5e-06 \ 28 | --num_train_epochs 3 \ 29 | --max_samples 100000 \ 30 | --per_device_train_batch_size 1 \ 31 | --gradient_accumulation_steps 8 \ 32 | --lr_scheduler_type cosine \ 33 | --max_grad_norm 1.0 \ 34 | --logging_steps 1 \ 35 | --save_steps 100 \ 36 | --warmup_steps 0 \ 37 | --packing False \ 38 | --report_to none \ 39 | --output_dir ${OUTPUT_DIR} \ 40 | --bf16 True \ 41 | --plot_loss True \ 42 | --trust_remote_code True \ 43 | --ddp_timeout 180000000 \ 44 | --include_num_input_tokens_seen True \ 45 | --optim adamw_torch \ 46 | --pref_beta 0.1 \ 47 | --pref_ftx 0 \ 48 | --pref_loss sigmoid \ 49 | --deepspeed RL_KTO/ds_z3_offload_config.json \ 50 | --gradient_checkpointing True -------------------------------------------------------------------------------- /data/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "medqa": { 3 | "file_name": "train_medqa.json", 4 | "columns": { 5 | "prompt": "instruction" 6 | } 7 | }, 8 | "pubmed": { 9 | "file_name": "train_pubmed.json", 10 | "columns": { 11 | "prompt": "instruction" 12 | } 13 | }, 14 | "pubhealth": { 15 | "file_name": "train_pubhealth.json", 16 | "columns": { 17 | "prompt": "instruction" 18 | } 19 | }, 20 | "casehold": { 21 | "file_name": "train_casehold.json", 22 | "columns": { 23 | "prompt": "instruction" 24 | } 25 | }, 26 | "mmrait": { 27 | "file_name": "train_mmrait_true.json", 28 | "formatting": "sharegpt", 29 | "columns": { 30 | "messages": "messages", 31 | "images": "images" 32 | }, 33 | "tags": { 34 | "role_tag": "role", 35 | "content_tag": "content", 36 | "user_tag": "user", 37 | "assistant_tag": "assistant" 38 | } 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /demo/llama_pubmedqa_rare.sh: -------------------------------------------------------------------------------- 1 | # preprocess data 2 | # download 3 | huggingface-cli download --repo-type dataset --resume-download yuhkalhic/rare_share --local-dir process/rare_share 4 | unzip process/rare_share/system=planner_addret,dataset=all_dev,debug=False.zip -d process/rare_share/system=planner_addret,dataset=all_dev,debug=False 5 | unzip process/rare_share/system=planner_addret,dataset=all_train,debug=False.zip -d process/rare_share/system=planner_addret,dataset=all_train,debug=False 6 | unzip process/rare_share/system=planner_addret,dataset=all_test,debug=False.zip -d process/rare_share/system=planner_addret,dataset=all_test,debug=False 7 | python process/process_pubmed.py 8 | 9 | modelscope download --model Qwen/QwQ-32B --local_dir saves/QwQ-32B 10 | 11 | # distill 12 | python inference/vllm_infer_text.py \ 13 | --model_name_or_path saves/QwQ-32B \ 14 | --dataset_path data/train_pubmed.json \ 15 | --template qwen 16 | 17 | # train 18 | # To achieve the best effect, please use 8 or more A100 as much as possible. 19 | torchrun --nproc-per-node 8 --master_port 12345 \ 20 | train/sft.py \ 21 | --block_size=32768 \ 22 | --per_device_train_batch_size=1 \ 23 | --per_device_eval_batch_size=1 \ 24 | --gradient_accumulation_steps=8 \ 25 | --num_train_epochs=5 \ 26 | --train_file_path="data/train_pubmed.json" \ 27 | --model_name_or_path="meta-llama/Llama-3.1-8B-Instruct" \ 28 | --warmup_ratio=0.05 \ 29 | --fsdp="full_shard auto_wrap" \ 30 | --fsdp_config="train/fsdp_config_llama.json" \ 31 | --bf16=True \ 32 | --eval_strategy="no" \ 33 | --logging_steps=1 \ 34 | --save_strategy="no" \ 35 | --lr_scheduler_type="cosine" \ 36 | --learning_rate=1e-5 \ 37 | --weight_decay=1e-4 \ 38 | --adam_beta1=0.9 \ 39 | --adam_beta2=0.95 \ 40 | --output_dir="saves/pubmed-llama" \ 41 | --push_to_hub=false \ 42 | --save_only_model=True \ 43 | --gradient_checkpointing=True \ 44 | --report_to="none" 45 | 46 | # inference 47 | python inference/vllm_infer_text.py \ 48 | --model_name_or_path saves/pubmed-llama \ 49 | --dataset_path data/test_pubmed.json \ 50 | --template llama \ 51 | --prediction_key llm_predict_rare_llama \ 52 | --tensor_parallel_size 8 53 | 54 | # eval 55 | python eval/eval.py \ 56 | --file data/test_pubmed.json \ 57 | --prediction_key llm_predict_rare_llama -------------------------------------------------------------------------------- /demo/qwenvl_mmrait_rare.sh: -------------------------------------------------------------------------------- 1 | # preprocess data 2 | # download 3 | huggingface-cli download whalezzz/M2RAG --repo-type dataset --local-dir process --include "fact_verify/*" 4 | python process/process_mmrait.py 5 | 6 | modelscope download \ 7 | --model Qwen/Qwen2.5-VL-32B-Instruct \ 8 | --local_dir saves/Qwen2.5-VL-32B-Instruct 9 | 10 | # distill 11 | python inference/vllm_infer_mm.py \ 12 | --model_name_or_path saves/Qwen2.5-VL-32B-Instruct \ 13 | --dataset_path data/train_mmrait.json 14 | 15 | # select only true 16 | python process/select_true.py data/train_mmrait.json --mm 17 | 18 | # train 19 | accelerate launch \ 20 | --config_file train/accelerate_config_mm.yaml train/train.py train/training_args_mm.yaml 21 | 22 | # Transfer checkpoint 23 | python saves/mmrait-qwenvl/zero_to_fp32.py saves/mmrait-qwenvl/ --safe_serialization 24 | 25 | # inference 26 | python inference/vllm_infer_mm.py \ 27 | --model_name_or_path saves/mmrait-qwenvl \ 28 | --dataset_path data/test_mmrait.json \ 29 | --prediction_key llm_predict_rare_qwen2vl \ 30 | --tensor_parallel_size 4 31 | 32 | # eval 33 | python eval/eval.py \ 34 | --file data/test_mmarit.json \ 35 | --prediction_key llm_predict_rare_qwen2vl -------------------------------------------------------------------------------- /eval/eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import argparse 4 | 5 | 6 | def extract_option(pred): 7 | # 1. get A/B/C/D 8 | for pattern in [ 9 | r"(.*?)", 10 | r"(.*?)", 11 | r"^([A-Z])[.,:]", 12 | r"Answer:\s*([A-Z])\s*", 13 | ]: 14 | match = re.search(pattern, pred, re.DOTALL) 15 | if match is not None: 16 | pred = match.group(1) 17 | 18 | # 2. remove <> 19 | pred = pred.replace("<", "").replace(">", "") 20 | pred = pred.strip() 21 | 22 | return pred 23 | 24 | 25 | def calculate_accuracy(data, prediction_key): 26 | correct = 0 27 | total = len(data) 28 | 29 | valid_options = ["A", "B", "C", "D", "E"] 30 | valid_extractions = 0 31 | 32 | for item in data: 33 | 34 | output = extract_option(item["output"]) 35 | 36 | predict = extract_option(item[prediction_key]) 37 | 38 | if predict in valid_options: 39 | valid_extractions += 1 40 | 41 | if output == predict: 42 | correct += 1 43 | # else: 44 | # print(f"Incorrect: {predict} vs {output}") 45 | 46 | accuracy = correct / total if total > 0 else 0 47 | 48 | return accuracy, correct, total, valid_extractions 49 | 50 | 51 | def main(): 52 | parser = argparse.ArgumentParser(description="Calculate accuracy") 53 | parser.add_argument("--file", type=str, required=True, help="path") 54 | parser.add_argument("--prediction_key", type=str, help="name of key") 55 | 56 | args = parser.parse_args() 57 | 58 | with open(args.file, "r") as f: 59 | data = json.load(f) 60 | 61 | 62 | accuracy, correct, total, valid_extractions = calculate_accuracy(data, args.prediction_key) 63 | 64 | print(f"accuracy: {accuracy * 100:.2f}%, {correct}, {total}") 65 | print(f"Valid extraction rate: {valid_extractions}/{total}") 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /image/Case_Study.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open-DataFlow/RARE/a1c61533e64ca8f9905281c41ad4ce199be1353d/image/Case_Study.png -------------------------------------------------------------------------------- /image/Overview_RARE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open-DataFlow/RARE/a1c61533e64ca8f9905281c41ad4ce199be1353d/image/Overview_RARE.png -------------------------------------------------------------------------------- /image/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open-DataFlow/RARE/a1c61533e64ca8f9905281c41ad4ce199be1353d/image/benchmark.png -------------------------------------------------------------------------------- /image/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open-DataFlow/RARE/a1c61533e64ca8f9905281c41ad4ce199be1353d/image/logo.png -------------------------------------------------------------------------------- /inference/api_infer_post.py: -------------------------------------------------------------------------------- 1 | import json 2 | import aiohttp 3 | import asyncio 4 | import aiofiles 5 | from pathlib import Path 6 | import time 7 | import os 8 | import argparse 9 | import shutil 10 | from tqdm.asyncio import tqdm 11 | 12 | TEMP_DIR="temp" 13 | MAX_RETRIES = 20 14 | INITIAL_TIMEOUT = 400 15 | 16 | async def main(api_url, api_key, input_file, concurrency,model_name): 17 | start_time = time.time() 18 | try: 19 | print("1. Initializing...") 20 | os.makedirs(TEMP_DIR, exist_ok=True) 21 | 22 | print("2. Preparing queue...") 23 | queue = await prepare_queue(input_file) 24 | 25 | print(f"3. Launching {concurrency} worker coroutines...") 26 | progress_bar = tqdm(total=queue.qsize(), desc="Processing", unit="task") 27 | workers = [ 28 | asyncio.create_task(process_item(queue, api_url, api_key, input_file, model_name, progress_bar)) 29 | for _ in range(concurrency) 30 | ] 31 | await queue.join() 32 | print("4. Queue processing completed.") 33 | merged_data = [] 34 | 35 | # Combination 36 | for filename in os.listdir(TEMP_DIR): 37 | if filename.endswith(".json"): 38 | file_path = os.path.join(TEMP_DIR, filename) 39 | with open(file_path, 'r', encoding='utf-8') as f: 40 | data = json.load(f) 41 | merged_data.append(data) 42 | 43 | with open(f"{input_file}", 'w', encoding='utf-8') as f: 44 | json.dump(merged_data, f, ensure_ascii=False, indent=4) 45 | if os.path.exists(TEMP_DIR): 46 | shutil.rmtree(TEMP_DIR) 47 | 48 | for worker in workers: 49 | worker.cancel() 50 | await asyncio.gather(*workers, return_exceptions=True) 51 | 52 | except Exception as e: 53 | print(f"Main error: {str(e)}") 54 | 55 | end_time = time.time() 56 | elapsed_time = end_time - start_time 57 | print(f"Total runtime: {elapsed_time:.2f} seconds") 58 | 59 | 60 | async def check_input_file(input_file): 61 | """Check if input file exists and contains valid data""" 62 | if not Path(input_file).exists(): 63 | return False 64 | 65 | async with aiofiles.open(input_file, "r") as f: 66 | first_line = await f.readline() 67 | return bool(first_line.strip()) 68 | 69 | 70 | async def prepare_queue(input_file): 71 | """Prepare processing queue""" 72 | queue = asyncio.Queue() 73 | 74 | files = os.listdir(TEMP_DIR) 75 | 76 | # Extract UUIDs from existing output files 77 | ids_in_files = set() 78 | for file_name in files: 79 | if file_name.endswith(".json"): 80 | try: 81 | file_id = str(file_name[:-5]) 82 | ids_in_files.add(file_id) 83 | except ValueError: 84 | continue 85 | print(ids_in_files) 86 | 87 | async with aiofiles.open(input_file, "r") as f: 88 | try: 89 | content = await f.read() 90 | data_list = json.loads(content) 91 | 92 | for data in data_list: 93 | if "id" in data and (data["id"] not in ids_in_files): 94 | await queue.put(data["id"]) 95 | print(f"Added to queue: {data['id']}") 96 | except json.JSONDecodeError: 97 | print("Failed to decode JSON.") 98 | 99 | print(f"Queue size: {queue.qsize()}") 100 | return queue 101 | 102 | 103 | async def process_item(queue, api_url, api_key, input_file, model_name, progress_bar): 104 | """Process a single item from queue""" 105 | async with aiohttp.ClientSession() as session: 106 | while not queue.empty(): 107 | id = await queue.get() 108 | #print(f"Processing: {id}") 109 | await handle_api_request(session, id, api_url, api_key, input_file, model_name, progress_bar) 110 | queue.task_done() 111 | 112 | 113 | async def handle_api_request(session, id, api_url, api_key, input_file, model_name, progress_bar): 114 | """Handle API request and response""" 115 | original_data = await load_data(input_file,id) 116 | #print(original_data["instruction"]) 117 | payload = { 118 | "model": model_name, 119 | "messages": [{"role": "user", "content": original_data["instruction"]}], 120 | "max_tokens": 12000, 121 | "temperature": 0.6, 122 | } 123 | 124 | for retry in range(MAX_RETRIES): 125 | try: 126 | start_time1 = time.time() 127 | 128 | async with session.post( 129 | api_url, 130 | json=payload, 131 | headers={"Authorization": f"Bearer {api_key}"}, 132 | timeout=INITIAL_TIMEOUT + retry * 10 133 | ) as response: 134 | if response.status == 504: 135 | raise aiohttp.ClientError("Gateway Timeout") 136 | response.raise_for_status() 137 | 138 | result = await response.json() 139 | end_time1 = time.time() 140 | elapsed_time1 = end_time1 - start_time1 141 | #print(f"Request time: {elapsed_time1:.2f} seconds") 142 | await update_data(original_data, result, model_name) 143 | progress_bar.update(1) 144 | return 145 | 146 | except (aiohttp.ClientError, asyncio.TimeoutError) as e: 147 | print(f"Attempt {retry + 1}/{MAX_RETRIES} failed: {str(e)}") 148 | await asyncio.sleep(2 ** retry) 149 | 150 | print(f"Failed to process {id} after all retries.") 151 | 152 | 153 | async def load_data(input_file,id): 154 | """Load data""" 155 | async with aiofiles.open(input_file, "r") as f: 156 | content = await f.read() 157 | data_list = json.loads(content) 158 | for data in data_list: 159 | if data.get("id")==id: 160 | return data 161 | return None 162 | 163 | 164 | async def update_data(original, result, model_name): 165 | """Update original data with response and save""" 166 | updated = original.copy() 167 | #print(result) 168 | if "deepseek-r1" in model_name.lower() or "o3-mini" in model_name.lower(): 169 | updated.update({ 170 | f"{model_name}": result["choices"][0]["message"]["content"], 171 | f"{model_name}_reasoning": result["choices"][0]["message"]["reasoning_content"], 172 | f"{model_name}_predict_token": result["usage"]["completion_tokens"], 173 | f"{model_name}_input_token": result["usage"]["prompt_tokens"], 174 | }) 175 | else: 176 | updated.update({ 177 | f"{model_name}": result["choices"][0]["message"]["content"], 178 | f"{model_name}_predict_token": result["usage"]["completion_tokens"], 179 | f"{model_name}_input_token": result["usage"]["prompt_tokens"], 180 | }) 181 | 182 | file_name = f"{TEMP_DIR}/{updated['id']}.json" 183 | #print(file_name) 184 | await write_to_file(file_name, updated) 185 | 186 | 187 | async def write_to_file(file_path, data): 188 | """Write updated data to file""" 189 | async with aiofiles.open(file_path, 'w') as f: 190 | await f.write(json.dumps(data) + "\n") 191 | #print("File written successfully.") 192 | 193 | 194 | if __name__ == "__main__": 195 | parser = argparse.ArgumentParser() 196 | parser.add_argument("--api_url", type=str,required=True, help="API endpoint URL") 197 | parser.add_argument("--api_key", type=str,required=True, help="API key") 198 | parser.add_argument("--model_name", type=str,required=True, help="Model name") 199 | parser.add_argument("--dataset_path", type=str,required=True, help="Path to input file") 200 | parser.add_argument("--concurrency", type=int, default=20, help="Number of concurrent workers") 201 | args = parser.parse_args() 202 | 203 | asyncio.run(main(args.api_url, args.api_key, args.dataset_path, args.concurrency,args.model_name)) 204 | -------------------------------------------------------------------------------- /inference/vllm_infer_mm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import gc 3 | import os 4 | import argparse 5 | import torch 6 | from vllm import LLM, SamplingParams 7 | from vllm.distributed.parallel_state import ( 8 | destroy_distributed_environment, 9 | destroy_model_parallel, 10 | ) 11 | 12 | from transformers import AutoProcessor 13 | from qwen_vl_utils import process_vision_info 14 | 15 | 16 | # def clean_up(): 17 | # """only for npu""" 18 | # destroy_model_parallel() 19 | # destroy_distributed_environment() 20 | # gc.collect() 21 | # torch.npu.empty_cache() 22 | 23 | 24 | def format_prompt(instruction, template): 25 | """Format instruction based on model template""" 26 | if template.lower() == "qwen": 27 | return f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n" 28 | elif template.lower() == "llama": 29 | return f"<|start_header_id|>user<|end_header_id|>\n{instruction}\n<|start_header_id|>assistant<|end_header_id|>\n\n" 30 | elif template.lower() == "mistral": 31 | return f"[INST] {instruction}[/INST] " 32 | else: 33 | # Generic fallback 34 | return f"USER: {instruction}\nASSISTANT: " 35 | 36 | 37 | def get_stop_tokens(template): 38 | """Get stop tokens based on template""" 39 | if template.lower() == "qwen": 40 | return ["<|im_end|>"] 41 | elif template.lower() == "llama": 42 | return ["<|end_header_id|>"] 43 | elif template.lower() == "mistral": 44 | return ["[INST]"] 45 | else: 46 | return ["USER:"] 47 | 48 | 49 | def load_dataset(dataset_path): 50 | """Load dataset from JSON or JSONL file""" 51 | data = [] 52 | file_extension = os.path.splitext(dataset_path)[1].lower() 53 | 54 | try: 55 | if file_extension == ".json": 56 | with open(dataset_path, "r", encoding="utf-8") as f: 57 | data = json.load(f) 58 | # Handle both list and dict formats 59 | if isinstance(data, dict): 60 | data = [data] 61 | elif file_extension == ".jsonl": 62 | with open(dataset_path, "r", encoding="utf-8") as f: 63 | for line in f: 64 | if line.strip(): 65 | data.append(json.loads(line)) 66 | else: 67 | # Try both formats if extension doesn't match 68 | try: 69 | with open(dataset_path, "r", encoding="utf-8") as f: 70 | data = json.load(f) 71 | # Handle both list and dict formats 72 | if isinstance(data, dict): 73 | data = [data] 74 | except json.JSONDecodeError: 75 | with open(dataset_path, "r", encoding="utf-8") as f: 76 | for line in f: 77 | if line.strip(): 78 | data.append(json.loads(line)) 79 | 80 | return data 81 | except Exception as e: 82 | print(f"Error loading dataset: {e}") 83 | return [] 84 | 85 | 86 | def prepare_multimodal_inputs(item): 87 | """Prepare multimodal inputs from dataset item without modifying original format""" 88 | # Create a deep copy to avoid modifying the original data 89 | messages = [] 90 | for msg in item.get("messages", []): 91 | # Create a copy of each message 92 | messages.append(dict(msg)) 93 | 94 | images = item.get("images", []) 95 | 96 | if not messages and "instruction" in item: 97 | messages = [{"role": "user", "content": item["instruction"]}] 98 | 99 | if images and messages: 100 | temp_messages = [] 101 | for msg in messages: 102 | msg_copy = dict(msg) 103 | 104 | # Only modify user messages that contain text content 105 | if msg_copy["role"] == "user" and isinstance(msg_copy["content"], str): 106 | new_content = [] 107 | 108 | # Add images 109 | for image_path in images: 110 | new_content.append( 111 | { 112 | "type": "image", 113 | "image": image_path, 114 | "min_pixels": 224 * 224, 115 | "max_pixels": 1280 * 28 * 28, 116 | } 117 | ) 118 | 119 | # Add the original text 120 | new_content.append({"type": "text", "text": msg_copy["content"]}) 121 | 122 | # Update temporary message 123 | msg_copy["content"] = new_content 124 | 125 | temp_messages.append(msg_copy) 126 | 127 | return temp_messages 128 | 129 | return messages 130 | 131 | 132 | def vllm_infer( 133 | model_name_or_path: str, 134 | dataset_path: str, 135 | template: str = "qwen", 136 | temperature: float = 0.95, 137 | top_p: float = 0.7, 138 | top_k: int = 50, 139 | max_new_tokens: int = 8192, 140 | repetition_penalty: float = 1.0, 141 | tensor_parallel_size: int = 4, 142 | max_model_len: int = 10240, 143 | prediction_key: str = "predict", 144 | ): 145 | """ 146 | With multimodal support always enabled: python vllm_infer_mm.py --model_name_or_path Qwen2.5-VL-32B-Instruct --dataset_path data/test_mmrait.json 147 | """ 148 | print(f"Loading model: {model_name_or_path}") 149 | print(f"Using template: {template}") 150 | print(f"Multimodal mode: Enabled") 151 | 152 | # Load dataset 153 | dataset = load_dataset(dataset_path) 154 | if not dataset: 155 | print(f"Failed to load dataset from {dataset_path} or dataset is empty.") 156 | return 157 | 158 | print(f"Loaded {len(dataset)} examples from dataset.") 159 | 160 | # Load the processor for multimodal models 161 | processor = AutoProcessor.from_pretrained(model_name_or_path) 162 | print(f"Loaded AutoProcessor for multimodal model") 163 | 164 | prompts = [] 165 | original_data = [] 166 | 167 | for item in dataset: 168 | # Handle multimodal input 169 | messages = prepare_multimodal_inputs(item) 170 | 171 | if not messages: 172 | print(f"Warning: No valid messages found in item: {item}") 173 | continue 174 | 175 | # Create a temporary copy for processing 176 | temp_messages = messages.copy() 177 | 178 | # Process the prompt using the model's processor 179 | prompt = processor.apply_chat_template( 180 | temp_messages, 181 | tokenize=False, 182 | add_generation_prompt=True, 183 | ) 184 | 185 | # Process vision information 186 | image_inputs, video_inputs, video_kwargs = process_vision_info( 187 | temp_messages, return_video_kwargs=True 188 | ) 189 | 190 | # Prepare multimodal data 191 | mm_data = {} 192 | if image_inputs: 193 | mm_data["image"] = image_inputs 194 | 195 | # Prepare LLM input 196 | llm_input = { 197 | "prompt": prompt, 198 | "multi_modal_data": mm_data, 199 | } 200 | 201 | prompts.append(llm_input) 202 | original_data.append(item) 203 | 204 | if not prompts: 205 | print("No valid prompts found in dataset. Exiting.") 206 | return 207 | 208 | sampling_params = SamplingParams( 209 | temperature=temperature, 210 | top_p=top_p, 211 | top_k=top_k, 212 | max_tokens=max_new_tokens, 213 | repetition_penalty=repetition_penalty, 214 | ) 215 | 216 | print(f"Initializing LLM with tensor_parallel_size={tensor_parallel_size}") 217 | 218 | llm_kwargs = { 219 | "model": model_name_or_path, 220 | "tensor_parallel_size": tensor_parallel_size, 221 | "distributed_executor_backend": "mp", 222 | "max_model_len": max_model_len, 223 | "trust_remote_code": True, 224 | "limit_mm_per_prompt": {"image": 10}, 225 | } 226 | 227 | llm = LLM(**llm_kwargs) 228 | 229 | print("Starting batch generation...") 230 | outputs = llm.generate(prompts, sampling_params) 231 | 232 | print(f"Using prediction key: {prediction_key}") 233 | 234 | for i, output in enumerate(outputs): 235 | generated_text = output.outputs[0].text.strip() 236 | 237 | # Add assistant response as a new message if messages array exists 238 | if "messages" in original_data[i]: 239 | original_data[i]["messages"].append( 240 | {"role": "assistant", "content": generated_text} 241 | ) 242 | 243 | # Also store in the prediction key for compatibility 244 | original_data[i][prediction_key] = generated_text 245 | 246 | file_extension = os.path.splitext(dataset_path)[1].lower() 247 | 248 | # Create backup of original file 249 | backup_path = dataset_path + ".bak" 250 | if not os.path.exists(backup_path): 251 | try: 252 | with open(dataset_path, "rb") as src, open(backup_path, "wb") as dst: 253 | dst.write(src.read()) 254 | print(f"Created backup of original dataset at {backup_path}") 255 | except Exception as e: 256 | print(f"Warning: Failed to create backup: {e}") 257 | 258 | try: 259 | if file_extension == ".json": 260 | with open(dataset_path, "w", encoding="utf-8") as f: 261 | json.dump(original_data, f, ensure_ascii=False, indent=2) 262 | else: # Use JSONL format by default 263 | with open(dataset_path, "w", encoding="utf-8") as f: 264 | for item in original_data: 265 | f.write(json.dumps(item, ensure_ascii=False) + "\n") 266 | 267 | print("*" * 70) 268 | print( 269 | f"{len(original_data)} records updated with predictions using key '{prediction_key}'" 270 | ) 271 | print(f"Updated dataset saved to {dataset_path}") 272 | print("*" * 70) 273 | except Exception as e: 274 | print(f"Error saving results: {e}") 275 | print("Please check the backup file if needed.") 276 | 277 | del llm 278 | # clean_up() 279 | 280 | 281 | def parse_args(): 282 | parser = argparse.ArgumentParser( 283 | description="vLLM Inference for NPU with Multimodal Support" 284 | ) 285 | parser.add_argument( 286 | "--model_name_or_path", 287 | type=str, 288 | required=True, 289 | help="Path to pretrained model or model identifier from huggingface.co/models", 290 | ) 291 | parser.add_argument( 292 | "--dataset_path", 293 | type=str, 294 | required=True, 295 | help="Path to dataset file (JSON or JSONL)", 296 | ) 297 | parser.add_argument( 298 | "--template", 299 | type=str, 300 | default="qwen", 301 | choices=["qwen", "llama", "mistral"], 302 | help="Prompt template to use", 303 | ) 304 | parser.add_argument( 305 | "--temperature", type=float, default=0.6, help="Sampling temperature" 306 | ) 307 | parser.add_argument( 308 | "--top_p", type=float, default=0.7, help="Top-p sampling parameter" 309 | ) 310 | parser.add_argument( 311 | "--top_k", type=int, default=50, help="Top-k sampling parameter" 312 | ) 313 | parser.add_argument( 314 | "--max_new_tokens", 315 | type=int, 316 | default=8192, 317 | help="Maximum number of tokens to generate", 318 | ) 319 | parser.add_argument( 320 | "--repetition_penalty", 321 | type=float, 322 | default=1.0, 323 | help="Repetition penalty parameter", 324 | ) 325 | parser.add_argument( 326 | "--tensor_parallel_size", 327 | type=int, 328 | default=4, 329 | help="Tensor parallel size for distributed inference", 330 | ) 331 | parser.add_argument( 332 | "--max_model_len", type=int, default=10240, help="Maximum model sequence length" 333 | ) 334 | parser.add_argument( 335 | "--prediction_key", 336 | type=str, 337 | default="predict", 338 | help="Key to use when storing model predictions in the dataset", 339 | ) 340 | 341 | return parser.parse_args() 342 | 343 | 344 | if __name__ == "__main__": 345 | args = parse_args() 346 | vllm_infer( 347 | model_name_or_path=args.model_name_or_path, 348 | dataset_path=args.dataset_path, 349 | template=args.template, 350 | temperature=args.temperature, 351 | top_p=args.top_p, 352 | top_k=args.top_k, 353 | max_new_tokens=args.max_new_tokens, 354 | repetition_penalty=args.repetition_penalty, 355 | tensor_parallel_size=args.tensor_parallel_size, 356 | max_model_len=args.max_model_len, 357 | prediction_key=args.prediction_key, 358 | ) 359 | -------------------------------------------------------------------------------- /inference/vllm_infer_text.py: -------------------------------------------------------------------------------- 1 | import json 2 | import gc 3 | import os 4 | import argparse 5 | import torch 6 | from vllm import LLM, SamplingParams 7 | from vllm.distributed.parallel_state import ( 8 | destroy_distributed_environment, 9 | destroy_model_parallel, 10 | ) 11 | 12 | 13 | # def clean_up(): 14 | # """only for npu""" 15 | # destroy_model_parallel() 16 | # destroy_distributed_environment() 17 | # gc.collect() 18 | # torch.npu.empty_cache() 19 | 20 | 21 | def format_prompt(instruction, template): 22 | """Format instruction based on model template""" 23 | if template.lower() == "qwen": 24 | return f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n" 25 | elif template.lower() == "llama": 26 | return f"<|start_header_id|>user<|end_header_id|>\n{instruction}\n<|start_header_id|>assistant<|end_header_id|>\n\n" 27 | elif template.lower() == "mistral": 28 | return f"[INST] {instruction}[/INST] " 29 | elif template.lower() == "deepseek": 30 | return f"<|User|>{instruction}<|Assistant|>" 31 | else: 32 | # Generic fallback 33 | return f"USER: {instruction}\nASSISTANT: " 34 | 35 | 36 | def get_stop_tokens(template): 37 | """Get stop tokens based on template""" 38 | if template.lower() == "qwen": 39 | return ["<|im_end|>"] 40 | elif template.lower() == "llama": 41 | return ["<|end_header_id|>"] 42 | elif template.lower() == "mistral": 43 | return ["[INST]"] 44 | elif template.lower() == "deepseek": 45 | return ["<|User|>"] 46 | else: 47 | return ["USER:"] 48 | 49 | 50 | def load_dataset(dataset_path): 51 | """Load dataset from JSON or JSONL file""" 52 | data = [] 53 | file_extension = os.path.splitext(dataset_path)[1].lower() 54 | 55 | try: 56 | if file_extension == ".json": 57 | with open(dataset_path, "r", encoding="utf-8") as f: 58 | data = json.load(f) 59 | elif file_extension == ".jsonl": 60 | with open(dataset_path, "r", encoding="utf-8") as f: 61 | for line in f: 62 | if line.strip(): 63 | data.append(json.loads(line)) 64 | else: 65 | # Try both formats if extension doesn't match 66 | try: 67 | with open(dataset_path, "r", encoding="utf-8") as f: 68 | data = json.load(f) 69 | except json.JSONDecodeError: 70 | with open(dataset_path, "r", encoding="utf-8") as f: 71 | for line in f: 72 | if line.strip(): 73 | data.append(json.loads(line)) 74 | 75 | return data 76 | except Exception as e: 77 | print(f"Error loading dataset: {e}") 78 | return [] 79 | 80 | 81 | def vllm_infer( 82 | model_name_or_path: str, 83 | dataset_path: str, 84 | template: str = "qwen", 85 | temperature: float = 0.95, 86 | top_p: float = 0.7, 87 | top_k: int = 50, 88 | max_new_tokens: int = 8192, 89 | repetition_penalty: float = 1.0, 90 | tensor_parallel_size: int = 4, 91 | max_model_len: int = 10240, 92 | prediction_key: str = "predict", 93 | ): 94 | """ 95 | Usage: python vllm_infer_text.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset_path data/test_pubmed.json 96 | """ 97 | print(f"Loading model: {model_name_or_path}") 98 | print(f"Using template: {template}") 99 | 100 | # Load dataset 101 | dataset = load_dataset(dataset_path) 102 | if not dataset: 103 | print(f"Failed to load dataset from {dataset_path} or dataset is empty.") 104 | return 105 | 106 | print(f"Loaded {len(dataset)} examples from dataset.") 107 | 108 | prompts = [] 109 | original_data = [] 110 | for item in dataset: 111 | instruction = item.get( 112 | "instruction", item.get("input", item.get("prompt", item.get("query", ""))) 113 | ) 114 | 115 | if not instruction: 116 | print(f"Warning: Couldn't find instruction in item: {item}") 117 | continue 118 | 119 | formatted_prompt = format_prompt(instruction, template) 120 | prompts.append(formatted_prompt) 121 | original_data.append(item) 122 | 123 | if not prompts: 124 | print("No valid prompts found in dataset. Exiting.") 125 | return 126 | 127 | sampling_params = SamplingParams( 128 | temperature=temperature, 129 | top_p=top_p, 130 | top_k=top_k, 131 | max_tokens=max_new_tokens, 132 | repetition_penalty=repetition_penalty, 133 | stop=get_stop_tokens(template), 134 | ) 135 | 136 | print(f"Initializing LLM with tensor_parallel_size={tensor_parallel_size}") 137 | llm = LLM( 138 | model=model_name_or_path, 139 | tensor_parallel_size=tensor_parallel_size, 140 | distributed_executor_backend="mp", 141 | max_model_len=max_model_len, 142 | trust_remote_code=True, 143 | ) 144 | 145 | print("Starting batch generation...") 146 | outputs = llm.generate(prompts, sampling_params) 147 | 148 | print(f"Using prediction key: {prediction_key}") 149 | 150 | for i, output in enumerate(outputs): 151 | generated_text = output.outputs[0].text.strip() 152 | 153 | original_data[i][prediction_key] = generated_text 154 | 155 | file_extension = os.path.splitext(dataset_path)[1].lower() 156 | 157 | # Create backup of original file 158 | backup_path = dataset_path + ".bak" 159 | if not os.path.exists(backup_path): 160 | try: 161 | with open(dataset_path, "rb") as src, open(backup_path, "wb") as dst: 162 | dst.write(src.read()) 163 | print(f"Created backup of original dataset at {backup_path}") 164 | except Exception as e: 165 | print(f"Warning: Failed to create backup: {e}") 166 | 167 | try: 168 | if file_extension == ".json": 169 | with open(dataset_path, "w", encoding="utf-8") as f: 170 | json.dump(original_data, f, ensure_ascii=False, indent=2) 171 | else: # Use JSONL format by default 172 | with open(dataset_path, "w", encoding="utf-8") as f: 173 | for item in original_data: 174 | f.write(json.dumps(item, ensure_ascii=False) + "\n") 175 | 176 | print("*" * 70) 177 | print( 178 | f"{len(original_data)} records updated with predictions using key '{prediction_key}'" 179 | ) 180 | print(f"Updated dataset saved to {dataset_path}") 181 | print("*" * 70) 182 | except Exception as e: 183 | print(f"Error saving results: {e}") 184 | print("Please check the backup file if needed.") 185 | 186 | # Clean up 187 | del llm 188 | # clean_up() 189 | 190 | 191 | def parse_args(): 192 | parser = argparse.ArgumentParser(description="vLLM Inference for NPU") 193 | parser.add_argument( 194 | "--model_name_or_path", 195 | type=str, 196 | required=True, 197 | help="Path to pretrained model or model identifier from huggingface.co/models", 198 | ) 199 | parser.add_argument( 200 | "--dataset_path", 201 | type=str, 202 | required=True, 203 | help="Path to dataset file (JSON or JSONL)", 204 | ) 205 | parser.add_argument( 206 | "--template", 207 | type=str, 208 | default="qwen", 209 | choices=["qwen", "llama", "mistral", "deepseek"], 210 | help="Prompt template to use", 211 | ) 212 | parser.add_argument( 213 | "--temperature", type=float, default=0.95, help="Sampling temperature" 214 | ) 215 | parser.add_argument( 216 | "--top_p", type=float, default=0.7, help="Top-p sampling parameter" 217 | ) 218 | parser.add_argument( 219 | "--top_k", type=int, default=50, help="Top-k sampling parameter" 220 | ) 221 | parser.add_argument( 222 | "--max_new_tokens", 223 | type=int, 224 | default=10240, 225 | help="Maximum number of tokens to generate", 226 | ) 227 | parser.add_argument( 228 | "--repetition_penalty", 229 | type=float, 230 | default=1.0, 231 | help="Repetition penalty parameter", 232 | ) 233 | parser.add_argument( 234 | "--tensor_parallel_size", 235 | type=int, 236 | default=8, 237 | help="Tensor parallel size for distributed inference", 238 | ) 239 | parser.add_argument( 240 | "--max_model_len", type=int, default=20480, help="Maximum model sequence length" 241 | ) 242 | parser.add_argument( 243 | "--prediction_key", 244 | type=str, 245 | default="predict", 246 | help="Key to use when storing model predictions in the dataset", 247 | ) 248 | 249 | return parser.parse_args() 250 | 251 | 252 | if __name__ == "__main__": 253 | args = parse_args() 254 | vllm_infer( 255 | model_name_or_path=args.model_name_or_path, 256 | dataset_path=args.dataset_path, 257 | template=args.template, 258 | temperature=args.temperature, 259 | top_p=args.top_p, 260 | top_k=args.top_k, 261 | max_new_tokens=args.max_new_tokens, 262 | repetition_penalty=args.repetition_penalty, 263 | tensor_parallel_size=args.tensor_parallel_size, 264 | max_model_len=args.max_model_len, 265 | prediction_key=args.prediction_key, 266 | ) 267 | -------------------------------------------------------------------------------- /inference/vllm_infer_text_reject_sampling.py: -------------------------------------------------------------------------------- 1 | import json 2 | import gc 3 | import os 4 | import re 5 | import argparse 6 | import torch 7 | from vllm import LLM, SamplingParams 8 | from vllm.distributed.parallel_state import ( 9 | destroy_distributed_environment, 10 | destroy_model_parallel, 11 | ) 12 | 13 | 14 | def clean_up(): 15 | """Clean up resources - support both NPU and GPU""" 16 | destroy_model_parallel() 17 | destroy_distributed_environment() 18 | gc.collect() 19 | if torch.cuda.is_available(): 20 | torch.cuda.empty_cache() 21 | elif hasattr(torch, 'npu') and torch.npu.is_available(): 22 | torch.npu.empty_cache() 23 | 24 | 25 | def format_prompt(instruction, template): 26 | """Format instruction based on model template""" 27 | if template.lower() == "qwen": 28 | return f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n" 29 | elif template.lower() == "llama": 30 | return f"<|start_header_id|>user<|end_header_id|>\n{instruction}\n<|start_header_id|>assistant<|end_header_id|>\n\n" 31 | elif template.lower() == "mistral": 32 | return f"[INST] {instruction}[/INST] " 33 | elif template.lower() == "deepseek": 34 | return f"<|User|>{instruction}<|Assistant|>" 35 | else: 36 | # Generic fallback 37 | return f"USER: {instruction}\nASSISTANT: " 38 | 39 | 40 | def get_stop_tokens(template): 41 | """Get stop tokens based on template""" 42 | if template.lower() == "qwen": 43 | return ["<|im_end|>"] 44 | elif template.lower() == "llama": 45 | return ["<|end_header_id|>"] 46 | elif template.lower() == "mistral": 47 | return ["[INST]"] 48 | elif template.lower() == "deepseek": 49 | return ["<|User|>"] 50 | else: 51 | return ["USER:"] 52 | 53 | 54 | def load_dataset(dataset_path): 55 | """Load dataset from JSON or JSONL file""" 56 | data = [] 57 | file_extension = os.path.splitext(dataset_path)[1].lower() 58 | 59 | try: 60 | if file_extension == ".json": 61 | with open(dataset_path, "r", encoding="utf-8") as f: 62 | data = json.load(f) 63 | elif file_extension == ".jsonl": 64 | with open(dataset_path, "r", encoding="utf-8") as f: 65 | for line in f: 66 | if line.strip(): 67 | data.append(json.loads(line)) 68 | else: 69 | # Try both formats if extension doesn't match 70 | try: 71 | with open(dataset_path, "r", encoding="utf-8") as f: 72 | data = json.load(f) 73 | except json.JSONDecodeError: 74 | with open(dataset_path, "r", encoding="utf-8") as f: 75 | for line in f: 76 | if line.strip(): 77 | data.append(json.loads(line)) 78 | 79 | return data 80 | except Exception as e: 81 | print(f"Error loading dataset: {e}") 82 | return [] 83 | 84 | 85 | def extract_option(pred): 86 | """Extract answer option from model output""" 87 | # 1. get A/B/C/D 88 | for pattern in [ 89 | r"(.*?)", 90 | r"(.*?)", 91 | r"^([A-Z])[.,:]", 92 | r"Answer:\s*([A-Z])\s*", 93 | ]: 94 | match = re.search(pattern, pred, re.DOTALL) 95 | if match is not None: 96 | pred = match.group(1) 97 | 98 | # 2. remove <> 99 | pred = pred.replace("<", "").replace(">", "") 100 | pred = pred.strip() 101 | 102 | # 3. Only keep first character if it's a valid option 103 | if pred and pred[0] in "ABCDE": 104 | return pred[0] 105 | return pred 106 | 107 | 108 | def check_answer(prediction, correct_answer): 109 | """Check if the prediction matches the correct answer""" 110 | extracted_pred = extract_option(prediction) 111 | extracted_answer = extract_option(correct_answer) 112 | return extracted_pred == extracted_answer 113 | 114 | 115 | def resampling_inference( 116 | model_name_or_path: str, 117 | dataset_path: str, 118 | template: str = "qwen", 119 | temperature: float = 0.95, 120 | top_p: float = 0.7, 121 | top_k: int = 50, 122 | max_new_tokens: int = 8192, 123 | repetition_penalty: float = 1.0, 124 | tensor_parallel_size: int = 4, 125 | max_model_len: int = 10240, 126 | prediction_key: str = "qwq_sft", 127 | max_attempts: int = 8, 128 | output_path: str = None, 129 | ): 130 | """ 131 | Perform resampling inference: retry incorrect answers up to max_attempts times 132 | """ 133 | print(f"Loading model: {model_name_or_path}") 134 | print(f"Using template: {template}") 135 | 136 | # Load dataset 137 | dataset = load_dataset(dataset_path) 138 | if not dataset: 139 | print(f"Failed to load dataset from {dataset_path} or dataset is empty.") 140 | return 141 | 142 | print(f"Loaded {len(dataset)} examples from dataset.") 143 | 144 | # Initialize LLM 145 | sampling_params = SamplingParams( 146 | temperature=temperature, 147 | top_p=top_p, 148 | top_k=top_k, 149 | max_tokens=max_new_tokens, 150 | repetition_penalty=repetition_penalty, 151 | stop=get_stop_tokens(template), 152 | ) 153 | 154 | print(f"Initializing LLM with tensor_parallel_size={tensor_parallel_size}") 155 | llm = LLM( 156 | model=model_name_or_path, 157 | tensor_parallel_size=tensor_parallel_size, 158 | distributed_executor_backend="mp", 159 | max_model_len=max_model_len, 160 | trust_remote_code=True, 161 | ) 162 | 163 | # Initialize tracking variables 164 | items_to_resample = list(range(len(dataset))) 165 | attempt_counts = [0] * len(dataset) 166 | correct_items = set() 167 | 168 | # Store original results 169 | original_results = [None] * len(dataset) 170 | 171 | # Track metrics 172 | total_correct = 0 173 | initial_correct = 0 174 | 175 | # Resampling loop - continue until all items correct or max attempts reached 176 | while items_to_resample and max(attempt_counts) < max_attempts: 177 | current_batch_indices = items_to_resample.copy() 178 | items_to_resample = [] 179 | 180 | # Prepare prompts for current batch 181 | prompts = [] 182 | for idx in current_batch_indices: 183 | item = dataset[idx] 184 | instruction = item.get( 185 | "instruction", item.get("input", item.get("prompt", item.get("query", ""))) 186 | ) 187 | if not instruction: 188 | print(f"Warning: Couldn't find instruction in item: {item}") 189 | continue 190 | 191 | formatted_prompt = format_prompt(instruction, template) 192 | prompts.append(formatted_prompt) 193 | attempt_counts[idx] += 1 194 | 195 | if not prompts: 196 | break 197 | 198 | print(f"Starting batch generation for {len(prompts)} items...") 199 | outputs = llm.generate(prompts, sampling_params) 200 | 201 | # Process outputs and determine which items need resampling 202 | for batch_idx, (idx, output) in enumerate(zip(current_batch_indices, outputs)): 203 | item = dataset[idx] 204 | generated_text = output.outputs[0].text.strip() 205 | 206 | # Save first attempt result 207 | if attempt_counts[idx] == 1: 208 | original_results[idx] = generated_text 209 | item[f"{prediction_key}_original"] = generated_text 210 | 211 | # Check if answer is correct 212 | is_correct = False 213 | if "output" in item: 214 | is_correct = check_answer(generated_text, item["output"]) 215 | 216 | # Track first attempt accuracy 217 | if attempt_counts[idx] == 1 and is_correct: 218 | initial_correct += 1 219 | 220 | # Update item with latest prediction 221 | item[prediction_key] = generated_text 222 | 223 | if is_correct: 224 | correct_items.add(idx) 225 | total_correct += 1 226 | print(f"Item {idx} correct on attempt {attempt_counts[idx]}") 227 | else: 228 | # If not correct and under max attempts, add to resample list 229 | if attempt_counts[idx] < max_attempts: 230 | items_to_resample.append(idx) 231 | 232 | # For items reaching max attempts without success, restore original result 233 | if attempt_counts[idx] == max_attempts - 1: 234 | print(f"Item {idx} failed after {max_attempts} attempts, reverting to original result") 235 | 236 | # For items that reached max attempts without success, restore original result 237 | for idx in range(len(dataset)): 238 | if idx not in correct_items and attempt_counts[idx] >= max_attempts: 239 | dataset[idx][prediction_key] = original_results[idx] 240 | 241 | # Calculate accuracy 242 | final_correct = sum(1 for idx in range(len(dataset)) 243 | if "output" in dataset[idx] and 244 | check_answer(dataset[idx][prediction_key], dataset[idx]["output"])) 245 | 246 | print("\n" + "="*50) 247 | print(f"Resampling results:") 248 | print(f"Total examples: {len(dataset)}") 249 | print(f"Initial accuracy: {initial_correct/len(dataset)*100:.2f}% ({initial_correct}/{len(dataset)})") 250 | print(f"Final accuracy: {final_correct/len(dataset)*100:.2f}% ({final_correct}/{len(dataset)})") 251 | print(f"Improvement: {(final_correct-initial_correct)/len(dataset)*100:.2f}%") 252 | 253 | # Attempt distribution 254 | attempts_hist = {} 255 | for count in attempt_counts: 256 | attempts_hist[count] = attempts_hist.get(count, 0) + 1 257 | 258 | print("\nAttempt distribution:") 259 | for attempt, count in sorted(attempts_hist.items()): 260 | print(f" {attempt} attempt(s): {count} example(s)") 261 | print("="*50) 262 | 263 | # Determine output path 264 | if output_path is None: 265 | filename, ext = os.path.splitext(dataset_path) 266 | output_path = f"{filename}_resampled{ext}" 267 | 268 | # Save results 269 | try: 270 | file_extension = os.path.splitext(output_path)[1].lower() 271 | if file_extension == ".json": 272 | with open(output_path, "w", encoding="utf-8") as f: 273 | json.dump(dataset, f, ensure_ascii=False, indent=2) 274 | else: # Use JSONL format by default 275 | with open(output_path, "w", encoding="utf-8") as f: 276 | for item in dataset: 277 | f.write(json.dumps(item, ensure_ascii=False) + "\n") 278 | 279 | print(f"Results saved to {output_path}") 280 | except Exception as e: 281 | print(f"Error saving results: {e}") 282 | 283 | if output_path: 284 | base_path, ext = os.path.splitext(output_path) 285 | true_output_path = f"{base_path}_true{ext}" 286 | else: 287 | base_path, ext = os.path.splitext(dataset_path) 288 | true_output_path = f"{base_path}_resampled_true{ext}" 289 | 290 | # 筛选正确答案 291 | correct_data = [] 292 | for item in dataset: 293 | if "output" in item and check_answer(item[prediction_key], item["output"]): 294 | correct_data.append(item) 295 | 296 | # 保存正确答案数据集 297 | try: 298 | if os.path.splitext(true_output_path)[1].lower() == ".json": 299 | with open(true_output_path, "w", encoding="utf-8") as f: 300 | json.dump(correct_data, f, ensure_ascii=False, indent=2) 301 | else: 302 | with open(true_output_path, "w", encoding="utf-8") as f: 303 | for item in correct_data: 304 | f.write(json.dumps(item, ensure_ascii=False) + "\n") 305 | print(f"Correct results saved to {true_output_path}") 306 | print(f"Correct samples count: {len(correct_data)}/{len(dataset)} " 307 | f"({len(correct_data)/len(dataset)*100:.2f}%)") 308 | except Exception as e: 309 | print(f"Error saving correct results: {e}") 310 | 311 | # Clean up resources 312 | del llm 313 | clean_up() 314 | 315 | 316 | def parse_args(): 317 | parser = argparse.ArgumentParser(description="vLLM Resampling Inference") 318 | parser.add_argument( 319 | "--model_name_or_path", 320 | type=str, 321 | required=True, 322 | help="Path to pretrained model or model identifier from huggingface.co/models", 323 | ) 324 | parser.add_argument( 325 | "--dataset_path", 326 | type=str, 327 | required=True, 328 | help="Path to dataset file (JSON or JSONL)", 329 | ) 330 | parser.add_argument( 331 | "--template", 332 | type=str, 333 | default="qwen", 334 | choices=["qwen", "llama", "mistral", "deepseek"], 335 | help="Prompt template to use", 336 | ) 337 | parser.add_argument( 338 | "--temperature", type=float, default=0.95, help="Sampling temperature" 339 | ) 340 | parser.add_argument( 341 | "--top_p", type=float, default=0.7, help="Top-p sampling parameter" 342 | ) 343 | parser.add_argument( 344 | "--top_k", type=int, default=50, help="Top-k sampling parameter" 345 | ) 346 | parser.add_argument( 347 | "--max_new_tokens", 348 | type=int, 349 | default=10240, 350 | help="Maximum number of tokens to generate", 351 | ) 352 | parser.add_argument( 353 | "--repetition_penalty", 354 | type=float, 355 | default=1.0, 356 | help="Repetition penalty parameter", 357 | ) 358 | parser.add_argument( 359 | "--tensor_parallel_size", 360 | type=int, 361 | default=8, 362 | help="Tensor parallel size for distributed inference", 363 | ) 364 | parser.add_argument( 365 | "--max_model_len", 366 | type=int, 367 | default=20480, 368 | help="Maximum model sequence length" 369 | ) 370 | parser.add_argument( 371 | "--prediction_key", 372 | type=str, 373 | default="qwq_sft", 374 | help="Key to use when storing model predictions in the dataset", 375 | ) 376 | parser.add_argument( 377 | "--max_attempts", 378 | type=int, 379 | default=8, 380 | help="Maximum number of resampling attempts for each incorrect answer", 381 | ) 382 | parser.add_argument( 383 | "--output_path", 384 | type=str, 385 | default=None, 386 | help="Path to save the resampled dataset (defaults to dataset_path with '_resampled' suffix)", 387 | ) 388 | 389 | return parser.parse_args() 390 | 391 | 392 | if __name__ == "__main__": 393 | args = parse_args() 394 | resampling_inference( 395 | model_name_or_path=args.model_name_or_path, 396 | dataset_path=args.dataset_path, 397 | template=args.template, 398 | temperature=args.temperature, 399 | top_p=args.top_p, 400 | top_k=args.top_k, 401 | max_new_tokens=args.max_new_tokens, 402 | repetition_penalty=args.repetition_penalty, 403 | tensor_parallel_size=args.tensor_parallel_size, 404 | max_model_len=args.max_model_len, 405 | prediction_key=args.prediction_key, 406 | max_attempts=args.max_attempts, 407 | output_path=args.output_path, 408 | ) -------------------------------------------------------------------------------- /license: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /process/process_casehold.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from datasets import load_dataset 4 | 5 | template = """ 6 | You are a professional legal expert specializing in case law analysis, skilled in analyzing determine which of the following holding # Statements correctly. Please first think step-by-step using the # Retrieved Documents and then identify the correct holding # Statement by using your own knowledge. Your responses will be used for research purposes only, so please have a definite answer. 7 | You should respond in the format: 8 | 9 | ... 10 | 11 | A/B/C/D/E (only one option can be chosen) 12 | 13 | # Retrieved Documents 14 | {citing_prompt} 15 | 16 | # Statements 17 | A. {holding_0} 18 | B. {holding_1} 19 | C. {holding_2} 20 | D. {holding_3} 21 | E. {holding_4}""" 22 | 23 | 24 | def process_casehold_dataset(): 25 | data_dir = os.path.join(".", "data") 26 | 27 | label_to_output = { 28 | "0": "A", 29 | "1": "B", 30 | "2": "C", 31 | "3": "D", 32 | "4": "E", 33 | } 34 | 35 | dataset = load_dataset("casehold/casehold", "all", trust_remote_code=True) 36 | 37 | for split_name in ["train", "test"]: 38 | print(f"processing {split_name} split...") 39 | split_data = dataset[split_name] 40 | 41 | processed_data = [] 42 | 43 | count = 0 44 | 45 | for item in split_data: 46 | processed_item = {} 47 | 48 | for key in item: 49 | processed_item[key] = item[key] 50 | 51 | label = str(item.get("label")) 52 | if label in label_to_output: 53 | processed_item["output"] = label_to_output[label] 54 | count += 1 55 | else: 56 | print(f"warning, unknown label: {label}") 57 | 58 | citing_prompt = item.get("citing_prompt", "") 59 | holding_0 = item.get("holding_0", "") 60 | holding_1 = item.get("holding_1", "") 61 | holding_2 = item.get("holding_2", "") 62 | holding_3 = item.get("holding_3", "") 63 | holding_4 = item.get("holding_4", "") 64 | 65 | processed_item["instruction"] = template.format( 66 | citing_prompt=citing_prompt, 67 | holding_0=holding_0, 68 | holding_1=holding_1, 69 | holding_2=holding_2, 70 | holding_3=holding_3, 71 | holding_4=holding_4, 72 | ) 73 | 74 | processed_data.append(processed_item) 75 | 76 | output_file = os.path.join(data_dir, f"{split_name}_casehold.json") 77 | with open(output_file, "w", encoding="utf-8") as f: 78 | json.dump(processed_data, f, ensure_ascii=False, indent=2) 79 | 80 | print(f"Have processed {count}, save to: {output_file}") 81 | 82 | 83 | if __name__ == "__main__": 84 | process_casehold_dataset() 85 | -------------------------------------------------------------------------------- /process/process_finfact.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import os 4 | import math 5 | from datasets import load_dataset 6 | 7 | output_train_path = "data/train_finfact.json" 8 | output_test_path = "data/test_finfact.json" 9 | 10 | # template 11 | template = """You are a professional financial expert in fact-checking, skilled in analyzing the accuracy of # Statement. Please first think step-by-step using the # Retrieved Documents and then check # Statement by using your own knowledge. Your responses will be used for research purposes only, so please have a definite answer. 12 | 13 | You should respond in the format: 14 | 15 | ... 16 | 17 | A/B/C (only one option can be chosen) 18 | 19 | # Retrieved Documents 20 | {documents} 21 | 22 | # Statement 23 | {claim}\nA. true\nB. false\nC. NEI""" 24 | 25 | ds = load_dataset("amanrangapur/Fin-Fact") 26 | 27 | data = [] 28 | for item in ds["train"]: 29 | data_item = {} 30 | for key in item: 31 | data_item[key] = item[key] 32 | 33 | 34 | for item in data: 35 | if item["label"] == "true": 36 | item["output"] = "A" 37 | elif item["label"] == "false": 38 | item["output"] = "B" 39 | elif item["label"] == "NEI": 40 | item["output"] = "C" 41 | 42 | documents_text = "" 43 | if "evidence" in item and isinstance(item["evidence"], list): 44 | for evidence in item["evidence"]: 45 | 46 | if evidence["sentence"] is None or (isinstance(evidence["sentence"], float) and math.isnan(evidence["sentence"])): 47 | continue 48 | 49 | if isinstance(evidence["sentence"], str): 50 | documents_text += evidence["sentence"] + "\n" 51 | else: 52 | try: 53 | documents_text += str(evidence["sentence"]) + "\n" 54 | except: 55 | pass 56 | 57 | item["documents"] = documents_text.strip() 58 | 59 | item["instruction"] = template.format( 60 | documents=item["documents"], 61 | claim=item["claim"] 62 | ) 63 | 64 | # Shuffle the data randomly 65 | random.seed(42) 66 | random.shuffle(data) 67 | 68 | # Split into training set (80%) and test set (20%) 69 | split_idx = int(len(data) * 0.8) 70 | train_data = data[:split_idx] 71 | test_data = data[split_idx:] 72 | 73 | # Save as JSON files 74 | with open(output_train_path, "w", encoding="utf-8") as f: 75 | json.dump(train_data, f, ensure_ascii=False, indent=2) 76 | 77 | with open(output_test_path, "w", encoding="utf-8") as f: 78 | json.dump(test_data, f, ensure_ascii=False, indent=2) 79 | 80 | print(f"Processing complete!") 81 | print(f"Training set size: {len(train_data)} items, saved to {output_train_path}") 82 | print(f"Test set size: {len(test_data)} items, saved to {output_test_path}") -------------------------------------------------------------------------------- /process/process_medqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from os.path import join as pjoin 5 | import re 6 | from typing import List, Dict, Tuple 7 | import logging 8 | import random 9 | 10 | import torch 11 | import numpy as np 12 | from transformers import AutoModel, AutoTokenizer 13 | from tqdm import tqdm 14 | from qdrant_client import QdrantClient 15 | from qdrant_client.http import models 16 | 17 | logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") 18 | 19 | template = """You are a professional medical expert to answer the # Question. Please first think step-by-step using the # Retrieved Documents and then answer the question. Your responses will be used for research purposes only, so please have a definite answer. 20 | 21 | The format should be like: 22 | 23 | ... 24 | 25 | A/B/C/D (only one option can be chosen) 26 | 27 | # Retrieved Documents 28 | {documents} 29 | 30 | # Question 31 | {question}""" 32 | 33 | 34 | class DocumentProcessor: 35 | def __init__(self, model_name: str = "BAAI/llm-embedder"): 36 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 37 | logging.info(f"Using device: {self.device}") 38 | 39 | logging.info(f"Loading model: {model_name}") 40 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 41 | self.model = AutoModel.from_pretrained(model_name).to(self.device) 42 | self.model.eval() 43 | 44 | self.qdrant = QdrantClient(":memory:") 45 | 46 | def extract_document_blocks(self, text: str) -> List[Dict[str, str]]: 47 | blocks = [] 48 | lines = [] 49 | current_source = None 50 | 51 | for line in text.split("\n"): 52 | if line.startswith("## source:"): 53 | if current_source and lines: 54 | blocks.append( 55 | {"source": current_source, "content": "\n".join(lines)} 56 | ) 57 | lines = [] 58 | current_source = line 59 | lines = [line] 60 | elif current_source is not None: 61 | lines.append(line) 62 | 63 | if current_source and lines: 64 | blocks.append({"source": current_source, "content": "\n".join(lines)}) 65 | 66 | return blocks 67 | 68 | def get_embedding(self, text: str) -> np.ndarray: 69 | with torch.no_grad(): 70 | inputs = self.tokenizer( 71 | text, return_tensors="pt", max_length=512, truncation=True, padding=True 72 | ) 73 | inputs = {k: v.to(self.device) for k, v in inputs.items()} 74 | outputs = self.model(**inputs) 75 | embeddings = outputs.last_hidden_state.mean(dim=1) 76 | return embeddings.cpu().numpy() 77 | 78 | def rank_blocks(self, blocks: List[Dict[str, str]], query: str) -> List[str]: 79 | if not blocks: 80 | return [] 81 | 82 | block_contents = [block["content"] for block in blocks] 83 | doc_embeddings = np.vstack( 84 | [self.get_embedding(content) for content in block_contents] 85 | ) 86 | query_embedding = self.get_embedding(query) 87 | 88 | vector_size = doc_embeddings.shape[1] 89 | 90 | collection_name = "temp_collection" 91 | self.qdrant.recreate_collection( 92 | collection_name=collection_name, 93 | vectors_config=models.VectorParams( 94 | size=vector_size, distance=models.Distance.COSINE 95 | ), 96 | ) 97 | 98 | self.qdrant.upload_points( 99 | collection_name=collection_name, 100 | points=[ 101 | models.PointStruct( 102 | id=idx, vector=embedding.tolist(), payload={"content": content} 103 | ) 104 | for idx, (embedding, content) in enumerate( 105 | zip(doc_embeddings, block_contents) 106 | ) 107 | ], 108 | ) 109 | 110 | k = min(len(blocks), 3) 111 | 112 | results = self.qdrant.search( 113 | collection_name=collection_name, 114 | query_vector=query_embedding[0].tolist(), 115 | limit=k, 116 | ) 117 | relevant_blocks = [hit.payload["content"] for hit in results] 118 | 119 | self.qdrant.delete_collection(collection_name) 120 | 121 | return relevant_blocks 122 | 123 | def process_example(self, content: str, question: str) -> Dict: 124 | blocks = self.extract_document_blocks(content) 125 | if not blocks: 126 | documents_str = "" 127 | else: 128 | relevant_blocks = self.rank_blocks(blocks, question) 129 | documents_str = "\n".join(relevant_blocks) 130 | 131 | instruction = template.format(documents=documents_str, question=question) 132 | 133 | return { 134 | "output": "", 135 | "documents": documents_str, 136 | "question": question, 137 | "instruction": instruction, 138 | } 139 | 140 | 141 | def load_dataset(dataset_name: str) -> List[Dict]: 142 | plan_name = f"system=planner_addret,dataset={dataset_name},debug=False" 143 | output_all_path = pjoin("process", "rare_share", plan_name, "output_all.json") 144 | logging.info(f"Loading data from {output_all_path}") 145 | 146 | try: 147 | with open(output_all_path, "r", encoding="utf-8") as f: 148 | return json.load(f) 149 | except Exception as e: 150 | logging.error(f"Error loading dataset {dataset_name}: {str(e)}") 151 | return [] 152 | 153 | 154 | def process_dataset(processor, dataset_items, dataset_type, target_names): 155 | filtered_data = {name: [] for name in target_names} 156 | 157 | logging.info(f"Processing {len(dataset_items)} items for {dataset_type} dataset") 158 | 159 | for item in tqdm(dataset_items): 160 | try: 161 | if "pred" not in item or "doc_path" not in item["pred"]: 162 | logging.warning( 163 | f"Missing pred or doc_path in item: {item.get('id', 'NO_ID')}" 164 | ) 165 | continue 166 | 167 | item_name = item.get("name", "") 168 | if item_name not in target_names: 169 | continue 170 | 171 | doc_path = item["pred"]["doc_path"] 172 | 173 | doc_path = doc_path.replace("\\", "/") 174 | 175 | if not os.path.exists(doc_path): 176 | if doc_path.startswith("process/"): 177 | alternative_path = doc_path[8:] 178 | if os.path.exists(alternative_path): 179 | doc_path = alternative_path 180 | else: 181 | logging.warning(f"Document file not found: {doc_path}") 182 | continue 183 | 184 | try: 185 | with open(doc_path, "r", encoding="utf-8") as f: 186 | content = f.read() 187 | except FileNotFoundError: 188 | logging.warning(f"Document file not found: {doc_path}") 189 | continue 190 | except Exception as e: 191 | logging.warning(f"Error reading document {doc_path}: {str(e)}") 192 | continue 193 | 194 | question = item.get("question", "") 195 | processed_item = processor.process_example(content, question) 196 | processed_item["output"] = f"{item.get('gold', '')}" 197 | processed_item["name"] = item_name 198 | processed_item["id"] = item.get("id", "") 199 | 200 | filtered_data[item_name].append(processed_item) 201 | 202 | except Exception as e: 203 | logging.error(f"Error processing item {item.get('id', 'NO_ID')}: {str(e)}") 204 | continue 205 | 206 | return filtered_data 207 | 208 | 209 | def save_results(filtered_data, dataset_type, target_names): 210 | data_dir = os.path.join(".", "data") 211 | os.makedirs(data_dir, exist_ok=True) 212 | 213 | for name in target_names: 214 | items = filtered_data[name] 215 | if items: 216 | output_filename = os.path.join(data_dir, f"{dataset_type}_{name}.json") 217 | logging.info(f"Saving {len(items)} {name} items to {output_filename}") 218 | 219 | with open(output_filename, "w", encoding="utf-8") as f: 220 | json.dump(items, f, ensure_ascii=False, indent=2) 221 | else: 222 | logging.warning( 223 | f"No {name} items found, skipping file creation for {dataset_type}_{name}.json" 224 | ) 225 | 226 | total_processed = sum(len(items) for items in filtered_data.values()) 227 | logging.info( 228 | f"Successfully processed and saved {total_processed} examples for {dataset_type} dataset" 229 | ) 230 | 231 | 232 | def main(): 233 | processor = DocumentProcessor() 234 | target_names = ["medqa"] 235 | 236 | train_datasets = ["all_train", "all_dev"] 237 | train_items = [] 238 | for dataset in train_datasets: 239 | train_items.extend(load_dataset(dataset)) 240 | 241 | train_filtered_data = process_dataset(processor, train_items, "train", target_names) 242 | save_results(train_filtered_data, "train", target_names) 243 | 244 | test_items = load_dataset("all_test") 245 | test_filtered_data = process_dataset(processor, test_items, "test", target_names) 246 | save_results(test_filtered_data, "test", target_names) 247 | 248 | logging.info("All processing completed successfully!") 249 | 250 | 251 | if __name__ == "__main__": 252 | main() -------------------------------------------------------------------------------- /process/process_mmrait.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import shutil 4 | 5 | input_file1 = "process/fact_verify/train_data.jsonl" 6 | output_file1 = "data/train_mmrait.json" 7 | 8 | input_file2 = "process/fact_verify/val_data.jsonl" 9 | output_file2 = "data/test_mmrait.json" 10 | 11 | source_train_img_dir = "process/fact_verify/train_document_images" 12 | target_train_img_dir = "data/train_document_images" 13 | 14 | source_val_img_dir = "process/fact_verify/val_document_images" 15 | target_val_img_dir = "data/val_document_images" 16 | 17 | instruction_template = """You are a professional medical expert in fact-checking, skilled in analyzing the accuracy of # Statement. Please first think step-by-step using the # Retrieved Documents and # Image related and then check # Statement by using your own knowledge. Your responses will be used for research purposes only, so please have a definite answer. 18 | 19 | You should respond in the format: 20 | 21 | ... 22 | 23 | A/B/C (only one option can be chosen) 24 | 25 | # Retrieved Documents 26 | {documents} 27 | 28 | # Image 29 | 30 | 31 | # Statement 32 | {claim} 33 | A. Support 34 | B. Refute 35 | C. Insufficient""" 36 | 37 | def process_line(line, is_train=True): 38 | data = json.loads(line) 39 | 40 | if data["category"] == "Support": 41 | data["output"] = "A" 42 | elif data["category"] == "Refute": 43 | data["output"] = "B" 44 | elif data["category"] == "Insufficient": 45 | data["output"] = "C" 46 | 47 | data["instruction"] = instruction_template.format( 48 | documents=data["document"], 49 | claim=data["claim"] 50 | ) 51 | 52 | data["messages"] = [ 53 | { 54 | "role": "user", 55 | "content": data["instruction"] 56 | } 57 | ] 58 | 59 | old_img_path = data["document_image"] 60 | img_filename = os.path.basename(old_img_path) 61 | 62 | if is_train: 63 | new_img_path = f"data/train_document_images/{img_filename}" 64 | else: 65 | new_img_path = f"data/val_document_images/{img_filename}" 66 | 67 | data["images"] = [new_img_path] 68 | 69 | return data 70 | 71 | def copy_image_directory(source_dir, target_dir): 72 | if not os.path.exists(target_dir): 73 | os.makedirs(target_dir) 74 | print(f"Create a directory: {target_dir}") 75 | 76 | if os.path.exists(source_dir): 77 | for filename in os.listdir(source_dir): 78 | source_file = os.path.join(source_dir, filename) 79 | target_file = os.path.join(target_dir, filename) 80 | 81 | if os.path.isfile(source_file): 82 | shutil.copy2(source_file, target_file) 83 | 84 | print(f"Copied image from {source_dir} to {target_dir}") 85 | else: 86 | print(f"Warning: Source directory does not exist {source_dir}") 87 | 88 | def main(): 89 | if not os.path.exists("data"): 90 | os.makedirs("data") 91 | print("Create a data directory") 92 | 93 | copy_image_directory(source_train_img_dir, target_train_img_dir) 94 | copy_image_directory(source_val_img_dir, target_val_img_dir) 95 | 96 | train_processed_data = [] 97 | with open(input_file1, 'r', encoding='utf-8') as f: 98 | for line in f: 99 | if line.strip(): 100 | processed_item = process_line(line, is_train=True) 101 | train_processed_data.append(processed_item) 102 | 103 | with open(output_file1, 'w', encoding='utf-8') as f: 104 | json.dump(train_processed_data, f, ensure_ascii=False, indent=2) 105 | 106 | print(f"Processed {len(train_processed_data)} training data and saved to {output_file1}") 107 | 108 | val_processed_data = [] 109 | with open(input_file2, 'r', encoding='utf-8') as f: 110 | for line in f: 111 | if line.strip(): 112 | processed_item = process_line(line, is_train=False) 113 | val_processed_data.append(processed_item) 114 | 115 | with open(output_file2, 'w', encoding='utf-8') as f: 116 | json.dump(val_processed_data, f, ensure_ascii=False, indent=2) 117 | 118 | print(f"Processed {len(val_processed_data)} test data and saved to {output_file2}") 119 | 120 | if __name__ == "__main__": 121 | main() 122 | -------------------------------------------------------------------------------- /process/process_pubhealth.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import json 3 | import os 4 | import logging 5 | 6 | logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") 7 | 8 | template = """You are a professional medical expert in fact-checking, skilled in analyzing the accuracy of # Statement. Please first think step-by-step using the # Retrieved Documents and then check # Statement by using your own knowledge. Your responses will be used for research purposes only, so please have a definite answer. 9 | You should respond in the format: 10 | 11 | ... 12 | 13 | A/B/C/D (only one option can be chosen) 14 | 15 | # Retrieved Documents 16 | {text_2} 17 | 18 | # Statement 19 | {text_1} 20 | A. true - The statement is entirely accurate and supported by solid evidence. 21 | B. false - The statement is completely untrue and contradicted by strong evidence. 22 | C. mixture - The statement is partially true but contains some inaccuracies or misleading elements. 23 | D. unproven - There is insufficient evidence to confirm or refute the statement.""" 24 | 25 | 26 | def download_and_process_pubhealth(): 27 | logging.info("Starting download of bigbio/pubhealth dataset...") 28 | 29 | data_dir = os.path.join(".", "data") 30 | os.makedirs(data_dir, exist_ok=True) 31 | 32 | dataset = load_dataset("bigbio/pubhealth", "pubhealth_bigbio_pairs", trust_remote_code=True) 33 | logging.info(f"Dataset loaded with splits: {dataset.keys()}") 34 | 35 | label_to_output = { 36 | "true": "A", 37 | "false": "B", 38 | "mixture": "C", 39 | "unproven": "D", 40 | } 41 | 42 | train_data = [] 43 | 44 | for split in ["train", "validation"]: 45 | if split in dataset: 46 | split_data = dataset[split] 47 | logging.info(f"Processing {split} split with {len(split_data)} items") 48 | 49 | for item in split_data: 50 | processed_item = process_item(item, label_to_output, template) 51 | if processed_item: 52 | train_data.append(processed_item) 53 | 54 | train_output_file = os.path.join(data_dir, "train_pubhealth.json") 55 | save_json(train_data, train_output_file) 56 | logging.info(f"Saved train data to {train_output_file} ({len(train_data)} items)") 57 | 58 | test_data = [] 59 | if "test" in dataset: 60 | test_split = dataset["test"] 61 | logging.info(f"Processing test split with {len(test_split)} items") 62 | 63 | for item in test_split: 64 | processed_item = process_item(item, label_to_output, template) 65 | if processed_item: 66 | test_data.append(processed_item) 67 | 68 | test_output_file = os.path.join(data_dir, "test_pubhealth.json") 69 | save_json(test_data, test_output_file) 70 | logging.info(f"Saved test data to {test_output_file} ({len(test_data)} items)") 71 | 72 | logging.info("Processing complete!") 73 | 74 | 75 | def process_item(item, label_to_output, instruction_template): 76 | 77 | label = item.get("label") 78 | if label not in label_to_output: 79 | logging.warning(f"Warning: Unknown label encountered: {label}") 80 | return None 81 | 82 | text_1 = item.get("text_1", "") 83 | text_2 = item.get("text_2", "") 84 | 85 | processed_item = { 86 | "output": label_to_output[label], 87 | "documents": text_2, 88 | "question": text_1, 89 | "instruction": instruction_template.format(text_1=text_1, text_2=text_2), 90 | "id": item.get("id", ""), 91 | "name": "pubhealth", 92 | } 93 | 94 | return processed_item 95 | 96 | 97 | def save_json(data, output_file): 98 | with open(output_file, "w", encoding="utf-8") as f: 99 | json.dump(data, f, ensure_ascii=False, indent=2) 100 | 101 | 102 | if __name__ == "__main__": 103 | download_and_process_pubhealth() 104 | -------------------------------------------------------------------------------- /process/process_pubmed.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from os.path import join as pjoin 5 | import re 6 | from typing import List, Dict, Tuple 7 | import logging 8 | import random 9 | 10 | import torch 11 | import numpy as np 12 | from transformers import AutoModel, AutoTokenizer 13 | from tqdm import tqdm 14 | from qdrant_client import QdrantClient 15 | from qdrant_client.http import models 16 | 17 | logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") 18 | 19 | template = """You are a professional medical expert to answer the # Question. Please first think step-by-step using the # Retrieved Documents and then answer the question. Your responses will be used for research purposes only, so please have a definite answer. 20 | 21 | The format should be like: 22 | 23 | ... 24 | 25 | A/B/C/D (only one option can be chosen) 26 | 27 | # Retrieved Documents 28 | {documents} 29 | 30 | # Question 31 | {question}""" 32 | 33 | 34 | class DocumentProcessor: 35 | def __init__(self, model_name: str = "BAAI/llm-embedder"): 36 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 37 | logging.info(f"Using device: {self.device}") 38 | 39 | logging.info(f"Loading model: {model_name}") 40 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 41 | self.model = AutoModel.from_pretrained(model_name).to(self.device) 42 | self.model.eval() 43 | 44 | self.qdrant = QdrantClient(":memory:") 45 | 46 | def extract_document_blocks(self, text: str) -> List[Dict[str, str]]: 47 | blocks = [] 48 | lines = [] 49 | current_source = None 50 | 51 | for line in text.split("\n"): 52 | if line.startswith("## source:"): 53 | if current_source and lines: 54 | blocks.append( 55 | {"source": current_source, "content": "\n".join(lines)} 56 | ) 57 | lines = [] 58 | current_source = line 59 | lines = [line] 60 | elif current_source is not None: 61 | lines.append(line) 62 | 63 | if current_source and lines: 64 | blocks.append({"source": current_source, "content": "\n".join(lines)}) 65 | 66 | return blocks 67 | 68 | def get_embedding(self, text: str) -> np.ndarray: 69 | with torch.no_grad(): 70 | inputs = self.tokenizer( 71 | text, return_tensors="pt", max_length=512, truncation=True, padding=True 72 | ) 73 | inputs = {k: v.to(self.device) for k, v in inputs.items()} 74 | outputs = self.model(**inputs) 75 | embeddings = outputs.last_hidden_state.mean(dim=1) 76 | return embeddings.cpu().numpy() 77 | 78 | def rank_blocks(self, blocks: List[Dict[str, str]], query: str) -> List[str]: 79 | if not blocks: 80 | return [] 81 | 82 | block_contents = [block["content"] for block in blocks] 83 | doc_embeddings = np.vstack( 84 | [self.get_embedding(content) for content in block_contents] 85 | ) 86 | query_embedding = self.get_embedding(query) 87 | 88 | vector_size = doc_embeddings.shape[1] 89 | 90 | collection_name = "temp_collection" 91 | self.qdrant.recreate_collection( 92 | collection_name=collection_name, 93 | vectors_config=models.VectorParams( 94 | size=vector_size, distance=models.Distance.COSINE 95 | ), 96 | ) 97 | 98 | self.qdrant.upload_points( 99 | collection_name=collection_name, 100 | points=[ 101 | models.PointStruct( 102 | id=idx, vector=embedding.tolist(), payload={"content": content} 103 | ) 104 | for idx, (embedding, content) in enumerate( 105 | zip(doc_embeddings, block_contents) 106 | ) 107 | ], 108 | ) 109 | 110 | k = min(len(blocks), 3) 111 | 112 | results = self.qdrant.search( 113 | collection_name=collection_name, 114 | query_vector=query_embedding[0].tolist(), 115 | limit=k, 116 | ) 117 | relevant_blocks = [hit.payload["content"] for hit in results] 118 | 119 | self.qdrant.delete_collection(collection_name) 120 | 121 | return relevant_blocks 122 | 123 | def process_example(self, content: str, question: str) -> Dict: 124 | blocks = self.extract_document_blocks(content) 125 | if not blocks: 126 | documents_str = "" 127 | else: 128 | relevant_blocks = self.rank_blocks(blocks, question) 129 | documents_str = "\n".join(relevant_blocks) 130 | 131 | instruction = template.format(documents=documents_str, question=question) 132 | 133 | return { 134 | "output": "", 135 | "documents": documents_str, 136 | "question": question, 137 | "instruction": instruction, 138 | } 139 | 140 | 141 | def load_dataset(dataset_name: str) -> List[Dict]: 142 | plan_name = f"system=planner_addret,dataset={dataset_name},debug=False" 143 | output_all_path = pjoin("process", "rare_share", plan_name, "output_all.json") 144 | logging.info(f"Loading data from {output_all_path}") 145 | 146 | try: 147 | with open(output_all_path, "r", encoding="utf-8") as f: 148 | return json.load(f) 149 | except Exception as e: 150 | logging.error(f"Error loading dataset {dataset_name}: {str(e)}") 151 | return [] 152 | 153 | 154 | def process_dataset(processor, dataset_items, dataset_type, target_names): 155 | filtered_data = {name: [] for name in target_names} 156 | 157 | logging.info(f"Processing {len(dataset_items)} items for {dataset_type} dataset") 158 | 159 | for item in tqdm(dataset_items): 160 | try: 161 | if "pred" not in item or "doc_path" not in item["pred"]: 162 | logging.warning( 163 | f"Missing pred or doc_path in item: {item.get('id', 'NO_ID')}" 164 | ) 165 | continue 166 | 167 | item_name = item.get("name", "") 168 | if item_name not in target_names: 169 | continue 170 | 171 | # 使用更新后的doc_path 172 | doc_path = item["pred"]["doc_path"] 173 | 174 | doc_path = doc_path.replace("\\", "/") 175 | 176 | # 确保文件路径存在 177 | if not os.path.exists(doc_path): 178 | # 尝试从工作目录解析相对路径 179 | if doc_path.startswith("process/"): 180 | # 移除前缀"process/"以使其相对于当前目录 181 | alternative_path = doc_path[8:] 182 | if os.path.exists(alternative_path): 183 | doc_path = alternative_path 184 | else: 185 | logging.warning(f"Document file not found: {doc_path}") 186 | continue 187 | 188 | try: 189 | with open(doc_path, "r", encoding="utf-8") as f: 190 | content = f.read() 191 | except FileNotFoundError: 192 | logging.warning(f"Document file not found: {doc_path}") 193 | continue 194 | except Exception as e: 195 | logging.warning(f"Error reading document {doc_path}: {str(e)}") 196 | continue 197 | 198 | question = item.get("question", "") 199 | processed_item = processor.process_example(content, question) 200 | processed_item["output"] = f"{item.get('gold', '')}" 201 | processed_item["name"] = item_name 202 | processed_item["id"] = item.get("id", "") 203 | 204 | filtered_data[item_name].append(processed_item) 205 | 206 | except Exception as e: 207 | logging.error(f"Error processing item {item.get('id', 'NO_ID')}: {str(e)}") 208 | continue 209 | 210 | return filtered_data 211 | 212 | 213 | def save_results(filtered_data, dataset_type, target_names): 214 | data_dir = os.path.join(".", "data") 215 | os.makedirs(data_dir, exist_ok=True) 216 | 217 | for name in target_names: 218 | items = filtered_data[name] 219 | if items: 220 | output_filename = os.path.join(data_dir, f"{dataset_type}_pubmed.json") 221 | logging.info(f"Saving {len(items)} {name} items to {output_filename}") 222 | 223 | with open(output_filename, "w", encoding="utf-8") as f: 224 | json.dump(items, f, ensure_ascii=False, indent=2) 225 | else: 226 | logging.warning( 227 | f"No {name} items found, skipping file creation for {dataset_type}_pubmed.json" 228 | ) 229 | 230 | total_processed = sum(len(items) for items in filtered_data.values()) 231 | logging.info( 232 | f"Successfully processed and saved {total_processed} examples for {dataset_type} dataset" 233 | ) 234 | 235 | 236 | def main(): 237 | processor = DocumentProcessor() 238 | target_names = ["pubmedqa"] 239 | 240 | train_datasets = ["all_train", "all_dev"] 241 | train_items = [] 242 | for dataset in train_datasets: 243 | train_items.extend(load_dataset(dataset)) 244 | 245 | train_filtered_data = process_dataset(processor, train_items, "train", target_names) 246 | save_results(train_filtered_data, "train", target_names) 247 | 248 | test_items = load_dataset("all_test") 249 | test_filtered_data = process_dataset(processor, test_items, "test", target_names) 250 | save_results(test_filtered_data, "test", target_names) 251 | 252 | logging.info("All processing completed successfully!") 253 | 254 | 255 | if __name__ == "__main__": 256 | main() -------------------------------------------------------------------------------- /process/select_true.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import argparse 4 | import os 5 | import shutil 6 | from pathlib import Path 7 | 8 | def extract_option(pred): 9 | # 1. get A/B/C/D 10 | for pattern in [ 11 | r"(.*?)", 12 | r"(.*?)", 13 | r"^([A-Z])[.,:]", 14 | r"Answer:\s*([A-Z])\s*", 15 | ]: 16 | match = re.search(pattern, pred, re.DOTALL) 17 | if match is not None: 18 | pred = match.group(1) 19 | # 2. remove <> 20 | pred = pred.replace("<", "").replace(">", "") 21 | pred = pred.strip() 22 | return pred 23 | 24 | def copy_images(image_paths, src_dir, dest_dir): 25 | 26 | os.makedirs(dest_dir, exist_ok=True) 27 | 28 | copied_images = {} 29 | 30 | for img_path in image_paths: 31 | if os.path.isabs(img_path): 32 | img_file = os.path.basename(img_path) 33 | src_img = img_path 34 | else: 35 | img_file = os.path.basename(img_path) 36 | src_img = os.path.join(src_dir, img_file) 37 | 38 | dest_img = os.path.join(dest_dir, img_file) 39 | 40 | if os.path.exists(src_img): 41 | shutil.copy2(src_img, dest_img) 42 | copied_images[img_path] = os.path.join(os.path.basename(dest_dir), img_file) 43 | else: 44 | print(f"Warning: Image not found: {src_img}") 45 | copied_images[img_path] = img_path 46 | 47 | return copied_images 48 | 49 | def filter_correct_predictions(file_path, is_mm_mode=False): 50 | with open(file_path, "r") as f: 51 | data = json.load(f) 52 | 53 | correct_data = [] 54 | image_paths = [] 55 | 56 | for item in data: 57 | output = extract_option(item["output"]) 58 | if is_mm_mode: 59 | if ( 60 | "messages" in item 61 | and len(item["messages"]) > 1 62 | and "content" in item["messages"][1] 63 | ): 64 | predict = extract_option(item["messages"][1]["content"]) 65 | else: 66 | continue 67 | else: 68 | predict = extract_option(item["predict"]) 69 | 70 | if output == predict: 71 | correct_data.append(item) 72 | 73 | if "images" in item and isinstance(item["images"], list): 74 | image_paths.extend(item["images"]) 75 | 76 | file_path_obj = Path(file_path) 77 | output_file = str(file_path_obj.with_stem(file_path_obj.stem + "_true")) 78 | 79 | if is_mm_mode and image_paths: 80 | 81 | parent_dir = "data" 82 | 83 | if "train" in file_path: 84 | src_img_dir = os.path.join(parent_dir, "train_document_images") 85 | dest_img_dir = os.path.join(parent_dir, "train_document_images_true") 86 | elif any(x in file_path for x in ["test", "val"]): 87 | src_img_dir = os.path.join(parent_dir, "val_document_images") 88 | dest_img_dir = os.path.join(parent_dir, "val_document_images_true") 89 | else: 90 | src_img_dir = os.path.join(parent_dir, "document_images") 91 | dest_img_dir = os.path.join(parent_dir, "document_images_true") 92 | 93 | path_mapping = copy_images(image_paths, src_img_dir, dest_img_dir) 94 | 95 | for item in correct_data: 96 | if "images" in item and isinstance(item["images"], list): 97 | item["images"] = [path_mapping.get(img_path, img_path) for img_path in item["images"]] 98 | 99 | print(f"Copied {len(path_mapping)} unique images to {dest_img_dir}") 100 | 101 | with open(output_file, "w") as f: 102 | json.dump(correct_data, f, indent=2) 103 | 104 | print(f"Filtered {len(correct_data)} correct predictions out of {len(data)} total") 105 | print(f"Saved to: {output_file}") 106 | 107 | if __name__ == "__main__": 108 | parser = argparse.ArgumentParser( 109 | description="Filter correct predictions from a dataset" 110 | ) 111 | parser.add_argument("file_path", help="Path to the JSON dataset file") 112 | parser.add_argument( 113 | "--mm", 114 | action="store_true", 115 | help="Use messages[1].content instead of predict and handle image paths" 116 | ) 117 | args = parser.parse_args() 118 | filter_correct_predictions(args.file_path, args.mm) 119 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 |

RARE: Retrieval-Augmented Reasoning Modeling

2 | 3 |

4 |   5 | 6 | License: apache-2-0 7 | GitHub Stars 8 |

9 | 10 |
If you like our project, please give us a star ⭐ on GitHub for the latest update.
11 | 12 | ## 💡 Overview 13 | 14 | We propose the **RARE** framework, a novel paradigm that decouples knowledge storage from reasoning modeling. This framework accelerates reasoning modeling via bypassing rote memorization of lower-level knowledge. *All progress will be openly shared and continuously updated in this repository!* 15 | 16 | 17 | 18 | 19 | 20 | 21 | 24 | 25 |
22 | Performance of RARE versus baselines on diverse benchmarks (medical, legal, financial, and more). 23 |
26 | 27 | 28 | 29 | 30 | 31 | 32 | 35 | 36 |
33 | Motivation of RARE. Left: A pyramid-shaped Bloom’s Taxonomy, illustrating the cognitive hierarchy from basic "Remember" to advanced "Evaluate" and "Create" levels. Right: The correspondence between Domain Knowledge and Domain Thinking with Bloom’s cognitive hierarchy (example related to government bond yields). In contrast to domain knowledge, domain thinking corresponds to the higher-order cognitive process—although relatively rare, it plays a crucial role. 34 |
37 | 38 | 39 | 40 | 41 | 42 | 43 | 46 | 47 |
44 | Demonstration of RARE with real medical case studies. Compared to RAG (only with domain knowledge), RARE (combining domain knowledge and thinking) enables LLMs to reason more deeply and accurately. RAG depends only on surface indicators, hastily concluding that the patient requires immediate glucose-lowering intervention, leading to an incorrect answer. In contrast, RARE integrates both clinical indicators and the effectiveness of prior treatment, carefully reasoning that the patient needs second-line therapy while providing a individualized treatment plan—ultimately arriving at the correct answer. 45 |
48 | 49 | 50 | 51 | ## 🔧 Installation 52 | 53 | Complete the environment deployment for this project through the following methods. 54 | 55 | ``` 56 | git clone https://github.com/Open-DataFlow/RARE 57 | cd RARE 58 | 59 | conda create -n rare python=3.10 60 | conda activate rare 61 | 62 | pip install -r requirements.txt 63 | ``` 64 | 65 | ## 🏃 Quick Start 66 | 67 | We provide two complete examples, one for pure language and one for visual-language tasks. 68 | 69 | - Train with [PubMedQA](https://arxiv.org/abs/1909.06146) (text-only dataset) 70 | ``` 71 | bash demo/llama_pubmedqa_rare.sh 72 | ``` 73 | 74 | - Train with [MM-RAIT](https://arxiv.org/abs/2502.17297) (multi-modal dataset) 75 | ``` 76 | bash demo/qwenvl_mmrait_rare.sh 77 | ``` 78 | 79 | ## ✨ Main Experiments 80 | 81 | ### 📋 Data Preparation 82 | 83 | We provide methods for preprocessing data from different sources. 84 | 85 | 1. Format Original Data 86 | 87 | - process medqa and pubmed 88 | 89 | ``` 90 | huggingface-cli download --repo-type dataset --resume-download yuhkalhic/rare_share --local-dir process/rare_share 91 | unzip process/rare_share/system=planner_addret,dataset=all_dev,debug=False.zip -d process/rare_share/system=planner_addret,dataset=all_dev,debug=False 92 | unzip process/rare_share/system=planner_addret,dataset=all_train,debug=False.zip -d process/rare_share/system=planner_addret,dataset=all_train,debug=False 93 | unzip process/rare_share/system=planner_addret,dataset=all_test,debug=False.zip -d process/rare_share/system=planner_addret,dataset=all_test,debug=False 94 | python process/process_medqa.py 95 | python process/process_pubmed.py 96 | ``` 97 | - process pubhealth 98 | ``` 99 | python process/process_pubhealth.py 100 | ``` 101 | - process casehold 102 | ``` 103 | python process/process_casehold.py 104 | ``` 105 | - process finfact 106 | ``` 107 | python process/process_finfact.py 108 | ``` 109 | - process mmrait 110 | ``` 111 | huggingface-cli download whalezzz/M2RAG --repo-type dataset --local-dir process --include "fact_verify/*" 112 | python process/process_mmrait.py 113 | ``` 114 | Through the above steps, the construction of prompts and answers for the dataset has been completed. 115 | 116 | 117 | 2. Distill Reasoning Model 118 | ``` 119 | # For medqa, pubmed, pubhealth, casehold, finfact, these steps should be done 120 | modelscope download --model Qwen/QwQ-32B --local_dir saves/QwQ-32B 121 | python inference/vllm_infer_text.py --model_name_or_path saves/QwQ-32B --dataset_path data/train_medqa.json --template qwen 122 | python process/select_true.py data/train_medqa.json # Only for medqa, casehold and finfact 123 | 124 | # For mmrait, these steps should be done 125 | modelscope download --model Qwen/Qwen2.5-VL-32B-Instruct --local_dir saves/Qwen2.5-VL-32B-Instruct 126 | python scripts/vllm_infer_mm.py --model_name_or_path saves/Qwen2.5-VL-32B-Instruct --dataset_path data/train_mmrait.json 127 | python process/select_true.py data/train_mmrait.json --mm 128 | ``` 129 | 130 | > [!TIP] 131 | > To achieve better results, as described in our paper, we recommend using rejection sampling to improve the quality of distillation data. Use vllm_infer_text_reject_sampling.py instead of inference/vllm_infer_text.py for text data distillation, with exactly the same usage. 132 | 133 | After preprocessing, the data should include at least the following keys, where 'instruction', 'predict', 'id', and 'output' represent the prompt, the teaching model's thought process and answer, the data identification number, and the standard answer to the question, respectively. 134 | 135 | ``` 136 | { 137 | "instruction": str, 138 | "predict": str, 139 | "id": str, 140 | "output": str 141 | } 142 | 143 | # or multimodal data: 144 | { 145 | "messages": list, 146 | "images": list, 147 | "id": str, 148 | "output": str 149 | } 150 | ``` 151 | 152 | 153 | ### 🏋️‍♂️ Model Training 154 | 155 | Our training code supports the types of models, and here are some examples of their specific names. 156 | 157 | - meta-llama/Llama-3.1-8B-Instruct 158 | - Qwen/Qwen2.5-7B-Instruct 159 | - mistralai/Mistral-7B-Instruct-v0.3 160 | 161 | You need to modify the value of the fsdp_config parameter to correspond to different models. If you wish to select more different models, you can choose to use llamafactory to start training or modify the code in train/sft.py 162 | 163 | - Training using only text datasets (medqa, pubmed, pubhealth, casehold, finfact) 164 | ``` 165 | bash train/sft.sh 166 | ``` 167 | 168 |
169 | 170 | Additional Note 171 | 172 | --------- 173 | 174 | Sometimes when using this training approach, inference outputs may exhibit repetitive sentence generation. Through our testing, we've found that this does not affect the final results, but it significantly increases inference time. If this issue occurs, we recommend using the following training approach: 175 | 176 | ``` 177 | accelerate launch --config_file train/accelerate_config.yaml train/train.py train/training_args.yaml 178 | ``` 179 | 180 | --------- 181 | 182 |
183 | 184 | 185 | 186 | 187 | - Training with multimodal datasets (mmrait) 188 | ``` 189 | accelerate launch --config_file train/accelerate_config_mm.yaml train/train.py train/training_args_mm.yaml 190 | 191 | python saves/mmrait-qwenvl/zero_to_fp32.py saves/mmrait-qwenvl --safe_serialization # Convert deepspeed checkpoints 192 | ``` 193 | 194 | ### 🔮 Model Inference 195 | 196 | Our inference script supports five types of models, and here are some examples of their specific names. 197 | 198 | - meta-llama/Llama-3.1-8B-Instruct 199 | - Qwen/Qwen2.5-7B-Instruct 200 | - mistralai/Mistral-7B-Instruct-v0.3 201 | - deepseek-ai/DeepSeek-R1-Distill-Llama-8B 202 | - Qwen/Qwen2.5-VL-7B-Instruct 203 | 204 | For the test set, at least [step 1](#step1) of preprocessing should be completed, including questions and standard answers. 205 | Parameters that are strongly recommended to adjust include model_name_or_path, dataset_path, template, prediction_key, and tensor_parallel_size, which represent the model path, dataset path, prompt template (corresponding to the pre-trained model), the key name where inference results are saved in the dataset, and the number of parallel processes (corresponding to the number of your GPUs). 206 | 207 | ``` 208 | python inference/vllm_infer_text.py --model_name_or_path saves/medqa-llama --dataset_path data/test_medqa.json --template llama --prediction_key llm_predict_rare_llama --tensor_parallel_size 8 209 | 210 | # multimodal 211 | python inference/vllm_infer_mm.py --model_name_or_path saves/mmrait-qwenvl --dataset_path data/test_mmrait.json --prediction_key llm_predict_rare_qwen2vl --tensor_parallel_size 4 212 | ``` 213 | 214 |
215 | 216 | API Inference 217 | 218 | --------- 219 | 220 | You can also use API calls to test closed-source models for baseline methods (e.g., RAG). Below is an example that uses the POST method to call the API for inference. You need to specify the model name, your API URL and key, the dataset path, and the number of concurrent workers. 221 | ``` 222 | python inference/api_infer_post.py --model_name 'your_model_name' --api_url 'your_api_url' --api_key 'your_api_key' --dataset_path data/test_medqa.json --concurrency 30 223 | ``` 224 | Then, you can use the exact same method for evaluation. Note that you need to set `--prediction_key` to the name of the model you used. 225 | 226 | ``` 227 | python eval/eval.py --file data/test_medqa.json --prediction_key 'your_model_name' 228 | ``` 229 | --------- 230 | 231 |
232 | 233 | ### 📊 Output Evaluation 234 | 235 | The inference script extracts the keys to be evaluated through regular expressions and compares them with the standard answers, ultimately calculating the accuracy. 236 | 237 | ``` 238 | python eval/eval.py --file data/test_medqa.json --prediction_key llm_predict_rare_llama 239 | ``` 240 | 241 | ## 🧠 Analysis and Discussion 242 | 243 | ### 🧪 Preliminary Experiment 244 | 1. Obtain Pre-experimental Data 245 | 246 | You first need to prepare the data required for the pre-experiment. The pre-experiment uses three datasets in total (PubHealth, CaseHOLD, FinFact), which need to be processed with different scripts. 247 | - pre-process pubhealth, casehold, finfact 248 | ``` 249 | python Pre_Experiment/data/preprocess_pubhealth.py --input data/train_pubhealth.json --output Pre_Experiment/data/pre_pubhealth.json 250 | 251 | python Pre_Experiment/data/preprocess_casehold.py --input data/train_casehold.json --output Pre_Experiment/data/pre_casehold.json 252 | 253 | python Pre_Experiment/data/preprocess_finfact.py --input data/train_finfact.json --output Pre_Experiment/data/pre_finfact.json 254 | ``` 255 | 2. Get Models Required 256 | 257 | Add special tokens to the model to better extract the loss values corresponding to these tokens. 258 | ``` 259 | python Pre_Experiment/customize_model.py --model_name_or_path meta-llama/Llama-3.1-8B-Instruct --output_dir Pre_Experiment/model 260 | ``` 261 | Download spaCy English small model (`en_core_web_sm`) for document key information extraction and text preprocessing 262 | ``` 263 | huggingface-cli download --resume-download spacy/en_core_web_sm --local-dir Pre_Experiment/model 264 | ``` 265 | 3. Conduct Preliminary Experiments 266 | 267 | Then you can use the following bash script to conduct pre-experiments: 268 | ``` 269 | bash Pre_Experiment/pre_experiment.sh 270 | ``` 271 | 272 |
273 | 274 | Experiment Results 275 | 276 | --------- 277 | Experiment Results are saved in: 278 | `Pre_Experiment/result/pre_experiment_{dataset_name}_{retrieval_ratio}_4.json` 279 | 280 | Where: 281 | - `{dataset_name}`: Name of dataset (e.g., `pubhealth`, `casehold`, `finfact`) 282 | - `{retrieval_ratio}`: Ratio used (0-4) from `--retrieval_ratio` parameter 283 | 284 | Example files: 285 | - `pre_experiment_pubhealth_1_4.json` (used 1/4 of retrieval content) 286 | - `pre_experiment_casehold_4_4.json` (used full retrieval content) 287 | 288 | --------- 289 | 290 |
291 | 292 | ### 🧩 PEFT and DEFT 293 | 294 | 1. PEFT 295 | 296 | In section 4.2, we explored the feasibility of using parameter-efficient fine-tuning with RARE as a replacement for full parameter fine-tuning. We found that using rank=64 or 128 could achieve good results. Thanks to LLaMA-Factory, we completed this exploration based on their project. Here, we provide examples of the training scripts we used, divided into a training script and a script for merging the LoRA adapter with the original model. 297 | 298 | - Integrating deepspeed and LLaMA-Factory to train using the RARE strategy 299 | 300 | ``` 301 | accelerate launch --config_file train/accelerate_config.yaml train/train.py train/training_args_lora.yaml 302 | ``` 303 | 304 | - Merging the trained LoRA Adapter with the original model for inference 305 | 306 | ``` 307 | llamafactory-cli export train/merge_lora.yaml 308 | ``` 309 | 310 | - These two commands can replace the commands marked with "# train" in demo/llama_pubmedqa_rare.sh to perform a complete trial 311 | 312 | 2. DEFT 313 | 314 | In appendix A.2, we examined the data efficiency of the RARE strategy. This section does not require additional scripts; you only need to extract subsets from the training data obtained in the previous [Data Preparation](#Data-Preparation) step according to certain percentages. Then, following the example workflow in the demo, train the base model with different sized subsets, complete the inference and evaluation processes to verify data efficiency. 315 | 316 | 317 | ### 🎯 Reinforcement Learning 318 | You first need to use `RL_KTO/process_kto.py` to process the data into the format required by KTO. 319 | ``` 320 | # You can perform similar operations on all seven datasets. 321 | python RL_KTO/process_kto.py --input_path data/train_covert.json --output_path RL_KTO/data/train_covert_kto.json 322 | ``` 323 | Then you can use the following bash script to train KTO: 324 | ``` 325 | bash RL_KTO/train_kto.sh 326 | ``` 327 | 328 | ## ✏️ TODO List 329 | - [x] A preliminary experiment to demonstrate the dynamics and effectiveness of RARE. 330 | - [x] More results on parameter-efficient fine-tuning and data-efficient fine-tuning. 331 | - [x] More results on reinforcement learning. 332 | - [ ] More results on multi-task and cross-task learning. 333 | - [ ] More results on diverse model sizes alongside the backbones. 334 | - [ ] Releasing RARE 2.0 (Stay Tuned!) 335 | 336 | 337 | ## 📖 Citation 338 | 339 | If you find this work helpful, please cite our paper: 340 | ```bibtex 341 | @article{wang2025rare, 342 | title={RARE: Retrieval-Augmented Reasoning Modeling}, 343 | author={Wang, Zhengren and Yu, Jiayang and Ma, Dongsheng and Chen, Zhe and Wang, Yu and Li, Zhiyu and Xiong, Feiyu and Wang, Yanfeng and Tang, Linpeng and Zhang, Wentao and others}, 344 | journal={arXiv preprint arXiv:2503.23513}, 345 | year={2025} 346 | } 347 | ``` 348 | 349 | ## ❤️ Acknowledgement 350 | 351 | This repo benefits from: [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory), [s1](https://github.com/simplescaling/s1). Thanks for wonderful works. 352 | 353 | 356 | 357 | 358 | 359 | ## 📞 Contact 360 | 361 | For any questions or feedback, please reach out to us at [wzr@stu.pku.edu.cn](wzr@stu.pku.edu.cn). 362 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | llamafactory==0.9.2 2 | vllm==0.7.3 3 | wandb==0.19.6 4 | qdrant_client 5 | qwen_vl_utils 6 | deepspeed==0.16.4 7 | modelscope==1.24.0 8 | flash-attn==2.7.2 -------------------------------------------------------------------------------- /train/accelerate_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: true 3 | deepspeed_config: 4 | gradient_accumulation_steps: 8 5 | offload_optimizer_device: cpu 6 | offload_param_device: cpu 7 | zero3_init_flag: false 8 | zero3_save_16bit_model: false 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | enable_cpu_affinity: false 13 | machine_rank: 0 14 | main_training_function: main 15 | mixed_precision: bf16 16 | num_machines: 1 17 | num_processes: 8 18 | rdzv_backend: static 19 | same_network: true 20 | tpu_env: [] 21 | tpu_use_cluster: false 22 | tpu_use_sudo: false 23 | use_cpu: false 24 | -------------------------------------------------------------------------------- /train/fsdp_config_llama.json: -------------------------------------------------------------------------------- 1 | {"transformer_layer_cls_to_wrap": "LlamaDecoderLayer"} 2 | -------------------------------------------------------------------------------- /train/fsdp_config_mistral.json: -------------------------------------------------------------------------------- 1 | {"transformer_layer_cls_to_wrap": "MistralDecoderLayer"} 2 | -------------------------------------------------------------------------------- /train/fsdp_config_qwen.json: -------------------------------------------------------------------------------- 1 | {"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer"} 2 | -------------------------------------------------------------------------------- /train/merge_lora.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: meta-llama/Llama-3.1-8B-Instruct 3 | adapter_name_or_path: saves/llama_pubmed_rare_lora_64 # The location where the training results are saved 4 | template: llama3 5 | finetuning_type: lora 6 | trust_remote_code: true 7 | 8 | ### export 9 | export_dir: saves/llama_pubmed_rare_lora 10 | export_size: 2 11 | export_device: cpu 12 | export_legacy_format: false -------------------------------------------------------------------------------- /train/sft.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field, asdict 3 | from typing import Optional 4 | import warnings 5 | 6 | warnings.filterwarnings("ignore", category=FutureWarning) 7 | import logging 8 | 9 | logging.basicConfig( 10 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" 11 | ) 12 | from datasets import Dataset 13 | import transformers 14 | import trl 15 | import json 16 | 17 | 18 | @dataclass 19 | class TrainingConfig: 20 | model_name_or_path: str = field(default="meta-llama/Llama-3.1-8B-Instruct") 21 | block_size: int = field(default=32768) 22 | # wandb_project: Optional[str] = field(default="RARE") 23 | # wandb_entity: Optional[str] = field(default="1111") 24 | train_file_path: Optional[str] = field( 25 | default="data/train_pubmed.json" 26 | ) 27 | dagger: bool = field(default=False) 28 | 29 | # def __post_init__(self): 30 | # os.environ["WANDB_PROJECT"] = self.wandb_project 31 | # os.environ["WANDB_ENTITY"] = self.wandb_entity 32 | 33 | 34 | def train(): 35 | # parsing input 36 | parser = transformers.HfArgumentParser((TrainingConfig, trl.SFTConfig)) 37 | config, args = parser.parse_args_into_dataclasses() 38 | log_config = {**asdict(config), **asdict(args)} 39 | logging.info(f"Training config: {log_config}") 40 | 41 | # loading model 42 | kwargs = {} 43 | if "70B" in config.model_name_or_path: 44 | kwargs = { 45 | "device_map": "auto", 46 | "torch_dtype": "auto", 47 | "attn_implementation": "flash_attention_2", 48 | "use_cache": False, 49 | } 50 | model = transformers.AutoModelForCausalLM.from_pretrained( 51 | config.model_name_or_path, **kwargs 52 | ) 53 | else: 54 | model = transformers.AutoModelForCausalLM.from_pretrained(config.model_name_or_path) 55 | 56 | # Load and process your custom dataset 57 | with open(config.train_file_path, "r") as f: 58 | data = json.load(f) 59 | 60 | # Prepare the dataset in the format expected by the trainer 61 | # Combine instruction and output into a single text field 62 | processed_data = [] 63 | 64 | for item in data: 65 | # Check if we need to handle specific formats or templates 66 | if "Qwen" in config.model_name_or_path: 67 | # Format for Qwen models 68 | text = f"<|im_start|>user\n{item['instruction']}<|im_end|>\n<|im_start|>assistant\n{item['predict']}<|im_end|>" 69 | elif "Llama" in config.model_name_or_path: 70 | # Format for Llama models 71 | text = f"<|start_header_id|>user<|end_header_id|>\n{item['instruction']}\n<|start_header_id|>assistant<|end_header_id|>\n\n{item['predict']}" 72 | elif "Mistral" in config.model_name_or_path: 73 | text = f"[INST] {item['instruction']}[/INST] {item['predict']}" 74 | else: 75 | # Generic format for other models 76 | text = f"USER: {item['instruction']}\nASSISTANT: {item['predict']}" 77 | processed_data.append({"text": text}) 78 | 79 | # Create a Hugging Face dataset from the processed data 80 | dataset = Dataset.from_list(processed_data) 81 | 82 | train_dataset = dataset 83 | 84 | # If you want a validation set, uncomment these lines: 85 | # train_size = int(0.95 * len(dataset)) 86 | # train_dataset = dataset.select(range(train_size)) 87 | # eval_dataset = dataset.select(range(train_size, len(dataset))) 88 | 89 | tokenizer = transformers.AutoTokenizer.from_pretrained( 90 | config.model_name_or_path, use_fast=True 91 | ) 92 | if "Llama" in config.model_name_or_path: 93 | instruction_template = "<|start_header_id|>user<|end_header_id|>" 94 | response_template = "<|start_header_id|>assistant<|end_header_id|>\n\n" 95 | # Use a token that is never used 96 | tokenizer.pad_token = "<|reserved_special_token_5|>" 97 | elif "Qwen" in config.model_name_or_path: 98 | instruction_template = "<|im_start|>user" 99 | response_template = "<|im_start|>assistant\n" 100 | # Use a token that is never used 101 | tokenizer.pad_token = "<|fim_pad|>" 102 | elif "Mistral" in config.model_name_or_path: 103 | instruction_template = tokenizer.encode("[INST]", add_special_tokens=False) 104 | response_template_tokens = tokenizer.encode("[/INST]", add_special_tokens=False) 105 | tokenizer.pad_token = "[control_766]" 106 | 107 | collator = None 108 | if "Mistral" in config.model_name_or_path: 109 | response_template_tokens = tokenizer.encode( 110 | "\n[/INST]", add_special_tokens=False 111 | )[2:] 112 | collator = trl.DataCollatorForCompletionOnlyLM( 113 | response_template=response_template_tokens, tokenizer=tokenizer, mlm=False 114 | ) 115 | else: 116 | collator = trl.DataCollatorForCompletionOnlyLM( 117 | instruction_template=instruction_template, 118 | response_template=response_template, 119 | tokenizer=tokenizer, 120 | mlm=False, 121 | ) 122 | 123 | args.dataset_text_field = "text" 124 | args.max_seq_length = config.block_size 125 | trainer = trl.SFTTrainer( 126 | model, 127 | train_dataset=train_dataset, 128 | # Use the whole dataset for training, no separate eval dataset 129 | eval_dataset=None, 130 | args=args, 131 | data_collator=collator, 132 | ) 133 | 134 | trainer.train() 135 | trainer.save_model(output_dir=args.output_dir) 136 | tokenizer.save_pretrained(args.output_dir) 137 | trainer.accelerator.wait_for_everyone() 138 | 139 | 140 | if __name__ == "__main__": 141 | train() 142 | -------------------------------------------------------------------------------- /train/sft.sh: -------------------------------------------------------------------------------- 1 | # Reference Running: bash train/sft.sh 2 | # {'train_runtime': 5268.8407, 'train_samples_per_second': 0.949, 'train_steps_per_second': 0.119, 'train_loss': 0.1172730620391667, 'epoch': 5.0} 3 | # export HCCL_CONNECT_TIMEOUT=2000 # for NPU 4 | uid="$(date +%Y%m%d_%H%M%S)" 5 | base_model="meta-llama/Llama-3.1-8B-Instruct" 6 | lr=1e-5 7 | min_lr=0 8 | epochs=5 9 | weight_decay=1e-4 # -> the same training pipe as slurm_training 10 | micro_batch_size=1 # -> batch_size will be 16 if 16 gpus 11 | gradient_accumulation_steps=8 # requires more GPU memory 12 | max_steps=-1 13 | gpu_count=8 14 | push_to_hub=false 15 | 16 | torchrun --nproc-per-node ${gpu_count} --master_port 12345 \ 17 | train/sft.py \ 18 | --block_size=32768 \ 19 | --train_file_path="data/train_pubmed" \ 20 | --per_device_train_batch_size=${micro_batch_size} \ 21 | --per_device_eval_batch_size=${micro_batch_size} \ 22 | --gradient_accumulation_steps=${gradient_accumulation_steps} \ 23 | --num_train_epochs=${epochs} \ 24 | --model_name_or_path=${base_model} \ 25 | --warmup_ratio=0.05 \ 26 | --fsdp="full_shard auto_wrap" \ 27 | --fsdp_config="train/fsdp_config_llama.json" \ 28 | --bf16=True \ 29 | --eval_strategy="no" \ 30 | --logging_steps=1 \ 31 | --save_strategy="no" \ 32 | --lr_scheduler_type="cosine" \ 33 | --learning_rate=${lr} \ 34 | --weight_decay=${weight_decay} \ 35 | --adam_beta1=0.9 \ 36 | --adam_beta2=0.95 \ 37 | --output_dir="saves/pubmed-llama" \ 38 | --push_to_hub=${push_to_hub} \ 39 | --save_only_model=True \ 40 | --gradient_checkpointing=True \ 41 | --report_to="none" # remove this parameter to use wandb 42 | # --accelerator_config='{"gradient_accumulation_kwargs": {"sync_each_batch": true}}' 43 | -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | from llamafactory.train.tuner import run_exp 2 | 3 | 4 | def main(): 5 | run_exp() 6 | 7 | 8 | if __name__ == "__main__": 9 | main() 10 | -------------------------------------------------------------------------------- /train/training_args.yaml: -------------------------------------------------------------------------------- 1 | bf16: true 2 | cutoff_len: 32768 3 | dataset: pubmed 4 | dataset_dir: data 5 | ddp_timeout: 180000000 6 | do_train: true 7 | finetuning_type: full 8 | flash_attn: auto 9 | gradient_accumulation_steps: 8 10 | include_num_input_tokens_seen: true 11 | learning_rate: 1.0e-05 12 | logging_steps: 1 13 | lr_scheduler_type: cosine 14 | max_grad_norm: 1.0 15 | max_samples: 100000 16 | model_name_or_path: meta-llama/Llama-3.1-8B-Instruct 17 | num_train_epochs: 5.0 18 | optim: adamw_torch 19 | output_dir: saves/pubmed-llama 20 | packing: false 21 | per_device_train_batch_size: 1 22 | plot_loss: true 23 | preprocessing_num_workers: 16 24 | report_to: none 25 | save_steps: 5000 26 | stage: sft 27 | template: llama3 28 | trust_remote_code: true 29 | warmup_steps: 0 30 | do_eval: false -------------------------------------------------------------------------------- /train/training_args_lora.yaml: -------------------------------------------------------------------------------- 1 | bf16: true 2 | cutoff_len: 32768 3 | dataset: pubmed 4 | dataset_dir: data 5 | ddp_timeout: 180000000 6 | do_train: true 7 | finetuning_type: lora 8 | flash_attn: auto 9 | gradient_accumulation_steps: 8 10 | include_num_input_tokens_seen: true 11 | learning_rate: 1.0e-05 12 | logging_steps: 1 13 | lora_alpha: 128 14 | lora_dropout: 0 15 | lora_rank: 64 16 | lora_target: all 17 | loraplus_lr_ratio: 16 18 | lr_scheduler_type: cosine 19 | max_grad_norm: 1.0 20 | max_samples: 100000 21 | model_name_or_path: meta-llama/Llama-3.1-8B-Instruct 22 | num_train_epochs: 5.0 23 | optim: adamw_torch 24 | output_dir: saves/llama_pubmed_rare_lora_64 25 | packing: false 26 | per_device_train_batch_size: 1 27 | plot_loss: true 28 | preprocessing_num_workers: 16 29 | report_to: none 30 | save_steps: 5000 31 | stage: sft 32 | template: llama3 33 | trust_remote_code: true 34 | warmup_steps: 0 35 | do_eval: false -------------------------------------------------------------------------------- /train/training_args_mm.yaml: -------------------------------------------------------------------------------- 1 | bf16: true 2 | cutoff_len: 32768 3 | dataset: mmrait 4 | dataset_dir: data 5 | ddp_timeout: 180000000 6 | do_train: true 7 | finetuning_type: full 8 | flash_attn: auto 9 | gradient_accumulation_steps: 8 10 | include_num_input_tokens_seen: true 11 | learning_rate: 5.0e-05 12 | logging_steps: 1 13 | lr_scheduler_type: cosine 14 | max_grad_norm: 1.0 15 | max_samples: 100000 16 | model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct 17 | num_train_epochs: 5.0 18 | optim: adamw_torch 19 | output_dir: saves//mmrait-qwenvl 20 | packing: false 21 | per_device_train_batch_size: 1 22 | plot_loss: true 23 | preprocessing_num_workers: 16 24 | report_to: none 25 | save_steps: 2150 26 | stage: sft 27 | template: qwen2_vl 28 | trust_remote_code: true 29 | warmup_steps: 0 30 | --------------------------------------------------------------------------------