├── version.txt
├── local_data
└── .gitignore
├── models
└── .gitignore
├── private_gpt
├── components
│ ├── __init__.py
│ ├── ingest
│ │ ├── __init__.py
│ │ └── ingest_helper.py
│ ├── embedding
│ │ ├── __init__.py
│ │ ├── custom
│ │ │ ├── __init__.py
│ │ │ └── sagemaker.py
│ │ └── embedding_component.py
│ ├── llm
│ │ ├── custom
│ │ │ ├── __init__.py
│ │ │ └── sagemaker.py
│ │ ├── __init__.py
│ │ ├── llm_component.py
│ │ └── prompt_helper.py
│ ├── node_store
│ │ ├── __init__.py
│ │ └── node_store_component.py
│ └── vector_store
│ │ ├── __init__.py
│ │ ├── batched_chroma.py
│ │ └── vector_store_component.py
├── server
│ ├── chat
│ │ ├── __init__.py
│ │ ├── chat_router.py
│ │ └── chat_service.py
│ ├── chunks
│ │ ├── __init__.py
│ │ ├── chunks_router.py
│ │ └── chunks_service.py
│ ├── health
│ │ ├── __init__.py
│ │ └── health_router.py
│ ├── ingest
│ │ ├── __init__.py
│ │ ├── model.py
│ │ ├── ingest_watcher.py
│ │ ├── ingest_router.py
│ │ └── ingest_service.py
│ ├── utils
│ │ ├── __init__.py
│ │ └── auth.py
│ ├── embeddings
│ │ ├── __init__.py
│ │ ├── embeddings_service.py
│ │ └── embeddings_router.py
│ ├── __init__.py
│ └── completions
│ │ ├── __init__.py
│ │ └── completions_router.py
├── settings
│ ├── __init__.py
│ ├── yaml.py
│ ├── settings_loader.py
│ └── settings.py
├── ui
│ ├── __init__.py
│ ├── avatar-bot.ico
│ └── images.py
├── utils
│ ├── __init__.py
│ └── typing.py
├── open_ai
│ ├── __init__.py
│ ├── extensions
│ │ ├── __init__.py
│ │ └── context_filter.py
│ └── openai_models.py
├── constants.py
├── main.py
├── __main__.py
├── paths.py
├── di.py
├── __init__.py
└── launcher.py
├── tests
├── __init__.py
├── fixtures
│ ├── __init__.py
│ ├── fast_api_test_client.py
│ ├── auto_close_qdrant.py
│ ├── ingest_helper.py
│ └── mock_injector.py
├── server
│ ├── ingest
│ │ ├── test.pdf
│ │ ├── test_ingest_routes.py
│ │ └── test.txt
│ ├── chunks
│ │ ├── chunk_test.txt
│ │ └── test_chunk_routes.py
│ ├── utils
│ │ ├── test_auth.py
│ │ └── test_simple_auth.py
│ ├── embeddings
│ │ └── test_embedding_routes.py
│ └── chat
│ │ └── test_chat_routes.py
├── ui
│ └── test_ui.py
├── conftest.py
├── settings
│ ├── test_settings.py
│ └── test_settings_loader.py
└── test_prompt_helper.py
├── scripts
├── __init__.py
├── utils.py
├── extract_openapi.py
├── setup
└── ingest_folder.py
├── fern
├── fern.config.json
├── docs
│ ├── assets
│ │ ├── ui.png
│ │ ├── favicon.ico
│ │ ├── header.jpeg
│ │ ├── logo_dark.png
│ │ └── logo_light.png
│ └── pages
│ │ ├── manual
│ │ ├── ingestion-reset.mdx
│ │ ├── vectordb.mdx
│ │ ├── ui.mdx
│ │ ├── settings.mdx
│ │ ├── llms.mdx
│ │ └── ingestion.mdx
│ │ ├── api-reference
│ │ ├── api-reference.mdx
│ │ └── sdks.mdx
│ │ ├── overview
│ │ ├── quickstart.mdx
│ │ └── welcome.mdx
│ │ └── recipes
│ │ └── list-llm.mdx
├── generators.yml
├── README.md
└── docs.yml
├── settings-local.yaml
├── .dockerignore
├── settings-mock.yaml
├── settings-vllm.yaml
├── settings-test.yaml
├── settings-sagemaker.yaml
├── docker-compose.yaml
├── .github
└── workflows
│ ├── release-please.yml
│ ├── fern-check.yml
│ ├── publish-docs.yml
│ ├── actions
│ └── install_dependencies
│ │ └── action.yml
│ ├── stale.yml
│ ├── docker.yml
│ ├── preview-docs.yml
│ └── tests.yml
├── .gitignore
├── settings-docker.yaml
├── CITATION.cff
├── Dockerfile.external
├── .pre-commit-config.yaml
├── Dockerfile.local
├── settings.yaml
├── Makefile
├── CHANGELOG.md
├── pyproject.toml
└── README.md
/version.txt:
--------------------------------------------------------------------------------
1 | 0.2.0
2 |
--------------------------------------------------------------------------------
/local_data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/private_gpt/components/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/private_gpt/server/chat/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/private_gpt/server/chunks/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/private_gpt/server/health/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/private_gpt/server/ingest/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/private_gpt/server/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests."""
2 |
--------------------------------------------------------------------------------
/private_gpt/components/ingest/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/private_gpt/server/embeddings/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/private_gpt/components/embedding/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/private_gpt/components/llm/custom/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/private_gpt/components/node_store/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/private_gpt/components/vector_store/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/private_gpt/components/embedding/custom/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/private_gpt/settings/__init__.py:
--------------------------------------------------------------------------------
1 | """Settings."""
2 |
--------------------------------------------------------------------------------
/private_gpt/ui/__init__.py:
--------------------------------------------------------------------------------
1 | """Gradio based UI."""
2 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | """PrivateGPT scripts."""
2 |
--------------------------------------------------------------------------------
/tests/fixtures/__init__.py:
--------------------------------------------------------------------------------
1 | """Global fixtures."""
2 |
--------------------------------------------------------------------------------
/private_gpt/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """general utils."""
2 |
--------------------------------------------------------------------------------
/private_gpt/server/__init__.py:
--------------------------------------------------------------------------------
1 | """private-gpt server."""
2 |
--------------------------------------------------------------------------------
/private_gpt/components/llm/__init__.py:
--------------------------------------------------------------------------------
1 | """LLM implementations."""
2 |
--------------------------------------------------------------------------------
/private_gpt/open_ai/__init__.py:
--------------------------------------------------------------------------------
1 | """OpenAI compatibility utilities."""
2 |
--------------------------------------------------------------------------------
/private_gpt/open_ai/extensions/__init__.py:
--------------------------------------------------------------------------------
1 | """OpenAI API extensions."""
2 |
--------------------------------------------------------------------------------
/fern/fern.config.json:
--------------------------------------------------------------------------------
1 | {
2 | "organization": "privategpt",
3 | "version": "0.15.3"
4 | }
--------------------------------------------------------------------------------
/fern/docs/assets/ui.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apm/privateGPT/main/fern/docs/assets/ui.png
--------------------------------------------------------------------------------
/private_gpt/server/completions/__init__.py:
--------------------------------------------------------------------------------
1 | """Deprecated Openai compatibility endpoint."""
2 |
--------------------------------------------------------------------------------
/settings-local.yaml:
--------------------------------------------------------------------------------
1 | server:
2 | env_name: ${APP_ENV:local}
3 |
4 | llm:
5 | mode: local
6 |
--------------------------------------------------------------------------------
/fern/docs/assets/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apm/privateGPT/main/fern/docs/assets/favicon.ico
--------------------------------------------------------------------------------
/fern/docs/assets/header.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apm/privateGPT/main/fern/docs/assets/header.jpeg
--------------------------------------------------------------------------------
/private_gpt/ui/avatar-bot.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apm/privateGPT/main/private_gpt/ui/avatar-bot.ico
--------------------------------------------------------------------------------
/tests/server/ingest/test.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apm/privateGPT/main/tests/server/ingest/test.pdf
--------------------------------------------------------------------------------
/fern/docs/assets/logo_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apm/privateGPT/main/fern/docs/assets/logo_dark.png
--------------------------------------------------------------------------------
/fern/docs/assets/logo_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apm/privateGPT/main/fern/docs/assets/logo_light.png
--------------------------------------------------------------------------------
/private_gpt/constants.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | PROJECT_ROOT_PATH: Path = Path(__file__).parents[1]
4 |
--------------------------------------------------------------------------------
/private_gpt/utils/typing.py:
--------------------------------------------------------------------------------
1 | from typing import TypeVar
2 |
3 | T = TypeVar("T")
4 | K = TypeVar("K")
5 | V = TypeVar("V")
6 |
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | .venv
2 | models
3 | .github
4 | .vscode
5 | .DS_Store
6 | .mypy_cache
7 | .ruff_cache
8 | local_data
9 | terraform
10 | tests
11 | Dockerfile
12 | Dockerfile.*
--------------------------------------------------------------------------------
/tests/server/chunks/chunk_test.txt:
--------------------------------------------------------------------------------
1 | e88c1005-637d-4cb4-ae79-9b8eb58cab97
2 |
3 | b483dd15-78c4-4d67-b546-21a0d690bf43
4 |
5 | a8080238-b294-4598-ac9c-7abf4c8e0552
6 |
7 | 14208dac-c600-4a18-872b-5e45354cfff2
--------------------------------------------------------------------------------
/fern/generators.yml:
--------------------------------------------------------------------------------
1 | groups:
2 | public:
3 | generators:
4 | - name: fernapi/fern-python-sdk
5 | version: 0.6.2
6 | output:
7 | location: local-file-system
8 | path: ../../pgpt-sdk/python
9 |
--------------------------------------------------------------------------------
/settings-mock.yaml:
--------------------------------------------------------------------------------
1 | server:
2 | env_name: ${APP_ENV:mock}
3 |
4 | # This configuration allows you to use GPU for creating embeddings while avoiding loading LLM into vRAM
5 | llm:
6 | mode: mock
7 | embedding:
8 | mode: local
9 |
--------------------------------------------------------------------------------
/private_gpt/open_ai/extensions/context_filter.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, Field
2 |
3 |
4 | class ContextFilter(BaseModel):
5 | docs_ids: list[str] | None = Field(
6 | examples=[["c202d5e6-7b69-4869-81cc-dd574ee8ee11"]]
7 | )
8 |
--------------------------------------------------------------------------------
/tests/server/utils/test_auth.py:
--------------------------------------------------------------------------------
1 | from fastapi.testclient import TestClient
2 |
3 |
4 | def test_default_does_not_require_auth(test_client: TestClient) -> None:
5 | response_before = test_client.get("/v1/ingest/list")
6 | assert response_before.status_code == 200
7 |
--------------------------------------------------------------------------------
/settings-vllm.yaml:
--------------------------------------------------------------------------------
1 | llm:
2 | mode: openailike
3 |
4 | embedding:
5 | mode: local
6 | ingest_mode: simple
7 |
8 | local:
9 | embedding_hf_model_name: BAAI/bge-small-en-v1.5
10 |
11 | openai:
12 | api_base: http://localhost:8000/v1
13 | api_key: EMPTY
14 | model: facebook/opt-125m
15 |
--------------------------------------------------------------------------------
/settings-test.yaml:
--------------------------------------------------------------------------------
1 | server:
2 | env_name: test
3 | auth:
4 | enabled: false
5 | # Dummy secrets used for tests
6 | secret: "foo bar; dummy secret"
7 |
8 | data:
9 | local_data_folder: local_data/tests
10 |
11 | qdrant:
12 | path: local_data/tests
13 |
14 | llm:
15 | mode: mock
16 |
17 | ui:
18 | enabled: false
--------------------------------------------------------------------------------
/private_gpt/main.py:
--------------------------------------------------------------------------------
1 | """FastAPI app creation, logger configuration and main API routes."""
2 |
3 | import llama_index
4 |
5 | from private_gpt.di import global_injector
6 | from private_gpt.launcher import create_app
7 |
8 | # Add LlamaIndex simple observability
9 | llama_index.set_global_handler("simple")
10 |
11 | app = create_app(global_injector)
12 |
--------------------------------------------------------------------------------
/settings-sagemaker.yaml:
--------------------------------------------------------------------------------
1 | server:
2 | env_name: ${APP_ENV:prod}
3 | port: ${PORT:8001}
4 |
5 | ui:
6 | enabled: true
7 | path: /
8 |
9 | llm:
10 | mode: sagemaker
11 |
12 | sagemaker:
13 | llm_endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140
14 | embedding_endpoint_name: huggingface-pytorch-inference-2023-11-03-07-41-36-479
--------------------------------------------------------------------------------
/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | services:
2 | private-gpt:
3 | build:
4 | dockerfile: Dockerfile.local
5 | volumes:
6 | - ./local_data/:/home/worker/app/local_data
7 | - ./models/:/home/worker/app/models
8 | ports:
9 | - 8001:8080
10 | environment:
11 | PORT: 8080
12 | PGPT_PROFILES: docker
13 | PGPT_MODE: local
14 |
15 |
--------------------------------------------------------------------------------
/tests/ui/test_ui.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from fastapi.testclient import TestClient
3 |
4 |
5 | @pytest.mark.parametrize(
6 | "test_client", [{"ui": {"enabled": True, "path": "/ui"}}], indirect=True
7 | )
8 | def test_ui_starts_in_the_given_endpoint(test_client: TestClient) -> None:
9 | response = test_client.get("/ui")
10 | assert response.status_code == 200
11 |
--------------------------------------------------------------------------------
/.github/workflows/release-please.yml:
--------------------------------------------------------------------------------
1 | name: release-please
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | permissions:
9 | contents: write
10 | pull-requests: write
11 |
12 | jobs:
13 | release-please:
14 | runs-on: ubuntu-latest
15 | steps:
16 | - uses: google-github-actions/release-please-action@v3
17 | with:
18 | release-type: simple
19 | version-file: version.txt
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .venv
2 | .env
3 | venv
4 |
5 | settings-me.yaml
6 |
7 | .ruff_cache
8 | .pytest_cache
9 | .mypy_cache
10 |
11 | # byte-compiled / optimized / DLL files
12 | __pycache__/
13 | *.py[cod]
14 |
15 | # unit tests / coverage reports
16 | /tests-results.xml
17 | /.coverage
18 | /coverage.xml
19 | /htmlcov/
20 |
21 | # pyenv
22 | /.python-version
23 |
24 | # IDE
25 | .idea/
26 | .vscode/
27 | /.run/
28 | .fleet/
29 |
30 | # macOS
31 | .DS_Store
32 |
--------------------------------------------------------------------------------
/.github/workflows/fern-check.yml:
--------------------------------------------------------------------------------
1 | name: fern check
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - main
7 | paths:
8 | - "fern/**"
9 |
10 | jobs:
11 | fern-check:
12 | runs-on: ubuntu-latest
13 | steps:
14 | - name: Checkout repo
15 | uses: actions/checkout@v4
16 |
17 | - name: Install Fern
18 | run: npm install -g fern-api
19 |
20 | - name: Check Fern API is valid
21 | run: fern check
--------------------------------------------------------------------------------
/fern/docs/pages/manual/ingestion-reset.mdx:
--------------------------------------------------------------------------------
1 | # Reset Local documents database
2 |
3 | When running in a local setup, you can remove all ingested documents by simply
4 | deleting all contents of `local_data` folder (except .gitignore).
5 |
6 | To simplify this process, you can use the command:
7 | ```bash
8 | make wipe
9 | ```
10 |
11 | # Advanced usage
12 |
13 | You can actually delete your documents from your storage by using the
14 | API endpoint `DELETE` in the Ingestion API.
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | from glob import glob
4 |
5 | root_path = pathlib.Path(__file__).parents[1]
6 | # This is to prevent a bug in intellij that uses the wrong working directory
7 | os.chdir(root_path)
8 |
9 |
10 | def _as_module(fixture_path: str) -> str:
11 | return fixture_path.replace("/", ".").replace("\\", ".").replace(".py", "")
12 |
13 |
14 | pytest_plugins = [_as_module(fixture) for fixture in glob("tests/fixtures/[!_]*.py")]
15 |
--------------------------------------------------------------------------------
/private_gpt/__main__.py:
--------------------------------------------------------------------------------
1 | # start a fastapi server with uvicorn
2 |
3 | import uvicorn
4 |
5 | from private_gpt.main import app
6 | from private_gpt.settings.settings import settings
7 |
8 | # Set log_config=None to do not use the uvicorn logging configuration, and
9 | # use ours instead. For reference, see below:
10 | # https://github.com/tiangolo/fastapi/discussions/7457#discussioncomment-5141108
11 | uvicorn.run(app, host="0.0.0.0", port=settings().server.port, log_config=None)
12 |
--------------------------------------------------------------------------------
/tests/settings/test_settings.py:
--------------------------------------------------------------------------------
1 | from private_gpt.settings.settings import Settings, settings
2 | from tests.fixtures.mock_injector import MockInjector
3 |
4 |
5 | def test_settings_are_loaded_and_merged() -> None:
6 | assert settings().server.env_name == "test"
7 |
8 |
9 | def test_settings_can_be_overriden(injector: MockInjector) -> None:
10 | injector.bind_settings({"server": {"env_name": "overriden"}})
11 | mocked_settings = injector.get(Settings)
12 | assert mocked_settings.server.env_name == "overriden"
13 |
--------------------------------------------------------------------------------
/private_gpt/server/health/health_router.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | from fastapi import APIRouter
4 | from pydantic import BaseModel, Field
5 |
6 | # Not authentication or authorization required to get the health status.
7 | health_router = APIRouter()
8 |
9 |
10 | class HealthResponse(BaseModel):
11 | status: Literal["ok"] = Field(default="ok")
12 |
13 |
14 | @health_router.get("/health", tags=["Health"])
15 | def health() -> HealthResponse:
16 | """Return ok if the system is up."""
17 | return HealthResponse(status="ok")
18 |
--------------------------------------------------------------------------------
/tests/fixtures/fast_api_test_client.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from fastapi.testclient import TestClient
3 |
4 | from private_gpt.launcher import create_app
5 | from tests.fixtures.mock_injector import MockInjector
6 |
7 |
8 | @pytest.fixture()
9 | def test_client(request: pytest.FixtureRequest, injector: MockInjector) -> TestClient:
10 | if request is not None and hasattr(request, "param"):
11 | injector.bind_settings(request.param or {})
12 |
13 | app_under_test = create_app(injector.test_injector)
14 | return TestClient(app_under_test)
15 |
--------------------------------------------------------------------------------
/private_gpt/paths.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from private_gpt.constants import PROJECT_ROOT_PATH
4 | from private_gpt.settings.settings import settings
5 |
6 |
7 | def _absolute_or_from_project_root(path: str) -> Path:
8 | if path.startswith("/"):
9 | return Path(path)
10 | return PROJECT_ROOT_PATH / path
11 |
12 |
13 | models_path: Path = PROJECT_ROOT_PATH / "models"
14 | models_cache_path: Path = models_path / "cache"
15 | docs_path: Path = PROJECT_ROOT_PATH / "docs"
16 | local_data_path: Path = _absolute_or_from_project_root(
17 | settings().data.local_data_folder
18 | )
19 |
--------------------------------------------------------------------------------
/private_gpt/di.py:
--------------------------------------------------------------------------------
1 | from injector import Injector
2 |
3 | from private_gpt.settings.settings import Settings, unsafe_typed_settings
4 |
5 |
6 | def create_application_injector() -> Injector:
7 | _injector = Injector(auto_bind=True)
8 | _injector.binder.bind(Settings, to=unsafe_typed_settings)
9 | return _injector
10 |
11 |
12 | """
13 | Global injector for the application.
14 |
15 | Avoid using this reference, it will make your code harder to test.
16 |
17 | Instead, use the `request.state.injector` reference, which is bound to every request
18 | """
19 | global_injector: Injector = create_application_injector()
20 |
--------------------------------------------------------------------------------
/.github/workflows/publish-docs.yml:
--------------------------------------------------------------------------------
1 | name: publish docs
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | paths:
8 | - "fern/**"
9 |
10 | jobs:
11 | publish-docs:
12 | runs-on: ubuntu-latest
13 | steps:
14 | - name: Checkout repo
15 | uses: actions/checkout@v4
16 |
17 | - name: Setup node
18 | uses: actions/setup-node@v3
19 |
20 | - name: Download Fern
21 | run: npm install -g fern-api
22 |
23 | - name: Generate and Publish Docs
24 | env:
25 | FERN_TOKEN: ${{ secrets.FERN_TOKEN }}
26 | run: fern generate --docs --log-level debug
27 |
--------------------------------------------------------------------------------
/settings-docker.yaml:
--------------------------------------------------------------------------------
1 | server:
2 | env_name: ${APP_ENV:prod}
3 | port: ${PORT:8080}
4 |
5 | llm:
6 | mode: ${PGPT_MODE:mock}
7 |
8 | embedding:
9 | mode: ${PGPT_MODE:sagemaker}
10 |
11 | local:
12 | llm_hf_repo_id: ${PGPT_HF_REPO_ID:TheBloke/Mistral-7B-Instruct-v0.1-GGUF}
13 | llm_hf_model_file: ${PGPT_HF_MODEL_FILE:mistral-7b-instruct-v0.1.Q4_K_M.gguf}
14 | embedding_hf_model_name: ${PGPT_EMBEDDING_HF_MODEL_NAME:BAAI/bge-small-en-v1.5}
15 |
16 | sagemaker:
17 | llm_endpoint_name: ${PGPT_SAGEMAKER_LLM_ENDPOINT_NAME:}
18 | embedding_endpoint_name: ${PGPT_SAGEMAKER_EMBEDDING_ENDPOINT_NAME:}
19 |
20 | ui:
21 | enabled: true
22 | path: /
23 |
--------------------------------------------------------------------------------
/tests/fixtures/auto_close_qdrant.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from private_gpt.components.vector_store.vector_store_component import (
4 | VectorStoreComponent,
5 | )
6 | from tests.fixtures.mock_injector import MockInjector
7 |
8 |
9 | @pytest.fixture(autouse=True)
10 | def _auto_close_vector_store_client(injector: MockInjector) -> None:
11 | """Auto close VectorStore client after each test.
12 |
13 | VectorStore client (qdrant/chromadb) opens a connection the
14 | Database that causes issues when running tests too fast,
15 | so close explicitly after each test.
16 | """
17 | yield
18 | injector.get(VectorStoreComponent).close()
19 |
--------------------------------------------------------------------------------
/tests/server/embeddings/test_embedding_routes.py:
--------------------------------------------------------------------------------
1 | from fastapi.testclient import TestClient
2 |
3 | from private_gpt.server.embeddings.embeddings_router import (
4 | EmbeddingsBody,
5 | EmbeddingsResponse,
6 | )
7 |
8 |
9 | def test_embeddings_generation(test_client: TestClient) -> None:
10 | body = EmbeddingsBody(input="Embed me")
11 | response = test_client.post("/v1/embeddings", json=body.model_dump())
12 |
13 | assert response.status_code == 200
14 | embedding_response = EmbeddingsResponse.model_validate(response.json())
15 | assert len(embedding_response.data) > 0
16 | assert len(embedding_response.data[0].embedding) > 0
17 |
--------------------------------------------------------------------------------
/fern/docs/pages/api-reference/api-reference.mdx:
--------------------------------------------------------------------------------
1 | # API Reference
2 |
3 | The API is divided in two logical blocks:
4 |
5 | 1. High-level API, abstracting all the complexity of a RAG (Retrieval Augmented Generation) pipeline implementation:
6 | - Ingestion of documents: internally managing document parsing, splitting, metadata extraction,
7 | embedding generation and storage.
8 | - Chat & Completions using context from ingested documents: abstracting the retrieval of context, the prompt
9 | engineering and the response generation.
10 |
11 | 2. Low-level API, allowing advanced users to implement their own complex pipelines:
12 | - Embeddings generation: based on a piece of text.
13 | - Contextual chunks retrieval: given a query, returns the most relevant chunks of text from the ingested
14 | documents.
--------------------------------------------------------------------------------
/tests/fixtures/ingest_helper.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import pytest
4 | from fastapi.testclient import TestClient
5 |
6 | from private_gpt.server.ingest.ingest_router import IngestResponse
7 |
8 |
9 | class IngestHelper:
10 | def __init__(self, test_client: TestClient):
11 | self.test_client = test_client
12 |
13 | def ingest_file(self, path: Path) -> IngestResponse:
14 | files = {"file": (path.name, path.open("rb"))}
15 |
16 | response = self.test_client.post("/v1/ingest/file", files=files)
17 | assert response.status_code == 200
18 | ingest_result = IngestResponse.model_validate(response.json())
19 | return ingest_result
20 |
21 |
22 | @pytest.fixture()
23 | def ingest_helper(test_client: TestClient) -> IngestHelper:
24 | return IngestHelper(test_client)
25 |
--------------------------------------------------------------------------------
/tests/server/chunks/test_chunk_routes.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from fastapi.testclient import TestClient
4 |
5 | from private_gpt.server.chunks.chunks_router import ChunksBody, ChunksResponse
6 | from tests.fixtures.ingest_helper import IngestHelper
7 |
8 |
9 | def test_chunks_retrieval(test_client: TestClient, ingest_helper: IngestHelper) -> None:
10 | # Make sure there is at least some chunk to query in the database
11 | path = Path(__file__).parents[0] / "chunk_test.txt"
12 | ingest_helper.ingest_file(path)
13 |
14 | body = ChunksBody(text="b483dd15-78c4-4d67-b546-21a0d690bf43")
15 | response = test_client.post("/v1/chunks", json=body.model_dump())
16 | assert response.status_code == 200
17 | chunk_response = ChunksResponse.model_validate(response.json())
18 | assert len(chunk_response.data) > 0
19 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | # This CITATION.cff file was generated with cffinit.
2 | # Visit https://bit.ly/cffinit to generate yours today!
3 |
4 | cff-version: 1.2.0
5 | title: PrivateGPT
6 | message: >-
7 | If you use this software, please cite it using the
8 | metadata from this file.
9 | type: software
10 | authors:
11 | - given-names: Iván
12 | family-names: Martínez Toro
13 | email: ivanmartit@gmail.com
14 | orcid: 'https://orcid.org/0009-0004-5065-2311'
15 | - family-names: Gallego Vico
16 | given-names: Daniel
17 | email: danielgallegovico@gmail.com
18 | orcid: 'https://orcid.org/0009-0006-8582-4384'
19 | - given-names: Pablo
20 | family-names: Orgaz
21 | email: pabloogc+gh@gmail.com
22 | orcid: 'https://orcid.org/0009-0008-0080-1437'
23 | repository-code: 'https://github.com/imartinez/privateGPT'
24 | license: Apache-2.0
25 | date-released: '2023-05-02'
26 |
--------------------------------------------------------------------------------
/.github/workflows/actions/install_dependencies/action.yml:
--------------------------------------------------------------------------------
1 | name: "Install Dependencies"
2 | description: "Action to build the project dependencies from the main versions"
3 | inputs:
4 | python_version:
5 | required: true
6 | type: string
7 | default: "3.11.4"
8 | poetry_version:
9 | required: true
10 | type: string
11 | default: "1.5.1"
12 |
13 | runs:
14 | using: composite
15 | steps:
16 | - name: Install Poetry
17 | uses: snok/install-poetry@v1
18 | with:
19 | version: ${{ inputs.poetry_version }}
20 | virtualenvs-create: true
21 | virtualenvs-in-project: false
22 | installer-parallel: true
23 | - uses: actions/setup-python@v4
24 | with:
25 | python-version: ${{ inputs.python_version }}
26 | cache: "poetry"
27 | - name: Install Dependencies
28 | run: poetry install --with ui --no-root
29 | shell: bash
30 |
31 |
--------------------------------------------------------------------------------
/.github/workflows/stale.yml:
--------------------------------------------------------------------------------
1 | # This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.
2 | #
3 | # You can adjust the behavior by modifying this file.
4 | # For more information, see:
5 | # https://github.com/actions/stale
6 | name: Mark stale issues and pull requests
7 |
8 | on:
9 | schedule:
10 | - cron: '42 5 * * *'
11 |
12 | jobs:
13 | stale:
14 |
15 | runs-on: ubuntu-latest
16 | permissions:
17 | issues: write
18 | pull-requests: write
19 |
20 | steps:
21 | - uses: actions/stale@v8
22 | with:
23 | repo-token: ${{ secrets.GITHUB_TOKEN }}
24 | days-before-stale: 15
25 | stale-issue-message: 'Stale issue'
26 | stale-pr-message: 'Stale pull request'
27 | stale-issue-label: 'stale'
28 | stale-pr-label: 'stale'
29 | exempt-issue-labels: 'autorelease: pending'
30 | exempt-pr-labels: 'autorelease: pending'
31 |
--------------------------------------------------------------------------------
/private_gpt/__init__.py:
--------------------------------------------------------------------------------
1 | """private-gpt."""
2 | import logging
3 | import os
4 |
5 | # Set to 'DEBUG' to have extensive logging turned on, even for libraries
6 | ROOT_LOG_LEVEL = "INFO"
7 |
8 | PRETTY_LOG_FORMAT = (
9 | "%(asctime)s.%(msecs)03d [%(levelname)-8s] %(name)+25s - %(message)s"
10 | )
11 | logging.basicConfig(level=ROOT_LOG_LEVEL, format=PRETTY_LOG_FORMAT, datefmt="%H:%M:%S")
12 | logging.captureWarnings(True)
13 |
14 | # Disable gradio analytics
15 | # This is done this way because gradio does not solely rely on what values are
16 | # passed to gr.Blocks(enable_analytics=...) but also on the environment
17 | # variable GRADIO_ANALYTICS_ENABLED. `gradio.strings` actually reads this env
18 | # directly, so to fully disable gradio analytics we need to set this env var.
19 | os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
20 |
21 | # Disable chromaDB telemetry
22 | # It is already disabled, see PR#1144
23 | # os.environ["ANONYMIZED_TELEMETRY"] = "False"
24 |
--------------------------------------------------------------------------------
/fern/docs/pages/overview/quickstart.mdx:
--------------------------------------------------------------------------------
1 | ## Local Installation steps
2 |
3 | The steps in [Installation](/installation) section are better explained and cover more
4 | setup scenarios (macOS, Windows, Linux).
5 | But if you like one-liners, have python3.11 installed, and you are running a UNIX (macOS or Linux)
6 | system, you can get up and running on CPU in few lines:
7 |
8 | ```bash
9 | git clone https://github.com/imartinez/privateGPT && cd privateGPT && \
10 | python3.11 -m venv .venv && source .venv/bin/activate && \
11 | pip install --upgrade pip poetry && poetry install --with ui,local && ./scripts/setup
12 |
13 | # Launch the privateGPT API server **and** the gradio UI
14 | poetry run python3.11 -m private_gpt
15 |
16 | # In another terminal, create a new browser window on your private GPT!
17 | open http://127.0.0.1:8001/
18 | ```
19 |
20 | The above is not working, or it is too slow, so **you want to run it on GPU(s)**?
21 | Please check the more detailed [installation guide](/installation).
22 |
--------------------------------------------------------------------------------
/fern/docs/pages/api-reference/sdks.mdx:
--------------------------------------------------------------------------------
1 | We use [Fern](www.buildwithfern.com) to offer API clients for Node.js, Python, Go, and Java.
2 | We recommend using these clients to interact with our endpoints.
3 | The clients are kept up to date automatically, so we encourage you to use the latest version.
4 |
5 | ## SDKs
6 |
7 | *Coming soon!*
8 |
9 |
10 |
15 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
31 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/private_gpt/server/embeddings/embeddings_service.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | from injector import inject, singleton
4 | from pydantic import BaseModel, Field
5 |
6 | from private_gpt.components.embedding.embedding_component import EmbeddingComponent
7 |
8 |
9 | class Embedding(BaseModel):
10 | index: int
11 | object: Literal["embedding"]
12 | embedding: list[float] = Field(examples=[[0.0023064255, -0.009327292]])
13 |
14 |
15 | @singleton
16 | class EmbeddingsService:
17 | @inject
18 | def __init__(self, embedding_component: EmbeddingComponent) -> None:
19 | self.embedding_model = embedding_component.embedding_model
20 |
21 | def texts_embeddings(self, texts: list[str]) -> list[Embedding]:
22 | texts_embeddings = self.embedding_model.get_text_embedding_batch(texts)
23 | return [
24 | Embedding(
25 | index=texts_embeddings.index(embedding),
26 | object="embedding",
27 | embedding=embedding,
28 | )
29 | for embedding in texts_embeddings
30 | ]
31 |
--------------------------------------------------------------------------------
/private_gpt/server/ingest/model.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Literal
2 |
3 | from llama_index import Document
4 | from pydantic import BaseModel, Field
5 |
6 |
7 | class IngestedDoc(BaseModel):
8 | object: Literal["ingest.document"]
9 | doc_id: str = Field(examples=["c202d5e6-7b69-4869-81cc-dd574ee8ee11"])
10 | doc_metadata: dict[str, Any] | None = Field(
11 | examples=[
12 | {
13 | "page_label": "2",
14 | "file_name": "Sales Report Q3 2023.pdf",
15 | }
16 | ]
17 | )
18 |
19 | @staticmethod
20 | def curate_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
21 | """Remove unwanted metadata keys."""
22 | for key in ["doc_id", "window", "original_text"]:
23 | metadata.pop(key, None)
24 | return metadata
25 |
26 | @staticmethod
27 | def from_document(document: Document) -> "IngestedDoc":
28 | return IngestedDoc(
29 | object="ingest.document",
30 | doc_id=document.doc_id,
31 | doc_metadata=IngestedDoc.curate_metadata(document.metadata),
32 | )
33 |
--------------------------------------------------------------------------------
/Dockerfile.external:
--------------------------------------------------------------------------------
1 | FROM python:3.11.6-slim-bookworm as base
2 |
3 | # Install poetry
4 | RUN pip install pipx
5 | RUN python3 -m pipx ensurepath
6 | RUN pipx install poetry
7 | ENV PATH="/root/.local/bin:$PATH"
8 | ENV PATH=".venv/bin/:$PATH"
9 |
10 | # https://python-poetry.org/docs/configuration/#virtualenvsin-project
11 | ENV POETRY_VIRTUALENVS_IN_PROJECT=true
12 |
13 | FROM base as dependencies
14 | WORKDIR /home/worker/app
15 | COPY pyproject.toml poetry.lock ./
16 |
17 | RUN poetry install --with ui
18 |
19 | FROM base as app
20 |
21 | ENV PYTHONUNBUFFERED=1
22 | ENV PORT=8080
23 | EXPOSE 8080
24 |
25 | # Prepare a non-root user
26 | RUN adduser --system worker
27 | WORKDIR /home/worker/app
28 |
29 | RUN mkdir local_data; chown worker local_data
30 | RUN mkdir models; chown worker models
31 | COPY --chown=worker --from=dependencies /home/worker/app/.venv/ .venv
32 | COPY --chown=worker private_gpt/ private_gpt
33 | COPY --chown=worker fern/ fern
34 | COPY --chown=worker *.yaml *.md ./
35 | COPY --chown=worker scripts/ scripts
36 |
37 | ENV PYTHONPATH="$PYTHONPATH:/private_gpt/"
38 |
39 | USER worker
40 | ENTRYPOINT python -m private_gpt
--------------------------------------------------------------------------------
/scripts/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 |
5 |
6 | def wipe():
7 | path = "local_data"
8 | print(f"Wiping {path}...")
9 | all_files = os.listdir(path)
10 |
11 | files_to_remove = [file for file in all_files if file != ".gitignore"]
12 | for file_name in files_to_remove:
13 | file_path = os.path.join(path, file_name)
14 | try:
15 | if os.path.isfile(file_path):
16 | os.remove(file_path)
17 | elif os.path.isdir(file_path):
18 | shutil.rmtree(file_path)
19 | print(f" - Deleted {file_path}")
20 | except PermissionError:
21 | print(
22 | f"PermissionError: Unable to remove {file_path}. It is in use by another process."
23 | )
24 | continue
25 |
26 |
27 | if __name__ == "__main__":
28 | commands = {
29 | "wipe": wipe,
30 | }
31 |
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument(
34 | "mode", help="select a mode to run", choices=list(commands.keys())
35 | )
36 | args = parser.parse_args()
37 | commands[args.mode.lower()]()
38 |
--------------------------------------------------------------------------------
/scripts/extract_openapi.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import sys
4 |
5 | import yaml
6 | from uvicorn.importer import import_from_string
7 |
8 | parser = argparse.ArgumentParser(prog="extract_openapi.py")
9 | parser.add_argument("app", help='App import string. Eg. "main:app"', default="main:app")
10 | parser.add_argument("--app-dir", help="Directory containing the app", default=None)
11 | parser.add_argument(
12 | "--out", help="Output file ending in .json or .yaml", default="openapi.yaml"
13 | )
14 |
15 | if __name__ == "__main__":
16 | args = parser.parse_args()
17 |
18 | if args.app_dir is not None:
19 | print(f"adding {args.app_dir} to sys.path")
20 | sys.path.insert(0, args.app_dir)
21 |
22 | print(f"importing app from {args.app}")
23 | app = import_from_string(args.app)
24 | openapi = app.openapi()
25 | version = openapi.get("openapi", "unknown version")
26 |
27 | print(f"writing openapi spec v{version}")
28 | with open(args.out, "w") as f:
29 | if args.out.endswith(".json"):
30 | json.dump(openapi, f, indent=2)
31 | else:
32 | yaml.dump(openapi, f, sort_keys=False)
33 |
34 | print(f"spec written to {args.out}")
35 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_install_hook_types:
2 | # Mandatory to install both pre-commit and pre-push hooks (see https://pre-commit.com/#top_level-default_install_hook_types)
3 | # Add new hook types here to ensure automatic installation when running `pre-commit install`
4 | - pre-commit
5 | - pre-push
6 | repos:
7 | - repo: https://github.com/pre-commit/pre-commit-hooks
8 | rev: v4.3.0
9 | hooks:
10 | - id: trailing-whitespace
11 | - id: end-of-file-fixer
12 | - id: check-yaml
13 | - id: check-json
14 | - id: check-added-large-files
15 |
16 | - repo: local
17 | hooks:
18 | - id: black
19 | name: Formatting (black)
20 | entry: black
21 | language: system
22 | types: [python]
23 | stages: [commit]
24 | - id: ruff
25 | name: Linter (ruff)
26 | entry: ruff
27 | language: system
28 | types: [python]
29 | stages: [commit]
30 | - id: mypy
31 | name: Type checking (mypy)
32 | entry: make mypy
33 | pass_filenames: false
34 | language: system
35 | types: [python]
36 | stages: [commit]
37 | - id: test
38 | name: Unit tests (pytest)
39 | entry: make test
40 | pass_filenames: false
41 | language: system
42 | types: [python]
43 | stages: [push]
--------------------------------------------------------------------------------
/private_gpt/components/node_store/node_store_component.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from injector import inject, singleton
4 | from llama_index.storage.docstore import BaseDocumentStore, SimpleDocumentStore
5 | from llama_index.storage.index_store import SimpleIndexStore
6 | from llama_index.storage.index_store.types import BaseIndexStore
7 |
8 | from private_gpt.paths import local_data_path
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | @singleton
14 | class NodeStoreComponent:
15 | index_store: BaseIndexStore
16 | doc_store: BaseDocumentStore
17 |
18 | @inject
19 | def __init__(self) -> None:
20 | try:
21 | self.index_store = SimpleIndexStore.from_persist_dir(
22 | persist_dir=str(local_data_path)
23 | )
24 | except FileNotFoundError:
25 | logger.debug("Local index store not found, creating a new one")
26 | self.index_store = SimpleIndexStore()
27 |
28 | try:
29 | self.doc_store = SimpleDocumentStore.from_persist_dir(
30 | persist_dir=str(local_data_path)
31 | )
32 | except FileNotFoundError:
33 | logger.debug("Local document store not found, creating a new one")
34 | self.doc_store = SimpleDocumentStore()
35 |
--------------------------------------------------------------------------------
/private_gpt/server/embeddings/embeddings_router.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | from fastapi import APIRouter, Depends, Request
4 | from pydantic import BaseModel
5 |
6 | from private_gpt.server.embeddings.embeddings_service import (
7 | Embedding,
8 | EmbeddingsService,
9 | )
10 | from private_gpt.server.utils.auth import authenticated
11 |
12 | embeddings_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
13 |
14 |
15 | class EmbeddingsBody(BaseModel):
16 | input: str | list[str]
17 |
18 |
19 | class EmbeddingsResponse(BaseModel):
20 | object: Literal["list"]
21 | model: Literal["private-gpt"]
22 | data: list[Embedding]
23 |
24 |
25 | @embeddings_router.post("/embeddings", tags=["Embeddings"])
26 | def embeddings_generation(request: Request, body: EmbeddingsBody) -> EmbeddingsResponse:
27 | """Get a vector representation of a given input.
28 |
29 | That vector representation can be easily consumed
30 | by machine learning models and algorithms.
31 | """
32 | service = request.state.injector.get(EmbeddingsService)
33 | input_texts = body.input if isinstance(body.input, list) else [body.input]
34 | embeddings = service.texts_embeddings(input_texts)
35 | return EmbeddingsResponse(object="list", model="private-gpt", data=embeddings)
36 |
--------------------------------------------------------------------------------
/tests/server/chat/test_chat_routes.py:
--------------------------------------------------------------------------------
1 | from fastapi.testclient import TestClient
2 |
3 | from private_gpt.open_ai.openai_models import OpenAICompletion, OpenAIMessage
4 | from private_gpt.server.chat.chat_router import ChatBody
5 |
6 |
7 | def test_chat_route_produces_a_stream(test_client: TestClient) -> None:
8 | body = ChatBody(
9 | messages=[OpenAIMessage(content="test", role="user")],
10 | use_context=False,
11 | stream=True,
12 | )
13 | response = test_client.post("/v1/chat/completions", json=body.model_dump())
14 |
15 | raw_events = response.text.split("\n\n")
16 | events = [
17 | item.removeprefix("data: ") for item in raw_events if item.startswith("data: ")
18 | ]
19 | assert response.status_code == 200
20 | assert "text/event-stream" in response.headers["content-type"]
21 | assert len(events) > 0
22 | assert events[-1] == "[DONE]"
23 |
24 |
25 | def test_chat_route_produces_a_single_value(test_client: TestClient) -> None:
26 | body = ChatBody(
27 | messages=[OpenAIMessage(content="test", role="user")],
28 | use_context=False,
29 | stream=False,
30 | )
31 | response = test_client.post("/v1/chat/completions", json=body.model_dump())
32 |
33 | # No asserts, if it validates it's good
34 | OpenAICompletion.model_validate(response.json())
35 | assert response.status_code == 200
36 |
--------------------------------------------------------------------------------
/tests/settings/test_settings_loader.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 |
4 | import pytest
5 |
6 | from private_gpt.settings.yaml import load_yaml_with_envvars
7 |
8 |
9 | def test_environment_variables_are_loaded() -> None:
10 | sample_yaml = """
11 | replaced: ${TEST_REPLACE_ME}
12 | """
13 | env = {"TEST_REPLACE_ME": "replaced"}
14 | loaded = load_yaml_with_envvars(io.StringIO(sample_yaml), env)
15 | os.environ.copy()
16 | assert loaded["replaced"] == "replaced"
17 |
18 |
19 | def test_environment_defaults_variables_are_loaded() -> None:
20 | sample_yaml = """
21 | replaced: ${PGPT_EMBEDDING_HF_MODEL_NAME:BAAI/bge-small-en-v1.5}
22 | """
23 | loaded = load_yaml_with_envvars(io.StringIO(sample_yaml), {})
24 | assert loaded["replaced"] == "BAAI/bge-small-en-v1.5"
25 |
26 |
27 | def test_environment_defaults_variables_are_loaded_with_duplicated_delimiters() -> None:
28 | sample_yaml = """
29 | replaced: ${PGPT_EMBEDDING_HF_MODEL_NAME::duped::}
30 | """
31 | loaded = load_yaml_with_envvars(io.StringIO(sample_yaml), {})
32 | assert loaded["replaced"] == ":duped::"
33 |
34 |
35 | def test_environment_without_defaults_fails() -> None:
36 | sample_yaml = """
37 | replaced: ${TEST_REPLACE_ME}
38 | """
39 | with pytest.raises(ValueError) as error:
40 | load_yaml_with_envvars(io.StringIO(sample_yaml), {})
41 | assert error is not None
42 |
--------------------------------------------------------------------------------
/.github/workflows/docker.yml:
--------------------------------------------------------------------------------
1 | name: docker
2 |
3 | on:
4 | release:
5 | types: [ published ]
6 | workflow_dispatch:
7 |
8 | env:
9 | REGISTRY: ghcr.io
10 | IMAGE_NAME: ${{ github.repository }}
11 |
12 | jobs:
13 | build-and-push-image:
14 | runs-on: ubuntu-latest
15 | permissions:
16 | contents: read
17 | packages: write
18 | steps:
19 | - name: Checkout repository
20 | uses: actions/checkout@v4
21 | - name: Log in to the Container registry
22 | uses: docker/login-action@v3
23 | with:
24 | registry: ${{ env.REGISTRY }}
25 | username: ${{ github.actor }}
26 | password: ${{ secrets.GITHUB_TOKEN }}
27 | - name: Extract metadata (tags, labels) for Docker
28 | id: meta
29 | uses: docker/metadata-action@v5
30 | with:
31 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
32 | tags: |
33 | type=ref,event=branch
34 | type=ref,event=pr
35 | type=semver,pattern={{version}}
36 | type=semver,pattern={{major}}.{{minor}}
37 | type=sha
38 | - name: Build and push Docker image
39 | uses: docker/build-push-action@v5
40 | with:
41 | context: .
42 | file: Dockerfile.external
43 | push: true
44 | tags: ${{ steps.meta.outputs.tags }}
45 | labels: ${{ steps.meta.outputs.labels }}
46 |
--------------------------------------------------------------------------------
/Dockerfile.local:
--------------------------------------------------------------------------------
1 | ### IMPORTANT, THIS IMAGE CAN ONLY BE RUN IN LINUX DOCKER
2 | ### You will run into a segfault in mac
3 | FROM python:3.11.6-slim-bookworm as base
4 |
5 | # Install poetry
6 | RUN pip install pipx
7 | RUN python3 -m pipx ensurepath
8 | RUN pipx install poetry
9 | ENV PATH="/root/.local/bin:$PATH"
10 | ENV PATH=".venv/bin/:$PATH"
11 |
12 | # Dependencies to build llama-cpp
13 | RUN apt update && apt install -y \
14 | libopenblas-dev\
15 | ninja-build\
16 | build-essential\
17 | pkg-config\
18 | wget
19 |
20 | # https://python-poetry.org/docs/configuration/#virtualenvsin-project
21 | ENV POETRY_VIRTUALENVS_IN_PROJECT=true
22 |
23 | FROM base as dependencies
24 | WORKDIR /home/worker/app
25 | COPY pyproject.toml poetry.lock ./
26 |
27 | RUN poetry install --with local
28 | RUN poetry install --with ui
29 |
30 | FROM base as app
31 |
32 | ENV PYTHONUNBUFFERED=1
33 | ENV PORT=8080
34 | EXPOSE 8080
35 |
36 | # Prepare a non-root user
37 | RUN adduser --system worker
38 | WORKDIR /home/worker/app
39 |
40 | RUN mkdir local_data; chown worker local_data
41 | RUN mkdir models; chown worker models
42 | COPY --chown=worker --from=dependencies /home/worker/app/.venv/ .venv
43 | COPY --chown=worker private_gpt/ private_gpt
44 | COPY --chown=worker fern/ fern
45 | COPY --chown=worker *.yaml *.md ./
46 | COPY --chown=worker scripts/ scripts
47 |
48 | ENV PYTHONPATH="$PYTHONPATH:/private_gpt/"
49 |
50 | USER worker
51 | ENTRYPOINT python -m private_gpt
--------------------------------------------------------------------------------
/tests/fixtures/mock_injector.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 | from typing import Any
3 | from unittest.mock import MagicMock
4 |
5 | import pytest
6 | from injector import Provider, ScopeDecorator, singleton
7 |
8 | from private_gpt.di import create_application_injector
9 | from private_gpt.settings.settings import Settings, unsafe_settings
10 | from private_gpt.settings.settings_loader import merge_settings
11 | from private_gpt.utils.typing import T
12 |
13 |
14 | class MockInjector:
15 | def __init__(self) -> None:
16 | self.test_injector = create_application_injector()
17 |
18 | def bind_mock(
19 | self,
20 | interface: type[T],
21 | mock: (T | (Callable[..., T] | Provider[T])) | None = None,
22 | *,
23 | scope: ScopeDecorator = singleton,
24 | ) -> T:
25 | if mock is None:
26 | mock = MagicMock()
27 | self.test_injector.binder.bind(interface, to=mock, scope=scope)
28 | return mock # type: ignore
29 |
30 | def bind_settings(self, settings: dict[str, Any]) -> Settings:
31 | merged = merge_settings([unsafe_settings, settings])
32 | new_settings = Settings(**merged)
33 | self.test_injector.binder.bind(Settings, new_settings)
34 | return new_settings
35 |
36 | def get(self, interface: type[T]) -> T:
37 | return self.test_injector.get(interface)
38 |
39 |
40 | @pytest.fixture()
41 | def injector() -> MockInjector:
42 | return MockInjector()
43 |
--------------------------------------------------------------------------------
/private_gpt/settings/yaml.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import typing
4 | from typing import Any, TextIO
5 |
6 | from yaml import SafeLoader
7 |
8 | _env_replace_matcher = re.compile(r"\$\{(\w|_)+:?.*}")
9 |
10 |
11 | @typing.no_type_check # pyaml does not have good hints, everything is Any
12 | def load_yaml_with_envvars(
13 | stream: TextIO, environ: dict[str, Any] = os.environ
14 | ) -> dict[str, Any]:
15 | """Load yaml file with environment variable expansion.
16 |
17 | The pattern ${VAR} or ${VAR:default} will be replaced with
18 | the value of the environment variable.
19 | """
20 | loader = SafeLoader(stream)
21 |
22 | def load_env_var(_, node) -> str:
23 | """Extract the matched value, expand env variable, and replace the match."""
24 | value = str(node.value).removeprefix("${").removesuffix("}")
25 | split = value.split(":", 1)
26 | env_var = split[0]
27 | value = environ.get(env_var)
28 | default = None if len(split) == 1 else split[1]
29 | if value is None and default is None:
30 | raise ValueError(
31 | f"Environment variable {env_var} is not set and not default was provided"
32 | )
33 | return value or default
34 |
35 | loader.add_implicit_resolver("env_var_replacer", _env_replace_matcher, None)
36 | loader.add_constructor("env_var_replacer", load_env_var)
37 |
38 | try:
39 | return loader.get_single_data()
40 | finally:
41 | loader.dispose()
42 |
--------------------------------------------------------------------------------
/fern/docs/pages/overview/welcome.mdx:
--------------------------------------------------------------------------------
1 | ## Introduction 👋
2 |
3 | PrivateGPT provides an **API** containing all the building blocks required to
4 | build **private, context-aware AI applications**.
5 | The API follows and extends OpenAI API standard, and supports both normal and streaming responses.
6 | That means that, if you can use OpenAI API in one of your tools, you can use your own PrivateGPT API instead,
7 | with no code changes, **and for free** if you are running privateGPT in `local` mode.
8 |
9 | Looking for the installation quickstart? [Quickstart installation guide for Linux and macOS](/overview/welcome/quickstart).
10 |
11 | Do you want to install it on Windows? Or do you want to take full advantage of your hardware for better performances?
12 | The installation guide will help you in the [Installation section](/installation).
13 |
14 |
15 | ## Frequently Visited Resources
16 |
17 |
18 |
23 |
28 |
33 |
34 |
35 |
36 | A working **Gradio UI client** is provided to test the API, together with a set of useful tools such as bulk
37 | model download script, ingestion script, documents folder watch, etc.
38 |
--------------------------------------------------------------------------------
/fern/README.md:
--------------------------------------------------------------------------------
1 | # Documentation of privateGPT
2 |
3 | The documentation of this project is being rendered thanks to [fern](https://github.com/fern-api/fern).
4 |
5 | Fern is basically transforming your `.md` and `.mdx` files into a static website: your documentation.
6 |
7 | The configuration of your documentation is done in the `./docs.yml` file.
8 | There, you can configure the navbar, tabs, sections and pages being rendered.
9 |
10 | The documentation of fern (and the syntax of its configuration `docs.yml`) is
11 | available there [docs.buildwithfern.com](https://docs.buildwithfern.com/).
12 |
13 | ## How to run fern
14 |
15 | **You cannot render your documentation locally without fern credentials.**
16 |
17 | To see how your documentation looks like, you **have to** use the CICD of this
18 | repository (by opening a PR, CICD job will be executed, and a preview of
19 | your PR's documentation will be deployed in vercel automatically, through fern).
20 |
21 | The only thing you can do locally, is to run `fern check`, which check the syntax of
22 | your `docs.yml` file.
23 |
24 | ## How to add a new page
25 | Add in the `docs.yml` a new `page`, with the following syntax:
26 |
27 | ```yml
28 | navigation:
29 | # ...
30 | - tab: my-existing-tab
31 | layout:
32 | # ...
33 | - section: My Existing Section
34 | contents:
35 | # ...
36 | - page: My new page display name
37 | # The path of the page, relative to `fern/`
38 | path: ./docs/pages/my-existing-tab/new-page-content.mdx
39 | ```
--------------------------------------------------------------------------------
/private_gpt/server/ingest/ingest_watcher.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 | from pathlib import Path
3 | from typing import Any
4 |
5 | from watchdog.events import (
6 | DirCreatedEvent,
7 | DirModifiedEvent,
8 | FileCreatedEvent,
9 | FileModifiedEvent,
10 | FileSystemEventHandler,
11 | )
12 | from watchdog.observers import Observer
13 |
14 |
15 | class IngestWatcher:
16 | def __init__(
17 | self, watch_path: Path, on_file_changed: Callable[[Path], None]
18 | ) -> None:
19 | self.watch_path = watch_path
20 | self.on_file_changed = on_file_changed
21 |
22 | class Handler(FileSystemEventHandler):
23 | def on_modified(self, event: DirModifiedEvent | FileModifiedEvent) -> None:
24 | if isinstance(event, FileModifiedEvent):
25 | on_file_changed(Path(event.src_path))
26 |
27 | def on_created(self, event: DirCreatedEvent | FileCreatedEvent) -> None:
28 | if isinstance(event, FileCreatedEvent):
29 | on_file_changed(Path(event.src_path))
30 |
31 | event_handler = Handler()
32 | observer: Any = Observer()
33 | self._observer = observer
34 | self._observer.schedule(event_handler, str(watch_path), recursive=True)
35 |
36 | def start(self) -> None:
37 | self._observer.start()
38 | while self._observer.is_alive():
39 | try:
40 | self._observer.join(1)
41 | except KeyboardInterrupt:
42 | break
43 |
44 | def stop(self) -> None:
45 | self._observer.stop()
46 | self._observer.join()
47 |
--------------------------------------------------------------------------------
/.github/workflows/preview-docs.yml:
--------------------------------------------------------------------------------
1 | name: deploy preview docs
2 |
3 | on:
4 | pull_request_target:
5 | branches:
6 | - main
7 | paths:
8 | - "fern/**"
9 |
10 | jobs:
11 | preview-docs:
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - name: Checkout repository
16 | uses: actions/checkout@v4
17 | with:
18 | ref: refs/pull/${{ github.event.pull_request.number }}/merge
19 |
20 | - name: Setup Node.js
21 | uses: actions/setup-node@v4
22 | with:
23 | node-version: "18"
24 |
25 | - name: Install Fern
26 | run: npm install -g fern-api
27 |
28 | - name: Generate Documentation Preview with Fern
29 | id: generate_docs
30 | env:
31 | FERN_TOKEN: ${{ secrets.FERN_TOKEN }}
32 | run: |
33 | output=$(fern generate --docs --preview --log-level debug)
34 | echo "$output"
35 | # Extract the URL
36 | preview_url=$(echo "$output" | grep -oP '(?<=Published docs to )https://[^\s]*')
37 | # Set the output for the step
38 | echo "::set-output name=preview_url::$preview_url"
39 | - name: Comment PR with URL using github-actions bot
40 | uses: actions/github-script@v4
41 | if: ${{ steps.generate_docs.outputs.preview_url }}
42 | with:
43 | script: |
44 | const preview_url = '${{ steps.generate_docs.outputs.preview_url }}';
45 | const issue_number = context.issue.number;
46 | github.issues.createComment({
47 | ...context.repo,
48 | issue_number: issue_number,
49 | body: `Published docs preview URL: ${preview_url}`
50 | })
51 |
--------------------------------------------------------------------------------
/scripts/setup:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import argparse
4 |
5 | from huggingface_hub import hf_hub_download, snapshot_download
6 | from transformers import AutoTokenizer
7 |
8 | from private_gpt.paths import models_path, models_cache_path
9 | from private_gpt.settings.settings import settings
10 |
11 | resume_download = True
12 | if __name__ == '__main__':
13 | parser = argparse.ArgumentParser(prog='Setup: Download models from huggingface')
14 | parser.add_argument('--resume', default=True, action=argparse.BooleanOptionalAction, help='Enable/Disable resume_download options to restart the download progress interrupted')
15 | args = parser.parse_args()
16 | resume_download = args.resume
17 |
18 | os.makedirs(models_path, exist_ok=True)
19 |
20 | # Download Embedding model
21 | embedding_path = models_path / "embedding"
22 | print(f"Downloading embedding {settings().local.embedding_hf_model_name}")
23 | snapshot_download(
24 | repo_id=settings().local.embedding_hf_model_name,
25 | cache_dir=models_cache_path,
26 | local_dir=embedding_path,
27 | )
28 | print("Embedding model downloaded!")
29 |
30 | # Download LLM and create a symlink to the model file
31 | print(f"Downloading LLM {settings().local.llm_hf_model_file}")
32 | hf_hub_download(
33 | repo_id=settings().local.llm_hf_repo_id,
34 | filename=settings().local.llm_hf_model_file,
35 | cache_dir=models_cache_path,
36 | local_dir=models_path,
37 | resume_download=resume_download,
38 | )
39 | print("LLM model downloaded!")
40 |
41 | # Download Tokenizer
42 | print(f"Downloading tokenizer {settings().llm.tokenizer}")
43 | AutoTokenizer.from_pretrained(
44 | pretrained_model_name_or_path=settings().llm.tokenizer,
45 | cache_dir=models_cache_path,
46 | )
47 | print("Tokenizer downloaded!")
48 |
49 | print("Setup done")
50 |
--------------------------------------------------------------------------------
/private_gpt/components/embedding/embedding_component.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from injector import inject, singleton
4 | from llama_index import MockEmbedding
5 | from llama_index.embeddings.base import BaseEmbedding
6 |
7 | from private_gpt.paths import models_cache_path
8 | from private_gpt.settings.settings import Settings
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | @singleton
14 | class EmbeddingComponent:
15 | embedding_model: BaseEmbedding
16 |
17 | @inject
18 | def __init__(self, settings: Settings) -> None:
19 | embedding_mode = settings.embedding.mode
20 | logger.info("Initializing the embedding model in mode=%s", embedding_mode)
21 | match embedding_mode:
22 | case "local":
23 | from llama_index.embeddings import HuggingFaceEmbedding
24 |
25 | self.embedding_model = HuggingFaceEmbedding(
26 | model_name=settings.local.embedding_hf_model_name,
27 | cache_folder=str(models_cache_path),
28 | )
29 | case "sagemaker":
30 |
31 | from private_gpt.components.embedding.custom.sagemaker import (
32 | SagemakerEmbedding,
33 | )
34 |
35 | self.embedding_model = SagemakerEmbedding(
36 | endpoint_name=settings.sagemaker.embedding_endpoint_name,
37 | )
38 | case "openai":
39 | from llama_index import OpenAIEmbedding
40 |
41 | openai_settings = settings.openai.api_key
42 | self.embedding_model = OpenAIEmbedding(api_key=openai_settings)
43 | case "mock":
44 | # Not a random number, is the dimensionality used by
45 | # the default embedding model
46 | self.embedding_model = MockEmbedding(384)
47 |
--------------------------------------------------------------------------------
/private_gpt/settings/settings_loader.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import logging
3 | import os
4 | import sys
5 | from collections.abc import Iterable
6 | from pathlib import Path
7 | from typing import Any
8 |
9 | from pydantic.v1.utils import deep_update, unique_list
10 |
11 | from private_gpt.constants import PROJECT_ROOT_PATH
12 | from private_gpt.settings.yaml import load_yaml_with_envvars
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 | _settings_folder = os.environ.get("PGPT_SETTINGS_FOLDER", PROJECT_ROOT_PATH)
17 |
18 | # if running in unittest, use the test profile
19 | _test_profile = ["test"] if "tests.fixtures" in sys.modules else []
20 |
21 | active_profiles: list[str] = unique_list(
22 | ["default"]
23 | + [
24 | item.strip()
25 | for item in os.environ.get("PGPT_PROFILES", "").split(",")
26 | if item.strip()
27 | ]
28 | + _test_profile
29 | )
30 |
31 |
32 | def merge_settings(settings: Iterable[dict[str, Any]]) -> dict[str, Any]:
33 | return functools.reduce(deep_update, settings, {})
34 |
35 |
36 | def load_settings_from_profile(profile: str) -> dict[str, Any]:
37 | if profile == "default":
38 | profile_file_name = "settings.yaml"
39 | else:
40 | profile_file_name = f"settings-{profile}.yaml"
41 |
42 | path = Path(_settings_folder) / profile_file_name
43 | with Path(path).open("r") as f:
44 | config = load_yaml_with_envvars(f)
45 | if not isinstance(config, dict):
46 | raise TypeError(f"Config file has no top-level mapping: {path}")
47 | return config
48 |
49 |
50 | def load_active_settings() -> dict[str, Any]:
51 | """Load active profiles and merge them."""
52 | logger.info("Starting application with profiles=%s", active_profiles)
53 | loaded_profiles = [
54 | load_settings_from_profile(profile) for profile in active_profiles
55 | ]
56 | merged: dict[str, Any] = merge_settings(loaded_profiles)
57 | return merged
58 |
--------------------------------------------------------------------------------
/tests/server/ingest/test_ingest_routes.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 | from pathlib import Path
3 |
4 | from fastapi.testclient import TestClient
5 |
6 | from private_gpt.server.ingest.ingest_router import IngestResponse
7 | from tests.fixtures.ingest_helper import IngestHelper
8 |
9 |
10 | def test_ingest_accepts_txt_files(ingest_helper: IngestHelper) -> None:
11 | path = Path(__file__).parents[0] / "test.txt"
12 | ingest_result = ingest_helper.ingest_file(path)
13 | assert len(ingest_result.data) == 1
14 |
15 |
16 | def test_ingest_accepts_pdf_files(ingest_helper: IngestHelper) -> None:
17 | path = Path(__file__).parents[0] / "test.pdf"
18 | ingest_result = ingest_helper.ingest_file(path)
19 | assert len(ingest_result.data) == 1
20 |
21 |
22 | def test_ingest_list_returns_something_after_ingestion(
23 | test_client: TestClient, ingest_helper: IngestHelper
24 | ) -> None:
25 | response_before = test_client.get("/v1/ingest/list")
26 | count_ingest_before = len(response_before.json()["data"])
27 | with tempfile.NamedTemporaryFile("w", suffix=".txt") as test_file:
28 | test_file.write("Foo bar; hello there!")
29 | test_file.flush()
30 | test_file.seek(0)
31 | ingest_result = ingest_helper.ingest_file(Path(test_file.name))
32 | assert len(ingest_result.data) == 1, "The temp doc should have been ingested"
33 | response_after = test_client.get("/v1/ingest/list")
34 | count_ingest_after = len(response_after.json()["data"])
35 | assert (
36 | count_ingest_after == count_ingest_before + 1
37 | ), "The temp doc should be returned"
38 |
39 |
40 | def test_ingest_plain_text(test_client: TestClient) -> None:
41 | response = test_client.post(
42 | "/v1/ingest/text", json={"file_name": "file_name", "text": "text"}
43 | )
44 | assert response.status_code == 200
45 | ingest_result = IngestResponse.model_validate(response.json())
46 | assert len(ingest_result.data) == 1
47 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: tests
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.head_ref || github.ref }}
11 | cancel-in-progress: ${{ github.event_name == 'pull_request' }}
12 |
13 | jobs:
14 | setup:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/checkout@v3
18 | - uses: ./.github/workflows/actions/install_dependencies
19 |
20 | checks:
21 | needs: setup
22 | runs-on: ubuntu-latest
23 | name: ${{ matrix.quality-command }}
24 | strategy:
25 | matrix:
26 | quality-command:
27 | - black
28 | - ruff
29 | - mypy
30 | steps:
31 | - uses: actions/checkout@v3
32 | - uses: ./.github/workflows/actions/install_dependencies
33 | - name: run ${{ matrix.quality-command }}
34 | run: make ${{ matrix.quality-command }}
35 |
36 | test:
37 | needs: setup
38 | runs-on: ubuntu-latest
39 | name: test
40 | steps:
41 | - uses: actions/checkout@v3
42 | - uses: ./.github/workflows/actions/install_dependencies
43 | - name: run test
44 | run: make test-coverage
45 | # Run even if make test fails for coverage reports
46 | # TODO: select a better xml results displayer
47 | - name: Archive test results coverage results
48 | uses: actions/upload-artifact@v3
49 | if: always()
50 | with:
51 | name: test_results
52 | path: tests-results.xml
53 | - name: Archive code coverage results
54 | uses: actions/upload-artifact@v3
55 | if: always()
56 | with:
57 | name: code-coverage-report
58 | path: htmlcov/
59 |
60 | all_checks_passed:
61 | # Used to easily force requirements checks in GitHub
62 | needs:
63 | - checks
64 | - test
65 | runs-on: ubuntu-latest
66 | steps:
67 | - run: echo "All checks passed"
68 |
--------------------------------------------------------------------------------
/tests/server/utils/test_simple_auth.py:
--------------------------------------------------------------------------------
1 | """Tests to validate that the simple authentication mechanism is working.
2 |
3 | NOTE: We are not testing the switch based on the config in
4 | `private_gpt.server.utils.auth`. This is not done because of the way the code
5 | is currently architecture (it is hard to patch the `settings` and the app while
6 | the tests are directly importing them).
7 | """
8 | from typing import Annotated
9 |
10 | import pytest
11 | from fastapi import Depends
12 | from fastapi.testclient import TestClient
13 |
14 | from private_gpt.server.utils.auth import (
15 | NOT_AUTHENTICATED,
16 | _simple_authentication,
17 | authenticated,
18 | )
19 | from private_gpt.settings.settings import settings
20 |
21 |
22 | def _copy_simple_authenticated(
23 | _simple_authentication: Annotated[bool, Depends(_simple_authentication)]
24 | ) -> bool:
25 | """Check if the request is authenticated."""
26 | if not _simple_authentication:
27 | raise NOT_AUTHENTICATED
28 | return True
29 |
30 |
31 | @pytest.fixture(autouse=True)
32 | def _patch_authenticated_dependency(test_client: TestClient):
33 | # Patch the server to use simple authentication
34 |
35 | test_client.app.dependency_overrides[authenticated] = _copy_simple_authenticated
36 |
37 | # Call the actual test
38 | yield
39 |
40 | # Remove the patch for other tests
41 | test_client.app.dependency_overrides = {}
42 |
43 |
44 | def test_default_auth_working_when_enabled_401(test_client: TestClient) -> None:
45 | response = test_client.get("/v1/ingest/list")
46 | assert response.status_code == 401
47 |
48 |
49 | def test_default_auth_working_when_enabled_200(test_client: TestClient) -> None:
50 | response_fail = test_client.get("/v1/ingest/list")
51 | assert response_fail.status_code == 401
52 |
53 | response_success = test_client.get(
54 | "/v1/ingest/list", headers={"Authorization": settings().server.auth.secret}
55 | )
56 | assert response_success.status_code == 200
57 |
--------------------------------------------------------------------------------
/private_gpt/launcher.py:
--------------------------------------------------------------------------------
1 | """FastAPI app creation, logger configuration and main API routes."""
2 | import logging
3 |
4 | from fastapi import Depends, FastAPI, Request
5 | from fastapi.middleware.cors import CORSMiddleware
6 | from injector import Injector
7 |
8 | from private_gpt.server.chat.chat_router import chat_router
9 | from private_gpt.server.chunks.chunks_router import chunks_router
10 | from private_gpt.server.completions.completions_router import completions_router
11 | from private_gpt.server.embeddings.embeddings_router import embeddings_router
12 | from private_gpt.server.health.health_router import health_router
13 | from private_gpt.server.ingest.ingest_router import ingest_router
14 | from private_gpt.settings.settings import Settings
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | def create_app(root_injector: Injector) -> FastAPI:
20 |
21 | # Start the API
22 | async def bind_injector_to_request(request: Request) -> None:
23 | request.state.injector = root_injector
24 |
25 | app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
26 |
27 | app.include_router(completions_router)
28 | app.include_router(chat_router)
29 | app.include_router(chunks_router)
30 | app.include_router(ingest_router)
31 | app.include_router(embeddings_router)
32 | app.include_router(health_router)
33 |
34 | settings = root_injector.get(Settings)
35 | if settings.server.cors.enabled:
36 | logger.debug("Setting up CORS middleware")
37 | app.add_middleware(
38 | CORSMiddleware,
39 | allow_credentials=settings.server.cors.allow_credentials,
40 | allow_origins=settings.server.cors.allow_origins,
41 | allow_origin_regex=settings.server.cors.allow_origin_regex,
42 | allow_methods=settings.server.cors.allow_methods,
43 | allow_headers=settings.server.cors.allow_headers,
44 | )
45 |
46 | if settings.ui.enabled:
47 | logger.debug("Importing the UI module")
48 | from private_gpt.ui.ui import PrivateGptUi
49 |
50 | ui = root_injector.get(PrivateGptUi)
51 | ui.mount_in_app(app, settings.ui.path)
52 |
53 | return app
54 |
--------------------------------------------------------------------------------
/settings.yaml:
--------------------------------------------------------------------------------
1 | # The default configuration file.
2 | # More information about configuration can be found in the documentation: https://docs.privategpt.dev/
3 | # Syntax in `private_pgt/settings/settings.py`
4 | server:
5 | env_name: ${APP_ENV:prod}
6 | port: ${PORT:8001}
7 | cors:
8 | enabled: false
9 | allow_origins: ["*"]
10 | allow_methods: ["*"]
11 | allow_headers: ["*"]
12 | auth:
13 | enabled: false
14 | # python -c 'import base64; print("Basic " + base64.b64encode("secret:key".encode()).decode())'
15 | # 'secret' is the username and 'key' is the password for basic auth by default
16 | # If the auth is enabled, this value must be set in the "Authorization" header of the request.
17 | secret: "Basic c2VjcmV0OmtleQ=="
18 |
19 | data:
20 | local_data_folder: local_data/private_gpt
21 |
22 | ui:
23 | enabled: true
24 | path: /
25 | default_chat_system_prompt: >
26 | You are a helpful, respectful and honest assistant.
27 | Always answer as helpfully as possible and follow ALL given instructions.
28 | Do not speculate or make up information.
29 | Do not reference any given instructions or context.
30 | default_query_system_prompt: >
31 | You can only answer questions about the provided context.
32 | If you know the answer but it is not based in the provided context, don't provide
33 | the answer, just state the answer is not in the context provided.
34 |
35 | llm:
36 | mode: local
37 | # Should be matching the selected model
38 | max_new_tokens: 512
39 | context_window: 3900
40 | tokenizer: mistralai/Mistral-7B-Instruct-v0.2
41 |
42 | embedding:
43 | # Should be matching the value above in most cases
44 | mode: local
45 | ingest_mode: simple
46 |
47 | vectorstore:
48 | database: qdrant
49 |
50 | qdrant:
51 | path: local_data/private_gpt/qdrant
52 |
53 | local:
54 | prompt_style: "mistral"
55 | llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.2-GGUF
56 | llm_hf_model_file: mistral-7b-instruct-v0.2.Q4_K_M.gguf
57 | embedding_hf_model_name: BAAI/bge-small-en-v1.5
58 |
59 | sagemaker:
60 | llm_endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140
61 | embedding_endpoint_name: huggingface-pytorch-inference-2023-11-03-07-41-36-479
62 |
63 | openai:
64 | api_key: ${OPENAI_API_KEY:}
65 | model: gpt-3.5-turbo
66 |
--------------------------------------------------------------------------------
/private_gpt/server/chunks/chunks_router.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | from fastapi import APIRouter, Depends, Request
4 | from pydantic import BaseModel, Field
5 |
6 | from private_gpt.open_ai.extensions.context_filter import ContextFilter
7 | from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
8 | from private_gpt.server.utils.auth import authenticated
9 |
10 | chunks_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
11 |
12 |
13 | class ChunksBody(BaseModel):
14 | text: str = Field(examples=["Q3 2023 sales"])
15 | context_filter: ContextFilter | None = None
16 | limit: int = 10
17 | prev_next_chunks: int = Field(default=0, examples=[2])
18 |
19 |
20 | class ChunksResponse(BaseModel):
21 | object: Literal["list"]
22 | model: Literal["private-gpt"]
23 | data: list[Chunk]
24 |
25 |
26 | @chunks_router.post("/chunks", tags=["Context Chunks"])
27 | def chunks_retrieval(request: Request, body: ChunksBody) -> ChunksResponse:
28 | """Given a `text`, returns the most relevant chunks from the ingested documents.
29 |
30 | The returned information can be used to generate prompts that can be
31 | passed to `/completions` or `/chat/completions` APIs. Note: it is usually a very
32 | fast API, because only the Embeddings model is involved, not the LLM. The
33 | returned information contains the relevant chunk `text` together with the source
34 | `document` it is coming from. It also contains a score that can be used to
35 | compare different results.
36 |
37 | The max number of chunks to be returned is set using the `limit` param.
38 |
39 | Previous and next chunks (pieces of text that appear right before or after in the
40 | document) can be fetched by using the `prev_next_chunks` field.
41 |
42 | The documents being used can be filtered using the `context_filter` and passing
43 | the document IDs to be used. Ingested documents IDs can be found using
44 | `/ingest/list` endpoint. If you want all ingested documents to be used,
45 | remove `context_filter` altogether.
46 | """
47 | service = request.state.injector.get(ChunksService)
48 | results = service.retrieve_relevant(
49 | body.text, body.context_filter, body.limit, body.prev_next_chunks
50 | )
51 | return ChunksResponse(
52 | object="list",
53 | model="private-gpt",
54 | data=results,
55 | )
56 |
--------------------------------------------------------------------------------
/private_gpt/components/ingest/ingest_helper.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from pathlib import Path
3 |
4 | from llama_index import Document
5 | from llama_index.readers import JSONReader, StringIterableReader
6 | from llama_index.readers.file.base import DEFAULT_FILE_READER_CLS
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 | # Patching the default file reader to support other file types
11 | FILE_READER_CLS = DEFAULT_FILE_READER_CLS.copy()
12 | FILE_READER_CLS.update(
13 | {
14 | ".json": JSONReader,
15 | }
16 | )
17 |
18 |
19 | class IngestionHelper:
20 | """Helper class to transform a file into a list of documents.
21 |
22 | This class should be used to transform a file into a list of documents.
23 | These methods are thread-safe (and multiprocessing-safe).
24 | """
25 |
26 | @staticmethod
27 | def transform_file_into_documents(
28 | file_name: str, file_data: Path
29 | ) -> list[Document]:
30 | documents = IngestionHelper._load_file_to_documents(file_name, file_data)
31 | for document in documents:
32 | document.metadata["file_name"] = file_name
33 | IngestionHelper._exclude_metadata(documents)
34 | return documents
35 |
36 | @staticmethod
37 | def _load_file_to_documents(file_name: str, file_data: Path) -> list[Document]:
38 | logger.debug("Transforming file_name=%s into documents", file_name)
39 | extension = Path(file_name).suffix
40 | reader_cls = FILE_READER_CLS.get(extension)
41 | if reader_cls is None:
42 | logger.debug(
43 | "No reader found for extension=%s, using default string reader",
44 | extension,
45 | )
46 | # Read as a plain text
47 | string_reader = StringIterableReader()
48 | return string_reader.load_data([file_data.read_text()])
49 |
50 | logger.debug("Specific reader found for extension=%s", extension)
51 | return reader_cls().load_data(file_data)
52 |
53 | @staticmethod
54 | def _exclude_metadata(documents: list[Document]) -> None:
55 | logger.debug("Excluding metadata from count=%s documents", len(documents))
56 | for document in documents:
57 | document.metadata["doc_id"] = document.doc_id
58 | # We don't want the Embeddings search to receive this metadata
59 | document.excluded_embed_metadata_keys = ["doc_id"]
60 | # We don't want the LLM to receive these metadata in the context
61 | document.excluded_llm_metadata_keys = ["file_name", "doc_id", "page_label"]
62 |
--------------------------------------------------------------------------------
/fern/docs/pages/manual/vectordb.mdx:
--------------------------------------------------------------------------------
1 | ## Vectorstores
2 | PrivateGPT supports [Qdrant](https://qdrant.tech/) and [Chroma](https://www.trychroma.com/) as vectorstore providers. Qdrant being the default.
3 |
4 | In order to select one or the other, set the `vectorstore.database` property in the `settings.yaml` file to `qdrant` or `chroma`.
5 |
6 | ```yaml
7 | vectorstore:
8 | database: qdrant
9 | ```
10 |
11 | ### Qdrant configuration
12 |
13 | To enable Qdrant, set the `vectorstore.database` property in the `settings.yaml` file to `qdrant`.
14 |
15 | Qdrant settings can be configured by setting values to the `qdrant` property in the `settings.yaml` file.
16 |
17 | The available configuration options are:
18 | | Field | Description |
19 | |--------------|-------------|
20 | | location | If `:memory:` - use in-memory Qdrant instance. If `str` - use it as a `url` parameter.|
21 | | url | Either host or str of 'Optional[scheme], host, Optional[port], Optional[prefix]'. Eg. `http://localhost:6333` |
22 | | port | Port of the REST API interface. Default: `6333` |
23 | | grpc_port | Port of the gRPC interface. Default: `6334` |
24 | | prefer_grpc | If `true` - use gRPC interface whenever possible in custom methods. |
25 | | https | If `true` - use HTTPS(SSL) protocol.|
26 | | api_key | API key for authentication in Qdrant Cloud.|
27 | | prefix | If set, add `prefix` to the REST URL path. Example: `service/v1` will result in `http://localhost:6333/service/v1/{qdrant-endpoint}` for REST API.|
28 | | timeout | Timeout for REST and gRPC API requests. Default: 5.0 seconds for REST and unlimited for gRPC |
29 | | host | Host name of Qdrant service. If url and host are not set, defaults to 'localhost'.|
30 | | path | Persistence path for QdrantLocal. Eg. `local_data/private_gpt/qdrant`|
31 | | force_disable_check_same_thread | Force disable check_same_thread for QdrantLocal sqlite connection, defaults to True.|
32 |
33 | By default Qdrant tries to connect to an instance of Qdrant server at `http://localhost:3000`.
34 |
35 | To obtain a local setup (disk-based database) without running a Qdrant server, configure the `qdrant.path` value in settings.yaml:
36 |
37 | ```yaml
38 | qdrant:
39 | path: local_data/private_gpt/qdrant
40 | ```
41 |
42 | ### Chroma configuration
43 |
44 | To enable Chroma, set the `vectorstore.database` property in the `settings.yaml` file to `chroma` and install the `chroma` extra.
45 |
46 | ```bash
47 | poetry install --extras chroma
48 | ```
49 |
50 | By default `chroma` will use a disk-based database stored in local_data_path / "chroma_db" (being local_data_path defined in settings.yaml)
--------------------------------------------------------------------------------
/tests/server/ingest/test.txt:
--------------------------------------------------------------------------------
1 | Once upon a time, in a magical forest called Enchantia, lived a young and cheerful deer named Zumi. Zumi was no ordinary deer; she was bright-eyed, intelligent, and had a heart full of curiosity. One sunny morning, as the forest came alive with the sweet melodies of chirping birds and rustling leaves, Zumi eagerly pranced through the woods on her way to school.
2 |
3 | Enchantia Forest School was a unique place, where all the woodland creatures gathered to learn and grow together. The school was nestled in a clearing surrounded by tall, ancient trees. Zumi loved the feeling of anticipation as she approached the school, her hooves barely touching the ground in excitement.
4 |
5 | As she arrived at the school, her dear friend and classmate, Oliver the wise old owl, greeted her with a friendly hoot. "Good morning, Zumi! Are you ready for another day of adventure and learning?"
6 |
7 | Zumi's eyes sparkled with enthusiasm as she nodded, "Absolutely, Oliver! I can't wait to see what we'll discover today."
8 |
9 | In their classroom, Teacher Willow, a gentle and nurturing willow tree, welcomed the students. The classroom was adorned with vibrant leaves and twinkling fireflies, creating a magical and cozy atmosphere. Today's lesson was about the history of the forest and the importance of living harmoniously with nature.
10 |
11 | The students listened attentively as Teacher Willow recounted stories of ancient times when the forest thrived in unity and peace. Zumi was particularly enthralled by the tales of forest guardians and how they protected the magical balance of Enchantia.
12 |
13 | After the lesson, it was time for recess. Zumi joined her friends in a lively game of tag, where they darted and danced playfully among the trees. Zumi's speed and agility made her an excellent tagger, and laughter filled the air as they played.
14 |
15 | Later, they gathered for an art class, where they expressed themselves through painting and sculpting with clay. Zumi chose to paint a mural of the forest, portraying the beauty and magic they were surrounded by every day.
16 |
17 | As the day came to an end, the students sat in a circle to share stories and reflections. Zumi shared her excitement for the day and how she learned to appreciate the interconnectedness of all creatures in the forest.
18 |
19 | As the sun set, casting a golden glow across the forest, Zumi made her way back home, her heart brimming with happiness and newfound knowledge. Each day at Enchantia Forest School was an adventure, and Zumi couldn't wait to learn more and grow with her friends, for the magic of learning was as boundless as the forest itself. And so, under the canopy of stars and the watchful eyes of the forest, Zumi drifted into dreams filled with wonder and anticipation for the adventures that awaited her on the morrow.
--------------------------------------------------------------------------------
/private_gpt/server/utils/auth.py:
--------------------------------------------------------------------------------
1 | """Authentication mechanism for the API.
2 |
3 | Define a simple mechanism to authenticate requests.
4 | More complex authentication mechanisms can be defined here, and be placed in the
5 | `authenticated` method (being a 'bean' injected in fastapi routers).
6 |
7 | Authorization can also be made after the authentication, and depends on
8 | the authentication. Authorization should not be implemented in this file.
9 |
10 | Authorization can be done by following fastapi's guides:
11 | * https://fastapi.tiangolo.com/advanced/security/oauth2-scopes/
12 | * https://fastapi.tiangolo.com/tutorial/security/
13 | * https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-in-path-operation-decorators/
14 | """
15 | # mypy: ignore-errors
16 | # Disabled mypy error: All conditional function variants must have identical signatures
17 | # We are changing the implementation of the authenticated method, based on
18 | # the config. If the auth is not enabled, we are not defining the complex method
19 | # with its dependencies.
20 | import logging
21 | import secrets
22 | from typing import Annotated
23 |
24 | from fastapi import Depends, Header, HTTPException
25 |
26 | from private_gpt.settings.settings import settings
27 |
28 | # 401 signify that the request requires authentication.
29 | # 403 signify that the authenticated user is not authorized to perform the operation.
30 | NOT_AUTHENTICATED = HTTPException(
31 | status_code=401,
32 | detail="Not authenticated",
33 | headers={"WWW-Authenticate": 'Basic realm="All the API", charset="UTF-8"'},
34 | )
35 |
36 | logger = logging.getLogger(__name__)
37 |
38 |
39 | def _simple_authentication(authorization: Annotated[str, Header()] = "") -> bool:
40 | """Check if the request is authenticated."""
41 | if not secrets.compare_digest(authorization, settings().server.auth.secret):
42 | # If the "Authorization" header is not the expected one, raise an exception.
43 | raise NOT_AUTHENTICATED
44 | return True
45 |
46 |
47 | if not settings().server.auth.enabled:
48 | logger.debug(
49 | "Defining a dummy authentication mechanism for fastapi, always authenticating requests"
50 | )
51 |
52 | # Define a dummy authentication method that always returns True.
53 | def authenticated() -> bool:
54 | """Check if the request is authenticated."""
55 | return True
56 |
57 | else:
58 | logger.info("Defining the given authentication mechanism for the API")
59 |
60 | # Method to be used as a dependency to check if the request is authenticated.
61 | def authenticated(
62 | _simple_authentication: Annotated[bool, Depends(_simple_authentication)]
63 | ) -> bool:
64 | """Check if the request is authenticated."""
65 | assert settings().server.auth.enabled
66 | if not _simple_authentication:
67 | raise NOT_AUTHENTICATED
68 | return True
69 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | # Any args passed to the make script, use with $(call args, default_value)
2 | args = `arg="$(filter-out $@,$(MAKECMDGOALS))" && echo $${arg:-${1}}`
3 |
4 | ########################################################################################################################
5 | # Quality checks
6 | ########################################################################################################################
7 |
8 | test:
9 | PYTHONPATH=. poetry run pytest tests
10 |
11 | test-coverage:
12 | PYTHONPATH=. poetry run pytest tests --cov private_gpt --cov-report term --cov-report=html --cov-report xml --junit-xml=tests-results.xml
13 |
14 | black:
15 | poetry run black . --check
16 |
17 | ruff:
18 | poetry run ruff check private_gpt tests
19 |
20 | format:
21 | poetry run black .
22 | poetry run ruff check private_gpt tests --fix
23 |
24 | mypy:
25 | poetry run mypy private_gpt
26 |
27 | check:
28 | make format
29 | make mypy
30 |
31 | ########################################################################################################################
32 | # Run
33 | ########################################################################################################################
34 |
35 | run:
36 | poetry run python -m private_gpt
37 |
38 | dev-windows:
39 | (set PGPT_PROFILES=local & poetry run python -m uvicorn private_gpt.main:app --reload --port 8001)
40 |
41 | dev:
42 | PYTHONUNBUFFERED=1 PGPT_PROFILES=local poetry run python -m uvicorn private_gpt.main:app --reload --port 8001
43 |
44 | ########################################################################################################################
45 | # Misc
46 | ########################################################################################################################
47 |
48 | api-docs:
49 | PGPT_PROFILES=mock poetry run python scripts/extract_openapi.py private_gpt.main:app --out fern/openapi/openapi.json
50 |
51 | ingest:
52 | @poetry run python scripts/ingest_folder.py $(call args)
53 |
54 | wipe:
55 | poetry run python scripts/utils.py wipe
56 |
57 | setup:
58 | poetry run python scripts/setup
59 |
60 | list:
61 | @echo "Available commands:"
62 | @echo " test : Run tests using pytest"
63 | @echo " test-coverage : Run tests with coverage report"
64 | @echo " black : Check code format with black"
65 | @echo " ruff : Check code with ruff"
66 | @echo " format : Format code with black and ruff"
67 | @echo " mypy : Run mypy for type checking"
68 | @echo " check : Run format and mypy commands"
69 | @echo " run : Run the application"
70 | @echo " dev-windows : Run the application in development mode on Windows"
71 | @echo " dev : Run the application in development mode"
72 | @echo " api-docs : Generate API documentation"
73 | @echo " ingest : Ingest data using specified script"
74 | @echo " wipe : Wipe data using specified script"
75 | @echo " setup : Setup the application"
76 |
--------------------------------------------------------------------------------
/private_gpt/components/vector_store/batched_chroma.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from llama_index.schema import BaseNode, MetadataMode
4 | from llama_index.vector_stores import ChromaVectorStore
5 | from llama_index.vector_stores.chroma import chunk_list
6 | from llama_index.vector_stores.utils import node_to_metadata_dict
7 |
8 |
9 | class BatchedChromaVectorStore(ChromaVectorStore):
10 | """Chroma vector store, batching additions to avoid reaching the max batch limit.
11 |
12 | In this vector store, embeddings are stored within a ChromaDB collection.
13 |
14 | During query time, the index uses ChromaDB to query for the top
15 | k most similar nodes.
16 |
17 | Args:
18 | chroma_client (from chromadb.api.API):
19 | API instance
20 | chroma_collection (chromadb.api.models.Collection.Collection):
21 | ChromaDB collection instance
22 |
23 | """
24 |
25 | chroma_client: Any | None
26 |
27 | def __init__(
28 | self,
29 | chroma_client: Any,
30 | chroma_collection: Any,
31 | host: str | None = None,
32 | port: str | None = None,
33 | ssl: bool = False,
34 | headers: dict[str, str] | None = None,
35 | collection_kwargs: dict[Any, Any] | None = None,
36 | ) -> None:
37 | super().__init__(
38 | chroma_collection=chroma_collection,
39 | host=host,
40 | port=port,
41 | ssl=ssl,
42 | headers=headers,
43 | collection_kwargs=collection_kwargs or {},
44 | )
45 | self.chroma_client = chroma_client
46 |
47 | def add(self, nodes: list[BaseNode], **add_kwargs: Any) -> list[str]:
48 | """Add nodes to index, batching the insertion to avoid issues.
49 |
50 | Args:
51 | nodes: List[BaseNode]: list of nodes with embeddings
52 | add_kwargs: _
53 | """
54 | if not self.chroma_client:
55 | raise ValueError("Client not initialized")
56 |
57 | if not self._collection:
58 | raise ValueError("Collection not initialized")
59 |
60 | max_chunk_size = self.chroma_client.max_batch_size
61 | node_chunks = chunk_list(nodes, max_chunk_size)
62 |
63 | all_ids = []
64 | for node_chunk in node_chunks:
65 | embeddings = []
66 | metadatas = []
67 | ids = []
68 | documents = []
69 | for node in node_chunk:
70 | embeddings.append(node.get_embedding())
71 | metadatas.append(
72 | node_to_metadata_dict(
73 | node, remove_text=True, flat_metadata=self.flat_metadata
74 | )
75 | )
76 | ids.append(node.node_id)
77 | documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
78 |
79 | self._collection.add(
80 | embeddings=embeddings,
81 | ids=ids,
82 | metadatas=metadatas,
83 | documents=documents,
84 | )
85 | all_ids.extend(ids)
86 |
87 | return all_ids
88 |
--------------------------------------------------------------------------------
/private_gpt/components/embedding/custom/sagemaker.py:
--------------------------------------------------------------------------------
1 | # mypy: ignore-errors
2 | import json
3 | from typing import Any
4 |
5 | import boto3
6 | from llama_index.embeddings.base import BaseEmbedding
7 | from pydantic import Field, PrivateAttr
8 |
9 |
10 | class SagemakerEmbedding(BaseEmbedding):
11 | """Sagemaker Embedding Endpoint.
12 |
13 | To use, you must supply the endpoint name from your deployed
14 | Sagemaker embedding model & the region where it is deployed.
15 |
16 | To authenticate, the AWS client uses the following methods to
17 | automatically load credentials:
18 | https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
19 |
20 | If a specific credential profile should be used, you must pass
21 | the name of the profile from the ~/.aws/credentials file that is to be used.
22 |
23 | Make sure the credentials / roles used have the required policies to
24 | access the Sagemaker endpoint.
25 | See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
26 | """
27 |
28 | endpoint_name: str = Field(description="")
29 |
30 | _boto_client: Any = boto3.client(
31 | "sagemaker-runtime",
32 | ) # TODO make it an optional field
33 |
34 | _async_not_implemented_warned: bool = PrivateAttr(default=False)
35 |
36 | @classmethod
37 | def class_name(cls) -> str:
38 | return "SagemakerEmbedding"
39 |
40 | def _async_not_implemented_warn_once(self) -> None:
41 | if not self._async_not_implemented_warned:
42 | print("Async embedding not available, falling back to sync method.")
43 | self._async_not_implemented_warned = True
44 |
45 | def _embed(self, sentences: list[str]) -> list[list[float]]:
46 | request_params = {
47 | "inputs": sentences,
48 | }
49 |
50 | resp = self._boto_client.invoke_endpoint(
51 | EndpointName=self.endpoint_name,
52 | Body=json.dumps(request_params),
53 | ContentType="application/json",
54 | )
55 |
56 | response_body = resp["Body"]
57 | response_str = response_body.read().decode("utf-8")
58 | response_json = json.loads(response_str)
59 |
60 | return response_json["vectors"]
61 |
62 | def _get_query_embedding(self, query: str) -> list[float]:
63 | """Get query embedding."""
64 | return self._embed([query])[0]
65 |
66 | async def _aget_query_embedding(self, query: str) -> list[float]:
67 | # Warn the user that sync is being used
68 | self._async_not_implemented_warn_once()
69 | return self._get_query_embedding(query)
70 |
71 | async def _aget_text_embedding(self, text: str) -> list[float]:
72 | # Warn the user that sync is being used
73 | self._async_not_implemented_warn_once()
74 | return self._get_text_embedding(text)
75 |
76 | def _get_text_embedding(self, text: str) -> list[float]:
77 | """Get text embedding."""
78 | return self._embed([text])[0]
79 |
80 | def _get_text_embeddings(self, texts: list[str]) -> list[list[float]]:
81 | """Get text embeddings."""
82 | return self._embed(texts)
83 |
--------------------------------------------------------------------------------
/fern/docs/pages/manual/ui.mdx:
--------------------------------------------------------------------------------
1 | ## Gradio UI user manual
2 |
3 | Gradio UI is a ready to use way of testing most of PrivateGPT API functionalities.
4 |
5 | 
6 |
7 | ### Execution Modes
8 |
9 | It has 3 modes of execution (you can select in the top-left):
10 |
11 | * Query Docs: uses the context from the
12 | ingested documents to answer the questions posted in the chat. It also takes
13 | into account previous chat messages as context.
14 | * Makes use of `/chat/completions` API with `use_context=true` and no
15 | `context_filter`.
16 | * Search in Docs: fast search that returns the 4 most related text
17 | chunks, together with their source document and page.
18 | * Makes use of `/chunks` API with no `context_filter`, `limit=4` and
19 | `prev_next_chunks=0`.
20 | * LLM Chat: simple, non-contextual chat with the LLM. The ingested documents won't
21 | be taken into account, only the previous messages.
22 | * Makes use of `/chat/completions` API with `use_context=false`.
23 |
24 | ### Document Ingestion
25 |
26 | Ingest documents by using the `Upload a File` button. You can check the progress of
27 | the ingestion in the console logs of the server.
28 |
29 | The list of ingested files is shown below the button.
30 |
31 | If you want to delete the ingested documents, refer to *Reset Local documents
32 | database* section in the documentation.
33 |
34 | ### Chat
35 |
36 | Normal chat interface, self-explanatory ;)
37 |
38 | #### System Prompt
39 | You can view and change the system prompt being passed to the LLM by clicking "Additional Inputs"
40 | in the chat interface. The system prompt is also logged on the server.
41 |
42 | By default, the `Query Docs` mode uses the setting value `ui.default_query_system_prompt`.
43 |
44 | The `LLM Chat` mode attempts to use the optional settings value `ui.default_chat_system_prompt`.
45 |
46 | If no system prompt is entered, the UI will display the default system prompt being used
47 | for the active mode.
48 |
49 | ##### System Prompt Examples:
50 |
51 | The system prompt can effectively provide your chat bot specialized roles, and results tailored to the prompt
52 | you have given the model. Examples of system prompts can be be found
53 | [here](https://www.w3schools.com/gen_ai/chatgpt-3-5/chatgpt-3-5_roles.php).
54 |
55 | Some interesting examples to try include:
56 |
57 | * You are -X-. You have all the knowledge and personality of -X-. Answer as if you were -X- using
58 | their manner of speaking and vocabulary.
59 | * Example: You are Shakespeare. You have all the knowledge and personality of Shakespeare.
60 | Answer as if you were Shakespeare using their manner of speaking and vocabulary.
61 | * You are an expert (at) -role-. Answer all questions using your expertise on -specific domain topic-.
62 | * Example: You are an expert software engineer. Answer all questions using your expertise on Python.
63 | * You are a -role- bot, respond with -response criteria needed-. If no -response criteria- is needed,
64 | respond with -alternate response-.
65 | * Example: You are a grammar checking bot, respond with any grammatical corrections needed. If no corrections
66 | are needed, respond with "verified".
--------------------------------------------------------------------------------
/private_gpt/components/llm/llm_component.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from injector import inject, singleton
4 | from llama_index import set_global_tokenizer
5 | from llama_index.llms import MockLLM
6 | from llama_index.llms.base import LLM
7 | from transformers import AutoTokenizer # type: ignore
8 |
9 | from private_gpt.components.llm.prompt_helper import get_prompt_style
10 | from private_gpt.paths import models_cache_path, models_path
11 | from private_gpt.settings.settings import Settings
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | @singleton
17 | class LLMComponent:
18 | llm: LLM
19 |
20 | @inject
21 | def __init__(self, settings: Settings) -> None:
22 | llm_mode = settings.llm.mode
23 | if settings.llm.tokenizer:
24 | set_global_tokenizer(
25 | AutoTokenizer.from_pretrained(
26 | pretrained_model_name_or_path=settings.llm.tokenizer,
27 | cache_dir=str(models_cache_path),
28 | )
29 | )
30 |
31 | logger.info("Initializing the LLM in mode=%s", llm_mode)
32 | match settings.llm.mode:
33 | case "local":
34 | from llama_index.llms import LlamaCPP
35 |
36 | prompt_style = get_prompt_style(settings.local.prompt_style)
37 |
38 | self.llm = LlamaCPP(
39 | model_path=str(models_path / settings.local.llm_hf_model_file),
40 | temperature=0.1,
41 | max_new_tokens=settings.llm.max_new_tokens,
42 | context_window=settings.llm.context_window,
43 | generate_kwargs={},
44 | # All to GPU
45 | model_kwargs={"n_gpu_layers": -1, "offload_kqv": True},
46 | # transform inputs into Llama2 format
47 | messages_to_prompt=prompt_style.messages_to_prompt,
48 | completion_to_prompt=prompt_style.completion_to_prompt,
49 | verbose=True,
50 | )
51 |
52 | case "sagemaker":
53 | from private_gpt.components.llm.custom.sagemaker import SagemakerLLM
54 |
55 | self.llm = SagemakerLLM(
56 | endpoint_name=settings.sagemaker.llm_endpoint_name,
57 | max_new_tokens=settings.llm.max_new_tokens,
58 | context_window=settings.llm.context_window,
59 | )
60 | case "openai":
61 | from llama_index.llms import OpenAI
62 |
63 | openai_settings = settings.openai
64 | self.llm = OpenAI(
65 | api_base=openai_settings.api_base,
66 | api_key=openai_settings.api_key,
67 | model=openai_settings.model,
68 | )
69 | case "openailike":
70 | from llama_index.llms import OpenAILike
71 |
72 | openai_settings = settings.openai
73 | self.llm = OpenAILike(
74 | api_base=openai_settings.api_base,
75 | api_key=openai_settings.api_key,
76 | model=openai_settings.model,
77 | is_chat_model=True,
78 | max_tokens=None,
79 | api_version="",
80 | )
81 | case "mock":
82 | self.llm = MockLLM()
83 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | ## [0.2.0](https://github.com/imartinez/privateGPT/compare/v0.1.0...v0.2.0) (2023-12-10)
4 |
5 |
6 | ### Features
7 |
8 | * **llm:** drop default_system_prompt ([#1385](https://github.com/imartinez/privateGPT/issues/1385)) ([a3ed14c](https://github.com/imartinez/privateGPT/commit/a3ed14c58f77351dbd5f8f2d7868d1642a44f017))
9 | * **ui:** Allows User to Set System Prompt via "Additional Options" in Chat Interface ([#1353](https://github.com/imartinez/privateGPT/issues/1353)) ([145f3ec](https://github.com/imartinez/privateGPT/commit/145f3ec9f41c4def5abf4065a06fb0786e2d992a))
10 |
11 | ## [0.1.0](https://github.com/imartinez/privateGPT/compare/v0.0.2...v0.1.0) (2023-11-30)
12 |
13 |
14 | ### Features
15 |
16 | * Disable Gradio Analytics ([#1165](https://github.com/imartinez/privateGPT/issues/1165)) ([6583dc8](https://github.com/imartinez/privateGPT/commit/6583dc84c082773443fc3973b1cdf8095fa3fec3))
17 | * Drop loguru and use builtin `logging` ([#1133](https://github.com/imartinez/privateGPT/issues/1133)) ([64c5ae2](https://github.com/imartinez/privateGPT/commit/64c5ae214a9520151c9c2d52ece535867d799367))
18 | * enable resume download for hf_hub_download ([#1249](https://github.com/imartinez/privateGPT/issues/1249)) ([4197ada](https://github.com/imartinez/privateGPT/commit/4197ada6267c822f32c1d7ba2be6e7ce145a3404))
19 | * move torch and transformers to local group ([#1172](https://github.com/imartinez/privateGPT/issues/1172)) ([0d677e1](https://github.com/imartinez/privateGPT/commit/0d677e10b970aec222ec04837d0f08f1631b6d4a))
20 | * Qdrant support ([#1228](https://github.com/imartinez/privateGPT/issues/1228)) ([03d1ae6](https://github.com/imartinez/privateGPT/commit/03d1ae6d70dffdd2411f0d4e92f65080fff5a6e2))
21 |
22 |
23 | ### Bug Fixes
24 |
25 | * Docker and sagemaker setup ([#1118](https://github.com/imartinez/privateGPT/issues/1118)) ([895588b](https://github.com/imartinez/privateGPT/commit/895588b82a06c2bc71a9e22fb840c7f6442a3b5b))
26 | * fix pytorch version to avoid wheel bug ([#1123](https://github.com/imartinez/privateGPT/issues/1123)) ([24cfddd](https://github.com/imartinez/privateGPT/commit/24cfddd60f74aadd2dade4c63f6012a2489938a1))
27 | * Remove global state ([#1216](https://github.com/imartinez/privateGPT/issues/1216)) ([022bd71](https://github.com/imartinez/privateGPT/commit/022bd718e3dfc197027b1e24fb97e5525b186db4))
28 | * sagemaker config and chat methods ([#1142](https://github.com/imartinez/privateGPT/issues/1142)) ([a517a58](https://github.com/imartinez/privateGPT/commit/a517a588c4927aa5c5c2a93e4f82a58f0599d251))
29 | * typo in README.md ([#1091](https://github.com/imartinez/privateGPT/issues/1091)) ([ba23443](https://github.com/imartinez/privateGPT/commit/ba23443a70d323cd4f9a242b33fd9dce1bacd2db))
30 | * Windows 11 failing to auto-delete tmp file ([#1260](https://github.com/imartinez/privateGPT/issues/1260)) ([0d52002](https://github.com/imartinez/privateGPT/commit/0d520026a3d5b08a9b8487be992d3095b21e710c))
31 | * Windows permission error on ingest service tmp files ([#1280](https://github.com/imartinez/privateGPT/issues/1280)) ([f1cbff0](https://github.com/imartinez/privateGPT/commit/f1cbff0fb7059432d9e71473cbdd039032dab60d))
32 |
33 | ## [0.0.2](https://github.com/imartinez/privateGPT/compare/v0.0.1...v0.0.2) (2023-10-20)
34 |
35 |
36 | ### Bug Fixes
37 |
38 | * chromadb max batch size ([#1087](https://github.com/imartinez/privateGPT/issues/1087)) ([f5a9bf4](https://github.com/imartinez/privateGPT/commit/f5a9bf4e374b2d4c76438cf8a97cccf222ec8e6f))
39 |
40 | ## 0.0.1 (2023-10-20)
41 |
42 | ### Miscellaneous Chores
43 |
44 | * Initial version ([490d93f](https://github.com/imartinez/privateGPT/commit/490d93fdc1977443c92f6c42e57a1c585aa59430))
45 |
--------------------------------------------------------------------------------
/fern/docs/pages/manual/settings.mdx:
--------------------------------------------------------------------------------
1 | # Settings and profiles for your private GPT
2 |
3 | The configuration of your private GPT server is done thanks to `settings` files (more precisely `settings.yaml`).
4 | These text files are written using the [YAML](https://en.wikipedia.org/wiki/YAML) syntax.
5 |
6 | While privateGPT is distributing safe and universal configuration files, you might want to quickly customize your
7 | privateGPT, and this can be done using the `settings` files.
8 |
9 | This project is defining the concept of **profiles** (or configuration profiles).
10 | This mechanism, using your environment variables, is giving you the ability to easily switch between
11 | configuration you've made.
12 |
13 | A typical use case of profile is to easily switch between LLM and embeddings.
14 | To be a bit more precise, you can change the language (to French, Spanish, Italian, English, etc) by simply changing
15 | the profile you've selected; no code changes required!
16 |
17 | PrivateGPT is configured through *profiles* that are defined using yaml files, and selected through env variables.
18 | The full list of properties configurable can be found in `settings.yaml`.
19 |
20 | ## How to know which profiles exist
21 | Given that a profile `foo_bar` points to the file `settings-foo_bar.yaml` and vice-versa, you simply have to look
22 | at the files starting with `settings` and ending in `.yaml`.
23 |
24 | ## How to use an existing profiles
25 | **Please note that the syntax to set the value of an environment variables depends on your OS**.
26 | You have to set environment variable `PGPT_PROFILES` to the name of the profile you want to use.
27 |
28 | For example, on **linux and macOS**, this gives:
29 | ```bash
30 | export PGPT_PROFILES=my_profile_name_here
31 | ```
32 |
33 | Windows Powershell(s) have a different syntax, one of them being:
34 | ```shell
35 | set PGPT_PROFILES=my_profile_name_here
36 | ```
37 | If the above is not working, you might want to try other ways to set an env variable in your window's terminal.
38 |
39 | ---
40 |
41 | Once you've set this environment variable to the desired profile, you can simply launch your privateGPT,
42 | and it will run using your profile on top of the default configuration.
43 |
44 | ## Reference
45 | Additional details on the profiles are described in this section
46 |
47 | ### Environment variable `PGPT_SETTINGS_FOLDER`
48 |
49 | The location of the settings folder. Defaults to the root of the project.
50 | Should contain the default `settings.yaml` and any other `settings-{profile}.yaml`.
51 |
52 | ### Environment variable `PGPT_PROFILES`
53 |
54 | By default, the profile definition in `settings.yaml` is loaded.
55 | Using this env var you can load additional profiles; format is a comma separated list of profile names.
56 | This will merge `settings-{profile}.yaml` on top of the base settings file.
57 |
58 | For example:
59 | `PGPT_PROFILES=local,cuda` will load `settings-local.yaml`
60 | and `settings-cuda.yaml`, their contents will be merged with
61 | later profiles properties overriding values of earlier ones like `settings.yaml`.
62 |
63 | During testing, the `test` profile will be active along with the default, therefore `settings-test.yaml`
64 | file is required.
65 |
66 | ### Environment variables expansion
67 |
68 | Configuration files can contain environment variables,
69 | they will be expanded at runtime.
70 |
71 | Expansion must follow the pattern `${VARIABLE_NAME:default_value}`.
72 |
73 | For example, the following configuration will use the value of the `PORT`
74 | environment variable or `8001` if it's not set.
75 | Missing variables with no default will produce an error.
76 |
77 | ```yaml
78 | server:
79 | port: ${PORT:8001}
80 | ```
--------------------------------------------------------------------------------
/private_gpt/server/completions/completions_router.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter, Depends, Request
2 | from pydantic import BaseModel
3 | from starlette.responses import StreamingResponse
4 |
5 | from private_gpt.open_ai.extensions.context_filter import ContextFilter
6 | from private_gpt.open_ai.openai_models import (
7 | OpenAICompletion,
8 | OpenAIMessage,
9 | )
10 | from private_gpt.server.chat.chat_router import ChatBody, chat_completion
11 | from private_gpt.server.utils.auth import authenticated
12 |
13 | completions_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
14 |
15 |
16 | class CompletionsBody(BaseModel):
17 | prompt: str
18 | system_prompt: str | None = None
19 | use_context: bool = False
20 | context_filter: ContextFilter | None = None
21 | include_sources: bool = True
22 | stream: bool = False
23 |
24 | model_config = {
25 | "json_schema_extra": {
26 | "examples": [
27 | {
28 | "prompt": "How do you fry an egg?",
29 | "system_prompt": "You are a rapper. Always answer with a rap.",
30 | "stream": False,
31 | "use_context": False,
32 | "include_sources": False,
33 | }
34 | ]
35 | }
36 | }
37 |
38 |
39 | @completions_router.post(
40 | "/completions",
41 | response_model=None,
42 | summary="Completion",
43 | responses={200: {"model": OpenAICompletion}},
44 | tags=["Contextual Completions"],
45 | openapi_extra={
46 | "x-fern-streaming": {
47 | "stream-condition": "stream",
48 | "response": {"$ref": "#/components/schemas/OpenAICompletion"},
49 | "response-stream": {"$ref": "#/components/schemas/OpenAICompletion"},
50 | }
51 | },
52 | )
53 | def prompt_completion(
54 | request: Request, body: CompletionsBody
55 | ) -> OpenAICompletion | StreamingResponse:
56 | """We recommend most users use our Chat completions API.
57 |
58 | Given a prompt, the model will return one predicted completion.
59 |
60 | Optionally include a `system_prompt` to influence the way the LLM answers.
61 |
62 | If `use_context`
63 | is set to `true`, the model will use context coming from the ingested documents
64 | to create the response. The documents being used can be filtered using the
65 | `context_filter` and passing the document IDs to be used. Ingested documents IDs
66 | can be found using `/ingest/list` endpoint. If you want all ingested documents to
67 | be used, remove `context_filter` altogether.
68 |
69 | When using `'include_sources': true`, the API will return the source Chunks used
70 | to create the response, which come from the context provided.
71 |
72 | When using `'stream': true`, the API will return data chunks following [OpenAI's
73 | streaming model](https://platform.openai.com/docs/api-reference/chat/streaming):
74 | ```
75 | {"id":"12345","object":"completion.chunk","created":1694268190,
76 | "model":"private-gpt","choices":[{"index":0,"delta":{"content":"Hello"},
77 | "finish_reason":null}]}
78 | ```
79 | """
80 | messages = [OpenAIMessage(content=body.prompt, role="user")]
81 | # If system prompt is passed, create a fake message with the system prompt.
82 | if body.system_prompt:
83 | messages.insert(0, OpenAIMessage(content=body.system_prompt, role="system"))
84 |
85 | chat_body = ChatBody(
86 | messages=messages,
87 | use_context=body.use_context,
88 | stream=body.stream,
89 | include_sources=body.include_sources,
90 | context_filter=body.context_filter,
91 | )
92 | return chat_completion(request, chat_body)
93 |
--------------------------------------------------------------------------------
/fern/docs.yml:
--------------------------------------------------------------------------------
1 | # Main Fern configuration file
2 | instances:
3 | - url: privategpt.docs.buildwithfern.com
4 | custom-domain: docs.privategpt.dev
5 |
6 | title: PrivateGPT | Docs
7 |
8 | # The tabs definition, in the top left corner
9 | tabs:
10 | overview:
11 | display-name: Overview
12 | icon: "fa-solid fa-home"
13 | installation:
14 | display-name: Installation
15 | icon: "fa-solid fa-download"
16 | manual:
17 | display-name: Manual
18 | icon: "fa-solid fa-book"
19 | recipes:
20 | display-name: Recipes
21 | icon: "fa-solid fa-flask"
22 | api-reference:
23 | display-name: API Reference
24 | icon: "fa-solid fa-file-contract"
25 |
26 | # Definition of tabs contents, will be displayed on the left side of the page, below all tabs
27 | navigation:
28 | # The default tab
29 | - tab: overview
30 | layout:
31 | - section: Welcome
32 | contents:
33 | - page: Welcome
34 | path: ./docs/pages/overview/welcome.mdx
35 | - page: Quickstart
36 | path: ./docs/pages/overview/quickstart.mdx
37 | # How to install privateGPT, with FAQ and troubleshooting
38 | - tab: installation
39 | layout:
40 | - section: Getting started
41 | contents:
42 | - page: Installation
43 | path: ./docs/pages/installation/installation.mdx
44 | # Manual of privateGPT: how to use it and configure it
45 | - tab: manual
46 | layout:
47 | - section: General configuration
48 | contents:
49 | - page: Configuration
50 | path: ./docs/pages/manual/settings.mdx
51 | - section: Document management
52 | contents:
53 | - page: Ingestion
54 | path: ./docs/pages/manual/ingestion.mdx
55 | - page: Deletion
56 | path: ./docs/pages/manual/ingestion-reset.mdx
57 | - section: Storage
58 | contents:
59 | - page: Vector Stores
60 | path: ./docs/pages/manual/vectordb.mdx
61 | - section: Advanced Setup
62 | contents:
63 | - page: LLM Backends
64 | path: ./docs/pages/manual/llms.mdx
65 | - section: User Interface
66 | contents:
67 | - page: User interface (Gradio) Manual
68 | path: ./docs/pages/manual/ui.mdx
69 | # Small code snippet or example of usage to help users
70 | - tab: recipes
71 | layout:
72 | - section: Choice of LLM
73 | contents:
74 | # TODO: add recipes
75 | - page: List of LLMs
76 | path: ./docs/pages/recipes/list-llm.mdx
77 | # More advanced usage of privateGPT, by API
78 | - tab: api-reference
79 | layout:
80 | - section: Overview
81 | contents:
82 | - page : API Reference overview
83 | path: ./docs/pages/api-reference/api-reference.mdx
84 | - page: SDKs
85 | path: ./docs/pages/api-reference/sdks.mdx
86 | - api: API Reference
87 |
88 | # Definition of the navbar, will be displayed in the top right corner.
89 | # `type:primary` is always displayed at the most right side of the navbar
90 | navbar-links:
91 | - type: secondary
92 | text: GitHub
93 | url: "https://github.com/imartinez/privateGPT"
94 | - type: secondary
95 | text: Contact us
96 | url: "mailto:hello@zylon.ai"
97 | - type: primary
98 | text: Join the Discord
99 | url: https://discord.com/invite/bK6mRVpErU
100 |
101 | colors:
102 | accentPrimary:
103 | dark: "#C6BBFF"
104 | light: "#756E98"
105 |
106 | logo:
107 | dark: ./docs/assets/logo_light.png
108 | light: ./docs/assets/logo_dark.png
109 | height: 50
110 |
111 | favicon: ./docs/assets/favicon.ico
112 |
--------------------------------------------------------------------------------
/fern/docs/pages/manual/llms.mdx:
--------------------------------------------------------------------------------
1 | ## Running the Server
2 |
3 | PrivateGPT supports running with different LLMs & setups.
4 |
5 | ### Local models
6 |
7 | Both the LLM and the Embeddings model will run locally.
8 |
9 | Make sure you have followed the *Local LLM requirements* section before moving on.
10 |
11 | This command will start PrivateGPT using the `settings.yaml` (default profile) together with the `settings-local.yaml`
12 | configuration files. By default, it will enable both the API and the Gradio UI. Run:
13 |
14 | ```bash
15 | PGPT_PROFILES=local make run
16 | ```
17 |
18 | or
19 |
20 | ```bash
21 | PGPT_PROFILES=local poetry run python -m private_gpt
22 | ```
23 |
24 | When the server is started it will print a log *Application startup complete*.
25 | Navigate to http://localhost:8001 to use the Gradio UI or to http://localhost:8001/docs (API section) to try the API
26 | using Swagger UI.
27 |
28 | ### Using OpenAI
29 |
30 | If you cannot run a local model (because you don't have a GPU, for example) or for testing purposes, you may
31 | decide to run PrivateGPT using OpenAI as the LLM and Embeddings model.
32 |
33 | In order to do so, create a profile `settings-openai.yaml` with the following contents:
34 |
35 | ```yaml
36 | llm:
37 | mode: openai
38 |
39 | openai:
40 | api_base: # Defaults to https://api.openai.com/v1
41 | api_key: # You could skip this configuration and use the OPENAI_API_KEY env var instead
42 | model: # Optional model to use. Default is "gpt-3.5-turbo"
43 | # Note: Open AI Models are listed here: https://platform.openai.com/docs/models
44 | ```
45 |
46 | And run PrivateGPT loading that profile you just created:
47 |
48 | `PGPT_PROFILES=openai make run`
49 |
50 | or
51 |
52 | `PGPT_PROFILES=openai poetry run python -m private_gpt`
53 |
54 | When the server is started it will print a log *Application startup complete*.
55 | Navigate to http://localhost:8001 to use the Gradio UI or to http://localhost:8001/docs (API section) to try the API.
56 | You'll notice the speed and quality of response is higher, given you are using OpenAI's servers for the heavy
57 | computations.
58 |
59 | ### Using OpenAI compatible API
60 |
61 | Many tools, including [LocalAI](https://localai.io/) and [vLLM](https://docs.vllm.ai/en/latest/),
62 | support serving local models with an OpenAI compatible API. Even when overriding the `api_base`,
63 | using the `openai` mode doesn't allow you to use custom models. Instead, you should use the `openailike` mode:
64 |
65 | ```yaml
66 | llm:
67 | mode: openailike
68 | ```
69 |
70 | This mode uses the same settings as the `openai` mode.
71 |
72 | As an example, you can follow the [vLLM quickstart guide](https://docs.vllm.ai/en/latest/getting_started/quickstart.html#openai-compatible-server)
73 | to run an OpenAI compatible server. Then, you can run PrivateGPT using the `settings-vllm.yaml` profile:
74 |
75 | `PGPT_PROFILES=vllm make run`
76 |
77 | ### Using AWS Sagemaker
78 |
79 | For a fully private & performant setup, you can choose to have both your LLM and Embeddings model deployed using Sagemaker.
80 |
81 | Note: how to deploy models on Sagemaker is out of the scope of this documentation.
82 |
83 | In order to do so, create a profile `settings-sagemaker.yaml` with the following contents (remember to
84 | update the values of the llm_endpoint_name and embedding_endpoint_name to yours):
85 |
86 | ```yaml
87 | llm:
88 | mode: sagemaker
89 |
90 | sagemaker:
91 | llm_endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140
92 | embedding_endpoint_name: huggingface-pytorch-inference-2023-11-03-07-41-36-479
93 | ```
94 |
95 | And run PrivateGPT loading that profile you just created:
96 |
97 | `PGPT_PROFILES=sagemaker make run`
98 |
99 | or
100 |
101 | `PGPT_PROFILES=sagemaker poetry run python -m private_gpt`
102 |
103 | When the server is started it will print a log *Application startup complete*.
104 | Navigate to http://localhost:8001 to use the Gradio UI or to http://localhost:8001/docs (API section) to try the API.
105 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "private-gpt"
3 | version = "0.2.0"
4 | description = "Private GPT"
5 | authors = ["Zylon "]
6 |
7 | [tool.poetry.dependencies]
8 | python = ">=3.11,<3.12"
9 | fastapi = { extras = ["all"], version = "^0.103.1" }
10 | boto3 = "^1.28.56"
11 | injector = "^0.21.0"
12 | pyyaml = "^6.0.1"
13 | python-multipart = "^0.0.6"
14 | pypdf = "^3.16.2"
15 | llama-index = { extras = ["local_models"], version = "0.9.3" }
16 | watchdog = "^3.0.0"
17 | qdrant-client = "^1.6.9"
18 | chromadb = {version = "^0.4.13", optional = true}
19 |
20 | [tool.poetry.group.dev.dependencies]
21 | black = "^22"
22 | mypy = "^1.2"
23 | pre-commit = "^2"
24 | pytest = "^7"
25 | pytest-cov = "^3"
26 | ruff = "^0"
27 | pytest-asyncio = "^0.21.1"
28 | types-pyyaml = "^6.0.12.12"
29 |
30 | # Dependencies for gradio UI
31 | [tool.poetry.group.ui]
32 | optional = true
33 | [tool.poetry.group.ui.dependencies]
34 | gradio = "^4.4.1"
35 |
36 | [tool.poetry.group.local]
37 | optional = true
38 | [tool.poetry.group.local.dependencies]
39 | llama-cpp-python = "^0.2.23"
40 | numpy = "1.26.0"
41 | sentence-transformers = "^2.2.2"
42 | # https://stackoverflow.com/questions/76327419/valueerror-libcublas-so-0-9-not-found-in-the-system-path
43 | torch = ">=2.0.0, !=2.0.1, !=2.1.0"
44 | transformers = "^4.34.0"
45 |
46 | [tool.poetry.extras]
47 | chroma = ["chromadb"]
48 |
49 | [build-system]
50 | requires = ["poetry-core>=1.0.0"]
51 | build-backend = "poetry.core.masonry.api"
52 |
53 | # Packages configs
54 |
55 | ## coverage
56 |
57 | [tool.coverage.run]
58 | branch = true
59 |
60 | [tool.coverage.report]
61 | skip_empty = true
62 | precision = 2
63 |
64 | ## black
65 |
66 | [tool.black]
67 | target-version = ['py311']
68 |
69 | ## ruff
70 | # Recommended ruff config for now, to be updated as we go along.
71 | [tool.ruff]
72 | target-version = 'py311'
73 |
74 | # See all rules at https://beta.ruff.rs/docs/rules/
75 | select = [
76 | "E", # pycodestyle
77 | "W", # pycodestyle
78 | "F", # Pyflakes
79 | "B", # flake8-bugbear
80 | "C4", # flake8-comprehensions
81 | "D", # pydocstyle
82 | "I", # isort
83 | "SIM", # flake8-simplify
84 | "TCH", # flake8-type-checking
85 | "TID", # flake8-tidy-imports
86 | "Q", # flake8-quotes
87 | "UP", # pyupgrade
88 | "PT", # flake8-pytest-style
89 | "RUF", # Ruff-specific rules
90 | ]
91 |
92 | ignore = [
93 | "E501", # "Line too long"
94 | # -> line length already regulated by black
95 | "PT011", # "pytest.raises() should specify expected exception"
96 | # -> would imply to update tests every time you update exception message
97 | "SIM102", # "Use a single `if` statement instead of nested `if` statements"
98 | # -> too restrictive,
99 | "D100",
100 | "D101",
101 | "D102",
102 | "D103",
103 | "D104",
104 | "D105",
105 | "D106",
106 | "D107"
107 | # -> "Missing docstring in public function too restrictive"
108 | ]
109 |
110 | [tool.ruff.pydocstyle]
111 | # Automatically disable rules that are incompatible with Google docstring convention
112 | convention = "google"
113 |
114 | [tool.ruff.pycodestyle]
115 | max-doc-length = 88
116 |
117 | [tool.ruff.flake8-tidy-imports]
118 | ban-relative-imports = "all"
119 |
120 | [tool.ruff.flake8-type-checking]
121 | strict = true
122 | runtime-evaluated-base-classes = ["pydantic.BaseModel"]
123 | # Pydantic needs to be able to evaluate types at runtime
124 | # see https://pypi.org/project/flake8-type-checking/ for flake8-type-checking documentation
125 | # see https://beta.ruff.rs/docs/settings/#flake8-type-checking-runtime-evaluated-base-classes for ruff documentation
126 |
127 | [tool.ruff.per-file-ignores]
128 | # Allow missing docstrings for tests
129 | "tests/**/*.py" = ["D1"]
130 |
131 | ## mypy
132 |
133 | [tool.mypy]
134 | python_version = "3.11"
135 | strict = true
136 | check_untyped_defs = false
137 | explicit_package_bases = true
138 | warn_unused_ignores = false
139 | exclude = ["tests"]
140 |
141 | [tool.pytest.ini_options]
142 | asyncio_mode = "auto"
143 | testpaths = ["tests"]
144 | addopts = [
145 | "--import-mode=importlib",
146 | ]
147 |
--------------------------------------------------------------------------------
/private_gpt/open_ai/openai_models.py:
--------------------------------------------------------------------------------
1 | import time
2 | import uuid
3 | from collections.abc import Iterator
4 | from typing import Literal
5 |
6 | from llama_index.llms import ChatResponse, CompletionResponse
7 | from pydantic import BaseModel, Field
8 |
9 | from private_gpt.server.chunks.chunks_service import Chunk
10 |
11 |
12 | class OpenAIDelta(BaseModel):
13 | """A piece of completion that needs to be concatenated to get the full message."""
14 |
15 | content: str | None
16 |
17 |
18 | class OpenAIMessage(BaseModel):
19 | """Inference result, with the source of the message.
20 |
21 | Role could be the assistant or system
22 | (providing a default response, not AI generated).
23 | """
24 |
25 | role: Literal["assistant", "system", "user"] = Field(default="user")
26 | content: str | None
27 |
28 |
29 | class OpenAIChoice(BaseModel):
30 | """Response from AI.
31 |
32 | Either the delta or the message will be present, but never both.
33 | Sources used will be returned in case context retrieval was enabled.
34 | """
35 |
36 | finish_reason: str | None = Field(examples=["stop"])
37 | delta: OpenAIDelta | None = None
38 | message: OpenAIMessage | None = None
39 | sources: list[Chunk] | None = None
40 | index: int = 0
41 |
42 |
43 | class OpenAICompletion(BaseModel):
44 | """Clone of OpenAI Completion model.
45 |
46 | For more information see: https://platform.openai.com/docs/api-reference/chat/object
47 | """
48 |
49 | id: str
50 | object: Literal["completion", "completion.chunk"] = Field(default="completion")
51 | created: int = Field(..., examples=[1623340000])
52 | model: Literal["private-gpt"]
53 | choices: list[OpenAIChoice]
54 |
55 | @classmethod
56 | def from_text(
57 | cls,
58 | text: str | None,
59 | finish_reason: str | None = None,
60 | sources: list[Chunk] | None = None,
61 | ) -> "OpenAICompletion":
62 | return OpenAICompletion(
63 | id=str(uuid.uuid4()),
64 | object="completion",
65 | created=int(time.time()),
66 | model="private-gpt",
67 | choices=[
68 | OpenAIChoice(
69 | message=OpenAIMessage(role="assistant", content=text),
70 | finish_reason=finish_reason,
71 | sources=sources,
72 | )
73 | ],
74 | )
75 |
76 | @classmethod
77 | def json_from_delta(
78 | cls,
79 | *,
80 | text: str | None,
81 | finish_reason: str | None = None,
82 | sources: list[Chunk] | None = None,
83 | ) -> str:
84 | chunk = OpenAICompletion(
85 | id=str(uuid.uuid4()),
86 | object="completion.chunk",
87 | created=int(time.time()),
88 | model="private-gpt",
89 | choices=[
90 | OpenAIChoice(
91 | delta=OpenAIDelta(content=text),
92 | finish_reason=finish_reason,
93 | sources=sources,
94 | )
95 | ],
96 | )
97 |
98 | return chunk.model_dump_json()
99 |
100 |
101 | def to_openai_response(
102 | response: str | ChatResponse, sources: list[Chunk] | None = None
103 | ) -> OpenAICompletion:
104 | if isinstance(response, ChatResponse):
105 | return OpenAICompletion.from_text(response.delta, finish_reason="stop")
106 | else:
107 | return OpenAICompletion.from_text(
108 | response, finish_reason="stop", sources=sources
109 | )
110 |
111 |
112 | def to_openai_sse_stream(
113 | response_generator: Iterator[str | CompletionResponse | ChatResponse],
114 | sources: list[Chunk] | None = None,
115 | ) -> Iterator[str]:
116 | for response in response_generator:
117 | if isinstance(response, CompletionResponse | ChatResponse):
118 | yield f"data: {OpenAICompletion.json_from_delta(text=response.delta)}\n\n"
119 | else:
120 | yield f"data: {OpenAICompletion.json_from_delta(text=response, sources=sources)}\n\n"
121 | yield f"data: {OpenAICompletion.json_from_delta(text='', finish_reason='stop')}\n\n"
122 | yield "data: [DONE]\n\n"
123 |
--------------------------------------------------------------------------------
/fern/docs/pages/recipes/list-llm.mdx:
--------------------------------------------------------------------------------
1 | # List of working LLM
2 |
3 | **Do you have any working combination of LLM and embeddings?**
4 | Please open a PR to add it to the list, and come on our Discord to tell us about it!
5 |
6 | ## Prompt style
7 |
8 | LLMs might have been trained with different prompt styles.
9 | The prompt style is the way the prompt is written, and how the system message is injected in the prompt.
10 |
11 | For example, `llama2` looks like this:
12 | ```text
13 | [INST] <>
14 | {{ system_prompt }}
15 | <>
16 |
17 | {{ user_message }} [/INST]
18 | ```
19 |
20 | While `default` (the `llama_index` default) looks like this:
21 | ```text
22 | system: {{ system_prompt }}
23 | user: {{ user_message }}
24 | assistant: {{ assistant_message }}
25 | ```
26 |
27 | The "`tag`" style looks like this:
28 |
29 | ```text
30 | <|system|>: {{ system_prompt }}
31 | <|user|>: {{ user_message }}
32 | <|assistant|>: {{ assistant_message }}
33 | ```
34 |
35 | The "`mistral`" style looks like this:
36 |
37 | ```text
38 | [INST] You are an AI assistant. [/INST][INST] Hello, how are you doing? [/INST]
39 | ```
40 |
41 | The "`chatml`" style looks like this:
42 | ```text
43 | <|im_start|>system
44 | {{ system_prompt }}<|im_end|>
45 | <|im_start|>user"
46 | {{ user_message }}<|im_end|>
47 | <|im_start|>assistant
48 | {{ assistant_message }}
49 | ```
50 |
51 | Some LLMs will not understand these prompt styles, and will not work (returning nothing).
52 | You can try to change the prompt style to `default` (or `tag`) in the settings, and it will
53 | change the way the messages are formatted to be passed to the LLM.
54 |
55 | ## Example of configuration
56 |
57 | You might want to change the prompt depending on the language and model you are using.
58 |
59 | ### English, with instructions
60 |
61 | `settings-en.yaml`:
62 | ```yml
63 | local:
64 | llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.1-GGUF
65 | llm_hf_model_file: mistral-7b-instruct-v0.1.Q4_K_M.gguf
66 | embedding_hf_model_name: BAAI/bge-small-en-v1.5
67 | prompt_style: "llama2"
68 | ```
69 |
70 | ### French, with instructions
71 |
72 | `settings-fr.yaml`:
73 | ```yml
74 | local:
75 | llm_hf_repo_id: TheBloke/Vigogne-2-7B-Instruct-GGUF
76 | llm_hf_model_file: vigogne-2-7b-instruct.Q4_K_M.gguf
77 | embedding_hf_model_name: dangvantuan/sentence-camembert-base
78 | prompt_style: "default"
79 | # prompt_style: "tag" # also works
80 | # The default system prompt is injected only when the `prompt_style` != default, and there are no system message in the discussion
81 | # default_system_prompt: Vous êtes un assistant IA qui répond à la question posée à la fin en utilisant le contexte suivant. Si vous ne connaissez pas la réponse, dites simplement que vous ne savez pas, n'essayez pas d'inventer une réponse. Veuillez répondre exclusivement en français.
82 | ```
83 |
84 | You might want to change the prompt as the one above might not directly answer your question.
85 | You can read online about how to write a good prompt, but in a nutshell, make it (extremely) directive.
86 |
87 | You can try and troubleshot your prompt by writing multiline requests in the UI, while
88 | writing your interaction with the model, for example:
89 |
90 | ```text
91 | Tu es un programmeur senior qui programme en python et utilise le framework fastapi. Ecrit moi un serveur qui retourne "hello world".
92 | ```
93 |
94 | Another example:
95 | ```text
96 | Context: None
97 | Situation: tu es au milieu d'un champ.
98 | Tache: va a la rivière, en bas du champ.
99 | Décrit comment aller a la rivière.
100 | ```
101 |
102 | ### Optimised Models
103 | GodziLLa2-70B LLM (English, rank 2 on HuggingFace OpenLLM Leaderboard), bge large Embedding Model (rank 1 on HuggingFace MTEB Leaderboard)
104 | `settings-optimised.yaml`:
105 | ```yml
106 | local:
107 | llm_hf_repo_id: TheBloke/GodziLLa2-70B-GGUF
108 | llm_hf_model_file: godzilla2-70b.Q4_K_M.gguf
109 | embedding_hf_model_name: BAAI/bge-large-en
110 | prompt_style: "llama2"
111 | ```
112 | ### German speaking model
113 | `settings-de.yaml`:
114 | ```yml
115 | local:
116 | llm_hf_repo_id: TheBloke/em_german_leo_mistral-GGUF
117 | llm_hf_model_file: em_german_leo_mistral.Q4_K_M.gguf
118 | embedding_hf_model_name: T-Systems-onsite/german-roberta-sentence-transformer-v2
119 | #llama, default or tag
120 | prompt_style: "default"
121 | ```
122 |
--------------------------------------------------------------------------------
/scripts/ingest_folder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import argparse
4 | import logging
5 | from pathlib import Path
6 |
7 | from private_gpt.di import global_injector
8 | from private_gpt.server.ingest.ingest_service import IngestService
9 | from private_gpt.server.ingest.ingest_watcher import IngestWatcher
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class LocalIngestWorker:
15 | def __init__(self, ingest_service: IngestService) -> None:
16 | self.ingest_service = ingest_service
17 |
18 | self.total_documents = 0
19 | self.current_document_count = 0
20 |
21 | self._files_under_root_folder: list[Path] = list()
22 |
23 | def _find_all_files_in_folder(self, root_path: Path, ignored: list[str]) -> None:
24 | """Search all files under the root folder recursively.
25 | Count them at the same time
26 | """
27 | for file_path in root_path.iterdir():
28 | if file_path.is_file() and file_path.name not in ignored:
29 | self.total_documents += 1
30 | self._files_under_root_folder.append(file_path)
31 | elif file_path.is_dir() and file_path.name not in ignored:
32 | self._find_all_files_in_folder(file_path, ignored)
33 |
34 | def ingest_folder(self, folder_path: Path, ignored: list[str]) -> None:
35 | # Count total documents before ingestion
36 | self._find_all_files_in_folder(folder_path, ignored)
37 | self._ingest_all(self._files_under_root_folder)
38 |
39 | def _ingest_all(self, files_to_ingest: list[Path]) -> None:
40 | logger.info("Ingesting files=%s", [f.name for f in files_to_ingest])
41 | self.ingest_service.bulk_ingest([(str(p.name), p) for p in files_to_ingest])
42 |
43 | def ingest_on_watch(self, changed_path: Path) -> None:
44 | logger.info("Detected change in at path=%s, ingesting", changed_path)
45 | self._do_ingest_one(changed_path)
46 |
47 | def _do_ingest_one(self, changed_path: Path) -> None:
48 | try:
49 | if changed_path.exists():
50 | logger.info(f"Started ingesting file={changed_path}")
51 | self.ingest_service.ingest_file(changed_path.name, changed_path)
52 | logger.info(f"Completed ingesting file={changed_path}")
53 | except Exception:
54 | logger.exception(
55 | f"Failed to ingest document: {changed_path}, find the exception attached"
56 | )
57 |
58 |
59 | parser = argparse.ArgumentParser(prog="ingest_folder.py")
60 | parser.add_argument("folder", help="Folder to ingest")
61 | parser.add_argument(
62 | "--watch",
63 | help="Watch for changes",
64 | action=argparse.BooleanOptionalAction,
65 | default=False,
66 | )
67 | parser.add_argument(
68 | "--ignored",
69 | nargs="*",
70 | help="List of files/directories to ignore",
71 | default=[],
72 | )
73 | parser.add_argument(
74 | "--log-file",
75 | help="Optional path to a log file. If provided, logs will be written to this file.",
76 | type=str,
77 | default=None,
78 | )
79 |
80 | args = parser.parse_args()
81 |
82 | # Set up logging to a file if a path is provided
83 | if args.log_file:
84 | file_handler = logging.FileHandler(args.log_file, mode="a")
85 | file_handler.setFormatter(
86 | logging.Formatter(
87 | "[%(asctime)s.%(msecs)03d] [%(levelname)s] %(message)s",
88 | datefmt="%Y-%m-%d %H:%M:%S",
89 | )
90 | )
91 | logger.addHandler(file_handler)
92 |
93 | if __name__ == "__main__":
94 |
95 | root_path = Path(args.folder)
96 | if not root_path.exists():
97 | raise ValueError(f"Path {args.folder} does not exist")
98 |
99 | ingest_service = global_injector.get(IngestService)
100 | worker = LocalIngestWorker(ingest_service)
101 | worker.ingest_folder(root_path, args.ignored)
102 |
103 | if args.ignored:
104 | logger.info(f"Skipping following files and directories: {args.ignored}")
105 |
106 | if args.watch:
107 | logger.info(f"Watching {args.folder} for changes, press Ctrl+C to stop...")
108 | directories_to_watch = [
109 | dir
110 | for dir in root_path.iterdir()
111 | if dir.is_dir() and dir.name not in args.ignored
112 | ]
113 | watcher = IngestWatcher(args.folder, worker.ingest_on_watch)
114 | watcher.start()
115 |
--------------------------------------------------------------------------------
/private_gpt/server/chat/chat_router.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter, Depends, Request
2 | from llama_index.llms import ChatMessage, MessageRole
3 | from pydantic import BaseModel
4 | from starlette.responses import StreamingResponse
5 |
6 | from private_gpt.open_ai.extensions.context_filter import ContextFilter
7 | from private_gpt.open_ai.openai_models import (
8 | OpenAICompletion,
9 | OpenAIMessage,
10 | to_openai_response,
11 | to_openai_sse_stream,
12 | )
13 | from private_gpt.server.chat.chat_service import ChatService
14 | from private_gpt.server.utils.auth import authenticated
15 |
16 | chat_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
17 |
18 |
19 | class ChatBody(BaseModel):
20 | messages: list[OpenAIMessage]
21 | use_context: bool = False
22 | context_filter: ContextFilter | None = None
23 | include_sources: bool = True
24 | stream: bool = False
25 |
26 | model_config = {
27 | "json_schema_extra": {
28 | "examples": [
29 | {
30 | "messages": [
31 | {
32 | "role": "system",
33 | "content": "You are a rapper. Always answer with a rap.",
34 | },
35 | {
36 | "role": "user",
37 | "content": "How do you fry an egg?",
38 | },
39 | ],
40 | "stream": False,
41 | "use_context": True,
42 | "include_sources": True,
43 | "context_filter": {
44 | "docs_ids": ["c202d5e6-7b69-4869-81cc-dd574ee8ee11"]
45 | },
46 | }
47 | ]
48 | }
49 | }
50 |
51 |
52 | @chat_router.post(
53 | "/chat/completions",
54 | response_model=None,
55 | responses={200: {"model": OpenAICompletion}},
56 | tags=["Contextual Completions"],
57 | openapi_extra={
58 | "x-fern-streaming": {
59 | "stream-condition": "stream",
60 | "response": {"$ref": "#/components/schemas/OpenAICompletion"},
61 | "response-stream": {"$ref": "#/components/schemas/OpenAICompletion"},
62 | }
63 | },
64 | )
65 | def chat_completion(
66 | request: Request, body: ChatBody
67 | ) -> OpenAICompletion | StreamingResponse:
68 | """Given a list of messages comprising a conversation, return a response.
69 |
70 | Optionally include an initial `role: system` message to influence the way
71 | the LLM answers.
72 |
73 | If `use_context` is set to `true`, the model will use context coming
74 | from the ingested documents to create the response. The documents being used can
75 | be filtered using the `context_filter` and passing the document IDs to be used.
76 | Ingested documents IDs can be found using `/ingest/list` endpoint. If you want
77 | all ingested documents to be used, remove `context_filter` altogether.
78 |
79 | When using `'include_sources': true`, the API will return the source Chunks used
80 | to create the response, which come from the context provided.
81 |
82 | When using `'stream': true`, the API will return data chunks following [OpenAI's
83 | streaming model](https://platform.openai.com/docs/api-reference/chat/streaming):
84 | ```
85 | {"id":"12345","object":"completion.chunk","created":1694268190,
86 | "model":"private-gpt","choices":[{"index":0,"delta":{"content":"Hello"},
87 | "finish_reason":null}]}
88 | ```
89 | """
90 | service = request.state.injector.get(ChatService)
91 | all_messages = [
92 | ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages
93 | ]
94 | if body.stream:
95 | completion_gen = service.stream_chat(
96 | messages=all_messages,
97 | use_context=body.use_context,
98 | context_filter=body.context_filter,
99 | )
100 | return StreamingResponse(
101 | to_openai_sse_stream(
102 | completion_gen.response,
103 | completion_gen.sources if body.include_sources else None,
104 | ),
105 | media_type="text/event-stream",
106 | )
107 | else:
108 | completion = service.chat(
109 | messages=all_messages,
110 | use_context=body.use_context,
111 | context_filter=body.context_filter,
112 | )
113 | return to_openai_response(
114 | completion.response, completion.sources if body.include_sources else None
115 | )
116 |
--------------------------------------------------------------------------------
/tests/test_prompt_helper.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from llama_index.llms import ChatMessage, MessageRole
3 |
4 | from private_gpt.components.llm.prompt_helper import (
5 | ChatMLPromptStyle,
6 | DefaultPromptStyle,
7 | Llama2PromptStyle,
8 | MistralPromptStyle,
9 | TagPromptStyle,
10 | get_prompt_style,
11 | )
12 |
13 |
14 | @pytest.mark.parametrize(
15 | ("prompt_style", "expected_prompt_style"),
16 | [
17 | ("default", DefaultPromptStyle),
18 | ("llama2", Llama2PromptStyle),
19 | ("tag", TagPromptStyle),
20 | ("mistral", MistralPromptStyle),
21 | ("chatml", ChatMLPromptStyle),
22 | ],
23 | )
24 | def test_get_prompt_style_success(prompt_style, expected_prompt_style):
25 | assert isinstance(get_prompt_style(prompt_style), expected_prompt_style)
26 |
27 |
28 | def test_get_prompt_style_failure():
29 | prompt_style = "unknown"
30 | with pytest.raises(ValueError) as exc_info:
31 | get_prompt_style(prompt_style)
32 | assert str(exc_info.value) == f"Unknown prompt_style='{prompt_style}'"
33 |
34 |
35 | def test_tag_prompt_style_format():
36 | prompt_style = TagPromptStyle()
37 | messages = [
38 | ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
39 | ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
40 | ]
41 |
42 | expected_prompt = (
43 | "<|system|>: You are an AI assistant.\n"
44 | "<|user|>: Hello, how are you doing?\n"
45 | "<|assistant|>: "
46 | )
47 |
48 | assert prompt_style.messages_to_prompt(messages) == expected_prompt
49 |
50 |
51 | def test_tag_prompt_style_format_with_system_prompt():
52 | prompt_style = TagPromptStyle()
53 | messages = [
54 | ChatMessage(
55 | content="FOO BAR Custom sys prompt from messages.", role=MessageRole.SYSTEM
56 | ),
57 | ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
58 | ]
59 |
60 | expected_prompt = (
61 | "<|system|>: FOO BAR Custom sys prompt from messages.\n"
62 | "<|user|>: Hello, how are you doing?\n"
63 | "<|assistant|>: "
64 | )
65 |
66 | assert prompt_style.messages_to_prompt(messages) == expected_prompt
67 |
68 |
69 | def test_mistral_prompt_style_format():
70 | prompt_style = MistralPromptStyle()
71 | messages = [
72 | ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
73 | ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
74 | ]
75 |
76 | expected_prompt = (
77 | "[INST] You are an AI assistant. [/INST]"
78 | "[INST] Hello, how are you doing? [/INST]"
79 | )
80 |
81 | assert prompt_style.messages_to_prompt(messages) == expected_prompt
82 |
83 |
84 | def test_chatml_prompt_style_format():
85 | prompt_style = ChatMLPromptStyle()
86 | messages = [
87 | ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
88 | ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
89 | ]
90 |
91 | expected_prompt = (
92 | "<|im_start|>system\n"
93 | "You are an AI assistant.<|im_end|>\n"
94 | "<|im_start|>user\n"
95 | "Hello, how are you doing?<|im_end|>\n"
96 | "<|im_start|>assistant\n"
97 | )
98 |
99 | assert prompt_style.messages_to_prompt(messages) == expected_prompt
100 |
101 |
102 | def test_llama2_prompt_style_format():
103 | prompt_style = Llama2PromptStyle()
104 | messages = [
105 | ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
106 | ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
107 | ]
108 |
109 | expected_prompt = (
110 | " [INST] <>\n"
111 | " You are an AI assistant. \n"
112 | "<>\n"
113 | "\n"
114 | " Hello, how are you doing? [/INST]"
115 | )
116 |
117 | assert prompt_style.messages_to_prompt(messages) == expected_prompt
118 |
119 |
120 | def test_llama2_prompt_style_with_system_prompt():
121 | prompt_style = Llama2PromptStyle()
122 | messages = [
123 | ChatMessage(
124 | content="FOO BAR Custom sys prompt from messages.", role=MessageRole.SYSTEM
125 | ),
126 | ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
127 | ]
128 |
129 | expected_prompt = (
130 | " [INST] <>\n"
131 | " FOO BAR Custom sys prompt from messages. \n"
132 | "<>\n"
133 | "\n"
134 | " Hello, how are you doing? [/INST]"
135 | )
136 |
137 | assert prompt_style.messages_to_prompt(messages) == expected_prompt
138 |
--------------------------------------------------------------------------------
/private_gpt/server/ingest/ingest_router.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
4 | from pydantic import BaseModel, Field
5 |
6 | from private_gpt.server.ingest.ingest_service import IngestService
7 | from private_gpt.server.ingest.model import IngestedDoc
8 | from private_gpt.server.utils.auth import authenticated
9 |
10 | ingest_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
11 |
12 |
13 | class IngestTextBody(BaseModel):
14 | file_name: str = Field(examples=["Avatar: The Last Airbender"])
15 | text: str = Field(
16 | examples=[
17 | "Avatar is set in an Asian and Arctic-inspired world in which some "
18 | "people can telekinetically manipulate one of the four elements—water, "
19 | "earth, fire or air—through practices known as 'bending', inspired by "
20 | "Chinese martial arts."
21 | ]
22 | )
23 |
24 |
25 | class IngestResponse(BaseModel):
26 | object: Literal["list"]
27 | model: Literal["private-gpt"]
28 | data: list[IngestedDoc]
29 |
30 |
31 | @ingest_router.post("/ingest", tags=["Ingestion"], deprecated=True)
32 | def ingest(request: Request, file: UploadFile) -> IngestResponse:
33 | """Ingests and processes a file.
34 |
35 | Deprecated. Use ingest/file instead.
36 | """
37 | return ingest_file(request, file)
38 |
39 |
40 | @ingest_router.post("/ingest/file", tags=["Ingestion"])
41 | def ingest_file(request: Request, file: UploadFile) -> IngestResponse:
42 | """Ingests and processes a file, storing its chunks to be used as context.
43 |
44 | The context obtained from files is later used in
45 | `/chat/completions`, `/completions`, and `/chunks` APIs.
46 |
47 | Most common document
48 | formats are supported, but you may be prompted to install an extra dependency to
49 | manage a specific file type.
50 |
51 | A file can generate different Documents (for example a PDF generates one Document
52 | per page). All Documents IDs are returned in the response, together with the
53 | extracted Metadata (which is later used to improve context retrieval). Those IDs
54 | can be used to filter the context used to create responses in
55 | `/chat/completions`, `/completions`, and `/chunks` APIs.
56 | """
57 | service = request.state.injector.get(IngestService)
58 | if file.filename is None:
59 | raise HTTPException(400, "No file name provided")
60 | ingested_documents = service.ingest_bin_data(file.filename, file.file)
61 | return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
62 |
63 |
64 | @ingest_router.post("/ingest/text", tags=["Ingestion"])
65 | def ingest_text(request: Request, body: IngestTextBody) -> IngestResponse:
66 | """Ingests and processes a text, storing its chunks to be used as context.
67 |
68 | The context obtained from files is later used in
69 | `/chat/completions`, `/completions`, and `/chunks` APIs.
70 |
71 | A Document will be generated with the given text. The Document
72 | ID is returned in the response, together with the
73 | extracted Metadata (which is later used to improve context retrieval). That ID
74 | can be used to filter the context used to create responses in
75 | `/chat/completions`, `/completions`, and `/chunks` APIs.
76 | """
77 | service = request.state.injector.get(IngestService)
78 | if len(body.file_name) == 0:
79 | raise HTTPException(400, "No file name provided")
80 | ingested_documents = service.ingest_text(body.file_name, body.text)
81 | return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
82 |
83 |
84 | @ingest_router.get("/ingest/list", tags=["Ingestion"])
85 | def list_ingested(request: Request) -> IngestResponse:
86 | """Lists already ingested Documents including their Document ID and metadata.
87 |
88 | Those IDs can be used to filter the context used to create responses
89 | in `/chat/completions`, `/completions`, and `/chunks` APIs.
90 | """
91 | service = request.state.injector.get(IngestService)
92 | ingested_documents = service.list_ingested()
93 | return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
94 |
95 |
96 | @ingest_router.delete("/ingest/{doc_id}", tags=["Ingestion"])
97 | def delete_ingested(request: Request, doc_id: str) -> None:
98 | """Delete the specified ingested Document.
99 |
100 | The `doc_id` can be obtained from the `GET /ingest/list` endpoint.
101 | The document will be effectively deleted from your storage context.
102 | """
103 | service = request.state.injector.get(IngestService)
104 | service.delete(doc_id)
105 |
--------------------------------------------------------------------------------
/private_gpt/ui/images.py:
--------------------------------------------------------------------------------
1 | logo_svg = "data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iODYxIiBoZWlnaHQ9Ijk4IiB2aWV3Qm94PSIwIDAgODYxIDk4IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgo8cGF0aCBkPSJNNDguMTM0NSAwLjE1NzkxMUMzNi44Mjk5IDEuMDM2NTQgMjYuMTIwNSA1LjU1MzI4IDE3LjYyNTYgMTMuMDI1QzkuMTMwNDYgMjAuNDk2NyAzLjMxMTcgMzAuNTE2OSAxLjA0OTUyIDQxLjU3MDVDLTEuMjEyNzMgNTIuNjIzOCAwLjIwNDQxOSA2NC4xMDk0IDUuMDg2MiA3NC4yOTA1QzkuOTY4NjggODQuNDcxNiAxOC4wNTAzIDkyLjc5NDMgMjguMTA5OCA5OEwzMy43MDI2IDgyLjU5MDdMMzUuNDU0MiA3Ny43NjU2QzI5LjgzODcgNzQuMTY5MiAyNS41NDQ0IDY4Ljg2MDcgMjMuMjE0IDYyLjYzNDRDMjAuODgyMiA1Ni40MDg2IDIwLjYzOSA0OS41OTkxIDIyLjUyMDQgNDMuMjI0M0MyNC40MDI5IDM2Ljg0OTUgMjguMzA5NiAzMS4yNTI1IDMzLjY1NjEgMjcuMjcwNkMzOS4wMDIgMjMuMjg4MyA0NS41MDAzIDIxLjEzNSA1Mi4xNzg5IDIxLjEzM0M1OC44NTczIDIxLjEzMDMgNjUuMzU3MSAyMy4yNzgzIDcwLjcwNjUgMjcuMjU1OEM3Ni4wNTU0IDMxLjIzNCA3OS45NjY0IDM2LjgyNzcgODEuODU0MyA0My4yMDA2QzgzLjc0MjkgNDkuNTczNiA4My41MDYyIDU2LjM4MzYgODEuMTgwMSA2Mi42MTE3Qzc4Ljg1NDUgNjguODM5NiA3NC41NjUgNzQuMTUxNCA2OC45NTI5IDc3Ljc1MjhMNzAuNzA3NCA4Mi41OTA3TDc2LjMwMDIgOTcuOTk3MUM4Ni45Nzg4IDkyLjQ3MDUgOTUuNDA4OCA4My40NDE5IDEwMC4xNjMgNzIuNDQwNEMxMDQuOTE3IDYxLjQzOTQgMTA1LjcwNCA0OS4xNDE3IDEwMi4zODkgMzcuNjNDOTkuMDc0NiAyNi4xMTc5IDkxLjg2MjcgMTYuMDk5MyA4MS45NzQzIDkuMjcwNzlDNzIuMDg2MSAyLjQ0MTkxIDYwLjEyOTEgLTAuNzc3MDg2IDQ4LjEyODYgMC4xNTg5MzRMNDguMTM0NSAwLjE1NzkxMVoiIGZpbGw9IiMxRjFGMjkiLz4KPGcgY2xpcC1wYXRoPSJ1cmwoI2NsaXAwXzVfMTkpIj4KPHBhdGggZD0iTTIyMC43NzIgMTIuNzUyNEgyNTIuNjM5QzI2Ny4yNjMgMTIuNzUyNCAyNzcuNzM5IDIxLjk2NzUgMjc3LjczOSAzNS40MDUyQzI3Ny43MzkgNDYuNzg3IDI2OS44ODEgNTUuMzUwOCAyNTguMzE0IDU3LjQxMDdMMjc4LjgzIDg1LjM3OTRIMjYxLjM3TDI0Mi4wNTQgNTcuOTUzM0gyMzUuNTA2Vjg1LjM3OTRIMjIwLjc3NEwyMjAuNzcyIDEyLjc1MjRaTTIzNS41MDQgMjYuMzAyOFY0NC40MDdIMjUyLjYzMkMyNTguOTYyIDQ0LjQwNyAyNjIuOTk5IDQwLjgyOTggMjYyLjk5OSAzNS40MTAyQzI2Mi45OTkgMjkuODgwOSAyNTguOTYyIDI2LjMwMjggMjUyLjYzMiAyNi4zMDI4SDIzNS41MDRaIiBmaWxsPSIjMUYxRjI5Ii8+CjxwYXRoIGQ9Ik0yOTUuMTc2IDg1LjM4NDRWMTIuNzUyNEgzMDkuOTA5Vjg1LjM4NDRIMjk1LjE3NloiIGZpbGw9IiMxRjFGMjkiLz4KPHBhdGggZD0iTTM2My43OTUgNjUuNzYzTDM4NS42MiAxMi43NTI0SDQwMS40NDRMMzcxLjIxNSA4NS4zODQ0SDM1Ni40ODNMMzI2LjI1NCAxMi43NTI0SDM0Mi4wNzhMMzYzLjc5NSA2NS43NjNaIiBmaWxsPSIjMUYxRjI5Ii8+CjxwYXRoIGQ9Ik00NDguMzI3IDcyLjA1MDRINDE1LjY5OEw0MTAuMjQxIDg1LjM4NDRIMzk0LjQxOEw0MjQuNjQ3IDEyLjc1MjRINDM5LjM3OUw0NjkuNjA4IDg1LjM4NDRINDUzLjc4M0w0NDguMzI3IDcyLjA1MDRaTTQ0Mi43NjEgNTguNUw0MzIuMDY2IDMyLjM3NDhMNDIxLjI2MiA1OC41SDQ0Mi43NjFaIiBmaWxsPSIjMUYxRjI5Ii8+CjxwYXRoIGQ9Ik00NjUuMjIxIDEyLjc1MjRINTMwLjU5MlYyNi4zMDI4SDUwNS4yNzVWODUuMzg0NEg0OTAuNTM5VjI2LjMwMjhINDY1LjIyMVYxMi43NTI0WiIgZmlsbD0iIzFGMUYyOSIvPgo8cGF0aCBkPSJNNTk1LjE5MyAxMi43NTI0VjI2LjMwMjhINTYyLjEyOFY0MS4xNTUxSDU5NS4xOTNWNTQuNzA2NUg1NjIuMTI4VjcxLjgzNEg1OTUuMTkzVjg1LjM4NDRINTQ3LjM5NVYxMi43NTI0SDU5NS4xOTNaIiBmaWxsPSIjMUYxRjI5Ii8+CjxwYXRoIGQ9Ik0xNjcuMjAxIDU3LjQxNThIMTg2LjUzNkMxOTAuODg2IDU3LjQ2NjIgMTk1LjE2OCA1Ni4zMzQ4IDE5OC45MTggNTQuMTQzN0MyMDIuMTc5IDUyLjIxOTkgMjA0Ljg2OSA0OS40NzM2IDIwNi43MTYgNDYuMTgzNUMyMDguNTYyIDQyLjg5MzQgMjA5LjUgMzkuMTc2NiAyMDkuNDMzIDM1LjQxMDJDMjA5LjQzMyAyMS45Njc1IDE5OC45NTggMTIuNzU3NCAxODQuMzM0IDEyLjc1NzRIMTUyLjQ2OFY4NS4zODk0SDE2Ny4yMDFWNTcuNDIwN1Y1Ny40MTU4Wk0xNjcuMjAxIDI2LjMwNThIMTg0LjMyOUMxOTAuNjU4IDI2LjMwNTggMTk0LjY5NiAyOS44ODQgMTk0LjY5NiAzNS40MTMzQzE5NC42OTYgNDAuODMyOSAxOTAuNjU4IDQ0LjQwOTkgMTg0LjMyOSA0NC40MDk5SDE2Ny4yMDFWMjYuMzA1OFoiIGZpbGw9IiMxRjFGMjkiLz4KPHBhdGggZD0iTTc5NC44MzUgMTIuNzUyNEg4NjAuMjA2VjI2LjMwMjhIODM0Ljg4OVY4NS4zODQ0SDgyMC4xNTZWMjYuMzAyOEg3OTQuODM1VjEyLjc1MjRaIiBmaWxsPSIjMUYxRjI5Ii8+CjxwYXRoIGQ9Ik03NDEuOTA3IDU3LjQxNThINzYxLjI0MUM3NjUuNTkyIDU3LjQ2NjEgNzY5Ljg3NCA1Ni4zMzQ3IDc3My42MjQgNTQuMTQzN0M3NzYuODg0IDUyLjIxOTkgNzc5LjU3NSA0OS40NzM2IDc4MS40MjEgNDYuMTgzNUM3ODMuMjY4IDQyLjg5MzQgNzg0LjIwNiAzOS4xNzY2IDc4NC4xMzkgMzUuNDEwMkM3ODQuMTM5IDIxLjk2NzUgNzczLjY2NCAxMi43NTc0IDc1OS4wMzkgMTIuNzU3NEg3MjcuMTc1Vjg1LjM4OTRINzQxLjkwN1Y1Ny40MjA3VjU3LjQxNThaTTc0MS45MDcgMjYuMzA1OEg3NTkuMDM1Qzc2NS4zNjUgMjYuMzA1OCA3NjkuNDAzIDI5Ljg4NCA3NjkuNDAzIDM1LjQxMzNDNzY5LjQwMyA0MC44MzI5IDc2NS4zNjUgNDQuNDA5OSA3NTkuMDM1IDQ0LjQwOTlINzQxLjkwN1YyNi4zMDU4WiIgZmlsbD0iIzFGMUYyOSIvPgo8cGF0aCBkPSJNNjgxLjA2OSA0Ny4wMTE1VjU5LjAxMjVINjk1LjM3OVY3MS42NzE5QzY5Mi41MjYgNzMuNDM2OCA2ODguNTI0IDc0LjMzMTkgNjgzLjQ3NyA3NC4zMzE5QzY2Ni4wMDMgNzQuMzMxOSA2NTguMDQ1IDYxLjgxMjQgNjU4LjA0NSA1MC4xOEM2NTguMDQ1IDMzLjk2MDUgNjcxLjAwOCAyNS40NzMyIDY4My44MTIgMjUuNDczMkM2OTAuNDI1IDI1LjQ2MjggNjk2LjkwOSAyNy4yODA0IDcwMi41NDEgMzAuNzIyNkw3MDMuMTU3IDMxLjEyNTRMNzA1Ljk1OCAxOC4xODZMNzA1LjY2MyAxNy45OTc3QzcwMC4wNDYgMTQuNDAwNCA2OTEuMjkxIDEyLjI1OSA2ODIuMjUxIDEyLjI1OUM2NjMuMTk3IDEyLjI1OSA2NDIuOTQ5IDI1LjM5NjcgNjQyLjk0OSA0OS43NDVDNjQyLjk0OSA2MS4wODQ1IDY0Ny4yOTMgNzAuNzE3NCA2NTUuNTExIDc3LjYwMjlDNjYzLjIyNCA4My44MjQ1IDY3Mi44NzQgODcuMTg5IDY4Mi44MDkgODcuMTIwMUM2OTQuMzYzIDg3LjEyMDEgNzAzLjA2MSA4NC42NDk1IDcwOS40MDIgNzkuNTY5Mkw3MDkuNTg5IDc5LjQxODFWNDcuMDExNUg2ODEuMDY5WiIgZmlsbD0iIzFGMUYyOSIvPgo8L2c+CjxkZWZzPgo8Y2xpcFBhdGggaWQ9ImNsaXAwXzVfMTkiPgo8cmVjdCB3aWR0aD0iNzA3Ljc3OCIgaGVpZ2h0PSI3NC44NjExIiBmaWxsPSJ3aGl0ZSIgdHJhbnNmb3JtPSJ0cmFuc2xhdGUoMTUyLjQ0NCAxMi4yNSkiLz4KPC9jbGlwUGF0aD4KPC9kZWZzPgo8L3N2Zz4K"
2 |
--------------------------------------------------------------------------------
/private_gpt/server/chunks/chunks_service.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, Literal
2 |
3 | from injector import inject, singleton
4 | from llama_index import ServiceContext, StorageContext, VectorStoreIndex
5 | from llama_index.schema import NodeWithScore
6 | from pydantic import BaseModel, Field
7 |
8 | from private_gpt.components.embedding.embedding_component import EmbeddingComponent
9 | from private_gpt.components.llm.llm_component import LLMComponent
10 | from private_gpt.components.node_store.node_store_component import NodeStoreComponent
11 | from private_gpt.components.vector_store.vector_store_component import (
12 | VectorStoreComponent,
13 | )
14 | from private_gpt.open_ai.extensions.context_filter import ContextFilter
15 | from private_gpt.server.ingest.model import IngestedDoc
16 |
17 | if TYPE_CHECKING:
18 | from llama_index.schema import RelatedNodeInfo
19 |
20 |
21 | class Chunk(BaseModel):
22 | object: Literal["context.chunk"]
23 | score: float = Field(examples=[0.023])
24 | document: IngestedDoc
25 | text: str = Field(examples=["Outbound sales increased 20%, driven by new leads."])
26 | previous_texts: list[str] | None = Field(
27 | default=None,
28 | examples=[["SALES REPORT 2023", "Inbound didn't show major changes."]],
29 | )
30 | next_texts: list[str] | None = Field(
31 | default=None,
32 | examples=[
33 | [
34 | "New leads came from Google Ads campaign.",
35 | "The campaign was run by the Marketing Department",
36 | ]
37 | ],
38 | )
39 |
40 | @classmethod
41 | def from_node(cls: type["Chunk"], node: NodeWithScore) -> "Chunk":
42 | doc_id = node.node.ref_doc_id if node.node.ref_doc_id is not None else "-"
43 | return cls(
44 | object="context.chunk",
45 | score=node.score or 0.0,
46 | document=IngestedDoc(
47 | object="ingest.document",
48 | doc_id=doc_id,
49 | doc_metadata=node.metadata,
50 | ),
51 | text=node.get_content(),
52 | )
53 |
54 |
55 | @singleton
56 | class ChunksService:
57 | @inject
58 | def __init__(
59 | self,
60 | llm_component: LLMComponent,
61 | vector_store_component: VectorStoreComponent,
62 | embedding_component: EmbeddingComponent,
63 | node_store_component: NodeStoreComponent,
64 | ) -> None:
65 | self.vector_store_component = vector_store_component
66 | self.storage_context = StorageContext.from_defaults(
67 | vector_store=vector_store_component.vector_store,
68 | docstore=node_store_component.doc_store,
69 | index_store=node_store_component.index_store,
70 | )
71 | self.query_service_context = ServiceContext.from_defaults(
72 | llm=llm_component.llm, embed_model=embedding_component.embedding_model
73 | )
74 |
75 | def _get_sibling_nodes_text(
76 | self, node_with_score: NodeWithScore, related_number: int, forward: bool = True
77 | ) -> list[str]:
78 | explored_nodes_texts = []
79 | current_node = node_with_score.node
80 | for _ in range(related_number):
81 | explored_node_info: RelatedNodeInfo | None = (
82 | current_node.next_node if forward else current_node.prev_node
83 | )
84 | if explored_node_info is None:
85 | break
86 |
87 | explored_node = self.storage_context.docstore.get_node(
88 | explored_node_info.node_id
89 | )
90 |
91 | explored_nodes_texts.append(explored_node.get_content())
92 | current_node = explored_node
93 |
94 | return explored_nodes_texts
95 |
96 | def retrieve_relevant(
97 | self,
98 | text: str,
99 | context_filter: ContextFilter | None = None,
100 | limit: int = 10,
101 | prev_next_chunks: int = 0,
102 | ) -> list[Chunk]:
103 | index = VectorStoreIndex.from_vector_store(
104 | self.vector_store_component.vector_store,
105 | storage_context=self.storage_context,
106 | service_context=self.query_service_context,
107 | show_progress=True,
108 | )
109 | vector_index_retriever = self.vector_store_component.get_retriever(
110 | index=index, context_filter=context_filter, similarity_top_k=limit
111 | )
112 | nodes = vector_index_retriever.retrieve(text)
113 | nodes.sort(key=lambda n: n.score or 0.0, reverse=True)
114 |
115 | retrieved_nodes = []
116 | for node in nodes:
117 | chunk = Chunk.from_node(node)
118 | chunk.previous_texts = self._get_sibling_nodes_text(
119 | node, prev_next_chunks, False
120 | )
121 | chunk.next_texts = self._get_sibling_nodes_text(node, prev_next_chunks)
122 | retrieved_nodes.append(chunk)
123 |
124 | return retrieved_nodes
125 |
--------------------------------------------------------------------------------
/private_gpt/components/vector_store/vector_store_component.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import typing
3 |
4 | from injector import inject, singleton
5 | from llama_index import VectorStoreIndex
6 | from llama_index.indices.vector_store import VectorIndexRetriever
7 | from llama_index.vector_stores.types import VectorStore
8 |
9 | from private_gpt.components.vector_store.batched_chroma import BatchedChromaVectorStore
10 | from private_gpt.open_ai.extensions.context_filter import ContextFilter
11 | from private_gpt.paths import local_data_path
12 | from private_gpt.settings.settings import Settings
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | @typing.no_type_check
18 | def _chromadb_doc_id_metadata_filter(
19 | context_filter: ContextFilter | None,
20 | ) -> dict | None:
21 | if context_filter is None or context_filter.docs_ids is None:
22 | return {} # No filter
23 | elif len(context_filter.docs_ids) < 1:
24 | return {"doc_id": "-"} # Effectively filtering out all docs
25 | else:
26 | doc_filter_items = []
27 | if len(context_filter.docs_ids) > 1:
28 | doc_filter = {"$or": doc_filter_items}
29 | for doc_id in context_filter.docs_ids:
30 | doc_filter_items.append({"doc_id": doc_id})
31 | else:
32 | doc_filter = {"doc_id": context_filter.docs_ids[0]}
33 | return doc_filter
34 |
35 |
36 | @singleton
37 | class VectorStoreComponent:
38 | vector_store: VectorStore
39 |
40 | @inject
41 | def __init__(self, settings: Settings) -> None:
42 | match settings.vectorstore.database:
43 | case "chroma":
44 | try:
45 | import chromadb # type: ignore
46 | from chromadb.config import ( # type: ignore
47 | Settings as ChromaSettings,
48 | )
49 | except ImportError as e:
50 | raise ImportError(
51 | "'chromadb' is not installed."
52 | "To use PrivateGPT with Chroma, install the 'chroma' extra."
53 | "`poetry install --extras chroma`"
54 | ) from e
55 |
56 | chroma_settings = ChromaSettings(anonymized_telemetry=False)
57 | chroma_client = chromadb.PersistentClient(
58 | path=str((local_data_path / "chroma_db").absolute()),
59 | settings=chroma_settings,
60 | )
61 | chroma_collection = chroma_client.get_or_create_collection(
62 | "make_this_parameterizable_per_api_call"
63 | ) # TODO
64 |
65 | self.vector_store = typing.cast(
66 | VectorStore,
67 | BatchedChromaVectorStore(
68 | chroma_client=chroma_client, chroma_collection=chroma_collection
69 | ),
70 | )
71 |
72 | case "qdrant":
73 | from llama_index.vector_stores.qdrant import QdrantVectorStore
74 | from qdrant_client import QdrantClient
75 |
76 | if settings.qdrant is None:
77 | logger.info(
78 | "Qdrant config not found. Using default settings."
79 | "Trying to connect to Qdrant at localhost:6333."
80 | )
81 | client = QdrantClient()
82 | else:
83 | client = QdrantClient(
84 | **settings.qdrant.model_dump(exclude_none=True)
85 | )
86 | self.vector_store = typing.cast(
87 | VectorStore,
88 | QdrantVectorStore(
89 | client=client,
90 | collection_name="make_this_parameterizable_per_api_call",
91 | ), # TODO
92 | )
93 | case _:
94 | # Should be unreachable
95 | # The settings validator should have caught this
96 | raise ValueError(
97 | f"Vectorstore database {settings.vectorstore.database} not supported"
98 | )
99 |
100 | @staticmethod
101 | def get_retriever(
102 | index: VectorStoreIndex,
103 | context_filter: ContextFilter | None = None,
104 | similarity_top_k: int = 2,
105 | ) -> VectorIndexRetriever:
106 | # This way we support qdrant (using doc_ids) and chroma (using where clause)
107 | return VectorIndexRetriever(
108 | index=index,
109 | similarity_top_k=similarity_top_k,
110 | doc_ids=context_filter.docs_ids if context_filter else None,
111 | vector_store_kwargs={
112 | "where": _chromadb_doc_id_metadata_filter(context_filter)
113 | },
114 | )
115 |
116 | def close(self) -> None:
117 | if hasattr(self.vector_store.client, "close"):
118 | self.vector_store.client.close()
119 |
--------------------------------------------------------------------------------
/fern/docs/pages/manual/ingestion.mdx:
--------------------------------------------------------------------------------
1 | # Ingesting & Managing Documents
2 |
3 | The ingestion of documents can be done in different ways:
4 |
5 | * Using the `/ingest` API
6 | * Using the Gradio UI
7 | * Using the Bulk Local Ingestion functionality (check next section)
8 |
9 | ## Bulk Local Ingestion
10 |
11 | When you are running PrivateGPT in a fully local setup, you can ingest a complete folder for convenience (containing
12 | pdf, text files, etc.)
13 | and optionally watch changes on it with the command:
14 |
15 | ```bash
16 | make ingest /path/to/folder -- --watch
17 | ```
18 |
19 | To log the processed and failed files to an additional file, use:
20 |
21 | ```bash
22 | make ingest /path/to/folder -- --watch --log-file /path/to/log/file.log
23 | ```
24 |
25 | **Note for Windows Users:** Depending on your Windows version and whether you are using PowerShell to execute
26 | PrivateGPT API calls, you may need to include the parameter name before passing the folder path for consumption:
27 |
28 | ```bash
29 | make ingest arg=/path/to/folder -- --watch --log-file /path/to/log/file.log
30 | ```
31 |
32 | After ingestion is complete, you should be able to chat with your documents
33 | by navigating to http://localhost:8001 and using the option `Query documents`,
34 | or using the completions / chat API.
35 |
36 | ## Ingestion troubleshooting
37 |
38 | ### Running out of memory
39 |
40 | To do not run out of memory, you should ingest your documents without the LLM loaded in your (video) memory.
41 | To do so, you should change your configuration to set `llm.mode: mock`.
42 |
43 | You can also use the existing `PGPT_PROFILES=mock` that will set the following configuration for you:
44 |
45 | ```yaml
46 | llm:
47 | mode: mock
48 | embedding:
49 | mode: local
50 | ```
51 |
52 | This configuration allows you to use hardware acceleration for creating embeddings while avoiding loading the full LLM into (video) memory.
53 |
54 | Once your documents are ingested, you can set the `llm.mode` value back to `local` (or your previous custom value).
55 |
56 | ### Ingestion speed
57 |
58 | The ingestion speed depends on the number of documents you are ingesting, and the size of each document.
59 | To speed up the ingestion, you can change the ingestion mode in configuration.
60 |
61 | The following ingestion mode exist:
62 | * `simple`: historic behavior, ingest one document at a time, sequentially
63 | * `batch`: read, parse, and embed multiple documents using batches (batch read, and then batch parse, and then batch embed)
64 | * `parallel`: read, parse, and embed multiple documents in parallel. This is the fastest ingestion mode for local setup.
65 | To change the ingestion mode, you can use the `embedding.ingest_mode` configuration value. The default value is `simple`.
66 |
67 | To configure the number of workers used for parallel or batched ingestion, you can use
68 | the `embedding.count_workers` configuration value. If you set this value too high, you might run out of
69 | memory, so be mindful when setting this value. The default value is `2`.
70 | For `batch` mode, you can easily set this value to your number of threads available on your CPU without
71 | running out of memory. For `parallel` mode, you should be more careful, and set this value to a lower value.
72 |
73 | The configuration below should be enough for users who want to stress more their hardware:
74 | ```yaml
75 | embedding:
76 | ingest_mode: parallel
77 | count_workers: 4
78 | ```
79 |
80 | If your hardware is powerful enough, and that you are loading heavy documents, you can increase the number of workers.
81 | It is recommended to do your own tests to find the optimal value for your hardware.
82 |
83 | If you have a `bash` shell, you can use this set of command to do your own benchmark:
84 |
85 | ```bash
86 | # Wipe your local data, to put yourself in a clean state
87 | # This will delete all your ingested documents
88 | make wipe
89 |
90 | time PGPT_PROFILES=mock python ./scripts/ingest_folder.py ~/my-dir/to-ingest/
91 | ```
92 |
93 | ## Supported file formats
94 |
95 | privateGPT by default supports all the file formats that contains clear text (for example, `.txt` files, `.html`, etc.).
96 | However, these text based file formats as only considered as text files, and are not pre-processed in any other way.
97 |
98 | It also supports the following file formats:
99 | * `.hwp`
100 | * `.pdf`
101 | * `.docx`
102 | * `.pptx`
103 | * `.ppt`
104 | * `.pptm`
105 | * `.jpg`
106 | * `.png`
107 | * `.jpeg`
108 | * `.mp3`
109 | * `.mp4`
110 | * `.csv`
111 | * `.epub`
112 | * `.md`
113 | * `.mbox`
114 | * `.ipynb`
115 | * `.json`
116 |
117 | **Please note the following nuance**: while `privateGPT` supports these file formats, it **might** require additional
118 | dependencies to be installed in your python's virtual environment.
119 | For example, if you try to ingest `.epub` files, `privateGPT` might fail to do it, and will instead display an
120 | explanatory error asking you to download the necessary dependencies to install this file format.
121 |
122 |
123 | **Other file formats might work**, but they will be considered as plain text
124 | files (in other words, they will be ingested as `.txt` files).
--------------------------------------------------------------------------------
/private_gpt/server/ingest/ingest_service.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import tempfile
3 | from pathlib import Path
4 | from typing import AnyStr, BinaryIO
5 |
6 | from injector import inject, singleton
7 | from llama_index import (
8 | ServiceContext,
9 | StorageContext,
10 | )
11 | from llama_index.node_parser import SentenceWindowNodeParser
12 |
13 | from private_gpt.components.embedding.embedding_component import EmbeddingComponent
14 | from private_gpt.components.ingest.ingest_component import get_ingestion_component
15 | from private_gpt.components.llm.llm_component import LLMComponent
16 | from private_gpt.components.node_store.node_store_component import NodeStoreComponent
17 | from private_gpt.components.vector_store.vector_store_component import (
18 | VectorStoreComponent,
19 | )
20 | from private_gpt.server.ingest.model import IngestedDoc
21 | from private_gpt.settings.settings import settings
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | @singleton
27 | class IngestService:
28 | @inject
29 | def __init__(
30 | self,
31 | llm_component: LLMComponent,
32 | vector_store_component: VectorStoreComponent,
33 | embedding_component: EmbeddingComponent,
34 | node_store_component: NodeStoreComponent,
35 | ) -> None:
36 | self.llm_service = llm_component
37 | self.storage_context = StorageContext.from_defaults(
38 | vector_store=vector_store_component.vector_store,
39 | docstore=node_store_component.doc_store,
40 | index_store=node_store_component.index_store,
41 | )
42 | node_parser = SentenceWindowNodeParser.from_defaults()
43 | self.ingest_service_context = ServiceContext.from_defaults(
44 | llm=self.llm_service.llm,
45 | embed_model=embedding_component.embedding_model,
46 | node_parser=node_parser,
47 | # Embeddings done early in the pipeline of node transformations, right
48 | # after the node parsing
49 | transformations=[node_parser, embedding_component.embedding_model],
50 | )
51 |
52 | self.ingest_component = get_ingestion_component(
53 | self.storage_context, self.ingest_service_context, settings=settings()
54 | )
55 |
56 | def _ingest_data(self, file_name: str, file_data: AnyStr) -> list[IngestedDoc]:
57 | logger.debug("Got file data of size=%s to ingest", len(file_data))
58 | # llama-index mainly supports reading from files, so
59 | # we have to create a tmp file to read for it to work
60 | # delete=False to avoid a Windows 11 permission error.
61 | with tempfile.NamedTemporaryFile(delete=False) as tmp:
62 | try:
63 | path_to_tmp = Path(tmp.name)
64 | if isinstance(file_data, bytes):
65 | path_to_tmp.write_bytes(file_data)
66 | else:
67 | path_to_tmp.write_text(str(file_data))
68 | return self.ingest_file(file_name, path_to_tmp)
69 | finally:
70 | tmp.close()
71 | path_to_tmp.unlink()
72 |
73 | def ingest_file(self, file_name: str, file_data: Path) -> list[IngestedDoc]:
74 | logger.info("Ingesting file_name=%s", file_name)
75 | documents = self.ingest_component.ingest(file_name, file_data)
76 | logger.info("Finished ingestion file_name=%s", file_name)
77 | return [IngestedDoc.from_document(document) for document in documents]
78 |
79 | def ingest_text(self, file_name: str, text: str) -> list[IngestedDoc]:
80 | logger.debug("Ingesting text data with file_name=%s", file_name)
81 | return self._ingest_data(file_name, text)
82 |
83 | def ingest_bin_data(
84 | self, file_name: str, raw_file_data: BinaryIO
85 | ) -> list[IngestedDoc]:
86 | logger.debug("Ingesting binary data with file_name=%s", file_name)
87 | file_data = raw_file_data.read()
88 | return self._ingest_data(file_name, file_data)
89 |
90 | def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[IngestedDoc]:
91 | logger.info("Ingesting file_names=%s", [f[0] for f in files])
92 | documents = self.ingest_component.bulk_ingest(files)
93 | logger.info("Finished ingestion file_name=%s", [f[0] for f in files])
94 | return [IngestedDoc.from_document(document) for document in documents]
95 |
96 | def list_ingested(self) -> list[IngestedDoc]:
97 | ingested_docs = []
98 | try:
99 | docstore = self.storage_context.docstore
100 | ingested_docs_ids: set[str] = set()
101 |
102 | for node in docstore.docs.values():
103 | if node.ref_doc_id is not None:
104 | ingested_docs_ids.add(node.ref_doc_id)
105 |
106 | for doc_id in ingested_docs_ids:
107 | ref_doc_info = docstore.get_ref_doc_info(ref_doc_id=doc_id)
108 | doc_metadata = None
109 | if ref_doc_info is not None and ref_doc_info.metadata is not None:
110 | doc_metadata = IngestedDoc.curate_metadata(ref_doc_info.metadata)
111 | ingested_docs.append(
112 | IngestedDoc(
113 | object="ingest.document",
114 | doc_id=doc_id,
115 | doc_metadata=doc_metadata,
116 | )
117 | )
118 | except ValueError:
119 | logger.warning("Got an exception when getting list of docs", exc_info=True)
120 | pass
121 | logger.debug("Found count=%s ingested documents", len(ingested_docs))
122 | return ingested_docs
123 |
124 | def delete(self, doc_id: str) -> None:
125 | """Delete an ingested document.
126 |
127 | :raises ValueError: if the document does not exist
128 | """
129 | logger.info(
130 | "Deleting the ingested document=%s in the doc and index store", doc_id
131 | )
132 | self.ingest_component.delete(doc_id)
133 |
--------------------------------------------------------------------------------
/private_gpt/server/chat/chat_service.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | from injector import inject, singleton
4 | from llama_index import ServiceContext, StorageContext, VectorStoreIndex
5 | from llama_index.chat_engine import ContextChatEngine, SimpleChatEngine
6 | from llama_index.chat_engine.types import (
7 | BaseChatEngine,
8 | )
9 | from llama_index.indices.postprocessor import MetadataReplacementPostProcessor
10 | from llama_index.llms import ChatMessage, MessageRole
11 | from llama_index.types import TokenGen
12 | from pydantic import BaseModel
13 |
14 | from private_gpt.components.embedding.embedding_component import EmbeddingComponent
15 | from private_gpt.components.llm.llm_component import LLMComponent
16 | from private_gpt.components.node_store.node_store_component import NodeStoreComponent
17 | from private_gpt.components.vector_store.vector_store_component import (
18 | VectorStoreComponent,
19 | )
20 | from private_gpt.open_ai.extensions.context_filter import ContextFilter
21 | from private_gpt.server.chunks.chunks_service import Chunk
22 |
23 |
24 | class Completion(BaseModel):
25 | response: str
26 | sources: list[Chunk] | None = None
27 |
28 |
29 | class CompletionGen(BaseModel):
30 | response: TokenGen
31 | sources: list[Chunk] | None = None
32 |
33 |
34 | @dataclass
35 | class ChatEngineInput:
36 | system_message: ChatMessage | None = None
37 | last_message: ChatMessage | None = None
38 | chat_history: list[ChatMessage] | None = None
39 |
40 | @classmethod
41 | def from_messages(cls, messages: list[ChatMessage]) -> "ChatEngineInput":
42 | # Detect if there is a system message, extract the last message and chat history
43 | system_message = (
44 | messages[0]
45 | if len(messages) > 0 and messages[0].role == MessageRole.SYSTEM
46 | else None
47 | )
48 | last_message = (
49 | messages[-1]
50 | if len(messages) > 0 and messages[-1].role == MessageRole.USER
51 | else None
52 | )
53 | # Remove from messages list the system message and last message,
54 | # if they exist. The rest is the chat history.
55 | if system_message:
56 | messages.pop(0)
57 | if last_message:
58 | messages.pop(-1)
59 | chat_history = messages if len(messages) > 0 else None
60 |
61 | return cls(
62 | system_message=system_message,
63 | last_message=last_message,
64 | chat_history=chat_history,
65 | )
66 |
67 |
68 | @singleton
69 | class ChatService:
70 | @inject
71 | def __init__(
72 | self,
73 | llm_component: LLMComponent,
74 | vector_store_component: VectorStoreComponent,
75 | embedding_component: EmbeddingComponent,
76 | node_store_component: NodeStoreComponent,
77 | ) -> None:
78 | self.llm_service = llm_component
79 | self.vector_store_component = vector_store_component
80 | self.storage_context = StorageContext.from_defaults(
81 | vector_store=vector_store_component.vector_store,
82 | docstore=node_store_component.doc_store,
83 | index_store=node_store_component.index_store,
84 | )
85 | self.service_context = ServiceContext.from_defaults(
86 | llm=llm_component.llm, embed_model=embedding_component.embedding_model
87 | )
88 | self.index = VectorStoreIndex.from_vector_store(
89 | vector_store_component.vector_store,
90 | storage_context=self.storage_context,
91 | service_context=self.service_context,
92 | show_progress=True,
93 | )
94 |
95 | def _chat_engine(
96 | self,
97 | system_prompt: str | None = None,
98 | use_context: bool = False,
99 | context_filter: ContextFilter | None = None,
100 | ) -> BaseChatEngine:
101 | if use_context:
102 | vector_index_retriever = self.vector_store_component.get_retriever(
103 | index=self.index, context_filter=context_filter
104 | )
105 | return ContextChatEngine.from_defaults(
106 | system_prompt=system_prompt,
107 | retriever=vector_index_retriever,
108 | service_context=self.service_context,
109 | node_postprocessors=[
110 | MetadataReplacementPostProcessor(target_metadata_key="window"),
111 | ],
112 | )
113 | else:
114 | return SimpleChatEngine.from_defaults(
115 | system_prompt=system_prompt,
116 | service_context=self.service_context,
117 | )
118 |
119 | def stream_chat(
120 | self,
121 | messages: list[ChatMessage],
122 | use_context: bool = False,
123 | context_filter: ContextFilter | None = None,
124 | ) -> CompletionGen:
125 | chat_engine_input = ChatEngineInput.from_messages(messages)
126 | last_message = (
127 | chat_engine_input.last_message.content
128 | if chat_engine_input.last_message
129 | else None
130 | )
131 | system_prompt = (
132 | chat_engine_input.system_message.content
133 | if chat_engine_input.system_message
134 | else None
135 | )
136 | chat_history = (
137 | chat_engine_input.chat_history if chat_engine_input.chat_history else None
138 | )
139 |
140 | chat_engine = self._chat_engine(
141 | system_prompt=system_prompt,
142 | use_context=use_context,
143 | context_filter=context_filter,
144 | )
145 | streaming_response = chat_engine.stream_chat(
146 | message=last_message if last_message is not None else "",
147 | chat_history=chat_history,
148 | )
149 | sources = [Chunk.from_node(node) for node in streaming_response.source_nodes]
150 | completion_gen = CompletionGen(
151 | response=streaming_response.response_gen, sources=sources
152 | )
153 | return completion_gen
154 |
155 | def chat(
156 | self,
157 | messages: list[ChatMessage],
158 | use_context: bool = False,
159 | context_filter: ContextFilter | None = None,
160 | ) -> Completion:
161 | chat_engine_input = ChatEngineInput.from_messages(messages)
162 | last_message = (
163 | chat_engine_input.last_message.content
164 | if chat_engine_input.last_message
165 | else None
166 | )
167 | system_prompt = (
168 | chat_engine_input.system_message.content
169 | if chat_engine_input.system_message
170 | else None
171 | )
172 | chat_history = (
173 | chat_engine_input.chat_history if chat_engine_input.chat_history else None
174 | )
175 |
176 | chat_engine = self._chat_engine(
177 | system_prompt=system_prompt,
178 | use_context=use_context,
179 | context_filter=context_filter,
180 | )
181 | wrapped_response = chat_engine.chat(
182 | message=last_message if last_message is not None else "",
183 | chat_history=chat_history,
184 | )
185 | sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes]
186 | completion = Completion(response=wrapped_response.response, sources=sources)
187 | return completion
188 |
--------------------------------------------------------------------------------
/private_gpt/components/llm/prompt_helper.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import logging
3 | from collections.abc import Sequence
4 | from typing import Any, Literal
5 |
6 | from llama_index.llms import ChatMessage, MessageRole
7 | from llama_index.llms.llama_utils import (
8 | completion_to_prompt,
9 | messages_to_prompt,
10 | )
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | class AbstractPromptStyle(abc.ABC):
16 | """Abstract class for prompt styles.
17 |
18 | This class is used to format a series of messages into a prompt that can be
19 | understood by the models. A series of messages represents the interaction(s)
20 | between a user and an assistant. This series of messages can be considered as a
21 | session between a user X and an assistant Y.This session holds, through the
22 | messages, the state of the conversation. This session, to be understood by the
23 | model, needs to be formatted into a prompt (i.e. a string that the models
24 | can understand). Prompts can be formatted in different ways,
25 | depending on the model.
26 |
27 | The implementations of this class represent the different ways to format a
28 | series of messages into a prompt.
29 | """
30 |
31 | def __init__(self, *args: Any, **kwargs: Any) -> None:
32 | logger.debug("Initializing prompt_style=%s", self.__class__.__name__)
33 |
34 | @abc.abstractmethod
35 | def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
36 | pass
37 |
38 | @abc.abstractmethod
39 | def _completion_to_prompt(self, completion: str) -> str:
40 | pass
41 |
42 | def messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
43 | prompt = self._messages_to_prompt(messages)
44 | logger.debug("Got for messages='%s' the prompt='%s'", messages, prompt)
45 | return prompt
46 |
47 | def completion_to_prompt(self, completion: str) -> str:
48 | prompt = self._completion_to_prompt(completion)
49 | logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt)
50 | return prompt
51 |
52 |
53 | class DefaultPromptStyle(AbstractPromptStyle):
54 | """Default prompt style that uses the defaults from llama_utils.
55 |
56 | It basically passes None to the LLM, indicating it should use
57 | the default functions.
58 | """
59 |
60 | def __init__(self, *args: Any, **kwargs: Any) -> None:
61 | super().__init__(*args, **kwargs)
62 |
63 | # Hacky way to override the functions
64 | # Override the functions to be None, and pass None to the LLM.
65 | self.messages_to_prompt = None # type: ignore[method-assign, assignment]
66 | self.completion_to_prompt = None # type: ignore[method-assign, assignment]
67 |
68 | def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
69 | return ""
70 |
71 | def _completion_to_prompt(self, completion: str) -> str:
72 | return ""
73 |
74 |
75 | class Llama2PromptStyle(AbstractPromptStyle):
76 | """Simple prompt style that just uses the default llama_utils functions.
77 |
78 | It transforms the sequence of messages into a prompt that should look like:
79 | ```text
80 | [INST] <> your system prompt here. <>
81 |
82 | user message here [/INST] assistant (model) response here
83 | ```
84 | """
85 |
86 | def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
87 | return messages_to_prompt(messages)
88 |
89 | def _completion_to_prompt(self, completion: str) -> str:
90 | return completion_to_prompt(completion)
91 |
92 |
93 | class TagPromptStyle(AbstractPromptStyle):
94 | """Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`.
95 |
96 | It transforms the sequence of messages into a prompt that should look like:
97 | ```text
98 | <|system|>: your system prompt here.
99 | <|user|>: user message here
100 | (possibly with context and question)
101 | <|assistant|>: assistant (model) response here.
102 | ```
103 |
104 | FIXME: should we add surrounding `` and `` tags, like in llama2?
105 | """
106 |
107 | def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
108 | """Format message to prompt with `<|ROLE|>: MSG` style."""
109 | prompt = ""
110 | for message in messages:
111 | role = message.role
112 | content = message.content or ""
113 | message_from_user = f"<|{role.lower()}|>: {content.strip()}"
114 | message_from_user += "\n"
115 | prompt += message_from_user
116 | # we are missing the last <|assistant|> tag that will trigger a completion
117 | prompt += "<|assistant|>: "
118 | return prompt
119 |
120 | def _completion_to_prompt(self, completion: str) -> str:
121 | return self._messages_to_prompt(
122 | [ChatMessage(content=completion, role=MessageRole.USER)]
123 | )
124 |
125 |
126 | class MistralPromptStyle(AbstractPromptStyle):
127 | def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
128 | prompt = ""
129 | for message in messages:
130 | role = message.role
131 | content = message.content or ""
132 | if role.lower() == "system":
133 | message_from_user = f"[INST] {content.strip()} [/INST]"
134 | prompt += message_from_user
135 | elif role.lower() == "user":
136 | prompt += ""
137 | message_from_user = f"[INST] {content.strip()} [/INST]"
138 | prompt += message_from_user
139 | return prompt
140 |
141 | def _completion_to_prompt(self, completion: str) -> str:
142 | return self._messages_to_prompt(
143 | [ChatMessage(content=completion, role=MessageRole.USER)]
144 | )
145 |
146 |
147 | class ChatMLPromptStyle(AbstractPromptStyle):
148 | def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
149 | prompt = "<|im_start|>system\n"
150 | for message in messages:
151 | role = message.role
152 | content = message.content or ""
153 | if role.lower() == "system":
154 | message_from_user = f"{content.strip()}"
155 | prompt += message_from_user
156 | elif role.lower() == "user":
157 | prompt += "<|im_end|>\n<|im_start|>user\n"
158 | message_from_user = f"{content.strip()}<|im_end|>\n"
159 | prompt += message_from_user
160 | prompt += "<|im_start|>assistant\n"
161 | return prompt
162 |
163 | def _completion_to_prompt(self, completion: str) -> str:
164 | return self._messages_to_prompt(
165 | [ChatMessage(content=completion, role=MessageRole.USER)]
166 | )
167 |
168 |
169 | def get_prompt_style(
170 | prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None
171 | ) -> AbstractPromptStyle:
172 | """Get the prompt style to use from the given string.
173 |
174 | :param prompt_style: The prompt style to use.
175 | :return: The prompt style to use.
176 | """
177 | if prompt_style is None or prompt_style == "default":
178 | return DefaultPromptStyle()
179 | elif prompt_style == "llama2":
180 | return Llama2PromptStyle()
181 | elif prompt_style == "tag":
182 | return TagPromptStyle()
183 | elif prompt_style == "mistral":
184 | return MistralPromptStyle()
185 | elif prompt_style == "chatml":
186 | return ChatMLPromptStyle()
187 | raise ValueError(f"Unknown prompt_style='{prompt_style}'")
188 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🔒 PrivateGPT 📑
2 |
3 | [](https://github.com/imartinez/privateGPT/actions/workflows/tests.yml?query=branch%3Amain)
4 | [](https://docs.privategpt.dev/)
5 |
6 | [](https://discord.gg/bK6mRVpErU)
7 | [](https://twitter.com/PrivateGPT_AI)
8 |
9 |
10 | > Install & usage docs: https://docs.privategpt.dev/
11 | >
12 | > Join the community: [Twitter](https://twitter.com/PrivateGPT_AI) & [Discord](https://discord.gg/bK6mRVpErU)
13 |
14 | 
15 |
16 | PrivateGPT is a production-ready AI project that allows you to ask questions about your documents using the power
17 | of Large Language Models (LLMs), even in scenarios without an Internet connection. 100% private, no data leaves your
18 | execution environment at any point.
19 |
20 | The project provides an API offering all the primitives required to build private, context-aware AI applications.
21 | It follows and extends the [OpenAI API standard](https://openai.com/blog/openai-api),
22 | and supports both normal and streaming responses.
23 |
24 | The API is divided into two logical blocks:
25 |
26 | **High-level API**, which abstracts all the complexity of a RAG (Retrieval Augmented Generation)
27 | pipeline implementation:
28 | - Ingestion of documents: internally managing document parsing,
29 | splitting, metadata extraction, embedding generation and storage.
30 | - Chat & Completions using context from ingested documents:
31 | abstracting the retrieval of context, the prompt engineering and the response generation.
32 |
33 | **Low-level API**, which allows advanced users to implement their own complex pipelines:
34 | - Embeddings generation: based on a piece of text.
35 | - Contextual chunks retrieval: given a query, returns the most relevant chunks of text from the ingested documents.
36 |
37 | In addition to this, a working [Gradio UI](https://www.gradio.app/)
38 | client is provided to test the API, together with a set of useful tools such as bulk model
39 | download script, ingestion script, documents folder watch, etc.
40 |
41 | > 👂 **Need help applying PrivateGPT to your specific use case?**
42 | > [Let us know more about it](https://forms.gle/4cSDmH13RZBHV9at7)
43 | > and we'll try to help! We are refining PrivateGPT through your feedback.
44 |
45 | ## 🎞️ Overview
46 | DISCLAIMER: This README is not updated as frequently as the [documentation](https://docs.privategpt.dev/).
47 | Please check it out for the latest updates!
48 |
49 | ### Motivation behind PrivateGPT
50 | Generative AI is a game changer for our society, but adoption in companies of all sizes and data-sensitive
51 | domains like healthcare or legal is limited by a clear concern: **privacy**.
52 | Not being able to ensure that your data is fully under your control when using third-party AI tools
53 | is a risk those industries cannot take.
54 |
55 | ### Primordial version
56 | The first version of PrivateGPT was launched in May 2023 as a novel approach to address the privacy
57 | concerns by using LLMs in a complete offline way.
58 |
59 | That version, which rapidly became a go-to project for privacy-sensitive setups and served as the seed
60 | for thousands of local-focused generative AI projects, was the foundation of what PrivateGPT is becoming nowadays;
61 | thus a simpler and more educational implementation to understand the basic concepts required
62 | to build a fully local -and therefore, private- chatGPT-like tool.
63 |
64 | If you want to keep experimenting with it, we have saved it in the
65 | [primordial branch](https://github.com/imartinez/privateGPT/tree/primordial) of the project.
66 |
67 | > It is strongly recommended to do a clean clone and install of this new version of
68 | PrivateGPT if you come from the previous, primordial version.
69 |
70 | ### Present and Future of PrivateGPT
71 | PrivateGPT is now evolving towards becoming a gateway to generative AI models and primitives, including
72 | completions, document ingestion, RAG pipelines and other low-level building blocks.
73 | We want to make it easier for any developer to build AI applications and experiences, as well as provide
74 | a suitable extensive architecture for the community to keep contributing.
75 |
76 | Stay tuned to our [releases](https://github.com/imartinez/privateGPT/releases) to check out all the new features and changes included.
77 |
78 | ## 📄 Documentation
79 | Full documentation on installation, dependencies, configuration, running the server, deployment options,
80 | ingesting local documents, API details and UI features can be found here: https://docs.privategpt.dev/
81 |
82 | ## 🧩 Architecture
83 | Conceptually, PrivateGPT is an API that wraps a RAG pipeline and exposes its
84 | primitives.
85 | * The API is built using [FastAPI](https://fastapi.tiangolo.com/) and follows
86 | [OpenAI's API scheme](https://platform.openai.com/docs/api-reference).
87 | * The RAG pipeline is based on [LlamaIndex](https://www.llamaindex.ai/).
88 |
89 | The design of PrivateGPT allows to easily extend and adapt both the API and the
90 | RAG implementation. Some key architectural decisions are:
91 | * Dependency Injection, decoupling the different components and layers.
92 | * Usage of LlamaIndex abstractions such as `LLM`, `BaseEmbedding` or `VectorStore`,
93 | making it immediate to change the actual implementations of those abstractions.
94 | * Simplicity, adding as few layers and new abstractions as possible.
95 | * Ready to use, providing a full implementation of the API and RAG
96 | pipeline.
97 |
98 | Main building blocks:
99 | * APIs are defined in `private_gpt:server:`. Each package contains an
100 | `_router.py` (FastAPI layer) and an `_service.py` (the
101 | service implementation). Each *Service* uses LlamaIndex base abstractions instead
102 | of specific implementations,
103 | decoupling the actual implementation from its usage.
104 | * Components are placed in
105 | `private_gpt:components:`. Each *Component* is in charge of providing
106 | actual implementations to the base abstractions used in the Services - for example
107 | `LLMComponent` is in charge of providing an actual implementation of an `LLM`
108 | (for example `LlamaCPP` or `OpenAI`).
109 |
110 | ## 💡 Contributing
111 | Contributions are welcomed! To ensure code quality we have enabled several format and
112 | typing checks, just run `make check` before committing to make sure your code is ok.
113 | Remember to test your code! You'll find a tests folder with helpers, and you can run
114 | tests using `make test` command.
115 |
116 | Don't know what to contribute? Here is the public
117 | [Project Board](https://github.com/users/imartinez/projects/3) with several ideas.
118 |
119 | Head over to Discord
120 | #contributors channel and ask for write permissions on that GitHub project.
121 |
122 | ## 💬 Community
123 | Join the conversation around PrivateGPT on our:
124 | - [Twitter (aka X)](https://twitter.com/PrivateGPT_AI)
125 | - [Discord](https://discord.gg/bK6mRVpErU)
126 |
127 | ## 📖 Citation
128 | If you use PrivateGPT in a paper, check out the [Citation file](CITATION.cff) for the correct citation.
129 | You can also use the "Cite this repository" button in this repo to get the citation in different formats.
130 |
131 | Here are a couple of examples:
132 |
133 | #### BibTeX
134 | ```bibtex
135 | @software{Martinez_Toro_PrivateGPT_2023,
136 | author = {Martínez Toro, Iván and Gallego Vico, Daniel and Orgaz, Pablo},
137 | license = {Apache-2.0},
138 | month = may,
139 | title = {{PrivateGPT}},
140 | url = {https://github.com/imartinez/privateGPT},
141 | year = {2023}
142 | }
143 | ```
144 |
145 | #### APA
146 | ```
147 | Martínez Toro, I., Gallego Vico, D., & Orgaz, P. (2023). PrivateGPT [Computer software]. https://github.com/imartinez/privateGPT
148 | ```
149 |
150 | ## 🤗 Partners & Supporters
151 | PrivateGPT is actively supported by the teams behind:
152 | * [Qdrant](https://qdrant.tech/), providing the default vector database
153 | * [Fern](https://buildwithfern.com/), providing Documentation and SDKs
154 | * [LlamaIndex](https://www.llamaindex.ai/), providing the base RAG framework and abstractions
155 |
156 | This project has been strongly influenced and supported by other amazing projects like
157 | [LangChain](https://github.com/hwchase17/langchain),
158 | [GPT4All](https://github.com/nomic-ai/gpt4all),
159 | [LlamaCpp](https://github.com/ggerganov/llama.cpp),
160 | [Chroma](https://www.trychroma.com/)
161 | and [SentenceTransformers](https://www.sbert.net/).
162 |
--------------------------------------------------------------------------------
/private_gpt/components/llm/custom/sagemaker.py:
--------------------------------------------------------------------------------
1 | # mypy: ignore-errors
2 | from __future__ import annotations
3 |
4 | import io
5 | import json
6 | import logging
7 | from typing import TYPE_CHECKING, Any
8 |
9 | import boto3 # type: ignore
10 | from llama_index.bridge.pydantic import Field
11 | from llama_index.llms import (
12 | CompletionResponse,
13 | CustomLLM,
14 | LLMMetadata,
15 | )
16 | from llama_index.llms.base import (
17 | llm_chat_callback,
18 | llm_completion_callback,
19 | )
20 | from llama_index.llms.generic_utils import (
21 | completion_response_to_chat_response,
22 | stream_completion_response_to_chat_response,
23 | )
24 | from llama_index.llms.llama_utils import (
25 | completion_to_prompt as generic_completion_to_prompt,
26 | )
27 | from llama_index.llms.llama_utils import (
28 | messages_to_prompt as generic_messages_to_prompt,
29 | )
30 |
31 | if TYPE_CHECKING:
32 | from collections.abc import Sequence
33 |
34 | from llama_index.callbacks import CallbackManager
35 | from llama_index.llms import (
36 | ChatMessage,
37 | ChatResponse,
38 | ChatResponseGen,
39 | CompletionResponseGen,
40 | )
41 |
42 | logger = logging.getLogger(__name__)
43 |
44 |
45 | class LineIterator:
46 | r"""A helper class for parsing the byte stream input from TGI container.
47 |
48 | The output of the model will be in the following format:
49 | ```
50 | b'data:{"token": {"text": " a"}}\n\n'
51 | b'data:{"token": {"text": " challenging"}}\n\n'
52 | b'data:{"token": {"text": " problem"
53 | b'}}'
54 | ...
55 | ```
56 |
57 | While usually each PayloadPart event from the event stream will contain a byte array
58 | with a full json, this is not guaranteed and some of the json objects may be split
59 | across PayloadPart events. For example:
60 | ```
61 | {'PayloadPart': {'Bytes': b'{"outputs": '}}
62 | {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
63 | ```
64 |
65 |
66 | This class accounts for this by concatenating bytes written via the 'write' function
67 | and then exposing a method which will return lines (ending with a '\n' character)
68 | within the buffer via the 'scan_lines' function. It maintains the position of the
69 | last read position to ensure that previous bytes are not exposed again. It will
70 | also save any pending lines that doe not end with a '\n' to make sure truncations
71 | are concatinated
72 | """
73 |
74 | def __init__(self, stream: Any) -> None:
75 | """Line iterator initializer."""
76 | self.byte_iterator = iter(stream)
77 | self.buffer = io.BytesIO()
78 | self.read_pos = 0
79 |
80 | def __iter__(self) -> Any:
81 | """Self iterator."""
82 | return self
83 |
84 | def __next__(self) -> Any:
85 | """Next element from iterator."""
86 | while True:
87 | self.buffer.seek(self.read_pos)
88 | line = self.buffer.readline()
89 | if line and line[-1] == ord("\n"):
90 | self.read_pos += len(line)
91 | return line[:-1]
92 | try:
93 | chunk = next(self.byte_iterator)
94 | except StopIteration:
95 | if self.read_pos < self.buffer.getbuffer().nbytes:
96 | continue
97 | raise
98 | if "PayloadPart" not in chunk:
99 | logger.warning("Unknown event type=%s", chunk)
100 | continue
101 | self.buffer.seek(0, io.SEEK_END)
102 | self.buffer.write(chunk["PayloadPart"]["Bytes"])
103 |
104 |
105 | class SagemakerLLM(CustomLLM):
106 | """Sagemaker Inference Endpoint models.
107 |
108 | To use, you must supply the endpoint name from your deployed
109 | Sagemaker model & the region where it is deployed.
110 |
111 | To authenticate, the AWS client uses the following methods to
112 | automatically load credentials:
113 | https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
114 |
115 | If a specific credential profile should be used, you must pass
116 | the name of the profile from the ~/.aws/credentials file that is to be used.
117 |
118 | Make sure the credentials / roles used have the required policies to
119 | access the Sagemaker endpoint.
120 | See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
121 | """
122 |
123 | endpoint_name: str = Field(description="")
124 | temperature: float = Field(description="The temperature to use for sampling.")
125 | max_new_tokens: int = Field(description="The maximum number of tokens to generate.")
126 | context_window: int = Field(
127 | description="The maximum number of context tokens for the model."
128 | )
129 | messages_to_prompt: Any = Field(
130 | description="The function to convert messages to a prompt.", exclude=True
131 | )
132 | completion_to_prompt: Any = Field(
133 | description="The function to convert a completion to a prompt.", exclude=True
134 | )
135 | generate_kwargs: dict[str, Any] = Field(
136 | default_factory=dict, description="Kwargs used for generation."
137 | )
138 | model_kwargs: dict[str, Any] = Field(
139 | default_factory=dict, description="Kwargs used for model initialization."
140 | )
141 | verbose: bool = Field(description="Whether to print verbose output.")
142 |
143 | _boto_client: Any = boto3.client(
144 | "sagemaker-runtime",
145 | ) # TODO make it an optional field
146 |
147 | def __init__(
148 | self,
149 | endpoint_name: str | None = "",
150 | temperature: float = 0.1,
151 | max_new_tokens: int = 512, # to review defaults
152 | context_window: int = 2048, # to review defaults
153 | messages_to_prompt: Any = None,
154 | completion_to_prompt: Any = None,
155 | callback_manager: CallbackManager | None = None,
156 | generate_kwargs: dict[str, Any] | None = None,
157 | model_kwargs: dict[str, Any] | None = None,
158 | verbose: bool = True,
159 | ) -> None:
160 | """SagemakerLLM initializer."""
161 | model_kwargs = model_kwargs or {}
162 | model_kwargs.update({"n_ctx": context_window, "verbose": verbose})
163 |
164 | messages_to_prompt = messages_to_prompt or generic_messages_to_prompt
165 | completion_to_prompt = completion_to_prompt or generic_completion_to_prompt
166 |
167 | generate_kwargs = generate_kwargs or {}
168 | generate_kwargs.update(
169 | {"temperature": temperature, "max_tokens": max_new_tokens}
170 | )
171 |
172 | super().__init__(
173 | endpoint_name=endpoint_name,
174 | temperature=temperature,
175 | context_window=context_window,
176 | max_new_tokens=max_new_tokens,
177 | messages_to_prompt=messages_to_prompt,
178 | completion_to_prompt=completion_to_prompt,
179 | callback_manager=callback_manager,
180 | generate_kwargs=generate_kwargs,
181 | model_kwargs=model_kwargs,
182 | verbose=verbose,
183 | )
184 |
185 | @property
186 | def inference_params(self):
187 | # TODO expose the rest of params
188 | return {
189 | "do_sample": True,
190 | "top_p": 0.7,
191 | "temperature": self.temperature,
192 | "top_k": 50,
193 | "max_new_tokens": self.max_new_tokens,
194 | }
195 |
196 | @property
197 | def metadata(self) -> LLMMetadata:
198 | """Get LLM metadata."""
199 | return LLMMetadata(
200 | context_window=self.context_window,
201 | num_output=self.max_new_tokens,
202 | model_name="Sagemaker LLama 2",
203 | )
204 |
205 | @llm_completion_callback()
206 | def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
207 | self.generate_kwargs.update({"stream": False})
208 |
209 | is_formatted = kwargs.pop("formatted", False)
210 | if not is_formatted:
211 | prompt = self.completion_to_prompt(prompt)
212 |
213 | request_params = {
214 | "inputs": prompt,
215 | "stream": False,
216 | "parameters": self.inference_params,
217 | }
218 |
219 | resp = self._boto_client.invoke_endpoint(
220 | EndpointName=self.endpoint_name,
221 | Body=json.dumps(request_params),
222 | ContentType="application/json",
223 | )
224 |
225 | response_body = resp["Body"]
226 | response_str = response_body.read().decode("utf-8")
227 | response_dict = eval(response_str)
228 |
229 | return CompletionResponse(
230 | text=response_dict[0]["generated_text"][len(prompt) :], raw=resp
231 | )
232 |
233 | @llm_completion_callback()
234 | def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
235 | def get_stream():
236 | text = ""
237 |
238 | request_params = {
239 | "inputs": prompt,
240 | "stream": True,
241 | "parameters": self.inference_params,
242 | }
243 | resp = self._boto_client.invoke_endpoint_with_response_stream(
244 | EndpointName=self.endpoint_name,
245 | Body=json.dumps(request_params),
246 | ContentType="application/json",
247 | )
248 |
249 | event_stream = resp["Body"]
250 | start_json = b"{"
251 | stop_token = "<|endoftext|>"
252 |
253 | for line in LineIterator(event_stream):
254 | if line != b"" and start_json in line:
255 | data = json.loads(line[line.find(start_json) :].decode("utf-8"))
256 | if data["token"]["text"] != stop_token:
257 | delta = data["token"]["text"]
258 | text += delta
259 | yield CompletionResponse(delta=delta, text=text, raw=data)
260 |
261 | return get_stream()
262 |
263 | @llm_chat_callback()
264 | def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
265 | prompt = self.messages_to_prompt(messages)
266 | completion_response = self.complete(prompt, formatted=True, **kwargs)
267 | return completion_response_to_chat_response(completion_response)
268 |
269 | @llm_chat_callback()
270 | def stream_chat(
271 | self, messages: Sequence[ChatMessage], **kwargs: Any
272 | ) -> ChatResponseGen:
273 | prompt = self.messages_to_prompt(messages)
274 | completion_response = self.stream_complete(prompt, formatted=True, **kwargs)
275 | return stream_completion_response_to_chat_response(completion_response)
276 |
--------------------------------------------------------------------------------
/private_gpt/settings/settings.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | from pydantic import BaseModel, Field
4 |
5 | from private_gpt.settings.settings_loader import load_active_settings
6 |
7 |
8 | class CorsSettings(BaseModel):
9 | """CORS configuration.
10 |
11 | For more details on the CORS configuration, see:
12 | # * https://fastapi.tiangolo.com/tutorial/cors/
13 | # * https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
14 | """
15 |
16 | enabled: bool = Field(
17 | description="Flag indicating if CORS headers are set or not."
18 | "If set to True, the CORS headers will be set to allow all origins, methods and headers.",
19 | default=False,
20 | )
21 | allow_credentials: bool = Field(
22 | description="Indicate that cookies should be supported for cross-origin requests",
23 | default=False,
24 | )
25 | allow_origins: list[str] = Field(
26 | description="A list of origins that should be permitted to make cross-origin requests.",
27 | default=[],
28 | )
29 | allow_origin_regex: list[str] = Field(
30 | description="A regex string to match against origins that should be permitted to make cross-origin requests.",
31 | default=None,
32 | )
33 | allow_methods: list[str] = Field(
34 | description="A list of HTTP methods that should be allowed for cross-origin requests.",
35 | default=[
36 | "GET",
37 | ],
38 | )
39 | allow_headers: list[str] = Field(
40 | description="A list of HTTP request headers that should be supported for cross-origin requests.",
41 | default=[],
42 | )
43 |
44 |
45 | class AuthSettings(BaseModel):
46 | """Authentication configuration.
47 |
48 | The implementation of the authentication strategy must
49 | """
50 |
51 | enabled: bool = Field(
52 | description="Flag indicating if authentication is enabled or not.",
53 | default=False,
54 | )
55 | secret: str = Field(
56 | description="The secret to be used for authentication. "
57 | "It can be any non-blank string. For HTTP basic authentication, "
58 | "this value should be the whole 'Authorization' header that is expected"
59 | )
60 |
61 |
62 | class ServerSettings(BaseModel):
63 | env_name: str = Field(
64 | description="Name of the environment (prod, staging, local...)"
65 | )
66 | port: int = Field(description="Port of PrivateGPT FastAPI server, defaults to 8001")
67 | cors: CorsSettings = Field(
68 | description="CORS configuration", default=CorsSettings(enabled=False)
69 | )
70 | auth: AuthSettings = Field(
71 | description="Authentication configuration",
72 | default_factory=lambda: AuthSettings(enabled=False, secret="secret-key"),
73 | )
74 |
75 |
76 | class DataSettings(BaseModel):
77 | local_data_folder: str = Field(
78 | description="Path to local storage."
79 | "It will be treated as an absolute path if it starts with /"
80 | )
81 |
82 |
83 | class LLMSettings(BaseModel):
84 | mode: Literal["local", "openai", "openailike", "sagemaker", "mock"]
85 | max_new_tokens: int = Field(
86 | 256,
87 | description="The maximum number of token that the LLM is authorized to generate in one completion.",
88 | )
89 | context_window: int = Field(
90 | 3900,
91 | description="The maximum number of context tokens for the model.",
92 | )
93 | tokenizer: str = Field(
94 | None,
95 | description="The model id of a predefined tokenizer hosted inside a model repo on "
96 | "huggingface.co. Valid model ids can be located at the root-level, like "
97 | "`bert-base-uncased`, or namespaced under a user or organization name, "
98 | "like `HuggingFaceH4/zephyr-7b-beta`. If not set, will load a tokenizer matching "
99 | "gpt-3.5-turbo LLM.",
100 | )
101 |
102 |
103 | class VectorstoreSettings(BaseModel):
104 | database: Literal["chroma", "qdrant"]
105 |
106 |
107 | class LocalSettings(BaseModel):
108 | llm_hf_repo_id: str
109 | llm_hf_model_file: str
110 | embedding_hf_model_name: str = Field(
111 | description="Name of the HuggingFace model to use for embeddings"
112 | )
113 | prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] = Field(
114 | "llama2",
115 | description=(
116 | "The prompt style to use for the chat engine. "
117 | "If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n"
118 | "If `llama2` - use the llama2 prompt style from the llama_index. Based on ``, `[INST]` and `<>`.\n"
119 | "If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n"
120 | "If `mistral` - use the `mistral prompt style. It shoudl look like [INST] {System Prompt} [/INST][INST] { UserInstructions } [/INST]"
121 | "`llama2` is the historic behaviour. `default` might work better with your custom models."
122 | ),
123 | )
124 |
125 |
126 | class EmbeddingSettings(BaseModel):
127 | mode: Literal["local", "openai", "sagemaker", "mock"]
128 | ingest_mode: Literal["simple", "batch", "parallel"] = Field(
129 | "simple",
130 | description=(
131 | "The ingest mode to use for the embedding engine:\n"
132 | "If `simple` - ingest files sequentially and one by one. It is the historic behaviour.\n"
133 | "If `batch` - if multiple files, parse all the files in parallel, "
134 | "and send them in batch to the embedding model.\n"
135 | "If `parallel` - parse the files in parallel using multiple cores, and embedd them in parallel.\n"
136 | "`parallel` is the fastest mode for local setup, as it parallelize IO RW in the index.\n"
137 | "For modes that leverage parallelization, you can specify the number of "
138 | "workers to use with `count_workers`.\n"
139 | ),
140 | )
141 | count_workers: int = Field(
142 | 2,
143 | description=(
144 | "The number of workers to use for file ingestion.\n"
145 | "In `batch` mode, this is the number of workers used to parse the files.\n"
146 | "In `parallel` mode, this is the number of workers used to parse the files and embed them.\n"
147 | "This is only used if `ingest_mode` is not `simple`.\n"
148 | "Do not go too high with this number, as it might cause memory issues. (especially in `parallel` mode)\n"
149 | "Do not set it higher than your number of threads of your CPU."
150 | ),
151 | )
152 |
153 |
154 | class SagemakerSettings(BaseModel):
155 | llm_endpoint_name: str
156 | embedding_endpoint_name: str
157 |
158 |
159 | class OpenAISettings(BaseModel):
160 | api_base: str = Field(
161 | None,
162 | description="Base URL of OpenAI API. Example: 'https://api.openai.com/v1'.",
163 | )
164 | api_key: str
165 | model: str = Field(
166 | "gpt-3.5-turbo",
167 | description="OpenAI Model to use. Example: 'gpt-4'.",
168 | )
169 |
170 |
171 | class UISettings(BaseModel):
172 | enabled: bool
173 | path: str
174 | default_chat_system_prompt: str = Field(
175 | None,
176 | description="The default system prompt to use for the chat mode.",
177 | )
178 | default_query_system_prompt: str = Field(
179 | None, description="The default system prompt to use for the query mode."
180 | )
181 |
182 |
183 | class QdrantSettings(BaseModel):
184 | location: str | None = Field(
185 | None,
186 | description=(
187 | "If `:memory:` - use in-memory Qdrant instance.\n"
188 | "If `str` - use it as a `url` parameter.\n"
189 | ),
190 | )
191 | url: str | None = Field(
192 | None,
193 | description=(
194 | "Either host or str of 'Optional[scheme], host, Optional[port], Optional[prefix]'."
195 | ),
196 | )
197 | port: int | None = Field(6333, description="Port of the REST API interface.")
198 | grpc_port: int | None = Field(6334, description="Port of the gRPC interface.")
199 | prefer_grpc: bool | None = Field(
200 | False,
201 | description="If `true` - use gRPC interface whenever possible in custom methods.",
202 | )
203 | https: bool | None = Field(
204 | None,
205 | description="If `true` - use HTTPS(SSL) protocol.",
206 | )
207 | api_key: str | None = Field(
208 | None,
209 | description="API key for authentication in Qdrant Cloud.",
210 | )
211 | prefix: str | None = Field(
212 | None,
213 | description=(
214 | "Prefix to add to the REST URL path."
215 | "Example: `service/v1` will result in "
216 | "'http://localhost:6333/service/v1/{qdrant-endpoint}' for REST API."
217 | ),
218 | )
219 | timeout: float | None = Field(
220 | None,
221 | description="Timeout for REST and gRPC API requests.",
222 | )
223 | host: str | None = Field(
224 | None,
225 | description="Host name of Qdrant service. If url and host are None, set to 'localhost'.",
226 | )
227 | path: str | None = Field(None, description="Persistence path for QdrantLocal.")
228 | force_disable_check_same_thread: bool | None = Field(
229 | True,
230 | description=(
231 | "For QdrantLocal, force disable check_same_thread. Default: `True`"
232 | "Only use this if you can guarantee that you can resolve the thread safety outside QdrantClient."
233 | ),
234 | )
235 |
236 |
237 | class Settings(BaseModel):
238 | server: ServerSettings
239 | data: DataSettings
240 | ui: UISettings
241 | llm: LLMSettings
242 | embedding: EmbeddingSettings
243 | local: LocalSettings
244 | sagemaker: SagemakerSettings
245 | openai: OpenAISettings
246 | vectorstore: VectorstoreSettings
247 | qdrant: QdrantSettings | None = None
248 |
249 |
250 | """
251 | This is visible just for DI or testing purposes.
252 |
253 | Use dependency injection or `settings()` method instead.
254 | """
255 | unsafe_settings = load_active_settings()
256 |
257 | """
258 | This is visible just for DI or testing purposes.
259 |
260 | Use dependency injection or `settings()` method instead.
261 | """
262 | unsafe_typed_settings = Settings(**unsafe_settings)
263 |
264 |
265 | def settings() -> Settings:
266 | """Get the current loaded settings from the DI container.
267 |
268 | This method exists to keep compatibility with the existing code,
269 | that require global access to the settings.
270 |
271 | For regular components use dependency injection instead.
272 | """
273 | from private_gpt.di import global_injector
274 |
275 | return global_injector.get(Settings)
276 |
--------------------------------------------------------------------------------