├── .github └── workflows │ └── python-check.yml ├── .gitignore ├── Dockerfile ├── Makefile ├── README.md ├── benchmark ├── README.md ├── bert.bin └── results │ ├── bert.md │ └── distilbert_serving_benchmark.png ├── pyproject.toml └── src ├── args.py ├── bentoml_services ├── README.md └── stable_diffusion │ ├── client.py │ └── service.py ├── mlserver_services ├── README.md └── stable_diffusion │ ├── client.py │ ├── model-settings.json │ ├── runtime.py │ └── settings.json ├── mosec_services ├── bert.py ├── image_bind.py ├── llama.py ├── stable_diffusion.py └── whisper.py ├── potassium_services └── stable_diffusion.py ├── pytriton_services ├── bert │ ├── client.py │ └── service.py └── stable_diffusion │ ├── client.py │ └── service.py └── triton_services ├── client.py └── stable_diffusion ├── 1 └── model.py └── config.pbtxt /.github/workflows/python-check.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python Lint 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | workflow_dispatch: 12 | merge_group: 13 | 14 | concurrency: 15 | group: ${{ github.ref }}-${{ github.workflow }} 16 | cancel-in-progress: true 17 | 18 | permissions: 19 | contents: read 20 | 21 | jobs: 22 | lint: 23 | 24 | runs-on: ubuntu-latest 25 | 26 | steps: 27 | - uses: actions/checkout@v3 28 | - name: Set up Python 3.11 29 | uses: actions/setup-python@v3 30 | with: 31 | python-version: "3.11" 32 | - name: Install dependencies 33 | run: | 34 | pip install ruff black 35 | - name: Lint 36 | run: | 37 | make lint 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # Ruff 163 | .ruff_cache 164 | 165 | # semver 166 | src/_version.py 167 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG base=nvidia/cuda:11.6.2-cudnn8-runtime-ubuntu20.04 2 | 3 | FROM ${base} 4 | 5 | ENV DEBIAN_FRONTEND=noninteractive LANG=en_US.UTF-8 LC_ALL=en_US.UTF-8 6 | ENV PATH /opt/conda/bin:$PATH 7 | 8 | ARG MOSEC_PORT=8000 9 | ENV MOSEC_PORT=${MOSEC_PORT} 10 | 11 | ARG CONDA_VERSION=py310_23.3.1-0 12 | 13 | RUN apt update && \ 14 | apt install -y --no-install-recommends \ 15 | wget \ 16 | git \ 17 | build-essential \ 18 | ca-certificates && \ 19 | rm -rf /var/lib/apt/lists/* 20 | 21 | RUN set -x && \ 22 | UNAME_M="$(uname -m)" && \ 23 | if [ "${UNAME_M}" = "x86_64" ]; then \ 24 | MINICONDA_URL="https://repo.anaconda.com/miniconda/Miniconda3-${CONDA_VERSION}-Linux-x86_64.sh"; \ 25 | SHA256SUM="aef279d6baea7f67940f16aad17ebe5f6aac97487c7c03466ff01f4819e5a651"; \ 26 | elif [ "${UNAME_M}" = "s390x" ]; then \ 27 | MINICONDA_URL="https://repo.anaconda.com/miniconda/Miniconda3-${CONDA_VERSION}-Linux-s390x.sh"; \ 28 | SHA256SUM="ed4f51afc967e921ff5721151f567a4c43c4288ac93ec2393c6238b8c4891de8"; \ 29 | elif [ "${UNAME_M}" = "aarch64" ]; then \ 30 | MINICONDA_URL="https://repo.anaconda.com/miniconda/Miniconda3-${CONDA_VERSION}-Linux-aarch64.sh"; \ 31 | SHA256SUM="6950c7b1f4f65ce9b87ee1a2d684837771ae7b2e6044e0da9e915d1dee6c924c"; \ 32 | elif [ "${UNAME_M}" = "ppc64le" ]; then \ 33 | MINICONDA_URL="https://repo.anaconda.com/miniconda/Miniconda3-${CONDA_VERSION}-Linux-ppc64le.sh"; \ 34 | SHA256SUM="b3de538cd542bc4f5a2f2d2a79386288d6e04f0e1459755f3cefe64763e51d16"; \ 35 | fi && \ 36 | wget "${MINICONDA_URL}" -O miniconda.sh -q && \ 37 | echo "${SHA256SUM} miniconda.sh" > shasum && \ 38 | if [ "${CONDA_VERSION}" != "latest" ]; then sha256sum --check --status shasum; fi && \ 39 | mkdir -p /opt && \ 40 | bash miniconda.sh -b -p /opt/conda && \ 41 | rm miniconda.sh shasum && \ 42 | ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ 43 | echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ 44 | echo "conda activate base" >> ~/.bashrc && \ 45 | find /opt/conda/ -follow -type f -name '*.a' -delete && \ 46 | find /opt/conda/ -follow -type f -name '*.js.map' -delete && \ 47 | /opt/conda/bin/conda clean -afy 48 | 49 | RUN conda create -n envd python=3.10 50 | 51 | ENV ENVD_PREFIX=/opt/conda/envs/envd/bin 52 | 53 | RUN update-alternatives --install /usr/bin/python python ${ENVD_PREFIX}/python 1 && \ 54 | update-alternatives --install /usr/bin/python3 python3 ${ENVD_PREFIX}/python3 1 && \ 55 | update-alternatives --install /usr/bin/pip pip ${ENVD_PREFIX}/pip 1 && \ 56 | update-alternatives --install /usr/bin/pip3 pip3 ${ENVD_PREFIX}/pip3 1 57 | 58 | RUN pip install torch transformers 59 | 60 | RUN mkdir -p /workspace 61 | 62 | COPY . workspace/ 63 | WORKDIR /workspace 64 | RUN pip install -e . 65 | 66 | ENTRYPOINT [ "bash" ] 67 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PY_SOURCE=src benchmark 2 | 3 | lint: 4 | @black --check --diff ${PY_SOURCE} 5 | @ruff check ${PY_SOURCE} 6 | 7 | format: 8 | @black ${PY_SOURCE} 9 | @ruff check --fix ${PY_SOURCE} 10 | 11 | clean: 12 | @-rm -rf dist build __pycache__ src/*.egg-info src/_version.py 13 | 14 | build: 15 | @python -m build 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Inference Benchmark 4 | 5 | Maximize the potential of your models with the inference benchmark (tool). 6 | 7 |
8 | 9 |

