├── .gitignore ├── LICENSE ├── README.md ├── configs ├── dataset_info.json ├── llama3_full_dpo.yaml ├── llama3_full_sft.yaml ├── qwen2.5_3b_full_dpo.yaml └── qwen2.5_3b_full_sft.yaml ├── datasets.zip ├── index.py ├── llm_reader.py └── slm_decompose.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Ran Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Collab-RAG 2 | 3 | Here is the code repo of our paper Collab-RAG: Boosting Retrieval-Augmented Generation for Complex Question Answering via White-Box and Black-Box LLM Collaboration. 4 | 5 | ## Dataset and Indexing 6 | The raw dataset we use is stored in `datasets.zip`, following the instruction in this link: [link](https://github.com/sunnynexus/Search-o1). 7 | The corpus can be found at [this link](https://huggingface.co/datasets/BeIR/hotpotqa/blob/main/corpus.jsonl.gz). Please use the following command to generate the embeddings for each dataset (suppose the corpus file is named `corpus.jsonl`). 8 | ``` 9 | CUDA_VISIBLE_DEVICES=0 process_wiki.py \ 10 | --shard_id={0,1,..,7} \ 11 | --shards=8 \ 12 | --sentence_embedding_model facebook/dragon-plus-context-encoder \ 13 | --total_data 6000000 \ 14 | --sentence_embedding_model_save_name dragon \ 15 | --dataset hpqa 16 | ``` 17 | 18 | ## White Box SLM Decomposition 19 | Please use the command for using SLM for break down questions: 20 | ``` 21 | CUDA_VISIBLE_DEVICES=0,1 python slm_decompose.py \ 22 | --model_path meta-llama/Llama-3.1-8B-Instruct (you can change to your trained model later) \ 23 | --tokenizer meta-llama/Llama-3.1-8B-Instruct \ 24 | --temperature 1.0 \ 25 | --tensor_parallel_size 2 \ 26 | --expname [Your Experiment Name] \ 27 | --save_dir test 28 | ``` 29 | After this step, the decomposed question will be saved at 30 | `f"test/{dataset}/prompts_decompose_test_t{args.temperature}_{model_name}/generate.jsonl` 31 | 32 | ## Black Box LLM Reader 33 | Please use the command for using SLM for break down questions: 34 | ``` 35 | CUDA_VISIBLE_DEVICES=0 python llm_reader.py \ 36 | --llm_model gpt-4o-mini \ 37 | --expname [Your Experiment Name] \ 38 | --temperature 1.0 \ 39 | --save_dir test 40 | ``` 41 | 42 | The final answer will be saved in `{args.save_dir}/output/{args.dataset}/f"prompts_{args.llm_model}-{args.expname}.jsonl`. 43 | 44 | ## Using Llama Factory for Finetuning 45 | Please directly add the following files into the llama factory repo: 46 | 47 | - **dataset_info.json** 48 | - **llama3_{sft,dpo}.yaml** 49 | 50 | See the `configs` folder for details. The processed data will be available shortly -- stay tuned! 51 | 52 | 53 | 54 | 55 | ## Citation 56 | If you find this repository helpful, please kindly consider citing the corresponding paper. Thanks! 57 | ``` 58 | @article{xu2025collab, 59 | title={Collab-RAG: Boosting Retrieval-Augmented Generation for Complex Question Answering via White-Box and Black-Box LLM Collaboration}, 60 | author={Xu, Ran and Shi, Wenqi and Zhuang, Yuchen and Yu, Yue and Ho, Joyce C and Wang, Haoyu and Yang, Carl}, 61 | journal={arXiv preprint}, 62 | year={2025} 63 | } 64 | ``` 65 | -------------------------------------------------------------------------------- /configs/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "identity": { 3 | "file_name": "identity.json" 4 | }, 5 | "alpaca_en_demo": { 6 | "file_name": "alpaca_en_demo.json" 7 | }, 8 | "alpaca_zh_demo": { 9 | "file_name": "alpaca_zh_demo.json" 10 | }, 11 | "2wiki_3b_sft_r1": { 12 | "file_name": "2wikimultihopqa/2wikimultihopqa_qwen-2.5-3b_sft_r1.json", 13 | "formatting": "sharegpt" 14 | }, 15 | "2wiki_3b_dpo_r1": { 16 | "file_name": "2wikimultihopqa/2wikimultihopqa_qwen-2.5-3b_dpo_r1.json", 17 | "ranking": true, 18 | "formatting": "sharegpt", 19 | "columns": { 20 | "messages": "conversations", 21 | "chosen": "chosen", 22 | "rejected": "rejected" 23 | } 24 | }, 25 | "2wiki_8b_sft_r1": { 26 | "file_name": "2wikimultihopqa/2wikimultihopqa_llama-3.1-8b_sft_r1.json", 27 | "formatting": "sharegpt" 28 | }, 29 | "2wiki_8b_dpo_r1": { 30 | "file_name": "2wikimultihopqa/2wikimultihopqa_llama-3.1-8b_dpo_r1.json", 31 | "ranking": true, 32 | "formatting": "sharegpt", 33 | "columns": { 34 | "messages": "conversations", 35 | "chosen": "chosen", 36 | "rejected": "rejected" 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /configs/llama3_full_dpo.yaml: -------------------------------------------------------------------------------- 1 | # llamafactory-cli train examples/train_full/llama3_full_dpo.yaml 2 | 3 | ### model 4 | model_name_or_path: [your checkpoint directory, example: meta-llama/Llama-3.1-8B-Instruct] 5 | trust_remote_code: true 6 | 7 | ### method 8 | stage: dpo 9 | do_train: true 10 | finetuning_type: full 11 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] 12 | pref_beta: 0.1 13 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo] 14 | pref_ftx: 1.0 15 | 16 | ### dataset 17 | dataset: [your dataset name, please define it in dataset_info.json] 18 | template: llama3 19 | cutoff_len: 2048 20 | max_samples: 1000000 21 | overwrite_cache: true 22 | preprocessing_num_workers: 16 23 | 24 | ### output 25 | output_dir: [your output directory for model ckpt. Example: saves/llama3-8b/full/dpo] 26 | logging_steps: 10 27 | save_steps: 200 28 | plot_loss: true 29 | overwrite_output_dir: true 30 | 31 | ### train 32 | per_device_train_batch_size: 4 33 | gradient_accumulation_steps: 2 34 | learning_rate: 8.0e-7 35 | num_train_epochs: 2.0 36 | lr_scheduler_type: cosine 37 | warmup_ratio: 0.1 38 | bf16: true 39 | ddp_timeout: 180000000 40 | 41 | ### eval 42 | val_size: 0.03 43 | per_device_eval_batch_size: 1 44 | eval_strategy: steps 45 | eval_steps: 30 46 | -------------------------------------------------------------------------------- /configs/llama3_full_sft.yaml: -------------------------------------------------------------------------------- 1 | # llamafactory-cli train examples/train_full/llama3_full_sft.yaml 2 | 3 | ### model 4 | model_name_or_path: [your checkpoint directory, example: meta-llama/Llama-3.1-8B-Instruct] 5 | trust_remote_code: true 6 | 7 | ### method 8 | stage: sft 9 | do_train: true 10 | finetuning_type: full 11 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] 12 | 13 | ### dataset 14 | dataset: [your dataset name, please define it in dataset_info.json] 15 | template: llama3 16 | cutoff_len: 2304 17 | max_samples: 1000000 18 | overwrite_cache: true 19 | preprocessing_num_workers: 32 20 | 21 | ### output 22 | output_dir: [your output directory for model ckpt. Example: saves/llama3-8b/full/dpo] 23 | logging_steps: 5 24 | save_steps: 1000 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | save_only_model: true 28 | 29 | ### train 30 | per_device_train_batch_size: 8 31 | gradient_accumulation_steps: 2 32 | learning_rate: 1.0e-6 33 | num_train_epochs: 1.0 34 | lr_scheduler_type: cosine 35 | warmup_ratio: 0.1 36 | bf16: true 37 | ddp_timeout: 180000000 38 | 39 | ## eval 40 | val_size: 0.04 41 | per_device_eval_batch_size: 8 42 | eval_strategy: steps 43 | eval_steps: 200 44 | -------------------------------------------------------------------------------- /configs/qwen2.5_3b_full_dpo.yaml: -------------------------------------------------------------------------------- 1 | # llamafactory-cli train examples/train_full/qwen2.5_3b_full_dpo.yaml 2 | 3 | ### model 4 | model_name_or_path: [your checkpoint directory, example: Qwen/Qwen2.5-3B-Instruct] 5 | trust_remote_code: true 6 | 7 | ### method 8 | stage: dpo 9 | do_train: true 10 | finetuning_type: full 11 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] 12 | pref_beta: 0.1 13 | pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo] 14 | pref_ftx: 0.5 15 | 16 | ### dataset 17 | dataset: [your dataset name, please define it in dataset_info.json] 18 | template: qwen 19 | cutoff_len: 2048 20 | max_samples: 10000 21 | overwrite_cache: true 22 | preprocessing_num_workers: 16 23 | 24 | ### output 25 | output_dir: [your output directory for model ckpt. Example: saves/qwen2.5-3b/full/dpo] 26 | logging_steps: 20 27 | # save_steps: 150 28 | save_strategy: epoch 29 | save_only_model: true 30 | plot_loss: true 31 | overwrite_output_dir: true 32 | 33 | ### train 34 | per_device_train_batch_size: 4 35 | gradient_accumulation_steps: 2 36 | learning_rate: 2.0e-6 37 | num_train_epochs: 2.0 38 | lr_scheduler_type: cosine 39 | warmup_ratio: 0.1 40 | bf16: true 41 | ddp_timeout: 180000000 42 | 43 | #@# eval 44 | val_size: 0.03 45 | per_device_eval_batch_size: 2 46 | eval_strategy: steps 47 | eval_steps: 30 48 | -------------------------------------------------------------------------------- /configs/qwen2.5_3b_full_sft.yaml: -------------------------------------------------------------------------------- 1 | # llamafactory-cli train examples/train_full/qwen2.5_3b_full_sft.yaml 2 | 3 | ### model 4 | model_name_or_path: [your checkpoint directory, example: Qwen/Qwen2.5-3B-Instruct] 5 | trust_remote_code: true 6 | 7 | ### method 8 | stage: sft 9 | do_train: true 10 | finetuning_type: full 11 | deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] 12 | 13 | ### dataset 14 | dataset: [your dataset name, please define it in dataset_info.json] 15 | template: qwen 16 | cutoff_len: 2048 17 | max_samples: 1000000 18 | overwrite_cache: true 19 | preprocessing_num_workers: 16 20 | 21 | ### output 22 | output_dir: [your output directory for model ckpt. Example: saves/qwen2.5-3b/full/sft] 23 | logging_steps: 10 24 | save_steps: 200 25 | plot_loss: true 26 | overwrite_output_dir: true 27 | save_only_model: true 28 | 29 | ### train 30 | per_device_train_batch_size: 8 31 | gradient_accumulation_steps: 1 32 | learning_rate: 2.0e-6 33 | num_train_epochs: 1.0 34 | lr_scheduler_type: cosine 35 | warmup_ratio: 0.1 36 | bf16: true 37 | ddp_timeout: 180000000 38 | 39 | ## eval 40 | val_size: 0.03 41 | per_device_eval_batch_size: 4 42 | eval_strategy: steps 43 | eval_steps: 50 44 | -------------------------------------------------------------------------------- /datasets.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ritaranx/Collab-RAG/13021a230eb0dc3a313c9f41c635cc0eea8eb294/datasets.zip -------------------------------------------------------------------------------- /index.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle as pkl 3 | import torch.nn.functional as F 4 | from tqdm import tqdm, trange 5 | import argparse 6 | import csv 7 | from transformers import AutoTokenizer, AutoModel 8 | import torch, os 9 | from torch import Tensor 10 | 11 | 12 | # step 1: download wikipedia corpora from `https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz` 13 | # wget https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz 14 | # then you will get the file named `psgs_w100.tsv` 15 | # run the command python CUDA_VISIBLE_DEVICES=0 process_wiki.py --shard_id=1 --shards=8 --sentence_embedding_model facebook/dragon-plus-context-encoder 16 | 17 | def average_pool(last_hidden_states: Tensor, 18 | attention_mask: Tensor) -> Tensor: 19 | last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) 20 | return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 21 | 22 | 23 | def gen_embedding(args, sentences, tokenizer, model, batch_size = 512, normalize = False): 24 | # Tokenize sentences 25 | sentence_embeddings = [] 26 | num_iter = len(sentences)//batch_size if len(sentences) % batch_size == 0 else (len(sentences)//batch_size + 1) 27 | 28 | for i in trange(num_iter): 29 | if args.sentence_embedding_model in ["intfloat/e5-large-v2"]: 30 | passages = ["passage: " + x for x in sentences[i*batch_size:(i+1)*batch_size]] 31 | encoded_input = tokenizer(passages, max_length = 200, padding=True, truncation=True, return_tensors='pt') #.to(f"cuda:{gpuid}") 32 | if i == 0: 33 | print("Example input:\t\t", passages[0]) 34 | else: 35 | encoded_input = tokenizer(sentences[i*batch_size:(i+1)*batch_size], max_length = 200, padding=True, truncation=True, return_tensors='pt') #.to(f"cuda:{gpuid}") 36 | if i == 0: 37 | print("Example input:\t\t", sentences[i*batch_size]) 38 | for x in encoded_input: 39 | encoded_input[x] = encoded_input[x].cuda() 40 | 41 | # Compute token embeddings 42 | if args.sentence_embedding_model in ['bert', 'simcse']: 43 | with torch.no_grad(): 44 | embeddings = model(**encoded_input, output_hidden_states=True, return_dict=True).pooler_output 45 | sentence_embeddings.append(embeddings.detach().cpu()) 46 | elif args.sentence_embedding_model in ['BAAI/bge-large-en-v1.5', "Alibaba-NLP/gte-base-en-v1.5"]: 47 | with torch.no_grad(): 48 | model_output = model(**encoded_input) 49 | # Perform pooling. In this case, cls pooling. 50 | sentence_embeddings.append(model_output.last_hidden_state[:, 0].detach().cpu()) 51 | elif args.sentence_embedding_model in ["intfloat/e5-large-v2"]: 52 | with torch.no_grad(): 53 | outputs = model(**encoded_input) 54 | embeddings = average_pool(outputs.last_hidden_state, encoded_input['attention_mask']) 55 | sentence_embeddings.append(embeddings.detach().cpu()) 56 | elif args.sentence_embedding_model in ['facebook/dragon-plus-context-encoder', "OpenMatch/cocodr-base-msmarco", "OpenMatch/cocodr-large-msmarco"]:# dragon 57 | with torch.no_grad(): 58 | embeddings = model(**encoded_input, output_hidden_states=True, return_dict=True).last_hidden_state[:, 0, :] # the embedding of the [CLS] token after the final layer 59 | sentence_embeddings.append(embeddings.detach().cpu()) 60 | 61 | sentence_embeddings = torch.cat(sentence_embeddings, dim = 0) 62 | print("shape:", sentence_embeddings.shape) 63 | if args.sentence_embedding_model in ["intfloat/e5-large-v2", "Alibaba-NLP/gte-base-en-v1.5"]: # set true for E5 and GTE, no need to use this for dragon!! 64 | sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1).squeeze(1) 65 | print(f"{args.sentence_embedding_model}, {args.sentence_embedding_model_save_name}, {args.dataset}, With Normalization!") 66 | else: 67 | print(f"{args.sentence_embedding_model}, {args.sentence_embedding_model_save_name}, {args.dataset}, NO Normalization!") 68 | print("Sentence embeddings Shape:", sentence_embeddings.shape) 69 | return sentence_embeddings.detach().cpu().numpy() 70 | 71 | # Sentences we want sentence embeddings for 72 | if __name__ == "__main__": 73 | parser = argparse.ArgumentParser("") 74 | parser.add_argument("--shard_id", type=int, default=0) 75 | parser.add_argument("--shards", type=int, default=8) 76 | parser.add_argument("--dataset", type=str, default=0) 77 | parser.add_argument("--total_data", type=int, default=22000000) 78 | parser.add_argument("--sentence_embedding_model", type=str, default="facebook/dragon-plus-context-encoder") #facebook/dragon-plus-context-encoder, "OpenMatch/cocodr-base-msmarco" 79 | parser.add_argument("--sentence_embedding_model_save_name", type=str, default="dragon", choices = ["coco_base", "coco_large", "dragon", "gte-base", "e5-large"]) 80 | args = parser.parse_args() 81 | 82 | # Load model from HuggingFace Hub 83 | tokenizer = AutoTokenizer.from_pretrained(args.sentence_embedding_model, trust_remote_code=True) 84 | model = AutoModel.from_pretrained(args.sentence_embedding_model, trust_remote_code=True).cuda() 85 | model.eval() 86 | cnt = 0 87 | sentences = [] 88 | idxs = [] 89 | print("start", args.total_data // args.shards * args.shard_id, "end", args.total_data // args.shards * (1+args.shard_id)) 90 | if args.dataset in ["wiki"]: 91 | target_file = f"{args.dataset}/psgs_w100.tsv" if args.dataset in ["wiki"] else f"{args.dataset}/corpus.tsv" 92 | with open(target_file, "r") as f: 93 | reader = csv.reader(f, delimiter = '\t') 94 | for lines in tqdm(reader): 95 | if lines[0] == "id": 96 | continue 97 | idx = int(lines[0]) 98 | text = lines[1] 99 | title = lines[2] 100 | if idx <= (args.total_data // args.shards * (args.shard_id)): 101 | continue 102 | elif idx > (args.total_data // args.shards * (1 + args.shard_id)): 103 | break 104 | else: 105 | sentences.append(text) 106 | idxs.append(idx) 107 | elif args.dataset == "hpqa": 108 | with open(f"corpus.jsonl", "r") as f: 109 | idx = 0 110 | for lines in tqdm(f): 111 | data = json.loads(lines) 112 | text = data["text"] 113 | title = data["text"] + " " + data["title"] 114 | if idx <= (args.total_data // args.shards * (args.shard_id)): 115 | idx += 1 116 | continue 117 | elif idx > (args.total_data // args.shards * (1 + args.shard_id)): 118 | break 119 | else: 120 | sentences.append(text) 121 | idxs.append(idx) 122 | idx += 1 123 | 124 | embeddings = gen_embedding(args, sentences, tokenizer, model, batch_size = 256) 125 | os.makedirs(f"{args.dataset}/{args.sentence_embedding_model_save_name}/", exist_ok = True) 126 | with open(f"{args.dataset}/{args.sentence_embedding_model_save_name}/embeddings-{args.shard_id}-of-{args.shards}.pkl", "wb") as f: 127 | pkl.dump(embeddings, f) -------------------------------------------------------------------------------- /llm_reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | import json 5 | import csv 6 | import copy 7 | import pickle as pkl 8 | import argparse 9 | from typing import List, Dict 10 | 11 | import torch 12 | import numpy as np 13 | import torch.nn.functional as F 14 | import faiss 15 | from tqdm import tqdm, trange 16 | from transformers import AutoTokenizer, AutoModel 17 | import openai 18 | 19 | from prompts import prompt 20 | 21 | # ------------------------ 22 | # Configurable constants 23 | # ------------------------ 24 | HF_CACHE_DIR = os.getenv("HF_CACHE_DIR", "./hf_cache") 25 | DATA_DIR = os.getenv("DATA_DIR", "./data") 26 | AZURE_API_BASE_GPT4O = os.getenv("AZURE_API_BASE_GPT4O", "https://your-azure-gpt4o-instance.openai.azure.com/") 27 | AZURE_API_BASE_GPT4O_MINI = os.getenv("AZURE_API_BASE_GPT4O_MINI", "https://your-azure-gpt4o-mini-instance.openai.azure.com/") 28 | AZURE_API_KEY_GPT4O = os.getenv("AZURE_API_KEY_GPT4O", "your-gpt4o-key") 29 | AZURE_API_KEY_GPT4O_MINI = os.getenv("AZURE_API_KEY_GPT4O_MINI", "your-gpt4o-mini-key") 30 | ENGINE_NAME = "your-engine-name" 31 | # ------------------------ 32 | # Utility functions 33 | # ------------------------ 34 | def prRed(s): print("\033[91m {}\033[00m".format(s)) 35 | def prPurple(s): print("\033[95m {}\033[00m".format(s)) 36 | def prYellow(s): print("\033[93m {}\033[00m".format(s)) 37 | def prLightPurple(s): print("\033[94m {}\033[00m".format(s)) 38 | 39 | # ------------------------ 40 | # LLM API Wrapper 41 | # ------------------------ 42 | def call_llm(prompt, model) -> str: 43 | if model == "gpt-4o": 44 | openai.api_type = "azure" 45 | openai.api_base = AZURE_API_BASE_GPT4O 46 | openai.api_version = "2024-05-01-preview" 47 | openai.api_key = AZURE_API_KEY_GPT4O 48 | engine = ENGINE_NAME 49 | elif model == "gpt-4o-mini": 50 | openai.api_type = "azure" 51 | openai.api_base = AZURE_API_BASE_GPT4O_MINI 52 | openai.api_version = "2024-02-15-preview" 53 | openai.api_key = AZURE_API_KEY_GPT4O_MINI 54 | engine = ENGINE_NAME 55 | 56 | try_times, success = 0, False 57 | while not success and try_times < 3: 58 | try: 59 | response = openai.ChatCompletion.create( 60 | engine=engine, 61 | messages=[{"role": "user", "content": prompt}], 62 | temperature=0, 63 | max_tokens=1000 64 | ) 65 | return response.choices[0].message["content"].strip() 66 | except (openai.error.Timeout, 67 | openai.error.RateLimitError, 68 | openai.error.APIConnectionError, 69 | openai.error.APIError, 70 | openai.error.InvalidRequestError, 71 | Exception): 72 | print("Retrying due to API error...") 73 | time.sleep(5) 74 | try_times += 1 75 | return "Failed" 76 | 77 | # ------------------------ 78 | # Question Decomposition 79 | # ------------------------ 80 | def decompose_question(question: str, prompt: str, llm_model: str = "gpt-4o") -> List[Dict]: 81 | call_prompt = prompt.replace("${question}", question) 82 | response = call_llm(call_prompt, model=llm_model) 83 | decomposed_questions = [] 84 | for line in response.strip().split("\n"): 85 | line = line.strip() 86 | if line.startswith("### Q"): 87 | q_part, ctx_part = line.split("## Need Context? ##") 88 | q_text = q_part.split(":", 1)[1].strip() 89 | needs_context = ctx_part.strip().lower().startswith("yes") 90 | q_label = q_part.split(":")[0].strip("#").strip() 91 | decomposed_questions.append({ 92 | "label": q_label, 93 | "text": q_text, 94 | "needs_context": needs_context 95 | }) 96 | return decomposed_questions 97 | 98 | # ------------------------ 99 | # Embedding + Retrieval 100 | # ------------------------ 101 | def load_embedding(dataset: str): 102 | sentences, passage_embeddings = [], [] 103 | with open(os.path.join(DATA_DIR, dataset, "corpus.tsv"), "r") as f: 104 | reader = csv.reader(f, delimiter='\t') 105 | for lines in tqdm(reader): 106 | if lines[0] != "id": 107 | sentences.append(lines[1]) 108 | for i in trange(4): 109 | with open(os.path.join(DATA_DIR, dataset, f"embeddings-{i}-of-4.pkl"), "rb") as f: 110 | passage_embeddings.append(pkl.load(f)) 111 | return sentences, np.concatenate(passage_embeddings, axis=0) 112 | 113 | def embed_text(text, tokenizer, model, normalize=True) -> List[float]: 114 | encoded = tokenizer(text, max_length=100, padding=True, truncation=True, return_tensors='pt') 115 | encoded = {k: v.cuda() for k, v in encoded.items()} 116 | with torch.no_grad(): 117 | output = model(**encoded).last_hidden_state[:, 0].detach().cpu() 118 | return F.normalize(output, p=2, dim=1).squeeze(1) if normalize else output 119 | 120 | def retrieve_context(query, cpu_index, corpus, tokenizer, model, top_k=3): 121 | q_embedding = embed_text(query, tokenizer, model, False) 122 | _, indices = cpu_index.search(q_embedding, top_k) 123 | return [corpus[i] for i in indices[0]] 124 | 125 | # ------------------------ 126 | # Sub-question Answering 127 | # ------------------------ 128 | def zigzag_visit(lst): 129 | return [lst[i] for i in range(0, len(lst), 2)] + [lst[i] for i in range(len(lst)-1 if len(lst)%2==0 else len(lst)-2, 0, -2)] 130 | 131 | def answer_sub_question(sub_q, passages, model): 132 | context = "\n\n".join(zigzag_visit(passages)) 133 | prompt = f"""You have the following context passages: 134 | {context} 135 | 136 | Question: {sub_q} 137 | 138 | Answer to the original question should be with one or a list of entities with the given context as the reference. 139 | If you do not find an answer in the context, please use your own knowledge to answer it. 140 | Do not give any explanation. Your answer needs to be as short as possible with one or a list of entities.""" 141 | return call_llm(prompt, model=model) 142 | 143 | # ------------------------ 144 | # Final Answer Generation 145 | # ------------------------ 146 | def generate_final_answer(original_q, sub_qs, sub_as, model): 147 | text = "\n".join([f"{k}: {sub_qs[k]}, Answer for {k}: {v}" for k, v in sub_as.items()]) 148 | prompt = f"""Original question: {original_q} 149 | 150 | We have the following decomposed subquestions and sub-answers: 151 | {text} 152 | 153 | Based on these sub-answers, provide the final concise answer to the original question: \"{original_q}\". 154 | Do not give an explanation. Answer needs to be as short as possible and should be with one or a list of entities.""" 155 | return call_llm(prompt, model=model) 156 | 157 | def multi_turn_qa(question, sub_questions, gold_answer, tokenizer, embedding_model, llm_model): 158 | subq_dict = {q["label"]: q["text"] for q in sub_questions} 159 | answer_dict = {} 160 | passage_dict = {} 161 | 162 | for subq in sub_questions: 163 | label = subq["label"] 164 | text = subq["text"] 165 | needs_context = subq.get("needs_context", True) 166 | 167 | # Replace #n placeholders with prior answers 168 | resolved_text = replace_placeholders(text, answer_dict) 169 | 170 | passages = [] 171 | if needs_context: 172 | passages = retrieve_context(resolved_text, index, corpus, tokenizer, embedding_model, top_k=9) 173 | 174 | answer = answer_sub_question(resolved_text, passages, llm_model) 175 | answer_dict[label] = answer 176 | passage_dict[label] = passages 177 | 178 | final_answer = generate_final_answer(question, subq_dict, answer_dict, llm_model) 179 | print(f"-------\nquestion: {question}\nsubq: {sub_questions}\nanswer: {answer_dict}\nPred: {final_answer} gold: {gold_answer}") 180 | return final_answer, answer_dict, passage_dict 181 | 182 | def replace_placeholders(text: str, answers_so_far: Dict[str, str]) -> str: 183 | matches = re.findall(r"#(\\d+)", text) 184 | for m in matches: 185 | placeholder = f"#{m}" 186 | q_key = f"Q{m}" 187 | if q_key in answers_so_far: 188 | text = text.replace(placeholder, answers_so_far[q_key]) 189 | return text 190 | 191 | # ------------------------ 192 | # Main Function 193 | # ------------------------ 194 | if __name__ == "__main__": 195 | parser = argparse.ArgumentParser() 196 | parser.add_argument("--llm_model", type=str, default="gpt-4o") 197 | parser.add_argument("--dataset", type=str, default="hotpotqa") 198 | parser.add_argument("--expname", type=str, default="") 199 | parser.add_argument("--temperature", type=float, default=0.1) 200 | parser.add_argument("--save_dir", type=str, default="test") 201 | args = parser.parse_args() 202 | 203 | sentence_embedding_model = "facebook/dragon-plus-query-encoder" 204 | tokenizer = AutoTokenizer.from_pretrained(sentence_embedding_model, cache_dir=HF_CACHE_DIR, trust_remote_code=True) 205 | embedding_model = AutoModel.from_pretrained(sentence_embedding_model, cache_dir=HF_CACHE_DIR, trust_remote_code=True).cuda() 206 | embedding_model.eval() 207 | 208 | questions_path = os.path.join(args.save_dir, args.dataset, f"prompts_decompose_test_t{args.temperature}_{args.expname}/generate.jsonl") 209 | with open(questions_path, "r") as f: 210 | questions = [json.loads(line) for line in f] 211 | 212 | corpus, embeddings = load_embedding(args.dataset) 213 | index = faiss.IndexFlatIP(embeddings.shape[1]) 214 | faiss.omp_set_num_threads(32) 215 | index.add(embeddings) 216 | 217 | saved = [] 218 | for idx, item in enumerate(tqdm(questions)): 219 | try: 220 | final, inter_ans, inter_pass = multi_turn_qa(item["question"], item["decomposed"], item["answer"], tokenizer, embedding_model, args.llm_model) 221 | item.update({"index": idx, "final_answer": final, "intermediate_answers": inter_ans, "intermediate_passages": inter_pass}) 222 | saved.append(item) 223 | except Exception as e: 224 | print(f"Error for item {idx}: {e}") 225 | 226 | if saved: 227 | out_path = os.path.join(args.save_dir, "output", args.dataset, f"prompts_{args.llm_model}-{args.expname}.jsonl") 228 | with open(out_path, "w") as f: 229 | for ex in saved: 230 | f.write(json.dumps(ex) + '\n') 231 | -------------------------------------------------------------------------------- /slm_decompose.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from vllm import LLM, SamplingParams 3 | import argparse 4 | import os 5 | from tqdm import trange 6 | import copy 7 | import json 8 | 9 | parser = argparse.ArgumentParser("") 10 | parser.add_argument("--tokenizer", type=str, default="meta-llama/Llama-3.1-8B-Instruct") 11 | parser.add_argument("--model_path", type=str, default="models/llama-3.1-8b-instruct") 12 | parser.add_argument("--expname", type=str, default="") 13 | parser.add_argument("--temperature", type=float, default=0.0) 14 | parser.add_argument("--top_p", type=float, default=0.99) 15 | parser.add_argument("--tensor_parallel_size", type=int, default=1) 16 | parser.add_argument("--save_dir", type=str, default="test") 17 | args = parser.parse_args() 18 | 19 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) 20 | 21 | sampling_params = SamplingParams( 22 | temperature=args.temperature, 23 | top_p=args.top_p, 24 | repetition_penalty=1.05, 25 | max_tokens=2048 26 | ) 27 | 28 | llm = LLM( 29 | model=args.model_path, 30 | tensor_parallel_size=args.tensor_parallel_size, 31 | gpu_memory_utilization=0.88, 32 | trust_remote_code=True 33 | ) 34 | 35 | datasets = ["bamboogle", "2wikimultihopqa", "hotpotqa", "musique", "strategyqa"] 36 | model_name = args.expname 37 | 38 | prompt = """Please break down the given question into multiple specific sub-questions that address individual components of the original question. 39 | Please generate the decomposed sub-questions for the below question. The sub-question should be labeled with a reference to previous answers (e.g., #1) when needed. For example, #1 means the answer for decomposed question 1. 40 | The token after the question `## Need Context` stands for whether the decomposed question needs external corpus to answer. 'Yes' means it needs external corpus to answer, 'No' means it can be directly answered without retrieval. 41 | Here are four examples: 42 | 43 | [[Begin of the Example 1]] 44 | ## Question: 45 | What is the average winter daytime temperature in the region containing Richmond, in the state where WXBX is located? 46 | 47 | ## Decomposed Question: 48 | ### Q1: Which state is WXBX located? ## Need Context? ## Yes 49 | ### Q2: In which of #1 's regions is Richmond? ## Need Context? ## Yes 50 | ### Q3: What is the average winter daytime temperature in #2? ## Need Context? ## Yes 51 | [[End of the Example 1]] 52 | 53 | [[Begin of the Example 2]] 54 | ## Question: 55 | How long was the place where the Yongle Emperor greeted the person to whom the edict was addressed the capitol of the area where Guangling District was located? 56 | 57 | ## Decomposed Question: 58 | ### Q1: Who was the edict addressed to? ## Need Context? ## Yes 59 | ### Q2: Where did the Yongle Emperor greet #1 ? ## Need Context? ## Yes 60 | ### Q3: Where does Guangling District locate? ## Need Context? ## Yes 61 | ### Q4: How long had #2 been the capital city of #3 ? ## Need Context? ## Yes 62 | [[End of the Example 2]] 63 | 64 | Now, decompose the following question: 65 | ## Question: 66 | ${question} 67 | 68 | ## Decomposed Question:""" 69 | 70 | for dataset in datasets: 71 | prompts = [] 72 | contexts = [] 73 | 74 | with open(f"./processed_data/{dataset}/test.jsonl", "r") as f: 75 | for line in f: 76 | example = json.loads(line) 77 | question = example["question_text"] 78 | answer_spans = [span for obj in example["answers_objects"] for span in obj["spans"]] 79 | item = {"question": question, "answer": answer_spans} 80 | prompt_text = prompt.replace("${question}", question) 81 | prompts.append([{"role": "user", "content": prompt_text.strip()}]) 82 | contexts.append(item) 83 | 84 | print(dataset, len(prompts), len(contexts)) 85 | examples = [] 86 | output_dir = f"{args.save_dir}/{dataset}/prompts_decompose_test_t{args.temperature}_{model_name}" 87 | os.makedirs(output_dir, exist_ok=True) 88 | 89 | for i in trange(len(prompts)): 90 | text = tokenizer.apply_chat_template(prompts[i], tokenize=False, add_generation_prompt=True) 91 | print(text) 92 | 93 | N_samples = 1 if args.temperature == 0 else 3 94 | for j in range(N_samples): 95 | ctx = copy.deepcopy(contexts[i]) 96 | outputs = llm.generate([text], sampling_params) 97 | generated_text = outputs[0].outputs[0].text 98 | if j == 0: 99 | print(len(outputs)) 100 | print('======\n', generated_text, '\n======') 101 | 102 | decomposed_questions = [] 103 | for line in generated_text.strip().split("\n"): 104 | line = line.strip() 105 | if line.startswith("### Q"): 106 | try: 107 | question_part, context_part = line.split("## Need Context? ##") 108 | question_text = question_part.split(":", 1)[1].strip() 109 | needs_context = context_part.strip().lower().startswith("yes") 110 | q_label = "Q" + question_part.split(":")[0].split("Q")[-1].strip() 111 | decomposed_questions.append({ 112 | "label": q_label, 113 | "text": question_text, 114 | "needs_context": needs_context 115 | }) 116 | except: 117 | decomposed_questions = "Error" 118 | print(f"Error in decomposing \n\n {generated_text} \n\n") 119 | break 120 | 121 | ctx["question_id"] = i 122 | ctx["decompose_id"] = j 123 | ctx["decomposed"] = decomposed_questions 124 | examples.append(ctx) 125 | 126 | with open(f"{output_dir}/generate.jsonl", "w") as f: 127 | for example in examples: 128 | f.write(json.dumps(example) + "\n") 129 | 130 | --------------------------------------------------------------------------------