├── Dockerfile ├── README.md ├── erniebot-openai-api.py └── requirements.txt /Dockerfile: -------------------------------------------------------------------------------- 1 | # 使用官方的Python基础镜像 2 | FROM python:3.9-slim 3 | 4 | # 设置工作目录 5 | WORKDIR /app 6 | 7 | # 将当前目录下的所有文件复制到工作目录中 8 | COPY . /app 9 | 10 | # 安装必要的依赖 11 | RUN pip install --no-cache-dir -r requirements.txt 12 | 13 | # 环境变量设置 14 | ENV EB_AGENT_ACCESS_TOKEN="" 15 | ENV EB_AGENT_LOGGING_LEVEL="" 16 | 17 | # 暴露端口 18 | EXPOSE 8000 19 | 20 | # 启动命令 21 | CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # erniebot-openai-api 2 | erniebot兼容openai的API调用方式,支持流式,非流式调用 ,支持system提示词 3 | 4 | # 快速使用 5 | 6 | ```bash 7 | 8 | conda create -n enbot python=3.10.6 9 | 10 | git clone https://github.com/Jun-Howie/erniebot-openai-api.git 11 | 12 | cd erniebot-openai-api 13 | 14 | pip install -r requirements.txt 15 | 16 | python erniebot-openai-api.py 17 | 18 | ``` 19 | 20 | # docker 21 | 22 | ```bash 23 | # 替换YOU_ACCESS_TOKEN 24 | docker run -e EB_AGENT_ACCESS_TOKEN=YOU_ACCESS_TOKEN -e EB_AGENT_LOGGING_LEVEL=info -p 8000:8000 amberyu/enbot 25 | 26 | 27 | 阿里云镜像 28 | docker run -d -e EB_AGENT_ACCESS_TOKEN=YOU_ACCESS_TOKEN -e EB_AGENT_LOGGING_LEVEL=info -p 8000:8000 registry.cn-shanghai.aliyuncs.com/chatpet/enbot 29 | 30 | ``` 31 | 32 | 33 | 34 | # 调用测试 35 | curl --location --request POST 'http://127.0.0.1:8000/v1/chat/completions' \ 36 | --header 'Content-Type: application/json' \ 37 | --data-raw '{ 38 | "model": "ernie-4.0", 39 | "messages": [ 40 | { 41 | "role": "user", 42 | "content": "百度公关一号位" 43 | } 44 | ] 45 | }' 46 | 47 | 48 | # 测试结果 49 | { 50 | "model": "ernie-4.0", 51 | "object": "chat.completion", 52 | "choices": [ 53 | { 54 | "index": 0, 55 | "message": { 56 | "role": "assistant", 57 | "content": "百度公关一号位指的是**百度副总裁璩静**。璩静毕业于外交学院,曾任新华社中央新闻采访中心记者,华为公共及政府事务部副总裁、中国媒体事务部部长。2021年8月入职百度担任公关副总裁(VP),负责集团公众沟通部工作。\n\n近期,璩静开设了名为“我是璩静”的抖音账号,并因发布的内容引发了争议和关注。其中,包括“员工闹分手提离职我秒批”、“为什么要考虑员工的家庭”、“举报信洒满工位”等视频内容在网络上广泛传播。这些视频在短短几天内就吸引了大量粉丝,使璩静成为了互联网媒体公关圈的热议话题。\n\n以上信息仅供参考,建议查阅相关新闻报道获取更多信息。" 58 | }, 59 | "finish_reason": "stop" 60 | } 61 | ], 62 | "created": 1715152014 63 | } 64 | # 使用飞桨平台调用ernie-4.0 / PS:toknes 比千帆便宜 65 | 66 | [飞桨ai studio星河社区](https://aistudio.baidu.com/)
67 | [ERNIE Bot文档](https://ernie-bot-agent.readthedocs.io/zh-cn/latest/sdk/)
68 | ![cd8dd2724b821c3004e51a1facb0b66](https://github.com/Jun-Howie/erniebot-openai-api/assets/62869005/9c489a0c-2c7f-4045-bc3e-7c35c4cc2721) 69 | 70 | ![image](https://github.com/Jun-Howie/erniebot-openai-api/assets/62869005/b4f1957b-6dd3-4ac6-983f-b31eb088b9e0) 71 | 72 | # 感谢 73 | 感谢[lixiaoxiangzhi](https://github.com/lixiaoxiangzhi) 帮助解决流式异步编程问题
74 | 感谢[ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B/blob/main/openai_api.py) 提供原始兼容openai-api的封装思路
75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /erniebot-openai-api.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | import uvicorn 5 | from erniebot_agent.chat_models import ERNIEBot 6 | from erniebot_agent.memory import HumanMessage, AIMessage, SystemMessage 7 | from fastapi import FastAPI, HTTPException 8 | from pydantic import BaseModel, Field 9 | from typing import List, Literal, Optional, Union 10 | from sse_starlette.sse import EventSourceResponse 11 | 12 | os.environ["EB_AGENT_ACCESS_TOKEN"] = "" 13 | os.environ["EB_AGENT_LOGGING_LEVEL"] = "info" 14 | 15 | app = FastAPI() # 创建 api 对象 16 | 17 | 18 | ##请求入参 19 | class DeltaMessage(BaseModel): 20 | role: Optional[Literal["user", "assistant", "system"]] = None 21 | content: Optional[str] = None 22 | 23 | 24 | class ChatCompletionResponseStreamChoice(BaseModel): 25 | index: int 26 | delta: DeltaMessage 27 | finish_reason: Optional[Literal["stop", "length"]] 28 | 29 | 30 | class ChatMessage(BaseModel): 31 | role: Literal["user", "assistant", "system"] 32 | content: str 33 | 34 | 35 | class ChatCompletionResponseChoice(BaseModel): 36 | index: int 37 | message: ChatMessage 38 | finish_reason: Literal["stop", "length"] 39 | 40 | 41 | # 创建参数对象 42 | class ChatCompletionResponse(BaseModel): 43 | model: str 44 | object: Literal["chat.completion", "chat.completion.chunk"] 45 | choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] 46 | created: Optional[int] = Field(default_factory=lambda: int(time.time())) 47 | 48 | 49 | # 返回体 50 | class ChatCompletionRequest(BaseModel): 51 | model: str 52 | messages: List[ChatMessage] 53 | temperature: Optional[float] = None 54 | top_p: Optional[float] = None 55 | max_length: Optional[int] = None 56 | stream: Optional[bool] = False 57 | 58 | 59 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) 60 | async def create_chat_completion(request: ChatCompletionRequest): 61 | # 封装对话 62 | messages = [] 63 | # 获取最近一个用户问题 64 | if request.messages[-1].role != "user": 65 | raise HTTPException(status_code=400, detail="Invalid request") 66 | query = request.messages[-1].content 67 | # 获取system 提示词 68 | prev_messages = request.messages[:-1] 69 | if len(prev_messages) > 0 and prev_messages[0].role == "system": 70 | system_message = SystemMessage(content=prev_messages.pop(0).content) 71 | else: 72 | system_message = SystemMessage(content='') 73 | 74 | # 追加历史记录 75 | if len(prev_messages) % 2 == 0: 76 | for i in range(0, len(prev_messages), 2): 77 | if prev_messages[i].role == "user" and prev_messages[i + 1].role == "assistant": 78 | messages.append(HumanMessage(content=prev_messages[i].content)) 79 | messages.append(AIMessage(content=prev_messages[i + 1].content)) 80 | 81 | # 指定模型 82 | model = ERNIEBot(model=request.model) 83 | 84 | # 添加最新用户问题 85 | messages.append(HumanMessage(content=query)) 86 | # AI回答 请求模型 87 | 88 | # 流式调用 89 | if request.stream: 90 | generate = predict(system_message.content, messages, request.model) 91 | return EventSourceResponse(generate, media_type="text/event-stream") 92 | 93 | # 非流式调用 94 | if system_message.content != '': 95 | ai_message = await model.chat(messages=messages, system=system_message.content) 96 | else: 97 | ai_message = await model.chat(messages=messages) 98 | 99 | choice_data = ChatCompletionResponseChoice( 100 | index=0, 101 | message=ChatMessage(role="assistant", content=ai_message.content), 102 | finish_reason="stop" 103 | ) 104 | return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") 105 | 106 | 107 | async def predict(system: str, messages: List[List[str]], model_id: str): 108 | choice_data = ChatCompletionResponseStreamChoice( 109 | index=0, 110 | delta=DeltaMessage(role="assistant"), 111 | finish_reason=None 112 | ) 113 | chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") 114 | # yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) 115 | yield json.dumps(chunk.dict(exclude_unset=True), ensure_ascii=False) 116 | 117 | model = ERNIEBot(model=model_id) 118 | 119 | if system: 120 | ai_message = await model.chat(system=system, messages=messages, stream=True) 121 | else: 122 | ai_message = await model.chat(messages=messages, stream=True) 123 | 124 | async for chunk in ai_message: 125 | # result += chunk.content 126 | choice_data = ChatCompletionResponseStreamChoice( 127 | index=0, 128 | delta=DeltaMessage(content=chunk.content), 129 | finish_reason=None 130 | ) 131 | chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") 132 | print(chunk) 133 | # yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) 134 | yield json.dumps(chunk.dict(exclude_unset=True), ensure_ascii=False) 135 | 136 | choice_data = ChatCompletionResponseStreamChoice( 137 | index=0, 138 | delta=DeltaMessage(), 139 | finish_reason="stop" 140 | ) 141 | chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") 142 | # yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) 143 | yield json.dumps(chunk.dict(exclude_unset=True), ensure_ascii=False) 144 | yield '[DONE]' 145 | 146 | 147 | if __name__ == "__main__": 148 | uvicorn.run(app, host='127.0.0.1', port=8000, workers=1) 149 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | uvicorn==0.29.0 2 | erniebot_agent==0.5.2 3 | fastapi==0.111.0 4 | sse_starlette==2.1.0 --------------------------------------------------------------------------------