├── common └── __init__.py ├── embeddings ├── __init__.py ├── embedding_keywords.txt └── add_embedding_keywords.py ├── server ├── chat │ ├── __init__.py │ ├── feedback.py │ ├── utils.py │ └── completion.py ├── db │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ ├── base.py │ │ ├── conversation_model.py │ │ ├── knowledge_base_model.py │ │ ├── message_model.py │ │ ├── knowledge_metadata_model.py │ │ └── knowledge_file_model.py │ ├── repository │ │ ├── __init__.py │ │ ├── conversation_repository.py │ │ ├── knowledge_base_repository.py │ │ ├── knowledge_metadata_repository.py │ │ └── message_repository.py │ ├── base.py │ └── session.py ├── knowledge_base │ ├── kb_service │ │ ├── __init__.py │ │ └── default_kb_service.py │ ├── kb_summary │ │ ├── __init__.py │ │ └── base.py │ ├── model │ │ └── kb_document_model.py │ ├── __init__.py │ └── kb_api.py ├── static │ └── favicon.png ├── agent │ ├── __init__.py │ ├── model_contain.py │ ├── tools │ │ ├── shell.py │ │ ├── arxiv.py │ │ ├── search_youtube.py │ │ ├── wolfram.py │ │ ├── __init__.py │ │ ├── weather_check.py │ │ ├── search_knowledgebase_simple.py │ │ ├── calculate.py │ │ └── search_internet.py │ ├── tools_select.py │ └── custom_template.py ├── model_workers │ ├── __init__.py │ ├── SparkApi.py │ └── tiangong.py ├── llm_api_shutdown.py ├── callback_handler │ └── conversation_callback_handler.py ├── minx_chat_openai.py ├── api_allinone_stale.py ├── memory │ └── conversation_db_buffer_memory.py └── webui_allinone_stale.py ├── webui_pages ├── __init__.py ├── dialogue │ └── __init__.py ├── model_config │ ├── __init__.py │ └── model_config.py └── knowledge_base │ └── __init__.py ├── tests ├── kb_vector_db │ ├── __init__.py │ ├── test_pg_db.py │ ├── test_faiss_kb.py │ └── test_milvus_db.py ├── samples │ ├── ocr_test.docx │ ├── ocr_test.jpg │ ├── ocr_test.pdf │ └── ocr_test.pptx ├── document_loader │ ├── test_imgloader.py │ └── test_pdfloader.py ├── api │ ├── test_server_state_api.py │ ├── test_kb_summary_api.py │ ├── test_stream_chat_api_thread.py │ └── test_llm_api.py ├── custom_splitter │ └── test_different_splitter.py └── test_online_api.py ├── img ├── LLM_success.png ├── docker_logs.png ├── qr_code_100.jpg ├── qr_code_101.jpg ├── qr_code_102.jpg ├── qr_code_103.jpg ├── qr_code_104.jpg ├── qr_code_105.jpg ├── qr_code_106.jpg ├── qr_code_90.jpg ├── qr_code_90.png ├── qr_code_91.jpg ├── qr_code_92.jpg ├── qr_code_93.jpg ├── qr_code_94.jpg ├── qr_code_95.jpg ├── qr_code_96.jpg ├── qr_code_97.jpg ├── qr_code_98.jpg ├── qr_code_99.jpg ├── qrcode_90_2.jpg ├── agent_continue.png ├── agent_success.png ├── qr_code_106_2.jpg ├── chatchat-qrcode.jpg ├── fastapi_docs_026.png ├── langchain+chatglm.png ├── init_knowledge_base.jpg ├── langchain+chatglm2.png ├── official_account_qr.png ├── knowledge_base_success.jpg ├── logo-long-chatchat-trans-v2.png ├── official_wechat_mp_account.png └── chatchat_icon_blue_square_v2.png ├── nltk_data ├── tokenizers │ └── punkt │ │ ├── czech.pickle │ │ ├── danish.pickle │ │ ├── dutch.pickle │ │ ├── french.pickle │ │ ├── german.pickle │ │ ├── polish.pickle │ │ ├── estonian.pickle │ │ ├── finnish.pickle │ │ ├── italian.pickle │ │ ├── russian.pickle │ │ ├── slovene.pickle │ │ ├── spanish.pickle │ │ ├── swedish.pickle │ │ ├── turkish.pickle │ │ ├── PY3 │ │ ├── czech.pickle │ │ ├── danish.pickle │ │ ├── dutch.pickle │ │ ├── english.pickle │ │ ├── finnish.pickle │ │ ├── french.pickle │ │ ├── german.pickle │ │ ├── greek.pickle │ │ ├── italian.pickle │ │ ├── polish.pickle │ │ ├── russian.pickle │ │ ├── slovene.pickle │ │ ├── spanish.pickle │ │ ├── swedish.pickle │ │ ├── turkish.pickle │ │ ├── estonian.pickle │ │ ├── malayalam.pickle │ │ ├── norwegian.pickle │ │ └── portuguese.pickle │ │ ├── malayalam.pickle │ │ ├── norwegian.pickle │ │ └── portuguese.pickle └── taggers │ └── averaged_perceptron_tagger │ └── averaged_perceptron_tagger.pickle ├── knowledge_base └── samples │ └── content │ ├── test_files │ ├── langchain.pdf │ ├── langchain-ChatGLM_open.xlsx │ └── langchain-ChatGLM_closed.xlsx │ └── llm │ ├── img │ ├── 分布式训练技术原理-幕布图片-20096-279847.jpg │ ├── 分布式训练技术原理-幕布图片-36114-765327.jpg │ ├── 分布式训练技术原理-幕布图片-42284-124759.jpg │ ├── 分布式训练技术原理-幕布图片-57107-679259.jpg │ ├── 大模型推理优化策略-幕布图片-590671-36787.jpg │ ├── 大模型推理优化策略-幕布图片-923924-83386.jpg │ ├── 分布式训练技术原理-幕布图片-124076-270516.jpg │ ├── 分布式训练技术原理-幕布图片-220157-552735.jpg │ ├── 分布式训练技术原理-幕布图片-392521-261326.jpg │ ├── 分布式训练技术原理-幕布图片-618350-869132.jpg │ ├── 分布式训练技术原理-幕布图片-838373-426344.jpg │ ├── 分布式训练技术原理-幕布图片-906937-836104.jpg │ ├── 大模型应用技术原理-幕布图片-108319-429731.jpg │ ├── 大模型应用技术原理-幕布图片-580318-260070.jpg │ ├── 大模型应用技术原理-幕布图片-793118-735987.jpg │ ├── 大模型应用技术原理-幕布图片-918388-323086.jpg │ ├── 大模型指令对齐训练原理-幕布图片-17565-176537.jpg │ ├── 大模型指令对齐训练原理-幕布图片-95996-523276.jpg │ ├── 大模型推理优化策略-幕布图片-276446-401476.jpg │ ├── 大模型推理优化策略-幕布图片-380552-579242.jpg │ ├── 大模型推理优化策略-幕布图片-699343-219844.jpg │ ├── 大模型推理优化策略-幕布图片-789705-122117.jpg │ ├── 大模型推理优化策略-幕布图片-930255-616209.jpg │ ├── 大模型技术栈-算法与原理-幕布图片-19929-302935.jpg │ ├── 大模型技术栈-算法与原理-幕布图片-299768-254064.jpg │ ├── 大模型技术栈-算法与原理-幕布图片-454007-940199.jpg │ ├── 大模型技术栈-算法与原理-幕布图片-628857-182232.jpg │ ├── 大模型技术栈-算法与原理-幕布图片-729151-372321.jpg │ ├── 大模型技术栈-算法与原理-幕布图片-81470-404273.jpg │ ├── 大模型指令对齐训练原理-幕布图片-349153-657791.jpg │ ├── 大模型指令对齐训练原理-幕布图片-350029-666381.jpg │ ├── 大模型指令对齐训练原理-幕布图片-759487-923925.jpg │ └── 大模型指令对齐训练原理-幕布图片-805089-731888.jpg │ ├── 大模型技术栈-实战与应用.md │ └── 大模型指令对齐训练原理.md ├── .gitmodules ├── shutdown_all.sh ├── configs ├── __init__.py └── basic_config.py.example ├── document_loaders ├── __init__.py ├── ocr.py ├── myimgloader.py ├── mypptloader.py ├── FilteredCSVloader.py └── mydocloader.py ├── text_splitter ├── __init__.py ├── ali_text_splitter.py └── chinese_text_splitter.py ├── requirements_webui.txt ├── .dockerignore ├── copy_config_example.py ├── markdown_docs ├── webui_pages │ └── model_config │ │ └── model_config.md ├── server │ ├── knowledge_base │ │ ├── model │ │ │ └── kb_document_model.md │ │ └── kb_api.md │ ├── chat │ │ ├── feedback.md │ │ ├── completion.md │ │ ├── chat.md │ │ ├── knowledge_base_chat.md │ │ └── agent_chat.md │ ├── db │ │ ├── models │ │ │ ├── base.md │ │ │ ├── knowledge_metadata_model.md │ │ │ ├── conversation_model.md │ │ │ ├── message_model.md │ │ │ └── knowledge_base_model.md │ │ └── repository │ │ │ └── conversation_repository.md │ ├── agent │ │ ├── model_contain.md │ │ └── tools │ │ │ ├── arxiv.md │ │ │ ├── shell.md │ │ │ ├── calculate.md │ │ │ ├── wolfram.md │ │ │ ├── search_youtube.md │ │ │ ├── search_knowledgebase_simple.md │ │ │ ├── weather_check.md │ │ │ └── search_internet.md │ ├── webui_allinone_stale.md │ └── minx_chat_openai.md ├── document_loaders │ ├── ocr.md │ └── myimgloader.md ├── release.md ├── text_splitter │ ├── zh_title_enhance.md │ └── ali_text_splitter.md └── embeddings │ └── add_embedding_keywords.md ├── .github ├── workflows │ └── close-issue.yml └── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md ├── chains └── llmchain_with_history.py ├── docs └── ES部署指南.md ├── requirements_lite.txt ├── Dockerfile ├── requirements_api.txt ├── release.py ├── requirements.txt └── webui.py /common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /server/chat/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /server/db/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /webui_pages/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /server/db/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/kb_vector_db/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /webui_pages/dialogue/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /server/knowledge_base/kb_service/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /server/knowledge_base/kb_summary/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /embeddings/embedding_keywords.txt: -------------------------------------------------------------------------------- 1 | Langchain-Chatchat 2 | 数据科学与大数据技术 3 | 人工智能与先进计算 -------------------------------------------------------------------------------- /webui_pages/model_config/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_config import model_config_page -------------------------------------------------------------------------------- /webui_pages/knowledge_base/__init__.py: -------------------------------------------------------------------------------- 1 | from .knowledge_base import knowledge_base_page 2 | -------------------------------------------------------------------------------- /img/LLM_success.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/LLM_success.png -------------------------------------------------------------------------------- /img/docker_logs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/docker_logs.png -------------------------------------------------------------------------------- /img/qr_code_100.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/qr_code_100.jpg -------------------------------------------------------------------------------- /img/qr_code_101.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/qr_code_101.jpg -------------------------------------------------------------------------------- /img/qr_code_102.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/qr_code_102.jpg -------------------------------------------------------------------------------- /img/qr_code_103.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/qr_code_103.jpg -------------------------------------------------------------------------------- /img/qr_code_104.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/qr_code_104.jpg -------------------------------------------------------------------------------- /img/qr_code_105.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/qr_code_105.jpg -------------------------------------------------------------------------------- /img/qr_code_106.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/qr_code_106.jpg -------------------------------------------------------------------------------- /img/qr_code_90.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/qr_code_90.jpg -------------------------------------------------------------------------------- /img/qr_code_90.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/qr_code_90.png -------------------------------------------------------------------------------- /img/qr_code_91.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/qr_code_91.jpg -------------------------------------------------------------------------------- /img/qr_code_92.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/qr_code_92.jpg -------------------------------------------------------------------------------- /img/qr_code_93.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/qr_code_93.jpg -------------------------------------------------------------------------------- /img/qr_code_94.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/qr_code_94.jpg -------------------------------------------------------------------------------- /img/qr_code_95.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/qr_code_95.jpg -------------------------------------------------------------------------------- /img/qr_code_96.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/qr_code_96.jpg -------------------------------------------------------------------------------- /img/qr_code_97.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/qr_code_97.jpg -------------------------------------------------------------------------------- /img/qr_code_98.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/qr_code_98.jpg -------------------------------------------------------------------------------- /img/qr_code_99.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/qr_code_99.jpg -------------------------------------------------------------------------------- /img/qrcode_90_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/qrcode_90_2.jpg -------------------------------------------------------------------------------- /img/agent_continue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/agent_continue.png -------------------------------------------------------------------------------- /img/agent_success.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/agent_success.png -------------------------------------------------------------------------------- /img/qr_code_106_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/qr_code_106_2.jpg -------------------------------------------------------------------------------- /img/chatchat-qrcode.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/chatchat-qrcode.jpg -------------------------------------------------------------------------------- /img/fastapi_docs_026.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/fastapi_docs_026.png -------------------------------------------------------------------------------- /img/langchain+chatglm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/langchain+chatglm.png -------------------------------------------------------------------------------- /server/static/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/server/static/favicon.png -------------------------------------------------------------------------------- /img/init_knowledge_base.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/init_knowledge_base.jpg -------------------------------------------------------------------------------- /img/langchain+chatglm2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/langchain+chatglm2.png -------------------------------------------------------------------------------- /img/official_account_qr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/official_account_qr.png -------------------------------------------------------------------------------- /tests/samples/ocr_test.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/tests/samples/ocr_test.docx -------------------------------------------------------------------------------- /tests/samples/ocr_test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/tests/samples/ocr_test.jpg -------------------------------------------------------------------------------- /tests/samples/ocr_test.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/tests/samples/ocr_test.pdf -------------------------------------------------------------------------------- /tests/samples/ocr_test.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/tests/samples/ocr_test.pptx -------------------------------------------------------------------------------- /img/knowledge_base_success.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/knowledge_base_success.jpg -------------------------------------------------------------------------------- /img/logo-long-chatchat-trans-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/logo-long-chatchat-trans-v2.png -------------------------------------------------------------------------------- /img/official_wechat_mp_account.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/official_wechat_mp_account.png -------------------------------------------------------------------------------- /img/chatchat_icon_blue_square_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/img/chatchat_icon_blue_square_v2.png -------------------------------------------------------------------------------- /server/agent/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_contain import * 2 | from .callbacks import * 3 | from .custom_template import * 4 | from .tools import * -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/czech.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/czech.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/danish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/danish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/dutch.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/dutch.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/french.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/french.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/german.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/german.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/polish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/polish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/estonian.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/estonian.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/finnish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/finnish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/italian.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/italian.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/russian.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/russian.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/slovene.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/slovene.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/spanish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/spanish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/swedish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/swedish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/turkish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/turkish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/czech.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/PY3/czech.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/danish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/PY3/danish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/dutch.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/PY3/dutch.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/english.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/PY3/english.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/finnish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/PY3/finnish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/french.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/PY3/french.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/german.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/PY3/german.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/greek.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/PY3/greek.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/italian.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/PY3/italian.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/polish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/PY3/polish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/russian.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/PY3/russian.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/slovene.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/PY3/slovene.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/spanish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/PY3/spanish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/swedish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/PY3/swedish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/turkish.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/PY3/turkish.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/malayalam.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/malayalam.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/norwegian.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/norwegian.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/portuguese.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/portuguese.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/estonian.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/PY3/estonian.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/malayalam.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/PY3/malayalam.pickle -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/norwegian.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/PY3/norwegian.pickle -------------------------------------------------------------------------------- /webui_pages/model_config/model_config.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from webui_pages.utils import * 3 | 4 | def model_config_page(api: ApiRequest): 5 | pass -------------------------------------------------------------------------------- /nltk_data/tokenizers/punkt/PY3/portuguese.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/tokenizers/punkt/PY3/portuguese.pickle -------------------------------------------------------------------------------- /knowledge_base/samples/content/test_files/langchain.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/test_files/langchain.pdf -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "knowledge_base/samples/content/wiki"] 2 | path = knowledge_base/samples/content/wiki 3 | url = https://github.com/chatchat-space/Langchain-Chatchat.wiki.git 4 | -------------------------------------------------------------------------------- /server/agent/model_contain.py: -------------------------------------------------------------------------------- 1 | class ModelContainer: 2 | def __init__(self): 3 | self.MODEL = None 4 | self.DATABASE = None 5 | 6 | model_container = ModelContainer() 7 | -------------------------------------------------------------------------------- /server/db/repository/__init__.py: -------------------------------------------------------------------------------- 1 | from .conversation_repository import * 2 | from .message_repository import * 3 | from .knowledge_base_repository import * 4 | from .knowledge_file_repository import * -------------------------------------------------------------------------------- /shutdown_all.sh: -------------------------------------------------------------------------------- 1 | # mac设备上的grep命令可能不支持grep -P选项,请使用Homebrew安装;或使用ggrep命令 2 | ps -eo pid,user,cmd|grep -P 'server/api.py|webui.py|fastchat.serve|multiprocessing'|grep -v grep|awk '{print $1}'|xargs kill -9 -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/分布式训练技术原理-幕布图片-20096-279847.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/分布式训练技术原理-幕布图片-20096-279847.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/分布式训练技术原理-幕布图片-36114-765327.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/分布式训练技术原理-幕布图片-36114-765327.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/分布式训练技术原理-幕布图片-42284-124759.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/分布式训练技术原理-幕布图片-42284-124759.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/分布式训练技术原理-幕布图片-57107-679259.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/分布式训练技术原理-幕布图片-57107-679259.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型推理优化策略-幕布图片-590671-36787.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型推理优化策略-幕布图片-590671-36787.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型推理优化策略-幕布图片-923924-83386.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型推理优化策略-幕布图片-923924-83386.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/test_files/langchain-ChatGLM_open.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/test_files/langchain-ChatGLM_open.xlsx -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic_config import * 2 | from .model_config import * 3 | from .kb_config import * 4 | from .server_config import * 5 | from .prompt_config import * 6 | 7 | 8 | VERSION = "v0.2.10" 9 | -------------------------------------------------------------------------------- /document_loaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .mypdfloader import RapidOCRPDFLoader 2 | from .myimgloader import RapidOCRLoader 3 | from .mydocloader import RapidOCRDocLoader 4 | from .mypptloader import RapidOCRPPTLoader 5 | -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/分布式训练技术原理-幕布图片-124076-270516.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/分布式训练技术原理-幕布图片-124076-270516.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/分布式训练技术原理-幕布图片-220157-552735.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/分布式训练技术原理-幕布图片-220157-552735.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/分布式训练技术原理-幕布图片-392521-261326.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/分布式训练技术原理-幕布图片-392521-261326.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/分布式训练技术原理-幕布图片-618350-869132.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/分布式训练技术原理-幕布图片-618350-869132.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/分布式训练技术原理-幕布图片-838373-426344.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/分布式训练技术原理-幕布图片-838373-426344.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/分布式训练技术原理-幕布图片-906937-836104.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/分布式训练技术原理-幕布图片-906937-836104.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型应用技术原理-幕布图片-108319-429731.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型应用技术原理-幕布图片-108319-429731.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型应用技术原理-幕布图片-580318-260070.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型应用技术原理-幕布图片-580318-260070.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型应用技术原理-幕布图片-793118-735987.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型应用技术原理-幕布图片-793118-735987.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型应用技术原理-幕布图片-918388-323086.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型应用技术原理-幕布图片-918388-323086.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型指令对齐训练原理-幕布图片-17565-176537.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型指令对齐训练原理-幕布图片-17565-176537.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型指令对齐训练原理-幕布图片-95996-523276.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型指令对齐训练原理-幕布图片-95996-523276.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型推理优化策略-幕布图片-276446-401476.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型推理优化策略-幕布图片-276446-401476.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型推理优化策略-幕布图片-380552-579242.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型推理优化策略-幕布图片-380552-579242.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型推理优化策略-幕布图片-699343-219844.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型推理优化策略-幕布图片-699343-219844.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型推理优化策略-幕布图片-789705-122117.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型推理优化策略-幕布图片-789705-122117.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型推理优化策略-幕布图片-930255-616209.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型推理优化策略-幕布图片-930255-616209.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/test_files/langchain-ChatGLM_closed.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/test_files/langchain-ChatGLM_closed.xlsx -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型技术栈-算法与原理-幕布图片-19929-302935.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型技术栈-算法与原理-幕布图片-19929-302935.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型技术栈-算法与原理-幕布图片-299768-254064.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型技术栈-算法与原理-幕布图片-299768-254064.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型技术栈-算法与原理-幕布图片-454007-940199.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型技术栈-算法与原理-幕布图片-454007-940199.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型技术栈-算法与原理-幕布图片-628857-182232.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型技术栈-算法与原理-幕布图片-628857-182232.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型技术栈-算法与原理-幕布图片-729151-372321.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型技术栈-算法与原理-幕布图片-729151-372321.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型技术栈-算法与原理-幕布图片-81470-404273.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型技术栈-算法与原理-幕布图片-81470-404273.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型指令对齐训练原理-幕布图片-349153-657791.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型指令对齐训练原理-幕布图片-349153-657791.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型指令对齐训练原理-幕布图片-350029-666381.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型指令对齐训练原理-幕布图片-350029-666381.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型指令对齐训练原理-幕布图片-759487-923925.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型指令对齐训练原理-幕布图片-759487-923925.jpg -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/img/大模型指令对齐训练原理-幕布图片-805089-731888.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/knowledge_base/samples/content/llm/img/大模型指令对齐训练原理-幕布图片-805089-731888.jpg -------------------------------------------------------------------------------- /nltk_data/taggers/averaged_perceptron_tagger/averaged_perceptron_tagger.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChrisKimZHT/Langchain-Chatchat/master/nltk_data/taggers/averaged_perceptron_tagger/averaged_perceptron_tagger.pickle -------------------------------------------------------------------------------- /server/knowledge_base/model/kb_document_model.py: -------------------------------------------------------------------------------- 1 | 2 | from langchain.docstore.document import Document 3 | 4 | 5 | class DocumentWithVSId(Document): 6 | """ 7 | 矢量化后的文档 8 | """ 9 | id: str = None 10 | score: float = 3.0 11 | -------------------------------------------------------------------------------- /server/knowledge_base/__init__.py: -------------------------------------------------------------------------------- 1 | # from .kb_api import list_kbs, create_kb, delete_kb 2 | # from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store 3 | # from .utils import KnowledgeFile, KBServiceFactory 4 | -------------------------------------------------------------------------------- /text_splitter/__init__.py: -------------------------------------------------------------------------------- 1 | from .chinese_text_splitter import ChineseTextSplitter 2 | from .ali_text_splitter import AliTextSplitter 3 | from .zh_title_enhance import zh_title_enhance 4 | from .chinese_recursive_text_splitter import ChineseRecursiveTextSplitter -------------------------------------------------------------------------------- /requirements_webui.txt: -------------------------------------------------------------------------------- 1 | streamlit==1.30.0 2 | streamlit-option-menu==0.3.12 3 | streamlit-antd-components==0.3.1 4 | streamlit-chatbox==1.1.11 5 | streamlit-modal==0.1.0 6 | streamlit-aggrid==0.3.4.post3 7 | httpx==0.26.0 8 | httpx_sse==0.4.0 9 | watchdog==3.0.0 10 | -------------------------------------------------------------------------------- /server/agent/tools/shell.py: -------------------------------------------------------------------------------- 1 | # LangChain 的 Shell 工具 2 | from pydantic import BaseModel, Field 3 | from langchain.tools import ShellTool 4 | def shell(query: str): 5 | tool = ShellTool() 6 | return tool.run(tool_input=query) 7 | 8 | class ShellInput(BaseModel): 9 | query: str = Field(description="一个能在Linux命令行运行的Shell命令") -------------------------------------------------------------------------------- /server/agent/tools/arxiv.py: -------------------------------------------------------------------------------- 1 | # LangChain 的 ArxivQueryRun 工具 2 | from pydantic import BaseModel, Field 3 | from langchain.tools.arxiv.tool import ArxivQueryRun 4 | def arxiv(query: str): 5 | tool = ArxivQueryRun() 6 | return tool.run(tool_input=query) 7 | 8 | class ArxivInput(BaseModel): 9 | query: str = Field(description="The search query title") -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .idea 2 | # Langchain-Chatchat 3 | docs 4 | .github 5 | tests 6 | Dockerfile 7 | .dockerignore 8 | .gitignore 9 | .gitmodules 10 | README.md 11 | README_en.md 12 | README_ja.md 13 | LICENSE 14 | requirements_api.txt 15 | requirements_lite.txt 16 | requirements_webui.txt 17 | # bge-large-zh-v1.5 18 | bge-large-zh-v1.5/README.md 19 | # chatglm3-6b -------------------------------------------------------------------------------- /server/agent/tools/search_youtube.py: -------------------------------------------------------------------------------- 1 | # Langchain 自带的 YouTube 搜索工具封装 2 | from langchain.tools import YouTubeSearchTool 3 | from pydantic import BaseModel, Field 4 | def search_youtube(query: str): 5 | tool = YouTubeSearchTool() 6 | return tool.run(tool_input=query) 7 | 8 | class YoutubeInput(BaseModel): 9 | location: str = Field(description="Query for Videos search") -------------------------------------------------------------------------------- /copy_config_example.py: -------------------------------------------------------------------------------- 1 | # 用于批量将configs下的.example文件复制并命名为.py文件 2 | import os 3 | import shutil 4 | 5 | if __name__ == "__main__": 6 | files = os.listdir("configs") 7 | 8 | src_files = [os.path.join("configs", file) for file in files if ".example" in file] 9 | 10 | for src_file in src_files: 11 | tar_file = src_file.replace(".example", "") 12 | shutil.copy(src_file, tar_file) 13 | -------------------------------------------------------------------------------- /server/model_workers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .zhipu import ChatGLMWorker 3 | from .minimax import MiniMaxWorker 4 | from .xinghuo import XingHuoWorker 5 | from .qianfan import QianFanWorker 6 | from .fangzhou import FangZhouWorker 7 | from .qwen import QwenWorker 8 | from .baichuan import BaiChuanWorker 9 | from .azure import AzureWorker 10 | from .tiangong import TianGongWorker 11 | from .gemini import GeminiWorker 12 | from .claude import ClaudeWorker -------------------------------------------------------------------------------- /server/agent/tools/wolfram.py: -------------------------------------------------------------------------------- 1 | # Langchain 自带的 Wolfram Alpha API 封装 2 | from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper 3 | from pydantic import BaseModel, Field 4 | wolfram_alpha_appid = "your key" 5 | def wolfram(query: str): 6 | wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=wolfram_alpha_appid) 7 | ans = wolfram.run(query) 8 | return ans 9 | 10 | class WolframInput(BaseModel): 11 | location: str = Field(description="需要运算的具体问题") -------------------------------------------------------------------------------- /server/db/base.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import create_engine 2 | from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta 3 | from sqlalchemy.orm import sessionmaker 4 | 5 | from configs import SQLALCHEMY_DATABASE_URI 6 | import json 7 | 8 | 9 | engine = create_engine( 10 | SQLALCHEMY_DATABASE_URI, 11 | json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False), 12 | ) 13 | 14 | SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) 15 | 16 | Base: DeclarativeMeta = declarative_base() 17 | -------------------------------------------------------------------------------- /server/db/repository/conversation_repository.py: -------------------------------------------------------------------------------- 1 | from server.db.session import with_session 2 | import uuid 3 | from server.db.models.conversation_model import ConversationModel 4 | 5 | 6 | @with_session 7 | def add_conversation_to_db(session, chat_type, name="", conversation_id=None): 8 | """ 9 | 新增聊天记录 10 | """ 11 | if not conversation_id: 12 | conversation_id = uuid.uuid4().hex 13 | c = ConversationModel(id=conversation_id, chat_type=chat_type, name=name) 14 | 15 | session.add(c) 16 | return c.id 17 | -------------------------------------------------------------------------------- /server/db/models/base.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from sqlalchemy import Column, DateTime, String, Integer 3 | 4 | 5 | class BaseModel: 6 | """ 7 | 基础模型 8 | """ 9 | id = Column(Integer, primary_key=True, index=True, comment="主键ID") 10 | create_time = Column(DateTime, default=datetime.utcnow, comment="创建时间") 11 | update_time = Column(DateTime, default=None, onupdate=datetime.utcnow, comment="更新时间") 12 | create_by = Column(String, default=None, comment="创建者") 13 | update_by = Column(String, default=None, comment="更新者") 14 | -------------------------------------------------------------------------------- /document_loaders/ocr.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | 4 | if TYPE_CHECKING: 5 | try: 6 | from rapidocr_paddle import RapidOCR 7 | except ImportError: 8 | from rapidocr_onnxruntime import RapidOCR 9 | 10 | 11 | def get_ocr(use_cuda: bool = True) -> "RapidOCR": 12 | try: 13 | from rapidocr_paddle import RapidOCR 14 | ocr = RapidOCR(det_use_cuda=use_cuda, cls_use_cuda=use_cuda, rec_use_cuda=use_cuda) 15 | except ImportError: 16 | from rapidocr_onnxruntime import RapidOCR 17 | ocr = RapidOCR() 18 | return ocr 19 | -------------------------------------------------------------------------------- /tests/document_loader/test_imgloader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | root_path = Path(__file__).parent.parent.parent 5 | sys.path.append(str(root_path)) 6 | from pprint import pprint 7 | 8 | test_files = { 9 | "ocr_test.jpg": str(root_path / "tests" / "samples" / "ocr_test.jpg"), 10 | } 11 | 12 | def test_rapidocrloader(): 13 | img_path = test_files["ocr_test.jpg"] 14 | from document_loaders import RapidOCRLoader 15 | 16 | loader = RapidOCRLoader(img_path) 17 | docs = loader.load() 18 | pprint(docs) 19 | assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str) 20 | 21 | 22 | -------------------------------------------------------------------------------- /tests/document_loader/test_pdfloader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | root_path = Path(__file__).parent.parent.parent 5 | sys.path.append(str(root_path)) 6 | from pprint import pprint 7 | 8 | test_files = { 9 | "ocr_test.pdf": str(root_path / "tests" / "samples" / "ocr_test.pdf"), 10 | } 11 | 12 | def test_rapidocrpdfloader(): 13 | pdf_path = test_files["ocr_test.pdf"] 14 | from document_loaders import RapidOCRPDFLoader 15 | 16 | loader = RapidOCRPDFLoader(pdf_path) 17 | docs = loader.load() 18 | pprint(docs) 19 | assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str) 20 | 21 | 22 | -------------------------------------------------------------------------------- /markdown_docs/webui_pages/model_config/model_config.md: -------------------------------------------------------------------------------- 1 | ## FunctionDef model_config_page(api) 2 | **model_config_page**: 此函数用于处理模型配置页面的请求。 3 | 4 | **参数**: 5 | - `api`: ApiRequest 类的实例,用于封装与 API 服务器的 HTTP 请求。 6 | 7 | **代码描述**: 8 | 函数 `model_config_page` 是项目中用于处理模型配置页面请求的函数。它接收一个 `ApiRequest` 类型的参数 `api`,该参数是一个封装了 HTTP 请求的对象,简化了与 API 服务器的交互过程。在此函数中,可以通过 `api` 参数调用不同的 API 接口,以实现获取或修改模型配置的功能。当前函数体内部为空,这意味着函数尚未实现具体的逻辑。在实际应用中,开发者需要根据项目需求,在此函数内部添加相应的代码逻辑,以完成对模型配置页面请求的处理。 9 | 10 | **注意**: 11 | - 在使用 `model_config_page` 函数时,需要确保传入的 `api` 参数已正确初始化,并且 `base_url` 属性已指向正确的 API 服务器地址。 12 | - 由于当前函数体为空,开发者在实际使用时需要根据具体需求添加实现代码。 13 | - 此函数是项目中处理模型配置页面请求的关键部分,因此在修改或扩展功能时应保持代码的清晰和稳定性。 14 | -------------------------------------------------------------------------------- /server/agent/tools/__init__.py: -------------------------------------------------------------------------------- 1 | ## 导入所有的工具类 2 | from .search_knowledgebase_simple import search_knowledgebase_simple 3 | from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput 4 | from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput 5 | from .calculate import calculate, CalculatorInput 6 | from .weather_check import weathercheck, WeatherInput 7 | from .shell import shell, ShellInput 8 | from .search_internet import search_internet, SearchInternetInput 9 | from .wolfram import wolfram, WolframInput 10 | from .search_youtube import search_youtube, YoutubeInput 11 | from .arxiv import arxiv, ArxivInput 12 | -------------------------------------------------------------------------------- /server/db/models/conversation_model.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, String, DateTime, JSON, func 2 | from server.db.base import Base 3 | 4 | 5 | class ConversationModel(Base): 6 | """ 7 | 聊天记录模型 8 | """ 9 | __tablename__ = 'conversation' 10 | id = Column(String(32), primary_key=True, comment='对话框ID') 11 | name = Column(String(50), comment='对话框名称') 12 | # chat/agent_chat等 13 | chat_type = Column(String(50), comment='聊天类型') 14 | create_time = Column(DateTime, default=func.now(), comment='创建时间') 15 | 16 | def __repr__(self): 17 | return f"" 18 | -------------------------------------------------------------------------------- /.github/workflows/close-issue.yml: -------------------------------------------------------------------------------- 1 | name: Close inactive issues 2 | on: 3 | schedule: 4 | - cron: "30 21 * * *" 5 | 6 | jobs: 7 | close-issues: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | issues: write 11 | pull-requests: write 12 | steps: 13 | - uses: actions/stale@v5 14 | with: 15 | days-before-issue-stale: 30 16 | days-before-issue-close: 14 17 | stale-issue-label: "stale" 18 | stale-issue-message: "这个问题已经被标记为 `stale` ,因为它已经超过 30 天没有任何活动。" 19 | close-issue-message: "这个问题已经被自动关闭,因为它被标为 `stale` 后超过 14 天没有任何活动。" 20 | days-before-pr-stale: -1 21 | days-before-pr-close: -1 22 | repo-token: ${{ secrets.GITHUB_TOKEN }} 23 | -------------------------------------------------------------------------------- /chains/llmchain_with_history.py: -------------------------------------------------------------------------------- 1 | from server.utils import get_ChatOpenAI 2 | from configs.model_config import LLM_MODELS, TEMPERATURE 3 | from langchain.chains import LLMChain 4 | from langchain.prompts.chat import ( 5 | ChatPromptTemplate, 6 | HumanMessagePromptTemplate, 7 | ) 8 | 9 | model = get_ChatOpenAI(model_name=LLM_MODELS[0], temperature=TEMPERATURE) 10 | 11 | 12 | human_prompt = "{input}" 13 | human_message_template = HumanMessagePromptTemplate.from_template(human_prompt) 14 | 15 | chat_prompt = ChatPromptTemplate.from_messages( 16 | [("human", "我们来玩成语接龙,我先来,生龙活虎"), 17 | ("ai", "虎头虎脑"), 18 | ("human", "{input}")]) 19 | 20 | 21 | chain = LLMChain(prompt=chat_prompt, llm=model, verbose=True) 22 | print(chain({"input": "恼羞成怒"})) -------------------------------------------------------------------------------- /configs/basic_config.py.example: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import langchain 4 | import tempfile 5 | import shutil 6 | 7 | 8 | # 是否显示详细日志 9 | log_verbose = False 10 | langchain.verbose = False 11 | 12 | # 通常情况下不需要更改以下内容 13 | 14 | # 日志格式 15 | LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" 16 | logger = logging.getLogger() 17 | logger.setLevel(logging.INFO) 18 | logging.basicConfig(format=LOG_FORMAT) 19 | 20 | 21 | # 日志存储路径 22 | LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") 23 | if not os.path.exists(LOG_PATH): 24 | os.mkdir(LOG_PATH) 25 | 26 | # 临时文件目录,主要用于文件对话 27 | BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat") 28 | try: 29 | shutil.rmtree(BASE_TEMP_DIR) 30 | except Exception: 31 | pass 32 | os.makedirs(BASE_TEMP_DIR, exist_ok=True) 33 | -------------------------------------------------------------------------------- /server/chat/feedback.py: -------------------------------------------------------------------------------- 1 | from fastapi import Body 2 | from configs import logger, log_verbose 3 | from server.utils import BaseResponse 4 | from server.db.repository import feedback_message_to_db 5 | 6 | def chat_feedback(message_id: str = Body("", max_length=32, description="聊天记录id"), 7 | score: int = Body(0, max=100, description="用户评分,满分100,越大表示评价越高"), 8 | reason: str = Body("", description="用户评分理由,比如不符合事实等") 9 | ): 10 | try: 11 | feedback_message_to_db(message_id, score, reason) 12 | except Exception as e: 13 | msg = f"反馈聊天记录出错: {e}" 14 | logger.error(f'{e.__class__.__name__}: {msg}', 15 | exc_info=e if log_verbose else None) 16 | return BaseResponse(code=500, msg=msg) 17 | 18 | return BaseResponse(code=200, msg=f"已反馈聊天记录 {message_id}") 19 | -------------------------------------------------------------------------------- /markdown_docs/server/knowledge_base/model/kb_document_model.md: -------------------------------------------------------------------------------- 1 | ## ClassDef DocumentWithVSId 2 | **DocumentWithVSId**: DocumentWithVSId 类的功能是表示一个经过向量化处理的文档。 3 | 4 | **属性**: 5 | - `id`: 文档的唯一标识符,类型为字符串。 6 | - `score`: 文档的评分,初始默认值为3.0,类型为浮点数。 7 | 8 | **代码描述**: 9 | DocumentWithVSId 类继承自 Document 类,用于表示一个经过向量化处理的文档。这个类主要用于知识库系统中,对文档进行向量化处理后,通过这个类的实例来表示处理结果。类中定义了两个属性:`id` 和 `score`。`id` 属性用于存储文档的唯一标识符,而 `score` 属性则用于存储文档在某些操作(如搜索或排序)中的评分或相关性度量。 10 | 11 | 在项目中,DocumentWithVSId 类的实例主要用于以下几个场景: 12 | 1. 在搜索知识库文档时,返回的搜索结果会包含一系列 DocumentWithVSId 实例,其中每个实例代表一个搜索到的文档,其 `score` 属性表示该文档与搜索查询的匹配程度。 13 | 2. 在列出知识库文档时,如果需要根据特定的文件名或元数据进行过滤,返回的结果也可能包含 DocumentWithVSId 实例。 14 | 3. 在文档摘要生成过程中,DocumentWithVSId 实例用于表示需要进行摘要处理的文档,其中 `id` 属性用于标识具体的文档。 15 | 16 | **注意**: 17 | - 在使用 DocumentWithVSId 类时,需要注意 `id` 属性的唯一性,确保每个实例能够准确地对应到知识库中的一个具体文档。 18 | - `score` 属性的值可能会根据不同的操作或上下文环境有所变化,因此在使用时应注意其含义和计算方式。 19 | -------------------------------------------------------------------------------- /docs/ES部署指南.md: -------------------------------------------------------------------------------- 1 | 2 | # 实现基于ES的数据插入、检索、删除、更新 3 | ```shell 4 | author: 唐国梁Tommy 5 | e-mail: flytang186@qq.com 6 | 7 | 如果遇到任何问题,可以与我联系,我这边部署后服务是没有问题的。 8 | ``` 9 | 10 | ## 第1步:ES docker部署 11 | ```shell 12 | docker network create elastic 13 | docker run -id --name elasticsearch --net elastic -p 9200:9200 -p 9300:9300 -e "discovery.type=single-node" -e "xpack.security.enabled=false" -e "xpack.security.http.ssl.enabled=false" -t docker.elastic.co/elasticsearch/elasticsearch:8.8.2 14 | ``` 15 | 16 | ### 第2步:Kibana docker部署 17 | **注意:Kibana版本与ES保持一致** 18 | ```shell 19 | docker pull docker.elastic.co/kibana/kibana:{version} 20 | docker run --name kibana --net elastic -p 5601:5601 docker.elastic.co/kibana/kibana:{version} 21 | ``` 22 | 23 | ### 第3步:核心代码 24 | ```shell 25 | 1. 核心代码路径 26 | server/knowledge_base/kb_service/es_kb_service.py 27 | 28 | 2. 需要在 configs/model_config.py 中 配置 ES参数(IP, PORT)等; 29 | ``` -------------------------------------------------------------------------------- /server/knowledge_base/kb_service/default_kb_service.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from langchain.embeddings.base import Embeddings 4 | from langchain.schema import Document 5 | 6 | from server.knowledge_base.kb_service.base import KBService 7 | 8 | 9 | class DefaultKBService(KBService): 10 | def do_create_kb(self): 11 | pass 12 | 13 | def do_drop_kb(self): 14 | pass 15 | 16 | def do_add_doc(self, docs: List[Document]): 17 | pass 18 | 19 | def do_clear_vs(self): 20 | pass 21 | 22 | def vs_type(self) -> str: 23 | return "default" 24 | 25 | def do_init(self): 26 | pass 27 | 28 | def do_search(self): 29 | pass 30 | 31 | def do_insert_multi_knowledge(self): 32 | pass 33 | 34 | def do_insert_one_knowledge(self): 35 | pass 36 | 37 | def do_delete_doc(self): 38 | pass 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 功能请求 / Feature Request 3 | about: 为项目提出新功能或建议 / Propose new features or suggestions for the project 4 | title: "[FEATURE] 简洁阐述功能 / Concise description of the feature" 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **功能描述 / Feature Description** 11 | 用简洁明了的语言描述所需的功能 / Describe the desired feature in a clear and concise manner. 12 | 13 | **解决的问题 / Problem Solved** 14 | 解释此功能如何解决现有问题或改进项目 / Explain how this feature solves existing problems or improves the project. 15 | 16 | **实现建议 / Implementation Suggestions** 17 | 如果可能,请提供关于如何实现此功能的建议 / If possible, provide suggestions on how to implement this feature. 18 | 19 | **替代方案 / Alternative Solutions** 20 | 描述您考虑过的替代方案 / Describe alternative solutions you have considered. 21 | 22 | **其他信息 / Additional Information** 23 | 添加与功能请求相关的任何其他信息 / Add any other information related to the feature request. -------------------------------------------------------------------------------- /server/llm_api_shutdown.py: -------------------------------------------------------------------------------- 1 | """ 2 | 调用示例: 3 | python llm_api_shutdown.py --serve all 4 | 可选"all","controller","model_worker","openai_api_server", all表示停止所有服务 5 | """ 6 | import sys 7 | import os 8 | 9 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 10 | 11 | import subprocess 12 | import argparse 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--serve", choices=["all", "controller", "model_worker", "openai_api_server"], default="all") 16 | 17 | args = parser.parse_args() 18 | 19 | base_shell = "ps -eo user,pid,cmd|grep fastchat.serve{}|grep -v grep|awk '{{print $2}}'|xargs kill -9" 20 | 21 | if args.serve == "all": 22 | shell_script = base_shell.format("") 23 | else: 24 | serve = f".{args.serve}" 25 | shell_script = base_shell.format(serve) 26 | 27 | subprocess.run(shell_script, shell=True, check=True) 28 | print(f"llm api sever --{args.serve} has been shutdown!") 29 | -------------------------------------------------------------------------------- /document_loaders/myimgloader.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from langchain.document_loaders.unstructured import UnstructuredFileLoader 3 | from document_loaders.ocr import get_ocr 4 | 5 | 6 | class RapidOCRLoader(UnstructuredFileLoader): 7 | def _get_elements(self) -> List: 8 | def img2text(filepath): 9 | resp = "" 10 | ocr = get_ocr() 11 | result, _ = ocr(filepath) 12 | if result: 13 | ocr_result = [line[1] for line in result] 14 | resp += "\n".join(ocr_result) 15 | return resp 16 | 17 | text = img2text(self.file_path) 18 | from unstructured.partition.text import partition_text 19 | return partition_text(text=text, **self.unstructured_kwargs) 20 | 21 | 22 | if __name__ == "__main__": 23 | loader = RapidOCRLoader(file_path="../tests/samples/ocr_test.jpg") 24 | docs = loader.load() 25 | print(docs) 26 | -------------------------------------------------------------------------------- /tests/kb_vector_db/test_pg_db.py: -------------------------------------------------------------------------------- 1 | from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService 2 | from server.knowledge_base.kb_service.pg_kb_service import PGKBService 3 | from server.knowledge_base.migrate import create_tables 4 | from server.knowledge_base.utils import KnowledgeFile 5 | 6 | kbService = PGKBService("test") 7 | 8 | test_kb_name = "test" 9 | test_file_name = "README.md" 10 | testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name) 11 | search_content = "如何启动api服务" 12 | 13 | 14 | def test_init(): 15 | create_tables() 16 | 17 | 18 | def test_create_db(): 19 | assert kbService.create_kb() 20 | 21 | 22 | def test_add_doc(): 23 | assert kbService.add_doc(testKnowledgeFile) 24 | 25 | 26 | def test_search_db(): 27 | result = kbService.search_docs(search_content) 28 | assert len(result) > 0 29 | def test_delete_doc(): 30 | assert kbService.delete_doc(testKnowledgeFile) 31 | 32 | -------------------------------------------------------------------------------- /tests/kb_vector_db/test_faiss_kb.py: -------------------------------------------------------------------------------- 1 | from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService 2 | from server.knowledge_base.migrate import create_tables 3 | from server.knowledge_base.utils import KnowledgeFile 4 | 5 | 6 | kbService = FaissKBService("test") 7 | test_kb_name = "test" 8 | test_file_name = "README.md" 9 | testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name) 10 | search_content = "如何启动api服务" 11 | 12 | 13 | def test_init(): 14 | create_tables() 15 | 16 | 17 | def test_create_db(): 18 | assert kbService.create_kb() 19 | 20 | 21 | def test_add_doc(): 22 | assert kbService.add_doc(testKnowledgeFile) 23 | 24 | 25 | def test_search_db(): 26 | result = kbService.search_docs(search_content) 27 | assert len(result) > 0 28 | 29 | 30 | def test_delete_doc(): 31 | assert kbService.delete_doc(testKnowledgeFile) 32 | 33 | 34 | def test_delete_db(): 35 | assert kbService.drop_kb() 36 | -------------------------------------------------------------------------------- /requirements_lite.txt: -------------------------------------------------------------------------------- 1 | langchain==0.0.354 2 | langchain-experimental==0.0.47 3 | pydantic==1.10.13 4 | fschat~=0.2.35 5 | openai~=1.9.0 6 | fastapi~=0.109.0 7 | sse_starlette~=1.8.2 8 | nltk~=3.8.1 9 | uvicorn>=0.27.0.post1 10 | starlette~=0.35.0 11 | unstructured[all-docs]~=0.12.0 12 | python-magic-bin; sys_platform == 'win32' 13 | SQLAlchemy~=2.0.25 14 | faiss-cpu~=1.7.4 15 | accelerate~=0.24.1 16 | spacy~=3.7.2 17 | PyMuPDF~=1.23.16 18 | rapidocr_onnxruntime~=1.3.8 19 | requests~=2.31.0 20 | pathlib~=1.0.1 21 | pytest~=7.4.3 22 | llama-index==0.9.35 23 | pyjwt==2.8.0 24 | httpx==0.26.0 25 | httpx_sse==0.4.0 26 | 27 | dashscope==1.13.6 28 | arxiv~=2.1.0 29 | youtube-search~=2.1.2 30 | duckduckgo-search~=3.9.9 31 | metaphor-python~=0.1.23 32 | watchdog~=3.0.0 33 | # volcengine>=1.0.134 34 | # pymilvus>=2.3.4 35 | # psycopg2==2.9.9 36 | # pgvector>=0.2.4 37 | # chromadb==0.4.13 38 | 39 | # jq==1.6.0 40 | # beautifulsoup4~=4.12.2 41 | # pysrt~=1.1.2 42 | 43 | -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/大模型技术栈-实战与应用.md: -------------------------------------------------------------------------------- 1 | # 大模型技术栈-实战与应用 2 | - 训练框架 3 | - deepspeed 4 | - megatron-lm 5 | - colossal-ai 6 | - trlx 7 | - 推理框架 8 | - triton 9 | - vllm 10 | - text-generation-inference 11 | - lit-llama 12 | - lightllm 13 | - TensorRT-LLM(原FasterTransformer) 14 | - fastllm 15 | - inferllm 16 | - llama-cpp 17 | - openPPL-LLM 18 | - 压缩框架 19 | - bitsandbytes 20 | - auto-gptq 21 | - deepspeed 22 | - embedding框架 23 | - sentence-transformer 24 | - FlagEmbedding 25 | - 向量数据库 [向量数据库对比]("https://www.jianshu.com/p/43cc19426113") 26 | - faiss 27 | - pgvector 28 | - milvus 29 | - pinecone 30 | - weaviate 31 | - LanceDB 32 | - Chroma 33 | - 应用框架 34 | - Auto-GPT 35 | - langchain 36 | - llama-index 37 | - quivr 38 | - python前端 39 | - streamlit 40 | - gradio 41 | - python API工具 42 | - FastAPI+uvicorn 43 | - flask 44 | - Django -------------------------------------------------------------------------------- /server/db/models/knowledge_base_model.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, String, DateTime, func 2 | 3 | from server.db.base import Base 4 | 5 | 6 | class KnowledgeBaseModel(Base): 7 | """ 8 | 知识库模型 9 | """ 10 | __tablename__ = 'knowledge_base' 11 | id = Column(Integer, primary_key=True, autoincrement=True, comment='知识库ID') 12 | kb_name = Column(String(50), comment='知识库名称') 13 | kb_info = Column(String(200), comment='知识库简介(用于Agent)') 14 | vs_type = Column(String(50), comment='向量库类型') 15 | embed_model = Column(String(50), comment='嵌入模型名称') 16 | file_count = Column(Integer, default=0, comment='文件数量') 17 | create_time = Column(DateTime, default=func.now(), comment='创建时间') 18 | 19 | def __repr__(self): 20 | return f"" 21 | -------------------------------------------------------------------------------- /tests/kb_vector_db/test_milvus_db.py: -------------------------------------------------------------------------------- 1 | from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService 2 | from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService 3 | from server.knowledge_base.kb_service.pg_kb_service import PGKBService 4 | from server.knowledge_base.migrate import create_tables 5 | from server.knowledge_base.utils import KnowledgeFile 6 | 7 | kbService = MilvusKBService("test") 8 | 9 | test_kb_name = "test" 10 | test_file_name = "README.md" 11 | testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name) 12 | search_content = "如何启动api服务" 13 | 14 | def test_init(): 15 | create_tables() 16 | 17 | 18 | def test_create_db(): 19 | assert kbService.create_kb() 20 | 21 | 22 | def test_add_doc(): 23 | assert kbService.add_doc(testKnowledgeFile) 24 | 25 | 26 | def test_search_db(): 27 | result = kbService.search_docs(search_content) 28 | assert len(result) > 0 29 | def test_delete_doc(): 30 | assert kbService.delete_doc(testKnowledgeFile) 31 | 32 | -------------------------------------------------------------------------------- /server/agent/tools/weather_check.py: -------------------------------------------------------------------------------- 1 | """ 2 | 更简单的单参数输入工具实现,用于查询现在天气的情况 3 | """ 4 | from pydantic import BaseModel, Field 5 | import requests 6 | from configs.kb_config import SENIVERSE_API_KEY 7 | 8 | 9 | def weather(location: str, api_key: str): 10 | url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c" 11 | response = requests.get(url) 12 | if response.status_code == 200: 13 | data = response.json() 14 | weather = { 15 | "temperature": data["results"][0]["now"]["temperature"], 16 | "description": data["results"][0]["now"]["text"], 17 | } 18 | return weather 19 | else: 20 | raise Exception( 21 | f"Failed to retrieve weather: {response.status_code}") 22 | 23 | 24 | def weathercheck(location: str): 25 | return weather(location, SENIVERSE_API_KEY) 26 | 27 | 28 | class WeatherInput(BaseModel): 29 | location: str = Field(description="City name,include city and county") 30 | -------------------------------------------------------------------------------- /markdown_docs/document_loaders/ocr.md: -------------------------------------------------------------------------------- 1 | ## FunctionDef get_ocr(use_cuda) 2 | **get_ocr**: 此函数的功能是获取一个OCR对象,用于执行图像或PDF中的文字识别。 3 | 4 | **参数**: 5 | - use_cuda: 布尔值,指定是否使用CUDA加速。默认为True。 6 | 7 | **代码描述**: 8 | `get_ocr`函数旨在提供一个灵活的方式来获取文字识别(OCR)的功能对象。它首先尝试从`rapidocr_paddle`模块导入`RapidOCR`类,如果成功,将创建一个`RapidOCR`实例,其中的CUDA加速设置将根据`use_cuda`参数来决定。如果在尝试导入`rapidocr_paddle`时发生`ImportError`异常,表明可能未安装相应的包,函数则会尝试从`rapidocr_onnxruntime`模块导入`RapidOCR`类,并创建一个不指定CUDA加速的`RapidOCR`实例。这种设计使得函数能够在不同的环境配置下灵活工作,即使在缺少某些依赖的情况下也能尽可能地提供OCR服务。 9 | 10 | 在项目中,`get_ocr`函数被用于不同的场景来执行OCR任务。例如,在`document_loaders/myimgloader.py`的`img2text`方法中,它被用来将图片文件中的文字识别出来;而在`document_loaders/mypdfloader.py`的`pdf2text`方法中,它被用于识别PDF文件中的文字以及PDF中嵌入图片的文字。这显示了`get_ocr`函数在项目中的多功能性和重要性,它为处理不同类型的文档提供了统一的OCR解决方案。 11 | 12 | **注意**: 13 | - 在使用`get_ocr`函数时,需要确保至少安装了`rapidocr_paddle`或`rapidocr_onnxruntime`中的一个包,以便函数能够成功返回一个OCR对象。 14 | - 如果计划在没有CUDA支持的环境中使用,应将`use_cuda`参数设置为False,以避免不必要的错误。 15 | 16 | **输出示例**: 17 | 由于`get_ocr`函数返回的是一个`RapidOCR`对象,因此输出示例将依赖于该对象的具体实现。一般而言,可以预期该对象提供了执行OCR任务的方法,如对图片或PDF中的文字进行识别,并返回识别结果。 18 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Base Image 2 | FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04 3 | # Labels 4 | LABEL maintainer=chatchat 5 | # Environment Variables 6 | ENV HOME=/Langchain-Chatchat 7 | # Commands 8 | WORKDIR / 9 | RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \ 10 | echo "Asia/Shanghai" > /etc/timezone && \ 11 | apt-get update -y && \ 12 | apt-get install -y --no-install-recommends python3.11 python3-pip curl libgl1 libglib2.0-0 jq && \ 13 | apt-get clean && \ 14 | rm -rf /var/lib/apt/lists/* && \ 15 | rm -f /usr/bin/python3 && \ 16 | ln -s /usr/bin/python3.11 /usr/bin/python3 && \ 17 | mkdir -p $HOME 18 | # Copy the application files 19 | COPY . $HOME 20 | WORKDIR $HOME 21 | # Install dependencies from requirements.txt 22 | RUN pip3 install -r requirements.txt -i https://pypi.org/simple && \ 23 | python3 copy_config_example.py && \ 24 | sed -i 's|MODEL_ROOT_PATH = ""|MODEL_ROOT_PATH = "/Langchain-Chatchat"|' configs/model_config.py && \ 25 | python3 init_database.py --recreate-vs 26 | EXPOSE 22 7861 8501 27 | ENTRYPOINT ["python3", "startup.py", "-a"] -------------------------------------------------------------------------------- /markdown_docs/server/chat/feedback.md: -------------------------------------------------------------------------------- 1 | ## FunctionDef chat_feedback(message_id, score, reason) 2 | **chat_feedback**: 此函数用于处理用户对聊天记录的反馈。 3 | 4 | **参数**: 5 | - `message_id`: 聊天记录的唯一标识ID,用于定位需要反馈的聊天记录。此参数最大长度为32个字符。 6 | - `score`: 用户对聊天记录的评分,满分为100。评分越高表示用户对聊天记录的满意度越高。 7 | - `reason`: 用户提供的评分理由,例如聊天记录不符合事实等。 8 | 9 | **代码描述**: 10 | `chat_feedback`函数首先尝试调用`feedback_message_to_db`函数,将用户的反馈信息(包括聊天记录ID、评分和评分理由)存储到数据库中。如果在执行过程中遇到任何异常,函数将捕获这些异常,并通过`logger.error`记录错误信息,同时返回一个包含错误信息的`BaseResponse`对象,状态码为500,表示服务器内部错误。如果没有发生异常,函数将返回一个状态码为200的`BaseResponse`对象,表示用户反馈已成功处理,并附带消息“已反馈聊天记录 {message_id}”。 11 | 12 | **注意**: 13 | - 在调用此函数之前,确保传入的`message_id`是有效的,并且在数据库中存在对应的聊天记录。 14 | - `score`参数应在0到100之间,以确保评分的有效性。 15 | - 在实际应用中,可能需要对用户的评分理由`reason`进行长度或内容的校验,以避免存储无效或不恰当的信息。 16 | - 此函数通过捕获异常并记录错误信息,提高了代码的健壮性。开发者应关注日志输出,以便及时发现并处理潜在的问题。 17 | 18 | **输出示例**: 19 | 如果用户反馈成功处理,函数可能返回如下的`BaseResponse`对象示例: 20 | ```json 21 | { 22 | "code": 200, 23 | "msg": "已反馈聊天记录 1234567890abcdef" 24 | } 25 | ``` 26 | 如果处理过程中发生异常,函数可能返回如下的`BaseResponse`对象示例: 27 | ```json 28 | { 29 | "code": 500, 30 | "msg": "反馈聊天记录出错:[异常信息]" 31 | } 32 | ``` 33 | -------------------------------------------------------------------------------- /server/db/session.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from contextlib import contextmanager 3 | from server.db.base import SessionLocal 4 | from sqlalchemy.orm import Session 5 | 6 | 7 | @contextmanager 8 | def session_scope() -> Session: 9 | """上下文管理器用于自动获取 Session, 避免错误""" 10 | session = SessionLocal() 11 | try: 12 | yield session 13 | session.commit() 14 | except: 15 | session.rollback() 16 | raise 17 | finally: 18 | session.close() 19 | 20 | 21 | def with_session(f): 22 | @wraps(f) 23 | def wrapper(*args, **kwargs): 24 | with session_scope() as session: 25 | try: 26 | result = f(session, *args, **kwargs) 27 | session.commit() 28 | return result 29 | except: 30 | session.rollback() 31 | raise 32 | 33 | return wrapper 34 | 35 | 36 | def get_db() -> SessionLocal: 37 | db = SessionLocal() 38 | try: 39 | yield db 40 | finally: 41 | db.close() 42 | 43 | 44 | def get_db0() -> SessionLocal: 45 | db = SessionLocal() 46 | return db 47 | -------------------------------------------------------------------------------- /markdown_docs/server/db/models/base.md: -------------------------------------------------------------------------------- 1 | ## ClassDef BaseModel 2 | **BaseModel**: BaseModel 的功能是提供一个数据库模型的基础结构。 3 | 4 | **属性**: 5 | - `id`: 主键ID,用于唯一标识每个记录。 6 | - `create_time`: 记录的创建时间。 7 | - `update_time`: 记录的最后更新时间。 8 | - `create_by`: 记录的创建者。 9 | - `update_by`: 记录的最后更新者。 10 | 11 | **代码描述**: 12 | BaseModel 类定义了一个数据库模型的基础结构,它包含了几个常见且重要的字段。这些字段包括: 13 | - `id` 字段使用 `Column` 函数定义,其类型为 `Integer`,并且被设置为主键(`primary_key=True`),同时启用索引(`index=True`),以便提高查询效率。此外,该字段还有一个注释(`comment="主键ID"`),用于说明字段的用途。 14 | - `create_time` 字段记录了数据被创建的时间,其类型为 `DateTime`。该字段的默认值通过 `datetime.utcnow` 函数设置,以确保使用的是创建记录时的UTC时间。此字段同样有一个注释(`comment="创建时间"`)。 15 | - `update_time` 字段记录了数据最后一次被更新的时间,类型也是 `DateTime`。不同的是,它的默认值设置为 `None`,并且通过 `onupdate=datetime.utcnow` 参数设置,当记录更新时,此字段会自动更新为当前的UTC时间。该字段也有相应的注释(`comment="更新时间"`)。 16 | - `create_by` 和 `update_by` 字段用于记录数据的创建者和最后更新者的信息,它们的类型都是 `String`。默认值为 `None`,并且各自有对应的注释(`comment="创建者"` 和 `comment="更新者"`),用于说明字段的用途。 17 | 18 | **注意**: 19 | - 使用BaseModel时,需要注意`create_time`和`update_time`字段默认使用的是UTC时间,这意味着如果应用程序在不同的时区运行,可能需要进行相应的时区转换。 20 | - `id`字段被设置为主键和索引,这对于数据库性能优化是非常重要的。确保每个模型都有一个唯一的标识符。 21 | - `create_by` 和 `update_by` 字段的默认值为 `None`,在实际应用中,根据业务需求,可能需要在数据创建或更新时,显式地设置这些字段的值。 22 | -------------------------------------------------------------------------------- /server/callback_handler/conversation_callback_handler.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | from langchain.callbacks.base import BaseCallbackHandler 4 | from langchain.schema import LLMResult 5 | from server.db.repository import update_message 6 | 7 | 8 | class ConversationCallbackHandler(BaseCallbackHandler): 9 | raise_error: bool = True 10 | 11 | def __init__(self, conversation_id: str, message_id: str, chat_type: str, query: str): 12 | self.conversation_id = conversation_id 13 | self.message_id = message_id 14 | self.chat_type = chat_type 15 | self.query = query 16 | self.start_at = None 17 | 18 | @property 19 | def always_verbose(self) -> bool: 20 | """Whether to call verbose callbacks even if verbose is False.""" 21 | return True 22 | 23 | def on_llm_start( 24 | self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any 25 | ) -> None: 26 | # 如果想存更多信息,则prompts 也需要持久化 27 | pass 28 | 29 | def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: 30 | answer = response.generations[0][0].text 31 | update_message(self.message_id, answer) 32 | -------------------------------------------------------------------------------- /markdown_docs/server/db/repository/conversation_repository.md: -------------------------------------------------------------------------------- 1 | ## FunctionDef add_conversation_to_db(session, chat_type, name, conversation_id) 2 | **add_conversation_to_db**: 此函数的功能是向数据库中新增一条聊天记录。 3 | 4 | **参数**: 5 | - `session`: 数据库会话实例,用于执行数据库操作。 6 | - `chat_type`: 字符串,表示聊天的类型(例如普通聊天、客服聊天等)。 7 | - `name`: 字符串,聊天记录的名称,默认为空字符串。 8 | - `conversation_id`: 字符串,聊天记录的唯一标识符,默认为None,若未提供,则会自动生成。 9 | 10 | **代码描述**: 11 | 此函数首先检查是否提供了`conversation_id`参数。如果没有提供,函数将使用`uuid.uuid4().hex`生成一个唯一的标识符。接着,函数创建一个`ConversationModel`实例,其中包含了聊天记录的ID、聊天类型、名称等信息。然后,通过`session.add(c)`将此实例添加到数据库会话中,准备将其保存到数据库。最后,函数返回新创建的聊天记录的ID。 12 | 13 | 此函数与`ConversationModel`类紧密相关,`ConversationModel`类定义了聊天记录的数据模型,包括聊天记录的ID、名称、聊天类型和创建时间等字段。`add_conversation_to_db`函数通过创建`ConversationModel`的实例并将其添加到数据库中,实现了聊天记录的新增功能。这体现了`ConversationModel`在项目中用于处理聊天记录数据的重要作用。 14 | 15 | **注意**: 16 | - 在调用此函数时,需要确保`session`参数是一个有效的数据库会话实例,以便能够正确执行数据库操作。 17 | - `chat_type`参数是必需的,因为它定义了聊天记录的类型,这对于后续的数据处理和查询是非常重要的。 18 | - 如果在调用函数时没有提供`conversation_id`,则会自动生成一个。这意味着每条聊天记录都将拥有一个唯一的标识符,即使在未显式指定ID的情况下也是如此。 19 | 20 | **输出示例**: 21 | 假设调用`add_conversation_to_db`函数,并传入相应的参数,函数可能会返回如下的聊天记录ID: 22 | ``` 23 | "e4eaaaf2-d142-11e1-b3e4-080027620cdd" 24 | ``` 25 | 这个返回值表示新创建的聊天记录的唯一标识符。 26 | -------------------------------------------------------------------------------- /server/db/models/message_model.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, String, DateTime, JSON, func 2 | 3 | from server.db.base import Base 4 | 5 | 6 | class MessageModel(Base): 7 | """ 8 | 聊天记录模型 9 | """ 10 | __tablename__ = 'message' 11 | id = Column(String(32), primary_key=True, comment='聊天记录ID') 12 | conversation_id = Column(String(32), default=None, index=True, comment='对话框ID') 13 | # chat/agent_chat等 14 | chat_type = Column(String(50), comment='聊天类型') 15 | query = Column(String(4096), comment='用户问题') 16 | response = Column(String(4096), comment='模型回答') 17 | # 记录知识库id等,以便后续扩展 18 | meta_data = Column(JSON, default={}) 19 | # 满分100 越高表示评价越好 20 | feedback_score = Column(Integer, default=-1, comment='用户评分') 21 | feedback_reason = Column(String(255), default="", comment='用户评分理由') 22 | create_time = Column(DateTime, default=func.now(), comment='创建时间') 23 | 24 | def __repr__(self): 25 | return f"" 26 | -------------------------------------------------------------------------------- /tests/api/test_server_state_api.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | root_path = Path(__file__).parent.parent.parent 4 | sys.path.append(str(root_path)) 5 | 6 | from webui_pages.utils import ApiRequest 7 | 8 | import pytest 9 | from pprint import pprint 10 | from typing import List 11 | 12 | 13 | api = ApiRequest() 14 | 15 | 16 | def test_get_default_llm(): 17 | llm = api.get_default_llm_model() 18 | 19 | print(llm) 20 | assert isinstance(llm, tuple) 21 | assert isinstance(llm[0], str) and isinstance(llm[1], bool) 22 | 23 | 24 | def test_server_configs(): 25 | configs = api.get_server_configs() 26 | pprint(configs, depth=2) 27 | 28 | assert isinstance(configs, dict) 29 | assert len(configs) > 0 30 | 31 | 32 | def test_list_search_engines(): 33 | engines = api.list_search_engines() 34 | pprint(engines) 35 | 36 | assert isinstance(engines, list) 37 | assert len(engines) > 0 38 | 39 | 40 | @pytest.mark.parametrize("type", ["llm_chat", "agent_chat"]) 41 | def test_get_prompt_template(type): 42 | print(f"prompt template for: {type}") 43 | template = api.get_prompt_template(type=type) 44 | 45 | print(template) 46 | assert isinstance(template, str) 47 | assert len(template) > 0 48 | -------------------------------------------------------------------------------- /server/db/models/knowledge_metadata_model.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, func 2 | 3 | from server.db.base import Base 4 | 5 | 6 | class SummaryChunkModel(Base): 7 | """ 8 | chunk summary模型,用于存储file_doc中每个doc_id的chunk 片段, 9 | 数据来源: 10 | 用户输入: 用户上传文件,可填写文件的描述,生成的file_doc中的doc_id,存入summary_chunk中 11 | 程序自动切分 对file_doc表meta_data字段信息中存储的页码信息,按每页的页码切分,自定义prompt生成总结文本,将对应页码关联的doc_id存入summary_chunk中 12 | 后续任务: 13 | 矢量库构建: 对数据库表summary_chunk中summary_context创建索引,构建矢量库,meta_data为矢量库的元数据(doc_ids) 14 | 语义关联: 通过用户输入的描述,自动切分的总结文本,计算 15 | 语义相似度 16 | 17 | """ 18 | __tablename__ = 'summary_chunk' 19 | id = Column(Integer, primary_key=True, autoincrement=True, comment='ID') 20 | kb_name = Column(String(50), comment='知识库名称') 21 | summary_context = Column(String(255), comment='总结文本') 22 | summary_id = Column(String(255), comment='总结矢量id') 23 | doc_ids = Column(String(1024), comment="向量库id关联列表") 24 | meta_data = Column(JSON, default={}) 25 | 26 | def __repr__(self): 27 | return (f"") 29 | -------------------------------------------------------------------------------- /markdown_docs/server/agent/model_contain.md: -------------------------------------------------------------------------------- 1 | ## ClassDef ModelContainer 2 | **ModelContainer**: ModelContainer 类的功能是作为模型和数据库的容器。 3 | 4 | **属性**: 5 | - MODEL: 用于存储模型实例。初始值为 None,表示在创建 ModelContainer 实例时,并没有预设的模型。 6 | - DATABASE: 用于存储数据库连接实例。初始值同样为 None,表示在创建 ModelContainer 实例时,并没有预设的数据库连接。 7 | 8 | **代码描述**: 9 | ModelContainer 类是一个简单的容器类,设计用来存储模型实例和数据库连接实例。这个类通过定义两个属性 `MODEL` 和 `DATABASE` 来实现其功能。这两个属性在类的初始化方法 `__init__` 中被设置为 None,这意味着在创建 ModelContainer 的实例时,这两个属性都不会持有任何值。这种设计允许开发者在创建 ModelContainer 实例后,根据需要将模型实例和数据库连接实例分别赋值给这两个属性。 10 | 11 | **注意**: 12 | - 在使用 ModelContainer 类时,开发者需要注意,`MODEL` 和 `DATABASE` 属性在初始状态下是 None。因此,在尝试访问这些属性或其方法之前,需要确保它们已被正确赋值,以避免遇到 `NoneType` 对象没有该方法的错误。 13 | - ModelContainer 类提供了一种灵活的方式来管理模型和数据库连接,但它本身不提供任何方法来初始化 `MODEL` 和 `DATABASE` 属性。开发者需要根据自己的需求,手动为这两个属性赋值。 14 | - 由于 ModelContainer 类的设计相对简单,它可以根据项目的需要进行扩展,例如添加更多的属性或方法来满足更复杂的需求。 15 | ### FunctionDef __init__(self) 16 | **__init__**: 此函数用于初始化ModelContainer类的实例。 17 | 18 | **参数**: 此函数不接受任何外部参数。 19 | 20 | **代码描述**: 在ModelContainer类的实例被创建时,`__init__`函数会被自动调用。此函数主要完成以下几点初始化操作: 21 | - 将`MODEL`属性设置为`None`。这意味着在实例化后,该属性暂时不关联任何模型,需要后续根据具体需求进行赋值。 22 | - 将`DATABASE`属性也设置为`None`。这表明在实例化的初始阶段,该属性不关联任何数据库,同样需要在后续操作中根据需要进行关联。 23 | 24 | 通过这种方式,`__init__`函数为ModelContainer类的实例提供了一个清晰、干净的初始状态,便于后续的属性赋值和方法调用。 25 | 26 | **注意**: 27 | - 在使用ModelContainer类创建实例后,需要根据实际情况给`MODEL`和`DATABASE`属性赋予具体的模型和数据库实例,以便于进行后续的操作。 28 | - 由于`MODEL`和`DATABASE`在初始化时都被设置为`None`,在对这两个属性进行操作前,建议先检查它们是否已被正确赋值,以避免在使用未初始化的属性时引发错误。 29 | *** 30 | -------------------------------------------------------------------------------- /server/chat/utils.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from langchain.prompts.chat import ChatMessagePromptTemplate 3 | from configs import logger, log_verbose 4 | from typing import List, Tuple, Dict, Union 5 | 6 | 7 | class History(BaseModel): 8 | """ 9 | 对话历史 10 | 可从dict生成,如 11 | h = History(**{"role":"user","content":"你好"}) 12 | 也可转换为tuple,如 13 | h.to_msy_tuple = ("human", "你好") 14 | """ 15 | role: str = Field(...) 16 | content: str = Field(...) 17 | 18 | def to_msg_tuple(self): 19 | return "ai" if self.role=="assistant" else "human", self.content 20 | 21 | def to_msg_template(self, is_raw=True) -> ChatMessagePromptTemplate: 22 | role_maps = { 23 | "ai": "assistant", 24 | "human": "user", 25 | } 26 | role = role_maps.get(self.role, self.role) 27 | if is_raw: # 当前默认历史消息都是没有input_variable的文本。 28 | content = "{% raw %}" + self.content + "{% endraw %}" 29 | else: 30 | content = self.content 31 | 32 | return ChatMessagePromptTemplate.from_template( 33 | content, 34 | "jinja2", 35 | role=role, 36 | ) 37 | 38 | @classmethod 39 | def from_data(cls, h: Union[List, Tuple, Dict]) -> "History": 40 | if isinstance(h, (list,tuple)) and len(h) >= 2: 41 | h = cls(role=h[0], content=h[1]) 42 | elif isinstance(h, dict): 43 | h = cls(**h) 44 | 45 | return h 46 | -------------------------------------------------------------------------------- /server/agent/tools/search_knowledgebase_simple.py: -------------------------------------------------------------------------------- 1 | from server.chat.knowledge_base_chat import knowledge_base_chat 2 | from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS 3 | import json 4 | import asyncio 5 | from server.agent import model_container 6 | 7 | async def search_knowledge_base_iter(database: str, query: str) -> str: 8 | response = await knowledge_base_chat(query=query, 9 | knowledge_base_name=database, 10 | model_name=model_container.MODEL.model_name, 11 | temperature=0.01, 12 | history=[], 13 | top_k=VECTOR_SEARCH_TOP_K, 14 | max_tokens=MAX_TOKENS, 15 | prompt_name="knowledge_base_chat", 16 | score_threshold=SCORE_THRESHOLD, 17 | stream=False) 18 | 19 | contents = "" 20 | async for data in response.body_iterator: # 这里的data是一个json字符串 21 | data = json.loads(data) 22 | contents = data["answer"] 23 | docs = data["docs"] 24 | return contents 25 | 26 | def search_knowledgebase_simple(query: str): 27 | return asyncio.run(search_knowledge_base_iter(query)) 28 | 29 | 30 | if __name__ == "__main__": 31 | result = search_knowledgebase_simple("大数据男女比例") 32 | print("答案:",result) -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug 报告 / Bug Report 3 | about: 报告项目中的错误或问题 / Report errors or issues in the project 4 | title: "[BUG] 简洁阐述问题 / Concise description of the issue" 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **问题描述 / Problem Description** 11 | 用简洁明了的语言描述这个问题 / Describe the problem in a clear and concise manner. 12 | 13 | **复现问题的步骤 / Steps to Reproduce** 14 | 1. 执行 '...' / Run '...' 15 | 2. 点击 '...' / Click '...' 16 | 3. 滚动到 '...' / Scroll to '...' 17 | 4. 问题出现 / Problem occurs 18 | 19 | **预期的结果 / Expected Result** 20 | 描述应该出现的结果 / Describe the expected result. 21 | 22 | **实际结果 / Actual Result** 23 | 描述实际发生的结果 / Describe the actual result. 24 | 25 | **环境信息 / Environment Information** 26 | - langchain-ChatGLM 版本/commit 号:(例如:v2.0.1 或 commit 123456) / langchain-ChatGLM version/commit number: (e.g., v2.0.1 or commit 123456) 27 | - 是否使用 Docker 部署(是/否):是 / Is Docker deployment used (yes/no): yes 28 | - 使用的模型(ChatGLM2-6B / Qwen-7B 等):ChatGLM-6B / Model used (ChatGLM2-6B / Qwen-7B, etc.): ChatGLM2-6B 29 | - 使用的 Embedding 模型(moka-ai/m3e-base 等):moka-ai/m3e-base / Embedding model used (moka-ai/m3e-base, etc.): moka-ai/m3e-base 30 | - 使用的向量库类型 (faiss / milvus / pg_vector 等): faiss / Vector library used (faiss, milvus, pg_vector, etc.): faiss 31 | - 操作系统及版本 / Operating system and version: 32 | - Python 版本 / Python version: 33 | - 其他相关环境信息 / Other relevant environment information: 34 | 35 | **附加信息 / Additional Information** 36 | 添加与问题相关的任何其他信息 / Add any other information related to the issue. -------------------------------------------------------------------------------- /requirements_api.txt: -------------------------------------------------------------------------------- 1 | torch~=2.1.2 2 | torchvision~=0.16.2 3 | torchaudio~=2.1.2 4 | xformers>=0.0.23.post1 5 | transformers==4.37.2 6 | sentence_transformers==2.2.2 7 | langchain==0.0.354 8 | langchain-experimental==0.0.47 9 | pydantic==1.10.13 10 | fschat==0.2.35 11 | openai~=1.9.0 12 | fastapi~=0.109.0 13 | sse_starlette==1.8.2 14 | nltk>=3.8.1 15 | uvicorn>=0.27.0.post1 16 | starlette~=0.35.0 17 | unstructured[all-docs]==0.11.0 18 | python-magic-bin; sys_platform == 'win32' 19 | SQLAlchemy==2.0.19 20 | faiss-cpu~=1.7.4 21 | accelerate~=0.24.1 22 | spacy~=3.7.2 23 | PyMuPDF~=1.23.8 24 | rapidocr_onnxruntime==1.3.8 25 | requests~=2.31.0 26 | pathlib~=1.0.1 27 | pytest~=7.4.3 28 | numexpr~=2.8.6 29 | strsimpy~=0.2.1 30 | markdownify~=0.11.6 31 | tiktoken~=0.5.2 32 | tqdm>=4.66.1 33 | websockets>=12.0 34 | numpy~=1.24.4 35 | pandas~=2.0.3 36 | einops>=0.7.0 37 | transformers_stream_generator==0.0.4 38 | vllm==0.2.7; sys_platform == "linux" 39 | httpx==0.26.0 40 | httpx_sse==0.4.0 41 | llama-index==0.9.35 42 | pyjwt==2.8.0 43 | 44 | # jq==1.6.0 45 | # beautifulsoup4~=4.12.2 46 | # pysrt~=1.1.2 47 | # dashscope==1.13.6 48 | # arxiv~=2.1.0 49 | # youtube-search~=2.1.2 50 | # duckduckgo-search~=3.9.9 51 | # metaphor-python~=0.1.23 52 | 53 | # volcengine>=1.0.134 54 | # pymilvus==2.3.6 55 | # psycopg2==2.9.9 56 | # pgvector>=0.2.4 57 | # chromadb==0.4.13 58 | 59 | #flash-attn==2.4.2 # For Orion-14B-Chat and Qwen-14B-Chat 60 | #autoawq==0.1.8 # For Int4 61 | #rapidocr_paddle[gpu]==1.3.11 # gpu accelleration for ocr of pdf and image files -------------------------------------------------------------------------------- /server/agent/tools/calculate.py: -------------------------------------------------------------------------------- 1 | from langchain.prompts import PromptTemplate 2 | from langchain.chains import LLMMathChain 3 | from server.agent import model_container 4 | from pydantic import BaseModel, Field 5 | 6 | _PROMPT_TEMPLATE = """ 7 | 将数学问题翻译成可以使用Python的numexpr库执行的表达式。使用运行此代码的输出来回答问题。 8 | 问题: ${{包含数学问题的问题。}} 9 | ```text 10 | ${{解决问题的单行数学表达式}} 11 | ``` 12 | ...numexpr.evaluate(query)... 13 | ```output 14 | ${{运行代码的输出}} 15 | ``` 16 | 答案: ${{答案}} 17 | 18 | 这是两个例子: 19 | 20 | 问题: 37593 * 67是多少? 21 | ```text 22 | 37593 * 67 23 | ``` 24 | ...numexpr.evaluate("37593 * 67")... 25 | ```output 26 | 2518731 27 | 28 | 答案: 2518731 29 | 30 | 问题: 37593的五次方根是多少? 31 | ```text 32 | 37593**(1/5) 33 | ``` 34 | ...numexpr.evaluate("37593**(1/5)")... 35 | ```output 36 | 8.222831614237718 37 | 38 | 答案: 8.222831614237718 39 | 40 | 41 | 问题: 2的平方是多少? 42 | ```text 43 | 2 ** 2 44 | ``` 45 | ...numexpr.evaluate("2 ** 2")... 46 | ```output 47 | 4 48 | 49 | 答案: 4 50 | 51 | 52 | 现在,这是我的问题: 53 | 问题: {question} 54 | """ 55 | 56 | PROMPT = PromptTemplate( 57 | input_variables=["question"], 58 | template=_PROMPT_TEMPLATE, 59 | ) 60 | 61 | 62 | class CalculatorInput(BaseModel): 63 | query: str = Field() 64 | 65 | def calculate(query: str): 66 | model = model_container.MODEL 67 | llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT) 68 | ans = llm_math.run(query) 69 | return ans 70 | 71 | if __name__ == "__main__": 72 | result = calculate("2的三次方") 73 | print("答案:",result) 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /server/agent/tools/search_internet.py: -------------------------------------------------------------------------------- 1 | import json 2 | from server.chat.search_engine_chat import search_engine_chat 3 | from configs import VECTOR_SEARCH_TOP_K, MAX_TOKENS 4 | import asyncio 5 | from server.agent import model_container 6 | from pydantic import BaseModel, Field 7 | 8 | async def search_engine_iter(query: str): 9 | response = await search_engine_chat(query=query, 10 | search_engine_name="bing", # 这里切换搜索引擎 11 | model_name=model_container.MODEL.model_name, 12 | temperature=0.01, # Agent 搜索互联网的时候,温度设置为0.01 13 | history=[], 14 | top_k = VECTOR_SEARCH_TOP_K, 15 | max_tokens= MAX_TOKENS, 16 | prompt_name = "default", 17 | stream=False) 18 | 19 | contents = "" 20 | 21 | async for data in response.body_iterator: # 这里的data是一个json字符串 22 | data = json.loads(data) 23 | contents = data["answer"] 24 | docs = data["docs"] 25 | 26 | return contents 27 | 28 | def search_internet(query: str): 29 | return asyncio.run(search_engine_iter(query)) 30 | 31 | class SearchInternetInput(BaseModel): 32 | location: str = Field(description="Query for Internet search") 33 | 34 | 35 | if __name__ == "__main__": 36 | result = search_internet("今天星期几") 37 | print("答案:",result) 38 | -------------------------------------------------------------------------------- /text_splitter/ali_text_splitter.py: -------------------------------------------------------------------------------- 1 | from langchain.text_splitter import CharacterTextSplitter 2 | import re 3 | from typing import List 4 | 5 | 6 | class AliTextSplitter(CharacterTextSplitter): 7 | def __init__(self, pdf: bool = False, **kwargs): 8 | super().__init__(**kwargs) 9 | self.pdf = pdf 10 | 11 | def split_text(self, text: str) -> List[str]: 12 | # use_document_segmentation参数指定是否用语义切分文档,此处采取的文档语义分割模型为达摩院开源的nlp_bert_document-segmentation_chinese-base,论文见https://arxiv.org/abs/2107.09278 13 | # 如果使用模型进行文档语义切分,那么需要安装modelscope[nlp]:pip install "modelscope[nlp]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html 14 | # 考虑到使用了三个模型,可能对于低配置gpu不太友好,因此这里将模型load进cpu计算,有需要的话可以替换device为自己的显卡id 15 | if self.pdf: 16 | text = re.sub(r"\n{3,}", r"\n", text) 17 | text = re.sub('\s', " ", text) 18 | text = re.sub("\n\n", "", text) 19 | try: 20 | from modelscope.pipelines import pipeline 21 | except ImportError: 22 | raise ImportError( 23 | "Could not import modelscope python package. " 24 | "Please install modelscope with `pip install modelscope`. " 25 | ) 26 | 27 | 28 | p = pipeline( 29 | task="document-segmentation", 30 | model='damo/nlp_bert_document-segmentation_chinese-base', 31 | device="cpu") 32 | result = p(documents=text) 33 | sent_list = [i for i in result["text"].split("\n\t") if i] 34 | return sent_list 35 | -------------------------------------------------------------------------------- /markdown_docs/server/agent/tools/arxiv.md: -------------------------------------------------------------------------------- 1 | ## FunctionDef arxiv(query) 2 | **arxiv**: 该函数用于执行对Arxiv的查询操作。 3 | 4 | **参数**: 5 | - **query**: 字符串类型,表示要在Arxiv上执行的查询内容。 6 | 7 | **代码描述**: 8 | `arxiv`函数是一个简单但功能强大的接口,用于在Arxiv数据库中执行查询。它首先创建了一个`ArxivQueryRun`的实例,然后调用该实例的`run`方法来执行查询。查询的具体内容由参数`query`指定,该参数应为一个字符串,表示用户希望在Arxiv上搜索的关键词或查询表达式。 9 | 10 | 在项目结构中,`arxiv`函数位于`server/agent/tools/arxiv.py`路径下,并且是`arxiv.py`模块中定义的核心功能之一。尽管在当前项目的其他部分,如`server/agent/tools/__init__.py`和`server/agent/tools_select.py`中没有直接的调用示例,但可以推断`arxiv`函数设计为被这些模块或其他项目部分调用,以实现对Arxiv数据库的查询功能。 11 | 12 | **注意**: 13 | - 在使用`arxiv`函数时,需要确保传入的查询字符串`query`是有效的,即它应该符合Arxiv的查询语法和要求。 14 | - 该函数的执行结果依赖于`ArxivQueryRun`类的`run`方法的实现,因此需要确保该方法能够正确处理传入的查询字符串,并返回期望的查询结果。 15 | 16 | **输出示例**: 17 | 假设对`arxiv`函数的调用如下: 18 | ```python 19 | result = arxiv("deep learning") 20 | ``` 21 | 则该函数可能返回一个包含查询结果的对象,例如包含多篇关于深度学习的论文的列表。具体的返回值格式将取决于`ArxivQueryRun`类的`run`方法的实现细节。 22 | ## ClassDef ArxivInput 23 | **ArxivInput**: ArxivInput类的功能是定义一个用于搜索查询的输入模型。 24 | 25 | **属性**: 26 | - query: 表示搜索查询标题的字符串。 27 | 28 | **代码描述**: 29 | ArxivInput类继承自BaseModel,这意味着它是一个模型类,用于定义数据结构。在这个类中,定义了一个名为`query`的属性,该属性是一个字符串类型,用于存储用户的搜索查询标题。通过使用`Field`函数,为`query`属性提供了一个描述,即"The search query title",这有助于理解该属性的用途。 30 | 31 | 在项目中,ArxivInput类作为一个数据模型,被用于处理与arXiv相关的搜索查询。尽管在提供的代码调用情况中没有直接的示例,但可以推断,该类可能会被用于在`server/agent/tools`目录下的其他模块中,作为接收用户搜索请求的输入参数。这样的设计使得代码更加模块化,便于维护和扩展。 32 | 33 | **注意**: 34 | - 在使用ArxivInput类时,需要确保传入的`query`参数是一个有效的字符串,因为它将直接影响搜索结果的相关性和准确性。 35 | - 由于ArxivInput继承自BaseModel,可以利用Pydantic库提供的数据验证功能,确保输入数据的合法性。 36 | - 考虑到ArxivInput类可能会被用于网络请求,应当注意处理潜在的安全问题,如SQL注入或跨站脚本攻击(XSS),确保用户输入被适当地清理和验证。 37 | -------------------------------------------------------------------------------- /release.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import re 4 | 5 | def get_latest_tag(): 6 | output = subprocess.check_output(['git', 'tag']) 7 | tags = output.decode('utf-8').split('\n')[:-1] 8 | latest_tag = sorted(tags, key=lambda t: tuple(map(int, re.match(r'v(\d+)\.(\d+)\.(\d+)', t).groups())))[-1] 9 | return latest_tag 10 | 11 | def update_version_number(latest_tag, increment): 12 | major, minor, patch = map(int, re.match(r'v(\d+)\.(\d+)\.(\d+)', latest_tag).groups()) 13 | if increment == 'X': 14 | major += 1 15 | minor, patch = 0, 0 16 | elif increment == 'Y': 17 | minor += 1 18 | patch = 0 19 | elif increment == 'Z': 20 | patch += 1 21 | new_version = f"v{major}.{minor}.{patch}" 22 | return new_version 23 | 24 | def main(): 25 | print("当前最近的Git标签:") 26 | latest_tag = get_latest_tag() 27 | print(latest_tag) 28 | 29 | print("请选择要递增的版本号部分(X, Y, Z):") 30 | increment = input().upper() 31 | 32 | while increment not in ['X', 'Y', 'Z']: 33 | print("输入错误,请输入X, Y或Z:") 34 | increment = input().upper() 35 | 36 | new_version = update_version_number(latest_tag, increment) 37 | print(f"新的版本号为:{new_version}") 38 | 39 | print("确认更新版本号并推送到远程仓库?(y/n)") 40 | confirmation = input().lower() 41 | 42 | if confirmation == 'y': 43 | subprocess.run(['git', 'tag', new_version]) 44 | subprocess.run(['git', 'push', 'origin', new_version]) 45 | print("新版本号已创建并推送到远程仓库。") 46 | else: 47 | print("操作已取消。") 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /markdown_docs/server/agent/tools/shell.md: -------------------------------------------------------------------------------- 1 | ## FunctionDef shell(query) 2 | **shell**: shell函数的功能是执行一个shell查询并返回结果。 3 | 4 | **参数**: 5 | - query: 字符串类型,表示要执行的shell查询命令。 6 | 7 | **代码描述**: 8 | 该shell函数定义在`server/agent/tools/shell.py`文件中,是项目中用于执行shell命令的核心函数。函数接收一个名为`query`的字符串参数,该参数是需要执行的shell命令。函数内部首先创建了一个`ShellTool`类的实例`tool`,然后调用这个实例的`run`方法执行传入的`query`命令。最终,函数返回`run`方法的执行结果。 9 | 10 | 在项目的结构中,虽然`server/agent/tools/__init__.py`和`server/agent/tools_select.py`这两个文件中没有直接的代码示例或文档说明如何调用`shell`函数,但可以推断,`shell`函数作为工具模块中的一部分,可能会被项目中的其他部分调用以执行特定的shell命令。这种设计使得执行shell命令的逻辑被封装在一个单独的函数中,便于维护和重用。 11 | 12 | **注意**: 13 | - 在使用`shell`函数时,需要确保传入的`query`命令是安全的,避免执行恶意代码。 14 | - 该函数的执行结果取决于`ShellTool`类的`run`方法如何实现,因此需要了解`ShellTool`的具体实现细节。 15 | 16 | **输出示例**: 17 | 假设`ShellTool`的`run`方法简单地返回执行命令的输出,如果调用`shell("echo Hello World")`,那么可能的返回值为: 18 | ``` 19 | Hello World 20 | ``` 21 | ## ClassDef ShellInput 22 | **ShellInput**: ShellInput类的功能是定义一个用于封装Shell命令的数据模型。 23 | 24 | **属性**: 25 | - query: 一个字符串类型的属性,用于存储可以在Linux命令行中执行的Shell命令。该属性通过Field方法定义,其中包含一个描述信息,说明这是一个可执行的Shell命令。 26 | 27 | **代码描述**: 28 | ShellInput类继承自BaseModel,这表明它是一个基于Pydantic库的模型,用于数据验证和管理。在这个类中,定义了一个名为`query`的属性,这个属性必须是一个字符串。通过使用Field方法,为`query`属性提供了一个描述,即“一个能在Linux命令行运行的Shell命令”,这有助于理解该属性的用途和功能。 29 | 30 | 在项目的上下文中,虽然当前提供的信息没有直接展示ShellInput类如何被其他对象调用,但可以推断,ShellInput类可能被用于封装用户输入或者其他来源的Shell命令,之后这些封装好的命令可能会在项目的其他部分,如服务器的代理工具中被执行。这样的设计使得Shell命令的处理更加模块化和安全,因为Pydantic模型提供了一层数据验证,确保只有合法和预期的命令才会被执行。 31 | 32 | **注意**: 33 | - 使用ShellInput类时,需要确保传入的`query`字符串是有效且安全的Shell命令。考虑到Shell命令的强大功能和潜在的安全风险,应当避免执行来自不可信源的命令。 34 | - 由于ShellInput类基于Pydantic库,使用该类之前需要确保项目中已经安装了Pydantic。此外,熟悉Pydantic库的基本使用和数据验证机制将有助于更有效地利用ShellInput类。 35 | -------------------------------------------------------------------------------- /tests/api/test_kb_summary_api.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import sys 4 | from pathlib import Path 5 | 6 | root_path = Path(__file__).parent.parent.parent 7 | sys.path.append(str(root_path)) 8 | from server.utils import api_address 9 | 10 | api_base_url = api_address() 11 | 12 | kb = "samples" 13 | file_name = "/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/knowledge_base/samples/content/llm/大模型技术栈-实战与应用.md" 14 | doc_ids = [ 15 | "357d580f-fdf7-495c-b58b-595a398284e8", 16 | "c7338773-2e83-4671-b237-1ad20335b0f0", 17 | "6da613d1-327d-466f-8c1a-b32e6f461f47" 18 | ] 19 | 20 | 21 | def test_summary_file_to_vector_store(api="/knowledge_base/kb_summary_api/summary_file_to_vector_store"): 22 | url = api_base_url + api 23 | print("\n文件摘要:") 24 | r = requests.post(url, json={"knowledge_base_name": kb, 25 | "file_name": file_name 26 | }, stream=True) 27 | for chunk in r.iter_content(None): 28 | data = json.loads(chunk[6:]) 29 | assert isinstance(data, dict) 30 | assert data["code"] == 200 31 | print(data["msg"]) 32 | 33 | 34 | def test_summary_doc_ids_to_vector_store(api="/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store"): 35 | url = api_base_url + api 36 | print("\n文件摘要:") 37 | r = requests.post(url, json={"knowledge_base_name": kb, 38 | "doc_ids": doc_ids 39 | }, stream=True) 40 | for chunk in r.iter_content(None): 41 | data = json.loads(chunk[6:]) 42 | assert isinstance(data, dict) 43 | assert data["code"] == 200 44 | print(data) 45 | -------------------------------------------------------------------------------- /markdown_docs/server/agent/tools/calculate.md: -------------------------------------------------------------------------------- 1 | ## ClassDef CalculatorInput 2 | **CalculatorInput**: CalculatorInput类的功能是定义计算器输入的数据结构。 3 | 4 | **属性**: 5 | - `query`: 表示计算器查询的字符串,是一个必填字段。 6 | 7 | **代码描述**: 8 | CalculatorInput类继承自BaseModel,这表明它是使用Pydantic库创建的,用于数据验证和设置。在这个类中,定义了一个属性`query`,它是一个字符串类型的字段。通过使用`Field()`函数,我们可以为这个字段添加额外的验证或描述信息,虽然在当前的代码示例中没有显示出来。这个类的主要作用是作为计算器服务的输入数据模型,确保传入的查询是有效且符合预期格式的字符串。 9 | 10 | 从项目结构来看,CalculatorInput类位于`server/agent/tools/calculate.py`文件中,但是在提供的项目信息中,并没有直接的代码示例显示这个类是如何被其他对象调用的。然而,基于它的定义和位置,我们可以推断CalculatorInput类可能被用于处理来自于`server/agent/tools`目录下其他模块的计算请求。例如,它可能被用于验证和解析用户输入,然后这些输入将被传递给实际执行计算的逻辑。 11 | 12 | **注意**: 13 | - 使用CalculatorInput类时,需要确保传入的`query`字段是一个有效的字符串,因为这是进行计算前的必要条件。 14 | - 由于CalculatorInput使用了Pydantic库,开发者需要熟悉Pydantic的基本使用方法,以便正确地定义和使用数据模型。 15 | - 虽然当前的CalculatorInput类定义相对简单,但开发者可以根据实际需求,通过添加更多的字段或使用Pydantic提供的更高级的验证功能来扩展它。 16 | ## FunctionDef calculate(query) 17 | **calculate**: 此函数的功能是执行数学计算查询。 18 | 19 | **参数**: 20 | - `query`: 字符串类型,表示需要进行计算的数学查询语句。 21 | 22 | **代码描述**: 23 | `calculate` 函数是一个用于执行数学计算的函数。它首先从`model_container`中获取一个模型实例,该模型被假定为已经加载并准备好处理数学计算查询。接着,使用`LLMMathChain.from_llm`方法创建一个`LLMMathChain`实例,这个实例能够利用提供的模型(`model`)来处理数学计算。在创建`LLMMathChain`实例时,会传入模型和一个标志`verbose=True`以及一个提示`PROMPT`,这表明在执行计算时会有更详细的输出信息。最后,通过调用`LLMMathChain`实例的`run`方法,传入用户的查询(`query`),执行实际的计算,并将计算结果返回。 24 | 25 | 在项目中,尽管`server/agent/tools/__init__.py`和`server/agent/tools_select.py`这两个对象的代码和文档未提供详细信息,但可以推断`calculate`函数可能被设计为一个核心的数学计算工具,供项目中的其他部分调用以执行具体的数学计算任务。这种设计使得数学计算功能模块化,便于在不同的上下文中重用和维护。 26 | 27 | **注意**: 28 | - 确保在调用此函数之前,`model_container.MODEL`已正确加载并初始化,因为这是执行计算的关键。 29 | - 由于函数使用了`verbose=True`,调用时会产生详细的日志输出,这对于调试和分析计算过程很有帮助,但在生产环境中可能需要根据实际情况调整。 30 | 31 | **输出示例**: 32 | 假设传入的`query`为"2 + 2",函数可能返回一个类似于`"4"`的字符串,表示计算结果。实际返回值将依赖于模型的具体实现和处理能力。 33 | -------------------------------------------------------------------------------- /markdown_docs/server/agent/tools/wolfram.md: -------------------------------------------------------------------------------- 1 | ## FunctionDef wolfram(query) 2 | **wolfram**: 此函数的功能是执行对Wolfram Alpha API的查询并返回结果。 3 | 4 | **参数**: 5 | - `query`: 字符串类型,表示要查询的内容。 6 | 7 | **代码描述**: 8 | `wolfram`函数首先创建了一个`WolframAlphaAPIWrapper`对象,该对象是对Wolfram Alpha API的一个封装。在创建这个对象时,需要提供一个`wolfram_alpha_appid`,这是调用Wolfram Alpha服务所需的应用程序ID。随后,函数使用`run`方法执行传入的查询`query`。最后,函数返回查询结果。 9 | 10 | 在项目中,`wolfram`函数作为`server/agent/tools/wolfram.py`模块的一部分,虽然其被调用的具体情况在提供的文档中没有直接说明,但可以推断这个函数可能被设计为一个工具函数,供项目中其他部分调用以获取Wolfram Alpha查询的结果。这样的设计使得项目中的其他模块可以轻松地利用Wolfram Alpha提供的强大计算和知识查询功能,而无需关心API调用的具体细节。 11 | 12 | **注意**: 13 | - 使用此函数前,需要确保已经获得了有效的Wolfram Alpha应用程序ID(`wolfram_alpha_appid`),并且该ID已经正确配置在创建`WolframAlphaAPIWrapper`对象时。 14 | - 查询结果的具体格式和内容将依赖于Wolfram Alpha API的返回值,可能包括文本、图像或其他数据类型。 15 | 16 | **输出示例**: 17 | 假设对Wolfram Alpha进行了一个查询“2+2”,函数可能返回如下的结果: 18 | ``` 19 | 4 20 | ``` 21 | 这只是一个简化的示例,实际返回的结果可能包含更多的信息和数据类型,具体取决于查询的内容和Wolfram Alpha API的响应。 22 | ## ClassDef WolframInput 23 | **WolframInput**: WolframInput类的功能是封装了用于Wolfram语言计算的输入数据。 24 | 25 | **属性**: 26 | - location: 表示需要进行计算的具体问题的字符串。 27 | 28 | **代码描述**: 29 | WolframInput类继承自BaseModel,这表明它是一个用于数据验证和序列化的模型类。在这个类中,定义了一个名为`location`的属性,该属性用于存储一个字符串,这个字符串代表了需要使用Wolfram语言进行计算的具体问题。通过使用Pydantic库中的`Field`函数,为`location`属性提供了一个描述,增强了代码的可读性和易用性。 30 | 31 | 在项目的结构中,WolframInput类位于`server/agent/tools/wolfram.py`文件中,这意味着它是服务端代理工具中的一部分,专门用于处理与Wolfram语言相关的输入数据。尽管在提供的信息中,`server/agent/tools/__init__.py`和`server/agent/tools_select.py`两个文件中没有直接提到WolframInput类的调用情况,但可以推测,WolframInput类可能会被这些模块或其他相关模块中的代码所使用,以便于处理和传递需要用Wolfram语言解决的问题。 32 | 33 | **注意**: 34 | - 在使用WolframInput类时,需要确保`location`属性中的问题描述是准确和有效的,因为这将直接影响到Wolfram语言计算的结果。 35 | - 由于WolframInput类继承自BaseModel,因此可以利用Pydantic提供的数据验证功能来确保输入数据的有效性。在实际应用中,可以根据需要对WolframInput类进行扩展,增加更多的属性和验证逻辑,以满足不同的计算需求。 36 | -------------------------------------------------------------------------------- /server/minx_chat_openai.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | TYPE_CHECKING, 3 | Any, 4 | Tuple 5 | ) 6 | import sys 7 | import logging 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | if TYPE_CHECKING: 12 | import tiktoken 13 | 14 | 15 | class MinxChatOpenAI: 16 | 17 | @staticmethod 18 | def import_tiktoken() -> Any: 19 | try: 20 | import tiktoken 21 | except ImportError: 22 | raise ValueError( 23 | "Could not import tiktoken python package. " 24 | "This is needed in order to calculate get_token_ids. " 25 | "Please install it with `pip install tiktoken`." 26 | ) 27 | return tiktoken 28 | 29 | @staticmethod 30 | def get_encoding_model(self) -> Tuple[str, "tiktoken.Encoding"]: 31 | tiktoken_ = MinxChatOpenAI.import_tiktoken() 32 | if self.tiktoken_model_name is not None: 33 | model = self.tiktoken_model_name 34 | else: 35 | model = self.model_name 36 | if model == "gpt-3.5-turbo": 37 | # gpt-3.5-turbo may change over time. 38 | # Returning num tokens assuming gpt-3.5-turbo-0301. 39 | model = "gpt-3.5-turbo-0301" 40 | elif model == "gpt-4": 41 | # gpt-4 may change over time. 42 | # Returning num tokens assuming gpt-4-0314. 43 | model = "gpt-4-0314" 44 | # Returns the number of tokens used by a list of messages. 45 | try: 46 | encoding = tiktoken_.encoding_for_model(model) 47 | except Exception as e: 48 | logger.warning("Warning: model not found. Using cl100k_base encoding.") 49 | model = "cl100k_base" 50 | encoding = tiktoken_.get_encoding(model) 51 | return model, encoding 52 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.2 2 | torchvision==0.16.2 3 | torchaudio==2.1.2 4 | xformers==0.0.23.post1 5 | transformers==4.37.2 6 | sentence_transformers==2.2.2 7 | langchain==0.0.354 8 | langchain-community==0.0.19 9 | langchain-experimental==0.0.47 10 | pydantic==1.10.13 11 | fschat==0.2.35 12 | openai==1.9.0 13 | fastapi==0.109.0 14 | sse_starlette==1.8.2 15 | nltk==3.8.1 16 | uvicorn>=0.27.0.post1 17 | starlette==0.35.0 18 | unstructured[all-docs] # ==0.11.8 19 | python-magic-bin; sys_platform == 'win32' 20 | SQLAlchemy==2.0.25 21 | faiss-cpu==1.7.4 22 | accelerate==0.24.1 23 | spacy==3.7.2 24 | PyMuPDF==1.23.16 25 | rapidocr_onnxruntime==1.3.8 26 | requests==2.31.0 27 | pathlib==1.0.1 28 | pytest==7.4.3 29 | numexpr==2.8.6 30 | strsimpy==0.2.1 31 | markdownify==0.11.6 32 | tiktoken==0.5.2 33 | tqdm==4.66.1 34 | websockets==12.0 35 | numpy==1.24.4 36 | pandas==2.0.3 37 | einops==0.7.0 38 | transformers_stream_generator==0.0.4 39 | vllm==0.2.7; sys_platform == "linux" 40 | llama-index==0.9.35 41 | 42 | # jq==1.6.0 43 | # beautifulsoup4==4.12.2 44 | # pysrt==1.1.2 45 | # dashscope==1.13.6 # qwen 46 | # volcengine==1.0.134 # fangzhou 47 | # uncomment libs if you want to use corresponding vector store 48 | # pymilvus==2.3.6 49 | # psycopg2==2.9.9 50 | # pgvector>=0.2.4 51 | # chromadb==0.4.13 52 | 53 | #flash-attn==2.4.2 # For Orion-14B-Chat and Qwen-14B-Chat 54 | #autoawq==0.1.8 # For Int4 55 | #rapidocr_paddle[gpu]==1.3.11 # gpu accelleration for ocr of pdf and image files 56 | 57 | arxiv==2.1.0 58 | youtube-search==2.1.2 59 | duckduckgo-search==3.9.9 60 | metaphor-python==0.1.23 61 | 62 | streamlit==1.30.0 63 | streamlit-option-menu==0.3.12 64 | streamlit-antd-components==0.3.1 65 | streamlit-chatbox==1.1.11 66 | streamlit-modal==0.1.0 67 | streamlit-aggrid==0.3.4.post3 68 | 69 | httpx==0.26.0 70 | httpx_sse==0.4.0 71 | watchdog==3.0.0 72 | pyjwt==2.8.0 73 | -------------------------------------------------------------------------------- /knowledge_base/samples/content/llm/大模型指令对齐训练原理.md: -------------------------------------------------------------------------------- 1 | # 大模型指令对齐训练原理 2 | - RLHF 3 | - SFT 4 | - RM 5 | - PPO 6 | - AIHF-based 7 | - RLAIF 8 | - 核心在于通过AI 模型监督其他 AI 模型,即在SFT阶段,从初始模型中采样,然后生成自我批评和修正,然后根据修正后的反应微调原始模型。在 RL 阶段,从微调模型中采样,使用一个模型来评估生成的样本,并从这个 AI 偏好数据集训练一个偏好模型。然后使用偏好模型作为奖励信号对 RL 进行训练 9 | - ![图片](./img/大模型指令对齐训练原理-幕布图片-17565-176537.jpg) 10 | - ![图片](./img/大模型指令对齐训练原理-幕布图片-95996-523276.jpg) 11 | - ![图片](./img/大模型指令对齐训练原理-幕布图片-349153-657791.jpg) 12 | - RRHF 13 | - RRHF( **R** ank **R** esponse from **H** uman **F** eedback) 不需要强化学习,可以利用不同语言模型生成的回复,包括 ChatGPT、GPT-4 或当前的训练模型。RRHF通过对回复进行评分,并通过排名损失来使回复与人类偏好对齐。RRHF 通过通过排名损失使评分与人类的偏好(或者代理的奖励模型)对齐。RRHF 训练好的模型可以同时作为生成语言模型和奖励模型使用。 14 | - ![图片](./img/大模型指令对齐训练原理-幕布图片-805089-731888.jpg) 15 | - SFT-only 16 | - LIMA 17 | - LIMA(Less Is More for Alignment) 即浅层对齐假说,即一 **个模型的知识和能力几乎完全是在预训练中学习的,而对齐则是教会它与用户交互时如何选择子分布** 。如果假说正确,对齐主要有关于学习方式,那么该假说的一个推论是,人们可以用相当少的样本充分调整预训练的语言模型。因此, **该工作假设,对齐可以是一个简单的过程,模型学习与用户互动的风格或格式,以揭示在预训练中已经获得的知识和能力。** 18 | - LTD Instruction Tuning 19 | - ![图片](./img/大模型指令对齐训练原理-幕布图片-759487-923925.jpg) 20 | - Reward-only 21 | - DPO 22 | - DPO(Direct Preference Optimization) 提出了一种使用二进制交叉熵目标来精确优化LLM的方法,以替代基于 RL HF 的优化目标,从而大大简化偏好学习 pipeline。也就是说,完全可以直接优化语言模型以实现人类的偏好,而不需要明确的奖励模型或强化学习。 23 | - DPO 也依赖于理论上的偏好模型(如 Bradley-Terry 模型),以此衡量给定的奖励函数与经验偏好数据的吻合程度。然而,现有的方法使用偏好模型定义偏好损失来训练奖励模型,然后训练优化所学奖励模型的策略,而 DPO 使用变量的变化来直接定义偏好损失作为策略的一个函数。鉴于人类对模型响应的偏好数据集,DPO 因此可以使用一个简单的二进制交叉熵目标来优化策略,而不需要明确地学习奖励函数或在训练期间从策略中采样。 24 | - RAFT 25 | - ![图片](./img/大模型指令对齐训练原理-幕布图片-350029-666381.jpg) 26 | - 参考文献 27 | - [反思RLHF]("https://mp.weixin.qq.com/s/e3E_XsZTiNMNYqzzi6Pbjw") 28 | - [RLHF笔记]("https://mathpretty.com/16017.html") 29 | - [hf-blog]("https://huggingface.co/blog/zh/rlhf") 30 | - ** [RLHF代码详解]("https://zhuanlan.zhihu.com/p/624589622") -------------------------------------------------------------------------------- /tests/custom_splitter/test_different_splitter.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from transformers import AutoTokenizer 4 | import sys 5 | 6 | sys.path.append("../..") 7 | from configs import ( 8 | CHUNK_SIZE, 9 | OVERLAP_SIZE 10 | ) 11 | 12 | from server.knowledge_base.utils import make_text_splitter 13 | 14 | def text(splitter_name): 15 | from langchain import document_loaders 16 | 17 | # 使用DocumentLoader读取文件 18 | filepath = "../../knowledge_base/samples/content/test.txt" 19 | loader = document_loaders.UnstructuredFileLoader(filepath, autodetect_encoding=True) 20 | docs = loader.load() 21 | text_splitter = make_text_splitter(splitter_name, CHUNK_SIZE, OVERLAP_SIZE) 22 | if splitter_name == "MarkdownHeaderTextSplitter": 23 | docs = text_splitter.split_text(docs[0].page_content) 24 | for doc in docs: 25 | if doc.metadata: 26 | doc.metadata["source"] = os.path.basename(filepath) 27 | else: 28 | docs = text_splitter.split_documents(docs) 29 | for doc in docs: 30 | print(doc) 31 | return docs 32 | 33 | 34 | 35 | 36 | import pytest 37 | from langchain.docstore.document import Document 38 | 39 | @pytest.mark.parametrize("splitter_name", 40 | [ 41 | "ChineseRecursiveTextSplitter", 42 | "SpacyTextSplitter", 43 | "RecursiveCharacterTextSplitter", 44 | "MarkdownHeaderTextSplitter" 45 | ]) 46 | def test_different_splitter(splitter_name): 47 | try: 48 | docs = text(splitter_name) 49 | assert isinstance(docs, list) 50 | if len(docs)>0: 51 | assert isinstance(docs[0], Document) 52 | except Exception as e: 53 | pytest.fail(f"test_different_splitter failed with {splitter_name}, error: {str(e)}") 54 | -------------------------------------------------------------------------------- /server/agent/tools_select.py: -------------------------------------------------------------------------------- 1 | from langchain.tools import Tool 2 | from server.agent.tools import * 3 | 4 | tools = [ 5 | Tool.from_function( 6 | func=calculate, 7 | name="calculate", 8 | description="Useful for when you need to answer questions about simple calculations", 9 | args_schema=CalculatorInput, 10 | ), 11 | Tool.from_function( 12 | func=arxiv, 13 | name="arxiv", 14 | description="A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.", 15 | args_schema=ArxivInput, 16 | ), 17 | Tool.from_function( 18 | func=weathercheck, 19 | name="weather_check", 20 | description="", 21 | args_schema=WeatherInput, 22 | ), 23 | Tool.from_function( 24 | func=shell, 25 | name="shell", 26 | description="Use Shell to execute Linux commands", 27 | args_schema=ShellInput, 28 | ), 29 | Tool.from_function( 30 | func=search_knowledgebase_complex, 31 | name="search_knowledgebase_complex", 32 | description="Use Use this tool to search local knowledgebase and get information", 33 | args_schema=KnowledgeSearchInput, 34 | ), 35 | Tool.from_function( 36 | func=search_internet, 37 | name="search_internet", 38 | description="Use this tool to use bing search engine to search the internet", 39 | args_schema=SearchInternetInput, 40 | ), 41 | Tool.from_function( 42 | func=wolfram, 43 | name="Wolfram", 44 | description="Useful for when you need to calculate difficult formulas", 45 | args_schema=WolframInput, 46 | ), 47 | Tool.from_function( 48 | func=search_youtube, 49 | name="search_youtube", 50 | description="use this tools to search youtube videos", 51 | args_schema=YoutubeInput, 52 | ), 53 | ] 54 | 55 | tool_names = [tool.name for tool in tools] 56 | -------------------------------------------------------------------------------- /markdown_docs/server/agent/tools/search_youtube.md: -------------------------------------------------------------------------------- 1 | ## FunctionDef search_youtube(query) 2 | **search_youtube**: 此函数的功能是根据提供的查询字符串搜索YouTube视频。 3 | 4 | **参数**: 5 | - query: 字符串类型,表示要在YouTube上搜索的查询字符串。 6 | 7 | **代码描述**: 8 | `search_youtube`函数接受一个名为`query`的参数,这个参数是一个字符串,代表用户希望在YouTube上进行搜索的关键词或短语。函数内部首先创建了一个`YouTubeSearchTool`的实例,命名为`tool`。然后,它调用了`tool`的`run`方法,将`query`作为输入参数传递给这个方法。最终,函数返回`run`方法的执行结果。 9 | 10 | 此函数是项目中用于与YouTube API交互的一部分,特别是在`server/agent/tools`路径下。它被设计为一个轻量级的接口,允许其他项目部分,如`server/agent/tools_select.py`,通过简单地调用此函数并传递相应的查询字符串,来实现YouTube搜索功能的集成。这种设计使得在不同的项目部分之间共享功能变得简单,同时也保持了代码的模块化和可维护性。 11 | 12 | **注意**: 13 | - 确保在调用此函数之前,`YouTubeSearchTool`类已经正确实现,并且其`run`方法能够接受一个字符串类型的输入参数并返回搜索结果。 14 | - 此函数的性能和返回结果直接依赖于`YouTubeSearchTool`类的实现细节以及YouTube API的响应。 15 | 16 | **输出示例**: 17 | 假设`YouTubeSearchTool`的`run`方法返回的是一个包含搜索结果视频标题和URL的列表,那么`search_youtube`函数的一个可能的返回值示例为: 18 | ```python 19 | [ 20 | {"title": "如何使用Python搜索YouTube", "url": "https://www.youtube.com/watch?v=example1"}, 21 | {"title": "Python YouTube API教程", "url": "https://www.youtube.com/watch?v=example2"} 22 | ] 23 | ``` 24 | 这个返回值展示了一个包含两个搜索结果的列表,每个结果都是一个字典,包含视频的标题和URL。 25 | ## ClassDef YoutubeInput 26 | **YoutubeInput**: YoutubeInput类的功能是定义用于YouTube视频搜索的输入参数模型。 27 | 28 | **属性**: 29 | - location: 用于视频搜索的查询字符串。 30 | 31 | **代码描述**: 32 | YoutubeInput类继承自BaseModel,这表明它是一个模型类,用于定义数据结构。在这个类中,定义了一个名为`location`的属性,该属性用于存储用户进行YouTube视频搜索时输入的查询字符串。通过使用`Field`函数,为`location`属性提供了一个描述,即"Query for Videos search",这有助于理解该属性的用途。 33 | 34 | 在项目中,虽然`server/agent/tools/__init__.py`和`server/agent/tools_select.py`两个文件中并没有直接提到`YoutubeInput`类的使用,但可以推断,`YoutubeInput`类作为一个数据模型,可能会在处理YouTube视频搜索请求的过程中被用到。具体来说,它可能被用于解析和验证用户的搜索请求参数,确保传递给YouTube API的查询字符串是有效和格式正确的。 35 | 36 | **注意**: 37 | - 在使用`YoutubeInput`类时,需要确保传递给`location`属性的值是一个有效的字符串,因为这将直接影响到YouTube视频搜索的结果。 38 | - 由于`YoutubeInput`类继承自`BaseModel`,可以利用Pydantic库提供的数据验证和序列化功能,以简化数据处理流程。 39 | - 虽然当前文档中没有提到`YoutubeInput`类在项目中的具体调用情况,开发者在实际使用时应考虑如何将此类集成到视频搜索功能中,以及如何处理可能出现的数据验证错误。 40 | -------------------------------------------------------------------------------- /server/db/models/knowledge_file_model.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, func 2 | 3 | from server.db.base import Base 4 | 5 | 6 | class KnowledgeFileModel(Base): 7 | """ 8 | 知识文件模型 9 | """ 10 | __tablename__ = 'knowledge_file' 11 | id = Column(Integer, primary_key=True, autoincrement=True, comment='知识文件ID') 12 | file_name = Column(String(255), comment='文件名') 13 | file_ext = Column(String(10), comment='文件扩展名') 14 | kb_name = Column(String(50), comment='所属知识库名称') 15 | document_loader_name = Column(String(50), comment='文档加载器名称') 16 | text_splitter_name = Column(String(50), comment='文本分割器名称') 17 | file_version = Column(Integer, default=1, comment='文件版本') 18 | file_mtime = Column(Float, default=0.0, comment="文件修改时间") 19 | file_size = Column(Integer, default=0, comment="文件大小") 20 | custom_docs = Column(Boolean, default=False, comment="是否自定义docs") 21 | docs_count = Column(Integer, default=0, comment="切分文档数量") 22 | create_time = Column(DateTime, default=func.now(), comment='创建时间') 23 | 24 | def __repr__(self): 25 | return f"" 26 | 27 | 28 | class FileDocModel(Base): 29 | """ 30 | 文件-向量库文档模型 31 | """ 32 | __tablename__ = 'file_doc' 33 | id = Column(Integer, primary_key=True, autoincrement=True, comment='ID') 34 | kb_name = Column(String(50), comment='知识库名称') 35 | file_name = Column(String(255), comment='文件名称') 36 | doc_id = Column(String(50), comment="向量库文档ID") 37 | meta_data = Column(JSON, default={}) 38 | 39 | def __repr__(self): 40 | return f"" 41 | -------------------------------------------------------------------------------- /server/api_allinone_stale.py: -------------------------------------------------------------------------------- 1 | """Usage 2 | 调用默认模型: 3 | python server/api_allinone.py 4 | 5 | 加载多个非默认模型: 6 | python server/api_allinone.py --model-path-address model1@host1@port1 model2@host2@port2 7 | 8 | 多卡启动: 9 | python server/api_allinone.py --model-path-address model@host@port --num-gpus 2 --gpus 0,1 --max-gpu-memory 10GiB 10 | 11 | """ 12 | import sys 13 | import os 14 | 15 | sys.path.append(os.path.dirname(__file__)) 16 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 17 | 18 | from llm_api_stale import launch_all, parser, controller_args, worker_args, server_args 19 | from api import create_app 20 | import uvicorn 21 | 22 | parser.add_argument("--api-host", type=str, default="0.0.0.0") 23 | parser.add_argument("--api-port", type=int, default=7861) 24 | parser.add_argument("--ssl_keyfile", type=str) 25 | parser.add_argument("--ssl_certfile", type=str) 26 | 27 | api_args = ["api-host", "api-port", "ssl_keyfile", "ssl_certfile"] 28 | 29 | 30 | def run_api(host, port, **kwargs): 31 | app = create_app() 32 | if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"): 33 | uvicorn.run(app, 34 | host=host, 35 | port=port, 36 | ssl_keyfile=kwargs.get("ssl_keyfile"), 37 | ssl_certfile=kwargs.get("ssl_certfile"), 38 | ) 39 | else: 40 | uvicorn.run(app, host=host, port=port) 41 | 42 | 43 | if __name__ == "__main__": 44 | print("Luanching api_allinone,it would take a while, please be patient...") 45 | print("正在启动api_allinone,LLM服务启动约3-10分钟,请耐心等待...") 46 | # 初始化消息 47 | args = parser.parse_args() 48 | args_dict = vars(args) 49 | launch_all(args=args, controller_args=controller_args, worker_args=worker_args, server_args=server_args) 50 | run_api( 51 | host=args.api_host, 52 | port=args.api_port, 53 | ssl_keyfile=args.ssl_keyfile, 54 | ssl_certfile=args.ssl_certfile, 55 | ) 56 | print("Luanching api_allinone done.") 57 | print("api_allinone启动完毕.") 58 | -------------------------------------------------------------------------------- /webui.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from webui_pages.utils import * 3 | from streamlit_option_menu import option_menu 4 | from webui_pages.dialogue.dialogue import dialogue_page, chat_box 5 | from webui_pages.knowledge_base.knowledge_base import knowledge_base_page 6 | import os 7 | import sys 8 | from configs import VERSION 9 | from server.utils import api_address 10 | 11 | 12 | api = ApiRequest(base_url=api_address()) 13 | 14 | if __name__ == "__main__": 15 | is_lite = "lite" in sys.argv 16 | 17 | st.set_page_config( 18 | "Langchain-Chatchat WebUI", 19 | os.path.join("img", "chatchat_icon_blue_square_v2.png"), 20 | initial_sidebar_state="expanded", 21 | menu_items={ 22 | 'Get Help': 'https://github.com/chatchat-space/Langchain-Chatchat', 23 | 'Report a bug': "https://github.com/chatchat-space/Langchain-Chatchat/issues", 24 | 'About': f"""欢迎使用 Langchain-Chatchat WebUI {VERSION}!""" 25 | } 26 | ) 27 | 28 | pages = { 29 | "对话": { 30 | "icon": "chat", 31 | "func": dialogue_page, 32 | }, 33 | "知识库管理": { 34 | "icon": "hdd-stack", 35 | "func": knowledge_base_page, 36 | }, 37 | } 38 | 39 | with st.sidebar: 40 | st.image( 41 | os.path.join( 42 | "img", 43 | "logo-long-chatchat-trans-v2.png" 44 | ), 45 | use_column_width=True 46 | ) 47 | st.caption( 48 | f"""

