├── spladerunner ├── __init__.py ├── Config.py └── Expander.py ├── images └── vocproj.png ├── setup.py ├── .gitignore └── README.md /spladerunner/__init__.py: -------------------------------------------------------------------------------- 1 | from .Expander import Expander -------------------------------------------------------------------------------- /images/vocproj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrithivirajDamodaran/SPLADERunner/HEAD/images/vocproj.png -------------------------------------------------------------------------------- /spladerunner/Config.py: -------------------------------------------------------------------------------- 1 | MODEL_URL = 'https://huggingface.co/prithivida/flashrank/resolve/main/{}.zip' 2 | DEFAULT_CACHE_DIR = "/tmp" 3 | DEFAULT_MODEL = "Splade_PP_en_v1" 4 | MODEL_FILE_MAP = { 5 | "Splade_PP_en_v1": "model.onnx" 6 | } 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='SPLADERunner', 5 | version='0.1.6', 6 | packages=find_packages(), 7 | install_requires=[ 8 | 'tokenizers', 9 | 'onnxruntime', 10 | 'numpy', 11 | 'requests', 12 | 'tqdm' 13 | ], 14 | author='Prithivi Da', 15 | author_email='', 16 | description='Ultralight and Fast wrapper for the independent implementation of SPLADE++ models for your search & retrieval pipelines. Models and Library created by Prithivi Da, For PRs and Collaboration to checkout the readme.', 17 | long_description=open('README.md').read(), 18 | long_description_content_type='text/markdown', 19 | url='https://github.com/PrithivirajDamodaran/SPLADERunner', 20 | license='Apache 2.0', 21 | classifiers=[ 22 | 'Programming Language :: Python :: 3', 23 | 'Operating System :: OS Independent', 24 | ], 25 | python_requires='>=3.6', 26 | ) 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SPLADERunner 2 | 3 | ## 1. What is it? 4 | 5 | >Title is dedicated to the Original Blade Runners - Harrison Ford and the Author Philip K. Dick of "Do Androids Dream of Electric Sheep?" 6 | 7 | A Ultra-lite & Super-fast Python wrapper for the [independent implementation of SPLADE++ models](https://huggingface.co/prithivida/Splade_PP_en_v1) for your search & retrieval pipelines. Based on the papers Naver's [From Distillation to Hard Negative Sampling: Making Sparse Neural IR Models More Effective](https://arxiv.org/pdf/2205.04733.pdf) and Google's [SparseEmbed](https://storage.googleapis.com/gweb-research2023-media/pubtools/pdf/79f16d3b3b948706d191a7fe6dd02abe516f5564.pdf6) 8 | 9 | - ⚡ **Lite weight**: 10 | - **No Torch or Transformers** needed. 11 | - **Runs on CPU** for query or passage expansion. 12 | - **FLOPS & Retrieval Efficient**: Refer model card for details. 13 | 14 | 15 | ## 🚀 Installation: 16 | 17 | ```python 18 | pip install spladerunner 19 | ``` 20 | 21 | ## Usage: 22 | ```python 23 | 24 | # One-time only init 25 | from spladerunner import Expander 26 | expander = Expander('Splade_PP_en_v1', 128) #pass model, max_seq_len 27 | 28 | # Sample passage expansion 29 | sparse_rep = expander.expand("The Manhattan Project and its atomic bomb helped bring an end to World War II. Its legacy of peaceful uses of atomic energy continues to have an impact on history and science.") 30 | 31 | 32 | # For solr or elastic or vanilla lucene stores. 33 | sparse_rep = expander.expand("The Manhattan Project and its atomic bomb helped bring an end to World War II. Its legacy of peaceful uses of atomic energy continues to have an impact on history and science.", outformat="lucene") 34 | 35 | print(sparse_rep) 36 | 37 | ``` 38 | 39 | (Feel free to skip to 3 If you are expert in sparse and dense representations) 40 | 41 | ## 2. Why Sparse Representations? 42 | 43 | 44 | - **Lexical search** with BOW based sparse vectors are strong baselines, but they famously suffer from **vocabulary mismatch** problem, as they can only do exact term matching. 45 | 46 |
47 | 48 | Pros 49 | 50 | ✅ Efficient and Cheap. 51 | ✅ No need to fine-tune models. 52 | ✅️ Interpretable. 53 | ✅️ Exact Term Matches. 54 | 55 | Cons 56 | 57 | ❌ Vocabulary mismatch (Need to remember exact terms) 58 | 59 |
60 | 61 | 62 | - **Semantic Search** Learned Neural / Dense retrievers with approximate nearest neighbors search has shown impressive results but they can 63 | 64 |
65 | 66 | Pros 67 | 68 | ✅ Search how humans innately think. 69 | ✅ When finetuned beats sparse by long way. 70 | ✅ Easily works with Multiple modals. 71 | 72 | Cons 73 | 74 | ❌ Suffers token amnesia (misses term matching), 75 | ❌ Resource intensive (both index & retreival), 76 | ❌ Famously hard to interpret. 77 | ❌ Needs fine-tuning for OOD data. 78 | 79 |
80 | 81 | - Getting pros of both searches made sense and that gave rise to interest in **learning sparse representations** for queries and documents with some interpretability. The sparse representations also double as implicit or explicit (latent, contextualized) expansion mechanisms for both query and documents. If you are new to [query expansion learn more here from Daniel Tunkelang.](https://queryunderstanding.com/query-expansion-2d68d47cf9c8) 82 | 83 | 84 | 85 | 2a. **What the Models learn?** 86 | - The model learns to project it's learned dense representations over a MLM head to give a vocabulary distribution. 87 |
88 | 89 | ## 3. 💸 **Why SPLADERunner?**: 90 | - $ Concious: Serverless deployments like Lambda are charged by memory & time per invocation 91 | - Smaller package size = shorter cold start times, quicker re-deployments for Serverless. 92 | 93 | ## 4. 🎯 **Models**: 94 | - Below are the list of models supported as of now. 95 | * [`prithivida/Splade_PP_en_v1`](https://huggingface.co/prithivida/Splade_PP_en_v1) (default model) 96 | 97 | 4a. 💸 **Where and How can you use?** 98 | - [TBD] 99 | 100 | 4b. **How (and what) to contribute?** 101 | - [TBD] 102 | 103 | ## 5. **Criticisms and Competitions to SPLADE and Learned Sparse representations:** 104 | 105 | - [Wacky Weights in Learned Sparse Representations and the Revenge of Score-at-a-Time Query Evaluation](https://arxiv.org/pdf/2110.11540.pdf) 106 | - [Query2doc: Query Expansion with Large Language Models](https://arxiv.org/pdf/2303.07678.pdf) 107 | *note: don't mistake this for docT5query, this is a recent work* 108 | 109 | 110 | - *Thanks to [Nils Reimers](https://www.linkedin.com/in/reimersnils/) for* 111 | - The trolls :-) and timely inputs around evaluation. 112 | - *Props to Naver folks, the original authors of the paper for such a robust research.* 113 | 114 | -------------------------------------------------------------------------------- /spladerunner/Expander.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from tokenizers import AddedToken, Tokenizer 4 | import onnxruntime as ort 5 | import numpy as np 6 | import os 7 | import zipfile 8 | import requests 9 | from tqdm import tqdm 10 | from spladerunner.Config import DEFAULT_MODEL, DEFAULT_CACHE_DIR, MODEL_URL, MODEL_FILE_MAP 11 | import collections 12 | 13 | 14 | class Expander: 15 | 16 | def __init__(self, 17 | model_name = DEFAULT_MODEL, 18 | max_length=64, 19 | cache_dir= DEFAULT_CACHE_DIR 20 | ): 21 | 22 | self.cache_dir = Path(cache_dir) 23 | 24 | if not self.cache_dir.exists(): 25 | print(f"Cache directory {self.cache_dir} not found. Creating it..") 26 | self.cache_dir.mkdir(parents=True, exist_ok=True) 27 | 28 | self.model_dir = self.cache_dir / model_name 29 | 30 | if not self.model_dir.exists(): 31 | print(f"Downloading {model_name}...") 32 | self._download_model_files(model_name) 33 | 34 | model_file = MODEL_FILE_MAP[model_name] 35 | 36 | self.session = ort.InferenceSession(self.cache_dir / model_name / model_file) 37 | self.tokenizer = self._get_tokenizer(max_length) 38 | self.reverse_voc = {v: k for k, v in self.tokenizer.get_vocab().items()} 39 | 40 | def _download_model_files(self, model_name): 41 | 42 | # The local file path to which the file should be downloaded 43 | local_zip_file = self.cache_dir / f"{model_name}.zip" 44 | 45 | formatted_model_url = MODEL_URL.format(model_name) 46 | 47 | with requests.get(formatted_model_url, stream=True) as r: 48 | r.raise_for_status() 49 | total_size = int(r.headers.get('content-length', 0)) 50 | with open(local_zip_file, 'wb') as f, tqdm( 51 | desc=local_zip_file.name, 52 | total=total_size, 53 | unit='iB', 54 | unit_scale=True, 55 | unit_divisor=1024, 56 | ) as bar: 57 | for chunk in r.iter_content(chunk_size=8192): 58 | size = f.write(chunk) 59 | bar.update(size) 60 | 61 | # Extract the zip file 62 | with zipfile.ZipFile(local_zip_file, 'r') as zip_ref: 63 | zip_ref.extractall(self.cache_dir) 64 | 65 | # Optionally, remove the zip file after extraction 66 | os.remove(local_zip_file) 67 | 68 | def _load_vocab(self, vocab_file): 69 | 70 | vocab = collections.OrderedDict() 71 | with open(vocab_file, "r", encoding="utf-8") as reader: 72 | tokens = reader.readlines() 73 | for index, token in enumerate(tokens): 74 | token = token.rstrip("\n") 75 | vocab[token] = index 76 | return vocab 77 | 78 | 79 | def _get_tokenizer(self, max_length): 80 | 81 | config_path = self.model_dir / "config.json" 82 | tokenizer_path = self.model_dir / "tokenizer.json" 83 | tokenizer_config_path = self.model_dir / "tokenizer_config.json" 84 | tokens_map_path = self.model_dir / "special_tokens_map.json" 85 | 86 | # Check for file existence 87 | for path in [config_path, tokenizer_path, tokenizer_config_path, tokens_map_path]: 88 | if not path.exists(): 89 | raise FileNotFoundError(f"{path.name} missing in {self.model_dir}") 90 | 91 | config = json.load(open(str(config_path))) 92 | tokenizer_config = json.load(open(str(tokenizer_config_path))) 93 | tokens_map = json.load(open(str(tokens_map_path))) 94 | 95 | tokenizer = Tokenizer.from_file(str(tokenizer_path)) 96 | tokenizer.enable_truncation(max_length=min(tokenizer_config["model_max_length"], max_length)) 97 | tokenizer.enable_padding(pad_id=config["pad_token_id"], pad_token=tokenizer_config["pad_token"]) 98 | 99 | for token in tokens_map.values(): 100 | if isinstance(token, str): 101 | tokenizer.add_special_tokens([token]) 102 | elif isinstance(token, dict): 103 | tokenizer.add_special_tokens([AddedToken(**token)]) 104 | 105 | # vocab_file = self.model_dir / "vocab.txt" 106 | # if vocab_file.exists(): 107 | # tokenizer.vocab = self._load_vocab(vocab_file) 108 | # tokenizer.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in tokenizer.vocab.items()]) 109 | 110 | return tokenizer 111 | 112 | def expand(self, request, outformat="hybrid"): 113 | 114 | if isinstance(request, str): 115 | plain_input = [request] 116 | else: 117 | plain_input = request 118 | 119 | encoded_input = self.tokenizer.encode_batch(plain_input) 120 | input_ids = np.array([e.ids for e in encoded_input], dtype=np.int64) 121 | token_type_ids = np.array([e.type_ids for e in encoded_input], dtype=np.int64) 122 | attention_mask = np.array([e.attention_mask for e in encoded_input], dtype=np.int64) 123 | 124 | onnx_input = { 125 | "input_ids": input_ids, 126 | "input_mask": attention_mask, 127 | "segment_ids": token_type_ids, 128 | } 129 | 130 | outputs = self.session.run(None, onnx_input)[0] 131 | 132 | # Apply ReLU log 133 | relu_log = np.log1p(np.maximum(outputs, 0)) 134 | 135 | # Apply attention mask 136 | weighted_log = relu_log * attention_mask[:, :, np.newaxis] 137 | 138 | # Initialize the list for sparse representations 139 | sparse_representations = [] 140 | 141 | # Iterate over each example in the batch 142 | for i in range(outputs.shape[0]): 143 | single_weighted_log = weighted_log[i] 144 | 145 | # Find max values for the current example 146 | max_val = np.max(single_weighted_log, axis=0) 147 | 148 | # Find non-zero columns for the current example 149 | example_cols = np.nonzero(max_val)[0] 150 | weights = max_val[example_cols].tolist() 151 | 152 | if outformat == "lucene": 153 | # Create dictionary and sort it 154 | d = dict(zip(example_cols, weights)) 155 | sorted_d = dict(sorted(d.items(), key=lambda item: item[1], reverse=True)) 156 | 157 | # Construct SPLADE BoW representation for the current sentence 158 | sparse_representation = {self.reverse_voc[k]: round(v, 2) for k, v in sorted_d.items()} 159 | sparse_representations.append(sparse_representation) 160 | else: 161 | sparse_representations.append({"indices": example_cols.tolist(), "values": weights}) 162 | 163 | return sparse_representations 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | --------------------------------------------------------------------------------