├── tests ├── __init__.py ├── test_data.py ├── tests │ └── test_frontend.py ├── data │ └── sample_documents.jsonl └── test_index.py ├── src └── spacerini │ ├── data │ ├── utils.py │ ├── __main__.py │ ├── __init__.py │ └── load.py │ ├── preprocess │ ├── __main__.py │ ├── __init__.py │ ├── tokenize.py │ └── utils.py │ ├── spacerini_utils │ ├── __init__.py │ ├── index.py │ └── search.py │ ├── prebuilt.py │ ├── __init__.py │ ├── search │ ├── __init__.py │ └── utils.py │ ├── frontend │ ├── __init__.py │ ├── local.py │ ├── space.py │ └── __main__.py │ ├── index │ ├── __init__.py │ ├── utils.py │ ├── encode.py │ └── index.py │ └── cli.py ├── templates ├── streamlit │ ├── cookiecutter.json │ └── {{ cookiecutter.module_slug }} │ │ ├── index │ │ └── .gitkeep │ │ ├── packages.txt │ │ ├── requirements.txt │ │ ├── .gitattributes │ │ ├── README.md │ │ └── app.py ├── gradio │ ├── {{ cookiecutter.local_app }} │ │ ├── data │ │ │ └── .gitkeep │ │ ├── index │ │ │ └── .gitkeep │ │ ├── packages.txt │ │ ├── requirements.txt │ │ ├── .gitattributes │ │ ├── README.md │ │ └── app.py │ └── cookiecutter.json └── gradio_roots_temp │ ├── {{ cookiecutter.local_app }} │ ├── index │ │ └── .gitkeep │ ├── packages.txt │ ├── requirements.txt │ ├── .gitattributes │ ├── README.md │ └── app.py │ └── cookiecutter.json ├── examples ├── configs │ └── xsum.json ├── scripts │ ├── index-push-pull.py │ ├── xsum-demo.py │ └── gradio-demo.py └── notebooks │ └── indexing_tutorial.ipynb ├── docs └── arguments.md ├── pyproject.toml ├── .gitignore ├── README.md └── LICENSE /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/spacerini/data/utils.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/tests/test_frontend.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/spacerini/data/__main__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/spacerini/preprocess/__main__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /templates/streamlit/cookiecutter.json: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/spacerini/spacerini_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/spacerini/prebuilt.py: -------------------------------------------------------------------------------- 1 | EXAMPLES = { 2 | 'xsum', 3 | } -------------------------------------------------------------------------------- /templates/gradio/{{ cookiecutter.local_app }}/data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /templates/gradio/{{ cookiecutter.local_app }}/index/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/spacerini/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | from . import tokenize, utils -------------------------------------------------------------------------------- /templates/streamlit/{{ cookiecutter.module_slug }}/index/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /templates/gradio_roots_temp/{{ cookiecutter.local_app }}/index/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /templates/gradio/{{ cookiecutter.local_app }}/packages.txt: -------------------------------------------------------------------------------- 1 | default-jdk 2 | -------------------------------------------------------------------------------- /src/spacerini/__init__.py: -------------------------------------------------------------------------------- 1 | from . import index, preprocess, frontend, search, data -------------------------------------------------------------------------------- /templates/streamlit/{{ cookiecutter.module_slug }}/packages.txt: -------------------------------------------------------------------------------- 1 | default-jdk 2 | -------------------------------------------------------------------------------- /templates/gradio_roots_temp/{{ cookiecutter.local_app }}/packages.txt: -------------------------------------------------------------------------------- 1 | default-jdk 2 | -------------------------------------------------------------------------------- /templates/streamlit/{{ cookiecutter.module_slug }}/requirements.txt: -------------------------------------------------------------------------------- 1 | pyserini 2 | faiss-cpu 3 | torch -------------------------------------------------------------------------------- /templates/gradio/{{ cookiecutter.local_app }}/requirements.txt: -------------------------------------------------------------------------------- 1 | pyserini 2 | datasets 3 | faiss-cpu 4 | torch -------------------------------------------------------------------------------- /templates/gradio/{{ cookiecutter.local_app }}/.gitattributes: -------------------------------------------------------------------------------- 1 | index/**/* filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /src/spacerini/search/__init__.py: -------------------------------------------------------------------------------- 1 | from . import utils 2 | from .utils import init_searcher, result_indices, result_page -------------------------------------------------------------------------------- /templates/gradio_roots_temp/{{ cookiecutter.local_app }}/requirements.txt: -------------------------------------------------------------------------------- 1 | pyserini 2 | datasets 3 | faiss-cpu 4 | torch -------------------------------------------------------------------------------- /templates/streamlit/{{ cookiecutter.module_slug }}/.gitattributes: -------------------------------------------------------------------------------- 1 | index/**/* filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /templates/gradio_roots_temp/{{ cookiecutter.local_app }}/.gitattributes: -------------------------------------------------------------------------------- 1 | index/**/* filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /src/spacerini/frontend/__init__.py: -------------------------------------------------------------------------------- 1 | from . import local, space 2 | from .space import create_space_from_local 3 | from .local import create_app -------------------------------------------------------------------------------- /tests/data/sample_documents.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "doc1", "contents": "contents of doc one."} 2 | {"id": "doc2", "contents": "contents of document two."} 3 | {"id": "doc3", "contents": "here's some text in document three."} -------------------------------------------------------------------------------- /src/spacerini/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import load, utils 2 | from .load import load_from_hub, load_from_pandas, load_ir_dataset, load_ir_dataset_low_memory, load_ir_dataset_streaming, load_from_local, load_from_sqlite_table -------------------------------------------------------------------------------- /src/spacerini/preprocess/tokenize.py: -------------------------------------------------------------------------------- 1 | def batch_iterator(hf_dataset, text_field, batch_size=1000): 2 | for i in range(0, len(hf_dataset), batch_size): 3 | yield hf_dataset.select(range(i, i + batch_size))[text_field] -------------------------------------------------------------------------------- /src/spacerini/index/__init__.py: -------------------------------------------------------------------------------- 1 | from . import index 2 | from .encode import encode_json_dataset 3 | from .index import index_json_shards, index_streaming_dataset 4 | from .index import fetch_index_stats 5 | from .utils import push_index_to_hub, load_index_from_hub 6 | -------------------------------------------------------------------------------- /templates/gradio/cookiecutter.json: -------------------------------------------------------------------------------- 1 | { 2 | "dset_text_field": null, 3 | "metadata_field": null, 4 | "space_title": null, 5 | "local_app": "{{ cookiecutter.space_title|lower|replace(' ', '-') }}", 6 | "space_gradio_sdk_version": "3.29.0", 7 | "space_license": "apache-2.0", 8 | "space_pinned": "false" 9 | } 10 | -------------------------------------------------------------------------------- /templates/gradio_roots_temp/cookiecutter.json: -------------------------------------------------------------------------------- 1 | { 2 | "space_title": null, 3 | "local_app": "{{ cookiecutter.space_title|lower|replace(' ', '-') }}", 4 | "space_gradio_sdk_version": "3.12.0", 5 | "space_license": "apache-2.0", 6 | "space_pinned": "false", 7 | "emoji": "🚀", 8 | "space_description": null, 9 | "private": false, 10 | "dataset_name": null 11 | } -------------------------------------------------------------------------------- /templates/streamlit/{{ cookiecutter.module_slug }}/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: {{cookiecutter.space_title}} 3 | emoji: 🐠 4 | colorFrom: blue 5 | colorTo: blue 6 | sdk: streamlit 7 | sdk_version: 1.25.0 8 | app_file: app.py 9 | pinned: false 10 | license: {{cookiecutter.space_license}} 11 | --- 12 | 13 | Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference 14 | -------------------------------------------------------------------------------- /examples/configs/xsum.json: -------------------------------------------------------------------------------- 1 | { 2 | "delete_after": true, 3 | "space_name": "xsum", 4 | "space_url_slug": "xsumcli", 5 | "sdk": "gradio", 6 | "organization": null, 7 | "description": null, 8 | "content_column": "document", 9 | "dataset": "xsum", 10 | "docid-column": "id", 11 | "split": "test", 12 | "template": "gradio_roots_temp", 13 | "store_raw": true, 14 | "store_contents": true 15 | } -------------------------------------------------------------------------------- /templates/gradio/{{ cookiecutter.local_app }}/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: {{ cookiecutter.space_title }} 3 | emoji: 🐠 4 | colorFrom: blue 5 | colorTo: blue 6 | sdk: gradio 7 | sdk_version: {{ cookiecutter.space_gradio_sdk_version }} 8 | app_file: app.py 9 | pinned: {{ cookiecutter.space_pinned }} 10 | license: {{ cookiecutter.space_license }} 11 | --- 12 | 13 | Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference -------------------------------------------------------------------------------- /templates/gradio_roots_temp/{{ cookiecutter.local_app }}/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: {{ cookiecutter.space_title }} 3 | emoji: {{ cookiecutter.emoji }} 4 | colorFrom: blue 5 | colorTo: blue 6 | sdk: gradio 7 | sdk_version: {{ cookiecutter.space_gradio_sdk_version }} 8 | app_file: app.py 9 | pinned: {{ cookiecutter.space_pinned }} 10 | license: {{ cookiecutter.space_license }} 11 | --- 12 | 13 | Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference -------------------------------------------------------------------------------- /src/spacerini/spacerini_utils/index.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any 3 | from typing import Dict 4 | 5 | from pyserini.index.lucene import IndexReader 6 | 7 | 8 | def fetch_index_stats(index_path: str) -> Dict[str, Any]: 9 | """ 10 | Fetch index statistics 11 | index_path : str 12 | Path to index directory 13 | Returns 14 | ------- 15 | Dictionary of index statistics 16 | Dictionary Keys ==> total_terms, documents, unique_terms 17 | """ 18 | assert os.path.exists(index_path), f"Index path {index_path} does not exist" 19 | index_reader = IndexReader(index_path) 20 | return index_reader.stats() 21 | -------------------------------------------------------------------------------- /docs/arguments.md: -------------------------------------------------------------------------------- 1 | # List of Parameters 2 | 3 | ## Indexing (`spacerini.index`) 4 | 5 | - `disable_tqdm` : bool 6 | Disable tqdm output 7 | - `index_path` : str 8 | Directory to store index 9 | - `language` : str 10 | Language of dataset 11 | - `pretokenized` : bool 12 | If True, dataset is already tokenized 13 | - `analyzeWithHuggingFaceTokenizer` : str 14 | If True, use HuggingFace tokenizer to tokenize dataset 15 | - `storePositions` : bool 16 | If True, store positions of tokens in index 17 | - `storeDocvectors` : bool 18 | If True, store document vectors in index 19 | - `storeContents` : bool 20 | If True, store contents of documents in index 21 | - `storeRaw` : bool 22 | If True, store raw contents of documents in index 23 | - `keepStopwords` : bool 24 | If True, keep stopwords in index 25 | - `stopwords` : str 26 | Path to stopwords file 27 | - `stemmer` : str 28 | Stemmer to use for indexing 29 | - `optimize` : bool 30 | If True, optimize index after indexing is complete 31 | - `verbose` : bool 32 | If True, print verbose output 33 | - `quiet` : bool 34 | If True, print no output 35 | - `memory_buffer` : str 36 | Memory buffer size 37 | - `n_threads` : bool 38 | Number of threads to use for indexing -------------------------------------------------------------------------------- /templates/streamlit/{{ cookiecutter.module_slug }}/app.py: -------------------------------------------------------------------------------- 1 | # This currently contains Odunayo's template. We need to adapt this to cookiecutter. 2 | import streamlit as st 3 | from pyserini.search.lucene import LuceneSearcher 4 | import json 5 | import time 6 | 7 | st.set_page_config(page_title="{{title}}", page_icon='', layout="centered") 8 | searcher = LuceneSearcher('{{index_path}}') 9 | 10 | 11 | col1, col2 = st.columns([9, 1]) 12 | with col1: 13 | search_query = st.text_input(label="", placeholder="Search") 14 | 15 | with col2: 16 | st.write('#') 17 | button_clicked = st.button("🔎") 18 | 19 | 20 | if search_query or button_clicked: 21 | num_results = None 22 | 23 | t_0 = time.time() 24 | search_results = searcher.search(search_query, k=100_000) 25 | search_time = time.time() - t_0 26 | 27 | st.write(f'

