├── .github
└── workflows
│ ├── ci.yml
│ ├── docker.yml
│ ├── hotfix.yml
│ ├── main.yml
│ └── release.yml
├── .gitignore
├── .pre-commit-config.yaml
├── Dockerfile
├── Dockerfile_ingestion
├── Makefile
├── README.md
├── README_zh.md
├── docker
├── .env.example
└── compose.yaml
├── docs
├── agentic_rag.md
├── api.md
├── api_v0.2.0.md
├── api_v0.3.0.md
├── api_zh.md
├── config_guide_cn.md
├── config_guide_en.md
├── data_analysis_doc.md
├── data_analysis_doc_250303.md
├── data_analysis_doc_zh.md
├── develop
│ ├── local_develop.md
│ └── local_develop_zh.md
├── docker_build.md
├── eas_deploy.md
├── eas_deploy_deepseek.md
├── figures
│ ├── agent
│ │ ├── agenda_query.png
│ │ ├── agent_tab.png
│ │ ├── date_query.png
│ │ ├── nl2sql_query.png
│ │ ├── rag_knowledge_query.png
│ │ ├── weather_query.png
│ │ └── web_search.png
│ ├── data_analysis
│ │ ├── DBChat_overview.png
│ │ ├── da_db_chat.png
│ │ ├── da_db_config.png
│ │ ├── da_db_enhance.png
│ │ ├── da_db_load.png
│ │ ├── da_db_prompt.png
│ │ ├── da_db_prompt_reset.png
│ │ ├── da_llm_config.png
│ │ ├── da_overview.png
│ │ ├── da_sheet_chat.png
│ │ ├── da_sheet_upload.png
│ │ ├── data_analysis_overview.png
│ │ ├── datafile_chat.png
│ │ ├── datafile_config.png
│ │ ├── db_chat.png
│ │ ├── db_chat_with_memo.png
│ │ ├── db_config.png
│ │ ├── db_config_update.png
│ │ ├── db_enhanced_features.png
│ │ ├── db_info_load.png
│ │ ├── db_query_desc.png
│ │ ├── db_query_no_desc.png
│ │ ├── llm_config.png
│ │ ├── llm_selection.png
│ │ ├── prompt_config.png
│ │ ├── prompt_reset.png
│ │ ├── sheet_data_preview.png
│ │ ├── sheet_upload.png
│ │ └── table_example.png
│ ├── deepseek
│ │ ├── deepseek_eas_api.png
│ │ ├── llm_chat.png
│ │ ├── llm_config.png
│ │ └── rag_chat.png
│ ├── deploy
│ │ └── eas
│ │ │ ├── deploy_json.png
│ │ │ ├── deploy_portal.png
│ │ │ ├── deploy_resources.jpg
│ │ │ ├── deploy_success.jpg
│ │ │ ├── deploy_vpc.jpg
│ │ │ ├── edited_json.jpg
│ │ │ ├── enable_web.jpg
│ │ │ ├── trace_detail.jpg
│ │ │ ├── trace_json.jpg
│ │ │ ├── trace_key.jpg
│ │ │ ├── trace_percent.jpg
│ │ │ └── view_web.jpg
│ ├── elastic
│ │ ├── aliyun_es_ik_hot_update.png
│ │ ├── aliyun_es_instance.png
│ │ ├── aliyun_es_instance_autoindex.png
│ │ ├── aliyun_es_instance_info.png
│ │ ├── aliyun_es_kibana.png
│ │ ├── aliyun_es_kibana_entry.png
│ │ ├── aliyun_es_kibana_index_detail.png
│ │ ├── aliyun_es_kibana_index_management.png
│ │ ├── aliyun_es_kibana_management.png
│ │ ├── aliyun_es_kibana_menu.png
│ │ ├── aliyun_es_kibana_whitelist.png
│ │ ├── aliyun_es_overview.png
│ │ ├── aliyun_es_password.png
│ │ ├── aliyun_es_password_reset.png
│ │ ├── aliyun_es_plugin.png
│ │ ├── aliyun_es_upload_dic.png
│ │ ├── new_word_dict.png
│ │ ├── pairag_es_connect.png
│ │ ├── pairag_es_param_list.png
│ │ ├── pairag_retrieval_mode.png
│ │ └── pairag_select_es.png
│ ├── framework.jpg
│ ├── index
│ │ ├── add_index.png
│ │ ├── config_new_index.png
│ │ └── select_index.png
│ ├── multimodal
│ │ ├── mm_chat.jpg
│ │ ├── mm_settings.jpg
│ │ └── mm_upload.jpg
│ └── quick_start
│ │ ├── query.png
│ │ ├── setting.png
│ │ └── upload.png
├── index
│ └── multi_index.md
├── multimodal_rag.md
├── qca_generation_and_evaluation.md
├── tabular_doc.md
└── vectordb
│ ├── elasticsearch.md
│ ├── opensearch_image.json
│ └── opensearch_text.json
├── example_data
├── cn_clip
│ └── pokemon.jpeg
├── eval_docs_crag_small
│ ├── crag-task1-straightforward-eval-msgs_1.jsonl
│ └── crag-task1-straightforward-eval-msgs_2.jsonl
├── eval_docs_image
│ └── 跑鞋推荐.pdf
├── eval_docs_image_example
│ └── multimodal_eval_dataset_zh_example.jsonl
├── eval_docs_text
│ ├── EasyRec.txt
│ └── PAI.txt
├── function_tools
│ └── api-tool-with-intent-detection-for-travel-assistant
│ │ ├── dataset
│ │ ├── 上海攻略信息.pdf
│ │ └── 北京攻略信息.pdf
│ │ ├── figures
│ │ ├── agent_chat.jpg
│ │ ├── agent_config.jpg
│ │ ├── settings.jpg
│ │ └── upload.jpg
│ │ ├── mock_api
│ │ └── main.py
│ │ ├── settings.toml
│ │ ├── tools.json
│ │ └── tools.py
├── pai_document.pdf
└── paul_graham
│ └── paul_graham_essay.txt
├── integrations
└── pairag-file
│ ├── README.md
│ ├── pairag
│ └── file
│ │ ├── __init__.py
│ │ ├── nodeparsers
│ │ ├── __init__.py
│ │ ├── pai
│ │ │ ├── __init__.py
│ │ │ ├── constants.py
│ │ │ ├── image_caption_tool.py
│ │ │ ├── pai_markdown_parser.py
│ │ │ └── pai_node_parser.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ └── pai_markdown_tree.py
│ │ ├── readers
│ │ ├── __init__.py
│ │ ├── pai
│ │ │ ├── __init__.py
│ │ │ ├── constants.py
│ │ │ ├── file_readers
│ │ │ │ ├── __init__.py
│ │ │ │ ├── pai_csv_reader.py
│ │ │ │ ├── pai_docx_reader.py
│ │ │ │ ├── pai_excel_reader.py
│ │ │ │ ├── pai_html_reader.py
│ │ │ │ ├── pai_image_reader.py
│ │ │ │ ├── pai_jsonl_reader.py
│ │ │ │ ├── pai_markdown_reader.py
│ │ │ │ ├── pai_pdf_reader.py
│ │ │ │ └── pai_pptx_reader.py
│ │ │ ├── pai_data_reader.py
│ │ │ └── utils
│ │ │ │ ├── __init__.py
│ │ │ │ ├── cuda_utils.py
│ │ │ │ ├── image_utils.py
│ │ │ │ ├── magic-pdf.template.json
│ │ │ │ ├── markdown_utils.py
│ │ │ │ └── modelscope_utils.py
│ │ └── readme.md
│ │ └── store
│ │ ├── __init__.py
│ │ ├── oss_store.py
│ │ └── pai_image_store.py
│ ├── poetry.lock
│ ├── pyproject.toml
│ └── tests
│ ├── __init__.py
│ ├── nodeparsers
│ └── test_markdown_mode_parser.py
│ ├── openailike_multimodal.py
│ ├── readers
│ ├── test_csv_reader.py
│ ├── test_docx_reader.py
│ ├── test_excel_reader.py
│ ├── test_html_reader.py
│ ├── test_image_reader.py
│ ├── test_jsonl_reader.py
│ ├── test_pdf_reader.py
│ └── test_post_process_multi_level_headings.py
│ ├── test_image_caption_tool.py
│ ├── test_image_utils.py
│ ├── test_load_data.py
│ └── testdata
│ ├── csv_data
│ ├── 30天留存率_csv_two_header.csv
│ └── titanic_train.csv
│ ├── db_data
│ ├── pets.db
│ └── pets.sqlite
│ ├── docx_data
│ └── 大模型RAG对话系统.docx
│ ├── excel_data
│ └── 30天留存率_excel_two_header.xlsx
│ ├── html_data
│ ├── AIGC Stable Diffusion文生图Lora模型微调实现虚拟上装.html
│ ├── AI写真计费说明.html
│ ├── EAS计费说明.html
│ ├── 多媒体分析计费说明.html
│ └── 聚类模型评估.html
│ ├── image_data
│ └── 11.jpg
│ ├── json_data
│ └── pai_document.json
│ ├── jsonl_data
│ └── price.jsonl
│ ├── md_data
│ └── pai_document.md
│ └── pdf_data
│ └── pai_document.pdf
├── poetry.lock
├── pyproject.toml
├── scripts
├── gunicorn.conf.py
├── load_data.sh
├── load_data_by_step.sh
├── locust.py
└── start.sh
├── setup.py
├── src
└── pairag
│ ├── api
│ ├── api_chat.py
│ ├── api_router_v1.py
│ ├── chat_completions.py
│ ├── exception_handler.py
│ └── middleware.py
│ ├── app
│ ├── app.py
│ └── constants.py
│ ├── chat
│ ├── chat_app.py
│ ├── chat_flow.py
│ ├── models.py
│ └── utils
│ │ └── chat_utils.py
│ ├── config
│ ├── evaluation
│ │ ├── config.yaml
│ │ ├── settings_eval_for_crag_text.toml
│ │ ├── settings_eval_for_image.toml
│ │ └── settings_eval_for_text.toml
│ └── settings.toml
│ ├── core
│ ├── chat_service.py
│ ├── models
│ │ ├── config.py
│ │ ├── container.py
│ │ ├── errors.py
│ │ └── state.py
│ ├── rag_config.py
│ ├── rag_config_manager.py
│ ├── rag_environment.py
│ ├── rag_module.py
│ └── service_daemon.py
│ ├── data_pipeline
│ ├── constants.py
│ ├── datasource
│ │ └── filedelta_datasource.py
│ ├── delta
│ │ ├── list_delta.py
│ │ ├── milvus.py
│ │ └── models.py
│ ├── e2e_config.yaml
│ ├── job
│ │ ├── file_task_executor.py
│ │ └── rag_job_manager.py
│ ├── main.py
│ ├── models
│ │ ├── config
│ │ │ ├── datasource.py
│ │ │ └── operator.py
│ │ ├── file
│ │ │ └── event.py
│ │ └── models.py
│ ├── operators
│ │ ├── base.py
│ │ ├── embedder.py
│ │ ├── parser.py
│ │ ├── sink.py
│ │ └── split.py
│ ├── ray_executor.py
│ ├── readme.md
│ └── utils
│ │ ├── compute_resource_utils.py
│ │ ├── concurrency_utils.py
│ │ ├── cuda_utils.py
│ │ ├── dataset_utils.py
│ │ ├── download_utils.py
│ │ ├── file_ext_utils.py
│ │ ├── filename_utils.py
│ │ ├── memory_utils.py
│ │ ├── node_utils.py
│ │ ├── path_resolver.py
│ │ ├── path_utils.py
│ │ └── vectordb_utils.py
│ ├── evaluation
│ ├── dataset
│ │ ├── crag
│ │ │ ├── crag_data_loader.py
│ │ │ └── crag_jsonl_reader.py
│ │ ├── open_dataset.py
│ │ ├── rag_eval_dataset_refactor.py
│ │ ├── rag_qca_dataset_refactor.py
│ │ └── state_manager.py
│ ├── evaluator
│ │ ├── base_evaluator.py
│ │ └── pai_evaluator.py
│ ├── generator
│ │ └── rag_qca_generator.py
│ ├── metrics
│ │ ├── response
│ │ │ ├── base.py
│ │ │ ├── correctness.py
│ │ │ └── faithfulness.py
│ │ └── retrieval
│ │ │ ├── core.py
│ │ │ ├── hitrate.py
│ │ │ └── mrr.py
│ ├── pipeline
│ │ ├── run_evaluation_pipeline.py
│ │ └── run_multimodal_evaluation_pipeline.py
│ ├── run_evaluation_experiments.py
│ └── utils
│ │ ├── create_components.py
│ │ └── file_utils.py
│ ├── extensions
│ └── news
│ │ ├── miaobi_news.py
│ │ └── news_config.py
│ ├── integrations
│ ├── data_analysis
│ │ ├── data_analysis_config.py
│ │ ├── data_analysis_tool.py
│ │ ├── nl2pandas_retriever.py
│ │ ├── nl2sql
│ │ │ ├── db_connector.py
│ │ │ ├── db_descriptor.py
│ │ │ ├── db_indexer.py
│ │ │ ├── db_loader.py
│ │ │ ├── db_preretriever.py
│ │ │ ├── db_query.py
│ │ │ ├── db_selector.py
│ │ │ ├── db_utils
│ │ │ │ ├── constants.py
│ │ │ │ └── nl2sql_utils.py
│ │ │ ├── nl2sql_prompts.py
│ │ │ ├── query_preprocessor.py
│ │ │ └── sql_generator.py
│ │ ├── nl2sql_retriever.py
│ │ ├── pandas_instruction_parser.py
│ │ ├── test
│ │ │ ├── test_bird_schema_collector.py
│ │ │ ├── test_db_connector.py
│ │ │ ├── test_db_info_collector.py
│ │ │ ├── test_index_retriever.py
│ │ │ └── test_query_processor.py
│ │ └── text2sql
│ │ │ ├── db_connector.py
│ │ │ ├── db_info_collector.py
│ │ │ ├── db_info_index.py
│ │ │ ├── db_info_node.py
│ │ │ ├── db_info_retriever.py
│ │ │ ├── db_info_selector.py
│ │ │ ├── db_loader.py
│ │ │ ├── db_query.py
│ │ │ ├── db_retriever_filter.py
│ │ │ ├── evaluations
│ │ │ ├── base_evaluator.py
│ │ │ ├── bird_evaluator.py
│ │ │ ├── eval_bird
│ │ │ │ └── evaluation.py
│ │ │ ├── eval_spider
│ │ │ │ ├── evaluation.py
│ │ │ │ ├── exec_eval.py
│ │ │ │ ├── parse.py
│ │ │ │ └── process_sql.py
│ │ │ └── spider_evaluator.py
│ │ │ ├── query_processor.py
│ │ │ ├── sql_generator.py
│ │ │ └── utils
│ │ │ ├── constants.py
│ │ │ ├── info_utils.py
│ │ │ ├── prompts.py
│ │ │ └── sql_utils.py
│ ├── embeddings
│ │ ├── pai
│ │ │ ├── embedding_utils.py
│ │ │ ├── pai_embedding.py
│ │ │ └── pai_embedding_config.py
│ │ └── readme.md
│ ├── guardrail
│ │ └── pai_guardrail.py
│ ├── llms
│ │ ├── pai
│ │ │ ├── llm_config.py
│ │ │ ├── llm_utils.py
│ │ │ ├── open_ai_alike_multi_modal.py
│ │ │ ├── pai_llm.py
│ │ │ └── pai_multi_modal_llm.py
│ │ └── readme.md
│ ├── postprocessor
│ │ ├── my_model_based_reranker.py
│ │ └── pai
│ │ │ └── pai_postprocessor.py
│ ├── query_transform
│ │ ├── intent_models.py
│ │ └── pai_query_transform.py
│ ├── search
│ │ ├── aliyun_search.py
│ │ ├── bing_search.py
│ │ ├── bs4_reader.py
│ │ ├── google_search.py
│ │ └── search_config.py
│ ├── synthesizer
│ │ ├── pai_synthesizer.py
│ │ └── prompt_templates.py
│ ├── trace
│ │ ├── base.py
│ │ ├── pai_query_wrapper.py
│ │ ├── reloadable_exporter.py
│ │ └── trace_config.py
│ └── vector_stores
│ │ ├── dashvector
│ │ └── dashvector.py
│ │ ├── elasticsearch
│ │ ├── elasticsearch_utils.py
│ │ ├── my_async_vector_store.py
│ │ └── my_elasticsearch.py
│ │ ├── faiss
│ │ └── my_faiss.py
│ │ ├── hologres
│ │ └── hologres.py
│ │ ├── milvus
│ │ └── my_milvus.py
│ │ ├── postgresql
│ │ └── postgresql.py
│ │ └── tablestore
│ │ └── tablestore.py
│ ├── knowledgebase
│ ├── index
│ │ └── pai
│ │ │ ├── pai_vector_index.py
│ │ │ ├── utils
│ │ │ ├── index_utils.py
│ │ │ ├── sparse_embed_function.py
│ │ │ └── vector_store_utils.py
│ │ │ └── vector_store_config.py
│ ├── models.py
│ ├── rag_knowledgebase.py
│ ├── rag_knowledgebase_helper.py
│ └── utils
│ │ └── knowledgebase_utils.py
│ ├── tools
│ ├── data_analysis
│ │ └── text2sql
│ │ │ ├── bird_eval.py
│ │ │ ├── bird_eval_parallel.py
│ │ │ └── spider_eval.py
│ └── intent_eval
│ │ ├── intent_eval_tool.py
│ │ ├── intent_sample.json
│ │ └── intent_sample_predicted.json
│ ├── utils
│ ├── __init__.py
│ ├── citation_utils.py
│ ├── constants.py
│ ├── cuda_utils.py
│ ├── download_models.py
│ ├── file_utils.py
│ ├── format_logging.py
│ ├── json_parser.py
│ ├── mdoelscope_utils.py
│ ├── prompt_template.py
│ ├── score_utils.py
│ └── time_utils.py
│ └── web
│ ├── element_manager.py
│ ├── event_listeners.py
│ ├── filebrowser
│ ├── constants.py
│ └── request_utils.py
│ ├── index_utils.py
│ ├── rag_local_client.py
│ ├── tabs
│ ├── chat_tab.py
│ ├── data_analysis_tab.py
│ ├── history_tab.py
│ ├── knowledgebase_tab.py
│ ├── model
│ │ └── index_info.py
│ ├── news_extension.py
│ ├── search_web_tab.py
│ ├── settings_tab.py
│ └── vector_db_panel.py
│ ├── ui_constants.py
│ ├── utils.py
│ ├── view_model.py
│ └── webui.py
└── tests
├── __init__.py
├── app
├── test_openai.py
└── test_openai_embedding.py
├── index
├── test_elasticsearch.py
└── test_milvus.py
├── integrations
├── intentdetection
│ ├── multi_intents_sample.json
│ ├── multi_intents_sample_predicted.json
│ └── test_multi_intents.py
├── llm
│ ├── test_function_calling_llm.py
│ └── test_llm.py
├── retriever
│ └── test_es_tokenizer.py
├── test_nl2pandas_retriever.py
└── test_nl2sql_retriever.py
├── news
└── intent_eval
│ ├── intent_sample.json
│ ├── intent_sample_predicted.json
│ └── test_news_intent.py
└── testdata
├── csv_data
└── titanic_train.csv
├── db_data
├── pets.db
└── pets.sqlite
├── pai_document.md
└── paul_graham
└── paul_graham_essay.txt
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: PAI-RAG CI Build
2 |
3 | on:
4 | push:
5 | # Sequence of patterns matched against refs/heads
6 | branches:
7 | - main
8 | - feature
9 | - "releases/**"
10 |
11 | concurrency:
12 | group: pairag-ci-${{ github.head_ref || github.run_id }}
13 | cancel-in-progress: true
14 |
15 | permissions:
16 | contents: read
17 | pull-requests: write
18 |
19 | jobs:
20 | build:
21 | name: Build and Test
22 | runs-on: ubuntu-latest
23 |
24 | steps:
25 | - uses: actions/checkout@v4
26 | - name: Set up Python 3.11
27 | # This is the version of the action for setting up Python, not the Python version.
28 | uses: actions/setup-python@v5
29 | with:
30 | # Semantic version range syntax or exact version of a Python version
31 | python-version: "3.11"
32 | # Optional - x64 or x86 architecture, defaults to x64
33 | architecture: "x64"
34 |
35 | - name: Install Dependencies
36 | run: |
37 | python -m pip install --upgrade pip setuptools wheel
38 | pip install poetry
39 | poetry install
40 | poetry run pip install magic-pdf[full]==1.3.10
41 | poetry run pip install opentelemetry-exporter-otlp-proto-grpc protobuf==5.27.4
42 | env:
43 | POETRY_VIRTUALENVS_CREATE: false
44 |
45 | - name: Install pre-commit
46 | shell: bash
47 | run: poetry run pip install pre-commit
48 |
49 | - name: Run Linter
50 | shell: bash
51 | run: poetry run make lint
52 |
53 | - name: Run Tests
54 | run: |
55 | make coveragetest
56 | env:
57 | DASHSCOPE_API_KEY: ${{ secrets.TESTDASHSCOPEKEY }}
58 | BING_SEARCH_KEY: ${{ secrets.BING_SEARCH_KEY }}
59 | OSS_ACCESS_KEY_ID: ${{ secrets.OSS_ACCESS_KEY_ID }}
60 | OSS_ACCESS_KEY_SECRET: ${{ secrets.OSS_ACCESS_KEY_SECRET }}
61 | PAIRAG_RAG__embedding__source: "huggingface"
62 | PAIRAG_RAG__llm__source: "DashScope"
63 | PAIRAG_RAG__llm__model: "qwen-max"
64 |
--------------------------------------------------------------------------------
/.github/workflows/main.yml:
--------------------------------------------------------------------------------
1 | name: PR Build
2 |
3 | on:
4 | pull_request:
5 | # Sequence of patterns matched against refs/heads
6 | branches:
7 | - main
8 | - feature
9 | - "releases/**"
10 |
11 | concurrency:
12 | group: pr-build-${{ github.head_ref || github.run_id }}
13 | cancel-in-progress: true
14 |
15 | permissions:
16 | contents: read
17 | pull-requests: write
18 |
19 | jobs:
20 | build:
21 | name: Build and Test
22 | runs-on: ubuntu-latest
23 |
24 | steps:
25 | - uses: actions/checkout@v4
26 | - name: Set up Python 3.11
27 | # This is the version of the action for setting up Python, not the Python version.
28 | uses: actions/setup-python@v5
29 | with:
30 | # Semantic version range syntax or exact version of a Python version
31 | python-version: "3.11"
32 | # Optional - x64 or x86 architecture, defaults to x64
33 | architecture: "x64"
34 |
35 | - name: Install Dependencies
36 | run: |
37 | python -m pip install --upgrade pip setuptools wheel
38 | pip install poetry
39 | poetry install
40 | poetry run pip install magic-pdf[full]==1.3.10
41 | poetry run pip install opentelemetry-exporter-otlp-proto-grpc protobuf==5.27.4
42 | env:
43 | POETRY_VIRTUALENVS_CREATE: false
44 |
45 | - name: Install pre-commit
46 | shell: bash
47 | run: poetry run pip install pre-commit
48 |
49 | - name: Run Linter
50 | shell: bash
51 | run: poetry run make lint
52 |
53 | - name: Run Tests
54 | run: |
55 | make coveragetest
56 | env:
57 | DASHSCOPE_API_KEY: ${{ secrets.TESTDASHSCOPEKEY }}
58 | PAIRAG_RAG__embedding__source: "huggingface"
59 | PAIRAG_RAG__llm__source: "DashScope"
60 | PAIRAG_RAG__llm__model: "qwen-max"
61 | BING_SEARCH_KEY: ${{ secrets.BING_SEARCH_KEY }}
62 | OSS_ACCESS_KEY_ID: ${{ secrets.OSS_ACCESS_KEY_ID }}
63 | OSS_ACCESS_KEY_SECRET: ${{ secrets.OSS_ACCESS_KEY_SECRET }}
64 |
65 | - name: Get Cover
66 | uses: orgoro/coverage@v3.1
67 | with:
68 | coverageFile: localdata/test_output/coverage_report.xml
69 | token: ${{ secrets.GITHUB_TOKEN }}
70 | thresholdAll: 0.4 # Total coverage threshold
71 | #thresholdNew: 0.9 # New files coverage threshold
72 | #thresholdModified: 0.9 # Modified files coverage threshold
73 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Release image
2 |
3 | # Configures this workflow to run every time a change is pushed to the branch called `release`.
4 | on:
5 | workflow_dispatch:
6 | push:
7 | branches: ["main", "release_test"]
8 |
9 | # Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds.
10 | env:
11 | REGISTRY: mybigpai-public-registry.cn-beijing.cr.aliyuncs.com
12 |
13 | # There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu.
14 | jobs:
15 | build-and-push-image:
16 | runs-on: ubuntu-latest
17 | steps:
18 | - name: Checkout repository
19 | uses: actions/checkout@v4
20 |
21 | - uses: actions/setup-python@v4
22 | with:
23 | python-version: "3.11"
24 |
25 | - name: Check disk space
26 | run: df . -h
27 |
28 | - name: Free disk space
29 | run: |
30 | sudo docker rmi $(docker image ls -aq) >/dev/null 2>&1 || true
31 | sudo rm -rf \
32 | /usr/share/dotnet /usr/local/lib/android /opt/ghc \
33 | /usr/local/share/powershell /usr/share/swift /usr/local/.ghcup \
34 | /usr/lib/jvm || true
35 |
36 | - name: Extract version
37 | run: |
38 | pip install poetry
39 | VERSION_TAG=$(poetry version --short)
40 | SPECIFIC_VERSION_TAG="$VERSION_TAG-$(date +'%Y%m%d')"
41 | echo "VERSION_TAG=$VERSION_TAG" >> $GITHUB_ENV
42 | echo "SPECIFIC_VERSION_TAG=$SPECIFIC_VERSION_TAG" >> $GITHUB_ENV
43 | echo "version:$SPECIFIC_VERSION_TAG\ncommit_id:$(git rev-parse HEAD)" > __build_version.cfg
44 |
45 | # Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here.
46 | - name: Login to ACR region
47 | uses: docker/login-action@v1
48 | with:
49 | registry: ${{ env.REGISTRY }}
50 | username: ${{ secrets.ACR_USER }}
51 | password: ${{ secrets.ACR_PUBLIC_PASSWORD }}
52 |
53 | - name: Build and push base image
54 | env:
55 | IMAGE_TAG: ${{env.VERSION_TAG}}
56 | SPECIFIC_IMAGE_TAG: ${{env.SPECIFIC_VERSION_TAG}}
57 | run: |
58 | docker build -t ${{ env.REGISTRY }}/mybigpai/pairag:${{ env.IMAGE_TAG }} .
59 | docker push ${{ env.REGISTRY }}/mybigpai/pairag:${{ env.IMAGE_TAG }}
60 | docker tag ${{ env.REGISTRY }}/mybigpai/pairag:${{ env.IMAGE_TAG }} ${{ env.REGISTRY }}/mybigpai/pairag:${{ env.SPECIFIC_IMAGE_TAG }}
61 | docker push ${{ env.REGISTRY }}/mybigpai/pairag:${{ env.SPECIFIC_IMAGE_TAG }}
62 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.11 AS builder
2 |
3 | RUN pip3 install poetry
4 |
5 | ENV POETRY_NO_INTERACTION=1 \
6 | POETRY_VIRTUALENVS_IN_PROJECT=1 \
7 | POETRY_VIRTUALENVS_CREATE=1 \
8 | POETRY_CACHE_DIR=/tmp/poetry_cache
9 |
10 | WORKDIR /app
11 | COPY . .
12 |
13 | RUN poetry install \
14 | && poetry run pip install magic-pdf[full]==1.3.10 \
15 | && poetry run pip install opentelemetry-exporter-otlp-proto-grpc protobuf==5.27.4 \
16 | && rm -rf $POETRY_CACHE_DIR
17 |
18 | FROM python:3.11-slim AS prod
19 |
20 | RUN rm -rf /etc/localtime && ln -s /usr/share/zoneinfo/Asia/Harbin /etc/localtime
21 |
22 | RUN apt-get update && apt-get install -y libgl1 libglib2.0-0 libgomp1 curl libgdiplus wget perl build-essential
23 |
24 | ENV VIRTUAL_ENV=/app/.venv \
25 | PATH="/app/.venv/bin:$PATH"
26 |
27 |
28 | ADD https://eas-data.oss-cn-shanghai.aliyuncs.com/3rdparty/sdwebui/filebrowser /bin/filebrowser
29 | RUN chmod u+x /bin/filebrowser
30 |
31 | # setup paddleocr dependencies
32 | RUN mkdir -p /root/.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer \
33 | && curl https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_infer.tar -o /root/.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer/ch_PP-OCRv4_det_infer.tar \
34 | && tar xvf /root/.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer/ch_PP-OCRv4_det_infer.tar -C /root/.paddleocr/whl/det/ch/
35 |
36 | RUN mkdir -p /root/.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer \
37 | && curl https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_rec_infer.tar -o /root/.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer/ch_PP-OCRv4_rec_infer.tar \
38 | && tar xvf /root/.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer/ch_PP-OCRv4_rec_infer.tar -C /root/.paddleocr/whl/rec/ch/
39 |
40 | RUN mkdir -p /root/.paddleocr/whl/cls/ch_ppocr_mobile_v2.0_cls_infer \
41 | && curl https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar -o /root/.paddleocr/whl/cls/ch_ppocr_mobile_v2.0_cls_infer/ch_ppocr_mobile_v2.0_cls_infer.tar \
42 | && tar xvf /root/.paddleocr/whl/cls/ch_ppocr_mobile_v2.0_cls_infer/ch_ppocr_mobile_v2.0_cls_infer.tar -C /root/.paddleocr/whl/cls/
43 |
44 | WORKDIR /app
45 |
46 | COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
47 | COPY . .
48 | CMD ["./scripts/start.sh", "-w", "1"]
49 |
--------------------------------------------------------------------------------
/Dockerfile_ingestion:
--------------------------------------------------------------------------------
1 | FROM rayproject/ray:2.46.0.0e19ea-py311-cu124
2 |
3 | RUN pip3 install poetry
4 |
5 | ENV POETRY_NO_INTERACTION=1 \
6 | POETRY_VIRTUALENVS_CREATE=false \
7 | POETRY_CACHE_DIR=/tmp/poetry_cache \
8 | HOME_DIR=/home/ray
9 |
10 | WORKDIR /app
11 | COPY . .
12 |
13 | RUN poetry install
14 | RUN poetry run pip install magic-pdf[full]==1.3.10 \
15 | && poetry run pip install opentelemetry-exporter-otlp-proto-grpc protobuf==5.27.4 \
16 | && rm -rf $POETRY_CACHE_DIR
17 |
18 | RUN sudo rm -rf /etc/localtime && sudo ln -s /usr/share/zoneinfo/Asia/Harbin /etc/localtime
19 | RUN sudo apt-get update && sudo apt-get install -y libgl1 libglib2.0-0 libgomp1 curl libgdiplus wget perl build-essential
20 |
21 | # setup paddleocr dependencies
22 | RUN sudo mkdir -p $HOME_DIR/.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer \
23 | && sudo curl https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_infer.tar -o $HOME_DIR/.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer/ch_PP-OCRv4_det_infer.tar \
24 | && sudo tar xvf $HOME_DIR/.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer/ch_PP-OCRv4_det_infer.tar -C $HOME_DIR/.paddleocr/whl/det/ch/
25 |
26 | RUN sudo mkdir -p $HOME_DIR/.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer \
27 | && sudo curl https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_rec_infer.tar -o $HOME_DIR/.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer/ch_PP-OCRv4_rec_infer.tar \
28 | && sudo tar xvf $HOME_DIR/.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer/ch_PP-OCRv4_rec_infer.tar -C $HOME_DIR/.paddleocr/whl/rec/ch/
29 |
30 | RUN sudo mkdir -p $HOME_DIR/.paddleocr/whl/cls/ch_ppocr_mobile_v2.0_cls_infer \
31 | && sudo curl https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar -o $HOME_DIR/.paddleocr/whl/cls/ch_ppocr_mobile_v2.0_cls_infer/ch_ppocr_mobile_v2.0_cls_infer.tar \
32 | && sudo tar xvf $HOME_DIR/.paddleocr/whl/cls/ch_ppocr_mobile_v2.0_cls_infer/ch_ppocr_mobile_v2.0_cls_infer.tar -C $HOME_DIR/.paddleocr/whl/cls/
33 |
34 | CMD ["./scripts/load_data.sh", "--help"]
35 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | GIT_ROOT ?= $(shell git rev-parse --show-toplevel)
2 |
3 | help: ## Show all Makefile targets.
4 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'
5 |
6 | format: ## Run code autoformatters (black).
7 | pre-commit install
8 | git ls-files | xargs pre-commit run black --files
9 |
10 | lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
11 | pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files
12 |
13 | test: ## Run tests via pytest.
14 | pytest tests -s
15 |
16 | coveragetest: ## Tests with coverage report
17 | pytest --cov-report xml:localdata/test_output/coverage_report.xml --cov=pairag tests -s
18 |
19 | watch-docs: ## Build and watch documentation.
20 | sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/
21 |
22 | publish:
23 | poetry publish --build --username __token__ --password $$PYPI_KEY --build --skip-existing
24 |
--------------------------------------------------------------------------------
/docker/.env.example:
--------------------------------------------------------------------------------
1 | # DASHSCOPE API_KEY
2 | DASHSCOPE_API_KEY=
3 | USE_CUDA=0
4 |
5 | # OSS AK SK
6 | OSS_ACCESS_KEY_ID=
7 | OSS_ACCESS_KEY_SECRET=
8 | OSS_BUCKET=
9 | OSS_ENDPOINT=
10 |
--------------------------------------------------------------------------------
/docker/compose.yaml:
--------------------------------------------------------------------------------
1 | services:
2 | api:
3 | image: mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/mybigpai/pairag:1.1.0
4 | ports:
5 | - "8680:8680"
6 | restart: always
7 | environment:
8 | DASHSCOPE_API_KEY: ${DASHSCOPE_API_KEY}
9 | OSS_ACCESS_KEY_ID: ${OSS_ACCESS_KEY_ID}
10 | OSS_ACCESS_KEY_SECRET: ${OSS_ACCESS_KEY_SECRET}
11 | PAIRAG_RAG__oss_store__bucket: ${OSS_BUCKET}
12 | PAIRAG_RAG__oss_store__endpoint: ${OSS_ENDPOINT:-oss-cn-hangzhou.aliyuncs.com}
13 | ARMS_APP_NAME: ${ARMS_APP_NAME}
14 | ARMS_REGION_ID: ${ARMS_REGION_ID}
15 | ARMS_LICENSE_KEY: ${ARMS_LICENSE_KEY}
16 | ARMS_IS_PUBLIC: true
17 | ENABLE_FASTAPI: false
18 | ENABLE_REQUESTS: false
19 | ENABLE_AIOHTTPCLIENT: false
20 | ENABLE_HTTPX: false
21 | LD_LIBRARY_PATH: /usr/local/lib
22 |
23 | volumes:
24 | - ../model_repository:/app/model_repository
25 | - ./app_data:/app/localdata
26 | entrypoint: ["./scripts/start.sh", "-w", "1"]
27 | healthcheck:
28 | test: ["CMD", "curl", "-f", "http://localhost:8680/api/v1/health"]
29 | interval: 30s
30 | retries: 40
31 | start_period: 20s
32 |
--------------------------------------------------------------------------------
/docs/data_analysis_doc.md:
--------------------------------------------------------------------------------
1 | # 模型配置
2 |
3 | 在web界面 Settings 中,选择需要的LLM,如果选择DashScope(通义API),推荐使用qwen-max模型;如果选择PaiEas开源部署,推荐使用qwen2-72b-instruct模型
4 |
5 | 点击下方Save更新使用的模型
6 |
7 | 
8 |
9 | 点击web界面上方Data Analysis,进入到数据分析页面,支持两种类型的数据分析:连接数据库(mysql)分析 和 上传表格文件(excel/csv)分析
10 |
11 | 
12 |
13 | # 数据库分析配置
14 |
15 | ## 数据库连接
16 |
17 | 连接数据库,选择左上方数据分析类型为 database,出现数据库连接配置界面,如下图:
18 |
19 | 
20 |
21 | 其中,
22 |
23 | - Dialect为数据库类别,当前支持mysql,默认mysql
24 | - Username和Passoword分别为用户名和密码
25 | - Host为本地或远程数据库url,Port为接口,默认3306
26 | - Database为需要分析的目标数据库名称
27 | - Tables为需要分析的数据表,格式为:table_A, table_B,... ,默认为空,使用目标数据库中所有数据表
28 | - Descriptions为针对目标数据库中每张表的补充描述,比如对表中字段的进一步解释,可以提升数据分析效果,格式为:{"table_A":"字段a表示xxx,字段b数据的格式为yyy","table_B":"这张表主要用于zzz"},注意:需要使用英文输入法下的字典格式(英文双引号,冒号,逗号),默认为空
29 | - 支持三种prompt template用于生成sql查询,点击选项后可以看到具体的prompt template,也可在custom中自定义
30 |
31 | **注意:** 如自定义prompt,请尽量参考template修改,其中中大括号{}的内容为输入参数,需要保留
32 |
33 | 确认输入无误后,可直接在右侧chatbot中开始提问,回答结果的Reference中可以看到查询的数据库表名称,生成的sql语句,以及该sql语句是否有效执行。**注意:** 这里有效执行是指sql语句语法有效,不代表业务逻辑一定正确
34 |
35 | Reference可以作为查询效果优化的"debug"工具
36 |
37 | 
38 |
39 | ## 查询效果优化
40 |
41 | ### Description
42 |
43 | 针对数据表中字段含义不清晰,或者字段存储内容格式不清晰等问题,可以在Descriptions中增加相应描述,帮助llm更准确提取数据表内容,此处以公开数据集Spider中my_pets数据库为例,其中pets表数据如下:
44 |
45 | 
46 |
47 | 问答效果对比:
48 |
49 | 当描述为空时,对问题“有几只狗”生成的sql查询语句为:SELECT COUNT(\*) FROM pets WHERE PetType = '狗',查询不到
50 |
51 | 
52 |
53 | 增加简单描述后,生成的sql查询语句为:SELECT COUNT(\*) FROM pets WHERE PetType = 'dog',可以准确回答
54 |
55 | 
56 |
57 | 如果查询效果有明显改善,可以将相应的补充描述在数据库中作为相应table或column的comment持久化添加
58 |
59 | ### Prompt
60 |
61 | 此外,还可以通过调整prompt template,优化查询效果
62 |
63 | - Reference观察到生成的sql语句包含了非sql的其他内容,如开头多了"sql"等(部分小模型可能会发生),可以在prompt中增加简单限制
64 | - Reference观察到生成的sql语句不满足某些业务逻辑,可以在prompt中给出示例,通过few-shot learning,也可以很快提升效果
65 |
66 | # 表格文件分析配置
67 |
68 | 表格文件配置相对简单,选择左上方的分析类型为:datafile,出现以下界面
69 |
70 | 
71 |
72 | 点击左侧中部的上传,一次上传一份表格文件(excel或csv格式),上传成功后,左侧下方会出现文件的前几行预览,如下图所示:
73 |
74 | 
75 |
76 | 上传表格文件后可以直接在右侧chatbot中提问,如需更换表格,重新上传所需表格即可
77 |
--------------------------------------------------------------------------------
/docs/develop/local_develop_zh.md:
--------------------------------------------------------------------------------
1 | 如果需要在本地进行开发运行,请参考以下步骤:
2 |
3 | ## 本地启动
4 |
5 | 1. 克隆仓库
6 |
7 | ```bash
8 | git clone git@github.com:aigc-apps/PAI-RAG.git
9 | ```
10 |
11 | 2. 配置开发环境
12 |
13 | 本项目使用poetry进行管理,若在本地环境下使用,建议在安装环境之前先创建一个空环境。为了确保环境一致性并避免因Python版本差异造成的问题,我们指定Python版本为3.11。
14 |
15 | ```bash
16 | conda create -n rag_env python==3.11
17 | conda activate rag_env
18 | ```
19 |
20 | 如果使用macOS且需要处理PPTX文件,需要下载依赖库处理PPTX文件
21 |
22 | ```bash
23 | brew install mono-libgdiplus
24 | ```
25 |
26 | 直接使用poetry安装项目依赖包:
27 |
28 | ```bash
29 | pip install poetry
30 | poetry install
31 | pip install magic-pdf[full]==1.3.10
32 | pip install opentelemetry-exporter-otlp-proto-grpc protobuf==5.27.4
33 | ```
34 |
35 | 安装filebrowser
36 |
37 | ```bash
38 | wget https://eas-data.oss-cn-shanghai.aliyuncs.com/3rdparty/sdwebui/filebrowser
39 | mv filebrowser /bin/filebrowser
40 | chmod u+x /bin/filebrowser
41 | ```
42 |
43 | - 常见网络超时问题
44 |
45 | 注:在安装过程中,若遇到网络连接超时的情况,可以添加阿里云或清华的镜像源,在 pyproject.toml 文件末尾追加以下几行:
46 |
47 | ```bash
48 | [[tool.poetry.source]]
49 | name = "mirrors"
50 | url = "http://mirrors.aliyun.com/pypi/simple/" # 阿里云
51 | # url = "https://pypi.tuna.tsinghua.edu.cn/simple/" # 清华
52 | priority = "default"
53 | ```
54 |
55 | 之后,再依次执行以下命令:
56 |
57 | ```bash
58 | poetry lock
59 | poetry install
60 | pip install magic-pdf[full]==1.3.10
61 | pip install opentelemetry-exporter-otlp-proto-grpc protobuf==5.27.4
62 | ```
63 |
64 | 3. 启动RAG服务
65 |
66 | 使用DashScope API,需要在命令行引入环境变量:
67 |
68 | ```bash
69 | export DASHSCOPE_API_KEY="xxx"
70 | ```
71 |
72 | 请替换xxx为你自己的DASHSCOPE_API_KEY,DASHSCOPE_API_KEY获取地址为 https://dashscope.console.aliyun.com/apiKey
73 |
74 | 启动:
75 |
76 | ```bash
77 | # 启动,支持自定义hport(默认8680), worker_num(默认1)
78 | # 默认启动时下载模型 [bge-m3, pdf-extract]
79 | # 可使用命令行 "load_model" 下载模型 including [bge-m3, pdf-extract, SGPT-125M-weightedmean-nli-bitfit, bge-large-zh-v1.5, bge-reranker-base, bge-reranker-large, paraphrase-multilingual-MiniLM-L12-v2, qwen_1.8b, text2vec-large-chinese]
80 | ./scripts/start.sh [-w WORKER_NUM] [-p PORT]
81 | ```
82 |
83 | ```bash
84 | ./scripts/start.sh
85 | ```
86 |
87 | 你可以打开http://localhost:8680/ 来配置RAG服务以及上传本地数据。
88 |
--------------------------------------------------------------------------------
/docs/docker_build.md:
--------------------------------------------------------------------------------
1 | # Docker build
2 |
3 | ## Server
4 |
5 | ### CPU
6 |
7 | ```bash
8 | docker build -f Dockerfile -t rag_serve:0.1 .
9 | ```
10 |
11 | ### GPU
12 |
13 | ```bash
14 | docker build -f Dockerfile_gpu -t rag_serve:0.1_gpu .
15 | ```
16 |
17 | ## UI
18 |
19 | ```bash
20 | docker build -f Dockerfile_ui -t rag_ui:0.1 .
21 | ```
22 |
23 | ## Nginx
24 |
25 | ```bash
26 | docker build -f Dockerfile_nginx -t rag_nginx:0.1 .
27 | ```
28 |
29 | # 常见问题
30 |
31 | ## docker pull timeout
32 |
33 | 建议更换docker镜像源为阿里云镜像,在阿里云在容器镜像服务 -> 镜像工具 -> 镜像加速器 中可以找到阿里云的专属镜像加速器,按照指示说明修改daemon配置文件来使用加速器即可。
34 |
--------------------------------------------------------------------------------
/docs/eas_deploy.md:
--------------------------------------------------------------------------------
1 | # EAS自定义部署RAG服务
2 |
3 | 模型在线服务EAS(Elastic Algorithm Service)是阿里云PAI产品为实现一站式模型开发部署应用,针对在线推理场景提供的模型在线服务,支持将模型服务部署在公共资源组或专属资源组,实现基于异构硬件(CPU和GPU)的模型加载和数据请求的实时响应。
4 |
5 | 我们支持通过`场景化部署`和`自定义部署`两种方式来一键部署RAG服务,其中,
6 |
7 | - `场景化部署`: 更加方便,只需要配置几个参数即可完成。可参考[场景化部署文档](https://help.aliyun.com/zh/pai/user-guide/deploy-a-rag-based-dialogue-system)。
8 | - `自定义部署`可以更灵活地配置服务,比如,部署GPU版本镜像,配置链路追踪服务等等。可参考[自定义部署文档](https://help.aliyun.com/zh/pai/use-cases/custom-deployment-of-rag-service#47e8104831b4f)
9 |
--------------------------------------------------------------------------------
/docs/figures/agent/agenda_query.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/agent/agenda_query.png
--------------------------------------------------------------------------------
/docs/figures/agent/agent_tab.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/agent/agent_tab.png
--------------------------------------------------------------------------------
/docs/figures/agent/date_query.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/agent/date_query.png
--------------------------------------------------------------------------------
/docs/figures/agent/nl2sql_query.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/agent/nl2sql_query.png
--------------------------------------------------------------------------------
/docs/figures/agent/rag_knowledge_query.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/agent/rag_knowledge_query.png
--------------------------------------------------------------------------------
/docs/figures/agent/weather_query.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/agent/weather_query.png
--------------------------------------------------------------------------------
/docs/figures/agent/web_search.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/agent/web_search.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/DBChat_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/DBChat_overview.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/da_db_chat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/da_db_chat.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/da_db_config.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/da_db_config.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/da_db_enhance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/da_db_enhance.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/da_db_load.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/da_db_load.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/da_db_prompt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/da_db_prompt.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/da_db_prompt_reset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/da_db_prompt_reset.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/da_llm_config.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/da_llm_config.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/da_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/da_overview.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/da_sheet_chat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/da_sheet_chat.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/da_sheet_upload.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/da_sheet_upload.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/data_analysis_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/data_analysis_overview.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/datafile_chat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/datafile_chat.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/datafile_config.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/datafile_config.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/db_chat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/db_chat.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/db_chat_with_memo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/db_chat_with_memo.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/db_config.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/db_config.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/db_config_update.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/db_config_update.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/db_enhanced_features.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/db_enhanced_features.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/db_info_load.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/db_info_load.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/db_query_desc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/db_query_desc.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/db_query_no_desc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/db_query_no_desc.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/llm_config.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/llm_config.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/llm_selection.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/llm_selection.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/prompt_config.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/prompt_config.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/prompt_reset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/prompt_reset.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/sheet_data_preview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/sheet_data_preview.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/sheet_upload.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/sheet_upload.png
--------------------------------------------------------------------------------
/docs/figures/data_analysis/table_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/table_example.png
--------------------------------------------------------------------------------
/docs/figures/deepseek/deepseek_eas_api.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deepseek/deepseek_eas_api.png
--------------------------------------------------------------------------------
/docs/figures/deepseek/llm_chat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deepseek/llm_chat.png
--------------------------------------------------------------------------------
/docs/figures/deepseek/llm_config.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deepseek/llm_config.png
--------------------------------------------------------------------------------
/docs/figures/deepseek/rag_chat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deepseek/rag_chat.png
--------------------------------------------------------------------------------
/docs/figures/deploy/eas/deploy_json.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/deploy_json.png
--------------------------------------------------------------------------------
/docs/figures/deploy/eas/deploy_portal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/deploy_portal.png
--------------------------------------------------------------------------------
/docs/figures/deploy/eas/deploy_resources.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/deploy_resources.jpg
--------------------------------------------------------------------------------
/docs/figures/deploy/eas/deploy_success.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/deploy_success.jpg
--------------------------------------------------------------------------------
/docs/figures/deploy/eas/deploy_vpc.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/deploy_vpc.jpg
--------------------------------------------------------------------------------
/docs/figures/deploy/eas/edited_json.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/edited_json.jpg
--------------------------------------------------------------------------------
/docs/figures/deploy/eas/enable_web.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/enable_web.jpg
--------------------------------------------------------------------------------
/docs/figures/deploy/eas/trace_detail.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/trace_detail.jpg
--------------------------------------------------------------------------------
/docs/figures/deploy/eas/trace_json.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/trace_json.jpg
--------------------------------------------------------------------------------
/docs/figures/deploy/eas/trace_key.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/trace_key.jpg
--------------------------------------------------------------------------------
/docs/figures/deploy/eas/trace_percent.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/trace_percent.jpg
--------------------------------------------------------------------------------
/docs/figures/deploy/eas/view_web.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/view_web.jpg
--------------------------------------------------------------------------------
/docs/figures/elastic/aliyun_es_ik_hot_update.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_ik_hot_update.png
--------------------------------------------------------------------------------
/docs/figures/elastic/aliyun_es_instance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_instance.png
--------------------------------------------------------------------------------
/docs/figures/elastic/aliyun_es_instance_autoindex.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_instance_autoindex.png
--------------------------------------------------------------------------------
/docs/figures/elastic/aliyun_es_instance_info.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_instance_info.png
--------------------------------------------------------------------------------
/docs/figures/elastic/aliyun_es_kibana.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_kibana.png
--------------------------------------------------------------------------------
/docs/figures/elastic/aliyun_es_kibana_entry.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_kibana_entry.png
--------------------------------------------------------------------------------
/docs/figures/elastic/aliyun_es_kibana_index_detail.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_kibana_index_detail.png
--------------------------------------------------------------------------------
/docs/figures/elastic/aliyun_es_kibana_index_management.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_kibana_index_management.png
--------------------------------------------------------------------------------
/docs/figures/elastic/aliyun_es_kibana_management.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_kibana_management.png
--------------------------------------------------------------------------------
/docs/figures/elastic/aliyun_es_kibana_menu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_kibana_menu.png
--------------------------------------------------------------------------------
/docs/figures/elastic/aliyun_es_kibana_whitelist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_kibana_whitelist.png
--------------------------------------------------------------------------------
/docs/figures/elastic/aliyun_es_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_overview.png
--------------------------------------------------------------------------------
/docs/figures/elastic/aliyun_es_password.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_password.png
--------------------------------------------------------------------------------
/docs/figures/elastic/aliyun_es_password_reset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_password_reset.png
--------------------------------------------------------------------------------
/docs/figures/elastic/aliyun_es_plugin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_plugin.png
--------------------------------------------------------------------------------
/docs/figures/elastic/aliyun_es_upload_dic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_upload_dic.png
--------------------------------------------------------------------------------
/docs/figures/elastic/new_word_dict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/new_word_dict.png
--------------------------------------------------------------------------------
/docs/figures/elastic/pairag_es_connect.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/pairag_es_connect.png
--------------------------------------------------------------------------------
/docs/figures/elastic/pairag_es_param_list.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/pairag_es_param_list.png
--------------------------------------------------------------------------------
/docs/figures/elastic/pairag_retrieval_mode.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/pairag_retrieval_mode.png
--------------------------------------------------------------------------------
/docs/figures/elastic/pairag_select_es.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/pairag_select_es.png
--------------------------------------------------------------------------------
/docs/figures/framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/framework.jpg
--------------------------------------------------------------------------------
/docs/figures/index/add_index.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/index/add_index.png
--------------------------------------------------------------------------------
/docs/figures/index/config_new_index.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/index/config_new_index.png
--------------------------------------------------------------------------------
/docs/figures/index/select_index.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/index/select_index.png
--------------------------------------------------------------------------------
/docs/figures/multimodal/mm_chat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/multimodal/mm_chat.jpg
--------------------------------------------------------------------------------
/docs/figures/multimodal/mm_settings.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/multimodal/mm_settings.jpg
--------------------------------------------------------------------------------
/docs/figures/multimodal/mm_upload.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/multimodal/mm_upload.jpg
--------------------------------------------------------------------------------
/docs/figures/quick_start/query.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/quick_start/query.png
--------------------------------------------------------------------------------
/docs/figures/quick_start/setting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/quick_start/setting.png
--------------------------------------------------------------------------------
/docs/figures/quick_start/upload.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/quick_start/upload.png
--------------------------------------------------------------------------------
/docs/index/multi_index.md:
--------------------------------------------------------------------------------
1 | # 多知识库索引支持
2 |
3 | 当需要多个知识库来存储不同类型/业务的知识文件时,可以使用我们的多知识库支持api/ui来管理知识库索引。
4 | 目前支持以下功能:
5 |
6 | 1. 新增一个知识库index
7 | 2. 修改创建好的知识库index
8 | 3. 删除一个知识库index
9 | 4. 检索/查询对应的知识库
10 | 5. 支持不同的向量数据库后端,faiss, elasticsearch等等。
11 |
12 | ## 快速使用
13 |
14 | 可以打开UI页面快速创建、上传、检索知识库。
15 |
16 | ### 创建
17 |
18 | 1. Index下拉列表选择新建"New"
19 |
20 |
21 |
22 | 2. 填写索引名称,选择embedding和向量数据库类型,点击"Add Index"
23 |
24 |
25 |
26 | 创建完成。
27 |
28 | ### 上传知识库和查询
29 |
30 | 可以通过左边的选择器选择对应的index_name进行操作:
31 |
32 |
33 |
34 | ## API 使用
35 |
36 | 注意,目前查询和上传API均可以指定index_name来切换知识库,当index_name参数省略时,默认为**default**知识库。
37 |
38 | 目前,删除index操作仅仅支持通过API删除。
39 |
40 | ### 查询接口(Query & Retrieval)
41 |
42 | ** Retrieval **
43 |
44 | ```sh
45 | curl -X POST http://localhost:8680/service/query/retrieval -H "Content-Type: application/json" -d '{"question": "什么是组件化", "index_name": "default"}'
46 | ```
47 |
48 | ** Query **
49 |
50 | ```sh
51 | curl -X POST http://localhost:8680/service/query -H "Content-Type: application/json" -d '{"question": "什么是组件化", "index_name": "default"}'
52 | ```
53 |
54 | ### 上传接口(Upload)
55 |
56 | ```sh
57 | curl -X POST http://localhost:8680/service/upload_data -H 'Content-Type: multipart/form-data' -F 'files=@/localpath/PAI.txt' -F "index_name=es_test_1"
58 | ```
59 |
60 | ### List Index
61 |
62 | ```sh
63 | curl -X GET http://localhost:8680/service/indexes
64 | ```
65 |
66 | ### Add Index
67 |
68 | ```sh
69 | curl -X POST http://localhost:8680/service/indexes/index3 -H "Content-Type: Application/json" -d '{"index_name":"index3","vector_store_config":{"persist_path":"localdata/storage3","type":"faiss","is_image_store":false},"embedding_config":{"source":"dashscope","embed_batch_size":10}}'
70 | ```
71 |
72 | response:
73 |
74 | ```json
75 | { "msg": "Add index 'index3' successfully." }
76 | ```
77 |
78 | ### Update Index
79 |
80 | ```sh
81 | curl -X PATCH http://localhost:8680/service/indexes/index3 -H "Content-Type: Application/json" -d '{"index_name":"index3","vector_store_config":{"persist_path":"localdata/storage4","type":"faiss","is_image_store":false},"embedding_config":{"source":"dashscope","embed_batch_size":10}}'
82 | ```
83 |
84 | response:
85 |
86 | ```json
87 | { "msg": "Update index 'index3' successfully." }
88 | ```
89 |
90 | ### Delete Index
91 |
92 | ```sh
93 | curl -X DELETE http://localhost:8680/service/indexes/index3
94 | ```
95 |
96 | response:
97 |
98 | ```json
99 | { "msg": "Update index 'index3' successfully." }
100 | ```
101 |
--------------------------------------------------------------------------------
/docs/multimodal_rag.md:
--------------------------------------------------------------------------------
1 | # 多模态问答
2 |
3 | 很多时候,知识库的文档中不只是纯文本信息,还包含很多图文交错的pdf、word、markdown等文件,甚至一些海报之类的纯图片文件。
4 | 普通的RAG流程会忽略这些图片输入,仅仅使用文本信息,这样会出现很多信息丢失的情况。
5 | 这里我们通过使用多模态模型来实现图文混合的多模态问答。
6 |
7 | ## 配置多模态LLM和Aliyun OSS存储
8 |
9 | 首先我们需要配置多模态LLM,这里我们推荐使用DASHSCOPE VLLM,或者部署在PAI-EAS的兼容openai协议的VLLM模型,比如开源的qwen2-vl。
10 |
11 | 然后需要添加一个Aliyun的OSS存储,来存储图片文件信息。这样在结果输出时,可以通过图片链接的方式在回复中展示图片。
12 |
13 | 配置示例如图, 配置完保存即可。
14 |
15 |
16 |
17 | ## 上传多模态文件
18 |
19 | 这里支持多种多模态文件格式,包括pdf, markdown, word, ppt, png, jpg等。
20 |
21 | 这里需要勾选`Process with MultiModal` 选项才会处理文件中的图片信息,处理pdf时,建议勾选`Process PDF with OCR` 选项。
22 |
23 |
24 |
25 | ## 问答测试
26 |
27 | 勾选`Display Image`选项,就可以进行多模态问答。
28 |
29 | 同时,你可以调整下方的Multimodal Prompt来优化问答提示词。
30 |
31 |
32 |
--------------------------------------------------------------------------------
/docs/tabular_doc.md:
--------------------------------------------------------------------------------
1 | # Tabular processing with PAI-RAG
2 |
3 | ## PaiCSVReader
4 |
5 | PaiCSVReader(concat_rows=True, row_joiner="\n", csv_config={})
6 |
7 | ### Parameters:
8 |
9 | **concat_rows:** _bool, default=True._
10 | Whether to concatenate rows into one document.
11 |
12 | **row_joiner:** _str, default="\n"._
13 | The separator used to join rows.
14 |
15 | **header:** _None or int, list of int, default 0._
16 | row (0-indexed) to use for the column labels of the parsed DataFrame. If a list of integers is passed those row
17 | positions will be combined into a MultiIndex. Use None if there is no header.
18 |
19 | ### Functions:
20 |
21 | load_data(file: Path, extra_info: Optional[Dict] = None, fs: Optional[AbstractFileSystem] = None)
22 |
23 | ## PaiPandasCSVReader
24 |
25 | PaiPandasCSVReader(concat_rows=True, row_joiner="\n", pandas_config={})
26 |
27 | ### Parameters:
28 |
29 | **concat_rows:** _bool, default=True._
30 | Whether to concatenate rows into one document.
31 |
32 | **row_joiner:** _str, default="\n"._
33 | The separator used to join rows.
34 |
35 | **pandas_config:** _dict, default={}._
36 | The configuration of pandas.read_csv.
37 | Refer to https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html for more information.
38 | Set to empty dict by default, this means pandas will try to figure out the separators, table head, etc. on its own.
39 |
40 | #### one important parameter:
41 |
42 | **header:** _None or int, list of int, default 0._
43 | Row (0-indexed) to use for the column labels of the parsed DataFrame. If a list of integers is passed those row
44 | positions will be combined into a MultiIndex. Use None if there is no header.
45 |
46 | ### Functions:
47 |
48 | load_data(file: Path, extra_info: Optional[Dict] = None, fs: Optional[AbstractFileSystem] = None)
49 |
50 | ## PaiPandasExcelReader
51 |
52 | PaiPandasExcelReader(concat_rows=True, row_joiner="\n", pandas_config={})
53 |
54 | ### Parameters:
55 |
56 | **concat_rows:** _bool, default=True._
57 | Whether to concatenate rows into one document.
58 |
59 | **row_joiner:** _str, default="\n"._
60 | The separator used to join rows.
61 |
62 | **pandas_config:** _dict, default={}._
63 | The configuration of pandas.read_csv.
64 | Refer to https://pandas.pydata.org/docs/reference/api/pandas.read_excel.html for more information.
65 | Set to empty dict by default, this means pandas will try to figure out the separators, table head, etc. on its own.
66 |
67 | #### one important parameter:
68 |
69 | **header:** _None or int, list of int, default 0._
70 | Row (0-indexed) to use for the column labels of the parsed DataFrame. If a list of integers is passed those row
71 | positions will be combined into a MultiIndex. Use None if there is no header.
72 |
73 | ### Functions:
74 |
75 | load_data(file: Path, extra_info: Optional[Dict] = None, fs: Optional[AbstractFileSystem] = None)
76 | only process the first sheet
77 |
--------------------------------------------------------------------------------
/example_data/cn_clip/pokemon.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/example_data/cn_clip/pokemon.jpeg
--------------------------------------------------------------------------------
/example_data/eval_docs_image/跑鞋推荐.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/example_data/eval_docs_image/跑鞋推荐.pdf
--------------------------------------------------------------------------------
/example_data/eval_docs_image_example/multimodal_eval_dataset_zh_example.jsonl:
--------------------------------------------------------------------------------
1 | {"id":"qca_03d5b51c-0308-436d-9a2f-a8389d821001","query":{"query_text":"2023年春夏期间,哪种泛户外运动在社交平台上的讨论声量最高?","source":{"name":"manual","model":"manual"}},"contexts":[{"node_id":"a09e96d9263b02b2b12a927144575af9bf31776e84e78ce35acab5b6619c84ac","type":"TextNode","text":"在赛事和政策的双重推动下,国民运动户外参与意愿高涨,超过六成的受访者表示近一年显著增加了运动户外的频率,各类运动项目正在快速走向“全民化”。新的一年,随着巴黎奥运会、美洲杯等赛事的举办,全民运动热情将进一步被激发。对于品牌而言,这是一个难得的市场机遇,通过精准地选中和锁定与运动相关的目标人群,品牌可以有效地实现用户收割。 \n\n \n\n悦己驱动,运动边界向轻量泛户外持续延伸 \n\n国民参与运动户外活动更多来自“悦己”观念的驱动,近7成的受访者表示他们主要是为了“强身健体/享受大自然”,因此轻量级、易开展的活动项目更受广大普通受众的青睐。近三年,社交平台关于“泛户外运动”的讨论热度持续走高,更是在23年春夏期间迎来一波小高峰:细分到具体的活动项目上,垂钓讨论声量较高;露营也保持较高声量,其经历过22年的大爆发、23年的行业调整,预计24年已经进入更深精细化运营;此外城市骑行热度也在不断上升,成为当下新兴的小众活动。","metadata":{"image_url_list":["https://pai-rag.oss-cn-hangzhou.aliyuncs.com/pairag/doc_images/2024春夏淘宝天猫运动户外行业趋势白皮书_淘宝/d4e624aceb4043839c924e33c075e388.jpeg", "https://pai-rag.oss-cn-hangzhou.aliyuncs.com/pairag/doc_images/2024春夏淘宝天猫运动户外行业趋势白皮书_淘宝/52d1353d4577698891e7710ae12e18b1.jpeg", "https://pai-rag.oss-cn-hangzhou.aliyuncs.com/pairag/doc_images/2024春夏淘宝天猫运动户外行业趋势白皮书_淘宝/4f77ded6421ddadd519ab9ef1601a784.jpeg"]}}],"answer":{"answer_text":"根据给定的材料,2023年春夏期间,垂钓在社交平台上的讨论声量最高。\n\n","answer_image_url_list":null,"source":{"name":"manual","model":"manual"}}}
2 | {"id":"qca_03d5b51c-0308-436d-9a2f-a8389d821002","query":{"query_text":"狄尔泰属于哪个教育学流派?","source":{"name":"manual","model":"manual"}},"contexts":[{"node_id":"a09e96d9263b02b2b12a927144575af9bf31776e84e78ce35acab5b6619c84ab","type":"TextNode","text":"","metadata":{"image_url_list":["https://pai-rag.oss-cn-hangzhou.aliyuncs.com/pairag/doc_images/教育文档/思维导图.jpeg"]}}],"answer":{"answer_text":"狄尔泰属于文化教育学流派。\n\n","answer_image_url_list":null,"source":{"name":"manual","model":"manual"}}}
3 |
--------------------------------------------------------------------------------
/example_data/eval_docs_text/EasyRec.txt:
--------------------------------------------------------------------------------
1 | EasyRec是一个易于使用的推荐框架¶
2 | EasyRec 实现了常见推荐任务中使用的最先进的机器学习模型:候选生成(匹配)、评分(排名)和多任务学习。 它通过简单的配置和超参数调整(HPO)提高了生成高性能模型的效率。
3 |
4 | EasyRec视频介绍
5 | 为什么选择 EasyRec?¶
6 | 到处运行¶
7 | MaxCompute / 数据科学 / DLC / 本地
8 | TF1.12-1.15 / TF2.x / PAI-TF
9 | 多样的输入数据¶
10 | MaxCompute表
11 | HDFS 文件
12 | 操作系统文件
13 | 卡夫卡流
14 | 本地 CSV
15 |
16 | 配置简单¶
17 | 灵活的功能配置和简单的模型配置
18 | 高效、鲁棒的特征生成[淘宝使用]
19 | 漂亮的网络界面正在开发中
20 |
21 | 它很聪明¶
22 | EarlyStop / 最佳检查站保护程序
23 | 超参数搜索/AutoFeatureCross
24 | 开发中:NAS、知识蒸馏、多模式
25 |
26 | 规模大、部署方便¶
27 | 支持大规模嵌入,增量保存
28 | 许多并行策略:ParameterServer、Mirrored、MultiWorker
29 | 轻松部署到 EAS:自动扩展、轻松监控
30 | 一致性保证:训练和服务
31 |
32 | 多种模型
33 | DSSM / MIND / DropoutNet / CoMetricLearningI2I / PDN
34 | W&D / DeepFM / MultiTower / DCN / DIN / BST
35 | MMoE / ESMM / DBMTL / PLE
36 | CMBF / 联合
37 |
38 | 易于定制¶
39 | 易于实现定制模型
40 | 无需关心数据管道
41 | 快速向量检索¶
42 | 在分布式环境中运行向量的 knn 算法
43 |
44 | 欢迎加入【EasyRec推荐算法交流群】,钉钉群号 : 32260796
45 |
--------------------------------------------------------------------------------
/example_data/eval_docs_text/PAI.txt:
--------------------------------------------------------------------------------
1 | 机器学习PAI(Platform of Artificial Intelligence)是阿里云人工智能平台,提供一站式的机器学习解决方案。本文为您介绍什么是机器学习PAI。
2 |
3 | 什么是机器学习
4 | 机器学习是指机器通过统计学算法,对大量历史数据进行学习,进而利用生成的经验模型指导业务。目前机器学习主要应用在以下场景:
5 | 营销类场景:商品推荐、用户群体画像或广告精准投放。
6 | 金融类场景:贷款发放预测、金融风险控制、股票走势预测或黄金价格预测。
7 | 社交网络服务关系挖掘场景:微博粉丝领袖分析或社交关系链分析。
8 | 文本类场景:新闻分类、关键词提取、文章摘要或文本内容分析。
9 | 非结构化数据处理场景:图片分类或图片文本内容提取。
10 | 其它各类预测场景:降雨预测或足球比赛结果预测。
11 | 机器学习包括传统机器学习和深度学习。传统机器学习分为以下几类:
12 | 有监督学习(Supervised Learning):每个样本都有对应的期望值,通过搭建模型,实现从输入特征向量到目标值的映射。例如解决回归和分类问题。
13 | 无监督学习(Unsupervised Learning):所有样本没有目标值,期望从数据本身发现一些潜在规律。例如解决聚类问题。
14 | 增强学习(Reinforcement Learning):相对比较复杂,系统和外界环境不断交互,根据外界反馈决定自身行为,达到目标最优化。例如阿尔法围棋和无人驾驶。
15 |
--------------------------------------------------------------------------------
/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/dataset/上海攻略信息.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/dataset/上海攻略信息.pdf
--------------------------------------------------------------------------------
/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/dataset/北京攻略信息.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/dataset/北京攻略信息.pdf
--------------------------------------------------------------------------------
/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/figures/agent_chat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/figures/agent_chat.jpg
--------------------------------------------------------------------------------
/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/figures/agent_config.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/figures/agent_config.jpg
--------------------------------------------------------------------------------
/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/figures/settings.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/figures/settings.jpg
--------------------------------------------------------------------------------
/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/figures/upload.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/figures/upload.jpg
--------------------------------------------------------------------------------
/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/settings.toml:
--------------------------------------------------------------------------------
1 | dynaconf_merge = true
2 |
3 | [rag]
4 | name = "pairag"
5 | version = "0.1.1"
6 |
7 | [rag.agent]
8 | type = "function_calling"
9 |
10 | [rag.agent.custom_config]
11 | agent_file_path = "example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant"
12 |
13 | [rag.agent.intent_detection]
14 | type = "single"
15 |
16 | [rag.agent.tool]
17 | type = "api"
18 |
19 | [rag.chat_store]
20 | type = "Local" # [Local, Aliyun-Redis]
21 | host = "Aliyun-Redis host"
22 | password = "Aliyun-Redis user:pwd"
23 | persist_path = "localdata/storage"
24 |
25 | [rag.data_reader]
26 | type = "SimpleDirectoryReader"
27 |
28 | # embedding configurations, source support API: OpenAI,DashScope; and local model:HuggingFace
29 | # if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model
30 | # eg.
31 | # source = "HuggingFace"
32 | # model = "bge-m3"
33 | # embed_batch_size = 10
34 | [rag.embedding]
35 | source = "DashScope"
36 | embed_batch_size = 10
37 |
38 | [rag.embedding.multi_modal]
39 | source = "cnclip"
40 |
41 | [rag.index]
42 | persist_path = "localdata/storage"
43 | vector_store.type = "FAISS"
44 |
45 | # llm configurations, source support API: OpenAI,DashScope or PAI-EAS's deployment
46 | # eg.
47 | # source = "PaiEas"
48 | # endpoint = ""
49 | # token = ""
50 | [rag.llm]
51 | source = "DashScope"
52 | model = "qwen-max"
53 |
54 | [rag.llm.function_calling_llm]
55 | source = "DashScope"
56 | model = "qwen2-7b-instruct"
57 |
58 | [rag.llm.multi_modal]
59 | source = ""
60 |
61 | [rag.node_enhancement]
62 | tree_depth = 3
63 | max_clusters = 52
64 | proba_threshold = 0.10
65 |
66 | [rag.node_parser]
67 | type = "Sentence"
68 | chunk_size = 500
69 | chunk_overlap = 10
70 |
71 | [rag.postprocessor]
72 | reranker_type = "simple-weighted-reranker" # [simple-weighted-reranker, model-based-reranker]
73 | reranker_model = "bge-reranker-base" # [bge-reranker-base, bge-reranker-large]
74 | keyword_weight = 0.3
75 | vector_weight = 0.7
76 | similarity_threshold = 0.5
77 | top_n = 2
78 |
79 | [rag.query_engine]
80 | type = "RetrieverQueryEngine"
81 |
82 | [rag.retriever]
83 | similarity_top_k = 3
84 | retrieval_mode = "hybrid" # [hybrid, embedding, keyword, router]
85 | query_rewrite_n = 1 # set to 1 to disable query generation
86 |
87 | [rag.synthesizer]
88 | type = "SimpleSummarize"
89 | text_qa_template = "参考内容信息如下\n---------------------\n{context_str}\n---------------------根据提供内容而非其他知识回答问题.\n问题: {query_str}\n答案: \n"
90 |
--------------------------------------------------------------------------------
/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/tools.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import os
3 | from loguru import logger
4 |
5 |
6 | def get_place_weather(city: str) -> str:
7 | logger.info(f"[Agent] Checking realtime weather info for {city}")
8 |
9 | """Get city name and return city weather"""
10 | api_key = os.environ.get("weather_api_key")
11 |
12 | # 可以直接赋值给api_key,原始代码的config只有type类型。
13 | base_url = "http://api.openweathermap.org/data/2.5/forecast?"
14 | complete_url = f"{base_url}q={city}&appid={api_key}&lang=zh_cn&units=metric"
15 | logger.info(f"Requesting {complete_url}...")
16 | response = requests.get(complete_url, timeout=5)
17 | weather_data = response.json()
18 |
19 | if weather_data["cod"] != "200":
20 | logger.error(
21 | f"获取天气信息失败,错误代码:{weather_data['cod']} 错误信息:{weather_data['message']}"
22 | )
23 | return f"获取天气信息失败,错误代码:{weather_data['cod']} 错误信息:{weather_data['message']}"
24 |
25 | element = weather_data["list"][0]
26 |
27 | return f"""
28 | {city}的天气:
29 | 时间: {element['dt_txt']}
30 | 温度: {element['main']['temp']} °C
31 | 天气描述: {element['weather'][0]['description']}
32 | """
33 |
--------------------------------------------------------------------------------
/example_data/pai_document.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/example_data/pai_document.pdf
--------------------------------------------------------------------------------
/integrations/pairag-file/README.md:
--------------------------------------------------------------------------------
1 | # pairag.file
2 |
3 | This module contains the readers for the different file types and node parsers to chunk parsed files.
4 |
--------------------------------------------------------------------------------
/integrations/pairag-file/pairag/file/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/pairag/file/__init__.py
--------------------------------------------------------------------------------
/integrations/pairag-file/pairag/file/nodeparsers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/pairag/file/nodeparsers/__init__.py
--------------------------------------------------------------------------------
/integrations/pairag-file/pairag/file/nodeparsers/pai/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/pairag/file/nodeparsers/pai/__init__.py
--------------------------------------------------------------------------------
/integrations/pairag-file/pairag/file/nodeparsers/pai/constants.py:
--------------------------------------------------------------------------------
1 | # paragraph separator for splitter
2 | DEFAULT_NODE_PARSER_TYPE = "Sentence"
3 | DEFAULT_PARAGRAPH_SEP = "\n\n"
4 | DEFAULT_SENTENCE_CHUNK_OVERLAP = 200
5 | DEFAULT_SENTENCE_WINDOW_SIZE = 3
6 | DEFAULT_BREAKPOINT = 95
7 | DEFAULT_BUFFER_SIZE = 1
8 |
--------------------------------------------------------------------------------
/integrations/pairag-file/pairag/file/nodeparsers/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/pairag/file/nodeparsers/utils/__init__.py
--------------------------------------------------------------------------------
/integrations/pairag-file/pairag/file/readers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/pairag/file/readers/__init__.py
--------------------------------------------------------------------------------
/integrations/pairag-file/pairag/file/readers/pai/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/pairag/file/readers/pai/__init__.py
--------------------------------------------------------------------------------
/integrations/pairag-file/pairag/file/readers/pai/constants.py:
--------------------------------------------------------------------------------
1 | ACCEPTABLE_DOC_TYPES = set(
2 | [
3 | ".html",
4 | ".htm",
5 | ".txt",
6 | ".docx",
7 | ".pdf",
8 | ".pptx",
9 | ".md",
10 | ".xls",
11 | ".jsonl",
12 | ".csv",
13 | ".xlsx",
14 | ".jpg",
15 | ".jpeg",
16 | ".png",
17 | ]
18 | )
19 |
--------------------------------------------------------------------------------
/integrations/pairag-file/pairag/file/readers/pai/file_readers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/pairag/file/readers/pai/file_readers/__init__.py
--------------------------------------------------------------------------------
/integrations/pairag-file/pairag/file/readers/pai/file_readers/pai_image_reader.py:
--------------------------------------------------------------------------------
1 | """Tabular parser-CSV parser.
2 |
3 | Contains parsers for tabular data files.
4 |
5 | """
6 |
7 | from pathlib import Path
8 | from typing import Any, Dict, List, Optional
9 | from fsspec import AbstractFileSystem
10 | import os
11 | from llama_index.core.readers.base import BaseReader
12 | from llama_index.core.schema import Document, ImageDocument
13 |
14 | from pairag.file.readers.pai.utils.image_utils import image_from_url
15 | from pairag.file.store.pai_image_store import PaiImageStore
16 |
17 |
18 | class PaiImageReader(BaseReader):
19 | """Image parser.
20 |
21 | Args:
22 | multi-modal llm (LLM)
23 |
24 | """
25 |
26 | def __init__(self, image_store: PaiImageStore, *args: Any, **kwargs: Any) -> None:
27 | """Init params."""
28 | super().__init__(*args, **kwargs)
29 | self.image_store = image_store
30 |
31 | def load_data(
32 | self,
33 | file_path: Path,
34 | extra_info: Optional[Dict] = None,
35 | fs: Optional[AbstractFileSystem] = None,
36 | ) -> List[Document]:
37 | if self.image_store is None:
38 | raise Exception(
39 | f"Oss config must be provided for image processing for file {file_path}."
40 | )
41 |
42 | file_name = os.path.basename(file_path)
43 | image_url = self.image_store.upload_image(
44 | image_from_url(file_path), doc_name="image_docs"
45 | )
46 | if extra_info is None:
47 | extra_info = {}
48 | extra_info["file_path"] = str(file_path)
49 | extra_info["file_name"] = file_name
50 | extra_info["image_url"] = image_url
51 | image_doc = ImageDocument(image_url=image_url, extra_info=extra_info)
52 |
53 | docs = [image_doc]
54 | return docs
55 |
--------------------------------------------------------------------------------
/integrations/pairag-file/pairag/file/readers/pai/file_readers/pai_jsonl_reader.py:
--------------------------------------------------------------------------------
1 | """Tabular parser-Excel parser.
2 |
3 | Contains parsers for tabular data files.
4 |
5 | """
6 |
7 | import os
8 | from pathlib import Path
9 | from typing import Any, Dict, List, Optional
10 | from fsspec import AbstractFileSystem
11 | from llama_index.core.readers.base import BaseReader
12 | from llama_index.core.schema import Document
13 |
14 |
15 | class PaiJsonLReader(BaseReader):
16 | """JsonL reader."""
17 |
18 | def __init__(self, *args: Any, **kwargs: Any) -> None:
19 | """Init params."""
20 | super().__init__(*args, **kwargs)
21 |
22 | def load_data(
23 | self,
24 | file_path: Path,
25 | extra_info: Optional[Dict] = None,
26 | fs: Optional[AbstractFileSystem] = None,
27 | ) -> List[Document]:
28 | with open(file_path, "r", encoding="utf-8") as file:
29 | json_lines = [line.strip() for line in file.readlines()]
30 |
31 | file_name = os.path.basename(file_path)
32 | extra_info = extra_info or {}
33 | extra_info["file_path"] = str(file_path)
34 | extra_info["file_name"] = file_name
35 |
36 | docs = []
37 | for i, text in enumerate(json_lines):
38 | extra_info["row_number"] = i + 1
39 | docs.append(Document(text=text, metadata=extra_info))
40 | return docs
41 |
--------------------------------------------------------------------------------
/integrations/pairag-file/pairag/file/readers/pai/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/pairag/file/readers/pai/utils/__init__.py
--------------------------------------------------------------------------------
/integrations/pairag-file/pairag/file/readers/pai/utils/cuda_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from loguru import logger
4 |
5 |
6 | USE_CUDA = os.environ.get("USE_CUDA", "false")
7 |
8 |
9 | def should_use_cuda():
10 | if not torch.cuda.is_available():
11 | return False
12 |
13 | if USE_CUDA.lower() == "true" or USE_CUDA == "1":
14 | return True
15 | else:
16 | return False
17 |
18 |
19 | def infer_cuda_device() -> str:
20 | if should_use_cuda():
21 | logger.info("Using cuda device.")
22 | return "cuda"
23 | else:
24 | logger.info(
25 | "Will not use CUDA device acceleration. If you want to use cuda, please set the environment variable USE_CUDA=1."
26 | )
27 | return "cpu"
28 |
--------------------------------------------------------------------------------
/integrations/pairag-file/pairag/file/readers/pai/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from PIL import Image
3 | from io import BytesIO
4 | from loguru import logger
5 | import requests
6 | from urllib.parse import urlparse
7 |
8 |
9 | def is_remote_url(url_or_path: str | Path) -> bool:
10 | result = urlparse(str(url_or_path))
11 | is_remote = result.scheme in ("http", "https", "ftp", "s3", "gs")
12 | return is_remote
13 |
14 |
15 | def image_from_bytes(
16 | image_blob: bytes,
17 | image_filename_or_extension: str,
18 | ):
19 | if image_filename_or_extension.lower().endswith(
20 | ".emf"
21 | ) or image_filename_or_extension.lower().endswith(".wmf"):
22 | # 暂时不处理Windows图元文件
23 | logger.warning(
24 | f"Skip processing EMF or WMF image: {image_filename_or_extension}"
25 | )
26 | return None
27 |
28 | return Image.open(BytesIO(image_blob))
29 |
30 |
31 | def image_from_url(image_url: str):
32 | if is_remote_url(image_url):
33 | try:
34 | response = requests.get(image_url)
35 | response.raise_for_status() # 检查请求是否成功
36 |
37 | # 将二进制数据转换为图像对象
38 | image = Image.open(BytesIO(response.content))
39 | return image
40 | except Exception as ex:
41 | logger.warning(
42 | f"Failed to download image from URL: {image_url}. Error: {ex}"
43 | )
44 | return None
45 | else:
46 | try:
47 | image = Image.open(image_url)
48 | return image
49 | except Exception as ex:
50 | logger.warning(f"Failed to open image from file: {image_url}. Error: {ex}")
51 | return None
52 |
--------------------------------------------------------------------------------
/integrations/pairag-file/pairag/file/readers/pai/utils/magic-pdf.template.json:
--------------------------------------------------------------------------------
1 | {
2 | "bucket_info": {
3 | "bucket-name-1": ["ak", "sk", "endpoint"],
4 | "bucket-name-2": ["ak", "sk", "endpoint"]
5 | },
6 | "models-dir": "model_repository/PDF-Extract-Kit-1.0/models",
7 | "layoutreader-model-dir": "model_repository/PDF-Extract-Kit-1.0/models/layoutreader",
8 | "device-mode": "cpu",
9 | "layout-config": {
10 | "model": "doclayout_yolo"
11 | },
12 | "formula-config": {
13 | "mfd_model": "yolo_v8_mfd",
14 | "mfr_model": "unimernet_small",
15 | "enable": true
16 | },
17 | "table-config": {
18 | "model": "rapid_table",
19 | "sub_model": "slanet_plus",
20 | "enable": true,
21 | "max_time": 400
22 | },
23 | "latex-delimiter-config": {
24 | "display": {
25 | "left": "$$",
26 | "right": "$$"
27 | },
28 | "inline": {
29 | "left": "$",
30 | "right": "$"
31 | }
32 | },
33 | "config_version": "1.2.1"
34 | }
35 |
--------------------------------------------------------------------------------
/integrations/pairag-file/pairag/file/readers/readme.md:
--------------------------------------------------------------------------------
1 | # Readers integrations
2 |
--------------------------------------------------------------------------------
/integrations/pairag-file/pairag/file/store/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/pairag/file/store/__init__.py
--------------------------------------------------------------------------------
/integrations/pairag-file/pairag/file/store/pai_image_store.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from io import BytesIO
4 | import math
5 | from PIL.PngImagePlugin import PngImageFile
6 | from loguru import logger
7 |
8 | from pairag.file.store.oss_store import PaiOssStore
9 |
10 | IMAGE_MAX_PIXELS = 512 * 512
11 |
12 |
13 | class PaiImageStore:
14 | def __init__(
15 | self, oss_store: PaiOssStore = None, save_prefix="pai_oss_images/"
16 | ) -> None:
17 | self.oss_store = oss_store
18 | self.save_prefix = save_prefix
19 |
20 | def upload_image(self, image: PngImageFile, doc_name: str):
21 | if image is None:
22 | return None
23 |
24 | if self.oss_store is None:
25 | logger.warning(
26 | "oss_store is not properly configured, skipping image upload."
27 | )
28 | return None
29 |
30 | try:
31 | if image.mode != "RGB":
32 | image = image.convert("RGB")
33 | if image.width <= 50 or image.height <= 50:
34 | logger.warning(f"Skipping small image {image}")
35 | return None
36 |
37 | current_pixels = image.width * image.height
38 |
39 | # 检查像素总数是否超过限制
40 | if current_pixels > IMAGE_MAX_PIXELS:
41 | # 计算缩放比例以适应最大像素数
42 | scale = math.sqrt(IMAGE_MAX_PIXELS / current_pixels)
43 | new_width = int(image.width * scale)
44 | new_height = int(image.height * scale)
45 |
46 | # 调整图片大小
47 | image = image.resize((new_width, new_height), Image.LANCZOS)
48 |
49 | image_stream = BytesIO()
50 | image.save(image_stream, format="jpeg")
51 |
52 | image_stream.seek(0)
53 | data = image_stream.getvalue()
54 |
55 | image_url = self.oss_store.put_object_if_not_exists(
56 | data=data,
57 | file_ext=".jpeg",
58 | headers={
59 | "x-oss-object-acl": "public-read"
60 | }, # set public read to make image accessible
61 | path_prefix=os.path.join(self.save_prefix, doc_name.strip()),
62 | )
63 | logger.info(
64 | f"Saved image {image_url} from {doc_name} with width={image.width}, height={image.height}."
65 | )
66 | return image_url
67 | except Exception as e:
68 | logger.warning(f"处理图片失败 '{image}': {e}")
69 |
--------------------------------------------------------------------------------
/integrations/pairag-file/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["poetry-core"]
3 | build-backend = "poetry.core.masonry.api"
4 |
5 | [tool.poetry]
6 | name = "pairag-file"
7 | version = "1.2.2"
8 | description = "PAI-RAG file processing library."
9 | authors = []
10 | readme = "README.md"
11 | packages = [
12 | {include = "pairag/file"}
13 | ]
14 |
15 | [tool.poetry.dependencies]
16 | python = ">=3.11.0,<3.12"
17 | llama-index-core = "^0.12.27"
18 | llama-index-readers-file = "^0.4.3"
19 | docx2txt = "^0.8"
20 | pydantic = "^2.7.0"
21 | oss2 = "^2.18.5"
22 | torch = "2.2.2"
23 | transformers = "4.51.3"
24 | openpyxl = "^3.1.2"
25 | xlrd = "^2.0.1"
26 | markdown = "^3.6"
27 | chardet = "^5.2.0"
28 | modelscope = "^1.16.0"
29 | pre-commit = "^3.8.0"
30 | peft = "^0.12.0"
31 | python-pptx = "^1.0.2"
32 | aspose-slides = "^24.10.0"
33 | datasketch = "^1.6.5"
34 | anyio = "^4.6.2.post1"
35 | mistletoe = "^1.4.0"
36 | html2text = "^2024.2.26"
37 | rapidfuzz = "^3.13.0"
38 | python-docx = "^1.1.2"
39 | numpy = "1.26.4"
40 | loguru = "^0.7.3"
41 | llama-index-multi-modal-llms-openai = "^0.5.0"
42 | pytest = "^8.3.5"
43 | magic-pdf = {version = "1.3.10", extras = ["full"]}
44 |
45 | [tool.pytest.ini_options]
46 | asyncio_mode = "auto"
47 |
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/tests/__init__.py
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/nodeparsers/test_markdown_mode_parser.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 |
4 | def test_markdown_parser():
5 | from pairag.file.nodeparsers.pai.pai_markdown_parser import (
6 | MarkdownNodeParser,
7 | )
8 | from pairag.file.readers.pai.pai_data_reader import (
9 | PaiDataReader,
10 | DataReaderConfig,
11 | )
12 |
13 | reader_config = DataReaderConfig()
14 | directory_reader = PaiDataReader(reader_config=reader_config)
15 |
16 | input_dir = "tests/testdata/md_data"
17 | documents = directory_reader.load_data(file_path_or_directory=input_dir)
18 | md_node_parser = MarkdownNodeParser(enable_multimodal=False)
19 | splitted_nodes = []
20 | for doc_node in documents:
21 | splitted_nodes.extend(md_node_parser.get_nodes_from_documents([doc_node]))
22 |
23 | text_list = [node.text for node in splitted_nodes]
24 |
25 | with open(
26 | "tests/testdata/json_data/pai_document.json", "r", encoding="utf-8"
27 | ) as file:
28 | chunk_text = json.load(file)
29 |
30 | assert text_list == chunk_text
31 | assert len(splitted_nodes) == 10
32 |
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/openailike_multimodal.py:
--------------------------------------------------------------------------------
1 | # dashscope multimodal llm的achat接口有问题,这里使用openai接口
2 |
3 | from llama_index.multi_modal_llms.openai import OpenAIMultiModal
4 | from typing import Dict, Any
5 |
6 |
7 | class OpenAIAlikeMultiModal(OpenAIMultiModal):
8 | def _get_model_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
9 | base_kwargs = {"model": self.model, "temperature": self.temperature, **kwargs}
10 | if self.max_new_tokens is not None:
11 | # If max_tokens is None, don't include in the payload:
12 | # https://platform.openai.com/docs/api-reference/chat
13 | # https://platform.openai.com/docs/api-reference/completions
14 | base_kwargs["max_tokens"] = self.max_new_tokens
15 | return {**base_kwargs, **self.additional_kwargs}
16 |
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/readers/test_csv_reader.py:
--------------------------------------------------------------------------------
1 | from pairag.file.readers.pai.pai_data_reader import DataReaderConfig, PaiDataReader
2 | from pairag.file.readers.pai.file_readers.pai_csv_reader import PaiPandasCSVReader
3 |
4 |
5 | def test_pandas_csv_reader():
6 | directory_reader = PaiDataReader(reader_config=DataReaderConfig())
7 | input_dir = "tests/testdata/csv_data"
8 | directory_reader.file_readers[".csv"] = PaiPandasCSVReader(
9 | concat_rows=False,
10 | pandas_config={"header": [0, 1]},
11 | )
12 | documents = directory_reader.load_data(file_path_or_directory=input_dir)
13 | for doc in documents:
14 | print(doc)
15 | assert len(documents) == 897
16 |
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/readers/test_docx_reader.py:
--------------------------------------------------------------------------------
1 | from pairag.file.readers.pai.pai_data_reader import PaiDataReader, DataReaderConfig
2 | from pairag.file.readers.pai.file_readers.pai_docx_reader import PaiDocxReader
3 |
4 |
5 | def test_pai_docx_reader():
6 | reader_config = DataReaderConfig()
7 | directory_reader = PaiDataReader(reader_config=reader_config)
8 | input_dir = "tests/testdata/docx_data"
9 |
10 | directory_reader.file_readers[".docx"] = PaiDocxReader()
11 |
12 | documents = directory_reader.load_data(file_path_or_directory=input_dir)
13 | assert "步骤一:部署RAG服务" in str(documents[0].text)
14 |
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/readers/test_excel_reader.py:
--------------------------------------------------------------------------------
1 | def test_pandas_excel_reader():
2 | from pairag.file.readers.pai.pai_data_reader import (
3 | PaiDataReader,
4 | DataReaderConfig,
5 | )
6 | from pairag.file.readers.pai.file_readers.pai_excel_reader import (
7 | PaiPandasExcelReader,
8 | )
9 |
10 | reader_config = DataReaderConfig()
11 | directory_reader = PaiDataReader(reader_config=reader_config)
12 | input_dir = "tests/testdata/excel_data"
13 | directory_reader.file_readers[".xlsx"] = PaiPandasExcelReader(
14 | concat_rows=reader_config.concat_csv_rows,
15 | pandas_config={"header": [0, 1]},
16 | )
17 | directory_reader.file_readers[".xls"] = PaiPandasExcelReader(
18 | concat_rows=reader_config.concat_csv_rows,
19 | pandas_config={"header": [0, 1]},
20 | )
21 |
22 | documents = directory_reader.load_data(file_path_or_directory=input_dir)
23 |
24 | for doc in documents:
25 | print(doc)
26 | assert len(documents) == 7
27 |
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/readers/test_html_reader.py:
--------------------------------------------------------------------------------
1 | from pairag.file.readers.pai.pai_data_reader import PaiDataReader, DataReaderConfig
2 | from pairag.file.readers.pai.file_readers.pai_html_reader import PaiHtmlReader
3 |
4 |
5 | def test_pai_html_reader():
6 | directory_reader = PaiDataReader(reader_config=DataReaderConfig())
7 | input_dir = "tests/testdata/html_data"
8 |
9 | directory_reader.file_readers[".html"] = PaiHtmlReader()
10 |
11 | documents = directory_reader.load_data(file_path_or_directory=input_dir)
12 | assert len(documents) == 5
13 |
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/readers/test_image_reader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 | import requests
4 | from pairag.file.store.oss_store import PaiOssStore
5 | from pairag.file.store.pai_image_store import PaiImageStore
6 | from pairag.file.readers.pai.file_readers.pai_image_reader import PaiImageReader
7 |
8 | if not os.environ.get("OSS_ACCESS_KEY_ID") or not os.environ.get(
9 | "OSS_ACCESS_KEY_SECRET"
10 | ):
11 | pytest.skip(
12 | reason="OSS_ACCESS_KEY_ID or OSS_ACCESS_KEY_SECRET not set",
13 | allow_module_level=True,
14 | )
15 |
16 |
17 | @pytest.fixture
18 | def image_store():
19 | oss_store = PaiOssStore(
20 | bucket_name="feiyue-test", endpoint="oss-cn-hangzhou.aliyuncs.com"
21 | )
22 | return PaiImageStore(oss_store=oss_store)
23 |
24 |
25 | def test_image_reader(image_store):
26 | image_reader = PaiImageReader(image_store=image_store)
27 | test_image_path = "tests/testdata/image_data/11.jpg"
28 | image_doc = image_reader.load_data(file_path=test_image_path)[0]
29 | image_url = image_doc.metadata.get("image_url")
30 | assert image_url is not None, "image url should not be None."
31 |
32 | image_response = requests.get(image_url)
33 | assert image_response.status_code == 200, "image url should be valid."
34 |
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/readers/test_jsonl_reader.py:
--------------------------------------------------------------------------------
1 | def test_jsonl_reader():
2 | from pairag.file.readers.pai.pai_data_reader import (
3 | PaiDataReader,
4 | DataReaderConfig,
5 | )
6 | from pairag.file.readers.pai.file_readers.pai_jsonl_reader import PaiJsonLReader
7 |
8 | reader_config = DataReaderConfig()
9 | directory_reader = PaiDataReader(reader_config=reader_config)
10 |
11 | input_dir = "tests/testdata/jsonl_data"
12 | directory_reader.file_readers[".jsonl"] = PaiJsonLReader()
13 |
14 | documents = directory_reader.load_data(file_path_or_directory=input_dir)
15 | for doc in documents:
16 | print(doc)
17 | assert len(documents) == 27
18 |
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/readers/test_pdf_reader.py:
--------------------------------------------------------------------------------
1 | def test_pai_pdf_reader():
2 | from pairag.file.readers.pai.pai_data_reader import (
3 | PaiDataReader,
4 | DataReaderConfig,
5 | )
6 | from pairag.file.readers.pai.file_readers.pai_pdf_reader import PaiPDFReader
7 |
8 | reader_config = DataReaderConfig()
9 | directory_reader = PaiDataReader(reader_config=reader_config)
10 |
11 | input_dir = "tests/testdata/pdf_data"
12 | directory_reader.file_readers[".pdf"] = PaiPDFReader()
13 |
14 | documents = directory_reader.load_data(file_path_or_directory=input_dir)
15 | assert len(documents) > 0
16 |
17 |
18 | def test_is_horizontal_table():
19 | from pairag.file.readers.pai.utils.markdown_utils import is_horizontal_table
20 |
21 | # example data
22 | horizontal_table_1 = [
23 | ["Name", "Age", "City"],
24 | ["Alice", 30, "New York"],
25 | ["Bob", 25, "San Francisco"],
26 | ]
27 |
28 | horizontal_table_2 = [
29 | ["Name", "Age", "discount"],
30 | ["Alice", 30, 0.3],
31 | ["Bob", 25, 0.4],
32 | ]
33 |
34 | horizontal_table_3 = [
35 | ["Age", "discount", "amount"],
36 | [30, 0.3, 3],
37 | [25, 0.4, 7],
38 | [34, 0.2, 9],
39 | ]
40 |
41 | vertical_table = [
42 | ["Field", "Record1", "Record2"],
43 | ["Name", "Alice", "Bob"],
44 | ["Age", 30, 25],
45 | ["City", "New York", "San Francisco"],
46 | ]
47 | assert is_horizontal_table(horizontal_table_1)
48 | assert is_horizontal_table(horizontal_table_2)
49 | assert is_horizontal_table(horizontal_table_3)
50 | assert not is_horizontal_table(vertical_table)
51 |
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/readers/test_post_process_multi_level_headings.py:
--------------------------------------------------------------------------------
1 | from pairag.file.readers.pai.file_readers.pai_pdf_reader import PaiPDFReader
2 |
3 |
4 | def test_post_process_multi_level_headings():
5 | title_list = [
6 | ("title_1", 6),
7 | ("title_2", 10),
8 | ("title_3", 8),
9 | ("title_4", 7),
10 | ("title_5", 14),
11 | ]
12 |
13 | pdf_process = PaiPDFReader()
14 | new_title_list = pdf_process.post_process_multi_level_headings(title_list, 0, 0)
15 | assert new_title_list == [
16 | "### title_1",
17 | "## title_2",
18 | "### title_3",
19 | "### title_4",
20 | "# title_5",
21 | ]
22 |
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/test_image_caption_tool.py:
--------------------------------------------------------------------------------
1 | from pairag.file.nodeparsers.pai.image_caption_tool import ImageCaptionTool
2 | from tests.openailike_multimodal import OpenAIAlikeMultiModal
3 | import os
4 | import pytest
5 |
6 |
7 | if os.environ.get("DASHSCOPE_API_KEY") is None:
8 | pytest.skip(reason="DASHSCOPE_API_KEY not set", allow_module_level=True)
9 |
10 |
11 | @pytest.fixture
12 | def multimodal_llm() -> OpenAIAlikeMultiModal:
13 | return OpenAIAlikeMultiModal(
14 | model="qwen-vl-max",
15 | api_base="https://dashscope.aliyuncs.com/compatible-mode/v1",
16 | api_key=os.environ.get("DASHSCOPE_API_KEY"),
17 | )
18 |
19 |
20 | def test_image_caption_tool(multimodal_llm: OpenAIAlikeMultiModal):
21 | test_image_url = "https://feiyue-test.oss-cn-hangzhou.aliyuncs.com/pai_oss_images/EAS%E8%AE%A1%E8%B4%B9%E8%AF%B4%E6%98%8E/8ea0d46984ba9166e96f8afb58132dc6.jpeg"
22 |
23 | image_caption_tool = ImageCaptionTool(multimodal_llm=multimodal_llm)
24 | caption = image_caption_tool.extract_url(test_image_url)
25 | print(caption)
26 |
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/test_image_utils.py:
--------------------------------------------------------------------------------
1 | from pairag.file.readers.pai.utils.image_utils import is_remote_url
2 |
3 |
4 | def test_is_remote_url():
5 | assert is_remote_url("http://www.baidu.com") is True
6 | assert is_remote_url("https://www.tencent.com/1.jpg") is True
7 | assert is_remote_url("http://123123213123") is True
8 | assert is_remote_url("https://abcabc") is True
9 | assert is_remote_url("/a/b/c/1.jpg") is False
10 | assert is_remote_url("./1.txt") is False
11 | assert is_remote_url("../..") is False
12 | assert is_remote_url(".") is False
13 | assert is_remote_url("") is False
14 | assert is_remote_url("/") is False
15 | assert is_remote_url("data/") is False
16 |
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/testdata/csv_data/30天留存率_csv_two_header.csv:
--------------------------------------------------------------------------------
1 | time,metric
2 | date,rate
3 | 20240101,0.9375
4 | 20240201,0.9744
5 | 20240301,0.9767
6 | 20240401,0.9375
7 | 20240501,0.9091
8 | 20240601,0.9474
9 | 20240701,0.9667
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/testdata/db_data/pets.db:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/tests/testdata/db_data/pets.db
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/testdata/db_data/pets.sqlite:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/tests/testdata/db_data/pets.sqlite
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/testdata/docx_data/大模型RAG对话系统.docx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/tests/testdata/docx_data/大模型RAG对话系统.docx
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/testdata/excel_data/30天留存率_excel_two_header.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/tests/testdata/excel_data/30天留存率_excel_two_header.xlsx
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/testdata/image_data/11.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/tests/testdata/image_data/11.jpg
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/testdata/jsonl_data/price.jsonl:
--------------------------------------------------------------------------------
1 | {"question": "335块", "answer": "订单报价"}
2 | {"question": "435元", "answer": "订单报价"}
3 | {"question": "435块", "answer": "订单报价"}
4 | {"question": "535元", "answer": "订单报价"}
5 | {"question": "535块", "answer": "订单报价"}
6 | {"question": "635元", "answer": "订单报价"}
7 | {"question": "635块", "answer": "订单报价"}
8 | {"question": "735元", "answer": "订单报价"}
9 | {"question": "735块", "answer": "订单报价"}
10 | {"question": "835元", "answer": "订单报价"}
11 | {"question": "835块", "answer": "订单报价"}
12 | {"question": "935元", "answer": "订单报价"}
13 | {"question": "935块", "answer": "订单报价"}
14 | {"question": "1357元", "answer": "订单报价"}
15 | {"question": "1357块", "answer": "订单报价"}
16 | {"question": "1100元", "answer": "订单报价"}
17 | {"question": "1179块", "answer": "订单报价"}
18 | {"question": "1279元", "answer": "订单报价"}
19 | {"question": "1279块", "answer": "订单报价"}
20 | {"question": "1379元", "answer": "订单报价"}
21 | {"question": "1379块", "answer": "订单报价"}
22 | {"question": "1479元", "answer": "订单报价"}
23 | {"question": "1479块", "answer": "订单报价"}
24 | {"question": "1579元", "answer": "订单报价"}
25 | {"question": "1579块", "answer": "订单报价"}
26 | {"question": "1679元", "answer": "订单报价"}
27 | {"question": "1679块", "answer": "订单报价"}
--------------------------------------------------------------------------------
/integrations/pairag-file/tests/testdata/pdf_data/pai_document.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/tests/testdata/pdf_data/pai_document.pdf
--------------------------------------------------------------------------------
/scripts/gunicorn.conf.py:
--------------------------------------------------------------------------------
1 | # gunicorn.conf.py
2 | bind = "0.0.0.0:8680"
3 | workers = 1
4 | worker_class = "uvicorn.workers.UvicornWorker"
5 | timeout = 600
6 |
--------------------------------------------------------------------------------
/scripts/load_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 |
5 | # 获取脚本所在的目录
6 | SCRIPT_DIR=$(dirname "$0")
7 | # 切换到脚本所在目录的上级目录
8 | cd "$SCRIPT_DIR/.."
9 | pwd
10 |
11 | python src/pairag/data_ingestion/main.py $*
12 |
--------------------------------------------------------------------------------
/scripts/load_data_by_step.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 |
5 | # 修改为你需要的目录
6 | INPUT_PATH="/testdata/testdata/small_txt"
7 | OUTPUT_PATH="/testdata/testdata/output/small0520"
8 |
9 |
10 | # 获取脚本所在的目录
11 | SCRIPT_DIR=$(dirname "$0")
12 | # 切换到脚本所在目录的上级目录
13 | cd "$SCRIPT_DIR/.."
14 | pwd
15 |
16 | echo "reading data source..."
17 | python src/pairag/data_ingestion/main.py data-source \
18 | --enable-delta \
19 | --input-path $INPUT_PATH \
20 | --output-path $OUTPUT_PATH
21 |
22 | echo "parsing files..."
23 | python src/pairag/data_ingestion/main.py parse \
24 | --input-path $OUTPUT_PATH \
25 | --output-path $OUTPUT_PATH
26 |
27 |
28 | echo "splitting into chunks..."
29 | python src/pairag/data_ingestion/main.py split \
30 | --input-path $OUTPUT_PATH \
31 | --output-path $OUTPUT_PATH \
32 |
33 | echo "embedding data..."
34 | python src/pairag/data_ingestion/main.py embed \
35 | --input-path $OUTPUT_PATH \
36 | --output-path $OUTPUT_PATH \
37 | --num-gpus 0.5 \
38 | --num-cpus 6 \
39 | --memory 16 \
40 | --concurrency 2
41 |
42 |
43 | echo "writing to data sink..."
44 | python src/pairag/data_ingestion/main.py data-sink \
45 | --input-path $OUTPUT_PATH \
46 | --output-path $OUTPUT_PATH \
47 |
--------------------------------------------------------------------------------
/scripts/locust.py:
--------------------------------------------------------------------------------
1 | from locust import HttpUser, task, between
2 | from random import randint
3 |
4 | auth_header = {"Authorization": ""}
5 |
6 | sample_queries = [
7 | "找一部关于下雪的电影。",
8 | "找一部适合下雨天看的电影。",
9 | "心情不好的时候看什么电影?",
10 | "无聊的时候想看什么电影",
11 | "压力很大的时候看的电影",
12 | "好看的中文警匪片",
13 | "校园爱情电影",
14 | "金庸小说改编的古装武打剧",
15 | "好看的仙侠剧",
16 | "搞笑的电影",
17 | "评分高的的动画片",
18 | ]
19 |
20 |
21 | class SimpleRagUser(HttpUser):
22 | wait_time = between(0, 1)
23 |
24 | @task
25 | def qa(self):
26 | q_i = randint(0, len(sample_queries) - 1)
27 | query = sample_queries[q_i]
28 |
29 | _ = self.client.post(
30 | "/service/query", headers=auth_header, json={"question": query}
31 | )
32 | # sprint(response.content.decode("utf-8"))
33 |
34 |
35 | class SimpleRetrievalUser(HttpUser):
36 | wait_time = between(0, 1)
37 |
38 | @task
39 | def qa(self):
40 | q_i = randint(0, len(sample_queries) - 1)
41 | query = sample_queries[q_i]
42 |
43 | _ = self.client.post(
44 | "/service/query/retrieval", headers=auth_header, json={"question": query}
45 | )
46 | # print(response.content.decode("utf-8"))
47 |
--------------------------------------------------------------------------------
/scripts/start.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | helpFunction()
4 | {
5 | echo ""
6 | echo "Usage: $0 -w workers -p port"
7 | echo -e "\t-w number of workers, default 1"
8 | echo -e "\t-p port, default 8680"
9 | exit 1 # Exit script after printing help
10 | }
11 |
12 | while getopts ":w:p:" opt
13 | do
14 | case "$opt" in
15 | w ) workers="$OPTARG" ;;
16 | p ) port="$OPTARG" ;;
17 | ? ) helpFunction ;; # Print helpFunction in case parameter is non-existent
18 | esac
19 | done
20 |
21 | workers="${workers:-1}"
22 | port="${port:-8680}"
23 |
24 | echo "Starting gunicorn with $workers workers on port $port..."
25 |
26 | gunicorn -w $workers -b "0.0.0.0:${port}" -c scripts/gunicorn.conf.py src.pairag.app.app:app --timeout 600
27 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="pairag.file", # 包名
5 | version="0.1.0",
6 | author="Your Name",
7 | author_email="you@example.com",
8 | description="PAI-RAG file utilities.",
9 | long_description=open("README.md").read(),
10 | long_description_content_type="text/markdown",
11 | url="https://github.com/aigc-apps/PAI-RAG",
12 | packages=find_packages(where="src/pairag/file"), # 从 src/ 下查找包
13 | package_dir={"": "src/pairag"}, # 指定源码根目录
14 | classifiers=[
15 | "Programming Language :: Python :: 3",
16 | "License :: OSI Approved :: MIT License",
17 | "Operating System :: OS Independent",
18 | ],
19 | python_requires=">=3.11",
20 | install_requires=[
21 | "requests>=2.26.0",
22 | "click>=8.0",
23 | "llama-index-core==0.12.12",
24 | "llama-index-readers-file>=0.4.3",
25 | "docx2txt>=0.8",
26 | "pydantic>=2.7.0",
27 | "oss2>=2.18.5",
28 | "torch==2.2.2",
29 | "transformers==4.51.3",
30 | "openpyxl>=3.1.2",
31 | "xlrd>=2.0.1",
32 | "markdown>=3.6",
33 | "chardet>=5.2.0",
34 | "peft>=0.12.0",
35 | "python-pptx>=1.0.2",
36 | "aspose-slides>=24.10.0",
37 | "datasketch>=1.6.5",
38 | "mistletoe>=1.4.0",
39 | "html2text>=2024.2.26",
40 | "python-docx>=1.1.2",
41 | "numpy==1.26.4",
42 | "loguru>=0.7.3",
43 | "magic-pdf[full]==1.3.10",
44 | ],
45 | )
46 |
--------------------------------------------------------------------------------
/src/pairag/api/chat_completions.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter
2 | from fastapi.responses import StreamingResponse
3 | from pairag.chat.models import ChatCompletionRequest
4 | from pairag.core.chat_service import chat_service
5 |
6 | router_openai = APIRouter()
7 |
8 |
9 | @router_openai.get("/models")
10 | async def get_models():
11 | return {
12 | "data": [
13 | {
14 | "id": "default",
15 | "object": "model",
16 | "created": 1739298766,
17 | "owned_by": "pai",
18 | "permission": [],
19 | }
20 | ]
21 | }
22 |
23 |
24 | @router_openai.post("/chat/completions")
25 | async def chat_completions(request: ChatCompletionRequest):
26 | if not request.stream:
27 | response = await chat_service.achat(request)
28 | return response
29 | else:
30 | response = await chat_service.astream_chat(request)
31 | return StreamingResponse(
32 | response,
33 | media_type="text/event-stream",
34 | )
35 |
--------------------------------------------------------------------------------
/src/pairag/api/exception_handler.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Request
2 | from fastapi.responses import JSONResponse
3 | from pairag.core.models.errors import UserInputError, ServiceError
4 |
5 |
6 | async def _user_input_exception_handler(request: Request, exception: UserInputError):
7 | return JSONResponse(
8 | status_code=400,
9 | content={"message": f"Failed to process request input: {exception.msg}"},
10 | )
11 |
12 |
13 | async def _service_exception_handler(request: Request, exception: ServiceError):
14 | return JSONResponse(status_code=500, content={"message": f"Oops, {exception.msg}"})
15 |
16 |
17 | def add_exception_handler(app: FastAPI):
18 | app.add_exception_handler(UserInputError, _user_input_exception_handler)
19 | app.add_exception_handler(ServiceError, _service_exception_handler)
20 |
--------------------------------------------------------------------------------
/src/pairag/api/middleware.py:
--------------------------------------------------------------------------------
1 | from fastapi.middleware.cors import CORSMiddleware
2 | from fastapi import FastAPI, Request
3 | from starlette.middleware.base import BaseHTTPMiddleware
4 | from asgi_correlation_id import CorrelationIdMiddleware
5 | import time
6 | from loguru import logger
7 |
8 |
9 | class CustomMiddleWare(BaseHTTPMiddleware):
10 | def __init__(self, app):
11 | super().__init__(app)
12 | self.last_log_time = 0
13 | self.log_interval = 60 # seconds
14 |
15 | async def dispatch(self, request: Request, call_next):
16 | start_time = time.time()
17 | response = await call_next(request)
18 | process_time = time.time() - start_time
19 | host = request.client.host
20 | response.headers["X-Process-Time"] = str(process_time)
21 | response.headers["X-Client-IP"] = host
22 |
23 | if "get_upload_state" in str(request.url):
24 | current_time = time.time()
25 | if current_time - self.last_log_time >= self.log_interval:
26 | logger.info(
27 | f"Request: {request.method} {request.url} - Response Time: {process_time:.4f} seconds Host {host}"
28 | )
29 | self.last_log_time = current_time
30 | else:
31 | logger.info(
32 | f"Request: {request.method} {request.url} - Response Time: {process_time:.4f} seconds Host {host}"
33 | )
34 | return response
35 |
36 |
37 | def _configure_session_middleware(app):
38 | app.add_middleware(
39 | CorrelationIdMiddleware,
40 | header_name="X-Request-ID",
41 | )
42 |
43 |
44 | def _configure_cors_middleware(app):
45 | app.add_middleware(
46 | CORSMiddleware,
47 | allow_origins=["*"],
48 | allow_methods=["*"],
49 | allow_headers=["*"],
50 | allow_credentials=False,
51 | )
52 |
53 |
54 | def add_middlewares(app: FastAPI):
55 | # reset current middleware to allow modifying user provided list
56 | app.middleware_stack = None
57 | _configure_cors_middleware(app)
58 | _configure_session_middleware(app)
59 | app.add_middleware(CustomMiddleWare)
60 | app.build_middleware_stack() # rebuild middleware stack on-the-fly
61 |
--------------------------------------------------------------------------------
/src/pairag/app/app.py:
--------------------------------------------------------------------------------
1 | # init trace
2 | import os
3 | import asyncio
4 | import threading
5 | from fastapi import FastAPI
6 |
7 | # setup models
8 |
9 | from pairag.utils.constants import DEFAULT_MODEL_DIR
10 | os.environ["PAIRAG_MODEL_DIR"] = DEFAULT_MODEL_DIR
11 | from pairag.utils.download_models import ModelScopeDownloader
12 | ModelScopeDownloader().load_rag_models()
13 |
14 |
15 | from contextlib import asynccontextmanager
16 | from pairag.utils.format_logging import format_logging
17 | from pairag.core.chat_service import chat_service
18 | from pairag.data_pipeline.job.rag_job_manager import job_manager
19 | from pairag.core.service_daemon import startup_event
20 | from loguru import logger
21 |
22 | format_logging()
23 |
24 | @asynccontextmanager
25 | async def lifespan(app: FastAPI):
26 | logger.info("Application starting up...")
27 | daemon_thread = threading.Thread(target=job_manager.execute_job, daemon=True)
28 | daemon_thread.start()
29 |
30 | asyncio.create_task(startup_event())
31 | yield
32 |
33 | logger.info("Application shutting down...")
34 |
35 |
36 | def configure(app: FastAPI):
37 | from pairag.api.api_router_v1 import router_v1
38 | from pairag.api.api_chat import router_openai, router_chat
39 | from pairag.api.exception_handler import add_exception_handler
40 | from pairag.api.middleware import add_middlewares
41 | from pairag.web.webui import configure_webapp
42 |
43 | app.include_router(router_v1, prefix="/api/v1", tags=["api_v1"])
44 | app.include_router(router_openai, prefix="/v1", tags=["chat_completions"])
45 | app.include_router(router_chat, prefix="/chat", tags=["chat_api"])
46 |
47 | chat_service.initialize()
48 | add_middlewares(app)
49 | add_exception_handler(app)
50 | configure_webapp(app)
51 |
52 |
53 | app = FastAPI(lifespan=lifespan)
54 | configure(app)
55 |
--------------------------------------------------------------------------------
/src/pairag/app/constants.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import os
3 |
4 | _BASE_DIR = Path(__file__).parent.parent
5 | _ROOT_BASE_DIR = Path(__file__).parent.parent.parent.parent
6 |
7 | DEFAULT_APPLICATION_CONFIG_FILE = os.path.join(_BASE_DIR, "config/settings.toml")
8 | DEFAULT_APPLICATION_EXAMPLE_DATA_FILE = os.path.join(
9 | _ROOT_BASE_DIR, "example_data/pai_document.pdf"
10 | )
11 | DEFAULT_HOST = "0.0.0.0"
12 | DEFAULT_PORT = 8001
13 | DEFAULT_RAG_URL = f"http://{DEFAULT_HOST}:{DEFAULT_PORT}/"
14 |
--------------------------------------------------------------------------------
/src/pairag/chat/models.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 | from typing import Any, List, Dict, Optional, AsyncGenerator, Generator
3 | from llama_index.core.base.llms.types import ChatMessage
4 | from llama_index.core.schema import NodeWithScore
5 | from pairag.integrations.query_transform.pai_query_transform import IntentResult
6 |
7 |
8 | class ContextDoc(BaseModel):
9 | text: str # 文档文本
10 | score: float # 文档得分
11 | metadata: Dict # 文档元数据
12 | image_url: str | None = None # 图片链接
13 |
14 |
15 | class RetrievalRequest(BaseModel):
16 | knowledgebase_id: Optional[str] = "default" # 知识库名称(index_name)
17 | query: str # 查询内容
18 | retrieval_settings: Optional[Dict] = None
19 | # ["retrieval_mode", "similarity_top_k", "vector_weight", "keyword_weight", "reranker_type", "similarity_threshold", "reranker_similarity_threshold", "reranker_model", "reranker_similarity_top_k"]
20 |
21 |
22 | class DocRecord(BaseModel):
23 | content: str # 包含知识库中数据源的文本块
24 | score: float # 结果与查询的相关性分数,范围:0~1
25 | title: str # 文档标题
26 | metadata: Dict # 包含数据源中文档的元数据属性及其值
27 |
28 |
29 | class NewRetrievalResponse(BaseModel):
30 | records: List[DocRecord]
31 |
32 |
33 | class RetrievalResponse(BaseModel):
34 | docs: List[ContextDoc]
35 |
36 |
37 | class ChatCompletionRequest(BaseModel):
38 | model: str # 模型名称
39 | messages: List[ChatMessage] # 上下文聊天
40 | stream: Optional[bool] = False # 流式输出
41 | index_name: Optional[str] = None # 索引名称
42 | chat_knowledgebase: Optional[bool] = False # 查询知识库
43 | search_web: Optional[bool] = False # 搜索网络
44 | return_reference: Optional[bool] = False # 返回参考
45 | chat_llm: Optional[bool] = False # llm聊天
46 | chat_db: Optional[bool] = False # 查询数据库
47 | chat_news: Optional[bool] = False # 使用新闻工具
48 | # llm args
49 | temperature: Optional[float] = None
50 | max_tokens: Optional[int] = None
51 | intent: Optional[IntentResult] = None # 意图
52 |
53 | class Config:
54 | extra = "allow" # allow extra fields
55 |
56 |
57 | class ChatResponseWrapper(BaseModel):
58 | response: Any
59 | additional_kwargs: Dict[str, Any] = {}
60 | source_nodes: List[NodeWithScore] = []
61 | intent_result: Optional[IntentResult] = None
62 |
63 | def model_dump_json(self, exclude=None, **kwargs) -> str:
64 | if exclude is None:
65 | exclude = set()
66 | elif isinstance(exclude, dict):
67 | exclude = {k for k, v in exclude.items() if v}
68 |
69 | # to compatible with arize instrumentation
70 | if isinstance(self.response, (Generator, AsyncGenerator)):
71 | exclude.add("response")
72 |
73 | return super().model_dump_json(exclude=exclude, **kwargs)
74 |
75 |
76 | class EmbeddingInput(BaseModel):
77 | input: str | List[str] = None
78 | model: str = "bge-m3"
79 |
--------------------------------------------------------------------------------
/src/pairag/config/evaluation/config.yaml:
--------------------------------------------------------------------------------
1 | experiment:
2 | # [text dataset][pai-eval]
3 | - name: "text_exp1"
4 | eval_data_path: "example_data/eval_docs_text"
5 | rag_setting_file: "src/pairag/config/evaluation/settings_eval_for_text.toml"
6 | eval_model_llm:
7 | source: "dashscope"
8 | model: "qwen-max"
9 | max_tokens: 1024
10 | use_pai_eval: false
11 | # [custom text dataset][crag]
12 | - name: "text_exp2"
13 | dataset: "crag"
14 | eval_data_path: "example_data/eval_docs_crag_small"
15 | rag_setting_file: "src/pairag/config/evaluation/settings_eval_for_crag_text.toml"
16 | eval_model_llm:
17 | source: "dashscope"
18 | model: "qwen-max"
19 | max_tokens: 1024
20 | # [multi-modal dataset]
21 | - name: "multi_modal_exp1"
22 | eval_data_path: "example_data/eval_docs_image"
23 | rag_setting_file: "src/pairag/config/evaluation/settings_eval_for_image.toml"
24 | eval_model_llm:
25 | source: "dashscope"
26 | model: "qwen-vl-max"
27 | max_tokens: 1024
28 | tested_multimodal_llm:
29 | source: "dashscope"
30 | model: "qwen-vl-max"
31 | max_tokens: 1024
32 | # [custom multi-modal dataset]
33 | - name: "multi_modal_exp2"
34 | qca_dataset_path: "example_data/eval_docs_image_example/multimodal_eval_dataset_zh_example.jsonl"
35 | rag_setting_file: "src/pairag/config/evaluation/settings_eval_for_image.toml"
36 | eval_model_llm:
37 | source: "dashscope"
38 | model: "qwen-vl-max"
39 | max_tokens: 1024
40 | tested_multimodal_llm:
41 | source: "dashscope"
42 | model: "qwen-vl-max"
43 | max_tokens: 1024
44 |
--------------------------------------------------------------------------------
/src/pairag/config/evaluation/settings_eval_for_text.toml:
--------------------------------------------------------------------------------
1 | dynaconf_merge = true
2 |
3 | [rag]
4 | name = "pairag"
5 | version = "0.1.1"
6 |
7 | [rag.agent]
8 | custom_agent_config_file = ""
9 | agent_tool_type = ""
10 |
11 | [rag.chat_store]
12 | type = "Local" # [Local, Aliyun-Redis]
13 | host = "Aliyun-Redis host"
14 | password = "Aliyun-Redis user:pwd"
15 | persist_path = "localdata/eval_exp_data/storage"
16 |
17 | [rag.data_analysis]
18 | type = "pandas"
19 | nl2sql_prompt = "给定一个输入问题,创建一个语法正确的{dialect}查询语句来执行,不要从特定的表中查询所有列,只根据问题查询几个相关的列。请注意只使用你在schema descriptions 中看到的列名。\n=====\n 小心不要查询不存在的列。请注意哪个列位于哪个表中。必要时,请使用表名限定列名。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: "
20 |
21 | [rag.data_reader]
22 | type = "SimpleDirectoryReader"
23 |
24 | # embedding configurations, source support API: OpenAI,DashScope; and local model:HuggingFace
25 | # if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model
26 | # eg.
27 | # source = "HuggingFace"
28 | # model = "bge-m3"
29 | # embed_batch_size = 10
30 | [rag.embedding]
31 | source = "DashScope"
32 | embed_batch_size = 10
33 |
34 | [rag.index]
35 | persist_path = "localdata/eval_exp_data/storage"
36 | enable_multimodal = false
37 | vector_store.type = "FAISS"
38 |
39 | # llm configurations, source support API: OpenAI,DashScope or PAI-EAS's deployment
40 | # eg.
41 | # source = "PaiEas"
42 | # model = ""
43 | # endpoint = ""
44 | # token = ""
45 | [rag.llm]
46 | source = "DashScope"
47 | model = "qwen-turbo"
48 |
49 | [rag.multimodal_embedding]
50 | source = "cnclip"
51 |
52 | [rag.multimodal_llm]
53 | source = "dashscope"
54 | model = "qwen-vl-plus"
55 |
56 | [rag.node_enhancement]
57 | tree_depth = 3
58 | max_clusters = 52
59 | proba_threshold = 0.10
60 |
61 | [rag.node_parser]
62 | type = "Sentence"
63 | chunk_size = 500
64 | chunk_overlap = 10
65 | enable_multimodal = false
66 |
67 | [rag.oss_store]
68 | bucket = ""
69 | endpoint = "oss-cn-hangzhou.aliyuncs.com"
70 |
71 | [rag.postprocessor]
72 | reranker_type = "no-reranker" # [simple-weighted-reranker, model-based-reranker]
73 | reranker_model = "bge-reranker-base" # [bge-reranker-base, bge-reranker-large]
74 | keyword_weight = 0.3
75 | vector_weight = 0.7
76 | similarity_threshold = 0.5
77 | top_n = 2
78 |
79 | [rag.query_transform]
80 | type = ""
81 |
82 | [rag.retriever]
83 | similarity_top_k = 3
84 | retrieval_mode = "hybrid" # [hybrid, embedding, keyword, router]
85 | query_rewrite_n = 1 # set to 1 to disable query generation
86 |
87 | [rag.search]
88 | source = "quark"
89 |
90 | [rag.synthesizer]
91 | type = "SimpleSummarize"
92 | text_qa_template = "参考内容信息如下\n---------------------\n{context_str}\n---------------------根据提供内容而非其他知识回答问题.\n问题: {query_str}\n答案: \n"
93 |
94 | [rag.trace]
95 | type = "pai_trace"
96 | endpoint = "http://tracing-analysis-dc-hz.aliyuncs.com:8090"
97 | token = ""
98 |
--------------------------------------------------------------------------------
/src/pairag/config/settings.toml:
--------------------------------------------------------------------------------
1 | dynaconf_merge = true
2 |
3 | [rag]
4 | name = "pairag"
5 | version = "0.1.1"
6 |
7 | [rag.agent]
8 | custom_agent_config_file = ""
9 | agent_tool_type = ""
10 |
11 | [rag.chat]
12 |
13 | [rag.chat_store]
14 | type = "Local" # [Local, Aliyun-Redis]
15 | host = "Aliyun-Redis host"
16 | password = "Aliyun-Redis user:pwd"
17 | persist_path = "localdata/storage"
18 |
19 | [rag.data_analysis]
20 | type = "pandas"
21 |
22 | [rag.data_reader]
23 | type = "SimpleDirectoryReader"
24 |
25 | # embedding configurations, source support API: OpenAI,DashScope; and local model:HuggingFace
26 | # if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model
27 | # eg.
28 | # source = "HuggingFace"
29 | # model = "bge-m3"
30 | # embed_batch_size = 10
31 | [rag.embedding]
32 | source = "huggingface"
33 | embed_batch_size = 10
34 | enable_sparse = false
35 |
36 | [rag.index]
37 | persist_path = "localdata/knowledgebase/default/.index/.faiss"
38 | enable_multimodal = true
39 | vector_store.type = "FAISS"
40 |
41 | [rag.llm]
42 | source = "openai_compatible"
43 |
44 | # llm configurations, source support API: OpenAI,DashScope or PAI-EAS's deployment
45 | # eg.
46 | # source = "PaiEas"
47 | # model = ""
48 | # base_url = ""
49 | # api_key = ""
50 | # vision_support = false
51 | [[rag.llms]]
52 |
53 | [rag.multimodal_embedding]
54 | source = "cnclip"
55 |
56 | [rag.multimodal_llm]
57 | source = "openai_compatible"
58 |
59 | [rag.node_enhancement]
60 | tree_depth = 3
61 | max_clusters = 52
62 | proba_threshold = 0.10
63 |
64 | [rag.node_parser]
65 | type = "Sentence"
66 | chunk_size = 500
67 | chunk_overlap = 10
68 | enable_multimodal = true
69 |
70 | [rag.oss_store]
71 | bucket = ""
72 | endpoint = "oss-cn-hangzhou.aliyuncs.com"
73 |
74 | [rag.postprocessor]
75 | reranker_type = "no-reranker" # [simple-weighted-reranker, model-based-reranker]
76 | reranker_model = "bge-reranker-base" # [bge-reranker-base, bge-reranker-large]
77 | similarity_threshold = 0.5
78 | top_n = 2
79 |
80 | [rag.query_rewrite]
81 | enabled = true
82 |
83 | [rag.query_rewrite.llm]
84 |
85 | [rag.retriever]
86 | similarity_top_k = 5
87 | retrieval_mode = "hybrid" # [hybrid, embedding, keyword, router]
88 | query_rewrite_n = 1 # set to 1 to disable query generation
89 |
90 | [rag.search]
91 | source = "bing"
92 |
93 | [rag.synthesizer]
94 | type = "SimpleSummarize"
95 |
--------------------------------------------------------------------------------
/src/pairag/core/models/config.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, ConfigDict
2 | from pairag.integrations.llms.pai.llm_config import OpenAICompatibleLlmConfig
3 | from pairag.integrations.synthesizer.prompt_templates import (
4 | DEFAULT_SYSTEM_ROLE_TEMPLATE,
5 | )
6 | from pairag.utils.prompt_template import (
7 | CHAT_LLM_REWRITE_PROMPT_ZH,
8 | KNOWLEDGEBASE_REWRITE_PROMPT_ZH,
9 | NEWS_REWRITE_PROMPT_ZH,
10 | NL2SQL_REWRITE_PROMPT_ZH,
11 | REWRITE_PROMPT_ROLE_ZH,
12 | WEBSEARCH_REWRITE_PROMPT_ZH,
13 | )
14 | from pairag.utils.prompt_template import (
15 | DEFALT_LLM_CHAT_PROMPT_TEMPL,
16 | )
17 |
18 |
19 | class ChatConfig(BaseModel):
20 | model_id: str | None = None
21 | model_config = ConfigDict(coerce_numbers_to_str=True)
22 |
23 |
24 | class QueryRewriteConfig(BaseModel):
25 | enabled: bool = True
26 | base_prompt_template_str: str = REWRITE_PROMPT_ROLE_ZH
27 | llm_tool_prompt_str: str = CHAT_LLM_REWRITE_PROMPT_ZH
28 | knowledge_tool_prompt_str: str = KNOWLEDGEBASE_REWRITE_PROMPT_ZH
29 | websearch_tool_prompt_str: str = WEBSEARCH_REWRITE_PROMPT_ZH
30 | db_tool_prompt_str: str = NL2SQL_REWRITE_PROMPT_ZH
31 | news_tool_prompt_str: str = NEWS_REWRITE_PROMPT_ZH
32 |
33 | model_id: str | None = None
34 | llm: OpenAICompatibleLlmConfig | None = OpenAICompatibleLlmConfig()
35 | model_config = ConfigDict(coerce_numbers_to_str=True)
36 |
37 |
38 | class AliyunTextModerationPlusConfig(BaseModel):
39 | endpoint: str | None = None
40 | region: str | None = None
41 | access_key_id: str | None = None
42 | access_key_secret: str | None = None
43 | custom_advice: str | None = None
44 |
45 | def is_enabled(self) -> bool:
46 | return (
47 | self.access_key_id is not None
48 | and self.access_key_secret is not None
49 | and len(self.access_key_id) > 0
50 | and len(self.access_key_secret) > 0
51 | and self.endpoint is not None
52 | and self.endpoint != ""
53 | and self.region is not None
54 | and self.region != ""
55 | )
56 |
57 |
58 | class NodeEnhancementConfig(BaseModel):
59 | tree_depth: int = 3
60 | max_clusters: int = 52
61 | proba_threshold: float = 0.10
62 |
63 |
64 | class OssStoreConfig(BaseModel):
65 | bucket: str | None = None
66 | endpoint: str = "oss-cn-hangzhou.aliyuncs.com"
67 | ak: str | None = None
68 | sk: str | None = None
69 | model_config = ConfigDict(coerce_numbers_to_str=True)
70 |
71 |
72 | class SynthesizerConfig(BaseModel):
73 | use_multimodal_llm: bool = False
74 | system_role_template: str = DEFAULT_SYSTEM_ROLE_TEMPLATE
75 | custom_prompt_template: str = DEFALT_LLM_CHAT_PROMPT_TEMPL
76 |
--------------------------------------------------------------------------------
/src/pairag/core/models/errors.py:
--------------------------------------------------------------------------------
1 | class UserInputError(Exception):
2 | def __init__(self, msg: str):
3 | self.msg = msg
4 |
5 |
6 | class ServiceError(Exception):
7 | def __init__(self, msg: str):
8 | self.msg = msg
9 |
--------------------------------------------------------------------------------
/src/pairag/core/models/state.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | class FileServiceState:
5 | def __init__(self, key):
6 | self.state_key = key
7 | self.state_value = -1
8 | self.state_value = self.check_state()
9 |
10 | def check_state(self):
11 | if not os.path.exists(self.state_key):
12 | return 0
13 |
14 | # dummpy read to update file state from oss mount path
15 | _ = open(self.state_key, "r").readline()
16 |
17 | mtime = os.path.getmtime(self.state_key)
18 | if mtime != self.state_value:
19 | return mtime
20 | return 0
21 |
22 | def update_state(self, new_value):
23 | self.state_value = new_value
24 |
--------------------------------------------------------------------------------
/src/pairag/core/rag_environment.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | class RagServiceEnvironment:
5 | def __init__(self):
6 | self.IS_API_INSTANCE = os.getenv("DEPLOY_MODE", "web").upper() == "API"
7 | self.SHOULD_START_WEB = not self.IS_API_INSTANCE
8 |
9 |
10 | service_environment = RagServiceEnvironment()
11 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/constants.py:
--------------------------------------------------------------------------------
1 | DEFAULT_NODE_SOURCE_FIELD = "source"
2 | DEFAULT_MODIFIED_AT_FIELD = "modified_at"
3 | DEFAULT_MD5_FIELD = "document_md5"
4 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/delta/list_delta.py:
--------------------------------------------------------------------------------
1 | # 通过向量数据库中的record比对确定delta信息
2 | from pairag.data_pipeline.delta.milvus import (
3 | list_docs_in_milvus_collection,
4 | )
5 | from pairag.data_pipeline.utils.vectordb_utils import get_vector_store
6 | from llama_index.vector_stores.milvus import MilvusVectorStore
7 |
8 | from loguru import logger
9 |
10 |
11 | """
12 | 获取PAI-RAG服务索引中的知识库文件列表
13 | """
14 |
15 |
16 | def list_files_from_rag_service(
17 | rag_endpoint: str,
18 | rag_api_key: str,
19 | knowledgebase: str,
20 | embed_dims: int,
21 | oss_path_prefix: str,
22 | ):
23 | vector_store = get_vector_store(
24 | rag_endpoint=rag_endpoint,
25 | rag_api_key=rag_api_key,
26 | knowledgebase=knowledgebase,
27 | embed_dims=embed_dims,
28 | )
29 | if isinstance(vector_store, MilvusVectorStore):
30 | return list_docs_in_milvus_collection(
31 | collection=vector_store._collection,
32 | oss_path_prefix=oss_path_prefix,
33 | )
34 | else:
35 | logger.error("Only support Milvus vector store for now.")
36 | raise NotImplementedError
37 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/delta/milvus.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | from pymilvus import Collection
3 |
4 | from pairag.data_pipeline.constants import (
5 | DEFAULT_MODIFIED_AT_FIELD,
6 | DEFAULT_NODE_SOURCE_FIELD,
7 | )
8 | from pairag.data_pipeline.delta.models import DocItem
9 |
10 |
11 | def list_docs_in_milvus_collection(
12 | collection: Collection,
13 | oss_path_prefix: str,
14 | ) -> Dict[str, DocItem]:
15 | # only return the documents under the oss_path_prefix
16 | if oss_path_prefix:
17 | expr = f"{DEFAULT_NODE_SOURCE_FIELD} like '{oss_path_prefix}%'"
18 | else:
19 | expr = None
20 |
21 | iterator = collection.query_iterator(
22 | batch_size=10,
23 | output_fields=[DEFAULT_NODE_SOURCE_FIELD, DEFAULT_MODIFIED_AT_FIELD],
24 | expr=expr,
25 | )
26 | results: Dict[str, DocItem] = {}
27 | while True:
28 | fetch_data = iterator.next()
29 | if not fetch_data:
30 | iterator.close()
31 | break
32 |
33 | # scan the fetch data, get the latest modified time and the node ids
34 | for record in fetch_data:
35 | source = record[DEFAULT_NODE_SOURCE_FIELD]
36 | if source not in results:
37 | results[source] = DocItem(
38 | doc_path=source,
39 | modified_time=record[DEFAULT_MODIFIED_AT_FIELD],
40 | node_ids=[record["id"]],
41 | )
42 | else:
43 | results[source].modified_time = min(
44 | record[DEFAULT_MODIFIED_AT_FIELD],
45 | results[source].modified_time,
46 | )
47 | results[source].node_ids.append(record["id"])
48 |
49 | return results
50 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/delta/models.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from pydantic import BaseModel
3 |
4 |
5 | class DocItem(BaseModel):
6 | doc_path: str
7 | node_ids: List[str]
8 | modified_time: float
9 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/e2e_config.yaml:
--------------------------------------------------------------------------------
1 | # process schedule
2 | # a list of several pairag offline process operators with their arguments
3 |
4 | operators:
5 | - name: "data_source"
6 | enable_delta: true
7 | supported_file_types_str: "pdf,txt,csv,xlsx,xls,docx,md,html,htm,jsonl"
8 | pairag_token:
9 | pairag_endpoint:
10 | pairag_knowledgebase:
11 | pairag_embed_dims: 1024
12 |
13 | - name: "parse"
14 | num_cpus: 2
15 | memory: 8
16 | num_gpus: 0
17 | enable_pdf_ocr: false
18 | concat_sheet_rows: false
19 | concurrency: 6
20 |
21 | - name: "split"
22 | node_parser_type: "Sentence"
23 | chunk_size: 1024
24 | chunk_overlap: 20
25 | paragraph_separator: "\n\n"
26 | num_cpus: 1
27 | memory: 4
28 | concurrency: 10
29 |
30 | - name: "embed"
31 | model: "bge-m3"
32 | source: "huggingface"
33 | batch_size: 300
34 | num_cpus: 4
35 | num_gpus: 1
36 | memory: 10
37 | concurrency: 1
38 |
39 | - name: "data_sink"
40 | pairag_token:
41 | pairag_endpoint:
42 | pairag_knowledgebase:
43 | pairag_embed_dims: 1024
44 | num_cpus: 1
45 | memory: 2
46 | batch_size: 500
47 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/models/config/datasource.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 | from pydantic import BaseModel
3 |
4 |
5 | class DataSourceConfig(BaseModel):
6 | input_path: str
7 | output_path: str
8 | enable_delta: bool = False
9 | file_extensions: Optional[List[str]] = None
10 |
11 | # connect to rag service
12 | pairag_token: Optional[str] = None
13 | pairag_endpoint: Optional[str] = None
14 | pairag_knowledgebase: Optional[str] = "default"
15 | pairag_embed_dims: Optional[int] = 1024
16 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/models/config/operator.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from typing import Optional
3 | from pydantic import BaseModel
4 | from llama_index.core.constants import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP
5 |
6 | from pairag.file.nodeparsers.pai.constants import (
7 | DEFAULT_NODE_PARSER_TYPE,
8 | DEFAULT_PARAGRAPH_SEP,
9 | )
10 |
11 |
12 | class OperatorName(str, Enum):
13 | DATA_SOURCE = "data_source"
14 | PARSER = "parse"
15 | SPLITTER = "split"
16 | EMBEDDER = "embed"
17 | DATA_SINK = "data_sink"
18 |
19 |
20 | class BaseOperatorConfig(BaseModel):
21 | """
22 | Base class for operator configs.
23 | """
24 |
25 | name: OperatorName
26 | num_cpus: float = 1
27 | num_gpus: float = 0
28 | memory: float = 2
29 |
30 | batch_size: int = 10
31 | input_path: str
32 | output_path: str
33 | model_dir: str = None
34 | concurrency: int = 1
35 |
36 |
37 | class ParserConfig(BaseOperatorConfig):
38 | """
39 | Config for parse operator.
40 | """
41 |
42 | name: OperatorName = OperatorName.PARSER
43 | enable_pdf_ocr: bool = False
44 | concat_sheet_rows: bool = False
45 | recursive: bool = True
46 |
47 |
48 | class SplitterConfig(BaseOperatorConfig):
49 | """
50 | Config for split operator.
51 | """
52 |
53 | name: OperatorName = OperatorName.SPLITTER
54 | paragraph_separator: str = DEFAULT_PARAGRAPH_SEP
55 | node_parser_type: str = DEFAULT_NODE_PARSER_TYPE
56 | chunk_size: int = DEFAULT_CHUNK_SIZE
57 | chunk_overlap: int = DEFAULT_CHUNK_OVERLAP
58 |
59 |
60 | class EmbedderConfig(BaseOperatorConfig):
61 | """
62 | Config for embed operator.
63 | """
64 |
65 | name: OperatorName = OperatorName.EMBEDDER
66 | model: str = "bge-m3"
67 | connection_name: Optional[str] = None
68 | workspace_id: Optional[str] = None
69 | enable_sparse: bool = False
70 | source: str = "huggingface"
71 | batch_size: int = 32
72 |
73 |
74 | class SinkConfig(BaseOperatorConfig):
75 | """
76 | Config for embed operator.
77 | """
78 |
79 | name: OperatorName = OperatorName.DATA_SINK
80 | pairag_endpoint: str
81 | pairag_token: str
82 | pairag_knowledgebase: str
83 | pairag_embed_dims: int
84 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/models/file/event.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class NodeOperationType(str, Enum):
5 | ADD = "add"
6 | DELETE = "delete"
7 |
8 |
9 | class FileChangeType(str, Enum):
10 | ADD = "add_file"
11 | MODIFY = "modify_file"
12 | DELETE = "delete_file"
13 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/models/models.py:
--------------------------------------------------------------------------------
1 | import time
2 | from pydantic import BaseModel, Field
3 | from typing import Dict
4 | from enum import Enum
5 |
6 | from pairag.utils.time_utils import get_current_time_str
7 |
8 |
9 | class FileOperationType(int, Enum):
10 | ADD = 1
11 | UPDATE = 2
12 | DELETE = 3
13 |
14 |
15 | class FileChange(BaseModel):
16 | task_id: str
17 | file_name: str
18 | file_hash: str
19 | operation: FileOperationType
20 | knowledgebase: str
21 |
22 |
23 | class FileProcessStatus(str, Enum):
24 | PENDING = "pending"
25 | Parsing = "parsing"
26 | Chunking = "chunking"
27 | Embedding = "embedding"
28 | Persisting = "persisting"
29 | Done = "done"
30 | Failed = "failed"
31 |
32 |
33 | class FileItem(FileChange):
34 | status: FileProcessStatus
35 | last_modified_time: str = Field(default_factory=lambda: get_current_time_str())
36 | timestamp: float = Field(default_factory=lambda: time.time())
37 | failed_reason: str | None = None
38 |
39 |
40 | class FileProcessResult(BaseModel):
41 | status: FileProcessStatus
42 | message: str | None = None
43 |
44 |
45 | class TaskInfo(BaseModel):
46 | knowledgebase: str
47 | task_map: Dict[str, FileItem] = {}
48 | last_modified_time: str = Field(default_factory=lambda: get_current_time_str())
49 |
50 |
51 | class JobStatus(BaseModel):
52 | task_statuses: Dict[str, TaskInfo] = {}
53 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/operators/base.py:
--------------------------------------------------------------------------------
1 | from pairag.data_pipeline.utils.cuda_utils import is_cuda_available
2 |
3 |
4 | OUTPUT_BATCH_SIZE = 50000
5 |
6 |
7 | class BaseOperator:
8 | def __init__(
9 | self,
10 | name: str = "default",
11 | num_cpus: float = 1,
12 | num_gpus: float = 0,
13 | model_dir: str = None,
14 | **kwargs,
15 | ):
16 | self.name = name
17 | self.num_cpus = num_cpus
18 | self.num_gpus = num_gpus
19 | self.model_dir = model_dir
20 | self.kwargs = kwargs
21 |
22 | def process(self, *args, **kwargs):
23 | raise NotImplementedError
24 |
25 | def use_cuda(self):
26 | return self.num_gpus > 0 and is_cuda_available()
27 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/operators/split.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List
2 | from llama_index.core.schema import TextNode
3 | from pairag.data_pipeline.models.config.operator import SplitterConfig
4 | from pairag.data_pipeline.models.file.event import NodeOperationType
5 | from pairag.data_pipeline.operators.base import BaseOperator
6 | from pairag.data_pipeline.utils.node_utils import (
7 | metadata_dict_to_node_v2,
8 | node_to_metadata_dict_v2,
9 | )
10 | from pairag.file.nodeparsers.pai.pai_node_parser import (
11 | NodeParserConfig,
12 | PaiNodeParser,
13 | )
14 | from loguru import logger
15 |
16 |
17 | class Splitter(BaseOperator):
18 | def __init__(self, config: SplitterConfig):
19 | super().__init__(
20 | name=config.name,
21 | num_cpus=config.num_cpus,
22 | num_gpus=config.num_gpus,
23 | model_dir=config.model_dir,
24 | )
25 |
26 | self.node_parser_config = NodeParserConfig(
27 | type=config.node_parser_type,
28 | chunk_size=config.chunk_size,
29 | chunk_overlap=config.chunk_overlap,
30 | )
31 | self.node_parser = PaiNodeParser(
32 | parser_config=self.node_parser_config,
33 | )
34 | logger.info(
35 | f"""SplitterActor init finished with following parameters: {config}"""
36 | )
37 |
38 | def process_delete(self, row: Dict[str, Any]) -> List[Dict[str, Any]]:
39 | return [row]
40 |
41 | def process_add(self, row: Dict[str, Any]) -> List[Dict[str, Any]]:
42 | doc = metadata_dict_to_node_v2(row)
43 | nodes = []
44 |
45 | splitted_nodes = self.node_parser.get_nodes_from_documents([doc])
46 | for node in splitted_nodes:
47 | # 去掉文本内容为空的分块
48 | if isinstance(node, TextNode) and not node.text:
49 | continue
50 | node_dict = node_to_metadata_dict_v2(node)
51 | node_dict["operation"] = row.get("operation")
52 | node_dict["operation_reason"] = row.get("operation_reason")
53 | nodes.append(node_dict)
54 |
55 | logger.info(f"Split {len(splitted_nodes)} nodes from {doc.node_id}.")
56 | return nodes
57 |
58 | def __call__(self, row: Dict[str, Any]) -> List[Dict[str, Any]]:
59 | if row.get("operation") == NodeOperationType.DELETE:
60 | return self.process_delete(row)
61 | else:
62 | return self.process_add(row)
63 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/readme.md:
--------------------------------------------------------------------------------
1 | ## 分步骤执行示例
2 |
3 | 设置output_path,每次执行选择一个新路径,做好数据隔离
4 |
5 | ```shell
6 | export OUTPUT_PATH=/testdata/output0428
7 | ```
8 |
9 | 1. Read命令
10 |
11 | ```shell
12 | python src/pairag/data_ingestion/main.py read --input-path /testdata/testdata/txt --output-path $OUTPUT_PATH/read --enable-delta
13 |
14 | ```
15 |
16 | 2. Parse命令
17 |
18 | ```shell
19 | python src/pairag/data_ingestion/main.py parse --input-path $OUTPUT_PATH/read --output-path $OUTPUT_PATH/parse \
20 | --num-cpus 2 --memory 8
21 | ```
22 |
23 | 3. Split命令
24 |
25 | ```shell
26 | python src/pairag/data_ingestion/main.py split --input-path $OUTPUT_PATH/parse --output-path $OUTPUT_PATH/split \
27 | --num-cpus 1 --memory 2
28 | ```
29 |
30 | 4. Embed命令
31 |
32 | ```shell
33 | python src/pairag/data_ingestion/main.py embed --input-path $OUTPUT_PATH/split --output-path $OUTPUT_PATH/embed \
34 | --num-cpus 5 --memory 10 --num-gpus 1 --source huggingface --model bge-m3
35 | ```
36 |
37 | 5. Write命令,存入向量数据库
38 |
39 | ```shell
40 | python src/pairag/data_ingestion/main.py write --input-path $OUTPUT_PATH/embed --num-cpus 10 --memory 20
41 | ```
42 |
43 | ## E2E执行示例
44 |
45 | 设置output_path,每次执行选择一个新路径,做好数据隔离
46 |
47 | ```shell
48 | export OUTPUT_PATH=/testdata/output0505
49 | export pairag_ENDPOINT=
50 | export pairag_TOKEN=
51 |
52 | python src/pairag/data_pipeline/main.py e2e --input-path xx --output-path $OUTPUT_PATH
53 | ```
54 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/utils/concurrency_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import math
3 | import psutil
4 | from loguru import logger
5 |
6 |
7 | def compute_concurrency_count(
8 | num_cpus: int,
9 | memory: int,
10 | num_gpus: int = 0,
11 | ):
12 | concurrency = sys.maxsize
13 |
14 | if num_gpus > 0:
15 | import torch
16 |
17 | cuda_device_count = torch.cuda.device_count()
18 | if cuda_device_count > 0:
19 | concurrency = math.floor(min(concurrency, cuda_device_count // num_gpus))
20 | logger.info(
21 | f"Available CUDA devices: {cuda_device_count}, required gpus {num_gpus}, updated conccurency {concurrency}."
22 | )
23 | else:
24 | logger.error("No CUDA devices found.")
25 | raise ValueError("No CUDA devices found.")
26 |
27 | cpu_available = psutil.cpu_count() - 4
28 | mem_available = psutil.virtual_memory().available
29 | mem_available = mem_available / 1024**3
30 |
31 | concurrency = math.floor(min(concurrency, cpu_available // num_cpus))
32 | logger.info(
33 | f"Available CPUs: {cpu_available}, required cpus {num_cpus}, updated conccurency {concurrency}."
34 | )
35 |
36 | concurrency = math.floor(min(concurrency, mem_available // memory))
37 | logger.info(
38 | f"Available memory: {mem_available}GB, required memory {memory}GB, updated conccurency {concurrency}."
39 | )
40 |
41 | return concurrency
42 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/utils/cuda_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import sys
4 | from loguru import logger
5 | from PIL import ImageFile
6 | import importlib.metadata
7 | import importlib.util
8 | from typing import Tuple, Union
9 |
10 | ImageFile.LOAD_TRUNCATED_IMAGES = True
11 |
12 | # For now, only INFO will be shown. Later the severity level will be changed
13 | # when setup_logger is called to initialize the logger.
14 | logger.remove()
15 | logger.add(sys.stderr, level="INFO")
16 |
17 |
18 | def _is_package_available(
19 | pkg_name: str, return_version: bool = False
20 | ) -> Union[Tuple[bool, str], bool]:
21 | # Check we're not importing a "pkg_name" directory somewhere
22 | # but the actual library by trying to grab the version
23 | package_exists = importlib.util.find_spec(pkg_name) is not None
24 | package_version = "N/A"
25 | if package_exists:
26 | try:
27 | package_version = importlib.metadata.version(pkg_name)
28 | package_exists = True
29 | except importlib.metadata.PackageNotFoundError:
30 | package_exists = False
31 | logger.debug(f"Detected {pkg_name} version {package_version}")
32 | if return_version:
33 | return package_exists, package_version
34 | else:
35 | return package_exists
36 |
37 |
38 | def _cuda_device_count():
39 | _torch_available = _is_package_available("torch")
40 |
41 | if _torch_available:
42 | import torch
43 |
44 | return torch.cuda.device_count()
45 |
46 | try:
47 | nvidia_smi_output = subprocess.check_output(["nvidia-smi", "-L"], text=True)
48 | all_devices = nvidia_smi_output.strip().split("\n")
49 |
50 | cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
51 | if cuda_visible_devices is not None:
52 | logger.warning(
53 | "CUDA_VISIBLE_DEVICES is ignored when torch is unavailable. "
54 | "All detected GPUs will be used."
55 | )
56 |
57 | return len(all_devices)
58 | except Exception:
59 | # nvidia-smi not found or other error
60 | return 0
61 |
62 |
63 | _CUDA_DEVICE_COUNT = _cuda_device_count()
64 |
65 |
66 | def cuda_device_count():
67 | return _CUDA_DEVICE_COUNT
68 |
69 |
70 | def is_cuda_available():
71 | return _CUDA_DEVICE_COUNT > 0
72 |
73 |
74 | def get_num_gpus(use_cuda, op_proc):
75 | if not use_cuda:
76 | return 0
77 | proc_per_gpu = op_proc / cuda_device_count()
78 | return 1.0 / proc_per_gpu
79 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/utils/download_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import fcntl
4 | from loguru import logger
5 | from pairag.utils.download_models import ModelScopeDownloader
6 |
7 |
8 | def download_models_via_lock(model_dir, model_name, use_cuda: bool = False):
9 | model_dir = model_dir or os.getenv("PAIRAG_MODEL_DIR", "./model_repository")
10 | model_path = os.path.join(model_dir, model_name)
11 | lock_file_path = model_name + ".lock"
12 | # 创建或打开一个锁文件
13 | with open(lock_file_path, "w") as lock_file:
14 | while True:
15 | try:
16 | # 尝试获取文件锁
17 | fcntl.flock(lock_file, fcntl.LOCK_EX | fcntl.LOCK_NB)
18 | logger.info(f"进程 {os.getpid()} 获得锁")
19 |
20 | # 检查模型文件是否已经下载
21 | if os.path.exists(model_path):
22 | logger.info(f"进程 {os.getpid()} 检查到: 模型已下载完成,use_cuda: {use_cuda}")
23 | else:
24 | logger.info(f"进程 {os.getpid()} 开始下载模型,环境: use_cuda: {use_cuda}。")
25 | ModelScopeDownloader(
26 | fetch_config=True,
27 | download_directory_path=model_dir,
28 | ).load_model(model=model_name)
29 |
30 | # 释放锁并结束循环
31 | fcntl.flock(lock_file, fcntl.LOCK_UN)
32 | logger.info(f"进程 {os.getpid()} 下载模型完成,use_cuda: {use_cuda}。")
33 | break
34 |
35 | except IOError as ex:
36 | logger.info(f"进程 {os.getpid()} 等待锁中... {ex}")
37 | time.sleep(1) # 等待后重试
38 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/utils/file_ext_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 |
4 | def parse_file_extensions(file_extensions_str: str) -> List[str]:
5 | """
6 | Parse the file extensions string into a list of file extensions.
7 |
8 | Args:
9 | file_extensions_str (str): The file extensions string.
10 |
11 | Returns:
12 | List[str]: The list of file extensions.
13 | """
14 | file_extensions = file_extensions_str.split(",")
15 | extensions = [f".{ext.strip()}" for ext in file_extensions if ext.strip()]
16 | return extensions
17 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/utils/filename_utils.py:
--------------------------------------------------------------------------------
1 | from ray.data.datasource import FilenameProvider
2 |
3 |
4 | class BlockFileNameProvider(FilenameProvider):
5 | def __init__(self, run_label: str, file_format: str):
6 | self._run_label = run_label
7 | self._file_format = file_format
8 |
9 | def get_filename_for_block(self, block, task_index, block_index):
10 | return f"{self._run_label}_{task_index:06}_{block_index:06}.{self._file_format}"
11 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/utils/memory_utils.py:
--------------------------------------------------------------------------------
1 | # Others
2 | def size_to_bytes(size):
3 | alphabets_list = [char for char in size if char.isalpha()]
4 | numbers_list = [char for char in size if char.isdigit()]
5 |
6 | if len(numbers_list) == 0:
7 | raise ValueError(f"Your input `size` does not contain numbers: {size}")
8 |
9 | size_numbers = int(float("".join(numbers_list)))
10 |
11 | if len(alphabets_list) == 0:
12 | # by default, if users do not specify the units, the number will be
13 | # regarded as in bytes
14 | return size_numbers
15 |
16 | suffix = "".join(alphabets_list).lower()
17 |
18 | if suffix == "kb" or suffix == "kib":
19 | return size_numbers << 10
20 | elif suffix == "mb" or suffix == "mib":
21 | return size_numbers << 20
22 | elif suffix == "gb" or suffix == "gib":
23 | return size_numbers << 30
24 | elif suffix == "tb" or suffix == "tib":
25 | return size_numbers << 40
26 | elif suffix == "pb" or suffix == "pib":
27 | return size_numbers << 50
28 | elif suffix == "eb" or suffix == "eib":
29 | return size_numbers << 60
30 | elif suffix == "zb" or suffix == "zib":
31 | return size_numbers << 70
32 | elif suffix == "yb" or suffix == "yib":
33 | return size_numbers << 80
34 | else:
35 | raise ValueError(
36 | f"You specified unidentifiable unit: {suffix}, "
37 | f"expected in [KB, MB, GB, TB, PB, EB, ZB, YB, "
38 | f"KiB, MiB, GiB, TiB, PiB, EiB, ZiB, YiB], "
39 | f"(case insensitive, counted by *Bytes*)."
40 | )
41 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/utils/path_resolver.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 |
3 |
4 | class MountPathResolver:
5 | @abstractmethod
6 | def resolve_destination_path(self, uri: str) -> str:
7 | """
8 | Resolve the path to the local file system.
9 | """
10 | raise NotImplementedError
11 |
12 | @abstractmethod
13 | def resolve_source_url(self, path: str) -> str:
14 | """
15 | Resolve the path to the local file system.
16 | """
17 | raise NotImplementedError
18 |
19 |
20 | # 本地运行,文件路径即为uri
21 | class LocalPathResolver(MountPathResolver):
22 | def resolve_destination_path(self, uri: str) -> str:
23 | """
24 | Resolve the path to the local file system.
25 | """
26 | return uri
27 |
28 | def resolve_source_url(self, path: str) -> str:
29 | """
30 | Resolve the path to the local file system.
31 | """
32 | return path
33 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/utils/path_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from loguru import logger
4 |
5 |
6 | def clear_folder(folder_path: str):
7 | if not os.path.exists(folder_path) or not os.path.isdir(folder_path):
8 | logger.warning(
9 | f"Fail to clear path {folder_path} because it is not a directory or does not exist."
10 | )
11 | return
12 |
13 | shutil.rmtree(folder_path) # 删除整个文件夹
14 | os.makedirs(folder_path) # 创建新的空文件夹
15 | logger.info(f"Folder {folder_path} cleared successfully.")
16 |
17 | return
18 |
--------------------------------------------------------------------------------
/src/pairag/data_pipeline/utils/vectordb_utils.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import requests
3 | from pairag.knowledgebase.index.pai.utils.vector_store_utils import create_vector_store
4 | from pairag.knowledgebase.index.pai.vector_store_config import (
5 | BaseVectorStoreConfig,
6 | SupportedVectorStoreType,
7 | )
8 | from pairag.knowledgebase.models import KnowledgeBase
9 | from loguru import logger
10 |
11 |
12 | def get_vector_store_config(
13 | rag_endpoint: str, rag_api_key: str, knowledgebase: str
14 | ) -> BaseVectorStoreConfig:
15 | try:
16 | response = requests.get(
17 | f"{rag_endpoint}/api/v1/knowledgebases/{knowledgebase}",
18 | headers={"Authorization": f"Bearer {rag_api_key}"},
19 | )
20 | knowledgebase: KnowledgeBase = KnowledgeBase.model_validate(response.json())
21 | return knowledgebase.vector_store_config
22 | except Exception as e:
23 | logger.error(f"Failed to get vector store config: {e}")
24 | raise e
25 |
26 |
27 | def get_vector_store(
28 | rag_endpoint: str,
29 | rag_api_key: str,
30 | knowledgebase: str,
31 | embed_dims: int,
32 | ):
33 | vector_store_config = get_vector_store_config(
34 | rag_endpoint=rag_endpoint, rag_api_key=rag_api_key, knowledgebase=knowledgebase
35 | )
36 |
37 | logger.info(f"Creating vector store from config {vector_store_config}.")
38 | assert (
39 | vector_store_config.type != SupportedVectorStoreType.faiss
40 | ), "FAISS is not supported."
41 |
42 | asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
43 |
44 | vector_store = create_vector_store(
45 | vectordb_config=vector_store_config,
46 | embed_dims=embed_dims,
47 | )
48 |
49 | logger.info(
50 | f"""[PaiVectorStore] init finished with following parameters:
51 | config: {vector_store_config}
52 | embed_dims: {embed_dims}
53 | """
54 | )
55 |
56 | return vector_store
57 |
--------------------------------------------------------------------------------
/src/pairag/evaluation/dataset/crag/crag_data_loader.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | from llama_index.core.indices import VectorStoreIndex
3 | from llama_index.core.ingestion import IngestionPipeline
4 | from pairag.file.readers.pai.pai_data_reader import PaiDataReader
5 | from loguru import logger
6 |
7 |
8 | class CragDataLoader:
9 | def __init__(
10 | self,
11 | data_reader: PaiDataReader,
12 | embed_model: Any = None,
13 | vector_index: VectorStoreIndex = None,
14 | ):
15 | self._data_reader = data_reader
16 | self._embed_model = embed_model
17 | self._vector_index = vector_index
18 |
19 | def load_data(
20 | self,
21 | file_path_or_directory: str,
22 | from_oss: bool = False,
23 | oss_path: str = None,
24 | filter_pattern: str = None,
25 | enable_raptor: str = False,
26 | ):
27 | """Load data from a file or directory."""
28 | documents = self._data_reader.load_data(
29 | file_path_or_directory=file_path_or_directory,
30 | filter_pattern=filter_pattern,
31 | oss_path=oss_path,
32 | from_oss=from_oss,
33 | )
34 | if from_oss:
35 | logger.info(f"Loaded {len(documents)} documents from {oss_path}")
36 | else:
37 | logger.info(
38 | f"Loaded {len(documents)} documents from {file_path_or_directory}"
39 | )
40 |
41 | transformations = [
42 | self._embed_model,
43 | ]
44 |
45 | ingestion_pipeline = IngestionPipeline(transformations=transformations)
46 |
47 | nodes = ingestion_pipeline.run(documents=documents)
48 | logger.info(
49 | f"[DataLoader] parsed {len(documents)} documents into {len(nodes)} nodes."
50 | )
51 |
52 | self._vector_index.insert_nodes(nodes)
53 | logger.info(f"[DataLoader] Inserted {len(nodes)} nodes.")
54 | logger.info("[DataLoader] Ingestion Completed!")
55 |
--------------------------------------------------------------------------------
/src/pairag/evaluation/dataset/crag/crag_jsonl_reader.py:
--------------------------------------------------------------------------------
1 | """Tabular parser-Excel parser.
2 |
3 | Contains parsers for tabular data files.
4 |
5 | """
6 |
7 | from pathlib import Path
8 | from typing import Any, Dict, List, Optional
9 | from fsspec import AbstractFileSystem
10 | from llama_index.core.readers.base import BaseReader
11 | from llama_index.core.schema import Document
12 | import json
13 |
14 |
15 | class CragJsonLReader(BaseReader):
16 | """JsonL reader."""
17 |
18 | def __init__(self, *args: Any, **kwargs: Any) -> None:
19 | """Init params."""
20 | super().__init__(*args, **kwargs)
21 |
22 | def load_data(
23 | self,
24 | file_path: Path,
25 | extra_info: Optional[Dict] = None,
26 | fs: Optional[AbstractFileSystem] = None,
27 | ) -> List[Document]:
28 | with open(file_path, "r", encoding="utf-8") as file:
29 | json_lines = [line.strip() for line in file.readlines()]
30 |
31 | docs = []
32 | for i, text in enumerate(json_lines):
33 | json_data = json.loads(text)
34 | search_results = json_data["search_results"]
35 | for j, search_result in enumerate(search_results):
36 | extra_info["row_number"] = i + 1
37 | extra_info["dataset_source"] = "crag"
38 | docs.append(
39 | Document(
40 | doc_id=f"{json_data['interaction_id']}__{j}",
41 | text=search_result["page_snippet"],
42 | metadata=extra_info,
43 | )
44 | )
45 | return docs
46 |
--------------------------------------------------------------------------------
/src/pairag/evaluation/dataset/rag_eval_dataset_refactor.py:
--------------------------------------------------------------------------------
1 | import json
2 | import uuid
3 | from typing import List, Optional, Dict, Any
4 | from pydantic import BaseModel, Field, ValidationError
5 | from pairag.evaluation.dataset.rag_qca_dataset_refactor import Source, QcapSample
6 | from loguru import logger
7 |
8 |
9 | def generate_eval_uuid() -> str:
10 | return f"eval_{str(uuid.uuid4())}"
11 |
12 |
13 | class EvalResults(BaseModel):
14 | results: Dict[str, Any]
15 | source: Optional[Source] = Source()
16 |
17 |
18 | class QcapEvalSample(BaseModel):
19 | """
20 | 主模型,表示完整的 Evaluated QCAP Sample,包含 id, qcap 和 eval_results 等字段。
21 | """
22 |
23 | id: str = Field(default_factory=generate_eval_uuid)
24 | qcap: QcapSample
25 | eval_results: EvalResults
26 |
27 |
28 | class QcapEvalDataset(BaseModel):
29 | """
30 | 评估数据集模型,包含多个 QcapEvalSample
31 | """
32 |
33 | samples: List[QcapEvalSample]
34 |
35 | def save_json(self, file_path: str):
36 | """
37 | 将数据集保存为 jsonl 文件,每一行是一个 QcapEvalSample 的 JSON 表示。
38 | """
39 | try:
40 | with open(file_path, "w", encoding="utf-8") as f:
41 | for sample in self.samples:
42 | f.write(sample.model_dump_json())
43 | f.write("\n")
44 | logger.info(f"数据集已成功保存到 {file_path}")
45 | except Exception as e:
46 | logger.info(f"保存数据集时出错: {e}")
47 |
48 | @classmethod
49 | def from_json(cls, file_path: str) -> "QcapEvalDataset":
50 | """
51 | 从 jsonl 文件中读取数据,生成一个 QcapEvalDataset 实例。
52 | """
53 | samples = []
54 | try:
55 | with open(file_path, "r", encoding="utf-8") as f:
56 | for line_number, line in enumerate(f, start=1):
57 | line = line.strip()
58 | if not line:
59 | continue # 跳过空行
60 | try:
61 | sample_dict = json.loads(line)
62 | sample = QcapEvalSample(**sample_dict)
63 | samples.append(sample)
64 | except (json.JSONDecodeError, ValidationError) as e:
65 | logger.info(f"在第 {line_number} 行读取样本时出错: {e}")
66 | logger.info(f"从 {file_path} 成功加载了 {len(samples)} 个样本")
67 | return cls(samples=samples)
68 | except FileNotFoundError:
69 | logger.info(f"文件 {file_path} 未找到。")
70 | return cls(samples=[])
71 | except Exception as e:
72 | logger.info(f"读取数据集时出错: {e}")
73 | return cls(samples=[])
74 |
--------------------------------------------------------------------------------
/src/pairag/evaluation/dataset/state_manager.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from enum import Enum
4 |
5 |
6 | class DatasetState(Enum):
7 | QCA = "qca"
8 | QCAP = "qcap"
9 | RETRIEVAL = "retrieval_evaluation"
10 | RESPONSE = "response_evaluation"
11 | E2E = "end2end_evaluation"
12 |
13 |
14 | class StateManager:
15 | def __init__(self, state_file="state.json"):
16 | self.state_file = state_file
17 | # 使用字典来保存每个状态的完成情况
18 | self.states = {state: False for state in DatasetState}
19 | self.load_state()
20 |
21 | def load_state(self):
22 | """从JSON文件加载所有状态的完成情况"""
23 | if os.path.exists(self.state_file):
24 | try:
25 | with open(self.state_file, "r", encoding="utf-8") as f:
26 | data = json.load(f)
27 | for state in DatasetState:
28 | # 更新状态的完成情况,如果JSON中有对应的键
29 | if state.value in data:
30 | self.states[state] = data[state.value]
31 | print(f"加载的状态: {self.states}")
32 | except Exception as e:
33 | print(f"加载状态时出错: {e}")
34 | # 如果出错,保持所有状态为未完成
35 | else:
36 | print("状态文件不存在,初始化所有状态为未完成")
37 |
38 | def save_state(self):
39 | """将所有状态的完成情况保存到JSON文件"""
40 | try:
41 | with open(self.state_file, "w", encoding="utf-8") as f:
42 | # 将状态字典转换为以状态值为键的字典
43 | data = {
44 | state.value: completed for state, completed in self.states.items()
45 | }
46 | json.dump(data, f, ensure_ascii=False, indent=4)
47 | print(f"状态已保存为: {data}")
48 | except Exception as e:
49 | print(f"保存状态时出错: {e}")
50 |
51 | def mark_completed(self, new_state):
52 | """标记某个状态为完成并保存"""
53 | if isinstance(new_state, DatasetState):
54 | self.states[new_state] = True
55 | self.save_state()
56 | print(f"状态已标记为完成: {new_state}")
57 | else:
58 | raise ValueError("new_state 必须是 DatasetState 的实例")
59 |
60 | def is_completed(self, state):
61 | """判断某个状态是否已经完成"""
62 | if isinstance(state, DatasetState):
63 | return self.states.get(state, False)
64 | else:
65 | raise ValueError("state 必须是 DatasetState 的实例")
66 |
67 | def get_completed_states(self):
68 | """获取所有已完成的状态"""
69 | return [state for state, completed in self.states.items() if completed]
70 |
71 | def reset_state(self):
72 | """重置所有状态为未完成"""
73 | for state in self.states:
74 | self.states[state] = False
75 | self.save_state()
76 | print("所有状态已重置为未完成")
77 |
--------------------------------------------------------------------------------
/src/pairag/evaluation/metrics/response/base.py:
--------------------------------------------------------------------------------
1 | """Llm metric for response evaluation."""
2 | from abc import abstractmethod
3 | from typing import Any, Optional, Sequence
4 |
5 | from llama_index.core.evaluation.base import EvaluationResult
6 | from llama_index.core.llms.llm import LLM
7 | from llama_index.core.prompts.mixin import PromptDictType
8 | from llama_index.core.prompts.mixin import PromptMixin, PromptMixinType
9 |
10 |
11 | class LlmMetric(PromptMixin):
12 | """
13 | Llm Metric.
14 | """
15 |
16 | metric_name: str = "base"
17 |
18 | def __init__(
19 | self,
20 | llm: Optional[LLM] = None,
21 | raise_error: bool = False,
22 | ) -> None:
23 | """Init params."""
24 | self._llm = llm
25 | self._raise_error = raise_error
26 |
27 | def _get_prompts(self) -> PromptDictType:
28 | """Get prompts."""
29 | return {
30 | "eval_template": self._eval_template,
31 | }
32 |
33 | def _update_prompts(self, prompts: PromptDictType) -> None:
34 | """Update prompts."""
35 | if "eval_template" in prompts:
36 | self._eval_template = prompts["eval_template"]
37 |
38 | @abstractmethod
39 | async def parse_eval_result(self, eval_result: str) -> float:
40 | """Parse eval_result."""
41 | raise NotImplementedError
42 |
43 | @abstractmethod
44 | async def aevaluate(
45 | self,
46 | query: str | None = None,
47 | reference_answer: str | None = None,
48 | contexts: Sequence[str] | None = None,
49 | response_answer: str | None = None,
50 | **kwargs: Any,
51 | ) -> EvaluationResult:
52 | """Run evaluation with query string, retrieved contexts,
53 | and generated response string.
54 |
55 | Subclasses can override this method to provide custom evaluation logic and
56 | take in additional arguments.
57 | """
58 | raise NotImplementedError
59 |
60 | def _get_prompt_modules(self) -> PromptMixinType:
61 | """Get prompt modules."""
62 | return {}
63 |
--------------------------------------------------------------------------------
/src/pairag/evaluation/metrics/retrieval/core.py:
--------------------------------------------------------------------------------
1 | def default_hit_rate(expected_ids, retrieved_ids):
2 | """Default HitRate calculation: Check if there is a single hit"""
3 | is_hit = any(id in expected_ids for id in retrieved_ids)
4 | score = 1.0 if is_hit else 0.0
5 | return score
6 |
7 |
8 | def granular_hit_rate(expected_ids, retrieved_ids):
9 | """Granular HitRate calculation: Calculate all hits and divide by the number of expected docs"""
10 | expected_set = set(expected_ids)
11 | hits = sum(1 for doc_id in retrieved_ids if doc_id in expected_set)
12 | score = hits / len(expected_ids) if expected_ids else 0.0
13 | return score
14 |
15 |
16 | def default_mrr(expected_ids, retrieved_ids):
17 | """Default MRR calculation: Reciprocal rank of the first relevant document retrieved"""
18 | for i, id in enumerate(retrieved_ids):
19 | if id in expected_ids:
20 | return 1.0 / (i + 1)
21 | return 0.0
22 |
23 |
24 | def granular_mrr(expected_ids, retrieved_ids):
25 | """Granular MRR calculation: All relevant retrieved docs have their reciprocal ranks summed and averaged."""
26 | expected_set = set(expected_ids)
27 | reciprocal_rank_sum = 0.0
28 | relevant_docs_count = 0
29 | for index, doc_id in enumerate(retrieved_ids):
30 | if doc_id in expected_set:
31 | relevant_docs_count += 1
32 | reciprocal_rank_sum += 1.0 / (index + 1)
33 | score = (
34 | reciprocal_rank_sum / relevant_docs_count if relevant_docs_count > 0 else 0.0
35 | )
36 | return score
37 |
--------------------------------------------------------------------------------
/src/pairag/evaluation/metrics/retrieval/hitrate.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 | from pairag.evaluation.metrics.retrieval.core import (
3 | granular_hit_rate,
4 | default_hit_rate,
5 | )
6 |
7 |
8 | class HitRate:
9 | """Hit rate metric: Compute hit rate with two calculation options.
10 |
11 | - The default method checks for a single match between any of the retrieved docs and expected docs.
12 | - The more granular method checks for all potential matches between retrieved docs and expected docs.
13 |
14 | Attributes:
15 | metric_name (str): The name of the metric.
16 | use_granular_hit_rate (bool): Determines whether to use the granular method for calculation.
17 | """
18 |
19 | def __init__(
20 | self, metric_name: str = "hitrate", use_granular_hit_rate: bool = False
21 | ):
22 | self.metric_name = metric_name
23 | self.use_granular_hit_rate = use_granular_hit_rate
24 |
25 | def compute(
26 | self,
27 | expected_ids: Optional[List[str]] = None,
28 | retrieved_ids: Optional[List[str]] = None,
29 | ):
30 | """Compute metric based on the provided inputs.
31 |
32 | Parameters:
33 | expected_ids (Optional[List[str]]): Expected document IDs.
34 | retrieved_ids (Optional[List[str]]): Retrieved document IDs.
35 |
36 | Raises:
37 | ValueError: If the necessary IDs are not provided.
38 |
39 | Returns:
40 | RetrievalMetricResult: The result with the computed hit rate score.
41 | """
42 | # Checking for the required arguments
43 | if (
44 | retrieved_ids is None
45 | or expected_ids is None
46 | or not retrieved_ids
47 | or not expected_ids
48 | ):
49 | raise ValueError("Retrieved ids and expected ids must be provided")
50 |
51 | if self.use_granular_hit_rate:
52 | return granular_hit_rate(expected_ids, retrieved_ids)
53 | else:
54 | return default_hit_rate(expected_ids, retrieved_ids)
55 |
--------------------------------------------------------------------------------
/src/pairag/evaluation/metrics/retrieval/mrr.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 | from pairag.evaluation.metrics.retrieval.core import (
3 | default_mrr,
4 | granular_mrr,
5 | )
6 |
7 |
8 | class MRR:
9 | """MRR (Mean Reciprocal Rank) metric with two calculation options.
10 |
11 | - The default method calculates the reciprocal rank of the first relevant retrieved document.
12 | - The more granular method sums the reciprocal ranks of all relevant retrieved documents and divides by the count of relevant documents.
13 |
14 | Attributes:
15 | metric_name (str): The name of the metric.
16 | use_granular_mrr (bool): Determines whether to use the granular method for calculation.
17 | """
18 |
19 | def __init__(self, metric_name: str = "mrr", use_granular_mrr: bool = False):
20 | self.metric_name = metric_name
21 | self.use_granular_mrr = use_granular_mrr
22 |
23 | def compute(
24 | self,
25 | expected_ids: Optional[List[str]] = None,
26 | retrieved_ids: Optional[List[str]] = None,
27 | ):
28 | """Compute MRR based on the provided inputs and selected method.
29 |
30 | Parameters:
31 | expected_ids (Optional[List[str]]): Expected document IDs.
32 | retrieved_ids (Optional[List[str]]): Retrieved document IDs.
33 |
34 | Raises:
35 | ValueError: If the necessary IDs are not provided.
36 |
37 | Returns:
38 | RetrievalMetricResult: The result with the computed MRR score.
39 | """
40 | # Checking for the required arguments
41 | if (
42 | retrieved_ids is None
43 | or expected_ids is None
44 | or not retrieved_ids
45 | or not expected_ids
46 | ):
47 | raise ValueError("Retrieved ids and expected ids must be provided")
48 |
49 | if self.use_granular_mrr:
50 | return granular_mrr(expected_ids, retrieved_ids)
51 | else:
52 | return default_mrr(expected_ids, retrieved_ids)
53 |
--------------------------------------------------------------------------------
/src/pairag/evaluation/pipeline/run_evaluation_pipeline.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from pairag.evaluation.utils.create_components import (
3 | get_rag_components,
4 | get_rag_config_and_mode,
5 | get_eval_components,
6 | )
7 |
8 |
9 | def run_rag_evaluation_pipeline(
10 | config_file=None,
11 | oss_path=None,
12 | data_path=None,
13 | pattern=None,
14 | exp_name="default",
15 | eval_model_llm_config=None,
16 | dataset=None,
17 | use_pai_eval=False,
18 | ):
19 | assert (oss_path is not None) or (
20 | data_path is not None
21 | ), "Must provide either local path or oss path."
22 | assert (oss_path is None) or (
23 | data_path is None
24 | ), f"Can not provide both local path '{data_path}' and oss path '{oss_path}'."
25 |
26 | config, mode, exist_flag = get_rag_config_and_mode(config_file, exp_name)
27 | data_loader, vector_index, query_engine = get_rag_components(config, dataset)
28 | if not exist_flag:
29 | data_loader.load_data(
30 | file_path_or_directory=data_path,
31 | filter_pattern=pattern,
32 | oss_path=oss_path,
33 | from_oss=oss_path is not None,
34 | enable_raptor=False,
35 | )
36 |
37 | qca_generator, evaluator = get_eval_components(
38 | config,
39 | vector_index,
40 | query_engine,
41 | mode,
42 | eval_model_llm_config,
43 | use_pai_eval,
44 | )
45 |
46 | _ = asyncio.run(
47 | qca_generator.agenerate_all_dataset(dataset=dataset, dataset_path=data_path)
48 | )
49 | asyncio.run(evaluator.aevaluation_for_retrieval())
50 | asyncio.run(evaluator.aevaluation_for_response())
51 | asyncio.run(evaluator.aevaluation_for_all())
52 |
--------------------------------------------------------------------------------
/src/pairag/evaluation/pipeline/run_multimodal_evaluation_pipeline.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from pairag.evaluation.utils.create_components import (
3 | get_rag_components,
4 | get_rag_config_and_mode,
5 | get_multimodal_eval_components,
6 | )
7 | from pairag.evaluation.dataset.state_manager import DatasetState
8 |
9 |
10 | def run_multimodal_evaluation_pipeline(
11 | config_file=None,
12 | oss_path=None,
13 | qca_dataset_path=None,
14 | data_path=None,
15 | pattern=None,
16 | exp_name="default",
17 | eval_model_llm_config=None,
18 | tested_multimodal_llm_config=None,
19 | ):
20 | config, mode, exist_flag = get_rag_config_and_mode(config_file, exp_name)
21 | assert mode == "image"
22 | data_loader, vector_index, query_engine = get_rag_components(config)
23 | multimodal_qca_generator, evaluator = get_multimodal_eval_components(
24 | config,
25 | exp_name,
26 | vector_index,
27 | query_engine,
28 | eval_model_llm_config,
29 | tested_multimodal_llm_config,
30 | qca_dataset_path,
31 | )
32 | if qca_dataset_path:
33 | multimodal_qca_generator.state_manager.mark_completed(DatasetState.QCA)
34 | _ = asyncio.run(
35 | multimodal_qca_generator.agenerate_predicted_multimodal_dataset_only_via_vlm()
36 | )
37 | asyncio.run(evaluator.aevaluation_for_response())
38 | return
39 |
40 | assert (oss_path is not None) or (
41 | data_path is not None
42 | ), "Must provide either local path or oss path."
43 | assert (oss_path is None) or (
44 | data_path is None
45 | ), f"Can not provide both local path '{data_path}' and oss path '{oss_path}'."
46 |
47 | if not exist_flag:
48 | data_loader.load_data(
49 | file_path_or_directory=data_path,
50 | filter_pattern=pattern,
51 | oss_path=oss_path,
52 | from_oss=oss_path is not None,
53 | enable_raptor=False,
54 | )
55 |
56 | _ = asyncio.run(multimodal_qca_generator.agenerate_all_dataset())
57 | asyncio.run(evaluator.aevaluation_for_retrieval())
58 | asyncio.run(evaluator.aevaluation_for_response())
59 |
--------------------------------------------------------------------------------
/src/pairag/evaluation/utils/file_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def list_files_in_directory(directory_path):
5 | full_paths = []
6 | for dirpath, _, filenames in os.walk(directory_path):
7 | for filename in filenames:
8 | full_path = os.path.join(dirpath, filename)
9 | full_paths.append(full_path)
10 | return full_paths
11 |
--------------------------------------------------------------------------------
/src/pairag/integrations/data_analysis/data_analysis_config.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from typing import Dict, List, Literal
3 | from pydantic import BaseModel
4 |
5 | from pairag.web.ui_constants import (
6 | NL2SQL_GENERAL_PROMPTS,
7 | SYN_GENERAL_PROMPTS,
8 | DA_SYSTEM_ROLE_PROMPT,
9 | )
10 |
11 |
12 | class DataAnalysisType(str, Enum):
13 | """Data Analysis types."""
14 |
15 | pandas = "pandas"
16 | sqlite = "sqlite"
17 | mysql = "mysql"
18 |
19 |
20 | class BaseAnalysisConfig(BaseModel):
21 | """Base class for data analysis config."""
22 |
23 | type: DataAnalysisType
24 | # llm: OpenAICompatibleLlmConfig | None = OpenAICompatibleLlmConfig()
25 | model_id: str = "default"
26 | nl2sql_prompt: str = NL2SQL_GENERAL_PROMPTS
27 | synthesizer_prompt: str = SYN_GENERAL_PROMPTS
28 | system_role_prompt: str = DA_SYSTEM_ROLE_PROMPT
29 |
30 |
31 | class PandasAnalysisConfig(BaseAnalysisConfig):
32 | type: Literal[DataAnalysisType.pandas] = DataAnalysisType.pandas
33 | file_path: str = "./localdata/data_analysis/"
34 |
35 |
36 | class SqlAnalysisConfig(BaseAnalysisConfig):
37 | database: str
38 | tables: List[str] = []
39 | descriptions: Dict[str, str] = {}
40 | # offline
41 | enable_enhanced_description: bool = False
42 | enable_db_history: bool = False
43 | enable_db_embedding: bool = False
44 | max_col_num: int = 100
45 | max_val_num: int = 10000
46 | # online
47 | enable_query_preprocessor: bool = False
48 | enable_db_preretriever: bool = False
49 | enable_db_selector: bool = False
50 |
51 |
52 | class SqliteAnalysisConfig(SqlAnalysisConfig):
53 | type: Literal[DataAnalysisType.sqlite] = DataAnalysisType.sqlite
54 | db_path: str
55 |
56 |
57 | class MysqlAnalysisConfig(SqlAnalysisConfig):
58 | type: Literal[DataAnalysisType.mysql] = DataAnalysisType.mysql
59 | user: str
60 | password: str
61 | host: str
62 | port: int
63 |
--------------------------------------------------------------------------------
/src/pairag/integrations/data_analysis/nl2sql/db_utils/constants.py:
--------------------------------------------------------------------------------
1 | DEFAULT_DB_DESCRIPTION_PATH = (
2 | "./localdata/data_analysis/nl2sql/description/db_structured_description.json"
3 | )
4 | DEFAULT_DB_HISTORY_PATH = (
5 | "./localdata/data_analysis/nl2sql/history/db_query_history.json"
6 | )
7 |
8 | DESCRIPTION_STORAGE_PATH = "./localdata/data_analysis/nl2sql/storage/description_index"
9 | HISTORY_STORAGE_PATH = "./localdata/data_analysis/nl2sql/storage/history_index"
10 | VALUE_STORAGE_PATH = "./localdata/data_analysis/nl2sql/storage/value_index"
11 | VALUE_LSH_PATH = "./localdata/data_analysis/nl2sql/storage/value_lsh"
12 |
13 | EMBEDDING_DIM_DICT = {"bge-large-zh-v1.5": 1024, "bge-m3": 1024}
14 |
--------------------------------------------------------------------------------
/src/pairag/integrations/data_analysis/nl2sql/query_preprocessor.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 | from pydantic import BaseModel
3 |
4 | from llama_index.core.llms.llm import LLM
5 | from llama_index.core import Settings
6 | from llama_index.core import BasePromptTemplate
7 | from llama_index.core.schema import QueryBundle
8 |
9 | from pairag.integrations.data_analysis.nl2sql.nl2sql_prompts import (
10 | DEFAULT_KEYWORD_EXTRACTION_PROMPT,
11 | )
12 |
13 |
14 | class QueryPreprocessor:
15 | """
16 | 预处理自然语言查询,目前主要考虑关键词提取,query改写待定;
17 | """
18 |
19 | def __init__(
20 | self,
21 | llm: Optional[LLM] = None,
22 | keyword_extraction_prompt: Optional[BasePromptTemplate] = None,
23 | ) -> None:
24 | self._llm = llm or Settings.llm
25 | self._keyword_extraction_prompt = (
26 | keyword_extraction_prompt or DEFAULT_KEYWORD_EXTRACTION_PROMPT
27 | )
28 |
29 | def extract_keywords(self, nl_query: QueryBundle) -> List[str]:
30 | keyword_list_obj = self._llm.structured_predict(
31 | output_cls=KeywordList,
32 | prompt=self._keyword_extraction_prompt,
33 | llm_kwargs={
34 | "tool_choice": {"type": "function", "function": {"name": "KeywordList"}}
35 | },
36 | query_str=nl_query.query_str,
37 | fewshot_examples="",
38 | )
39 | # text_complection = LLMTextCompletionProgram.from_defaults(
40 | # output_cls=KeywordList,
41 | # prompt=self._keyword_extraction_prompt,
42 | # )
43 | # keyword_list_obj = text_complection(query_str=nl_query.query_str, fewshot_examples="")
44 |
45 | keywords = keyword_list_obj.Keywords
46 | # later check if parser needed
47 | # keywords = parse(self, keywords)
48 | # logger.info(f"keyword_list: {keywords} extracted.")
49 | return keywords
50 |
51 | async def aextract_keywords(self, nl_query: QueryBundle) -> List[str]:
52 | keyword_list_obj = await self._llm.astructured_predict(
53 | output_cls=KeywordList,
54 | prompt=self._keyword_extraction_prompt,
55 | llm_kwargs={
56 | "tool_choice": {"type": "function", "function": {"name": "KeywordList"}}
57 | },
58 | query_str=nl_query.query_str,
59 | fewshot_examples="",
60 | )
61 | keywords = keyword_list_obj.Keywords
62 | # later check if parser needed
63 | # keywords = parse(self, keywords)
64 | # logger.info(f"keyword_list: {keywords} extracted.")
65 | return keywords
66 |
67 | def transform_query(self, nl_query: QueryBundle) -> List[str]:
68 | # 考虑历史对话的query改写
69 | pass
70 |
71 |
72 | class KeywordList(BaseModel):
73 | """Data model for KeywordList."""
74 |
75 | Keywords: List[str]
76 |
--------------------------------------------------------------------------------
/src/pairag/integrations/data_analysis/test/test_bird_schema_collector.py:
--------------------------------------------------------------------------------
1 | from pairag.integrations.data_analysis.data_analysis_config import (
2 | SqliteAnalysisConfig,
3 | )
4 | from pairag.integrations.data_analysis.text2sql.db_connector import (
5 | SqliteConnector,
6 | )
7 | from pairag.integrations.data_analysis.text2sql.db_info_collector import (
8 | BirdSchemaCollector,
9 | )
10 |
11 |
12 | sqlite_config = SqliteAnalysisConfig(
13 | db_path="/Users/chuyu/Documents/datasets/BIRD/dev_20240627/dev_databases/california_schools/",
14 | database="california_schools.sqlite",
15 | )
16 |
17 | connector = SqliteConnector(sqlite_config)
18 | sql_databse = connector.connect()
19 |
20 | bird_schema_collector = BirdSchemaCollector(
21 | db_name="california_schools",
22 | sql_database=sql_databse,
23 | database_file_path="/Users/chuyu/Documents/datasets/BIRD/dev_20240627/dev_databases",
24 | )
25 |
26 | bird_schema_collector.collect()
27 |
--------------------------------------------------------------------------------
/src/pairag/integrations/data_analysis/test/test_db_connector.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dotenv import load_dotenv
3 |
4 | from llama_index.core import SQLDatabase
5 |
6 | from pairag.integrations.data_analysis.data_analysis_config import (
7 | SqliteAnalysisConfig,
8 | MysqlAnalysisConfig,
9 | )
10 | from pairag.integrations.data_analysis.text2sql.db_connector import (
11 | MysqlConnector,
12 | SqliteConnector,
13 | )
14 |
15 | # 加载 .env 文件中的环境变量
16 | load_dotenv()
17 |
18 | # 获取环境变量中的 API 密钥
19 | host = os.getenv("host")
20 | port = os.getenv("port")
21 | user = os.getenv("user")
22 | password = os.getenv("password")
23 | database = os.getenv("database")
24 |
25 | mysql_config = MysqlAnalysisConfig(
26 | host=host,
27 | port=port,
28 | user=user,
29 | password=password,
30 | database=database,
31 | tables=["pets"],
32 | )
33 |
34 | sqlite_config = SqliteAnalysisConfig(
35 | db_path="./tests/testdata/db_data/",
36 | database="pets.sqlite",
37 | )
38 |
39 |
40 | def test_mysql_connector():
41 | connector = MysqlConnector(mysql_config)
42 | assert connector._db_config.type == "mysql"
43 | sql_databse = connector.connect()
44 | assert isinstance(sql_databse, SQLDatabase)
45 |
46 |
47 | def test_sqlite_connector():
48 | connector = SqliteConnector(sqlite_config)
49 | assert connector._db_config.type == "sqlite"
50 | sql_databse = connector.connect()
51 | assert isinstance(sql_databse, SQLDatabase)
52 |
--------------------------------------------------------------------------------
/src/pairag/integrations/data_analysis/test/test_db_info_collector.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dotenv import load_dotenv
3 |
4 |
5 | from pairag.integrations.data_analysis.data_analysis_config import (
6 | MysqlAnalysisConfig,
7 | )
8 | from pairag.integrations.data_analysis.text2sql.db_connector import (
9 | MysqlConnector,
10 | )
11 | from pairag.integrations.data_analysis.text2sql.db_info_collector import (
12 | SchemaCollector,
13 | HistoryCollector,
14 | ValueCollector,
15 | )
16 |
17 |
18 | # 加载 .env 文件中的环境变量
19 | load_dotenv()
20 |
21 | # 获取环境变量中的 API 密钥
22 | host = os.getenv("host")
23 | port = os.getenv("port")
24 | user = os.getenv("user")
25 | password = os.getenv("password")
26 | database = os.getenv("database")
27 |
28 | mysql_config = MysqlAnalysisConfig(
29 | host=host,
30 | port=port,
31 | user=user,
32 | password=password,
33 | database=database,
34 | # tables=["pets"],
35 | )
36 |
37 | connector = MysqlConnector(mysql_config)
38 | sql_database = connector.connect()
39 | print("connector_info:", sql_database)
40 |
41 |
42 | def test_schema_processor():
43 | schema_collector = SchemaCollector(
44 | db_name=mysql_config.database, sql_database=sql_database
45 | )
46 | schema_description = schema_collector.collect()
47 | assert isinstance(schema_description, dict)
48 |
49 |
50 | def test_history_processor():
51 | history_collector = HistoryCollector(db_name=mysql_config.database)
52 | query_history = history_collector.collect()
53 | assert isinstance(query_history, list)
54 |
55 |
56 | def test_value_processor():
57 | value_collector = ValueCollector(
58 | db_name=mysql_config.database, sql_database=sql_database
59 | )
60 | unique_values = value_collector.collect()
61 | assert isinstance(unique_values, dict)
62 |
--------------------------------------------------------------------------------
/src/pairag/integrations/data_analysis/test/test_index_retriever.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List
3 | import hashlib
4 |
5 | from llama_index.embeddings.huggingface import HuggingFaceEmbedding
6 | from llama_index.core.schema import TextNode
7 | from llama_index.core.base.embeddings.base import BaseEmbedding
8 |
9 | from pairag.integrations.data_analysis.text2sql.db_info_retriever import (
10 | SchemaRetriever,
11 | )
12 |
13 |
14 | if os.path.exists("./model_repository/bge-m3"):
15 | embed_model_bge = HuggingFaceEmbedding(
16 | model_name="./model_repository/bge-m3", embed_batch_size=20
17 | )
18 | else:
19 | embed_model_bge = None
20 |
21 | mock_nodes = [
22 | TextNode(
23 | text="This is mock node 0.",
24 | metadata={"table_name": "table0", "column_name": "column0"},
25 | ),
26 | TextNode(
27 | text="This is mock node 1.",
28 | metadata={"table_name": "table0", "column_name": "column1"},
29 | ),
30 | TextNode(
31 | text="This is mock node 2.",
32 | metadata={"table_name": "table0", "column_name": "column2"},
33 | ),
34 | ]
35 |
36 |
37 | def get_nodes_with_embeddings(embed_model: BaseEmbedding, nodes: List[TextNode]):
38 | # get embeddings
39 | embeddings = embed_model.get_text_embedding_batch(
40 | [node.get_content(metadata_mode="embed") for node in nodes]
41 | )
42 | # update nodes embedding
43 | for node, embedding in zip(nodes, embeddings):
44 | node.embedding = embedding
45 | node_info_str = node.get_metadata_str() + node.get_text()
46 | node.id_ = hashlib.sha256(node_info_str.encode()).hexdigest()
47 |
48 | return nodes
49 |
50 |
51 | mock_nodes_with_embeddings = get_nodes_with_embeddings(embed_model_bge, mock_nodes)
52 |
53 | # 初始化检索器
54 | mock_retriever = SchemaRetriever(
55 | db_name="mock_db",
56 | embed_model=embed_model_bge,
57 | similarity_top_k=1,
58 | )
59 |
60 | # 插入/更新nodes
61 | mock_retriever.get_index(mock_nodes_with_embeddings)
62 |
63 |
64 | mock_nodes_update = [
65 | TextNode(
66 | text="This is mock node 0.",
67 | metadata={"table_name": "table0", "column_name": "column0"},
68 | ),
69 | TextNode(
70 | text="This is mock node 01test.",
71 | metadata={"table_name": "table0", "column_name": "column1"},
72 | ),
73 | TextNode(
74 | text="This is mock node 2.",
75 | metadata={"table_name": "table0", "column_name": "column2"},
76 | ),
77 | ]
78 |
79 | mock_nodes_with_embeddings = get_nodes_with_embeddings(
80 | embed_model_bge, mock_nodes_update
81 | )
82 |
83 | # 插入/更新nodes
84 | mock_retriever.get_index(mock_nodes_with_embeddings)
85 | # new_retriever = mock_retriever._schema_index.as_retriever()
86 |
87 | res = mock_retriever.retrieve_nodes(query="what is mock node 2?")
88 |
89 | print("retrieve result:", res)
90 |
--------------------------------------------------------------------------------
/src/pairag/integrations/data_analysis/test/test_query_processor.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dotenv import load_dotenv
3 |
4 | from llama_index.llms.openai_like import OpenAILike
5 | from llama_index.core.schema import QueryBundle
6 | from llama_index.llms.dashscope import DashScope, DashScopeGenerationModels
7 | from llama_index.core import Settings
8 |
9 | from pairag.integrations.data_analysis.text2sql.query_processor import KeywordExtractor
10 | from pairag.integrations.llms.pai.llm_config import (
11 | OpenAICompatibleLlmConfig,
12 | )
13 |
14 |
15 | # 加载 .env 文件中的环境变量
16 | load_dotenv()
17 |
18 | llm_ds = DashScope(
19 | model_name=DashScopeGenerationModels.QWEN_MAX,
20 | api_key=os.getenv("DASHSCOPE_API_KEY"),
21 | temperature=0.1,
22 | max_tokens=2048,
23 | )
24 | print("DashScope:", llm_ds.metadata.is_function_calling_model)
25 |
26 | llm_config = OpenAICompatibleLlmConfig(model="qwen-max")
27 |
28 | llm_ol = OpenAILike(
29 | model=llm_config.model,
30 | api_base=llm_config.base_url,
31 | temperature=llm_config.temperature,
32 | system_prompt=llm_config.system_prompt,
33 | is_chat_model=True,
34 | api_key=llm_config.api_key or os.environ.get("DASHSCOPE_API_KEY"),
35 | max_tokens=llm_config.max_tokens,
36 | reuse_client=False,
37 | is_function_calling_model=True,
38 | )
39 |
40 | Settings.llm = llm_ol
41 | query = "有猫的学生有多少?"
42 | qp = KeywordExtractor()
43 | result = qp.process(QueryBundle(query))
44 | print(result)
45 |
46 |
47 | # # def test_query_processor():
48 | # # query = "有猫的学生有多少?"
49 | # # qp = KeywordExtractor()
50 | # # keywords = qp.process(QueryBundle(query))
51 | # # assert isinstance(keywords, list)
52 | # # assert len(keywords) > 0
53 | # # assert "猫" in keywords
54 | # # assert "学生" in keywords
55 |
--------------------------------------------------------------------------------
/src/pairag/integrations/data_analysis/text2sql/evaluations/base_evaluator.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 |
4 | class SQLEvaluator(ABC):
5 | """生成SQL评估接口"""
6 |
7 | @abstractmethod
8 | async def abatch_loader(
9 | self,
10 | ):
11 | pass
12 |
13 | @abstractmethod
14 | async def abatch_query(self, nums: int):
15 | pass
16 |
17 | @abstractmethod
18 | def batch_evaluate(
19 | self, gold_file: str, predicted_file: str, evaluation_type: str, **args
20 | ):
21 | pass
22 |
--------------------------------------------------------------------------------
/src/pairag/integrations/data_analysis/text2sql/utils/constants.py:
--------------------------------------------------------------------------------
1 | DEFAULT_DB_DESCRIPTION_PATH = "./localdata/data_analysis/text2sql/description"
2 | DEFAULT_DB_DESCRIPTION_NAME = "db_structured_description.json"
3 | DEFAULT_DESCRIPTION_FOLDER_PATH = "./localdata/data_analysis/text2sql/input_description"
4 |
5 | DEFAULT_DB_HISTORY_PATH = "./localdata/data_analysis/text2sql/history"
6 | DEFAULT_DB_HISTORY_NAME = "db_query_history.json"
7 |
8 | DEFAULT_DB_VALUE_PATH = "./localdata/data_analysis/text2sql/value"
9 | DEFAULT_DB_VALUE_NAME = "db_unique_value.json"
10 |
11 | DEFAULT_TABLE_COMMENT_PATH = "./localdata/data_analysis/text2sql"
12 | DEFAULT_TABLE_COMMENT_NAME = "table_comment.json"
13 |
14 | DESCRIPTION_STORAGE_PATH = (
15 | "./localdata/data_analysis/text2sql/storage/description_index"
16 | )
17 | HISTORY_STORAGE_PATH = "./localdata/data_analysis/text2sql/storage/history_index"
18 | VALUE_STORAGE_PATH = "./localdata/data_analysis/text2sql/storage/value_index"
19 | VALUE_LSH_PATH = "./localdata/data_analysis/text2sql/storage/value_lsh"
20 |
21 | EMBEDDING_DIM_DICT = {"bge-large-zh-v1.5": 1024, "bge-m3": 1024}
22 |
--------------------------------------------------------------------------------
/src/pairag/integrations/embeddings/pai/pai_embedding_config.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 | from pydantic import BaseModel, Field
3 | from enum import Enum
4 | from llama_index.core.constants import DEFAULT_EMBED_BATCH_SIZE
5 | import os
6 |
7 | DEFAULT_HF_EMBED_MODEL = "bge-m3"
8 |
9 |
10 | class SupportedEmbedType(str, Enum):
11 | dashscope = "dashscope"
12 | openai = "openai"
13 | huggingface = "huggingface"
14 |
15 |
16 | class PaiBaseEmbeddingConfig(BaseModel):
17 | source: Literal[SupportedEmbedType.huggingface] = SupportedEmbedType.huggingface
18 | model: str
19 | embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE
20 | enable_sparse: bool = False
21 |
22 | class Config:
23 | frozen = True
24 |
25 | @classmethod
26 | def get_subclasses(cls):
27 | return tuple(cls.__subclasses__())
28 |
29 | @classmethod
30 | def get_type(cls):
31 | return cls.model_fields["source"].default
32 |
33 |
34 | class DashScopeEmbeddingConfig(PaiBaseEmbeddingConfig):
35 | source: Literal[SupportedEmbedType.dashscope] = SupportedEmbedType.dashscope
36 | model: str | None = "text-embedding-v2" # use default
37 | api_key: str | None = Field(default=os.getenv("DASHSCOPE_API_KEY")) # use default
38 |
39 |
40 | class OpenAIEmbeddingConfig(PaiBaseEmbeddingConfig):
41 | source: Literal[SupportedEmbedType.openai] = SupportedEmbedType.openai
42 | model: str | None = None # use default
43 | api_key: str | None = None # use default
44 | api_base: str | None = None # use default
45 |
46 |
47 | class HuggingFaceEmbeddingConfig(PaiBaseEmbeddingConfig):
48 | source: Literal[SupportedEmbedType.huggingface] = SupportedEmbedType.huggingface
49 | model: str | None = DEFAULT_HF_EMBED_MODEL
50 |
51 |
52 | SupporttedEmbeddingClsMap = {
53 | cls.get_type(): cls for cls in PaiBaseEmbeddingConfig.get_subclasses()
54 | }
55 |
56 |
57 | def parse_embed_config(config_data):
58 | if "source" not in config_data:
59 | raise ValueError("Embedding config must contain 'source' field")
60 |
61 | embedding_cls = SupporttedEmbeddingClsMap.get(config_data["source"].lower())
62 | if embedding_cls is None:
63 | raise ValueError(f"Unsupported embedding source: {config_data['source']}")
64 |
65 | return embedding_cls(**config_data)
66 |
67 |
68 | if __name__ == "__main__":
69 | embedding_config_data = {"source": "Openai", "model": "gpt-1", "api_key": None}
70 |
71 | print(parse_embed_config(embedding_config_data))
72 |
--------------------------------------------------------------------------------
/src/pairag/integrations/embeddings/readme.md:
--------------------------------------------------------------------------------
1 | # Custom embedding implementations
2 |
3 | Stay tuned.
4 |
--------------------------------------------------------------------------------
/src/pairag/integrations/llms/pai/open_ai_alike_multi_modal.py:
--------------------------------------------------------------------------------
1 | from llama_index.multi_modal_llms.openai import OpenAIMultiModal
2 | from typing import Dict, Any
3 |
4 |
5 | class OpenAIAlikeMultiModal(OpenAIMultiModal):
6 | def _get_model_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
7 | base_kwargs = {"model": self.model, "temperature": self.temperature, **kwargs}
8 | if self.max_new_tokens is not None:
9 | # If max_tokens is None, don't include in the payload:
10 | # https://platform.openai.com/docs/api-reference/chat
11 | # https://platform.openai.com/docs/api-reference/completions
12 | base_kwargs["max_tokens"] = self.max_new_tokens
13 | return {**base_kwargs, **self.additional_kwargs}
14 |
--------------------------------------------------------------------------------
/src/pairag/integrations/llms/readme.md:
--------------------------------------------------------------------------------
1 | # Custom LLM implementations
2 |
--------------------------------------------------------------------------------
/src/pairag/integrations/query_transform/intent_models.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from typing import List, Optional
3 | from pydantic import BaseModel
4 | from openai.types.completion_usage import CompletionUsage
5 |
6 |
7 | class ChatToolType(str, Enum):
8 | SEARCH_WEB = "search_web"
9 | CHAT_NEWS = "chat_news"
10 | CHAT_KNOWLEDGEBASE = "chat_knowledgebase"
11 | CHAT_DB = "chat_db"
12 | CHAT_LLM = "chat_llm"
13 |
14 |
15 | class ChatIntentType(str, Enum):
16 | SEARCH_WEB = "search_web" # search web
17 | CHAT_LLM = "chat_llm" # llm chat
18 | LIST_NEWS = "list_news" # list news
19 | CHAT_NEWS = "chat_news" # chat news
20 | CHAT_NEWS_LLM = "chat_news_llm" # chat news only by llm
21 | CHAT_KNOWLEDGEBASE = "chat_knowledgebase"
22 | CHAT_DB = "chat_db" # chat sql
23 |
24 |
25 | class IntentResult(BaseModel):
26 | intent: ChatIntentType = ChatIntentType.CHAT_LLM
27 | query_str: str = None
28 | news_topics: Optional[List[str]] = None
29 | token_usage: CompletionUsage = CompletionUsage(
30 | completion_tokens=0, prompt_tokens=0, total_tokens=0
31 | )
32 |
--------------------------------------------------------------------------------
/src/pairag/integrations/search/search_config.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 | from enum import Enum
3 | from typing import Literal
4 |
5 | DEFAULT_ALIYUN_SEARCH_ENDPOINT = "iqs.cn-zhangjiakou.aliyuncs.com"
6 | DEFAULT_GOOGLE_SEARCH_ENDPOINT = "https://serpapi.com/search"
7 | DEFAULT_SEARCH_COUNT = 10
8 | DEFAULT_SEARCH_QA_PROMPT_TEMPLATE = """
9 | 你的目标是根据搜索结果提供准确、有用且易于理解的信息。
10 | # 任务要求:
11 | - 请严格根据提供的参考内容回答问题,并非所有参考内容都与用户的问题密切相关,你需要结合问题,对参考内容进行甄别、筛选。仅参考与问题相关的内容并忽略所有不相关的信息。
12 | - 如果参考内容中没有相关信息或与问题无关,请基于你的已有知识进行回答。
13 | - 确保答案准确、简洁,并且使用与用户提问相同的语种。
14 | - 在回答过程中,请避免使用“从参考内容得出”、“从材料得出”、“根据参考内容”等措辞。
15 | - 保持回答的专业性和友好性。
16 | - 如果需要更多信息来更好地回答问题,请礼貌地询问。
17 | - 对于复杂的问题,尽量简化解释,使信息易于理解。如果回答很长,请尽量结构化、分段落总结。如果需要分点作答,尽量控制在5个点以内,并合并相关的内容。
18 | - 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。
19 | - 除非用户要求,否则请保持输出语种与用户输入问题语种的一致性。
20 | - 对于涉及不安全/不道德/敏感/色情/暴力/赌博/违法等行为的问题,请明确拒绝提供所要求的信息,并简单解释为什么这样的请求不能被满足。
21 | - 你知道今天的日期是{current_datetime},但你不会主动在回复开头提到日期信息。
22 |
23 | # 以下内容是基于用户发送的消息的搜索/查询结果:
24 | {context_str}
25 |
26 | # 以下内容是用户问答历史记录:
27 | {history_str}
28 |
29 | # 以下内容是用户消息:
30 | {query_str}
31 | """
32 |
33 |
34 | class SupportedSearchType(str, Enum):
35 | bing = "bing"
36 | aliyun = "aliyun"
37 | google = "google"
38 |
39 |
40 | class BaseSearchConfig(BaseModel):
41 | source: SupportedSearchType
42 | search_count: int = DEFAULT_SEARCH_COUNT
43 | search_qa_prompt_template: str = DEFAULT_SEARCH_QA_PROMPT_TEMPLATE
44 |
45 | class Config:
46 | frozen = True
47 |
48 | @classmethod
49 | def get_subclasses(cls):
50 | return tuple(cls.__subclasses__())
51 |
52 | @classmethod
53 | def get_type(cls):
54 | return cls.model_fields["source"].default
55 |
56 |
57 | class BingSearchConfig(BaseSearchConfig):
58 | source: Literal[SupportedSearchType.bing] = SupportedSearchType.bing
59 | search_api_key: str | None = None
60 | search_lang: str = "zh-CN"
61 |
62 |
63 | class AliyunSearchConfig(BaseSearchConfig):
64 | source: Literal[SupportedSearchType.aliyun] = SupportedSearchType.aliyun
65 | endpoint: str = DEFAULT_ALIYUN_SEARCH_ENDPOINT
66 | access_key_id: str | None = None
67 | access_key_secret: str | None = None
68 |
69 |
70 | class GoogleSearchConfig(BaseSearchConfig):
71 | source: Literal[SupportedSearchType.google] = SupportedSearchType.google
72 | serpapi_key: str | None = None
73 | search_lang: str = "zh-CN"
74 |
--------------------------------------------------------------------------------
/src/pairag/integrations/trace/trace_config.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 |
3 |
4 | class TraceConfig(BaseModel):
5 | service_name: str | None = None
6 | token: str | None = None
7 | endpoint: str | None = None
8 |
9 | def is_enabled(self) -> bool:
10 | return self.service_name and self.token and self.endpoint
11 |
--------------------------------------------------------------------------------
/src/pairag/integrations/vector_stores/elasticsearch/elasticsearch_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional
2 |
3 | from elasticsearch import AsyncElasticsearch, Elasticsearch
4 |
5 |
6 | def get_user_agent() -> str:
7 | """Get user agent for Elasticsearch client."""
8 | import llama_index.core
9 |
10 | version = getattr(llama_index.core, "__version__", "")
11 | return f"llama_index-py-vs/{version}"
12 |
13 |
14 | def get_elasticsearch_client(
15 | url: Optional[str] = None,
16 | cloud_id: Optional[str] = None,
17 | api_key: Optional[str] = None,
18 | username: Optional[str] = None,
19 | password: Optional[str] = None,
20 | ) -> AsyncElasticsearch:
21 | if url and cloud_id:
22 | raise ValueError(
23 | "Both es_url and cloud_id are defined. Please provide only one."
24 | )
25 |
26 | connection_params: Dict[str, Any] = {}
27 |
28 | if url:
29 | connection_params["hosts"] = [url]
30 | elif cloud_id:
31 | connection_params["cloud_id"] = cloud_id
32 | else:
33 | raise ValueError("Please provide either elasticsearch_url or cloud_id.")
34 |
35 | if api_key:
36 | connection_params["api_key"] = api_key
37 | elif username and password:
38 | connection_params["basic_auth"] = (username, password)
39 |
40 | sync_es_client = Elasticsearch(
41 | **connection_params, headers={"user-agent": get_user_agent()}
42 | )
43 | async_es_client = AsyncElasticsearch(
44 | **connection_params,
45 | headers={"user-agent": get_user_agent()},
46 | request_timeout=60,
47 | retry_on_timeout=True,
48 | max_retries=2,
49 | )
50 |
51 | sync_es_client.info() # use sync client so don't have to 'await' to just get info
52 |
53 | return async_es_client
54 |
--------------------------------------------------------------------------------
/src/pairag/knowledgebase/index/pai/utils/sparse_embed_function.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import os
3 | from typing import List, Optional, Dict
4 | from pairag.utils.constants import DEFAULT_MODEL_DIR
5 |
6 | from loguru import logger
7 |
8 | MODEL_NAME = "bge-m3"
9 |
10 |
11 | class BaseSparseEmbeddingFunction(ABC):
12 | @abstractmethod
13 | def encode_queries(self, queries: List[str]) -> List[Dict[int, float]]:
14 | pass
15 |
16 | @abstractmethod
17 | def encode_documents(self, documents: List[str]) -> List[Dict[int, float]]:
18 | pass
19 |
20 |
21 | class BGEM3SparseEmbeddingFunction(BaseSparseEmbeddingFunction):
22 | def __init__(self, model_name_or_path: Optional[str] = None) -> None:
23 | try:
24 | from FlagEmbedding import BGEM3FlagModel
25 |
26 | model_dir = os.getenv("PAIRAG_MODEL_DIR", DEFAULT_MODEL_DIR)
27 |
28 | self.model = BGEM3FlagModel(
29 | model_name_or_path=os.path.join(
30 | model_name_or_path or model_dir, MODEL_NAME
31 | ),
32 | use_fp16=False,
33 | )
34 | except Exception:
35 | error_info = (
36 | "Cannot import BGEM3FlagModel from FlagEmbedding. It seems it is not installed. "
37 | "Please install it using:\n"
38 | "pip install FlagEmbedding\n",
39 | "error_info",
40 | )
41 |
42 | logger.error(error_info)
43 | raise
44 |
45 | def encode_queries(self, queries: List[str]):
46 | outputs = self.model.encode(
47 | queries, return_dense=False, return_sparse=True, return_colbert_vecs=False
48 | )["lexical_weights"]
49 | return [self._to_standard_dict(output) for output in outputs]
50 |
51 | def encode_documents(self, documents: List[str]):
52 | outputs = self.model.encode(
53 | documents, return_dense=False, return_sparse=True, return_colbert_vecs=False
54 | )["lexical_weights"]
55 | return [self._to_standard_dict(output) for output in outputs]
56 |
57 | def _to_standard_dict(self, raw_output):
58 | result = {}
59 | for k in raw_output:
60 | result[int(k)] = raw_output[k]
61 | return result
62 |
63 |
64 | def get_default_sparse_embedding_function() -> BGEM3SparseEmbeddingFunction:
65 | return BGEM3SparseEmbeddingFunction()
66 |
--------------------------------------------------------------------------------
/src/pairag/knowledgebase/models.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, Field, model_validator
2 | from typing import Annotated, Dict, Union
3 |
4 | from pairag.integrations.embeddings.pai.pai_embedding_config import (
5 | PaiBaseEmbeddingConfig,
6 | )
7 | from pairag.knowledgebase.index.pai.vector_store_config import BaseVectorStoreConfig
8 | from pairag.file.nodeparsers.pai.pai_node_parser import NodeParserConfig
9 | from pairag.integrations.synthesizer.prompt_templates import (
10 | DEFAULT_CUSTOM_PROMPT_TEMPLATE,
11 | DEFAULT_SYSTEM_ROLE_TEMPLATE,
12 | )
13 | from pairag.utils.constants import DEFAULT_KNOWLEDGEBASE_NAME
14 |
15 | from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
16 | from pairag.integrations.postprocessor.pai.pai_postprocessor import (
17 | DEFAULT_RERANK_MODEL,
18 | DEFAULT_RERANK_SIMILARITY_THRESHOLD,
19 | DEFAULT_RERANK_TOP_N,
20 | DEFAULT_SIMILARITY_THRESHOLD,
21 | )
22 |
23 |
24 | class KnowledgeBase(BaseModel):
25 | name: str = Field(
26 | default=DEFAULT_KNOWLEDGEBASE_NAME,
27 | description="Knowledgebase name.",
28 | pattern=r"^[0-9a-zA-Z_-]{3, 20}$",
29 | )
30 |
31 | vector_store_config: Annotated[
32 | Union[BaseVectorStoreConfig.get_subclasses()], Field(discriminator="type")
33 | ]
34 | node_parser_config: NodeParserConfig = Field(default_factory=NodeParserConfig)
35 | embedding_config: Annotated[
36 | Union[PaiBaseEmbeddingConfig.get_subclasses()], Field(discriminator="source")
37 | ]
38 | retrieval_settings: Dict = Field(default_factory=dict)
39 | qa_prompt_templates: Dict = {
40 | "system_prompt_template": DEFAULT_SYSTEM_ROLE_TEMPLATE,
41 | "task_prompt_template": DEFAULT_CUSTOM_PROMPT_TEMPLATE,
42 | }
43 |
44 | @model_validator(mode="before")
45 | def preprocess(cls, values: Dict) -> Dict:
46 | if "index_name" in values:
47 | values["name"] = values["index_name"]
48 | return values
49 |
50 | def model_post_init(self, context):
51 | # 修改retrieval_settings的key
52 | default_retrieval_settings = {
53 | "retrieval_mode": "default",
54 | "similarity_top_k": DEFAULT_SIMILARITY_TOP_K,
55 | "reranker_type": "no-reranker",
56 | "reranker_model": DEFAULT_RERANK_MODEL,
57 | "similarity_threshold": DEFAULT_SIMILARITY_THRESHOLD,
58 | "reranker_similarity_threshold": DEFAULT_RERANK_SIMILARITY_THRESHOLD,
59 | "reranker_similarity_top_k": DEFAULT_RERANK_TOP_N,
60 | }
61 | default_retrieval_settings.update(self.retrieval_settings)
62 | self.retrieval_settings = default_retrieval_settings
63 |
--------------------------------------------------------------------------------
/src/pairag/knowledgebase/utils/knowledgebase_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from loguru import logger
4 | from pairag.utils.constants import DEFAULT_KNOWLEDGEBASE_PATH
5 |
6 |
7 | def delete_file(file_path):
8 | try:
9 | os.remove(file_path)
10 | logger.info(f"File {file_path} successfully deleted.")
11 | except FileNotFoundError:
12 | logger.error(f"File {file_path} does not exist.")
13 | except Exception as e:
14 | logger.error(f"Error deleting file {file_path}: {e}")
15 |
16 |
17 | def delete_dir(folder_path):
18 | try:
19 | shutil.rmtree(folder_path)
20 | logger.info(f"Folder {folder_path} and its contents successfully deleted.")
21 | except FileNotFoundError:
22 | logger.error(f"Folder {folder_path} does not exist.")
23 | except Exception as e:
24 | logger.error(f"Error deleting folder {folder_path}: {e}")
25 |
26 |
27 | def delete_knowledgebase_dir(index_name):
28 | destination_folder = os.path.join(DEFAULT_KNOWLEDGEBASE_PATH, index_name)
29 | delete_dir(destination_folder)
30 |
31 |
32 | def delete_default_knowledgebase_dir():
33 | destination_folder = os.path.join(DEFAULT_KNOWLEDGEBASE_PATH, "default")
34 | delete_dir(destination_folder)
35 |
36 |
37 | def write_markdown_to_parse_dir(md_content, file_name, parse_dir):
38 | destination_md_file = os.path.join(parse_dir, f"{file_name}.md")
39 | try:
40 | with open(destination_md_file, "w", encoding="utf-8") as md_file:
41 | md_file.write(md_content)
42 | print(f"成功写入{destination_md_file}")
43 | except IOError as e:
44 | print(f"写入文件时出错: {e}")
45 |
46 |
47 | def copy_original_files_to_parse_dir(file_path, parse_dir):
48 | try:
49 | if os.path.isfile(file_path):
50 | shutil.copy(file_path, parse_dir)
51 | print(f"已复制: {file_path} 到 {parse_dir}")
52 | else:
53 | print(f"源文件不存在: {file_path}")
54 | except Exception as e:
55 | print(f"复制文件时出错: {file_path} -> {e}")
56 |
--------------------------------------------------------------------------------
/src/pairag/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/src/pairag/utils/__init__.py
--------------------------------------------------------------------------------
/src/pairag/utils/constants.py:
--------------------------------------------------------------------------------
1 | """Set of constants of modules."""
2 |
3 | import os
4 |
5 |
6 | def try_get_int_env(key, default_value=None):
7 | """
8 | Retrieves an integer from an environment variable.
9 | """
10 |
11 | value_str = os.getenv(key)
12 | if value_str is None:
13 | if default_value is not None:
14 | return default_value
15 | else:
16 | return None
17 | try:
18 | return int(value_str)
19 | except ValueError:
20 | return None
21 |
22 |
23 | EAS_DEFAULT_MODEL_DIR = "/huggingface/pai_rag_model_repository_01"
24 | if not os.path.exists(EAS_DEFAULT_MODEL_DIR):
25 | DEFAULT_MODEL_DIR = "./model_repository"
26 | else:
27 | DEFAULT_MODEL_DIR = EAS_DEFAULT_MODEL_DIR
28 |
29 | OSS_URL = "https://pai-rag-bj.oss-cn-beijing.aliyuncs.com/model_repository/model_config_1.1.0.json"
30 |
31 | DEFAULT_DATAFILE_DIR = "./data"
32 |
33 | DEFAULT_DASHSCOPE_EMBEDDING_MODEL = "text-embedding-v2"
34 |
35 |
36 | DEFAULT_TASK_FILE = "localdata/ingestion__task__summary.json"
37 |
38 | DEFAULT_INDEX_FILE = "localdata/default__rag__index.json"
39 | DEFAULT_INDEX_NAME = "default"
40 | DEFAULT_INDEX_NAME_OLD = "default_index"
41 |
42 | DEFAULT_KNOWLEDGEBASE_PATH = "localdata/knowledgebase"
43 | DEFAULT_KNOWLEDGEBASE_NAME = "default"
44 | DEFAULT_KNOWLEDGEBASE_NAME_OLD = "default_index"
45 | DEFAULT_KNOWLEDGEBASE_FILE = "localdata/default__rag__index.json"
46 | DEFAULT_DOC_STORE_NAME = "default__knowledge__docs.json"
47 | DEFAULT_MAX_KNOWLEDGEBASE_COUNT = try_get_int_env(
48 | "DEFAULT_MAX_KNOWLEDGEBASE_COUNT", 3000
49 | )
50 | DEFAULT_MAX_FILE_TASK_COUNT = try_get_int_env("DEFAULT_MAX_FILE_TASK_COUNT", 10000)
51 |
--------------------------------------------------------------------------------
/src/pairag/utils/cuda_utils.py:
--------------------------------------------------------------------------------
1 | from loguru import logger
2 | import torch
3 | import os
4 |
5 |
6 | USE_CUDA = os.environ.get("USE_CUDA", "false")
7 |
8 |
9 | def should_use_cuda():
10 | if not torch.cuda.is_available():
11 | return False
12 |
13 | if USE_CUDA.lower() == "true" or USE_CUDA == "1":
14 | return True
15 | else:
16 | return False
17 |
18 |
19 | def infer_cuda_device() -> str:
20 | if should_use_cuda():
21 | logger.info("Using cuda device.")
22 | return "cuda"
23 | else:
24 | logger.info(
25 | "Will not use CUDA device acceleration. If you want to use cuda, please set the environment variable USE_CUDA=1."
26 | )
27 | return "cpu"
28 |
--------------------------------------------------------------------------------
/src/pairag/utils/file_utils.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import logging
3 | from loguru import logger
4 | from tenacity import (
5 | before_sleep_log,
6 | retry,
7 | stop_after_attempt,
8 | wait_fixed,
9 | retry_if_exception_type,
10 | )
11 | import os
12 |
13 |
14 | def get_modified_time(file_path):
15 | state = os.stat(file_path)
16 | return state.st_mtime
17 |
18 |
19 | def generate_text_md5(text):
20 | text_md5 = hashlib.md5() # Create an MD5 hash object
21 | # Encode the file path string to bytes and update the hash
22 | text_md5.update(text.encode("utf-8"))
23 | return text_md5.hexdigest()
24 |
25 |
26 | # 读取文件的retry机制
27 | @retry(
28 | wait=wait_fixed(1),
29 | stop=stop_after_attempt(10),
30 | retry=retry_if_exception_type(OSError),
31 | before_sleep=before_sleep_log(logger, logging.INFO),
32 | )
33 | def generate_file_md5(file_path):
34 | with open(file_path, "rb") as file:
35 | file_content_md5 = hashlib.md5() # Create an MD5 hash object
36 |
37 | while chunk := file.read(8192): # Read the file in 8 KB chunks
38 | file_content_md5.update(chunk) # Update the hash with the chunk
39 |
40 | return file_content_md5.hexdigest()
41 |
42 |
43 | def generate_md5(file_path):
44 | """Generate MD5 hash of the content of the specified file."""
45 | try:
46 | return (
47 | generate_text_md5(file_path),
48 | generate_file_md5(file_path),
49 | ) # Return the hexadecimal representation of the hash
50 | except Exception as ex:
51 | logger.error(f"Error generating md5 for file '{file_path}'.")
52 | raise ex
53 |
--------------------------------------------------------------------------------
/src/pairag/utils/format_logging.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 | from loguru import logger
4 |
5 |
6 | class InterceptHandler(logging.Handler):
7 | def emit(self, record: logging.LogRecord) -> None:
8 | level: str | int
9 | # 尝试获取与标准 logging 等级相对应的 Loguru 日志等级
10 | try:
11 | level = logger.level(record.levelname).name
12 | except ValueError:
13 | # 如果找不到对应的 Loguru 等级,则使用原始的数字等级
14 | level = record.levelno
15 |
16 | # 探测调用日志的代码位置
17 | frame, depth = logging.currentframe(), 2
18 | while frame.f_code.co_filename == logging.__file__:
19 | frame = frame.f_back
20 | depth += 1
21 | # 使用 Loguru 记录日志信息,保持调用栈的深度和异常信息
22 | logger.opt(
23 | depth=depth,
24 | exception=record.exc_info,
25 | ).log(level, record.getMessage())
26 |
27 |
28 | def format_logging():
29 | logging.basicConfig(handlers=[InterceptHandler()], level=logging.INFO, force=True)
30 | logger.remove(0)
31 | logger.add(
32 | sys.stderr,
33 | format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {process} | {level: <8} | {name}:{function}:{line} - {message}",
34 | )
35 |
--------------------------------------------------------------------------------
/src/pairag/utils/json_parser.py:
--------------------------------------------------------------------------------
1 | import json
2 | from loguru import logger
3 |
4 |
5 | def parse_json_from_code_block_str(input_str):
6 | start = input_str.find("{")
7 | end = input_str.find("}", start + 1)
8 | if start != -1 and end != -1:
9 | content = input_str[start : end + 1]
10 | try:
11 | data = json.loads(content)
12 | logger.debug(f"解析后的 JSON 对象:{data}")
13 | return data
14 | except json.JSONDecodeError as e:
15 | logger.debug("JSON 解码错误:", e)
16 | return json.loads('{ "queries": [] }')
17 | else:
18 | logger.debug("未找到有效的JSON对象。")
19 | return json.loads('{ "queries": [] }')
20 |
--------------------------------------------------------------------------------
/src/pairag/utils/score_utils.py:
--------------------------------------------------------------------------------
1 | def normalize_cosine_similarity_score(sim_score):
2 | return round((1 + sim_score) / 2, 6)
3 |
--------------------------------------------------------------------------------
/src/pairag/utils/time_utils.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 |
3 |
4 | def get_current_time_str() -> str:
5 | return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
6 |
7 |
8 | def get_prompt_current_time_str() -> str:
9 | return datetime.now().strftime("%Y年%m月%d日 %H:%M:%S")
10 |
--------------------------------------------------------------------------------
/src/pairag/web/element_manager.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, Dict, Generator, List, Tuple
2 |
3 | if TYPE_CHECKING:
4 | from gradio.components import Component
5 |
6 |
7 | class ElementManager:
8 | def __init__(self) -> None:
9 | self._id_to_elem: Dict[str, "Component"] = {}
10 | self._elem_to_id: Dict["Component", str] = {}
11 |
12 | def add_elems(self, elem_dict: Dict[str, "Component"]) -> None:
13 | r"""
14 | Adds elements to manager.
15 | """
16 | for elem_id, elem in elem_dict.items():
17 | self._id_to_elem[elem_id] = elem
18 | self._elem_to_id[elem] = elem_id
19 |
20 | def get_elem_list(self) -> List["Component"]:
21 | r"""
22 | Returns the list of all elements.
23 | """
24 | return list(self._id_to_elem.values())
25 |
26 | def get_elem_iter(self) -> Generator[Tuple[str, "Component"], None, None]:
27 | r"""
28 | Returns an iterator over all elements with their names.
29 | """
30 | for elem_id, elem in self._id_to_elem.items():
31 | yield elem_id.split(".")[-1], elem
32 |
33 | def get_elem_by_id(self, elem_id: str) -> "Component":
34 | r"""
35 | Gets element by id.
36 |
37 | Example: top.lang, train.dataset
38 | """
39 | return self._id_to_elem[elem_id]
40 |
41 | def get_id_by_elem(self, elem: "Component") -> str:
42 | r"""
43 | Gets id by element.
44 | """
45 | return self._elem_to_id[elem]
46 |
47 |
48 | elem_manager = ElementManager()
49 |
--------------------------------------------------------------------------------
/src/pairag/web/filebrowser/constants.py:
--------------------------------------------------------------------------------
1 | from pairag.utils.constants import try_get_int_env
2 |
3 | DEFAULT_FILE_BROWSER_PORT = try_get_int_env("DEFAULT_FILE_BROWSER_PORT", 8012)
4 |
5 | FILEBROWSER_PREFIX = "/filebrowser/api/resources"
6 | FILEBROWSER_PREFIX_LEN = len(FILEBROWSER_PREFIX)
7 |
--------------------------------------------------------------------------------
/src/pairag/web/filebrowser/request_utils.py:
--------------------------------------------------------------------------------
1 | from fastapi import Request, Response
2 | from pairag.web.filebrowser.constants import (
3 | DEFAULT_FILE_BROWSER_PORT,
4 | )
5 | import aiohttp
6 | from loguru import logger
7 |
8 |
9 | def clean_headers(headers: dict, keys):
10 | for k in keys:
11 | headers.pop(k, None)
12 | headers.pop(str.lower(k), None)
13 | return headers
14 |
15 |
16 | async def sender_data(req: Request):
17 | async for chunk in req.stream():
18 | yield chunk
19 |
20 |
21 | async def postprocess_middleware_to_filebrowser(session, request, url):
22 | async with session.request(
23 | request.method,
24 | str(url),
25 | headers=clean_headers(dict(request.headers), ["Transfer-Encoding"]),
26 | params=str(request.path_params),
27 | data=sender_data(request),
28 | allow_redirects=False,
29 | ) as resp:
30 | content = await resp.content.read()
31 | return Response(
32 | content=content,
33 | headers=clean_headers(
34 | dict(resp.headers), ["Content-Encoding", "Content-Length"]
35 | ),
36 | status_code=resp.status,
37 | )
38 |
39 |
40 | async def postprocess_middleware(request, call_next):
41 | logger.debug(f"request_path: {request.url.path} , method: {request.method}")
42 | if "/filebrowser" in request.url.path:
43 | url = request.url.replace(
44 | scheme="http", hostname="localhost", port=DEFAULT_FILE_BROWSER_PORT
45 | )
46 | async with aiohttp.ClientSession(
47 | timeout=aiohttp.ClientTimeout(
48 | total=500 * 60,
49 | connect=500 * 60,
50 | )
51 | ) as session:
52 | return await postprocess_middleware_to_filebrowser(session, request, url)
53 | else:
54 | response = await call_next(request)
55 | return response
56 |
--------------------------------------------------------------------------------
/src/pairag/web/tabs/history_tab.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | import pandas as pd
3 | from pairag.web.rag_local_client import rag_client
4 |
5 |
6 | def refresh_upload_history(knowledgebase_name):
7 | upload_jobs = rag_client.get_upload_history(knowledgebase_name)
8 | if upload_jobs:
9 | history_data = pd.DataFrame(upload_jobs)
10 | history_data["文件名"] = history_data["file_name"]
11 | history_data["操作"] = history_data["operation"]
12 | history_data["上传状态"] = history_data["status"]
13 | history_data["更新时间"] = history_data["last_modified_time"]
14 | history_data["错误原因"] = history_data["message"]
15 |
16 | history_data = history_data[["文件名", "操作", "上传状态", "更新时间", "错误原因"]].sort_values(
17 | by="更新时间", ascending=False
18 | )
19 | else:
20 | history_data = pd.DataFrame(columns=["文件名", "操作", "上传状态", "更新时间", "错误原因"])
21 | summary = f"累计上传{len(upload_jobs)}个文件。"
22 | if len(upload_jobs) > 0:
23 | finished = [job for job in upload_jobs if job["status"] == "done"]
24 | failed = [job for job in upload_jobs if job["status"] == "failed"]
25 | summary += f"成功上传{len(finished)}个文件,失败{len(failed)}个文件。"
26 |
27 | return [
28 | gr.update(value=summary),
29 | gr.update(value=history_data),
30 | ]
31 |
32 |
33 | def create_upload_history():
34 | with gr.Row():
35 | history_index = gr.Dropdown(
36 | choices=[],
37 | value="",
38 | label="\N{bookmark} 知识库名称",
39 | elem_id="history_index",
40 | allow_custom_value=True,
41 | )
42 |
43 | upload_summary = gr.HTML(value="累计上传0个文件", elem_id="upload_summary")
44 | refresh_button = gr.Button(
45 | value="刷新",
46 | elem_id="refresh_button",
47 | variant="primary",
48 | )
49 | upload_history = gr.DataFrame(
50 | label="上传历史",
51 | visible=True,
52 | elem_id="upload_history",
53 | headers=["文件名", "上传状态", "更新时间", "错误原因"],
54 | )
55 |
56 | history_index.change(
57 | fn=refresh_upload_history,
58 | inputs=[history_index],
59 | outputs=[upload_summary, upload_history],
60 | )
61 |
62 | refresh_button.click(
63 | fn=refresh_upload_history,
64 | inputs=[history_index],
65 | outputs=[upload_summary, upload_history],
66 | )
67 |
68 | return {
69 | history_index.elem_id: history_index,
70 | upload_summary.elem_id: upload_summary,
71 | upload_history.elem_id: upload_history,
72 | }
73 |
--------------------------------------------------------------------------------
/src/pairag/web/tabs/model/index_info.py:
--------------------------------------------------------------------------------
1 | from pairag.web.rag_local_client import rag_client
2 |
3 |
4 | def get_index_map():
5 | index_map = rag_client.list_indexes()
6 | return index_map
7 |
--------------------------------------------------------------------------------
/src/pairag/web/utils.py:
--------------------------------------------------------------------------------
1 | from typing import List, Any, Dict
2 |
3 |
4 | def components_to_dict(components: List[Any]) -> Dict[str, Any]:
5 | return {c.elem_id: c for c in components}
6 |
7 |
8 | def check_variables_in_string(text, variables):
9 | missing_variables = [var for var in variables if f"{{{var}}}" not in text]
10 | if missing_variables:
11 | raise ValueError(f"以下变量名缺失: {', '.join(missing_variables)}")
12 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/tests/__init__.py
--------------------------------------------------------------------------------
/tests/app/test_openai_embedding.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from httpx import ASGITransport, AsyncClient
3 | from pairag.app.app import app
4 |
5 |
6 | @pytest.mark.asyncio(scope="session")
7 | async def test_embedding():
8 | import openai
9 |
10 | async with AsyncClient(
11 | transport=ASGITransport(app=app), base_url="http://test"
12 | ) as client:
13 | client = openai.AsyncClient(
14 | base_url="http://test/v1", api_key="123", http_client=client
15 | )
16 | embedding_result = await client.embeddings.create(
17 | input="hello world",
18 | model="bge-m3",
19 | )
20 |
21 | assert len(embedding_result.data[0].embedding) == 1024
22 | assert embedding_result.data[0].index == 0
23 |
24 | embedding_result = await client.embeddings.create(
25 | input="",
26 | model="bge-m3",
27 | )
28 |
29 | assert len(embedding_result.data[0].embedding) == 1024
30 | assert embedding_result.data[0].index == 0
31 |
32 | embedding_result = await client.embeddings.create(
33 | input=["", "hi", "你在干什么"],
34 | model="bge-m3",
35 | )
36 |
37 | assert len(embedding_result.data[0].embedding) == 1024
38 | assert len(embedding_result.data[1].embedding) == 1024
39 | assert len(embedding_result.data[2].embedding) == 1024
40 | assert embedding_result.data[0].index == 0
41 | assert embedding_result.data[1].index == 1
42 | assert embedding_result.data[2].index == 2
43 |
44 | try:
45 | embedding_result = await client.embeddings.create(
46 | input=None,
47 | model="bge-m3",
48 | )
49 | raise Exception("should not reach here.")
50 | except Exception as e:
51 | print(e)
52 |
--------------------------------------------------------------------------------
/tests/integrations/llm/test_function_calling_llm.py:
--------------------------------------------------------------------------------
1 | from llama_index.core.tools import FunctionTool
2 | from pairag.integrations.llms.pai.pai_llm import PaiLlm
3 | from pairag.integrations.llms.pai.llm_config import OpenAICompatibleLlmConfig
4 | import os
5 |
6 |
7 | fc_llm_config = OpenAICompatibleLlmConfig(
8 | model="qwen-max", api_key=os.environ.get("DASHSCOPE_API_KEY")
9 | )
10 |
11 | fc_llm = PaiLlm(fc_llm_config)
12 |
13 |
14 | def multiply(a: int, b: int) -> int:
15 | """Multiple two integers and returns the result integer"""
16 | return a * b
17 |
18 |
19 | multiply_tool = FunctionTool.from_defaults(fn=multiply)
20 |
21 |
22 | def add(a: int, b: int) -> int:
23 | """Add two integers and returns the result integer"""
24 | return a + b
25 |
26 |
27 | add_tool = FunctionTool.from_defaults(fn=add)
28 |
29 | tools = [multiply_tool, add_tool]
30 |
31 |
32 | def test_fc_llm_chat_with_tools():
33 | response = fc_llm.chat_with_tools(tools=tools, user_msg="What is (121 + 2) * 5?")
34 | tool_calls = fc_llm.get_tool_calls_from_response(
35 | response, error_on_no_tool_call=False
36 | )
37 | assert len(tool_calls) > 0
38 | for _, tool_call in enumerate(tool_calls):
39 | if tool_call.tool_name == "add":
40 | assert tool_call.tool_kwargs["a"] == 121
41 | assert tool_call.tool_kwargs["b"] == 2
42 | if tool_call.tool_name == "multiply":
43 | assert tool_call.tool_kwargs["a"] == 123
44 | assert tool_call.tool_kwargs["b"] == 5
45 |
--------------------------------------------------------------------------------
/tests/integrations/test_nl2pandas_retriever.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 | import pandas as pd
4 |
5 | from llama_index.llms.dashscope import DashScope
6 | from llama_index.embeddings.dashscope import DashScopeEmbedding
7 | from llama_index.core import Settings
8 |
9 | from pairag.integrations.data_analysis.nl2pandas_retriever import PandasQueryRetriever
10 |
11 |
12 | dashscope_key = os.environ.get("DASHSCOPE_API_KEY")
13 | llm = DashScope(model_name="qwen-max", temperature=0.1, api_key=dashscope_key)
14 | embed_model = DashScopeEmbedding(embed_batch_size=10, api_key=dashscope_key)
15 | Settings.llm = llm
16 | Settings.embed_model = embed_model
17 |
18 |
19 | @pytest.mark.skipif(
20 | os.getenv("DASHSCOPE_API_KEY") is None, reason="no llm api key provided"
21 | )
22 | def test_pandas_query_retriever():
23 | file_path = "./tests/testdata/csv_data/titanic_train.csv"
24 | df = pd.read_csv(file_path)
25 | data_analysis_retriever = PandasQueryRetriever(df)
26 | query = "What is the correlation between survival and age?"
27 |
28 | retrieved_res = data_analysis_retriever.retrieve(query)
29 |
30 | assert (
31 | retrieved_res[0].metadata["query_code_instruction"]
32 | == "df['survived'].corr(df['age'])"
33 | )
34 |
35 | assert eval(retrieved_res[0].metadata["query_output"]) < 0
36 |
--------------------------------------------------------------------------------
/tests/testdata/db_data/pets.db:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/tests/testdata/db_data/pets.db
--------------------------------------------------------------------------------
/tests/testdata/db_data/pets.sqlite:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/tests/testdata/db_data/pets.sqlite
--------------------------------------------------------------------------------