├── .github
└── workflows
│ └── python-check.yml
├── .gitignore
├── Dockerfile
├── Makefile
├── README.md
├── benchmark
├── README.md
├── bert.bin
└── results
│ ├── bert.md
│ └── distilbert_serving_benchmark.png
├── pyproject.toml
└── src
├── args.py
├── bentoml_services
├── README.md
└── stable_diffusion
│ ├── client.py
│ └── service.py
├── mlserver_services
├── README.md
└── stable_diffusion
│ ├── client.py
│ ├── model-settings.json
│ ├── runtime.py
│ └── settings.json
├── mosec_services
├── bert.py
├── image_bind.py
├── llama.py
├── stable_diffusion.py
└── whisper.py
├── potassium_services
└── stable_diffusion.py
├── pytriton_services
├── bert
│ ├── client.py
│ └── service.py
└── stable_diffusion
│ ├── client.py
│ └── service.py
└── triton_services
├── client.py
└── stable_diffusion
├── 1
└── model.py
└── config.pbtxt
/.github/workflows/python-check.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3 |
4 | name: Python Lint
5 |
6 | on:
7 | push:
8 | branches: [ "main" ]
9 | pull_request:
10 | branches: [ "main" ]
11 | workflow_dispatch:
12 | merge_group:
13 |
14 | concurrency:
15 | group: ${{ github.ref }}-${{ github.workflow }}
16 | cancel-in-progress: true
17 |
18 | permissions:
19 | contents: read
20 |
21 | jobs:
22 | lint:
23 |
24 | runs-on: ubuntu-latest
25 |
26 | steps:
27 | - uses: actions/checkout@v3
28 | - name: Set up Python 3.11
29 | uses: actions/setup-python@v3
30 | with:
31 | python-version: "3.11"
32 | - name: Install dependencies
33 | run: |
34 | pip install ruff black
35 | - name: Lint
36 | run: |
37 | make lint
38 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | # Ruff
163 | .ruff_cache
164 |
165 | # semver
166 | src/_version.py
167 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | ARG base=nvidia/cuda:11.6.2-cudnn8-runtime-ubuntu20.04
2 |
3 | FROM ${base}
4 |
5 | ENV DEBIAN_FRONTEND=noninteractive LANG=en_US.UTF-8 LC_ALL=en_US.UTF-8
6 | ENV PATH /opt/conda/bin:$PATH
7 |
8 | ARG MOSEC_PORT=8000
9 | ENV MOSEC_PORT=${MOSEC_PORT}
10 |
11 | ARG CONDA_VERSION=py310_23.3.1-0
12 |
13 | RUN apt update && \
14 | apt install -y --no-install-recommends \
15 | wget \
16 | git \
17 | build-essential \
18 | ca-certificates && \
19 | rm -rf /var/lib/apt/lists/*
20 |
21 | RUN set -x && \
22 | UNAME_M="$(uname -m)" && \
23 | if [ "${UNAME_M}" = "x86_64" ]; then \
24 | MINICONDA_URL="https://repo.anaconda.com/miniconda/Miniconda3-${CONDA_VERSION}-Linux-x86_64.sh"; \
25 | SHA256SUM="aef279d6baea7f67940f16aad17ebe5f6aac97487c7c03466ff01f4819e5a651"; \
26 | elif [ "${UNAME_M}" = "s390x" ]; then \
27 | MINICONDA_URL="https://repo.anaconda.com/miniconda/Miniconda3-${CONDA_VERSION}-Linux-s390x.sh"; \
28 | SHA256SUM="ed4f51afc967e921ff5721151f567a4c43c4288ac93ec2393c6238b8c4891de8"; \
29 | elif [ "${UNAME_M}" = "aarch64" ]; then \
30 | MINICONDA_URL="https://repo.anaconda.com/miniconda/Miniconda3-${CONDA_VERSION}-Linux-aarch64.sh"; \
31 | SHA256SUM="6950c7b1f4f65ce9b87ee1a2d684837771ae7b2e6044e0da9e915d1dee6c924c"; \
32 | elif [ "${UNAME_M}" = "ppc64le" ]; then \
33 | MINICONDA_URL="https://repo.anaconda.com/miniconda/Miniconda3-${CONDA_VERSION}-Linux-ppc64le.sh"; \
34 | SHA256SUM="b3de538cd542bc4f5a2f2d2a79386288d6e04f0e1459755f3cefe64763e51d16"; \
35 | fi && \
36 | wget "${MINICONDA_URL}" -O miniconda.sh -q && \
37 | echo "${SHA256SUM} miniconda.sh" > shasum && \
38 | if [ "${CONDA_VERSION}" != "latest" ]; then sha256sum --check --status shasum; fi && \
39 | mkdir -p /opt && \
40 | bash miniconda.sh -b -p /opt/conda && \
41 | rm miniconda.sh shasum && \
42 | ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \
43 | echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \
44 | echo "conda activate base" >> ~/.bashrc && \
45 | find /opt/conda/ -follow -type f -name '*.a' -delete && \
46 | find /opt/conda/ -follow -type f -name '*.js.map' -delete && \
47 | /opt/conda/bin/conda clean -afy
48 |
49 | RUN conda create -n envd python=3.10
50 |
51 | ENV ENVD_PREFIX=/opt/conda/envs/envd/bin
52 |
53 | RUN update-alternatives --install /usr/bin/python python ${ENVD_PREFIX}/python 1 && \
54 | update-alternatives --install /usr/bin/python3 python3 ${ENVD_PREFIX}/python3 1 && \
55 | update-alternatives --install /usr/bin/pip pip ${ENVD_PREFIX}/pip 1 && \
56 | update-alternatives --install /usr/bin/pip3 pip3 ${ENVD_PREFIX}/pip3 1
57 |
58 | RUN pip install torch transformers
59 |
60 | RUN mkdir -p /workspace
61 |
62 | COPY . workspace/
63 | WORKDIR /workspace
64 | RUN pip install -e .
65 |
66 | ENTRYPOINT [ "bash" ]
67 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | PY_SOURCE=src benchmark
2 |
3 | lint:
4 | @black --check --diff ${PY_SOURCE}
5 | @ruff check ${PY_SOURCE}
6 |
7 | format:
8 | @black ${PY_SOURCE}
9 | @ruff check --fix ${PY_SOURCE}
10 |
11 | clean:
12 | @-rm -rf dist build __pycache__ src/*.egg-info src/_version.py
13 |
14 | build:
15 | @python -m build
16 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Inference Benchmark
4 |
5 | Maximize the potential of your models with the inference benchmark (tool).
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | # What is it
15 |
16 | Inference benchmark provides a standard way to measure the performance of inference workloads. It is also a tool that allows you to evaluate and optimize the performance of your inference workloads.
17 |
18 | # Results
19 |
20 | ## Bert
21 |
22 | We benchmarked [pytriton (triton-inference-server)](https://github.com/triton-inference-server/pytriton) and [mosec](https://github.com/mosecorg/mosec) with bert. We enabled dynamic batching for both frameworks with max batch size 32 and max wait time 10ms. Please checkout the [result](./benchmark/results/bert.md) for more details.
23 |
24 | 
25 |
26 | More [results with different models on different serving frameworks](https://github.com/tensorchord/inference-benchmark/issues/7) are coming soon.
27 |
--------------------------------------------------------------------------------
/benchmark/README.md:
--------------------------------------------------------------------------------
1 | ## HTTP load test tool
2 |
3 | - [hey](https://github.com/rakyll/hey)
4 |
5 | ## Script
6 |
7 | ### Bert
8 |
9 | - mosec
10 | - `hey -m POST -n 50000 -d "The quick brown fox jumps over the lazy dog" http://127.0.0.1:8000/inference`
11 | - pytriton
12 | - `hey -m POST -n 50000 -H Inference-Header-Content-Length:175 -D bert.bin http://127.0.0.1:8000/v2/models/distilbert/infer`
13 |
14 | This binary file is generated by `pytriton` client:
15 |
16 | ```python
17 | b'{"id":"0","inputs":[{"name":"messages","shape":[1,1],"datatype":"BYTES","parameters":{"binary_data_size":47}}],"outputs":[{"name":"labels","parameters":{"binary_data":true}}]}+\x00\x00\x00The quick brown fox jumps over the lazy dog'
18 | ```
19 |
20 | 
21 |
22 | Check the [result](./results/bert.md) for more details.
23 |
--------------------------------------------------------------------------------
/benchmark/bert.bin:
--------------------------------------------------------------------------------
1 | {"id":"0","inputs":[{"name":"messages","shape":[1,1],"datatype":"BYTES","parameters":{"binary_data_size":47}}],"outputs":[{"name":"labels","parameters":{"binary_data":true}}]}+ The quick brown fox jumps over the lazy dog
--------------------------------------------------------------------------------
/benchmark/results/bert.md:
--------------------------------------------------------------------------------
1 | ## Environment
2 |
3 | GCP VM with:
4 | - Machine type: n1-standard-8
5 | - CPU: 8 vCPU Intel Haswell
6 | - RAM: 32G
7 | - GPU: 1 x NVIDIA T4
8 | - GPU Driver Version: 510.47.03
9 | - CUDA Version: 11.6
10 |
11 | Python:
12 | - MiniConda env with Python 3.8.17 (`pytriton` requires py38)
13 | - torch==2.0.1
14 | - transformers==4.30.2
15 | - nvidia-pytriton==0.2.0
16 | - mosec==0.7.2
17 |
18 | ## Results
19 |
20 | All the results are collected after the **warmup** load test.
21 |
22 | Dynamic batching:
23 | - max batch size: 32
24 | - max wait time: 10ms
25 |
26 | ### pytriton
27 |
28 | `hey -m POST -n 50000 -H Inference-Header-Content-Length:175 -D bert.bin http://127.0.0.1:8000/v2/models/distilbert/infer`
29 |
30 | Usage:
31 | - GPU Memory: 1117MiB / 15360MiB
32 | - GPU Util: 39%
33 |
34 | ```
35 | Summary:
36 | Total: 55.0552 secs
37 | Slowest: 0.0911 secs
38 | Fastest: 0.0110 secs
39 | Average: 0.0549 secs
40 | Requests/sec: 908.1787
41 |
42 | Response time histogram:
43 | 0.011 [1] |
44 | 0.019 [28] |
45 | 0.027 [18] |
46 | 0.035 [68] |
47 | 0.043 [207] |
48 | 0.051 [31239] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
49 | 0.059 [3274] |■■■■
50 | 0.067 [6956] |■■■■■■■■■
51 | 0.075 [2954] |■■■■
52 | 0.083 [5155] |■■■■■■■
53 | 0.091 [100] |
54 |
55 | Latency distribution:
56 | 10% in 0.0463 secs
57 | 25% in 0.0476 secs
58 | 50% in 0.0494 secs
59 | 75% in 0.0626 secs
60 | 90% in 0.0752 secs
61 | 95% in 0.0770 secs
62 | 99% in 0.0792 secs
63 |
64 | Status code distribution:
65 | [200] 50000 responses
66 | ```
67 |
68 | ### mosec
69 |
70 | ```sh
71 | hey -c 50 -m POST -n 50000 -d "The quick brown fox jumps over the lazy dog" http://127.0.0.1:8000/inference
72 | ```
73 |
74 | Usage:
75 | - GPU Memory: 1043MiB / 15360MiB
76 | - GPU Util: 53%
77 |
78 | ```
79 | Summary:
80 | Total: 23.2878 secs
81 | Slowest: 0.0778 secs
82 | Fastest: 0.0119 secs
83 | Average: 0.0230 secs
84 | Requests/sec: 2147.0514
85 |
86 | Response time histogram:
87 | 0.012 [1] |
88 | 0.018 [21773] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
89 | 0.025 [363] |■
90 | 0.032 [26762] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
91 | 0.038 [902] |■
92 | 0.045 [125] |
93 | 0.051 [56] |
94 | 0.058 [17] |
95 | 0.065 [0] |
96 | 0.071 [0] |
97 | 0.078 [1] |
98 |
99 | Latency distribution:
100 | 10% in 0.0142 secs
101 | 25% in 0.0149 secs
102 | 50% in 0.0278 secs
103 | 75% in 0.0293 secs
104 | 90% in 0.0303 secs
105 | 95% in 0.0310 secs
106 | 99% in 0.0324 secs
107 |
108 | Status code distribution:
109 | [200] 50000 responses
110 | ```
111 |
112 | The response time is not even due to the request is not enough. We can modify the concurrent workers from `50` to `80`:
113 |
114 | ```sh
115 | hey -c 80 -m POST -n 50000 -d "The quick brown fox jumps over the lazy dog" http://127.0.0.1:8000/inference
116 | ```
117 |
118 | Usage:
119 | - GPU Memory: 1043MiB / 15360MiB
120 | - GPU Util: 78%
121 |
122 | ```
123 | Summary:
124 | Total: 18.0979 secs
125 | Slowest: 0.0720 secs
126 | Fastest: 0.0127 secs
127 | Average: 0.0286 secs
128 | Requests/sec: 2762.7568
129 |
130 | Total data: 400000 bytes
131 | Size/request: 8 bytes
132 |
133 | Response time histogram:
134 | 0.013 [1] |
135 | 0.019 [88] |
136 | 0.025 [24420] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
137 | 0.030 [419] |■
138 | 0.036 [24618] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
139 | 0.042 [153] |
140 | 0.048 [26] |
141 | 0.054 [222] |
142 | 0.060 [37] |
143 | 0.066 [0] |
144 | 0.072 [16] |
145 |
146 | Latency distribution:
147 | 10% in 0.0224 secs
148 | 25% in 0.0229 secs
149 | 50% in 0.0315 secs
150 | 75% in 0.0340 secs
151 | 90% in 0.0347 secs
152 | 95% in 0.0351 secs
153 | 99% in 0.0363 secs
154 |
155 | Status code distribution:
156 | [200] 50000 responses
157 | ```
158 |
159 | If we change the mosec inference number to 2, it will be even faster:
160 |
161 | ```sh
162 | hey -c 80 -m POST -n 50000 -d "The quick brown fox jumps over the lazy dog" http://127.0.0.1:8000/inference
163 | ```
164 |
165 | - GPU Memory: 2080MiB / 15360MiB
166 | - GPU Util: 99%
167 |
168 | ```
169 | Summary:
170 | Total: 16.5202 secs
171 | Slowest: 0.1151 secs
172 | Fastest: 0.0135 secs
173 | Average: 0.0259 secs
174 | Requests/sec: 3026.6061
175 |
176 | Response time histogram:
177 | 0.013 [1] |
178 | 0.024 [24159] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
179 | 0.034 [22667] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
180 | 0.044 [3037] |■■■■■
181 | 0.054 [59] |
182 | 0.064 [32] |
183 | 0.074 [11] |
184 | 0.085 [20] |
185 | 0.095 [8] |
186 | 0.105 [0] |
187 | 0.115 [6] |
188 |
189 | Latency distribution:
190 | 10% in 0.0201 secs
191 | 25% in 0.0207 secs
192 | 50% in 0.0253 secs
193 | 75% in 0.0305 secs
194 | 90% in 0.0332 secs
195 | 95% in 0.0342 secs
196 | 99% in 0.0382 secs
197 |
198 | Status code distribution:
199 | [200] 50000 responses
200 | ```
201 |
--------------------------------------------------------------------------------
/benchmark/results/distilbert_serving_benchmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorchord/inference-benchmark/330d7497389a70998170a35d50aa28b6f591c1ab/benchmark/results/distilbert_serving_benchmark.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "inference-benchmark"
3 | dynamic = ["version"]
4 | description = "machine learning model serving benchmark"
5 | authors = [{ name = "TensorChord", email = "modelz@tensorchord.ai" }]
6 | keywords = ["machine learning", "deep learning", "model serving"]
7 | classifiers = [
8 | "Intended Audience :: Developers",
9 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
10 | ]
11 | dependencies = [
12 | "torch",
13 | "transformers",
14 | "numpy",
15 | "soundfile",
16 | "diffusers",
17 | "mosec",
18 | "nvidia-pytriton",
19 | ]
20 | requires-python = ">=3.8"
21 | readme = "README.md"
22 | license = { text = "Apache-2.0" }
23 |
24 | [project.urls]
25 | "Homepage" = "https://github.com/tensorchord/inference-benchmark"
26 |
27 | [project.optional-dependencies]
28 | dev = [
29 | "black",
30 | "ruff",
31 | ]
32 |
33 | [build-system]
34 | requires = ["setuptools", "setuptools_scm>=7.0"]
35 | build-backend = "setuptools.build_meta"
36 |
37 | [tool.setuptools_scm]
38 | write_to = "src/_version.py"
39 |
40 | [tool.ruff]
41 | target-version = "py38"
42 | line-length = 88
43 | select = ["E", "F", "B", "I", "SIM", "TID", "PL"]
44 | [tool.ruff.pylint]
45 | max-branches = 35
46 | max-statements = 100
47 |
48 | [tool.black]
49 | line-length = 88
50 |
--------------------------------------------------------------------------------
/src/args.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 |
4 | def build_parser():
5 | parser = ArgumentParser()
6 | parser.add_argument("--service", "-s", required=True, choices=("mosec"))
7 |
8 | return parser
9 |
10 |
11 | def run():
12 | parser = build_parser()
13 | args = parser.parse_args()
14 | print(args.service)
15 |
--------------------------------------------------------------------------------
/src/bentoml_services/README.md:
--------------------------------------------------------------------------------
1 | ## Usage
2 |
3 | Inside each directory:
4 | - server: `bentoml serve service.py:svc`
5 | - client: `python client.py`
6 |
--------------------------------------------------------------------------------
/src/bentoml_services/stable_diffusion/client.py:
--------------------------------------------------------------------------------
1 | from http import HTTPStatus
2 | from threading import Thread
3 |
4 | import requests
5 |
6 | endpoint = "http://127.0.0.1:3000/batch_generate"
7 |
8 |
9 | def post():
10 | response = requests.post(
11 | endpoint,
12 | headers={"Content-Type": "application/text"},
13 | data="a happy boy with his toys",
14 | )
15 | assert response.status_code == HTTPStatus.OK
16 |
17 |
18 | for _ in range(10):
19 | # Test sequential requests.
20 | post()
21 |
22 | for _ in range(10):
23 | # Test concurrent requests for adaptive batching.
24 | Thread(target=post).start()
25 |
--------------------------------------------------------------------------------
/src/bentoml_services/stable_diffusion/service.py:
--------------------------------------------------------------------------------
1 | """BentoML service for stable diffusion.
2 |
3 | This implementation is based on the official examples:
4 |
5 | - https://github.com/bentoml/BentoML/tree/main/examples/custom_runner
6 | """
7 | from time import time
8 |
9 | import bentoml
10 | import torch
11 | from bentoml.io import Image, Text
12 | from diffusers import StableDiffusionPipeline # type: ignore
13 |
14 | st_time = time()
15 |
16 |
17 | class StableDiffusionRunnable(bentoml.Runnable):
18 | SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu")
19 | SUPPORTS_CPU_MULTI_THREADING = True
20 |
21 | def __init__(self):
22 | init_st_time = time()
23 | device = "cuda" if torch.cuda.is_available() else "cpu"
24 | self._model = StableDiffusionPipeline.from_pretrained(
25 | "runwayml/stable-diffusion-v1-5",
26 | torch_dtype=torch.float16 if device == "cuda" else torch.float32,
27 | ).to(device)
28 | print(f"model loading time = {time() - init_st_time}")
29 | print(f"total starting time = {time() - st_time}")
30 |
31 | @bentoml.Runnable.method(batchable=True, batch_dim=0)
32 | def inference(self, prompts):
33 | print(prompts)
34 | pil_images = self._model(prompts).images
35 | return pil_images
36 |
37 |
38 | stable_diffusion_runner = bentoml.Runner(StableDiffusionRunnable, max_batch_size=8)
39 |
40 | svc = bentoml.Service("stable_diffusion", runners=[stable_diffusion_runner])
41 |
42 |
43 | """
44 | At the time this benchmark was created (27 June 2023), the bentoml batching
45 | seems to have some bugs:
46 | * The async call (`\generate`) failes to batch
47 | * The sync call (`\batch_generate`) allows batch but may lead to the following
48 | error: bentoml.exceptions.ServiceUnavailable: Service Busy
49 | """
50 |
51 |
52 | @svc.api(input=Text(), output=Image())
53 | async def generate(input_txt):
54 | batch_ret = await stable_diffusion_runner.inference.async_run([input_txt])
55 | return batch_ret[0]
56 |
57 |
58 | @svc.api(input=Text(), output=Image())
59 | def batch_generate(input_txt):
60 | batch_ret = stable_diffusion_runner.inference.run([input_txt])
61 | return batch_ret[0]
62 |
--------------------------------------------------------------------------------
/src/mlserver_services/README.md:
--------------------------------------------------------------------------------
1 | ## Usage
2 |
3 | Inside each directory:
4 | - server: `mlserver start .`
5 | - client: `python client.py`
6 |
--------------------------------------------------------------------------------
/src/mlserver_services/stable_diffusion/client.py:
--------------------------------------------------------------------------------
1 | from http import HTTPStatus
2 | from threading import Thread
3 |
4 | import requests
5 | from mlserver.codecs import StringCodec
6 | from mlserver.types import InferenceRequest
7 |
8 | endpoint = "http://127.0.0.1:8080/v2/models/stable-diffusion/infer"
9 |
10 | inference_request_dict = InferenceRequest(
11 | inputs=[
12 | StringCodec.encode_input(
13 | name="prompts", payload=["a happy boy with his toys"], use_bytes=False
14 | )
15 | ]
16 | ).dict()
17 |
18 |
19 | def post():
20 | response = requests.post(endpoint, json=inference_request_dict)
21 | assert response.status_code == HTTPStatus.OK
22 |
23 |
24 | for _ in range(10):
25 | # Test sequential requests.
26 | post()
27 |
28 | for _ in range(10):
29 | # Test concurrent requests for adaptive batching.
30 | Thread(target=post).start()
31 |
--------------------------------------------------------------------------------
/src/mlserver_services/stable_diffusion/model-settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "stable-diffusion",
3 | "implementation": "runtime.StableDiffusion",
4 | "warm_workers": true,
5 | "max_batch_size": 8,
6 | "max_batch_time": 0.01
7 | }
--------------------------------------------------------------------------------
/src/mlserver_services/stable_diffusion/runtime.py:
--------------------------------------------------------------------------------
1 | """MLServer service for stable diffusion.
2 |
3 | This implementation is based on the official documentation:
4 |
5 | - https://mlserver.readthedocs.io/en/latest/examples/custom/README.html
6 | """
7 |
8 | import base64
9 | from io import BytesIO
10 | from typing import List
11 |
12 | import torch # type: ignore
13 | from diffusers import StableDiffusionPipeline # type: ignore
14 | from mlserver import MLModel
15 | from mlserver.codecs import decode_args
16 |
17 |
18 | class StableDiffusion(MLModel):
19 | async def load(self) -> bool:
20 | device = "cuda" if torch.cuda.is_available() else "cpu"
21 | self._model = StableDiffusionPipeline.from_pretrained(
22 | "runwayml/stable-diffusion-v1-5",
23 | torch_dtype=torch.float16 if device == "cuda" else torch.float32,
24 | ).to(device)
25 | return True
26 |
27 | @decode_args
28 | async def predict(self, prompts: List[str]) -> List[str]:
29 | images_b64 = []
30 | for pil_im in self._model(prompts).images:
31 | buf = BytesIO()
32 | pil_im.save(buf, format="JPEG")
33 | images_b64.append(base64.b64encode(buf.getvalue()).decode("utf-8"))
34 | return images_b64
35 |
--------------------------------------------------------------------------------
/src/mlserver_services/stable_diffusion/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "debug": "true"
3 | }
--------------------------------------------------------------------------------
/src/mosec_services/bert.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 MOSEC Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Example: Mosec with PyTorch Distil BERT."""
15 |
16 | from typing import Any, List
17 |
18 | import torch # type: ignore
19 | from mosec import Server, Worker, get_logger
20 | from transformers import ( # type: ignore
21 | DistilBertForSequenceClassification,
22 | DistilBertTokenizer,
23 | )
24 |
25 | logger = get_logger()
26 |
27 | MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
28 | NUM_INSTANCE = 1
29 | INFERENCE_BATCH_SIZE = 32
30 |
31 |
32 | class Preprocess(Worker):
33 | """Preprocess BERT on current setup."""
34 |
35 | def __init__(self):
36 | super().__init__()
37 | self.tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
38 |
39 | def deserialize(self, data: bytes) -> str:
40 | # Override `deserialize` for the *first* stage;
41 | # `data` is the raw bytes from the request body
42 | return data.decode()
43 |
44 | def forward(self, data: str) -> Any:
45 | tokens = self.tokenizer.encode(data, add_special_tokens=True)
46 | return tokens
47 |
48 |
49 | class Inference(Worker):
50 | """Pytorch Inference class"""
51 |
52 | resp_mime_type = "text/plain"
53 |
54 | def __init__(self):
55 | super().__init__()
56 | self.device = (
57 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
58 | )
59 | logger.info("using computing device: %s", self.device)
60 | self.model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
61 | self.model.eval()
62 | self.model.to(self.device)
63 |
64 | # Overwrite self.example for warmup
65 | self.example = [
66 | [101, 2023, 2003, 1037, 8403, 4937, 999, 102] * 5 # make sentence longer
67 | ] * INFERENCE_BATCH_SIZE
68 |
69 | def forward(self, data: List[Any]) -> List[str]:
70 | tensors = [torch.tensor(token) for token in data]
71 | with torch.no_grad():
72 | logits = self.model(
73 | torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True).to(
74 | self.device
75 | )
76 | ).logits
77 | label_ids = logits.argmax(dim=1).cpu().tolist()
78 | return [self.model.config.id2label[i] for i in label_ids]
79 |
80 | def serialize(self, data: str) -> bytes:
81 | # Override `serialize` for the *last* stage;
82 | # `data` is the string from the `forward` output
83 | return data.encode()
84 |
85 |
86 | if __name__ == "__main__":
87 | server = Server()
88 | server.append_worker(Preprocess, num=2 * NUM_INSTANCE)
89 | server.append_worker(
90 | Inference,
91 | max_batch_size=INFERENCE_BATCH_SIZE,
92 | max_wait_time=10,
93 | num=NUM_INSTANCE,
94 | )
95 | server.run()
96 |
--------------------------------------------------------------------------------
/src/mosec_services/image_bind.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorchord/inference-benchmark/330d7497389a70998170a35d50aa28b6f591c1ab/src/mosec_services/image_bind.py
--------------------------------------------------------------------------------
/src/mosec_services/llama.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from mosec import Server, Worker
3 | from transformers import LlamaForCausalLM, LlamaTokenizer
4 |
5 | MODEL = "decapoda-research/llama-7b-hf"
6 | MAX_LENGTH = 50
7 |
8 |
9 | class TokenEncoder(Worker):
10 | def __init__(self):
11 | self.tokenizer = LlamaTokenizer.from_pretrained(MODEL)
12 |
13 | def deserialize(self, data):
14 | return data.decode()
15 |
16 | def forward(self, data):
17 | tokens = self.tokenizer(data)
18 | print(tokens)
19 | return tokens.input_ids
20 |
21 |
22 | class Inference(Worker):
23 | def __init__(self):
24 | self.model = LlamaForCausalLM.from_pretrained(MODEL)
25 | self.device = (
26 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
27 | )
28 | self.model.to(self.device)
29 | self.model.eval()
30 |
31 | def forward(self, data):
32 | inputs = torch.nn.utils.rnn.pad_sequence(
33 | [torch.tensor(tokens) for tokens in data], batch_first=True
34 | ).to(self.device)
35 | outputs = self.model.generate(inputs, max_length=MAX_LENGTH).tolist()
36 | return outputs
37 |
38 |
39 | class TokenDecoder(Worker):
40 | def __init__(self):
41 | self.tokenizer = LlamaTokenizer.from_pretrained(MODEL)
42 |
43 | def forward(self, data):
44 | outputs = self.tokenizer.decode(data, skip_special_tokens=True)
45 | return outputs
46 |
47 |
48 | if __name__ == "__main__":
49 | server = Server()
50 | server.append_worker(TokenEncoder)
51 | server.append_worker(Inference, max_batch_size=4, timeout=30)
52 | server.append_worker(TokenDecoder)
53 | server.run()
54 |
--------------------------------------------------------------------------------
/src/mosec_services/stable_diffusion.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 | from typing import List
3 |
4 | import torch # type: ignore
5 | from diffusers import StableDiffusionPipeline # type: ignore
6 | from mosec import Server, Worker, get_logger
7 | from mosec.mixin import MsgpackMixin
8 |
9 | logger = get_logger()
10 |
11 |
12 | class StableDiffusion(MsgpackMixin, Worker):
13 | def __init__(self):
14 | device = "cuda" if torch.cuda.is_available() else "cpu"
15 | self.pipe = StableDiffusionPipeline.from_pretrained(
16 | "runwayml/stable-diffusion-v1-5",
17 | torch_dtype=torch.float16 if device == "cuda" else torch.float32,
18 | )
19 | self.example = ["useless example prompt"] * 8
20 | self.pipe = self.pipe.to(device)
21 |
22 | def forward(self, data: List[str]) -> List[memoryview]:
23 | res = self.pipe(data)
24 | images = []
25 | for img in res[0]:
26 | dummy_file = BytesIO()
27 | img.save(dummy_file, format="JPEG")
28 | images.append(dummy_file.getbuffer())
29 | return images
30 |
31 |
32 | if __name__ == "__main__":
33 | server = Server()
34 | server.append_worker(StableDiffusion, num=1, max_batch_size=8, timeout=30)
35 | server.run()
36 |
--------------------------------------------------------------------------------
/src/mosec_services/whisper.py:
--------------------------------------------------------------------------------
1 | import io
2 |
3 | import numpy as np
4 | import soundfile
5 | import torch
6 | from mosec import Server, Worker
7 | from transformers import WhisperForConditionalGeneration, WhisperProcessor
8 |
9 | STEREO_CHANNEL_NUM = 2
10 |
11 |
12 | class Preprocess(Worker):
13 | def __init__(self):
14 | self.processor = WhisperProcessor.from_pretrained("openai/whisper-base")
15 |
16 | def deserialize(self, data: bytes) -> any:
17 | with io.BytesIO(data) as byte_io:
18 | array, sampling_rate = soundfile.read(byte_io)
19 | if array.shape[1] == STEREO_CHANNEL_NUM:
20 | # conbime the channel
21 | array = np.mean(array, 1)
22 | return {"array": array, "sampling_rate": sampling_rate}
23 |
24 | def forward(self, data):
25 | res = self.processor(
26 | data["array"], sampling_rate=data["sampling_rate"], return_tensors="pt"
27 | )
28 | return res.input_features
29 |
30 |
31 | class Inference(Worker):
32 | def __init__(self):
33 | self.model = WhisperForConditionalGeneration.from_pretrained(
34 | "openai/whisper-base"
35 | )
36 | self.model.config.forced_decoder_ids = None
37 | self.device = (
38 | torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
39 | )
40 | self.model.to(self.device)
41 |
42 | def forward(self, data):
43 | ids = self.model.generate(torch.cat(data).to(self.device))
44 | return ids.cpu().tolist()
45 |
46 |
47 | class Postprocess(Worker):
48 | def __init__(self):
49 | self.processor = WhisperProcessor.from_pretrained("openai/whisper-base")
50 |
51 | def forward(self, data):
52 | return self.processor.batch_decode(data, skip_special_tokens=True)
53 |
54 | def serialize(self, data: str) -> bytes:
55 | return data.encode("utf-8")
56 |
57 |
58 | if __name__ == "__main__":
59 | server = Server()
60 | server.append_worker(Preprocess, num=2)
61 | server.append_worker(Inference, max_batch_size=16, max_wait_time=10)
62 | server.append_worker(Postprocess, num=2, max_batch_size=8, max_wait_time=5)
63 | server.run()
64 |
--------------------------------------------------------------------------------
/src/potassium_services/stable_diffusion.py:
--------------------------------------------------------------------------------
1 | """Potassium service for stable diffusion.
2 |
3 | This implementation is based on the official README example:
4 |
5 | - https://github.com/bananaml/potassium#or-do-it-yourself
6 | """
7 |
8 | import base64
9 | from io import BytesIO
10 |
11 | import torch
12 | from diffusers import StableDiffusionPipeline
13 | from potassium import Potassium, Request, Response
14 |
15 | app = Potassium(__name__)
16 |
17 |
18 | @app.init
19 | def init():
20 | device = "cuda" if torch.cuda.is_available() else "cpu"
21 | model = StableDiffusionPipeline.from_pretrained(
22 | "runwayml/stable-diffusion-v1-5",
23 | torch_dtype=torch.float16 if device == "cuda" else torch.float32,
24 | )
25 | model.to(device)
26 |
27 | context = {
28 | "model": model,
29 | }
30 | return context
31 |
32 |
33 | @app.handler()
34 | def handler(context: dict, request: Request) -> Response:
35 | model = context["model"]
36 | prompt = request.json.get("prompt")
37 | image = model(prompt).images[0]
38 | buf = BytesIO()
39 | image.save(buf, format="JPEG")
40 | return Response(
41 | json={"image_b64": base64.b64encode(buf.getvalue()).decode("utf-8")},
42 | status=200,
43 | )
44 |
45 |
46 | if __name__ == "__main__":
47 | app.serve()
48 |
--------------------------------------------------------------------------------
/src/pytriton_services/bert/client.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pytriton.client import ModelClient
3 |
4 | MODEL_NAME = "distilbert"
5 | MODEL_URL = "127.0.0.1"
6 | TIMEOUT_SECOND = 600
7 |
8 |
9 | def query(url: str = MODEL_URL, name: str = MODEL_NAME):
10 | with ModelClient(url, name, init_timeout_s=TIMEOUT_SECOND) as client:
11 | message = "The quick brown fox jumps over the lazy dog"
12 | request = np.char.encode(np.asarray([[message]]), "utf-8")
13 | resp = client.infer_batch(messages=request)
14 | print(resp)
15 |
16 |
17 | if __name__ == "__main__":
18 | query()
19 |
--------------------------------------------------------------------------------
/src/pytriton_services/bert/service.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from pytriton.decorators import batch
4 | from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
5 | from pytriton.triton import Triton, TritonConfig
6 | from transformers import ( # type: ignore
7 | DistilBertForSequenceClassification,
8 | DistilBertTokenizer,
9 | )
10 |
11 | MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
12 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13 | NUM_INSTANCE = 1
14 | tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
15 | model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
16 | model.to(DEVICE)
17 | model.eval()
18 |
19 |
20 | @batch
21 | def infer(messages: np.ndarray):
22 | msg = [
23 | np.char.decode(message.astype("bytes"), "utf-8").item() for message in messages
24 | ]
25 | inputs = tokenizer(msg, return_tensors="pt", padding=True)
26 | inputs.to(DEVICE)
27 | logits = model(**inputs).logits
28 | label_ids = logits.argmax(dim=1).cpu().tolist()
29 | return {"labels": np.asarray([model.config.id2label[i] for i in label_ids])}
30 |
31 |
32 | def main():
33 | config = TritonConfig(exit_on_error=True)
34 | with Triton(config=config) as triton:
35 | triton.bind(
36 | model_name="distilbert",
37 | infer_func=[infer] * NUM_INSTANCE,
38 | inputs=[Tensor(name="messages", dtype=np.bytes_, shape=(1,))],
39 | outputs=[Tensor(name="labels", dtype=np.bytes_, shape=(1,))],
40 | config=ModelConfig(
41 | max_batch_size=32,
42 | batcher=DynamicBatcher(max_queue_delay_microseconds=10),
43 | ),
44 | )
45 | triton.serve()
46 |
47 |
48 | if __name__ == "__main__":
49 | main()
50 |
--------------------------------------------------------------------------------
/src/pytriton_services/stable_diffusion/client.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pytriton.client import ModelClient
3 |
4 | MODEL_NAME = "stable_diffusion"
5 | MODEL_URL = "127.0.0.1"
6 | TIMEOUT_SECOND = 600
7 |
8 |
9 | def query(url: str = MODEL_URL, name: str = MODEL_NAME):
10 | with ModelClient(url, name, init_timeout_s=TIMEOUT_SECOND) as client:
11 | prompt = "A photo of a cat"
12 | prompt_req = np.char.encode(np.asarray([[prompt]]), "utf-8")
13 | client.infer_batch(prompt=prompt_req)
14 |
15 |
16 | if __name__ == "__main__":
17 | query()
18 |
--------------------------------------------------------------------------------
/src/pytriton_services/stable_diffusion/service.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Server for Stable Diffusion 1.5."""
16 |
17 | import base64
18 | import io
19 | import logging
20 |
21 | import numpy as np
22 | import torch # pytype: disable=import-error
23 | from diffusers import StableDiffusionPipeline # pytype: disable=import-error
24 | from pytriton.decorators import batch
25 | from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
26 | from pytriton.triton import Triton, TritonConfig
27 |
28 | LOGGER = logging.getLogger(__file__)
29 |
30 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31 | IMAGE_FORMAT = "JPEG"
32 |
33 | pipe = pipe = StableDiffusionPipeline.from_pretrained(
34 | "runwayml/stable-diffusion-v1-5",
35 | torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
36 | )
37 | pipe = pipe.to(DEVICE)
38 |
39 |
40 | def _encode_image_to_base64(image):
41 | raw_bytes = io.BytesIO()
42 | image.save(raw_bytes, IMAGE_FORMAT)
43 | raw_bytes.seek(0) # return to the start of the buffer
44 | return base64.b64encode(raw_bytes.read())
45 |
46 |
47 | @batch
48 | def _infer_fn(prompt: np.ndarray):
49 | prompts = [np.char.decode(p.astype("bytes"), "utf-8").item() for p in prompt]
50 | LOGGER.debug(f"Prompts: {prompts}")
51 |
52 | outputs = []
53 | for idx, image in enumerate(pipe(prompt=prompts).images):
54 | raw_data = _encode_image_to_base64(image)
55 | outputs.append(raw_data)
56 | LOGGER.debug(
57 | f"Generated result for prompt `{prompts[idx]}` with size {len(raw_data)}"
58 | )
59 |
60 | LOGGER.debug(f"Prepared batch response of size: {len(outputs)}")
61 | return {"image": np.array(outputs)}
62 |
63 |
64 | def main():
65 | """Initialize server with model."""
66 | log_level = logging.DEBUG
67 | logging.basicConfig(
68 | level=log_level, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s"
69 | )
70 | log_verbose = 1
71 | config = TritonConfig(exit_on_error=True, log_verbose=log_verbose)
72 |
73 | with Triton(config=config) as triton:
74 | LOGGER.info("Loading the pipeline")
75 | triton.bind(
76 | model_name="stable_diffusion",
77 | infer_func=_infer_fn,
78 | inputs=[
79 | Tensor(name="prompt", dtype=np.bytes_, shape=(1,)),
80 | ],
81 | outputs=[
82 | Tensor(name="image", dtype=np.bytes_, shape=(1,)),
83 | ],
84 | config=ModelConfig(
85 | max_batch_size=4,
86 | batcher=DynamicBatcher(
87 | max_queue_delay_microseconds=100,
88 | ),
89 | ),
90 | )
91 | triton.serve()
92 |
93 |
94 | if __name__ == "__main__":
95 | main()
96 |
--------------------------------------------------------------------------------
/src/triton_services/client.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tritonclient.http as httpclient
3 |
4 | with httpclient.InferenceServerClient("localhost:8000") as client:
5 | prompt = httpclient.InferInput("PROMPT", shape=(1,), datatype="BYTES")
6 | prompt.set_data_from_numpy(np.asarray(["cat"], dtype=object))
7 | images = httpclient.InferRequestedOutput("IMAGE", binary_data=False)
8 | response = client.infer(
9 | model_name="stable_diffusion",
10 | inputs=[prompt],
11 | outputs=[images],
12 | )
13 | content = response.as_numpy("IMAGE")
14 |
--------------------------------------------------------------------------------
/src/triton_services/stable_diffusion/1/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import triton_python_backend_utils as triton_utils
4 | from diffusers import StableDiffusionPipeline
5 |
6 |
7 | class TritonPythonModel:
8 | def initialize(self, args):
9 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
10 | self.model = StableDiffusionPipeline.from_pretrained(
11 | "runwayml/stable-diffusion-v1-5",
12 | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
13 | )
14 | self.model.to(self.device)
15 |
16 | def execute(self, requests):
17 | responses = []
18 | prompts = []
19 | for request in requests:
20 | prompt = (
21 | triton_utils.get_input_tensor_by_name(request, "PROMPT")
22 | .as_numpy()
23 | .tolist()
24 | )
25 | prompts.append(prompt[0].decode())
26 |
27 | images = self.model(prompts).images
28 | for image in images:
29 | output = triton_utils.Tensor("IMAGE", np.asarray(image))
30 | resp = triton_utils.InferenceResponse(output_tensors=[output])
31 | responses.append(resp)
32 |
33 | return responses
34 |
--------------------------------------------------------------------------------
/src/triton_services/stable_diffusion/config.pbtxt:
--------------------------------------------------------------------------------
1 | name: "stable_diffusion"
2 | backend: "python"
3 | # max_batch_size: 8
4 | # dynamic_batching {
5 | # max_queue_delay_microseconds: 10
6 | # }
7 |
8 | input [
9 | {
10 | name: "PROMPT"
11 | data_type: TYPE_STRING
12 | dims: [ 1 ]
13 | }
14 | ]
15 |
16 | output [
17 | {
18 | name: "IMAGE"
19 | data_type: TYPE_UINT8
20 | dims: [ 512, 512, 3 ]
21 | }
22 | ]
23 |
24 | instance_group [
25 | {
26 | count: 1
27 | }
28 | ]
--------------------------------------------------------------------------------