Retrieved {len(search_results):,.0f} documents in {search_time*1000:.2f} ms

', unsafe_allow_html=True) 28 | for result in search_results[:10]: 29 | result = json.loads(result.raw) 30 | doc = result["contents"] 31 | result_id = result["id"] 32 | try: 33 | st.write(doc[:1000], unsafe_allow_html=True) 34 | st.write(f'
Document ID: {result_id}
', unsafe_allow_html=True) 35 | 36 | except: 37 | pass 38 | 39 | st.write('---') -------------------------------------------------------------------------------- /examples/scripts/index-push-pull.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from spacerini.preprocess.utils import shard_dataset 3 | from spacerini.index.index import index_json_shards 4 | from spacerini.index.utils import push_index_to_hub, load_index_from_hub 5 | from spacerini.search.utils import result_indices 6 | 7 | DSET = "imdb" 8 | SPLIT = "train" 9 | COLUMN_TO_INDEX = "text" 10 | NUM_PROC = 28 11 | SHARDS_PATH = f"{DSET}-json-shards" 12 | TEST_QUERY = "great movie" 13 | NUM_RESULTS = 5 14 | INDEX_PATH = "index" 15 | DATASET_SLUG = "lucene-imdb-train" 16 | 17 | dset = load_dataset( 18 | DSET, 19 | split=SPLIT 20 | ) 21 | 22 | shard_dataset( 23 | hf_dataset=dset, 24 | shard_size="10MB", 25 | column_to_index=COLUMN_TO_INDEX, 26 | shards_paths=SHARDS_PATH 27 | ) 28 | 29 | index_json_shards( 30 | shards_path=SHARDS_PATH, 31 | keep_shards=False, 32 | index_path= INDEX_PATH, 33 | language="en", 34 | n_threads=NUM_PROC 35 | ) 36 | 37 | print( 38 | f"First {NUM_RESULTS} results for query: \"{TEST_QUERY}\"", 39 | result_indices(TEST_QUERY, NUM_RESULTS, INDEX_PATH) 40 | ) 41 | 42 | push_index_to_hub( 43 | dataset_slug=DATASET_SLUG, 44 | index_path="index", 45 | delete_after_push=True 46 | ) 47 | 48 | new_index_path = load_index_from_hub(DATASET_SLUG) 49 | print( 50 | f"First {NUM_RESULTS} results for query: \"{TEST_QUERY}\"", 51 | result_indices(TEST_QUERY, NUM_RESULTS, new_index_path) 52 | ) -------------------------------------------------------------------------------- /src/spacerini/frontend/local.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from shutil import copytree 3 | 4 | from cookiecutter.main import cookiecutter 5 | 6 | default_templates_dir = (Path(__file__).parents[3] / "templates").resolve() 7 | LOCAL_TEMPLATES = ["gradio", "streamlit", "gradio_roots_temp"] 8 | 9 | 10 | def create_app( 11 | template: str, 12 | extra_context_dict: dict, 13 | output_dir: str, 14 | no_input: bool = True, 15 | overwrite_if_exists: bool=True 16 | ) -> None: 17 | """ 18 | Create a new app from a template. 19 | Parameters 20 | ---------- 21 | template : str 22 | The name of the template to use. 23 | extra_context_dict : dict 24 | The extra context to pass to the template. 25 | output_dir : str 26 | The output directory. 27 | no_input : bool, optional 28 | If True, do not prompt for parameters and only use 29 | overwrite_if_exists : bool, optional 30 | If True, overwrite the output directory if it already exists. 31 | Returns 32 | ------- 33 | None 34 | """ 35 | cookiecutter( 36 | "https://github.com/castorini/hf-spacerini.git/" if template in LOCAL_TEMPLATES else template, 37 | directory="templates/" + template if template in LOCAL_TEMPLATES else None, 38 | no_input=no_input, 39 | extra_context=extra_context_dict, 40 | output_dir=output_dir, 41 | overwrite_if_exists=overwrite_if_exists, 42 | ) 43 | 44 | utils_dir = Path(__file__).parents[1].resolve() / "spacerini_utils" 45 | app_dir = Path(output_dir) / extra_context_dict["local_app"] / "spacerini_utils" 46 | copytree(utils_dir, app_dir, dirs_exist_ok=True) 47 | 48 | return None 49 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=64.0.0", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | 6 | [project] 7 | name = "spacerini" 8 | description = "Hosted Lucene indexes with Pyserini and Hugging Face Spaces" 9 | readme = { file = "README.md", content-type = "text/markdown" } 10 | maintainers = [ 11 | { name = "Christopher Akiki", email = "christopher.akiki@gmail.com" }, 12 | { name = "Ogundepo Odunayo", email = "ogundepoodunayo@gmail.com"}, 13 | { name = "Akintunde Oladipo", email = "akin.o.oladipo@gmail.com"} 14 | ] 15 | requires-python = ">=3.8" 16 | dependencies = [ 17 | 'pyserini', 18 | 'cookiecutter', 19 | 'huggingface_hub', 20 | 'tokenizers', 21 | 'datasets>=2.8.0', 22 | 'ir_datasets', 23 | 'streamlit', 24 | 'gradio', 25 | 'torch', 26 | 'faiss-cpu', 27 | ] 28 | dynamic = [ 29 | "version", 30 | ] 31 | classifiers = [ 32 | 'Development Status :: 3 - Alpha', 33 | 'Intended Audience :: Developers', 34 | 'Intended Audience :: Information Technology', 35 | 'Intended Audience :: Science/Research', 36 | 'License :: OSI Approved :: Apache Software License', 37 | 'Programming Language :: Python', 38 | 'Topic :: Software Development :: Libraries :: Python Modules', 39 | 'Operating System :: OS Independent', 40 | 'Programming Language :: Python :: 3', 41 | 'Programming Language :: Python :: 3.8', 42 | 'Programming Language :: Python :: 3.9', 43 | 'Programming Language :: Python :: 3.10', 44 | ] 45 | license = { text = "Apache-2.0" } 46 | 47 | [project.optional-dependencies] 48 | dev = [ 49 | 'pytest', 50 | ] 51 | 52 | [project.urls] 53 | Homepage = "https://github.com/castorini/hf-spacerini" 54 | 55 | [project.scripts] 56 | spacerini = "spacerini.cli:main" -------------------------------------------------------------------------------- /examples/scripts/xsum-demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains a demo of Spacerini using Train Collection of the XSum dataset. 3 | """ 4 | import os 5 | import logging 6 | from spacerini.frontend import create_app, create_space_from_local 7 | from spacerini.index import index_streaming_dataset 8 | 9 | logging.basicConfig(level=logging.INFO) 10 | 11 | DATASET = "xsum" 12 | SPLIT = "test" 13 | SPACE_TITLE = "XSum Train Dataset Search" 14 | COLUMN_TO_INDEX = ["document"] 15 | LOCAL_APP = "xsum-demo" 16 | SDK = "gradio" 17 | TEMPLATE = "gradio_roots_temp" 18 | ORGANIZATION = "ToluClassics" 19 | 20 | 21 | cookiecutter_vars = { 22 | "dset_text_field": COLUMN_TO_INDEX, 23 | "space_title": SPACE_TITLE, 24 | "local_app":LOCAL_APP, 25 | "space_description": "This is a demo of Spacerini using the XSum dataset.", 26 | "dataset_name": "xsum" 27 | } 28 | 29 | 30 | 31 | logging.info(f"Creating local app into {LOCAL_APP} directory") 32 | create_app( 33 | template=TEMPLATE, 34 | extra_context_dict=cookiecutter_vars, 35 | output_dir="apps" 36 | ) 37 | 38 | logging.info(f"Indexing {DATASET} dataset into {os.path.join('apps', LOCAL_APP, 'index')}") 39 | index_streaming_dataset( 40 | dataset_name_or_path=DATASET, 41 | index_path= os.path.join("apps", LOCAL_APP, "index"), 42 | split=SPLIT, 43 | column_to_index=COLUMN_TO_INDEX, 44 | doc_id_column="id", 45 | storeContents=True, 46 | storeRaw=True, 47 | language="en" 48 | ) 49 | 50 | logging.info(f"Creating space {SPACE_TITLE} on {ORGANIZATION}") 51 | create_space_from_local( 52 | space_slug="xsum-test", 53 | organization=ORGANIZATION, 54 | space_sdk=SDK, 55 | local_dir=os.path.join("apps", LOCAL_APP), 56 | delete_after_push=False 57 | ) -------------------------------------------------------------------------------- /src/spacerini/frontend/space.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import shutil 3 | 4 | from huggingface_hub import HfApi, create_repo, upload_folder 5 | 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def create_space_from_local( 11 | space_slug: str, 12 | space_sdk: str, 13 | local_dir: str, 14 | private: bool=False, 15 | organization: str=None, 16 | delete_after_push: bool=False, 17 | ) -> str: 18 | """ 19 | Create a new space from a local directory. 20 | Parameters 21 | ---------- 22 | space_slug : str 23 | The slug of the space. 24 | space_sdk : str 25 | The SDK of the space, could be either Gradio or Streamlit. 26 | local_dir : str 27 | The local directory where the app is currently stored. 28 | private : bool, optional 29 | If True, the space will be private. 30 | organization : str, optional 31 | The organization to create the space in. 32 | delete_after_push : bool, optional 33 | If True, delete the local directory after pushing it to the Hub. 34 | 35 | Returns 36 | ------- 37 | repo_url: str 38 | The URL of the space. 39 | """ 40 | 41 | if organization is None: 42 | hf_api = HfApi() 43 | namespace = hf_api.whoami()["name"] 44 | else: 45 | namespace = organization 46 | repo_id = namespace + "/" + space_slug 47 | try: 48 | repo_url = create_repo(repo_id=repo_id, repo_type="space", space_sdk=space_sdk, private=private, exist_ok=True) 49 | except Exception as ex: 50 | logger.error("Encountered an error while creating the space: ", ex) 51 | raise 52 | 53 | upload_folder(folder_path=local_dir, repo_id=repo_id, repo_type="space") 54 | if delete_after_push: 55 | shutil.rmtree(local_dir) 56 | return repo_url 57 | -------------------------------------------------------------------------------- /examples/scripts/gradio-demo.py: -------------------------------------------------------------------------------- 1 | from spacerini.frontend.local import create_app 2 | from spacerini.frontend.space import create_space_from_local 3 | from datasets import load_dataset 4 | from spacerini.preprocess.utils import shard_dataset, get_num_shards 5 | from spacerini.index.index import index_json_shards 6 | 7 | DSET = "imdb" 8 | SPLIT = "train" 9 | SPACE_TITLE = "IMDB search" 10 | COLUMN_TO_INDEX = "text" 11 | METADATA_COLUMNS = ["sentiment", "docid"] 12 | NUM_PROC = 28 13 | SHARDS_PATH = f"{DSET}-json-shards" 14 | LOCAL_APP = "gradio_app" 15 | SDK = "gradio" 16 | ORGANIZATION = "cakiki" 17 | MAX_ARROW_SHARD_SIZE="1GB" 18 | 19 | cookiecutter_vars = { 20 | "dset_text_field": COLUMN_TO_INDEX, 21 | "metadata_field": METADATA_COLUMNS[1], 22 | "space_title": SPACE_TITLE, 23 | "local_app":LOCAL_APP 24 | } 25 | create_app( 26 | template=SDK, 27 | extra_context_dict=cookiecutter_vars, 28 | output_dir="." 29 | ) 30 | 31 | dset = load_dataset( 32 | DSET, 33 | split=SPLIT 34 | ) 35 | 36 | shard_dataset( 37 | hf_dataset=dset, 38 | shard_size="10MB", 39 | column_to_index=COLUMN_TO_INDEX, 40 | shards_paths=SHARDS_PATH 41 | ) 42 | 43 | index_json_shards( 44 | shards_path=SHARDS_PATH, 45 | index_path=LOCAL_APP + "/index", 46 | language="en", 47 | n_threads=NUM_PROC 48 | ) 49 | 50 | dset = dset.add_column("docid", range(len(dset))) 51 | num_shards = get_num_shards(dset.data.nbytes, MAX_ARROW_SHARD_SIZE) 52 | dset.remove_columns([c for c in dset.column_names if not c in [COLUMN_TO_INDEX,*METADATA_COLUMNS]]).save_to_disk( 53 | LOCAL_APP + "/data", 54 | num_shards=num_shards, 55 | num_proc=NUM_PROC 56 | ) 57 | 58 | create_space_from_local( 59 | space_slug="imdb-search", 60 | organization=ORGANIZATION, 61 | space_sdk=SDK, 62 | local_dir=LOCAL_APP, 63 | delete_after_push=False 64 | ) -------------------------------------------------------------------------------- /src/spacerini/index/utils.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import logging 3 | from huggingface_hub import HfApi, create_repo, upload_folder, snapshot_download 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | def push_index_to_hub( 9 | dataset_slug: str, 10 | index_path: str, 11 | private: bool=False, 12 | organization: str=None, 13 | delete_after_push: bool=False, 14 | ) -> str: 15 | """ 16 | Push an index as a dataset to the Hugging Face Hub. 17 | ---------- 18 | dataset_slug : str 19 | The slug of the space. 20 | index_path : str 21 | The local directory where the app is currently stored. 22 | private : bool, optional 23 | If True, the space will be private. 24 | organization : str, optional 25 | The organization to create the space in. 26 | delete_after_push : bool, optional 27 | If True, delete the local index after pushing it to the Hub. 28 | 29 | Returns 30 | ------- 31 | repo_url: str 32 | The URL of the dataset. 33 | """ 34 | 35 | if organization is None: 36 | hf_api = HfApi() 37 | namespace = hf_api.whoami()["name"] 38 | else: 39 | namespace = organization 40 | repo_id = namespace + "/" + dataset_slug 41 | try: 42 | repo_url = create_repo(repo_id=repo_id, repo_type="dataset", private=private) 43 | except Exception as ex: 44 | logger.error("Encountered an error while creating the dataset repository: ", ex) 45 | raise 46 | 47 | upload_folder(folder_path=index_path, path_in_repo="index", repo_id=repo_id, repo_type="dataset") 48 | if delete_after_push: 49 | shutil.rmtree(index_path) 50 | return repo_url 51 | 52 | 53 | def load_index_from_hub(dataset_slug: str, organization: str=None) -> str: 54 | if organization is None: 55 | hf_api = HfApi() 56 | namespace = hf_api.whoami()["name"] 57 | else: 58 | namespace = organization 59 | repo_id = namespace + "/" + dataset_slug 60 | 61 | local_path = snapshot_download(repo_id=repo_id, repo_type="dataset") 62 | index_path = local_path + "/index" 63 | return index_path 64 | -------------------------------------------------------------------------------- /src/spacerini/preprocess/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from datasets import Dataset 3 | from datasets.utils.py_utils import convert_file_size_to_int 4 | import os 5 | 6 | 7 | def shard_dataset(hf_dataset: Dataset, shard_size: Union[int, str], shards_paths: str, column_to_index: str) -> None: # TODO break this up into smaller functions 8 | """ 9 | Shard a dataset into multiple files. 10 | Parameters 11 | ---------- 12 | hf_dataset : datasets.Dataset 13 | a Hugging Face datasets object 14 | shard_size : str 15 | The size of each arrow shard that gets written as a JSON file. 16 | shards_paths : str 17 | The path to the directory where the shards will be stored. 18 | column_to_index : str 19 | The column to index mapping. 20 | 21 | Returns 22 | ------- 23 | None 24 | """ 25 | hf_dataset = hf_dataset.remove_columns([c for c in hf_dataset.column_names if c!=column_to_index]) 26 | hf_dataset = hf_dataset.rename_column(column_to_index, "contents") # pyserini only wants a content column and an index column 27 | hf_dataset = hf_dataset.add_column("id", range(len(hf_dataset))) 28 | num_shards = get_num_shards(hf_dataset.data.nbytes, shard_size) 29 | os.makedirs(shards_paths, exist_ok=True) 30 | for shard_index in range(num_shards): 31 | shard = hf_dataset.shard(num_shards=num_shards, index=shard_index, contiguous=True) 32 | shard.to_json(f"{shards_paths}/docs-{shard_index:03d}.jsonl", orient="records", lines=True) 33 | 34 | 35 | def get_num_shards(dataset_size: Union[int, str], max_shard_size: Union[int, str]) -> int: 36 | """ 37 | Returns the number of shards required for a maximum shard size for a datasets.Dataset of a given size. 38 | Parameters 39 | ---------- 40 | dataset_size: int or str 41 | The size of the dataset in either number of bytes or a string such as "10MB" or "6GB" 42 | max_shard_size: int or str 43 | The maximum size for every corresponding arrow shard in either number of bytes or a string such as "10MB" or "6GB". 44 | 45 | Returns 46 | ------- 47 | int 48 | """ 49 | max_shard_size = convert_file_size_to_int(max_shard_size) 50 | dataset_nbytes = convert_file_size_to_int(dataset_size) 51 | num_shards = int(dataset_nbytes / max_shard_size) + 1 52 | return max(num_shards, 1) 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | apps/* 132 | 133 | .vscode/ 134 | apps/* -------------------------------------------------------------------------------- /examples/notebooks/indexing_tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from spacerini.index import index_streaming_hf_dataset" 10 | ] 11 | }, 12 | { 13 | "attachments": {}, 14 | "cell_type": "markdown", 15 | "metadata": {}, 16 | "source": [ 17 | "In this notebook, we demonstrate how to index the Swahili part of the [`Mr. TyDi corpus`](https://huggingface.co/datasets/castorini/mr-tydi-corpus) on the fly using the `bert-base-multilingual-uncased` tokenizer available on [`HuggingFace`](https://huggingface.co/bert-base-multilingual-uncased). This should take roughly 3mins on a Mac M1.\n", 18 | "\n", 19 | "For a full understanding of arguments passed to Anserini, see [`io.anserini.index.IndexCollection.Args`](https://github.com/castorini/anserini)" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 3, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stdout", 29 | "output_type": "stream", 30 | "text": [ 31 | "WARNING: sun.reflect.Reflection.getCallerClass is not supported. This will impact performance.\n", 32 | "2023-02-14 00:31:59,337 INFO [main] index.SimpleIndexer (SimpleIndexer.java:120) - Bert Tokenizer\n" 33 | ] 34 | }, 35 | { 36 | "name": "stderr", 37 | "output_type": "stream", 38 | "text": [ 39 | "136689it [01:36, 1414.10it/s]\n" 40 | ] 41 | } 42 | ], 43 | "source": [ 44 | "index_streaming_hf_dataset(\n", 45 | " \"./test-index\",\n", 46 | " \"castorini/mr-tydi-corpus\",\n", 47 | " \"train\",\n", 48 | " \"text\",\n", 49 | " ds_config_name=\"swahili\", # This is used by datasets to load the Swahili split of the mr-tydi-corpus\n", 50 | " language=\"sw\", # This is passed to Anserini and may be used to initialize a Lucene Indexer\n", 51 | " analyzeWithHuggingFaceTokenizer=\"bert-base-multilingual-uncased\"\n", 52 | ")" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [] 61 | } 62 | ], 63 | "metadata": { 64 | "kernelspec": { 65 | "display_name": "pyserini-dev", 66 | "language": "python", 67 | "name": "python3" 68 | }, 69 | "language_info": { 70 | "codemirror_mode": { 71 | "name": "ipython", 72 | "version": 3 73 | }, 74 | "file_extension": ".py", 75 | "mimetype": "text/x-python", 76 | "name": "python", 77 | "nbconvert_exporter": "python", 78 | "pygments_lexer": "ipython3", 79 | "version": "3.8.15" 80 | }, 81 | "orig_nbformat": 4, 82 | "vscode": { 83 | "interpreter": { 84 | "hash": "b0d9161c88aeda2dc6f47d29dac86208dc1568afc6233ac80536ab4d91f86f7a" 85 | } 86 | } 87 | }, 88 | "nbformat": 4, 89 | "nbformat_minor": 2 90 | } 91 | -------------------------------------------------------------------------------- /tests/test_index.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from os import path 4 | import unittest 5 | from spacerini.index import fetch_index_stats, index_streaming_dataset 6 | from pyserini.search.lucene import LuceneSearcher 7 | from typing import List 8 | 9 | 10 | class TestIndex(unittest.TestCase): 11 | def setUp(self): 12 | self.index_path = path.join(path.dirname(__file__), "indexes") 13 | self.dataset_name_or_path = path.join(path.dirname(__file__),"data/sample_documents.jsonl") 14 | os.makedirs(self.index_path, exist_ok=True) 15 | 16 | 17 | def test_index_streaming_hf_dataset_local(self): 18 | """ 19 | Test indexing a local dataset 20 | """ 21 | local_index_path = path.join(self.index_path, "local") 22 | index_streaming_dataset( 23 | index_path=local_index_path, 24 | dataset_name_or_path=self.dataset_name_or_path, 25 | split="train", 26 | column_to_index=["contents"], 27 | doc_id_column="id", 28 | language="en", 29 | storeContents = True, 30 | storeRaw = True, 31 | num_rows=3 32 | ) 33 | 34 | self.assertTrue(os.path.exists(local_index_path)) 35 | searcher = LuceneSearcher(local_index_path) 36 | 37 | hits = searcher.search('contents') 38 | self.assertTrue(isinstance(hits, List)) 39 | self.assertEqual(hits[0].docid, 'doc1') 40 | self.assertEqual(hits[0].contents, "contents of doc one.") 41 | 42 | index_stats = fetch_index_stats(local_index_path) 43 | self.assertEqual(index_stats["total_terms"], 11) 44 | self.assertEqual(index_stats["documents"], 3) 45 | self.assertEqual(index_stats["unique_terms"], 9) 46 | 47 | 48 | def test_index_streaming_hf_dataset_huggingface(self): 49 | """ 50 | Test indexing a dataset from HuggingFace Hub 51 | """ 52 | hgf_index_path = path.join(self.index_path, "hgf") 53 | index_streaming_dataset( 54 | index_path=hgf_index_path, 55 | dataset_name_or_path="sciq", 56 | split="test", 57 | column_to_index=["question", "support"], 58 | language="en", 59 | num_rows=1000, 60 | storeContents = True, 61 | storeRaw = True 62 | ) 63 | 64 | self.assertTrue(os.path.exists(hgf_index_path)) 65 | searcher = LuceneSearcher(hgf_index_path) 66 | index_stats = fetch_index_stats(hgf_index_path) 67 | 68 | hits = searcher.search('contents') 69 | self.assertTrue(isinstance(hits, List)) 70 | self.assertEqual(hits[0].docid, '528') 71 | self.assertEqual(index_stats["total_terms"], 54197) 72 | self.assertEqual(index_stats["documents"], 1000) 73 | 74 | def tearDown(self): 75 | shutil.rmtree(self.index_path) 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /src/spacerini/frontend/__main__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | from spacerini.frontend import create_space_from_local, create_app 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def parser(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument( 12 | "--template", 13 | type=str, 14 | help="The path to a predefined template to use.", 15 | required=True 16 | ) 17 | parser.add_argument( 18 | "--extra_context_dict", 19 | type=dict, 20 | help="The extra context to pass to the template.", 21 | ) 22 | parser.add_argument( 23 | "--no_input", 24 | type=bool, 25 | help="If True, do not prompt for parameters and only use", 26 | ) 27 | parser.add_argument( 28 | "--overwrite_if_exists", 29 | type=bool, 30 | help="If True, overwrite the output directory if it already exists.", 31 | default=True 32 | ) 33 | parser.add_argument( 34 | "--space_slug", 35 | type=str, 36 | help="The name of the space on huggingface.", 37 | required=True 38 | ) 39 | parser.add_argument( 40 | "--space_sdk", 41 | type=str, 42 | help="The SDK of the space, could be either Gradio or Streamlit.", 43 | choices=["gradio", "streamlit"], 44 | required=True 45 | ) 46 | parser.add_argument( 47 | "--local_dir", 48 | type=str, 49 | help="The local directory where the app should be stored.", 50 | ) 51 | parser.add_argument( 52 | "--private", 53 | type=bool, 54 | help="If True, the space will be private.", 55 | default=False 56 | ) 57 | parser.add_argument( 58 | "--organization", 59 | type=str, 60 | help="The organization to create the space in.", 61 | default=None 62 | ) 63 | parser.add_argument( 64 | "--delete_after_push", 65 | type=bool, 66 | help="If True, delete the local directory after pushing it to the Hub.", 67 | ) 68 | return parser 69 | 70 | def main(): 71 | parser = parser() 72 | args = parser.parse_args() 73 | 74 | logger.info("Validating the input arguments...") 75 | assert os.path.exists(args.template), "The template does not exist." 76 | assert os.path.exists(args.local_dir), "The local directory does not exist." 77 | 78 | logger.info("Creating the app locally...") 79 | create_app( 80 | template=args.template, 81 | extra_context_dict=args.extra_context_dict, 82 | output_dir=args.local_dir, 83 | no_input=args.no_input, 84 | overwrite_if_exists=args.overwrite_if_exists 85 | ) 86 | 87 | logger.info("Creating the space on huggingface...") 88 | create_space_from_local( 89 | space_slug=args.space_slug, 90 | space_sdk=args.space_sdk, 91 | local_dir=args.local_dir, 92 | private=args.private, 93 | organization=args.organization, 94 | delete_after_push=args.delete_after_push 95 | ) 96 | 97 | 98 | if __name__ == "__main__": 99 | pass -------------------------------------------------------------------------------- /src/spacerini/index/encode.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | from typing import List 3 | from typing import Literal 4 | from typing import Optional 5 | from typing import Protocol 6 | 7 | from pyserini.encode.__main__ import init_encoder 8 | from pyserini.encode import RepresentationWriter 9 | from pyserini.encode import FaissRepresentationWriter 10 | from pyserini.encode import JsonlCollectionIterator 11 | from pyserini.encode import JsonlRepresentationWriter 12 | 13 | EncoderClass = Literal["dkrr", "dpr", "tct_colbert", "ance", "sentence", "contriever", "auto"] 14 | 15 | 16 | class Encoder(Protocol): 17 | def encode(**kwargs): ... 18 | 19 | 20 | def init_writer( 21 | embedding_dir: str, 22 | embedding_dimension: int = 768, 23 | output_to_faiss: bool = False 24 | ) -> RepresentationWriter: 25 | """ 26 | """ 27 | if output_to_faiss: 28 | writer = FaissRepresentationWriter(embedding_dir, dimension=embedding_dimension) 29 | return writer 30 | 31 | return JsonlRepresentationWriter(embedding_dir) 32 | 33 | 34 | def encode_corpus_or_shard( 35 | encoder: Encoder, 36 | collection_iterator: Iterable[dict], 37 | embedding_writer: RepresentationWriter, 38 | batch_size: int, 39 | shard_id: int, 40 | shard_num: int, 41 | max_length: int = 256, 42 | add_sep: bool = False, 43 | input_fields: List[str] = None, 44 | title_column_to_encode: Optional[str] = None, 45 | text_column_to_encode: Optional[str] = "text", 46 | expand_column_to_encode: Optional[str] = None, 47 | fp16: bool = False 48 | ) -> None: 49 | """ 50 | 51 | """ 52 | with embedding_writer: 53 | for batch_info in collection_iterator(batch_size, shard_id, shard_num): 54 | kwargs = { 55 | 'texts': batch_info[text_column_to_encode], 56 | 'titles': batch_info[title_column_to_encode] if title_column_to_encode else None, 57 | 'expands': batch_info[expand_column_to_encode] if expand_column_to_encode else None, 58 | 'fp16': fp16, 59 | 'max_length': max_length, 60 | 'add_sep': add_sep, 61 | } 62 | embeddings = encoder.encode(**kwargs) 63 | batch_info['vector'] = embeddings 64 | embedding_writer.write(batch_info, input_fields) 65 | 66 | 67 | def encode_json_dataset( 68 | data_path: str, 69 | encoder_name_or_path: str, 70 | encoder_class: EncoderClass, 71 | embedding_dir: str, 72 | batch_size: int, 73 | index_shard_id: int = 0, 74 | num_index_shards: int = 1, 75 | device: str = "cuda:0", 76 | delimiter: str = "\n", 77 | max_length: int = 256, 78 | add_sep: bool = False, 79 | input_fields: List[str] = None, 80 | title_column_to_encode: Optional[str] = None, 81 | text_column_to_encode: Optional[str] = "text", 82 | expand_column_to_encode: Optional[str] = None, 83 | output_to_faiss: bool = False, 84 | embedding_dimension: int = 768, 85 | fp16: bool = False 86 | ) -> None: 87 | """ 88 | 89 | """ 90 | if input_fields is None: 91 | input_fields = ["text"] 92 | 93 | encoder = init_encoder(encoder_name_or_path, encoder_class, device=device) 94 | 95 | writer = init_writer( 96 | embedding_dir=embedding_dir, 97 | embedding_dimension=embedding_dimension, 98 | output_to_faiss=output_to_faiss 99 | ) 100 | 101 | collection_iterator = JsonlCollectionIterator(data_path, input_fields, delimiter) 102 | encode_corpus_or_shard( 103 | encoder=encoder, 104 | collection_iterator=collection_iterator, 105 | embedding_writer=writer, 106 | batch_size=batch_size, 107 | shard_id=index_shard_id, 108 | shard_num=num_index_shards, 109 | max_length=max_length, 110 | add_sep=add_sep, 111 | input_fields=input_fields, 112 | title_column_to_encode=title_column_to_encode, 113 | text_column_to_encode=text_column_to_encode, 114 | expand_column_to_encode=expand_column_to_encode, 115 | fp16=fp16 116 | ) 117 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spacerini 🦄 2 | 3 | Spacerini is a modular framework for seamless building and deployment of interactive search application. It integrates Pyserini and the HuggingFace 🤗 ecosystem to enable facilitate the qualitative analysis of large scale research datasets. 4 | 5 | You can index collections and deploy them as ad-hoc search engines for efficient retrieval of relevant documents. 6 | Spacerini provides a customisable, user-friendly interface for auditing massive datasets. Spacerini also allows Information Retrieval (IR) researchers and Search Engineers to demonstrate the capabilities of their indices easily and interactively. 7 | 8 | Spacerini currently supports the use of Gradio and Streamlit to create these search applications. In the future, we will also support deployment of docker containers. 9 | 10 | ## Installation ⚒️ 11 | 12 | To get started create an access token with write access on huggingface to enable the creation of spaces. [Check here](https://huggingface.co/docs/hub/security-tokens) for documentation on user access tokens on huggingface. 13 | 14 | Run `huggingface-cli login` to register your access token locally. 15 | 16 | ### From Github 17 | 18 | - `pip install git+https://github.com/castorini/hf-spacerini.git` 19 | 20 | ### Development Installation 21 | 22 | You will need a development installation if you are contributing to Spacerini. 23 | 24 | - Create a virual environment - `conda create --name spacerini python=3.9` 25 | - Clone the repository - `git clone https://github.com/castorini/hf-spacerini.git` 26 | - Install in editable mode -`pip install -e ".[dev]"` 27 | 28 | 29 | ## Creating Spaces applications 🔎 30 | 31 | Spacerini provides flexibility. You can customize every step of the `index-create-deploy` process as your project requires. You can provide your own index, built a more interactive web application and deploy changes as necessary. 32 | 33 | Some of the commands that allow this flexibility are: 34 | 35 | * index: Create a sparse, dense or hybrid index from specified dataset. This does not create a space. 36 | * create-space: Create index from dataset and create HuggingFace Space application. 37 | * deploy: Create HuggingFace Space application and deploy! 38 | * deploy-only: Deploy an already created or recently modified Space application. 39 | 40 | If you have an existing index you have built, you can pass the `--index-exists` flag to any of the listed commands. Run `spacerini --help` for a full list of commands and arguments. 41 | 42 | 1. Getting started 43 | 44 | You can deploy a search system for the test set of the Extreme Summatization [XSUM](https://huggingface.co/datasets/xsum) dataset on huggingface using the following command. This system is based on our [gradio_roots_temp](templates/gradio_roots_temp/) template. 45 | 46 | ```bash 47 | spacerini --from-example xsum deploy 48 | ``` 49 | 50 | Voila!, you have successfully deployed a search engine on HuggingFace Spaces 🤩🥳! This command downloads the XSUM dataset from the HuggingFace Hub, builds an interative user interface on top of a sparse index the dataset and deploys the application to HuggingFace Spaces. You can find the configurations for the application in: [examples/configs/xsum.json](examples/configs/xsum.json). 51 | 52 | 2. Building your own custom application 53 | 54 | The easiest way to build your own application is to provide a JSON configuration containing arguments for indexing, creating your space and deploying it. 55 | 56 | ```bash 57 | spacerini --config-file 58 | ``` 59 | where `` may be `index`, `create-space`, `deploy` or `deploy-only`. 60 | 61 | It helps to familiarise yourself with the arguments available for each step by running. 62 | 63 | ```bash 64 | spacerini --help 65 | ``` 66 | 67 | If you are using a custom template you have built, you can pass it using the `template` argument. Once your application has been created locally, you can run it before deploying. 68 | 69 | ```bash 70 | cd apps/ 71 | python app.py 72 | ``` 73 | 74 | After completing all necessary modifications, run the following command to deploy your Spaces application 75 | 76 | ```bash 77 | spacerini --config-file deploy-only 78 | ``` 79 | 80 | ## Contribution 81 | 82 | ## Acknowledgement 83 | -------------------------------------------------------------------------------- /src/spacerini/spacerini_utils/search.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List, Literal, Protocol, Tuple, TypedDict, Union 3 | 4 | from pyserini.analysis import get_lucene_analyzer 5 | from pyserini.index import IndexReader 6 | from pyserini.search import DenseSearchResult, JLuceneSearcherResult 7 | from pyserini.search.faiss.__main__ import init_query_encoder 8 | from pyserini.search.faiss import FaissSearcher 9 | from pyserini.search.hybrid import HybridSearcher 10 | from pyserini.search.lucene import LuceneSearcher 11 | 12 | EncoderClass = Literal["dkrr", "dpr", "tct_colbert", "ance", "sentence", "contriever", "auto"] 13 | 14 | 15 | class AnalyzerArgs(TypedDict): 16 | language: str 17 | stemming: bool 18 | stemmer: str 19 | stopwords: bool 20 | huggingFaceTokenizer: str 21 | 22 | 23 | class SearchResult(TypedDict): 24 | docid: str 25 | text: str 26 | score: float 27 | language: str 28 | 29 | 30 | class Searcher(Protocol): 31 | def search(self, query: str, **kwargs) -> List[Union[DenseSearchResult, JLuceneSearcherResult]]: 32 | ... 33 | 34 | 35 | def init_searcher_and_reader( 36 | sparse_index_path: str = None, 37 | bm25_k1: float = None, 38 | bm25_b: float = None, 39 | analyzer_args: AnalyzerArgs = None, 40 | dense_index_path: str = None, 41 | encoder_name_or_path: str = None, 42 | encoder_class: EncoderClass = None, 43 | tokenizer_name: str = None, 44 | device: str = None, 45 | prefix: str = None 46 | ) -> Tuple[Union[FaissSearcher, HybridSearcher, LuceneSearcher], IndexReader]: 47 | """ 48 | Initialize and return an approapriate searcher 49 | 50 | Parameters 51 | ---------- 52 | sparse_index_path: str 53 | Path to sparse index 54 | dense_index_path: str 55 | Path to dense index 56 | encoder_name_or_path: str 57 | Path to query encoder checkpoint or encoder name 58 | encoder_class: str 59 | Query encoder class to use. If None, infer from `encoder` 60 | tokenizer_name: str 61 | Tokenizer name or path 62 | device: str 63 | Device to load Query encoder on. 64 | prefix: str 65 | Query prefix if exists 66 | 67 | Returns 68 | ------- 69 | Searcher: FaissSearcher | HybridSearcher | LuceneSearcher 70 | A sparse, dense or hybrid searcher 71 | """ 72 | reader = None 73 | if sparse_index_path: 74 | ssearcher = LuceneSearcher(sparse_index_path) 75 | if analyzer_args: 76 | analyzer = get_lucene_analyzer(**analyzer_args) 77 | ssearcher.set_analyzer(analyzer) 78 | if bm25_k1 and bm25_b: 79 | ssearcher.set_bm25(bm25_k1, bm25_b) 80 | 81 | if dense_index_path: 82 | encoder = init_query_encoder( 83 | encoder=encoder_name_or_path, 84 | encoder_class=encoder_class, 85 | tokenizer_name=tokenizer_name, 86 | topics_name=None, 87 | encoded_queries=None, 88 | device=device, 89 | prefix=prefix 90 | ) 91 | 92 | reader = IndexReader(sparse_index_path) 93 | dsearcher = FaissSearcher(dense_index_path, encoder) 94 | 95 | if sparse_index_path: 96 | hsearcher = HybridSearcher(dense_searcher=dsearcher, sparse_searcher=ssearcher) 97 | return hsearcher, reader 98 | else: 99 | return dsearcher, reader 100 | 101 | return ssearcher, reader 102 | 103 | 104 | def _search(searcher: Searcher, reader: IndexReader, query: str, num_results: int = 10) -> List[SearchResult]: 105 | """ 106 | Parameters: 107 | ----------- 108 | searcher: FaissSearcher | HybridSearcher | LuceneSearcher 109 | A sparse, dense or hybrid searcher 110 | query: str 111 | Query for which to retrieve results 112 | num_results: int 113 | Maximum number of results to retrieve 114 | 115 | Returns: 116 | -------- 117 | Dict: 118 | """ 119 | def _get_dict(r: Union[DenseSearchResult, JLuceneSearcherResult]): 120 | if isinstance(r, JLuceneSearcherResult): 121 | return json.loads(r.raw) 122 | elif isinstance(r, DenseSearchResult): 123 | # Get document from sparse_index using index reader 124 | return json.loads(reader.doc(r.docid).raw()) 125 | 126 | search_results = searcher.search(query, k=num_results) 127 | all_results = [ 128 | SearchResult( 129 | docid=result["id"], 130 | text=result["contents"], 131 | score=search_results[idx].score 132 | ) for idx, result in enumerate(map(lambda r: _get_dict(r), search_results)) 133 | ] 134 | 135 | return all_results 136 | -------------------------------------------------------------------------------- /templates/gradio/{{ cookiecutter.local_app }}/app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from datasets import load_from_disk 3 | from pyserini.search.lucene import LuceneSearcher 4 | 5 | searcher = LuceneSearcher("index") 6 | ds = load_from_disk("data") 7 | NUM_PAGES = 10 # STATIC. THIS CAN'T CHANGE BECAUSE GRADIO CAN'T DYNAMICALLY CREATE COMPONENTS. 8 | RESULTS_PER_PAGE = 5 9 | 10 | TEXT_FIELD = "{{ cookiecutter.dset_text_field }}" 11 | METADATA_FIELD = "{{ cookiecutter.metadata_field }}" 12 | 13 | 14 | def result_html(result, meta): 15 | return ( 16 | f"
{meta}

