├── .github ├── CONTRIBUTING.md ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── README.rst └── workflows │ ├── integration-build.yml │ ├── mirror_repo_to_gitlab.yml │ ├── publish_and_release.yml │ ├── sdk-build.yml │ └── unit-build.yml ├── .gitignore ├── LICENSE ├── README.rst ├── README_en.rst ├── deployment ├── embedding_service │ └── docker-compose.yml └── sdk-base │ ├── Dockerfile │ └── docker-compose.yml ├── docs ├── source │ └── contribution.rst ├── Описание_программы_ProtoLLM.docx ├── Программа_апробации_ProtoLLM.docx ├── Программа_и_методика испытаний_ProtoLLM.docx ├── Программа_и_методика испытаний_ProtoLLM.docx ├── Протокол приемочных испытаний.docx ├── Руководство_программиста_ProtoLLM.docx └── Текст_программы_ProtoLLM.docx ├── examples ├── __init__.py ├── agent_graph_builder_example.py ├── connector_creator_usage_example.py ├── llama31_usage_example.py ├── metrics_usage_examples.py ├── real_world │ ├── .DS_Store │ ├── __init__.py │ ├── chemical_multi_agent_system │ │ ├── __init__.py │ │ ├── chain.py │ │ ├── prompting.py │ │ └── tools.py │ ├── chemical_pipeline │ │ ├── __init__.py │ │ ├── llama31_chemical_example.py │ │ ├── queries_responses_chemical.xlsx │ │ ├── queries_responses_chemical_large.xlsx │ │ └── validate_tools.py │ └── urbanistics │ │ ├── __init__.py │ │ ├── data │ │ └── sample_data_city_rag.json │ │ ├── rag_example.ipynb │ │ ├── rag_example.py │ │ └── synthetic_rag_query_example.py └── universal_agents │ ├── in_and_translator_example.py │ ├── re_and_planner_example.py │ ├── summary_example.py │ └── supervisor_example.py ├── poetry.lock ├── protollm-synthetic ├── README.md ├── examples │ ├── aspect_summarisation_example.py │ ├── free_query_example.py │ ├── quiz_example.py │ ├── rag_example.py │ └── summarisation_example.py ├── poetry.lock ├── protollm_synthetic │ ├── __init__.py │ ├── synthetic_pipelines │ │ ├── __init__.py │ │ ├── chains.py │ │ ├── genetic_evolver.py │ │ ├── prompts.py │ │ └── schemes.py │ └── utils.py ├── pyproject.toml ├── tests │ └── test_summarization_chain.py └── tmp_data │ ├── data.json │ ├── gen_questions.csv │ ├── gen_questions.json │ ├── rag_generated.pickle │ ├── sample_data_city_rag.json │ ├── sample_data_city_rag_generated.json │ ├── sample_data_free_instruction.json │ ├── sample_data_rag.json │ └── tmp_sample_summarization_dataset.json ├── protollm ├── .DS_Store ├── __init__.py ├── agents │ ├── __init__.py │ ├── agent_prompts.py │ ├── agent_utils │ │ ├── parsers.py │ │ ├── pydantic_models.py │ │ └── states.py │ ├── builder.py │ ├── llama31_agents │ │ ├── __init__.py │ │ └── llama31_agent.py │ └── universal_agents.py ├── connectors │ ├── README.md │ ├── __init__.py │ ├── connector_creator.py │ ├── rest_server.py │ └── utils.py ├── definitions.py ├── ensembles_ma │ └── collect_results.py ├── metrics │ ├── __init__.py │ ├── deepeval_connector.py │ └── evaluation_metrics.py ├── rags │ ├── __init__.py │ ├── configs │ │ ├── chroma.env │ │ ├── docs_processing_config.yaml │ │ └── elastic.env │ ├── jobs.py │ ├── pipeline │ │ ├── __init__.py │ │ ├── docs_processing │ │ │ ├── __init__.py │ │ │ ├── entities.py │ │ │ ├── exceptions.py │ │ │ ├── models.py │ │ │ └── utils.py │ │ └── etl_pipeline.py │ ├── rag_core │ │ ├── __init__.py │ │ ├── planner.py │ │ ├── reranker.py │ │ ├── retriever.py │ │ └── utils.py │ ├── settings │ │ ├── __init__.py │ │ ├── chroma_settings.py │ │ ├── es_settings.py │ │ └── pipeline_settings.py │ └── stores │ │ ├── __init__.py │ │ ├── chroma │ │ ├── __init__.py │ │ ├── chroma_loader.py │ │ └── utils.py │ │ └── elasticsearch │ │ ├── __init__.py │ │ ├── configs │ │ ├── index_mappings.json │ │ ├── index_settings.json │ │ ├── query_all_hits.json │ │ └── query_template.json │ │ ├── retrieval_strategies.py │ │ ├── settings.py │ │ └── utilities.py ├── raw_data_processing │ ├── __init__.py │ ├── docs_parsers │ │ ├── __init__.py │ │ ├── loaders │ │ │ ├── __init__.py │ │ │ ├── directory_loader.py │ │ │ ├── doc_loader.py │ │ │ ├── pdf_loader.py │ │ │ └── zip_loader.py │ │ ├── parsers │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── converting │ │ │ │ ├── __init__.py │ │ │ │ ├── converted_file.py │ │ │ │ ├── converting.py │ │ │ │ └── exceptions.py │ │ │ ├── entities.py │ │ │ ├── pdf │ │ │ │ ├── __init__.py │ │ │ │ ├── pdf_parser.py │ │ │ │ └── utilities.py │ │ │ ├── utilities.py │ │ │ └── word_doc │ │ │ │ ├── __init__.py │ │ │ │ ├── docx_parsing.py │ │ │ │ ├── docx_parsing_config.py │ │ │ │ ├── utilities.py │ │ │ │ ├── word_doc_parser.py │ │ │ │ └── xml │ │ │ │ ├── __init__.py │ │ │ │ ├── utilities.py │ │ │ │ ├── xml_processing.py │ │ │ │ ├── xml_tag.py │ │ │ │ └── xsl │ │ │ │ ├── mml2tex │ │ │ │ ├── cmarkup.xsl │ │ │ │ ├── entities.xsl │ │ │ │ ├── glayout.xsl │ │ │ │ ├── mmltex.xsl │ │ │ │ ├── scripts.xsl │ │ │ │ ├── tables.xsl │ │ │ │ └── tokens.xsl │ │ │ │ └── omml2mml │ │ │ │ └── OMML2MML.XSL │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── exceptions.py │ │ │ ├── logger.py │ │ │ └── utilities.py │ └── docs_transformers │ │ ├── __init__.py │ │ ├── chunk_merger.py │ │ ├── key_words_splitter.py │ │ ├── metadata_sentence_splitter.py │ │ ├── recursive_splitter.py │ │ ├── sentences_splitter.py │ │ └── utilities.py ├── templates │ ├── __init__.py │ ├── code_templates │ │ ├── __init__.py │ │ ├── agent_template.py │ │ ├── plugin_template.py │ │ └── rag_template.py │ └── prompt_templates │ │ ├── __init__.py │ │ ├── assistant_prompt_templates.py │ │ ├── metric_evalutation_prompts.py │ │ ├── qa_prompt_templates.py │ │ ├── rag_prompt_templates.py │ │ └── synthetic_data_prompts.py └── tools │ └── web_tools.py ├── protollm_tools ├── llm-agents-api │ ├── .env.example │ ├── README.md │ ├── __init__.py │ ├── docker-compose.yml │ ├── examples │ │ ├── admin-config.yml │ │ ├── main.py │ │ ├── main_local.py │ │ └── pipelines │ │ │ ├── __init__.py │ │ │ ├── mock_background_agent.py │ │ │ ├── rag_agent.py │ │ │ ├── rag_pipeline.py │ │ │ └── utils.py │ ├── poetry.lock │ ├── protollm_agents │ │ ├── __init__.py │ │ ├── api │ │ │ ├── __init__.py │ │ │ ├── rest_agents.py │ │ │ └── socket_agents.py │ │ ├── configs.py │ │ ├── entrypoint.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── models.py │ │ │ ├── requests.py │ │ │ ├── responses.py │ │ │ └── schemas.py │ │ ├── sdk │ │ │ ├── __init__.py │ │ │ ├── agents.py │ │ │ ├── base.py │ │ │ ├── context.py │ │ │ ├── events.py │ │ │ ├── models.py │ │ │ ├── pipelines │ │ │ │ ├── ensemble_router_pipeline.py │ │ │ │ └── router_pipeline.py │ │ │ └── vector_stores.py │ │ └── services │ │ │ ├── __init__.py │ │ │ ├── agents_manager.py │ │ │ ├── cache_client.py │ │ │ ├── db_client.py │ │ │ ├── exceptions.py │ │ │ ├── socket_connector.py │ │ │ └── storage.py │ ├── pyproject.toml │ └── tests │ │ ├── .env.test.example │ │ ├── config.test.yml │ │ ├── conftest.py │ │ ├── docker-compose.test.yml │ │ ├── docs │ │ ├── col_data_ed.json │ │ └── col_data_env.json │ │ └── test_api.py ├── llm-api │ ├── Dockerfile │ ├── README.md │ ├── __init__.py │ ├── docker-compose.yml │ ├── poetry.lock │ ├── protollm_api │ │ ├── __init__.py │ │ ├── backend │ │ │ ├── broker.py │ │ │ ├── endpoints.py │ │ │ └── main.py │ │ └── config.py │ ├── pyproject.toml │ ├── requirements.txt │ ├── tests │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── integration │ │ │ ├── test_local_RMQ.py │ │ │ ├── test_local_Redis.py │ │ │ └── test_with_llm.py │ │ └── unit │ │ │ ├── test_brocker.py │ │ │ └── test_endpoints.py │ └── unit_config.json ├── llm-worker │ ├── Dockerfile │ ├── README.md │ ├── __init__.py │ ├── docker-compose.yml │ ├── main.py │ ├── poetry.lock │ ├── protollm_worker │ │ ├── __init__.py │ │ ├── config.py │ │ ├── models │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── cpp_models.py │ │ │ ├── hf_models.py │ │ │ ├── open_api_llm.py │ │ │ └── vllm_models.py │ │ └── services │ │ │ ├── __init__.py │ │ │ └── broker.py │ ├── pyproject.toml │ └── requirements.txt └── sdk │ ├── Dockerfile │ ├── README.md │ ├── docker-compose.yml │ ├── examples │ └── celery.py │ ├── poetry.lock │ ├── protollm_sdk │ ├── __init__.py │ ├── celery │ │ ├── app.py │ │ ├── config.py │ │ ├── constants.py │ │ └── job.py │ ├── config.py │ ├── jobs │ │ ├── __init__.py │ │ ├── job.py │ │ ├── job_context.py │ │ ├── job_invoke.py │ │ ├── llm_api.py │ │ ├── outer_llm_api.py │ │ ├── result_storage.py │ │ ├── text_embedder.py │ │ ├── utility.py │ │ └── vector_db.py │ ├── models │ │ ├── __init__.py │ │ ├── errors.py │ │ └── job_context_models.py │ ├── object_interface │ │ ├── __init__.py │ │ ├── rabbit_mq_wrapper.py │ │ └── redis_wrapper.py │ └── utils │ │ ├── __init__.py │ │ ├── reddis.py │ │ └── singleton.py │ ├── pyproject.toml │ ├── requirements.txt │ └── tests │ ├── __init__.py │ └── protollm_sdk │ ├── __init__.py │ ├── celery │ └── test_app.py │ ├── job │ ├── __init__.py │ ├── test_job.py │ ├── test_job_api.py │ ├── test_job_invoke.py │ ├── test_llm_api.py │ ├── test_outer_llm_api.py │ ├── test_result_storage.py │ ├── test_text_embedder.py │ ├── test_utility.py │ └── test_vector_db.py │ ├── models │ └── test_job_context.py │ ├── object_interface │ ├── integration │ │ ├── __init__.py │ │ ├── test_rabbit_mq_wrapper.py │ │ └── test_redis_wrapper.py │ └── unit │ │ ├── __init__.py │ │ ├── test_rabbit_mq_wrapper.py │ │ └── test_redis_wrapper.py │ └── test_utils.py ├── pyproject.toml ├── requirements.txt └── tests ├── __init__.py ├── mock_chat_model.py ├── test_agent.py ├── test_connector.py ├── test_metrics.py └── validation ├── __init__.py ├── admin_config.yml ├── api_check.py ├── complex_check.py ├── complex_check_ens.py ├── ens_check.py ├── fail_check.py ├── rag_check.py ├── repro_check.py └── testing.md /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Is something not working as expected? 4 | title: "[Bug]: " 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | Hi! Thank you for taking the time to report a bug with ProtoLLM. 11 | 12 | Please consider asking your question at https://t.me/ProtoLLM_helpdesk if one or more of the following is applicable to your situation: 13 | 14 | - You are not sure if the issue is a bug in ProtoLLM. 15 | - The issue is caused by a third-party library. 16 | - This is just a generic usage question. 17 | 18 | Additionally, please note that this platform is meant for bugs in ProtoLLM only. 19 | Issues regarding dependencies and libraries should be reported in their respective repositories. 20 | 21 | 22 | 23 | ## Expected Behavior 24 | 25 | 26 | ## Current Behavior 27 | 28 | 29 | ## Possible Solution 30 | 31 | 32 | 33 | ## Steps to Reproduce 34 | 35 | 36 | ## Context [OPTIONAL] 37 | 38 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Want us to add any features to ProtoLLM? 4 | title: 'enh: ' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | 16 | 17 | ## Summary 18 | 19 | 22 | 23 | ## Motivation 24 | 25 | 28 | 29 | ## Guide-level explanation 30 | 31 | 42 | 43 | ## Reference-level explanation 44 | 45 | 55 | 56 | ## Drawbacks 57 | 58 | 61 | 62 | ## Unresolved Questions 63 | 64 | -------------------------------------------------------------------------------- /.github/README.rst: -------------------------------------------------------------------------------- 1 | ../README_en.rst -------------------------------------------------------------------------------- /.github/workflows/integration-build.yml: -------------------------------------------------------------------------------- 1 | name: Integration Build 2 | 3 | on: 4 | schedule: 5 | - cron: '0 12 * * *' 6 | push: 7 | branches: [ main ] 8 | pull_request: 9 | branches: [ main ] 10 | workflow_dispatch: 11 | 12 | jobs: 13 | scheduled: 14 | runs-on: ubuntu-latest 15 | timeout-minutes: 95 16 | strategy: 17 | matrix: 18 | python-version: [ '3.10' ] 19 | 20 | services: 21 | redis: 22 | image: redis:latest 23 | ports: 24 | - 6379:6379 25 | rabbitmq: 26 | image: rabbitmq:latest 27 | env: 28 | RABBITMQ_DEFAULT_USER: admin 29 | RABBITMQ_DEFAULT_PASS: admin 30 | ports: 31 | - 5672:5672 32 | - 15672:15672 33 | 34 | steps: 35 | - name: Checkout branch 36 | uses: actions/checkout@v2 37 | - name: Set up Python ${{ matrix.python-version }} 38 | uses: actions/setup-python@v2 39 | with: 40 | python-version: ${{ matrix.python-version }} 41 | - name: Install llm-api dependencies 42 | run: | 43 | python -m pip install --upgrade pip 44 | pip install pytest 45 | pip install pytest-asyncio 46 | pip install -r ./protollm_tools/llm-api/requirements.txt 47 | - name: Test llm-api with pytest 48 | run: | 49 | pytest -s ./protollm_tools/llm-api/tests/integration 50 | -------------------------------------------------------------------------------- /.github/workflows/mirror_repo_to_gitlab.yml: -------------------------------------------------------------------------------- 1 | name: Mirror repo to GitLab 2 | 3 | on: [push, pull_request, delete] 4 | 5 | jobs: 6 | call-nss-ops-mirror-workflow: 7 | uses: aimclub/open-source-ops/.github/workflows/mirror-repo.yml@master 8 | with: 9 | GITLAB_URL: 'https://gitlab.actcognitive.org/itmo-sai-code/ProtoLLM.git' 10 | secrets: 11 | GITLAB_USER: ${{ secrets.GITLAB_USER }} 12 | GITLAB_PASSWORD: ${{ secrets.GITLAB_PASSWORD }} -------------------------------------------------------------------------------- /.github/workflows/publish_and_release.yml: -------------------------------------------------------------------------------- 1 | name: Publish package and create release 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v[0-9]+.[0-9]+.[0-9]+' 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | 12 | permissions: 13 | contents: write 14 | id-token: write 15 | 16 | steps: 17 | - name: Checkout code 18 | uses: actions/checkout@v4 19 | 20 | - name: Set up Python 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: "3.10" 24 | 25 | - name: Install Poetry 26 | run: | 27 | curl -sSL https://install.python-poetry.org | python3 - 28 | echo "$HOME/.local/bin" >> $GITHUB_PATH 29 | 30 | - name: Install dependencies and build package 31 | run: | 32 | poetry install 33 | poetry build 34 | 35 | - name: Publish to PyPI 36 | env: 37 | POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYPI_TOKEN }} 38 | run: | 39 | poetry config pypi-token.pypi $POETRY_PYPI_TOKEN_PYPI 40 | poetry publish 41 | 42 | - name: Create GitHub Release 43 | env: 44 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 45 | run: >- 46 | gh release create 47 | "${GITHUB_REF_NAME}" 48 | --repo "${GITHUB_REPOSITORY}" 49 | --title "Release ${GITHUB_REF_NAME}" 50 | --generate-notes 51 | 52 | - name: Upload artifact signatures to GitHub Release 53 | env: 54 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 55 | run: >- 56 | gh release upload 57 | "$GITHUB_REF_NAME" dist/** 58 | --repo "$GITHUB_REPOSITORY" -------------------------------------------------------------------------------- /.github/workflows/sdk-build.yml: -------------------------------------------------------------------------------- 1 | name: SDK Build 2 | 3 | on: 4 | schedule: 5 | - cron: '0 12 * * *' 6 | push: 7 | branches: [ main ] 8 | pull_request: 9 | branches: [ main ] 10 | workflow_dispatch: 11 | 12 | jobs: 13 | scheduled: 14 | runs-on: ubuntu-latest 15 | timeout-minutes: 95 16 | strategy: 17 | matrix: 18 | python-version: [ '3.10' ] 19 | 20 | services: 21 | redis: 22 | image: redis:latest 23 | ports: 24 | - 6379:6379 25 | rabbitmq: 26 | image: rabbitmq:latest 27 | ports: 28 | - 5672:5672 29 | - 15672:15672 30 | 31 | steps: 32 | - name: Checkout branch 33 | uses: actions/checkout@v2 34 | - name: Set up Python ${{ matrix.python-version }} 35 | uses: actions/setup-python@v2 36 | with: 37 | python-version: ${{ matrix.python-version }} 38 | - name: Install pytest dependencies 39 | run: | 40 | python -m pip install --upgrade pip 41 | pip install pytest 42 | pip install pytest-asyncio 43 | - name: Change directory and install sdk dependencies 44 | run: | 45 | cd ./protollm_tools/sdk 46 | pip install -r requirements.txt 47 | pip install -e . 48 | - name: Test sdk with pytest 49 | run: | 50 | cd ./protollm_tools/sdk 51 | pytest -s ./tests -m ci 52 | -------------------------------------------------------------------------------- /.github/workflows/unit-build.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Unit Build 5 | 6 | on: 7 | push: 8 | branches: [ main, release ] 9 | pull_request: 10 | branches: [ main, release ] 11 | workflow_dispatch: # manually launch from GitHub actions 12 | 13 | jobs: 14 | build: 15 | 16 | runs-on: ubuntu-latest 17 | timeout-minutes: 15 18 | strategy: 19 | matrix: 20 | python-version: [ '3.10' ] 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install pytest 32 | pip install . 33 | - name: Test with pytest 34 | run: | 35 | pytest -s tests 36 | - name: Install llm-api dependencies 37 | run: | 38 | pip install pytest-asyncio 39 | pip install -r ./protollm_tools/llm-api/requirements.txt 40 | - name: Test llm-api with pytest 41 | run: | 42 | pytest -s ./protollm_tools/llm-api/tests/unit 43 | 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, Natural Systems Simulation Lab 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /deployment/embedding_service/docker-compose.yml: -------------------------------------------------------------------------------- 1 | networks: 2 | net: 3 | driver: bridge 4 | 5 | services: 6 | server: 7 | image: chromadb/chroma:latest 8 | environment: 9 | - IS_PERSISTENT=TRUE 10 | volumes: 11 | # Default configuration for persist_directory in chromadb/config.py 12 | # Currently it's located in "/chroma/chroma/" 13 | - chroma_data:/chroma/chroma/ 14 | ports: 15 | - 9941:8000 16 | networks: 17 | - net 18 | 19 | embedding_server: 20 | image: ${EMBEDDING_IMAGE:-ghcr.io/huggingface/text-embeddings-inference:cpu-0.3.0} #default image with CPU support 21 | # using ${EMBEDDING_IMAGE:-ghcr.io/huggingface/text-embeddings-inference:cuda-1.6} for gpu 22 | command: --model-id ${ST_MODEL:-intfloat/multilingual-e5-large} --revision ${ST_MODEL_REVISION:-main} # configure model and model revision paramters. 23 | # you can choose a embedding model by changing the varaibale ST_MODEL:-intfloat/multilingual-e5-large 24 | # where intfloat/multilingual-e5-large path to huggingface model 25 | ports: 26 | - 9942:80 27 | networks: 28 | - net 29 | volumes: 30 | - embedding_data:/data #by default we create a volume for the models. 31 | 32 | volumes: 33 | chroma_data: 34 | driver: local 35 | # device: # enter path to external folder 36 | embedding_data: 37 | driver: local 38 | # device: # enter path to external folder 39 | -------------------------------------------------------------------------------- /deployment/sdk-base/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10 2 | 3 | # Set the working directory for the container 4 | WORKDIR /app 5 | 6 | # Copy only the necessary files from the protollm_tools/sdk directory 7 | COPY protollm_tools/sdk . 8 | 9 | # Install required dependencies 10 | RUN pip install -r requirements.txt 11 | 12 | # Default command to start Celery worker 13 | CMD ["celery", "-A", "protollm_sdk.celery.app", "worker", "--loglevel=info"] 14 | -------------------------------------------------------------------------------- /docs/Описание_программы_ProtoLLM.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/docs/Описание_программы_ProtoLLM.docx -------------------------------------------------------------------------------- /docs/Программа_апробации_ProtoLLM.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/docs/Программа_апробации_ProtoLLM.docx -------------------------------------------------------------------------------- /docs/Программа_и_методика испытаний_ProtoLLM.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/docs/Программа_и_методика испытаний_ProtoLLM.docx -------------------------------------------------------------------------------- /docs/Программа_и_методика испытаний_ProtoLLM.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/docs/Программа_и_методика испытаний_ProtoLLM.docx -------------------------------------------------------------------------------- /docs/Протокол приемочных испытаний.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/docs/Протокол приемочных испытаний.docx -------------------------------------------------------------------------------- /docs/Руководство_программиста_ProtoLLM.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/docs/Руководство_программиста_ProtoLLM.docx -------------------------------------------------------------------------------- /docs/Текст_программы_ProtoLLM.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/docs/Текст_программы_ProtoLLM.docx -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/examples/__init__.py -------------------------------------------------------------------------------- /examples/metrics_usage_examples.py: -------------------------------------------------------------------------------- 1 | # You can use correctness_metric from ProtoLLM or import desired directly from deepeval. 2 | # The correctness metric is demonstrative, but has been shown to be good for determining the correctness of an answer. 3 | # You can define a new one, with a different criterion, if necessary. 4 | # 5 | # In the second case, you may also need to import a connector object for deepeval metrics to work. This can be done 6 | # as follows: 7 | # 8 | # `from protollm.metrics import model_for_metrics` 9 | # 10 | # Also make sure that you set the model URL and model_name in the same format as for a normal LLM connector 11 | # (URL;model_name). 12 | # 13 | # Detailed documentation on metrics is available at the following URL: 14 | # https://docs.confident-ai.com/docs/metrics-introduction 15 | 16 | import logging 17 | 18 | from deepeval.metrics import AnswerRelevancyMetric, ToolCorrectnessMetric 19 | from deepeval.test_case import LLMTestCase, ToolCall 20 | 21 | from protollm.metrics import correctness_metric, model_for_metrics 22 | 23 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 24 | 25 | answer_relevancy = AnswerRelevancyMetric(model=model_for_metrics, async_mode=False) 26 | tool_correctness = ToolCorrectnessMetric() 27 | 28 | if __name__ == "__main__": 29 | # ===================================metrics using LLM============================================= 30 | # Create test case for metric 31 | test_case = LLMTestCase( 32 | input="What if these shoes don't fit?", 33 | actual_output="We offer a 30-day full refund at no extra cost.", 34 | expected_output="You are eligible for a 30 day full refund at no extra cost." 35 | ) 36 | 37 | answer_relevancy.measure(test_case) # Evaluate metric 38 | logging.info(f"Answer relevancy score {answer_relevancy.score}") 39 | logging.info(f"Answer relevancy reason: {answer_relevancy.reason}") 40 | 41 | correctness_metric.measure(test_case) # Evaluate metric 42 | logging.info(f"Correctness score {correctness_metric.score}") 43 | logging.info(f"Correctness reason: {correctness_metric.reason}") 44 | 45 | # ===================================metrics not using LLM========================================= 46 | # Create test case for metric 47 | test_case = LLMTestCase( 48 | input="What if these shoes don't fit?", 49 | actual_output="We offer a 30-day full refund at no extra cost.", 50 | # Replace this with the tools that was actually used by your LLM agent 51 | tools_called=[ToolCall(name="WebSearch", input_parameters={}), ToolCall(name="ToolQuery", input_parameters={})], 52 | expected_tools=[ToolCall(name="WebSearch", input_parameters={})], 53 | ) 54 | 55 | tool_correctness.measure(test_case) 56 | logging.info(f"Tool correctness score {tool_correctness.score}") 57 | logging.info(f"Tool correctness reason: {tool_correctness.reason}") 58 | -------------------------------------------------------------------------------- /examples/real_world/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/examples/real_world/.DS_Store -------------------------------------------------------------------------------- /examples/real_world/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/examples/real_world/__init__.py -------------------------------------------------------------------------------- /examples/real_world/chemical_multi_agent_system/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/examples/real_world/chemical_multi_agent_system/__init__.py -------------------------------------------------------------------------------- /examples/real_world/chemical_multi_agent_system/prompting.py: -------------------------------------------------------------------------------- 1 | system_prompt_conductor = ''' 2 | Respond to the human as helpfully and accurately as possible. You have access to the following tools: 3 | 4 | {tools} 5 | 6 | Use a JSON blob to specify a tool by providing an "action" key (tool name) and an "action_input" key (tool input). 7 | 8 | Valid "action" values: "Final Answer" or {tool_names} 9 | 10 | Provide only ONE action per JSON blob, as shown: 11 | 12 | {{ "action": $TOOL_NAME, "action_input": $INPUT }} 13 | 14 | Follow this format: 15 | 16 | Question: input question to answer 17 | Thought: consider previous and subsequent steps 18 | Action: $JSON_BLOB 19 | 20 | Observation: action result 21 | ... (repeat Thought/Action/Observation N times) 22 | Thought: I know what to respond 23 | Action: {{ "action": "Final Answer", "action_input": "Final response to human" }} 24 | 25 | 26 | Begin! Reminder to ALWAYS respond with a valid JSON blob of a single action. Use tools if necessary. 27 | Respond directly if appropriate. Format is Action:```$JSON_BLOB``` then Observation 28 | In the "Final Answer" you must ALWAYS display all generated molecules!!! 29 | For example answer must consist table (!): 30 | | Molecules | QED | Synthetic Accessibility | PAINS | SureChEMBL | Glaxo | Brenk | BBB | IC50 | 31 | \n| --- | --- | --- | --- | --- | --- | --- | --- | --- |\n| Fc1ccc2c(c1)CCc1ccccc1-2 | 0.6064732613170888 32 | | 1.721973678244476 | 0 | 0 | 0 | 0 | 1 | 0 |\n| O=C(Nc1ccc(C(=O)c2ccccc2)cc1)c1ccc(F)cc1 | 0.728441789442482 33 | | 1.4782662488060723 | 0 | 0 | 0 | 0 | 1 | 1 |\n| O=C(Nc1ccccc1)c1ccc(NS(=O)(=O)c2ccc3c(c2)CCC3=O)cc1 | 34 | 0.6727786031171711 | 1.9616124655434675 | 0 | 0 | 0 | 0 | 0 | 0 |\n| Cc1ccc(C)c(-n2c(=O)c3ccccc3n(Cc3ccccc3)c2=O)c1 35 | | 0.5601042919484651 | 1.920664623176684 | 0 | 0 | 0 | 0 | 1 | 1 |\n| Cc1ccc2c(c1)N(C(=O)CN1C(=O)NC3(CCCc4ccccc43)C1=O)CC2 36 | | 0.8031696199670261 | 3.3073398307371438 | 0 | 0 | 0 | 1 | 1 | 0 |" 37 | ''' 38 | system_prompt_decomposer = \ 39 | """ 40 | Respond to the human as helpfully and accurately as possible. You must decompose the input questions into tasks. 41 | 42 | Use a JSON to specify a tool by providing an "action" key (tool name) and an "action_input" key (tool input). 43 | Valid "action" values: "Final Answer". Action is always == "Final Answer". 44 | Valid number of tasks: 1-5. 45 | 46 | Follow this format: 47 | Question: input questions to answer 48 | { "action": "Final Answer", "action_input": "[task1, task2, task3...]" } 49 | 50 | Example: 51 | Question: Generate molecule for Alzheimer. Generate 3 molecules for Parkinson 52 | { "action": "Final Answer", "action_input": "['Generate molecule for Alzheimer', 'Generate 3 molecules for Parkinson']" } 53 | 54 | Begin! Reminder to ALWAYS respond with a valid JSON of a single action. 55 | In the "Final Answer" you must ALWAYS display in list! 56 | """ 57 | 58 | human_prompt = '''{input} 59 | {agent_scratchpad} 60 | (Reminder to respond in a JSON blob no matter what)''' -------------------------------------------------------------------------------- /examples/real_world/chemical_pipeline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/examples/real_world/chemical_pipeline/__init__.py -------------------------------------------------------------------------------- /examples/real_world/chemical_pipeline/queries_responses_chemical.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/examples/real_world/chemical_pipeline/queries_responses_chemical.xlsx -------------------------------------------------------------------------------- /examples/real_world/chemical_pipeline/queries_responses_chemical_large.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/examples/real_world/chemical_pipeline/queries_responses_chemical_large.xlsx -------------------------------------------------------------------------------- /examples/real_world/urbanistics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/examples/real_world/urbanistics/__init__.py -------------------------------------------------------------------------------- /examples/real_world/urbanistics/rag_example.py: -------------------------------------------------------------------------------- 1 | import chromadb 2 | import os 3 | import uuid 4 | from dotenv import load_dotenv 5 | from langchain_community.embeddings.huggingface_hub import HuggingFaceHubEmbeddings 6 | 7 | from protollm_sdk.models.job_context_models import PromptModel 8 | from protollm_sdk.jobs.outer_llm_api import OuterLLMAPI 9 | from protollm.rags.rag_core.retriever import DocRetriever, DocsSearcherModels 10 | 11 | from protollm.definitions import CONFIG_PATH 12 | 13 | 14 | def init_chroma_client(): 15 | host, port = os.environ.get("CHROMA_DEFAULT_SETTINGS").split(':') 16 | return chromadb.HttpClient( 17 | host=host, 18 | port=int(port), 19 | settings=chromadb.Settings(), 20 | ) 21 | 22 | 23 | def proto_view( 24 | query: str, 25 | collection: str, 26 | k: int = 1, 27 | embedding_function: HuggingFaceHubEmbeddings = None, 28 | ) -> list: 29 | # Returns k chunks that are closest to the query 30 | embedding_host = os.environ.get("EMBEDDING_HOST") 31 | embedding_function = HuggingFaceHubEmbeddings(model=embedding_host) 32 | chroma_client = init_chroma_client() 33 | 34 | docs_searcher_models = DocsSearcherModels(embedding_model=embedding_function, chroma_client=chroma_client) 35 | retriever = DocRetriever(top_k=k, 36 | docs_searcher_models=docs_searcher_models, 37 | ) 38 | 39 | return retriever.retrieve_top(collection_name=collection, query=query) 40 | 41 | 42 | def outer_llm(question: str, 43 | meta: dict, 44 | key: str): 45 | llmapi = OuterLLMAPI(key) 46 | llm_request = PromptModel(job_id=str(uuid.uuid4()), 47 | meta=meta, 48 | content=question) 49 | res = llmapi.inference(llm_request) 50 | return res.content 51 | 52 | 53 | if __name__ == "__main__": 54 | load_dotenv(CONFIG_PATH) 55 | 56 | # Настройки БЯМ 57 | meta = {"temperature": 0.05, 58 | "tokens_limit": 4096, 59 | "stop_words": None} 60 | key = os.environ.get("VSE_GPT_KEY") 61 | 62 | 63 | # Название коллекции в БД 64 | collection_name = "strategy-spb" 65 | 66 | # Вопрос 67 | question = 'Какие задачи Стратегия ставит в области энергосбережения?' 68 | 69 | # Извлечение контекста из БД 70 | context = proto_view(question, collection_name) 71 | context = f'Вопрос: {question} Контекст: {context[0].page_content}' 72 | 73 | # Получение ответа от БЯМ 74 | print(f'Ответ VseGPT LLM: \n {outer_llm(context, meta, key)}') 75 | -------------------------------------------------------------------------------- /examples/real_world/urbanistics/synthetic_rag_query_example.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from samplefactory.synthetic_pipelines.chains import RAGChain 4 | from samplefactory.utils import Dataset, VLLMChatOpenAI 5 | import asyncio 6 | 7 | logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") 8 | logger = logging.getLogger(__name__) 9 | 10 | # path = 'tmp_data/sample_data_city_rag.json' 11 | path = 'tmp_data/sample_data_rag_spb.json' 12 | dataset = Dataset(data_col='content', path=path) 13 | 14 | qwen_large_api_key = os.environ.get("QWEN_OPENAI_API_KEY") 15 | qwen_large_api_base = os.environ.get("QWEN_OPENAI_API_BASE") 16 | 17 | logger.info("Initializing LLM connection") 18 | 19 | llm=VLLMChatOpenAI( 20 | api_key=qwen_large_api_key, 21 | base_url=qwen_large_api_base, 22 | model="/model", 23 | max_tokens=2048, 24 | ) 25 | 26 | rag_chain = RAGChain(llm=llm) 27 | 28 | logger.info("Starting generating") 29 | asyncio.run(rag_chain.run(dataset, 30 | n_examples=5)) 31 | 32 | logger.info("Saving results") 33 | path = 'tmp_data/sample_data_city_rag_generated.json' 34 | 35 | # An alternative way to save data 36 | # rag_chain.save_chain_output('tmp_data/sample_data_city_rag_generated.json') 37 | 38 | df = rag_chain.data.explode('generated') 39 | df['question'] = df['generated'].apply(lambda x: x['question']) 40 | df['answer'] = df['generated'].apply(lambda x: x['answer']) 41 | df = df[['content', 'question', 'answer']] 42 | 43 | logger.info(f"Writing result to {path}") 44 | df.to_json(path, orient="records") 45 | 46 | logger.info("Generation successfully finished") 47 | -------------------------------------------------------------------------------- /examples/universal_agents/in_and_translator_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of launching the Agent-translator. 3 | A phrase in Italian is fed to the input (it can be fed in any language). 4 | Phrase will be translated into English (for execution within the intended system), and when 5 | passed to the second agent, a response will be returned in the original language. 6 | """ 7 | 8 | from protollm.agents.universal_agents import in_translator_node, re_translator_node 9 | from protollm.connectors import create_llm_connector 10 | 11 | if __name__ == "__main__": 12 | state = {"input": "Ciao! Come stai?"} 13 | 14 | model = create_llm_connector( 15 | "https://api.vsegpt.ru/v1;meta-llama/llama-3.1-70b-instruct" 16 | ) 17 | conf = {"llm": model, "max_retries": 1} 18 | res = in_translator_node(state, conf) 19 | print("Language is: ", res["language"]) 20 | print("Translate: ", res["translation"]) 21 | 22 | res["response"] = "Made up answer..." 23 | total_res = re_translator_node(res, conf) 24 | 25 | print("Total answer: ") 26 | print(total_res["response"].content) 27 | -------------------------------------------------------------------------------- /examples/universal_agents/re_and_planner_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of launching the Agent-Planner and RePlanner 3 | """ 4 | 5 | from protollm.agents.universal_agents import plan_node, replan_node 6 | from protollm.connectors import create_llm_connector 7 | 8 | if __name__ == "__main__": 9 | tools_description = """name2smiles(mol: typing.Annotated[str, 'Name of a molecule']) 10 | - Use this to convert molecule name to smiles format. 11 | Only use for organic molecules\n 12 | 13 | smiles2name(smiles: typing.Annotated[str, 'SMILES of a molecule']) 14 | - Use this to convert SMILES to IUPAC name of given molecule\n 15 | 16 | smiles2prop(smiles: typing.Annotated[str, 'SMILES of a molecule'], iupac: Optional[str] = None) 17 | - Use this to calculate all available properties of given molecule. 18 | Only use for organic molecules\n 19 | params:\n 20 | smiles: str, smiles of a molecule,\n 21 | iupac: optional, default is None, iupac of molecule\n 22 | 23 | generate_molecule(params: typing.Annotated[str, 'Description of target molecule'], 24 | config: langchain_core.runnables.config.RunnableConfig) 25 | - Use this to generate a molecule with given description. 26 | Returns smiles. Only use for organic molecules""" 27 | 28 | response = """ 29 | Generate molecule with IC50 less then 5. 30 | """ 31 | state = {"input": response, "language": "English"} 32 | 33 | model = create_llm_connector( 34 | "https://api.vsegpt.ru/v1;meta-llama/llama-3.1-70b-instruct" 35 | ) 36 | conf = {"llm": model, "max_retries": 1, "tools_descp": tools_description} 37 | res = plan_node(state, conf) 38 | 39 | fake_plan = [ 40 | "1) Convert [Ca+2].[Cl-].[Cl-] to IUPAC name using smiles2name function with the given SMILES as input." 41 | ] 42 | state["plan"] = fake_plan 43 | state["past_steps"] = [(fake_plan[0], "Calcium chloride")] 44 | replan_res = replan_node(state, conf) 45 | 46 | print("Planner answer: ") 47 | print(res["plan"]) 48 | 49 | print("\nReplanner answer with fake plan: ") 50 | print(replan_res["plan"]) 51 | -------------------------------------------------------------------------------- /examples/universal_agents/summary_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of launching the Agent-Summarizer 3 | """ 4 | 5 | from protollm.agents.universal_agents import summary_node 6 | from protollm.connectors import create_llm_connector 7 | 8 | if __name__ == "__main__": 9 | query = "Make a story about a boy drawing. Recipe for duck in the oven" 10 | response = "I found answers to both questions." 11 | state = { 12 | "input": query, 13 | "response": response, 14 | # past_steps consist of List[Tuple('task', 'answer')] - task is name of func 15 | "past_steps": [("agent_storyteller", """The child sits at the table, holding a crayon tightly in their hand. 16 | They look at the blank sheet of paper, deep in thought, deciding what to 17 | draw. Slowly, they begin making small, careful lines, their concentration 18 | intense. As they gain confidence, the lines become bolder, forming shapes 19 | and patterns. The child adds colors, using bright reds, blues, and yellows, 20 | filling the spaces with excitement. They carefully stay within the lines, 21 | but sometimes a splash of color goes outside, adding a playful touch. 22 | Every so often, they stop to admire their work, grinning with pride. 23 | The drawing starts to come together, showing a simple scene—maybe a sun, 24 | a house, or a tree. With each stroke, the child’s imagination comes to 25 | life on the paper. Finally, they put down the crayon, pleased with their 26 | masterpiece, and smile at the colorful creation in front of them. """), 27 | ("web_search", """To cook duck in the oven, first rinse and pat dry the duck, seasoning 28 | it with salt, pepper, and your favorite spices. Stuff the cavity with a 29 | couple of garlic cloves and apples, then preheat the oven to 180°C (350°F). 30 | Place the duck in a roasting pan and roast for 1.5 to 2 hours, basting 31 | it occasionally with the juices. About 15 minutes before the end, brush 32 | the duck with honey or soy sauce for a golden, crispy skin, and serve 33 | it hot with a side of potatoes or vegetables.""")], 34 | } 35 | 36 | model = create_llm_connector( 37 | "https://api.vsegpt.ru/v1;meta-llama/llama-3.1-70b-instruct" 38 | ) 39 | conf = {"llm": model, "max_retries": 1} 40 | res = summary_node(state, conf) 41 | 42 | print("Total answer: ") 43 | print(res["response"]) 44 | -------------------------------------------------------------------------------- /examples/universal_agents/supervisor_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of launching the Supervisor agent 3 | """ 4 | 5 | from langchain_community.tools.tavily_search import TavilySearchResults 6 | from protollm.agents.universal_agents import supervisor_node 7 | 8 | from protollm.connectors import create_llm_connector 9 | 10 | if __name__ == "__main__": 11 | response = """ 12 | Find the molecule was discovered in 2008 by American scientists. 13 | """ 14 | state = { 15 | "input": response, 16 | "language": "English", 17 | "plan": [ 18 | "Use web_search to find the molecule was discovered in 2008 by American scientists." 19 | ], 20 | } 21 | 22 | model = create_llm_connector( 23 | "https://api.vsegpt.ru/v1;meta-llama/llama-3.1-70b-instruct" 24 | ) 25 | conf = { 26 | "llm": model, 27 | "max_retries": 1, 28 | "scenario_agents": ["web_search"], 29 | "tools_for_agents": {"web_serach": [TavilySearchResults]}, 30 | } 31 | res = supervisor_node(state, conf) 32 | print("Next node will be: ", res.update["next"]) 33 | -------------------------------------------------------------------------------- /protollm-synthetic/README.md: -------------------------------------------------------------------------------- 1 | # Protollm-synthetic 2 | 3 | This repository contains a set of tools for synthetic observation generation for fine-tuning LLM-based pipelines. 4 | 5 | Available pipelines: 6 | - Summarisation 7 | - RAG 8 | - Aspect Summarisation 9 | - Quiz generation 10 | - Free-form generation 11 | - Augmentation of existing dataset 12 | 13 | ## Installation 14 | 15 | ```bash 16 | poetry install 17 | ``` 18 | 19 | set OPENAI_API_KEY and OPENAI_API_BASE in your environment (it can be openai model or local model that is set with vllm openai server), e.g. 20 | ```bash 21 | export OPENAI_API_KEY=your_api_key 22 | export OPENAI_API_BASE=your_api_base 23 | ``` 24 | 25 | ## Usage 26 | 27 | Chains are the main building blocks of the library. They are designed to be used in a pipeline. 28 | 29 | ```python 30 | from protollm_synthetic.synthetic_pipelines.chains import SummarisationChain 31 | ``` 32 | 33 | To run a chain, you need to provide a dataset and a chain. 34 | 35 | ```python 36 | dataset = Dataset(path="data/sample_summarization_dataset.csv", labels=False) 37 | summarisation_chain = SummarisationChain(llm=llm) 38 | summaries = summarisation_chain.run(dataset, n_examples=100) 39 | ``` -------------------------------------------------------------------------------- /protollm-synthetic/examples/aspect_summarisation_example.py: -------------------------------------------------------------------------------- 1 | from protollm_synthetic.synthetic_pipelines.chains import AspectSummarisationChain 2 | from protollm_synthetic.utils import Dataset, VLLMChatOpenAI 3 | import pandas as pd 4 | import os 5 | import asyncio 6 | 7 | texts = [ 8 | "The quick brown fox jumps over the lazy dog. The fox is a cunning animal as some politicians.", 9 | "Artificial intelligence is transforming the world. AI is a powerful technology and influence much on the world politics.", 10 | "Python is a popular programming language." 11 | ] 12 | 13 | df = pd.DataFrame(texts, columns=["content"]) 14 | df.to_json("tmp_data/tmp_sample_summarization_dataset.json", index=False) 15 | 16 | dataset = Dataset(path="tmp_data/tmp_sample_summarization_dataset.json") 17 | # Expected output: a list of summaries 18 | expected_summaries = [ 19 | "The fox jumps over the dog.", 20 | "AI is changing the world.", 21 | "Python is a popular language." 22 | ] 23 | 24 | aspect = "politics" 25 | 26 | qwen_large_api_key = os.environ.get("OPENAI_API_KEY") 27 | qwen_large_api_base = os.environ.get("OPENAI_API_BASE") 28 | 29 | llm=VLLMChatOpenAI( 30 | api_key=qwen_large_api_key, 31 | base_url=qwen_large_api_base, 32 | model="/model", 33 | max_tokens=2048, 34 | # max_concurrency=10 35 | ) 36 | 37 | aspect_summarisation_chain = AspectSummarisationChain(llm=llm) 38 | actual_summaries = asyncio.run(aspect_summarisation_chain.run(dataset, aspect=aspect, n_examples=3)) 39 | print(actual_summaries) -------------------------------------------------------------------------------- /protollm-synthetic/examples/free_query_example.py: -------------------------------------------------------------------------------- 1 | import os 2 | from protollm_synthetic.synthetic_pipelines.chains import FreeQueryChain 3 | from protollm_synthetic.utils import Dataset, VLLMChatOpenAI 4 | import json 5 | import asyncio 6 | 7 | # создаем небольшой датасет с задачей извлечения событий из текста 8 | texts = [ 9 | "Сегодня хорошая погода в Москве", 10 | "Завтра в 11 будет важный созвон с заказчиком из Италии", 11 | "Я записан в бассейн 30.12.2024 в 12 и 31.12.2024 в 13. Удивляюсь, как мне это удалось", 12 | "Через неделю будет вечеринка в клубе 'Золотой' в 23:00", 13 | "Сегодня в 10:00 я занялся спортом", 14 | "19.01.2025 я записан к стоматологу" 15 | ] 16 | 17 | solutions = [ 18 | "[{'date': '22.12.2024', 'event': 'хорошая погода в Москве'}]", 19 | "[{'date': '23.12.2024', 'time': '11:00', 'event': 'созвон с заказчиком из Италии'}]", 20 | "[{'date': '30.12.2024', 'time': '12:00', 'event': 'запись в бассейн'}, {'date': '31.12.2024', 'time': '13:00', 'event': 'запись в бассейн'}]", 21 | None, 22 | None, 23 | None, 24 | ] 25 | 26 | data_dict = {'content': texts, 'solution': solutions} 27 | 28 | with open('tmp_data/sample_data_free_instruction.json', 'w', encoding='utf-8') as file: 29 | json.dump(data_dict, file, ensure_ascii=False) 30 | 31 | dataset = Dataset(data_col='content', labels_col='solution', path='tmp_data/sample_data_free_instruction.json') 32 | 33 | qwen_large_api_key = os.environ.get("OPENAI_API_KEY") 34 | qwen_large_api_base = os.environ.get("OPENAI_API_BASE") 35 | 36 | llm=VLLMChatOpenAI( 37 | api_key=qwen_large_api_key, 38 | base_url=qwen_large_api_base, 39 | model="/model", 40 | max_tokens=2048, 41 | # max_concurrency=10 42 | ) 43 | 44 | free_query_chain = FreeQueryChain(llm=llm) 45 | asyncio.run(free_query_chain.run(dataset, n_examples=3)) 46 | 47 | print(free_query_chain.data) 48 | -------------------------------------------------------------------------------- /protollm-synthetic/examples/summarisation_example.py: -------------------------------------------------------------------------------- 1 | from protollm_synthetic.synthetic_pipelines.chains import SummarisationChain 2 | from protollm_synthetic.utils import Dataset, VLLMChatOpenAI 3 | import pandas as pd 4 | import os 5 | import asyncio 6 | 7 | texts = [ 8 | "The quick brown fox jumps over the lazy dog.", 9 | "Artificial intelligence is transforming the world.", 10 | "Python is a popular programming language." 11 | ] 12 | 13 | df = pd.DataFrame(texts, columns=["content"]) 14 | df.to_json("tmp_data/tmp_sample_summarization_dataset.json", index=False) 15 | 16 | dataset = Dataset(path="tmp_data/tmp_sample_summarization_dataset.json") 17 | # Expected output: a list of summaries 18 | expected_summaries = [ 19 | "The fox jumps over the dog.", 20 | "AI is changing the world.", 21 | "Python is a popular language." 22 | ] 23 | 24 | qwen_large_api_key = os.environ.get("OPENAI_API_KEY") 25 | qwen_large_api_base = os.environ.get("OPENAI_API_BASE") 26 | 27 | llm=VLLMChatOpenAI( 28 | api_key=qwen_large_api_key, 29 | base_url=qwen_large_api_base, 30 | model="/model", 31 | max_tokens=2048, 32 | # max_concurrency=10 33 | ) 34 | 35 | summarisation_chain = SummarisationChain(llm=llm) 36 | actual_summaries = asyncio.run(summarisation_chain.run(dataset, n_examples=3)) 37 | print(actual_summaries) -------------------------------------------------------------------------------- /protollm-synthetic/protollm_synthetic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm-synthetic/protollm_synthetic/__init__.py -------------------------------------------------------------------------------- /protollm-synthetic/protollm_synthetic/synthetic_pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm-synthetic/protollm_synthetic/synthetic_pipelines/__init__.py -------------------------------------------------------------------------------- /protollm-synthetic/protollm_synthetic/synthetic_pipelines/genetic_evolver.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List 3 | 4 | class GeneticEvolver: 5 | def __init__(self, initial_population: List[str], 6 | generations: int = 10, mutation_rate: float = 0.1): 7 | self.initial_population = initial_population 8 | self.generations = generations 9 | self.mutation_rate = mutation_rate 10 | 11 | def evolve(self): 12 | population = self.initial_population 13 | for generation in range(self.generations): 14 | fitness_scores = [self.evaluate_prompt(prompt) for prompt in population] 15 | selected_parents = self.select_parents(population, fitness_scores) 16 | new_population = [] 17 | while len(new_population) < len(population): 18 | parent1, parent2 = random.sample(selected_parents, 2) 19 | child = self.crossover(parent1, parent2) 20 | if random.random() < self.mutation_rate: 21 | child = self.mutate_prompt(child) 22 | new_population.append(child) 23 | 24 | def crossover(self, parent1, parent2): 25 | # Simple crossover: concatenate the first half of parent1 with the second half of parent2 26 | return parent1[:len(parent1)//2] + parent2[len(parent2)//2:] 27 | 28 | def mutate_prompt(self, prompt): 29 | # Example mutation: add a random word 30 | words = ["optimize", "enhance", "improve", "boost"] 31 | return prompt + " " + random.choice(words) 32 | 33 | # Define a function to evaluate the prompt using an LLM-based evaluation function 34 | def evaluate_prompt(self, prompt): 35 | # Placeholder for LLM-based evaluation logic 36 | # Return a score representing the prompt's success 37 | return random.uniform(0, 1) # Example: random score for demonstration 38 | 39 | def select_parents(self, population, fitness_scores): 40 | # Select top 50% based on fitness 41 | sorted_population = [x for _, x in sorted(zip(fitness_scores, population), reverse=True)] 42 | return sorted_population[:len(population)//2] 43 | 44 | 45 | -------------------------------------------------------------------------------- /protollm-synthetic/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "protollm-synthetic" 3 | version = "0.1.0" 4 | description = "Sample generation with LLMs" 5 | authors = ["Your Name "] 6 | license = "BSD 3-Clause License" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.12" 11 | langchain = "^0.3.12" 12 | langchain-core = "^0.3.25" 13 | langchain-openai = "^0.2.12" 14 | langchain-community = "^0.3.12" 15 | pandas = "^2.2.3" 16 | 17 | 18 | [build-system] 19 | requires = ["poetry-core"] 20 | build-backend = "poetry.core.masonry.api" 21 | -------------------------------------------------------------------------------- /protollm-synthetic/tests/test_summarization_chain.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | from protollm_synthetic.synthetic_pipelines.chains import SummarisationChain 4 | from protollm_synthetic.utils import VLLMChatOpenAI, Dataset 5 | import pandas as pd 6 | import asyncio 7 | 8 | class TestSummarizationChain(unittest.TestCase): 9 | def test_summarization_chain_on_list_of_texts(self): 10 | # Sample input: a list of texts 11 | texts = [ 12 | "The quick brown fox jumps over the lazy dog.", 13 | "Artificial intelligence is transforming the world.", 14 | "Python is a popular programming language." 15 | ] 16 | 17 | df = pd.DataFrame(texts, columns=["content"]) 18 | df.to_json("tmp_data/tmp_sample_summarization_dataset.json", index=False) 19 | 20 | dataset = Dataset(path="tmp_data/tmp_sample_summarization_dataset.json") 21 | # Expected output: a list of summaries 22 | expected_summaries = [ 23 | "The fox jumps over the dog.", 24 | "AI is changing the world.", 25 | "Python is a popular language." 26 | ] 27 | 28 | qwen2vl_api_key = os.environ.get("QWEN2VL_OPENAI_API_KEY") 29 | qwen2vl_api_base = os.environ.get("QWEN2VL_OPENAI_API_BASE") 30 | 31 | llm=VLLMChatOpenAI( 32 | api_key=qwen2vl_api_key, 33 | base_url=qwen2vl_api_base, 34 | model="/model", 35 | max_tokens=2048, 36 | # max_concurrency=10 37 | ) 38 | 39 | summarisation_chain = SummarisationChain(llm=llm) 40 | actual_summaries = asyncio.run(summarisation_chain.run(dataset, n_examples=3)) 41 | 42 | # Assert that the actual summaries match the expected summaries 43 | self.assertEqual(len(actual_summaries), len(expected_summaries)) 44 | # self.assertEqual(actual_summaries, expected_summaries) 45 | 46 | if __name__ == '__main__': 47 | unittest.main() 48 | -------------------------------------------------------------------------------- /protollm-synthetic/tmp_data/rag_generated.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm-synthetic/tmp_data/rag_generated.pickle -------------------------------------------------------------------------------- /protollm-synthetic/tmp_data/sample_data_free_instruction.json: -------------------------------------------------------------------------------- 1 | {"content": ["Сегодня хорошая погода в Москве", "Завтра в 11 будет важный созвон с заказчиком из Италии", "Я записан в бассейн 30.12.2024 в 12 и 31.12.2024 в 13. Удивляюсь, как мне это удалось", "Через неделю будет вечеринка в клубе 'Золотой' в 23:00", "Сегодня в 10:00 я занялся спортом", "19.01.2025 я записан к стоматологу"], "solution": ["[{'date': '22.12.2024', 'event': 'хорошая погода в Москве'}]", "[{'date': '23.12.2024', 'time': '11:00', 'event': 'созвон с заказчиком из Италии'}]", "[{'date': '30.12.2024', 'time': '12:00', 'event': 'запись в бассейн'}, {'date': '31.12.2024', 'time': '13:00', 'event': 'запись в бассейн'}]", null, null, null]} -------------------------------------------------------------------------------- /protollm-synthetic/tmp_data/tmp_sample_summarization_dataset.json: -------------------------------------------------------------------------------- 1 | {"content":{"0":"The quick brown fox jumps over the lazy dog.","1":"Artificial intelligence is transforming the world.","2":"Python is a popular programming language."}} -------------------------------------------------------------------------------- /protollm/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm/.DS_Store -------------------------------------------------------------------------------- /protollm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm/__init__.py -------------------------------------------------------------------------------- /protollm/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from protollm.agents.llama31_agents import Llama31ChatModel 2 | -------------------------------------------------------------------------------- /protollm/agents/agent_utils/parsers.py: -------------------------------------------------------------------------------- 1 | from langchain_core.output_parsers import PydanticOutputParser 2 | 3 | from protollm.agents.agent_utils.pydantic_models import ( 4 | Act, 5 | Plan, 6 | Translation, 7 | Worker, 8 | ) 9 | from protollm.agents.agent_utils.pydantic_models import Chat 10 | 11 | chat_parser = PydanticOutputParser(pydantic_object=Chat) 12 | planner_parser = PydanticOutputParser(pydantic_object=Plan) 13 | supervisor_parser = PydanticOutputParser(pydantic_object=Worker) 14 | replanner_parser = PydanticOutputParser(pydantic_object=Act) 15 | translator_parser = PydanticOutputParser(pydantic_object=Translation) 16 | -------------------------------------------------------------------------------- /protollm/agents/agent_utils/pydantic_models.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | 6 | class Response(BaseModel): 7 | """Response to user.""" 8 | 9 | response: str 10 | 11 | 12 | class Plan(BaseModel): 13 | """Plan to follow in future""" 14 | 15 | steps: List[str] = Field( 16 | description="different steps to follow, should be in sorted order" 17 | ) 18 | 19 | 20 | class Act(BaseModel): 21 | """Action to perform.""" 22 | 23 | action: Union[Response, Plan] = Field( 24 | description="Action to perform. If you want to respond to user, use Response. " 25 | "If you need to further use tools to get the answer, use Plan." 26 | ) 27 | 28 | 29 | class Worker(BaseModel): 30 | """Worker to call in future""" 31 | 32 | next: str = Field(description="Next worker to call") 33 | 34 | 35 | class Chat(BaseModel): 36 | """Action to perform""" 37 | 38 | action: Union[Response, Worker] = Field( 39 | description="Action to perform. If you want to respond to user, use Response. " 40 | "If you need to further use tools to get the answer, use Next." 41 | ) 42 | 43 | last_memory: Optional[str] = Field( 44 | description="last memory of the user, if any", default="" 45 | ) 46 | 47 | 48 | class Translation(BaseModel): 49 | """Action to perform""" 50 | 51 | language: Optional[str] = Field( 52 | description="language to translate", default="English" 53 | ) 54 | translation: Optional[str] = Field( 55 | default=None, description="translation from English" 56 | ) 57 | -------------------------------------------------------------------------------- /protollm/agents/agent_utils/states.py: -------------------------------------------------------------------------------- 1 | from typing_extensions import TypedDict 2 | from typing import Annotated, List, Tuple 3 | import operator 4 | from protollm.agents.universal_agents import store 5 | 6 | 7 | class PlanExecute(TypedDict): 8 | input: str 9 | plan: List[str] 10 | past_steps: Annotated[List[Tuple], operator.add] 11 | next: str 12 | response: str 13 | visualization: str 14 | language: str 15 | translation: str 16 | automl_results: str 17 | nodes_calls: Annotated[List[Tuple], operator.add] 18 | last_memory: str 19 | 20 | 21 | def load_summary(user_id: str) -> str: 22 | namespace = (user_id, "memory") 23 | item = store.get(namespace, "latest-summary") 24 | return item.value.get("summary", "") if item else "" 25 | 26 | 27 | def initialize_state(user_input: str, user_id: str) -> PlanExecute: 28 | memory = load_summary(user_id) 29 | return { 30 | "input": user_input, 31 | "plan": [], 32 | "past_steps": [], 33 | "next": "", 34 | "response": "", 35 | "visualization": "", 36 | "language": "", 37 | "translation": "", 38 | "automl_results": "", 39 | "nodes_calls": [], 40 | "last_memory": memory, 41 | } 42 | -------------------------------------------------------------------------------- /protollm/agents/llama31_agents/__init__.py: -------------------------------------------------------------------------------- 1 | from protollm.agents.llama31_agents.llama31_agent import Llama31ChatModel 2 | -------------------------------------------------------------------------------- /protollm/connectors/__init__.py: -------------------------------------------------------------------------------- 1 | from .connector_creator import create_llm_connector, CustomChatOpenAI 2 | from .rest_server import ChatRESTServer -------------------------------------------------------------------------------- /protollm/definitions.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | 5 | CONFIG_PATH = os.path.join(ROOT_DIR, 'config.env') 6 | -------------------------------------------------------------------------------- /protollm/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluation_metrics import correctness_metric, model_for_metrics 2 | -------------------------------------------------------------------------------- /protollm/metrics/deepeval_connector.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from deepeval.models.base_model import DeepEvalBaseLLM 4 | from langchain_core.language_models.chat_models import BaseChatModel 5 | from pydantic import BaseModel 6 | from openai._types import NOT_GIVEN 7 | 8 | from ..connectors import create_llm_connector 9 | 10 | 11 | class DeepEvalConnector(DeepEvalBaseLLM): 12 | """Implementation of Evaluation agent based on large language model for Assistant's answers evaluation. 13 | 14 | Uses the LangChain's ChatModel to make requests to a compatible API. Must inherit from the base class and 15 | implement a set of methods. 16 | The vsegpt.ru service is used by default, so in the configuration file it is necessary to specify the API key of 17 | this service and the model name, as in the general case. 18 | """ 19 | 20 | def __init__(self, sys_prompt: str = "", *args, **kwargs): 21 | """Initialize instance with evaluation LLM. 22 | 23 | Args: 24 | sys_prompt: predefined rules for model 25 | """ 26 | super().__init__(*args, **kwargs) 27 | self._sys_prompt = sys_prompt 28 | self.model = self.load_model() 29 | 30 | @staticmethod 31 | def load_model() -> BaseChatModel: 32 | """Returns LangChain's ChatModel for requests""" 33 | return create_llm_connector(os.getenv("DEEPEVAL_LLM_URL", "test_model")) 34 | 35 | def generate( 36 | self, 37 | prompt: str, 38 | *args, 39 | **kwargs, 40 | ) -> str | BaseModel: 41 | """Get a response from LLM to given question. 42 | 43 | Args: 44 | prompt (str): Query, the model must answer. 45 | 46 | Returns: 47 | str: Model's response. 48 | """ 49 | messages = [ 50 | {"role": "system", "content": self._sys_prompt}, 51 | {"role": "user", "content": prompt}, 52 | ] 53 | response_format = kwargs.get("schema", NOT_GIVEN) 54 | if response_format == NOT_GIVEN: 55 | return self.model.invoke(messages).content 56 | else: 57 | struct_llm = self.model.with_structured_output(schema=response_format, method="json_mode") 58 | return struct_llm.invoke(messages) 59 | 60 | async def a_generate( 61 | self, 62 | prompt: str, 63 | *args, 64 | **kwargs, 65 | ) -> str: 66 | """Same as synchronous generate method just because it must be implemented""" 67 | return self.generate( 68 | prompt, *args, **kwargs 69 | ) 70 | 71 | def get_model_name(self, *args, **kwargs) -> str: 72 | """Returns a description of what the class is about""" 73 | return "Implementation of custom LLM connector using OpenAI compatible API for evaluation." 74 | -------------------------------------------------------------------------------- /protollm/metrics/evaluation_metrics.py: -------------------------------------------------------------------------------- 1 | from deepeval.metrics import GEval 2 | from deepeval.test_case import LLMTestCaseParams 3 | 4 | from .deepeval_connector import DeepEvalConnector 5 | 6 | model_for_metrics = DeepEvalConnector() 7 | 8 | # Custom metric for evaluating the correctness of an answer 9 | correctness_metric = GEval( 10 | name="Correctness", 11 | criteria=( 12 | "1. Correctness and Relevance:" 13 | "- Compare the actual response against the expected response. Determine the" 14 | " extent to which the actual response captures the key elements and concepts of" 15 | " the expected response." 16 | "- Assign higher scores to actual responses that accurately reflect the core" 17 | " information of the expected response, even if only partial." 18 | "2. Numerical Accuracy and Interpretation:" 19 | "- Pay particular attention to any numerical values present in the expected" 20 | " response. Verify that these values are correctly included in the actual" 21 | " response and accurately interpreted within the context." 22 | "- Ensure that units of measurement, scales, and numerical relationships are" 23 | " preserved and correctly conveyed." 24 | "3. Allowance for Partial Information:" 25 | "- Do not heavily penalize the actual response for incompleteness if it covers" 26 | " significant aspects of the expected response. Prioritize the correctness of" 27 | " provided information over total completeness." 28 | "4. Handling of Extraneous Information:" 29 | "- While additional information not present in the expected response should not" 30 | " necessarily reduce score, ensure that such additions do not introduce" 31 | " inaccuracies or deviate from the context of the expected response." 32 | ), 33 | evaluation_params=[ 34 | LLMTestCaseParams.ACTUAL_OUTPUT, 35 | LLMTestCaseParams.EXPECTED_OUTPUT, 36 | ], 37 | model=model_for_metrics, 38 | async_mode=False 39 | ) 40 | -------------------------------------------------------------------------------- /protollm/rags/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm/rags/__init__.py -------------------------------------------------------------------------------- /protollm/rags/configs/chroma.env: -------------------------------------------------------------------------------- 1 | # Chroma DB Settings 2 | CHROMA_HOST='any' 3 | CHROMA_PORT=any 4 | ALLOW_RESET=False 5 | 6 | # Documents collection's settings 7 | COLLECTION_NAME='any' 8 | COLLECTION_NAMES_FOR_ADVANCE=['any'] 9 | EMBEDDING_NAME='any' 10 | EMBEDDING_HOST='any' 11 | DISTANCE_FN='cosine' -------------------------------------------------------------------------------- /protollm/rags/configs/docs_processing_config.yaml: -------------------------------------------------------------------------------- 1 | loader: 2 | loader_name: 'PDFLoader' 3 | parsing_params: 4 | parsing_scheme: 'paragraphs' 5 | extract_images: False 6 | extract_tables: False 7 | parse_formulas: False 8 | remove_service_info: True 9 | handle_converting_error: False 10 | 11 | 12 | splitter: 13 | splitter_name: 'hierarchical_merger' 14 | splitter_params: 15 | chunk_size: 510 16 | chunk_overlap: 0 17 | separators: 18 | - '\n\n' 19 | - '\n' 20 | - '. ' 21 | - ', ' 22 | - '.' 23 | - ',' 24 | - ' ' 25 | - '' 26 | keep_separator: False 27 | add_start_index: False 28 | strip_whitespace: True 29 | apply_chunks_merge: True 30 | 31 | #tokenizer: 'hf-internal-testing/llama-tokenizer' 32 | tokenizer: 'any' -------------------------------------------------------------------------------- /protollm/rags/configs/elastic.env: -------------------------------------------------------------------------------- 1 | ES_HOST=any 2 | ES_PORT=any 3 | ES_USER=any 4 | ES_PASSWORD=any 5 | 6 | ES_INDEX_MAPPINGS: dict = json.loads(Path(CONFIG_PATH, 'index_mappings.json').read_text(encoding="utf-8")) 7 | ES_INDEX_SETTINGS: dict = json.loads(Path(CONFIG_PATH, 'index_settings.json').read_text(encoding="utf-8")) 8 | es_query_template: dict = json.loads(Path(CONFIG_PATH, 'query_template.json').read_text(encoding="utf-8")) 9 | es_query_all_hits: dict = json.loads(Path(CONFIG_PATH, 'query_all_hits.json').read_text(encoding="utf-8")) 10 | 11 | metadata_fields: list[str] = list(es_index_mappings['properties']['metadata']['properties'].keys()) 12 | content_field: str = 'paragraph' -------------------------------------------------------------------------------- /protollm/rags/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from protollm.rags.pipeline.etl_pipeline import DocsLoadPipeline, DocsTransformPipeline, DocsExtractPipeline 2 | -------------------------------------------------------------------------------- /protollm/rags/pipeline/docs_processing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm/rags/pipeline/docs_processing/__init__.py -------------------------------------------------------------------------------- /protollm/rags/pipeline/docs_processing/entities.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from json import load 3 | from typing import Iterator 4 | 5 | from langchain_core.document_loaders import BaseLoader 6 | from langchain_core.documents import Document 7 | from langchain_core.load import load as ln_load 8 | 9 | from protollm.raw_data_processing.docs_transformers import ChunkMerger, RecursiveSplitter 10 | from protollm.raw_data_processing.docs_transformers.metadata_sentence_splitter import DivMetadataSentencesSplitter 11 | from protollm.raw_data_processing.docs_transformers.key_words_splitter import MultiMetadataAppender 12 | 13 | 14 | transformer_object_dict = { 15 | 'recursive_character': RecursiveSplitter, 16 | # 'list_hierarchy': ListHierarchySplitter, 17 | 'hierarchical_merger': ChunkMerger, 18 | 'div_sentence_splitter': DivMetadataSentencesSplitter, 19 | 'keyword_appender': MultiMetadataAppender 20 | } 21 | 22 | 23 | class LoaderType(str, Enum): 24 | docx = 'docx' 25 | doc = 'doc' 26 | odt = 'odt' 27 | rtf = 'rtf' 28 | pdf = 'pdf' 29 | directory = 'directory' 30 | zip = 'zip' 31 | json = 'json' 32 | 33 | 34 | class LangChainDocumentLoader(BaseLoader): 35 | def __init__(self, file_path: str): 36 | self.file_path = file_path 37 | 38 | def lazy_load(self) -> Iterator[Document]: 39 | with open(self.file_path, 'r') as f: 40 | for i, doc_dict in load(f).items(): 41 | yield ln_load(doc_dict) 42 | -------------------------------------------------------------------------------- /protollm/rags/pipeline/docs_processing/exceptions.py: -------------------------------------------------------------------------------- 1 | class PathIsNotAssigned(Exception): 2 | def __init__(self, message): 3 | super.__init__(message) 4 | 5 | 6 | class PipelineError(Exception): 7 | def __init__(self, message): 8 | super().__init__(message) 9 | 10 | 11 | class FileExtensionError(Exception): 12 | def __init__(self, message): 13 | super().__init__(message) 14 | 15 | 16 | class TransformerNameError(Exception): 17 | def __init__(self, message): 18 | super().__init__(message) 19 | 20 | 21 | class LoaderNameError(Exception): 22 | def __init__(self, message): 23 | super().__init__(message) -------------------------------------------------------------------------------- /protollm/rags/pipeline/docs_processing/models.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class ConfigLoader(BaseModel): 7 | file_path: str = '' 8 | save_path: str = '' 9 | loader_name: str 10 | parsing_params: dict[str, Any] = dict() 11 | 12 | 13 | class ConfigSplitter(BaseModel): 14 | splitter_name: str | None = None 15 | splitter_params: dict[str, Any] = dict() 16 | 17 | 18 | class ConfigFile(BaseModel): 19 | loader: ConfigLoader 20 | splitter: List[ConfigSplitter] = [] 21 | tokenizer: str | None = None 22 | -------------------------------------------------------------------------------- /protollm/rags/pipeline/docs_processing/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Any 4 | 5 | from protollm.raw_data_processing.docs_parsers.loaders import PDFLoader, WordDocumentLoader, ZipLoader, \ 6 | RecursiveDirectoryLoader 7 | from langchain_core.document_loaders import BaseLoader 8 | 9 | from protollm.rags.pipeline.docs_processing.entities import LoaderType, LangChainDocumentLoader 10 | from protollm.rags.pipeline.docs_processing.exceptions import FileExtensionError, PathIsNotAssigned 11 | 12 | 13 | def get_loader(**loader_params) -> BaseLoader: 14 | document_path = loader_params.get('file_path', '') 15 | if not isinstance(document_path, (str, Path)) or document_path == '': 16 | raise PathIsNotAssigned('Input file (directory) path is not assigned') 17 | doc_extension = str(document_path).lower().split('.')[-1] 18 | if os.path.isdir(document_path): 19 | doc_extension = LoaderType.directory 20 | match doc_extension: 21 | case LoaderType.pdf: 22 | return PDFLoader(**loader_params) 23 | case LoaderType.json: 24 | return LangChainDocumentLoader(**loader_params) 25 | case LoaderType.docx | LoaderType.doc | LoaderType.rtf | LoaderType.odt: 26 | return WordDocumentLoader(**loader_params) 27 | 28 | parsing_scheme = loader_params.pop('parsing_scheme', 'lines') 29 | extract_images = loader_params.pop('extract_images', False) 30 | extract_tables = loader_params.pop('extract_tables', False) 31 | parse_formulas = loader_params.pop('parse_formulas', False) 32 | remove_service_info = loader_params.pop('remove_service_info', False) 33 | loader_params = dict( 34 | pdf_parsing_scheme=parsing_scheme, 35 | pdf_extract_images=extract_images, 36 | pdf_extract_tables=extract_tables, 37 | pdf_parse_formulas=parse_formulas, 38 | pdf_remove_service_info=remove_service_info, 39 | word_doc_parsing_scheme=parsing_scheme, 40 | word_doc_extract_images=extract_images, 41 | word_doc_extract_tables=extract_tables, 42 | word_doc_parse_formulas=parse_formulas, 43 | word_doc_remove_service_info=remove_service_info, 44 | **loader_params, 45 | ) 46 | match doc_extension: 47 | case LoaderType.zip: 48 | return ZipLoader(**loader_params) 49 | case LoaderType.directory: 50 | return RecursiveDirectoryLoader(**loader_params) 51 | case _: 52 | raise FileExtensionError(f'File with extension {doc_extension} has not been implemented yet.') 53 | -------------------------------------------------------------------------------- /protollm/rags/rag_core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm/rags/rag_core/__init__.py -------------------------------------------------------------------------------- /protollm/rags/rag_core/planner.py: -------------------------------------------------------------------------------- 1 | import ast 2 | 3 | from langchain_core.language_models import LLM 4 | from langchain_core.prompts import PromptTemplate 5 | 6 | 7 | class Planner: 8 | 9 | def __init__(self, llm: LLM, prompt_template: PromptTemplate) -> None: 10 | self._prompt_template = prompt_template 11 | self._llm = llm 12 | 13 | def generate_answer(self, query: str | list[str]) -> list: 14 | if isinstance(query, str): 15 | query = [query] 16 | inst_query = [self._prompt_template.format(context=i) for i in query] 17 | answer = [self._llm.invoke(prompt) for prompt in inst_query] 18 | good_answers, bad_answers = self._extract_planner_queries(answer) 19 | if len(bad_answers) > 0: 20 | updated_answers = self._regenerate_answer(bad_answers) 21 | good_answers += updated_answers 22 | return good_answers 23 | 24 | def _extract_planner_queries(self, answer: list): 25 | bad_idx = [] 26 | for i, ans in enumerate(answer): 27 | try: 28 | result = ast.literal_eval(ans.content.split('ЗАПРОСЫ:')[1].strip().split(']')[0] + ']') 29 | answer[i] = result 30 | except: 31 | bad_idx.append(i) 32 | return [answer[i] for i in range(len(answer)) if i not in bad_idx], [answer[i] for i in range(len(answer)) if 33 | i in bad_idx] 34 | 35 | def _regenerate_answer(self, query: list, retries: int = 3): 36 | fixed_queries = [] 37 | for trial in range(retries): 38 | result = [self._llm.invoke(prompt) for prompt in query] 39 | good_res, bad_res = self._extract_planner_queries(result) 40 | fixed_queries += good_res 41 | query = bad_res 42 | if len(bad_res) == 0: 43 | return fixed_queries 44 | return fixed_queries 45 | -------------------------------------------------------------------------------- /protollm/rags/settings/__init__.py: -------------------------------------------------------------------------------- 1 | from protollm.rags.settings.pipeline_settings import PipelineSettings 2 | from protollm.rags.settings.es_settings import settings as es_settings 3 | from protollm.rags.settings.chroma_settings import settings as default_settings 4 | -------------------------------------------------------------------------------- /protollm/rags/settings/chroma_settings.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname 2 | from pathlib import Path 3 | from pydantic_settings import BaseSettings, SettingsConfigDict 4 | 5 | 6 | class ChromaSettings(BaseSettings): 7 | # Chroma DB settings 8 | chroma_host: str = 'any' 9 | chroma_port: int = 8888 10 | allow_reset: bool = False 11 | 12 | # Documents collection's settings 13 | collection_name: str = 'collection' 14 | collection_names_for_advance: list[str] = ['collection'] 15 | embedding_name: str = 'intfloat/multilingual-e5-large' 16 | embedding_host: str = '' 17 | distance_fn: str = 'cosine' 18 | 19 | # Documents' processing settings 20 | docs_processing_config: str = str(Path(dirname(dirname(__file__)), '/config_files/', 'docs_processing_config.yaml')) 21 | docs_collection_path: str = str(Path(dirname(dirname(dirname(__file__))), '/docs/', 'example.docx')) 22 | 23 | model_config = SettingsConfigDict( 24 | env_file=Path(dirname(dirname(__file__)), '/config_files/', 'chroma.env'), 25 | env_file_encoding='utf-8', 26 | extra='ignore', 27 | ) 28 | 29 | 30 | settings = ChromaSettings() 31 | -------------------------------------------------------------------------------- /protollm/rags/settings/es_settings.py: -------------------------------------------------------------------------------- 1 | import json 2 | from os.path import dirname 3 | from pathlib import Path 4 | 5 | from pydantic import computed_field 6 | from pydantic_settings import BaseSettings, SettingsConfigDict 7 | 8 | 9 | CONFIG_PATH = Path(dirname(dirname(__file__)), '/stores/elasticsearch/configs') 10 | 11 | 12 | class ElasticsearchSettings(BaseSettings): 13 | es_host: str = "elasticDUMB" 14 | es_port: int = 9201 15 | 16 | es_user: str = "elastic" 17 | es_password: str = "admin" 18 | 19 | @computed_field 20 | @property 21 | def es_url(self) -> str: 22 | return f"http://{self.es_host}:{self.es_port}" 23 | 24 | es_index_mappings: dict = json.loads(Path(CONFIG_PATH, 'index_mappings.json').read_text(encoding="utf-8")) 25 | es_index_settings: dict = json.loads(Path(CONFIG_PATH, 'index_settings.json').read_text(encoding="utf-8")) 26 | es_query_template: dict = json.loads(Path(CONFIG_PATH, 'query_template.json').read_text(encoding="utf-8")) 27 | es_query_all_hits: dict = json.loads(Path(CONFIG_PATH, 'query_all_hits.json').read_text(encoding="utf-8")) 28 | 29 | metadata_fields: list[str] = list(es_index_mappings['properties']['metadata']['properties'].keys()) 30 | content_field: str = 'paragraph' 31 | 32 | model_config = SettingsConfigDict( 33 | env_file=Path(Path(__file__).parent.parent, '/configs/elastic.env'), 34 | env_file_encoding='utf-8', 35 | extra='ignore', 36 | ) 37 | 38 | 39 | settings = ElasticsearchSettings() 40 | -------------------------------------------------------------------------------- /protollm/rags/stores/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm/rags/stores/__init__.py -------------------------------------------------------------------------------- /protollm/rags/stores/chroma/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm/rags/stores/chroma/__init__.py -------------------------------------------------------------------------------- /protollm/rags/stores/chroma/chroma_loader.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import logging 3 | 4 | import chromadb 5 | from langchain_community.embeddings.huggingface_hub import HuggingFaceHubEmbeddings 6 | from langchain_community.vectorstores.chroma import Chroma 7 | 8 | from protollm.rags.pipeline.etl_pipeline import DocsExtractPipeline 9 | from protollm.rags.settings.pipeline_settings import PipelineSettings 10 | from protollm.rags.settings.chroma_settings import ChromaSettings, settings as default_settings 11 | 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(logging.INFO) 14 | 15 | 16 | def load_documents_to_chroma_db(settings: Optional[ChromaSettings] = None, 17 | processing_batch_size: int = 100, 18 | loading_batch_size: int = 32, 19 | **kwargs) -> None: 20 | 21 | if settings is None: 22 | settings = default_settings 23 | 24 | logger.info( 25 | f'Initializing batch generator with processing_batch_size: {processing_batch_size},' 26 | f' loading_batch_size: {loading_batch_size}' 27 | ) 28 | 29 | pipeline_settings = PipelineSettings.config_from_file(settings.docs_processing_config) 30 | 31 | store = Chroma(collection_name=settings.collection_name, 32 | embedding_function=HuggingFaceHubEmbeddings(model=settings.embedding_host, huggingfacehub_api_token='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'), 33 | client=chromadb.HttpClient(host=settings.chroma_host, port=settings.chroma_port)) 34 | 35 | # Documents loading and processing 36 | DocsExtractPipeline(pipeline_settings) \ 37 | .go_to_next_step(docs_collection_path=settings.docs_collection_path) \ 38 | .update_docs_transformers(**kwargs) \ 39 | .go_to_next_step(batch_size=processing_batch_size) \ 40 | .load(store, loading_batch_size=loading_batch_size) 41 | -------------------------------------------------------------------------------- /protollm/rags/stores/elasticsearch/__init__.py: -------------------------------------------------------------------------------- 1 | from protollm.rags.stores.elasticsearch.utilities import (get_index_name, get_elasticsearch_store, 2 | custom_query_for_metadata_mapping) 3 | from protollm.rags.stores.elasticsearch.retrieval_strategies import BM25RetrievalStrategy 4 | -------------------------------------------------------------------------------- /protollm/rags/stores/elasticsearch/configs/index_mappings.json: -------------------------------------------------------------------------------- 1 | { 2 | "properties": { 3 | "paragraph": {"type": "text", "analyzer": "russian_analyzer"}, 4 | "metadata": { 5 | "properties": { 6 | "doc_idx": { 7 | "type": "float" 8 | }, 9 | "paragraph_number": { 10 | "type": "text" 11 | }, 12 | "doc_name": { 13 | "type": "text", 14 | "analyzer": "russian_analyzer" 15 | }, 16 | "title_0": { 17 | "type": "text", 18 | "analyzer": "russian_analyzer" 19 | }, 20 | "title_1": { 21 | "type": "text", 22 | "analyzer": "russian_analyzer" 23 | }, 24 | "title_2": { 25 | "type": "text", 26 | "analyzer": "russian_analyzer" 27 | }, 28 | "title_3": { 29 | "type": "text", 30 | "analyzer": "russian_analyzer" 31 | } 32 | } 33 | } 34 | } 35 | } -------------------------------------------------------------------------------- /protollm/rags/stores/elasticsearch/configs/index_settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "analysis": { 3 | "filter": { 4 | "russian_stop": { 5 | "type": "stop", 6 | "stopwords": "_russian_" 7 | }, 8 | "russian_keywords": { 9 | "type": "keyword_marker", 10 | "keywords": [] 11 | }, 12 | "russian_stemmer": { 13 | "type": "stemmer", 14 | "language": "russian" 15 | } 16 | }, 17 | "analyzer": { 18 | "russian_analyzer": { 19 | "type": "custom", 20 | "tokenizer": "standard", 21 | "filter": ["lowercase", 22 | "russian_stop", 23 | "russian_keywords", 24 | "russian_stemmer" 25 | ] 26 | } 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /protollm/rags/stores/elasticsearch/configs/query_all_hits.json: -------------------------------------------------------------------------------- 1 | { 2 | "query": { 3 | "match_all": {} 4 | } 5 | } -------------------------------------------------------------------------------- /protollm/rags/stores/elasticsearch/configs/query_template.json: -------------------------------------------------------------------------------- 1 | { 2 | "multi_match": { 3 | "query": "", 4 | "type": "most_fields", 5 | "fields": ["doc_name", 6 | "title_0", 7 | "title_1", 8 | "title_2", 9 | "title_3^2", 10 | "paragraph^3"] 11 | } 12 | } -------------------------------------------------------------------------------- /protollm/rags/stores/elasticsearch/retrieval_strategies.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from langchain_elasticsearch._utilities import DistanceStrategy 4 | from langchain_elasticsearch.vectorstores import BaseRetrievalStrategy 5 | 6 | from protollm.rags.settings import es_settings 7 | 8 | 9 | class BM25RetrievalStrategy(BaseRetrievalStrategy): 10 | """BM25 retrieval strategy for ElasticSearch""" 11 | 12 | def index( 13 | self, 14 | dims_length: Union[int, None], 15 | vector_query_field: str, 16 | similarity: Union[DistanceStrategy, None], 17 | ) -> dict: 18 | index_kwargs = {"settings": es_settings.es_index_settings, 19 | "mappings": es_settings.es_index_mappings} 20 | return index_kwargs 21 | 22 | def query( 23 | self, 24 | query_vector: Union[list[float], None], 25 | query: Union[str, None], 26 | *, 27 | k: int, 28 | fetch_k: int, 29 | vector_query_field: str, 30 | text_field: str, 31 | filter: list[dict], 32 | similarity: Union[DistanceStrategy, None], 33 | ) -> dict: 34 | if query is None: 35 | raise ValueError( 36 | "You must provide a query to perform a similarity search." 37 | ) 38 | new_query = dict(es_settings.es_query_template) 39 | new_query['multi_match']['query'] = query 40 | query_body = {'query': new_query} 41 | return query_body 42 | 43 | def require_inference(self) -> bool: 44 | return False 45 | -------------------------------------------------------------------------------- /protollm/rags/stores/elasticsearch/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from pathlib import Path 4 | 5 | 6 | CONFIG_PATH = Path(Path(__file__).parent, 'configs') 7 | 8 | 9 | class ElasticsearchSettings: 10 | es_host: str = os.environ.get("ELASTIC_HOST", "") 11 | es_port: int = os.environ.get("ELASTIC_PORT", 80) 12 | es_url: str = f"http://{es_host}:{es_port}" 13 | es_user: str = os.environ.get("ELASTIC_USER", "") 14 | es_password: str = os.environ.get("ELASTIC_PASSWORD", "") 15 | 16 | es_index_mappings: dict = json.loads(Path(CONFIG_PATH, 'index_mappings.json').read_text(encoding="utf-8")) 17 | es_index_settings: dict = json.loads(Path(CONFIG_PATH, 'index_settings.json').read_text(encoding="utf-8")) 18 | es_query_template: dict = json.loads(Path(CONFIG_PATH, 'query_template.json').read_text(encoding="utf-8")) 19 | es_query_all_hits: dict = json.loads(Path(CONFIG_PATH, 'query_all_hits.json').read_text(encoding="utf-8")) 20 | 21 | metadata_fields: list[str] = list(es_index_mappings['properties']['metadata']['properties'].keys()) 22 | content_field: str = 'paragraph' 23 | 24 | 25 | settings = ElasticsearchSettings() 26 | -------------------------------------------------------------------------------- /protollm/rags/stores/elasticsearch/utilities.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any 2 | 3 | from langchain_elasticsearch import ElasticsearchStore 4 | from langchain_elasticsearch.vectorstores import BaseRetrievalStrategy 5 | 6 | from protollm.rags.stores.elasticsearch.settings import settings 7 | from protollm.rags.stores.elasticsearch.retrieval_strategies import BM25RetrievalStrategy 8 | 9 | 10 | def get_index_name(index: int) -> str: 11 | return f"index_v{index}" 12 | 13 | 14 | def get_elasticsearch_store(index_name: str, 15 | es_url: str = settings.es_url, 16 | es_user: str = settings.es_user, 17 | es_password: str = settings.es_password, 18 | query_field: str = settings.content_field, 19 | strategy: BaseRetrievalStrategy = BM25RetrievalStrategy(), 20 | es_params: Optional[dict[str, Any]] = None) -> ElasticsearchStore: 21 | return ElasticsearchStore(index_name, 22 | es_url=es_url, 23 | es_user=es_user, 24 | es_password=es_password, 25 | query_field=query_field, 26 | strategy=strategy, 27 | es_params=es_params) 28 | 29 | 30 | def custom_query_for_metadata_mapping(query_body: dict, query: str) -> dict: 31 | """Custom query to be used in Elasticsearch with indexes that use langchain Document schema. 32 | This implies that all additional fields are stored in the metadata 33 | Args: 34 | query_body (dict): Elasticsearch query body. 35 | query (str): Query string. 36 | Returns: 37 | dict: Elasticsearch query body. 38 | """ 39 | query = query_body['query'] 40 | if 'multi_match' in query: 41 | if 'fields' in query['multi_match']: 42 | fields = [] 43 | for field in query['multi_match']['fields']: 44 | if not (field.startswith('metadata') or field.startswith(settings.content_field)): 45 | field = 'metadata.' + field 46 | fields.append(field) 47 | 48 | query['multi_match']['fields'] = fields 49 | 50 | return query_body 51 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm/raw_data_processing/__init__.py -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm/raw_data_processing/docs_parsers/__init__.py -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | from doc_loader import WordDocumentLoader 2 | from directory_loader import RecursiveDirectoryLoader 3 | from pdf_loader import PDFLoader 4 | from zip_loader import ZipLoader 5 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/loaders/pdf_loader.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Iterator, Union, Any, Optional 3 | 4 | from langchain_core.document_loaders import BaseLoader, Blob 5 | from langchain_core.documents import Document 6 | 7 | from protollm.raw_data_processing.docs_parsers.parsers import PDFParser, ParsingScheme, DocType 8 | from protollm.raw_data_processing.docs_parsers.utils.logger import ParsingLogger 9 | 10 | 11 | class PDFLoader(BaseLoader): 12 | """ 13 | Load PDF into list of documents. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | file_path: Union[str, Path], 19 | byte_content: Optional[bytes] = None, 20 | parsing_scheme: Union[ParsingScheme, str] = ParsingScheme.lines, 21 | extract_images: bool = False, 22 | extract_tables: bool = False, 23 | extract_formulas: bool = False, 24 | remove_headers: bool = False, 25 | parsing_logger: Optional[ParsingLogger] = None, 26 | **kwargs: Any, 27 | ) -> None: 28 | """Initialize with a file path.""" 29 | self.file_path = str(file_path) 30 | doc_type = PDFParser.get_doc_type(self.file_path) 31 | if doc_type is not DocType.pdf: 32 | if doc_type is DocType.unsupported: 33 | raise ValueError("The file type is unsupported") 34 | else: 35 | raise ValueError( 36 | f"The {doc_type} file type does not match the Loader! Use a suitable one." 37 | ) 38 | self.byte_content = byte_content 39 | self._logger = parsing_logger or ParsingLogger(name=__name__) 40 | self.parser = PDFParser( 41 | parsing_scheme, 42 | extract_images, 43 | extract_tables, 44 | extract_formulas, 45 | remove_headers, 46 | ) 47 | 48 | @property 49 | def logs(self): 50 | return self._logger.logs 51 | 52 | def lazy_load( 53 | self, 54 | ) -> Iterator[Document]: 55 | """Lazy load given path""" 56 | if self.byte_content is None: 57 | blob = Blob.from_path(self.file_path) 58 | else: 59 | blob = Blob.from_data( 60 | self.byte_content, path=self.file_path, mime_type=DocType.pdf.value 61 | ) 62 | with self._logger.parsing_info_handler(self.file_path): 63 | yield from self.parser.lazy_parse(blob) 64 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/parsers/__init__.py: -------------------------------------------------------------------------------- 1 | from protollm.raw_data_processing.docs_parsers.parsers.base import BaseParser 2 | from protollm.raw_data_processing.docs_parsers.parsers.entities import DocType, ParsingScheme 3 | from protollm.raw_data_processing.docs_parsers.parsers.pdf import PDFParser 4 | from protollm.raw_data_processing.docs_parsers.parsers.word_doc import WordDocumentParser 5 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/parsers/base.py: -------------------------------------------------------------------------------- 1 | import mimetypes 2 | from abc import ABC, abstractmethod 3 | from pathlib import Path 4 | from typing import Iterator, List, Union 5 | 6 | from langchain_core.document_loaders import Blob 7 | from langchain_core.documents import Document 8 | 9 | from protollm.raw_data_processing.docs_parsers.parsers.entities import DocType 10 | 11 | 12 | class BaseParser(ABC): 13 | """ 14 | Abstract interface for parsers. 15 | 16 | A parser provides a way to parse raw data into one or more documents. 17 | """ 18 | 19 | @abstractmethod 20 | def lazy_parse(self, blob: Blob) -> Iterator[Document]: 21 | """Lazy parsing interface. 22 | 23 | Subclasses are required to implement this method. 24 | 25 | Args: 26 | blob: representation of raw data from file 27 | 28 | Returns: 29 | Generator of documents 30 | """ 31 | 32 | def parse(self, blob: Blob) -> List[Document]: 33 | """Eagerly parse raw data into a document or documents. 34 | 35 | This is a convenience method for interactive development environment. 36 | 37 | Production applications should favor the lazy_parse method instead. 38 | 39 | Subclasses should generally not over-ride this parse method. 40 | 41 | Args: 42 | blob: representation of raw data from file 43 | 44 | Returns: 45 | List of documents 46 | """ 47 | return list(self.lazy_parse(blob)) 48 | 49 | @staticmethod 50 | def get_doc_type(file: Union[str, Path]) -> DocType: 51 | mimetype = mimetypes.guess_type(file)[0] 52 | if mimetype is None: 53 | mimetype = Path(file).suffix.replace(".", "") 54 | 55 | match mimetype: 56 | case "application/pdf" | "pdf": 57 | return DocType.pdf 58 | case ( 59 | "application/vnd.openxmlformats-officedocument.wordprocessingml.document" 60 | | "docx" 61 | ): 62 | return DocType.docx 63 | case "application/msword" | "doc": 64 | return DocType.doc 65 | case "application/vnd.oasis.opendocument.text" | "odt": 66 | return DocType.odt 67 | case "application/rtf" | "rtf": 68 | return DocType.rtf 69 | case ( 70 | "application/zip" 71 | | "application/x-zip-compressed" 72 | | "multipart/x-zip" 73 | | "zip" 74 | ): 75 | return DocType.zip 76 | case _: # TODO: add txt, xlsx, pptx support 77 | return DocType.unsupported 78 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/parsers/converting/__init__.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from protollm.raw_data_processing.docs_parsers.parsers.converting.converted_file import ( 4 | converted_file as __converted_file, 5 | ) 6 | from protollm.raw_data_processing.docs_parsers.parsers.entities import ConvertingDocType 7 | 8 | converted_file_to_docx = partial( 9 | __converted_file, target_doc_type=ConvertingDocType.docx 10 | ) 11 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/parsers/converting/converted_file.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from tempfile import TemporaryDirectory, NamedTemporaryFile 3 | from typing import BinaryIO, Generator, Union, Optional 4 | 5 | from protollm.raw_data_processing.docs_parsers.parsers.converting.converting import _convert_with_soffice 6 | from protollm.raw_data_processing.docs_parsers.parsers.entities import ConvertingDocType 7 | 8 | 9 | @contextmanager 10 | def converted_file( 11 | stream, 12 | target_doc_type: Union[str, ConvertingDocType] = ConvertingDocType.docx, 13 | timeout: Optional[int] = None, 14 | ) -> Generator[BinaryIO, None, None]: 15 | if target_doc_type not in ConvertingDocType.__members__: 16 | raise ValueError("Invalid target document type") 17 | target_doc_type = f"{target_doc_type}" 18 | 19 | with TemporaryDirectory() as tmp_dir: 20 | tmp_file = NamedTemporaryFile(delete=False, dir=tmp_dir) 21 | tmp_file.write(stream.read()) 22 | tmp_file.close() 23 | tmp_file_path = tmp_file.name 24 | 25 | _convert_with_soffice( 26 | filename=tmp_file_path, 27 | output_directory=tmp_dir, 28 | target_doc_type=target_doc_type, 29 | timeout=timeout, 30 | ) 31 | converted_file_path = ".".join((tmp_file_path, target_doc_type)) 32 | 33 | with open(converted_file_path, "rb") as f: 34 | yield f 35 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/parsers/converting/converting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from pathlib import Path 4 | from typing import Union 5 | 6 | from protollm.raw_data_processing.docs_parsers.utils.exceptions import ConvertingError 7 | 8 | 9 | def _convert_with_soffice( 10 | filename: Union[Path, str], 11 | output_directory: Union[Path, str], 12 | target_doc_type: str = "docx", 13 | timeout: int = None, 14 | ): 15 | """ 16 | Converts a file to a target format using the libreoffice CLI. 17 | """ 18 | command = [ 19 | "soffice", 20 | "--headless", 21 | "--convert-to", 22 | target_doc_type, 23 | "--outdir", 24 | output_directory, 25 | str(filename), 26 | ] 27 | expected_path = Path( 28 | output_directory, ".".join((Path(filename).stem, target_doc_type)) 29 | ) 30 | try: 31 | result = subprocess.run( 32 | command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=timeout 33 | ) 34 | try: 35 | error = result.stderr.decode().strip() 36 | except UnicodeError: 37 | error = "*** Error message cannot be displayed ***" 38 | if error and not os.path.isfile(expected_path): 39 | raise ConvertingError( 40 | f"Could not convert file to {target_doc_type}\n{error}" 41 | ) 42 | except subprocess.TimeoutExpired: 43 | raise ConvertingError( 44 | f"Converting file to {target_doc_type} hadn't terminated after {timeout} seconds" 45 | ) from None 46 | except FileNotFoundError: 47 | raise ConvertingError( 48 | "soffice command was not found. Please install libreoffice on your system and try again." 49 | ) from None 50 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/parsers/converting/exceptions.py: -------------------------------------------------------------------------------- 1 | class ConvertingError(Exception): 2 | def __init__(self, message): 3 | super().__init__(message) 4 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/parsers/entities.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class DocType(str, Enum): 5 | docx = "docx" 6 | doc = "doc" 7 | odt = "odt" 8 | rtf = "rtf" 9 | pdf = "pdf" 10 | zip = "zip" 11 | unsupported = "unsupported" # TODO: add txt, xlsx, pptx support 12 | 13 | 14 | class ConvertingDocType(str, Enum): 15 | docx = "docx" # TODO: add xlsx, pptx support 16 | 17 | 18 | class ParsingScheme(str, Enum): 19 | paragraphs = "paragraphs" 20 | lines = "lines" 21 | chapters = "chapters" 22 | full = "full" 23 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/parsers/pdf/__init__.py: -------------------------------------------------------------------------------- 1 | from protollm.raw_data_processing.docs_parsers.parsers.pdf.pdf_parser import PDFParser 2 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/parsers/utilities.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | # HEADINGS 4 | 5 | HEADING_KEYWORDS = [ 6 | "предисловие", 7 | "содержание", 8 | "оглавление", 9 | "введение", 10 | "лист согласований", 11 | ] 12 | CONTENTS_KEYWORDS = ["содержание", "оглавление", "лист согласований"] 13 | HEADING_STOP_LIST = ["утверждаю"] 14 | 15 | FOOTER_KEYWORDS = ["документ создан в электронной форме", "страница создана"] 16 | 17 | 18 | # BULLETS 19 | 20 | UNICODE_BULLETS = [ 21 | "\u0095", 22 | "\u2022", 23 | "\u2023", 24 | "\u2043", 25 | "\u3164", 26 | "\u204C", 27 | "\u204D", 28 | "\u2219", 29 | "\u25CB", 30 | "\u25CF", 31 | "\u25D8", 32 | "\u25E6", 33 | "\u2619", 34 | "\u2765", 35 | "\u2767", 36 | "\u29BE", 37 | "\u29BF", 38 | "\u002D", 39 | "", 40 | r"\*", 41 | "\x95", 42 | "·", 43 | ] 44 | 45 | BULLETS_PATTERN = "|".join(UNICODE_BULLETS) 46 | UNICODE_BULLETS_RE = re.compile(f"(?:{BULLETS_PATTERN})(?!{BULLETS_PATTERN})") 47 | 48 | 49 | def is_bulleted_text(text: str) -> bool: 50 | """Checks to see if the section of text is part of a bulleted list.""" 51 | return UNICODE_BULLETS_RE.match(text.strip()) is not None 52 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/parsers/word_doc/__init__.py: -------------------------------------------------------------------------------- 1 | from protollm.raw_data_processing.docs_parsers.parsers.word_doc.word_doc_parser import WordDocumentParser 2 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/parsers/word_doc/docx_parsing_config.py: -------------------------------------------------------------------------------- 1 | from docx.document import Document 2 | 3 | from protollm.raw_data_processing.docs_parsers.parsers.word_doc.xml.utilities import ( 4 | _get_omml2mml_transformation, 5 | _get_mml2tex_transformation, 6 | ) 7 | 8 | 9 | class DocxParsingConfig: 10 | def __init__( 11 | self, 12 | document: Document, 13 | extract_images: bool = False, 14 | parse_formulas: bool = False, 15 | ): 16 | self.__rels = document.part.rels 17 | self.__extract_images = extract_images 18 | self.__parse_formulas = parse_formulas 19 | self.__omml2mml = None 20 | self.__mml2tex = None 21 | 22 | @property 23 | def extract_images(self): 24 | return self.__extract_images 25 | 26 | @property 27 | def parse_formulas(self): 28 | return self.__parse_formulas 29 | 30 | @property 31 | def document_relationships(self): 32 | return self.__rels 33 | 34 | @property 35 | def omml2mml_transformation(self): 36 | if self.__omml2mml is None and self.__parse_formulas: 37 | self.__omml2mml = _get_omml2mml_transformation() 38 | 39 | return self.__omml2mml 40 | 41 | @property 42 | def mml2tex_transformation(self): 43 | if self.__mml2tex is None and self.__parse_formulas: 44 | self.__mml2tex = _get_mml2tex_transformation() 45 | 46 | return self.__mml2tex 47 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/parsers/word_doc/xml/__init__.py: -------------------------------------------------------------------------------- 1 | from protollm.raw_data_processing.docs_parsers.parsers.word_doc.xml.xml_processing import ( 2 | process_paragraph_body, 3 | ) 4 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/parsers/word_doc/xml/utilities.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from lxml import etree 4 | 5 | 6 | def _get_omml2mml_transformation() -> etree.XSLT: 7 | omml2mml_file = Path(Path(__file__).parent, "xsl", "omml2mml", "OMML2MML.XSL") 8 | omml2mml = etree.XSLT(etree.parse(omml2mml_file)) 9 | return omml2mml 10 | 11 | 12 | def _get_mml2tex_transformation() -> etree.XSLT: 13 | mml2tex_file = Path(Path(__file__).parent, "xsl", "mml2tex", "mmltex.xsl") 14 | mml2tex = etree.XSLT(etree.parse(mml2tex_file)) 15 | 16 | return mml2tex 17 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/parsers/word_doc/xml/xml_tag.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | _namespace_mapping = { 5 | "w": "http://schemas.openxmlformats.org/wordprocessingml/2006/main", 6 | "m": "http://schemas.openxmlformats.org/officeDocument/2006/math", 7 | "a": "http://schemas.openxmlformats.org/drawingml/2006/main", 8 | "r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships", 9 | } 10 | 11 | 12 | def _get_xml_tag_name(tag: str, namespace: str) -> str: 13 | namespace = _namespace_mapping[namespace] 14 | return f"{{{namespace}}}{tag}" 15 | 16 | 17 | class XMLTag(str, Enum): 18 | raw = _get_xml_tag_name("r", "w") 19 | text = _get_xml_tag_name("t", "w") 20 | image = _get_xml_tag_name("drawing", "w") 21 | blip = _get_xml_tag_name("blip", "a") 22 | embed = _get_xml_tag_name("embed", "r") 23 | math = _get_xml_tag_name("oMath", "m") 24 | math_paragraph = _get_xml_tag_name("oMathPara", "m") 25 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/parsers/word_doc/xml/xsl/mml2tex/mmltex.xsl: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | 7 | 8 | 9 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | $ 26 | 27 | $ 28 | 29 | 30 | 31 | \[ 32 | 33 | \] 34 | 35 | 36 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm/raw_data_processing/docs_parsers/utils/__init__.py -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/utils/exceptions.py: -------------------------------------------------------------------------------- 1 | class ConvertingError(Exception): 2 | def __init__(self, message): 3 | super().__init__(message) 4 | 5 | 6 | class EncodingError(Exception): 7 | def __init__(self, message): 8 | super().__init__(message) 9 | 10 | 11 | class NoTextLayerError(Exception): 12 | def __init__(self, message): 13 | super().__init__(message) 14 | 15 | 16 | class ChaptersExtractingFailedWarning(Warning): 17 | def __init__(self, message): 18 | super().__init__(message) 19 | 20 | 21 | class ParseImageWarning(Warning): 22 | def __init__(self, message): 23 | super().__init__(message) 24 | 25 | 26 | class TitleExtractingWarning(Warning): 27 | def __init__(self, message): 28 | super().__init__(message) 29 | 30 | 31 | class PageNumbersExtractingWarning(Warning): 32 | def __init__(self, message): 33 | super().__init__(message) 34 | 35 | 36 | class FooterExtractingWarning(Warning): 37 | def __init__(self, message): 38 | super().__init__(message) 39 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | from contextlib import contextmanager 4 | from typing import Optional, Generator 5 | 6 | 7 | class ParsingLogger: 8 | def __init__(self, silent_errors: bool = False, name: Optional[str] = None): 9 | name = name or __name__ 10 | self._logger = logging.getLogger(name) 11 | self._logs: dict[str, list[str]] = {} 12 | self._silent_errors = silent_errors 13 | 14 | @property 15 | def logger(self): 16 | return self._logger 17 | 18 | @property 19 | def logs(self): 20 | return self._logs 21 | 22 | def info(self, msg: str, *args, **kwargs): 23 | self._logger.info(msg, *args, **kwargs) 24 | 25 | def warning(self, msg: str, *args, **kwargs): 26 | self._logger.warning(msg, *args, **kwargs) 27 | 28 | def error(self, msg: str, *args, **kwargs): 29 | self._logger.error(msg, *args, **kwargs) 30 | 31 | def critical(self, msg: str, *args, **kwargs): 32 | self._logger.critical(msg, *args, **kwargs) 33 | 34 | def exception(self, msg: str, *args, **kwargs): 35 | self._logger.exception(msg, *args, **kwargs) 36 | 37 | def debug(self, msg: str, *args, **kwargs): 38 | self._logger.debug(msg, *args, **kwargs) 39 | 40 | @contextmanager 41 | def parsing_info_handler(self, file_name: str) -> Generator[None, None, None]: 42 | try: 43 | try: 44 | with warnings.catch_warnings(record=True) as record: 45 | warnings.simplefilter("default") 46 | yield 47 | finally: 48 | for warning in record: 49 | warn_msg = f"{warning.message} (in {file_name})" 50 | self.warning(warn_msg) 51 | except Exception as error: 52 | self.logs[file_name] = self.logs.get(file_name, []) 53 | self.logs[file_name].append(str(error)) 54 | err_msg = f"{error} (in {file_name})" 55 | self.error(err_msg) 56 | if not self._silent_errors: 57 | raise 58 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_parsers/utils/utilities.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Union 3 | 4 | import chardet 5 | from ftfy import is_bad 6 | 7 | 8 | def is_bad_encoding(lines: list[str]) -> bool: 9 | count_bad = sum([_is_bad(line) for line in lines], start=0) 10 | proportion_bad = count_bad / max(1, len(lines)) 11 | return proportion_bad >= 0.5 12 | 13 | 14 | def _is_bad(text: str) -> bool: 15 | if is_bad(text): 16 | return True 17 | # This verification works for russian docs !!! 18 | try: 19 | text.encode("sloppy-windows-1252") 20 | except UnicodeEncodeError: 21 | return False 22 | return True 23 | 24 | 25 | def correct_path_encoding(path: Union[str, Path]) -> str: 26 | path = Path(path) 27 | path = Path(*[fix_zip_path(part) for part in path.parts]) 28 | return str(path) 29 | 30 | 31 | def fix_zip_path(path: str) -> str: 32 | try: 33 | string_bytes = path.encode("437") 34 | guessed_encoding = chardet.detect(string_bytes)["encoding"] or "cp1252" 35 | path = string_bytes.decode(guessed_encoding, "replace") 36 | except: 37 | pass 38 | return path 39 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_transformers/__init__.py: -------------------------------------------------------------------------------- 1 | from chunk_merger import ChunkMerger 2 | from recursive_splitter import RecursiveSplitter 3 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_transformers/recursive_splitter.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Iterable, List, Optional 3 | 4 | from langchain_text_splitters.character import RecursiveCharacterTextSplitter 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class RecursiveSplitter(RecursiveCharacterTextSplitter): 10 | """Splitting text by the given sequence of splitters. 11 | """ 12 | 13 | def __init__( 14 | self, 15 | separators: Optional[List[str]] = None, 16 | keep_separator: bool = True, 17 | is_separator_regex: bool = False, 18 | **kwargs: Any, 19 | ) -> None: 20 | """Create a new TextSplitter.""" 21 | kwargs["chunk_overlap"] = 0 22 | super().__init__(keep_separator=keep_separator, **kwargs) 23 | self._separators = separators or [ 24 | "\n\n", 25 | "\n", 26 | ". ", 27 | ";", 28 | ", ", 29 | ".", 30 | ",", 31 | " ", 32 | "", 33 | ] 34 | self._is_separator_regex = is_separator_regex 35 | 36 | def split_text(self, text: str) -> List[str]: 37 | if self._length_function(text) < self._chunk_size: 38 | return [text] 39 | return self._split_text(text, self._separators) 40 | 41 | def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: 42 | # We now want to combine these smaller pieces into medium size 43 | # chunks to send to the LLM. 44 | docs = [] 45 | current_doc = [] 46 | for text in splits: 47 | merged_text = self._join_docs([*current_doc, text], separator) 48 | if ( 49 | merged_text is None 50 | or self._length_function(merged_text) <= self._chunk_size 51 | ): 52 | current_doc.append(text) 53 | continue 54 | doc = self._join_docs(current_doc, separator) 55 | if doc is None and self._length_function(text) > self._chunk_size: 56 | logger.warning( 57 | f"Created a chunk, which is longer than the specified {self._chunk_size}" 58 | ) 59 | else: 60 | docs.append(doc) 61 | current_doc = [text] 62 | doc = self._join_docs(current_doc, separator) 63 | if doc is not None: 64 | docs.append(doc) 65 | return docs 66 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_transformers/sentences_splitter.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Any, Iterable, Optional 3 | 4 | from langchain_text_splitters import TextSplitter 5 | 6 | from protollm.raw_data_processing.docs_transformers.utilities import fix_list_dots_separators 7 | 8 | 9 | class SentencesSplitter(TextSplitter): 10 | def __init__(self, separators: Optional[Iterable[str]] = None, **kwargs: Any): 11 | super().__init__(**kwargs) 12 | self._separators = separators or [r"\. (? list[str]: 15 | sentences_lst = [ 16 | x.strip() for x in re.split(" | ".join(self._separators), text) 17 | ] 18 | return fix_list_dots_separators(sentences_lst) 19 | -------------------------------------------------------------------------------- /protollm/raw_data_processing/docs_transformers/utilities.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def fix_list_dots_separators(sentences: list[str]) -> list[str]: 5 | """ 6 | Takes list of sentences and combines those of them that are list items 7 | that were incorrectly separated due to the use of a dot separator. 8 | Returns updated list of sentences 9 | """ 10 | fixed_sentences_lst = [] 11 | sentence_parts_lst = [] 12 | i = 0 13 | while i < len(sentences): 14 | chunk = sentences[i].strip() 15 | if len(chunk) > 0: 16 | # it means that the dot was used to separate list elements, and we should join such sentences 17 | if not chunk[0].isupper() and not chunk[0].isdigit(): 18 | sentence_parts_lst.append(chunk) 19 | else: 20 | if len(sentence_parts_lst) != 0: 21 | fixed_sentences_lst.append("; ".join(sentence_parts_lst)) 22 | sentence_parts_lst = [chunk] 23 | i += 1 24 | 25 | if len(sentence_parts_lst) != 0: 26 | fixed_sentences_lst.append("; ".join(sentence_parts_lst)) 27 | return fixed_sentences_lst 28 | -------------------------------------------------------------------------------- /protollm/templates/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm/templates/__init__.py -------------------------------------------------------------------------------- /protollm/templates/code_templates/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm/templates/code_templates/__init__.py -------------------------------------------------------------------------------- /protollm/templates/code_templates/agent_template.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | 4 | from langchain.agents import create_structured_chat_agent, AgentExecutor 5 | from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder 6 | from langchain_core.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate 7 | import pandas as pd 8 | from langchain_core.tools import tool 9 | 10 | from protollm.agents.llama31_agents.llama31_agent import Llama31ChatModel 11 | from examples.real_world.chemical_pipeline.validate_tools import validate_decompose, compute_metrics, validate_conductor 12 | 13 | # Create the system and human prompts 14 | system_prompt = ''' 15 | Respond to the human as helpfully and accurately as possible. 16 | ''' 17 | 18 | human_prompt = '''{input} 19 | {agent_scratchpad} 20 | (Reminder to respond in a JSON blob no matter what)''' 21 | 22 | system_message = SystemMessagePromptTemplate.from_template( 23 | system_prompt, 24 | input_variables=["tools", "tool_names"], 25 | ) 26 | human_message = HumanMessagePromptTemplate.from_template( 27 | human_prompt, 28 | input_variables=["input", "agent_scratchpad"], 29 | ) 30 | 31 | # Create the ChatPromptTemplate 32 | prompt = ChatPromptTemplate.from_messages( 33 | [ 34 | system_message, 35 | MessagesPlaceholder(variable_name="chat_history", optional=True), 36 | human_message, 37 | ] 38 | ) 39 | 40 | # Initialize the custom LLM 41 | llm = Llama31ChatModel( 42 | api_key='API_KEY_HERE', 43 | base_url="URL_HERE", 44 | model="meta-llama/llama-3.1-70b-instruct", 45 | temperature=0.5, 46 | max_tokens=5000 47 | ) 48 | 49 | # Create the structured chat agent 50 | agent = create_structured_chat_agent( 51 | llm=llm, 52 | prompt=prompt, 53 | stop_sequence=True 54 | ) 55 | 56 | # Create the AgentExecutor 57 | agent_executor = AgentExecutor.from_agent_and_tools( 58 | verbose=True, 59 | return_intermediate_steps=True, 60 | output_keys=["output"], 61 | early_stopping_method="generate" 62 | ) 63 | 64 | response = agent_executor.invoke({ 65 | "input": 'test' 66 | }) 67 | -------------------------------------------------------------------------------- /protollm/templates/prompt_templates/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm/templates/prompt_templates/__init__.py -------------------------------------------------------------------------------- /protollm/templates/prompt_templates/assistant_prompt_templates.py: -------------------------------------------------------------------------------- 1 | DOMAIN_SPECIFIC_ASSTISTANT=""" 2 | You are a smart AI assistant. You have high expertise in the field 3 | [...]. 4 | Answer the question following the rules below. 5 | 1. Before giving an answer to the user question, provide an 6 | explanation. Mark the answer with keyword 'ANSWER', and 7 | explanation with 'EXPLANATION'. Both answer and explanation must be 8 | in the English language. 9 | 2. If the question is about complaints, answer about at least 5 10 | complaints topics. 11 | 3. Answer should be five sentences maximum. 12 | 4. In answers you must use only the English language. 13 | """ 14 | 15 | -------------------------------------------------------------------------------- /protollm/templates/prompt_templates/metric_evalutation_prompts.py: -------------------------------------------------------------------------------- 1 | GEVAL_PROMPT = """criteria=( 2 | 1. Correctness and Relevance: 3 | - Compare the actual response against the expected response. 4 | Determine the extent to which the actual response 5 | captures the key elements and concepts of the expected response. 6 | - Assign higher scores to actual responses that accurately reflect 7 | the core information of the expected response, even if only partial. 8 | 2. Numerical Accuracy and Interpretation: 9 | - Pay particular attention to any numerical values present 10 | in the expected response. Verify that these values are 11 | correctly included in the actual response and accurately 12 | interpreted within the context. 13 | - Ensure that units of measurement, scales, and numerical 14 | relationships are preserved and correctly conveyed. 15 | 3. Allowance for Partial Information: 16 | - Do not heavily penalize the actual response for incompleteness 17 | if it covers significant aspects of the expected response. 18 | Prioritize the correctness of provided information over 19 | total completeness. 20 | 4. Handling of Extraneous Information: 21 | - While additional information not present in the expected response 22 | should not necessarily reduce score, 23 | ensure that such additions do not introduce inaccuracies 24 | or deviate from the context of the expected response.)""" -------------------------------------------------------------------------------- /protollm/templates/prompt_templates/synthetic_data_prompts.py: -------------------------------------------------------------------------------- 1 | 2 | # Пример заполнения шаблона - objects = 'характеристики различных зданий в Санкт-Петербурге', 3 | # user_field = 'ГИС-системы', N=10 4 | 5 | synthetic_prompt_template_basic = '''В файле даны {objects}. 6 | Проанализируй файл и сформулируй {N} вопросов, которые мог бы задать 7 | пользователь {user_field}.''' 8 | 9 | # data_desc = 'зданий в Санкт-Петербурге', stucture - JSON-описание, role - 'урбанист', 10 | # work = 'анализ различных данных о городе для получения некоторых практических выводов' 11 | synthetic_prompt_template_advanced = ''' 12 | Датасет с данными {data_desc} имеет следующую структуру: {structures}. 13 | Представь, что ты - {role}. В круг твоих задач входит {work}. 14 | Сформулируй {N} вопросов, которые мог бы задать {role}, основываясь на 15 | данных датасета.''' 16 | 17 | prompt_template_enrichment = """Теперь представь, что у тебя есть другие данные о {data_desc}, и 18 | можно ориентироваться не только на структуру датасета. Сгенерируй ещё {N} вопросов по теме {data_desc}.""" 19 | 20 | prompt_template_file_only = """В файле даны характеристики различных {data_desc}. 21 | Проанализируй файл и сформулируй {N} вопросов, которые мог бы 22 | задать пользователь {user_field}. Ответь на них на основе информации из файла.""" 23 | 24 | prompt_template_expert_role = """Представь, что ты - {role}, 25 | который занимается {work}. Вот примеры вопросов, которые задают специалисты в этой области: 26 | {questions}. Сформулируй ещё {N} вопросов. Они могут быть как широкими и обобщёнными, так и узкоспециализированными. 27 | """ 28 | -------------------------------------------------------------------------------- /protollm/tools/web_tools.py: -------------------------------------------------------------------------------- 1 | from langchain_community.tools.tavily_search import TavilySearchResults 2 | from langchain_community.tools import DuckDuckGoSearchResults 3 | from langchain.tools.render import render_text_description 4 | import os 5 | 6 | 7 | tavily_tool = None 8 | 9 | if os.getenv('TAVILY_API_KEY') is not None: 10 | tavily_tool = TavilySearchResults(max_results=5) 11 | web_tools = [tavily_tool] 12 | else: 13 | web_tools = [DuckDuckGoSearchResults()] 14 | 15 | web_tools_rendered = render_text_description(web_tools).replace('duckduckgo_results_json', 'DuckDuckGoSearchResults') 16 | -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/.env.example: -------------------------------------------------------------------------------- 1 | POSTGRES_PORT=55530 2 | POSTGRES_HOST=0.0.0.0 3 | POSTGRES_USER=user 4 | POSTGRES_PASSWORD=password 5 | POSTGRES_DB=agents 6 | REDIS_HOST=0.0.0.0 7 | REDIS_PORT=55531 -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/README.md: -------------------------------------------------------------------------------- 1 | # llm-agents-api 2 | 3 | This tool library provides a simple API for creating and running LLM agents and build multi-agent systems. 4 | SDK allows agents creation and management, and provides interface to integrate those using router agent and ensemble agent. 5 | Tool also provides an Entrypoint object which starts uvicorn server for running the API. 6 | 7 | ## Installation 8 | 9 | ```bash 10 | poetry install 11 | ``` 12 | 13 | ## Run example 14 | 1) Copy .env.example to .env and set variables 15 | 2) Run 16 | ```bash 17 | docker compose up -d 18 | ``` 19 | 3) Run example 20 | ```bash 21 | poetry run python examples/main.py 22 | ``` 23 | 24 | 4) Open browser and go to http://:/docs (by default http://0.0.0.0:8080/docs): 25 | - `/` - Agents listing 26 | - `/agents/` - Agent details 27 | 28 | 5) You can use any Websocket client to connect to the agents and send messages. For example, you can use Postman. 29 | - `ws://:/agents` - Websocket connection to the agent 30 | Example query: 31 | ```json 32 | { 33 | "dialogue_id": "2fb8e8f0-bd05-5eca-8e4d-376ede293e53", 34 | "agent_id": "3208a446-d847-45a8-a724-159fa87334b9", 35 | "chat_history":[], 36 | "query":"Какие целевые показатели госпрограмм по образованию и защите окружающей среды?", 37 | "run_params": {} 38 | } 39 | ``` 40 | - `ws://:/router` - Websocket connection to the router agent 41 | Example query: 42 | ```json 43 | { 44 | "dialogue_id": "2fb8e8f0-bd05-5eca-8e4d-376ede293e53", 45 | "chat_history":[], 46 | "query":"Какие целевые показатели госпрограмм по образованию и защите окружающей среды?" 47 | } 48 | ``` 49 | - `ws://:/ensemble` - Websocket connection to the ensemble agent 50 | Example query: 51 | ```json 52 | { 53 | "dialogue_id": "2fb8e8f0-bd05-5eca-8e4d-376ede293e53", 54 | "chat_history":[], 55 | "query":"Какие целевые показатели госпрограмм по образованию и защите окружающей среды?" 56 | } 57 | ``` 58 | 59 | After the message is sent, you will receive a stream of messages from the agent. 60 | 61 | ## Run tests 62 | 1) Copy .env.test.example to .env.test and set variables 63 | 2) Run 64 | ```bash 65 | poetry run pytest 66 | ``` 67 | -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/llm-agents-api/__init__.py -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | db: 3 | image: postgres:16 4 | ports: 5 | - 55530:5432 6 | env_file: 7 | - .env 8 | volumes: 9 | - postgres-data:/var/lib/postgresql/data 10 | 11 | cache: 12 | image: redis:7.2-alpine 13 | volumes: 14 | - redis-data:/data 15 | ports: 16 | - 55531:6379 17 | 18 | vectorstore: 19 | image: chromadb/chroma:0.5.11 20 | ports: 21 | - "57777:8000" 22 | volumes: 23 | - vectorstore-data:/chroma/chroma 24 | 25 | volumes: 26 | postgres-data: 27 | redis-data: 28 | vectorstore-data: 29 | external: true -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/examples/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from protollm_agents.entrypoint import Entrypoint 4 | 5 | 6 | logging.basicConfig( 7 | level=logging.DEBUG, 8 | format="%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", 9 | handlers=[logging.StreamHandler()], 10 | ) 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | if __name__ == "__main__": 15 | epoint = Entrypoint( 16 | config_path="./examples/admin-config.yml", 17 | ) 18 | epoint.run() 19 | -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/examples/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/llm-agents-api/examples/pipelines/__init__.py -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/examples/pipelines/mock_background_agent.py: -------------------------------------------------------------------------------- 1 | from typing import AsyncGenerator 2 | from protollm_agents.sdk import BackgroundAgent 3 | 4 | from protollm_agents.sdk.base import Event 5 | from protollm_agents.sdk.events import TextEvent, MultiDictEvent, StatusEvent, DictEvent 6 | from protollm_agents.sdk.context import Context 7 | 8 | class SummaryAgent(BackgroundAgent): 9 | 10 | class Arguments(BackgroundAgent.Arguments): 11 | a: int 12 | 13 | async def invoke(self, *args, **kwargs): 14 | pass 15 | 16 | 17 | async def stream(self, ctx: Context, arguments: Arguments, documents: list[str]) -> AsyncGenerator[Event, None]: 18 | mock_events = [ 19 | StatusEvent(agent_id=self.agent_id, result="Planning", name="Planning event"), 20 | MultiDictEvent(agent_id=self.agent_id, result=[{"id": "1", "page_content": "Mock document 1"}, {"id": "2", "page_content": "Mock document 2"}]), 21 | TextEvent(agent_id=self.agent_id, result="Mock", is_eos=False), 22 | TextEvent(agent_id=self.agent_id, result="Mock answer", is_eos=False), 23 | TextEvent(agent_id=self.agent_id, result="Mock answer to question", is_eos=False), 24 | DictEvent(agent_id=self.agent_id, result=dict(a=1), is_eos=True), 25 | ] 26 | for event in mock_events: 27 | yield event 28 | 29 | 30 | async def to_tool(self, *args, **kwargs): 31 | pass 32 | 33 | 34 | async def to_runnable(self): 35 | pass 36 | 37 | -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/protollm_agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/llm-agents-api/protollm_agents/__init__.py -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/protollm_agents/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/llm-agents-api/protollm_agents/api/__init__.py -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/protollm_agents/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/llm-agents-api/protollm_agents/models/__init__.py -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/protollm_agents/models/models.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import uuid 3 | 4 | from sqlalchemy import TIMESTAMP, Column, String, types 5 | from sqlalchemy.dialects.postgresql import ARRAY, JSONB 6 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 7 | from sqlalchemy.sql import func 8 | from typing_extensions import Annotated 9 | 10 | str30 = Annotated[str, 30] 11 | str50 = Annotated[str, 50] 12 | 13 | class Base(DeclarativeBase): 14 | type_annotation_map = { 15 | datetime.datetime: TIMESTAMP(timezone=True), 16 | str30: String(length=30), 17 | str50: String(length=50), 18 | } 19 | 20 | class Agent(Base): 21 | __tablename__ = "agents" 22 | 23 | agent_id: Mapped[uuid.UUID] = mapped_column(types.Uuid, primary_key=True) 24 | name: Mapped[str] 25 | description: Mapped[str] 26 | arguments = Column(JSONB, nullable=True) 27 | created: Mapped[datetime.datetime] = mapped_column(server_default=func.CURRENT_TIMESTAMP(), nullable=False) 28 | module_name: Mapped[str] 29 | class_name: Mapped[str] 30 | agent_type = Column(ARRAY(String), nullable=True) 31 | -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/protollm_agents/models/requests.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated 2 | from annotated_types import Len 3 | from pydantic import BaseModel, Field, field_serializer 4 | import uuid 5 | 6 | 7 | class HistoryRecord(BaseModel): 8 | query: str = Field(..., description="The query to send to the agent") 9 | response: str = Field(..., description="The response from the agent") 10 | 11 | class RouterSocketQuery(BaseModel): 12 | dialogue_id: uuid.UUID = Field(..., description="Unique dialogue id") 13 | query: str = Field(..., description="The query to send to the agent") 14 | chat_history: Annotated[list[HistoryRecord], Len(max_length=50)] = Field( 15 | default_factory=list, 16 | description="Previous queries in the dialogue" 17 | ) 18 | 19 | @field_serializer('dialogue_id') 20 | def serialize_id_dialogue(self, dialogue_id: uuid.UUID, _info): 21 | return str(dialogue_id) 22 | 23 | @property 24 | def history_as_tuple_list(self) -> list[tuple[str, str]]: 25 | chat_history = [] 26 | for record in self.chat_history: 27 | chat_history.extend([("human", record.query), ("ai", record.response)]) 28 | return chat_history 29 | 30 | class AgentSocketQuery(RouterSocketQuery): 31 | agent_id: uuid.UUID = Field(..., description="The ID of the agent to query") 32 | run_params: dict = Field(default_factory=dict, description="The parameters to run the agent with") 33 | 34 | @field_serializer('agent_id') 35 | def serialize_id_agent(self, agent_id: uuid.UUID, _info): 36 | return str(agent_id) 37 | 38 | class AgentRESTQuery(BaseModel): 39 | agent_id: uuid.UUID = Field(..., description="The ID of the agent to query") 40 | run_params: dict = Field(default_factory=dict, description="The parameters to run the agent with") 41 | documents: list[str] = Field(default_factory=list, description="The documents to process") 42 | 43 | @field_serializer('agent_id') 44 | def serialize_id_agent(self, agent_id: uuid.UUID, _info): 45 | return str(agent_id) 46 | -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/protollm_agents/models/responses.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from typing import Any 3 | from pydantic import BaseModel, Field, field_serializer 4 | 5 | class AgentResponse(BaseModel): 6 | agent_id: uuid.UUID = Field(..., description="The agent id") 7 | name: str = Field(..., description="The agent name") 8 | description: str = Field(..., description="The agent description") 9 | arguments: dict = Field(..., description="The agent arguments") 10 | 11 | 12 | class ErrorMessage(BaseModel): 13 | type: str = "error" 14 | detail: str = "Error" 15 | 16 | 17 | class TaskIdResponse(BaseModel): 18 | task_id: uuid.UUID = Field(..., description="The task id") 19 | 20 | @field_serializer('task_id') 21 | def serialize_task_id(self, task_id: uuid.UUID, _info: Any) -> str: 22 | return str(task_id) 23 | 24 | 25 | class TaskResponse(TaskIdResponse): 26 | events: list[dict] = Field(..., description="The events") 27 | 28 | -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/protollm_agents/models/schemas.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from typing import Literal 3 | from pydantic import BaseModel, Field, ConfigDict 4 | import uuid 5 | 6 | from protollm_agents.sdk.base import BaseAgent 7 | 8 | 9 | class Agent(BaseModel): 10 | model_config = ConfigDict(from_attributes=True) 11 | agent_id: uuid.UUID = Field(..., description='ID of the agent') 12 | name: str = Field(..., description='Name of the agent') 13 | description: str = Field(..., description='Description of the agent') 14 | arguments: dict = Field( 15 | ..., 16 | description='Agent parameters' 17 | ) 18 | module_name: str = Field(..., description='Module name of the agent') 19 | class_name: str = Field(..., description='Class name of the agent') 20 | agent_type: list[Literal['background', 'streaming', 'router', 'ensemble']] = Field(default_factory=list, description='Type of the agent') 21 | 22 | @property 23 | def agent_instance(self) -> BaseAgent: 24 | module = importlib.import_module(self.module_name) 25 | agent_cls = getattr(module, self.class_name) 26 | return agent_cls( 27 | agent_id=self.agent_id, 28 | name=self.name, 29 | description=self.description, 30 | arguments=agent_cls.get_arguments_class().model_validate(self.arguments), 31 | module_name=self.module_name, 32 | class_name=self.class_name 33 | ) 34 | -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/protollm_agents/sdk/__init__.py: -------------------------------------------------------------------------------- 1 | from .agents import StreamingAgent, BackgroundAgent 2 | from .events import ( 3 | TextEvent, 4 | MultiTextEvent, 5 | DictEvent, 6 | MultiDictEvent, 7 | StatusEvent, 8 | ErrorEvent, 9 | ) 10 | from .models import CompletionModel, ChatModel, TokenizerModel, EmbeddingAPIModel 11 | -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/protollm_agents/sdk/context.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | from langchain_core.tools import Tool 4 | from protollm_agents.sdk.base import ModelType, VectorStoreType, AgentType 5 | from protollm_agents.sdk.models import TokenizerModel, CompletionModel, ChatModel, MultimodalModel, EmbeddingAPIModel 6 | 7 | @dataclass 8 | class Context: 9 | embeddings: dict[str, EmbeddingAPIModel] = field(default_factory=dict) 10 | llms: dict[str, CompletionModel | ChatModel] = field(default_factory=dict) 11 | multimodals: dict[str, MultimodalModel] = field(default_factory=dict) 12 | tokenizers: dict[str, TokenizerModel] = field(default_factory=dict) 13 | vector_stores: dict[str, VectorStoreType] = field(default_factory=dict) 14 | agents: dict[str, AgentType] = field(default_factory=dict) 15 | tools: dict[str, Tool] = field(default_factory=dict) 16 | 17 | -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/protollm_agents/sdk/events.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | from pydantic import Field 4 | 5 | from protollm_agents.sdk.base import Event 6 | 7 | 8 | class EventType(str, enum.Enum): 9 | text = "text" 10 | multitext = "multitext" 11 | dict = "dict" 12 | multidict = "multidict" 13 | status = "status" 14 | error = "error" 15 | 16 | class TextEvent(Event): 17 | name: str = EventType.text 18 | result: str = Field(..., description='Text of the event') 19 | 20 | class MultiTextEvent(Event): 21 | name: str = EventType.multitext 22 | result: list[str] = Field(..., description='Sequence of texts of the event') 23 | 24 | class DictEvent(Event): 25 | name: str = EventType.dict 26 | result: dict = Field(..., description='Arbitrary event') 27 | 28 | class MultiDictEvent(Event): 29 | name: str = EventType.multidict 30 | result: list[dict] = Field(..., description='Sequence of dictionaries of the event') 31 | 32 | class StatusEvent(Event): 33 | name: str = EventType.status 34 | result: str = Field(..., description='Status of the event') 35 | 36 | class ErrorEvent(Event): 37 | name: str = EventType.error 38 | result: str = Field(..., description='Error of the event') 39 | 40 | -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/protollm_agents/sdk/vector_stores.py: -------------------------------------------------------------------------------- 1 | from langchain_community.vectorstores import VectorStore, Chroma 2 | import chromadb 3 | from pydantic import Field 4 | 5 | from langchain_community.vectorstores import VectorStore 6 | 7 | from protollm_agents.sdk.base import BaseVectorStore 8 | 9 | 10 | class ChromaVectorStore(BaseVectorStore): 11 | host: str = Field(..., description="Host of the vector store") 12 | port: int = Field(..., description="Port of the vector store") 13 | collection_name: str = Field(..., description="Collection name of the vector store") 14 | 15 | def to_vector_store(self) -> VectorStore: 16 | if self.embeddings_model is None: 17 | raise ValueError("Embeddings model is not initialized") 18 | return Chroma( 19 | client=chromadb.HttpClient(host=self.host, port=self.port), 20 | collection_name=self.collection_name, 21 | embedding_function=self.embeddings_model 22 | ) 23 | 24 | 25 | -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/protollm_agents/services/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/llm-agents-api/protollm_agents/services/__init__.py -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/protollm_agents/services/cache_client.py: -------------------------------------------------------------------------------- 1 | from redis.asyncio import Redis 2 | 3 | cache: Redis | None = None 4 | 5 | 6 | async def get_cache(): 7 | return cache 8 | -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/protollm_agents/services/db_client.py: -------------------------------------------------------------------------------- 1 | from contextlib import asynccontextmanager 2 | from typing import Literal 3 | import uuid 4 | from sqlalchemy import select 5 | from sqlalchemy.ext.asyncio import ( 6 | AsyncSession, 7 | AsyncEngine, 8 | ) 9 | from sqlalchemy.exc import NoResultFound 10 | 11 | from protollm_agents.models.schemas import Agent 12 | from protollm_agents.models import models as db_models 13 | from protollm_agents.services.exceptions import AgentNotFound 14 | 15 | 16 | 17 | class DBClient: 18 | def __init__(self, session: AsyncSession): 19 | self.session = session 20 | 21 | def _get_session(self, session: AsyncSession | None = None) -> AsyncSession: 22 | return session or self.session 23 | 24 | @asynccontextmanager 25 | async def execute(self, session: AsyncSession | None = None): 26 | session = self._get_session(session) 27 | try: 28 | yield session 29 | await session.commit() 30 | except Exception: 31 | await session.rollback() 32 | raise 33 | 34 | async def create_agent(self, agent: Agent, session: AsyncSession | None = None) -> Agent: 35 | agent_db = db_models.Agent(**agent.model_dump()) 36 | async with self.execute(session) as sess_: 37 | sess_.add(agent_db) 38 | await sess_.flush() 39 | await sess_.refresh(agent_db) 40 | return Agent.model_validate(agent_db) 41 | 42 | 43 | async def get_agent(self, agent_id: uuid.UUID, session: AsyncSession | None = None) -> Agent: 44 | async with self.execute(session) as sess_: 45 | query = select(db_models.Agent).where(db_models.Agent.agent_id == agent_id) 46 | result = await sess_.execute(query) 47 | try: 48 | agent_db = result.scalar_one() 49 | except NoResultFound: 50 | raise AgentNotFound(f"Agent with id {agent_id} not found") 51 | return Agent.model_validate(agent_db) 52 | 53 | 54 | async def get_agents(self, agent_type: Literal['background', 'streaming', 'all', 'router', 'ensemble'] = 'all', session: AsyncSession | None = None) -> list[Agent]: 55 | async with self.execute(session) as sess_: 56 | query = select(db_models.Agent) 57 | if agent_type != 'all': 58 | query = query.where(db_models.Agent.agent_type.contains([agent_type])) 59 | result = await sess_.execute(query) 60 | return [Agent.model_validate(row) for row in result.scalars()] 61 | 62 | 63 | 64 | engine: AsyncEngine | None = None 65 | SessionLocal: AsyncSession | None = None 66 | 67 | 68 | async def init_db(): 69 | async with engine.begin() as conn: 70 | await conn.run_sync(db_models.Base.metadata.drop_all) 71 | await conn.run_sync(db_models.Base.metadata.create_all) 72 | 73 | 74 | 75 | async def get_db_client() -> DBClient: 76 | async with SessionLocal() as session: 77 | yield DBClient(session) 78 | -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/protollm_agents/services/exceptions.py: -------------------------------------------------------------------------------- 1 | 2 | class AgentNotFound(Exception): 3 | pass 4 | 5 | 6 | 7 | class TaskNotFound(Exception): 8 | pass 9 | -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/protollm_agents/services/socket_connector.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | 4 | from fastapi import WebSocket 5 | from pydantic import BaseModel 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class SocketConnector: 12 | def __init__(self): 13 | logger.info(f'Socket connector initialized') 14 | self.websocket = None 15 | 16 | async def connect(self, websocket: WebSocket): 17 | logger.info('Connecting to socket') 18 | self.websocket = websocket 19 | await self.websocket.accept() 20 | 21 | async def disconnect(self): 22 | self.websocket = None 23 | 24 | async def send_encoded_model(self, message: BaseModel): 25 | await self.websocket.send_json(message.model_dump()) 26 | 27 | async def get_socket_connector() -> SocketConnector: 28 | return SocketConnector() 29 | -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/protollm_agents/services/storage.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from protollm_agents.sdk.models import CompletionModel, ChatModel, MultimodalModel, EmbeddingAPIModel, TokenizerModel 3 | from protollm_agents.sdk.vector_stores import BaseVectorStore 4 | from pydantic import BaseModel, Field 5 | 6 | 7 | class Storage(BaseModel): 8 | llm_models: dict[str, CompletionModel | ChatModel] = Field(...) 9 | multimodal_models: dict[str, MultimodalModel] = Field(...) 10 | embeddings: dict[str, EmbeddingAPIModel] = Field(...) 11 | tokenizers: dict[str, TokenizerModel] = Field(...) 12 | vector_store_clients: dict[str, BaseVectorStore] = Field(...) 13 | 14 | 15 | storage: Storage | None = None 16 | 17 | def get_storage() -> Storage: 18 | if storage is None: 19 | raise ValueError("Storage is not initialized") 20 | return storage 21 | -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "llm-agents-api" 3 | version = "0.1.0" 4 | description = "" 5 | authors = [] 6 | readme = "README.md" 7 | packages = [{include = "protollm_agents"}] 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.10" 11 | fastapi = "^0.115.6" 12 | sqlalchemy = "^2.0.36" 13 | asyncpg = "^0.30.0" 14 | redis = {extras = ["asyncio"], version = "^5.2.1"} 15 | pydantic = "^2.10.3" 16 | pydantic-settings = "^2.6.1" 17 | uvicorn = "^0.32.1" 18 | gunicorn = "^23.0.0" 19 | orjson = "^3.10.12" 20 | websockets = "^14.1" 21 | chromadb = "0.5.11" 22 | langchain = "^0.3.12" 23 | pyyaml = "^6.0.2" 24 | langchain-openai = "^0.2.12" 25 | langchain-community = "^0.3.12" 26 | elasticsearch = {extras = ["async"], version = "^8.17.0"} 27 | transformers = "^4.47.1" 28 | torch = "2.2.2" 29 | langgraph = "^0.2.60" 30 | langfuse = "^2.57.0" 31 | 32 | 33 | [tool.poetry.group.dev.dependencies] 34 | pytest = "^8.3.4" 35 | pytest-docker = "^3.1.1" 36 | psycopg2-binary = "^2.9.10" 37 | pytest-dotenv = "^0.5.2" 38 | 39 | [tool.pytest.ini_options] 40 | testpaths = ["tests"] 41 | pythonpath = ["."] 42 | env_files = ["tests/.env.test"] 43 | filterwarnings = [ 44 | "ignore::DeprecationWarning" 45 | ] 46 | log_cli = true 47 | log_cli_level = "INFO" 48 | log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" 49 | log_cli_date_format = "%Y-%m-%d %H:%M:%S %Z" 50 | 51 | [build-system] 52 | requires = ["poetry-core"] 53 | build-backend = "poetry.core.masonry.api" 54 | -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/tests/.env.test.example: -------------------------------------------------------------------------------- 1 | PYTEST_COMPLETION_MODEL_HOST=localhost 2 | PYTEST_COMPLETION_MODEL_PORT=8001 3 | PYTEST_COMPLETION_MODEL_API_KEY=token 4 | PYTEST_CHAT_MODEL_HOST=localhost 5 | PYTEST_CHAT_MODEL_PORT=8001 6 | PYTEST_CHAT_MODEL_API_KEY=token 7 | PYTEST_EMBEDDING_MODEL_HOST=localhost 8 | PYTEST_EMBEDDING_MODEL_PORT=58891 9 | PYTEST_EMBEDDING_MODEL_API_KEY=token -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/tests/config.test.yml: -------------------------------------------------------------------------------- 1 | app_port: 8080 2 | app_host: 0.0.0.0 3 | redis_host: 0.0.0.0 4 | redis_port: 55531 5 | postgres_host: 0.0.0.0 6 | postgres_port: 55530 7 | postgres_user: test 8 | postgres_password: test 9 | postgres_db: test 10 | agents: 11 | - name: rag_environment 12 | agent_id: 07dd7db1-075a-4391-b537-6fbca4d5a5f6 13 | description: Поиск по базе докуметов компании 14 | class_path: examples.pipelines.rag_agent.RAGAgent 15 | default_params: 16 | max_input_tokens: 6144 17 | max_chat_history_token_length: 24576 18 | retrieving_top_k: 2 19 | generator_context_top_k: 2 20 | include_original_question_in_queries: True 21 | planner_model_name: planner_llm 22 | generator_model_name: generator_llm 23 | tokenizer_name: qwen_2.5 24 | store_name: chroma_environment 25 | - name: rag_education 26 | agent_id: 3208a446-d847-45a8-a724-159fa87334b9 27 | description: Поиск по документам политики развития образования 28 | class_path: examples.pipelines.rag_agent.RAGAgent 29 | default_params: 30 | max_input_tokens: 6144 31 | max_chat_history_token_length: 24576 32 | retrieving_top_k: 2 33 | generator_context_top_k: 2 34 | include_original_question_in_queries: True 35 | planner_model_name: planner_llm 36 | generator_model_name: generator_llm 37 | tokenizer_name: qwen_2.5 38 | store_name: chroma_education 39 | - name: rag_union 40 | agent_id: 2fb8e8f0-bd05-5eca-8e4d-376ede293e52 41 | description: Поиск по документам политики развития образования 42 | class_path: examples.pipelines.rag_agent.RAGAgent 43 | default_params: 44 | max_input_tokens: 6144 45 | max_chat_history_token_length: 24576 46 | retrieving_top_k: 2 47 | generator_context_top_k: 2 48 | include_original_question_in_queries: True 49 | planner_model_name: planner_llm 50 | generator_model_name: generator_llm 51 | tokenizer_name: qwen_2.5 52 | store_name: chroma_union 53 | vector_stores: 54 | - type: chroma 55 | params: 56 | name: chroma_education 57 | description: vector_store_description 58 | host: 0.0.0.0 59 | port: 57776 60 | collection_name: test_collection_1 61 | embeddings_model_name: e5-mistral-7b-instruct 62 | - type: chroma 63 | params: 64 | name: chroma_environment 65 | description: vector_store_description 66 | host: 0.0.0.0 67 | port: 57776 68 | collection_name: test_collection_2 69 | embeddings_model_name: e5-mistral-7b-instruct 70 | - type: chroma 71 | params: 72 | name: chroma_union 73 | description: vector_store_description 74 | host: 0.0.0.0 75 | port: 57776 76 | collection_name: test_collection_3 77 | embeddings_model_name: e5-mistral-7b-instruct -------------------------------------------------------------------------------- /protollm_tools/llm-agents-api/tests/docker-compose.test.yml: -------------------------------------------------------------------------------- 1 | services: 2 | db: 3 | image: postgres:16 4 | ports: 5 | - "55530:5432" 6 | environment: 7 | - POSTGRES_PASSWORD=test 8 | - POSTGRES_DB=test 9 | - POSTGRES_USER=test 10 | 11 | cache: 12 | image: redis:7.2-alpine 13 | ports: 14 | - "55531:6379" 15 | 16 | vectorstore: 17 | image: chromadb/chroma:0.5.11 18 | ports: 19 | - "57776:8000" 20 | -------------------------------------------------------------------------------- /protollm_tools/llm-api/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nginx/unit:1.28.0-python3.10 2 | 3 | COPY protollm_api /app/protollm_api 4 | COPY requirements.txt /app 5 | COPY unit_config.json /docker-entrypoint.d/config.json 6 | WORKDIR /app 7 | 8 | RUN pip install --upgrade pip 9 | RUN pip install -r requirements.txt 10 | RUN apt install git 11 | 12 | COPY unit_config.json /docker-entrypoint.d/config.json -------------------------------------------------------------------------------- /protollm_tools/llm-api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/llm-api/__init__.py -------------------------------------------------------------------------------- /protollm_tools/llm-api/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.5' 2 | 3 | services: 4 | api: 5 | container_name: llm-api 6 | image: llm-api-image 7 | build: 8 | context: . 9 | dockerfile: Dockerfile 10 | ports: 11 | - ${API_PORT}:6672 12 | env_file: 13 | - .env 14 | volumes: 15 | - ./unit_config.json:/docker-entrypoint.d/unit_config.json 16 | networks: 17 | - llm_wrap_network 18 | 19 | rabbitmq: 20 | image: "rabbitmq:3-management" 21 | ports: 22 | - ${RABBIT_MQ_PORT}:5672 23 | - ${WEB_RABBIT_MQ}:15672 24 | env_file: 25 | - .env 26 | volumes: 27 | - rabbitmq_data:/var/lib/rabbitmq 28 | networks: 29 | - llm_wrap_network 30 | 31 | redis: 32 | image: "redis:alpine" 33 | ports: 34 | - ${REDIS_PORT}:6379 35 | volumes: 36 | - redis_data:/var/lib/data 37 | networks: 38 | - llm_wrap_network 39 | 40 | networks: 41 | llm_wrap_network: 42 | name: llm_wrap_network 43 | driver: bridge 44 | 45 | volumes: 46 | rabbitmq_data: 47 | redis_data: 48 | -------------------------------------------------------------------------------- /protollm_tools/llm-api/protollm_api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/llm-api/protollm_api/__init__.py -------------------------------------------------------------------------------- /protollm_tools/llm-api/protollm_api/backend/endpoints.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from fastapi import APIRouter 3 | from protollm_api.backend.broker import send_task, get_result 4 | from protollm_api.config import Config 5 | from protollm_sdk.models.job_context_models import ( 6 | PromptModel, ResponseModel, ChatCompletionModel, 7 | PromptTransactionModel, ChatCompletionTransactionModel, 8 | PromptTypes 9 | ) 10 | from protollm_sdk.object_interface.redis_wrapper import RedisWrapper 11 | from protollm_sdk.object_interface.rabbit_mq_wrapper import RabbitMQWrapper 12 | 13 | logging.basicConfig(level=logging.INFO) 14 | logger = logging.getLogger(__name__) 15 | 16 | def get_router(config: Config) -> APIRouter: 17 | router = APIRouter( 18 | prefix="", 19 | tags=["root"], 20 | responses={404: {"description": "Not found"}}, 21 | ) 22 | 23 | redis_db = RedisWrapper(config.redis_host, config.redis_port) 24 | rabbitmq = RabbitMQWrapper(config.rabbit_host, config.rabbit_port, config.rabbit_login, config.rabbit_password) 25 | 26 | @router.post('/generate', response_model=ResponseModel) 27 | async def generate(prompt_data: PromptModel, queue_name: str = config.queue_name): 28 | transaction_model = ChatCompletionTransactionModel( 29 | prompt=ChatCompletionModel.from_prompt_model(prompt_data), 30 | prompt_type=PromptTypes.CHAT_COMPLETION.value 31 | ) 32 | await send_task(config, queue_name, transaction_model, rabbitmq) 33 | logger.info(f"Task {prompt_data.job_id} was sent to LLM.") 34 | return await get_result(config, prompt_data.job_id, redis_db) 35 | 36 | @router.post('/chat_completion', response_model=ResponseModel) 37 | async def chat_completion(prompt_data: ChatCompletionModel, queue_name: str = config.queue_name): 38 | transaction_model = ChatCompletionTransactionModel( 39 | prompt=prompt_data, 40 | prompt_type=PromptTypes.CHAT_COMPLETION.value 41 | ) 42 | await send_task(config, queue_name, transaction_model, rabbitmq) 43 | logger.info(f"Task {prompt_data.job_id} was sent to LLM.") 44 | return await get_result(config, prompt_data.job_id, redis_db) 45 | 46 | return router 47 | -------------------------------------------------------------------------------- /protollm_tools/llm-api/protollm_api/backend/main.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from protollm_api.config import Config 3 | from protollm_api.backend.endpoints import get_router 4 | 5 | app = FastAPI() 6 | 7 | config = Config.read_from_env() 8 | 9 | app.include_router(get_router(config)) 10 | -------------------------------------------------------------------------------- /protollm_tools/llm-api/protollm_api/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class Config: 5 | def __init__( 6 | self, 7 | inner_llm_url: str = "localhost:8670", 8 | redis_host: str = "localhost", 9 | redis_port: int = 6379, 10 | redis_prefix: str = "llm-api", 11 | rabbit_host: str = "localhost", 12 | rabbit_port: int = 5672, 13 | rabbit_login: str = "admin", 14 | rabbit_password: str = "admin", 15 | queue_name: str = "llm-api-queue", 16 | queue_durable: bool=True, 17 | base_priority: int=1 18 | ): 19 | self.inner_lln_url = inner_llm_url 20 | self.redis_host = redis_host 21 | self.redis_port = redis_port 22 | self.redis_prefix = redis_prefix 23 | self.rabbit_host = rabbit_host 24 | self.rabbit_port = rabbit_port 25 | self.rabbit_login = rabbit_login 26 | self.rabbit_password = rabbit_password 27 | self.queue_name = queue_name 28 | self.queue_durable = queue_durable 29 | self.base_priority = base_priority 30 | 31 | @classmethod 32 | def read_from_env(cls) -> 'Config': 33 | return Config( 34 | os.environ.get("INNER_LLM_URL"), 35 | os.environ.get("REDIS_HOST"), 36 | int(os.environ.get("REDIS_PORT")), 37 | os.environ.get("REDIS_PREFIX"), 38 | os.environ.get("RABBIT_MQ_HOST"), 39 | int(os.environ.get("RABBIT_MQ_PORT")), 40 | os.environ.get("RABBIT_MQ_LOGIN"), 41 | os.environ.get("RABBIT_MQ_PASSWORD"), 42 | os.environ.get("QUEUE_NAME"), 43 | bool(os.environ.get("QUEUE_DURABLE")), 44 | int(os.getenv("BASE_PRIORITY")) 45 | ) 46 | 47 | @classmethod 48 | def read_from_env_file(cls, path: str) -> 'Config': 49 | with open(path) as file: 50 | lines = file.readlines() 51 | env_vars = {} 52 | for line in lines: 53 | key, value = line.split("=") 54 | env_vars[key] = value 55 | return Config( 56 | env_vars.get("INNER_LLM_URL"), 57 | env_vars.get("REDIS_HOST"), 58 | int(env_vars.get("REDIS_PORT")), 59 | env_vars.get("REDIS_PREFIX"), 60 | env_vars.get("RABBIT_MQ_HOST"), 61 | int(env_vars.get("RABBIT_MQ_PORT")), 62 | env_vars.get("RABBIT_MQ_LOGIN"), 63 | env_vars.get("RABBIT_MQ_PASSWORD"), 64 | env_vars.get("QUEUE_NAME"), 65 | bool(env_vars.get("QUEUE_DURABLE")), 66 | int(env_vars.get("BASE_PRIORITY")) 67 | ) 68 | -------------------------------------------------------------------------------- /protollm_tools/llm-api/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "protollm-api" 3 | version = "1.0.5" 4 | description = "" 5 | authors = ["aimclub"] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.10" 10 | redis = "^5.0.5" 11 | pika = "^1.3.2" 12 | pydantic = "^2.7.4" 13 | flower = "2.0.1" 14 | 15 | protollm_sdk = "^1.1.6" 16 | 17 | [tool.poetry.group.dev.dependencies] 18 | pytest = "^8.2.2" 19 | pytest-asyncio = "^0.24.0" 20 | uvicorn = "^0.34.0" 21 | 22 | -------------------------------------------------------------------------------- /protollm_tools/llm-api/tests/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /protollm_tools/llm-api/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from protollm_api.config import Config 4 | 5 | 6 | @pytest.fixture(scope="module") 7 | def test_local_config(): 8 | return Config() 9 | 10 | 11 | @pytest.fixture(scope="module") 12 | def test_real_config(): 13 | return Config.read_from_env() 14 | -------------------------------------------------------------------------------- /protollm_tools/llm-api/tests/integration/test_local_Redis.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import uuid 3 | 4 | import pytest 5 | from protollm_sdk.models.job_context_models import ResponseModel 6 | from protollm_sdk.object_interface.redis_wrapper import RedisWrapper 7 | 8 | from protollm_api.backend.broker import get_result 9 | 10 | 11 | @pytest.fixture(scope="module") 12 | def redis_client(test_local_config): 13 | assert test_local_config.redis_host == "localhost" 14 | client = RedisWrapper(test_local_config.redis_host, test_local_config.redis_port) 15 | return client 16 | 17 | 18 | @pytest.mark.asyncio 19 | async def test_get_result_from_local_redis(test_local_config, redis_client): 20 | task_id = str(uuid.uuid4()) 21 | redis_key = f"{test_local_config.redis_prefix}:{task_id}" 22 | expected_content = {'content': 'success'} 23 | 24 | result_task = asyncio.create_task(get_result(test_local_config, task_id, redis_client)) 25 | 26 | await asyncio.sleep(1) 27 | 28 | redis_client.save_item(redis_key, expected_content) 29 | 30 | response = await result_task 31 | 32 | assert isinstance(response, ResponseModel) 33 | assert response.content == 'success' 34 | -------------------------------------------------------------------------------- /protollm_tools/llm-api/tests/integration/test_with_llm.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import pytest 4 | from protollm_sdk.models.job_context_models import ( 5 | ChatCompletionModel, PromptMeta, ChatCompletionUnit, 6 | ChatCompletionTransactionModel, PromptTypes 7 | ) 8 | from protollm_sdk.models.job_context_models import ResponseModel 9 | from protollm_sdk.object_interface.redis_wrapper import RedisWrapper 10 | 11 | from protollm_api.backend.broker import get_result 12 | from protollm_api.backend.broker import send_task 13 | 14 | 15 | @pytest.fixture(scope="module") 16 | def redis_client(test_real_config): 17 | assert test_real_config.redis_host != "localhost" 18 | client = RedisWrapper(test_real_config.redis_host, test_real_config.redis_port) 19 | return client 20 | 21 | 22 | @pytest.mark.asyncio 23 | @pytest.mark.skip(reason="Test waits infinitely in GitHub Action") 24 | async def test_task_in_queue(test_real_config, redis_client): 25 | task_id = str(uuid.uuid4()) 26 | prompt = ChatCompletionModel( 27 | job_id=task_id, 28 | meta=PromptMeta(), 29 | messages=[ChatCompletionUnit(role="user", content="Сколько будет 2+2*2?")] 30 | ) 31 | transaction = ChatCompletionTransactionModel(prompt=prompt, prompt_type=PromptTypes.CHAT_COMPLETION.value) 32 | 33 | await send_task(test_real_config, test_real_config.queue_name, transaction) 34 | 35 | result = await get_result(test_real_config, task_id, redis_client) 36 | 37 | assert isinstance(result, ResponseModel) 38 | assert result.content != "" 39 | -------------------------------------------------------------------------------- /protollm_tools/llm-api/tests/unit/test_brocker.py: -------------------------------------------------------------------------------- 1 | import json 2 | import uuid 3 | from unittest.mock import AsyncMock, patch, MagicMock, ANY 4 | 5 | import pytest 6 | from protollm_sdk.models.job_context_models import ResponseModel, ChatCompletionTransactionModel, ChatCompletionModel, \ 7 | PromptMeta, ChatCompletionUnit, PromptTypes 8 | 9 | from protollm_api.backend.broker import send_task, get_result 10 | 11 | 12 | @pytest.mark.asyncio 13 | async def test_send_task(test_local_config): 14 | prompt = ChatCompletionModel( 15 | job_id=str(uuid.uuid4()), 16 | priority=None, 17 | meta=PromptMeta(), 18 | messages=[ChatCompletionUnit(role="user", content="test request")] 19 | ) 20 | transaction = ChatCompletionTransactionModel(prompt=prompt, prompt_type=PromptTypes.CHAT_COMPLETION.value) 21 | 22 | mock_rabbit = MagicMock() 23 | 24 | await send_task(test_local_config, test_local_config.queue_name, transaction, mock_rabbit) 25 | 26 | mock_rabbit.publish_message.assert_called_once_with(test_local_config.queue_name, ANY, True) 27 | 28 | 29 | @pytest.mark.asyncio 30 | async def test_get_result(test_local_config): 31 | redis_mock = MagicMock() 32 | redis_mock.wait_item = AsyncMock(return_value=json.dumps({"content": "return test success"}).encode()) 33 | task_id = str(uuid.uuid4()) 34 | 35 | response = await get_result(test_local_config, task_id, redis_mock) 36 | 37 | redis_mock.wait_item.assert_called_once_with(f"{test_local_config.redis_prefix}:{task_id}", timeout=90) 38 | assert response == ResponseModel(content="return test success") 39 | 40 | 41 | @pytest.mark.asyncio 42 | async def test_get_result_with_exception(test_local_config): 43 | redis_mock = MagicMock() 44 | redis_mock.wait_item = AsyncMock( 45 | side_effect=[Exception("Redis error"), json.dumps({"content": "return test success"}).encode()]) 46 | task_id = str(uuid.uuid4()) 47 | 48 | response = await get_result(test_local_config, task_id, redis_mock) 49 | 50 | assert redis_mock.wait_item.call_count == 2 51 | assert response == ResponseModel(content="return test success") 52 | -------------------------------------------------------------------------------- /protollm_tools/llm-api/unit_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "listeners": { 3 | "*:6672": { 4 | "pass": "applications/backend" 5 | } 6 | }, 7 | "applications": { 8 | "backend": { 9 | "type": "python", 10 | "path": ".", 11 | "module": "protollm_api.backend.main", 12 | "callable": "app" 13 | } 14 | } 15 | } -------------------------------------------------------------------------------- /protollm_tools/llm-worker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | 3 | RUN apt update && apt upgrade -y 4 | RUN apt install pip -y 5 | RUN apt install python3 -y 6 | RUN apt install git -y 7 | 8 | RUN apt install build-essential -y 9 | RUN apt install gcc-11 g++-11 -y 10 | 11 | COPY requirements.txt ./requirements.txt 12 | 13 | RUN pip install -r requirements.txt 14 | 15 | COPY . . 16 | 17 | CMD python3 main.py 18 | -------------------------------------------------------------------------------- /protollm_tools/llm-worker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/llm-worker/__init__.py -------------------------------------------------------------------------------- /protollm_tools/llm-worker/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | llm: 5 | container_name: llm-worker 6 | image: llm-core:latest 7 | # runtime: nvidia 8 | # deploy: 9 | # resources: 10 | # limits: 11 | # # cpus: 5 12 | # memory: 100G 13 | build: 14 | context: . 15 | dockerfile: Dockerfile 16 | env_file: .env 17 | # volumes: 18 | # - :/data 19 | ports: 20 | - ${LLM_WORKER_PORT}:8672 21 | networks: 22 | - llm_wrap_network 23 | restart: unless-stopped 24 | 25 | networks: 26 | llm_wrap_network: 27 | name: llm_wrap_network 28 | driver: bridge 29 | -------------------------------------------------------------------------------- /protollm_tools/llm-worker/main.py: -------------------------------------------------------------------------------- 1 | from protollm_worker.models.open_api_llm import OpenAPILLM 2 | from protollm_worker.services.broker import LLMWrap 3 | from protollm_worker.config import Config 4 | 5 | if __name__ == "__main__": 6 | config = Config.read_from_env() 7 | llm_model = OpenAPILLM(model_url="https://api.vsegpt.ru/v1", 8 | token="sk-or-vv-7fcc4ab944ca013feb7608fb7c0f001e5c12c32abf66233aad414183b4191a79", 9 | default_model="openai/gpt-4o-2024-08-06", 10 | # app_tag="test_protollm_worker" 11 | ) 12 | # llm_model = VllMModel(model_path=config.model_path, 13 | # tensor_parallel_size=config.tensor_parallel_size, 14 | # gpu_memory_utilisation=config.gpu_memory_utilisation, 15 | # tokens_len=config.token_len) 16 | llm_wrap = LLMWrap(llm_model=llm_model, 17 | config= config) 18 | llm_wrap.start_connection() 19 | -------------------------------------------------------------------------------- /protollm_tools/llm-worker/protollm_worker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/llm-worker/protollm_worker/__init__.py -------------------------------------------------------------------------------- /protollm_tools/llm-worker/protollm_worker/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/llm-worker/protollm_worker/models/__init__.py -------------------------------------------------------------------------------- /protollm_tools/llm-worker/protollm_worker/models/cpp_models.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from llama_cpp import Llama 4 | from protollm_sdk.models.job_context_models import PromptModel, ChatCompletionModel, PromptTransactionModel, \ 5 | ChatCompletionTransactionModel, PromptTypes 6 | 7 | from protollm_worker.models.base import BaseLLM, LocalLLM 8 | 9 | logging.basicConfig(level=logging.INFO) 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class CppModel(LocalLLM ,BaseLLM): 14 | def __init__(self, model_path, n_ctx=8192): 15 | super().__init__(model_path) 16 | 17 | self.model = Llama( 18 | model_path=model_path, 19 | n_ctx=n_ctx * 2, 20 | verbose=True, 21 | n_gpu_layers=-1, 22 | ) 23 | self.handlers = { 24 | PromptTypes.SINGLE_GENERATION.value: self.generate, 25 | PromptTypes.CHAT_COMPLETION.value: self.create_completion, 26 | } 27 | 28 | def __call__(self, transaction: PromptTransactionModel | ChatCompletionTransactionModel): 29 | prompt_type: PromptTypes = transaction.prompt_type 30 | func = self.handlers[prompt_type] 31 | return func(transaction.prompt, **transaction.prompt.meta.model_dump()) 32 | 33 | def generate( 34 | self, 35 | prompt: PromptModel, 36 | tokens_limit=None, 37 | temperature=None, 38 | repeat_penalty=1.1, 39 | stop_words=None, 40 | **kwargs 41 | ): 42 | if temperature is None: 43 | temperature = 0.5 44 | if stop_words is None: 45 | stop_words = [] 46 | logger.info(f"start generated from single prompt {prompt.content} and temp {temperature}") 47 | generated_text = self.model( 48 | prompt.content, 49 | temperature=temperature, 50 | repeat_penalty=repeat_penalty, 51 | max_tokens=tokens_limit, 52 | stop=stop_words, 53 | 54 | ) 55 | response = generated_text['choices'][0]['text'] 56 | return response 57 | 58 | def create_completion( 59 | self, 60 | prompt: ChatCompletionModel, 61 | tokens_limit=None, 62 | temperature=None, 63 | repeat_penalty=1.1, 64 | stop_words=None, 65 | **kwargs 66 | ): 67 | if temperature is None: 68 | temperature = 0.5 69 | if stop_words is None: 70 | stop_words = [] 71 | logger.info(f"start generated from chat completion {prompt.messages}") 72 | messages = prompt.model_dump()['messages'] 73 | response = self.model.create_chat_completion( 74 | messages=messages, 75 | temperature=temperature, 76 | repeat_penalty=repeat_penalty, 77 | max_tokens=tokens_limit, 78 | stop=stop_words, 79 | ) 80 | return response['choices'][0]['message']['content'] 81 | -------------------------------------------------------------------------------- /protollm_tools/llm-worker/protollm_worker/models/hf_models.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/llm-worker/protollm_worker/models/hf_models.py -------------------------------------------------------------------------------- /protollm_tools/llm-worker/protollm_worker/services/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/llm-worker/protollm_worker/services/__init__.py -------------------------------------------------------------------------------- /protollm_tools/llm-worker/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "protollm-worker" 3 | version = "1.0.5" 4 | description = "" 5 | authors = ["aimclub"] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.10" 10 | redis = "^5.0.5" 11 | pika = "^1.3.2" 12 | pydantic = "^2.7.4" 13 | protollm_sdk = "^1.1.6" 14 | 15 | [toll.poetry.llama-cpp] 16 | llama-cpp-python = "^0.3.3" -------------------------------------------------------------------------------- /protollm_tools/sdk/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10 2 | 3 | COPY . . 4 | 5 | WORKDIR . 6 | 7 | RUN pip install -r requirements.txt 8 | -------------------------------------------------------------------------------- /protollm_tools/sdk/README.md: -------------------------------------------------------------------------------- 1 | SDK module -------------------------------------------------------------------------------- /protollm_tools/sdk/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.5' 2 | 3 | services: 4 | rabbitmq: 5 | image: "rabbitmq:3-management" 6 | ports: 7 | - ${RABBIT_PORT}:5672 8 | - ${WEB_RABBIT_MQ}:15672 9 | env_file: 10 | - .env 11 | volumes: 12 | - "rabbitmq_data:/var/lib/rabbitmq" 13 | networks: 14 | - llm_wrap_network 15 | 16 | redis: 17 | image: "redis:alpine" 18 | ports: 19 | - ${REDIS_PORT}:6379 20 | volumes: 21 | - redis_data:/var/lib/data 22 | networks: 23 | - llm_wrap_network 24 | 25 | celery_worker: 26 | build: . 27 | depends_on: 28 | - rabbitmq 29 | - redis 30 | networks: 31 | - llm_wrap_network 32 | env_file: 33 | - .env 34 | command: celery -A protollm_sdk.celery.app worker --loglevel=info 35 | 36 | flower: 37 | build: . 38 | ports: 39 | - ${FLOWER_PORT}:7672 40 | depends_on: 41 | - rabbitmq 42 | - celery_worker 43 | networks: 44 | - llm_wrap_network 45 | env_file: 46 | - .env 47 | command: sh -c "sleep 20 && celery -A protollm_sdk.celery.app flower --broker=${CELERY_BROKER_URL} --port=7672" 48 | 49 | server: 50 | image: chromadb/chroma:latest 51 | env_file: 52 | - .env 53 | environment: 54 | - IS_PERSISTENT=TRUE 55 | ports: 56 | - ${VECTOR_PORT}:8000 57 | networks: 58 | - llm_wrap_network 59 | 60 | embedding_server: 61 | image: ${EMBEDDING_IMAGE} 62 | command: --model-id ${ST_MODEL} --revision ${ST_MODEL_REVISION} 63 | ports: 64 | - ${EMBEDER_PORT}:80 65 | networks: 66 | - llm_wrap_network 67 | 68 | volumes: 69 | rabbitmq_data: 70 | redis_data: 71 | 72 | networks: 73 | llm_wrap_network: 74 | name: llm_wrap_network 75 | driver: bridge 76 | -------------------------------------------------------------------------------- /protollm_tools/sdk/examples/celery.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from protollm_sdk.config import Config 4 | from protollm_sdk.celery.app import task_test 5 | from protollm_sdk.celery.job import TextEmbedderJob, ResultStorageJob, LLMAPIJob, \ 6 | VectorDBJob, OuterLLMAPIJob # , LangchainLLMAPIJob 7 | from protollm_sdk.object_interface import RedisWrapper 8 | 9 | 10 | def embed(): 11 | """An example of using an embedder with celery""" 12 | text_embedder_request = {"job_id": "0", 13 | "inputs": "Какой-то умный текст. Или не очень умный.", 14 | "truncate": False} 15 | random_id = uuid.uuid4() 16 | result = task_test.apply_async(args=(TextEmbedderJob.__name__, random_id), kwargs=text_embedder_request) 17 | print(result.get()) 18 | 19 | 20 | def store_results(): 21 | """An example of using a storage with celery""" 22 | random_id = uuid.uuid4() 23 | result_storage = {"job_id": random_id, 24 | "result": {"question": "Очень умный вопрос.", 25 | "answers": "Не очень умный ответ"}} 26 | 27 | result = task_test.apply_async(args=(ResultStorageJob.__name__, random_id), kwargs=result_storage) 28 | print(result.get()) 29 | 30 | 31 | def llm_resp(): 32 | """An example of using a llm with celery""" 33 | meta = {"temperature": 0.5, 34 | "tokens_limit": 10, 35 | "stop_words": None} 36 | llm_request = {"job_id": str(uuid.uuid4()), 37 | "meta": meta, 38 | "content": "Сколько попугаев Какаду в одном метре?"} 39 | result = task_test.apply_async(args=(LLMAPIJob.__name__, llm_request["job_id"]), kwargs=llm_request) 40 | 41 | print(result.get()) 42 | 43 | 44 | async def out_llm_resp(redis_client: RedisWrapper): 45 | """An example of using a outer llm with celery""" 46 | meta = {"temperature": 0.2, 47 | "tokens_limit": 4096, 48 | "stop_words": None} 49 | llm_request = {"job_id": str(uuid.uuid4()), 50 | "meta": meta, 51 | "content": "Монтаж оголовников, Сборка опор/порталов, Подвеска провода, Укладка активного соляного заземления"} 52 | 53 | task_test.apply_async(args=(OuterLLMAPIJob.__name__, llm_request["job_id"]), kwargs=llm_request) 54 | result = await redis_client.wait_item(f"{OuterLLMAPIJob.__name__}:{llm_request['job_id']}", timeout=60) 55 | 56 | def get_dict(key): 57 | rd = RedisWrapper(redis_host=Config.redis_host, 58 | redis_port=Config.redis_port) 59 | resp = rd.get_item(key) 60 | print(resp) 61 | 62 | 63 | def vector_db(): 64 | random_id = uuid.uuid4() 65 | result = task_test.apply_async(args=(VectorDBJob.__name__, random_id), task_id=random_id) 66 | print(result.get()) 67 | -------------------------------------------------------------------------------- /protollm_tools/sdk/protollm_sdk/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/sdk/protollm_sdk/__init__.py -------------------------------------------------------------------------------- /protollm_tools/sdk/protollm_sdk/celery/app.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from celery import Celery 4 | from kombu.serialization import registry 5 | 6 | from protollm_sdk.celery.config import CeleryConfig 7 | from protollm_sdk.celery.constants import JOB_NAME2CLASS 8 | from protollm_sdk.jobs.job import Job 9 | from protollm_sdk.jobs.utility import construct_job_context 10 | 11 | 12 | def init_celery(celery_config: CeleryConfig) -> Celery: 13 | init_args, init_kwargs = celery_config.init_args 14 | 15 | celery = Celery(*init_args, **init_kwargs) 16 | 17 | if celery_config.conf_update: 18 | celery.conf.update(**celery_config.conf_update) 19 | 20 | if celery_config.formats: 21 | for f in celery_config.formats: 22 | registry.enable(f) 23 | 24 | return celery 25 | 26 | celery_config = CeleryConfig() 27 | celery = init_celery(celery_config) 28 | 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | 33 | @celery.task(**celery_config.task_kwargs) 34 | def task_test(task_class: str, task_id: str, **kwargs): # noqa 35 | ctx = construct_job_context(task_class, abstract_task) 36 | if job_class := JOB_NAME2CLASS.get(task_class): 37 | job = job_class() 38 | else: 39 | msg = f"Error in task '{task_id}'. Unknown job class: '{task_class}'." 40 | logger.error(msg) 41 | raise Exception(msg) 42 | 43 | logger.info(f"Starting task '{task_id}'. Job '{task_class}'.") 44 | 45 | forecast = job.run(task_id=task_id, ctx=ctx, **kwargs) 46 | return forecast 47 | 48 | 49 | @celery.task(**celery_config.task_kwargs) 50 | def abstract_task(task_class: type[Job], task_id: str, **kwargs): 51 | class_name = task_class.__name__ if isinstance(task_class, type) else task_class.__class__.__name__ 52 | ctx = construct_job_context(class_name, abstract_task) 53 | logger.info(f"Starting task '{task_id}'. Job '{class_name}'.") 54 | job_object = task_class() if isinstance(task_class, type) else task_class 55 | job_object.run(task_id=task_id, ctx=ctx, **kwargs) 56 | -------------------------------------------------------------------------------- /protollm_tools/sdk/protollm_sdk/celery/constants.py: -------------------------------------------------------------------------------- 1 | from protollm_sdk.celery.job import TextEmbedderJob, LLMAPIJob, ResultStorageJob, VectorDBJob, OuterLLMAPIJob 2 | 3 | JOBS = {TextEmbedderJob, LLMAPIJob, ResultStorageJob, VectorDBJob, OuterLLMAPIJob} 4 | 5 | JOB_NAME2CLASS = {cls.__name__: cls for cls in JOBS} 6 | -------------------------------------------------------------------------------- /protollm_tools/sdk/protollm_sdk/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | 5 | @dataclass(frozen=True) 6 | class Config: 7 | llm_api_host = os.environ.get("LLM_API_HOST", "localhost") 8 | llm_api_port = os.environ.get("LLM_API_PORT", "6672") 9 | 10 | outer_llm_key = os.environ.get("OUTER_LLM_KEY", "sk-or-vv-c49f40fdb086053ec32c6ae2723b8d222cb7767f3b98527e7ae282986e7d33ed") 11 | 12 | redis_host = os.environ.get("REDIS_HOST", "localhost") 13 | redis_port = os.environ.get("REDIS_PORT", "6379") 14 | 15 | rabbit_mq_host = os.environ.get("RABBIT_HOST", "localhost") 16 | rabbit_mq_port = os.environ.get("RABBIT_PORT", "5672") 17 | rabbit_mq_login = os.environ.get("RABBIT_MQ_LOGIN", "admin") 18 | rabbit_mq_password = os.environ.get("RABBIT_MQ_PASSWORD", "admin") 19 | 20 | text_embedder_host = os.environ.get("TEXT_EMB_HOST", "localhost") 21 | text_embedder_port = os.environ.get("TEXT_EMB_PORT", "9942") 22 | 23 | vector_bd_host = os.environ.get("VECTOR_HOST", "localhost") 24 | vector_db_port = os.environ.get("VECTOR_PORT", "9941") 25 | 26 | job_invocation_type = os.environ.get("JOB_INVOCATION_TYPE", "worker") 27 | 28 | celery_queue_name = os.environ.get("CELERY_QUEUE_NAME", "celery") 29 | 30 | @classmethod 31 | def reload_invocation_type(cls): 32 | cls.job_invocation_type = os.environ.get("JOB_INVOCATION_TYPE", "worker") 33 | -------------------------------------------------------------------------------- /protollm_tools/sdk/protollm_sdk/jobs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/sdk/protollm_sdk/jobs/__init__.py -------------------------------------------------------------------------------- /protollm_tools/sdk/protollm_sdk/jobs/job.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import TypeVar 3 | 4 | from protollm_sdk.jobs.job_context import JobContext 5 | 6 | 7 | class Job(ABC): 8 | """ 9 | Job interface for integration with outer modules to SDK. 10 | All the required data for job should be passed and parameters should be defined in advance, 11 | this also applies to the run method 12 | """ 13 | 14 | @abstractmethod 15 | def run(self, job_id: str, ctx: JobContext, **kwargs): 16 | """ 17 | Run the job. The job can use a number of functions defined in the module and service functions from ctx. 18 | After that, using the ctx, it saves the result to Redis. 19 | The method should not return any data. If an error occurs, 20 | it is thrown inside the run method via the `raise`, `raise ex` or `raise … from ex` statement. 21 | 22 | :param job_id: job id 23 | :type job_id: str 24 | :param ctx: contextual services 25 | :type ctx: JobContext 26 | :param kwargs: The data and parameters required to complete the task are set via key arguments 27 | :return: None 28 | :raises TypeError: Error 29 | 30 | """ 31 | pass 32 | 33 | 34 | TJob = TypeVar('TJob', bound=Job) 35 | -------------------------------------------------------------------------------- /protollm_tools/sdk/protollm_sdk/jobs/job_context.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import logging 4 | 5 | from protollm_sdk.jobs.outer_llm_api import OuterLLMAPI 6 | from protollm_sdk.jobs.job_invoke import JobInvoker 7 | from protollm_sdk.jobs.vector_db import VectorDB 8 | from protollm_sdk.jobs.result_storage import ResultStorage 9 | from protollm_sdk.jobs.text_embedder import TextEmbedder 10 | from protollm_sdk.jobs.llm_api import LLMAPI 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | @dataclass 16 | class JobContext: 17 | """ 18 | The class contains contextual services for executing Job 19 | """ 20 | llm_api: LLMAPI 21 | outer_llm_api: OuterLLMAPI 22 | text_embedder: TextEmbedder 23 | result_storage: ResultStorage 24 | vector_db: VectorDB 25 | job_invoker: JobInvoker 26 | 27 | 28 | -------------------------------------------------------------------------------- /protollm_tools/sdk/protollm_sdk/jobs/text_embedder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from urllib.parse import urljoin 3 | 4 | import httpx 5 | 6 | from protollm_sdk.models.job_context_models import TextEmbedderRequest, TextEmbedderResponse, ToEmbed 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class TextEmbedder: 12 | """ 13 | Class provides an object interface to a text embedder. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | text_emb_host: str, 19 | text_emb_port: str | int | None = None, 20 | timeout_sec: int = 10 * 60, 21 | ): 22 | """ 23 | Initialize TextEmbedder 24 | 25 | :param text_emb_host: host of the text embedder 26 | :type text_emb_host: str 27 | :param text_emb_port: port of the text embedder 28 | :type text_emb_port: str | int | None 29 | :param timeout_sec: timeout in seconds 30 | :type timeout_sec: int 31 | """ 32 | self.path = f"http://{text_emb_host}:{text_emb_port}" if text_emb_port is not None else f"http://{text_emb_host}" 33 | self.timeout_sec = timeout_sec 34 | self.client = httpx.Client() 35 | 36 | def inference(self, request: TextEmbedderRequest) -> TextEmbedderResponse: 37 | """ 38 | Create an embedding for the input text 39 | 40 | :param request: request 41 | :type request: TextEmbedderRequest 42 | :return: TextEmbedderResponse 43 | """ 44 | try: 45 | emb = ToEmbed(inputs=request.inputs, truncate=request.truncate) 46 | response = self.client.post( 47 | urljoin(self.path, "/embed"), 48 | headers={"Content-type": "application/json"}, 49 | json=emb.model_dump(), 50 | timeout=self.timeout_sec 51 | ) 52 | if response.status_code == 500: 53 | raise ConnectionError('The LLM server is not available.') 54 | elif response.status_code == 422: 55 | raise ValueError(f'Data model validation error. {response.json()}') 56 | result = response.json()[0] 57 | logger.info("The embedding has been completed successfully.") 58 | text_embedder_result = {"embeddings": result} 59 | return TextEmbedderResponse(**text_embedder_result) 60 | except Exception as ex: 61 | msg = f"The embedding was interrupted. Error: {ex}." 62 | logger.info(msg) 63 | raise Exception(msg) 64 | -------------------------------------------------------------------------------- /protollm_tools/sdk/protollm_sdk/jobs/utility.py: -------------------------------------------------------------------------------- 1 | from protollm_sdk.config import Config 2 | from protollm_sdk.jobs.job_context import JobContext 3 | from protollm_sdk.jobs.job_invoke import JobInvoker, InvokeType 4 | from protollm_sdk.jobs.llm_api import LLMAPI 5 | from protollm_sdk.jobs.outer_llm_api import OuterLLMAPI 6 | from protollm_sdk.jobs.result_storage import ResultStorage 7 | from protollm_sdk.jobs.text_embedder import TextEmbedder 8 | from protollm_sdk.jobs.vector_db import VectorDB 9 | 10 | 11 | def construct_job_context(job_name: str, abstract_task=None) -> JobContext: 12 | """ 13 | Create JobContext object with object access to functions and services, based on environment variable values. 14 | 15 | :param job_name: job name 16 | :type job_name: str 17 | :param abstract_task: optional reference to celery.task for recursive Job calling from Job 18 | :return: JobContext object 19 | """ 20 | 21 | llm_api = LLMAPI(Config.llm_api_host, Config.llm_api_port) 22 | 23 | out_llm_api = OuterLLMAPI(Config.outer_llm_key) 24 | 25 | text_embedder = TextEmbedder(Config.text_embedder_host, Config.text_embedder_port) 26 | 27 | result_storage = ResultStorage( 28 | redis_host=Config.redis_host, 29 | redis_port=Config.redis_port, 30 | prefix=job_name 31 | ) 32 | 33 | invoke_type_str = Config.job_invocation_type.lower() 34 | match invoke_type_str: 35 | case "worker": 36 | invoke_type = InvokeType.Worker 37 | case "blocking": 38 | invoke_type = InvokeType.Blocking 39 | case _: 40 | raise ValueError(f"Found unknown invocation type '{invoke_type_str}'.") 41 | job_invoker = JobInvoker(abstract_task, result_storage, invoke_type) 42 | 43 | vector_db = VectorDB(vector_bd_host=Config.vector_bd_host, vector_db_port=Config.vector_db_port) 44 | 45 | return JobContext( 46 | llm_api=llm_api, 47 | outer_llm_api=out_llm_api, 48 | text_embedder=text_embedder, 49 | result_storage=result_storage, 50 | vector_db=vector_db, 51 | job_invoker=job_invoker, 52 | ) 53 | -------------------------------------------------------------------------------- /protollm_tools/sdk/protollm_sdk/jobs/vector_db.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import urljoin 2 | 3 | import httpx 4 | 5 | 6 | class VectorDB: 7 | """ 8 | VectorDB client 9 | """ 10 | 11 | def __init__(self, vector_bd_host: str, vector_db_port: str | int | None = None): 12 | """ 13 | Initialize VectorDB 14 | 15 | :param vector_bd_host: host of vector db 16 | :type vector_bd_host: str 17 | :param vector_db_port: port of vector db 18 | :type vector_db_port: str | int | None 19 | """ 20 | self.url = f"http://{vector_bd_host}:{vector_db_port}" if vector_db_port is not None else f"http://{vector_bd_host}" 21 | self.client = httpx.Client() 22 | 23 | def api_v1(self): 24 | """ 25 | Get api v1 26 | 27 | :return: response 28 | """ 29 | response = self.client.get( 30 | urljoin(self.url, "/api/v1"), 31 | headers={"Content-type": "application/json"}, 32 | timeout=10 * 60 33 | ) 34 | return response.json() 35 | -------------------------------------------------------------------------------- /protollm_tools/sdk/protollm_sdk/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/sdk/protollm_sdk/models/__init__.py -------------------------------------------------------------------------------- /protollm_tools/sdk/protollm_sdk/models/errors.py: -------------------------------------------------------------------------------- 1 | class RedisWaitItemTimeoutError(TimeoutError): 2 | pass 3 | -------------------------------------------------------------------------------- /protollm_tools/sdk/protollm_sdk/models/job_context_models.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Literal, Union 3 | 4 | from pydantic import BaseModel, Field 5 | 6 | 7 | class PromptTypes(Enum): 8 | SINGLE_GENERATION: str = "single_generation" 9 | CHAT_COMPLETION: str = "chat_completion" 10 | 11 | 12 | class PromptMeta(BaseModel): 13 | temperature: float | None = 0.2 14 | tokens_limit: int | None = 8096 15 | stop_words: list[str] | None = None 16 | model: str | None = Field(default=None, examples=[None]) 17 | 18 | 19 | class PromptModel(BaseModel): 20 | job_id: str 21 | priority: int | None = Field(default=None, examples=[None]) 22 | meta: PromptMeta 23 | content: str 24 | 25 | 26 | class ChatCompletionUnit(BaseModel): 27 | """A model for element of chat completion""" 28 | role: str 29 | content: str 30 | 31 | 32 | class ChatCompletionModel(BaseModel): 33 | """A model for chat completion order""" 34 | job_id: str 35 | priority: int | None = Field(default=None, examples=[None]) 36 | source: str = "local" 37 | meta: PromptMeta 38 | messages: list[ChatCompletionUnit] 39 | 40 | @classmethod 41 | def from_prompt_model(cls, prompt_model: PromptModel) -> 'ChatCompletionModel': 42 | # Создаем первое сообщение из содержимого PromptModel 43 | initial_message = ChatCompletionUnit( 44 | role="user", # Или другой подходящий role 45 | content=prompt_model.content 46 | ) 47 | # Возвращаем новый экземпляр ChatCompletionModel 48 | return cls( 49 | job_id=prompt_model.job_id, 50 | priority=prompt_model.priority, 51 | meta=prompt_model.meta, 52 | messages=[initial_message] 53 | ) 54 | 55 | 56 | class PromptTransactionModel(BaseModel): 57 | prompt: PromptModel 58 | prompt_type: Literal[PromptTypes.SINGLE_GENERATION.value] 59 | 60 | 61 | class ChatCompletionTransactionModel(BaseModel): 62 | prompt: ChatCompletionModel 63 | prompt_type: Literal[PromptTypes.CHAT_COMPLETION.value] 64 | 65 | 66 | class PromptWrapper(BaseModel): 67 | prompt: Union[PromptTransactionModel, ChatCompletionTransactionModel] = Field(..., discriminator='prompt_type') 68 | 69 | 70 | class ResponseModel(BaseModel): 71 | content: str 72 | 73 | 74 | class LLMResponse(BaseModel): 75 | job_id: str 76 | text: str 77 | 78 | 79 | class TextEmbedderRequest(BaseModel): 80 | job_id: str 81 | inputs: str 82 | truncate: bool 83 | 84 | 85 | class ToEmbed(BaseModel): 86 | inputs: str 87 | truncate: bool 88 | 89 | 90 | class TextEmbedderResponse(BaseModel): 91 | embeddings: list[float] 92 | -------------------------------------------------------------------------------- /protollm_tools/sdk/protollm_sdk/object_interface/__init__.py: -------------------------------------------------------------------------------- 1 | from protollm_sdk.object_interface.redis_wrapper import RedisWrapper 2 | from protollm_sdk.object_interface.rabbit_mq_wrapper import RabbitMQWrapper -------------------------------------------------------------------------------- /protollm_tools/sdk/protollm_sdk/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/sdk/protollm_sdk/utils/__init__.py -------------------------------------------------------------------------------- /protollm_tools/sdk/protollm_sdk/utils/reddis.py: -------------------------------------------------------------------------------- 1 | from protollm_sdk.config import Config 2 | from protollm_sdk.jobs.result_storage import ResultStorage 3 | from protollm_sdk.object_interface import RedisWrapper 4 | 5 | 6 | def get_reddis_wrapper(): 7 | """ 8 | Create RedisWrapper object for working with Redis 9 | 10 | :return: RedisWrapper object 11 | :rtype: RedisWrapper 12 | """ 13 | return RedisWrapper( 14 | redis_host=Config.redis_host, 15 | redis_port=Config.redis_port 16 | ) 17 | 18 | 19 | def load_result(rd: RedisWrapper, job_id: str, prefix: str or None) -> bytes: 20 | """ 21 | Load result from Redis by job_id and prefix (job_name). 22 | The code returns bytes. If you want to convert these bytes into a ResponseModel, 23 | first decode the bytes into a string using `.decode()`, then create the model 24 | like this: 25 | model_text = byte_data.decode() 26 | response_model = ResponseModel.model_validate_json(model_text) 27 | 28 | :param rd: RedisWrapper object 29 | :type rd: RedisWrapper 30 | :param job_id: uuid of the job 31 | :type job_id: str 32 | :param prefix: prefix of type of job or jon_name 33 | :type prefix: str or None 34 | :return: result 35 | :rtype: bytes 36 | """ 37 | resp = rd.get_item(ResultStorage.build_key(job_id, prefix)) 38 | return resp 39 | -------------------------------------------------------------------------------- /protollm_tools/sdk/protollm_sdk/utils/singleton.py: -------------------------------------------------------------------------------- 1 | class Singleton(type): 2 | _instances = {} 3 | 4 | def __call__(cls, *args, **kwargs): 5 | if cls not in cls._instances: 6 | cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) 7 | return cls._instances[cls] 8 | -------------------------------------------------------------------------------- /protollm_tools/sdk/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "protollm-sdk" 3 | version = "1.2.0" 4 | description = "" 5 | authors = ["aimclub"] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.10" 10 | aioredis = "^2.0.1" 11 | pydantic = "^2.7.4" 12 | celery = "^5.4.0" 13 | kombu = "^5.3.7" 14 | uuid = "^1.30" 15 | redis = "^5.0.6" 16 | flower = "^2.0.1" 17 | pika = "^1.3.2" 18 | urllib3 = "^2.2.2" 19 | requests = "^2.32.3" 20 | fastapi = "^0.111.0" 21 | pydantic-core = "2.23.4" 22 | langchain = "^0.3.4" 23 | httpx = "^0.27.0" 24 | openai = "^1.42.0" 25 | 26 | [tool.poetry.group.dev.dependencies] 27 | pytest = "^8.2.2" 28 | pytest-asyncio = "^0.24.0" 29 | 30 | [tool.pytest.ini_options] 31 | markers = [ 32 | "local: Mark tests as part of the local pipeline (e.g., for Redis/Rabbit/etc)", 33 | "ci: Mark tests as part of the CI pipeline" 34 | ] 35 | 36 | [build-system] 37 | requires = ["poetry-core"] 38 | build-backend = "poetry.core.masonry.api" 39 | -------------------------------------------------------------------------------- /protollm_tools/sdk/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/sdk/tests/__init__.py -------------------------------------------------------------------------------- /protollm_tools/sdk/tests/protollm_sdk/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/sdk/tests/protollm_sdk/__init__.py -------------------------------------------------------------------------------- /protollm_tools/sdk/tests/protollm_sdk/celery/test_app.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import uuid 3 | 4 | import pytest 5 | 6 | from protollm_sdk.celery.app import task_test, abstract_task 7 | from protollm_sdk.celery.job import ResultStorageJob 8 | from protollm_sdk.jobs.job import Job 9 | from protollm_sdk.jobs.utility import construct_job_context 10 | 11 | 12 | @pytest.fixture 13 | def result_storage(): 14 | return {"question": "What is the ultimate question answer?", 15 | "answers": "42"} 16 | 17 | @pytest.mark.ci 18 | def test_task_test_unknown_job_class(caplog): 19 | task_id = str(uuid.uuid4()) 20 | task_class = "unknown_class" 21 | 22 | with pytest.raises(Exception, match=f"Unknown job class: '{task_class}'"): 23 | task_test(task_class=task_class, task_id=task_id) 24 | 25 | assert f"Error in task '{task_id}'. Unknown job class: '{task_class}'." in caplog.text 26 | 27 | @pytest.mark.local 28 | def test_task_test_known_job_class(caplog, result_storage): 29 | caplog.set_level(logging.INFO) 30 | task_id = str(uuid.uuid4()) 31 | task_class = ResultStorageJob.__name__ 32 | 33 | result = task_test(task_class=task_class, task_id=task_id, kwargs=result_storage) 34 | 35 | assert f"Starting task '{task_id}'. Job '{task_class}'." in caplog.text 36 | assert result is None 37 | 38 | 39 | class DummyJob(Job): 40 | def __init__(self): 41 | self.ran = False 42 | 43 | def run(self, task_id, ctx, **kwargs): 44 | self.ran = True 45 | self.task_id = task_id 46 | self.ctx = ctx 47 | self.kwargs = kwargs 48 | 49 | 50 | @pytest.fixture 51 | def dummy_job(): 52 | return DummyJob() 53 | 54 | @pytest.mark.ci 55 | def test_abstract_task_class_input(caplog, dummy_job): 56 | caplog.set_level("INFO") 57 | 58 | task_id = "test_task_id" 59 | task_class = DummyJob 60 | 61 | abstract_task(task_class=task_class, task_id=task_id, test_arg="value") 62 | 63 | assert f"Starting task '{task_id}'. Job 'DummyJob'." in caplog.text 64 | 65 | assert construct_job_context("DummyJob", abstract_task) is not None 66 | 67 | job_instance = task_class() 68 | job_instance.run(task_id=task_id, ctx=construct_job_context("DummyJob", abstract_task), test_arg="value") 69 | assert job_instance.ran is True 70 | assert job_instance.task_id == task_id 71 | assert job_instance.kwargs == {"test_arg": "value"} 72 | 73 | @pytest.mark.ci 74 | def test_abstract_task_instance_input(caplog, dummy_job): 75 | caplog.set_level("INFO") 76 | 77 | task_id = "test_task_id" 78 | 79 | abstract_task(task_class=dummy_job, task_id=task_id, test_arg="value") 80 | 81 | assert f"Starting task '{task_id}'. Job 'DummyJob'." in caplog.text 82 | 83 | assert construct_job_context("DummyJob", abstract_task) is not None 84 | 85 | assert dummy_job.ran is True 86 | assert dummy_job.task_id == task_id 87 | assert dummy_job.kwargs == {"test_arg": "value"} 88 | -------------------------------------------------------------------------------- /protollm_tools/sdk/tests/protollm_sdk/job/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/sdk/tests/protollm_sdk/job/__init__.py -------------------------------------------------------------------------------- /protollm_tools/sdk/tests/protollm_sdk/job/test_job.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import pytest 4 | 5 | from protollm_sdk.celery.app import task_test 6 | from protollm_sdk.celery.job import ( 7 | LLMAPIJob, TextEmbedderJob, ResultStorageJob, VectorDBJob 8 | ) 9 | from protollm_sdk.models.job_context_models import LLMResponse, TextEmbedderResponse 10 | 11 | 12 | @pytest.fixture 13 | def llm_request(): 14 | random_id = uuid.uuid4() 15 | prompt_msg = "What has a head like cat, feet like a kat, tail like a cat, but isn't a cat?" 16 | meta = {"temperature": 0.5, 17 | "tokens_limit": 10, 18 | "stop_words": ["Stop"]} 19 | llm_request = {"job_id": str(random_id), 20 | "meta": meta, 21 | "content": prompt_msg} 22 | return llm_request 23 | 24 | 25 | @pytest.fixture 26 | def text_embedder_request(): 27 | return {"job_id": "0", 28 | "inputs": "Everybody steals and throws, they cut each other and hang each other... " 29 | "In general, normal civilized life is going on. McDonald's everywhere. " 30 | "I don't see them here, by the way. That can't be good.", 31 | "truncate": False} 32 | 33 | 34 | @pytest.fixture 35 | def result_storage(): 36 | return {"question": "What is the ultimate question answer?", 37 | "answers": "42"} 38 | 39 | 40 | @pytest.mark.skip(reason="LLM die sometimes because stupid questions") 41 | def test_llm_request(llm_request): 42 | result = task_test.apply_async(args=(LLMAPIJob.__name__, llm_request["job_id"]), kwargs=llm_request) 43 | r = result.get() 44 | res = LLMResponse(job_id=llm_request["job_id"], text=r.content) 45 | assert isinstance(res, LLMResponse) 46 | 47 | @pytest.mark.local 48 | def test_text_embedder_request(text_embedder_request): 49 | random_id = uuid.uuid4() 50 | result = task_test.apply_async(args=(TextEmbedderJob.__name__, random_id), kwargs=text_embedder_request) 51 | assert isinstance(result.get(), TextEmbedderResponse) 52 | 53 | @pytest.mark.local 54 | def test_result_storage(result_storage): 55 | random_id = uuid.uuid4() 56 | task_test.apply_async(args=(ResultStorageJob.__name__, random_id), kwargs=result_storage) 57 | 58 | @pytest.mark.skip(reason="We don't have local vector DB") 59 | def test_ping_vector_db(): 60 | random_id = uuid.uuid4() 61 | result = task_test.apply_async(args=(VectorDBJob.__name__, random_id)) 62 | r = result.get() 63 | assert isinstance(r, dict) 64 | -------------------------------------------------------------------------------- /protollm_tools/sdk/tests/protollm_sdk/job/test_job_api.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import pytest 4 | 5 | from protollm_sdk.jobs.utility import construct_job_context 6 | from protollm_sdk.models.job_context_models import PromptModel, ChatCompletionModel 7 | 8 | @pytest.fixture 9 | def llm_request(): 10 | random_id = str(uuid.uuid4()) 11 | prompt_msg = "What has a head like cat, feet like a kat, tail like a cat, but isn't a cat?" 12 | meta = {"temperature": 0.5, 13 | "tokens_limit": 1000, 14 | "stop_words": ["Stop"]} 15 | llm_request = {"job_id": random_id, 16 | "meta": meta, 17 | "content": prompt_msg} 18 | request = PromptModel(**llm_request) 19 | return request 20 | 21 | @pytest.fixture 22 | def llm_job_context(llm_request): 23 | jc = construct_job_context(llm_request.job_id) 24 | return jc 25 | 26 | @pytest.mark.local 27 | def test_api_interface(llm_request, llm_job_context): 28 | response = llm_job_context.llm_api.inference(llm_request) 29 | print(response) 30 | 31 | @pytest.mark.local 32 | def test_api_interface_with_queue_name(llm_request, llm_job_context): 33 | response = llm_job_context.llm_api.inference(llm_request, "wq_outer_vsegpt") 34 | print(response) 35 | 36 | @pytest.mark.local 37 | def test_api_chat_completion(llm_request, llm_job_context): 38 | response = llm_job_context.llm_api.chat_completion(ChatCompletionModel.from_prompt_model(llm_request)) 39 | print(response) 40 | 41 | @pytest.mark.local 42 | def test_api_chat_completion_with_queue_name(llm_request, llm_job_context): 43 | response = llm_job_context.llm_api.chat_completion(ChatCompletionModel.from_prompt_model(llm_request), "wq_outer_vsegpt") 44 | print(response) 45 | -------------------------------------------------------------------------------- /protollm_tools/sdk/tests/protollm_sdk/job/test_utility.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | 6 | from protollm_sdk.config import Config 7 | from protollm_sdk.jobs.job_context import JobContext 8 | from protollm_sdk.jobs.job_invoke import InvokeType 9 | from protollm_sdk.jobs.utility import construct_job_context 10 | 11 | @pytest.mark.ci 12 | def test_construct_job_context_real(): 13 | """ 14 | Test construct_job_context function with real environment variables and connections. 15 | This test assumes that all services (Redis, LLM API, Text Embedder, etc.) are properly running. 16 | """ 17 | job_name = "test_job" 18 | job_context = construct_job_context(job_name) 19 | 20 | assert isinstance(job_context, JobContext) 21 | 22 | assert job_context.llm_api is not None 23 | assert job_context.outer_llm_api is not None 24 | assert job_context.text_embedder is not None 25 | assert job_context.result_storage is not None 26 | assert job_context.vector_db is not None 27 | assert job_context.job_invoker is not None 28 | 29 | @pytest.mark.ci 30 | def test_construct_job_context_with_invoke_type_worker(): 31 | """ 32 | Test construct_job_context with a missing environment variable. 33 | """ 34 | with patch.dict(os.environ, {"JOB_INVOCATION_TYPE": "worker"}): 35 | assert os.getenv("JOB_INVOCATION_TYPE") == "worker" 36 | 37 | job_name = "test_job" 38 | job_context = construct_job_context(job_name) 39 | assert job_context.job_invoker._invoke_type == InvokeType.Worker 40 | 41 | @pytest.mark.ci 42 | def test_construct_job_context_with_invoke_type_blocking(): 43 | """ 44 | Test construct_job_context with a missing environment variable. 45 | """ 46 | with patch.dict(os.environ, {"JOB_INVOCATION_TYPE": "blocking"}): 47 | assert os.getenv("JOB_INVOCATION_TYPE") == "blocking" 48 | Config.reload_invocation_type() 49 | job_name = "test_job" 50 | job_context = construct_job_context(job_name) 51 | assert job_context.job_invoker._invoke_type == InvokeType.Blocking 52 | 53 | Config.reload_invocation_type() 54 | 55 | @pytest.mark.ci 56 | def test_construct_job_context_with_wrong_invoke_type(): 57 | """ 58 | Test construct_job_context with a missing environment variable. 59 | """ 60 | with patch.dict(os.environ, {"JOB_INVOCATION_TYPE": "cringe"}): 61 | assert os.getenv("JOB_INVOCATION_TYPE") == "cringe" 62 | Config.reload_invocation_type() 63 | 64 | with pytest.raises(ValueError, match="Found unknown invocation type 'cringe'."): 65 | job_name = "test_job" 66 | construct_job_context(job_name) 67 | 68 | Config.reload_invocation_type() 69 | -------------------------------------------------------------------------------- /protollm_tools/sdk/tests/protollm_sdk/job/test_vector_db.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch, MagicMock 2 | from urllib.parse import urljoin 3 | 4 | import httpx 5 | import pytest 6 | 7 | from protollm_sdk.jobs.vector_db import VectorDB 8 | 9 | @pytest.mark.ci 10 | def test_vector_db_initialization_without_port(): 11 | """ 12 | Test that VectorDB is initialized correctly without a port. 13 | """ 14 | vector_db = VectorDB("localhost") 15 | assert vector_db.url == "http://localhost" 16 | assert isinstance(vector_db.client, httpx.Client) 17 | 18 | @pytest.mark.ci 19 | def test_vector_db_initialization_with_port(): 20 | """ 21 | Test that VectorDB is initialized correctly with a port. 22 | """ 23 | vector_db_with_port = VectorDB("localhost", 8080) 24 | assert vector_db_with_port.url == "http://localhost:8080" 25 | assert isinstance(vector_db_with_port.client, httpx.Client) 26 | 27 | @pytest.mark.ci 28 | @patch('httpx.Client.get') 29 | def test_vector_db_api_v1(mock_get): 30 | """ 31 | Test the api_v1 method of VectorDB class. 32 | """ 33 | mock_response = MagicMock() 34 | mock_response.json.return_value = {"status": "success"} 35 | mock_get.return_value = mock_response 36 | 37 | vector_db = VectorDB("localhost", 8080) 38 | 39 | response = vector_db.api_v1() 40 | 41 | mock_get.assert_called_once_with( 42 | urljoin("http://localhost:8080", "/api/v1"), 43 | headers={"Content-type": "application/json"}, 44 | timeout=10 * 60 45 | ) 46 | 47 | assert response == {"status": "success"} 48 | -------------------------------------------------------------------------------- /protollm_tools/sdk/tests/protollm_sdk/models/test_job_context.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from protollm_sdk.models.job_context_models import PromptModel, PromptMeta, ChatCompletionModel 4 | 5 | @pytest.mark.ci 6 | def test_from_prompt_model(): 7 | prompt_model = PromptModel( 8 | job_id="test_job_123", 9 | meta=PromptMeta( 10 | temperature=0.5, 11 | tokens_limit=100, 12 | stop_words=["stop", "words"], 13 | model="gpt-3" 14 | ), 15 | content="This is a test prompt" 16 | ) 17 | 18 | chat_completion = ChatCompletionModel.from_prompt_model(prompt_model) 19 | 20 | assert chat_completion.job_id == prompt_model.job_id 21 | assert chat_completion.meta == prompt_model.meta 22 | 23 | assert len(chat_completion.messages) == 1 24 | 25 | assert chat_completion.messages[0].role == "user" 26 | assert chat_completion.messages[0].content == prompt_model.content 27 | 28 | assert chat_completion.meta.temperature == 0.5 29 | assert chat_completion.meta.tokens_limit == 100 30 | assert chat_completion.meta.stop_words == ["stop", "words"] 31 | assert chat_completion.meta.model == "gpt-3" 32 | -------------------------------------------------------------------------------- /protollm_tools/sdk/tests/protollm_sdk/object_interface/integration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/sdk/tests/protollm_sdk/object_interface/integration/__init__.py -------------------------------------------------------------------------------- /protollm_tools/sdk/tests/protollm_sdk/object_interface/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/protollm_tools/sdk/tests/protollm_sdk/object_interface/unit/__init__.py -------------------------------------------------------------------------------- /protollm_tools/sdk/tests/protollm_sdk/test_utils.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from protollm_sdk.models.job_context_models import ResponseModel 4 | from protollm_sdk.object_interface import RedisWrapper 5 | from protollm_sdk.utils.reddis import get_reddis_wrapper, load_result 6 | 7 | 8 | def test_get_reddis_wrapper(): 9 | redis_wrapper = get_reddis_wrapper() 10 | assert isinstance(redis_wrapper, RedisWrapper) 11 | 12 | 13 | def test_load_result(): 14 | job_id = str(uuid.uuid4()) 15 | prefix = None 16 | redis = get_reddis_wrapper() 17 | redis.save_item(job_id, {"content": "value"}) 18 | 19 | result = load_result(redis, job_id, prefix) 20 | 21 | assert ResponseModel.model_validate_json(result.decode()) == ResponseModel(content="value") 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "ProtoLLM" 3 | version = "0.1.3" 4 | description = "A library with which to prototype LLM-based applications quickly and easily." 5 | requires-python = ">=3.10,<4.0" 6 | authors = [ 7 | { name = "aimclub", email = "aim.club@itmo.ru"} 8 | ] 9 | readme = "README_en.rst" 10 | license = "BSD-3-Clause" 11 | packages = [{ include = "protollm" }] 12 | dependencies = [ 13 | "aioredis>=2.0.1,<3.0.0", 14 | "celery>=5.4.0,<6.0.0", 15 | "chardet==5.2.0", 16 | "chromadb>=0.5.0,<0.6.0", 17 | "click==8.1.7", 18 | "deepeval==2.3.3; python_version >= '3.10' and python_version < '3.14'", 19 | "fastapi>=0.111.0,<0.112.0", 20 | "flower>=2.0.1,<3.0.0", 21 | "ftfy==6.3.1", 22 | "httpx>=0.27.0,<0.28.0", 23 | "kombu>=5.3.7,<6.0.0", 24 | "langchain>=0.3.4,<0.4.0", 25 | "langchain-chroma==0.1.4", 26 | "langchain-community==0.3.16", 27 | "langchain-core>=0.3.34", 28 | "langchain-elasticsearch==0.3.2", 29 | "langchain-gigachat==0.3.3", 30 | "langchain-openai==0.3.3", 31 | "langchain-text-splitters>=0.3.3,<0.4.0", 32 | "numpy==1.26.4", 33 | "openai>=1.42.0,<2.0.0", 34 | "pandas>=2.2.3,<3.0.0", 35 | "pdf2image==1.17.0", 36 | "pdfplumber==0.11.5", 37 | "pdfminer.six==20231228", 38 | "pillow==11.1.0", 39 | "pika>=1.3.2,<2.0.0", 40 | "protollm-sdk==1.1.6", 41 | "pydantic>=2.7.4,<3.0.0", 42 | "pydantic-core==2.23.4", 43 | "pydantic-settings==2.7.1", 44 | "pypdf2==3.0.1", 45 | "python-docx==1.1.2", 46 | "python-dotenv==1.0.1", 47 | "pytesseract==0.3.13", 48 | "pyyaml==6.0.2", 49 | "redis>=5.0.6,<6.0.0", 50 | "requests>=2.32.3,<3.0.0", 51 | "spacy==3.8.4; python_version >= '3.10' and python_version < '3.13'", 52 | "tabulate==0.9.0", 53 | "tornado>=6.4.1,<7.0.0", 54 | "tqdm==4.67.1", 55 | "transformers==4.48.2", 56 | "urllib3>=2.2.2,<3.0.0", 57 | "uuid>=1.30,<2.0.0", 58 | "websockets==14.1", 59 | "langchain-ollama (==0.3.0)", 60 | "langgraph (>=0.3.24,<0.4.0)" 61 | ] 62 | 63 | [project.urls] 64 | Homepage = "https://github.com/aimclub/ProtoLLM" 65 | Issues = "https://github.com/aimclub/ProtoLLM/issues" 66 | 67 | [project.optional-dependencies] 68 | api-tools = [ 69 | "protollm-api>=1.0.5,<2.0.0", 70 | "protollm-worker>=1.0.5,<2.0.0" 71 | ] 72 | 73 | [tool.poetry.group.dev.dependencies] 74 | pytest = "^8.2.2" 75 | pytest-asyncio = "^0.24.0" 76 | 77 | [build-system] 78 | requires = ["poetry-core"] 79 | build-backend = "poetry.core.masonry.api" 80 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/tests/__init__.py -------------------------------------------------------------------------------- /tests/mock_chat_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Optional, Any, Dict 3 | 4 | import requests 5 | from langchain.chat_models.base import BaseChatModel 6 | from langchain.schema import ( 7 | BaseMessage, 8 | AIMessage, 9 | HumanMessage, 10 | SystemMessage, 11 | ChatResult, 12 | ChatGeneration, 13 | ) 14 | from pydantic import PrivateAttr 15 | 16 | 17 | class MockChatModel(BaseChatModel): 18 | api_key: str 19 | base_url: str 20 | model: str 21 | temperature: float = 0.5 22 | max_tokens: int = 3000 23 | logging_level: int = logging.INFO 24 | 25 | _logger: logging.Logger = PrivateAttr() 26 | 27 | class Config: 28 | arbitrary_types_allowed = True # Allows arbitrary types like the logger 29 | 30 | def __init__(self, **data): 31 | super().__init__(**data) 32 | self._logger = logging.getLogger(__name__) 33 | logging.basicConfig(level=self.logging_level) 34 | 35 | @property 36 | def _llm_type(self) -> str: 37 | return "llama31" 38 | 39 | def _prepare_headers(self) -> Dict[str, str]: 40 | return { 41 | "Authorization": f"Bearer {self.api_key}", 42 | "Content-Type": "application/json", 43 | } 44 | 45 | def _prepare_context(self, messages: List[BaseMessage]) -> List[Dict[str, str]]: 46 | role_map = { 47 | HumanMessage: "user", 48 | AIMessage: "assistant", 49 | SystemMessage: "system" 50 | } 51 | 52 | return [{"role": role_map.get(type(message), "user"), "content": message.content} for message in messages] 53 | 54 | def _prepare_payload( 55 | self, 56 | context: List[Dict[str, str]], 57 | stop: Optional[List[str]] = None, 58 | **kwargs: Any 59 | ) -> Dict[str, Any]: 60 | payload = { 61 | "model": self.model, 62 | "messages": context, 63 | "temperature": kwargs.get("temperature", self.temperature), 64 | "max_tokens": kwargs.get("max_tokens", self.max_tokens), 65 | } 66 | if stop is not None: 67 | payload["stop"] = stop 68 | return payload 69 | 70 | def _generate( 71 | self, 72 | messages: List[BaseMessage], 73 | stop: Optional[List[str]] = None, 74 | **kwargs: Any, 75 | ) -> ChatResult: 76 | ai_message = AIMessage(content='Action:```{ "action": "add_numbers", "action_input": { "a": 15, "b": 27 } }```') 77 | generation = ChatGeneration(message=ai_message) 78 | return ChatResult(generations=[generation]) 79 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from unittest.mock import patch 3 | 4 | from deepeval.test_case import LLMTestCase 5 | from langchain_core.messages import AIMessage 6 | from pydantic import BaseModel, Field 7 | 8 | from protollm.metrics.deepeval_connector import DeepEvalConnector 9 | from protollm.metrics.evaluation_metrics import correctness_metric 10 | 11 | 12 | class Joke(BaseModel): 13 | """Joke to tell user.""" 14 | setup: str = Field(description="The setup of the joke") 15 | punchline: str = Field(description="The punchline to the joke") 16 | rating: Optional[int] = Field( 17 | default=None, description="How funny the joke is, from 1 to 10" 18 | ) 19 | 20 | 21 | def test_metric_connector(): 22 | model = DeepEvalConnector() 23 | mock_response = AIMessage(content="Hello, world!") 24 | with patch.object(model, 'generate', return_value=mock_response): 25 | result = model.generate("Hello") 26 | assert result.content == "Hello, world!" 27 | 28 | 29 | def test_metric_connector_with_schema(): 30 | model = DeepEvalConnector() 31 | mock_response = Joke.model_validate_json('{"setup": "test", "punchline": "test", "score": "7"}') 32 | with patch.object(model, 'generate', return_value=mock_response): 33 | response = model.generate(prompt="Tell me a joke", schema=Joke) 34 | assert issubclass(type(response), BaseModel) 35 | 36 | 37 | def test_correctness_metric(): 38 | test_case = LLMTestCase( 39 | input="The dog chased the cat up the tree, who ran up the tree?", 40 | actual_output="It depends, some might consider the cat, while others might argue the dog.", 41 | expected_output="The cat." 42 | ) 43 | 44 | with ( 45 | patch.object( 46 | correctness_metric, "_generate_evaluation_steps", return_value=["first step", "second step"] 47 | ), 48 | patch.object( 49 | correctness_metric,"evaluate", return_value=(1.0, "all good") 50 | ) as mocked_evaluate, 51 | ): 52 | correctness_metric.measure(test_case) 53 | mocked_evaluate.assert_called_with(test_case) 54 | assert isinstance(correctness_metric.score, float) 55 | assert isinstance(correctness_metric.reason, str) 56 | -------------------------------------------------------------------------------- /tests/validation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimclub/ProtoLLM/2b8d610361ddd13020fb742d7d394b37443df89c/tests/validation/__init__.py -------------------------------------------------------------------------------- /tests/validation/admin_config.yml: -------------------------------------------------------------------------------- 1 | app_port: 8000 2 | app_host: 0.0.0.0 3 | agents: 4 | - name: agent0 5 | description: Описание агента 6 | class_path: pipelines.rag_agent.RAGAgent 7 | default_params: 8 | max_input_tokens: 6144 9 | max_chat_history_token_length: 24576 10 | retrieving_top_k: 3 11 | generator_context_top_k: 3 12 | include_original_question_in_queries: True 13 | planner_model_name: planner_llm 14 | generator_model_name: generator_llm 15 | tokenizer_name: qwen_2.5 16 | store_name: specific1_vector_store 17 | - name: agent1 18 | description: Описание агента 19 | class_path: pipelines.rag_agent.RAGAgent 20 | default_params: 21 | max_input_tokens: 6144 22 | max_chat_history_token_length: 24576 23 | retrieving_top_k: 3 24 | generator_context_top_k: 5 25 | include_original_question_in_queries: False 26 | generator_model_name: generator_llm 27 | tokenizer_name: qwen_2.5 28 | store_name: specific2_vector_store 29 | 30 | models: 31 | - type: completion 32 | params: 33 | model: /model 34 | temperature: 0.01 35 | top_p: 0.95 36 | streaming: false 37 | name: planner_llm 38 | url: http://0.0.0.0:8001/v1 39 | api_key: token-api-key 40 | - type: completion 41 | params: 42 | model: /model 43 | temperature: 0.01 44 | top_p: 0.95 45 | streaming: true 46 | name: generator_llm 47 | url: http://0.0.0.0:8001/v1 48 | api_key: token-api-key 49 | - type: chat 50 | params: 51 | model: /model 52 | temperature: 0.01 53 | top_p: 0.95 54 | streaming: true 55 | name: router_llm 56 | url: http://0.0.0.0:8001/v1 57 | api_key: token-api-key 58 | - type: embedding 59 | params: 60 | name: e5-mistral-7b-instruct 61 | url: http://0.0.0.0:58891/v1 62 | api_key: token-api-key 63 | model: /models/e5-mistral-7b-instruct 64 | check_embedding_ctx_length: false 65 | tiktoken_enabled: false 66 | - type: tokenizer 67 | params: 68 | name: qwen_2.5 69 | path_or_repo_id: Qwen/Qwen2.5-7B-Instruct 70 | vector_stores: 71 | - type: chroma 72 | params: 73 | name: specific1_vector_store 74 | description: vector store description 75 | host: 0.0.0.0 76 | port: 57777 77 | collection_name: domain1_specific_collection 78 | embeddings_model_name: e5-mistral-7b-instruct 79 | - type: chroma 80 | params: 81 | name: specific2_vector_store 82 | description: vector store description 83 | host: 0.0.0.0 84 | port: 57777 85 | collection_name: domain2_specific_collection 86 | embeddings_model_name: e5-mistral-7b-instruct 87 | -------------------------------------------------------------------------------- /tests/validation/api_check.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from protollm_tools.llm-api.protollm_api.config import Config 3 | from protollm_tools.llm-api.backend.endpoints import get_router 4 | 5 | app = FastAPI() 6 | 7 | config = Config.read_from_env() 8 | 9 | app.include_router(get_router(config)) 10 | 11 | ''' 12 | curl -X POST "http://localhost:8000/generate" -H "Content-Type: application/json" -d '{ 13 | "job_id": "12345", 14 | "meta": { 15 | "temperature": 0.5, 16 | "tokens_limit": 1000, 17 | "stop_words": ["stop"], 18 | "model": "gpt-model" 19 | }, 20 | "content": "What is AI?" 21 | }' 22 | 23 | curl -X POST "http://localhost:8000/chat_completion" -H "Content-Type: application/json" -d '{ 24 | "job_id": "12345", 25 | "meta": { 26 | "temperature": 0.5, 27 | "tokens_limit": 1000, 28 | "stop_words": ["stop"], 29 | "model": "gpt-model" 30 | }, 31 | "messages": [ 32 | {"role": "user", "content": "What is AI?"}, 33 | {"role": "assistant", "content": "Artificial Intelligence is..."} 34 | ]}' 35 | ''' -------------------------------------------------------------------------------- /tests/validation/complex_check_ens.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from protollm_agents.entrypoint import Entrypoint 4 | 5 | 6 | logging.basicConfig( 7 | level=logging.DEBUG, 8 | format="%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", 9 | handlers=[logging.StreamHandler()], 10 | ) 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | if __name__ == "__main__": 15 | epoint = Entrypoint(config_path="./examples/admin-config.yml") 16 | epoint.run() 17 | -------------------------------------------------------------------------------- /tests/validation/ens_check.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from protollm_tools.llm-agents-api.protollm_agents.entrypoint import Entrypoint 3 | logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", handlers=[logging.StreamHandler()], ) 4 | logger = logging.getLogger(name) 5 | if __name__ == "__main__": 6 | epoint = Entrypoint(config_path="./examples/admin-config.yml") 7 | epoint.run() 8 | -------------------------------------------------------------------------------- /tests/validation/rag_check.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from protollm_tools.sdk.protollm_sdk.jobs.utility import construct_job_context 3 | from protollm_tools.sdk.protollm_sdk.utils.reddis import get_reddis_wrapper, load_result 4 | from protollm.rags.jobs import RAGJob 5 | 6 | # Шаг 1. Инициализация уникального номера идентификации 7 | job_id = str(uuid.uuid4()) 8 | # Шаг 2. Инициализация переменных доступа к БД и SDK 9 | job_name = "fast_validation" 10 | ctx = construct_job_context(job_name) 11 | # Шаг 3. Запуск поиска релевантных документов 12 | RAGJob().run(job_id, ctx, user_prompt='Какой бывает арматура железобетонных конструкций?', use_advanced_rag=False) 13 | # Шаг 4. Получение ответа модели из базы данных. 14 | rd = get_reddis_wrapper() 15 | result = load_result(rd, job_id, job_name) 16 | --------------------------------------------------------------------------------