├── tests ├── __init__.py ├── test_scorer.py └── test_basic.py ├── MANIFEST.in ├── pyt_splade ├── .DS_Store ├── __init__.py ├── pt_docs │ ├── api.rst │ └── index.rst ├── _utils.py ├── _scorer.py ├── _encoder.py └── _model.py ├── setup.py ├── .github └── workflows │ └── push.yml ├── README.md ├── .gitignore ├── trec-covid.ipynb └── msmarco-psg-v1.ipynb /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include pyt_splade *.rst 2 | -------------------------------------------------------------------------------- /pyt_splade/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmacdonald/pyt_splade/HEAD/pyt_splade/.DS_Store -------------------------------------------------------------------------------- /pyt_splade/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.2' 2 | 3 | from pyt_splade._model import Splade 4 | from pyt_splade._encoder import SpladeEncoder 5 | from pyt_splade._scorer import SpladeScorer 6 | from pyt_splade._utils import Toks2Doc 7 | 8 | SpladeFactory = Splade # backward compatible name 9 | toks2doc = Toks2Doc # backward compatible name 10 | 11 | __all__ = ['Splade', 'SpladeEncoder', 'SpladeScorer', 'SpladeFactory', 'Toks2Doc', 'toks2doc'] 12 | -------------------------------------------------------------------------------- /pyt_splade/pt_docs/api.rst: -------------------------------------------------------------------------------- 1 | API Documentation 2 | ========================================== 3 | 4 | :class:`~pyt_splade.Splade` is the primary way to interact with this package: 5 | 6 | .. autoclass:: pyt_splade.Splade 7 | :members: 8 | 9 | Utils / Internals 10 | ------------------------------------------ 11 | 12 | .. autoclass:: pyt_splade.Toks2Doc 13 | :members: 14 | 15 | .. autoclass:: pyt_splade.SpladeEncoder 16 | :members: 17 | 18 | .. autoclass:: pyt_splade.SpladeScorer 19 | :members: 20 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('README.md') as f: 4 | readme = f.read() 5 | 6 | setup( 7 | name='pyt_splade', 8 | version='0.0.2', 9 | description='PyT wrapper for SPLADE', 10 | url='https://github.com/cmacdonald/pyt_splade', 11 | classifiers=[ 12 | 'Intended Audience :: Science/Research', 13 | 'Programming Language :: Python :: 3.7', 14 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 15 | ], 16 | packages=['pyt_splade'] + ['pyt_splade.' + i for i in find_packages('pyt_splade')], 17 | # as per splade 18 | include_package_data=True, 19 | license="Creative Commons Attribution-NonCommercial-ShareAlike", 20 | long_description=readme, 21 | install_requires=[ 22 | 'torch>=2.6.0', 'transformers', 'python-terrier>=0.11.0', 'pyterrier_alpha', 23 | ], 24 | ) 25 | -------------------------------------------------------------------------------- /pyt_splade/_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pyterrier as pt 3 | 4 | 5 | class Toks2Doc(pt.Transformer): 6 | """Converts a toks field into a text field, by scaling the weights by ``mult`` and repeating them.""" 7 | def __init__(self, mult: float = 100.): 8 | """Initializes the transformer. 9 | 10 | Args: 11 | mult: the multiplier to apply to the term frequencies 12 | """ 13 | self.mult = mult 14 | 15 | def transform(self, inp: pd.DataFrame) -> pd.DataFrame: 16 | """Converts the toks field into a text field.""" 17 | res = inp.assign(text=inp['toks'].apply(self._dict_tf2text)) 18 | res.drop(columns=['toks'], inplace=True) 19 | return res 20 | 21 | def _dict_tf2text(self, tfdict): 22 | rtr = "" 23 | for t in tfdict: 24 | for i in range(int(self.mult * tfdict[t])): 25 | rtr += t + " " 26 | return rtr 27 | -------------------------------------------------------------------------------- /.github/workflows/push.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | 12 | strategy: 13 | matrix: 14 | python-version: ['3.9', '3.12'] 15 | java: [11] 16 | os: ['ubuntu-latest'] 17 | architecture: ['x64'] 18 | terrier: ['snapshot'] 19 | 20 | runs-on: ${{ matrix.os }} 21 | steps: 22 | - uses: actions/checkout@v4 23 | 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v5 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | 29 | - name: Setup java 30 | uses: actions/setup-java@v4 31 | with: 32 | java-version: ${{ matrix.java }} 33 | distribution: 'zulu' 34 | 35 | - name: Install Python dependencies 36 | run: | 37 | python -m pip install --upgrade pip 38 | pip install --timeout=120 . torch pytest 39 | 40 | - name: All unit tests 41 | env: 42 | TERRIER_VERSION: ${{ matrix.terrier }} 43 | run: | 44 | pytest 45 | -------------------------------------------------------------------------------- /tests/test_scorer.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pandas as pd 3 | import pyt_splade 4 | 5 | class TestScorer(unittest.TestCase): 6 | 7 | def setUp(self): 8 | self.splade = pyt_splade.Splade(device='cpu') 9 | 10 | def test_scorer(self): 11 | df = self.splade.scorer()(pd.DataFrame([ 12 | {'qid': '0', 'query': 'chemical reactions', 'docno' : 'd1', 'text' : 'hello there'}, 13 | {'qid': '0', 'query': 'chemical reactions', 'docno' : 'd2', 'text' : 'chemistry society'}, 14 | {'qid': '1', 'query': 'hello', 'docno' : 'd1', 'text' : 'hello there'}, 15 | ])) 16 | self.assertAlmostEqual(0., df['score'][0]) 17 | self.assertAlmostEqual(11.133593, df['score'][1], places=4) 18 | self.assertAlmostEqual(17.566324, df['score'][2], places=3) 19 | self.assertEqual('0', df['qid'][0]) 20 | self.assertEqual('0', df['qid'][1]) 21 | self.assertEqual('1', df['qid'][2]) 22 | self.assertEqual('d1', df['docno'][0]) 23 | self.assertEqual('d2', df['docno'][1]) 24 | self.assertEqual('d1', df['docno'][2]) 25 | self.assertEqual(1, df['rank'][0]) 26 | self.assertEqual(0, df['rank'][1]) 27 | self.assertEqual(0, df['rank'][2]) 28 | -------------------------------------------------------------------------------- /pyt_splade/_scorer.py: -------------------------------------------------------------------------------- 1 | import more_itertools 2 | import pandas as pd 3 | import numpy as np 4 | import pyterrier as pt 5 | import pyterrier_alpha as pta 6 | 7 | class SpladeScorer(pt.Transformer): 8 | """Scores (re-ranks) documents against queries using a SPLADE model.""" 9 | def __init__(self, splade, text_field, batch_size=100, verbose=False): 10 | """Initializes the SPLADE scorer. 11 | 12 | Args: 13 | splade: :class:`pyt_splade.Splade` instance 14 | text_field: the text field to score 15 | batch_size: the batch size to use when scoring 16 | verbose: if True, show a progress bar 17 | """ 18 | self.splade = splade 19 | self.text_field = text_field 20 | self.batch_size = batch_size 21 | self.verbose = verbose 22 | 23 | def transform(self, df: pd.DataFrame) -> pd.DataFrame: 24 | """Scores (re-ranks) the documents against the queries in the input DataFrame.""" 25 | pta.validate.result_frame(df, ['query', self.text_field]) 26 | it = df.groupby('query') 27 | if self.verbose: 28 | it = pt.tqdm(it, unit='query') 29 | res = [] 30 | for query, df in it: 31 | query_enc = self.splade.encode([query], 'q', 'torch') 32 | scores = [] 33 | for batch in more_itertools.chunked(df[self.text_field], self.batch_size): 34 | doc_enc = self.splade.encode(batch, 'd', 'torch') 35 | scores.append((query_enc @ doc_enc.T).flatten().cpu().numpy()) 36 | res.append(df.assign(score=np.concatenate(scores))) 37 | res = pd.concat(res) 38 | from pyterrier.model import add_ranks 39 | res = add_ranks(res) 40 | return res 41 | -------------------------------------------------------------------------------- /tests/test_basic.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pandas as pd 3 | from unittest.mock import MagicMock 4 | import pyt_splade 5 | import pyterrier as pt 6 | 7 | class TestBasic(unittest.TestCase): 8 | 9 | def setUp(self): 10 | self.splade = pyt_splade.Splade(device='cpu') 11 | 12 | def test_transformer_indexing(self): 13 | df = (self.splade.doc_encoder() >> pyt_splade.toks2doc())(pd.DataFrame([{'docno' : 'd1', 'text' : 'hello there'}])) 14 | self.assertTrue('there there' in df.iloc[0].text) 15 | df = self.splade.doc_encoder()([ 16 | {'docno' : 'd1', 'text' : 'hello there'}, 17 | {'docno' : 'd1', 'text' : ''}, #empty 18 | {'docno' : 'd1', 'text' : 'hello hello hello hello hello there'}]) 19 | 20 | def test_transformer_querying(self): 21 | q = self.splade.query_encoder() 22 | df = q(pd.DataFrame([{'qid' : 'q1', 'query' : 'chemical reactions'}])) 23 | self.assertTrue('query_toks' in df.columns) 24 | 25 | def test_transformer_empty_query(self): 26 | q = self.splade.query_encoder() 27 | self.assertEqual([["qid", "query"]], pt.inspect.transformer_inputs(q)) 28 | self.assertEqual(["qid", "query", "query_toks"], pt.inspect.transformer_outputs(q, ["qid", "query"])) 29 | res = q(pd.DataFrame([], columns=['qid', 'query'])) 30 | self.assertEqual(['qid', 'query', 'query_toks'], list(res.columns)) 31 | 32 | def test_transformer_empty_doc(self): 33 | d = self.splade.doc_encoder() 34 | self.assertEqual([["docno", "text"]], pt.inspect.transformer_inputs(d)) 35 | self.assertEqual(["docno", "text", "toks"], pt.inspect.transformer_outputs(d, ["docno", "text"])) 36 | res = d(pd.DataFrame([], columns=['docno', 'text'])) 37 | self.assertEqual(['docno', 'text', 'toks'], list(res.columns)) 38 | -------------------------------------------------------------------------------- /pyt_splade/_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | import more_itertools 3 | import pandas as pd 4 | import pyterrier as pt 5 | import pyterrier_alpha as pta 6 | import pyt_splade 7 | 8 | 9 | class SpladeEncoder(pt.Transformer): 10 | """Encodes a text field using a SPLADE model. The output is a dense or sparse representation of the text field.""" 11 | 12 | def __init__( 13 | self, 14 | splade: pyt_splade.Splade, 15 | text_field: str, 16 | out_field: str, 17 | rep: Literal['q', 'd'], 18 | sparse: bool = True, 19 | batch_size: int = 100, 20 | verbose: bool = False, 21 | scale: float = 1., 22 | ): 23 | """Initializes the SPLADE encoder. 24 | 25 | Args: 26 | splade: :class:`pyt_splade.Splade` instance 27 | text_field: the input text field to encode 28 | out_field: the output field to store the encoded representation 29 | rep: 'q' for query, 'd' for document 30 | sparse: if True, the output will be a dict of term frequencies, otherwise a dense vector 31 | batch_size: the batch size to use when encoding 32 | verbose: if True, show a progress bar 33 | scale: the scale to apply to the term frequencies 34 | """ 35 | self.splade = splade 36 | self.text_field = text_field 37 | self.out_field = out_field 38 | self.rep = rep 39 | self.sparse = sparse 40 | self.batch_size = batch_size 41 | self.verbose = verbose 42 | self.scale = scale 43 | 44 | def transform(self, df: pd.DataFrame) -> pd.DataFrame: 45 | """Encodes the text field in the input DataFrame.""" 46 | if self.text_field == 'query': 47 | pta.validate.query_frame(df, extra_columns=[self.text_field]) 48 | else: 49 | pta.validate.document_frame(df, extra_columns=[self.text_field]) 50 | it = iter(df[self.text_field]) 51 | if self.verbose: 52 | it = pt.tqdm(it, total=len(df), unit=self.text_field) 53 | res = [] 54 | for batch in more_itertools.chunked(it, self.batch_size): 55 | res.extend(self.splade.encode(batch, self.rep, format='dict' if self.sparse else 'np', scale=self.scale)) 56 | return df.assign(**{self.out_field: res}) 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pyterrier_splade 2 | 3 | An example of a SPLADE learned sparse indexing and retrieval using PyTerrier transformers. 4 | 5 | # Installation 6 | 7 | ```python 8 | %pip install -q git+https://github.com/cmacdonald/pyt_splade.git 9 | ``` 10 | 11 | # Indexing 12 | 13 | Indexing takes place as a pipeline: we apply SPLADE transformation of the documents, which maps raw text into a dictionary of BERT WordPiece tokens and corresponding weights. The underlying indexer, Terrier, is configured to handle arbitrary word counts without further tokenisation (`pretokenised=True`). 14 | 15 | The Terrier indexer is configured to index tokens unchanged. 16 | 17 | ```python 18 | 19 | import pyterrier as pt 20 | 21 | import pyt_splade 22 | splade = pyt_splade.Splade() 23 | indexer = pt.IterDictIndexer('./msmarco_psg', pretokenised=True) 24 | 25 | indxr_pipe = splade.doc_encoder() >> indexer 26 | index_ref = indxr_pipe.index(dataset.get_corpus_iter(), batch_size=128) 27 | 28 | ``` 29 | 30 | # Retrieval 31 | 32 | Similarly, SPLADE encodes the query into BERT WordPieces and corresponding weights. 33 | We apply this as a query encoding transformer. 34 | 35 | ```python 36 | 37 | splade_retr = splade.query_encoder() >> pt.terrier.Retriever('./msmarco_psg', wmodel='Tf') 38 | 39 | ``` 40 | 41 | # Scoring 42 | 43 | SPLADE can also be used as a text scoring function. 44 | 45 | ```python 46 | 47 | first_stage = ... # e.g., BM25, dense retrieval, etc. 48 | splade_scorer = first_stage >> pt.text.get_text(dataset, 'text') >> splade.scorer() 49 | 50 | ``` 51 | 52 | # PISA 53 | 54 | For faster retrieval with SPLADE, you can use the fast PISA retrieval backend provided by [PyTerrier_PISA](https://github.com/terrierteam/pyterrier_pisa): 55 | 56 | ```python 57 | import pyt_splade 58 | splade = pyt_splade.Splade() 59 | dataset = pt.get_dataset('irds:msmarco-passage') 60 | index = PisaIndex('./msmarco-passage-splade', stemmer='none') 61 | 62 | # indexing 63 | idx_pipeline = splade.doc_encoder() >> index.toks_indexer() 64 | idx_pipeline.index(dataset.get_corpus_iter()) 65 | 66 | # retrieval 67 | 68 | retr_pipeline = splade.query_encoder() >> index.quantized() 69 | ``` 70 | 71 | # Demo 72 | 73 | We have a demo of PyTerrier_SPLADE at https://huggingface.co/spaces/terrierteam/splade 74 | 75 | # Credits 76 | 77 | - Craig Macdonald 78 | - Sean MacAvaney 79 | - Nicola Tonellotto 80 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | .vscode/ 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ -------------------------------------------------------------------------------- /pyt_splade/pt_docs/index.rst: -------------------------------------------------------------------------------- 1 | SPLADE + PyTerrier 2 | ========================================== 3 | 4 | SPLADE indexing and retrieval using PyTerrier transformers. 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | API Documentation 10 | 11 | Installation 12 | --------------------------------------------- 13 | 14 | We use `Naver's SPLADE repository `__ as a dependency: 15 | 16 | .. code-block:: console 17 | :caption: Install PyTerrier SPLADE 18 | 19 | $ pip install -q python-terrier 20 | $ pip install -q git+https://github.com/naver/splade.git git+https://github.com/cmacdonald/pyt_splade.git 21 | 22 | 23 | Indexing 24 | --------------------------------------------- 25 | 26 | Indexing takes place as a pipeline: we apply SPLADE transformation of the documents, which maps raw text into a dictionary of BERT WordPiece tokens and corresponding weights. The underlying indexer, Terrier, is configured to handle arbitrary word counts without further tokenisation (`pretokenised=True`). 27 | 28 | The Terrier indexer is configured to index tokens unchanged. 29 | 30 | .. code-block:: python 31 | :caption: Install PyTerrier SPLADE 32 | 33 | import pyterrier as pt 34 | 35 | import pyt_splade 36 | splade = pyt_splade.Splade() 37 | indexer = pt.IterDictIndexer('./msmarco_psg', pretokenised=True) 38 | 39 | indxr_pipe = splade.doc_encoder() >> indexer 40 | index_ref = indxr_pipe.index(dataset.get_corpus_iter(), batch_size=128) 41 | 42 | 43 | Retrieval 44 | --------------------------------------------- 45 | 46 | Similarly, SPLADE encodes the query into BERT WordPieces and corresponding weights. 47 | We apply this as a query encoding transformer. 48 | 49 | .. code-block:: python 50 | 51 | splade_retr = splade.query_encoder() >> pt.terrier.Retriever('./msmarco_psg', wmodel='Tf') 52 | 53 | 54 | Scoring 55 | --------------------------------------------- 56 | 57 | SPLADE can also be used as a text scoring function. 58 | 59 | .. code-block:: python 60 | 61 | first_stage = ... # e.g., BM25, dense retrieval, etc. 62 | splade_scorer = first_stage >> pt.text.get_text(dataset, 'text') >> splade.scorer() 63 | 64 | 65 | PISA 66 | --------------------------------------------- 67 | 68 | For faster retrieval with SPLADE, you can use the fast PISA retrieval backend provided by [PyTerrier_PISA](https://github.com/terrierteam/pyterrier_pisa): 69 | 70 | .. code-block:: python 71 | 72 | import pyt_splade 73 | splade = pyt_splade.Splade() 74 | dataset = pt.get_dataset('irds:msmarco-passage') 75 | index = PisaIndex('./msmarco-passage-splade', stemmer='none') 76 | 77 | # indexing 78 | idx_pipeline = splade.doc_encoder() >> index.toks_indexer() 79 | idx_pipeline.index(dataset.get_corpus_iter()) 80 | 81 | # retrieval 82 | retr_pipeline = splade.query_encoder() >> index.quantized() 83 | 84 | 85 | Demo 86 | --------------------------------------------- 87 | 88 | We have a demo of PyTerrier_SPLADE at https://huggingface.co/spaces/terrierteam/splade 89 | 90 | 91 | Credits 92 | --------------------------------------------- 93 | 94 | - Craig Macdonald 95 | - Sean MacAvaney 96 | -------------------------------------------------------------------------------- /pyt_splade/_model.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Literal, Dict 2 | import torch 3 | import numpy as np 4 | import pyterrier as pt 5 | import pyt_splade 6 | 7 | class Splade: 8 | """A SPLADE model, which provides transformers for sparse encoding documents and queries, and scoring documents.""" 9 | 10 | def __init__( 11 | self, 12 | model: Union[torch.nn.Module, str] = "naver/splade-cocondenser-ensembledistil", 13 | tokenizer=None, 14 | agg='max', 15 | max_length=256, 16 | device=None 17 | ): 18 | """Initializes the SPLADE model. 19 | 20 | Args: 21 | model: the SPLADE model to use, either a PyTorch model or a string to load from HuggingFace 22 | tokenizer: the tokenizer to use, if not included in the model 23 | agg: the aggregation function to use for the SPLADE model 24 | max_length: the maximum length of the input sequences 25 | device: the device to use, e.g. 'cuda' or 'cpu' 26 | """ 27 | self.max_length = max_length 28 | self.model = model 29 | self.tokenizer = tokenizer 30 | if device is None: 31 | self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 32 | else: 33 | self.device = torch.device(device) 34 | if isinstance(model, str): 35 | from transformers import AutoModelForMaskedLM 36 | if self.tokenizer is None: 37 | from transformers import AutoTokenizer 38 | self.tokenizer = AutoTokenizer.from_pretrained(model) 39 | self.model = AutoModelForMaskedLM.from_pretrained(model) 40 | self.agg = agg 41 | self.model.output_dim = self.model.config.vocab_size 42 | self.model.eval() 43 | self.model = self.model.to(self.device) 44 | else: 45 | if self.tokenizer is None: 46 | raise ValueError("you must specify tokenizer if passing a model") 47 | 48 | self.reverse_voc = {v: k for k, v in self.tokenizer.vocab.items()} 49 | 50 | def doc_encoder(self, text_field='text', batch_size=100, sparse=True, verbose=False, scale=100) -> pt.Transformer: 51 | """Returns a transformer that encodes a text field into a document representation. 52 | 53 | Args: 54 | text_field: the text field to encode 55 | batch_size: the batch size to use when encoding 56 | sparse: if True, the output will be a dict of term frequencies, otherwise a dense vector 57 | verbose: if True, show a progress bar 58 | scale: the scale to apply to the term frequencies 59 | """ 60 | out_field = 'toks' if sparse else 'doc_vec' 61 | return pyt_splade.SpladeEncoder(self, text_field, out_field, 'd', sparse, batch_size, verbose, scale) 62 | 63 | indexing = doc_encoder # backward compatible name 64 | 65 | def query_encoder(self, batch_size=100, sparse=True, verbose=False, scale=100) -> pt.Transformer: 66 | """Returns a transformer that encodes a query field into a query representation. 67 | 68 | Args: 69 | batch_size: the batch size to use when encoding 70 | sparse: if True, the output will be a dict of term frequencies, otherwise a dense vector 71 | verbose: if True, show a progress bar 72 | scale: the scale to apply to the term frequencies 73 | """ 74 | out_field = 'query_toks' if sparse else 'query_vec' 75 | res = pyt_splade.SpladeEncoder(self, 'query', out_field, 'q', sparse, batch_size, verbose, scale) 76 | return res 77 | 78 | query = query_encoder # backward compatible name 79 | 80 | def scorer(self, text_field='text', batch_size=100, verbose=False) -> pt.Transformer: 81 | """Returns a transformer that scores documents against queries. 82 | 83 | Args: 84 | text_field: the text field to score 85 | batch_size: the batch size to use when scoring 86 | verbose: if True, show a progress bar 87 | """ 88 | return pyt_splade.SpladeScorer(self, text_field, batch_size, verbose) 89 | 90 | def encode( 91 | self, 92 | texts: List[str], 93 | rep: Literal['d', 'q'] = 'd', 94 | format: Literal['dict', 'np', 'torch'] ='dict', 95 | scale: float = 1., 96 | ) -> Union[List[Dict[str, float]], List[np.ndarray], torch.Tensor]: 97 | """Encodes a batch of texts into their SPLADE representations. 98 | 99 | Args: 100 | texts: the list of texts to encode 101 | rep: 'q' for query, 'd' for document 102 | format: 'dict' for a dict of term frequencies, 'np' for a list of numpy arrays, 'torch' for a torch tensor 103 | scale: the scale to apply to the term frequencies 104 | """ 105 | rtr = [] 106 | with torch.no_grad(): 107 | inputs = self.tokenizer( 108 | texts, 109 | add_special_tokens=True, 110 | return_tensors="pt", 111 | padding="longest", 112 | truncation="longest_first", # truncates to max model length, 113 | max_length=self.max_length, 114 | ).to(self.device) 115 | 116 | reps = self.model(**inputs).logits 117 | if self.agg == "sum": 118 | reps = torch.sum(torch.log(1 + torch.relu(reps)) * inputs["attention_mask"].unsqueeze(-1), dim=1) 119 | else: 120 | reps, _ = torch.max(torch.log(1 + torch.relu(reps)) * inputs["attention_mask"].unsqueeze(-1), dim=1) 121 | reps = reps * scale 122 | if format == 'dict': 123 | reps = reps.cpu() 124 | for i in range(reps.shape[0]): 125 | # get the number of non-zero dimensions in the rep: 126 | col = torch.nonzero(reps[i]).squeeze(1).tolist() 127 | # now let's create the bow representation as a dictionary 128 | weights = reps[i, col].cpu().tolist() 129 | # if document cast to int to make the weights ready for terrier indexing 130 | if rep == "d": 131 | weights = list(map(int, weights)) 132 | sorted_weights = sorted(zip(col, weights), key=lambda x: (-x[1], x[0])) 133 | # create the dict removing the weights less than 1, i.e. 0, that are not helpful 134 | d = {self.reverse_voc[k]: v for k, v in sorted_weights if v > 0} 135 | rtr.append(d) 136 | elif format == 'np': 137 | reps = reps.cpu().numpy() 138 | for i in range(reps.shape[0]): 139 | rtr.append(reps[i]) 140 | elif format == 'torch': 141 | rtr = reps 142 | return rtr 143 | -------------------------------------------------------------------------------- /trec-covid.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "private_outputs": true, 7 | "provenance": [], 8 | "gpuType": "T4" 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | }, 17 | "accelerator": "GPU" 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "source": [ 23 | "# SPLADE on TREC COVID Corpus using PyTerrier\n", 24 | "\n", 25 | "This notebook demonstrates the creation of a SPLADE index using PyTerrier.\n", 26 | "\n", 27 | "## Installation\n", 28 | "\n", 29 | "Install using pip:" 30 | ], 31 | "metadata": { 32 | "id": "aT_ldO9xB70Y" 33 | } 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": { 39 | "id": "RzJ6orDcAVaw" 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "!pip install -q git+https://github.com/tonellotto/pyt_splade@naverless-branch" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "source": [ 49 | "## Setup\n", 50 | "\n", 51 | "We create a factory object `splade` that gives us access to the appropriate transformers to use SPLADE." 52 | ], 53 | "metadata": { 54 | "id": "Muc5TsTpCKYj" 55 | } 56 | }, 57 | { 58 | "cell_type": "code", 59 | "source": [ 60 | "import pyterrier as pt\n", 61 | "import pyt_splade\n", 62 | "\n", 63 | "splade = pyt_splade.Splade(device='cuda:0')\n", 64 | "doc_encoder = splade.doc_encoder()" 65 | ], 66 | "metadata": { 67 | "id": "ZsbO1m39BOAc" 68 | }, 69 | "execution_count": null, 70 | "outputs": [] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "source": [ 75 | "## Indexing demo\n", 76 | "\n", 77 | "Lets see what terms are generated by the SPLADE model during indexing." 78 | ], 79 | "metadata": { 80 | "id": "gtgZfso7CTiw" 81 | } 82 | }, 83 | { 84 | "cell_type": "code", 85 | "source": [ 86 | "df = doc_encoder([{'docno' : 'd1', 'text' : 'ww2'}])\n", 87 | "df[0]['toks']" 88 | ], 89 | "metadata": { 90 | "id": "7QMhQucBBbXV" 91 | }, 92 | "execution_count": null, 93 | "outputs": [] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "source": [ 98 | "## Indexing TREC COVID\n", 99 | "\n", 100 | "Lets go and create an index for the TREC COVID corpus. The following will provide access to the dataset:" 101 | ], 102 | "metadata": { 103 | "id": "uVjoexiJCat7" 104 | } 105 | }, 106 | { 107 | "cell_type": "code", 108 | "source": [ 109 | "dataset = pt.get_dataset('irds:beir/trec-covid')" 110 | ], 111 | "metadata": { 112 | "id": "8k-WT0CDBmxS" 113 | }, 114 | "execution_count": null, 115 | "outputs": [] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "source": [ 120 | "This is the actual indexing code. We use the SPLADE model to transform the passages into tokens and weights. The following code took approx. 1 hour to run on Google Colab." 121 | ], 122 | "metadata": { 123 | "id": "Z1WtTORtClki" 124 | } 125 | }, 126 | { 127 | "cell_type": "code", 128 | "source": [ 129 | "import os\n", 130 | "\n", 131 | "if not os.path.exists('./trec_covid'): # skip if already created\n", 132 | " indexer = pt.IterDictIndexer('./trec_covid', pretokenised=True)\n", 133 | " indexer.setProperty(\"termpipelines\", \"\")\n", 134 | " indexer.setProperty(\"tokeniser\", \"WhitespaceTokeniser\")\n", 135 | "\n", 136 | " indexer_pipe = doc_encoder >> indexer\n", 137 | " index_ref = indexer_pipe.index(dataset.get_corpus_iter())" 138 | ], 139 | "metadata": { 140 | "id": "d2rQGu8XChNi" 141 | }, 142 | "execution_count": null, 143 | "outputs": [] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "source": [ 148 | "## Retrieval\n", 149 | "\n", 150 | "We can now conduct retrieval using PyTerrier." 151 | ], 152 | "metadata": { 153 | "id": "YVMmGpkkcTtT" 154 | } 155 | }, 156 | { 157 | "cell_type": "code", 158 | "source": [ 159 | "retr = pt.terrier.Retriever('./trec_covid', wmodel='Tf', verbose=True)\n", 160 | "\n", 161 | "retr_pipe = splade.query_encoder() >> retr" 162 | ], 163 | "metadata": { 164 | "id": "sB8PgyuscQdf" 165 | }, 166 | "execution_count": null, 167 | "outputs": [] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "source": [ 172 | "Lets check retrieval works, and we can see the generated query." 173 | ], 174 | "metadata": { 175 | "id": "kk02JvMOccrP" 176 | } 177 | }, 178 | { 179 | "cell_type": "code", 180 | "source": [ 181 | "retr_pipe.search('chemical reactions')" 182 | ], 183 | "metadata": { 184 | "id": "9HRA3jUkPy22" 185 | }, 186 | "execution_count": null, 187 | "outputs": [] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "source": [ 192 | "Finally, lets run the experiment and see the resulting performance." 193 | ], 194 | "metadata": { 195 | "id": "FdDJih3Fc0kR" 196 | } 197 | }, 198 | { 199 | "cell_type": "code", 200 | "source": [ 201 | "from pyterrier.measures import *\n", 202 | "pt.Experiment(\n", 203 | " [retr_pipe],\n", 204 | " dataset.get_topics(),\n", 205 | " dataset.get_qrels(),\n", 206 | " eval_metrics=[RR(rel=2), nDCG@10, nDCG@100, AP(rel=2)],\n", 207 | " names=['splade']\n", 208 | ")" 209 | ], 210 | "metadata": { 211 | "id": "0sZCAnSLczrh" 212 | }, 213 | "execution_count": null, 214 | "outputs": [] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "source": [ 219 | "## Exploring the Index" 220 | ], 221 | "metadata": { 222 | "id": "g5BROW5HdXUR" 223 | } 224 | }, 225 | { 226 | "cell_type": "code", 227 | "source": [ 228 | "index = pt.java.cast(\"org.terrier.querying.LocalManager\", retr.manager).index" 229 | ], 230 | "metadata": { 231 | "id": "pyDI0ETgcg7c" 232 | }, 233 | "execution_count": null, 234 | "outputs": [] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "source": [ 239 | "Lets explore the lexicon - what tokens were used? (First 100)" 240 | ], 241 | "metadata": { 242 | "id": "jEJT77AcdjX4" 243 | } 244 | }, 245 | { 246 | "cell_type": "code", 247 | "source": [ 248 | "for i, entry in enumerate(index.getLexicon()):\n", 249 | " if i == 100:\n", 250 | " break\n", 251 | " print(entry.getKey() + \" \" + entry.getValue().toString())" 252 | ], 253 | "metadata": { 254 | "id": "eq33vuL9ddx3" 255 | }, 256 | "execution_count": null, 257 | "outputs": [] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "source": [ 262 | "print(index.getCollectionStatistics().toString())" 263 | ], 264 | "metadata": { 265 | "id": "yiXTJYvjdmCZ" 266 | }, 267 | "execution_count": null, 268 | "outputs": [] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "source": [ 273 | "We can even look into particular document in the index." 274 | ], 275 | "metadata": { 276 | "id": "_u6SWmbydy7o" 277 | } 278 | }, 279 | { 280 | "cell_type": "code", 281 | "source": [ 282 | "di = index.getDirectIndex()\n", 283 | "doi = index.getDocumentIndex()\n", 284 | "lex = index.getLexicon()\n", 285 | "docid = 77_000 #docids are 0-based\n", 286 | "#NB: postings will be null if the document is empty\n", 287 | "dictrep = {}\n", 288 | "for posting in di.getPostings(doi.getDocumentEntry(docid)):\n", 289 | " termid = posting.getId()\n", 290 | " lee = lex.getLexiconEntry(termid)\n", 291 | " dictrep[lee.getKey()] = posting.getFrequency()\n", 292 | "\n", 293 | "for k in sorted(dictrep.keys()):\n", 294 | " print(k, dictrep[k])" 295 | ], 296 | "metadata": { 297 | "id": "EuF9PQIydto2" 298 | }, 299 | "execution_count": null, 300 | "outputs": [] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "source": [], 305 | "metadata": { 306 | "id": "cHZ4GfxFd1br" 307 | }, 308 | "execution_count": null, 309 | "outputs": [] 310 | } 311 | ] 312 | } -------------------------------------------------------------------------------- /msmarco-psg-v1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "aT_ldO9xB70Y" 7 | }, 8 | "source": [ 9 | "# SPLADE on MSMARCO v1 Passage Corpus using PyTerrier\n", 10 | "\n", 11 | "This notebook demonstrates the creation of a SPLADE index using PyTerrier.\n", 12 | "\n", 13 | "## Installation\n", 14 | "\n", 15 | "Install using pip:" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": { 22 | "id": "RzJ6orDcAVaw" 23 | }, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "Collecting git+https://github.com/tonellotto/pyt_splade@naverless-branch\n", 30 | " Cloning https://github.com/tonellotto/pyt_splade (to revision naverless-branch) to /tmp/pip-req-build-zcd3emia\n", 31 | " Running command git clone --filter=blob:none --quiet https://github.com/tonellotto/pyt_splade /tmp/pip-req-build-zcd3emia\n", 32 | " Running command git checkout -b naverless-branch --track origin/naverless-branch\n", 33 | " Switched to a new branch 'naverless-branch'\n", 34 | " Branch 'naverless-branch' set up to track remote branch 'naverless-branch' from 'origin'.\n", 35 | " Resolved https://github.com/tonellotto/pyt_splade to commit 77a2ab7964e1c3297c5412bc60e3c34c87b1faae\n", 36 | " Preparing metadata (setup.py) ... \u001b[?25ldone\n", 37 | "\u001b[?25hRequirement already satisfied: torch>=2.6.0 in /home/nicola/miniforge3/lib/python3.12/site-packages (from pyt_splade==0.0.2) (2.7.1)\n", 38 | "Requirement already satisfied: transformers in /home/nicola/miniforge3/lib/python3.12/site-packages (from pyt_splade==0.0.2) (4.52.4)\n", 39 | "Requirement already satisfied: python-terrier>=0.11.0 in /home/nicola/miniforge3/lib/python3.12/site-packages (from pyt_splade==0.0.2) (0.13.0)\n", 40 | "Requirement already satisfied: pyterrier_alpha in /home/nicola/miniforge3/lib/python3.12/site-packages (from pyt_splade==0.0.2) (0.14.0)\n", 41 | "Requirement already satisfied: numpy in /home/nicola/miniforge3/lib/python3.12/site-packages (from python-terrier>=0.11.0->pyt_splade==0.0.2) (1.26.4)\n", 42 | "Requirement already satisfied: pandas in /home/nicola/miniforge3/lib/python3.12/site-packages (from python-terrier>=0.11.0->pyt_splade==0.0.2) (2.1.4)\n", 43 | "Requirement already satisfied: more-itertools in /home/nicola/miniforge3/lib/python3.12/site-packages (from python-terrier>=0.11.0->pyt_splade==0.0.2) (10.7.0)\n", 44 | "Requirement already satisfied: tqdm in /home/nicola/miniforge3/lib/python3.12/site-packages (from python-terrier>=0.11.0->pyt_splade==0.0.2) (4.67.1)\n", 45 | "Requirement already satisfied: requests in /home/nicola/miniforge3/lib/python3.12/site-packages (from python-terrier>=0.11.0->pyt_splade==0.0.2) (2.32.3)\n", 46 | "Requirement already satisfied: ir-datasets>=0.3.2 in /home/nicola/miniforge3/lib/python3.12/site-packages (from python-terrier>=0.11.0->pyt_splade==0.0.2) (0.5.10)\n", 47 | "Requirement already satisfied: wget in /home/nicola/miniforge3/lib/python3.12/site-packages (from python-terrier>=0.11.0->pyt_splade==0.0.2) (3.2)\n", 48 | "Requirement already satisfied: pyjnius>=1.4.2 in /home/nicola/miniforge3/lib/python3.12/site-packages (from python-terrier>=0.11.0->pyt_splade==0.0.2) (1.6.1)\n", 49 | "Requirement already satisfied: deprecated in /home/nicola/miniforge3/lib/python3.12/site-packages (from python-terrier>=0.11.0->pyt_splade==0.0.2) (1.2.18)\n", 50 | "Requirement already satisfied: scipy in /home/nicola/miniforge3/lib/python3.12/site-packages (from python-terrier>=0.11.0->pyt_splade==0.0.2) (1.14.1)\n", 51 | "Requirement already satisfied: ir-measures>=0.3.1 in /home/nicola/miniforge3/lib/python3.12/site-packages (from python-terrier>=0.11.0->pyt_splade==0.0.2) (0.3.7)\n", 52 | "Requirement already satisfied: pytrec-eval-terrier>=0.5.3 in /home/nicola/miniforge3/lib/python3.12/site-packages (from python-terrier>=0.11.0->pyt_splade==0.0.2) (0.5.7)\n", 53 | "Requirement already satisfied: jinja2 in /home/nicola/miniforge3/lib/python3.12/site-packages (from python-terrier>=0.11.0->pyt_splade==0.0.2) (3.1.4)\n", 54 | "Requirement already satisfied: statsmodels in /home/nicola/miniforge3/lib/python3.12/site-packages (from python-terrier>=0.11.0->pyt_splade==0.0.2) (0.14.4)\n", 55 | "Requirement already satisfied: dill in /home/nicola/miniforge3/lib/python3.12/site-packages (from python-terrier>=0.11.0->pyt_splade==0.0.2) (0.4.0)\n", 56 | "Requirement already satisfied: joblib in /home/nicola/miniforge3/lib/python3.12/site-packages (from python-terrier>=0.11.0->pyt_splade==0.0.2) (1.5.1)\n", 57 | "Requirement already satisfied: chest in /home/nicola/miniforge3/lib/python3.12/site-packages (from python-terrier>=0.11.0->pyt_splade==0.0.2) (0.2.3)\n", 58 | "Requirement already satisfied: lz4 in /home/nicola/miniforge3/lib/python3.12/site-packages (from python-terrier>=0.11.0->pyt_splade==0.0.2) (4.4.4)\n", 59 | "Requirement already satisfied: filelock in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (3.16.1)\n", 60 | "Requirement already satisfied: typing-extensions>=4.10.0 in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (4.12.2)\n", 61 | "Requirement already satisfied: setuptools in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (75.6.0)\n", 62 | "Requirement already satisfied: sympy>=1.13.3 in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (1.14.0)\n", 63 | "Requirement already satisfied: networkx in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (3.4.2)\n", 64 | "Requirement already satisfied: fsspec in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (2024.10.0)\n", 65 | "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (12.6.77)\n", 66 | "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (12.6.77)\n", 67 | "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (12.6.80)\n", 68 | "Requirement already satisfied: nvidia-cudnn-cu12==9.5.1.17 in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (9.5.1.17)\n", 69 | "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (12.6.4.1)\n", 70 | "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (11.3.0.4)\n", 71 | "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (10.3.7.77)\n", 72 | "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (11.7.1.2)\n", 73 | "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (12.5.4.2)\n", 74 | "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.3 in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (0.6.3)\n", 75 | "Requirement already satisfied: nvidia-nccl-cu12==2.26.2 in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (2.26.2)\n", 76 | "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (12.6.77)\n", 77 | "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (12.6.85)\n", 78 | "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (1.11.1.6)\n", 79 | "Requirement already satisfied: triton==3.3.1 in /home/nicola/miniforge3/lib/python3.12/site-packages (from torch>=2.6.0->pyt_splade==0.0.2) (3.3.1)\n", 80 | "Requirement already satisfied: colaburl in /home/nicola/miniforge3/lib/python3.12/site-packages (from pyterrier_alpha->pyt_splade==0.0.2) (0.1.0)\n", 81 | "Requirement already satisfied: huggingface-hub<1.0,>=0.30.0 in /home/nicola/miniforge3/lib/python3.12/site-packages (from transformers->pyt_splade==0.0.2) (0.32.4)\n", 82 | "Requirement already satisfied: packaging>=20.0 in /home/nicola/miniforge3/lib/python3.12/site-packages (from transformers->pyt_splade==0.0.2) (24.2)\n", 83 | "Requirement already satisfied: pyyaml>=5.1 in /home/nicola/miniforge3/lib/python3.12/site-packages (from transformers->pyt_splade==0.0.2) (6.0.2)\n", 84 | "Requirement already satisfied: regex!=2019.12.17 in /home/nicola/miniforge3/lib/python3.12/site-packages (from transformers->pyt_splade==0.0.2) (2024.11.6)\n", 85 | "Requirement already satisfied: tokenizers<0.22,>=0.21 in /home/nicola/miniforge3/lib/python3.12/site-packages (from transformers->pyt_splade==0.0.2) (0.21.1)\n", 86 | "Requirement already satisfied: safetensors>=0.4.3 in /home/nicola/miniforge3/lib/python3.12/site-packages (from transformers->pyt_splade==0.0.2) (0.5.3)\n", 87 | "Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /home/nicola/miniforge3/lib/python3.12/site-packages (from huggingface-hub<1.0,>=0.30.0->transformers->pyt_splade==0.0.2) (1.1.3)\n", 88 | "Requirement already satisfied: beautifulsoup4>=4.4.1 in /home/nicola/miniforge3/lib/python3.12/site-packages (from ir-datasets>=0.3.2->python-terrier>=0.11.0->pyt_splade==0.0.2) (4.12.3)\n", 89 | "Requirement already satisfied: inscriptis>=2.2.0 in /home/nicola/miniforge3/lib/python3.12/site-packages (from ir-datasets>=0.3.2->python-terrier>=0.11.0->pyt_splade==0.0.2) (2.6.0)\n", 90 | "Requirement already satisfied: lxml>=4.5.2 in /home/nicola/miniforge3/lib/python3.12/site-packages (from ir-datasets>=0.3.2->python-terrier>=0.11.0->pyt_splade==0.0.2) (5.4.0)\n", 91 | "Requirement already satisfied: trec-car-tools>=2.5.4 in /home/nicola/miniforge3/lib/python3.12/site-packages (from ir-datasets>=0.3.2->python-terrier>=0.11.0->pyt_splade==0.0.2) (2.6)\n", 92 | "Requirement already satisfied: warc3-wet>=0.2.3 in /home/nicola/miniforge3/lib/python3.12/site-packages (from ir-datasets>=0.3.2->python-terrier>=0.11.0->pyt_splade==0.0.2) (0.2.5)\n", 93 | "Requirement already satisfied: warc3-wet-clueweb09>=0.2.5 in /home/nicola/miniforge3/lib/python3.12/site-packages (from ir-datasets>=0.3.2->python-terrier>=0.11.0->pyt_splade==0.0.2) (0.2.5)\n", 94 | "Requirement already satisfied: zlib-state>=0.1.3 in /home/nicola/miniforge3/lib/python3.12/site-packages (from ir-datasets>=0.3.2->python-terrier>=0.11.0->pyt_splade==0.0.2) (0.1.9)\n", 95 | "Requirement already satisfied: ijson>=3.1.3 in /home/nicola/miniforge3/lib/python3.12/site-packages (from ir-datasets>=0.3.2->python-terrier>=0.11.0->pyt_splade==0.0.2) (3.4.0)\n", 96 | "Requirement already satisfied: unlzw3>=0.2.1 in /home/nicola/miniforge3/lib/python3.12/site-packages (from ir-datasets>=0.3.2->python-terrier>=0.11.0->pyt_splade==0.0.2) (0.2.3)\n", 97 | "Requirement already satisfied: pyarrow>=16.1.0 in /home/nicola/miniforge3/lib/python3.12/site-packages (from ir-datasets>=0.3.2->python-terrier>=0.11.0->pyt_splade==0.0.2) (20.0.0)\n", 98 | "Requirement already satisfied: charset_normalizer<4,>=2 in /home/nicola/miniforge3/lib/python3.12/site-packages (from requests->python-terrier>=0.11.0->pyt_splade==0.0.2) (3.4.0)\n", 99 | "Requirement already satisfied: idna<4,>=2.5 in /home/nicola/miniforge3/lib/python3.12/site-packages (from requests->python-terrier>=0.11.0->pyt_splade==0.0.2) (3.10)\n", 100 | "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/nicola/miniforge3/lib/python3.12/site-packages (from requests->python-terrier>=0.11.0->pyt_splade==0.0.2) (2.2.3)\n", 101 | "Requirement already satisfied: certifi>=2017.4.17 in /home/nicola/miniforge3/lib/python3.12/site-packages (from requests->python-terrier>=0.11.0->pyt_splade==0.0.2) (2024.8.30)\n", 102 | "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/nicola/miniforge3/lib/python3.12/site-packages (from sympy>=1.13.3->torch>=2.6.0->pyt_splade==0.0.2) (1.3.0)\n", 103 | "Requirement already satisfied: heapdict in /home/nicola/miniforge3/lib/python3.12/site-packages (from chest->python-terrier>=0.11.0->pyt_splade==0.0.2) (1.0.1)\n", 104 | "Requirement already satisfied: wrapt<2,>=1.10 in /home/nicola/miniforge3/lib/python3.12/site-packages (from deprecated->python-terrier>=0.11.0->pyt_splade==0.0.2) (1.17.2)\n", 105 | "Requirement already satisfied: MarkupSafe>=2.0 in /home/nicola/miniforge3/lib/python3.12/site-packages (from jinja2->python-terrier>=0.11.0->pyt_splade==0.0.2) (3.0.2)\n", 106 | "Requirement already satisfied: python-dateutil>=2.8.2 in /home/nicola/miniforge3/lib/python3.12/site-packages (from pandas->python-terrier>=0.11.0->pyt_splade==0.0.2) (2.9.0.post0)\n", 107 | "Requirement already satisfied: pytz>=2020.1 in /home/nicola/miniforge3/lib/python3.12/site-packages (from pandas->python-terrier>=0.11.0->pyt_splade==0.0.2) (2024.2)\n", 108 | "Requirement already satisfied: tzdata>=2022.1 in /home/nicola/miniforge3/lib/python3.12/site-packages (from pandas->python-terrier>=0.11.0->pyt_splade==0.0.2) (2024.2)\n", 109 | "Requirement already satisfied: patsy>=0.5.6 in /home/nicola/miniforge3/lib/python3.12/site-packages (from statsmodels->python-terrier>=0.11.0->pyt_splade==0.0.2) (1.0.1)\n", 110 | "Requirement already satisfied: soupsieve>1.2 in /home/nicola/miniforge3/lib/python3.12/site-packages (from beautifulsoup4>=4.4.1->ir-datasets>=0.3.2->python-terrier>=0.11.0->pyt_splade==0.0.2) (2.5)\n", 111 | "Requirement already satisfied: six>=1.5 in /home/nicola/miniforge3/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas->python-terrier>=0.11.0->pyt_splade==0.0.2) (1.16.0)\n", 112 | "Requirement already satisfied: cbor>=1.0.0 in /home/nicola/miniforge3/lib/python3.12/site-packages (from trec-car-tools>=2.5.4->ir-datasets>=0.3.2->python-terrier>=0.11.0->pyt_splade==0.0.2) (1.0.0)\n" 113 | ] 114 | } 115 | ], 116 | "source": [ 117 | "!pip install git+https://github.com/tonellotto/pyt_splade@naverless-branch" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": { 123 | "id": "Muc5TsTpCKYj" 124 | }, 125 | "source": [ 126 | "## Setup\n", 127 | "\n", 128 | "We create a factory object `splade` that gives us access to the appropriate transformers to use SPLADE." 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 1, 134 | "metadata": { 135 | "id": "ZsbO1m39BOAc" 136 | }, 137 | "outputs": [], 138 | "source": [ 139 | "import pyterrier as pt\n", 140 | "import pyt_splade\n", 141 | "\n", 142 | "splade = pyt_splade.Splade(device='cuda:1')\n", 143 | "doc_encoder = splade.doc_encoder()" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": { 149 | "id": "gtgZfso7CTiw" 150 | }, 151 | "source": [ 152 | "## Indexing demo\n", 153 | "\n", 154 | "Lets see what terms are generated by the SPLADE model during indexing." 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 2, 160 | "metadata": { 161 | "id": "7QMhQucBBbXV" 162 | }, 163 | "outputs": [ 164 | { 165 | "data": { 166 | "text/plain": [ 167 | "{'w': 199,\n", 168 | " '##2': 193,\n", 169 | " 'war': 167,\n", 170 | " 'wwii': 150,\n", 171 | " '##w': 130,\n", 172 | " 'ii': 110,\n", 173 | " '2': 94,\n", 174 | " 'germany': 86,\n", 175 | " 'army': 76,\n", 176 | " 'battle': 70,\n", 177 | " 'was': 66,\n", 178 | " 'bomb': 48,\n", 179 | " 'event': 43,\n", 180 | " 'wilson': 43,\n", 181 | " 'conflict': 38,\n", 182 | " 'marshall': 33,\n", 183 | " 'allied': 23,\n", 184 | " 'surrender': 22,\n", 185 | " 'peace': 16,\n", 186 | " 'military': 12,\n", 187 | " 'era': 10,\n", 188 | " 'alliance': 10,\n", 189 | " 'weapon': 10,\n", 190 | " 'wars': 8,\n", 191 | " 'camp': 7,\n", 192 | " 'were': 6,\n", 193 | " 'france': 6,\n", 194 | " 'invasion': 6,\n", 195 | " 'nazi': 4,\n", 196 | " 'zombie': 2,\n", 197 | " 'german': 1,\n", 198 | " 'japan': 1,\n", 199 | " 'patton': 1}" 200 | ] 201 | }, 202 | "execution_count": 2, 203 | "metadata": {}, 204 | "output_type": "execute_result" 205 | } 206 | ], 207 | "source": [ 208 | "df = doc_encoder([{'docno' : 'd1', 'text' : 'ww2'}])\n", 209 | "df[0]['toks']" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": { 215 | "id": "uVjoexiJCat7" 216 | }, 217 | "source": [ 218 | "## Indexing MSMARCO\n", 219 | "\n", 220 | "Lets go and create an index for the MSMARCO v1 passage corpus. The following will provide access to the dataset:" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 3, 226 | "metadata": { 227 | "id": "8k-WT0CDBmxS" 228 | }, 229 | "outputs": [], 230 | "source": [ 231 | "dataset = pt.get_dataset('irds:msmarco-passage')" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": { 237 | "id": "Z1WtTORtClki" 238 | }, 239 | "source": [ 240 | "This is the actual indexing code. We use the SPLADE model to transform the passages into tokens and weights. It took around 4 hours to run on a RTX 4090." 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 4, 246 | "metadata": { 247 | "id": "d2rQGu8XChNi" 248 | }, 249 | "outputs": [ 250 | { 251 | "name": "stderr", 252 | "output_type": "stream", 253 | "text": [ 254 | "Java started (triggered by TerrierIndexer.__init__) and loaded: pyterrier.java, pyterrier.terrier.java [version=5.11 (build: craig.macdonald 2025-01-13 21:29), helper_version=0.0.8]\n", 255 | "msmarco-passage documents: 100%|████████████████████████████████████████████████████████████████████████████████████████| 8841823/8841823 [3:57:10<00:00, 621.34it/s]\n" 256 | ] 257 | } 258 | ], 259 | "source": [ 260 | "import os\n", 261 | "\n", 262 | "if not os.path.exists('./msmarco_psg'): # skip if already created\n", 263 | " indexer = pt.IterDictIndexer('./msmarco_psg', pretokenised=True)\n", 264 | " indexer.setProperty(\"termpipelines\", \"\")\n", 265 | " indexer.setProperty(\"tokeniser\", \"WhitespaceTokeniser\")\n", 266 | "\n", 267 | " indexer_pipe = doc_encoder >> indexer\n", 268 | " index_ref = indexer_pipe.index(dataset.get_corpus_iter())" 269 | ] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "metadata": { 274 | "id": "qvK6c1vjCqDn" 275 | }, 276 | "source": [ 277 | "## Retrieval\n", 278 | "\n", 279 | "We can now conduct retrieval using PyTerrier." 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 5, 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [ 288 | "retr = pt.terrier.Retriever('./msmarco_psg', wmodel='Tf', verbose=True)\n", 289 | "\n", 290 | "retr_pipe = splade.query_encoder() >> retr" 291 | ] 292 | }, 293 | { 294 | "cell_type": "markdown", 295 | "metadata": {}, 296 | "source": [ 297 | "Let check retrieval works, and we can see the generated query." 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 6, 303 | "metadata": {}, 304 | "outputs": [ 305 | { 306 | "name": "stderr", 307 | "output_type": "stream", 308 | "text": [ 309 | "TerrierRetr(Tf): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.48s/q]\n" 310 | ] 311 | }, 312 | { 313 | "data": { 314 | "text/html": [ 315 | "
\n", 316 | "\n", 329 | "\n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | "
qiddociddocnorankscorequeryquery_toks
017582847582840759.949764chemical reactions{'reactions': 269.4256896972656, 'reaction': 2...
11591379459137941758.851035chemical reactions{'reactions': 269.4256896972656, 'reaction': 2...
217422067422062757.623774chemical reactions{'reactions': 269.4256896972656, 'reaction': 2...
31857219185721913750.944456chemical reactions{'reactions': 269.4256896972656, 'reaction': 2...
411299011299014748.763385chemical reactions{'reactions': 269.4256896972656, 'reaction': 2...
........................
995150946055094605995543.317372chemical reactions{'reactions': 269.4256896972656, 'reaction': 2...
996122264672226467996543.290851chemical reactions{'reactions': 269.4256896972656, 'reaction': 2...
997178669697866969997543.282898chemical reactions{'reactions': 269.4256896972656, 'reaction': 2...
998137324153732415998543.242431chemical reactions{'reactions': 269.4256896972656, 'reaction': 2...
999111928171192817999543.160156chemical reactions{'reactions': 269.4256896972656, 'reaction': 2...
\n", 455 | "

1000 rows × 7 columns

\n", 456 | "
" 457 | ], 458 | "text/plain": [ 459 | " qid docid docno rank score query \\\n", 460 | "0 1 758284 758284 0 759.949764 chemical reactions \n", 461 | "1 1 5913794 5913794 1 758.851035 chemical reactions \n", 462 | "2 1 742206 742206 2 757.623774 chemical reactions \n", 463 | "3 1 8572191 8572191 3 750.944456 chemical reactions \n", 464 | "4 1 129901 129901 4 748.763385 chemical reactions \n", 465 | ".. .. ... ... ... ... ... \n", 466 | "995 1 5094605 5094605 995 543.317372 chemical reactions \n", 467 | "996 1 2226467 2226467 996 543.290851 chemical reactions \n", 468 | "997 1 7866969 7866969 997 543.282898 chemical reactions \n", 469 | "998 1 3732415 3732415 998 543.242431 chemical reactions \n", 470 | "999 1 1192817 1192817 999 543.160156 chemical reactions \n", 471 | "\n", 472 | " query_toks \n", 473 | "0 {'reactions': 269.4256896972656, 'reaction': 2... \n", 474 | "1 {'reactions': 269.4256896972656, 'reaction': 2... \n", 475 | "2 {'reactions': 269.4256896972656, 'reaction': 2... \n", 476 | "3 {'reactions': 269.4256896972656, 'reaction': 2... \n", 477 | "4 {'reactions': 269.4256896972656, 'reaction': 2... \n", 478 | ".. ... \n", 479 | "995 {'reactions': 269.4256896972656, 'reaction': 2... \n", 480 | "996 {'reactions': 269.4256896972656, 'reaction': 2... \n", 481 | "997 {'reactions': 269.4256896972656, 'reaction': 2... \n", 482 | "998 {'reactions': 269.4256896972656, 'reaction': 2... \n", 483 | "999 {'reactions': 269.4256896972656, 'reaction': 2... \n", 484 | "\n", 485 | "[1000 rows x 7 columns]" 486 | ] 487 | }, 488 | "execution_count": 6, 489 | "metadata": {}, 490 | "output_type": "execute_result" 491 | } 492 | ], 493 | "source": [ 494 | "retr_pipe.search('chemical reactions')" 495 | ] 496 | }, 497 | { 498 | "cell_type": "markdown", 499 | "metadata": {}, 500 | "source": [ 501 | "Finally, lets run the experiment and see the resulting performance." 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": 7, 507 | "metadata": {}, 508 | "outputs": [ 509 | { 510 | "name": "stderr", 511 | "output_type": "stream", 512 | "text": [ 513 | "TerrierRetr(Tf): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 43/43 [00:43<00:00, 1.01s/q]\n" 514 | ] 515 | }, 516 | { 517 | "data": { 518 | "text/html": [ 519 | "
\n", 520 | "\n", 533 | "\n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | "
nameRR(rel=2)nDCG@10nDCG@100AP(rel=2)
0splade0.9186050.7310910.6725690.504771
\n", 555 | "
" 556 | ], 557 | "text/plain": [ 558 | " name RR(rel=2) nDCG@10 nDCG@100 AP(rel=2)\n", 559 | "0 splade 0.918605 0.731091 0.672569 0.504771" 560 | ] 561 | }, 562 | "execution_count": 7, 563 | "metadata": {}, 564 | "output_type": "execute_result" 565 | } 566 | ], 567 | "source": [ 568 | "from pyterrier.measures import *\n", 569 | "\n", 570 | "pt.Experiment(\n", 571 | " [retr_pipe],\n", 572 | " pt.get_dataset('irds:msmarco-passage/trec-dl-2019/judged').get_topics(),\n", 573 | " pt.get_dataset('irds:msmarco-passage/trec-dl-2019/judged').get_qrels(),\n", 574 | " eval_metrics=[RR(rel=2), nDCG@10, nDCG@100, AP(rel=2)],\n", 575 | " names=['splade']\n", 576 | ") " 577 | ] 578 | }, 579 | { 580 | "cell_type": "markdown", 581 | "metadata": {}, 582 | "source": [ 583 | "## Exploring the Index" 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "execution_count": 8, 589 | "metadata": {}, 590 | "outputs": [], 591 | "source": [ 592 | "index = pt.java.cast(\"org.terrier.querying.LocalManager\", retr.manager).index" 593 | ] 594 | }, 595 | { 596 | "cell_type": "markdown", 597 | "metadata": {}, 598 | "source": [ 599 | "Lets explore the lexicon - what tokens were used? (First 100)" 600 | ] 601 | }, 602 | { 603 | "cell_type": "code", 604 | "execution_count": 9, 605 | "metadata": {}, 606 | "outputs": [ 607 | { 608 | "name": "stdout", 609 | "output_type": "stream", 610 | "text": [ 611 | "! term8768 Nt=35737 TF=1147883 maxTF=2147483647 @{0 0 0}\n", 612 | "\" term908 Nt=861221 TF=35398179 maxTF=2147483647 @{0 194287 4}\n", 613 | "# term5228 Nt=69467 TF=3962679 maxTF=2147483647 @{0 5071126 4}\n", 614 | "##0 term5242 Nt=68206 TF=4326264 maxTF=2147483647 @{0 5658804 6}\n", 615 | "##00 term19501 Nt=14675 TF=972519 maxTF=2147483647 @{0 6285192 0}\n", 616 | "##01 term12382 Nt=7860 TF=535730 maxTF=2147483647 @{0 6431703 6}\n", 617 | "##0s term26590 Nt=390 TF=27064 maxTF=2147483647 @{0 6513717 2}\n", 618 | "##1 term5497 Nt=105146 TF=5876691 maxTF=2147483647 @{0 6518228 4}\n", 619 | "##10 term17384 Nt=21166 TF=1574624 maxTF=2147483647 @{0 7370202 5}\n", 620 | "##100 term12383 Nt=13688 TF=807672 maxTF=2147483647 @{0 7601374 3}\n", 621 | "##11 term9506 Nt=17113 TF=1083192 maxTF=2147483647 @{0 7725680 3}\n", 622 | "##12 term8684 Nt=8396 TF=695098 maxTF=2147483647 @{0 7889352 4}\n", 623 | "##13 term12419 Nt=13620 TF=714544 maxTF=2147483647 @{0 7992063 4}\n", 624 | "##14 term17856 Nt=6107 TF=427526 maxTF=2147483647 @{0 8105092 2}\n", 625 | "##15 term13479 Nt=14618 TF=926535 maxTF=2147483647 @{0 8170719 5}\n", 626 | "##16 term8683 Nt=6142 TF=441899 maxTF=2147483647 @{0 8311817 6}\n", 627 | "##17 term24326 Nt=2636 TF=230550 maxTF=2147483647 @{0 8379185 3}\n", 628 | "##18 term16762 Nt=3902 TF=290819 maxTF=2147483647 @{0 8414039 1}\n", 629 | "##19 term10528 Nt=4095 TF=245959 maxTF=2147483647 @{0 8458778 0}\n", 630 | "##2 term8649 Nt=105439 TF=9245710 maxTF=2147483647 @{0 8498403 0}\n", 631 | "##20 term18609 Nt=7490 TF=659423 maxTF=2147483647 @{0 9769284 5}\n", 632 | "##200 term25235 Nt=2397 TF=231221 maxTF=2147483647 @{0 9866351 2}\n", 633 | "##21 term23253 Nt=4141 TF=373411 maxTF=2147483647 @{0 9900874 2}\n", 634 | "##22 term18176 Nt=8760 TF=662213 maxTF=2147483647 @{0 9956130 6}\n", 635 | "##23 term17557 Nt=3049 TF=281636 maxTF=2147483647 @{0 10055332 7}\n", 636 | "##24 term17721 Nt=3379 TF=278187 maxTF=2147483647 @{0 10097175 4}\n", 637 | "##25 term21614 Nt=5200 TF=400535 maxTF=2147483647 @{0 10139290 2}\n", 638 | "##26 term5510 Nt=9445 TF=451434 maxTF=2147483647 @{0 10199755 3}\n", 639 | "##27 term21852 Nt=4382 TF=308694 maxTF=2147483647 @{0 10273912 0}\n", 640 | "##28 term17720 Nt=3061 TF=203768 maxTF=2147483647 @{0 10321616 6}\n", 641 | "##29 term18147 Nt=3255 TF=219114 maxTF=2147483647 @{0 10353826 3}\n", 642 | "##3 term9089 Nt=100372 TF=6159820 maxTF=2147483647 @{0 10388279 0}\n", 643 | "##30 term8701 Nt=10309 TF=677044 maxTF=2147483647 @{0 11265301 4}\n", 644 | "##31 term8106 Nt=11104 TF=522069 maxTF=2147483647 @{0 11368653 5}\n", 645 | "##32 term22740 Nt=6042 TF=705050 maxTF=2147483647 @{0 11454090 6}\n", 646 | "##33 term12015 Nt=3192 TF=240934 maxTF=2147483647 @{0 11553508 0}\n", 647 | "##34 term5509 Nt=4688 TF=350059 maxTF=2147483647 @{0 11590549 0}\n", 648 | "##35 term15227 Nt=3599 TF=374293 maxTF=2147483647 @{0 11643931 5}\n", 649 | "##36 term8753 Nt=2978 TF=301242 maxTF=2147483647 @{0 11698166 5}\n", 650 | "##37 term8749 Nt=5202 TF=305868 maxTF=2147483647 @{0 11742357 7}\n", 651 | "##38 term8737 Nt=10590 TF=566862 maxTF=2147483647 @{0 11790941 3}\n", 652 | "##39 term21619 Nt=2371 TF=191157 maxTF=2147483647 @{0 11880762 3}\n", 653 | "##3d term23449 Nt=822 TF=71283 maxTF=2147483647 @{0 11909954 7}\n", 654 | "##4 term5220 Nt=73404 TF=5355513 maxTF=2147483647 @{0 11920945 0}\n", 655 | "##40 term8697 Nt=6192 TF=513181 maxTF=2147483647 @{0 12679028 7}\n", 656 | "##400 term23344 Nt=870 TF=104038 maxTF=2147483647 @{0 12755176 6}\n", 657 | "##41 term8696 Nt=7573 TF=366795 maxTF=2147483647 @{0 12770387 0}\n", 658 | "##42 term10282 Nt=2447 TF=223318 maxTF=2147483647 @{0 12830891 6}\n", 659 | "##43 term17558 Nt=3716 TF=264761 maxTF=2147483647 @{0 12864358 1}\n", 660 | "##44 term5555 Nt=18231 TF=710853 maxTF=2147483647 @{0 12905117 4}\n", 661 | "##45 term14730 Nt=5035 TF=467397 maxTF=2147483647 @{0 13022857 2}\n", 662 | "##46 term5493 Nt=3747 TF=273426 maxTF=2147483647 @{0 13091575 4}\n", 663 | "##47 term5492 Nt=5011 TF=332006 maxTF=2147483647 @{0 13133731 5}\n", 664 | "##48 term20854 Nt=1245 TF=110963 maxTF=2147483647 @{0 13185318 6}\n", 665 | "##49 term5491 Nt=3088 TF=209541 maxTF=2147483647 @{0 13202237 2}\n", 666 | "##4th term26677 Nt=99 TF=7835 maxTF=2147483647 @{0 13235127 3}\n", 667 | "##5 term9686 Nt=59334 TF=3827392 maxTF=2147483647 @{0 13236458 7}\n", 668 | "##50 term21092 Nt=4266 TF=387504 maxTF=2147483647 @{0 13792330 7}\n", 669 | "##500 term23143 Nt=1114 TF=149198 maxTF=2147483647 @{0 13849412 3}\n", 670 | "##51 term23511 Nt=2678 TF=236336 maxTF=2147483647 @{0 13870734 7}\n", 671 | "##52 term24258 Nt=3479 TF=262717 maxTF=2147483647 @{0 13906198 5}\n", 672 | "##53 term14968 Nt=2714 TF=246442 maxTF=2147483647 @{0 13946261 1}\n", 673 | "##54 term18391 Nt=2611 TF=223373 maxTF=2147483647 @{0 13982901 7}\n", 674 | "##55 term19085 Nt=3263 TF=244740 maxTF=2147483647 @{0 14016676 3}\n", 675 | "##56 term19084 Nt=972 TF=82606 maxTF=2147483647 @{0 14054242 6}\n", 676 | "##57 term19083 Nt=1951 TF=174945 maxTF=2147483647 @{0 14067004 0}\n", 677 | "##58 term23812 Nt=1394 TF=111265 maxTF=2147483647 @{0 14093424 2}\n", 678 | "##59 term12641 Nt=4058 TF=251090 maxTF=2147483647 @{0 14110712 3}\n", 679 | "##6 term8739 Nt=61302 TF=4087465 maxTF=2147483647 @{0 14150718 1}\n", 680 | "##60 term21093 Nt=6597 TF=498553 maxTF=2147483647 @{0 14741180 6}\n", 681 | "##64 term14966 Nt=3582 TF=328066 maxTF=2147483647 @{0 14816034 4}\n", 682 | "##65 term8771 Nt=1822 TF=186637 maxTF=2147483647 @{0 14864515 0}\n", 683 | "##66 term20389 Nt=1647 TF=165595 maxTF=2147483647 @{0 14892094 3}\n", 684 | "##6th term21817 Nt=166 TF=11349 maxTF=2147483647 @{0 14916523 7}\n", 685 | "##7 term5894 Nt=47443 TF=3062574 maxTF=2147483647 @{0 14918471 0}\n", 686 | "##70 term375 Nt=5428 TF=419113 maxTF=2147483647 @{0 15365343 3}\n", 687 | "##75 term12402 Nt=2318 TF=219361 maxTF=2147483647 @{0 15428436 6}\n", 688 | "##7th term22523 Nt=133 TF=12469 maxTF=2147483647 @{0 15461212 5}\n", 689 | "##8 term5241 Nt=62717 TF=3416232 maxTF=2147483647 @{0 15463151 7}\n", 690 | "##80 term19502 Nt=4460 TF=354300 maxTF=2147483647 @{0 15970484 6}\n", 691 | "##85 term24748 Nt=1705 TF=183624 maxTF=2147483647 @{0 16023806 4}\n", 692 | "##86 term20582 Nt=4168 TF=387621 maxTF=2147483647 @{0 16050672 7}\n", 693 | "##8th term22522 Nt=209 TF=9778 maxTF=2147483647 @{0 16107617 4}\n", 694 | "##9 term8770 Nt=37759 TF=2438344 maxTF=2147483647 @{0 16109490 7}\n", 695 | "##90 term12381 Nt=7156 TF=505946 maxTF=2147483647 @{0 16467771 0}\n", 696 | "##9th term27302 Nt=129 TF=10540 maxTF=2147483647 @{0 16544757 6}\n", 697 | "##a term1820 Nt=346290 TF=22052483 maxTF=2147483647 @{0 16546498 5}\n", 698 | "##a1 term25024 Nt=765 TF=67588 maxTF=2147483647 @{0 19560053 6}\n", 699 | "##aa term13733 Nt=10358 TF=1261542 maxTF=2147483647 @{0 19570364 1}\n", 700 | "##aan term19777 Nt=3122 TF=290479 maxTF=2147483647 @{0 19745293 5}\n", 701 | "##aar term23539 Nt=435 TF=78243 maxTF=2147483647 @{0 19787334 0}\n", 702 | "##ab term11707 Nt=31971 TF=2738108 maxTF=2147483647 @{0 19798187 0}\n", 703 | "##aba term19537 Nt=2464 TF=400298 maxTF=2147483647 @{0 20184194 5}\n", 704 | "##abad term26450 Nt=349 TF=47909 maxTF=2147483647 @{0 20238825 7}\n", 705 | "##abas term27079 Nt=510 TF=117890 maxTF=2147483647 @{0 20245643 3}\n", 706 | "##abe term21930 Nt=3433 TF=480493 maxTF=2147483647 @{0 20261558 1}\n", 707 | "##abi term18984 Nt=9264 TF=883815 maxTF=2147483647 @{0 20327719 5}\n", 708 | "##abia term28035 Nt=193 TF=35788 maxTF=2147483647 @{0 20452903 0}\n", 709 | "##ability term12376 Nt=18904 TF=1336255 maxTF=2147483647 @{0 20457881 1}\n", 710 | "##able term3683 Nt=180469 TF=12316532 maxTF=2147483647 @{0 20657259 0}\n" 711 | ] 712 | } 713 | ], 714 | "source": [ 715 | "for i, entry in enumerate(index.getLexicon()):\n", 716 | " if i == 100:\n", 717 | " break\n", 718 | " print(entry.getKey() + \" \" + entry.getValue().toString())" 719 | ] 720 | }, 721 | { 722 | "cell_type": "code", 723 | "execution_count": 10, 724 | "metadata": {}, 725 | "outputs": [ 726 | { 727 | "name": "stdout", 728 | "output_type": "stream", 729 | "text": [ 730 | "Number of documents: 8841823\n", 731 | "Number of terms: 28679\n", 732 | "Number of postings: 1038252055\n", 733 | "Number of fields: 0\n", 734 | "Number of tokens: 51858349642\n", 735 | "Field names: []\n", 736 | "Positions: false\n", 737 | "\n" 738 | ] 739 | } 740 | ], 741 | "source": [ 742 | "print(index.getCollectionStatistics().toString())" 743 | ] 744 | }, 745 | { 746 | "cell_type": "markdown", 747 | "metadata": {}, 748 | "source": [ 749 | "We can even look into particular document in the index." 750 | ] 751 | }, 752 | { 753 | "cell_type": "code", 754 | "execution_count": 11, 755 | "metadata": {}, 756 | "outputs": [ 757 | { 758 | "name": "stdout", 759 | "output_type": "stream", 760 | "text": [ 761 | "\" 13\n", 762 | "##uation 37\n", 763 | "000 59\n", 764 | "25 43\n", 765 | "30 70\n", 766 | "35 72\n", 767 | "40 35\n", 768 | "accountant 39\n", 769 | "accounting 70\n", 770 | "advice 11\n", 771 | "amount 71\n", 772 | "amounts 88\n", 773 | "applicants 12\n", 774 | "ask 108\n", 775 | "asked 86\n", 776 | "asking 33\n", 777 | "asks 6\n", 778 | "assessment 19\n", 779 | "average 38\n", 780 | "bargaining 48\n", 781 | "bart 20\n", 782 | "bottom 94\n", 783 | "briggs 20\n", 784 | "burke 10\n", 785 | "business 13\n", 786 | "businesses 28\n", 787 | "calculate 62\n", 788 | "calculated 2\n", 789 | "candidacy 30\n", 790 | "candidate 110\n", 791 | "candidates 115\n", 792 | "chart 26\n", 793 | "companies 43\n", 794 | "company 15\n", 795 | "considered 36\n", 796 | "corporate 10\n", 797 | "dave 17\n", 798 | "davis 5\n", 799 | "desk 7\n", 800 | "difference 1\n", 801 | "diversity 2\n", 802 | "employee 95\n", 803 | "employees 17\n", 804 | "employment 40\n", 805 | "engineer 20\n", 806 | "example 107\n", 807 | "examples 98\n", 808 | "excel 14\n", 809 | "executive 5\n", 810 | "finance 43\n", 811 | "fisher 8\n", 812 | "flat 134\n", 813 | "flex 50\n", 814 | "flexibility 129\n", 815 | "flexible 92\n", 816 | "gage 6\n", 817 | "give 98\n", 818 | "given 17\n", 819 | "giving 29\n", 820 | "highest 1\n", 821 | "hr 62\n", 822 | "improvisation 9\n", 823 | "include 57\n", 824 | "included 28\n", 825 | "income 15\n", 826 | "interview 19\n", 827 | "job 80\n", 828 | "jobs 33\n", 829 | "kelly 1\n", 830 | "letter 9\n", 831 | "low 76\n", 832 | "management 4\n", 833 | "marketing 14\n", 834 | "matching 23\n", 835 | "math 56\n", 836 | "max 44\n", 837 | "maximum 84\n", 838 | "median 101\n", 839 | "mid 102\n", 840 | "middle 32\n", 841 | "money 32\n", 842 | "murray 21\n", 843 | "negotiate 98\n", 844 | "negotiating 49\n", 845 | "negotiation 82\n", 846 | "negotiations 59\n", 847 | "normal 13\n", 848 | "numbers 9\n", 849 | "often 19\n", 850 | "pay 139\n", 851 | "payroll 7\n", 852 | "point 59\n", 853 | "points 50\n", 854 | "post 50\n", 855 | "posted 72\n", 856 | "posting 95\n", 857 | "practical 5\n", 858 | "price 123\n", 859 | "provide 54\n", 860 | "provided 78\n", 861 | "qualification 14\n", 862 | "range 233\n", 863 | "ranges 149\n", 864 | "rather 52\n", 865 | "recruit 36\n", 866 | "required 70\n", 867 | "requirements 125\n", 868 | "respect 4\n", 869 | "resume 38\n", 870 | "salaries 22\n", 871 | "salary 224\n", 872 | "say 93\n", 873 | "smith 43\n", 874 | "statistics 1\n", 875 | "sum 14\n", 876 | "survey 36\n", 877 | "thirty 24\n", 878 | "thousand 11\n", 879 | "wage 40\n", 880 | "when 61\n" 881 | ] 882 | } 883 | ], 884 | "source": [ 885 | "di = index.getDirectIndex()\n", 886 | "doi = index.getDocumentIndex()\n", 887 | "lex = index.getLexicon()\n", 888 | "docid = 7700000 #docids are 0-based\n", 889 | "#NB: postings will be null if the document is empty\n", 890 | "dictrep = {}\n", 891 | "for posting in di.getPostings(doi.getDocumentEntry(docid)):\n", 892 | " termid = posting.getId()\n", 893 | " lee = lex.getLexiconEntry(termid)\n", 894 | " dictrep[lee.getKey()] = posting.getFrequency()\n", 895 | "\n", 896 | "for k in sorted(dictrep.keys()):\n", 897 | " print(k, dictrep[k])" 898 | ] 899 | }, 900 | { 901 | "cell_type": "code", 902 | "execution_count": null, 903 | "metadata": {}, 904 | "outputs": [], 905 | "source": [] 906 | } 907 | ], 908 | "metadata": { 909 | "accelerator": "GPU", 910 | "colab": { 911 | "gpuType": "T4", 912 | "private_outputs": true, 913 | "provenance": [] 914 | }, 915 | "kernelspec": { 916 | "display_name": "Python 3 (ipykernel)", 917 | "language": "python", 918 | "name": "python3" 919 | }, 920 | "language_info": { 921 | "codemirror_mode": { 922 | "name": "ipython", 923 | "version": 3 924 | }, 925 | "file_extension": ".py", 926 | "mimetype": "text/x-python", 927 | "name": "python", 928 | "nbconvert_exporter": "python", 929 | "pygments_lexer": "ipython3", 930 | "version": "3.12.7" 931 | } 932 | }, 933 | "nbformat": 4, 934 | "nbformat_minor": 4 935 | } 936 | --------------------------------------------------------------------------------