├── evaluation ├── bright │ ├── configs │ │ ├── bm25 │ │ │ ├── aops.json │ │ │ ├── pony.json │ │ │ ├── biology.json │ │ │ ├── economics.json │ │ │ ├── leetcode.json │ │ │ ├── robotics.json │ │ │ ├── theoremqa.json │ │ │ ├── earth_science.json │ │ │ ├── psychology.json │ │ │ ├── stackoverflow.json │ │ │ ├── sustainable_living.json │ │ │ ├── theoremqa_questions.json │ │ │ └── theoremqa_theorems.json │ │ └── reasonir │ │ │ ├── aops.json │ │ │ ├── theoremqa.json │ │ │ ├── leetcode.json │ │ │ ├── theoremqa_questions.json │ │ │ ├── theoremqa_theorems.json │ │ │ ├── biology.json │ │ │ ├── robotics.json │ │ │ ├── earth_science.json │ │ │ ├── economics.json │ │ │ ├── psychology.json │ │ │ ├── stackoverflow.json │ │ │ ├── sustainable_living.json │ │ │ └── pony.json │ ├── other_requirements.txt │ ├── requirements.txt │ ├── script.sh │ ├── reranker_script.sh │ ├── prompts.py │ ├── reranker.py │ └── run.py ├── rag │ ├── mmlu_cot │ │ ├── cot_prompt_lib │ │ │ └── initial_prompt.txt │ │ ├── scripts │ │ │ └── eval_llama_3_8b_mmlu_rag.sh │ │ ├── extract_mmlu_group.py │ │ ├── utils │ │ │ └── extract_cot_as_queries.py │ │ ├── compute_accuracy.py │ │ └── evaluate_from_local_mmlu.py │ ├── gpqa │ │ ├── src │ │ │ ├── conf │ │ │ │ └── naive_rag_default.yaml │ │ │ ├── data │ │ │ │ └── datasets.py │ │ │ ├── main.py │ │ │ ├── utils │ │ │ │ ├── cache_utils.py │ │ │ │ ├── math_equivalence.py │ │ │ │ └── hydra_runner.py │ │ │ └── workflow │ │ │ │ └── naive_rag.py │ │ ├── scripts │ │ │ └── evaluate_naive_rag.sh │ │ └── apis │ │ │ ├── base.py │ │ │ └── offline_massiveds_search_api.py │ └── datastore │ │ ├── build_datastore.sh │ │ └── construct_datastore_corpus.py └── README.md ├── synthetic_data_generation ├── document_filters │ ├── __init__.py │ ├── basic_filters.py │ └── fineweb_edu_filter.py ├── requirements.txt ├── setup_java.sh ├── script.sh ├── script_batch.sh ├── README.md ├── generate_reasoning.py ├── data_gen_prompts.py ├── hard_negative_mining.py ├── lm_helper.py ├── batch_api_helper.py ├── generate_reasoning_batch.py ├── supplement_negative_passage.py ├── doc_to_query.py ├── doc_to_query_batch.py └── gen_utils.py ├── training ├── README.md ├── config_128gpusfsdp_llama.yml └── train.sh ├── test_time_techniques ├── README.md └── query_rewriting.py ├── CONTRIBUTING.md ├── CODE_OF_CONDUCT.md ├── README.md └── LICENSE /evaluation/bright/configs/bm25/aops.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": {}, 3 | "instructions_long": {} 4 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/bm25/pony.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": {}, 3 | "instructions_long": {} 4 | } -------------------------------------------------------------------------------- /evaluation/bright/other_requirements.txt: -------------------------------------------------------------------------------- 1 | cohere==4.36 2 | voyageai 3 | vertexai 4 | openai 5 | gritlm -------------------------------------------------------------------------------- /evaluation/bright/configs/bm25/biology.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": {}, 3 | "instructions_long": {} 4 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/bm25/economics.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": {}, 3 | "instructions_long": {} 4 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/bm25/leetcode.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": {}, 3 | "instructions_long": {} 4 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/bm25/robotics.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": {}, 3 | "instructions_long": {} 4 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/bm25/theoremqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": {}, 3 | "instructions_long": {} 4 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/bm25/earth_science.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": {}, 3 | "instructions_long": {} 4 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/bm25/psychology.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": {}, 3 | "instructions_long": {} 4 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/bm25/stackoverflow.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": {}, 3 | "instructions_long": {} 4 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/bm25/sustainable_living.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": {}, 3 | "instructions_long": {} 4 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/bm25/theoremqa_questions.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": {}, 3 | "instructions_long": {} 4 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/bm25/theoremqa_theorems.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": {}, 3 | "instructions_long": {} 4 | } -------------------------------------------------------------------------------- /synthetic_data_generation/document_filters/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic_filters import white_space_length_filter 2 | from .fineweb_edu_filter import fineweb_quality_filter -------------------------------------------------------------------------------- /evaluation/bright/configs/reasonir/aops.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": { 3 | "query": "<|user|>\nGiven a Math problem, retrieve relevant examples that help answer the problem\n<|embed|>\n", 4 | "document": "<|embed|>\n" 5 | } 6 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/reasonir/theoremqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": { 3 | "query": "<|user|>\nGiven a Math problem, retrieve relevant examples that help answer the problem\n<|embed|>\n", 4 | "document": "<|embed|>\n" 5 | } 6 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/reasonir/leetcode.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": { 3 | "query": "<|user|>\nGiven a coding problem, retrieve relevant examples that help answer the problem\n<|embed|>\n", 4 | "document": "<|embed|>\n" 5 | } 6 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/reasonir/theoremqa_questions.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": { 3 | "query": "<|user|>\nGiven a Math problem, retrieve relevant examples that help answer the problem\n<|embed|>\n", 4 | "document": "<|embed|>\n" 5 | } 6 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/reasonir/theoremqa_theorems.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": { 3 | "query": "<|user|>\nGiven a Math problem, retrieve relevant theorems that help answer the problem\n<|embed|>\n", 4 | "document": "<|embed|>\n" 5 | } 6 | } -------------------------------------------------------------------------------- /evaluation/rag/mmlu_cot/cot_prompt_lib/initial_prompt.txt: -------------------------------------------------------------------------------- 1 | The following are multiple choice questions (with answers) about {$}. Think step by step and then finish your answer with "the answer is (X)" where X is the correct letter choice. 2 | 3 | 4 | -------------------------------------------------------------------------------- /synthetic_data_generation/requirements.txt: -------------------------------------------------------------------------------- 1 | bitsandbytes 2 | pytrec-eval 3 | tqdm 4 | openai 5 | transformers==4.47.0 6 | tiktoken 7 | scipy 8 | torchmetrics 9 | pyserini 10 | gensim 11 | einops 12 | vllm==0.6.3 13 | fasttext 14 | nltk 15 | # cohere==4.36 16 | # voyageai 17 | # vertexai 18 | # gritlm -------------------------------------------------------------------------------- /evaluation/bright/requirements.txt: -------------------------------------------------------------------------------- 1 | bitsandbytes 2 | pytrec-eval 3 | tqdm 4 | transformers==4.47.0 5 | tiktoken 6 | scipy 7 | torchmetrics 8 | pyserini 9 | gensim 10 | einops 11 | vllm 12 | fasttext 13 | nltk 14 | sentence-transformers 15 | hf_transfer 16 | # cohere==4.36 17 | # voyageai 18 | # vertexai 19 | # openai 20 | # gritlm -------------------------------------------------------------------------------- /synthetic_data_generation/setup_java.sh: -------------------------------------------------------------------------------- 1 | echo "Setting up Java 23.0.1" 2 | wget https://download.oracle.com/java/23/archive/jdk-23.0.1_linux-x64_bin.tar.gz 3 | tar -xvf jdk-23.0.1_linux-x64_bin.tar.gz 4 | rm jdk-23.0.1_linux-x64_bin.tar.gz 5 | echo "export JAVA_HOME=~/jdk-23.0.1" >> ~/.bashrc 6 | echo "export JVM_PATH=$JAVA_HOME/lib/server/libjvm.so" >> ~/.bashrc 7 | source ~/.bashrc -------------------------------------------------------------------------------- /evaluation/bright/configs/reasonir/biology.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": { 3 | "query": "<|user|>\nGiven a {task} post, retrieve relevant passages that help answer the post\n<|embed|>\n", 4 | "document": "<|embed|>\n" 5 | }, 6 | "instructions_long": { 7 | "query": "<|user|>\nGiven a {task} post, retrieve relevant documents that help answer the post\n<|embed|>\n", 8 | "document": "<|embed|>\n" 9 | } 10 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/reasonir/robotics.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": { 3 | "query": "<|user|>\nGiven a {task} post, retrieve relevant passages that help answer the post\n<|embed|>\n", 4 | "document": "<|embed|>\n" 5 | }, 6 | "instructions_long": { 7 | "query": "<|user|>\nGiven a {task} post, retrieve relevant documents that help answer the post\n<|embed|>\n", 8 | "document": "<|embed|>\n" 9 | } 10 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/reasonir/earth_science.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": { 3 | "query": "<|user|>\nGiven a {task} post, retrieve relevant passages that help answer the post\n<|embed|>\n", 4 | "document": "<|embed|>\n" 5 | }, 6 | "instructions_long": { 7 | "query": "<|user|>\nGiven a {task} post, retrieve relevant documents that help answer the post\n<|embed|>\n", 8 | "document": "<|embed|>\n" 9 | } 10 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/reasonir/economics.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": { 3 | "query": "<|user|>\nGiven a {task} post, retrieve relevant passages that help answer the post\n<|embed|>\n", 4 | "document": "<|embed|>\n" 5 | }, 6 | "instructions_long": { 7 | "query": "<|user|>\nGiven a {task} post, retrieve relevant documents that help answer the post\n<|embed|>\n", 8 | "document": "<|embed|>\n" 9 | } 10 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/reasonir/psychology.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": { 3 | "query": "<|user|>\nGiven a {task} post, retrieve relevant passages that help answer the post\n<|embed|>\n", 4 | "document": "<|embed|>\n" 5 | }, 6 | "instructions_long": { 7 | "query": "<|user|>\nGiven a {task} post, retrieve relevant documents that help answer the post\n<|embed|>\n", 8 | "document": "<|embed|>\n" 9 | } 10 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/reasonir/stackoverflow.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": { 3 | "query": "<|user|>\nGiven a {task} post, retrieve relevant passages that help answer the post\n<|embed|>\n", 4 | "document": "<|embed|>\n" 5 | }, 6 | "instructions_long": { 7 | "query": "<|user|>\nGiven a {task} post, retrieve relevant documents that help answer the post\n<|embed|>\n", 8 | "document": "<|embed|>\n" 9 | } 10 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/reasonir/sustainable_living.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": { 3 | "query": "<|user|>\nGiven a {task} post, retrieve relevant passages that help answer the post\n<|embed|>\n", 4 | "document": "<|embed|>\n" 5 | }, 6 | "instructions_long": { 7 | "query": "<|user|>\nGiven a {task} post, retrieve relevant documents that help answer the post\n<|embed|>\n", 8 | "document": "<|embed|>\n" 9 | } 10 | } -------------------------------------------------------------------------------- /evaluation/bright/configs/reasonir/pony.json: -------------------------------------------------------------------------------- 1 | { 2 | "instructions": { 3 | "query": "<|user|>\nGiven a {task} question, retrieve relevant passages that help answer the question\n<|embed|>\n", 4 | "document": "<|embed|>\n" 5 | }, 6 | "instructions_long": { 7 | "query": "<|user|>\nGiven a {task} question, retrieve relevant documents that help answer the question\n<|embed|>\n", 8 | "document": "<|embed|>\n" 9 | } 10 | } -------------------------------------------------------------------------------- /evaluation/bright/script.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | cd evaluation/bright 4 | 5 | MODEL=reasonir 6 | REASONING=gpt4 7 | BS=-1 8 | for TASK in biology earth_science economics psychology robotics stackoverflow sustainable_living leetcode pony aops theoremqa_theorems theoremqa_questions; do 9 | python run.py --task $TASK --model $MODEL --output_dir output/${MODEL}_${REASONING}_reasoning --cache_dir cache --reasoning $REASONING --encode_batch_size $BS 10 | done -------------------------------------------------------------------------------- /training/README.md: -------------------------------------------------------------------------------- 1 | # ReasonIR: Training 2 | 3 | This directory contains scripts for contrastive learning to fine-tune Large Language Models (LLMs) for information retrieval tasks. 4 | 5 | We use [gritlm](https://github.com/ContextualAI/gritlm) for contrastive training. We provide our script for training in `training/train.sh`. 6 | 7 | Note: to convert llama to embedding model, you need to replace the `modeling_llama.py` in your transformers package with `training/modeling_llama.py` to support non-causal attention mask. -------------------------------------------------------------------------------- /evaluation/rag/gpqa/src/conf/naive_rag_default.yaml: -------------------------------------------------------------------------------- 1 | name: naive_rag 2 | 3 | 4 | dataset_name: gpqa 5 | split: diamond 6 | subset_num: null 7 | 8 | model_path: ??? 9 | llm_endpoint: ??? 10 | search_llm_endpoint: None 11 | temperature: 0.7 12 | top_p: 0.8 13 | repetition_penalty: 1.0 14 | max_tokens: 10000 # make sure max_tokens + max_doc_len < max_model_context_length 15 | 16 | 17 | workflow_type: naive_rag 18 | top_k: 3 19 | max_doc_len: 10000 20 | search_engine: offline_massiveds 21 | use_query_rewriting: false -------------------------------------------------------------------------------- /evaluation/rag/gpqa/scripts/evaluate_naive_rag.sh: -------------------------------------------------------------------------------- 1 | # First, launch 2 | python -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --disable-cuda-graph --tp 1 --host 0.0.0.0 3 | 4 | 5 | 6 | PYTHONPATH=. python src/main.py \ 7 | --config-name naive_rag_default \ 8 | model_path=Qwen/Qwen2.5-7B-Instructt \ 9 | llm_endpoint=http://rulin@a100-st-p4de24xlarge-435:30000/v1 \ 10 | top_k=5 \ 11 | search_engine=offline_massiveds \ 12 | use_query_rewriting=false \ 13 | dataset_name=gpqa \ 14 | split=diamond -------------------------------------------------------------------------------- /evaluation/rag/gpqa/apis/base.py: -------------------------------------------------------------------------------- 1 | from apis.offline_cached_massiveds_searched_results import ( 2 | search_offline_cached_massiveds, 3 | format_offline_document_string, 4 | ) 5 | 6 | def search_api(search_engine, question, client=None, model_name=None, use_query_rewriting=False, cache=None): 7 | if search_engine == 'offline_massiveds': 8 | return search_offline_cached_massiveds(), None 9 | else: 10 | raise NotImplementedError 11 | 12 | 13 | def format_document_string(search_engine, results, top_k, max_doc_len=None): 14 | if search_engine == 'offline_massiveds': 15 | return format_offline_document_string(results, top_k, max_doc_len) 16 | else: 17 | raise NotImplementedError -------------------------------------------------------------------------------- /evaluation/rag/mmlu_cot/scripts/eval_llama_3_8b_mmlu_rag.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | save_dir="eval_results/" 4 | global_record_file="eval_results/eval_record_collection.csv" 5 | model="meta-llama/Meta-Llama-3.1-8B-Instruct" 6 | selected_subjects="all" 7 | gpu_util=0.95 8 | 9 | 10 | CUDA_VISIBLE_DEVICES=0 python evaluate_from_local_mmlu.py \ 11 | --selected_subjects $selected_subjects \ 12 | --save_dir $save_dir \ 13 | --model $model \ 14 | --global_record_file $global_record_file \ 15 | --gpu_util $gpu_util \ 16 | --retrieval_file $retrieval_file \ 17 | --raw_query_file $raw_query_file \ 18 | --concat_k 3 19 | -------------------------------------------------------------------------------- /evaluation/rag/datastore/build_datastore.sh: -------------------------------------------------------------------------------- 1 | conda activate scaling 2 | cd retrieval-scaling 3 | 4 | 5 | datastore_domain=mmlu_reasonir 6 | checkpoint=reasonir/ReasonIR-8B 7 | 8 | PYTHONPATH=. python ric/main_ric.py \ 9 | --config-name default \ 10 | tasks.datastore.embedding=True \ 11 | tasks.datastore.index=True \ 12 | model.datastore_encoder=$checkpoint \ 13 | model.query_encoder=$checkpoint \ 14 | datastore.domain=$datastore_domain \ 15 | datastore.embedding.passage_maxlength=2048 \ 16 | datastore.embedding.per_gpu_batch_size=8 \ 17 | datastore.index.projection_size=4096 \ 18 | datastore.embedding.num_shards=1 \ 19 | datastore.embedding.shard_ids=[0] \ 20 | datastore.index.index_shard_ids=[0] -------------------------------------------------------------------------------- /synthetic_data_generation/document_filters/basic_filters.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic filtering 3 | References: 4 | 1. https://github.com/mlfoundations/dclm/blob/main/baselines/mappers/filters/content_filters.py 5 | 2. https://github.com/mlfoundations/dclm/blob/main/baselines/mappers/modifiers.py 6 | 7 | Model-based filtering 8 | References: 9 | 1. https://github.com/mlfoundations/dclm/tree/main/baselines#fasttext-filtering 10 | 11 | """ 12 | 13 | import os 14 | import fasttext 15 | import urllib.request 16 | import pdb 17 | 18 | 19 | ################################################## 20 | # RULE-BASED 21 | ################################################## 22 | def white_space_length_filter(doc, min_words=20): 23 | return len(doc.split(' ')) >= min_words -------------------------------------------------------------------------------- /evaluation/rag/gpqa/src/data/datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/sunnynexus/Search-o1/blob/main/scripts/run_naive_rag.py 3 | """ 4 | import json 5 | 6 | 7 | def load_datasets(cfg): 8 | # Paths to datasets 9 | if cfg.dataset_name == 'livecode': 10 | data_path = f'./data/LiveCodeBench/{cfg.split}.json' 11 | elif cfg.dataset_name in ['math500', 'gpqa', 'aime', 'amc']: 12 | data_path = f'./data/{cfg.dataset_name.upper()}/{cfg.split}.json' 13 | else: 14 | data_path = f'./data/QA_Datasets/{cfg.dataset_name}.json' 15 | 16 | 17 | # ---------------------- Data Loading ---------------------- 18 | with open(data_path, 'r', encoding='utf-8') as json_file: 19 | data = json.load(json_file) 20 | if cfg.subset_num is not None: 21 | data = data[:cfg.subset_num] 22 | 23 | return data -------------------------------------------------------------------------------- /training/config_128gpusfsdp_llama.yml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: FSDP 4 | downcast_bf16: 'no' 5 | fsdp_config: 6 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 7 | fsdp_backward_prefetch: NO_PREFETCH 8 | fsdp_cpu_ram_efficient_loading: false 9 | fsdp_forward_prefetch: false 10 | fsdp_offload_params: false 11 | fsdp_sharding_strategy: HYBRID_SHARD 12 | fsdp_state_dict_type: FULL_STATE_DICT 13 | fsdp_sync_module_states: true 14 | fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer 15 | fsdp_use_orig_params: true 16 | machine_rank: 0 17 | main_process_ip: merlion-ultra-ghpc-15 18 | main_process_port: 8000 19 | main_training_function: main 20 | mixed_precision: bf16 21 | num_machines: 16 22 | num_processes: 128 23 | rdzv_backend: static 24 | same_network: true 25 | tpu_env: [] 26 | tpu_use_cluster: false 27 | tpu_use_sudo: false 28 | use_cpu: false 29 | -------------------------------------------------------------------------------- /test_time_techniques/README.md: -------------------------------------------------------------------------------- 1 | # ReasonIR: Test-Time Techniques---Query Rewriting and LLM Reranking 2 | 3 | This directory contains scripts for the two techniques we used to further enhance ReasonIR's performance---query rewriting and LLM reranking. 4 | 5 | To run query rewriting with token limit, run 6 | ```bash 7 | tasks=( 8 | biology 9 | earth_science 10 | economics 11 | psychology 12 | robotics 13 | stackoverflow 14 | sustainable_living 15 | leetcode 16 | pony 17 | aops 18 | theoremqa_theorems 19 | theoremqa_questions 20 | ) 21 | 22 | MODEL=gpt-4o-mini 23 | for TOKEN_LIMIT in 2048; do 24 | for TASK in "${tasks[@]}" 25 | do 26 | echo $TASK 27 | echo $TOKEN_LIMIT 28 | PYTHONPATH=. python /home/rulin/bright-dev/reason.py \ 29 | --task $TASK \ 30 | --llm $MODEL \ 31 | --output_token_limit $TOKEN_LIMIT 32 | done 33 | done 34 | ``` 35 | -------------------------------------------------------------------------------- /evaluation/rag/mmlu_cot/extract_mmlu_group.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | 4 | 5 | def get_mmlu_group_subjects(): 6 | task_dir = '/home/rulin/RAG-evaluation-harnesses/lm_eval/tasks/mmlu/default' 7 | 8 | groups = { 9 | 'humanities': [], 10 | 'other': [], 11 | 'social_sciences': [], 12 | 'stem': [], 13 | } 14 | 15 | subject_to_group = {} 16 | 17 | for filename in os.listdir(task_dir): 18 | if filename.endswith(".yaml") and filename.startswith('mmlu'): 19 | filepath = os.path.join(task_dir, filename) 20 | with open(filepath, 'r') as fin: 21 | data = yaml.safe_load(fin) 22 | 23 | group = data['group'].replace('mmlu_', '') 24 | task = data['task'].replace('mmlu_', '') 25 | 26 | groups[group].append(task) 27 | 28 | subject_to_group[task] = group 29 | 30 | print(groups) 31 | 32 | return groups, subject_to_group 33 | -------------------------------------------------------------------------------- /evaluation/rag/mmlu_cot/utils/extract_cot_as_queries.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pdb 4 | 5 | 6 | def load_json(path): 7 | with open(path, 'r') as file: 8 | data = json.load(file) 9 | return data 10 | 11 | def write_jsonl(data, path): 12 | with open(path, 'w') as fout: 13 | for ex in data: 14 | fout.write(json.dumps(ex)+'\n') 15 | 16 | 17 | def format_query(input_dir, output_file): 18 | all_filenames = os.listdir(input_dir) 19 | all_queries = [] 20 | for filename in all_filenames: 21 | if not filename.endswith('.json'): 22 | continue 23 | data = load_json(os.path.join(input_dir, filename)) 24 | for ex in data: 25 | query = ex['model_outputs'].split('\n\nThe answer is')[0] 26 | all_queries.append({ 27 | 'query': query, 28 | 'question': ex['question'], 29 | }) 30 | write_jsonl(all_queries, output_file) 31 | -------------------------------------------------------------------------------- /synthetic_data_generation/script.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | ## BRIGHT tasks ## 4 | tasks=( 5 | biology 6 | earth_science 7 | economics 8 | psychology 9 | robotics 10 | stackoverflow 11 | sustainable_living 12 | leetcode 13 | pony 14 | aops 15 | theoremqa_theorems 16 | theoremqa_questions 17 | ) 18 | 19 | 20 | MODEL=meta-llama/Meta-Llama-3.1-70B-Instruct 21 | queries_per_doc=1 22 | num_docs=10 23 | prompt_id=hq_gen 24 | output_dir=synthetic_data/$MODEL/hq 25 | 26 | export VLLM_WORKER_MULTIPROC_METHOD=spawn 27 | 28 | # Loop through each task and run the script 29 | for TASK in "${tasks[@]}"; do 30 | python -m doc_to_query --model_id $MODEL --queries_per_doc $queries_per_doc --num_docs $num_docs --subject $TASK --output_dir $output_dir --filter fineweb --prompt_id $prompt_id 31 | python -m generate_reasoning --model_id $MODEL --num_docs $num_docs --subject $TASK --base_dir $output_dir --prompt_id $prompt_id 32 | done -------------------------------------------------------------------------------- /synthetic_data_generation/document_filters/fineweb_edu_filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 3 | import pdb 4 | 5 | 6 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 7 | 8 | def fineweb_quality_filter(passages): 9 | tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/fineweb-edu-classifier") 10 | model = AutoModelForSequenceClassification.from_pretrained("HuggingFaceTB/fineweb-edu-classifier").to(device) 11 | 12 | if isinstance(passages, str): 13 | passages = [passages] 14 | 15 | scores = [] 16 | for text in passages: 17 | inputs = tokenizer(text, return_tensors="pt", padding="longest", truncation=True) 18 | inputs = {k: v.to(device) for k, v in inputs.items()} 19 | with torch.no_grad(): 20 | outputs = model(**inputs) 21 | logits = outputs.logits.squeeze(-1).float().detach().cpu().numpy() 22 | score = logits.item() 23 | result = { 24 | "text": text, 25 | "score": score, 26 | "int_score": int(round(max(0, min(score, 5)))), 27 | } 28 | scores.append(score) 29 | 30 | return scores 31 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to ReasonIR 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to ReasonIR, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /synthetic_data_generation/script_batch.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | ## BRIGHT tasks ## 4 | tasks=( 5 | biology 6 | earth_science 7 | economics 8 | psychology 9 | robotics 10 | stackoverflow 11 | sustainable_living 12 | leetcode 13 | pony 14 | aops 15 | theoremqa_theorems 16 | theoremqa_questions 17 | ) 18 | 19 | 20 | # Loop through each task and run the script by adding 21 | # for TASK in "${tasks[@]}"; do 22 | # done 23 | 24 | # This is a demo script that only runs for documents in the datastore of the biology task 25 | TASK=biology 26 | MODEL=gpt-4o 27 | queries_per_doc=1 28 | num_docs=10 29 | prompt_id=hq_gen 30 | output_dir=synthetic_data/$MODEL/hq 31 | 32 | export VLLM_WORKER_MULTIPROC_METHOD=spawn 33 | python -m doc_to_query_batch --model_id $MODEL --queries_per_doc $queries_per_doc --num_docs $num_docs --subject $TASK --output_dir $output_dir --filter fineweb --prompt_id $prompt_id 34 | python -m doc_to_query_batch --model_id $MODEL --queries_per_doc $queries_per_doc --num_docs $num_docs --subject $TASK --output_dir $output_dir --filter fineweb --prompt_id $prompt_id --gather_results 35 | python -m generate_reasoning_batch --model_id $MODEL --num_docs $num_docs --subject $TASK --base_dir $output_dir --prompt_id $prompt_id 36 | python -m generate_reasoning_batch --model_id $MODEL --num_docs $num_docs --subject $TASK --base_dir $output_dir --prompt_id $prompt_id --gather_results -------------------------------------------------------------------------------- /evaluation/bright/reranker_script.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | cd evaluation/bright 4 | 5 | tasks=("biology" "earth_science" "economics" "psychology" "robotics" "stackoverflow" "sustainable_living" "leetcode" "pony" "aops" "theoremqa_theorems" "theoremqa_questions") 6 | 7 | # use this block to run reasonir and store the reranker scores 8 | MODEL=reasonir 9 | REASONING=gpt4 10 | for TASK in "${tasks[@]}"; do 11 | python run.py --task $TASK --model $MODEL --output_dir output/${MODEL} --cache_dir cache --reasoning $REASONING 12 | done 13 | 14 | # use this block to run bm25 and store the bm25 scores for combining with the reranker scores 15 | # If you want to interpolate with retriever scores, you don't need to run this block 16 | # Loop through tasks and run the command 17 | # for task in "${tasks[@]}"; do 18 | # echo "Running task: $task" 19 | # python run.py --task "$task" --model bm25 --output_dir output/bm25/ --reasoning gpt4 --store_all_scores 20 | # # --reasoning gpt4 21 | # done 22 | 23 | # use this block to run the reranker 24 | # If you want to combine with bm25 scores run the commented block above and 25 | # add --bm25_score_file "output/bm25/${task}_bm25_long_False/${REASONING}_score.json" to the python command below 26 | for task in "${tasks[@]}"; do 27 | echo "Running task: $task" 28 | python reranker.py --task "$task" --retriever_score_file "output/${MODEL}/${task}_${MODEL}_long_False/${REASONING}_score.json" --input_k 100 --k 100 --output_dir "output/reranker/${task}_${MODEL}_long_False/" --reasoning $REASONING 29 | done -------------------------------------------------------------------------------- /synthetic_data_generation/README.md: -------------------------------------------------------------------------------- 1 | # ReasonIR: Synthetic Data Generation 2 | 3 | This directory contains scripts for synthetic data generation. 4 | 5 | ## Setup 6 | ```bash 7 | conda create -n reasonir python=3.10 8 | conda activate reasonir 9 | pip install -r requirements.txt 10 | bash setup_java.sh 11 | ``` 12 | 13 | ## Synthetic Data Generation 14 | 15 | A simple example is provided in `script.sh`. In detail: 16 | ### Generate the queries based on the documents from the datastore of BRIGHT 17 | ```bash 18 | python -m doc_to_query --model_id $MODEL --queries_per_doc $queries_per_doc \ 19 | --num_docs $num_docs --subject $TASK --output_dir $output_dir \ 20 | --filter fineweb --prompt_id $prompt_id 21 | ``` 22 | 23 | ### Generate the rewritten queries with reasoning given the queries 24 | ```bash 25 | python -m generate_reasoning --model_id $MODEL --num_docs $num_docs --subject $TASK \ 26 | --base_dir $output_dir --prompt_id $prompt_id 27 | ``` 28 | 29 | ### Batched version 30 | When generating data using the APIs, it is generally cheaper and faster to use the batch API. We provide a helper in `batch_api_helper.py` and scripts in `script_batch.sh`. In particular, the data synthesis steps can be performed via: 31 | ```bash 32 | python -m doc_to_query_batch --model_id $MODEL --queries_per_doc $queries_per_doc \ 33 | --num_docs $num_docs --subject $TASK --output_dir $output_dir \ 34 | --filter fineweb --prompt_id $prompt_id 35 | 36 | python -m doc_to_query_batch --model_id $MODEL --queries_per_doc $queries_per_doc \ 37 | --num_docs $num_docs --subject $TASK --output_dir $output_dir \ 38 | --filter fineweb --prompt_id $prompt_id --gather_results 39 | 40 | python -m generate_reasoning_batch --model_id $MODEL --num_docs $num_docs --subject $TASK --base_dir $output_dir \ 41 | --prompt_id $prompt_id 42 | 43 | python -m generate_reasoning_batch --model_id $MODEL --num_docs $num_docs --subject $TASK --base_dir $output_dir \ 44 | --prompt_id $prompt_id --gather_results 45 | ``` -------------------------------------------------------------------------------- /evaluation/rag/gpqa/src/main.py: -------------------------------------------------------------------------------- 1 | # run_naive_rag.py 2 | import os 3 | import logging 4 | from omegaconf.omegaconf import OmegaConf, open_dict 5 | 6 | 7 | 8 | from openai import OpenAI 9 | 10 | from src.eval.evaluate import run_evaluation, extract_answer 11 | from src.data.datasets import load_datasets 12 | from src.workflow.naive_rag import NaiveRAG 13 | from src.utils.hydra_runner import hydra_runner 14 | 15 | 16 | 17 | @hydra_runner(config_path="conf", config_name="default") 18 | def main(cfg): 19 | logging.info("\n\n************** Experiment configuration ***********") 20 | logging.info(f'\n{OmegaConf.to_yaml(cfg)}') 21 | 22 | # ---------------------- Model Loading ---------------------- 23 | llm = OpenAI(base_url=cfg.llm_endpoint, api_key="") 24 | if cfg.search_llm_endpoint is not None: 25 | search_llm = OpenAI(base_url=cfg.search_llm_endpoint, api_key="") 26 | else: 27 | search_llm = None 28 | # ---------------------- Dataset Loading ---------------------- 29 | data = load_datasets(cfg) 30 | 31 | # ---------------------- Run Workflow ---------------------- 32 | workflow = NaiveRAG(llm, cfg, search_llm=search_llm) 33 | outputs = workflow.run(data) 34 | 35 | # ---------------------- Evaluation ---------------------- 36 | print("Evaluating generated answers...") 37 | # Define output directory based on model and dataset 38 | model_short_name = cfg.model_path.split('/')[-1].lower() 39 | output_dir = f'./outputs/runs.naive_rag/{cfg.dataset_name}/{model_short_name}.{cfg.search_engine}' 40 | os.makedirs(output_dir, exist_ok=True) 41 | 42 | run_evaluation( 43 | filtered_data=data, 44 | input_list=outputs.input_prompts, 45 | output_list=outputs.output_list, 46 | dataset_name=cfg.dataset_name, 47 | output_dir=output_dir, 48 | total_time=outputs.total_time, 49 | split=cfg.split, 50 | ) 51 | 52 | print("Process completed.") 53 | 54 | if __name__ == "__main__": 55 | main() -------------------------------------------------------------------------------- /evaluation/rag/mmlu_cot/compute_accuracy.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import sys 3 | import json 4 | import re 5 | import random 6 | 7 | assert len(sys.argv) > 1, 'You need to pass the directory' 8 | path = sys.argv[1] 9 | 10 | 11 | def extract_answer(text, level): 12 | if level == 'l1': 13 | pattern = r"answer is \(?([A-J])\)?" 14 | match = re.search(pattern, text) 15 | if match: 16 | return match.group(1) 17 | else: 18 | return None 19 | elif level == 'l2': 20 | pattern = r"answer is \(?([A-J])\)?" 21 | match = re.search(pattern, text) 22 | if match: 23 | return match.group(1) 24 | else: 25 | return extract_again(text) 26 | 27 | 28 | def extract_again(text): 29 | match = re.search(r'.*[aA]nswer:\s*([A-J])', text) 30 | if match: 31 | return match.group(1) 32 | else: 33 | return extract_final(text) 34 | 35 | 36 | def extract_final(text): 37 | pattern = r"\b[A-J]\b(?!.*\b[A-J]\b)" 38 | match = re.search(pattern, text, re.DOTALL) 39 | if match: 40 | return match.group(0) 41 | else: 42 | return None 43 | 44 | 45 | for name in glob.glob(path + '/*'): 46 | print('Level 1 regex' + '==' * 20) 47 | succ, fail = 0, 0 48 | with open(name, 'r') as f: 49 | entries = json.load(f) 50 | for e in entries: 51 | pred = extract_answer(e['model_outputs'], 'l1') 52 | if pred is None: 53 | random.seed(12345) 54 | pred = random.choice(["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]) 55 | # Remove the None cases 56 | if pred == e['answer']: 57 | succ += 1 58 | else: 59 | fail += 1 60 | print(name, succ / (succ + fail)) 61 | 62 | print('Level 2 regex' + '==' * 20) 63 | succ, fail = 0, 0 64 | with open(name, 'r') as f: 65 | entries = json.load(f) 66 | for e in entries: 67 | pred = extract_answer(e['model_outputs'], 'l2') 68 | if pred is None: 69 | random.seed(12345) 70 | pred = random.choice(["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]) 71 | # Remove the None cases 72 | if pred == e['answer']: 73 | succ += 1 74 | else: 75 | fail += 1 76 | print(name, succ / (succ + fail)) 77 | 78 | print() -------------------------------------------------------------------------------- /evaluation/rag/gpqa/src/utils/cache_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | 5 | class Cache: 6 | def __init__(self, cfg): 7 | self.cfg = cfg 8 | # ---------------------- Caching Mechanism ---------------------- 9 | # Define cache directories and file paths 10 | search_cache_dir = self.get_search_cache_dirname() 11 | self.search_cache_path = os.path.join(search_cache_dir, 'search_cache.json') 12 | self.url_cache_path = os.path.join(search_cache_dir, 'url_cache.json') 13 | os.makedirs(search_cache_dir, exist_ok=True) 14 | 15 | if self.cfg.use_query_rewriting: 16 | query_cache_dir = self.get_cot_query_cache_dirname() 17 | self.cot_query_path = os.path.join(query_cache_dir, f'query_cache.json') 18 | os.makedirs(query_cache_dir, exist_ok=True) 19 | 20 | # Load existing caches or initialize empty dictionaries 21 | if os.path.exists(self.search_cache_path): 22 | with open(self.search_cache_path, 'r', encoding='utf-8') as f: 23 | self.search_cache = json.load(f) 24 | else: 25 | self.search_cache = {} 26 | 27 | if os.path.exists(self.url_cache_path): 28 | with open(self.url_cache_path, 'r', encoding='utf-8') as f: 29 | self.url_cache = json.load(f) 30 | else: 31 | self.url_cache = {} 32 | 33 | if self.cfg.use_query_rewriting and os.path.exists(self.cot_query_path): 34 | with open(self.cot_query_path, 'r', encoding='utf-8') as f: 35 | self.cot_query_cache = json.load(f) 36 | elif self.cfg.use_query_rewriting: 37 | self.cot_query_cache = {} 38 | 39 | def get_search_cache_dirname(self,): 40 | cache_dir = f'./cache/{self.cfg.search_engine}' 41 | if self.cfg.use_query_rewriting: 42 | query_writer_name = self.cfg.model_path.split('/')[-1].replace('-', '_').replace('.', '_') 43 | cache_dir += f'/{query_writer_name}_cot_query' 44 | return cache_dir 45 | 46 | def get_cot_query_cache_dirname(self,): 47 | query_writer_name = self.cfg.model_path.split('/')[-1].replace('-', '_').replace('.', '_') 48 | cache_dir = f'./cache/cot_query/{self.cfg.search_engine}_with_query_rewrite_{query_writer_name}' 49 | return cache_dir 50 | 51 | # Function to save caches 52 | def save_caches(self,): 53 | with open(self.search_cache_path, 'w', encoding='utf-8') as f: 54 | json.dump(self.search_cache, f, ensure_ascii=False, indent=2) 55 | with open(self.url_cache_path, 'w', encoding='utf-8') as f: 56 | json.dump(self.url_cache, f, ensure_ascii=False, indent=2) 57 | if self.cfg.use_query_rewriting: 58 | with open(self.cot_query_path, 'w', encoding='utf-8') as f: 59 | json.dump(self.cot_query_cache, f, ensure_ascii=False, indent=2) -------------------------------------------------------------------------------- /synthetic_data_generation/generate_reasoning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import argparse 5 | from tqdm import tqdm 6 | 7 | from hard_negative_mining import BM25_Miner 8 | from data_gen_prompts import * 9 | from gen_utils import * 10 | from data_gen_prompts import * 11 | from lm_helper import MultiTurnOpenAILM, OpenAILM, HFLM 12 | import re 13 | 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--model_id', type=str, default='gpt-4o', help='model id') 18 | parser.add_argument('--prompt_id', type=str, default=None, help='prompt id') 19 | parser.add_argument('--debug', action='store_true', help='debug mode') 20 | parser.add_argument('--base_dir', type=str, default='synthetic_data/top100', help='base directory to save the generated data') 21 | parser.add_argument('--num_docs', type=int, default=None, help='number of samples to collect') 22 | parser.add_argument('--dataset', type=str, default='bright', help='dataset to collect the data for') 23 | parser.add_argument('--subject', type=str, default='biology', help='subject to collect the data for') 24 | parser.add_argument("--temperature", type=float, default=0) 25 | parser.add_argument("--top_p", type=float, default=0) 26 | args = parser.parse_args() 27 | print(args) 28 | 29 | # load the queries 30 | model_id = args.model_id 31 | model_id_str = model_id.split('/')[-1] 32 | num_docs = args.num_docs 33 | if args.prompt_id is not None: 34 | train_data_path = os.path.join(args.base_dir, 'all_docs_train_data', args.prompt_id, model_id_str) 35 | output_path = f'{args.base_dir}/reasoning_data/{args.prompt_id}/{model_id_str}' 36 | else: 37 | train_data_path = os.path.join(args.base_dir, 'all_docs_train_data', model_id_str) 38 | output_path = f'{args.base_dir}/reasoning_data/{model_id_str}' 39 | train_data_path = os.path.expanduser(train_data_path) 40 | from gen_utils import load_training_data 41 | subject = args.dataset if args.dataset in ['msmarco'] else args.subject 42 | data = load_training_data(train_data_path, num_docs=num_docs, subject=subject) 43 | 44 | # Initialize the model 45 | if 'gpt' in model_id: 46 | model = OpenAILM(model_id, temperature=args.temperature, top_p=args.top_p, seed=0) 47 | else: 48 | model = HFLM(model_id, temperature=args.temperature, top_p=args.top_p) 49 | 50 | system_prompt = PROMPT_COT_BRIGHT 51 | 52 | print(f'System prompt:\n{system_prompt}') 53 | os.makedirs(output_path, exist_ok=True) 54 | for subject in data.keys(): 55 | subject_filename = f'{output_path}/{subject}_{num_docs}.jsonl' 56 | subject_data = data[subject] 57 | for datum in tqdm(subject_data): 58 | query = datum['query'] 59 | reasoning = model.generate(query, system_prompt=system_prompt) 60 | datum['reasoning'] = reasoning 61 | if args.debug: 62 | break 63 | write_jsonl(subject_data, subject_filename) 64 | -------------------------------------------------------------------------------- /evaluation/rag/gpqa/apis/offline_massiveds_search_api.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | 5 | original_query_retrieved_file = os.getenv("RETRIEVED_FILE") 6 | retrieved_file = os.getenv("RETRIEVED_FILE") # when evaluating query rewriting, replace it with the path to the retrieved file with query rewriting 7 | 8 | def load_original_queries(): 9 | queries = [] 10 | with open(original_query_retrieved_file, 'r') as fin: 11 | for line in fin: 12 | result = json.loads(line) 13 | query = result['query'] 14 | queries.append(query) 15 | return queries 16 | 17 | def load_offline_searched_results(top_k=5): 18 | if "reasoning" in retrieved_file: 19 | queries = load_original_queries() 20 | query2docs = {} 21 | with open(retrieved_file, 'r') as fin: 22 | for idx, line in enumerate(fin): 23 | result = json.loads(line) 24 | if 'searched_results' in retrieved_file: 25 | # search results from API calls 26 | passages = result['results']['results']['passages'][0][:top_k] 27 | passages = [{"retrieval text": psg} for psg in passages] 28 | query = queries[idx] 29 | query2docs[query] = passages 30 | else: 31 | query = queries[idx] 32 | passages = result['ctxs'] 33 | query2docs[query] = passages 34 | else: 35 | query2docs = {} 36 | with open(retrieved_file, 'r') as fin: 37 | for line in fin: 38 | result = json.loads(line) 39 | if 'searched_results' in retrieved_file: 40 | # search results from API calls 41 | passages = result['results']['results']['passages'][0][:top_k] 42 | passages = [{"retrieval text": psg} for psg in passages] 43 | query = result['results']['query'] 44 | query2docs[query] = passages 45 | else: 46 | query = result['query'] 47 | passages = result['ctxs'] 48 | query2docs[query] = passages 49 | return query2docs 50 | 51 | def search_offline_cached_massiveds(data, top_k=5): 52 | questions = [item['Question'] for item in data] 53 | query2docs = load_offline_searched_results(top_k) 54 | results = [] 55 | for question in questions: 56 | if question in query2docs: 57 | results.append(query2docs[question]) 58 | else: 59 | results.append(None) 60 | print(f"Question {question} not found in offline cached massiveds.") 61 | return results 62 | 63 | 64 | def format_offline_document_string(results, top_k=5, max_doc_len=None): 65 | passage_str = "" 66 | for i in range(min(top_k, len(results))): 67 | passage_str += f"Passage {i+1}:\n{results[i]['retrieval text']}\n\n" 68 | print(len(passage_str.split(" "))) 69 | return passage_str.strip() 70 | 71 | if __name__ == '__main__': 72 | load_offline_searched_results(5) -------------------------------------------------------------------------------- /evaluation/bright/prompts.py: -------------------------------------------------------------------------------- 1 | bright_aops = """We want to find different but similar math problems to the following problem: 2 | {} 3 | A document is relevant if it uses the same class of functions and shares **any** overlapping techniques. 4 | Document: {} 5 | Score the document above. The answer should be 'Relevance score: X.' where X is a number from 0-5. 6 | 0 means completely irrelevant, 5 means highly relevant and completely addresses the query. Don't output anything else. 7 | """ 8 | 9 | bright_theoremqa_questions = """We want to find a document which uses the same mathematical process as this one: 10 | {} 11 | A document is relevant if it uses the same mathematical process as the query. 12 | Document: {} 13 | Score the document above. The answer should be 'Relevance score: X.' where X is a number from 0-5. 14 | 0 means completely irrelevant, 5 means highly relevant and completely addresses the query. Don't output anything else. 15 | """ 16 | 17 | bright_leetcode = """I am looking to find different problems that share similar data structures 18 | (of any kind) or algorithms (e.g. DFS, DP, sorting, traversals, etc.). I am looking for problems that share one or both of these similarities to 19 | this: 20 | {} 21 | Does the passage below share any similarities? e.g. if there was a textbook on leetcode problems, this would be in the same book even though it could be in a different chapter. 22 | Passage: {} 23 | Please rate the passage above. The answer should be 'Relevance score: X.' where X is a number from 0-5. 24 | 0 means completely irrelevant, 5 means highly relevant and completely addresses the query. Don't output anything else. 25 | """ 26 | 27 | bright_pony = """I will use the programming language pony. 28 | Problem: {} 29 | But to solve the problem above, I need to know things about pony. A passage is relevant if it contains docs that match any part (even basic parts) of the code I will have to write for the above program. 30 | Passage: {} 31 | Please rate the passage above. The answer should be 'Relevance score: X.' where X is a number from 0-5. 32 | 0 means completely irrelevant, 5 means highly relevant and completely addresses the query. Don't output anything else. 33 | """ 34 | 35 | bright_theoremqa_theorems = """"We want to find a document which uses the same mathematical process as this one: 36 | {} 37 | A document is relevant if it uses the same mathematical process as the query. 38 | Document: {} 39 | Score the document above. The answer should be 'Relevance score: X.' where X is a number from 0-5. 40 | 0 means completely irrelevant, 5 means highly relevant and completely addresses the query. Don't output anything else. 41 | """ 42 | 43 | bright_general = """A document is relevant if it contains information that helps answer or address the query. 44 | A document is not relevant if it doesn't contain information that helps answer the query, even if it mentions similar topics. 45 | Is the document below relevant to answering the query below? 46 | The answer should be 'Relevance score: X.' where X is a number from 0-5. 47 | 0 means completely irrelevant, 5 means highly relevant and completely addresses the query. Don't output anything else. 48 | Here is the query: 49 | 50 | {} 51 | 52 | Here is the document: 53 | 54 | {} 55 | 56 | """ 57 | -------------------------------------------------------------------------------- /training/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=llama 3 | #SBATCH --nodes=16 4 | #SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! 5 | #SBATCH --hint=nomultithread # we get physical cores not logical 6 | #SBATCH --account fair_amaia_cw_explore 7 | #SBATCH --qos explore 8 | #SBATCH --mem 1000G 9 | #SBATCH --gres=gpu:8 # number of gpus 10 | #SBATCH --time 120:00:00 # maximum execution time (HH:MM:SS) 11 | #SBATCH --requeue 12 | #SBATCH --chdir=/home/rulin/gritlm-dev/gritlm 13 | #SBATCH --output=/checkpoint/amaia/explore/rulin/gritlm/slurm_cache/slurm-%A_%a.out 14 | #SBATCH --array=0 15 | 16 | ###################### 17 | ### Set enviroment ### 18 | ###################### 19 | cd /home/rulin/gritlm-dev/gritlm 20 | source /home/rulin/miniconda3/bin/activate 21 | conda activate grit 22 | export HF_HOME=/checkpoint/amaia/explore/rulin/cache/.cache/huggingface 23 | #NCCL_ASYNC_ERROR_HANDLING=1 24 | export WANDB_PROJECT="grit" 25 | # Training setup 26 | GPUS_PER_NODE=8 27 | # so processes know who to talk to 28 | MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 29 | MASTER_PORT=6000 30 | NNODES=$SLURM_NNODES 31 | NODE_RANK=$SLURM_PROCID 32 | WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) 33 | ###################### 34 | 35 | ###################### 36 | #### Set network ##### 37 | ###################### 38 | head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 39 | ###################### 40 | 41 | 42 | 43 | LAUNCHER="accelerate launch \ 44 | --config_file /home/rulin/gritlm-dev/scripts/configs/config_128gpusfsdp_llama.yml \ 45 | --num_machines $NNODES \ 46 | --num_processes $WORLD_SIZE \ 47 | --main_process_ip "$MASTER_ADDR" \ 48 | --main_process_port $MASTER_PORT \ 49 | --num_processes $WORLD_SIZE \ 50 | --machine_rank \$SLURM_PROCID \ 51 | --role $SLURMD_NODENAME: \ 52 | --rdzv_conf rdzv_backend=c10d \ 53 | --max_restarts 0 \ 54 | --tee 3 \ 55 | " 56 | 57 | 58 | TRAIN_DATA=data/ # replace with the directory of your training data 59 | 60 | # please refer to https://github.com/ContextualAI/gritlm/blob/main/gritlm/training/run.py for training scripts (e.g., run.py) 61 | export CMD=" \ 62 | -m training.run \ 63 | --output_dir checkpoints/$(date "+%Y-%m-%d-%H_%M_%S")/ \ 64 | --model_name_or_path meta-llama/Llama-3.1-8B \ 65 | --train_data $TRAIN_DATA \ 66 | --learning_rate 2e-5 \ 67 | --lr_scheduler_type constant_with_warmup \ 68 | --warmup_ratio 0.06 \ 69 | --max_steps 1000 \ 70 | --per_device_train_batch_size 4 \ 71 | --gradient_accumulation_steps 4 \ 72 | --dataloader_drop_last \ 73 | --normalized \ 74 | --temperature 0.02 \ 75 | --train_group_size 2 \ 76 | --negatives_cross_device \ 77 | --query_max_len 2048 \ 78 | --passage_max_len 2048 \ 79 | --mode embedding \ 80 | --logging_steps 1 \ 81 | --bf16 \ 82 | --pooling_method mean \ 83 | --use_unique_indices \ 84 | --loss_gen_factor 2 \ 85 | --attn bbcc \ 86 | --gradient_checkpointing \ 87 | --attn_implementation sdpa \ 88 | --split_emb \ 89 | --save_steps 500 90 | " 91 | 92 | SRUN_ARGS=" \ 93 | --wait=60 \ 94 | --kill-on-bad-exit=1 \ 95 | " 96 | 97 | clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER $CMD" 2>&1 98 | -------------------------------------------------------------------------------- /evaluation/README.md: -------------------------------------------------------------------------------- 1 | # ReasonIR: Evaluation 2 | 3 | This directory contains evaluation scripts for Information Retrieval (IR) and Retrieval-Augmented Generation (RAG) tasks as described in our research paper. 4 | 5 | 6 | ## BRIGHT 7 | 8 | Setup (same as for synthetic data generation): 9 | ```bash 10 | conda create -n reasonir python=3.10 11 | conda activate reasonir 12 | pip install -r evaluation/bright/requirements.txt 13 | 14 | # to evaluate with BM25, you need to download java 15 | wget https://download.oracle.com/java/22/latest/jdk-22_linux-x64_bin.deb 16 | sudo dpkg -i 17 | ``` 18 | 19 | To evaluate ReasonIR on BRIGHT, run 20 | ```bash 21 | bash evaluation/bright/script.sh 22 | ``` 23 | 24 | To evaluate ReasonIR on BRIGHT with a reranker (QwenRerank), run 25 | ```bash 26 | bash evaluation/bright/reranker_script.sh 27 | ``` 28 | Note that this script first runs the retriever and then the reranker. This script produces two results files –reranker_results.json and reranker_retriever_results.json. The former is the results obtained using just the reranker scores and the later is results obtained when the reranker score are interpolated with the retriever scors. 29 | If you want to combine the reranker scores with BM25 scores, look at the comments in reranker_script.sh for instructions. 30 | 31 | To reproduce the results for some other baselines (such as Cohere and Voyage embeddings), please install other required packages via `pip install evaluation/bright/other_requirements.txt`. 32 | 33 | 34 | ## Downstream RAG evaluation 35 | 36 | In order to reduce the cost of datastore construction, we first retrieve the top-1000 documents from the original [MassiveDS-1.4T](https://arxiv.org/abs/2407.12854) built with Contriever for each benchmark respectively. We then merge the retrieved documents as a new and smaller pool of datastore for experiments. To merge and deduplicate these documents, we use the script in `datastore/construct_datastore_corpus.py`. 37 | 38 | 39 | To embed and index the filtered data with our retriever, run 40 | ```bash 41 | git clone https://github.com/RulinShao/retrieval-scaling.git 42 | bash evaluation/rag/datastore/build_datastore.sh 43 | ``` 44 | We then use the [MassiveDS codebase](https://github.com/RulinShao/retrieval-scaling) to search for the queries following the instructions in the [Datastore](#datastore). 45 | 46 | 47 | 48 | ### MMLU 49 | 50 | To evaluate ReasonIR on MMLU, replace the data directories and run 51 | ```bash 52 | export retrieval_file=$YOUR_RETRIEVAL_FILE # refer to REAMDE for more details 53 | export raw_query_file=mmlu.jsonl # refer to the original MMLU questions used for retrieval 54 | bash evaluation/rag/mmlu_cot/scripts/eval_llama_3_8b_mmlu_rag.sh 55 | ``` 56 | 57 | ### GPQA 58 | 59 | First launch the LLM using vllm to obtain a local serving api: 60 | ```bash 61 | python -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --disable-cuda-graph --tp 1 --host 0.0.0.0 62 | ``` 63 | 64 | Then, run evaluation 65 | ```bash 66 | cd evaluation/rag/gpqa 67 | export RETRIEVED_FILE=YOUR_RETRIEVED_FILE 68 | PYTHONPATH=. python src/main.py \ 69 | --config-name naive_rag_default \ 70 | model_path=Qwen/Qwen2.5-7B-Instructt \ 71 | llm_endpoint=http://${VLLM_ENDPOINT}-${VLLM_PORT}:30000/v1 \ 72 | top_k=5 \ 73 | search_engine=offline_massiveds \ 74 | use_query_rewriting=false \ 75 | dataset_name=gpqa \ 76 | split=diamond 77 | ``` 78 | -------------------------------------------------------------------------------- /synthetic_data_generation/data_gen_prompts.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | 4 | DOC2QUERY_BASELINE = '''Given a document, generate {num_questions} questions for which the document is relevant and useful to provide the answer. Format the generated questions in JSON with key "questions": 5 | ```json 6 | {{ 7 | "questions": [ "question 1", ...] 8 | }} 9 | ``` 10 | ''' 11 | 12 | 13 | DOC2HARD_QUERY = '''# Context 14 | You are tasked with generating {num_questions} reasoning-intensive questions with scenarios based on a given document. These questions must be standalone (meaningful without the document) while being answerable using information from the document as supporting evidence. The questions should specifically engage with core concepts and principles from the document's domain. 15 | 16 | # Question Requirements 17 | 1. Each question MUST: 18 | - Present a complete scenario or context within itself 19 | - Be answerable through logical reasoning and critical thinking 20 | - Remain valid and meaningful even if the source document didn't exist 21 | - Target higher-order thinking skills (analysis, evaluation, synthesis) 22 | - Be domain-relevant but not document-specific 23 | - Incorporate key concepts, terminology, and principles from the document's field 24 | - Challenge understanding of domain-specific problem-solving approaches 25 | 26 | 2. Each question MUST NOT: 27 | - Directly reference the document or its contents 28 | - Be answerable through simple fact recall 29 | - Require specific knowledge only found in the document 30 | - Be a reading comprehension question 31 | - Stray from the core subject matter of the document's domain 32 | 33 | # Domain Alignment Guidelines 34 | Before generating questions: 35 | 1. Identify the primary domain (e.g., programming, medicine, economics) 36 | 2. Extract key concepts and principles from the document 37 | 3. List common problem-solving patterns in this domain 38 | 39 | When crafting questions: 40 | 1. Frame scenarios using domain-specific contexts 41 | 2. Incorporate relevant technical terminology naturally 42 | 3. Focus on problem-solving approaches typical to the field 43 | 4. Connect theoretical concepts to practical applications within the domain 44 | 45 | After generating the questions step by step, reformat all questions including the corresponding scenarios in JSON with key "hard_query": 46 | ```json 47 | {{ 48 | "hard_query": [ Q1, Q2, Q3, ...] 49 | }} 50 | ``` 51 | ''' 52 | 53 | 54 | PROMPT_COT_BRIGHT = '''(1) Identify the essential problem in the post. 55 | (2) Think step by step to reason about what should be included in the relevant documents. 56 | (3) Draft an answer.''' 57 | 58 | 59 | def get_user_prompt_cot_bright(query, output_token_limit=128): 60 | cur_post = query.replace('\n', ' ') 61 | prompt = (f'{cur_post}\n\n' 62 | f'Instructions:\n' 63 | f'1. Identify the essential problem.\n' 64 | f'2. Think step by step to reason and describe what information could be relevant and helpful to address the questions in detail.\n' 65 | # f'Your answer must be written within {output_token_limit} tokens.' 66 | f'3. Draft an answer with as many thoughts as you have.\n' 67 | ) 68 | return prompt 69 | 70 | 71 | def fill_sys_prompt(prompt, queries_per_doc=1): 72 | return prompt.format(num_questions=queries_per_doc) 73 | 74 | 75 | def fill_user_prompt(doc): 76 | user_prompt = '''The document is given below: 77 | 78 | 79 | {document} 80 | 81 | 82 | Please start generating the questions.''' 83 | return user_prompt.format(document=doc) 84 | 85 | prompt_registry = { 86 | 'baseline': DOC2QUERY_BASELINE, 87 | 'hq_gen': DOC2HARD_QUERY, 88 | 'cot_bright': PROMPT_COT_BRIGHT, 89 | } -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReasonIR 2 | 3 | ReasonIR-8B is the first retriever specifically trained for general reasoning tasks, achieving the state-of-the-art retrieval performance on BRIGHT (reasoning-intensive retrieval). 4 | When employed for retrieval-augmented generation (RAG), ReasonIR-8B also brings substantial gains on MMLU and GPQA. 5 | 6 | - Paper: https://arxiv.org/abs/2504.20595 7 | - Model: https://huggingface.co/reasonir/ReasonIR-8B 8 | - Data: https://huggingface.co/datasets/reasonir/reasonir-data 9 | 10 | 11 | ## General Usage 12 | Make sure to install `transformers>=4.47.0` first! 13 | 14 | ### Transformers 15 | 16 | ```python 17 | from transformers import AutoModel, AutoTokenizer 18 | model = AutoModel.from_pretrained("reasonir/ReasonIR-8B", torch_dtype="auto", trust_remote_code=True) 19 | 20 | query = "The quick brown fox jumps over the lazy dog." 21 | document = "The quick brown fox jumps over the lazy dog." 22 | query_instruction = "" 23 | doc_instruction = "" 24 | model = model.to("cuda") 25 | model.eval() 26 | query_emb = model.encode(query, instruction=query_instruction) 27 | doc_emb = model.encode(document, instruction=doc_instruction) 28 | sim = query_emb @ doc_emb.T 29 | ``` 30 | 31 | When using `AutoModel`, it is important to: 32 | 33 | 1. Include `trust_remote_code=True` to make sure our custom bidirectional encoding architecture is used. 34 | 2. Use `torch_dtype="auto"` so that `bf16` is activated (by default torch will use `fp32`). 35 | 36 | ### Sentence Transformers 37 | 38 | Ordinary retrieval models that use mean pooling can automatically be used with SentenceTransformer after being published on Huggingface. 39 | 40 | ```python 41 | from sentence_transformers import SentenceTransformer 42 | model_kwargs = {"torch_dtype": "auto"} 43 | model = SentenceTransformer("reasonir/ReasonIR-8B", trust_remote_code=True, model_kwargs=model_kwargs) 44 | model.set_pooling_include_prompt(include_prompt=False) # exclude the prompt during pooling 45 | 46 | query = "The quick brown fox jumps over the lazy dog." 47 | document = "The quick brown fox jumps over the lazy dog." 48 | query_instruction = "" 49 | doc_instruction = "" 50 | query_emb = model.encode(query, instruction=query_instruction) 51 | doc_emb = model.encode(document, instruction=doc_instruction) 52 | sim = query_emb @ doc_emb.T 53 | ``` 54 | 55 | It is important to also include `trust_remote_code=True` and `torch_dtype="auto"` as discussed earlier. 56 | 57 | NOTE: there seems to be some very slight floating point discrepancy when using the SentenceTransformer (because it does not support bf16 precision), though it should not affect the results in general. 58 | 59 | ## Evaluations 60 | Please refer to the instructions in [`evaluation/`](https://github.com/facebookresearch/ReasonIR/tree/main/evaluation). 61 | 62 | ## Synthetic Data Generation 63 | Please refer to the instructions in [`synthetic_data_generation/`](https://github.com/facebookresearch/ReasonIR/tree/main/synthetic_data_generation). 64 | 65 | ## Test Time Scaling Techniques 66 | Please refer to the instructions in [`test_time_techniques/`](https://github.com/facebookresearch/ReasonIR/tree/main/test_time_techniques). 67 | 68 | ## Retriever Training 69 | Please refer to the instructions in [`training/`](https://github.com/facebookresearch/ReasonIR/tree/main/training). 70 | 71 | ## Citation 72 | ``` 73 | @article{shao2025reasonir, 74 | title={ReasonIR: Training Retrievers for Reasoning Tasks}, 75 | author={Rulin Shao and Rui Qiao and Varsha Kishore and Niklas Muennighoff and Xi Victoria Lin and Daniela Rus and Bryan Kian Hsiang Low and Sewon Min and Wen-tau Yih and Pang Wei Koh and Luke Zettlemoyer}, 76 | year={2025}, 77 | journal={arXiv preprint arXiv:2504.20595}, 78 | url={https://arxiv.org/abs/2504.20595}, 79 | } 80 | ``` 81 | 82 | ## License 83 | ReasonIR is FAIR Noncommercial Research License licensed, as found in the LICENSE file. 84 | 85 | ## Acknowledgments 86 | We thank the following great open-source repositories: 87 | - [BRIGHT](https://github.com/xlang-ai/BRIGHT) 88 | - [GritLM](https://github.com/ContextualAI/gritlm) 89 | - [MassiveDS](https://github.com/RulinShao/retrieval-scaling) 90 | -------------------------------------------------------------------------------- /evaluation/rag/datastore/construct_datastore_corpus.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pdb 4 | 5 | 6 | 7 | def load_jsonl(path): 8 | data = [] 9 | with open(path, 'r') as fin: 10 | for line in fin: 11 | data.append(json.loads(line)) 12 | return data 13 | 14 | def write_to_jsonl(data, path): 15 | with open(path, 'w') as fout: 16 | for ex in data: 17 | fout.write(json.dumps(ex) + '\n') 18 | 19 | 20 | def construct_datastore_pool(input_dir, output_dir, top_k=float('inf')): 21 | """ 22 | (Do not use this function which will save duplicated documents.) 23 | Construct a new corpus of the retrieved documents by Contriever for MMLU. 24 | Use this data pool to test out complex retrieval pipeline such as CoT, IRCoT. 25 | """ 26 | 27 | os.makedirs(output_dir, exist_ok=True) 28 | 29 | for i, filename in enumerate(os.listdir(input_dir)): 30 | input_file = os.path.join(input_dir, filename) 31 | output_file = os.path.join(output_dir, filename.replace('_retrieved_results.jsonl', '.jsonl')) 32 | print(f"{i}: Processing {output_file}") 33 | 34 | if os.path.exists(output_file): 35 | continue 36 | 37 | data = load_jsonl(input_file) 38 | 39 | all_retrieved_documents = [] 40 | for idx, ex in enumerate(data): 41 | K = min(top_k, len(ex['ctxs'])) 42 | retrieved_documents = [{'id': idx, 'source': ctx['source'], 'text': ctx['retrieval text']} for ctx in ex['ctxs'][:K]] 43 | all_retrieved_documents.extend(retrieved_documents) 44 | 45 | write_to_jsonl(all_retrieved_documents, output_file) 46 | 47 | 48 | def construct_deduplicated_datastore_pool(input_dir, output_dir, top_k=float('inf')): 49 | """ 50 | Construct a new corpus of the retrieved documents by Contriever for MMLU. 51 | Use this data pool to test out complex retrieval pipeline such as CoT, IRCoT. 52 | Conduct deduplication on the raw data. 53 | """ 54 | 55 | os.makedirs(output_dir, exist_ok=True) 56 | output_file = os.path.join(output_dir, 'mmlu_datastore_pool_dedupped.jsonl') 57 | if os.path.exists(output_file): 58 | print(f"Found prebuilt {output_file}") 59 | return 60 | 61 | unique_documents = set() 62 | all_retrieved_documents = [] 63 | idx = 0 64 | 65 | for i, filename in enumerate(os.listdir(input_dir)): 66 | input_file = os.path.join(input_dir, filename) 67 | 68 | data = load_jsonl(input_file) 69 | 70 | for ex in data: 71 | K = min(top_k, len(ex['ctxs'])) 72 | for ctx in ex['ctxs'][:K]: 73 | retrieval_text = ctx['retrieval text'] 74 | if retrieval_text not in unique_documents: 75 | all_retrieved_documents.append({'id': idx, 'source': ctx['source'], 'text': retrieval_text}) 76 | unique_documents.add(retrieval_text) 77 | idx += 1 78 | else: 79 | print(f"The document is in the pool. Skipped.") 80 | 81 | print(f"{i}: Proccessed {input_file}, {idx} passages added.") 82 | 83 | write_to_jsonl(all_retrieved_documents, output_file) 84 | 85 | 86 | def construct_deduplicated_datastore_pool_from_api_searched_results(input_dir, output_dir): 87 | """ 88 | Construct a new corpus of the retrieved documents by Contriever API for GPQA. 89 | Use this data pool to test out complex retrieval pipeline such as CoT, IRCoT. 90 | Conduct deduplication on the raw data. 91 | """ 92 | 93 | os.makedirs(output_dir, exist_ok=True) 94 | output_file = os.path.join(output_dir, 'gpqa_datastore_pool_dedupped.jsonl') 95 | if os.path.exists(output_file): 96 | print(f"Found prebuilt {output_file}") 97 | return 98 | 99 | unique_documents = set() 100 | all_retrieved_documents = [] 101 | idx = 0 102 | 103 | for i, filename in enumerate(os.listdir(input_dir)): 104 | input_file = os.path.join(input_dir, filename) 105 | 106 | data = load_jsonl(input_file) 107 | 108 | for ex in data: 109 | results = ex['results']['results'] 110 | for retrieval_text in results["passages"][0]: 111 | if retrieval_text not in unique_documents: 112 | assert isinstance(retrieval_text, str) 113 | all_retrieved_documents.append({'id': idx, 'source': 'MassiveDS API', 'text': retrieval_text}) 114 | unique_documents.add(retrieval_text) 115 | idx += 1 116 | else: 117 | print(f"The document is in the pool. Skipped.") 118 | 119 | print(f"{i}: Proccessed {input_file}, {idx} passages added.") 120 | 121 | write_to_jsonl(all_retrieved_documents, output_file) 122 | -------------------------------------------------------------------------------- /synthetic_data_generation/hard_negative_mining.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from pyserini import analysis 3 | from gensim.corpora import Dictionary 4 | from gensim.models import LuceneBM25Model 5 | from gensim.similarities import SparseMatrixSimilarity 6 | import pdb 7 | import json 8 | 9 | 10 | class BM25_Miner(): 11 | def __init__(self, documents=None, doc_ids=None, task=None, long_context=False, cache_dir='cache', dataset='bright', data_path=None): 12 | assert (documents is not None and doc_ids is not None) or (dataset in ['bright'] and task is not None) or dataset in ['msmarco'] 13 | if documents is None: 14 | if dataset == 'msmarco': 15 | documents, doc_ids = self.get_ms_marco_documents(cache_dir) 16 | elif dataset == 'bright': 17 | documents, doc_ids = self.get_bright_documents(task, long_context, cache_dir) 18 | else: 19 | raise ValueError("Invalid dataset") 20 | self.dataset = dataset 21 | self.task = task if dataset in ['bright'] else None 22 | self.documents = documents 23 | self.doc_ids = doc_ids 24 | self.hashed_documents = self.get_hashed_documents(documents, doc_ids) 25 | self.analyzer = analysis.Analyzer(analysis.get_lucene_analyzer()) 26 | corpus = [self.analyzer.analyze(x) for x in documents] 27 | self.dictionary = Dictionary(corpus) 28 | self.model = LuceneBM25Model(dictionary=self.dictionary, k1=0.9, b=0.4) 29 | bm25_corpus = self.model[list(map(self.dictionary.doc2bow, corpus))] 30 | self.bm25_index = SparseMatrixSimilarity(bm25_corpus, num_docs=len(corpus), num_terms=len(self.dictionary), 31 | normalize_queries=False, normalize_documents=False) 32 | 33 | def search(self, query): 34 | query = self.analyzer.analyze(query) 35 | bm25_query = self.model[self.dictionary.doc2bow(query)] 36 | similarities = self.bm25_index[bm25_query].tolist() 37 | all_scores = {} 38 | for did, s in zip(self.doc_ids, similarities): 39 | all_scores[did] = s 40 | cur_scores = sorted(all_scores.items(),key=lambda x:x[1],reverse=True)[:1000] 41 | all_scores = {} 42 | for pair in cur_scores: 43 | all_scores[pair[0]] = pair[1] 44 | return all_scores 45 | 46 | def get_ms_marco_documents(self, cache_dir): 47 | dataset = load_dataset("microsoft/ms_marco", "v1.1") 48 | # dataset = load_dataset("microsoft/ms_marco", "v2.1") 49 | doc_ids = [] 50 | documents = [] 51 | max_length = 0 52 | for dp in dataset['train']: 53 | passages = dp['passages']['passage_text'] 54 | doc = ' '.join(passages) 55 | # chunk the document into 2000 words 56 | doc_split = doc.split() 57 | max_length = max(max_length, len(doc_split)) 58 | doc = ' '.join(doc_split[:2000]) 59 | documents.append(doc) 60 | doc_ids.append(str(dp['query_id'])) 61 | print(f"Max length of all documents: {max_length}") 62 | return documents, doc_ids 63 | 64 | def get_bright_documents(self, task, long_context, cache_dir): 65 | if long_context: 66 | doc_pairs = load_dataset('xlangai/BRIGHT', 'long_documents', cache_dir=cache_dir)[task] 67 | else: 68 | doc_pairs = load_dataset('xlangai/BRIGHT', 'documents', cache_dir=cache_dir)[task] 69 | 70 | doc_ids = [] 71 | documents = [] 72 | for dp in doc_pairs: 73 | doc_ids.append(str(dp['id'])) 74 | documents.append(dp['content']) 75 | return documents, doc_ids 76 | 77 | def get_hashed_documents(self, documents, doc_ids): 78 | hashed_documents = {} 79 | for docid, doc in zip(doc_ids, documents): 80 | hashed_documents[docid] = doc 81 | return hashed_documents 82 | 83 | 84 | def get_documents_text(self, docids): 85 | return [self.hashed_documents[docid] for docid in docids] 86 | 87 | def select_hard_negatives(self, query, gold_doc, num_neg=1, hard_neg_start_index=20): 88 | scores = self.search(query) 89 | 90 | num_added = 0 91 | hard_negatives_ids = [] 92 | sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) 93 | for i, (doc_id, score) in enumerate(sorted_scores): 94 | if i >= hard_neg_start_index: 95 | # avoid selecting false negative 96 | if self.hashed_documents[doc_id] != gold_doc: 97 | hard_negatives_ids.append(doc_id) 98 | num_added += 1 99 | if num_added == num_neg: 100 | break 101 | 102 | hard_negative_documents = self.get_documents_text(hard_negatives_ids) 103 | return hard_negative_documents 104 | -------------------------------------------------------------------------------- /evaluation/rag/gpqa/src/utils/math_equivalence.py: -------------------------------------------------------------------------------- 1 | def _fix_fracs(string): 2 | substrs = string.split("\\frac") 3 | new_str = substrs[0] 4 | if len(substrs) > 1: 5 | substrs = substrs[1:] 6 | for substr in substrs: 7 | new_str += "\\frac" 8 | if substr[0] == "{": 9 | new_str += substr 10 | else: 11 | try: 12 | assert len(substr) >= 2 13 | except: 14 | return string 15 | a = substr[0] 16 | b = substr[1] 17 | if b != "{": 18 | if len(substr) > 2: 19 | post_substr = substr[2:] 20 | new_str += "{" + a + "}{" + b + "}" + post_substr 21 | else: 22 | new_str += "{" + a + "}{" + b + "}" 23 | else: 24 | if len(substr) > 2: 25 | post_substr = substr[2:] 26 | new_str += "{" + a + "}" + b + post_substr 27 | else: 28 | new_str += "{" + a + "}" + b 29 | string = new_str 30 | return string 31 | 32 | def _fix_a_slash_b(string): 33 | if len(string.split("/")) != 2: 34 | return string 35 | a = string.split("/")[0] 36 | b = string.split("/")[1] 37 | try: 38 | a = int(a) 39 | b = int(b) 40 | assert string == "{}/{}".format(a, b) 41 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 42 | return new_string 43 | except: 44 | return string 45 | 46 | def _remove_right_units(string): 47 | # "\\text{ " only ever occurs (at least in the val set) when describing units 48 | if "\\text{ " in string: 49 | splits = string.split("\\text{ ") 50 | assert len(splits) == 2 51 | return splits[0] 52 | else: 53 | return string 54 | 55 | def _fix_sqrt(string): 56 | if "\\sqrt" not in string: 57 | return string 58 | splits = string.split("\\sqrt") 59 | new_string = splits[0] 60 | for split in splits[1:]: 61 | if split[0] != "{": 62 | a = split[0] 63 | new_substr = "\\sqrt{" + a + "}" + split[1:] 64 | else: 65 | new_substr = "\\sqrt" + split 66 | new_string += new_substr 67 | return new_string 68 | 69 | def _strip_string(string): 70 | # linebreaks 71 | string = string.replace("\n", "") 72 | #print(string) 73 | 74 | # remove inverse spaces 75 | string = string.replace("\\!", "") 76 | #print(string) 77 | 78 | # replace \\ with \ 79 | string = string.replace("\\\\", "\\") 80 | #print(string) 81 | 82 | # replace tfrac and dfrac with frac 83 | string = string.replace("tfrac", "frac") 84 | string = string.replace("dfrac", "frac") 85 | #print(string) 86 | 87 | # remove \left and \right 88 | string = string.replace("\\left", "") 89 | string = string.replace("\\right", "") 90 | #print(string) 91 | 92 | # Remove circ (degrees) 93 | string = string.replace("^{\\circ}", "") 94 | string = string.replace("^\\circ", "") 95 | 96 | # remove dollar signs 97 | string = string.replace("\\$", "") 98 | 99 | # remove units (on the right) 100 | string = _remove_right_units(string) 101 | 102 | # remove percentage 103 | string = string.replace("\\%", "") 104 | string = string.replace("\%", "") 105 | 106 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 107 | string = string.replace(" .", " 0.") 108 | string = string.replace("{.", "{0.") 109 | # if empty, return empty string 110 | if len(string) == 0: 111 | return string 112 | if string[0] == ".": 113 | string = "0" + string 114 | 115 | # to consider: get rid of e.g. "k = " or "q = " at beginning 116 | if len(string.split("=")) == 2: 117 | if len(string.split("=")[0]) <= 2: 118 | string = string.split("=")[1] 119 | 120 | # fix sqrt3 --> sqrt{3} 121 | string = _fix_sqrt(string) 122 | 123 | # remove spaces 124 | string = string.replace(" ", "") 125 | 126 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 127 | string = _fix_fracs(string) 128 | 129 | # manually change 0.5 --> \frac{1}{2} 130 | if string == "0.5": 131 | string = "\\frac{1}{2}" 132 | 133 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 134 | string = _fix_a_slash_b(string) 135 | 136 | return string 137 | 138 | def is_equiv(str1, str2, verbose=False): 139 | if str1 is None and str2 is None: 140 | print("WARNING: Both None") 141 | return True 142 | if str1 is None or str2 is None: 143 | return False 144 | 145 | try: 146 | ss1 = _strip_string(str1) 147 | ss2 = _strip_string(str2) 148 | if verbose: 149 | print(ss1, ss2) 150 | return ss1 == ss2 151 | except: 152 | return str1 == str2 -------------------------------------------------------------------------------- /evaluation/rag/gpqa/src/utils/hydra_runner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import functools 3 | import os 4 | import sys 5 | from typing import Any, Callable, Optional 6 | 7 | from hydra._internal.utils import _run_hydra, get_args_parser 8 | from hydra.core.config_store import ConfigStore 9 | from hydra.types import TaskFunction 10 | from omegaconf import DictConfig, OmegaConf 11 | 12 | 13 | def _get_gpu_name(): 14 | try: 15 | import pynvml 16 | except (ImportError, ModuleNotFoundError): 17 | return None 18 | 19 | pynvml.nvmlInit() 20 | handle = pynvml.nvmlDeviceGetHandleByIndex(0) 21 | cuda_capability, _ = pynvml.nvmlDeviceGetCudaComputeCapability(handle) 22 | pynvml.nvmlShutdown() 23 | if cuda_capability == 8: 24 | return "a100" 25 | elif cuda_capability == 9: 26 | return "h100" 27 | else: 28 | return None 29 | 30 | 31 | OmegaConf.register_new_resolver("gpu_name", _get_gpu_name) 32 | 33 | # multiple interpolated values in the config 34 | OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) 35 | 36 | 37 | def hydra_runner( 38 | config_path: Optional[str] = ".", config_name: Optional[str] = None, schema: Optional[Any] = None 39 | ) -> Callable[[TaskFunction], Any]: 40 | """ 41 | Decorator used for passing the Config paths to main function. 42 | Optionally registers a schema used for validation/providing default values. 43 | 44 | Args: 45 | config_path: Optional path that will be added to config search directory. 46 | NOTE: The default value of `config_path` has changed between Hydra 1.0 and Hydra 1.1+. 47 | Please refer to https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_hydra_main_config_path/ 48 | for details. 49 | config_name: Pathname of the config file. 50 | schema: Structured config type representing the schema used for validation/providing default values. 51 | """ 52 | 53 | def decorator(task_function: TaskFunction) -> Callable[[], None]: 54 | @functools.wraps(task_function) 55 | def wrapper(cfg_passthrough: Optional[DictConfig] = None) -> Any: 56 | # Check it config was passed. 57 | if cfg_passthrough is not None: 58 | return task_function(cfg_passthrough) 59 | else: 60 | args = get_args_parser() 61 | 62 | # Parse arguments in order to retrieve overrides 63 | parsed_args = args.parse_args() # type: argparse.Namespace 64 | 65 | # Get overriding args in dot string format 66 | overrides = parsed_args.overrides # type: list 67 | 68 | # Disable the creation of .hydra subdir 69 | # https://hydra.cc/docs/tutorials/basic/running_your_app/working_directory 70 | overrides.append("hydra.output_subdir=null") 71 | # Hydra logging outputs only to stdout (no log file). 72 | # https://hydra.cc/docs/configure_hydra/logging 73 | overrides.append("hydra/job_logging=stdout") 74 | 75 | # Set run.dir ONLY for ExpManager "compatibility" - to be removed. 76 | overrides.append("hydra.run.dir=.") 77 | 78 | # Check if user set the schema. 79 | if schema is not None: 80 | # Create config store. 81 | cs = ConfigStore.instance() 82 | 83 | # Get the correct ConfigStore "path name" to "inject" the schema. 84 | if parsed_args.config_name is not None: 85 | path, name = os.path.split(parsed_args.config_name) 86 | # Make sure the path is not set - as this will disable validation scheme. 87 | if path != '': 88 | sys.stderr.write( 89 | f"ERROR Cannot set config file path using `--config-name` when " 90 | "using schema. Please set path using `--config-path` and file name using " 91 | "`--config-name` separately.\n" 92 | ) 93 | sys.exit(1) 94 | else: 95 | name = config_name 96 | 97 | # Register the configuration as a node under the name in the group. 98 | cs.store(name=name, node=schema) # group=group, 99 | 100 | # Wrap a callable object with name `parse_args` 101 | # This is to mimic the ArgParser.parse_args() API. 102 | def parse_args(self, args=None, namespace=None): 103 | return parsed_args 104 | 105 | parsed_args.parse_args = parse_args 106 | 107 | # no return value from run_hydra() as it may sometime actually run the task_function 108 | # multiple times (--multirun) 109 | # argparse_wrapper = _argparse_wrapper(args) 110 | argparse_wrapper = parsed_args 111 | 112 | _run_hydra( 113 | args=argparse_wrapper, 114 | args_parser=args, 115 | task_function=task_function, 116 | config_path=config_path, 117 | config_name=config_name, 118 | ) 119 | 120 | return wrapper 121 | 122 | return decorator -------------------------------------------------------------------------------- /evaluation/rag/gpqa/src/workflow/naive_rag.py: -------------------------------------------------------------------------------- 1 | import time 2 | from tqdm import tqdm 3 | import re 4 | from dataclasses import dataclass, field 5 | 6 | from src.utils.cache_utils import Cache 7 | from apis.base import ( 8 | search_api, 9 | format_document_string, 10 | ) 11 | from apis.offline_cached_massiveds_searched_results import search_offline_cached_massiveds 12 | from src.prompts.task_instructions import ( 13 | get_task_user_prompt 14 | ) 15 | 16 | 17 | @dataclass 18 | class NaiveRAGOutput: 19 | all_relevant_info: None 20 | input_prompts: None 21 | output_list: None 22 | total_time: None 23 | 24 | 25 | class NaiveRAG: 26 | def __init__(self, llm, cfg, search_llm=None): 27 | self.cfg = cfg 28 | self.llm = llm 29 | self.search_llm = search_llm 30 | self.cache = Cache(cfg) 31 | 32 | def run_search(self, data): 33 | print("Performing Bing Web Searches for all questions...") 34 | 35 | # Initialize a list to hold relevant information for each question 36 | all_relevant_info = [] 37 | 38 | if self.cfg.search_engine == 'offline_massiveds': 39 | all_relevant_info = search_offline_cached_massiveds(data, self.cfg.top_k) 40 | else: 41 | for item in tqdm(data, desc="Searching"): 42 | question = item['Question'] 43 | # Check if the question has already been searched and cached 44 | if self.cache and question in self.cache.search_cache: 45 | results = self.cache.search_cache[question] 46 | else: 47 | search_question = question 48 | results, new_cache = search_api( 49 | self.cfg.search_engine, 50 | search_question, 51 | self.llm if self.search_llm is None else self.search_llm, 52 | self.cfg.model_path, 53 | self.cfg.use_query_rewriting, 54 | self.cache, 55 | ) 56 | self.cache = new_cache 57 | # self.cache.search_cache[question] = results['searched_results'] 58 | all_relevant_info.append(results) 59 | 60 | # Save search cache after retrieval 61 | self.cache.save_caches() 62 | print("Search cache saved.") 63 | return all_relevant_info 64 | 65 | def get_naive_rag_instruction(self, question, documents): 66 | return ( 67 | "You are a knowledgeable assistant that uses the provided documents to answer the user's question.\n\n" 68 | "Question:\n" 69 | f"{question}\n" 70 | "Documents:\n" 71 | f"{documents}\n" 72 | ) 73 | 74 | def prepare_full_context(self, data, all_relevant_info): 75 | print("Constructing prompts for generation...") 76 | input_prompts = [] 77 | 78 | # Set max input tokens to leave reasonable room for output while maximizing context usage 79 | # Model context length is 32,768 tokens 80 | # Reserve ~8k tokens for output, leaving ~24k for input 81 | max_input_tokens = 24000 82 | 83 | for idx, item in enumerate(tqdm(data, desc="Constructing Prompts")): 84 | question = item['Question'] 85 | formatted_documents = format_document_string(self.cfg.search_engine, all_relevant_info[idx], self.cfg.top_k, max_doc_len=max_input_tokens) # Increased document length 86 | 87 | instruction = self.get_naive_rag_instruction(question, formatted_documents) 88 | user_prompt = get_task_user_prompt(self.cfg.dataset_name, self.cfg.model_path, question) 89 | full_prompt = instruction + "\n\n" + user_prompt 90 | 91 | input_prompts.append(full_prompt) 92 | 93 | return input_prompts 94 | 95 | def llm_generate(self, input_prompts): 96 | # Generate model outputs 97 | print("Generating answers with LLM...") 98 | output_list = [] 99 | for input_prompt in tqdm(input_prompts): 100 | try: 101 | chat_completion = self.llm.chat.completions.create( 102 | model=self.cfg.model_path, 103 | messages=[{"role": "system", "content": "You are a helpful assistant."}, 104 | {"role": "user", "content": input_prompt}], 105 | max_tokens=8000, # Increased to allow for longer outputs while staying within context limit 106 | temperature=self.cfg.temperature, 107 | top_p=self.cfg.top_p, 108 | frequency_penalty=self.cfg.repetition_penalty, 109 | ) 110 | output_list.append(chat_completion.choices[0].message.content) 111 | except Exception as e: 112 | print(f"Error during generation: {str(e)}") 113 | output_list.append("Error: Failed to generate response due to context length.") 114 | continue 115 | 116 | return output_list 117 | 118 | def run(self, data): 119 | 120 | all_relevant_info = self.run_search(data) 121 | input_prompts = self.prepare_full_context(data, all_relevant_info) 122 | 123 | start_time = time.time() 124 | output_list = self.llm_generate(input_prompts) 125 | total_time = time.time() - start_time 126 | 127 | return NaiveRAGOutput( 128 | all_relevant_info=all_relevant_info, 129 | input_prompts=input_prompts, 130 | output_list=output_list, 131 | total_time=total_time, 132 | ) -------------------------------------------------------------------------------- /synthetic_data_generation/lm_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | class OpenAILM(): 5 | def __init__(self, model_id, temperature=1, top_p=1, seed=None, base_url=None, api_key=None) -> None: 6 | from openai import OpenAI 7 | if api_key is None: 8 | OPENAI_KEY = os.environ["OPENAI_API_KEY"] 9 | else: 10 | OPENAI_KEY = api_key 11 | self.client = OpenAI(api_key=OPENAI_KEY, base_url=base_url) 12 | self.model_id = model_id 13 | self.temperature = temperature 14 | self.top_p = top_p 15 | self.seed = seed 16 | 17 | def generate(self, prompt, system_prompt=None): 18 | messages = self.apply_chat_template(prompt, system_prompt) 19 | output = self.client.chat.completions.create( 20 | model=self.model_id, # e.g., "gpt-4o-mini" with 128k context length 21 | messages=messages, 22 | temperature=self.temperature, 23 | top_p=self.top_p, 24 | seed=self.seed, 25 | ) 26 | self.log_output(output) 27 | response = output.choices[0].message.content 28 | return response 29 | 30 | def generate_batch(self, prompts, system_prompt=None): 31 | responses = [] 32 | for prompt in prompts: 33 | resp = self.generate(prompt, system_prompt) 34 | responses.append(resp) 35 | return responses 36 | 37 | def apply_chat_template(self, user_prompt, system_prompt): 38 | if system_prompt is not None: 39 | messages=[ 40 | {"role": "system", "content": system_prompt}, 41 | {"role": "user", "content": user_prompt}, 42 | ] 43 | else: 44 | messages = [ 45 | {"role": "system", "content": ""}, 46 | {"role": "user", "content": user_prompt}, 47 | ] 48 | return messages 49 | 50 | def log_output(self, output): 51 | print(output.choices[0].message.content) 52 | print(output.usage) 53 | 54 | 55 | class MultiTurnOpenAILM(OpenAILM): 56 | def __init__(self, model_id, temperature=1, top_p=1, seed=None) -> None: 57 | super().__init__(model_id, temperature=temperature, top_p=top_p, seed=seed) 58 | 59 | def generate(self, messages): 60 | output = self.client.chat.completions.create( 61 | model=self.model_id, # e.g., "gpt-4o-mini" with 128k context length 62 | messages=messages, 63 | temperature=self.temperature, 64 | top_p=self.top_p, 65 | seed=self.seed, 66 | ) 67 | self.log_output(output) 68 | response = output.choices[0].message.content 69 | return response 70 | 71 | def apply_chat_template(self, user_prompt, system_prompt=None): 72 | if system_prompt is None: 73 | system_prompt = "" 74 | messages=[ 75 | self.format_prompt(system_prompt, role="system"), 76 | self.format_prompt(user_prompt, role="user"), 77 | ] 78 | return messages 79 | 80 | def append_history(self, messages, prompt, role="user"): 81 | messages.append(self.format_prompt(prompt, role)) 82 | return messages 83 | 84 | def format_prompt(self, prompt, role="user"): 85 | if role == "user": 86 | return {"role": "user", "content": prompt} 87 | elif role == "system": 88 | return {"role": "system", "content": prompt} 89 | elif role == "assistant": 90 | return {"role": "assistant", "content": prompt} 91 | else: 92 | raise ValueError(f"Invalid role: {role}") 93 | 94 | 95 | class HFLM(): 96 | def __init__(self, model_id, temperature=0, top_p=1e-9, deterministic=False) -> None: 97 | import torch 98 | from vllm import LLM, SamplingParams 99 | from transformers import AutoTokenizer 100 | 101 | self.model = LLM( 102 | model_id, 103 | tensor_parallel_size=torch.cuda.device_count(), 104 | enable_chunked_prefill=False, 105 | gpu_memory_utilization=0.96, 106 | max_model_len=4096, 107 | ) 108 | 109 | if deterministic: 110 | self.sampling_params = SamplingParams( 111 | max_tokens=1536, 112 | temperature=0, 113 | top_p=1e-9, 114 | do_sample=False, 115 | num_beams=1, 116 | ) 117 | else: 118 | self.sampling_params = SamplingParams( 119 | max_tokens=1536, 120 | temperature=temperature, 121 | top_p=max(top_p, 1e-9), 122 | ) 123 | 124 | self.tokenizer = AutoTokenizer.from_pretrained(model_id) 125 | 126 | def generate(self, prompt, system_prompt=None): 127 | _prompt = prompt 128 | if system_prompt is not None: 129 | _prompt = self.apply_chat_template(system_prompt, prompt) 130 | output = self.model.generate(_prompt, self.sampling_params) 131 | 132 | self.log_output(output) 133 | return output[0].outputs[0].text 134 | 135 | def generate_batch(self, prompts, system_prompt=None): 136 | responses = [] 137 | for prompt in prompts: 138 | response = self.generate(prompt, system_prompt) 139 | responses.append(response) 140 | return responses 141 | 142 | def apply_chat_template(self, system_prompt, user_prompt): 143 | message = [ 144 | {"role": "system", "content": system_prompt}, 145 | {"role": "user", "content": user_prompt}, 146 | ] 147 | 148 | prompt = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False) 149 | return prompt 150 | 151 | def log_output(self, output): 152 | print(output[0].outputs[0].text) 153 | print(f"Number of input tokens: {len(output[0].prompt_token_ids)}") 154 | print(f"Number of output tokens: {len(output[0].outputs[0].token_ids)}") 155 | 156 | -------------------------------------------------------------------------------- /synthetic_data_generation/batch_api_helper.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import json 3 | from openai import OpenAI 4 | 5 | 6 | def write_jsonl(data, path): 7 | with open(path, 'w') as fout: 8 | for ex in data: 9 | fout.write(json.dumps(ex)+'\n') 10 | 11 | 12 | def format_request_json(messages, model="gpt-4o", custom_id="request-1"): 13 | return {"custom_id": custom_id, 14 | "method": "POST", 15 | "url": "/v1/chat/completions", 16 | "body": { 17 | "model": model, 18 | "messages": messages}} 19 | 20 | 21 | def process_batch_response(response): 22 | outputs = [] 23 | json_data = response.content.decode('utf-8') 24 | for line in json_data.splitlines(): 25 | # Parse the JSON record (line) to validate it 26 | json_record = json.loads(line) 27 | 28 | custom_id = json_record.get("custom_id") 29 | 30 | # Navigate to the 'choices' key within the 'response' -> 'body' 31 | choices = json_record.get("response", {}).get("body", {}).get("choices", []) 32 | 33 | # Loop through the choices to find messages with the 'assistant' role 34 | for choice in choices: 35 | message = choice.get("message", {}) 36 | if message.get("role") == "assistant": 37 | assistant_content = message.get("content") 38 | outputs.append({"id": custom_id, "response": assistant_content}) 39 | break 40 | return outputs 41 | 42 | 43 | class BatchAPIHelper(object): 44 | '''Helper class for calling OpenAI API for batch completions 45 | Usage: 46 | batch_helper = BatchAPIHelper(model_id, task_filename) 47 | batch_request_data = [] 48 | batch_request_ids = [] 49 | batch_helper.batch_request(batch_request_data, batch_request_ids) 50 | # gather the results after the batch job is completed. It will return None if the job is not completed yet. 51 | outputs = batch_helper.gather_results() 52 | batch_helper.clean_up() 53 | ''' 54 | def __init__(self, model_id, task_filename, batch_id=None): 55 | self.model_id = model_id 56 | self.client = OpenAI() 57 | 58 | self.task_filename = task_filename 59 | self.batch_request_filename = task_filename.replace('.jsonl', '_batch.jsonl') 60 | self.batch_id_filename = self.batch_request_filename.replace('.jsonl', '_batch_id.txt') 61 | self.batch_base_filename = self.batch_request_filename.replace('.jsonl', '_base.jsonl') 62 | 63 | self.batch_id = batch_id 64 | 65 | def gather_results(self): 66 | '''Gather the results from the batch job 67 | Return the outputs if the batch job is completed, otherwise return None 68 | ''' 69 | if self.batch_id is not None: 70 | batch_id = self.batch_id 71 | else: 72 | if not os.path.exists(self.batch_id_filename): 73 | print('Batch id file does not exist') 74 | raise FileNotFoundError 75 | if os.path.exists(self.task_filename): 76 | print('Warning: Results are already collected as the task file already exists') 77 | with open(self.batch_id_filename, 'r') as f: 78 | batch_id = f.read().strip() 79 | response = self.client.batches.retrieve(batch_id) 80 | if response.status == 'completed': 81 | response = self.client.files.content(response.output_file_id) 82 | outputs = process_batch_response(response) 83 | return outputs 84 | else: 85 | print('Batch job is not completed yet') 86 | print('Response status:', response.status) 87 | return None 88 | 89 | def batch_request(self, batch_request_data, batch_request_ids=None): 90 | '''the batch request data should be a list of messages 91 | 92 | Example: 93 | batch_request_data = [ 94 | [{"role": "system", "content": "Hello, how are you?"}, 95 | {"role": "user", "content": "I am doing well, thank you."}] 96 | ] 97 | 98 | Provide the batch_request_ids if you want to keep track of the request ids 99 | ''' 100 | if batch_request_ids is not None: 101 | assert len(batch_request_data) == len(batch_request_ids) 102 | batch_request_data = [format_request_json(messages, model=self.model_id, custom_id=custom_id) for messages, custom_id in zip(batch_request_data, batch_request_ids)] 103 | else: 104 | batch_request_data = [format_request_json(messages, model=self.model_id) for messages in batch_request_data] 105 | batch_request_filename = self.batch_request_filename 106 | write_jsonl(batch_request_data, batch_request_filename) 107 | response = self.client.files.create( 108 | file=open(batch_request_filename, "rb"), 109 | purpose="batch" 110 | ) 111 | file_id = response.id 112 | response = self.client.batches.create( 113 | input_file_id=file_id, 114 | endpoint="/v1/chat/completions", 115 | completion_window="24h" 116 | ) 117 | batch_id = response.id 118 | with open(self.batch_id_filename, 'w') as f: 119 | f.write(batch_id) 120 | 121 | def batch_save_base(self, base_data): 122 | '''Save the base data for the batch job 123 | The base data contains information that will be combined with the batch results 124 | ''' 125 | write_jsonl(base_data, self.batch_base_filename) 126 | 127 | def batch_load_base(self): 128 | '''Load the base data for the batch job 129 | The base data contains information that will be combined with the batch results 130 | ''' 131 | with open(self.batch_base_filename, 'r') as f: 132 | base_data = [json.loads(line) for line in f] 133 | return base_data 134 | 135 | def clean_up(self): 136 | if os.path.exists(self.batch_request_filename): 137 | os.remove(self.batch_request_filename) 138 | if os.path.exists(self.batch_id_filename): 139 | os.remove(self.batch_id_filename) 140 | if os.path.exists(self.batch_base_filename): 141 | os.remove(self.batch_base_filename) -------------------------------------------------------------------------------- /synthetic_data_generation/generate_reasoning_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import argparse 5 | from tqdm import tqdm 6 | import time 7 | 8 | from hard_negative_mining import BM25_Miner 9 | from data_gen_prompts import * 10 | from gen_utils import * 11 | from data_gen_prompts import * 12 | from lm_helper import MultiTurnOpenAILM, OpenAILM, HFLM 13 | from batch_api_helper import BatchAPIHelper 14 | import re 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--model_id', type=str, default='gpt-4o', help='model id') 20 | parser.add_argument('--prompt_id', type=str, default=None, help='prompt id') 21 | parser.add_argument('--debug', action='store_true', help='debug mode') 22 | parser.add_argument('--base_dir', type=str, default='synthetic_data/top100', help='base directory to save the generated data') 23 | parser.add_argument('--num_docs', type=int, default=None, help='number of samples to collect') 24 | parser.add_argument('--dataset', type=str, default='bright', help='dataset to collect the data for') 25 | parser.add_argument('--subject', type=str, default='biology', help='subject to collect the data for') 26 | parser.add_argument("--temperature", type=float, default=0) 27 | parser.add_argument("--top_p", type=float, default=0) 28 | parser.add_argument("--no_batch", action='store_true', help='disable batch generation') 29 | parser.add_argument("--gather_results", action='store_true', help='gather batch results') 30 | args = parser.parse_args() 31 | print(args) 32 | 33 | # load the queries 34 | model_id = args.model_id 35 | model_id_str = model_id.split('/')[-1] 36 | num_docs = args.num_docs 37 | if args.prompt_id is not None: 38 | train_data_path = os.path.join(args.base_dir, 'all_docs_train_data', args.prompt_id, model_id_str) 39 | output_path = f'{args.base_dir}/reasoning_data/{args.prompt_id}/{model_id_str}' 40 | else: 41 | train_data_path = os.path.join(args.base_dir, 'all_docs_train_data', model_id_str) 42 | output_path = f'{args.base_dir}/reasoning_data/{model_id_str}' 43 | train_data_path = os.path.expanduser(train_data_path) 44 | from gen_utils import load_training_data 45 | subject = args.dataset if args.dataset in ['msmarco'] else args.subject 46 | data = load_training_data(train_data_path, num_docs=num_docs, subject=subject) 47 | 48 | # Initialize the model 49 | if 'gpt' in model_id: 50 | model = OpenAILM(model_id, temperature=args.temperature, top_p=args.top_p, seed=0) 51 | else: 52 | model = HFLM(model_id, temperature=args.temperature, top_p=args.top_p) 53 | 54 | system_prompt = PROMPT_COT_BRIGHT 55 | 56 | print(f'System prompt:\n{system_prompt}') 57 | os.makedirs(output_path, exist_ok=True) 58 | 59 | 60 | if not args.no_batch or args.gather_results: 61 | for subject in data.keys(): 62 | subject_filename = f'{output_path}/{subject}_{num_docs}.jsonl' 63 | subject_data = data[subject] 64 | final_training_data = [] 65 | 66 | batch_helper = BatchAPIHelper(model_id, subject_filename) 67 | if args.gather_results: 68 | outputs = None 69 | while outputs is None: 70 | outputs = batch_helper.gather_results() 71 | if outputs is None: 72 | print("Waiting for the batch job to finish...") 73 | time.sleep(60) 74 | 75 | if len(outputs) == len(subject_data): 76 | for output, datum in zip(outputs, subject_data): 77 | response = output['response'] 78 | datum['reasoning'] = response 79 | write_jsonl(subject_data, subject_filename) 80 | else: 81 | print("Warning: Some outputs are missing.") 82 | print("Collecting results using the ids") 83 | outputs_dict = {output['id']: output for output in outputs} 84 | base_data = batch_helper.batch_load_base() 85 | for base_datum in base_data: 86 | query_id = base_datum['id'] 87 | query = base_datum['query'] 88 | if query_id in outputs_dict: 89 | response = outputs_dict[query_id]['response'] 90 | base_datum['reasoning'] = response 91 | else: 92 | print(f"Warning: Missing response for query id: {query_id}") 93 | try: 94 | print("Regenerating reasoning...") 95 | reasoning = model.generate(query, system_prompt=system_prompt) 96 | base_datum['reasoning'] = reasoning 97 | except Exception as e: 98 | print(f"Error: {e}") 99 | base_datum['reasoning'] = None 100 | 101 | base_data_dict = {datum['query']: datum for datum in base_data if datum['reasoning'] is not None} 102 | for datum in subject_data: 103 | query = datum['query'] 104 | if query in base_data_dict: 105 | datum['reasoning'] = base_data_dict[query]['reasoning'] 106 | final_training_data.append(datum) 107 | write_jsonl(final_training_data, subject_filename) 108 | batch_helper.clean_up() 109 | continue 110 | 111 | batch_request_messages = [] 112 | batch_request_ids = [] 113 | batch_base_data = [] 114 | 115 | for i, datum in enumerate(tqdm(subject_data)): 116 | query = datum['query'] 117 | query_id = str(i) 118 | messages = model.apply_chat_template(query, system_prompt=system_prompt) 119 | batch_request_messages.append(messages) 120 | batch_request_ids.append(query_id) 121 | batch_base_data.append({'id': query_id, 'query': query, 'prompt': system_prompt}) 122 | batch_helper.batch_save_base(batch_base_data) 123 | batch_helper.batch_request(batch_request_messages, batch_request_ids) 124 | else: 125 | for subject in data.keys(): 126 | subject_filename = f'{output_path}/{subject}_{num_docs}.jsonl' 127 | subject_data = data[subject] 128 | for datum in tqdm(subject_data): 129 | query = datum['query'] 130 | reasoning = model.generate(query, system_prompt=system_prompt) 131 | datum['reasoning'] = reasoning 132 | if args.debug: 133 | break 134 | write_jsonl(subject_data, subject_filename) 135 | -------------------------------------------------------------------------------- /synthetic_data_generation/supplement_negative_passage.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import re 4 | import json 5 | import argparse 6 | from tqdm import tqdm 7 | import torch.distributed as dist 8 | from lm_helper import HFLM 9 | 10 | 11 | sample_generation_prompt = """ 12 | You have been assigned a passage generation task: 13 | 14 | You will be provided an incomplete data with the below information 15 | - "input": a string, a random input specified by one task. 16 | - "positive_document": a string, a relevant document for the "input" according to the task. 17 | 18 | Your task is to generate a "hard_negative_document" in a JSON format: 19 | - The "hard_negative_document" contains some relevant information with superficial lexical overlapping, but it should be not helpful to address the question in the input and is less relevant to the input compared with the "positive_document". 20 | 21 | Please adhere to the following guidelines: 22 | - The values of "hard_negative_document" should be in {language}. 23 | - The "hard_negative_document" should be long documents (at least 300 words), avoid substantial word overlaps, otherwise the task would be too easy. 24 | - The "input", "positive_document", and "hard_negative_document" should be independent of each other. 25 | 26 | Your output must always be a JSON object only, do not explain yourself or output anything else. Be creative! 27 | 28 | Now process the below data following the above instruction: 29 | {input_str} 30 | 31 | Your response: 32 | """ 33 | 34 | 35 | language_options = ['English'] 36 | 37 | 38 | def sample_negative_passage_prompt(input_str): 39 | sample_prompt = sample_generation_prompt.format( 40 | language=random.choice(language_options), 41 | input_str=input_str, 42 | ) 43 | return sample_prompt 44 | 45 | 46 | class NegativePassageGenerator: 47 | def __init__(self, input_file, output_file, num_workers=None, worker_id=None) -> None: 48 | self.num_workers = num_workers 49 | self.worker_id = worker_id 50 | self.model = HFLM() 51 | self.data = self.load_file(input_file) 52 | self.output_file = output_file 53 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 54 | self.hashed_queries = set() 55 | if os.path.exists(output_file): 56 | with open(output_file, 'r') as fin: 57 | for line in fin: 58 | query = json.loads(line)['query'][-1] 59 | if query not in self.hashed_queries: 60 | self.hashed_queries.add(query) 61 | 62 | def load_file(self, input_file): 63 | data = [] 64 | with open(input_file, 'r') as fin: 65 | for line in fin: 66 | data.append(json.loads(line)) 67 | 68 | if self.num_workers is not None and self.worker_id is not None: 69 | shard_size = len(data) // self.num_workers 70 | remainder = len(data) % self.num_workers 71 | start_idx = self.worker_id * shard_size + min(self.worker_id, remainder) 72 | end_idx = start_idx + shard_size + (1 if self.worker_id < remainder else 0) 73 | return data[start_idx:end_idx] 74 | return data 75 | 76 | def prepare_input_string(self, orig_ex): 77 | query_str = orig_ex['query'][0]+orig_ex['query'][1] 78 | pos_str = orig_ex['pos'][0][0]+orig_ex['pos'][0][1] 79 | input_str = "" 80 | input_str += f"'input': {query_str}" 81 | input_str += f"'positive_document': {pos_str}" 82 | return input_str 83 | 84 | def generate_sample(self, orig_ex): 85 | input_str = self.prepare_input_string(orig_ex) 86 | sample_instruction = sample_negative_passage_prompt(input_str) 87 | outputs = self.model.generate(sample_instruction, '') 88 | return outputs 89 | 90 | def parse_generated_sample(self, outputs): 91 | outputs = outputs.replace('```json', '').replace('```', '').strip().replace('\n', '') 92 | outputs = re.sub(r'(\w)"s', r'\1\'s', outputs) 93 | outputs = re.sub(r'"\s+"hard_negative_document"', '", "hard_negative_document"', outputs) 94 | try: 95 | sample = json.loads(outputs) 96 | except: 97 | outputs = outputs.replace("'", '"').replace('\\"', '"').replace("\\'", "'") 98 | sample = json.loads(outputs) 99 | print(sample) 100 | return sample 101 | 102 | def generate(self,): 103 | new_samples = [] 104 | for orig_ex in tqdm(self.data): 105 | if orig_ex['query'][-1] in self.hashed_queries: 106 | print("Skipping one example that has been processed.") 107 | continue 108 | max_attempts = 3 109 | num_attempts = 0 110 | while num_attempts < max_attempts: 111 | num_attempts += 1 112 | try: 113 | sample = self.generate_sample(orig_ex) 114 | sample = self.parse_generated_sample(sample) 115 | assert sample['hard_negative_document'] and isinstance(sample['hard_negative_document'], str) 116 | orig_ex['neg'] = [['', sample['hard_negative_document']]] 117 | new_samples.append(orig_ex) 118 | break 119 | except Exception as e: 120 | print(f"An error occured: {e}") 121 | print(sample) 122 | if num_attempts > max_attempts: 123 | break 124 | 125 | if len(new_samples) % 10 == 0: 126 | with open(self.output_file, 'a+') as fout: 127 | for ex in new_samples: 128 | fout.write(json.dumps(ex) + '\n') 129 | new_samples = [] 130 | 131 | if len(new_samples) > 0: 132 | with open(self.output_file, 'a+') as fout: 133 | for ex in outputs: 134 | fout.write(json.dumps(ex) + '\n') 135 | 136 | 137 | if __name__ == '__main__': 138 | parser = argparse.ArgumentParser() 139 | parser.add_argument('--input_file', type=str, required=True) 140 | parser.add_argument('--num_workers', type=int, default=None, help="Total number of workers that will be used to generate negative documents.") 141 | parser.add_argument('--worker_id', type=int, default=None, help="Used to separate the output file to avoid I/O error when writing to the same file at the same time.") 142 | args = parser.parse_args() 143 | 144 | if args.worker_id is None or args.num_workers is None: 145 | output_file = args.input_file.replace('.jsonl', '_generated_negative.jsonl') 146 | else: 147 | output_file = args.input_file.replace('.jsonl', f'_generated_negative_{args.worker_id}_of_{args.num_workers}.jsonl') 148 | 149 | generator = NegativePassageGenerator(args.input_file, output_file, args.num_workers, args.worker_id) 150 | outputs = generator.generate() 151 | 152 | -------------------------------------------------------------------------------- /evaluation/bright/reranker.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import json 4 | from tqdm import tqdm 5 | import argparse 6 | from vllm import LLM, SamplingParams 7 | from datasets import load_dataset 8 | import torch 9 | import prompts 10 | 11 | class Reranker: 12 | def __init__(self, task): 13 | model_name = "Qwen/Qwen2.5-32B-Instruct" 14 | self.sampling_params = SamplingParams(temperature=0.6, top_p=0.9, max_tokens=32, logprobs=10) 15 | self.model = LLM(model=model_name, dtype="bfloat16", tensor_parallel_size=torch.cuda.device_count(), max_model_len=16384) 16 | 17 | retrieval_dict = { 18 | "aops": prompts.bright_aops, 19 | "theoremqa_questions": prompts.bright_theoremqa_questions, 20 | "leetcode": prompts.bright_leetcode, 21 | "pony": prompts.bright_pony, 22 | "theoremqa_theorems": prompts.bright_theoremqa_theorems 23 | } 24 | if args.task in retrieval_dict: 25 | self.prompt = retrieval_dict[task] 26 | else: 27 | self.prompt = prompts.bright_general 28 | 29 | # set random seed 30 | torch.manual_seed(42) 31 | 32 | def rerank(self, docs, query, topk): 33 | scores = [] 34 | 35 | batch_size = 50 36 | for i in range(0, len(docs), batch_size): 37 | list_docs = docs[i:i+batch_size] 38 | 39 | doc_prompts = [self.prompt.format(query, doc["text"]) for doc in list_docs] 40 | 41 | output = self.model.generate(doc_prompts, self.sampling_params) 42 | doc_prompt_outputs = [o.outputs[0].text for o in output] 43 | 44 | for j in range(len(doc_prompt_outputs)): 45 | pos_score = doc_prompt_outputs[j].rfind("Relevance score:") 46 | if pos_score != -1: 47 | try: 48 | # get the score from the prompt output 49 | score = float(doc_prompt_outputs[j][pos_score+16:pos_score+18])/5 50 | scores.append(score) 51 | except: 52 | print("In exception!! {}".format(doc_prompt_outputs[j])) 53 | scores.append(-1) 54 | else: 55 | scores.append(0) 56 | 57 | 58 | ranking = {doc["id"]: score for doc, score in zip(docs, scores)} 59 | ranking = dict(sorted(ranking.items(), key=lambda item: item[1], reverse=True)[:topk]) 60 | return ranking 61 | 62 | 63 | if __name__=='__main__': 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('--task', type=str, required=True, 66 | choices=['biology','earth_science','economics','pony','psychology','robotics','theoremqa_questions', "theoremqa_theorems", 67 | 'stackoverflow','sustainable_living','aops','leetcode']) 68 | parser.add_argument('--long_context', action='store_true') 69 | parser.add_argument('--retriever_score_file', type=str, default=None) 70 | parser.add_argument('--input_k', type=int) 71 | parser.add_argument('--k', type=int) 72 | parser.add_argument('--cache_dir', type=str, default='cache') 73 | parser.add_argument('--reasoning', type=str, default=None) 74 | parser.add_argument('--bm25_score_file', type=str, default=None) 75 | parser.add_argument('--output_dir', type=str, default=None) 76 | args = parser.parse_args() 77 | 78 | if args.reasoning is not None: 79 | raw_examples = load_dataset('xlangai/bright', f"{args.reasoning}_reason", cache_dir=args.cache_dir)[args.task] 80 | else: 81 | raw_examples = load_dataset('xlangai/bright', 'examples',cache_dir=args.cache_dir)[args.task] 82 | 83 | # raw_examples = load_dataset('xlangai/bright', 'examples', cache_dir=args.cache_dir)[args.task] 84 | examples = {} 85 | for e in raw_examples: 86 | examples[e['id']] = e 87 | if args.long_context: 88 | doc_pairs = load_dataset('xlangai/bright', 'long_documents', cache_dir=args.cache_dir)[args.task] 89 | else: 90 | doc_pairs = load_dataset('xlangai/bright', 'documents', cache_dir=args.cache_dir)[args.task] 91 | documents = {} 92 | for d in doc_pairs: 93 | documents[d['id']] = d['content'] 94 | 95 | with open(args.retriever_score_file) as f: 96 | all_scores = json.load(f) 97 | 98 | outputs_path = args.output_dir 99 | score_file_path = os.path.join(outputs_path, f"{args.reasoning}_score.json") 100 | 101 | if not os.path.isfile(score_file_path): 102 | new_scores = copy.deepcopy(all_scores) 103 | 104 | model = Reranker(args.task) 105 | 106 | for qid,scores in tqdm(all_scores.items()): 107 | docs = [] 108 | sorted_scores = sorted(scores.items(),key=lambda x:x[1],reverse=True)[:args.input_k] 109 | for did, _ in sorted_scores: 110 | docs.append([did, documents[did]]) 111 | 112 | ctxs = [{'id': did, 'text': documents[did]} for did, _ in sorted_scores] 113 | 114 | cur_score = model.rerank(query=examples[qid]['query'], docs=ctxs, topk=args.k) 115 | 116 | assert len(cur_score) == len(sorted_scores) 117 | 118 | new_scores[qid] = cur_score 119 | 120 | os.makedirs(outputs_path, exist_ok=True) 121 | with open(score_file_path, 'w') as f: 122 | json.dump(new_scores, f, indent=2) 123 | else: 124 | with open(score_file_path) as f: 125 | new_scores = json.load(f) 126 | print(score_file_path,'exists') 127 | 128 | if args.long_context: 129 | key = 'gold_ids_long' 130 | else: 131 | key = 'gold_ids' 132 | ground_truth = {} 133 | for e in raw_examples: 134 | ground_truth[e['id']] = {} 135 | for gid in e[key]: 136 | ground_truth[e['id']][gid] = 1 137 | for i in e["excluded_ids"]: 138 | if i in documents: 139 | ground_truth[e['id']][i] = 0 140 | 141 | # The import is here to prevent environment issues 142 | from retrievers import calculate_retrieval_metrics 143 | 144 | results = calculate_retrieval_metrics(results=new_scores, qrels=ground_truth) 145 | with open(os.path.join(outputs_path, "reranker_results.json"), 'w') as f: 146 | json.dump(results, f, indent=2) 147 | 148 | # break ties by interpolating with the retriever scores 149 | retriever_interpolated_scores = {} 150 | for qid in new_scores: 151 | retriever_interpolated_scores[qid] = {} 152 | for did in new_scores[qid]: 153 | retriever_interpolated_scores[qid][did] = (0.5 * new_scores[qid][did]) + (0.5 * all_scores[qid][did]) 154 | results = calculate_retrieval_metrics(results=retriever_interpolated_scores, qrels=ground_truth) 155 | with open(os.path.join(outputs_path, f"reranker_retriever_results.json"), 'w') as f: 156 | json.dump(results, f, indent=2) 157 | 158 | # break ties by combining with the bm25 scores 159 | if args.bm25_score_file is not None: 160 | bm25_interpolated_scores = {} 161 | with open(args.bm25_score_file) as f: 162 | bm25_scores = json.load(f) 163 | for qid in new_scores: 164 | bm25_interpolated_scores[qid] = {} 165 | for did in new_scores[qid]: 166 | bm25_interpolated_scores[qid][did] = (100 * new_scores[qid][did]) + bm25_scores[qid][did] 167 | results = calculate_retrieval_metrics(results=bm25_interpolated_scores, qrels=ground_truth) 168 | with open(os.path.join(outputs_path, f"reranker_bm25_results.json"), 'w') as f: 169 | json.dump(results, f, indent=2) 170 | -------------------------------------------------------------------------------- /test_time_techniques/query_rewriting.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import time 4 | import os 5 | from datasets import load_dataset 6 | from tqdm import tqdm 7 | import functools 8 | 9 | import logging 10 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 11 | datefmt='%m/%d/%Y %H:%M:%S') 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(logging.INFO) 14 | 15 | 16 | def call_api(func): 17 | count = 0 18 | while True: 19 | try: 20 | count += 1 21 | output = func() 22 | break 23 | except Exception as e: 24 | logger.info(f"Exception while using api: {e}") 25 | if "rate limit" in str(e).lower() or "rate_limit" in str(e).lower(): 26 | logger.info("Rate limit exceeded, waiting 60 secs and retrying...") 27 | time.sleep(60) 28 | # elif count < 5: 29 | # logger.info("Encountered error, retrying...") 30 | # time.sleep(5) 31 | else: 32 | raise ValueError 33 | # logger.info("Skipping generation due to unknown error after 5 retries.") 34 | # output = None 35 | # break 36 | return output 37 | 38 | 39 | def format_chat(message, include_system=True, system_message="You are a helpful assistant."): 40 | if include_system: 41 | chat = [{"role": "system", "content": system_message}, {"role": "user", "content": message}] 42 | else: 43 | chat = [{"role": "user", "content": message}] 44 | print(chat) 45 | return chat 46 | 47 | 48 | class ClaudeModel: 49 | 50 | def __init__(self, version): 51 | from anthropic import AnthropicVertex 52 | PROJECT_ID = "xxx" # @param 53 | LOCATION = "xxx" # @param 54 | self.model = AnthropicVertex(region=LOCATION, project_id=PROJECT_ID) 55 | self.version = version 56 | 57 | def generate(self, prompt): 58 | inputs = format_chat(prompt, include_system=False) 59 | func = functools.partial( 60 | self.model.messages.create, 61 | max_tokens=2048, 62 | messages=inputs, 63 | model=self.version, 64 | temperature=0.8, 65 | top_p=0.8 66 | ) 67 | message = call_api(func) 68 | if message is not None: 69 | response = json.loads(message.model_dump_json(indent=2)) 70 | return response['content'][0]['text'] 71 | return None 72 | 73 | 74 | class OpenAIModel: 75 | def __init__(self, model_name, temperature=0.8, top_p=0.8, max_tokens=2048): 76 | import openai 77 | if "azure" in model_name: 78 | # env var: AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and OPENAI_API_VERSION 79 | self.model = openai.AzureOpenAI() 80 | model_name = model_name[model_name.index("/")+1:] 81 | else: 82 | # make sure to set the OPENAI_API_KEY environment variable 83 | self.model = openai.OpenAI() 84 | self.model_name = model_name 85 | self.temperature = temperature 86 | self.top_p = top_p 87 | self.max_tokens = max_tokens 88 | 89 | def generate(self, prompt, system_message="You are a helpful assistant", **kwargs): 90 | # kwargs can be used to pass additional parameters to the model: max_tokens, stop, etc. 91 | inputs = format_chat(prompt, system_message=system_message) 92 | func = functools.partial( 93 | self.model.chat.completions.create, 94 | model=self.model_name, 95 | messages=inputs, 96 | max_tokens=self.max_tokens, 97 | temperature=self.temperature, 98 | top_p=self.top_p, 99 | **kwargs, 100 | ) 101 | output = call_api(func) 102 | if output is not None: 103 | return output.choices[0].message.content 104 | return None 105 | 106 | 107 | class GeminiModel: 108 | def __init__(self, model_name, temperature=0.8, top_p=0.8, max_tokens=2048): 109 | import google.generativeai as genai 110 | api_key=os.environ["GEMINI_API_KEY"] 111 | genai.configure(api_key=api_key) 112 | self.model = genai.GenerativeModel(model_name) 113 | self.temperature = temperature 114 | self.top_p = top_p 115 | self.max_tokens = max_tokens 116 | self.generation_config = genai.types.GenerationConfig( 117 | max_output_tokens=self.max_tokens, 118 | temperature=self.temperature, 119 | top_p=self.top_p 120 | ) 121 | 122 | def generate(self, prompt, system_message="You are a helpful assistant", **kwargs): 123 | # kwargs can be used to pass additional parameters to the model: max_tokens, stop, etc. 124 | output = self.model.generate_content( 125 | prompt, 126 | generation_config=self.generation_config) 127 | # time.sleep(1) 128 | if output is not None: 129 | try: 130 | return output.text 131 | except: 132 | # import pdb; pdb.set_trace() 133 | return prompt.split("\n\nInstructions")[0] 134 | return None 135 | 136 | 137 | class HFModel: 138 | def __init__(self, model_name, temperature, top_p, max_tokens=2048): 139 | import torch 140 | from transformers import AutoModelForCausalLM, AutoTokenizer 141 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 142 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 143 | self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device) 144 | self.temperature = temperature 145 | self.top_p = top_p 146 | self.max_tokens = max_tokens 147 | 148 | def generate(self, message, **kwargs): 149 | inputs = self.tokenizer([message], return_tensors="pt").to(self.device) 150 | outputs = self.model.generate( 151 | **inputs, 152 | # max_length=1024, 153 | max_new_tokens=512, # added for llama3.1-7B, because max_length runs into errors. 154 | temperature=self.temperature, 155 | top_p=self.top_p, 156 | **kwargs, 157 | ) 158 | text = self.tokenizer.decode(outputs[0, inputs.input_ids.size(1):], skip_special_tokens=True) 159 | return text 160 | 161 | 162 | if __name__ == '__main__': 163 | parser = argparse.ArgumentParser() 164 | parser.add_argument('--task', type=str, required=True) 165 | parser.add_argument('--example_file', type=str, default=None) 166 | parser.add_argument('--output_dir', type=str, default="cache/reasoning") 167 | parser.add_argument('--model', type=str, default="gemini-1.5-flash") 168 | parser.add_argument('--output_token_limit', type=int, default=None) 169 | parser.add_argument('--sweep_output_dir', type=str, default=None) 170 | args = parser.parse_args() 171 | 172 | if args.example_file is not None: 173 | # supports json and jsonl files 174 | examples = load_dataset("json", data_files=args.example_file)["train"] 175 | else: 176 | examples = load_dataset('xlangai/BRIGHT', 'examples')[args.task] 177 | 178 | os.makedirs(args.output_dir, exist_ok=True) 179 | model_name = args.model.split("/")[-1] 180 | output_file = os.path.join(args.output_dir, f"{args.task}_{model_name}_{args.output_token_limit}.json") 181 | 182 | if os.path.exists(output_file): 183 | print(f"{output_file} exists, skipping") 184 | 185 | else: 186 | 187 | if 'claude' in args.model: 188 | model = ClaudeModel(version=args.model) 189 | elif 'gpt' in args.model: 190 | model = OpenAIModel(model_name=args.model, max_tokens=args.output_token_limit) 191 | elif 'gemini' in args.model: 192 | model = GeminiModel(model_name=args.model, max_tokens=args.output_token_limit) 193 | else: 194 | logger.info(f"Assuming Hugging Face model: {args.model}") 195 | model = HFModel(model_name=args.model, temperature=1e-9, top_p=1e-9) 196 | 197 | rewritten_examples = [] 198 | for e in tqdm(examples): 199 | cur_post = e["query"].replace('\n', ' ') 200 | prompt = (f'{cur_post}\n\n' 201 | f'Instructions:\n' 202 | f'1. Identify the essential problem.\n' 203 | f'2. Think step by step to reason and describe what information could be relevant and helpful to address the questions in detail.\n' 204 | f'3. Draft an answer with as many thoughts as you have.\n' 205 | ) 206 | if args.output_token_limit is not None: 207 | prompt += f'Your answer must be written within {args.output_token_limit} tokens.' 208 | output = model.generate(prompt) 209 | if output is not None: 210 | e['query'] = output 211 | rewritten_examples.append(e) 212 | 213 | logger.info(f"Saving rewritten examples to {output_file}") 214 | with open(output_file, 'w') as f: 215 | json.dump(rewritten_examples, f, indent=2) 216 | 217 | # track successful completion of the run 218 | if args.sweep_output_dir: 219 | with open(os.path.join(args.sweep_output_dir, 'done'), 'w') as f: 220 | f.write('done') -------------------------------------------------------------------------------- /synthetic_data_generation/doc_to_query.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import argparse 5 | from tqdm import tqdm 6 | import pdb 7 | from collections import defaultdict 8 | import numpy as np 9 | 10 | from hard_negative_mining import BM25_Miner 11 | from data_gen_prompts import * 12 | from gen_utils import * 13 | from data_gen_prompts import * 14 | from lm_helper import OpenAILM, HFLM 15 | import re 16 | 17 | 18 | def doc2query(bm25_miner, model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", num_docs=100, queries_per_doc=1, filter_name=None, 19 | output_dir='synthetic_data', cache_dir='cache', prompt_id='hq_gen', diverse_prompt=False, num_prompts=1, temperature=0, top_p=0): 20 | 21 | def hash_existed_documents(path): 22 | if not os.path.exists(path): 23 | return {} 24 | data = load_jsonl(path) 25 | hashed_data = {} 26 | for ex in data: 27 | query, document = ex['query'], ex['document'] 28 | key = document 29 | if 'prompt' in ex: 30 | key += ex['prompt'] 31 | hashed_data[key] = query 32 | return hashed_data 33 | 34 | def check_sample_existence(document, hashed_data, prompt=None): 35 | key = document 36 | if prompt: 37 | key += prompt 38 | if key in hashed_data: 39 | return True, hashed_data[key] 40 | else: 41 | return False, None 42 | 43 | def format_example(query, document): 44 | decoder = json.JSONDecoder() 45 | try: 46 | json_data, _ = decoder.raw_decode(query[query.find('{'):]) 47 | if 'hard_query' in json_data: 48 | queries = json_data['hard_query'] 49 | else: 50 | queries = json_data['questions'] 51 | if len(queries) > queries_per_doc: 52 | queries = queries[:queries_per_doc] 53 | except Exception as e: 54 | print(e) 55 | print(f"Skipping query: {query}") 56 | return None 57 | items = [] 58 | for _query in queries: 59 | pos = document 60 | try: 61 | if isinstance(_query, dict): 62 | _question = _query['question'] 63 | _scenario = _query['scenario'] 64 | _question = f"{_scenario} {_question}" 65 | else: 66 | _question = _query 67 | negs = bm25_miner.select_hard_negatives(_question, pos, 1) 68 | except Exception as e: 69 | print(e) 70 | print(f"Skipping query with type {type(_query)}: {_query}") 71 | continue 72 | item = { 73 | 'query': _question, 74 | 'pos': [pos], 75 | 'neg': negs, 76 | } 77 | items.append(item) 78 | return items 79 | 80 | prompt = prompt_registry[prompt_id] 81 | dataset = bm25_miner.dataset 82 | subject = bm25_miner.task if bm25_miner.task else dataset 83 | 84 | if bm25_miner.dataset == 'BRIGHT' and bm25_miner.task: 85 | examples = load_dataset('xlangai/BRIGHT', 'examples', cache_dir='cache')[subject] 86 | else: 87 | print("No task specified, using all examples.") 88 | tasks = ['biology', 'earth_science', 'economics', 'psychology', 'robotics', 89 | 'stackoverflow', 'sustainable_living', 'leetcode', 'pony', 90 | 'aops', 'theoremqa_theorems', 'theoremqa_questions'] 91 | all_examples = load_dataset('xlangai/BRIGHT', 'examples', cache_dir='cache') 92 | examples = [] 93 | for task in tasks: 94 | examples.extend(all_examples[task]) 95 | 96 | documents, doc_ids = bm25_miner.documents, bm25_miner.doc_ids 97 | doc_dicts = [{'doc_id': doc_id, 'doc': doc} for doc_id, doc in zip(doc_ids, documents)] 98 | 99 | total_num_docs = len(doc_dicts) 100 | num_docs_sample_pool = min(num_docs*1, total_num_docs) # document pool to sample num_oversample_docs docs 101 | num_oversample_docs = min(num_docs*2, total_num_docs) # obtain more documents than expected to make the final output matches the expectation 102 | filter_cache_dir = f'cache/{subject}' 103 | os.makedirs(filter_cache_dir, exist_ok=True) 104 | print(f"Filtering documents based on {filter_name}...") 105 | 106 | if dataset == 'msmarco': 107 | doc_dicts, doc_ids = document_filter(doc_dicts, doc_ids, filter_name=filter_name, num_docs=num_docs, cache_dir=filter_cache_dir) 108 | else: 109 | doc_dicts, doc_ids = document_filter(doc_dicts, doc_ids, filter_name=filter_name, num_docs=num_docs, cache_dir=filter_cache_dir) 110 | 111 | num_filtered_docs = len(doc_dicts) 112 | print(f"Total number of documents: {total_num_docs}, number of filtered documents with oversampling: {num_filtered_docs}") 113 | 114 | model_id_str = model_id.split('/')[-1] 115 | # path to save intermediate model generated results, will check if the same document has been used before to avoid repetitive generation. 116 | final_output_path = os.path.join(output_dir, f'all_docs_train_data/{prompt_id}/{model_id_str}/{subject}_{num_docs}_train_data.jsonl') 117 | final_output_path = os.path.expanduser(final_output_path) 118 | os.makedirs(os.path.dirname(final_output_path), exist_ok=True) 119 | 120 | output_path = os.path.join(output_dir, f'all_docs/{prompt_id}/{model_id_str}/{subject}_train_data.jsonl') 121 | output_path = os.path.expanduser(output_path) 122 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 123 | hashed_data = hash_existed_documents(output_path) 124 | 125 | if 'gpt' in model_id: 126 | model = OpenAILM(model_id, temperature=temperature, top_p=top_p, seed=0) 127 | else: 128 | model = HFLM(model_id, temperature=temperature, top_p=top_p) 129 | 130 | system_prompt = fill_sys_prompt(prompt, queries_per_doc=queries_per_doc) 131 | system_prompts = [system_prompt] 132 | 133 | final_training_data = [] 134 | print("Sampling from system prompts:", system_prompts) 135 | with open(output_path, 'a+', buffering=1) as fout: 136 | for doc_dict in tqdm(doc_dicts): 137 | document = doc_dict['doc'] 138 | formatted_query = format_query_doc(document) 139 | 140 | if len(system_prompts) > num_prompts: 141 | sampled_prompts = random.choices(system_prompts, k=num_prompts) 142 | else: 143 | sampled_prompts = system_prompts 144 | 145 | for system_prompt in sampled_prompts: # generate examples for each prompt 146 | has_generated, query = check_sample_existence(document, hashed_data, prompt=system_prompt) 147 | if not has_generated or temperature > 0: 148 | max_num_attempt_per_doc = 3 149 | num_attempt = 0 150 | succeed = False 151 | while not succeed and num_attempt < max_num_attempt_per_doc: 152 | query = model.generate(formatted_query, system_prompt=system_prompt) 153 | print(f"Generated query: {query}") 154 | items = {'query': query, 'document': document, 'prompt': system_prompt} 155 | items = format_example(query, document) 156 | num_attempt += 1 157 | if items is not None: 158 | succeed = True 159 | else: 160 | print(f"Using existing query. Skipping.") 161 | 162 | if not items: 163 | print(f"Skipping document: {document}") 164 | continue 165 | if isinstance(items, list): 166 | if len(items) < queries_per_doc: 167 | continue 168 | final_training_data.extend(items[:queries_per_doc]) 169 | if len(final_training_data) >= num_docs * queries_per_doc: 170 | break 171 | 172 | write_jsonl(final_training_data, final_output_path) 173 | 174 | 175 | if __name__ == '__main__': 176 | parser = argparse.ArgumentParser() 177 | parser.add_argument('--mode', type=str, default=None, help='mode') 178 | parser.add_argument('--model_id', type=str, default='gpt-4o', help='model id') 179 | parser.add_argument('--dataset', type=str, default='bright', help='dataset') 180 | parser.add_argument('--subject', type=str, default=None, help='subject') 181 | parser.add_argument('--queries_per_doc', type=int, default=3, help='number of generated samples per document') 182 | parser.add_argument('--num_docs', type=int, default=None, help='number of documents to sample for each subject') 183 | parser.add_argument('--debug', action='store_true', help='debug mode') 184 | parser.add_argument('--filter', type=str, default=None, help='the default filter is the length filter', choices=['length', 'fineweb', 'dclm']) 185 | parser.add_argument('--data_path', type=str, default='~/data/chunks/mathpile_wiki_chunks.jsonl', help='data path') 186 | parser.add_argument('--output_dir', type=str, default='data/synthetic_questions', help='base directory to save the generated data') 187 | parser.add_argument('--cache_dir', type=str, default='cache/', help='cache directory to save cached data during document filtering.') 188 | parser.add_argument('--prompt_id', type=str, default='hq_gen', help='prompt to use') 189 | parser.add_argument("--temperature", type=float, default=0) 190 | parser.add_argument("--top_p", type=float, default=0) 191 | args = parser.parse_args() 192 | args.output_dir = os.path.expanduser(args.output_dir) 193 | args.data_path = os.path.expanduser(args.data_path) 194 | os.makedirs(args.cache_dir, exist_ok=True) 195 | print(args) 196 | 197 | print(f"Dataset: {args.dataset}, Task: {args.subject}") 198 | bm25_miner = BM25_Miner(dataset=args.dataset, task=args.subject, data_path=args.data_path) 199 | 200 | model_id = args.model_id 201 | doc2query(bm25_miner, model_id=model_id, num_docs=args.num_docs, 202 | filter_name=args.filter, queries_per_doc=args.queries_per_doc, output_dir=args.output_dir, cache_dir=args.cache_dir, 203 | prompt_id=args.prompt_id, temperature=args.temperature, top_p=args.top_p) 204 | 205 | -------------------------------------------------------------------------------- /evaluation/bright/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | from tqdm import tqdm 5 | from retrievers import RETRIEVAL_FUNCS,calculate_retrieval_metrics 6 | from datasets import load_dataset 7 | 8 | 9 | if __name__=='__main__': 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--task', type=str, required=True, 12 | choices=['biology','earth_science','economics','pony','psychology','robotics', 13 | 'stackoverflow','sustainable_living','aops','leetcode','theoremqa_theorems', 14 | 'theoremqa_questions']) 15 | parser.add_argument('--model', type=str, required=True, 16 | choices=['bm25','cohere','e5','google','grit','inst-l','inst-xl', 17 | 'openai','qwen','qwen2','sbert','sf','voyage','bge', 18 | 'bge_ce', 'nomic', 'm2', 'contriever', 'reasonir']) 19 | parser.add_argument('--model_id', type=str, default=None, help='(Optional) Pass a different model ID for cache and output path naming.') 20 | parser.add_argument('--long_context', action='store_true') 21 | parser.add_argument('--document_expansion', default=None, type=str, choices=[None, 'gold', 'full'], 22 | help="Set to None to use original documents provided by BRIGHT; Set to `oracle` to use documents with oracle ones expanded'; Set to `full` to use all expanded documents.") 23 | parser.add_argument('--global_summary', default=None, choices=[None, 'concat']) 24 | parser.add_argument('--query_max_length', type=int, default=-1) 25 | parser.add_argument('--doc_max_length', type=int, default=-1) 26 | parser.add_argument('--encode_batch_size', type=int, default=-1) 27 | parser.add_argument('--output_dir', type=str, default='outputs') 28 | parser.add_argument('--cache_dir', type=str, default='cache') 29 | parser.add_argument('--config_dir', type=str, default='configs') 30 | parser.add_argument('--checkpoint', type=str, default=None) 31 | parser.add_argument('--key', type=str, default=None) 32 | parser.add_argument('--input_file', type=str, default=None) 33 | parser.add_argument('--reasoning', type=str, default=None) 34 | parser.add_argument('--reasoning_id', type=str, default=None) 35 | parser.add_argument('--reasoning_length_limit', type=int, default=None) 36 | parser.add_argument('--separate_reasoning', action='store_true', help='Append reasoning after the original query, separate by .') 37 | parser.add_argument('--debug', action='store_true') 38 | parser.add_argument('--ignore_cache', action='store_true') 39 | parser.add_argument('--no_log', action='store_true', help="Disable logging to Google Sheets.") 40 | parser.add_argument('--sweep_output_dir', type=str, default=None) 41 | parser.add_argument('--skip_doc_emb', action='store_true', help="Skip document embedding.") 42 | parser.add_argument('--store_all_scores', action='store_true', help="The default is to store the top 1000 scores. This option will store all scores.") 43 | args = parser.parse_args() 44 | if args.model_id is None: 45 | args.output_dir = os.path.join(args.output_dir,f"{args.task}_{args.model}_long_{args.long_context}") 46 | else: 47 | args.output_dir = os.path.join(args.output_dir,f"{args.task}_{args.model_id}_long_{args.long_context}") 48 | if not os.path.isdir(args.output_dir): 49 | os.makedirs(args.output_dir) 50 | if args.reasoning is not None: 51 | score_file_path = os.path.join(args.output_dir,f'{args.reasoning}_score.json') 52 | else: 53 | score_file_path = os.path.join(args.output_dir,f'score.json') 54 | 55 | assert args.document_expansion is None or args.global_summary is None, "Cannot use expansion and summary together!" 56 | if args.global_summary: 57 | assert not args.long_context, "Global summary is supposed to enhance short-context retrieval!" 58 | 59 | document_postfix = '' 60 | if args.document_expansion == 'gold': 61 | document_postfix = '_expanded_gold_only' 62 | dataset_source = 'rulins/bright-expanded' 63 | elif args.document_expansion == 'full': 64 | document_postfix = '_expanded' 65 | dataset_source = 'rulins/bright-expanded' 66 | elif args.global_summary == 'concat': 67 | document_postfix = '_concat_with_summary' 68 | dataset_source = 'rulins/bright-expanded' 69 | else: 70 | dataset_source = 'xlangai/BRIGHT' 71 | 72 | if args.reasoning is not None and 'llama3-8b' in args.reasoning: 73 | reasoning_source = 'dreamorg/BRIGHT' 74 | else: 75 | reasoning_source = 'xlangai/BRIGHT' 76 | print(f"Dataset source: {dataset_source}") 77 | print(f"Reasoning source: {reasoning_source}") 78 | 79 | if args.input_file is not None: 80 | with open(args.input_file) as f: 81 | examples = json.load(f) 82 | elif args.reasoning is not None and args.reasoning in ['promptriever', 'cot_v2']: 83 | reasoning_file = f"cache/{args.reasoning}/{args.reasoning_id}_{args.task}.jsonl" 84 | from synthetic_data_gen.gen_utils import load_jsonl 85 | examples = load_jsonl(reasoning_file) 86 | for e in examples: 87 | e['query'] = e['query'] + '\n' + e['rewritten_query'] 88 | elif args.reasoning is not None and args.separate_reasoning: 89 | examples = load_dataset(dataset_source, 'examples', cache_dir=args.cache_dir)[args.task] 90 | reasoning_examples = load_dataset(reasoning_source, f"{args.reasoning}_reason", cache_dir=args.cache_dir)[args.task] 91 | elif args.reasoning is not None and args.reasoning_length_limit is None: 92 | examples = load_dataset(reasoning_source, f"{args.reasoning}_reason", cache_dir=args.cache_dir)[args.task] 93 | elif args.reasoning is not None and args.reasoning_length_limit: 94 | reasoning_file = f"cache/reasoning/{args.task}_{args.reasoning}_{args.reasoning_length_limit}" 95 | with open(reasoning_file, 'r') as f: 96 | examples = json.load(f) 97 | else: 98 | examples = load_dataset(dataset_source, 'examples',cache_dir=args.cache_dir)[args.task] 99 | 100 | if args.long_context: 101 | doc_pairs = load_dataset(dataset_source, 'long_documents'+document_postfix, cache_dir=args.cache_dir)[args.task] 102 | else: 103 | doc_pairs = load_dataset(dataset_source, 'documents'+document_postfix, cache_dir=args.cache_dir)[args.task] 104 | 105 | doc_ids = [] 106 | documents = [] 107 | for dp in doc_pairs: 108 | doc_ids.append(dp['id']) 109 | documents.append(dp['content']) 110 | 111 | if not os.path.isfile(score_file_path): 112 | print("The scores file does not exist, start retrieving...") 113 | 114 | with open(os.path.join(args.config_dir,args.model.split('_ckpt')[0].split('_bilevel')[0],f"{args.task}.json")) as f: 115 | config = json.load(f) 116 | if not os.path.isdir(args.output_dir): 117 | os.makedirs(args.output_dir) 118 | 119 | queries = [] 120 | query_ids = [] 121 | excluded_ids = {} 122 | for qid, e in enumerate(examples): 123 | if args.separate_reasoning: 124 | new_query = f"{e['query']}\n\n{reasoning_examples[qid]['query']}" 125 | queries.append(new_query) 126 | else: 127 | queries.append(e["query"]) 128 | query_ids.append(e['id']) 129 | excluded_ids[e['id']] = e['excluded_ids'] 130 | overlap = set(e['excluded_ids']).intersection(set(e['gold_ids'])) 131 | assert len(overlap)==0 132 | assert len(queries)==len(query_ids), f"{len(queries)}, {len(query_ids)}" 133 | if not os.path.isdir(os.path.join(args.cache_dir, 'doc_ids')): 134 | os.makedirs(os.path.join(args.cache_dir, 'doc_ids')) 135 | if os.path.isfile(os.path.join(args.cache_dir,'doc_ids',f"{args.task}_{args.long_context}.json")): 136 | try: 137 | with open(os.path.join(args.cache_dir,'doc_ids',f"{args.task}_{args.long_context}.json")) as f: 138 | cached_doc_ids = json.load(f) 139 | for id1,id2 in zip(cached_doc_ids,doc_ids): 140 | assert id2 in cached_doc_ids 141 | except: 142 | print("Document IDs mismatche with the cached version!") 143 | else: 144 | with open(os.path.join(args.cache_dir,'doc_ids',f"{args.task}_{args.long_context}.json"),'w') as f: 145 | json.dump(doc_ids,f,indent=2) 146 | assert len(doc_ids)==len(documents), f"{len(doc_ids)}, {len(documents)}" 147 | 148 | print(f"{len(queries)} queries") 149 | print(f"{len(documents)} documents") 150 | if args.debug: 151 | documents = documents[:30] 152 | doc_paths = doc_ids[:30] 153 | kwargs = {} 154 | if args.query_max_length>0: 155 | kwargs = {'query_max_length': args.query_max_length} 156 | if args.doc_max_length>0: 157 | kwargs.update({'doc_max_length': args.doc_max_length}) 158 | if args.encode_batch_size>0: 159 | kwargs.update({'batch_size': args.encode_batch_size}) 160 | if args.key is not None: 161 | kwargs.update({'key': args.key}) 162 | if args.ignore_cache: 163 | kwargs.update({'ignore_cache': args.ignore_cache}) 164 | if args.skip_doc_emb: 165 | kwargs.update({'skip_doc_emb': args.skip_doc_emb}) 166 | if args.store_all_scores: 167 | kwargs.update({'store_all_scores': args.store_all_scores}) 168 | model_id = args.model_id if args.model_id is not None else args.model 169 | scores = RETRIEVAL_FUNCS[args.model](queries=queries,query_ids=query_ids,documents=documents,excluded_ids=excluded_ids, 170 | instructions=config['instructions_long'] if args.long_context else config['instructions'], 171 | doc_ids=doc_ids,task=args.task,cache_dir=args.cache_dir,long_context=args.long_context, 172 | model_id=model_id,checkpoint=args.checkpoint,**kwargs) 173 | with open(score_file_path,'w') as f: 174 | json.dump(scores,f,indent=2) 175 | else: 176 | with open(score_file_path) as f: 177 | scores = json.load(f) 178 | print(score_file_path,'exists') 179 | if args.long_context: 180 | key = 'gold_ids_long' 181 | else: 182 | key = 'gold_ids' 183 | ground_truth = {} 184 | for e in tqdm(examples): 185 | ground_truth[e['id']] = {} 186 | for gid in e[key]: 187 | ground_truth[e['id']][gid] = 1 188 | for did in e['excluded_ids']: 189 | assert not did in scores[e['id']] 190 | assert not did in ground_truth[e['id']] 191 | 192 | print(args.output_dir) 193 | results = calculate_retrieval_metrics(results=scores, qrels=ground_truth) 194 | with open(os.path.join(args.output_dir, 'results.json'), 'w') as f: 195 | json.dump(results, f, indent=2) 196 | 197 | 198 | # track successful completion of the run 199 | if args.sweep_output_dir: 200 | with open(os.path.join(args.sweep_output_dir, 'done'), 'w') as f: 201 | f.write('done') -------------------------------------------------------------------------------- /synthetic_data_generation/doc_to_query_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import argparse 5 | from tqdm import tqdm 6 | import pdb 7 | from collections import defaultdict 8 | import numpy as np 9 | 10 | from hard_negative_mining import BM25_Miner 11 | from data_gen_prompts import * 12 | from gen_utils import * 13 | from lm_helper import OpenAILM, HFLM 14 | from batch_api_helper import BatchAPIHelper 15 | import re 16 | import time 17 | 18 | 19 | def base_data_to_dict(base_data): 20 | base_dict = {} 21 | for item in base_data: 22 | doc_id = item['doc_id'] 23 | base_dict[doc_id] = item 24 | return base_dict 25 | 26 | 27 | def format_example(bm25_miner, query, document, queries_per_doc=1): 28 | decoder = json.JSONDecoder() 29 | try: 30 | json_data, _ = decoder.raw_decode(query[query.find('{'):]) 31 | if 'hard_query' in json_data: 32 | queries = json_data['hard_query'] 33 | else: 34 | queries = json_data['questions'] 35 | if len(queries) > queries_per_doc: 36 | queries = queries[:queries_per_doc] 37 | except Exception as e: 38 | print(e) 39 | print(f"Skipping query: {query}") 40 | return None 41 | items = [] 42 | for _query in queries: 43 | pos = document 44 | try: 45 | _question = _query 46 | negs = bm25_miner.select_hard_negatives(_question, pos, 1) 47 | except Exception as e: 48 | print(e) 49 | print(f"Skipping query with type {type(_query)}: {_query}") 50 | continue 51 | item = { 52 | 'query': _question, 53 | 'pos': [pos], 54 | 'neg': negs, 55 | } 56 | items.append(item) 57 | return items 58 | 59 | def doc2query(bm25_miner, model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", num_docs=100, queries_per_doc=1, filter_name=None, 60 | output_dir='synthetic_data', cache_dir='cache', prompt_id='hq_gen', diverse_prompt=False, num_prompts=1, temperature=0, top_p=0, batch_helper=None): 61 | 62 | prompt = prompt_registry[prompt_id] 63 | dataset = bm25_miner.dataset 64 | subject = bm25_miner.task if bm25_miner.task else dataset 65 | 66 | if bm25_miner.dataset == 'BRIGHT' and bm25_miner.task: 67 | examples = load_dataset('xlangai/BRIGHT', 'examples', cache_dir='cache')[subject] 68 | else: 69 | print("No task specified, using all examples.") 70 | tasks = ['biology', 'earth_science', 'economics', 'psychology', 'robotics', 71 | 'stackoverflow', 'sustainable_living', 'leetcode', 'pony', 72 | 'aops', 'theoremqa_theorems', 'theoremqa_questions'] 73 | all_examples = load_dataset('xlangai/BRIGHT', 'examples', cache_dir='cache') 74 | examples = [] 75 | for task in tasks: 76 | examples.extend(all_examples[task]) 77 | 78 | documents, doc_ids = bm25_miner.documents, bm25_miner.doc_ids 79 | doc_dicts = [{'doc_id': doc_id, 'doc': doc} for doc_id, doc in zip(doc_ids, documents)] 80 | 81 | total_num_docs = len(doc_dicts) 82 | num_docs_sample_pool = min(num_docs*1, total_num_docs) # document pool to sample num_oversample_docs docs 83 | num_oversample_docs = min(num_docs*2, total_num_docs) # obtain more documents than expected to make the final output matches the expectation 84 | filter_cache_dir = f'cache/{subject}' 85 | os.makedirs(filter_cache_dir, exist_ok=True) 86 | print(f"Filtering documents based on {filter_name}...") 87 | 88 | if dataset == 'msmarco': 89 | doc_dicts, doc_ids = document_filter(doc_dicts, doc_ids, filter_name=filter_name, num_docs=num_docs, cache_dir=filter_cache_dir) 90 | else: 91 | doc_dicts, doc_ids = document_filter(doc_dicts, doc_ids, filter_name=filter_name, num_docs=num_docs, cache_dir=filter_cache_dir) 92 | 93 | num_filtered_docs = len(doc_dicts) 94 | print(f"Total number of documents: {total_num_docs}, number of filtered documents with oversampling: {num_filtered_docs}") 95 | 96 | model_id_str = model_id.split('/')[-1] 97 | # path to save intermediate model generated results, will check if the same document has been used before to avoid repetitive generation. 98 | final_output_path = os.path.join(output_dir, f'all_docs_train_data/{prompt_id}/{model_id_str}/{subject}_{num_docs}_train_data.jsonl') 99 | final_output_path = os.path.expanduser(final_output_path) 100 | os.makedirs(os.path.dirname(final_output_path), exist_ok=True) 101 | 102 | if 'gpt' in model_id: 103 | model = OpenAILM(model_id, temperature=temperature, top_p=top_p, seed=0) 104 | else: 105 | model = HFLM(model_id, temperature=temperature, top_p=top_p) 106 | 107 | system_prompt = fill_sys_prompt(prompt, queries_per_doc=queries_per_doc) 108 | system_prompts = [system_prompt] 109 | 110 | batch_request_messages = [] 111 | batch_request_ids = [] 112 | batch_base_data = [] 113 | 114 | final_training_data = [] 115 | print("Sampling from system prompts:", system_prompts) 116 | for doc_dict in tqdm(doc_dicts): 117 | document = doc_dict['doc'] 118 | formatted_query = format_query_doc(document) 119 | 120 | if len(system_prompts) > num_prompts: 121 | sampled_prompts = random.choices(system_prompts, k=num_prompts) 122 | else: 123 | sampled_prompts = system_prompts 124 | 125 | for system_prompt in sampled_prompts: # generate examples for each prompt 126 | max_num_attempt_per_doc = 3 127 | num_attempt = 0 128 | has_generated = False 129 | while not has_generated and num_attempt < max_num_attempt_per_doc: 130 | if batch_helper is not None: 131 | doc_id = doc_dict['doc_id'] 132 | messages = model.apply_chat_template(formatted_query, system_prompt=system_prompt) 133 | batch_request_messages.append(messages) 134 | batch_request_ids.append(doc_id) 135 | batch_base_data.append({'doc_id': doc_id, 'document': document, 'prompt': system_prompt}) 136 | items = None 137 | has_generated = True 138 | else: 139 | query = model.generate(formatted_query, system_prompt=system_prompt) 140 | print(f"Generated query: {query}") 141 | items = {'query': query, 'document': document, 'prompt': system_prompt} 142 | items = format_example(bm25_miner, query, document, queries_per_doc) 143 | num_attempt += 1 144 | if items is not None: 145 | has_generated = True 146 | 147 | if not items: 148 | print(f"Skipping document: {document}") 149 | continue 150 | if isinstance(items, list): 151 | if len(items) < queries_per_doc: 152 | continue 153 | final_training_data.extend(items[:queries_per_doc]) 154 | if len(final_training_data) >= num_docs * queries_per_doc: 155 | break 156 | 157 | if batch_helper is not None: 158 | batch_helper.batch_save_base(batch_base_data) 159 | batch_helper.batch_request(batch_request_messages, batch_request_ids) 160 | else: 161 | write_jsonl(final_training_data, final_output_path) 162 | 163 | 164 | def main(): 165 | pass 166 | 167 | 168 | if __name__ == '__main__': 169 | parser = argparse.ArgumentParser() 170 | parser.add_argument('--mode', type=str, default=None, help='mode') 171 | parser.add_argument('--model_id', type=str, default='gpt-4o', help='model id') 172 | parser.add_argument('--dataset', type=str, default='bright', help='dataset') 173 | parser.add_argument('--subject', type=str, default=None, help='subject') 174 | parser.add_argument('--queries_per_doc', type=int, default=3, help='number of generated samples per document') 175 | parser.add_argument('--num_docs', type=int, default=None, help='number of documents to sample for each subject') 176 | parser.add_argument('--debug', action='store_true', help='debug mode') 177 | parser.add_argument('--filter', type=str, default=None, help='the default filter is the length filter', choices=['length', 'fineweb', 'dclm']) 178 | parser.add_argument('--data_path', type=str, default='~/data/chunks/mathpile_wiki_chunks.jsonl', help='data path') 179 | parser.add_argument('--output_dir', type=str, default='data/synthetic_questions', help='base directory to save the generated data') 180 | parser.add_argument('--cache_dir', type=str, default='cache/', help='cache directory to save cached data during document filtering.') 181 | parser.add_argument('--prompt_id', type=str, default='hq_gen', help='prompt to use') 182 | parser.add_argument("--num_prompts", type=int, default=1, help='number of prompts to use for diversity') 183 | parser.add_argument("--temperature", type=float, default=0) 184 | parser.add_argument("--top_p", type=float, default=0) 185 | parser.add_argument("--no_batch", action='store_true', help='disable batch mode') 186 | parser.add_argument('--gather_results', action='store_true', help='check batch job and process') 187 | args = parser.parse_args() 188 | args.output_dir = os.path.expanduser(args.output_dir) 189 | args.data_path = os.path.expanduser(args.data_path) 190 | os.makedirs(args.cache_dir, exist_ok=True) 191 | print(args) 192 | 193 | 194 | print(f"Dataset: {args.dataset}, Task: {args.subject}") 195 | bm25_miner = BM25_Miner(dataset=args.dataset, task=args.subject, data_path=args.data_path) 196 | 197 | model_id = args.model_id 198 | prompt_id = args.prompt_id 199 | num_docs = args.num_docs 200 | queries_per_doc = args.queries_per_doc 201 | dataset = bm25_miner.dataset 202 | subject = bm25_miner.task if bm25_miner.task else dataset 203 | output_dir = args.output_dir 204 | 205 | model_id_str = model_id.replace('/', '_') 206 | task_filename = f'{model_id_str}_{bm25_miner.dataset}_{bm25_miner.task}.jsonl' 207 | task_filename = os.path.join(args.output_dir, task_filename) 208 | batch_helper = None 209 | if not args.no_batch or args.gather_results: 210 | batch_helper = BatchAPIHelper(model_id, task_filename) 211 | 212 | if args.gather_results: 213 | base_data = batch_helper.batch_load_base() 214 | base_data_dict = base_data_to_dict(base_data) 215 | outputs = None 216 | while outputs is None: 217 | outputs = batch_helper.gather_results() 218 | if outputs is None: 219 | print("Waiting for the batch job to finish...") 220 | time.sleep(60) 221 | 222 | # process the outputs and cache the results 223 | model_id_str = model_id.split('/')[-1] 224 | # path to save intermediate model generated results, will check if the same document has been used before to avoid repetitive generation. 225 | final_output_path = os.path.join(output_dir, f'all_docs_train_data/{prompt_id}/{model_id_str}/{subject}_{num_docs}_train_data.jsonl') 226 | final_output_path = os.path.expanduser(final_output_path) 227 | os.makedirs(os.path.dirname(final_output_path), exist_ok=True) 228 | 229 | final_training_data = [] 230 | for output in outputs: 231 | query_id = output['id'] 232 | response = output['response'] 233 | document = base_data_dict[query_id]['document'] 234 | items = format_example(bm25_miner, response, document, queries_per_doc) 235 | 236 | if not items: 237 | print(f"Skipping document: {document}") 238 | continue 239 | if isinstance(items, list): 240 | if len(items) < queries_per_doc: 241 | continue 242 | final_training_data.extend(items[:queries_per_doc]) 243 | write_jsonl(final_training_data, final_output_path) 244 | if len(final_training_data) > 0: 245 | print(f"Generated {len(final_training_data)} samples.") 246 | batch_helper.clean_up() 247 | exit() 248 | 249 | doc2query(bm25_miner, model_id=model_id, num_docs=args.num_docs, 250 | filter_name=args.filter, queries_per_doc=args.queries_per_doc, output_dir=args.output_dir, cache_dir=args.cache_dir, 251 | prompt_id=args.prompt_id, temperature=args.temperature, top_p=args.top_p, batch_helper=batch_helper) 252 | 253 | -------------------------------------------------------------------------------- /synthetic_data_generation/gen_utils.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import os 3 | import json 4 | import re 5 | import random 6 | 7 | 8 | def format_query_doc(doc, example_query=None): 9 | user_prompt = '''The document is given below: 10 | 11 | 12 | \n'''.replace('', doc) 13 | if example_query is not None: 14 | user_prompt += f''' 15 | \nThe example question is given below: 16 | 17 | 18 | \n'''.replace('', example_query) 19 | user_prompt += '\nPlease start generating the questions.' 20 | return user_prompt 21 | 22 | 23 | def process_batch_response(response): 24 | outputs = [] 25 | json_data = response.content.decode('utf-8') 26 | for line in json_data.splitlines(): 27 | # Parse the JSON record (line) to validate it 28 | json_record = json.loads(line) 29 | 30 | custom_id = json_record.get("custom_id") 31 | 32 | # Navigate to the 'choices' key within the 'response' -> 'body' 33 | choices = json_record.get("response", {}).get("body", {}).get("choices", []) 34 | 35 | # Loop through the choices to find messages with the 'assistant' role 36 | for choice in choices: 37 | message = choice.get("message", {}) 38 | if message.get("role") == "assistant": 39 | assistant_content = message.get("content") 40 | outputs.append({"id": custom_id, "response": assistant_content}) 41 | break 42 | return outputs 43 | 44 | 45 | def format_request_json(messages, model="gpt-4o", custom_id="request-1"): 46 | return {"custom_id": custom_id, 47 | "method": "POST", 48 | "url": "/v1/chat/completions", 49 | "body": { 50 | "model": model, 51 | "messages": messages}} 52 | 53 | 54 | def extract_json_from_text(text): 55 | decoder = json.JSONDecoder() 56 | try: 57 | json_data, _ = decoder.raw_decode(text[text.find('{'):]) 58 | except json.decoder.JSONDecodeError as e: 59 | # text = text.replace('\\', '\\\\') 60 | text = re.sub(r'(? 0 and isinstance(documents[0], dict): 151 | # check if the documents are in the form of dictionaries 152 | passages = [documents[i]['doc'] for i in range(len(documents))] 153 | else: 154 | passages = [i for i in documents] 155 | 156 | if filter_name is None: 157 | print("No filter specified. Using the default length filter.") 158 | filter_name = 'length' 159 | if filter_name == 'length' or filter_name == 'chunk': 160 | if filter_name == 'length': 161 | ids = length_filter(passages) 162 | elif filter_name == 'chunk': 163 | ids = chunk_filter(passages) 164 | filtered_docs = [documents[i] for i in ids] 165 | filtered_doc_ids = [doc_ids[i] for i in ids] 166 | if num_docs is not None and num_sample_pool is not None: 167 | # Sample num_docs documents from a larger pool for diversity 168 | # and avoid conjection chunks with similar quality scores being chosen too much (might lead to false negative) 169 | num_sample_pool = max(num_docs, num_sample_pool) 170 | filtered_docs = filtered_docs[:num_sample_pool] 171 | filtered_doc_ids = filtered_doc_ids[:num_sample_pool] 172 | indices = random.sample(range(num_sample_pool), num_docs) 173 | filtered_docs = [filtered_docs[i] for i in indices] 174 | filtered_doc_ids = [filtered_doc_ids[i] for i in indices] 175 | elif num_docs is not None: 176 | filtered_docs = filtered_docs[:num_docs] 177 | filtered_doc_ids = filtered_doc_ids[:num_docs] 178 | if return_scores: 179 | return filtered_docs, filtered_doc_ids, None 180 | return filtered_docs, filtered_doc_ids 181 | 182 | if filter_name == 'fineweb': 183 | # NOTE: the threshold is set to 1 because the fineweb filter returns a score between 0 and 5 184 | print(f"Using fineweb_edu filter: {filter_name}") 185 | from document_filters.fineweb_edu_filter import fineweb_quality_filter 186 | quality_filter = fineweb_quality_filter 187 | cache_file = 'fineweb_scores.json' 188 | else: 189 | raise ValueError(f"Unknown filter name: {filter_name}") 190 | 191 | 192 | if cache_dir is not None: 193 | cache_dir = os.path.expanduser(cache_dir) 194 | # load scores from cache if available 195 | cache_path = os.path.join(cache_dir, cache_file) 196 | use_cache = True 197 | else: 198 | use_cache = False 199 | 200 | if use_cache and os.path.exists(cache_path): 201 | # NOTE: disabled caching since the filename is shared by different tasks 202 | with open(cache_path, 'r') as fin: 203 | score_dict = json.load(fin) 204 | scores = [score_dict[doc_id] for doc_id in doc_ids] 205 | else: 206 | scores = quality_filter(passages) 207 | score_dict = {} 208 | for i, score in enumerate(scores): 209 | score_dict[doc_ids[i]] = score 210 | if use_cache: 211 | with open(cache_path, 'w') as fout: 212 | json.dump(score_dict, fout) 213 | 214 | if debug: 215 | from copy import deepcopy 216 | all_docs = deepcopy(documents) 217 | for item in all_docs: 218 | if isinstance(item, dict): 219 | item['score'] = score_dict[item['doc_id']] 220 | # save the list of dict items into excel 221 | import pandas as pd 222 | df = pd.DataFrame(all_docs) 223 | output_filename = cache_path.replace('.json', '.pkl') 224 | # df.to_csv(excel_filename, index=False) 225 | df.to_pickle(output_filename) 226 | 227 | filtered_docs = [] 228 | filtered_doc_ids = [] 229 | if num_docs is not None: 230 | num_docs = min(num_docs, len(passages)) 231 | sorted_ids = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True) 232 | # top10_scores = [scores[i] for i in sorted_ids[:10]] 233 | # print(f"Top 10 scores: {top10_scores}") 234 | 235 | for i in sorted_ids: 236 | if scores[i] >= threshold: 237 | filtered_docs.append(documents[i]) 238 | filtered_doc_ids.append(doc_ids[i]) 239 | if len(filtered_docs) >= num_docs: 240 | break 241 | if return_scores: 242 | return filtered_docs, filtered_doc_ids, scores 243 | return filtered_docs, filtered_doc_ids 244 | 245 | 246 | def get_long_doc_id(doc_id): 247 | # The doc id is in the format of 'long_doc_id_%d_%d' or 'long_doc_id_%d' 248 | # We need to extract the "long_doc_id" to get the actual document id 249 | doc_id = doc_id.split('_')[:-1] 250 | doc_id = '_'.join(doc_id) 251 | doc_id = re.sub(r'_\d+$', '', doc_id) 252 | return doc_id + '.txt' 253 | 254 | 255 | def random_sample_documents(documents, doc_ids, num_docs, seed=42): 256 | assert len(documents) == len(doc_ids) 257 | random.seed(seed) 258 | indices = list(range(len(documents))) 259 | sampled_indices = random.sample(indices, num_docs) 260 | sampled_documents = [documents[i] for i in sampled_indices] 261 | sampled_doc_ids = [doc_ids[i] for i in sampled_indices] 262 | return sampled_documents, sampled_doc_ids 263 | 264 | # filter the documents with less than 1024 tokens 265 | def chunk_filter(documents, chunk_size=1024): 266 | filtered_doc_ids = [] 267 | for doc_id, doc in enumerate(documents): 268 | if len(doc.split(' ')) < 1024: 269 | continue 270 | filtered_doc_ids.append(doc_id) 271 | return filtered_doc_ids 272 | 273 | 274 | # filter the documents with less than 20 tokens 275 | def length_filter(documents): 276 | filtered_doc_ids = [] 277 | for doc_id, doc in enumerate(documents): 278 | if len(doc.split(' ')) < 20: 279 | continue 280 | filtered_doc_ids.append(doc_id) 281 | return filtered_doc_ids 282 | 283 | 284 | def deduplicate_docs(documents): 285 | unique_docs = set() 286 | unique_doc_ids = [] 287 | for doc_id, doc in enumerate(documents): 288 | if isinstance(doc, dict): 289 | doc = doc['doc'] 290 | if doc not in unique_docs: 291 | unique_docs.add(doc) 292 | unique_doc_ids.append(doc_id) 293 | unique_docs = [documents[i] for i in unique_doc_ids] 294 | return unique_docs, unique_doc_ids 295 | 296 | 297 | def retrieve_gold_docs(subject, cache_dir='cache'): 298 | examples = get_examples(subject, cache_dir) 299 | 300 | docs = get_docs(subject, cache_dir=cache_dir) 301 | doc_dict = {} 302 | for d in docs: 303 | doc_dict[d['id']] = d['content'] 304 | 305 | results = [] 306 | for ex in examples: 307 | gold_ids = ex['gold_ids'] 308 | for gold_id in gold_ids: 309 | result_dict = { 310 | 'query': ex['query'], 311 | 'reasoning': ex['reasoning'], 312 | 'doc_id': gold_id, 313 | 'doc': doc_dict[gold_id] 314 | } 315 | results.append(result_dict) 316 | return results 317 | 318 | 319 | import tiktoken 320 | def count_tokens(messages, model_name="gpt-4o"): 321 | enc = tiktoken.encoding_for_model(model_name) 322 | n_tokens = 0 323 | for message in messages: 324 | text = message['content'] 325 | n_tokens += len(enc.encode(text)) 326 | return n_tokens 327 | 328 | 329 | if __name__ == "__main__": 330 | folder_name = 'top100' 331 | model_id = 'gpt-4o' 332 | path = f'synthetic_data/{folder_name}/all_docs_train_data/{model_id}' 333 | path = os.path.expanduser(path) 334 | from gen_utils import load_training_data 335 | data = load_training_data(path) 336 | # the data is structured as follows: 337 | # {subject: [{'query': ..., 'pos': ..., 'neg': ...}, ...]} -------------------------------------------------------------------------------- /evaluation/rag/mmlu_cot/evaluate_from_local_mmlu.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import argparse 4 | import os 5 | import torch 6 | import random 7 | import transformers 8 | import time 9 | import re 10 | from vllm import LLM, SamplingParams 11 | from tqdm import tqdm 12 | import logging 13 | import sys 14 | from datasets import load_dataset 15 | import pdb 16 | from extract_mmlu_group import get_mmlu_group_subjects 17 | 18 | choices = ["A", "B", "C", "D"] 19 | ids_to_letters = {0: "A", 1: "B", 2: "C", 3: "D"} 20 | max_model_length = 4096 21 | max_new_tokens = 2048 22 | 23 | 24 | def load_mmlu(): 25 | dataset = load_dataset("cais/mmlu", 'all') 26 | test_df, val_df = dataset["test"], dataset["validation"] 27 | test_df = preprocess(test_df) 28 | val_df = preprocess(val_df) 29 | return test_df, val_df 30 | 31 | 32 | def load_model(): 33 | # TODO: support data parallel 34 | # refer to https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/vllm_causallms.py#L240 35 | llm = LLM(model=args.model, gpu_memory_utilization=float(args.gpu_util), 36 | tensor_parallel_size=torch.cuda.device_count(), 37 | max_model_len=max_model_length, 38 | trust_remote_code=True) 39 | sampling_params = SamplingParams(temperature=0, max_tokens=max_new_tokens, 40 | stop=["Question:"]) 41 | tokenizer = transformers.AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) 42 | return (llm, sampling_params), tokenizer 43 | 44 | 45 | def preprocess(test_df): 46 | res_df = [] 47 | for each in test_df: 48 | options = [] 49 | for opt in each["choices"]: 50 | if opt == "N/A": 51 | continue 52 | options.append(opt) 53 | each["choices"] = options 54 | res_df.append(each) 55 | return res_df 56 | 57 | 58 | def args_generate_path(input_args): 59 | scoring_method = "CoT" 60 | model_name = input_args.model.split("/")[-1] 61 | subjects = args.selected_subjects.replace(",", "-").replace(" ", "_") 62 | concat_k = f'Concat_{args.concat_k}' 63 | return [model_name, scoring_method, subjects, concat_k] 64 | 65 | 66 | def select_by_category(df, subject): 67 | res = [] 68 | for each in df: 69 | if each["subject"] == subject: 70 | res.append(each) 71 | return res 72 | 73 | 74 | def format_cot_example(example, including_answer=True): 75 | prompt = "Question:\n" 76 | question = example["question"] 77 | options = example["choices"] 78 | prompt += question + "\n" 79 | prompt += "Options:\n" 80 | for i, opt in enumerate(options): 81 | prompt += "{}. {}\n".format(choices[i], opt) 82 | if including_answer: 83 | cot_content = example["cot_content"].replace("A: Let's think step by step.", 84 | "Answer: Let's think step by step.") 85 | prompt += cot_content + "\n\n" 86 | else: 87 | prompt += "Answer: Let's think step by step." 88 | return prompt 89 | 90 | 91 | def generate_cot_prompt(val_df, curr, k): 92 | prompt = "" 93 | with open(f"cot_prompt_lib/initial_prompt.txt", "r") as fi: 94 | for line in fi.readlines(): 95 | prompt += line 96 | subject = curr["subject"] 97 | val_df = select_by_category(val_df, subject) 98 | val_df = val_df[: k] 99 | prompt = prompt.replace("{$}", subject) + "\n" 100 | for example in val_df: 101 | prompt += format_cot_example(example, including_answer=True) 102 | prompt += format_cot_example(curr, including_answer=False) 103 | return prompt 104 | 105 | 106 | def generate_rag_cot_prompt(val_df, curr, k, hashed_retrieval_results): 107 | question = curr["question"] 108 | if question in hashed_retrieval_results.keys(): 109 | formatted_question = question 110 | else: 111 | choices = curr["choices"] 112 | subject = curr["subject"] 113 | instruction = f"The following are multiple choice questions (with answers) about {' '.join(subject.split('_'))}.\n\n" 114 | formatted_question = f"{instruction}{question.strip()}\nA. {choices[0]}\nB. {choices[1]}\nC. {choices[2]}\nD. {choices[3]}\nAnswer:" 115 | 116 | if formatted_question not in hashed_retrieval_results.keys(): 117 | if curr['subject'] == 'business_ethics': 118 | pass 119 | elif curr['subject'] == 'college_medicine': 120 | pass 121 | elif curr['subject'] == 'formal_logic': 122 | pass 123 | else: 124 | pdb.set_trace() 125 | rag_context = "" 126 | else: 127 | rag_context = "Related background:\n" 128 | for ctx in hashed_retrieval_results[formatted_question]: 129 | rag_context += ctx['retrieval text'] + '\n\n' 130 | 131 | prompt = "" 132 | with open(f"cot_prompt_lib/initial_prompt.txt", "r") as fi: 133 | for line in fi.readlines(): 134 | prompt += line 135 | subject = curr["subject"] 136 | val_df = select_by_category(val_df, subject) 137 | val_df = val_df[: k] 138 | prompt = prompt.replace("{$}", subject) + "\n" 139 | for example in val_df: 140 | prompt += format_cot_example(example, including_answer=True) 141 | prompt += format_cot_example(curr, including_answer=False) 142 | return rag_context + prompt 143 | 144 | 145 | def extract_answer(text): 146 | pattern = r"answer is \(?([A-D])\)?" 147 | match = re.search(pattern, text) 148 | if match: 149 | return match.group(1) 150 | else: 151 | print("1st answer extract failed\n" + text) 152 | return extract_again(text) 153 | 154 | 155 | def extract_again(text): 156 | match = re.search(r'.*[aA]nswer:\s*([A-D])', text) 157 | if match: 158 | return match.group(1) 159 | else: 160 | return extract_final(text) 161 | 162 | 163 | def extract_final(text): 164 | pattern = r"\b[A-D]\b(?!.*\b[A-D]\b)" 165 | match = re.search(pattern, text, re.DOTALL) 166 | if match: 167 | return match.group(0) 168 | else: 169 | return None 170 | 171 | 172 | def batch_inference(llm, sampling_params, inference_batch): 173 | start = time.time() 174 | outputs = llm.generate(inference_batch, sampling_params) 175 | logging.info(str(len(inference_batch)) + "size batch costing time: " + str(time.time() - start)) 176 | response_batch = [] 177 | pred_batch = [] 178 | for output in outputs: 179 | generated_text = output.outputs[0].text 180 | response_batch.append(generated_text) 181 | pred = extract_answer(generated_text) 182 | pred_batch.append(pred) 183 | return pred_batch, response_batch 184 | 185 | 186 | def save_res(res, output_path): 187 | accu, corr, wrong = 0.0, 0.0, 0.0 188 | with open(output_path, "w") as fo: 189 | fo.write(json.dumps(res)) 190 | for each in res: 191 | if not each["pred"]: 192 | random.seed(12345) 193 | x = random.randint(0, len(each["choices"]) - 1) 194 | if x == each["answer"]: 195 | corr += 1 196 | # print("random hit.") 197 | else: 198 | wrong += 1 199 | elif each["pred"] == ids_to_letters[each["answer"]]: 200 | corr += 1 201 | else: 202 | wrong += 1 203 | if corr + wrong == 0: 204 | return 0.0, 0.0, 0.0 205 | accu = corr / (corr + wrong) 206 | return accu, corr, wrong 207 | 208 | 209 | def get_hashed_retrieval_results(retrieval_path, k=3, raw_query_file=None): 210 | if retrieval_path is None or k == 0: 211 | return {} 212 | 213 | data = [] 214 | with open(retrieval_path, 'r') as fin: 215 | for line in fin: 216 | data.append(json.loads(line)) 217 | 218 | if raw_query_file is not None: 219 | print(f"Using raw queries provided in {raw_query_file} to create hashes") 220 | raw_queries = [] 221 | with open(raw_query_file, 'r') as fin: 222 | for line in fin: 223 | raw_queries.append(json.loads(line)) 224 | 225 | assert len(raw_queries) == len(data), f"Raw queries size mismatched: expecting {len(data)}, got {len(raw_queries)}." 226 | 227 | hashed_retrieval_results = {} 228 | for ex, raw_query in zip(data, raw_queries): 229 | query = raw_query['query'] 230 | topk_ctxs = ex['ctxs'][:min(k, len(ex['ctxs']))] 231 | key = query 232 | print(key) 233 | hashed_retrieval_results[key] = topk_ctxs 234 | 235 | else: 236 | hashed_retrieval_results = {} 237 | for ex in data: 238 | query = ex['question'] if 'question' in ex else ex['raw_query'] 239 | topk_ctxs = ex['ctxs'][:min(k, len(ex['ctxs']))] 240 | 241 | key = query 242 | print(key) 243 | hashed_retrieval_results[key] = topk_ctxs 244 | 245 | return hashed_retrieval_results 246 | 247 | 248 | @torch.no_grad() 249 | def eval_cot(subject, model, tokenizer, val_df, test_df, output_path): 250 | llm, sampling_params = model 251 | global choices 252 | logging.info("evaluating " + subject) 253 | inference_batches = [] 254 | 255 | for i in tqdm(range(len(test_df))): 256 | k = args.ntrain 257 | assert k == 0, "Error: MMLU does not have CoT fewshot examples." 258 | curr = test_df[i] 259 | prompt = generate_cot_prompt(val_df, curr, k) 260 | message = [{"role": "system", "content": "You are a helpful assistent."}, {"role": "user", "content": prompt}] 261 | prompt = tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False) 262 | inference_batches.append(prompt) 263 | 264 | pred_batch, response_batch = batch_inference(llm, sampling_params, inference_batches) 265 | res = [] 266 | for j, curr in enumerate(test_df): 267 | curr["pred"] = pred_batch[j] 268 | curr["model_outputs"] = response_batch[j] 269 | res.append(curr) 270 | accu, corr, wrong = save_res(res, output_path) 271 | logging.info("this batch accu is: {}, corr: {}, wrong: {}\n".format(str(accu), str(corr), str(wrong))) 272 | 273 | accu, corr, wrong = save_res(res, output_path) 274 | return accu, corr, wrong 275 | 276 | 277 | @torch.no_grad() 278 | def eval_rag_cot(subject, model, tokenizer, val_df, test_df, output_path, hashed_retrieval_results): 279 | llm, sampling_params = model 280 | global choices 281 | logging.info("evaluating " + subject) 282 | inference_batches = [] 283 | 284 | for i in tqdm(range(len(test_df))): 285 | k = args.ntrain 286 | assert k == 0, "Error: MMLU does not have CoT fewshot examples." 287 | curr = test_df[i] 288 | prompt = generate_rag_cot_prompt(val_df, curr, k, hashed_retrieval_results) 289 | message = [{"role": "system", "content": "You are a helpful assistent."}, {"role": "user", "content": prompt}] 290 | prompt = tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False) 291 | inference_batches.append(prompt) 292 | 293 | pred_batch, response_batch = batch_inference(llm, sampling_params, inference_batches) 294 | res = [] 295 | for j, curr in enumerate(test_df): 296 | curr["pred"] = pred_batch[j] 297 | curr["model_outputs"] = response_batch[j] 298 | res.append(curr) 299 | accu, corr, wrong = save_res(res, output_path) 300 | logging.info("this batch accu is: {}, corr: {}, wrong: {}\n".format(str(accu), str(corr), str(wrong))) 301 | 302 | accu, corr, wrong = save_res(res, output_path) 303 | return accu, corr, wrong 304 | 305 | 306 | def main(): 307 | model, tokenizer = load_model() 308 | if not os.path.exists(save_result_dir): 309 | os.makedirs(save_result_dir) 310 | 311 | full_test_df, full_val_df = load_mmlu() 312 | all_subjects = [] 313 | for each in full_test_df: 314 | if each["subject"] not in all_subjects: 315 | all_subjects.append(each["subject"]) 316 | if args.selected_subjects == "all": 317 | selected_subjects = all_subjects 318 | else: 319 | selected_subjects = [] 320 | args_selected = args.selected_subjects.split(",") 321 | for sub in all_subjects: 322 | for each in args_selected: 323 | if each.replace(" ", "_") in sub.replace(" ", "_"): 324 | selected_subjects.append(sub) 325 | 326 | # Check every subject is assigned a group properly 327 | groups, subject_to_group = get_mmlu_group_subjects() 328 | for subject in selected_subjects: 329 | assert subject in subject_to_group.keys(), subject 330 | print(subject_to_group[subject]) 331 | 332 | # Load retrieval file 333 | use_rag = args.retrieval_file is not None and args.concat_k > 0 334 | if use_rag: 335 | hashed_retrieval_results = get_hashed_retrieval_results(args.retrieval_file, args.concat_k, args.raw_query_file) 336 | 337 | # Start evaluation 338 | logging.info("selected subjects:\n" + "\n".join(selected_subjects)) 339 | print("selected subjects:\n" + "\n".join(selected_subjects)) 340 | sta_dict = {} 341 | selected_subjects = sorted(selected_subjects) 342 | with open(os.path.join(summary_path), 'a') as f: 343 | f.write("\n------category level sta------\n") 344 | for subject in selected_subjects: 345 | if subject not in sta_dict: 346 | sta_dict[subject] = {"corr": 0.0, "wrong": 0.0, "accu": 0.0} 347 | test_df = select_by_category(full_test_df, subject) 348 | val_df = select_by_category(full_val_df, subject) 349 | output_path = os.path.join(save_result_dir, "{}.json".format(subject)) 350 | if use_rag: 351 | acc, corr_count, wrong_count = eval_rag_cot(subject, model, tokenizer, val_df, test_df, output_path, hashed_retrieval_results) 352 | else: 353 | acc, corr_count, wrong_count = eval_cot(subject, model, tokenizer, val_df, test_df, output_path) 354 | sta_dict[subject]["corr"] = corr_count 355 | sta_dict[subject]["wrong"] = wrong_count 356 | sta_dict[subject]["accu"] = acc 357 | with open(os.path.join(summary_path), 'a') as f: 358 | f.write("Average accuracy {:.4f} - {}\n".format(sta_dict[subject]["accu"], subject)) 359 | 360 | # Compute group-wise accuracy 361 | group_weighted_acc, group_names = [], [] 362 | for group_name, group in groups.items(): 363 | total_corr, total_wrong = 0.0, 0.0 364 | for k, v in sta_dict.items(): 365 | if k in group: 366 | total_corr += v["corr"] 367 | total_wrong += v["wrong"] 368 | total_accu = total_corr / (total_corr + total_wrong + 0.000001) 369 | # sta_dict[group_name] = {"corr": total_corr, "wrong": total_wrong, "accu": total_accu} 370 | group_names.append(group_name) 371 | group_weighted_acc.append(total_accu) 372 | 373 | # Compute total accuracy 374 | total_corr, total_wrong = 0.0, 0.0 375 | for k, v in sta_dict.items(): 376 | total_corr += v["corr"] 377 | total_wrong += v["wrong"] 378 | total_accu = total_corr / (total_corr + total_wrong + 0.000001) 379 | sta_dict["total"] = {"corr": total_corr, "wrong": total_wrong, "accu": total_accu} 380 | 381 | # Write out results 382 | with open(os.path.join(summary_path), 'a') as f: 383 | f.write("\n------group-wise average acc sta------\n") 384 | for group_name, weighted_acc in zip(group_names, group_weighted_acc): 385 | f.write("Average accuracy ({}): {:.4f}\n".format(group_name, weighted_acc)) 386 | f.write("\n------average acc sta------\n") 387 | weighted_acc = total_accu 388 | f.write("Average accuracy: {:.4f}\n".format(weighted_acc)) 389 | with open(global_record_file, 'a', newline='') as file: 390 | writer = csv.writer(file) 391 | record = args_generate_path(args) + [time_str, weighted_acc] + group_weighted_acc 392 | writer.writerow(record) 393 | 394 | 395 | if __name__ == "__main__": 396 | parser = argparse.ArgumentParser() 397 | parser.add_argument("--ntrain", "-k", type=int, default=0) 398 | parser.add_argument("--selected_subjects", "-sub", type=str, default="all") 399 | parser.add_argument("--save_dir", "-s", type=str, default="results") 400 | parser.add_argument("--global_record_file", "-grf", type=str, 401 | default="eval_record_collection.csv") 402 | parser.add_argument("--gpu_util", "-gu", type=str, default="0.8") 403 | parser.add_argument("--model", "-m", type=str, default="meta-llama/Llama-2-7b-hf") 404 | parser.add_argument("--retrieval_file", type=str, default=None) 405 | parser.add_argument("--raw_query_file", type=str, default=None, 406 | help="Pass the original query to this argument, where the order of the query matches the reasoning query in the retrieval_file.") 407 | parser.add_argument("--concat_k", type=int, default=0) 408 | 409 | args = parser.parse_args() 410 | os.makedirs(args.save_dir, exist_ok=True) 411 | global_record_file = args.global_record_file 412 | save_result_dir = os.path.join( 413 | args.save_dir, "/".join(args_generate_path(args)) 414 | ) 415 | file_prefix = "-".join(args_generate_path(args)) 416 | timestamp = time.time() 417 | time_str = time.strftime('%m-%d_%H-%M', time.localtime(timestamp)) 418 | file_name = f"{file_prefix}_{time_str}_summary.txt" 419 | summary_path = os.path.join(args.save_dir, "summary", file_name) 420 | os.makedirs(os.path.join(args.save_dir, "summary"), exist_ok=True) 421 | os.makedirs(save_result_dir, exist_ok=True) 422 | save_log_dir = os.path.join(args.save_dir, "log") 423 | os.makedirs(save_log_dir, exist_ok=True) 424 | logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s', 425 | handlers=[logging.FileHandler(os.path.join(save_log_dir, 426 | file_name.replace("_summary.txt", 427 | "_logfile.log"))), 428 | logging.StreamHandler(sys.stdout)]) 429 | 430 | main() 431 | 432 | 433 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. 401 | --------------------------------------------------------------------------------