├── .gitignore ├── LICENSE ├── README.md ├── eval ├── onehop │ └── get_onehop_union.py ├── reranking │ └── rerank.py └── retrieval │ ├── bm25.py │ ├── build_index.py │ ├── e5.py │ ├── evaluate_index.py │ ├── grit.py │ ├── gtr.py │ ├── instructor.py │ └── kv_store.py └── utils ├── openai_utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | # LitSearch 165 | results/ 166 | figures/ 167 | query_sets/ 168 | corpus/ 169 | retrieval_indices/ 170 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Princeton Natural Language Processing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LitSearch 2 | 3 | This repository contains the code and data for paper [LitSearch: A Retrieval Benchmark for Scientific **Lit**erature **Search**](https://arxiv.org/abs/2407.18940). In this paper, we introduce a benchmark consisting of a set of 597 realistic literature search queries about recent ML and NLP papers. We provide the code we used for benchmarking state-of-the-art retrieval models and two LLM-based reranking pipelines. 4 | 5 | LitSearch 6 | 7 | ## Requirements 8 | Please install the latest versions of PyTorch (`torch`), NumPy (`numpy`), HuggingFace Transformers (`transformers`), HuggingFace Datasets (`datasets`), SentenceTransformers (`sentence-transformers`), InstructorEmbedding (`InstructorEmbedding`), Rank-BM25 (`rank-bm25`), GritLM (`gritlm`) and the OpenAI API package (`openai`). This codebase is tested on `torch==1.13.1`, `numpy==1.23.5`, `transformers==4.30.2`, `datasets==2.20.0`, `sentence-transformers==2.2.2`, `InstructorEmbedding==1.0.1`, `rank-bm25==0.2.2`, `gritlm==1.0.0` and `openai==1.33.0` with Python 3.10.14. 9 | 10 | Note: We used a standalone environment for GritLM since its dependencies were incompatible with other packages. 11 | 12 | ## Data 13 | We provide the LitSearch query set and retrieval corpus as separate HuggingFace `datasets` configurations under [`princeton-nlp/LitSearch`](https://huggingface.co/datasets/princeton-nlp/LitSearch). We also provide the retrieval corpus in the Semantic Scholar Open Research Corpus (S2ORC) format along with all available metadata to facilitate exploration of retrieval strategies more advanced than the ones we implement in this codebase. The data can be downloaded using the `datasets` package using 14 | ```python 15 | from datasets import load_dataset 16 | 17 | query_data = load_dataset("princeton-nlp/LitSearch", "query", split="full") 18 | corpus_clean_data = load_dataset("princeton-nlp/LitSearch", "corpus_clean", split="full") 19 | corpus_s2orc_data = load_dataset("princeton-nlp/LitSearch", "corpus_s2orc", split="full") 20 | ``` 21 | 22 | ## Code Structure 23 | * `eval/retrieval/` 24 | * Contains a parent class for retrievers in `kv_store.py` and implementations of 5 retrieval pipelines including BM25 (`bm25.py`), GTR (`gtr.py`), Instructor (`instructor.py`), E5 (`e5.py`) and GRIT (`grit.py`). 25 | * Contains `build_index.py` for building a retrieval index of the required type using a given retrieval corpus. 26 | * Contains `evaluate_index.py` for evaluating a retriever using the associated retrieval index and a query set. 27 | * `eval/reranking/rerank.py` contains code for reranking a provided set of retrieval results using GPT4. This code is adapted from [Rank-GPT](https://github.com/sunnweiwei/RankGPT). 28 | * `eval/onehop/get_onehop_union.py` contains code that implements the first stage of the one-hop reranking operation described in section 3.2 of our paper. Once the union is computed using this script, GPT4-based reranking is applied as before using `eval/reranking/rerank.py`. 29 | 30 | ## Evaluation 31 | This repository provides support for running evaluations using the BM25, GTR, Instructor, E5 and GRIT retrievers, reranking using GPT-4, and executing a one-hop reranking strategy. We provide sample commands for running the corresponding scripts: 32 | 33 | #### Build retrieval index 34 | ```bash 35 | python3 -m eval.retrieval.build_index --index_type bm25 --key title_abstract 36 | ``` 37 | 38 | #### Run retrieval using built index 39 | ```bash 40 | python3 -m eval.retrieval.evaluate_index --index_name LitSearch.title_abstract.bm25 41 | ``` 42 | 43 | #### Run GPT-4-based reranking 44 | ```bash 45 | python3 -m eval.reranking.rerank --retrieval_results_file results/retrieval/LitSearch.title_abstract.bm25.jsonl 46 | ``` 47 | 48 | #### Run one-hop strategy (union + reranking) 49 | ```bash 50 | python3 -m eval.onehop.get_onehop_union --input_path results/retrieval/LitSearch.title_abstract.bm25.jsonl 51 | python3 -m eval.reranking.rerank --retrieval_results_file results/onehop/prereranking/LitSearch.title_abstract.bm25.union.jsonl --output_dir results/onehop/postreranking --max_k 200 52 | ``` 53 | 54 | ## Bug or Questions? 55 | 56 | If you have any questions related to the code or the paper, feel free to email Anirudh (`anirudh.ajith@princeton.edu`). If you encounter any problems when using the code, or want to report a bug, you can open an issue. Please try to specify the problem with details so we can help you better and quicker! 57 | 58 | ## Citation 59 | 60 | Please cite our paper if you use LitSearch in your work: 61 | ```bibtex 62 | @inproceedings{ajith2024litsearch, 63 | title={LitSearch: A Retrieval Benchmark for Scientific Literature Search}, 64 | author={Ajith, Anirudh and Xia, Mengzhou and Chevalier, Alexis and Goyal, Tanya and Chen, Danqi and Gao, Tianyu}, 65 | booktitle={Empirical Methods in Natural Language Processing (EMNLP)}, 66 | year={2024} 67 | } 68 | 69 | ``` 70 | -------------------------------------------------------------------------------- /eval/onehop/get_onehop_union.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import argparse 4 | import datasets 5 | from tqdm import tqdm 6 | from utils import utils 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--input_path", type=str, required=True) 10 | 11 | parser.add_argument("--output_dir", required=False, default="results/onehop/prereranking") 12 | parser.add_argument("--cutoff_count", type=int, required=False, default=50) 13 | parser.add_argument("--total_count", type=int, required=False, default=200) 14 | parser.add_argument("--dataset_path", required=False, default="princeton-nlp/LitSearch") 15 | args = parser.parse_args() 16 | 17 | original_retrieval_results = utils.read_json(args.input_path) 18 | corpus_data = utils.get_clean_dict(datasets.load_dataset(args.dataset_path, "corpus_clean", split="full")) 19 | 20 | union_retrieval_results = [] 21 | for result in tqdm(original_retrieval_results): 22 | original_retrieved_ids = result["retrieved"][:args.cutoff_count] 23 | union_retrieved_ids = copy.deepcopy(original_retrieved_ids) 24 | citation_counts = [] 25 | for paper_id in original_retrieved_ids: 26 | cited_papers = utils.get_clean_citations(corpus_data[paper_id]) 27 | citation_counts.append(len(cited_papers)) 28 | for cited_paper in cited_papers: 29 | if cited_paper not in union_retrieved_ids: 30 | union_retrieved_ids.append(cited_paper) 31 | if len(union_retrieved_ids) >= args.total_count: 32 | break 33 | 34 | result["original_retrieved"] = result["retrieved"] 35 | result["retrieved"] = union_retrieved_ids[:args.total_count] 36 | union_retrieval_results.append(result) 37 | 38 | os.makedirs(args.output_dir, exist_ok=True) 39 | output_path = os.path.join(args.output_dir, os.path.basename(args.input_path).replace(".json", ".union.json")) 40 | utils.write_json(union_retrieval_results, output_path) 41 | -------------------------------------------------------------------------------- /eval/reranking/rerank.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import json 4 | import argparse 5 | import datasets 6 | from tqdm import tqdm 7 | from typing import List, Tuple 8 | from utils import utils 9 | from utils.openai_utils import OPENAIBaseEngine 10 | 11 | ###### QUERY CONSTRUCTION FUNCTIONS ###### 12 | def create_prompt_messages(item: dict, rank_start: int, rank_end: int, index_type: str) -> List[dict]: 13 | query = item['query'] 14 | num_docs = len(item['documents'][rank_start:rank_end]) 15 | 16 | if index_type == "title_abstract": 17 | messages = [{'role': 'system', 'content': "You are RankGPT, an intelligent assistant that can rank papers based on their relevancy to a research query."}, 18 | {'role': 'user', 'content': f"I will provide you with the abstracts of {num_docs} papers, each indicated by number identifier []. \nRank the papers based on their relevance to research question: {query}."}, 19 | {'role': 'assistant', 'content': 'Okay, please provide the papers.'}] 20 | max_length = 300 21 | elif index_type == "full_paper": 22 | messages = [{'role': 'system', 'content': "You are RankGPT, an intelligent assistant that can rank papers based on their relevancy to a research query."}, 23 | {'role': 'user', 'content': f"I will provide you with {num_docs} papers, each indicated by number identifier []. \nRank the papers based on their relevance to research question: {query}."}, 24 | {'role': 'assistant', 'content': 'Okay, please provide the papers.'}] 25 | max_length = 10000 26 | else: 27 | raise ValueError(f"Invalid index type: {index_type}") 28 | 29 | for rank, document in enumerate(item['documents'][rank_start: rank_end]): 30 | content = document['content'].replace('Title: Content: ', '').strip() 31 | content = ' '.join(content.split()[:int(max_length)]) 32 | messages.append({'role': 'user', 'content': f"[{rank+1}] {content}"}) 33 | messages.append({'role': 'assistant', 'content': f'Received passage [{rank+1}].'}) 34 | postfix_prompt = f"Search Query: {query}. \nRank the {num_docs} papers above based on their relevance to the research query. The papers should be listed in descending order using identifiers. The most relevant papers should be listed first. The output format should be [] > [], e.g., [1] > [2]. Only respond with the ranking results, do not say any words or explain." 35 | messages.append({'role': 'user', 'content': postfix_prompt}) 36 | return messages 37 | 38 | ###### RESPONSE PROCESSING FUNCTIONS ###### 39 | def clean_response(response: str): 40 | new_response = '' 41 | for c in response: 42 | new_response += (c if c.isdigit() else ' ') 43 | new_response = new_response.strip() 44 | return new_response 45 | 46 | def remove_duplicate(response): 47 | new_response = [] 48 | for c in response: 49 | if c not in new_response: 50 | new_response.append(c) 51 | return new_response 52 | 53 | def receive_permutation(item, permutation, rank_start, rank_end): 54 | response = clean_response(permutation) 55 | response = [int(x) - 1 for x in response.split()] 56 | response = remove_duplicate(response) 57 | cut_range = copy.deepcopy(item['documents'][rank_start: rank_end]) 58 | original_rank = [tt for tt in range(len(cut_range))] 59 | response = [ss for ss in response if ss in original_rank] 60 | response = response + [tt for tt in original_rank if tt not in response] 61 | for j, x in enumerate(response): 62 | item['documents'][j + rank_start] = copy.deepcopy(cut_range[x]) 63 | if 'rank' in item['documents'][j + rank_start]: 64 | item['documents'][j + rank_start]['rank'] = cut_range[j]['rank'] 65 | if 'score' in item['documents'][j + rank_start]: 66 | item['documents'][j + rank_start]['score'] = cut_range[j]['score'] 67 | return item 68 | 69 | def permutation_pipeline(model: OPENAIBaseEngine, item: dict, rank_start: int, rank_end: int, index_type: str) -> dict: 70 | decrement_rate = (rank_end - rank_start) // 5 71 | min_count = (rank_end - rank_start) // 2 72 | 73 | while rank_end - rank_start >= min_count: 74 | try: 75 | messages = create_prompt_messages(item, rank_start, rank_end, index_type) 76 | permutation = utils.prompt_gpt4_model(model, messages=messages) 77 | return receive_permutation(item, permutation, rank_start, rank_end) 78 | except Exception as e: # the context window might be overflowing; reduce the number of documents and try again; 79 | rank_end -= decrement_rate 80 | print(f"Error: context window overflow; reducing the number of documents to {rank_end - rank_start}") 81 | print(f"Error: unable to rerank the documents. Returning the original order.") 82 | return item 83 | 84 | if __name__ == "__main__": 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument("--retrieval_results_file", type=str, required=True) 87 | 88 | parser.add_argument("--model", type=str, help="Simulator LLM", default="gpt-4-1106-preview") 89 | parser.add_argument("--max_k", default=100, type=int, help="Max number of retrieved documents to rerank") 90 | parser.add_argument("--output_dir", type=str, required=False, default="results/reranking/") 91 | parser.add_argument("--dataset_path", required=False, default="princeton-nlp/LitSearch") 92 | args = parser.parse_args() 93 | 94 | corpus_data = datasets.load_dataset(args.dataset_path, "corpus_clean", split="full") 95 | retrieval_results = utils.read_json(args.retrieval_results_file) 96 | model = utils.get_gpt4_model(args.model, azure=True) 97 | 98 | os.makedirs(args.output_dir, exist_ok=True) 99 | output_file = os.path.join(args.output_dir, os.path.basename(args.retrieval_results_file).replace(".json", ".reranked.json")) 100 | 101 | index_type = os.path.basename(args.retrieval_results_file).split(".")[1] 102 | if index_type == "title_abstract": 103 | corpusid_to_text = {utils.get_clean_corpusid(item): utils.get_clean_title_abstract(item) for item in corpus_data} 104 | elif index_type == "full_paper": 105 | corpusid_to_text = {utils.get_clean_corpusid(item): utils.get_clean_full_paper(item) for item in corpus_data} 106 | else: 107 | raise ValueError(f"Invalid index type: {index_type}") 108 | 109 | # truncate retrieval results to max_k 110 | for result in retrieval_results: 111 | result["retrieved"] = result["retrieved"][:args.max_k] 112 | 113 | ###### RERANKING ###### 114 | # put retrieval results into format required by reranking pipeline 115 | reranking_inputs = [] 116 | for query_info in retrieval_results: 117 | reranking_inputs.append({ 118 | "query": query_info["query"], 119 | "documents": [{ 120 | "content": corpusid_to_text[retrieved_corpusid], 121 | "corpusid": retrieved_corpusid 122 | } for retrieved_corpusid in query_info["retrieved"]] 123 | }) 124 | 125 | # rerank 126 | if not os.path.exists(output_file): 127 | reranking_outputs = copy.deepcopy(retrieval_results) 128 | utils.write_json(reranking_outputs, output_file) 129 | 130 | for item_idx, item in enumerate(tqdm(reranking_inputs)): 131 | reranking_outputs = utils.read_json(output_file) 132 | if "pre_reranked" not in reranking_outputs[item_idx]: 133 | reranked_item = permutation_pipeline(model, item, rank_start=0, rank_end=len(item["documents"]), index_type=index_type) 134 | reranking_outputs[item_idx]["pre_reranked"] = reranking_outputs[item_idx]["retrieved"] 135 | reranking_outputs[item_idx]["retrieved"] = [document["corpusid"] for document in reranked_item["documents"]] 136 | utils.write_json(reranking_outputs, output_file, silent=True) # save after each iteration 137 | -------------------------------------------------------------------------------- /eval/retrieval/bm25.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import numpy as np 3 | from tqdm import tqdm 4 | from rank_bm25 import BM25Okapi 5 | from typing import List, Tuple, Any 6 | from eval.retrieval.kv_store import KVStore 7 | from eval.retrieval.kv_store import TextType 8 | 9 | class BM25(KVStore): 10 | def __init__(self, index_name: str): 11 | super().__init__(index_name, 'bm25') 12 | 13 | nltk.download('punkt') 14 | nltk.download('stopwords') 15 | 16 | self._tokenizer = nltk.word_tokenize 17 | self._stop_words = set(nltk.corpus.stopwords.words('english')) 18 | self._stemmer = nltk.stem.PorterStemmer().stem 19 | self.index = None # BM25 index 20 | 21 | def _encode_batch(self, texts: str, type: TextType, show_progress_bar: bool = True) -> List[str]: 22 | # lowercase, tokenize, remove stopwords, and stem 23 | tokens_list = [] 24 | for text in tqdm(texts, disable=not show_progress_bar): 25 | tokens = self._tokenizer(text.lower()) 26 | tokens = [token for token in tokens if token not in self._stop_words] 27 | tokens = [self._stemmer(token) for token in tokens] 28 | tokens_list.append(tokens) 29 | return tokens_list 30 | 31 | def _query(self, encoded_query: List[str], n: int) -> List[int]: 32 | top_indices = np.argsort(self.index.get_scores(encoded_query))[::-1][:n].tolist() 33 | return top_indices 34 | 35 | def clear(self) -> None: 36 | super().clear() 37 | self.index = None 38 | 39 | def create_index(self, key_value_pairs: List[Tuple[str, Any]]) -> None: 40 | super().create_index(key_value_pairs) 41 | self.index = BM25Okapi(self.encoded_keys) 42 | 43 | def load(self, dir_name: str) -> None: 44 | super().load(dir_name) 45 | self._tokenizer = nltk.word_tokenize 46 | self._stop_words = set(nltk.corpus.stopwords.words('english')) 47 | self._stemmer = nltk.stem.PorterStemmer().stem 48 | return self -------------------------------------------------------------------------------- /eval/retrieval/build_index.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import datasets 4 | from typing import List 5 | from utils import utils 6 | from eval.retrieval.kv_store import KVStore 7 | 8 | def get_index_name(args: argparse.Namespace) -> str: 9 | return os.path.basename(args.dataset_path) + "." + args.key 10 | 11 | def create_index(args: argparse.Namespace) -> KVStore: 12 | index_name = get_index_name(args) 13 | 14 | if args.index_type == "bm25": 15 | from eval.retrieval.bm25 import BM25 16 | index = BM25(index_name) 17 | elif args.index_type == "instructor": 18 | from eval.retrieval.instructor import Instructor 19 | if args.key == "title_abstract": 20 | query_instruction = "Represent the research question for retrieving relevant research paper abstracts:" 21 | key_instruction = "Represent the title and abstract of the research paper for retrieval:" 22 | elif args.key == "full_paper": 23 | query_instruction = "Represent the research question for retrieving relevant research papers:" 24 | key_instruction = "Represent the research paper for retrieval:" 25 | elif args.key == "paragraphs": 26 | query_instruction = "Represent the research question for retrieving passages from relevant research papers:" 27 | key_instruction = "Represent the passage from the research paper for retrieval:" 28 | else: 29 | raise ValueError("Invalid key") 30 | index = Instructor(index_name, key_instruction, query_instruction) 31 | elif args.index_type == "e5": 32 | from eval.retrieval.e5 import E5 33 | index = E5(index_name) 34 | elif args.index_type == "gtr": 35 | from eval.retrieval.gtr import GTR 36 | index = GTR(index_name) 37 | elif args.index_type == "grit": 38 | from eval.retrieval.grit import GRIT 39 | if args.key == "title_abstract": 40 | raw_instruction = "Given a research query, retrieve the title and abstract of the relevant research paper" 41 | elif args.key == "full_paper": 42 | raw_instruction = "Given a research query, retrieve the relevant research paper" 43 | elif args.key == "paragraphs": 44 | raw_instruction = "Given a research query, retrieve the passage from the relevant research paper" 45 | else: 46 | raise ValueError("Invalid key") 47 | index = GRIT(index_name, raw_instruction) 48 | else: 49 | raise ValueError("Invalid index type") 50 | return index 51 | 52 | def create_kv_pairs(data: List[dict], key: str) -> dict: 53 | if key == "title_abstract": 54 | kv_pairs = {utils.get_clean_title_abstract(record): utils.get_clean_corpusid(record) for record in data} 55 | elif key == "full_paper": 56 | kv_pairs = {utils.get_clean_full_paper(record): utils.get_clean_corpusid(record) for record in data} 57 | elif key == "paragraphs": 58 | kv_pairs = {} 59 | for record in data: 60 | corpusid = utils.get_clean_corpusid(record) 61 | paragraphs = utils.get_clean_paragraphs(record) 62 | for paragraph_idx, paragraph in enumerate(paragraphs): 63 | kv_pairs[paragraph] = (corpusid, paragraph_idx) 64 | else: 65 | raise ValueError("Invalid key") 66 | return kv_pairs 67 | 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument("--index_type", required=True) # bm25, instructor, e5, gtr, grit 70 | parser.add_argument("--key", required=True), # title_absract, full_paper, paragraphs 71 | 72 | parser.add_argument("--dataset_path", required=False, default="princeton-nlp/LitSearch") 73 | parser.add_argument("--index_root_dir", required=False, default="retrieval_indices") 74 | args = parser.parse_args() 75 | 76 | corpus_data = datasets.load_dataset(args.dataset_path, "corpus_clean", split="full") 77 | index = create_index(args) 78 | kv_pairs = create_kv_pairs(corpus_data, args.key) 79 | index.create_index(kv_pairs) 80 | 81 | index_name = get_index_name(args) 82 | index.save(args.index_root_dir) 83 | -------------------------------------------------------------------------------- /eval/retrieval/e5.py: -------------------------------------------------------------------------------- 1 | import sentence_transformers 2 | import numpy as np 3 | from typing import List, Any 4 | from sklearn.metrics.pairwise import cosine_similarity 5 | from eval.retrieval.kv_store import KVStore 6 | from eval.retrieval.kv_store import TextType 7 | from utils import utils 8 | 9 | class E5(KVStore): 10 | def __init__(self, index_name: str, model_path: str = "intfloat/e5-large-v2"): 11 | super().__init__(index_name, 'e5') 12 | self.model_path = model_path 13 | self._model = sentence_transformers.SentenceTransformer(model_path, device="cuda", cache_folder=utils.get_cache_dir()).bfloat16() 14 | 15 | def _format_text(self, text: str, type: TextType) -> str: 16 | if type == TextType.KEY: 17 | text = "passage: " + text 18 | elif type == TextType.QUERY: 19 | text = "query: " + text 20 | else: 21 | raise ValueError("Invalid TextType") 22 | return text 23 | 24 | def _encode_batch(self, texts: List[str], type: TextType, show_progress_bar: bool = True) -> List[Any]: 25 | texts = [self._format_text(text, type) for text in texts] 26 | return self._model.encode(texts, batch_size=256, normalize_embeddings=True, show_progress_bar=show_progress_bar).astype(np.float16) 27 | 28 | def _query(self, encoded_query: Any, n: int) -> List[int]: 29 | cosine_similarities = cosine_similarity([encoded_query], self.encoded_keys)[0] 30 | top_indices = cosine_similarities.argsort()[-n:][::-1] 31 | return top_indices 32 | 33 | def load(self, path: str): 34 | super().load(path) 35 | self._model = sentence_transformers.SentenceTransformer(self.model_path, device="cuda", cache_folder=utils.get_cache_dir()) 36 | return self 37 | 38 | -------------------------------------------------------------------------------- /eval/retrieval/evaluate_index.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import datasets 4 | from tqdm import tqdm 5 | from utils import utils 6 | from eval.retrieval.kv_store import KVStore 7 | 8 | def load_index(index_path: str) -> KVStore: 9 | index_type = os.path.basename(index_path).split(".")[-1] 10 | if index_type == "bm25": 11 | from eval.retrieval.bm25 import BM25 12 | index = BM25(None).load(index_path) 13 | elif index_type == "instructor": 14 | from eval.retrieval.instructor import Instructor 15 | index = Instructor(None, None, None).load(index_path) 16 | elif index_type == "e5": 17 | from eval.retrieval.e5 import E5 18 | index = E5(None).load(index_path) 19 | elif index_type == "gtr": 20 | from eval.retrieval.gtr import GTR 21 | index = GTR(None).load(index_path) 22 | elif index_type == "grit": 23 | from eval.retrieval.grit import GRIT 24 | index = GRIT(None, None).load(index_path) 25 | else: 26 | raise ValueError("Invalid index type") 27 | return index 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--index_name", type=str, required=True) 31 | 32 | parser.add_argument("--top_k", type=int, required=False, default=200) 33 | parser.add_argument("--retrieval_results_root_dir", type=str, required=False, default="results/retrieval") 34 | parser.add_argument("--index_root_dir", type=str, required=False, default="retrieval_indices") 35 | parser.add_argument("--dataset_path", required=False, default="princeton-nlp/LitSearch") 36 | args = parser.parse_args() 37 | 38 | index = load_index(os.path.join(args.index_root_dir, args.index_name)) 39 | query_set = [query for query in datasets.load_dataset(args.dataset_path, "query", split="full")] 40 | for query in tqdm(query_set): 41 | query_text = query["query"] 42 | top_k = index.query(query_text, args.top_k) 43 | query["retrieved"] = top_k 44 | 45 | os.makedirs(args.retrieval_results_root_dir, exist_ok=True) 46 | output_path = os.path.join(args.retrieval_results_root_dir, f"{args.index_name}.jsonl") 47 | utils.write_json(query_set, output_path) 48 | -------------------------------------------------------------------------------- /eval/retrieval/grit.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List, Any 3 | from sklearn.metrics.pairwise import cosine_similarity 4 | from gritlm import GritLM 5 | from eval.retrieval.kv_store import KVStore 6 | from eval.retrieval.kv_store import TextType 7 | 8 | class GRIT(KVStore): 9 | def __init__(self, index_name: str, raw_instruction: str, model_path: str = "GritLM/GritLM-7B"): 10 | super().__init__(index_name, 'grit') 11 | self.model_path = model_path 12 | self.raw_instruction = raw_instruction 13 | self._model = GritLM(model_path, torch_dtype="auto", device_map="auto", mode="embedding") 14 | 15 | def _get_instruction(self, type: TextType) -> str: 16 | if type == TextType.KEY: 17 | return "<|embed|>\n" 18 | elif type == TextType.QUERY: 19 | return "<|user|>\n" + self.raw_instruction + "\n<|embed|>\n" 20 | else: 21 | raise ValueError("Invalid TextType") 22 | 23 | def _encode_batch(self, texts: List[str], type: TextType, show_progress_bar: bool = True) -> List[Any]: 24 | return self._model.encode(texts, batch_size=256, instruction=self._get_instruction(type), show_progress_bar=show_progress_bar).astype(np.float16) 25 | 26 | def _query(self, encoded_query: Any, n: int) -> List[int]: 27 | try: 28 | cosine_similarities = cosine_similarity([encoded_query], self.encoded_keys)[0] 29 | except: 30 | for i, encoded_key in enumerate(self.encoded_keys): 31 | if np.any(np.isnan(encoded_key)): 32 | self.encoded_keys[i] = np.zeros_like(encoded_key) 33 | cosine_similarities = cosine_similarity([encoded_query], self.encoded_keys)[0] 34 | top_indices = cosine_similarities.argsort()[-n:][::-1] 35 | return top_indices 36 | 37 | def load(self, path: str): 38 | super().load(path) 39 | self._model = GritLM(self.model_path, torch_dtype="auto", device_map="auto", mode="embedding") 40 | return self 41 | 42 | -------------------------------------------------------------------------------- /eval/retrieval/gtr.py: -------------------------------------------------------------------------------- 1 | import sentence_transformers 2 | import numpy as np 3 | from typing import List, Any 4 | from sklearn.metrics.pairwise import cosine_similarity 5 | from eval.retrieval.kv_store import KVStore 6 | from eval.retrieval.kv_store import TextType 7 | from utils import utils 8 | 9 | class GTR(KVStore): 10 | def __init__(self, index_name: str, model_path: str = "sentence-transformers/gtr-t5-large"): 11 | super().__init__(index_name, 'gtr') 12 | self.model_path = model_path 13 | self._model = sentence_transformers.SentenceTransformer(model_path, device="cuda", cache_folder=utils.get_cache_dir()) 14 | 15 | def _encode_batch(self, texts: List[str], type: TextType, show_progress_bar: bool = True) -> List[Any]: 16 | return self._model.encode(texts, batch_size=256, show_progress_bar=show_progress_bar).astype(np.float16) 17 | 18 | def _query(self, encoded_query: Any, n: int) -> List[int]: 19 | cosine_similarities = cosine_similarity([encoded_query], self.encoded_keys)[0] 20 | top_indices = cosine_similarities.argsort()[-n:][::-1] 21 | return top_indices 22 | 23 | def load(self, path: str): 24 | super().load(path) 25 | self._model = sentence_transformers.SentenceTransformer(self.model_path, device="cuda", cache_folder=utils.get_cache_dir()) 26 | return self 27 | 28 | -------------------------------------------------------------------------------- /eval/retrieval/instructor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List, Any 3 | from sklearn.metrics.pairwise import cosine_similarity 4 | from InstructorEmbedding import INSTRUCTOR 5 | from utils import utils 6 | from eval.retrieval.kv_store import KVStore 7 | from eval.retrieval.kv_store import TextType 8 | 9 | class Instructor(KVStore): 10 | def __init__(self, index_name: str, key_instruction: str, query_instruction: str, model_path: str = "hkunlp/instructor-xl"): 11 | super().__init__(index_name, 'instructor') 12 | self.model_path = model_path 13 | self.key_instruction = key_instruction 14 | self.query_instruction = query_instruction 15 | self._model = INSTRUCTOR(model_path, device="cuda", cache_folder=utils.get_cache_dir()) 16 | 17 | def _format_text(self, text: str, type: TextType) -> List[str]: 18 | if type == TextType.KEY: 19 | return [self.key_instruction, text] 20 | elif type == TextType.QUERY: 21 | return [self.query_instruction, text] 22 | else: 23 | raise ValueError("Invalid TextType") 24 | 25 | def _encode_batch(self, texts: List[str], type: TextType, show_progress_bar: bool = True) -> List[Any]: 26 | texts = [self._format_text(text, type) for text in texts] 27 | return self._model.encode(texts, batch_size=128, normalize_embeddings=True, show_progress_bar=show_progress_bar).astype(np.float16) 28 | 29 | def _query(self, encoded_query: Any, n: int) -> List[int]: 30 | cosine_similarities = cosine_similarity([encoded_query], self.encoded_keys)[0] 31 | top_indices = cosine_similarities.argsort()[-n:][::-1] 32 | return top_indices 33 | 34 | def load(self, path: str): 35 | super().load(path) 36 | self._model = INSTRUCTOR(self.model_path, device="cuda", cache_folder=utils.get_cache_dir()) 37 | return self 38 | 39 | -------------------------------------------------------------------------------- /eval/retrieval/kv_store.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from tqdm import tqdm 4 | from enum import Enum 5 | from typing import List, Tuple, Any 6 | 7 | class TextType(Enum): 8 | KEY = 1 9 | QUERY = 2 10 | 11 | class KVStore: 12 | def __init__(self, index_name: str, index_type: str) -> None: 13 | self.index_name = index_name 14 | self.index_type = index_type 15 | 16 | self.keys = [] 17 | self.encoded_keys = [] 18 | self.values = [] 19 | 20 | def __len__(self) -> int: 21 | return len(self.keys) 22 | 23 | def _encode(self, text: str, type: TextType) -> Any: 24 | return self._encode_batch([text], type, show_progress_bar=False)[0] 25 | 26 | def _encode_batch(self, texts: List[str], type: TextType, show_progress_bar: bool = True) -> List[Any]: 27 | raise NotImplementedError 28 | 29 | def _query(self, encoded_query: Any, n: int) -> List[int]: 30 | raise NotImplementedError 31 | 32 | def clear(self) -> None: 33 | self.keys = [] 34 | self.encoded_keys = [] 35 | self.values = [] 36 | 37 | def create_index(self, key_value_pairs: List[Tuple[str, Any]]) -> None: 38 | if len(self.keys) > 0: 39 | raise ValueError("Index is not empty. Please create a new index or clear the existing one.") 40 | 41 | for key, value in tqdm(key_value_pairs.items(), desc=f"Creating {self.index_name} index"): 42 | self.keys.append(key) 43 | self.values.append(value) 44 | self.encoded_keys = self._encode_batch(self.keys, TextType.KEY) 45 | 46 | def query(self, query_text: str, n: int, return_keys: bool = False) -> List[Any]: 47 | encoded_query = self._encode(query_text, TextType.QUERY) 48 | indices = self._query(encoded_query, n) 49 | if return_keys: 50 | results = [(self.keys[i], self.values[i]) for i in indices] 51 | else: 52 | results = [self.values[i] for i in indices] 53 | return results 54 | 55 | def save(self, dir_name: str) -> None: 56 | save_dict = {} 57 | for key, value in self.__dict__.items(): 58 | if key[0] != "_": 59 | save_dict[key] = value 60 | 61 | print(f"Saving index to {os.path.join(dir_name, f'{self.index_name}.{self.index_type}')}") 62 | os.makedirs(dir_name, exist_ok=True) 63 | with open(os.path.join(dir_name, f"{self.index_name}.{self.index_type}"), 'wb') as file: 64 | pickle.dump(save_dict, file, protocol=pickle.HIGHEST_PROTOCOL) 65 | 66 | 67 | def load(self, file_path: str) -> None: 68 | if len(self.keys) > 0: 69 | raise ValueError("Index is not empty. Please create a new index or clear the existing one before loading from disk.") 70 | 71 | print(f"Loading index from {file_path}...") 72 | with open(file_path, 'rb') as file: 73 | pickle_data = pickle.load(file) 74 | 75 | for key, value in pickle_data.items(): 76 | setattr(self, key, value) -------------------------------------------------------------------------------- /utils/openai_utils.py: -------------------------------------------------------------------------------- 1 | import openai 2 | from typing import List 3 | 4 | class OPENAIBaseEngine(): 5 | def __init__(self, model_name: str, azure: bool = True): 6 | if not azure: 7 | raise NotImplementedError("Only Azure API is supported") 8 | 9 | self.model_name = model_name 10 | self.client = openai.AzureOpenAI() 11 | 12 | def safe_completion(self, messages: List[dict], max_tokens: int = 2000, temperature: float = 0, top_p: float = 1): 13 | args_dict = { 14 | "max_tokens": max_tokens, 15 | "temperature": temperature, 16 | "top_p": top_p, 17 | } 18 | if top_p == 1.0: 19 | args_dict.pop("top_p") 20 | 21 | response = self.client.chat.completions.create(model=self.model_name, messages=messages, **args_dict).to_dict() 22 | return { 23 | "finish_reason": response["choices"][0]["finish_reason"], 24 | "content": response["choices"][0]["message"]["content"] 25 | } 26 | 27 | def test_api(self): 28 | print("Testing API connection") 29 | messages = [{"role": "user", "content": "Why did the chicken cross the road?"}] 30 | response = self.safe_completion(messages=messages, max_tokens=20, temperature=0, top_p=1.0) 31 | content = response["content"] 32 | 33 | if response["finish_reason"] == 'api_error': 34 | print(f'Error in connecting to API: {response}') 35 | else: 36 | print(f'Successful API connection: {content}') 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import List, Any, Tuple 4 | from datasets import Dataset 5 | from utils.openai_utils import OPENAIBaseEngine 6 | 7 | ##### file reading and writing ##### 8 | 9 | def read_json(filename: str, silent: bool = False) -> List[Any]: 10 | with open(filename, 'r') as file: 11 | if filename.endswith(".json"): 12 | data = json.load(file) 13 | elif filename.endswith(".jsonl"): 14 | data = [json.loads(line) for line in file] 15 | else: 16 | raise ValueError("Input file must be either a .json or .jsonl file") 17 | 18 | if not silent: 19 | print(f"Loaded {len(data)} records from {filename}") 20 | return data 21 | 22 | def write_json(data: List[Any], filename: str, silent: bool = False) -> None: 23 | if filename.endswith(".json"): 24 | with open(filename, 'w') as file: 25 | json.dump(data, file, indent=4) 26 | elif filename.endswith(".jsonl"): 27 | with open(filename, 'w') as file: 28 | for item in data: 29 | file.write(json.dumps(item) + "\n") 30 | else: 31 | raise ValueError("Output file must be either a .json or .jsonl file") 32 | 33 | if not silent: 34 | print(f"Saved {len(data)} records to {filename}") 35 | 36 | def read_txt(filename: str) -> str: 37 | with open(filename, 'r') as file: 38 | text = file.read() 39 | return text 40 | 41 | ##### evaluation metrics ##### 42 | 43 | def calculate_recall(retrieved: List[int], relevant_docs: List[int]) -> float: 44 | num_relevant_retrieved = len(set(retrieved).intersection(set(relevant_docs))) 45 | num_relevant = len(relevant_docs) 46 | recall = num_relevant_retrieved / num_relevant if num_relevant > 0 else 0 47 | return recall 48 | 49 | def calculate_ndcg(retrieved: List[int], relevant_docs: List[int]) -> float: 50 | dcg = 0 51 | for idx, docid in enumerate(retrieved): 52 | if docid in relevant_docs: 53 | dcg += 1 / (idx + 1) 54 | idcg = sum([1 / (idx + 1) for idx in range(len(relevant_docs))]) 55 | ndcg = dcg / idcg if idcg > 0 else 0 56 | return ndcg 57 | 58 | def calculate_ngram_overlap(query: str, text: str) -> float: 59 | query_ngrams = set(query.split()) 60 | text_ngrams = set(text.split()) 61 | overlap = len(query_ngrams.intersection(text_ngrams)) / len(query_ngrams) 62 | return overlap 63 | 64 | ##### reading fields from corpus_s2orc ##### 65 | 66 | def get_s2orc_corpusid(item: dict) -> int: 67 | return item['corpusid'] 68 | 69 | def get_s2orc_title(item: dict) -> str: 70 | try: 71 | title_info = json.loads(item['content']['annotations']['title']) 72 | title_start, title_end = title_info[0]['start'], title_info[0]['end'] 73 | return get_s2orc_text(item, title_start, title_end) 74 | except: 75 | return "" 76 | 77 | def get_s2orc_abstract(item: dict) -> str: 78 | try: 79 | abstract_info = json.loads(item['content']['annotations']['abstract']) 80 | abstract_start, abstract_end = abstract_info[0]['start'], abstract_info[0]['end'] 81 | return get_s2orc_text(item, abstract_start, abstract_end) 82 | except: 83 | return "" 84 | 85 | def get_s2orc_title_abstract(item: dict) -> str: 86 | title = get_s2orc_title(item) 87 | abstract = get_s2orc_abstract(item) 88 | return f"Title: {title}\nAbstract: {abstract}" 89 | 90 | def get_s2orc_full_paper(item: dict) -> str: 91 | if "content" in item and "text" in item['content'] and item['content']['text'] is not None: 92 | return item['content']['text'] 93 | else: 94 | return "" 95 | 96 | def get_s2orc_paragraph_indices(item: dict) -> List[Tuple[int, int]]: 97 | text = get_s2orc_full_paper(item) 98 | paragraph_indices = [] 99 | paragraph_start = 0 100 | paragraph_end = 0 101 | while paragraph_start < len(text): 102 | paragraph_end = text.find("\n\n", paragraph_start) 103 | if paragraph_end == -1: 104 | paragraph_end = len(text) 105 | paragraph_indices.append((paragraph_start, paragraph_end)) 106 | paragraph_start = paragraph_end + 2 107 | return paragraph_indices 108 | 109 | def get_s2orc_text(item: dict, start_idx: int, end_idx: int) -> str: 110 | assert start_idx >= 0 and end_idx >= 0 111 | assert start_idx <= end_idx 112 | assert end_idx <= len(item['content']['text']) 113 | if "content" in item and "text" in item['content']: 114 | return item['content']['text'][start_idx:end_idx] 115 | else: 116 | return "" 117 | 118 | def get_s2orc_paragraphs(item: dict, min_words: int = 10) -> List[str]: 119 | paragraph_indices = get_s2orc_paragraph_indices(item) 120 | paragraphs = [get_s2orc_text(item, paragraph_start, paragraph_end) for paragraph_start, paragraph_end in paragraph_indices] 121 | paragraphs = [paragraph for paragraph in paragraphs if len(paragraph.split()) >= min_words] 122 | return paragraphs 123 | 124 | def get_s2orc_citations(item: dict, corpus_data: dict = None) -> List[int]: 125 | try: 126 | bibentry_string = item['content']['annotations']['bibentry'] 127 | bibentry_data = json.loads(bibentry_string) 128 | citations = set() 129 | for ref in bibentry_data: 130 | if "attributes" in ref and "matched_paper_id" in ref["attributes"]: 131 | if (corpus_data is None) or (ref["attributes"]["matched_paper_id"] in corpus_data): 132 | citations.add(ref["attributes"]["matched_paper_id"]) 133 | return list(citations) 134 | except: 135 | return [] 136 | 137 | def get_s2orc_dict(data: Dataset) -> dict: 138 | return {get_s2orc_corpusid(item): item for item in data} 139 | 140 | ##### reading fields from corpus_clean ##### 141 | 142 | def get_clean_corpusid(item: dict) -> int: 143 | return item['corpusid'] 144 | 145 | def get_clean_title(item: dict) -> str: 146 | return item['title'] 147 | 148 | def get_clean_abstract(item: dict) -> str: 149 | return item['abstract'] 150 | 151 | def get_clean_title_abstract(item: dict) -> str: 152 | title = get_clean_title(item) 153 | abstract = get_clean_abstract(item) 154 | return f"Title: {title}\nAbstract: {abstract}" 155 | 156 | def get_clean_full_paper(item: dict) -> str: 157 | return item['full_paper'] 158 | 159 | def get_clean_paragraph_indices(item: dict) -> List[Tuple[int, int]]: 160 | text = get_clean_full_paper(item) 161 | paragraph_indices = [] 162 | paragraph_start = 0 163 | paragraph_end = 0 164 | while paragraph_start < len(text): 165 | paragraph_end = text.find("\n\n", paragraph_start) 166 | if paragraph_end == -1: 167 | paragraph_end = len(text) 168 | paragraph_indices.append((paragraph_start, paragraph_end)) 169 | paragraph_start = paragraph_end + 2 170 | return paragraph_indices 171 | 172 | def get_clean_text(item: dict, start_idx: int, end_idx: int) -> str: 173 | text = get_clean_full_paper(item) 174 | assert start_idx >= 0 and end_idx >= 0 175 | assert start_idx <= end_idx 176 | assert end_idx <= len(text) 177 | return text[start_idx:end_idx] 178 | 179 | def get_clean_paragraphs(item: dict, min_words: int = 10) -> List[str]: 180 | paragraph_indices = get_clean_paragraph_indices(item) 181 | paragraphs = [get_clean_text(item, paragraph_start, paragraph_end) for paragraph_start, paragraph_end in paragraph_indices] 182 | paragraphs = [paragraph for paragraph in paragraphs if len(paragraph.split()) >= min_words] 183 | return paragraphs 184 | 185 | def get_clean_citations(item: dict) -> List[int]: 186 | return item['citations'] 187 | 188 | def get_clean_dict(data: Dataset) -> dict: 189 | return {get_clean_corpusid(item): item for item in data} 190 | 191 | ##### openai gpt-4 model ##### 192 | 193 | def get_gpt4_model(model_name: str = "gpt-4-1106-preview", azure: bool = True) -> OPENAIBaseEngine: 194 | model = OPENAIBaseEngine(model_name, azure) 195 | model.test_api() 196 | return model 197 | 198 | def prompt_gpt4_model(model: OPENAIBaseEngine, prompt: str = None, messages: List[dict] = None) -> str: 199 | if prompt is not None: 200 | messages = [{"role": "assistant", "content": prompt}] 201 | elif messages is None: 202 | raise ValueError("Either prompt or messages must be provided") 203 | 204 | response = model.safe_completion(messages) 205 | if response["finish_reason"] != "stop": 206 | print(f"Unexpected stop reason: {response['finish_reason']}") 207 | return response["content"] 208 | 209 | ##### cache directory ##### 210 | 211 | def get_cache_dir() -> str: 212 | return os.environ['HF_HOME'] --------------------------------------------------------------------------------