├── .gitignore ├── DEPLOY.md ├── README.md ├── api ├── embedding.py ├── llm.py └── rerank.py ├── config-exp.ini ├── direct └── direct_request.py ├── llm ├── adaptor │ └── chat2llm.py └── llm_loader.py ├── main.py ├── modal └── openai_api_modal.py ├── patching └── langchain_patch.py ├── prompt └── function_call.prompt ├── requirements.txt └── utils └── general_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .conda 2 | .venv 3 | 4 | # Logs 5 | logs 6 | *.log 7 | npm-debug.log* 8 | yarn-debug.log* 9 | yarn-error.log* 10 | pnpm-debug.log* 11 | lerna-debug.log* 12 | 13 | node_modules 14 | .DS_Store 15 | dist 16 | dist-ssr 17 | coverage 18 | *.local 19 | 20 | /cypress/videos/ 21 | /cypress/screenshots/ 22 | 23 | # Editor directories and files 24 | .vscode/* 25 | !.vscode/extensions.json 26 | .idea 27 | *.suo 28 | *.ntvs* 29 | *.njsproj 30 | *.sln 31 | *.sw? 32 | 33 | __pycache__ 34 | 35 | config.ini 36 | 37 | temp 38 | web_ui/version.txt 39 | config-colab.ini 40 | config-model.ini 41 | .streamlit/config.toml 42 | .streamlit/credentials.toml 43 | -------------------------------------------------------------------------------- /DEPLOY.md: -------------------------------------------------------------------------------- 1 | * **下载向量和rerank模型** 2 | 3 | ``` 4 | # 下载安装git-fls https://github.com/git-lfs/git-lfs/releases 5 | git lfs install 6 | 7 | mkdir -p modal 8 | cd modal 9 | 10 | git clone https://www.modelscope.cn/quietnight/bge-reranker-large.git 11 | git clone https://www.modelscope.cn/AI-ModelScope/bge-large-zh-v1.5.git 12 | ``` 13 | 14 | - **配置llm,复制config-exp.ini** 15 | 16 | ``` 17 | cp config-exp.ini config.ini 18 | ``` 19 | 20 | - **配置核心字段(最简版,除标注须替换的字段外,其他字段不动)** 21 | 22 | ``` 23 | # 复制本文件并命名 config.ini 24 | 25 | [llm] 26 | 27 | 28 | # https://console.xfyun.cn/services/cbm 29 | # 讯飞星火 app id 30 | xh_app_id = 31 | # 讯飞星火 api secret 32 | xh_api_secret = 33 | # 讯飞星火 api key 34 | xh_api_key = 35 | 36 | 37 | 38 | [embedding] 39 | bge_embedding_path = 40 | 41 | [rerank] 42 | bge_reranker_path = 43 | 44 | ``` 45 | 46 | * **拉取项目,安装依赖** 47 | 48 | ``` 49 | # 进入项目主目录 50 | cd lang2openai 51 | # 创建虚拟环境 52 | python -m venv venv 53 | # 激活虚拟环境win10 54 | venv\Scripts\activate 55 | # 激活虚拟环境linux 56 | source venv/bin/activate 57 | # 后端依赖安装 58 | pip install -r requirements.txt 59 | ``` 60 | 61 | * **启动项目** 62 | 63 | ``` 64 | # python 3.10以上版本 65 | python main.py 66 | ``` 67 | 68 | - **访问接口** 69 | 70 | ``` 71 | curl --location --request POST 'http://127.0.0.1:8778/v1/completions' \ 72 | --header 'User-Agent: Apifox/1.0.0 (https://apifox.com)' \ 73 | --header 'Content-Type: application/json' \ 74 | --header 'Accept: */*' \ 75 | --header 'Host: 127.0.0.1:8778' \ 76 | --header 'Connection: keep-alive' \ 77 | --data-raw '{ 78 | "model": "spark-3.1", 79 | "prompt": "你能做什么?", 80 | "stream":false 81 | }' 82 | ``` 83 | 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### 1.项目是做什么的? 2 | 3 | 通过LangChain标准化适配的llm,embedding,rerank,等实现,通过适配层,按openai-api统一的接口标准对外提供服务。 4 | 5 | ### 2.有那么多中间适配层的项目,为什么要用你的 6 | 7 | 市面上确实有很多api统一适配的项目比如one-api,我们的项目优势有如下几点: 8 | 9 | - 简单,只做接口协议转换,因而本项目非常简单,没有任何中间件或数据库的依赖,可以说拿到代码,配置key之后就能使用 10 | - RAG,市面上几乎所有的api统一适配项目都只是进行了对话接口协议的适配,本项目针对RAG类型的项目额外进行了embedding,rerank的接口协议适配与实现,需要按说明将模型文件下载到本地目录,即可直接使用 11 | - 可进化,因为本项目最终是基于LangChain标准化来实现外部接口调用的,因而一切LangChain自身或厂商适配过的模型,本项目都支持,而作为大语言模型应用开发的当红炸子鸡框架,我相信,其支持的模型调用必然会越来约多。目前已支持的见此链接[Chat models | 🦜️🔗 LangChain](https://python.langchain.com/docs/integrations/chat/?ref=blog.langchain.dev) 12 | - function calling,考虑到越来越多的agents以及工具调用的需求产生,本项目特意基于提示词工程的方式实现了openai-api的function calling适配,目前以spark-3.1作为底层模型进行测试,效果良好。 13 | 14 | ### 3.怎么用 15 | 16 | 详见[lang2openai/DEPLOY.md](https://github.com/q2wxec/lang2openai/blob/master/DEPLOY.md) 17 | 18 | ### 4.接口一览 19 | 20 | - /v1/completions 21 | 22 | - request 23 | 24 | ``` 25 | curl --location --request POST 'http://127.0.0.1:8778/v1/completions' \ 26 | --header 'User-Agent: Apifox/1.0.0 (https://apifox.com)' \ 27 | --header 'Content-Type: application/json' \ 28 | --header 'Accept: */*' \ 29 | --header 'Host: 127.0.0.1:8778' \ 30 | --header 'Connection: keep-alive' \ 31 | --data-raw '{ 32 | "model": "spark-3.1", 33 | "prompt": "你能做什么?", 34 | "stream":false 35 | }' 36 | ``` 37 | 38 | - response 39 | 40 | ``` 41 | { 42 | "choices": [ 43 | { 44 | "finish_reason": "length", 45 | "index": 0, 46 | "logprobs": "", 47 | "text": "作为一个认知智能模型,我可以回答你的问题、提供信息和建议、进行自然语言处理和生成文本等。以下是一些我可以做的事情:\n * 回答各种问题,包括常见问题、学术问题、技术问题等。\n * 提供各种领域的知识和信息,例如历史、科学、文化、艺术等。\n * 给出建议和指导,例如如何学习一门新技能、如何解决某个问题等。\n * 进行自然语言处理和生成文本,例如翻译、摘要、对话等。\n当然,我的能力也是有限的,有些问题可能超出了我的能力范围或者我没有足够的信息来回答。如果您有任何疑问或需要帮助,请随时告诉我。" 48 | } 49 | ], 50 | "created": 1713772029.4484026, 51 | "id": "513a06f4bccc4bb3a563c494bab952ab", 52 | "model": "spark-3.1", 53 | "object": "text_completion", 54 | "usage": { 55 | "completion_tokens": 244, 56 | "prompt_tokens": 8, 57 | "total_tokens": 252 58 | } 59 | } 60 | ``` 61 | 62 | 63 | 64 | - /v1/chat/completions 65 | 66 | - request 67 | 68 | ``` 69 | curl --location --request POST 'http://127.0.0.1:8778/v1/chat/completions' \ 70 | --header 'User-Agent: Apifox/1.0.0 (https://apifox.com)' \ 71 | --header 'Content-Type: application/json' \ 72 | --header 'Accept: */*' \ 73 | --header 'Host: 127.0.0.1:8778' \ 74 | --header 'Connection: keep-alive' \ 75 | --data-raw '{ 76 | "model": "spark-3.1", 77 | "messages": [ 78 | { 79 | "role": "system", 80 | "content": "You are a helpful assistant." 81 | }, 82 | { 83 | "role": "user", 84 | "content": "Hello!" 85 | } 86 | ] 87 | }' 88 | ``` 89 | 90 | - response 91 | 92 | ``` 93 | { 94 | "id": "33e10d6cb5b34c149b38178dccfa9e05", 95 | "object": "chat.completion", 96 | "created": 1713772263.898054, 97 | "model": "spark-3.1", 98 | "system_fingerprint": "", 99 | "choices": [ 100 | { 101 | "index": 0, 102 | "message": { 103 | "role": "assistant", 104 | "content": "Hello! How can I assist you today?", 105 | "tool_calls": "" 106 | }, 107 | "logprobs": "", 108 | "finish_reason": "stop" 109 | } 110 | ], 111 | "usage": { 112 | "prompt_tokens": 8, 113 | "completion_tokens": 9, 114 | "total_tokens": 17 115 | } 116 | } 117 | ``` 118 | 119 | 120 | 121 | - /v1/chat/completions(function call) 122 | 123 | - request 124 | 125 | ``` 126 | { 127 | "model": "spark-3.1", 128 | "messages": [ 129 | { 130 | "role": "user", 131 | "content": "What's the weather like in Boston today?" 132 | } 133 | ], 134 | "tools": [ 135 | { 136 | "type": "function", 137 | "function": { 138 | "name": "get_current_weather", 139 | "description": "Get the current weather in a given location", 140 | "parameters": { 141 | "type": "object", 142 | "properties": { 143 | "location": { 144 | "type": "string", 145 | "description": "The city and state, e.g. San Francisco, CA" 146 | }, 147 | "unit": { 148 | "type": "string", 149 | "enum": [ 150 | "celsius", 151 | "fahrenheit" 152 | ] 153 | } 154 | }, 155 | "required": [ 156 | "location" 157 | ] 158 | } 159 | } 160 | } 161 | ], 162 | "tool_choice": "auto" 163 | } 164 | ``` 165 | 166 | - response 167 | 168 | ``` 169 | { 170 | "id": "d121afe229164872b20d3123a41b3d48", 171 | "object": "chat.completion", 172 | "created": 1713772368.84041, 173 | "model": "spark-3.1", 174 | "system_fingerprint": "", 175 | "choices": [ 176 | { 177 | "index": 0, 178 | "message": { 179 | "role": "assistant", 180 | "content": "", 181 | "tool_calls": [ 182 | { 183 | "id": "a0366c69f7c24c0d92bfc598457fca48", 184 | "type": "function", 185 | "function": { 186 | "name": "get_current_weather", 187 | "arguments": "{\"location\": \"Boston, MA\"}" 188 | } 189 | } 190 | ] 191 | }, 192 | "logprobs": "", 193 | "finish_reason": "stop" 194 | } 195 | ], 196 | "usage": { 197 | "prompt_tokens": 9, 198 | "completion_tokens": 28, 199 | "total_tokens": 37 200 | } 201 | } 202 | ``` 203 | 204 | 205 | 206 | - /v1/embeddings 207 | 208 | - request 209 | 210 | ``` 211 | curl --location --request POST 'http://127.0.0.1:8778/v1/embeddings' \ 212 | --header 'User-Agent: Apifox/1.0.0 (https://apifox.com)' \ 213 | --header 'Content-Type: application/json' \ 214 | --header 'Accept: */*' \ 215 | --header 'Host: 127.0.0.1:8778' \ 216 | --header 'Connection: keep-alive' \ 217 | --data-raw '{ 218 | "input": "Your text string goes here", 219 | "model": "bge-large-zh-v1.5" 220 | }' 221 | ``` 222 | 223 | - response 224 | 225 | ``` 226 | { 227 | "object": "list", 228 | "data": [ 229 | { 230 | "object": "embedding", 231 | "index": 0, 232 | "embedding": [ 233 | -0.024587785825133324, 234 | -0.018740979954600334, 235 | ... 236 | 0.016061244532465935 237 | ] 238 | } 239 | ], 240 | "model": "bge-large-zh-v1.5", 241 | "usage": { 242 | "prompt_tokens": 5, 243 | "total_tokens": 5 244 | } 245 | } 246 | ``` 247 | 248 | 249 | 250 | - /v1/rerank 251 | 252 | - request 253 | 254 | ``` 255 | curl --location --request POST 'http://127.0.0.1:8778/v1/rerank' \ 256 | --header 'User-Agent: Apifox/1.0.0 (https://apifox.com)' \ 257 | --header 'Content-Type: application/json' \ 258 | --header 'Accept: */*' \ 259 | --header 'Host: 127.0.0.1:8778' \ 260 | --header 'Connection: keep-alive' \ 261 | --data-raw '{ 262 | "model": "bge-reranker-large", 263 | "query": "A man is eating pasta.", 264 | "documents": [ 265 | "A man is eating food.", 266 | "A man is eating a piece of bread.", 267 | "The girl is carrying a baby.", 268 | "A man is riding a horse.", 269 | "A woman is playing violin.", 270 | "A man is eating pasta.", 271 | "A man is eating pasta", 272 | "A man is eating " 273 | ] 274 | }' 275 | ``` 276 | 277 | - response 278 | 279 | ``` 280 | { 281 | "id": "9a5b0cc3947a45348cef53bf3ff6d449", 282 | "results": [ 283 | { 284 | "index": 0, 285 | "relevance_score": 0.9754678249359131, 286 | "document": "A man is eating food." 287 | }, 288 | { 289 | "index": 1, 290 | "relevance_score": 0.3509412884712219, 291 | "document": "A man is eating a piece of bread." 292 | }, 293 | { 294 | "index": 2, 295 | "relevance_score": 0.02600417137145996, 296 | "document": "The girl is carrying a baby." 297 | }, 298 | { 299 | "index": 3, 300 | "relevance_score": 0.02587881088256836, 301 | "document": "A man is riding a horse." 302 | }, 303 | { 304 | "index": 4, 305 | "relevance_score": 0.026003408432006835, 306 | "document": "A woman is playing violin." 307 | }, 308 | { 309 | "index": 5, 310 | "relevance_score": 0.9765342235565185, 311 | "document": "A man is eating pasta." 312 | }, 313 | { 314 | "index": 6, 315 | "relevance_score": 0.9765145778656006, 316 | "document": "A man is eating pasta" 317 | }, 318 | { 319 | "index": 7, 320 | "relevance_score": 0.9709453105926513, 321 | "document": "A man is eating " 322 | } 323 | ] 324 | } 325 | ``` 326 | 327 | 328 | ### LangChain用法 329 | 330 | ``` 331 | gpt35 = ChatOpenAI(model="spark-3.1", temperature=0,openai_api_base="http://localhost:8778/v1",openai_api_key="123").bind_tools(tools) 332 | ``` 333 | 334 | -------------------------------------------------------------------------------- /api/embedding.py: -------------------------------------------------------------------------------- 1 | from sanic.response import json as sanic_json 2 | from sanic import request 3 | import os 4 | 5 | 6 | from utils.general_utils import * 7 | 8 | def get_embeddings_dict()->dict: 9 | from langchain.embeddings.huggingface import HuggingFaceEmbeddings 10 | result = {} 11 | # if get_config('embedding','bge_embedding_path'): 12 | # result["bge-large-zh-v1.5"] = HuggingFaceBgeEmbeddings(model_name=get_config('embedding','bge_embedding_path')) 13 | if get_config('embedding','embedding_path'): 14 | embedding_pathes = get_config('embedding','embedding_path') 15 | embedding_path_list = embedding_pathes.split(',') 16 | for embedding_path in embedding_path_list: 17 | model = os.path.basename(embedding_path) 18 | result[model] = HuggingFaceEmbeddings(model_name=embedding_path) 19 | return result 20 | async def embeddings(req: request): 21 | models = req.app.ctx.embedding_models 22 | input = safe_get(req, 'input', []) 23 | if not isinstance(input, list): 24 | input = [input] 25 | model = safe_get(req, 'model') 26 | # 如果model为none或者不存在于model_paths中,则使用默认模型 27 | if model is None or model not in models: 28 | model = "bge-large-zh-v1.5" 29 | embeddings = models[model] 30 | embed_datas = embeddings.embed_documents(input) 31 | data = [] 32 | num_tokens = cal_tokens(input,"text-embedding-ada-002") 33 | # 遍历embed_datas,转化为embed_data格式后,添加到data中 34 | for i, embed_data in enumerate(embed_datas): 35 | data.append({ 36 | "object": "embedding", 37 | "index": i, 38 | "embedding": embed_data 39 | }) 40 | resp = { 41 | "object": "list", 42 | "data": data, 43 | "model": model, 44 | "usage": { 45 | "prompt_tokens": num_tokens, 46 | "total_tokens": num_tokens 47 | } 48 | } 49 | return sanic_json(resp) 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /api/llm.py: -------------------------------------------------------------------------------- 1 | from sanic.response import json as sanic_json 2 | from sanic.response import ResponseStream 3 | from sanic import request 4 | import json 5 | import re 6 | 7 | import asyncio 8 | import os 9 | from langchain_core.language_models.chat_models import BaseChatModel 10 | from langchain.llms.base import LLM 11 | from langchain_core.messages import HumanMessage, SystemMessage,AIMessage 12 | # from llm.adaptor.chat2llm import Chat2LLM 13 | from utils.general_utils import * 14 | from llm.llm_loader import getLLM,getChat,modal_type_dict 15 | from modal.openai_api_modal import * 16 | from direct.direct_request import pre_router 17 | 18 | def get_function_prompt(question,functions)->str: 19 | BASE_DIR = os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir))) 20 | with open(BASE_DIR+'/prompt/function_call.prompt', 'r',encoding='utf-8') as f: 21 | function_prompt = f.read() 22 | # 替换function_prompt中的{question}和{functions} 23 | result = function_prompt.replace("{question}", question).replace("{functions}", functions) 24 | return result 25 | 26 | 27 | async def chat(req: request): 28 | # models = req.app.ctx.chat_models 29 | model = safe_get(req, 'model') 30 | if not model in modal_type_dict: 31 | model = "glm-4" 32 | # glm接口与openai兼容,可以直接处理返回 33 | stream = safe_get(req, 'stream', False) 34 | messages = safe_get(req, 'messages', []) 35 | tools = safe_get(req, 'tools', []) 36 | temperature = safe_get(req, 'temperature', 0.01) 37 | functions = safe_get(req, 'functions', []) 38 | if functions and not tools: 39 | for function in functions: 40 | tool = { 41 | "type": "function", 42 | "function": function 43 | } 44 | tools.append(tool) 45 | resp = pre_router(req,'chat') 46 | if resp: 47 | return resp 48 | #print('messages:'+str(messages)) 49 | #print('tools:'+str(tools)) 50 | # spark-3.1最新消息必须来自用户 51 | if model == "spark-3.1" and messages[-1]['role'] != 'user' : 52 | messages[-1]['role'] = 'user' 53 | chat : BaseChatModel = getChat(model,temperature) 54 | chat_messages = [] 55 | for message in messages: 56 | if message['content'] is None: 57 | message['content']='' 58 | if message['role'] == 'system': 59 | chat_message = SystemMessage(content=message['content']) 60 | elif message['role'] == 'user': 61 | chat_message = HumanMessage(content=message['content']) 62 | elif message['role'] == 'assistant': 63 | if message['tool_calls'] and len(message['tool_calls']) > 0: 64 | # tool_calls = [] 65 | # for tool_call in message['tool_calls']: 66 | # tc = ToolCall(name=tool_call['function']['name'],args=json.loads(tool_call['function']['arguments']),id = tool_call['id']) 67 | # tool_calls.append(tc) 68 | # chat_message = AIMessage(content=message['content'],tool_calls=tool_calls) 69 | chat_message = AIMessage(content=message['content']+'/n工具调用情况如下:'+str(message['tool_calls'])) 70 | else : 71 | chat_message = AIMessage(content=message['content']) 72 | elif message['role'] == 'tool': 73 | # chat_message = ToolMessage(content=message['content'],tool_call_id = message['tool_calls']) 74 | chat_message = AIMessage(content='工具调用结果如下,tool_call_id:'+message['tool_call_id']+' ,调用结果result:'+message['content']) 75 | chat_messages.append(chat_message) 76 | # tools不为空,说明是工具调用 77 | is_function_call = (tools and messages[-1]['role'] == 'user') 78 | if is_function_call: 79 | stream = False 80 | question = chat_messages.pop().content 81 | # 将tools转化为str 82 | functions = json.dumps(tools) 83 | function_prompt = get_function_prompt(question,functions) 84 | chat_message = HumanMessage(content=function_prompt) 85 | chat_messages.append(chat_message) 86 | 87 | prompt_tokens = sum(cal_tokens(s['content'], 'gpt-3.5-turbo') for s in messages) 88 | if stream: 89 | stream_resp = get_chat_stream_resp(model) 90 | async def generate_answer(response:ResponseStream): 91 | completion_tokens = 0 92 | for chunk in chat.stream(chat_messages): 93 | #logger.info(resp) 94 | resp_content = chunk.content 95 | completion_tokens += cal_tokens(resp_content, 'gpt-3.5-turbo') 96 | stream_resp["choices"][0]['delta']['content'] = resp_content 97 | await response.write(f"data: {json.dumps(stream_resp, ensure_ascii=False)}\n\n") 98 | # 确保流式输出不被压缩 99 | await asyncio.sleep(0.001) 100 | stream_resp["choices"][0]['finish_reason'] = "stop" 101 | stream_resp["usage"]["prompt_tokens"]= prompt_tokens 102 | stream_resp["usage"]["completion_tokens"]= completion_tokens 103 | stream_resp["usage"]["total_tokens"]= completion_tokens+prompt_tokens 104 | stream_resp["choices"][0]['delta'] = {} 105 | await response.write(f"data: {json.dumps(stream_resp, ensure_ascii=False)}\n\n") 106 | # 确保流式输出不被压缩 107 | await asyncio.sleep(0.001) 108 | return ResponseStream(generate_answer, content_type='text/event-stream') 109 | else: 110 | resp = get_chat_resp(model) 111 | content = chat.invoke(chat_messages).content 112 | completion_tokens = cal_tokens(content, 'gpt-3.5-turbo') 113 | resp["usage"]["prompt_tokens"]= prompt_tokens 114 | resp["usage"]["completion_tokens"]= completion_tokens 115 | resp["usage"]["total_tokens"]= completion_tokens+prompt_tokens 116 | if is_function_call: 117 | tool_calls = [] 118 | if is_valid_json_array(content): 119 | tool_array = json.loads(content) 120 | else: 121 | match = re.search(r'\[.*\]', content) 122 | if match: 123 | tool_array = json.loads(match.group(0)) 124 | else: 125 | raise Exception('函数调用结果格式错误,请检查') 126 | for tool in tool_array: 127 | tool_resp = { 128 | "id": uuid.uuid4().hex, 129 | "type": "function", 130 | "function": tool 131 | } 132 | tool_calls.append(tool_resp) 133 | resp["choices"][0]['message']['tool_calls'] = tool_calls 134 | else: 135 | resp["choices"][0]['message']['content'] = content 136 | #print('resp:'+str(resp)) 137 | return sanic_json(resp) 138 | 139 | 140 | async def completions(req: request): 141 | # models = req.app.ctx.llm_models 142 | resp = pre_router(req,'completions') 143 | if resp: 144 | return resp 145 | prompt = safe_get(req, 'prompt') 146 | stream = safe_get(req, 'stream', False) 147 | model = safe_get(req, 'model') 148 | temperature = safe_get(req, 'temperature', 0.01) 149 | llm : LLM = getLLM(model, temperature) 150 | prompt_tokens = cal_tokens(prompt, 'gpt-3.5-turbo') 151 | if stream: 152 | stream_resp = get_completions_stream_resp(model) 153 | async def generate_answer(response:ResponseStream): 154 | completion_tokens = 0 155 | for chunk in llm.stream(prompt): 156 | #logger.info(resp) 157 | resp_content = chunk 158 | completion_tokens += cal_tokens(resp_content, 'gpt-3.5-turbo') 159 | stream_resp["choices"][0]['delta']['content'] = resp_content 160 | await response.write(f"data: {json.dumps(stream_resp, ensure_ascii=False)}\n\n") 161 | # 确保流式输出不被压缩 162 | await asyncio.sleep(0.001) 163 | stream_resp["choices"][0]['finish_reason'] = "length" 164 | stream_resp["usage"]["prompt_tokens"]= prompt_tokens 165 | stream_resp["usage"]["completion_tokens"]= completion_tokens 166 | stream_resp["usage"]["total_tokens"]= completion_tokens+prompt_tokens 167 | stream_resp["choices"][0]['delta'] = {} 168 | await response.write(f"data: {json.dumps(stream_resp, ensure_ascii=False)}\n\n") 169 | # 确保流式输出不被压缩 170 | await asyncio.sleep(0.001) 171 | return ResponseStream(generate_answer, content_type='text/event-stream') 172 | else: 173 | resp = get_completions_resp(model) 174 | content = llm.invoke(prompt) 175 | resp["choices"][0]['text'] = content 176 | completion_tokens = cal_tokens(content, 'gpt-3.5-turbo') 177 | resp["usage"]["prompt_tokens"]= prompt_tokens 178 | resp["usage"]["completion_tokens"]= completion_tokens 179 | resp["usage"]["total_tokens"]= completion_tokens+prompt_tokens 180 | return sanic_json(resp) 181 | 182 | 183 | -------------------------------------------------------------------------------- /api/rerank.py: -------------------------------------------------------------------------------- 1 | from sanic.response import json as sanic_json 2 | from sanic import request 3 | import uuid 4 | 5 | from utils.general_utils import * 6 | 7 | def get_rerank_dict()->dict: 8 | from FlagEmbedding import FlagReranker 9 | result = {} 10 | if get_config('rerank','bge_reranker_path'): 11 | result['bge-reranker-large'] = FlagReranker(model_name_or_path=get_config('rerank','bge_reranker_path'), use_fp16=True) 12 | return result 13 | 14 | async def rerank(req: request): 15 | models = req.app.ctx.reranke_models 16 | query = safe_get(req, 'query') 17 | documents = safe_get(req, 'documents', []) 18 | model = safe_get(req, 'model') 19 | # 如果model为none或者不存在于model_paths中,则使用默认模型 20 | if model is None or model not in models: 21 | model = "bge-reranker-large" 22 | reranker = models[model] 23 | sentence_pairs = [[query, passage] for passage in documents] 24 | scores = reranker.compute_score(sentence_pairs) 25 | results = [] 26 | # 循环scores,documents,生成result并放入results中 27 | 28 | for i, (score, document) in enumerate(zip(scores, documents)): 29 | # 生成并添加 result 到 results 列表中 30 | results.append({ 31 | "index": i, 32 | "relevance_score": (score+10)/20, 33 | "document": document 34 | }) 35 | resp={ 36 | "id": uuid.uuid4().hex, 37 | "results": results 38 | } 39 | return sanic_json(resp) 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /config-exp.ini: -------------------------------------------------------------------------------- 1 | # 复制本文件并命名 config.ini 2 | 3 | [llm] 4 | # https://platform.openai.com/api-keys 5 | # openai api key 6 | api_key = 7 | # openai api url 8 | api_url = 9 | 10 | #https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application 11 | # 百度千帆 ak 12 | qf_ak= 13 | # 百度千帆 sk 14 | qf_sk= 15 | 16 | 17 | # https://console.xfyun.cn/services/cbm 18 | # 讯飞星火 app id 19 | xh_app_id = 20 | # 讯飞星火 api secret 21 | xh_api_secret = 22 | # 讯飞星火 api key 23 | xh_api_key = 24 | 25 | # https://ai.google.dev 26 | # gemini pro api key 27 | google_api_key = 28 | 29 | # https://dashscope.console.aliyun.com/apiKey 30 | # 通义千问 api key 31 | ty_api_key = 32 | 33 | # 商汤日日新 34 | st_ak= 35 | 36 | st_sk= 37 | # 智谱 38 | # https://open.bigmodel.cn/partner 39 | zhipu_key = 40 | 41 | [embedding] 42 | # git clone https://www.modelscope.cn/AI-ModelScope/bge-large-zh-v1.5.git 43 | # 配置多个用,将路径隔开如C:/Users/fxx/codes/ai/modal/bge-large-zh-v1.5,C:/Users/fxx/codes/ai/modal/m3e-large 44 | embedding_path = 45 | 46 | [rerank] 47 | # git clone https://www.modelscope.cn/quietnight/bge-reranker-large.git 48 | bge_reranker_path = 49 | -------------------------------------------------------------------------------- /direct/direct_request.py: -------------------------------------------------------------------------------- 1 | from utils.general_utils import * 2 | from sanic import request 3 | from sanic.response import json as sanic_json 4 | from sanic.response import ResponseStream 5 | import asyncio 6 | import json 7 | 8 | from llm.llm_loader import modal_type_dict 9 | from modal.openai_api_modal import * 10 | 11 | def pre_router(req: request,req_type:str = 'chat'): 12 | model = safe_get(req, 'model') 13 | if model : 14 | type = modal_type_dict[model] 15 | if type == 'tongyi': 16 | #return tongyi_chat(req) 17 | api_key =get_config('llm','ty_api_key') 18 | url = "https://dashscope.aliyuncs.com/compatible-mode/v1" 19 | elif type == 'zhipu': 20 | #return glm_chat(req) 21 | api_key =get_config('llm','zhipu_key') 22 | url = "https://open.bigmodel.cn/api/paas/v4" 23 | elif type == 'proxy': 24 | api_key = get_config('llm','proxy_key') 25 | url = get_config('llm','proxy_url') 26 | elif type == 'moonshot': 27 | api_key = get_config('llm','moonshot_key') 28 | url = "https://api.moonshot.cn/v1" 29 | elif type == 'spark': 30 | api_key = get_config('llm','xh_api_key')+':'+ get_config('llm','xh_api_secret') 31 | url = "https://spark-api-open.xf-yun.com/v1" 32 | return gpt_chat(req,url,api_key,req_type) 33 | return None 34 | 35 | 36 | def gpt_chat(req: request,url:str,api_key:str,req_type:str): 37 | import httpx 38 | if req_type == 'chat': 39 | url = url+'/chat/completions' 40 | else: 41 | url = url+'/completions' 42 | params = req.json 43 | stream = safe_get(req, 'stream', False) 44 | 45 | # 请求数据 46 | data = params 47 | 48 | print('request:'+str(data)) 49 | # 设置请求头部 50 | headers = { 51 | 'Authorization': f'Bearer {api_key}', 52 | 'Content-Type': 'application/json' 53 | } 54 | 55 | # 发送请求 56 | resp = httpx.post(url, json=data, headers=headers, timeout=120.0) 57 | if not stream: 58 | print('resp:'+str(resp)) 59 | return sanic_json(resp.json()) 60 | else: 61 | async def generate_answer(response:ResponseStream): 62 | for chunk in resp.iter_text(): 63 | # 去除chunk开头的 data: 64 | if chunk.startswith("data:"): 65 | chunk = chunk[5:] 66 | #logger.info(resp) 67 | await response.write(f"data: {chunk}\n\n") 68 | # 确保流式输出不被压缩 69 | await asyncio.sleep(0.001) 70 | # await response.write(f"data: {chunk}\n\n") 71 | # # 确保流式输出不被压缩 72 | # await asyncio.sleep(0.001) 73 | return ResponseStream(generate_answer, content_type='text/event-stream') 74 | 75 | # def glm_chat(req: request): 76 | # from zhipuai import ZhipuAI 77 | # stream = safe_get(req, 'stream', False) 78 | # model = safe_get(req, 'model') 79 | # messages = safe_get(req, 'messages', []) 80 | # tools = safe_get(req, 'tools', []) 81 | # tool_choice = safe_get(req, 'tool_choice', 'auto') 82 | # temperature = safe_get(req, 'temperature', 0) 83 | # client = ZhipuAI(api_key=get_config('llm','zhipu_key')) # 填写您自己的APIKey 84 | # req={ 85 | # 'model':model, # 填写需要调用的模型名称 86 | # 'messages':messages, 87 | # 'stream':stream, 88 | # 'temperature':temperature, 89 | # } 90 | # if tools: 91 | # req['tools'] = tools 92 | # req['tool_choice'] = tool_choice 93 | # #print(req) 94 | # resp = client.chat.completions.create(**req) 95 | # if not stream: 96 | # return sanic_json(json.loads(resp.model_dump_json(exclude_none = True))) 97 | # else: 98 | # async def generate_answer(response:ResponseStream): 99 | # for chunk in resp: 100 | # #logger.info(resp) 101 | # await response.write(f"data: {chunk.model_dump_json(exclude_none = True)}\n\n") 102 | # # 确保流式输出不被压缩 103 | # await asyncio.sleep(0.001) 104 | # # await response.write(f"data: {chunk.model_dump_json(exclude_none = True)}\n\n") 105 | # # # 确保流式输出不被压缩 106 | # # await asyncio.sleep(0.001) 107 | # return ResponseStream(generate_answer, content_type='text/event-stream') 108 | 109 | 110 | # def tongyi_chat(req: request): 111 | # from openai import OpenAI 112 | # stream = safe_get(req, 'stream', False) 113 | # model = safe_get(req, 'model') 114 | # messages = safe_get(req, 'messages', []) 115 | # for message in messages: 116 | # if not message['content']: 117 | # message['content']='工具调用' 118 | # tools = safe_get(req, 'tools', []) 119 | # tool_choice = safe_get(req, 'tool_choice', '') 120 | # temperature = safe_get(req, 'temperature', 0) 121 | # client = OpenAI( 122 | # api_key=get_config('llm','ty_api_key'), # 替换成真实DashScope的API_KEY 123 | # base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", # 填写DashScope服务endpoint 124 | # ) 125 | # req={ 126 | # 'model':model, # 填写需要调用的模型名称 127 | # 'messages':messages, 128 | # 'stream':stream, 129 | # 'temperature':temperature, 130 | # #'tool_choice':tool_choice, 131 | # } 132 | # if tools: 133 | # req['tools'] = tools 134 | # resp = client.chat.completions.create(**req) 135 | # if not stream: 136 | # return sanic_json(json.loads(resp.to_json(exclude_none = True))) 137 | # else: 138 | # async def generate_answer(response:ResponseStream): 139 | # for chunk in resp: 140 | # #logger.info(resp) 141 | # await response.write(f"data: {chunk.to_json(exclude_none = True, indent=None)}\n\n") 142 | # # 确保流式输出不被压缩 143 | # await asyncio.sleep(0.001) 144 | # # await response.write(f"data: {chunk.to_json(exclude_none = True, indent=None)}\n\n") 145 | # # # 确保流式输出不被压缩 146 | # # await asyncio.sleep(0.001) 147 | # return ResponseStream(generate_answer, content_type='text/event-stream') -------------------------------------------------------------------------------- /llm/adaptor/chat2llm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Any, List, Mapping, Optional,Iterator 3 | from langchain_core.outputs import GenerationChunk 4 | from langchain.callbacks.manager import CallbackManagerForLLMRun 5 | from langchain.llms.base import LLM 6 | from langchain_core.language_models.chat_models import BaseChatModel 7 | from langchain_core.messages import HumanMessage 8 | 9 | 10 | class Chat2LLM(LLM): 11 | chat:BaseChatModel 12 | 13 | @property 14 | def _llm_type(self) -> str: 15 | return "Chat2LLM" 16 | 17 | @property 18 | def _identifying_params(self) -> Mapping[str, Any]: 19 | """Get the identifying parameters.""" 20 | return {"chat": self.chat} 21 | def _call( 22 | self, 23 | prompt: str, 24 | stop: Optional[List[str]] = None, 25 | run_manager: Optional[CallbackManagerForLLMRun] = None, 26 | **kwargs: Any, 27 | ) -> str: 28 | chat_message = HumanMessage(content=prompt) 29 | return self.chat.invoke([chat_message]).content 30 | 31 | 32 | def _stream( 33 | self, 34 | prompt: str, 35 | stop: Optional[List[str]] = None, 36 | run_manager: Optional[CallbackManagerForLLMRun] = None, 37 | **kwargs: Any, 38 | ) -> Iterator[GenerationChunk]: 39 | chat_message = HumanMessage(content=prompt) 40 | resp = self.chat.stream([chat_message]) 41 | for part in resp: 42 | chunk = GenerationChunk(text=part.content) 43 | yield chunk -------------------------------------------------------------------------------- /llm/llm_loader.py: -------------------------------------------------------------------------------- 1 | from langchain_community.chat_models import QianfanChatEndpoint 2 | from langchain_community.llms.tongyi import Tongyi 3 | from langchain_community.llms.sparkllm import SparkLLM 4 | from langchain_community.llms.baidu_qianfan_endpoint import QianfanLLMEndpoint 5 | from langchain_community.chat_models import QianfanChatEndpoint,ChatTongyi,ChatSparkLLM,ChatZhipuAI 6 | from langchain.llms.base import LLM 7 | from langchain_core.language_models.chat_models import BaseChatModel 8 | 9 | from llm.adaptor.chat2llm import Chat2LLM 10 | 11 | from utils.general_utils import * 12 | 13 | modal_list={ 14 | "qianfan":["ERNIE-Bot-4","ERNIE-Speed-128K","ERNIE-Speed-8K","ERNIE-Lite-8K"], 15 | "tongyi":["qwen-plus","qwen-turbo","qwen2.5-7b-instruct","qwen2.5-72b-instruct"], 16 | "zhipu":["glm-4","glm-4-plus","glm-4-flash"], 17 | "spark":["general","generalv3","pro-128k","generalv3.5","4.0Ultra"], 18 | "proxy":["gpt-3.5-turbo","gpt-4","gpt-4o","gpt-4o-2024-05-13","gpt-4o-2024-08-06","gpt-4o-mini-2024-07-18"], 19 | "moonshot":["moonshot-v1-8k","moonshot-v1-32k","moonshot-v1-128k"], 20 | } 21 | 22 | modal_type_dict = {item: key for key, sublist in modal_list.items() for item in sublist} 23 | 24 | 25 | # 大模型定义 26 | 27 | 28 | def getLLM(model,temperature=0.1)->LLM: 29 | type = modal_type_dict[model] 30 | if type == "qianfan": 31 | if get_config('llm','qf_ak'): 32 | return QianfanLLMEndpoint(qianfan_ak=get_config('llm','qf_ak'),qianfan_sk=get_config('llm','qf_sk'),model=model,temperature=temperature) 33 | else: 34 | raise Exception("qianfan_ak not set .Please check config.ini") 35 | 36 | def getChat(model,temperature=0.1)->BaseChatModel: 37 | type = modal_type_dict[model] 38 | if type == "qianfan": 39 | if get_config('llm','qf_ak'): 40 | return QianfanChatEndpoint(qianfan_ak=get_config('llm','qf_ak'),qianfan_sk=get_config('llm','qf_sk'),model=model,temperature=temperature) 41 | else: 42 | raise Exception("qianfan_ak not set .Please check config.ini") 43 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from patching import langchain_patch 4 | langchain_patch.mk() 5 | 6 | # 获取当前脚本的绝对路径 7 | current_script_path = os.path.abspath(__file__) 8 | 9 | # 获取当前脚本的父目录的路径 10 | current_dir = os.path.dirname(current_script_path) 11 | 12 | # 获取父目录 13 | parent_dir = os.path.dirname(current_dir) 14 | 15 | # 获取根目录: 16 | root_dir = os.path.dirname(parent_dir) 17 | 18 | # 将项目根目录添加到sys.path 19 | sys.path.append(root_dir) 20 | 21 | 22 | from sanic import Sanic 23 | from sanic import response as sanic_response 24 | 25 | import os 26 | 27 | from sanic.worker.manager import WorkerManager 28 | import argparse 29 | 30 | from api.embedding import embeddings 31 | from api.rerank import rerank 32 | from api.llm import chat,completions 33 | 34 | 35 | WorkerManager.THRESHOLD = 6000 36 | # 接收外部参数mode 37 | parser = argparse.ArgumentParser() 38 | # mode必须是local或online 39 | parser.add_argument('--mode', type=str, default='local', help='local or online') 40 | # 检查是否是local或online,不是则报错 41 | args = parser.parse_args() 42 | if args.mode not in ['local', 'online']: 43 | raise ValueError('mode must be local or online') 44 | 45 | app = Sanic("Lang2OpenAI") 46 | # 设置请求体最大为 10MB 47 | app.config.REQUEST_MAX_SIZE = 400 * 1024 * 1024 48 | 49 | # CORS中间件,用于在每个响应中添加必要的头信息 50 | @app.middleware("response") 51 | async def add_cors_headers(request, response): 52 | # response.headers["Access-Control-Allow-Origin"] = "http://10.234.10.144:5052" 53 | response.headers["Access-Control-Allow-Origin"] = "*" 54 | response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" 55 | response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization" 56 | response.headers["Access-Control-Allow-Credentials"] = "true" # 如果需要的话 57 | 58 | @app.middleware("request") 59 | async def handle_options_request(request): 60 | if request.method == "OPTIONS": 61 | headers = { 62 | # "Access-Control-Allow-Origin": "http://10.234.10.144:5052", 63 | "Access-Control-Allow-Origin": "*", 64 | "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS", 65 | "Access-Control-Allow-Headers": "Content-Type, Authorization", 66 | "Access-Control-Allow-Credentials": "true" # 如果需要的话 67 | } 68 | return sanic_response.text("", headers=headers) 69 | 70 | @app.before_server_start 71 | async def init_modal(app, loop): 72 | # from api.llm import get_llm_dict,get_chat_dict 73 | from api.rerank import get_rerank_dict 74 | from api.embedding import get_embeddings_dict 75 | app.ctx.embedding_models = get_embeddings_dict() 76 | app.ctx.reranke_models = get_rerank_dict() 77 | # app.ctx.llm_models = get_llm_dict() 78 | # app.ctx.chat_models = get_chat_dict() 79 | 80 | 81 | 82 | app.add_route(embeddings, "/v1/embeddings", methods=['POST']) # tags=["embeddings"] 83 | app.add_route(rerank, "/v1/rerank", methods=['POST']) # tags=["rerank"] 84 | app.add_route(completions, "/v1/completions", methods=['POST']) # tags=["completions"] 85 | app.add_route(chat, "/v1/chat/completions", methods=['POST']) # tags=["chat"] 86 | 87 | if __name__ == "__main__": 88 | app.run(host='0.0.0.0', port=8778, workers=4) 89 | -------------------------------------------------------------------------------- /modal/openai_api_modal.py: -------------------------------------------------------------------------------- 1 | import time 2 | import uuid 3 | def get_chat_resp(model): 4 | return { 5 | "id": uuid.uuid4().hex, 6 | "object": "chat.completion", 7 | "created": time.time(), 8 | "model": model, 9 | "system_fingerprint": "", 10 | "choices": [{ 11 | "index": 0, 12 | "message": { 13 | "role": "assistant", 14 | "content": "", 15 | "tool_calls":"", 16 | }, 17 | "logprobs": '', 18 | "finish_reason": "stop" 19 | }], 20 | "usage": { 21 | "prompt_tokens": 0, 22 | "completion_tokens": 0, 23 | "total_tokens": 0 24 | } 25 | } 26 | def get_chat_stream_resp(model): 27 | return { 28 | "id": uuid.uuid4().hex, 29 | "object": "chat.completion.chunk", 30 | "created": time.time(), 31 | "model": model, 32 | "system_fingerprint": "", 33 | "choices": [ 34 | { 35 | "index": 0, 36 | "delta": { 37 | "role": "assistant", 38 | "content": "" 39 | }, 40 | "logprobs": '', 41 | "finish_reason": "" 42 | } 43 | ], 44 | "usage": { 45 | "completion_tokens": 0, 46 | "prompt_tokens": 0, 47 | "total_tokens": 0 48 | } 49 | } 50 | def get_completions_resp(model): 51 | return { 52 | "choices": [ 53 | { 54 | "finish_reason": "length", 55 | "index": 0, 56 | "logprobs": '', 57 | "text": "" 58 | } 59 | ], 60 | "created": time.time(), 61 | "id": uuid.uuid4().hex, 62 | "model": model, 63 | "object": "text_completion", 64 | "usage": { 65 | "completion_tokens": 0, 66 | "prompt_tokens": 0, 67 | "total_tokens": 0 68 | } 69 | } 70 | 71 | def get_completions_stream_resp(model): 72 | return { 73 | "choices": [ 74 | { 75 | "finish_reason": "", 76 | "index": 0, 77 | "logprobs": '', 78 | "delta": { 79 | "content": "" 80 | }, 81 | } 82 | ], 83 | "created": time.time(), 84 | "id": uuid.uuid4().hex, 85 | "model": model, 86 | "object": "text_completion", 87 | "usage": { 88 | "completion_tokens": 0, 89 | "prompt_tokens": 0, 90 | "total_tokens": 0 91 | } 92 | } -------------------------------------------------------------------------------- /patching/langchain_patch.py: -------------------------------------------------------------------------------- 1 | from langchain_core.utils._merge import merge_dicts 2 | from langchain_core.utils import _merge 3 | from typing import Any, Dict 4 | 5 | 6 | 7 | def do_merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]: 8 | """Merge two dicts, handling specific scenarios where a key exists in both 9 | dictionaries but has a value of None in 'left'. In such cases, the method uses the 10 | value from 'right' for that key in the merged dictionary. 11 | 12 | Example: 13 | If left = {"function_call": {"arguments": None}} and 14 | right = {"function_call": {"arguments": "{\n"}} 15 | then, after merging, for the key "function_call", 16 | the value from 'right' is used, 17 | resulting in merged = {"function_call": {"arguments": "{\n"}}. 18 | """ 19 | merged = left.copy() 20 | for k, v in right.items(): 21 | if k not in merged: 22 | merged[k] = v 23 | elif v is not None and merged[k] is None: 24 | merged[k] = v 25 | elif v is None or merged[k] == v: 26 | continue 27 | elif type(merged[k]) != type(v): 28 | raise TypeError( 29 | f'additional_kwargs["{k}"] already exists in this message,' 30 | " but with a different type." 31 | ) 32 | elif isinstance(merged[k], str): 33 | merged[k] += v 34 | elif isinstance(merged[k], int): 35 | merged[k] += v 36 | elif isinstance(merged[k], dict): 37 | merged[k] = do_merge_dicts(merged[k], v) 38 | elif isinstance(merged[k], list): 39 | merged[k] = merged[k] + v 40 | else: 41 | raise TypeError( 42 | f"Additional kwargs key {k} already exists in left dict and value has " 43 | f"unsupported type {type(merged[k])}." 44 | ) 45 | return merged 46 | 47 | 48 | def mk(): 49 | _merge.merge_dicts = do_merge_dicts 50 | merge_dicts = do_merge_dicts -------------------------------------------------------------------------------- /prompt/function_call.prompt: -------------------------------------------------------------------------------- 1 | # role 2 | 你是一个函数调用的助手,你需要根据用户的问题,以及用户提供的工具函数列表,选择调用的函数并基于用户问题生成调用函数的参数。 3 | 4 | # user question 5 | {question} 6 | 7 | # avialable tools 8 | {functions} 9 | 10 | # output format 11 | 你需要使用 JSONArray 格式输出,其中的每一个元素都包含函数名和调用参数,涉及到多少个方法的调用就输出多少个子对象,请根据用户的问题来判断需要调用方法列表中的哪些方法 12 | 13 | # constrants 14 | 输出的子对象除了JSON格式只允许包含"name","arguments"两个key,不要做任何解释,只输出JSON格式的结果,具体格式如下 15 | 16 | [ 17 | { 18 | "name": "get_current_weather", 19 | "arguments": "{\"location\": \"Boston, MA\"}" 20 | }, 21 | { 22 | "name": "get_current_country", 23 | "arguments": "{\"location\": \"Boston, MA\"}" 24 | } 25 | ] 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # langchain 2 | langchain==0.1.16 3 | langchain-core==0.1.44 4 | langchain-community==0.0.33 5 | 6 | 7 | # embedding 8 | sentence-transformers==2.2.2 9 | BCEmbedding==0.1.3 10 | FlagEmbedding==1.2.5 11 | tiktoken==0.5.2 12 | protobuf==4.25.3 13 | #pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 14 | 15 | # web 16 | python-multipart==0.0.6 17 | sanic==23.6.0 18 | sanic_ext==23.6.0 19 | concurrent-log-handler==0.9.25 20 | websocket==0.2.1 21 | websocket-client==1.7.0 22 | httpx==0.27.0 23 | httpx_sse==0.4.0 24 | 25 | # llm 26 | dashscope==1.14.1 27 | qianfan==0.3.7.1 28 | zhipuai==2.0.1.20240427 29 | openai==1.25.0 -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | from sanic.request import Request 2 | from sanic.exceptions import BadRequest 3 | import traceback 4 | from urllib.parse import urlparse 5 | import time 6 | import os 7 | import logging 8 | import tiktoken 9 | import re 10 | import configparser 11 | 12 | 13 | __all__ = ['write_check_file', 'isURL', 'format_source_documents', 'get_time', 'safe_get', 'truncate_filename', 'read_files_with_extensions', 'validate_user_id', 'get_invalid_user_id_msg','cal_tokens','get_config','is_valid_json_array'] 14 | 15 | def get_invalid_user_id_msg(user_id): 16 | return "fail, Invalid user_id: {}. user_id 必须只含有字母,数字和下划线且字母开头".format(user_id) 17 | 18 | def write_check_file(filepath, docs): 19 | folder_path = os.path.join(os.path.dirname(filepath), "tmp_files") 20 | if not os.path.exists(folder_path): 21 | os.makedirs(folder_path) 22 | fp = os.path.join(folder_path, 'load_file.txt') 23 | with open(fp, 'a+', encoding='utf-8') as fout: 24 | fout.write("filepath=%s,len=%s" % (filepath, len(docs))) 25 | fout.write('\n') 26 | for i in docs: 27 | fout.write(str(i)) 28 | fout.write('\n') 29 | fout.close() 30 | 31 | 32 | def isURL(string): 33 | result = urlparse(string) 34 | return result.scheme != '' and result.netloc != '' 35 | 36 | 37 | def format_source_documents(ori_source_documents): 38 | source_documents = [] 39 | for inum, doc in enumerate(ori_source_documents): 40 | #for inum, doc in enumerate(answer_source_documents): 41 | #doc_source = doc.metadata['source'] 42 | file_id = doc.metadata['file_id'] 43 | file_name = doc.metadata['file_name'] 44 | #source_str = doc_source if isURL(doc_source) else os.path.split(doc_source)[-1] 45 | source_info = {'file_id': doc.metadata['file_id'], 46 | 'file_name': doc.metadata['file_name'], 47 | 'content': doc.page_content, 48 | 'retrieval_query': doc.metadata['retrieval_query'], 49 | #'kernel': doc.metadata['kernel'], 50 | 'score': str(doc.metadata['score']), 51 | #'embed_version': doc.metadata['embed_version'] 52 | } 53 | source_documents.append(source_info) 54 | return source_documents 55 | 56 | def get_time(func): 57 | def inner(*arg, **kwargs): 58 | s_time = time.time() 59 | res = func(*arg, **kwargs) 60 | e_time = time.time() 61 | print('函数 {} 执行耗时: {} 秒'.format(func.__name__, e_time - s_time)) 62 | return res 63 | return inner 64 | 65 | def safe_get(req: Request, attr: str, default=None): 66 | try: 67 | if attr in req.form: 68 | return req.form.getlist(attr)[0] 69 | if attr in req.args: 70 | return req.args[attr] 71 | if attr in req.json: 72 | return req.json[attr] 73 | # if value := req.form.get(attr): 74 | # return value 75 | # if value := req.args.get(attr): 76 | # return value 77 | # """req.json执行时不校验content-type,body字段可能不能被正确解析为json""" 78 | # if value := req.json.get(attr): 79 | # return value 80 | except BadRequest: 81 | logging.warning(f"missing {attr} in request") 82 | except Exception as e: 83 | logging.warning(f"get {attr} from request failed:") 84 | logging.warning(traceback.format_exc()) 85 | return default 86 | 87 | def truncate_filename(filename, max_length=200): 88 | # 获取文件名后缀 89 | file_ext = os.path.splitext(filename)[1] 90 | 91 | # 获取不带后缀的文件名 92 | file_name_no_ext = os.path.splitext(filename)[0] 93 | 94 | # 计算文件名长度,注意中文字符 95 | filename_length = len(filename.encode('utf-8')) 96 | 97 | # 如果文件名长度超过最大长度限制 98 | if filename_length > max_length: 99 | # 生成一个时间戳标记 100 | timestamp = str(int(time.time())) 101 | # 截取文件名 102 | while filename_length > max_length: 103 | file_name_no_ext = file_name_no_ext[:-4] 104 | new_filename = file_name_no_ext + "_" + timestamp + file_ext 105 | filename_length = len(new_filename.encode('utf-8')) 106 | else: 107 | new_filename = filename 108 | 109 | return new_filename 110 | 111 | def read_files_with_extensions(): 112 | # 获取当前脚本文件的路径 113 | current_file = os.path.abspath(__file__) 114 | 115 | # 获取当前脚本文件所在的目录 116 | current_dir = os.path.dirname(current_file) 117 | 118 | # 获取项目根目录 119 | project_dir = os.path.dirname(current_dir) 120 | 121 | directory = project_dir + '/data' 122 | print(f'now reading {directory}') 123 | extensions = ['.md', '.txt', '.pdf', '.jpg', '.docx', '.xlsx', '.eml', '.csv'] 124 | for root, dirs, files in os.walk(directory): 125 | for file in files: 126 | if file.endswith(tuple(extensions)): 127 | file_path = os.path.join(root, file) 128 | yield file_path 129 | 130 | def validate_user_id(user_id): 131 | # 定义正则表达式模式 132 | pattern = r'^[A-Za-z][A-Za-z0-9_]*$' 133 | # 检查是否匹配 134 | if isinstance(user_id, str) and re.match(pattern, user_id): 135 | return True 136 | else: 137 | return False 138 | 139 | def cal_tokens(inputs,model): 140 | if isinstance(inputs, str): 141 | inputs = [inputs] 142 | encoding = tiktoken.encoding_for_model(model) 143 | num_tokens = 0 144 | for text in inputs: 145 | num_tokens += len(encoding.encode(text)) 146 | return num_tokens 147 | 148 | def get_config(section, option, fallback=''): 149 | BASE_DIR = os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir))) 150 | config = configparser.ConfigParser() 151 | config.read(os.path.join(BASE_DIR,'config.ini')) 152 | return config.get(section, option, fallback=fallback) 153 | 154 | import json 155 | 156 | def is_valid_json_array(json_string): 157 | """ 158 | 判断给定的字符串是否是一个有效的JSON数组。 159 | 160 | 参数: 161 | json_string (str): 待验证的字符串。 162 | 163 | 返回: 164 | bool: 如果字符串是有效的JSON数组,返回True;否则返回False。 165 | """ 166 | try: 167 | # 尝试解析字符串为JSON 168 | data = json.loads(json_string) 169 | 170 | # 检查解析结果是否为列表 171 | if isinstance(data, list): 172 | return True 173 | else: 174 | return False 175 | except json.JSONDecodeError: 176 | # 解析失败,说明不是有效的JSON 177 | return False --------------------------------------------------------------------------------