├── .github └── workflows │ ├── ci.yml │ ├── docker.yml │ ├── hotfix.yml │ ├── main.yml │ └── release.yml ├── .gitignore ├── .pre-commit-config.yaml ├── Dockerfile ├── Dockerfile_ingestion ├── Makefile ├── README.md ├── README_zh.md ├── docker ├── .env.example └── compose.yaml ├── docs ├── agentic_rag.md ├── api.md ├── api_v0.2.0.md ├── api_v0.3.0.md ├── api_zh.md ├── config_guide_cn.md ├── config_guide_en.md ├── data_analysis_doc.md ├── data_analysis_doc_250303.md ├── data_analysis_doc_zh.md ├── develop │ ├── local_develop.md │ └── local_develop_zh.md ├── docker_build.md ├── eas_deploy.md ├── eas_deploy_deepseek.md ├── figures │ ├── agent │ │ ├── agenda_query.png │ │ ├── agent_tab.png │ │ ├── date_query.png │ │ ├── nl2sql_query.png │ │ ├── rag_knowledge_query.png │ │ ├── weather_query.png │ │ └── web_search.png │ ├── data_analysis │ │ ├── DBChat_overview.png │ │ ├── da_db_chat.png │ │ ├── da_db_config.png │ │ ├── da_db_enhance.png │ │ ├── da_db_load.png │ │ ├── da_db_prompt.png │ │ ├── da_db_prompt_reset.png │ │ ├── da_llm_config.png │ │ ├── da_overview.png │ │ ├── da_sheet_chat.png │ │ ├── da_sheet_upload.png │ │ ├── data_analysis_overview.png │ │ ├── datafile_chat.png │ │ ├── datafile_config.png │ │ ├── db_chat.png │ │ ├── db_chat_with_memo.png │ │ ├── db_config.png │ │ ├── db_config_update.png │ │ ├── db_enhanced_features.png │ │ ├── db_info_load.png │ │ ├── db_query_desc.png │ │ ├── db_query_no_desc.png │ │ ├── llm_config.png │ │ ├── llm_selection.png │ │ ├── prompt_config.png │ │ ├── prompt_reset.png │ │ ├── sheet_data_preview.png │ │ ├── sheet_upload.png │ │ └── table_example.png │ ├── deepseek │ │ ├── deepseek_eas_api.png │ │ ├── llm_chat.png │ │ ├── llm_config.png │ │ └── rag_chat.png │ ├── deploy │ │ └── eas │ │ │ ├── deploy_json.png │ │ │ ├── deploy_portal.png │ │ │ ├── deploy_resources.jpg │ │ │ ├── deploy_success.jpg │ │ │ ├── deploy_vpc.jpg │ │ │ ├── edited_json.jpg │ │ │ ├── enable_web.jpg │ │ │ ├── trace_detail.jpg │ │ │ ├── trace_json.jpg │ │ │ ├── trace_key.jpg │ │ │ ├── trace_percent.jpg │ │ │ └── view_web.jpg │ ├── elastic │ │ ├── aliyun_es_ik_hot_update.png │ │ ├── aliyun_es_instance.png │ │ ├── aliyun_es_instance_autoindex.png │ │ ├── aliyun_es_instance_info.png │ │ ├── aliyun_es_kibana.png │ │ ├── aliyun_es_kibana_entry.png │ │ ├── aliyun_es_kibana_index_detail.png │ │ ├── aliyun_es_kibana_index_management.png │ │ ├── aliyun_es_kibana_management.png │ │ ├── aliyun_es_kibana_menu.png │ │ ├── aliyun_es_kibana_whitelist.png │ │ ├── aliyun_es_overview.png │ │ ├── aliyun_es_password.png │ │ ├── aliyun_es_password_reset.png │ │ ├── aliyun_es_plugin.png │ │ ├── aliyun_es_upload_dic.png │ │ ├── new_word_dict.png │ │ ├── pairag_es_connect.png │ │ ├── pairag_es_param_list.png │ │ ├── pairag_retrieval_mode.png │ │ └── pairag_select_es.png │ ├── framework.jpg │ ├── index │ │ ├── add_index.png │ │ ├── config_new_index.png │ │ └── select_index.png │ ├── multimodal │ │ ├── mm_chat.jpg │ │ ├── mm_settings.jpg │ │ └── mm_upload.jpg │ └── quick_start │ │ ├── query.png │ │ ├── setting.png │ │ └── upload.png ├── index │ └── multi_index.md ├── multimodal_rag.md ├── qca_generation_and_evaluation.md ├── tabular_doc.md └── vectordb │ ├── elasticsearch.md │ ├── opensearch_image.json │ └── opensearch_text.json ├── example_data ├── cn_clip │ └── pokemon.jpeg ├── eval_docs_crag_small │ ├── crag-task1-straightforward-eval-msgs_1.jsonl │ └── crag-task1-straightforward-eval-msgs_2.jsonl ├── eval_docs_image │ └── 跑鞋推荐.pdf ├── eval_docs_image_example │ └── multimodal_eval_dataset_zh_example.jsonl ├── eval_docs_text │ ├── EasyRec.txt │ └── PAI.txt ├── function_tools │ └── api-tool-with-intent-detection-for-travel-assistant │ │ ├── dataset │ │ ├── 上海攻略信息.pdf │ │ └── 北京攻略信息.pdf │ │ ├── figures │ │ ├── agent_chat.jpg │ │ ├── agent_config.jpg │ │ ├── settings.jpg │ │ └── upload.jpg │ │ ├── mock_api │ │ └── main.py │ │ ├── settings.toml │ │ ├── tools.json │ │ └── tools.py ├── pai_document.pdf └── paul_graham │ └── paul_graham_essay.txt ├── integrations └── pairag-file │ ├── README.md │ ├── pairag │ └── file │ │ ├── __init__.py │ │ ├── nodeparsers │ │ ├── __init__.py │ │ ├── pai │ │ │ ├── __init__.py │ │ │ ├── constants.py │ │ │ ├── image_caption_tool.py │ │ │ ├── pai_markdown_parser.py │ │ │ └── pai_node_parser.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ └── pai_markdown_tree.py │ │ ├── readers │ │ ├── __init__.py │ │ ├── pai │ │ │ ├── __init__.py │ │ │ ├── constants.py │ │ │ ├── file_readers │ │ │ │ ├── __init__.py │ │ │ │ ├── pai_csv_reader.py │ │ │ │ ├── pai_docx_reader.py │ │ │ │ ├── pai_excel_reader.py │ │ │ │ ├── pai_html_reader.py │ │ │ │ ├── pai_image_reader.py │ │ │ │ ├── pai_jsonl_reader.py │ │ │ │ ├── pai_markdown_reader.py │ │ │ │ ├── pai_pdf_reader.py │ │ │ │ └── pai_pptx_reader.py │ │ │ ├── pai_data_reader.py │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── cuda_utils.py │ │ │ │ ├── image_utils.py │ │ │ │ ├── magic-pdf.template.json │ │ │ │ ├── markdown_utils.py │ │ │ │ └── modelscope_utils.py │ │ └── readme.md │ │ └── store │ │ ├── __init__.py │ │ ├── oss_store.py │ │ └── pai_image_store.py │ ├── poetry.lock │ ├── pyproject.toml │ └── tests │ ├── __init__.py │ ├── nodeparsers │ └── test_markdown_mode_parser.py │ ├── openailike_multimodal.py │ ├── readers │ ├── test_csv_reader.py │ ├── test_docx_reader.py │ ├── test_excel_reader.py │ ├── test_html_reader.py │ ├── test_image_reader.py │ ├── test_jsonl_reader.py │ ├── test_pdf_reader.py │ └── test_post_process_multi_level_headings.py │ ├── test_image_caption_tool.py │ ├── test_image_utils.py │ ├── test_load_data.py │ └── testdata │ ├── csv_data │ ├── 30天留存率_csv_two_header.csv │ └── titanic_train.csv │ ├── db_data │ ├── pets.db │ └── pets.sqlite │ ├── docx_data │ └── 大模型RAG对话系统.docx │ ├── excel_data │ └── 30天留存率_excel_two_header.xlsx │ ├── html_data │ ├── AIGC Stable Diffusion文生图Lora模型微调实现虚拟上装.html │ ├── AI写真计费说明.html │ ├── EAS计费说明.html │ ├── 多媒体分析计费说明.html │ └── 聚类模型评估.html │ ├── image_data │ └── 11.jpg │ ├── json_data │ └── pai_document.json │ ├── jsonl_data │ └── price.jsonl │ ├── md_data │ └── pai_document.md │ └── pdf_data │ └── pai_document.pdf ├── poetry.lock ├── pyproject.toml ├── scripts ├── gunicorn.conf.py ├── load_data.sh ├── load_data_by_step.sh ├── locust.py └── start.sh ├── setup.py ├── src └── pairag │ ├── api │ ├── api_chat.py │ ├── api_router_v1.py │ ├── chat_completions.py │ ├── exception_handler.py │ └── middleware.py │ ├── app │ ├── app.py │ └── constants.py │ ├── chat │ ├── chat_app.py │ ├── chat_flow.py │ ├── models.py │ └── utils │ │ └── chat_utils.py │ ├── config │ ├── evaluation │ │ ├── config.yaml │ │ ├── settings_eval_for_crag_text.toml │ │ ├── settings_eval_for_image.toml │ │ └── settings_eval_for_text.toml │ └── settings.toml │ ├── core │ ├── chat_service.py │ ├── models │ │ ├── config.py │ │ ├── container.py │ │ ├── errors.py │ │ └── state.py │ ├── rag_config.py │ ├── rag_config_manager.py │ ├── rag_environment.py │ ├── rag_module.py │ └── service_daemon.py │ ├── data_pipeline │ ├── constants.py │ ├── datasource │ │ └── filedelta_datasource.py │ ├── delta │ │ ├── list_delta.py │ │ ├── milvus.py │ │ └── models.py │ ├── e2e_config.yaml │ ├── job │ │ ├── file_task_executor.py │ │ └── rag_job_manager.py │ ├── main.py │ ├── models │ │ ├── config │ │ │ ├── datasource.py │ │ │ └── operator.py │ │ ├── file │ │ │ └── event.py │ │ └── models.py │ ├── operators │ │ ├── base.py │ │ ├── embedder.py │ │ ├── parser.py │ │ ├── sink.py │ │ └── split.py │ ├── ray_executor.py │ ├── readme.md │ └── utils │ │ ├── compute_resource_utils.py │ │ ├── concurrency_utils.py │ │ ├── cuda_utils.py │ │ ├── dataset_utils.py │ │ ├── download_utils.py │ │ ├── file_ext_utils.py │ │ ├── filename_utils.py │ │ ├── memory_utils.py │ │ ├── node_utils.py │ │ ├── path_resolver.py │ │ ├── path_utils.py │ │ └── vectordb_utils.py │ ├── evaluation │ ├── dataset │ │ ├── crag │ │ │ ├── crag_data_loader.py │ │ │ └── crag_jsonl_reader.py │ │ ├── open_dataset.py │ │ ├── rag_eval_dataset_refactor.py │ │ ├── rag_qca_dataset_refactor.py │ │ └── state_manager.py │ ├── evaluator │ │ ├── base_evaluator.py │ │ └── pai_evaluator.py │ ├── generator │ │ └── rag_qca_generator.py │ ├── metrics │ │ ├── response │ │ │ ├── base.py │ │ │ ├── correctness.py │ │ │ └── faithfulness.py │ │ └── retrieval │ │ │ ├── core.py │ │ │ ├── hitrate.py │ │ │ └── mrr.py │ ├── pipeline │ │ ├── run_evaluation_pipeline.py │ │ └── run_multimodal_evaluation_pipeline.py │ ├── run_evaluation_experiments.py │ └── utils │ │ ├── create_components.py │ │ └── file_utils.py │ ├── extensions │ └── news │ │ ├── miaobi_news.py │ │ └── news_config.py │ ├── integrations │ ├── data_analysis │ │ ├── data_analysis_config.py │ │ ├── data_analysis_tool.py │ │ ├── nl2pandas_retriever.py │ │ ├── nl2sql │ │ │ ├── db_connector.py │ │ │ ├── db_descriptor.py │ │ │ ├── db_indexer.py │ │ │ ├── db_loader.py │ │ │ ├── db_preretriever.py │ │ │ ├── db_query.py │ │ │ ├── db_selector.py │ │ │ ├── db_utils │ │ │ │ ├── constants.py │ │ │ │ └── nl2sql_utils.py │ │ │ ├── nl2sql_prompts.py │ │ │ ├── query_preprocessor.py │ │ │ └── sql_generator.py │ │ ├── nl2sql_retriever.py │ │ ├── pandas_instruction_parser.py │ │ ├── test │ │ │ ├── test_bird_schema_collector.py │ │ │ ├── test_db_connector.py │ │ │ ├── test_db_info_collector.py │ │ │ ├── test_index_retriever.py │ │ │ └── test_query_processor.py │ │ └── text2sql │ │ │ ├── db_connector.py │ │ │ ├── db_info_collector.py │ │ │ ├── db_info_index.py │ │ │ ├── db_info_node.py │ │ │ ├── db_info_retriever.py │ │ │ ├── db_info_selector.py │ │ │ ├── db_loader.py │ │ │ ├── db_query.py │ │ │ ├── db_retriever_filter.py │ │ │ ├── evaluations │ │ │ ├── base_evaluator.py │ │ │ ├── bird_evaluator.py │ │ │ ├── eval_bird │ │ │ │ └── evaluation.py │ │ │ ├── eval_spider │ │ │ │ ├── evaluation.py │ │ │ │ ├── exec_eval.py │ │ │ │ ├── parse.py │ │ │ │ └── process_sql.py │ │ │ └── spider_evaluator.py │ │ │ ├── query_processor.py │ │ │ ├── sql_generator.py │ │ │ └── utils │ │ │ ├── constants.py │ │ │ ├── info_utils.py │ │ │ ├── prompts.py │ │ │ └── sql_utils.py │ ├── embeddings │ │ ├── pai │ │ │ ├── embedding_utils.py │ │ │ ├── pai_embedding.py │ │ │ └── pai_embedding_config.py │ │ └── readme.md │ ├── guardrail │ │ └── pai_guardrail.py │ ├── llms │ │ ├── pai │ │ │ ├── llm_config.py │ │ │ ├── llm_utils.py │ │ │ ├── open_ai_alike_multi_modal.py │ │ │ ├── pai_llm.py │ │ │ └── pai_multi_modal_llm.py │ │ └── readme.md │ ├── postprocessor │ │ ├── my_model_based_reranker.py │ │ └── pai │ │ │ └── pai_postprocessor.py │ ├── query_transform │ │ ├── intent_models.py │ │ └── pai_query_transform.py │ ├── search │ │ ├── aliyun_search.py │ │ ├── bing_search.py │ │ ├── bs4_reader.py │ │ ├── google_search.py │ │ └── search_config.py │ ├── synthesizer │ │ ├── pai_synthesizer.py │ │ └── prompt_templates.py │ ├── trace │ │ ├── base.py │ │ ├── pai_query_wrapper.py │ │ ├── reloadable_exporter.py │ │ └── trace_config.py │ └── vector_stores │ │ ├── dashvector │ │ └── dashvector.py │ │ ├── elasticsearch │ │ ├── elasticsearch_utils.py │ │ ├── my_async_vector_store.py │ │ └── my_elasticsearch.py │ │ ├── faiss │ │ └── my_faiss.py │ │ ├── hologres │ │ └── hologres.py │ │ ├── milvus │ │ └── my_milvus.py │ │ ├── postgresql │ │ └── postgresql.py │ │ └── tablestore │ │ └── tablestore.py │ ├── knowledgebase │ ├── index │ │ └── pai │ │ │ ├── pai_vector_index.py │ │ │ ├── utils │ │ │ ├── index_utils.py │ │ │ ├── sparse_embed_function.py │ │ │ └── vector_store_utils.py │ │ │ └── vector_store_config.py │ ├── models.py │ ├── rag_knowledgebase.py │ ├── rag_knowledgebase_helper.py │ └── utils │ │ └── knowledgebase_utils.py │ ├── tools │ ├── data_analysis │ │ └── text2sql │ │ │ ├── bird_eval.py │ │ │ ├── bird_eval_parallel.py │ │ │ └── spider_eval.py │ └── intent_eval │ │ ├── intent_eval_tool.py │ │ ├── intent_sample.json │ │ └── intent_sample_predicted.json │ ├── utils │ ├── __init__.py │ ├── citation_utils.py │ ├── constants.py │ ├── cuda_utils.py │ ├── download_models.py │ ├── file_utils.py │ ├── format_logging.py │ ├── json_parser.py │ ├── mdoelscope_utils.py │ ├── prompt_template.py │ ├── score_utils.py │ └── time_utils.py │ └── web │ ├── element_manager.py │ ├── event_listeners.py │ ├── filebrowser │ ├── constants.py │ └── request_utils.py │ ├── index_utils.py │ ├── rag_local_client.py │ ├── tabs │ ├── chat_tab.py │ ├── data_analysis_tab.py │ ├── history_tab.py │ ├── knowledgebase_tab.py │ ├── model │ │ └── index_info.py │ ├── news_extension.py │ ├── search_web_tab.py │ ├── settings_tab.py │ └── vector_db_panel.py │ ├── ui_constants.py │ ├── utils.py │ ├── view_model.py │ └── webui.py └── tests ├── __init__.py ├── app ├── test_openai.py └── test_openai_embedding.py ├── index ├── test_elasticsearch.py └── test_milvus.py ├── integrations ├── intentdetection │ ├── multi_intents_sample.json │ ├── multi_intents_sample_predicted.json │ └── test_multi_intents.py ├── llm │ ├── test_function_calling_llm.py │ └── test_llm.py ├── retriever │ └── test_es_tokenizer.py ├── test_nl2pandas_retriever.py └── test_nl2sql_retriever.py ├── news └── intent_eval │ ├── intent_sample.json │ ├── intent_sample_predicted.json │ └── test_news_intent.py └── testdata ├── csv_data └── titanic_train.csv ├── db_data ├── pets.db └── pets.sqlite ├── pai_document.md └── paul_graham └── paul_graham_essay.txt /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: PAI-RAG CI Build 2 | 3 | on: 4 | push: 5 | # Sequence of patterns matched against refs/heads 6 | branches: 7 | - main 8 | - feature 9 | - "releases/**" 10 | 11 | concurrency: 12 | group: pairag-ci-${{ github.head_ref || github.run_id }} 13 | cancel-in-progress: true 14 | 15 | permissions: 16 | contents: read 17 | pull-requests: write 18 | 19 | jobs: 20 | build: 21 | name: Build and Test 22 | runs-on: ubuntu-latest 23 | 24 | steps: 25 | - uses: actions/checkout@v4 26 | - name: Set up Python 3.11 27 | # This is the version of the action for setting up Python, not the Python version. 28 | uses: actions/setup-python@v5 29 | with: 30 | # Semantic version range syntax or exact version of a Python version 31 | python-version: "3.11" 32 | # Optional - x64 or x86 architecture, defaults to x64 33 | architecture: "x64" 34 | 35 | - name: Install Dependencies 36 | run: | 37 | python -m pip install --upgrade pip setuptools wheel 38 | pip install poetry 39 | poetry install 40 | poetry run pip install magic-pdf[full]==1.3.10 41 | poetry run pip install opentelemetry-exporter-otlp-proto-grpc protobuf==5.27.4 42 | env: 43 | POETRY_VIRTUALENVS_CREATE: false 44 | 45 | - name: Install pre-commit 46 | shell: bash 47 | run: poetry run pip install pre-commit 48 | 49 | - name: Run Linter 50 | shell: bash 51 | run: poetry run make lint 52 | 53 | - name: Run Tests 54 | run: | 55 | make coveragetest 56 | env: 57 | DASHSCOPE_API_KEY: ${{ secrets.TESTDASHSCOPEKEY }} 58 | BING_SEARCH_KEY: ${{ secrets.BING_SEARCH_KEY }} 59 | OSS_ACCESS_KEY_ID: ${{ secrets.OSS_ACCESS_KEY_ID }} 60 | OSS_ACCESS_KEY_SECRET: ${{ secrets.OSS_ACCESS_KEY_SECRET }} 61 | PAIRAG_RAG__embedding__source: "huggingface" 62 | PAIRAG_RAG__llm__source: "DashScope" 63 | PAIRAG_RAG__llm__model: "qwen-max" 64 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: PR Build 2 | 3 | on: 4 | pull_request: 5 | # Sequence of patterns matched against refs/heads 6 | branches: 7 | - main 8 | - feature 9 | - "releases/**" 10 | 11 | concurrency: 12 | group: pr-build-${{ github.head_ref || github.run_id }} 13 | cancel-in-progress: true 14 | 15 | permissions: 16 | contents: read 17 | pull-requests: write 18 | 19 | jobs: 20 | build: 21 | name: Build and Test 22 | runs-on: ubuntu-latest 23 | 24 | steps: 25 | - uses: actions/checkout@v4 26 | - name: Set up Python 3.11 27 | # This is the version of the action for setting up Python, not the Python version. 28 | uses: actions/setup-python@v5 29 | with: 30 | # Semantic version range syntax or exact version of a Python version 31 | python-version: "3.11" 32 | # Optional - x64 or x86 architecture, defaults to x64 33 | architecture: "x64" 34 | 35 | - name: Install Dependencies 36 | run: | 37 | python -m pip install --upgrade pip setuptools wheel 38 | pip install poetry 39 | poetry install 40 | poetry run pip install magic-pdf[full]==1.3.10 41 | poetry run pip install opentelemetry-exporter-otlp-proto-grpc protobuf==5.27.4 42 | env: 43 | POETRY_VIRTUALENVS_CREATE: false 44 | 45 | - name: Install pre-commit 46 | shell: bash 47 | run: poetry run pip install pre-commit 48 | 49 | - name: Run Linter 50 | shell: bash 51 | run: poetry run make lint 52 | 53 | - name: Run Tests 54 | run: | 55 | make coveragetest 56 | env: 57 | DASHSCOPE_API_KEY: ${{ secrets.TESTDASHSCOPEKEY }} 58 | PAIRAG_RAG__embedding__source: "huggingface" 59 | PAIRAG_RAG__llm__source: "DashScope" 60 | PAIRAG_RAG__llm__model: "qwen-max" 61 | BING_SEARCH_KEY: ${{ secrets.BING_SEARCH_KEY }} 62 | OSS_ACCESS_KEY_ID: ${{ secrets.OSS_ACCESS_KEY_ID }} 63 | OSS_ACCESS_KEY_SECRET: ${{ secrets.OSS_ACCESS_KEY_SECRET }} 64 | 65 | - name: Get Cover 66 | uses: orgoro/coverage@v3.1 67 | with: 68 | coverageFile: localdata/test_output/coverage_report.xml 69 | token: ${{ secrets.GITHUB_TOKEN }} 70 | thresholdAll: 0.4 # Total coverage threshold 71 | #thresholdNew: 0.9 # New files coverage threshold 72 | #thresholdModified: 0.9 # Modified files coverage threshold 73 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release image 2 | 3 | # Configures this workflow to run every time a change is pushed to the branch called `release`. 4 | on: 5 | workflow_dispatch: 6 | push: 7 | branches: ["main", "release_test"] 8 | 9 | # Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds. 10 | env: 11 | REGISTRY: mybigpai-public-registry.cn-beijing.cr.aliyuncs.com 12 | 13 | # There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu. 14 | jobs: 15 | build-and-push-image: 16 | runs-on: ubuntu-latest 17 | steps: 18 | - name: Checkout repository 19 | uses: actions/checkout@v4 20 | 21 | - uses: actions/setup-python@v4 22 | with: 23 | python-version: "3.11" 24 | 25 | - name: Check disk space 26 | run: df . -h 27 | 28 | - name: Free disk space 29 | run: | 30 | sudo docker rmi $(docker image ls -aq) >/dev/null 2>&1 || true 31 | sudo rm -rf \ 32 | /usr/share/dotnet /usr/local/lib/android /opt/ghc \ 33 | /usr/local/share/powershell /usr/share/swift /usr/local/.ghcup \ 34 | /usr/lib/jvm || true 35 | 36 | - name: Extract version 37 | run: | 38 | pip install poetry 39 | VERSION_TAG=$(poetry version --short) 40 | SPECIFIC_VERSION_TAG="$VERSION_TAG-$(date +'%Y%m%d')" 41 | echo "VERSION_TAG=$VERSION_TAG" >> $GITHUB_ENV 42 | echo "SPECIFIC_VERSION_TAG=$SPECIFIC_VERSION_TAG" >> $GITHUB_ENV 43 | echo "version:$SPECIFIC_VERSION_TAG\ncommit_id:$(git rev-parse HEAD)" > __build_version.cfg 44 | 45 | # Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here. 46 | - name: Login to ACR region 47 | uses: docker/login-action@v1 48 | with: 49 | registry: ${{ env.REGISTRY }} 50 | username: ${{ secrets.ACR_USER }} 51 | password: ${{ secrets.ACR_PUBLIC_PASSWORD }} 52 | 53 | - name: Build and push base image 54 | env: 55 | IMAGE_TAG: ${{env.VERSION_TAG}} 56 | SPECIFIC_IMAGE_TAG: ${{env.SPECIFIC_VERSION_TAG}} 57 | run: | 58 | docker build -t ${{ env.REGISTRY }}/mybigpai/pairag:${{ env.IMAGE_TAG }} . 59 | docker push ${{ env.REGISTRY }}/mybigpai/pairag:${{ env.IMAGE_TAG }} 60 | docker tag ${{ env.REGISTRY }}/mybigpai/pairag:${{ env.IMAGE_TAG }} ${{ env.REGISTRY }}/mybigpai/pairag:${{ env.SPECIFIC_IMAGE_TAG }} 61 | docker push ${{ env.REGISTRY }}/mybigpai/pairag:${{ env.SPECIFIC_IMAGE_TAG }} 62 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11 AS builder 2 | 3 | RUN pip3 install poetry 4 | 5 | ENV POETRY_NO_INTERACTION=1 \ 6 | POETRY_VIRTUALENVS_IN_PROJECT=1 \ 7 | POETRY_VIRTUALENVS_CREATE=1 \ 8 | POETRY_CACHE_DIR=/tmp/poetry_cache 9 | 10 | WORKDIR /app 11 | COPY . . 12 | 13 | RUN poetry install \ 14 | && poetry run pip install magic-pdf[full]==1.3.10 \ 15 | && poetry run pip install opentelemetry-exporter-otlp-proto-grpc protobuf==5.27.4 \ 16 | && rm -rf $POETRY_CACHE_DIR 17 | 18 | FROM python:3.11-slim AS prod 19 | 20 | RUN rm -rf /etc/localtime && ln -s /usr/share/zoneinfo/Asia/Harbin /etc/localtime 21 | 22 | RUN apt-get update && apt-get install -y libgl1 libglib2.0-0 libgomp1 curl libgdiplus wget perl build-essential 23 | 24 | ENV VIRTUAL_ENV=/app/.venv \ 25 | PATH="/app/.venv/bin:$PATH" 26 | 27 | 28 | ADD https://eas-data.oss-cn-shanghai.aliyuncs.com/3rdparty/sdwebui/filebrowser /bin/filebrowser 29 | RUN chmod u+x /bin/filebrowser 30 | 31 | # setup paddleocr dependencies 32 | RUN mkdir -p /root/.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer \ 33 | && curl https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_infer.tar -o /root/.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer/ch_PP-OCRv4_det_infer.tar \ 34 | && tar xvf /root/.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer/ch_PP-OCRv4_det_infer.tar -C /root/.paddleocr/whl/det/ch/ 35 | 36 | RUN mkdir -p /root/.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer \ 37 | && curl https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_rec_infer.tar -o /root/.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer/ch_PP-OCRv4_rec_infer.tar \ 38 | && tar xvf /root/.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer/ch_PP-OCRv4_rec_infer.tar -C /root/.paddleocr/whl/rec/ch/ 39 | 40 | RUN mkdir -p /root/.paddleocr/whl/cls/ch_ppocr_mobile_v2.0_cls_infer \ 41 | && curl https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar -o /root/.paddleocr/whl/cls/ch_ppocr_mobile_v2.0_cls_infer/ch_ppocr_mobile_v2.0_cls_infer.tar \ 42 | && tar xvf /root/.paddleocr/whl/cls/ch_ppocr_mobile_v2.0_cls_infer/ch_ppocr_mobile_v2.0_cls_infer.tar -C /root/.paddleocr/whl/cls/ 43 | 44 | WORKDIR /app 45 | 46 | COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV} 47 | COPY . . 48 | CMD ["./scripts/start.sh", "-w", "1"] 49 | -------------------------------------------------------------------------------- /Dockerfile_ingestion: -------------------------------------------------------------------------------- 1 | FROM rayproject/ray:2.46.0.0e19ea-py311-cu124 2 | 3 | RUN pip3 install poetry 4 | 5 | ENV POETRY_NO_INTERACTION=1 \ 6 | POETRY_VIRTUALENVS_CREATE=false \ 7 | POETRY_CACHE_DIR=/tmp/poetry_cache \ 8 | HOME_DIR=/home/ray 9 | 10 | WORKDIR /app 11 | COPY . . 12 | 13 | RUN poetry install 14 | RUN poetry run pip install magic-pdf[full]==1.3.10 \ 15 | && poetry run pip install opentelemetry-exporter-otlp-proto-grpc protobuf==5.27.4 \ 16 | && rm -rf $POETRY_CACHE_DIR 17 | 18 | RUN sudo rm -rf /etc/localtime && sudo ln -s /usr/share/zoneinfo/Asia/Harbin /etc/localtime 19 | RUN sudo apt-get update && sudo apt-get install -y libgl1 libglib2.0-0 libgomp1 curl libgdiplus wget perl build-essential 20 | 21 | # setup paddleocr dependencies 22 | RUN sudo mkdir -p $HOME_DIR/.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer \ 23 | && sudo curl https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_infer.tar -o $HOME_DIR/.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer/ch_PP-OCRv4_det_infer.tar \ 24 | && sudo tar xvf $HOME_DIR/.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer/ch_PP-OCRv4_det_infer.tar -C $HOME_DIR/.paddleocr/whl/det/ch/ 25 | 26 | RUN sudo mkdir -p $HOME_DIR/.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer \ 27 | && sudo curl https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_rec_infer.tar -o $HOME_DIR/.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer/ch_PP-OCRv4_rec_infer.tar \ 28 | && sudo tar xvf $HOME_DIR/.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer/ch_PP-OCRv4_rec_infer.tar -C $HOME_DIR/.paddleocr/whl/rec/ch/ 29 | 30 | RUN sudo mkdir -p $HOME_DIR/.paddleocr/whl/cls/ch_ppocr_mobile_v2.0_cls_infer \ 31 | && sudo curl https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar -o $HOME_DIR/.paddleocr/whl/cls/ch_ppocr_mobile_v2.0_cls_infer/ch_ppocr_mobile_v2.0_cls_infer.tar \ 32 | && sudo tar xvf $HOME_DIR/.paddleocr/whl/cls/ch_ppocr_mobile_v2.0_cls_infer/ch_ppocr_mobile_v2.0_cls_infer.tar -C $HOME_DIR/.paddleocr/whl/cls/ 33 | 34 | CMD ["./scripts/load_data.sh", "--help"] 35 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | GIT_ROOT ?= $(shell git rev-parse --show-toplevel) 2 | 3 | help: ## Show all Makefile targets. 4 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' 5 | 6 | format: ## Run code autoformatters (black). 7 | pre-commit install 8 | git ls-files | xargs pre-commit run black --files 9 | 10 | lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy 11 | pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files 12 | 13 | test: ## Run tests via pytest. 14 | pytest tests -s 15 | 16 | coveragetest: ## Tests with coverage report 17 | pytest --cov-report xml:localdata/test_output/coverage_report.xml --cov=pairag tests -s 18 | 19 | watch-docs: ## Build and watch documentation. 20 | sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ 21 | 22 | publish: 23 | poetry publish --build --username __token__ --password $$PYPI_KEY --build --skip-existing 24 | -------------------------------------------------------------------------------- /docker/.env.example: -------------------------------------------------------------------------------- 1 | # DASHSCOPE API_KEY 2 | DASHSCOPE_API_KEY= 3 | USE_CUDA=0 4 | 5 | # OSS AK SK 6 | OSS_ACCESS_KEY_ID= 7 | OSS_ACCESS_KEY_SECRET= 8 | OSS_BUCKET= 9 | OSS_ENDPOINT= 10 | -------------------------------------------------------------------------------- /docker/compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | api: 3 | image: mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/mybigpai/pairag:1.1.0 4 | ports: 5 | - "8680:8680" 6 | restart: always 7 | environment: 8 | DASHSCOPE_API_KEY: ${DASHSCOPE_API_KEY} 9 | OSS_ACCESS_KEY_ID: ${OSS_ACCESS_KEY_ID} 10 | OSS_ACCESS_KEY_SECRET: ${OSS_ACCESS_KEY_SECRET} 11 | PAIRAG_RAG__oss_store__bucket: ${OSS_BUCKET} 12 | PAIRAG_RAG__oss_store__endpoint: ${OSS_ENDPOINT:-oss-cn-hangzhou.aliyuncs.com} 13 | ARMS_APP_NAME: ${ARMS_APP_NAME} 14 | ARMS_REGION_ID: ${ARMS_REGION_ID} 15 | ARMS_LICENSE_KEY: ${ARMS_LICENSE_KEY} 16 | ARMS_IS_PUBLIC: true 17 | ENABLE_FASTAPI: false 18 | ENABLE_REQUESTS: false 19 | ENABLE_AIOHTTPCLIENT: false 20 | ENABLE_HTTPX: false 21 | LD_LIBRARY_PATH: /usr/local/lib 22 | 23 | volumes: 24 | - ../model_repository:/app/model_repository 25 | - ./app_data:/app/localdata 26 | entrypoint: ["./scripts/start.sh", "-w", "1"] 27 | healthcheck: 28 | test: ["CMD", "curl", "-f", "http://localhost:8680/api/v1/health"] 29 | interval: 30s 30 | retries: 40 31 | start_period: 20s 32 | -------------------------------------------------------------------------------- /docs/data_analysis_doc.md: -------------------------------------------------------------------------------- 1 | # 模型配置 2 | 3 | 在web界面 Settings 中,选择需要的LLM,如果选择DashScope(通义API),推荐使用qwen-max模型;如果选择PaiEas开源部署,推荐使用qwen2-72b-instruct模型 4 | 5 | 点击下方Save更新使用的模型 6 | 7 | ![llm_selection](/docs/figures/data_analysis/llm_selection.png) 8 | 9 | 点击web界面上方Data Analysis,进入到数据分析页面,支持两种类型的数据分析:连接数据库(mysql)分析 和 上传表格文件(excel/csv)分析 10 | 11 | ![data_analysis_overview](/docs/figures/data_analysis/data_analysis_overview.png) 12 | 13 | # 数据库分析配置 14 | 15 | ## 数据库连接 16 | 17 | 连接数据库,选择左上方数据分析类型为 database,出现数据库连接配置界面,如下图: 18 | 19 | ![db_config](/docs/figures/data_analysis/db_config.png) 20 | 21 | 其中, 22 | 23 | - Dialect为数据库类别,当前支持mysql,默认mysql 24 | - Username和Passoword分别为用户名和密码 25 | - Host为本地或远程数据库url,Port为接口,默认3306 26 | - Database为需要分析的目标数据库名称 27 | - Tables为需要分析的数据表,格式为:table_A, table_B,... ,默认为空,使用目标数据库中所有数据表 28 | - Descriptions为针对目标数据库中每张表的补充描述,比如对表中字段的进一步解释,可以提升数据分析效果,格式为:{"table_A":"字段a表示xxx,字段b数据的格式为yyy","table_B":"这张表主要用于zzz"},注意:需要使用英文输入法下的字典格式(英文双引号,冒号,逗号),默认为空 29 | - 支持三种prompt template用于生成sql查询,点击选项后可以看到具体的prompt template,也可在custom中自定义 30 | 31 | **注意:** 如自定义prompt,请尽量参考template修改,其中中大括号{}的内容为输入参数,需要保留 32 | 33 | 确认输入无误后,可直接在右侧chatbot中开始提问,回答结果的Reference中可以看到查询的数据库表名称,生成的sql语句,以及该sql语句是否有效执行。**注意:** 这里有效执行是指sql语句语法有效,不代表业务逻辑一定正确 34 | 35 | Reference可以作为查询效果优化的"debug"工具 36 | 37 | ![db_chat](/docs/figures/data_analysis/db_chat.png) 38 | 39 | ## 查询效果优化 40 | 41 | ### Description 42 | 43 | 针对数据表中字段含义不清晰,或者字段存储内容格式不清晰等问题,可以在Descriptions中增加相应描述,帮助llm更准确提取数据表内容,此处以公开数据集Spider中my_pets数据库为例,其中pets表数据如下: 44 | 45 | ![table_example](/docs/figures/data_analysis/table_example.png) 46 | 47 | 问答效果对比: 48 | 49 | 当描述为空时,对问题“有几只狗”生成的sql查询语句为:SELECT COUNT(\*) FROM pets WHERE PetType = '狗',查询不到 50 | 51 | ![db_query_no_desc](/docs/figures/data_analysis/db_query_no_desc.png) 52 | 53 | 增加简单描述后,生成的sql查询语句为:SELECT COUNT(\*) FROM pets WHERE PetType = 'dog',可以准确回答 54 | 55 | ![db_query_desc](/docs/figures/data_analysis/db_query_desc.png) 56 | 57 | 如果查询效果有明显改善,可以将相应的补充描述在数据库中作为相应table或column的comment持久化添加 58 | 59 | ### Prompt 60 | 61 | 此外,还可以通过调整prompt template,优化查询效果 62 | 63 | - Reference观察到生成的sql语句包含了非sql的其他内容,如开头多了"sql"等(部分小模型可能会发生),可以在prompt中增加简单限制 64 | - Reference观察到生成的sql语句不满足某些业务逻辑,可以在prompt中给出示例,通过few-shot learning,也可以很快提升效果 65 | 66 | # 表格文件分析配置 67 | 68 | 表格文件配置相对简单,选择左上方的分析类型为:datafile,出现以下界面 69 | 70 | ![sheet_upload](/docs/figures/data_analysis/sheet_upload.png) 71 | 72 | 点击左侧中部的上传,一次上传一份表格文件(excel或csv格式),上传成功后,左侧下方会出现文件的前几行预览,如下图所示: 73 | 74 | ![sheet_data_preview](/docs/figures/data_analysis/sheet_data_preview.png) 75 | 76 | 上传表格文件后可以直接在右侧chatbot中提问,如需更换表格,重新上传所需表格即可 77 | -------------------------------------------------------------------------------- /docs/develop/local_develop_zh.md: -------------------------------------------------------------------------------- 1 | 如果需要在本地进行开发运行,请参考以下步骤: 2 | 3 | ## 本地启动 4 | 5 | 1. 克隆仓库 6 | 7 | ```bash 8 | git clone git@github.com:aigc-apps/PAI-RAG.git 9 | ``` 10 | 11 | 2. 配置开发环境 12 | 13 | 本项目使用poetry进行管理,若在本地环境下使用,建议在安装环境之前先创建一个空环境。为了确保环境一致性并避免因Python版本差异造成的问题,我们指定Python版本为3.11。 14 | 15 | ```bash 16 | conda create -n rag_env python==3.11 17 | conda activate rag_env 18 | ``` 19 | 20 | 如果使用macOS且需要处理PPTX文件,需要下载依赖库处理PPTX文件 21 | 22 | ```bash 23 | brew install mono-libgdiplus 24 | ``` 25 | 26 | 直接使用poetry安装项目依赖包: 27 | 28 | ```bash 29 | pip install poetry 30 | poetry install 31 | pip install magic-pdf[full]==1.3.10 32 | pip install opentelemetry-exporter-otlp-proto-grpc protobuf==5.27.4 33 | ``` 34 | 35 | 安装filebrowser 36 | 37 | ```bash 38 | wget https://eas-data.oss-cn-shanghai.aliyuncs.com/3rdparty/sdwebui/filebrowser 39 | mv filebrowser /bin/filebrowser 40 | chmod u+x /bin/filebrowser 41 | ``` 42 | 43 | - 常见网络超时问题 44 | 45 | 注:在安装过程中,若遇到网络连接超时的情况,可以添加阿里云或清华的镜像源,在 pyproject.toml 文件末尾追加以下几行: 46 | 47 | ```bash 48 | [[tool.poetry.source]] 49 | name = "mirrors" 50 | url = "http://mirrors.aliyun.com/pypi/simple/" # 阿里云 51 | # url = "https://pypi.tuna.tsinghua.edu.cn/simple/" # 清华 52 | priority = "default" 53 | ``` 54 | 55 | 之后,再依次执行以下命令: 56 | 57 | ```bash 58 | poetry lock 59 | poetry install 60 | pip install magic-pdf[full]==1.3.10 61 | pip install opentelemetry-exporter-otlp-proto-grpc protobuf==5.27.4 62 | ``` 63 | 64 | 3. 启动RAG服务 65 | 66 | 使用DashScope API,需要在命令行引入环境变量: 67 | 68 | ```bash 69 | export DASHSCOPE_API_KEY="xxx" 70 | ``` 71 | 72 | 请替换xxx为你自己的DASHSCOPE_API_KEY,DASHSCOPE_API_KEY获取地址为 https://dashscope.console.aliyun.com/apiKey 73 | 74 | 启动: 75 | 76 | ```bash 77 | # 启动,支持自定义hport(默认8680), worker_num(默认1) 78 | # 默认启动时下载模型 [bge-m3, pdf-extract] 79 | # 可使用命令行 "load_model" 下载模型 including [bge-m3, pdf-extract, SGPT-125M-weightedmean-nli-bitfit, bge-large-zh-v1.5, bge-reranker-base, bge-reranker-large, paraphrase-multilingual-MiniLM-L12-v2, qwen_1.8b, text2vec-large-chinese] 80 | ./scripts/start.sh [-w WORKER_NUM] [-p PORT] 81 | ``` 82 | 83 | ```bash 84 | ./scripts/start.sh 85 | ``` 86 | 87 | 你可以打开http://localhost:8680/ 来配置RAG服务以及上传本地数据。 88 | -------------------------------------------------------------------------------- /docs/docker_build.md: -------------------------------------------------------------------------------- 1 | # Docker build 2 | 3 | ## Server 4 | 5 | ### CPU 6 | 7 | ```bash 8 | docker build -f Dockerfile -t rag_serve:0.1 . 9 | ``` 10 | 11 | ### GPU 12 | 13 | ```bash 14 | docker build -f Dockerfile_gpu -t rag_serve:0.1_gpu . 15 | ``` 16 | 17 | ## UI 18 | 19 | ```bash 20 | docker build -f Dockerfile_ui -t rag_ui:0.1 . 21 | ``` 22 | 23 | ## Nginx 24 | 25 | ```bash 26 | docker build -f Dockerfile_nginx -t rag_nginx:0.1 . 27 | ``` 28 | 29 | # 常见问题 30 | 31 | ## docker pull timeout 32 | 33 | 建议更换docker镜像源为阿里云镜像,在阿里云在容器镜像服务 -> 镜像工具 -> 镜像加速器 中可以找到阿里云的专属镜像加速器,按照指示说明修改daemon配置文件来使用加速器即可。 34 | -------------------------------------------------------------------------------- /docs/eas_deploy.md: -------------------------------------------------------------------------------- 1 | # EAS自定义部署RAG服务 2 | 3 | 模型在线服务EAS(Elastic Algorithm Service)是阿里云PAI产品为实现一站式模型开发部署应用,针对在线推理场景提供的模型在线服务,支持将模型服务部署在公共资源组或专属资源组,实现基于异构硬件(CPU和GPU)的模型加载和数据请求的实时响应。 4 | 5 | 我们支持通过`场景化部署`和`自定义部署`两种方式来一键部署RAG服务,其中, 6 | 7 | - `场景化部署`: 更加方便,只需要配置几个参数即可完成。可参考[场景化部署文档](https://help.aliyun.com/zh/pai/user-guide/deploy-a-rag-based-dialogue-system)。 8 | - `自定义部署`可以更灵活地配置服务,比如,部署GPU版本镜像,配置链路追踪服务等等。可参考[自定义部署文档](https://help.aliyun.com/zh/pai/use-cases/custom-deployment-of-rag-service#47e8104831b4f) 9 | -------------------------------------------------------------------------------- /docs/figures/agent/agenda_query.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/agent/agenda_query.png -------------------------------------------------------------------------------- /docs/figures/agent/agent_tab.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/agent/agent_tab.png -------------------------------------------------------------------------------- /docs/figures/agent/date_query.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/agent/date_query.png -------------------------------------------------------------------------------- /docs/figures/agent/nl2sql_query.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/agent/nl2sql_query.png -------------------------------------------------------------------------------- /docs/figures/agent/rag_knowledge_query.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/agent/rag_knowledge_query.png -------------------------------------------------------------------------------- /docs/figures/agent/weather_query.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/agent/weather_query.png -------------------------------------------------------------------------------- /docs/figures/agent/web_search.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/agent/web_search.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/DBChat_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/DBChat_overview.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/da_db_chat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/da_db_chat.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/da_db_config.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/da_db_config.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/da_db_enhance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/da_db_enhance.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/da_db_load.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/da_db_load.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/da_db_prompt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/da_db_prompt.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/da_db_prompt_reset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/da_db_prompt_reset.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/da_llm_config.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/da_llm_config.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/da_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/da_overview.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/da_sheet_chat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/da_sheet_chat.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/da_sheet_upload.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/da_sheet_upload.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/data_analysis_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/data_analysis_overview.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/datafile_chat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/datafile_chat.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/datafile_config.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/datafile_config.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/db_chat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/db_chat.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/db_chat_with_memo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/db_chat_with_memo.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/db_config.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/db_config.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/db_config_update.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/db_config_update.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/db_enhanced_features.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/db_enhanced_features.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/db_info_load.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/db_info_load.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/db_query_desc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/db_query_desc.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/db_query_no_desc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/db_query_no_desc.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/llm_config.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/llm_config.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/llm_selection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/llm_selection.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/prompt_config.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/prompt_config.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/prompt_reset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/prompt_reset.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/sheet_data_preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/sheet_data_preview.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/sheet_upload.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/sheet_upload.png -------------------------------------------------------------------------------- /docs/figures/data_analysis/table_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/data_analysis/table_example.png -------------------------------------------------------------------------------- /docs/figures/deepseek/deepseek_eas_api.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deepseek/deepseek_eas_api.png -------------------------------------------------------------------------------- /docs/figures/deepseek/llm_chat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deepseek/llm_chat.png -------------------------------------------------------------------------------- /docs/figures/deepseek/llm_config.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deepseek/llm_config.png -------------------------------------------------------------------------------- /docs/figures/deepseek/rag_chat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deepseek/rag_chat.png -------------------------------------------------------------------------------- /docs/figures/deploy/eas/deploy_json.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/deploy_json.png -------------------------------------------------------------------------------- /docs/figures/deploy/eas/deploy_portal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/deploy_portal.png -------------------------------------------------------------------------------- /docs/figures/deploy/eas/deploy_resources.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/deploy_resources.jpg -------------------------------------------------------------------------------- /docs/figures/deploy/eas/deploy_success.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/deploy_success.jpg -------------------------------------------------------------------------------- /docs/figures/deploy/eas/deploy_vpc.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/deploy_vpc.jpg -------------------------------------------------------------------------------- /docs/figures/deploy/eas/edited_json.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/edited_json.jpg -------------------------------------------------------------------------------- /docs/figures/deploy/eas/enable_web.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/enable_web.jpg -------------------------------------------------------------------------------- /docs/figures/deploy/eas/trace_detail.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/trace_detail.jpg -------------------------------------------------------------------------------- /docs/figures/deploy/eas/trace_json.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/trace_json.jpg -------------------------------------------------------------------------------- /docs/figures/deploy/eas/trace_key.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/trace_key.jpg -------------------------------------------------------------------------------- /docs/figures/deploy/eas/trace_percent.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/trace_percent.jpg -------------------------------------------------------------------------------- /docs/figures/deploy/eas/view_web.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/deploy/eas/view_web.jpg -------------------------------------------------------------------------------- /docs/figures/elastic/aliyun_es_ik_hot_update.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_ik_hot_update.png -------------------------------------------------------------------------------- /docs/figures/elastic/aliyun_es_instance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_instance.png -------------------------------------------------------------------------------- /docs/figures/elastic/aliyun_es_instance_autoindex.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_instance_autoindex.png -------------------------------------------------------------------------------- /docs/figures/elastic/aliyun_es_instance_info.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_instance_info.png -------------------------------------------------------------------------------- /docs/figures/elastic/aliyun_es_kibana.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_kibana.png -------------------------------------------------------------------------------- /docs/figures/elastic/aliyun_es_kibana_entry.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_kibana_entry.png -------------------------------------------------------------------------------- /docs/figures/elastic/aliyun_es_kibana_index_detail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_kibana_index_detail.png -------------------------------------------------------------------------------- /docs/figures/elastic/aliyun_es_kibana_index_management.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_kibana_index_management.png -------------------------------------------------------------------------------- /docs/figures/elastic/aliyun_es_kibana_management.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_kibana_management.png -------------------------------------------------------------------------------- /docs/figures/elastic/aliyun_es_kibana_menu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_kibana_menu.png -------------------------------------------------------------------------------- /docs/figures/elastic/aliyun_es_kibana_whitelist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_kibana_whitelist.png -------------------------------------------------------------------------------- /docs/figures/elastic/aliyun_es_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_overview.png -------------------------------------------------------------------------------- /docs/figures/elastic/aliyun_es_password.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_password.png -------------------------------------------------------------------------------- /docs/figures/elastic/aliyun_es_password_reset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_password_reset.png -------------------------------------------------------------------------------- /docs/figures/elastic/aliyun_es_plugin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_plugin.png -------------------------------------------------------------------------------- /docs/figures/elastic/aliyun_es_upload_dic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/aliyun_es_upload_dic.png -------------------------------------------------------------------------------- /docs/figures/elastic/new_word_dict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/new_word_dict.png -------------------------------------------------------------------------------- /docs/figures/elastic/pairag_es_connect.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/pairag_es_connect.png -------------------------------------------------------------------------------- /docs/figures/elastic/pairag_es_param_list.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/pairag_es_param_list.png -------------------------------------------------------------------------------- /docs/figures/elastic/pairag_retrieval_mode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/pairag_retrieval_mode.png -------------------------------------------------------------------------------- /docs/figures/elastic/pairag_select_es.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/elastic/pairag_select_es.png -------------------------------------------------------------------------------- /docs/figures/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/framework.jpg -------------------------------------------------------------------------------- /docs/figures/index/add_index.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/index/add_index.png -------------------------------------------------------------------------------- /docs/figures/index/config_new_index.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/index/config_new_index.png -------------------------------------------------------------------------------- /docs/figures/index/select_index.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/index/select_index.png -------------------------------------------------------------------------------- /docs/figures/multimodal/mm_chat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/multimodal/mm_chat.jpg -------------------------------------------------------------------------------- /docs/figures/multimodal/mm_settings.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/multimodal/mm_settings.jpg -------------------------------------------------------------------------------- /docs/figures/multimodal/mm_upload.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/multimodal/mm_upload.jpg -------------------------------------------------------------------------------- /docs/figures/quick_start/query.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/quick_start/query.png -------------------------------------------------------------------------------- /docs/figures/quick_start/setting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/quick_start/setting.png -------------------------------------------------------------------------------- /docs/figures/quick_start/upload.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/docs/figures/quick_start/upload.png -------------------------------------------------------------------------------- /docs/index/multi_index.md: -------------------------------------------------------------------------------- 1 | # 多知识库索引支持 2 | 3 | 当需要多个知识库来存储不同类型/业务的知识文件时,可以使用我们的多知识库支持api/ui来管理知识库索引。 4 | 目前支持以下功能: 5 | 6 | 1. 新增一个知识库index 7 | 2. 修改创建好的知识库index 8 | 3. 删除一个知识库index 9 | 4. 检索/查询对应的知识库 10 | 5. 支持不同的向量数据库后端,faiss, elasticsearch等等。 11 | 12 | ## 快速使用 13 | 14 | 可以打开UI页面快速创建、上传、检索知识库。 15 | 16 | ### 创建 17 | 18 | 1. Index下拉列表选择新建"New" 19 | 20 | Description 21 | 22 | 2. 填写索引名称,选择embedding和向量数据库类型,点击"Add Index" 23 | 24 | Description 25 | 26 | 创建完成。 27 | 28 | ### 上传知识库和查询 29 | 30 | 可以通过左边的选择器选择对应的index_name进行操作: 31 | 32 | Description 33 | 34 | ## API 使用 35 | 36 | 注意,目前查询和上传API均可以指定index_name来切换知识库,当index_name参数省略时,默认为**default**知识库。 37 | 38 | 目前,删除index操作仅仅支持通过API删除。 39 | 40 | ### 查询接口(Query & Retrieval) 41 | 42 | ** Retrieval ** 43 | 44 | ```sh 45 | curl -X POST http://localhost:8680/service/query/retrieval -H "Content-Type: application/json" -d '{"question": "什么是组件化", "index_name": "default"}' 46 | ``` 47 | 48 | ** Query ** 49 | 50 | ```sh 51 | curl -X POST http://localhost:8680/service/query -H "Content-Type: application/json" -d '{"question": "什么是组件化", "index_name": "default"}' 52 | ``` 53 | 54 | ### 上传接口(Upload) 55 | 56 | ```sh 57 | curl -X POST http://localhost:8680/service/upload_data -H 'Content-Type: multipart/form-data' -F 'files=@/localpath/PAI.txt' -F "index_name=es_test_1" 58 | ``` 59 | 60 | ### List Index 61 | 62 | ```sh 63 | curl -X GET http://localhost:8680/service/indexes 64 | ``` 65 | 66 | ### Add Index 67 | 68 | ```sh 69 | curl -X POST http://localhost:8680/service/indexes/index3 -H "Content-Type: Application/json" -d '{"index_name":"index3","vector_store_config":{"persist_path":"localdata/storage3","type":"faiss","is_image_store":false},"embedding_config":{"source":"dashscope","embed_batch_size":10}}' 70 | ``` 71 | 72 | response: 73 | 74 | ```json 75 | { "msg": "Add index 'index3' successfully." } 76 | ``` 77 | 78 | ### Update Index 79 | 80 | ```sh 81 | curl -X PATCH http://localhost:8680/service/indexes/index3 -H "Content-Type: Application/json" -d '{"index_name":"index3","vector_store_config":{"persist_path":"localdata/storage4","type":"faiss","is_image_store":false},"embedding_config":{"source":"dashscope","embed_batch_size":10}}' 82 | ``` 83 | 84 | response: 85 | 86 | ```json 87 | { "msg": "Update index 'index3' successfully." } 88 | ``` 89 | 90 | ### Delete Index 91 | 92 | ```sh 93 | curl -X DELETE http://localhost:8680/service/indexes/index3 94 | ``` 95 | 96 | response: 97 | 98 | ```json 99 | { "msg": "Update index 'index3' successfully." } 100 | ``` 101 | -------------------------------------------------------------------------------- /docs/multimodal_rag.md: -------------------------------------------------------------------------------- 1 | # 多模态问答 2 | 3 | 很多时候,知识库的文档中不只是纯文本信息,还包含很多图文交错的pdf、word、markdown等文件,甚至一些海报之类的纯图片文件。 4 | 普通的RAG流程会忽略这些图片输入,仅仅使用文本信息,这样会出现很多信息丢失的情况。 5 | 这里我们通过使用多模态模型来实现图文混合的多模态问答。 6 | 7 | ## 配置多模态LLM和Aliyun OSS存储 8 | 9 | 首先我们需要配置多模态LLM,这里我们推荐使用DASHSCOPE VLLM,或者部署在PAI-EAS的兼容openai协议的VLLM模型,比如开源的qwen2-vl。 10 | 11 | 然后需要添加一个Aliyun的OSS存储,来存储图片文件信息。这样在结果输出时,可以通过图片链接的方式在回复中展示图片。 12 | 13 | 配置示例如图, 配置完保存即可。 14 | 15 | 16 | 17 | ## 上传多模态文件 18 | 19 | 这里支持多种多模态文件格式,包括pdf, markdown, word, ppt, png, jpg等。 20 | 21 | 这里需要勾选`Process with MultiModal` 选项才会处理文件中的图片信息,处理pdf时,建议勾选`Process PDF with OCR` 选项。 22 | 23 | 24 | 25 | ## 问答测试 26 | 27 | 勾选`Display Image`选项,就可以进行多模态问答。 28 | 29 | 同时,你可以调整下方的Multimodal Prompt来优化问答提示词。 30 | 31 | 32 | -------------------------------------------------------------------------------- /docs/tabular_doc.md: -------------------------------------------------------------------------------- 1 | # Tabular processing with PAI-RAG 2 | 3 | ## PaiCSVReader 4 | 5 | PaiCSVReader(concat_rows=True, row_joiner="\n", csv_config={}) 6 | 7 | ### Parameters: 8 | 9 | **concat_rows:** _bool, default=True._ 10 | Whether to concatenate rows into one document. 11 | 12 | **row_joiner:** _str, default="\n"._ 13 | The separator used to join rows. 14 | 15 | **header:** _None or int, list of int, default 0._ 16 | row (0-indexed) to use for the column labels of the parsed DataFrame. If a list of integers is passed those row 17 | positions will be combined into a MultiIndex. Use None if there is no header. 18 | 19 | ### Functions: 20 | 21 | load_data(file: Path, extra_info: Optional[Dict] = None, fs: Optional[AbstractFileSystem] = None) 22 | 23 | ## PaiPandasCSVReader 24 | 25 | PaiPandasCSVReader(concat_rows=True, row_joiner="\n", pandas_config={}) 26 | 27 | ### Parameters: 28 | 29 | **concat_rows:** _bool, default=True._ 30 | Whether to concatenate rows into one document. 31 | 32 | **row_joiner:** _str, default="\n"._ 33 | The separator used to join rows. 34 | 35 | **pandas_config:** _dict, default={}._ 36 | The configuration of pandas.read_csv. 37 | Refer to https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html for more information. 38 | Set to empty dict by default, this means pandas will try to figure out the separators, table head, etc. on its own. 39 | 40 | #### one important parameter: 41 | 42 | **header:** _None or int, list of int, default 0._ 43 | Row (0-indexed) to use for the column labels of the parsed DataFrame. If a list of integers is passed those row 44 | positions will be combined into a MultiIndex. Use None if there is no header. 45 | 46 | ### Functions: 47 | 48 | load_data(file: Path, extra_info: Optional[Dict] = None, fs: Optional[AbstractFileSystem] = None) 49 | 50 | ## PaiPandasExcelReader 51 | 52 | PaiPandasExcelReader(concat_rows=True, row_joiner="\n", pandas_config={}) 53 | 54 | ### Parameters: 55 | 56 | **concat_rows:** _bool, default=True._ 57 | Whether to concatenate rows into one document. 58 | 59 | **row_joiner:** _str, default="\n"._ 60 | The separator used to join rows. 61 | 62 | **pandas_config:** _dict, default={}._ 63 | The configuration of pandas.read_csv. 64 | Refer to https://pandas.pydata.org/docs/reference/api/pandas.read_excel.html for more information. 65 | Set to empty dict by default, this means pandas will try to figure out the separators, table head, etc. on its own. 66 | 67 | #### one important parameter: 68 | 69 | **header:** _None or int, list of int, default 0._ 70 | Row (0-indexed) to use for the column labels of the parsed DataFrame. If a list of integers is passed those row 71 | positions will be combined into a MultiIndex. Use None if there is no header. 72 | 73 | ### Functions: 74 | 75 | load_data(file: Path, extra_info: Optional[Dict] = None, fs: Optional[AbstractFileSystem] = None) 76 | only process the first sheet 77 | -------------------------------------------------------------------------------- /example_data/cn_clip/pokemon.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/example_data/cn_clip/pokemon.jpeg -------------------------------------------------------------------------------- /example_data/eval_docs_image/跑鞋推荐.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/example_data/eval_docs_image/跑鞋推荐.pdf -------------------------------------------------------------------------------- /example_data/eval_docs_image_example/multimodal_eval_dataset_zh_example.jsonl: -------------------------------------------------------------------------------- 1 | {"id":"qca_03d5b51c-0308-436d-9a2f-a8389d821001","query":{"query_text":"2023年春夏期间,哪种泛户外运动在社交平台上的讨论声量最高?","source":{"name":"manual","model":"manual"}},"contexts":[{"node_id":"a09e96d9263b02b2b12a927144575af9bf31776e84e78ce35acab5b6619c84ac","type":"TextNode","text":"在赛事和政策的双重推动下,国民运动户外参与意愿高涨,超过六成的受访者表示近一年显著增加了运动户外的频率,各类运动项目正在快速走向“全民化”。新的一年,随着巴黎奥运会、美洲杯等赛事的举办,全民运动热情将进一步被激发。对于品牌而言,这是一个难得的市场机遇,通过精准地选中和锁定与运动相关的目标人群,品牌可以有效地实现用户收割。 \n\n \n\n悦己驱动,运动边界向轻量泛户外持续延伸 \n\n国民参与运动户外活动更多来自“悦己”观念的驱动,近7成的受访者表示他们主要是为了“强身健体/享受大自然”,因此轻量级、易开展的活动项目更受广大普通受众的青睐。近三年,社交平台关于“泛户外运动”的讨论热度持续走高,更是在23年春夏期间迎来一波小高峰:细分到具体的活动项目上,垂钓讨论声量较高;露营也保持较高声量,其经历过22年的大爆发、23年的行业调整,预计24年已经进入更深精细化运营;此外城市骑行热度也在不断上升,成为当下新兴的小众活动。","metadata":{"image_url_list":["https://pai-rag.oss-cn-hangzhou.aliyuncs.com/pairag/doc_images/2024春夏淘宝天猫运动户外行业趋势白皮书_淘宝/d4e624aceb4043839c924e33c075e388.jpeg", "https://pai-rag.oss-cn-hangzhou.aliyuncs.com/pairag/doc_images/2024春夏淘宝天猫运动户外行业趋势白皮书_淘宝/52d1353d4577698891e7710ae12e18b1.jpeg", "https://pai-rag.oss-cn-hangzhou.aliyuncs.com/pairag/doc_images/2024春夏淘宝天猫运动户外行业趋势白皮书_淘宝/4f77ded6421ddadd519ab9ef1601a784.jpeg"]}}],"answer":{"answer_text":"根据给定的材料,2023年春夏期间,垂钓在社交平台上的讨论声量最高。\n\n![](https://pai-rag.oss-cn-hangzhou.aliyuncs.com/pairag/doc_images/2024春夏淘宝天猫运动户外行业趋势白皮书_淘宝/d4e624aceb4043839c924e33c075e388.jpeg)","answer_image_url_list":null,"source":{"name":"manual","model":"manual"}}} 2 | {"id":"qca_03d5b51c-0308-436d-9a2f-a8389d821002","query":{"query_text":"狄尔泰属于哪个教育学流派?","source":{"name":"manual","model":"manual"}},"contexts":[{"node_id":"a09e96d9263b02b2b12a927144575af9bf31776e84e78ce35acab5b6619c84ab","type":"TextNode","text":"","metadata":{"image_url_list":["https://pai-rag.oss-cn-hangzhou.aliyuncs.com/pairag/doc_images/教育文档/思维导图.jpeg"]}}],"answer":{"answer_text":"狄尔泰属于文化教育学流派。\n\n![](https://pai-rag.oss-cn-hangzhou.aliyuncs.com/pairag/doc_images/教育文档/思维导图.jpeg)","answer_image_url_list":null,"source":{"name":"manual","model":"manual"}}} 3 | -------------------------------------------------------------------------------- /example_data/eval_docs_text/EasyRec.txt: -------------------------------------------------------------------------------- 1 | EasyRec是一个易于使用的推荐框架¶ 2 | EasyRec 实现了常见推荐任务中使用的最先进的机器学习模型:候选生成(匹配)、评分(排名)和多任务学习。 它通过简单的配置和超参数调整(HPO)提高了生成高性能模型的效率。 3 | 4 | EasyRec视频介绍 5 | 为什么选择 EasyRec?¶ 6 | 到处运行¶ 7 | MaxCompute / 数据科学 / DLC / 本地 8 | TF1.12-1.15 / TF2.x / PAI-TF 9 | 多样的输入数据¶ 10 | MaxCompute表 11 | HDFS 文件 12 | 操作系统文件 13 | 卡夫卡流 14 | 本地 CSV 15 | 16 | 配置简单¶ 17 | 灵活的功能配置和简单的模型配置 18 | 高效、鲁棒的特征生成[淘宝使用] 19 | 漂亮的网络界面正在开发中 20 | 21 | 它很聪明¶ 22 | EarlyStop / 最佳检查站保护程序 23 | 超参数搜索/AutoFeatureCross 24 | 开发中:NAS、知识蒸馏、多模式 25 | 26 | 规模大、部署方便¶ 27 | 支持大规模嵌入,增量保存 28 | 许多并行策略:ParameterServer、Mirrored、MultiWorker 29 | 轻松部署到 EAS:自动扩展、轻松监控 30 | 一致性保证:训练和服务 31 | 32 | 多种模型 33 | DSSM / MIND / DropoutNet / CoMetricLearningI2I / PDN 34 | W&D / DeepFM / MultiTower / DCN / DIN / BST 35 | MMoE / ESMM / DBMTL / PLE 36 | CMBF / 联合 37 | 38 | 易于定制¶ 39 | 易于实现定制模型 40 | 无需关心数据管道 41 | 快速向量检索¶ 42 | 在分布式环境中运行向量的 knn 算法 43 | 44 | 欢迎加入【EasyRec推荐算法交流群】,钉钉群号 : 32260796 45 | -------------------------------------------------------------------------------- /example_data/eval_docs_text/PAI.txt: -------------------------------------------------------------------------------- 1 | 机器学习PAI(Platform of Artificial Intelligence)是阿里云人工智能平台,提供一站式的机器学习解决方案。本文为您介绍什么是机器学习PAI。 2 | 3 | 什么是机器学习 4 | 机器学习是指机器通过统计学算法,对大量历史数据进行学习,进而利用生成的经验模型指导业务。目前机器学习主要应用在以下场景: 5 | 营销类场景:商品推荐、用户群体画像或广告精准投放。 6 | 金融类场景:贷款发放预测、金融风险控制、股票走势预测或黄金价格预测。 7 | 社交网络服务关系挖掘场景:微博粉丝领袖分析或社交关系链分析。 8 | 文本类场景:新闻分类、关键词提取、文章摘要或文本内容分析。 9 | 非结构化数据处理场景:图片分类或图片文本内容提取。 10 | 其它各类预测场景:降雨预测或足球比赛结果预测。 11 | 机器学习包括传统机器学习和深度学习。传统机器学习分为以下几类: 12 | 有监督学习(Supervised Learning):每个样本都有对应的期望值,通过搭建模型,实现从输入特征向量到目标值的映射。例如解决回归和分类问题。 13 | 无监督学习(Unsupervised Learning):所有样本没有目标值,期望从数据本身发现一些潜在规律。例如解决聚类问题。 14 | 增强学习(Reinforcement Learning):相对比较复杂,系统和外界环境不断交互,根据外界反馈决定自身行为,达到目标最优化。例如阿尔法围棋和无人驾驶。 15 | -------------------------------------------------------------------------------- /example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/dataset/上海攻略信息.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/dataset/上海攻略信息.pdf -------------------------------------------------------------------------------- /example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/dataset/北京攻略信息.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/dataset/北京攻略信息.pdf -------------------------------------------------------------------------------- /example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/figures/agent_chat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/figures/agent_chat.jpg -------------------------------------------------------------------------------- /example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/figures/agent_config.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/figures/agent_config.jpg -------------------------------------------------------------------------------- /example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/figures/settings.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/figures/settings.jpg -------------------------------------------------------------------------------- /example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/figures/upload.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/figures/upload.jpg -------------------------------------------------------------------------------- /example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/settings.toml: -------------------------------------------------------------------------------- 1 | dynaconf_merge = true 2 | 3 | [rag] 4 | name = "pairag" 5 | version = "0.1.1" 6 | 7 | [rag.agent] 8 | type = "function_calling" 9 | 10 | [rag.agent.custom_config] 11 | agent_file_path = "example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant" 12 | 13 | [rag.agent.intent_detection] 14 | type = "single" 15 | 16 | [rag.agent.tool] 17 | type = "api" 18 | 19 | [rag.chat_store] 20 | type = "Local" # [Local, Aliyun-Redis] 21 | host = "Aliyun-Redis host" 22 | password = "Aliyun-Redis user:pwd" 23 | persist_path = "localdata/storage" 24 | 25 | [rag.data_reader] 26 | type = "SimpleDirectoryReader" 27 | 28 | # embedding configurations, source support API: OpenAI,DashScope; and local model:HuggingFace 29 | # if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model 30 | # eg. 31 | # source = "HuggingFace" 32 | # model = "bge-m3" 33 | # embed_batch_size = 10 34 | [rag.embedding] 35 | source = "DashScope" 36 | embed_batch_size = 10 37 | 38 | [rag.embedding.multi_modal] 39 | source = "cnclip" 40 | 41 | [rag.index] 42 | persist_path = "localdata/storage" 43 | vector_store.type = "FAISS" 44 | 45 | # llm configurations, source support API: OpenAI,DashScope or PAI-EAS's deployment 46 | # eg. 47 | # source = "PaiEas" 48 | # endpoint = "" 49 | # token = "" 50 | [rag.llm] 51 | source = "DashScope" 52 | model = "qwen-max" 53 | 54 | [rag.llm.function_calling_llm] 55 | source = "DashScope" 56 | model = "qwen2-7b-instruct" 57 | 58 | [rag.llm.multi_modal] 59 | source = "" 60 | 61 | [rag.node_enhancement] 62 | tree_depth = 3 63 | max_clusters = 52 64 | proba_threshold = 0.10 65 | 66 | [rag.node_parser] 67 | type = "Sentence" 68 | chunk_size = 500 69 | chunk_overlap = 10 70 | 71 | [rag.postprocessor] 72 | reranker_type = "simple-weighted-reranker" # [simple-weighted-reranker, model-based-reranker] 73 | reranker_model = "bge-reranker-base" # [bge-reranker-base, bge-reranker-large] 74 | keyword_weight = 0.3 75 | vector_weight = 0.7 76 | similarity_threshold = 0.5 77 | top_n = 2 78 | 79 | [rag.query_engine] 80 | type = "RetrieverQueryEngine" 81 | 82 | [rag.retriever] 83 | similarity_top_k = 3 84 | retrieval_mode = "hybrid" # [hybrid, embedding, keyword, router] 85 | query_rewrite_n = 1 # set to 1 to disable query generation 86 | 87 | [rag.synthesizer] 88 | type = "SimpleSummarize" 89 | text_qa_template = "参考内容信息如下\n---------------------\n{context_str}\n---------------------根据提供内容而非其他知识回答问题.\n问题: {query_str}\n答案: \n" 90 | -------------------------------------------------------------------------------- /example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/tools.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import os 3 | from loguru import logger 4 | 5 | 6 | def get_place_weather(city: str) -> str: 7 | logger.info(f"[Agent] Checking realtime weather info for {city}") 8 | 9 | """Get city name and return city weather""" 10 | api_key = os.environ.get("weather_api_key") 11 | 12 | # 可以直接赋值给api_key,原始代码的config只有type类型。 13 | base_url = "http://api.openweathermap.org/data/2.5/forecast?" 14 | complete_url = f"{base_url}q={city}&appid={api_key}&lang=zh_cn&units=metric" 15 | logger.info(f"Requesting {complete_url}...") 16 | response = requests.get(complete_url, timeout=5) 17 | weather_data = response.json() 18 | 19 | if weather_data["cod"] != "200": 20 | logger.error( 21 | f"获取天气信息失败,错误代码:{weather_data['cod']} 错误信息:{weather_data['message']}" 22 | ) 23 | return f"获取天气信息失败,错误代码:{weather_data['cod']} 错误信息:{weather_data['message']}" 24 | 25 | element = weather_data["list"][0] 26 | 27 | return f""" 28 | {city}的天气: 29 | 时间: {element['dt_txt']} 30 | 温度: {element['main']['temp']} °C 31 | 天气描述: {element['weather'][0]['description']} 32 | """ 33 | -------------------------------------------------------------------------------- /example_data/pai_document.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/example_data/pai_document.pdf -------------------------------------------------------------------------------- /integrations/pairag-file/README.md: -------------------------------------------------------------------------------- 1 | # pairag.file 2 | 3 | This module contains the readers for the different file types and node parsers to chunk parsed files. 4 | -------------------------------------------------------------------------------- /integrations/pairag-file/pairag/file/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/pairag/file/__init__.py -------------------------------------------------------------------------------- /integrations/pairag-file/pairag/file/nodeparsers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/pairag/file/nodeparsers/__init__.py -------------------------------------------------------------------------------- /integrations/pairag-file/pairag/file/nodeparsers/pai/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/pairag/file/nodeparsers/pai/__init__.py -------------------------------------------------------------------------------- /integrations/pairag-file/pairag/file/nodeparsers/pai/constants.py: -------------------------------------------------------------------------------- 1 | # paragraph separator for splitter 2 | DEFAULT_NODE_PARSER_TYPE = "Sentence" 3 | DEFAULT_PARAGRAPH_SEP = "\n\n" 4 | DEFAULT_SENTENCE_CHUNK_OVERLAP = 200 5 | DEFAULT_SENTENCE_WINDOW_SIZE = 3 6 | DEFAULT_BREAKPOINT = 95 7 | DEFAULT_BUFFER_SIZE = 1 8 | -------------------------------------------------------------------------------- /integrations/pairag-file/pairag/file/nodeparsers/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/pairag/file/nodeparsers/utils/__init__.py -------------------------------------------------------------------------------- /integrations/pairag-file/pairag/file/readers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/pairag/file/readers/__init__.py -------------------------------------------------------------------------------- /integrations/pairag-file/pairag/file/readers/pai/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/pairag/file/readers/pai/__init__.py -------------------------------------------------------------------------------- /integrations/pairag-file/pairag/file/readers/pai/constants.py: -------------------------------------------------------------------------------- 1 | ACCEPTABLE_DOC_TYPES = set( 2 | [ 3 | ".html", 4 | ".htm", 5 | ".txt", 6 | ".docx", 7 | ".pdf", 8 | ".pptx", 9 | ".md", 10 | ".xls", 11 | ".jsonl", 12 | ".csv", 13 | ".xlsx", 14 | ".jpg", 15 | ".jpeg", 16 | ".png", 17 | ] 18 | ) 19 | -------------------------------------------------------------------------------- /integrations/pairag-file/pairag/file/readers/pai/file_readers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/pairag/file/readers/pai/file_readers/__init__.py -------------------------------------------------------------------------------- /integrations/pairag-file/pairag/file/readers/pai/file_readers/pai_image_reader.py: -------------------------------------------------------------------------------- 1 | """Tabular parser-CSV parser. 2 | 3 | Contains parsers for tabular data files. 4 | 5 | """ 6 | 7 | from pathlib import Path 8 | from typing import Any, Dict, List, Optional 9 | from fsspec import AbstractFileSystem 10 | import os 11 | from llama_index.core.readers.base import BaseReader 12 | from llama_index.core.schema import Document, ImageDocument 13 | 14 | from pairag.file.readers.pai.utils.image_utils import image_from_url 15 | from pairag.file.store.pai_image_store import PaiImageStore 16 | 17 | 18 | class PaiImageReader(BaseReader): 19 | """Image parser. 20 | 21 | Args: 22 | multi-modal llm (LLM) 23 | 24 | """ 25 | 26 | def __init__(self, image_store: PaiImageStore, *args: Any, **kwargs: Any) -> None: 27 | """Init params.""" 28 | super().__init__(*args, **kwargs) 29 | self.image_store = image_store 30 | 31 | def load_data( 32 | self, 33 | file_path: Path, 34 | extra_info: Optional[Dict] = None, 35 | fs: Optional[AbstractFileSystem] = None, 36 | ) -> List[Document]: 37 | if self.image_store is None: 38 | raise Exception( 39 | f"Oss config must be provided for image processing for file {file_path}." 40 | ) 41 | 42 | file_name = os.path.basename(file_path) 43 | image_url = self.image_store.upload_image( 44 | image_from_url(file_path), doc_name="image_docs" 45 | ) 46 | if extra_info is None: 47 | extra_info = {} 48 | extra_info["file_path"] = str(file_path) 49 | extra_info["file_name"] = file_name 50 | extra_info["image_url"] = image_url 51 | image_doc = ImageDocument(image_url=image_url, extra_info=extra_info) 52 | 53 | docs = [image_doc] 54 | return docs 55 | -------------------------------------------------------------------------------- /integrations/pairag-file/pairag/file/readers/pai/file_readers/pai_jsonl_reader.py: -------------------------------------------------------------------------------- 1 | """Tabular parser-Excel parser. 2 | 3 | Contains parsers for tabular data files. 4 | 5 | """ 6 | 7 | import os 8 | from pathlib import Path 9 | from typing import Any, Dict, List, Optional 10 | from fsspec import AbstractFileSystem 11 | from llama_index.core.readers.base import BaseReader 12 | from llama_index.core.schema import Document 13 | 14 | 15 | class PaiJsonLReader(BaseReader): 16 | """JsonL reader.""" 17 | 18 | def __init__(self, *args: Any, **kwargs: Any) -> None: 19 | """Init params.""" 20 | super().__init__(*args, **kwargs) 21 | 22 | def load_data( 23 | self, 24 | file_path: Path, 25 | extra_info: Optional[Dict] = None, 26 | fs: Optional[AbstractFileSystem] = None, 27 | ) -> List[Document]: 28 | with open(file_path, "r", encoding="utf-8") as file: 29 | json_lines = [line.strip() for line in file.readlines()] 30 | 31 | file_name = os.path.basename(file_path) 32 | extra_info = extra_info or {} 33 | extra_info["file_path"] = str(file_path) 34 | extra_info["file_name"] = file_name 35 | 36 | docs = [] 37 | for i, text in enumerate(json_lines): 38 | extra_info["row_number"] = i + 1 39 | docs.append(Document(text=text, metadata=extra_info)) 40 | return docs 41 | -------------------------------------------------------------------------------- /integrations/pairag-file/pairag/file/readers/pai/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/pairag/file/readers/pai/utils/__init__.py -------------------------------------------------------------------------------- /integrations/pairag-file/pairag/file/readers/pai/utils/cuda_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from loguru import logger 4 | 5 | 6 | USE_CUDA = os.environ.get("USE_CUDA", "false") 7 | 8 | 9 | def should_use_cuda(): 10 | if not torch.cuda.is_available(): 11 | return False 12 | 13 | if USE_CUDA.lower() == "true" or USE_CUDA == "1": 14 | return True 15 | else: 16 | return False 17 | 18 | 19 | def infer_cuda_device() -> str: 20 | if should_use_cuda(): 21 | logger.info("Using cuda device.") 22 | return "cuda" 23 | else: 24 | logger.info( 25 | "Will not use CUDA device acceleration. If you want to use cuda, please set the environment variable USE_CUDA=1." 26 | ) 27 | return "cpu" 28 | -------------------------------------------------------------------------------- /integrations/pairag-file/pairag/file/readers/pai/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from PIL import Image 3 | from io import BytesIO 4 | from loguru import logger 5 | import requests 6 | from urllib.parse import urlparse 7 | 8 | 9 | def is_remote_url(url_or_path: str | Path) -> bool: 10 | result = urlparse(str(url_or_path)) 11 | is_remote = result.scheme in ("http", "https", "ftp", "s3", "gs") 12 | return is_remote 13 | 14 | 15 | def image_from_bytes( 16 | image_blob: bytes, 17 | image_filename_or_extension: str, 18 | ): 19 | if image_filename_or_extension.lower().endswith( 20 | ".emf" 21 | ) or image_filename_or_extension.lower().endswith(".wmf"): 22 | # 暂时不处理Windows图元文件 23 | logger.warning( 24 | f"Skip processing EMF or WMF image: {image_filename_or_extension}" 25 | ) 26 | return None 27 | 28 | return Image.open(BytesIO(image_blob)) 29 | 30 | 31 | def image_from_url(image_url: str): 32 | if is_remote_url(image_url): 33 | try: 34 | response = requests.get(image_url) 35 | response.raise_for_status() # 检查请求是否成功 36 | 37 | # 将二进制数据转换为图像对象 38 | image = Image.open(BytesIO(response.content)) 39 | return image 40 | except Exception as ex: 41 | logger.warning( 42 | f"Failed to download image from URL: {image_url}. Error: {ex}" 43 | ) 44 | return None 45 | else: 46 | try: 47 | image = Image.open(image_url) 48 | return image 49 | except Exception as ex: 50 | logger.warning(f"Failed to open image from file: {image_url}. Error: {ex}") 51 | return None 52 | -------------------------------------------------------------------------------- /integrations/pairag-file/pairag/file/readers/pai/utils/magic-pdf.template.json: -------------------------------------------------------------------------------- 1 | { 2 | "bucket_info": { 3 | "bucket-name-1": ["ak", "sk", "endpoint"], 4 | "bucket-name-2": ["ak", "sk", "endpoint"] 5 | }, 6 | "models-dir": "model_repository/PDF-Extract-Kit-1.0/models", 7 | "layoutreader-model-dir": "model_repository/PDF-Extract-Kit-1.0/models/layoutreader", 8 | "device-mode": "cpu", 9 | "layout-config": { 10 | "model": "doclayout_yolo" 11 | }, 12 | "formula-config": { 13 | "mfd_model": "yolo_v8_mfd", 14 | "mfr_model": "unimernet_small", 15 | "enable": true 16 | }, 17 | "table-config": { 18 | "model": "rapid_table", 19 | "sub_model": "slanet_plus", 20 | "enable": true, 21 | "max_time": 400 22 | }, 23 | "latex-delimiter-config": { 24 | "display": { 25 | "left": "$$", 26 | "right": "$$" 27 | }, 28 | "inline": { 29 | "left": "$", 30 | "right": "$" 31 | } 32 | }, 33 | "config_version": "1.2.1" 34 | } 35 | -------------------------------------------------------------------------------- /integrations/pairag-file/pairag/file/readers/readme.md: -------------------------------------------------------------------------------- 1 | # Readers integrations 2 | -------------------------------------------------------------------------------- /integrations/pairag-file/pairag/file/store/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/pairag/file/store/__init__.py -------------------------------------------------------------------------------- /integrations/pairag-file/pairag/file/store/pai_image_store.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from io import BytesIO 4 | import math 5 | from PIL.PngImagePlugin import PngImageFile 6 | from loguru import logger 7 | 8 | from pairag.file.store.oss_store import PaiOssStore 9 | 10 | IMAGE_MAX_PIXELS = 512 * 512 11 | 12 | 13 | class PaiImageStore: 14 | def __init__( 15 | self, oss_store: PaiOssStore = None, save_prefix="pai_oss_images/" 16 | ) -> None: 17 | self.oss_store = oss_store 18 | self.save_prefix = save_prefix 19 | 20 | def upload_image(self, image: PngImageFile, doc_name: str): 21 | if image is None: 22 | return None 23 | 24 | if self.oss_store is None: 25 | logger.warning( 26 | "oss_store is not properly configured, skipping image upload." 27 | ) 28 | return None 29 | 30 | try: 31 | if image.mode != "RGB": 32 | image = image.convert("RGB") 33 | if image.width <= 50 or image.height <= 50: 34 | logger.warning(f"Skipping small image {image}") 35 | return None 36 | 37 | current_pixels = image.width * image.height 38 | 39 | # 检查像素总数是否超过限制 40 | if current_pixels > IMAGE_MAX_PIXELS: 41 | # 计算缩放比例以适应最大像素数 42 | scale = math.sqrt(IMAGE_MAX_PIXELS / current_pixels) 43 | new_width = int(image.width * scale) 44 | new_height = int(image.height * scale) 45 | 46 | # 调整图片大小 47 | image = image.resize((new_width, new_height), Image.LANCZOS) 48 | 49 | image_stream = BytesIO() 50 | image.save(image_stream, format="jpeg") 51 | 52 | image_stream.seek(0) 53 | data = image_stream.getvalue() 54 | 55 | image_url = self.oss_store.put_object_if_not_exists( 56 | data=data, 57 | file_ext=".jpeg", 58 | headers={ 59 | "x-oss-object-acl": "public-read" 60 | }, # set public read to make image accessible 61 | path_prefix=os.path.join(self.save_prefix, doc_name.strip()), 62 | ) 63 | logger.info( 64 | f"Saved image {image_url} from {doc_name} with width={image.width}, height={image.height}." 65 | ) 66 | return image_url 67 | except Exception as e: 68 | logger.warning(f"处理图片失败 '{image}': {e}") 69 | -------------------------------------------------------------------------------- /integrations/pairag-file/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["poetry-core"] 3 | build-backend = "poetry.core.masonry.api" 4 | 5 | [tool.poetry] 6 | name = "pairag-file" 7 | version = "1.2.2" 8 | description = "PAI-RAG file processing library." 9 | authors = [] 10 | readme = "README.md" 11 | packages = [ 12 | {include = "pairag/file"} 13 | ] 14 | 15 | [tool.poetry.dependencies] 16 | python = ">=3.11.0,<3.12" 17 | llama-index-core = "^0.12.27" 18 | llama-index-readers-file = "^0.4.3" 19 | docx2txt = "^0.8" 20 | pydantic = "^2.7.0" 21 | oss2 = "^2.18.5" 22 | torch = "2.2.2" 23 | transformers = "4.51.3" 24 | openpyxl = "^3.1.2" 25 | xlrd = "^2.0.1" 26 | markdown = "^3.6" 27 | chardet = "^5.2.0" 28 | modelscope = "^1.16.0" 29 | pre-commit = "^3.8.0" 30 | peft = "^0.12.0" 31 | python-pptx = "^1.0.2" 32 | aspose-slides = "^24.10.0" 33 | datasketch = "^1.6.5" 34 | anyio = "^4.6.2.post1" 35 | mistletoe = "^1.4.0" 36 | html2text = "^2024.2.26" 37 | rapidfuzz = "^3.13.0" 38 | python-docx = "^1.1.2" 39 | numpy = "1.26.4" 40 | loguru = "^0.7.3" 41 | llama-index-multi-modal-llms-openai = "^0.5.0" 42 | pytest = "^8.3.5" 43 | magic-pdf = {version = "1.3.10", extras = ["full"]} 44 | 45 | [tool.pytest.ini_options] 46 | asyncio_mode = "auto" 47 | -------------------------------------------------------------------------------- /integrations/pairag-file/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/tests/__init__.py -------------------------------------------------------------------------------- /integrations/pairag-file/tests/nodeparsers/test_markdown_mode_parser.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def test_markdown_parser(): 5 | from pairag.file.nodeparsers.pai.pai_markdown_parser import ( 6 | MarkdownNodeParser, 7 | ) 8 | from pairag.file.readers.pai.pai_data_reader import ( 9 | PaiDataReader, 10 | DataReaderConfig, 11 | ) 12 | 13 | reader_config = DataReaderConfig() 14 | directory_reader = PaiDataReader(reader_config=reader_config) 15 | 16 | input_dir = "tests/testdata/md_data" 17 | documents = directory_reader.load_data(file_path_or_directory=input_dir) 18 | md_node_parser = MarkdownNodeParser(enable_multimodal=False) 19 | splitted_nodes = [] 20 | for doc_node in documents: 21 | splitted_nodes.extend(md_node_parser.get_nodes_from_documents([doc_node])) 22 | 23 | text_list = [node.text for node in splitted_nodes] 24 | 25 | with open( 26 | "tests/testdata/json_data/pai_document.json", "r", encoding="utf-8" 27 | ) as file: 28 | chunk_text = json.load(file) 29 | 30 | assert text_list == chunk_text 31 | assert len(splitted_nodes) == 10 32 | -------------------------------------------------------------------------------- /integrations/pairag-file/tests/openailike_multimodal.py: -------------------------------------------------------------------------------- 1 | # dashscope multimodal llm的achat接口有问题,这里使用openai接口 2 | 3 | from llama_index.multi_modal_llms.openai import OpenAIMultiModal 4 | from typing import Dict, Any 5 | 6 | 7 | class OpenAIAlikeMultiModal(OpenAIMultiModal): 8 | def _get_model_kwargs(self, **kwargs: Any) -> Dict[str, Any]: 9 | base_kwargs = {"model": self.model, "temperature": self.temperature, **kwargs} 10 | if self.max_new_tokens is not None: 11 | # If max_tokens is None, don't include in the payload: 12 | # https://platform.openai.com/docs/api-reference/chat 13 | # https://platform.openai.com/docs/api-reference/completions 14 | base_kwargs["max_tokens"] = self.max_new_tokens 15 | return {**base_kwargs, **self.additional_kwargs} 16 | -------------------------------------------------------------------------------- /integrations/pairag-file/tests/readers/test_csv_reader.py: -------------------------------------------------------------------------------- 1 | from pairag.file.readers.pai.pai_data_reader import DataReaderConfig, PaiDataReader 2 | from pairag.file.readers.pai.file_readers.pai_csv_reader import PaiPandasCSVReader 3 | 4 | 5 | def test_pandas_csv_reader(): 6 | directory_reader = PaiDataReader(reader_config=DataReaderConfig()) 7 | input_dir = "tests/testdata/csv_data" 8 | directory_reader.file_readers[".csv"] = PaiPandasCSVReader( 9 | concat_rows=False, 10 | pandas_config={"header": [0, 1]}, 11 | ) 12 | documents = directory_reader.load_data(file_path_or_directory=input_dir) 13 | for doc in documents: 14 | print(doc) 15 | assert len(documents) == 897 16 | -------------------------------------------------------------------------------- /integrations/pairag-file/tests/readers/test_docx_reader.py: -------------------------------------------------------------------------------- 1 | from pairag.file.readers.pai.pai_data_reader import PaiDataReader, DataReaderConfig 2 | from pairag.file.readers.pai.file_readers.pai_docx_reader import PaiDocxReader 3 | 4 | 5 | def test_pai_docx_reader(): 6 | reader_config = DataReaderConfig() 7 | directory_reader = PaiDataReader(reader_config=reader_config) 8 | input_dir = "tests/testdata/docx_data" 9 | 10 | directory_reader.file_readers[".docx"] = PaiDocxReader() 11 | 12 | documents = directory_reader.load_data(file_path_or_directory=input_dir) 13 | assert "步骤一:部署RAG服务" in str(documents[0].text) 14 | -------------------------------------------------------------------------------- /integrations/pairag-file/tests/readers/test_excel_reader.py: -------------------------------------------------------------------------------- 1 | def test_pandas_excel_reader(): 2 | from pairag.file.readers.pai.pai_data_reader import ( 3 | PaiDataReader, 4 | DataReaderConfig, 5 | ) 6 | from pairag.file.readers.pai.file_readers.pai_excel_reader import ( 7 | PaiPandasExcelReader, 8 | ) 9 | 10 | reader_config = DataReaderConfig() 11 | directory_reader = PaiDataReader(reader_config=reader_config) 12 | input_dir = "tests/testdata/excel_data" 13 | directory_reader.file_readers[".xlsx"] = PaiPandasExcelReader( 14 | concat_rows=reader_config.concat_csv_rows, 15 | pandas_config={"header": [0, 1]}, 16 | ) 17 | directory_reader.file_readers[".xls"] = PaiPandasExcelReader( 18 | concat_rows=reader_config.concat_csv_rows, 19 | pandas_config={"header": [0, 1]}, 20 | ) 21 | 22 | documents = directory_reader.load_data(file_path_or_directory=input_dir) 23 | 24 | for doc in documents: 25 | print(doc) 26 | assert len(documents) == 7 27 | -------------------------------------------------------------------------------- /integrations/pairag-file/tests/readers/test_html_reader.py: -------------------------------------------------------------------------------- 1 | from pairag.file.readers.pai.pai_data_reader import PaiDataReader, DataReaderConfig 2 | from pairag.file.readers.pai.file_readers.pai_html_reader import PaiHtmlReader 3 | 4 | 5 | def test_pai_html_reader(): 6 | directory_reader = PaiDataReader(reader_config=DataReaderConfig()) 7 | input_dir = "tests/testdata/html_data" 8 | 9 | directory_reader.file_readers[".html"] = PaiHtmlReader() 10 | 11 | documents = directory_reader.load_data(file_path_or_directory=input_dir) 12 | assert len(documents) == 5 13 | -------------------------------------------------------------------------------- /integrations/pairag-file/tests/readers/test_image_reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import requests 4 | from pairag.file.store.oss_store import PaiOssStore 5 | from pairag.file.store.pai_image_store import PaiImageStore 6 | from pairag.file.readers.pai.file_readers.pai_image_reader import PaiImageReader 7 | 8 | if not os.environ.get("OSS_ACCESS_KEY_ID") or not os.environ.get( 9 | "OSS_ACCESS_KEY_SECRET" 10 | ): 11 | pytest.skip( 12 | reason="OSS_ACCESS_KEY_ID or OSS_ACCESS_KEY_SECRET not set", 13 | allow_module_level=True, 14 | ) 15 | 16 | 17 | @pytest.fixture 18 | def image_store(): 19 | oss_store = PaiOssStore( 20 | bucket_name="feiyue-test", endpoint="oss-cn-hangzhou.aliyuncs.com" 21 | ) 22 | return PaiImageStore(oss_store=oss_store) 23 | 24 | 25 | def test_image_reader(image_store): 26 | image_reader = PaiImageReader(image_store=image_store) 27 | test_image_path = "tests/testdata/image_data/11.jpg" 28 | image_doc = image_reader.load_data(file_path=test_image_path)[0] 29 | image_url = image_doc.metadata.get("image_url") 30 | assert image_url is not None, "image url should not be None." 31 | 32 | image_response = requests.get(image_url) 33 | assert image_response.status_code == 200, "image url should be valid." 34 | -------------------------------------------------------------------------------- /integrations/pairag-file/tests/readers/test_jsonl_reader.py: -------------------------------------------------------------------------------- 1 | def test_jsonl_reader(): 2 | from pairag.file.readers.pai.pai_data_reader import ( 3 | PaiDataReader, 4 | DataReaderConfig, 5 | ) 6 | from pairag.file.readers.pai.file_readers.pai_jsonl_reader import PaiJsonLReader 7 | 8 | reader_config = DataReaderConfig() 9 | directory_reader = PaiDataReader(reader_config=reader_config) 10 | 11 | input_dir = "tests/testdata/jsonl_data" 12 | directory_reader.file_readers[".jsonl"] = PaiJsonLReader() 13 | 14 | documents = directory_reader.load_data(file_path_or_directory=input_dir) 15 | for doc in documents: 16 | print(doc) 17 | assert len(documents) == 27 18 | -------------------------------------------------------------------------------- /integrations/pairag-file/tests/readers/test_pdf_reader.py: -------------------------------------------------------------------------------- 1 | def test_pai_pdf_reader(): 2 | from pairag.file.readers.pai.pai_data_reader import ( 3 | PaiDataReader, 4 | DataReaderConfig, 5 | ) 6 | from pairag.file.readers.pai.file_readers.pai_pdf_reader import PaiPDFReader 7 | 8 | reader_config = DataReaderConfig() 9 | directory_reader = PaiDataReader(reader_config=reader_config) 10 | 11 | input_dir = "tests/testdata/pdf_data" 12 | directory_reader.file_readers[".pdf"] = PaiPDFReader() 13 | 14 | documents = directory_reader.load_data(file_path_or_directory=input_dir) 15 | assert len(documents) > 0 16 | 17 | 18 | def test_is_horizontal_table(): 19 | from pairag.file.readers.pai.utils.markdown_utils import is_horizontal_table 20 | 21 | # example data 22 | horizontal_table_1 = [ 23 | ["Name", "Age", "City"], 24 | ["Alice", 30, "New York"], 25 | ["Bob", 25, "San Francisco"], 26 | ] 27 | 28 | horizontal_table_2 = [ 29 | ["Name", "Age", "discount"], 30 | ["Alice", 30, 0.3], 31 | ["Bob", 25, 0.4], 32 | ] 33 | 34 | horizontal_table_3 = [ 35 | ["Age", "discount", "amount"], 36 | [30, 0.3, 3], 37 | [25, 0.4, 7], 38 | [34, 0.2, 9], 39 | ] 40 | 41 | vertical_table = [ 42 | ["Field", "Record1", "Record2"], 43 | ["Name", "Alice", "Bob"], 44 | ["Age", 30, 25], 45 | ["City", "New York", "San Francisco"], 46 | ] 47 | assert is_horizontal_table(horizontal_table_1) 48 | assert is_horizontal_table(horizontal_table_2) 49 | assert is_horizontal_table(horizontal_table_3) 50 | assert not is_horizontal_table(vertical_table) 51 | -------------------------------------------------------------------------------- /integrations/pairag-file/tests/readers/test_post_process_multi_level_headings.py: -------------------------------------------------------------------------------- 1 | from pairag.file.readers.pai.file_readers.pai_pdf_reader import PaiPDFReader 2 | 3 | 4 | def test_post_process_multi_level_headings(): 5 | title_list = [ 6 | ("title_1", 6), 7 | ("title_2", 10), 8 | ("title_3", 8), 9 | ("title_4", 7), 10 | ("title_5", 14), 11 | ] 12 | 13 | pdf_process = PaiPDFReader() 14 | new_title_list = pdf_process.post_process_multi_level_headings(title_list, 0, 0) 15 | assert new_title_list == [ 16 | "### title_1", 17 | "## title_2", 18 | "### title_3", 19 | "### title_4", 20 | "# title_5", 21 | ] 22 | -------------------------------------------------------------------------------- /integrations/pairag-file/tests/test_image_caption_tool.py: -------------------------------------------------------------------------------- 1 | from pairag.file.nodeparsers.pai.image_caption_tool import ImageCaptionTool 2 | from tests.openailike_multimodal import OpenAIAlikeMultiModal 3 | import os 4 | import pytest 5 | 6 | 7 | if os.environ.get("DASHSCOPE_API_KEY") is None: 8 | pytest.skip(reason="DASHSCOPE_API_KEY not set", allow_module_level=True) 9 | 10 | 11 | @pytest.fixture 12 | def multimodal_llm() -> OpenAIAlikeMultiModal: 13 | return OpenAIAlikeMultiModal( 14 | model="qwen-vl-max", 15 | api_base="https://dashscope.aliyuncs.com/compatible-mode/v1", 16 | api_key=os.environ.get("DASHSCOPE_API_KEY"), 17 | ) 18 | 19 | 20 | def test_image_caption_tool(multimodal_llm: OpenAIAlikeMultiModal): 21 | test_image_url = "https://feiyue-test.oss-cn-hangzhou.aliyuncs.com/pai_oss_images/EAS%E8%AE%A1%E8%B4%B9%E8%AF%B4%E6%98%8E/8ea0d46984ba9166e96f8afb58132dc6.jpeg" 22 | 23 | image_caption_tool = ImageCaptionTool(multimodal_llm=multimodal_llm) 24 | caption = image_caption_tool.extract_url(test_image_url) 25 | print(caption) 26 | -------------------------------------------------------------------------------- /integrations/pairag-file/tests/test_image_utils.py: -------------------------------------------------------------------------------- 1 | from pairag.file.readers.pai.utils.image_utils import is_remote_url 2 | 3 | 4 | def test_is_remote_url(): 5 | assert is_remote_url("http://www.baidu.com") is True 6 | assert is_remote_url("https://www.tencent.com/1.jpg") is True 7 | assert is_remote_url("http://123123213123") is True 8 | assert is_remote_url("https://abcabc") is True 9 | assert is_remote_url("/a/b/c/1.jpg") is False 10 | assert is_remote_url("./1.txt") is False 11 | assert is_remote_url("../..") is False 12 | assert is_remote_url(".") is False 13 | assert is_remote_url("") is False 14 | assert is_remote_url("/") is False 15 | assert is_remote_url("data/") is False 16 | -------------------------------------------------------------------------------- /integrations/pairag-file/tests/testdata/csv_data/30天留存率_csv_two_header.csv: -------------------------------------------------------------------------------- 1 | time,metric 2 | date,rate 3 | 20240101,0.9375 4 | 20240201,0.9744 5 | 20240301,0.9767 6 | 20240401,0.9375 7 | 20240501,0.9091 8 | 20240601,0.9474 9 | 20240701,0.9667 -------------------------------------------------------------------------------- /integrations/pairag-file/tests/testdata/db_data/pets.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/tests/testdata/db_data/pets.db -------------------------------------------------------------------------------- /integrations/pairag-file/tests/testdata/db_data/pets.sqlite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/tests/testdata/db_data/pets.sqlite -------------------------------------------------------------------------------- /integrations/pairag-file/tests/testdata/docx_data/大模型RAG对话系统.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/tests/testdata/docx_data/大模型RAG对话系统.docx -------------------------------------------------------------------------------- /integrations/pairag-file/tests/testdata/excel_data/30天留存率_excel_two_header.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/tests/testdata/excel_data/30天留存率_excel_two_header.xlsx -------------------------------------------------------------------------------- /integrations/pairag-file/tests/testdata/image_data/11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/tests/testdata/image_data/11.jpg -------------------------------------------------------------------------------- /integrations/pairag-file/tests/testdata/jsonl_data/price.jsonl: -------------------------------------------------------------------------------- 1 | {"question": "335块", "answer": "订单报价"} 2 | {"question": "435元", "answer": "订单报价"} 3 | {"question": "435块", "answer": "订单报价"} 4 | {"question": "535元", "answer": "订单报价"} 5 | {"question": "535块", "answer": "订单报价"} 6 | {"question": "635元", "answer": "订单报价"} 7 | {"question": "635块", "answer": "订单报价"} 8 | {"question": "735元", "answer": "订单报价"} 9 | {"question": "735块", "answer": "订单报价"} 10 | {"question": "835元", "answer": "订单报价"} 11 | {"question": "835块", "answer": "订单报价"} 12 | {"question": "935元", "answer": "订单报价"} 13 | {"question": "935块", "answer": "订单报价"} 14 | {"question": "1357元", "answer": "订单报价"} 15 | {"question": "1357块", "answer": "订单报价"} 16 | {"question": "1100元", "answer": "订单报价"} 17 | {"question": "1179块", "answer": "订单报价"} 18 | {"question": "1279元", "answer": "订单报价"} 19 | {"question": "1279块", "answer": "订单报价"} 20 | {"question": "1379元", "answer": "订单报价"} 21 | {"question": "1379块", "answer": "订单报价"} 22 | {"question": "1479元", "answer": "订单报价"} 23 | {"question": "1479块", "answer": "订单报价"} 24 | {"question": "1579元", "answer": "订单报价"} 25 | {"question": "1579块", "answer": "订单报价"} 26 | {"question": "1679元", "answer": "订单报价"} 27 | {"question": "1679块", "answer": "订单报价"} -------------------------------------------------------------------------------- /integrations/pairag-file/tests/testdata/pdf_data/pai_document.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/integrations/pairag-file/tests/testdata/pdf_data/pai_document.pdf -------------------------------------------------------------------------------- /scripts/gunicorn.conf.py: -------------------------------------------------------------------------------- 1 | # gunicorn.conf.py 2 | bind = "0.0.0.0:8680" 3 | workers = 1 4 | worker_class = "uvicorn.workers.UvicornWorker" 5 | timeout = 600 6 | -------------------------------------------------------------------------------- /scripts/load_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | 5 | # 获取脚本所在的目录 6 | SCRIPT_DIR=$(dirname "$0") 7 | # 切换到脚本所在目录的上级目录 8 | cd "$SCRIPT_DIR/.." 9 | pwd 10 | 11 | python src/pairag/data_ingestion/main.py $* 12 | -------------------------------------------------------------------------------- /scripts/load_data_by_step.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | 5 | # 修改为你需要的目录 6 | INPUT_PATH="/testdata/testdata/small_txt" 7 | OUTPUT_PATH="/testdata/testdata/output/small0520" 8 | 9 | 10 | # 获取脚本所在的目录 11 | SCRIPT_DIR=$(dirname "$0") 12 | # 切换到脚本所在目录的上级目录 13 | cd "$SCRIPT_DIR/.." 14 | pwd 15 | 16 | echo "reading data source..." 17 | python src/pairag/data_ingestion/main.py data-source \ 18 | --enable-delta \ 19 | --input-path $INPUT_PATH \ 20 | --output-path $OUTPUT_PATH 21 | 22 | echo "parsing files..." 23 | python src/pairag/data_ingestion/main.py parse \ 24 | --input-path $OUTPUT_PATH \ 25 | --output-path $OUTPUT_PATH 26 | 27 | 28 | echo "splitting into chunks..." 29 | python src/pairag/data_ingestion/main.py split \ 30 | --input-path $OUTPUT_PATH \ 31 | --output-path $OUTPUT_PATH \ 32 | 33 | echo "embedding data..." 34 | python src/pairag/data_ingestion/main.py embed \ 35 | --input-path $OUTPUT_PATH \ 36 | --output-path $OUTPUT_PATH \ 37 | --num-gpus 0.5 \ 38 | --num-cpus 6 \ 39 | --memory 16 \ 40 | --concurrency 2 41 | 42 | 43 | echo "writing to data sink..." 44 | python src/pairag/data_ingestion/main.py data-sink \ 45 | --input-path $OUTPUT_PATH \ 46 | --output-path $OUTPUT_PATH \ 47 | -------------------------------------------------------------------------------- /scripts/locust.py: -------------------------------------------------------------------------------- 1 | from locust import HttpUser, task, between 2 | from random import randint 3 | 4 | auth_header = {"Authorization": ""} 5 | 6 | sample_queries = [ 7 | "找一部关于下雪的电影。", 8 | "找一部适合下雨天看的电影。", 9 | "心情不好的时候看什么电影?", 10 | "无聊的时候想看什么电影", 11 | "压力很大的时候看的电影", 12 | "好看的中文警匪片", 13 | "校园爱情电影", 14 | "金庸小说改编的古装武打剧", 15 | "好看的仙侠剧", 16 | "搞笑的电影", 17 | "评分高的的动画片", 18 | ] 19 | 20 | 21 | class SimpleRagUser(HttpUser): 22 | wait_time = between(0, 1) 23 | 24 | @task 25 | def qa(self): 26 | q_i = randint(0, len(sample_queries) - 1) 27 | query = sample_queries[q_i] 28 | 29 | _ = self.client.post( 30 | "/service/query", headers=auth_header, json={"question": query} 31 | ) 32 | # sprint(response.content.decode("utf-8")) 33 | 34 | 35 | class SimpleRetrievalUser(HttpUser): 36 | wait_time = between(0, 1) 37 | 38 | @task 39 | def qa(self): 40 | q_i = randint(0, len(sample_queries) - 1) 41 | query = sample_queries[q_i] 42 | 43 | _ = self.client.post( 44 | "/service/query/retrieval", headers=auth_header, json={"question": query} 45 | ) 46 | # print(response.content.decode("utf-8")) 47 | -------------------------------------------------------------------------------- /scripts/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | helpFunction() 4 | { 5 | echo "" 6 | echo "Usage: $0 -w workers -p port" 7 | echo -e "\t-w number of workers, default 1" 8 | echo -e "\t-p port, default 8680" 9 | exit 1 # Exit script after printing help 10 | } 11 | 12 | while getopts ":w:p:" opt 13 | do 14 | case "$opt" in 15 | w ) workers="$OPTARG" ;; 16 | p ) port="$OPTARG" ;; 17 | ? ) helpFunction ;; # Print helpFunction in case parameter is non-existent 18 | esac 19 | done 20 | 21 | workers="${workers:-1}" 22 | port="${port:-8680}" 23 | 24 | echo "Starting gunicorn with $workers workers on port $port..." 25 | 26 | gunicorn -w $workers -b "0.0.0.0:${port}" -c scripts/gunicorn.conf.py src.pairag.app.app:app --timeout 600 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="pairag.file", # 包名 5 | version="0.1.0", 6 | author="Your Name", 7 | author_email="you@example.com", 8 | description="PAI-RAG file utilities.", 9 | long_description=open("README.md").read(), 10 | long_description_content_type="text/markdown", 11 | url="https://github.com/aigc-apps/PAI-RAG", 12 | packages=find_packages(where="src/pairag/file"), # 从 src/ 下查找包 13 | package_dir={"": "src/pairag"}, # 指定源码根目录 14 | classifiers=[ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: MIT License", 17 | "Operating System :: OS Independent", 18 | ], 19 | python_requires=">=3.11", 20 | install_requires=[ 21 | "requests>=2.26.0", 22 | "click>=8.0", 23 | "llama-index-core==0.12.12", 24 | "llama-index-readers-file>=0.4.3", 25 | "docx2txt>=0.8", 26 | "pydantic>=2.7.0", 27 | "oss2>=2.18.5", 28 | "torch==2.2.2", 29 | "transformers==4.51.3", 30 | "openpyxl>=3.1.2", 31 | "xlrd>=2.0.1", 32 | "markdown>=3.6", 33 | "chardet>=5.2.0", 34 | "peft>=0.12.0", 35 | "python-pptx>=1.0.2", 36 | "aspose-slides>=24.10.0", 37 | "datasketch>=1.6.5", 38 | "mistletoe>=1.4.0", 39 | "html2text>=2024.2.26", 40 | "python-docx>=1.1.2", 41 | "numpy==1.26.4", 42 | "loguru>=0.7.3", 43 | "magic-pdf[full]==1.3.10", 44 | ], 45 | ) 46 | -------------------------------------------------------------------------------- /src/pairag/api/chat_completions.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | from fastapi.responses import StreamingResponse 3 | from pairag.chat.models import ChatCompletionRequest 4 | from pairag.core.chat_service import chat_service 5 | 6 | router_openai = APIRouter() 7 | 8 | 9 | @router_openai.get("/models") 10 | async def get_models(): 11 | return { 12 | "data": [ 13 | { 14 | "id": "default", 15 | "object": "model", 16 | "created": 1739298766, 17 | "owned_by": "pai", 18 | "permission": [], 19 | } 20 | ] 21 | } 22 | 23 | 24 | @router_openai.post("/chat/completions") 25 | async def chat_completions(request: ChatCompletionRequest): 26 | if not request.stream: 27 | response = await chat_service.achat(request) 28 | return response 29 | else: 30 | response = await chat_service.astream_chat(request) 31 | return StreamingResponse( 32 | response, 33 | media_type="text/event-stream", 34 | ) 35 | -------------------------------------------------------------------------------- /src/pairag/api/exception_handler.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Request 2 | from fastapi.responses import JSONResponse 3 | from pairag.core.models.errors import UserInputError, ServiceError 4 | 5 | 6 | async def _user_input_exception_handler(request: Request, exception: UserInputError): 7 | return JSONResponse( 8 | status_code=400, 9 | content={"message": f"Failed to process request input: {exception.msg}"}, 10 | ) 11 | 12 | 13 | async def _service_exception_handler(request: Request, exception: ServiceError): 14 | return JSONResponse(status_code=500, content={"message": f"Oops, {exception.msg}"}) 15 | 16 | 17 | def add_exception_handler(app: FastAPI): 18 | app.add_exception_handler(UserInputError, _user_input_exception_handler) 19 | app.add_exception_handler(ServiceError, _service_exception_handler) 20 | -------------------------------------------------------------------------------- /src/pairag/api/middleware.py: -------------------------------------------------------------------------------- 1 | from fastapi.middleware.cors import CORSMiddleware 2 | from fastapi import FastAPI, Request 3 | from starlette.middleware.base import BaseHTTPMiddleware 4 | from asgi_correlation_id import CorrelationIdMiddleware 5 | import time 6 | from loguru import logger 7 | 8 | 9 | class CustomMiddleWare(BaseHTTPMiddleware): 10 | def __init__(self, app): 11 | super().__init__(app) 12 | self.last_log_time = 0 13 | self.log_interval = 60 # seconds 14 | 15 | async def dispatch(self, request: Request, call_next): 16 | start_time = time.time() 17 | response = await call_next(request) 18 | process_time = time.time() - start_time 19 | host = request.client.host 20 | response.headers["X-Process-Time"] = str(process_time) 21 | response.headers["X-Client-IP"] = host 22 | 23 | if "get_upload_state" in str(request.url): 24 | current_time = time.time() 25 | if current_time - self.last_log_time >= self.log_interval: 26 | logger.info( 27 | f"Request: {request.method} {request.url} - Response Time: {process_time:.4f} seconds Host {host}" 28 | ) 29 | self.last_log_time = current_time 30 | else: 31 | logger.info( 32 | f"Request: {request.method} {request.url} - Response Time: {process_time:.4f} seconds Host {host}" 33 | ) 34 | return response 35 | 36 | 37 | def _configure_session_middleware(app): 38 | app.add_middleware( 39 | CorrelationIdMiddleware, 40 | header_name="X-Request-ID", 41 | ) 42 | 43 | 44 | def _configure_cors_middleware(app): 45 | app.add_middleware( 46 | CORSMiddleware, 47 | allow_origins=["*"], 48 | allow_methods=["*"], 49 | allow_headers=["*"], 50 | allow_credentials=False, 51 | ) 52 | 53 | 54 | def add_middlewares(app: FastAPI): 55 | # reset current middleware to allow modifying user provided list 56 | app.middleware_stack = None 57 | _configure_cors_middleware(app) 58 | _configure_session_middleware(app) 59 | app.add_middleware(CustomMiddleWare) 60 | app.build_middleware_stack() # rebuild middleware stack on-the-fly 61 | -------------------------------------------------------------------------------- /src/pairag/app/app.py: -------------------------------------------------------------------------------- 1 | # init trace 2 | import os 3 | import asyncio 4 | import threading 5 | from fastapi import FastAPI 6 | 7 | # setup models 8 | 9 | from pairag.utils.constants import DEFAULT_MODEL_DIR 10 | os.environ["PAIRAG_MODEL_DIR"] = DEFAULT_MODEL_DIR 11 | from pairag.utils.download_models import ModelScopeDownloader 12 | ModelScopeDownloader().load_rag_models() 13 | 14 | 15 | from contextlib import asynccontextmanager 16 | from pairag.utils.format_logging import format_logging 17 | from pairag.core.chat_service import chat_service 18 | from pairag.data_pipeline.job.rag_job_manager import job_manager 19 | from pairag.core.service_daemon import startup_event 20 | from loguru import logger 21 | 22 | format_logging() 23 | 24 | @asynccontextmanager 25 | async def lifespan(app: FastAPI): 26 | logger.info("Application starting up...") 27 | daemon_thread = threading.Thread(target=job_manager.execute_job, daemon=True) 28 | daemon_thread.start() 29 | 30 | asyncio.create_task(startup_event()) 31 | yield 32 | 33 | logger.info("Application shutting down...") 34 | 35 | 36 | def configure(app: FastAPI): 37 | from pairag.api.api_router_v1 import router_v1 38 | from pairag.api.api_chat import router_openai, router_chat 39 | from pairag.api.exception_handler import add_exception_handler 40 | from pairag.api.middleware import add_middlewares 41 | from pairag.web.webui import configure_webapp 42 | 43 | app.include_router(router_v1, prefix="/api/v1", tags=["api_v1"]) 44 | app.include_router(router_openai, prefix="/v1", tags=["chat_completions"]) 45 | app.include_router(router_chat, prefix="/chat", tags=["chat_api"]) 46 | 47 | chat_service.initialize() 48 | add_middlewares(app) 49 | add_exception_handler(app) 50 | configure_webapp(app) 51 | 52 | 53 | app = FastAPI(lifespan=lifespan) 54 | configure(app) 55 | -------------------------------------------------------------------------------- /src/pairag/app/constants.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import os 3 | 4 | _BASE_DIR = Path(__file__).parent.parent 5 | _ROOT_BASE_DIR = Path(__file__).parent.parent.parent.parent 6 | 7 | DEFAULT_APPLICATION_CONFIG_FILE = os.path.join(_BASE_DIR, "config/settings.toml") 8 | DEFAULT_APPLICATION_EXAMPLE_DATA_FILE = os.path.join( 9 | _ROOT_BASE_DIR, "example_data/pai_document.pdf" 10 | ) 11 | DEFAULT_HOST = "0.0.0.0" 12 | DEFAULT_PORT = 8001 13 | DEFAULT_RAG_URL = f"http://{DEFAULT_HOST}:{DEFAULT_PORT}/" 14 | -------------------------------------------------------------------------------- /src/pairag/chat/models.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import Any, List, Dict, Optional, AsyncGenerator, Generator 3 | from llama_index.core.base.llms.types import ChatMessage 4 | from llama_index.core.schema import NodeWithScore 5 | from pairag.integrations.query_transform.pai_query_transform import IntentResult 6 | 7 | 8 | class ContextDoc(BaseModel): 9 | text: str # 文档文本 10 | score: float # 文档得分 11 | metadata: Dict # 文档元数据 12 | image_url: str | None = None # 图片链接 13 | 14 | 15 | class RetrievalRequest(BaseModel): 16 | knowledgebase_id: Optional[str] = "default" # 知识库名称(index_name) 17 | query: str # 查询内容 18 | retrieval_settings: Optional[Dict] = None 19 | # ["retrieval_mode", "similarity_top_k", "vector_weight", "keyword_weight", "reranker_type", "similarity_threshold", "reranker_similarity_threshold", "reranker_model", "reranker_similarity_top_k"] 20 | 21 | 22 | class DocRecord(BaseModel): 23 | content: str # 包含知识库中数据源的文本块 24 | score: float # 结果与查询的相关性分数,范围:0~1 25 | title: str # 文档标题 26 | metadata: Dict # 包含数据源中文档的元数据属性及其值 27 | 28 | 29 | class NewRetrievalResponse(BaseModel): 30 | records: List[DocRecord] 31 | 32 | 33 | class RetrievalResponse(BaseModel): 34 | docs: List[ContextDoc] 35 | 36 | 37 | class ChatCompletionRequest(BaseModel): 38 | model: str # 模型名称 39 | messages: List[ChatMessage] # 上下文聊天 40 | stream: Optional[bool] = False # 流式输出 41 | index_name: Optional[str] = None # 索引名称 42 | chat_knowledgebase: Optional[bool] = False # 查询知识库 43 | search_web: Optional[bool] = False # 搜索网络 44 | return_reference: Optional[bool] = False # 返回参考 45 | chat_llm: Optional[bool] = False # llm聊天 46 | chat_db: Optional[bool] = False # 查询数据库 47 | chat_news: Optional[bool] = False # 使用新闻工具 48 | # llm args 49 | temperature: Optional[float] = None 50 | max_tokens: Optional[int] = None 51 | intent: Optional[IntentResult] = None # 意图 52 | 53 | class Config: 54 | extra = "allow" # allow extra fields 55 | 56 | 57 | class ChatResponseWrapper(BaseModel): 58 | response: Any 59 | additional_kwargs: Dict[str, Any] = {} 60 | source_nodes: List[NodeWithScore] = [] 61 | intent_result: Optional[IntentResult] = None 62 | 63 | def model_dump_json(self, exclude=None, **kwargs) -> str: 64 | if exclude is None: 65 | exclude = set() 66 | elif isinstance(exclude, dict): 67 | exclude = {k for k, v in exclude.items() if v} 68 | 69 | # to compatible with arize instrumentation 70 | if isinstance(self.response, (Generator, AsyncGenerator)): 71 | exclude.add("response") 72 | 73 | return super().model_dump_json(exclude=exclude, **kwargs) 74 | 75 | 76 | class EmbeddingInput(BaseModel): 77 | input: str | List[str] = None 78 | model: str = "bge-m3" 79 | -------------------------------------------------------------------------------- /src/pairag/config/evaluation/config.yaml: -------------------------------------------------------------------------------- 1 | experiment: 2 | # [text dataset][pai-eval] 3 | - name: "text_exp1" 4 | eval_data_path: "example_data/eval_docs_text" 5 | rag_setting_file: "src/pairag/config/evaluation/settings_eval_for_text.toml" 6 | eval_model_llm: 7 | source: "dashscope" 8 | model: "qwen-max" 9 | max_tokens: 1024 10 | use_pai_eval: false 11 | # [custom text dataset][crag] 12 | - name: "text_exp2" 13 | dataset: "crag" 14 | eval_data_path: "example_data/eval_docs_crag_small" 15 | rag_setting_file: "src/pairag/config/evaluation/settings_eval_for_crag_text.toml" 16 | eval_model_llm: 17 | source: "dashscope" 18 | model: "qwen-max" 19 | max_tokens: 1024 20 | # [multi-modal dataset] 21 | - name: "multi_modal_exp1" 22 | eval_data_path: "example_data/eval_docs_image" 23 | rag_setting_file: "src/pairag/config/evaluation/settings_eval_for_image.toml" 24 | eval_model_llm: 25 | source: "dashscope" 26 | model: "qwen-vl-max" 27 | max_tokens: 1024 28 | tested_multimodal_llm: 29 | source: "dashscope" 30 | model: "qwen-vl-max" 31 | max_tokens: 1024 32 | # [custom multi-modal dataset] 33 | - name: "multi_modal_exp2" 34 | qca_dataset_path: "example_data/eval_docs_image_example/multimodal_eval_dataset_zh_example.jsonl" 35 | rag_setting_file: "src/pairag/config/evaluation/settings_eval_for_image.toml" 36 | eval_model_llm: 37 | source: "dashscope" 38 | model: "qwen-vl-max" 39 | max_tokens: 1024 40 | tested_multimodal_llm: 41 | source: "dashscope" 42 | model: "qwen-vl-max" 43 | max_tokens: 1024 44 | -------------------------------------------------------------------------------- /src/pairag/config/evaluation/settings_eval_for_text.toml: -------------------------------------------------------------------------------- 1 | dynaconf_merge = true 2 | 3 | [rag] 4 | name = "pairag" 5 | version = "0.1.1" 6 | 7 | [rag.agent] 8 | custom_agent_config_file = "" 9 | agent_tool_type = "" 10 | 11 | [rag.chat_store] 12 | type = "Local" # [Local, Aliyun-Redis] 13 | host = "Aliyun-Redis host" 14 | password = "Aliyun-Redis user:pwd" 15 | persist_path = "localdata/eval_exp_data/storage" 16 | 17 | [rag.data_analysis] 18 | type = "pandas" 19 | nl2sql_prompt = "给定一个输入问题,创建一个语法正确的{dialect}查询语句来执行,不要从特定的表中查询所有列,只根据问题查询几个相关的列。请注意只使用你在schema descriptions 中看到的列名。\n=====\n 小心不要查询不存在的列。请注意哪个列位于哪个表中。必要时,请使用表名限定列名。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: " 20 | 21 | [rag.data_reader] 22 | type = "SimpleDirectoryReader" 23 | 24 | # embedding configurations, source support API: OpenAI,DashScope; and local model:HuggingFace 25 | # if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model 26 | # eg. 27 | # source = "HuggingFace" 28 | # model = "bge-m3" 29 | # embed_batch_size = 10 30 | [rag.embedding] 31 | source = "DashScope" 32 | embed_batch_size = 10 33 | 34 | [rag.index] 35 | persist_path = "localdata/eval_exp_data/storage" 36 | enable_multimodal = false 37 | vector_store.type = "FAISS" 38 | 39 | # llm configurations, source support API: OpenAI,DashScope or PAI-EAS's deployment 40 | # eg. 41 | # source = "PaiEas" 42 | # model = "" 43 | # endpoint = "" 44 | # token = "" 45 | [rag.llm] 46 | source = "DashScope" 47 | model = "qwen-turbo" 48 | 49 | [rag.multimodal_embedding] 50 | source = "cnclip" 51 | 52 | [rag.multimodal_llm] 53 | source = "dashscope" 54 | model = "qwen-vl-plus" 55 | 56 | [rag.node_enhancement] 57 | tree_depth = 3 58 | max_clusters = 52 59 | proba_threshold = 0.10 60 | 61 | [rag.node_parser] 62 | type = "Sentence" 63 | chunk_size = 500 64 | chunk_overlap = 10 65 | enable_multimodal = false 66 | 67 | [rag.oss_store] 68 | bucket = "" 69 | endpoint = "oss-cn-hangzhou.aliyuncs.com" 70 | 71 | [rag.postprocessor] 72 | reranker_type = "no-reranker" # [simple-weighted-reranker, model-based-reranker] 73 | reranker_model = "bge-reranker-base" # [bge-reranker-base, bge-reranker-large] 74 | keyword_weight = 0.3 75 | vector_weight = 0.7 76 | similarity_threshold = 0.5 77 | top_n = 2 78 | 79 | [rag.query_transform] 80 | type = "" 81 | 82 | [rag.retriever] 83 | similarity_top_k = 3 84 | retrieval_mode = "hybrid" # [hybrid, embedding, keyword, router] 85 | query_rewrite_n = 1 # set to 1 to disable query generation 86 | 87 | [rag.search] 88 | source = "quark" 89 | 90 | [rag.synthesizer] 91 | type = "SimpleSummarize" 92 | text_qa_template = "参考内容信息如下\n---------------------\n{context_str}\n---------------------根据提供内容而非其他知识回答问题.\n问题: {query_str}\n答案: \n" 93 | 94 | [rag.trace] 95 | type = "pai_trace" 96 | endpoint = "http://tracing-analysis-dc-hz.aliyuncs.com:8090" 97 | token = "" 98 | -------------------------------------------------------------------------------- /src/pairag/config/settings.toml: -------------------------------------------------------------------------------- 1 | dynaconf_merge = true 2 | 3 | [rag] 4 | name = "pairag" 5 | version = "0.1.1" 6 | 7 | [rag.agent] 8 | custom_agent_config_file = "" 9 | agent_tool_type = "" 10 | 11 | [rag.chat] 12 | 13 | [rag.chat_store] 14 | type = "Local" # [Local, Aliyun-Redis] 15 | host = "Aliyun-Redis host" 16 | password = "Aliyun-Redis user:pwd" 17 | persist_path = "localdata/storage" 18 | 19 | [rag.data_analysis] 20 | type = "pandas" 21 | 22 | [rag.data_reader] 23 | type = "SimpleDirectoryReader" 24 | 25 | # embedding configurations, source support API: OpenAI,DashScope; and local model:HuggingFace 26 | # if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model 27 | # eg. 28 | # source = "HuggingFace" 29 | # model = "bge-m3" 30 | # embed_batch_size = 10 31 | [rag.embedding] 32 | source = "huggingface" 33 | embed_batch_size = 10 34 | enable_sparse = false 35 | 36 | [rag.index] 37 | persist_path = "localdata/knowledgebase/default/.index/.faiss" 38 | enable_multimodal = true 39 | vector_store.type = "FAISS" 40 | 41 | [rag.llm] 42 | source = "openai_compatible" 43 | 44 | # llm configurations, source support API: OpenAI,DashScope or PAI-EAS's deployment 45 | # eg. 46 | # source = "PaiEas" 47 | # model = "" 48 | # base_url = "" 49 | # api_key = "" 50 | # vision_support = false 51 | [[rag.llms]] 52 | 53 | [rag.multimodal_embedding] 54 | source = "cnclip" 55 | 56 | [rag.multimodal_llm] 57 | source = "openai_compatible" 58 | 59 | [rag.node_enhancement] 60 | tree_depth = 3 61 | max_clusters = 52 62 | proba_threshold = 0.10 63 | 64 | [rag.node_parser] 65 | type = "Sentence" 66 | chunk_size = 500 67 | chunk_overlap = 10 68 | enable_multimodal = true 69 | 70 | [rag.oss_store] 71 | bucket = "" 72 | endpoint = "oss-cn-hangzhou.aliyuncs.com" 73 | 74 | [rag.postprocessor] 75 | reranker_type = "no-reranker" # [simple-weighted-reranker, model-based-reranker] 76 | reranker_model = "bge-reranker-base" # [bge-reranker-base, bge-reranker-large] 77 | similarity_threshold = 0.5 78 | top_n = 2 79 | 80 | [rag.query_rewrite] 81 | enabled = true 82 | 83 | [rag.query_rewrite.llm] 84 | 85 | [rag.retriever] 86 | similarity_top_k = 5 87 | retrieval_mode = "hybrid" # [hybrid, embedding, keyword, router] 88 | query_rewrite_n = 1 # set to 1 to disable query generation 89 | 90 | [rag.search] 91 | source = "bing" 92 | 93 | [rag.synthesizer] 94 | type = "SimpleSummarize" 95 | -------------------------------------------------------------------------------- /src/pairag/core/models/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, ConfigDict 2 | from pairag.integrations.llms.pai.llm_config import OpenAICompatibleLlmConfig 3 | from pairag.integrations.synthesizer.prompt_templates import ( 4 | DEFAULT_SYSTEM_ROLE_TEMPLATE, 5 | ) 6 | from pairag.utils.prompt_template import ( 7 | CHAT_LLM_REWRITE_PROMPT_ZH, 8 | KNOWLEDGEBASE_REWRITE_PROMPT_ZH, 9 | NEWS_REWRITE_PROMPT_ZH, 10 | NL2SQL_REWRITE_PROMPT_ZH, 11 | REWRITE_PROMPT_ROLE_ZH, 12 | WEBSEARCH_REWRITE_PROMPT_ZH, 13 | ) 14 | from pairag.utils.prompt_template import ( 15 | DEFALT_LLM_CHAT_PROMPT_TEMPL, 16 | ) 17 | 18 | 19 | class ChatConfig(BaseModel): 20 | model_id: str | None = None 21 | model_config = ConfigDict(coerce_numbers_to_str=True) 22 | 23 | 24 | class QueryRewriteConfig(BaseModel): 25 | enabled: bool = True 26 | base_prompt_template_str: str = REWRITE_PROMPT_ROLE_ZH 27 | llm_tool_prompt_str: str = CHAT_LLM_REWRITE_PROMPT_ZH 28 | knowledge_tool_prompt_str: str = KNOWLEDGEBASE_REWRITE_PROMPT_ZH 29 | websearch_tool_prompt_str: str = WEBSEARCH_REWRITE_PROMPT_ZH 30 | db_tool_prompt_str: str = NL2SQL_REWRITE_PROMPT_ZH 31 | news_tool_prompt_str: str = NEWS_REWRITE_PROMPT_ZH 32 | 33 | model_id: str | None = None 34 | llm: OpenAICompatibleLlmConfig | None = OpenAICompatibleLlmConfig() 35 | model_config = ConfigDict(coerce_numbers_to_str=True) 36 | 37 | 38 | class AliyunTextModerationPlusConfig(BaseModel): 39 | endpoint: str | None = None 40 | region: str | None = None 41 | access_key_id: str | None = None 42 | access_key_secret: str | None = None 43 | custom_advice: str | None = None 44 | 45 | def is_enabled(self) -> bool: 46 | return ( 47 | self.access_key_id is not None 48 | and self.access_key_secret is not None 49 | and len(self.access_key_id) > 0 50 | and len(self.access_key_secret) > 0 51 | and self.endpoint is not None 52 | and self.endpoint != "" 53 | and self.region is not None 54 | and self.region != "" 55 | ) 56 | 57 | 58 | class NodeEnhancementConfig(BaseModel): 59 | tree_depth: int = 3 60 | max_clusters: int = 52 61 | proba_threshold: float = 0.10 62 | 63 | 64 | class OssStoreConfig(BaseModel): 65 | bucket: str | None = None 66 | endpoint: str = "oss-cn-hangzhou.aliyuncs.com" 67 | ak: str | None = None 68 | sk: str | None = None 69 | model_config = ConfigDict(coerce_numbers_to_str=True) 70 | 71 | 72 | class SynthesizerConfig(BaseModel): 73 | use_multimodal_llm: bool = False 74 | system_role_template: str = DEFAULT_SYSTEM_ROLE_TEMPLATE 75 | custom_prompt_template: str = DEFALT_LLM_CHAT_PROMPT_TEMPL 76 | -------------------------------------------------------------------------------- /src/pairag/core/models/errors.py: -------------------------------------------------------------------------------- 1 | class UserInputError(Exception): 2 | def __init__(self, msg: str): 3 | self.msg = msg 4 | 5 | 6 | class ServiceError(Exception): 7 | def __init__(self, msg: str): 8 | self.msg = msg 9 | -------------------------------------------------------------------------------- /src/pairag/core/models/state.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class FileServiceState: 5 | def __init__(self, key): 6 | self.state_key = key 7 | self.state_value = -1 8 | self.state_value = self.check_state() 9 | 10 | def check_state(self): 11 | if not os.path.exists(self.state_key): 12 | return 0 13 | 14 | # dummpy read to update file state from oss mount path 15 | _ = open(self.state_key, "r").readline() 16 | 17 | mtime = os.path.getmtime(self.state_key) 18 | if mtime != self.state_value: 19 | return mtime 20 | return 0 21 | 22 | def update_state(self, new_value): 23 | self.state_value = new_value 24 | -------------------------------------------------------------------------------- /src/pairag/core/rag_environment.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class RagServiceEnvironment: 5 | def __init__(self): 6 | self.IS_API_INSTANCE = os.getenv("DEPLOY_MODE", "web").upper() == "API" 7 | self.SHOULD_START_WEB = not self.IS_API_INSTANCE 8 | 9 | 10 | service_environment = RagServiceEnvironment() 11 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_NODE_SOURCE_FIELD = "source" 2 | DEFAULT_MODIFIED_AT_FIELD = "modified_at" 3 | DEFAULT_MD5_FIELD = "document_md5" 4 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/delta/list_delta.py: -------------------------------------------------------------------------------- 1 | # 通过向量数据库中的record比对确定delta信息 2 | from pairag.data_pipeline.delta.milvus import ( 3 | list_docs_in_milvus_collection, 4 | ) 5 | from pairag.data_pipeline.utils.vectordb_utils import get_vector_store 6 | from llama_index.vector_stores.milvus import MilvusVectorStore 7 | 8 | from loguru import logger 9 | 10 | 11 | """ 12 | 获取PAI-RAG服务索引中的知识库文件列表 13 | """ 14 | 15 | 16 | def list_files_from_rag_service( 17 | rag_endpoint: str, 18 | rag_api_key: str, 19 | knowledgebase: str, 20 | embed_dims: int, 21 | oss_path_prefix: str, 22 | ): 23 | vector_store = get_vector_store( 24 | rag_endpoint=rag_endpoint, 25 | rag_api_key=rag_api_key, 26 | knowledgebase=knowledgebase, 27 | embed_dims=embed_dims, 28 | ) 29 | if isinstance(vector_store, MilvusVectorStore): 30 | return list_docs_in_milvus_collection( 31 | collection=vector_store._collection, 32 | oss_path_prefix=oss_path_prefix, 33 | ) 34 | else: 35 | logger.error("Only support Milvus vector store for now.") 36 | raise NotImplementedError 37 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/delta/milvus.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from pymilvus import Collection 3 | 4 | from pairag.data_pipeline.constants import ( 5 | DEFAULT_MODIFIED_AT_FIELD, 6 | DEFAULT_NODE_SOURCE_FIELD, 7 | ) 8 | from pairag.data_pipeline.delta.models import DocItem 9 | 10 | 11 | def list_docs_in_milvus_collection( 12 | collection: Collection, 13 | oss_path_prefix: str, 14 | ) -> Dict[str, DocItem]: 15 | # only return the documents under the oss_path_prefix 16 | if oss_path_prefix: 17 | expr = f"{DEFAULT_NODE_SOURCE_FIELD} like '{oss_path_prefix}%'" 18 | else: 19 | expr = None 20 | 21 | iterator = collection.query_iterator( 22 | batch_size=10, 23 | output_fields=[DEFAULT_NODE_SOURCE_FIELD, DEFAULT_MODIFIED_AT_FIELD], 24 | expr=expr, 25 | ) 26 | results: Dict[str, DocItem] = {} 27 | while True: 28 | fetch_data = iterator.next() 29 | if not fetch_data: 30 | iterator.close() 31 | break 32 | 33 | # scan the fetch data, get the latest modified time and the node ids 34 | for record in fetch_data: 35 | source = record[DEFAULT_NODE_SOURCE_FIELD] 36 | if source not in results: 37 | results[source] = DocItem( 38 | doc_path=source, 39 | modified_time=record[DEFAULT_MODIFIED_AT_FIELD], 40 | node_ids=[record["id"]], 41 | ) 42 | else: 43 | results[source].modified_time = min( 44 | record[DEFAULT_MODIFIED_AT_FIELD], 45 | results[source].modified_time, 46 | ) 47 | results[source].node_ids.append(record["id"]) 48 | 49 | return results 50 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/delta/models.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from pydantic import BaseModel 3 | 4 | 5 | class DocItem(BaseModel): 6 | doc_path: str 7 | node_ids: List[str] 8 | modified_time: float 9 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/e2e_config.yaml: -------------------------------------------------------------------------------- 1 | # process schedule 2 | # a list of several pairag offline process operators with their arguments 3 | 4 | operators: 5 | - name: "data_source" 6 | enable_delta: true 7 | supported_file_types_str: "pdf,txt,csv,xlsx,xls,docx,md,html,htm,jsonl" 8 | pairag_token: 9 | pairag_endpoint: 10 | pairag_knowledgebase: 11 | pairag_embed_dims: 1024 12 | 13 | - name: "parse" 14 | num_cpus: 2 15 | memory: 8 16 | num_gpus: 0 17 | enable_pdf_ocr: false 18 | concat_sheet_rows: false 19 | concurrency: 6 20 | 21 | - name: "split" 22 | node_parser_type: "Sentence" 23 | chunk_size: 1024 24 | chunk_overlap: 20 25 | paragraph_separator: "\n\n" 26 | num_cpus: 1 27 | memory: 4 28 | concurrency: 10 29 | 30 | - name: "embed" 31 | model: "bge-m3" 32 | source: "huggingface" 33 | batch_size: 300 34 | num_cpus: 4 35 | num_gpus: 1 36 | memory: 10 37 | concurrency: 1 38 | 39 | - name: "data_sink" 40 | pairag_token: 41 | pairag_endpoint: 42 | pairag_knowledgebase: 43 | pairag_embed_dims: 1024 44 | num_cpus: 1 45 | memory: 2 46 | batch_size: 500 47 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/models/config/datasource.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from pydantic import BaseModel 3 | 4 | 5 | class DataSourceConfig(BaseModel): 6 | input_path: str 7 | output_path: str 8 | enable_delta: bool = False 9 | file_extensions: Optional[List[str]] = None 10 | 11 | # connect to rag service 12 | pairag_token: Optional[str] = None 13 | pairag_endpoint: Optional[str] = None 14 | pairag_knowledgebase: Optional[str] = "default" 15 | pairag_embed_dims: Optional[int] = 1024 16 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/models/config/operator.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Optional 3 | from pydantic import BaseModel 4 | from llama_index.core.constants import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP 5 | 6 | from pairag.file.nodeparsers.pai.constants import ( 7 | DEFAULT_NODE_PARSER_TYPE, 8 | DEFAULT_PARAGRAPH_SEP, 9 | ) 10 | 11 | 12 | class OperatorName(str, Enum): 13 | DATA_SOURCE = "data_source" 14 | PARSER = "parse" 15 | SPLITTER = "split" 16 | EMBEDDER = "embed" 17 | DATA_SINK = "data_sink" 18 | 19 | 20 | class BaseOperatorConfig(BaseModel): 21 | """ 22 | Base class for operator configs. 23 | """ 24 | 25 | name: OperatorName 26 | num_cpus: float = 1 27 | num_gpus: float = 0 28 | memory: float = 2 29 | 30 | batch_size: int = 10 31 | input_path: str 32 | output_path: str 33 | model_dir: str = None 34 | concurrency: int = 1 35 | 36 | 37 | class ParserConfig(BaseOperatorConfig): 38 | """ 39 | Config for parse operator. 40 | """ 41 | 42 | name: OperatorName = OperatorName.PARSER 43 | enable_pdf_ocr: bool = False 44 | concat_sheet_rows: bool = False 45 | recursive: bool = True 46 | 47 | 48 | class SplitterConfig(BaseOperatorConfig): 49 | """ 50 | Config for split operator. 51 | """ 52 | 53 | name: OperatorName = OperatorName.SPLITTER 54 | paragraph_separator: str = DEFAULT_PARAGRAPH_SEP 55 | node_parser_type: str = DEFAULT_NODE_PARSER_TYPE 56 | chunk_size: int = DEFAULT_CHUNK_SIZE 57 | chunk_overlap: int = DEFAULT_CHUNK_OVERLAP 58 | 59 | 60 | class EmbedderConfig(BaseOperatorConfig): 61 | """ 62 | Config for embed operator. 63 | """ 64 | 65 | name: OperatorName = OperatorName.EMBEDDER 66 | model: str = "bge-m3" 67 | connection_name: Optional[str] = None 68 | workspace_id: Optional[str] = None 69 | enable_sparse: bool = False 70 | source: str = "huggingface" 71 | batch_size: int = 32 72 | 73 | 74 | class SinkConfig(BaseOperatorConfig): 75 | """ 76 | Config for embed operator. 77 | """ 78 | 79 | name: OperatorName = OperatorName.DATA_SINK 80 | pairag_endpoint: str 81 | pairag_token: str 82 | pairag_knowledgebase: str 83 | pairag_embed_dims: int 84 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/models/file/event.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class NodeOperationType(str, Enum): 5 | ADD = "add" 6 | DELETE = "delete" 7 | 8 | 9 | class FileChangeType(str, Enum): 10 | ADD = "add_file" 11 | MODIFY = "modify_file" 12 | DELETE = "delete_file" 13 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/models/models.py: -------------------------------------------------------------------------------- 1 | import time 2 | from pydantic import BaseModel, Field 3 | from typing import Dict 4 | from enum import Enum 5 | 6 | from pairag.utils.time_utils import get_current_time_str 7 | 8 | 9 | class FileOperationType(int, Enum): 10 | ADD = 1 11 | UPDATE = 2 12 | DELETE = 3 13 | 14 | 15 | class FileChange(BaseModel): 16 | task_id: str 17 | file_name: str 18 | file_hash: str 19 | operation: FileOperationType 20 | knowledgebase: str 21 | 22 | 23 | class FileProcessStatus(str, Enum): 24 | PENDING = "pending" 25 | Parsing = "parsing" 26 | Chunking = "chunking" 27 | Embedding = "embedding" 28 | Persisting = "persisting" 29 | Done = "done" 30 | Failed = "failed" 31 | 32 | 33 | class FileItem(FileChange): 34 | status: FileProcessStatus 35 | last_modified_time: str = Field(default_factory=lambda: get_current_time_str()) 36 | timestamp: float = Field(default_factory=lambda: time.time()) 37 | failed_reason: str | None = None 38 | 39 | 40 | class FileProcessResult(BaseModel): 41 | status: FileProcessStatus 42 | message: str | None = None 43 | 44 | 45 | class TaskInfo(BaseModel): 46 | knowledgebase: str 47 | task_map: Dict[str, FileItem] = {} 48 | last_modified_time: str = Field(default_factory=lambda: get_current_time_str()) 49 | 50 | 51 | class JobStatus(BaseModel): 52 | task_statuses: Dict[str, TaskInfo] = {} 53 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/operators/base.py: -------------------------------------------------------------------------------- 1 | from pairag.data_pipeline.utils.cuda_utils import is_cuda_available 2 | 3 | 4 | OUTPUT_BATCH_SIZE = 50000 5 | 6 | 7 | class BaseOperator: 8 | def __init__( 9 | self, 10 | name: str = "default", 11 | num_cpus: float = 1, 12 | num_gpus: float = 0, 13 | model_dir: str = None, 14 | **kwargs, 15 | ): 16 | self.name = name 17 | self.num_cpus = num_cpus 18 | self.num_gpus = num_gpus 19 | self.model_dir = model_dir 20 | self.kwargs = kwargs 21 | 22 | def process(self, *args, **kwargs): 23 | raise NotImplementedError 24 | 25 | def use_cuda(self): 26 | return self.num_gpus > 0 and is_cuda_available() 27 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/operators/split.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | from llama_index.core.schema import TextNode 3 | from pairag.data_pipeline.models.config.operator import SplitterConfig 4 | from pairag.data_pipeline.models.file.event import NodeOperationType 5 | from pairag.data_pipeline.operators.base import BaseOperator 6 | from pairag.data_pipeline.utils.node_utils import ( 7 | metadata_dict_to_node_v2, 8 | node_to_metadata_dict_v2, 9 | ) 10 | from pairag.file.nodeparsers.pai.pai_node_parser import ( 11 | NodeParserConfig, 12 | PaiNodeParser, 13 | ) 14 | from loguru import logger 15 | 16 | 17 | class Splitter(BaseOperator): 18 | def __init__(self, config: SplitterConfig): 19 | super().__init__( 20 | name=config.name, 21 | num_cpus=config.num_cpus, 22 | num_gpus=config.num_gpus, 23 | model_dir=config.model_dir, 24 | ) 25 | 26 | self.node_parser_config = NodeParserConfig( 27 | type=config.node_parser_type, 28 | chunk_size=config.chunk_size, 29 | chunk_overlap=config.chunk_overlap, 30 | ) 31 | self.node_parser = PaiNodeParser( 32 | parser_config=self.node_parser_config, 33 | ) 34 | logger.info( 35 | f"""SplitterActor init finished with following parameters: {config}""" 36 | ) 37 | 38 | def process_delete(self, row: Dict[str, Any]) -> List[Dict[str, Any]]: 39 | return [row] 40 | 41 | def process_add(self, row: Dict[str, Any]) -> List[Dict[str, Any]]: 42 | doc = metadata_dict_to_node_v2(row) 43 | nodes = [] 44 | 45 | splitted_nodes = self.node_parser.get_nodes_from_documents([doc]) 46 | for node in splitted_nodes: 47 | # 去掉文本内容为空的分块 48 | if isinstance(node, TextNode) and not node.text: 49 | continue 50 | node_dict = node_to_metadata_dict_v2(node) 51 | node_dict["operation"] = row.get("operation") 52 | node_dict["operation_reason"] = row.get("operation_reason") 53 | nodes.append(node_dict) 54 | 55 | logger.info(f"Split {len(splitted_nodes)} nodes from {doc.node_id}.") 56 | return nodes 57 | 58 | def __call__(self, row: Dict[str, Any]) -> List[Dict[str, Any]]: 59 | if row.get("operation") == NodeOperationType.DELETE: 60 | return self.process_delete(row) 61 | else: 62 | return self.process_add(row) 63 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/readme.md: -------------------------------------------------------------------------------- 1 | ## 分步骤执行示例 2 | 3 | 设置output_path,每次执行选择一个新路径,做好数据隔离 4 | 5 | ```shell 6 | export OUTPUT_PATH=/testdata/output0428 7 | ``` 8 | 9 | 1. Read命令 10 | 11 | ```shell 12 | python src/pairag/data_ingestion/main.py read --input-path /testdata/testdata/txt --output-path $OUTPUT_PATH/read --enable-delta 13 | 14 | ``` 15 | 16 | 2. Parse命令 17 | 18 | ```shell 19 | python src/pairag/data_ingestion/main.py parse --input-path $OUTPUT_PATH/read --output-path $OUTPUT_PATH/parse \ 20 | --num-cpus 2 --memory 8 21 | ``` 22 | 23 | 3. Split命令 24 | 25 | ```shell 26 | python src/pairag/data_ingestion/main.py split --input-path $OUTPUT_PATH/parse --output-path $OUTPUT_PATH/split \ 27 | --num-cpus 1 --memory 2 28 | ``` 29 | 30 | 4. Embed命令 31 | 32 | ```shell 33 | python src/pairag/data_ingestion/main.py embed --input-path $OUTPUT_PATH/split --output-path $OUTPUT_PATH/embed \ 34 | --num-cpus 5 --memory 10 --num-gpus 1 --source huggingface --model bge-m3 35 | ``` 36 | 37 | 5. Write命令,存入向量数据库 38 | 39 | ```shell 40 | python src/pairag/data_ingestion/main.py write --input-path $OUTPUT_PATH/embed --num-cpus 10 --memory 20 41 | ``` 42 | 43 | ## E2E执行示例 44 | 45 | 设置output_path,每次执行选择一个新路径,做好数据隔离 46 | 47 | ```shell 48 | export OUTPUT_PATH=/testdata/output0505 49 | export pairag_ENDPOINT= 50 | export pairag_TOKEN= 51 | 52 | python src/pairag/data_pipeline/main.py e2e --input-path xx --output-path $OUTPUT_PATH 53 | ``` 54 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/utils/concurrency_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | import psutil 4 | from loguru import logger 5 | 6 | 7 | def compute_concurrency_count( 8 | num_cpus: int, 9 | memory: int, 10 | num_gpus: int = 0, 11 | ): 12 | concurrency = sys.maxsize 13 | 14 | if num_gpus > 0: 15 | import torch 16 | 17 | cuda_device_count = torch.cuda.device_count() 18 | if cuda_device_count > 0: 19 | concurrency = math.floor(min(concurrency, cuda_device_count // num_gpus)) 20 | logger.info( 21 | f"Available CUDA devices: {cuda_device_count}, required gpus {num_gpus}, updated conccurency {concurrency}." 22 | ) 23 | else: 24 | logger.error("No CUDA devices found.") 25 | raise ValueError("No CUDA devices found.") 26 | 27 | cpu_available = psutil.cpu_count() - 4 28 | mem_available = psutil.virtual_memory().available 29 | mem_available = mem_available / 1024**3 30 | 31 | concurrency = math.floor(min(concurrency, cpu_available // num_cpus)) 32 | logger.info( 33 | f"Available CPUs: {cpu_available}, required cpus {num_cpus}, updated conccurency {concurrency}." 34 | ) 35 | 36 | concurrency = math.floor(min(concurrency, mem_available // memory)) 37 | logger.info( 38 | f"Available memory: {mem_available}GB, required memory {memory}GB, updated conccurency {concurrency}." 39 | ) 40 | 41 | return concurrency 42 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/utils/cuda_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | from loguru import logger 5 | from PIL import ImageFile 6 | import importlib.metadata 7 | import importlib.util 8 | from typing import Tuple, Union 9 | 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | 12 | # For now, only INFO will be shown. Later the severity level will be changed 13 | # when setup_logger is called to initialize the logger. 14 | logger.remove() 15 | logger.add(sys.stderr, level="INFO") 16 | 17 | 18 | def _is_package_available( 19 | pkg_name: str, return_version: bool = False 20 | ) -> Union[Tuple[bool, str], bool]: 21 | # Check we're not importing a "pkg_name" directory somewhere 22 | # but the actual library by trying to grab the version 23 | package_exists = importlib.util.find_spec(pkg_name) is not None 24 | package_version = "N/A" 25 | if package_exists: 26 | try: 27 | package_version = importlib.metadata.version(pkg_name) 28 | package_exists = True 29 | except importlib.metadata.PackageNotFoundError: 30 | package_exists = False 31 | logger.debug(f"Detected {pkg_name} version {package_version}") 32 | if return_version: 33 | return package_exists, package_version 34 | else: 35 | return package_exists 36 | 37 | 38 | def _cuda_device_count(): 39 | _torch_available = _is_package_available("torch") 40 | 41 | if _torch_available: 42 | import torch 43 | 44 | return torch.cuda.device_count() 45 | 46 | try: 47 | nvidia_smi_output = subprocess.check_output(["nvidia-smi", "-L"], text=True) 48 | all_devices = nvidia_smi_output.strip().split("\n") 49 | 50 | cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") 51 | if cuda_visible_devices is not None: 52 | logger.warning( 53 | "CUDA_VISIBLE_DEVICES is ignored when torch is unavailable. " 54 | "All detected GPUs will be used." 55 | ) 56 | 57 | return len(all_devices) 58 | except Exception: 59 | # nvidia-smi not found or other error 60 | return 0 61 | 62 | 63 | _CUDA_DEVICE_COUNT = _cuda_device_count() 64 | 65 | 66 | def cuda_device_count(): 67 | return _CUDA_DEVICE_COUNT 68 | 69 | 70 | def is_cuda_available(): 71 | return _CUDA_DEVICE_COUNT > 0 72 | 73 | 74 | def get_num_gpus(use_cuda, op_proc): 75 | if not use_cuda: 76 | return 0 77 | proc_per_gpu = op_proc / cuda_device_count() 78 | return 1.0 / proc_per_gpu 79 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/utils/download_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import fcntl 4 | from loguru import logger 5 | from pairag.utils.download_models import ModelScopeDownloader 6 | 7 | 8 | def download_models_via_lock(model_dir, model_name, use_cuda: bool = False): 9 | model_dir = model_dir or os.getenv("PAIRAG_MODEL_DIR", "./model_repository") 10 | model_path = os.path.join(model_dir, model_name) 11 | lock_file_path = model_name + ".lock" 12 | # 创建或打开一个锁文件 13 | with open(lock_file_path, "w") as lock_file: 14 | while True: 15 | try: 16 | # 尝试获取文件锁 17 | fcntl.flock(lock_file, fcntl.LOCK_EX | fcntl.LOCK_NB) 18 | logger.info(f"进程 {os.getpid()} 获得锁") 19 | 20 | # 检查模型文件是否已经下载 21 | if os.path.exists(model_path): 22 | logger.info(f"进程 {os.getpid()} 检查到: 模型已下载完成,use_cuda: {use_cuda}") 23 | else: 24 | logger.info(f"进程 {os.getpid()} 开始下载模型,环境: use_cuda: {use_cuda}。") 25 | ModelScopeDownloader( 26 | fetch_config=True, 27 | download_directory_path=model_dir, 28 | ).load_model(model=model_name) 29 | 30 | # 释放锁并结束循环 31 | fcntl.flock(lock_file, fcntl.LOCK_UN) 32 | logger.info(f"进程 {os.getpid()} 下载模型完成,use_cuda: {use_cuda}。") 33 | break 34 | 35 | except IOError as ex: 36 | logger.info(f"进程 {os.getpid()} 等待锁中... {ex}") 37 | time.sleep(1) # 等待后重试 38 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/utils/file_ext_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | def parse_file_extensions(file_extensions_str: str) -> List[str]: 5 | """ 6 | Parse the file extensions string into a list of file extensions. 7 | 8 | Args: 9 | file_extensions_str (str): The file extensions string. 10 | 11 | Returns: 12 | List[str]: The list of file extensions. 13 | """ 14 | file_extensions = file_extensions_str.split(",") 15 | extensions = [f".{ext.strip()}" for ext in file_extensions if ext.strip()] 16 | return extensions 17 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/utils/filename_utils.py: -------------------------------------------------------------------------------- 1 | from ray.data.datasource import FilenameProvider 2 | 3 | 4 | class BlockFileNameProvider(FilenameProvider): 5 | def __init__(self, run_label: str, file_format: str): 6 | self._run_label = run_label 7 | self._file_format = file_format 8 | 9 | def get_filename_for_block(self, block, task_index, block_index): 10 | return f"{self._run_label}_{task_index:06}_{block_index:06}.{self._file_format}" 11 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/utils/memory_utils.py: -------------------------------------------------------------------------------- 1 | # Others 2 | def size_to_bytes(size): 3 | alphabets_list = [char for char in size if char.isalpha()] 4 | numbers_list = [char for char in size if char.isdigit()] 5 | 6 | if len(numbers_list) == 0: 7 | raise ValueError(f"Your input `size` does not contain numbers: {size}") 8 | 9 | size_numbers = int(float("".join(numbers_list))) 10 | 11 | if len(alphabets_list) == 0: 12 | # by default, if users do not specify the units, the number will be 13 | # regarded as in bytes 14 | return size_numbers 15 | 16 | suffix = "".join(alphabets_list).lower() 17 | 18 | if suffix == "kb" or suffix == "kib": 19 | return size_numbers << 10 20 | elif suffix == "mb" or suffix == "mib": 21 | return size_numbers << 20 22 | elif suffix == "gb" or suffix == "gib": 23 | return size_numbers << 30 24 | elif suffix == "tb" or suffix == "tib": 25 | return size_numbers << 40 26 | elif suffix == "pb" or suffix == "pib": 27 | return size_numbers << 50 28 | elif suffix == "eb" or suffix == "eib": 29 | return size_numbers << 60 30 | elif suffix == "zb" or suffix == "zib": 31 | return size_numbers << 70 32 | elif suffix == "yb" or suffix == "yib": 33 | return size_numbers << 80 34 | else: 35 | raise ValueError( 36 | f"You specified unidentifiable unit: {suffix}, " 37 | f"expected in [KB, MB, GB, TB, PB, EB, ZB, YB, " 38 | f"KiB, MiB, GiB, TiB, PiB, EiB, ZiB, YiB], " 39 | f"(case insensitive, counted by *Bytes*)." 40 | ) 41 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/utils/path_resolver.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | 4 | class MountPathResolver: 5 | @abstractmethod 6 | def resolve_destination_path(self, uri: str) -> str: 7 | """ 8 | Resolve the path to the local file system. 9 | """ 10 | raise NotImplementedError 11 | 12 | @abstractmethod 13 | def resolve_source_url(self, path: str) -> str: 14 | """ 15 | Resolve the path to the local file system. 16 | """ 17 | raise NotImplementedError 18 | 19 | 20 | # 本地运行,文件路径即为uri 21 | class LocalPathResolver(MountPathResolver): 22 | def resolve_destination_path(self, uri: str) -> str: 23 | """ 24 | Resolve the path to the local file system. 25 | """ 26 | return uri 27 | 28 | def resolve_source_url(self, path: str) -> str: 29 | """ 30 | Resolve the path to the local file system. 31 | """ 32 | return path 33 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/utils/path_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from loguru import logger 4 | 5 | 6 | def clear_folder(folder_path: str): 7 | if not os.path.exists(folder_path) or not os.path.isdir(folder_path): 8 | logger.warning( 9 | f"Fail to clear path {folder_path} because it is not a directory or does not exist." 10 | ) 11 | return 12 | 13 | shutil.rmtree(folder_path) # 删除整个文件夹 14 | os.makedirs(folder_path) # 创建新的空文件夹 15 | logger.info(f"Folder {folder_path} cleared successfully.") 16 | 17 | return 18 | -------------------------------------------------------------------------------- /src/pairag/data_pipeline/utils/vectordb_utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import requests 3 | from pairag.knowledgebase.index.pai.utils.vector_store_utils import create_vector_store 4 | from pairag.knowledgebase.index.pai.vector_store_config import ( 5 | BaseVectorStoreConfig, 6 | SupportedVectorStoreType, 7 | ) 8 | from pairag.knowledgebase.models import KnowledgeBase 9 | from loguru import logger 10 | 11 | 12 | def get_vector_store_config( 13 | rag_endpoint: str, rag_api_key: str, knowledgebase: str 14 | ) -> BaseVectorStoreConfig: 15 | try: 16 | response = requests.get( 17 | f"{rag_endpoint}/api/v1/knowledgebases/{knowledgebase}", 18 | headers={"Authorization": f"Bearer {rag_api_key}"}, 19 | ) 20 | knowledgebase: KnowledgeBase = KnowledgeBase.model_validate(response.json()) 21 | return knowledgebase.vector_store_config 22 | except Exception as e: 23 | logger.error(f"Failed to get vector store config: {e}") 24 | raise e 25 | 26 | 27 | def get_vector_store( 28 | rag_endpoint: str, 29 | rag_api_key: str, 30 | knowledgebase: str, 31 | embed_dims: int, 32 | ): 33 | vector_store_config = get_vector_store_config( 34 | rag_endpoint=rag_endpoint, rag_api_key=rag_api_key, knowledgebase=knowledgebase 35 | ) 36 | 37 | logger.info(f"Creating vector store from config {vector_store_config}.") 38 | assert ( 39 | vector_store_config.type != SupportedVectorStoreType.faiss 40 | ), "FAISS is not supported." 41 | 42 | asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) 43 | 44 | vector_store = create_vector_store( 45 | vectordb_config=vector_store_config, 46 | embed_dims=embed_dims, 47 | ) 48 | 49 | logger.info( 50 | f"""[PaiVectorStore] init finished with following parameters: 51 | config: {vector_store_config} 52 | embed_dims: {embed_dims} 53 | """ 54 | ) 55 | 56 | return vector_store 57 | -------------------------------------------------------------------------------- /src/pairag/evaluation/dataset/crag/crag_data_loader.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from llama_index.core.indices import VectorStoreIndex 3 | from llama_index.core.ingestion import IngestionPipeline 4 | from pairag.file.readers.pai.pai_data_reader import PaiDataReader 5 | from loguru import logger 6 | 7 | 8 | class CragDataLoader: 9 | def __init__( 10 | self, 11 | data_reader: PaiDataReader, 12 | embed_model: Any = None, 13 | vector_index: VectorStoreIndex = None, 14 | ): 15 | self._data_reader = data_reader 16 | self._embed_model = embed_model 17 | self._vector_index = vector_index 18 | 19 | def load_data( 20 | self, 21 | file_path_or_directory: str, 22 | from_oss: bool = False, 23 | oss_path: str = None, 24 | filter_pattern: str = None, 25 | enable_raptor: str = False, 26 | ): 27 | """Load data from a file or directory.""" 28 | documents = self._data_reader.load_data( 29 | file_path_or_directory=file_path_or_directory, 30 | filter_pattern=filter_pattern, 31 | oss_path=oss_path, 32 | from_oss=from_oss, 33 | ) 34 | if from_oss: 35 | logger.info(f"Loaded {len(documents)} documents from {oss_path}") 36 | else: 37 | logger.info( 38 | f"Loaded {len(documents)} documents from {file_path_or_directory}" 39 | ) 40 | 41 | transformations = [ 42 | self._embed_model, 43 | ] 44 | 45 | ingestion_pipeline = IngestionPipeline(transformations=transformations) 46 | 47 | nodes = ingestion_pipeline.run(documents=documents) 48 | logger.info( 49 | f"[DataLoader] parsed {len(documents)} documents into {len(nodes)} nodes." 50 | ) 51 | 52 | self._vector_index.insert_nodes(nodes) 53 | logger.info(f"[DataLoader] Inserted {len(nodes)} nodes.") 54 | logger.info("[DataLoader] Ingestion Completed!") 55 | -------------------------------------------------------------------------------- /src/pairag/evaluation/dataset/crag/crag_jsonl_reader.py: -------------------------------------------------------------------------------- 1 | """Tabular parser-Excel parser. 2 | 3 | Contains parsers for tabular data files. 4 | 5 | """ 6 | 7 | from pathlib import Path 8 | from typing import Any, Dict, List, Optional 9 | from fsspec import AbstractFileSystem 10 | from llama_index.core.readers.base import BaseReader 11 | from llama_index.core.schema import Document 12 | import json 13 | 14 | 15 | class CragJsonLReader(BaseReader): 16 | """JsonL reader.""" 17 | 18 | def __init__(self, *args: Any, **kwargs: Any) -> None: 19 | """Init params.""" 20 | super().__init__(*args, **kwargs) 21 | 22 | def load_data( 23 | self, 24 | file_path: Path, 25 | extra_info: Optional[Dict] = None, 26 | fs: Optional[AbstractFileSystem] = None, 27 | ) -> List[Document]: 28 | with open(file_path, "r", encoding="utf-8") as file: 29 | json_lines = [line.strip() for line in file.readlines()] 30 | 31 | docs = [] 32 | for i, text in enumerate(json_lines): 33 | json_data = json.loads(text) 34 | search_results = json_data["search_results"] 35 | for j, search_result in enumerate(search_results): 36 | extra_info["row_number"] = i + 1 37 | extra_info["dataset_source"] = "crag" 38 | docs.append( 39 | Document( 40 | doc_id=f"{json_data['interaction_id']}__{j}", 41 | text=search_result["page_snippet"], 42 | metadata=extra_info, 43 | ) 44 | ) 45 | return docs 46 | -------------------------------------------------------------------------------- /src/pairag/evaluation/dataset/rag_eval_dataset_refactor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import uuid 3 | from typing import List, Optional, Dict, Any 4 | from pydantic import BaseModel, Field, ValidationError 5 | from pairag.evaluation.dataset.rag_qca_dataset_refactor import Source, QcapSample 6 | from loguru import logger 7 | 8 | 9 | def generate_eval_uuid() -> str: 10 | return f"eval_{str(uuid.uuid4())}" 11 | 12 | 13 | class EvalResults(BaseModel): 14 | results: Dict[str, Any] 15 | source: Optional[Source] = Source() 16 | 17 | 18 | class QcapEvalSample(BaseModel): 19 | """ 20 | 主模型,表示完整的 Evaluated QCAP Sample,包含 id, qcap 和 eval_results 等字段。 21 | """ 22 | 23 | id: str = Field(default_factory=generate_eval_uuid) 24 | qcap: QcapSample 25 | eval_results: EvalResults 26 | 27 | 28 | class QcapEvalDataset(BaseModel): 29 | """ 30 | 评估数据集模型,包含多个 QcapEvalSample 31 | """ 32 | 33 | samples: List[QcapEvalSample] 34 | 35 | def save_json(self, file_path: str): 36 | """ 37 | 将数据集保存为 jsonl 文件,每一行是一个 QcapEvalSample 的 JSON 表示。 38 | """ 39 | try: 40 | with open(file_path, "w", encoding="utf-8") as f: 41 | for sample in self.samples: 42 | f.write(sample.model_dump_json()) 43 | f.write("\n") 44 | logger.info(f"数据集已成功保存到 {file_path}") 45 | except Exception as e: 46 | logger.info(f"保存数据集时出错: {e}") 47 | 48 | @classmethod 49 | def from_json(cls, file_path: str) -> "QcapEvalDataset": 50 | """ 51 | 从 jsonl 文件中读取数据,生成一个 QcapEvalDataset 实例。 52 | """ 53 | samples = [] 54 | try: 55 | with open(file_path, "r", encoding="utf-8") as f: 56 | for line_number, line in enumerate(f, start=1): 57 | line = line.strip() 58 | if not line: 59 | continue # 跳过空行 60 | try: 61 | sample_dict = json.loads(line) 62 | sample = QcapEvalSample(**sample_dict) 63 | samples.append(sample) 64 | except (json.JSONDecodeError, ValidationError) as e: 65 | logger.info(f"在第 {line_number} 行读取样本时出错: {e}") 66 | logger.info(f"从 {file_path} 成功加载了 {len(samples)} 个样本") 67 | return cls(samples=samples) 68 | except FileNotFoundError: 69 | logger.info(f"文件 {file_path} 未找到。") 70 | return cls(samples=[]) 71 | except Exception as e: 72 | logger.info(f"读取数据集时出错: {e}") 73 | return cls(samples=[]) 74 | -------------------------------------------------------------------------------- /src/pairag/evaluation/dataset/state_manager.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from enum import Enum 4 | 5 | 6 | class DatasetState(Enum): 7 | QCA = "qca" 8 | QCAP = "qcap" 9 | RETRIEVAL = "retrieval_evaluation" 10 | RESPONSE = "response_evaluation" 11 | E2E = "end2end_evaluation" 12 | 13 | 14 | class StateManager: 15 | def __init__(self, state_file="state.json"): 16 | self.state_file = state_file 17 | # 使用字典来保存每个状态的完成情况 18 | self.states = {state: False for state in DatasetState} 19 | self.load_state() 20 | 21 | def load_state(self): 22 | """从JSON文件加载所有状态的完成情况""" 23 | if os.path.exists(self.state_file): 24 | try: 25 | with open(self.state_file, "r", encoding="utf-8") as f: 26 | data = json.load(f) 27 | for state in DatasetState: 28 | # 更新状态的完成情况,如果JSON中有对应的键 29 | if state.value in data: 30 | self.states[state] = data[state.value] 31 | print(f"加载的状态: {self.states}") 32 | except Exception as e: 33 | print(f"加载状态时出错: {e}") 34 | # 如果出错,保持所有状态为未完成 35 | else: 36 | print("状态文件不存在,初始化所有状态为未完成") 37 | 38 | def save_state(self): 39 | """将所有状态的完成情况保存到JSON文件""" 40 | try: 41 | with open(self.state_file, "w", encoding="utf-8") as f: 42 | # 将状态字典转换为以状态值为键的字典 43 | data = { 44 | state.value: completed for state, completed in self.states.items() 45 | } 46 | json.dump(data, f, ensure_ascii=False, indent=4) 47 | print(f"状态已保存为: {data}") 48 | except Exception as e: 49 | print(f"保存状态时出错: {e}") 50 | 51 | def mark_completed(self, new_state): 52 | """标记某个状态为完成并保存""" 53 | if isinstance(new_state, DatasetState): 54 | self.states[new_state] = True 55 | self.save_state() 56 | print(f"状态已标记为完成: {new_state}") 57 | else: 58 | raise ValueError("new_state 必须是 DatasetState 的实例") 59 | 60 | def is_completed(self, state): 61 | """判断某个状态是否已经完成""" 62 | if isinstance(state, DatasetState): 63 | return self.states.get(state, False) 64 | else: 65 | raise ValueError("state 必须是 DatasetState 的实例") 66 | 67 | def get_completed_states(self): 68 | """获取所有已完成的状态""" 69 | return [state for state, completed in self.states.items() if completed] 70 | 71 | def reset_state(self): 72 | """重置所有状态为未完成""" 73 | for state in self.states: 74 | self.states[state] = False 75 | self.save_state() 76 | print("所有状态已重置为未完成") 77 | -------------------------------------------------------------------------------- /src/pairag/evaluation/metrics/response/base.py: -------------------------------------------------------------------------------- 1 | """Llm metric for response evaluation.""" 2 | from abc import abstractmethod 3 | from typing import Any, Optional, Sequence 4 | 5 | from llama_index.core.evaluation.base import EvaluationResult 6 | from llama_index.core.llms.llm import LLM 7 | from llama_index.core.prompts.mixin import PromptDictType 8 | from llama_index.core.prompts.mixin import PromptMixin, PromptMixinType 9 | 10 | 11 | class LlmMetric(PromptMixin): 12 | """ 13 | Llm Metric. 14 | """ 15 | 16 | metric_name: str = "base" 17 | 18 | def __init__( 19 | self, 20 | llm: Optional[LLM] = None, 21 | raise_error: bool = False, 22 | ) -> None: 23 | """Init params.""" 24 | self._llm = llm 25 | self._raise_error = raise_error 26 | 27 | def _get_prompts(self) -> PromptDictType: 28 | """Get prompts.""" 29 | return { 30 | "eval_template": self._eval_template, 31 | } 32 | 33 | def _update_prompts(self, prompts: PromptDictType) -> None: 34 | """Update prompts.""" 35 | if "eval_template" in prompts: 36 | self._eval_template = prompts["eval_template"] 37 | 38 | @abstractmethod 39 | async def parse_eval_result(self, eval_result: str) -> float: 40 | """Parse eval_result.""" 41 | raise NotImplementedError 42 | 43 | @abstractmethod 44 | async def aevaluate( 45 | self, 46 | query: str | None = None, 47 | reference_answer: str | None = None, 48 | contexts: Sequence[str] | None = None, 49 | response_answer: str | None = None, 50 | **kwargs: Any, 51 | ) -> EvaluationResult: 52 | """Run evaluation with query string, retrieved contexts, 53 | and generated response string. 54 | 55 | Subclasses can override this method to provide custom evaluation logic and 56 | take in additional arguments. 57 | """ 58 | raise NotImplementedError 59 | 60 | def _get_prompt_modules(self) -> PromptMixinType: 61 | """Get prompt modules.""" 62 | return {} 63 | -------------------------------------------------------------------------------- /src/pairag/evaluation/metrics/retrieval/core.py: -------------------------------------------------------------------------------- 1 | def default_hit_rate(expected_ids, retrieved_ids): 2 | """Default HitRate calculation: Check if there is a single hit""" 3 | is_hit = any(id in expected_ids for id in retrieved_ids) 4 | score = 1.0 if is_hit else 0.0 5 | return score 6 | 7 | 8 | def granular_hit_rate(expected_ids, retrieved_ids): 9 | """Granular HitRate calculation: Calculate all hits and divide by the number of expected docs""" 10 | expected_set = set(expected_ids) 11 | hits = sum(1 for doc_id in retrieved_ids if doc_id in expected_set) 12 | score = hits / len(expected_ids) if expected_ids else 0.0 13 | return score 14 | 15 | 16 | def default_mrr(expected_ids, retrieved_ids): 17 | """Default MRR calculation: Reciprocal rank of the first relevant document retrieved""" 18 | for i, id in enumerate(retrieved_ids): 19 | if id in expected_ids: 20 | return 1.0 / (i + 1) 21 | return 0.0 22 | 23 | 24 | def granular_mrr(expected_ids, retrieved_ids): 25 | """Granular MRR calculation: All relevant retrieved docs have their reciprocal ranks summed and averaged.""" 26 | expected_set = set(expected_ids) 27 | reciprocal_rank_sum = 0.0 28 | relevant_docs_count = 0 29 | for index, doc_id in enumerate(retrieved_ids): 30 | if doc_id in expected_set: 31 | relevant_docs_count += 1 32 | reciprocal_rank_sum += 1.0 / (index + 1) 33 | score = ( 34 | reciprocal_rank_sum / relevant_docs_count if relevant_docs_count > 0 else 0.0 35 | ) 36 | return score 37 | -------------------------------------------------------------------------------- /src/pairag/evaluation/metrics/retrieval/hitrate.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from pairag.evaluation.metrics.retrieval.core import ( 3 | granular_hit_rate, 4 | default_hit_rate, 5 | ) 6 | 7 | 8 | class HitRate: 9 | """Hit rate metric: Compute hit rate with two calculation options. 10 | 11 | - The default method checks for a single match between any of the retrieved docs and expected docs. 12 | - The more granular method checks for all potential matches between retrieved docs and expected docs. 13 | 14 | Attributes: 15 | metric_name (str): The name of the metric. 16 | use_granular_hit_rate (bool): Determines whether to use the granular method for calculation. 17 | """ 18 | 19 | def __init__( 20 | self, metric_name: str = "hitrate", use_granular_hit_rate: bool = False 21 | ): 22 | self.metric_name = metric_name 23 | self.use_granular_hit_rate = use_granular_hit_rate 24 | 25 | def compute( 26 | self, 27 | expected_ids: Optional[List[str]] = None, 28 | retrieved_ids: Optional[List[str]] = None, 29 | ): 30 | """Compute metric based on the provided inputs. 31 | 32 | Parameters: 33 | expected_ids (Optional[List[str]]): Expected document IDs. 34 | retrieved_ids (Optional[List[str]]): Retrieved document IDs. 35 | 36 | Raises: 37 | ValueError: If the necessary IDs are not provided. 38 | 39 | Returns: 40 | RetrievalMetricResult: The result with the computed hit rate score. 41 | """ 42 | # Checking for the required arguments 43 | if ( 44 | retrieved_ids is None 45 | or expected_ids is None 46 | or not retrieved_ids 47 | or not expected_ids 48 | ): 49 | raise ValueError("Retrieved ids and expected ids must be provided") 50 | 51 | if self.use_granular_hit_rate: 52 | return granular_hit_rate(expected_ids, retrieved_ids) 53 | else: 54 | return default_hit_rate(expected_ids, retrieved_ids) 55 | -------------------------------------------------------------------------------- /src/pairag/evaluation/metrics/retrieval/mrr.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from pairag.evaluation.metrics.retrieval.core import ( 3 | default_mrr, 4 | granular_mrr, 5 | ) 6 | 7 | 8 | class MRR: 9 | """MRR (Mean Reciprocal Rank) metric with two calculation options. 10 | 11 | - The default method calculates the reciprocal rank of the first relevant retrieved document. 12 | - The more granular method sums the reciprocal ranks of all relevant retrieved documents and divides by the count of relevant documents. 13 | 14 | Attributes: 15 | metric_name (str): The name of the metric. 16 | use_granular_mrr (bool): Determines whether to use the granular method for calculation. 17 | """ 18 | 19 | def __init__(self, metric_name: str = "mrr", use_granular_mrr: bool = False): 20 | self.metric_name = metric_name 21 | self.use_granular_mrr = use_granular_mrr 22 | 23 | def compute( 24 | self, 25 | expected_ids: Optional[List[str]] = None, 26 | retrieved_ids: Optional[List[str]] = None, 27 | ): 28 | """Compute MRR based on the provided inputs and selected method. 29 | 30 | Parameters: 31 | expected_ids (Optional[List[str]]): Expected document IDs. 32 | retrieved_ids (Optional[List[str]]): Retrieved document IDs. 33 | 34 | Raises: 35 | ValueError: If the necessary IDs are not provided. 36 | 37 | Returns: 38 | RetrievalMetricResult: The result with the computed MRR score. 39 | """ 40 | # Checking for the required arguments 41 | if ( 42 | retrieved_ids is None 43 | or expected_ids is None 44 | or not retrieved_ids 45 | or not expected_ids 46 | ): 47 | raise ValueError("Retrieved ids and expected ids must be provided") 48 | 49 | if self.use_granular_mrr: 50 | return granular_mrr(expected_ids, retrieved_ids) 51 | else: 52 | return default_mrr(expected_ids, retrieved_ids) 53 | -------------------------------------------------------------------------------- /src/pairag/evaluation/pipeline/run_evaluation_pipeline.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from pairag.evaluation.utils.create_components import ( 3 | get_rag_components, 4 | get_rag_config_and_mode, 5 | get_eval_components, 6 | ) 7 | 8 | 9 | def run_rag_evaluation_pipeline( 10 | config_file=None, 11 | oss_path=None, 12 | data_path=None, 13 | pattern=None, 14 | exp_name="default", 15 | eval_model_llm_config=None, 16 | dataset=None, 17 | use_pai_eval=False, 18 | ): 19 | assert (oss_path is not None) or ( 20 | data_path is not None 21 | ), "Must provide either local path or oss path." 22 | assert (oss_path is None) or ( 23 | data_path is None 24 | ), f"Can not provide both local path '{data_path}' and oss path '{oss_path}'." 25 | 26 | config, mode, exist_flag = get_rag_config_and_mode(config_file, exp_name) 27 | data_loader, vector_index, query_engine = get_rag_components(config, dataset) 28 | if not exist_flag: 29 | data_loader.load_data( 30 | file_path_or_directory=data_path, 31 | filter_pattern=pattern, 32 | oss_path=oss_path, 33 | from_oss=oss_path is not None, 34 | enable_raptor=False, 35 | ) 36 | 37 | qca_generator, evaluator = get_eval_components( 38 | config, 39 | vector_index, 40 | query_engine, 41 | mode, 42 | eval_model_llm_config, 43 | use_pai_eval, 44 | ) 45 | 46 | _ = asyncio.run( 47 | qca_generator.agenerate_all_dataset(dataset=dataset, dataset_path=data_path) 48 | ) 49 | asyncio.run(evaluator.aevaluation_for_retrieval()) 50 | asyncio.run(evaluator.aevaluation_for_response()) 51 | asyncio.run(evaluator.aevaluation_for_all()) 52 | -------------------------------------------------------------------------------- /src/pairag/evaluation/pipeline/run_multimodal_evaluation_pipeline.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from pairag.evaluation.utils.create_components import ( 3 | get_rag_components, 4 | get_rag_config_and_mode, 5 | get_multimodal_eval_components, 6 | ) 7 | from pairag.evaluation.dataset.state_manager import DatasetState 8 | 9 | 10 | def run_multimodal_evaluation_pipeline( 11 | config_file=None, 12 | oss_path=None, 13 | qca_dataset_path=None, 14 | data_path=None, 15 | pattern=None, 16 | exp_name="default", 17 | eval_model_llm_config=None, 18 | tested_multimodal_llm_config=None, 19 | ): 20 | config, mode, exist_flag = get_rag_config_and_mode(config_file, exp_name) 21 | assert mode == "image" 22 | data_loader, vector_index, query_engine = get_rag_components(config) 23 | multimodal_qca_generator, evaluator = get_multimodal_eval_components( 24 | config, 25 | exp_name, 26 | vector_index, 27 | query_engine, 28 | eval_model_llm_config, 29 | tested_multimodal_llm_config, 30 | qca_dataset_path, 31 | ) 32 | if qca_dataset_path: 33 | multimodal_qca_generator.state_manager.mark_completed(DatasetState.QCA) 34 | _ = asyncio.run( 35 | multimodal_qca_generator.agenerate_predicted_multimodal_dataset_only_via_vlm() 36 | ) 37 | asyncio.run(evaluator.aevaluation_for_response()) 38 | return 39 | 40 | assert (oss_path is not None) or ( 41 | data_path is not None 42 | ), "Must provide either local path or oss path." 43 | assert (oss_path is None) or ( 44 | data_path is None 45 | ), f"Can not provide both local path '{data_path}' and oss path '{oss_path}'." 46 | 47 | if not exist_flag: 48 | data_loader.load_data( 49 | file_path_or_directory=data_path, 50 | filter_pattern=pattern, 51 | oss_path=oss_path, 52 | from_oss=oss_path is not None, 53 | enable_raptor=False, 54 | ) 55 | 56 | _ = asyncio.run(multimodal_qca_generator.agenerate_all_dataset()) 57 | asyncio.run(evaluator.aevaluation_for_retrieval()) 58 | asyncio.run(evaluator.aevaluation_for_response()) 59 | -------------------------------------------------------------------------------- /src/pairag/evaluation/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def list_files_in_directory(directory_path): 5 | full_paths = [] 6 | for dirpath, _, filenames in os.walk(directory_path): 7 | for filename in filenames: 8 | full_path = os.path.join(dirpath, filename) 9 | full_paths.append(full_path) 10 | return full_paths 11 | -------------------------------------------------------------------------------- /src/pairag/integrations/data_analysis/data_analysis_config.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Dict, List, Literal 3 | from pydantic import BaseModel 4 | 5 | from pairag.web.ui_constants import ( 6 | NL2SQL_GENERAL_PROMPTS, 7 | SYN_GENERAL_PROMPTS, 8 | DA_SYSTEM_ROLE_PROMPT, 9 | ) 10 | 11 | 12 | class DataAnalysisType(str, Enum): 13 | """Data Analysis types.""" 14 | 15 | pandas = "pandas" 16 | sqlite = "sqlite" 17 | mysql = "mysql" 18 | 19 | 20 | class BaseAnalysisConfig(BaseModel): 21 | """Base class for data analysis config.""" 22 | 23 | type: DataAnalysisType 24 | # llm: OpenAICompatibleLlmConfig | None = OpenAICompatibleLlmConfig() 25 | model_id: str = "default" 26 | nl2sql_prompt: str = NL2SQL_GENERAL_PROMPTS 27 | synthesizer_prompt: str = SYN_GENERAL_PROMPTS 28 | system_role_prompt: str = DA_SYSTEM_ROLE_PROMPT 29 | 30 | 31 | class PandasAnalysisConfig(BaseAnalysisConfig): 32 | type: Literal[DataAnalysisType.pandas] = DataAnalysisType.pandas 33 | file_path: str = "./localdata/data_analysis/" 34 | 35 | 36 | class SqlAnalysisConfig(BaseAnalysisConfig): 37 | database: str 38 | tables: List[str] = [] 39 | descriptions: Dict[str, str] = {} 40 | # offline 41 | enable_enhanced_description: bool = False 42 | enable_db_history: bool = False 43 | enable_db_embedding: bool = False 44 | max_col_num: int = 100 45 | max_val_num: int = 10000 46 | # online 47 | enable_query_preprocessor: bool = False 48 | enable_db_preretriever: bool = False 49 | enable_db_selector: bool = False 50 | 51 | 52 | class SqliteAnalysisConfig(SqlAnalysisConfig): 53 | type: Literal[DataAnalysisType.sqlite] = DataAnalysisType.sqlite 54 | db_path: str 55 | 56 | 57 | class MysqlAnalysisConfig(SqlAnalysisConfig): 58 | type: Literal[DataAnalysisType.mysql] = DataAnalysisType.mysql 59 | user: str 60 | password: str 61 | host: str 62 | port: int 63 | -------------------------------------------------------------------------------- /src/pairag/integrations/data_analysis/nl2sql/db_utils/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_DB_DESCRIPTION_PATH = ( 2 | "./localdata/data_analysis/nl2sql/description/db_structured_description.json" 3 | ) 4 | DEFAULT_DB_HISTORY_PATH = ( 5 | "./localdata/data_analysis/nl2sql/history/db_query_history.json" 6 | ) 7 | 8 | DESCRIPTION_STORAGE_PATH = "./localdata/data_analysis/nl2sql/storage/description_index" 9 | HISTORY_STORAGE_PATH = "./localdata/data_analysis/nl2sql/storage/history_index" 10 | VALUE_STORAGE_PATH = "./localdata/data_analysis/nl2sql/storage/value_index" 11 | VALUE_LSH_PATH = "./localdata/data_analysis/nl2sql/storage/value_lsh" 12 | 13 | EMBEDDING_DIM_DICT = {"bge-large-zh-v1.5": 1024, "bge-m3": 1024} 14 | -------------------------------------------------------------------------------- /src/pairag/integrations/data_analysis/nl2sql/query_preprocessor.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from pydantic import BaseModel 3 | 4 | from llama_index.core.llms.llm import LLM 5 | from llama_index.core import Settings 6 | from llama_index.core import BasePromptTemplate 7 | from llama_index.core.schema import QueryBundle 8 | 9 | from pairag.integrations.data_analysis.nl2sql.nl2sql_prompts import ( 10 | DEFAULT_KEYWORD_EXTRACTION_PROMPT, 11 | ) 12 | 13 | 14 | class QueryPreprocessor: 15 | """ 16 | 预处理自然语言查询,目前主要考虑关键词提取,query改写待定; 17 | """ 18 | 19 | def __init__( 20 | self, 21 | llm: Optional[LLM] = None, 22 | keyword_extraction_prompt: Optional[BasePromptTemplate] = None, 23 | ) -> None: 24 | self._llm = llm or Settings.llm 25 | self._keyword_extraction_prompt = ( 26 | keyword_extraction_prompt or DEFAULT_KEYWORD_EXTRACTION_PROMPT 27 | ) 28 | 29 | def extract_keywords(self, nl_query: QueryBundle) -> List[str]: 30 | keyword_list_obj = self._llm.structured_predict( 31 | output_cls=KeywordList, 32 | prompt=self._keyword_extraction_prompt, 33 | llm_kwargs={ 34 | "tool_choice": {"type": "function", "function": {"name": "KeywordList"}} 35 | }, 36 | query_str=nl_query.query_str, 37 | fewshot_examples="", 38 | ) 39 | # text_complection = LLMTextCompletionProgram.from_defaults( 40 | # output_cls=KeywordList, 41 | # prompt=self._keyword_extraction_prompt, 42 | # ) 43 | # keyword_list_obj = text_complection(query_str=nl_query.query_str, fewshot_examples="") 44 | 45 | keywords = keyword_list_obj.Keywords 46 | # later check if parser needed 47 | # keywords = parse(self, keywords) 48 | # logger.info(f"keyword_list: {keywords} extracted.") 49 | return keywords 50 | 51 | async def aextract_keywords(self, nl_query: QueryBundle) -> List[str]: 52 | keyword_list_obj = await self._llm.astructured_predict( 53 | output_cls=KeywordList, 54 | prompt=self._keyword_extraction_prompt, 55 | llm_kwargs={ 56 | "tool_choice": {"type": "function", "function": {"name": "KeywordList"}} 57 | }, 58 | query_str=nl_query.query_str, 59 | fewshot_examples="", 60 | ) 61 | keywords = keyword_list_obj.Keywords 62 | # later check if parser needed 63 | # keywords = parse(self, keywords) 64 | # logger.info(f"keyword_list: {keywords} extracted.") 65 | return keywords 66 | 67 | def transform_query(self, nl_query: QueryBundle) -> List[str]: 68 | # 考虑历史对话的query改写 69 | pass 70 | 71 | 72 | class KeywordList(BaseModel): 73 | """Data model for KeywordList.""" 74 | 75 | Keywords: List[str] 76 | -------------------------------------------------------------------------------- /src/pairag/integrations/data_analysis/test/test_bird_schema_collector.py: -------------------------------------------------------------------------------- 1 | from pairag.integrations.data_analysis.data_analysis_config import ( 2 | SqliteAnalysisConfig, 3 | ) 4 | from pairag.integrations.data_analysis.text2sql.db_connector import ( 5 | SqliteConnector, 6 | ) 7 | from pairag.integrations.data_analysis.text2sql.db_info_collector import ( 8 | BirdSchemaCollector, 9 | ) 10 | 11 | 12 | sqlite_config = SqliteAnalysisConfig( 13 | db_path="/Users/chuyu/Documents/datasets/BIRD/dev_20240627/dev_databases/california_schools/", 14 | database="california_schools.sqlite", 15 | ) 16 | 17 | connector = SqliteConnector(sqlite_config) 18 | sql_databse = connector.connect() 19 | 20 | bird_schema_collector = BirdSchemaCollector( 21 | db_name="california_schools", 22 | sql_database=sql_databse, 23 | database_file_path="/Users/chuyu/Documents/datasets/BIRD/dev_20240627/dev_databases", 24 | ) 25 | 26 | bird_schema_collector.collect() 27 | -------------------------------------------------------------------------------- /src/pairag/integrations/data_analysis/test/test_db_connector.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import load_dotenv 3 | 4 | from llama_index.core import SQLDatabase 5 | 6 | from pairag.integrations.data_analysis.data_analysis_config import ( 7 | SqliteAnalysisConfig, 8 | MysqlAnalysisConfig, 9 | ) 10 | from pairag.integrations.data_analysis.text2sql.db_connector import ( 11 | MysqlConnector, 12 | SqliteConnector, 13 | ) 14 | 15 | # 加载 .env 文件中的环境变量 16 | load_dotenv() 17 | 18 | # 获取环境变量中的 API 密钥 19 | host = os.getenv("host") 20 | port = os.getenv("port") 21 | user = os.getenv("user") 22 | password = os.getenv("password") 23 | database = os.getenv("database") 24 | 25 | mysql_config = MysqlAnalysisConfig( 26 | host=host, 27 | port=port, 28 | user=user, 29 | password=password, 30 | database=database, 31 | tables=["pets"], 32 | ) 33 | 34 | sqlite_config = SqliteAnalysisConfig( 35 | db_path="./tests/testdata/db_data/", 36 | database="pets.sqlite", 37 | ) 38 | 39 | 40 | def test_mysql_connector(): 41 | connector = MysqlConnector(mysql_config) 42 | assert connector._db_config.type == "mysql" 43 | sql_databse = connector.connect() 44 | assert isinstance(sql_databse, SQLDatabase) 45 | 46 | 47 | def test_sqlite_connector(): 48 | connector = SqliteConnector(sqlite_config) 49 | assert connector._db_config.type == "sqlite" 50 | sql_databse = connector.connect() 51 | assert isinstance(sql_databse, SQLDatabase) 52 | -------------------------------------------------------------------------------- /src/pairag/integrations/data_analysis/test/test_db_info_collector.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import load_dotenv 3 | 4 | 5 | from pairag.integrations.data_analysis.data_analysis_config import ( 6 | MysqlAnalysisConfig, 7 | ) 8 | from pairag.integrations.data_analysis.text2sql.db_connector import ( 9 | MysqlConnector, 10 | ) 11 | from pairag.integrations.data_analysis.text2sql.db_info_collector import ( 12 | SchemaCollector, 13 | HistoryCollector, 14 | ValueCollector, 15 | ) 16 | 17 | 18 | # 加载 .env 文件中的环境变量 19 | load_dotenv() 20 | 21 | # 获取环境变量中的 API 密钥 22 | host = os.getenv("host") 23 | port = os.getenv("port") 24 | user = os.getenv("user") 25 | password = os.getenv("password") 26 | database = os.getenv("database") 27 | 28 | mysql_config = MysqlAnalysisConfig( 29 | host=host, 30 | port=port, 31 | user=user, 32 | password=password, 33 | database=database, 34 | # tables=["pets"], 35 | ) 36 | 37 | connector = MysqlConnector(mysql_config) 38 | sql_database = connector.connect() 39 | print("connector_info:", sql_database) 40 | 41 | 42 | def test_schema_processor(): 43 | schema_collector = SchemaCollector( 44 | db_name=mysql_config.database, sql_database=sql_database 45 | ) 46 | schema_description = schema_collector.collect() 47 | assert isinstance(schema_description, dict) 48 | 49 | 50 | def test_history_processor(): 51 | history_collector = HistoryCollector(db_name=mysql_config.database) 52 | query_history = history_collector.collect() 53 | assert isinstance(query_history, list) 54 | 55 | 56 | def test_value_processor(): 57 | value_collector = ValueCollector( 58 | db_name=mysql_config.database, sql_database=sql_database 59 | ) 60 | unique_values = value_collector.collect() 61 | assert isinstance(unique_values, dict) 62 | -------------------------------------------------------------------------------- /src/pairag/integrations/data_analysis/test/test_index_retriever.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | import hashlib 4 | 5 | from llama_index.embeddings.huggingface import HuggingFaceEmbedding 6 | from llama_index.core.schema import TextNode 7 | from llama_index.core.base.embeddings.base import BaseEmbedding 8 | 9 | from pairag.integrations.data_analysis.text2sql.db_info_retriever import ( 10 | SchemaRetriever, 11 | ) 12 | 13 | 14 | if os.path.exists("./model_repository/bge-m3"): 15 | embed_model_bge = HuggingFaceEmbedding( 16 | model_name="./model_repository/bge-m3", embed_batch_size=20 17 | ) 18 | else: 19 | embed_model_bge = None 20 | 21 | mock_nodes = [ 22 | TextNode( 23 | text="This is mock node 0.", 24 | metadata={"table_name": "table0", "column_name": "column0"}, 25 | ), 26 | TextNode( 27 | text="This is mock node 1.", 28 | metadata={"table_name": "table0", "column_name": "column1"}, 29 | ), 30 | TextNode( 31 | text="This is mock node 2.", 32 | metadata={"table_name": "table0", "column_name": "column2"}, 33 | ), 34 | ] 35 | 36 | 37 | def get_nodes_with_embeddings(embed_model: BaseEmbedding, nodes: List[TextNode]): 38 | # get embeddings 39 | embeddings = embed_model.get_text_embedding_batch( 40 | [node.get_content(metadata_mode="embed") for node in nodes] 41 | ) 42 | # update nodes embedding 43 | for node, embedding in zip(nodes, embeddings): 44 | node.embedding = embedding 45 | node_info_str = node.get_metadata_str() + node.get_text() 46 | node.id_ = hashlib.sha256(node_info_str.encode()).hexdigest() 47 | 48 | return nodes 49 | 50 | 51 | mock_nodes_with_embeddings = get_nodes_with_embeddings(embed_model_bge, mock_nodes) 52 | 53 | # 初始化检索器 54 | mock_retriever = SchemaRetriever( 55 | db_name="mock_db", 56 | embed_model=embed_model_bge, 57 | similarity_top_k=1, 58 | ) 59 | 60 | # 插入/更新nodes 61 | mock_retriever.get_index(mock_nodes_with_embeddings) 62 | 63 | 64 | mock_nodes_update = [ 65 | TextNode( 66 | text="This is mock node 0.", 67 | metadata={"table_name": "table0", "column_name": "column0"}, 68 | ), 69 | TextNode( 70 | text="This is mock node 01test.", 71 | metadata={"table_name": "table0", "column_name": "column1"}, 72 | ), 73 | TextNode( 74 | text="This is mock node 2.", 75 | metadata={"table_name": "table0", "column_name": "column2"}, 76 | ), 77 | ] 78 | 79 | mock_nodes_with_embeddings = get_nodes_with_embeddings( 80 | embed_model_bge, mock_nodes_update 81 | ) 82 | 83 | # 插入/更新nodes 84 | mock_retriever.get_index(mock_nodes_with_embeddings) 85 | # new_retriever = mock_retriever._schema_index.as_retriever() 86 | 87 | res = mock_retriever.retrieve_nodes(query="what is mock node 2?") 88 | 89 | print("retrieve result:", res) 90 | -------------------------------------------------------------------------------- /src/pairag/integrations/data_analysis/test/test_query_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import load_dotenv 3 | 4 | from llama_index.llms.openai_like import OpenAILike 5 | from llama_index.core.schema import QueryBundle 6 | from llama_index.llms.dashscope import DashScope, DashScopeGenerationModels 7 | from llama_index.core import Settings 8 | 9 | from pairag.integrations.data_analysis.text2sql.query_processor import KeywordExtractor 10 | from pairag.integrations.llms.pai.llm_config import ( 11 | OpenAICompatibleLlmConfig, 12 | ) 13 | 14 | 15 | # 加载 .env 文件中的环境变量 16 | load_dotenv() 17 | 18 | llm_ds = DashScope( 19 | model_name=DashScopeGenerationModels.QWEN_MAX, 20 | api_key=os.getenv("DASHSCOPE_API_KEY"), 21 | temperature=0.1, 22 | max_tokens=2048, 23 | ) 24 | print("DashScope:", llm_ds.metadata.is_function_calling_model) 25 | 26 | llm_config = OpenAICompatibleLlmConfig(model="qwen-max") 27 | 28 | llm_ol = OpenAILike( 29 | model=llm_config.model, 30 | api_base=llm_config.base_url, 31 | temperature=llm_config.temperature, 32 | system_prompt=llm_config.system_prompt, 33 | is_chat_model=True, 34 | api_key=llm_config.api_key or os.environ.get("DASHSCOPE_API_KEY"), 35 | max_tokens=llm_config.max_tokens, 36 | reuse_client=False, 37 | is_function_calling_model=True, 38 | ) 39 | 40 | Settings.llm = llm_ol 41 | query = "有猫的学生有多少?" 42 | qp = KeywordExtractor() 43 | result = qp.process(QueryBundle(query)) 44 | print(result) 45 | 46 | 47 | # # def test_query_processor(): 48 | # # query = "有猫的学生有多少?" 49 | # # qp = KeywordExtractor() 50 | # # keywords = qp.process(QueryBundle(query)) 51 | # # assert isinstance(keywords, list) 52 | # # assert len(keywords) > 0 53 | # # assert "猫" in keywords 54 | # # assert "学生" in keywords 55 | -------------------------------------------------------------------------------- /src/pairag/integrations/data_analysis/text2sql/evaluations/base_evaluator.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class SQLEvaluator(ABC): 5 | """生成SQL评估接口""" 6 | 7 | @abstractmethod 8 | async def abatch_loader( 9 | self, 10 | ): 11 | pass 12 | 13 | @abstractmethod 14 | async def abatch_query(self, nums: int): 15 | pass 16 | 17 | @abstractmethod 18 | def batch_evaluate( 19 | self, gold_file: str, predicted_file: str, evaluation_type: str, **args 20 | ): 21 | pass 22 | -------------------------------------------------------------------------------- /src/pairag/integrations/data_analysis/text2sql/utils/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_DB_DESCRIPTION_PATH = "./localdata/data_analysis/text2sql/description" 2 | DEFAULT_DB_DESCRIPTION_NAME = "db_structured_description.json" 3 | DEFAULT_DESCRIPTION_FOLDER_PATH = "./localdata/data_analysis/text2sql/input_description" 4 | 5 | DEFAULT_DB_HISTORY_PATH = "./localdata/data_analysis/text2sql/history" 6 | DEFAULT_DB_HISTORY_NAME = "db_query_history.json" 7 | 8 | DEFAULT_DB_VALUE_PATH = "./localdata/data_analysis/text2sql/value" 9 | DEFAULT_DB_VALUE_NAME = "db_unique_value.json" 10 | 11 | DEFAULT_TABLE_COMMENT_PATH = "./localdata/data_analysis/text2sql" 12 | DEFAULT_TABLE_COMMENT_NAME = "table_comment.json" 13 | 14 | DESCRIPTION_STORAGE_PATH = ( 15 | "./localdata/data_analysis/text2sql/storage/description_index" 16 | ) 17 | HISTORY_STORAGE_PATH = "./localdata/data_analysis/text2sql/storage/history_index" 18 | VALUE_STORAGE_PATH = "./localdata/data_analysis/text2sql/storage/value_index" 19 | VALUE_LSH_PATH = "./localdata/data_analysis/text2sql/storage/value_lsh" 20 | 21 | EMBEDDING_DIM_DICT = {"bge-large-zh-v1.5": 1024, "bge-m3": 1024} 22 | -------------------------------------------------------------------------------- /src/pairag/integrations/embeddings/pai/pai_embedding_config.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | from pydantic import BaseModel, Field 3 | from enum import Enum 4 | from llama_index.core.constants import DEFAULT_EMBED_BATCH_SIZE 5 | import os 6 | 7 | DEFAULT_HF_EMBED_MODEL = "bge-m3" 8 | 9 | 10 | class SupportedEmbedType(str, Enum): 11 | dashscope = "dashscope" 12 | openai = "openai" 13 | huggingface = "huggingface" 14 | 15 | 16 | class PaiBaseEmbeddingConfig(BaseModel): 17 | source: Literal[SupportedEmbedType.huggingface] = SupportedEmbedType.huggingface 18 | model: str 19 | embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE 20 | enable_sparse: bool = False 21 | 22 | class Config: 23 | frozen = True 24 | 25 | @classmethod 26 | def get_subclasses(cls): 27 | return tuple(cls.__subclasses__()) 28 | 29 | @classmethod 30 | def get_type(cls): 31 | return cls.model_fields["source"].default 32 | 33 | 34 | class DashScopeEmbeddingConfig(PaiBaseEmbeddingConfig): 35 | source: Literal[SupportedEmbedType.dashscope] = SupportedEmbedType.dashscope 36 | model: str | None = "text-embedding-v2" # use default 37 | api_key: str | None = Field(default=os.getenv("DASHSCOPE_API_KEY")) # use default 38 | 39 | 40 | class OpenAIEmbeddingConfig(PaiBaseEmbeddingConfig): 41 | source: Literal[SupportedEmbedType.openai] = SupportedEmbedType.openai 42 | model: str | None = None # use default 43 | api_key: str | None = None # use default 44 | api_base: str | None = None # use default 45 | 46 | 47 | class HuggingFaceEmbeddingConfig(PaiBaseEmbeddingConfig): 48 | source: Literal[SupportedEmbedType.huggingface] = SupportedEmbedType.huggingface 49 | model: str | None = DEFAULT_HF_EMBED_MODEL 50 | 51 | 52 | SupporttedEmbeddingClsMap = { 53 | cls.get_type(): cls for cls in PaiBaseEmbeddingConfig.get_subclasses() 54 | } 55 | 56 | 57 | def parse_embed_config(config_data): 58 | if "source" not in config_data: 59 | raise ValueError("Embedding config must contain 'source' field") 60 | 61 | embedding_cls = SupporttedEmbeddingClsMap.get(config_data["source"].lower()) 62 | if embedding_cls is None: 63 | raise ValueError(f"Unsupported embedding source: {config_data['source']}") 64 | 65 | return embedding_cls(**config_data) 66 | 67 | 68 | if __name__ == "__main__": 69 | embedding_config_data = {"source": "Openai", "model": "gpt-1", "api_key": None} 70 | 71 | print(parse_embed_config(embedding_config_data)) 72 | -------------------------------------------------------------------------------- /src/pairag/integrations/embeddings/readme.md: -------------------------------------------------------------------------------- 1 | # Custom embedding implementations 2 | 3 | Stay tuned. 4 | -------------------------------------------------------------------------------- /src/pairag/integrations/llms/pai/open_ai_alike_multi_modal.py: -------------------------------------------------------------------------------- 1 | from llama_index.multi_modal_llms.openai import OpenAIMultiModal 2 | from typing import Dict, Any 3 | 4 | 5 | class OpenAIAlikeMultiModal(OpenAIMultiModal): 6 | def _get_model_kwargs(self, **kwargs: Any) -> Dict[str, Any]: 7 | base_kwargs = {"model": self.model, "temperature": self.temperature, **kwargs} 8 | if self.max_new_tokens is not None: 9 | # If max_tokens is None, don't include in the payload: 10 | # https://platform.openai.com/docs/api-reference/chat 11 | # https://platform.openai.com/docs/api-reference/completions 12 | base_kwargs["max_tokens"] = self.max_new_tokens 13 | return {**base_kwargs, **self.additional_kwargs} 14 | -------------------------------------------------------------------------------- /src/pairag/integrations/llms/readme.md: -------------------------------------------------------------------------------- 1 | # Custom LLM implementations 2 | -------------------------------------------------------------------------------- /src/pairag/integrations/query_transform/intent_models.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import List, Optional 3 | from pydantic import BaseModel 4 | from openai.types.completion_usage import CompletionUsage 5 | 6 | 7 | class ChatToolType(str, Enum): 8 | SEARCH_WEB = "search_web" 9 | CHAT_NEWS = "chat_news" 10 | CHAT_KNOWLEDGEBASE = "chat_knowledgebase" 11 | CHAT_DB = "chat_db" 12 | CHAT_LLM = "chat_llm" 13 | 14 | 15 | class ChatIntentType(str, Enum): 16 | SEARCH_WEB = "search_web" # search web 17 | CHAT_LLM = "chat_llm" # llm chat 18 | LIST_NEWS = "list_news" # list news 19 | CHAT_NEWS = "chat_news" # chat news 20 | CHAT_NEWS_LLM = "chat_news_llm" # chat news only by llm 21 | CHAT_KNOWLEDGEBASE = "chat_knowledgebase" 22 | CHAT_DB = "chat_db" # chat sql 23 | 24 | 25 | class IntentResult(BaseModel): 26 | intent: ChatIntentType = ChatIntentType.CHAT_LLM 27 | query_str: str = None 28 | news_topics: Optional[List[str]] = None 29 | token_usage: CompletionUsage = CompletionUsage( 30 | completion_tokens=0, prompt_tokens=0, total_tokens=0 31 | ) 32 | -------------------------------------------------------------------------------- /src/pairag/integrations/search/search_config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from enum import Enum 3 | from typing import Literal 4 | 5 | DEFAULT_ALIYUN_SEARCH_ENDPOINT = "iqs.cn-zhangjiakou.aliyuncs.com" 6 | DEFAULT_GOOGLE_SEARCH_ENDPOINT = "https://serpapi.com/search" 7 | DEFAULT_SEARCH_COUNT = 10 8 | DEFAULT_SEARCH_QA_PROMPT_TEMPLATE = """ 9 | 你的目标是根据搜索结果提供准确、有用且易于理解的信息。 10 | # 任务要求: 11 | - 请严格根据提供的参考内容回答问题,并非所有参考内容都与用户的问题密切相关,你需要结合问题,对参考内容进行甄别、筛选。仅参考与问题相关的内容并忽略所有不相关的信息。 12 | - 如果参考内容中没有相关信息或与问题无关,请基于你的已有知识进行回答。 13 | - 确保答案准确、简洁,并且使用与用户提问相同的语种。 14 | - 在回答过程中,请避免使用“从参考内容得出”、“从材料得出”、“根据参考内容”等措辞。 15 | - 保持回答的专业性和友好性。 16 | - 如果需要更多信息来更好地回答问题,请礼貌地询问。 17 | - 对于复杂的问题,尽量简化解释,使信息易于理解。如果回答很长,请尽量结构化、分段落总结。如果需要分点作答,尽量控制在5个点以内,并合并相关的内容。 18 | - 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。 19 | - 除非用户要求,否则请保持输出语种与用户输入问题语种的一致性。 20 | - 对于涉及不安全/不道德/敏感/色情/暴力/赌博/违法等行为的问题,请明确拒绝提供所要求的信息,并简单解释为什么这样的请求不能被满足。 21 | - 你知道今天的日期是{current_datetime},但你不会主动在回复开头提到日期信息。 22 | 23 | # 以下内容是基于用户发送的消息的搜索/查询结果: 24 | {context_str} 25 | 26 | # 以下内容是用户问答历史记录: 27 | {history_str} 28 | 29 | # 以下内容是用户消息: 30 | {query_str} 31 | """ 32 | 33 | 34 | class SupportedSearchType(str, Enum): 35 | bing = "bing" 36 | aliyun = "aliyun" 37 | google = "google" 38 | 39 | 40 | class BaseSearchConfig(BaseModel): 41 | source: SupportedSearchType 42 | search_count: int = DEFAULT_SEARCH_COUNT 43 | search_qa_prompt_template: str = DEFAULT_SEARCH_QA_PROMPT_TEMPLATE 44 | 45 | class Config: 46 | frozen = True 47 | 48 | @classmethod 49 | def get_subclasses(cls): 50 | return tuple(cls.__subclasses__()) 51 | 52 | @classmethod 53 | def get_type(cls): 54 | return cls.model_fields["source"].default 55 | 56 | 57 | class BingSearchConfig(BaseSearchConfig): 58 | source: Literal[SupportedSearchType.bing] = SupportedSearchType.bing 59 | search_api_key: str | None = None 60 | search_lang: str = "zh-CN" 61 | 62 | 63 | class AliyunSearchConfig(BaseSearchConfig): 64 | source: Literal[SupportedSearchType.aliyun] = SupportedSearchType.aliyun 65 | endpoint: str = DEFAULT_ALIYUN_SEARCH_ENDPOINT 66 | access_key_id: str | None = None 67 | access_key_secret: str | None = None 68 | 69 | 70 | class GoogleSearchConfig(BaseSearchConfig): 71 | source: Literal[SupportedSearchType.google] = SupportedSearchType.google 72 | serpapi_key: str | None = None 73 | search_lang: str = "zh-CN" 74 | -------------------------------------------------------------------------------- /src/pairag/integrations/trace/trace_config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class TraceConfig(BaseModel): 5 | service_name: str | None = None 6 | token: str | None = None 7 | endpoint: str | None = None 8 | 9 | def is_enabled(self) -> bool: 10 | return self.service_name and self.token and self.endpoint 11 | -------------------------------------------------------------------------------- /src/pairag/integrations/vector_stores/elasticsearch/elasticsearch_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | from elasticsearch import AsyncElasticsearch, Elasticsearch 4 | 5 | 6 | def get_user_agent() -> str: 7 | """Get user agent for Elasticsearch client.""" 8 | import llama_index.core 9 | 10 | version = getattr(llama_index.core, "__version__", "") 11 | return f"llama_index-py-vs/{version}" 12 | 13 | 14 | def get_elasticsearch_client( 15 | url: Optional[str] = None, 16 | cloud_id: Optional[str] = None, 17 | api_key: Optional[str] = None, 18 | username: Optional[str] = None, 19 | password: Optional[str] = None, 20 | ) -> AsyncElasticsearch: 21 | if url and cloud_id: 22 | raise ValueError( 23 | "Both es_url and cloud_id are defined. Please provide only one." 24 | ) 25 | 26 | connection_params: Dict[str, Any] = {} 27 | 28 | if url: 29 | connection_params["hosts"] = [url] 30 | elif cloud_id: 31 | connection_params["cloud_id"] = cloud_id 32 | else: 33 | raise ValueError("Please provide either elasticsearch_url or cloud_id.") 34 | 35 | if api_key: 36 | connection_params["api_key"] = api_key 37 | elif username and password: 38 | connection_params["basic_auth"] = (username, password) 39 | 40 | sync_es_client = Elasticsearch( 41 | **connection_params, headers={"user-agent": get_user_agent()} 42 | ) 43 | async_es_client = AsyncElasticsearch( 44 | **connection_params, 45 | headers={"user-agent": get_user_agent()}, 46 | request_timeout=60, 47 | retry_on_timeout=True, 48 | max_retries=2, 49 | ) 50 | 51 | sync_es_client.info() # use sync client so don't have to 'await' to just get info 52 | 53 | return async_es_client 54 | -------------------------------------------------------------------------------- /src/pairag/knowledgebase/index/pai/utils/sparse_embed_function.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import os 3 | from typing import List, Optional, Dict 4 | from pairag.utils.constants import DEFAULT_MODEL_DIR 5 | 6 | from loguru import logger 7 | 8 | MODEL_NAME = "bge-m3" 9 | 10 | 11 | class BaseSparseEmbeddingFunction(ABC): 12 | @abstractmethod 13 | def encode_queries(self, queries: List[str]) -> List[Dict[int, float]]: 14 | pass 15 | 16 | @abstractmethod 17 | def encode_documents(self, documents: List[str]) -> List[Dict[int, float]]: 18 | pass 19 | 20 | 21 | class BGEM3SparseEmbeddingFunction(BaseSparseEmbeddingFunction): 22 | def __init__(self, model_name_or_path: Optional[str] = None) -> None: 23 | try: 24 | from FlagEmbedding import BGEM3FlagModel 25 | 26 | model_dir = os.getenv("PAIRAG_MODEL_DIR", DEFAULT_MODEL_DIR) 27 | 28 | self.model = BGEM3FlagModel( 29 | model_name_or_path=os.path.join( 30 | model_name_or_path or model_dir, MODEL_NAME 31 | ), 32 | use_fp16=False, 33 | ) 34 | except Exception: 35 | error_info = ( 36 | "Cannot import BGEM3FlagModel from FlagEmbedding. It seems it is not installed. " 37 | "Please install it using:\n" 38 | "pip install FlagEmbedding\n", 39 | "error_info", 40 | ) 41 | 42 | logger.error(error_info) 43 | raise 44 | 45 | def encode_queries(self, queries: List[str]): 46 | outputs = self.model.encode( 47 | queries, return_dense=False, return_sparse=True, return_colbert_vecs=False 48 | )["lexical_weights"] 49 | return [self._to_standard_dict(output) for output in outputs] 50 | 51 | def encode_documents(self, documents: List[str]): 52 | outputs = self.model.encode( 53 | documents, return_dense=False, return_sparse=True, return_colbert_vecs=False 54 | )["lexical_weights"] 55 | return [self._to_standard_dict(output) for output in outputs] 56 | 57 | def _to_standard_dict(self, raw_output): 58 | result = {} 59 | for k in raw_output: 60 | result[int(k)] = raw_output[k] 61 | return result 62 | 63 | 64 | def get_default_sparse_embedding_function() -> BGEM3SparseEmbeddingFunction: 65 | return BGEM3SparseEmbeddingFunction() 66 | -------------------------------------------------------------------------------- /src/pairag/knowledgebase/models.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field, model_validator 2 | from typing import Annotated, Dict, Union 3 | 4 | from pairag.integrations.embeddings.pai.pai_embedding_config import ( 5 | PaiBaseEmbeddingConfig, 6 | ) 7 | from pairag.knowledgebase.index.pai.vector_store_config import BaseVectorStoreConfig 8 | from pairag.file.nodeparsers.pai.pai_node_parser import NodeParserConfig 9 | from pairag.integrations.synthesizer.prompt_templates import ( 10 | DEFAULT_CUSTOM_PROMPT_TEMPLATE, 11 | DEFAULT_SYSTEM_ROLE_TEMPLATE, 12 | ) 13 | from pairag.utils.constants import DEFAULT_KNOWLEDGEBASE_NAME 14 | 15 | from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K 16 | from pairag.integrations.postprocessor.pai.pai_postprocessor import ( 17 | DEFAULT_RERANK_MODEL, 18 | DEFAULT_RERANK_SIMILARITY_THRESHOLD, 19 | DEFAULT_RERANK_TOP_N, 20 | DEFAULT_SIMILARITY_THRESHOLD, 21 | ) 22 | 23 | 24 | class KnowledgeBase(BaseModel): 25 | name: str = Field( 26 | default=DEFAULT_KNOWLEDGEBASE_NAME, 27 | description="Knowledgebase name.", 28 | pattern=r"^[0-9a-zA-Z_-]{3, 20}$", 29 | ) 30 | 31 | vector_store_config: Annotated[ 32 | Union[BaseVectorStoreConfig.get_subclasses()], Field(discriminator="type") 33 | ] 34 | node_parser_config: NodeParserConfig = Field(default_factory=NodeParserConfig) 35 | embedding_config: Annotated[ 36 | Union[PaiBaseEmbeddingConfig.get_subclasses()], Field(discriminator="source") 37 | ] 38 | retrieval_settings: Dict = Field(default_factory=dict) 39 | qa_prompt_templates: Dict = { 40 | "system_prompt_template": DEFAULT_SYSTEM_ROLE_TEMPLATE, 41 | "task_prompt_template": DEFAULT_CUSTOM_PROMPT_TEMPLATE, 42 | } 43 | 44 | @model_validator(mode="before") 45 | def preprocess(cls, values: Dict) -> Dict: 46 | if "index_name" in values: 47 | values["name"] = values["index_name"] 48 | return values 49 | 50 | def model_post_init(self, context): 51 | # 修改retrieval_settings的key 52 | default_retrieval_settings = { 53 | "retrieval_mode": "default", 54 | "similarity_top_k": DEFAULT_SIMILARITY_TOP_K, 55 | "reranker_type": "no-reranker", 56 | "reranker_model": DEFAULT_RERANK_MODEL, 57 | "similarity_threshold": DEFAULT_SIMILARITY_THRESHOLD, 58 | "reranker_similarity_threshold": DEFAULT_RERANK_SIMILARITY_THRESHOLD, 59 | "reranker_similarity_top_k": DEFAULT_RERANK_TOP_N, 60 | } 61 | default_retrieval_settings.update(self.retrieval_settings) 62 | self.retrieval_settings = default_retrieval_settings 63 | -------------------------------------------------------------------------------- /src/pairag/knowledgebase/utils/knowledgebase_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from loguru import logger 4 | from pairag.utils.constants import DEFAULT_KNOWLEDGEBASE_PATH 5 | 6 | 7 | def delete_file(file_path): 8 | try: 9 | os.remove(file_path) 10 | logger.info(f"File {file_path} successfully deleted.") 11 | except FileNotFoundError: 12 | logger.error(f"File {file_path} does not exist.") 13 | except Exception as e: 14 | logger.error(f"Error deleting file {file_path}: {e}") 15 | 16 | 17 | def delete_dir(folder_path): 18 | try: 19 | shutil.rmtree(folder_path) 20 | logger.info(f"Folder {folder_path} and its contents successfully deleted.") 21 | except FileNotFoundError: 22 | logger.error(f"Folder {folder_path} does not exist.") 23 | except Exception as e: 24 | logger.error(f"Error deleting folder {folder_path}: {e}") 25 | 26 | 27 | def delete_knowledgebase_dir(index_name): 28 | destination_folder = os.path.join(DEFAULT_KNOWLEDGEBASE_PATH, index_name) 29 | delete_dir(destination_folder) 30 | 31 | 32 | def delete_default_knowledgebase_dir(): 33 | destination_folder = os.path.join(DEFAULT_KNOWLEDGEBASE_PATH, "default") 34 | delete_dir(destination_folder) 35 | 36 | 37 | def write_markdown_to_parse_dir(md_content, file_name, parse_dir): 38 | destination_md_file = os.path.join(parse_dir, f"{file_name}.md") 39 | try: 40 | with open(destination_md_file, "w", encoding="utf-8") as md_file: 41 | md_file.write(md_content) 42 | print(f"成功写入{destination_md_file}") 43 | except IOError as e: 44 | print(f"写入文件时出错: {e}") 45 | 46 | 47 | def copy_original_files_to_parse_dir(file_path, parse_dir): 48 | try: 49 | if os.path.isfile(file_path): 50 | shutil.copy(file_path, parse_dir) 51 | print(f"已复制: {file_path} 到 {parse_dir}") 52 | else: 53 | print(f"源文件不存在: {file_path}") 54 | except Exception as e: 55 | print(f"复制文件时出错: {file_path} -> {e}") 56 | -------------------------------------------------------------------------------- /src/pairag/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/src/pairag/utils/__init__.py -------------------------------------------------------------------------------- /src/pairag/utils/constants.py: -------------------------------------------------------------------------------- 1 | """Set of constants of modules.""" 2 | 3 | import os 4 | 5 | 6 | def try_get_int_env(key, default_value=None): 7 | """ 8 | Retrieves an integer from an environment variable. 9 | """ 10 | 11 | value_str = os.getenv(key) 12 | if value_str is None: 13 | if default_value is not None: 14 | return default_value 15 | else: 16 | return None 17 | try: 18 | return int(value_str) 19 | except ValueError: 20 | return None 21 | 22 | 23 | EAS_DEFAULT_MODEL_DIR = "/huggingface/pai_rag_model_repository_01" 24 | if not os.path.exists(EAS_DEFAULT_MODEL_DIR): 25 | DEFAULT_MODEL_DIR = "./model_repository" 26 | else: 27 | DEFAULT_MODEL_DIR = EAS_DEFAULT_MODEL_DIR 28 | 29 | OSS_URL = "https://pai-rag-bj.oss-cn-beijing.aliyuncs.com/model_repository/model_config_1.1.0.json" 30 | 31 | DEFAULT_DATAFILE_DIR = "./data" 32 | 33 | DEFAULT_DASHSCOPE_EMBEDDING_MODEL = "text-embedding-v2" 34 | 35 | 36 | DEFAULT_TASK_FILE = "localdata/ingestion__task__summary.json" 37 | 38 | DEFAULT_INDEX_FILE = "localdata/default__rag__index.json" 39 | DEFAULT_INDEX_NAME = "default" 40 | DEFAULT_INDEX_NAME_OLD = "default_index" 41 | 42 | DEFAULT_KNOWLEDGEBASE_PATH = "localdata/knowledgebase" 43 | DEFAULT_KNOWLEDGEBASE_NAME = "default" 44 | DEFAULT_KNOWLEDGEBASE_NAME_OLD = "default_index" 45 | DEFAULT_KNOWLEDGEBASE_FILE = "localdata/default__rag__index.json" 46 | DEFAULT_DOC_STORE_NAME = "default__knowledge__docs.json" 47 | DEFAULT_MAX_KNOWLEDGEBASE_COUNT = try_get_int_env( 48 | "DEFAULT_MAX_KNOWLEDGEBASE_COUNT", 3000 49 | ) 50 | DEFAULT_MAX_FILE_TASK_COUNT = try_get_int_env("DEFAULT_MAX_FILE_TASK_COUNT", 10000) 51 | -------------------------------------------------------------------------------- /src/pairag/utils/cuda_utils.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | import torch 3 | import os 4 | 5 | 6 | USE_CUDA = os.environ.get("USE_CUDA", "false") 7 | 8 | 9 | def should_use_cuda(): 10 | if not torch.cuda.is_available(): 11 | return False 12 | 13 | if USE_CUDA.lower() == "true" or USE_CUDA == "1": 14 | return True 15 | else: 16 | return False 17 | 18 | 19 | def infer_cuda_device() -> str: 20 | if should_use_cuda(): 21 | logger.info("Using cuda device.") 22 | return "cuda" 23 | else: 24 | logger.info( 25 | "Will not use CUDA device acceleration. If you want to use cuda, please set the environment variable USE_CUDA=1." 26 | ) 27 | return "cpu" 28 | -------------------------------------------------------------------------------- /src/pairag/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import logging 3 | from loguru import logger 4 | from tenacity import ( 5 | before_sleep_log, 6 | retry, 7 | stop_after_attempt, 8 | wait_fixed, 9 | retry_if_exception_type, 10 | ) 11 | import os 12 | 13 | 14 | def get_modified_time(file_path): 15 | state = os.stat(file_path) 16 | return state.st_mtime 17 | 18 | 19 | def generate_text_md5(text): 20 | text_md5 = hashlib.md5() # Create an MD5 hash object 21 | # Encode the file path string to bytes and update the hash 22 | text_md5.update(text.encode("utf-8")) 23 | return text_md5.hexdigest() 24 | 25 | 26 | # 读取文件的retry机制 27 | @retry( 28 | wait=wait_fixed(1), 29 | stop=stop_after_attempt(10), 30 | retry=retry_if_exception_type(OSError), 31 | before_sleep=before_sleep_log(logger, logging.INFO), 32 | ) 33 | def generate_file_md5(file_path): 34 | with open(file_path, "rb") as file: 35 | file_content_md5 = hashlib.md5() # Create an MD5 hash object 36 | 37 | while chunk := file.read(8192): # Read the file in 8 KB chunks 38 | file_content_md5.update(chunk) # Update the hash with the chunk 39 | 40 | return file_content_md5.hexdigest() 41 | 42 | 43 | def generate_md5(file_path): 44 | """Generate MD5 hash of the content of the specified file.""" 45 | try: 46 | return ( 47 | generate_text_md5(file_path), 48 | generate_file_md5(file_path), 49 | ) # Return the hexadecimal representation of the hash 50 | except Exception as ex: 51 | logger.error(f"Error generating md5 for file '{file_path}'.") 52 | raise ex 53 | -------------------------------------------------------------------------------- /src/pairag/utils/format_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from loguru import logger 4 | 5 | 6 | class InterceptHandler(logging.Handler): 7 | def emit(self, record: logging.LogRecord) -> None: 8 | level: str | int 9 | # 尝试获取与标准 logging 等级相对应的 Loguru 日志等级 10 | try: 11 | level = logger.level(record.levelname).name 12 | except ValueError: 13 | # 如果找不到对应的 Loguru 等级,则使用原始的数字等级 14 | level = record.levelno 15 | 16 | # 探测调用日志的代码位置 17 | frame, depth = logging.currentframe(), 2 18 | while frame.f_code.co_filename == logging.__file__: 19 | frame = frame.f_back 20 | depth += 1 21 | # 使用 Loguru 记录日志信息,保持调用栈的深度和异常信息 22 | logger.opt( 23 | depth=depth, 24 | exception=record.exc_info, 25 | ).log(level, record.getMessage()) 26 | 27 | 28 | def format_logging(): 29 | logging.basicConfig(handlers=[InterceptHandler()], level=logging.INFO, force=True) 30 | logger.remove(0) 31 | logger.add( 32 | sys.stderr, 33 | format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {process} | {level: <8} | {name}:{function}:{line} - {message}", 34 | ) 35 | -------------------------------------------------------------------------------- /src/pairag/utils/json_parser.py: -------------------------------------------------------------------------------- 1 | import json 2 | from loguru import logger 3 | 4 | 5 | def parse_json_from_code_block_str(input_str): 6 | start = input_str.find("{") 7 | end = input_str.find("}", start + 1) 8 | if start != -1 and end != -1: 9 | content = input_str[start : end + 1] 10 | try: 11 | data = json.loads(content) 12 | logger.debug(f"解析后的 JSON 对象:{data}") 13 | return data 14 | except json.JSONDecodeError as e: 15 | logger.debug("JSON 解码错误:", e) 16 | return json.loads('{ "queries": [] }') 17 | else: 18 | logger.debug("未找到有效的JSON对象。") 19 | return json.loads('{ "queries": [] }') 20 | -------------------------------------------------------------------------------- /src/pairag/utils/score_utils.py: -------------------------------------------------------------------------------- 1 | def normalize_cosine_similarity_score(sim_score): 2 | return round((1 + sim_score) / 2, 6) 3 | -------------------------------------------------------------------------------- /src/pairag/utils/time_utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | 4 | def get_current_time_str() -> str: 5 | return datetime.now().strftime("%Y-%m-%d %H:%M:%S") 6 | 7 | 8 | def get_prompt_current_time_str() -> str: 9 | return datetime.now().strftime("%Y年%m月%d日 %H:%M:%S") 10 | -------------------------------------------------------------------------------- /src/pairag/web/element_manager.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Dict, Generator, List, Tuple 2 | 3 | if TYPE_CHECKING: 4 | from gradio.components import Component 5 | 6 | 7 | class ElementManager: 8 | def __init__(self) -> None: 9 | self._id_to_elem: Dict[str, "Component"] = {} 10 | self._elem_to_id: Dict["Component", str] = {} 11 | 12 | def add_elems(self, elem_dict: Dict[str, "Component"]) -> None: 13 | r""" 14 | Adds elements to manager. 15 | """ 16 | for elem_id, elem in elem_dict.items(): 17 | self._id_to_elem[elem_id] = elem 18 | self._elem_to_id[elem] = elem_id 19 | 20 | def get_elem_list(self) -> List["Component"]: 21 | r""" 22 | Returns the list of all elements. 23 | """ 24 | return list(self._id_to_elem.values()) 25 | 26 | def get_elem_iter(self) -> Generator[Tuple[str, "Component"], None, None]: 27 | r""" 28 | Returns an iterator over all elements with their names. 29 | """ 30 | for elem_id, elem in self._id_to_elem.items(): 31 | yield elem_id.split(".")[-1], elem 32 | 33 | def get_elem_by_id(self, elem_id: str) -> "Component": 34 | r""" 35 | Gets element by id. 36 | 37 | Example: top.lang, train.dataset 38 | """ 39 | return self._id_to_elem[elem_id] 40 | 41 | def get_id_by_elem(self, elem: "Component") -> str: 42 | r""" 43 | Gets id by element. 44 | """ 45 | return self._elem_to_id[elem] 46 | 47 | 48 | elem_manager = ElementManager() 49 | -------------------------------------------------------------------------------- /src/pairag/web/filebrowser/constants.py: -------------------------------------------------------------------------------- 1 | from pairag.utils.constants import try_get_int_env 2 | 3 | DEFAULT_FILE_BROWSER_PORT = try_get_int_env("DEFAULT_FILE_BROWSER_PORT", 8012) 4 | 5 | FILEBROWSER_PREFIX = "/filebrowser/api/resources" 6 | FILEBROWSER_PREFIX_LEN = len(FILEBROWSER_PREFIX) 7 | -------------------------------------------------------------------------------- /src/pairag/web/filebrowser/request_utils.py: -------------------------------------------------------------------------------- 1 | from fastapi import Request, Response 2 | from pairag.web.filebrowser.constants import ( 3 | DEFAULT_FILE_BROWSER_PORT, 4 | ) 5 | import aiohttp 6 | from loguru import logger 7 | 8 | 9 | def clean_headers(headers: dict, keys): 10 | for k in keys: 11 | headers.pop(k, None) 12 | headers.pop(str.lower(k), None) 13 | return headers 14 | 15 | 16 | async def sender_data(req: Request): 17 | async for chunk in req.stream(): 18 | yield chunk 19 | 20 | 21 | async def postprocess_middleware_to_filebrowser(session, request, url): 22 | async with session.request( 23 | request.method, 24 | str(url), 25 | headers=clean_headers(dict(request.headers), ["Transfer-Encoding"]), 26 | params=str(request.path_params), 27 | data=sender_data(request), 28 | allow_redirects=False, 29 | ) as resp: 30 | content = await resp.content.read() 31 | return Response( 32 | content=content, 33 | headers=clean_headers( 34 | dict(resp.headers), ["Content-Encoding", "Content-Length"] 35 | ), 36 | status_code=resp.status, 37 | ) 38 | 39 | 40 | async def postprocess_middleware(request, call_next): 41 | logger.debug(f"request_path: {request.url.path} , method: {request.method}") 42 | if "/filebrowser" in request.url.path: 43 | url = request.url.replace( 44 | scheme="http", hostname="localhost", port=DEFAULT_FILE_BROWSER_PORT 45 | ) 46 | async with aiohttp.ClientSession( 47 | timeout=aiohttp.ClientTimeout( 48 | total=500 * 60, 49 | connect=500 * 60, 50 | ) 51 | ) as session: 52 | return await postprocess_middleware_to_filebrowser(session, request, url) 53 | else: 54 | response = await call_next(request) 55 | return response 56 | -------------------------------------------------------------------------------- /src/pairag/web/tabs/history_tab.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import pandas as pd 3 | from pairag.web.rag_local_client import rag_client 4 | 5 | 6 | def refresh_upload_history(knowledgebase_name): 7 | upload_jobs = rag_client.get_upload_history(knowledgebase_name) 8 | if upload_jobs: 9 | history_data = pd.DataFrame(upload_jobs) 10 | history_data["文件名"] = history_data["file_name"] 11 | history_data["操作"] = history_data["operation"] 12 | history_data["上传状态"] = history_data["status"] 13 | history_data["更新时间"] = history_data["last_modified_time"] 14 | history_data["错误原因"] = history_data["message"] 15 | 16 | history_data = history_data[["文件名", "操作", "上传状态", "更新时间", "错误原因"]].sort_values( 17 | by="更新时间", ascending=False 18 | ) 19 | else: 20 | history_data = pd.DataFrame(columns=["文件名", "操作", "上传状态", "更新时间", "错误原因"]) 21 | summary = f"累计上传{len(upload_jobs)}个文件。" 22 | if len(upload_jobs) > 0: 23 | finished = [job for job in upload_jobs if job["status"] == "done"] 24 | failed = [job for job in upload_jobs if job["status"] == "failed"] 25 | summary += f"成功上传{len(finished)}个文件,失败{len(failed)}个文件。" 26 | 27 | return [ 28 | gr.update(value=summary), 29 | gr.update(value=history_data), 30 | ] 31 | 32 | 33 | def create_upload_history(): 34 | with gr.Row(): 35 | history_index = gr.Dropdown( 36 | choices=[], 37 | value="", 38 | label="\N{bookmark} 知识库名称", 39 | elem_id="history_index", 40 | allow_custom_value=True, 41 | ) 42 | 43 | upload_summary = gr.HTML(value="累计上传0个文件", elem_id="upload_summary") 44 | refresh_button = gr.Button( 45 | value="刷新", 46 | elem_id="refresh_button", 47 | variant="primary", 48 | ) 49 | upload_history = gr.DataFrame( 50 | label="上传历史", 51 | visible=True, 52 | elem_id="upload_history", 53 | headers=["文件名", "上传状态", "更新时间", "错误原因"], 54 | ) 55 | 56 | history_index.change( 57 | fn=refresh_upload_history, 58 | inputs=[history_index], 59 | outputs=[upload_summary, upload_history], 60 | ) 61 | 62 | refresh_button.click( 63 | fn=refresh_upload_history, 64 | inputs=[history_index], 65 | outputs=[upload_summary, upload_history], 66 | ) 67 | 68 | return { 69 | history_index.elem_id: history_index, 70 | upload_summary.elem_id: upload_summary, 71 | upload_history.elem_id: upload_history, 72 | } 73 | -------------------------------------------------------------------------------- /src/pairag/web/tabs/model/index_info.py: -------------------------------------------------------------------------------- 1 | from pairag.web.rag_local_client import rag_client 2 | 3 | 4 | def get_index_map(): 5 | index_map = rag_client.list_indexes() 6 | return index_map 7 | -------------------------------------------------------------------------------- /src/pairag/web/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any, Dict 2 | 3 | 4 | def components_to_dict(components: List[Any]) -> Dict[str, Any]: 5 | return {c.elem_id: c for c in components} 6 | 7 | 8 | def check_variables_in_string(text, variables): 9 | missing_variables = [var for var in variables if f"{{{var}}}" not in text] 10 | if missing_variables: 11 | raise ValueError(f"以下变量名缺失: {', '.join(missing_variables)}") 12 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/tests/__init__.py -------------------------------------------------------------------------------- /tests/app/test_openai_embedding.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from httpx import ASGITransport, AsyncClient 3 | from pairag.app.app import app 4 | 5 | 6 | @pytest.mark.asyncio(scope="session") 7 | async def test_embedding(): 8 | import openai 9 | 10 | async with AsyncClient( 11 | transport=ASGITransport(app=app), base_url="http://test" 12 | ) as client: 13 | client = openai.AsyncClient( 14 | base_url="http://test/v1", api_key="123", http_client=client 15 | ) 16 | embedding_result = await client.embeddings.create( 17 | input="hello world", 18 | model="bge-m3", 19 | ) 20 | 21 | assert len(embedding_result.data[0].embedding) == 1024 22 | assert embedding_result.data[0].index == 0 23 | 24 | embedding_result = await client.embeddings.create( 25 | input="", 26 | model="bge-m3", 27 | ) 28 | 29 | assert len(embedding_result.data[0].embedding) == 1024 30 | assert embedding_result.data[0].index == 0 31 | 32 | embedding_result = await client.embeddings.create( 33 | input=["", "hi", "你在干什么"], 34 | model="bge-m3", 35 | ) 36 | 37 | assert len(embedding_result.data[0].embedding) == 1024 38 | assert len(embedding_result.data[1].embedding) == 1024 39 | assert len(embedding_result.data[2].embedding) == 1024 40 | assert embedding_result.data[0].index == 0 41 | assert embedding_result.data[1].index == 1 42 | assert embedding_result.data[2].index == 2 43 | 44 | try: 45 | embedding_result = await client.embeddings.create( 46 | input=None, 47 | model="bge-m3", 48 | ) 49 | raise Exception("should not reach here.") 50 | except Exception as e: 51 | print(e) 52 | -------------------------------------------------------------------------------- /tests/integrations/llm/test_function_calling_llm.py: -------------------------------------------------------------------------------- 1 | from llama_index.core.tools import FunctionTool 2 | from pairag.integrations.llms.pai.pai_llm import PaiLlm 3 | from pairag.integrations.llms.pai.llm_config import OpenAICompatibleLlmConfig 4 | import os 5 | 6 | 7 | fc_llm_config = OpenAICompatibleLlmConfig( 8 | model="qwen-max", api_key=os.environ.get("DASHSCOPE_API_KEY") 9 | ) 10 | 11 | fc_llm = PaiLlm(fc_llm_config) 12 | 13 | 14 | def multiply(a: int, b: int) -> int: 15 | """Multiple two integers and returns the result integer""" 16 | return a * b 17 | 18 | 19 | multiply_tool = FunctionTool.from_defaults(fn=multiply) 20 | 21 | 22 | def add(a: int, b: int) -> int: 23 | """Add two integers and returns the result integer""" 24 | return a + b 25 | 26 | 27 | add_tool = FunctionTool.from_defaults(fn=add) 28 | 29 | tools = [multiply_tool, add_tool] 30 | 31 | 32 | def test_fc_llm_chat_with_tools(): 33 | response = fc_llm.chat_with_tools(tools=tools, user_msg="What is (121 + 2) * 5?") 34 | tool_calls = fc_llm.get_tool_calls_from_response( 35 | response, error_on_no_tool_call=False 36 | ) 37 | assert len(tool_calls) > 0 38 | for _, tool_call in enumerate(tool_calls): 39 | if tool_call.tool_name == "add": 40 | assert tool_call.tool_kwargs["a"] == 121 41 | assert tool_call.tool_kwargs["b"] == 2 42 | if tool_call.tool_name == "multiply": 43 | assert tool_call.tool_kwargs["a"] == 123 44 | assert tool_call.tool_kwargs["b"] == 5 45 | -------------------------------------------------------------------------------- /tests/integrations/test_nl2pandas_retriever.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import pandas as pd 4 | 5 | from llama_index.llms.dashscope import DashScope 6 | from llama_index.embeddings.dashscope import DashScopeEmbedding 7 | from llama_index.core import Settings 8 | 9 | from pairag.integrations.data_analysis.nl2pandas_retriever import PandasQueryRetriever 10 | 11 | 12 | dashscope_key = os.environ.get("DASHSCOPE_API_KEY") 13 | llm = DashScope(model_name="qwen-max", temperature=0.1, api_key=dashscope_key) 14 | embed_model = DashScopeEmbedding(embed_batch_size=10, api_key=dashscope_key) 15 | Settings.llm = llm 16 | Settings.embed_model = embed_model 17 | 18 | 19 | @pytest.mark.skipif( 20 | os.getenv("DASHSCOPE_API_KEY") is None, reason="no llm api key provided" 21 | ) 22 | def test_pandas_query_retriever(): 23 | file_path = "./tests/testdata/csv_data/titanic_train.csv" 24 | df = pd.read_csv(file_path) 25 | data_analysis_retriever = PandasQueryRetriever(df) 26 | query = "What is the correlation between survival and age?" 27 | 28 | retrieved_res = data_analysis_retriever.retrieve(query) 29 | 30 | assert ( 31 | retrieved_res[0].metadata["query_code_instruction"] 32 | == "df['survived'].corr(df['age'])" 33 | ) 34 | 35 | assert eval(retrieved_res[0].metadata["query_output"]) < 0 36 | -------------------------------------------------------------------------------- /tests/testdata/db_data/pets.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/tests/testdata/db_data/pets.db -------------------------------------------------------------------------------- /tests/testdata/db_data/pets.sqlite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/PAI-RAG/a27b0b981106f363248cee7f20b51a8e0dbd3f45/tests/testdata/db_data/pets.sqlite --------------------------------------------------------------------------------