" 17 | f"
{result[:250]}...

{result[250:]}




" 18 | ) 19 | 20 | 21 | def format_results(results): 22 | return "\n".join([result_html(result, meta) for result,meta in zip(results[TEXT_FIELD], results[METADATA_FIELD])]) 23 | 24 | 25 | def page_0(query): 26 | hits = searcher.search(query, k=NUM_PAGES*RESULTS_PER_PAGE) 27 | ix = [int(hit.docid) for hit in hits] 28 | results = ds.select(ix).shard(num_shards=NUM_PAGES, index=0, contiguous=True) # no need to shard. split ix in batches instead. (would make sense if results was cacheable) 29 | results = format_results(results) 30 | return results, [ix], gr.update(visible=True) 31 | 32 | 33 | def page_i(i, ix): 34 | ix = ix[0] 35 | results = ds.select(ix).shard(num_shards=NUM_PAGES, index=i, contiguous=True) 36 | results = format_results(results) 37 | return results, [ix] 38 | 39 | 40 | with gr.Blocks(css="#b {min-width:15px;background:transparent;border:white;box-shadow:none;}") as demo: # 41 | with gr.Row(): 42 | gr.Markdown(value="""##

{{ cookiecutter.space_title }}

""") 43 | with gr.Row(): 44 | with gr.Column(scale=1): 45 | result_list = gr.Dataframe(type="array", visible=False, col_count=1) 46 | with gr.Column(scale=13): 47 | query = gr.Textbox(lines=1, max_lines=1, placeholder="Search…", label="") 48 | with gr.Column(scale=1): 49 | with gr.Row(scale=1): 50 | pass 51 | with gr.Row(scale=1): 52 | submit_btn = gr.Button("🔍", elem_id="b").style(full_width=False) 53 | with gr.Row(scale=1): 54 | pass 55 | 56 | with gr.Row(): 57 | with gr.Column(scale=1): 58 | pass 59 | with gr.Column(scale=13): 60 | c = gr.HTML(label="Results") 61 | with gr.Row(visible=False) as pagination: 62 | # left = gr.Button(value="◀", elem_id="b", visible=False).style(full_width=True) 63 | page_1 = gr.Button(value="1", elem_id="b").style(full_width=True) 64 | page_2 = gr.Button(value="2", elem_id="b").style(full_width=True) 65 | page_3 = gr.Button(value="3", elem_id="b").style(full_width=True) 66 | page_4 = gr.Button(value="4", elem_id="b").style(full_width=True) 67 | page_5 = gr.Button(value="5", elem_id="b").style(full_width=True) 68 | page_6 = gr.Button(value="6", elem_id="b").style(full_width=True) 69 | page_7 = gr.Button(value="7", elem_id="b").style(full_width=True) 70 | page_8 = gr.Button(value="8", elem_id="b").style(full_width=True) 71 | page_9 = gr.Button(value="9", elem_id="b").style(full_width=True) 72 | page_10 = gr.Button(value="10", elem_id="b").style(full_width=True) 73 | # right = gr.Button(value="▶", elem_id="b", visible=False).style(full_width=True) 74 | with gr.Column(scale=1): 75 | pass 76 | 77 | query.submit(fn=page_0, inputs=[query], outputs=[c, result_list, pagination]) 78 | submit_btn.click(page_0, inputs=[query], outputs=[c, result_list, pagination]) 79 | 80 | with gr.Box(visible=False): 81 | nums = [gr.Number(i, visible=False, precision=0) for i in range(NUM_PAGES)] 82 | 83 | page_1.click(fn=page_i, inputs=[nums[0], result_list], outputs=[c, result_list]) 84 | page_2.click(fn=page_i, inputs=[nums[1], result_list], outputs=[c, result_list]) 85 | page_3.click(fn=page_i, inputs=[nums[2], result_list], outputs=[c, result_list]) 86 | page_4.click(fn=page_i, inputs=[nums[3], result_list], outputs=[c, result_list]) 87 | page_5.click(fn=page_i, inputs=[nums[4], result_list], outputs=[c, result_list]) 88 | page_6.click(fn=page_i, inputs=[nums[5], result_list], outputs=[c, result_list]) 89 | page_7.click(fn=page_i, inputs=[nums[6], result_list], outputs=[c, result_list]) 90 | page_8.click(fn=page_i, inputs=[nums[7], result_list], outputs=[c, result_list]) 91 | page_9.click(fn=page_i, inputs=[nums[8], result_list], outputs=[c, result_list]) 92 | page_10.click(fn=page_i, inputs=[nums[9], result_list], outputs=[c, result_list]) 93 | 94 | demo.launch(enable_queue=True, debug=True) 95 | -------------------------------------------------------------------------------- /templates/gradio_roots_temp/{{ cookiecutter.local_app }}/app.py: -------------------------------------------------------------------------------- 1 | from typing import List, NewType, Optional, Union 2 | 3 | import gradio as gr 4 | 5 | from spacerini_utils.index import fetch_index_stats 6 | from spacerini_utils.search import _search, init_searcher, SearchResult 7 | 8 | HTML = NewType('HTML', str) 9 | 10 | searcher = init_searcher(sparse_index_path="sparse_index") 11 | 12 | def get_docid_html(docid: Union[int, str]) -> HTML: 13 | {% if cookiecutter.private -%} 14 | docid_html = ( 15 | f"🔒{{ cookiecutter.dataset_name }}/'+f'{docid}' 20 | ) 21 | {%- else -%} 22 | docid_html = ( 23 | f"🔒{{ cookiecutter.dataset_name }}/'+f'{docid}' 29 | ) 30 | {%- endif %} 31 | 32 | return docid_html 33 | 34 | 35 | def process_results(results: List[SearchResult], language: str, highlight_terms: Optional[List[str]] = None) -> HTML: 36 | if len(results) == 0: 37 | return """

