├── .DS_Store ├── figure ├── .DS_Store ├── overview.png └── example_result.png ├── dataset_pipeline ├── .DS_Store ├── 3_dataset_prepare.py ├── 4_convert_natural_question.py ├── 2_convert_MCQ_full.py └── 1_email_parse.py ├── difficulty_and_category_assignment ├── .DS_Store └── category_difficulty.py ├── requirements.txt ├── README.md ├── evaluation └── genome-bench_eval.py └── training ├── sft_training.py ├── rl_training.py └── rl_router_training.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyin0312/RL4GenomeBench/HEAD/.DS_Store -------------------------------------------------------------------------------- /figure/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyin0312/RL4GenomeBench/HEAD/figure/.DS_Store -------------------------------------------------------------------------------- /figure/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyin0312/RL4GenomeBench/HEAD/figure/overview.png -------------------------------------------------------------------------------- /figure/example_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyin0312/RL4GenomeBench/HEAD/figure/example_result.png -------------------------------------------------------------------------------- /dataset_pipeline/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyin0312/RL4GenomeBench/HEAD/dataset_pipeline/.DS_Store -------------------------------------------------------------------------------- /difficulty_and_category_assignment/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyin0312/RL4GenomeBench/HEAD/difficulty_and_category_assignment/.DS_Store -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets 2 | openai 3 | pandas 4 | torch 5 | tqdm 6 | transformers 7 | wandb 8 | csv 9 | mailbox 10 | re 11 | json 12 | tqdm 13 | os 14 | vllm 15 | random 16 | warnings 17 | trl 18 | typing 19 | accelerate 20 | packaging 21 | 22 | 23 | -------------------------------------------------------------------------------- /dataset_pipeline/3_dataset_prepare.py: -------------------------------------------------------------------------------- 1 | #######-------------------------- 2 | ## This script will use Converted_MCQs_gpt4o_full.json file 3 | ## as the input and convert it to MCQs_formatted.json, 4 | ## where answer key has the format 5 | ## {original expert explanation} {correct option}. 6 | ## It also concatenates 'context' with 'question' in question field. 7 | #######-------------------------- 8 | 9 | 10 | import json 11 | 12 | # Load the original dataset 13 | with open("Converted_MCQs_gpt4o_full.json", "r") as file: 14 | data = json.load(file) 15 | 16 | # Function to format the answer with step-by-step reasoning (dummy example) 17 | def format_answer(original_answer, final_answer): 18 | return f"{original_answer} {final_answer}" 19 | 20 | # Convert to GENOME-BENCH format 21 | converted_data = [] 22 | for entry in data: 23 | question = entry["Questions with options"] 24 | reasoning = entry.get("Original answer", "The answer is derived from the given choices.") # Placeholder reasoning 25 | final_answer = entry.get("Answer key"), 26 | context = entry["Context"] 27 | 28 | 29 | # Ensure proper formatting 30 | formatted_answer = format_answer(reasoning, final_answer[0]) 31 | 32 | converted_data.append({ 33 | "question": 'Question context: ' + context + ' Question: ' + question, 34 | "answer": formatted_answer 35 | }) 36 | 37 | # Save the new dataset 38 | with open("MCQs_formatted.json", "w") as outfile: 39 | json.dump(converted_data, outfile, indent=4) 40 | 41 | print("Dataset reformatted to GENOME-BENCH format and saved as MCQs_formatted.json") -------------------------------------------------------------------------------- /dataset_pipeline/4_convert_natural_question.py: -------------------------------------------------------------------------------- 1 | ####----------------------------------------------------------------- 2 | #### Converting "Question + Option" questions to natural questions 3 | ####----------------------------------------------------------------- 4 | 5 | import json 6 | import openai 7 | from tqdm import tqdm 8 | 9 | # Set your OpenAI API key 10 | openai.api_key = "xxxx" 11 | 12 | # Load the JSON file 13 | input_file = "MCQs_formatted.json" 14 | output_file = "MCQs_Genome-Bench.json" 15 | 16 | with open(input_file, "r", encoding="utf-8") as f: 17 | questions_data = json.load(f) 18 | 19 | # Function to transform structured question into a natural-sounding question 20 | def generate_natural_question(question_context, question_text): 21 | prompt = (f"Convert the following structured question into a natural-sounding question while preserving the " 22 | f"maximal amount of information:\n\n" 23 | f"Context: {question_context}\n" 24 | f"Question: {question_text}\n\n" 25 | f"The output should be a single natural question that integrates the context smoothly without phrases like " 26 | f"'Person A' or 'Person B'.") 27 | 28 | try: 29 | response = openai.chat.completions.create( 30 | model="gpt-4o", 31 | messages=[{"role": "system", "content": "You are an AI that rewrites structured questions into natural-sounding questions while preserving information."}, 32 | {"role": "user", "content": prompt}], 33 | max_tokens=1024, 34 | temperature=0.7 35 | ) 36 | 37 | return response.choices[0].message.content 38 | except Exception as e: 39 | print(f"Error generating question: {e}") 40 | return None 41 | 42 | 43 | # Process each question in the JSON file 44 | processed_questions = [] 45 | for item in tqdm(questions_data, desc="Processing Questions", unit="question"): 46 | structured_question = item["question"] 47 | 48 | # Extract context and question text 49 | try: 50 | context_start = structured_question.find("Question context: ") + len("Question context: ") 51 | 52 | question_start = structured_question.find("Question: ") 53 | 54 | context_text = structured_question[context_start:question_start].strip() 55 | 56 | question_text = structured_question[question_start + len("Question: "):].strip() 57 | 58 | 59 | # Generate natural question 60 | natural_question = generate_natural_question(context_text, question_text) 61 | if natural_question: 62 | processed_questions.append({"natural_question": natural_question}) 63 | except Exception as e: 64 | print(f"Error processing question: {e}") 65 | 66 | # Save the transformed questions to a new JSON file 67 | with open(output_file, "w", encoding="utf-8") as f: 68 | json.dump(processed_questions, f, indent=4, ensure_ascii=False) 69 | 70 | print(f"Processed questions saved to {output_file}") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Toward Scientific Reasoning in LLMs: Training from Expert Discussions via Reinforcement Learning 4 | 5 |
6 | 🧬 A New Benchmark Genome-Bench and RL Fine-Tuning for Scientific Reasoning 📊 7 |
8 |
9 | 10 |
11 |
12 | 13 |
14 | 15 | [![arXiv](https://img.shields.io/badge/arXiv-2505.19501-red?style=for-the-badge&logo=arxiv&logoColor=auto)](https://arxiv.org/abs/2505.19501) 16 | [![GitHub](https://img.shields.io/badge/GitHub-Code-000000?style=for-the-badge&logo=github&logoColor=auto)](https://github.com/mingyin0312/RL4GenomeBench) 17 | [![HuggingFace](https://img.shields.io/badge/HuggingFace-Dataset-ffcc00?style=for-the-badge&logo=huggingface&logoColor=auto)](https://huggingface.co/datasets/Mingyin0312/Genome-Bench) 18 | 19 |
20 |
21 | 22 | ## Overview 23 | 24 | ![](figure/overview.png) 25 | 26 | We introduce **Genome-Bench**, a novel benchmark for evaluating and improving scientific reasoning in large language models. Genome-Bench consists of over 3,000 multiple-choice and QA items derived from CRISPR-related scientific discussions and forum threads, covering key topics in genome engineering, experimental design, and error analysis. 27 | 28 | Our RL training pipeline (based on Group Relative Policy Optimization) improves model performance across expert-labeled evaluation sets. For example, our fine-tuned Qwen2.5-7B model exceeds GPT-4o in accuracy and consistency on multi-hop reasoning tasks. 29 | 30 | ![](figure/example_result.png) 31 | 32 | --- 33 | 34 | ## Getting Started 🎯 35 | 36 | ### Installation 37 | 38 | ```bash 39 | git clone https://github.com/mingyin0312/RL4GenomeBench.git 40 | cd RL4GenomeBench 41 | pip install -r requirements.txt 42 | ``` 43 | 44 | 45 | ### Dataset Preparation 46 | 47 | We provide tools to parse .mbox email archives and convert them into standardized MCQ and QA formats. 48 | 49 | ```bash 50 | cd dataset_pipeline 51 | python 1_email_parse.py 52 | python 2_convert_MCQ_full.py 53 | python 3_dataset_prepare.py 54 | python 4_convert_natural_question.py 55 | ``` 56 | 57 | ## Training 58 | 59 | ### Reinforcement Fine-tuning (GRPO) 60 | 61 | ```bash 62 | python training/rl_training.py 63 | ``` 64 | 65 | ### Supervised Fine-Tuning (SFT) 66 | 67 | ```bash 68 | python training/sft_training.py 69 | ``` 70 | 71 | ### Multi-Agent RL Routing 72 | 73 | ```bash 74 | python training/rl_router_training.py 75 | ``` 76 | 77 | ## Evaluation 78 | 79 | To evaluate on the Genome-Bench test data: 80 | 81 | ```bash 82 | python evaluation/genome-bench_eval.py 83 | ``` 84 | 85 | ## Citation 86 | 87 | ```bibtex 88 | @article{yin2025genome, 89 | title={Toward Scientific Reasoning in LLMs: Training from Expert Discussions via Reinforcement Learning}, 90 | author={Yin, Ming and Qu, Yuanhao and Ling, Yang and Cong, Le and Wang Mengdi}, 91 | journal={arXiv preprint arXiv:2505.19501}, 92 | year={2025} 93 | } 94 | ``` 95 | 96 | ## Acknowledgement 97 | 98 | This project leverages the 🤗 [Transformers Reinforcement Learning (TRL)](https://github.com/huggingface/trl) library, which provides powerful tools for fine-tuning large language models with reinforcement learning techniques such as GRPO. 99 | 100 | -------------------------------------------------------------------------------- /dataset_pipeline/2_convert_MCQ_full.py: -------------------------------------------------------------------------------- 1 | #######-------------------------- 2 | ## This script will use new_QA.json file 3 | ## as the input and use gpt-4o to convert it to 4 | ## multiple-choice question data named 5 | ## Converted_MCQs_gpt4o_full.json 6 | #######-------------------------- 7 | 8 | import json 9 | import openai 10 | from tqdm import tqdm 11 | 12 | # Set up your OpenAI API key 13 | openai.api_key = "xxxx" 14 | # Load the MCQ examples for prompting 15 | with open("MCQ_updated.json", "r") as f: 16 | mcq_examples = json.load(f)["MCQ_Updated"] 17 | 18 | # Load the QA data 19 | with open("new_QA.json", "r") as f: 20 | final_qa_data = json.load(f) 21 | 22 | # Prepare the examples prompt 23 | example_prompts = "\n\n".join( 24 | f"Example MCQ:\n{mcq['Questions with options']}\nCorrect Answer: {mcq['Key']}" 25 | for mcq in mcq_examples# 26 | ) 27 | 28 | # Initialize a list to store generated MCQs 29 | generated_mcqs = [] 30 | 31 | # Process each question in the QA data 32 | for item in tqdm(final_qa_data, total = len(final_qa_data)): 33 | 34 | question = item["question"] 35 | answer = item["answer"] 36 | context = item["context"] 37 | 38 | if answer.strip() == "": 39 | continue 40 | else: 41 | # Prepare the prompt 42 | prompt = f""" 43 | Below are examples of multiple-choice questions (MCQs) with their formats. Use them to generate a new Single-Choice MCQ based on the provided question and context. The MCQ should include five answer choices (a-e) with one being the correct answer, and it should be identical as provided, following the similar format as the examples. The response should not include anything else except for the generated multiple-choice question itself. Just provide one version, then end the response and don't repeat. 44 | 45 | {example_prompts} 46 | 47 | New Question: 48 | Question: {question} 49 | Answer: {answer} 50 | Context: {context} 51 | 52 | Do not modify the original question in the new MCQ. 53 | 54 | Generate a new MCQ: 55 | """ 56 | #If the above question is not clear, you can also modify the question by using 'context' information for new questions. However, do not use 'answer' information for it! 57 | 58 | # Generate MCQ using the OpenAI API 59 | response = openai.chat.completions.create( 60 | model="gpt-4o", 61 | messages=[ 62 | {"role": "system", "content": "You are an gene-editing expert generating formatted single-choice multiple-choice questions."}, 63 | {"role": "user", "content": prompt} 64 | ], 65 | max_tokens=1024, 66 | temperature=0.7 67 | ) 68 | 69 | # Extract the generated text 70 | generated_text = response.choices[0].message.content 71 | 72 | # Process the response to separate options and answer key 73 | index = generated_text.find("Correct Answer:") 74 | if index != -1: 75 | options = generated_text[:index].strip() 76 | answer_key = generated_text[index+16:].strip() 77 | else: 78 | options = generated_text 79 | answer_key = "N/A" # Default if no answer key is found 80 | 81 | # Store the result 82 | generated_mcqs.append({ 83 | "Questions with options": question + " " + options, 84 | "Answer key": answer_key, 85 | "Original answer": answer, 86 | "Context": context 87 | }) 88 | 89 | 90 | # Save the generated MCQs to a new JSON file 91 | output_file = "Converted_MCQs_gpt4o_full.json" 92 | with open(output_file, "w") as f: 93 | json.dump(generated_mcqs, f, indent=4) 94 | 95 | print(f"Generated MCQs saved to {output_file}") -------------------------------------------------------------------------------- /evaluation/genome-bench_eval.py: -------------------------------------------------------------------------------- 1 | ####----------------------- 2 | #### Formal evaluation 3 | ####----------------------- 4 | 5 | import json 6 | import transformers 7 | import torch 8 | from datasets import load_dataset, Dataset, load_from_disk 9 | from transformers import AutoTokenizer 10 | from vllm import LLM, SamplingParams 11 | from tqdm import tqdm 12 | import openai 13 | 14 | 15 | #### Load data 16 | dataset_name = './dataset/Genome-Bench' 17 | dataset_loaded = load_from_disk(dataset_name)['test'] 18 | # Convert to the desired list format with the name 'questions' 19 | questions = [{"question": q, "answer": a} for q, a in zip(dataset_loaded["question"], dataset_loaded["answer"])] 20 | 21 | 22 | # System Prompt 23 | R1_STYLE_SYSTEM_PROMPT = """A conversation between User and Assistant. The user asks a single-choice Multiple Choice question, and the Assistant solves it. Please answer the multiple choice question by selecting only one from optiona a., option b., option c., option d., option e.. 24 | The assistant first thinks about the explanation process in the mind and then provides the user 25 | with the answer. The explanation process and answer are enclosed within and 26 | tags, respectively, i.e., explanation process here 27 | answer here .""" 28 | 29 | 30 | TASK_SPECIFIC_INSTRUCTIONS = "The answer must be a single letter from a,b,c,d,e." 31 | 32 | def extract_xml_answer(text: str) -> str: 33 | """Extracts the answer from a response using the tag.""" 34 | try: 35 | answer = text.split("")[-1].split("")[0].strip() 36 | return answer 37 | except IndexError: 38 | return "" 39 | 40 | 41 | # Initialize the pipeline with Qwen2.5-7B-Instruct 42 | name = 'Qwen2.5-7B-Instruct' 43 | model_id = name # Replace with your model ID 44 | 45 | 46 | llm = LLM(model=model_id, dtype="half", max_model_len=1024, device="cuda:0") 47 | 48 | # Load tokenizer 49 | tokenizer = AutoTokenizer.from_pretrained(model_id, model_max_length=1024, padding_side='right') 50 | 51 | import random 52 | random_number = random.randint(1, 10000) 53 | 54 | # Set sampling parameters 55 | sampling_params = SamplingParams( 56 | temperature=0.7, 57 | max_tokens=1024, 58 | stop_token_ids=[tokenizer.eos_token_id], 59 | seed = random_number, 60 | ) 61 | 62 | 63 | BATCH_SIZE = 8 64 | # Evaluate questions in batches 65 | results = [] 66 | correct = 0 67 | total = 0 68 | 69 | # Progress bar 70 | progress_bar = tqdm(total=len(questions), desc="Processing", unit="examples", dynamic_ncols=True) 71 | 72 | for i in range(0, len(questions), BATCH_SIZE): 73 | batch_data = questions[i:i + BATCH_SIZE] 74 | 75 | # Prepare prompts using few-shot learning 76 | prompts = [ 77 | [ 78 | {'role': 'system', 'content': R1_STYLE_SYSTEM_PROMPT + "\n" + TASK_SPECIFIC_INSTRUCTIONS}, 79 | {"role": "user", "content": q["question"]}, 80 | ] 81 | 82 | for q in batch_data 83 | ] 84 | 85 | # Convert prompts to formatted strings 86 | formatted_prompts = [ 87 | tokenizer.apply_chat_template(p, tokenize=False, add_generation_prompt=True) 88 | for p in prompts 89 | ] 90 | 91 | # Generate responses using vLLM 92 | outputs = llm.generate(formatted_prompts, sampling_params) 93 | 94 | # Process responses 95 | for j, output in enumerate(outputs): 96 | response = output.outputs[0].text 97 | generated_answer = extract_xml_answer(response) 98 | true_answer = extract_xml_answer(batch_data[j]["answer"]) 99 | 100 | # Store the result 101 | result = { 102 | "Question": batch_data[j]["question"], 103 | "Generated Answer": generated_answer, 104 | "Correct Answer": true_answer, 105 | "Full Response": response, 106 | "Correct": generated_answer == true_answer 107 | } 108 | results.append(result) 109 | 110 | if generated_answer == true_answer: 111 | correct += 1 112 | total += 1 113 | 114 | # Update progress bar 115 | progress_bar.update(len(batch_data)) 116 | progress_bar.set_postfix({ 117 | "Accuracy": f"{(correct / total) * 100:.2f}%", 118 | "Correct": f"{correct}/{total}", 119 | }) 120 | 121 | progress_bar.close() 122 | 123 | 124 | # Save the results to a JSON file 125 | output_file = name + '_evaluation_results.json' 126 | with open(output_file, 'w') as file: 127 | json.dump(results, file, indent=4) 128 | 129 | print(f"Evaluation complete. Results saved to {output_file}.") 130 | -------------------------------------------------------------------------------- /training/sft_training.py: -------------------------------------------------------------------------------- 1 | ####### ---------------- 2 | ## SFT Training 3 | ####### ---------------- 4 | 5 | 6 | import os 7 | import re 8 | import torch 9 | from datasets import load_from_disk, Dataset 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | from trl import SFTTrainer, SFTConfig 12 | import wandb 13 | 14 | # System prompt and task instructions 15 | R1_STYLE_SYSTEM_PROMPT = """A conversation between User and Assistant. The user asks a single-choice Multiple Choice question, and the Assistant solves it. Please answer the multiple choice question by selecting only one from option a., option b., option c., option d., option e.. 16 | The assistant first thinks about the explanation process in the mind and then provides the user 17 | with the answer. The explanation process and answer are enclosed within and 18 | tags, respectively, i.e., explanation process here 19 | answer here .""" 20 | 21 | TASK_SPECIFIC_INSTRUCTIONS = "The answer must be a single letter from a, b, c, d, e." 22 | 23 | EXAMPLE = "Question context: PersonA, a novice in the field of selecting mouse ES cells for screening, raises several detailed questions regarding colony selection and verification processes. PersonB provides thorough answers focused on optimizing the selection and genetic verification processes in embryonic stem cells. Question: Would anyone recommendation colony picking versus 96-well dilution for screening colonies? Please choose one of the following options: a. \"Colony picking is preferred for mouse ES cells because they grow as perfect colonies that are easy to pick\" b. \"96-well dilution is more efficient because it allows for high-throughput screening of colonies\" c. \"Colony picking is more labor-intensive and should be avoided when possible\" d. \"96-well dilution is the best method for ensuring genetic consistency across colonies\" e. \"Both methods are equally effective, so the choice depends on available resources\""+"For mouse ES cells there's no reason to do limited dilution, as they grow as perfect colonies that are easy to pick.\na" 24 | 25 | def extract_hash_answer(text: str) -> str: 26 | try: 27 | explanation, answer = text.split("####", 1) 28 | return f" {explanation.strip()} {answer.strip()} " 29 | except ValueError: 30 | return "" 31 | 32 | def preprocess_dataset(dataset_name, split="train", chunk_size=1000) -> Dataset: 33 | """ 34 | Load the dataset from disk and process each batch to generate chat-style prompts. 35 | The resulting dataset will have a "text" field (a string) and an "answer" field. 36 | """ 37 | dataset = load_from_disk(dataset_name)[split] 38 | 39 | def process_batch(batch): 40 | chats = [ 41 | f"System: {R1_STYLE_SYSTEM_PROMPT}\n{TASK_SPECIFIC_INSTRUCTIONS}\n{EXAMPLE}\nUser: {q.strip()}" 42 | for q in batch['question'] 43 | ] 44 | return { 45 | 'text': chats, # Ensure it's a list of strings, not a list of dictionaries 46 | 'answer': [extract_hash_answer(a) for a in batch['answer']] 47 | } 48 | 49 | return dataset.map(process_batch, batched=True, batch_size=chunk_size) 50 | 51 | def main(): 52 | # Load and preprocess the dataset 53 | dataset_name = './dataset/Genome-Bench' 54 | dataset = preprocess_dataset(dataset_name, chunk_size=500) 55 | 56 | learning_rate = 1e-5 57 | 58 | epoch = 2 59 | 60 | # Define model and output paths 61 | model_name = "Qwen2.5-7B-Instruct" 62 | output_dir = f"{model_name.split('/')[-1]}-SFT" 63 | run_name = f"{model_name.split('/')[-1]}-{dataset_name.split('/')[-1]}" 64 | 65 | # Set memory-related environment variable 66 | os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' 67 | 68 | # Create SFT training configuration 69 | training_args = SFTConfig( 70 | learning_rate = learning_rate, 71 | logging_steps=1, 72 | bf16=True, 73 | per_device_train_batch_size = 4, 74 | gradient_accumulation_steps = 4, 75 | num_train_epochs=epoch, 76 | max_grad_norm=0.1, 77 | report_to="wandb", 78 | output_dir=output_dir, 79 | run_name=run_name, 80 | log_on_each_node=False, 81 | ) 82 | 83 | # Load the model 84 | model = AutoModelForCausalLM.from_pretrained( 85 | model_name, 86 | torch_dtype=torch.bfloat16, 87 | device_map="auto", 88 | ) 89 | 90 | # Load the tokenizer and set pad token 91 | tokenizer = AutoTokenizer.from_pretrained( 92 | model_name, 93 | model_max_length=512, 94 | ) 95 | tokenizer.pad_token = tokenizer.eos_token 96 | 97 | # Initialize the SFT trainer using the tokenizer as the processing class 98 | trainer = SFTTrainer( 99 | model=model, 100 | processing_class=tokenizer, 101 | train_dataset=dataset, 102 | args=training_args, 103 | ) 104 | 105 | # Initialize wandb in offline mode for experiment tracking 106 | wandb.init(project="crispr_sft", name=run_name, mode="offline") 107 | trainer.train() 108 | trainer.save_model(training_args.output_dir) 109 | 110 | if __name__ == "__main__": 111 | main() 112 | 113 | 114 | -------------------------------------------------------------------------------- /difficulty_and_category_assignment/category_difficulty.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from llm import LLMChat 4 | 5 | # Load the test data 6 | def load_test_data(file_path): 7 | with open(file_path, 'r') as f: 8 | return json.load(f) 9 | 10 | # Prompt templates for category and difficulty classification 11 | CATEGORY_PROMPT = """ 12 | You are a CRISPR expert assistant. I will give you a CRISPR-related question and answer, and you will assign it to exactly one of the following seven predefined categories based on its core intent. Be precise—choose the category that most closely reflects the primary focus of the question. 13 | Here are the categories: 14 | 1. Gene-editing Enzyme Selection 15 | 2. GuideRNA Design 16 | 3. Cloning & Plasmid Construction 17 | 4. Gene-editing Delivery Methods 18 | 5. CRISPR Screening & Library Workflows 19 | 6. Validation, Troubleshooting & Optimization 20 | 7. Practical Considerations & Lab Logistics 21 | 22 | For each question, provide your category assignment with a brief (1-2 sentence) explanation of why you selected that category. 23 | 24 | Question: {question} 25 | Answer: {answer} 26 | 27 | Please format your response following this response format and make sure it is parsable by JSON: 28 | { 29 | "category": , # Category name 30 | "reason": # Brief statement on why you picked 31 | } 32 | """ 33 | 34 | DIFFICULTY_PROMPT = """ 35 | You are a CRISPR expert assistant. I will give you a CRISPR-related question and answer, and you will assign it to exactly one of the following three predefined difficulty levels. 36 | 37 | Here are the categories: 38 | 1. Easy 39 | 2. Medium 40 | 3. Hard 41 | 42 | For each question, provide your difficulty assignment with a brief (1-2 sentence) explanation of why you selected that level. 43 | 44 | Question: {question} 45 | Answer: {answer} 46 | 47 | Please format your response following this response format and make sure it is parsable by JSON: 48 | { 49 | "difficulty": , # Difficulty name 50 | "reason": # Brief statement on why you picked 51 | } 52 | """ 53 | 54 | def main(): 55 | # Path to the test data file 56 | test_data_path = "MCQs_Genome-Bench-evaluation.json" 57 | 58 | # Load test data 59 | data = load_test_data(test_data_path) 60 | 61 | # Process each entry 62 | results = [] 63 | 64 | # Variables to track stats 65 | category_counts = {} 66 | difficulty_counts = {} 67 | category_mismatches = 0 68 | difficulty_mismatches = 0 69 | entries_with_original_category = 0 70 | entries_with_original_difficulty = 0 71 | 72 | for i, entry in enumerate(data): 73 | print(f"Processing entry {i+1}/{len(data)}: ID {entry['id']}") 74 | 75 | # Extract question and answer 76 | question = entry['question'] 77 | answer = entry['answer'] 78 | entry_id = entry['id'] 79 | 80 | # Store original values if they exist 81 | original_category = entry.get('question type') 82 | original_difficulty = entry.get('difficulty') 83 | 84 | if original_category: 85 | entries_with_original_category += 1 86 | if original_difficulty: 87 | entries_with_original_difficulty += 1 88 | 89 | # Generate new category 90 | try: 91 | category_prompt = CATEGORY_PROMPT.format(question=question, answer=answer) 92 | category_response = LLMChat.chat(category_prompt, model_name="gpt4o") 93 | category = category_response.get('category') 94 | print(f" - Generated category: {category}") 95 | 96 | if original_category and category != original_category: 97 | category_mismatches += 1 98 | print(f" - MISMATCH: Original category was '{original_category}'") 99 | except Exception as e: 100 | print(f" - Error generating category: {str(e)}") 101 | category = None 102 | 103 | # Generate new difficulty 104 | try: 105 | difficulty_prompt = DIFFICULTY_PROMPT.format(question=question, answer=answer) 106 | difficulty_response = LLMChat.chat(difficulty_prompt, model_name="gpt4o") 107 | difficulty = difficulty_response.get('difficulty') 108 | print(f" - Generated difficulty: {difficulty}") 109 | 110 | if original_difficulty and difficulty != original_difficulty: 111 | difficulty_mismatches += 1 112 | print(f" - MISMATCH: Original difficulty was '{original_difficulty}'") 113 | except Exception as e: 114 | print(f" - Error generating difficulty: {str(e)}") 115 | difficulty = None 116 | 117 | # Update counts 118 | if category: 119 | category_counts[category] = category_counts.get(category, 0) + 1 120 | if difficulty: 121 | difficulty_counts[difficulty] = difficulty_counts.get(difficulty, 0) + 1 122 | 123 | # Add to results 124 | results.append({ 125 | 'id': entry_id, 126 | 'question': question, 127 | 'answer': answer, 128 | 'category': category, 129 | 'difficulty': difficulty 130 | }) 131 | 132 | # Print summary statistics 133 | print("\n=== SUMMARY STATISTICS ===") 134 | print("\nCategory Counts:") 135 | for cat, count in category_counts.items(): 136 | print(f" {cat}: {count}") 137 | 138 | print("\nDifficulty Counts:") 139 | for diff, count in difficulty_counts.items(): 140 | print(f" {diff}: {count}") 141 | 142 | 143 | # Save results to a new file 144 | output_file = "Genome-Bench-evaluation.json" 145 | with open(output_file, 'w') as f: 146 | json.dump(results, f, indent=2) 147 | 148 | print(f"\nProcessed {len(results)} entries and saved to {output_file}") 149 | 150 | if __name__ == "__main__": 151 | main() -------------------------------------------------------------------------------- /training/rl_training.py: -------------------------------------------------------------------------------- 1 | ####### ---------------- 2 | ## GRPO Training 3 | ####### ---------------- 4 | 5 | 6 | import os 7 | import re 8 | import torch 9 | from datasets import load_dataset, Dataset, load_from_disk 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | from trl.trainer import GRPOConfig, GRPOTrainer 12 | import wandb 13 | 14 | 15 | # ========================== SYSTEM PROMPT ========================== 16 | R1_STYLE_SYSTEM_PROMPT = """A conversation between User and Assistant. The user asks a single-choice Multiple Choice question, and the Assistant solves it. Please answer the multiple choice question by selecting only one from optiona a., option b., option c., option d., option e.. 17 | The assistant first thinks about the explanation process in the mind and then provides the user 18 | with the answer. The explanation process and answer are enclosed within and 19 | tags, respectively, i.e., explanation process here 20 | answer here .""" 21 | 22 | TASK_SPECIFIC_INSTRUCTIONS = "The answer must be a single letter from a,b,c,d,e." 23 | 24 | 25 | # ===================================================================== 26 | # Utility functions 27 | # ===================================================================== 28 | def extract_xml_answer(text: str) -> str: 29 | try: 30 | return text.split("")[-1].split("")[0].strip() 31 | except IndexError: 32 | return "" 33 | 34 | 35 | def preprocess_dataset(dataset_name, split="train", chunk_size=1000) -> Dataset: 36 | dataset = load_from_disk(dataset_name)[split] 37 | 38 | def process_batch(batch): 39 | prompts = [ 40 | [ 41 | { 42 | "role": "system", 43 | "content": R1_STYLE_SYSTEM_PROMPT + "\n" + TASK_SPECIFIC_INSTRUCTIONS, 44 | }, 45 | {"role": "user", "content": q.strip()}, 46 | ] 47 | for q in batch["question"] 48 | ] 49 | 50 | return { 51 | "prompt": prompts, 52 | "answer": [extract_xml_answer(a) for a in batch["answer"]], 53 | } 54 | 55 | return dataset.map(process_batch, batched=True, batch_size=chunk_size) 56 | 57 | 58 | # ------------------------- Reward functions -------------------------- 59 | def format_reward_func(completions, **kwargs) -> list[float]: 60 | """Reward 1 pt if completion matches required XML template.""" 61 | pattern = r"^(?:(?!).)*\n(?:(?!).)*$" 62 | responses = [completion[0]["content"] for completion in completions] 63 | return [1.0 if re.match(pattern, r) else 0.0 for r in responses] 64 | 65 | 66 | def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: 67 | """Reward 2 pt if extracted answer matches ground-truth letter.""" 68 | responses = [completion[0]["content"] for completion in completions] 69 | extracted = [extract_xml_answer(r) for r in responses] 70 | 71 | print( 72 | f"\n\n==================== DEBUG ====================\n" 73 | f"User Question:\n{prompts[0][-1]['content']}" 74 | f"\n\nCorrect Answer:\n{answer[0]}\n" 75 | f"\nFirst generated response:\n{responses[0]}" 76 | f"\nExtracted: {extracted[0]}" 77 | f"\nCorrectness flags: {''.join('Y' if r==a else 'N' for r,a in zip(extracted,answer))}" 78 | ) 79 | 80 | return [2.0 if r == a else 0.0 for r, a in zip(extracted, answer)] 81 | 82 | 83 | # ===================================================================== 84 | # MAIN 85 | # ===================================================================== 86 | def main(): 87 | dataset_name = "./dataset/Genome-Bench" 88 | dataset = preprocess_dataset(dataset_name, chunk_size=500) 89 | 90 | model_name = "Qwen2.5-7B-Instruct" 91 | epoch = 2 92 | learning_rate = 1e-5 93 | num_generations = 4 94 | 95 | output_dir = f"./../{model_name.split('/')[-1]}-RL" 96 | run_name = f"{model_name.split('/')[-1]}-{dataset_name.split('/')[-1]}" 97 | 98 | # --- memory env --- 99 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" 100 | 101 | # ------------------ TRAINING ARGS ------------------ 102 | training_args = GRPOConfig( 103 | learning_rate=learning_rate, 104 | beta=0.005, 105 | optim="adamw_torch", 106 | adam_beta1=0.9, 107 | adam_beta2=0.99, 108 | weight_decay=0.1, 109 | warmup_ratio=0.1, 110 | lr_scheduler_type="cosine", 111 | logging_steps=1, 112 | bf16=True, 113 | per_device_train_batch_size=8, 114 | num_generations=num_generations, 115 | gradient_accumulation_steps=4, 116 | max_prompt_length=256, 117 | max_completion_length=512, 118 | num_train_epochs=epoch, 119 | save_steps=100_000, 120 | max_grad_norm=0.1, 121 | report_to="wandb", 122 | output_dir=output_dir, 123 | run_name=run_name, 124 | log_on_each_node=False, 125 | ) 126 | 127 | # ------------------ Model / Tokenizer -------------- 128 | model = AutoModelForCausalLM.from_pretrained( 129 | model_name, torch_dtype=torch.bfloat16, device_map="balanced" 130 | ) 131 | 132 | tokenizer = AutoTokenizer.from_pretrained( 133 | model_name, model_max_length=training_args.max_completion_length 134 | ) 135 | tokenizer.pad_token = tokenizer.eos_token 136 | 137 | # ------------------ Trainer ------------------------ 138 | trainer = GRPOTrainer( 139 | model=model, 140 | processing_class=tokenizer, 141 | reward_funcs=[format_reward_func, correctness_reward_func], 142 | args=training_args, 143 | train_dataset=dataset, 144 | ) 145 | 146 | wandb.init(project="crispr_grpo", name=run_name, mode="offline") 147 | trainer.train() 148 | trainer.save_model(training_args.output_dir) 149 | 150 | 151 | if __name__ == "__main__": 152 | main() 153 | 154 | -------------------------------------------------------------------------------- /dataset_pipeline/1_email_parse.py: -------------------------------------------------------------------------------- 1 | #######-------------------------- 2 | ## This script will modify the mbox content to 3 | ## the new_QA.json file with each item has (question,answer,context) 4 | #######-------------------------- 5 | 6 | 7 | import mailbox 8 | import csv 9 | import re 10 | 11 | def remove_quotes(text): 12 | # Regular expression to match lines that start with one or more '>' characters followed by a space 13 | clean_lines = [line for line in text.splitlines() if not re.match(r'^\s*(>+ )', line)] 14 | return '\n'.join(clean_lines) 15 | 16 | def get_email_body(message): 17 | body = None # Initialize body to None or an empty string 18 | if message.is_multipart(): 19 | for part in message.walk(): 20 | ctype = part.get_content_type() 21 | cdispo = str(part.get('Content-Disposition')) 22 | if (ctype == 'text/plain' and 'attachment' not in cdispo) or (ctype == 'text/html' and 'attachment' not in cdispo): 23 | body = part.get_payload(decode=True) 24 | body = body.decode('utf-8', errors='ignore') 25 | body = remove_quotes(body) 26 | return body # Return the cleaned body once found 27 | else: 28 | body = message.get_payload(decode=True) 29 | body = body.decode('utf-8', errors='ignore') 30 | body = remove_quotes(body) 31 | return body 32 | 33 | 34 | # Load the mbox file 35 | mbox = mailbox.mbox('.../mbox') # your local mbox address 36 | 37 | # Open a CSV file to write to 38 | with open('testemail.csv', 'w', newline='', encoding='utf-8') as f: 39 | writer = csv.writer(f) 40 | writer.writerow(['Message-ID', 'In-Reply-To', 'References', 'Subject', 'From', 'Date', 'To', 'Body']) 41 | 42 | for message in mbox: 43 | message_id = message['message-id'] 44 | in_reply_to = message['in-reply-to'] 45 | references = message['references'] 46 | subject = message['subject'] 47 | from_email = message['from'] 48 | date = message['date'] 49 | to_email = message['to'] 50 | body = get_email_body(message) 51 | 52 | writer.writerow([message_id, in_reply_to, references, subject, from_email, date, to_email, body]) 53 | 54 | print("Conversion completed with adjusted email body extraction!") 55 | 56 | 57 | 58 | # %% 59 | 60 | import pandas as pd 61 | # Adjust display settings 62 | pd.set_option('display.max_columns', None) # Show all columns 63 | pd.set_option('display.max_rows', 30) # Show all rows 64 | pd.set_option('display.max_colwidth', None) # Show full content of each cell 65 | pd.set_option('display.width', None) # Auto-detect the display width 66 | df = pd.read_csv('/xxxx/testemail.csv') 67 | 68 | df['Body'] = df['Body'].str.split('You received this message because you are subscribed to the Google Groups').str[0] 69 | pattern = r"On\s+(Monday|Tuesday|Wednesday|Thursday|Friday|Saturday|Sunday),\s+(January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{1,2},\s+(2000|2001|2002|2003|2004|2005|2006|2007|2008|2009|2010|2011|2012|2013|2014|2015|2016|2017|2018|2019|2020|2021|2022|2023|2024)\s+at.*?wrote:" 70 | 71 | # Split the 'body' column and keep only the part before the unwanted text 72 | df['Body'] = df['Body'].str.split(pattern, n=1, expand=True)[0] 73 | pattern = r"On\s+(Mon|Tue|Wed|Thu|Fri|Sat|Sun),\s+(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s+\d{1,2},\s+\d{4}\s+at\s+\d{1,2}:\d{2}\s+(AM|PM)\s+.*?<.*?>\s*\nwrote:" 74 | df['Body'] = df['Body'].str.split(pattern, n=1, expand=True)[0] 75 | df['Body'] = df['Body'].str.replace('\n', '', regex=False) 76 | df['Message-ID'] = df['Message-ID'].str.replace(' ', '', regex=False) 77 | df['Message-ID'] = df['Message-ID'].str.replace('<', '').str.replace('>', '') 78 | df['In-Reply-To'] = df['In-Reply-To'].str.replace(' ', '', regex=False) 79 | df['In-Reply-To'] = df['In-Reply-To'].str.replace('<', '').str.replace('>', '') 80 | df['References'] = df['References'].str.extractall(r'<(.*?)>').groupby(level=0).agg(list) 81 | df['References'] = df['References'].apply(lambda refs: [ref.replace(' ', '') for ref in refs] if isinstance(refs, list) else []) 82 | df.head(30) 83 | 84 | 85 | # %% 86 | 87 | import pandas as pd 88 | # Assuming 'df' is already loaded and formatted with 'References' as lists of message IDs. 89 | 90 | # Identifying root messages 91 | roots = df[df['In-Reply-To'].isna()] 92 | 93 | # Preparing a list to store threads 94 | threads = [] 95 | 96 | # Processing each root to find the thread end and reconstruct the thread 97 | for index, root in roots.iterrows(): 98 | thread_id = root['Message-ID'] 99 | thread_subject = root['Subject'] 100 | 101 | # Find messages with this root ID in their references 102 | thread_messages = df[df['References'].apply(lambda x: thread_id in x if x else False)] 103 | thread_body = '***SUBJECT: {subject}***'.format(subject=root['Subject']) 104 | 105 | if not thread_messages.empty: 106 | # Find the last message by identifying the one with the longest 'References' list 107 | last_message = thread_messages.loc[thread_messages['References'].apply(len).idxmax()] 108 | 109 | # Initialize the thread body with empty string message_ids = [root['Message-ID']] # Start with the root message ID 110 | 111 | # Iterate through the reference IDs to construct the body of the thread 112 | for msg_id in last_message['References']: 113 | # Safely retrieve the message, ensure there's at least one message that matches 114 | message = df[df['Message-ID'] == msg_id].iloc[0] if not df[df['Message-ID'] == msg_id].empty else None 115 | if message is not None: 116 | thread_body += ' ***NEW MESSAGE BY {sender}***: {body}'.format(sender=message['From'], body=message['Body']) 117 | message_ids.append(msg_id) # Track the message ID 118 | 119 | # Append the last message body after all referenced messages 120 | thread_body += ' ***NEW MESSAGE BY {sender}***: {body} '.format(sender=last_message['From'], body=last_message['Body']) 121 | message_ids.append(last_message['Message-ID']) # Include the last message ID 122 | 123 | # Append to the list of threads 124 | threads.append({'threadID': thread_id, 'threadSubject': thread_subject, 'Body': thread_body, 'Message-IDs': message_ids}) 125 | else: 126 | # Handle the case where no messages reference the root 127 | thread_body += ' ***NEW MESSAGE BY {sender}***: {body}'.format(sender=root['From'], body=root['Body']) 128 | message_ids = [root['Message-ID']] # Just the root message ID for standalone threads 129 | threads.append({'threadID': thread_id, 'threadSubject': thread_subject, 'Body': thread_body, 'Message-IDs': message_ids}) 130 | 131 | # Convert list to DataFrame 132 | thread_df = pd.DataFrame(threads) 133 | 134 | # Displaying the new DataFrame 135 | print(thread_df.head(5)) 136 | 137 | # %% 138 | 139 | import pandas as pd 140 | import json 141 | import openai 142 | from openai import OpenAI 143 | import re 144 | from concurrent.futures import ThreadPoolExecutor, as_completed 145 | 146 | # Initialize your OpenAI client 147 | client = OpenAI(api_key='xxxx') 148 | 149 | # Assuming 'df' is your DataFrame and it has 'threadID' and 'Body' columns 150 | df = thread_df 151 | 152 | system_prompt = "The following text represents an entire email thread that's in chronological order. It started with a ***SUBJECT: {subject}***. It then has different emails by different people, segmented by '***NEW MESSAGE BY {name} ***'. I want you to extract as many research/scientific related Q&A pairs as possible. You also want to have a context field that makes it so a new researcher, who is not involved in the Email thread, can read the Q&A and the context field, and then have an understanding of what is going on. It's good for answer to be a solid response to the question based on the email thread, and it'd be nice for for answer to include explanations. The context field is additionally provided to give more macro information about the whole thread. The question and the context field together can help this new researcher understand why this question is taking in place, and has a holistic understanding of the entire Email thread. Your output must be a list of maps (i.e. [{question: 'what is 1+1' answer': '2', context: 'PersonA has been asking PersonB about elementary math', questionFrom: 'daniel, Daniel@gmail.com', answerFrom: 'john, john@gmail.com'}...]. If there's no answer then just keep the answer field as empty string "". FOR Question, Answer, and Context Field, MAKE SURE TO REPLACE REAL NAMES WITH PERSON1, PERSON2... MOST IMPORTANT: REMINDER THAT YOUR REPLY MUST BE A LIST OF MAPS IN THAT FORMAT I DESCRIBED AND NOTHING ELSE." 153 | 154 | 155 | def fix_json_string(json_like_string): 156 | # Step 1: Replace single quotes around keys with double quotes 157 | corrected_json = re.sub(r"\'([a-zA-Z0-9_]+)\'\s*:", r'"\1":', json_like_string) 158 | 159 | # Step 2: Add double quotes around unquoted keys 160 | corrected_json = re.sub(r'(? str: 45 | """Extracts the answer from between ... tags.""" 46 | try: 47 | answer = text.split("")[-1].split("")[0].strip() 48 | return answer 49 | except IndexError: 50 | return "" 51 | 52 | ########################################################################## 53 | # New Trainer Subclass: PreGeneratedGRPOTrainer 54 | ########################################################################## 55 | class PreGeneratedGRPOTrainer(GRPOTrainer): 56 | """ 57 | A custom GRPOTrainer that uses pre‐generated completions instead of on‐the‐fly generation. 58 | 59 | Each training sample is expected to have a "pre_generated" key containing a list of 4 dictionaries. 60 | Each dictionary must include: 61 | - "content": the candidate output from the router assistant (e.g. "2") 62 | - Optionally, a detailed response under a key like "response by model 2" 63 | 64 | For training, we use all 4 candidates. We replicate the prompt for each candidate so that the effective 65 | batch size is (original batch size × num_generations). The ground‐truth answer is repeated accordingly, 66 | and later the rewards are computed and grouped per sample. 67 | """ 68 | 69 | def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: 70 | device = self.accelerator.device 71 | prompts = [x["prompt"] for x in inputs] 72 | prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] 73 | prompt_inputs = self.processing_class( 74 | prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False 75 | ) 76 | #prompt_inputs = super()._prepare_inputs(prompt_inputs) 77 | prompt_ids, prompt_mask = prompt_inputs["input_ids"].to(device), prompt_inputs["attention_mask"].to(device) 78 | 79 | if self.max_prompt_length is not None: 80 | prompt_ids = prompt_ids[:, -self.max_prompt_length :] 81 | prompt_mask = prompt_mask[:, -self.max_prompt_length :] 82 | 83 | 84 | pre_generated_list = [x["pre_generated"] for x in inputs] # length: N; each element is list of 4 dicts 85 | candidate_texts = [] # will have length: N * 4 86 | reward_texts = [] # will have length: N * 4 87 | 88 | for pg in pre_generated_list: 89 | entry = random.choice(pg) # randomly sample one entry from pg 90 | candidate = entry["content"].strip() 91 | candidate_texts.append(candidate) 92 | 93 | digit_list = re.findall(r"\d", candidate) 94 | if not digit_list: 95 | raise ValueError(f"Candidate text '{candidate}' does not contain a digit.") 96 | digit = digit_list[0] 97 | response_key = f"response by model {digit}" 98 | reward_texts.append(entry.get(response_key, candidate)) 99 | 100 | completion_inputs = self.processing_class( 101 | text=candidate_texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False 102 | ) 103 | completion_ids = completion_inputs["input_ids"].to(device) # This is analogous to completion_ids produced by generation. 104 | completion_mask = completion_inputs["attention_mask"].to(device) # Similarly, candidate_mask. 105 | 106 | 107 | # Construct prompt_completion_ids from pre-generated completions 108 | prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) 109 | 110 | # Mask everything after the first EOS token 111 | is_eos = (completion_ids == self.processing_class.eos_token_id).to(device) 112 | eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) 113 | 114 | # Make sure the mask and argmax are on the same device 115 | any_eos = is_eos.any(dim=1) 116 | argmax_eos = is_eos.int().argmax(dim=1) 117 | 118 | eos_idx[any_eos] = argmax_eos[any_eos] # no device mismatch here 119 | sequence_indices = torch.arange(is_eos.size(1), device=is_eos.device).expand(is_eos.size(0), -1) 120 | completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() 121 | 122 | 123 | # Concatenate prompt_mask with completion_mask for logit computation 124 | attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C) 125 | 126 | logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens 127 | 128 | with torch.inference_mode(): 129 | if self.ref_model is not None: 130 | ref_per_token_logps = self._get_per_token_logps( 131 | self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep 132 | ) 133 | else: 134 | with self.accelerator.unwrap_model(self.model).disable_adapter(): 135 | ref_per_token_logps = self._get_per_token_logps( 136 | self.model, prompt_completion_ids, attention_mask, logits_to_keep 137 | ) 138 | 139 | 140 | # Decode the generated completions 141 | completions_text = reward_texts 142 | completions = completions_text 143 | rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) 144 | for i, (reward_func, reward_processing_class) in enumerate( 145 | zip(self.reward_funcs, self.reward_processing_classes) 146 | ): 147 | 148 | # Repeat all input columns (but "prompt" and "completion") to match the number of generations 149 | keys = [key for key in inputs[0] if key not in ["prompt", "completion","pre_generated"]] 150 | reward_kwargs = {key: [example[key] for example in inputs] for key in keys} 151 | output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) 152 | rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) 153 | 154 | # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the 155 | # completions may be distributed across processes 156 | rewards_per_func = gather(rewards_per_func) 157 | 158 | 159 | # Apply weights to each reward function's output and sum 160 | rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1) 161 | 162 | # Compute grouped-wise rewards 163 | mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) 164 | std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) 165 | 166 | # Normalize the rewards to compute the advantages 167 | mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) 168 | std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) 169 | advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) 170 | 171 | # Slice to keep only the local part of the data 172 | process_slice = slice( 173 | self.accelerator.process_index * len(prompts), 174 | (self.accelerator.process_index + 1) * len(prompts), 175 | ) 176 | advantages = advantages[process_slice] 177 | 178 | # Log the metrics 179 | reward_per_func = rewards_per_func.mean(0) 180 | for i, reward_func in enumerate(self.reward_funcs): 181 | if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models 182 | reward_func_name = reward_func.config._name_or_path.split("/")[-1] 183 | else: 184 | reward_func_name = reward_func.__name__ 185 | self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item()) 186 | 187 | self._metrics["reward"].append(rewards.mean().item()) 188 | self._metrics["reward_std"].append(std_grouped_rewards.mean().item()) 189 | 190 | if ( 191 | self.log_completions 192 | and self.state.global_step % self.args.logging_steps == 0 193 | and "wandb" in self.args.report_to 194 | ): 195 | import pandas as pd 196 | 197 | # For logging 198 | table = { 199 | "step": [str(self.state.global_step)] * len(rewards), 200 | "prompt": gather_object(prompts_text), 201 | "completion": gather_object(completions_text), 202 | "reward": rewards.tolist(), 203 | } 204 | df = pd.DataFrame(table) 205 | 206 | if wandb.run is not None and self.accelerator.is_main_process: 207 | wandb.log({"completions": wandb.Table(dataframe=df)}) 208 | 209 | return { 210 | "prompt_ids": prompt_ids, 211 | "prompt_mask": prompt_mask, 212 | "completion_ids": completion_ids, 213 | "completion_mask": completion_mask, 214 | "ref_per_token_logps": ref_per_token_logps, 215 | "advantages": advantages, 216 | } 217 | 218 | 219 | 220 | 221 | ########################################################################## 222 | # Reward Function: Correctness 223 | ########################################################################## 224 | def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: 225 | """ 226 | Uses the detailed responses (reward_texts) to compute correctness. 227 | It extracts the answer (e.g. the letter "a") from the detailed response and compares it with the ground truth. 228 | Returns 1.0 if they match, otherwise -1.0. 229 | """ 230 | reward_texts = kwargs.get("reward_texts", None) 231 | if reward_texts is not None: 232 | extracted_answers = [extract_xml_answer(rt) for rt in reward_texts] 233 | else: 234 | extracted_answers = [extract_xml_answer(c) for c in completions] 235 | 236 | print("\n\n===============================================================\n" 237 | f"User question (sample): {prompts[0]}\n" 238 | f"Ground truth answer: {answer[0]}\n" 239 | f"Extracted answers (from reward texts): {extracted_answers}\n") 240 | return [1.0 if r == a else -1.0 for r, a in zip(extracted_answers, answer)] 241 | 242 | ########################################################################## 243 | # Dataset Preprocessing 244 | ########################################################################## 245 | R1_STYLE_SYSTEM_PROMPT = """There are four models capable of answering single-choice multiple choice questions. You are the Router Assistant. 246 | 247 | In a conversation between the User and the Router Assistant, the User provides a single-choice multiple choice question. Your job as the Router Assistant is to suggest which of the four models is most likely to provide the best answer. 248 | 249 | You must select only one model from the following options: model 1, model 2, model 3, or model 4. 250 | 251 | Provide your recommendation enclosed within tags, for example: 1 252 | """ 253 | TASK_SPECIFIC_INSTRUCTIONS = "The choice must be a single digit: 1, 2, 3, or 4." 254 | 255 | def preprocess_dataset(dataset_name, split="train", chunk_size=1000) -> Dataset: 256 | dataset = load_from_disk(dataset_name)[split] 257 | def process_batch(batch): 258 | # Build the prompt as a list of two messages: 259 | # System instruction (with task-specific instructions) and the user question. 260 | prompts = [ 261 | [ 262 | {'role': 'system', 'content': R1_STYLE_SYSTEM_PROMPT + "\n" + TASK_SPECIFIC_INSTRUCTIONS}, 263 | {'role': 'user', 'content': q.strip()} 264 | ] 265 | for q in batch['question'] 266 | ] 267 | return { 268 | 'prompt': prompts, 269 | 'answer': [extract_xml_answer(a) for a in batch['answer']], 270 | 'pre_generated': batch['pre_generated'] 271 | } 272 | return dataset.map(process_batch, batched=True, batch_size=chunk_size) 273 | 274 | ########################################################################## 275 | # Main Training 276 | ########################################################################## 277 | def main(): 278 | dataset_name = './dataset/Genome-Bench-Router' # Genome-Bench-Router dataset contains pre-generated responses from different RL models 279 | dataset = preprocess_dataset(dataset_name, chunk_size=500) 280 | 281 | model_name = "Qwen2.5-7B-Instruct" 282 | epoch = 2 283 | learning_rate = 1e-5 284 | output_dir = f"{model_name.split('/')[-1]}-Router" 285 | run_name = f"{model_name.split('/')[-1]}-{dataset_name.split('/')[-1]}" 286 | 287 | os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' 288 | 289 | training_args = GRPOConfig( 290 | learning_rate=learning_rate, 291 | beta=0.005, 292 | optim="adamw_8bit", 293 | adam_beta1=0.9, 294 | adam_beta2=0.99, 295 | weight_decay=0.1, 296 | warmup_ratio=0.1, 297 | lr_scheduler_type='cosine', 298 | logging_steps=1, 299 | bf16=True, 300 | per_device_train_batch_size=8, 301 | num_generations=8, # Use 4 candidate generations per sample 302 | gradient_accumulation_steps=4, 303 | max_prompt_length=256, 304 | max_completion_length=512, 305 | num_train_epochs=epoch, 306 | save_steps=100000, 307 | max_grad_norm=0.1, 308 | report_to="wandb", 309 | output_dir=output_dir, 310 | run_name=run_name, 311 | log_on_each_node=False, 312 | ) 313 | 314 | model = AutoModelForCausalLM.from_pretrained( 315 | model_name, 316 | torch_dtype=torch.bfloat16, 317 | device_map="auto", 318 | ) 319 | 320 | tokenizer = AutoTokenizer.from_pretrained( 321 | model_name, 322 | model_max_length=training_args.max_completion_length, 323 | ) 324 | tokenizer.pad_token = tokenizer.eos_token 325 | 326 | # Instantiate the custom trainer. 327 | trainer = PreGeneratedGRPOTrainer( 328 | model=model, 329 | processing_class=tokenizer, 330 | reward_funcs=[correctness_reward_func], 331 | args=training_args, 332 | train_dataset=dataset, 333 | ) 334 | 335 | wandb.init(project="crispr_grpo_router", name=run_name, mode="offline") 336 | trainer.train() 337 | trainer.save_model(training_args.output_dir) 338 | 339 | if __name__ == "__main__": 340 | main() --------------------------------------------------------------------------------