├── README.md ├── colab └── my_openai_api.ipynb ├── my_openai_api.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # my_openai_api 2 | 3 | 部署你自己的**OpenAI** 格式api😆,基于**flask, transformers** (使用 **Baichuan2-13B-Chat-4bits** 模型,可以运行在单张Tesla T4显卡) ,实现以下**OpenAI**接口: 4 | - **Chat** /v1/chat/completions 5 | - **Models** /v1/models 6 | - **Completions** /v1/completions 7 | 8 | 同时实现接口相应的STREAMING模式,保证在**langchain**中基础调用 9 | 10 | ## 起因 11 | 12 | 目前Baichuan2-13B-Chat int4量化后可在单张tesla T4显卡运行,并且效果和速度还可以,可以和gpt-3.5媲美。 13 | - **Baichuan2-13B-Chat-4bits**:[https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat-4bits](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat-4bits) 14 | 15 | ## 最低配置 16 | 17 | 需要16g显存,如果主机显存不够可以考虑~~腾讯云的活动,60块钱15天32g内存、T4显卡的主机~~(**活动已经下架,可以考虑使用Colab**),非常划算😝,可以跑动baichuan2-13b-chat-4bits。 18 | 19 | 20 | 21 | ~~地址: https://cloud.tencent.com/act/pro/gpu-study~~ 22 | 23 | 如果想要本地运行,T4显卡价格在5600元左右,也可以考虑2080ti魔改22g版本,某宝只要2600元左右 🤓️。 24 | 25 | ## Colab 26 | 免费的Colab可以使用12G内存和T4显卡🤓️,可以考虑免费的Colab结合ngrok运行 27 | 28 | ### 步骤 29 | 1. 打开 [![Colab](https://img.shields.io/badge/✏️-Colab-important)](https://colab.research.google.com/github/billvsme/my_openai_api/blob/main/colab/my_openai_api.ipynb) 30 | 2. 修改 NGROK_AUTHTOKEN,注意不要带双引号 31 | 3. 修改运行配置,点击 修改 -> 笔记本设置 中把显卡改成T4 32 | 4. 运行,点击 代码执行程序 -> 全部运行 (下载模型时间较长,请耐心等在) 33 | 5. 运行结束后查看输出的ngrok外网访问链接,使用该链接请求 34 | 35 | 36 | ## 安装 37 | 38 | 1. 下载代码 39 | ``` 40 | git clone https://github.com/billvsme/my_openai_api.git 41 | ``` 42 | 2. 下载Baichuan2-13B-Chat-4bits模型 43 | ``` 44 | cd my_openai_api 45 | 46 | git lfs install #需要先安装好git-lfs 47 | git clone https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat-4bits 48 | ``` 49 | 3. 安装venv环境 50 | ``` 51 | mkdir ~/.venv 52 | python -m venv ~/.venv/ai 53 | . ~/.venv/ai/bin/activate 54 | 55 | pip install -r requirements.txt 56 | ``` 57 | 58 | ## 启动 59 | ``` 60 | python my_openai_api.py 61 | 或者 62 | gunicorn -b 0.0.0.0:5000 --workers=1 my_openai_api:app 63 | ``` 64 | 65 | ## 文档 66 | 实现了openai的models, chat, moderations 3个接口 67 | 可以参考https://platform.openai.com/docs/api-reference/chat 68 | ``` 69 | 打开 http://127.0.0.1:5000/apidocs/ 70 | ``` 71 | 72 | ![github_my_open_api_002.png](https://s2.loli.net/2023/09/25/o8I5GE3ONfhSaqz.png) 73 | 74 | 75 | ## 使用 76 | 77 | ### langchain 78 | 79 | 替换openai_base_api 80 | ``` 81 | # coding: utf-8 82 | from langchain.llms import OpenAI 83 | from langchain.chat_models import ChatOpenAI 84 | from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler 85 | from langchain.schema import ( 86 | HumanMessage, 87 | ) 88 | 89 | openai_api_base = "http://127.0.0.1:5000/v1" 90 | openai_api_key = "test" 91 | 92 | # /v1/chat/completions流式响应 93 | chat_model = ChatOpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], openai_api_base=openai_api_base, openai_api_key=openai_api_key) 94 | resp = chat_model([HumanMessage(content="给我一个django admin的demo代码")]) 95 | chat_model.predict("你叫什么?") 96 | 97 | # /v1/chat/completions普通响应 98 | chat_model = ChatOpenAI(openai_api_base=openai_api_base, openai_api_key=openai_api_key) 99 | resp = chat_model.predict("给我一个django admin的demo代码") 100 | print(resp) 101 | 102 | # /v1/completions流式响应 103 | llm = OpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0, openai_api_base=openai_api_base, openai_api_key=openai_api_key) 104 | llm("登鹳雀楼->王之涣\n夜雨寄北->") 105 | 106 | # /v1/completions普通响应 107 | llm = OpenAI(openai_api_base=openai_api_base, openai_api_key=openai_api_key) 108 | print(llm("登鹳雀楼->王之涣\n夜雨寄北->")) 109 | ``` 110 | 111 | ### ChatGPT Next 112 | 113 | 设置中把接口地址修改为你的ip,如果部署网页为https,注意在Chrome设置中“不安全内容”选择“允许” 114 | 115 | 116 | 117 | ### OpenAI Translator 118 | 119 | 设置中把api url修改为你的ip 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /colab/my_openai_api.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "gpuType": "T4" 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | }, 16 | "accelerator": "GPU" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": { 23 | "id": "-50AbTRLCh5D" 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "\"\"\"下载代码\n", 28 | "\"\"\"\n", 29 | "%cd /content/\n", 30 | "!git clone https://github.com/billvsme/my_openai_api.git" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "source": [ 36 | "\"\"\"下载模型\n", 37 | "\"\"\"\n", 38 | "%cd /content/my_openai_api\n", 39 | "!git clone --depth=1 https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat-4bits" 40 | ], 41 | "metadata": { 42 | "id": "vtbxBPhlC60S" 43 | }, 44 | "execution_count": null, 45 | "outputs": [] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "source": [ 50 | "\"\"\"安装环境\n", 51 | "\"\"\"\n", 52 | "%cd /content/my_openai_api\n", 53 | "\n", 54 | "!apt install python3.10-venv\n", 55 | "!mkdir ~/.venv\n", 56 | "!python -m venv ~/.venv/ai\n", 57 | "!~/.venv/ai/bin/pip install -r requirements.txt" 58 | ], 59 | "metadata": { 60 | "id": "33opJSj6DVYr" 61 | }, 62 | "execution_count": null, 63 | "outputs": [] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "source": [ 68 | "\"\"\"运行\n", 69 | "\"\"\"\n", 70 | "%cd /content/my_openai_api\n", 71 | "# 获取ngrok的tokn,https://dashboard.ngrok.com/get-started/your-authtoken\n", 72 | "# 注意不要带双引号\n", 73 | "%env NGROK_AUTHTOKEN=yourtoken\n", 74 | "!~/.venv/ai/bin/python my_openai_api.py" 75 | ], 76 | "metadata": { 77 | "id": "NKGQRhxkDddZ" 78 | }, 79 | "execution_count": null, 80 | "outputs": [] 81 | } 82 | ] 83 | } -------------------------------------------------------------------------------- /my_openai_api.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import json 3 | import time 4 | import uuid 5 | from threading import Thread 6 | 7 | import torch 8 | from flask import Flask, current_app, request, Blueprint, stream_with_context 9 | from flask_cors import CORS 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | from transformers.generation.utils import GenerationConfig 12 | from transformers.generation.streamers import TextIteratorStreamer 13 | from marshmallow import validate 14 | from flasgger import APISpec, Schema, Swagger, fields 15 | from apispec.ext.marshmallow import MarshmallowPlugin 16 | from apispec_webframeworks.flask import FlaskPlugin 17 | 18 | 19 | class Transformers(): 20 | def __init__(self, app=None, tokenizer=None, model=None): 21 | self.chat = None 22 | if app is not None: 23 | self.init_app(app, tokenizer, model) 24 | 25 | def init_app(self, app, tokenizer=None, model=None, chat=None): 26 | self.tokenizer = tokenizer 27 | self.model = model 28 | if chat is None: 29 | self.chat = model.chat 30 | 31 | 32 | tfs = Transformers() 33 | base_tfs = Transformers() 34 | 35 | 36 | models_bp = Blueprint('Models', __name__, url_prefix='/v1/models') 37 | chat_bp = Blueprint('Chat', __name__, url_prefix='/v1/chat') 38 | completions_bp = Blueprint('Completions', __name__, url_prefix='/v1/completions') 39 | 40 | 41 | def sse(line, field="data"): 42 | return "{}: {}\n\n".format( 43 | field, json.dumps(line, ensure_ascii=False) if isinstance(line, dict) else line) 44 | 45 | 46 | def empty_cache(): 47 | if torch.backends.mps.is_available(): 48 | torch.mps.empty_cache() 49 | 50 | 51 | def create_app(): 52 | app = Flask(__name__) 53 | CORS(app) 54 | app.register_blueprint(models_bp) 55 | app.register_blueprint(chat_bp) 56 | app.register_blueprint(completions_bp) 57 | 58 | @app.after_request 59 | def after_request(resp): 60 | empty_cache() 61 | return resp 62 | 63 | # Init Swagger 64 | spec = APISpec( 65 | title='My OpenAI api', 66 | version='0.0.1', 67 | openapi_version='3.0.2', 68 | plugins=[ 69 | FlaskPlugin(), 70 | MarshmallowPlugin(), 71 | ], 72 | ) 73 | 74 | bearer_scheme = {"type": "http", "scheme": "bearer"} 75 | spec.components.security_scheme("bearer", bearer_scheme) 76 | template = spec.to_flasgger( 77 | app, 78 | paths=[list_models, create_chat_completion, create_completion] 79 | ) 80 | 81 | app.config['SWAGGER'] = {"openapi": "3.0.2"} 82 | Swagger(app, template=template) 83 | 84 | # Init transformers 85 | model_name = "./Baichuan2-13B-Chat-4bits" 86 | tokenizer = AutoTokenizer.from_pretrained( 87 | model_name, use_fast=False, trust_remote_code=True) 88 | model = AutoModelForCausalLM.from_pretrained( 89 | model_name, device_map="auto", trust_remote_code=True) 90 | model.generation_config = GenerationConfig.from_pretrained(model_name) 91 | 92 | tfs.init_app(app, tokenizer, model) 93 | base_tfs.init_app(app, tokenizer, model) 94 | 95 | return app 96 | 97 | 98 | class ModelSchema(Schema): 99 | id = fields.Str() 100 | object = fields.Str(dump_default="model", metadata={"example": "model"}) 101 | created = fields.Int(dump_default=lambda: int(time.time()), metadata={"example": 1695402567}) 102 | owned_by = fields.Str(dump_default="owner", metadata={"example": "owner"}) 103 | 104 | 105 | class ModelListSchema(Schema): 106 | object = fields.Str(dump_default="list", metadata={"example": "list"}) 107 | data = fields.List(fields.Nested(ModelSchema), dump_default=[]) 108 | 109 | 110 | class ChatMessageSchema(Schema): 111 | role = fields.Str(required=True, metadata={"example": "system"}) 112 | content = fields.Str(required=True, metadata={"example": "You are a helpful assistant."}) 113 | 114 | 115 | class CreateChatCompletionSchema(Schema): 116 | model = fields.Str(required=True, metadata={"example": "gpt-3.5-turbo"}) 117 | messages = fields.List( 118 | fields.Nested(ChatMessageSchema), required=True, 119 | metadata={"example": [ 120 | ChatMessageSchema().dump({"role": "system", "content": "You are a helpful assistant."}), 121 | ChatMessageSchema().dump({"role": "user", "content": "Hello!"}) 122 | ]} 123 | ) 124 | temperature = fields.Float(load_default=1.0, metadata={"example": 1.0}) 125 | top_p = fields.Float(load_default=1.0, metadata={"example": 1.0}) 126 | n = fields.Int(load_default=1, metadata={"example": 1}) 127 | max_tokens = fields.Int(load_default=None, metadata={"example": None}) 128 | stream = fields.Bool(load_default=False, example=False) 129 | presence_penalty = fields.Float(load_default=0.0, example=0.0) 130 | frequency_penalty = fields.Float(load_default=0.0, example=0.0) 131 | 132 | 133 | class ChatCompletionChoiceSchema(Schema): 134 | index = fields.Int(metadata={"example": 0}) 135 | message = fields.Nested(ChatMessageSchema, metadata={ 136 | "example": ChatMessageSchema().dump( 137 | {"role": "assistant", "content": "\n\nHello there, how may I assist you today?"} 138 | )}) 139 | finish_reason = fields.Str( 140 | validate=validate.OneOf(["stop", "length", "content_filter", "function_call"]), 141 | metadata={"example": "stop"}) 142 | 143 | 144 | class ChatCompletionSchema(Schema): 145 | id = fields.Str( 146 | dump_default=lambda: uuid.uuid4().hex, 147 | metadata={"example": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7"}) 148 | object = fields.Constant("chat.completion") 149 | created = fields.Int(dump_default=lambda: int(time.time()), metadata={"example": 1695402567}) 150 | model = fields.Str(metadata={"example": "gpt-3.5-turbo"}) 151 | choices = fields.List(fields.Nested(ChatCompletionChoiceSchema)) 152 | 153 | 154 | class ChatDeltaSchema(Schema): 155 | role = fields.Str(metadata={"example": "assistant"}) 156 | content = fields.Str(required=True, metadata={"example": "Hello"}) 157 | 158 | 159 | class ChatCompletionChunkChoiceSchema(Schema): 160 | index = fields.Int(metadata={"example": 0}) 161 | delta = fields.Nested(ChatDeltaSchema, metadata={"example": ChatDeltaSchema().dump( 162 | {"role": "assistant", "example": "Hello"})}) 163 | finish_reason = fields.Str( 164 | validate=validate.OneOf(["stop", "length", "content_filter", "function_call"]), 165 | metadata={"example": "stop"}) 166 | 167 | 168 | class ChatCompletionChunkShema(Schema): 169 | id = fields.Str( 170 | dump_default=lambda: uuid.uuid4().hex, 171 | metadata={"example": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7"}) 172 | object = fields.Constant("chat.completion.chunk") 173 | created = fields.Int(dump_default=lambda: int(time.time()), metadata={"example": 1695402567}) 174 | model = fields.Str(metadata={"example": "gpt-3.5-turbo"}) 175 | choices = fields.List(fields.Nested(ChatCompletionChunkChoiceSchema)) 176 | 177 | 178 | class CreateCompletionSchema(Schema): 179 | model = fields.Str(required=True, metadata={"example": "gpt-3.5-turbo"}) 180 | prompt = fields.Raw(metadata={"example": "Say this is a test"}) 181 | max_tokens = fields.Int(load_default=16, metadata={"example": 256}) 182 | temperature = fields.Float(load_default=1.0, metadata={"example": 1.0}) 183 | top_p = fields.Float(load_default=1.0, metadata={"example": 1.0}) 184 | n = fields.Int(load_default=1, metadata={"example": 1}) 185 | stream = fields.Bool(load_default=False, example=False) 186 | logit_bias = fields.Dict(load_default=None, example={}) 187 | presence_penalty = fields.Float(load_default=0.0, example=0.0) 188 | frequency_penalty = fields.Float(load_default=0.0, example=0.0) 189 | 190 | 191 | class CompletionChoiceSchema(Schema): 192 | index = fields.Int(load_default=0, metadata={"example": 0}) 193 | text = fields.Str(required=True, metadata={"example": "登鹳雀楼->王之涣\n夜雨寄北->"}) 194 | logprobs = fields.Dict(load_default=None, metadata={"example": {}}) 195 | finish_reason = fields.Str( 196 | validate=validate.OneOf(["stop", "length", "content_filter", "function_call"]), 197 | metadata={"example": "stop"}) 198 | 199 | 200 | class CompletionUsageSchema(Schema): 201 | prompt_tokens = fields.Int(metadata={"example": 5}) 202 | completion_tokens = fields.Int(metadata={"example": 7}) 203 | total_tokens = fields.Int(metadata={"example": 12}) 204 | 205 | 206 | class CompletionSchema(Schema): 207 | id = fields.Str( 208 | dump_default=lambda: uuid.uuid4().hex, 209 | metadata={"example": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7"}) 210 | object = fields.Constant("text_completion") 211 | created = fields.Int(dump_default=lambda: int(time.time()), metadata={"example": 1695402567}) 212 | model = fields.Str(metadata={"example": "gpt-3.5-turbo"}) 213 | choices = fields.List(fields.Nested(CompletionChoiceSchema)) 214 | usage = fields.Nested(CompletionUsageSchema) 215 | 216 | 217 | @models_bp.route("") 218 | def list_models(): 219 | """ 220 | List models 221 | --- 222 | get: 223 | tags: 224 | - Models 225 | description: Lists the currently available models, \ 226 | and provides basic information about each one such as the owner and availability. 227 | security: 228 | - bearer: [] 229 | responses: 230 | 200: 231 | description: Models returned 232 | content: 233 | application/json: 234 | schema: ModelListSchema 235 | """ 236 | 237 | model = ModelSchema().dump({"id": "gpt-3.5-turbo"}) 238 | return ModelListSchema().dump({"data": [model]}) 239 | 240 | 241 | @stream_with_context 242 | def stream_chat_generate(messages): 243 | delta = ChatDeltaSchema().dump( 244 | {"role": "assistant"}) 245 | choice = ChatCompletionChunkChoiceSchema().dump( 246 | {"index": 0, "delta": delta, "finish_reason": None}) 247 | 248 | yield sse( 249 | ChatCompletionChunkShema().dump({ 250 | "model": "gpt-3.5-turbo", 251 | "choices": [choice]}) 252 | ) 253 | 254 | position = 0 255 | for response in tfs.chat( 256 | tfs.tokenizer, 257 | messages, 258 | stream=True): 259 | content = response[position:] 260 | if not content: 261 | continue 262 | empty_cache() 263 | delta = ChatDeltaSchema().dump( 264 | {"content": content}) 265 | choice = ChatCompletionChunkChoiceSchema().dump( 266 | {"index": 0, "delta": delta, "finish_reason": None}) 267 | 268 | yield sse( 269 | ChatCompletionChunkShema().dump({ 270 | "model": "gpt-3.5-turbo", 271 | "choices": [choice]}) 272 | ) 273 | position = len(response) 274 | 275 | choice = ChatCompletionChunkChoiceSchema().dump( 276 | {"index": 0, "delta": {}, "finish_reason": "stop"}) 277 | 278 | yield sse( 279 | ChatCompletionChunkShema().dump({ 280 | "model": "gpt-3.5-turbo", 281 | "choices": [choice]}) 282 | ) 283 | 284 | yield sse('[DONE]') 285 | 286 | 287 | @chat_bp.route("/completions", methods=['POST']) 288 | def create_chat_completion(): 289 | """Create chat completion 290 | --- 291 | post: 292 | tags: 293 | - Chat 294 | description: Creates a model response for the given chat conversation. 295 | requestBody: 296 | request: True 297 | content: 298 | application/json: 299 | schema: CreateChatCompletionSchema 300 | security: 301 | - bearer: [] 302 | responses: 303 | 200: 304 | description: ChatCompletion return 305 | content: 306 | application/json: 307 | schema: 308 | oneOf: 309 | - ChatCompletionSchema 310 | - ChatCompletionChunkShema 311 | """ 312 | 313 | create_chat_completion = CreateChatCompletionSchema().load(request.json) 314 | 315 | if create_chat_completion["stream"]: 316 | return current_app.response_class( 317 | stream_chat_generate(create_chat_completion["messages"]), 318 | mimetype="text/event-stream" 319 | ) 320 | else: 321 | response = tfs.chat(tfs.tokenizer, create_chat_completion["messages"]) 322 | 323 | message = ChatMessageSchema().dump( 324 | {"role": "assistant", "content": response}) 325 | choice = ChatCompletionChoiceSchema().dump( 326 | {"index": 0, "message": message, "finish_reason": "stop"}) 327 | return ChatCompletionSchema().dump({ 328 | "model": "gpt-3.5-turbo", 329 | "choices": [choice]}) 330 | 331 | 332 | @stream_with_context 333 | def stream_generate(prompts, **generate_kwargs): 334 | finish_choices = [] 335 | for index, prompt in enumerate(prompts): 336 | choice = CompletionChoiceSchema().dump( 337 | {"index": index, "text": "\n\n", "logprobs": None, "finish_reason": None}) 338 | 339 | yield sse( 340 | CompletionSchema().dump( 341 | {"model": "gpt-3.5-turbo-instruct", "choices": [choice]}) 342 | ) 343 | 344 | inputs = base_tfs.tokenizer(prompt, padding=True, return_tensors='pt') 345 | inputs = inputs.to(base_tfs.model.device) 346 | streamer = TextIteratorStreamer( 347 | base_tfs.tokenizer, 348 | decode_kwargs={"skip_special_tokens": True}) 349 | Thread( 350 | target=base_tfs.model.generate, kwargs=dict( 351 | inputs, streamer=streamer, 352 | repetition_penalty=1.1, **generate_kwargs) 353 | ).start() 354 | 355 | finish_reason = None 356 | for text in streamer: 357 | if not text: 358 | continue 359 | empty_cache() 360 | if text.endswith(base_tfs.tokenizer.eos_token): 361 | finish_reason = "stop" 362 | break 363 | 364 | choice = CompletionChoiceSchema().dump( 365 | {"index": index, "text": text, "logprobs": None, "finish_reason": None}) 366 | 367 | yield sse( 368 | CompletionSchema().dump( 369 | {"model": "gpt-3.5-turbo-instruct", "choices": [choice]}) 370 | ) 371 | else: 372 | finish_reason = "length" 373 | choice = CompletionChoiceSchema().dump( 374 | {"index": index, "text": text, "logprobs": None, "finish_reason": finish_reason}) 375 | yield sse( 376 | CompletionSchema().dump( 377 | {"model": "gpt-3.5-turbo-instruct", "choices": [choice]}) 378 | ) 379 | 380 | choice = CompletionChoiceSchema().dump( 381 | {"index": index, "text": "", "logprobs": None, "finish_reason": finish_reason}) 382 | finish_choices.append(choice) 383 | 384 | yield sse( 385 | CompletionSchema().dump( 386 | {"model": "gpt-3.5-turbo-instruct", "choices": finish_choices}) 387 | ) 388 | 389 | yield sse('[DONE]') 390 | 391 | 392 | @completions_bp.route("", methods=["POST"]) 393 | def create_completion(): 394 | """Create completion 395 | --- 396 | post: 397 | tags: 398 | - Completions 399 | description: Creates a completion for the provided prompt and parameters. 400 | requestBody: 401 | request: True 402 | content: 403 | application/json: 404 | schema: CreateCompletionSchema 405 | security: 406 | - bearer: [] 407 | responses: 408 | 200: 409 | description: Completion return 410 | content: 411 | application/json: 412 | schema: 413 | CompletionSchema 414 | """ 415 | create_completion = CreateCompletionSchema().load(request.json) 416 | 417 | prompt = create_completion["prompt"] 418 | prompts = prompt if isinstance(prompt, list) else [prompt] 419 | 420 | if create_completion["stream"]: 421 | return current_app.response_class( 422 | stream_generate(prompts, max_new_tokens=create_completion["max_tokens"]), 423 | mimetype="text/event-stream" 424 | ) 425 | else: 426 | choices = [] 427 | prompt_tokens = 0 428 | completion_tokens = 0 429 | for index, prompt in enumerate(prompts): 430 | inputs = base_tfs.tokenizer(prompt, return_tensors='pt') 431 | inputs = inputs.to(base_tfs.model.device) 432 | prompt_tokens += len(inputs["input_ids"][0]) 433 | pred = base_tfs.model.generate( 434 | **inputs, max_new_tokens=create_completion["max_tokens"], repetition_penalty=1.1) 435 | 436 | completion_tokens += len(pred.cpu()[0]) 437 | resp = base_tfs.tokenizer.decode(pred.cpu()[0], skip_special_tokens=False) 438 | 439 | finish_reason = None 440 | if resp.endswith(base_tfs.tokenizer.eos_token): 441 | finish_reason = "stop" 442 | resp = resp[:-len(base_tfs.tokenizer.eos_token)] 443 | else: 444 | finish_reason = "length" 445 | 446 | choices.append( 447 | CompletionChoiceSchema().dump( 448 | {"index": index, "text": resp, "logprobs": {}, "finish_reason": finish_reason}) 449 | ) 450 | usage = CompletionUsageSchema().dump({ 451 | "prompt_tokens": prompt_tokens, 452 | "completion_tokens": completion_tokens, 453 | "total_tokens": prompt_tokens+completion_tokens}) 454 | 455 | return CompletionSchema().dump( 456 | {"model": "gpt-3.5-turbo-instruct", "choices": choices, "usage": usage}) 457 | 458 | 459 | app = create_app() 460 | 461 | if __name__ == '__main__': 462 | try: 463 | import ngrok 464 | import logging 465 | 466 | logging.basicConfig(level=logging.INFO) 467 | listener = ngrok.werkzeug_develop() 468 | except Exception: 469 | pass 470 | 471 | app.run(debug=False, host="0.0.0.0", port=5000) 472 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Flask==2.3.3 2 | Flask-Cors==4.0.0 3 | gunicorn==21.2.0 4 | transformers==4.30.2 5 | torch==2.0.1 6 | torchvision==0.15.2 7 | xformers==0.0.21 8 | flasgger==0.9.7.1 9 | marshmallow==3.20.1 10 | apispec==6.3.0 11 | apispec-webframeworks==0.5.2 12 | accelerate==0.23.0 13 | colorama==0.4.6 14 | bitsandbytes==0.41.1 15 | sentencepiece==0.1.99 16 | streamlit==1.26.0 17 | scipy==1.11.2 18 | ngrok==0.12.0 19 | --------------------------------------------------------------------------------