38 | No results retrieved.



""" 39 | 40 | results_html = "" 41 | for result in results: 42 | tokens = result["text"].split() 43 | 44 | tokens_html = [] 45 | if highlight_terms: 46 | for token in tokens: 47 | if token in highlight_terms: 48 | tokens_html.append("{}".format(token)) 49 | else: 50 | tokens_html.append(token) 51 | 52 | tokens_html = " ".join(tokens_html) 53 | meta_html = ( 54 | """ 55 |

56 | """ 57 | ) 58 | docid_html = get_docid_html(result["docid"]) 59 | results_html += """{} 60 |

Document ID: {}

61 |

Score: {}

62 |

Language: {}

63 |

{}

64 |
65 | """.format( 66 | meta_html, docid_html, result["score"], language, tokens_html 67 | ) 68 | return results_html + "
" 69 | 70 | 71 | def search(query: str, language: str, num_results: int = 10) -> HTML: 72 | results_dict = _search(searcher, query, num_results=num_results) 73 | return process_results(results_dict, language) 74 | 75 | 76 | stats = fetch_index_stats('sparse_index/') 77 | 78 | description = f"""#

{{ cookiecutter.emoji }} 🔎 {{ cookiecutter.space_title }} 🔍 {{ cookiecutter.emoji }}

