├── img ├── example.jpg └── introduce.jpg ├── knowledge ├── data │ ├── operator_type.csv │ └── dict_type.csv ├── content │ └── question_answer.txt └── source_service.py ├── models ├── llm_tongyi.py ├── llm_baichuan.py └── llm_chatglm.py ├── requirements.txt ├── README_en.md ├── README.md ├── configs └── config.py ├── common ├── dict.py ├── log.py ├── llm_output.py └── structured.py ├── query_data ├── query_route.py ├── db.py └── query_execute.py ├── chains └── chatbi_chain.py └── main_webui.py /img/example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamiclu/Langchain-ChatBI/HEAD/img/example.jpg -------------------------------------------------------------------------------- /img/introduce.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dynamiclu/Langchain-ChatBI/HEAD/img/introduce.jpg -------------------------------------------------------------------------------- /knowledge/data/operator_type.csv: -------------------------------------------------------------------------------- 1 | sum,求和 2 | sum,汇总 3 | avg,平均值 4 | avg,平均 5 | avg,去平均 6 | detail,明细 7 | detail,明细数据 8 | detail,详细数据 9 | min,最小 10 | min,最小值 11 | max,最大值 12 | max,最大 -------------------------------------------------------------------------------- /knowledge/data/dict_type.csv: -------------------------------------------------------------------------------- 1 | >,大于 2 | >=,大于等于 3 | <,小于 4 | <=,小于等于 5 | =,等于 6 | =,范围 7 | day,天 8 | week,周 9 | month,月 10 | hour,小时 11 | quarter,季度 12 | year,年 13 | pv,pv 14 | pv,曝光 15 | pv,曝光量 16 | pv,访问数 17 | pv,访问次数 18 | uv,uv 19 | uv,访问用户数 20 | uv,用户数 21 | uv,曝光用户数 -------------------------------------------------------------------------------- /models/llm_tongyi.py: -------------------------------------------------------------------------------- 1 | import os 2 | from langchain.chat_models.tongyi import ChatTongyi 3 | from configs.config import DASHSCOPE_API_KEY 4 | 5 | 6 | os.environ["DASHSCOPE_API_KEY"] = DASHSCOPE_API_KEY 7 | 8 | class LLMTongyi(ChatTongyi): 9 | streaming = True 10 | 11 | def __init__(self): 12 | super().__init__() 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # requirements 2 | langchain==0.0.346 3 | langchain-core==0.0.10 4 | langchain-experimental==0.0.44 5 | modelscope>=1.9.0 6 | fastapi>=0.104.1 7 | python-multipart 8 | nltk~=3.8.1 9 | uvicorn~=0.23.1 10 | starlette~=0.27.0 11 | unstructured[all-docs]>=0.10.12 12 | faiss-cpu 13 | sentence-transformers 14 | requests 15 | websockets==11.0.3 16 | pathlib 17 | pytest 18 | uvicorn 19 | gradio==4.12.0 20 | gradio-client==0.8.0 21 | pytesseract 22 | unstructured_pytesseract 23 | unstructured_inference 24 | cpm_kernels 25 | protobuf==3.20.0 26 | pymysql -------------------------------------------------------------------------------- /knowledge/content/question_answer.txt: -------------------------------------------------------------------------------- 1 | 问题: "微博过去一个月的访问量?" 答案:{"data_indicators": "pv", "operator_type": "求和", "time_type": "月", "dimension": "微博", "filter": "微博","filter_type": "范围", "date_range": "2023-12-01,2023-12-31", "compare_type": "无"} 2 | 问题: "京东2023年2月的访问用户?" 答案: {"data_indicators": "uv", "operator_type": "求和", "time_type": "年", "dimension": "京东", "filter": "京东","filter_type": "范围", "date_range": "2023-02-01,2023-02-28", "compare_type": "无"} 3 | 问题: "淘宝访问数据明细?" 答案: {"data_indicators": "pv、uv", "operator_type": "明细", "time_type": "天", "dimension": "淘宝", "filter": "淘宝","filter_type": "范围", "date_range": "2023-12-01,2023-12-31", "compare_type": "无"} -------------------------------------------------------------------------------- /models/llm_baichuan.py: -------------------------------------------------------------------------------- 1 | import os 2 | from langchain.chat_models.baichuan import ChatBaichuan 3 | 4 | from configs.config import chat_model_baichuan_dict 5 | 6 | os.environ["DEFAULT_API_BASE"] = chat_model_baichuan_dict["DEFAULT_API_BASE"] 7 | os.environ["BAICHUAN_API_KEY"] = chat_model_baichuan_dict["BAICHUAN_API_KEY"] 8 | os.environ["BAICHUAN_SECRET_KEY"] = chat_model_baichuan_dict["BAICHUAN_SECRET_KEY"] 9 | 10 | class LLMBaiChuan(ChatBaichuan): 11 | model = "Baichuan2-53B" 12 | """model name of Baichuan, default is `Baichuan2-53B`.""" 13 | temperature: float = 0.3 14 | """What sampling temperature to use.""" 15 | top_k: int = 5 16 | """What search sampling control to use.""" 17 | top_p: float = 0.85 18 | 19 | def __init__(self): 20 | super().__init__() 21 | 22 | -------------------------------------------------------------------------------- /README_en.md: -------------------------------------------------------------------------------- 1 | # Langchain-ChatBI 2 | 3 | A conversational BI implemented using the Langchain framework and local vector library aims to help users find and understand data knowledge, analyze data, and gain insights into results. Through natural language dialogue, the threshold for data analysis is lowered. 4 | 5 | ![](img/introduce.jpg) 6 | ## Deploy 7 | 8 | ### 1. Environmental configuration 9 | 10 | + Ensure that your machine has Python 3.8-3.11 installed in the environment configuration 11 | ``` 12 | $ python3 --version 13 | Python 3.10.0 14 | ``` 15 | 16 | ```shell 17 | 18 | $ git clone https://github.com/dynamiclu/Langchain-ChatBI.git 19 | 20 | $ cd Langchain-ChatBI 21 | 22 | $ pip3 install -r requirements.txt 23 | ``` 24 | ### 2. Model Download 25 | + Vector model 26 | 27 | ```python 28 | # bge-large-en-v1.5 29 | from modelscope import snapshot_download 30 | model_dir = snapshot_download('AI-ModelScope/bge-large-en-v1.5') 31 | 32 | #text2vec 33 | from modelscope import snapshot_download 34 | model_dir = snapshot_download('Jerry0/text2vec-large-chinese') 35 | ``` 36 | + LLM 37 | ```python 38 | from modelscope import snapshot_download 39 | model_dir = snapshot_download('ZhipuAI/chatglm2-6b-int4') 40 | ``` 41 | 42 | ### 3. Start Gradio 43 | ```shell 44 | # 启动Gradio 45 | $ python3 main_webui.py 46 | ``` 47 | ### 4. Example 48 | ![](img/example.jpg) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Langchain-ChatBI 2 | 3 | ## 介绍 4 | [_READ THIS IN ENGLISH_](README_en.md) 5 | 6 | 一种利用 [Langchain](https://github.com/hwchase17/langchain) 框架和本地向量库实现的对话式BI,它的目标是帮助用户寻找、理解数据知识,并能够分析数据、洞察结果,通过自然语言对话,降低数据分析的门槛。 7 | 8 | 该项目可以实现本地化部署,可启动本地大模型([ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) ),也可用HTTP调用百川和通义千问大模型。本项目利用大模型的语言理解能力,识别用户的BI意图,利用指标匹配,解决对话式BI如何确保数据准确的难点。 9 | ![](img/introduce.jpg) 10 | ## 部署 11 | 12 | ### 1. 环境配置 13 | 14 | + 确保你的机器安装了 Python 3.8 - 3.11 15 | ``` 16 | $ python3 --version 17 | Python 3.10.0 18 | ``` 19 | 20 | ```shell 21 | # 拉取仓库 22 | $ git clone https://github.com/dynamiclu/Langchain-ChatBI.git 23 | 24 | # 进入目录 25 | $ cd Langchain-ChatBI 26 | 27 | # 安装全部依赖 28 | $ pip3 install -r requirements.txt 29 | ``` 30 | ### 2. 模型下载 31 | + 向量模型 32 | 33 | ```python 34 | # bge-large-en-v1.5 下载 35 | from modelscope import snapshot_download 36 | model_dir = snapshot_download('AI-ModelScope/bge-large-en-v1.5') 37 | 38 | #text2vec 下载 39 | from modelscope import snapshot_download 40 | model_dir = snapshot_download('Jerry0/text2vec-large-chinese') 41 | ``` 42 | + 大模型 43 | ```python 44 | #chatglm2-6b-int4 下载 45 | from modelscope import snapshot_download 46 | model_dir = snapshot_download('ZhipuAI/chatglm2-6b-int4') 47 | ``` 48 | 49 | ### 3. 启动Gradio 50 | ```shell 51 | # 启动Gradio 52 | $ python3 main_webui.py 53 | ``` 54 | ### 4. 示例 55 | ![](img/example.jpg) -------------------------------------------------------------------------------- /configs/config.py: -------------------------------------------------------------------------------- 1 | import torch.cuda 2 | import torch.backends 3 | 4 | APP_BOOT_PATH = "/Users/PycharmProjects/Langchain-ChatBI" 5 | MODEL_BOOT_PATH = "/Users/PycharmProjects/Langchain-ChatBI/llm/models" 6 | 7 | # 本地chatGLM模型配置 8 | VECTOR_SEARCH_TOP_K = 10 9 | LOCAL_EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 10 | LOCAL_LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 11 | VECTOR_STORE_PATH = APP_BOOT_PATH + "/vector_store" 12 | # 多模型选择,向量模型选择 13 | embedding_model_dict = { 14 | "bge-large-zh": MODEL_BOOT_PATH + "/bge-large-zh-v1.5", 15 | "text2vec": MODEL_BOOT_PATH + "/text2vec-large-chinese", 16 | } 17 | LLM_TOP_K = 6 18 | LLM_HISTORY_LEN = 8 19 | 20 | llm_model_dict = { 21 | "chatglm2-6b-int4": MODEL_BOOT_PATH + "/chatglm2-6b-int4", 22 | "Baichuan2-53B": "", 23 | "qwen-turbo": "", 24 | } 25 | EMBEDDING_MODEL_DEFAULT = "bge-large-zh" 26 | 27 | LLM_MODEL_CHAT_GLM = "chatglm2-6b-int4" 28 | LLM_MODEL_BAICHUAN = "Baichuan2-53B" 29 | LLM_MODEL_QIANWEN = "qwen-turbo" 30 | 31 | """ 32 | 百川公司大模型 33 | """ 34 | chat_model_baichuan_dict = { 35 | "BAICHUAN_API_KEY": "####", 36 | "BAICHUAN_SECRET_KEY": "######", 37 | "DEFAULT_API_BASE": "https://api.baichuan-ai.com/v1/chat/completions" 38 | } 39 | 40 | """ 41 | 阿里通义千问大模型key 42 | """ 43 | DASHSCOPE_API_KEY = "#########" 44 | 45 | 46 | WEB_SERVER_NAME = "127.0.0.1" 47 | WEB_SERVER_PORT = 8080 48 | -------------------------------------------------------------------------------- /common/dict.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from common.log import * 3 | from configs.config import * 4 | 5 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 6 | 7 | config_dict = { 8 | "operator_type": APP_BOOT_PATH+"/knowledge/data/operator_type.csv", 9 | "dict_type": APP_BOOT_PATH+"/knowledge/data/dict_type.csv" 10 | } 11 | 12 | FILE_OPERATOR_TYPE = "operator_type" 13 | FILE_DICT_TYPE = "dict_type" 14 | 15 | dict_operator_type = {} 16 | dict_type = {} 17 | 18 | def init_dict(file_path: str, dict_name: str, key: int, val: int): 19 | try: 20 | with open(file_path, "r") as f: 21 | for line in f: 22 | line = line.strip().split(",") 23 | if dict_name == FILE_OPERATOR_TYPE: 24 | dict_operator_type[str(line[key])] = str(line[val]) 25 | elif dict_name == FILE_DICT_TYPE: 26 | dict_type[str(line[key])] = str(line[val]) 27 | except FileNotFoundError: 28 | logger.error(" %s File not found !" % file_path) 29 | except Exception as e: 30 | logger.error("Error:", e) 31 | 32 | 33 | class Dict: 34 | def __init__(self) -> object: 35 | logger.info("--" * 10 + "Dict init start " + "--" * 10) 36 | self.__init_dict__() 37 | logger.info("--" * 10 + "Dict init end " + "--" * 10) 38 | 39 | @staticmethod 40 | def __init_dict__(): 41 | init_dict(config_dict[FILE_OPERATOR_TYPE], FILE_OPERATOR_TYPE, 1, 0) 42 | init_dict(config_dict[FILE_DICT_TYPE], FILE_DICT_TYPE, 1, 0) 43 | 44 | @staticmethod 45 | def __value__(dict_name: str, val: str): 46 | if dict_name == FILE_OPERATOR_TYPE: 47 | if val in dict_operator_type: 48 | return dict_operator_type[val] 49 | elif dict_name == FILE_DICT_TYPE: 50 | if val in dict_type: 51 | return dict_type[val] 52 | return "" 53 | 54 | 55 | if __name__ == "__main__": 56 | dict_obj = Dict() 57 | -------------------------------------------------------------------------------- /common/log.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import datetime 5 | import base64 6 | import logging 7 | import numpy as np 8 | from logging.handlers import RotatingFileHandler 9 | from logging.handlers import TimedRotatingFileHandler 10 | from threading import Lock 11 | 12 | LOG_BASE_PATH = '../log/' 13 | now = datetime.datetime.now() 14 | nowtime = now.strftime("%Y-%m-%d") 15 | LOG_PATH = LOG_BASE_PATH + nowtime + "/" 16 | try: 17 | os.makedirs(LOG_PATH, exist_ok=True) 18 | except Exception as e: 19 | pass 20 | 21 | 22 | class LoggerProject(object): 23 | 24 | def __init__(self, name): 25 | self.mutex = Lock() 26 | self.name = name 27 | self.formatter = '%(asctime)s -<>- %(filename)s -<>- [line]:%(lineno)d -<>- %(levelname)s -<>- %(message)s' 28 | 29 | def _create_logger(self): 30 | _logger = logging.getLogger(self.name + __name__) 31 | _logger.setLevel(level=logging.INFO) 32 | return _logger 33 | 34 | def _file_logger(self): 35 | time_rotate_file = TimedRotatingFileHandler(filename=LOG_BASE_PATH + self.name, when='D', interval=1, 36 | backupCount=30) 37 | time_rotate_file.setFormatter(logging.Formatter(self.formatter)) 38 | time_rotate_file.setLevel(logging.INFO) 39 | return time_rotate_file 40 | 41 | def _console_logger(self): 42 | console_handler = logging.StreamHandler() 43 | console_handler.setLevel(level=logging.INFO) 44 | console_handler.setFormatter(logging.Formatter(self.formatter)) 45 | return console_handler 46 | 47 | def pub_logger(self): 48 | logger = self._create_logger() 49 | self.mutex.acquire() 50 | logger.addHandler(self._file_logger()) 51 | logger.addHandler(self._console_logger()) 52 | self.mutex.release() 53 | return logger 54 | 55 | 56 | log_api = LoggerProject('Langchain-ChatBI') 57 | logger = log_api.pub_logger() 58 | -------------------------------------------------------------------------------- /common/llm_output.py: -------------------------------------------------------------------------------- 1 | from common.log import logger 2 | from common.dict import * 3 | 4 | """ 5 | 大模型输出结构化处理 6 | input: 7 | { 8 | "data_indicators": "pv", 9 | "operator_type": "求和", 10 | "time_type": "半年", 11 | "dimension": "一汽红旗", 12 | "filter": "一汽红旗", 13 | "filter_type": "范围", 14 | "date_range": "2023-07-01,2023-12-31", 15 | "compare_type": "无" 16 | } 17 | output: 18 | {"data_indicators": "pv", "operator_type": "101", "time_type": "day", 19 | "dimensions": [{"enName": "name"}, {"enName": "id"}], "filters": [{"enName": "name", "val": "一汽红旗"}], 20 | "filter_type": "eq", "date_range": "2023-02-01,2023-02-25", "compare_type": "无"} 21 | 22 | """ 23 | obj_dict = Dict() 24 | 25 | def out_json_data(info): 26 | out_json = {} 27 | if "data_indicators" in info: 28 | out_json["data_indicators"] = obj_dict.__value__(FILE_DICT_TYPE, str(info["data_indicators"])) 29 | if "operator_type" in info: 30 | out_json["operator_type"] = obj_dict.__value__(FILE_OPERATOR_TYPE, str(info["operator_type"])) 31 | if "time_type" in info: 32 | out_json["time_type"] = obj_dict.__value__(FILE_DICT_TYPE, str(info["time_type"])) 33 | if "dimension" in info: 34 | out_json["dimensions"] = [{"enName": "name"}] 35 | if "filter" in info: 36 | out_json["filters"] = [{"enName": "name", "val": info["filter"]}] 37 | if "filter_type" in info: 38 | out_json["filter_type"] = obj_dict.__value__(FILE_DICT_TYPE, str(info["filter_type"])) 39 | if "date_range" in info: 40 | out_json["date_range"] = info["date_range"] 41 | if "compare_type" in info: 42 | out_json["compare_type"] = info["compare_type"] 43 | return out_json 44 | 45 | def dict_to_md(dictionary): 46 | md = "" 47 | formatted_data = json.dumps(dictionary, indent=4, ensure_ascii=False) 48 | md += f"```json\n"+formatted_data+"\n```\n" 49 | return md 50 | 51 | -------------------------------------------------------------------------------- /query_data/query_route.py: -------------------------------------------------------------------------------- 1 | from common.log import logger 2 | from query_data.db import selectMysql 3 | 4 | class QueryRoute: 5 | 6 | def __init__(self): 7 | logger.info("--" * 10 + "queryRoute init " + "--" * 10) 8 | 9 | def verify_query(self, out_dict: dict): 10 | if out_dict is None: 11 | out_dict = {"data_indicators": "pv", "operator_type": "detail", "time_type": "day", 12 | "dimensions": [{"enName": "name"}, {"enName": "id"}], "filters": [{"enName": "name", "val": "一汽"}], 13 | "filter_type": "eq", "date_range": "2024-01-01,2024-02-01", "compare_type": "无"} 14 | indicators_code = out_dict["data_indicators"] 15 | dim_code_list = out_dict["dimensions"] 16 | dim_code = "" 17 | index = 1 18 | if dim_code_list: 19 | for line in dim_code_list: 20 | if index == 1: 21 | dim_code = line["enName"] 22 | else: 23 | dim_code += "," + line["enName"] 24 | index = index + 1 25 | 26 | SQL_like = """ 27 | SELECT query_info,datasource_info,datasource_type 28 | FROM query_route 29 | WHERE indicators_code = '%s' 30 | AND dim_code_list like '%s' 31 | AND dim_query_type = 0 32 | """ % (indicators_code, "%"+dim_code+"%") 33 | 34 | SQL_eq = """ 35 | SELECT query_info,datasource_info,datasource_type 36 | FROM query_route 37 | WHERE indicators_code = '%s' 38 | AND dim_code_list = '%s' 39 | AND dim_query_type = 1 40 | """ % (indicators_code, dim_code) 41 | # print("SQL_like:", SQL_like) 42 | # print("SQL_eq:", SQL_eq) 43 | result_like = selectMysql(SQL_like) 44 | result_eq = selectMysql(SQL_eq) 45 | if (result_like and len(result_like) > 0) or (result_eq and len(result_eq) > 0): 46 | if result_like: 47 | return result_like[0][0], result_like[0][1], result_like[0][2] 48 | elif result_eq: 49 | return result_eq[0][0], result_like[0][1], result_like[0][2] 50 | else: 51 | return None 52 | 53 | 54 | if __name__ == "__main__": 55 | 56 | qr = QueryRoute() 57 | sql = qr.verify_query(None) 58 | print("sql:", sql) 59 | 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /models/llm_chatglm.py: -------------------------------------------------------------------------------- 1 | from langchain.llms.base import LLM 2 | from typing import Optional, List 3 | from langchain.llms.utils import enforce_stop_tokens 4 | from transformers import AutoTokenizer, AutoModel 5 | import torch 6 | from configs.config import LOCAL_LLM_DEVICE 7 | 8 | DEVICE = LOCAL_LLM_DEVICE 9 | DEVICE_ID = "0" if torch.cuda.is_available() else None 10 | CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE 11 | 12 | 13 | def torch_gc(): 14 | if torch.cuda.is_available(): 15 | with torch.cuda.device(CUDA_DEVICE): 16 | torch.cuda.empty_cache() 17 | torch.cuda.ipc_collect() 18 | 19 | class ChatGLM(LLM): 20 | max_token: int = 10000 21 | temperature: float = 0.01 22 | top_p = 0.9 23 | history = [] 24 | tokenizer: object = None 25 | model: object = None 26 | history_len: int = 10 27 | 28 | def __init__(self): 29 | super().__init__() 30 | 31 | @property 32 | def _llm_type(self) -> str: 33 | return "ChatGLM" 34 | 35 | def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: 36 | response, _ = self.model.chat( 37 | self.tokenizer, 38 | prompt, 39 | history=self.history[-self.history_len:] if self.history_len > 0 else [], 40 | max_length=self.max_token, 41 | temperature=self.temperature, 42 | ) 43 | torch_gc() 44 | if stop is not None: 45 | response = enforce_stop_tokens(response, stop) 46 | self.history = self.history + [[None, response]] 47 | return response 48 | 49 | def load_model(self, 50 | model_name_or_path: str = "THUDM/chatglm-6b", 51 | llm_device=LOCAL_LLM_DEVICE): 52 | self.tokenizer = AutoTokenizer.from_pretrained( 53 | model_name_or_path, 54 | trust_remote_code=True 55 | ) 56 | if torch.cuda.is_available() and llm_device.lower().startswith("cuda"): 57 | self.model = ( 58 | AutoModel.from_pretrained( 59 | model_name_or_path, 60 | trust_remote_code=True) 61 | .half() 62 | .cuda() 63 | ) 64 | else: 65 | self.model = ( 66 | AutoModel.from_pretrained( 67 | model_name_or_path, 68 | trust_remote_code=True) 69 | .float() 70 | .to(llm_device) 71 | ) 72 | self.model = self.model.eval() 73 | -------------------------------------------------------------------------------- /query_data/db.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pymysql 4 | 5 | initSQL = """ 6 | use demo; 7 | CREATE TABLE `query_route` ( 8 | `id` bigint NOT NULL AUTO_INCREMENT COMMENT '主键', 9 | `indicators_code` varchar(100) NOT NULL COMMENT '指标编码', 10 | `indicators_name` varchar(200) NOT NULL COMMENT '指标名称', 11 | `dim_code_list` varchar(200) NOT NULL COMMENT '维度编码列表', 12 | `dim_query_type` tinyint DEFAULT '0' COMMENT '维度匹配类型,0:任意组合、1:等匹配', 13 | `indicators_operator_type` varchar(200) DEFAULT '101' COMMENT '指标支持操作类型,101:明细、102:求和、103:平均值、104:最大值、105:最小值', 14 | `query_info` text NOT NULL COMMENT '查询信息', 15 | `datasource_info` varchar(2048) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL COMMENT '数据源信息', 16 | `datasource_type` tinyint DEFAULT '0' COMMENT '数据源类型,0:数据表、1:接口、2: 现成SQL', 17 | PRIMARY KEY (`id`) 18 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 ROW_FORMAT=DYNAMIC COMMENT='查询路由表'; 19 | 20 | CREATE TABLE `website_data` ( 21 | `dt` varchar(100) NOT NULL COMMENT '日期', 22 | `id` varchar(100) NOT NULL COMMENT '网站ID', 23 | `name` varchar(100) NOT NULL COMMENT '网站名', 24 | `pv` bigint NOT NULL COMMENT 'PV', 25 | `uv` bigint NOT NULL COMMENT 'UV' 26 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 ROW_FORMAT=DYNAMIC COMMENT='网站测试数据'; 27 | 28 | INSERT INTO query_route(id, indicators_code, indicators_name, dim_code_list, dim_query_type, indicators_operator_type, query_info, datasource_info, datasource_type) VALUES(1, 'pv', '曝光量', 'name,id,dt', 0, '101', 'website_data', '{"driverName":"com.mysql.cj.jdbc.Driver","jdbcUrl":"jdbc:mysql://127.0.0.1:3306/demo?allowMultiQueries=true&useUnicode=true&characterEncoding=utf8&autoReconnect=true&zeroDateTimeBehavior=convertToNull&useSSL=false&serverTimezone=Asia/Shanghai","username":"root","password":""}', 0); 29 | INSERT INTO query_route(id, indicators_code, indicators_name, dim_code_list, dim_query_type, indicators_operator_type, query_info, datasource_info, datasource_type) VALUES(2, 'uv', '用户数', 'name,id', 1, '101', 'website_data', '{"driverName":"com.mysql.cj.jdbc.Driver","jdbcUrl":"jdbc:mysql://127.0.0.1:3306/demo?allowMultiQueries=true&useUnicode=true&characterEncoding=utf8&autoReconnect=true&zeroDateTimeBehavior=convertToNull&useSSL=false&serverTimezone=Asia/Shanghai","username":"root","password":""}', 0); 30 | INSERT INTO website_data(dt, id, name, pv, uv) VALUES('2023-02-22', '11', '微博', 333333, 33333); 31 | INSERT INTO website_data(dt, id, name, pv, uv) VALUES('2023-02-22', '12', '京东', 105933, 45533); 32 | INSERT INTO website_data(dt, id, name, pv, uv) VALUES('2023-02-23', '12', '京东', 444444, 34444); 33 | INSERT INTO website_data(dt, id, name, pv, uv) VALUES('2023-02-22', '13', '淘宝', 333555, 55555); 34 | INSERT INTO website_data(dt, id, name, pv, uv) VALUES('2023-02-23', '13', '淘宝', 145555, 32355); 35 | INSERT INTO website_data(dt, id, name, pv, uv) VALUES('2023-02-23', '11', '微博', 445333, 32333); 36 | """ 37 | 38 | def selectMysql(sql): 39 | conn = pymysql.connect( 40 | host='127.0.0.1', 41 | port=3306, 42 | database='demo', 43 | user='root', 44 | password='####' 45 | ) 46 | try: 47 | cursor = conn.cursor() 48 | cursor.execute(sql) 49 | 50 | return cursor.fetchall() 51 | finally: 52 | cursor.close() 53 | conn.close() 54 | 55 | -------------------------------------------------------------------------------- /common/structured.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, List 4 | 5 | from langchain_core.output_parsers import BaseOutputParser 6 | from langchain_core.pydantic_v1 import BaseModel 7 | 8 | from langchain.output_parsers.format_instructions import ( 9 | STRUCTURED_FORMAT_INSTRUCTIONS, 10 | STRUCTURED_FORMAT_SIMPLE_INSTRUCTIONS, 11 | ) 12 | from langchain.output_parsers.json import parse_and_check_json_markdown 13 | 14 | line_template = '\t"{name}": {type} ' 15 | 16 | 17 | class ResponseSchema(BaseModel): 18 | """A schema for a response from a structured output parser.""" 19 | 20 | name: str 21 | """The name of the schema.""" 22 | description: str 23 | """The description of the schema.""" 24 | type: str = "string" 25 | """The type of the response.""" 26 | 27 | 28 | def _get_sub_string(schema: ResponseSchema) -> str: 29 | return line_template.format( 30 | name=schema.name, description=schema.description, type=schema.type 31 | ) 32 | 33 | 34 | class StructuredOutputParser(BaseOutputParser): 35 | """Parse the output of an LLM call to a structured output.""" 36 | 37 | response_schemas: List[ResponseSchema] 38 | """The schemas for the response.""" 39 | 40 | @classmethod 41 | def from_response_schemas( 42 | cls, response_schemas: List[ResponseSchema] 43 | ) -> StructuredOutputParser: 44 | return cls(response_schemas=response_schemas) 45 | 46 | def get_format_instructions(self, only_json: bool = False) -> str: 47 | """Get format instructions for the output parser. 48 | 49 | example: 50 | ```python 51 | from langchain.output_parsers.structured import ( 52 | StructuredOutputParser, ResponseSchema 53 | ) 54 | 55 | response_schemas = [ 56 | ResponseSchema( 57 | name="foo", 58 | description="a list of strings", 59 | type="List[string]" 60 | ), 61 | ResponseSchema( 62 | name="bar", 63 | description="a string", 64 | type="string" 65 | ), 66 | ] 67 | 68 | parser = StructuredOutputParser.from_response_schemas(response_schemas) 69 | 70 | print(parser.get_format_instructions()) 71 | 72 | output: 73 | # The output should be a Markdown code snippet formatted in the following 74 | # schema, including the leading and trailing "```json" and "```": 75 | # 76 | # ```json 77 | # { 78 | # "foo": List[string] // a list of strings 79 | # "bar": string // a string 80 | # } 81 | # ``` 82 | 83 | Args: 84 | only_json (bool): If True, only the json in the Markdown code snippet 85 | will be returned, without the introducing text. Defaults to False. 86 | """ 87 | schema_str = "\n".join( 88 | [_get_sub_string(schema) for schema in self.response_schemas] 89 | ) 90 | if only_json: 91 | return STRUCTURED_FORMAT_SIMPLE_INSTRUCTIONS.format(format=schema_str) 92 | else: 93 | return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str) 94 | 95 | def parse(self, text: str) -> Any: 96 | expected_keys = [rs.name for rs in self.response_schemas] 97 | return parse_and_check_json_markdown(text, expected_keys) 98 | 99 | @property 100 | def _type(self) -> str: 101 | return "structured" 102 | -------------------------------------------------------------------------------- /knowledge/source_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | from langchain.document_loaders import UnstructuredFileLoader 3 | from langchain.embeddings.huggingface import HuggingFaceEmbeddings 4 | from langchain.vectorstores import FAISS 5 | from langchain.document_loaders.csv_loader import CSVLoader 6 | from common.log import logger 7 | from configs.config import * 8 | import sentence_transformers 9 | from typing import List 10 | import datetime 11 | 12 | """ 13 | 知识库向量化服务 14 | """ 15 | class SourceService: 16 | def __init__(self, 17 | embedding_model: str = EMBEDDING_MODEL_DEFAULT, 18 | embedding_device=LOCAL_EMBEDDING_DEVICE): 19 | self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], ) 20 | self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name, 21 | device=embedding_device) 22 | self.vector_store = None 23 | self.vector_store_path = VECTOR_STORE_PATH 24 | 25 | def init_source_vector(self, docs_path): 26 | """ 27 | 初始化本地知识库向量 28 | :return: 29 | """ 30 | docs = [] 31 | for doc in os.listdir(docs_path): 32 | if doc.endswith('.txt'): 33 | logger.info(doc) 34 | loader = UnstructuredFileLoader(f'{docs_path}/{doc}', mode="elements") 35 | doc = loader.load() 36 | docs.extend(doc) 37 | self.vector_store = FAISS.from_documents(docs, self.embeddings) 38 | self.vector_store.save_local(self.vector_store_path) 39 | 40 | def init_knowledge_vector_store(self, 41 | filepath: str or List[str]): 42 | if isinstance(filepath, str): 43 | if not os.path.exists(filepath): 44 | logger.error("路径不存在") 45 | return None 46 | elif os.path.isfile(filepath): 47 | file = os.path.split(filepath)[-1] 48 | try: 49 | loader = UnstructuredFileLoader(filepath, mode="elements") 50 | docs = loader.load() 51 | logger.info(f"{file} 已成功加载") 52 | except Exception as e: 53 | logger.error(f"{file} 未能成功加载", e) 54 | return None 55 | elif os.path.isdir(filepath): 56 | docs = [] 57 | for file in os.listdir(filepath): 58 | fullfilepath = os.path.join(filepath, file) 59 | try: 60 | loader = UnstructuredFileLoader(fullfilepath, mode="elements") 61 | docs += loader.load() 62 | logger.info(f"{file} 已成功加载") 63 | except Exception as e: 64 | logger.error(f"{file} 未能成功加载", e) 65 | else: 66 | docs = [] 67 | for file in filepath: 68 | try: 69 | loader = UnstructuredFileLoader(file, mode="elements") 70 | docs += loader.load() 71 | logger.info(f"{file} 已成功加载") 72 | except Exception as e: 73 | logger.error(f"{file} 未能成功加载", e) 74 | 75 | vector_store = FAISS.from_documents(docs, self.embeddings) 76 | vs_path = f"""{VECTOR_STORE_PATH}/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""" 77 | vector_store.save_local(vs_path) 78 | return vs_path if len(docs) > 0 else None 79 | 80 | def add_document(self, document_path): 81 | loader = UnstructuredFileLoader(document_path, mode="elements") 82 | doc = loader.load() 83 | self.vector_store.add_documents(doc) 84 | self.vector_store.save_local(self.vector_store_path) 85 | 86 | def load_vector_store(self, path): 87 | if path is None: 88 | self.vector_store = FAISS.load_local(self.vector_store_path, self.embeddings) 89 | else: 90 | self.vector_store = FAISS.load_local(path, self.embeddings) 91 | return self.vector_store 92 | 93 | def add_csv(self, document_path): 94 | loader = CSVLoader(file_path=document_path) 95 | doc = loader.load() 96 | logger.info("doc:", doc) 97 | 98 | -------------------------------------------------------------------------------- /query_data/query_execute.py: -------------------------------------------------------------------------------- 1 | from query_data.query_route import QueryRoute 2 | from query_data.db import selectMysql 3 | from common.log import logger 4 | import json 5 | import requests 6 | 7 | query_route = QueryRoute() 8 | 9 | def exe_query(out_dict): 10 | result_data = [] 11 | if out_dict: 12 | out_dict_result, datasource_info, datasource_type = query_route.verify_query(out_dict) 13 | if out_dict_result: 14 | if datasource_type == 0: 15 | out_dict["table_name"] = out_dict_result 16 | sql_query = sql_assemble(out_dict) 17 | list_data = selectMysql(sql_query) 18 | for row in list_data: 19 | result = { 20 | "name": row[0], 21 | "value": int(row[1]) 22 | } 23 | result_data.append(result) 24 | elif datasource_type == 1: 25 | out_dict["url"] = out_dict_result 26 | result_data = url_get_data(out_dict, datasource_info) 27 | 28 | return result_data 29 | 30 | def url_get_data(out_dict, datasource_info): 31 | # req_params_map = { 32 | # 33 | # } 34 | req_params_map = datasource_info 35 | try: 36 | json_data = json.dumps(req_params_map) 37 | res = requests.post( 38 | url=out_dict["url"], 39 | headers={ 40 | "Content-Type": "application/json", 41 | }, 42 | data=json_data, 43 | timeout=60 44 | ) 45 | res_json = json.loads(res.text) 46 | if res.status_code == 200: 47 | return res_json["data"], res.status_code 48 | else: 49 | return res_json["msg"], res.status_code 50 | except Exception as e: 51 | logger.error(e) 52 | return "query fail, wait a second! ", 500 53 | 54 | def sql_assemble(out_dict: dict): 55 | if out_dict is None: 56 | out_dict = {'data_indicators': 'pv', 'operator_type': 'sum', 'time_type': 'quarter', 'dimensions': [{'enName': 'name'}], 'filters': [{'enName': 'name', 'val': '一汽大众'}], 'filter_type': '=', 'date_range': '2023-04-01,2023-06-30', 'compare_type': '无', 'table_name': 'brand_data'} 57 | data_indicators = out_dict["data_indicators"] 58 | operator_type = out_dict["operator_type"] 59 | time_type = out_dict["time_type"] 60 | dimensions = out_dict["dimensions"] 61 | filters = out_dict["filters"] 62 | filter_type = out_dict["filter_type"] 63 | date_range = out_dict["date_range"] 64 | table_name = out_dict["table_name"] 65 | # compare_type = out_dict["compare_type"] 66 | group_by_sql, dim_sql, dim_group = "", "", "" 67 | condition = "1=1 " 68 | if dimensions: 69 | for line in dimensions: 70 | dim_sql = line["enName"] + " as name" 71 | dim_group = line["enName"] 72 | if filters: 73 | for fi in filters: 74 | key = fi["enName"] 75 | val = fi["val"] 76 | condition += " and " + filters_join(key, val, filter_type) 77 | if date_range: 78 | if "," in date_range: 79 | begin_date = date_range.split(",")[0] 80 | end_date = date_range.split(",")[1] 81 | condition += time_type_format(begin_date, end_date, time_type) 82 | else: 83 | condition += time_type_format_eq(date_range, time_type) 84 | 85 | operator_type_sql = "" 86 | if operator_type: 87 | if operator_type == "sum": 88 | operator_type_sql = "sum(%s) as value" % data_indicators 89 | group_by_sql = "group by %s" % dim_group 90 | elif operator_type == "avg": 91 | operator_type_sql = "avg(%s) as value" % data_indicators 92 | group_by_sql = "group by %s" % dim_group 93 | elif operator_type == "max": 94 | operator_type_sql = "max(%s) as value" % data_indicators 95 | group_by_sql = "group by %s" % dim_group 96 | elif operator_type == "min": 97 | operator_type_sql = "min(%s) as value" % data_indicators 98 | group_by_sql = "group by %s" % dim_group 99 | else: 100 | operator_type_sql = data_indicators 101 | 102 | SQL = """ 103 | SELECT %s,%s 104 | FROM %s 105 | WHERE %s 106 | %s 107 | """ % (dim_sql, operator_type_sql, table_name, condition, group_by_sql) 108 | return SQL 109 | 110 | 111 | def filters_join(key: str, val: str, filter_type: str): 112 | filter_sql = "" 113 | if filter_type: 114 | if filter_type == "=": 115 | filter_sql = " " + key + " like '%"+val+"%'" 116 | # filter_sql = " %s = '%s' " % (key, val) 117 | if filter_type == ">": 118 | filter_sql = " %s > '%s' " % (key, val) 119 | if filter_type == ">=": 120 | filter_sql = " %s >= '%s' " % (key, val) 121 | if filter_type == "in": 122 | filter_sql = " %s in('%s')" % (key, val) 123 | if filter_type == "like": 124 | filter_sql = " %s like '%s'" % (key, val) 125 | if filter_type == "<": 126 | filter_sql = " %s < '%s'" % (key, val) 127 | if filter_type == "<=": 128 | filter_sql = " %s <= '%s'" % (key, val) 129 | return filter_sql 130 | 131 | 132 | def time_type_format(begin_date: str, end_date: str, time_type: str): 133 | condition = "" 134 | if time_type: 135 | if time_type == "day" or time_type == "quarter" or time_type == "week": 136 | condition = " and dt >= '%s' and dt <= '%s' " % (begin_date, end_date) 137 | elif time_type == "month": 138 | condition = " and DATE_FORMAT(dt, '%Y-%m') >= DATE_FORMAT('" + begin_date + "', '%Y-%m') and DATE_FORMAT(dt, '%Y-%m') <= DATE_FORMAT('" + end_date + "', '%Y-%m') " 139 | return condition 140 | 141 | 142 | def time_type_format_eq(date_range: str, time_type: str): 143 | condition = "" 144 | if time_type: 145 | if time_type == "day": 146 | condition = "dt = '%s' """ % date_range 147 | elif time_type == "month": 148 | condition = " DATE_FORMAT(dt, '%Y-%m') = DATE_FORMAT('" + date_range + "', '%Y-%m') " 149 | return condition 150 | 151 | 152 | if __name__ == "__main__": 153 | sql = sql_assemble(None) 154 | print("sql=", sql) 155 | -------------------------------------------------------------------------------- /chains/chatbi_chain.py: -------------------------------------------------------------------------------- 1 | from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate 2 | from langchain.output_parsers.json import parse_json_markdown 3 | from langchain.chains import RetrievalQA, LLMChain 4 | from langchain.memory import ConversationBufferWindowMemory 5 | from common.structured import StructuredOutputParser, ResponseSchema 6 | from common.log import logger 7 | from common.llm_output import out_json_data 8 | from configs.config import * 9 | from knowledge.source_service import SourceService 10 | from models.llm_chatglm import ChatGLM 11 | from models.llm_baichuan import LLMBaiChuan 12 | from models.llm_tongyi import LLMTongyi 13 | from query_data.query_execute import exe_query 14 | import datetime 15 | 16 | line_template = '\t"{name}": {type} ' 17 | 18 | time_today = datetime.date.today() 19 | class ChatBiChain: 20 | llm: object = None 21 | service: object = None 22 | memory: object = None 23 | top_k: int = LLM_TOP_K 24 | llm_model: str 25 | his_query: str 26 | 27 | def init_cfg(self, 28 | llm_model: str = LLM_MODEL_CHAT_GLM, 29 | embedding_model: str = EMBEDDING_MODEL_DEFAULT, 30 | llm_history_len=LLM_HISTORY_LEN, 31 | top_k=LLM_TOP_K 32 | ): 33 | self.init_mode(llm_model, llm_history_len) 34 | self.service = SourceService(embedding_model, LOCAL_EMBEDDING_DEVICE) 35 | self.his_query = "" 36 | self.top_k = top_k 37 | logger.info("--" * 30 + "ChatBiChain init " + "--" * 30) 38 | 39 | def init_mode(self, llm_model: str = LLM_MODEL_CHAT_GLM, llm_history_len: str = LLM_HISTORY_LEN): 40 | self.llm_model = llm_model 41 | self.memory = ConversationBufferWindowMemory(k=llm_history_len) 42 | if llm_model == LLM_MODEL_CHAT_GLM: 43 | self.llm = ChatGLM() 44 | self.llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL_CHAT_GLM], 45 | llm_device=LOCAL_LLM_DEVICE) 46 | self.llm.history_len = llm_history_len 47 | elif llm_model == LLM_MODEL_BAICHUAN: 48 | self.llm = LLMBaiChuan() 49 | elif llm_model == LLM_MODEL_QIANWEN: 50 | self.llm = LLMTongyi() 51 | 52 | def run_answer(self, query, vs_path, chat_history, top_k=VECTOR_SEARCH_TOP_K): 53 | result_dict = {"data": "sorry,the query is fail"} 54 | out_dict = self.get_intent_identify(query) 55 | out_str = out_dict["info"] 56 | if out_dict["code"] == 200 and "回答:" in out_str: 57 | if "意图:完整" in out_str or "意图: 完整" in out_str: 58 | query = out_str.split("回答:")[1] 59 | # chat_history = chat_history + [[None, query]] 60 | else: 61 | result_dict["data"] = out_str.split("回答:")[1] 62 | return result_dict, chat_history 63 | else: 64 | result_dict["data"] = out_str 65 | return result_dict, chat_history 66 | try: 67 | resp = self.get_answer(query, vs_path, top_k) 68 | res_dict = parse_json_markdown(resp["result"]) 69 | out_json = out_json_data(res_dict) 70 | result_dict["data"] = str(exe_query(out_json)) 71 | except Exception as e: 72 | logger.error(e) 73 | return result_dict, chat_history 74 | 75 | def get_intent_identify(self, query: str): 76 | template = """ 你是智能数据分析助手,根据上下文和Human提问,识别对方数据分析意图('完整'、'缺失'、'闲聊') 77 | ## 背景知识 78 | 完整:对方上下文信息中必须同时包含指标和时间范围,否则是缺失,例如:微博过去一个月的访问量,为完整 79 | 缺失:对方上下文信息不完整,只有时间段或只有指标,例如:微博的访问量量或过去一个月的访问量,都为缺失 80 | 闲聊:跟数据查询无关,如:你是谁 81 | 82 | ## 回答约束 83 | 若数据分析意图为完整,要根据上下文信息总结成一句完整的语句,否则礼貌询问对方需要查询什么 84 | 85 | ## 输出格式 86 | 意图:#,回答:# 87 | 88 | {history} 89 | Human: {human_input} 90 | """ 91 | out_dict = {"code": 500} 92 | prompt = PromptTemplate( 93 | input_variables=["history", "human_input"], 94 | template=template 95 | ) 96 | _chain = LLMChain(llm=self.llm, prompt=prompt, verbose=True, memory=self.memory) 97 | try: 98 | out_dict["info"] = _chain.predict(human_input=query) 99 | out_dict["code"] = 200 100 | except Exception as e: 101 | print(e) 102 | out_dict["info"] = "sorry,LLM model (%s) is fail,wait a minute..." % self.llm_model 103 | return out_dict 104 | 105 | def get_answer(self, query: object, vs_path: str = VECTOR_STORE_PATH, top_k=VECTOR_SEARCH_TOP_K): 106 | response_schemas = [ 107 | ResponseSchema(name="data_indicators", description="数据指标: 如 PV、UV"), 108 | ResponseSchema(name="operator_type", description="计算类型: 明细,求和,最大值,最小值,平均值"), 109 | ResponseSchema(name="time_type", description="时间类型: 天、周、月、小时"), 110 | ResponseSchema(name="dimensions", description="维度"), 111 | ResponseSchema(name="filters", description="过滤条件"), 112 | ResponseSchema(name="filter_type", description="过滤条件类型:大于,等于,小于,范围"), 113 | ResponseSchema(name="date_range", description="日期范围,需按当前日期计算,假如当前日期为:2023-12-01,问 过去三个月或近几个月,则输出2023-09-01,2023-11-30;问过去一个月或上个月,则输出2023-11-01,2023-11-30;问八月或8月,则输出2023-08-01,2023-08-31;"), 114 | ResponseSchema(name="compare_type", description="比较类型:无,同比,环比") 115 | ] 116 | output_parser = StructuredOutputParser.from_response_schemas(response_schemas) 117 | format_instructions = output_parser.get_format_instructions(only_json=False) 118 | prompt = ChatPromptTemplate( 119 | messages=[ 120 | HumanMessagePromptTemplate.from_template( 121 | "从问题中抽取准确的信息,若不匹配,返回空,\n{format_instructions},输出时,去掉备注 \n 当前日期:%s \n 已知内容:{context} \n 问题:{question} " % time_today 122 | ) 123 | ], 124 | input_variables=["context", "question"], 125 | partial_variables={"format_instructions": format_instructions} 126 | ) 127 | vector_store = self.service.load_vector_store(vs_path) 128 | knowledge_chain = RetrievalQA.from_llm( 129 | llm=self.llm, 130 | retriever=vector_store.as_retriever(search_kwargs={"k": top_k}), 131 | prompt=prompt 132 | ) 133 | knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate( 134 | input_variables=["page_content"], template="{page_content}" 135 | ) 136 | knowledge_chain.return_source_documents = True 137 | result_dict = {} 138 | try: 139 | result_dict = knowledge_chain({"query": query}) 140 | except Exception as e: 141 | logger.error(e) 142 | result_dict["result"] = "sorry,LLM model (%s) is fail,wait a minute..." % self.llm_model 143 | return result_dict 144 | -------------------------------------------------------------------------------- /main_webui.py: -------------------------------------------------------------------------------- 1 | from configs.config import * 2 | from chains.chatbi_chain import ChatBiChain 3 | from common.log import logger 4 | from common.llm_output import dict_to_md 5 | from fastapi import FastAPI 6 | from fastapi.middleware.cors import CORSMiddleware 7 | import gradio as gr 8 | import argparse 9 | import uvicorn 10 | import os 11 | import re 12 | import shutil 13 | 14 | chain = ChatBiChain() 15 | embedding_model_dict_list = list(embedding_model_dict.keys()) 16 | 17 | llm_model_dict_list = list(llm_model_dict.keys()) 18 | 19 | def get_file_list(): 20 | if not os.path.exists("knowledge/content"): 21 | return [] 22 | return [f for f in os.listdir("knowledge/content")] 23 | 24 | 25 | file_list = get_file_list() 26 | 27 | 28 | def upload_file(file): 29 | if not os.path.exists("knowledge/content"): 30 | os.mkdir("knowledge/content") 31 | filename = os.path.basename(file.name) 32 | shutil.move(file.name, "knowledge/content/" + filename) 33 | file_list.insert(0, filename) 34 | return gr.Dropdown(choices=file_list, value=filename) 35 | 36 | 37 | def reinit_model(llm_model, embedding_model, llm_history_len, top_k, history): 38 | try: 39 | chain.init_cfg(llm_model=llm_model, 40 | embedding_model=embedding_model, 41 | llm_history_len=llm_history_len, 42 | top_k=top_k) 43 | model_msg = """The LLM model has been successfully reloaded. Please select the file and click the "Load File" button to send the message again""" 44 | except Exception as e: 45 | logger.error(e) 46 | model_msg = """sorry,If the model does not reload successfully, click "Load model" button""" 47 | return history + [[None, model_msg]] 48 | 49 | 50 | def get_answer(query, vs_path, history, top_k): 51 | if vs_path: 52 | history = history + [[query, None]] 53 | result_data, history = chain.run_answer(query=query, vs_path=vs_path, chat_history=history, top_k=top_k) 54 | history = history + [[None, result_data["data"]]] 55 | return history, "" 56 | else: 57 | history = history + [[None, "Please load the file before you ask questions."]] 58 | return history, "" 59 | 60 | 61 | def get_vector_store(filepath, history): 62 | if chain.llm and chain.service: 63 | vs_path = chain.service.init_knowledge_vector_store(["knowledge/content/" + filepath]) 64 | if vs_path: 65 | file_status = "The file has been successfully loaded. Please start asking questions" 66 | else: 67 | file_status = "The file did not load successfully, please upload the file again" 68 | else: 69 | file_status = "The model did not finished loading, please load the model before loading the file" 70 | vs_path = None 71 | return vs_path, history + [[None, file_status]] 72 | 73 | 74 | def init_model(): 75 | try: 76 | chain.init_cfg() 77 | return """The model has been loaded successfully, please select the file and click the "Load file" button""" 78 | except: 79 | return """The model did not load successfully, please click "Load model" button""" 80 | 81 | 82 | block_css = """.importantButton { 83 | background: linear-gradient(45deg, #7e05ff,#5d1c99, #6e00ff) !important; 84 | border: none !important; 85 | } 86 | 87 | .importantButton:hover { 88 | background: linear-gradient(45deg, #ff00e0,#8500ff, #6e00ff) !important; 89 | border: none !important; 90 | } 91 | 92 | #chat_bi { 93 | height: 100%; 94 | min-height: 455px; 95 | } 96 | """ 97 | 98 | webui_title = """ 99 | # Langchain-ChatBI Project 100 | """ 101 | init_message = """Welcome to the ChatBI, click 'Reload the model', if you choose the Embedding model, select or upload the corpus, and then click 'Load the File' """ 102 | 103 | model_status = init_model() 104 | 105 | with gr.Blocks(css=block_css) as demo: 106 | vs_path, file_status, model_status = gr.State(""), gr.State(""), gr.State(model_status) 107 | gr.Markdown(webui_title) 108 | with gr.Row(): 109 | with gr.Column(scale=1): 110 | llm_model = gr.Radio(llm_model_dict_list, 111 | label="LLM Model", 112 | value=LLM_MODEL_CHAT_GLM, 113 | interactive=True) 114 | llm_history_len = gr.Slider(0, 115 | 10, 116 | value=5, 117 | step=1, 118 | label="LLM history len", 119 | interactive=True) 120 | embedding_model = gr.Radio(embedding_model_dict_list, 121 | label="Embedding Model", 122 | value=EMBEDDING_MODEL_DEFAULT, 123 | interactive=True) 124 | top_k = gr.Slider(1, 125 | 20, 126 | value=6, 127 | step=1, 128 | label="top k", 129 | interactive=True) 130 | load_model_button = gr.Button("Reload Model") 131 | 132 | with gr.Tab("select"): 133 | selectFile = gr.Dropdown(file_list, 134 | label="content file", 135 | interactive=True, 136 | value=file_list[0] if len(file_list) > 0 else None) 137 | with gr.Tab("upload"): 138 | file = gr.File(label="content file", 139 | file_types=['.txt', '.md', '.docx', '.pdf'] 140 | ) # .style(height=100) 141 | load_file_button = gr.Button("Load File") 142 | with gr.Column(scale=2): 143 | chatbot = gr.Chatbot(label=init_message, elem_id="chat_bi", show_label=True) 144 | query = gr.Textbox(show_label=True, 145 | placeholder="Please enter the questions and submit them according to the return", 146 | label="Input Field") 147 | send = gr.Button(" Submit") 148 | load_model_button.click(reinit_model, 149 | show_progress=True, 150 | inputs=[llm_model, embedding_model, llm_history_len, top_k, chatbot], 151 | outputs=chatbot 152 | ) 153 | # 将上传的文件保存到content文件夹下,并更新下拉框 154 | file.upload(upload_file, 155 | inputs=file, 156 | outputs=selectFile) 157 | load_file_button.click(get_vector_store, 158 | show_progress=True, 159 | inputs=[selectFile, chatbot], 160 | outputs=[vs_path, chatbot], 161 | ) 162 | query.submit(get_answer, 163 | show_progress=True, 164 | inputs=[query, vs_path, chatbot, top_k], 165 | outputs=[chatbot, query], 166 | ) 167 | # 发送按钮 提交 168 | send.click(get_answer, 169 | show_progress=True, 170 | inputs=[query, vs_path, chatbot, top_k], 171 | outputs=[chatbot, query], 172 | ) 173 | 174 | app = FastAPI() 175 | app = gr.mount_gradio_app(app, demo, path="/") 176 | 177 | if __name__ == "__main__": 178 | parser = argparse.ArgumentParser() 179 | parser.add_argument("--host", type=str, default=WEB_SERVER_NAME) 180 | parser.add_argument("--port", type=int, default=WEB_SERVER_PORT) 181 | parser.add_argument("--async", type=int, default=0) 182 | args = parser.parse_args() 183 | 184 | app.add_middleware( 185 | CORSMiddleware, 186 | allow_origins=["*"], 187 | allow_credentials=True, 188 | allow_methods=["*"], 189 | allow_headers=["*"], 190 | ) 191 | uvicorn.run(app, host=args.host, port=args.port) 192 | --------------------------------------------------------------------------------