├── models ├── detection │ ├── 1 │ │ └── read │ └── config.pbtxt ├── post_classification │ ├── 1 │ │ └── read │ └── config.pbtxt ├── classification │ ├── 1 │ │ └── read │ └── config.pbtxt ├── model_pipeline │ ├── 1 │ │ └── read │ └── config.pbtxt └── post_detection │ ├── 1 │ └── model.py │ └── config.pbtxt ├── image1.png ├── image2.png ├── pipeline.png ├── Dockerfile ├── README.md ├── .gitignore └── client.py /models/detection/1/read: -------------------------------------------------------------------------------- 1 | place model there -------------------------------------------------------------------------------- /models/post_classification/config.pbtxt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/classification/1/read: -------------------------------------------------------------------------------- 1 | place model there -------------------------------------------------------------------------------- /models/post_classification/1/read: -------------------------------------------------------------------------------- 1 | place model there -------------------------------------------------------------------------------- /models/model_pipeline/1/read: -------------------------------------------------------------------------------- 1 | this sub folder should be empty -------------------------------------------------------------------------------- /image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bobo-y/triton_ensemble_model_demo/HEAD/image1.png -------------------------------------------------------------------------------- /image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bobo-y/triton_ensemble_model_demo/HEAD/image2.png -------------------------------------------------------------------------------- /pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bobo-y/triton_ensemble_model_demo/HEAD/pipeline.png -------------------------------------------------------------------------------- /models/post_detection/config.pbtxt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bobo-y/triton_ensemble_model_demo/HEAD/models/post_detection/config.pbtxt -------------------------------------------------------------------------------- /models/classification/config.pbtxt: -------------------------------------------------------------------------------- 1 | name: "classification" 2 | platform: "tensorflow_savedmodel" 3 | input [ 4 | { 5 | name: "input_1" 6 | data_type: TYPE_FP32 7 | dims: [-1, 260, 260, 3 ] 8 | } 9 | ] 10 | output [ 11 | { 12 | name: "dense" 13 | data_type: TYPE_FP32 14 | dims: [-1, 3] 15 | } 16 | ] 17 | -------------------------------------------------------------------------------- /models/detection/config.pbtxt: -------------------------------------------------------------------------------- 1 | name: "detection" 2 | platform: "onnxruntime_onnx" 3 | input [ 4 | { 5 | name: "images" 6 | data_type: TYPE_FP32 7 | format: FORMAT_NCHW 8 | dims: [ 3, 640, 640 ] 9 | reshape { shape: [ 1, 3, 640, 640 ] } 10 | } 11 | ] 12 | output [ 13 | { 14 | name: "output" 15 | data_type: TYPE_FP32 16 | dims: [1,25200,7] 17 | } 18 | ] 19 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/tritonserver:21.03-py3 2 | 3 | LABEL maintainer="yl305237731@foxmail.com" description="triton serving including models" 4 | 5 | 6 | RUN pip install --upgrade pip && pip install -U opencv-python && apt-get upgrade && apt update && apt install -y libsm6 libxext6 ffmpeg libfontconfig1 libxrender1 libgl1-mesa-glx \ 7 | && pip install torch==1.5.0 && pip install torchvision==0.6.0 && pip install numpy 8 | 9 | # Copy all models to docker 10 | COPY ./models /models 11 | 12 | 13 | RUN echo -e '#!/bin/bash \n\n\ 14 | tritonserver --model-repository=/models \ 15 | "$@"' > /usr/bin/triton_serving_entrypoint.sh \ 16 | && chmod +x /usr/bin/triton_serving_entrypoint.sh 17 | 18 | ENTRYPOINT ["/usr/bin/triton_serving_entrypoint.sh"] 19 | -------------------------------------------------------------------------------- /models/model_pipeline/config.pbtxt: -------------------------------------------------------------------------------- 1 | name: "model_pipeline" 2 | platform: "ensemble" 3 | input [ 4 | { 5 | name: "IMAGE" 6 | data_type: TYPE_FP32 7 | dims: [ 3, 640, 640 ] 8 | } 9 | ] 10 | output [ 11 | { 12 | name: "BBOXES" 13 | data_type: TYPE_FP32 14 | dims: [-1, 6] 15 | }, 16 | { 17 | name: "CLASSIFICATION" 18 | data_type: TYPE_FP32 19 | dims: [ -1, 3 ] 20 | } 21 | ] 22 | 23 | ensemble_scheduling { 24 | step [ 25 | { 26 | model_name: "detection" 27 | model_version: -1 28 | input_map { 29 | key: "images" 30 | value: "IMAGE" 31 | } 32 | output_map { 33 | key: "output" 34 | value: "DETECTION" 35 | } 36 | }, 37 | { 38 | model_name: "post_detection" 39 | model_version: -1 40 | input_map { 41 | key: "INPUT0" 42 | value: "IMAGE" 43 | } 44 | input_map { 45 | key: "INPUT1" 46 | value: "DETECTION" 47 | } 48 | output_map { 49 | key: "OUTPUT0" 50 | value: "crops" 51 | } 52 | output_map { 53 | key: "OUTPUT1" 54 | value: "BBOXES" 55 | } 56 | }, 57 | { 58 | model_name: "classification" 59 | model_version: -1 60 | input_map { 61 | key: "input_1" 62 | value: "crops" 63 | } 64 | output_map { 65 | key: "dense" 66 | value: "CLASSIFICATION" 67 | } 68 | } 69 | ] 70 | } 71 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # triton_ensemble_model_template 2 | 3 | 4 | ## Pipeline algorithms in client(left) vs Pipeline algorithms in triton (right) 5 | 6 | ![](pipeline.png) 7 | 8 | 9 | ## ensemble model 10 | 11 | input: raw image 12 | output: bboxes, classification info. (if there no object in image, not return empty, return a specified value) 13 | 14 | input image-->detection--> post_detection: filter object and crop object and preprocess for classification --> classification --> post_classification 15 | 16 | 17 | In this template i want detect person and distinguish whether they are wearing vest. 18 | 19 | ### detection model 20 | 21 | The detection model is yolov5.pytorch, export to model.onnx, input-size:640, detect two class: person and head, in this demo, only use person [my yolov5](https://github.com/yl305237731/flexible-yolov5) 22 | 23 | ### post detection model 24 | 25 | The post detection model, write by python backend, first, get detection bboxes, then, according to origin input image and detection bboxes to crop person, and format every crop image to classification input. 26 | 27 | *notice: In order to prevent the subsequent classification error caused by no person in detection result, a random image is appended at the end of the crop image list.* 28 | 29 | ### classification model 30 | 31 | The classification model is efficient-net.keras, input-size: 260, return 3 class, background, wear vest, no wear vest. 32 | 33 | ### post classification 34 | 35 | The post classification model. In this demo, only use one classification mode, so not implementation this. 36 | 37 | ### result 38 | 39 | ![](image1.png) 40 | 41 | ![](image2.png) 42 | 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea/ 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | -------------------------------------------------------------------------------- /client.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | from tritonclient.utils import * 28 | import tritonclient.grpc as grpcclient 29 | import tritonclient.http as httpclient 30 | 31 | import numpy as np 32 | 33 | # 调用入口为集成模型 34 | 35 | model_name = "model_pipeline" 36 | 37 | with httpclient.InferenceServerClient("localhost:8000") as client: 38 | input0_data = np.random.rand(3, 640, 640).astype(np.float32) 39 | inputs = [ 40 | httpclient.InferInput("IMAGE", input0_data.shape, np_to_triton_dtype(input0_data.dtype)) 41 | ] 42 | 43 | inputs[0].set_data_from_numpy(input0_data) 44 | 45 | outputs = [ 46 | httpclient.InferRequestedOutput("CLASSIFICATION"), 47 | httpclient.InferRequestedOutput("BBOXES") 48 | ] 49 | 50 | response = client.infer(model_name, 51 | inputs, 52 | request_id=str(1), 53 | outputs=outputs) 54 | result = response.get_response() 55 | print("OUTPUT0 ({})".format(response.as_numpy("CLASSIFICATION"))) 56 | print("OUTPUT0 ({})".format(response.as_numpy("BBOXES"))) 57 | -------------------------------------------------------------------------------- /models/post_detection/1/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | import numpy as np 28 | import sys 29 | import json 30 | import cv2 31 | import torch 32 | import torchvision 33 | import time 34 | 35 | # triton_python_backend_utils is available in every Triton Python model. You 36 | # need to use this module to create inference requests and responses. It also 37 | # contains some utility functions for extracting information from model_config 38 | # and converting Triton input/output types to numpy types. 39 | import triton_python_backend_utils as pb_utils 40 | 41 | 42 | class TritonPythonModel: 43 | """Your Python model must use the same class name. Every Python model 44 | that is created must have "TritonPythonModel" as the class name. 45 | """ 46 | 47 | def initialize(self, args): 48 | """`initialize` is called only once when the model is being loaded. 49 | Implementing `initialize` function is optional. This function allows 50 | the model to intialize any state associated with this model. 51 | Parameters 52 | ---------- 53 | args : dict 54 | Both keys and values are strings. The dictionary keys and values are: 55 | * model_config: A JSON string containing the model configuration 56 | * model_instance_kind: A string containing model instance kind 57 | * model_instance_device_id: A string containing model instance device ID 58 | * model_repository: Model repository path 59 | * model_version: Model version 60 | * model_name: Model name 61 | """ 62 | 63 | # You must parse model_config. JSON string is not parsed here 64 | self.model_config = model_config = json.loads(args['model_config']) 65 | 66 | # Get OUTPUT0 configuration 67 | output0_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT0") 68 | 69 | # Convert Triton types to numpy types 70 | self.output0_dtype = pb_utils.triton_string_to_numpy(output0_config['data_type']) 71 | 72 | # Get OUTPUT0 configuration 73 | output1_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT1") 74 | 75 | # Convert Triton types to numpy types 76 | self.output1_dtype = pb_utils.triton_string_to_numpy(output1_config['data_type']) 77 | 78 | self.confThreshold = 0.4 79 | self.class_id = 0 80 | 81 | def xywh2xyxy(self, x): 82 | """ 83 | 84 | Args: 85 | x: 86 | 87 | Returns: 88 | 89 | """ 90 | y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x) 91 | y[:, 0] = x[:, 0] - x[:, 2] / 2 92 | y[:, 1] = x[:, 1] - x[:, 3] / 2 93 | y[:, 2] = x[:, 0] + x[:, 2] / 2 94 | y[:, 3] = x[:, 1] + x[:, 3] / 2 95 | return y 96 | 97 | def non_max_suppression(self, pred, conf_thres=0.4, iou_thres=0.5, classes=0, agnostic=False): 98 | """Performs Non-Maximum Suppression (NMS) on inference results 99 | 100 | Returns: 101 | detections with shape: nx6 (x1, y1, x2, y2, conf, cls) 102 | """ 103 | prediction = torch.from_numpy(pred.astype(np.float32)) 104 | if prediction.dtype is torch.float16: 105 | prediction = prediction.float() 106 | nc = prediction[0].shape[1] - 5 107 | xc = prediction[..., 4] > conf_thres 108 | min_wh, max_wh = 2, 4096 109 | max_det = 100 110 | time_limit = 10.0 111 | multi_label = nc > 1 112 | output = [None] * prediction.shape[0] 113 | t = time.time() 114 | for xi, x in enumerate(prediction): 115 | x = x[xc[xi]] 116 | if not x.shape[0]: 117 | continue 118 | x[:, 5:] *= x[:, 4:5] 119 | box = self.xywh2xyxy(x[:, :4]) 120 | if multi_label: 121 | i, j = (x[:, 5:] > conf_thres).nonzero().t() 122 | x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) 123 | else: 124 | conf, j = x[:, 5:].max(1, keepdim=True) 125 | x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] 126 | if classes: 127 | x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] 128 | n = x.shape[0] 129 | if not n: 130 | continue 131 | c = x[:, 5:6] * (0 if agnostic else max_wh) 132 | boxes, scores = x[:, :4] + c, x[:, 4] 133 | i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) 134 | if i.shape[0] > max_det: 135 | i = i[:max_det] 136 | output[xi] = x[i] 137 | if (time.time() - t) > time_limit: 138 | break 139 | return output 140 | 141 | def scale_coords(self, img1_shape, coords, img0_shape, ratio_pad=None): 142 | """ 143 | 144 | Args: 145 | img1_shape: 146 | coords: 147 | img0_shape: 148 | ratio_pad: 149 | 150 | Returns: 151 | 152 | """ 153 | if ratio_pad is None: 154 | gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new 155 | pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding 156 | else: 157 | gain = ratio_pad[0][0] 158 | pad = ratio_pad[1] 159 | coords[:, [0, 2]] -= pad[0] 160 | coords[:, [1, 3]] -= pad[1] 161 | coords[:, :4] /= gain 162 | self.clip_coords(coords, img0_shape) 163 | return coords 164 | 165 | def clip_coords(self, boxes, img_shape): 166 | """ 167 | 168 | Args: 169 | boxes: 170 | img_shape: 171 | 172 | Returns: 173 | 174 | """ 175 | boxes[:, 0].clamp_(0, img_shape[1]) 176 | boxes[:, 1].clamp_(0, img_shape[0]) 177 | boxes[:, 2].clamp_(0, img_shape[1]) 178 | boxes[:, 3].clamp_(0, img_shape[0]) 179 | 180 | def get_bbox(self, image, detections_bs): 181 | boxes = self.non_max_suppression(detections_bs) 182 | image_shape = image.shape 183 | outputs = [[-1, 0, 0, 0, 0, 0]] 184 | crops = [] 185 | if len(boxes) > 0: 186 | for i, det in enumerate(boxes): 187 | if det is not None and len(det): 188 | det[:, :4] = self.scale_coords((640, 640), det[:, :4], 189 | (image_shape[0], image_shape[1], image_shape[2])).round() 190 | 191 | for *xyxy, conf, cls in det: 192 | x_min = (xyxy[0] / float(image_shape[1])) 193 | y_min = (xyxy[1] / float(image_shape[0])) 194 | x_max = (xyxy[2] / float(image_shape[1])) 195 | y_max = (xyxy[3] / float(image_shape[0])) 196 | score = conf 197 | class_id = int(cls) 198 | 199 | if class_id == self.class_id and score > self.confThreshold: 200 | outputs.append([class_id, score, x_min, y_min, x_max, y_max]) 201 | crops.append([image[y_min:y_max, x_min:x_max]]) 202 | return outputs, crops 203 | 204 | def resize_image(self, input_image, target_size=416, mode=None): 205 | """Resize input to target size. 206 | 207 | Args: 208 | img: a ndarray, image data. 209 | target_size: an integer 210 | 211 | Return: 212 | img: a ndarray, image data. 213 | scale: a list of two elements, [col_scale, row_scale], indicates the ratio of resized length / original length. 214 | """ 215 | img = input_image.copy() 216 | (rows, cols, _) = img.shape 217 | if mode: 218 | img = cv2.resize(img, (int(target_size), int(target_size)), mode) 219 | else: 220 | img = cv2.resize(img, (int(target_size), int(target_size))) 221 | 222 | scale = [float(target_size) / cols, float(target_size) / rows] 223 | 224 | return img, scale 225 | 226 | def process_classfi_data(self, crops): 227 | # 分类图像预处理 228 | new_crops = [] 229 | for image_np in crops: 230 | image_np = self.resize_image(image_np, 260, 'inter_area') 231 | image_np = image_np.astype(np.float32) 232 | image_np /= 255. 233 | image_np -= 0.5 234 | image_np *= 2 235 | new_crops.append(image_np) 236 | # 可能一个目标都没有,但是流程好像得走完,因此随机生成一张图片防止分类器没有输入而报错 237 | new_crops.append(np.random.rand(260, 260, 3)) 238 | return np.asarray(new_crops) 239 | 240 | def execute(self, requests): 241 | """`execute` MUST be implemented in every Python model. `execute` 242 | function receives a list of pb_utils.InferenceRequest as the only 243 | argument. This function is called when an inference request is made 244 | for this model. Depending on the batching configuration (e.g. Dynamic 245 | Batching) used, `requests` may contain multiple requests. Every 246 | Python model, must create one pb_utils.InferenceResponse for every 247 | pb_utils.InferenceRequest in `requests`. If there is an error, you can 248 | set the error argument when creating a pb_utils.InferenceResponse 249 | Parameters 250 | ---------- 251 | requests : list 252 | A list of pb_utils.InferenceRequest 253 | Returns 254 | ------- 255 | list 256 | A list of pb_utils.InferenceResponse. The length of this list must 257 | be the same as `requests` 258 | """ 259 | output0_dtype = self.output0_dtype 260 | output1_dtype = self.output1_dtype 261 | responses = [] 262 | 263 | for request in requests: 264 | in_0 = pb_utils.get_input_tensor_by_name(request, "INPUT0").as_numpy() 265 | in_0 = np.transpose(in_0, (1, 2, 0)) 266 | in_1 = pb_utils.get_input_tensor_by_name(request, "INPUT1").as_numpy() 267 | 268 | bboxs, crops = self.get_bbox(in_0, in_1) 269 | crops = self.process_classfi_data(crops) 270 | out_tensor_0 = pb_utils.Tensor("OUTPUT0", crops.astype(output0_dtype)) 271 | out_tensor_1 = pb_utils.Tensor("OUTPUT1", np.asarray(bboxs).astype(output1_dtype)) 272 | inference_response = pb_utils.InferenceResponse(output_tensors=[out_tensor_0, out_tensor_1]) 273 | responses.append(inference_response) 274 | return responses 275 | 276 | 277 | def finalize(self): 278 | """`finalize` is called only once when the model is being unloaded. 279 | Implementing `finalize` function is OPTIONAL. This function allows 280 | the model to perform any necessary clean ups before exit. 281 | """ 282 | print('Cleaning up...') 283 | --------------------------------------------------------------------------------