├── .coveragerc ├── .github └── workflows │ ├── build.yml │ └── docs.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .pylintrc ├── CITATION.cff ├── LICENSE ├── Makefile ├── README.md ├── apps.jpg ├── demo.gif ├── docker ├── api │ └── Dockerfile ├── aws │ ├── Dockerfile │ ├── api.py │ └── workflow.py ├── base │ └── Dockerfile ├── schedule │ └── Dockerfile └── workflow │ └── Dockerfile ├── docs ├── agent │ ├── configuration.md │ ├── index.md │ └── methods.md ├── api │ ├── cluster.md │ ├── configuration.md │ ├── customization.md │ ├── index.md │ ├── mcp.md │ ├── methods.md │ ├── openai.md │ └── security.md ├── cloud.md ├── embeddings │ ├── configuration │ │ ├── ann.md │ │ ├── cloud.md │ │ ├── database.md │ │ ├── general.md │ │ ├── graph.md │ │ ├── index.md │ │ ├── scoring.md │ │ └── vectors.md │ ├── format.md │ ├── index.md │ ├── indexing.md │ ├── methods.md │ └── query.md ├── examples.md ├── faq.md ├── further.md ├── images │ ├── agent.excalidraw │ ├── agent.png │ ├── api-dark.png │ ├── api.excalidraw │ ├── api.png │ ├── architecture-dark.png │ ├── architecture.excalidraw │ ├── architecture.png │ ├── cloud-dark.png │ ├── cloud.excalidraw │ ├── cloud.png │ ├── embeddings-dark.png │ ├── embeddings.excalidraw │ ├── embeddings.png │ ├── examples-dark.png │ ├── examples.excalidraw │ ├── examples.png │ ├── faq.excalidraw │ ├── faq.png │ ├── flows-dark.png │ ├── flows.excalidraw │ ├── flows.png │ ├── format-dark.png │ ├── format.excalidraw │ ├── format.png │ ├── further-dark.png │ ├── further-ghdark.png │ ├── further.excalidraw │ ├── further.png │ ├── indexing-dark.png │ ├── indexing.excalidraw │ ├── indexing.png │ ├── install-dark.png │ ├── install.excalidraw │ ├── install.png │ ├── llm.excalidraw │ ├── llm.png │ ├── logo.png │ ├── models.excalidraw │ ├── models.png │ ├── pipeline-dark.png │ ├── pipeline.excalidraw │ ├── pipeline.png │ ├── query-dark.png │ ├── query.excalidraw │ ├── query.png │ ├── rag-dark.png │ ├── rag.excalidraw │ ├── rag.png │ ├── schedule-dark.png │ ├── schedule.excalidraw │ ├── schedule.png │ ├── search-dark.png │ ├── search.excalidraw │ ├── search.png │ ├── task-dark.png │ ├── task.excalidraw │ ├── task.png │ ├── why-dark.png │ ├── why.excalidraw │ ├── why.png │ ├── workflow-dark.png │ ├── workflow.excalidraw │ └── workflow.png ├── index.md ├── install.md ├── models.md ├── observability.md ├── overrides │ └── main.html ├── pipeline │ ├── audio │ │ ├── audiomixer.md │ │ ├── audiostream.md │ │ ├── microphone.md │ │ ├── texttoaudio.md │ │ ├── texttospeech.md │ │ └── transcription.md │ ├── data │ │ ├── filetohtml.md │ │ ├── htmltomd.md │ │ ├── segmentation.md │ │ ├── tabular.md │ │ └── textractor.md │ ├── image │ │ ├── caption.md │ │ ├── imagehash.md │ │ └── objects.md │ ├── index.md │ ├── text │ │ ├── entity.md │ │ ├── labels.md │ │ ├── llm.md │ │ ├── rag.md │ │ ├── similarity.md │ │ ├── summary.md │ │ └── translation.md │ └── train │ │ ├── hfonnx.md │ │ ├── mlonnx.md │ │ └── trainer.md ├── poweredby.md ├── usecases.md ├── why.md └── workflow │ ├── index.md │ ├── schedule.md │ └── task │ ├── console.md │ ├── export.md │ ├── file.md │ ├── image.md │ ├── index.md │ ├── retrieve.md │ ├── service.md │ ├── storage.md │ ├── template.md │ ├── url.md │ └── workflow.md ├── examples ├── 01_Introducing_txtai.ipynb ├── 02_Build_an_Embeddings_index_with_Hugging_Face_Datasets.ipynb ├── 03_Build_an_Embeddings_index_from_a_data_source.ipynb ├── 04_Add_semantic_search_to_Elasticsearch.ipynb ├── 05_Extractive_QA_with_txtai.ipynb ├── 06_Extractive_QA_with_Elasticsearch.ipynb ├── 07_Apply_labels_with_zero_shot_classification.ipynb ├── 08_API_Gallery.ipynb ├── 09_Building_abstractive_text_summaries.ipynb ├── 10_Extract_text_from_documents.ipynb ├── 11_Transcribe_audio_to_text.ipynb ├── 12_Translate_text_between_languages.ipynb ├── 13_Similarity_search_with_images.ipynb ├── 14_Run_pipeline_workflows.ipynb ├── 15_Distributed_embeddings_cluster.ipynb ├── 16_Train_a_text_labeler.ipynb ├── 17_Train_without_labels.ipynb ├── 18_Export_and_run_models_with_ONNX.ipynb ├── 19_Train_a_QA_model.ipynb ├── 20_Extractive_QA_to_build_structured_data.ipynb ├── 21_Export_and_run_other_machine_learning_models.ipynb ├── 22_Transform_tabular_data_with_composable_workflows.ipynb ├── 23_Tensor_workflows.ipynb ├── 24_Whats_new_in_txtai_4_0.ipynb ├── 25_Generate_image_captions_and_detect_objects.ipynb ├── 26_Entity_extraction_workflows.ipynb ├── 27_Workflow_scheduling.ipynb ├── 28_Push_notifications_with_workflows.ipynb ├── 29_Anatomy_of_a_txtai_index.ipynb ├── 30_Embeddings_SQL_custom_functions.ipynb ├── 31_Near_duplicate_image_detection.ipynb ├── 32_Model_explainability.ipynb ├── 33_Query_translation.ipynb ├── 34_Build_a_QA_database.ipynb ├── 35_Pictures_are_worth_a_thousand_words.ipynb ├── 36_Run_txtai_in_native_code.ipynb ├── 37_Embeddings_index_components.ipynb ├── 38_Introducing_the_Semantic_Graph.ipynb ├── 39_Classic_Topic_Modeling_with_BM25.ipynb ├── 40_Text_to_Speech_Generation.ipynb ├── 41_Train_a_language_model_from_scratch.ipynb ├── 42_Prompt_driven_search_with_LLMs.ipynb ├── 43_Embeddings_in_the_Cloud.ipynb ├── 44_Prompt_templates_and_task_chains.ipynb ├── 45_Customize_your_own_embeddings_database.ipynb ├── 46_Whats_new_in_txtai_6_0.ipynb ├── 47_Building_an_efficient_sparse_keyword_index_in_Python.ipynb ├── 48_Benefits_of_hybrid_search.ipynb ├── 49_External_database_integration.ipynb ├── 50_All_about_vector_quantization.ipynb ├── 51_Custom_API_Endpoints.ipynb ├── 52_Build_RAG_pipelines_with_txtai.ipynb ├── 53_Integrate_LLM_Frameworks.ipynb ├── 54_API_Authorization_and_Authentication.ipynb ├── 55_Generate_knowledge_with_Semantic_Graphs_and_RAG.ipynb ├── 56_External_vectorization.ipynb ├── 57_Build_knowledge_graphs_with_LLM_driven_entity_extraction.ipynb ├── 58_Advanced_RAG_with_graph_path_traversal.ipynb ├── 59_Whats_new_in_txtai_7_0.ipynb ├── 60_Advanced_RAG_with_guided_generation.ipynb ├── 61_Integrate_txtai_with_Postgres.ipynb ├── 62_RAG_with_llama_cpp_and_external_API_services.ipynb ├── 63_How_RAG_with_txtai_works.ipynb ├── 64_Embeddings_index_format_for_open_data_access.ipynb ├── 65_Speech_to_Speech_RAG.ipynb ├── 66_Generative_Audio.ipynb ├── 67_Whats_new_in_txtai_8_0.ipynb ├── 68_Analyzing_Hugging_Face_Posts_with_Graphs_and_Agents.ipynb ├── 69_Granting_autonomy_to_agents.ipynb ├── 70_Getting_started_with_LLM_APIs.ipynb ├── 71_Analyzing_LinkedIn_Company_Posts_with_Graphs_and_Agents.ipynb ├── 72_Parsing_the_stars_with_txtai.ipynb ├── 73_Chunking_your_data_for_RAG.ipynb ├── 74_OpenAI_Compatible_API.ipynb ├── article.py ├── baseball.py ├── benchmarks.py ├── books.py ├── images.py ├── similarity.py ├── wiki.py └── workflows.py ├── logo.png ├── mkdocs.yml ├── pyproject.toml ├── setup.py ├── src └── python │ └── txtai │ ├── __init__.py │ ├── agent │ ├── __init__.py │ ├── base.py │ ├── factory.py │ ├── model.py │ ├── placeholder.py │ └── tool │ │ ├── __init__.py │ │ ├── embeddings.py │ │ ├── factory.py │ │ └── function.py │ ├── ann │ ├── __init__.py │ ├── annoy.py │ ├── base.py │ ├── factory.py │ ├── faiss.py │ ├── hnsw.py │ ├── numpy.py │ ├── pgvector.py │ ├── sqlite.py │ └── torch.py │ ├── api │ ├── __init__.py │ ├── application.py │ ├── authorization.py │ ├── base.py │ ├── cluster.py │ ├── extension.py │ ├── factory.py │ ├── responses │ │ ├── __init__.py │ │ ├── factory.py │ │ ├── json.py │ │ └── messagepack.py │ ├── route.py │ └── routers │ │ ├── __init__.py │ │ ├── agent.py │ │ ├── caption.py │ │ ├── embeddings.py │ │ ├── entity.py │ │ ├── extractor.py │ │ ├── labels.py │ │ ├── llm.py │ │ ├── objects.py │ │ ├── openai.py │ │ ├── rag.py │ │ ├── segmentation.py │ │ ├── similarity.py │ │ ├── summary.py │ │ ├── tabular.py │ │ ├── textractor.py │ │ ├── texttospeech.py │ │ ├── transcription.py │ │ ├── translation.py │ │ ├── upload.py │ │ └── workflow.py │ ├── app │ ├── __init__.py │ └── base.py │ ├── archive │ ├── __init__.py │ ├── base.py │ ├── compress.py │ ├── factory.py │ ├── tar.py │ └── zip.py │ ├── cloud │ ├── __init__.py │ ├── base.py │ ├── factory.py │ ├── hub.py │ └── storage.py │ ├── console │ ├── __init__.py │ ├── __main__.py │ └── base.py │ ├── data │ ├── __init__.py │ ├── base.py │ ├── labels.py │ ├── questions.py │ ├── sequences.py │ ├── texts.py │ └── tokens.py │ ├── database │ ├── __init__.py │ ├── base.py │ ├── client.py │ ├── duckdb.py │ ├── embedded.py │ ├── encoder │ │ ├── __init__.py │ │ ├── base.py │ │ ├── factory.py │ │ ├── image.py │ │ └── serialize.py │ ├── factory.py │ ├── rdbms.py │ ├── schema │ │ ├── __init__.py │ │ ├── orm.py │ │ └── statement.py │ ├── sql │ │ ├── __init__.py │ │ ├── aggregate.py │ │ ├── base.py │ │ ├── expression.py │ │ └── token.py │ └── sqlite.py │ ├── embeddings │ ├── __init__.py │ ├── base.py │ ├── index │ │ ├── __init__.py │ │ ├── action.py │ │ ├── autoid.py │ │ ├── configuration.py │ │ ├── documents.py │ │ ├── functions.py │ │ ├── indexes.py │ │ ├── indexids.py │ │ ├── reducer.py │ │ ├── stream.py │ │ └── transform.py │ └── search │ │ ├── __init__.py │ │ ├── base.py │ │ ├── errors.py │ │ ├── explain.py │ │ ├── ids.py │ │ ├── query.py │ │ ├── scan.py │ │ └── terms.py │ ├── graph │ ├── __init__.py │ ├── base.py │ ├── factory.py │ ├── networkx.py │ ├── query.py │ ├── rdbms.py │ └── topics.py │ ├── models │ ├── __init__.py │ ├── models.py │ ├── onnx.py │ ├── pooling │ │ ├── __init__.py │ │ ├── base.py │ │ ├── cls.py │ │ ├── factory.py │ │ └── mean.py │ ├── registry.py │ └── tokendetection.py │ ├── pipeline │ ├── __init__.py │ ├── audio │ │ ├── __init__.py │ │ ├── audiomixer.py │ │ ├── audiostream.py │ │ ├── microphone.py │ │ ├── signal.py │ │ ├── texttoaudio.py │ │ ├── texttospeech.py │ │ └── transcription.py │ ├── base.py │ ├── data │ │ ├── __init__.py │ │ ├── filetohtml.py │ │ ├── htmltomd.py │ │ ├── segmentation.py │ │ ├── tabular.py │ │ ├── textractor.py │ │ └── tokenizer.py │ ├── factory.py │ ├── hfmodel.py │ ├── hfpipeline.py │ ├── image │ │ ├── __init__.py │ │ ├── caption.py │ │ ├── imagehash.py │ │ └── objects.py │ ├── llm │ │ ├── __init__.py │ │ ├── factory.py │ │ ├── generation.py │ │ ├── huggingface.py │ │ ├── litellm.py │ │ ├── llama.py │ │ ├── llm.py │ │ └── rag.py │ ├── nop.py │ ├── tensors.py │ ├── text │ │ ├── __init__.py │ │ ├── crossencoder.py │ │ ├── entity.py │ │ ├── labels.py │ │ ├── questions.py │ │ ├── similarity.py │ │ ├── summary.py │ │ └── translation.py │ └── train │ │ ├── __init__.py │ │ ├── hfonnx.py │ │ ├── hftrainer.py │ │ └── mlonnx.py │ ├── scoring │ ├── __init__.py │ ├── base.py │ ├── bm25.py │ ├── factory.py │ ├── pgtext.py │ ├── sif.py │ ├── terms.py │ └── tfidf.py │ ├── serialize │ ├── __init__.py │ ├── base.py │ ├── errors.py │ ├── factory.py │ ├── messagepack.py │ ├── pickle.py │ └── serializer.py │ ├── util │ ├── __init__.py │ ├── resolver.py │ └── template.py │ ├── vectors │ ├── __init__.py │ ├── base.py │ ├── external.py │ ├── factory.py │ ├── huggingface.py │ ├── litellm.py │ ├── llama.py │ ├── m2v.py │ ├── recovery.py │ ├── sbert.py │ └── words.py │ ├── version.py │ └── workflow │ ├── __init__.py │ ├── base.py │ ├── execute.py │ ├── factory.py │ └── task │ ├── __init__.py │ ├── base.py │ ├── console.py │ ├── export.py │ ├── factory.py │ ├── file.py │ ├── image.py │ ├── retrieve.py │ ├── service.py │ ├── storage.py │ ├── stream.py │ ├── template.py │ ├── url.py │ └── workflow.py └── test └── python ├── testagent.py ├── testann.py ├── testapi ├── __init__.py ├── testapiagent.py ├── testapiembeddings.py ├── testapipipeline.py ├── testapiworkflow.py ├── testauthorization.py ├── testcluster.py ├── testencoding.py ├── testextension.py ├── testmcp.py └── testopenai.py ├── testapp.py ├── testarchive.py ├── testcloud.py ├── testconsole.py ├── testdatabase ├── __init__.py ├── testclient.py ├── testcustom.py ├── testdatabase.py ├── testduckdb.py ├── testencoder.py ├── testrdbms.py ├── testsql.py └── testsqlite.py ├── testembeddings.py ├── testgraph.py ├── testmodels ├── __init__.py ├── testmodels.py └── testpooling.py ├── testoptional.py ├── testpipeline ├── __init__.py ├── testaudio │ ├── __init__.py │ ├── testaudiomixer.py │ ├── testaudiostream.py │ ├── testmicrophone.py │ ├── testtexttoaudio.py │ ├── testtexttospeech.py │ └── testtranscription.py ├── testdata │ ├── __init__.py │ ├── testfiletohtml.py │ ├── testtabular.py │ ├── testtextractor.py │ └── testtokenizer.py ├── testimage │ ├── __init__.py │ ├── testcaption.py │ ├── testimagehash.py │ └── testobjects.py ├── testllm │ ├── __init__.py │ ├── testgenerator.py │ ├── testlitellm.py │ ├── testllama.py │ ├── testllm.py │ ├── testrag.py │ └── testsequences.py ├── testtext │ ├── __init__.py │ ├── testentity.py │ ├── testlabels.py │ ├── testsummary.py │ └── testtranslation.py └── testtrain │ ├── __init__.py │ ├── testonnx.py │ ├── testquantization.py │ └── testtrainer.py ├── testscoring.py ├── testserialize.py ├── testvectors ├── __init__.py ├── testcustom.py ├── testexternal.py ├── testhuggingface.py ├── testlitellm.py ├── testllama.py ├── testm2v.py ├── testsbert.py ├── testvectors.py └── testwordvectors.py ├── testworkflow.py └── utils.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source = src/python 3 | concurrency = multiprocessing,thread 4 | disable_warnings = no-data-collected 5 | omit = **/__main__.py 6 | 7 | [combine] 8 | disable_warnings = no-data-collected 9 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | on: 3 | push: 4 | branches: 5 | - master 6 | jobs: 7 | deploy: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v3 11 | - uses: actions/setup-python@v4 12 | with: 13 | python-version: "3.10" 14 | - run: | 15 | pip install -U pip wheel 16 | pip install .[all,dev] 17 | - run: mkdocs gh-deploy --force 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | dist/ 3 | docker/**/*.yml 4 | htmlcov/ 5 | *egg-info/ 6 | __pycache__/ 7 | .coverage 8 | .coverage.* 9 | *.pyc 10 | .vscode/ 11 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pycqa/pylint 3 | rev: v3.3.1 4 | hooks: 5 | - id: pylint 6 | args: 7 | - -d import-error 8 | - -d duplicate-code 9 | - -d too-many-positional-arguments 10 | - repo: https://github.com/ambv/black 11 | rev: 24.10.0 12 | hooks: 13 | - id: black 14 | language_version: python3 15 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [BASIC] 2 | module-rgx=[a-z_][a-zA-Z0-9_]{2,30}$ 3 | method-rgx=[a-z_][a-zA-Z0-9_]{2,30}$ 4 | function-rgx=[a-z_][a-zA-Z0-9_]{2,30}$ 5 | argument-rgx=[a-z_][a-zA-Z0-9_]{0,30}$ 6 | variable-rgx=[a-z_][a-zA-Z0-9_]{0,30}$ 7 | attr-rgx=[a-z_][a-zA-Z0-9_]{0,30}$ 8 | 9 | [DESIGN] 10 | max-args=10 11 | max-locals=40 12 | max-returns=10 13 | max-attributes=20 14 | min-public-methods=0 15 | 16 | [FORMAT] 17 | max-line-length=150 18 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | date-released: 2020-08-11 3 | message: "If you use this software, please cite it as below." 4 | title: "txtai: the all-in-one AI framework" 5 | abstract: "txtai is an all-in-one open-source AI framework for semantic search, LLM orchestration and language model workflows" 6 | url: "https://github.com/neuml/txtai" 7 | authors: 8 | - family-names: "Mezzetti" 9 | given-names: "David" 10 | affiliation: NeuML 11 | license: Apache-2.0 12 | -------------------------------------------------------------------------------- /apps.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/apps.jpg -------------------------------------------------------------------------------- /demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/demo.gif -------------------------------------------------------------------------------- /docker/api/Dockerfile: -------------------------------------------------------------------------------- 1 | # Set base image 2 | ARG BASE_IMAGE=neuml/txtai-cpu 3 | FROM $BASE_IMAGE 4 | 5 | # Copy configuration 6 | COPY config.yml . 7 | 8 | # Run local API instance to cache models in container 9 | RUN python -c "from txtai.api import API; API('config.yml', False)" 10 | 11 | # Start server and listen on all interfaces 12 | ENV CONFIG "config.yml" 13 | ENTRYPOINT ["uvicorn", "--host", "0.0.0.0", "txtai.api:app"] 14 | -------------------------------------------------------------------------------- /docker/aws/Dockerfile: -------------------------------------------------------------------------------- 1 | # Set base image 2 | ARG BASE_IMAGE=neuml/txtai-cpu 3 | FROM $BASE_IMAGE 4 | 5 | # Application script to copy into image 6 | ARG APP=api.py 7 | 8 | # Install Lambda Runtime Interface Client and Mangum ASGI bindings 9 | RUN pip install awslambdaric mangum 10 | 11 | # Copy configuration 12 | COPY config.yml . 13 | 14 | # Run local API instance to cache models in container 15 | RUN python -c "from txtai.api import API; API('config.yml', False)" 16 | 17 | # Copy application 18 | COPY $APP ./app.py 19 | 20 | # Start runtime client using default application handler 21 | ENV CONFIG "config.yml" 22 | ENTRYPOINT ["python", "-m", "awslambdaric"] 23 | CMD ["app.handler"] 24 | -------------------------------------------------------------------------------- /docker/aws/api.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lambda handler for a txtai API instance 3 | """ 4 | 5 | from mangum import Mangum 6 | 7 | from txtai.api import app, start 8 | 9 | # pylint: disable=C0103 10 | # Create FastAPI application instance wrapped by Mangum 11 | handler = None 12 | if not handler: 13 | # Start application 14 | start() 15 | 16 | # Create handler 17 | handler = Mangum(app, lifespan="off") 18 | -------------------------------------------------------------------------------- /docker/aws/workflow.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lambda handler for txtai workflows 3 | """ 4 | 5 | import json 6 | 7 | from txtai.api import API 8 | 9 | APP = None 10 | 11 | 12 | # pylint: disable=W0603,W0613 13 | def handler(event, context): 14 | """ 15 | Runs a workflow using input event parameters. 16 | 17 | Args: 18 | event: input event 19 | context: input context 20 | 21 | Returns: 22 | Workflow results 23 | """ 24 | 25 | # Create (or get) global app instance 26 | global APP 27 | APP = APP if APP else API("config.yml") 28 | 29 | # Get parameters from event body 30 | event = json.loads(event["body"]) 31 | 32 | # Run workflow and return results 33 | return {"statusCode": 200, "headers": {"Content-Type": "application/json"}, "body": list(APP.workflow(event["name"], event["elements"]))} 34 | -------------------------------------------------------------------------------- /docker/base/Dockerfile: -------------------------------------------------------------------------------- 1 | # Set base image 2 | ARG BASE_IMAGE=python:3.10-slim 3 | FROM $BASE_IMAGE 4 | 5 | # Install GPU-enabled version of PyTorch if set 6 | ARG GPU 7 | 8 | # Target CPU architecture 9 | ARG TARGETARCH 10 | 11 | # Set Python version (i.e. 3, 3.10) 12 | ARG PYTHON_VERSION=3 13 | 14 | # List of txtai components to install 15 | ARG COMPONENTS=[all] 16 | 17 | # Locale environment variables 18 | ENV LC_ALL=C.UTF-8 19 | ENV LANG=C.UTF-8 20 | 21 | RUN \ 22 | # Install required packages 23 | apt-get update && \ 24 | apt-get -y --no-install-recommends install libgomp1 libportaudio2 libsndfile1 gcc g++ python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python3-pip && \ 25 | rm -rf /var/lib/apt/lists && \ 26 | \ 27 | # Install txtai project and dependencies 28 | ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python && \ 29 | python -m pip install --no-cache-dir -U pip wheel setuptools && \ 30 | if [ -z ${GPU} ] && { [ -z ${TARGETARCH} ] || [ ${TARGETARCH} = "amd64" ] ;}; then pip install --no-cache-dir torch==2.6.0+cpu torchvision==0.21.0+cpu -f https://download.pytorch.org/whl/torch -f https://download.pytorch.org/whl/torchvision; fi && \ 31 | python -m pip install --no-cache-dir txtai${COMPONENTS} && \ 32 | python -c "import sys, importlib.util as util; 1 if util.find_spec('nltk') else sys.exit(); import nltk; nltk.download(['punkt', 'punkt_tab', 'averaged_perceptron_tagger_eng'])" && \ 33 | \ 34 | # Cleanup build packages 35 | apt-get -y purge gcc g++ python${PYTHON_VERSION}-dev && apt-get -y autoremove 36 | 37 | # Set default working directory 38 | WORKDIR /app 39 | -------------------------------------------------------------------------------- /docker/schedule/Dockerfile: -------------------------------------------------------------------------------- 1 | # Set base image 2 | ARG BASE_IMAGE=neuml/txtai-cpu 3 | FROM $BASE_IMAGE 4 | 5 | # Copy configuration 6 | COPY config.yml . 7 | 8 | # Run local API instance to cache models in container 9 | RUN python -c "from txtai.api import API; API('config.yml', False)" 10 | 11 | # Start application and wait for completion. Scheduled workflows can run indefinitely. 12 | ENTRYPOINT ["python", "-c", "from txtai.api import API; API('config.yml').wait()"] 13 | -------------------------------------------------------------------------------- /docker/workflow/Dockerfile: -------------------------------------------------------------------------------- 1 | # Set base image 2 | ARG BASE_IMAGE=neuml/txtai-cpu 3 | FROM $BASE_IMAGE 4 | 5 | # Copy configuration 6 | COPY config.yml . 7 | 8 | # Run local API instance to cache models in container 9 | RUN python -c "from txtai.api import API; API('config.yml', False)" 10 | 11 | # Run workflow. Requires two command line arguments: name of workflow and input elements 12 | ENTRYPOINT ["python", "-c", "import sys; from txtai.api import API\nfor _ in API('config.yml').workflow(sys.argv[1], sys.argv[2:]): pass"] 13 | CMD ["workflow"] 14 | -------------------------------------------------------------------------------- /docs/agent/methods.md: -------------------------------------------------------------------------------- 1 | # Methods 2 | 3 | ## ::: txtai.agent.base.Agent.__init__ 4 | ## ::: txtai.agent.base.Agent.__call__ 5 | -------------------------------------------------------------------------------- /docs/api/cluster.md: -------------------------------------------------------------------------------- 1 | # Distributed embeddings clusters 2 | 3 | The API supports combining multiple API instances into a single logical embeddings index. An example configuration is shown below. 4 | 5 | ```yaml 6 | cluster: 7 | shards: 8 | - http://127.0.0.1:8002 9 | - http://127.0.0.1:8003 10 | ``` 11 | 12 | This configuration aggregates the API instances above as index shards. Data is evenly split among each of the shards at index time. Queries are run in parallel against each shard and the results are joined together. This method allows horizontal scaling and supports very large index clusters. 13 | 14 | This method is only recommended for data sets in the 1 billion+ records. The ANN libraries can easily support smaller data sizes and this method is not worth the additional complexity. At this time, new shards can not be added after building the initial index. 15 | 16 | See the link below for a detailed example covering distributed embeddings clusters. 17 | 18 | | Notebook | Description | | 19 | |:----------|:-------------|------:| 20 | | [Distributed embeddings cluster](https://github.com/neuml/txtai/blob/master/examples/15_Distributed_embeddings_cluster.ipynb) | Distribute an embeddings index across multiple data nodes | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/neuml/txtai/blob/master/examples/15_Distributed_embeddings_cluster.ipynb) | 21 | -------------------------------------------------------------------------------- /docs/api/customization.md: -------------------------------------------------------------------------------- 1 | # Customization 2 | 3 | The txtai API has a number of features out of the box that are designed to help get started quickly. API services can also be augmented with custom code and functionality. The two main ways to do this are with extensions and dependencies. 4 | 5 | Extensions add a custom endpoint. Dependencies add middleware that executes with each request. See the sections below for more. 6 | 7 | ## Extensions 8 | 9 | While the API is extremely flexible and complex logic can be executed through YAML-driven workflows, some may prefer to create an endpoint in Python. API extensions define custom Python endpoints that interact with txtai applications. 10 | 11 | See the link below for a detailed example. 12 | 13 | | Notebook | Description | | 14 | |:----------|:-------------|------:| 15 | | [Custom API Endpoints](https://github.com/neuml/txtai/blob/master/examples/51_Custom_API_Endpoints.ipynb) | Extend the API with custom endpoints | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/neuml/txtai/blob/master/examples/51_Custom_API_Endpoints.ipynb) | 16 | 17 | ## Dependencies 18 | 19 | txtai has a default API token authorization method that works well in many cases. Dependencies can also add custom logic with each request. This could be an additional authorization step and/or an authentication method. 20 | 21 | See the link below for a detailed example. 22 | 23 | | Notebook | Description | | 24 | |:----------|:-------------|------:| 25 | | [API Authorization and Authentication](https://github.com/neuml/txtai/blob/master/examples/54_API_Authorization_and_Authentication.ipynb) | Add authorization, authentication and middleware dependencies to the API | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/neuml/txtai/blob/master/examples/54_API_Authorization_and_Authentication.ipynb) | 26 | -------------------------------------------------------------------------------- /docs/api/mcp.md: -------------------------------------------------------------------------------- 1 | # Model Context Protocol 2 | 3 | The [Model Context Protocol (MCP)](https://modelcontextprotocol.io/introduction) is an open standard that enables developers to build secure, two-way connections between their data sources and AI-powered tools. 4 | 5 | The API can be configured to handle MCP requests. All enabled endpoints set in the API configuration are automatically added as MCP tools. 6 | 7 | ```yaml 8 | mcp: True 9 | ``` 10 | 11 | Once this configuration option is added, a new route is added to the application `/mcp`. 12 | 13 | The [Model Context Protocol Inspector tool](https://www.npmjs.com/package/@modelcontextprotocol/inspector) is a quick way to explore how the MCP tools are exported through this interface. 14 | 15 | Run the following and go to the local URL specified. 16 | 17 | ``` 18 | npx @modelcontextprotocol/inspector node build/index.js 19 | ``` 20 | 21 | Enter `http://localhost:8000/mcp` to see the full list of tools available. 22 | -------------------------------------------------------------------------------- /docs/api/methods.md: -------------------------------------------------------------------------------- 1 | # Methods 2 | 3 | ::: txtai.api.API 4 | options: 5 | inherited_members: true 6 | filters: 7 | - "!__del__" 8 | - "!flows" 9 | - "!function" 10 | - "!indexes" 11 | - "!limit" 12 | - "!pipes" 13 | - "!read" 14 | - "!resolve" 15 | - "!weights" 16 | -------------------------------------------------------------------------------- /docs/api/openai.md: -------------------------------------------------------------------------------- 1 | # OpenAI-compatible API 2 | 3 | The API can be configured to serve an OpenAI-compatible API as shown below. 4 | 5 | ```yaml 6 | openai: True 7 | ``` 8 | 9 | See the link below for a detailed example. 10 | 11 | | Notebook | Description | | 12 | |:----------|:-------------|------:| 13 | | [OpenAI Compatible API](https://github.com/neuml/txtai/blob/master/examples/74_OpenAI_Compatible_API.ipynb) | Connect to txtai with a standard OpenAI client library | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/neuml/txtai/blob/master/examples/74_OpenAI_Compatible_API.ipynb) | 14 | -------------------------------------------------------------------------------- /docs/api/security.md: -------------------------------------------------------------------------------- 1 | # Security 2 | 3 | The default implementation of an API service runs via HTTP and is fully open. If the service is being run as a prototype on an internal network, that may be fine. In most scenarios, the connection should at least be encrypted. Authorization is another built-in feature that requires a valid API token with each request. See below for more. 4 | 5 | ## HTTPS 6 | 7 | The default API service command starts a Uvicorn server as a HTTP service on port 8000. To run a HTTPS service, consider the following options. 8 | 9 | - [TLS Proxy Server](https://fastapi.tiangolo.com/deployment/https/). *Recommended choice*. With this configuration, the txtai API service runs as a HTTP service only accessible on the localhost/local network. The proxy server handles all encryption and redirects requests to local services. See this [example configuration](https://www.uvicorn.org/deployment/#running-behind-nginx) for more. 10 | 11 | - [Uvicorn SSL Certificate](https://www.uvicorn.org/deployment/). Another option is setting the SSL certificate on the Uvicorn service. This works in simple situations but gets complex when hosting multiple txtai or other related services. 12 | 13 | ## Authorization 14 | 15 | Authorization requires a valid API token with each API request. This token is sent as a HTTP `Authorization` header. 16 | 17 | *Server* 18 | ```bash 19 | CONFIG=config.yml TOKEN= uvicorn "txtai.api:app" 20 | ``` 21 | 22 | *Client* 23 | ```bash 24 | curl \ 25 | -X POST "http://localhost:8000/workflow" \ 26 | -H "Content-Type: application/json" \ 27 | -H "Authorization: Bearer " \ 28 | -d '{"name":"sumfrench", "elements": ["https://github.com/neuml/txtai"]}' 29 | ``` 30 | 31 | It's important to note that HTTPS **must** be enabled using one of the methods mentioned above. Otherwise, tokens will be exchanged as clear text. 32 | 33 | Authentication and Authorization can be fully customized. See the [dependencies](../customization#dependencies) section for more. 34 | -------------------------------------------------------------------------------- /docs/embeddings/configuration/general.md: -------------------------------------------------------------------------------- 1 | # General 2 | 3 | General configuration options that don't fit elsewhere. 4 | 5 | ## keyword 6 | ```yaml 7 | keyword: boolean 8 | ``` 9 | 10 | Enables sparse keyword indexing for this embeddings. 11 | 12 | When enabled, this parameter creates a BM25 index for full text search. It also implicitly disables the [defaults](../vectors/#defaults) setting for vector search. 13 | 14 | ## hybrid 15 | ```yaml 16 | hybrid: boolean 17 | ``` 18 | 19 | Enables hybrid (sparse + dense) indexing for this embeddings. 20 | 21 | When enabled, this parameter creates a BM25 index for full text search. It has no effect on the [defaults](../vectors/#defaults) or [path](../vectors/#path) settings. 22 | 23 | ## indexes 24 | ```yaml 25 | indexes: dict 26 | ``` 27 | 28 | Key value pairs defining subindexes for this embeddings. Each key is the index name and the value is the full configuration. This configuration can use any of the available configurations in a standard embeddings instance. 29 | 30 | ## autoid 31 | ```yaml 32 | format: int|uuid function 33 | ``` 34 | 35 | Sets the auto id generation method. When this is not set, an autogenerated numeric sequence is used. This also supports [UUID generation functions](https://docs.python.org/3/library/uuid.html#uuid.uuid1). For example, setting this value to `uuid4` will generate random UUIDs. Setting this to `uuid5` will generate deterministic UUIDs for each input data row. 36 | 37 | ## columns 38 | ```yaml 39 | columns: 40 | text: name of the text column 41 | object: name of the object column 42 | ``` 43 | 44 | Sets the `text` and `object` column names. Defaults to `text` and `object` if not provided. 45 | 46 | ## format 47 | ```yaml 48 | format: json|pickle 49 | ``` 50 | 51 | Sets the configuration storage format. Defaults to `json`. 52 | -------------------------------------------------------------------------------- /docs/embeddings/configuration/index.md: -------------------------------------------------------------------------------- 1 | # Configuration 2 | 3 | The following describes available embeddings configuration. These parameters are set in the [Embeddings constructor](../methods#txtai.embeddings.base.Embeddings.__init__) via either the `config` parameter or as keyword arguments. 4 | 5 | Configuration is designed to be optional and set only when needed. Out of the box, sensible defaults are picked to get up and running fast. For example: 6 | 7 | ```python 8 | from txtai import Embeddings 9 | 10 | embeddings = Embeddings() 11 | ``` 12 | 13 | Creates a new embeddings instance, using [all-MiniLM-L6-v2](https://hf.co/sentence-transformers/all-MiniLM-L6-v2) as the vector model, [Faiss](https://faiss.ai/) as the ANN index backend and content disabled. 14 | 15 | ```python 16 | from txtai import Embeddings 17 | 18 | embeddings = Embeddings(content=True) 19 | ``` 20 | 21 | Is the same as above except it adds in [SQLite](https://www.sqlite.org/index.html) for content storage. 22 | 23 | The following sections link to all the available configuration options. 24 | 25 | ## [ANN](./ann) 26 | 27 | The default vector index backend is Faiss. 28 | 29 | ## [Cloud](./cloud) 30 | 31 | Embeddings databases can optionally be synced with cloud storage. 32 | 33 | ## [Database](./database) 34 | 35 | Content storage is disabled by default. When enabled, SQLite is the default storage engine. 36 | 37 | ## [General](./general) 38 | 39 | General configuration that doesn't fit elsewhere. 40 | 41 | ## [Graph](./graph) 42 | 43 | An accomplying graph index can be created with an embeddings database. This enables topic modeling, path traversal and more. [NetworkX](https://github.com/networkx/networkx) is the default graph index. 44 | 45 | ## [Scoring](./scoring) 46 | 47 | Sparse keyword indexing and word vectors term weighting. 48 | 49 | ## [Vectors](./vectors) 50 | 51 | Vector search is enabled by converting text and other binary data into embeddings vectors. These vectors are then stored in an ANN index. The vector model is optional and a default model is used when not provided. 52 | -------------------------------------------------------------------------------- /docs/embeddings/configuration/scoring.md: -------------------------------------------------------------------------------- 1 | # Scoring 2 | 3 | Enable scoring support via the `scoring` parameter. 4 | 5 | This scoring instance can serve two purposes, depending on the settings. 6 | 7 | One use case is building sparse/keyword indexes. This occurs when the `terms` parameter is set to `True`. 8 | 9 | The other use case is with word vector term weighting. This feature has been available since the initial version but isn't quite as common anymore. 10 | 11 | The following covers the available options. 12 | 13 | ## method 14 | ```yaml 15 | method: bm25|tfidf|sif|pgtext|custom 16 | ``` 17 | 18 | Sets the scoring method. Add custom scoring via setting this parameter to the fully resolvable class string. 19 | 20 | ### pgtext 21 | ```yaml 22 | schema: database schema to store keyword index - defaults to being 23 | determined by the database 24 | ``` 25 | 26 | Additional settings for Postgres full-text keyword indexes. 27 | 28 | ## terms 29 | ```yaml 30 | terms: boolean|dict 31 | ``` 32 | 33 | Enables term frequency sparse arrays for a scoring instance. This is the backend for sparse keyword indexes. 34 | 35 | Supports a `dict` with the parameters `cachelimit` and `cutoff`. 36 | 37 | `cachelimit` is the maximum amount of resident memory in bytes to use during indexing before flushing to disk. This parameter is an `int`. 38 | 39 | `cutoff` is used during search to determine what constitutes a common term. This parameter is a `float`, i.e. 0.1 for a cutoff of 10%. 40 | 41 | When `terms` is set to `True`, default parameters are used for the `cachelimit` and `cutoff`. Normally, these defaults are sufficient. 42 | 43 | ## normalize 44 | ```yaml 45 | normalize: boolean 46 | ``` 47 | 48 | Enables normalized scoring (ranging from 0 to 1). When enabled, statistics from the index will be used to calculate normalized scores. 49 | -------------------------------------------------------------------------------- /docs/embeddings/methods.md: -------------------------------------------------------------------------------- 1 | # Methods 2 | 3 | ::: txtai.embeddings.Embeddings 4 | options: 5 | filters: 6 | - "!columns" 7 | - "!createann" 8 | - "!createcloud" 9 | - "!createdatabase" 10 | - "!creategraph" 11 | - "!createids" 12 | - "!createindexes" 13 | - "!createscoring" 14 | - "!checkarchive" 15 | - "!configure" 16 | - "!defaultallowed" 17 | - "!defaults" 18 | - "!initindex" 19 | - "!loadquery" 20 | - "!loadvectors" 21 | -------------------------------------------------------------------------------- /docs/further.md: -------------------------------------------------------------------------------- 1 | # Further reading 2 | 3 | ![further](images/further.png#only-light) 4 | ![further](images/further-dark.png#only-dark) 5 | 6 | - [Introducing txtai, the all-in-one AI framework](https://medium.com/neuml/introducing-txtai-the-all-in-one-ai-framework-0660ecfc39d7) 7 | - [Tutorial series on Hashnode](https://neuml.hashnode.dev/series/txtai-tutorial) | [dev.to](https://dev.to/neuml/tutorial-series-on-txtai-ibg) 8 | - [What's new in txtai 8.0](https://medium.com/neuml/whats-new-in-txtai-8-0-2d7d0ab4506b) | [7.0](https://medium.com/neuml/whats-new-in-txtai-7-0-855ad6a55440) | [6.0](https://medium.com/neuml/whats-new-in-txtai-6-0-7d93eeedf804) | [5.0](https://medium.com/neuml/whats-new-in-txtai-5-0-e5c75a13b101) | [4.0](https://medium.com/neuml/whats-new-in-txtai-4-0-bbc3a65c3d1c) 9 | - [Getting started with semantic search](https://medium.com/neuml/getting-started-with-semantic-search-a9fd9d8a48cf) | [workflows](https://medium.com/neuml/getting-started-with-semantic-workflows-2fefda6165d9) | [rag](https://medium.com/neuml/getting-started-with-rag-9a0cca75f748) 10 | - [Running txtai at scale](https://medium.com/neuml/running-at-scale-with-txtai-71196cdd99f9) 11 | - [Vector search & RAG Landscape: A review with txtai](https://medium.com/neuml/vector-search-rag-landscape-a-review-with-txtai-a7f37ad0e187) 12 | -------------------------------------------------------------------------------- /docs/images/agent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/agent.png -------------------------------------------------------------------------------- /docs/images/api-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/api-dark.png -------------------------------------------------------------------------------- /docs/images/api.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/api.png -------------------------------------------------------------------------------- /docs/images/architecture-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/architecture-dark.png -------------------------------------------------------------------------------- /docs/images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/architecture.png -------------------------------------------------------------------------------- /docs/images/cloud-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/cloud-dark.png -------------------------------------------------------------------------------- /docs/images/cloud.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/cloud.png -------------------------------------------------------------------------------- /docs/images/embeddings-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/embeddings-dark.png -------------------------------------------------------------------------------- /docs/images/embeddings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/embeddings.png -------------------------------------------------------------------------------- /docs/images/examples-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/examples-dark.png -------------------------------------------------------------------------------- /docs/images/examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/examples.png -------------------------------------------------------------------------------- /docs/images/faq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/faq.png -------------------------------------------------------------------------------- /docs/images/flows-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/flows-dark.png -------------------------------------------------------------------------------- /docs/images/flows.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/flows.png -------------------------------------------------------------------------------- /docs/images/format-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/format-dark.png -------------------------------------------------------------------------------- /docs/images/format.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/format.png -------------------------------------------------------------------------------- /docs/images/further-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/further-dark.png -------------------------------------------------------------------------------- /docs/images/further-ghdark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/further-ghdark.png -------------------------------------------------------------------------------- /docs/images/further.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/further.png -------------------------------------------------------------------------------- /docs/images/indexing-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/indexing-dark.png -------------------------------------------------------------------------------- /docs/images/indexing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/indexing.png -------------------------------------------------------------------------------- /docs/images/install-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/install-dark.png -------------------------------------------------------------------------------- /docs/images/install.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/install.png -------------------------------------------------------------------------------- /docs/images/llm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/llm.png -------------------------------------------------------------------------------- /docs/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/logo.png -------------------------------------------------------------------------------- /docs/images/models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/models.png -------------------------------------------------------------------------------- /docs/images/pipeline-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/pipeline-dark.png -------------------------------------------------------------------------------- /docs/images/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/pipeline.png -------------------------------------------------------------------------------- /docs/images/query-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/query-dark.png -------------------------------------------------------------------------------- /docs/images/query.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/query.png -------------------------------------------------------------------------------- /docs/images/rag-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/rag-dark.png -------------------------------------------------------------------------------- /docs/images/rag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/rag.png -------------------------------------------------------------------------------- /docs/images/schedule-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/schedule-dark.png -------------------------------------------------------------------------------- /docs/images/schedule.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/schedule.png -------------------------------------------------------------------------------- /docs/images/search-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/search-dark.png -------------------------------------------------------------------------------- /docs/images/search.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/search.png -------------------------------------------------------------------------------- /docs/images/task-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/task-dark.png -------------------------------------------------------------------------------- /docs/images/task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/task.png -------------------------------------------------------------------------------- /docs/images/why-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/why-dark.png -------------------------------------------------------------------------------- /docs/images/why.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/why.png -------------------------------------------------------------------------------- /docs/images/workflow-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/workflow-dark.png -------------------------------------------------------------------------------- /docs/images/workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/docs/images/workflow.png -------------------------------------------------------------------------------- /docs/overrides/main.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block extrahead %} 4 | {% set title = config.site_name %} 5 | {% if page and page.meta and page.meta.title %} 6 | {% set title = title ~ " - " ~ page.meta.title %} 7 | {% elif page and page.title and not page.is_homepage %} 8 | {% set title = title ~ " - " ~ page.title %} 9 | {% endif %} 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | {% endblock %} 25 | -------------------------------------------------------------------------------- /docs/pipeline/train/hfonnx.md: -------------------------------------------------------------------------------- 1 | # HFOnnx 2 | 3 | ![pipeline](../../images/pipeline.png#only-light) 4 | ![pipeline](../../images/pipeline-dark.png#only-dark) 5 | 6 | Exports a Hugging Face Transformer model to ONNX. Currently, this works best with classification/pooling/qa models. Work is ongoing for sequence to 7 | sequence models (summarization, transcription, translation). 8 | 9 | ## Example 10 | 11 | The following shows a simple example using this pipeline. 12 | 13 | ```python 14 | from txtai.pipeline import HFOnnx, Labels 15 | 16 | # Model path 17 | path = "distilbert-base-uncased-finetuned-sst-2-english" 18 | 19 | # Export model to ONNX 20 | onnx = HFOnnx() 21 | model = onnx(path, "text-classification", "model.onnx", True) 22 | 23 | # Run inference and validate 24 | labels = Labels((model, path), dynamic=False) 25 | labels("I am happy") 26 | ``` 27 | 28 | See the link below for a more detailed example. 29 | 30 | | Notebook | Description | | 31 | |:----------|:-------------|------:| 32 | | [Export and run models with ONNX](https://github.com/neuml/txtai/blob/master/examples/18_Export_and_run_models_with_ONNX.ipynb) | Export models with ONNX, run natively in JavaScript, Java and Rust | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/neuml/txtai/blob/master/examples/18_Export_and_run_models_with_ONNX.ipynb) | 33 | 34 | ## Methods 35 | 36 | Python documentation for the pipeline. 37 | 38 | ### ::: txtai.pipeline.HFOnnx.__call__ 39 | -------------------------------------------------------------------------------- /docs/pipeline/train/mlonnx.md: -------------------------------------------------------------------------------- 1 | # MLOnnx 2 | 3 | ![pipeline](../../images/pipeline.png#only-light) 4 | ![pipeline](../../images/pipeline-dark.png#only-dark) 5 | 6 | Exports a traditional machine learning model (i.e. scikit-learn) to ONNX. 7 | 8 | ## Example 9 | 10 | See the link below for a detailed example. 11 | 12 | | Notebook | Description | | 13 | |:----------|:-------------|------:| 14 | | [Export and run other machine learning models](https://github.com/neuml/txtai/blob/master/examples/21_Export_and_run_other_machine_learning_models.ipynb) | Export and run models from scikit-learn, PyTorch and more | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/neuml/txtai/blob/master/examples/21_Export_and_run_other_machine_learning_models.ipynb) | 15 | 16 | ## Methods 17 | 18 | Python documentation for the pipeline. 19 | 20 | ### ::: txtai.pipeline.MLOnnx.__call__ 21 | -------------------------------------------------------------------------------- /docs/poweredby.md: -------------------------------------------------------------------------------- 1 | # Powered by txtai 2 | 3 | The following applications are powered by txtai. 4 | 5 | ![apps](https://raw.githubusercontent.com/neuml/txtai/master/apps.jpg) 6 | 7 | | Application | Description | 8 | |:------------ |:-------------| 9 | | [rag](https://github.com/neuml/rag) | Retrieval Augmented Generation (RAG) application | 10 | | [ragdata](https://github.com/neuml/ragdata) | Build knowledge bases for RAG | 11 | | [paperai](https://github.com/neuml/paperai) | Semantic search and workflows for medical/scientific papers | 12 | | [annotateai](https://github.com/neuml/annotateai) | Automatically annotate papers with LLMs | 13 | 14 | In addition to this list, there are also many other [open-source projects](https://github.com/neuml/txtai/network/dependents), [published research](https://scholar.google.com/scholar?q=txtai&hl=en&as_ylo=2022) and closed proprietary/commercial projects that have built on txtai in production. 15 | -------------------------------------------------------------------------------- /docs/why.md: -------------------------------------------------------------------------------- 1 | # Why txtai? 2 | 3 | ![why](images/why.png#only-light) 4 | ![why](images/why-dark.png#only-dark) 5 | 6 | New vector databases, LLM frameworks and everything in between are sprouting up daily. Why build with txtai? 7 | 8 | - Up and running in minutes with [pip](../install/) or [Docker](../cloud/) 9 | ```python 10 | # Get started in a couple lines 11 | import txtai 12 | 13 | embeddings = txtai.Embeddings() 14 | embeddings.index(["Correct", "Not what we hoped"]) 15 | embeddings.search("positive", 1) 16 | #[(0, 0.29862046241760254)] 17 | ``` 18 | - Built-in API makes it easy to develop applications using your programming language of choice 19 | ```yaml 20 | # app.yml 21 | embeddings: 22 | path: sentence-transformers/all-MiniLM-L6-v2 23 | ``` 24 | ```bash 25 | CONFIG=app.yml uvicorn "txtai.api:app" 26 | curl -X GET "http://localhost:8000/search?query=positive" 27 | ``` 28 | - Run local - no need to ship data off to disparate remote services 29 | - Work with micromodels all the way up to large language models (LLMs) 30 | - Low footprint - install additional dependencies and scale up when needed 31 | - [Learn by example](../examples) - notebooks cover all available functionality 32 | -------------------------------------------------------------------------------- /docs/workflow/task/console.md: -------------------------------------------------------------------------------- 1 | # Console Task 2 | 3 | ![task](../../images/task.png#only-light) 4 | ![task](../../images/task-dark.png#only-dark) 5 | 6 | The Console Task prints task inputs and outputs to standard output. This task is mainly used for debugging and can be added at any point in a workflow. 7 | 8 | ## Example 9 | 10 | The following shows a simple example using this task as part of a workflow. 11 | 12 | ```python 13 | from txtai.workflow import FileTask, Workflow 14 | 15 | workflow = Workflow([ConsoleTask()]) 16 | workflow(["Input 1", "Input2"]) 17 | ``` 18 | 19 | ## Configuration-driven example 20 | 21 | This task can also be created with workflow configuration. 22 | 23 | ```yaml 24 | workflow: 25 | tasks: 26 | - task: console 27 | ``` 28 | 29 | ## Methods 30 | 31 | Python documentation for the task. 32 | 33 | ### ::: txtai.workflow.ConsoleTask.__init__ 34 | -------------------------------------------------------------------------------- /docs/workflow/task/export.md: -------------------------------------------------------------------------------- 1 | # Export Task 2 | 3 | ![task](../../images/task.png#only-light) 4 | ![task](../../images/task-dark.png#only-dark) 5 | 6 | The Export Task exports task outputs to CSV or Excel. 7 | 8 | ## Example 9 | 10 | The following shows a simple example using this task as part of a workflow. 11 | 12 | ```python 13 | from txtai.workflow import FileTask, Workflow 14 | 15 | workflow = Workflow([ExportTask()]) 16 | workflow(["Input 1", "Input2"]) 17 | ``` 18 | 19 | ## Configuration-driven example 20 | 21 | This task can also be created with workflow configuration. 22 | 23 | ```yaml 24 | workflow: 25 | tasks: 26 | - task: export 27 | ``` 28 | 29 | ## Methods 30 | 31 | Python documentation for the task. 32 | 33 | ### ::: txtai.workflow.ExportTask.__init__ 34 | ### ::: txtai.workflow.ExportTask.register 35 | -------------------------------------------------------------------------------- /docs/workflow/task/file.md: -------------------------------------------------------------------------------- 1 | # File Task 2 | 3 | ![task](../../images/task.png#only-light) 4 | ![task](../../images/task-dark.png#only-dark) 5 | 6 | The File Task validates a file exists. It handles both file paths and local file urls. Note that this task _only_ works with local files. 7 | 8 | ## Example 9 | 10 | The following shows a simple example using this task as part of a workflow. 11 | 12 | ```python 13 | from txtai.workflow import FileTask, Workflow 14 | 15 | workflow = Workflow([FileTask()]) 16 | workflow(["/path/to/file", "file:///path/to/file"]) 17 | ``` 18 | 19 | ## Configuration-driven example 20 | 21 | This task can also be created with workflow configuration. 22 | 23 | ```yaml 24 | workflow: 25 | tasks: 26 | - task: file 27 | ``` 28 | 29 | ## Methods 30 | 31 | Python documentation for the task. 32 | 33 | ### ::: txtai.workflow.FileTask.__init__ 34 | -------------------------------------------------------------------------------- /docs/workflow/task/image.md: -------------------------------------------------------------------------------- 1 | # Image Task 2 | 3 | ![task](../../images/task.png#only-light) 4 | ![task](../../images/task-dark.png#only-dark) 5 | 6 | The Image Task reads file paths, check the file is an image and opens it as an Image object. Note that this task _only_ works with local files. 7 | 8 | ## Example 9 | 10 | The following shows a simple example using this task as part of a workflow. 11 | 12 | ```python 13 | from txtai.workflow import ImageTask, Workflow 14 | 15 | workflow = Workflow([ImageTask()]) 16 | workflow(["image.jpg", "image.gif"]) 17 | ``` 18 | 19 | ## Configuration-driven example 20 | 21 | This task can also be created with workflow configuration. 22 | 23 | ```yaml 24 | workflow: 25 | tasks: 26 | - task: image 27 | ``` 28 | 29 | ## Methods 30 | 31 | Python documentation for the task. 32 | 33 | ### ::: txtai.workflow.ImageTask.__init__ 34 | -------------------------------------------------------------------------------- /docs/workflow/task/retrieve.md: -------------------------------------------------------------------------------- 1 | # Retrieve Task 2 | 3 | ![task](../../images/task.png#only-light) 4 | ![task](../../images/task-dark.png#only-dark) 5 | 6 | The Retrieve Task connects to a url and downloads the content locally. This task is helpful when working with actions that require data to be available locally. 7 | 8 | ## Example 9 | 10 | The following shows a simple example using this task as part of a workflow. 11 | 12 | ```python 13 | from txtai.workflow import RetrieveTask, Workflow 14 | 15 | workflow = Workflow([RetrieveTask(directory="/tmp")]) 16 | workflow(["https://file.to.download", "/local/file/to/copy"]) 17 | ``` 18 | 19 | ## Configuration-driven example 20 | 21 | This task can also be created with workflow configuration. 22 | 23 | ```yaml 24 | workflow: 25 | tasks: 26 | - task: retrieve 27 | directory: /tmp 28 | ``` 29 | 30 | ## Methods 31 | 32 | Python documentation for the task. 33 | 34 | ### ::: txtai.workflow.RetrieveTask.__init__ 35 | ### ::: txtai.workflow.RetrieveTask.register 36 | -------------------------------------------------------------------------------- /docs/workflow/task/service.md: -------------------------------------------------------------------------------- 1 | # Service Task 2 | 3 | ![task](../../images/task.png#only-light) 4 | ![task](../../images/task-dark.png#only-dark) 5 | 6 | The Service Task extracts content from a http service. 7 | 8 | ## Example 9 | 10 | The following shows a simple example using this task as part of a workflow. 11 | 12 | ```python 13 | from txtai.workflow import ServiceTask, Workflow 14 | 15 | workflow = Workflow([ServiceTask(url="https://service.url/action)]) 16 | workflow(["parameter"]) 17 | ``` 18 | 19 | ## Configuration-driven example 20 | 21 | This task can also be created with workflow configuration. 22 | 23 | ```yaml 24 | workflow: 25 | tasks: 26 | - task: service 27 | url: https://service.url/action 28 | ``` 29 | 30 | ## Methods 31 | 32 | Python documentation for the task. 33 | 34 | ### ::: txtai.workflow.ServiceTask.__init__ 35 | ### ::: txtai.workflow.ServiceTask.register 36 | -------------------------------------------------------------------------------- /docs/workflow/task/storage.md: -------------------------------------------------------------------------------- 1 | # Storage Task 2 | 3 | ![task](../../images/task.png#only-light) 4 | ![task](../../images/task-dark.png#only-dark) 5 | 6 | The Storage Task expands a local directory or cloud storage bucket into a list of URLs to process. 7 | 8 | ## Example 9 | 10 | The following shows a simple example using this task as part of a workflow. 11 | 12 | ```python 13 | from txtai.workflow import StorageTask, Workflow 14 | 15 | workflow = Workflow([StorageTask()]) 16 | workflow(["s3://path/to/bucket", "local://local/directory"]) 17 | ``` 18 | 19 | ## Configuration-driven example 20 | 21 | This task can also be created with workflow configuration. 22 | 23 | ```yaml 24 | workflow: 25 | tasks: 26 | - task: storage 27 | ``` 28 | 29 | ## Methods 30 | 31 | Python documentation for the task. 32 | 33 | ### ::: txtai.workflow.StorageTask.__init__ 34 | -------------------------------------------------------------------------------- /docs/workflow/task/template.md: -------------------------------------------------------------------------------- 1 | # Template Task 2 | 3 | ![task](../../images/task.png#only-light) 4 | ![task](../../images/task-dark.png#only-dark) 5 | 6 | The Template Task generates text from a template and task inputs. Templates can be used to prepare data for a number of tasks including generating large 7 | language model (LLM) prompts. 8 | 9 | ## Example 10 | 11 | The following shows a simple example using this task as part of a workflow. 12 | 13 | ```python 14 | from txtai.workflow import TemplateTask, Workflow 15 | 16 | workflow = Workflow([TemplateTask(template="This is a {text} task")]) 17 | workflow([{"text": "template"}]) 18 | ``` 19 | 20 | ## Configuration-driven example 21 | 22 | This task can also be created with workflow configuration. 23 | 24 | ```yaml 25 | workflow: 26 | tasks: 27 | - task: template 28 | template: This is a {text} task 29 | ``` 30 | 31 | ## Methods 32 | 33 | Python documentation for the task. 34 | 35 | ### ::: txtai.workflow.TemplateTask.__init__ 36 | -------------------------------------------------------------------------------- /docs/workflow/task/url.md: -------------------------------------------------------------------------------- 1 | # Url Task 2 | 3 | ![task](../../images/task.png#only-light) 4 | ![task](../../images/task-dark.png#only-dark) 5 | 6 | The Url Task validates that inputs start with a url prefix. 7 | 8 | ## Example 9 | 10 | The following shows a simple example using this task as part of a workflow. 11 | 12 | ```python 13 | from txtai.workflow import UrlTask, Workflow 14 | 15 | workflow = Workflow([UrlTask()]) 16 | workflow(["https://file.to.download", "file:////local/file/to/copy"]) 17 | ``` 18 | 19 | ## Configuration-driven example 20 | 21 | This task can also be created with workflow configuration. 22 | 23 | ```yaml 24 | workflow: 25 | tasks: 26 | - task: url 27 | ``` 28 | 29 | ## Methods 30 | 31 | Python documentation for the task. 32 | 33 | ### ::: txtai.workflow.UrlTask.__init__ 34 | -------------------------------------------------------------------------------- /docs/workflow/task/workflow.md: -------------------------------------------------------------------------------- 1 | # Workflow Task 2 | 3 | ![task](../../images/task.png#only-light) 4 | ![task](../../images/task-dark.png#only-dark) 5 | 6 | The Workflow Task runs a workflow. Allows creating workflows of workflows. 7 | 8 | ## Example 9 | 10 | The following shows a simple example using this task as part of a workflow. 11 | 12 | ```python 13 | from txtai.workflow import WorkflowTask, Workflow 14 | 15 | workflow = Workflow([WorkflowTask(otherworkflow)]) 16 | workflow(["input data"]) 17 | ``` 18 | 19 | ## Methods 20 | 21 | Python documentation for the task. 22 | 23 | ### ::: txtai.workflow.WorkflowTask.__init__ 24 | -------------------------------------------------------------------------------- /examples/article.py: -------------------------------------------------------------------------------- 1 | """ 2 | Application that builds a summary of an article. 3 | 4 | Requires streamlit to be installed. 5 | pip install streamlit 6 | """ 7 | 8 | import os 9 | 10 | import streamlit as st 11 | 12 | from txtai.pipeline import Summary, Textractor 13 | from txtai.workflow import UrlTask, Task, Workflow 14 | 15 | 16 | class Application: 17 | """ 18 | Main application. 19 | """ 20 | 21 | def __init__(self): 22 | """ 23 | Creates a new application. 24 | """ 25 | 26 | textract = Textractor(paragraphs=True, minlength=100, join=True) 27 | summary = Summary("sshleifer/distilbart-cnn-12-6") 28 | 29 | self.workflow = Workflow([UrlTask(textract), Task(summary)]) 30 | 31 | def run(self): 32 | """ 33 | Runs a Streamlit application. 34 | """ 35 | 36 | st.title("Article Summary") 37 | st.markdown("This application builds a summary of an article.") 38 | 39 | url = st.text_input("URL") 40 | if url: 41 | # Run workflow and get summary 42 | summary = list(self.workflow([url]))[0] 43 | 44 | # Write results 45 | st.write(summary) 46 | st.markdown("*Source: " + url + "*") 47 | 48 | 49 | @st.cache(allow_output_mutation=True) 50 | def create(): 51 | """ 52 | Creates and caches a Streamlit application. 53 | 54 | Returns: 55 | Application 56 | """ 57 | 58 | return Application() 59 | 60 | 61 | if __name__ == "__main__": 62 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 63 | 64 | # Create and run application 65 | app = create() 66 | app.run() 67 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/logo.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 150 3 | -------------------------------------------------------------------------------- /src/python/txtai/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base imports 3 | """ 4 | 5 | import logging 6 | 7 | # Top-level imports 8 | from .agent import Agent 9 | from .app import Application 10 | from .embeddings import Embeddings 11 | from .pipeline import LLM, RAG 12 | from .workflow import Workflow 13 | 14 | # Configure logging per standard Python library recommendations 15 | logger = logging.getLogger(__name__) 16 | logger.addHandler(logging.NullHandler()) 17 | -------------------------------------------------------------------------------- /src/python/txtai/agent/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Agent imports 3 | """ 4 | 5 | # Conditional import 6 | try: 7 | from .base import Agent 8 | from .factory import ProcessFactory 9 | from .model import PipelineModel 10 | from .tool import * 11 | except ImportError: 12 | from .placeholder import Agent 13 | -------------------------------------------------------------------------------- /src/python/txtai/agent/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Agent module 3 | """ 4 | 5 | from .factory import ProcessFactory 6 | 7 | 8 | class Agent: 9 | """ 10 | An agent automatically creates workflows to answer multi-faceted user requests. Agents iteratively prompt and/or interface with tools to 11 | step through a process and ultimately come to an answer for a request. 12 | 13 | Agents excel at complex tasks where multiple tools and/or methods are required. They incorporate a level of randomness similar to different 14 | people working on the same task. When the request is simple and/or there is a rule-based process, other methods such as RAG and Workflows 15 | should be explored. 16 | """ 17 | 18 | def __init__(self, **kwargs): 19 | """ 20 | Creates a new Agent. 21 | 22 | Args: 23 | kwargs: arguments to pass to the underlying Agent backend and LLM pipeline instance 24 | """ 25 | 26 | # Ensure backwards compatibility 27 | if "max_iterations" in kwargs: 28 | kwargs["max_steps"] = kwargs.pop("max_iterations") 29 | 30 | # Create agent process runner 31 | self.process = ProcessFactory.create(kwargs) 32 | 33 | # Tools dictionary 34 | self.tools = self.process.tools 35 | 36 | def __call__(self, text, maxlength=8192, stream=False, **kwargs): 37 | """ 38 | Runs an agent loop. 39 | 40 | Args: 41 | text: instructions to run 42 | maxlength: maximum sequence length 43 | stream: stream response if True, defaults to False 44 | kwargs: additional keyword arguments 45 | 46 | Returns: 47 | result 48 | """ 49 | 50 | # Process parameters 51 | self.process.model.parameters(maxlength) 52 | 53 | # Run agent loop 54 | return self.process.run(text, stream=stream, **kwargs) 55 | -------------------------------------------------------------------------------- /src/python/txtai/agent/factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Factory module 3 | """ 4 | 5 | from smolagents import CodeAgent, ToolCallingAgent 6 | 7 | from .model import PipelineModel 8 | from .tool import ToolFactory 9 | 10 | 11 | class ProcessFactory: 12 | """ 13 | Methods to create agent processes. 14 | """ 15 | 16 | @staticmethod 17 | def create(config): 18 | """ 19 | Create an agent process runner. The agent process runner takes a list of tools and an LLM 20 | and executes an agent process flow. 21 | 22 | Args: 23 | config: agent configuration 24 | 25 | Returns: 26 | agent process runner 27 | """ 28 | 29 | constructor = ToolCallingAgent 30 | method = config.pop("method", None) 31 | if method == "code": 32 | constructor = CodeAgent 33 | 34 | # Create model backed by LLM pipeline 35 | model = config.pop("model", config.pop("llm", None)) 36 | model = PipelineModel(**model) if isinstance(model, dict) else PipelineModel(model) 37 | 38 | # Create the agent process 39 | return constructor(tools=ToolFactory.create(config), model=model, **config) 40 | -------------------------------------------------------------------------------- /src/python/txtai/agent/placeholder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Placeholder module 3 | """ 4 | 5 | 6 | class Agent: 7 | """ 8 | Agent placeholder stub for when smolagents isn't installed 9 | """ 10 | 11 | def __init__(self, *args, **kwargs): 12 | """ 13 | Raises an exception that smolagents isn't installed. 14 | """ 15 | 16 | raise ImportError('smolagents is not available - install "agent" extra to enable') 17 | -------------------------------------------------------------------------------- /src/python/txtai/agent/tool/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tool imports 3 | """ 4 | 5 | from .embeddings import EmbeddingsTool 6 | from .factory import ToolFactory 7 | from .function import FunctionTool 8 | -------------------------------------------------------------------------------- /src/python/txtai/agent/tool/embeddings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Embeddings module 3 | """ 4 | 5 | from smolagents import Tool 6 | 7 | from ...embeddings import Embeddings 8 | 9 | 10 | class EmbeddingsTool(Tool): 11 | """ 12 | Tool to execute an Embeddings search. 13 | """ 14 | 15 | def __init__(self, config): 16 | """ 17 | Creates a new EmbeddingsTool. 18 | 19 | Args: 20 | config: embeddings tool configuration 21 | """ 22 | 23 | # Tool parameters 24 | self.name = config["name"] 25 | self.description = f"""{config['description']}. Results are returned as a list of dict elements. 26 | Each result has keys 'id', 'text', 'score'.""" 27 | 28 | # Input and output descriptions 29 | self.inputs = {"query": {"type": "string", "description": "The search query to perform."}} 30 | self.output_type = "any" 31 | 32 | # Load embeddings instance 33 | self.embeddings = self.load(config) 34 | 35 | # Validate parameters and initialize tool 36 | super().__init__() 37 | 38 | # pylint: disable=W0221 39 | def forward(self, query): 40 | """ 41 | Runs a search. 42 | 43 | Args: 44 | query: input query 45 | 46 | Returns: 47 | search results 48 | """ 49 | 50 | return self.embeddings.search(query, 5) 51 | 52 | def load(self, config): 53 | """ 54 | Loads an embeddings instance from config. 55 | 56 | Args: 57 | config: embeddings tool configuration 58 | 59 | Returns: 60 | Embeddings 61 | """ 62 | 63 | if "target" in config: 64 | return config["target"] 65 | 66 | embeddings = Embeddings() 67 | embeddings.load(**config) 68 | 69 | return embeddings 70 | -------------------------------------------------------------------------------- /src/python/txtai/agent/tool/function.py: -------------------------------------------------------------------------------- 1 | """ 2 | Function imports 3 | """ 4 | 5 | from smolagents import Tool 6 | 7 | 8 | class FunctionTool(Tool): 9 | """ 10 | Creates a FunctionTool. A FunctionTool takes descriptive configuration and injects it along with a target function 11 | into an LLM prompt. 12 | """ 13 | 14 | # pylint: disable=W0231 15 | def __init__(self, config): 16 | """ 17 | Creates a FunctionTool. 18 | 19 | Args: 20 | config: `name`, `description`, `inputs`, `output` and `target` configuration 21 | """ 22 | 23 | # Tool parameters 24 | self.name = config["name"] 25 | self.description = config["description"] 26 | self.inputs = config["inputs"] 27 | self.output_type = config.get("output", config.get("output_type", "any")) 28 | self.target = config["target"] 29 | 30 | # pylint: disable=C0103 31 | # Skip forward signature validation 32 | self.skip_forward_signature_validation = True 33 | 34 | # Validate parameters and initialize tool 35 | super().__init__() 36 | 37 | def forward(self, *args, **kwargs): 38 | """ 39 | Runs target function. 40 | 41 | Args: 42 | args: positional args 43 | kwargs: keyword args 44 | 45 | Returns: 46 | result 47 | """ 48 | 49 | return self.target(*args, **kwargs) 50 | -------------------------------------------------------------------------------- /src/python/txtai/ann/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ANN imports 3 | """ 4 | 5 | from .annoy import Annoy 6 | from .base import ANN 7 | from .factory import ANNFactory 8 | from .faiss import Faiss 9 | from .hnsw import HNSW 10 | from .numpy import NumPy 11 | from .pgvector import PGVector 12 | from .torch import Torch 13 | -------------------------------------------------------------------------------- /src/python/txtai/ann/factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Factory module 3 | """ 4 | 5 | from ..util import Resolver 6 | 7 | from .annoy import Annoy 8 | from .faiss import Faiss 9 | from .hnsw import HNSW 10 | from .numpy import NumPy 11 | from .pgvector import PGVector 12 | from .sqlite import SQLite 13 | from .torch import Torch 14 | 15 | 16 | class ANNFactory: 17 | """ 18 | Methods to create ANN indexes. 19 | """ 20 | 21 | @staticmethod 22 | def create(config): 23 | """ 24 | Create an ANN. 25 | 26 | Args: 27 | config: index configuration parameters 28 | 29 | Returns: 30 | ANN 31 | """ 32 | 33 | # ANN instance 34 | ann = None 35 | backend = config.get("backend", "faiss") 36 | 37 | # Create ANN instance 38 | if backend == "annoy": 39 | ann = Annoy(config) 40 | elif backend == "faiss": 41 | ann = Faiss(config) 42 | elif backend == "hnsw": 43 | ann = HNSW(config) 44 | elif backend == "numpy": 45 | ann = NumPy(config) 46 | elif backend == "pgvector": 47 | ann = PGVector(config) 48 | elif backend == "sqlite": 49 | ann = SQLite(config) 50 | elif backend == "torch": 51 | ann = Torch(config) 52 | else: 53 | ann = ANNFactory.resolve(backend, config) 54 | 55 | # Store config back 56 | config["backend"] = backend 57 | 58 | return ann 59 | 60 | @staticmethod 61 | def resolve(backend, config): 62 | """ 63 | Attempt to resolve a custom backend. 64 | 65 | Args: 66 | backend: backend class 67 | config: index configuration parameters 68 | 69 | Returns: 70 | ANN 71 | """ 72 | 73 | try: 74 | return Resolver()(backend)(config) 75 | except Exception as e: 76 | raise ImportError(f"Unable to resolve ann backend: '{backend}'") from e 77 | -------------------------------------------------------------------------------- /src/python/txtai/ann/torch.py: -------------------------------------------------------------------------------- 1 | """ 2 | PyTorch module 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from .numpy import NumPy 9 | 10 | 11 | class Torch(NumPy): 12 | """ 13 | Builds an ANN index backed by a PyTorch array. 14 | """ 15 | 16 | def __init__(self, config): 17 | super().__init__(config) 18 | 19 | # Define array functions 20 | self.all, self.cat, self.dot, self.zeros = torch.all, torch.cat, torch.mm, torch.zeros 21 | self.argsort, self.xor, self.clip = torch.argsort, torch.bitwise_xor, torch.clip 22 | 23 | def tensor(self, array): 24 | # Convert array to Tensor 25 | if isinstance(array, np.ndarray): 26 | array = torch.from_numpy(array) 27 | 28 | # Load to GPU device, if available 29 | return array.cuda() if torch.cuda.is_available() else array 30 | 31 | def numpy(self, array): 32 | return array.cpu().numpy() 33 | 34 | def totype(self, array, dtype): 35 | return array.long() if dtype == np.int64 else array 36 | 37 | def settings(self): 38 | return {"torch": torch.__version__} 39 | -------------------------------------------------------------------------------- /src/python/txtai/api/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | API imports 3 | """ 4 | 5 | # Conditional import 6 | try: 7 | from .authorization import Authorization 8 | from .application import app, start 9 | from .base import API 10 | from .cluster import Cluster 11 | from .extension import Extension 12 | from .factory import APIFactory 13 | from .responses import * 14 | from .routers import * 15 | from .route import EncodingAPIRoute 16 | except ImportError as missing: 17 | # pylint: disable=W0707 18 | raise ImportError('API is not available - install "api" extra to enable') from missing 19 | -------------------------------------------------------------------------------- /src/python/txtai/api/authorization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authorization module 3 | """ 4 | 5 | import hashlib 6 | import os 7 | 8 | from fastapi import Header, HTTPException 9 | 10 | 11 | class Authorization: 12 | """ 13 | Basic token authorization. 14 | """ 15 | 16 | def __init__(self, token=None): 17 | """ 18 | Creates a new Authorization instance. 19 | 20 | Args: 21 | token: SHA-256 hash of token to check 22 | """ 23 | 24 | self.token = token if token else os.environ.get("TOKEN") 25 | 26 | def __call__(self, authorization: str = Header(default=None)): 27 | """ 28 | Validates authorization header is present and equal to current token. 29 | 30 | Args: 31 | authorization: authorization header 32 | """ 33 | 34 | if not authorization or self.token != self.digest(authorization): 35 | raise HTTPException(status_code=401, detail="Invalid Authorization Token") 36 | 37 | def digest(self, authorization): 38 | """ 39 | Computes a SHA-256 hash for input authorization token. 40 | 41 | Args: 42 | authorization: authorization header 43 | 44 | Returns: 45 | SHA-256 hash of authorization token 46 | """ 47 | 48 | # Replace Bearer prefix 49 | prefix = "Bearer " 50 | token = authorization[len(prefix) :] if authorization.startswith(prefix) else authorization 51 | 52 | # Compute SHA-256 hash 53 | return hashlib.sha256(token.encode("utf-8")).hexdigest() 54 | -------------------------------------------------------------------------------- /src/python/txtai/api/extension.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extension module 3 | """ 4 | 5 | 6 | class Extension: 7 | """ 8 | Defines an API extension. API extensions can expose custom pipelines or other custom logic. 9 | """ 10 | 11 | def __call__(self, app): 12 | """ 13 | Hook to register custom routing logic and/or modify the FastAPI instance. 14 | 15 | Args: 16 | app: FastAPI application instance 17 | """ 18 | 19 | return 20 | -------------------------------------------------------------------------------- /src/python/txtai/api/factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | API factory module 3 | """ 4 | 5 | from ..util import Resolver 6 | 7 | 8 | class APIFactory: 9 | """ 10 | API factory. Creates new API instances. 11 | """ 12 | 13 | @staticmethod 14 | def get(api): 15 | """ 16 | Gets a new instance of api class. 17 | 18 | Args: 19 | api: API instance class 20 | 21 | Returns: 22 | API 23 | """ 24 | 25 | return Resolver()(api) 26 | 27 | @staticmethod 28 | def create(config, api): 29 | """ 30 | Creates a new API instance. 31 | 32 | Args: 33 | config: API configuration 34 | api: API instance class 35 | 36 | Returns: 37 | API instance 38 | """ 39 | 40 | return APIFactory.get(api)(config) 41 | -------------------------------------------------------------------------------- /src/python/txtai/api/responses/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Responses imports 3 | """ 4 | 5 | from .factory import ResponseFactory 6 | from .json import JSONEncoder, JSONResponse 7 | from .messagepack import MessagePackResponse 8 | -------------------------------------------------------------------------------- /src/python/txtai/api/responses/factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Factory module 3 | """ 4 | 5 | from .json import JSONResponse 6 | from .messagepack import MessagePackResponse 7 | 8 | 9 | class ResponseFactory: 10 | """ 11 | Methods to create Response classes. 12 | """ 13 | 14 | @staticmethod 15 | def create(request): 16 | """ 17 | Gets a response class for request using the Accept header. 18 | 19 | Args: 20 | request: request 21 | 22 | Returns: 23 | response class 24 | """ 25 | 26 | # Get Accept header 27 | accept = request.headers.get("Accept") 28 | 29 | # Get response class 30 | return MessagePackResponse if accept == MessagePackResponse.media_type else JSONResponse 31 | -------------------------------------------------------------------------------- /src/python/txtai/api/responses/json.py: -------------------------------------------------------------------------------- 1 | """ 2 | JSON module 3 | """ 4 | 5 | import base64 6 | import json 7 | 8 | from io import BytesIO 9 | from typing import Any 10 | 11 | import fastapi.responses 12 | 13 | from PIL.Image import Image 14 | 15 | 16 | class JSONEncoder(json.JSONEncoder): 17 | """ 18 | Extended JSONEncoder that serializes images and byte streams as base64 encoded text. 19 | """ 20 | 21 | def default(self, o): 22 | # Convert Image to BytesIO 23 | if isinstance(o, Image): 24 | buffered = BytesIO() 25 | o.save(buffered, format=o.format, quality="keep") 26 | o = buffered 27 | 28 | # Unpack bytes from BytesIO 29 | if isinstance(o, BytesIO): 30 | o = o.getvalue() 31 | 32 | # Base64 encode bytes instances 33 | if isinstance(o, bytes): 34 | return base64.b64encode(o).decode("utf-8") 35 | 36 | # Default handler 37 | return super().default(o) 38 | 39 | 40 | class JSONResponse(fastapi.responses.JSONResponse): 41 | """ 42 | Extended JSONResponse that serializes images and byte streams as base64 encoded text. 43 | """ 44 | 45 | def render(self, content: Any) -> bytes: 46 | """ 47 | Renders content to the response. 48 | 49 | Args: 50 | content: input content 51 | 52 | Returns: 53 | rendered content as bytes 54 | """ 55 | 56 | return json.dumps(content, ensure_ascii=False, allow_nan=False, indent=None, separators=(",", ":"), cls=JSONEncoder).encode("utf-8") 57 | -------------------------------------------------------------------------------- /src/python/txtai/api/responses/messagepack.py: -------------------------------------------------------------------------------- 1 | """ 2 | MessagePack module 3 | """ 4 | 5 | from io import BytesIO 6 | from typing import Any 7 | 8 | import msgpack 9 | 10 | from fastapi import Response 11 | from PIL.Image import Image 12 | 13 | 14 | class MessagePackResponse(Response): 15 | """ 16 | Encodes responses with MessagePack. 17 | """ 18 | 19 | media_type = "application/msgpack" 20 | 21 | def render(self, content: Any) -> bytes: 22 | """ 23 | Renders content to the response. 24 | 25 | Args: 26 | content: input content 27 | 28 | Returns: 29 | rendered content as bytes 30 | """ 31 | 32 | return msgpack.packb(content, default=MessagePackEncoder()) 33 | 34 | 35 | class MessagePackEncoder: 36 | """ 37 | Extended MessagePack encoder that converts images to bytes. 38 | """ 39 | 40 | def __call__(self, o): 41 | # Convert Image to bytes 42 | if isinstance(o, Image): 43 | buffered = BytesIO() 44 | o.save(buffered, format=o.format, quality="keep") 45 | o = buffered 46 | 47 | # Get bytes from BytesIO 48 | if isinstance(o, BytesIO): 49 | o = o.getvalue() 50 | 51 | return o 52 | -------------------------------------------------------------------------------- /src/python/txtai/api/route.py: -------------------------------------------------------------------------------- 1 | """ 2 | Route module 3 | """ 4 | 5 | from fastapi.routing import APIRoute, get_request_handler 6 | 7 | from .responses import ResponseFactory 8 | 9 | 10 | class EncodingAPIRoute(APIRoute): 11 | """ 12 | Extended APIRoute that encodes responses based on HTTP Accept header. 13 | """ 14 | 15 | def get_route_handler(self): 16 | """ 17 | Resolves a response class based on the HTTP Accept header. 18 | 19 | Returns: 20 | route handler function 21 | """ 22 | 23 | async def handler(request): 24 | route = get_request_handler( 25 | dependant=self.dependant, 26 | body_field=self.body_field, 27 | status_code=self.status_code, 28 | response_class=ResponseFactory.create(request), 29 | response_field=self.secure_cloned_response_field, 30 | response_model_include=self.response_model_include, 31 | response_model_exclude=self.response_model_exclude, 32 | response_model_by_alias=self.response_model_by_alias, 33 | response_model_exclude_unset=self.response_model_exclude_unset, 34 | response_model_exclude_defaults=self.response_model_exclude_defaults, 35 | response_model_exclude_none=self.response_model_exclude_none, 36 | dependency_overrides_provider=self.dependency_overrides_provider, 37 | ) 38 | 39 | return await route(request) 40 | 41 | return handler 42 | -------------------------------------------------------------------------------- /src/python/txtai/api/routers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Router imports 3 | """ 4 | 5 | from . import agent 6 | from . import caption 7 | from . import embeddings 8 | from . import entity 9 | from . import extractor 10 | from . import labels 11 | from . import llm 12 | from . import objects 13 | from . import openai 14 | from . import rag 15 | from . import segmentation 16 | from . import similarity 17 | from . import summary 18 | from . import tabular 19 | from . import textractor 20 | from . import texttospeech 21 | from . import transcription 22 | from . import translation 23 | from . import workflow 24 | from . import upload 25 | -------------------------------------------------------------------------------- /src/python/txtai/api/routers/agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines API paths for agent endpoints. 3 | """ 4 | 5 | from typing import Optional 6 | 7 | from fastapi import APIRouter, Body 8 | from fastapi.responses import StreamingResponse 9 | 10 | from .. import application 11 | from ..route import EncodingAPIRoute 12 | 13 | router = APIRouter(route_class=EncodingAPIRoute) 14 | 15 | 16 | @router.post("/agent") 17 | def agent(name: str = Body(...), text: str = Body(...), maxlength: Optional[int] = Body(default=None), stream: Optional[bool] = Body(default=None)): 18 | """ 19 | Executes a named agent for input text. 20 | 21 | Args: 22 | name: agent name 23 | text: instructions to run 24 | maxlength: maximum sequence length 25 | stream: stream response if True, defaults to False 26 | 27 | Returns: 28 | response text 29 | """ 30 | 31 | # Build keyword arguments 32 | kwargs = {key: value for key, value in [("stream", stream), ("maxlength", maxlength)] if value} 33 | 34 | # Run agent 35 | result = application.get().agent(name, text, **kwargs) 36 | 37 | # Handle both standard and streaming responses 38 | return StreamingResponse(result) if stream else result 39 | -------------------------------------------------------------------------------- /src/python/txtai/api/routers/caption.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines API paths for caption endpoints. 3 | """ 4 | 5 | from typing import List 6 | 7 | from fastapi import APIRouter, Body 8 | 9 | from .. import application 10 | from ..route import EncodingAPIRoute 11 | 12 | router = APIRouter(route_class=EncodingAPIRoute) 13 | 14 | 15 | @router.get("/caption") 16 | def caption(file: str): 17 | """ 18 | Builds captions for images. 19 | 20 | Args: 21 | file: file to process 22 | 23 | Returns: 24 | list of captions 25 | """ 26 | 27 | return application.get().pipeline("caption", (file,)) 28 | 29 | 30 | @router.post("/batchcaption") 31 | def batchcaption(files: List[str] = Body(...)): 32 | """ 33 | Builds captions for images. 34 | 35 | Args: 36 | files: list of files to process 37 | 38 | Returns: 39 | list of captions 40 | """ 41 | 42 | return application.get().pipeline("caption", (files,)) 43 | -------------------------------------------------------------------------------- /src/python/txtai/api/routers/entity.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines API paths for entity endpoints. 3 | """ 4 | 5 | from typing import List 6 | 7 | from fastapi import APIRouter, Body 8 | 9 | from .. import application 10 | from ..route import EncodingAPIRoute 11 | 12 | router = APIRouter(route_class=EncodingAPIRoute) 13 | 14 | 15 | @router.get("/entity") 16 | def entity(text: str): 17 | """ 18 | Applies a token classifier to text. 19 | 20 | Args: 21 | text: input text 22 | 23 | Returns: 24 | list of (entity, entity type, score) per text element 25 | """ 26 | 27 | return application.get().pipeline("entity", (text,)) 28 | 29 | 30 | @router.post("/batchentity") 31 | def batchentity(texts: List[str] = Body(...)): 32 | """ 33 | Applies a token classifier to text. 34 | 35 | Args: 36 | texts: list of text 37 | 38 | Returns: 39 | list of (entity, entity type, score) per text element 40 | """ 41 | 42 | return application.get().pipeline("entity", (texts,)) 43 | -------------------------------------------------------------------------------- /src/python/txtai/api/routers/extractor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines API paths for extractor endpoints. 3 | """ 4 | 5 | from typing import List, Optional 6 | 7 | from fastapi import APIRouter, Body 8 | 9 | from .. import application 10 | from ..route import EncodingAPIRoute 11 | 12 | router = APIRouter(route_class=EncodingAPIRoute) 13 | 14 | 15 | @router.post("/extract") 16 | def extract(queue: List[dict] = Body(...), texts: Optional[List[str]] = Body(default=None)): 17 | """ 18 | Extracts answers to input questions. 19 | 20 | Args: 21 | queue: list of {name: value, query: value, question: value, snippet: value} 22 | texts: optional list of text 23 | 24 | Returns: 25 | list of {name: value, answer: value} 26 | """ 27 | 28 | return application.get().extract(queue, texts) 29 | -------------------------------------------------------------------------------- /src/python/txtai/api/routers/labels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines API paths for labels endpoints. 3 | """ 4 | 5 | from typing import List 6 | 7 | from fastapi import APIRouter, Body 8 | 9 | from .. import application 10 | from ..route import EncodingAPIRoute 11 | 12 | router = APIRouter(route_class=EncodingAPIRoute) 13 | 14 | 15 | @router.post("/label") 16 | def label(text: str = Body(...), labels: List[str] = Body(...)): 17 | """ 18 | Applies a zero shot classifier to text using a list of labels. Returns a list of 19 | {id: value, score: value} sorted by highest score, where id is the index in labels. 20 | 21 | Args: 22 | text: input text 23 | labels: list of labels 24 | 25 | Returns: 26 | list of {id: value, score: value} per text element 27 | """ 28 | 29 | return application.get().label(text, labels) 30 | 31 | 32 | @router.post("/batchlabel") 33 | def batchlabel(texts: List[str] = Body(...), labels: List[str] = Body(...)): 34 | """ 35 | Applies a zero shot classifier to list of text using a list of labels. Returns a list of 36 | {id: value, score: value} sorted by highest score, where id is the index in labels per 37 | text element. 38 | 39 | Args: 40 | texts: list of text 41 | labels: list of labels 42 | 43 | Returns: 44 | list of {id: value score: value} per text element 45 | """ 46 | 47 | return application.get().label(texts, labels) 48 | -------------------------------------------------------------------------------- /src/python/txtai/api/routers/llm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines API paths for llm endpoints. 3 | """ 4 | 5 | from typing import List, Optional 6 | 7 | from fastapi import APIRouter, Body 8 | from fastapi.responses import StreamingResponse 9 | 10 | from .. import application 11 | from ..route import EncodingAPIRoute 12 | 13 | router = APIRouter(route_class=EncodingAPIRoute) 14 | 15 | 16 | @router.get("/llm") 17 | def llm(text: str, maxlength: Optional[int] = None, stream: Optional[bool] = False): 18 | """ 19 | Runs a LLM pipeline for the input text. 20 | 21 | Args: 22 | text: input text 23 | maxlength: optional response max length 24 | stream: streams response if True 25 | 26 | Returns: 27 | response text 28 | """ 29 | 30 | # Build keyword arguments 31 | kwargs = {key: value for key, value in [("stream", stream), ("maxlength", maxlength)] if value} 32 | 33 | # Run pipeline 34 | result = application.get().pipeline("llm", text, **kwargs) 35 | 36 | # Handle both standard and streaming responses 37 | return StreamingResponse(result) if stream else result 38 | 39 | 40 | @router.post("/batchllm") 41 | def batchllm(texts: List[str] = Body(...), maxlength: Optional[int] = Body(default=None), stream: Optional[bool] = Body(default=False)): 42 | """ 43 | Runs a LLM pipeline for the input texts. 44 | 45 | Args: 46 | texts: input texts 47 | maxlength: optional response max length 48 | stream: streams response if True 49 | 50 | Returns: 51 | response texts 52 | """ 53 | 54 | # Build keyword arguments 55 | kwargs = {key: value for key, value in [("stream", stream), ("maxlength", maxlength)] if value} 56 | 57 | # Run pipeline 58 | result = application.get().pipeline("llm", texts, **kwargs) 59 | 60 | # Handle both standard and streaming responses 61 | return StreamingResponse(result) if stream else result 62 | -------------------------------------------------------------------------------- /src/python/txtai/api/routers/objects.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines API paths for objects endpoints. 3 | """ 4 | 5 | from typing import List 6 | 7 | from fastapi import APIRouter, Body 8 | 9 | from .. import application 10 | from ..route import EncodingAPIRoute 11 | 12 | router = APIRouter(route_class=EncodingAPIRoute) 13 | 14 | 15 | @router.get("/objects") 16 | def objects(file: str): 17 | """ 18 | Applies object detection/image classification models to images. 19 | 20 | Args: 21 | file: file to process 22 | 23 | Returns: 24 | list of (label, score) elements 25 | """ 26 | 27 | return application.get().pipeline("objects", (file,)) 28 | 29 | 30 | @router.post("/batchobjects") 31 | def batchobjects(files: List[str] = Body(...)): 32 | """ 33 | Applies object detection/image classification models to images. 34 | 35 | Args: 36 | files: list of files to process 37 | 38 | Returns: 39 | list of (label, score) elements 40 | """ 41 | 42 | return application.get().pipeline("objects", (files,)) 43 | -------------------------------------------------------------------------------- /src/python/txtai/api/routers/rag.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines API paths for rag endpoints. 3 | """ 4 | 5 | from typing import List, Optional 6 | 7 | from fastapi import APIRouter, Body 8 | from fastapi.responses import StreamingResponse 9 | 10 | from .. import application 11 | from ..route import EncodingAPIRoute 12 | 13 | router = APIRouter(route_class=EncodingAPIRoute) 14 | 15 | 16 | @router.get("/rag") 17 | def rag(query: str, maxlength: Optional[int] = None, stream: Optional[bool] = False): 18 | """ 19 | Runs a RAG pipeline for the input query. 20 | 21 | Args: 22 | query: input RAG query 23 | maxlength: optional response max length 24 | stream: streams response if True 25 | 26 | Returns: 27 | answer 28 | """ 29 | 30 | # Build keyword arguments 31 | kwargs = {key: value for key, value in [("stream", stream), ("maxlength", maxlength)] if value} 32 | 33 | # Run pipeline 34 | result = application.get().pipeline("rag", query, **kwargs) 35 | 36 | # Handle both standard and streaming responses 37 | return StreamingResponse(result) if stream else result 38 | 39 | 40 | @router.post("/batchrag") 41 | def batchrag(queries: List[str] = Body(...), maxlength: Optional[int] = Body(default=None), stream: Optional[bool] = Body(default=False)): 42 | """ 43 | Runs a RAG pipeline for the input queries. 44 | 45 | Args: 46 | queries: input RAG queries 47 | maxlength: optional response max length 48 | stream: streams response if True 49 | 50 | Returns: 51 | answers 52 | """ 53 | 54 | # Build keyword arguments 55 | kwargs = {key: value for key, value in [("stream", stream), ("maxlength", maxlength)] if value} 56 | 57 | # Run pipeline 58 | result = application.get().pipeline("rag", queries, **kwargs) 59 | 60 | # Handle both standard and streaming responses 61 | return StreamingResponse(result) if stream else result 62 | -------------------------------------------------------------------------------- /src/python/txtai/api/routers/segmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines API paths for segmentation endpoints. 3 | """ 4 | 5 | from typing import List 6 | 7 | from fastapi import APIRouter, Body 8 | 9 | from .. import application 10 | from ..route import EncodingAPIRoute 11 | 12 | router = APIRouter(route_class=EncodingAPIRoute) 13 | 14 | 15 | @router.get("/segment") 16 | def segment(text: str): 17 | """ 18 | Segments text into semantic units. 19 | 20 | Args: 21 | text: input text 22 | 23 | Returns: 24 | segmented text 25 | """ 26 | 27 | return application.get().pipeline("segmentation", (text,)) 28 | 29 | 30 | @router.post("/batchsegment") 31 | def batchsegment(texts: List[str] = Body(...)): 32 | """ 33 | Segments text into semantic units. 34 | 35 | Args: 36 | texts: list of texts to segment 37 | 38 | Returns: 39 | list of segmented text 40 | """ 41 | 42 | return application.get().pipeline("segmentation", (texts,)) 43 | -------------------------------------------------------------------------------- /src/python/txtai/api/routers/similarity.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines API paths for similarity endpoints. 3 | """ 4 | 5 | from typing import List 6 | 7 | from fastapi import APIRouter, Body 8 | 9 | from .. import application 10 | from ..route import EncodingAPIRoute 11 | 12 | router = APIRouter(route_class=EncodingAPIRoute) 13 | 14 | 15 | @router.post("/similarity") 16 | def similarity(query: str = Body(...), texts: List[str] = Body(...)): 17 | """ 18 | Computes the similarity between query and list of text. Returns a list of 19 | {id: value, score: value} sorted by highest score, where id is the index 20 | in texts. 21 | 22 | Args: 23 | query: query text 24 | texts: list of text 25 | 26 | Returns: 27 | list of {id: value, score: value} 28 | """ 29 | 30 | return application.get().similarity(query, texts) 31 | 32 | 33 | @router.post("/batchsimilarity") 34 | def batchsimilarity(queries: List[str] = Body(...), texts: List[str] = Body(...)): 35 | """ 36 | Computes the similarity between list of queries and list of text. Returns a list 37 | of {id: value, score: value} sorted by highest score per query, where id is the 38 | index in texts. 39 | 40 | Args: 41 | queries: queries text 42 | texts: list of text 43 | 44 | Returns: 45 | list of {id: value, score: value} per query 46 | """ 47 | 48 | return application.get().batchsimilarity(queries, texts) 49 | -------------------------------------------------------------------------------- /src/python/txtai/api/routers/summary.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines API paths for summary endpoints. 3 | """ 4 | 5 | from typing import List, Optional 6 | 7 | from fastapi import APIRouter, Body 8 | 9 | from .. import application 10 | from ..route import EncodingAPIRoute 11 | 12 | router = APIRouter(route_class=EncodingAPIRoute) 13 | 14 | 15 | @router.get("/summary") 16 | def summary(text: str, minlength: Optional[int] = None, maxlength: Optional[int] = None): 17 | """ 18 | Runs a summarization model against a block of text. 19 | 20 | Args: 21 | text: text to summarize 22 | minlength: minimum length for summary 23 | maxlength: maximum length for summary 24 | 25 | Returns: 26 | summary text 27 | """ 28 | 29 | return application.get().pipeline("summary", (text, minlength, maxlength)) 30 | 31 | 32 | @router.post("/batchsummary") 33 | def batchsummary(texts: List[str] = Body(...), minlength: Optional[int] = Body(default=None), maxlength: Optional[int] = Body(default=None)): 34 | """ 35 | Runs a summarization model against a block of text. 36 | 37 | Args: 38 | texts: list of text to summarize 39 | minlength: minimum length for summary 40 | maxlength: maximum length for summary 41 | 42 | Returns: 43 | list of summary text 44 | """ 45 | 46 | return application.get().pipeline("summary", (texts, minlength, maxlength)) 47 | -------------------------------------------------------------------------------- /src/python/txtai/api/routers/tabular.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines API paths for tabular endpoints. 3 | """ 4 | 5 | from typing import List 6 | 7 | from fastapi import APIRouter, Body 8 | 9 | from .. import application 10 | from ..route import EncodingAPIRoute 11 | 12 | router = APIRouter(route_class=EncodingAPIRoute) 13 | 14 | 15 | @router.get("/tabular") 16 | def tabular(file: str): 17 | """ 18 | Splits tabular data into rows and columns. 19 | 20 | Args: 21 | file: file to process 22 | 23 | Returns: 24 | list of (id, text, tag) elements 25 | """ 26 | 27 | return application.get().pipeline("tabular", (file,)) 28 | 29 | 30 | @router.post("/batchtabular") 31 | def batchtabular(files: List[str] = Body(...)): 32 | """ 33 | Splits tabular data into rows and columns. 34 | 35 | Args: 36 | files: list of files to process 37 | 38 | Returns: 39 | list of (id, text, tag) elements 40 | """ 41 | 42 | return application.get().pipeline("tabular", (files,)) 43 | -------------------------------------------------------------------------------- /src/python/txtai/api/routers/textractor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines API paths for textractor endpoints. 3 | """ 4 | 5 | from typing import List 6 | 7 | from fastapi import APIRouter, Body 8 | 9 | from .. import application 10 | from ..route import EncodingAPIRoute 11 | 12 | router = APIRouter(route_class=EncodingAPIRoute) 13 | 14 | 15 | @router.get("/textract") 16 | def textract(file: str): 17 | """ 18 | Extracts text from a file at path. 19 | 20 | Args: 21 | file: file to extract text 22 | 23 | Returns: 24 | extracted text 25 | """ 26 | 27 | return application.get().pipeline("textractor", (file,)) 28 | 29 | 30 | @router.post("/batchtextract") 31 | def batchtextract(files: List[str] = Body(...)): 32 | """ 33 | Extracts text from a file at path. 34 | 35 | Args: 36 | files: list of files to extract text 37 | 38 | Returns: 39 | list of extracted text 40 | """ 41 | 42 | return application.get().pipeline("textractor", (files,)) 43 | -------------------------------------------------------------------------------- /src/python/txtai/api/routers/texttospeech.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines API paths for TTS endpoints 3 | """ 4 | 5 | from typing import Optional 6 | 7 | from fastapi import APIRouter, Response 8 | 9 | from .. import application 10 | from ..route import EncodingAPIRoute 11 | 12 | router = APIRouter(route_class=EncodingAPIRoute) 13 | 14 | 15 | @router.get("/texttospeech") 16 | def texttospeech(text: str, speaker: Optional[str] = None, encoding: Optional[str] = "mp3"): 17 | """ 18 | Generates speech from text. 19 | 20 | Args: 21 | text: text 22 | speaker: speaker id, defaults to 1 23 | encoding: optional audio encoding format 24 | 25 | Returns: 26 | Audio data 27 | """ 28 | 29 | # Convert to audio 30 | audio = application.get().pipeline("texttospeech", text, speaker=speaker, encoding=encoding) 31 | 32 | # Write audio 33 | return Response(audio, headers={"Content-Disposition": f"attachment;filename=speech.{encoding.lower()}"}) 34 | -------------------------------------------------------------------------------- /src/python/txtai/api/routers/transcription.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines API paths for transcription endpoints. 3 | """ 4 | 5 | from typing import List 6 | 7 | from fastapi import APIRouter, Body 8 | 9 | from .. import application 10 | from ..route import EncodingAPIRoute 11 | 12 | router = APIRouter(route_class=EncodingAPIRoute) 13 | 14 | 15 | @router.get("/transcribe") 16 | def transcribe(file: str): 17 | """ 18 | Transcribes audio files to text. 19 | 20 | Args: 21 | file: file to transcribe 22 | 23 | Returns: 24 | transcribed text 25 | """ 26 | 27 | return application.get().pipeline("transcription", (file,)) 28 | 29 | 30 | @router.post("/batchtranscribe") 31 | def batchtranscribe(files: List[str] = Body(...)): 32 | """ 33 | Transcribes audio files to text. 34 | 35 | Args: 36 | files: list of files to transcribe 37 | 38 | Returns: 39 | list of transcribed text 40 | """ 41 | 42 | return application.get().pipeline("transcription", (files,)) 43 | -------------------------------------------------------------------------------- /src/python/txtai/api/routers/translation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines API paths for translation endpoints. 3 | """ 4 | 5 | from typing import List, Optional 6 | 7 | from fastapi import APIRouter, Body 8 | 9 | from .. import application 10 | from ..route import EncodingAPIRoute 11 | 12 | router = APIRouter(route_class=EncodingAPIRoute) 13 | 14 | 15 | @router.get("/translate") 16 | def translate(text: str, target: Optional[str] = "en", source: Optional[str] = None): 17 | """ 18 | Translates text from source language into target language. 19 | 20 | Args: 21 | text: text to translate 22 | target: target language code, defaults to "en" 23 | source: source language code, detects language if not provided 24 | 25 | Returns: 26 | translated text 27 | """ 28 | 29 | return application.get().pipeline("translation", (text, target, source)) 30 | 31 | 32 | @router.post("/batchtranslate") 33 | def batchtranslate(texts: List[str] = Body(...), target: Optional[str] = Body(default="en"), source: Optional[str] = Body(default=None)): 34 | """ 35 | Translates text from source language into target language. 36 | 37 | Args: 38 | texts: list of text to translate 39 | target: target language code, defaults to "en" 40 | source: source language code, detects language if not provided 41 | 42 | Returns: 43 | list of translated text 44 | """ 45 | 46 | return application.get().pipeline("translation", (texts, target, source)) 47 | -------------------------------------------------------------------------------- /src/python/txtai/api/routers/upload.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines API paths for upload endpoints. 3 | """ 4 | 5 | import shutil 6 | import tempfile 7 | 8 | from typing import List 9 | 10 | from fastapi import APIRouter, File, Form, UploadFile 11 | 12 | from ..route import EncodingAPIRoute 13 | 14 | 15 | router = APIRouter(route_class=EncodingAPIRoute) 16 | 17 | 18 | @router.post("/upload") 19 | def upload(files: List[UploadFile] = File(), suffix: str = Form(default=None)): 20 | """ 21 | Uploads files for local server processing. 22 | 23 | Args: 24 | data: list of files to upload 25 | 26 | Returns: 27 | list of server paths 28 | """ 29 | 30 | paths = [] 31 | for f in files: 32 | with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=suffix) as tmp: 33 | shutil.copyfileobj(f.file, tmp) 34 | paths.append(tmp.name) 35 | 36 | return paths 37 | -------------------------------------------------------------------------------- /src/python/txtai/api/routers/workflow.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines API paths for workflow endpoints. 3 | """ 4 | 5 | from typing import List 6 | 7 | from fastapi import APIRouter, Body 8 | 9 | from .. import application 10 | from ..route import EncodingAPIRoute 11 | 12 | router = APIRouter(route_class=EncodingAPIRoute) 13 | 14 | 15 | @router.post("/workflow") 16 | def workflow(name: str = Body(...), elements: List = Body(...)): 17 | """ 18 | Executes a named workflow using elements as input. 19 | 20 | Args: 21 | name: workflow name 22 | elements: list of elements to run through workflow 23 | 24 | Returns: 25 | list of processed elements 26 | """ 27 | 28 | return application.get().workflow(name, elements) 29 | -------------------------------------------------------------------------------- /src/python/txtai/app/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | App imports 3 | """ 4 | 5 | from .base import Application, ReadOnlyError 6 | -------------------------------------------------------------------------------- /src/python/txtai/archive/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Archive imports 3 | """ 4 | 5 | from .base import Archive 6 | from .compress import Compress 7 | from .factory import ArchiveFactory 8 | from .tar import Tar 9 | from .zip import Zip 10 | -------------------------------------------------------------------------------- /src/python/txtai/archive/compress.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compress module 3 | """ 4 | 5 | import os 6 | 7 | 8 | class Compress: 9 | """ 10 | Base class for Compress instances. 11 | """ 12 | 13 | def pack(self, path, output): 14 | """ 15 | Compresses files in directory path to file output. 16 | 17 | Args: 18 | path: input directory path 19 | output: output file 20 | """ 21 | 22 | raise NotImplementedError 23 | 24 | def unpack(self, path, output): 25 | """ 26 | Extracts all files in path to output. 27 | 28 | Args: 29 | path: input file path 30 | output: output directory 31 | """ 32 | 33 | raise NotImplementedError 34 | 35 | def validate(self, directory, path): 36 | """ 37 | Validates path is under directory. 38 | 39 | Args: 40 | directory: base directory 41 | path: path to validate 42 | 43 | Returns: 44 | True if path is under directory, False otherwise 45 | """ 46 | 47 | directory = os.path.abspath(directory) 48 | path = os.path.abspath(path) 49 | prefix = os.path.commonprefix([directory, path]) 50 | 51 | return prefix == directory 52 | -------------------------------------------------------------------------------- /src/python/txtai/archive/factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Factory module 3 | """ 4 | 5 | from .base import Archive 6 | 7 | 8 | class ArchiveFactory: 9 | """ 10 | Methods to create Archive instances. 11 | """ 12 | 13 | @staticmethod 14 | def create(directory=None): 15 | """ 16 | Create a new Archive instance. 17 | 18 | Args: 19 | directory: optional default working directory, otherwise uses a temporary directory 20 | 21 | Returns: 22 | Archive 23 | """ 24 | 25 | return Archive(directory) 26 | -------------------------------------------------------------------------------- /src/python/txtai/archive/tar.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tar module 3 | """ 4 | 5 | import os 6 | import tarfile 7 | 8 | from .compress import Compress 9 | 10 | 11 | class Tar(Compress): 12 | """ 13 | Tar compression 14 | """ 15 | 16 | def pack(self, path, output): 17 | # Infer compression type 18 | compression = self.compression(output) 19 | 20 | with tarfile.open(output, f"w:{compression}" if compression else "w") as tar: 21 | tar.add(path, arcname=".") 22 | 23 | def unpack(self, path, output): 24 | # Infer compression type 25 | compression = self.compression(path) 26 | 27 | with tarfile.open(path, f"r:{compression}" if compression else "r") as tar: 28 | # Validate paths 29 | for member in tar.getmembers(): 30 | fullpath = os.path.join(path, member.name) 31 | if not self.validate(path, fullpath): 32 | raise IOError(f"Invalid tar entry: {member.name}") 33 | 34 | tar.extractall(output) 35 | 36 | def compression(self, path): 37 | """ 38 | Gets compression type for path. 39 | 40 | Args: 41 | path: path to file 42 | 43 | Returns: 44 | compression type 45 | """ 46 | 47 | # Infer compression type from last path component. Limit to supported types. 48 | compression = path.lower().split(".")[-1] 49 | return compression if compression in ("bz2", "gz", "xz") else None 50 | -------------------------------------------------------------------------------- /src/python/txtai/archive/zip.py: -------------------------------------------------------------------------------- 1 | """ 2 | Zip module 3 | """ 4 | 5 | import os 6 | 7 | from zipfile import ZipFile, ZIP_DEFLATED 8 | 9 | from .compress import Compress 10 | 11 | 12 | class Zip(Compress): 13 | """ 14 | Zip compression 15 | """ 16 | 17 | def pack(self, path, output): 18 | with ZipFile(output, "w", ZIP_DEFLATED) as zfile: 19 | for root, _, files in sorted(os.walk(path)): 20 | for f in files: 21 | # Generate archive name with relative path, if necessary 22 | name = os.path.join(os.path.relpath(root, path), f) 23 | 24 | # Write file to zip 25 | zfile.write(os.path.join(root, f), arcname=name) 26 | 27 | def unpack(self, path, output): 28 | with ZipFile(path, "r") as zfile: 29 | # Validate path if directory specified 30 | for fullpath in zfile.namelist(): 31 | fullpath = os.path.join(path, fullpath) 32 | if os.path.dirname(fullpath) and not self.validate(path, fullpath): 33 | raise IOError(f"Invalid zip entry: {fullpath}") 34 | 35 | zfile.extractall(output) 36 | -------------------------------------------------------------------------------- /src/python/txtai/cloud/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Cloud imports 3 | """ 4 | 5 | from .base import Cloud 6 | from .factory import CloudFactory 7 | from .hub import HuggingFaceHub 8 | from .storage import ObjectStorage 9 | -------------------------------------------------------------------------------- /src/python/txtai/cloud/factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Factory module 3 | """ 4 | 5 | from ..util import Resolver 6 | 7 | from .hub import HuggingFaceHub 8 | from .storage import ObjectStorage, LIBCLOUD 9 | 10 | 11 | class CloudFactory: 12 | """ 13 | Methods to create Cloud instances. 14 | """ 15 | 16 | @staticmethod 17 | def create(config): 18 | """ 19 | Creates a Cloud instance. 20 | 21 | Args: 22 | config: cloud configuration 23 | 24 | Returns: 25 | Cloud 26 | """ 27 | 28 | # Cloud instance 29 | cloud = None 30 | 31 | provider = config.get("provider", "") 32 | 33 | # Hugging Face Hub 34 | if provider.lower() == "huggingface-hub": 35 | cloud = HuggingFaceHub(config) 36 | 37 | # Cloud object storage 38 | elif ObjectStorage.isprovider(provider): 39 | cloud = ObjectStorage(config) 40 | 41 | # External provider 42 | elif provider: 43 | cloud = CloudFactory.resolve(provider, config) 44 | 45 | return cloud 46 | 47 | @staticmethod 48 | def resolve(backend, config): 49 | """ 50 | Attempt to resolve a custom cloud backend. 51 | 52 | Args: 53 | backend: backend class 54 | config: configuration parameters 55 | 56 | Returns: 57 | Cloud 58 | """ 59 | 60 | try: 61 | return Resolver()(backend)(config) 62 | 63 | except Exception as e: 64 | # Failure message 65 | message = f'Unable to resolve cloud backend: "{backend}".' 66 | 67 | # Append message if LIBCLOUD is not installed 68 | message += ' Cloud storage is not available - install "cloud" extra to enable' if not LIBCLOUD else "" 69 | 70 | raise ImportError(message) from e 71 | -------------------------------------------------------------------------------- /src/python/txtai/console/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Console imports 3 | """ 4 | 5 | from .base import Console 6 | -------------------------------------------------------------------------------- /src/python/txtai/console/__main__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main module. 3 | """ 4 | 5 | import sys 6 | 7 | from .base import Console 8 | 9 | 10 | def main(path=None): 11 | """ 12 | Console execution loop. 13 | 14 | Args: 15 | path: model path 16 | """ 17 | 18 | Console(path).cmdloop() 19 | 20 | 21 | if __name__ == "__main__": 22 | main(sys.argv[1] if len(sys.argv) > 1 else None) 23 | -------------------------------------------------------------------------------- /src/python/txtai/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data imports 3 | """ 4 | 5 | from .base import Data 6 | from .labels import Labels 7 | from .questions import Questions 8 | from .sequences import Sequences 9 | from .texts import Texts 10 | from .tokens import Tokens 11 | -------------------------------------------------------------------------------- /src/python/txtai/data/labels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Labels module 3 | """ 4 | 5 | from .base import Data 6 | 7 | 8 | class Labels(Data): 9 | """ 10 | Tokenizes text-classification datasets as input for training text-classification models. 11 | """ 12 | 13 | def __init__(self, tokenizer, columns, maxlength): 14 | """ 15 | Creates a new instance for tokenizing Labels training data. 16 | 17 | Args: 18 | tokenizer: model tokenizer 19 | columns: tuple of columns to use for text/label 20 | maxlength: maximum sequence length 21 | """ 22 | 23 | super().__init__(tokenizer, columns, maxlength) 24 | 25 | # Standardize columns 26 | if not self.columns: 27 | self.columns = ("text", None, "label") 28 | elif len(columns) < 3: 29 | self.columns = (self.columns[0], None, self.columns[-1]) 30 | 31 | def process(self, data): 32 | # Column keys 33 | text1, text2, label = self.columns 34 | 35 | # Tokenizer inputs can be single string or string pair, depending on task 36 | text = (data[text1], data[text2]) if text2 else (data[text1],) 37 | 38 | # Tokenize text and add label 39 | inputs = self.tokenizer(*text, max_length=self.maxlength, padding=True, truncation=True) 40 | inputs[label] = data[label] 41 | 42 | return inputs 43 | -------------------------------------------------------------------------------- /src/python/txtai/data/sequences.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sequences module 3 | """ 4 | 5 | from .base import Data 6 | 7 | 8 | class Sequences(Data): 9 | """ 10 | Tokenizes sequence-sequence datasets as input for training sequence-sequence models 11 | """ 12 | 13 | def __init__(self, tokenizer, columns, maxlength, prefix): 14 | """ 15 | Creates a new instance for tokenizing Sequences training data. 16 | 17 | Args: 18 | tokenizer: model tokenizer 19 | columns: tuple of columns to use for text/label 20 | maxlength: maximum sequence length 21 | prefix: source prefix 22 | """ 23 | 24 | super().__init__(tokenizer, columns, maxlength) 25 | 26 | # Standardize columns 27 | if not self.columns: 28 | self.columns = ("source", "target") 29 | 30 | # Save source prefix 31 | self.prefix = prefix 32 | 33 | def process(self, data): 34 | # Column keys 35 | source, target = self.columns 36 | 37 | # Tokenize source 38 | source = [self.prefix + x if self.prefix else x for x in data[source]] 39 | inputs = self.tokenizer(source, max_length=self.maxlength, padding=False, truncation=True) 40 | 41 | # Tokenize target 42 | with self.tokenizer.as_target_tokenizer(): 43 | targets = self.tokenizer(data[target], max_length=self.maxlength, padding=False, truncation=True) 44 | 45 | # Combine inputs 46 | inputs["labels"] = targets["input_ids"] 47 | 48 | return inputs 49 | -------------------------------------------------------------------------------- /src/python/txtai/data/tokens.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tokens module 3 | """ 4 | 5 | import torch 6 | 7 | 8 | class Tokens(torch.utils.data.Dataset): 9 | """ 10 | Default dataset used to hold tokenized data. 11 | """ 12 | 13 | def __init__(self, columns): 14 | self.data = [] 15 | 16 | # Map column-oriented data to rows 17 | for column in columns: 18 | for x, value in enumerate(columns[column]): 19 | if len(self.data) <= x: 20 | self.data.append({}) 21 | 22 | self.data[x][column] = value 23 | 24 | def __len__(self): 25 | return len(self.data) 26 | 27 | def __getitem__(self, index): 28 | return self.data[index] 29 | -------------------------------------------------------------------------------- /src/python/txtai/database/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Database imports 3 | """ 4 | 5 | from .base import Database 6 | from .client import Client 7 | from .duckdb import DuckDB 8 | from .embedded import Embedded 9 | from .encoder import * 10 | from .factory import DatabaseFactory 11 | from .rdbms import RDBMS 12 | from .schema import * 13 | from .sqlite import SQLite 14 | from .sql import * 15 | -------------------------------------------------------------------------------- /src/python/txtai/database/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Encoder imports 3 | """ 4 | 5 | from .base import Encoder 6 | from .factory import EncoderFactory 7 | from .image import ImageEncoder 8 | from .serialize import SerializeEncoder 9 | -------------------------------------------------------------------------------- /src/python/txtai/database/encoder/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Encoder module 3 | """ 4 | 5 | from io import BytesIO 6 | 7 | 8 | class Encoder: 9 | """ 10 | Encodes and decodes object content. The base encoder works only with byte arrays. It can be extended to encode different datatypes. 11 | """ 12 | 13 | def encode(self, obj): 14 | """ 15 | Encodes an object to a byte array using the encoder. 16 | 17 | Args: 18 | obj: object to encode 19 | 20 | Returns: 21 | encoded object as a byte array 22 | """ 23 | 24 | return obj 25 | 26 | def decode(self, data): 27 | """ 28 | Decodes input byte array into an object using this encoder. 29 | 30 | Args: 31 | data: encoded data 32 | 33 | Returns: 34 | decoded object 35 | """ 36 | 37 | return BytesIO(data) if data else None 38 | -------------------------------------------------------------------------------- /src/python/txtai/database/encoder/factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Encoder factory module 3 | """ 4 | 5 | from ...util import Resolver 6 | 7 | from .base import Encoder 8 | from .serialize import SerializeEncoder 9 | 10 | 11 | class EncoderFactory: 12 | """ 13 | Encoder factory. Creates new Encoder instances. 14 | """ 15 | 16 | @staticmethod 17 | def get(encoder): 18 | """ 19 | Gets a new instance of encoder class. 20 | 21 | Args: 22 | encoder: Encoder instance class 23 | 24 | Returns: 25 | Encoder class 26 | """ 27 | 28 | # Local task if no package 29 | if "." not in encoder: 30 | # Get parent package 31 | encoder = ".".join(__name__.split(".")[:-1]) + "." + encoder.capitalize() + "Encoder" 32 | 33 | return Resolver()(encoder) 34 | 35 | @staticmethod 36 | def create(encoder): 37 | """ 38 | Creates a new Encoder instance. 39 | 40 | Args: 41 | encoder: Encoder instance class 42 | 43 | Returns: 44 | Encoder 45 | """ 46 | 47 | # Return default encoder 48 | if encoder is True: 49 | return Encoder() 50 | 51 | # Supported serialization methods 52 | if encoder in ["messagepack", "pickle"]: 53 | return SerializeEncoder(encoder) 54 | 55 | # Get Encoder instance 56 | return EncoderFactory.get(encoder)() 57 | -------------------------------------------------------------------------------- /src/python/txtai/database/encoder/image.py: -------------------------------------------------------------------------------- 1 | """ 2 | ImageEncoder module 3 | """ 4 | 5 | from io import BytesIO 6 | 7 | # Conditional import 8 | try: 9 | from PIL import Image 10 | 11 | PIL = True 12 | except ImportError: 13 | PIL = False 14 | 15 | from .base import Encoder 16 | 17 | 18 | class ImageEncoder(Encoder): 19 | """ 20 | Encodes and decodes Image objects as compressed binary content, using the original image's algorithm. 21 | """ 22 | 23 | def __init__(self): 24 | """ 25 | Creates a new ImageEncoder. 26 | """ 27 | 28 | if not PIL: 29 | raise ImportError('ImageEncoder is not available - install "database" extra to enable') 30 | 31 | def encode(self, obj): 32 | # Create byte stream 33 | output = BytesIO() 34 | 35 | # Write image to byte stream 36 | obj.save(output, format=obj.format, quality="keep") 37 | 38 | # Return byte array 39 | return output.getvalue() 40 | 41 | def decode(self, data): 42 | # Return a PIL image 43 | return Image.open(BytesIO(data)) if data else None 44 | -------------------------------------------------------------------------------- /src/python/txtai/database/encoder/serialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | SerializeEncoder module 3 | """ 4 | 5 | from ...serialize import SerializeFactory 6 | 7 | from .base import Encoder 8 | 9 | 10 | class SerializeEncoder(Encoder): 11 | """ 12 | Encodes and decodes objects using the internal serialize package. 13 | """ 14 | 15 | def __init__(self, method): 16 | # Parent constructor 17 | super().__init__() 18 | 19 | # Pickle serialization 20 | self.serializer = SerializeFactory.create(method) 21 | 22 | def encode(self, obj): 23 | # Pickle object 24 | return self.serializer.savebytes(obj) 25 | 26 | def decode(self, data): 27 | # Unpickle to object 28 | return self.serializer.loadbytes(data) 29 | -------------------------------------------------------------------------------- /src/python/txtai/database/factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Factory module 3 | """ 4 | 5 | from urllib.parse import urlparse 6 | 7 | from ..util import Resolver 8 | 9 | from .client import Client 10 | from .duckdb import DuckDB 11 | from .sqlite import SQLite 12 | 13 | 14 | class DatabaseFactory: 15 | """ 16 | Methods to create document databases. 17 | """ 18 | 19 | @staticmethod 20 | def create(config): 21 | """ 22 | Create a Database. 23 | 24 | Args: 25 | config: database configuration parameters 26 | 27 | Returns: 28 | Database 29 | """ 30 | 31 | # Database instance 32 | database = None 33 | 34 | # Enables document database 35 | content = config.get("content") 36 | 37 | # Standardize content name 38 | if content is True: 39 | content = "sqlite" 40 | 41 | # Create document database instance 42 | if content == "duckdb": 43 | database = DuckDB(config) 44 | elif content == "sqlite": 45 | database = SQLite(config) 46 | elif content: 47 | # Check if content is a URL 48 | url = urlparse(content) 49 | if content == "client" or url.scheme: 50 | # Connect to database server URL 51 | database = Client(config) 52 | else: 53 | # Resolve custom database if content is not a URL 54 | database = DatabaseFactory.resolve(content, config) 55 | 56 | # Store config back 57 | config["content"] = content 58 | 59 | return database 60 | 61 | @staticmethod 62 | def resolve(backend, config): 63 | """ 64 | Attempt to resolve a custom backend. 65 | 66 | Args: 67 | backend: backend class 68 | config: index configuration parameters 69 | 70 | Returns: 71 | Database 72 | """ 73 | 74 | try: 75 | return Resolver()(backend)(config) 76 | except Exception as e: 77 | raise ImportError(f"Unable to resolve database backend: '{backend}'") from e 78 | -------------------------------------------------------------------------------- /src/python/txtai/database/schema/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Schema imports 3 | """ 4 | 5 | from .orm import * 6 | from .statement import Statement 7 | -------------------------------------------------------------------------------- /src/python/txtai/database/sql/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | SQL imports 3 | """ 4 | 5 | from .aggregate import Aggregate 6 | from .base import SQL, SQLError 7 | from .expression import Expression 8 | from .token import Token 9 | -------------------------------------------------------------------------------- /src/python/txtai/database/sqlite.py: -------------------------------------------------------------------------------- 1 | """ 2 | SQLite module 3 | """ 4 | 5 | import os 6 | import sqlite3 7 | 8 | from .embedded import Embedded 9 | 10 | 11 | class SQLite(Embedded): 12 | """ 13 | Database instance backed by SQLite. 14 | """ 15 | 16 | def connect(self, path=""): 17 | # Create connection 18 | connection = sqlite3.connect(path, check_same_thread=False) 19 | 20 | # Enable WAL mode, if necessary 21 | if self.setting("wal"): 22 | connection.execute("PRAGMA journal_mode=WAL") 23 | 24 | return connection 25 | 26 | def getcursor(self): 27 | return self.connection.cursor() 28 | 29 | def rows(self): 30 | return self.cursor 31 | 32 | def addfunctions(self): 33 | if self.connection and self.functions: 34 | # Enable callback tracebacks to show user-defined function errors 35 | sqlite3.enable_callback_tracebacks(True) 36 | 37 | for name, argcount, fn in self.functions: 38 | self.connection.create_function(name, argcount, fn) 39 | 40 | def copy(self, path): 41 | # Delete existing file, if necessary 42 | if os.path.exists(path): 43 | os.remove(path) 44 | 45 | # Create database. Thread locking must be handled externally. 46 | connection = self.connect(path) 47 | 48 | if self.connection.in_transaction: 49 | # The backup call will hang if there are uncommitted changes, need to copy over 50 | # with iterdump (which is much slower) 51 | for sql in self.connection.iterdump(): 52 | connection.execute(sql) 53 | else: 54 | # Database is up to date, can do a more efficient copy with SQLite C API 55 | self.connection.backup(connection) 56 | 57 | return connection 58 | -------------------------------------------------------------------------------- /src/python/txtai/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Embeddings imports 3 | """ 4 | 5 | from .base import Embeddings 6 | from .index import * 7 | from .search import * 8 | -------------------------------------------------------------------------------- /src/python/txtai/embeddings/index/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Index imports 3 | """ 4 | 5 | from .action import Action 6 | from .autoid import AutoId 7 | from .configuration import Configuration 8 | from .documents import Documents 9 | from .functions import Functions 10 | from .indexes import Indexes 11 | from .indexids import IndexIds 12 | from .reducer import Reducer 13 | from .stream import Stream 14 | from .transform import Transform 15 | -------------------------------------------------------------------------------- /src/python/txtai/embeddings/index/action.py: -------------------------------------------------------------------------------- 1 | """ 2 | Action module 3 | """ 4 | 5 | from enum import Enum 6 | 7 | 8 | class Action(Enum): 9 | """ 10 | Index action types 11 | """ 12 | 13 | INDEX = 1 14 | UPSERT = 2 15 | REINDEX = 3 16 | -------------------------------------------------------------------------------- /src/python/txtai/embeddings/index/indexids.py: -------------------------------------------------------------------------------- 1 | """ 2 | IndexIds module 3 | """ 4 | 5 | from ...serialize import Serializer 6 | 7 | 8 | class IndexIds: 9 | """ 10 | Stores index ids when content is disabled. 11 | """ 12 | 13 | def __init__(self, embeddings, ids=None): 14 | """ 15 | Creates an IndexIds instance. 16 | 17 | Args: 18 | embeddings: embeddings instance 19 | ids: ids to store 20 | """ 21 | 22 | self.config = embeddings.config 23 | self.ids = ids 24 | 25 | def __iter__(self): 26 | yield from self.ids 27 | 28 | def __getitem__(self, index): 29 | return self.ids[index] 30 | 31 | def __setitem__(self, index, value): 32 | self.ids[index] = value 33 | 34 | def __add__(self, ids): 35 | return self.ids + ids 36 | 37 | def load(self, path): 38 | """ 39 | Loads IndexIds from path. 40 | 41 | Args: 42 | path: path to load 43 | """ 44 | 45 | if "ids" in self.config: 46 | # Legacy ids format 47 | self.ids = self.config.pop("ids") 48 | else: 49 | # Standard ids format 50 | self.ids = Serializer.load(path) 51 | 52 | def save(self, path): 53 | """ 54 | Saves IndexIds to path. 55 | 56 | Args: 57 | path: path to save 58 | """ 59 | 60 | Serializer.save(self.ids, path) 61 | -------------------------------------------------------------------------------- /src/python/txtai/embeddings/search/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Search imports 3 | """ 4 | 5 | from .base import Search 6 | from .errors import * 7 | from .explain import Explain 8 | from .ids import Ids 9 | from .query import Query 10 | from .scan import Scan 11 | from .terms import Terms 12 | -------------------------------------------------------------------------------- /src/python/txtai/embeddings/search/errors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Errors module 3 | """ 4 | 5 | 6 | class IndexNotFoundError(Exception): 7 | """ 8 | Raised when an embeddings query fails to locate an index 9 | """ 10 | -------------------------------------------------------------------------------- /src/python/txtai/embeddings/search/ids.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ids module 3 | """ 4 | 5 | 6 | class Ids: 7 | """ 8 | Resolves internal ids for lists of ids. 9 | """ 10 | 11 | def __init__(self, embeddings): 12 | """ 13 | Create a new ids action. 14 | 15 | Args: 16 | embeddings: embeddings instance 17 | """ 18 | 19 | self.database = embeddings.database 20 | self.ids = embeddings.ids 21 | 22 | def __call__(self, ids): 23 | """ 24 | Resolve internal ids. 25 | 26 | Args: 27 | ids: ids 28 | 29 | Returns: 30 | internal ids 31 | """ 32 | 33 | # Resolve ids using database if available, otherwise fallback to embeddings ids 34 | results = self.database.ids(ids) if self.database else self.scan(ids) 35 | 36 | # Create dict of id: [iids] given there is a one to many relationship 37 | ids = {} 38 | for iid, uid in results: 39 | if uid not in ids: 40 | ids[uid] = [] 41 | ids[uid].append(iid) 42 | 43 | return ids 44 | 45 | def scan(self, ids): 46 | """ 47 | Scans embeddings ids array for matches when content is disabled. 48 | 49 | Args: 50 | ids: search ids 51 | 52 | Returns: 53 | internal ids 54 | """ 55 | 56 | # Find existing ids 57 | indices = [] 58 | for uid in ids: 59 | indices.extend([(index, value) for index, value in enumerate(self.ids) if uid == value]) 60 | 61 | return indices 62 | -------------------------------------------------------------------------------- /src/python/txtai/embeddings/search/query.py: -------------------------------------------------------------------------------- 1 | """ 2 | Query module 3 | """ 4 | 5 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration 6 | 7 | 8 | class Query: 9 | """ 10 | Query translation model. 11 | """ 12 | 13 | def __init__(self, path, prefix=None, maxlength=512): 14 | """ 15 | Creates a query translation model. 16 | 17 | Args: 18 | path: path to query model 19 | prefix: text prefix 20 | maxlength: max sequence length to generate 21 | """ 22 | 23 | self.tokenizer = AutoTokenizer.from_pretrained(path) 24 | self.model = AutoModelForSeq2SeqLM.from_pretrained(path) 25 | 26 | # Default prefix if not provided for T5 models 27 | if not prefix and isinstance(self.model, T5ForConditionalGeneration): 28 | prefix = "translate English to SQL: " 29 | 30 | self.prefix = prefix 31 | self.maxlength = maxlength 32 | 33 | def __call__(self, query): 34 | """ 35 | Runs query translation model. 36 | 37 | Args: 38 | query: input query 39 | 40 | Returns: 41 | transformed query 42 | """ 43 | 44 | # Add prefix, if necessary 45 | if self.prefix: 46 | query = f"{self.prefix}{query}" 47 | 48 | # Tokenize and generate text using model 49 | features = self.tokenizer([query], return_tensors="pt") 50 | output = self.model.generate(input_ids=features["input_ids"], attention_mask=features["attention_mask"], max_length=self.maxlength) 51 | 52 | # Decode tokens to text 53 | result = self.tokenizer.decode(output[0], skip_special_tokens=True) 54 | 55 | # Clean and return generated text 56 | return self.clean(result) 57 | 58 | def clean(self, text): 59 | """ 60 | Applies a series of rules to clean generated text. 61 | 62 | Args: 63 | text: input text 64 | 65 | Returns: 66 | clean text 67 | """ 68 | 69 | return text.replace("$=", "<=") 70 | -------------------------------------------------------------------------------- /src/python/txtai/embeddings/search/terms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Terms module 3 | """ 4 | 5 | 6 | class Terms: 7 | """ 8 | Reduces a query statement down to keyword terms. This method extracts the query text from similar clauses if it's a SQL statement. 9 | Otherwise, the original query is returned. 10 | """ 11 | 12 | def __init__(self, embeddings): 13 | """ 14 | Create a new terms action. 15 | 16 | Args: 17 | embeddings: embeddings instance 18 | """ 19 | 20 | self.database = embeddings.database 21 | 22 | def __call__(self, queries): 23 | """ 24 | Extracts keyword terms from a list of queries. 25 | 26 | Args: 27 | queries: list of queries 28 | 29 | Returns: 30 | list of queries reduced down to keyword term strings 31 | """ 32 | 33 | # Parse queries and extract keyword terms for each query 34 | if self.database: 35 | terms = [] 36 | for query in queries: 37 | # Parse query 38 | parse = self.database.parse(query) 39 | 40 | # Join terms from similar clauses 41 | terms.append(" ".join(" ".join(s) for s in parse["similar"])) 42 | 43 | return terms 44 | 45 | # Return original query when database is None 46 | return queries 47 | -------------------------------------------------------------------------------- /src/python/txtai/graph/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Graph imports 3 | """ 4 | 5 | from .base import Graph 6 | from .factory import GraphFactory 7 | from .networkx import NetworkX 8 | from .query import Query 9 | from .rdbms import RDBMS 10 | from .topics import Topics 11 | -------------------------------------------------------------------------------- /src/python/txtai/graph/factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Factory module 3 | """ 4 | 5 | from ..util import Resolver 6 | 7 | from .networkx import NetworkX 8 | from .rdbms import RDBMS 9 | 10 | 11 | class GraphFactory: 12 | """ 13 | Methods to create graphs. 14 | """ 15 | 16 | @staticmethod 17 | def create(config): 18 | """ 19 | Create a Graph. 20 | 21 | Args: 22 | config: graph configuration 23 | 24 | Returns: 25 | Graph 26 | """ 27 | 28 | # Graph instance 29 | graph = None 30 | backend = config.get("backend", "networkx") 31 | 32 | # Create graph instance 33 | if backend == "networkx": 34 | graph = NetworkX(config) 35 | elif backend == "rdbms": 36 | graph = RDBMS(config) 37 | else: 38 | graph = GraphFactory.resolve(backend, config) 39 | 40 | # Store config back 41 | config["backend"] = backend 42 | 43 | return graph 44 | 45 | @staticmethod 46 | def resolve(backend, config): 47 | """ 48 | Attempt to resolve a custom backend. 49 | 50 | Args: 51 | backend: backend class 52 | config: index configuration parameters 53 | 54 | Returns: 55 | Graph 56 | """ 57 | 58 | try: 59 | return Resolver()(backend)(config) 60 | except Exception as e: 61 | raise ImportError(f"Unable to resolve graph backend: '{backend}'") from e 62 | -------------------------------------------------------------------------------- /src/python/txtai/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Models imports 3 | """ 4 | 5 | from .models import Models 6 | from .onnx import OnnxModel 7 | from .pooling import * 8 | from .registry import Registry 9 | from .tokendetection import TokenDetection 10 | -------------------------------------------------------------------------------- /src/python/txtai/models/pooling/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pooling imports 3 | """ 4 | 5 | from .base import Pooling 6 | from .cls import ClsPooling 7 | from .factory import PoolingFactory 8 | from .mean import MeanPooling 9 | -------------------------------------------------------------------------------- /src/python/txtai/models/pooling/cls.py: -------------------------------------------------------------------------------- 1 | """ 2 | CLS module 3 | """ 4 | 5 | from .base import Pooling 6 | 7 | 8 | class ClsPooling(Pooling): 9 | """ 10 | Builds CLS pooled vectors using outputs from a transformers model. 11 | """ 12 | 13 | def forward(self, **inputs): 14 | """ 15 | Runs CLS pooling on token embeddings. 16 | 17 | Args: 18 | inputs: model inputs 19 | 20 | Returns: 21 | CLS pooled embeddings using output token embeddings (i.e. last hidden state) 22 | """ 23 | 24 | # Run through transformers model 25 | tokens = super().forward(**inputs) 26 | 27 | # CLS token pooling 28 | return tokens[:, 0] 29 | -------------------------------------------------------------------------------- /src/python/txtai/models/pooling/mean.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mean module 3 | """ 4 | 5 | import torch 6 | 7 | from .base import Pooling 8 | 9 | 10 | class MeanPooling(Pooling): 11 | """ 12 | Builds mean pooled vectors usings outputs from a transformers model. 13 | """ 14 | 15 | def forward(self, **inputs): 16 | """ 17 | Runs mean pooling on token embeddings taking the input mask into account. 18 | 19 | Args: 20 | inputs: model inputs 21 | 22 | Returns: 23 | mean pooled embeddings using output token embeddings (i.e. last hidden state) 24 | """ 25 | 26 | # Run through transformers model 27 | tokens = super().forward(**inputs) 28 | mask = inputs["attention_mask"] 29 | 30 | # Mean pooling 31 | # pylint: disable=E1101 32 | mask = mask.unsqueeze(-1).expand(tokens.size()).float() 33 | return torch.sum(tokens * mask, 1) / torch.clamp(mask.sum(1), min=1e-9) 34 | -------------------------------------------------------------------------------- /src/python/txtai/models/registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Registry module 3 | """ 4 | 5 | from transformers import AutoModel, AutoModelForQuestionAnswering, AutoModelForSequenceClassification 6 | from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING 7 | 8 | 9 | class Registry: 10 | """ 11 | Methods to register models and fully support pipelines. 12 | """ 13 | 14 | @staticmethod 15 | def register(model, config=None): 16 | """ 17 | Registers a model with auto model and tokenizer configuration to fully support pipelines. 18 | 19 | Args: 20 | model: model to register 21 | config: config class name 22 | """ 23 | 24 | # Default config class to model class if not provided 25 | config = config if config else model.__class__ 26 | 27 | # Default model config_class if empty 28 | if hasattr(model.__class__, "config_class") and not model.__class__.config_class: 29 | model.__class__.config_class = config 30 | 31 | # Add references for this class to supported AutoModel classes 32 | for mapping in [AutoModel, AutoModelForQuestionAnswering, AutoModelForSequenceClassification]: 33 | mapping.register(config, model.__class__) 34 | 35 | # Add references for this class to support pipeline AutoTokenizers 36 | if hasattr(model, "config") and type(model.config) not in TOKENIZER_MAPPING: 37 | TOKENIZER_MAPPING.register(type(model.config), type(model.config).__name__) 38 | -------------------------------------------------------------------------------- /src/python/txtai/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pipeline imports 3 | """ 4 | 5 | from .audio import * 6 | from .base import Pipeline 7 | from .data import * 8 | from .factory import PipelineFactory 9 | from .hfmodel import HFModel 10 | from .hfpipeline import HFPipeline 11 | from .image import * 12 | from .llm import * 13 | from .llm import RAG as Extractor 14 | from .nop import Nop 15 | from .text import * 16 | from .tensors import Tensors 17 | from .train import * 18 | -------------------------------------------------------------------------------- /src/python/txtai/pipeline/audio/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Audio imports 3 | """ 4 | 5 | from .audiomixer import AudioMixer 6 | from .audiostream import AudioStream 7 | from .microphone import Microphone 8 | from .signal import Signal 9 | from .texttoaudio import TextToAudio 10 | from .texttospeech import TextToSpeech 11 | from .transcription import Transcription 12 | -------------------------------------------------------------------------------- /src/python/txtai/pipeline/audio/audiomixer.py: -------------------------------------------------------------------------------- 1 | """ 2 | AudioMixer module 3 | """ 4 | 5 | from ..base import Pipeline 6 | from .signal import Signal, SCIPY 7 | 8 | 9 | class AudioMixer(Pipeline): 10 | """ 11 | Mixes multiple audio streams into a single stream. 12 | """ 13 | 14 | def __init__(self, rate=None): 15 | """ 16 | Creates an AudioMixer pipeline. 17 | 18 | Args: 19 | rate: optional target sample rate, otherwise uses input target rate with each audio segment 20 | """ 21 | 22 | if not SCIPY: 23 | raise ImportError('AudioMixer pipeline is not available - install "pipeline" extra to enable.') 24 | 25 | # Target sample rate 26 | self.rate = rate 27 | 28 | def __call__(self, segment, scale1=1, scale2=1): 29 | """ 30 | Mixes multiple audio streams into a single stream. 31 | 32 | Args: 33 | segment: ((audio1, sample rate), (audio2, sample rate))|list 34 | scale1: optional scaling factor for segment1 35 | scale2: optional scaling factor for segment2 36 | 37 | Returns: 38 | list of (audio, sample rate) 39 | """ 40 | 41 | # Convert single element to list 42 | segments = [segment] if isinstance(segment, tuple) else segment 43 | 44 | results = [] 45 | for segment1, segment2 in segments: 46 | audio1, rate1 = segment1 47 | audio2, rate2 = segment2 48 | 49 | # Resample audio, as necessary 50 | target = self.rate if self.rate else rate1 51 | audio1 = Signal.resample(audio1, rate1, target) 52 | audio2 = Signal.resample(audio2, rate2, target) 53 | 54 | # Mix audio into single segment 55 | results.append((Signal.mix(audio1, audio2, scale1, scale2), target)) 56 | 57 | # Return single element if single element passed in 58 | return results[0] if isinstance(segment, tuple) else results 59 | -------------------------------------------------------------------------------- /src/python/txtai/pipeline/audio/texttoaudio.py: -------------------------------------------------------------------------------- 1 | """ 2 | TextToAudio module 3 | """ 4 | 5 | from ..hfpipeline import HFPipeline 6 | from .signal import Signal, SCIPY 7 | 8 | 9 | class TextToAudio(HFPipeline): 10 | """ 11 | Generates audio from text. 12 | """ 13 | 14 | def __init__(self, path=None, quantize=False, gpu=True, model=None, rate=None, **kwargs): 15 | if not SCIPY: 16 | raise ImportError('TextToAudio pipeline is not available - install "pipeline" extra to enable.') 17 | 18 | # Call parent constructor 19 | super().__init__("text-to-audio", path, quantize, gpu, model, **kwargs) 20 | 21 | # Target sample rate, defaults to model sample rate 22 | self.rate = rate 23 | 24 | def __call__(self, text, maxlength=512): 25 | """ 26 | Generates audio from text. 27 | 28 | This method supports text as a string or a list. If the input is a string, 29 | the return type is a single audio output. If text is a list, the return type is a list. 30 | 31 | Args: 32 | text: text|list 33 | maxlength: maximum audio length to generate 34 | 35 | Returns: 36 | list of (audio, sample rate) 37 | """ 38 | 39 | # Format inputs 40 | texts = [text] if isinstance(text, str) else text 41 | 42 | # Run pipeline 43 | results = [self.convert(x) for x in self.pipeline(texts, forward_params={"max_new_tokens": maxlength})] 44 | 45 | # Extract results 46 | return results[0] if isinstance(text, str) else results 47 | 48 | def convert(self, result): 49 | """ 50 | Converts audio result to target sample rate for this pipeline, if set. 51 | 52 | Args: 53 | result: dict with audio samples and sample rate 54 | 55 | Returns: 56 | (audio, sample rate) 57 | """ 58 | 59 | audio, rate = result["audio"].squeeze(), result["sampling_rate"] 60 | return (Signal.resample(audio, rate, self.rate), self.rate) if self.rate else (audio, rate) 61 | -------------------------------------------------------------------------------- /src/python/txtai/pipeline/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pipeline module 3 | """ 4 | 5 | 6 | class Pipeline: 7 | """ 8 | Base class for all Pipelines. The only interface requirement is to define a __call___ method. 9 | """ 10 | 11 | def batch(self, data, size): 12 | """ 13 | Splits data into separate batch sizes specified by size. 14 | 15 | Args: 16 | data: data elements 17 | size: batch size 18 | 19 | Returns: 20 | list of evenly sized batches with the last batch having the remaining elements 21 | """ 22 | 23 | return [data[x : x + size] for x in range(0, len(data), size)] 24 | -------------------------------------------------------------------------------- /src/python/txtai/pipeline/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Segment imports 3 | """ 4 | 5 | from .filetohtml import FileToHTML 6 | from .htmltomd import HTMLToMarkdown 7 | from .segmentation import Segmentation 8 | from .tabular import Tabular 9 | from .textractor import Textractor 10 | from .tokenizer import Tokenizer 11 | -------------------------------------------------------------------------------- /src/python/txtai/pipeline/factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pipeline factory module 3 | """ 4 | 5 | import inspect 6 | import sys 7 | import types 8 | 9 | from ..util import Resolver 10 | 11 | from .base import Pipeline 12 | 13 | 14 | class PipelineFactory: 15 | """ 16 | Pipeline factory. Creates new Pipeline instances. 17 | """ 18 | 19 | @staticmethod 20 | def get(pipeline): 21 | """ 22 | Gets a new instance of pipeline class. 23 | 24 | Args: 25 | pclass: Pipeline instance class 26 | 27 | Returns: 28 | Pipeline class 29 | """ 30 | 31 | # Local pipeline if no package 32 | if "." not in pipeline: 33 | return PipelineFactory.list()[pipeline] 34 | 35 | # Attempt to load custom pipeline 36 | return Resolver()(pipeline) 37 | 38 | @staticmethod 39 | def create(config, pipeline): 40 | """ 41 | Creates a new Pipeline instance. 42 | 43 | Args: 44 | config: Pipeline configuration 45 | pipeline: Pipeline instance class 46 | 47 | Returns: 48 | Pipeline 49 | """ 50 | 51 | # Resolve pipeline 52 | pipeline = PipelineFactory.get(pipeline) 53 | 54 | # Return functions directly, otherwise create pipeline instance 55 | return pipeline if isinstance(pipeline, types.FunctionType) else pipeline(**config) 56 | 57 | @staticmethod 58 | def list(): 59 | """ 60 | Lists callable pipelines. 61 | 62 | Returns: 63 | {short name: pipeline class} 64 | """ 65 | 66 | pipelines = {} 67 | 68 | # Get handle to pipeline module 69 | pipeline = sys.modules[".".join(__name__.split(".")[:-1])] 70 | 71 | # Get list of callable pipelines 72 | for x in inspect.getmembers(pipeline, inspect.isclass): 73 | if issubclass(x[1], Pipeline) and [y for y, _ in inspect.getmembers(x[1], inspect.isfunction) if y == "__call__"]: 74 | # short name: pipeline class 75 | pipelines[x[0].lower()] = x[1] 76 | 77 | return pipelines 78 | -------------------------------------------------------------------------------- /src/python/txtai/pipeline/image/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image imports 3 | """ 4 | 5 | from .caption import Caption 6 | from .imagehash import ImageHash 7 | from .objects import Objects 8 | -------------------------------------------------------------------------------- /src/python/txtai/pipeline/image/caption.py: -------------------------------------------------------------------------------- 1 | """ 2 | Caption module 3 | """ 4 | 5 | # Conditional import 6 | try: 7 | from PIL import Image 8 | 9 | PIL = True 10 | except ImportError: 11 | PIL = False 12 | 13 | from ..hfpipeline import HFPipeline 14 | 15 | 16 | class Caption(HFPipeline): 17 | """ 18 | Constructs captions for images. 19 | """ 20 | 21 | def __init__(self, path=None, quantize=False, gpu=True, model=None, **kwargs): 22 | if not PIL: 23 | raise ImportError('Captions pipeline is not available - install "pipeline" extra to enable') 24 | 25 | # Call parent constructor 26 | super().__init__("image-to-text", path, quantize, gpu, model, **kwargs) 27 | 28 | def __call__(self, images): 29 | """ 30 | Builds captions for images. 31 | 32 | This method supports a single image or a list of images. If the input is an image, the return 33 | type is a string. If text is a list, a list of strings is returned 34 | 35 | Args: 36 | images: image|list 37 | 38 | Returns: 39 | list of captions 40 | """ 41 | 42 | # Convert single element to list 43 | values = [images] if not isinstance(images, list) else images 44 | 45 | # Open images if file strings 46 | values = [Image.open(image) if isinstance(image, str) else image for image in values] 47 | 48 | # Get and clean captions 49 | captions = [] 50 | for result in self.pipeline(values): 51 | text = " ".join([r["generated_text"] for r in result]).strip() 52 | captions.append(text) 53 | 54 | # Return single element if single element passed in 55 | return captions[0] if not isinstance(images, list) else captions 56 | -------------------------------------------------------------------------------- /src/python/txtai/pipeline/llm/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | LLM imports 3 | """ 4 | 5 | from .factory import GenerationFactory 6 | from .generation import Generation 7 | from .huggingface import * 8 | from .litellm import LiteLLM 9 | from .llama import LlamaCpp 10 | from .llm import LLM 11 | from .rag import RAG 12 | -------------------------------------------------------------------------------- /src/python/txtai/pipeline/nop.py: -------------------------------------------------------------------------------- 1 | """ 2 | No-Op module 3 | """ 4 | 5 | from .base import Pipeline 6 | 7 | 8 | class Nop(Pipeline): 9 | """ 10 | Simple no-op pipeline that returns inputs 11 | """ 12 | 13 | def __call__(self, inputs): 14 | return inputs 15 | -------------------------------------------------------------------------------- /src/python/txtai/pipeline/tensors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tensor processing framework module 3 | """ 4 | 5 | import torch 6 | 7 | from .base import Pipeline 8 | 9 | 10 | class Tensors(Pipeline): 11 | """ 12 | Pipeline backed by a tensor processing framework. Currently supports PyTorch. 13 | """ 14 | 15 | def quantize(self, model): 16 | """ 17 | Quantizes input model and returns. This only is supported for CPU devices. 18 | 19 | Args: 20 | model: torch model 21 | 22 | Returns: 23 | quantized torch model 24 | """ 25 | 26 | # pylint: disable=E1101 27 | return torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) 28 | 29 | def tensor(self, data): 30 | """ 31 | Creates a tensor array. 32 | 33 | Args: 34 | data: input data 35 | 36 | Returns: 37 | tensor 38 | """ 39 | 40 | # pylint: disable=E1102 41 | return torch.tensor(data) 42 | 43 | def context(self): 44 | """ 45 | Defines a context used to wrap processing with the tensor processing framework. 46 | 47 | Returns: 48 | processing context 49 | """ 50 | 51 | # pylint: disable=E1101 52 | return torch.no_grad() 53 | -------------------------------------------------------------------------------- /src/python/txtai/pipeline/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Text imports 3 | """ 4 | 5 | from .crossencoder import CrossEncoder 6 | from .entity import Entity 7 | from .labels import Labels 8 | from .questions import Questions 9 | from .similarity import Similarity 10 | from .summary import Summary 11 | from .translation import Translation 12 | -------------------------------------------------------------------------------- /src/python/txtai/pipeline/text/questions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Questions module 3 | """ 4 | 5 | from ..hfpipeline import HFPipeline 6 | 7 | 8 | class Questions(HFPipeline): 9 | """ 10 | Runs extractive QA for a series of questions and contexts. 11 | """ 12 | 13 | def __init__(self, path=None, quantize=False, gpu=True, model=None, **kwargs): 14 | super().__init__("question-answering", path, quantize, gpu, model, **kwargs) 15 | 16 | def __call__(self, questions, contexts, workers=0): 17 | """ 18 | Runs a extractive question-answering model against each question-context pair, finding the best answers. 19 | 20 | Args: 21 | questions: list of questions 22 | contexts: list of contexts to pull answers from 23 | workers: number of concurrent workers to use for processing data, defaults to None 24 | 25 | Returns: 26 | list of answers 27 | """ 28 | 29 | answers = [] 30 | 31 | for x, question in enumerate(questions): 32 | if question and contexts[x]: 33 | # Run the QA pipeline 34 | result = self.pipeline(question=question, context=contexts[x], num_workers=workers) 35 | 36 | # Get answer and score 37 | answer, score = result["answer"], result["score"] 38 | 39 | # Require score to be at least 0.05 40 | if score < 0.05: 41 | answer = None 42 | 43 | # Add answer 44 | answers.append(answer) 45 | else: 46 | answers.append(None) 47 | 48 | return answers 49 | -------------------------------------------------------------------------------- /src/python/txtai/pipeline/train/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train imports 3 | """ 4 | 5 | from .hfonnx import HFOnnx 6 | from .hftrainer import HFTrainer 7 | from .mlonnx import MLOnnx 8 | -------------------------------------------------------------------------------- /src/python/txtai/scoring/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Scoring imports 3 | """ 4 | 5 | from .base import Scoring 6 | from .bm25 import BM25 7 | from .factory import ScoringFactory 8 | from .pgtext import PGText 9 | from .sif import SIF 10 | from .terms import Terms 11 | from .tfidf import TFIDF 12 | -------------------------------------------------------------------------------- /src/python/txtai/scoring/bm25.py: -------------------------------------------------------------------------------- 1 | """ 2 | BM25 module 3 | """ 4 | 5 | import numpy as np 6 | 7 | from .tfidf import TFIDF 8 | 9 | 10 | class BM25(TFIDF): 11 | """ 12 | Best matching (BM25) scoring. 13 | """ 14 | 15 | def __init__(self, config=None): 16 | super().__init__(config) 17 | 18 | # BM25 configurable parameters 19 | self.k1 = self.config.get("k1", 1.2) 20 | self.b = self.config.get("b", 0.75) 21 | 22 | def computeidf(self, freq): 23 | # Calculate BM25 IDF score 24 | return np.log(1 + (self.total - freq + 0.5) / (freq + 0.5)) 25 | 26 | def score(self, freq, idf, length): 27 | # Calculate BM25 score 28 | k = self.k1 * ((1 - self.b) + self.b * length / self.avgdl) 29 | return idf * (freq * (self.k1 + 1)) / (freq + k) 30 | -------------------------------------------------------------------------------- /src/python/txtai/scoring/factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Factory module 3 | """ 4 | 5 | from ..util import Resolver 6 | 7 | from .bm25 import BM25 8 | from .sif import SIF 9 | from .pgtext import PGText 10 | from .tfidf import TFIDF 11 | 12 | 13 | class ScoringFactory: 14 | """ 15 | Methods to create Scoring indexes. 16 | """ 17 | 18 | @staticmethod 19 | def create(config): 20 | """ 21 | Factory method to construct a Scoring instance. 22 | 23 | Args: 24 | config: scoring configuration parameters - supports bm25, sif, tfidf 25 | 26 | Returns: 27 | Scoring 28 | """ 29 | 30 | # Scoring instance 31 | scoring = None 32 | 33 | # Support string and dict configuration 34 | if isinstance(config, str): 35 | config = {"method": config} 36 | 37 | # Get scoring method 38 | method = config.get("method", "bm25") 39 | 40 | if method == "bm25": 41 | scoring = BM25(config) 42 | elif method == "sif": 43 | scoring = SIF(config) 44 | elif method == "pgtext": 45 | scoring = PGText(config) 46 | elif method == "tfidf": 47 | scoring = TFIDF(config) 48 | else: 49 | # Resolve custom method 50 | scoring = ScoringFactory.resolve(method, config) 51 | 52 | # Store config back 53 | config["method"] = method 54 | 55 | return scoring 56 | 57 | @staticmethod 58 | def resolve(backend, config): 59 | """ 60 | Attempt to resolve a custom backend. 61 | 62 | Args: 63 | backend: backend class 64 | config: index configuration parameters 65 | 66 | Returns: 67 | Scoring 68 | """ 69 | 70 | try: 71 | return Resolver()(backend)(config) 72 | except Exception as e: 73 | raise ImportError(f"Unable to resolve scoring backend: '{backend}'") from e 74 | -------------------------------------------------------------------------------- /src/python/txtai/scoring/sif.py: -------------------------------------------------------------------------------- 1 | """ 2 | SIF module 3 | """ 4 | 5 | import numpy as np 6 | 7 | from .tfidf import TFIDF 8 | 9 | 10 | class SIF(TFIDF): 11 | """ 12 | Smooth Inverse Frequency (SIF) scoring. 13 | """ 14 | 15 | def __init__(self, config=None): 16 | super().__init__(config) 17 | 18 | # SIF configurable parameters 19 | self.a = self.config.get("a", 1e-3) 20 | 21 | def computefreq(self, tokens): 22 | # Default method computes frequency for a single entry 23 | # SIF uses word frequencies across entire index 24 | return {token: self.wordfreq[token] for token in tokens} 25 | 26 | def score(self, freq, idf, length): 27 | # Set freq to word frequencies across entire index when freq and idf shape don't match 28 | if isinstance(freq, np.ndarray) and freq.shape != idf.shape: 29 | freq.fill(freq.sum()) 30 | 31 | # Calculate SIF score 32 | return self.a / (self.a + freq / self.tokens) 33 | -------------------------------------------------------------------------------- /src/python/txtai/serialize/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Serialize imports 3 | """ 4 | 5 | from .base import Serialize 6 | from .errors import SerializeError 7 | from .factory import SerializeFactory 8 | from .messagepack import MessagePack 9 | from .pickle import Pickle 10 | from .serializer import Serializer 11 | -------------------------------------------------------------------------------- /src/python/txtai/serialize/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Serialize module 3 | """ 4 | 5 | 6 | class Serialize: 7 | """ 8 | Base class for Serialize instances. This class serializes data to files, streams and bytes. 9 | """ 10 | 11 | def load(self, path): 12 | """ 13 | Loads data from path. 14 | 15 | Args: 16 | path: input path 17 | 18 | Returns: 19 | deserialized data 20 | """ 21 | 22 | with open(path, "rb") as handle: 23 | return self.loadstream(handle) 24 | 25 | def save(self, data, path): 26 | """ 27 | Saves data to path. 28 | 29 | Args: 30 | data: data to save 31 | path: output path 32 | """ 33 | 34 | with open(path, "wb") as handle: 35 | self.savestream(data, handle) 36 | 37 | def loadstream(self, stream): 38 | """ 39 | Loads data from stream. 40 | 41 | Args: 42 | stream: input stream 43 | 44 | Returns: 45 | deserialized data 46 | """ 47 | 48 | raise NotImplementedError 49 | 50 | def savestream(self, data, stream): 51 | """ 52 | Saves data to stream. 53 | 54 | Args: 55 | data: data to save 56 | stream: output stream 57 | """ 58 | 59 | raise NotImplementedError 60 | 61 | def loadbytes(self, data): 62 | """ 63 | Loads data from bytes. 64 | 65 | Args: 66 | data: input bytes 67 | 68 | Returns: 69 | deserialized data 70 | """ 71 | 72 | raise NotImplementedError 73 | 74 | def savebytes(self, data): 75 | """ 76 | Saves data as bytes. 77 | 78 | Args: 79 | data: data to save 80 | 81 | Returns: 82 | serialized data 83 | """ 84 | 85 | raise NotImplementedError 86 | -------------------------------------------------------------------------------- /src/python/txtai/serialize/errors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Errors module 3 | """ 4 | 5 | 6 | class SerializeError(Exception): 7 | """ 8 | Raised when data serialization fails 9 | """ 10 | -------------------------------------------------------------------------------- /src/python/txtai/serialize/factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Factory module 3 | """ 4 | 5 | from .messagepack import MessagePack 6 | from .pickle import Pickle 7 | 8 | 9 | class SerializeFactory: 10 | """ 11 | Methods to create data serializers. 12 | """ 13 | 14 | @staticmethod 15 | def create(method=None, **kwargs): 16 | """ 17 | Creates a new Serialize instance. 18 | 19 | Args: 20 | method: serialization method 21 | kwargs: additional keyword arguments to pass to serialize instance 22 | """ 23 | 24 | # Pickle serialization 25 | if method == "pickle": 26 | return Pickle(**kwargs) 27 | 28 | # Default serialization 29 | return MessagePack(**kwargs) 30 | -------------------------------------------------------------------------------- /src/python/txtai/serialize/messagepack.py: -------------------------------------------------------------------------------- 1 | """ 2 | MessagePack module 3 | """ 4 | 5 | import msgpack 6 | from msgpack import Unpacker 7 | from msgpack.exceptions import ExtraData 8 | 9 | from .base import Serialize 10 | from .errors import SerializeError 11 | 12 | 13 | class MessagePack(Serialize): 14 | """ 15 | MessagePack serialization. 16 | """ 17 | 18 | def __init__(self, streaming=False): 19 | # Parent constructor 20 | super().__init__() 21 | 22 | self.streaming = streaming 23 | 24 | def loadstream(self, stream): 25 | try: 26 | # Support both streaming and non-streaming unpacking of data 27 | return Unpacker(stream) if self.streaming else msgpack.unpack(stream) 28 | except ExtraData as e: 29 | raise SerializeError(e) from e 30 | 31 | def savestream(self, data, stream): 32 | msgpack.pack(data, stream) 33 | 34 | def loadbytes(self, data): 35 | return msgpack.unpackb(data) 36 | 37 | def savebytes(self, data): 38 | return msgpack.packb(data) 39 | -------------------------------------------------------------------------------- /src/python/txtai/serialize/serializer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Serializer module 3 | """ 4 | 5 | from .errors import SerializeError 6 | from .factory import SerializeFactory 7 | 8 | 9 | class Serializer: 10 | """ 11 | Methods to serialize and deserialize data. 12 | """ 13 | 14 | @staticmethod 15 | def load(path): 16 | """ 17 | Loads data from path. This method first tries to load the default serialization format. 18 | If that fails, it will fallback to pickle format for backwards-compatability purposes. 19 | 20 | Note that loading pickle files requires the env variable `ALLOW_PICKLE=True`. 21 | 22 | Args: 23 | path: data to load 24 | 25 | Returns: 26 | data 27 | """ 28 | 29 | try: 30 | return SerializeFactory.create().load(path) 31 | except SerializeError: 32 | # Backwards compatible check for pickled data 33 | return SerializeFactory.create("pickle").load(path) 34 | 35 | @staticmethod 36 | def save(data, path): 37 | """ 38 | Saves data to path. 39 | 40 | Args: 41 | data: data to save 42 | path: output path 43 | """ 44 | 45 | # Save using default serialization method 46 | SerializeFactory.create().save(data, path) 47 | -------------------------------------------------------------------------------- /src/python/txtai/util/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility imports 3 | """ 4 | 5 | from .resolver import Resolver 6 | from .template import TemplateFormatter 7 | -------------------------------------------------------------------------------- /src/python/txtai/util/resolver.py: -------------------------------------------------------------------------------- 1 | """ 2 | Resolver module 3 | """ 4 | 5 | 6 | class Resolver: 7 | """ 8 | Resolves a Python class path 9 | """ 10 | 11 | def __call__(self, path): 12 | """ 13 | Class instance to resolve. 14 | 15 | Args: 16 | path: path to class 17 | 18 | Returns: 19 | class instance 20 | """ 21 | 22 | # Split into path components 23 | parts = path.split(".") 24 | 25 | # Resolve each path component 26 | module = ".".join(parts[:-1]) 27 | m = __import__(module) 28 | for comp in parts[1:]: 29 | m = getattr(m, comp) 30 | 31 | # Return class instance 32 | return m 33 | -------------------------------------------------------------------------------- /src/python/txtai/util/template.py: -------------------------------------------------------------------------------- 1 | """ 2 | Template module 3 | """ 4 | 5 | from string import Formatter 6 | 7 | 8 | class TemplateFormatter(Formatter): 9 | """ 10 | Custom Formatter that requires each argument to be consumed. 11 | """ 12 | 13 | def check_unused_args(self, used_args, args, kwargs): 14 | difference = set(kwargs).difference(used_args) 15 | if difference: 16 | raise KeyError(difference) 17 | -------------------------------------------------------------------------------- /src/python/txtai/vectors/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Vectors import 3 | """ 4 | 5 | from .base import Vectors 6 | from .external import External 7 | from .factory import VectorsFactory 8 | from .huggingface import HFVectors 9 | from .litellm import LiteLLM 10 | from .llama import LlamaCpp 11 | from .m2v import Model2Vec 12 | from .recovery import Recovery 13 | from .sbert import STVectors 14 | from .words import WordVectors 15 | -------------------------------------------------------------------------------- /src/python/txtai/vectors/external.py: -------------------------------------------------------------------------------- 1 | """ 2 | External module 3 | """ 4 | 5 | import types 6 | 7 | import numpy as np 8 | 9 | from ..util import Resolver 10 | 11 | from .base import Vectors 12 | 13 | 14 | class External(Vectors): 15 | """ 16 | Builds vectors using an external method. This can be a local function or an external API call. 17 | """ 18 | 19 | def __init__(self, config, scoring, models): 20 | super().__init__(config, scoring, models) 21 | 22 | # Lookup and resolve transform function 23 | self.transform = self.resolve(config.get("transform")) 24 | 25 | def loadmodel(self, path): 26 | return None 27 | 28 | def encode(self, data): 29 | # Call external transform function, if available and data not already an array 30 | # Batching is handed by the external transform function 31 | if self.transform and data and not isinstance(data[0], np.ndarray): 32 | data = self.transform(data) 33 | 34 | # Cast to float32 35 | return data.astype(np.float32) if isinstance(data, np.ndarray) else np.array(data, dtype=np.float32) 36 | 37 | def resolve(self, transform): 38 | """ 39 | Resolves a transform function. 40 | 41 | Args: 42 | transform: transform function 43 | 44 | Returns: 45 | resolved transform function 46 | """ 47 | 48 | if transform: 49 | # Resolve transform instance, if necessary 50 | transform = Resolver()(transform) if transform and isinstance(transform, str) else transform 51 | 52 | # Get function or callable instance 53 | transform = transform if isinstance(transform, types.FunctionType) else transform() 54 | 55 | return transform 56 | -------------------------------------------------------------------------------- /src/python/txtai/vectors/huggingface.py: -------------------------------------------------------------------------------- 1 | """ 2 | Hugging Face module 3 | """ 4 | 5 | from ..models import Models, PoolingFactory 6 | 7 | from .base import Vectors 8 | 9 | 10 | class HFVectors(Vectors): 11 | """ 12 | Builds vectors using the Hugging Face transformers library. 13 | """ 14 | 15 | @staticmethod 16 | def ismethod(method): 17 | """ 18 | Checks if this method uses local transformers-based models. 19 | 20 | Args: 21 | method: input method 22 | 23 | Returns: 24 | True if this is a local transformers-based model, False otherwise 25 | """ 26 | 27 | return method in ("transformers", "pooling", "clspooling", "meanpooling") 28 | 29 | def loadmodel(self, path): 30 | # Build embeddings with transformers pooling 31 | return PoolingFactory.create( 32 | { 33 | "method": self.config.get("method"), 34 | "path": path, 35 | "device": Models.deviceid(self.config.get("gpu", True)), 36 | "tokenizer": self.config.get("tokenizer"), 37 | "maxlength": self.config.get("maxlength"), 38 | "modelargs": self.config.get("vectors", {}), 39 | } 40 | ) 41 | 42 | def encode(self, data): 43 | # Encode data using vectors model 44 | return self.model.encode(data, batch=self.encodebatch) 45 | -------------------------------------------------------------------------------- /src/python/txtai/vectors/litellm.py: -------------------------------------------------------------------------------- 1 | """ 2 | LiteLLM module 3 | """ 4 | 5 | import numpy as np 6 | 7 | # Conditional import 8 | try: 9 | import litellm as api 10 | 11 | LITELLM = True 12 | except ImportError: 13 | LITELLM = False 14 | 15 | from .base import Vectors 16 | 17 | 18 | class LiteLLM(Vectors): 19 | """ 20 | Builds vectors using an external embeddings API via LiteLLM. 21 | """ 22 | 23 | @staticmethod 24 | def ismodel(path): 25 | """ 26 | Checks if path is a LiteLLM model. 27 | 28 | Args: 29 | path: input path 30 | 31 | Returns: 32 | True if this is a LiteLLM model, False otherwise 33 | """ 34 | 35 | # pylint: disable=W0702 36 | if isinstance(path, str) and LITELLM: 37 | debug = api.suppress_debug_info 38 | try: 39 | # Suppress debug messages for this test 40 | api.suppress_debug_info = True 41 | return api.get_llm_provider(path) 42 | except: 43 | return False 44 | finally: 45 | # Restore debug info value to original value 46 | api.suppress_debug_info = debug 47 | 48 | return False 49 | 50 | def __init__(self, config, scoring, models): 51 | # Check before parent constructor since it calls loadmodel 52 | if not LITELLM: 53 | raise ImportError('LiteLLM is not available - install "vectors" extra to enable') 54 | 55 | super().__init__(config, scoring, models) 56 | 57 | def loadmodel(self, path): 58 | return None 59 | 60 | def encode(self, data): 61 | # Call external embeddings API using LiteLLM 62 | # Batching is handled server-side 63 | response = api.embedding(model=self.config.get("path"), input=data, **self.config.get("vectors", {})) 64 | 65 | # Read response into a NumPy array 66 | return np.array([x["embedding"] for x in response.data], dtype=np.float32) 67 | -------------------------------------------------------------------------------- /src/python/txtai/vectors/m2v.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model2Vec module 3 | """ 4 | 5 | import json 6 | 7 | from huggingface_hub.errors import HFValidationError 8 | from transformers.utils import cached_file 9 | 10 | # Conditional import 11 | try: 12 | from model2vec import StaticModel 13 | 14 | MODEL2VEC = True 15 | except ImportError: 16 | MODEL2VEC = False 17 | 18 | from .base import Vectors 19 | 20 | 21 | class Model2Vec(Vectors): 22 | """ 23 | Builds vectors using Model2Vec. 24 | """ 25 | 26 | @staticmethod 27 | def ismodel(path): 28 | """ 29 | Checks if path is a Model2Vec model. 30 | 31 | Args: 32 | path: input path 33 | 34 | Returns: 35 | True if this is a Model2Vec model, False otherwise 36 | """ 37 | 38 | try: 39 | # Download file and parse JSON 40 | path = cached_file(path_or_repo_id=path, filename="config.json") 41 | if path: 42 | with open(path, encoding="utf-8") as f: 43 | config = json.load(f) 44 | return config.get("model_type") == "model2vec" 45 | 46 | # Ignore this error - invalid repo or directory 47 | except (HFValidationError, OSError): 48 | pass 49 | 50 | return False 51 | 52 | def __init__(self, config, scoring, models): 53 | # Check before parent constructor since it calls loadmodel 54 | if not MODEL2VEC: 55 | raise ImportError('Model2Vec is not available - install "vectors" extra to enable') 56 | 57 | super().__init__(config, scoring, models) 58 | 59 | def loadmodel(self, path): 60 | return StaticModel.from_pretrained(path) 61 | 62 | def encode(self, data): 63 | # Additional model arguments 64 | modelargs = self.config.get("vectors", {}) 65 | 66 | # Encode data 67 | return self.model.encode(data, batch_size=self.encodebatch, **modelargs) 68 | -------------------------------------------------------------------------------- /src/python/txtai/vectors/recovery.py: -------------------------------------------------------------------------------- 1 | """ 2 | Recovery module 3 | """ 4 | 5 | import os 6 | import shutil 7 | 8 | import numpy as np 9 | 10 | 11 | class Recovery: 12 | """ 13 | Vector embeddings recovery. This class handles streaming embeddings from a vector checkpoint file. 14 | """ 15 | 16 | def __init__(self, checkpoint, vectorsid): 17 | """ 18 | Creates a Recovery instance. 19 | 20 | Args: 21 | checkpoint: checkpoint directory 22 | vectorsid: vectors uid for current configuration 23 | """ 24 | 25 | self.spool, self.path = None, None 26 | 27 | # Get unique file id 28 | path = f"{checkpoint}/{vectorsid}" 29 | if os.path.exists(path): 30 | # Generate recovery path 31 | self.path = f"{checkpoint}/recovery" 32 | 33 | # Copy current checkpoint to recovery 34 | shutil.copyfile(path, self.path) 35 | 36 | # Open file an return 37 | # pylint: disable=R1732 38 | self.spool = open(self.path, "rb") 39 | 40 | def __call__(self): 41 | """ 42 | Reads and returns the next batch of embeddings. 43 | 44 | Returns 45 | batch of embeddings 46 | """ 47 | 48 | try: 49 | return np.load(self.spool) if self.spool else None 50 | except EOFError: 51 | # End of spool file, cleanup 52 | self.spool.close() 53 | os.remove(self.path) 54 | 55 | # Clear parameters 56 | self.spool, self.path = None, None 57 | 58 | return None 59 | -------------------------------------------------------------------------------- /src/python/txtai/version.py: -------------------------------------------------------------------------------- 1 | """ 2 | Version strings 3 | """ 4 | 5 | # Current version tag 6 | __version__ = "8.6.0" 7 | -------------------------------------------------------------------------------- /src/python/txtai/workflow/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Workflow imports 3 | """ 4 | 5 | from .base import Workflow 6 | from .execute import Execute 7 | from .factory import WorkflowFactory 8 | from .task import * 9 | -------------------------------------------------------------------------------- /src/python/txtai/workflow/factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Workflow factory module 3 | """ 4 | 5 | from .base import Workflow 6 | from .task import TaskFactory 7 | 8 | 9 | class WorkflowFactory: 10 | """ 11 | Workflow factory. Creates new Workflow instances. 12 | """ 13 | 14 | @staticmethod 15 | def create(config, name): 16 | """ 17 | Creates a new Workflow instance. 18 | 19 | Args: 20 | config: Workflow configuration 21 | name: Workflow name 22 | 23 | Returns: 24 | Workflow 25 | """ 26 | 27 | # Resolve workflow tasks 28 | tasks = [] 29 | for tconfig in config["tasks"]: 30 | task = tconfig.pop("task") if "task" in tconfig else "" 31 | tasks.append(TaskFactory.create(tconfig, task)) 32 | 33 | config["tasks"] = tasks 34 | 35 | if "stream" in config: 36 | sconfig = config["stream"] 37 | task = sconfig.pop("task") if "task" in sconfig else "stream" 38 | 39 | config["stream"] = TaskFactory.create(sconfig, task) 40 | 41 | # Create workflow 42 | return Workflow(**config, name=name) 43 | -------------------------------------------------------------------------------- /src/python/txtai/workflow/task/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Task imports 3 | """ 4 | 5 | from .base import Task 6 | from .console import ConsoleTask 7 | from .export import ExportTask 8 | from .factory import TaskFactory 9 | from .file import FileTask 10 | from .image import ImageTask 11 | from .retrieve import RetrieveTask 12 | from .service import ServiceTask 13 | from .storage import StorageTask 14 | from .stream import StreamTask 15 | from .template import * 16 | from .template import RagTask as ExtractorTask 17 | from .url import UrlTask 18 | from .workflow import WorkflowTask 19 | -------------------------------------------------------------------------------- /src/python/txtai/workflow/task/console.py: -------------------------------------------------------------------------------- 1 | """ 2 | ConsoleTask module 3 | """ 4 | 5 | import json 6 | 7 | from .base import Task 8 | 9 | 10 | class ConsoleTask(Task): 11 | """ 12 | Task that prints task elements to the console. 13 | """ 14 | 15 | def __call__(self, elements, executor=None): 16 | # Run task 17 | outputs = super().__call__(elements, executor) 18 | 19 | # Print inputs and outputs to console 20 | print("Inputs:", json.dumps(elements, indent=2)) 21 | print("Outputs:", json.dumps(outputs, indent=2)) 22 | 23 | # Return results 24 | return outputs 25 | -------------------------------------------------------------------------------- /src/python/txtai/workflow/task/export.py: -------------------------------------------------------------------------------- 1 | """ 2 | ExportTask module 3 | """ 4 | 5 | import datetime 6 | import os 7 | 8 | # Conditional import 9 | try: 10 | import pandas as pd 11 | 12 | PANDAS = True 13 | except ImportError: 14 | PANDAS = False 15 | 16 | from .base import Task 17 | 18 | 19 | class ExportTask(Task): 20 | """ 21 | Task that exports task elements using Pandas. 22 | """ 23 | 24 | def register(self, output=None, timestamp=None): 25 | """ 26 | Add export parameters to task. Checks if required dependencies are installed. 27 | 28 | Args: 29 | output: output file path 30 | timestamp: true if output file should be timestamped 31 | """ 32 | 33 | if not PANDAS: 34 | raise ImportError('ExportTask is not available - install "workflow" extra to enable') 35 | 36 | # pylint: disable=W0201 37 | self.output = output 38 | self.timestamp = timestamp 39 | 40 | def __call__(self, elements, executor=None): 41 | # Run task 42 | outputs = super().__call__(elements, executor) 43 | 44 | # Get output file extension 45 | output = self.output 46 | parts = list(os.path.splitext(output)) 47 | extension = parts[-1].lower() 48 | 49 | # Add timestamp to filename 50 | if self.timestamp: 51 | timestamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%dT%H%M%SZ") 52 | parts[-1] = timestamp + parts[-1] 53 | 54 | # Create full path to output file 55 | output = ".".join(parts) 56 | 57 | # Write output 58 | if extension == ".xlsx": 59 | pd.DataFrame(outputs).to_excel(output, index=False) 60 | else: 61 | pd.DataFrame(outputs).to_csv(output, index=False) 62 | 63 | # Return results 64 | return outputs 65 | -------------------------------------------------------------------------------- /src/python/txtai/workflow/task/file.py: -------------------------------------------------------------------------------- 1 | """ 2 | FileTask module 3 | """ 4 | 5 | import os 6 | import re 7 | 8 | from .base import Task 9 | 10 | 11 | class FileTask(Task): 12 | """ 13 | Task that processes file paths 14 | """ 15 | 16 | # File prefix 17 | FILE = r"file:\/\/" 18 | 19 | def accept(self, element): 20 | # Replace file prefixes 21 | element = re.sub(FileTask.FILE, "", element) 22 | 23 | # Only accept file paths that exist 24 | return super().accept(element) and isinstance(element, str) and os.path.exists(element) 25 | 26 | def prepare(self, element): 27 | # Replace file prefixes 28 | return re.sub(FileTask.FILE, "", element) 29 | -------------------------------------------------------------------------------- /src/python/txtai/workflow/task/image.py: -------------------------------------------------------------------------------- 1 | """ 2 | ImageTask module 3 | """ 4 | 5 | import re 6 | 7 | # Conditional import 8 | try: 9 | from PIL import Image 10 | 11 | PIL = True 12 | except ImportError: 13 | PIL = False 14 | 15 | from .file import FileTask 16 | 17 | 18 | class ImageTask(FileTask): 19 | """ 20 | Task that processes image file urls 21 | """ 22 | 23 | def register(self): 24 | """ 25 | Checks if required dependencies are installed. 26 | """ 27 | 28 | if not PIL: 29 | raise ImportError('ImageTask is not available - install "workflow" extra to enable') 30 | 31 | def accept(self, element): 32 | # Only accept image files 33 | return super().accept(element) and re.search(r"\.(gif|bmp|jpg|jpeg|png|webp)$", element.lower()) 34 | 35 | def prepare(self, element): 36 | return Image.open(super().prepare(element)) 37 | -------------------------------------------------------------------------------- /src/python/txtai/workflow/task/retrieve.py: -------------------------------------------------------------------------------- 1 | """ 2 | RetrieveTask module 3 | """ 4 | 5 | import os 6 | import tempfile 7 | 8 | from urllib.request import urlretrieve 9 | from urllib.parse import urlparse 10 | 11 | from .url import UrlTask 12 | 13 | 14 | class RetrieveTask(UrlTask): 15 | """ 16 | Task that retrieves urls (local or remote) to a local directory. 17 | """ 18 | 19 | def register(self, directory=None, flatten=True): 20 | """ 21 | Adds retrieve parameters to task. 22 | 23 | Args: 24 | directory: local directory used to store retrieved files 25 | flatten: flatten input directory structure, defaults to True 26 | """ 27 | 28 | # pylint: disable=W0201 29 | # Create default temporary directory if not specified 30 | if not directory: 31 | # Save tempdir to prevent content from being deleted until this task is out of scope 32 | # pylint: disable=R1732 33 | self.tempdir = tempfile.TemporaryDirectory() 34 | directory = self.tempdir.name 35 | 36 | # Create output directory if necessary 37 | os.makedirs(directory, exist_ok=True) 38 | 39 | self.directory = directory 40 | self.flatten = flatten 41 | 42 | def prepare(self, element): 43 | # Extract file path from URL 44 | path = urlparse(element).path 45 | 46 | if self.flatten: 47 | # Flatten directory structure (default) 48 | path = os.path.join(self.directory, os.path.basename(path)) 49 | else: 50 | # Derive output path 51 | path = os.path.join(self.directory, os.path.normpath(path.lstrip("/"))) 52 | directory = os.path.dirname(path) 53 | 54 | # Create local directory, if necessary 55 | os.makedirs(directory, exist_ok=True) 56 | 57 | # Retrieve URL 58 | urlretrieve(element, path) 59 | 60 | # Return new file path 61 | return path 62 | -------------------------------------------------------------------------------- /src/python/txtai/workflow/task/stream.py: -------------------------------------------------------------------------------- 1 | """ 2 | StreamTask module 3 | """ 4 | 5 | from .base import Task 6 | 7 | 8 | class StreamTask(Task): 9 | """ 10 | Task that calls a task action and yields results. 11 | """ 12 | 13 | def register(self, batch=False): 14 | """ 15 | Adds stream parameters to task. 16 | 17 | Args: 18 | batch: all elements are passed to a single action call if True, otherwise an action call is executed per element, defaults to False 19 | """ 20 | 21 | # pylint: disable=W0201 22 | # All elements are passed to a single action call if True, otherwise an action call is executed per element, defaults to False 23 | self.batch = batch 24 | 25 | def __call__(self, elements, executor=None): 26 | for action in self.action: 27 | if self.batch: 28 | # Single batch call 29 | yield from action(elements) 30 | else: 31 | # Call action for each element 32 | for x in elements: 33 | yield from action(x) 34 | -------------------------------------------------------------------------------- /src/python/txtai/workflow/task/url.py: -------------------------------------------------------------------------------- 1 | """ 2 | UrlTask module 3 | """ 4 | 5 | import re 6 | 7 | from .base import Task 8 | 9 | 10 | class UrlTask(Task): 11 | """ 12 | Task that processes urls 13 | """ 14 | 15 | # URL prefix 16 | PREFIX = r"\w+:\/\/" 17 | 18 | def accept(self, element): 19 | # Only accept elements that start with a url prefix 20 | return super().accept(element) and re.match(UrlTask.PREFIX, element.lower()) 21 | -------------------------------------------------------------------------------- /src/python/txtai/workflow/task/workflow.py: -------------------------------------------------------------------------------- 1 | """ 2 | WorkflowTask module 3 | """ 4 | 5 | from .base import Task 6 | 7 | 8 | class WorkflowTask(Task): 9 | """ 10 | Task that executes a separate Workflow 11 | """ 12 | 13 | def process(self, action, inputs): 14 | return list(super().process(action, inputs)) 15 | -------------------------------------------------------------------------------- /test/python/testapi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/test/python/testapi/__init__.py -------------------------------------------------------------------------------- /test/python/testapi/testauthorization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authorization module tests 3 | """ 4 | 5 | import hashlib 6 | import os 7 | import tempfile 8 | import unittest 9 | 10 | from unittest.mock import patch 11 | 12 | from fastapi.testclient import TestClient 13 | 14 | from txtai.api import application 15 | 16 | 17 | class TestAuthorization(unittest.TestCase): 18 | """ 19 | API tests for token authorization. 20 | """ 21 | 22 | @staticmethod 23 | @patch.dict( 24 | os.environ, 25 | { 26 | "CONFIG": os.path.join(tempfile.gettempdir(), "testapi.yml"), 27 | "DEPENDENCIES": "txtai.api.Authorization", 28 | "TOKEN": hashlib.sha256("token".encode("utf-8")).hexdigest(), 29 | }, 30 | ) 31 | def start(): 32 | """ 33 | Starts a mock FastAPI client. 34 | """ 35 | 36 | config = os.path.join(tempfile.gettempdir(), "testapi.yml") 37 | 38 | with open(config, "w", encoding="utf-8") as output: 39 | output.write("embeddings:\n") 40 | 41 | # Create new application and set on client 42 | application.app = application.create() 43 | client = TestClient(application.app) 44 | application.start() 45 | 46 | return client 47 | 48 | @classmethod 49 | def setUpClass(cls): 50 | """ 51 | Create API client on creation of class. 52 | """ 53 | 54 | cls.client = TestAuthorization.start() 55 | 56 | def testInvalid(self): 57 | """ 58 | Test invalid authorization 59 | """ 60 | 61 | response = self.client.get("search?query=test") 62 | self.assertEqual(response.status_code, 401) 63 | 64 | response = self.client.get("search?query=test", headers={"Authorization": "Bearer invalid"}) 65 | self.assertEqual(response.status_code, 401) 66 | 67 | def testValid(self): 68 | """ 69 | Test valid authorization 70 | """ 71 | 72 | results = self.client.get("search?query=test", headers={"Authorization": "Bearer token"}).json() 73 | self.assertEqual(results, []) 74 | -------------------------------------------------------------------------------- /test/python/testapi/testmcp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Agent API module tests 3 | """ 4 | 5 | import os 6 | import tempfile 7 | import unittest 8 | 9 | from unittest.mock import patch 10 | 11 | from fastapi.testclient import TestClient 12 | 13 | from txtai.api import application 14 | 15 | # Configuration for agents 16 | MCP = """ 17 | mcp: True 18 | """ 19 | 20 | 21 | # pylint: disable=R0904 22 | class TestMCP(unittest.TestCase): 23 | """ 24 | API tests for model context protocol (MCP) 25 | """ 26 | 27 | @staticmethod 28 | @patch.dict(os.environ, {"CONFIG": os.path.join(tempfile.gettempdir(), "testapi.yml"), "API_CLASS": "txtai.api.API"}) 29 | def start(): 30 | """ 31 | Starts a mock FastAPI client. 32 | """ 33 | 34 | config = os.path.join(tempfile.gettempdir(), "testapi.yml") 35 | 36 | with open(config, "w", encoding="utf-8") as output: 37 | output.write(MCP) 38 | 39 | # Create new application and set on client 40 | application.app = application.create() 41 | client = TestClient(application.app) 42 | application.start() 43 | 44 | return client 45 | 46 | @classmethod 47 | def setUpClass(cls): 48 | """ 49 | Create API client on creation of class. 50 | """ 51 | 52 | cls.client = TestMCP.start() 53 | 54 | def testMCP(self): 55 | """ 56 | Test that application a /mcp route 57 | """ 58 | 59 | self.assertTrue(any(route.path == "/mcp" for route in self.client.app.routes)) 60 | -------------------------------------------------------------------------------- /test/python/testdatabase/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/test/python/testdatabase/__init__.py -------------------------------------------------------------------------------- /test/python/testdatabase/testcustom.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom database tests 3 | """ 4 | 5 | import unittest 6 | 7 | from txtai.database import DatabaseFactory 8 | 9 | 10 | class TestCustom(unittest.TestCase): 11 | """ 12 | Custom database backend tests. 13 | """ 14 | 15 | def testCustomBackend(self): 16 | """ 17 | Test resolving a custom backend 18 | """ 19 | 20 | database = DatabaseFactory.create({"content": "txtai.database.SQLite"}) 21 | self.assertIsNotNone(database) 22 | 23 | def testCustomBackendNotFound(self): 24 | """ 25 | Test resolving an unresolvable backend 26 | """ 27 | 28 | with self.assertRaises(ImportError): 29 | DatabaseFactory.create({"content": "notfound.database"}) 30 | -------------------------------------------------------------------------------- /test/python/testdatabase/testdatabase.py: -------------------------------------------------------------------------------- 1 | """ 2 | Database tests 3 | """ 4 | 5 | import unittest 6 | 7 | from txtai.database import Database 8 | 9 | 10 | class TestDatabase(unittest.TestCase): 11 | """ 12 | Base database tests. 13 | """ 14 | 15 | def testNotImplemented(self): 16 | """ 17 | Test exceptions for non-implemented methods 18 | """ 19 | 20 | database = Database({}) 21 | 22 | self.assertRaises(NotImplementedError, database.load, None) 23 | self.assertRaises(NotImplementedError, database.insert, None) 24 | self.assertRaises(NotImplementedError, database.delete, None) 25 | self.assertRaises(NotImplementedError, database.reindex, None) 26 | self.assertRaises(NotImplementedError, database.save, None) 27 | self.assertRaises(NotImplementedError, database.close) 28 | self.assertRaises(NotImplementedError, database.ids, None) 29 | self.assertRaises(NotImplementedError, database.count) 30 | self.assertRaises(NotImplementedError, database.resolve, None, None) 31 | self.assertRaises(NotImplementedError, database.embed, None, None) 32 | self.assertRaises(NotImplementedError, database.query, None, None, None, None) 33 | -------------------------------------------------------------------------------- /test/python/testdatabase/testduckdb.py: -------------------------------------------------------------------------------- 1 | """ 2 | DuckDB module tests 3 | """ 4 | 5 | import os 6 | import unittest 7 | 8 | from txtai.embeddings import Embeddings 9 | 10 | from .testrdbms import Common 11 | 12 | 13 | # pylint: disable=R0904 14 | class TestDuckDB(Common.TestRDBMS): 15 | """ 16 | Embeddings with content stored in DuckDB. 17 | """ 18 | 19 | @classmethod 20 | def setUpClass(cls): 21 | """ 22 | Initialize test data. 23 | """ 24 | 25 | cls.data = [ 26 | "US tops 5 million confirmed virus cases", 27 | "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", 28 | "Beijing mobilises invasion craft along coast as Taiwan tensions escalate", 29 | "The National Park Service warns against sacrificing slower friends in a bear attack", 30 | "Maine man wins $1M from $25 lottery ticket", 31 | "Make huge profits without work, earn up to $100,000 a day", 32 | ] 33 | 34 | # Content backend 35 | cls.backend = "duckdb" 36 | 37 | # Create embeddings model, backed by sentence-transformers & transformers 38 | cls.embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "content": cls.backend}) 39 | 40 | @classmethod 41 | def tearDownClass(cls): 42 | """ 43 | Cleanup data. 44 | """ 45 | 46 | if cls.embeddings: 47 | cls.embeddings.close() 48 | 49 | @unittest.skipIf(os.name == "nt", "testArchive skipped on Windows") 50 | def testArchive(self): 51 | """ 52 | Test embeddings index archiving 53 | """ 54 | 55 | super().testArchive() 56 | -------------------------------------------------------------------------------- /test/python/testmodels/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/test/python/testmodels/__init__.py -------------------------------------------------------------------------------- /test/python/testmodels/testmodels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Models module tests 3 | """ 4 | 5 | import unittest 6 | 7 | from unittest.mock import patch 8 | 9 | import torch 10 | 11 | from txtai.models import Models 12 | 13 | 14 | class TestModels(unittest.TestCase): 15 | """ 16 | Models tests. 17 | """ 18 | 19 | @patch("torch.cuda.is_available") 20 | def testDeviceid(self, cuda): 21 | """ 22 | Test the deviceid method 23 | """ 24 | 25 | cuda.return_value = True 26 | self.assertEqual(Models.deviceid(True), 0) 27 | self.assertEqual(Models.deviceid(False), -1) 28 | self.assertEqual(Models.deviceid(0), 0) 29 | self.assertEqual(Models.deviceid(1), 1) 30 | 31 | # Test direct torch device 32 | # pylint: disable=E1101 33 | self.assertEqual(Models.deviceid(torch.device("cpu")), torch.device("cpu")) 34 | 35 | cuda.return_value = False 36 | self.assertEqual(Models.deviceid(True), -1) 37 | self.assertEqual(Models.deviceid(False), -1) 38 | self.assertEqual(Models.deviceid(0), -1) 39 | self.assertEqual(Models.deviceid(1), -1) 40 | 41 | def testDevice(self): 42 | """ 43 | Test the device method 44 | """ 45 | 46 | # pylint: disable=E1101 47 | self.assertEqual(Models.device("cpu"), torch.device("cpu")) 48 | self.assertEqual(Models.device(torch.device("cpu")), torch.device("cpu")) 49 | -------------------------------------------------------------------------------- /test/python/testpipeline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/test/python/testpipeline/__init__.py -------------------------------------------------------------------------------- /test/python/testpipeline/testaudio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/test/python/testpipeline/testaudio/__init__.py -------------------------------------------------------------------------------- /test/python/testpipeline/testaudio/testaudiomixer.py: -------------------------------------------------------------------------------- 1 | """ 2 | AudioMixer module tests 3 | """ 4 | 5 | import unittest 6 | 7 | import numpy as np 8 | 9 | from txtai.pipeline import AudioMixer 10 | 11 | 12 | class TestAudioStream(unittest.TestCase): 13 | """ 14 | AudioStream tests. 15 | """ 16 | 17 | def testAudioStream(self): 18 | """ 19 | Test mixing audio streams 20 | """ 21 | 22 | audio1 = np.random.rand(2, 5000), 100 23 | audio2 = np.random.rand(2, 5000), 100 24 | 25 | mixer = AudioMixer() 26 | audio, rate = mixer((audio1, audio2)) 27 | 28 | self.assertEqual(audio.shape, (2, 5000)) 29 | self.assertEqual(rate, 100) 30 | -------------------------------------------------------------------------------- /test/python/testpipeline/testaudio/testaudiostream.py: -------------------------------------------------------------------------------- 1 | """ 2 | AudioStream module tests 3 | """ 4 | 5 | import unittest 6 | 7 | from unittest.mock import patch 8 | 9 | import soundfile as sf 10 | 11 | from txtai.pipeline import AudioStream 12 | 13 | # pylint: disable = C0411 14 | from utils import Utils 15 | 16 | 17 | class TestAudioStream(unittest.TestCase): 18 | """ 19 | AudioStream tests. 20 | """ 21 | 22 | @patch("sounddevice.play") 23 | def testAudioStream(self, play): 24 | """ 25 | Test playing audio 26 | """ 27 | 28 | play.return_value = True 29 | 30 | # Read audio data 31 | audio, rate = sf.read(Utils.PATH + "/Make_huge_profits.wav") 32 | 33 | stream = AudioStream() 34 | self.assertIsNotNone(stream([(audio, rate), AudioStream.COMPLETE])) 35 | 36 | # Wait for completion 37 | stream.wait() 38 | -------------------------------------------------------------------------------- /test/python/testpipeline/testaudio/testtexttoaudio.py: -------------------------------------------------------------------------------- 1 | """ 2 | TextToAudio module tests 3 | """ 4 | 5 | import unittest 6 | 7 | from txtai.pipeline import TextToAudio 8 | 9 | 10 | class TestTextToAudio(unittest.TestCase): 11 | """ 12 | TextToAudio tests. 13 | """ 14 | 15 | def testTextToAudio(self): 16 | """ 17 | Test generating audio for text 18 | """ 19 | 20 | tta = TextToAudio("hf-internal-testing/tiny-random-MusicgenForConditionalGeneration") 21 | 22 | # Check that data is generated 23 | audio, rate = tta("This is a test") 24 | 25 | self.assertGreater(len(audio), 0) 26 | self.assertEqual(rate, 24000) 27 | -------------------------------------------------------------------------------- /test/python/testpipeline/testdata/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/test/python/testpipeline/testdata/__init__.py -------------------------------------------------------------------------------- /test/python/testpipeline/testdata/testfiletohtml.py: -------------------------------------------------------------------------------- 1 | """ 2 | FileToHTML module tests 3 | """ 4 | 5 | import os 6 | import unittest 7 | 8 | from unittest.mock import patch 9 | 10 | from txtai.pipeline.data.filetohtml import Tika 11 | 12 | 13 | class TestFileToHTML(unittest.TestCase): 14 | """ 15 | FileToHTML tests. 16 | """ 17 | 18 | @patch.dict(os.environ, {"TIKA_JAVA": "1112444abc"}) 19 | def testTika(self): 20 | """ 21 | Test the Tika.available returns False when Java is not available 22 | """ 23 | 24 | self.assertFalse(Tika.available()) 25 | -------------------------------------------------------------------------------- /test/python/testpipeline/testdata/testtokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tokenizer module tests 3 | """ 4 | 5 | import unittest 6 | 7 | from txtai.pipeline import Tokenizer 8 | 9 | 10 | class TestTokenizer(unittest.TestCase): 11 | """ 12 | Tokenizer tests. 13 | """ 14 | 15 | def testAlphanumTokenize(self): 16 | """ 17 | Test alphanumeric tokenization 18 | """ 19 | 20 | # Alphanumeric tokenization through backwards compatible static method 21 | self.assertEqual(Tokenizer.tokenize("Y this is a test!"), ["test"]) 22 | self.assertEqual(Tokenizer.tokenize("abc123 ABC 123"), ["abc123", "abc"]) 23 | 24 | def testEmptyTokenize(self): 25 | """ 26 | Test handling empty and None inputs 27 | """ 28 | 29 | # Test that parser can handle empty or None strings 30 | self.assertEqual(Tokenizer.tokenize(""), []) 31 | self.assertEqual(Tokenizer.tokenize(None), None) 32 | 33 | def testStandardTokenize(self): 34 | """ 35 | Test standard tokenization 36 | """ 37 | 38 | # Default standard tokenizer parameters 39 | tokenizer = Tokenizer() 40 | 41 | # Define token tests 42 | tests = [ 43 | ("Y this is a test!", ["y", "this", "is", "a", "test"]), 44 | ("abc123 ABC 123", ["abc123", "abc", "123"]), 45 | ("Testing hy-phenated words", ["testing", "hy", "phenated", "words"]), 46 | ("111-111-1111", ["111", "111", "1111"]), 47 | ("Test.1234", ["test", "1234"]), 48 | ] 49 | 50 | # Run through tests 51 | for test, result in tests: 52 | # Unicode Text Segmentation per Unicode Annex #29 53 | self.assertEqual(tokenizer(test), result) 54 | -------------------------------------------------------------------------------- /test/python/testpipeline/testimage/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/test/python/testpipeline/testimage/__init__.py -------------------------------------------------------------------------------- /test/python/testpipeline/testimage/testcaption.py: -------------------------------------------------------------------------------- 1 | """ 2 | Caption module tests 3 | """ 4 | 5 | import unittest 6 | 7 | from PIL import Image 8 | 9 | from txtai.pipeline import Caption 10 | 11 | # pylint: disable = C0411 12 | from utils import Utils 13 | 14 | 15 | class TestCaption(unittest.TestCase): 16 | """ 17 | Caption tests. 18 | """ 19 | 20 | def testCaption(self): 21 | """ 22 | Test captions 23 | """ 24 | 25 | caption = Caption() 26 | self.assertEqual(caption(Image.open(Utils.PATH + "/books.jpg")), "a book shelf filled with books and a stack of books") 27 | -------------------------------------------------------------------------------- /test/python/testpipeline/testimage/testimagehash.py: -------------------------------------------------------------------------------- 1 | """ 2 | ImageHash module tests 3 | """ 4 | 5 | import unittest 6 | 7 | from PIL import Image 8 | 9 | from txtai.pipeline import ImageHash 10 | 11 | # pylint: disable = C0411 12 | from utils import Utils 13 | 14 | 15 | class TestImageHash(unittest.TestCase): 16 | """ 17 | ImageHash tests. 18 | """ 19 | 20 | @classmethod 21 | def setUpClass(cls): 22 | """ 23 | Caches an image to hash 24 | """ 25 | 26 | cls.image = Image.open(Utils.PATH + "/books.jpg") 27 | 28 | def testArray(self): 29 | """ 30 | Test numpy return type 31 | """ 32 | 33 | ihash = ImageHash(strings=False) 34 | self.assertEqual(ihash(self.image).shape, (64,)) 35 | 36 | def testAverage(self): 37 | """ 38 | Test average hash 39 | """ 40 | 41 | ihash = ImageHash("average") 42 | self.assertIn(ihash(self.image), ["0859dd04bfbfbf00", "0859dd04ffbfbf00"]) 43 | 44 | def testColor(self): 45 | """ 46 | Test color hash 47 | """ 48 | 49 | ihash = ImageHash("color") 50 | self.assertIn(ihash(self.image), ["1ffffe02000e000c0e0000070000", "1ff8fe03000e00070e0000070000"]) 51 | 52 | def testDifference(self): 53 | """ 54 | Test difference hash 55 | """ 56 | 57 | ihash = ImageHash("difference") 58 | self.assertEqual(ihash(self.image), "d291996d6969686a") 59 | 60 | def testPerceptual(self): 61 | """ 62 | Test perceptual hash 63 | """ 64 | 65 | ihash = ImageHash("perceptual") 66 | self.assertEqual(ihash(self.image), "8be8418577b331b9") 67 | 68 | def testWavelet(self): 69 | """ 70 | Test wavelet hash 71 | """ 72 | 73 | ihash = ImageHash("wavelet") 74 | self.assertEqual(ihash(Utils.PATH + "/books.jpg"), "68015d85bfbf3f00") 75 | -------------------------------------------------------------------------------- /test/python/testpipeline/testimage/testobjects.py: -------------------------------------------------------------------------------- 1 | """ 2 | Objects module tests 3 | """ 4 | 5 | import unittest 6 | 7 | from txtai.pipeline import Objects 8 | 9 | # pylint: disable = C0411 10 | from utils import Utils 11 | 12 | 13 | class TestObjects(unittest.TestCase): 14 | """ 15 | Object detection tests. 16 | """ 17 | 18 | def testClassification(self): 19 | """ 20 | Test object detection using an image classification model 21 | """ 22 | 23 | objects = Objects(classification=True, threshold=0.3) 24 | self.assertEqual(objects(Utils.PATH + "/books.jpg")[0][0], "library") 25 | 26 | def testDetection(self): 27 | """ 28 | Test object detection using an object detection model 29 | """ 30 | 31 | objects = Objects() 32 | self.assertEqual(objects(Utils.PATH + "/books.jpg")[0][0], "book") 33 | 34 | def testFlatten(self): 35 | """ 36 | Test object detection using an object detection model, flatten to return only objects 37 | """ 38 | 39 | objects = Objects() 40 | self.assertEqual(objects(Utils.PATH + "/books.jpg", flatten=True)[0], "book") 41 | -------------------------------------------------------------------------------- /test/python/testpipeline/testllm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/test/python/testpipeline/testllm/__init__.py -------------------------------------------------------------------------------- /test/python/testpipeline/testllm/testgenerator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generator module tests 3 | """ 4 | 5 | import unittest 6 | 7 | from txtai.pipeline import Generator 8 | 9 | 10 | class TestGenerator(unittest.TestCase): 11 | """ 12 | Sequences tests. 13 | """ 14 | 15 | def testGeneration(self): 16 | """ 17 | Test text pipeline generation 18 | """ 19 | 20 | model = Generator("hf-internal-testing/tiny-random-gpt2") 21 | start = "Hello, how are" 22 | 23 | # Test that text is generated 24 | self.assertIsNotNone(model(start)) 25 | -------------------------------------------------------------------------------- /test/python/testpipeline/testllm/testsequences.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sequences module tests 3 | """ 4 | 5 | import unittest 6 | 7 | from txtai.pipeline import Sequences 8 | 9 | 10 | class TestSequences(unittest.TestCase): 11 | """ 12 | Sequences tests. 13 | """ 14 | 15 | def testGeneration(self): 16 | """ 17 | Test text2text pipeline generation 18 | """ 19 | 20 | model = Sequences("t5-small") 21 | self.assertEqual(model("Testing the model", prefix="translate English to German: "), "Das Modell zu testen") 22 | -------------------------------------------------------------------------------- /test/python/testpipeline/testtext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/test/python/testpipeline/testtext/__init__.py -------------------------------------------------------------------------------- /test/python/testpipeline/testtext/testentity.py: -------------------------------------------------------------------------------- 1 | """ 2 | Entity module tests 3 | """ 4 | 5 | import unittest 6 | 7 | from txtai.pipeline import Entity 8 | 9 | 10 | class TestEntity(unittest.TestCase): 11 | """ 12 | Entity tests. 13 | """ 14 | 15 | @classmethod 16 | def setUpClass(cls): 17 | """ 18 | Create entity instance. 19 | """ 20 | 21 | cls.entity = Entity("dslim/bert-base-NER") 22 | 23 | def testEntity(self): 24 | """ 25 | Test entity 26 | """ 27 | 28 | # Run entity extraction 29 | entities = self.entity("Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg") 30 | self.assertEqual([e[0] for e in entities], ["Canada", "Manhattan"]) 31 | 32 | def testEntityFlatten(self): 33 | """ 34 | Test entity with flattened output 35 | """ 36 | 37 | # Test flatten 38 | entities = self.entity("Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", flatten=True) 39 | self.assertEqual(entities, ["Canada", "Manhattan"]) 40 | 41 | # Test flatten with join 42 | entities = self.entity( 43 | "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", flatten=True, join=True 44 | ) 45 | self.assertEqual(entities, "Canada Manhattan") 46 | 47 | def testEntityTypes(self): 48 | """ 49 | Test entity type filtering 50 | """ 51 | 52 | # Run entity extraction 53 | entities = self.entity("Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", labels=["PER"]) 54 | self.assertFalse(entities) 55 | 56 | def testGliner(self): 57 | """ 58 | Test entity pipeline with a GLiNER model 59 | """ 60 | 61 | entity = Entity("neuml/gliner-bert-tiny") 62 | entities = entity("My name is John Smith.", flatten=True) 63 | self.assertEqual(entities, ["John Smith"]) 64 | -------------------------------------------------------------------------------- /test/python/testpipeline/testtrain/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/test/python/testpipeline/testtrain/__init__.py -------------------------------------------------------------------------------- /test/python/testpipeline/testtrain/testquantization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Quantization module tests 3 | """ 4 | 5 | import platform 6 | import unittest 7 | 8 | from transformers import AutoModel 9 | 10 | from txtai.pipeline import HFModel, HFPipeline 11 | 12 | 13 | class TestQuantization(unittest.TestCase): 14 | """ 15 | Quantization tests. 16 | """ 17 | 18 | @unittest.skipIf(platform.system() == "Darwin", "Quantized models not supported on macOS") 19 | def testModel(self): 20 | """ 21 | Test quantizing a model through HFModel. 22 | """ 23 | 24 | model = HFModel(quantize=True, gpu=False) 25 | model = model.prepare(AutoModel.from_pretrained("google/bert_uncased_L-2_H-128_A-2")) 26 | self.assertIsNotNone(model) 27 | 28 | @unittest.skipIf(platform.system() == "Darwin", "Quantized models not supported on macOS") 29 | def testPipeline(self): 30 | """ 31 | Test quantizing a model through HFPipeline. 32 | """ 33 | 34 | pipeline = HFPipeline("text-classification", "google/bert_uncased_L-2_H-128_A-2", True, False) 35 | self.assertIsNotNone(pipeline) 36 | -------------------------------------------------------------------------------- /test/python/testserialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Serialize module tests 3 | """ 4 | 5 | import os 6 | import unittest 7 | 8 | from unittest.mock import patch 9 | 10 | from txtai.serialize import Serialize, SerializeFactory 11 | 12 | 13 | class TestSerialize(unittest.TestCase): 14 | """ 15 | Serialize tests. 16 | """ 17 | 18 | def testNotImplemented(self): 19 | """ 20 | Test exceptions for non-implemented methods 21 | """ 22 | 23 | serialize = Serialize() 24 | 25 | self.assertRaises(NotImplementedError, serialize.loadstream, None) 26 | self.assertRaises(NotImplementedError, serialize.savestream, None, None) 27 | self.assertRaises(NotImplementedError, serialize.loadbytes, None) 28 | self.assertRaises(NotImplementedError, serialize.savebytes, None) 29 | 30 | def testMessagePack(self): 31 | """ 32 | Test MessagePack encoder 33 | """ 34 | 35 | serializer = SerializeFactory.create() 36 | self.assertEqual(serializer.loadbytes(serializer.savebytes("test")), "test") 37 | 38 | @patch.dict(os.environ, {"ALLOW_PICKLE": "False"}) 39 | def testPickleDisabled(self): 40 | """ 41 | Test disabled pickle serialization 42 | """ 43 | 44 | # Validate an error is raised 45 | with self.assertRaises(ValueError): 46 | serializer = SerializeFactory.create("pickle") 47 | data = serializer.savebytes("Test") 48 | serializer.loadbytes(data) 49 | -------------------------------------------------------------------------------- /test/python/testvectors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuml/txtai/7fffd5e435233d59f4aa4e9d673d2add845081db/test/python/testvectors/__init__.py -------------------------------------------------------------------------------- /test/python/testvectors/testcustom.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom module tests 3 | """ 4 | 5 | import os 6 | import unittest 7 | 8 | import numpy as np 9 | 10 | from txtai.vectors import VectorsFactory 11 | 12 | 13 | class TestCustom(unittest.TestCase): 14 | """ 15 | Custom vectors tests 16 | """ 17 | 18 | @classmethod 19 | def setUpClass(cls): 20 | """ 21 | Create custom vectors instance. 22 | """ 23 | 24 | cls.model = VectorsFactory.create({"method": "txtai.vectors.HFVectors", "path": "sentence-transformers/nli-mpnet-base-v2"}, None) 25 | 26 | def testIndex(self): 27 | """ 28 | Test transformers indexing 29 | """ 30 | 31 | # Generate enough volume to test batching 32 | documents = [(x, "This is a test", None) for x in range(1000)] 33 | 34 | ids, dimension, batches, stream = self.model.index(documents) 35 | 36 | self.assertEqual(len(ids), 1000) 37 | self.assertEqual(dimension, 768) 38 | self.assertEqual(batches, 2) 39 | self.assertIsNotNone(os.path.exists(stream)) 40 | 41 | # Test shape of serialized embeddings 42 | with open(stream, "rb") as queue: 43 | self.assertEqual(np.load(queue).shape, (500, 768)) 44 | 45 | def testNotFound(self): 46 | """ 47 | Test unresolvable vector backend 48 | """ 49 | 50 | with self.assertRaises(ImportError): 51 | VectorsFactory.create({"method": "notfound.vectors"}) 52 | -------------------------------------------------------------------------------- /test/python/testvectors/testexternal.py: -------------------------------------------------------------------------------- 1 | """ 2 | External module tests 3 | """ 4 | 5 | import os 6 | import unittest 7 | 8 | import numpy as np 9 | 10 | from txtai.vectors import External, VectorsFactory 11 | 12 | 13 | class TestExternal(unittest.TestCase): 14 | """ 15 | External vectors tests 16 | """ 17 | 18 | @classmethod 19 | def setUpClass(cls): 20 | """ 21 | Create External vectors instance. 22 | """ 23 | 24 | cls.model = VectorsFactory.create({"method": "external"}, None) 25 | 26 | def testIndex(self): 27 | """ 28 | Test indexing with external vectors 29 | """ 30 | 31 | # Generate dummy data 32 | data = np.random.rand(1000, 768).astype(np.float32) 33 | 34 | # Generate enough volume to test batching 35 | documents = [(x, data[x], None) for x in range(1000)] 36 | 37 | ids, dimension, batches, stream = self.model.index(documents) 38 | 39 | self.assertEqual(len(ids), 1000) 40 | self.assertEqual(dimension, 768) 41 | self.assertEqual(batches, 2) 42 | self.assertIsNotNone(os.path.exists(stream)) 43 | 44 | # Test shape of serialized embeddings 45 | with open(stream, "rb") as queue: 46 | self.assertEqual(np.load(queue).shape, (500, 768)) 47 | 48 | def testMethod(self): 49 | """ 50 | Test method is derived when transform function passed 51 | """ 52 | 53 | model = VectorsFactory.create({"transform": lambda x: x}, None) 54 | self.assertTrue(isinstance(model, External)) 55 | -------------------------------------------------------------------------------- /test/python/testvectors/testllama.py: -------------------------------------------------------------------------------- 1 | """ 2 | Llama module tests 3 | """ 4 | 5 | import os 6 | import unittest 7 | 8 | import numpy as np 9 | 10 | from txtai.vectors import VectorsFactory 11 | 12 | 13 | class TestLlamaCpp(unittest.TestCase): 14 | """ 15 | llama.cpp vectors tests 16 | """ 17 | 18 | @classmethod 19 | def setUpClass(cls): 20 | """ 21 | Create LlamaCpp instance. 22 | """ 23 | 24 | cls.model = VectorsFactory.create({"path": "nomic-ai/nomic-embed-text-v1.5-GGUF/nomic-embed-text-v1.5.Q2_K.gguf"}, None) 25 | 26 | def testIndex(self): 27 | """ 28 | Test indexing with LlamaCpp vectors 29 | """ 30 | 31 | ids, dimension, batches, stream = self.model.index([(0, "test", None)]) 32 | 33 | self.assertEqual(len(ids), 1) 34 | self.assertEqual(dimension, 768) 35 | self.assertEqual(batches, 1) 36 | self.assertIsNotNone(os.path.exists(stream)) 37 | 38 | # Test shape of serialized embeddings 39 | with open(stream, "rb") as queue: 40 | self.assertEqual(np.load(queue).shape, (1, 768)) 41 | -------------------------------------------------------------------------------- /test/python/testvectors/testm2v.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model2Vec module tests 3 | """ 4 | 5 | import os 6 | import unittest 7 | 8 | import numpy as np 9 | 10 | from txtai.vectors import VectorsFactory 11 | 12 | 13 | class TestModel2Vec(unittest.TestCase): 14 | """ 15 | Model2vec vectors tests 16 | """ 17 | 18 | @classmethod 19 | def setUpClass(cls): 20 | """ 21 | Create Model2Vec instance. 22 | """ 23 | 24 | cls.model = VectorsFactory.create({"path": "minishlab/potion-base-8M"}, None) 25 | 26 | def testIndex(self): 27 | """ 28 | Test indexing with Model2Vec vectors 29 | """ 30 | 31 | ids, dimension, batches, stream = self.model.index([(0, "test", None)]) 32 | 33 | self.assertEqual(len(ids), 1) 34 | self.assertEqual(dimension, 256) 35 | self.assertEqual(batches, 1) 36 | self.assertIsNotNone(os.path.exists(stream)) 37 | 38 | # Test shape of serialized embeddings 39 | with open(stream, "rb") as queue: 40 | self.assertEqual(np.load(queue).shape, (1, 256)) 41 | -------------------------------------------------------------------------------- /test/python/testvectors/testsbert.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sentence Transformers module tests 3 | """ 4 | 5 | import os 6 | import unittest 7 | 8 | from unittest.mock import patch 9 | 10 | import numpy as np 11 | 12 | from txtai.vectors import VectorsFactory 13 | 14 | 15 | class TestSTVectors(unittest.TestCase): 16 | """ 17 | STVectors tests 18 | """ 19 | 20 | def testIndex(self): 21 | """ 22 | Test indexing with sentence-transformers vectors 23 | """ 24 | 25 | model = VectorsFactory.create({"method": "sentence-transformers", "path": "paraphrase-MiniLM-L3-v2"}, None) 26 | ids, dimension, batches, stream = model.index([(0, "test", None)]) 27 | 28 | self.assertEqual(len(ids), 1) 29 | self.assertEqual(dimension, 384) 30 | self.assertEqual(batches, 1) 31 | self.assertIsNotNone(os.path.exists(stream)) 32 | 33 | # Test shape of serialized embeddings 34 | with open(stream, "rb") as queue: 35 | self.assertEqual(np.load(queue).shape, (1, 384)) 36 | 37 | @patch("torch.cuda.device_count") 38 | def testMultiGPU(self, count): 39 | """ 40 | Test multiple gpu encoding 41 | """ 42 | 43 | # Mock accelerator count 44 | count.return_value = 2 45 | 46 | model = VectorsFactory.create({"method": "sentence-transformers", "path": "paraphrase-MiniLM-L3-v2", "gpu": "all"}, None) 47 | ids, dimension, batches, stream = model.index([(0, "test", None)]) 48 | 49 | self.assertEqual(len(ids), 1) 50 | self.assertEqual(dimension, 384) 51 | self.assertEqual(batches, 1) 52 | self.assertIsNotNone(os.path.exists(stream)) 53 | 54 | # Test shape of serialized embeddings 55 | with open(stream, "rb") as queue: 56 | self.assertEqual(np.load(queue).shape, (1, 384)) 57 | 58 | # Close the multiprocessing pool 59 | model.close() 60 | -------------------------------------------------------------------------------- /test/python/testvectors/testvectors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Vectors module tests 3 | """ 4 | 5 | import os 6 | import tempfile 7 | import unittest 8 | 9 | import numpy as np 10 | 11 | from txtai.vectors import Vectors, Recovery 12 | 13 | 14 | class TestVectors(unittest.TestCase): 15 | """ 16 | Vectors tests. 17 | """ 18 | 19 | def testNotImplemented(self): 20 | """ 21 | Test exceptions for non-implemented methods 22 | """ 23 | 24 | vectors = Vectors(None, None, None) 25 | 26 | self.assertRaises(NotImplementedError, vectors.load, None) 27 | self.assertRaises(NotImplementedError, vectors.encode, None) 28 | 29 | def testNormalize(self): 30 | """ 31 | Test batch normalize and single input normalize are equal 32 | """ 33 | 34 | vectors = Vectors(None, None, None) 35 | 36 | # Generate data 37 | data1 = np.random.rand(5, 5).astype(np.float32) 38 | data2 = data1.copy() 39 | 40 | # Keep original data to ensure it changed 41 | original = data1.copy() 42 | 43 | # Normalize data 44 | vectors.normalize(data1) 45 | for x in data2: 46 | vectors.normalize(x) 47 | 48 | # Test both data arrays are the same and changed from original 49 | self.assertTrue(np.allclose(data1, data2)) 50 | self.assertFalse(np.allclose(data1, original)) 51 | 52 | def testRecovery(self): 53 | """ 54 | Test vectors recovery failure 55 | """ 56 | 57 | # Checkpoint directory 58 | checkpoint = os.path.join(tempfile.gettempdir(), "recovery") 59 | os.makedirs(checkpoint, exist_ok=True) 60 | 61 | # Create empty file 62 | # pylint: disable=R1732 63 | f = open(os.path.join(checkpoint, "id"), "w", encoding="utf-8") 64 | f.close() 65 | 66 | # Create the recovery instance with an empty checkpoint file 67 | recovery = Recovery(checkpoint, "id") 68 | self.assertIsNone(recovery()) 69 | -------------------------------------------------------------------------------- /test/python/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils module 3 | """ 4 | 5 | 6 | class Utils: 7 | """ 8 | Utility constants and methods 9 | """ 10 | 11 | PATH = "/tmp/txtai" 12 | --------------------------------------------------------------------------------