├── .gitignore ├── README.md ├── assets ├── bus.png ├── horse.gif ├── horse.png ├── street.gif └── zidane.png ├── images ├── bus.jpg ├── horse.jpg ├── street.jpg └── zidane.jpg ├── models └── yolov5s-seg.onnx ├── requirements.txt ├── segment.py └── src ├── draw.py ├── general.py ├── models.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YOLOv5 Segmentation Python 2 | 3 | ![horse](./assets/horse.png) 4 | 5 | --- 6 | 7 | Run yolov5 segmentation model on _onnxruntime_ or _opencv dnn_ **without torch**! 8 | 9 | ## Usage 10 | 11 | All you need is on `segment.py`, it provides cli to run yolov5-seg onnx model. 12 | 13 | ```bash 14 | python segment.py -- help 15 | ``` 16 | 17 | ### Image 18 | 19 | | | | 20 | | :----------------------------: | :----------------------: | 21 | | ![zidane](./assets/zidane.png) | ![bus](./assets/bus.png) | 22 | 23 | ```bash 24 | python segment.py -m -i 25 | ``` 26 | 27 | ### Video 28 | 29 | | | | 30 | | :--------------------------: | :----------------------------: | 31 | | ![horse](./assets/horse.gif) | ![street](./assets/street.gif) | 32 | 33 | ```bash 34 | python segment.py -m -v 0 # webcam 35 | # local video 36 | ``` 37 | 38 | **Note** : Press `q` to stop video processing. 39 | 40 | ## OpenCV DNN 41 | 42 | Use Opencv DNN as backend with `--dnn` arguments. 43 | 44 | ```bash 45 | python segment.py -m -v 0 --dnn 46 | ``` 47 | 48 | ## Run on GPU 49 | 50 | Auto using gpu to run model when devices is supported. 51 | 52 | - `onnxruntime` need `onnxruntime-gpu` to be installed. 53 | 54 | **Note** : `onnxruntime-gpu` must be installed with the same version as `onnxruntime` to be able to use GPU. 55 | 56 | - `opencv-dnn` need custom build. 57 | 58 | ```bash 59 | pip install onnxruntime-gpu== 60 | ``` 61 | 62 | ## Reference 63 | 64 | - https://github.com/ultralytics/yolov5 65 | - https://github.com/UNeedCryDear/yolov5-seg-opencv-dnn-cpp 66 | -------------------------------------------------------------------------------- /assets/bus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hyuto/yolov5-seg-python/4d033a64c07b9d5234664f0453f197029a0ae62b/assets/bus.png -------------------------------------------------------------------------------- /assets/horse.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hyuto/yolov5-seg-python/4d033a64c07b9d5234664f0453f197029a0ae62b/assets/horse.gif -------------------------------------------------------------------------------- /assets/horse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hyuto/yolov5-seg-python/4d033a64c07b9d5234664f0453f197029a0ae62b/assets/horse.png -------------------------------------------------------------------------------- /assets/street.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hyuto/yolov5-seg-python/4d033a64c07b9d5234664f0453f197029a0ae62b/assets/street.gif -------------------------------------------------------------------------------- /assets/zidane.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hyuto/yolov5-seg-python/4d033a64c07b9d5234664f0453f197029a0ae62b/assets/zidane.png -------------------------------------------------------------------------------- /images/bus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hyuto/yolov5-seg-python/4d033a64c07b9d5234664f0453f197029a0ae62b/images/bus.jpg -------------------------------------------------------------------------------- /images/horse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hyuto/yolov5-seg-python/4d033a64c07b9d5234664f0453f197029a0ae62b/images/horse.jpg -------------------------------------------------------------------------------- /images/street.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hyuto/yolov5-seg-python/4d033a64c07b9d5234664f0453f197029a0ae62b/images/street.jpg -------------------------------------------------------------------------------- /images/zidane.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hyuto/yolov5-seg-python/4d033a64c07b9d5234664f0453f197029a0ae62b/images/zidane.jpg -------------------------------------------------------------------------------- /models/yolov5s-seg.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hyuto/yolov5-seg-python/4d033a64c07b9d5234664f0453f197029a0ae62b/models/yolov5s-seg.onnx -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hyuto/yolov5-seg-python/4d033a64c07b9d5234664f0453f197029a0ae62b/requirements.txt -------------------------------------------------------------------------------- /segment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import cv2 5 | 6 | from src.models import ORTModelLoader, DNNModelLoader 7 | from src.general import run_yolov5_seg 8 | from src.utils import check_file 9 | 10 | 11 | def parse_opt(): 12 | parser = argparse.ArgumentParser(description="Detect using YOLOv5 Segmentation model") 13 | required = parser.add_argument_group("required arguments") 14 | required.add_argument( 15 | "-m", "--model", type=str, required=True, help="YOLOv5 Segmentation onnx model path" 16 | ) 17 | source = parser.add_argument_group("source arguments") 18 | source.add_argument("-i", "--image", type=str, help="Image source") 19 | source.add_argument("-v", "--video", type=str, help="Video source") 20 | 21 | parser.add_argument( 22 | "--topk", 23 | type=int, 24 | default=100, 25 | help="Integer representing the maximum number of boxes to be selected per class", 26 | ) 27 | parser.add_argument( 28 | "--conf-tresh", 29 | type=float, 30 | default=0.2, 31 | help="Float representing the threshold for deciding when to remove boxes based on confidence score", 32 | ) 33 | parser.add_argument( 34 | "--iou-tresh", 35 | type=float, 36 | default=0.45, 37 | help="Float representing the threshold for deciding whether boxes overlap too much with respect to IOU", 38 | ) 39 | parser.add_argument( 40 | "--score-tresh", 41 | type=float, 42 | default=0.25, 43 | help="Float representing the threshold for deciding whether render boxes or not", 44 | ) 45 | parser.add_argument( 46 | "--mask-tresh", 47 | type=float, 48 | default=0.5, 49 | help="Float representing the threshold for deciding mask area", 50 | ) 51 | parser.add_argument( 52 | "--mask-alpha", 53 | type=float, 54 | default=0.4, 55 | help="Float representing the opacity of mask layer", 56 | ) 57 | parser.add_argument( 58 | "--dnn", 59 | action="store_true", 60 | help="Use OpenCV DNN module [if false using onnxruntime] for backend", 61 | ) 62 | 63 | opt = parser.parse_args() 64 | if opt.image is None and opt.video is None: 65 | raise argparse.ArgumentError("Please specify image or video source!") 66 | elif opt.image and opt.video: 67 | raise argparse.ArgumentError("Please specify either image or video source!") 68 | return opt 69 | 70 | 71 | def main(opt) -> None: 72 | if opt.dnn: 73 | model = DNNModelLoader(opt.model) # use Opencv DNN module 74 | else: 75 | model = ORTModelLoader(opt.model) # use onnxruntime 76 | 77 | # warmup model 78 | _ = run_yolov5_seg( 79 | model, 80 | (np.random.rand(model.width, model.height, 3) * 255).astype(np.uint8), # random image 81 | opt.conf_tresh, 82 | opt.iou_tresh, 83 | opt.score_tresh, 84 | opt.topk, 85 | opt.mask_tresh, 86 | opt.mask_alpha, 87 | ) 88 | 89 | if opt.image: 90 | check_file(opt.image, "Image file not found!") 91 | 92 | # Image preprocessing 93 | img = cv2.imread(opt.image) 94 | img = run_yolov5_seg( 95 | model, 96 | img, 97 | opt.conf_tresh, 98 | opt.iou_tresh, 99 | opt.score_tresh, 100 | opt.topk, 101 | opt.mask_tresh, 102 | opt.mask_alpha, 103 | ) 104 | 105 | cv2.imshow("output", img) 106 | cv2.waitKey(0) 107 | elif opt.video: 108 | # Video processing 109 | vid_source = 0 if opt.video == "0" else opt.video 110 | cap = cv2.VideoCapture(vid_source) 111 | 112 | while cap.isOpened(): 113 | ret, frame = cap.read() 114 | 115 | if not ret: 116 | break 117 | 118 | frame = run_yolov5_seg( 119 | model, 120 | frame, 121 | opt.conf_tresh, 122 | opt.iou_tresh, 123 | opt.score_tresh, 124 | opt.topk, 125 | opt.mask_tresh, 126 | opt.mask_alpha, 127 | ) 128 | 129 | cv2.imshow("output", frame) 130 | 131 | if cv2.waitKey(1) == ord("q"): 132 | break 133 | 134 | cap.release() 135 | cv2.destroyAllWindows() 136 | 137 | 138 | if __name__ == "__main__": 139 | opt = parse_opt() 140 | main(opt) 141 | -------------------------------------------------------------------------------- /src/draw.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import cv2 4 | import numpy as np 5 | import numpy.typing as npt 6 | 7 | 8 | def draw_boxes( 9 | source: npt.NDArray[np.uint8], 10 | box: npt.NDArray[np.int32], 11 | label: str, 12 | score: float, 13 | color: Tuple[int, int, int], 14 | ) -> None: 15 | """Draw boxes on images 16 | 17 | Args: 18 | source (npt.NDArray[np.uint8]): image array 19 | box (npt.NDArray[np.int32]): box to draw [left, top, width, height] 20 | label (str): box label 21 | score (float): box score 22 | color (Tuple[int, int, int]): box color in rgb format 23 | """ 24 | cv2.rectangle(source, box, color, 2) # draw box 25 | (label_width, label_height), _ = cv2.getTextSize( 26 | f"{label} - {round(score, 2)}", 27 | cv2.FONT_HERSHEY_SIMPLEX, 28 | 0.5, 29 | 1, 30 | ) 31 | cv2.rectangle( 32 | source, 33 | (box[0] - 1, box[1] - label_height - 6), 34 | (box[0] + label_width + 1, box[1]), 35 | color, 36 | -1, 37 | ) 38 | cv2.putText( 39 | source, 40 | f"{label} - {round(score, 2)}", 41 | (box[0], box[1] - 5), 42 | cv2.FONT_HERSHEY_SIMPLEX, 43 | 0.5, 44 | [255, 255, 255], 45 | 1, 46 | ) 47 | 48 | 49 | class Colors: 50 | """Ultralytics color palette https://ultralytics.com/""" 51 | 52 | def __init__(self) -> None: 53 | hexs = ( 54 | "FF3838", 55 | "FF9D97", 56 | "FF701F", 57 | "FFB21D", 58 | "CFD231", 59 | "48F90A", 60 | "92CC17", 61 | "3DDB86", 62 | "1A9334", 63 | "00D4BB", 64 | "2C99A8", 65 | "00C2FF", 66 | "344593", 67 | "6473FF", 68 | "0018EC", 69 | "8438FF", 70 | "520085", 71 | "CB38FF", 72 | "FF95C8", 73 | "FF37C7", 74 | ) 75 | self.palette = [self.hex2rgb(f"#{c}") for c in hexs] 76 | self.n = len(self.palette) 77 | 78 | def __call__(self, i: int, bgr: bool = False) -> Tuple[int, int, int]: 79 | c = self.palette[int(i) % self.n] 80 | return (c[2], c[1], c[0]) if bgr else c 81 | 82 | @staticmethod 83 | def hex2rgb(h: str) -> Tuple[int, int, int]: # rgb order (PIL) 84 | return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4)) 85 | 86 | 87 | colors = Colors() 88 | -------------------------------------------------------------------------------- /src/general.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import cv2 4 | import numpy as np 5 | import numpy.typing as npt 6 | 7 | from .models import ORTModelLoader, DNNModelLoader 8 | from .draw import draw_boxes, colors 9 | from .utils import get_divable_size, handle_overflow_box 10 | 11 | 12 | def run_yolov5_seg( 13 | model: Union[ORTModelLoader, DNNModelLoader], 14 | source: npt.NDArray[np.uint8], 15 | conf_tresh: float, 16 | iou_tresh: float, 17 | score_tresh: float, 18 | topk: int, 19 | mask_tresh: float, 20 | mask_alpha: float, 21 | ) -> npt.NDArray[np.uint8]: 22 | """Run YOLOv5 Segmentation model 23 | 24 | Args: 25 | model (Union[ORTModelLoader, DNNModelLoader]): Model loader 26 | source (npt.NDArray[np.uint8]): Source array 27 | conf_tresh (float): Confidences treshold 28 | iou_tresh (float): IoU or NMS treshold 29 | score_tresh (float): Scores treshold 30 | topk (int): TopK classes 31 | mask_tresh (float): Mask treshold 32 | mask_alpha (float): Mask opacity on overlay 33 | """ 34 | source_height, source_width, _ = source.shape 35 | 36 | ## resize to divable size by stride 37 | source_width, source_height = get_divable_size([source_width, source_height], model.stride) 38 | source = cv2.resize(source, [source_width, source_height]) 39 | 40 | ## padding image 41 | max_size = max(source_width, source_height) # get max size 42 | source_padded = np.zeros((max_size, max_size, 3), dtype=np.uint8) # initial zeros mat 43 | source_padded[:source_height, :source_width] = source.copy() # place original image 44 | overlay = source_padded.copy() # make overlay mat 45 | 46 | ## ratios 47 | x_ratio = max_size / model.width 48 | y_ratio = max_size / model.height 49 | 50 | # run model 51 | input_img = cv2.dnn.blobFromImage( 52 | source_padded, 53 | 1 / 255.0, 54 | (model.width, model.height), 55 | swapRB=False, 56 | crop=False, 57 | ) # normalize and resize: [h, w, 3] => [1, 3, h, w] 58 | result = model.forward(input_img) 59 | 60 | # box preprocessing 61 | result[0][0, :, 0] = (result[0][0, :, 0] - 0.5 * result[0][0, :, 2]) * x_ratio 62 | result[0][0, :, 1] = (result[0][0, :, 1] - 0.5 * result[0][0, :, 3]) * y_ratio 63 | result[0][0, :, 2] *= x_ratio 64 | result[0][0, :, 3] *= y_ratio 65 | 66 | # get boxes, conf, score, and mask 67 | boxes = result[0][0, :, :4] 68 | confidences = result[0][0, :, 4] 69 | scores = confidences.reshape(-1, 1) * result[0][0, :, 5 : len(model.labels) + 5] 70 | masks = result[0][0, :, len(model.labels) + 5 :] 71 | 72 | # NMS 73 | selected = cv2.dnn.NMSBoxes(boxes, confidences, conf_tresh, iou_tresh, top_k=topk) 74 | 75 | boxes_to_draw = [] # boxes to draw 76 | 77 | for i in selected: # loop through selected 78 | box = boxes[i].round().astype(np.int32) # to int 79 | box = handle_overflow_box(box, [max_size, max_size]) # handle overflow boxes 80 | 81 | _, score, _, label = cv2.minMaxLoc(scores[i]) # get score and classId 82 | if score >= score_tresh: # filtering by score_tresh 83 | color = colors(label[1], True) # get color 84 | 85 | # save box to draw latter (add mask first) 86 | boxes_to_draw.append([box, model.labels[label[1]], score, color]) 87 | 88 | # crop mask from proto 89 | x = int(round(box[0] * model.seg_width / max_size)) 90 | y = int(round(box[1] * model.seg_height / max_size)) 91 | w = int(round(box[2] * model.seg_width / max_size)) 92 | h = int(round(box[3] * model.seg_height / max_size)) 93 | 94 | # process protos 95 | protos = result[1][0, :, y : y + h, x : x + w].reshape(model.seg_chanels, -1) 96 | protos = np.expand_dims(masks[i], 0) @ protos # matmul 97 | protos = 1 / (1 + np.exp(-protos)) # sigmoid 98 | protos = protos.reshape(h, w) # reshape 99 | mask = cv2.resize(protos, (box[2], box[3])) # resize mask 100 | mask = mask >= mask_tresh # filtering mask by tresh 101 | 102 | # add mask to overlay layer 103 | to_mask = overlay[box[1] : box[1] + box[3], box[0] : box[0] + box[2]] # get box roi 104 | mask = mask[: to_mask.shape[0], : to_mask.shape[1]] # crop mask 105 | to_mask[mask] = color # apply mask 106 | 107 | # combine image and overlay 108 | source_padded = cv2.addWeighted(source_padded, 1 - mask_alpha, overlay, mask_alpha, 0) 109 | 110 | for draw_box in boxes_to_draw: # draw boxes 111 | draw_boxes(source_padded, *draw_box) 112 | 113 | source = source_padded[:source_height, :source_width] # crop padding 114 | 115 | return source 116 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import ast 2 | from typing import List 3 | 4 | import cv2 5 | import numpy as np 6 | import numpy.typing as npt 7 | import onnxruntime as ort 8 | 9 | from .utils import check_file, labels 10 | 11 | 12 | class ORTModelLoader: 13 | """ONNXRUNTIME model handler""" 14 | 15 | def __init__(self, path: str) -> None: 16 | self._load_model(path) 17 | self._get_metadata() 18 | 19 | def _load_model(self, model_path: str) -> None: 20 | """Load model and get model input and output information 21 | 22 | Args: 23 | model_path (str): Model path 24 | """ 25 | check_file(model_path, "Model is not exist!") # check model existence 26 | 27 | providers = ( 28 | ["CUDAExecutionProvider", "CPUExecutionProvider"] # use cuda if gpu is available 29 | if ort.get_device() == "GPU" 30 | else ["CPUExecutionProvider"] 31 | ) # get providers 32 | self.model = ort.InferenceSession(model_path, providers=providers) # load session 33 | 34 | model_input = self.model.get_inputs()[0] # get input info 35 | self.input_name = model_input.name 36 | _, _, self.width, self.height = model_input.shape 37 | 38 | output, seg_output = self.model.get_outputs() # get output info 39 | self.output_names = [output.name, seg_output.name] 40 | _, self.seg_chanels, self.seg_width, self.seg_height = seg_output.shape 41 | 42 | def _get_metadata(self, default_labels: List[str] = labels, default_stride: int = 32) -> None: 43 | """Get model metadata 44 | 45 | Args: 46 | default_labels (List[str], optional): Get model labels if specified. If model metadata 47 | doesn't contain label then use default labels (utils.labels). 48 | default_stride (int, optional): Get model stride if specified. model metadata 49 | doesn't contain stride the use default stride=32. 50 | """ 51 | metadata = self.model.get_modelmeta().custom_metadata_map 52 | self.labels = ast.literal_eval(metadata["names"]) if "names" in metadata else default_labels 53 | self.stride = ( 54 | ast.literal_eval(metadata["stride"]) if "stride" in metadata else default_stride 55 | ) 56 | 57 | def forward(self, input: npt.NDArray[np.float32]) -> List[npt.NDArray[np.float32]]: 58 | """Get model prediction 59 | 60 | Args: 61 | input (npt.NDArray[np.float32]): Input image. 62 | 63 | Returns: 64 | List[npt.NDArray[np.float32]]: Model outputs 65 | """ 66 | return self.model.run(self.output_names, {self.input_name: input}) 67 | 68 | 69 | class DNNModelLoader(ORTModelLoader): 70 | """OpenCV DNN model handler""" 71 | 72 | def __init__(self, path) -> None: 73 | super().__init__(path) 74 | 75 | self.model = cv2.dnn.readNet(path) # overide ort model 76 | 77 | if cv2.cuda.getCudaEnabledDeviceCount(): # use CUDA if available 78 | self.model.setPreferableBackend(cv2.dnn.DNN_BACKEND_CUDA) 79 | self.model.setPreferableTarget(cv2.dnn.DNN_TARGET_CUDA) 80 | else: # use CPU 81 | self.model.setPreferableBackend(cv2.dnn.DNN_BACKEND_OPENCV) 82 | self.model.setPreferableTarget(cv2.dnn.DNN_TARGET_CPU) 83 | 84 | def forward(self, input: npt.NDArray[np.float32]) -> List[npt.NDArray[np.float32]]: 85 | """Get model prediction 86 | 87 | Args: 88 | input (npt.NDArray[np.float32]): Input image. 89 | 90 | Returns: 91 | List[npt.NDArray[np.float32]]: Model outputs 92 | """ 93 | self.model.setInput(input, self.input_name) 94 | return self.model.forward(self.output_names) 95 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Tuple, Iterable 3 | 4 | import numpy as np 5 | import numpy.typing as npt 6 | 7 | 8 | def check_file(path: str, message: str) -> None: 9 | """Check if file is exist or not. 10 | 11 | Args: 12 | path (str): File path 13 | message (str): Message if file is not exist 14 | 15 | Raises: 16 | FileNotFoundError: If file is not exist 17 | """ 18 | path = Path(path) 19 | if not path.exists(): 20 | raise FileNotFoundError(message) 21 | 22 | 23 | def get_divable_size(imgsz: Iterable[int], stride: int) -> Iterable[int]: 24 | """Get divable image size by model stride 25 | 26 | Args: 27 | imgsz (Iterable[int]): Current image size [width, height] 28 | stride (int): Model stride 29 | 30 | Returns: 31 | Divable image size by model stride 32 | """ 33 | for i in range(len(imgsz)): 34 | div, mod = divmod(imgsz[i], stride) 35 | if mod > stride / 2: 36 | div += 1 37 | imgsz[i] = div * stride 38 | return imgsz 39 | 40 | 41 | def handle_overflow_box( 42 | box: npt.NDArray[np.int32], imgsz: Tuple[int, int] 43 | ) -> npt.NDArray[np.int32]: 44 | """Handle if box contain overflowing coordinate based on image size 45 | 46 | Args: 47 | box (npt.NDArray[np.int32]): box to draw [left, top, width, height] 48 | imgsz (Tuple[int, int]): Current image size [width, height] 49 | 50 | Returns: 51 | Non overflowing box 52 | """ 53 | if box[0] < 0: 54 | box[0] = 0 55 | elif box[0] >= imgsz[0]: 56 | box[0] = imgsz[0] - 1 57 | if box[1] < 0: 58 | box[1] = 0 59 | elif box[1] >= imgsz[1]: 60 | box[1] = imgsz[1] - 1 61 | box[2] = box[2] if box[0] + box[2] <= imgsz[0] else imgsz[0] - box[0] 62 | box[3] = box[3] if box[1] + box[3] <= imgsz[1] else box[3] - box[1] 63 | return box 64 | 65 | 66 | # fmt: off 67 | labels = [ 68 | "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", 69 | "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", 70 | "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", 71 | "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", 72 | "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", 73 | "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", 74 | "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", 75 | "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", 76 | "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", 77 | "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", 78 | "toothbrush", 79 | ] 80 | # fmt: on 81 | --------------------------------------------------------------------------------