├── .gitignore ├── LICENSE ├── readme.md ├── requirements.txt ├── run_beir.py ├── run_sample.py ├── setup.py └── xtr ├── __init__.py ├── configuration_xtr.py ├── modeling_xtr.py ├── retrieval_xtr.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | models 2 | cookiecutter-template-XTR 3 | __pycache__ 4 | datasets 5 | wandb 6 | index -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2024] Jinhyuk Lee and Mujeen Sung 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # XTR-Pytorch 2 | 3 | This repository provides a PyTorch-based reimplementation of XTR (Contextualized Token Retriever) for document retrieval. For this reimplemetation, I referred to the [original XTR code](https://github.com/google-deepmind/xtr/blob/main/xtr_evaluation_on_beir_miracl.ipynb) by the author, which uses Tensorflow. 4 | 5 | ## Installation 6 | ```bash 7 | $ git clone git@github.com:mjeensung/xtr-pytorch.git 8 | $ pip install -e . 9 | ``` 10 | 11 | ## Usage 12 | 13 | ### Simple example 14 | To see how this XTR-pytorch works, please see the example snippet below or run `python run_sample.py`. 15 | 16 | ```python 17 | # Create the dataset 18 | sample_doc = "Google LLC (/ˈɡuːɡəl/ (listen)) is an American multinational technology company focusing on online advertising, search engine technology, cloud computing, computer software, quantum computing, e-commerce, artificial intelligence..." 19 | chunks = [chunk.lower() for chunk in sent_tokenize(sample_doc)] 20 | 21 | # Load the XTR retriever 22 | xtr = XtrRetriever(model_name_or_path="google/xtr-base-en", use_faiss=False, device="cuda") 23 | 24 | # Build the index 25 | xtr.build_index(chunks) 26 | 27 | # Retrieve top-3 documents given the query 28 | query = "Who founded google" 29 | retrieved_docs, metadata = xtr.retrieve_docs([query], document_top_k=3) 30 | for rank, (did, score, doc) in enumerate(retrieved_docs[0]): 31 | print(f"[{rank}] doc={did} ({score:.3f}): {doc}") 32 | 33 | """ 34 | >> [0] doc=0 (0.925): google llc (/ˈɡuːɡəl/ (listen)) is an american multinational technology company focusing on online advertising, search engine technology, cloud computing, computer software, quantum computing, e-commerce, artificial intelligence, and consumer electronics. 35 | >> [1] doc=1 (0.903): it has been referred to as "the most powerful company in the world" and one of the world's most valuable brands due to its market dominance, data collection, and technological advantages in the area of artificial intelligence. 36 | >> [2] doc=2 (0.900): its parent company alphabet is considered one of the big five american information technology companies, alongside amazon, apple, meta, and microsoft. 37 | """ 38 | ``` 39 | 40 | ### BEIR example 41 | 42 | To evaluate XTR-pytorch on the [BEIR benchmark](https://github.com/beir-cellar/beir/), please run `run_beir.py`. 43 | ```bash 44 | $ pip install beir --no-deps 45 | $ python run_beir.py \ 46 | --model_name_or_path google/xtr-base-en \ 47 | --dataset nfcorpus \ 48 | --token_top_k 8000 \ 49 | --use_faiss 50 | ``` 51 | 52 | Below is the comparsion of NDCG@10 between the reported scores from [the XTR paper](https://arxiv.org/abs/2304.01982) and the scores from the reimplemented XTR in this repo across four datasets from BEIR. 53 | 54 | | Dataset | XTR base ([Reported](https://arxiv.org/abs/2304.01982), k=40000) | XTR-pytorch base (This repo, k=8000) | 55 | |:----------------------------------|:--------:|:--------:| 56 | | MSMARCO | 45.0 | 42.9 | 57 | | NQ | 53.0 | 52.0 | 58 | | NFCorpus | 34.0 | 34.1 | 59 | | SciFact | 71.0 | 71.8 | 60 | 61 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | transformers==4.40.2 3 | faiss-gpu==1.7.2 4 | tqdm 5 | nltk 6 | numpy 7 | pytrec_eval -------------------------------------------------------------------------------- /run_beir.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import re 4 | import os 5 | from beir import util 6 | from enum import Enum 7 | from tqdm import tqdm 8 | import pytrec_eval 9 | import argparse 10 | 11 | from beir.datasets.data_loader import GenericDataLoader 12 | from xtr.retrieval_xtr import XtrRetriever 13 | 14 | logger = logging.getLogger(__name__) 15 | logging.getLogger().setLevel(logging.INFO) 16 | device = "cuda" if torch.cuda.is_available() else "cpu" 17 | torch.set_grad_enabled(False) 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--model_name_or_path", type=str, default="google/xtr-base-en") 21 | parser.add_argument("--dataset", type=str, default="scifact") 22 | parser.add_argument("--use_faiss", action="store_true") 23 | parser.add_argument("--doc_sample_ratio", type=float, default=0.2) 24 | parser.add_argument("--vec_sample_ratio", type=float, default=0.2) 25 | parser.add_argument("--code_size", type=int, default=64) 26 | parser.add_argument("--nprobe", type=int, default=128) 27 | parser.add_argument("--token_top_k", type=int, default=8000) 28 | parser.add_argument("--dataset_dir", type=str, default="datasets") 29 | parser.add_argument("--index_dir", type=str, default="index") 30 | parser.add_argument("--load_index", action="store_true") 31 | 32 | args = parser.parse_args() 33 | 34 | ###################################### 35 | print("Step 1 - Load XTR Retriever") 36 | ###################################### 37 | 38 | xtr = XtrRetriever( 39 | model_name_or_path=args.model_name_or_path, 40 | use_faiss=args.use_faiss, 41 | device=device 42 | ) 43 | 44 | ###################################### 45 | print("Step 2 - Load BEIR Datasets") 46 | ###################################### 47 | 48 | url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(args.dataset) 49 | out_dir = os.path.join(os.getcwd(), "datasets") 50 | data_path = util.download_and_unzip(url, out_dir) 51 | logger.info("Dataset downloaded here: {}".format(data_path)) 52 | 53 | data_path = f"{args.dataset_dir}/{args.dataset}" 54 | 55 | if args.dataset == 'msmarco': 56 | split='dev' 57 | else: 58 | split='test' 59 | corpus, queries, qrels = GenericDataLoader(data_path).load(split=split) 60 | 61 | ###################################### 62 | print("Step 3 - Index BEIR Corpus") 63 | ###################################### 64 | 65 | # For Scifact + XTR-base-en (P100), this should take about 3 minutes. 66 | all_docs = [] 67 | all_keys = [] 68 | for doc_key, doc in tqdm(corpus.items()): 69 | doc_text = f"{doc['title']} {doc['text']}".lower() 70 | all_keys.append(doc_key) 71 | all_docs.append(doc_text) 72 | 73 | index_dir = f"{args.index_dir}/{args.dataset}" 74 | if args.load_index: 75 | index_num = xtr.load_index( 76 | all_docs, 77 | index_dir=index_dir, 78 | code_size=args.code_size, 79 | nprobe=args.nprobe 80 | ) 81 | else: 82 | index_num = xtr.build_index( 83 | all_docs, 84 | index_dir=index_dir, 85 | doc_sample_ratio=args.doc_sample_ratio, 86 | vec_sample_ratio=args.vec_sample_ratio, 87 | code_size=args.code_size, 88 | nprobe=args.nprobe 89 | ) 90 | print(f"XTR Index Size: {index_num}") 91 | 92 | ###################################### 93 | print("Step 4 - Run BEIR Evaluation") 94 | ###################################### 95 | 96 | # For Scifact, XTR-base-en (P100), this should take about 2 minutes. 97 | 98 | # Evaluation hyperparameters. 99 | TOKEN_TOP_K = args.token_top_k 100 | TREC_TOP_K = 100 101 | 102 | predictions = {} 103 | # Running evaluation per query for a better latency measurement. 104 | for q_idx, (query_key, query) in tqdm(enumerate(queries.items()), total=len(queries)): 105 | ranking, metadata = xtr.retrieve_docs( 106 | [query.lower()], 107 | token_top_k=TOKEN_TOP_K, 108 | return_text=False 109 | ) 110 | ranking = ranking[0] 111 | predictions[query_key] = {str(all_keys[did]): score for did, score in ranking[:TREC_TOP_K]} 112 | 113 | # Run pytrec_eval. 114 | K_VALUES = [5, 10, 50, 100] 115 | METRIC_NAMES = ['ndcg_cut', 'map_cut', 'recall'] 116 | 117 | def eval_metrics(qrels, predictions): 118 | measurements = [] 119 | for metric_name in METRIC_NAMES: 120 | measurements.append( 121 | f"{metric_name}." + ",".join([str(k) for k in K_VALUES]) 122 | ) 123 | evaluator = pytrec_eval.RelevanceEvaluator(qrels, measurements) 124 | final_scores = evaluator.evaluate(predictions) 125 | 126 | final_metrics = dict() 127 | for metric_name in METRIC_NAMES: 128 | for k in K_VALUES: 129 | final_metrics[f"{metric_name}@{k}"] = 0.0 130 | 131 | for query_id in final_scores.keys(): 132 | for metric_name in METRIC_NAMES: 133 | for k in K_VALUES: 134 | final_metrics[f"{metric_name}@{k}"] += final_scores[query_id][ 135 | f"{metric_name}_{k}" 136 | ] 137 | 138 | for metric_name in METRIC_NAMES: 139 | for k in K_VALUES: 140 | final_metrics[f"{metric_name}@{k}"] = round( 141 | final_metrics[f"{metric_name}@{k}"] / len(final_scores), 5 142 | ) 143 | 144 | print("[Result]") 145 | for metric_name, metric_score in final_metrics.items(): 146 | metric_score = round(metric_score*100,2) 147 | print(f"{metric_name}: {metric_score}") 148 | 149 | eval_metrics(qrels, predictions) -------------------------------------------------------------------------------- /run_sample.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import re 4 | import nltk; nltk.download('punkt') 5 | from nltk.tokenize import sent_tokenize 6 | from xtr.retrieval_xtr import XtrRetriever 7 | 8 | logger = logging.getLogger(__name__) 9 | device = "cuda" if torch.cuda.is_available() else "cpu" 10 | torch.set_grad_enabled(False) 11 | 12 | ###################################### 13 | logger.info("Step 1 - Create the dataset") 14 | ###################################### 15 | 16 | # Source: https://en.wikipedia.org/wiki/Google 17 | sample_doc = """Google LLC (/ˈɡuːɡəl/ (listen)) is an American multinational technology company focusing on online advertising, search engine technology, cloud computing, computer software, quantum computing, e-commerce, artificial intelligence,[9] and consumer electronics. It has been referred to as "the most powerful company in the world"[10] and one of the world's most valuable brands due to its market dominance, data collection, and technological advantages in the area of artificial intelligence.[11][12][13] Its parent company Alphabet is considered one of the Big Five American information technology companies, alongside Amazon, Apple, Meta, and Microsoft. 18 | Google was founded on September 4, 1998, by computer scientists Larry Page and Sergey Brin while they were PhD students at Stanford University in California. Together they own about 14% of its publicly listed shares and control 56% of its stockholder voting power through super-voting stock. The company went public via an initial public offering (IPO) in 2004. In 2015, Google was reorganized as a wholly owned subsidiary of Alphabet Inc. Google is Alphabet's largest subsidiary and is a holding company for Alphabet's internet properties and interests. Sundar Pichai was appointed CEO of Google on October 24, 2015, replacing Larry Page, who became the CEO of Alphabet. On December 3, 2019, Pichai also became the CEO of Alphabet.[14] 19 | The company has since rapidly grown to offer a multitude of products and services beyond Google Search, many of which hold dominant market positions. These products address a wide range of use cases, including email (Gmail), navigation (Waze & Maps), cloud computing (Cloud), web browsing (Chrome), video sharing (YouTube), productivity (Workspace), operating systems (Android), cloud storage (Drive), language translation (Translate), photo storage (Photos), video calling (Meet), smart home (Nest), smartphones (Pixel), wearable technology (Pixel Watch & Fitbit), music streaming (YouTube Music), video on demand (YouTube TV), artificial intelligence (Google Assistant), machine learning APIs (TensorFlow), AI chips (TPU), and more. Discontinued Google products include gaming (Stadia), Glass, Google+, Reader, Play Music, Nexus, Hangouts, and Inbox by Gmail.[15][16] 20 | Google's other ventures outside of Internet services and consumer electronics include quantum computing (Sycamore), self-driving cars (Waymo, formerly the Google Self-Driving Car Project), smart cities (Sidewalk Labs), and transformer models (Google Brain).[17] 21 | Google and YouTube are the two most visited websites worldwide followed by Facebook and Twitter. Google is also the largest search engine, mapping and navigation application, email provider, office suite, video sharing platform, photo and cloud storage provider, mobile operating system, web browser, ML framework, and AI virtual assistant provider in the world as measured by market share. On the list of most valuable brands, Google is ranked second by Forbes[18] and fourth by Interbrand.[19] It has received significant criticism involving issues such as privacy concerns, tax avoidance, censorship, search neutrality, antitrust and abuse of its monopoly position. 22 | Google began in January 1996 as a research project by Larry Page and Sergey Brin when they were both PhD students at Stanford University in California.[20][21][22] The project initially involved an unofficial "third founder", Scott Hassan, the original lead programmer who wrote much of the code for the original Google Search engine, but he left before Google was officially founded as a company;[23][24] Hassan went on to pursue a career in robotics and founded the company Willow Garage in 2006.[25][26] 23 | While conventional search engines ranked results by counting how many times the search terms appeared on the page, they theorized about a better system that analyzed the relationships among websites.[27] They called this algorithm PageRank; it determined a website's relevance by the number of pages, and the importance of those pages that linked back to the original site.[28][29] Page told his ideas to Hassan, who began writing the code to implement Page's ideas.[23] 24 | Page and Brin originally nicknamed the new search engine "BackRub", because the system checked backlinks to estimate the importance of a site.[20][30][31] Hassan as well as Alan Steremberg were cited by Page and Brin as being critical to the development of Google. Rajeev Motwani and Terry Winograd later co-authored with Page and Brin the first paper about the project, describing PageRank and the initial prototype of the Google search engine, published in 1998. Héctor García-Molina and Jeff Ullman were also cited as contributors to the project.[32] PageRank was influenced by a similar page-ranking and site-scoring algorithm earlier used for RankDex, developed by Robin Li in 1996, with Larry Page's PageRank patent including a citation to Li's earlier RankDex patent; Li later went on to create the Chinese search engine Baidu.[33][34] 25 | Eventually, they changed the name to Google; the name of the search engine was a misspelling of the word googol,[20][35][36] a very large number written 10100 (1 followed by 100 zeros), picked to signify that the search engine was intended to provide large quantities of information.[37] 26 | Google was initially funded by an August 1998 investment of $100,000 from Andy Bechtolsheim,[20] co-founder of Sun Microsystems. This initial investment served as a motivation to incorporate the company to be able to use the funds.[39][40] Page and Brin initially approached David Cheriton for advice because he had a nearby office in Stanford, and they knew he had startup experience, having recently sold the company he co-founded, Granite Systems, to Cisco for $220 million. David arranged a meeting with Page and Brin and his Granite co-founder Andy Bechtolsheim. The meeting was set for 8 AM at the front porch of David's home in Palo Alto and it had to be brief because Andy had another meeting at Cisco, where he now worked after the acquisition, at 9 AM. Andy briefly tested a demo of the website, liked what he saw, and then went back to his car to grab the check. David Cheriton later also joined in with a $250,000 investment.[41][42] 27 | Google received money from two other angel investors in 1998: Amazon.com founder Jeff Bezos, and entrepreneur Ram Shriram.[43] Page and Brin had first approached Shriram, who was a venture capitalist, for funding and counsel, and Shriram invested $250,000 in Google in February 1998. Shriram knew Bezos because Amazon had acquired Junglee, at which Shriram was the president. It was Shriram who told Bezos about Google. Bezos asked Shriram to meet Google's founders and they met 6 months after Shriram had made his investment when Bezos and his wife were in a vacation trip to the Bay Area. Google's initial funding round had already formally closed but Bezos' status as CEO of Amazon was enough to persuade Page and Brin to extend the round and accept his investment.[44][45] 28 | Between these initial investors, friends, and family Google raised around $1,000,000, which is what allowed them to open up their original shop in Menlo Park, California.[46] Craig Silverstein, a fellow PhD student at Stanford, was hired as the first employee.[22][47][48] 29 | After some additional, small investments through the end of 1998 to early 1999,[43] a new $25 million round of funding was announced on June 7, 1999,[49] with major investors including the venture capital firms Kleiner Perkins and Sequoia Capital.[40] Both firms were initially reticent about investing jointly in Google, as each wanted to retain a larger percentage of control over the company to themselves. Larry and Sergey however insisted in taking investments from both. Both venture companies finally agreed to investing jointly $12.5 million each due to their belief in Google's great potential and through mediation of earlier angel investors Ron Conway and Ram Shriram who had contacts in the venture companies.[50] 30 | In March 1999, the company moved its offices to Palo Alto, California,[51] which is home to several prominent Silicon Valley technology start-ups.[52] The next year, Google began selling advertisements associated with search keywords against Page and Brin's initial opposition toward an advertising-funded search engine.[53][22] To maintain an uncluttered page design, advertisements were solely text-based.[54] In June 2000, it was announced that Google would become the default search engine provider for Yahoo!, one of the most popular websites at the time, replacing Inktomi.[55][56] 31 | In 2003, after outgrowing two other locations, the company leased an office complex from Silicon Graphics, at 1600 Amphitheatre Parkway in Mountain View, California.[58] The complex became known as the Googleplex, a play on the word googolplex, the number one followed by a googol zeroes. Three years later, Google bought the property from SGI for $319 million.[59] By that time, the name "Google" had found its way into everyday language, causing the verb "google" to be added to the Merriam-Webster Collegiate Dictionary and the Oxford English Dictionary, denoted as: "to use the Google search engine to obtain information on the Internet".[60][61] The first use of the verb on television appeared in an October 2002 episode of Buffy the Vampire Slayer.[62] 32 | Additionally, in 2001 Google's investors felt the need to have a strong internal management, and they agreed to hire Eric Schmidt as the chairman and CEO of Google.[46] Eric was proposed by John Doerr from Kleiner Perkins. He had been trying to find a CEO that Sergey and Larry would accept for several months, but they rejected several candidates because they wanted to retain control over the company. Michael Moritz from Sequoia Capital at one point even menaced requesting Google to immediately pay back Sequoia's $12.5m investment if they did not fulfill their promise to hire a chief executive officer, which had been made verbally during investment negotiations. Eric wasn't initially enthusiastic about joining Google either, as the company's full potential hadn't yet been widely recognized at the time, and as he was occupied with his responsibilities at Novell where he was CEO. As part of him joining, Eric agreed to buy $1 million of Google preferred stocks as a way to show his commitment and to provide funds Google needed.[63] 33 | On August 19, 2004, Google became a public company via an initial public offering. At that time Larry Page, Sergey Brin, and Eric Schmidt agreed to work together at Google for 20 years, until the year 2024.[64] The company offered 19,605,052 shares at a price of $85 per share.[65][66] Shares were sold in an online auction format using a system built by Morgan Stanley and Credit Suisse, underwriters for the deal.[67][68] The sale of $1.67 billion gave Google a market capitalization of more than $23 billion.[69] 34 | On November 13, 2006, Google acquired YouTube for $1.65 billion in Google stock,[70][71][72][73] On March 11, 2008, Google acquired DoubleClick for $3.1 billion, transferring to Google valuable relationships that DoubleClick had with Web publishers and advertising agencies.[74][75] 35 | By 2011, Google was handling approximately 3 billion searches per day. To handle this workload, Google built 11 data centers around the world with several thousand servers in each. These data centers allowed Google to handle the ever-changing workload more efficiently.[46] 36 | In May 2011, the number of monthly unique visitors to Google surpassed one billion for the first time.[76][77] 37 | In May 2012, Google acquired Motorola Mobility for $12.5 billion, in its largest acquisition to date.[78][79][80] This purchase was made in part to help Google gain Motorola's considerable patent portfolio on mobile phones and wireless technologies, to help protect Google in its ongoing patent disputes with other companies,[81] mainly Apple and Microsoft,[82] and to allow it to continue to freely offer Android.[83] 38 | """ 39 | sample_doc = re.sub(r'\[\d+\]', '', sample_doc) 40 | 41 | # Single-sentence chunks. 42 | chunks = [chunk.lower() for chunk in sent_tokenize(sample_doc)] 43 | for i, chunk in enumerate(chunks): 44 | print(f'chunk{i}: {chunk[:150]} \n') 45 | if i > 3: 46 | print('...\n') 47 | break 48 | print('total # of chunks:', len(chunks)) 49 | 50 | ###################################### 51 | logger.info("Step 2 - Load XTR and Index the dataset") 52 | ###################################### 53 | 54 | # Load the XTR retriever 55 | model_name_or_path="google/xtr-base-en" 56 | xtr = XtrRetriever(model_name_or_path=model_name_or_path, use_faiss=False, device=device) 57 | 58 | # Build the index 59 | xtr.build_index(chunks) 60 | 61 | ###################################### 62 | logger.info("Step 3 - Have fun") 63 | ###################################### 64 | 65 | query = "Who founded google" 66 | retrieved_docs, metadata = xtr.retrieve_docs([query], document_top_k=3) 67 | print(f"\nQuery: {query}") 68 | for rank, (did, score, doc) in enumerate(retrieved_docs[0]): 69 | print(f"[{rank}] doc={did} ({score:.3f}): {doc}") -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | from setuptools import setup, find_packages 3 | 4 | with open('readme.md', encoding='utf8') as f: 5 | readme = f.read() 6 | 7 | with open('LICENSE', encoding='utf8') as f: 8 | license = f.read() 9 | 10 | with open('requirements.txt', encoding='utf8') as f: 11 | reqs = f.read() 12 | 13 | setup( 14 | name='xtr', 15 | version='1.0', 16 | description='Rethinking the Role of Token Retrieval in Multi-Vector Retrieval', 17 | long_description=readme, 18 | license=license, 19 | url='https://github.com/mjeensung/xtr_reimplementation', 20 | python_requires='>=3.8', 21 | packages=find_packages(include=['xtr', 'xtr.*']), 22 | install_requires=reqs.strip().split('\n'), 23 | ) -------------------------------------------------------------------------------- /xtr/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available 18 | 19 | 20 | _import_structure = { 21 | "retrieval_xtr": ["XtrRetriever"], 22 | } 23 | 24 | if TYPE_CHECKING: 25 | from .retrieval_xtr import XtrRetriever 26 | 27 | else: 28 | import sys 29 | 30 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 31 | -------------------------------------------------------------------------------- /xtr/configuration_xtr.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Jinhyuk Lee and Mujeen Sung. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ XTR model configuration """ 16 | 17 | from transformers import PretrainedConfig 18 | from transformers.utils import logging 19 | from .utils import load_file_path 20 | import json 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | class XtrConfig(PretrainedConfig): 25 | """ 26 | 27 | """ 28 | model_type = "xtr" 29 | 30 | def __init__( 31 | self, 32 | model_name_or_path=None 33 | ): 34 | from transformers import AutoConfig 35 | t5_config = AutoConfig.from_pretrained(model_name_or_path) 36 | 37 | # overwrite config 38 | self.__class__ = t5_config.__class__ 39 | self.__dict__ = t5_config.__dict__ 40 | 41 | # load config for linear layer 42 | linear_config_path = load_file_path( 43 | model_name_or_path, "2_Dense/config.json" 44 | ) 45 | with open(linear_config_path) as f: 46 | self.__dict__.update(json.load(f)) 47 | -------------------------------------------------------------------------------- /xtr/modeling_xtr.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Jinhyuk Lee and Mujeen Sung. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch XTR model. """ 16 | 17 | 18 | 19 | 20 | import math 21 | import os 22 | 23 | import torch 24 | import torch.utils.checkpoint 25 | from torch import nn 26 | from typing import Optional, Tuple, Union 27 | 28 | from transformers.modeling_utils import PreTrainedModel 29 | import torch.nn.functional as F 30 | 31 | from .configuration_xtr import XtrConfig 32 | from transformers.utils import logging 33 | from .utils import load_file_path 34 | 35 | logger = logging.get_logger(__name__) 36 | 37 | class XtrPreTrainedModel(PreTrainedModel): 38 | """ 39 | A simple interface for downloading and loading pretrained models. 40 | """ 41 | 42 | config_class = XtrConfig 43 | base_model_prefix = "xtr" 44 | supports_gradient_checkpointing = True 45 | _keys_to_ignore_on_load_missing = [r"position_ids"] 46 | 47 | class XtrModel(XtrPreTrainedModel): 48 | """ 49 | """ 50 | def __init__(self, model_name_or_path=None, config=None, device='cpu'): 51 | super().__init__(config) 52 | self.config = config 53 | self.t5_encoder = self._load_t5_encoder(model_name_or_path, device) 54 | self.linear_layer = self._load_linear_layer(model_name_or_path, device) 55 | 56 | def _load_t5_encoder(self,model_name_or_path=None, device='cpu'): 57 | from transformers import T5EncoderModel 58 | T5EncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"] 59 | 60 | t5_encoder = T5EncoderModel.from_pretrained(model_name_or_path, use_safetensors=True).to(device=device) 61 | 62 | return t5_encoder 63 | 64 | def _load_linear_layer(self,model_name_or_path=None, device='cpu'): 65 | linear_path = load_file_path( 66 | model_name_or_path, "2_Dense/pytorch_model.bin" 67 | ) 68 | linear_weight = torch.load(linear_path) 69 | linear_layer = torch.nn.Linear(self.config.in_features, self.config.out_features) 70 | linear_layer.weight = torch.nn.Parameter(linear_weight['linear.weight'].to(device=device)) 71 | linear_layer.bias = torch.nn.Parameter(torch.zeros(self.config.out_features, device=device)) 72 | return linear_layer 73 | 74 | def get_token_embed_dim(self): 75 | return self.config.out_features 76 | 77 | def forward( 78 | self, 79 | input_ids=None, 80 | attention_mask=None, 81 | ): 82 | def pass_through(model_output, attention_mask): 83 | token_embeddings = model_output[0] # First element of model_output contains all token embeddings 84 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 85 | return token_embeddings * input_mask_expanded 86 | 87 | model_output = self.t5_encoder(input_ids=input_ids, attention_mask=attention_mask) 88 | 89 | # Perform pooling 90 | embeddings = pass_through(model_output, attention_mask) 91 | 92 | # Apply linear layer 93 | embeddings = self.linear_layer(embeddings) 94 | 95 | # Normalize embeddings 96 | embeddings = F.normalize(embeddings, p=2, dim=2) 97 | 98 | return embeddings -------------------------------------------------------------------------------- /xtr/retrieval_xtr.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024, Jinhyuk Lee and Mujeen Sung. All rights reserved. 3 | # Original repo for XTR: https://github.com/google-deepmind/xtr 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """XTR Retriever model implementation.""" 17 | 18 | 19 | from tqdm import tqdm 20 | import numpy as np 21 | import logging 22 | import faiss 23 | import torch 24 | import os 25 | import math 26 | import random 27 | 28 | from faiss import write_index, read_index 29 | from typing import List, Dict, Optional, Union, Tuple 30 | from transformers import AutoTokenizer 31 | from transformers.utils import logging 32 | from .modeling_xtr import XtrModel 33 | from .configuration_xtr import XtrConfig 34 | 35 | logger = logging.get_logger(__name__) 36 | 37 | class XtrRetriever(object): 38 | def __init__(self, 39 | model_name_or_path: str, 40 | cache_dir: Optional[str] = None, 41 | use_faiss=False, 42 | device='cpu' 43 | ): 44 | 45 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 46 | self.config = XtrConfig(model_name_or_path) 47 | self.encoder = XtrModel(model_name_or_path, config=self.config, device=device) 48 | self.use_faiss = use_faiss 49 | self.device = device 50 | self.max_seq_len = self.tokenizer.model_max_length 51 | self.token_embed_dim = self.encoder.get_token_embed_dim() 52 | self.doc_offset = 512 # max token length 53 | 54 | def tokenize(self, text): 55 | return [self.tokenizer.id_to_string(id_).numpy().decode('utf-8') for id_ in self.tokenizer.tokenize(text)] 56 | 57 | def get_token_embeddings(self, texts): 58 | encoded_inputs = self.tokenizer(texts, return_tensors='pt',padding='max_length', truncation=True, max_length=self.max_seq_len).to(device=self.device) 59 | 60 | with torch.no_grad(): 61 | batch_embeds = self.encoder(**encoded_inputs) 62 | 63 | batch_lengths = np.sum(encoded_inputs['attention_mask'].cpu().numpy(), axis=1) 64 | 65 | return batch_embeds.cpu().numpy(), batch_lengths 66 | 67 | def get_flatten_embeddings(self, batch_text, return_last_offset=False): 68 | batch_embeddings, batch_lengths = self.get_token_embeddings(batch_text) 69 | flatten_embeddings = None 70 | num_tokens = 0 71 | offsets = [0] 72 | for batch_id, (embeddings, length) in enumerate(zip(batch_embeddings, batch_lengths)): 73 | if flatten_embeddings is not None: 74 | flatten_embeddings = np.append(flatten_embeddings, embeddings[:int(length)], axis=0) 75 | else: 76 | flatten_embeddings = embeddings[:int(length)] 77 | num_tokens += int(length) 78 | offsets.append(num_tokens) 79 | assert num_tokens == flatten_embeddings.shape[0] 80 | if not return_last_offset: 81 | offsets = offsets[:-1] 82 | return flatten_embeddings, offsets 83 | 84 | def build_index(self, documents, batch_size=32, **kwargs): 85 | if self.use_faiss: 86 | index_num = self.build_index_faiss(documents, batch_size=batch_size, **kwargs) 87 | else: 88 | index_num = self.build_index_bruteforce(documents, batch_size=batch_size) 89 | 90 | return index_num 91 | 92 | def build_index_bruteforce(self, documents, index_dir=None, batch_size=32): 93 | # Used only for small-scale, exact inference. 94 | all_token_embeds = np.zeros((len(documents)*self.max_seq_len, self.token_embed_dim), dtype=np.float32) 95 | num_tokens = 0 96 | for batch_idx in tqdm(range(0, len(documents), batch_size)): 97 | batch_docs = documents[batch_idx:batch_idx+batch_size] 98 | batch_embeds, batch_offsets = self.get_flatten_embeddings(batch_docs) 99 | num_tokens += len(batch_embeds) 100 | all_token_embeds[num_tokens-len(batch_embeds):num_tokens] = batch_embeds 101 | 102 | class BruteForceSearcher(object): 103 | def search_batched(self, query_embeds, final_num_neighbors, **kwargs): 104 | scores = query_embeds.dot(all_token_embeds[:num_tokens].T) # Q x D 105 | top_ids = scores.argsort(axis=1)[:, ::-1][:,:final_num_neighbors] # Q x top_k 106 | return top_ids, [q_score[q_top_ids] for q_score, q_top_ids in zip(scores, top_ids)] # (Q x top_k, Q x top_k) 107 | self.searcher = BruteForceSearcher() 108 | self.docs = documents 109 | print("Index Ready!", self.searcher) 110 | 111 | return num_tokens 112 | 113 | def build_index_faiss(self, documents, batch_size=32, doc_sample_ratio=0.2, vec_sample_ratio=0.2, seed=29, index_dir=None, code_size=64, nprobe=4): 114 | # 1. sample token embeddings for train index 115 | random.seed(seed) 116 | np.random.seed(seed) 117 | smpl_vec_len = int(vec_sample_ratio * self.max_seq_len) 118 | smpl_documents = random.sample(documents, int(doc_sample_ratio * len(documents))) 119 | smpl_token_embeds = np.zeros((int(len(documents)*doc_sample_ratio*self.max_seq_len*vec_sample_ratio), self.token_embed_dim), dtype=np.float32) 120 | num_tokens = 0 121 | for batch_idx in tqdm(range(0, len(smpl_documents), batch_size)): 122 | batch_docs = smpl_documents[batch_idx:batch_idx+batch_size] 123 | batch_embeds, _ = self.get_flatten_embeddings(batch_docs) 124 | smpl_batch_idx = np.random.choice(len(batch_embeds), int(vec_sample_ratio * len(batch_embeds))) 125 | smpl_batch_embeds = batch_embeds[smpl_batch_idx] 126 | num_tokens += len(smpl_batch_embeds) 127 | smpl_token_embeds[num_tokens-len(smpl_batch_embeds):num_tokens] = smpl_batch_embeds 128 | 129 | smpl_token_embeds = smpl_token_embeds[:num_tokens] 130 | 131 | # use the square root of total token nums as num_clusters 132 | num_clusters = int(math.sqrt(num_tokens/doc_sample_ratio/vec_sample_ratio)) 133 | 134 | ds = self.token_embed_dim 135 | quantizer = faiss.IndexFlatIP(ds) 136 | opq_matrix = faiss.OPQMatrix(ds, code_size) 137 | opq_matrix.niter = 10 138 | sub_index = faiss.IndexIVFPQ(quantizer, ds, num_clusters, code_size, 8, faiss.METRIC_INNER_PRODUCT) 139 | sub_index.nprobe = nprobe 140 | index = faiss.IndexPreTransform(opq_matrix, sub_index) 141 | # Convert to GPU index 142 | res = faiss.StandardGpuResources() 143 | co = faiss.GpuClonerOptions() 144 | co.useFloat16 = True 145 | gpu_index = faiss.index_cpu_to_gpu(res, 0, index, co) 146 | gpu_index.verbose = False 147 | 148 | # Train on GPU with sampled token embeddings and back to CPU 149 | gpu_index.train(smpl_token_embeds) 150 | index = faiss.index_gpu_to_cpu(gpu_index) 151 | 152 | # 2. embed all tokens in batch and add them faiss index 153 | add_size = 128 154 | add_num_tokens = 0 155 | add_count = 0 156 | add_token_ids = [] 157 | add_token_embeds = np.zeros((int(batch_size*add_size*self.max_seq_len), self.token_embed_dim), dtype=np.float32) 158 | all_num_tokens = 0 159 | for batch_idx in tqdm(range(0, len(documents), batch_size)): 160 | batch_docs = documents[batch_idx:batch_idx+batch_size] 161 | batch_embeds, batch_offsets = self.get_flatten_embeddings(batch_docs) 162 | batch_token_len = [batch_offsets[i+1] - batch_offsets[i] for i, offset in enumerate(batch_offsets[:-1])] + [len(batch_embeds) - batch_offsets[-1]] 163 | # batch_token_ids = [f"{did*self.doc_offset + tid}" for did in range(batch_idx,batch_idx+len(batch_docs)) for tid in range(batch_token_len[did - batch_idx])] 164 | batch_token_ids = [did*self.doc_offset + tid for did in range(batch_idx,batch_idx+len(batch_docs)) for tid in range(batch_token_len[did - batch_idx])] 165 | 166 | add_num_tokens += len(batch_embeds) 167 | all_num_tokens += len(batch_embeds) 168 | add_token_embeds[add_num_tokens-len(batch_embeds):add_num_tokens] = batch_embeds 169 | add_token_ids += batch_token_ids 170 | 171 | # add batch embeds with ids to index 172 | add_count += 1 173 | if add_count >= add_size: 174 | add_token_embeds = add_token_embeds[:len(add_token_ids)] 175 | index.add_with_ids(x=add_token_embeds,ids=np.array(add_token_ids)) 176 | 177 | add_num_tokens = 0 178 | add_count = 0 179 | add_token_ids = [] 180 | add_token_embeds = np.zeros((int(batch_size*add_size*self.max_seq_len), self.token_embed_dim), dtype=np.float32) 181 | 182 | if add_count != 0: 183 | add_token_embeds = add_token_embeds[:len(add_token_ids)] 184 | index.add_with_ids(x=add_token_embeds,ids=np.array(add_token_ids)) 185 | 186 | assert all_num_tokens == index.ntotal 187 | 188 | self.save_index(index, index_dir, code_size, nprobe) 189 | 190 | class FaissSearcher(object): 191 | def search_batched(self, query_embeds, final_num_neighbors, **kwargs): 192 | scores, top_ids = index.search(query_embeds, final_num_neighbors) 193 | return top_ids, scores 194 | self.searcher = FaissSearcher() 195 | self.docs = documents 196 | 197 | print("Index Ready!", self.searcher) 198 | return index.ntotal 199 | 200 | 201 | def save_index(self, index, index_dir, code_size, nprobe): 202 | os.makedirs(index_dir, exist_ok=True) 203 | index_path = f"{index_dir}/cs{code_size}.index" 204 | write_index(index, index_path) 205 | 206 | def load_index(self, documents, index_dir, code_size, nprobe): 207 | self.docs = documents 208 | 209 | index_path = f"{index_dir}/cs{code_size}.index" 210 | 211 | index = read_index(index_path) 212 | index_ivf = faiss.extract_index_ivf(index) 213 | index_ivf.nprobe = nprobe 214 | 215 | class FaissSearcher(object): 216 | def search_batched(self, query_embeds, final_num_neighbors, **kwargs): 217 | scores, top_ids = index.search(query_embeds, final_num_neighbors) 218 | return top_ids, scores 219 | self.searcher = FaissSearcher() 220 | 221 | return index.ntotal 222 | 223 | def batch_search_tokens(self, batch_query, token_top_k=100): 224 | all_query_encodings, query_offsets = self.get_flatten_embeddings(batch_query, return_last_offset=True) 225 | all_neighbors, all_scores = self.searcher.search_batched(all_query_encodings, final_num_neighbors=token_top_k) 226 | 227 | return [ 228 | ( 229 | [f'q_{i}' for i in range(query_offsets[oid], query_offsets[oid+1])], # query_id 230 | all_neighbors[query_offsets[oid]:query_offsets[oid+1]], # neighbors 231 | all_scores[query_offsets[oid]:query_offsets[oid+1]], # scores 232 | ) 233 | for oid in range(len(query_offsets)-1) 234 | ] 235 | 236 | def estimate_missing_similarity(self, batch_result): 237 | batch_qtoken_to_ems = [dict() for _ in range(len(batch_result))] 238 | for b_idx, (query_tokens, _, distances) in enumerate(batch_result): 239 | for token_idx, qtoken in enumerate(query_tokens): 240 | idx_t = (token_idx, qtoken) 241 | # Use similarity of the last token as imputed similarity. 242 | batch_qtoken_to_ems[b_idx][idx_t] = distances[token_idx][-1] 243 | return batch_qtoken_to_ems 244 | 245 | def aggregate_scores(self, batch_result, batch_ems, document_top_k): 246 | """Aggregates token-level retrieval scores into query-document scores.""" 247 | 248 | def get_did2scores(query_tokens, all_neighbors, all_scores): 249 | did2scores = {} 250 | # |Q| x k' 251 | for qtoken_idx, (qtoken, neighbors, scores) in enumerate(zip(query_tokens, all_neighbors, all_scores)): 252 | for _, (doc_token_id, score) in enumerate(zip(neighbors, scores)): 253 | if np.isnan(score): 254 | continue 255 | 256 | docid = doc_token_id // self.doc_offset 257 | if docid not in did2scores: 258 | did2scores[docid] = {} 259 | qtoken_with_idx = (qtoken_idx, qtoken) 260 | if qtoken_with_idx not in did2scores[docid]: 261 | # Only keep the top score for sum-of-max. 262 | did2scores[docid][qtoken_with_idx] = score 263 | 264 | return did2scores 265 | batch_did2scores = [get_did2scores(qtokens, neighbors, scores) for qtokens, neighbors, scores in batch_result] 266 | 267 | def add_ems(did2scores, query_tokens, ems): 268 | # |Q| x |Q|k' (assuming most docid is unique) 269 | for qtoken_idx, qtoken in enumerate(query_tokens): 270 | qtoken_with_idx = (qtoken_idx, qtoken) 271 | for docid, scores in did2scores.items(): 272 | if qtoken_with_idx not in scores: 273 | scores[qtoken_with_idx] = ems[qtoken_with_idx] 274 | for did2scores, result, ems in zip(batch_did2scores, batch_result, batch_ems): 275 | add_ems(did2scores, result[0], ems) 276 | 277 | def get_final_score(did2scores, query_tokens): 278 | final_qd_score = {} 279 | # |Q|k' x |Q| 280 | for docid, scores in did2scores.items(): 281 | assert len(scores) == len(query_tokens) 282 | final_qd_score[docid] = sum(scores.values()) / len(scores) 283 | return final_qd_score 284 | 285 | batch_scores = [get_final_score(did2scores, result[0]) for did2scores, result in zip(batch_did2scores, batch_result)] 286 | 287 | batch_ranking = [ 288 | sorted([(docid, score) for docid, score in final_qd_score.items()], key=lambda x: x[1], reverse=True)[:document_top_k] 289 | for final_qd_score in batch_scores 290 | ] 291 | return batch_ranking 292 | 293 | def get_document_text(self, batch_ranking): 294 | batch_retrieved_docs = [] 295 | for ranking in batch_ranking: 296 | retrieved_docs = [] 297 | for did, score in ranking: 298 | retrieved_docs.append((did, score, self.docs[did])) 299 | batch_retrieved_docs.append(retrieved_docs) 300 | return batch_retrieved_docs 301 | 302 | def retrieve_docs( 303 | self, 304 | batch_query: List[str], 305 | token_top_k: int = 100, 306 | document_top_k: int = 100, 307 | return_text: bool = True, 308 | ): 309 | """Runs XTR retrieval for a query.""" 310 | batch_result = self.batch_search_tokens(batch_query, token_top_k=token_top_k) 311 | batch_mae = self.estimate_missing_similarity(batch_result) 312 | batch_ranking = self.aggregate_scores(batch_result, batch_mae, document_top_k) 313 | if return_text: 314 | return self.get_document_text(batch_ranking), batch_result 315 | else: 316 | return batch_ranking, batch_result 317 | -------------------------------------------------------------------------------- /xtr/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Union 3 | from huggingface_hub import hf_hub_download 4 | 5 | def load_file_path( 6 | model_name_or_path: str, 7 | filename: str, 8 | ) -> Optional[str]: 9 | # If file is local 10 | file_path = os.path.join(model_name_or_path, filename) 11 | if os.path.exists(file_path): 12 | return file_path 13 | 14 | # If file is remote 15 | try: 16 | return hf_hub_download( 17 | model_name_or_path, 18 | filename=filename, 19 | ) 20 | except Exception: 21 | return 22 | --------------------------------------------------------------------------------