├── .gitignore ├── Dockerfile ├── README.md ├── docs └── deploy_triton.md ├── pyproject.toml ├── requirements.txt ├── src ├── fastapi_tritonserver │ ├── __init__.py │ ├── config.py │ ├── constants.py │ ├── ctx.py │ ├── engine │ │ ├── __init__.py │ │ ├── engine.py │ │ └── tritonserver.py │ ├── entrypoints │ │ ├── __init__.py │ │ ├── api_server.py │ │ └── openai_api.py │ ├── logger.py │ ├── models │ │ ├── base_model.py │ │ ├── prompt_template.py │ │ ├── qwen2chat.py │ │ └── qwenvl.py │ ├── protocols │ │ ├── __init__.py │ │ ├── fastapi.py │ │ └── openai.py │ ├── sampling_params.py │ └── utils │ │ ├── __init__.py │ │ ├── generate_cfg.py │ │ └── tools.py └── triton_server_helper │ ├── __init__.py │ └── client.py ├── start_api.sh └── start_openapi.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | *.whl 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.lib 27 | 28 | # Executables 29 | *.exe 30 | *.out 31 | *.app 32 | 33 | # Byte-compiled / optimized / DLL files 34 | __pycache__/ 35 | *.py[cod] 36 | *$py.class 37 | 38 | # C extensions 39 | *.so 40 | 41 | # Distribution / packaging 42 | .Python 43 | build/ 44 | develop-eggs/ 45 | dist/ 46 | downloads/ 47 | eggs/ 48 | .eggs/ 49 | lib/ 50 | lib64/ 51 | parts/ 52 | sdist/ 53 | var/ 54 | wheels/ 55 | share/python-wheels/ 56 | *.egg-info/ 57 | .installed.cfg 58 | *.egg 59 | MANIFEST 60 | 61 | # PyInstaller 62 | # Usually these files are written by a python script from a template 63 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 64 | *.manifest 65 | *.spec 66 | 67 | # Installer logs 68 | pip-log.txt 69 | pip-delete-this-directory.txt 70 | 71 | # Unit test / coverage reports 72 | htmlcov/ 73 | .tox/ 74 | .nox/ 75 | .coverage 76 | .coverage.* 77 | .cache 78 | nosetests.xml 79 | coverage.xml 80 | *.cover 81 | *.py,cover 82 | .hypothesis/ 83 | .pytest_cache/ 84 | cover/ 85 | 86 | # Translations 87 | *.mo 88 | *.pot 89 | 90 | # Django stuff: 91 | *.log 92 | local_settings.py 93 | db.sqlite3 94 | db.sqlite3-journal 95 | 96 | # Flask stuff: 97 | instance/ 98 | .webassets-cache 99 | 100 | # Scrapy stuff: 101 | .scrapy 102 | 103 | # Sphinx documentation 104 | docs/_build/ 105 | 106 | # PyBuilder 107 | .pybuilder/ 108 | target/ 109 | 110 | # Jupyter Notebook 111 | .ipynb_checkpoints 112 | 113 | # IPython 114 | profile_default/ 115 | ipython_config.py 116 | 117 | # pyenv 118 | # For a library or package, you might want to ignore these files since the code is 119 | # intended to run in multiple environments; otherwise, check them in: 120 | # .python-version 121 | 122 | # pipenv 123 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 124 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 125 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 126 | # install all needed dependencies. 127 | #Pipfile.lock 128 | 129 | # poetry 130 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 131 | # This is especially recommended for binary packages to ensure reproducibility, and is more 132 | # commonly ignored for libraries. 133 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 134 | #poetry.lock 135 | 136 | # pdm 137 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 138 | #pdm.lock 139 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 140 | # in version control. 141 | # https://pdm.fming.dev/#use-with-ide 142 | .pdm.toml 143 | 144 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 145 | __pypackages__/ 146 | 147 | # Celery stuff 148 | celerybeat-schedule 149 | celerybeat.pid 150 | 151 | # SageMath parsed files 152 | *.sage.py 153 | 154 | # Environments 155 | .env 156 | .venv 157 | env/ 158 | venv/ 159 | ENV/ 160 | env.bak/ 161 | venv.bak/ 162 | 163 | # Spyder project settings 164 | .spyderproject 165 | .spyproject 166 | 167 | # Rope project settings 168 | .ropeproject 169 | 170 | # mkdocs documentation 171 | /site 172 | 173 | # mypy 174 | .mypy_cache/ 175 | .dmypy.json 176 | dmypy.json 177 | 178 | # Pyre type checker 179 | .pyre/ 180 | 181 | # pytype static type analyzer 182 | .pytype/ 183 | 184 | # Cython debug symbols 185 | cython_debug/ 186 | 187 | # PyCharm 188 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 189 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 190 | # and can be added to the global gitignore or merged into this file. For a more nuclear 191 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 192 | #.idea/ 193 | kineto/ 194 | .vscode/ 195 | *.tar.gz 196 | tmp/ 197 | .idea/ 198 | */.idea/ 199 | *.jpeg 200 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10-bullseye as builder 2 | 3 | # install torch 4 | RUN pip install --upgrade pip \ 5 | && pip install torch 6 | 7 | FROM python:3.10-bullseye as final 8 | COPY --from=builder /usr/local /usr/local 9 | RUN mkdir /app 10 | WORKDIR /app 11 | COPY . /app 12 | # RUN chmod +x ./start_api.sh 13 | RUN chmod +x ./start_openapi.sh 14 | 15 | ENV TOKENIZER_PATH="" 16 | ENV TRITON_SERVER_HOST="127.0.0.1" 17 | ENV TRITON_SERVER_PORT="8001" 18 | ENV WORKERS=1 19 | ENV MODEL_TYPE="qwen2-chat" 20 | 21 | EXPOSE 9900 22 | 23 | RUN python -m pip install --upgrade pip && \ 24 | pip install . 25 | 26 | # 启动程序 27 | CMD ["./start_openapi.sh"] 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 项目介绍 2 | - 一个使用fastapi部署,封装tritonserver以支持openai API格式的项目。 3 | 4 | # 准备工作 5 | - 需要已部署tritonserver for trt-llm服务,[参考教程](./docs/deploy_triton.md) 6 | - 目前仅支持tritonserver24.02 + tensorrt_llm 0.8.0部署的qwen1.5系列,其它未做测试。 7 | 8 | 9 | ## 安装本项目代码 10 | - 源码安装 11 | ```shell 12 | pip install . 13 | ``` 14 | 15 | - whl安装包安装(暂未提供) 16 | ``` 17 | pip install fastapi_tritonserver-0.0.1-py3-none-any.whl 18 | ``` 19 | 20 | - 从Pypi仓库安装(暂未实现) 21 | 22 | ## Api server 启动示例(普通的接口调用) 23 | ### 确保tritonserver已经启动,并切暴露在127.0.0.1:8001 24 | - 假设你部署的qwen1.5-1.8b-chat模型,那么可以将tokenizer的本机离线路径或者在huggingface的在线路径提供给`--tokenizer-path`参考。 25 | - workers可以根据你的tritonserver最大支持的batch_size来设置。 26 | - 下面是一个简单示例 27 | ```shell 28 | python3 -m fastapi_tritonserver.entrypoints.api_server \ 29 | --port 9900 \ 30 | --host 0.0.0.0 \ 31 | --model-name tensorrt_llm \ 32 | --tokenizer-path Qwen/Qwen1.5-1.8B-Chat \ 33 | --server-url 127.0.0.1:8001 \ 34 | --workers 4 \ 35 | --model_type qwen2-chat 36 | ``` 37 | 38 | 39 | ## 请求示例 40 | ``` 41 | curl -X POST localhost:9100/generate -d '{ 42 | "prompt": "who are you?" 43 | }' 44 | ``` 45 | output: 46 | ```shell 47 | {"text":"I am QianWen, a large language model created by Alibaba Cloud.","id":"89101ccc-d6d0-4cdf-a05c-8cbb7b466d66"} 48 | ``` 49 | 50 | ## 参数说明 51 | ``` 52 | - prompt: 用于生成的提示。 53 | - images: 只有vl模型需要这个输入 54 | - max_output_len: 每个输出序列生成的最大令牌数。 55 | - num_beams: 使用波束搜索时的波束宽度。 56 | - repetition_penalty 57 | - top_k: 控制要考虑的顶级令牌数的整数。设置为-1以考虑所有令牌。 58 | - top_p: 控制要考虑的顶级令牌的累积概率的浮点数,必须在(0, 1]内。设置为1以考虑所有令牌。 59 | - temperature: 控制采样随机性的浮点数。较低的值使模型更确定,而较高的值使模型更随机。零表示贪婪采样。 60 | ``` 61 | 62 | ## Openai API 示例(符合openai api规范的调用) 63 | ### 确保tritonserver已经启动,并切暴露在127.0.0.1:8001, workers可以根据你的tritonserver最大支持的batch_size来设置。 64 | ```shell 65 | python3 -m fastapi_tritonserver.entrypoints.openai_api \ 66 | --port 9900 \ 67 | --host 0.0.0.0 \ 68 | --model-name tensorrt_llm \ 69 | --tokenizer-path Qwen/Qwen1.5-1.8B-Chat \ 70 | --server-url 127.0.0.1:8001 \ 71 | --workers 4 \ 72 | --model_type qwen2-chat 73 | ``` 74 | 75 | ## 请求示例 76 | ``` 77 | curl -X POST localhost:9900/v1/chat/completions \ 78 | -H "Content-Type: application/json" \ 79 | -d '{ 80 | "model": "gpt-3.5-turbo", \ 81 | "messages": [{"role": "system", "content": "You are a helpful assistant."}, \ 82 | {"role": "user", "content": "who you are."}] 83 | }' 84 | ``` 85 | output: 86 | ```shell 87 | { 88 | "model":"gpt-3.5-turbo", 89 | "object":"chat.completion", 90 | "choices":[{"index":0,"message": 91 | {"role":"assistant","content":"I am QianWen, a large language model created by Alibaba Cloud. I was trained on a vast amount of text data from the web, including books, articles, and other sources, to understand natural language and provide responses to various questions and tasks.\n\nMy primary function is to assist with a wide range of applications, including answering questions, generating text based on input prompts, summarizing long documents, translating languages, and even writing code. I can understand and generate human-like text in multiple languages, including English, Chinese, Spanish, French, German, Italian, Japanese, Korean, Russian, Portuguese, and more.\n\nQianW","function_call":null}, 92 | "finish_reason":"stop" 93 | }], 94 | "created":1711955133} 95 | ``` 96 | 97 | 98 | ## qwenvl 示例 99 | 100 | ```shell 101 | # 本地启动 102 | python3 -m fastapi_tritonserver.entrypoints.api_server --port 9000 --host 0.0.0.0 --model-name qwen-vl-test --tokenizer-path qwenvl_repo/qwen-vl-test/qwen-vl-test-llm/20240220104327/tokenizer/ --server-url localhost:6601 --workers 1 --model_type qwen-vl 103 | 104 | # triton server启动 105 | tritonserver --model-repository=qwenvl_repo/repo/ --strict-model-config=false --log-verbose=0 --metrics-port=6000 --http-port=6609 --grpc-port=6601 106 | 107 | # 请求示例 108 | curl -X POST localhost:9000/generate -d '{"images": ["https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"], "prompt": "what it is"}' 109 | ``` 110 | 111 | ### docker部署 112 | 1. 编译docker 113 | ```bash 114 | docker build . -t fastapi_tritonserver 115 | ``` 116 | 117 | 2. 运行docker 118 | ```bash 119 | docker run -d --restart=always \ 120 | -e TOKENIZER_PATH="Qwen/Qwen1.5-1.8B-Chat" \ 121 | -e TRITON_SERVER_HOST="192.168.x.x" \ 122 | -e TRITON_SERVER_PORT="8001" \ 123 | -e MODEL_TYPE="qwen2-chat" \ 124 | -e WORKERS=4 \ 125 | --name fastapi_tritonserver \ 126 | -p 9900:9900 \ 127 | fastapi_tritonserver 128 | ``` -------------------------------------------------------------------------------- /docs/deploy_triton.md: -------------------------------------------------------------------------------- 1 | ## 模型准备: 2 | - 以Qwen1.5-1.8B-chat为例,其他规模的qwen1.5也是同样做法。 3 | ```shell 4 | git clone https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat 5 | ``` 6 | 7 | 8 | ## 模型转换: 9 | ```shell 10 | git clone https://github.com/Tlntin/Qwen-TensorRT-LLM.git 11 | cd example/qwen2 12 | ``` 13 | ```shell 14 | python build.py --hf_model_dir ./qwen/Qwen1.5-1.8B-Chat/ \ 15 | --dtype float16 \ 16 | --remove_input_padding \ 17 | --gpt_attention_plugin float16 \ 18 | --enable_context_fmha \ 19 | --gemm_plugin float16 \ 20 | --paged_kv_cache \ 21 | --output_dir ./tmp/Qwen1.5 22 | ``` 23 | ### 验证engine: 24 | ```shell 25 | python3 run.py \ 26 | --tokenizer_dir ./qwen/Qwen1.5-1.8B-Chat/ \ 27 | --engine_dir=./tmp/Qwen1.5 28 | ``` 29 | output: 30 | 31 | ```shell 32 | [TensorRT-LLM] TensorRT-LLM version: 0.8.0Input [Text 0]: "<|im_start|>system 33 | You are a helpful assistant.<|im_end|> 34 | <|im_start|>user 35 | 你好,请问你叫什么?<|im_end|> 36 | <|im_start|>assistant 37 | " 38 | Output [Text 0 Beam 0]: "你好!我是来自阿里云的大规模语言模型,我叫通义千问。" 39 | ``` 40 | engine验证完成。 41 | 42 | ## 封装triton服务: 43 | 参考Qwen-TensorRT-LLM项目triton_model_rep目录,主要是修改config.obtxt文件中gpt_model_path字段,修改成engine的实际路径。 44 | triton加载命令,triton镜像选择使用nvcr.io/nvidia/tritonserver:24.02-trtllm-python-py3 : 45 | ```shell 46 | CUDA_VISIBLE_DEVICES=0 tritonserver --model-repository=repo --strict-model-config=false --log-verbose=0 --metrics-port=6000 47 | ``` 48 | output: 49 | ```shell 50 | I0328 03:16:20.315813 150 server.cc:634] 51 | +-------------+-----------------------------------------------------------------+---------------------------------------------------------------------------------------------------+ 52 | | Backend | Path | Config | 53 | +-------------+-----------------------------------------------------------------+---------------------------------------------------------------------------------------------------+ 54 | | tensorrtllm | /opt/tritonserver/backends/tensorrtllm/libtriton_tensorrtllm.so | {"cmdline":{"auto-complete-config":"true","backend-directory":"/opt/tritonserver/backends","min-c | 55 | | | | ompute-capability":"6.000000","default-max-batch-size":"4"}} | 56 | +-------------+-----------------------------------------------------------------+---------------------------------------------------------------------------------------------------+ 57 | 58 | I0328 03:16:20.315837 150 server.cc:677] 59 | +--------------+---------+--------+ 60 | | Model | Version | Status | 61 | +--------------+---------+--------+ 62 | | tensorrt_llm | 1 | READY | 63 | +--------------+---------+--------+ 64 | 65 | I0328 03:16:20.929471 150 metrics.cc:877] Collecting metrics for GPU 0: NVIDIA A100-SXM4-80GB 66 | I0328 03:16:20.996808 150 metrics.cc:770] Collecting CPU metrics 67 | I0328 03:16:20.997046 150 tritonserver.cc:2508] 68 | +----------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------+ 69 | | Option | Value | 70 | +----------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------+ 71 | | server_id | triton | 72 | | server_version | 2.43.0 | 73 | | server_extensions | classification sequence model_repository model_repository(unload_dependents) schedule_policy model_configuration system_shared_memory cuda_sha | 74 | | | red_memory binary_tensor_data parameters statistics trace logging | 75 | | model_repository_path[0] | repo | 76 | | model_control_mode | MODE_NONE | 77 | | strict_model_config | 0 | 78 | | rate_limit | OFF | 79 | | pinned_memory_pool_byte_size | 268435456 | 80 | | cuda_memory_pool_byte_size{0} | 67108864 | 81 | | min_supported_compute_capability | 6.0 | 82 | | strict_readiness | 1 | 83 | | exit_timeout | 30 | 84 | | cache_enabled | 0 | 85 | +----------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------+ 86 | 87 | I0328 03:16:20.998511 150 grpc_server.cc:2519] Started GRPCInferenceService at 0.0.0.0:8001 88 | I0328 03:16:20.998732 150 http_server.cc:4637] Started HTTPService at 0.0.0.0:8000 89 | I0328 03:16:21.039690 150 http_server.cc:320] Started Metrics Service at 0.0.0.0:6000 90 | ``` 91 | 现实加载成功,grpc接口是0.0.0.0:8001,这个会作为模型的url被fastapi-tritonserver使用。 -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "fastapi_tritonserver" 3 | version = "0.0.1" 4 | authors = [ 5 | {name="zhaohb", email="zhaohbcloud@126.com"}, 6 | {name="Tlntin", email="tlntindeng01@gmail.com"}, 7 | ] 8 | description = "Support openai api format input for tritonserver" 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | "Operating System :: OS Independent", 15 | ] 16 | dependencies = [ 17 | "fastapi", 18 | "uvicorn", 19 | "tritonclient[all]", 20 | "SentencePiece", 21 | "numpy", 22 | "transformers", 23 | "torch", # generation.utils need torch 24 | "torchvision", # need by qwen-vl 25 | ] 26 | 27 | [project.urls] 28 | "Homepage" = "https://github.com/zhaohb/fastapi_tritonserver_test" 29 | "Bug Tracker" = "https://github.com/zhaohb/fastapi_tritonserver_test/issues" 30 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi 2 | uvicorn 3 | tritonclient[all] 4 | SentencePiece 5 | numpy 6 | transformers 7 | -------------------------------------------------------------------------------- /src/fastapi_tritonserver/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohb/fastapi_tritonserver/b2dd2007b1895c8205739932db9af9e6f27dce92/src/fastapi_tritonserver/__init__.py -------------------------------------------------------------------------------- /src/fastapi_tritonserver/config.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from dataclasses import dataclass 3 | import argparse 4 | import os 5 | 6 | now_dir = os.path.dirname(os.path.abspath(__file__)) 7 | parent_dir = os.path.dirname(now_dir) 8 | 9 | 10 | @dataclass 11 | class ServerConf: 12 | '''arguments''' 13 | server_url: str = '' 14 | model_name: str = '' 15 | tokenizer_path: str = '' 16 | model_type: str = '' 17 | 18 | @staticmethod 19 | def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 20 | parser.add_argument( 21 | '--server-url', 22 | type=str, 23 | default='127.0.0.1:8001', 24 | help='tritonserver url' 25 | ) 26 | parser.add_argument( 27 | '--model-name', 28 | type=str, 29 | default="tensorrt_llm", 30 | help='model name' 31 | ) 32 | tokenizer_dir = os.path.join(parent_dir, "qwen1.5_7b_chat") 33 | parser.add_argument( 34 | '--tokenizer-path', 35 | type=str, 36 | default="Qwen/Qwen1.5-4B-Chat", 37 | help='tokenzier ptah for load') 38 | parser.add_argument( 39 | '--model-type', 40 | type=str, 41 | default='qwen2-chat', 42 | help='the model type for load') 43 | return parser 44 | 45 | @classmethod 46 | def from_cli_args(cls, args: argparse.Namespace) -> 'ServerConf': 47 | # Get the list of attributes of this dataclass. 48 | attrs = [attr.name for attr in dataclasses.fields(cls)] 49 | # Set the attributes from the parsed arguments. 50 | conf = cls(**{attr: getattr(args, attr) for attr in attrs}) 51 | return conf 52 | -------------------------------------------------------------------------------- /src/fastapi_tritonserver/constants.py: -------------------------------------------------------------------------------- 1 | app_constants = { 2 | # constants 3 | "model_dict": {'qwen-vl':'QwenvlModel', 'qwen2-chat': 'Qwen2ChatModel'} 4 | } 5 | -------------------------------------------------------------------------------- /src/fastapi_tritonserver/ctx.py: -------------------------------------------------------------------------------- 1 | app_ctx = { 2 | # LLM infer engin 3 | "asyncEngine": None, 4 | "asyncVisualEngine": None, 5 | "model_type": None, 6 | } 7 | -------------------------------------------------------------------------------- /src/fastapi_tritonserver/engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohb/fastapi_tritonserver/b2dd2007b1895c8205739932db9af9e6f27dce92/src/fastapi_tritonserver/engine/__init__.py -------------------------------------------------------------------------------- /src/fastapi_tritonserver/engine/engine.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from abc import ABC, abstractmethod 3 | from fastapi_tritonserver.sampling_params import SamplingParams 4 | from transformers import AutoTokenizer 5 | from fastapi_tritonserver.logger import _root_logger 6 | from fastapi_tritonserver.utils.generate_cfg import parse_cfg 7 | 8 | logger = _root_logger 9 | 10 | 11 | class BaseEngine(ABC): 12 | def __init__(self, tokenizer_path): 13 | logger.info("init tokenizer tokenizer_path: [%s]", tokenizer_path) 14 | self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) 15 | self._default_params = parse_cfg(tokenizer_path) 16 | 17 | def merge_default_params(self, params: SamplingParams) -> SamplingParams: 18 | """ 19 | overwrite SamplingParams with default generate config value 20 | """ 21 | if params.end_id is None and "end_id" in self._default_params: 22 | params.end_id = self._default_params["end_id"] 23 | if params.pad_id is None and "pad_id" in self._default_params: 24 | params.pad_id = self._default_params["pad_id"] 25 | if params.top_k is None and "top_k" in self._default_params: 26 | params.top_k = self._default_params["top_k"] 27 | if params.top_p is None and "top_p" in self._default_params: 28 | params.top_p = self._default_params["top_p"] 29 | if params.temperature is None and "temperature" in self._default_params: 30 | params.temperature = self._default_params["temperature"] 31 | if params.len_penalty is None and "len_penalty" in self._default_params: 32 | params.len_penalty = self._default_params["len_penalty"] 33 | if params.repetition_penalty is None and "repetition_penalty" in self._default_params: 34 | params.repetition_penalty = self._default_params["repetition_penalty"] 35 | if ( 36 | params.stop_words is None 37 | or (isinstance(params.stop_words, list) and len(params.stop_words) == 0) 38 | ) and "stop" in self._default_params: 39 | params.stop_words = self._default_params["stop"] 40 | return params 41 | 42 | 43 | class Engine(BaseEngine): 44 | 45 | @abstractmethod 46 | def is_server_live(self) -> bool: 47 | pass 48 | 49 | @abstractmethod 50 | async def generate( 51 | self, 52 | query: str, 53 | system_prompt: str, 54 | history_list: list, 55 | params: SamplingParams, 56 | images: list = [], 57 | visual_output=None, 58 | timeout: int = 60000, 59 | request_id: str = '', 60 | only_return_output=False 61 | ): 62 | pass 63 | 64 | @abstractmethod 65 | async def wait_ready(self): 66 | pass 67 | 68 | 69 | class AsyncEngine(BaseEngine): 70 | @abstractmethod 71 | async def is_server_live(self) -> bool: 72 | pass 73 | 74 | @abstractmethod 75 | async def generate( 76 | self, 77 | query: str, 78 | system_prompt: str, 79 | history_list: list, 80 | params: SamplingParams, 81 | images: list = [], 82 | visual_output=None, 83 | timeout: int = 60000, 84 | request_id: str = '', 85 | only_return_output=False 86 | ): 87 | pass 88 | 89 | @abstractmethod 90 | async def wait_ready(self): 91 | pass 92 | -------------------------------------------------------------------------------- /src/fastapi_tritonserver/engine/tritonserver.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from triton_server_helper.client import StreamClient, AsyncStreamClient 3 | import numpy as np 4 | import time 5 | import csv 6 | from ..sampling_params import SamplingParams 7 | from .engine import Engine, AsyncEngine 8 | from tritonclient.utils import np_to_triton_dtype 9 | import tritonclient.grpc.aio as grpcclient 10 | from fastapi_tritonserver.logger import _root_logger 11 | from torchvision import transforms 12 | from torchvision.transforms import InterpolationMode 13 | from PIL import Image 14 | import requests 15 | import torch 16 | from typing import Tuple, List, Union 17 | from ..models.base_model import BaseModel 18 | from ..models.qwenvl import QwenvlModel 19 | from ..models.qwen2chat import Qwen2ChatModel 20 | from fastapi_tritonserver.ctx import app_ctx 21 | from fastapi_tritonserver.constants import app_constants 22 | 23 | logger = _root_logger 24 | 25 | 26 | def prepare_tensor(name, input): 27 | client_util = grpcclient 28 | t = client_util.InferInput(name, input.shape, 29 | np_to_triton_dtype(input.dtype)) 30 | t.set_data_from_numpy(input) 31 | return t 32 | 33 | 34 | def to_word_list_format(word_dict: List[List[str]], 35 | tokenizer=None, 36 | add_special_tokens=False): 37 | """ 38 | format of word_dict 39 | len(word_dict) should be same to batch_size 40 | word_dict[i] means the words for batch i 41 | len(word_dict[i]) must be 1, which means it only contains 1 string 42 | This string can contain several sentences and split by ",". 43 | For example, if word_dict[2] = " I am happy, I am sad", then this function will return 44 | the ids for two short sentences " I am happy" and " I am sad". 45 | """ 46 | assert tokenizer is not None, "need to set tokenizer" 47 | 48 | flat_ids = [] 49 | offsets = [] 50 | for word_dict_item in word_dict: 51 | item_flat_ids = [] 52 | item_offsets = [] 53 | 54 | if isinstance(word_dict_item[0], bytes): 55 | word_dict_item = [word_dict_item[0].decode()] 56 | 57 | words = list(csv.reader(word_dict_item))[0] 58 | for word in words: 59 | ids = tokenizer.encode(word, add_special_tokens=add_special_tokens) 60 | 61 | if len(ids) == 0: 62 | continue 63 | 64 | item_flat_ids += ids 65 | item_offsets.append(len(ids)) 66 | 67 | flat_ids.append(np.array(item_flat_ids)) 68 | offsets.append(np.cumsum(np.array(item_offsets))) 69 | 70 | pad_to = max(1, max(len(ids) for ids in flat_ids)) 71 | 72 | for i, (ids, offs) in enumerate(zip(flat_ids, offsets)): 73 | flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0) 74 | offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1) 75 | 76 | return np.array([flat_ids, offsets], dtype="int32").transpose((1, 0, 2)) 77 | 78 | 79 | def create_inputs( 80 | model, 81 | tokenizer, 82 | query: str, 83 | system_prompt: str, 84 | history_list: list, 85 | params: SamplingParams, 86 | images: List[str], 87 | visual_output, 88 | model_type: str = None, 89 | streaming: bool = False 90 | ): 91 | 92 | if np.size(visual_output) != 0 and model_type.endswith('-vl'): 93 | input_start_ids, input_lens, prompt_table = model.make_context( 94 | tokenizer, query, images, visual_output 95 | ) 96 | input_start_ids = input_start_ids.numpy() 97 | input_lens = input_lens.numpy() 98 | else: 99 | encoded_inputs = model.encode( 100 | tokenizer, 101 | query, 102 | system_prompt=system_prompt, 103 | history_list=history_list 104 | ) 105 | input_start_ids = encoded_inputs['input_ids'] 106 | input_lens = np.sum(encoded_inputs['attention_mask'], axis=-1).reshape(-1, 1) 107 | 108 | inputs_token_lens = len(input_start_ids) 109 | 110 | input_lens = input_lens.astype(np.int32).reshape(-1, 1) 111 | inputs_shape = np.ones_like(input_lens) 112 | output0_len = inputs_shape.astype(np.int32) * params.max_output_len 113 | streaming_data = np.array([[streaming]], dtype=bool) 114 | 115 | inputs = [ 116 | prepare_tensor("input_ids", np.array([input_start_ids]).astype(np.int32)), 117 | prepare_tensor("input_lengths", input_lens), 118 | prepare_tensor("request_output_len", output0_len), 119 | prepare_tensor("streaming", streaming_data), 120 | prepare_tensor("exclude_input_in_output", np.array([[True]], dtype=bool)) 121 | ] 122 | 123 | if np.size(visual_output) != 0 and model_type.endswith('-vl'): 124 | inputs.append(prepare_tensor("prompt_embedding_table", prompt_table.astype(np.float16))) 125 | inputs.append(prepare_tensor("prompt_vocab_size", np.array([[256]]).astype(np.int32))) 126 | inputs.append(prepare_tensor("task", np.array([[0]]).astype(np.int32))) 127 | 128 | if params.beam_width: 129 | beam_width = (inputs_shape * params.beam_width).astype(np.int32) 130 | inputs.append(prepare_tensor("beam_width", beam_width)) 131 | 132 | if params.temperature: 133 | temperature = (inputs_shape * params.temperature).astype(np.float32) 134 | inputs.append(prepare_tensor("temperature", temperature)) 135 | 136 | if params.top_k: 137 | top_k = (inputs_shape * params.top_k).astype(np.int32) 138 | inputs.append(prepare_tensor("runtime_top_k", top_k)) 139 | 140 | if params.top_p: 141 | top_p = (inputs_shape * params.top_p).astype(np.float32) 142 | inputs.append(prepare_tensor("runtime_top_p", top_p)) 143 | 144 | if params.presence_penalty: 145 | presence_penalty = (inputs_shape * params.presence_penalty).astype(np.float32) 146 | inputs.append(prepare_tensor("presence_penalty", presence_penalty)) 147 | 148 | if params.repetition_penalty: 149 | repetition_penalty = (inputs_shape * params.repetition_penalty).astype(np.float32) 150 | inputs.append(prepare_tensor("repetition_penalty", repetition_penalty)) 151 | 152 | if params.len_penalty: 153 | len_penalty = (inputs_shape * params.len_penalty).astype(np.float32) 154 | inputs.append(prepare_tensor("len_penalty", len_penalty)) 155 | 156 | if params.random_seed: 157 | random_seed = (inputs_shape * params.random_seed).astype(np.uint64) 158 | inputs.append(prepare_tensor("random_seed", random_seed)) 159 | 160 | if params.end_id: 161 | end_id = (inputs_shape * params.end_id).astype(np.int32) 162 | inputs.append(prepare_tensor("end_id", end_id)) 163 | 164 | if params.stop_words: 165 | stop_words = to_word_list_format([params.stop_words], tokenizer) 166 | inputs.append(prepare_tensor("stop_words_list", stop_words)) 167 | 168 | return inputs, inputs_token_lens 169 | 170 | def decode(tokenizer, output_ids: np.ndarray, cutoff_len=0): 171 | 172 | new_ids = [[]] 173 | for id in output_ids[0]: 174 | new_ids[0].append(id[cutoff_len:]) 175 | new_ids = np.array(new_ids) 176 | return tokenizer.batch_decode(new_ids[0], skip_special_tokens=True) 177 | 178 | 179 | class TritonServerAsyncEngine(AsyncEngine): 180 | def __init__(self, url: str, tokenizer_path: str, model_name: str, model_type: str): 181 | super().__init__(tokenizer_path) 182 | self._model_name = model_name 183 | self._client = AsyncStreamClient(url, [model_name]) 184 | self._model_type = model_type 185 | if self._model_type: 186 | self._model = globals()[app_constants['model_dict'][self._model_type]]() 187 | else: 188 | self._model = BaseModel() 189 | 190 | async def is_server_live(self): 191 | return await self._client.is_server_live() 192 | 193 | async def wait_ready(self): 194 | await self._client.wait_server_ready() 195 | 196 | async def visual_infer(self, images: List[str], timeout: int = 60000, request_id: str = '', 197 | only_return_output=False): 198 | if self._model_type == 'qwen-vl': 199 | image_size = 448 200 | 201 | before_process_time = time.time() 202 | images = self._model.encode(images) 203 | #image_pre_obj = QwenvlPreprocess(image_size) 204 | #images = image_pre_obj.encode(images) 205 | 206 | if torch.numel(images) != 0: 207 | before_create_inputs_time = time.time() 208 | inputs = [ 209 | prepare_tensor("input", images.numpy()), 210 | ] 211 | 212 | before_infer_time = time.time() 213 | response_iterator = self._client.infer({ 214 | 'model_name': self._model_name, 215 | 'inputs': inputs, 216 | # 'request_id': request_id 217 | }, timeout) 218 | 219 | async for response in response_iterator: 220 | result, error = response 221 | if error: 222 | raise Exception(error) 223 | else: 224 | after_infer_time = time.time() 225 | logger.info('[%s] generate elapsed times preprocess_time: [%.4fms], ' 226 | 'infer_time: [%.4fms]', 227 | request_id, 228 | (before_create_inputs_time - before_process_time) * 1000, 229 | (after_infer_time - before_infer_time) * 1000) 230 | return result.as_numpy('output') 231 | else: 232 | return np.array([]) 233 | 234 | async def generate( 235 | self, 236 | query: str, 237 | system_prompt: str, 238 | history_list: list, 239 | params: SamplingParams, 240 | images: List[str] = [], 241 | visual_output=None, 242 | timeout: int = 60000, 243 | request_id: str = '', 244 | only_return_output=False 245 | ): 246 | params = self.merge_default_params(params) 247 | logger.info('[%s] req merged generate_params: [%s] timeout: [%s]', request_id, params.to_json(), timeout) 248 | before_create_inputs_time = time.time() 249 | inputs, inputs_token_lens = create_inputs( 250 | self._model, 251 | self._tokenizer, 252 | query=query, 253 | system_prompt=system_prompt, 254 | history_list=history_list, 255 | params=params, 256 | images=images, 257 | visual_output=visual_output, 258 | model_type=self._model_type, 259 | streaming=False 260 | ) 261 | cutoff_len = inputs_token_lens if only_return_output else 0 262 | 263 | before_infer_time = time.time() 264 | response_iterator = self._client.infer({ 265 | 'model_name': self._model_name, 266 | 'inputs': inputs, 267 | # 'request_id': request_id 268 | }, 6000) 269 | 270 | async for response in response_iterator: 271 | result, error = response 272 | if error: 273 | raise Exception(error) 274 | else: 275 | before_decode_time = time.time() 276 | decoded = self._model.decode(self._tokenizer, result.as_numpy('output_ids'), inputs_token_lens, cutoff_len) 277 | #if self._model_type == 'qwen-vl': 278 | # decoded = qwenvl_decode(self._tokenizer, result, inputs_token_lens, cutoff_len) 279 | #else: 280 | # decoded = decode(self._tokenizer, result, cutoff_len) 281 | logger.info('[%s] generate elapsed times create_input: [%.4fms], ' 282 | 'infer_time: [%.3fs], decoded_input: [%.4fms]', 283 | request_id, 284 | (before_infer_time - before_create_inputs_time) * 1000, 285 | (before_decode_time - before_infer_time), 286 | (time.time() - before_decode_time) * 1000) 287 | return decoded 288 | 289 | async def generate_streaming( 290 | self, 291 | query: str, 292 | system_prompt: str, 293 | history: list, 294 | params: SamplingParams, 295 | timeout: int = 60000, 296 | request_id: str = '' 297 | ): 298 | params = self.merge_default_params(params) 299 | logger.info('[%s] req merged generate_params: [%s] timeout: [%s]', request_id, params.to_json(), timeout) 300 | before_create_inputs_time = time.time() 301 | # inputs, inputs_token_lens = create_inputs(self._tokenizer, prompt, params, True) 302 | inputs, inputs_token_lens = create_inputs( 303 | self._model, 304 | self._tokenizer, 305 | query=query, 306 | system_prompt=system_prompt, 307 | history_list=history, 308 | params=params, 309 | images=[], # images, 310 | visual_output=None, # visual_output, 311 | model_type=self._model_type, 312 | streaming=True 313 | ) 314 | 315 | before_infer_time = time.time() 316 | response_iterator = self._client.infer({ 317 | 'model_name': self._model_name, 318 | 'inputs': inputs, 319 | # 'request_id': request_id 320 | }, timeout) 321 | 322 | before_infer_time = time.time() 323 | queue_output_ids = [] 324 | async for response in response_iterator: 325 | result, error = response 326 | if error: 327 | raise Exception(error) 328 | else: 329 | output_ids = result.as_numpy('output_ids') 330 | if len(queue_output_ids) > 0: 331 | queue_output_ids.append(output_ids) 332 | output_ids = np.concatenate(queue_output_ids, axis=-1) 333 | output_texts = decode(self._tokenizer, output_ids) 334 | is_ok = True 335 | for temp_text in output_texts: 336 | if b"\xef\xbf\xbd" in temp_text.encode(): 337 | is_ok = False 338 | if is_ok: 339 | yield output_texts 340 | queue_output_ids = [] 341 | else: 342 | if len(queue_output_ids) == 0: 343 | queue_output_ids.append(output_ids) 344 | 345 | logger.info('[%s] generate elapsed times create_input: [%.4fms], infer_time: [%.3fs]', 346 | request_id, 347 | (before_infer_time - before_create_inputs_time) * 1000, 348 | time.time() - before_infer_time) 349 | 350 | 351 | class TritonServerEngine(Engine): 352 | def __init__(self, url: str, tokenizer_path: str, model_name: str, model_type: str): 353 | super().__init__(tokenizer_path) 354 | self._model_name = model_name 355 | self._client = StreamClient(url, [model_name]) 356 | self._model_type = model_type 357 | 358 | def is_server_live(self) -> bool: 359 | return self._client.is_server_live() 360 | 361 | def wait_ready(self): 362 | self._client.wait_server_ready() 363 | 364 | def generate(self, prompt: str, params: SamplingParams, timeout: int = 10000, 365 | request_id: str = '', only_return_output=False, streaming: bool = False): 366 | params = self.merge_default_params(params) 367 | logger.info('[%s] req merged generate_params: [%s] timeout: [%s]', request_id, params.to_json(), timeout) 368 | inputs, inputs_token_lens = create_inputs(self._tokenizer, prompt, params, streaming) 369 | result = self._client.infer({ 370 | 'model_name': self._model_name, 371 | 'inputs': inputs, 372 | 'request_id': request_id 373 | }, timeout) 374 | cutoff_len = inputs_token_lens if only_return_output else 0 375 | return decode(self._tokenizer, result, cutoff_len) 376 | -------------------------------------------------------------------------------- /src/fastapi_tritonserver/entrypoints/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohb/fastapi_tritonserver/b2dd2007b1895c8205739932db9af9e6f27dce92/src/fastapi_tritonserver/entrypoints/__init__.py -------------------------------------------------------------------------------- /src/fastapi_tritonserver/entrypoints/api_server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | import uvicorn 5 | import uuid 6 | import time 7 | import asyncio 8 | import logging 9 | import os 10 | from threading import Thread 11 | from http import HTTPStatus 12 | from contextlib import asynccontextmanager 13 | from fastapi import FastAPI, Request 14 | from fastapi.responses import JSONResponse, Response 15 | from fastapi_tritonserver.logger import _root_logger, get_formatter 16 | from fastapi_tritonserver.config import ServerConf 17 | from fastapi_tritonserver.engine.tritonserver import TritonServerAsyncEngine, TritonServerEngine 18 | from fastapi_tritonserver.protocols.fastapi import GenerateRequest 19 | from fastapi_tritonserver.sampling_params import SamplingParams 20 | from fastapi_tritonserver.ctx import app_ctx 21 | from fastapi_tritonserver.entrypoints import openai_api 22 | from tritonclient.utils import InferenceServerException 23 | 24 | logger = _root_logger 25 | 26 | TIMEOUT_KEEP_ALIVE = 60 # seconds. 27 | 28 | 29 | @asynccontextmanager 30 | async def lifespan(app: FastAPI): 31 | """ 32 | init triton server client 33 | """ 34 | # reset uvicorn.access logger format 35 | uvicorn_logger = logging.getLogger("uvicorn.access") 36 | uvicorn_logger.handlers[0].setFormatter(get_formatter()) 37 | 38 | server_args = parse_args() 39 | server_conf = ServerConf.from_cli_args(server_args) 40 | logger.info("worker initiating engine. sever_conf: %s", server_conf) 41 | 42 | app_ctx['model_type'] = server_conf.model_type 43 | if server_conf.model_type.endswith('-vl'): 44 | llm_model_name = server_conf.model_name + '-llm' 45 | visual_model_name = server_conf.model_name + '-visual' 46 | 47 | app_ctx["asyncEngine"] = TritonServerAsyncEngine(server_conf.server_url, server_conf.tokenizer_path, llm_model_name, server_conf.model_type) 48 | app_ctx["asyncVisualEngine"] = TritonServerAsyncEngine(server_conf.server_url, server_conf.tokenizer_path, visual_model_name, server_conf.model_type) 49 | else: 50 | app_ctx["asyncEngine"] = TritonServerAsyncEngine(server_conf.server_url, server_conf.tokenizer_path, 51 | server_conf.model_name, server_conf.model_type) 52 | logger.info("worker waiting engine ready") 53 | await app_ctx["asyncEngine"].wait_ready() 54 | yield 55 | logger.info("worker exited") 56 | 57 | 58 | app = FastAPI(lifespan=lifespan) 59 | app.include_router(openai_api.router) 60 | 61 | 62 | def create_error_response(status_code: HTTPStatus, 63 | message: str, type: str) -> JSONResponse: 64 | return JSONResponse({"message": message, "type": type}, 65 | status_code=status_code.value) 66 | 67 | 68 | @app.exception_handler(InferenceServerException) 69 | async def validation_exception_handler(request, exc): 70 | return create_error_response(HTTPStatus.BAD_REQUEST, str(exc), "infer_err") 71 | 72 | 73 | @app.exception_handler(ValueError) 74 | async def validation_exception_handler(request, exc): 75 | return create_error_response(HTTPStatus.BAD_REQUEST, str(exc), "param_err") 76 | 77 | 78 | def get_request_id(request:GenerateRequest): 79 | if request.uuid is not None and len(request.uuid) > 0: 80 | return request.uuid 81 | else: 82 | return str(uuid.uuid4()) 83 | 84 | 85 | @app.post("/generate") 86 | async def generate(raw_request: Request) -> Response: 87 | """ 88 | Generate completion for the request. 89 | The request should be a JSON object with the following fields: 90 | - prompt: the prompt to use for the generation. 91 | - max_output_len: Maximum number of tokens to generate per output sequence. 92 | - num_beams: the beam width when use beam search. 93 | - repetition_penalty 94 | - top_k: Integer that controls the number of top tokens to consider. 95 | Set to -1 to consider all tokens. 96 | - top_p: Float that controls the cumulative probability of the top tokens 97 | to consider. Must be in (0, 1]. Set to 1 to consider all tokens. 98 | - temperature: Float that controls the randomness of the sampling. Lower 99 | values make the model more deterministic, while higher values make 100 | the model more random. Zero means greedy sampling. 101 | """ 102 | request_dict = await raw_request.json() 103 | request = GenerateRequest(**request_dict) 104 | request_id = get_request_id(request) 105 | params = SamplingParams(**request_dict) 106 | # Non-streaming response 107 | if not request.stream: 108 | begin = time.time() 109 | try: 110 | if app_ctx['model_type'].endswith("-vl"): 111 | logger.info('[%s] req request: [%s] generate_params: [%s]', request_id, request_dict, params.to_json()) 112 | visual_output = await app_ctx["asyncVisualEngine"].visual_infer(request.images, request.timeout, request_id, request.only_return_output) 113 | text = await app_ctx["asyncEngine"].generate(request.prompt, params, request.images, visual_output, request.timeout, request_id, request.only_return_output) 114 | else: 115 | logger.info('[%s] req request: [%s] generate_params: [%s]', request_id, request_dict, params.to_json()) 116 | text = await app_ctx["asyncEngine"].generate(request.prompt, params, request.timeout, request_id, request.only_return_output) 117 | except Exception as e: 118 | logger.error('[%s] process fail msg: [%s]', request_id, str(e)) 119 | raise e 120 | logger.info('[%s] resp elapsed: [%.4fs] result: [%s]', request_id, time.time() - begin, text) 121 | ret = {"text": text, "id": request_id} 122 | return JSONResponse(ret) 123 | else: 124 | app_ctx["asyncEngine"].generate_streaming(request.prompt, params, request.timeout, request_id) 125 | 126 | 127 | @app.post("/batch_generate") 128 | async def batch_generate(raw_request: Request) -> Response: 129 | request_dict = await raw_request.json() 130 | request = GenerateRequest(**request_dict) 131 | request_id = get_request_id(request) 132 | params = SamplingParams(**request_dict) 133 | begin = time.time() 134 | try: 135 | logger.info('[%s] req request: [%s] generate_params: [%s]', request_id, request_dict, params.to_json()) 136 | futures = [] 137 | for i, p in enumerate(request.prompts): 138 | futures.append( 139 | app_ctx["asyncEngine"].generate(p, params, request.timeout, request_id + '|' + str(i), request.only_return_output)) 140 | 141 | texts = await asyncio.gather(*futures) 142 | except Exception as e: 143 | logger.error('[%s] process fail msg: [%s]', request_id, str(e)) 144 | raise e 145 | logger.info('[%s] resp elapsed: [%.4fs] result: [%s]', request_id, time.time() - begin, texts) 146 | ret = {"texts": texts, "id": request_id} 147 | return JSONResponse(ret) 148 | 149 | 150 | def parse_args(): 151 | parser = argparse.ArgumentParser() 152 | parser.add_argument("--host", type=str, default="localhost") 153 | parser.add_argument("--port", type=int, default=8000) 154 | parser.add_argument("--workers", type=int, default=1) 155 | parser.add_argument("--model_type", type=str, default="qwen-vl") 156 | parser = ServerConf.add_cli_args(parser) 157 | return parser.parse_args() 158 | 159 | 160 | def wait_engine_ready(engine:TritonServerEngine): 161 | engine.wait_ready() 162 | 163 | 164 | def health_check(args, engine:TritonServerEngine): 165 | def server_live_job(): 166 | err_cnt = 0 167 | while True: 168 | try: 169 | time.sleep(2) 170 | if not engine.is_server_live(): 171 | logger.warning('server is not live') 172 | err_cnt = err_cnt + 1 173 | except Exception as e: 174 | logger.warning('sever_live_job err: %s', e) 175 | err_cnt = err_cnt + 1 176 | if err_cnt > 5: 177 | cmd = "lsof -i :" + str(args.port) + " | awk '{print $2}' | grep -v PID| xargs kill -9" 178 | logger.error('server is not live > 5 times, exit!, exec cmd: [%s]', cmd) 179 | os.system(cmd) 180 | 181 | thread = Thread(target=server_live_job) 182 | thread.start() 183 | return thread 184 | 185 | 186 | if __name__ == "__main__": 187 | logger.info('fastapi-trt-llm-server start') 188 | args = parse_args() 189 | logger.info(f"args: {args}") 190 | logger.info("initiating engine.") 191 | 192 | server_conf = ServerConf.from_cli_args(args) 193 | model_type = server_conf.model_type 194 | logger.info(f"model_type: {server_conf.model_type}") 195 | 196 | if server_conf.model_type.endswith('-vl'): 197 | llm_model_name = server_conf.model_name + '-llm' 198 | visual_model_name = server_conf.model_name + '-visual' 199 | 200 | llm_engine = TritonServerEngine(server_conf.server_url, server_conf.tokenizer_path, llm_model_name, server_conf.model_type) 201 | wait_engine_ready(llm_engine) 202 | health_task = health_check(args, llm_engine) 203 | visual_engine = TritonServerEngine(server_conf.server_url, server_conf.tokenizer_path, visual_model_name, server_conf.model_type) 204 | wait_engine_ready(visual_engine) 205 | health_task = health_check(args, visual_engine) 206 | else: 207 | engine = TritonServerEngine(server_conf.server_url, server_conf.tokenizer_path, server_conf.model_name, server_conf.model_type) 208 | wait_engine_ready(engine) 209 | health_task = health_check(args, engine) 210 | 211 | logger.info("start http server") 212 | uvicorn.run("__main__:app", 213 | host=args.host, 214 | port=args.port, 215 | workers=args.workers, 216 | use_colors=False, 217 | reload=True, 218 | timeout_keep_alive=TIMEOUT_KEEP_ALIVE) 219 | health_task.join() 220 | 221 | 222 | -------------------------------------------------------------------------------- /src/fastapi_tritonserver/entrypoints/openai_api.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import uvicorn 3 | import time 4 | import re 5 | import logging 6 | import os 7 | import json 8 | import copy 9 | from threading import Thread 10 | from http import HTTPStatus 11 | from typing import AsyncGenerator, Optional 12 | from contextlib import asynccontextmanager 13 | from fastapi import FastAPI, Request, HTTPException 14 | from fastapi.responses import JSONResponse, Response, StreamingResponse 15 | from fastapi_tritonserver.logger import _root_logger, get_formatter 16 | from fastapi_tritonserver.config import ServerConf 17 | from fastapi_tritonserver.engine.tritonserver import TritonServerAsyncEngine, TritonServerEngine 18 | from fastapi_tritonserver.protocols.openai import ( 19 | ModelCard, ModelList, ChatMessage, DeltaMessage, 20 | ChatCompletionRequest, ChatCompletionResponseChoice, 21 | ChatCompletionResponseStreamChoice, 22 | ChatCompletionResponse, ChatCompletionStreamResponse, 23 | UsageInfo 24 | ) 25 | 26 | from fastapi_tritonserver.sampling_params import SamplingParams 27 | from fastapi_tritonserver.ctx import app_ctx 28 | from tritonclient.utils import InferenceServerException 29 | from fastapi_tritonserver.utils.tools import random_uuid 30 | from typing import List, Literal, Optional, Union, Dict 31 | from fastapi import APIRouter 32 | 33 | _TEXT_COMPLETION_CMD = object() 34 | 35 | logger = _root_logger 36 | router = APIRouter(prefix="/openai", tags=["OpenAI API"]) 37 | 38 | TIMEOUT_KEEP_ALIVE = 60 # seconds. 39 | 40 | 41 | @asynccontextmanager 42 | async def lifespan(app: FastAPI): 43 | """ 44 | init triton server client 45 | """ 46 | # reset uvicorn.access logger format 47 | uvicorn_logger = logging.getLogger("uvicorn.access") 48 | uvicorn_logger.handlers[0].setFormatter(get_formatter()) 49 | 50 | server_args = parse_args() 51 | server_conf = ServerConf.from_cli_args(server_args) 52 | logger.info("worker initiating engine. sever_conf: %s", server_conf) 53 | 54 | app_ctx['model_type'] = server_conf.model_type 55 | if server_conf.model_type.endswith('-vl'): 56 | llm_model_name = server_conf.model_name + '-llm' 57 | visual_model_name = server_conf.model_name + '-visual' 58 | 59 | app_ctx["asyncEngine"] = TritonServerAsyncEngine(server_conf.server_url, server_conf.tokenizer_path, llm_model_name, server_conf.model_type) 60 | app_ctx["asyncVisualEngine"] = TritonServerAsyncEngine(server_conf.server_url, server_conf.tokenizer_path, visual_model_name, server_conf.model_type) 61 | else: 62 | app_ctx["asyncEngine"] = TritonServerAsyncEngine(server_conf.server_url, server_conf.tokenizer_path, 63 | server_conf.model_name, server_conf.model_type) 64 | logger.info("worker waiting engine ready") 65 | await app_ctx["asyncEngine"].wait_ready() 66 | yield 67 | logger.info("worker exited") 68 | 69 | 70 | app = FastAPI(lifespan=lifespan) 71 | app.include_router(router) 72 | 73 | 74 | def trim_stop_words(response, stop_words): 75 | if stop_words: 76 | for stop in stop_words: 77 | idx = response.find(stop) 78 | if idx != -1: 79 | response = response[:idx] 80 | return response 81 | 82 | def create_error_response(status_code: HTTPStatus, 83 | message: str, type: str) -> JSONResponse: 84 | return JSONResponse({"message": message, "type": type}, 85 | status_code=status_code.value) 86 | 87 | 88 | @app.exception_handler(InferenceServerException) 89 | async def validation_exception_handler(request, exc): 90 | return create_error_response(HTTPStatus.BAD_REQUEST, str(exc), "infer_err") 91 | 92 | 93 | @app.exception_handler(ValueError) 94 | async def validation_exception_handler(request, exc): 95 | return create_error_response(HTTPStatus.BAD_REQUEST, str(exc), "param_err") 96 | 97 | 98 | TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}""" 99 | 100 | REACT_INSTRUCTION = """Answer the following questions as best you can. You have access to the following APIs: 101 | 102 | {tools_text} 103 | 104 | Use the following format: 105 | 106 | Question: the input question you must answer 107 | Thought: you should always think about what to do 108 | Action: the action to take, should be one of [{tools_name_text}] 109 | Action Input: the input to the action 110 | Observation: the result of the action 111 | ... (this Thought/Action/Action Input/Observation can be repeated zero or more times) 112 | Thought: I now know the final answer 113 | Final Answer: the final answer to the original input question 114 | 115 | Begin!""" 116 | 117 | 118 | _TEXT_COMPLETION_CMD = object() 119 | 120 | 121 | def parse_messages(messages, functions): 122 | if all(m.role != "user" for m in messages): 123 | raise HTTPException( 124 | status_code=400, 125 | detail=f"Invalid request: Expecting at least one user message.", 126 | ) 127 | 128 | messages = copy.deepcopy(messages) 129 | if messages[0].role == "system": 130 | system = messages.pop(0).content.lstrip("\n").rstrip() 131 | else: 132 | system = "You are a helpful assistant." 133 | 134 | if functions: 135 | tools_text = [] 136 | tools_name_text = [] 137 | for func_info in functions: 138 | name = func_info.get("name", "") 139 | name_m = func_info.get("name_for_model", name) 140 | name_h = func_info.get("name_for_human", name) 141 | desc = func_info.get("description", "") 142 | desc_m = func_info.get("description_for_model", desc) 143 | tool = TOOL_DESC.format( 144 | name_for_model=name_m, 145 | name_for_human=name_h, 146 | # Hint: You can add the following format requirements in description: 147 | # "Format the arguments as a JSON object." 148 | # "Enclose the code within triple backticks (`) at the beginning and end of the code." 149 | description_for_model=desc_m, 150 | parameters=json.dumps(func_info["parameters"], ensure_ascii=False), 151 | ) 152 | tools_text.append(tool) 153 | tools_name_text.append(name_m) 154 | tools_text = "\n\n".join(tools_text) 155 | tools_name_text = ", ".join(tools_name_text) 156 | instruction = (REACT_INSTRUCTION.format( 157 | tools_text=tools_text, 158 | tools_name_text=tools_name_text, 159 | ).lstrip('\n').rstrip()) 160 | else: 161 | instruction = '' 162 | 163 | dummy_thought = { 164 | "en": "\nThought: I now know the final answer.\nFinal answer: ", 165 | "zh": "\nThought: 我会作答了。\nFinal answer: ", 166 | } 167 | 168 | _messages = messages 169 | messages = [] 170 | for m_idx, m in enumerate(_messages): 171 | role, content, func_call = m.role, m.content, m.function_call 172 | if content: 173 | content = content.lstrip("\n").rstrip() 174 | if role == "function": 175 | if (len(messages) == 0) or (messages[-1].role != "assistant"): 176 | raise HTTPException( 177 | status_code=400, 178 | detail=f"Invalid request: Expecting role assistant before role function.", 179 | ) 180 | messages[-1].content += f"\nObservation: {content}" 181 | if m_idx == len(_messages) - 1: 182 | messages[-1].content += "\nThought:" 183 | elif role == "assistant": 184 | if len(messages) == 0: 185 | raise HTTPException( 186 | status_code=400, 187 | detail=f"Invalid request: Expecting role user before role assistant.", 188 | ) 189 | last_msg = messages[-1].content 190 | last_msg_has_zh = len(re.findall(r"[\u4e00-\u9fff]+", last_msg)) > 0 191 | if func_call is None: 192 | if functions: 193 | content = dummy_thought["zh" if last_msg_has_zh else "en"] + content 194 | else: 195 | f_name, f_args = func_call["name"], func_call["arguments"] 196 | if not content: 197 | if last_msg_has_zh: 198 | content = f"Thought: 我可以使用 {f_name} API。" 199 | else: 200 | content = f"Thought: I can use {f_name}." 201 | content = f"\n{content}\nAction: {f_name}\nAction Input: {f_args}" 202 | if messages[-1].role == "user": 203 | messages.append( 204 | ChatMessage(role="assistant", content=content.lstrip("\n").rstrip()) 205 | ) 206 | else: 207 | messages[-1].content += content 208 | elif role == "user": 209 | messages.append( 210 | ChatMessage(role="user", content=content.lstrip("\n").rstrip()) 211 | ) 212 | else: 213 | raise HTTPException( 214 | status_code=400, detail=f"Invalid request: Incorrect role {role}." 215 | ) 216 | 217 | query = _TEXT_COMPLETION_CMD 218 | if messages[-1].role == "user": 219 | query = messages[-1].content 220 | messages = messages[:-1] 221 | 222 | if len(messages) % 2 != 0: 223 | print(376) 224 | raise HTTPException(status_code=400, detail="Invalid request") 225 | 226 | history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)] 227 | for i in range(0, len(messages), 2): 228 | if messages[i].role == "user" and messages[i + 1].role == "assistant": 229 | usr_msg = messages[i].content.lstrip("\n").rstrip() 230 | bot_msg = messages[i + 1].content.lstrip("\n").rstrip() 231 | if instruction and (i == len(messages) - 2): 232 | usr_msg = f"{instruction}\n\nQuestion: {usr_msg}" 233 | instruction = "" 234 | for t in dummy_thought.values(): 235 | t = t.lstrip("\n") 236 | if bot_msg.startswith(t) and ("\nAction: " in bot_msg): 237 | bot_msg = bot_msg[len(t):] 238 | history.append([usr_msg, bot_msg]) 239 | else: 240 | raise HTTPException( 241 | status_code=400, 242 | detail="Invalid request: Expecting exactly one user (or function) role before every assistant role.", 243 | ) 244 | if instruction: 245 | assert query is not _TEXT_COMPLETION_CMD 246 | query = f"{instruction}\n\nQuestion: {query}" 247 | return query, history, system 248 | 249 | 250 | async def get_gen_prompt(request): 251 | query, history, system = parse_messages(request.messages, request.functions) 252 | return query, history, system 253 | 254 | 255 | def parse_response(response): 256 | func_name, func_args = "", "" 257 | i = response.rfind("\nAction:") 258 | j = response.rfind("\nAction Input:") 259 | k = response.rfind("\nObservation:") 260 | if 0 <= i < j: # If the text has `Action` and `Action input`, 261 | if k < j: # but does not contain `Observation`, 262 | # then it is likely that `Observation` is omitted by the LLM, 263 | # because the output text may have discarded the stop word. 264 | response = response.rstrip() + "\nObservation:" # Add it back. 265 | k = response.rfind("\nObservation:") 266 | func_name = response[i + len("\nAction:") : j].strip() 267 | func_args = response[j + len("\nAction Input:") : k].strip() 268 | if func_name: 269 | choice_data = ChatCompletionResponseChoice( 270 | index=0, 271 | message=ChatMessage( 272 | role="assistant", 273 | content=response[:i], 274 | function_call={"name": func_name, "arguments": func_args}, 275 | ), 276 | finish_reason="function_call", 277 | ) 278 | return choice_data 279 | z = response.rfind("\nFinal Answer: ") 280 | if z >= 0: 281 | response = response[z + len("\nFinal Answer: ") :] 282 | choice_data = ChatCompletionResponseChoice( 283 | index=0, 284 | message=ChatMessage(role="assistant", content=response), 285 | finish_reason="stop", 286 | ) 287 | return choice_data 288 | 289 | 290 | @app.post("/v1/chat/completions") 291 | async def create_chat_completion(raw_request: ChatCompletionRequest): 292 | 293 | request_dict = raw_request.json() 294 | request_json = json.loads(request_dict) 295 | request = ChatCompletionRequest(**request_json) 296 | 297 | begin = time.time() 298 | # logger.info(f"Received chat completion request: {request}") 299 | 300 | stop_words = [] 301 | if request.stop is not None and len(request.stop) > 0: 302 | if isinstance(request.stop, str): 303 | stop_words.append(request.stop) 304 | else: 305 | stop_words.append(','.join(request.stop)) 306 | query, history, system = parse_messages(request.messages, request.functions) 307 | 308 | # print("query: ", query) 309 | # print("history: ", history) 310 | if request.stream and request.functions: 311 | raise HTTPException( 312 | status_code=400, 313 | detail="Invalid request: Function calling is not yet implemented for stream mode.", 314 | ) 315 | 316 | model_name = request.model 317 | request_id = f"{random_uuid()}" 318 | created_time = int(time.time()) 319 | 320 | if request.messages[-1].role not in ["user", "function"]: 321 | print(454) 322 | raise HTTPException(status_code=400, detail="Invalid request") 323 | # query = request.messages[-1].content 324 | prev_messages = request.messages[:-1] 325 | if len(prev_messages) > 0 and prev_messages[0].role == "system": 326 | system = prev_messages.pop(0).content 327 | else: 328 | system = "You are a helpful assistant." 329 | 330 | if request.functions: 331 | stop_words = stop_words or [] 332 | if "Observation:" not in stop_words: 333 | stop_words.append("Observation:") 334 | 335 | params = SamplingParams( 336 | max_output_len=request.max_tokens, 337 | temperature=request.temperature, 338 | top_p=request.top_p, 339 | top_k=request.top_k, 340 | beam_width=request.n, 341 | presence_penalty=request.presence_penalty, 342 | frequency_penalty=request.frequency_penalty, 343 | stop_words=stop_words, 344 | ) 345 | 346 | def create_stream_response_json( 347 | index: int, 348 | text: str, 349 | finish_reason: Optional[str] = None, 350 | ) -> str: 351 | choice_data = ChatCompletionResponseStreamChoice( 352 | index=index, 353 | delta=DeltaMessage(content=text), 354 | finish_reason=finish_reason, 355 | ) 356 | response = ChatCompletionStreamResponse( 357 | id=request_id, 358 | created=created_time, 359 | model=model_name, 360 | choices=[choice_data], 361 | ) 362 | response_json = response.model_dump_json() 363 | 364 | return response_json 365 | 366 | async def completion_stream_generator() -> AsyncGenerator[str, None]: 367 | texts = [] 368 | 369 | # first chunk do 370 | for i in range(request.n): 371 | texts.append("") 372 | 373 | async for res in app_ctx["asyncEngine"].generate_streaming( 374 | query=query, 375 | system_prompt=system, 376 | history=history, 377 | params=params, 378 | request_id=request_id, 379 | # output_accumulate=request.output_accumulate 380 | ): 381 | for i, output in enumerate(res): 382 | texts[i] += output 383 | response_json = create_stream_response_json( 384 | index=i, 385 | text=output, 386 | ) 387 | yield f"data: {response_json}\n\n" 388 | 389 | # last chunk with role 390 | for i in range(request.n): 391 | choice_data = ChatCompletionResponseStreamChoice( 392 | index=i, 393 | delta=DeltaMessage(role="assistant"), 394 | finish_reason="stop", 395 | ) 396 | chunk = ChatCompletionStreamResponse( 397 | id=request_id, 398 | choices=[choice_data], 399 | model=model_name, 400 | ) 401 | data = chunk.model_dump_json() 402 | yield f"data: {data}\n\n" 403 | 404 | yield "data: [DONE]\n\n" 405 | # logger.info('[%s] resp elapsed: [%.4fs] result: [%s]', request_id, time.time() - begin, texts) 406 | 407 | # Streaming response 408 | if request.stream: 409 | # background_tasks = BackgroundTasks() 410 | # Abort the request if the client disconnects. 411 | # background_tasks.add_task(abort_request) 412 | return StreamingResponse(completion_stream_generator(), 413 | media_type="text/event-stream") 414 | try: 415 | # logger.info('[%s] req request: [%s] generate_params: [%s]', request_id, request_dict, params.to_json()) 416 | response = await app_ctx["asyncEngine"].generate( 417 | query=query, 418 | system_prompt=system, 419 | history_list=history, 420 | params=params, 421 | timeout=6000, 422 | request_id=request_id 423 | ) 424 | except Exception as e: 425 | logger.error('[%s] process fail msg: [%s]', request_id, str(e)) 426 | raise e 427 | # logger.info( 428 | # '[%s] resp elapsed: [%.4fs] result: [%s]', 429 | # request_id, 430 | # time.time() - begin, 431 | # response 432 | # ) 433 | response = trim_stop_words(response, stop_words) 434 | if request.functions: 435 | choice_data = parse_response(response) 436 | else: 437 | choice_data = ChatCompletionResponseChoice( 438 | index=0, 439 | message=ChatMessage(role="assistant", content=response), 440 | finish_reason="stop", 441 | ) 442 | response = ChatCompletionResponse( 443 | id=request_id, 444 | created=created_time, 445 | model=model_name, 446 | choices=[choice_data], 447 | usage=UsageInfo(), 448 | ) 449 | if request.stream: 450 | resp = response.json(ensure_ascii=False) 451 | 452 | async def fake_stream_generator() -> AsyncGenerator[str, None]: 453 | yield f"data: {resp}\n\n" 454 | yield "data: [DONE]\n\n" 455 | 456 | return StreamingResponse(fake_stream_generator(), 457 | media_type="text/event-stream") 458 | return response 459 | 460 | 461 | def parse_args(): 462 | parser = argparse.ArgumentParser() 463 | parser.add_argument("--host", type=str, default="0.0.0.0") 464 | parser.add_argument("--port", type=int, default=8000) 465 | parser.add_argument("--workers", type=int, default=1) 466 | parser.add_argument("--model_type", type=str, default="qwen2-chat") 467 | parser = ServerConf.add_cli_args(parser) 468 | return parser.parse_args() 469 | 470 | 471 | def wait_engine_ready(engine:TritonServerEngine): 472 | engine.wait_ready() 473 | 474 | 475 | def health_check(args, engine:TritonServerEngine): 476 | def server_live_job(): 477 | err_cnt = 0 478 | while True: 479 | try: 480 | time.sleep(2) 481 | if not engine.is_server_live(): 482 | logger.warning('server is not live') 483 | err_cnt = err_cnt + 1 484 | except Exception as e: 485 | logger.warning('sever_live_job err: %s', e) 486 | err_cnt = err_cnt + 1 487 | if err_cnt > 5: 488 | cmd = "lsof -i :" + str(args.port) + " | awk '{print $2}' | grep -v PID| xargs kill -9" 489 | logger.error('server is not live > 5 times, exit!, exec cmd: [%s]', cmd) 490 | os.system(cmd) 491 | 492 | thread = Thread(target=server_live_job) 493 | thread.start() 494 | return thread 495 | 496 | 497 | if __name__ == "__main__": 498 | logger.info('fastapi-trt-llm-server start') 499 | args = parse_args() 500 | logger.info(f"args: {args}") 501 | logger.info("initiating engine.") 502 | 503 | server_conf = ServerConf.from_cli_args(args) 504 | model_type = server_conf.model_type 505 | logger.info(f"model_type: {server_conf.model_type}") 506 | 507 | if server_conf.model_type.endswith('-vl'): 508 | llm_model_name = server_conf.model_name + '-llm' 509 | visual_model_name = server_conf.model_name + '-visual' 510 | 511 | llm_engine = TritonServerEngine(server_conf.server_url, server_conf.tokenizer_path, llm_model_name, server_conf.model_type) 512 | wait_engine_ready(llm_engine) 513 | health_task = health_check(args, llm_engine) 514 | visual_engine = TritonServerEngine(server_conf.server_url, server_conf.tokenizer_path, visual_model_name, server_conf.model_type) 515 | wait_engine_ready(visual_engine) 516 | health_task = health_check(args, visual_engine) 517 | else: 518 | engine = TritonServerEngine(server_conf.server_url, server_conf.tokenizer_path, server_conf.model_name, server_conf.model_type) 519 | wait_engine_ready(engine) 520 | health_task = health_check(args, engine) 521 | 522 | logger.info("start http server") 523 | uvicorn.run("__main__:app", 524 | host=args.host, 525 | port=args.port, 526 | workers=args.workers, 527 | use_colors=False, 528 | reload=True, 529 | timeout_keep_alive=TIMEOUT_KEEP_ALIVE) 530 | health_task.join() 531 | -------------------------------------------------------------------------------- /src/fastapi_tritonserver/logger.py: -------------------------------------------------------------------------------- 1 | """Logging configuration for TRT-LLM.""" 2 | import logging 3 | import sys 4 | 5 | _FORMAT = "%(levelname)s %(asctime)s.%(msecs)03d [%(filename)s:%(lineno)d] %(message)s" 6 | _DATE_FORMAT = "%m-%d %H:%M:%S" 7 | 8 | 9 | class NewLineFormatter(logging.Formatter): 10 | """Adds logging prefix to newlines to align multi-line messages.""" 11 | 12 | def __init__(self, fmt, datefmt=None): 13 | logging.Formatter.__init__(self, fmt, datefmt) 14 | 15 | def format(self, record): 16 | msg = logging.Formatter.format(self, record) 17 | if record.message != "": 18 | parts = msg.split(record.message) 19 | msg = msg.replace("\n", "\r\n" + parts[0]) 20 | return msg 21 | 22 | 23 | _root_logger = logging.getLogger("trt-llm-server") 24 | _default_handler = None 25 | 26 | 27 | def get_formatter(): 28 | return NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT) 29 | 30 | 31 | def _setup_logger(): 32 | print('setup trt-llm logger') 33 | logging.basicConfig(level = logging.INFO) 34 | _root_logger.setLevel(logging.INFO) 35 | global _default_handler 36 | if _default_handler is None: 37 | _default_handler = logging.StreamHandler(sys.stdout) 38 | _default_handler.flush = sys.stdout.flush # type: ignore 39 | _default_handler.setLevel(logging.INFO) 40 | _root_logger.addHandler(_default_handler) 41 | _default_handler.setFormatter(get_formatter()) 42 | # Setting this will avoid the message 43 | # being propagated to the parent logger. 44 | _root_logger.propagate = False 45 | 46 | 47 | # The logger is initialized when the module is imported. 48 | # This is thread-safe as the module is only imported once, 49 | # guaranteed by the Python GIL. 50 | _setup_logger() 51 | 52 | 53 | def init_logger(name: str): 54 | return logging.getLogger(name) 55 | 56 | -------------------------------------------------------------------------------- /src/fastapi_tritonserver/models/base_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from abc import ABC, abstractmethod 3 | 4 | 5 | class BaseModel(ABC): 6 | def __init__(self): 7 | pass 8 | 9 | def encode(self, tokenizer, query, system_prompt, history_list: list): 10 | pass 11 | 12 | def decode(self, tokenizer, output_ids, inputs_token_lens, cutoff_len=0): 13 | new_ids = [[]] 14 | for id in output_ids[0]: 15 | new_ids[0].append(id[cutoff_len:]) 16 | new_ids = np.array(new_ids) 17 | return tokenizer.batch_decode(new_ids[0], skip_special_tokens=True) 18 | 19 | def make_context(self): 20 | pass 21 | -------------------------------------------------------------------------------- /src/fastapi_tritonserver/models/prompt_template.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | from fastchat.conversation import register_conv_template 3 | from fastchat.model.model_adapter import register_model_adapter 4 | from fastchat.conversation import Conversation 5 | # from fastapi_tritonserver.models.nanbeige import NabeigeModelAdapter, get_nanbeige_conversation 6 | 7 | 8 | # def register_prompt_conf(): 9 | # register_templates() 10 | # register_adaptors() 11 | 12 | 13 | # def register_templates(): 14 | # register_conv_template(get_nanbeige_conversation()) 15 | # 16 | # 17 | # def register_adaptors(): 18 | # register_model_adapter(NabeigeModelAdapter) 19 | 20 | 21 | def conv_add_openai_messages(conv: Conversation, messages: List[Dict[str, str]]): 22 | for message in messages: 23 | msg_role = message["role"] 24 | if msg_role == "system": 25 | # rewrite system prompt 26 | conv.system_message = message["content"] 27 | elif msg_role == "user": 28 | conv.append_message(conv.roles[0], message["content"]) 29 | elif msg_role == "assistant": 30 | conv.append_message(conv.roles[1], message["content"]) 31 | else: 32 | raise ValueError(f"Unknown role: {msg_role}") 33 | 34 | # Add a blank message for the assistant. 35 | conv.append_message(conv.roles[1], '') -------------------------------------------------------------------------------- /src/fastapi_tritonserver/models/qwen2chat.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from fastapi_tritonserver.logger import _root_logger 3 | logger = _root_logger 4 | 5 | 6 | DEFAULT_PROMPT_TEMPLATES = { 7 | 'InternLMForCausalLM': 8 | "<|User|>:{input_text}\n<|Bot|>:", 9 | 'qwen': 10 | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{input_text}<|im_end|>\n<|im_start|>assistant\n", 11 | 'Qwen2ForCausalLM': 12 | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{input_text}<|im_end|>\n<|im_start|>assistant\n", 13 | } 14 | 15 | 16 | class Qwen2ChatModel(BaseModel): 17 | 18 | def __init__(self): 19 | pass 20 | 21 | def encode( 22 | self, 23 | tokenizer, 24 | query, 25 | system_prompt="You are a helpful assistant.", 26 | history_list=None 27 | ): 28 | # use make_content to generate prompt 29 | # print("input_id_list len", len(input_id_list)) 30 | messages = [ 31 | {"role": "system", "content": system_prompt}, 32 | ] 33 | for (old_query, old_response) in history_list: 34 | messages.append( 35 | {"role": "user", "content": old_query} 36 | ) 37 | messages.append( 38 | {"role": "assistant", "content": old_response} 39 | ) 40 | if isinstance(query, str): 41 | messages.append({"role": "user", "content": query}) 42 | prompt = tokenizer.apply_chat_template( 43 | messages, 44 | tokenize=False, 45 | add_generation_prompt=True 46 | ) 47 | # print("prompt: ", prompt) 48 | # used in function call 49 | if not isinstance(query, str): 50 | im_end = "<|im_end|>" 51 | # right trip 52 | prompt = prompt[: -len("<|im_start|>assistant") - 1] 53 | prompt = prompt.rstrip() 54 | prompt = prompt[: -len(im_end)] 55 | # stop_words.append(im_end) 56 | encoded_outputs = tokenizer( 57 | prompt, 58 | ) 59 | return encoded_outputs 60 | 61 | def decode(self, tokenizer, output_ids, input_lengths, cutoff_len=0): 62 | return tokenizer.decode(output_ids[0][0], skip_special_tokens=True) -------------------------------------------------------------------------------- /src/fastapi_tritonserver/models/qwenvl.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from transformers import AutoTokenizer 3 | from fastapi_tritonserver.logger import _root_logger 4 | from typing import List 5 | from torchvision import transforms 6 | from torchvision.transforms import InterpolationMode 7 | from PIL import Image 8 | import requests 9 | import torch 10 | from typing import Tuple, List, Union 11 | from fastapi_tritonserver.logger import _root_logger 12 | 13 | logger = _root_logger 14 | 15 | class QwenvlModel(BaseModel): 16 | 17 | def __init__(self): 18 | image_size = 448 19 | 20 | mean = (0.48145466, 0.4578275, 0.40821073) 21 | std = (0.26862954, 0.26130258, 0.27577711) 22 | 23 | self.image_transform = transforms.Compose([ 24 | transforms.Resize( 25 | (image_size,image_size), 26 | interpolation = InterpolationMode.BICUBIC 27 | ), 28 | transforms.ToTensor(), 29 | transforms.Normalize(mean=mean,std=std), 30 | 31 | ]) 32 | 33 | def encode(self, image_paths: List[str]): 34 | images = [] 35 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 36 | if image_paths: 37 | for image_path in image_paths: 38 | if image_path.startswith("http://") or image_path.startswith("https://"): 39 | try: 40 | image = requests.get(image_path,stream=True, timeout = 3).raw 41 | except requests.exceptions.Timeout as e: 42 | logger.info(e) 43 | return torch.tensor([]) 44 | except exceptions.MissingSchema as e: 45 | logger.info(e) 46 | return torch.tensor([]) 47 | image = Image.open(image) 48 | else: 49 | image = Image.open(image_path) 50 | image = image.convert("RGB") 51 | images.append(self.image_transform(image)) 52 | images = torch.stack(images, dim=0) 53 | return images 54 | else: 55 | return torch.tensor([]) 56 | 57 | def decode(self, tokenizer, output_ids, input_lengths, cutoff_len=0): 58 | ### Fix me: For now, although the incoming images are in array format, only single images are supported 59 | new_ids = output_ids[0][0, input_lengths:] 60 | return tokenizer.decode(new_ids, skip_special_tokens=True) 61 | 62 | def make_context(self, tokenizer, prompt: str, images: List[str], visual_output): 63 | ## fix me: images is an array, but for now I'll just take the first one 64 | image = images[0] 65 | content_list = [] 66 | content_list.append({'image': image}) 67 | content_list.append({'text': prompt}) 68 | query = tokenizer.from_list_format(content_list) 69 | 70 | def qwenvl_make_context( 71 | tokenizer, 72 | query: str, 73 | history: List[Tuple[str, str]] = None, 74 | system: str = "You are a helpful assistant.", 75 | max_window_size: int = 6144, 76 | chat_format: str = "chatml", 77 | ): 78 | if history is None: 79 | history = [] 80 | 81 | if chat_format == "chatml": 82 | im_start, im_end = "<|im_start|>", "<|im_end|>" 83 | im_start_tokens = [tokenizer.im_start_id]#151644 84 | im_end_tokens = [tokenizer.im_end_id]#[151645] 85 | nl_tokens = tokenizer.encode("\n") 86 | 87 | def _tokenize_str(role, content): 88 | return f"{role}\n{content}", tokenizer.encode( 89 | role, allowed_special=set(tokenizer.IMAGE_ST) 90 | ) + nl_tokens + tokenizer.encode(content, allowed_special=set(tokenizer.IMAGE_ST)) 91 | 92 | system_text, system_tokens_part = _tokenize_str("system", system) 93 | system_tokens = im_start_tokens + system_tokens_part + im_end_tokens 94 | 95 | raw_text = "" 96 | context_tokens = [] 97 | 98 | for turn_query, turn_response in reversed(history): 99 | query_text, query_tokens_part = _tokenize_str("user", turn_query) 100 | query_tokens = im_start_tokens + query_tokens_part + im_end_tokens 101 | if turn_response is not None: 102 | response_text, response_tokens_part = _tokenize_str( 103 | "assistant", turn_response 104 | ) 105 | response_tokens = im_start_tokens + response_tokens_part + im_end_tokens 106 | 107 | next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens 108 | prev_chat = ( 109 | f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" 110 | ) 111 | else: 112 | next_context_tokens = nl_tokens + query_tokens + nl_tokens 113 | prev_chat = f"\n{im_start}{query_text}{im_end}\n" 114 | 115 | current_context_size = ( 116 | len(system_tokens) + len(next_context_tokens) + len(context_tokens) 117 | ) 118 | if current_context_size < max_window_size: 119 | context_tokens = next_context_tokens + context_tokens 120 | raw_text = prev_chat + raw_text 121 | else: 122 | break 123 | 124 | context_tokens = system_tokens + context_tokens 125 | raw_text = f"{im_start}{system_text}{im_end}" + raw_text 126 | context_tokens += ( 127 | nl_tokens 128 | + im_start_tokens 129 | + _tokenize_str("user", query)[1] 130 | + im_end_tokens 131 | + nl_tokens 132 | + im_start_tokens 133 | + tokenizer.encode("assistant") 134 | + nl_tokens 135 | ) 136 | raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n" 137 | elif chat_format == "raw": 138 | raw_text = query 139 | context_tokens = tokenizer.encode(raw_text) 140 | else: 141 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 142 | 143 | return raw_text, context_tokens 144 | 145 | raw_text, context_tokens = qwenvl_make_context(tokenizer, query,history=None) 146 | 147 | input_ids = torch.tensor([context_tokens]) 148 | bos_pos = torch.where(input_ids == 151857) ## self.config.visual['image_start_id'] 149 | eos_pos = torch.where(input_ids == 151858) ## self.config.visual['image_start_id'] + 1 150 | assert (bos_pos[0] == eos_pos[0]).all() 151 | img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1) 152 | 153 | vocab_size = 151936 ##self.config.vocab_size 154 | fake_prompt_id = torch.arange(vocab_size, 155 | vocab_size + 156 | visual_output.shape[0] * visual_output.shape[1]) 157 | fake_prompt_id = fake_prompt_id.reshape(visual_output.shape[0], 158 | visual_output.shape[1]) 159 | for idx, (i, a, b) in enumerate(img_pos): 160 | input_ids[i][a + 1 : b] = fake_prompt_id[idx] 161 | input_ids = input_ids.contiguous().to(torch.int32) 162 | input_lengths = torch.tensor(input_ids.size(1), dtype=torch.int32) 163 | 164 | visual_output = torch.tensor(visual_output) 165 | prompt_table = visual_output.numpy() 166 | 167 | return input_ids, input_lengths, prompt_table 168 | -------------------------------------------------------------------------------- /src/fastapi_tritonserver/protocols/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohb/fastapi_tritonserver/b2dd2007b1895c8205739932db9af9e6f27dce92/src/fastapi_tritonserver/protocols/__init__.py -------------------------------------------------------------------------------- /src/fastapi_tritonserver/protocols/fastapi.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Literal, Optional, Union 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | 6 | class GenerateRequest(BaseModel): 7 | prompt: Optional[str] = None 8 | prompts: Optional[List[str]] = None 9 | image: Optional[str] = None 10 | images: Optional[List[str]] = None 11 | timeout: Optional[float] = 6000 12 | only_return_output: Optional[bool] = False 13 | uuid: Optional[str] = None 14 | stream: Optional[bool] = False 15 | 16 | -------------------------------------------------------------------------------- /src/fastapi_tritonserver/protocols/openai.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Dict, List, Literal, Optional, Union 3 | 4 | from pydantic import BaseModel, Field 5 | from fastapi_tritonserver.utils.tools import random_uuid 6 | 7 | 8 | class ErrorResponse(BaseModel): 9 | object: str = "error" 10 | message: str 11 | type: str 12 | param: Optional[str] = None 13 | code: Optional[str] = None 14 | 15 | 16 | class ModelPermission(BaseModel): 17 | id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") 18 | object: str = "model_permission" 19 | created: int = Field(default_factory=lambda: int(time.time())) 20 | allow_create_engine: bool = False 21 | allow_sampling: bool = True 22 | allow_logprobs: bool = True 23 | allow_search_indices: bool = False 24 | allow_view: bool = True 25 | allow_fine_tuning: bool = False 26 | organization: str = "*" 27 | group: Optional[str] = None 28 | is_blocking: str = False 29 | 30 | 31 | class ModelCard(BaseModel): 32 | id: str 33 | object: str = "model" 34 | created: int = Field(default_factory=lambda: int(time.time())) 35 | owned_by: str = "vllm" 36 | root: Optional[str] = None 37 | parent: Optional[str] = None 38 | permission: List[ModelPermission] = Field(default_factory=list) 39 | 40 | 41 | class ModelList(BaseModel): 42 | object: str = "list" 43 | data: List[ModelCard] = Field(default_factory=list) 44 | 45 | 46 | class UsageInfo(BaseModel): 47 | prompt_tokens: int = 0 48 | total_tokens: int = 0 49 | completion_tokens: Optional[int] = 0 50 | 51 | 52 | class ChatMessage(BaseModel): 53 | role: Literal["user", "assistant", "system", "function"] 54 | content: Optional[str] 55 | function_call: Optional[Dict] = None 56 | 57 | 58 | class ChatCompletionRequest(BaseModel): 59 | model: Optional[str] = None 60 | messages: Union[str, List[ChatMessage]] 61 | functions: Optional[List[Dict]] = None 62 | temperature: Optional[float] = 0.7 63 | top_p: Optional[float] = None 64 | n: Optional[int] = 1 65 | max_tokens: Optional[int] = 512 66 | stop: Optional[Union[str, List[str]]] = Field(default_factory=list) 67 | stream: Optional[bool] = False 68 | presence_penalty: Optional[float] = 0.0 69 | frequency_penalty: Optional[float] = 0.0 70 | logit_bias: Optional[Dict[str, float]] = None 71 | user: Optional[str] = None 72 | # Additional parameters supported by vLLM 73 | best_of: Optional[int] = None 74 | top_k: Optional[int] = None 75 | ignore_eos: Optional[bool] = False 76 | use_beam_search: Optional[bool] = False 77 | output_accumulate: Optional[bool] = False 78 | 79 | 80 | class CompletionRequest(BaseModel): 81 | model: Optional[str] = "default" 82 | prompt: Union[str, List[str]] 83 | suffix: Optional[str] = None 84 | max_tokens: Optional[int] = 512 85 | temperature: Optional[float] = None 86 | top_p: Optional[float] = None 87 | n: Optional[int] = 1 88 | stream: Optional[bool] = False 89 | logprobs: Optional[int] = None 90 | echo: Optional[bool] = False 91 | stop: Optional[Union[str, List[str]]] = Field(default_factory=list) 92 | presence_penalty: Optional[float] = 0.0 93 | frequency_penalty: Optional[float] = 0.0 94 | best_of: Optional[int] = None 95 | logit_bias: Optional[Dict[str, float]] = None 96 | user: Optional[str] = None 97 | top_k: Optional[int] = None 98 | ignore_eos: Optional[bool] = False 99 | use_beam_search: Optional[bool] = False 100 | output_accumulate: Optional[bool] = False 101 | 102 | 103 | class LogProbs(BaseModel): 104 | text_offset: List[int] = Field(default_factory=list) 105 | token_logprobs: List[Optional[float]] = Field(default_factory=list) 106 | tokens: List[str] = Field(default_factory=list) 107 | top_logprobs: List[Optional[Dict[str, 108 | float]]] = Field(default_factory=list) 109 | 110 | 111 | class CompletionResponseChoice(BaseModel): 112 | index: int 113 | text: str 114 | logprobs: Optional[LogProbs] = None 115 | finish_reason: Optional[Literal["stop", "length"]] = None 116 | 117 | 118 | class CompletionResponse(BaseModel): 119 | id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") 120 | object: str = "text_completion" 121 | created: int = Field(default_factory=lambda: int(time.time())) 122 | model: str 123 | choices: List[CompletionResponseChoice] 124 | usage: UsageInfo 125 | 126 | 127 | class CompletionResponseStreamChoice(BaseModel): 128 | index: int 129 | text: str 130 | logprobs: Optional[LogProbs] = None 131 | finish_reason: Optional[Literal["stop", "length"]] = None 132 | 133 | 134 | class CompletionStreamResponse(BaseModel): 135 | id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") 136 | object: str = "text_completion" 137 | created: int = Field(default_factory=lambda: int(time.time())) 138 | model: str 139 | choices: List[CompletionResponseStreamChoice] 140 | 141 | 142 | class ChatCompletionResponseChoice(BaseModel): 143 | index: int 144 | message: ChatMessage 145 | finish_reason: Optional[Literal["stop", "length", "function_call"]] = None 146 | 147 | 148 | class ChatCompletionResponse(BaseModel): 149 | id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") 150 | object: str = "chat.completion" 151 | created: int = Field(default_factory=lambda: int(time.time())) 152 | model: str 153 | choices: List[ChatCompletionResponseChoice] 154 | usage: UsageInfo 155 | 156 | 157 | class DeltaMessage(BaseModel): 158 | role: Optional[str] = None 159 | content: Optional[str] = None 160 | 161 | 162 | class ChatCompletionResponseStreamChoice(BaseModel): 163 | index: int 164 | delta: DeltaMessage 165 | finish_reason: Optional[Literal["stop", "length"]] = None 166 | 167 | 168 | class ChatCompletionStreamResponse(BaseModel): 169 | id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") 170 | object: str = "chat.completion.chunk" 171 | created: int = Field(default_factory=lambda: int(time.time())) 172 | model: str 173 | choices: List[ChatCompletionResponseStreamChoice] 174 | -------------------------------------------------------------------------------- /src/fastapi_tritonserver/sampling_params.py: -------------------------------------------------------------------------------- 1 | """Sampling parameters for text generation.""" 2 | from typing import List, Optional, Union 3 | import json 4 | 5 | _SAMPLING_EPS = 1e-5 6 | 7 | 8 | class SamplingParams: 9 | """Sampling parameters for text generation. 10 | 11 | Overall, we follow the sampling parameters from the OpenAI text completion 12 | API (https://platform.openai.com/docs/api-reference/completions/create). 13 | In addition, we support beam search, which is not supported by OpenAI. 14 | 15 | Args: 16 | max_output_len 17 | temperature 18 | top_p 19 | top_k 20 | beam_width 21 | repetition_penalty 22 | presence_penalty 23 | len_penalty 24 | min_length 25 | random_seed 26 | end_id 27 | """ 28 | 29 | def __init__( 30 | self, 31 | max_output_len: int = 16, 32 | temperature: float = None, 33 | top_p: float = None, 34 | top_k: int = None, 35 | beam_width: Optional[int] = None, 36 | repetition_penalty: Optional[float] = None, 37 | presence_penalty: Optional[float] = None, 38 | len_penalty: Optional[float] = None, 39 | random_seed: Optional[int] = None, 40 | end_id: Optional[List[int]] = None, 41 | pad_id: Optional[List[int]] = None, 42 | stop_words: Optional[List[str]] = None, 43 | **kwargs: object 44 | ) -> None: 45 | self.max_output_len = max_output_len 46 | self.beam_width = beam_width 47 | self.temperature = temperature 48 | self.top_p = top_p 49 | self.top_k = top_k 50 | self.repetition_penalty = repetition_penalty 51 | self.presence_penalty = presence_penalty 52 | self.len_penalty = len_penalty 53 | self.random_seed = random_seed 54 | self.end_id = end_id 55 | self.pad_id = pad_id 56 | self.stop_words = stop_words 57 | 58 | self._verify_args() 59 | # if self.beam_width: 60 | # self._verity_beam_search() 61 | # elif self.temperature is not None and self.temperature < _SAMPLING_EPS: 62 | # # Zero temperature means greedy sampling. 63 | # self._verify_greedy_sampling() 64 | 65 | def _verify_args(self) -> None: 66 | if self.repetition_penalty: 67 | if not -2.0 <= self.repetition_penalty <= 2.0: 68 | raise ValueError("repetition_penalty must be in [-2, 2], got " 69 | f"{self.repetition_penalty}.") 70 | if self.presence_penalty: 71 | if not -2.0 <= self.presence_penalty <= 2.0: 72 | raise ValueError("presence_penalty must be in [-2, 2], got " 73 | f"{self.presence_penalty}.") 74 | if self.temperature is not None and self.temperature < 0.0: 75 | raise ValueError( 76 | f"temperature must be non-negative, got {self.temperature}.") 77 | if self.top_p is not None: 78 | if not 0.0 < self.top_p <= 1.0: 79 | raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") 80 | if self.top_k is not None: 81 | if self.top_k < -1 or self.top_k == 0: 82 | raise ValueError(f"top_k must be -1 (disable), or at least 1, " 83 | f"got {self.top_k}.") 84 | if self.max_output_len < 1: 85 | raise ValueError( 86 | f"max_output_len must be at least 1, got {self.max_output_len}.") 87 | 88 | def _verity_beam_search(self) -> None: 89 | if self.temperature > _SAMPLING_EPS: 90 | raise ValueError("temperature must be 0 when using beam search.") 91 | if self.top_p < 1.0 - _SAMPLING_EPS: 92 | raise ValueError("top_p must be 1 when using beam search.") 93 | if self.top_k != -1: 94 | raise ValueError("top_k must be -1 when using beam search.") 95 | 96 | def _verify_greedy_sampling(self) -> None: 97 | if self.top_p is not None and self.top_p < 1.0 - _SAMPLING_EPS: 98 | raise ValueError("top_p must be 1 when using greedy sampling.") 99 | if self.top_p is not None and self.top_k != -1: 100 | raise ValueError("top_k must be -1 when using greedy sampling.") 101 | 102 | 103 | def to_json(self): 104 | return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, ) 105 | -------------------------------------------------------------------------------- /src/fastapi_tritonserver/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohb/fastapi_tritonserver/b2dd2007b1895c8205739932db9af9e6f27dce92/src/fastapi_tritonserver/utils/__init__.py -------------------------------------------------------------------------------- /src/fastapi_tritonserver/utils/generate_cfg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers.generation.utils import GenerationConfig 3 | 4 | 5 | def parse_cfg(path: str): 6 | cfg = GenerationConfig.from_pretrained(path) 7 | if isinstance(cfg.eos_token_id, list): 8 | end_id = cfg.eos_token_id[0] 9 | else: 10 | end_id = cfg.eos_token_id 11 | return { 12 | "end_id": end_id, 13 | "pad_id": end_id, 14 | "top_k": cfg.top_k, 15 | "top_p": cfg.top_p, 16 | "temperature": cfg.temperature, 17 | "len_penalty": 1, 18 | "repetition_penalty": cfg.repetition_penalty, 19 | "stop": ["<|endoftext|>,<|im_start|>"] 20 | } -------------------------------------------------------------------------------- /src/fastapi_tritonserver/utils/tools.py: -------------------------------------------------------------------------------- 1 | from random import randrange 2 | import math 3 | import uuid 4 | 5 | max_range = math.pow(2, 64) - 1 6 | 7 | 8 | def random_uuid() -> str: 9 | return str(uuid.uuid4().hex) 10 | 11 | 12 | def random_int64() -> int: 13 | return randrange(0, max_range) -------------------------------------------------------------------------------- /src/triton_server_helper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohb/fastapi_tritonserver/b2dd2007b1895c8205739932db9af9e6f27dce92/src/triton_server_helper/__init__.py -------------------------------------------------------------------------------- /src/triton_server_helper/client.py: -------------------------------------------------------------------------------- 1 | import time, asyncio 2 | from typing import List 3 | import queue 4 | import threading 5 | from functools import partial 6 | import tritonclient.grpc as grpcclient 7 | import tritonclient.grpc.aio as asyncgrpcclient 8 | from tritonclient.utils import InferenceServerException 9 | from fastapi_tritonserver.logger import _root_logger 10 | 11 | logger = _root_logger 12 | 13 | 14 | class AsyncStreamClient: 15 | def __init__(self, url: str, readiness_models: List[str] = []): 16 | self.url = url 17 | self._readiness_models = readiness_models 18 | self._client = asyncgrpcclient.InferenceServerClient(url=url) 19 | 20 | async def is_server_live(self) -> bool: 21 | return await self._client.is_server_live() 22 | 23 | async def wait_server_ready(self): 24 | live = await self._client.is_server_live() 25 | while not live: 26 | logger.debug("wait_server_ready live: [%s]", live) 27 | await asyncio.sleep(1) 28 | try: 29 | live = await self._client.is_server_live() 30 | except Exception as e: 31 | logger.warn("wait server ready err: %s", str(e)) 32 | 33 | await self.wait_models_ready() 34 | 35 | async def wait_models_ready(self): 36 | all_ready = False 37 | while not all_ready: 38 | curr_state = True 39 | for model_name in self._readiness_models: 40 | ready = await self._client.is_model_ready(model_name) 41 | if not ready: 42 | curr_state = False 43 | all_ready = curr_state 44 | await asyncio.sleep(1) 45 | 46 | async def get_model_meta(self, model_name): 47 | return await self._client.get_model_metadata(model_name) 48 | 49 | def infer(self, request, timeout: int = 1000): 50 | """ 51 | yield: 52 | - output_ids 53 | - sequence_length 54 | - cum_log_probs 55 | - output_log_probs 56 | """ 57 | async def async_request_iterator(): 58 | yield { 59 | **request, 60 | "sequence_start": True, 61 | "sequence_end": True, 62 | } 63 | 64 | # Start streaming 65 | response_iterator = self._client.stream_infer( 66 | inputs_iterator=async_request_iterator(), 67 | stream_timeout=timeout, 68 | ) 69 | 70 | return response_iterator 71 | 72 | 73 | class UserData: 74 | 75 | def __init__(self): 76 | self._completed_requests = queue.Queue() 77 | 78 | @property 79 | def completed_requests(self): 80 | return self._completed_requests 81 | 82 | 83 | def callback(user_data, result, error): 84 | if error: 85 | user_data.completed_requests.put(error) 86 | else: 87 | user_data.completed_requests.put(result) 88 | 89 | 90 | class StreamClient: 91 | def __init__(self, url: str, readiness_models: List[str] = []): 92 | self.url = url 93 | self._readiness_models = readiness_models 94 | self._client = grpcclient.InferenceServerClient(url=url) 95 | self._lock = threading.Lock() 96 | 97 | def is_server_live(self) -> bool: 98 | return self._client.is_server_live() 99 | 100 | def wait_server_ready(self): 101 | live = False 102 | while not live: 103 | time.sleep(1) 104 | try: 105 | live = self._client.is_server_live() 106 | except Exception as e: 107 | logger.warn("wait server ready err: %s", str(e)) 108 | 109 | self.wait_models_ready() 110 | 111 | def wait_models_ready(self): 112 | all_ready = False 113 | while not all_ready: 114 | curr_state = True 115 | for model_name in self._readiness_models: 116 | ready = self._client.is_model_ready(model_name) 117 | if not ready: 118 | curr_state = False 119 | all_ready = curr_state 120 | time.sleep(1) 121 | 122 | def get_model_meta(self, model_name): 123 | return self._client.get_model_metadata(model_name) 124 | 125 | # todo: add thread lock 126 | def infer(self, request, timeout: int = 1000): 127 | user_data = UserData() 128 | with self._lock: 129 | try: 130 | # Establish stream 131 | self._client.start_stream(callback=partial(callback, user_data)) 132 | # Send request 133 | self._client.async_stream_infer(request['model_name'], request['inputs'], timeout=timeout) 134 | finally: 135 | # Wait for server to close the stream 136 | self._client.stop_stream() 137 | 138 | # Parse the responses 139 | while True: 140 | try: 141 | result = user_data.completed_requests.get(block=False) 142 | except Exception: 143 | break 144 | 145 | if type(result) == InferenceServerException: 146 | raise result 147 | else: 148 | return result 149 | -------------------------------------------------------------------------------- /start_api.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python3 -m fastapi_tritonserver.entrypoints.api_server \ 3 | --port 9900 --host 0.0.0.0 \ 4 | --model-name tensorrt_llm \ 5 | --tokenizer-path ${TOKENIZER_PATH} \ 6 | --server-url ${TRITON_SERVER_HOST}:${TRITON_SERVER_PORT} \ 7 | --model_type ${MODEL_TYPE} -------------------------------------------------------------------------------- /start_openapi.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python3 -m fastapi_tritonserver.entrypoints.openai_api \ 3 | --port 9900 --host 0.0.0.0 \ 4 | --model-name tensorrt_llm \ 5 | --tokenizer-path ${TOKENIZER_PATH} \ 6 | --server-url ${TRITON_SERVER_HOST}:${TRITON_SERVER_PORT} \ 7 | --model_type ${MODEL_TYPE} --------------------------------------------------------------------------------