10 | discord invitation link 11 | trackgit-views 12 |

13 | 14 | # What is it 15 | 16 | Inference benchmark provides a standard way to measure the performance of inference workloads. It is also a tool that allows you to evaluate and optimize the performance of your inference workloads. 17 | 18 | # Results 19 | 20 | ## Bert 21 | 22 | We benchmarked [pytriton (triton-inference-server)](https://github.com/triton-inference-server/pytriton) and [mosec](https://github.com/mosecorg/mosec) with bert. We enabled dynamic batching for both frameworks with max batch size 32 and max wait time 10ms. Please checkout the [result](./benchmark/results/bert.md) for more details. 23 | 24 | ![DistilBert](./benchmark/results/distilbert_serving_benchmark.png) 25 | 26 | More [results with different models on different serving frameworks](https://github.com/tensorchord/inference-benchmark/issues/7) are coming soon. 27 | -------------------------------------------------------------------------------- /benchmark/README.md: -------------------------------------------------------------------------------- 1 | ## HTTP load test tool 2 | 3 | - [hey](https://github.com/rakyll/hey) 4 | 5 | ## Script 6 | 7 | ### Bert 8 | 9 | - mosec 10 | - `hey -m POST -n 50000 -d "The quick brown fox jumps over the lazy dog" http://127.0.0.1:8000/inference` 11 | - pytriton 12 | - `hey -m POST -n 50000 -H Inference-Header-Content-Length:175 -D bert.bin http://127.0.0.1:8000/v2/models/distilbert/infer` 13 | 14 | This binary file is generated by `pytriton` client: 15 | 16 | ```python 17 | b'{"id":"0","inputs":[{"name":"messages","shape":[1,1],"datatype":"BYTES","parameters":{"binary_data_size":47}}],"outputs":[{"name":"labels","parameters":{"binary_data":true}}]}+\x00\x00\x00The quick brown fox jumps over the lazy dog' 18 | ``` 19 | 20 | ![DistilBert](results/distilbert_serving_benchmark.png) 21 | 22 | Check the [result](./results/bert.md) for more details. 23 | -------------------------------------------------------------------------------- /benchmark/bert.bin: -------------------------------------------------------------------------------- 1 | {"id":"0","inputs":[{"name":"messages","shape":[1,1],"datatype":"BYTES","parameters":{"binary_data_size":47}}],"outputs":[{"name":"labels","parameters":{"binary_data":true}}]}+The quick brown fox jumps over the lazy dog -------------------------------------------------------------------------------- /benchmark/results/bert.md: -------------------------------------------------------------------------------- 1 | ## Environment 2 | 3 | GCP VM with: 4 | - Machine type: n1-standard-8 5 | - CPU: 8 vCPU Intel Haswell 6 | - RAM: 32G 7 | - GPU: 1 x NVIDIA T4 8 | - GPU Driver Version: 510.47.03 9 | - CUDA Version: 11.6 10 | 11 | Python: 12 | - MiniConda env with Python 3.8.17 (`pytriton` requires py38) 13 | - torch==2.0.1 14 | - transformers==4.30.2 15 | - nvidia-pytriton==0.2.0 16 | - mosec==0.7.2 17 | 18 | ## Results 19 | 20 | All the results are collected after the **warmup** load test. 21 | 22 | Dynamic batching: 23 | - max batch size: 32 24 | - max wait time: 10ms 25 | 26 | ### pytriton 27 | 28 | `hey -m POST -n 50000 -H Inference-Header-Content-Length:175 -D bert.bin http://127.0.0.1:8000/v2/models/distilbert/infer` 29 | 30 | Usage: 31 | - GPU Memory: 1117MiB / 15360MiB 32 | - GPU Util: 39% 33 | 34 | ``` 35 | Summary: 36 | Total: 55.0552 secs 37 | Slowest: 0.0911 secs 38 | Fastest: 0.0110 secs 39 | Average: 0.0549 secs 40 | Requests/sec: 908.1787 41 | 42 | Response time histogram: 43 | 0.011 [1] | 44 | 0.019 [28] | 45 | 0.027 [18] | 46 | 0.035 [68] | 47 | 0.043 [207] | 48 | 0.051 [31239] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 49 | 0.059 [3274] |■■■■ 50 | 0.067 [6956] |■■■■■■■■■ 51 | 0.075 [2954] |■■■■ 52 | 0.083 [5155] |■■■■■■■ 53 | 0.091 [100] | 54 | 55 | Latency distribution: 56 | 10% in 0.0463 secs 57 | 25% in 0.0476 secs 58 | 50% in 0.0494 secs 59 | 75% in 0.0626 secs 60 | 90% in 0.0752 secs 61 | 95% in 0.0770 secs 62 | 99% in 0.0792 secs 63 | 64 | Status code distribution: 65 | [200] 50000 responses 66 | ``` 67 | 68 | ### mosec 69 | 70 | ```sh 71 | hey -c 50 -m POST -n 50000 -d "The quick brown fox jumps over the lazy dog" http://127.0.0.1:8000/inference 72 | ``` 73 | 74 | Usage: 75 | - GPU Memory: 1043MiB / 15360MiB 76 | - GPU Util: 53% 77 | 78 | ``` 79 | Summary: 80 | Total: 23.2878 secs 81 | Slowest: 0.0778 secs 82 | Fastest: 0.0119 secs 83 | Average: 0.0230 secs 84 | Requests/sec: 2147.0514 85 | 86 | Response time histogram: 87 | 0.012 [1] | 88 | 0.018 [21773] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 89 | 0.025 [363] |■ 90 | 0.032 [26762] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 91 | 0.038 [902] |■ 92 | 0.045 [125] | 93 | 0.051 [56] | 94 | 0.058 [17] | 95 | 0.065 [0] | 96 | 0.071 [0] | 97 | 0.078 [1] | 98 | 99 | Latency distribution: 100 | 10% in 0.0142 secs 101 | 25% in 0.0149 secs 102 | 50% in 0.0278 secs 103 | 75% in 0.0293 secs 104 | 90% in 0.0303 secs 105 | 95% in 0.0310 secs 106 | 99% in 0.0324 secs 107 | 108 | Status code distribution: 109 | [200] 50000 responses 110 | ``` 111 | 112 | The response time is not even due to the request is not enough. We can modify the concurrent workers from `50` to `80`: 113 | 114 | ```sh 115 | hey -c 80 -m POST -n 50000 -d "The quick brown fox jumps over the lazy dog" http://127.0.0.1:8000/inference 116 | ``` 117 | 118 | Usage: 119 | - GPU Memory: 1043MiB / 15360MiB 120 | - GPU Util: 78% 121 | 122 | ``` 123 | Summary: 124 | Total: 18.0979 secs 125 | Slowest: 0.0720 secs 126 | Fastest: 0.0127 secs 127 | Average: 0.0286 secs 128 | Requests/sec: 2762.7568 129 | 130 | Total data: 400000 bytes 131 | Size/request: 8 bytes 132 | 133 | Response time histogram: 134 | 0.013 [1] | 135 | 0.019 [88] | 136 | 0.025 [24420] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 137 | 0.030 [419] |■ 138 | 0.036 [24618] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 139 | 0.042 [153] | 140 | 0.048 [26] | 141 | 0.054 [222] | 142 | 0.060 [37] | 143 | 0.066 [0] | 144 | 0.072 [16] | 145 | 146 | Latency distribution: 147 | 10% in 0.0224 secs 148 | 25% in 0.0229 secs 149 | 50% in 0.0315 secs 150 | 75% in 0.0340 secs 151 | 90% in 0.0347 secs 152 | 95% in 0.0351 secs 153 | 99% in 0.0363 secs 154 | 155 | Status code distribution: 156 | [200] 50000 responses 157 | ``` 158 | 159 | If we change the mosec inference number to 2, it will be even faster: 160 | 161 | ```sh 162 | hey -c 80 -m POST -n 50000 -d "The quick brown fox jumps over the lazy dog" http://127.0.0.1:8000/inference 163 | ``` 164 | 165 | - GPU Memory: 2080MiB / 15360MiB 166 | - GPU Util: 99% 167 | 168 | ``` 169 | Summary: 170 | Total: 16.5202 secs 171 | Slowest: 0.1151 secs 172 | Fastest: 0.0135 secs 173 | Average: 0.0259 secs 174 | Requests/sec: 3026.6061 175 | 176 | Response time histogram: 177 | 0.013 [1] | 178 | 0.024 [24159] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 179 | 0.034 [22667] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 180 | 0.044 [3037] |■■■■■ 181 | 0.054 [59] | 182 | 0.064 [32] | 183 | 0.074 [11] | 184 | 0.085 [20] | 185 | 0.095 [8] | 186 | 0.105 [0] | 187 | 0.115 [6] | 188 | 189 | Latency distribution: 190 | 10% in 0.0201 secs 191 | 25% in 0.0207 secs 192 | 50% in 0.0253 secs 193 | 75% in 0.0305 secs 194 | 90% in 0.0332 secs 195 | 95% in 0.0342 secs 196 | 99% in 0.0382 secs 197 | 198 | Status code distribution: 199 | [200] 50000 responses 200 | ``` 201 | -------------------------------------------------------------------------------- /benchmark/results/distilbert_serving_benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorchord/inference-benchmark/330d7497389a70998170a35d50aa28b6f591c1ab/benchmark/results/distilbert_serving_benchmark.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "inference-benchmark" 3 | dynamic = ["version"] 4 | description = "machine learning model serving benchmark" 5 | authors = [{ name = "TensorChord", email = "modelz@tensorchord.ai" }] 6 | keywords = ["machine learning", "deep learning", "model serving"] 7 | classifiers = [ 8 | "Intended Audience :: Developers", 9 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 10 | ] 11 | dependencies = [ 12 | "torch", 13 | "transformers", 14 | "numpy", 15 | "soundfile", 16 | "diffusers", 17 | "mosec", 18 | "nvidia-pytriton", 19 | ] 20 | requires-python = ">=3.8" 21 | readme = "README.md" 22 | license = { text = "Apache-2.0" } 23 | 24 | [project.urls] 25 | "Homepage" = "https://github.com/tensorchord/inference-benchmark" 26 | 27 | [project.optional-dependencies] 28 | dev = [ 29 | "black", 30 | "ruff", 31 | ] 32 | 33 | [build-system] 34 | requires = ["setuptools", "setuptools_scm>=7.0"] 35 | build-backend = "setuptools.build_meta" 36 | 37 | [tool.setuptools_scm] 38 | write_to = "src/_version.py" 39 | 40 | [tool.ruff] 41 | target-version = "py38" 42 | line-length = 88 43 | select = ["E", "F", "B", "I", "SIM", "TID", "PL"] 44 | [tool.ruff.pylint] 45 | max-branches = 35 46 | max-statements = 100 47 | 48 | [tool.black] 49 | line-length = 88 50 | -------------------------------------------------------------------------------- /src/args.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | def build_parser(): 5 | parser = ArgumentParser() 6 | parser.add_argument("--service", "-s", required=True, choices=("mosec")) 7 | 8 | return parser 9 | 10 | 11 | def run(): 12 | parser = build_parser() 13 | args = parser.parse_args() 14 | print(args.service) 15 | -------------------------------------------------------------------------------- /src/bentoml_services/README.md: -------------------------------------------------------------------------------- 1 | ## Usage 2 | 3 | Inside each directory: 4 | - server: `bentoml serve service.py:svc` 5 | - client: `python client.py` 6 | -------------------------------------------------------------------------------- /src/bentoml_services/stable_diffusion/client.py: -------------------------------------------------------------------------------- 1 | from http import HTTPStatus 2 | from threading import Thread 3 | 4 | import requests 5 | 6 | endpoint = "http://127.0.0.1:3000/batch_generate" 7 | 8 | 9 | def post(): 10 | response = requests.post( 11 | endpoint, 12 | headers={"Content-Type": "application/text"}, 13 | data="a happy boy with his toys", 14 | ) 15 | assert response.status_code == HTTPStatus.OK 16 | 17 | 18 | for _ in range(10): 19 | # Test sequential requests. 20 | post() 21 | 22 | for _ in range(10): 23 | # Test concurrent requests for adaptive batching. 24 | Thread(target=post).start() 25 | -------------------------------------------------------------------------------- /src/bentoml_services/stable_diffusion/service.py: -------------------------------------------------------------------------------- 1 | """BentoML service for stable diffusion. 2 | 3 | This implementation is based on the official examples: 4 | 5 | - https://github.com/bentoml/BentoML/tree/main/examples/custom_runner 6 | """ 7 | from time import time 8 | 9 | import bentoml 10 | import torch 11 | from bentoml.io import Image, Text 12 | from diffusers import StableDiffusionPipeline # type: ignore 13 | 14 | st_time = time() 15 | 16 | 17 | class StableDiffusionRunnable(bentoml.Runnable): 18 | SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu") 19 | SUPPORTS_CPU_MULTI_THREADING = True 20 | 21 | def __init__(self): 22 | init_st_time = time() 23 | device = "cuda" if torch.cuda.is_available() else "cpu" 24 | self._model = StableDiffusionPipeline.from_pretrained( 25 | "runwayml/stable-diffusion-v1-5", 26 | torch_dtype=torch.float16 if device == "cuda" else torch.float32, 27 | ).to(device) 28 | print(f"model loading time = {time() - init_st_time}") 29 | print(f"total starting time = {time() - st_time}") 30 | 31 | @bentoml.Runnable.method(batchable=True, batch_dim=0) 32 | def inference(self, prompts): 33 | print(prompts) 34 | pil_images = self._model(prompts).images 35 | return pil_images 36 | 37 | 38 | stable_diffusion_runner = bentoml.Runner(StableDiffusionRunnable, max_batch_size=8) 39 | 40 | svc = bentoml.Service("stable_diffusion", runners=[stable_diffusion_runner]) 41 | 42 | 43 | """ 44 | At the time this benchmark was created (27 June 2023), the bentoml batching 45 | seems to have some bugs: 46 | * The async call (`\generate`) failes to batch 47 | * The sync call (`\batch_generate`) allows batch but may lead to the following 48 | error: bentoml.exceptions.ServiceUnavailable: Service Busy 49 | """ 50 | 51 | 52 | @svc.api(input=Text(), output=Image()) 53 | async def generate(input_txt): 54 | batch_ret = await stable_diffusion_runner.inference.async_run([input_txt]) 55 | return batch_ret[0] 56 | 57 | 58 | @svc.api(input=Text(), output=Image()) 59 | def batch_generate(input_txt): 60 | batch_ret = stable_diffusion_runner.inference.run([input_txt]) 61 | return batch_ret[0] 62 | -------------------------------------------------------------------------------- /src/mlserver_services/README.md: -------------------------------------------------------------------------------- 1 | ## Usage 2 | 3 | Inside each directory: 4 | - server: `mlserver start .` 5 | - client: `python client.py` 6 | -------------------------------------------------------------------------------- /src/mlserver_services/stable_diffusion/client.py: -------------------------------------------------------------------------------- 1 | from http import HTTPStatus 2 | from threading import Thread 3 | 4 | import requests 5 | from mlserver.codecs import StringCodec 6 | from mlserver.types import InferenceRequest 7 | 8 | endpoint = "http://127.0.0.1:8080/v2/models/stable-diffusion/infer" 9 | 10 | inference_request_dict = InferenceRequest( 11 | inputs=[ 12 | StringCodec.encode_input( 13 | name="prompts", payload=["a happy boy with his toys"], use_bytes=False 14 | ) 15 | ] 16 | ).dict() 17 | 18 | 19 | def post(): 20 | response = requests.post(endpoint, json=inference_request_dict) 21 | assert response.status_code == HTTPStatus.OK 22 | 23 | 24 | for _ in range(10): 25 | # Test sequential requests. 26 | post() 27 | 28 | for _ in range(10): 29 | # Test concurrent requests for adaptive batching. 30 | Thread(target=post).start() 31 | -------------------------------------------------------------------------------- /src/mlserver_services/stable_diffusion/model-settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "stable-diffusion", 3 | "implementation": "runtime.StableDiffusion", 4 | "warm_workers": true, 5 | "max_batch_size": 8, 6 | "max_batch_time": 0.01 7 | } -------------------------------------------------------------------------------- /src/mlserver_services/stable_diffusion/runtime.py: -------------------------------------------------------------------------------- 1 | """MLServer service for stable diffusion. 2 | 3 | This implementation is based on the official documentation: 4 | 5 | - https://mlserver.readthedocs.io/en/latest/examples/custom/README.html 6 | """ 7 | 8 | import base64 9 | from io import BytesIO 10 | from typing import List 11 | 12 | import torch # type: ignore 13 | from diffusers import StableDiffusionPipeline # type: ignore 14 | from mlserver import MLModel 15 | from mlserver.codecs import decode_args 16 | 17 | 18 | class StableDiffusion(MLModel): 19 | async def load(self) -> bool: 20 | device = "cuda" if torch.cuda.is_available() else "cpu" 21 | self._model = StableDiffusionPipeline.from_pretrained( 22 | "runwayml/stable-diffusion-v1-5", 23 | torch_dtype=torch.float16 if device == "cuda" else torch.float32, 24 | ).to(device) 25 | return True 26 | 27 | @decode_args 28 | async def predict(self, prompts: List[str]) -> List[str]: 29 | images_b64 = [] 30 | for pil_im in self._model(prompts).images: 31 | buf = BytesIO() 32 | pil_im.save(buf, format="JPEG") 33 | images_b64.append(base64.b64encode(buf.getvalue()).decode("utf-8")) 34 | return images_b64 35 | -------------------------------------------------------------------------------- /src/mlserver_services/stable_diffusion/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "debug": "true" 3 | } -------------------------------------------------------------------------------- /src/mosec_services/bert.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 MOSEC Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Example: Mosec with PyTorch Distil BERT.""" 15 | 16 | from typing import Any, List 17 | 18 | import torch # type: ignore 19 | from mosec import Server, Worker, get_logger 20 | from transformers import ( # type: ignore 21 | DistilBertForSequenceClassification, 22 | DistilBertTokenizer, 23 | ) 24 | 25 | logger = get_logger() 26 | 27 | MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english" 28 | NUM_INSTANCE = 1 29 | INFERENCE_BATCH_SIZE = 32 30 | 31 | 32 | class Preprocess(Worker): 33 | """Preprocess BERT on current setup.""" 34 | 35 | def __init__(self): 36 | super().__init__() 37 | self.tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME) 38 | 39 | def deserialize(self, data: bytes) -> str: 40 | # Override `deserialize` for the *first* stage; 41 | # `data` is the raw bytes from the request body 42 | return data.decode() 43 | 44 | def forward(self, data: str) -> Any: 45 | tokens = self.tokenizer.encode(data, add_special_tokens=True) 46 | return tokens 47 | 48 | 49 | class Inference(Worker): 50 | """Pytorch Inference class""" 51 | 52 | resp_mime_type = "text/plain" 53 | 54 | def __init__(self): 55 | super().__init__() 56 | self.device = ( 57 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 58 | ) 59 | logger.info("using computing device: %s", self.device) 60 | self.model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME) 61 | self.model.eval() 62 | self.model.to(self.device) 63 | 64 | # Overwrite self.example for warmup 65 | self.example = [ 66 | [101, 2023, 2003, 1037, 8403, 4937, 999, 102] * 5 # make sentence longer 67 | ] * INFERENCE_BATCH_SIZE 68 | 69 | def forward(self, data: List[Any]) -> List[str]: 70 | tensors = [torch.tensor(token) for token in data] 71 | with torch.no_grad(): 72 | logits = self.model( 73 | torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True).to( 74 | self.device 75 | ) 76 | ).logits 77 | label_ids = logits.argmax(dim=1).cpu().tolist() 78 | return [self.model.config.id2label[i] for i in label_ids] 79 | 80 | def serialize(self, data: str) -> bytes: 81 | # Override `serialize` for the *last* stage; 82 | # `data` is the string from the `forward` output 83 | return data.encode() 84 | 85 | 86 | if __name__ == "__main__": 87 | server = Server() 88 | server.append_worker(Preprocess, num=2 * NUM_INSTANCE) 89 | server.append_worker( 90 | Inference, 91 | max_batch_size=INFERENCE_BATCH_SIZE, 92 | max_wait_time=10, 93 | num=NUM_INSTANCE, 94 | ) 95 | server.run() 96 | -------------------------------------------------------------------------------- /src/mosec_services/image_bind.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorchord/inference-benchmark/330d7497389a70998170a35d50aa28b6f591c1ab/src/mosec_services/image_bind.py -------------------------------------------------------------------------------- /src/mosec_services/llama.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mosec import Server, Worker 3 | from transformers import LlamaForCausalLM, LlamaTokenizer 4 | 5 | MODEL = "decapoda-research/llama-7b-hf" 6 | MAX_LENGTH = 50 7 | 8 | 9 | class TokenEncoder(Worker): 10 | def __init__(self): 11 | self.tokenizer = LlamaTokenizer.from_pretrained(MODEL) 12 | 13 | def deserialize(self, data): 14 | return data.decode() 15 | 16 | def forward(self, data): 17 | tokens = self.tokenizer(data) 18 | print(tokens) 19 | return tokens.input_ids 20 | 21 | 22 | class Inference(Worker): 23 | def __init__(self): 24 | self.model = LlamaForCausalLM.from_pretrained(MODEL) 25 | self.device = ( 26 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 27 | ) 28 | self.model.to(self.device) 29 | self.model.eval() 30 | 31 | def forward(self, data): 32 | inputs = torch.nn.utils.rnn.pad_sequence( 33 | [torch.tensor(tokens) for tokens in data], batch_first=True 34 | ).to(self.device) 35 | outputs = self.model.generate(inputs, max_length=MAX_LENGTH).tolist() 36 | return outputs 37 | 38 | 39 | class TokenDecoder(Worker): 40 | def __init__(self): 41 | self.tokenizer = LlamaTokenizer.from_pretrained(MODEL) 42 | 43 | def forward(self, data): 44 | outputs = self.tokenizer.decode(data, skip_special_tokens=True) 45 | return outputs 46 | 47 | 48 | if __name__ == "__main__": 49 | server = Server() 50 | server.append_worker(TokenEncoder) 51 | server.append_worker(Inference, max_batch_size=4, timeout=30) 52 | server.append_worker(TokenDecoder) 53 | server.run() 54 | -------------------------------------------------------------------------------- /src/mosec_services/stable_diffusion.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | from typing import List 3 | 4 | import torch # type: ignore 5 | from diffusers import StableDiffusionPipeline # type: ignore 6 | from mosec import Server, Worker, get_logger 7 | from mosec.mixin import MsgpackMixin 8 | 9 | logger = get_logger() 10 | 11 | 12 | class StableDiffusion(MsgpackMixin, Worker): 13 | def __init__(self): 14 | device = "cuda" if torch.cuda.is_available() else "cpu" 15 | self.pipe = StableDiffusionPipeline.from_pretrained( 16 | "runwayml/stable-diffusion-v1-5", 17 | torch_dtype=torch.float16 if device == "cuda" else torch.float32, 18 | ) 19 | self.example = ["useless example prompt"] * 8 20 | self.pipe = self.pipe.to(device) 21 | 22 | def forward(self, data: List[str]) -> List[memoryview]: 23 | res = self.pipe(data) 24 | images = [] 25 | for img in res[0]: 26 | dummy_file = BytesIO() 27 | img.save(dummy_file, format="JPEG") 28 | images.append(dummy_file.getbuffer()) 29 | return images 30 | 31 | 32 | if __name__ == "__main__": 33 | server = Server() 34 | server.append_worker(StableDiffusion, num=1, max_batch_size=8, timeout=30) 35 | server.run() 36 | -------------------------------------------------------------------------------- /src/mosec_services/whisper.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | import numpy as np 4 | import soundfile 5 | import torch 6 | from mosec import Server, Worker 7 | from transformers import WhisperForConditionalGeneration, WhisperProcessor 8 | 9 | STEREO_CHANNEL_NUM = 2 10 | 11 | 12 | class Preprocess(Worker): 13 | def __init__(self): 14 | self.processor = WhisperProcessor.from_pretrained("openai/whisper-base") 15 | 16 | def deserialize(self, data: bytes) -> any: 17 | with io.BytesIO(data) as byte_io: 18 | array, sampling_rate = soundfile.read(byte_io) 19 | if array.shape[1] == STEREO_CHANNEL_NUM: 20 | # conbime the channel 21 | array = np.mean(array, 1) 22 | return {"array": array, "sampling_rate": sampling_rate} 23 | 24 | def forward(self, data): 25 | res = self.processor( 26 | data["array"], sampling_rate=data["sampling_rate"], return_tensors="pt" 27 | ) 28 | return res.input_features 29 | 30 | 31 | class Inference(Worker): 32 | def __init__(self): 33 | self.model = WhisperForConditionalGeneration.from_pretrained( 34 | "openai/whisper-base" 35 | ) 36 | self.model.config.forced_decoder_ids = None 37 | self.device = ( 38 | torch.cuda.current_device() if torch.cuda.is_available() else "cpu" 39 | ) 40 | self.model.to(self.device) 41 | 42 | def forward(self, data): 43 | ids = self.model.generate(torch.cat(data).to(self.device)) 44 | return ids.cpu().tolist() 45 | 46 | 47 | class Postprocess(Worker): 48 | def __init__(self): 49 | self.processor = WhisperProcessor.from_pretrained("openai/whisper-base") 50 | 51 | def forward(self, data): 52 | return self.processor.batch_decode(data, skip_special_tokens=True) 53 | 54 | def serialize(self, data: str) -> bytes: 55 | return data.encode("utf-8") 56 | 57 | 58 | if __name__ == "__main__": 59 | server = Server() 60 | server.append_worker(Preprocess, num=2) 61 | server.append_worker(Inference, max_batch_size=16, max_wait_time=10) 62 | server.append_worker(Postprocess, num=2, max_batch_size=8, max_wait_time=5) 63 | server.run() 64 | -------------------------------------------------------------------------------- /src/potassium_services/stable_diffusion.py: -------------------------------------------------------------------------------- 1 | """Potassium service for stable diffusion. 2 | 3 | This implementation is based on the official README example: 4 | 5 | - https://github.com/bananaml/potassium#or-do-it-yourself 6 | """ 7 | 8 | import base64 9 | from io import BytesIO 10 | 11 | import torch 12 | from diffusers import StableDiffusionPipeline 13 | from potassium import Potassium, Request, Response 14 | 15 | app = Potassium(__name__) 16 | 17 | 18 | @app.init 19 | def init(): 20 | device = "cuda" if torch.cuda.is_available() else "cpu" 21 | model = StableDiffusionPipeline.from_pretrained( 22 | "runwayml/stable-diffusion-v1-5", 23 | torch_dtype=torch.float16 if device == "cuda" else torch.float32, 24 | ) 25 | model.to(device) 26 | 27 | context = { 28 | "model": model, 29 | } 30 | return context 31 | 32 | 33 | @app.handler() 34 | def handler(context: dict, request: Request) -> Response: 35 | model = context["model"] 36 | prompt = request.json.get("prompt") 37 | image = model(prompt).images[0] 38 | buf = BytesIO() 39 | image.save(buf, format="JPEG") 40 | return Response( 41 | json={"image_b64": base64.b64encode(buf.getvalue()).decode("utf-8")}, 42 | status=200, 43 | ) 44 | 45 | 46 | if __name__ == "__main__": 47 | app.serve() 48 | -------------------------------------------------------------------------------- /src/pytriton_services/bert/client.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytriton.client import ModelClient 3 | 4 | MODEL_NAME = "distilbert" 5 | MODEL_URL = "127.0.0.1" 6 | TIMEOUT_SECOND = 600 7 | 8 | 9 | def query(url: str = MODEL_URL, name: str = MODEL_NAME): 10 | with ModelClient(url, name, init_timeout_s=TIMEOUT_SECOND) as client: 11 | message = "The quick brown fox jumps over the lazy dog" 12 | request = np.char.encode(np.asarray([[message]]), "utf-8") 13 | resp = client.infer_batch(messages=request) 14 | print(resp) 15 | 16 | 17 | if __name__ == "__main__": 18 | query() 19 | -------------------------------------------------------------------------------- /src/pytriton_services/bert/service.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from pytriton.decorators import batch 4 | from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor 5 | from pytriton.triton import Triton, TritonConfig 6 | from transformers import ( # type: ignore 7 | DistilBertForSequenceClassification, 8 | DistilBertTokenizer, 9 | ) 10 | 11 | MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english" 12 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 13 | NUM_INSTANCE = 1 14 | tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME) 15 | model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME) 16 | model.to(DEVICE) 17 | model.eval() 18 | 19 | 20 | @batch 21 | def infer(messages: np.ndarray): 22 | msg = [ 23 | np.char.decode(message.astype("bytes"), "utf-8").item() for message in messages 24 | ] 25 | inputs = tokenizer(msg, return_tensors="pt", padding=True) 26 | inputs.to(DEVICE) 27 | logits = model(**inputs).logits 28 | label_ids = logits.argmax(dim=1).cpu().tolist() 29 | return {"labels": np.asarray([model.config.id2label[i] for i in label_ids])} 30 | 31 | 32 | def main(): 33 | config = TritonConfig(exit_on_error=True) 34 | with Triton(config=config) as triton: 35 | triton.bind( 36 | model_name="distilbert", 37 | infer_func=[infer] * NUM_INSTANCE, 38 | inputs=[Tensor(name="messages", dtype=np.bytes_, shape=(1,))], 39 | outputs=[Tensor(name="labels", dtype=np.bytes_, shape=(1,))], 40 | config=ModelConfig( 41 | max_batch_size=32, 42 | batcher=DynamicBatcher(max_queue_delay_microseconds=10), 43 | ), 44 | ) 45 | triton.serve() 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /src/pytriton_services/stable_diffusion/client.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytriton.client import ModelClient 3 | 4 | MODEL_NAME = "stable_diffusion" 5 | MODEL_URL = "127.0.0.1" 6 | TIMEOUT_SECOND = 600 7 | 8 | 9 | def query(url: str = MODEL_URL, name: str = MODEL_NAME): 10 | with ModelClient(url, name, init_timeout_s=TIMEOUT_SECOND) as client: 11 | prompt = "A photo of a cat" 12 | prompt_req = np.char.encode(np.asarray([[prompt]]), "utf-8") 13 | client.infer_batch(prompt=prompt_req) 14 | 15 | 16 | if __name__ == "__main__": 17 | query() 18 | -------------------------------------------------------------------------------- /src/pytriton_services/stable_diffusion/service.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Server for Stable Diffusion 1.5.""" 16 | 17 | import base64 18 | import io 19 | import logging 20 | 21 | import numpy as np 22 | import torch # pytype: disable=import-error 23 | from diffusers import StableDiffusionPipeline # pytype: disable=import-error 24 | from pytriton.decorators import batch 25 | from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor 26 | from pytriton.triton import Triton, TritonConfig 27 | 28 | LOGGER = logging.getLogger(__file__) 29 | 30 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 31 | IMAGE_FORMAT = "JPEG" 32 | 33 | pipe = pipe = StableDiffusionPipeline.from_pretrained( 34 | "runwayml/stable-diffusion-v1-5", 35 | torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, 36 | ) 37 | pipe = pipe.to(DEVICE) 38 | 39 | 40 | def _encode_image_to_base64(image): 41 | raw_bytes = io.BytesIO() 42 | image.save(raw_bytes, IMAGE_FORMAT) 43 | raw_bytes.seek(0) # return to the start of the buffer 44 | return base64.b64encode(raw_bytes.read()) 45 | 46 | 47 | @batch 48 | def _infer_fn(prompt: np.ndarray): 49 | prompts = [np.char.decode(p.astype("bytes"), "utf-8").item() for p in prompt] 50 | LOGGER.debug(f"Prompts: {prompts}") 51 | 52 | outputs = [] 53 | for idx, image in enumerate(pipe(prompt=prompts).images): 54 | raw_data = _encode_image_to_base64(image) 55 | outputs.append(raw_data) 56 | LOGGER.debug( 57 | f"Generated result for prompt `{prompts[idx]}` with size {len(raw_data)}" 58 | ) 59 | 60 | LOGGER.debug(f"Prepared batch response of size: {len(outputs)}") 61 | return {"image": np.array(outputs)} 62 | 63 | 64 | def main(): 65 | """Initialize server with model.""" 66 | log_level = logging.DEBUG 67 | logging.basicConfig( 68 | level=log_level, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s" 69 | ) 70 | log_verbose = 1 71 | config = TritonConfig(exit_on_error=True, log_verbose=log_verbose) 72 | 73 | with Triton(config=config) as triton: 74 | LOGGER.info("Loading the pipeline") 75 | triton.bind( 76 | model_name="stable_diffusion", 77 | infer_func=_infer_fn, 78 | inputs=[ 79 | Tensor(name="prompt", dtype=np.bytes_, shape=(1,)), 80 | ], 81 | outputs=[ 82 | Tensor(name="image", dtype=np.bytes_, shape=(1,)), 83 | ], 84 | config=ModelConfig( 85 | max_batch_size=4, 86 | batcher=DynamicBatcher( 87 | max_queue_delay_microseconds=100, 88 | ), 89 | ), 90 | ) 91 | triton.serve() 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /src/triton_services/client.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tritonclient.http as httpclient 3 | 4 | with httpclient.InferenceServerClient("localhost:8000") as client: 5 | prompt = httpclient.InferInput("PROMPT", shape=(1,), datatype="BYTES") 6 | prompt.set_data_from_numpy(np.asarray(["cat"], dtype=object)) 7 | images = httpclient.InferRequestedOutput("IMAGE", binary_data=False) 8 | response = client.infer( 9 | model_name="stable_diffusion", 10 | inputs=[prompt], 11 | outputs=[images], 12 | ) 13 | content = response.as_numpy("IMAGE") 14 | -------------------------------------------------------------------------------- /src/triton_services/stable_diffusion/1/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import triton_python_backend_utils as triton_utils 4 | from diffusers import StableDiffusionPipeline 5 | 6 | 7 | class TritonPythonModel: 8 | def initialize(self, args): 9 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 10 | self.model = StableDiffusionPipeline.from_pretrained( 11 | "runwayml/stable-diffusion-v1-5", 12 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, 13 | ) 14 | self.model.to(self.device) 15 | 16 | def execute(self, requests): 17 | responses = [] 18 | prompts = [] 19 | for request in requests: 20 | prompt = ( 21 | triton_utils.get_input_tensor_by_name(request, "PROMPT") 22 | .as_numpy() 23 | .tolist() 24 | ) 25 | prompts.append(prompt[0].decode()) 26 | 27 | images = self.model(prompts).images 28 | for image in images: 29 | output = triton_utils.Tensor("IMAGE", np.asarray(image)) 30 | resp = triton_utils.InferenceResponse(output_tensors=[output]) 31 | responses.append(resp) 32 | 33 | return responses 34 | -------------------------------------------------------------------------------- /src/triton_services/stable_diffusion/config.pbtxt: -------------------------------------------------------------------------------- 1 | name: "stable_diffusion" 2 | backend: "python" 3 | # max_batch_size: 8 4 | # dynamic_batching { 5 | # max_queue_delay_microseconds: 10 6 | # } 7 | 8 | input [ 9 | { 10 | name: "PROMPT" 11 | data_type: TYPE_STRING 12 | dims: [ 1 ] 13 | } 14 | ] 15 | 16 | output [ 17 | { 18 | name: "IMAGE" 19 | data_type: TYPE_UINT8 20 | dims: [ 512, 512, 3 ] 21 | } 22 | ] 23 | 24 | instance_group [ 25 | { 26 | count: 1 27 | } 28 | ] --------------------------------------------------------------------------------