├── tests
├── __init__.py
├── test_data.py
├── tests
│ └── test_frontend.py
├── data
│ └── sample_documents.jsonl
└── test_index.py
├── src
└── spacerini
│ ├── data
│ ├── utils.py
│ ├── __main__.py
│ ├── __init__.py
│ └── load.py
│ ├── preprocess
│ ├── __main__.py
│ ├── __init__.py
│ ├── tokenize.py
│ └── utils.py
│ ├── spacerini_utils
│ ├── __init__.py
│ ├── index.py
│ └── search.py
│ ├── prebuilt.py
│ ├── __init__.py
│ ├── search
│ ├── __init__.py
│ └── utils.py
│ ├── frontend
│ ├── __init__.py
│ ├── local.py
│ ├── space.py
│ └── __main__.py
│ ├── index
│ ├── __init__.py
│ ├── utils.py
│ ├── encode.py
│ └── index.py
│ └── cli.py
├── templates
├── streamlit
│ ├── cookiecutter.json
│ └── {{ cookiecutter.module_slug }}
│ │ ├── index
│ │ └── .gitkeep
│ │ ├── packages.txt
│ │ ├── requirements.txt
│ │ ├── .gitattributes
│ │ ├── README.md
│ │ └── app.py
├── gradio
│ ├── {{ cookiecutter.local_app }}
│ │ ├── data
│ │ │ └── .gitkeep
│ │ ├── index
│ │ │ └── .gitkeep
│ │ ├── packages.txt
│ │ ├── requirements.txt
│ │ ├── .gitattributes
│ │ ├── README.md
│ │ └── app.py
│ └── cookiecutter.json
└── gradio_roots_temp
│ ├── {{ cookiecutter.local_app }}
│ ├── index
│ │ └── .gitkeep
│ ├── packages.txt
│ ├── requirements.txt
│ ├── .gitattributes
│ ├── README.md
│ └── app.py
│ └── cookiecutter.json
├── examples
├── configs
│ └── xsum.json
├── scripts
│ ├── index-push-pull.py
│ ├── xsum-demo.py
│ └── gradio-demo.py
└── notebooks
│ └── indexing_tutorial.ipynb
├── docs
└── arguments.md
├── pyproject.toml
├── .gitignore
├── README.md
└── LICENSE
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/test_data.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/spacerini/data/utils.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/tests/test_frontend.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/spacerini/data/__main__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/spacerini/preprocess/__main__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/templates/streamlit/cookiecutter.json:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/spacerini/spacerini_utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/spacerini/prebuilt.py:
--------------------------------------------------------------------------------
1 | EXAMPLES = {
2 | 'xsum',
3 | }
--------------------------------------------------------------------------------
/templates/gradio/{{ cookiecutter.local_app }}/data/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/templates/gradio/{{ cookiecutter.local_app }}/index/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/spacerini/preprocess/__init__.py:
--------------------------------------------------------------------------------
1 | from . import tokenize, utils
--------------------------------------------------------------------------------
/templates/streamlit/{{ cookiecutter.module_slug }}/index/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/templates/gradio_roots_temp/{{ cookiecutter.local_app }}/index/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/templates/gradio/{{ cookiecutter.local_app }}/packages.txt:
--------------------------------------------------------------------------------
1 | default-jdk
2 |
--------------------------------------------------------------------------------
/src/spacerini/__init__.py:
--------------------------------------------------------------------------------
1 | from . import index, preprocess, frontend, search, data
--------------------------------------------------------------------------------
/templates/streamlit/{{ cookiecutter.module_slug }}/packages.txt:
--------------------------------------------------------------------------------
1 | default-jdk
2 |
--------------------------------------------------------------------------------
/templates/gradio_roots_temp/{{ cookiecutter.local_app }}/packages.txt:
--------------------------------------------------------------------------------
1 | default-jdk
2 |
--------------------------------------------------------------------------------
/templates/streamlit/{{ cookiecutter.module_slug }}/requirements.txt:
--------------------------------------------------------------------------------
1 | pyserini
2 | faiss-cpu
3 | torch
--------------------------------------------------------------------------------
/templates/gradio/{{ cookiecutter.local_app }}/requirements.txt:
--------------------------------------------------------------------------------
1 | pyserini
2 | datasets
3 | faiss-cpu
4 | torch
--------------------------------------------------------------------------------
/templates/gradio/{{ cookiecutter.local_app }}/.gitattributes:
--------------------------------------------------------------------------------
1 | index/**/* filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/src/spacerini/search/__init__.py:
--------------------------------------------------------------------------------
1 | from . import utils
2 | from .utils import init_searcher, result_indices, result_page
--------------------------------------------------------------------------------
/templates/gradio_roots_temp/{{ cookiecutter.local_app }}/requirements.txt:
--------------------------------------------------------------------------------
1 | pyserini
2 | datasets
3 | faiss-cpu
4 | torch
--------------------------------------------------------------------------------
/templates/streamlit/{{ cookiecutter.module_slug }}/.gitattributes:
--------------------------------------------------------------------------------
1 | index/**/* filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/templates/gradio_roots_temp/{{ cookiecutter.local_app }}/.gitattributes:
--------------------------------------------------------------------------------
1 | index/**/* filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/src/spacerini/frontend/__init__.py:
--------------------------------------------------------------------------------
1 | from . import local, space
2 | from .space import create_space_from_local
3 | from .local import create_app
--------------------------------------------------------------------------------
/tests/data/sample_documents.jsonl:
--------------------------------------------------------------------------------
1 | {"id": "doc1", "contents": "contents of doc one."}
2 | {"id": "doc2", "contents": "contents of document two."}
3 | {"id": "doc3", "contents": "here's some text in document three."}
--------------------------------------------------------------------------------
/src/spacerini/data/__init__.py:
--------------------------------------------------------------------------------
1 | from . import load, utils
2 | from .load import load_from_hub, load_from_pandas, load_ir_dataset, load_ir_dataset_low_memory, load_ir_dataset_streaming, load_from_local, load_from_sqlite_table
--------------------------------------------------------------------------------
/src/spacerini/preprocess/tokenize.py:
--------------------------------------------------------------------------------
1 | def batch_iterator(hf_dataset, text_field, batch_size=1000):
2 | for i in range(0, len(hf_dataset), batch_size):
3 | yield hf_dataset.select(range(i, i + batch_size))[text_field]
--------------------------------------------------------------------------------
/src/spacerini/index/__init__.py:
--------------------------------------------------------------------------------
1 | from . import index
2 | from .encode import encode_json_dataset
3 | from .index import index_json_shards, index_streaming_dataset
4 | from .index import fetch_index_stats
5 | from .utils import push_index_to_hub, load_index_from_hub
6 |
--------------------------------------------------------------------------------
/templates/gradio/cookiecutter.json:
--------------------------------------------------------------------------------
1 | {
2 | "dset_text_field": null,
3 | "metadata_field": null,
4 | "space_title": null,
5 | "local_app": "{{ cookiecutter.space_title|lower|replace(' ', '-') }}",
6 | "space_gradio_sdk_version": "3.29.0",
7 | "space_license": "apache-2.0",
8 | "space_pinned": "false"
9 | }
10 |
--------------------------------------------------------------------------------
/templates/gradio_roots_temp/cookiecutter.json:
--------------------------------------------------------------------------------
1 | {
2 | "space_title": null,
3 | "local_app": "{{ cookiecutter.space_title|lower|replace(' ', '-') }}",
4 | "space_gradio_sdk_version": "3.12.0",
5 | "space_license": "apache-2.0",
6 | "space_pinned": "false",
7 | "emoji": "🚀",
8 | "space_description": null,
9 | "private": false,
10 | "dataset_name": null
11 | }
--------------------------------------------------------------------------------
/templates/streamlit/{{ cookiecutter.module_slug }}/README.md:
--------------------------------------------------------------------------------
1 | ---
2 | title: {{cookiecutter.space_title}}
3 | emoji: 🐠
4 | colorFrom: blue
5 | colorTo: blue
6 | sdk: streamlit
7 | sdk_version: 1.25.0
8 | app_file: app.py
9 | pinned: false
10 | license: {{cookiecutter.space_license}}
11 | ---
12 |
13 | Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14 |
--------------------------------------------------------------------------------
/examples/configs/xsum.json:
--------------------------------------------------------------------------------
1 | {
2 | "delete_after": true,
3 | "space_name": "xsum",
4 | "space_url_slug": "xsumcli",
5 | "sdk": "gradio",
6 | "organization": null,
7 | "description": null,
8 | "content_column": "document",
9 | "dataset": "xsum",
10 | "docid-column": "id",
11 | "split": "test",
12 | "template": "gradio_roots_temp",
13 | "store_raw": true,
14 | "store_contents": true
15 | }
--------------------------------------------------------------------------------
/templates/gradio/{{ cookiecutter.local_app }}/README.md:
--------------------------------------------------------------------------------
1 | ---
2 | title: {{ cookiecutter.space_title }}
3 | emoji: 🐠
4 | colorFrom: blue
5 | colorTo: blue
6 | sdk: gradio
7 | sdk_version: {{ cookiecutter.space_gradio_sdk_version }}
8 | app_file: app.py
9 | pinned: {{ cookiecutter.space_pinned }}
10 | license: {{ cookiecutter.space_license }}
11 | ---
12 |
13 | Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
--------------------------------------------------------------------------------
/templates/gradio_roots_temp/{{ cookiecutter.local_app }}/README.md:
--------------------------------------------------------------------------------
1 | ---
2 | title: {{ cookiecutter.space_title }}
3 | emoji: {{ cookiecutter.emoji }}
4 | colorFrom: blue
5 | colorTo: blue
6 | sdk: gradio
7 | sdk_version: {{ cookiecutter.space_gradio_sdk_version }}
8 | app_file: app.py
9 | pinned: {{ cookiecutter.space_pinned }}
10 | license: {{ cookiecutter.space_license }}
11 | ---
12 |
13 | Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
--------------------------------------------------------------------------------
/src/spacerini/spacerini_utils/index.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any
3 | from typing import Dict
4 |
5 | from pyserini.index.lucene import IndexReader
6 |
7 |
8 | def fetch_index_stats(index_path: str) -> Dict[str, Any]:
9 | """
10 | Fetch index statistics
11 | index_path : str
12 | Path to index directory
13 | Returns
14 | -------
15 | Dictionary of index statistics
16 | Dictionary Keys ==> total_terms, documents, unique_terms
17 | """
18 | assert os.path.exists(index_path), f"Index path {index_path} does not exist"
19 | index_reader = IndexReader(index_path)
20 | return index_reader.stats()
21 |
--------------------------------------------------------------------------------
/docs/arguments.md:
--------------------------------------------------------------------------------
1 | # List of Parameters
2 |
3 | ## Indexing (`spacerini.index`)
4 |
5 | - `disable_tqdm` : bool
6 | Disable tqdm output
7 | - `index_path` : str
8 | Directory to store index
9 | - `language` : str
10 | Language of dataset
11 | - `pretokenized` : bool
12 | If True, dataset is already tokenized
13 | - `analyzeWithHuggingFaceTokenizer` : str
14 | If True, use HuggingFace tokenizer to tokenize dataset
15 | - `storePositions` : bool
16 | If True, store positions of tokens in index
17 | - `storeDocvectors` : bool
18 | If True, store document vectors in index
19 | - `storeContents` : bool
20 | If True, store contents of documents in index
21 | - `storeRaw` : bool
22 | If True, store raw contents of documents in index
23 | - `keepStopwords` : bool
24 | If True, keep stopwords in index
25 | - `stopwords` : str
26 | Path to stopwords file
27 | - `stemmer` : str
28 | Stemmer to use for indexing
29 | - `optimize` : bool
30 | If True, optimize index after indexing is complete
31 | - `verbose` : bool
32 | If True, print verbose output
33 | - `quiet` : bool
34 | If True, print no output
35 | - `memory_buffer` : str
36 | Memory buffer size
37 | - `n_threads` : bool
38 | Number of threads to use for indexing
--------------------------------------------------------------------------------
/templates/streamlit/{{ cookiecutter.module_slug }}/app.py:
--------------------------------------------------------------------------------
1 | # This currently contains Odunayo's template. We need to adapt this to cookiecutter.
2 | import streamlit as st
3 | from pyserini.search.lucene import LuceneSearcher
4 | import json
5 | import time
6 |
7 | st.set_page_config(page_title="{{title}}", page_icon='', layout="centered")
8 | searcher = LuceneSearcher('{{index_path}}')
9 |
10 |
11 | col1, col2 = st.columns([9, 1])
12 | with col1:
13 | search_query = st.text_input(label="", placeholder="Search")
14 |
15 | with col2:
16 | st.write('#')
17 | button_clicked = st.button("🔎")
18 |
19 |
20 | if search_query or button_clicked:
21 | num_results = None
22 |
23 | t_0 = time.time()
24 | search_results = searcher.search(search_query, k=100_000)
25 | search_time = time.time() - t_0
26 |
27 | st.write(f'
Retrieved {len(search_results):,.0f} documents in {search_time*1000:.2f} ms
', unsafe_allow_html=True)
28 | for result in search_results[:10]:
29 | result = json.loads(result.raw)
30 | doc = result["contents"]
31 | result_id = result["id"]
32 | try:
33 | st.write(doc[:1000], unsafe_allow_html=True)
34 | st.write(f'Document ID: {result_id}
', unsafe_allow_html=True)
35 |
36 | except:
37 | pass
38 |
39 | st.write('---')
--------------------------------------------------------------------------------
/examples/scripts/index-push-pull.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | from spacerini.preprocess.utils import shard_dataset
3 | from spacerini.index.index import index_json_shards
4 | from spacerini.index.utils import push_index_to_hub, load_index_from_hub
5 | from spacerini.search.utils import result_indices
6 |
7 | DSET = "imdb"
8 | SPLIT = "train"
9 | COLUMN_TO_INDEX = "text"
10 | NUM_PROC = 28
11 | SHARDS_PATH = f"{DSET}-json-shards"
12 | TEST_QUERY = "great movie"
13 | NUM_RESULTS = 5
14 | INDEX_PATH = "index"
15 | DATASET_SLUG = "lucene-imdb-train"
16 |
17 | dset = load_dataset(
18 | DSET,
19 | split=SPLIT
20 | )
21 |
22 | shard_dataset(
23 | hf_dataset=dset,
24 | shard_size="10MB",
25 | column_to_index=COLUMN_TO_INDEX,
26 | shards_paths=SHARDS_PATH
27 | )
28 |
29 | index_json_shards(
30 | shards_path=SHARDS_PATH,
31 | keep_shards=False,
32 | index_path= INDEX_PATH,
33 | language="en",
34 | n_threads=NUM_PROC
35 | )
36 |
37 | print(
38 | f"First {NUM_RESULTS} results for query: \"{TEST_QUERY}\"",
39 | result_indices(TEST_QUERY, NUM_RESULTS, INDEX_PATH)
40 | )
41 |
42 | push_index_to_hub(
43 | dataset_slug=DATASET_SLUG,
44 | index_path="index",
45 | delete_after_push=True
46 | )
47 |
48 | new_index_path = load_index_from_hub(DATASET_SLUG)
49 | print(
50 | f"First {NUM_RESULTS} results for query: \"{TEST_QUERY}\"",
51 | result_indices(TEST_QUERY, NUM_RESULTS, new_index_path)
52 | )
--------------------------------------------------------------------------------
/src/spacerini/frontend/local.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from shutil import copytree
3 |
4 | from cookiecutter.main import cookiecutter
5 |
6 | default_templates_dir = (Path(__file__).parents[3] / "templates").resolve()
7 | LOCAL_TEMPLATES = ["gradio", "streamlit", "gradio_roots_temp"]
8 |
9 |
10 | def create_app(
11 | template: str,
12 | extra_context_dict: dict,
13 | output_dir: str,
14 | no_input: bool = True,
15 | overwrite_if_exists: bool=True
16 | ) -> None:
17 | """
18 | Create a new app from a template.
19 | Parameters
20 | ----------
21 | template : str
22 | The name of the template to use.
23 | extra_context_dict : dict
24 | The extra context to pass to the template.
25 | output_dir : str
26 | The output directory.
27 | no_input : bool, optional
28 | If True, do not prompt for parameters and only use
29 | overwrite_if_exists : bool, optional
30 | If True, overwrite the output directory if it already exists.
31 | Returns
32 | -------
33 | None
34 | """
35 | cookiecutter(
36 | "https://github.com/castorini/hf-spacerini.git/" if template in LOCAL_TEMPLATES else template,
37 | directory="templates/" + template if template in LOCAL_TEMPLATES else None,
38 | no_input=no_input,
39 | extra_context=extra_context_dict,
40 | output_dir=output_dir,
41 | overwrite_if_exists=overwrite_if_exists,
42 | )
43 |
44 | utils_dir = Path(__file__).parents[1].resolve() / "spacerini_utils"
45 | app_dir = Path(output_dir) / extra_context_dict["local_app"] / "spacerini_utils"
46 | copytree(utils_dir, app_dir, dirs_exist_ok=True)
47 |
48 | return None
49 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=64.0.0", "setuptools-scm"]
3 | build-backend = "setuptools.build_meta"
4 |
5 |
6 | [project]
7 | name = "spacerini"
8 | description = "Hosted Lucene indexes with Pyserini and Hugging Face Spaces"
9 | readme = { file = "README.md", content-type = "text/markdown" }
10 | maintainers = [
11 | { name = "Christopher Akiki", email = "christopher.akiki@gmail.com" },
12 | { name = "Ogundepo Odunayo", email = "ogundepoodunayo@gmail.com"},
13 | { name = "Akintunde Oladipo", email = "akin.o.oladipo@gmail.com"}
14 | ]
15 | requires-python = ">=3.8"
16 | dependencies = [
17 | 'pyserini',
18 | 'cookiecutter',
19 | 'huggingface_hub',
20 | 'tokenizers',
21 | 'datasets>=2.8.0',
22 | 'ir_datasets',
23 | 'streamlit',
24 | 'gradio',
25 | 'torch',
26 | 'faiss-cpu',
27 | ]
28 | dynamic = [
29 | "version",
30 | ]
31 | classifiers = [
32 | 'Development Status :: 3 - Alpha',
33 | 'Intended Audience :: Developers',
34 | 'Intended Audience :: Information Technology',
35 | 'Intended Audience :: Science/Research',
36 | 'License :: OSI Approved :: Apache Software License',
37 | 'Programming Language :: Python',
38 | 'Topic :: Software Development :: Libraries :: Python Modules',
39 | 'Operating System :: OS Independent',
40 | 'Programming Language :: Python :: 3',
41 | 'Programming Language :: Python :: 3.8',
42 | 'Programming Language :: Python :: 3.9',
43 | 'Programming Language :: Python :: 3.10',
44 | ]
45 | license = { text = "Apache-2.0" }
46 |
47 | [project.optional-dependencies]
48 | dev = [
49 | 'pytest',
50 | ]
51 |
52 | [project.urls]
53 | Homepage = "https://github.com/castorini/hf-spacerini"
54 |
55 | [project.scripts]
56 | spacerini = "spacerini.cli:main"
--------------------------------------------------------------------------------
/examples/scripts/xsum-demo.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains a demo of Spacerini using Train Collection of the XSum dataset.
3 | """
4 | import os
5 | import logging
6 | from spacerini.frontend import create_app, create_space_from_local
7 | from spacerini.index import index_streaming_dataset
8 |
9 | logging.basicConfig(level=logging.INFO)
10 |
11 | DATASET = "xsum"
12 | SPLIT = "test"
13 | SPACE_TITLE = "XSum Train Dataset Search"
14 | COLUMN_TO_INDEX = ["document"]
15 | LOCAL_APP = "xsum-demo"
16 | SDK = "gradio"
17 | TEMPLATE = "gradio_roots_temp"
18 | ORGANIZATION = "ToluClassics"
19 |
20 |
21 | cookiecutter_vars = {
22 | "dset_text_field": COLUMN_TO_INDEX,
23 | "space_title": SPACE_TITLE,
24 | "local_app":LOCAL_APP,
25 | "space_description": "This is a demo of Spacerini using the XSum dataset.",
26 | "dataset_name": "xsum"
27 | }
28 |
29 |
30 |
31 | logging.info(f"Creating local app into {LOCAL_APP} directory")
32 | create_app(
33 | template=TEMPLATE,
34 | extra_context_dict=cookiecutter_vars,
35 | output_dir="apps"
36 | )
37 |
38 | logging.info(f"Indexing {DATASET} dataset into {os.path.join('apps', LOCAL_APP, 'index')}")
39 | index_streaming_dataset(
40 | dataset_name_or_path=DATASET,
41 | index_path= os.path.join("apps", LOCAL_APP, "index"),
42 | split=SPLIT,
43 | column_to_index=COLUMN_TO_INDEX,
44 | doc_id_column="id",
45 | storeContents=True,
46 | storeRaw=True,
47 | language="en"
48 | )
49 |
50 | logging.info(f"Creating space {SPACE_TITLE} on {ORGANIZATION}")
51 | create_space_from_local(
52 | space_slug="xsum-test",
53 | organization=ORGANIZATION,
54 | space_sdk=SDK,
55 | local_dir=os.path.join("apps", LOCAL_APP),
56 | delete_after_push=False
57 | )
--------------------------------------------------------------------------------
/src/spacerini/frontend/space.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import shutil
3 |
4 | from huggingface_hub import HfApi, create_repo, upload_folder
5 |
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | def create_space_from_local(
11 | space_slug: str,
12 | space_sdk: str,
13 | local_dir: str,
14 | private: bool=False,
15 | organization: str=None,
16 | delete_after_push: bool=False,
17 | ) -> str:
18 | """
19 | Create a new space from a local directory.
20 | Parameters
21 | ----------
22 | space_slug : str
23 | The slug of the space.
24 | space_sdk : str
25 | The SDK of the space, could be either Gradio or Streamlit.
26 | local_dir : str
27 | The local directory where the app is currently stored.
28 | private : bool, optional
29 | If True, the space will be private.
30 | organization : str, optional
31 | The organization to create the space in.
32 | delete_after_push : bool, optional
33 | If True, delete the local directory after pushing it to the Hub.
34 |
35 | Returns
36 | -------
37 | repo_url: str
38 | The URL of the space.
39 | """
40 |
41 | if organization is None:
42 | hf_api = HfApi()
43 | namespace = hf_api.whoami()["name"]
44 | else:
45 | namespace = organization
46 | repo_id = namespace + "/" + space_slug
47 | try:
48 | repo_url = create_repo(repo_id=repo_id, repo_type="space", space_sdk=space_sdk, private=private, exist_ok=True)
49 | except Exception as ex:
50 | logger.error("Encountered an error while creating the space: ", ex)
51 | raise
52 |
53 | upload_folder(folder_path=local_dir, repo_id=repo_id, repo_type="space")
54 | if delete_after_push:
55 | shutil.rmtree(local_dir)
56 | return repo_url
57 |
--------------------------------------------------------------------------------
/examples/scripts/gradio-demo.py:
--------------------------------------------------------------------------------
1 | from spacerini.frontend.local import create_app
2 | from spacerini.frontend.space import create_space_from_local
3 | from datasets import load_dataset
4 | from spacerini.preprocess.utils import shard_dataset, get_num_shards
5 | from spacerini.index.index import index_json_shards
6 |
7 | DSET = "imdb"
8 | SPLIT = "train"
9 | SPACE_TITLE = "IMDB search"
10 | COLUMN_TO_INDEX = "text"
11 | METADATA_COLUMNS = ["sentiment", "docid"]
12 | NUM_PROC = 28
13 | SHARDS_PATH = f"{DSET}-json-shards"
14 | LOCAL_APP = "gradio_app"
15 | SDK = "gradio"
16 | ORGANIZATION = "cakiki"
17 | MAX_ARROW_SHARD_SIZE="1GB"
18 |
19 | cookiecutter_vars = {
20 | "dset_text_field": COLUMN_TO_INDEX,
21 | "metadata_field": METADATA_COLUMNS[1],
22 | "space_title": SPACE_TITLE,
23 | "local_app":LOCAL_APP
24 | }
25 | create_app(
26 | template=SDK,
27 | extra_context_dict=cookiecutter_vars,
28 | output_dir="."
29 | )
30 |
31 | dset = load_dataset(
32 | DSET,
33 | split=SPLIT
34 | )
35 |
36 | shard_dataset(
37 | hf_dataset=dset,
38 | shard_size="10MB",
39 | column_to_index=COLUMN_TO_INDEX,
40 | shards_paths=SHARDS_PATH
41 | )
42 |
43 | index_json_shards(
44 | shards_path=SHARDS_PATH,
45 | index_path=LOCAL_APP + "/index",
46 | language="en",
47 | n_threads=NUM_PROC
48 | )
49 |
50 | dset = dset.add_column("docid", range(len(dset)))
51 | num_shards = get_num_shards(dset.data.nbytes, MAX_ARROW_SHARD_SIZE)
52 | dset.remove_columns([c for c in dset.column_names if not c in [COLUMN_TO_INDEX,*METADATA_COLUMNS]]).save_to_disk(
53 | LOCAL_APP + "/data",
54 | num_shards=num_shards,
55 | num_proc=NUM_PROC
56 | )
57 |
58 | create_space_from_local(
59 | space_slug="imdb-search",
60 | organization=ORGANIZATION,
61 | space_sdk=SDK,
62 | local_dir=LOCAL_APP,
63 | delete_after_push=False
64 | )
--------------------------------------------------------------------------------
/src/spacerini/index/utils.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import logging
3 | from huggingface_hub import HfApi, create_repo, upload_folder, snapshot_download
4 |
5 | logger = logging.getLogger(__name__)
6 |
7 |
8 | def push_index_to_hub(
9 | dataset_slug: str,
10 | index_path: str,
11 | private: bool=False,
12 | organization: str=None,
13 | delete_after_push: bool=False,
14 | ) -> str:
15 | """
16 | Push an index as a dataset to the Hugging Face Hub.
17 | ----------
18 | dataset_slug : str
19 | The slug of the space.
20 | index_path : str
21 | The local directory where the app is currently stored.
22 | private : bool, optional
23 | If True, the space will be private.
24 | organization : str, optional
25 | The organization to create the space in.
26 | delete_after_push : bool, optional
27 | If True, delete the local index after pushing it to the Hub.
28 |
29 | Returns
30 | -------
31 | repo_url: str
32 | The URL of the dataset.
33 | """
34 |
35 | if organization is None:
36 | hf_api = HfApi()
37 | namespace = hf_api.whoami()["name"]
38 | else:
39 | namespace = organization
40 | repo_id = namespace + "/" + dataset_slug
41 | try:
42 | repo_url = create_repo(repo_id=repo_id, repo_type="dataset", private=private)
43 | except Exception as ex:
44 | logger.error("Encountered an error while creating the dataset repository: ", ex)
45 | raise
46 |
47 | upload_folder(folder_path=index_path, path_in_repo="index", repo_id=repo_id, repo_type="dataset")
48 | if delete_after_push:
49 | shutil.rmtree(index_path)
50 | return repo_url
51 |
52 |
53 | def load_index_from_hub(dataset_slug: str, organization: str=None) -> str:
54 | if organization is None:
55 | hf_api = HfApi()
56 | namespace = hf_api.whoami()["name"]
57 | else:
58 | namespace = organization
59 | repo_id = namespace + "/" + dataset_slug
60 |
61 | local_path = snapshot_download(repo_id=repo_id, repo_type="dataset")
62 | index_path = local_path + "/index"
63 | return index_path
64 |
--------------------------------------------------------------------------------
/src/spacerini/preprocess/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | from datasets import Dataset
3 | from datasets.utils.py_utils import convert_file_size_to_int
4 | import os
5 |
6 |
7 | def shard_dataset(hf_dataset: Dataset, shard_size: Union[int, str], shards_paths: str, column_to_index: str) -> None: # TODO break this up into smaller functions
8 | """
9 | Shard a dataset into multiple files.
10 | Parameters
11 | ----------
12 | hf_dataset : datasets.Dataset
13 | a Hugging Face datasets object
14 | shard_size : str
15 | The size of each arrow shard that gets written as a JSON file.
16 | shards_paths : str
17 | The path to the directory where the shards will be stored.
18 | column_to_index : str
19 | The column to index mapping.
20 |
21 | Returns
22 | -------
23 | None
24 | """
25 | hf_dataset = hf_dataset.remove_columns([c for c in hf_dataset.column_names if c!=column_to_index])
26 | hf_dataset = hf_dataset.rename_column(column_to_index, "contents") # pyserini only wants a content column and an index column
27 | hf_dataset = hf_dataset.add_column("id", range(len(hf_dataset)))
28 | num_shards = get_num_shards(hf_dataset.data.nbytes, shard_size)
29 | os.makedirs(shards_paths, exist_ok=True)
30 | for shard_index in range(num_shards):
31 | shard = hf_dataset.shard(num_shards=num_shards, index=shard_index, contiguous=True)
32 | shard.to_json(f"{shards_paths}/docs-{shard_index:03d}.jsonl", orient="records", lines=True)
33 |
34 |
35 | def get_num_shards(dataset_size: Union[int, str], max_shard_size: Union[int, str]) -> int:
36 | """
37 | Returns the number of shards required for a maximum shard size for a datasets.Dataset of a given size.
38 | Parameters
39 | ----------
40 | dataset_size: int or str
41 | The size of the dataset in either number of bytes or a string such as "10MB" or "6GB"
42 | max_shard_size: int or str
43 | The maximum size for every corresponding arrow shard in either number of bytes or a string such as "10MB" or "6GB".
44 |
45 | Returns
46 | -------
47 | int
48 | """
49 | max_shard_size = convert_file_size_to_int(max_shard_size)
50 | dataset_nbytes = convert_file_size_to_int(dataset_size)
51 | num_shards = int(dataset_nbytes / max_shard_size) + 1
52 | return max(num_shards, 1)
53 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | apps/*
132 |
133 | .vscode/
134 | apps/*
--------------------------------------------------------------------------------
/examples/notebooks/indexing_tutorial.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from spacerini.index import index_streaming_hf_dataset"
10 | ]
11 | },
12 | {
13 | "attachments": {},
14 | "cell_type": "markdown",
15 | "metadata": {},
16 | "source": [
17 | "In this notebook, we demonstrate how to index the Swahili part of the [`Mr. TyDi corpus`](https://huggingface.co/datasets/castorini/mr-tydi-corpus) on the fly using the `bert-base-multilingual-uncased` tokenizer available on [`HuggingFace`](https://huggingface.co/bert-base-multilingual-uncased). This should take roughly 3mins on a Mac M1.\n",
18 | "\n",
19 | "For a full understanding of arguments passed to Anserini, see [`io.anserini.index.IndexCollection.Args`](https://github.com/castorini/anserini)"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 3,
25 | "metadata": {},
26 | "outputs": [
27 | {
28 | "name": "stdout",
29 | "output_type": "stream",
30 | "text": [
31 | "WARNING: sun.reflect.Reflection.getCallerClass is not supported. This will impact performance.\n",
32 | "2023-02-14 00:31:59,337 INFO [main] index.SimpleIndexer (SimpleIndexer.java:120) - Bert Tokenizer\n"
33 | ]
34 | },
35 | {
36 | "name": "stderr",
37 | "output_type": "stream",
38 | "text": [
39 | "136689it [01:36, 1414.10it/s]\n"
40 | ]
41 | }
42 | ],
43 | "source": [
44 | "index_streaming_hf_dataset(\n",
45 | " \"./test-index\",\n",
46 | " \"castorini/mr-tydi-corpus\",\n",
47 | " \"train\",\n",
48 | " \"text\",\n",
49 | " ds_config_name=\"swahili\", # This is used by datasets to load the Swahili split of the mr-tydi-corpus\n",
50 | " language=\"sw\", # This is passed to Anserini and may be used to initialize a Lucene Indexer\n",
51 | " analyzeWithHuggingFaceTokenizer=\"bert-base-multilingual-uncased\"\n",
52 | ")"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": []
61 | }
62 | ],
63 | "metadata": {
64 | "kernelspec": {
65 | "display_name": "pyserini-dev",
66 | "language": "python",
67 | "name": "python3"
68 | },
69 | "language_info": {
70 | "codemirror_mode": {
71 | "name": "ipython",
72 | "version": 3
73 | },
74 | "file_extension": ".py",
75 | "mimetype": "text/x-python",
76 | "name": "python",
77 | "nbconvert_exporter": "python",
78 | "pygments_lexer": "ipython3",
79 | "version": "3.8.15"
80 | },
81 | "orig_nbformat": 4,
82 | "vscode": {
83 | "interpreter": {
84 | "hash": "b0d9161c88aeda2dc6f47d29dac86208dc1568afc6233ac80536ab4d91f86f7a"
85 | }
86 | }
87 | },
88 | "nbformat": 4,
89 | "nbformat_minor": 2
90 | }
91 |
--------------------------------------------------------------------------------
/tests/test_index.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from os import path
4 | import unittest
5 | from spacerini.index import fetch_index_stats, index_streaming_dataset
6 | from pyserini.search.lucene import LuceneSearcher
7 | from typing import List
8 |
9 |
10 | class TestIndex(unittest.TestCase):
11 | def setUp(self):
12 | self.index_path = path.join(path.dirname(__file__), "indexes")
13 | self.dataset_name_or_path = path.join(path.dirname(__file__),"data/sample_documents.jsonl")
14 | os.makedirs(self.index_path, exist_ok=True)
15 |
16 |
17 | def test_index_streaming_hf_dataset_local(self):
18 | """
19 | Test indexing a local dataset
20 | """
21 | local_index_path = path.join(self.index_path, "local")
22 | index_streaming_dataset(
23 | index_path=local_index_path,
24 | dataset_name_or_path=self.dataset_name_or_path,
25 | split="train",
26 | column_to_index=["contents"],
27 | doc_id_column="id",
28 | language="en",
29 | storeContents = True,
30 | storeRaw = True,
31 | num_rows=3
32 | )
33 |
34 | self.assertTrue(os.path.exists(local_index_path))
35 | searcher = LuceneSearcher(local_index_path)
36 |
37 | hits = searcher.search('contents')
38 | self.assertTrue(isinstance(hits, List))
39 | self.assertEqual(hits[0].docid, 'doc1')
40 | self.assertEqual(hits[0].contents, "contents of doc one.")
41 |
42 | index_stats = fetch_index_stats(local_index_path)
43 | self.assertEqual(index_stats["total_terms"], 11)
44 | self.assertEqual(index_stats["documents"], 3)
45 | self.assertEqual(index_stats["unique_terms"], 9)
46 |
47 |
48 | def test_index_streaming_hf_dataset_huggingface(self):
49 | """
50 | Test indexing a dataset from HuggingFace Hub
51 | """
52 | hgf_index_path = path.join(self.index_path, "hgf")
53 | index_streaming_dataset(
54 | index_path=hgf_index_path,
55 | dataset_name_or_path="sciq",
56 | split="test",
57 | column_to_index=["question", "support"],
58 | language="en",
59 | num_rows=1000,
60 | storeContents = True,
61 | storeRaw = True
62 | )
63 |
64 | self.assertTrue(os.path.exists(hgf_index_path))
65 | searcher = LuceneSearcher(hgf_index_path)
66 | index_stats = fetch_index_stats(hgf_index_path)
67 |
68 | hits = searcher.search('contents')
69 | self.assertTrue(isinstance(hits, List))
70 | self.assertEqual(hits[0].docid, '528')
71 | self.assertEqual(index_stats["total_terms"], 54197)
72 | self.assertEqual(index_stats["documents"], 1000)
73 |
74 | def tearDown(self):
75 | shutil.rmtree(self.index_path)
76 |
77 |
78 |
79 |
--------------------------------------------------------------------------------
/src/spacerini/frontend/__main__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import logging
4 | from spacerini.frontend import create_space_from_local, create_app
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | def parser():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument(
12 | "--template",
13 | type=str,
14 | help="The path to a predefined template to use.",
15 | required=True
16 | )
17 | parser.add_argument(
18 | "--extra_context_dict",
19 | type=dict,
20 | help="The extra context to pass to the template.",
21 | )
22 | parser.add_argument(
23 | "--no_input",
24 | type=bool,
25 | help="If True, do not prompt for parameters and only use",
26 | )
27 | parser.add_argument(
28 | "--overwrite_if_exists",
29 | type=bool,
30 | help="If True, overwrite the output directory if it already exists.",
31 | default=True
32 | )
33 | parser.add_argument(
34 | "--space_slug",
35 | type=str,
36 | help="The name of the space on huggingface.",
37 | required=True
38 | )
39 | parser.add_argument(
40 | "--space_sdk",
41 | type=str,
42 | help="The SDK of the space, could be either Gradio or Streamlit.",
43 | choices=["gradio", "streamlit"],
44 | required=True
45 | )
46 | parser.add_argument(
47 | "--local_dir",
48 | type=str,
49 | help="The local directory where the app should be stored.",
50 | )
51 | parser.add_argument(
52 | "--private",
53 | type=bool,
54 | help="If True, the space will be private.",
55 | default=False
56 | )
57 | parser.add_argument(
58 | "--organization",
59 | type=str,
60 | help="The organization to create the space in.",
61 | default=None
62 | )
63 | parser.add_argument(
64 | "--delete_after_push",
65 | type=bool,
66 | help="If True, delete the local directory after pushing it to the Hub.",
67 | )
68 | return parser
69 |
70 | def main():
71 | parser = parser()
72 | args = parser.parse_args()
73 |
74 | logger.info("Validating the input arguments...")
75 | assert os.path.exists(args.template), "The template does not exist."
76 | assert os.path.exists(args.local_dir), "The local directory does not exist."
77 |
78 | logger.info("Creating the app locally...")
79 | create_app(
80 | template=args.template,
81 | extra_context_dict=args.extra_context_dict,
82 | output_dir=args.local_dir,
83 | no_input=args.no_input,
84 | overwrite_if_exists=args.overwrite_if_exists
85 | )
86 |
87 | logger.info("Creating the space on huggingface...")
88 | create_space_from_local(
89 | space_slug=args.space_slug,
90 | space_sdk=args.space_sdk,
91 | local_dir=args.local_dir,
92 | private=args.private,
93 | organization=args.organization,
94 | delete_after_push=args.delete_after_push
95 | )
96 |
97 |
98 | if __name__ == "__main__":
99 | pass
--------------------------------------------------------------------------------
/src/spacerini/index/encode.py:
--------------------------------------------------------------------------------
1 | from typing import Iterable
2 | from typing import List
3 | from typing import Literal
4 | from typing import Optional
5 | from typing import Protocol
6 |
7 | from pyserini.encode.__main__ import init_encoder
8 | from pyserini.encode import RepresentationWriter
9 | from pyserini.encode import FaissRepresentationWriter
10 | from pyserini.encode import JsonlCollectionIterator
11 | from pyserini.encode import JsonlRepresentationWriter
12 |
13 | EncoderClass = Literal["dkrr", "dpr", "tct_colbert", "ance", "sentence", "contriever", "auto"]
14 |
15 |
16 | class Encoder(Protocol):
17 | def encode(**kwargs): ...
18 |
19 |
20 | def init_writer(
21 | embedding_dir: str,
22 | embedding_dimension: int = 768,
23 | output_to_faiss: bool = False
24 | ) -> RepresentationWriter:
25 | """
26 | """
27 | if output_to_faiss:
28 | writer = FaissRepresentationWriter(embedding_dir, dimension=embedding_dimension)
29 | return writer
30 |
31 | return JsonlRepresentationWriter(embedding_dir)
32 |
33 |
34 | def encode_corpus_or_shard(
35 | encoder: Encoder,
36 | collection_iterator: Iterable[dict],
37 | embedding_writer: RepresentationWriter,
38 | batch_size: int,
39 | shard_id: int,
40 | shard_num: int,
41 | max_length: int = 256,
42 | add_sep: bool = False,
43 | input_fields: List[str] = None,
44 | title_column_to_encode: Optional[str] = None,
45 | text_column_to_encode: Optional[str] = "text",
46 | expand_column_to_encode: Optional[str] = None,
47 | fp16: bool = False
48 | ) -> None:
49 | """
50 |
51 | """
52 | with embedding_writer:
53 | for batch_info in collection_iterator(batch_size, shard_id, shard_num):
54 | kwargs = {
55 | 'texts': batch_info[text_column_to_encode],
56 | 'titles': batch_info[title_column_to_encode] if title_column_to_encode else None,
57 | 'expands': batch_info[expand_column_to_encode] if expand_column_to_encode else None,
58 | 'fp16': fp16,
59 | 'max_length': max_length,
60 | 'add_sep': add_sep,
61 | }
62 | embeddings = encoder.encode(**kwargs)
63 | batch_info['vector'] = embeddings
64 | embedding_writer.write(batch_info, input_fields)
65 |
66 |
67 | def encode_json_dataset(
68 | data_path: str,
69 | encoder_name_or_path: str,
70 | encoder_class: EncoderClass,
71 | embedding_dir: str,
72 | batch_size: int,
73 | index_shard_id: int = 0,
74 | num_index_shards: int = 1,
75 | device: str = "cuda:0",
76 | delimiter: str = "\n",
77 | max_length: int = 256,
78 | add_sep: bool = False,
79 | input_fields: List[str] = None,
80 | title_column_to_encode: Optional[str] = None,
81 | text_column_to_encode: Optional[str] = "text",
82 | expand_column_to_encode: Optional[str] = None,
83 | output_to_faiss: bool = False,
84 | embedding_dimension: int = 768,
85 | fp16: bool = False
86 | ) -> None:
87 | """
88 |
89 | """
90 | if input_fields is None:
91 | input_fields = ["text"]
92 |
93 | encoder = init_encoder(encoder_name_or_path, encoder_class, device=device)
94 |
95 | writer = init_writer(
96 | embedding_dir=embedding_dir,
97 | embedding_dimension=embedding_dimension,
98 | output_to_faiss=output_to_faiss
99 | )
100 |
101 | collection_iterator = JsonlCollectionIterator(data_path, input_fields, delimiter)
102 | encode_corpus_or_shard(
103 | encoder=encoder,
104 | collection_iterator=collection_iterator,
105 | embedding_writer=writer,
106 | batch_size=batch_size,
107 | shard_id=index_shard_id,
108 | shard_num=num_index_shards,
109 | max_length=max_length,
110 | add_sep=add_sep,
111 | input_fields=input_fields,
112 | title_column_to_encode=title_column_to_encode,
113 | text_column_to_encode=text_column_to_encode,
114 | expand_column_to_encode=expand_column_to_encode,
115 | fp16=fp16
116 | )
117 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Spacerini 🦄
2 |
3 | Spacerini is a modular framework for seamless building and deployment of interactive search application. It integrates Pyserini and the HuggingFace 🤗 ecosystem to enable facilitate the qualitative analysis of large scale research datasets.
4 |
5 | You can index collections and deploy them as ad-hoc search engines for efficient retrieval of relevant documents.
6 | Spacerini provides a customisable, user-friendly interface for auditing massive datasets. Spacerini also allows Information Retrieval (IR) researchers and Search Engineers to demonstrate the capabilities of their indices easily and interactively.
7 |
8 | Spacerini currently supports the use of Gradio and Streamlit to create these search applications. In the future, we will also support deployment of docker containers.
9 |
10 | ## Installation ⚒️
11 |
12 | To get started create an access token with write access on huggingface to enable the creation of spaces. [Check here](https://huggingface.co/docs/hub/security-tokens) for documentation on user access tokens on huggingface.
13 |
14 | Run `huggingface-cli login` to register your access token locally.
15 |
16 | ### From Github
17 |
18 | - `pip install git+https://github.com/castorini/hf-spacerini.git`
19 |
20 | ### Development Installation
21 |
22 | You will need a development installation if you are contributing to Spacerini.
23 |
24 | - Create a virual environment - `conda create --name spacerini python=3.9`
25 | - Clone the repository - `git clone https://github.com/castorini/hf-spacerini.git`
26 | - Install in editable mode -`pip install -e ".[dev]"`
27 |
28 |
29 | ## Creating Spaces applications 🔎
30 |
31 | Spacerini provides flexibility. You can customize every step of the `index-create-deploy` process as your project requires. You can provide your own index, built a more interactive web application and deploy changes as necessary.
32 |
33 | Some of the commands that allow this flexibility are:
34 |
35 | * index: Create a sparse, dense or hybrid index from specified dataset. This does not create a space.
36 | * create-space: Create index from dataset and create HuggingFace Space application.
37 | * deploy: Create HuggingFace Space application and deploy!
38 | * deploy-only: Deploy an already created or recently modified Space application.
39 |
40 | If you have an existing index you have built, you can pass the `--index-exists` flag to any of the listed commands. Run `spacerini --help` for a full list of commands and arguments.
41 |
42 | 1. Getting started
43 |
44 | You can deploy a search system for the test set of the Extreme Summatization [XSUM](https://huggingface.co/datasets/xsum) dataset on huggingface using the following command. This system is based on our [gradio_roots_temp](templates/gradio_roots_temp/) template.
45 |
46 | ```bash
47 | spacerini --from-example xsum deploy
48 | ```
49 |
50 | Voila!, you have successfully deployed a search engine on HuggingFace Spaces 🤩🥳! This command downloads the XSUM dataset from the HuggingFace Hub, builds an interative user interface on top of a sparse index the dataset and deploys the application to HuggingFace Spaces. You can find the configurations for the application in: [examples/configs/xsum.json](examples/configs/xsum.json).
51 |
52 | 2. Building your own custom application
53 |
54 | The easiest way to build your own application is to provide a JSON configuration containing arguments for indexing, creating your space and deploying it.
55 |
56 | ```bash
57 | spacerini --config-file
58 | ```
59 | where `` may be `index`, `create-space`, `deploy` or `deploy-only`.
60 |
61 | It helps to familiarise yourself with the arguments available for each step by running.
62 |
63 | ```bash
64 | spacerini --help
65 | ```
66 |
67 | If you are using a custom template you have built, you can pass it using the `template` argument. Once your application has been created locally, you can run it before deploying.
68 |
69 | ```bash
70 | cd apps/
71 | python app.py
72 | ```
73 |
74 | After completing all necessary modifications, run the following command to deploy your Spaces application
75 |
76 | ```bash
77 | spacerini --config-file deploy-only
78 | ```
79 |
80 | ## Contribution
81 |
82 | ## Acknowledgement
83 |
--------------------------------------------------------------------------------
/src/spacerini/spacerini_utils/search.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import List, Literal, Protocol, Tuple, TypedDict, Union
3 |
4 | from pyserini.analysis import get_lucene_analyzer
5 | from pyserini.index import IndexReader
6 | from pyserini.search import DenseSearchResult, JLuceneSearcherResult
7 | from pyserini.search.faiss.__main__ import init_query_encoder
8 | from pyserini.search.faiss import FaissSearcher
9 | from pyserini.search.hybrid import HybridSearcher
10 | from pyserini.search.lucene import LuceneSearcher
11 |
12 | EncoderClass = Literal["dkrr", "dpr", "tct_colbert", "ance", "sentence", "contriever", "auto"]
13 |
14 |
15 | class AnalyzerArgs(TypedDict):
16 | language: str
17 | stemming: bool
18 | stemmer: str
19 | stopwords: bool
20 | huggingFaceTokenizer: str
21 |
22 |
23 | class SearchResult(TypedDict):
24 | docid: str
25 | text: str
26 | score: float
27 | language: str
28 |
29 |
30 | class Searcher(Protocol):
31 | def search(self, query: str, **kwargs) -> List[Union[DenseSearchResult, JLuceneSearcherResult]]:
32 | ...
33 |
34 |
35 | def init_searcher_and_reader(
36 | sparse_index_path: str = None,
37 | bm25_k1: float = None,
38 | bm25_b: float = None,
39 | analyzer_args: AnalyzerArgs = None,
40 | dense_index_path: str = None,
41 | encoder_name_or_path: str = None,
42 | encoder_class: EncoderClass = None,
43 | tokenizer_name: str = None,
44 | device: str = None,
45 | prefix: str = None
46 | ) -> Tuple[Union[FaissSearcher, HybridSearcher, LuceneSearcher], IndexReader]:
47 | """
48 | Initialize and return an approapriate searcher
49 |
50 | Parameters
51 | ----------
52 | sparse_index_path: str
53 | Path to sparse index
54 | dense_index_path: str
55 | Path to dense index
56 | encoder_name_or_path: str
57 | Path to query encoder checkpoint or encoder name
58 | encoder_class: str
59 | Query encoder class to use. If None, infer from `encoder`
60 | tokenizer_name: str
61 | Tokenizer name or path
62 | device: str
63 | Device to load Query encoder on.
64 | prefix: str
65 | Query prefix if exists
66 |
67 | Returns
68 | -------
69 | Searcher: FaissSearcher | HybridSearcher | LuceneSearcher
70 | A sparse, dense or hybrid searcher
71 | """
72 | reader = None
73 | if sparse_index_path:
74 | ssearcher = LuceneSearcher(sparse_index_path)
75 | if analyzer_args:
76 | analyzer = get_lucene_analyzer(**analyzer_args)
77 | ssearcher.set_analyzer(analyzer)
78 | if bm25_k1 and bm25_b:
79 | ssearcher.set_bm25(bm25_k1, bm25_b)
80 |
81 | if dense_index_path:
82 | encoder = init_query_encoder(
83 | encoder=encoder_name_or_path,
84 | encoder_class=encoder_class,
85 | tokenizer_name=tokenizer_name,
86 | topics_name=None,
87 | encoded_queries=None,
88 | device=device,
89 | prefix=prefix
90 | )
91 |
92 | reader = IndexReader(sparse_index_path)
93 | dsearcher = FaissSearcher(dense_index_path, encoder)
94 |
95 | if sparse_index_path:
96 | hsearcher = HybridSearcher(dense_searcher=dsearcher, sparse_searcher=ssearcher)
97 | return hsearcher, reader
98 | else:
99 | return dsearcher, reader
100 |
101 | return ssearcher, reader
102 |
103 |
104 | def _search(searcher: Searcher, reader: IndexReader, query: str, num_results: int = 10) -> List[SearchResult]:
105 | """
106 | Parameters:
107 | -----------
108 | searcher: FaissSearcher | HybridSearcher | LuceneSearcher
109 | A sparse, dense or hybrid searcher
110 | query: str
111 | Query for which to retrieve results
112 | num_results: int
113 | Maximum number of results to retrieve
114 |
115 | Returns:
116 | --------
117 | Dict:
118 | """
119 | def _get_dict(r: Union[DenseSearchResult, JLuceneSearcherResult]):
120 | if isinstance(r, JLuceneSearcherResult):
121 | return json.loads(r.raw)
122 | elif isinstance(r, DenseSearchResult):
123 | # Get document from sparse_index using index reader
124 | return json.loads(reader.doc(r.docid).raw())
125 |
126 | search_results = searcher.search(query, k=num_results)
127 | all_results = [
128 | SearchResult(
129 | docid=result["id"],
130 | text=result["contents"],
131 | score=search_results[idx].score
132 | ) for idx, result in enumerate(map(lambda r: _get_dict(r), search_results))
133 | ]
134 |
135 | return all_results
136 |
--------------------------------------------------------------------------------
/templates/gradio/{{ cookiecutter.local_app }}/app.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from datasets import load_from_disk
3 | from pyserini.search.lucene import LuceneSearcher
4 |
5 | searcher = LuceneSearcher("index")
6 | ds = load_from_disk("data")
7 | NUM_PAGES = 10 # STATIC. THIS CAN'T CHANGE BECAUSE GRADIO CAN'T DYNAMICALLY CREATE COMPONENTS.
8 | RESULTS_PER_PAGE = 5
9 |
10 | TEXT_FIELD = "{{ cookiecutter.dset_text_field }}"
11 | METADATA_FIELD = "{{ cookiecutter.metadata_field }}"
12 |
13 |
14 | def result_html(result, meta):
15 | return (
16 | f"{meta}
"
17 | f"{result[:250]}...
{result[250:]}
"
18 | )
19 |
20 |
21 | def format_results(results):
22 | return "\n".join([result_html(result, meta) for result,meta in zip(results[TEXT_FIELD], results[METADATA_FIELD])])
23 |
24 |
25 | def page_0(query):
26 | hits = searcher.search(query, k=NUM_PAGES*RESULTS_PER_PAGE)
27 | ix = [int(hit.docid) for hit in hits]
28 | results = ds.select(ix).shard(num_shards=NUM_PAGES, index=0, contiguous=True) # no need to shard. split ix in batches instead. (would make sense if results was cacheable)
29 | results = format_results(results)
30 | return results, [ix], gr.update(visible=True)
31 |
32 |
33 | def page_i(i, ix):
34 | ix = ix[0]
35 | results = ds.select(ix).shard(num_shards=NUM_PAGES, index=i, contiguous=True)
36 | results = format_results(results)
37 | return results, [ix]
38 |
39 |
40 | with gr.Blocks(css="#b {min-width:15px;background:transparent;border:white;box-shadow:none;}") as demo: #
41 | with gr.Row():
42 | gr.Markdown(value="""## {{ cookiecutter.space_title }}
""")
43 | with gr.Row():
44 | with gr.Column(scale=1):
45 | result_list = gr.Dataframe(type="array", visible=False, col_count=1)
46 | with gr.Column(scale=13):
47 | query = gr.Textbox(lines=1, max_lines=1, placeholder="Search…", label="")
48 | with gr.Column(scale=1):
49 | with gr.Row(scale=1):
50 | pass
51 | with gr.Row(scale=1):
52 | submit_btn = gr.Button("🔍", elem_id="b").style(full_width=False)
53 | with gr.Row(scale=1):
54 | pass
55 |
56 | with gr.Row():
57 | with gr.Column(scale=1):
58 | pass
59 | with gr.Column(scale=13):
60 | c = gr.HTML(label="Results")
61 | with gr.Row(visible=False) as pagination:
62 | # left = gr.Button(value="◀", elem_id="b", visible=False).style(full_width=True)
63 | page_1 = gr.Button(value="1", elem_id="b").style(full_width=True)
64 | page_2 = gr.Button(value="2", elem_id="b").style(full_width=True)
65 | page_3 = gr.Button(value="3", elem_id="b").style(full_width=True)
66 | page_4 = gr.Button(value="4", elem_id="b").style(full_width=True)
67 | page_5 = gr.Button(value="5", elem_id="b").style(full_width=True)
68 | page_6 = gr.Button(value="6", elem_id="b").style(full_width=True)
69 | page_7 = gr.Button(value="7", elem_id="b").style(full_width=True)
70 | page_8 = gr.Button(value="8", elem_id="b").style(full_width=True)
71 | page_9 = gr.Button(value="9", elem_id="b").style(full_width=True)
72 | page_10 = gr.Button(value="10", elem_id="b").style(full_width=True)
73 | # right = gr.Button(value="▶", elem_id="b", visible=False).style(full_width=True)
74 | with gr.Column(scale=1):
75 | pass
76 |
77 | query.submit(fn=page_0, inputs=[query], outputs=[c, result_list, pagination])
78 | submit_btn.click(page_0, inputs=[query], outputs=[c, result_list, pagination])
79 |
80 | with gr.Box(visible=False):
81 | nums = [gr.Number(i, visible=False, precision=0) for i in range(NUM_PAGES)]
82 |
83 | page_1.click(fn=page_i, inputs=[nums[0], result_list], outputs=[c, result_list])
84 | page_2.click(fn=page_i, inputs=[nums[1], result_list], outputs=[c, result_list])
85 | page_3.click(fn=page_i, inputs=[nums[2], result_list], outputs=[c, result_list])
86 | page_4.click(fn=page_i, inputs=[nums[3], result_list], outputs=[c, result_list])
87 | page_5.click(fn=page_i, inputs=[nums[4], result_list], outputs=[c, result_list])
88 | page_6.click(fn=page_i, inputs=[nums[5], result_list], outputs=[c, result_list])
89 | page_7.click(fn=page_i, inputs=[nums[6], result_list], outputs=[c, result_list])
90 | page_8.click(fn=page_i, inputs=[nums[7], result_list], outputs=[c, result_list])
91 | page_9.click(fn=page_i, inputs=[nums[8], result_list], outputs=[c, result_list])
92 | page_10.click(fn=page_i, inputs=[nums[9], result_list], outputs=[c, result_list])
93 |
94 | demo.launch(enable_queue=True, debug=True)
95 |
--------------------------------------------------------------------------------
/templates/gradio_roots_temp/{{ cookiecutter.local_app }}/app.py:
--------------------------------------------------------------------------------
1 | from typing import List, NewType, Optional, Union
2 |
3 | import gradio as gr
4 |
5 | from spacerini_utils.index import fetch_index_stats
6 | from spacerini_utils.search import _search, init_searcher, SearchResult
7 |
8 | HTML = NewType('HTML', str)
9 |
10 | searcher = init_searcher(sparse_index_path="sparse_index")
11 |
12 | def get_docid_html(docid: Union[int, str]) -> HTML:
13 | {% if cookiecutter.private -%}
14 | docid_html = (
15 | f"🔒{{ cookiecutter.dataset_name }}/'+f'{docid}'
20 | )
21 | {%- else -%}
22 | docid_html = (
23 | f"🔒{{ cookiecutter.dataset_name }}/'+f'{docid}'
29 | )
30 | {%- endif %}
31 |
32 | return docid_html
33 |
34 |
35 | def process_results(results: List[SearchResult], language: str, highlight_terms: Optional[List[str]] = None) -> HTML:
36 | if len(results) == 0:
37 | return """
38 | No results retrieved.
"""
39 |
40 | results_html = ""
41 | for result in results:
42 | tokens = result["text"].split()
43 |
44 | tokens_html = []
45 | if highlight_terms:
46 | for token in tokens:
47 | if token in highlight_terms:
48 | tokens_html.append("{}".format(token))
49 | else:
50 | tokens_html.append(token)
51 |
52 | tokens_html = " ".join(tokens_html)
53 | meta_html = (
54 | """
55 |
56 | """
57 | )
58 | docid_html = get_docid_html(result["docid"])
59 | results_html += """{}
60 |
Document ID: {}
61 | Score: {}
62 | Language: {}
63 | {}
64 |
65 | """.format(
66 | meta_html, docid_html, result["score"], language, tokens_html
67 | )
68 | return results_html + "
"
69 |
70 |
71 | def search(query: str, language: str, num_results: int = 10) -> HTML:
72 | results_dict = _search(searcher, query, num_results=num_results)
73 | return process_results(results_dict, language)
74 |
75 |
76 | stats = fetch_index_stats('sparse_index/')
77 |
78 | description = f"""# {{ cookiecutter.emoji }} 🔎 {{ cookiecutter.space_title }} 🔍 {{ cookiecutter.emoji }}
79 | {{ cookiecutter.space_description}}
80 | Dataset Statistics: Total Number of Documents = {stats["documents"]}, Number of Terms = {stats["total_terms"]}
"""
81 |
82 | demo = gr.Blocks(
83 | css=".underline-on-hover:hover { text-decoration: underline; } .flagging { font-size:12px; color:Silver; }"
84 | )
85 |
86 | with demo:
87 | with gr.Row():
88 | gr.Markdown(value=description)
89 | with gr.Row():
90 | query = gr.Textbox(lines=1, max_lines=1, placeholder="Type your query here...", label="Query")
91 | with gr.Row():
92 | lang = gr.Dropdown(
93 | choices=[
94 | "en",
95 | "detect_language",
96 | "all",
97 | ],
98 | value="en",
99 | label="Language",
100 | )
101 | with gr.Row():
102 | k = gr.Slider(1, 100, value=10, step=1, label="Max Results")
103 | with gr.Row():
104 | submit_btn = gr.Button("Submit")
105 | with gr.Row():
106 | results = gr.HTML(label="Results")
107 |
108 |
109 | def submit(query: str, lang: str, k: int):
110 | query = query.strip()
111 | if query is None or query == "":
112 | return "", ""
113 | return {
114 | results: search(query, lang, k),
115 | }
116 |
117 | query.submit(fn=submit, inputs=[query, lang, k], outputs=[results])
118 | submit_btn.click(submit, inputs=[query, lang, k], outputs=[results])
119 |
120 | demo.launch(enable_queue=True, debug=True)
121 |
--------------------------------------------------------------------------------
/src/spacerini/data/load.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datasets import Dataset, IterableDataset
3 | from datasets import load_dataset
4 | import pandas as pd
5 | import ir_datasets
6 | from typing import Generator, Dict, List
7 |
8 | def ir_dataset_dict_generator(dataset_name: str) -> Generator[Dict,None,None]:
9 | """
10 | Generator for streaming datasets from ir_datasets
11 | https://github.com/allenai/ir_datasets
12 | Parameters
13 | ----------
14 | dataset_name : str
15 | Name of dataset to load
16 | """
17 | dataset = ir_datasets.load(dataset_name)
18 | for doc in dataset.docs_iter():
19 | yield {
20 | "contents": doc.text,
21 | "docid": doc.doc_id
22 | }
23 |
24 | def load_ir_dataset(dataset_name: str) -> Dataset:
25 | """
26 | Load dataset from ir_datasets
27 |
28 | Parameters
29 | ----------
30 | dataset_name : str
31 | Name of dataset to load
32 |
33 | Returns
34 | -------
35 | Dataset
36 | """
37 | dataset = ir_datasets.load(dataset_name)
38 | return Dataset.from_pandas(pd.DataFrame(dataset.docs_iter()))
39 |
40 | def load_ir_dataset_low_memory(dataset_name: str, num_proc: int) -> IterableDataset:
41 | """
42 | Load dataset from ir_datasets by streaming into a Dataset object. This is slower than first loading the data into a pandas DataFrame but does not require loading the entire ir_dataset into memory. This variant also supports multiprocessing if the data is sharded, and only consumes the generator once then caches it so that future calls are instantaneous.
43 |
44 | Parameters
45 | ----------
46 | dataset_name : str
47 | Name of dataset to load
48 | num_proc: int
49 | Number of processes to use
50 |
51 | Returns
52 | -------
53 | Dataset
54 | """
55 | return Dataset.from_generator(ir_dataset_dict_generator, gen_kwargs={"dataset_name": dataset_name}, num_proc=num_proc)
56 |
57 | def load_ir_dataset_streaming(dataset_name: str) -> IterableDataset:
58 | """
59 | Load dataset from ir_datasets
60 |
61 | Parameters
62 | ----------
63 | dataset_name : str
64 | Name of dataset to load
65 |
66 | Returns
67 | -------
68 | IterableDataset
69 | """
70 | return IterableDataset.from_generator(ir_dataset_dict_generator, gen_kwargs={"dataset_name": dataset_name})
71 |
72 | def load_from_pandas(df: pd.DataFrame) -> Dataset:
73 | """
74 | Load dataset from pandas DataFrame
75 |
76 | Parameters
77 | ----------
78 | df : pd.DataFrame
79 | DataFrame to load
80 |
81 | Returns
82 | -------
83 | Dataset
84 | """
85 | return Dataset.from_pandas(df)
86 |
87 | def load_from_hub(dataset_name_or_path: str, split: str, config_name: str=None, streaming: bool = True) -> Dataset:
88 | """
89 | Load dataset from HuggingFace Hub
90 |
91 | Parameters
92 | ----------
93 | dataset_name_or_path : str
94 | Name of dataset to load
95 | split : str
96 | Split of dataset to load
97 | streaming : bool
98 | Whether to load dataset in streaming mode
99 |
100 | Returns
101 | -------
102 | Dataset
103 | """
104 |
105 | return load_dataset(dataset_name_or_path, split=split, streaming=streaming, name=config_name)
106 |
107 | def load_from_local(dataset_name_or_path: str or List[str], split: str, streaming: bool = True) -> Dataset:
108 | """
109 | Load dataset from local text file. Supports JSON, JSONL, CSV, and TSV.
110 |
111 | Parameters
112 | ----------
113 | dataset_name_or_path : str
114 | Path to a .json, .jsonl, .csv, or .tsv file
115 | streaming : bool
116 | Whether to load dataset in streaming mode
117 |
118 | Returns
119 | -------
120 | Dataset
121 | """
122 | if isinstance(dataset_name_or_path, str):
123 | assert os.path.exists(dataset_name_or_path), f"File {dataset_name_or_path} does not exist"
124 | if dataset_name_or_path.endswith(".jsonl") or dataset_name_or_path.endswith(".json"):
125 | return load_dataset("json", data_files=dataset_name_or_path, split=split, streaming=streaming)
126 | elif dataset_name_or_path.endswith(".csv"):
127 | return load_dataset("csv", data_files=dataset_name_or_path, split=split, streaming=streaming, sep=",")
128 | elif dataset_name_or_path.endswith(".tsv"):
129 | return load_dataset("csv", data_files=dataset_name_or_path, split=split, streaming=streaming, sep="\t")
130 | else:
131 | raise ValueError("Unsupported file type")
132 |
133 | def load_from_sqlite_table(uri_or_con: str, table_or_query: str) -> Dataset:
134 | """
135 | Load dataset from a sqlite database table
136 |
137 | Parameters
138 | ----------
139 | uri_or_con : str
140 | URI to a SQLITE database or connection object
141 | table_or_query : str
142 | database table or query that returns a table
143 |
144 | Returns
145 | -------
146 | Dataset
147 | """
148 | return Dataset.from_sql(con=uri_or_con, sql=table_or_query)
149 |
--------------------------------------------------------------------------------
/src/spacerini/search/utils.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from typing import Dict
3 | from typing import List
4 | from typing import Literal
5 | from typing import Protocol
6 | from typing import TypedDict
7 | from typing import Union
8 |
9 | from datasets import Dataset
10 | from pyserini.analysis import get_lucene_analyzer
11 | from pyserini.search import DenseSearchResult
12 | from pyserini.search import JLuceneSearcherResult
13 | from pyserini.search.faiss.__main__ import init_query_encoder
14 | from pyserini.search.faiss import FaissSearcher
15 | from pyserini.search.hybrid import HybridSearcher
16 | from pyserini.search.lucene import LuceneSearcher
17 |
18 | Encoder = Literal["dkrr", "dpr", "tct_colbert", "ance", "sentence", "contriever", "auto"]
19 |
20 |
21 | class AnalyzerArgs(TypedDict):
22 | language: str
23 | stemming: bool
24 | stemmer: str
25 | stopwords: bool
26 | huggingFaceTokenizer: str
27 |
28 |
29 | class Searcher(Protocol):
30 | def search(self, query: str, **kwargs) -> List[Union[DenseSearchResult, JLuceneSearcherResult]]:
31 | ...
32 |
33 |
34 | def init_searcher(
35 | sparse_index_path: str = None,
36 | bm25_k1: float = None,
37 | bm25_b: float = None,
38 | analyzer_args: AnalyzerArgs = None,
39 | dense_index_path: str = None,
40 | encoder_name_or_path: str = None,
41 | encoder_class: Encoder = None,
42 | tokenizer_name: str = None,
43 | device: str = None,
44 | prefix: str = None
45 | ) -> Union[FaissSearcher, HybridSearcher, LuceneSearcher]:
46 | """
47 | Initialize and return an approapriate searcher
48 |
49 | Parameters
50 | ----------
51 | sparse_index_path: str
52 | Path to sparse index
53 | dense_index_path: str
54 | Path to dense index
55 | encoder_name_or_path: str
56 | Path to query encoder checkpoint or encoder name
57 | encoder_class: str
58 | Query encoder class to use. If None, infer from `encoder`
59 | tokenizer_name: str
60 | Tokenizer name or path
61 | device: str
62 | Device to load Query encoder on.
63 | prefix: str
64 | Query prefix if exists
65 |
66 | Returns
67 | -------
68 | Searcher:
69 | A sparse, dense or hybrid searcher
70 | """
71 | if sparse_index_path:
72 | ssearcher = LuceneSearcher(sparse_index_path)
73 | if analyzer_args:
74 | analyzer = get_lucene_analyzer(**analyzer_args)
75 | ssearcher.set_analyzer(analyzer)
76 | if bm25_k1 and bm25_b:
77 | ssearcher.set_bm25(bm25_k1, bm25_b)
78 |
79 | if dense_index_path:
80 | encoder = init_query_encoder(
81 | encoder=encoder_name_or_path,
82 | encoder_class=encoder_class,
83 | tokenizer_name=tokenizer_name,
84 | topics_name=None,
85 | encoded_queries=None,
86 | device=device,
87 | prefix=prefix
88 | )
89 |
90 | dsearcher = FaissSearcher(dense_index_path, encoder)
91 |
92 | if sparse_index_path:
93 | hsearcher = HybridSearcher(dense_searcher=dsearcher, sparse_searcher=ssearcher)
94 | return hsearcher
95 | else:
96 | return dsearcher
97 |
98 | return ssearcher
99 |
100 |
101 | def result_indices(
102 | query: str,
103 | num_results: int,
104 | searcher: Searcher,
105 | index_path: str,
106 | analyzer=None
107 | ) -> list:
108 | """
109 | Get the indices of the results of a query.
110 | Parameters
111 | ----------
112 | query : str
113 | The query.
114 | num_results : int
115 | The number of results to return.
116 | index_path : str
117 | The path to the index.
118 | analyzer : str (default=None)
119 | The analyzer to use.
120 |
121 | Returns
122 | -------
123 | list
124 | The indices of the returned documents.
125 | """
126 | # searcher.search()
127 | # searcher = LuceneSearcher(index_path)
128 | # if analyzer is not None:
129 | # searcher.set_analyzer(analyzer)
130 | hits = searcher.search(query, k=num_results)
131 | ix = [int(hit.docid) for hit in hits]
132 | return ix
133 |
134 |
135 | def result_page(
136 | hf_dataset: Dataset,
137 | result_indices: List[int],
138 | page: int = 0,
139 | results_per_page: int=10
140 | ) -> Dataset:
141 | """
142 | Returns a the ith results page as a datasets.Dataset object. Nothing is loaded into memory. Call `to_pandas()` on the returned Dataset to materialize the table.
143 | ----------
144 | hf_dataset : datasets.Dataset
145 | a Hugging Face datasets dataset.
146 | result_indices : list of int
147 | The indices of the results.
148 | page: int (default=0)
149 | The result page to return. Returns the first page by default.
150 | results_per_page : int (default=10)
151 | The number of results per page.
152 |
153 | Returns
154 | -------
155 | datasets.Dataset
156 | A results page.
157 | """
158 | results = hf_dataset.select(result_indices)
159 | num_result_pages = int(len(results)/results_per_page) + 1
160 | return results.shard(num_result_pages, page, contiguous=True)
161 |
--------------------------------------------------------------------------------
/src/spacerini/index/index.py:
--------------------------------------------------------------------------------
1 | from itertools import chain
2 | from typing import List
3 | from typing import Literal
4 |
5 | import shutil
6 | from pyserini.index.lucene import LuceneIndexer, IndexReader
7 | from typing import Any, Dict, List
8 | from pyserini.pyclass import autoclass
9 | from tqdm import tqdm
10 | import json
11 | import os
12 |
13 | from spacerini.data import load_from_hub, load_from_local
14 |
15 |
16 | def parse_args(
17 | index_path: str,
18 | shards_path: str = "",
19 | fields: List[str] = None,
20 | language: str = None,
21 | pretokenized: bool = False,
22 | analyzeWithHuggingFaceTokenizer: str = None,
23 | storePositions: bool = True,
24 | storeDocvectors: bool = False,
25 | storeContents: bool = False,
26 | storeRaw: bool = False,
27 | keepStopwords: bool = False,
28 | stopwords: str = None,
29 | stemmer: Literal["porter", "krovetz"] = None,
30 | optimize: bool = False,
31 | verbose: bool = False,
32 | quiet: bool = False,
33 | memory_buffer: str ="4096",
34 | n_threads: int = 5,
35 | for_otf_indexing: bool = False,
36 | **kwargs
37 | ) -> List[str]:
38 | """
39 | Parse arguments into list for `SimpleIndexer` class in Anserini.
40 |
41 | Parameters
42 | ----------
43 | for_otf_indexing : bool
44 | If True, `-input` & `-collection` args are safely ignored.
45 | Used when performing on-the-fly indexing with HF Datasets.
46 |
47 | See [docs](docs/arguments.md) for remaining argument definitions
48 |
49 | Returns
50 | -------
51 | List of arguments to initialize the `LuceneIndexer`
52 | """
53 | params = locals()
54 | args = []
55 | args.extend([
56 | "-input", shards_path,
57 | "-collection", "JsonCollection",
58 | "-threads", f"{n_threads}" if n_threads!=-1 else f"{os.cpu_count()}",
59 | "-generator", "DefaultLuceneDocumentGenerator",
60 | "-index", index_path,
61 | "-memorybuffer", memory_buffer,
62 | ])
63 | variables = [
64 | "analyzeWithHuggingFaceTokenizer", "language", "stemmer"
65 | ]
66 | additional_args = [[f"-{var}", params[var]] for var in variables if params[var]]
67 | args.extend(chain.from_iterable(additional_args))
68 |
69 | if fields:
70 | args.extend(["-fields", " ".join(fields)])
71 |
72 | flags = [
73 | "pretokenized", "storePositions", "storeDocvectors",
74 | "storeContents", "storeRaw", "optimize", "verbose", "quiet"
75 | ]
76 | args.extend([f"-{flag}" for flag in flags if params[flag]])
77 |
78 | if for_otf_indexing:
79 | args = args[4:]
80 | return args
81 |
82 |
83 | def index_json_shards(
84 | shards_path: str,
85 | index_path: str,
86 | keep_shards: bool = True,
87 | fields: List[str] = None,
88 | language: str = "en",
89 | pretokenized: bool = False,
90 | analyzeWithHuggingFaceTokenizer: str = None,
91 | storePositions: bool = True,
92 | storeDocvectors: bool = False,
93 | storeContents: bool = False,
94 | storeRaw: bool = False,
95 | keepStopwords: bool = False,
96 | stopwords: str = None,
97 | stemmer: Literal["porter", "krovetz"] = "porter",
98 | optimize: bool = True,
99 | verbose: bool = False,
100 | quiet: bool = False,
101 | memory_buffer: str = "4096",
102 | n_threads: bool = 5,
103 | ):
104 | """Index dataset from a directory containing files
105 |
106 | Parameters
107 | ----------
108 | shards_path : str
109 | Path to dataset to index
110 | index_path : str
111 | Directory to store index
112 | keep_shards : bool
113 | If False, remove dataset after indexing is complete
114 |
115 | See [docs](../../docs/arguments.md) for remaining argument definitions
116 |
117 | Returns
118 | -------
119 | None
120 | """
121 | args = parse_args(**locals())
122 | JIndexCollection = autoclass('io.anserini.index.IndexCollection')
123 | JIndexCollection.main(args)
124 | if not keep_shards:
125 | shutil.rmtree(shards_path)
126 |
127 | return None
128 |
129 |
130 | def index_streaming_dataset(
131 | index_path: str,
132 | dataset_name_or_path: str,
133 | split: str,
134 | column_to_index: List[str],
135 | doc_id_column: str = None,
136 | ds_config_name: str = None, # For HF Dataset
137 | num_rows: int = -1,
138 | disable_tqdm: bool = False,
139 | language: str = "en",
140 | pretokenized: bool = False,
141 | analyzeWithHuggingFaceTokenizer: str = None,
142 | storePositions: bool = True,
143 | storeDocvectors: bool = False,
144 | storeContents: bool = False,
145 | storeRaw: bool = False,
146 | keepStopwords: bool = False,
147 | stopwords: str = None,
148 | stemmer: Literal["porter", "krovetz"] = "porter",
149 | optimize: bool = True,
150 | verbose: bool = False,
151 | quiet: bool = True,
152 | memory_buffer: str = "4096",
153 | n_threads: bool = 5,
154 | ):
155 | """Stream dataset from HuggingFace Hub & index
156 |
157 | Parameters
158 | ----------
159 | dataset_name_or_path : str
160 | Name of HuggingFace dataset to stream
161 | split : str
162 | Split of dataset to index
163 | column_to_index : List[str]
164 | Column of dataset to index
165 | doc_id_column : str
166 | Column of dataset to use as document ID
167 | ds_config_name: str
168 | Dataset configuration to stream. Usually a language name or code
169 | num_rows : int
170 | Number of rows in dataset
171 |
172 | See [docs](../../docs/arguments.md) for remaining argument definitions
173 |
174 | Returns
175 | -------
176 | None
177 | """
178 |
179 | args = parse_args(**locals(), for_otf_indexing=True)
180 | if os.path.exists(dataset_name_or_path):
181 | ds = load_from_local(dataset_name_or_path, split=split, streaming=True)
182 | else:
183 | ds = load_from_hub(dataset_name_or_path, split=split,config_name=ds_config_name, streaming=True)
184 |
185 | indexer = LuceneIndexer(args=args)
186 |
187 | for i, row in tqdm(enumerate(ds), total=num_rows, disable=disable_tqdm):
188 | contents = " ".join([row[column] for column in column_to_index])
189 | indexer.add_doc_raw(json.dumps({"id": i if not doc_id_column else row[doc_id_column] , "contents": contents}))
190 |
191 | indexer.close()
192 |
193 | return None
194 |
195 |
196 | def fetch_index_stats(index_path: str) -> Dict[str, Any]:
197 | """
198 | Fetch index statistics
199 | index_path : str
200 | Path to index directory
201 | Returns
202 | -------
203 | Dictionary of index statistics
204 | Dictionary Keys ==> total_terms, documents, unique_terms
205 | """
206 | assert os.path.exists(index_path), f"Index path {index_path} does not exist"
207 | index_reader = IndexReader(index_path)
208 | return index_reader.stats()
209 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/src/spacerini/cli.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import logging
4 | from pathlib import Path
5 | from shutil import copytree
6 |
7 | from spacerini.frontend import create_app, create_space_from_local
8 | from spacerini.index import index_streaming_dataset
9 | from spacerini.index.encode import encode_json_dataset
10 | from spacerini.prebuilt import EXAMPLES
11 |
12 |
13 | def update_args_from_json(_args: argparse.Namespace, file: str) -> argparse.Namespace:
14 | config = json.load(open(file, "r"))
15 | config = {k.replace("-", "_"): v for k,v in config.items()}
16 |
17 | args_dict = vars(_args)
18 | args_dict.update(config)
19 | return _args
20 |
21 |
22 | # TODO: @theyorubayesian: Switch to HFArgumentParser with post init
23 | # How to use multiple argparse commands with HFArgumentParser
24 | def get_args() -> argparse.Namespace:
25 | parser = argparse.ArgumentParser(
26 | prog="Spacerini",
27 | description="A modular framework for seamless building and deployment of interactive search applications.",
28 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
29 | epilog="Written by: Akintunde 'theyorubayesian' Oladipo , Christopher Akiki , Odunayo Ogundepo "
30 | )
31 | parser.add_argument("--space-name", required=False)
32 | parser.add_argument("--config-file", type=str, help="Path to configuration for space")
33 | parser.add_argument("--from-example", type=str, choices=list(EXAMPLES), help="Name of an example spaces applications.")
34 | parser.add_argument("--verbose", type=bool, default=True, help="If True, print verbose output")
35 | parser.add_argument("--index-exists", type=str, help="Path to an existing index to be used for Spaces application")
36 |
37 | # --------
38 | # Commands
39 | # --------
40 | sub_parser = parser.add_subparsers(dest="command", title="Commands", description="Valid commands")
41 | index_parser = sub_parser.add_parser("index", help="Index dataset. This will not create the app or deploy it.")
42 | create_parser = sub_parser.add_parser("create-space", help="Create space in local directory. This won't deploy the space.")
43 | deploy_only_parser = sub_parser.add_parser("deploy-only", help="Deploy an already created space")
44 | deploy_parser = sub_parser.add_parser("deploy", help="Deploy new space.")
45 | deploy_parser.add_argument("--delete-after", type=bool, default=False, help="If True, delete the local directory after pushing it to the Hub.")
46 | # --------
47 |
48 | space_args = parser.add_argument_group("Space arguments")
49 | space_args.add_argument("--space-title", required=False, help="Title to show on Space")
50 | space_args.add_argument("--space-url-slug", required=False, help="")
51 | space_args.add_argument("--sdk", default="gradio", choices=["gradio", "streamlit"])
52 | space_args.add_argument("--template", default="streamlit", help="A directory containing a project template directory, or a URL to a git repository.")
53 | space_args.add_argument("--organization", required=False, help="Organization to deploy new space under.")
54 | space_args.add_argument("--description", type=str, help="Description of new space")
55 | space_args.add_argument("--private", action="store_true", help="Deploy private Spaces application")
56 |
57 | data_args = parser.add_argument_group("Data arguments")
58 | data_args.add_argument("--columns-to-index", default=[], action="store", nargs="*", help="Other columns to index in dataset")
59 | data_args.add_argument("--split", type=str, required=False, default="train", help="Dataset split to index")
60 | data_args.add_argument("--dataset", type=str, required=False, help="Local dataset folder or Huggingface name")
61 | data_args.add_argument("--docid-column", default="id", help="Name of docid column in dataset")
62 | data_args.add_argument("--language", default="en", help="ISO Code for language of dataset")
63 | data_args.add_argument("--title-column", type=str, help="Name of title column in data")
64 | data_args.add_argument("--content-column", type=str, default="content", help="Name of content column in data")
65 | data_args.add_argument("--expand-column", type=str, help="Name of column containing document expansion. Used in dense indexes")
66 |
67 | sparse_index_args = parser.add_argument_group("Sparse Index arguments")
68 | sparse_index_args.add_argument("--collection", type=str, help="Collection class")
69 | sparse_index_args.add_argument("--memory-buffer", type=str, help="Memory buffer size")
70 | sparse_index_args.add_argument("--threads", type=int, default=5, help="Number of threads to use for indexing")
71 | sparse_index_args.add_argument("--hf-tokenizer", type=str, default=None, help="HuggingFace tokenizer to tokenize dataset")
72 | sparse_index_args.add_argument("--pretokenized", type=bool, default=False, help="If True, dataset is already tokenized")
73 | sparse_index_args.add_argument("--store-positions", action="store_true", help="If True, store document vectors in index") # TODO: @theyorubayesian
74 | sparse_index_args.add_argument("--store-docvectors", action="store_true", help="If True, store document vectors in index") # TODO: @theyorubayesian
75 | sparse_index_args.add_argument("--store-contents", action="store_true", help="If True, store contents of documents in index")
76 | sparse_index_args.add_argument("--store-raw", action="store_true", help="If True, store raw contents of documents in index")
77 | sparse_index_args.add_argument("--keep-stopwords", action="store_true", help="If True, keep stopwords in index")
78 | sparse_index_args.add_argument("--optimize-index", action="store_true", help="If True, optimize index after indexing is complete")
79 | sparse_index_args.add_argument("--stopwords", type=str, help="Path to stopwords file")
80 | sparse_index_args.add_argument("--stemmer", type=str, nargs=1, choices=["porter", "krovetz"], help="Stemmer to use for indexing")
81 |
82 | dense_index_args = parser.add_argument_group("Dense Index arguments")
83 | dense_index_args.add_argument("--encoder-name-or-path", type=str, help="Encoder name or path")
84 | dense_index_args.add_argument("--encoder-class", default="auto", type=str, choices=["dpr", "bpr", "tct_colbert", "ance", "sentence-transformers", "auto"], help="Encoder to use")
85 | dense_index_args.add_argument("--delimiter", default="\n", type=str, help="Delimiter for the fields in encoded corpus")
86 | dense_index_args.add_argument("--index-shard-id", default=0, type=int, help="Zero-based index shard id")
87 | dense_index_args.add_argument("--n-index-shards", type=int, default=1, help="Number of index shards")
88 | dense_index_args.add_argument("--batch-size", default=64, type=int, help="Batch size for encoding")
89 | dense_index_args.add_argument("--max-length", type=int, default=256, help="Max document length to encode")
90 | dense_index_args.add_argument('--device', default='cuda:0', type=str, help='Device: cpu or cuda [cuda:0, cuda:1...]', required=False)
91 | dense_index_args.add_argument("--dimension", default=768, type=int, help="Dimension for Faiss Index")
92 | dense_index_args.add_argument("--add-sep", action="store_true", help="Pass `title` and `content` columns separately into encode function")
93 | dense_index_args.add_argument("--to-faiss", action="store_true", help="Store embeddings in Faiss Index")
94 | dense_index_args.add_argument("--fp16", action="store_true", help="Use FP 16")
95 |
96 | search_args = parser.add_argument_group("Search arguments")
97 | search_args.add_argument("--bm25_k1", type=float, help="BM25: k1 parameter")
98 | search_args.add_argument("--bm24_b", type=float, help="BM25: b parameter")
99 |
100 | args, _ = parser.parse_known_args()
101 |
102 | if args.from_example in EXAMPLES:
103 | example_config_path = Path(__file__).parents[2] / "examples" \
104 | / "configs" / f"{args.from_example}.json"
105 | args = update_args_from_json(args, example_config_path)
106 |
107 | # For customization, user provided config file supersedes example config file
108 | if args.config_file:
109 | args = update_args_from_json(args, args.config_file)
110 |
111 | args.template = "templates/gradio_roots_temp"
112 | return args
113 |
114 |
115 | def main():
116 | args = get_args()
117 |
118 | local_app_dir = Path(f"apps/{args.space_name}")
119 | local_app_dir.mkdir(exist_ok=True)
120 |
121 | columns = [args.content_column, *args.columns_to_index]
122 |
123 | if args.command in ["index", "create-space", "deploy"]:
124 | logging.info(f"Indexing {args.dataset} dataset into {str(local_app_dir)}")
125 |
126 | if args.index_exists:
127 | index_dir = Path(args.index_exists)
128 | assert index_dir.exists(), f"No index found at {args.index_exists}"
129 | copytree(index_dir, local_app_dir / index_dir.name, dirs_exist_ok=True)
130 | else:
131 | # TODO: @theyorubayesian
132 | # We always create a sparse index because dense index only contain docid. Does this make sense memory-wise?
133 | # Another option could be to load the dataset from huggingface and filter documents by docids retrieved
134 | # Can documents be stored in dense indexes? Does it make sense to?
135 | index_streaming_dataset(
136 | dataset_name_or_path=args.dataset,
137 | index_path=(local_app_dir / "sparse_index").as_posix(),
138 | split=args.split,
139 | column_to_index=columns,
140 | doc_id_column=args.docid_column,
141 | language=args.language,
142 | storeContents=args.store_contents,
143 | storeRaw=args.store_raw,
144 | analyzeWithHuggingFaceTokenizer=args.hf_tokenizer
145 | )
146 |
147 | if args.encoder_name_or_path:
148 | encode_json_dataset(
149 | data_path=args.dataset,
150 | encoder_name_or_path=args.encoder_name_or_path,
151 | encoder_class=args.encoder_class,
152 | embedding_dir=(local_app_dir / "dense_index").as_posix(),
153 | batch_size=args.batch_size,
154 | device=args.device,
155 | index_shard_id=args.index_shard_id,
156 | num_index_shards=args.n_index_shards,
157 | delimiter=args.delimiter,
158 | max_length=args.max_length,
159 | add_sep=args.add_sep,
160 | title_column_to_encode=args.title_column,
161 | text_column_to_encode=args.content_column,
162 | expand_column_to_encode=args.expand_column,
163 | output_to_faiss=True,
164 | embedding_dimension=args.dimension,
165 | fp16=args.fp16
166 | )
167 |
168 | if args.command in ["create-space", "deploy"]:
169 | logging.info(f"Creating local app into {args.space_name} directory")
170 | # TODO: @theyorubayesian - How to make cookiecutter_vars more flexible
171 | cookiecutter_vars = {
172 | "dset_text_field": columns,
173 | "space_title": args.space_title,
174 | "local_app": args.space_name,
175 | "space_description": args.description,
176 | "dataset_name": args.dataset
177 | }
178 |
179 | create_app(
180 | template=args.template,
181 | extra_context_dict=cookiecutter_vars,
182 | output_dir="apps"
183 | )
184 |
185 | if args.command in ["deploy", "deploy-only"]:
186 | logging.info(f"Creating space {args.space_name} on {args.organization}")
187 | create_space_from_local(
188 | space_slug=args.space_url_slug,
189 | organization=args.organization,
190 | space_sdk=args.sdk,
191 | local_dir=local_app_dir,
192 | delete_after_push=args.delete_after,
193 | private=args.private
194 | )
195 |
196 |
197 | if __name__ == "__main__":
198 | main()
199 |
--------------------------------------------------------------------------------