当前版本:{VERSION}

""", 49 | unsafe_allow_html=True, 50 | ) 51 | options = list(pages) 52 | icons = [x["icon"] for x in pages.values()] 53 | 54 | default_index = 0 55 | selected_page = option_menu( 56 | "", 57 | options=options, 58 | icons=icons, 59 | # menu_icon="chat-quote", 60 | default_index=default_index, 61 | ) 62 | 63 | if selected_page in pages: 64 | pages[selected_page]["func"](api=api, is_lite=is_lite) 65 | -------------------------------------------------------------------------------- /markdown_docs/server/agent/tools/search_knowledgebase_simple.md: -------------------------------------------------------------------------------- 1 | ## FunctionDef search_knowledge_base_iter(database, query) 2 | **search_knowledge_base_iter**: 该函数用于异步迭代地搜索知识库并获取回答。 3 | 4 | **参数**: 5 | - `database`: 知识库的名称,类型为字符串。 6 | - `query`: 用户的查询内容,类型为字符串。 7 | 8 | **代码描述**: 9 | `search_knowledge_base_iter` 函数是一个异步函数,它接受两个参数:`database` 和 `query`。这个函数首先调用 `knowledge_base_chat` 函数,传入相应的参数,包括知识库名称、查询内容、模型名称、温度参数、历史记录、向量搜索的 top_k 值、最大 token 数量、prompt 名称、分数阈值以及是否流式输出的标志。`knowledge_base_chat` 函数负责处理用户的查询请求,并与知识库进行交互,返回一个响应对象。 10 | 11 | 函数内部通过异步迭代 `response.body_iterator` 来处理响应体中的数据。每次迭代得到的 `data` 是一个 JSON 字符串,表示一部分的查询结果。函数使用 `json.loads` 方法将 JSON 字符串解析为字典对象,然后从中提取出答案和相关文档信息。最终,函数返回最后一次迭代得到的答案内容。 12 | 13 | **注意**: 14 | - 由于 `search_knowledge_base_iter` 是一个异步函数,因此在调用时需要使用 `await` 关键字。 15 | - 函数返回的是最后一次迭代得到的答案内容,如果需要处理每一次迭代得到的数据,需要在迭代过程中添加相应的处理逻辑。 16 | - 确保传入的知识库名称在系统中已经存在,否则可能无法正确处理查询请求。 17 | 18 | **输出示例**: 19 | 调用 `search_knowledge_base_iter` 函数可能返回的示例: 20 | ```json 21 | "这是根据您的查询生成的回答。" 22 | ``` 23 | 此输出示例仅表示函数可能返回的答案内容的格式,实际返回的内容将根据查询内容和知识库中的数据而有所不同。 24 | ## FunctionDef search_knowledgebase_simple(query) 25 | **search_knowledgebase_simple**: 此函数用于简化地搜索知识库并获取回答。 26 | 27 | **参数**: 28 | - `query`: 用户的查询内容,类型为字符串。 29 | 30 | **代码描述**: 31 | `search_knowledgebase_simple` 函数是一个简化的接口,用于对知识库进行搜索。它接受一个参数 `query`,即用户的查询内容。函数内部通过调用 `search_knowledge_base_iter` 函数来实现对知识库的搜索。`search_knowledge_base_iter` 是一个异步函数,负责异步迭代地搜索知识库并获取回答。`search_knowledgebase_simple` 函数通过使用 `asyncio.run` 方法来运行异步的 `search_knowledge_base_iter` 函数,从而实现同步调用的效果。 32 | 33 | 由于 `search_knowledge_base_iter` 函数需要数据库名称和查询内容作为参数,但在 `search_knowledgebase_simple` 函数中只提供了查询内容 `query`,这意味着在 `search_knowledge_base_iter` 函数的实现中,数据库名称可能是预设的或通过其他方式获取。 34 | 35 | **注意**: 36 | - `search_knowledgebase_simple` 函数提供了一个简化的接口,使得开发者可以不必直接处理异步编程的复杂性,而是通过一个简单的同步函数调用来搜索知识库。 37 | - 由于内部调用了异步函数 `search_knowledge_base_iter`,确保在使用此函数时,相关的异步环境和配置已正确设置。 38 | - 考虑到 `search_knowledge_base_iter` 函数的异步特性和可能的迭代处理,调用 `search_knowledgebase_simple` 函数时应注意可能的延迟或异步执行的影响。 39 | 40 | **输出示例**: 41 | 调用 `search_knowledgebase_simple` 函数可能返回的示例: 42 | ``` 43 | "这是根据您的查询生成的回答。" 44 | ``` 45 | 此输出示例表示函数可能返回的答案内容的格式,实际返回的内容将根据查询内容和知识库中的数据而有所不同。 46 | -------------------------------------------------------------------------------- /tests/api/test_stream_chat_api_thread.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import sys 4 | from pathlib import Path 5 | 6 | sys.path.append(str(Path(__file__).parent.parent.parent)) 7 | from configs import BING_SUBSCRIPTION_KEY 8 | from server.utils import api_address 9 | 10 | from pprint import pprint 11 | from concurrent.futures import ThreadPoolExecutor, as_completed 12 | import time 13 | 14 | 15 | api_base_url = api_address() 16 | 17 | 18 | def dump_input(d, title): 19 | print("\n") 20 | print("=" * 30 + title + " input " + "="*30) 21 | pprint(d) 22 | 23 | 24 | def dump_output(r, title): 25 | print("\n") 26 | print("=" * 30 + title + " output" + "="*30) 27 | for line in r.iter_content(None, decode_unicode=True): 28 | print(line, end="", flush=True) 29 | 30 | 31 | headers = { 32 | 'accept': 'application/json', 33 | 'Content-Type': 'application/json', 34 | } 35 | 36 | 37 | def knowledge_chat(api="/chat/knowledge_base_chat"): 38 | url = f"{api_base_url}{api}" 39 | data = { 40 | "query": "如何提问以获得高质量答案", 41 | "knowledge_base_name": "samples", 42 | "history": [ 43 | { 44 | "role": "user", 45 | "content": "你好" 46 | }, 47 | { 48 | "role": "assistant", 49 | "content": "你好,我是 ChatGLM" 50 | } 51 | ], 52 | "stream": True 53 | } 54 | result = [] 55 | response = requests.post(url, headers=headers, json=data, stream=True) 56 | 57 | for line in response.iter_content(None, decode_unicode=True): 58 | data = json.loads(line[6:]) 59 | result.append(data) 60 | 61 | return result 62 | 63 | 64 | def test_thread(): 65 | threads = [] 66 | times = [] 67 | pool = ThreadPoolExecutor() 68 | start = time.time() 69 | for i in range(10): 70 | t = pool.submit(knowledge_chat) 71 | threads.append(t) 72 | 73 | for r in as_completed(threads): 74 | end = time.time() 75 | times.append(end - start) 76 | print("\nResult:\n") 77 | pprint(r.result()) 78 | 79 | print("\nTime used:\n") 80 | for x in times: 81 | print(f"{x}") 82 | -------------------------------------------------------------------------------- /markdown_docs/server/db/models/knowledge_metadata_model.md: -------------------------------------------------------------------------------- 1 | ## ClassDef SummaryChunkModel 2 | **SummaryChunkModel**: SummaryChunkModel 类的功能是用于存储和管理文档中每个文档标识符(doc_id)的摘要信息。 3 | 4 | **属性**: 5 | - `id`: 唯一标识符,用于标识每个摘要信息的ID。 6 | - `kb_name`: 知识库名称,表示该摘要信息属于哪个知识库。 7 | - `summary_context`: 总结文本,存储自动生成或用户输入的文档摘要。 8 | - `summary_id`: 总结矢量id,用于后续的矢量库构建和语义关联。 9 | - `doc_ids`: 向量库id关联列表,存储与该摘要相关的文档标识符列表。 10 | - `meta_data`: 元数据,以JSON格式存储额外的信息,如页码信息等。 11 | 12 | **代码描述**: 13 | SummaryChunkModel 类定义了一个用于存储文档摘要信息的数据模型。该模型包括了文档的基本信息如知识库名称、摘要文本、摘要矢量ID、相关文档ID列表以及额外的元数据。这些信息主要来源于用户上传文件时的描述或程序自动切分文档生成的摘要。此外,该模型还支持后续的矢量库构建和语义关联任务,通过对summary_context创建索引和计算语义相似度来实现。 14 | 15 | 在项目中,SummaryChunkModel 被 knowledge_metadata_repository.py 文件中的多个函数调用,包括添加、删除、列出和统计知识库中的摘要信息。这些函数通过操作 SummaryChunkModel 实例来实现对数据库中摘要信息的管理,如添加新的摘要信息、删除特定知识库的摘要信息、根据知识库名称列出摘要信息以及统计特定知识库的摘要数量。 16 | 17 | **注意**: 18 | - 在使用 SummaryChunkModel 进行数据库操作时,需要确保传入的参数类型和格式正确,特别是 `meta_data` 字段,它应以正确的JSON格式存储。 19 | - 在进行矢量库构建和语义关联任务时,应注意 `summary_id` 和 `doc_ids` 字段的正确使用和关联。 20 | 21 | **输出示例**: 22 | 假设数据库中有一个摘要信息实例,其可能的表示如下: 23 | ``` 24 | 25 | ``` 26 | 这表示一个ID为1的摘要信息,属于“技术文档”知识库,摘要文本为“这是一个关于AI技术的摘要”,关联的文档标识符为doc1和doc2,没有额外的元数据信息。 27 | ### FunctionDef __repr__(self) 28 | **__repr__**: 此函数的功能是生成并返回一个代表对象状态的字符串。 29 | 30 | **参数**: 此函数不接受除`self`之外的任何参数。 31 | 32 | **代码描述**: `__repr__`函数是`SummaryChunkModel`类的一个特殊方法,用于创建一个代表该对象实例状态的字符串。这个字符串包含了`SummaryChunkModel`实例的几个关键属性:`id`、`kb_name`、`summary_context`、`doc_ids`以及`metadata`。这些属性通过访问实例的相应属性并将它们格式化为一个特定格式的字符串来展示。这个字符串格式遵循``的形式,其中每个`...`会被实例相应属性的实际值替换。这种表示方式便于开发者在调试过程中快速识别对象的状态。 33 | 34 | **注意**: `__repr__`方法通常用于调试和日志记录,它应该返回一个明确且易于理解的对象状态描述。返回的字符串应该尽可能地反映出对象的关键属性。此外,虽然`__repr__`的主要目的不是被终端用户直接看到,但它的设计应确保在需要时能够提供足够的信息来识别对象的具体状态。 35 | 36 | **输出示例**: 假设有一个`SummaryChunkModel`实例,其`id`为`123`,`kb_name`为`"KnowledgeBase1"`,`summary_context`为`"Context1"`,`doc_ids`为`"doc1, doc2"`,`metadata`为`"{'author': 'John Doe'}"`。调用此实例的`__repr__`方法将返回以下字符串: 37 | 38 | ``` 39 | 40 | ``` 41 | *** 42 | -------------------------------------------------------------------------------- /markdown_docs/server/db/models/conversation_model.md: -------------------------------------------------------------------------------- 1 | ## ClassDef ConversationModel 2 | **ConversationModel**: ConversationModel类的功能是定义一个聊天记录模型,用于数据库中存储聊天会话的详细信息。 3 | 4 | **属性**: 5 | - `id`: 对话框ID,是每个对话框的唯一标识符,使用String类型。 6 | - `name`: 对话框名称,存储对话框的名称,使用String类型。 7 | - `chat_type`: 聊天类型,标识聊天的类型(如普通聊天、客服聊天等),使用String类型。 8 | - `create_time`: 创建时间,记录对话框创建的时间,使用DateTime类型,默认值为当前时间。 9 | 10 | **代码描述**: 11 | ConversationModel类继承自Base类,是一个ORM模型,用于映射数据库中的`conversation`表。该模型包含四个字段:`id`、`name`、`chat_type`和`create_time`,分别用于存储对话框的唯一标识符、名称、聊天类型和创建时间。其中,`id`字段被设置为主键。此外,该类还重写了`__repr__`方法,以便在打印实例时能够清晰地显示出实例的主要信息。 12 | 13 | 在项目中,ConversationModel类被用于创建和管理聊天记录的数据。例如,在`server/db/repository/conversation_repository.py`中的`add_conversation_to_db`函数中,通过创建ConversationModel的实例并将其添加到数据库会话中,实现了聊天记录的新增功能。这显示了ConversationModel类在项目中用于处理聊天记录数据的重要角色。 14 | 15 | **注意**: 16 | - 在使用ConversationModel进行数据库操作时,需要确保传入的参数类型与字段定义相匹配,避免类型不匹配的错误。 17 | - 创建ConversationModel实例时,`id`字段可以不传入,由数据库自动生成唯一标识符,但在`add_conversation_to_db`函数中,如果没有提供`conversation_id`,则会使用`uuid.uuid4().hex`生成一个。 18 | 19 | **输出示例**: 20 | 假设创建了一个ConversationModel实例,其属性值如下: 21 | - id: "1234567890abcdef" 22 | - name: "客服对话" 23 | - chat_type: "agent_chat" 24 | - create_time: "2023-04-01 12:00:00" 25 | 26 | 则该实例的`__repr__`方法输出可能如下: 27 | ``` 28 | 29 | ``` 30 | ### FunctionDef __repr__(self) 31 | **__repr__**: 该函数的功能是生成并返回一个代表会话对象的字符串。 32 | 33 | **参数**: 此函数不接受任何外部参数。 34 | 35 | **代码描述**: `__repr__` 方法是一个特殊方法,用于定义一个对象的“官方”字符串表示。在这个场景中,`__repr__` 被定义在 `ConversationModel` 类中,目的是为了提供一个清晰且易于理解的会话对象表示。当调用此方法时,它会返回一个格式化的字符串,其中包含了会话对象的几个关键属性:`id`、`name`、`chat_type` 和 `create_time`。这些属性通过 `self` 关键字访问,表示它们属于当前的会话实例。字符串使用 f-string 格式化,这是 Python 3.6 及以上版本中引入的一种字符串格式化机制,允许将表达式的值直接嵌入到字符串常量中。 36 | 37 | **注意**: 使用 `__repr__` 方法的一个重要原则是,其返回的字符串应尽可能地反映出对象的重要信息,且最好能够通过执行这个字符串(假设环境中有正确的上下文)来重新创建出该对象。虽然在许多实际情况下,直接执行 `__repr__` 返回的字符串来复制对象并不是必需的,但这一原则仍然是一个很好的指导思想。此外,当你在调试过程中打印对象或在交互式环境中查看对象时,`__repr__` 方法返回的字符串将会被显示,这有助于快速识别对象的状态。 38 | 39 | **输出示例**: 假设有一个会话对象,其 `id` 为 "123",`name` 为 "Test Conversation",`chat_type` 为 "group",`create_time` 为 "2023-04-01",则调用 `__repr__` 方法将返回如下字符串: 40 | `""` 41 | *** 42 | -------------------------------------------------------------------------------- /tests/api/test_llm_api.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import sys 4 | from pathlib import Path 5 | 6 | root_path = Path(__file__).parent.parent.parent 7 | sys.path.append(str(root_path)) 8 | from configs.server_config import FSCHAT_MODEL_WORKERS 9 | from server.utils import api_address, get_model_worker_config 10 | 11 | from pprint import pprint 12 | import random 13 | from typing import List 14 | 15 | 16 | def get_configured_models() -> List[str]: 17 | model_workers = list(FSCHAT_MODEL_WORKERS) 18 | if "default" in model_workers: 19 | model_workers.remove("default") 20 | return model_workers 21 | 22 | 23 | api_base_url = api_address() 24 | 25 | 26 | def get_running_models(api="/llm_model/list_models"): 27 | url = api_base_url + api 28 | r = requests.post(url) 29 | if r.status_code == 200: 30 | return r.json()["data"] 31 | return [] 32 | 33 | 34 | def test_running_models(api="/llm_model/list_running_models"): 35 | url = api_base_url + api 36 | r = requests.post(url) 37 | assert r.status_code == 200 38 | print("\n获取当前正在运行的模型列表:") 39 | pprint(r.json()) 40 | assert isinstance(r.json()["data"], list) 41 | assert len(r.json()["data"]) > 0 42 | 43 | 44 | # 不建议使用stop_model功能。按现在的实现,停止了就只能手动再启动 45 | # def test_stop_model(api="/llm_model/stop"): 46 | # url = api_base_url + api 47 | # r = requests.post(url, json={""}) 48 | 49 | 50 | def test_change_model(api="/llm_model/change_model"): 51 | url = api_base_url + api 52 | 53 | running_models = get_running_models() 54 | assert len(running_models) > 0 55 | 56 | model_workers = get_configured_models() 57 | 58 | availabel_new_models = list(set(model_workers) - set(running_models)) 59 | assert len(availabel_new_models) > 0 60 | print(availabel_new_models) 61 | 62 | local_models = [x for x in running_models if not get_model_worker_config(x).get("online_api")] 63 | model_name = random.choice(local_models) 64 | new_model_name = random.choice(availabel_new_models) 65 | print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}") 66 | r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name}) 67 | assert r.status_code == 200 68 | 69 | running_models = get_running_models() 70 | assert new_model_name in running_models 71 | -------------------------------------------------------------------------------- /markdown_docs/server/webui_allinone_stale.md: -------------------------------------------------------------------------------- 1 | ## FunctionDef launch_api(args, args_list, log_name) 2 | **launch_api**: 此函数的功能是启动API服务。 3 | 4 | **参数**: 5 | - args: 包含API启动所需参数的对象,此对象应具备访问如api_host和api_port等属性的能力。 6 | - args_list: 一个字符串列表,默认值为api_args,指定了需要转换为命令行参数的键名。 7 | - log_name: 日志文件的名称。如果未提供,则根据API的主机名和端口动态生成。 8 | 9 | **代码描述**: 10 | `launch_api` 函数首先打印出启动API服务的提示信息,包括中英文两种语言。接着,如果没有提供`log_name`参数,函数会根据API服务的主机名和端口号生成日志文件的名称,并将其存储在预定义的日志路径下。然后,函数通过调用`string_args`函数,将`args`对象中的参数转换成命令行可接受的字符串格式。`string_args`函数的详细功能和使用方法已在相关文档中描述。 11 | 12 | 之后,`launch_api`函数构建了一个用于启动API服务的shell命令字符串,该字符串包含了启动脚本的名称(`api.py`)、转换后的参数字符串以及日志文件的路径。最后,使用`subprocess.run`方法执行构建的shell命令,以在后台启动API服务,并将标准输出和标准错误重定向到日志文件中。 13 | 14 | 在整个过程中,`launch_api`函数还会打印出日志文件的位置信息,以便于在API服务启动异常时,用户可以轻松地找到并查看日志文件。 15 | 16 | **在项目中的调用关系**: 17 | `launch_api` 函数在项目中负责启动API服务的核心功能。它通过调用`string_args`函数来处理命令行参数的转换,这显示了`launch_api`与`string_args`之间的直接依赖关系。`string_args`函数为`launch_api`提供了参数字符串化的能力,使得`launch_api`能够有效地构建用于启动API服务的shell命令。 18 | 19 | **注意**: 20 | - 确保传递给`launch_api`函数的`args`对象包含了所有必要的API启动参数,如`api_host`和`api_port`。 21 | - 如果`log_name`参数未提供,日志文件的命名将依赖于API服务的主机名和端口号,因此请确保这些信息的准确性。 22 | - 在使用`launch_api`函数时,应确保相关的API启动脚本(`api.py`)存在于预期的路径下,并且能够正确处理通过命令行传递的参数。 23 | ## FunctionDef launch_webui(args, args_list, log_name) 24 | **launch_webui**: 此函数的功能是启动webui服务。 25 | 26 | **参数**: 27 | - args: 包含启动webui所需参数的对象。此对象应具备访问各参数值的能力。 28 | - args_list: 参数列表,默认值为web_args,用于指定哪些参数需要被包含在最终生成的命令行字符串中。 29 | - log_name: 日志文件的名称。如果未提供,则默认使用LOG_PATH路径下的webui作为日志文件名。 30 | 31 | **代码描述**: 32 | `launch_webui` 函数主要负责启动webui服务。首先,函数打印出启动webui的提示信息,既包括英文也包括中文,以确保用户了解当前操作。接着,函数检查是否提供了`log_name`参数,如果没有提供,则使用默认的日志文件名。 33 | 34 | 接下来,函数调用`string_args`函数,将`args`对象中的参数转换为命令行可接受的字符串格式。这一步骤是通过检查`args`对象中的参数与`args_list`列表中指定的参数键名,生成最终的参数字符串。 35 | 36 | 根据`args`对象中的`nohup`参数值,`launch_webui`函数决定是否以后台模式启动webui服务。如果`nohup`为真,则构造一个命令行字符串,该字符串将webui服务的输出重定向到指定的日志文件,并在后台运行。否则,直接构造一个命令行字符串以前台模式运行webui服务。 37 | 38 | 最后,使用`subprocess.run`方法执行构造好的命令行字符串,启动webui服务。函数在webui服务启动后打印出完成提示信息。 39 | 40 | **在项目中的调用关系**: 41 | `launch_webui` 函数在项目中负责启动webui服务的任务。它依赖于`string_args`函数来处理命令行参数的生成。`string_args`函数根据提供的参数对象和参数列表,生成适用于命令行的参数字符串。这种设计使得`launch_webui`函数能够灵活地处理不同的启动参数,同时保持命令行参数生成逻辑的集中和一致性。 42 | 43 | **注意**: 44 | - 确保传递给`launch_webui`函数的`args`对象中包含了所有必要的参数,特别是`nohup`参数,因为它决定了webui服务是以前台模式还是后台模式运行。 45 | - 如果在后台模式下运行webui服务,务必检查指定的日志文件,以便于排查可能出现的启动异常。 46 | -------------------------------------------------------------------------------- /markdown_docs/server/db/models/message_model.md: -------------------------------------------------------------------------------- 1 | ## ClassDef MessageModel 2 | **MessageModel**: MessageModel类的功能是定义聊天记录的数据模型。 3 | 4 | **属性**: 5 | - `id`: 聊天记录的唯一标识ID。 6 | - `conversation_id`: 对话框ID,用于标识一次会话。 7 | - `chat_type`: 聊天类型,如普通聊天、客服聊天等。 8 | - `query`: 用户的提问或输入。 9 | - `response`: 系统或模型的回答。 10 | - `meta_data`: 存储额外信息的JSON字段,如知识库ID等,便于后续扩展。 11 | - `feedback_score`: 用户对聊天回答的评分,满分为100。 12 | - `feedback_reason`: 用户评分的理由。 13 | - `create_time`: 记录的创建时间。 14 | 15 | **代码描述**: 16 | MessageModel类继承自Base,用于定义聊天记录的数据结构。它包含了聊天记录的基本信息,如聊天ID、会话ID、聊天类型、用户问题、模型回答、元数据、用户反馈等。此类通过定义SQLAlchemy的Column字段来映射数据库中的`message`表结构。其中,`__tablename__`属性指定了数据库中对应的表名为`message`。每个属性都通过Column实例来定义,其中包括数据类型、是否为主键、默认值、索引创建、注释等信息。 17 | 18 | 在项目中,MessageModel类被用于server/db/repository/message_repository.py文件中的几个函数调用中,主要涉及到聊天记录的增加、查询和反馈。例如,`add_message_to_db`函数用于新增聊天记录,它创建了一个MessageModel实例并将其添加到数据库中。`get_message_by_id`函数通过聊天记录ID查询聊天记录。`feedback_message_to_db`函数用于更新聊天记录的用户反馈信息。`filter_message`函数则是根据对话框ID过滤聊天记录,并返回最近的几条记录。 19 | 20 | **注意**: 21 | - 在使用MessageModel进行数据库操作时,需要确保传入的参数类型与定义的字段类型相匹配。 22 | - 对于`meta_data`字段,虽然默认值为一个空字典,但在实际使用中可以根据需要存储任意结构的JSON数据。 23 | - 在进行数据库操作如添加、查询、更新记录时,应确保操作在正确的数据库会话(session)上下文中执行。 24 | 25 | **输出示例**: 26 | 由于MessageModel是一个数据模型类,它本身不直接产生输出。但是,当它被实例化并用于数据库操作时,例如通过`add_message_to_db`函数添加一条新的聊天记录,可能会返回如下的聊天记录ID: 27 | ``` 28 | '1234567890abcdef1234567890abcdef' 29 | ``` 30 | ### FunctionDef __repr__(self) 31 | **__repr__**: 此函数的功能是生成并返回一个代表消息对象的字符串。 32 | 33 | **参数**: 此函数没有参数。 34 | 35 | **代码描述**: `__repr__` 方法是一个特殊方法,用于定义对象的“官方”字符串表示。在这个具体的实现中,它返回一个格式化的字符串,该字符串包含了消息对象的多个属性,包括:`id`, `conversation_id`, `chat_type`, `query`, `response`, `meta_data`, `feedback_score`, `feedback_reason`, 以及 `create_time`。这些属性通过使用 `self` 关键字访问,表示它们是对象的实例变量。字符串使用了 f-string 格式化,这是 Python 3.6 及以上版本中引入的一种字符串格式化机制,允许将表达式的值直接嵌入到字符串常量中。 36 | 37 | **注意**: `__repr__` 方法的返回值应该尽可能地返回一个明确的对象表示,以便于调试和日志记录。返回的字符串应该尽量遵循 Python 对象表示的惯例,即 `` 的格式。此外,虽然这个方法主要用于调试和开发,但它也可以被用于日志记录或其他需要对象字符串表示的场景。 38 | 39 | **输出示例**: 假设有一个消息对象,其属性值如下:`id=1`, `conversation_id=2`, `chat_type='group'`, `query='天气如何'`, `response='晴朗'`, `meta_data='{}'`, `feedback_score=5`, `feedback_reason='准确'`, `create_time='2023-04-01 12:00:00'`。调用此对象的 `__repr__` 方法将返回以下字符串: 40 | 41 | ``` 42 | 43 | ``` 44 | *** 45 | -------------------------------------------------------------------------------- /server/db/repository/knowledge_base_repository.py: -------------------------------------------------------------------------------- 1 | from server.db.models.knowledge_base_model import KnowledgeBaseModel 2 | from server.db.session import with_session 3 | 4 | 5 | @with_session 6 | def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model): 7 | # 创建知识库实例 8 | kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() 9 | if not kb: 10 | kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model) 11 | session.add(kb) 12 | else: # update kb with new vs_type and embed_model 13 | kb.kb_info = kb_info 14 | kb.vs_type = vs_type 15 | kb.embed_model = embed_model 16 | return True 17 | 18 | 19 | @with_session 20 | def list_kbs_from_db(session, min_file_count: int = -1): 21 | kbs = session.query(KnowledgeBaseModel.kb_name).filter(KnowledgeBaseModel.file_count > min_file_count).all() 22 | kbs = [kb[0] for kb in kbs] 23 | return kbs 24 | 25 | 26 | @with_session 27 | def kb_exists(session, kb_name): 28 | kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() 29 | status = True if kb else False 30 | return status 31 | 32 | 33 | @with_session 34 | def load_kb_from_db(session, kb_name): 35 | kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() 36 | if kb: 37 | kb_name, vs_type, embed_model = kb.kb_name, kb.vs_type, kb.embed_model 38 | else: 39 | kb_name, vs_type, embed_model = None, None, None 40 | return kb_name, vs_type, embed_model 41 | 42 | 43 | @with_session 44 | def delete_kb_from_db(session, kb_name): 45 | kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() 46 | if kb: 47 | session.delete(kb) 48 | return True 49 | 50 | 51 | @with_session 52 | def get_kb_detail(session, kb_name: str) -> dict: 53 | kb: KnowledgeBaseModel = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() 54 | if kb: 55 | return { 56 | "kb_name": kb.kb_name, 57 | "kb_info": kb.kb_info, 58 | "vs_type": kb.vs_type, 59 | "embed_model": kb.embed_model, 60 | "file_count": kb.file_count, 61 | "create_time": kb.create_time, 62 | } 63 | else: 64 | return {} 65 | -------------------------------------------------------------------------------- /markdown_docs/server/chat/completion.md: -------------------------------------------------------------------------------- 1 | ## FunctionDef completion(query, stream, echo, model_name, temperature, max_tokens, prompt_name) 2 | **completion**: 此函数用于处理用户输入,生成并返回由LLM模型补全的文本。 3 | 4 | **参数**: 5 | - `query`: 用户输入的文本。 6 | - `stream`: 是否以流式输出结果。 7 | - `echo`: 是否在输出结果中回显输入文本。 8 | - `model_name`: 使用的LLM模型名称。 9 | - `temperature`: LLM模型采样温度,用于控制生成文本的随机性。 10 | - `max_tokens`: 限制LLM模型生成的Token数量。 11 | - `prompt_name`: 使用的prompt模板名称。 12 | 13 | **代码描述**: 14 | 此函数首先定义了一个异步生成器`completion_iterator`,该生成器负责实际的文本生成逻辑。它使用`get_OpenAI`函数初始化一个LLM模型,并根据提供的参数配置模型。然后,它使用`get_prompt_template`函数获取指定的prompt模板,并将用户输入的`query`传递给LLM模型进行处理。根据`stream`参数的值,此函数可以以流式方式逐个返回生成的Token,或者等待所有Token生成完成后一次性返回结果。最后,使用`EventSourceResponse`包装`completion_iterator`生成器,以适合流式输出的HTTP响应格式返回结果。 15 | 16 | 在项目中,`completion`函数通过`/other/completion`路由在`server/api.py`文件中被注册为一个POST请求的处理器。这表明它设计用于处理来自客户端的文本补全请求,客户端可以通过发送POST请求到此路由,并在请求体中提供相应的参数,来获取LLM模型基于用户输入生成的补全文本。 17 | 18 | **注意**: 19 | - 确保`model_name`参数对应的LLM模型已正确配置且可用。 20 | - `temperature`参数应在0.0到1.0之间,以控制生成文本的随机性。 21 | - 如果`max_tokens`设为负数或0,将不会限制Token数量。 22 | 23 | **输出示例**: 24 | 如果`stream`参数为`False`,并且用户输入为"今天天气如何",一个可能的返回值示例为: 25 | ``` 26 | "今天天气晴朗,适合外出。" 27 | ``` 28 | 如果`stream`参数为`True`,则可能逐个返回上述文本中的每个字或词。 29 | ### FunctionDef completion_iterator(query, model_name, prompt_name, echo) 30 | **completion_iterator**: 此函数的功能是异步迭代生成基于给定查询的完成文本。 31 | 32 | **参数**: 33 | - `query`: 字符串类型,用户的查询输入。 34 | - `model_name`: 字符串类型,默认为LLM_MODELS列表中的第一个模型,指定使用的语言模型名称。 35 | - `prompt_name`: 字符串类型,指定使用的提示模板名称。 36 | - `echo`: 布尔类型,指示是否回显输入。 37 | 38 | **代码描述**: 39 | `completion_iterator`函数是一个异步生成器,用于根据用户的查询输入生成文本。首先,函数检查`max_tokens`参数是否为整数且小于等于0,如果是,则将其设置为None。接着,通过`get_OpenAI`函数初始化一个配置好的OpenAI模型实例,其中包括模型名称、温度、最大令牌数、回调函数列表以及是否回显输入等参数。然后,使用`get_prompt_template`函数加载指定类型和名称的提示模板,并通过`PromptTemplate.from_template`方法创建一个`PromptTemplate`实例。之后,创建一个`LLMChain`实例,将提示模板和语言模型作为参数传入。 40 | 41 | 函数接下来创建一个异步任务,使用`asyncio.create_task`方法将`chain.acall`方法的调用包装起来,并通过`wrap_done`函数与一个回调函数关联,以便在任务完成时进行通知。根据`stream`变量的值,函数将以不同的方式生成文本。如果`stream`为真,则通过`callback.aiter()`异步迭代每个生成的令牌,并使用服务器发送事件(server-sent-events)来流式传输响应。如果`stream`为假,则将所有生成的令牌累加到一个字符串中,然后一次性生成整个答案。 42 | 43 | 最后,函数等待之前创建的异步任务完成,确保所有生成的文本都已处理完毕。 44 | 45 | **注意**: 46 | - 使用此函数时,需要确保`query`参数正确无误,因为它直接影响生成文本的内容。 47 | - `model_name`和`prompt_name`参数应根据需要选择合适的模型和提示模板,以获得最佳的文本生成效果。 48 | - 在使用流式传输功能时,应考虑客户端如何处理流式数据,以确保用户体验。 49 | - 此函数依赖于`get_OpenAI`和`get_prompt_template`等函数,因此在使用前应确保相关配置和模板已正确设置。 50 | *** 51 | -------------------------------------------------------------------------------- /server/db/repository/knowledge_metadata_repository.py: -------------------------------------------------------------------------------- 1 | from server.db.models.knowledge_metadata_model import SummaryChunkModel 2 | from server.db.session import with_session 3 | from typing import List, Dict 4 | 5 | 6 | @with_session 7 | def list_summary_from_db(session, 8 | kb_name: str, 9 | metadata: Dict = {}, 10 | ) -> List[Dict]: 11 | ''' 12 | 列出某知识库chunk summary。 13 | 返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...] 14 | ''' 15 | docs = session.query(SummaryChunkModel).filter(SummaryChunkModel.kb_name.ilike(kb_name)) 16 | 17 | for k, v in metadata.items(): 18 | docs = docs.filter(SummaryChunkModel.meta_data[k].as_string() == str(v)) 19 | 20 | return [{"id": x.id, 21 | "summary_context": x.summary_context, 22 | "summary_id": x.summary_id, 23 | "doc_ids": x.doc_ids, 24 | "metadata": x.metadata} for x in docs.all()] 25 | 26 | 27 | @with_session 28 | def delete_summary_from_db(session, 29 | kb_name: str 30 | ) -> List[Dict]: 31 | ''' 32 | 删除知识库chunk summary,并返回被删除的Dchunk summary。 33 | 返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...] 34 | ''' 35 | docs = list_summary_from_db(kb_name=kb_name) 36 | query = session.query(SummaryChunkModel).filter(SummaryChunkModel.kb_name.ilike(kb_name)) 37 | query.delete(synchronize_session=False) 38 | session.commit() 39 | return docs 40 | 41 | 42 | @with_session 43 | def add_summary_to_db(session, 44 | kb_name: str, 45 | summary_infos: List[Dict]): 46 | ''' 47 | 将总结信息添加到数据库。 48 | summary_infos形式:[{"summary_context": str, "doc_ids": str}, ...] 49 | ''' 50 | for summary in summary_infos: 51 | obj = SummaryChunkModel( 52 | kb_name=kb_name, 53 | summary_context=summary["summary_context"], 54 | summary_id=summary["summary_id"], 55 | doc_ids=summary["doc_ids"], 56 | meta_data=summary["metadata"], 57 | ) 58 | session.add(obj) 59 | 60 | session.commit() 61 | return True 62 | 63 | 64 | @with_session 65 | def count_summary_from_db(session, kb_name: str) -> int: 66 | return session.query(SummaryChunkModel).filter(SummaryChunkModel.kb_name.ilike(kb_name)).count() 67 | -------------------------------------------------------------------------------- /server/db/repository/message_repository.py: -------------------------------------------------------------------------------- 1 | from server.db.session import with_session 2 | from typing import Dict, List 3 | import uuid 4 | from server.db.models.message_model import MessageModel 5 | 6 | 7 | @with_session 8 | def add_message_to_db(session, conversation_id: str, chat_type, query, response="", message_id=None, 9 | metadata: Dict = {}): 10 | """ 11 | 新增聊天记录 12 | """ 13 | if not message_id: 14 | message_id = uuid.uuid4().hex 15 | m = MessageModel(id=message_id, chat_type=chat_type, query=query, response=response, 16 | conversation_id=conversation_id, 17 | meta_data=metadata) 18 | session.add(m) 19 | session.commit() 20 | return m.id 21 | 22 | 23 | @with_session 24 | def update_message(session, message_id, response: str = None, metadata: Dict = None): 25 | """ 26 | 更新已有的聊天记录 27 | """ 28 | m = get_message_by_id(message_id) 29 | if m is not None: 30 | if response is not None: 31 | m.response = response 32 | if isinstance(metadata, dict): 33 | m.meta_data = metadata 34 | session.add(m) 35 | session.commit() 36 | return m.id 37 | 38 | 39 | @with_session 40 | def get_message_by_id(session, message_id) -> MessageModel: 41 | """ 42 | 查询聊天记录 43 | """ 44 | m = session.query(MessageModel).filter_by(id=message_id).first() 45 | return m 46 | 47 | 48 | @with_session 49 | def feedback_message_to_db(session, message_id, feedback_score, feedback_reason): 50 | """ 51 | 反馈聊天记录 52 | """ 53 | m = session.query(MessageModel).filter_by(id=message_id).first() 54 | if m: 55 | m.feedback_score = feedback_score 56 | m.feedback_reason = feedback_reason 57 | session.commit() 58 | return m.id 59 | 60 | 61 | @with_session 62 | def filter_message(session, conversation_id: str, limit: int = 10): 63 | messages = (session.query(MessageModel).filter_by(conversation_id=conversation_id). 64 | # 用户最新的query 也会插入到db,忽略这个message record 65 | filter(MessageModel.response != ''). 66 | # 返回最近的limit 条记录 67 | order_by(MessageModel.create_time.desc()).limit(limit).all()) 68 | # 直接返回 List[MessageModel] 报错 69 | data = [] 70 | for m in messages: 71 | data.append({"query": m.query, "response": m.response}) 72 | return data 73 | -------------------------------------------------------------------------------- /markdown_docs/server/db/models/knowledge_base_model.md: -------------------------------------------------------------------------------- 1 | ## ClassDef KnowledgeBaseModel 2 | **KnowledgeBaseModel**: KnowledgeBaseModel 类的功能是定义知识库的数据模型,用于在数据库中存储和管理知识库的相关信息。 3 | 4 | **属性**: 5 | - `id`: 知识库ID,是每个知识库的唯一标识。 6 | - `kb_name`: 知识库名称,用于标识和检索特定的知识库。 7 | - `kb_info`: 知识库简介,提供关于知识库的基本信息,用于Agent。 8 | - `vs_type`: 向量库类型,指定知识库使用的向量库的类型。 9 | - `embed_model`: 嵌入模型名称,指定用于知识库的嵌入模型。 10 | - `file_count`: 文件数量,记录知识库中包含的文件数目。 11 | - `create_time`: 创建时间,记录知识库被创建的时间。 12 | 13 | **代码描述**: 14 | KnowledgeBaseModel 类继承自 Base 类,是一个ORM模型,用于映射数据库中的 `knowledge_base` 表。该类定义了知识库的基本属性,包括知识库ID、名称、简介、向量库类型、嵌入模型名称、文件数量和创建时间。通过这些属性,可以在数据库中有效地存储和管理知识库的相关信息。 15 | 16 | 在项目中,KnowledgeBaseModel 类被多个函数调用,以实现对知识库的增删查改操作。例如,在 `add_kb_to_db` 函数中,使用KnowledgeBaseModel 来创建新的知识库实例或更新现有知识库的信息。在 `list_kbs_from_db` 函数中,通过查询KnowledgeBaseModel 来获取满足特定条件的知识库列表。此外,`kb_exists`、`load_kb_from_db`、`delete_kb_from_db` 和 `get_kb_detail` 等函数也都涉及到对KnowledgeBaseModel 类的操作,以实现检查知识库是否存在、加载知识库信息、删除知识库和获取知识库详细信息等功能。 17 | 18 | **注意**: 19 | 在使用KnowledgeBaseModel 类进行数据库操作时,需要注意确保传入的参数类型和值符合定义的属性类型和业务逻辑要求,以避免数据类型错误或逻辑错误。 20 | 21 | **输出示例**: 22 | 假设数据库中有一个知识库实例,其属性值如下: 23 | ``` 24 | 25 | ``` 26 | 这表示有一个ID为1的知识库,名称为“技术文档库”,简介为“存储技术相关文档”,使用的向量库类型为ElasticSearch,嵌入模型为BERT,包含100个文件,创建时间为2023年4月1日12点。 27 | ### FunctionDef __repr__(self) 28 | **__repr__**: __repr__函数的功能是提供KnowledgeBaseModel对象的官方字符串表示。 29 | 30 | **参数**: 此函数没有接受额外参数,它仅使用self来访问对象的属性。 31 | 32 | **代码描述**: 33 | `__repr__`方法定义在KnowledgeBaseModel类中,用于生成该对象的官方字符串表示。这个字符串表示包含了对象的关键信息,使得开发者和调试者能够更容易地识别对象。具体来说,它返回一个格式化的字符串,其中包含了KnowledgeBaseModel对象的多个属性值,包括: 34 | - `id`:对象的唯一标识符。 35 | - `kb_name`:知识库的名称。 36 | - `kb_info`:知识库的简介。 37 | - `vs_type`:知识库的版本类型。 38 | - `embed_model`:嵌入模型的名称。 39 | - `file_count`:知识库中文件的数量。 40 | - `create_time`:知识库创建的时间。 41 | 42 | 这个方法通过f-string格式化字符串的方式,将对象属性嵌入到预定义的字符串模板中,从而生成易于阅读和理解的表示形式。 43 | 44 | **注意**: 45 | - `__repr__`方法通常用于调试和日志记录,它应该返回一个明确且无歧义的对象表示。 46 | - 在Python中,当你尝试将对象转换为字符串时(例如使用`str()`函数或在打印时),如果没有定义`__str__`方法,Python会回退到使用`__repr__`方法。 47 | - 保证`__repr__`方法返回的字符串包含足够的信息,可以用来识别对象中的关键信息。 48 | 49 | **输出示例**: 50 | ```python 51 | 52 | ``` 53 | 此示例展示了一个KnowledgeBaseModel对象的`__repr__`方法返回值的可能形式,其中包含了对象的id, kb_name, kb_intro, vs_type, embed_model, file_count, 和 create_time属性的值。 54 | *** 55 | -------------------------------------------------------------------------------- /tests/test_online_api.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | root_path = Path(__file__).parent.parent 4 | sys.path.append(str(root_path)) 5 | 6 | from configs import ONLINE_LLM_MODEL 7 | from server.model_workers.base import * 8 | from server.utils import get_model_worker_config, list_config_llm_models 9 | from pprint import pprint 10 | import pytest 11 | 12 | 13 | workers = [] 14 | for x in list_config_llm_models()["online"]: 15 | if x in ONLINE_LLM_MODEL and x not in workers: 16 | workers.append(x) 17 | print(f"all workers to test: {workers}") 18 | 19 | # workers = ["fangzhou-api"] 20 | 21 | 22 | @pytest.mark.parametrize("worker", workers) 23 | def test_chat(worker): 24 | params = ApiChatParams( 25 | messages = [ 26 | {"role": "user", "content": "你是谁"}, 27 | ], 28 | ) 29 | print(f"\nchat with {worker} \n") 30 | 31 | if worker_class := get_model_worker_config(worker).get("worker_class"): 32 | for x in worker_class().do_chat(params): 33 | pprint(x) 34 | assert isinstance(x, dict) 35 | assert x["error_code"] == 0 36 | 37 | 38 | @pytest.mark.parametrize("worker", workers) 39 | def test_embeddings(worker): 40 | params = ApiEmbeddingsParams( 41 | texts = [ 42 | "LangChain-Chatchat (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。", 43 | "一种利用 langchain 思想实现的基于本地知识库的问答应用,目标期望建立一套对中文场景与开源模型支持友好、可离线运行的知识库问答解决方案。", 44 | ] 45 | ) 46 | 47 | if worker_class := get_model_worker_config(worker).get("worker_class"): 48 | if worker_class.can_embedding(): 49 | print(f"\embeddings with {worker} \n") 50 | resp = worker_class().do_embeddings(params) 51 | 52 | pprint(resp, depth=2) 53 | assert resp["code"] == 200 54 | assert "data" in resp 55 | embeddings = resp["data"] 56 | assert isinstance(embeddings, list) and len(embeddings) > 0 57 | assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0 58 | assert isinstance(embeddings[0][0], float) 59 | print("向量长度:", len(embeddings[0])) 60 | 61 | 62 | # @pytest.mark.parametrize("worker", workers) 63 | # def test_completion(worker): 64 | # params = ApiCompletionParams(prompt="五十六个民族") 65 | 66 | # print(f"\completion with {worker} \n") 67 | 68 | # worker_class = get_model_worker_config(worker)["worker_class"] 69 | # resp = worker_class().do_completion(params) 70 | # pprint(resp) 71 | -------------------------------------------------------------------------------- /markdown_docs/server/agent/tools/weather_check.md: -------------------------------------------------------------------------------- 1 | ## FunctionDef weather(location, api_key) 2 | **weather**: 此函数的功能是获取指定地点的当前天气信息。 3 | 4 | **参数**: 5 | - location: 字符串类型,表示查询天气信息的地点。 6 | - api_key: 字符串类型,用于访问天气API的密钥。 7 | 8 | **代码描述**: 9 | `weather` 函数通过构造一个请求URL,使用 `requests.get` 方法向 `seniverse.com` 的天气API发送请求,以获取指定地点的当前天气信息。此URL包含了API密钥(api_key)、地点(location)、语言设置(默认为简体中文)和温度单位(摄氏度)。如果请求成功(HTTP状态码为200),函数将解析响应的JSON数据,提取温度和天气描述信息,然后以字典形式返回这些信息。如果请求失败,则抛出异常,异常信息包含了失败的HTTP状态码。 10 | 11 | 在项目中,`weather` 函数被 `weathercheck` 函数调用。`weathercheck` 函数接受一个地点作为参数,并使用项目中预定义的 `SENIVERSE_API_KEY` 作为API密钥来调用 `weather` 函数。这表明 `weather` 函数是项目中用于获取天气信息的核心功能,而 `weathercheck` 函数提供了一个更简便的接口,使得其他部分的代码无需直接处理API密钥即可请求天气信息。 12 | 13 | **注意**: 14 | - 确保提供的 `api_key` 是有效的,否则请求将失败。 15 | - 由于网络请求的性质,此函数的执行时间可能受到网络状况的影响。 16 | 17 | **输出示例**: 18 | ```python 19 | { 20 | "temperature": "22", 21 | "description": "多云" 22 | } 23 | ``` 24 | 此示例展示了函数返回值的可能外观,其中包含了温度和天气描述信息。 25 | ## FunctionDef weathercheck(location) 26 | **weathercheck**: 此函数的功能是使用预定义的API密钥获取指定地点的当前天气信息。 27 | 28 | **参数**: 29 | - location: 字符串类型,表示查询天气信息的地点。 30 | 31 | **代码描述**: 32 | `weathercheck` 函数是一个简化的接口,用于获取指定地点的天气信息。它接受一个地点名称作为参数,并内部调用了 `weather` 函数,后者是实际执行天气信息获取操作的函数。在调用 `weather` 函数时,`weathercheck` 使用了预定义的 `SENIVERSE_API_KEY` 作为API密钥参数。这意味着使用 `weathercheck` 函数时,用户无需直接处理API密钥,从而简化了获取天气信息的过程。 33 | 34 | `weather` 函数负责构造请求URL,并通过HTTP GET请求向 `seniverse.com` 的天气API发送请求。如果请求成功,它将解析响应的JSON数据,并提取出温度和天气描述信息,然后以字典形式返回这些信息。如果请求失败,`weather` 函数将抛出异常,包含失败的HTTP状态码。 35 | 36 | **注意**: 37 | - 使用 `weathercheck` 函数时,确保预定义的 `SENIVERSE_API_KEY` 是有效的。无效的API密钥将导致请求失败。 38 | - 获取天气信息的过程涉及网络请求,因此执行时间可能受到网络状况的影响。在网络状况不佳的情况下,响应时间可能会较长。 39 | 40 | **输出示例**: 41 | 由于 `weathercheck` 函数内部调用了 `weather` 函数并直接返回其结果,因此输出示例与 `weather` 函数的输出示例相同。以下是一个可能的返回值示例: 42 | ```python 43 | { 44 | "temperature": "22", 45 | "description": "多云" 46 | } 47 | ``` 48 | 此示例展示了函数返回值的可能外观,其中包含了温度和天气描述信息。 49 | ## ClassDef WeatherInput 50 | **WeatherInput**: WeatherInput类的功能是定义一个用于天气查询的输入模型。 51 | 52 | **属性**: 53 | - location: 表示查询天气的城市名称,包括城市和县。 54 | 55 | **代码描述**: 56 | WeatherInput类继承自BaseModel,这是一个常见的做法,用于创建具有类型注解的数据模型。在这个类中,定义了一个名为`location`的属性,该属性用于存储用户希望查询天气的城市名称。通过使用`Field`函数,为`location`属性提供了额外的描述信息,即"City name, include city and county",这有助于理解该属性的用途和预期的值格式。 57 | 58 | 在项目的上下文中,尽管具体的调用情况未在提供的信息中明确,但可以推断WeatherInput类被设计为在天气查询功能中使用。它可能被用于从用户那里接收输入,然后这些输入将被用于查询特定城市的天气信息。这种设计允许天气查询功能以一种结构化和类型安全的方式处理用户输入。 59 | 60 | **注意**: 61 | - 在使用WeatherInput类时,需要确保传递给`location`属性的值是一个格式正确的字符串,即包含城市和县的名称。这是因为该模型可能会被用于向天气API发送请求,而这些API通常要求准确的地理位置信息以返回正确的天气数据。 62 | - 由于WeatherInput类继承自BaseModel,因此可以利用Pydantic库提供的各种功能,如数据验证、序列化和反序列化等。这使得处理和转换用户输入变得更加容易和安全。 63 | -------------------------------------------------------------------------------- /markdown_docs/server/chat/chat.md: -------------------------------------------------------------------------------- 1 | ## FunctionDef chat(query, conversation_id, history_len, history, stream, model_name, temperature, max_tokens, prompt_name) 2 | **chat**: 此函数用于实现与LLM模型的对话功能。 3 | 4 | **参数**: 5 | - `query`: 用户输入的查询字符串。 6 | - `conversation_id`: 对话框ID,用于标识一个对话会话。 7 | - `history_len`: 从数据库中取历史消息的数量。 8 | - `history`: 历史对话记录,可以是一个整数或一个历史记录列表。 9 | - `stream`: 是否以流式输出的方式返回对话结果。 10 | - `model_name`: LLM模型名称。 11 | - `temperature`: LLM采样温度,用于控制生成文本的随机性。 12 | - `max_tokens`: 限制LLM生成Token的数量。 13 | - `prompt_name`: 使用的prompt模板名称。 14 | 15 | **代码描述**: 16 | 此函数主要负责处理用户与LLM模型之间的对话。它首先通过`add_message_to_db`函数将用户的查询和对话信息保存到数据库中。然后,根据传入的参数,如历史对话记录、模型名称、温度等,构建一个适合LLM模型的输入提示(prompt)。接着,使用`LLMChain`对象发起对话请求,并通过`AsyncIteratorCallbackHandler`处理模型的异步响应。如果启用了流式输出,函数将逐个Token地返回响应结果;否则,会等待所有响应完成后,一次性返回整个对话结果。最后,通过`EventSourceResponse`将结果以服务器发送事件(SSE)的形式返回给客户端。 17 | 18 | **注意**: 19 | - 在使用`history`参数时,可以直接传入历史对话记录的列表,或者传入一个整数,函数会从数据库中读取指定数量的历史消息。 20 | - `stream`参数控制输出模式,当设置为True时,对话结果将以流式输出,适用于需要实时显示对话过程的场景。 21 | - `max_tokens`参数用于限制LLM模型生成的Token数量,有助于控制生成文本的长度。 22 | 23 | **输出示例**: 24 | 假设函数以非流式模式运行,并且返回了一条简单的对话响应: 25 | ```json 26 | { 27 | "text": "好的,我明白了。", 28 | "message_id": "123456" 29 | } 30 | ``` 31 | 这表示LLM模型对用户的查询给出了回复“好的,我明白了。”,并且该对话消息在数据库中的ID为"123456"。 32 | ### FunctionDef chat_iterator 33 | **chat_iterator**: 此函数的功能是异步迭代聊天过程,生成并流式传输聊天响应。 34 | 35 | **参数**: 36 | - 无参数直接传递给此函数,但函数内部使用了多个外部定义的变量和对象。 37 | 38 | **代码描述**: 39 | `chat_iterator`是一个异步生成器函数,用于处理聊天会话,生成聊天响应并以流的形式传输。函数首先定义了一个`callback`对象,该对象是`AsyncIteratorCallbackHandler`的实例,用于处理异步迭代的回调。接着,创建了一个回调列表`callbacks`,并将`callback`对象添加到其中。 40 | 41 | 函数通过调用`add_message_to_db`函数,将聊天请求添加到数据库中,并创建了一个`conversation_callback`对象,该对象是`ConversationCallbackHandler`的实例,用于处理聊天过程中的回调,并将其添加到回调列表中。 42 | 43 | 根据`max_tokens`的值调整生成文本的最大token数量,如果`max_tokens`为非正整数,则将其设置为`None`。 44 | 45 | 接下来,通过调用`get_ChatOpenAI`函数,初始化一个聊天模型`model`,并根据聊天历史(如果有)或会话ID(如果指定)来构建聊天提示`chat_prompt`。如果没有提供历史或会话ID,则使用默认的提示模板。 46 | 47 | 然后,创建了一个`LLMChain`对象`chain`,它负责将聊天提示传递给聊天模型,并开始一个异步任务,该任务使用`wrap_done`函数包装了`chain.acall`的调用,以便在任务完成时通过`callback.done`方法进行通知。 48 | 49 | 最后,根据`stream`变量的值决定是流式传输每个生成的token,还是等待所有token生成后一次性返回。在流式传输模式下,使用`json.dumps`将生成的token和消息ID封装成JSON格式并逐个yield返回;在非流式传输模式下,将所有生成的token拼接后一次性yield返回。 50 | 51 | **注意**: 52 | - `chat_iterator`函数是异步的,因此在调用时需要使用`await`关键字或在其他异步函数中调用。 53 | - 函数内部使用了多个外部定义的变量和对象,如`history`、`max_tokens`等,这要求在调用`chat_iterator`之前,这些变量和对象必须已经被正确初始化和配置。 54 | - 函数依赖于多个外部定义的函数和类,如`AsyncIteratorCallbackHandler`、`add_message_to_db`、`ConversationCallbackHandler`、`get_ChatOpenAI`等,确保这些依赖项在项目中已正确实现。 55 | - 在处理聊天响应时,函数考虑了多种情况,包括有无聊天历史、是否从数据库获取历史消息等,这要求调用者根据实际情况提供正确的参数和配置。 56 | - 使用此函数时,应注意异常处理和资源管理,确保在聊天会话结束时释放所有资源。 57 | *** 58 | -------------------------------------------------------------------------------- /document_loaders/mypptloader.py: -------------------------------------------------------------------------------- 1 | from langchain.document_loaders.unstructured import UnstructuredFileLoader 2 | from typing import List 3 | import tqdm 4 | 5 | 6 | class RapidOCRPPTLoader(UnstructuredFileLoader): 7 | def _get_elements(self) -> List: 8 | def ppt2text(filepath): 9 | from pptx import Presentation 10 | from PIL import Image 11 | import numpy as np 12 | from io import BytesIO 13 | from rapidocr_onnxruntime import RapidOCR 14 | ocr = RapidOCR() 15 | prs = Presentation(filepath) 16 | resp = "" 17 | 18 | def extract_text(shape): 19 | nonlocal resp 20 | if shape.has_text_frame: 21 | resp += shape.text.strip() + "\n" 22 | if shape.has_table: 23 | for row in shape.table.rows: 24 | for cell in row.cells: 25 | for paragraph in cell.text_frame.paragraphs: 26 | resp += paragraph.text.strip() + "\n" 27 | if shape.shape_type == 13: # 13 表示图片 28 | image = Image.open(BytesIO(shape.image.blob)) 29 | result, _ = ocr(np.array(image)) 30 | if result: 31 | ocr_result = [line[1] for line in result] 32 | resp += "\n".join(ocr_result) 33 | elif shape.shape_type == 6: # 6 表示组合 34 | for child_shape in shape.shapes: 35 | extract_text(child_shape) 36 | 37 | b_unit = tqdm.tqdm(total=len(prs.slides), 38 | desc="RapidOCRPPTLoader slide index: 1") 39 | # 遍历所有幻灯片 40 | for slide_number, slide in enumerate(prs.slides, start=1): 41 | b_unit.set_description( 42 | "RapidOCRPPTLoader slide index: {}".format(slide_number)) 43 | b_unit.refresh() 44 | sorted_shapes = sorted(slide.shapes, 45 | key=lambda x: (x.top, x.left)) # 从上到下、从左到右遍历 46 | for shape in sorted_shapes: 47 | extract_text(shape) 48 | b_unit.update(1) 49 | return resp 50 | 51 | text = ppt2text(self.file_path) 52 | from unstructured.partition.text import partition_text 53 | return partition_text(text=text, **self.unstructured_kwargs) 54 | 55 | 56 | if __name__ == '__main__': 57 | loader = RapidOCRPPTLoader(file_path="../tests/samples/ocr_test.pptx") 58 | docs = loader.load() 59 | print(docs) 60 | -------------------------------------------------------------------------------- /markdown_docs/release.md: -------------------------------------------------------------------------------- 1 | ## FunctionDef get_latest_tag 2 | **get_latest_tag**: 此函数的功能是获取Git仓库中最新的标签。 3 | 4 | **参数**: 此函数不接受任何参数。 5 | 6 | **代码描述**: `get_latest_tag` 函数首先使用 `subprocess.check_output` 方法执行 `git tag` 命令,以获取当前Git仓库中所有的标签。然后,通过对输出结果进行解码(UTF-8)和分割,将其转换成一个标签列表。接下来,使用 `sorted` 函数和一个自定义的排序键,基于标签的版本号(假设遵循 `v主版本号.次版本号.修订号` 的格式)对标签列表进行排序。排序键通过正则表达式 `re.match` 匹配每个标签的版本号,并将其转换为整数元组,以便进行比较。最后,函数返回排序后的最后一个元素,即最新的标签。 7 | 8 | 在项目中,`get_latest_tag` 函数被 `main` 函数调用,用于获取当前Git仓库中的最新标签,并在终端中显示。此外,`main` 函数还根据用户的输入决定如何递增版本号,并创建新的标签推送到远程仓库。因此,`get_latest_tag` 函数在自动化版本控制和发布流程中起着关键作用,它确保了版本号的正确递增和新版本标签的生成。 9 | 10 | **注意**: 使用此函数时,需要确保当前环境已安装Git,并且函数调用是在一个Git仓库的根目录下进行的。此外,此函数假定标签遵循 `v主版本号.次版本号.修订号` 的命名约定,如果标签不遵循此格式,可能无法正确排序和识别最新标签。 11 | 12 | **输出示例**: 假设Git仓库中的最新标签为 `v1.2.3`,则函数调用 `get_latest_tag()` 将返回字符串 `"v1.2.3"`。 13 | ## FunctionDef update_version_number(latest_tag, increment) 14 | **update_version_number**: 此函数用于根据最新的Git标签和用户指定的版本号递增规则来更新版本号。 15 | 16 | **参数**: 17 | - `latest_tag`: 最新的Git标签,字符串格式,预期为`vX.Y.Z`的形式,其中X、Y、Z分别代表主版本号、次版本号和修订号。 18 | - `increment`: 用户指定的版本号递增规则,接受的值为`'X'`、`'Y'`或`'Z'`,分别代表递增主版本号、次版本号或修订号。 19 | 20 | **代码描述**: 21 | 函数首先通过正则表达式从`latest_tag`中提取出当前的主版本号、次版本号和修订号,并将它们转换为整数。根据`increment`参数的值,函数将相应的版本号部分递增。如果`increment`为`'X'`,则主版本号加一,次版本号和修订号重置为0。如果`increment`为`'Y'`,则次版本号加一,修订号重置为0。如果`increment`为`'Z'`,则修订号加一。最后,函数将更新后的版本号拼接成`vX.Y.Z`的格式并返回。 22 | 23 | 此函数在项目中被`main`函数调用。在`main`函数中,首先获取当前最新的Git标签,然后询问用户希望递增哪部分版本号(主版本号、次版本号或修订号)。用户输入后,`update_version_number`函数被调用以生成新的版本号。根据用户的确认,新的版本号可能会被用来创建Git标签并推送到远程仓库。 24 | 25 | **注意**: 26 | - 输入的`latest_tag`必须严格遵循`vX.Y.Z`的格式,否则正则表达式匹配将失败,函数将无法正确执行。 27 | - `increment`参数仅接受`'X'`、`'Y'`、`'Z'`三个值,任何其他输入都将导致函数无法按预期递增版本号。 28 | 29 | **输出示例**: 30 | 如果`latest_tag`为`v1.2.3`且`increment`为`'Y'`,则函数将返回`v1.3.0`。 31 | ## FunctionDef main 32 | **main**: 此函数的功能是自动化Git版本控制流程,包括获取最新Git标签,递增版本号,并根据用户确认将新版本号作为标签推送到远程仓库。 33 | 34 | **参数**: 此函数不接受任何参数。 35 | 36 | **代码描述**: `main` 函数首先通过调用 `get_latest_tag` 函数获取当前Git仓库中的最新标签,并将其打印出来。接着,函数提示用户选择要递增的版本号部分(主版本号X、次版本号Y或修订号Z)。用户的选择通过标准输入接收,并转换为大写字母以便后续处理。如果用户输入的不是X、Y或Z中的任何一个,系统会提示错误并要求用户重新输入,直到输入正确为止。 37 | 38 | 一旦获得有效输入,`main` 函数将调用 `update_version_number` 函数,传入最新的Git标签和用户选择的递增部分,以生成新的版本号。新版本号随后被打印出来,询问用户是否确认更新版本号并推送到远程仓库。用户的确认通过标准输入接收,并转换为小写字母进行判断。 39 | 40 | 如果用户确认(输入'y'),则使用 `subprocess.run` 方法执行Git命令,首先创建新的版本标签,然后将该标签推送到远程仓库。操作完成后,打印出相应的提示信息。如果用户不确认(输入'n'),则打印出操作已取消的信息。 41 | 42 | **注意**: 43 | - 在使用此函数之前,需要确保当前环境已安装Git,并且函数调用是在一个Git仓库的根目录下进行的。 44 | - 用户输入的处理是大小写不敏感的,即输入'X'、'x'均被视为有效输入,并且都会被转换为大写进行处理。 45 | - 在推送新标签到远程仓库之前,函数会要求用户进行确认。这是一个安全措施,以防止意外修改远程仓库。 46 | - 此函数依赖于`get_latest_tag`和`update_version_number`两个函数。`get_latest_tag`用于获取最新的Git标签,而`update_version_number`根据用户指定的递增规则更新版本号。这两个函数的正确执行是`main`函数能够正确工作的基础。 47 | -------------------------------------------------------------------------------- /server/model_workers/SparkApi.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import datetime 3 | import hashlib 4 | import hmac 5 | from urllib.parse import urlparse 6 | from datetime import datetime 7 | from time import mktime 8 | from urllib.parse import urlencode 9 | from wsgiref.handlers import format_date_time 10 | 11 | 12 | class Ws_Param(object): 13 | # 初始化 14 | def __init__(self, APPID, APIKey, APISecret, Spark_url): 15 | self.APPID = APPID 16 | self.APIKey = APIKey 17 | self.APISecret = APISecret 18 | self.host = urlparse(Spark_url).netloc 19 | self.path = urlparse(Spark_url).path 20 | self.Spark_url = Spark_url 21 | 22 | # 生成url 23 | def create_url(self): 24 | # 生成RFC1123格式的时间戳 25 | now = datetime.now() 26 | date = format_date_time(mktime(now.timetuple())) 27 | 28 | # 拼接字符串 29 | signature_origin = "host: " + self.host + "\n" 30 | signature_origin += "date: " + date + "\n" 31 | signature_origin += "GET " + self.path + " HTTP/1.1" 32 | 33 | # 进行hmac-sha256进行加密 34 | signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), 35 | digestmod=hashlib.sha256).digest() 36 | 37 | signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') 38 | 39 | authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' 40 | 41 | authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') 42 | 43 | # 将请求的鉴权参数组合为字典 44 | v = { 45 | "authorization": authorization, 46 | "date": date, 47 | "host": self.host 48 | } 49 | # 拼接鉴权参数,生成url 50 | url = self.Spark_url + '?' + urlencode(v) 51 | # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 52 | return url 53 | 54 | 55 | def gen_params(appid, domain, question, temperature, max_token): 56 | """ 57 | 通过appid和用户的提问来生成请参数 58 | """ 59 | data = { 60 | "header": { 61 | "app_id": appid, 62 | "uid": "1234" 63 | }, 64 | "parameter": { 65 | "chat": { 66 | "domain": domain, 67 | "random_threshold": 0.5, 68 | "max_tokens": max_token, 69 | "auditing": "default", 70 | "temperature": temperature, 71 | } 72 | }, 73 | "payload": { 74 | "message": { 75 | "text": question 76 | } 77 | } 78 | } 79 | return data 80 | -------------------------------------------------------------------------------- /markdown_docs/server/agent/tools/search_internet.md: -------------------------------------------------------------------------------- 1 | ## FunctionDef search_engine_iter(query) 2 | **search_engine_iter**: 该函数用于通过指定的搜索引擎异步检索查询内容,并生成相关的回答。 3 | 4 | **参数**: 5 | - `query`: 用户输入的查询内容,类型为字符串。 6 | 7 | **代码描述**: 8 | `search_engine_iter`函数是一个异步函数,主要用于处理用户的查询请求。它首先调用`search_engine_chat`函数,向指定的搜索引擎(本例中为Bing)发送查询请求,并设置了一系列参数,包括模型名称、温度值、历史记录、返回结果的数量、最大Token数、提示名称以及是否以流式传输的方式返回结果。这些参数的设置旨在优化搜索结果的相关性和质量。 9 | 10 | 在调用`search_engine_chat`后,函数通过异步迭代器`response.body_iterator`遍历响应体。每次迭代返回的数据是一个JSON字符串,包含了搜索引擎返回的答案和相关文档。函数解析这些JSON字符串,提取出答案和文档信息,并将答案内容累加到`contents`变量中。 11 | 12 | 最终,函数返回累加后的`contents`变量,即包含了所有相关答案的字符串。 13 | 14 | **注意**: 15 | - 该函数是异步的,因此在调用时需要使用`await`关键字或在异步环境中调用。 16 | - 函数的执行依赖于外部的搜索引擎服务和LLM模型,因此执行时间可能受到网络状况和服务响应时间的影响。 17 | - 在使用该函数之前,需要确保已经配置了相应的搜索引擎API密钥和LLM模型。 18 | 19 | **输出示例**: 20 | ```json 21 | "根据您的查询,这里是生成的回答。" 22 | ``` 23 | 该输出示例展示了函数可能返回的答案内容。实际返回的内容将根据查询内容和搜索引擎返回的结果而有所不同。 24 | ## FunctionDef search_internet(query) 25 | **search_internet**: 该函数用于通过异步方式调用搜索引擎,检索用户查询的内容。 26 | 27 | **参数**: 28 | - `query`: 用户输入的查询内容,类型为字符串。 29 | 30 | **代码描述**: 31 | `search_internet`函数是一个简洁的接口,用于触发对指定查询内容的互联网搜索。它通过调用`search_engine_iter`函数实现,后者是一个异步函数,负责具体的搜索操作和处理逻辑。在`search_internet`函数中,使用`asyncio.run`方法来运行`search_engine_iter`函数,这允许同步代码中方便地调用异步函数,并等待其结果。 32 | 33 | `search_engine_iter`函数详细描述了搜索过程,包括向搜索引擎发送请求、处理返回的数据,并最终将累加的答案内容作为字符串返回。这个过程涉及到异步编程的知识,特别是在处理网络请求和响应时的异步迭代。 34 | 35 | **注意**: 36 | - 由于`search_internet`函数内部使用了`asyncio.run`,它不应该被用在已经运行的异步函数或事件循环中,以避免抛出异常。 37 | - 函数的执行效率和结果质量依赖于外部搜索引擎的响应速度和准确性,因此在网络状况不佳或搜索引擎服务不稳定时,可能会影响使用体验。 38 | - 在使用之前,确保相关的搜索引擎API密钥和配置已经正确设置,以保证搜索功能的正常工作。 39 | 40 | **输出示例**: 41 | 假设用户查询的内容为“Python 异步编程”,函数可能返回的字符串示例为: 42 | ``` 43 | "Python异步编程是一种编程范式,旨在提高程序的并发性和性能。这里是一些关于Python异步编程的基础知识和实践指南。" 44 | ``` 45 | 该示例展示了函数可能返回的答案内容,实际返回的内容将根据查询内容和搜索引擎返回的结果而有所不同。 46 | ## ClassDef SearchInternetInput 47 | **SearchInternetInput**: SearchInternetInput类的功能是定义一个用于互联网搜索的输入模型。 48 | 49 | **属性**: 50 | - location: 用于互联网搜索的查询字符串。 51 | 52 | **代码描述**: 53 | SearchInternetInput类继承自BaseModel,这意味着它是一个模型类,通常用于处理数据的验证、序列化和反序列化。在这个类中,定义了一个名为`location`的属性,该属性用于存储用户希望进行搜索的查询字符串。通过使用Pydantic库中的`Field`函数,为`location`属性提供了一个描述性文本,即"Query for Internet search",这有助于理解该属性的用途。 54 | 55 | 该类在项目中的作用是作为搜索互联网功能的输入数据模型。它的设计允许开发者在调用搜索互联网相关功能时,能够以结构化的方式提供必要的输入信息,即用户想要搜索的内容。这种方式提高了代码的可读性和易用性,同时也便于后续的数据验证和处理。 56 | 57 | 从项目结构来看,虽然`server/agent/tools/__init__.py`和`server/agent/tools_select.py`两个文件中没有直接提到SearchInternetInput类的使用,但可以推断,SearchInternetInput类可能会被项目中负责处理搜索请求的部分调用。具体来说,开发者可能会在处理搜索请求的函数或方法中,实例化SearchInternetInput类,然后根据用户的输入构造location属性,最后使用这个实例来执行搜索操作。 58 | 59 | **注意**: 60 | - 在使用SearchInternetInput类时,开发者需要确保提供的`location`值是有效的搜索查询字符串,因为这将直接影响搜索结果的相关性和准确性。 61 | - 考虑到数据验证的需求,开发者在使用此类时应当熟悉Pydantic库的基本用法,以便充分利用模型验证等功能。 62 | -------------------------------------------------------------------------------- /server/knowledge_base/kb_api.py: -------------------------------------------------------------------------------- 1 | import urllib 2 | from server.utils import BaseResponse, ListResponse 3 | from server.knowledge_base.utils import validate_kb_name 4 | from server.knowledge_base.kb_service.base import KBServiceFactory 5 | from server.db.repository.knowledge_base_repository import list_kbs_from_db 6 | from configs import EMBEDDING_MODEL, logger, log_verbose 7 | from fastapi import Body 8 | 9 | 10 | def list_kbs(): 11 | # Get List of Knowledge Base 12 | return ListResponse(data=list_kbs_from_db()) 13 | 14 | 15 | def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), 16 | vector_store_type: str = Body("faiss"), 17 | embed_model: str = Body(EMBEDDING_MODEL), 18 | ) -> BaseResponse: 19 | # Create selected knowledge base 20 | if not validate_kb_name(knowledge_base_name): 21 | return BaseResponse(code=403, msg="Don't attack me") 22 | if knowledge_base_name is None or knowledge_base_name.strip() == "": 23 | return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称") 24 | 25 | kb = KBServiceFactory.get_service_by_name(knowledge_base_name) 26 | if kb is not None: 27 | return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") 28 | 29 | kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model) 30 | try: 31 | kb.create_kb() 32 | except Exception as e: 33 | msg = f"创建知识库出错: {e}" 34 | logger.error(f'{e.__class__.__name__}: {msg}', 35 | exc_info=e if log_verbose else None) 36 | return BaseResponse(code=500, msg=msg) 37 | 38 | return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") 39 | 40 | 41 | def delete_kb( 42 | knowledge_base_name: str = Body(..., examples=["samples"]) 43 | ) -> BaseResponse: 44 | # Delete selected knowledge base 45 | if not validate_kb_name(knowledge_base_name): 46 | return BaseResponse(code=403, msg="Don't attack me") 47 | knowledge_base_name = urllib.parse.unquote(knowledge_base_name) 48 | 49 | kb = KBServiceFactory.get_service_by_name(knowledge_base_name) 50 | 51 | if kb is None: 52 | return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") 53 | 54 | try: 55 | status = kb.clear_vs() 56 | status = kb.drop_kb() 57 | if status: 58 | return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}") 59 | except Exception as e: 60 | msg = f"删除知识库时出现意外: {e}" 61 | logger.error(f'{e.__class__.__name__}: {msg}', 62 | exc_info=e if log_verbose else None) 63 | return BaseResponse(code=500, msg=msg) 64 | 65 | return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}") 66 | -------------------------------------------------------------------------------- /server/knowledge_base/kb_summary/base.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from configs import ( 4 | EMBEDDING_MODEL, 5 | KB_ROOT_PATH) 6 | 7 | from abc import ABC, abstractmethod 8 | from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss 9 | import os 10 | import shutil 11 | from server.db.repository.knowledge_metadata_repository import add_summary_to_db, delete_summary_from_db 12 | 13 | from langchain.docstore.document import Document 14 | 15 | 16 | class KBSummaryService(ABC): 17 | kb_name: str 18 | embed_model: str 19 | vs_path: str 20 | kb_path: str 21 | 22 | def __init__(self, 23 | knowledge_base_name: str, 24 | embed_model: str = EMBEDDING_MODEL 25 | ): 26 | self.kb_name = knowledge_base_name 27 | self.embed_model = embed_model 28 | 29 | self.kb_path = self.get_kb_path() 30 | self.vs_path = self.get_vs_path() 31 | 32 | if not os.path.exists(self.vs_path): 33 | os.makedirs(self.vs_path) 34 | 35 | 36 | def get_vs_path(self): 37 | return os.path.join(self.get_kb_path(), "summary_vector_store") 38 | 39 | def get_kb_path(self): 40 | return os.path.join(KB_ROOT_PATH, self.kb_name) 41 | 42 | def load_vector_store(self) -> ThreadSafeFaiss: 43 | return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, 44 | vector_name="summary_vector_store", 45 | embed_model=self.embed_model, 46 | create=True) 47 | 48 | def add_kb_summary(self, summary_combine_docs: List[Document]): 49 | with self.load_vector_store().acquire() as vs: 50 | ids = vs.add_documents(documents=summary_combine_docs) 51 | vs.save_local(self.vs_path) 52 | 53 | summary_infos = [{"summary_context": doc.page_content, 54 | "summary_id": id, 55 | "doc_ids": doc.metadata.get('doc_ids'), 56 | "metadata": doc.metadata} for id, doc in zip(ids, summary_combine_docs)] 57 | status = add_summary_to_db(kb_name=self.kb_name, summary_infos=summary_infos) 58 | return status 59 | 60 | def create_kb_summary(self): 61 | """ 62 | 创建知识库chunk summary 63 | :return: 64 | """ 65 | 66 | if not os.path.exists(self.vs_path): 67 | os.makedirs(self.vs_path) 68 | 69 | def drop_kb_summary(self): 70 | """ 71 | 删除知识库chunk summary 72 | :param kb_name: 73 | :return: 74 | """ 75 | with kb_faiss_pool.atomic: 76 | kb_faiss_pool.pop(self.kb_name) 77 | shutil.rmtree(self.vs_path) 78 | delete_summary_from_db(kb_name=self.kb_name) 79 | -------------------------------------------------------------------------------- /markdown_docs/text_splitter/zh_title_enhance.md: -------------------------------------------------------------------------------- 1 | ## FunctionDef under_non_alpha_ratio(text, threshold) 2 | **under_non_alpha_ratio**: 此函数用于检查文本片段中非字母字符的比例是否超过给定阈值。 3 | 4 | **参数**: 5 | - text: 需要测试的输入字符串。 6 | - threshold: 如果非字母字符的比例超过此阈值,则函数返回False。 7 | 8 | **代码描述**: 9 | `under_non_alpha_ratio`函数主要用于过滤掉那些可能被错误标记为标题或叙述文本的字符串,例如包含大量非字母字符(如"-----------BREAK---------")的字符串。该函数通过计算输入文本中非空格且为字母的字符占非空格字符总数的比例,来判断该比例是否低于给定的阈值。如果是,则认为文本中非字母字符的比例过高,函数返回False。值得注意的是,空格字符在计算总字符数时被忽略。 10 | 11 | 在项目中,`under_non_alpha_ratio`函数被`is_possible_title`函数调用,用于判断一个文本是否可能是一个有效的标题。`is_possible_title`函数通过一系列规则(如文本长度、文本末尾是否有标点符号、文本中非字母字符的比例等)来判断文本是否可能是标题。在这个过程中,`under_non_alpha_ratio`函数负责检查文本中非字母字符的比例是否超过了设定的阈值(默认为0.5),这是判断文本是否可能是标题的重要条件之一。 12 | 13 | **注意**: 14 | - 如果输入的文本为空,或者在计算比例时发生任何异常(例如除以零的情况),函数将返回False。 15 | - 函数的阈值参数是可配置的,可以根据实际情况调整,默认值为0.5。 16 | 17 | **输出示例**: 18 | 假设有一个文本`"Hello, World!"`,调用`under_non_alpha_ratio("Hello, World!")`将返回False,因为该文本中字母字符的比例高于默认阈值0.5。而对于文本`"-----BREAK-----"`,调用`under_non_alpha_ratio("-----BREAK-----")`则可能返回True,因为非字母字符的比例超过了阈值。 19 | ## FunctionDef is_possible_title(text, title_max_word_length, non_alpha_threshold) 20 | **is_possible_title**: 此函数用于检查文本是否符合作为有效标题的所有条件。 21 | 22 | **参数**: 23 | - text: 要检查的输入文本。 24 | - title_max_word_length: 标题可以包含的最大单词数,默认为20。 25 | - non_alpha_threshold: 文本被认为是标题所需的最小字母字符比例,默认为0.5。 26 | 27 | **代码描述**: 28 | `is_possible_title`函数通过一系列条件来判断给定的文本是否可能是一个有效的标题。首先,如果文本长度为0,即文本为空,则直接返回False,表示这不是一个标题。其次,如果文本以标点符号结束,也被认为不是标题。此外,如果文本的长度超过了设定的最大单词数(默认为20),或者文本中非字母字符的比例超过了设定的阈值(通过调用`under_non_alpha_ratio`函数检查),则同样认为不是标题。函数还会检查文本是否以逗号、句号结束,或者文本是否全为数字,这些情况下文本也不会被认为是标题。最后,函数检查文本开头的5个字符中是否包含数字,如果不包含,则认为这不是一个标题。 29 | 30 | **注意**: 31 | - 函数中使用了正则表达式来检查文本是否以标点符号结束,这是判断文本是否可能是标题的一个条件。 32 | - 在判断文本长度是否超过最大单词数时,简单地基于空格进行分割,而没有使用复杂的词语分词方法,这是出于性能考虑。 33 | - `under_non_alpha_ratio`函数被用于计算文本中非字母字符的比例,以帮助判断文本是否可能是标题。 34 | 35 | **输出示例**: 36 | 假设有一个文本`"这是一个可能的标题"`,调用`is_possible_title("这是一个可能的标题")`将返回True,因为该文本满足所有作为标题的条件。而对于文本`"这不是标题。"`,调用`is_possible_title("这不是标题。")`则会返回False,因为它以标点符号结束。 37 | ## FunctionDef zh_title_enhance(docs) 38 | **zh_title_enhance**: 此函数的功能是增强文档集中的标题,并对后续文档内容进行相应的标注。 39 | 40 | **参数**: 41 | - docs: 一个Document对象,代表需要处理的文档集。 42 | 43 | **代码描述**: 44 | `zh_title_enhance`函数首先检查传入的文档集`docs`是否为空。如果不为空,它遍历每个文档,使用`is_possible_title`函数来判断当前文档的`page_content`是否可能是一个有效的标题。如果是,它会将当前文档的`metadata`中的`category`设置为`'cn_Title'`,并将该文档的`page_content`作为标题保存。对于随后的文档,如果已经找到了标题,它会在这些文档的`page_content`前添加一段文本,说明这部分内容与之前找到的标题有关。如果传入的文档集为空,则会打印出“文件不存在”的提示。 45 | 46 | **注意**: 47 | - 此函数依赖于`is_possible_title`函数来判断一个文档内容是否可以作为标题。`is_possible_title`函数根据文本的特征(如长度、标点符号结束、数字比例等)来判断文本是否可能是标题。 48 | - 函数修改了传入的Document对象,为可能的标题文档添加了元数据标记,并且修改了后续文档的内容以反映它们与找到的标题的关系。 49 | - 如果文档集为空,函数不会执行任何操作,只会打印提示信息。 50 | 51 | **输出示例**: 52 | 假设传入的文档集包含两个文档,第一个文档的`page_content`是一个有效的标题,第二个文档是正文内容。处理后,第一个文档的`metadata`将包含`{'category': 'cn_Title'}`,而第二个文档的`page_content`将被修改为“下文与(有效标题)有关。原始正文内容”。 53 | -------------------------------------------------------------------------------- /markdown_docs/server/chat/knowledge_base_chat.md: -------------------------------------------------------------------------------- 1 | ## FunctionDef knowledge_base_chat(query, knowledge_base_name, top_k, score_threshold, history, stream, model_name, temperature, max_tokens, prompt_name, request) 2 | **knowledge_base_chat**: 该函数用于处理用户与知识库的交互对话。 3 | 4 | **参数**: 5 | - `query`: 用户的输入查询,类型为字符串。 6 | - `knowledge_base_name`: 知识库的名称,类型为字符串。 7 | - `top_k`: 匹配向量数,类型为整数。 8 | - `score_threshold`: 知识库匹配相关度阈值,取值范围在0-1之间,类型为浮点数。 9 | - `history`: 历史对话列表,每个元素为一个`History`对象。 10 | - `stream`: 是否以流式输出,类型为布尔值。 11 | - `model_name`: LLM模型名称,类型为字符串。 12 | - `temperature`: LLM采样温度,类型为浮点数。 13 | - `max_tokens`: 限制LLM生成Token数量,类型为整数或None。 14 | - `prompt_name`: 使用的prompt模板名称,类型为字符串。 15 | - `request`: 当前的请求对象,类型为`Request`。 16 | 17 | **代码描述**: 18 | 函数首先通过`KBServiceFactory.get_service_by_name`方法获取指定名称的知识库服务实例。如果未找到对应的知识库,将返回404状态码的响应。然后,将传入的历史对话数据转换为`History`对象列表。接下来,定义了一个异步生成器`knowledge_base_chat_iterator`,用于处理知识库查询和LLM模型生成回答的逻辑。在这个生成器中,首先根据条件调整`max_tokens`的值,然后创建LLM模型实例,并执行文档搜索。如果启用了重排序(reranker),则对搜索结果进行重排序处理。根据搜索结果构建上下文,并生成LLM模型的输入提示。最后,使用LLM模型生成回答,并根据`stream`参数决定是以流式输出还是一次性输出所有结果。 19 | 20 | **注意**: 21 | - 在调用此函数时,需要确保传入的知识库名称在系统中已经存在,否则会返回404错误。 22 | - `history`参数允许传入空列表,表示没有历史对话。 23 | - `stream`参数控制输出模式,当设置为True时,将以流式输出回答和文档信息;否则,将一次性返回所有内容。 24 | - 函数内部使用了多个异步操作,因此在调用时需要使用`await`关键字。 25 | 26 | **输出示例**: 27 | 调用`knowledge_base_chat`函数可能返回的JSON格式示例: 28 | ```json 29 | { 30 | "answer": "这是根据您的查询生成的回答。", 31 | "docs": [ 32 | "出处 [1] [文档名称](文档链接) \n\n文档内容\n\n", 33 | "未找到相关文档,该回答为大模型自身能力解答!" 34 | ] 35 | } 36 | ``` 37 | 如果启用了流式输出,每个生成的回答片段和文档信息将作为独立的JSON对象逐个发送。 38 | ### FunctionDef knowledge_base_chat_iterator(query, top_k, history, model_name, prompt_name) 39 | **knowledge_base_chat_iterator**: 此函数的功能是异步迭代生成基于知识库的聊天回答。 40 | 41 | **参数**: 42 | - `query`: 字符串类型,用户的查询内容。 43 | - `top_k`: 整型,指定返回的最相关文档数量。 44 | - `history`: 可选的历史记录列表,每个历史记录是一个`History`对象。 45 | - `model_name`: 字符串类型,默认为`model_name`,指定使用的模型名称。 46 | - `prompt_name`: 字符串类型,默认为`prompt_name`,指定使用的提示模板名称。 47 | 48 | **代码描述**: 49 | `knowledge_base_chat_iterator`函数是一个异步生成器,用于处理用户的查询请求,并基于知识库内容异步生成聊天回答。首先,函数检查`max_tokens`的有效性,并根据需要调整其值。接着,使用`get_ChatOpenAI`函数初始化一个聊天模型实例,该模型配置了模型名称、温度、最大token数和回调函数。 50 | 51 | 函数通过`run_in_threadpool`异步运行`search_docs`函数,根据用户的查询内容在知识库中搜索相关文档。如果启用了重排序功能(`USE_RERANKER`),则使用`LangchainReranker`类对搜索结果进行重排序,以提高结果的相关性。 52 | 53 | 根据搜索到的文档数量,函数选择相应的提示模板。如果没有找到相关文档,使用“empty”模板;否则,使用指定的`prompt_name`模板。然后,将历史记录和用户的查询请求转换为聊天提示模板。 54 | 55 | 使用`LLMChain`类创建一个聊天链,并通过`wrap_done`函数包装异步任务,以便在任务完成时进行回调处理。函数还生成了源文档的信息,包括文档的出处和内容。 56 | 57 | 最后,根据是否启用流式传输(`stream`),函数以不同方式异步生成聊天回答。如果启用流式传输,使用服务器发送事件(server-sent-events)逐个发送回答的token;否则,将所有token拼接后一次性返回。 58 | 59 | **注意**: 60 | - 在使用此函数时,需要确保提供的`model_name`和`prompt_name`在系统中已配置且有效。 61 | - 当启用重排序功能时,需要确保`LangchainReranker`类的配置正确,包括模型路径和设备类型。 62 | - 函数的异步特性要求调用者使用`async`和`await`关键字进行调用,以确保异步操作的正确执行。 63 | - 在处理大量查询请求时,合理配置`top_k`和重排序参数可以有效提高处理效率和回答质量。 64 | *** 65 | -------------------------------------------------------------------------------- /server/agent/custom_template.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from langchain.agents import Tool, AgentOutputParser 3 | from langchain.prompts import StringPromptTemplate 4 | from typing import List 5 | from langchain.schema import AgentAction, AgentFinish 6 | 7 | from configs import SUPPORT_AGENT_MODEL 8 | from server.agent import model_container 9 | class CustomPromptTemplate(StringPromptTemplate): 10 | template: str 11 | tools: List[Tool] 12 | 13 | def format(self, **kwargs) -> str: 14 | intermediate_steps = kwargs.pop("intermediate_steps") 15 | thoughts = "" 16 | for action, observation in intermediate_steps: 17 | thoughts += action.log 18 | thoughts += f"\nObservation: {observation}\nThought: " 19 | kwargs["agent_scratchpad"] = thoughts 20 | kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools]) 21 | kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools]) 22 | return self.template.format(**kwargs) 23 | 24 | class CustomOutputParser(AgentOutputParser): 25 | begin: bool = False 26 | def __init__(self): 27 | super().__init__() 28 | self.begin = True 29 | 30 | def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction: 31 | if not any(agent in model_container.MODEL for agent in SUPPORT_AGENT_MODEL) and self.begin: 32 | self.begin = False 33 | stop_words = ["Observation:"] 34 | min_index = len(llm_output) 35 | for stop_word in stop_words: 36 | index = llm_output.find(stop_word) 37 | if index != -1 and index < min_index: 38 | min_index = index 39 | llm_output = llm_output[:min_index] 40 | 41 | if "Final Answer:" in llm_output: 42 | self.begin = True 43 | return AgentFinish( 44 | return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()}, 45 | log=llm_output, 46 | ) 47 | parts = llm_output.split("Action:") 48 | if len(parts) < 2: 49 | return AgentFinish( 50 | return_values={"output": f"调用agent工具失败,该回答为大模型自身能力的回答:\n\n `{llm_output}`"}, 51 | log=llm_output, 52 | ) 53 | 54 | action = parts[1].split("Action Input:")[0].strip() 55 | action_input = parts[1].split("Action Input:")[1].strip() 56 | try: 57 | ans = AgentAction( 58 | tool=action, 59 | tool_input=action_input.strip(" ").strip('"'), 60 | log=llm_output 61 | ) 62 | return ans 63 | except: 64 | return AgentFinish( 65 | return_values={"output": f"调用agent失败: `{llm_output}`"}, 66 | log=llm_output, 67 | ) 68 | -------------------------------------------------------------------------------- /markdown_docs/server/chat/agent_chat.md: -------------------------------------------------------------------------------- 1 | ## FunctionDef agent_chat(query, history, stream, model_name, temperature, max_tokens, prompt_name) 2 | **agent_chat**: 该函数用于处理与代理的异步聊天对话。 3 | 4 | **参数**: 5 | - `query`: 用户输入的查询字符串,必填参数。 6 | - `history`: 历史对话列表,每个元素为一个`History`对象。 7 | - `stream`: 是否以流式输出的方式返回数据,默认为`False`。 8 | - `model_name`: 使用的LLM模型名称,默认为`LLM_MODELS`列表中的第一个模型。 9 | - `temperature`: LLM采样温度,用于调整生成文本的随机性,默认值由`TEMPERATURE`常量决定,取值范围为0.0到1.0。 10 | - `max_tokens`: 限制LLM生成Token的数量,默认为`None`,代表使用模型的最大值。 11 | - `prompt_name`: 使用的prompt模板名称,默认为"default"。 12 | 13 | **代码描述**: 14 | `agent_chat`函数是一个异步函数,主要负责处理用户与代理的聊天对话。它首先将传入的历史对话列表`history`转换为`History`对象列表。然后,定义了一个异步迭代器`agent_chat_iterator`,用于生成聊天回复。在`agent_chat_iterator`中,根据传入的参数和配置,初始化相应的LLM模型和工具,处理历史对话记录,并根据用户的查询生成回复。 15 | 16 | 如果设置了`stream`参数为`True`,则函数以流式输出的方式返回数据,适用于需要实时更新聊天内容的场景。在流式输出模式下,函数会根据不同的状态(如工具调用开始、完成、错误等)生成不同的JSON格式数据块,并通过`yield`语句异步返回给调用者。 17 | 18 | 在非流式输出模式下,函数会收集所有生成的聊天回复,并在最终将它们整合为一个JSON格式的响应体返回。 19 | 20 | **注意**: 21 | - 在使用`agent_chat`函数时,需要确保传入的`history`参数格式正确,即每个元素都应为`History`对象或能够转换为`History`对象的数据结构。 22 | - `stream`参数的设置会影响函数的返回方式,根据实际应用场景选择合适的模式。 23 | - 函数依赖于配置好的LLM模型和prompt模板,确保在调用前已正确配置这些依赖项。 24 | 25 | **输出示例**: 26 | 在非流式输出模式下,假设用户的查询得到了一系列的聊天回复,函数可能返回如下格式的JSON数据: 27 | ```json 28 | { 29 | "answer": "这是聊天过程中生成的回复文本。", 30 | "final_answer": "这是最终的回复文本。" 31 | } 32 | ``` 33 | 在流式输出模式下,函数会逐块返回数据,每块数据可能如下所示: 34 | ```json 35 | { 36 | "tools": [ 37 | "工具名称: 天气查询", 38 | "工具状态: 调用成功", 39 | "工具输入: 北京今天天气", 40 | "工具输出: 北京今天多云,10-14摄氏度" 41 | ] 42 | } 43 | ``` 44 | 或者在得到最终回复时: 45 | ```json 46 | { 47 | "final_answer": "这是最终的回复文本。" 48 | } 49 | ``` 50 | ### FunctionDef agent_chat_iterator(query, history, model_name, prompt_name) 51 | **agent_chat_iterator**: 此函数的功能是异步迭代生成代理聊天的响应。 52 | 53 | **参数**: 54 | - `query`: 字符串类型,用户的查询或输入。 55 | - `history`: 可选的`List[History]`类型,表示对话历史记录。 56 | - `model_name`: 字符串类型,默认为`LLM_MODELS`列表中的第一个模型,用于指定使用的语言模型。 57 | - `prompt_name`: 字符串类型,用于指定提示模板的名称。 58 | 59 | **代码描述**: 60 | `agent_chat_iterator`函数是一个异步生成器,用于处理代理聊天的逻辑。首先,函数检查`max_tokens`是否为整数且小于等于0,如果是,则将其设置为`None`。接着,使用`get_ChatOpenAI`函数初始化一个聊天模型实例,并通过`get_kb_details`函数获取知识库列表,将其存储在模型容器的数据库中。如果存在`Agent_MODEL`,则使用该模型初始化另一个聊天模型实例,并将其存储在模型容器中;否则,使用之前初始化的模型。 61 | 62 | 函数通过`get_prompt_template`函数获取指定的提示模板,并使用`CustomPromptTemplate`类创建一个自定义提示模板实例。然后,使用`CustomOutputParser`类创建一个输出解析器实例,并根据模型名称决定使用`initialize_glm3_agent`函数初始化GLM3代理执行器,或者使用`LLMSingleActionAgent`和`AgentExecutor`创建一个代理执行器。 63 | 64 | 在异步循环中,函数尝试创建一个任务,使用`wrap_done`函数包装代理执行器的调用,并在完成时通过回调通知。如果设置了`stream`参数,则函数会异步迭代回调处理器的输出,并根据状态生成不同的响应数据,最终以JSON格式产生输出。如果未设置`stream`参数,则会收集所有输出数据,并在最后生成一个包含答案和最终答案的JSON对象。 65 | 66 | **注意**: 67 | - 在使用此函数时,需要确保提供的`history`参数格式正确,且每个历史记录项都应为`History`类的实例。 68 | - `model_name`和`prompt_name`参数应根据实际需要选择合适的模型和提示模板。 69 | - 函数内部使用了多个异步操作和自定义类,如`CustomAsyncIteratorCallbackHandler`、`CustomPromptTemplate`和`CustomOutputParser`,需要确保这些组件的正确实现和配置。 70 | - 此函数设计为与前端实现实时或异步的聊天交互,因此在集成到聊天系统时,应考虑其异步特性和对外部回调的处理方式。 71 | *** 72 | -------------------------------------------------------------------------------- /server/memory/conversation_db_buffer_memory.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, List, Dict 3 | 4 | from langchain.memory.chat_memory import BaseChatMemory 5 | from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage 6 | from langchain.schema.language_model import BaseLanguageModel 7 | from server.db.repository.message_repository import filter_message 8 | from server.db.models.message_model import MessageModel 9 | 10 | 11 | class ConversationBufferDBMemory(BaseChatMemory): 12 | conversation_id: str 13 | human_prefix: str = "Human" 14 | ai_prefix: str = "Assistant" 15 | llm: BaseLanguageModel 16 | memory_key: str = "history" 17 | max_token_limit: int = 2000 18 | message_limit: int = 10 19 | 20 | @property 21 | def buffer(self) -> List[BaseMessage]: 22 | """String buffer of memory.""" 23 | # fetch limited messages desc, and return reversed 24 | 25 | messages = filter_message(conversation_id=self.conversation_id, limit=self.message_limit) 26 | # 返回的记录按时间倒序,转为正序 27 | messages = list(reversed(messages)) 28 | chat_messages: List[BaseMessage] = [] 29 | for message in messages: 30 | chat_messages.append(HumanMessage(content=message["query"])) 31 | chat_messages.append(AIMessage(content=message["response"])) 32 | 33 | if not chat_messages: 34 | return [] 35 | 36 | # prune the chat message if it exceeds the max token limit 37 | curr_buffer_length = self.llm.get_num_tokens(get_buffer_string(chat_messages)) 38 | if curr_buffer_length > self.max_token_limit: 39 | pruned_memory = [] 40 | while curr_buffer_length > self.max_token_limit and chat_messages: 41 | pruned_memory.append(chat_messages.pop(0)) 42 | curr_buffer_length = self.llm.get_num_tokens(get_buffer_string(chat_messages)) 43 | 44 | return chat_messages 45 | 46 | @property 47 | def memory_variables(self) -> List[str]: 48 | """Will always return list of memory variables. 49 | 50 | :meta private: 51 | """ 52 | return [self.memory_key] 53 | 54 | def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: 55 | """Return history buffer.""" 56 | buffer: Any = self.buffer 57 | if self.return_messages: 58 | final_buffer: Any = buffer 59 | else: 60 | final_buffer = get_buffer_string( 61 | buffer, 62 | human_prefix=self.human_prefix, 63 | ai_prefix=self.ai_prefix, 64 | ) 65 | return {self.memory_key: final_buffer} 66 | 67 | def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: 68 | """Nothing should be saved or changed""" 69 | pass 70 | 71 | def clear(self) -> None: 72 | """Nothing to clear, got a memory like a vault.""" 73 | pass -------------------------------------------------------------------------------- /markdown_docs/text_splitter/ali_text_splitter.md: -------------------------------------------------------------------------------- 1 | ## ClassDef AliTextSplitter 2 | **AliTextSplitter**: AliTextSplitter类的功能是对文本进行分割,特别是针对PDF文档或其他文本,可以选择是否使用文档语义分割模型进行更加精确的文本分割。 3 | 4 | **属性**: 5 | - `pdf`: 布尔值,指示是否对PDF文档进行特殊处理,默认为False。 6 | - `**kwargs`: 接收可变数量的关键字参数,这些参数将传递给父类CharacterTextSplitter的构造函数。 7 | 8 | **代码描述**: 9 | AliTextSplitter类继承自CharacterTextSplitter类,提供了对文本进行分割的功能。在初始化时,可以通过`pdf`参数指定是否对PDF文档进行特殊处理。如果`pdf`为True,会对文本进行预处理,包括合并多余的换行符和空格,以及移除连续的换行符,以便于后续的文本分割处理。 10 | 11 | 在`split_text`方法中,首先根据`pdf`参数的值对文本进行预处理。然后尝试导入`modelscope.pipelines`模块,如果导入失败,会抛出`ImportError`异常,提示用户需要安装`modelscope`包。 12 | 13 | 使用`modelscope.pipelines`的`pipeline`函数创建一个文档分割任务,模型选择为`damo/nlp_bert_document-segmentation_chinese-base`,并指定设备为CPU。通过调用`pipeline`对象的方法对文本进行分割,得到的结果是一个包含分割后文本的列表。 14 | 15 | **注意**: 16 | - 使用此类之前,需要确保已安装`modelscope`包,特别是如果要进行文档语义分割,需要安装`modelscope[nlp]`。 17 | - 文档语义分割模型`damo/nlp_bert_document-segmentation_chinese-base`是基于BERT的中文文档分割模型,对于中文文本有较好的分割效果。 18 | - 在低配置的GPU环境下,由于模型较大,建议将设备设置为CPU进行文本分割处理,以避免可能的性能问题。 19 | 20 | **输出示例**: 21 | ```python 22 | ['这是第一段文本。', '这是第二段文本,包含多个句子。', '这是第三段文本。'] 23 | ``` 24 | 此输出示例展示了`split_text`方法返回的分割后的文本列表,每个元素代表文档中的一段文本。 25 | ### FunctionDef __init__(self, pdf) 26 | **__init__**: 此函数的功能是初始化AliTextSplitter类的实例。 27 | 28 | **参数**: 29 | - `pdf`: 一个布尔值,用于指定是否处理PDF文件,默认值为False。 30 | - `**kwargs`: 接收一个可变数量的关键字参数,这些参数将传递给父类的初始化方法。 31 | 32 | **代码描述**: 33 | 此初始化函数是`AliTextSplitter`类的构造函数,用于创建类的实例时设置初始状态。它接受一个名为`pdf`的参数和多个关键字参数`**kwargs`。`pdf`参数用于指示`AliTextSplitter`实例是否将用于处理PDF文件,其默认值为False,表示默认不处理PDF文件。如果需要处理PDF文件,则在创建`AliTextSplitter`实例时将此参数设置为True。 34 | 35 | 此外,通过`**kwargs`参数,此函数支持接收额外的关键字参数,这些参数不在函数定义中直接声明。这些额外的参数通过`super().__init__(**kwargs)`语句传递给父类的初始化方法。这种设计允许`AliTextSplitter`类在不修改其构造函数签名的情况下,灵活地扩展或修改其父类的行为。 36 | 37 | **注意**: 38 | - 在使用`AliTextSplitter`类时,应根据实际需求决定是否将`pdf`参数设置为True。如果您的应用场景中需要处理PDF文件,则应将此参数设置为True。 39 | - 通过`**kwargs`传递给父类的参数应确保与父类的初始化方法兼容,避免传递无效或不相关的参数,以免引发错误。 40 | *** 41 | ### FunctionDef split_text(self, text) 42 | **split_text**: 该函数的功能是对文本进行语义分割。 43 | 44 | **参数**: 45 | - text: 需要进行分割的文本,数据类型为字符串(str)。 46 | 47 | **代码描述**: 48 | `split_text`函数主要用于对给定的文本进行语义上的分割。它首先检查是否存在`self.pdf`属性,如果存在,会对文本进行预处理,包括合并过多的换行符、将所有空白字符替换为单个空格以及删除连续的换行符。这一步骤旨在清理PDF文档中常见的格式问题,以便于后续的文档分割。 49 | 50 | 接下来,函数尝试导入`modelscope.pipelines`模块,该模块提供了一个`pipeline`函数,用于加载并执行特定的NLP任务。如果导入失败,会抛出`ImportError`异常,提示用户需要安装`modelscope`包。 51 | 52 | 在成功导入`modelscope.pipelines`后,函数使用`pipeline`函数创建一个文档分割任务,指定使用的模型为`damo/nlp_bert_document-segmentation_chinese-base`,并将计算设备设置为CPU。这个模型基于BERT,由阿里巴巴达摩院开源,专门用于中文文档的语义分割。 53 | 54 | 最后,函数将输入文本传递给模型进行分割,并将分割结果(一个包含分割后文本的列表)返回。分割结果是通过将模型输出的文本按`\n\t`分割,并过滤掉空字符串后得到的。 55 | 56 | **注意**: 57 | - 使用该函数前,需要确保已经安装了`modelscope[nlp]`包。可以通过执行`pip install "modelscope[nlp]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html`来安装。 58 | - 由于使用了基于BERT的模型进行文档分割,对计算资源有一定要求。默认情况下,模型会在CPU上运行,但如果有足够的GPU资源,可以通过修改`device`参数来加速计算。 59 | 60 | **输出示例**: 61 | ```python 62 | ['欢迎使用文档分割功能', '这是第二段文本', '这是第三段文本'] 63 | ``` 64 | 此输出示例展示了`split_text`函数处理后的结果,其中输入文本被分割成了三段,每段文本作为列表的一个元素返回。 65 | *** 66 | -------------------------------------------------------------------------------- /markdown_docs/embeddings/add_embedding_keywords.md: -------------------------------------------------------------------------------- 1 | ## FunctionDef get_keyword_embedding(bert_model, tokenizer, key_words) 2 | **get_keyword_embedding**: 该函数的功能是获取关键词的嵌入表示。 3 | 4 | **参数**: 5 | - **bert_model**: 一个预训练的BERT模型,用于生成嵌入表示。 6 | - **tokenizer**: 与BERT模型相匹配的分词器,用于将关键词转换为模型能理解的格式。 7 | - **key_words**: 一个字符串列表,包含需要获取嵌入表示的关键词。 8 | 9 | **代码描述**: 10 | `get_keyword_embedding`函数首先使用传入的`tokenizer`将关键词列表`key_words`转换为模型能够处理的输入格式。这一步骤包括将关键词转换为对应的输入ID,并对输入进行填充和截断以满足模型的要求。随后,函数从`tokenizer`的输出中提取`input_ids`,并去除每个序列的首尾特殊标记,因为这些标记对于关键词的嵌入表示不是必需的。 11 | 12 | 接着,函数利用`bert_model`的`embeddings.word_embeddings`属性,根据`input_ids`获取对应的嵌入表示。由于可能传入多个关键词,函数对所有关键词的嵌入表示进行平均,以获得一个统一的表示形式。 13 | 14 | 在项目中,`get_keyword_embedding`函数被`add_keyword_to_model`函数调用,用于将自定义关键词的嵌入表示添加到预训练的BERT模型中。这一过程涉及读取关键词文件,生成关键词的嵌入表示,扩展模型的嵌入层以包含这些新的关键词,最后将修改后的模型保存到指定路径。这使得模型能够理解并有效处理这些新增的关键词,从而提高模型在特定任务上的性能。 15 | 16 | **注意**: 17 | - 确保传入的`bert_model`和`tokenizer`是匹配的,即它们来源于同一个预训练模型。 18 | - 关键词列表`key_words`应该是经过精心挑选的,因为这些关键词将直接影响模型的理解能力和性能。 19 | - 在调用此函数之前,应该已经准备好关键词文件,并确保其格式正确。 20 | 21 | **输出示例**: 22 | 假设传入了两个关键词`["AI", "机器学习"]`,函数可能返回一个形状为`(2, embedding_size)`的张量,其中`embedding_size`是模型嵌入层的维度,表示这两个关键词的平均嵌入表示。 23 | ## FunctionDef add_keyword_to_model(model_name, keyword_file, output_model_path) 24 | **add_keyword_to_model**: 该函数的功能是将自定义关键词添加到预训练的嵌入模型中。 25 | 26 | **参数**: 27 | - **model_name**: 字符串类型,默认为`EMBEDDING_MODEL`。指定要使用的预训练嵌入模型的名称。 28 | - **keyword_file**: 字符串类型,默认为空字符串。指定包含自定义关键词的文件路径。 29 | - **output_model_path**: 字符串类型,可为`None`。指定添加了关键词后的模型保存路径。 30 | 31 | **代码描述**: 32 | 首先,函数通过读取`keyword_file`文件,将文件中的每一行作为一个关键词添加到`key_words`列表中。接着,使用指定的`model_name`加载一个句子转换器模型(SentenceTransformer模型),并从中提取第一个模块作为词嵌入模型。通过这个词嵌入模型,可以获取到BERT模型及其分词器。 33 | 34 | 然后,函数调用`get_keyword_embedding`函数,传入BERT模型、分词器和关键词列表,以获取这些关键词的嵌入表示。接下来,函数将这些新的关键词嵌入添加到BERT模型的嵌入层中。这一步骤包括扩展分词器以包含新的关键词,调整BERT模型的嵌入层大小以适应新增的关键词,并将关键词的嵌入表示直接赋值给模型的嵌入层权重。 35 | 36 | 最后,如果提供了`output_model_path`参数,则函数会在该路径下创建必要的目录,并将更新后的词嵌入模型以及BERT模型保存到指定位置。这一过程确保了模型能够在后续的使用中,理解并有效处理这些新增的关键词。 37 | 38 | **注意**: 39 | - 确保`keyword_file`文件存在且格式正确,每行应包含一个关键词。 40 | - 由于模型的嵌入层大小会根据新增关键词进行调整,因此在添加关键词后,模型的大小可能会增加。 41 | - 在保存模型时,会使用`safetensors`格式保存BERT模型,确保模型的兼容性和安全性。 42 | - 添加关键词到模型是一个影响模型性能的操作,因此应谨慎选择关键词,并考虑到这些关键词在特定任务上的实际应用价值。 43 | ## FunctionDef add_keyword_to_embedding_model(path) 44 | **add_keyword_to_embedding_model**: 该函数的功能是将自定义关键词添加到指定的嵌入模型中。 45 | 46 | **参数**: 47 | - **path**: 字符串类型,默认为`EMBEDDING_KEYWORD_FILE`。指定包含自定义关键词的文件路径。 48 | 49 | **代码描述**: 50 | 此函数首先通过`os.path.join(path)`获取关键词文件的完整路径。然后,它从配置中读取模型的名称和路径,这些配置通过`MODEL_PATH["embed_model"][EMBEDDING_MODEL]`获得。接着,函数计算模型所在的父目录,并生成一个包含当前时间戳的新模型名称,格式为`EMBEDDING_MODEL_Merge_Keywords_当前时间`,以确保输出模型名称的唯一性。 51 | 52 | 接下来,函数调用`add_keyword_to_model`,这是一个重要的调用关系,因为`add_keyword_to_model`负责实际将关键词添加到嵌入模型中。在调用`add_keyword_to_model`时,传入当前使用的模型名称、关键词文件路径以及新模型的保存路径。这一步骤完成了将自定义关键词集成到预训练嵌入模型中的核心功能。 53 | 54 | **注意**: 55 | - 确保传入的`path`参数指向一个有效的关键词文件,且该文件格式正确,每行包含一个要添加的关键词。 56 | - 该函数依赖于`add_keyword_to_model`函数,后者负责实际的关键词添加逻辑,包括读取关键词、更新模型的嵌入层以及保存更新后的模型。因此,了解`add_keyword_to_model`的具体实现对于理解整个关键词添加过程是非常重要的。 57 | - 生成的新模型名称包含时间戳,这有助于区分不同时间点生成的模型版本。 58 | - 在使用此函数时,应考虑到模型大小可能会因为添加新的关键词而增加,这可能会对模型加载和运行时的性能产生影响。 59 | -------------------------------------------------------------------------------- /text_splitter/chinese_text_splitter.py: -------------------------------------------------------------------------------- 1 | from langchain.text_splitter import CharacterTextSplitter 2 | import re 3 | from typing import List 4 | 5 | 6 | class ChineseTextSplitter(CharacterTextSplitter): 7 | def __init__(self, pdf: bool = False, sentence_size: int = 250, **kwargs): 8 | super().__init__(**kwargs) 9 | self.pdf = pdf 10 | self.sentence_size = sentence_size 11 | 12 | def split_text1(self, text: str) -> List[str]: 13 | if self.pdf: 14 | text = re.sub(r"\n{3,}", "\n", text) 15 | text = re.sub('\s', ' ', text) 16 | text = text.replace("\n\n", "") 17 | sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del :; 18 | sent_list = [] 19 | for ele in sent_sep_pattern.split(text): 20 | if sent_sep_pattern.match(ele) and sent_list: 21 | sent_list[-1] += ele 22 | elif ele: 23 | sent_list.append(ele) 24 | return sent_list 25 | 26 | def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑 27 | if self.pdf: 28 | text = re.sub(r"\n{3,}", r"\n", text) 29 | text = re.sub('\s', " ", text) 30 | text = re.sub("\n\n", "", text) 31 | 32 | text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text) # 单字符断句符 33 | text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号 34 | text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号 35 | text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text) 36 | # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号 37 | text = text.rstrip() # 段尾如果有多余的\n就去掉它 38 | # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。 39 | ls = [i for i in text.split("\n") if i] 40 | for ele in ls: 41 | if len(ele) > self.sentence_size: 42 | ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele) 43 | ele1_ls = ele1.split("\n") 44 | for ele_ele1 in ele1_ls: 45 | if len(ele_ele1) > self.sentence_size: 46 | ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1) 47 | ele2_ls = ele_ele2.split("\n") 48 | for ele_ele2 in ele2_ls: 49 | if len(ele_ele2) > self.sentence_size: 50 | ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2) 51 | ele2_id = ele2_ls.index(ele_ele2) 52 | ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[ 53 | ele2_id + 1:] 54 | ele_id = ele1_ls.index(ele_ele1) 55 | ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:] 56 | 57 | id = ls.index(ele) 58 | ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:] 59 | return ls 60 | -------------------------------------------------------------------------------- /markdown_docs/document_loaders/myimgloader.md: -------------------------------------------------------------------------------- 1 | ## ClassDef RapidOCRLoader 2 | **RapidOCRLoader**: RapidOCRLoader的功能是将图像文件中的文本通过OCR技术提取出来,并将提取的文本进行结构化处理。 3 | 4 | **属性**: 5 | - 无特定公开属性,继承自UnstructuredFileLoader的属性。 6 | 7 | **代码描述**: 8 | RapidOCRLoader是一个继承自UnstructuredFileLoader的类,专门用于处理图像文件中的文本提取。它通过定义一个内部函数`img2text`来实现OCR(光学字符识别)功能。`img2text`函数接受一个文件路径作为输入,使用`get_ocr`函数获取OCR处理器,然后对指定的图像文件进行文本识别。识别结果是一个列表,其中每个元素包含识别的文本行。这些文本行随后被连接成一个字符串,作为函数的返回值。 9 | 10 | 在`_get_elements`方法中,调用了`img2text`函数处理类初始化时指定的文件路径,将图像文件中的文本提取出来。提取出的文本随后通过`partition_text`函数进行结构化处理,这个函数根据提供的参数(通过`self.unstructured_kwargs`传递)对文本进行分区,最终返回一个文本分区列表。 11 | 12 | 在项目中,RapidOCRLoader类被用于测试模块`test_imgloader.py`中,通过`test_rapidocrloader`函数进行测试。测试函数创建了一个RapidOCRLoader实例,传入了一个OCR测试用的图像文件路径,然后调用`load`方法加载处理结果。测试验证了RapidOCRLoader能够成功提取图像中的文本,并且返回值是一个包含至少一个元素的列表,列表中的每个元素都是一个包含提取文本的对象。 13 | 14 | **注意**: 15 | - 使用RapidOCRLoader之前,需要确保OCR处理器(通过`get_ocr`函数获取)已正确配置并可用。 16 | - 该类主要用于处理图像文件中的文本提取,不适用于非图像文件。 17 | 18 | **输出示例**: 19 | ```python 20 | [ 21 | { 22 | "page_content": "这是通过OCR技术提取的文本内容。" 23 | } 24 | ] 25 | ``` 26 | 此输出示例展示了RapidOCRLoader处理图像文件并通过OCR技术提取文本后的可能返回值。返回值是一个列表,列表中的每个元素都是一个字典,其中`page_content`键对应的值是提取的文本内容。 27 | ### FunctionDef _get_elements(self) 28 | **_get_elements**: 该函数的功能是将图片文件中的文本内容提取出来,并根据给定的参数对提取出的文本进行分段处理。 29 | 30 | **参数**: 31 | - 无直接参数,但函数通过`self.file_path`访问图片路径,通过`self.unstructured_kwargs`访问用于文本分段的参数。 32 | 33 | **代码描述**: 34 | 该函数首先定义了一个内部函数`img2text`,用于将指定路径的图片文件转换为文本。`img2text`函数通过调用`get_ocr()`函数获取OCR(光学字符识别)服务的实例,然后使用此实例对图片文件进行识别,将识别结果中的文本内容提取出来并返回。 35 | 36 | 在`_get_elements`函数的主体中,首先调用`img2text`函数,将`self.file_path`指定的图片文件转换为文本。然后,使用`partition_text`函数对提取的文本进行分段处理。`partition_text`函数接受一个文本字符串和一组分段参数(通过`self.unstructured_kwargs`提供),并返回分段后的文本列表。 37 | 38 | **注意**: 39 | - 确保`self.file_path`正确指向了需要处理的图片文件。 40 | - `self.unstructured_kwargs`应包含适用于`partition_text`函数的所有必要参数,以确保文本可以按预期进行分段处理。 41 | - OCR识别的准确性可能受到图片质量和内容复杂度的影响,因此在处理极其复杂或低质量的图片时可能会遇到识别准确度下降的问题。 42 | 43 | **输出示例**: 44 | 假设图片中包含以下文本内容:“Hello World! Welcome to OCR processing.”,并且`partition_text`函数的参数设置为按句子分段,那么该函数可能返回如下列表: 45 | ```python 46 | ["Hello World!", "Welcome to OCR processing."] 47 | ``` 48 | #### FunctionDef img2text(filepath) 49 | **img2text**: 此函数的功能是将图片文件中的文字通过OCR技术识别出来,并以字符串形式返回。 50 | 51 | **参数**: 52 | - filepath: 字符串类型,指定需要进行文字识别的图片文件路径。 53 | 54 | **代码描述**: 55 | `img2text`函数是一个用于图像文字识别的高级封装。它首先调用`get_ocr`函数获取一个OCR对象,该对象是根据系统配置(是否使用CUDA加速)动态选择的OCR实现。随后,使用该OCR对象对传入的图片文件路径`filepath`指向的图片进行文字识别。识别结果是一个列表,其中每个元素是一个包含识别区域坐标和识别出的文字的元组。函数进一步处理这个列表,提取出所有识别到的文字,并将它们连接成一个单一的字符串,每行文字之间用换行符`\n`分隔。最后,返回这个字符串。 56 | 57 | 从功能角度看,`img2text`与其调用的`get_ocr`函数紧密相关。`get_ocr`负责提供OCR服务的对象,而`img2text`则利用这个对象完成具体的图像文字识别任务。这种设计使得`img2text`能够灵活适应不同的OCR技术实现,同时也便于在项目中重用OCR服务。 58 | 59 | **注意**: 60 | - 确保传入的`filepath`是有效的图片文件路径,且文件存在。否则,OCR识别过程可能失败。 61 | - OCR识别的准确性受到多种因素的影响,包括图片质量、文字清晰度和字体大小等,因此在使用时应考虑这些因素可能对识别结果的影响。 62 | - 根据`get_ocr`函数的说明,如果系统中未安装支持CUDA的OCR包或在不支持CUDA的环境中运行,应确保`get_ocr`函数的`use_cuda`参数被设置为False,以避免运行时错误。 63 | 64 | **输出示例**: 65 | ``` 66 | 这是一个OCR识别的示例文本。 67 | 第二行文字。 68 | ``` 69 | 此输出示例展示了`img2text`函数处理后的可能输出,其中包含了从图片中识别出的文字,每行文字之间用换行符分隔。实际输出将根据输入图片中的文字内容而有所不同。 70 | *** 71 | *** 72 | -------------------------------------------------------------------------------- /server/model_workers/tiangong.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import hashlib 4 | 5 | from fastchat.conversation import Conversation 6 | from server.model_workers.base import * 7 | from server.utils import get_httpx_client 8 | from fastchat import conversation as conv 9 | import json 10 | from typing import List, Literal, Dict 11 | import requests 12 | 13 | 14 | class TianGongWorker(ApiModelWorker): 15 | def __init__( 16 | self, 17 | *, 18 | controller_addr: str = None, 19 | worker_addr: str = None, 20 | model_names: List[str] = ["tiangong-api"], 21 | version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse", 22 | **kwargs, 23 | ): 24 | kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) 25 | kwargs.setdefault("context_len", 32768) 26 | super().__init__(**kwargs) 27 | self.version = version 28 | 29 | def do_chat(self, params: ApiChatParams) -> Dict: 30 | params.load_config(self.model_names[0]) 31 | 32 | url = 'https://sky-api.singularity-ai.com/saas/api/v4/generate' 33 | data = { 34 | "messages": params.messages, 35 | "model": "SkyChat-MegaVerse" 36 | } 37 | timestamp = str(int(time.time())) 38 | sign_content = params.api_key + params.secret_key + timestamp 39 | sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest() 40 | headers = { 41 | "app_key": params.api_key, 42 | "timestamp": timestamp, 43 | "sign": sign_result, 44 | "Content-Type": "application/json", 45 | "stream": "true" # or change to "false" 不处理流式返回内容 46 | } 47 | 48 | # 发起请求并获取响应 49 | response = requests.post(url, headers=headers, json=data, stream=True) 50 | 51 | text = "" 52 | # 处理响应流 53 | for line in response.iter_lines(chunk_size=None, decode_unicode=True): 54 | if line: 55 | # 处理接收到的数据 56 | # print(line.decode('utf-8')) 57 | resp = json.loads(line) 58 | if resp["code"] == 200: 59 | text += resp['resp_data']['reply'] 60 | yield { 61 | "error_code": 0, 62 | "text": text 63 | } 64 | else: 65 | data = { 66 | "error_code": resp["code"], 67 | "text": resp["code_msg"] 68 | } 69 | self.logger.error(f"请求天工 API 时出错:{data}") 70 | yield data 71 | 72 | def get_embeddings(self, params): 73 | print("embedding") 74 | print(params) 75 | 76 | def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: 77 | return conv.Conversation( 78 | name=self.model_names[0], 79 | system_message="", 80 | messages=[], 81 | roles=["user", "system"], 82 | sep="\n### ", 83 | stop_str="###", 84 | ) 85 | -------------------------------------------------------------------------------- /server/chat/completion.py: -------------------------------------------------------------------------------- 1 | from fastapi import Body 2 | from sse_starlette.sse import EventSourceResponse 3 | from configs import LLM_MODELS, TEMPERATURE 4 | from server.utils import wrap_done, get_OpenAI 5 | from langchain.chains import LLMChain 6 | from langchain.callbacks import AsyncIteratorCallbackHandler 7 | from typing import AsyncIterable, Optional 8 | import asyncio 9 | from langchain.prompts import PromptTemplate 10 | 11 | from server.utils import get_prompt_template 12 | 13 | 14 | async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), 15 | stream: bool = Body(False, description="流式输出"), 16 | echo: bool = Body(False, description="除了输出之外,还回显输入"), 17 | model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), 18 | temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), 19 | max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"), 20 | # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), 21 | prompt_name: str = Body("default", 22 | description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), 23 | ): 24 | 25 | #todo 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理 26 | async def completion_iterator(query: str, 27 | model_name: str = LLM_MODELS[0], 28 | prompt_name: str = prompt_name, 29 | echo: bool = echo, 30 | ) -> AsyncIterable[str]: 31 | nonlocal max_tokens 32 | callback = AsyncIteratorCallbackHandler() 33 | if isinstance(max_tokens, int) and max_tokens <= 0: 34 | max_tokens = None 35 | 36 | model = get_OpenAI( 37 | model_name=model_name, 38 | temperature=temperature, 39 | max_tokens=max_tokens, 40 | callbacks=[callback], 41 | echo=echo 42 | ) 43 | 44 | prompt_template = get_prompt_template("completion", prompt_name) 45 | prompt = PromptTemplate.from_template(prompt_template) 46 | chain = LLMChain(prompt=prompt, llm=model) 47 | 48 | # Begin a task that runs in the background. 49 | task = asyncio.create_task(wrap_done( 50 | chain.acall({"input": query}), 51 | callback.done), 52 | ) 53 | 54 | if stream: 55 | async for token in callback.aiter(): 56 | # Use server-sent-events to stream the response 57 | yield token 58 | else: 59 | answer = "" 60 | async for token in callback.aiter(): 61 | answer += token 62 | yield answer 63 | 64 | await task 65 | 66 | return EventSourceResponse(completion_iterator(query=query, 67 | model_name=model_name, 68 | prompt_name=prompt_name), 69 | ) 70 | -------------------------------------------------------------------------------- /markdown_docs/server/minx_chat_openai.md: -------------------------------------------------------------------------------- 1 | ## ClassDef MinxChatOpenAI 2 | **MinxChatOpenAI**: MinxChatOpenAI类的功能是提供与tiktoken库交互的方法,用于导入tiktoken库和获取编码模型。 3 | 4 | **属性**: 5 | 此类主要通过静态方法实现功能,不直接使用属性存储数据。 6 | 7 | **代码描述**: 8 | MinxChatOpenAI类包含两个静态方法:`import_tiktoken`和`get_encoding_model`。 9 | 10 | - `import_tiktoken`方法尝试导入`tiktoken`包,如果导入失败,则抛出`ValueError`异常,提示用户需要安装`tiktoken`。这是为了确保后续操作可以使用`tiktoken`包提供的功能。 11 | 12 | - `get_encoding_model`方法负责根据模型名称获取相应的编码模型。它首先尝试从`tiktoken`库中获取指定模型的编码信息。如果模型名称是`gpt-3.5-turbo`或`gpt-4`,方法会自动调整为对应的具体版本,以适应模型可能的更新。如果指定的模型在`tiktoken`库中找不到,将使用默认的`cl100k_base`编码模型,并记录一条警告信息。 13 | 14 | 在项目中,`MinxChatOpenAI`类的`get_encoding_model`方法被`get_ChatOpenAI`函数调用,以配置和初始化`ChatOpenAI`实例。这表明`MinxChatOpenAI`类提供的功能是为`ChatOpenAI`实例获取正确的编码模型,这对于处理和理解聊天内容至关重要。 15 | 16 | **注意**: 17 | - 使用`MinxChatOpenAI`类之前,请确保已经安装了`tiktoken`包,否则将无法成功导入和使用。 18 | - 在调用`get_encoding_model`方法时,需要注意传入的模型名称是否正确,以及是否准备好处理可能的异常和警告。 19 | 20 | **输出示例**: 21 | 调用`get_encoding_model`方法可能返回的示例输出为: 22 | ```python 23 | ("gpt-3.5-turbo-0301", ) 24 | ``` 25 | 这表示方法返回了模型名称和对应的编码对象。 26 | ### FunctionDef import_tiktoken 27 | **import_tiktoken**: 该函数的功能是导入tiktoken库。 28 | 29 | **参数**: 此函数没有参数。 30 | 31 | **代码描述**: `import_tiktoken` 函数尝试导入 `tiktoken` Python包。如果导入失败,即 `tiktoken` 包未安装在环境中,函数将抛出一个 `ImportError` 异常。为了向用户提供清晰的错误信息,函数捕获了这个异常并抛出一个新的 `ValueError`,提示用户需要安装 `tiktoken` 包以计算 `get_token_ids`。这个函数是 `MinxChatOpenAI` 类的一部分,主要用于在需要使用 `tiktoken` 功能时确保该库已被导入。在项目中,`import_tiktoken` 被 `get_encoding_model` 方法调用,用于获取特定模型的编码信息。这表明 `tiktoken` 库在处理模型编码方面起着关键作用。 32 | 33 | 在 `get_encoding_model` 方法中,首先通过调用 `import_tiktoken` 函数来确保 `tiktoken` 库可用。然后,根据模型名称(`self.tiktoken_model_name` 或 `self.model_name`)获取相应的编码信息。如果指定的模型名称不被支持,将使用默认的编码模型。这个过程展示了 `import_tiktoken` 在项目中的实际应用,即作为获取模型编码前的必要步骤。 34 | 35 | **注意**: 使用此函数前,请确保已经安装了 `tiktoken` 包。如果未安装,可以通过运行 `pip install tiktoken` 来安装。此外,当 `tiktoken` 包导入失败时,函数将抛出一个 `ValueError`,提示需要安装该包。开发者应当注意捕获并妥善处理这一异常,以避免程序在未安装 `tiktoken` 包时崩溃。 36 | 37 | **输出示例**: 由于此函数的目的是导入 `tiktoken` 包,因此它不直接返回数据。成功执行后,它将返回 `tiktoken` 模块对象,允许后续代码调用 `tiktoken` 的功能。例如,成功导入后,可以使用 `tiktoken.encoding_for_model(model_name)` 来获取指定模型的编码信息。 38 | *** 39 | ### FunctionDef get_encoding_model(self) 40 | **get_encoding_model**: 该函数的功能是获取指定模型的编码信息。 41 | 42 | **参数**: 此函数没有参数。 43 | 44 | **代码描述**: `get_encoding_model` 方法首先尝试通过调用 `import_tiktoken` 函数来导入 `tiktoken` 库,确保后续操作可以使用 `tiktoken` 提供的功能。接着,根据实例变量 `self.tiktoken_model_name` 或 `self.model_name` 来确定需要获取编码信息的模型名称。如果 `self.tiktoken_model_name` 不为 `None`,则直接使用该值;否则,使用 `self.model_name`。对于特定的模型名称,如 "gpt-3.5-turbo" 或 "gpt-4",方法内部会将其转换为具体的版本名称,以适应模型可能随时间更新的情况。之后,尝试使用 `tiktoken_.encoding_for_model(model)` 获取指定模型的编码信息。如果在此过程中发生异常(例如模型名称不被支持),则会捕获异常并记录警告信息,同时使用默认的编码模型 "cl100k_base"。最后,方法返回一个包含模型名称和编码信息的元组。 45 | 46 | **注意**: 在使用 `get_encoding_model` 方法之前,确保已经安装了 `tiktoken` 包。如果在尝试导入 `tiktoken` 时遇到问题,会抛出 `ValueError` 异常,提示需要安装 `tiktoken` 包。此外,当指定的模型名称不被支持时,方法会默认使用 "cl100k_base" 编码模型,并记录一条警告信息。 47 | 48 | **输出示例**: 假设调用 `get_encoding_model` 方法并且指定的模型名称被正确识别,可能的返回值为: 49 | 50 | ```python 51 | ("gpt-3.5-turbo-0301", ) 52 | ``` 53 | 54 | 其中,返回的第一个元素是模型名称,第二个元素是该模型对应的编码信息对象。如果模型名称不被支持,返回值可能为: 55 | 56 | ```python 57 | ("cl100k_base", ) 58 | ``` 59 | 60 | 这表明方法使用了默认的编码模型 "cl100k_base"。 61 | *** 62 | -------------------------------------------------------------------------------- /embeddings/add_embedding_keywords.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 该功能是为了将关键词加入到embedding模型中,以便于在embedding模型中进行关键词的embedding 3 | 该功能的实现是通过修改embedding模型的tokenizer来实现的 4 | 该功能仅仅对EMBEDDING_MODEL参数对应的的模型有效,输出后的模型保存在原本模型 5 | 感谢@CharlesJu1和@charlesyju的贡献提出了想法和最基础的PR 6 | 7 | 保存的模型的位置位于原本嵌入模型的目录下,模型的名称为原模型名称+Merge_Keywords_时间戳 8 | ''' 9 | import sys 10 | 11 | sys.path.append("..") 12 | import os 13 | import torch 14 | 15 | from datetime import datetime 16 | from configs import ( 17 | MODEL_PATH, 18 | EMBEDDING_MODEL, 19 | EMBEDDING_KEYWORD_FILE, 20 | ) 21 | 22 | from safetensors.torch import save_model 23 | from sentence_transformers import SentenceTransformer 24 | from langchain_core._api import deprecated 25 | 26 | 27 | @deprecated( 28 | since="0.3.0", 29 | message="自定义关键词 Langchain-Chatchat 0.3.x 重写, 0.2.x中相关功能将废弃", 30 | removal="0.3.0" 31 | ) 32 | def get_keyword_embedding(bert_model, tokenizer, key_words): 33 | tokenizer_output = tokenizer(key_words, return_tensors="pt", padding=True, truncation=True) 34 | input_ids = tokenizer_output['input_ids'] 35 | input_ids = input_ids[:, 1:-1] 36 | 37 | keyword_embedding = bert_model.embeddings.word_embeddings(input_ids) 38 | keyword_embedding = torch.mean(keyword_embedding, 1) 39 | return keyword_embedding 40 | 41 | 42 | def add_keyword_to_model(model_name=EMBEDDING_MODEL, keyword_file: str = "", output_model_path: str = None): 43 | key_words = [] 44 | with open(keyword_file, "r") as f: 45 | for line in f: 46 | key_words.append(line.strip()) 47 | 48 | st_model = SentenceTransformer(model_name) 49 | key_words_len = len(key_words) 50 | word_embedding_model = st_model._first_module() 51 | bert_model = word_embedding_model.auto_model 52 | tokenizer = word_embedding_model.tokenizer 53 | key_words_embedding = get_keyword_embedding(bert_model, tokenizer, key_words) 54 | 55 | embedding_weight = bert_model.embeddings.word_embeddings.weight 56 | embedding_weight_len = len(embedding_weight) 57 | tokenizer.add_tokens(key_words) 58 | bert_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32) 59 | embedding_weight = bert_model.embeddings.word_embeddings.weight 60 | with torch.no_grad(): 61 | embedding_weight[embedding_weight_len:embedding_weight_len + key_words_len, :] = key_words_embedding 62 | 63 | if output_model_path: 64 | os.makedirs(output_model_path, exist_ok=True) 65 | word_embedding_model.save(output_model_path) 66 | safetensors_file = os.path.join(output_model_path, "model.safetensors") 67 | metadata = {'format': 'pt'} 68 | save_model(bert_model, safetensors_file, metadata) 69 | print("save model to {}".format(output_model_path)) 70 | 71 | 72 | def add_keyword_to_embedding_model(path: str = EMBEDDING_KEYWORD_FILE): 73 | keyword_file = os.path.join(path) 74 | model_name = MODEL_PATH["embed_model"][EMBEDDING_MODEL] 75 | model_parent_directory = os.path.dirname(model_name) 76 | current_time = datetime.now().strftime('%Y%m%d_%H%M%S') 77 | output_model_name = "{}_Merge_Keywords_{}".format(EMBEDDING_MODEL, current_time) 78 | output_model_path = os.path.join(model_parent_directory, output_model_name) 79 | add_keyword_to_model(model_name, keyword_file, output_model_path) 80 | -------------------------------------------------------------------------------- /document_loaders/FilteredCSVloader.py: -------------------------------------------------------------------------------- 1 | ## 指定制定列的csv文件加载器 2 | 3 | from langchain.document_loaders import CSVLoader 4 | import csv 5 | from io import TextIOWrapper 6 | from typing import Dict, List, Optional 7 | from langchain.docstore.document import Document 8 | from langchain.document_loaders.helpers import detect_file_encodings 9 | 10 | 11 | class FilteredCSVLoader(CSVLoader): 12 | def __init__( 13 | self, 14 | file_path: str, 15 | columns_to_read: List[str], 16 | source_column: Optional[str] = None, 17 | metadata_columns: List[str] = [], 18 | csv_args: Optional[Dict] = None, 19 | encoding: Optional[str] = None, 20 | autodetect_encoding: bool = False, 21 | ): 22 | super().__init__( 23 | file_path=file_path, 24 | source_column=source_column, 25 | metadata_columns=metadata_columns, 26 | csv_args=csv_args, 27 | encoding=encoding, 28 | autodetect_encoding=autodetect_encoding, 29 | ) 30 | self.columns_to_read = columns_to_read 31 | 32 | def load(self) -> List[Document]: 33 | """Load data into document objects.""" 34 | 35 | docs = [] 36 | try: 37 | with open(self.file_path, newline="", encoding=self.encoding) as csvfile: 38 | docs = self.__read_file(csvfile) 39 | except UnicodeDecodeError as e: 40 | if self.autodetect_encoding: 41 | detected_encodings = detect_file_encodings(self.file_path) 42 | for encoding in detected_encodings: 43 | try: 44 | with open( 45 | self.file_path, newline="", encoding=encoding.encoding 46 | ) as csvfile: 47 | docs = self.__read_file(csvfile) 48 | break 49 | except UnicodeDecodeError: 50 | continue 51 | else: 52 | raise RuntimeError(f"Error loading {self.file_path}") from e 53 | except Exception as e: 54 | raise RuntimeError(f"Error loading {self.file_path}") from e 55 | 56 | return docs 57 | 58 | def __read_file(self, csvfile: TextIOWrapper) -> List[Document]: 59 | docs = [] 60 | csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore 61 | for i, row in enumerate(csv_reader): 62 | content = [] 63 | for col in self.columns_to_read: 64 | if col in row: 65 | content.append(f'{col}:{str(row[col])}') 66 | else: 67 | raise ValueError(f"Column '{self.columns_to_read[0]}' not found in CSV file.") 68 | content = '\n'.join(content) 69 | # Extract the source if available 70 | source = ( 71 | row.get(self.source_column, None) 72 | if self.source_column is not None 73 | else self.file_path 74 | ) 75 | metadata = {"source": source, "row": i} 76 | 77 | for col in self.metadata_columns: 78 | if col in row: 79 | metadata[col] = row[col] 80 | 81 | doc = Document(page_content=content, metadata=metadata) 82 | docs.append(doc) 83 | 84 | return docs 85 | -------------------------------------------------------------------------------- /document_loaders/mydocloader.py: -------------------------------------------------------------------------------- 1 | from langchain.document_loaders.unstructured import UnstructuredFileLoader 2 | from typing import List 3 | import tqdm 4 | 5 | 6 | class RapidOCRDocLoader(UnstructuredFileLoader): 7 | def _get_elements(self) -> List: 8 | def doc2text(filepath): 9 | from docx.table import _Cell, Table 10 | from docx.oxml.table import CT_Tbl 11 | from docx.oxml.text.paragraph import CT_P 12 | from docx.text.paragraph import Paragraph 13 | from docx import Document, ImagePart 14 | from PIL import Image 15 | from io import BytesIO 16 | import numpy as np 17 | from rapidocr_onnxruntime import RapidOCR 18 | ocr = RapidOCR() 19 | doc = Document(filepath) 20 | resp = "" 21 | 22 | def iter_block_items(parent): 23 | from docx.document import Document 24 | if isinstance(parent, Document): 25 | parent_elm = parent.element.body 26 | elif isinstance(parent, _Cell): 27 | parent_elm = parent._tc 28 | else: 29 | raise ValueError("RapidOCRDocLoader parse fail") 30 | 31 | for child in parent_elm.iterchildren(): 32 | if isinstance(child, CT_P): 33 | yield Paragraph(child, parent) 34 | elif isinstance(child, CT_Tbl): 35 | yield Table(child, parent) 36 | 37 | b_unit = tqdm.tqdm(total=len(doc.paragraphs)+len(doc.tables), 38 | desc="RapidOCRDocLoader block index: 0") 39 | for i, block in enumerate(iter_block_items(doc)): 40 | b_unit.set_description( 41 | "RapidOCRDocLoader block index: {}".format(i)) 42 | b_unit.refresh() 43 | if isinstance(block, Paragraph): 44 | resp += block.text.strip() + "\n" 45 | images = block._element.xpath('.//pic:pic') # 获取所有图片 46 | for image in images: 47 | for img_id in image.xpath('.//a:blip/@r:embed'): # 获取图片id 48 | part = doc.part.related_parts[img_id] # 根据图片id获取对应的图片 49 | if isinstance(part, ImagePart): 50 | image = Image.open(BytesIO(part._blob)) 51 | result, _ = ocr(np.array(image)) 52 | if result: 53 | ocr_result = [line[1] for line in result] 54 | resp += "\n".join(ocr_result) 55 | elif isinstance(block, Table): 56 | for row in block.rows: 57 | for cell in row.cells: 58 | for paragraph in cell.paragraphs: 59 | resp += paragraph.text.strip() + "\n" 60 | b_unit.update(1) 61 | return resp 62 | 63 | text = doc2text(self.file_path) 64 | from unstructured.partition.text import partition_text 65 | return partition_text(text=text, **self.unstructured_kwargs) 66 | 67 | 68 | if __name__ == '__main__': 69 | loader = RapidOCRDocLoader(file_path="../tests/samples/ocr_test.docx") 70 | docs = loader.load() 71 | print(docs) 72 | -------------------------------------------------------------------------------- /markdown_docs/server/knowledge_base/kb_api.md: -------------------------------------------------------------------------------- 1 | ## FunctionDef list_kbs 2 | **list_kbs**: 此函数的功能是获取知识库列表。 3 | 4 | **参数**: 此函数不接受任何参数。 5 | 6 | **代码描述**: `list_kbs` 函数是一个无参数函数,用于从数据库中获取知识库的列表。它通过调用 `list_kbs_from_db` 函数来实现这一功能。`list_kbs_from_db` 函数从数据库中查询满足特定条件的知识库名称列表,并返回这些名称。然后,`list_kbs` 函数将这些名称封装在 `ListResponse` 类的实例中返回。`ListResponse` 类是专门用于封装列表数据响应的类,它继承自 `BaseResponse` 类,能够提供状态码、状态消息以及数据列表。这样的设计使得 API 的响应格式保持一致,便于前端开发者理解和使用。 7 | 8 | **注意**: 9 | - `list_kbs` 函数依赖于 `list_kbs_from_db` 函数正确地从数据库中获取知识库名称列表。因此,确保数据库连接和查询逻辑正确是使用此函数的前提。 10 | - 返回的 `ListResponse` 实例中包含的数据列表应正确反映数据库中的知识库情况。这要求 `list_kbs_from_db` 函数能准确地执行其查询逻辑。 11 | - 在实际部署和使用时,应注意数据库的性能和响应时间,尤其是在知识库数量较多的情况下,以保证良好的用户体验。 12 | 13 | **输出示例**: 14 | 假设数据库中存在三个知识库,名称分别为 "知识库A", "知识库B", "知识库C",则函数可能返回的 `ListResponse` 实例如下所示: 15 | ``` 16 | { 17 | "code": 200, 18 | "msg": "success", 19 | "data": ["知识库A", "知识库B", "知识库C"] 20 | } 21 | ``` 22 | 这表示 API 调用成功,且返回了包含三个知识库名称的列表。 23 | ## FunctionDef create_kb(knowledge_base_name, vector_store_type, embed_model) 24 | **create_kb**: 此函数用于创建一个新的知识库。 25 | 26 | **参数**: 27 | - `knowledge_base_name`: 知识库的名称,类型为字符串。通过示例参数可以提供默认示例值。 28 | - `vector_store_type`: 向量存储类型,类型为字符串,默认值为"faiss"。 29 | - `embed_model`: 嵌入模型的名称,类型为字符串,默认使用项目配置的嵌入模型。 30 | 31 | **代码描述**: 32 | 此函数首先通过调用`validate_kb_name`函数验证知识库名称的合法性。如果名称不合法或为空,则分别返回403和404状态码的`BaseResponse`对象,提示错误信息。接下来,使用`KBServiceFactory.get_service_by_name`方法检查是否已存在同名的知识库,如果存在,则返回404状态码的`BaseResponse`对象,提示知识库已存在。如果验证通过,函数将通过`KBServiceFactory.get_service`方法获取对应的知识库服务实例,并调用该实例的`create_kb`方法创建知识库。如果在创建过程中发生异常,将记录错误信息并返回500状态码的`BaseResponse`对象。成功创建知识库后,返回200状态码的`BaseResponse`对象,提示已新增知识库。 33 | 34 | **注意**: 35 | - 在调用此函数创建知识库之前,需要确保知识库名称不为空且不包含非法字符,以避免安全风险。 36 | - 向量存储类型和嵌入模型应根据项目需求和配置进行选择,以确保知识库的正确创建和后续操作的有效性。 37 | - 在处理异常时,应注意记录详细的错误信息,以便于问题的定位和解决。 38 | 39 | **输出示例**: 40 | 如果成功创建名为"技术文档库"的知识库,函数将返回以下`BaseResponse`对象: 41 | ``` 42 | { 43 | "code": 200, 44 | "msg": "已新增知识库 技术文档库" 45 | } 46 | ``` 47 | 如果尝试创建一个已存在的知识库,例如名为"技术文档库",函数将返回: 48 | ``` 49 | { 50 | "code": 404, 51 | "msg": "已存在同名知识库 技术文档库" 52 | } 53 | ``` 54 | 如果知识库名称不合法,将返回: 55 | ``` 56 | { 57 | "code": 403, 58 | "msg": "Don't attack me" 59 | } 60 | ``` 61 | ## FunctionDef delete_kb(knowledge_base_name) 62 | **delete_kb**: 此函数的功能是删除指定的知识库。 63 | 64 | **参数**: 65 | - `knowledge_base_name`: 字符串类型,表示要删除的知识库的名称。此参数通过请求体传入,且提供了示例值 "samples"。 66 | 67 | **代码描述**: 68 | `delete_kb` 函数首先验证知识库名称的合法性。如果名称不合法,即不通过 `validate_kb_name` 函数的验证,将返回一个状态码为403的 `BaseResponse` 对象,消息内容为 "Don't attack me",表示请求被拒绝。接着,函数对知识库名称进行URL解码,以确保名称的正确性。 69 | 70 | 通过 `KBServiceFactory.get_service_by_name` 方法,根据知识库名称获取对应的知识库服务实例。如果实例为 `None`,即知识库不存在,将返回一个状态码为404的 `BaseResponse` 对象,消息内容为 "未找到知识库 {knowledge_base_name}"。 71 | 72 | 若知识库服务实例获取成功,函数尝试调用知识库服务实例的 `clear_vs` 方法来清除知识库中的向量数据,然后调用 `drop_kb` 方法删除知识库。如果删除操作成功,将返回一个状态码为200的 `BaseResponse` 对象,消息内容为 "成功删除知识库 {knowledge_base_name}"。 73 | 74 | 如果在删除过程中发生异常,将捕获异常并记录错误日志,然后返回一个状态码为500的 `BaseResponse` 对象,消息内容为 "删除知识库时出现意外: {e}",其中 `{e}` 是异常信息。 75 | 76 | **注意**: 77 | - 在调用此函数之前,确保传入的知识库名称是经过URL编码的。 78 | - 此函数依赖于 `validate_kb_name` 函数来验证知识库名称的合法性,以防止潜在的安全风险。 79 | - 删除知识库是一个不可逆的操作,一旦执行,知识库中的所有数据将被永久删除。 80 | 81 | **输出示例**: 82 | 如果尝试删除一个不存在的知识库 "unknown_kb",函数可能返回的 `BaseResponse` 对象如下: 83 | ``` 84 | { 85 | "code": 404, 86 | "msg": "未找到知识库 unknown_kb" 87 | } 88 | ``` 89 | 如果成功删除名为 "samples" 的知识库,函数可能返回的 `BaseResponse` 对象如下: 90 | ``` 91 | { 92 | "code": 200, 93 | "msg": "成功删除知识库 samples" 94 | } 95 | ``` 96 | -------------------------------------------------------------------------------- /server/webui_allinone_stale.py: -------------------------------------------------------------------------------- 1 | """Usage 2 | 加载本地模型: 3 | python webui_allinone.py 4 | 5 | 调用远程api服务: 6 | python webui_allinone.py --use-remote-api 7 | 8 | 后台运行webui服务: 9 | python webui_allinone.py --nohup 10 | 11 | 加载多个非默认模型: 12 | python webui_allinone.py --model-path-address model1@host1@port1 model2@host2@port2 13 | 14 | 多卡启动: 15 | python webui_alline.py --model-path-address model@host@port --num-gpus 2 --gpus 0,1 --max-gpu-memory 10GiB 16 | 17 | """ 18 | import streamlit as st 19 | from webui_pages.utils import * 20 | from streamlit_option_menu import option_menu 21 | from webui_pages import * 22 | import os 23 | from server.llm_api_stale import string_args,launch_all,controller_args,worker_args,server_args,LOG_PATH 24 | 25 | from server.api_allinone_stale import parser, api_args 26 | import subprocess 27 | 28 | parser.add_argument("--use-remote-api",action="store_true") 29 | parser.add_argument("--nohup",action="store_true") 30 | parser.add_argument("--server.port",type=int,default=8501) 31 | parser.add_argument("--theme.base",type=str,default='"light"') 32 | parser.add_argument("--theme.primaryColor",type=str,default='"#165dff"') 33 | parser.add_argument("--theme.secondaryBackgroundColor",type=str,default='"#f5f5f5"') 34 | parser.add_argument("--theme.textColor",type=str,default='"#000000"') 35 | web_args = ["server.port","theme.base","theme.primaryColor","theme.secondaryBackgroundColor","theme.textColor"] 36 | 37 | 38 | def launch_api(args,args_list=api_args,log_name=None): 39 | print("Launching api ...") 40 | print("启动API服务...") 41 | if not log_name: 42 | log_name = f"{LOG_PATH}api_{args.api_host}_{args.api_port}" 43 | print(f"logs on api are written in {log_name}") 44 | print(f"API日志位于{log_name}下,如启动异常请查看日志") 45 | args_str = string_args(args,args_list) 46 | api_sh = "python server/{script} {args_str} >{log_name}.log 2>&1 &".format( 47 | script="api.py",args_str=args_str,log_name=log_name) 48 | subprocess.run(api_sh, shell=True, check=True) 49 | print("launch api done!") 50 | print("启动API服务完毕.") 51 | 52 | def launch_webui(args,args_list=web_args,log_name=None): 53 | print("Launching webui...") 54 | print("启动webui服务...") 55 | if not log_name: 56 | log_name = f"{LOG_PATH}webui" 57 | 58 | args_str = string_args(args,args_list) 59 | if args.nohup: 60 | print(f"logs on api are written in {log_name}") 61 | print(f"webui服务日志位于{log_name}下,如启动异常请查看日志") 62 | webui_sh = "streamlit run webui.py {args_str} >{log_name}.log 2>&1 &".format( 63 | args_str=args_str,log_name=log_name) 64 | else: 65 | webui_sh = "streamlit run webui.py {args_str}".format( 66 | args_str=args_str) 67 | subprocess.run(webui_sh, shell=True, check=True) 68 | print("launch webui done!") 69 | print("启动webui服务完毕.") 70 | 71 | 72 | if __name__ == "__main__": 73 | print("Starting webui_allineone.py, it would take a while, please be patient....") 74 | print(f"开始启动webui_allinone,启动LLM服务需要约3-10分钟,请耐心等待,如长时间未启动,请到{LOG_PATH}下查看日志...") 75 | args = parser.parse_args() 76 | 77 | print("*"*80) 78 | if not args.use_remote_api: 79 | launch_all(args=args,controller_args=controller_args,worker_args=worker_args,server_args=server_args) 80 | launch_api(args=args,args_list=api_args) 81 | launch_webui(args=args,args_list=web_args) 82 | print("Start webui_allinone.py done!") 83 | print("感谢耐心等待,启动webui_allinone完毕。") --------------------------------------------------------------------------------