├── .gitignore ├── CITATION.bib ├── LICENSE ├── Makefile ├── README.md ├── cherche ├── __init__.py ├── __version__.py ├── compose │ ├── __init__.py │ ├── base.py │ ├── intersection_union_vote.py │ ├── pipeline.py │ ├── test_compose.py │ ├── test_union_inter.py │ └── test_vote.py ├── data │ ├── __init__.py │ ├── semanlink.py │ ├── semanlink │ │ ├── arxiv.json │ │ ├── docs.json │ │ └── tags.json │ ├── towns.json │ └── towns.py ├── evaluate │ ├── __init__.py │ └── evaluate.py ├── index │ ├── __init__.py │ └── faiss_index.py ├── qa │ ├── __init__.py │ └── qa.py ├── query │ ├── __init__.py │ ├── base.py │ ├── norvig.py │ └── prf.py ├── rank │ ├── __init__.py │ ├── base.py │ ├── cross_encoder.py │ ├── dpr.py │ ├── embedding.py │ ├── encoder.py │ └── test_rank.py ├── retrieve │ ├── __init__.py │ ├── base.py │ ├── bm25.py │ ├── dpr.py │ ├── embedding.py │ ├── encoder.py │ ├── flash.py │ ├── fuzz.py │ ├── lunr.py │ ├── test_retrieve.py │ └── tfidf.py └── utils │ ├── __init__.py │ ├── batch.py │ ├── quantize.py │ └── topk.py ├── docs ├── .pages ├── CNAME ├── api │ ├── .pages │ ├── compose │ │ ├── .pages │ │ ├── Intersection.md │ │ ├── Pipeline.md │ │ ├── Union.md │ │ └── Vote.md │ ├── data │ │ ├── .pages │ │ ├── arxiv-tags.md │ │ └── load-towns.md │ ├── evaluate │ │ ├── .pages │ │ └── evaluation.md │ ├── index │ │ ├── .pages │ │ └── Faiss.md │ ├── overview.md │ ├── qa │ │ ├── .pages │ │ └── QA.md │ ├── query │ │ ├── .pages │ │ ├── Norvig.md │ │ ├── PRF.md │ │ └── Query.md │ ├── rank │ │ ├── .pages │ │ ├── CrossEncoder.md │ │ ├── DPR.md │ │ ├── Embedding.md │ │ ├── Encoder.md │ │ └── Ranker.md │ ├── retrieve │ │ ├── .pages │ │ ├── DPR.md │ │ ├── Embedding.md │ │ ├── Encoder.md │ │ ├── Flash.md │ │ ├── Fuzz.md │ │ ├── Lunr.md │ │ ├── Retriever.md │ │ └── TfIdf.md │ └── utils │ │ ├── .pages │ │ ├── TopK.md │ │ ├── quantize.md │ │ ├── yield-batch-single.md │ │ └── yield-batch.md ├── css │ └── version-select.css ├── documents │ ├── .pages │ ├── documents.md │ └── towns.md ├── examples │ ├── .pages │ ├── encoder_retriever.ipynb │ ├── eval_pipeline.ipynb │ ├── retriever_ranker.ipynb │ ├── retriever_ranker_qa.ipynb │ ├── union_intersection_rankers.ipynb │ └── voting.ipynb ├── img │ ├── doc.png │ ├── explain.excalidraw │ ├── explain.png │ ├── explain.svg │ ├── logo.png │ └── renault.jpg ├── index.md ├── javascripts │ └── config.js ├── js │ └── version-select.js ├── pipeline │ ├── .pages │ └── pipeline.md ├── qa │ ├── .pages │ └── qa.md ├── rank │ ├── .pages │ ├── crossencoder.md │ ├── dpr.md │ ├── embedding.md │ ├── encoder.md │ └── rank.md ├── retrieve │ ├── .pages │ ├── bm25.md │ ├── dpr.md │ ├── embedding.md │ ├── encoder.md │ ├── flash.md │ ├── fuzz.md │ ├── lunr.md │ ├── retrieve.md │ └── tfidf.md ├── scripts │ └── index.py ├── serialize │ ├── .pages │ └── serialize.md └── stylesheets │ └── extra.css ├── mkdocs.yml ├── pytest.ini ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | *.test 7 | *.onx 8 | *.qonx 9 | *.DS_Store 10 | *.pyc 11 | *.ipynb_checkpoints 12 | *.pickle 13 | *.pkl 14 | *.icloud 15 | cache/ 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | -------------------------------------------------------------------------------- /CITATION.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{Sourty2022sigir, 2 | author = {Raphael Sourty and Jose G. Moreno and Lynda Tamine and Francois-Paul Servant}, 3 | title = {CHERCHE: A new tool to rapidly implement pipelines in information retrieval}, 4 | booktitle = {Proceedings of SIGIR 2022}, 5 | year = {2022} 6 | } 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Raphael Sourty 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | COMMIT_HASH := $(shell eval git rev-parse HEAD) 2 | 3 | execute-notebooks: 4 | jupyter nbconvert --execute --to notebook --inplace docs/*/*.ipynb --ExecutePreprocessor.timeout=-1 5 | 6 | render-notebooks: 7 | 8 | livedoc: 9 | mkdocs build --clean 10 | mkdocs serve --dirtyreload 11 | 12 | deploydoc: 13 | mkdocs gh-deploy --force 14 | 15 | .PHONY: bench 16 | bench: 17 | asv run ${COMMIT_HASH} --config benchmarks/asv.conf.json --steps 1 18 | asv run master --config benchmarks/asv.conf.json --steps 1 19 | asv compare the-merge ${COMMIT_HASH} --config benchmarks/asv.conf.json -------------------------------------------------------------------------------- /cherche/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "compose", 3 | "data", 4 | "evaluate", 5 | "index", 6 | "qa", 7 | "query", 8 | "rank", 9 | "retrieve", 10 | "utils", 11 | ] 12 | -------------------------------------------------------------------------------- /cherche/__version__.py: -------------------------------------------------------------------------------- 1 | VERSION = (2, 2, 1) 2 | 3 | __version__ = ".".join(map(str, VERSION)) 4 | -------------------------------------------------------------------------------- /cherche/compose/__init__.py: -------------------------------------------------------------------------------- 1 | from .intersection_union_vote import Intersection, Union, Vote 2 | from .pipeline import Pipeline 3 | 4 | __all__ = ["Pipeline", "Intersection", "Union", "Vote"] 5 | -------------------------------------------------------------------------------- /cherche/compose/test_compose.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from .. import rank, retrieve 4 | 5 | 6 | def cherche_retrievers(key: str, on: str): 7 | """List of retrievers available in cherche.""" 8 | yield from [ 9 | retrieve.TfIdf(key=key, on=on, documents=documents()), 10 | retrieve.Lunr(key=key, on=on, documents=documents()), 11 | ] 12 | 13 | 14 | def cherche_rankers(key: str, on: str): 15 | """List of rankers available in cherche.""" 16 | from sentence_transformers import CrossEncoder, SentenceTransformer 17 | 18 | yield from [ 19 | rank.DPR( 20 | key=key, 21 | on=on, 22 | encoder=SentenceTransformer( 23 | "facebook-dpr-ctx_encoder-single-nq-base" 24 | ).encode, 25 | query_encoder=SentenceTransformer( 26 | "facebook-dpr-question_encoder-single-nq-base" 27 | ).encode, 28 | ), 29 | rank.Encoder( 30 | key=key, 31 | on="title", 32 | encoder=SentenceTransformer( 33 | "sentence-transformers/all-mpnet-base-v2" 34 | ).encode, 35 | ), 36 | rank.CrossEncoder( 37 | on=on, 38 | encoder=CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1").predict, 39 | ), 40 | ] 41 | 42 | 43 | def documents(): 44 | return [ 45 | { 46 | "url": "ckb/github.com", 47 | "title": "Github library with Pytorch and Transformers.", 48 | "date": "10-11-2021", 49 | }, 50 | { 51 | "url": "mkb/github.com", 52 | "title": "Github Library with PyTorch.", 53 | "date": "22-11-2021", 54 | }, 55 | { 56 | "url": "blp/github.com", 57 | "title": "Github library with Pytorch and Transformers.", 58 | "date": "22-11-2020", 59 | }, 60 | ] 61 | 62 | 63 | def tags(): 64 | return [ 65 | { 66 | "tags": ["Github", "git"], 67 | "title": "Github is a great tool.", 68 | "uri": "tag:github", 69 | }, 70 | { 71 | "tags": ["cherche", "tool"], 72 | "title": "Cherche is a tool to retrieve informations.", 73 | "uri": "tag:cherche", 74 | }, 75 | { 76 | "tags": ["python", "programming"], 77 | "title": "Python is a programming Language", 78 | "uri": "tag:python", 79 | }, 80 | ] 81 | 82 | 83 | @pytest.mark.parametrize( 84 | "search, documents, k", 85 | [ 86 | pytest.param( 87 | retriever + ranker + documents() 88 | if not isinstance(ranker, rank.CrossEncoder) 89 | else retriever + documents() + ranker, 90 | documents(), 91 | k, 92 | id=f"retriever: {retriever.__class__.__name__}, ranker: {ranker.__class__.__name__}, k: {k}", 93 | ) 94 | for k in [None, 2, 4] 95 | for ranker in cherche_rankers(key="url", on="title") 96 | for retriever in cherche_retrievers(key="url", on="title") 97 | ], 98 | ) 99 | def test_retriever_ranker(search, documents: list, k: int): 100 | """Test retriever ranker pipeline. Test if the number of retrieved documents is coherent. 101 | Check if the number of documents asked is higher than the actual number of documents retrieved. 102 | Check if the retriever do not find any document. 103 | """ 104 | search = search.add(documents) 105 | 106 | answers = search(q="Github library with PyTorch and Transformers", k=k) 107 | for index in range(len(documents) if k is None else k): 108 | if index in [0, 1]: 109 | assert ( 110 | answers[index]["title"] 111 | == "Github library with Pytorch and Transformers." 112 | ) 113 | elif index in [2]: 114 | assert answers[index]["title"] == "Github Library with PyTorch." 115 | 116 | answers = search(q="Github", k=k) 117 | if k is None: 118 | assert len(answers) == len(documents) 119 | else: 120 | assert len(answers) == min(k, len(documents)) 121 | 122 | for sample in answers: 123 | for key in ["url", "title", "date"]: 124 | assert key in sample 125 | 126 | 127 | @pytest.mark.parametrize( 128 | "search, documents, k", 129 | [ 130 | pytest.param( 131 | retrieve.Flash(key="uri", on="tags") + ranker + tags() 132 | if not isinstance(ranker, rank.CrossEncoder) 133 | else retrieve.Flash(key="uri", on="tags") + tags() + ranker, 134 | tags(), 135 | k, 136 | id=f"retriever: Flash, ranker: {ranker.__class__.__name__}, k: {k}", 137 | ) 138 | for k in [None, 0, 2, 4] 139 | for ranker in cherche_rankers(key="uri", on="title") 140 | ], 141 | ) 142 | def test_flash_ranker(search, documents: list, k: int): 143 | search = search.add(documents) 144 | answers = search(q="Github ( git ) is a great tool", k=k) 145 | if k is None: 146 | assert len(answers) == 2 147 | else: 148 | assert len(answers) == min(k, 2) 149 | for sample in answers: 150 | for key in ["tags", "title", "uri"]: 151 | assert key in sample 152 | -------------------------------------------------------------------------------- /cherche/compose/test_vote.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from .. import rank, retrieve 4 | 5 | 6 | def cherche_retrievers(key: str, on: str): 7 | """List of retrievers available in cherche.""" 8 | for retriever in [ 9 | retrieve.TfIdf, 10 | retrieve.Lunr, 11 | ]: 12 | yield retriever(key=key, on=on, documents=documents()) 13 | 14 | 15 | def documents(): 16 | return [ 17 | { 18 | "id": 0, 19 | "title": "Paris", 20 | "article": "This town is the capital of France", 21 | "author": "Wikipedia", 22 | }, 23 | { 24 | "id": 1, 25 | "title": "Eiffel tower", 26 | "article": "Eiffel tower is based in Paris", 27 | "author": "Wikipedia", 28 | }, 29 | { 30 | "id": 2, 31 | "title": "Montreal", 32 | "article": "Montreal is in Canada.", 33 | "author": "Wikipedia", 34 | }, 35 | ] 36 | 37 | 38 | @pytest.mark.parametrize( 39 | "search, documents, k", 40 | [ 41 | pytest.param( 42 | (retriever_a * retriever_b * retriever_c) + documents(), 43 | documents(), 44 | k, 45 | id=f"Union retrievers: {retriever_a.__class__.__name__} | {retriever_b.__class__.__name__} | {retriever_c.__class__.__name__} k: {k}", 46 | ) 47 | for k in [None, 3, 4] 48 | for retriever_c in cherche_retrievers(key="id", on="title") 49 | for retriever_b in cherche_retrievers(key="id", on="article") 50 | for retriever_a in cherche_retrievers(key="id", on="author") 51 | ], 52 | ) 53 | def test_retriever_union(search, documents: list, k: int): 54 | """Test retriever union operator.""" 55 | # Empty documents 56 | search = search.add(documents) 57 | 58 | answers = search(q="France", k=k) 59 | assert len(answers) == min(k, 1) if k is not None else 1 60 | 61 | for sample in answers: 62 | for key in ["title", "article", "author"]: 63 | assert key in sample 64 | 65 | answers = search(q="Wikipedia", k=k) 66 | assert len(answers) == min(k, len(documents)) if k is not None else len(documents) 67 | 68 | answers = search(q="Unknown", k=k) 69 | assert len(answers) == 0 70 | -------------------------------------------------------------------------------- /cherche/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .semanlink import arxiv_tags 2 | from .towns import load_towns 3 | 4 | __all__ = ["arxiv_tags", "load_towns"] 5 | -------------------------------------------------------------------------------- /cherche/data/semanlink.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | import pathlib 4 | 5 | __all__ = ["arxiv_tags"] 6 | 7 | 8 | def arxiv_tags( 9 | arxiv_title: bool = True, 10 | arxiv_summary: bool = True, 11 | comment: bool = False, 12 | broader_prefLabel_text: bool = True, 13 | broader_altLabel_text: bool = True, 14 | prefLabel_text: bool = True, 15 | altLabel_text: bool = True, 16 | ) -> tuple: 17 | """Semanlink tags arXiv documents. The objective of this dataset is to evaluate a neural 18 | search pipeline for automatic tagging of arXiv documents. This function returns the set of tags 19 | and the pairs arXiv documents and tags. 20 | 21 | Parameters 22 | ---------- 23 | arxiv_title 24 | Include title of the arxiv paper inside the query. 25 | arxiv_summary 26 | Include summary of the arxiv paper inside the query. 27 | comment 28 | Include comment of the arxiv paper inside the query. 29 | broader_prefLabel_text 30 | Include broader_prefLabel as a text field. 31 | broader_altLabel_text 32 | Include broader_altLabel_text as a text field. 33 | prefLabel_text 34 | Include prefLabel_text as a text field. 35 | altLabel_text 36 | Include altLabel_text as a text field. 37 | 38 | Examples 39 | -------- 40 | 41 | >>> from pprint import pprint as print 42 | >>> from cherche import data 43 | 44 | >>> documents, query_answers = data.arxiv_tags() 45 | 46 | >>> print(list(documents[0].keys())) 47 | ['prefLabel', 48 | 'type', 49 | 'broader', 50 | 'creationTime', 51 | 'creationDate', 52 | 'comment', 53 | 'uri', 54 | 'broader_prefLabel', 55 | 'broader_related', 56 | 'broader_prefLabel_text', 57 | 'prefLabel_text'] 58 | 59 | """ 60 | with open( 61 | pathlib.Path(__file__).parent.joinpath("semanlink/arxiv.json"), "r" 62 | ) as input_file: 63 | docs = json.load(input_file) 64 | 65 | with open( 66 | pathlib.Path(__file__).parent.joinpath("semanlink/tags.json"), "r" 67 | ) as input_file: 68 | tags = json.load(input_file) 69 | 70 | # Filter arxiv tags 71 | counter = collections.defaultdict(int) 72 | 73 | query_answers = [] 74 | for doc in docs: 75 | query = "" 76 | answers = [] 77 | for field, include in [ 78 | ("arxiv_title", arxiv_title), 79 | ("arxiv_summary", arxiv_summary), 80 | ("comment", comment), 81 | ]: 82 | if include: 83 | query = f"{query} {doc[field]}" 84 | 85 | for tag in doc["tag"]: 86 | answers.append({"uri": tags[tag]["uri"]}) 87 | counter[tag] += 1 88 | 89 | query_answers.append((query, answers)) 90 | 91 | # Filter arxiv tags 92 | documents = [] 93 | for tag in counter: 94 | documents.append( 95 | {key: value for key, value in tags[tag].items() if len(value) >= 1} 96 | ) 97 | 98 | for tag in documents: 99 | for field, include in [ 100 | ("broader_prefLabel", broader_prefLabel_text), 101 | ("broader_altLabel", broader_altLabel_text), 102 | ("prefLabel", prefLabel_text), 103 | ("altLabel", altLabel_text), 104 | ]: 105 | if include and len(tag.get(field, "")) >= 1: 106 | tag[f"{field}_text"] = " ".join(tag[field]) 107 | 108 | return documents, query_answers 109 | -------------------------------------------------------------------------------- /cherche/data/towns.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pathlib 3 | 4 | __all__ = ["load_towns"] 5 | 6 | 7 | def load_towns(): 8 | """Sample of Wikipedia dataset that contains informations about Toulouse, Paris, Lyon and 9 | Bordeaux. 10 | 11 | Examples 12 | -------- 13 | 14 | >>> from pprint import pprint as print 15 | >>> from cherche import data 16 | 17 | >>> towns = data.load_towns() 18 | 19 | >>> print(towns[:3]) 20 | [{'article': 'Paris (French pronunciation: \u200b[paʁi] (listen)) is the ' 21 | 'capital and most populous city of France, with an estimated ' 22 | 'population of 2,175,601 residents as of 2018, in an area of more ' 23 | 'than 105 square kilometres (41 square miles).', 24 | 'id': 0, 25 | 'title': 'Paris', 26 | 'url': 'https://en.wikipedia.org/wiki/Paris'}, 27 | {'article': "Since the 17th century, Paris has been one of Europe's major " 28 | 'centres of finance, diplomacy, commerce, fashion, gastronomy, ' 29 | 'science, and arts.', 30 | 'id': 1, 31 | 'title': 'Paris', 32 | 'url': 'https://en.wikipedia.org/wiki/Paris'}, 33 | {'article': 'The City of Paris is the centre and seat of government of the ' 34 | 'region and province of Île-de-France, or Paris Region, which has ' 35 | 'an estimated population of 12,174,880, or about 18 percent of ' 36 | 'the population of France as of 2017.', 37 | 'id': 2, 38 | 'title': 'Paris', 39 | 'url': 'https://en.wikipedia.org/wiki/Paris'}] 40 | 41 | 42 | """ 43 | with open(pathlib.Path(__file__).parent.joinpath("towns.json"), "r") as towns_json: 44 | return json.load(towns_json) 45 | -------------------------------------------------------------------------------- /cherche/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluate import evaluation 2 | 3 | __all__ = ["evaluation"] 4 | -------------------------------------------------------------------------------- /cherche/evaluate/evaluate.py: -------------------------------------------------------------------------------- 1 | __all__ = ["evaluation"] 2 | 3 | import collections 4 | import typing 5 | 6 | __all__ = ["evaluation"] 7 | 8 | 9 | class Mean: 10 | """Online running mean. 11 | 12 | Reference 13 | --------- 14 | 1. River [https://github.com/online-ml/river/blob/main/river/stats/mean.py] 15 | """ 16 | 17 | def __init__(self): 18 | self.n = 0 19 | self._mean = 0.0 20 | 21 | def update(self, x, w=1.0): 22 | self.n += w 23 | self._mean += (w / self.n) * (x - self._mean) 24 | return self 25 | 26 | def get(self): 27 | return self._mean 28 | 29 | 30 | def evaluation( 31 | search, 32 | query_answers: typing.List[typing.Tuple[str, typing.List[typing.Dict[str, str]]]], 33 | hits_k: range = range(1, 6, 1), 34 | batch_size: typing.Optional[int] = None, 35 | k: typing.Optional[int] = None, 36 | ): 37 | """Evaluation function 38 | 39 | 40 | Parameters 41 | ---------- 42 | search 43 | Search function. 44 | query_answers 45 | List of tuples (query, answers). 46 | hits_k 47 | List of k to compute precision, recall and F1. 48 | k 49 | Number of documents to retrieve. 50 | batch_size 51 | Batch size. 52 | 53 | Examples 54 | -------- 55 | >>> from pprint import pprint as print 56 | >>> from cherche import data, evaluate, retrieve 57 | >>> from lenlp import sparse 58 | 59 | >>> documents, query_answers = data.arxiv_tags( 60 | ... arxiv_title=True, arxiv_summary=False, comment=False 61 | ... ) 62 | 63 | >>> search = retrieve.TfIdf( 64 | ... key="uri", 65 | ... on=["prefLabel_text", "altLabel_text"], 66 | ... documents=documents, 67 | ... tfidf=sparse.TfidfVectorizer(normalize=True, ngram_range=(3, 7), analyzer="char"), 68 | ... ) + documents 69 | 70 | >>> scores = evaluate.evaluation(search=search, query_answers=query_answers, k=10) 71 | 72 | >>> print(scores) 73 | {'F1@1': '26.52%', 74 | 'F1@2': '29.41%', 75 | 'F1@3': '28.65%', 76 | 'F1@4': '26.85%', 77 | 'F1@5': '25.19%', 78 | 'Precision@1': '63.06%', 79 | 'Precision@2': '43.47%', 80 | 'Precision@3': '33.12%', 81 | 'Precision@4': '26.67%', 82 | 'Precision@5': '22.55%', 83 | 'R-Precision': '26.95%', 84 | 'Recall@1': '16.79%', 85 | 'Recall@2': '22.22%', 86 | 'Recall@3': '25.25%', 87 | 'Recall@4': '27.03%', 88 | 'Recall@5': '28.54%'} 89 | 90 | """ 91 | precision = collections.defaultdict(lambda: Mean()) 92 | recall = collections.defaultdict(lambda: Mean()) 93 | f1 = collections.defaultdict(lambda: Mean()) 94 | r_precision = Mean() 95 | 96 | answers = search( 97 | **{ 98 | "q": [q for q, _ in query_answers], 99 | "batch_size": batch_size, 100 | "k": k, 101 | } 102 | ) 103 | 104 | for (q, golds), candidates in zip(query_answers, answers): 105 | candidates = [candidate[search.key] for candidate in candidates] 106 | golds = {gold[search.key]: True for gold in golds} 107 | 108 | # Precision @ k 109 | for k in hits_k: 110 | for candidate in candidates[:k]: 111 | precision[k].update(1) if candidate in golds else precision[k].update(0) 112 | 113 | # Recall @ k 114 | for k in hits_k: 115 | if k == 0: 116 | continue 117 | positives = 0 118 | for candidate in candidates[:k]: 119 | if candidate in golds: 120 | positives += 1 121 | recall[k].update(positives / len(golds)) if positives > 0 else recall[ 122 | k 123 | ].update(0) 124 | 125 | # R-Precision 126 | relevant = 0 127 | for candidate in candidates[: len(golds)]: 128 | if candidate in golds: 129 | relevant += 1 130 | r_precision.update(relevant / len(golds) if relevant > 0 else 0) 131 | 132 | # F1 @ k 133 | for k in hits_k: 134 | if k == 0: 135 | continue 136 | f1[k] = ( 137 | (2 * precision[k].get() * recall[k].get()) 138 | / (precision[k].get() + recall[k].get()) 139 | if (precision[k].get() + recall[k].get()) > 0 140 | else 0 141 | ) 142 | 143 | metrics = { 144 | f"Precision@{k}": f"{metric.get():.2%}" for k, metric in precision.items() 145 | } 146 | metrics.update( 147 | {f"Recall@{k}": f"{metric.get():.2%}" for k, metric in recall.items()} 148 | ) 149 | metrics.update({f"F1@{k}": f"{metric:.2%}" for k, metric in f1.items()}) 150 | metrics.update({"R-Precision": f"{r_precision.get():.2%}"}) 151 | return metrics 152 | -------------------------------------------------------------------------------- /cherche/index/__init__.py: -------------------------------------------------------------------------------- 1 | from .faiss_index import Faiss 2 | 3 | __all__ = ["Faiss"] 4 | -------------------------------------------------------------------------------- /cherche/qa/__init__.py: -------------------------------------------------------------------------------- 1 | from .qa import QA 2 | 3 | __all__ = ["QA"] 4 | -------------------------------------------------------------------------------- /cherche/query/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Query 2 | from .norvig import Norvig 3 | from .prf import PRF 4 | 5 | __all__ = ["Query", "Norvig", "PRF"] 6 | -------------------------------------------------------------------------------- /cherche/query/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import typing 3 | 4 | from ..compose import Intersection, Pipeline, Union 5 | 6 | __all__ = ["Query"] 7 | 8 | 9 | class Query(abc.ABC): 10 | """Abstract class for models working on a query.""" 11 | 12 | def __init__(self, on: typing.Union[str, list]): 13 | self.on = on if isinstance(on, list) else [on] 14 | 15 | @property 16 | def type(self) -> str: 17 | return "query" 18 | 19 | def __repr__(self) -> str: 20 | repr = f"Query {self.__class__.__name__}" 21 | return repr 22 | 23 | @abc.abstractmethod 24 | def __call__( 25 | self, q: typing.Union[typing.List[str], str], **kwargs 26 | ) -> typing.Union[typing.List[str], str]: 27 | return [] 28 | 29 | def __add__(self, other) -> Pipeline: 30 | """Pipeline operator.""" 31 | if isinstance(other, Pipeline): 32 | return Pipeline(models=[self] + other.models) 33 | return Pipeline(models=[self, other]) 34 | 35 | def __or__(self, other) -> Union: 36 | """Union operator.""" 37 | raise NotImplementedError("Union not working with a Query model") 38 | 39 | def __and__(self, other) -> Intersection: 40 | """Intersection operator.""" 41 | raise NotImplementedError("Intersection not working with a Query model") 42 | -------------------------------------------------------------------------------- /cherche/query/norvig.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import pathlib 3 | import re 4 | import string 5 | import typing 6 | 7 | from ..utils import yield_batch_single 8 | from .base import Query 9 | 10 | __all__ = ["Norvig"] 11 | 12 | 13 | class Norvig(Query): 14 | """Spelling corrector written by Peter Norvig: 15 | [How to Write a Spelling Corrector](https://norvig.com/spell-correct.html) 16 | 17 | Parameters 18 | ---------- 19 | on 20 | Fields to use for fitting the spelling corrector on. 21 | 22 | Examples 23 | -------- 24 | 25 | >>> from cherche import query, data 26 | 27 | >>> documents = data.load_towns() 28 | 29 | >>> corrector = query.Norvig(on = ["title", "article"], lower=True) 30 | 31 | >>> corrector.add(documents) 32 | Query Norvig 33 | Vocabulary: 967 34 | 35 | >>> corrector(q="tha citi af Parisa is in Fronce") 36 | 'the city of paris is in france' 37 | 38 | >>> corrector(q=["tha citi af Parisa is in Fronce", "parisa"]) 39 | ['the city of paris is in france', 'paris'] 40 | 41 | References 42 | ---------- 43 | 1. [How to Write a Spelling Corrector](https://norvig.com/spell-correct.html) 44 | 45 | """ 46 | 47 | def __init__( 48 | self, 49 | on: typing.Union[str, typing.List], 50 | lower: bool = True, 51 | ) -> None: 52 | super().__init__(on=on) 53 | 54 | self.occurrences = collections.Counter() 55 | self.lower = lower 56 | 57 | def __repr__(self) -> str: 58 | repr = super().__repr__() 59 | repr += f"\n\t Vocabulary: {len(self.occurrences)}" 60 | return repr 61 | 62 | def __call__( 63 | self, q: typing.Union[typing.List[str], str], **kwargs 64 | ) -> typing.Union[typing.List[str], str]: 65 | """Correct spelling errors in a given query.""" 66 | queries = [] 67 | for batch in yield_batch_single(q, desc="Spelling-correction"): 68 | if len(self.occurrences) == 0: 69 | queries.append(batch) 70 | else: 71 | queries.append(" ".join(map(self.correct, batch.split(" ")))) 72 | return queries[0] if isinstance(q, str) else queries 73 | 74 | def correct(self, word: str) -> float: 75 | """Most probable spelling correction for word.""" 76 | return max( 77 | self._candidates(word), 78 | key=lambda w: self._probability(w.lower() if self.lower else w), 79 | ) 80 | 81 | def _probability(self, word: str) -> float: 82 | """Probability of `word`.""" 83 | return self.occurrences[word] / sum(self.occurrences.values()) 84 | 85 | def _candidates(self, word: str) -> set: 86 | """Generate possible spelling corrections for word.""" 87 | return ( 88 | self._known([word]) 89 | or self._known(self._edits1(word)) 90 | or self._known(self._edits2(word)) 91 | or [word] 92 | ) 93 | 94 | def _known(self, words: str) -> set: 95 | """The subset of `words` that appear in the dictionary.""" 96 | return set(w for w in words if w in self.occurrences) 97 | 98 | def _edits1(self, word: str) -> set: 99 | """All edits that are one edit away from `word`.""" 100 | letters = string.ascii_lowercase 101 | splits = [(word[:i], word[i:]) for i in range(len(word) + 1)] 102 | deletes = [L + R[1:] for L, R in splits if R] 103 | transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R) > 1] 104 | replaces = [L + c + R[1:] for L, R in splits if R for c in letters] 105 | inserts = [L + c + R for L, R in splits for c in letters] 106 | return set(deletes + transposes + replaces + inserts) 107 | 108 | def _edits2(self, word: str) -> set: 109 | """All edits that are two edits away from `word`. s""" 110 | return (e2 for e1 in self._edits1(word) for e2 in self._edits1(e1)) 111 | 112 | def add(self, documents: typing.Union[typing.List[typing.Dict], str]) -> "Norvig": 113 | """Fit Nervig spelling corrector.""" 114 | documents = ( 115 | documents 116 | if isinstance(documents, str) 117 | else " ".join( 118 | [ 119 | " ".join([document.get(field, "") for field in self.on]) 120 | for document in documents 121 | ] 122 | ) 123 | ) 124 | 125 | if self.lower: 126 | documents = documents.lower() 127 | 128 | self.occurrences.update(documents.split(" ")) 129 | return self 130 | 131 | def _update_from_file(self, path_file: str) -> "Norvig": 132 | """Update dictionary from all words fetched from a raw text file.""" 133 | with open(path_file, "r") as fp: 134 | self.occurrences.update(re.findall(r"\w+", fp.read().lower())) 135 | return self 136 | 137 | def reset(self) -> "Norvig": 138 | """Wipe dictionary.""" 139 | self.occurrences = collections.Counter() 140 | return self 141 | -------------------------------------------------------------------------------- /cherche/rank/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Ranker 2 | from .cross_encoder import CrossEncoder 3 | from .dpr import DPR 4 | from .embedding import Embedding 5 | from .encoder import Encoder 6 | 7 | __all__ = ["Ranker", "CrossEncoder", "DPR", "Embedding", "Encoder"] 8 | -------------------------------------------------------------------------------- /cherche/rank/dpr.py: -------------------------------------------------------------------------------- 1 | __all__ = ["DPR"] 2 | 3 | 4 | import typing 5 | 6 | from .base import MemoryStore, Ranker 7 | 8 | 9 | class DPR(Ranker): 10 | """Dual Sentence Transformer as a ranker. This ranker is compatible with any 11 | SentenceTransformer. DPR is a dual encoder model, it uses two SentenceTransformer, 12 | one for encoding documents and one for encoding queries. 13 | 14 | Parameters 15 | ---------- 16 | key 17 | Field identifier of each document. 18 | on 19 | Fields on wich encoder will perform similarity matching. 20 | encoder 21 | Encoding function dedicated documents. 22 | query_encoder 23 | Encoding function dedicated to queries. 24 | normalize 25 | If set to True, the similarity measure is cosine similarity, if set to False, similarity 26 | measure is dot product. 27 | 28 | Examples 29 | -------- 30 | >>> from pprint import pprint as print 31 | >>> from cherche import rank 32 | >>> from sentence_transformers import SentenceTransformer 33 | 34 | >>> ranker = rank.DPR( 35 | ... key = "id", 36 | ... on = ["title", "article"], 37 | ... encoder = SentenceTransformer('facebook-dpr-ctx_encoder-single-nq-base').encode, 38 | ... query_encoder = SentenceTransformer('facebook-dpr-question_encoder-single-nq-base').encode, 39 | ... normalize = True, 40 | ... ) 41 | 42 | >>> documents = [ 43 | ... {"id": 0, "title": "Paris France"}, 44 | ... {"id": 1, "title": "Madrid Spain"}, 45 | ... {"id": 2, "title": "Montreal Canada"} 46 | ... ] 47 | 48 | >>> ranker.add(documents=documents) 49 | DPR ranker 50 | key : id 51 | on : title, article 52 | normalize : True 53 | embeddings: 3 54 | 55 | >>> match = ranker( 56 | ... q="Paris", 57 | ... documents=documents 58 | ... ) 59 | 60 | >>> print(match) 61 | [{'id': 0, 'similarity': 7.806636, 'title': 'Paris France'}, 62 | {'id': 1, 'similarity': 6.239272, 'title': 'Madrid Spain'}, 63 | {'id': 2, 'similarity': 6.168748, 'title': 'Montreal Canada'}] 64 | 65 | >>> match = ranker( 66 | ... q=["Paris", "Madrid"], 67 | ... documents=[documents + [{"id": 3, "title": "Paris"}]] * 2, 68 | ... k=2, 69 | ... ) 70 | 71 | >>> print(match) 72 | [[{'id': 3, 'similarity': 7.906666, 'title': 'Paris'}, 73 | {'id': 0, 'similarity': 7.806636, 'title': 'Paris France'}], 74 | [{'id': 1, 'similarity': 8.07025, 'title': 'Madrid Spain'}, 75 | {'id': 0, 'similarity': 6.1131663, 'title': 'Paris France'}]] 76 | 77 | """ 78 | 79 | def __init__( 80 | self, 81 | on: typing.Union[str, typing.List[str]], 82 | key: str, 83 | encoder, 84 | query_encoder, 85 | normalize: bool = True, 86 | k: typing.Optional[int] = None, 87 | batch_size: int = 64, 88 | ) -> None: 89 | super().__init__( 90 | key=key, 91 | on=on, 92 | encoder=encoder, 93 | normalize=normalize, 94 | k=k, 95 | batch_size=batch_size, 96 | ) 97 | self.query_encoder = query_encoder 98 | 99 | def __call__( 100 | self, 101 | q: typing.Union[typing.List[str], str], 102 | documents: typing.Union[ 103 | typing.List[typing.List[typing.Dict[str, str]]], 104 | typing.List[typing.Dict[str, str]], 105 | ], 106 | k: int = None, 107 | batch_size: typing.Optional[int] = None, 108 | **kwargs, 109 | ) -> typing.Union[ 110 | typing.List[typing.List[typing.Dict[str, str]]], 111 | typing.List[typing.Dict[str, str]], 112 | ]: 113 | """Encode input query and ranks documents based on the similarity between the query and 114 | the selected field of the documents. 115 | 116 | Parameters 117 | ---------- 118 | q 119 | Input query. 120 | documents 121 | List of documents to rank. 122 | 123 | """ 124 | if k is None: 125 | k = self.k 126 | 127 | if k is None: 128 | k = len(self) 129 | 130 | if not documents and isinstance(q, str): 131 | return [] 132 | 133 | if not documents and isinstance(q, list): 134 | return [[]] 135 | 136 | rank = self.encode_rank( 137 | embeddings_queries=self.query_encoder([q] if isinstance(q, str) else q), 138 | documents=[documents] if isinstance(q, str) else documents, 139 | k=k, 140 | batch_size=batch_size if batch_size is not None else self.batch_size, 141 | ) 142 | 143 | return rank[0] if isinstance(q, str) else rank 144 | -------------------------------------------------------------------------------- /cherche/rank/encoder.py: -------------------------------------------------------------------------------- 1 | __all__ = ["Encoder"] 2 | 3 | import typing 4 | 5 | from .base import MemoryStore, Ranker 6 | 7 | 8 | class Encoder(Ranker): 9 | """Sentence Transformer as a ranker. This ranker is compatible with any SentenceTransformer. 10 | 11 | Parameters 12 | ---------- 13 | key 14 | Field identifier of each document. 15 | on 16 | Fields on wich encoder will perform similarity matching. 17 | encoder 18 | Encoding function dedicated to both documents and queries. 19 | normalize 20 | If set to True, the similarity measure is cosine similarity, if set to False, similarity 21 | measure is dot product. 22 | 23 | Examples 24 | -------- 25 | >>> from pprint import pprint as print 26 | >>> from cherche import rank 27 | >>> from sentence_transformers import SentenceTransformer 28 | 29 | >>> documents = [ 30 | ... {"id": 0, "title": "Paris France"}, 31 | ... {"id": 1, "title": "Madrid Spain"}, 32 | ... {"id": 2, "title": "Montreal Canada"} 33 | ... ] 34 | 35 | >>> ranker = rank.Encoder( 36 | ... encoder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2").encode, 37 | ... key = "id", 38 | ... on = ["title"], 39 | ... ) 40 | 41 | >>> ranker.add(documents=documents) 42 | Encoder ranker 43 | key : id 44 | on : title 45 | normalize : True 46 | embeddings: 3 47 | 48 | >>> match = ranker( 49 | ... q="Paris", 50 | ... documents=documents 51 | ... ) 52 | 53 | >>> print(match) 54 | [{'id': 0, 'similarity': 0.7127624, 'title': 'Paris France'}, 55 | {'id': 1, 'similarity': 0.5497405, 'title': 'Madrid Spain'}, 56 | {'id': 2, 'similarity': 0.50252455, 'title': 'Montreal Canada'}] 57 | 58 | >>> match = ranker( 59 | ... q=["Paris France", "Madrid Spain"], 60 | ... documents=[documents + [{"id": 3, "title": "Paris"}]] * 2, 61 | ... k=2, 62 | ... ) 63 | 64 | >>> print(match) 65 | [[{'id': 0, 'similarity': 0.99999994, 'title': 'Paris France'}, 66 | {'id': 1, 'similarity': 0.856435, 'title': 'Madrid Spain'}], 67 | [{'id': 1, 'similarity': 1.0, 'title': 'Madrid Spain'}, 68 | {'id': 0, 'similarity': 0.856435, 'title': 'Paris France'}]] 69 | 70 | """ 71 | 72 | def __init__( 73 | self, 74 | on: typing.Union[str, typing.List[str]], 75 | key: str, 76 | encoder, 77 | normalize: bool = True, 78 | k: typing.Optional[int] = None, 79 | batch_size: int = 64, 80 | ) -> None: 81 | super().__init__( 82 | key=key, 83 | on=on, 84 | encoder=encoder, 85 | normalize=normalize, 86 | k=k, 87 | batch_size=batch_size, 88 | ) 89 | 90 | def __call__( 91 | self, 92 | q: typing.Union[typing.List[str], str], 93 | documents: typing.Union[ 94 | typing.List[typing.List[typing.Dict[str, str]]], 95 | typing.List[typing.Dict[str, str]], 96 | ], 97 | k: typing.Optional[int] = None, 98 | batch_size: typing.Optional[int] = None, 99 | **kwargs 100 | ) -> typing.Union[ 101 | typing.List[typing.List[typing.Dict[str, str]]], 102 | typing.List[typing.Dict[str, str]], 103 | ]: 104 | """Encode input query and ranks documents based on the similarity between the query and 105 | the selected field of the documents. 106 | 107 | Parameters 108 | ---------- 109 | q 110 | Input query. 111 | documents 112 | List of documents to rank. 113 | 114 | """ 115 | if k is None: 116 | k = self.k 117 | 118 | if k is None: 119 | k = len(self) 120 | 121 | if not documents and isinstance(q, str): 122 | return [] 123 | 124 | if not documents and isinstance(q, list): 125 | return [[]] 126 | 127 | rank = self.encode_rank( 128 | embeddings_queries=self.encoder([q] if isinstance(q, str) else q), 129 | documents=[documents] if isinstance(q, str) else documents, 130 | k=k, 131 | batch_size=batch_size if batch_size is not None else self.batch_size, 132 | ) 133 | 134 | return rank[0] if isinstance(q, str) else rank 135 | -------------------------------------------------------------------------------- /cherche/rank/test_rank.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from .. import rank 4 | 5 | 6 | def cherche_rankers(key: str, on: str): 7 | """List of rankers available in cherche.""" 8 | from sentence_transformers import CrossEncoder, SentenceTransformer 9 | 10 | yield from [ 11 | rank.DPR( 12 | key=key, 13 | on=on, 14 | encoder=SentenceTransformer( 15 | "facebook-dpr-ctx_encoder-single-nq-base" 16 | ).encode, 17 | query_encoder=SentenceTransformer( 18 | "facebook-dpr-question_encoder-single-nq-base" 19 | ).encode, 20 | ), 21 | rank.Encoder( 22 | key=key, 23 | on=on, 24 | encoder=SentenceTransformer( 25 | "sentence-transformers/all-mpnet-base-v2" 26 | ).encode, 27 | ), 28 | rank.CrossEncoder( 29 | on=on, 30 | encoder=CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1").predict, 31 | ), 32 | ] 33 | 34 | 35 | def documents(): 36 | return [ 37 | { 38 | "title": "Paris", 39 | "article": "This town is the capital of France", 40 | "author": "Wikipedia", 41 | }, 42 | { 43 | "title": "Eiffel tower", 44 | "article": "Eiffel tower is based in Paris", 45 | "author": "Wikipedia", 46 | }, 47 | { 48 | "title": "Montreal", 49 | "article": "Montreal is in Canada.", 50 | "author": "Wikipedia", 51 | }, 52 | ] 53 | 54 | 55 | def missing_documents(): 56 | return [] 57 | 58 | 59 | @pytest.mark.parametrize( 60 | "ranker, documents, key, k", 61 | [ 62 | pytest.param( 63 | ranker, 64 | documents(), 65 | "title", 66 | k, 67 | id=f"Ranker: {ranker.__class__.__name__}, k: {k}", 68 | ) 69 | for k in [None, 2, 4] 70 | for ranker in cherche_rankers(key="title", on="article") 71 | ], 72 | ) 73 | def test_ranker(ranker, documents: list, key: str, k: int): 74 | """Test ranker. Test if the number of ranked documents is coherent. 75 | Check for empty retrieved documents should returns an empty list. 76 | """ 77 | if not isinstance(ranker, rank.CrossEncoder): 78 | ranker.add(documents) 79 | 80 | # CrossEncoder shot needs all the fields 81 | if not isinstance(ranker, rank.CrossEncoder): 82 | ranker += documents 83 | # Convert inputs document to a list of id [{"id": 0}, {"id": 1}, {"id": 2}] 84 | documents = [{key: document[key]} for document in documents] 85 | 86 | answers = ranker(q="Eiffel tower France", documents=documents, k=k) 87 | 88 | if k is not None: 89 | assert len(answers) == min(k, len(documents)) 90 | else: 91 | assert len(answers) == len(documents) 92 | 93 | for index, sample in enumerate(answers): 94 | for key in ["title", "article", "author"]: 95 | assert key in sample 96 | 97 | if index == 0: 98 | assert sample["title"] == "Eiffel tower" 99 | 100 | answers = ranker(q="Canada", documents=documents, k=k) 101 | 102 | if k is None: 103 | assert answers[0]["title"] == "Montreal" 104 | elif k >= 1: 105 | assert answers[0]["title"] == "Montreal" 106 | else: 107 | assert len(answers) == 0 108 | 109 | # Unknown token. 110 | answers = ranker(q="Paris", documents=[], k=k) 111 | assert len(answers) == 0 112 | 113 | 114 | @pytest.mark.parametrize( 115 | "ranker, documents, key, k", 116 | [ 117 | pytest.param( 118 | ranker, 119 | missing_documents(), 120 | "title", 121 | 5, 122 | id=f"Ranker: {ranker.__class__.__name__}, missing documents, k: 5", 123 | ) 124 | for ranker in cherche_rankers(key="title", on="article") 125 | ], 126 | ) 127 | def test_ranker_missing_documents(ranker, documents: list, key: str, k: int): 128 | """Test ranker when retriever do not returns any documents.""" 129 | answers = ranker(q="Eiffel tower France", documents=documents, k=k) 130 | assert len(answers) == 0 131 | 132 | answers = ranker( 133 | q=["Eiffel tower France", "Montreal Canada"], 134 | documents=[documents, documents], 135 | k=k, 136 | ) 137 | assert len(answers) == 2 and len(answers[0]) == 0 and len(answers[1]) == 0 138 | -------------------------------------------------------------------------------- /cherche/retrieve/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Retriever 2 | from .bm25 import BM25 3 | from .dpr import DPR 4 | from .embedding import Embedding 5 | from .encoder import Encoder 6 | from .flash import Flash 7 | from .fuzz import Fuzz 8 | from .lunr import Lunr 9 | from .tfidf import TfIdf 10 | 11 | __all__ = [ 12 | "Retriever", 13 | "BM25", 14 | "DPR", 15 | "Embedding", 16 | "Encoder", 17 | "Flash", 18 | "Fuzz", 19 | "Lunr", 20 | "TfIdf", 21 | ] 22 | -------------------------------------------------------------------------------- /cherche/retrieve/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import typing 3 | 4 | from ..compose import Intersection, Pipeline, Union, Vote 5 | 6 | __all__ = ["Retriever"] 7 | 8 | 9 | class Retriever(abc.ABC): 10 | """Retriever base class. 11 | 12 | Parameters 13 | ---------- 14 | key 15 | Field identifier of each document. 16 | on 17 | Fields to use to match the query to the documents. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | key: str, 23 | on: typing.Union[str, list], 24 | k: typing.Optional[int], 25 | batch_size: int, 26 | ) -> None: 27 | super().__init__() 28 | self.key = key 29 | self.on = on if isinstance(on, list) else [on] 30 | self.documents = None 31 | self.k = k 32 | self.batch_size = batch_size 33 | 34 | def __repr__(self) -> str: 35 | repr = f"{self.__class__.__name__} retriever" 36 | repr += f"\n\tkey : {self.key}" 37 | repr += f"\n\ton : {', '.join(self.on)}" 38 | repr += f"\n\tdocuments: {len(self)}" 39 | return repr 40 | 41 | @abc.abstractclassmethod 42 | def __call__( 43 | self, 44 | q: typing.Union[typing.List[str], str], 45 | k: typing.Optional[int], 46 | batch_size: typing.Optional[int], 47 | **kwargs, 48 | ) -> typing.Union[ 49 | typing.List[typing.List[typing.Dict[str, str]]], 50 | typing.List[typing.Dict[str, str]], 51 | ]: 52 | """Retrieve documents from the index.""" 53 | return [] 54 | 55 | def __len__(self): 56 | return len(self.documents) if self.documents is not None else 0 57 | 58 | def __add__(self, other) -> Pipeline: 59 | """Pipeline operator.""" 60 | if isinstance(other, Pipeline): 61 | return Pipeline(self, other.models) 62 | elif isinstance(other, list): 63 | # Documents are part of the pipeline. 64 | return Pipeline( 65 | [self, {document[self.key]: document for document in other}] 66 | ) 67 | return Pipeline([self, other]) 68 | 69 | def __or__(self, other) -> Union: 70 | """Union operator.""" 71 | if isinstance(other, Union): 72 | return Union([self] + other.models) 73 | return Union([self, other]) 74 | 75 | def __and__(self, other) -> Intersection: 76 | """Intersection operator.""" 77 | if isinstance(other, Intersection): 78 | return Intersection([self] + other.models) 79 | return Intersection([self, other]) 80 | 81 | def __mul__(self, other) -> Vote: 82 | """Voting operator.""" 83 | if isinstance(other, Vote): 84 | return Vote([self] + other.models) 85 | return Vote([self, other]) 86 | -------------------------------------------------------------------------------- /cherche/retrieve/bm25.py: -------------------------------------------------------------------------------- 1 | __all__ = ["BM25"] 2 | 3 | import typing 4 | 5 | from lenlp import sparse 6 | 7 | from .tfidf import TfIdf 8 | 9 | 10 | class BM25(TfIdf): 11 | """TfIdf retriever based on cosine similarities. 12 | 13 | Parameters 14 | ---------- 15 | key 16 | Field identifier of each document. 17 | on 18 | Fields to use to match the query to the documents. 19 | documents 20 | Documents in TFIdf retriever are static. The retriever must be reseted to index new 21 | documents. 22 | k 23 | Number of documents to retrieve. Default is `None`, i.e all documents that match the query 24 | will be retrieved. 25 | tfidf 26 | TfidfVectorizer class of Sklearn to create a custom TfIdf retriever. 27 | 28 | Examples 29 | -------- 30 | 31 | >>> from pprint import pprint as print 32 | >>> from cherche import retrieve 33 | 34 | >>> documents = [ 35 | ... {"id": 0, "title": "Paris", "article": "Eiffel tower"}, 36 | ... {"id": 1, "title": "Montreal", "article": "Montreal is in Canada."}, 37 | ... {"id": 2, "title": "Paris", "article": "Eiffel tower"}, 38 | ... {"id": 3, "title": "Montreal", "article": "Montreal is in Canada."}, 39 | ... ] 40 | 41 | >>> retriever = retrieve.BM25( 42 | ... key="id", 43 | ... on=["title", "article"], 44 | ... documents=documents, 45 | ... ) 46 | 47 | >>> documents = [ 48 | ... {"id": 4, "title": "Paris", "article": "Eiffel tower"}, 49 | ... {"id": 5, "title": "Montreal", "article": "Montreal is in Canada."}, 50 | ... {"id": 6, "title": "Paris", "article": "Eiffel tower"}, 51 | ... {"id": 7, "title": "Montreal", "article": "Montreal is in Canada."}, 52 | ... ] 53 | 54 | >>> retriever = retriever.add(documents) 55 | 56 | >>> print(retriever(q=["paris", "canada"], k=4)) 57 | [[{'id': 6, 'similarity': 0.5404109029445249}, 58 | {'id': 0, 'similarity': 0.5404109029445249}, 59 | {'id': 2, 'similarity': 0.5404109029445249}, 60 | {'id': 4, 'similarity': 0.5404109029445249}], 61 | [{'id': 7, 'similarity': 0.3157669764669935}, 62 | {'id': 5, 'similarity': 0.3157669764669935}, 63 | {'id': 3, 'similarity': 0.3157669764669935}, 64 | {'id': 1, 'similarity': 0.3157669764669935}]] 65 | 66 | >>> print(retriever(["unknown", "montreal paris"], k=2)) 67 | [[], 68 | [{'id': 7, 'similarity': 0.7391866872635209}, 69 | {'id': 5, 'similarity': 0.7391866872635209}]] 70 | 71 | 72 | >>> print(retriever(q="paris")) 73 | [{'id': 6, 'similarity': 0.5404109029445249}, 74 | {'id': 0, 'similarity': 0.5404109029445249}, 75 | {'id': 2, 'similarity': 0.5404109029445249}, 76 | {'id': 4, 'similarity': 0.5404109029445249}] 77 | 78 | References 79 | ---------- 80 | 1. [sklearn.feature_extraction.text.TfidfVectorizer](https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.TfidfVectorizer.html) 81 | 2. [Python: tf-idf-cosine: to find document similarity](https://stackoverflow.com/questions/12118720/python-tf-idf-cosine-to-find-document-similarity) 82 | 83 | """ 84 | 85 | def __init__( 86 | self, 87 | key: str, 88 | on: typing.Union[str, list], 89 | documents: typing.List[typing.Dict[str, str]] = None, 90 | count_vectorizer: sparse.BM25Vectorizer = None, 91 | k: typing.Optional[int] = None, 92 | batch_size: int = 1024, 93 | fit: bool = True, 94 | ) -> None: 95 | count_vectorizer = ( 96 | sparse.BM25Vectorizer( 97 | normalize=True, ngram_range=(3, 5), analyzer="char_wb" 98 | ) 99 | if count_vectorizer is None 100 | else count_vectorizer 101 | ) 102 | 103 | super().__init__( 104 | key=key, 105 | on=on, 106 | documents=documents, 107 | tfidf=count_vectorizer, 108 | k=k, 109 | batch_size=batch_size, 110 | fit=fit, 111 | ) 112 | -------------------------------------------------------------------------------- /cherche/retrieve/dpr.py: -------------------------------------------------------------------------------- 1 | __all__ = ["DPR"] 2 | 3 | import typing 4 | 5 | import tqdm 6 | 7 | from ..index import Faiss 8 | from ..utils import yield_batch 9 | from .base import Retriever 10 | 11 | 12 | class DPR(Retriever): 13 | """DPR as a retriever using Faiss Index. 14 | 15 | Parameters 16 | ---------- 17 | key 18 | Field identifier of each document. 19 | on 20 | Field to use to retrieve documents. 21 | index 22 | Faiss index that will store the embeddings and perform the similarity search. 23 | normalize 24 | Whether to normalize the embeddings before adding them to the index in order to measure 25 | cosine similarity. 26 | 27 | Examples 28 | -------- 29 | >>> from pprint import pprint as print 30 | >>> from cherche import retrieve 31 | >>> from sentence_transformers import SentenceTransformer 32 | 33 | >>> documents = [ 34 | ... {"id": 0, "title": "Paris France"}, 35 | ... {"id": 1, "title": "Madrid Spain"}, 36 | ... {"id": 2, "title": "Montreal Canada"} 37 | ... ] 38 | 39 | >>> retriever = retrieve.DPR( 40 | ... key = "id", 41 | ... on = ["title"], 42 | ... encoder = SentenceTransformer('facebook-dpr-ctx_encoder-single-nq-base').encode, 43 | ... query_encoder = SentenceTransformer('facebook-dpr-question_encoder-single-nq-base').encode, 44 | ... normalize = True, 45 | ... ) 46 | 47 | >>> retriever.add(documents) 48 | DPR retriever 49 | key : id 50 | on : title 51 | documents: 3 52 | 53 | >>> print(retriever("Spain", k=2)) 54 | [{'id': 1, 'similarity': 0.5534179127892946}, 55 | {'id': 0, 'similarity': 0.48604427456660426}] 56 | 57 | >>> print(retriever(["Spain", "Montreal"], k=2)) 58 | [[{'id': 1, 'similarity': 0.5534179492996913}, 59 | {'id': 0, 'similarity': 0.4860442182428353}], 60 | [{'id': 2, 'similarity': 0.5451990410703741}, 61 | {'id': 0, 'similarity': 0.47405722260691213}]] 62 | 63 | """ 64 | 65 | def __init__( 66 | self, 67 | key: str, 68 | on: typing.Union[str, list], 69 | encoder, 70 | query_encoder, 71 | normalize: bool = True, 72 | k: typing.Optional[int] = None, 73 | batch_size: int = 64, 74 | index=None, 75 | ) -> None: 76 | super().__init__(key=key, on=on, k=k, batch_size=batch_size) 77 | self.encoder = encoder 78 | self.query_encoder = query_encoder 79 | 80 | if index is None: 81 | self.index = Faiss(key=self.key, normalize=normalize) 82 | else: 83 | self.index = Faiss(key=self.key, index=index, normalize=normalize) 84 | 85 | def __len__(self) -> int: 86 | return len(self.index) 87 | 88 | def add( 89 | self, 90 | documents: typing.List[typing.Dict[str, str]], 91 | batch_size: int = 64, 92 | tqdm_bar: bool = True, 93 | **kwargs, 94 | ) -> "DPR": 95 | """Add documents to the index. 96 | 97 | Parameters 98 | ---------- 99 | documents 100 | List of documents to add the index. 101 | batch_size 102 | Number of documents to encode at once. 103 | """ 104 | 105 | for batch in yield_batch( 106 | array=documents, 107 | batch_size=batch_size, 108 | desc=f"{self.__class__.__name__} index creation", 109 | tqdm_bar=tqdm_bar, 110 | ): 111 | self.index.add( 112 | documents=batch, 113 | embeddings=self.encoder( 114 | [ 115 | " ".join([document.get(field, "") for field in self.on]) 116 | for document in batch 117 | ] 118 | ), 119 | ) 120 | 121 | self.k = len(self.index) 122 | return self 123 | 124 | def __call__( 125 | self, 126 | q: typing.Union[typing.List[str], str], 127 | k: typing.Optional[int] = None, 128 | batch_size: typing.Optional[int] = None, 129 | tqdm_bar: bool = True, 130 | **kwargs, 131 | ) -> typing.Union[ 132 | typing.List[typing.List[typing.Dict[str, str]]], 133 | typing.List[typing.Dict[str, str]], 134 | ]: 135 | """Retrieve documents from the index. 136 | 137 | Parameters 138 | ---------- 139 | q 140 | Either a single query or a list of queries. 141 | k 142 | Number of documents to retrieve. Default is `None`, i.e all documents that match the 143 | query will be retrieved. 144 | batch_size 145 | Number of queries to encode at once. 146 | """ 147 | k = k if k is not None else len(self) 148 | 149 | rank = [] 150 | 151 | for batch in yield_batch( 152 | array=q, 153 | batch_size=batch_size if batch_size is not None else self.batch_size, 154 | desc=f"{self.__class__.__name__} retriever", 155 | tqdm_bar=tqdm_bar, 156 | ): 157 | rank.extend( 158 | self.index( 159 | embeddings=self.query_encoder(batch), 160 | k=k, 161 | ) 162 | ) 163 | 164 | return rank[0] if isinstance(q, str) else rank 165 | -------------------------------------------------------------------------------- /cherche/retrieve/embedding.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import numpy as np 4 | 5 | from ..index import Faiss 6 | from ..utils import yield_batch 7 | from .base import Retriever 8 | 9 | __all__ = ["Embedding"] 10 | 11 | 12 | class Embedding(Retriever): 13 | """The Embedding retriever is dedicated to perform IR on embeddings calculated by the user 14 | rather than Cherche. 15 | 16 | Parameters 17 | ---------- 18 | key 19 | Field identifier of each document. 20 | index 21 | Faiss index that will store the embeddings and perform the similarity search. 22 | normalize 23 | Whether to normalize the embeddings before adding them to the index in order to measure 24 | cosine similarity. 25 | 26 | Examples 27 | -------- 28 | >>> from pprint import pprint as print 29 | >>> from cherche import retrieve 30 | >>> from sentence_transformers import SentenceTransformer 31 | 32 | >>> recommend = retrieve.Embedding( 33 | ... key="id", 34 | ... ) 35 | 36 | >>> documents = [ 37 | ... {"id": "a", "title": "Paris", "author": "Paris"}, 38 | ... {"id": "b", "title": "Madrid", "author": "Madrid"}, 39 | ... {"id": "c", "title": "Montreal", "author": "Montreal"}, 40 | ... ] 41 | 42 | >>> encoder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") 43 | >>> embeddings_documents = encoder.encode([ 44 | ... document["title"] for document in documents 45 | ... ]) 46 | 47 | >>> recommend.add( 48 | ... documents=documents, 49 | ... embeddings_documents=embeddings_documents, 50 | ... ) 51 | Embedding retriever 52 | key : id 53 | documents: 3 54 | 55 | >>> queries = [ 56 | ... "Paris", 57 | ... "Madrid", 58 | ... "Montreal" 59 | ... ] 60 | 61 | >>> embeddings_queries = encoder.encode(queries) 62 | >>> print(recommend(embeddings_queries, k=2)) 63 | [[{'id': 'a', 'similarity': 1.0}, 64 | {'id': 'c', 'similarity': 0.5385907831761005}], 65 | [{'id': 'b', 'similarity': 1.0}, 66 | {'id': 'a', 'similarity': 0.4990788711758875}], 67 | [{'id': 'c', 'similarity': 1.0}, 68 | {'id': 'a', 'similarity': 0.5385907831761005}]] 69 | 70 | >>> embeddings_queries = encoder.encode("Paris") 71 | >>> print(recommend(embeddings_queries, k=2)) 72 | [{'id': 'a', 'similarity': 0.9999999999989104}, 73 | {'id': 'c', 'similarity': 0.5385907485958683}] 74 | 75 | """ 76 | 77 | def __init__( 78 | self, 79 | key: str, 80 | index=None, 81 | normalize: bool = True, 82 | k: typing.Optional[int] = None, 83 | batch_size: int = 1024, 84 | ) -> None: 85 | super().__init__(key=key, on="", k=k, batch_size=batch_size) 86 | 87 | if index is None: 88 | self.index = Faiss(key=self.key, normalize=normalize) 89 | else: 90 | self.index = Faiss(key=self.key, index=index, normalize=normalize) 91 | 92 | def __repr__(self) -> str: 93 | repr = f"{self.__class__.__name__} retriever" 94 | repr += f"\n\tkey : {self.key}" 95 | repr += f"\n\tdocuments: {len(self)}" 96 | return repr 97 | 98 | def __len__(self) -> int: 99 | return len(self.index) 100 | 101 | def add( 102 | self, 103 | documents: list, 104 | embeddings_documents: np.ndarray, 105 | **kwargs, 106 | ) -> "Embedding": 107 | """Add embeddings both documents and users. 108 | 109 | Parameters 110 | ---------- 111 | documents 112 | List of documents to add to the index. 113 | 114 | embeddings_documents 115 | Embeddings of the documents ordered as the list of documents. 116 | """ 117 | self.index.add( 118 | documents=documents, 119 | embeddings=embeddings_documents, 120 | ) 121 | return self 122 | 123 | def __call__( 124 | self, 125 | q: np.ndarray, 126 | k: typing.Optional[int] = None, 127 | batch_size: typing.Optional[int] = None, 128 | tqdm_bar: bool = True, 129 | **kwargs, 130 | ) -> typing.Union[ 131 | typing.List[typing.List[typing.Dict[str, str]]], 132 | typing.List[typing.Dict[str, str]], 133 | ]: 134 | """Retrieve documents from the index. 135 | 136 | Parameters 137 | ---------- 138 | q 139 | Either a single query or a list of queries. 140 | k 141 | Number of documents to retrieve. Default is `None`, i.e all documents that match the 142 | query will be retrieved. 143 | batch_size 144 | Number of queries to encode at once. 145 | """ 146 | k = k if k is not None else len(self) 147 | 148 | if len(q.shape) == 1: 149 | q = q.reshape(1, -1) 150 | 151 | rank = [] 152 | for batch in yield_batch( 153 | array=q, 154 | batch_size=batch_size if batch_size is not None else self.batch_size, 155 | desc=f"{self.__class__.__name__} retriever", 156 | tqdm_bar=tqdm_bar, 157 | ): 158 | rank.extend( 159 | self.index( 160 | embeddings=batch, 161 | k=k, 162 | ) 163 | ) 164 | 165 | return rank[0] if len(q) == 1 else rank 166 | -------------------------------------------------------------------------------- /cherche/retrieve/encoder.py: -------------------------------------------------------------------------------- 1 | __all__ = ["Encoder"] 2 | 3 | import typing 4 | 5 | import tqdm 6 | 7 | from ..index import Faiss 8 | from ..utils import yield_batch 9 | from .base import Retriever 10 | 11 | 12 | class Encoder(Retriever): 13 | """Encoder as a retriever using Faiss Index. 14 | 15 | Parameters 16 | ---------- 17 | key 18 | Field identifier of each document. 19 | on 20 | Field to use to retrieve documents. 21 | index 22 | Faiss index that will store the embeddings and perform the similarity search. 23 | normalize 24 | Whether to normalize the embeddings before adding them to the index in order to measure 25 | cosine similarity. 26 | 27 | Examples 28 | -------- 29 | >>> from pprint import pprint as print 30 | >>> from cherche import retrieve 31 | >>> from sentence_transformers import SentenceTransformer 32 | 33 | >>> documents = [ 34 | ... {"id": 0, "title": "Paris France"}, 35 | ... {"id": 1, "title": "Madrid Spain"}, 36 | ... {"id": 2, "title": "Montreal Canada"} 37 | ... ] 38 | 39 | >>> retriever = retrieve.Encoder( 40 | ... encoder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2").encode, 41 | ... key = "id", 42 | ... on = ["title"], 43 | ... ) 44 | 45 | >>> retriever.add(documents, batch_size=1) 46 | Encoder retriever 47 | key : id 48 | on : title 49 | documents: 3 50 | 51 | >>> print(retriever("Spain", k=2)) 52 | [{'id': 1, 'similarity': 0.6544566453117681}, 53 | {'id': 0, 'similarity': 0.5405465419981407}] 54 | 55 | >>> print(retriever(["Spain", "Montreal"], k=2)) 56 | [[{'id': 1, 'similarity': 0.6544566453117681}, 57 | {'id': 0, 'similarity': 0.54054659424589}], 58 | [{'id': 2, 'similarity': 0.7372165680578416}, 59 | {'id': 0, 'similarity': 0.5185645704259234}]] 60 | 61 | """ 62 | 63 | def __init__( 64 | self, 65 | encoder, 66 | key: str, 67 | on: typing.Union[str, list], 68 | normalize: bool = True, 69 | k: typing.Optional[int] = None, 70 | batch_size: int = 64, 71 | index=None, 72 | ) -> None: 73 | super().__init__( 74 | key=key, 75 | on=on, 76 | k=k, 77 | batch_size=batch_size, 78 | ) 79 | self.encoder = encoder 80 | 81 | if index is None: 82 | self.index = Faiss(key=self.key, normalize=normalize) 83 | else: 84 | self.index = Faiss(key=self.key, index=index, normalize=normalize) 85 | 86 | def __len__(self) -> int: 87 | return len(self.index) 88 | 89 | def add( 90 | self, 91 | documents: typing.List[typing.Dict[str, str]], 92 | batch_size: int = 64, 93 | tqdm_bar: bool = True, 94 | **kwargs, 95 | ) -> "Encoder": 96 | """Add documents to the index. 97 | 98 | Parameters 99 | ---------- 100 | documents 101 | List of documents to add to the index. 102 | batch_size 103 | Number of documents to encode at once. 104 | """ 105 | 106 | for batch in yield_batch( 107 | array=documents, 108 | batch_size=batch_size, 109 | desc=f"{self.__class__.__name__} index creation", 110 | tqdm_bar=tqdm_bar, 111 | ): 112 | self.index.add( 113 | documents=batch, 114 | embeddings=self.encoder( 115 | [ 116 | " ".join([document.get(field, "") for field in self.on]) 117 | for document in batch 118 | ] 119 | ), 120 | ) 121 | 122 | return self 123 | 124 | def __call__( 125 | self, 126 | q: typing.Union[typing.List[str], str], 127 | k: typing.Optional[int] = None, 128 | batch_size: typing.Optional[int] = None, 129 | tqdm_bar: bool = True, 130 | **kwargs, 131 | ) -> typing.Union[ 132 | typing.List[typing.List[typing.Dict[str, str]]], 133 | typing.List[typing.Dict[str, str]], 134 | ]: 135 | """Retrieve documents from the index. 136 | 137 | Parameters 138 | ---------- 139 | q 140 | Either a single query or a list of queries. 141 | k 142 | Number of documents to retrieve. Default is `None`, i.e all documents that match the 143 | query will be retrieved. 144 | batch_size 145 | Number of queries to encode at once. 146 | """ 147 | k = k if k is not None else len(self) 148 | 149 | rank = [] 150 | for batch in yield_batch( 151 | array=q, 152 | batch_size=batch_size if batch_size is not None else self.batch_size, 153 | desc=f"{self.__class__.__name__} retriever", 154 | tqdm_bar=tqdm_bar, 155 | ): 156 | rank.extend( 157 | self.index( 158 | embeddings=self.encoder(batch), 159 | k=k, 160 | ) 161 | ) 162 | 163 | return rank[0] if isinstance(q, str) else rank 164 | -------------------------------------------------------------------------------- /cherche/retrieve/flash.py: -------------------------------------------------------------------------------- 1 | __all__ = ["Flash"] 2 | 3 | import collections 4 | import typing 5 | from itertools import chain 6 | 7 | from flashtext import KeywordProcessor 8 | 9 | from ..utils import yield_batch_single 10 | from .base import Retriever 11 | 12 | 13 | class Flash(Retriever): 14 | """FlashText Retriever. Flash aims to find documents that contain keywords such as a list of 15 | tags for example. 16 | 17 | Parameters 18 | ---------- 19 | key 20 | Field identifier of each document. 21 | on 22 | Fields to use to match the query to the documents. 23 | keywords 24 | Keywords extractor from [FlashText](https://github.com/vi3k6i5/flashtext). If set to None, 25 | a default one is created. 26 | 27 | Examples 28 | -------- 29 | >>> from pprint import pprint as print 30 | >>> from cherche import retrieve 31 | 32 | >>> documents = [ 33 | ... {"id": 0, "title": "paris", "article": "eiffel tower"}, 34 | ... {"id": 1, "title": "paris", "article": "paris"}, 35 | ... {"id": 2, "title": "montreal", "article": "montreal is in canada"}, 36 | ... ] 37 | 38 | >>> retriever = retrieve.Flash(key="id", on=["title", "article"]) 39 | 40 | >>> retriever.add(documents=documents) 41 | Flash retriever 42 | key : id 43 | on : title, article 44 | documents: 4 45 | 46 | >>> print(retriever(q="paris", k=2)) 47 | [{'id': 1, 'similarity': 0.6666666666666666}, 48 | {'id': 0, 'similarity': 0.3333333333333333}] 49 | 50 | [{'id': 0, 'similarity': 1}, {'id': 1, 'similarity': 1}] 51 | 52 | >>> print(retriever(q=["paris", "montreal"])) 53 | [[{'id': 1, 'similarity': 0.6666666666666666}, 54 | {'id': 0, 'similarity': 0.3333333333333333}], 55 | [{'id': 2, 'similarity': 1.0}]] 56 | 57 | References 58 | ---------- 59 | 1. [FlashText](https://github.com/vi3k6i5/flashtext) 60 | 2. [Replace or Retrieve Keywords In Documents at Scale](https://arxiv.org/abs/1711.00046) 61 | 62 | """ 63 | 64 | def __init__( 65 | self, 66 | key: str, 67 | on: typing.Union[str, list], 68 | keywords: KeywordProcessor = None, 69 | lowercase: bool = True, 70 | k: typing.Optional[int] = None, 71 | ) -> None: 72 | super().__init__(key=key, on=on, k=k, batch_size=1) 73 | self.documents = collections.defaultdict(list) 74 | self.keywords = KeywordProcessor() if keywords is None else keywords 75 | self.lowercase = lowercase 76 | 77 | def add(self, documents: typing.List[typing.Dict[str, str]], **kwargs) -> "Flash": 78 | """Add keywords to the retriever. 79 | 80 | Parameters 81 | ---------- 82 | documents 83 | List of documents to add to the retriever. 84 | 85 | """ 86 | for document in documents: 87 | for field in self.on: 88 | if field not in document: 89 | continue 90 | 91 | if isinstance(document[field], str): 92 | words = document[field] 93 | if self.lowercase: 94 | words = words.lower() 95 | self.documents[words].append({self.key: document[self.key]}) 96 | self.keywords.add_keyword(words) 97 | 98 | elif isinstance(document[field], list): 99 | words = document[field] 100 | if self.lowercase: 101 | words = [word.lower() for word in words] 102 | 103 | for word in words: 104 | self.documents[word].append({self.key: document[self.key]}) 105 | self.keywords.add_keywords_from_list(words) 106 | 107 | return self 108 | 109 | def __call__( 110 | self, 111 | q: typing.Union[typing.List[str], str], 112 | k: typing.Optional[int] = None, 113 | tqdm_bar: bool = True, 114 | **kwargs, 115 | ) -> list: 116 | """Retrieve documents from the index. 117 | 118 | Parameters 119 | ---------- 120 | q 121 | Either a single query or a list of queries. 122 | k 123 | Number of documents to retrieve. Default is `None`, i.e all documents that match the 124 | query will be retrieved. 125 | """ 126 | rank = [] 127 | 128 | for batch in yield_batch_single( 129 | q, desc=f"{self.__class__.__name__} retriever", tqdm_bar=tqdm_bar 130 | ): 131 | if self.lowercase: 132 | batch = batch.lower() 133 | 134 | match = list( 135 | chain.from_iterable( 136 | [ 137 | self.documents[tag] 138 | for tag in self.keywords.extract_keywords(batch) 139 | ] 140 | ) 141 | ) 142 | 143 | scores = collections.defaultdict(int) 144 | for document in match: 145 | scores[document[self.key]] += 1 146 | 147 | total = len(match) 148 | 149 | documents = [ 150 | {self.key: key, "similarity": scores[key] / total} 151 | for key in sorted(scores, key=scores.get, reverse=True) 152 | ] 153 | 154 | documents = documents[:k] if k is not None else documents 155 | rank.append(documents) 156 | 157 | return rank[0] if isinstance(q, str) else rank 158 | -------------------------------------------------------------------------------- /cherche/retrieve/lunr.py: -------------------------------------------------------------------------------- 1 | __all__ = ["Lunr"] 2 | 3 | import re 4 | import typing 5 | 6 | from lunr import lunr 7 | 8 | from ..utils import yield_batch_single 9 | from .base import Retriever 10 | 11 | 12 | class Lunr(Retriever): 13 | """Lunr is a Python implementation of Lunr.js by Oliver Nightingale. Lunr is a retriever 14 | dedicated for small and middle size corpus. 15 | 16 | Parameters 17 | ---------- 18 | key 19 | Field identifier of each document. 20 | on 21 | Fields to use to match the query to the documents. 22 | documents 23 | Documents in Lunr retriever are static. The retriever must be reseted to index new 24 | documents. 25 | 26 | Examples 27 | -------- 28 | >>> from pprint import pprint as print 29 | >>> from cherche import retrieve 30 | 31 | >>> documents = [ 32 | ... {"id": 0, "title": "Paris", "article": "Eiffel tower"}, 33 | ... {"id": 1, "title": "Paris", "article": "Paris is in France."}, 34 | ... {"id": 2, "title": "Montreal", "article": "Montreal is in Canada."}, 35 | ... ] 36 | 37 | >>> retriever = retrieve.Lunr( 38 | ... key="id", 39 | ... on=["title", "article"], 40 | ... documents=documents, 41 | ... ) 42 | 43 | >>> retriever 44 | Lunr retriever 45 | key : id 46 | on : title, article 47 | documents: 3 48 | 49 | >>> print(retriever(q="paris", k=2)) 50 | [{'id': 1, 'similarity': 0.268}, {'id': 0, 'similarity': 0.134}] 51 | 52 | >>> print(retriever(q=["paris", "montreal"], k=2)) 53 | [[{'id': 1, 'similarity': 0.268}, {'id': 0, 'similarity': 0.134}], 54 | [{'id': 2, 'similarity': 0.94}]] 55 | 56 | 57 | References 58 | ---------- 59 | 1. [Lunr.py](https://github.com/yeraydiazdiaz/lunr.py) 60 | 2. [Lunr.js](https://lunrjs.com) 61 | 2. [Solr](https://solr.apache.org) 62 | 63 | """ 64 | 65 | def __init__( 66 | self, 67 | key: str, 68 | on: typing.Union[str, list], 69 | documents: list, 70 | k: typing.Optional[int] = None, 71 | ) -> None: 72 | super().__init__(key=key, on=on, k=k, batch_size=1) 73 | 74 | self.documents = { 75 | str(document[self.key]): {self.key: document[self.key]} 76 | for document in documents 77 | } 78 | 79 | self.idx = lunr( 80 | ref=self.key, 81 | fields=tuple(self.on), 82 | documents=[ 83 | {field: document.get(field, "") for field in [self.key] + self.on} 84 | for document in documents 85 | ], 86 | ) 87 | 88 | def __call__( 89 | self, 90 | q: typing.Union[str, typing.List[str]], 91 | k: typing.Optional[int] = None, 92 | tqdm_bar: bool = True, 93 | **kwargs, 94 | ) -> typing.Union[ 95 | typing.List[typing.List[typing.Dict[str, str]]], 96 | typing.List[typing.Dict[str, str]], 97 | ]: 98 | """Retrieve documents from the index. 99 | 100 | Parameters 101 | ---------- 102 | q 103 | Either a single query or a list of queries. 104 | k 105 | Number of documents to retrieve. Default is `None`, i.e all documents that match the 106 | query will be retrieved. 107 | """ 108 | rank = [] 109 | 110 | for batch in yield_batch_single( 111 | array=q, 112 | desc=f"{self.__class__.__name__} retriever", 113 | tqdm_bar=tqdm_bar, 114 | ): 115 | batch = re.sub("[^a-zA-Z0-9 \n\.]", " ", batch) 116 | documents = [ 117 | {**self.documents[match["ref"]], "similarity": match["score"]} 118 | for match in self.idx.search(batch) 119 | ] 120 | documents = documents[:k] if k is not None else documents 121 | rank.append(documents) 122 | 123 | return rank[0] if isinstance(q, str) else rank 124 | -------------------------------------------------------------------------------- /cherche/retrieve/test_retrieve.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from .. import retrieve 4 | 5 | 6 | def cherche_retrievers(on: str): 7 | """List of retrievers available in cherche.""" 8 | yield from [ 9 | retrieve.TfIdf(key="title", on=on, documents=documents()), 10 | retrieve.Lunr(key="title", on=on, documents=documents()), 11 | ] 12 | 13 | 14 | def documents(): 15 | return [ 16 | { 17 | "title": "Paris", 18 | "article": "This town is the capital of France", 19 | "author": "Wikipedia", 20 | }, 21 | { 22 | "title": "Eiffel tower", 23 | "article": "Eiffel tower is based in Paris", 24 | "author": "Wikipedia", 25 | }, 26 | { 27 | "title": "Montreal", 28 | "article": "Montreal is in Canada.", 29 | "author": "Wikipedia", 30 | }, 31 | ] 32 | 33 | 34 | @pytest.mark.parametrize( 35 | "retriever, documents, k", 36 | [ 37 | pytest.param( 38 | retriever, 39 | documents(), 40 | k, 41 | id=f"retriever: {retriever.__class__.__name__}, k: {k}", 42 | ) 43 | for k in [None, 2, 4] 44 | for retriever in cherche_retrievers(on="article") 45 | ], 46 | ) 47 | def test_retriever(retriever, documents: list, k: int): 48 | """Test retriever. Test if the number of retrieved documents is coherent. 49 | Check for unknown tokens in the corpus, should returns an empty list. 50 | """ 51 | retriever = retriever + documents 52 | retriever.add(documents) 53 | 54 | # A single document contains town. 55 | answers = retriever(q="town", k=k) 56 | if k is None or k >= 1: 57 | assert len(answers) >= 1 58 | else: 59 | assert len(answers) == 0 60 | 61 | for sample in answers: 62 | for key in ["title", "article", "author"]: 63 | assert key in sample 64 | 65 | # Unknown token. 66 | answers = retriever(q="un", k=k) 67 | assert len(answers) == 0 68 | 69 | # All documents contains "Montreal Eiffel France" 70 | answers = retriever(q="Montreal Eiffel France", k=k) 71 | if k is None or k >= len(documents): 72 | assert len(answers) == len(documents) 73 | else: 74 | assert len(answers) == k 75 | 76 | 77 | @pytest.mark.parametrize( 78 | "retriever, documents, k", 79 | [ 80 | pytest.param( 81 | retriever, 82 | documents(), 83 | k, 84 | id=f"Multiple fields retriever: {retriever.__class__.__name__}, k: {k}", 85 | ) 86 | for k in [None, 2, 4] 87 | for retriever in cherche_retrievers(on=["article", "title", "author"]) 88 | ], 89 | ) 90 | def test_fields_retriever(retriever, documents: list, k: int): 91 | """Test retriever when providing multiples fields.""" 92 | retriever = retriever + documents 93 | 94 | # All documents have Wikipedia as author. 95 | answers = retriever(q="Wikipedia", k=k) 96 | if k is None or k >= len(documents): 97 | assert len(answers) == len(documents) 98 | else: 99 | assert len(answers) == max(k, 0) 100 | 101 | for sample in answers: 102 | for key in ["title", "article", "author"]: 103 | assert key in sample 104 | 105 | # Unknown token. 106 | answers = retriever(q="un") 107 | assert len(answers) == 0 108 | 109 | # Two documents contains paris 110 | answers = retriever(q="Paris", k=k) 111 | 112 | if k is None or k >= 2: 113 | assert len(answers) == 2 114 | else: 115 | assert len(answers) == max(k, 0) 116 | 117 | 118 | @pytest.mark.parametrize( 119 | "documents, k", 120 | [ 121 | pytest.param( 122 | documents(), 123 | k, 124 | id=f"retriever: Flash, k: {k}", 125 | ) 126 | for k in [None, 2, 4] 127 | ], 128 | ) 129 | def test_flash(documents: list, k: int): 130 | """Test Flash retriever.""" 131 | # Reset retriever 132 | retriever = retrieve.Flash(key="title", on="title") + documents 133 | retriever.add(documents) 134 | 135 | # A single document contains town. 136 | answers = retriever(q="paris", k=k) 137 | if k is None or k >= 1: 138 | assert len(answers) == 1 139 | else: 140 | assert len(answers) == 0 141 | 142 | for sample in answers: 143 | for key in ["title", "article", "author"]: 144 | assert key in sample 145 | 146 | # Unknown token. 147 | answers = retriever(q="Unknown", k=k) 148 | assert len(answers) == 0 149 | 150 | # All documents contains is 151 | answers = retriever(q="Paris Eiffel tower Montreal", k=k) 152 | 153 | if k is None or k >= len(documents): 154 | assert len(answers) == len(documents) 155 | else: 156 | assert len(answers) == k 157 | -------------------------------------------------------------------------------- /cherche/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .batch import yield_batch, yield_batch_single 2 | from .quantize import quantize 3 | from .topk import TopK 4 | 5 | __all__ = ["quantize", "yield_batch", "yield_batch_single", "TopK"] 6 | -------------------------------------------------------------------------------- /cherche/utils/batch.py: -------------------------------------------------------------------------------- 1 | __all__ = ["yield_batch", "yield_batch_single"] 2 | 3 | import typing 4 | 5 | import numpy as np 6 | import tqdm 7 | 8 | 9 | def yield_batch_single( 10 | array: typing.Union[ 11 | typing.Union[typing.List[str], str], 12 | typing.List[typing.Dict[str, typing.Any]], 13 | ], 14 | desc: str, 15 | tqdm_bar: bool = True, 16 | ): 17 | """Yield successive n-sized chunks from array.""" 18 | if isinstance(array, str): 19 | yield array 20 | elif tqdm_bar: 21 | for batch in tqdm.tqdm( 22 | array, 23 | position=0, 24 | desc=desc, 25 | total=len(array), 26 | ): 27 | yield batch 28 | else: 29 | for batch in array: 30 | yield batch 31 | 32 | 33 | def yield_batch( 34 | array: typing.Union[ 35 | typing.Union[ 36 | typing.Union[typing.List[str], str], 37 | typing.List[typing.Dict[str, typing.Any]], 38 | ], 39 | np.ndarray, 40 | ], 41 | batch_size: int, 42 | desc: str, 43 | tqdm_bar: bool = True, 44 | ) -> typing.Generator: 45 | """Yield successive n-sized chunks from array.""" 46 | if isinstance(array, str): 47 | yield [array] 48 | elif tqdm_bar: 49 | for batch in tqdm.tqdm( 50 | [array[pos : pos + batch_size] for pos in range(0, len(array), batch_size)], 51 | position=0, 52 | desc=desc, 53 | total=1 + len(array) // batch_size, 54 | ): 55 | yield batch 56 | else: 57 | for batch in [ 58 | array[pos : pos + batch_size] for pos in range(0, len(array), batch_size) 59 | ]: 60 | yield batch 61 | -------------------------------------------------------------------------------- /cherche/utils/quantize.py: -------------------------------------------------------------------------------- 1 | __all__ = ["quantize"] 2 | 3 | 4 | def quantize(model, dtype=None, layers=None, engine="qnnpack"): 5 | """Quantize model to speedup inference. May reduce accuracy. 6 | 7 | Parameters 8 | ---------- 9 | model 10 | Transformer model to quantize. 11 | dtype 12 | Dtype to apply to selected layers. 13 | layers 14 | Layers to quantize. 15 | engine 16 | The qengine specifies which backend is to be used for execution. 17 | 18 | Examples 19 | -------- 20 | >>> from pprint import pprint as print 21 | >>> from cherche import utils, retrieve 22 | >>> from sentence_transformers import SentenceTransformer 23 | 24 | >>> documents = [ 25 | ... {"id": 0, "title": "Paris France"}, 26 | ... {"id": 1, "title": "Madrid Spain"}, 27 | ... {"id": 2, "title": "Montreal Canada"} 28 | ... ] 29 | 30 | >>> encoder = utils.quantize(SentenceTransformer("sentence-transformers/all-mpnet-base-v2")) 31 | 32 | >>> retriever = retrieve.Encoder( 33 | ... encoder = encoder.encode, 34 | ... key = "id", 35 | ... on = ["title"], 36 | ... ) 37 | 38 | >>> retriever = retriever.add(documents) 39 | 40 | >>> print(retriever("paris")) 41 | [{'id': 0, 'similarity': 0.6361529519968355}, 42 | {'id': 2, 'similarity': 0.42750324298964354}, 43 | {'id': 1, 'similarity': 0.42645383885361576}] 44 | 45 | References 46 | ---------- 47 | 1. [PyTorch Quantization](https://pytorch.org/docs/stable/quantization.html) 48 | 49 | """ 50 | try: 51 | import torch 52 | except ImportError: 53 | raise ImportError( 54 | "Run pip install cherche[cpu] or pip install cherche[gpu] to use quantize." 55 | ) 56 | 57 | if dtype is None: 58 | dtype = torch.qint8 59 | 60 | if layers is None: 61 | layers = {torch.nn.Linear} 62 | 63 | torch.backends.quantized.engine = engine 64 | return torch.quantization.quantize_dynamic(model, layers, dtype) 65 | -------------------------------------------------------------------------------- /cherche/utils/topk.py: -------------------------------------------------------------------------------- 1 | __all__ = ["TopK"] 2 | 3 | import typing 4 | 5 | 6 | class TopK: 7 | """Filter top k documents in pipeline. 8 | 9 | Parameters 10 | ---------- 11 | k 12 | Number of documents to keep. 13 | 14 | Examples 15 | -------- 16 | 17 | >>> from pprint import pprint as print 18 | >>> from cherche import retrieve, rank, utils 19 | >>> from sentence_transformers import SentenceTransformer 20 | 21 | >>> documents = [ 22 | ... {"id": 0, "title": "Paris France"}, 23 | ... {"id": 1, "title": "Madrid Spain"}, 24 | ... {"id": 2, "title": "Montreal Canada"} 25 | ... ] 26 | 27 | >>> retriever = retrieve.TfIdf(key="id", on=["title", "article"], documents=documents) 28 | 29 | >>> ranker = rank.Encoder( 30 | ... encoder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2").encode, 31 | ... key = "id", 32 | ... on = ["title"], 33 | ... ) 34 | 35 | >>> pipeline = retriever + ranker + utils.TopK(k=2) 36 | >>> pipeline.add(documents=documents) 37 | TfIdf retriever 38 | key : id 39 | on : title, article 40 | documents: 3 41 | Encoder ranker 42 | key : id 43 | on : title 44 | normalize : True 45 | embeddings: 3 46 | Filter TopK 47 | k: 2 48 | 49 | >>> print(pipeline(q="Paris Madrid Montreal", k=2)) 50 | [{'id': 0, 'similarity': 0.62922895}, {'id': 2, 'similarity': 0.61419094}] 51 | 52 | """ 53 | 54 | def __init__(self, k: int): 55 | self.k = k 56 | 57 | def __repr__(self) -> str: 58 | repr = f"Filter {self.__class__.__name__}" 59 | repr += f"\n\tk: {self.k}" 60 | return repr 61 | 62 | def __call__( 63 | self, 64 | documents: typing.Union[typing.List[typing.List[typing.Dict[str, str]]]], 65 | **kwargs, 66 | ) -> typing.Union[ 67 | typing.List[typing.List[typing.Dict[str, str]]], 68 | typing.List[typing.Dict[str, str]], 69 | ]: 70 | """Filter top k documents in pipeline.""" 71 | if not documents: 72 | return [] 73 | 74 | if isinstance(documents[0], list): 75 | return [document[: self.k] for document in documents] 76 | 77 | return documents[: self.k] 78 | -------------------------------------------------------------------------------- /docs/.pages: -------------------------------------------------------------------------------- 1 | nav: 2 | - Home: index.md 3 | - documents 4 | - retrieve 5 | - rank 6 | - qa 7 | - pipeline 8 | - api 9 | - examples 10 | - serialize 11 | -------------------------------------------------------------------------------- /docs/CNAME: -------------------------------------------------------------------------------- 1 | raphaelsty.github.io/cherche/ -------------------------------------------------------------------------------- /docs/api/.pages: -------------------------------------------------------------------------------- 1 | title: API reference 2 | arrange: 3 | - overview.md 4 | - ... 5 | -------------------------------------------------------------------------------- /docs/api/compose/.pages: -------------------------------------------------------------------------------- 1 | title: compose -------------------------------------------------------------------------------- /docs/api/compose/Intersection.md: -------------------------------------------------------------------------------- 1 | # Intersection 2 | 3 | Intersection gathers retrieved documents from multiples retrievers and ranked documents from multiples rankers only if they are proposed by all models of the intersection pipeline. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **models** (*list*) 10 | 11 | List of models of the union. 12 | 13 | 14 | 15 | ## Examples 16 | 17 | ```python 18 | >>> from pprint import pprint as print 19 | >>> from cherche import retrieve 20 | 21 | >>> documents = [ 22 | ... {"id": 0, "town": "Paris", "country": "France", "continent": "Europe"}, 23 | ... {"id": 1, "town": "Montreal", "country": "Canada", "continent": "North America"}, 24 | ... {"id": 2, "town": "Madrid", "country": "Spain", "continent": "Europe"}, 25 | ... ] 26 | 27 | >>> search = ( 28 | ... retrieve.TfIdf(key="id", on="town", documents=documents) & 29 | ... retrieve.TfIdf(key="id", on="country", documents=documents) & 30 | ... retrieve.Flash(key="id", on="continent") 31 | ... ) 32 | 33 | >>> search = search.add(documents) 34 | 35 | >>> print(search("Paris")) 36 | [] 37 | 38 | >>> print(search(["Paris", "Europe"])) 39 | [[], []] 40 | 41 | >>> print(search(["Paris", "Europe", "Paris Madrid Europe France Spain"])) 42 | [[], 43 | [], 44 | [{'id': 2, 'similarity': 4.25}, {'id': 0, 'similarity': 3.0999999999999996}]] 45 | ``` 46 | 47 | ## Methods 48 | 49 | ???- note "__call__" 50 | 51 | Call self as a function. 52 | 53 | **Parameters** 54 | 55 | - **q** (*Union[List[List[Dict[str, str]]], List[Dict[str, str]]]*) 56 | - **batch_size** (*Optional[int]*) – defaults to `None` 57 | - **k** (*Optional[int]*) – defaults to `None` 58 | - **documents** (*Optional[List[Dict[str, str]]]*) – defaults to `None` 59 | - **kwargs** 60 | 61 | ???- note "add" 62 | 63 | Add new documents. 64 | 65 | **Parameters** 66 | 67 | - **documents** (*list*) 68 | - **kwargs** 69 | 70 | ???- note "reset" 71 | 72 | -------------------------------------------------------------------------------- /docs/api/compose/Pipeline.md: -------------------------------------------------------------------------------- 1 | # Pipeline 2 | 3 | Neurals search pipeline. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **models** (*list*) 10 | 11 | List of models of the pipeline. 12 | 13 | 14 | 15 | ## Examples 16 | 17 | ```python 18 | >>> from pprint import pprint as print 19 | >>> from cherche import retrieve, rank 20 | >>> from sentence_transformers import SentenceTransformer 21 | 22 | >>> documents = [ 23 | ... {"id": 0, "town": "Paris", "country": "France", "continent": "Europe"}, 24 | ... {"id": 1, "town": "Montreal", "country": "Canada", "continent": "North America"}, 25 | ... {"id": 2, "town": "Madrid", "country": "Spain", "continent": "Europe"}, 26 | ... ] 27 | 28 | >>> retriever = retrieve.TfIdf( 29 | ... key="id", on=["town", "country", "continent"], documents=documents) 30 | 31 | >>> ranker = rank.Encoder( 32 | ... encoder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2").encode, 33 | ... key = "id", 34 | ... on = ["town", "country", "continent"], 35 | ... ) 36 | 37 | >>> pipeline = retriever + ranker 38 | 39 | >>> pipeline = pipeline.add(documents) 40 | 41 | >>> print(pipeline("Paris Europe")) 42 | [{'id': 0, 'similarity': 0.9149576}, {'id': 2, 'similarity': 0.8091332}] 43 | 44 | >>> print(pipeline(["Paris", "Europe", "Paris Madrid Europe France Spain"])) 45 | [[{'id': 0, 'similarity': 0.69523287}], 46 | [{'id': 0, 'similarity': 0.7381397}, {'id': 2, 'similarity': 0.6488539}], 47 | [{'id': 0, 'similarity': 0.8582063}, {'id': 2, 'similarity': 0.8200009}]] 48 | 49 | >>> pipeline = retriever + ranker + documents 50 | 51 | >>> print(pipeline("Paris Europe")) 52 | [{'continent': 'Europe', 53 | 'country': 'France', 54 | 'id': 0, 55 | 'similarity': 0.9149576, 56 | 'town': 'Paris'}, 57 | {'continent': 'Europe', 58 | 'country': 'Spain', 59 | 'id': 2, 60 | 'similarity': 0.8091332, 61 | 'town': 'Madrid'}] 62 | 63 | >>> print(pipeline(["Paris", "Europe", "Paris Madrid Europe France Spain"])) 64 | [[{'continent': 'Europe', 65 | 'country': 'France', 66 | 'id': 0, 67 | 'similarity': 0.69523287, 68 | 'town': 'Paris'}], 69 | [{'continent': 'Europe', 70 | 'country': 'France', 71 | 'id': 0, 72 | 'similarity': 0.7381397, 73 | 'town': 'Paris'}, 74 | {'continent': 'Europe', 75 | 'country': 'Spain', 76 | 'id': 2, 77 | 'similarity': 0.6488539, 78 | 'town': 'Madrid'}], 79 | [{'continent': 'Europe', 80 | 'country': 'France', 81 | 'id': 0, 82 | 'similarity': 0.8582063, 83 | 'town': 'Paris'}, 84 | {'continent': 'Europe', 85 | 'country': 'Spain', 86 | 'id': 2, 87 | 'similarity': 0.8200009, 88 | 'town': 'Madrid'}]] 89 | ``` 90 | 91 | ## Methods 92 | 93 | ???- note "__call__" 94 | 95 | Pipeline main method. It takes a query and returns a list of documents. If the query is a list of queries, it returns a list of list of documents. If the batch_size_ranker, or batch_size_retriever it takes precedence over the batch_size. If the k_ranker, or k_retriever it takes precedence over the k parameter. 96 | 97 | **Parameters** 98 | 99 | - **q** (*Union[List[str], str]*) 100 | - **k** (*Optional[int]*) – defaults to `None` 101 | - **batch_size** (*Optional[int]*) – defaults to `None` 102 | - **documents** (*Optional[List[Dict[str, str]]]*) – defaults to `None` 103 | - **kwargs** 104 | 105 | ???- note "add" 106 | 107 | Add new documents. 108 | 109 | **Parameters** 110 | 111 | - **documents** (*list*) 112 | - **kwargs** 113 | 114 | ???- note "reset" 115 | 116 | -------------------------------------------------------------------------------- /docs/api/compose/Union.md: -------------------------------------------------------------------------------- 1 | # Union 2 | 3 | Union gathers retrieved documents from multiples retrievers and ranked documents from multiples rankers. The union operator concat results with respect of the orders of the models in the union. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **models** (*list*) 10 | 11 | List of models of the union. 12 | 13 | 14 | 15 | ## Examples 16 | 17 | ```python 18 | >>> from pprint import pprint as print 19 | >>> from cherche import retrieve 20 | 21 | >>> documents = [ 22 | ... {"id": 0, "town": "Paris", "country": "France", "continent": "Europe"}, 23 | ... {"id": 1, "town": "Montreal", "country": "Canada", "continent": "North America"}, 24 | ... {"id": 2, "town": "Madrid", "country": "Spain", "continent": "Europe"}, 25 | ... ] 26 | 27 | >>> search = ( 28 | ... retrieve.TfIdf(key="id", on="town", documents=documents) | 29 | ... retrieve.TfIdf(key="id", on="country", documents=documents) | 30 | ... retrieve.Flash(key="id", on="continent") 31 | ... ) 32 | 33 | >>> search = search.add(documents) 34 | 35 | >>> print(search("Paris")) 36 | [{'id': 0, 'similarity': 1.0}] 37 | 38 | >>> print(search(["Paris", "Europe"])) 39 | [[{'id': 0, 'similarity': 1.0}], 40 | [{'id': 0, 'similarity': 1.0}, {'id': 2, 'similarity': 0.5}]] 41 | ``` 42 | 43 | ## Methods 44 | 45 | ???- note "__call__" 46 | 47 | 48 | 49 | **Parameters** 50 | 51 | - **q** (*Union[List[List[Dict[str, str]]], List[Dict[str, str]]]*) 52 | - **batch_size** (*Optional[int]*) – defaults to `None` 53 | - **k** (*Optional[int]*) – defaults to `None` 54 | - **documents** (*Optional[List[Dict[str, str]]]*) – defaults to `None` 55 | - **kwargs** 56 | 57 | ???- note "add" 58 | 59 | Add new documents. 60 | 61 | **Parameters** 62 | 63 | - **documents** (*list*) 64 | - **kwargs** 65 | 66 | ???- note "reset" 67 | 68 | -------------------------------------------------------------------------------- /docs/api/compose/Vote.md: -------------------------------------------------------------------------------- 1 | # Vote 2 | 3 | Voting operator. Computes the score for each document based on it's number of occurences and based on documents ranks: $nb_occurences * sum_{rank \in ranks} 1 / rank$. The higher the score, the higher the document is ranked in output of the vote. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **models** (*list*) 10 | 11 | List of models of the vote. 12 | 13 | 14 | 15 | ## Examples 16 | 17 | ```python 18 | >>> from pprint import pprint as print 19 | >>> from cherche import retrieve, rank 20 | >>> from sentence_transformers import SentenceTransformer 21 | 22 | >>> documents = [ 23 | ... {"id": 0, "town": "Paris", "country": "France", "continent": "Europe"}, 24 | ... {"id": 1, "town": "Montreal", "country": "Canada", "continent": "North America"}, 25 | ... {"id": 2, "town": "Madrid", "country": "Spain", "continent": "Europe"}, 26 | ... ] 27 | 28 | >>> search = ( 29 | ... retrieve.TfIdf(key="id", on="town", documents=documents) * 30 | ... retrieve.TfIdf(key="id", on="country", documents=documents) * 31 | ... retrieve.Flash(key="id", on="continent") 32 | ... ) 33 | 34 | >>> search = search.add(documents) 35 | 36 | >>> retriever = retrieve.TfIdf(key="id", on=["town", "country", "continent"], documents=documents) 37 | 38 | >>> ranker = rank.Encoder( 39 | ... key="id", 40 | ... on=["town", "country", "continent"], 41 | ... encoder=SentenceTransformer("sentence-transformers/all-mpnet-base-v2").encode, 42 | ... ) * rank.Encoder( 43 | ... key="id", 44 | ... on=["town", "country", "continent"], 45 | ... encoder=SentenceTransformer( 46 | ... "sentence-transformers/multi-qa-mpnet-base-cos-v1" 47 | ... ).encode, 48 | ... ) 49 | 50 | >>> search = retriever + ranker 51 | 52 | >>> search = search.add(documents) 53 | 54 | >>> print(search("What is the capital of Canada ? Is it paris, montreal or madrid ?")) 55 | [{'id': 1, 'similarity': 2.5}, 56 | {'id': 0, 'similarity': 1.4}, 57 | {'id': 2, 'similarity': 1.0}] 58 | ``` 59 | 60 | ## Methods 61 | 62 | ???- note "__call__" 63 | 64 | Call self as a function. 65 | 66 | **Parameters** 67 | 68 | - **q** (*Union[List[List[Dict[str, str]]], List[Dict[str, str]]]*) 69 | - **batch_size** (*Optional[int]*) – defaults to `None` 70 | - **k** (*Optional[int]*) – defaults to `None` 71 | - **documents** (*Optional[List[Dict[str, str]]]*) – defaults to `None` 72 | - **kwargs** 73 | 74 | ???- note "add" 75 | 76 | Add new documents. 77 | 78 | **Parameters** 79 | 80 | - **documents** (*list*) 81 | - **kwargs** 82 | 83 | ???- note "reset" 84 | 85 | -------------------------------------------------------------------------------- /docs/api/data/.pages: -------------------------------------------------------------------------------- 1 | title: data -------------------------------------------------------------------------------- /docs/api/data/arxiv-tags.md: -------------------------------------------------------------------------------- 1 | # arxiv_tags 2 | 3 | Semanlink tags arXiv documents. The objective of this dataset is to evaluate a neural search pipeline for automatic tagging of arXiv documents. This function returns the set of tags and the pairs arXiv documents and tags. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **arxiv_title** (*bool*) – defaults to `True` 10 | 11 | Include title of the arxiv paper inside the query. 12 | 13 | - **arxiv_summary** (*bool*) – defaults to `True` 14 | 15 | Include summary of the arxiv paper inside the query. 16 | 17 | - **comment** (*bool*) – defaults to `False` 18 | 19 | Include comment of the arxiv paper inside the query. 20 | 21 | - **broader_prefLabel_text** (*bool*) – defaults to `True` 22 | 23 | Include broader_prefLabel as a text field. 24 | 25 | - **broader_altLabel_text** (*bool*) – defaults to `True` 26 | 27 | Include broader_altLabel_text as a text field. 28 | 29 | - **prefLabel_text** (*bool*) – defaults to `True` 30 | 31 | Include prefLabel_text as a text field. 32 | 33 | - **altLabel_text** (*bool*) – defaults to `True` 34 | 35 | Include altLabel_text as a text field. 36 | 37 | 38 | 39 | ## Examples 40 | 41 | ```python 42 | >>> from pprint import pprint as print 43 | >>> from cherche import data 44 | 45 | >>> documents, query_answers = data.arxiv_tags() 46 | 47 | >>> print(list(documents[0].keys())) 48 | ['prefLabel', 49 | 'type', 50 | 'broader', 51 | 'creationTime', 52 | 'creationDate', 53 | 'comment', 54 | 'uri', 55 | 'broader_prefLabel', 56 | 'broader_related', 57 | 'broader_prefLabel_text', 58 | 'prefLabel_text'] 59 | ``` 60 | 61 | -------------------------------------------------------------------------------- /docs/api/data/load-towns.md: -------------------------------------------------------------------------------- 1 | # load_towns 2 | 3 | Sample of Wikipedia dataset that contains informations about Toulouse, Paris, Lyon and Bordeaux. 4 | 5 | 6 | 7 | 8 | 9 | ## Examples 10 | 11 | ```python 12 | >>> from pprint import pprint as print 13 | >>> from cherche import data 14 | 15 | >>> towns = data.load_towns() 16 | 17 | >>> print(towns[:3]) 18 | [{'article': 'Paris (French pronunciation: ​[paʁi] (listen)) is the ' 19 | 'capital and most populous city of France, with an estimated ' 20 | 'population of 2,175,601 residents as of 2018, in an area of more ' 21 | 'than 105 square kilometres (41 square miles).', 22 | 'id': 0, 23 | 'title': 'Paris', 24 | 'url': 'https://en.wikipedia.org/wiki/Paris'}, 25 | {'article': "Since the 17th century, Paris has been one of Europe's major " 26 | 'centres of finance, diplomacy, commerce, fashion, gastronomy, ' 27 | 'science, and arts.', 28 | 'id': 1, 29 | 'title': 'Paris', 30 | 'url': 'https://en.wikipedia.org/wiki/Paris'}, 31 | {'article': 'The City of Paris is the centre and seat of government of the ' 32 | 'region and province of Île-de-France, or Paris Region, which has ' 33 | 'an estimated population of 12,174,880, or about 18 percent of ' 34 | 'the population of France as of 2017.', 35 | 'id': 2, 36 | 'title': 'Paris', 37 | 'url': 'https://en.wikipedia.org/wiki/Paris'}] 38 | ``` 39 | 40 | -------------------------------------------------------------------------------- /docs/api/evaluate/.pages: -------------------------------------------------------------------------------- 1 | title: evaluate -------------------------------------------------------------------------------- /docs/api/evaluate/evaluation.md: -------------------------------------------------------------------------------- 1 | # evaluation 2 | 3 | Evaluation function 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **search** 10 | 11 | Search function. 12 | 13 | - **query_answers** (*List[Tuple[str, List[Dict[str, str]]]]*) 14 | 15 | List of tuples (query, answers). 16 | 17 | - **hits_k** (*range*) – defaults to `range(1, 6)` 18 | 19 | List of k to compute precision, recall and F1. 20 | 21 | - **batch_size** (*Optional[int]*) – defaults to `None` 22 | 23 | Batch size. 24 | 25 | - **k** (*Optional[int]*) – defaults to `None` 26 | 27 | Number of documents to retrieve. 28 | 29 | 30 | 31 | ## Examples 32 | 33 | ```python 34 | >>> from pprint import pprint as print 35 | >>> from cherche import data, evaluate, retrieve 36 | >>> from lenlp import sparse 37 | 38 | >>> documents, query_answers = data.arxiv_tags( 39 | ... arxiv_title=True, arxiv_summary=False, comment=False 40 | ... ) 41 | 42 | >>> search = retrieve.TfIdf( 43 | ... key="uri", 44 | ... on=["prefLabel_text", "altLabel_text"], 45 | ... documents=documents, 46 | ... tfidf=sparse.TfidfVectorizer(normalize=True, ngram_range=(3, 7), analyzer="char"), 47 | ... ) + documents 48 | 49 | >>> scores = evaluate.evaluation(search=search, query_answers=query_answers, k=10) 50 | 51 | >>> print(scores) 52 | {'F1@1': '26.52%', 53 | 'F1@2': '29.41%', 54 | 'F1@3': '28.65%', 55 | 'F1@4': '26.85%', 56 | 'F1@5': '25.19%', 57 | 'Precision@1': '63.06%', 58 | 'Precision@2': '43.47%', 59 | 'Precision@3': '33.12%', 60 | 'Precision@4': '26.67%', 61 | 'Precision@5': '22.55%', 62 | 'R-Precision': '26.95%', 63 | 'Recall@1': '16.79%', 64 | 'Recall@2': '22.22%', 65 | 'Recall@3': '25.25%', 66 | 'Recall@4': '27.03%', 67 | 'Recall@5': '28.54%'} 68 | ``` 69 | 70 | -------------------------------------------------------------------------------- /docs/api/index/.pages: -------------------------------------------------------------------------------- 1 | title: index -------------------------------------------------------------------------------- /docs/api/index/Faiss.md: -------------------------------------------------------------------------------- 1 | # Faiss 2 | 3 | Faiss index dedicated to vector search. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **key** 10 | 11 | Identifier field for each document. 12 | 13 | - **index** – defaults to `None` 14 | 15 | Faiss index to use. 16 | 17 | - **normalize** (*bool*) – defaults to `True` 18 | 19 | 20 | 21 | ## Examples 22 | 23 | ```python 24 | >>> from pprint import pprint as print 25 | >>> from cherche import index 26 | >>> from sentence_transformers import SentenceTransformer 27 | 28 | >>> documents = [ 29 | ... {"id": 0, "title": "Paris France"}, 30 | ... {"id": 1, "title": "Madrid Spain"}, 31 | ... {"id": 2, "title": "Montreal Canada"} 32 | ... ] 33 | 34 | >>> encoder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") 35 | 36 | >>> faiss_index = index.Faiss(key="id") 37 | >>> faiss_index = faiss_index.add( 38 | ... documents = documents, 39 | ... embeddings = encoder.encode([document["title"] for document in documents]), 40 | ... ) 41 | 42 | >>> print(faiss_index(embeddings=encoder.encode(["Spain", "Montreal"]))) 43 | [[{'id': 1, 'similarity': 0.6544566197822951}, 44 | {'id': 0, 'similarity': 0.5405466290777285}, 45 | {'id': 2, 'similarity': 0.48717489472604614}], 46 | [{'id': 2, 'similarity': 0.7372165680578416}, 47 | {'id': 0, 'similarity': 0.5185646665953703}, 48 | {'id': 1, 'similarity': 0.4834444940712032}]] 49 | 50 | >>> documents = [ 51 | ... {"id": 3, "title": "Paris France"}, 52 | ... {"id": 4, "title": "Madrid Spain"}, 53 | ... {"id": 5, "title": "Montreal Canada"} 54 | ... ] 55 | 56 | >>> faiss_index = faiss_index.add( 57 | ... documents = documents, 58 | ... embeddings = encoder.encode([document["title"] for document in documents]), 59 | ... ) 60 | 61 | >>> print(faiss_index(embeddings=encoder.encode(["Spain", "Montreal"]), k=4)) 62 | [[{'id': 1, 'similarity': 0.6544566197822951}, 63 | {'id': 4, 'similarity': 0.6544566197822951}, 64 | {'id': 0, 'similarity': 0.5405466290777285}, 65 | {'id': 3, 'similarity': 0.5405466290777285}], 66 | [{'id': 2, 'similarity': 0.7372165680578416}, 67 | {'id': 5, 'similarity': 0.7372165680578416}, 68 | {'id': 0, 'similarity': 0.5185646665953703}, 69 | {'id': 3, 'similarity': 0.5185646665953703}]] 70 | ``` 71 | 72 | ## Methods 73 | 74 | ???- note "__call__" 75 | 76 | Call self as a function. 77 | 78 | **Parameters** 79 | 80 | - **embeddings** (*numpy.ndarray*) 81 | - **k** (*Optional[int]*) – defaults to `None` 82 | 83 | ???- note "add" 84 | 85 | Add documents to the faiss index and export embeddings if the path is provided. Streaming friendly. 86 | 87 | **Parameters** 88 | 89 | - **documents** (*list*) 90 | - **embeddings** (*numpy.ndarray*) 91 | 92 | ## References 93 | 94 | 1. [Faiss](https://github.com/facebookresearch/faiss) 95 | 96 | -------------------------------------------------------------------------------- /docs/api/overview.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | ## compose 4 | 5 | - [Intersection](../compose/Intersection) 6 | - [Pipeline](../compose/Pipeline) 7 | - [Union](../compose/Union) 8 | - [Vote](../compose/Vote) 9 | 10 | ## data 11 | 12 | - [arxiv_tags](../data/arxiv-tags) 13 | - [load_towns](../data/load-towns) 14 | 15 | ## evaluate 16 | 17 | - [evaluation](../evaluate/evaluation) 18 | 19 | ## index 20 | 21 | - [Faiss](../index/Faiss) 22 | 23 | ## qa 24 | 25 | - [QA](../qa/QA) 26 | 27 | ## query 28 | 29 | - [Norvig](../query/Norvig) 30 | - [PRF](../query/PRF) 31 | - [Query](../query/Query) 32 | 33 | ## rank 34 | 35 | - [CrossEncoder](../rank/CrossEncoder) 36 | - [DPR](../rank/DPR) 37 | - [Embedding](../rank/Embedding) 38 | - [Encoder](../rank/Encoder) 39 | - [Ranker](../rank/Ranker) 40 | 41 | ## retrieve 42 | 43 | - [DPR](../retrieve/DPR) 44 | - [Embedding](../retrieve/Embedding) 45 | - [Encoder](../retrieve/Encoder) 46 | - [Flash](../retrieve/Flash) 47 | - [Fuzz](../retrieve/Fuzz) 48 | - [Lunr](../retrieve/Lunr) 49 | - [Retriever](../retrieve/Retriever) 50 | - [TfIdf](../retrieve/TfIdf) 51 | 52 | ## utils 53 | 54 | 55 | **Classes** 56 | 57 | - [TopK](../utils/TopK) 58 | 59 | **Functions** 60 | 61 | - [quantize](../utils/quantize) 62 | - [yield_batch](../utils/yield-batch) 63 | - [yield_batch_single](../utils/yield-batch-single) 64 | 65 | -------------------------------------------------------------------------------- /docs/api/qa/.pages: -------------------------------------------------------------------------------- 1 | title: qa -------------------------------------------------------------------------------- /docs/api/qa/QA.md: -------------------------------------------------------------------------------- 1 | # QA 2 | 3 | Question Answering model. QA models needs input documents contents to run. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **on** (*Union[str, list]*) 10 | 11 | Fields to use to answer to the question. 12 | 13 | - **model** 14 | 15 | Hugging Face question answering model available [here](https://huggingface.co/models?pipeline_tag=question-answering). 16 | 17 | - **batch_size** (*int*) – defaults to `32` 18 | 19 | 20 | 21 | ## Examples 22 | 23 | ```python 24 | >>> from pprint import pprint as print 25 | >>> from cherche import retrieve, qa 26 | >>> from transformers import pipeline 27 | 28 | >>> documents = [ 29 | ... {"id": 0, "title": "Paris France"}, 30 | ... {"id": 1, "title": "Madrid Spain"}, 31 | ... {"id": 2, "title": "Montreal Canada"} 32 | ... ] 33 | 34 | >>> retriever = retrieve.TfIdf(key="id", on=["title"], documents=documents) 35 | 36 | >>> qa_model = qa.QA( 37 | ... model = pipeline("question-answering", model = "deepset/roberta-base-squad2", tokenizer = "deepset/roberta-base-squad2"), 38 | ... on = ["title"], 39 | ... ) 40 | 41 | >>> pipeline = retriever + documents + qa_model 42 | 43 | >>> pipeline 44 | TfIdf retriever 45 | key : id 46 | on : title 47 | documents: 3 48 | Mapping to documents 49 | Question Answering 50 | on: title 51 | 52 | >>> print(pipeline(q="what is the capital of france?")) 53 | [{'answer': 'Paris', 54 | 'end': 5, 55 | 'id': 0, 56 | 'question': 'what is the capital of france?', 57 | 'score': 0.05615315958857536, 58 | 'similarity': 0.5962847939999439, 59 | 'start': 0, 60 | 'title': 'Paris France'}, 61 | {'answer': 'Montreal', 62 | 'end': 8, 63 | 'id': 2, 64 | 'question': 'what is the capital of france?', 65 | 'score': 0.01080897357314825, 66 | 'similarity': 0.0635641726163728, 67 | 'start': 0, 68 | 'title': 'Montreal Canada'}] 69 | 70 | >>> print(pipeline(["what is the capital of France?", "what is the capital of Canada?"])) 71 | [[{'answer': 'Paris', 72 | 'end': 5, 73 | 'id': 0, 74 | 'question': 'what is the capital of France?', 75 | 'score': 0.1554129421710968, 76 | 'similarity': 0.5962847939999439, 77 | 'start': 0, 78 | 'title': 'Paris France'}, 79 | {'answer': 'Montreal', 80 | 'end': 8, 81 | 'id': 2, 82 | 'question': 'what is the capital of France?', 83 | 'score': 1.2884755960840266e-05, 84 | 'similarity': 0.0635641726163728, 85 | 'start': 0, 86 | 'title': 'Montreal Canada'}], 87 | [{'answer': 'Montreal', 88 | 'end': 8, 89 | 'id': 2, 90 | 'question': 'what is the capital of Canada?', 91 | 'score': 0.05316793918609619, 92 | 'similarity': 0.5125692857821978, 93 | 'start': 0, 94 | 'title': 'Montreal Canada'}, 95 | {'answer': 'Paris France', 96 | 'end': 12, 97 | 'id': 0, 98 | 'question': 'what is the capital of Canada?', 99 | 'score': 4.7594025431862974e-07, 100 | 'similarity': 0.035355339059327376, 101 | 'start': 0, 102 | 'title': 'Paris France'}]] 103 | ``` 104 | 105 | ## Methods 106 | 107 | ???- note "__call__" 108 | 109 | Question answering main method. 110 | 111 | **Parameters** 112 | 113 | - **q** (*Union[str, List[str]]*) 114 | - **documents** (*Union[List[List[Dict[str, str]]], List[Dict[str, str]]]*) 115 | - **batch_size** (*Optional[int]*) – defaults to `None` 116 | - **kwargs** 117 | 118 | ???- note "get_question_context" 119 | 120 | -------------------------------------------------------------------------------- /docs/api/query/.pages: -------------------------------------------------------------------------------- 1 | title: query -------------------------------------------------------------------------------- /docs/api/query/Norvig.md: -------------------------------------------------------------------------------- 1 | # Norvig 2 | 3 | Spelling corrector written by Peter Norvig: [How to Write a Spelling Corrector](https://norvig.com/spell-correct.html) 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **on** (*Union[str, List]*) 10 | 11 | Fields to use for fitting the spelling corrector on. 12 | 13 | - **lower** (*bool*) – defaults to `True` 14 | 15 | 16 | ## Attributes 17 | 18 | - **type** 19 | 20 | 21 | ## Examples 22 | 23 | ```python 24 | >>> from cherche import query, data 25 | 26 | >>> documents = data.load_towns() 27 | 28 | >>> corrector = query.Norvig(on = ["title", "article"], lower=True) 29 | 30 | >>> corrector.add(documents) 31 | Query Norvig 32 | Vocabulary: 967 33 | 34 | >>> corrector(q="tha citi af Parisa is in Fronce") 35 | 'the city of paris is in france' 36 | 37 | >>> corrector(q=["tha citi af Parisa is in Fronce", "parisa"]) 38 | ['the city of paris is in france', 'paris'] 39 | ``` 40 | 41 | ## Methods 42 | 43 | ???- note "__call__" 44 | 45 | Correct spelling errors in a given query. 46 | 47 | **Parameters** 48 | 49 | - **q** (*Union[List[str], str]*) 50 | - **kwargs** 51 | 52 | ???- note "add" 53 | 54 | Fit Nervig spelling corrector. 55 | 56 | **Parameters** 57 | 58 | - **documents** (*Union[List[Dict], str]*) 59 | 60 | ???- note "correct" 61 | 62 | Most probable spelling correction for word. 63 | 64 | **Parameters** 65 | 66 | - **word** (*str*) 67 | 68 | ???- note "reset" 69 | 70 | Wipe dictionary. 71 | 72 | 73 | ## References 74 | 75 | 1. [How to Write a Spelling Corrector](https://norvig.com/spell-correct.html) 76 | 77 | -------------------------------------------------------------------------------- /docs/api/query/PRF.md: -------------------------------------------------------------------------------- 1 | # PRF 2 | 3 | Pseudo (or blind) Relevance-Feedback module. The Query-Augmentation method applies a fast document retrieving method and then extracts keywords from relevant documents. Thus, we have to retrieve top words from relevant documents to give a proper augmentation of a given query. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **on** (*Union[str, list]*) 10 | 11 | Fields to use for fitting the spelling corrector on. 12 | 13 | - **documents** (*list*) 14 | 15 | - **tf** (*sklearn.feature_extraction.text.CountVectorizer*) – defaults to `sparse.TfidfVectorizer()` 16 | 17 | defaults to sklearn.feature_extraction.text.sparse.TfidfVectorizer. If you want to implement your own tf, it needs to follow the sklearn base API and provides the `transform` `fit_transform` and `get_feature_names_out` methods. See sklearn documentation for more information. 18 | 19 | - **nb_docs** (*int*) – defaults to `5` 20 | 21 | Number of documents from which to retrieve top-terms. 22 | 23 | - **nb_terms_per_doc** (*int*) – defaults to `3` 24 | 25 | Number of terms to extract from each top documents retrieved. 26 | 27 | 28 | ## Attributes 29 | 30 | - **type** 31 | 32 | 33 | ## Examples 34 | 35 | ```python 36 | >>> from cherche import query, data 37 | 38 | >>> documents = data.load_towns() 39 | 40 | >>> prf = query.PRF( 41 | ... on=["title", "article"], 42 | ... nb_docs=8, nb_terms_per_doc=1, 43 | ... documents=documents 44 | ... ) 45 | 46 | >>> prf 47 | Query PRF 48 | on : title, article 49 | documents: 8 50 | terms : 1 51 | 52 | >>> prf(q="Europe") 53 | 'Europe art metro space science bordeaux paris university significance' 54 | 55 | >>> prf(q=["Europe", "Paris"]) 56 | ['Europe art metro space science bordeaux paris university significance', 'Paris received paris club subway billion source tour tournament'] 57 | ``` 58 | 59 | ## Methods 60 | 61 | ???- note "__call__" 62 | 63 | Augment a given query with new terms. 64 | 65 | **Parameters** 66 | 67 | - **q** (*Union[List[str], str]*) 68 | - **kwargs** 69 | 70 | ## References 71 | 72 | 1. [Relevance feedback and pseudo relevance feedback](https://nlp.stanford.edu/IR-book/html/htmledition/relevance-feedback-and-pseudo-relevance-feedback-1.html) 73 | 2. [Blind Feedback](https://en.wikipedia.org/wiki/Relevance_feedback#Blind_feedback) 74 | 75 | -------------------------------------------------------------------------------- /docs/api/query/Query.md: -------------------------------------------------------------------------------- 1 | # Query 2 | 3 | Abstract class for models working on a query. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **on** (*Union[str, list]*) 10 | 11 | 12 | ## Attributes 13 | 14 | - **type** 15 | 16 | 17 | 18 | ## Methods 19 | 20 | ???- note "__call__" 21 | 22 | Call self as a function. 23 | 24 | **Parameters** 25 | 26 | - **q** (*Union[List[str], str]*) 27 | - **kwargs** 28 | 29 | -------------------------------------------------------------------------------- /docs/api/rank/.pages: -------------------------------------------------------------------------------- 1 | title: rank -------------------------------------------------------------------------------- /docs/api/rank/CrossEncoder.md: -------------------------------------------------------------------------------- 1 | # CrossEncoder 2 | 3 | Cross-Encoder as a ranker. CrossEncoder takes both the query and the document as input and outputs a score. The score is a similarity score between the query and the document. The CrossEncoder cannot pre-compute the embeddings of the documents since it need both the query and the document. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **on** (*Union[List[str], str]*) 10 | 11 | Fields to use to match the query to the documents. 12 | 13 | - **encoder** 14 | 15 | Sentence Transformer cross-encoder. 16 | 17 | - **k** (*Optional[int]*) – defaults to `None` 18 | 19 | - **batch_size** (*int*) – defaults to `64` 20 | 21 | 22 | 23 | ## Examples 24 | 25 | ```python 26 | >>> from pprint import pprint as print 27 | >>> from cherche import retrieve, rank, evaluate, data 28 | >>> from sentence_transformers import CrossEncoder 29 | 30 | >>> documents, query_answers = data.arxiv_tags( 31 | ... arxiv_title=True, arxiv_summary=False, comment=False 32 | ... ) 33 | 34 | >>> retriever = retrieve.TfIdf( 35 | ... key="uri", 36 | ... on=["prefLabel_text", "altLabel_text"], 37 | ... documents=documents, 38 | ... k=100, 39 | ... ) 40 | 41 | >>> ranker = rank.CrossEncoder( 42 | ... on = ["prefLabel_text", "altLabel_text"], 43 | ... encoder = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1").predict, 44 | ... ) 45 | 46 | >>> pipeline = retriever + documents + ranker 47 | 48 | >>> match = pipeline("graph neural network", k=5) 49 | 50 | >>> for m in match: 51 | ... print(m.get("uri", "")) 52 | 'http://www.semanlink.net/tag/graph_neural_networks' 53 | 'http://www.semanlink.net/tag/artificial_neural_network' 54 | 'http://www.semanlink.net/tag/dans_deep_averaging_neural_networks' 55 | 'http://www.semanlink.net/tag/recurrent_neural_network' 56 | 'http://www.semanlink.net/tag/convolutional_neural_network' 57 | ``` 58 | 59 | ## Methods 60 | 61 | ???- note "__call__" 62 | 63 | Rank inputs documents based on query. 64 | 65 | **Parameters** 66 | 67 | - **q** (*str*) 68 | - **documents** (*list*) 69 | - **batch_size** (*Optional[int]*) – defaults to `None` 70 | - **k** (*Optional[int]*) – defaults to `None` 71 | - **kwargs** 72 | 73 | ## References 74 | 75 | 1. [Sentence Transformers Cross-Encoders](https://www.sbert.net/examples/applications/cross-encoder/README.html) 76 | 2. [Cross-Encoders Hub](https://huggingface.co/cross-encoder) 77 | 78 | -------------------------------------------------------------------------------- /docs/api/rank/DPR.md: -------------------------------------------------------------------------------- 1 | # DPR 2 | 3 | Dual Sentence Transformer as a ranker. This ranker is compatible with any SentenceTransformer. DPR is a dual encoder model, it uses two SentenceTransformer, one for encoding documents and one for encoding queries. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **on** (*Union[str, List[str]]*) 10 | 11 | Fields on wich encoder will perform similarity matching. 12 | 13 | - **key** (*str*) 14 | 15 | Field identifier of each document. 16 | 17 | - **encoder** 18 | 19 | Encoding function dedicated documents. 20 | 21 | - **query_encoder** 22 | 23 | Encoding function dedicated to queries. 24 | 25 | - **normalize** (*bool*) – defaults to `True` 26 | 27 | If set to True, the similarity measure is cosine similarity, if set to False, similarity measure is dot product. 28 | 29 | - **k** (*Optional[int]*) – defaults to `None` 30 | 31 | - **batch_size** (*int*) – defaults to `64` 32 | 33 | 34 | 35 | ## Examples 36 | 37 | ```python 38 | >>> from pprint import pprint as print 39 | >>> from cherche import rank 40 | >>> from sentence_transformers import SentenceTransformer 41 | 42 | >>> ranker = rank.DPR( 43 | ... key = "id", 44 | ... on = ["title", "article"], 45 | ... encoder = SentenceTransformer('facebook-dpr-ctx_encoder-single-nq-base').encode, 46 | ... query_encoder = SentenceTransformer('facebook-dpr-question_encoder-single-nq-base').encode, 47 | ... normalize = True, 48 | ... ) 49 | 50 | >>> documents = [ 51 | ... {"id": 0, "title": "Paris France"}, 52 | ... {"id": 1, "title": "Madrid Spain"}, 53 | ... {"id": 2, "title": "Montreal Canada"} 54 | ... ] 55 | 56 | >>> ranker.add(documents=documents) 57 | DPR ranker 58 | key : id 59 | on : title, article 60 | normalize : True 61 | embeddings: 3 62 | 63 | >>> match = ranker( 64 | ... q="Paris", 65 | ... documents=documents 66 | ... ) 67 | 68 | >>> print(match) 69 | [{'id': 0, 'similarity': 7.806636, 'title': 'Paris France'}, 70 | {'id': 1, 'similarity': 6.239272, 'title': 'Madrid Spain'}, 71 | {'id': 2, 'similarity': 6.168748, 'title': 'Montreal Canada'}] 72 | 73 | >>> match = ranker( 74 | ... q=["Paris", "Madrid"], 75 | ... documents=[documents + [{"id": 3, "title": "Paris"}]] * 2, 76 | ... k=2, 77 | ... ) 78 | 79 | >>> print(match) 80 | [[{'id': 3, 'similarity': 7.906666, 'title': 'Paris'}, 81 | {'id': 0, 'similarity': 7.806636, 'title': 'Paris France'}], 82 | [{'id': 1, 'similarity': 8.07025, 'title': 'Madrid Spain'}, 83 | {'id': 0, 'similarity': 6.1131663, 'title': 'Paris France'}]] 84 | ``` 85 | 86 | ## Methods 87 | 88 | ???- note "__call__" 89 | 90 | Encode input query and ranks documents based on the similarity between the query and the selected field of the documents. 91 | 92 | **Parameters** 93 | 94 | - **q** (*Union[List[str], str]*) 95 | - **documents** (*Union[List[List[Dict[str, str]]], List[Dict[str, str]]]*) 96 | - **k** (*int*) – defaults to `None` 97 | - **batch_size** (*Optional[int]*) – defaults to `None` 98 | - **kwargs** 99 | 100 | ???- note "add" 101 | 102 | Pre-compute embeddings and store them at the selected path. 103 | 104 | **Parameters** 105 | 106 | - **documents** (*List[Dict[str, str]]*) 107 | - **batch_size** (*int*) – defaults to `64` 108 | 109 | ???- note "encode_rank" 110 | 111 | Encode documents and rank them according to the query. 112 | 113 | **Parameters** 114 | 115 | - **embeddings_queries** (*numpy.ndarray*) 116 | - **documents** (*List[List[Dict[str, str]]]*) 117 | - **k** (*int*) 118 | - **batch_size** (*Optional[int]*) – defaults to `None` 119 | 120 | ???- note "rank" 121 | 122 | Rank inputs documents ordered by relevance among the top k. 123 | 124 | **Parameters** 125 | 126 | - **embeddings_documents** (*Dict[str, numpy.ndarray]*) 127 | - **embeddings_queries** (*numpy.ndarray*) 128 | - **documents** (*List[List[Dict[str, str]]]*) 129 | - **k** (*int*) 130 | - **batch_size** (*Optional[int]*) – defaults to `None` 131 | 132 | -------------------------------------------------------------------------------- /docs/api/rank/Embedding.md: -------------------------------------------------------------------------------- 1 | # Embedding 2 | 3 | Collaborative filtering as a ranker. Recommend is compatible with the library [Implicit](https://github.com/benfred/implicit). 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **key** (*'str'*) 10 | 11 | Field identifier of each document. 12 | 13 | - **normalize** (*'bool'*) – defaults to `True` 14 | 15 | If set to True, the similarity measure is cosine similarity, if set to False, similarity measure is dot product. 16 | 17 | - **k** (*'typing.Optional[int]'*) – defaults to `None` 18 | 19 | - **batch_size** (*'int'*) – defaults to `1024` 20 | 21 | 22 | 23 | ## Examples 24 | 25 | ```python 26 | >>> from pprint import pprint as print 27 | >>> from cherche import rank 28 | >>> from sentence_transformers import SentenceTransformer 29 | 30 | >>> documents = [ 31 | ... {"id": "a", "title": "Paris"}, 32 | ... {"id": "b", "title": "Madrid"}, 33 | ... {"id": "c", "title": "Montreal"}, 34 | ... ] 35 | 36 | >>> encoder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") 37 | >>> embeddings_documents = encoder.encode([ 38 | ... document["title"] for document in documents 39 | ... ]) 40 | 41 | >>> recommend = rank.Embedding( 42 | ... key="id", 43 | ... ) 44 | 45 | >>> recommend.add( 46 | ... documents=documents, 47 | ... embeddings_documents=embeddings_documents, 48 | ... ) 49 | Embedding ranker 50 | key : id 51 | documents: 3 52 | normalize: True 53 | 54 | >>> match = recommend( 55 | ... q=encoder.encode("Paris"), 56 | ... documents=documents, 57 | ... k=2 58 | ... ) 59 | 60 | >>> print(match) 61 | [{'id': 'a', 'similarity': 1.0, 'title': 'Paris'}, 62 | {'id': 'c', 'similarity': 0.57165134, 'title': 'Montreal'}] 63 | 64 | >>> queries = [ 65 | ... "Paris", 66 | ... "Madrid", 67 | ... "Montreal" 68 | ... ] 69 | 70 | >>> match = recommend( 71 | ... q=encoder.encode(queries), 72 | ... documents=[documents] * 3, 73 | ... k=2 74 | ... ) 75 | 76 | >>> print(match) 77 | [[{'id': 'a', 'similarity': 1.0, 'title': 'Paris'}, 78 | {'id': 'c', 'similarity': 0.57165134, 'title': 'Montreal'}], 79 | [{'id': 'b', 'similarity': 1.0, 'title': 'Madrid'}, 80 | {'id': 'a', 'similarity': 0.49815434, 'title': 'Paris'}], 81 | [{'id': 'c', 'similarity': 0.9999999, 'title': 'Montreal'}, 82 | {'id': 'a', 'similarity': 0.5716514, 'title': 'Paris'}]] 83 | ``` 84 | 85 | ## Methods 86 | 87 | ???- note "__call__" 88 | 89 | Retrieve documents from user id. 90 | 91 | **Parameters** 92 | 93 | - **q** (*'np.ndarray'*) 94 | - **documents** (*'typing.Union[typing.List[typing.List[typing.Dict[str, str]]], typing.List[typing.Dict[str, str]]]'*) 95 | - **k** (*'typing.Optional[int]'*) – defaults to `None` 96 | - **batch_size** (*'typing.Optional[int]'*) – defaults to `None` 97 | - **kwargs** 98 | 99 | ???- note "add" 100 | 101 | Add embeddings both documents and users. 102 | 103 | **Parameters** 104 | 105 | - **documents** (*'list'*) 106 | - **embeddings_documents** (*'typing.List[np.ndarray]'*) 107 | - **kwargs** 108 | 109 | ???- note "encode_rank" 110 | 111 | Encode documents and rank them according to the query. 112 | 113 | **Parameters** 114 | 115 | - **embeddings_queries** (*numpy.ndarray*) 116 | - **documents** (*List[List[Dict[str, str]]]*) 117 | - **k** (*int*) 118 | - **batch_size** (*Optional[int]*) – defaults to `None` 119 | 120 | ???- note "rank" 121 | 122 | Rank inputs documents ordered by relevance among the top k. 123 | 124 | **Parameters** 125 | 126 | - **embeddings_documents** (*Dict[str, numpy.ndarray]*) 127 | - **embeddings_queries** (*numpy.ndarray*) 128 | - **documents** (*List[List[Dict[str, str]]]*) 129 | - **k** (*int*) 130 | - **batch_size** (*Optional[int]*) – defaults to `None` 131 | 132 | -------------------------------------------------------------------------------- /docs/api/rank/Encoder.md: -------------------------------------------------------------------------------- 1 | # Encoder 2 | 3 | Sentence Transformer as a ranker. This ranker is compatible with any SentenceTransformer. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **on** (*Union[str, List[str]]*) 10 | 11 | Fields on wich encoder will perform similarity matching. 12 | 13 | - **key** (*str*) 14 | 15 | Field identifier of each document. 16 | 17 | - **encoder** 18 | 19 | Encoding function dedicated to both documents and queries. 20 | 21 | - **normalize** (*bool*) – defaults to `True` 22 | 23 | If set to True, the similarity measure is cosine similarity, if set to False, similarity measure is dot product. 24 | 25 | - **k** (*Optional[int]*) – defaults to `None` 26 | 27 | - **batch_size** (*int*) – defaults to `64` 28 | 29 | 30 | 31 | ## Examples 32 | 33 | ```python 34 | >>> from pprint import pprint as print 35 | >>> from cherche import rank 36 | >>> from sentence_transformers import SentenceTransformer 37 | 38 | >>> documents = [ 39 | ... {"id": 0, "title": "Paris France"}, 40 | ... {"id": 1, "title": "Madrid Spain"}, 41 | ... {"id": 2, "title": "Montreal Canada"} 42 | ... ] 43 | 44 | >>> ranker = rank.Encoder( 45 | ... encoder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2").encode, 46 | ... key = "id", 47 | ... on = ["title"], 48 | ... ) 49 | 50 | >>> ranker.add(documents=documents) 51 | Encoder ranker 52 | key : id 53 | on : title 54 | normalize : True 55 | embeddings: 3 56 | 57 | >>> match = ranker( 58 | ... q="Paris", 59 | ... documents=documents 60 | ... ) 61 | 62 | >>> print(match) 63 | [{'id': 0, 'similarity': 0.7127624, 'title': 'Paris France'}, 64 | {'id': 1, 'similarity': 0.5497405, 'title': 'Madrid Spain'}, 65 | {'id': 2, 'similarity': 0.50252455, 'title': 'Montreal Canada'}] 66 | 67 | >>> match = ranker( 68 | ... q=["Paris France", "Madrid Spain"], 69 | ... documents=[documents + [{"id": 3, "title": "Paris"}]] * 2, 70 | ... k=2, 71 | ... ) 72 | 73 | >>> print(match) 74 | [[{'id': 0, 'similarity': 0.99999994, 'title': 'Paris France'}, 75 | {'id': 1, 'similarity': 0.856435, 'title': 'Madrid Spain'}], 76 | [{'id': 1, 'similarity': 1.0, 'title': 'Madrid Spain'}, 77 | {'id': 0, 'similarity': 0.856435, 'title': 'Paris France'}]] 78 | ``` 79 | 80 | ## Methods 81 | 82 | ???- note "__call__" 83 | 84 | Encode input query and ranks documents based on the similarity between the query and the selected field of the documents. 85 | 86 | **Parameters** 87 | 88 | - **q** (*Union[List[str], str]*) 89 | - **documents** (*Union[List[List[Dict[str, str]]], List[Dict[str, str]]]*) 90 | - **k** (*Optional[int]*) – defaults to `None` 91 | - **batch_size** (*Optional[int]*) – defaults to `None` 92 | - **kwargs** 93 | 94 | ???- note "add" 95 | 96 | Pre-compute embeddings and store them at the selected path. 97 | 98 | **Parameters** 99 | 100 | - **documents** (*List[Dict[str, str]]*) 101 | - **batch_size** (*int*) – defaults to `64` 102 | 103 | ???- note "encode_rank" 104 | 105 | Encode documents and rank them according to the query. 106 | 107 | **Parameters** 108 | 109 | - **embeddings_queries** (*numpy.ndarray*) 110 | - **documents** (*List[List[Dict[str, str]]]*) 111 | - **k** (*int*) 112 | - **batch_size** (*Optional[int]*) – defaults to `None` 113 | 114 | ???- note "rank" 115 | 116 | Rank inputs documents ordered by relevance among the top k. 117 | 118 | **Parameters** 119 | 120 | - **embeddings_documents** (*Dict[str, numpy.ndarray]*) 121 | - **embeddings_queries** (*numpy.ndarray*) 122 | - **documents** (*List[List[Dict[str, str]]]*) 123 | - **k** (*int*) 124 | - **batch_size** (*Optional[int]*) – defaults to `None` 125 | 126 | -------------------------------------------------------------------------------- /docs/api/rank/Ranker.md: -------------------------------------------------------------------------------- 1 | # Ranker 2 | 3 | Abstract class for ranking models. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **key** (*str*) 10 | 11 | Field identifier of each document. 12 | 13 | - **on** (*Union[str, List[str]]*) 14 | 15 | Fields of the documents to use for ranking. 16 | 17 | - **encoder** 18 | 19 | Encoding function to computes embeddings of the documents. 20 | 21 | - **normalize** (*bool*) 22 | 23 | Normalize the embeddings in order to measure cosine similarity if set to True, dot product if set to False. 24 | 25 | - **batch_size** (*int*) 26 | 27 | - **k** (*Optional[int]*) – defaults to `None` 28 | 29 | 30 | 31 | 32 | ## Methods 33 | 34 | ???- note "__call__" 35 | 36 | Rank documents according to the query. 37 | 38 | **Parameters** 39 | 40 | - **q** (*Union[List[str], str]*) 41 | - **documents** (*Union[List[List[Dict[str, str]]], List[Dict[str, str]]]*) 42 | - **k** (*int*) 43 | - **batch_size** (*Optional[int]*) – defaults to `None` 44 | - **kwargs** 45 | 46 | ???- note "add" 47 | 48 | Pre-compute embeddings and store them at the selected path. 49 | 50 | **Parameters** 51 | 52 | - **documents** (*List[Dict[str, str]]*) 53 | - **batch_size** (*int*) – defaults to `64` 54 | 55 | ???- note "encode_rank" 56 | 57 | Encode documents and rank them according to the query. 58 | 59 | **Parameters** 60 | 61 | - **embeddings_queries** (*numpy.ndarray*) 62 | - **documents** (*List[List[Dict[str, str]]]*) 63 | - **k** (*int*) 64 | - **batch_size** (*Optional[int]*) – defaults to `None` 65 | 66 | ???- note "rank" 67 | 68 | Rank inputs documents ordered by relevance among the top k. 69 | 70 | **Parameters** 71 | 72 | - **embeddings_documents** (*Dict[str, numpy.ndarray]*) 73 | - **embeddings_queries** (*numpy.ndarray*) 74 | - **documents** (*List[List[Dict[str, str]]]*) 75 | - **k** (*int*) 76 | - **batch_size** (*Optional[int]*) – defaults to `None` 77 | 78 | -------------------------------------------------------------------------------- /docs/api/retrieve/.pages: -------------------------------------------------------------------------------- 1 | title: retrieve -------------------------------------------------------------------------------- /docs/api/retrieve/DPR.md: -------------------------------------------------------------------------------- 1 | # DPR 2 | 3 | DPR as a retriever using Faiss Index. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **key** (*str*) 10 | 11 | Field identifier of each document. 12 | 13 | - **on** (*Union[str, list]*) 14 | 15 | Field to use to retrieve documents. 16 | 17 | - **encoder** 18 | 19 | - **query_encoder** 20 | 21 | - **normalize** (*bool*) – defaults to `True` 22 | 23 | Whether to normalize the embeddings before adding them to the index in order to measure cosine similarity. 24 | 25 | - **k** (*Optional[int]*) – defaults to `None` 26 | 27 | - **batch_size** (*int*) – defaults to `64` 28 | 29 | - **index** – defaults to `None` 30 | 31 | Faiss index that will store the embeddings and perform the similarity search. 32 | 33 | 34 | 35 | ## Examples 36 | 37 | ```python 38 | >>> from pprint import pprint as print 39 | >>> from cherche import retrieve 40 | >>> from sentence_transformers import SentenceTransformer 41 | 42 | >>> documents = [ 43 | ... {"id": 0, "title": "Paris France"}, 44 | ... {"id": 1, "title": "Madrid Spain"}, 45 | ... {"id": 2, "title": "Montreal Canada"} 46 | ... ] 47 | 48 | >>> retriever = retrieve.DPR( 49 | ... key = "id", 50 | ... on = ["title"], 51 | ... encoder = SentenceTransformer('facebook-dpr-ctx_encoder-single-nq-base').encode, 52 | ... query_encoder = SentenceTransformer('facebook-dpr-question_encoder-single-nq-base').encode, 53 | ... normalize = True, 54 | ... ) 55 | 56 | >>> retriever.add(documents) 57 | DPR retriever 58 | key : id 59 | on : title 60 | documents: 3 61 | 62 | >>> print(retriever("Spain", k=2)) 63 | [{'id': 1, 'similarity': 0.5534179127892946}, 64 | {'id': 0, 'similarity': 0.48604427456660426}] 65 | 66 | >>> print(retriever(["Spain", "Montreal"], k=2)) 67 | [[{'id': 1, 'similarity': 0.5534179492996913}, 68 | {'id': 0, 'similarity': 0.4860442182428353}], 69 | [{'id': 2, 'similarity': 0.5451990410703741}, 70 | {'id': 0, 'similarity': 0.47405722260691213}]] 71 | ``` 72 | 73 | ## Methods 74 | 75 | ???- note "__call__" 76 | 77 | Retrieve documents from the index. 78 | 79 | **Parameters** 80 | 81 | - **q** (*Union[List[str], str]*) 82 | - **k** (*Optional[int]*) – defaults to `None` 83 | - **batch_size** (*Optional[int]*) – defaults to `None` 84 | - **tqdm_bar** (*bool*) – defaults to `True` 85 | - **kwargs** 86 | 87 | ???- note "add" 88 | 89 | Add documents to the index. 90 | 91 | **Parameters** 92 | 93 | - **documents** (*List[Dict[str, str]]*) 94 | - **batch_size** (*int*) – defaults to `64` 95 | - **tqdm_bar** (*bool*) – defaults to `True` 96 | - **kwargs** 97 | 98 | -------------------------------------------------------------------------------- /docs/api/retrieve/Embedding.md: -------------------------------------------------------------------------------- 1 | # Embedding 2 | 3 | The Embedding retriever is dedicated to perform IR on embeddings calculated by the user rather than Cherche. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **key** (*str*) 10 | 11 | Field identifier of each document. 12 | 13 | - **index** – defaults to `None` 14 | 15 | Faiss index that will store the embeddings and perform the similarity search. 16 | 17 | - **normalize** (*bool*) – defaults to `True` 18 | 19 | Whether to normalize the embeddings before adding them to the index in order to measure cosine similarity. 20 | 21 | - **k** (*Optional[int]*) – defaults to `None` 22 | 23 | - **batch_size** (*int*) – defaults to `1024` 24 | 25 | 26 | 27 | ## Examples 28 | 29 | ```python 30 | >>> from pprint import pprint as print 31 | >>> from cherche import retrieve 32 | >>> from sentence_transformers import SentenceTransformer 33 | 34 | >>> recommend = retrieve.Embedding( 35 | ... key="id", 36 | ... ) 37 | 38 | >>> documents = [ 39 | ... {"id": "a", "title": "Paris", "author": "Paris"}, 40 | ... {"id": "b", "title": "Madrid", "author": "Madrid"}, 41 | ... {"id": "c", "title": "Montreal", "author": "Montreal"}, 42 | ... ] 43 | 44 | >>> encoder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") 45 | >>> embeddings_documents = encoder.encode([ 46 | ... document["title"] for document in documents 47 | ... ]) 48 | 49 | >>> recommend.add( 50 | ... documents=documents, 51 | ... embeddings_documents=embeddings_documents, 52 | ... ) 53 | Embedding retriever 54 | key : id 55 | documents: 3 56 | 57 | >>> queries = [ 58 | ... "Paris", 59 | ... "Madrid", 60 | ... "Montreal" 61 | ... ] 62 | 63 | >>> embeddings_queries = encoder.encode(queries) 64 | >>> print(recommend(embeddings_queries, k=2)) 65 | [[{'id': 'a', 'similarity': 1.0}, 66 | {'id': 'c', 'similarity': 0.5385907831761005}], 67 | [{'id': 'b', 'similarity': 1.0}, 68 | {'id': 'a', 'similarity': 0.4990788711758875}], 69 | [{'id': 'c', 'similarity': 1.0}, 70 | {'id': 'a', 'similarity': 0.5385907831761005}]] 71 | 72 | >>> embeddings_queries = encoder.encode("Paris") 73 | >>> print(recommend(embeddings_queries, k=2)) 74 | [{'id': 'a', 'similarity': 0.9999999999989104}, 75 | {'id': 'c', 'similarity': 0.5385907485958683}] 76 | ``` 77 | 78 | ## Methods 79 | 80 | ???- note "__call__" 81 | 82 | Retrieve documents from the index. 83 | 84 | **Parameters** 85 | 86 | - **q** (*numpy.ndarray*) 87 | - **k** (*Optional[int]*) – defaults to `None` 88 | - **batch_size** (*Optional[int]*) – defaults to `None` 89 | - **tqdm_bar** (*bool*) – defaults to `True` 90 | - **kwargs** 91 | 92 | ???- note "add" 93 | 94 | Add embeddings both documents and users. 95 | 96 | **Parameters** 97 | 98 | - **documents** (*list*) 99 | - **embeddings_documents** (*numpy.ndarray*) 100 | - **kwargs** 101 | 102 | -------------------------------------------------------------------------------- /docs/api/retrieve/Encoder.md: -------------------------------------------------------------------------------- 1 | # Encoder 2 | 3 | Encoder as a retriever using Faiss Index. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **encoder** 10 | 11 | - **key** (*str*) 12 | 13 | Field identifier of each document. 14 | 15 | - **on** (*Union[str, list]*) 16 | 17 | Field to use to retrieve documents. 18 | 19 | - **normalize** (*bool*) – defaults to `True` 20 | 21 | Whether to normalize the embeddings before adding them to the index in order to measure cosine similarity. 22 | 23 | - **k** (*Optional[int]*) – defaults to `None` 24 | 25 | - **batch_size** (*int*) – defaults to `64` 26 | 27 | - **index** – defaults to `None` 28 | 29 | Faiss index that will store the embeddings and perform the similarity search. 30 | 31 | 32 | 33 | ## Examples 34 | 35 | ```python 36 | >>> from pprint import pprint as print 37 | >>> from cherche import retrieve 38 | >>> from sentence_transformers import SentenceTransformer 39 | 40 | >>> documents = [ 41 | ... {"id": 0, "title": "Paris France"}, 42 | ... {"id": 1, "title": "Madrid Spain"}, 43 | ... {"id": 2, "title": "Montreal Canada"} 44 | ... ] 45 | 46 | >>> retriever = retrieve.Encoder( 47 | ... encoder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2").encode, 48 | ... key = "id", 49 | ... on = ["title"], 50 | ... ) 51 | 52 | >>> retriever.add(documents, batch_size=1) 53 | Encoder retriever 54 | key : id 55 | on : title 56 | documents: 3 57 | 58 | >>> print(retriever("Spain", k=2)) 59 | [{'id': 1, 'similarity': 0.6544566453117681}, 60 | {'id': 0, 'similarity': 0.5405465419981407}] 61 | 62 | >>> print(retriever(["Spain", "Montreal"], k=2)) 63 | [[{'id': 1, 'similarity': 0.6544566453117681}, 64 | {'id': 0, 'similarity': 0.54054659424589}], 65 | [{'id': 2, 'similarity': 0.7372165680578416}, 66 | {'id': 0, 'similarity': 0.5185645704259234}]] 67 | ``` 68 | 69 | ## Methods 70 | 71 | ???- note "__call__" 72 | 73 | Retrieve documents from the index. 74 | 75 | **Parameters** 76 | 77 | - **q** (*Union[List[str], str]*) 78 | - **k** (*Optional[int]*) – defaults to `None` 79 | - **batch_size** (*Optional[int]*) – defaults to `None` 80 | - **tqdm_bar** (*bool*) – defaults to `True` 81 | - **kwargs** 82 | 83 | ???- note "add" 84 | 85 | Add documents to the index. 86 | 87 | **Parameters** 88 | 89 | - **documents** (*List[Dict[str, str]]*) 90 | - **batch_size** (*int*) – defaults to `64` 91 | - **tqdm_bar** (*bool*) – defaults to `True` 92 | - **kwargs** 93 | 94 | -------------------------------------------------------------------------------- /docs/api/retrieve/Flash.md: -------------------------------------------------------------------------------- 1 | # Flash 2 | 3 | FlashText Retriever. Flash aims to find documents that contain keywords such as a list of tags for example. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **key** (*str*) 10 | 11 | Field identifier of each document. 12 | 13 | - **on** (*Union[str, list]*) 14 | 15 | Fields to use to match the query to the documents. 16 | 17 | - **keywords** (*flashtext.keyword.KeywordProcessor*) – defaults to `None` 18 | 19 | Keywords extractor from [FlashText](https://github.com/vi3k6i5/flashtext). If set to None, a default one is created. 20 | 21 | - **lowercase** (*bool*) – defaults to `True` 22 | 23 | - **k** (*Optional[int]*) – defaults to `None` 24 | 25 | 26 | 27 | ## Examples 28 | 29 | ```python 30 | >>> from pprint import pprint as print 31 | >>> from cherche import retrieve 32 | 33 | >>> documents = [ 34 | ... {"id": 0, "title": "paris", "article": "eiffel tower"}, 35 | ... {"id": 1, "title": "paris", "article": "paris"}, 36 | ... {"id": 2, "title": "montreal", "article": "montreal is in canada"}, 37 | ... ] 38 | 39 | >>> retriever = retrieve.Flash(key="id", on=["title", "article"]) 40 | 41 | >>> retriever.add(documents=documents) 42 | Flash retriever 43 | key : id 44 | on : title, article 45 | documents: 4 46 | 47 | >>> print(retriever(q="paris", k=2)) 48 | [{'id': 1, 'similarity': 0.6666666666666666}, 49 | {'id': 0, 'similarity': 0.3333333333333333}] 50 | 51 | ``` 52 | 53 | [{'id': 0, 'similarity': 1}, {'id': 1, 'similarity': 1}] 54 | 55 | ```python 56 | >>> print(retriever(q=["paris", "montreal"])) 57 | [[{'id': 1, 'similarity': 0.6666666666666666}, 58 | {'id': 0, 'similarity': 0.3333333333333333}], 59 | [{'id': 2, 'similarity': 1.0}]] 60 | ``` 61 | 62 | ## Methods 63 | 64 | ???- note "__call__" 65 | 66 | Retrieve documents from the index. 67 | 68 | **Parameters** 69 | 70 | - **q** (*Union[List[str], str]*) 71 | - **k** (*Optional[int]*) – defaults to `None` 72 | - **tqdm_bar** (*bool*) – defaults to `True` 73 | - **kwargs** 74 | 75 | ???- note "add" 76 | 77 | Add keywords to the retriever. 78 | 79 | **Parameters** 80 | 81 | - **documents** (*List[Dict[str, str]]*) 82 | - **kwargs** 83 | 84 | ## References 85 | 86 | 1. [FlashText](https://github.com/vi3k6i5/flashtext) 87 | 2. [Replace or Retrieve Keywords In Documents at Scale](https://arxiv.org/abs/1711.00046) 88 | 89 | -------------------------------------------------------------------------------- /docs/api/retrieve/Fuzz.md: -------------------------------------------------------------------------------- 1 | # Fuzz 2 | 3 | [RapidFuzz](https://github.com/maxbachmann/RapidFuzz) wrapper. Rapid fuzzy string matching in Python and C++ using the Levenshtein Distance. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **key** (*str*) 10 | 11 | Field identifier of each document. 12 | 13 | - **on** (*Union[str, list]*) 14 | 15 | Fields to use to match the query to the documents. 16 | 17 | - **fuzzer** – defaults to `` 18 | 19 | [RapidFuzz scorer](https://maxbachmann.github.io/RapidFuzz/Usage/fuzz.html): fuzz.ratio, fuzz.partial_ratio, fuzz.token_set_ratio, fuzz.partial_token_set_ratio, fuzz.token_sort_ratio, fuzz.partial_token_sort_ratio, fuzz.token_ratio, fuzz.partial_token_ratio, fuzz.WRatio, fuzz.QRatio, string_metric.levenshtein, string_metric.normalized_levenshtein 20 | 21 | - **default_process** (*bool*) – defaults to `True` 22 | 23 | Pre-processing step. If set to True, documents processed by [RapidFuzz default process.](https://maxbachmann.github.io/RapidFuzz/Usage/utils.html) 24 | 25 | - **k** (*Optional[int]*) – defaults to `None` 26 | 27 | 28 | 29 | ## Examples 30 | 31 | ```python 32 | >>> from pprint import pprint as print 33 | >>> from cherche import retrieve 34 | >>> from rapidfuzz import fuzz 35 | 36 | >>> documents = [ 37 | ... {"id": 0, "title": "Paris", "article": "Eiffel tower"}, 38 | ... {"id": 1, "title": "Paris", "article": "Paris is in France."}, 39 | ... {"id": 2, "title": "Montreal", "article": "Montreal is in Canada."}, 40 | ... ] 41 | 42 | >>> retriever = retrieve.Fuzz( 43 | ... key = "id", 44 | ... on = ["title", "article"], 45 | ... fuzzer = fuzz.partial_ratio, 46 | ... ) 47 | 48 | >>> retriever.add(documents=documents) 49 | Fuzz retriever 50 | key : id 51 | on : title, article 52 | documents: 3 53 | 54 | >>> print(retriever(q="paris", k=2)) 55 | [{'id': 0, 'similarity': 100.0}, {'id': 1, 'similarity': 100.0}] 56 | 57 | >>> print(retriever(q=["paris", "montreal"], k=2)) 58 | [[{'id': 0, 'similarity': 100.0}, {'id': 1, 'similarity': 100.0}], 59 | [{'id': 2, 'similarity': 100.0}, {'id': 1, 'similarity': 37.5}]] 60 | 61 | >>> print(retriever(q=["unknown", "montreal"], k=2)) 62 | [[{'id': 2, 'similarity': 40.0}, {'id': 0, 'similarity': 36.36363636363637}], 63 | [{'id': 2, 'similarity': 100.0}, {'id': 1, 'similarity': 37.5}]] 64 | ``` 65 | 66 | ## Methods 67 | 68 | ???- note "__call__" 69 | 70 | Retrieve documents from the index. 71 | 72 | **Parameters** 73 | 74 | - **q** (*Union[List[str], str]*) 75 | - **k** (*Optional[int]*) – defaults to `None` 76 | - **tqdm_bar** (*bool*) – defaults to `True` 77 | - **kwargs** 78 | 79 | ???- note "add" 80 | 81 | Fuzz is streaming friendly. 82 | 83 | **Parameters** 84 | 85 | - **documents** (*List[Dict[str, str]]*) 86 | - **kwargs** 87 | 88 | ## References 89 | 90 | 1. [RapidFuzz](https://github.com/maxbachmann/RapidFuzz) 91 | 92 | -------------------------------------------------------------------------------- /docs/api/retrieve/Lunr.md: -------------------------------------------------------------------------------- 1 | # Lunr 2 | 3 | Lunr is a Python implementation of Lunr.js by Oliver Nightingale. Lunr is a retriever dedicated for small and middle size corpus. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **key** (*str*) 10 | 11 | Field identifier of each document. 12 | 13 | - **on** (*Union[str, list]*) 14 | 15 | Fields to use to match the query to the documents. 16 | 17 | - **documents** (*list*) 18 | 19 | Documents in Lunr retriever are static. The retriever must be reseted to index new documents. 20 | 21 | - **k** (*Optional[int]*) – defaults to `None` 22 | 23 | 24 | 25 | ## Examples 26 | 27 | ```python 28 | >>> from pprint import pprint as print 29 | >>> from cherche import retrieve 30 | 31 | >>> documents = [ 32 | ... {"id": 0, "title": "Paris", "article": "Eiffel tower"}, 33 | ... {"id": 1, "title": "Paris", "article": "Paris is in France."}, 34 | ... {"id": 2, "title": "Montreal", "article": "Montreal is in Canada."}, 35 | ... ] 36 | 37 | >>> retriever = retrieve.Lunr( 38 | ... key="id", 39 | ... on=["title", "article"], 40 | ... documents=documents, 41 | ... ) 42 | 43 | >>> retriever 44 | Lunr retriever 45 | key : id 46 | on : title, article 47 | documents: 3 48 | 49 | >>> print(retriever(q="paris", k=2)) 50 | [{'id': 1, 'similarity': 0.268}, {'id': 0, 'similarity': 0.134}] 51 | 52 | >>> print(retriever(q=["paris", "montreal"], k=2)) 53 | [[{'id': 1, 'similarity': 0.268}, {'id': 0, 'similarity': 0.134}], 54 | [{'id': 2, 'similarity': 0.94}]] 55 | ``` 56 | 57 | ## Methods 58 | 59 | ???- note "__call__" 60 | 61 | Retrieve documents from the index. 62 | 63 | **Parameters** 64 | 65 | - **q** (*Union[str, List[str]]*) 66 | - **k** (*Optional[int]*) – defaults to `None` 67 | - **tqdm_bar** (*bool*) – defaults to `True` 68 | - **kwargs** 69 | 70 | ## References 71 | 72 | 1. [Lunr.py](https://github.com/yeraydiazdiaz/lunr.py) 73 | 2. [Lunr.js](https://lunrjs.com) 74 | 2. [Solr](https://solr.apache.org) 75 | 76 | -------------------------------------------------------------------------------- /docs/api/retrieve/Retriever.md: -------------------------------------------------------------------------------- 1 | # Retriever 2 | 3 | Retriever base class. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **key** (*str*) 10 | 11 | Field identifier of each document. 12 | 13 | - **on** (*Union[str, list]*) 14 | 15 | Fields to use to match the query to the documents. 16 | 17 | - **k** (*Optional[int]*) 18 | 19 | - **batch_size** (*int*) 20 | 21 | 22 | 23 | 24 | ## Methods 25 | 26 | ???- note "__call__" 27 | 28 | Retrieve documents from the index. 29 | 30 | **Parameters** 31 | 32 | - **q** (*Union[List[str], str]*) 33 | - **k** (*Optional[int]*) 34 | - **batch_size** (*Optional[int]*) 35 | - **kwargs** 36 | 37 | -------------------------------------------------------------------------------- /docs/api/retrieve/TfIdf.md: -------------------------------------------------------------------------------- 1 | # TfIdf 2 | 3 | TfIdf retriever based on cosine similarities. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **key** (*str*) 10 | 11 | Field identifier of each document. 12 | 13 | - **on** (*Union[str, list]*) 14 | 15 | Fields to use to match the query to the documents. 16 | 17 | - **documents** (*List[Dict[str, str]]*) – defaults to `None` 18 | 19 | Documents in TFIdf retriever are static. The retriever must be reseted to index new documents. 20 | 21 | - **tfidf** (*sklearn.feature_extraction.text.sparse.TfidfVectorizer*) – defaults to `None` 22 | 23 | sparse.TfidfVectorizer class of Sklearn to create a custom TfIdf retriever. 24 | 25 | - **k** (*Optional[int]*) – defaults to `None` 26 | 27 | Number of documents to retrieve. Default is `None`, i.e all documents that match the query will be retrieved. 28 | 29 | - **batch_size** (*int*) – defaults to `1024` 30 | 31 | - **fit** (*bool*) – defaults to `True` 32 | 33 | 34 | 35 | ## Examples 36 | 37 | ```python 38 | >>> from pprint import pprint as print 39 | >>> from cherche import retrieve 40 | >>> from lenlp import sparse 41 | 42 | >>> documents = [ 43 | ... {"id": 0, "title": "Paris", "article": "Eiffel tower"}, 44 | ... {"id": 1, "title": "Montreal", "article": "Montreal is in Canada."}, 45 | ... {"id": 2, "title": "Paris", "article": "Eiffel tower"}, 46 | ... {"id": 3, "title": "Montreal", "article": "Montreal is in Canada."}, 47 | ... ] 48 | 49 | >>> retriever = retrieve.TfIdf( 50 | ... key="id", 51 | ... on=["title", "article"], 52 | ... documents=documents, 53 | ... ) 54 | 55 | >>> documents = [ 56 | ... {"id": 4, "title": "Paris", "article": "Eiffel tower"}, 57 | ... {"id": 5, "title": "Montreal", "article": "Montreal is in Canada."}, 58 | ... {"id": 6, "title": "Paris", "article": "Eiffel tower"}, 59 | ... {"id": 7, "title": "Montreal", "article": "Montreal is in Canada."}, 60 | ... ] 61 | 62 | >>> retriever = retriever.add(documents) 63 | 64 | >>> print(retriever(q=["paris", "canada"], k=4)) 65 | [[{'id': 6, 'similarity': 0.5404109029445249}, 66 | {'id': 0, 'similarity': 0.5404109029445249}, 67 | {'id': 2, 'similarity': 0.5404109029445249}, 68 | {'id': 4, 'similarity': 0.5404109029445249}], 69 | [{'id': 7, 'similarity': 0.3157669764669935}, 70 | {'id': 5, 'similarity': 0.3157669764669935}, 71 | {'id': 3, 'similarity': 0.3157669764669935}, 72 | {'id': 1, 'similarity': 0.3157669764669935}]] 73 | 74 | >>> print(retriever(["unknown", "montreal paris"], k=2)) 75 | [[], 76 | [{'id': 7, 'similarity': 0.7391866872635209}, 77 | {'id': 5, 'similarity': 0.7391866872635209}]] 78 | 79 | >>> print(retriever(q="paris")) 80 | [{'id': 6, 'similarity': 0.5404109029445249}, 81 | {'id': 0, 'similarity': 0.5404109029445249}, 82 | {'id': 2, 'similarity': 0.5404109029445249}, 83 | {'id': 4, 'similarity': 0.5404109029445249}] 84 | ``` 85 | 86 | ## Methods 87 | 88 | ???- note "__call__" 89 | 90 | Retrieve documents from batch of queries. 91 | 92 | **Parameters** 93 | 94 | - **q** (*Union[str, List[str]]*) 95 | - **k** (*Optional[int]*) – defaults to `None` 96 | Number of documents to retrieve. Default is `None`, i.e all documents that match the query will be retrieved. 97 | - **batch_size** (*Optional[int]*) – defaults to `None` 98 | - **tqdm_bar** (*bool*) – defaults to `True` 99 | - **kwargs** 100 | 101 | ???- note "add" 102 | 103 | Add new documents to the TFIDF retriever. The tfidf won't be refitted. 104 | 105 | **Parameters** 106 | 107 | - **documents** (*list*) 108 | Documents in TFIdf retriever are static. The retriever must be reseted to index new documents. 109 | - **batch_size** (*int*) – defaults to `100000` 110 | - **tqdm_bar** (*bool*) – defaults to `False` 111 | - **kwargs** 112 | 113 | ???- note "top_k" 114 | 115 | Return the top k documents for each query. 116 | 117 | **Parameters** 118 | 119 | - **similarities** (*scipy.sparse._csc.csc_matrix*) 120 | - **k** (*int*) 121 | Number of documents to retrieve. Default is `None`, i.e all documents that match the query will be retrieved. 122 | 123 | ## References 124 | 125 | 1. [sklearn.feature_extraction.text.sparse.TfidfVectorizer](https://github.com/raphaelsty/LeNLP) 126 | 2. [Python: tf-idf-cosine: to find document similarity](https://stackoverflow.com/questions/12118720/python-tf-idf-cosine-to-find-document-similarity) 127 | 128 | -------------------------------------------------------------------------------- /docs/api/utils/.pages: -------------------------------------------------------------------------------- 1 | title: utils -------------------------------------------------------------------------------- /docs/api/utils/TopK.md: -------------------------------------------------------------------------------- 1 | # TopK 2 | 3 | Filter top k documents in pipeline. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **k** (*int*) 10 | 11 | Number of documents to keep. 12 | 13 | 14 | 15 | ## Examples 16 | 17 | ```python 18 | >>> from pprint import pprint as print 19 | >>> from cherche import retrieve, rank, utils 20 | >>> from sentence_transformers import SentenceTransformer 21 | 22 | >>> documents = [ 23 | ... {"id": 0, "title": "Paris France"}, 24 | ... {"id": 1, "title": "Madrid Spain"}, 25 | ... {"id": 2, "title": "Montreal Canada"} 26 | ... ] 27 | 28 | >>> retriever = retrieve.TfIdf(key="id", on=["title", "article"], documents=documents) 29 | 30 | >>> ranker = rank.Encoder( 31 | ... encoder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2").encode, 32 | ... key = "id", 33 | ... on = ["title"], 34 | ... ) 35 | 36 | >>> pipeline = retriever + ranker + utils.TopK(k=2) 37 | >>> pipeline.add(documents=documents) 38 | TfIdf retriever 39 | key : id 40 | on : title, article 41 | documents: 3 42 | Encoder ranker 43 | key : id 44 | on : title 45 | normalize : True 46 | embeddings: 3 47 | Filter TopK 48 | k: 2 49 | 50 | >>> print(pipeline(q="Paris Madrid Montreal", k=2)) 51 | [{'id': 0, 'similarity': 0.62922895}, {'id': 2, 'similarity': 0.61419094}] 52 | ``` 53 | 54 | ## Methods 55 | 56 | ???- note "__call__" 57 | 58 | Filter top k documents in pipeline. 59 | 60 | **Parameters** 61 | 62 | - **documents** (*List[List[Dict[str, str]]]*) 63 | - **kwargs** 64 | 65 | -------------------------------------------------------------------------------- /docs/api/utils/quantize.md: -------------------------------------------------------------------------------- 1 | # quantize 2 | 3 | Quantize model to speedup inference. May reduce accuracy. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **model** 10 | 11 | Transformer model to quantize. 12 | 13 | - **dtype** – defaults to `None` 14 | 15 | Dtype to apply to selected layers. 16 | 17 | - **layers** – defaults to `None` 18 | 19 | Layers to quantize. 20 | 21 | - **engine** – defaults to `qnnpack` 22 | 23 | The qengine specifies which backend is to be used for execution. 24 | 25 | 26 | 27 | ## Examples 28 | 29 | ```python 30 | >>> from pprint import pprint as print 31 | >>> from cherche import utils, retrieve 32 | >>> from sentence_transformers import SentenceTransformer 33 | 34 | >>> documents = [ 35 | ... {"id": 0, "title": "Paris France"}, 36 | ... {"id": 1, "title": "Madrid Spain"}, 37 | ... {"id": 2, "title": "Montreal Canada"} 38 | ... ] 39 | 40 | >>> encoder = utils.quantize(SentenceTransformer("sentence-transformers/all-mpnet-base-v2")) 41 | 42 | >>> retriever = retrieve.Encoder( 43 | ... encoder = encoder.encode, 44 | ... key = "id", 45 | ... on = ["title"], 46 | ... ) 47 | 48 | >>> retriever = retriever.add(documents) 49 | 50 | >>> print(retriever("paris")) 51 | [{'id': 0, 'similarity': 0.6361529519968355}, 52 | {'id': 2, 'similarity': 0.42750324298964354}, 53 | {'id': 1, 'similarity': 0.42645383885361576}] 54 | ``` 55 | 56 | ## References 57 | 58 | 1. [PyTorch Quantization](https://pytorch.org/docs/stable/quantization.html) 59 | 60 | -------------------------------------------------------------------------------- /docs/api/utils/yield-batch-single.md: -------------------------------------------------------------------------------- 1 | # yield_batch_single 2 | 3 | Yield successive n-sized chunks from array. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **array** (*Union[List[str], str, List[Dict[str, Any]]]*) 10 | 11 | - **desc** (*str*) 12 | 13 | - **tqdm_bar** (*bool*) – defaults to `True` 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /docs/api/utils/yield-batch.md: -------------------------------------------------------------------------------- 1 | # yield_batch 2 | 3 | Yield successive n-sized chunks from array. 4 | 5 | 6 | 7 | ## Parameters 8 | 9 | - **array** (*Union[List[str], str, List[Dict[str, Any]], numpy.ndarray]*) 10 | 11 | - **batch_size** (*int*) 12 | 13 | - **desc** (*str*) 14 | 15 | - **tqdm_bar** (*bool*) – defaults to `True` 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /docs/css/version-select.css: -------------------------------------------------------------------------------- 1 | @media only screen and (max-width:76.1875em) { 2 | #version-selector { 3 | padding: .6rem .8rem; 4 | } 5 | } -------------------------------------------------------------------------------- /docs/documents/.pages: -------------------------------------------------------------------------------- 1 | title: Documents 2 | nav: 3 | - documents.md 4 | - towns.md 5 | -------------------------------------------------------------------------------- /docs/documents/documents.md: -------------------------------------------------------------------------------- 1 | # Documents 2 | 3 | When using Cherche, we must define a document as a Python dictionary. A set of documents is simply a list of dictionaries. The name of the fields of the documents does not matter. We can choose the field(s) of your choice to perform neural search. However, it is mandatory to have a unique identifier for each document. Also, the name of this identifier does not matter. In the example below, the identifier is the `id` field. 4 | 5 | It can happen that not all documents have the same fields. Therefore, we do not need to standardize or fill all the fields (except the identifier). 6 | 7 | ```python 8 | [ 9 | { 10 | "id": 0, 11 | "title": "Paris", 12 | "url": "https://en.wikipedia.org/wiki/Paris", 13 | "article": "Paris (French pronunciation: \u200b[paʁi] (listen)) is the capital and most populous city of France, with an estimated population of 2,175,601 residents as of 2018, in an area of more than 105 square kilometres (41 square miles).", 14 | }, 15 | { 16 | "id": 1, 17 | "title": "Paris", 18 | "url": "https://en.wikipedia.org/wiki/Paris", 19 | "article": "Since the 17th century, Paris has been one of Europe's major centres of finance, diplomacy, commerce, fashion, gastronomy, science, and arts.", 20 | }, 21 | { 22 | "id": 2, 23 | "title": "Paris", 24 | "url": "https://en.wikipedia.org/wiki/Paris", 25 | "article": "The City of Paris is the centre and seat of government of the region and province of Île-de-France, or Paris Region, which has an estimated population of 12,174,880, or about 18 percent of the population of France as of 2017.", 26 | }, 27 | ] 28 | ``` 29 | -------------------------------------------------------------------------------- /docs/documents/towns.md: -------------------------------------------------------------------------------- 1 | # Towns 2 | 3 | Cherche provides a dummy dataset made of sentences from Wikipedia that describes towns such as Toulouse, Paris, Bordeaux and Lyon. This dataset is intended to easily test Cherche. It contains ~200 documents. 4 | 5 | ```python 6 | >>> from cherche import data 7 | >>> documents = data.load_towns() 8 | >>> documents[:3] 9 | ``` 10 | 11 | ```python 12 | [ 13 | { 14 | "id": 0, 15 | "title": "Paris", 16 | "url": "https://en.wikipedia.org/wiki/Paris", 17 | "article": "Paris (French pronunciation: \u200b[paʁi] (listen)) is the capital and most populous city of France, with an estimated population of 2,175,601 residents as of 2018, in an area of more than 105 square kilometres (41 square miles).", 18 | }, 19 | { 20 | "id": 1, 21 | "title": "Paris", 22 | "url": "https://en.wikipedia.org/wiki/Paris", 23 | "article": "Since the 17th century, Paris has been one of Europe's major centres of finance, diplomacy, commerce, fashion, gastronomy, science, and arts.", 24 | }, 25 | { 26 | "id": 2, 27 | "title": "Paris", 28 | "url": "https://en.wikipedia.org/wiki/Paris", 29 | "article": "The City of Paris is the centre and seat of government of the region and province of Île-de-France, or Paris Region, which has an estimated population of 12,174,880, or about 18 percent of the population of France as of 2017.", 30 | }, 31 | ] 32 | ``` 33 | -------------------------------------------------------------------------------- /docs/examples/.pages: -------------------------------------------------------------------------------- 1 | nav: 2 | - retriever_ranker.ipynb 3 | - retriever_ranker_qa.ipynb 4 | - union_intersection_rankers.ipynb 5 | - voting.ipynb 6 | - encoder_retriever.ipynb 7 | - eval_pipeline.ipynb -------------------------------------------------------------------------------- /docs/img/doc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raphaelsty/cherche/b640571a33b774a5157a07046e0aecb313960f14/docs/img/doc.png -------------------------------------------------------------------------------- /docs/img/explain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raphaelsty/cherche/b640571a33b774a5157a07046e0aecb313960f14/docs/img/explain.png -------------------------------------------------------------------------------- /docs/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raphaelsty/cherche/b640571a33b774a5157a07046e0aecb313960f14/docs/img/logo.png -------------------------------------------------------------------------------- /docs/img/renault.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raphaelsty/cherche/b640571a33b774a5157a07046e0aecb313960f14/docs/img/renault.jpg -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 |
2 |

Cherche

3 |

Neural search

4 |
5 | 6 |
7 | 8 | documentation 9 | 10 | Demo 11 | 12 | license 13 |
14 | 15 | ## Installation 🤖 16 | 17 | To install Cherche for use with a simple retriever on CPU, such as TfIdf, Flash or Lunr, use the following command: 18 | 19 | ```sh 20 | pip install cherche 21 | ``` 22 | 23 | To install Cherche for use with any semantic retriever or ranker on CPU, use the following command: 24 | 25 | ```sh 26 | pip install "cherche[cpu]" 27 | ``` 28 | 29 | Finally, if you plan to use any semantic retriever or ranker on GPU, use the following command: 30 | 31 | ```sh 32 | pip install "cherche[gpu]" 33 | ``` 34 | 35 | By following these installation instructions, you will be able to use Cherche with the appropriate requirements for your needs. 36 | 37 | Links to the documentation: 38 | 39 | - [Retriever](https://raphaelsty.github.io/cherche/retrieve/retrieve/) 40 | 41 | - [Ranker](https://raphaelsty.github.io/cherche/rank/rank/) 42 | 43 | - [Question answering](https://raphaelsty.github.io/cherche/qa/qa/) 44 | 45 | - [Pipeline](https://raphaelsty.github.io/cherche/pipeline/pipeline/) 46 | 47 | - [Examples](https://raphaelsty.github.io/cherche/examples/retriever_ranker/) 48 | 49 | - [API reference](https://raphaelsty.github.io/cherche/api/overview/) -------------------------------------------------------------------------------- /docs/javascripts/config.js: -------------------------------------------------------------------------------- 1 | window.MathJax = { 2 | tex: { 3 | inlineMath: [["\\(", "\\)"]], 4 | displayMath: [["\\[", "\\]"]], 5 | processEscapes: true, 6 | processEnvironments: true 7 | }, 8 | options: { 9 | ignoreHtmlClass: ".*|", 10 | processHtmlClass: "arithmatex" 11 | } 12 | }; -------------------------------------------------------------------------------- /docs/js/version-select.js: -------------------------------------------------------------------------------- 1 | window.addEventListener("DOMContentLoaded", function () { 2 | // This is a bit hacky. Figure out the base URL from a known CSS file the 3 | // template refers to... 4 | var ex = new RegExp("/?css/version-select.css$"); 5 | var sheet = document.querySelector('link[href$="version-select.css"]'); 6 | 7 | var ABS_BASE_URL = sheet.href.replace(ex, ""); 8 | var CURRENT_VERSION = ABS_BASE_URL.split("/").pop(); 9 | 10 | function makeSelect(options, selected) { 11 | var select = document.createElement("select"); 12 | select.classList.add("form-control"); 13 | 14 | options.forEach(function (i) { 15 | var option = new Option(i.text, i.value, undefined, 16 | i.value === selected); 17 | select.add(option); 18 | }); 19 | 20 | return select; 21 | } 22 | 23 | var xhr = new XMLHttpRequest(); 24 | xhr.open("GET", ABS_BASE_URL + "/../versions.json"); 25 | xhr.onload = function () { 26 | var versions = JSON.parse(this.responseText); 27 | 28 | var realVersion = versions.find(function (i) { 29 | return i.version === CURRENT_VERSION || 30 | i.aliases.includes(CURRENT_VERSION); 31 | }).version; 32 | 33 | var select = makeSelect(versions.map(function (i) { 34 | return { text: i.title, value: i.version }; 35 | }), realVersion); 36 | select.addEventListener("change", function (event) { 37 | window.location.href = ABS_BASE_URL + "/../" + this.value; 38 | }); 39 | 40 | var container = document.createElement("div"); 41 | container.id = "version-selector"; 42 | container.className = "md-nav__item"; 43 | container.appendChild(select); 44 | 45 | var sidebar = document.querySelector(".md-nav--primary > .md-nav__list"); 46 | sidebar.parentNode.insertBefore(container, sidebar); 47 | }; 48 | xhr.send(); 49 | }); -------------------------------------------------------------------------------- /docs/pipeline/.pages: -------------------------------------------------------------------------------- 1 | title: Pipeline 2 | nav: 3 | - pipeline.md 4 | -------------------------------------------------------------------------------- /docs/qa/.pages: -------------------------------------------------------------------------------- 1 | title: QA 2 | nav: 3 | - qa.md 4 | -------------------------------------------------------------------------------- /docs/qa/qa.md: -------------------------------------------------------------------------------- 1 | # Question Answering 2 | 3 | The `qa.QA` module is a crucial component of our neural search pipeline, integrating an extractive question answering model that is compatible with [Hugging Face](https://huggingface.co/models?pipeline_tag=question-answering). This model efficiently extracts the most likely answer spans from a list of documents in response to user queries. To further expedite the search process, our neural search pipeline filters the entire corpus and narrows down the search to a few relevant documents, resulting in faster response times for top answers. However, it's worth noting that even with corpus filtering, question answering models can be slow when using a CPU and typically require a GPU to achieve optimal performance. 4 | 5 | ## Documents 6 | 7 | The pipeline must provide the documents and not only the identifiers to the question answering model such as: 8 | 9 | ```python 10 | search = pipeline + documents + question_answering 11 | ``` 12 | 13 | ## Tutorial 14 | 15 | ```python 16 | >>> from cherche import data, rank, retrieve, qa 17 | >>> from sentence_transformers import SentenceTransformer 18 | >>> from transformers import pipeline 19 | 20 | >>> documents = data.load_towns() 21 | 22 | >>> retriever = retrieve.TfIdf(key="id", on=["title", "article"], documents=documents, k=100) 23 | 24 | >>> ranker = rank.Encoder( 25 | ... key = "id", 26 | ... on = "article", 27 | ... encoder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2").encode, 28 | ... k = 30, 29 | ... ) 30 | 31 | >>> question_answering = qa.QA( 32 | ... model = pipeline("question-answering", 33 | ... model = "deepset/roberta-base-squad2", 34 | ... tokenizer = "deepset/roberta-base-squad2" 35 | ... ), 36 | ... on = "article", 37 | ... ) 38 | 39 | >>> search = retriever + ranker + documents + question_answering 40 | >>> search.add(documents) 41 | >>> answers = search( 42 | ... q=[ 43 | ... "What is the name of the football club of Paris?", 44 | ... "What is the speciality of Lyon?" 45 | ... ] 46 | ... ) 47 | 48 | # The answer is Paris Saint-Germain 49 | >>> answers[0][0] 50 | {'id': 20, 51 | 'title': 'Paris', 52 | 'url': 'https://en.wikipedia.org/wiki/Paris', 53 | 'article': 'The football club Paris Saint-Germain and the rugby union club Stade Français are based in Paris.', 54 | 'similarity': 0.6905894, 55 | 'score': 0.9848365783691406, 56 | 'start': 18, 57 | 'end': 37, 58 | 'answer': 'Paris Saint-Germain', 59 | 'question': 'What is the name of the football club of Paris?'} 60 | 61 | 62 | >>> answers[1][0] 63 | {'id': 52, 64 | 'title': 'Lyon', 65 | 'url': 'https://en.wikipedia.org/wiki/Lyon', 66 | 'article': 'Economically, Lyon is a major centre for banking, as well as for the chemical, pharmaceutical and biotech industries.', 67 | 'similarity': 0.64728546, 68 | 'score': 0.6952874660491943, 69 | 'start': 41, 70 | 'end': 48, 71 | 'answer': 'banking', 72 | 'question': 'What is the speciality of Lyon?'} 73 | ``` 74 | -------------------------------------------------------------------------------- /docs/rank/.pages: -------------------------------------------------------------------------------- 1 | title: Rank 2 | nav: 3 | - rank.md 4 | - encoder.md 5 | - dpr.md 6 | - crossencoder.md 7 | - embedding.md 8 | -------------------------------------------------------------------------------- /docs/rank/dpr.md: -------------------------------------------------------------------------------- 1 | # rank.DPR 2 | 3 | The `rank.DPR` model re-ranks documents in ouput of the retriever. `rank.DPR` is dedicated to the [Dense Passage Retrieval](https://arxiv.org/abs/2004.04906) models which aims to use two distinct neural networks, one that encodes the query and the other one that encodes the documents. 4 | 5 | The `rank.DPR` can pre-compute the set of document embeddings to speed up search and avoiding computing embeddings twice using method `.add`. A GPU will significantly reduce pre-computing time dedicated to document embeddings. 6 | 7 | ## Tutorial 8 | 9 | To use the DPR ranker we will need to install "cherche[cpu]" 10 | 11 | ```sh 12 | pip install "cherche[cpu]" 13 | ``` 14 | 15 | or on GPU: 16 | 17 | ```sh 18 | pip install "cherche[gpu]" 19 | ``` 20 | 21 | ## Tutorial 22 | 23 | ```python 24 | >>> from cherche import retrieve, rank 25 | >>> from sentence_transformers import SentenceTransformer 26 | 27 | >>> documents = [ 28 | ... { 29 | ... "id": 0, 30 | ... "article": "Paris is the capital and most populous city of France", 31 | ... "title": "Paris", 32 | ... "url": "https://en.wikipedia.org/wiki/Paris" 33 | ... }, 34 | ... { 35 | ... "id": 1, 36 | ... "article": "Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.", 37 | ... "title": "Paris", 38 | ... "url": "https://en.wikipedia.org/wiki/Paris" 39 | ... }, 40 | ... { 41 | ... "id": 2, 42 | ... "article": "The City of Paris is the centre and seat of government of the region and province of Île-de-France .", 43 | ... "title": "Paris", 44 | ... "url": "https://en.wikipedia.org/wiki/Paris" 45 | ... } 46 | ... ] 47 | 48 | >>> retriever = retrieve.TfIdf(key="id", on=["title", "article"], documents=documents) 49 | 50 | >>> ranker = rank.DPR( 51 | ... key = "id", 52 | ... on = ["title", "article"], 53 | ... encoder = SentenceTransformer('facebook-dpr-ctx_encoder-single-nq-base').encode, 54 | ... query_encoder = SentenceTransformer('facebook-dpr-question_encoder-single-nq-base').encode, 55 | ... normalize=True, 56 | ... ) 57 | 58 | >>> ranker.add(documents, batch_size=64) 59 | 60 | >>> match = retriever(["paris", "art", "fashion"], k=100) 61 | 62 | # Re-rank output of retriever 63 | >>> ranker(["paris", "art", "fashion"], documents=match, k=30) 64 | [[{'id': 0, 'similarity': 8.163156}, # Query 1 65 | {'id': 1, 'similarity': 8.021494}, 66 | {'id': 2, 'similarity': 7.8683443}], 67 | [{'id': 1, 'similarity': 5.4577255}], # Query 2 68 | [{'id': 1, 'similarity': 6.8593264}, {'id': 2, 'similarity': 6.1895266}]] # Query 3 69 | ``` 70 | 71 | ## Ranker in pipeline 72 | 73 | ```python 74 | >>> from cherche import retrieve, rank 75 | >>> from sentence_transformers import SentenceTransformer 76 | 77 | >>> documents = [ 78 | ... { 79 | ... "id": 0, 80 | ... "article": "Paris is the capital and most populous city of France", 81 | ... "title": "Paris", 82 | ... "url": "https://en.wikipedia.org/wiki/Paris" 83 | ... }, 84 | ... { 85 | ... "id": 1, 86 | ... "article": "Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.", 87 | ... "title": "Paris", 88 | ... "url": "https://en.wikipedia.org/wiki/Paris" 89 | ... }, 90 | ... { 91 | ... "id": 2, 92 | ... "article": "The City of Paris is the centre and seat of government of the region and province of Île-de-France .", 93 | ... "title": "Paris", 94 | ... "url": "https://en.wikipedia.org/wiki/Paris" 95 | ... } 96 | ... ] 97 | 98 | >>> retriever = retrieve.TfIdf(key="id", on=["title", "article"], documents=documents, k=100) 99 | 100 | >>> ranker = rank.DPR( 101 | ... key = "id", 102 | ... on = ["title", "article"], 103 | ... encoder = SentenceTransformer('facebook-dpr-ctx_encoder-single-nq-base').encode, 104 | ... query_encoder = SentenceTransformer('facebook-dpr-question_encoder-single-nq-base').encode, 105 | ... k = 30, 106 | ... ) 107 | 108 | >>> search = retriever + ranker 109 | >>> search.add(documents, batch_size=64) 110 | >>> search(q=["paris", "arts", "fashion"]) 111 | [[{'id': 0, 'similarity': 8.163156}, # Query 1 112 | {'id': 1, 'similarity': 8.021494}, 113 | {'id': 2, 'similarity': 7.8683443}], 114 | [{'id': 1, 'similarity': 5.4577255}], # Query 2 115 | [{'id': 1, 'similarity': 6.8593264}, {'id': 2, 'similarity': 6.1895266}]] # Query 3 116 | ``` 117 | 118 | ## Map index to documents 119 | 120 | We can map the documents to the ids retrieved by the pipeline. 121 | 122 | ```python 123 | >>> search += documents 124 | >>> search(q="arts") 125 | [{'id': 1, 126 | 'article': 'Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.', 127 | 'title': 'Paris', 128 | 'url': 'https://en.wikipedia.org/wiki/Paris', 129 | 'similarity': 0.21684971}] 130 | ``` 131 | 132 | ## Pre-trained models 133 | 134 | Here is the list of models provided by [SentenceTransformers](https://www.sbert.net/docs/pretrained_models.html). This list of models is not exhaustive; there is a wide range of models available with [Hugging Face](https://huggingface.co/models?pipeline_tag=sentence-similarity&sort=downloads) and in many languages. 135 | -------------------------------------------------------------------------------- /docs/rank/embedding.md: -------------------------------------------------------------------------------- 1 | # rank.Embedding 2 | 3 | The `rank.Embedding` model utilizes pre-computed embeddings to re-rank documents within the output of the retriever. If you have a custom model that produces its own embeddings and want to re-rank documents accordingly, `rank.Embedding` is the ideal tool for the job. 4 | 5 | ## Tutorial 6 | 7 | ```python 8 | >>> from cherche import retrieve, rank 9 | >>> from sentence_transformers import SentenceTransformer 10 | 11 | >>> documents = [ 12 | ... { 13 | ... "id": 0, 14 | ... "article": "Paris is the capital and most populous city of France", 15 | ... "title": "Paris", 16 | ... "url": "https://en.wikipedia.org/wiki/Paris" 17 | ... }, 18 | ... { 19 | ... "id": 1, 20 | ... "article": "Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.", 21 | ... "title": "Paris", 22 | ... "url": "https://en.wikipedia.org/wiki/Paris" 23 | ... }, 24 | ... { 25 | ... "id": 2, 26 | ... "article": "The City of Paris is the centre and seat of government of the region and province of Île-de-France .", 27 | ... "title": "Paris", 28 | ... "url": "https://en.wikipedia.org/wiki/Paris" 29 | ... } 30 | ... ] 31 | 32 | # Let's use a custom encoder and create our documents embeddings of shape (n_documents, dim_embeddings) 33 | >>> encoder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") 34 | >>> embeddings_documents = encoder.encode([ 35 | ... document["article"] for document in documents 36 | ... ]) 37 | 38 | >>> queries = ["paris", "art", "fashion"] 39 | 40 | # Queries embeddings of shape (n_queries, dim_embeddings) 41 | >>> embeddings_queries = encoder.encode(queries) 42 | 43 | >>> retriever = retrieve.TfIdf(key="id", on=["title", "article"], documents=documents) 44 | 45 | >>> ranker = rank.Embedding( 46 | ... key = "id", 47 | ... normalize = True, 48 | ... ) 49 | 50 | >>> ranker = ranker.add( 51 | ... documents=documents, 52 | ... embeddings_documents=embeddings_documents, 53 | ... ) 54 | 55 | >>> match = retriever(queries, k=100) 56 | 57 | # Re-rank output of retriever 58 | >>> ranker(q=embeddings_queries, documents=match, k=30) 59 | [[{'id': 0, 'similarity': 0.6560695}, # Query 1 60 | {'id': 1, 'similarity': 0.58203197}, 61 | {'id': 2, 'similarity': 0.5283624}], 62 | [{'id': 1, 'similarity': 0.1115652}], # Query 2 63 | [{'id': 1, 'similarity': 0.2555524}, {'id': 2, 'similarity': 0.06398084}]] # Query 3 64 | ``` 65 | 66 | ## Map index to documents 67 | 68 | We can map the documents to the ids retrieved by the pipeline. 69 | 70 | ```python 71 | >>> ranker += documents 72 | >>> match = retriever(queries, k=100) 73 | >>> ranker(q=embeddings_queries, documents=match, k=30) 74 | [[{'id': 0, 75 | 'article': 'Paris is the capital and most populous city of France', 76 | 'title': 'Paris', 77 | 'url': 'https://en.wikipedia.org/wiki/Paris', 78 | 'similarity': 0.6560695}, 79 | {'id': 1, 80 | 'article': 'Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.', 81 | 'title': 'Paris', 82 | 'url': 'https://en.wikipedia.org/wiki/Paris', 83 | 'similarity': 0.58203197}, 84 | {'id': 2, 85 | 'article': 'The City of Paris is the centre and seat of government of the region and province of Île-de-France .', 86 | 'title': 'Paris', 87 | 'url': 'https://en.wikipedia.org/wiki/Paris', 88 | 'similarity': 0.5283624}], 89 | [{'id': 1, 90 | 'article': 'Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.', 91 | 'title': 'Paris', 92 | 'url': 'https://en.wikipedia.org/wiki/Paris', 93 | 'similarity': 0.1115652}], 94 | [{'id': 1, 95 | 'article': 'Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.', 96 | 'title': 'Paris', 97 | 'url': 'https://en.wikipedia.org/wiki/Paris', 98 | 'similarity': 0.2555524}, 99 | {'id': 2, 100 | 'article': 'The City of Paris is the centre and seat of government of the region and province of Île-de-France .', 101 | 'title': 'Paris', 102 | 'url': 'https://en.wikipedia.org/wiki/Paris', 103 | 'similarity': 0.06398084}]] 104 | ``` 105 | -------------------------------------------------------------------------------- /docs/rank/encoder.md: -------------------------------------------------------------------------------- 1 | # rank.Encoder 2 | 3 | The `rank.Encoder` model re-ranks documents in ouput of the retriever using a pre-trained [SentenceTransformers](https://www.sbert.net/docs/pretrained_models.html). 4 | 5 | The `rank.Encoder` can pre-compute the set of document embeddings to speed up search and avoiding computing embeddings twice using method `.add`. A GPU will significantly reduce pre-computing time dedicated to document embeddings. 6 | 7 | ## Requirements 8 | 9 | To use the Encoder ranker we will need to install "cherche[cpu]" 10 | 11 | ```sh 12 | pip install "cherche[cpu]" 13 | ``` 14 | 15 | or on GPU: 16 | 17 | ```sh 18 | pip install "cherche[gpu]" 19 | ``` 20 | 21 | ## Tutorial 22 | 23 | ```python 24 | >>> from cherche import retrieve, rank 25 | >>> from sentence_transformers import SentenceTransformer 26 | 27 | >>> documents = [ 28 | ... { 29 | ... "id": 0, 30 | ... "article": "Paris is the capital and most populous city of France", 31 | ... "title": "Paris", 32 | ... "url": "https://en.wikipedia.org/wiki/Paris" 33 | ... }, 34 | ... { 35 | ... "id": 1, 36 | ... "article": "Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.", 37 | ... "title": "Paris", 38 | ... "url": "https://en.wikipedia.org/wiki/Paris" 39 | ... }, 40 | ... { 41 | ... "id": 2, 42 | ... "article": "The City of Paris is the centre and seat of government of the region and province of Île-de-France .", 43 | ... "title": "Paris", 44 | ... "url": "https://en.wikipedia.org/wiki/Paris" 45 | ... } 46 | ... ] 47 | 48 | >>> retriever = retrieve.TfIdf(key="id", on=["title", "article"], documents=documents) 49 | 50 | >>> ranker = rank.Encoder( 51 | ... key = "id", 52 | ... on = ["title", "article"], 53 | ... encoder = SentenceTransformer(f"sentence-transformers/all-mpnet-base-v2").encode, 54 | ... normalize=True, 55 | ... ) 56 | 57 | >>> ranker.add(documents, batch_size=64) 58 | 59 | >>> match = retriever(["paris", "art", "fashion"], k=100) 60 | 61 | # Re-rank output of retriever 62 | >>> ranker(["paris", "art", "fashion"], documents=match, k=30) 63 | [[{'id': 0, 'similarity': 0.6638489}, # Query 1 64 | {'id': 2, 'similarity': 0.602515}, 65 | {'id': 1, 'similarity': 0.60133684}], 66 | [{'id': 1, 'similarity': 0.10321068}], # Query 2 67 | [{'id': 1, 'similarity': 0.26405674}, {'id': 2, 'similarity': 0.096046045}]] # Query 3 68 | ``` 69 | 70 | ## Ranker in pipeline 71 | 72 | ```python 73 | >>> from cherche import retrieve, rank 74 | >>> from sentence_transformers import SentenceTransformer 75 | 76 | >>> documents = [ 77 | ... { 78 | ... "id": 0, 79 | ... "article": "Paris is the capital and most populous city of France", 80 | ... "title": "Paris", 81 | ... "url": "https://en.wikipedia.org/wiki/Paris" 82 | ... }, 83 | ... { 84 | ... "id": 1, 85 | ... "article": "Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.", 86 | ... "title": "Paris", 87 | ... "url": "https://en.wikipedia.org/wiki/Paris" 88 | ... }, 89 | ... { 90 | ... "id": 2, 91 | ... "article": "The City of Paris is the centre and seat of government of the region and province of Île-de-France .", 92 | ... "title": "Paris", 93 | ... "url": "https://en.wikipedia.org/wiki/Paris" 94 | ... } 95 | ... ] 96 | 97 | >>> retriever = retrieve.TfIdf(key="id", on=["title", "article"], documents=documents, k=100) 98 | 99 | >>> ranker = rank.Encoder( 100 | ... key = "id", 101 | ... on = ["title", "article"], 102 | ... encoder = SentenceTransformer(f"sentence-transformers/all-mpnet-base-v2").encode, 103 | ... k = 30, 104 | ... normalize=True, 105 | ... ) 106 | 107 | >>> search = retriever + ranker 108 | >>> search.add(documents) 109 | >>> search(q=["paris", "arts", "fashion"]) 110 | [[{'id': 0, 'similarity': 0.6638489}, # Query 1 111 | {'id': 2, 'similarity': 0.602515}, 112 | {'id': 1, 'similarity': 0.60133684}], 113 | [{'id': 1, 'similarity': 0.21684976}], # Query 2 114 | [{'id': 1, 'similarity': 0.26405674}, {'id': 2, 'similarity': 0.096046045}]] # Query 3 115 | ``` 116 | 117 | ## Map index to documents 118 | 119 | We can map the documents to the ids retrieved by the pipeline. 120 | 121 | ```python 122 | >>> search += documents 123 | >>> search(q="arts") 124 | [{'id': 1, 125 | 'article': 'Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.', 126 | 'title': 'Paris', 127 | 'url': 'https://en.wikipedia.org/wiki/Paris', 128 | 'similarity': 0.21684971}] 129 | ``` 130 | 131 | ## Pre-trained encoders 132 | 133 | Here is the list of models provided by [SentenceTransformers](https://www.sbert.net/docs/pretrained_models.html). This list of models is not exhaustive; there is a wide range of models available with [Hugging Face](https://huggingface.co/models?pipeline_tag=sentence-similarity&sort=downloads) and in many languages. 134 | -------------------------------------------------------------------------------- /docs/rank/rank.md: -------------------------------------------------------------------------------- 1 | # Rank 2 | 3 | Rankers are models that measure the semantic similarity between a document and a query. Rankers filter out documents based on the semantic similarity between the query and the documents. Rankers are compatible with all the retrievers. 4 | 5 | 6 | | Ranker | Precomputing | GPU | 7 | |:---------------:|:------------:|:---------------------------------------------------------------------------------------------------------------------:| 8 | | ranker.Encoder | ✅ | Highly recommended when precomputing
embeddings if the corpus is large.
Not needed anymore after precomputing | 9 | | ranker.DPR | ✅ | Highly recommended when precomputing
embeddings if the corpus is large.
Not needed anymore after precomputing | 10 | | ranker.CrossEncoder | ❌ | Highly recommended since
ranker.ZeroShot cannot precompute
embeddings | 11 | | ranker.Embedding | ❌ | Not needed | 12 | 13 | 14 | The `rank.Encoder` and `rank.DPR` rankers pre-compute the document embeddings once for all with the `add` method. This step can be time-consuming if we do not have a GPU. The embeddings are pre-computed so that the model can then rank the retriever documents at lightning speed. 15 | 16 | ## Requirements 17 | 18 | To use the Encoder ranker we will need to install "cherche[cpu]" 19 | 20 | ```sh 21 | pip install "cherche[cpu]" 22 | ``` 23 | 24 | or on GPU: 25 | 26 | ```sh 27 | pip install "cherche[gpu]" 28 | ``` 29 | 30 | ## Tutorial 31 | 32 | ```python 33 | >>> from cherche import retrieve, rank 34 | >>> from sentence_transformers import SentenceTransformer 35 | 36 | >>> documents = [ 37 | ... { 38 | ... "id": 0, 39 | ... "article": "Paris is the capital and most populous city of France", 40 | ... "title": "Paris", 41 | ... "url": "https://en.wikipedia.org/wiki/Paris" 42 | ... }, 43 | ... { 44 | ... "id": 1, 45 | ... "article": "Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.", 46 | ... "title": "Paris", 47 | ... "url": "https://en.wikipedia.org/wiki/Paris" 48 | ... }, 49 | ... { 50 | ... "id": 2, 51 | ... "article": "The City of Paris is the centre and seat of government of the region and province of Île-de-France .", 52 | ... "title": "Paris", 53 | ... "url": "https://en.wikipedia.org/wiki/Paris" 54 | ... } 55 | ... ] 56 | 57 | >>> retriever = retrieve.TfIdf(key="id", on=["title", "article"], documents=documents) 58 | 59 | >>> ranker = rank.Encoder( 60 | ... key = "id", 61 | ... on = ["title", "article"], 62 | ... encoder = SentenceTransformer(f"sentence-transformers/all-mpnet-base-v2").encode, 63 | ... ) 64 | 65 | # Pre-compute embeddings 66 | >>> ranker.add(documents, batch_size=64) 67 | 68 | >>> match = retriever(["paris", "art", "fashion"], k=100) 69 | 70 | # Re-rank output of retriever 71 | >>> ranker(["paris", "art", "fashion"], documents=match, k=30) 72 | [[{'id': 0, 'similarity': 0.6638489}, 73 | {'id': 2, 'similarity': 0.602515}, 74 | {'id': 1, 'similarity': 0.60133684}], 75 | [{'id': 1, 'similarity': 0.10321068}], 76 | [{'id': 1, 'similarity': 0.26405674}, {'id': 2, 'similarity': 0.096046045}]] 77 | ``` 78 | -------------------------------------------------------------------------------- /docs/retrieve/.pages: -------------------------------------------------------------------------------- 1 | title: Retrieve 2 | nav: 3 | - retrieve.md 4 | - bm25.md 5 | - tfidf.md 6 | - flash.md 7 | - lunr.md 8 | - fuzz.md 9 | - encoder.md 10 | - dpr.md 11 | - embedding.md -------------------------------------------------------------------------------- /docs/retrieve/bm25.md: -------------------------------------------------------------------------------- 1 | # BM25 2 | 3 | Our BM25 retriever relies on the [sparse.BM25Vectorizer](https://github.com/raphaelsty/LeNLP) of LeNLP. 4 | 5 | ```python 6 | >>> from cherche import retrieve 7 | 8 | >>> documents = [ 9 | ... { 10 | ... "id": 0, 11 | ... "article": "Paris is the capital and most populous city of France", 12 | ... "title": "Paris", 13 | ... "url": "https://en.wikipedia.org/wiki/Paris" 14 | ... }, 15 | ... { 16 | ... "id": 1, 17 | ... "article": "Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.", 18 | ... "title": "Paris", 19 | ... "url": "https://en.wikipedia.org/wiki/Paris" 20 | ... }, 21 | ... { 22 | ... "id": 2, 23 | ... "article": "The City of Paris is the centre and seat of government of the region and province of Île-de-France .", 24 | ... "title": "Paris", 25 | ... "url": "https://en.wikipedia.org/wiki/Paris" 26 | ... } 27 | ... ] 28 | 29 | >>> retriever = retrieve.BM25(key="id", on=["title", "article"], documents=documents, k=30) 30 | 31 | >>> retriever("france") 32 | [{'id': 0, 'similarity': 0.1236413097778466}, 33 | {'id': 2, 'similarity': 0.08907655343363269}, 34 | {'id': 1, 'similarity': 0.0031730868527342104}] 35 | ``` 36 | 37 | We can also initialize the retriever with a custom [sparse.BM25Vectorizer](https://github.com/raphaelsty/LeNLP). 38 | 39 | 40 | 41 | ```python 42 | >>> from cherche import retrieve 43 | >>> from lenlp import sparse 44 | 45 | >>> documents = [ 46 | ... { 47 | ... "id": 0, 48 | ... "article": "Paris is the capital and most populous city of France", 49 | ... "title": "Paris", 50 | ... "url": "https://en.wikipedia.org/wiki/Paris" 51 | ... }, 52 | ... { 53 | ... "id": 1, 54 | ... "article": "Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.", 55 | ... "title": "Paris", 56 | ... "url": "https://en.wikipedia.org/wiki/Paris" 57 | ... }, 58 | ... { 59 | ... "id": 2, 60 | ... "article": "The City of Paris is the centre and seat of government of the region and province of Île-de-France .", 61 | ... "title": "Paris", 62 | ... "url": "https://en.wikipedia.org/wiki/Paris" 63 | ... } 64 | ... ] 65 | 66 | >>> count_vectorizer = sparse.BM25Vectorizer( 67 | ... normalize=True, ngram_range=(3, 7), analyzer="char_wb") 68 | 69 | >>> retriever = retrieve.BM25Vectorizer( 70 | ... key="id", on=["title", "article"], documents=documents, count_vectorizer=count_vectorizer) 71 | 72 | >>> retriever("fra", k=3) 73 | [{'id': 0, 'similarity': 0.15055477454160002}, 74 | {'id': 2, 'similarity': 0.022883459495904895}] 75 | ``` 76 | 77 | ## Batch retrieval 78 | 79 | If we have several queries for which we want to retrieve the top k documents then we can 80 | pass a list of queries to the retriever. This is much faster for multiple queries. In batch-mode, 81 | retriever returns a list of list of documents instead of a list of documents. 82 | 83 | ```python 84 | >>> retriever(["fra", "arts", "capital"], k=3) 85 | [[{'id': 0, 'similarity': 0.051000705070125066}, # Match query 1 86 | {'id': 2, 'similarity': 0.03415513704304113}], 87 | [{'id': 1, 'similarity': 0.07021399356970497}], # Match query 2 88 | [{'id': 0, 'similarity': 0.25972148184421534}]] # Match query 3 89 | ``` 90 | 91 | ## Map keys to documents 92 | 93 | We can map documents to retrieved keys. 94 | 95 | ```python 96 | >>> retriever += documents 97 | >>> retriever("fra") 98 | [{'id': 0, 99 | 'article': 'Paris is the capital and most populous city of France', 100 | 'title': 'Paris', 101 | 'url': 'https://en.wikipedia.org/wiki/Paris', 102 | 'similarity': 0.15055477454160002}, 103 | {'id': 2, 104 | 'article': 'The City of Paris is the centre and seat of government of the region and province of Île-de-France .', 105 | 'title': 'Paris', 106 | 'url': 'https://en.wikipedia.org/wiki/Paris', 107 | 'similarity': 0.022883459495904895}] 108 | ``` 109 | -------------------------------------------------------------------------------- /docs/retrieve/flash.md: -------------------------------------------------------------------------------- 1 | # Flash 2 | 3 | Flash is a wrapper of [FlashText](https://github.com/vi3k6i5/flashtext). This great algorithm can 4 | retrieve keywords in documents faster than anything else. We can find more information about Flash in [Replace or Retrieve Keywords In Documents At Scale](https://arxiv.org/pdf/1711.00046.pdf). 5 | 6 | We can use Flash to find documents from a field that contains a keyword or a list of keywords. 7 | Flash will find documents that contain the keyword or keywords specified in the query. 8 | 9 | We can update the Flash retriever with new documents using mini-batch via the `add` method. 10 | 11 | ```python 12 | >>> from cherche import retrieve 13 | 14 | >>> documents = [ 15 | ... { 16 | ... "id": 0, 17 | ... "article": "Paris is the capital and most populous city of France", 18 | ... "title": "Paris", 19 | ... "url": "https://en.wikipedia.org/wiki/Paris", 20 | ... "tags": ["paris", "france", "capital"] 21 | ... }, 22 | ... { 23 | ... "id": 1, 24 | ... "article": "Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.", 25 | ... "title": "Paris", 26 | ... "url": "https://en.wikipedia.org/wiki/Paris", 27 | ... "tags": ["paris", "france", "capital", "fashion"] 28 | ... }, 29 | ... { 30 | ... "id": 2, 31 | ... "article": "The City of Paris is the centre and seat of government of the region and province of Île-de-France .", 32 | ... "title": "Paris", 33 | ... "url": "https://en.wikipedia.org/wiki/Paris", 34 | ... "tags": "paris" 35 | ... } 36 | ... ] 37 | 38 | >>> retriever = retrieve.Flash(key="id", on="tags") 39 | 40 | >>> retriever.add(documents=documents) 41 | 42 | >>> retriever("fashion") 43 | [{'id': 1}] 44 | ``` 45 | 46 | ## Map keys to documents 47 | 48 | ```python 49 | >>> retriever += documents 50 | >>> retriever("fashion") 51 | [{'id': 1, 52 | 'article': 'Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.', 53 | 'title': 'Paris', 54 | 'url': 'https://en.wikipedia.org/wiki/Paris', 55 | 'tags': ['paris', 'france', 'capital', 'fashion']}] 56 | ``` 57 | -------------------------------------------------------------------------------- /docs/retrieve/fuzz.md: -------------------------------------------------------------------------------- 1 | # Fuzz 2 | 3 | `retrieve.Fuzz` is a wrapper of [RapidFuzz](https://github.com/maxbachmann/RapidFuzz). It is a blazing fast library dedicated to fuzzy string matching. Documents can be indexed online with this retriever using the `add` method. 4 | 5 | [RapidFuzz](https://github.com/maxbachmann/RapidFuzz) provides more scoring functions for the fuzzy string matching task. We can select the most suitable method for our dataset with the `fuzzer` parameter. The default scoring function is `fuzz.partial_ratio`. 6 | 7 | ```python 8 | >>> from cherche import retrieve 9 | >>> from rapidfuzz import fuzz 10 | 11 | >>> documents = [ 12 | ... { 13 | ... "id": 0, 14 | ... "article": "Paris is the capital and most populous city of France", 15 | ... "title": "Paris", 16 | ... "url": "https://en.wikipedia.org/wiki/Paris" 17 | ... }, 18 | ... { 19 | ... "id": 1, 20 | ... "article": "Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.", 21 | ... "title": "Paris", 22 | ... "url": "https://en.wikipedia.org/wiki/Paris" 23 | ... }, 24 | ... { 25 | ... "id": 2, 26 | ... "article": "The City of Paris is the centre and seat of government of the region and province of Île-de-France .", 27 | ... "title": "Paris", 28 | ... "url": "https://en.wikipedia.org/wiki/Paris" 29 | ... } 30 | ... ] 31 | 32 | # List of available scoring function 33 | >>> scoring = [ 34 | ... fuzz.ratio, 35 | ... fuzz.partial_ratio, 36 | ... fuzz.token_set_ratio, 37 | ... fuzz.partial_token_set_ratio, 38 | ... fuzz.token_sort_ratio, 39 | ... fuzz.partial_token_sort_ratio, 40 | ... fuzz.token_ratio, 41 | ... fuzz.partial_token_ratio, 42 | ... fuzz.WRatio, 43 | ... fuzz.QRatio, 44 | ... ] 45 | 46 | >>> retriever = retrieve.Fuzz( 47 | ... key = "id", 48 | ... on = ["title", "article"], 49 | ... fuzzer = fuzz.partial_ratio, # Choose the scoring function. 50 | ... ) 51 | 52 | # Index documents 53 | >>> retriever.add(documents) 54 | 55 | >>> retriever("fashion", k=2) 56 | [{'id': 1, 'similarity': 100.0}, {'id': 0, 'similarity': 46.15384615384615}] 57 | ``` 58 | 59 | ## Batch retrieval 60 | 61 | If we have several queries for which we want to retrieve the top k documents then we can 62 | pass a list of queries to the retriever. In batch-mode, retriever returns a list of list of 63 | documents instead of a list of documents. 64 | 65 | ```python 66 | >>> retriever(["france", "arts", "capital"], k=30) 67 | [[{'id': 0, 'similarity': 100.0}, # Match query 1 68 | {'id': 2, 'similarity': 100.0}, 69 | {'id': 1, 'similarity': 66.66666666666667}], 70 | [{'id': 1, 'similarity': 100.0}, # Match query 2 71 | {'id': 0, 'similarity': 75.0}, 72 | {'id': 2, 'similarity': 75.0}], 73 | [{'id': 0, 'similarity': 100.0}, # Match query 3 74 | {'id': 1, 'similarity': 44.44444444444444}, 75 | {'id': 2, 'similarity': 44.44444444444444}]] 76 | ``` 77 | 78 | ## Map keys to documents 79 | 80 | We can map documents to retrieved keys. 81 | 82 | ```python 83 | >>> retriever += documents 84 | >>> retriever("fashion", k=30) 85 | [{'id': 1, 86 | 'article': 'Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.', 87 | 'title': 'Paris', 88 | 'url': 'https://en.wikipedia.org/wiki/Paris', 89 | 'similarity': 100.0}, 90 | {'id': 0, 91 | 'article': 'Paris is the capital and most populous city of France', 92 | 'title': 'Paris', 93 | 'url': 'https://en.wikipedia.org/wiki/Paris', 94 | 'similarity': 46.15384615384615}, 95 | {'id': 2, 96 | 'article': 'The City of Paris is the centre and seat of government of the region and province of Île-de-France .', 97 | 'title': 'Paris', 98 | 'url': 'https://en.wikipedia.org/wiki/Paris', 99 | 'similarity': 46.15384615384615}] 100 | ``` 101 | -------------------------------------------------------------------------------- /docs/retrieve/lunr.md: -------------------------------------------------------------------------------- 1 | # Lunr 2 | 3 | `retrieve.Lunr` is a wrapper of [Lunr.py](https://github.com/yeraydiazdiaz/lunr.py). It is a powerful and practical solution for searching inside a corpus of documents without using a retriever such as Elasticsearch when it is not needed. Lunr stores an inverted index in memory. 4 | 5 | ```python 6 | >>> from cherche import retrieve 7 | 8 | >>> documents = [ 9 | ... { 10 | ... "id": 0, 11 | ... "article": "Paris is the capital and most populous city of France", 12 | ... "title": "Paris", 13 | ... "url": "https://en.wikipedia.org/wiki/Paris" 14 | ... }, 15 | ... { 16 | ... "id": 1, 17 | ... "article": "Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.", 18 | ... "title": "Paris", 19 | ... "url": "https://en.wikipedia.org/wiki/Paris" 20 | ... }, 21 | ... { 22 | ... "id": 2, 23 | ... "article": "The City of Paris is the centre and seat of government of the region and province of Île-de-France .", 24 | ... "title": "Paris", 25 | ... "url": "https://en.wikipedia.org/wiki/Paris" 26 | ... } 27 | ... ] 28 | 29 | >>> retriever = retrieve.Lunr(key="id", on=["title", "article"], documents=documents) 30 | 31 | >>> retriever("france", k=30) 32 | [{'id': 0, 'similarity': 0.605}, {'id': 2, 'similarity': 0.47}] 33 | ``` 34 | 35 | ## Batch retrieval 36 | 37 | If we have several queries for which we want to retrieve the top k documents then we can 38 | pass a list of queries to the retriever. In batch-mode, retriever returns a list of list of 39 | documents instead of a list of documents. 40 | 41 | ```python 42 | >>> retriever(["france", "arts", "capital"], k=30) 43 | [[{'id': 0, 'similarity': 0.605}, {'id': 2, 'similarity': 0.47}], # Match query 1 44 | [{'id': 1, 'similarity': 0.802}], # Match query 2 45 | [{'id': 0, 'similarity': 1.263}]] # Match query 3 46 | ``` 47 | 48 | ## Map keys to documents 49 | 50 | ```python 51 | >>> retriever += documents 52 | >>> retriever("arts") 53 | [{'id': 1, 54 | 'article': 'Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.', 55 | 'title': 'Paris', 56 | 'url': 'https://en.wikipedia.org/wiki/Paris', 57 | 'similarity': 0.802}] 58 | ``` 59 | -------------------------------------------------------------------------------- /docs/retrieve/retrieve.md: -------------------------------------------------------------------------------- 1 | # Retrieve 2 | 3 | Retrievers speed up the neural search pipeline by filtering out the majority of documents that are not relevant. Rankers (slower) will then pull up the most relevant documents based on semantic 4 | similarity. 5 | 6 | `retrieve.Encoder`, `retrieve.DPR` and `retrieve.Embedding` retrievers rely on semantic similarity, unlike the other retrievers, which match exact words. 7 | 8 | ## Retrievers 9 | 10 | Here is the list of available retrievers using Cherche: 11 | 12 | - `retrieve.TfIdf` 13 | - `retrieve.Lunr` 14 | - `retrieve.Flash` 15 | - `retrieve.Fuzz` 16 | - `retrieve.Encoder` 17 | - `retrieve.DPR` 18 | - `retrieve.Embedding` 19 | 20 | To use `retrieve.Encoder`, `retrieve.DPR` or `retrieve.Embedding` we will need to install cherche using: 21 | 22 | ```sh 23 | pip install "cherche[cpu]" 24 | ``` 25 | 26 | If we want to run semantic retrievers on GPU: 27 | 28 | ```sh 29 | pip install "cherche[gpu]" 30 | ``` 31 | 32 | ## Tutorial 33 | 34 | The main parameter of retrievers is `on`; it is the field(s) on which the retriever will perform the search. If multiple fields are specified, the retriever will concatenate all fields in the order provided. The `key` 35 | parameter is the name of the field that contain an unique identifier for the document. 36 | 37 | ```python 38 | >>> from cherche import retrieve 39 | 40 | >>> documents = [ 41 | ... { 42 | ... "id": 0, 43 | ... "article": "Paris is the capital and most populous city of France", 44 | ... "title": "Paris", 45 | ... }, 46 | ... { 47 | ... "id": 1, 48 | ... "article": "Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.", 49 | ... "title": "Paris", 50 | ... }, 51 | ... ] 52 | 53 | >>> retriever = retrieve.TfIdf(key="id", on=["title", "article"], documents=documents) 54 | ``` 55 | 56 | Calling a retriever with a single query output a list of documents. 57 | ```python 58 | retriever("paris", k=30) 59 | [{'id': 0, 'similarity': 0.21638007903488998}, 60 | {'id': 1, 'similarity': 0.13897776006154242}] 61 | ``` 62 | 63 | Calling a retriever with a list of queries output a list of list of documents. If we want to call 64 | retriever on multiples queries, we should opt for the following code: 65 | 66 | ```python 67 | retriever(["paris", "art", "finance"], k=30) 68 | [[{'id': 0, 'similarity': 0.21638007903488998}, # Query 1 69 | {'id': 1, 'similarity': 0.13897776006154242}], 70 | [{'id': 1, 'similarity': 0.03987124117278171}], # Query 2 71 | [{'id': 1, 'similarity': 0.15208763286878763}, # Query 3 72 | {'id': 0, 'similarity': 0.02564158475123616}]] 73 | ``` 74 | 75 | ## Parameters 76 | 77 | | Retriever | Add | Semantic | Batch optmized | 78 | |:------------------:|:---------:|:---------:|:-----------:| 79 | | retrieve.Encoder | ✅ | ✅ | ✅ | 80 | | retrieve.DPR | ✅ | ✅ | ✅ | 81 | | retrieve.Embedding | ✅ | ✅ | ✅ | 82 | | retrieve.Flash | ✅ | ❌ | ❌ | 83 | | retrieve.Fuzz | ✅ | ❌ | ❌ | 84 | | retrieve.TfIdf | ❌ | ❌ | ✅ | 85 | | retrieve.Lunr | ❌ | ❌ | ❌ | 86 | 87 | - Add: Retriever has a `.add(documents)` method to index new documents along the way. 88 | - Semantic: The Retriever is powered by a language model, enabling semantic similarity-based document retrieval. 89 | - Batch-Optimized: The Retriever is optimized for batch processing, with a batch_size parameter that can be adjusted to handle multiple queries efficiently. 90 | 91 | We can call retrievers with a k-parameter, which enables the selection of the number of documents to be retrieved. By default, the value of k is set to None, meaning that the retrievers will retrieve all documents that match the query. However, if a specific value for k is chosen, the retriever will only retrieve the top k documents that are most likely to match the query. 92 | 93 | ```python 94 | >>> retriever(["paris", "art"], k=3) 95 | [[{'id': 0, 'similarity': 0.21638007903488998}, 96 | {'id': 1, 'similarity': 0.13897776006154242}], 97 | [{'id': 1, 'similarity': 0.03987124117278171}]] 98 | ``` 99 | 100 | ## Matching indexes to documents 101 | 102 | It is possible to directly retrieve the content of the documents using the `+` operator between retriever and documents. Documents mapping is helpful if we want to plug our retriever on a `rank.CrossEncoder`. 103 | 104 | ```python 105 | >>> retriever += documents 106 | >>> retriever("Paris") 107 | [{'id': 0, 108 | 'article': 'Paris is the capital and most populous city of France', 109 | 'title': 'Paris', 110 | 'url': 'https://en.wikipedia.org/wiki/Paris'}, 111 | {'id': 1, 112 | 'article': 'Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.', 113 | 'title': 'Paris', 114 | 'url': 'https://en.wikipedia.org/wiki/Paris'}, 115 | {'id': 2, 116 | 'article': 'The City of Paris is the centre and seat of government of the region and province of Île-de-France .', 117 | 'title': 'Paris', 118 | 'url': 'https://en.wikipedia.org/wiki/Paris'}] 119 | ``` 120 | -------------------------------------------------------------------------------- /docs/retrieve/tfidf.md: -------------------------------------------------------------------------------- 1 | # TfIdf 2 | 3 | Our TF-IDF retriever relies on the [sparse.TfidfVectorizer](https://github.com/raphaelsty/LeNLP) of Sklearn. It computes the dot product between the query TF-IDF vector and the documents TF-IDF matrix and retrieves the highest match. TfIdf retriever stores a sparse matrix and an index that links the rows of the matrix to document identifiers. 4 | 5 | ```python 6 | >>> from cherche import retrieve 7 | 8 | >>> documents = [ 9 | ... { 10 | ... "id": 0, 11 | ... "article": "Paris is the capital and most populous city of France", 12 | ... "title": "Paris", 13 | ... "url": "https://en.wikipedia.org/wiki/Paris" 14 | ... }, 15 | ... { 16 | ... "id": 1, 17 | ... "article": "Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.", 18 | ... "title": "Paris", 19 | ... "url": "https://en.wikipedia.org/wiki/Paris" 20 | ... }, 21 | ... { 22 | ... "id": 2, 23 | ... "article": "The City of Paris is the centre and seat of government of the region and province of Île-de-France .", 24 | ... "title": "Paris", 25 | ... "url": "https://en.wikipedia.org/wiki/Paris" 26 | ... } 27 | ... ] 28 | 29 | >>> retriever = retrieve.TfIdf(key="id", on=["title", "article"], documents=documents, k=30) 30 | 31 | >>> retriever("france") 32 | [{'id': 0, 'similarity': 0.15137222675009282}, 33 | {'id': 2, 'similarity': 0.10831402366399025}, 34 | {'id': 1, 'similarity': 0.02505818772920329}] 35 | ``` 36 | 37 | We can also initialize the retriever with a custom [sparse.TfidfVectorizer](https://github.com/raphaelsty/LeNLP). 38 | 39 | ```python 40 | >>> from cherche import retrieve 41 | >>> from lenlp import sparse 42 | 43 | >>> documents = [ 44 | ... { 45 | ... "id": 0, 46 | ... "article": "Paris is the capital and most populous city of France", 47 | ... "title": "Paris", 48 | ... "url": "https://en.wikipedia.org/wiki/Paris" 49 | ... }, 50 | ... { 51 | ... "id": 1, 52 | ... "article": "Paris has been one of Europe major centres of finance, diplomacy , commerce , fashion , gastronomy , science , and arts.", 53 | ... "title": "Paris", 54 | ... "url": "https://en.wikipedia.org/wiki/Paris" 55 | ... }, 56 | ... { 57 | ... "id": 2, 58 | ... "article": "The City of Paris is the centre and seat of government of the region and province of Île-de-France .", 59 | ... "title": "Paris", 60 | ... "url": "https://en.wikipedia.org/wiki/Paris" 61 | ... } 62 | ... ] 63 | 64 | >>> tfidf = sparse.TfidfVectorizer( 65 | ... normalize=True, ngram_range=(3, 7), analyzer="char_wb") 66 | 67 | >>> retriever = retrieve.TfIdf( 68 | ... key="id", on=["title", "article"], documents=documents, tfidf=tfidf) 69 | 70 | >>> retriever("fra", k=3) 71 | [{'id': 0, 'similarity': 0.15055477454160002}, 72 | {'id': 2, 'similarity': 0.022883459495904895}] 73 | ``` 74 | 75 | ## Batch retrieval 76 | 77 | If we have several queries for which we want to retrieve the top k documents then we can 78 | pass a list of queries to the retriever. This is much faster for multiple queries. In batch-mode, 79 | retriever returns a list of list of documents instead of a list of documents. 80 | 81 | ```python 82 | >>> retriever(["fra", "arts", "capital"], k=3) 83 | [[{'id': 0, 'similarity': 0.051000705070125066}, # Match query 1 84 | {'id': 2, 'similarity': 0.03415513704304113}], 85 | [{'id': 1, 'similarity': 0.07021399356970497}], # Match query 2 86 | [{'id': 0, 'similarity': 0.25972148184421534}]] # Match query 3 87 | ``` 88 | 89 | ## Map keys to documents 90 | 91 | We can map documents to retrieved keys. 92 | 93 | ```python 94 | >>> retriever += documents 95 | >>> retriever("fra") 96 | [{'id': 0, 97 | 'article': 'Paris is the capital and most populous city of France', 98 | 'title': 'Paris', 99 | 'url': 'https://en.wikipedia.org/wiki/Paris', 100 | 'similarity': 0.15055477454160002}, 101 | {'id': 2, 102 | 'article': 'The City of Paris is the centre and seat of government of the region and province of Île-de-France .', 103 | 'title': 'Paris', 104 | 'url': 'https://en.wikipedia.org/wiki/Paris', 105 | 'similarity': 0.022883459495904895}] 106 | ``` 107 | -------------------------------------------------------------------------------- /docs/serialize/.pages: -------------------------------------------------------------------------------- 1 | title: Save & Load 2 | nav: 3 | - serialize.md 4 | -------------------------------------------------------------------------------- /docs/serialize/serialize.md: -------------------------------------------------------------------------------- 1 | # Save & Load 2 | 3 | Serialization in Python saves an object on the disk to reload it during a new session. Using Cherche, we could prototype a neural search pipeline in a notebook before deploying it on an API. We can also save a neural search pipeline to avoid recomputing embeddings of the ranker. 4 | 5 | We must ensure that the package versions are identical in both environments (dumping and loading). 6 | 7 | ## Saving and loading on same environment 8 | 9 | ### Saving 10 | 11 | We will initialize and save our pipeline in a `search.pkl` file 12 | 13 | ```python 14 | >>> from cherche import data, retrieve, rank 15 | >>> from sentence_transformers import SentenceTransformer 16 | >>> import pickle 17 | 18 | >>> documents = data.load_towns() 19 | 20 | >>> retriever = retrieve.TfIdf(key="id", on=["title", "article"], documents=documents) 21 | 22 | >>> ranker = rank.Encoder( 23 | ... key = "id", 24 | ... on = ["title", "article"], 25 | ... encoder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2").encode, 26 | ... ) 27 | 28 | >>> search = retriever + ranker 29 | # Pre-compute embeddings of the ranker 30 | >>> search.add(documents=documents) 31 | 32 | # Dump our pipeline using pickle. 33 | # The file search.pkl contains our pipeline 34 | >>> with open("search.pkl", "wb") as search_file: 35 | ... pickle.dump(search, search_file) 36 | 37 | ``` 38 | 39 | ### Loading 40 | 41 | After saving our pipeline in the file `search.pkl`, we can reload it using Pickle. 42 | 43 | ```python 44 | >>> import pickle 45 | 46 | >>> with open("search.pkl", "rb") as search_file: 47 | ... search = pickle.load(search_file) 48 | 49 | >>> search("bordeaux", k=10) 50 | [{'id': 57, 'similarity': 0.69513476}, 51 | {'id': 63, 'similarity': 0.6214991}, 52 | {'id': 65, 'similarity': 0.61809057}, 53 | {'id': 59, 'similarity': 0.61285114}, 54 | {'id': 71, 'similarity': 0.5893674}, 55 | {'id': 67, 'similarity': 0.5893066}, 56 | {'id': 74, 'similarity': 0.58757037}, 57 | {'id': 61, 'similarity': 0.58593774}, 58 | {'id': 70, 'similarity': 0.5854107}, 59 | {'id': 66, 'similarity': 0.56525207}] 60 | ``` 61 | 62 | ## Saving on GPU, loading on CPU 63 | 64 | Typically, we could pre-compute the document integration on google collab with a GPU before 65 | deploying our neural search pipeline on a CPU-based instance. 66 | 67 | When transferring the pipeline that runs on the GPU to a machine that will run it on the CPU, it will be necessary to avoid serializing the `retrieve.Encoder`, `retrieve.DPR`, `rank.DPR` and `rank.Encoder`. These retrievers and rankers would not be compatible if we initialized them on GPU. We will have to replace the models on GPU to put them on CPU. We must ensure that the package versions are strictly identical in both environments (GPU and CPU). 68 | 69 | ### Saving on GPU 70 | 71 | ```python 72 | >>> from cherche import data, retrieve, rank 73 | >>> from sentence_transformers import SentenceTransformer 74 | >>> import pickle 75 | 76 | >>> documents = data.load_towns() 77 | 78 | >>> retriever = retrieve.TfIdf(key="id", on=["title", "article"], documents=documents) 79 | 80 | >>> ranker = rank.Encoder( 81 | ... key = "id", 82 | ... on = ["title", "article"], 83 | ... encoder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", device="cuda").encode, 84 | ... ) 85 | 86 | >>> search = retriever + ranker 87 | # Pre-compute embeddings of the ranker 88 | >>> search.add(documents=documents) 89 | 90 | # Replace the GPU-based encoder with a CPU-based encoder. 91 | >>> ranker.encoder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2").encode 92 | 93 | with open("search.pkl", "wb") as search_file: 94 | pickle.dump(search, search_file) 95 | ``` 96 | 97 | ### Loading on CPU 98 | 99 | We can load our neural search pipeline using `pickle.load` in a new session. 100 | 101 | ```python 102 | >>> import pickle 103 | 104 | >>> with open("search.pkl", "rb") as search_file: 105 | ... search = pickle.load(search_file) 106 | 107 | >>> search("bordeaux", k=10) 108 | [{'id': 57, 'similarity': 0.69513476}, 109 | {'id': 63, 'similarity': 0.6214991}, 110 | {'id': 65, 'similarity': 0.61809057}, 111 | {'id': 59, 'similarity': 0.61285114}, 112 | {'id': 71, 'similarity': 0.5893674}, 113 | {'id': 67, 'similarity': 0.5893066}, 114 | {'id': 74, 'similarity': 0.58757037}, 115 | {'id': 61, 'similarity': 0.58593774}, 116 | {'id': 70, 'similarity': 0.5854107}, 117 | {'id': 66, 'similarity': 0.56525207}] 118 | ``` 119 | -------------------------------------------------------------------------------- /docs/stylesheets/extra.css: -------------------------------------------------------------------------------- 1 | .md-typeset h2 { 2 | margin: 1.5em 0; 3 | padding-bottom: .4rem; 4 | border-bottom: .04rem solid var(--md-default-fg-color--lighter); 5 | } 6 | 7 | .md-footer { 8 | margin-top: 2em; 9 | } 10 | 11 | .md-typeset pre > code { 12 | border-radius: 0.5em; 13 | } -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | # Project information 2 | site_name: cherche 3 | site_description: Neural Search 4 | site_author: Raphael Sourty 5 | site_url: https://raphaelsty.github.io/cherche 6 | 7 | # Repository 8 | repo_name: raphaelsty/cherche 9 | repo_url: https://github.com/raphaelsty/cherche 10 | edit_uri: "" 11 | 12 | # Copyright 13 | copyright: Copyright © 2020 - 2021 14 | 15 | # Configuration 16 | theme: 17 | name: material 18 | language: en 19 | palette: 20 | primary: indigo 21 | accent: indigo 22 | font: 23 | text: Roboto 24 | code: Roboto Mono 25 | favicon: img/favicon.ico 26 | features: 27 | - navigation.tabs 28 | - navigation.instant 29 | 30 | # Extras 31 | extra: 32 | social: 33 | - icon: fontawesome/brands/github-alt 34 | link: https://github.com/raphaelsty/cherche 35 | analytics: 36 | provider: google 37 | property: G-NHQCHCD6L6 38 | 39 | # Extensions 40 | markdown_extensions: 41 | - admonition 42 | - footnotes 43 | - toc: 44 | permalink: true 45 | toc_depth: "1-3" 46 | - pymdownx.details 47 | - pymdownx.arithmatex: 48 | generic: true 49 | - pymdownx.highlight 50 | - pymdownx.superfences 51 | 52 | plugins: 53 | - search 54 | - awesome-pages 55 | - mkdocs-jupyter 56 | 57 | extra_javascript: 58 | - javascripts/config.js 59 | - https://polyfill.io/v3/polyfill.min.js?features=es6 60 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 61 | - js/version-select.js 62 | 63 | extra_css: 64 | - stylesheets/extra.css 65 | - css/version-select.css 66 | 67 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | filterwarnings = 3 | ignore::DeprecationWarning 4 | ignore::RuntimeWarning 5 | ignore::UserWarning 6 | addopts = 7 | --doctest-modules 8 | --verbose 9 | -ra 10 | --cov-config=.coveragerc 11 | -m "not web and not slow" 12 | doctest_optionflags = NORMALIZE_WHITESPACE NUMBER 13 | norecursedirs = 14 | build 15 | docs 16 | node_modules 17 | markers = 18 | web: tests that require using the Internet 19 | slow: tests that take a long time to run -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Inside of setup.cfg 2 | [metadata] 3 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | from cherche.__version__ import __version__ 4 | 5 | with open("README.md", "r", encoding="utf-8") as fh: 6 | long_description = fh.read() 7 | 8 | base_packages = [ 9 | "numpy >= 1.23.5", 10 | "scikit-learn >= 1.5.0", 11 | "lunr >= 0.6.2", 12 | "rapidfuzz >= 3.0.0", 13 | "flashtext >= 2.7", 14 | "tqdm >= 4.62.3", 15 | "scipy >= 1.7.3", 16 | "lenlp == 1.1.0", 17 | "sentence-transformers >= 3.0.0", 18 | ] 19 | 20 | cpu = ["sentence-transformers >= 3.0.0", "faiss-cpu >= 1.7.4"] 21 | gpu = ["sentence-transformers >= 3.0.0", "faiss-gpu >= 1.7.4"] 22 | dev = [ 23 | "numpydoc >= 1.4.0", 24 | "mkdocs_material >= 8.3.5", 25 | "mkdocs-awesome-pages-plugin >= 2.7.0", 26 | "mkdocs-jupyter >= 0.21.0", 27 | "pytest-cov >= 4.0.0", 28 | "pytest >= 7.3.1", 29 | "isort >= 5.12.0", 30 | "ipywidgets >= 8.0.6", 31 | ] 32 | 33 | setuptools.setup( 34 | name="cherche", 35 | version=f"{__version__}", 36 | license="MIT", 37 | author="Raphael Sourty", 38 | author_email="raphael.sourty@gmail.com", 39 | description="Neural Search", 40 | long_description=long_description, 41 | long_description_content_type="text/markdown", 42 | url="https://github.com/raphaelsty/cherche", 43 | download_url="https://github.com/user/cherche/archive/v_01.tar.gz", 44 | keywords=[ 45 | "neural search", 46 | "information retrieval", 47 | "question answering", 48 | "semantic search", 49 | ], 50 | packages=setuptools.find_packages(), 51 | install_requires=base_packages, 52 | extras_require={ 53 | "cpu": base_packages + cpu, 54 | "gpu": base_packages + gpu, 55 | "dev": base_packages + cpu + dev, 56 | }, 57 | package_data={"cherche": ["data/towns.json", "data/semanlink/*.json"]}, 58 | classifiers=[ 59 | "Programming Language :: Python :: 3", 60 | "License :: OSI Approved :: MIT License", 61 | "Operating System :: OS Independent", 62 | ], 63 | python_requires=">=3.6", 64 | ) 65 | --------------------------------------------------------------------------------