79 |

{{ cookiecutter.space_description}}

80 |

Dataset Statistics: Total Number of Documents = {stats["documents"]}, Number of Terms = {stats["total_terms"]}

""" 81 | 82 | demo = gr.Blocks( 83 | css=".underline-on-hover:hover { text-decoration: underline; } .flagging { font-size:12px; color:Silver; }" 84 | ) 85 | 86 | with demo: 87 | with gr.Row(): 88 | gr.Markdown(value=description) 89 | with gr.Row(): 90 | query = gr.Textbox(lines=1, max_lines=1, placeholder="Type your query here...", label="Query") 91 | with gr.Row(): 92 | lang = gr.Dropdown( 93 | choices=[ 94 | "en", 95 | "detect_language", 96 | "all", 97 | ], 98 | value="en", 99 | label="Language", 100 | ) 101 | with gr.Row(): 102 | k = gr.Slider(1, 100, value=10, step=1, label="Max Results") 103 | with gr.Row(): 104 | submit_btn = gr.Button("Submit") 105 | with gr.Row(): 106 | results = gr.HTML(label="Results") 107 | 108 | 109 | def submit(query: str, lang: str, k: int): 110 | query = query.strip() 111 | if query is None or query == "": 112 | return "", "" 113 | return { 114 | results: search(query, lang, k), 115 | } 116 | 117 | query.submit(fn=submit, inputs=[query, lang, k], outputs=[results]) 118 | submit_btn.click(submit, inputs=[query, lang, k], outputs=[results]) 119 | 120 | demo.launch(enable_queue=True, debug=True) 121 | -------------------------------------------------------------------------------- /src/spacerini/data/load.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datasets import Dataset, IterableDataset 3 | from datasets import load_dataset 4 | import pandas as pd 5 | import ir_datasets 6 | from typing import Generator, Dict, List 7 | 8 | def ir_dataset_dict_generator(dataset_name: str) -> Generator[Dict,None,None]: 9 | """ 10 | Generator for streaming datasets from ir_datasets 11 | https://github.com/allenai/ir_datasets 12 | Parameters 13 | ---------- 14 | dataset_name : str 15 | Name of dataset to load 16 | """ 17 | dataset = ir_datasets.load(dataset_name) 18 | for doc in dataset.docs_iter(): 19 | yield { 20 | "contents": doc.text, 21 | "docid": doc.doc_id 22 | } 23 | 24 | def load_ir_dataset(dataset_name: str) -> Dataset: 25 | """ 26 | Load dataset from ir_datasets 27 | 28 | Parameters 29 | ---------- 30 | dataset_name : str 31 | Name of dataset to load 32 | 33 | Returns 34 | ------- 35 | Dataset 36 | """ 37 | dataset = ir_datasets.load(dataset_name) 38 | return Dataset.from_pandas(pd.DataFrame(dataset.docs_iter())) 39 | 40 | def load_ir_dataset_low_memory(dataset_name: str, num_proc: int) -> IterableDataset: 41 | """ 42 | Load dataset from ir_datasets by streaming into a Dataset object. This is slower than first loading the data into a pandas DataFrame but does not require loading the entire ir_dataset into memory. This variant also supports multiprocessing if the data is sharded, and only consumes the generator once then caches it so that future calls are instantaneous. 43 | 44 | Parameters 45 | ---------- 46 | dataset_name : str 47 | Name of dataset to load 48 | num_proc: int 49 | Number of processes to use 50 | 51 | Returns 52 | ------- 53 | Dataset 54 | """ 55 | return Dataset.from_generator(ir_dataset_dict_generator, gen_kwargs={"dataset_name": dataset_name}, num_proc=num_proc) 56 | 57 | def load_ir_dataset_streaming(dataset_name: str) -> IterableDataset: 58 | """ 59 | Load dataset from ir_datasets 60 | 61 | Parameters 62 | ---------- 63 | dataset_name : str 64 | Name of dataset to load 65 | 66 | Returns 67 | ------- 68 | IterableDataset 69 | """ 70 | return IterableDataset.from_generator(ir_dataset_dict_generator, gen_kwargs={"dataset_name": dataset_name}) 71 | 72 | def load_from_pandas(df: pd.DataFrame) -> Dataset: 73 | """ 74 | Load dataset from pandas DataFrame 75 | 76 | Parameters 77 | ---------- 78 | df : pd.DataFrame 79 | DataFrame to load 80 | 81 | Returns 82 | ------- 83 | Dataset 84 | """ 85 | return Dataset.from_pandas(df) 86 | 87 | def load_from_hub(dataset_name_or_path: str, split: str, config_name: str=None, streaming: bool = True) -> Dataset: 88 | """ 89 | Load dataset from HuggingFace Hub 90 | 91 | Parameters 92 | ---------- 93 | dataset_name_or_path : str 94 | Name of dataset to load 95 | split : str 96 | Split of dataset to load 97 | streaming : bool 98 | Whether to load dataset in streaming mode 99 | 100 | Returns 101 | ------- 102 | Dataset 103 | """ 104 | 105 | return load_dataset(dataset_name_or_path, split=split, streaming=streaming, name=config_name) 106 | 107 | def load_from_local(dataset_name_or_path: str or List[str], split: str, streaming: bool = True) -> Dataset: 108 | """ 109 | Load dataset from local text file. Supports JSON, JSONL, CSV, and TSV. 110 | 111 | Parameters 112 | ---------- 113 | dataset_name_or_path : str 114 | Path to a .json, .jsonl, .csv, or .tsv file 115 | streaming : bool 116 | Whether to load dataset in streaming mode 117 | 118 | Returns 119 | ------- 120 | Dataset 121 | """ 122 | if isinstance(dataset_name_or_path, str): 123 | assert os.path.exists(dataset_name_or_path), f"File {dataset_name_or_path} does not exist" 124 | if dataset_name_or_path.endswith(".jsonl") or dataset_name_or_path.endswith(".json"): 125 | return load_dataset("json", data_files=dataset_name_or_path, split=split, streaming=streaming) 126 | elif dataset_name_or_path.endswith(".csv"): 127 | return load_dataset("csv", data_files=dataset_name_or_path, split=split, streaming=streaming, sep=",") 128 | elif dataset_name_or_path.endswith(".tsv"): 129 | return load_dataset("csv", data_files=dataset_name_or_path, split=split, streaming=streaming, sep="\t") 130 | else: 131 | raise ValueError("Unsupported file type") 132 | 133 | def load_from_sqlite_table(uri_or_con: str, table_or_query: str) -> Dataset: 134 | """ 135 | Load dataset from a sqlite database table 136 | 137 | Parameters 138 | ---------- 139 | uri_or_con : str 140 | URI to a SQLITE database or connection object 141 | table_or_query : str 142 | database table or query that returns a table 143 | 144 | Returns 145 | ------- 146 | Dataset 147 | """ 148 | return Dataset.from_sql(con=uri_or_con, sql=table_or_query) 149 | -------------------------------------------------------------------------------- /src/spacerini/search/utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Dict 3 | from typing import List 4 | from typing import Literal 5 | from typing import Protocol 6 | from typing import TypedDict 7 | from typing import Union 8 | 9 | from datasets import Dataset 10 | from pyserini.analysis import get_lucene_analyzer 11 | from pyserini.search import DenseSearchResult 12 | from pyserini.search import JLuceneSearcherResult 13 | from pyserini.search.faiss.__main__ import init_query_encoder 14 | from pyserini.search.faiss import FaissSearcher 15 | from pyserini.search.hybrid import HybridSearcher 16 | from pyserini.search.lucene import LuceneSearcher 17 | 18 | Encoder = Literal["dkrr", "dpr", "tct_colbert", "ance", "sentence", "contriever", "auto"] 19 | 20 | 21 | class AnalyzerArgs(TypedDict): 22 | language: str 23 | stemming: bool 24 | stemmer: str 25 | stopwords: bool 26 | huggingFaceTokenizer: str 27 | 28 | 29 | class Searcher(Protocol): 30 | def search(self, query: str, **kwargs) -> List[Union[DenseSearchResult, JLuceneSearcherResult]]: 31 | ... 32 | 33 | 34 | def init_searcher( 35 | sparse_index_path: str = None, 36 | bm25_k1: float = None, 37 | bm25_b: float = None, 38 | analyzer_args: AnalyzerArgs = None, 39 | dense_index_path: str = None, 40 | encoder_name_or_path: str = None, 41 | encoder_class: Encoder = None, 42 | tokenizer_name: str = None, 43 | device: str = None, 44 | prefix: str = None 45 | ) -> Union[FaissSearcher, HybridSearcher, LuceneSearcher]: 46 | """ 47 | Initialize and return an approapriate searcher 48 | 49 | Parameters 50 | ---------- 51 | sparse_index_path: str 52 | Path to sparse index 53 | dense_index_path: str 54 | Path to dense index 55 | encoder_name_or_path: str 56 | Path to query encoder checkpoint or encoder name 57 | encoder_class: str 58 | Query encoder class to use. If None, infer from `encoder` 59 | tokenizer_name: str 60 | Tokenizer name or path 61 | device: str 62 | Device to load Query encoder on. 63 | prefix: str 64 | Query prefix if exists 65 | 66 | Returns 67 | ------- 68 | Searcher: 69 | A sparse, dense or hybrid searcher 70 | """ 71 | if sparse_index_path: 72 | ssearcher = LuceneSearcher(sparse_index_path) 73 | if analyzer_args: 74 | analyzer = get_lucene_analyzer(**analyzer_args) 75 | ssearcher.set_analyzer(analyzer) 76 | if bm25_k1 and bm25_b: 77 | ssearcher.set_bm25(bm25_k1, bm25_b) 78 | 79 | if dense_index_path: 80 | encoder = init_query_encoder( 81 | encoder=encoder_name_or_path, 82 | encoder_class=encoder_class, 83 | tokenizer_name=tokenizer_name, 84 | topics_name=None, 85 | encoded_queries=None, 86 | device=device, 87 | prefix=prefix 88 | ) 89 | 90 | dsearcher = FaissSearcher(dense_index_path, encoder) 91 | 92 | if sparse_index_path: 93 | hsearcher = HybridSearcher(dense_searcher=dsearcher, sparse_searcher=ssearcher) 94 | return hsearcher 95 | else: 96 | return dsearcher 97 | 98 | return ssearcher 99 | 100 | 101 | def result_indices( 102 | query: str, 103 | num_results: int, 104 | searcher: Searcher, 105 | index_path: str, 106 | analyzer=None 107 | ) -> list: 108 | """ 109 | Get the indices of the results of a query. 110 | Parameters 111 | ---------- 112 | query : str 113 | The query. 114 | num_results : int 115 | The number of results to return. 116 | index_path : str 117 | The path to the index. 118 | analyzer : str (default=None) 119 | The analyzer to use. 120 | 121 | Returns 122 | ------- 123 | list 124 | The indices of the returned documents. 125 | """ 126 | # searcher.search() 127 | # searcher = LuceneSearcher(index_path) 128 | # if analyzer is not None: 129 | # searcher.set_analyzer(analyzer) 130 | hits = searcher.search(query, k=num_results) 131 | ix = [int(hit.docid) for hit in hits] 132 | return ix 133 | 134 | 135 | def result_page( 136 | hf_dataset: Dataset, 137 | result_indices: List[int], 138 | page: int = 0, 139 | results_per_page: int=10 140 | ) -> Dataset: 141 | """ 142 | Returns a the ith results page as a datasets.Dataset object. Nothing is loaded into memory. Call `to_pandas()` on the returned Dataset to materialize the table. 143 | ---------- 144 | hf_dataset : datasets.Dataset 145 | a Hugging Face datasets dataset. 146 | result_indices : list of int 147 | The indices of the results. 148 | page: int (default=0) 149 | The result page to return. Returns the first page by default. 150 | results_per_page : int (default=10) 151 | The number of results per page. 152 | 153 | Returns 154 | ------- 155 | datasets.Dataset 156 | A results page. 157 | """ 158 | results = hf_dataset.select(result_indices) 159 | num_result_pages = int(len(results)/results_per_page) + 1 160 | return results.shard(num_result_pages, page, contiguous=True) 161 | -------------------------------------------------------------------------------- /src/spacerini/index/index.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | from typing import List 3 | from typing import Literal 4 | 5 | import shutil 6 | from pyserini.index.lucene import LuceneIndexer, IndexReader 7 | from typing import Any, Dict, List 8 | from pyserini.pyclass import autoclass 9 | from tqdm import tqdm 10 | import json 11 | import os 12 | 13 | from spacerini.data import load_from_hub, load_from_local 14 | 15 | 16 | def parse_args( 17 | index_path: str, 18 | shards_path: str = "", 19 | fields: List[str] = None, 20 | language: str = None, 21 | pretokenized: bool = False, 22 | analyzeWithHuggingFaceTokenizer: str = None, 23 | storePositions: bool = True, 24 | storeDocvectors: bool = False, 25 | storeContents: bool = False, 26 | storeRaw: bool = False, 27 | keepStopwords: bool = False, 28 | stopwords: str = None, 29 | stemmer: Literal["porter", "krovetz"] = None, 30 | optimize: bool = False, 31 | verbose: bool = False, 32 | quiet: bool = False, 33 | memory_buffer: str ="4096", 34 | n_threads: int = 5, 35 | for_otf_indexing: bool = False, 36 | **kwargs 37 | ) -> List[str]: 38 | """ 39 | Parse arguments into list for `SimpleIndexer` class in Anserini. 40 | 41 | Parameters 42 | ---------- 43 | for_otf_indexing : bool 44 | If True, `-input` & `-collection` args are safely ignored. 45 | Used when performing on-the-fly indexing with HF Datasets. 46 | 47 | See [docs](docs/arguments.md) for remaining argument definitions 48 | 49 | Returns 50 | ------- 51 | List of arguments to initialize the `LuceneIndexer` 52 | """ 53 | params = locals() 54 | args = [] 55 | args.extend([ 56 | "-input", shards_path, 57 | "-collection", "JsonCollection", 58 | "-threads", f"{n_threads}" if n_threads!=-1 else f"{os.cpu_count()}", 59 | "-generator", "DefaultLuceneDocumentGenerator", 60 | "-index", index_path, 61 | "-memorybuffer", memory_buffer, 62 | ]) 63 | variables = [ 64 | "analyzeWithHuggingFaceTokenizer", "language", "stemmer" 65 | ] 66 | additional_args = [[f"-{var}", params[var]] for var in variables if params[var]] 67 | args.extend(chain.from_iterable(additional_args)) 68 | 69 | if fields: 70 | args.extend(["-fields", " ".join(fields)]) 71 | 72 | flags = [ 73 | "pretokenized", "storePositions", "storeDocvectors", 74 | "storeContents", "storeRaw", "optimize", "verbose", "quiet" 75 | ] 76 | args.extend([f"-{flag}" for flag in flags if params[flag]]) 77 | 78 | if for_otf_indexing: 79 | args = args[4:] 80 | return args 81 | 82 | 83 | def index_json_shards( 84 | shards_path: str, 85 | index_path: str, 86 | keep_shards: bool = True, 87 | fields: List[str] = None, 88 | language: str = "en", 89 | pretokenized: bool = False, 90 | analyzeWithHuggingFaceTokenizer: str = None, 91 | storePositions: bool = True, 92 | storeDocvectors: bool = False, 93 | storeContents: bool = False, 94 | storeRaw: bool = False, 95 | keepStopwords: bool = False, 96 | stopwords: str = None, 97 | stemmer: Literal["porter", "krovetz"] = "porter", 98 | optimize: bool = True, 99 | verbose: bool = False, 100 | quiet: bool = False, 101 | memory_buffer: str = "4096", 102 | n_threads: bool = 5, 103 | ): 104 | """Index dataset from a directory containing files 105 | 106 | Parameters 107 | ---------- 108 | shards_path : str 109 | Path to dataset to index 110 | index_path : str 111 | Directory to store index 112 | keep_shards : bool 113 | If False, remove dataset after indexing is complete 114 | 115 | See [docs](../../docs/arguments.md) for remaining argument definitions 116 | 117 | Returns 118 | ------- 119 | None 120 | """ 121 | args = parse_args(**locals()) 122 | JIndexCollection = autoclass('io.anserini.index.IndexCollection') 123 | JIndexCollection.main(args) 124 | if not keep_shards: 125 | shutil.rmtree(shards_path) 126 | 127 | return None 128 | 129 | 130 | def index_streaming_dataset( 131 | index_path: str, 132 | dataset_name_or_path: str, 133 | split: str, 134 | column_to_index: List[str], 135 | doc_id_column: str = None, 136 | ds_config_name: str = None, # For HF Dataset 137 | num_rows: int = -1, 138 | disable_tqdm: bool = False, 139 | language: str = "en", 140 | pretokenized: bool = False, 141 | analyzeWithHuggingFaceTokenizer: str = None, 142 | storePositions: bool = True, 143 | storeDocvectors: bool = False, 144 | storeContents: bool = False, 145 | storeRaw: bool = False, 146 | keepStopwords: bool = False, 147 | stopwords: str = None, 148 | stemmer: Literal["porter", "krovetz"] = "porter", 149 | optimize: bool = True, 150 | verbose: bool = False, 151 | quiet: bool = True, 152 | memory_buffer: str = "4096", 153 | n_threads: bool = 5, 154 | ): 155 | """Stream dataset from HuggingFace Hub & index 156 | 157 | Parameters 158 | ---------- 159 | dataset_name_or_path : str 160 | Name of HuggingFace dataset to stream 161 | split : str 162 | Split of dataset to index 163 | column_to_index : List[str] 164 | Column of dataset to index 165 | doc_id_column : str 166 | Column of dataset to use as document ID 167 | ds_config_name: str 168 | Dataset configuration to stream. Usually a language name or code 169 | num_rows : int 170 | Number of rows in dataset 171 | 172 | See [docs](../../docs/arguments.md) for remaining argument definitions 173 | 174 | Returns 175 | ------- 176 | None 177 | """ 178 | 179 | args = parse_args(**locals(), for_otf_indexing=True) 180 | if os.path.exists(dataset_name_or_path): 181 | ds = load_from_local(dataset_name_or_path, split=split, streaming=True) 182 | else: 183 | ds = load_from_hub(dataset_name_or_path, split=split,config_name=ds_config_name, streaming=True) 184 | 185 | indexer = LuceneIndexer(args=args) 186 | 187 | for i, row in tqdm(enumerate(ds), total=num_rows, disable=disable_tqdm): 188 | contents = " ".join([row[column] for column in column_to_index]) 189 | indexer.add_doc_raw(json.dumps({"id": i if not doc_id_column else row[doc_id_column] , "contents": contents})) 190 | 191 | indexer.close() 192 | 193 | return None 194 | 195 | 196 | def fetch_index_stats(index_path: str) -> Dict[str, Any]: 197 | """ 198 | Fetch index statistics 199 | index_path : str 200 | Path to index directory 201 | Returns 202 | ------- 203 | Dictionary of index statistics 204 | Dictionary Keys ==> total_terms, documents, unique_terms 205 | """ 206 | assert os.path.exists(index_path), f"Index path {index_path} does not exist" 207 | index_reader = IndexReader(index_path) 208 | return index_reader.stats() 209 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /src/spacerini/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | from pathlib import Path 5 | from shutil import copytree 6 | 7 | from spacerini.frontend import create_app, create_space_from_local 8 | from spacerini.index import index_streaming_dataset 9 | from spacerini.index.encode import encode_json_dataset 10 | from spacerini.prebuilt import EXAMPLES 11 | 12 | 13 | def update_args_from_json(_args: argparse.Namespace, file: str) -> argparse.Namespace: 14 | config = json.load(open(file, "r")) 15 | config = {k.replace("-", "_"): v for k,v in config.items()} 16 | 17 | args_dict = vars(_args) 18 | args_dict.update(config) 19 | return _args 20 | 21 | 22 | # TODO: @theyorubayesian: Switch to HFArgumentParser with post init 23 | # How to use multiple argparse commands with HFArgumentParser 24 | def get_args() -> argparse.Namespace: 25 | parser = argparse.ArgumentParser( 26 | prog="Spacerini", 27 | description="A modular framework for seamless building and deployment of interactive search applications.", 28 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 29 | epilog="Written by: Akintunde 'theyorubayesian' Oladipo , Christopher Akiki , Odunayo Ogundepo " 30 | ) 31 | parser.add_argument("--space-name", required=False) 32 | parser.add_argument("--config-file", type=str, help="Path to configuration for space") 33 | parser.add_argument("--from-example", type=str, choices=list(EXAMPLES), help="Name of an example spaces applications.") 34 | parser.add_argument("--verbose", type=bool, default=True, help="If True, print verbose output") 35 | parser.add_argument("--index-exists", type=str, help="Path to an existing index to be used for Spaces application") 36 | 37 | # -------- 38 | # Commands 39 | # -------- 40 | sub_parser = parser.add_subparsers(dest="command", title="Commands", description="Valid commands") 41 | index_parser = sub_parser.add_parser("index", help="Index dataset. This will not create the app or deploy it.") 42 | create_parser = sub_parser.add_parser("create-space", help="Create space in local directory. This won't deploy the space.") 43 | deploy_only_parser = sub_parser.add_parser("deploy-only", help="Deploy an already created space") 44 | deploy_parser = sub_parser.add_parser("deploy", help="Deploy new space.") 45 | deploy_parser.add_argument("--delete-after", type=bool, default=False, help="If True, delete the local directory after pushing it to the Hub.") 46 | # -------- 47 | 48 | space_args = parser.add_argument_group("Space arguments") 49 | space_args.add_argument("--space-title", required=False, help="Title to show on Space") 50 | space_args.add_argument("--space-url-slug", required=False, help="") 51 | space_args.add_argument("--sdk", default="gradio", choices=["gradio", "streamlit"]) 52 | space_args.add_argument("--template", default="streamlit", help="A directory containing a project template directory, or a URL to a git repository.") 53 | space_args.add_argument("--organization", required=False, help="Organization to deploy new space under.") 54 | space_args.add_argument("--description", type=str, help="Description of new space") 55 | space_args.add_argument("--private", action="store_true", help="Deploy private Spaces application") 56 | 57 | data_args = parser.add_argument_group("Data arguments") 58 | data_args.add_argument("--columns-to-index", default=[], action="store", nargs="*", help="Other columns to index in dataset") 59 | data_args.add_argument("--split", type=str, required=False, default="train", help="Dataset split to index") 60 | data_args.add_argument("--dataset", type=str, required=False, help="Local dataset folder or Huggingface name") 61 | data_args.add_argument("--docid-column", default="id", help="Name of docid column in dataset") 62 | data_args.add_argument("--language", default="en", help="ISO Code for language of dataset") 63 | data_args.add_argument("--title-column", type=str, help="Name of title column in data") 64 | data_args.add_argument("--content-column", type=str, default="content", help="Name of content column in data") 65 | data_args.add_argument("--expand-column", type=str, help="Name of column containing document expansion. Used in dense indexes") 66 | 67 | sparse_index_args = parser.add_argument_group("Sparse Index arguments") 68 | sparse_index_args.add_argument("--collection", type=str, help="Collection class") 69 | sparse_index_args.add_argument("--memory-buffer", type=str, help="Memory buffer size") 70 | sparse_index_args.add_argument("--threads", type=int, default=5, help="Number of threads to use for indexing") 71 | sparse_index_args.add_argument("--hf-tokenizer", type=str, default=None, help="HuggingFace tokenizer to tokenize dataset") 72 | sparse_index_args.add_argument("--pretokenized", type=bool, default=False, help="If True, dataset is already tokenized") 73 | sparse_index_args.add_argument("--store-positions", action="store_true", help="If True, store document vectors in index") # TODO: @theyorubayesian 74 | sparse_index_args.add_argument("--store-docvectors", action="store_true", help="If True, store document vectors in index") # TODO: @theyorubayesian 75 | sparse_index_args.add_argument("--store-contents", action="store_true", help="If True, store contents of documents in index") 76 | sparse_index_args.add_argument("--store-raw", action="store_true", help="If True, store raw contents of documents in index") 77 | sparse_index_args.add_argument("--keep-stopwords", action="store_true", help="If True, keep stopwords in index") 78 | sparse_index_args.add_argument("--optimize-index", action="store_true", help="If True, optimize index after indexing is complete") 79 | sparse_index_args.add_argument("--stopwords", type=str, help="Path to stopwords file") 80 | sparse_index_args.add_argument("--stemmer", type=str, nargs=1, choices=["porter", "krovetz"], help="Stemmer to use for indexing") 81 | 82 | dense_index_args = parser.add_argument_group("Dense Index arguments") 83 | dense_index_args.add_argument("--encoder-name-or-path", type=str, help="Encoder name or path") 84 | dense_index_args.add_argument("--encoder-class", default="auto", type=str, choices=["dpr", "bpr", "tct_colbert", "ance", "sentence-transformers", "auto"], help="Encoder to use") 85 | dense_index_args.add_argument("--delimiter", default="\n", type=str, help="Delimiter for the fields in encoded corpus") 86 | dense_index_args.add_argument("--index-shard-id", default=0, type=int, help="Zero-based index shard id") 87 | dense_index_args.add_argument("--n-index-shards", type=int, default=1, help="Number of index shards") 88 | dense_index_args.add_argument("--batch-size", default=64, type=int, help="Batch size for encoding") 89 | dense_index_args.add_argument("--max-length", type=int, default=256, help="Max document length to encode") 90 | dense_index_args.add_argument('--device', default='cuda:0', type=str, help='Device: cpu or cuda [cuda:0, cuda:1...]', required=False) 91 | dense_index_args.add_argument("--dimension", default=768, type=int, help="Dimension for Faiss Index") 92 | dense_index_args.add_argument("--add-sep", action="store_true", help="Pass `title` and `content` columns separately into encode function") 93 | dense_index_args.add_argument("--to-faiss", action="store_true", help="Store embeddings in Faiss Index") 94 | dense_index_args.add_argument("--fp16", action="store_true", help="Use FP 16") 95 | 96 | search_args = parser.add_argument_group("Search arguments") 97 | search_args.add_argument("--bm25_k1", type=float, help="BM25: k1 parameter") 98 | search_args.add_argument("--bm24_b", type=float, help="BM25: b parameter") 99 | 100 | args, _ = parser.parse_known_args() 101 | 102 | if args.from_example in EXAMPLES: 103 | example_config_path = Path(__file__).parents[2] / "examples" \ 104 | / "configs" / f"{args.from_example}.json" 105 | args = update_args_from_json(args, example_config_path) 106 | 107 | # For customization, user provided config file supersedes example config file 108 | if args.config_file: 109 | args = update_args_from_json(args, args.config_file) 110 | 111 | args.template = "templates/gradio_roots_temp" 112 | return args 113 | 114 | 115 | def main(): 116 | args = get_args() 117 | 118 | local_app_dir = Path(f"apps/{args.space_name}") 119 | local_app_dir.mkdir(exist_ok=True) 120 | 121 | columns = [args.content_column, *args.columns_to_index] 122 | 123 | if args.command in ["index", "create-space", "deploy"]: 124 | logging.info(f"Indexing {args.dataset} dataset into {str(local_app_dir)}") 125 | 126 | if args.index_exists: 127 | index_dir = Path(args.index_exists) 128 | assert index_dir.exists(), f"No index found at {args.index_exists}" 129 | copytree(index_dir, local_app_dir / index_dir.name, dirs_exist_ok=True) 130 | else: 131 | # TODO: @theyorubayesian 132 | # We always create a sparse index because dense index only contain docid. Does this make sense memory-wise? 133 | # Another option could be to load the dataset from huggingface and filter documents by docids retrieved 134 | # Can documents be stored in dense indexes? Does it make sense to? 135 | index_streaming_dataset( 136 | dataset_name_or_path=args.dataset, 137 | index_path=(local_app_dir / "sparse_index").as_posix(), 138 | split=args.split, 139 | column_to_index=columns, 140 | doc_id_column=args.docid_column, 141 | language=args.language, 142 | storeContents=args.store_contents, 143 | storeRaw=args.store_raw, 144 | analyzeWithHuggingFaceTokenizer=args.hf_tokenizer 145 | ) 146 | 147 | if args.encoder_name_or_path: 148 | encode_json_dataset( 149 | data_path=args.dataset, 150 | encoder_name_or_path=args.encoder_name_or_path, 151 | encoder_class=args.encoder_class, 152 | embedding_dir=(local_app_dir / "dense_index").as_posix(), 153 | batch_size=args.batch_size, 154 | device=args.device, 155 | index_shard_id=args.index_shard_id, 156 | num_index_shards=args.n_index_shards, 157 | delimiter=args.delimiter, 158 | max_length=args.max_length, 159 | add_sep=args.add_sep, 160 | title_column_to_encode=args.title_column, 161 | text_column_to_encode=args.content_column, 162 | expand_column_to_encode=args.expand_column, 163 | output_to_faiss=True, 164 | embedding_dimension=args.dimension, 165 | fp16=args.fp16 166 | ) 167 | 168 | if args.command in ["create-space", "deploy"]: 169 | logging.info(f"Creating local app into {args.space_name} directory") 170 | # TODO: @theyorubayesian - How to make cookiecutter_vars more flexible 171 | cookiecutter_vars = { 172 | "dset_text_field": columns, 173 | "space_title": args.space_title, 174 | "local_app": args.space_name, 175 | "space_description": args.description, 176 | "dataset_name": args.dataset 177 | } 178 | 179 | create_app( 180 | template=args.template, 181 | extra_context_dict=cookiecutter_vars, 182 | output_dir="apps" 183 | ) 184 | 185 | if args.command in ["deploy", "deploy-only"]: 186 | logging.info(f"Creating space {args.space_name} on {args.organization}") 187 | create_space_from_local( 188 | space_slug=args.space_url_slug, 189 | organization=args.organization, 190 | space_sdk=args.sdk, 191 | local_dir=local_app_dir, 192 | delete_after_push=args.delete_after, 193 | private=args.private 194 | ) 195 | 196 | 197 | if __name__ == "__main__": 198 | main() 199 | --------------------------------------------------------------------------------