├── .DS_Store
├── figure
├── .DS_Store
├── overview.png
└── example_result.png
├── dataset_pipeline
├── .DS_Store
├── 3_dataset_prepare.py
├── 4_convert_natural_question.py
├── 2_convert_MCQ_full.py
└── 1_email_parse.py
├── difficulty_and_category_assignment
├── .DS_Store
└── category_difficulty.py
├── requirements.txt
├── README.md
├── evaluation
└── genome-bench_eval.py
└── training
├── sft_training.py
├── rl_training.py
└── rl_router_training.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingyin0312/RL4GenomeBench/HEAD/.DS_Store
--------------------------------------------------------------------------------
/figure/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingyin0312/RL4GenomeBench/HEAD/figure/.DS_Store
--------------------------------------------------------------------------------
/figure/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingyin0312/RL4GenomeBench/HEAD/figure/overview.png
--------------------------------------------------------------------------------
/figure/example_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingyin0312/RL4GenomeBench/HEAD/figure/example_result.png
--------------------------------------------------------------------------------
/dataset_pipeline/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingyin0312/RL4GenomeBench/HEAD/dataset_pipeline/.DS_Store
--------------------------------------------------------------------------------
/difficulty_and_category_assignment/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingyin0312/RL4GenomeBench/HEAD/difficulty_and_category_assignment/.DS_Store
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets
2 | openai
3 | pandas
4 | torch
5 | tqdm
6 | transformers
7 | wandb
8 | csv
9 | mailbox
10 | re
11 | json
12 | tqdm
13 | os
14 | vllm
15 | random
16 | warnings
17 | trl
18 | typing
19 | accelerate
20 | packaging
21 |
22 |
23 |
--------------------------------------------------------------------------------
/dataset_pipeline/3_dataset_prepare.py:
--------------------------------------------------------------------------------
1 | #######--------------------------
2 | ## This script will use Converted_MCQs_gpt4o_full.json file
3 | ## as the input and convert it to MCQs_formatted.json,
4 | ## where answer key has the format
5 | ## {original expert explanation} {correct option}.
6 | ## It also concatenates 'context' with 'question' in question field.
7 | #######--------------------------
8 |
9 |
10 | import json
11 |
12 | # Load the original dataset
13 | with open("Converted_MCQs_gpt4o_full.json", "r") as file:
14 | data = json.load(file)
15 |
16 | # Function to format the answer with step-by-step reasoning (dummy example)
17 | def format_answer(original_answer, final_answer):
18 | return f"{original_answer} {final_answer}"
19 |
20 | # Convert to GENOME-BENCH format
21 | converted_data = []
22 | for entry in data:
23 | question = entry["Questions with options"]
24 | reasoning = entry.get("Original answer", "The answer is derived from the given choices.") # Placeholder reasoning
25 | final_answer = entry.get("Answer key"),
26 | context = entry["Context"]
27 |
28 |
29 | # Ensure proper formatting
30 | formatted_answer = format_answer(reasoning, final_answer[0])
31 |
32 | converted_data.append({
33 | "question": 'Question context: ' + context + ' Question: ' + question,
34 | "answer": formatted_answer
35 | })
36 |
37 | # Save the new dataset
38 | with open("MCQs_formatted.json", "w") as outfile:
39 | json.dump(converted_data, outfile, indent=4)
40 |
41 | print("Dataset reformatted to GENOME-BENCH format and saved as MCQs_formatted.json")
--------------------------------------------------------------------------------
/dataset_pipeline/4_convert_natural_question.py:
--------------------------------------------------------------------------------
1 | ####-----------------------------------------------------------------
2 | #### Converting "Question + Option" questions to natural questions
3 | ####-----------------------------------------------------------------
4 |
5 | import json
6 | import openai
7 | from tqdm import tqdm
8 |
9 | # Set your OpenAI API key
10 | openai.api_key = "xxxx"
11 |
12 | # Load the JSON file
13 | input_file = "MCQs_formatted.json"
14 | output_file = "MCQs_Genome-Bench.json"
15 |
16 | with open(input_file, "r", encoding="utf-8") as f:
17 | questions_data = json.load(f)
18 |
19 | # Function to transform structured question into a natural-sounding question
20 | def generate_natural_question(question_context, question_text):
21 | prompt = (f"Convert the following structured question into a natural-sounding question while preserving the "
22 | f"maximal amount of information:\n\n"
23 | f"Context: {question_context}\n"
24 | f"Question: {question_text}\n\n"
25 | f"The output should be a single natural question that integrates the context smoothly without phrases like "
26 | f"'Person A' or 'Person B'.")
27 |
28 | try:
29 | response = openai.chat.completions.create(
30 | model="gpt-4o",
31 | messages=[{"role": "system", "content": "You are an AI that rewrites structured questions into natural-sounding questions while preserving information."},
32 | {"role": "user", "content": prompt}],
33 | max_tokens=1024,
34 | temperature=0.7
35 | )
36 |
37 | return response.choices[0].message.content
38 | except Exception as e:
39 | print(f"Error generating question: {e}")
40 | return None
41 |
42 |
43 | # Process each question in the JSON file
44 | processed_questions = []
45 | for item in tqdm(questions_data, desc="Processing Questions", unit="question"):
46 | structured_question = item["question"]
47 |
48 | # Extract context and question text
49 | try:
50 | context_start = structured_question.find("Question context: ") + len("Question context: ")
51 |
52 | question_start = structured_question.find("Question: ")
53 |
54 | context_text = structured_question[context_start:question_start].strip()
55 |
56 | question_text = structured_question[question_start + len("Question: "):].strip()
57 |
58 |
59 | # Generate natural question
60 | natural_question = generate_natural_question(context_text, question_text)
61 | if natural_question:
62 | processed_questions.append({"natural_question": natural_question})
63 | except Exception as e:
64 | print(f"Error processing question: {e}")
65 |
66 | # Save the transformed questions to a new JSON file
67 | with open(output_file, "w", encoding="utf-8") as f:
68 | json.dump(processed_questions, f, indent=4, ensure_ascii=False)
69 |
70 | print(f"Processed questions saved to {output_file}")
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Toward Scientific Reasoning in LLMs: Training from Expert Discussions via Reinforcement Learning
4 |
5 |
6 | 🧬 A New Benchmark Genome-Bench and RL Fine-Tuning for Scientific Reasoning 📊
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | [](https://arxiv.org/abs/2505.19501)
16 | [](https://github.com/mingyin0312/RL4GenomeBench)
17 | [](https://huggingface.co/datasets/Mingyin0312/Genome-Bench)
18 |
19 |
20 |
21 |
22 | ## Overview
23 |
24 | 
25 |
26 | We introduce **Genome-Bench**, a novel benchmark for evaluating and improving scientific reasoning in large language models. Genome-Bench consists of over 3,000 multiple-choice and QA items derived from CRISPR-related scientific discussions and forum threads, covering key topics in genome engineering, experimental design, and error analysis.
27 |
28 | Our RL training pipeline (based on Group Relative Policy Optimization) improves model performance across expert-labeled evaluation sets. For example, our fine-tuned Qwen2.5-7B model exceeds GPT-4o in accuracy and consistency on multi-hop reasoning tasks.
29 |
30 | 
31 |
32 | ---
33 |
34 | ## Getting Started 🎯
35 |
36 | ### Installation
37 |
38 | ```bash
39 | git clone https://github.com/mingyin0312/RL4GenomeBench.git
40 | cd RL4GenomeBench
41 | pip install -r requirements.txt
42 | ```
43 |
44 |
45 | ### Dataset Preparation
46 |
47 | We provide tools to parse .mbox email archives and convert them into standardized MCQ and QA formats.
48 |
49 | ```bash
50 | cd dataset_pipeline
51 | python 1_email_parse.py
52 | python 2_convert_MCQ_full.py
53 | python 3_dataset_prepare.py
54 | python 4_convert_natural_question.py
55 | ```
56 |
57 | ## Training
58 |
59 | ### Reinforcement Fine-tuning (GRPO)
60 |
61 | ```bash
62 | python training/rl_training.py
63 | ```
64 |
65 | ### Supervised Fine-Tuning (SFT)
66 |
67 | ```bash
68 | python training/sft_training.py
69 | ```
70 |
71 | ### Multi-Agent RL Routing
72 |
73 | ```bash
74 | python training/rl_router_training.py
75 | ```
76 |
77 | ## Evaluation
78 |
79 | To evaluate on the Genome-Bench test data:
80 |
81 | ```bash
82 | python evaluation/genome-bench_eval.py
83 | ```
84 |
85 | ## Citation
86 |
87 | ```bibtex
88 | @article{yin2025genome,
89 | title={Toward Scientific Reasoning in LLMs: Training from Expert Discussions via Reinforcement Learning},
90 | author={Yin, Ming and Qu, Yuanhao and Ling, Yang and Cong, Le and Wang Mengdi},
91 | journal={arXiv preprint arXiv:2505.19501},
92 | year={2025}
93 | }
94 | ```
95 |
96 | ## Acknowledgement
97 |
98 | This project leverages the 🤗 [Transformers Reinforcement Learning (TRL)](https://github.com/huggingface/trl) library, which provides powerful tools for fine-tuning large language models with reinforcement learning techniques such as GRPO.
99 |
100 |
--------------------------------------------------------------------------------
/dataset_pipeline/2_convert_MCQ_full.py:
--------------------------------------------------------------------------------
1 | #######--------------------------
2 | ## This script will use new_QA.json file
3 | ## as the input and use gpt-4o to convert it to
4 | ## multiple-choice question data named
5 | ## Converted_MCQs_gpt4o_full.json
6 | #######--------------------------
7 |
8 | import json
9 | import openai
10 | from tqdm import tqdm
11 |
12 | # Set up your OpenAI API key
13 | openai.api_key = "xxxx"
14 | # Load the MCQ examples for prompting
15 | with open("MCQ_updated.json", "r") as f:
16 | mcq_examples = json.load(f)["MCQ_Updated"]
17 |
18 | # Load the QA data
19 | with open("new_QA.json", "r") as f:
20 | final_qa_data = json.load(f)
21 |
22 | # Prepare the examples prompt
23 | example_prompts = "\n\n".join(
24 | f"Example MCQ:\n{mcq['Questions with options']}\nCorrect Answer: {mcq['Key']}"
25 | for mcq in mcq_examples#
26 | )
27 |
28 | # Initialize a list to store generated MCQs
29 | generated_mcqs = []
30 |
31 | # Process each question in the QA data
32 | for item in tqdm(final_qa_data, total = len(final_qa_data)):
33 |
34 | question = item["question"]
35 | answer = item["answer"]
36 | context = item["context"]
37 |
38 | if answer.strip() == "":
39 | continue
40 | else:
41 | # Prepare the prompt
42 | prompt = f"""
43 | Below are examples of multiple-choice questions (MCQs) with their formats. Use them to generate a new Single-Choice MCQ based on the provided question and context. The MCQ should include five answer choices (a-e) with one being the correct answer, and it should be identical as provided, following the similar format as the examples. The response should not include anything else except for the generated multiple-choice question itself. Just provide one version, then end the response and don't repeat.
44 |
45 | {example_prompts}
46 |
47 | New Question:
48 | Question: {question}
49 | Answer: {answer}
50 | Context: {context}
51 |
52 | Do not modify the original question in the new MCQ.
53 |
54 | Generate a new MCQ:
55 | """
56 | #If the above question is not clear, you can also modify the question by using 'context' information for new questions. However, do not use 'answer' information for it!
57 |
58 | # Generate MCQ using the OpenAI API
59 | response = openai.chat.completions.create(
60 | model="gpt-4o",
61 | messages=[
62 | {"role": "system", "content": "You are an gene-editing expert generating formatted single-choice multiple-choice questions."},
63 | {"role": "user", "content": prompt}
64 | ],
65 | max_tokens=1024,
66 | temperature=0.7
67 | )
68 |
69 | # Extract the generated text
70 | generated_text = response.choices[0].message.content
71 |
72 | # Process the response to separate options and answer key
73 | index = generated_text.find("Correct Answer:")
74 | if index != -1:
75 | options = generated_text[:index].strip()
76 | answer_key = generated_text[index+16:].strip()
77 | else:
78 | options = generated_text
79 | answer_key = "N/A" # Default if no answer key is found
80 |
81 | # Store the result
82 | generated_mcqs.append({
83 | "Questions with options": question + " " + options,
84 | "Answer key": answer_key,
85 | "Original answer": answer,
86 | "Context": context
87 | })
88 |
89 |
90 | # Save the generated MCQs to a new JSON file
91 | output_file = "Converted_MCQs_gpt4o_full.json"
92 | with open(output_file, "w") as f:
93 | json.dump(generated_mcqs, f, indent=4)
94 |
95 | print(f"Generated MCQs saved to {output_file}")
--------------------------------------------------------------------------------
/evaluation/genome-bench_eval.py:
--------------------------------------------------------------------------------
1 | ####-----------------------
2 | #### Formal evaluation
3 | ####-----------------------
4 |
5 | import json
6 | import transformers
7 | import torch
8 | from datasets import load_dataset, Dataset, load_from_disk
9 | from transformers import AutoTokenizer
10 | from vllm import LLM, SamplingParams
11 | from tqdm import tqdm
12 | import openai
13 |
14 |
15 | #### Load data
16 | dataset_name = './dataset/Genome-Bench'
17 | dataset_loaded = load_from_disk(dataset_name)['test']
18 | # Convert to the desired list format with the name 'questions'
19 | questions = [{"question": q, "answer": a} for q, a in zip(dataset_loaded["question"], dataset_loaded["answer"])]
20 |
21 |
22 | # System Prompt
23 | R1_STYLE_SYSTEM_PROMPT = """A conversation between User and Assistant. The user asks a single-choice Multiple Choice question, and the Assistant solves it. Please answer the multiple choice question by selecting only one from optiona a., option b., option c., option d., option e..
24 | The assistant first thinks about the explanation process in the mind and then provides the user
25 | with the answer. The explanation process and answer are enclosed within and
26 | tags, respectively, i.e., explanation process here
27 | answer here ."""
28 |
29 |
30 | TASK_SPECIFIC_INSTRUCTIONS = "The answer must be a single letter from a,b,c,d,e."
31 |
32 | def extract_xml_answer(text: str) -> str:
33 | """Extracts the answer from a response using the tag."""
34 | try:
35 | answer = text.split("")[-1].split("")[0].strip()
36 | return answer
37 | except IndexError:
38 | return ""
39 |
40 |
41 | # Initialize the pipeline with Qwen2.5-7B-Instruct
42 | name = 'Qwen2.5-7B-Instruct'
43 | model_id = name # Replace with your model ID
44 |
45 |
46 | llm = LLM(model=model_id, dtype="half", max_model_len=1024, device="cuda:0")
47 |
48 | # Load tokenizer
49 | tokenizer = AutoTokenizer.from_pretrained(model_id, model_max_length=1024, padding_side='right')
50 |
51 | import random
52 | random_number = random.randint(1, 10000)
53 |
54 | # Set sampling parameters
55 | sampling_params = SamplingParams(
56 | temperature=0.7,
57 | max_tokens=1024,
58 | stop_token_ids=[tokenizer.eos_token_id],
59 | seed = random_number,
60 | )
61 |
62 |
63 | BATCH_SIZE = 8
64 | # Evaluate questions in batches
65 | results = []
66 | correct = 0
67 | total = 0
68 |
69 | # Progress bar
70 | progress_bar = tqdm(total=len(questions), desc="Processing", unit="examples", dynamic_ncols=True)
71 |
72 | for i in range(0, len(questions), BATCH_SIZE):
73 | batch_data = questions[i:i + BATCH_SIZE]
74 |
75 | # Prepare prompts using few-shot learning
76 | prompts = [
77 | [
78 | {'role': 'system', 'content': R1_STYLE_SYSTEM_PROMPT + "\n" + TASK_SPECIFIC_INSTRUCTIONS},
79 | {"role": "user", "content": q["question"]},
80 | ]
81 |
82 | for q in batch_data
83 | ]
84 |
85 | # Convert prompts to formatted strings
86 | formatted_prompts = [
87 | tokenizer.apply_chat_template(p, tokenize=False, add_generation_prompt=True)
88 | for p in prompts
89 | ]
90 |
91 | # Generate responses using vLLM
92 | outputs = llm.generate(formatted_prompts, sampling_params)
93 |
94 | # Process responses
95 | for j, output in enumerate(outputs):
96 | response = output.outputs[0].text
97 | generated_answer = extract_xml_answer(response)
98 | true_answer = extract_xml_answer(batch_data[j]["answer"])
99 |
100 | # Store the result
101 | result = {
102 | "Question": batch_data[j]["question"],
103 | "Generated Answer": generated_answer,
104 | "Correct Answer": true_answer,
105 | "Full Response": response,
106 | "Correct": generated_answer == true_answer
107 | }
108 | results.append(result)
109 |
110 | if generated_answer == true_answer:
111 | correct += 1
112 | total += 1
113 |
114 | # Update progress bar
115 | progress_bar.update(len(batch_data))
116 | progress_bar.set_postfix({
117 | "Accuracy": f"{(correct / total) * 100:.2f}%",
118 | "Correct": f"{correct}/{total}",
119 | })
120 |
121 | progress_bar.close()
122 |
123 |
124 | # Save the results to a JSON file
125 | output_file = name + '_evaluation_results.json'
126 | with open(output_file, 'w') as file:
127 | json.dump(results, file, indent=4)
128 |
129 | print(f"Evaluation complete. Results saved to {output_file}.")
130 |
--------------------------------------------------------------------------------
/training/sft_training.py:
--------------------------------------------------------------------------------
1 | ####### ----------------
2 | ## SFT Training
3 | ####### ----------------
4 |
5 |
6 | import os
7 | import re
8 | import torch
9 | from datasets import load_from_disk, Dataset
10 | from transformers import AutoTokenizer, AutoModelForCausalLM
11 | from trl import SFTTrainer, SFTConfig
12 | import wandb
13 |
14 | # System prompt and task instructions
15 | R1_STYLE_SYSTEM_PROMPT = """A conversation between User and Assistant. The user asks a single-choice Multiple Choice question, and the Assistant solves it. Please answer the multiple choice question by selecting only one from option a., option b., option c., option d., option e..
16 | The assistant first thinks about the explanation process in the mind and then provides the user
17 | with the answer. The explanation process and answer are enclosed within and
18 | tags, respectively, i.e., explanation process here
19 | answer here ."""
20 |
21 | TASK_SPECIFIC_INSTRUCTIONS = "The answer must be a single letter from a, b, c, d, e."
22 |
23 | EXAMPLE = "Question context: PersonA, a novice in the field of selecting mouse ES cells for screening, raises several detailed questions regarding colony selection and verification processes. PersonB provides thorough answers focused on optimizing the selection and genetic verification processes in embryonic stem cells. Question: Would anyone recommendation colony picking versus 96-well dilution for screening colonies? Please choose one of the following options: a. \"Colony picking is preferred for mouse ES cells because they grow as perfect colonies that are easy to pick\" b. \"96-well dilution is more efficient because it allows for high-throughput screening of colonies\" c. \"Colony picking is more labor-intensive and should be avoided when possible\" d. \"96-well dilution is the best method for ensuring genetic consistency across colonies\" e. \"Both methods are equally effective, so the choice depends on available resources\""+"For mouse ES cells there's no reason to do limited dilution, as they grow as perfect colonies that are easy to pick.\na"
24 |
25 | def extract_hash_answer(text: str) -> str:
26 | try:
27 | explanation, answer = text.split("####", 1)
28 | return f" {explanation.strip()} {answer.strip()} "
29 | except ValueError:
30 | return ""
31 |
32 | def preprocess_dataset(dataset_name, split="train", chunk_size=1000) -> Dataset:
33 | """
34 | Load the dataset from disk and process each batch to generate chat-style prompts.
35 | The resulting dataset will have a "text" field (a string) and an "answer" field.
36 | """
37 | dataset = load_from_disk(dataset_name)[split]
38 |
39 | def process_batch(batch):
40 | chats = [
41 | f"System: {R1_STYLE_SYSTEM_PROMPT}\n{TASK_SPECIFIC_INSTRUCTIONS}\n{EXAMPLE}\nUser: {q.strip()}"
42 | for q in batch['question']
43 | ]
44 | return {
45 | 'text': chats, # Ensure it's a list of strings, not a list of dictionaries
46 | 'answer': [extract_hash_answer(a) for a in batch['answer']]
47 | }
48 |
49 | return dataset.map(process_batch, batched=True, batch_size=chunk_size)
50 |
51 | def main():
52 | # Load and preprocess the dataset
53 | dataset_name = './dataset/Genome-Bench'
54 | dataset = preprocess_dataset(dataset_name, chunk_size=500)
55 |
56 | learning_rate = 1e-5
57 |
58 | epoch = 2
59 |
60 | # Define model and output paths
61 | model_name = "Qwen2.5-7B-Instruct"
62 | output_dir = f"{model_name.split('/')[-1]}-SFT"
63 | run_name = f"{model_name.split('/')[-1]}-{dataset_name.split('/')[-1]}"
64 |
65 | # Set memory-related environment variable
66 | os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
67 |
68 | # Create SFT training configuration
69 | training_args = SFTConfig(
70 | learning_rate = learning_rate,
71 | logging_steps=1,
72 | bf16=True,
73 | per_device_train_batch_size = 4,
74 | gradient_accumulation_steps = 4,
75 | num_train_epochs=epoch,
76 | max_grad_norm=0.1,
77 | report_to="wandb",
78 | output_dir=output_dir,
79 | run_name=run_name,
80 | log_on_each_node=False,
81 | )
82 |
83 | # Load the model
84 | model = AutoModelForCausalLM.from_pretrained(
85 | model_name,
86 | torch_dtype=torch.bfloat16,
87 | device_map="auto",
88 | )
89 |
90 | # Load the tokenizer and set pad token
91 | tokenizer = AutoTokenizer.from_pretrained(
92 | model_name,
93 | model_max_length=512,
94 | )
95 | tokenizer.pad_token = tokenizer.eos_token
96 |
97 | # Initialize the SFT trainer using the tokenizer as the processing class
98 | trainer = SFTTrainer(
99 | model=model,
100 | processing_class=tokenizer,
101 | train_dataset=dataset,
102 | args=training_args,
103 | )
104 |
105 | # Initialize wandb in offline mode for experiment tracking
106 | wandb.init(project="crispr_sft", name=run_name, mode="offline")
107 | trainer.train()
108 | trainer.save_model(training_args.output_dir)
109 |
110 | if __name__ == "__main__":
111 | main()
112 |
113 |
114 |
--------------------------------------------------------------------------------
/difficulty_and_category_assignment/category_difficulty.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from llm import LLMChat
4 |
5 | # Load the test data
6 | def load_test_data(file_path):
7 | with open(file_path, 'r') as f:
8 | return json.load(f)
9 |
10 | # Prompt templates for category and difficulty classification
11 | CATEGORY_PROMPT = """
12 | You are a CRISPR expert assistant. I will give you a CRISPR-related question and answer, and you will assign it to exactly one of the following seven predefined categories based on its core intent. Be precise—choose the category that most closely reflects the primary focus of the question.
13 | Here are the categories:
14 | 1. Gene-editing Enzyme Selection
15 | 2. GuideRNA Design
16 | 3. Cloning & Plasmid Construction
17 | 4. Gene-editing Delivery Methods
18 | 5. CRISPR Screening & Library Workflows
19 | 6. Validation, Troubleshooting & Optimization
20 | 7. Practical Considerations & Lab Logistics
21 |
22 | For each question, provide your category assignment with a brief (1-2 sentence) explanation of why you selected that category.
23 |
24 | Question: {question}
25 | Answer: {answer}
26 |
27 | Please format your response following this response format and make sure it is parsable by JSON:
28 | {
29 | "category": , # Category name
30 | "reason": # Brief statement on why you picked
31 | }
32 | """
33 |
34 | DIFFICULTY_PROMPT = """
35 | You are a CRISPR expert assistant. I will give you a CRISPR-related question and answer, and you will assign it to exactly one of the following three predefined difficulty levels.
36 |
37 | Here are the categories:
38 | 1. Easy
39 | 2. Medium
40 | 3. Hard
41 |
42 | For each question, provide your difficulty assignment with a brief (1-2 sentence) explanation of why you selected that level.
43 |
44 | Question: {question}
45 | Answer: {answer}
46 |
47 | Please format your response following this response format and make sure it is parsable by JSON:
48 | {
49 | "difficulty": , # Difficulty name
50 | "reason": # Brief statement on why you picked
51 | }
52 | """
53 |
54 | def main():
55 | # Path to the test data file
56 | test_data_path = "MCQs_Genome-Bench-evaluation.json"
57 |
58 | # Load test data
59 | data = load_test_data(test_data_path)
60 |
61 | # Process each entry
62 | results = []
63 |
64 | # Variables to track stats
65 | category_counts = {}
66 | difficulty_counts = {}
67 | category_mismatches = 0
68 | difficulty_mismatches = 0
69 | entries_with_original_category = 0
70 | entries_with_original_difficulty = 0
71 |
72 | for i, entry in enumerate(data):
73 | print(f"Processing entry {i+1}/{len(data)}: ID {entry['id']}")
74 |
75 | # Extract question and answer
76 | question = entry['question']
77 | answer = entry['answer']
78 | entry_id = entry['id']
79 |
80 | # Store original values if they exist
81 | original_category = entry.get('question type')
82 | original_difficulty = entry.get('difficulty')
83 |
84 | if original_category:
85 | entries_with_original_category += 1
86 | if original_difficulty:
87 | entries_with_original_difficulty += 1
88 |
89 | # Generate new category
90 | try:
91 | category_prompt = CATEGORY_PROMPT.format(question=question, answer=answer)
92 | category_response = LLMChat.chat(category_prompt, model_name="gpt4o")
93 | category = category_response.get('category')
94 | print(f" - Generated category: {category}")
95 |
96 | if original_category and category != original_category:
97 | category_mismatches += 1
98 | print(f" - MISMATCH: Original category was '{original_category}'")
99 | except Exception as e:
100 | print(f" - Error generating category: {str(e)}")
101 | category = None
102 |
103 | # Generate new difficulty
104 | try:
105 | difficulty_prompt = DIFFICULTY_PROMPT.format(question=question, answer=answer)
106 | difficulty_response = LLMChat.chat(difficulty_prompt, model_name="gpt4o")
107 | difficulty = difficulty_response.get('difficulty')
108 | print(f" - Generated difficulty: {difficulty}")
109 |
110 | if original_difficulty and difficulty != original_difficulty:
111 | difficulty_mismatches += 1
112 | print(f" - MISMATCH: Original difficulty was '{original_difficulty}'")
113 | except Exception as e:
114 | print(f" - Error generating difficulty: {str(e)}")
115 | difficulty = None
116 |
117 | # Update counts
118 | if category:
119 | category_counts[category] = category_counts.get(category, 0) + 1
120 | if difficulty:
121 | difficulty_counts[difficulty] = difficulty_counts.get(difficulty, 0) + 1
122 |
123 | # Add to results
124 | results.append({
125 | 'id': entry_id,
126 | 'question': question,
127 | 'answer': answer,
128 | 'category': category,
129 | 'difficulty': difficulty
130 | })
131 |
132 | # Print summary statistics
133 | print("\n=== SUMMARY STATISTICS ===")
134 | print("\nCategory Counts:")
135 | for cat, count in category_counts.items():
136 | print(f" {cat}: {count}")
137 |
138 | print("\nDifficulty Counts:")
139 | for diff, count in difficulty_counts.items():
140 | print(f" {diff}: {count}")
141 |
142 |
143 | # Save results to a new file
144 | output_file = "Genome-Bench-evaluation.json"
145 | with open(output_file, 'w') as f:
146 | json.dump(results, f, indent=2)
147 |
148 | print(f"\nProcessed {len(results)} entries and saved to {output_file}")
149 |
150 | if __name__ == "__main__":
151 | main()
--------------------------------------------------------------------------------
/training/rl_training.py:
--------------------------------------------------------------------------------
1 | ####### ----------------
2 | ## GRPO Training
3 | ####### ----------------
4 |
5 |
6 | import os
7 | import re
8 | import torch
9 | from datasets import load_dataset, Dataset, load_from_disk
10 | from transformers import AutoTokenizer, AutoModelForCausalLM
11 | from trl.trainer import GRPOConfig, GRPOTrainer
12 | import wandb
13 |
14 |
15 | # ========================== SYSTEM PROMPT ==========================
16 | R1_STYLE_SYSTEM_PROMPT = """A conversation between User and Assistant. The user asks a single-choice Multiple Choice question, and the Assistant solves it. Please answer the multiple choice question by selecting only one from optiona a., option b., option c., option d., option e..
17 | The assistant first thinks about the explanation process in the mind and then provides the user
18 | with the answer. The explanation process and answer are enclosed within and
19 | tags, respectively, i.e., explanation process here
20 | answer here ."""
21 |
22 | TASK_SPECIFIC_INSTRUCTIONS = "The answer must be a single letter from a,b,c,d,e."
23 |
24 |
25 | # =====================================================================
26 | # Utility functions
27 | # =====================================================================
28 | def extract_xml_answer(text: str) -> str:
29 | try:
30 | return text.split("")[-1].split("")[0].strip()
31 | except IndexError:
32 | return ""
33 |
34 |
35 | def preprocess_dataset(dataset_name, split="train", chunk_size=1000) -> Dataset:
36 | dataset = load_from_disk(dataset_name)[split]
37 |
38 | def process_batch(batch):
39 | prompts = [
40 | [
41 | {
42 | "role": "system",
43 | "content": R1_STYLE_SYSTEM_PROMPT + "\n" + TASK_SPECIFIC_INSTRUCTIONS,
44 | },
45 | {"role": "user", "content": q.strip()},
46 | ]
47 | for q in batch["question"]
48 | ]
49 |
50 | return {
51 | "prompt": prompts,
52 | "answer": [extract_xml_answer(a) for a in batch["answer"]],
53 | }
54 |
55 | return dataset.map(process_batch, batched=True, batch_size=chunk_size)
56 |
57 |
58 | # ------------------------- Reward functions --------------------------
59 | def format_reward_func(completions, **kwargs) -> list[float]:
60 | """Reward 1 pt if completion matches required XML template."""
61 | pattern = r"^(?:(?!).)*\n(?:(?!).)*$"
62 | responses = [completion[0]["content"] for completion in completions]
63 | return [1.0 if re.match(pattern, r) else 0.0 for r in responses]
64 |
65 |
66 | def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
67 | """Reward 2 pt if extracted answer matches ground-truth letter."""
68 | responses = [completion[0]["content"] for completion in completions]
69 | extracted = [extract_xml_answer(r) for r in responses]
70 |
71 | print(
72 | f"\n\n==================== DEBUG ====================\n"
73 | f"User Question:\n{prompts[0][-1]['content']}"
74 | f"\n\nCorrect Answer:\n{answer[0]}\n"
75 | f"\nFirst generated response:\n{responses[0]}"
76 | f"\nExtracted: {extracted[0]}"
77 | f"\nCorrectness flags: {''.join('Y' if r==a else 'N' for r,a in zip(extracted,answer))}"
78 | )
79 |
80 | return [2.0 if r == a else 0.0 for r, a in zip(extracted, answer)]
81 |
82 |
83 | # =====================================================================
84 | # MAIN
85 | # =====================================================================
86 | def main():
87 | dataset_name = "./dataset/Genome-Bench"
88 | dataset = preprocess_dataset(dataset_name, chunk_size=500)
89 |
90 | model_name = "Qwen2.5-7B-Instruct"
91 | epoch = 2
92 | learning_rate = 1e-5
93 | num_generations = 4
94 |
95 | output_dir = f"./../{model_name.split('/')[-1]}-RL"
96 | run_name = f"{model_name.split('/')[-1]}-{dataset_name.split('/')[-1]}"
97 |
98 | # --- memory env ---
99 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
100 |
101 | # ------------------ TRAINING ARGS ------------------
102 | training_args = GRPOConfig(
103 | learning_rate=learning_rate,
104 | beta=0.005,
105 | optim="adamw_torch",
106 | adam_beta1=0.9,
107 | adam_beta2=0.99,
108 | weight_decay=0.1,
109 | warmup_ratio=0.1,
110 | lr_scheduler_type="cosine",
111 | logging_steps=1,
112 | bf16=True,
113 | per_device_train_batch_size=8,
114 | num_generations=num_generations,
115 | gradient_accumulation_steps=4,
116 | max_prompt_length=256,
117 | max_completion_length=512,
118 | num_train_epochs=epoch,
119 | save_steps=100_000,
120 | max_grad_norm=0.1,
121 | report_to="wandb",
122 | output_dir=output_dir,
123 | run_name=run_name,
124 | log_on_each_node=False,
125 | )
126 |
127 | # ------------------ Model / Tokenizer --------------
128 | model = AutoModelForCausalLM.from_pretrained(
129 | model_name, torch_dtype=torch.bfloat16, device_map="balanced"
130 | )
131 |
132 | tokenizer = AutoTokenizer.from_pretrained(
133 | model_name, model_max_length=training_args.max_completion_length
134 | )
135 | tokenizer.pad_token = tokenizer.eos_token
136 |
137 | # ------------------ Trainer ------------------------
138 | trainer = GRPOTrainer(
139 | model=model,
140 | processing_class=tokenizer,
141 | reward_funcs=[format_reward_func, correctness_reward_func],
142 | args=training_args,
143 | train_dataset=dataset,
144 | )
145 |
146 | wandb.init(project="crispr_grpo", name=run_name, mode="offline")
147 | trainer.train()
148 | trainer.save_model(training_args.output_dir)
149 |
150 |
151 | if __name__ == "__main__":
152 | main()
153 |
154 |
--------------------------------------------------------------------------------
/dataset_pipeline/1_email_parse.py:
--------------------------------------------------------------------------------
1 | #######--------------------------
2 | ## This script will modify the mbox content to
3 | ## the new_QA.json file with each item has (question,answer,context)
4 | #######--------------------------
5 |
6 |
7 | import mailbox
8 | import csv
9 | import re
10 |
11 | def remove_quotes(text):
12 | # Regular expression to match lines that start with one or more '>' characters followed by a space
13 | clean_lines = [line for line in text.splitlines() if not re.match(r'^\s*(>+ )', line)]
14 | return '\n'.join(clean_lines)
15 |
16 | def get_email_body(message):
17 | body = None # Initialize body to None or an empty string
18 | if message.is_multipart():
19 | for part in message.walk():
20 | ctype = part.get_content_type()
21 | cdispo = str(part.get('Content-Disposition'))
22 | if (ctype == 'text/plain' and 'attachment' not in cdispo) or (ctype == 'text/html' and 'attachment' not in cdispo):
23 | body = part.get_payload(decode=True)
24 | body = body.decode('utf-8', errors='ignore')
25 | body = remove_quotes(body)
26 | return body # Return the cleaned body once found
27 | else:
28 | body = message.get_payload(decode=True)
29 | body = body.decode('utf-8', errors='ignore')
30 | body = remove_quotes(body)
31 | return body
32 |
33 |
34 | # Load the mbox file
35 | mbox = mailbox.mbox('.../mbox') # your local mbox address
36 |
37 | # Open a CSV file to write to
38 | with open('testemail.csv', 'w', newline='', encoding='utf-8') as f:
39 | writer = csv.writer(f)
40 | writer.writerow(['Message-ID', 'In-Reply-To', 'References', 'Subject', 'From', 'Date', 'To', 'Body'])
41 |
42 | for message in mbox:
43 | message_id = message['message-id']
44 | in_reply_to = message['in-reply-to']
45 | references = message['references']
46 | subject = message['subject']
47 | from_email = message['from']
48 | date = message['date']
49 | to_email = message['to']
50 | body = get_email_body(message)
51 |
52 | writer.writerow([message_id, in_reply_to, references, subject, from_email, date, to_email, body])
53 |
54 | print("Conversion completed with adjusted email body extraction!")
55 |
56 |
57 |
58 | # %%
59 |
60 | import pandas as pd
61 | # Adjust display settings
62 | pd.set_option('display.max_columns', None) # Show all columns
63 | pd.set_option('display.max_rows', 30) # Show all rows
64 | pd.set_option('display.max_colwidth', None) # Show full content of each cell
65 | pd.set_option('display.width', None) # Auto-detect the display width
66 | df = pd.read_csv('/xxxx/testemail.csv')
67 |
68 | df['Body'] = df['Body'].str.split('You received this message because you are subscribed to the Google Groups').str[0]
69 | pattern = r"On\s+(Monday|Tuesday|Wednesday|Thursday|Friday|Saturday|Sunday),\s+(January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{1,2},\s+(2000|2001|2002|2003|2004|2005|2006|2007|2008|2009|2010|2011|2012|2013|2014|2015|2016|2017|2018|2019|2020|2021|2022|2023|2024)\s+at.*?wrote:"
70 |
71 | # Split the 'body' column and keep only the part before the unwanted text
72 | df['Body'] = df['Body'].str.split(pattern, n=1, expand=True)[0]
73 | pattern = r"On\s+(Mon|Tue|Wed|Thu|Fri|Sat|Sun),\s+(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s+\d{1,2},\s+\d{4}\s+at\s+\d{1,2}:\d{2}\s+(AM|PM)\s+.*?<.*?>\s*\nwrote:"
74 | df['Body'] = df['Body'].str.split(pattern, n=1, expand=True)[0]
75 | df['Body'] = df['Body'].str.replace('\n', '', regex=False)
76 | df['Message-ID'] = df['Message-ID'].str.replace(' ', '', regex=False)
77 | df['Message-ID'] = df['Message-ID'].str.replace('<', '').str.replace('>', '')
78 | df['In-Reply-To'] = df['In-Reply-To'].str.replace(' ', '', regex=False)
79 | df['In-Reply-To'] = df['In-Reply-To'].str.replace('<', '').str.replace('>', '')
80 | df['References'] = df['References'].str.extractall(r'<(.*?)>').groupby(level=0).agg(list)
81 | df['References'] = df['References'].apply(lambda refs: [ref.replace(' ', '') for ref in refs] if isinstance(refs, list) else [])
82 | df.head(30)
83 |
84 |
85 | # %%
86 |
87 | import pandas as pd
88 | # Assuming 'df' is already loaded and formatted with 'References' as lists of message IDs.
89 |
90 | # Identifying root messages
91 | roots = df[df['In-Reply-To'].isna()]
92 |
93 | # Preparing a list to store threads
94 | threads = []
95 |
96 | # Processing each root to find the thread end and reconstruct the thread
97 | for index, root in roots.iterrows():
98 | thread_id = root['Message-ID']
99 | thread_subject = root['Subject']
100 |
101 | # Find messages with this root ID in their references
102 | thread_messages = df[df['References'].apply(lambda x: thread_id in x if x else False)]
103 | thread_body = '***SUBJECT: {subject}***'.format(subject=root['Subject'])
104 |
105 | if not thread_messages.empty:
106 | # Find the last message by identifying the one with the longest 'References' list
107 | last_message = thread_messages.loc[thread_messages['References'].apply(len).idxmax()]
108 |
109 | # Initialize the thread body with empty string message_ids = [root['Message-ID']] # Start with the root message ID
110 |
111 | # Iterate through the reference IDs to construct the body of the thread
112 | for msg_id in last_message['References']:
113 | # Safely retrieve the message, ensure there's at least one message that matches
114 | message = df[df['Message-ID'] == msg_id].iloc[0] if not df[df['Message-ID'] == msg_id].empty else None
115 | if message is not None:
116 | thread_body += ' ***NEW MESSAGE BY {sender}***: {body}'.format(sender=message['From'], body=message['Body'])
117 | message_ids.append(msg_id) # Track the message ID
118 |
119 | # Append the last message body after all referenced messages
120 | thread_body += ' ***NEW MESSAGE BY {sender}***: {body} '.format(sender=last_message['From'], body=last_message['Body'])
121 | message_ids.append(last_message['Message-ID']) # Include the last message ID
122 |
123 | # Append to the list of threads
124 | threads.append({'threadID': thread_id, 'threadSubject': thread_subject, 'Body': thread_body, 'Message-IDs': message_ids})
125 | else:
126 | # Handle the case where no messages reference the root
127 | thread_body += ' ***NEW MESSAGE BY {sender}***: {body}'.format(sender=root['From'], body=root['Body'])
128 | message_ids = [root['Message-ID']] # Just the root message ID for standalone threads
129 | threads.append({'threadID': thread_id, 'threadSubject': thread_subject, 'Body': thread_body, 'Message-IDs': message_ids})
130 |
131 | # Convert list to DataFrame
132 | thread_df = pd.DataFrame(threads)
133 |
134 | # Displaying the new DataFrame
135 | print(thread_df.head(5))
136 |
137 | # %%
138 |
139 | import pandas as pd
140 | import json
141 | import openai
142 | from openai import OpenAI
143 | import re
144 | from concurrent.futures import ThreadPoolExecutor, as_completed
145 |
146 | # Initialize your OpenAI client
147 | client = OpenAI(api_key='xxxx')
148 |
149 | # Assuming 'df' is your DataFrame and it has 'threadID' and 'Body' columns
150 | df = thread_df
151 |
152 | system_prompt = "The following text represents an entire email thread that's in chronological order. It started with a ***SUBJECT: {subject}***. It then has different emails by different people, segmented by '***NEW MESSAGE BY {name} ***'. I want you to extract as many research/scientific related Q&A pairs as possible. You also want to have a context field that makes it so a new researcher, who is not involved in the Email thread, can read the Q&A and the context field, and then have an understanding of what is going on. It's good for answer to be a solid response to the question based on the email thread, and it'd be nice for for answer to include explanations. The context field is additionally provided to give more macro information about the whole thread. The question and the context field together can help this new researcher understand why this question is taking in place, and has a holistic understanding of the entire Email thread. Your output must be a list of maps (i.e. [{question: 'what is 1+1' answer': '2', context: 'PersonA has been asking PersonB about elementary math', questionFrom: 'daniel, Daniel@gmail.com', answerFrom: 'john, john@gmail.com'}...]. If there's no answer then just keep the answer field as empty string "". FOR Question, Answer, and Context Field, MAKE SURE TO REPLACE REAL NAMES WITH PERSON1, PERSON2... MOST IMPORTANT: REMINDER THAT YOUR REPLY MUST BE A LIST OF MAPS IN THAT FORMAT I DESCRIBED AND NOTHING ELSE."
153 |
154 |
155 | def fix_json_string(json_like_string):
156 | # Step 1: Replace single quotes around keys with double quotes
157 | corrected_json = re.sub(r"\'([a-zA-Z0-9_]+)\'\s*:", r'"\1":', json_like_string)
158 |
159 | # Step 2: Add double quotes around unquoted keys
160 | corrected_json = re.sub(r'(? str:
45 | """Extracts the answer from between ... tags."""
46 | try:
47 | answer = text.split("")[-1].split("")[0].strip()
48 | return answer
49 | except IndexError:
50 | return ""
51 |
52 | ##########################################################################
53 | # New Trainer Subclass: PreGeneratedGRPOTrainer
54 | ##########################################################################
55 | class PreGeneratedGRPOTrainer(GRPOTrainer):
56 | """
57 | A custom GRPOTrainer that uses pre‐generated completions instead of on‐the‐fly generation.
58 |
59 | Each training sample is expected to have a "pre_generated" key containing a list of 4 dictionaries.
60 | Each dictionary must include:
61 | - "content": the candidate output from the router assistant (e.g. "2")
62 | - Optionally, a detailed response under a key like "response by model 2"
63 |
64 | For training, we use all 4 candidates. We replicate the prompt for each candidate so that the effective
65 | batch size is (original batch size × num_generations). The ground‐truth answer is repeated accordingly,
66 | and later the rewards are computed and grouped per sample.
67 | """
68 |
69 | def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
70 | device = self.accelerator.device
71 | prompts = [x["prompt"] for x in inputs]
72 | prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
73 | prompt_inputs = self.processing_class(
74 | prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
75 | )
76 | #prompt_inputs = super()._prepare_inputs(prompt_inputs)
77 | prompt_ids, prompt_mask = prompt_inputs["input_ids"].to(device), prompt_inputs["attention_mask"].to(device)
78 |
79 | if self.max_prompt_length is not None:
80 | prompt_ids = prompt_ids[:, -self.max_prompt_length :]
81 | prompt_mask = prompt_mask[:, -self.max_prompt_length :]
82 |
83 |
84 | pre_generated_list = [x["pre_generated"] for x in inputs] # length: N; each element is list of 4 dicts
85 | candidate_texts = [] # will have length: N * 4
86 | reward_texts = [] # will have length: N * 4
87 |
88 | for pg in pre_generated_list:
89 | entry = random.choice(pg) # randomly sample one entry from pg
90 | candidate = entry["content"].strip()
91 | candidate_texts.append(candidate)
92 |
93 | digit_list = re.findall(r"\d", candidate)
94 | if not digit_list:
95 | raise ValueError(f"Candidate text '{candidate}' does not contain a digit.")
96 | digit = digit_list[0]
97 | response_key = f"response by model {digit}"
98 | reward_texts.append(entry.get(response_key, candidate))
99 |
100 | completion_inputs = self.processing_class(
101 | text=candidate_texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
102 | )
103 | completion_ids = completion_inputs["input_ids"].to(device) # This is analogous to completion_ids produced by generation.
104 | completion_mask = completion_inputs["attention_mask"].to(device) # Similarly, candidate_mask.
105 |
106 |
107 | # Construct prompt_completion_ids from pre-generated completions
108 | prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
109 |
110 | # Mask everything after the first EOS token
111 | is_eos = (completion_ids == self.processing_class.eos_token_id).to(device)
112 | eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
113 |
114 | # Make sure the mask and argmax are on the same device
115 | any_eos = is_eos.any(dim=1)
116 | argmax_eos = is_eos.int().argmax(dim=1)
117 |
118 | eos_idx[any_eos] = argmax_eos[any_eos] # no device mismatch here
119 | sequence_indices = torch.arange(is_eos.size(1), device=is_eos.device).expand(is_eos.size(0), -1)
120 | completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
121 |
122 |
123 | # Concatenate prompt_mask with completion_mask for logit computation
124 | attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
125 |
126 | logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
127 |
128 | with torch.inference_mode():
129 | if self.ref_model is not None:
130 | ref_per_token_logps = self._get_per_token_logps(
131 | self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
132 | )
133 | else:
134 | with self.accelerator.unwrap_model(self.model).disable_adapter():
135 | ref_per_token_logps = self._get_per_token_logps(
136 | self.model, prompt_completion_ids, attention_mask, logits_to_keep
137 | )
138 |
139 |
140 | # Decode the generated completions
141 | completions_text = reward_texts
142 | completions = completions_text
143 | rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
144 | for i, (reward_func, reward_processing_class) in enumerate(
145 | zip(self.reward_funcs, self.reward_processing_classes)
146 | ):
147 |
148 | # Repeat all input columns (but "prompt" and "completion") to match the number of generations
149 | keys = [key for key in inputs[0] if key not in ["prompt", "completion","pre_generated"]]
150 | reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
151 | output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
152 | rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
153 |
154 | # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
155 | # completions may be distributed across processes
156 | rewards_per_func = gather(rewards_per_func)
157 |
158 |
159 | # Apply weights to each reward function's output and sum
160 | rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)
161 |
162 | # Compute grouped-wise rewards
163 | mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
164 | std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
165 |
166 | # Normalize the rewards to compute the advantages
167 | mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
168 | std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
169 | advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
170 |
171 | # Slice to keep only the local part of the data
172 | process_slice = slice(
173 | self.accelerator.process_index * len(prompts),
174 | (self.accelerator.process_index + 1) * len(prompts),
175 | )
176 | advantages = advantages[process_slice]
177 |
178 | # Log the metrics
179 | reward_per_func = rewards_per_func.mean(0)
180 | for i, reward_func in enumerate(self.reward_funcs):
181 | if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
182 | reward_func_name = reward_func.config._name_or_path.split("/")[-1]
183 | else:
184 | reward_func_name = reward_func.__name__
185 | self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
186 |
187 | self._metrics["reward"].append(rewards.mean().item())
188 | self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
189 |
190 | if (
191 | self.log_completions
192 | and self.state.global_step % self.args.logging_steps == 0
193 | and "wandb" in self.args.report_to
194 | ):
195 | import pandas as pd
196 |
197 | # For logging
198 | table = {
199 | "step": [str(self.state.global_step)] * len(rewards),
200 | "prompt": gather_object(prompts_text),
201 | "completion": gather_object(completions_text),
202 | "reward": rewards.tolist(),
203 | }
204 | df = pd.DataFrame(table)
205 |
206 | if wandb.run is not None and self.accelerator.is_main_process:
207 | wandb.log({"completions": wandb.Table(dataframe=df)})
208 |
209 | return {
210 | "prompt_ids": prompt_ids,
211 | "prompt_mask": prompt_mask,
212 | "completion_ids": completion_ids,
213 | "completion_mask": completion_mask,
214 | "ref_per_token_logps": ref_per_token_logps,
215 | "advantages": advantages,
216 | }
217 |
218 |
219 |
220 |
221 | ##########################################################################
222 | # Reward Function: Correctness
223 | ##########################################################################
224 | def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
225 | """
226 | Uses the detailed responses (reward_texts) to compute correctness.
227 | It extracts the answer (e.g. the letter "a") from the detailed response and compares it with the ground truth.
228 | Returns 1.0 if they match, otherwise -1.0.
229 | """
230 | reward_texts = kwargs.get("reward_texts", None)
231 | if reward_texts is not None:
232 | extracted_answers = [extract_xml_answer(rt) for rt in reward_texts]
233 | else:
234 | extracted_answers = [extract_xml_answer(c) for c in completions]
235 |
236 | print("\n\n===============================================================\n"
237 | f"User question (sample): {prompts[0]}\n"
238 | f"Ground truth answer: {answer[0]}\n"
239 | f"Extracted answers (from reward texts): {extracted_answers}\n")
240 | return [1.0 if r == a else -1.0 for r, a in zip(extracted_answers, answer)]
241 |
242 | ##########################################################################
243 | # Dataset Preprocessing
244 | ##########################################################################
245 | R1_STYLE_SYSTEM_PROMPT = """There are four models capable of answering single-choice multiple choice questions. You are the Router Assistant.
246 |
247 | In a conversation between the User and the Router Assistant, the User provides a single-choice multiple choice question. Your job as the Router Assistant is to suggest which of the four models is most likely to provide the best answer.
248 |
249 | You must select only one model from the following options: model 1, model 2, model 3, or model 4.
250 |
251 | Provide your recommendation enclosed within tags, for example: 1
252 | """
253 | TASK_SPECIFIC_INSTRUCTIONS = "The choice must be a single digit: 1, 2, 3, or 4."
254 |
255 | def preprocess_dataset(dataset_name, split="train", chunk_size=1000) -> Dataset:
256 | dataset = load_from_disk(dataset_name)[split]
257 | def process_batch(batch):
258 | # Build the prompt as a list of two messages:
259 | # System instruction (with task-specific instructions) and the user question.
260 | prompts = [
261 | [
262 | {'role': 'system', 'content': R1_STYLE_SYSTEM_PROMPT + "\n" + TASK_SPECIFIC_INSTRUCTIONS},
263 | {'role': 'user', 'content': q.strip()}
264 | ]
265 | for q in batch['question']
266 | ]
267 | return {
268 | 'prompt': prompts,
269 | 'answer': [extract_xml_answer(a) for a in batch['answer']],
270 | 'pre_generated': batch['pre_generated']
271 | }
272 | return dataset.map(process_batch, batched=True, batch_size=chunk_size)
273 |
274 | ##########################################################################
275 | # Main Training
276 | ##########################################################################
277 | def main():
278 | dataset_name = './dataset/Genome-Bench-Router' # Genome-Bench-Router dataset contains pre-generated responses from different RL models
279 | dataset = preprocess_dataset(dataset_name, chunk_size=500)
280 |
281 | model_name = "Qwen2.5-7B-Instruct"
282 | epoch = 2
283 | learning_rate = 1e-5
284 | output_dir = f"{model_name.split('/')[-1]}-Router"
285 | run_name = f"{model_name.split('/')[-1]}-{dataset_name.split('/')[-1]}"
286 |
287 | os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
288 |
289 | training_args = GRPOConfig(
290 | learning_rate=learning_rate,
291 | beta=0.005,
292 | optim="adamw_8bit",
293 | adam_beta1=0.9,
294 | adam_beta2=0.99,
295 | weight_decay=0.1,
296 | warmup_ratio=0.1,
297 | lr_scheduler_type='cosine',
298 | logging_steps=1,
299 | bf16=True,
300 | per_device_train_batch_size=8,
301 | num_generations=8, # Use 4 candidate generations per sample
302 | gradient_accumulation_steps=4,
303 | max_prompt_length=256,
304 | max_completion_length=512,
305 | num_train_epochs=epoch,
306 | save_steps=100000,
307 | max_grad_norm=0.1,
308 | report_to="wandb",
309 | output_dir=output_dir,
310 | run_name=run_name,
311 | log_on_each_node=False,
312 | )
313 |
314 | model = AutoModelForCausalLM.from_pretrained(
315 | model_name,
316 | torch_dtype=torch.bfloat16,
317 | device_map="auto",
318 | )
319 |
320 | tokenizer = AutoTokenizer.from_pretrained(
321 | model_name,
322 | model_max_length=training_args.max_completion_length,
323 | )
324 | tokenizer.pad_token = tokenizer.eos_token
325 |
326 | # Instantiate the custom trainer.
327 | trainer = PreGeneratedGRPOTrainer(
328 | model=model,
329 | processing_class=tokenizer,
330 | reward_funcs=[correctness_reward_func],
331 | args=training_args,
332 | train_dataset=dataset,
333 | )
334 |
335 | wandb.init(project="crispr_grpo_router", name=run_name, mode="offline")
336 | trainer.train()
337 | trainer.save_model(training_args.output_dir)
338 |
339 | if __name__ == "__main__":
340 | main()
--------------------------------------------------------------------------------