├── README.md ├── openai_api.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # ChatGLM3-OpenAI-API 2 | 基于ChatGLM2带的openai_api.py修改支持ChatGLM3。 3 | 4 | 使用方法: 5 | 6 | 1. pip install -r requirements.txt 7 | 2. python openai_api.py -m [chatglm3-6b] 8 | -------------------------------------------------------------------------------- /openai_api.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Implements API for ChatGLM2-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat) 3 | # Usage: python openai_api.py 4 | # Visit http://localhost:8000/docs for documents. 5 | 6 | 7 | import time 8 | import torch 9 | import uvicorn 10 | import sys 11 | import getopt 12 | from pydantic import BaseModel, Field 13 | from fastapi import FastAPI, HTTPException 14 | from fastapi.middleware.cors import CORSMiddleware 15 | from contextlib import asynccontextmanager 16 | from typing import Any, Dict, List, Literal, Optional, Union 17 | from transformers import AutoTokenizer, AutoModel 18 | from sse_starlette.sse import ServerSentEvent, EventSourceResponse 19 | 20 | global model 21 | global tokenizer 22 | 23 | 24 | @asynccontextmanager 25 | async def lifespan(app: FastAPI): # collects GPU memory 26 | yield 27 | if torch.cuda.is_available(): 28 | torch.cuda.empty_cache() 29 | torch.cuda.ipc_collect() 30 | 31 | 32 | app = FastAPI(lifespan=lifespan) 33 | 34 | app.add_middleware( 35 | CORSMiddleware, 36 | allow_origins=["*"], 37 | allow_credentials=True, 38 | allow_methods=["*"], 39 | allow_headers=["*"], 40 | ) 41 | 42 | 43 | class ModelCard(BaseModel): 44 | id: str 45 | object: str = "model" 46 | created: int = Field(default_factory=lambda: int(time.time())) 47 | owned_by: str = "owner" 48 | root: Optional[str] = None 49 | parent: Optional[str] = None 50 | permission: Optional[list] = None 51 | 52 | 53 | class ModelList(BaseModel): 54 | object: str = "list" 55 | data: List[ModelCard] = [] 56 | 57 | 58 | class ChatMessage(BaseModel): 59 | role: Literal["user", "assistant", "system"] 60 | content: str 61 | 62 | 63 | class DeltaMessage(BaseModel): 64 | role: Optional[Literal["user", "assistant", "system"]] = None 65 | content: Optional[str] = None 66 | 67 | 68 | class ChatCompletionRequest(BaseModel): 69 | model: str 70 | messages: List[ChatMessage] 71 | temperature: Optional[float] = None 72 | top_p: Optional[float] = None 73 | max_length: Optional[int] = None 74 | stream: Optional[bool] = False 75 | 76 | 77 | class ChatCompletionResponseChoice(BaseModel): 78 | index: int 79 | message: ChatMessage 80 | finish_reason: Literal["stop", "length"] 81 | 82 | 83 | class ChatCompletionResponseStreamChoice(BaseModel): 84 | index: int 85 | delta: DeltaMessage 86 | finish_reason: Optional[Literal["stop", "length"]] 87 | 88 | 89 | class ChatCompletionResponse(BaseModel): 90 | model: str 91 | object: Literal["chat.completion", "chat.completion.chunk"] 92 | choices: List[Union[ChatCompletionResponseChoice, 93 | ChatCompletionResponseStreamChoice]] 94 | created: Optional[int] = Field(default_factory=lambda: int(time.time())) 95 | 96 | 97 | @app.get("/v1/models", response_model=ModelList) 98 | async def list_models(): 99 | global model_args 100 | model_card = ModelCard(id="gpt-3.5-turbo") 101 | return ModelList(data=[model_card]) 102 | 103 | 104 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) 105 | async def create_chat_completion(request: ChatCompletionRequest): 106 | global model, tokenizer 107 | 108 | if request.messages[-1].role != "user": 109 | raise HTTPException(status_code=400, detail="Invalid request") 110 | query = request.messages[-1].content 111 | 112 | prev_messages = request.messages[:-1] 113 | if len(prev_messages) > 0 and prev_messages[0].role == "system": 114 | query = prev_messages.pop(0).content + query 115 | 116 | history = [] 117 | if len(prev_messages) % 2 == 0: 118 | for i in range(0, len(prev_messages), 2): 119 | if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": 120 | history.append( 121 | {"role": prev_messages[i].role, "content": prev_messages[i].content}) 122 | history.append( 123 | {"role": prev_messages[i+1].role, "content": prev_messages[i+1].content}) 124 | 125 | if request.stream: 126 | generate = predict(query, history, request.model) 127 | return EventSourceResponse(generate, media_type="text/event-stream") 128 | 129 | response, _ = model.chat(tokenizer, query, history=history) 130 | 131 | choice_data = ChatCompletionResponseChoice( 132 | index=0, 133 | message=ChatMessage(role="assistant", content=response), 134 | finish_reason="stop" 135 | ) 136 | 137 | return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") 138 | 139 | 140 | async def predict(query: str, history: None, model_id: str): 141 | global model, tokenizer 142 | 143 | if history is None: 144 | history = [] 145 | 146 | choice_data = ChatCompletionResponseStreamChoice( 147 | index=0, 148 | delta=DeltaMessage(role="assistant"), 149 | finish_reason=None 150 | ) 151 | chunk = ChatCompletionResponse(model=model_id, choices=[ 152 | choice_data], object="chat.completion.chunk") 153 | yield "{}".format(chunk.json(exclude_unset=True)) 154 | 155 | current_length = 0 156 | 157 | for new_response, _ in model.stream_chat(tokenizer, query, history): 158 | if len(new_response) == current_length: 159 | continue 160 | 161 | new_text = new_response[current_length:] 162 | current_length = len(new_response) 163 | 164 | choice_data = ChatCompletionResponseStreamChoice( 165 | index=0, 166 | delta=DeltaMessage(content=new_text), 167 | finish_reason=None 168 | ) 169 | chunk = ChatCompletionResponse(model=model_id, choices=[ 170 | choice_data], object="chat.completion.chunk") 171 | yield "{}".format(chunk.json(exclude_unset=True)) 172 | 173 | choice_data = ChatCompletionResponseStreamChoice( 174 | index=0, 175 | delta=DeltaMessage(), 176 | finish_reason="stop" 177 | ) 178 | chunk = ChatCompletionResponse(model=model_id, choices=[ 179 | choice_data], object="chat.completion.chunk") 180 | yield "{}".format(chunk.json(exclude_unset=True)) 181 | yield '[DONE]' 182 | 183 | 184 | def main(argv): 185 | global model, tokenizer 186 | 187 | models = "" 188 | port = 8000 189 | 190 | opts, args = getopt.getopt(argv, "m:p:") 191 | for opt, arg in opts: 192 | if opt == "-m": 193 | models = arg 194 | elif opt == "-p": 195 | port = int(arg) 196 | 197 | if models == "": 198 | print("usage: ", sys.argv[0], "-m models path") 199 | sys.exit(1) 200 | 201 | tokenizer = AutoTokenizer.from_pretrained(models, trust_remote_code=True) 202 | model = AutoModel.from_pretrained(models, trust_remote_code=True).cuda() 203 | # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量 204 | # from utils import load_model_on_gpus 205 | # model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2) 206 | model.eval() 207 | 208 | uvicorn.run(app, host='0.0.0.0', port=port, workers=1) 209 | 210 | 211 | if __name__ == "__main__": 212 | main(sys.argv[1:]) 213 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | protobuf 2 | transformers==4.30.2 3 | cpm_kernels 4 | torch>=2.0 5 | gradio==3.39 6 | mdtex2html 7 | sentencepiece 8 | accelerate 9 | sse-starlette 10 | --------------------------------------------------------------------------------