├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── chatglm ├── __init__.py ├── chat.py └── chatglm.py ├── cloudflared.py ├── config.toml.example ├── context.py ├── embeddings.py ├── img ├── 2023-04-22-08-42-06.png ├── 2023-04-22-08-48-57.png └── 2023-04-22-09-07-18.png ├── main.py ├── phoenix ├── __init__.py ├── chat.py ├── conversation.py └── phoenix.py ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .vscode/ 131 | 132 | config.toml 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Tao Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # chatglm-openai-api 2 | 3 | Provide OpenAI style API for ChatGLM-6B/ChatGLM2-6B and Chinese Embeddings Model 4 | 5 | ## 更新历史 6 | 7 | - 2023-04-26: 支持 `FreedomIntelligence/phoenix-inst-chat-7b` 模型 8 | - 使用 `--llm_model phoenix-inst-chat-7b/phoenix-inst-chat-7b-int4` 加载 9 | 10 | ## 注意事项 11 | 12 | - 模型托管在 huggingface 上,需要良好的国际互联网访问体验。 13 | - 默认运行在 GPU + CUDA 上。 14 | 15 | ## 在 Colab 中运行 16 | 17 | ```python 18 | # 必须首先选择运行时为GPU运行时 19 | !git clone https://github.com/ninehills/chatglm-openai-api.git 20 | !cd chatglm-openai-api && cp config.toml.example config.toml 21 | !cd chatglm-openai-api && pip install -r requirements.txt 22 | !cd chatglm-openai-api && python3 main.py --llm_model="chatglm2-6b-int4" --tunnel=ngrok --port 8100 23 | ``` 24 | 25 | ## 高级功能 26 | 27 | ### 1. 从本地加载 28 | 29 | 在 `config.toml` 中,配置模型的 path 为本地目录即可 30 | 31 | ```toml 32 | [models.llm."chatglm-6b-int4"] 33 | type = "chatglm" 34 | path = "{checkpoint_path}" 35 | ``` 36 | 37 | ### 2. 多卡运行推理 38 | 39 | 使用 `CUDA_VISIBLE_DEVICES` 环境变量,选择运行的 GPU 卡号,并设定运行的 GPU 数量(目前仅对 LLM Model 有效),例如: 40 | 41 | ```bash 42 | CUDA_VISIBLE_DEVICES=0,1 python main.py --port 8080 --llm_model chatglm-6b-int4 --tunnel ngrok --gpus 2 43 | ``` 44 | 45 | ## 本地运行(ngrok 隧道,测试用) 46 | 47 | > 注: ngrok 隧道在未付费的时候无法使用自定义域名,只能使用动态域名,仅用来演示 48 | > ngrok 的 token 和 subdomain,请在 config.toml 中配置 49 | 50 | ```bash 51 | # 首先初始化虚拟环境 52 | python3 -m venv .venv 53 | source .venv/bin/activate 54 | 55 | # 安装依赖 56 | pip install -r requirements.txt 57 | 58 | # 复制配置文件 59 | cp config.toml.example config.toml 60 | 61 | # 使用 CUDA_VISIBLE_DEVICES 选择运行的 GPU 62 | # llm_model 支持 chatglm-6b、chatglm-6b-int8、chatglm-6b-int4,占用显存从高到低。 63 | CUDA_VISIBLE_DEVICES=0 python main.py --port 8080 --llm_model chatglm-6b-int4 --tunnel ngrok 64 | 65 | # 如果想同时包含 Embedding Model,可以使用 --embeddings_model 参数 66 | CUDA_VISIBLE_DEVICES=0 python main.py --port 8080 --llm_model chatglm-6b-int4 --embeddings_model text2vec-large-chinese --tunnel ngrok 67 | 68 | # 如果想让 API 一直运行,可以使用 nohup 69 | CUDA_VISIBLE_DEVICES=0 nohup python main.py > nohup.out 2>&1 & 70 | ``` 71 | 72 | 运行后,访问显示的 ngrok 隧道地址,即可使用 API,默认输出 `{"hello": "world"}`,该 API 和 OpenAI API 一致。 73 | 74 | ```bash 75 | # https://platform.openai.com/docs/api-reference/chat/create 76 | export CHATGLM_API_KEY=token1 # API key 配置在 config.toml 中 77 | curl https:///v1/chat/completions \ 78 | -H "Content-Type: application/json" \ 79 | -H "Authorization: Bearer $CHATGLM_API_KEY" \ 80 | -d '{ 81 | "model": "gpt-3.5-turbo", 82 | "messages": [{"role": "user", "content": "Hello!"}] 83 | }' 84 | ``` 85 | 86 | ## 本地运行(cloudflare 隧道,推荐) 87 | 88 | 前提:需要你已经在 cloudflare 上绑定了域名,且已经配置好了 DNS 解析 89 | 90 | 首先安装 cloudflare tunnel 91 | 92 | ```bash 93 | # https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/local/ 94 | 95 | # 假如 cloudflared 已经安装,路径为 `.cloudflared` 96 | # 首先登录 cloudflare 97 | ./cloudflared tunnel login 98 | # 此处需要选择 tunnel 绑定的域名 99 | ./cloudflared tunnel create chatglm-openai-api 100 | # 将 tunnel 和你的自定义域名的自域名绑定,这里的 chatglm-openai-api.ninehills.tech 就是你选择的自域名,后续访问这个域名。 101 | ./cloudflared tunnel route dns chatglm-openai-api chatglm-openai-api.ninehills.tech 102 | ``` 103 | 104 | 然后运行 API 105 | 106 | ```bash 107 | CUDA_VISIBLE_DEVICES=0 python main.py --port 8080 --llm_model chatglm-6b-int4 --embeddings_model text2vec-large-chinese --tunnel cloudflare 108 | ``` 109 | 110 | 这样,你访问 `https://chatglm-openai-api.ninehills.tech` 就可以使用 API 了。 111 | 112 | ## 常见客户端配置 113 | 114 | ### OpenCat 115 | 116 | ![](img/2023-04-22-08-42-06.png) 117 | 118 | 119 | ### Chatbot-UI 120 | 121 | 1. Fork `https://github.com/ninehills/chatbot-ui`(去掉了系统 Prompt) 到你的仓库 122 | 2. 注册 `https://vercel.com/` 账号 123 | 3. `Add new` - `Project` - `Import Git Repository` 选择你 Fork 的仓库 124 | 4. 在环境变量部分,填写 125 | - `OPENAI_API_KEY=token1`,token1 为你的 API key 126 | - `OPENAI_API_HOST=https://chatglm-openai-api.ninehills.tech`,chatglm-openai-api.ninehills.tech 为你的域名 127 | - ![](img/2023-04-22-08-48-57.png) 128 | 5. 点击 `Deploy` 部署 129 | 6. 等待部署完成后,点击 `Visit`,即可使用。 130 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import json 4 | from typing import List, Optional, Any 5 | 6 | from fastapi import FastAPI, HTTPException, Request, status, BackgroundTasks 7 | from fastapi.middleware.cors import CORSMiddleware 8 | from fastapi.responses import JSONResponse 9 | from pydantic import BaseModel 10 | from sse_starlette.sse import EventSourceResponse 11 | 12 | from context import context 13 | from utils import torch_gc 14 | 15 | app = FastAPI() 16 | 17 | app.add_middleware( 18 | CORSMiddleware, 19 | allow_origins=['*'], 20 | allow_credentials=True, 21 | allow_methods=['*'], 22 | allow_headers=['*'], 23 | ) 24 | 25 | 26 | class Message(BaseModel): 27 | role: str 28 | content: str 29 | 30 | 31 | class ChatBody(BaseModel): 32 | messages: List[Message] 33 | model: str 34 | stream: Optional[bool] = False 35 | max_tokens: Optional[int] 36 | temperature: Optional[float] 37 | top_p: Optional[float] 38 | 39 | 40 | class CompletionBody(BaseModel): 41 | prompt: str 42 | model: str 43 | stream: Optional[bool] = False 44 | max_tokens: Optional[int] 45 | temperature: Optional[float] 46 | top_p: Optional[float] 47 | 48 | 49 | class EmbeddingsBody(BaseModel): 50 | # Python 3.8 does not support str | List[str] 51 | input: Any 52 | model: Optional[str] 53 | 54 | 55 | @app.get("/") 56 | def read_root(): 57 | return {"Hello": "World!"} 58 | 59 | 60 | @app.get("/v1/models") 61 | def get_models(): 62 | ret = {"data": [], "object": "list"} 63 | 64 | if context.model: 65 | ret['data'].append({ 66 | "created": 1677610602, 67 | "id": "gpt-3.5-turbo", 68 | "object": "model", 69 | "owned_by": "openai", 70 | "permission": [ 71 | { 72 | "created": 1680818747, 73 | "id": "modelperm-fTUZTbzFp7uLLTeMSo9ks6oT", 74 | "object": "model_permission", 75 | "allow_create_engine": False, 76 | "allow_sampling": True, 77 | "allow_logprobs": True, 78 | "allow_search_indices": False, 79 | "allow_view": True, 80 | "allow_fine_tuning": False, 81 | "organization": "*", 82 | "group": None, 83 | "is_blocking": False 84 | } 85 | ], 86 | "root": "gpt-3.5-turbo", 87 | "parent": None, 88 | }) 89 | if context.embeddings_model: 90 | ret['data'].append({ 91 | "created": 1671217299, 92 | "id": "text-embedding-ada-002", 93 | "object": "model", 94 | "owned_by": "openai-internal", 95 | "permission": [ 96 | { 97 | "created": 1678892857, 98 | "id": "modelperm-Dbv2FOgMdlDjO8py8vEjD5Mi", 99 | "object": "model_permission", 100 | "allow_create_engine": False, 101 | "allow_sampling": True, 102 | "allow_logprobs": True, 103 | "allow_search_indices": True, 104 | "allow_view": True, 105 | "allow_fine_tuning": False, 106 | "organization": "*", 107 | "group": None, 108 | "is_blocking": False 109 | } 110 | ], 111 | "root": "text-embedding-ada-002", 112 | "parent": "" 113 | }) 114 | 115 | return ret 116 | 117 | 118 | def generate_response(content: str, chat: bool = True): 119 | if chat: 120 | return { 121 | "id": "chatcmpl-77PZm95TtxE0oYLRx3cxa6HtIDI7s", 122 | "object": "chat.completion", 123 | "created": 1682000966, 124 | "model": "gpt-3.5-turbo-0301", 125 | "usage": { 126 | "prompt_tokens": 0, 127 | "completion_tokens": 0, 128 | "total_tokens": 0, 129 | }, 130 | "choices": [{ 131 | "message": {"role": "assistant", "content": content}, 132 | "finish_reason": "stop", "index": 0} 133 | ] 134 | } 135 | else: 136 | return { 137 | "id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7", 138 | "object": "text_completion", 139 | "created": 1589478378, 140 | "model": "text-davinci-003", 141 | "choices": [ 142 | { 143 | "text": content, 144 | "index": 0, 145 | "logprobs": None, 146 | "finish_reason": "stop" 147 | } 148 | ], 149 | "usage": { 150 | "prompt_tokens": 0, 151 | "completion_tokens": 0, 152 | "total_tokens": 0 153 | } 154 | } 155 | 156 | 157 | def generate_stream_response_start(): 158 | return { 159 | "id": "chatcmpl-77QWpn5cxFi9sVMw56DZReDiGKmcB", 160 | "object": "chat.completion.chunk", "created": 1682004627, 161 | "model": "gpt-3.5-turbo-0301", 162 | "choices": [{"delta": {"role": "assistant"}, "index": 0, "finish_reason": None}] 163 | } 164 | 165 | 166 | 167 | def generate_stream_response(content: str, chat: bool = True): 168 | if chat: 169 | return { 170 | "id": "chatcmpl-77QWpn5cxFi9sVMw56DZReDiGKmcB", 171 | "object": "chat.completion.chunk", 172 | "created": 1682004627, 173 | "model": "gpt-3.5-turbo-0301", 174 | "choices": [{"delta": {"content": content}, "index": 0, "finish_reason": None} 175 | ]} 176 | else: 177 | return { 178 | "id":"cmpl-7GfnvmcsDmmTVbPHmTBcNqlMtaEVj", 179 | "object":"text_completion", 180 | "created":1684208299, 181 | "choices":[ 182 | { 183 | "text": content, 184 | "index": 0, 185 | "logprobs": None, 186 | "finish_reason": None, 187 | } 188 | ], 189 | "model": "text-davinci-003" 190 | } 191 | 192 | 193 | def generate_stream_response_stop(chat: bool = True): 194 | if chat: 195 | return {"id": "chatcmpl-77QWpn5cxFi9sVMw56DZReDiGKmcB", 196 | "object": "chat.completion.chunk", "created": 1682004627, 197 | "model": "gpt-3.5-turbo-0301", 198 | "choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}] 199 | } 200 | else: 201 | return { 202 | "id":"cmpl-7GfnvmcsDmmTVbPHmTBcNqlMtaEVj", 203 | "object":"text_completion", 204 | "created":1684208299, 205 | "choices":[ 206 | {"text":"","index":0,"logprobs":None,"finish_reason":"stop"}], 207 | "model":"text-davinci-003", 208 | } 209 | 210 | @app.post("/v1/embeddings") 211 | async def embeddings(body: EmbeddingsBody, request: Request, background_tasks: BackgroundTasks): 212 | return do_embeddings(body, request, background_tasks) 213 | 214 | 215 | def do_embeddings(body: EmbeddingsBody, request: Request, background_tasks: BackgroundTasks): 216 | background_tasks.add_task(torch_gc) 217 | if request.headers.get("Authorization").split(" ")[1] not in context.tokens: 218 | raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token is wrong!") 219 | 220 | if not context.embeddings_model: 221 | raise HTTPException(status.HTTP_404_NOT_FOUND, "Embeddings model not found!") 222 | 223 | embeddings = context.embeddings_model.encode(body.input) 224 | data = [] 225 | if isinstance(body.input, str): 226 | data.append({ 227 | "object": "embedding", 228 | "index": 0, 229 | "embedding": embeddings.tolist(), 230 | }) 231 | else: 232 | for i, embed in enumerate(embeddings): 233 | data.append({ 234 | "object": "embedding", 235 | "index": i, 236 | "embedding": embed.tolist(), 237 | }) 238 | content = { 239 | "object": "list", 240 | "data": data, 241 | "model": "text-embedding-ada-002-v2", 242 | "usage": { 243 | "prompt_tokens": 0, 244 | "total_tokens": 0 245 | } 246 | } 247 | return JSONResponse(status_code=200, content=content) 248 | 249 | 250 | @app.post("/v1/engines/{engine}/embeddings") 251 | async def engines_embeddings(engine: str, body: EmbeddingsBody, request: Request, background_tasks: BackgroundTasks): 252 | return do_embeddings(body, request, background_tasks) 253 | 254 | 255 | @app.post("/v1/chat/completions") 256 | async def chat_completions(body: ChatBody, request: Request, background_tasks: BackgroundTasks): 257 | background_tasks.add_task(torch_gc) 258 | if request.headers.get("Authorization").split(" ")[1] not in context.tokens: 259 | raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token is wrong!") 260 | 261 | if not context.model: 262 | raise HTTPException(status.HTTP_404_NOT_FOUND, "LLM model not found!") 263 | question = body.messages[-1] 264 | if question.role == 'user': 265 | question = question.content 266 | else: 267 | raise HTTPException(status.HTTP_400_BAD_REQUEST, "No Question Found") 268 | 269 | history = [] 270 | user_question = '' 271 | for message in body.messages: 272 | if message.role == 'system': 273 | history.append((message.content, "OK")) 274 | if message.role == 'user': 275 | user_question = message.content 276 | elif message.role == 'assistant': 277 | assistant_answer = message.content 278 | history.append((user_question, assistant_answer)) 279 | 280 | print(f"question = {question}, history = {history}") 281 | 282 | if body.stream: 283 | async def eval_llm(): 284 | first = True 285 | for response in context.model.do_chat_stream( 286 | context.model, context.tokenizer, question, history, { 287 | "temperature": body.temperature, 288 | "top_p": body.top_p, 289 | "max_tokens": body.max_tokens, 290 | }): 291 | if first: 292 | first = False 293 | yield json.dumps(generate_stream_response_start(), 294 | ensure_ascii=False) 295 | yield json.dumps(generate_stream_response(response), ensure_ascii=False) 296 | yield json.dumps(generate_stream_response_stop(), ensure_ascii=False) 297 | yield "[DONE]" 298 | return EventSourceResponse(eval_llm(), ping=10000) 299 | else: 300 | response = context.model.do_chat(context.model, context.tokenizer, question, history, { 301 | "temperature": body.temperature, 302 | "top_p": body.top_p, 303 | "max_tokens": body.max_tokens, 304 | }) 305 | return JSONResponse(content=generate_response(response)) 306 | 307 | 308 | @app.post("/v1/completions") 309 | async def completions(body: CompletionBody, request: Request, background_tasks: BackgroundTasks): 310 | background_tasks.add_task(torch_gc) 311 | if request.headers.get("Authorization").split(" ")[1] not in context.tokens: 312 | raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token is wrong!") 313 | 314 | if not context.model: 315 | raise HTTPException(status.HTTP_404_NOT_FOUND, "LLM model not found!") 316 | question = body.prompt 317 | 318 | print(f"question = {question}") 319 | 320 | if body.stream: 321 | async def eval_llm(): 322 | for response in context.model.do_chat_stream( 323 | context.model, context.tokenizer, question, [], { 324 | "temperature": body.temperature, 325 | "top_p": body.top_p, 326 | "max_tokens": body.max_tokens, 327 | }): 328 | yield json.dumps(generate_stream_response(response, chat=False), ensure_ascii=False) 329 | yield json.dumps(generate_stream_response_stop(chat=False), ensure_ascii=False) 330 | yield "[DONE]" 331 | return EventSourceResponse(eval_llm(), ping=10000) 332 | else: 333 | response = context.model.do_chat(context.model, context.tokenizer, question, [], { 334 | "temperature": body.temperature, 335 | "top_p": body.top_p, 336 | "max_tokens": body.max_tokens, 337 | }) 338 | return JSONResponse(content=generate_response(response, chat=False)) 339 | -------------------------------------------------------------------------------- /chatglm/__init__.py: -------------------------------------------------------------------------------- 1 | from .chatglm import init_chatglm 2 | -------------------------------------------------------------------------------- /chatglm/chat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | def init_model_args(model_args = None): 6 | if model_args is None: 7 | model_args = {} 8 | model_args['temperature'] = model_args['temperature'] if model_args.get('temperature') != None else 0.95 9 | if model_args['temperature'] <= 0: 10 | model_args['temperature'] = 0.1 11 | if model_args['temperature'] > 1: 12 | model_args['temperature'] = 1 13 | model_args['top_p'] = model_args['top_p'] if model_args.get('top_p') else 0.7 14 | model_args['max_tokens'] = model_args['max_tokens'] if model_args.get('max_tokens') != None else 512 15 | 16 | return model_args 17 | 18 | def do_chat_stream(model, tokenizer, question, history, model_args = None): 19 | model_args = init_model_args(model_args) 20 | sends = 0 21 | for response, _ in model.stream_chat( 22 | tokenizer, question, history, 23 | temperature=model_args['temperature'], 24 | top_p=model_args['top_p'], 25 | max_length=max(2048, model_args['max_tokens'])): 26 | ret = response[sends:] 27 | # https://github.com/THUDM/ChatGLM-6B/issues/478 28 | # 修复表情符号的输出问题 29 | if "\uFFFD" == ret[-1:]: 30 | continue 31 | sends = len(response) 32 | 33 | yield ret 34 | 35 | 36 | def do_chat(model, tokenizer, question, history, model_args = None): 37 | model_args = init_model_args(model_args) 38 | response, _ = model.chat( 39 | tokenizer, question, history, 40 | temperature=model_args['temperature'], 41 | top_p=model_args['top_p'], 42 | max_length=max(2048, model_args['max_tokens'])) 43 | return response 44 | -------------------------------------------------------------------------------- /chatglm/chatglm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | ## From: https://github.com/THUDM/ChatGLM-6B 4 | import torch 5 | import os 6 | from typing import Dict, Union, Optional 7 | 8 | from torch.nn import Module 9 | from transformers import AutoModel, AutoTokenizer 10 | 11 | from .chat import do_chat, do_chat_stream 12 | 13 | def init_chatglm(model_path: str, running_device: str, gpus: int): 14 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 15 | 16 | if running_device.upper() == "GPU": 17 | model = load_model_on_gpus(model_path, gpus) 18 | else: 19 | model = AutoModel.from_pretrained(model_path, trust_remote_code=True) 20 | model = model.float() 21 | 22 | model.eval() 23 | model.do_chat = do_chat 24 | model.do_chat_stream = do_chat_stream 25 | return tokenizer, model 26 | 27 | 28 | def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: 29 | # transformer.word_embeddings 占用1层 30 | # transformer.final_layernorm 和 lm_head 占用1层 31 | # transformer.layers 占用 28 层 32 | # 总共30层分配到num_gpus张卡上 33 | num_trans_layers = 28 34 | per_gpu_layers = 30 / num_gpus 35 | 36 | # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError 37 | # windows下 model.device 会被设置成 transformer.word_embeddings.device 38 | # linux下 model.device 会被设置成 lm_head.device 39 | # 在调用chat或者stream_chat时,input_ids会被放到model.device上 40 | # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError 41 | # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上 42 | device_map = {'transformer.word_embeddings': 0, 43 | 'transformer.final_layernorm': 0, 'lm_head': 0} 44 | 45 | used = 2 46 | gpu_target = 0 47 | for i in range(num_trans_layers): 48 | if used >= per_gpu_layers: 49 | gpu_target += 1 50 | used = 0 51 | assert gpu_target < num_gpus 52 | device_map[f'transformer.layers.{i}'] = gpu_target 53 | used += 1 54 | 55 | return device_map 56 | 57 | 58 | def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2, 59 | device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module: 60 | if num_gpus < 2 and device_map is None: 61 | model = AutoModel.from_pretrained( 62 | checkpoint_path, trust_remote_code=True, **kwargs).half().cuda() 63 | else: 64 | if num_gpus > torch.cuda.device_count(): 65 | raise Exception(f"need {num_gpus} GPU, but only has {torch.cuda.device_count()}") 66 | 67 | from accelerate import dispatch_model 68 | 69 | model = AutoModel.from_pretrained( 70 | checkpoint_path, trust_remote_code=True, **kwargs).half() 71 | 72 | if device_map is None: 73 | device_map = auto_configure_device_map(num_gpus) 74 | 75 | model = dispatch_model(model, device_map=device_map) 76 | print(f"Device Map: {model.hf_device_map}\n") 77 | 78 | return model 79 | -------------------------------------------------------------------------------- /cloudflared.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | import atexit 4 | import subprocess 5 | 6 | from threading import Timer 7 | 8 | 9 | def start_cloudflared(command, name, port): 10 | cloudflared = subprocess.Popen( 11 | [command, 'tunnel', '--url', 'http://127.0.0.1:' + 12 | str(port) + '/.', 'run', name], 13 | stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT) 14 | atexit.register(cloudflared.terminate) 15 | 16 | 17 | def run(command, name, port): 18 | # Starting the Cloudflared tunnel in a separate thread. 19 | thread = Timer(2, start_cloudflared, args=(command, name, port,)) 20 | thread.setDaemon(True) 21 | thread.start() 22 | -------------------------------------------------------------------------------- /config.toml.example: -------------------------------------------------------------------------------- 1 | [models] 2 | [models.llm] 3 | [models.llm."chatglm-6b"] 4 | type = "chatglm" 5 | path = "THUDM/chatglm-6b" 6 | [models.llm."chatglm-6b-int8"] 7 | type = "chatglm" 8 | path = "THUDM/chatglm-6b-int8" 9 | [models.llm."chatglm-6b-int4"] 10 | type = "chatglm" 11 | path = "THUDM/chatglm-6b-int4" 12 | [models.llm."chatglm2-6b"] 13 | type = "chatglm" 14 | path = "THUDM/chatglm2-6b" 15 | [models.llm."chatglm2-6b-int8"] 16 | type = "chatglm" 17 | path = "THUDM/chatglm2-6b-int8" 18 | [models.llm."chatglm2-6b-int4"] 19 | type = "chatglm" 20 | path = "THUDM/chatglm2-6b-int4" 21 | [models.llm."phoenix-inst-chat-7b"] 22 | type = "phoenix" 23 | path = "FreedomIntelligence/phoenix-inst-chat-7b" 24 | [models.llm."phoenix-inst-chat-7b-int4"] 25 | type = "phoenix" 26 | path = "FreedomIntelligence/phoenix-inst-chat-7b-int4" 27 | 28 | [models.embeddings] 29 | [models.embeddings."text2vec-large-chinese"] 30 | type = "default" 31 | path = "GanymedeNil/text2vec-large-chinese" 32 | 33 | [auth] 34 | tokens = ["token1"] 35 | 36 | [tunnel] 37 | [tunnel.ngrok] 38 | token = "" 39 | region = "jp" 40 | # Binding custom subdomains is a feature for paid accounts. 41 | subdomain = "" 42 | [tunnel.cloudflare] 43 | # first need init cloudflare tunnel, see README.md 44 | cloudflared_path = "/usr/local/bin/cloudflared" 45 | # tunnel name, see README.md 46 | name = "chatglm-openai-api" 47 | -------------------------------------------------------------------------------- /context.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from dataclasses import dataclass 5 | from typing import List 6 | 7 | @dataclass 8 | class Context: 9 | llm_model_type: str 10 | model: any 11 | tokenizer: any 12 | embeddings_model: any 13 | 14 | tokens: List[str] 15 | 16 | 17 | context = Context(None, None, None, None, []) 18 | 19 | -------------------------------------------------------------------------------- /embeddings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | from text2vec import SentenceModel 4 | 5 | def load_embeddings_model(model_path: str, device: str): 6 | if device == "gpu": 7 | device = "cuda" 8 | model = SentenceModel(model_path, max_seq_length=1024, device=device) 9 | return model -------------------------------------------------------------------------------- /img/2023-04-22-08-42-06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ninehills/chatglm-openai-api/f9e068c252aca60588e002f53375af0226bd9ae0/img/2023-04-22-08-42-06.png -------------------------------------------------------------------------------- /img/2023-04-22-08-48-57.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ninehills/chatglm-openai-api/f9e068c252aca60588e002f53375af0226bd9ae0/img/2023-04-22-08-48-57.png -------------------------------------------------------------------------------- /img/2023-04-22-09-07-18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ninehills/chatglm-openai-api/f9e068c252aca60588e002f53375af0226bd9ae0/img/2023-04-22-09-07-18.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | import argparse 4 | import os 5 | import sys 6 | 7 | import toml 8 | import uvicorn 9 | 10 | from context import context 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser( 15 | description='Start LLM and Embeddings models as a service.') 16 | 17 | parser.add_argument('--config', type=str, help='Path to the config file', 18 | default='config.toml') 19 | parser.add_argument('--llm_model', type=str, help='Choosed LLM model', 20 | default='chatglm-6b-int4') 21 | parser.add_argument('--embeddings_model', type=str, 22 | help='Choosed embeddings model, can be empty', 23 | default='') 24 | parser.add_argument('--device', type=str, 25 | help='Device to run the service, gpu/cpu/mps', 26 | default='gpu') 27 | parser.add_argument('--gpus', type=int, help='Use how many gpus, default 1', 28 | default=1) 29 | parser.add_argument('--port', type=int, help='Port number to run the service', 30 | default=8080) 31 | parser.add_argument('--tunnel', type=str, help='Remote tunnel for public visit, default not set', 32 | default="") 33 | 34 | args = parser.parse_args() 35 | 36 | print("> Load config and arguments...") 37 | print(f"Config file: {args.config}") 38 | print(f"Language Model: {args.llm_model}") 39 | print(f"Embeddings Model: {args.embeddings_model}") 40 | print(f"Device: {args.device}") 41 | print(f"GPUs: {args.gpus}") 42 | print(f"Port: {args.port}") 43 | print(f"Tunneling: {args.tunnel}") 44 | 45 | with open(args.config) as f: 46 | config = toml.load(f) 47 | print(f"Config: \n{config}") 48 | context.tokens = config['auth']['tokens'] 49 | 50 | if args.llm_model: 51 | print(f"> Start LLM model {args.llm_model}") 52 | if args.llm_model not in config['models']['llm']: 53 | print(f"LLM model {args.llm_model} not found in config file") 54 | sys.exit(1) 55 | 56 | llm = config['models']['llm'][args.llm_model] 57 | context.llm_model_type = llm['type'] 58 | if llm['type'] == 'chatglm': 59 | print(f">> Use chatglm llm model {llm['path']}") 60 | from chatglm import init_chatglm 61 | context.tokenizer, context.model = init_chatglm( 62 | llm['path'], args.device, args.gpus) 63 | elif llm['type'] == 'phoenix': 64 | print(f">> Use phoenix llm model {llm['path']}") 65 | from phoenix import init_phoenix 66 | context.tokenizer, context.model = init_phoenix( 67 | llm['path'], args.device, args.gpus) 68 | else: 69 | print(f"Unsupported LLM model type {llm['type']}") 70 | sys.exit(1) 71 | 72 | if args.embeddings_model: 73 | print(f"> Start Embeddings model {args.embeddings_model}") 74 | if args.embeddings_model not in config['models']['embeddings']: 75 | print( 76 | f"Embeddings model {args.embeddings_model} not found in config file") 77 | sys.exit(1) 78 | 79 | embeddings = config['models']['embeddings'][args.embeddings_model] 80 | if embeddings['type'] == 'default': 81 | print(f">> Use default embeddings model {embeddings['path']}") 82 | from embeddings import load_embeddings_model 83 | context.embeddings_model = load_embeddings_model( 84 | embeddings['path'], args.device) 85 | else: 86 | print(f"Unsupported Embeddings model type {embeddings['type']}") 87 | sys.exit(1) 88 | 89 | print("> Start API server...") 90 | if args.tunnel: 91 | print(">> Enable remote tunneling...") 92 | if args.tunnel not in config['tunnel']: 93 | print(f"Tunneling {args.tunnel} not found in config file") 94 | sys.exit(1) 95 | if args.tunnel == "ngrok": 96 | print(">>> Start ngrok tunneling...") 97 | from pyngrok import ngrok, conf 98 | conf.get_default().region = config['tunnel']['ngrok']['region'] 99 | if config['tunnel']['ngrok']['token']: 100 | ngrok.set_auth_token(config['tunnel']['ngrok']['token']) 101 | subdomain = config['tunnel']['ngrok']['subdomain'] or None 102 | http_tunnel = ngrok.connect(args.port, subdomain=subdomain) 103 | print(f">> Public URL: {http_tunnel.public_url}") 104 | if args.tunnel == "cloudflare": 105 | print(">>> Start cloudflare tunnel..") 106 | from cloudflared import run 107 | command = config['tunnel']['cloudflare']['cloudflared_path'] \ 108 | or "cloudflared" 109 | run(command, config['tunnel']['cloudflare']['name'], args.port) 110 | 111 | from app import app 112 | uvicorn.run(app, host="0.0.0.0", port=args.port) 113 | 114 | 115 | if __name__ == '__main__': 116 | main() 117 | -------------------------------------------------------------------------------- /phoenix/__init__.py: -------------------------------------------------------------------------------- 1 | from .phoenix import init_phoenix 2 | -------------------------------------------------------------------------------- /phoenix/chat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import torch 3 | from .conversation import get_default_conv_template, SeparatorStyle 4 | 5 | def init_model_args(model_args = None): 6 | if model_args is None: 7 | model_args = {} 8 | model_args['temperature'] = model_args['temperature'] if model_args.get('temperature') != None else 0.7 9 | model_args['max_tokens'] = model_args['max_tokens'] if model_args.get('max_tokens') != None else 512 10 | 11 | return model_args 12 | 13 | 14 | def do_chat(model, tokenizer, question, history, model_args = None): 15 | ret = "" 16 | for char in do_chat_stream(model, tokenizer, question, history, model_args): 17 | ret += char 18 | return ret 19 | 20 | 21 | def do_chat_stream(model, tokenizer, question, history, model_args = None): 22 | model_args = init_model_args(model_args) 23 | conv = get_default_conv_template().copy() 24 | 25 | for (human, ai) in history: 26 | conv.append_message(conv.roles[0], human) 27 | # NOTE: strip is important to align with the training data. 28 | conv.append_message(conv.roles[1], ai.strip()) 29 | conv.append_message(conv.roles[0], question) 30 | conv.append_message(conv.roles[1], None) 31 | 32 | generate_stream_func = generate_stream 33 | prompt = conv.get_prompt() 34 | 35 | params = { 36 | "model": model, 37 | "prompt": prompt, 38 | "temperature": model_args['temperature'], 39 | "max_new_tokens": model_args['max_tokens'], 40 | "stop": conv.sep if conv.sep_style == SeparatorStyle.SINGLE else None, 41 | } 42 | 43 | output_stream = generate_stream_func(model, tokenizer, params, model.running_device) 44 | 45 | pre = 0 46 | for outputs in output_stream: 47 | now = len(outputs) - 1 48 | if now > pre: 49 | yield(outputs[pre:now]) 50 | pre = now 51 | yield(outputs[pre:]) 52 | 53 | 54 | @torch.inference_mode() 55 | def generate_stream(model, tokenizer, params, device, context_len=2048, stream_interval=2): 56 | prompt = params["prompt"] 57 | temperature = float(params.get("temperature", 1.0)) 58 | max_new_tokens = int(params.get("max_new_tokens", 256)) 59 | stop_str = params.get("stop", None) 60 | stop_token_ids = params.get("stop_ids", [tokenizer.eos_token_id]) 61 | 62 | input_ids = tokenizer(prompt).input_ids 63 | output_ids = list(input_ids) 64 | 65 | l_prompt = len(tokenizer.decode(input_ids, skip_special_tokens=False)) 66 | 67 | max_src_len = context_len - max_new_tokens - 8 68 | input_ids = input_ids[-max_src_len:] 69 | 70 | for i in range(max_new_tokens): 71 | if i == 0: 72 | if model.config.is_encoder_decoder: 73 | encoder_outputs = model.encoder( 74 | input_ids=torch.as_tensor([input_ids], device=device) 75 | ) 76 | out = model( 77 | torch.as_tensor([input_ids], device=device), 78 | decoder_input_ids=torch.as_tensor( 79 | [[model.generation_config.decoder_start_token_id]], 80 | device=device, 81 | ), 82 | encoder_outputs=encoder_outputs, 83 | use_cache=True, 84 | ) 85 | logits = out.logits 86 | past_key_values = out.past_key_values 87 | else: 88 | out = model(torch.as_tensor([input_ids], device=device), use_cache=True) 89 | logits = out.logits 90 | past_key_values = out.past_key_values 91 | else: 92 | if model.config.is_encoder_decoder: 93 | out = model( 94 | input_ids=torch.as_tensor([input_ids], device=device), 95 | use_cache=True, 96 | encoder_outputs=encoder_outputs, 97 | decoder_input_ids=torch.as_tensor([[token]], device=device), 98 | past_key_values=past_key_values, 99 | ) 100 | logits = out.logits 101 | past_key_values = out.past_key_values 102 | else: 103 | out = model( 104 | input_ids=torch.as_tensor([[token]], device=device), 105 | use_cache=True, 106 | past_key_values=past_key_values, 107 | ) 108 | logits = out.logits 109 | past_key_values = out.past_key_values 110 | 111 | last_token_logits = logits[0][-1] 112 | 113 | if device == "mps": 114 | # Switch to CPU by avoiding some bugs in mps backend. 115 | last_token_logits = last_token_logits.float().to("cpu") 116 | 117 | if temperature < 1e-4: 118 | token = int(torch.argmax(last_token_logits)) 119 | else: 120 | probs = torch.softmax(last_token_logits / temperature, dim=-1) 121 | token = int(torch.multinomial(probs, num_samples=1)) 122 | 123 | output_ids.append(token) 124 | 125 | if token in stop_token_ids: 126 | stopped = True 127 | else: 128 | stopped = False 129 | 130 | if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: 131 | output = tokenizer.decode(output_ids, skip_special_tokens=False) 132 | if stop_str: 133 | pos = output.rfind(stop_str, l_prompt) 134 | if pos != -1: 135 | output = output[l_prompt:pos] 136 | stopped = True 137 | else: 138 | output = output[l_prompt:] 139 | yield output 140 | else: 141 | raise NotImplementedError 142 | 143 | if stopped: 144 | break 145 | 146 | del past_key_values -------------------------------------------------------------------------------- /phoenix/conversation.py: -------------------------------------------------------------------------------- 1 | # https://raw.githubusercontent.com/FreedomIntelligence/LLMZoo/main/llmzoo/utils.py 2 | import dataclasses 3 | from enum import auto, Enum 4 | from typing import List 5 | 6 | import transformers 7 | 8 | 9 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, 10 | output_dir: str): 11 | """Collects the state dict and dump to disk.""" 12 | state_dict = trainer.model.state_dict() 13 | if trainer.args.should_save: 14 | cpu_state_dict = { 15 | key: value.cpu() 16 | for key, value in state_dict.items() 17 | } 18 | del state_dict 19 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa 20 | 21 | 22 | class SeparatorStyle(Enum): 23 | """Different separator style.""" 24 | SINGLE = auto() 25 | TWO = auto() 26 | 27 | 28 | @dataclasses.dataclass 29 | class Conversation: 30 | """A class that keeps all conversation history.""" 31 | system: str 32 | roles: List[str] 33 | messages: List[List[str]] 34 | offset: int 35 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE 36 | sep: str = "" 37 | 38 | skip_next: bool = False 39 | 40 | def get_prompt(self): 41 | if self.sep_style == SeparatorStyle.SINGLE: 42 | ret = self.system 43 | for role, message in self.messages: 44 | if message: 45 | ret += role + ": " + "" + message + "" 46 | else: 47 | ret += role + ": " + "" 48 | return ret 49 | else: 50 | raise ValueError(f"Invalid style: {self.sep_style}") 51 | 52 | def append_message(self, role, message): 53 | self.messages.append([role, message]) 54 | 55 | def to_gradio_chatbot(self): 56 | ret = [] 57 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 58 | if i % 2 == 0: 59 | ret.append([msg, None]) 60 | else: 61 | ret[-1][-1] = msg 62 | return ret 63 | 64 | def copy(self): 65 | return Conversation( 66 | system=self.system, 67 | roles=self.roles, 68 | messages=[[x, y] for x, y in self.messages], 69 | offset=self.offset, 70 | sep_style=self.sep_style, 71 | sep=self.sep) 72 | 73 | def dict(self): 74 | return { 75 | "system": self.system, 76 | "roles": self.roles, 77 | "messages": self.messages, 78 | "offset": self.offset, 79 | "sep": self.sep 80 | } 81 | 82 | 83 | def get_default_conv_template(model_name=None): 84 | if model_name is None: 85 | return default_conversation 86 | model_name = model_name.lower() 87 | if "phoenix" in model_name: 88 | return default_conversation 89 | else: 90 | raise NotImplementedError 91 | 92 | 93 | conv = Conversation( 94 | system="A chat between a curious human and an artificial intelligence assistant. " 95 | "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", 96 | roles=("Human", "Assistant"), 97 | messages=(), 98 | offset=0, 99 | sep_style=SeparatorStyle.SINGLE, 100 | sep="", 101 | ) 102 | 103 | default_conversation = conv 104 | conv_templates = {"default": conv} 105 | 106 | if __name__ == "__main__": 107 | print(default_conversation.get_prompt()) -------------------------------------------------------------------------------- /phoenix/phoenix.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # From: https://github.com/FreedomIntelligence/LLMZoo 4 | 5 | import torch 6 | from transformers import AutoTokenizer, AutoModelForCausalLM 7 | 8 | from .chat import do_chat, do_chat_stream 9 | 10 | 11 | def init_phoenix(model_path: str, device: str, num_gpus: int): 12 | if device == "cpu": 13 | kwargs = {} 14 | elif device == "gpu": 15 | kwargs = {"torch_dtype": torch.float16} 16 | kwargs["device_map"] = "sequential" # This is important for not the same VRAM sizes 17 | else: 18 | raise ValueError(f"Invalid device: {device}") 19 | 20 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 21 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 22 | 23 | model.running_device = "cuda" if device == "gpu" else "cpu" 24 | model.do_chat = do_chat 25 | model.do_chat_stream = do_chat_stream 26 | return tokenizer, model 27 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | protobuf>=3.20.0 2 | transformers>=4.27.1 3 | icetk 4 | cpm_kernels 5 | torch 6 | fastapi 7 | pydantic==1.10.11 8 | uvicorn 9 | sse_starlette 10 | pyngrok 11 | toml 12 | # for notebook 13 | nest-asyncio 14 | # only need by embeddings model 15 | text2vec 16 | 17 | # for multi-gpu 18 | accelerate 19 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | 4 | last_gc = 0 5 | 6 | 7 | def torch_gc(): 8 | # 使用 last_gc 变量来控制 gc 的频率,不多于 1min 一次 9 | global last_gc 10 | if time.time() - last_gc > 60: 11 | last_gc = time.time() 12 | if torch.cuda.is_available(): 13 | device = torch.cuda.current_device() 14 | print(f"Emptying gpu cache {device}...") 15 | with torch.cuda.device(device): 16 | torch.cuda.empty_cache() 17 | torch.cuda.ipc_collect() 18 | --------------------------------------------------------------------------------