├── .gitignore
├── LICENSE.md
├── README.md
├── detect_onnx.py
├── detect_openvino.py
├── inf_requirements.txt
├── media
├── 1.jpg
├── 2.jpg
├── 3.jpg
├── 4.jpg
└── soccer1.mp4
├── models
└── .gitkeep
└── utils
├── detector_utils.py
└── general.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 |
140 | # output files
141 | output
142 |
143 | # pytorch weights
144 | *.pt
145 | *.pth
146 |
147 | # onnx weights
148 | *.onnx
149 |
150 | # openvino IR
151 | *.bin
152 | *.xml
153 | *.mapping
154 |
155 | # mac os fsystem files
156 | *.DS_Store
157 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Copyright (c) 2022 SamSamhuns
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
14 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
15 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
17 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
18 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
19 | OR OTHER DEALINGS IN THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # YOLOv5 CPU Export and OpenVINO Inference
2 |
3 | [](https://www.codacy.com/gh/SamSamhuns/yolov5_export_cpu/dashboard?utm_source=github.com&utm_medium=referral&utm_content=SamSamhuns/yolov5_export_cpu&utm_campaign=Badge_Grade)
4 |
5 | Documentation on exporting YOLOv5 models for fast CPU inference using Intel's OpenVINO framework (Tested on commits up to June 6, 2022 in docker).
6 |
7 | ## Google Colab Conversion
8 |
9 | Convert yolov5 model to IR format with Google Colab. [](https://colab.research.google.com/drive/1K8gnZEka47Gbcp1eJbBaSe3GxngJdvio?usp=sharing) (Recommended)
10 |
11 | ## 1. Clone and set up the Official YOLOv5 GitHub repository
12 |
13 |
14 | Setup
15 |
16 | All package installations should be done in a virtualenv or conda env to prevent package conflict errors.
17 |
18 | - Install required requirements for onnx and openvino Inference
19 |
20 | ```bash
21 | pip install --upgrade pip
22 | pip install -r inf_requirements.txt
23 | ```
24 |
25 | - Clone and install requirements for yolov5 repository
26 |
27 | ```bash
28 | git clone https://github.com/ultralytics/yolov5 # clone repo
29 | cd yolov5
30 | pip install -r requirements.txt # base requirements
31 | ```
32 |
33 |
34 |
35 | ## 2. Export a Trained YOLOv5 Model as ONNX
36 |
37 |
38 | Export
39 |
40 | Export a pre-trained or custom trained YOLOv5 model to generate the respective ONNX, TorchScript and CoreML formats of the model. The pre-trained `yolov5s.pt` is the lightest and fastest model for CPU inference. Other slower but more accurate models include `yolov5m.pt, yolov5l.pt` and `yolov5x.pt`. All available model details at Ultralytics YOLOv5 [README](https://github.com/ultralytics/yolov5#pretrained-checkpoints).
41 |
42 | A custom training checkpoint i.e. `runs/exp/weights/best.pt` can be used for conversion as well.
43 |
44 | - Export a pre-trained light yolov5s.pt model at 640x640 with batch size 1
45 |
46 | ```bash
47 | python export.py --weights yolov5s.pt --include onnx --img 640 --batch 1
48 | ```
49 |
50 | - Export a custom checkpoint for dynamic input shape {BATCH_SIZE, 3, HEIGHT, WIDTH}. Note, for CPU inference mode, BATCH_SIZE must be set to 1. Install onnx-simplifier for simplifying onnx exports
51 |
52 | ```bash
53 | pip install onnx-simplifier==0.3.10
54 | python export.py --weights runs/exp/weights/best.pt --include onnx --dynamic --simplify
55 | ```
56 |
57 | - Cd to `yolov5_export_cpu` dir and move the onnx model to `yolov5_export_cpu/models` directory
58 |
59 | ```bash
60 | mv yolov5_export_cpu/models/
61 | ```
62 |
63 |
64 |
65 | ## 3. Test YOLOv5 ONNX model inference
66 |
67 |
68 | ONNX inference
69 |
70 | ```bash
71 | python detect_onnx.py -m image -i
72 | python detect_onnx.py -m video -i
73 | # python detect_onnx.py -h for more info
74 | ```
75 |
76 | Optional: To convert the all frames in the `output` directory into a mp4 video using `ffmpeg`, use `ffmpeg -r 25 -start_number 00001 -i output/frame_onnx_%5d.jpg -vcodec libx264 -y -an onnx_result.mp4`
77 |
78 |
79 |
80 | ## 4. Export ONNX to OpenVINO
81 |
82 | **Recommended Option A**
83 |
84 | ### Option A. Use OpenVINO's python dev library
85 |
86 |
87 | A1. Install OpenVINO python dev library
88 |
89 | Instructions for setting OpenVINO available [here](https://docs.openvino.ai/latest/openvino_docs_install_guides_install_dev_tools.html)
90 |
91 | ```bash
92 | # install required OpenVINO lib to convert ONNX to OpenVINO IR
93 | pip install openvino-dev[onnx]
94 | ```
95 |
96 |
97 |
98 |
99 | A2. Export ONNX to OpenVINO IR
100 |
101 | This will create the OpenVINO Intermediate Model Representation (IR) model files (xml and bin) in the directory `models/yolov5_openvino`.
102 |
103 | **Important Note:** --input_shape must be provided and match the img shape used to export ONNX model. Batching might not supported for CPU inference
104 |
105 | ```bash
106 | # export onnx to OpenVINO IR
107 | mo \
108 | --progress \
109 | --input_shape [1,3,640,640] \
110 | --input_model models/yolov5s.onnx \
111 | --output_dir models/yolov5_openvino
112 | ```
113 |
114 | [Full OpenVINO export options](https://docs.openvinotoolkit.org/latest/openvino_docs_MO_DG_prepare_model_convert_model_Converting_Model_General.html)
115 |
116 |
117 |
118 | ### Option B. Use OpenVINO Docker
119 |
120 |
121 | B1. Download Docker and OpenVINO Docker Image
122 |
123 | [Install docker](https://docs.docker.com/get-docker/) in your system if not already installed.
124 |
125 | Pass the docker run command below in a terminal which will automatically download the OpenVINO Docker Image and run it. The `models` directory containing the ONNX model must be in the current working directory.
126 |
127 | ```bash
128 | docker run -it --rm \
129 | -v $PWD/models:/home/openvino/models \
130 | openvino/ubuntu18_dev:latest \
131 | /bin/bash -c "cd /home/openvino/; bash"
132 | ```
133 |
134 |
135 |
136 |
137 | B2. Export ONNX model to an OpenVINO IR representation
138 |
139 | This will create the OpenVINO Intermediate Model Representation (IR) model files (xml and bin) in the directory `models/yolov5_openvino` which will be available in the host system outside the docker container.
140 |
141 | **Important Note:** --input_shape must be provided and match the img shape used to export ONNX model. Batching might not supported for CPU inference
142 |
143 | ```bash
144 | # inside the OpenVINO docker container
145 | mo \
146 | --progress \
147 | --input_shape [1,3,640,640] \
148 | --input_model models/yolov5s.onnx \
149 | --output_dir models/yolov5_openvino
150 | # exit OpenVINO docker container
151 | exit
152 | ```
153 |
154 | [Full OpenVINO export options](https://docs.openvinotoolkit.org/latest/openvino_docs_MO_DG_prepare_model_convert_model_Converting_Model_General.html)
155 |
156 |
157 |
158 | ## 5. Test YOLOv5 OpenVINO IR model CPU inference
159 |
160 |
161 | OpenVINO model inference
162 |
163 | ```bash
164 | python detect_openvino.py -m image -i
165 | python detect_openvino.py -m video -i
166 | # python detect_openvino.py -h for more info
167 | ```
168 |
169 | Optional: To convert the all frames in the `output` directory into a mp4 video using `ffmpeg`, use `ffmpeg -r 25 -start_number 00001 -i output/frame_openvino_%5d.jpg -vcodec libx264 -y -an openvino_result.mp4`
170 |
171 |
172 |
--------------------------------------------------------------------------------
/detect_onnx.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from functools import partial
4 |
5 | import torch
6 | import numpy as np
7 | import onnxruntime
8 |
9 | from utils.general import parse_arguments, DataStreamer
10 | from utils.detector_utils import save_output, preprocess_image, non_max_suppression, w_non_max_suppression
11 |
12 |
13 | @torch.no_grad()
14 | def detect_onnx(src_path: str,
15 | media_type: str,
16 | threshold: float = 0.6,
17 | official: bool = True,
18 | onnx_path: str = "models/yolov5s.onnx",
19 | output_dir: str = "output",
20 | num_classes: int = 80) -> None:
21 | session = onnxruntime.InferenceSession(onnx_path)
22 | model_batch_size = session.get_inputs()[0].shape[0]
23 | model_h = session.get_inputs()[0].shape[2]
24 | model_w = session.get_inputs()[0].shape[3]
25 | in_w = 640 if (model_w is None or isinstance(model_w, str)) else model_w
26 | in_h = 640 if (model_h is None or isinstance(model_h, str)) else model_h
27 | print("Input Layer: ", session.get_inputs()[0].name)
28 | print("Output Layer: ", session.get_outputs()[0].name)
29 | print("Model Input Shape: ", session.get_inputs()[0].shape)
30 | print("Model Output Shape: ", session.get_outputs()[0].shape)
31 |
32 | start_time = time.time()
33 | preprocess_func = partial(preprocess_image, in_size=(in_w, in_h))
34 | data_stream = DataStreamer(src_path, media_type, preprocess_func)
35 | if output_dir is not None:
36 | os.makedirs(output_dir, exist_ok=True)
37 |
38 | for i, (orig_input, model_input) in enumerate(data_stream, start=1):
39 | batch_size = model_input.shape[0] if isinstance(
40 | model_batch_size, str) else model_batch_size
41 | input_name = session.get_inputs()[0].name
42 |
43 | # inference
44 | start = time.time()
45 | outputs = session.run(None, {input_name: model_input})
46 | end = time.time()
47 |
48 | inf_time = end - start
49 | print('Inference Time: {} Seconds Single Image'.format(inf_time))
50 | fps = 1. / (end - start)
51 | print('Estimated Inference FPS: {} FPS Single Image'.format(fps))
52 |
53 | batch_detections = []
54 | # model.model[-1].export = boolean ---> True:3 False:4
55 | if official: # recommended
56 | # model.model[-1].export = False ---> outputs[0] (1, xxxx, 85)
57 | # Use the official code directly
58 | batch_detections = torch.from_numpy(np.array(outputs[0]))
59 | batch_detections = non_max_suppression(
60 | batch_detections, conf_thres=0.4, iou_thres=0.5, agnostic=False)
61 | else:
62 | # model.model[-1].export = False ---> outputs[1]/outputs[2]/outputs[2]
63 | # model.model[-1].export = True ---> outputs
64 | # (1, 3, 20, 20, 85)
65 | # (1, 3, 40, 40, 85)
66 | # (1, 3, 80, 80, 85)
67 | # same anchors for 5s, 5l, 5x
68 | anchors = [[116, 90, 156, 198, 373, 326], [
69 | 30, 61, 62, 45, 59, 119], [10, 13, 16, 30, 33, 23]]
70 |
71 | boxs = []
72 | a = torch.tensor(anchors).float().view(3, -1, 2)
73 | anchor_grid = a.clone().view(3, 1, -1, 1, 1, 2)
74 | if len(outputs) == 4:
75 | outputs = [outputs[1], outputs[2], outputs[3]]
76 | for index, out in enumerate(outputs):
77 | out = torch.from_numpy(out)
78 | # batch = out.shape[1]
79 | feature_w = out.shape[2]
80 | feature_h = out.shape[3]
81 |
82 | # Feature map corresponds to the original image zoom factor
83 | stride_w = int(in_w / feature_w)
84 | stride_h = int(in_h / feature_h)
85 |
86 | grid_x, grid_y = np.meshgrid(
87 | np.arange(feature_w), np.arange(feature_h))
88 |
89 | # cx, cy, w, h
90 | pred_boxes = torch.FloatTensor(out[..., :4].shape)
91 | pred_boxes[..., 0] = (torch.sigmoid(
92 | out[..., 0]) * 2.0 - 0.5 + grid_x) * stride_w # cx
93 | pred_boxes[..., 1] = (torch.sigmoid(
94 | out[..., 1]) * 2.0 - 0.5 + grid_y) * stride_h # cy
95 | pred_boxes[..., 2:4] = (torch.sigmoid(
96 | out[..., 2:4]) * 2) ** 2 * anchor_grid[index] # wh
97 |
98 | conf = torch.sigmoid(out[..., 4])
99 | pred_cls = torch.sigmoid(out[..., 5:])
100 |
101 | output = torch.cat((pred_boxes.view(batch_size, -1, 4),
102 | conf.view(batch_size, -1, 1),
103 | pred_cls.view(batch_size, -1, num_classes)),
104 | -1)
105 | boxs.append(output)
106 |
107 | outputx = torch.cat(boxs, 1)
108 | # NMS
109 | batch_detections = w_non_max_suppression(
110 | outputx, num_classes, conf_thres=0.4, nms_thres=0.3)
111 | if output_dir is not None:
112 | save_path = os.path.join(
113 | output_dir, f"frame_onnx_{str(i).zfill(5)}.jpg")
114 | save_output(batch_detections[0], orig_input, save_path,
115 | threshold=threshold, model_in_HW=(in_h, in_w),
116 | line_thickness=None, text_bg_alpha=0.0)
117 |
118 | elapse_time = time.time() - start_time
119 | print(f'Total Frames: {i}')
120 | print(f'Total Elapsed Time: {elapse_time:.3f} Seconds'.format())
121 | print(f'Final Estimated FPS: {i / (elapse_time):.2f}')
122 |
123 |
124 | if __name__ == '__main__':
125 | args = parse_arguments("YoloV5 onnx demo")
126 | t1 = time.time()
127 | detect_onnx(src_path=args.input_path,
128 | media_type=args.media_type,
129 | threshold=args.threshold,
130 | official=True, # official yolov5 post-processing
131 | onnx_path=args.onnx_path,
132 | output_dir=args.output_dir,
133 | num_classes=args.num_classes)
134 |
--------------------------------------------------------------------------------
/detect_openvino.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import argparse
4 | from functools import partial
5 |
6 | import cv2
7 | import torch
8 | import openvino as ov
9 |
10 | from utils.general import DataStreamer
11 | from utils.detector_utils import save_output, non_max_suppression, preprocess_image
12 |
13 |
14 | def parse_arguments(desc: str) -> argparse.Namespace:
15 | parser = argparse.ArgumentParser(description=desc)
16 | parser.add_argument('-i', '--input_path', dest='input_path', required=True, type=str,
17 | help='Path to Input: Video File or Image file')
18 | parser.add_argument('--model_xml', dest='model_xml', default='models/yolov5_openvino/yolov5s.xml',
19 | help='OpenVINO XML File. (default: %(default)s)')
20 | parser.add_argument('--model_bin', dest='model_bin', default='models/yolov5_openvino/yolov5s.bin',
21 | help='OpenVINO BIN File. (default: %(default)s)')
22 | parser.add_argument('-d', '--target_device', dest='target_device', default='CPU', type=str,
23 | help='Target Plugin: CPU, GPU, FPGA, MYRIAD, MULTI:CPU,GPU, HETERO:FPGA,CPU. (default: %(default)s)')
24 | parser.add_argument('-m', '--media_type', dest='media_type', default='image', type=str,
25 | choices=('image', 'video'),
26 | help='Type of Input: image, video. (default: %(default)s)')
27 | parser.add_argument('-o', '--output_dir', dest='output_dir', default='output', type=str,
28 | help='Output directory. (default: %(default)s)')
29 | parser.add_argument('-t', '--threshold', dest='threshold', default=0.6, type=float,
30 | help='Object Detection Accuracy Threshold. (default: %(default)s)')
31 |
32 | return parser.parse_args()
33 |
34 |
35 | def get_openvino_core_net_exec(model_xml_path: str, model_bin_path: str, target_device: str = "CPU"):
36 | # load openvino Core object
37 | core = ov.Core()
38 |
39 | # load CPU extensions if availabel
40 | lib_ext_path = '/opt/intel/openvino/inference_engine/lib/intel64/libcpu_extension.so'
41 | if 'CPU' in target_device and os.path.exists(lib_ext_path):
42 | print(f"Loading CPU extensions from {lib_ext_path}")
43 | core.add_extension(lib_ext_path)
44 |
45 | # load openVINO network
46 | model = core.read_model(
47 | model=model_xml_path, weights=model_bin_path)
48 |
49 | # create executable network
50 | compiled_model = core.compile_model(
51 | model=model, device_name=target_device)
52 |
53 | return core, model, compiled_model
54 |
55 |
56 | def inference(args: argparse.Namespace) -> None:
57 | """Run Object Detection Application
58 |
59 | args: ArgumentParser Namespace
60 | """
61 | print(f"Running Inference for {args.media_type}: {args.input_path}")
62 | # Load model and executable
63 | core, model, compiled_model = get_openvino_core_net_exec(
64 | args.model_xml, args.model_bin, args.target_device)
65 |
66 | # Get Input, Output Information
67 | print("Available Devices: ", core.available_devices)
68 | print("Input layer names ", model.inputs[0].names)
69 | print("Output layer names ", model.outputs[0].names)
70 | print("Input Layer: ", model.inputs)
71 | print("Output Layer: ", model.outputs)
72 |
73 | output_layer_ir = compiled_model.output(model.outputs[0].names.pop())
74 | if args.output_dir is not None:
75 | os.makedirs(args.output_dir, exist_ok=True)
76 |
77 | start_time = time.time()
78 | _, C, H, W = model.inputs[0].shape
79 | preprocess_func = partial(preprocess_image, in_size=(W, H))
80 | data_stream = DataStreamer(
81 | args.input_path, args.media_type, preprocess_func)
82 |
83 | for i, (orig_input, model_input) in enumerate(data_stream, start=1):
84 | # Inference
85 | start = time.time()
86 | # results = compiled_model.infer(inputs={InputLayer: model_input})
87 | results = compiled_model([model_input])
88 | end = time.time()
89 |
90 | inf_time = end - start
91 | print('Inference Time: {} Seconds Single Image'.format(inf_time))
92 | fps = 1. / (end - start)
93 | print('Estimated Inference FPS: {} FPS Single Image'.format(fps))
94 |
95 | # Write fos, inference info on Image
96 | text = 'FPS: {}, INF: {}'.format(round(fps, 2), round(inf_time, 2))
97 | cv2.putText(orig_input, text, (0, 20), cv2.FONT_HERSHEY_COMPLEX,
98 | 0.6, (0, 125, 255), 1)
99 |
100 | # Print Bounding Boxes on Image
101 | detections = results[output_layer_ir]
102 | detections = torch.from_numpy(detections)
103 | detections = non_max_suppression(
104 | detections, conf_thres=0.4, iou_thres=0.5, agnostic=False)
105 |
106 | save_path = os.path.join(
107 | args.output_dir, f"frame_openvino_{str(i).zfill(5)}.jpg")
108 | save_output(detections[0], orig_input, save_path,
109 | threshold=args.threshold, model_in_HW=(H, W),
110 | line_thickness=None, text_bg_alpha=0.0)
111 |
112 | elapse_time = time.time() - start_time
113 | print(f'Total Frames: {i}')
114 | print(f'Total Elapsed Time: {elapse_time:.3f} Seconds'.format())
115 | print(f'Final Estimated FPS: {i / (elapse_time):.2f}')
116 |
117 |
118 | if __name__ == '__main__':
119 | parsed_args = parse_arguments(
120 | desc="Basic OpenVINO Example for person/object detection")
121 | inference(parsed_args)
122 |
--------------------------------------------------------------------------------
/inf_requirements.txt:
--------------------------------------------------------------------------------
1 | onnxruntime==1.18.1
2 | opencv-python==4.10.0.84
3 | openvino==2024.3.0
4 | torch==2.4.0
5 | torchvision==0.19.0
6 |
--------------------------------------------------------------------------------
/media/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SamSamhuns/yolov5_export_cpu/200f209ce6440a8851ea8598c0ae9e915890e3bb/media/1.jpg
--------------------------------------------------------------------------------
/media/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SamSamhuns/yolov5_export_cpu/200f209ce6440a8851ea8598c0ae9e915890e3bb/media/2.jpg
--------------------------------------------------------------------------------
/media/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SamSamhuns/yolov5_export_cpu/200f209ce6440a8851ea8598c0ae9e915890e3bb/media/3.jpg
--------------------------------------------------------------------------------
/media/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SamSamhuns/yolov5_export_cpu/200f209ce6440a8851ea8598c0ae9e915890e3bb/media/4.jpg
--------------------------------------------------------------------------------
/media/soccer1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SamSamhuns/yolov5_export_cpu/200f209ce6440a8851ea8598c0ae9e915890e3bb/media/soccer1.mp4
--------------------------------------------------------------------------------
/models/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SamSamhuns/yolov5_export_cpu/200f209ce6440a8851ea8598c0ae9e915890e3bb/models/.gitkeep
--------------------------------------------------------------------------------
/utils/detector_utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import time
3 | import torch
4 | import torchvision
5 | import numpy as np
6 | from typing import Tuple, Optional, Union
7 |
8 | from utils.general import CLASS_LABELS
9 |
10 |
11 | def preprocess_image(
12 | cv2_img: np.ndarray,
13 | in_size: Tuple[int, int] = (640, 640)
14 | ) -> np.ndarray:
15 | """preprocesses cv2 image and returns a norm np.ndarray
16 |
17 | cv2_img = cv2 image
18 | in_size: in_width, in_height
19 | """
20 | resized = pad_resize_image(cv2_img, in_size)
21 | img_in = np.transpose(resized, (2, 0, 1)).astype(np.float32) # HWC -> CHW
22 | img_in /= 255.0
23 | return img_in
24 |
25 |
26 | def save_output(
27 | detections,
28 | image_src: np.ndarray,
29 | save_path: str,
30 | threshold: float,
31 | model_in_HW: Tuple[int, int],
32 | line_thickness: Optional[int] = None,
33 | text_bg_alpha: float = 0.0
34 | ) -> None:
35 | image_src = cv2.cvtColor(image_src, cv2.COLOR_RGB2BGR)
36 | labels = detections[..., -1].numpy()
37 | boxs = detections[..., :4].numpy()
38 | confs = detections[..., 4].numpy()
39 |
40 | if isinstance(image_src, str):
41 | image_src = cv2.imread(image_src)
42 | elif isinstance(image_src, np.ndarray):
43 | image_src = image_src
44 |
45 | mh, mw = model_in_HW
46 | h, w = image_src.shape[:2]
47 | boxs[:, :] = scale_coords((mh, mw), boxs[:, :], (h, w)).round()
48 | tl = line_thickness or round(0.002 * (w + h) / 2) + 1
49 | for i, box in enumerate(boxs):
50 | if confs[i] >= threshold:
51 | x1, y1, x2, y2 = map(int, box)
52 | np.random.seed(int(labels[i]) + 2020)
53 | color = [np.random.randint(0, 255), 0, np.random.randint(0, 255)]
54 | cv2.rectangle(image_src, (x1, y1), (x2, y2), color, thickness=max(
55 | int((w + h) / 600), 1), lineType=cv2.LINE_AA)
56 | label = '%s %.2f' % (CLASS_LABELS[int(labels[i])], confs[i])
57 | t_size = cv2.getTextSize(
58 | label, 0, fontScale=tl / 3, thickness=1)[0]
59 | c2 = x1 + t_size[0] + 3, y1 - t_size[1] - 5
60 | if text_bg_alpha == 0.0:
61 | cv2.rectangle(image_src, (x1 - 1, y1), c2,
62 | color, cv2.FILLED, cv2.LINE_AA)
63 | else:
64 | # Transparent text background
65 | alphaReserve = text_bg_alpha # 0: opaque 1: transparent
66 | BChannel, GChannel, RChannel = color
67 | xMin, yMin = int(x1 - 1), int(y1 - t_size[1] - 3)
68 | xMax, yMax = int(x1 + t_size[0]), int(y1)
69 | image_src[yMin:yMax, xMin:xMax, 0] = image_src[yMin:yMax,
70 | xMin:xMax, 0] * alphaReserve + BChannel * (1 - alphaReserve)
71 | image_src[yMin:yMax, xMin:xMax, 1] = image_src[yMin:yMax,
72 | xMin:xMax, 1] * alphaReserve + GChannel * (1 - alphaReserve)
73 | image_src[yMin:yMax, xMin:xMax, 2] = image_src[yMin:yMax,
74 | xMin:xMax, 2] * alphaReserve + RChannel * (1 - alphaReserve)
75 | cv2.putText(image_src, label, (x1 + 3, y1 - 4), 0, tl / 3, [255, 255, 255],
76 | thickness=1, lineType=cv2.LINE_AA)
77 | print("bbox:", box, "conf:", confs[i],
78 | "class:", CLASS_LABELS[int(labels[i])])
79 | cv2.imwrite(save_path, image_src)
80 |
81 |
82 | def w_bbox_iou(
83 | box1: torch.Tensor,
84 | box2: torch.Tensor,
85 | x1y1x2y2: bool = True
86 | ) -> torch.Tensor:
87 | """Calculate IOU
88 |
89 | """
90 | if not x1y1x2y2:
91 | b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
92 | b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
93 | b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
94 | b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
95 | else:
96 | b1_x1, b1_y1, b1_x2, b1_y2 = box1[:,
97 | 0], box1[:, 1], box1[:, 2], box1[:, 3]
98 | b2_x1, b2_y1, b2_x2, b2_y2 = box2[:,
99 | 0], box2[:, 1], box2[:, 2], box2[:, 3]
100 |
101 | inter_rect_x1 = torch.max(b1_x1, b2_x1)
102 | inter_rect_y1 = torch.max(b1_y1, b2_y1)
103 | inter_rect_x2 = torch.min(b1_x2, b2_x2)
104 | inter_rect_y2 = torch.min(b1_y2, b2_y2)
105 |
106 | inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1, min=0) * \
107 | torch.clamp(inter_rect_y2 - inter_rect_y1 + 1, min=0)
108 |
109 | b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
110 | b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
111 |
112 | iou = inter_area / (b1_area + b2_area - inter_area + 1e-16)
113 |
114 | return iou
115 |
116 |
117 | def w_non_max_suppression(prediction, num_classes, conf_thres=0.5, nms_thres=0.4):
118 | # Find the upper left and lower right corners
119 | # box_corner = prediction.new(prediction.shape)
120 | box_corner = torch.FloatTensor(prediction.shape)
121 | box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
122 | box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
123 | box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
124 | box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
125 | prediction[:, :, :4] = box_corner[:, :, :4]
126 |
127 | output = [None for _ in range(len(prediction))]
128 | for image_i, image_pred in enumerate(prediction):
129 | # Use confidence for the first round of screening
130 | conf_mask = (image_pred[:, 4] >= conf_thres).squeeze()
131 | image_pred = image_pred[conf_mask]
132 |
133 | if not image_pred.size(0):
134 | continue
135 |
136 | # Obtain type and its confidence
137 | class_conf, class_pred = torch.max(
138 | image_pred[:, 5:5 + num_classes], 1, keepdim=True)
139 |
140 | # content obtained is (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
141 | detections = torch.cat(
142 | (image_pred[:, :5], class_conf.float(), class_pred.float()), 1)
143 |
144 | # Type of acquisition
145 | unique_labels = detections[:, -1].cpu().unique()
146 |
147 | if prediction.is_cuda:
148 | unique_labels = unique_labels.cuda()
149 |
150 | for c in unique_labels:
151 | # Obtain all prediction results after a certain type of preliminary screening
152 | detections_class = detections[detections[:, -1] == c]
153 | # Sort according to the confidence of the existence of the object
154 | _, conf_sort_index = torch.sort(
155 | detections_class[:, 4], descending=True)
156 | detections_class = detections_class[conf_sort_index]
157 | # Non-maximum suppression
158 | max_detections = []
159 | while detections_class.size(0):
160 | # Take out this category with the highest confidence, judge step by step, and
161 | # judge whether the degree of coincidence is greater than nms_thres, and if so, remove it
162 | max_detections.append(detections_class[0].unsqueeze(0))
163 | if len(detections_class) == 1:
164 | break
165 | ious = w_bbox_iou(max_detections[-1], detections_class[1:])
166 | detections_class = detections_class[1:][ious < nms_thres]
167 | # Stacked
168 | max_detections = torch.cat(max_detections).data
169 | # Add max detections to outputs
170 | output[image_i] = max_detections if output[image_i] is None else torch.cat(
171 | (output[image_i], max_detections))
172 |
173 | return output
174 |
175 |
176 | def box_iou(
177 | box1: torch.Tensor,
178 | box2: torch.Tensor
179 | ) -> torch.Tensor:
180 | """Return intersection-over-union (Jaccard index) of boxes.
181 |
182 | # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
183 | Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
184 | Arguments:
185 | box1 (Tensor[N, 4])
186 | box2 (Tensor[M, 4])
187 | Returns:
188 | iou (Tensor[N, M]): the NxM matrix containing the pairwise
189 | IoU values for every element in boxes1 and boxes2
190 | """
191 |
192 | def box_area(box):
193 | # box = 4xn
194 | return (box[2] - box[0]) * (box[3] - box[1])
195 |
196 | area1 = box_area(box1.T)
197 | area2 = box_area(box2.T)
198 |
199 | # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
200 | inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) -
201 | torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
202 | # iou = inter / (area1 + area2 - inter)
203 | return inter / (area1[:, None] + area2 - inter)
204 |
205 |
206 | def non_max_suppression(
207 | prediction: torch.Tensor,
208 | conf_thres: float = 0.25,
209 | iou_thres: float = 0.45,
210 | classes: Optional[torch.Tensor] = None,
211 | agnostic: bool = False,
212 | multi_label: bool = False,
213 | labels: Tuple[str] = ()
214 | ) -> torch.Tensor:
215 | """Runs Non-Maximum Suppression (NMS) on inference results
216 |
217 | Returns:
218 | list of detections, on (n,6) tensor per image [xyxy, conf, cls]
219 | """
220 | nc = prediction.shape[2] - 5 # number of classes
221 | xc = prediction[..., 4] > conf_thres # candidates
222 |
223 | # Checks
224 | assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
225 | assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
226 |
227 | # Settings
228 | # (pixels) maximum and minimum box width and height
229 | max_wh = 4096 # min_wh = 2
230 | max_det = 300 # maximum number of detections per image
231 | max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
232 | time_limit = 10.0 # seconds to quit after
233 | redundant = True # require redundant detections
234 | multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
235 | merge = False # use merge-NMS
236 |
237 | t = time.time()
238 | output = [torch.zeros((0, 6), device=prediction.device)
239 | ] * prediction.shape[0]
240 | for xi, x in enumerate(prediction): # image index, image inference
241 | # Apply constraints
242 | # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
243 | x = x[xc[xi]] # confidence
244 |
245 | # Cat apriori labels if autolabelling
246 | if labels and len(labels[xi]):
247 | lxi = labels[xi]
248 | v = torch.zeros((len(lxi), nc + 5), device=x.device)
249 | v[:, :4] = lxi[:, 1:5] # box
250 | v[:, 4] = 1.0 # conf
251 | v[range(len(lxi)), lxi[:, 0].long() + 5] = 1.0 # cls
252 | x = torch.cat((x, v), 0)
253 |
254 | # If none remain process next image
255 | if not x.shape[0]:
256 | continue
257 |
258 | # Compute conf
259 | x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
260 |
261 | # Box (center x, center y, width, height) to (x1, y1, x2, y2)
262 | box = xywh2xyxy(x[:, :4])
263 |
264 | # Detections matrix nx6 (xyxy, conf, cls)
265 | if multi_label:
266 | i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
267 | x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
268 | else: # best class only
269 | conf, j = x[:, 5:].max(1, keepdim=True)
270 | x = torch.cat((box, conf, j.float()), 1)[
271 | conf.view(-1) > conf_thres]
272 |
273 | # Filter by class
274 | if classes is not None:
275 | x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
276 |
277 | # Apply finite constraint
278 | # if not torch.isfinite(x).all():
279 | # x = x[torch.isfinite(x).all(1)]
280 |
281 | # Check shape
282 | n = x.shape[0] # number of boxes
283 | if not n: # no boxes
284 | continue
285 | elif n > max_nms: # excess boxes
286 | # sort by confidence
287 | x = x[x[:, 4].argsort(descending=True)[:max_nms]]
288 |
289 | # Batched NMS
290 | c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
291 | # boxes (offset by class), scores
292 | boxes, scores = x[:, :4] + c, x[:, 4]
293 | i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
294 | if i.shape[0] > max_det: # limit detections
295 | i = i[:max_det]
296 | if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
297 | # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
298 | iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
299 | weights = iou * scores[None] # box weights
300 | x[i, :4] = torch.mm(weights, x[:, :4]).float(
301 | ) / weights.sum(1, keepdim=True) # merged boxes
302 | if redundant:
303 | i = i[iou.sum(1) > 1] # require redundancy
304 |
305 | output[xi] = x[i]
306 | if (time.time() - t) > time_limit:
307 | print(f'WARNING: NMS time limit {time_limit}s exceeded')
308 | break # time limit exceeded
309 |
310 | return output
311 |
312 |
313 | def pad_resize_image(
314 | cv2_img: np.ndarray,
315 | new_size: Tuple[int, int] = (640, 480),
316 | color: Tuple[int, int, int] = (125, 125, 125)
317 | ) -> np.ndarray:
318 | """Resize and pad image with color if necessary, maintaining orig scale
319 |
320 | args:
321 | cv2_img: numpy.ndarray = cv2 image
322 | new_size: tuple(int, int) = (width, height)
323 | color: tuple(int, int, int) = (B, G, R)
324 | """
325 | in_h, in_w = cv2_img.shape[:2]
326 | new_w, new_h = new_size
327 | # rescale down
328 | scale = min(new_w / in_w, new_h / in_h)
329 | # get new sacled widths and heights
330 | scale_new_w, scale_new_h = int(in_w * scale), int(in_h * scale)
331 | resized_img = cv2.resize(cv2_img, (scale_new_w, scale_new_h))
332 | # calculate deltas for padding
333 | d_w = max(new_w - scale_new_w, 0)
334 | d_h = max(new_h - scale_new_h, 0)
335 | # center image with padding on top/bottom or left/right
336 | top, bottom = d_h // 2, d_h - (d_h // 2)
337 | left, right = d_w // 2, d_w - (d_w // 2)
338 | pad_resized_img = cv2.copyMakeBorder(resized_img,
339 | top, bottom, left, right,
340 | cv2.BORDER_CONSTANT,
341 | value=color)
342 | return pad_resized_img
343 |
344 |
345 | def clip_coords(
346 | boxes: Union[torch.Tensor, np.ndarray],
347 | img_shape: Tuple[int, int]
348 | ) -> None:
349 | # Clip bounding xyxy bounding boxes to image shape (height, width)
350 | if isinstance(boxes, torch.Tensor):
351 | boxes[:, 0].clamp_(0, img_shape[1]) # x1
352 | boxes[:, 1].clamp_(0, img_shape[0]) # y1
353 | boxes[:, 2].clamp_(0, img_shape[1]) # x2
354 | boxes[:, 3].clamp_(0, img_shape[0]) # y2
355 | else: # np.array
356 | boxes[:, 0].clip(0, img_shape[1], out=boxes[:, 0]) # x1
357 | boxes[:, 1].clip(0, img_shape[0], out=boxes[:, 1]) # y1
358 | boxes[:, 2].clip(0, img_shape[1], out=boxes[:, 2]) # x2
359 | boxes[:, 3].clip(0, img_shape[0], out=boxes[:, 3]) # y2
360 |
361 |
362 | def scale_coords(img1_shape: Tuple[int, int], coords: np.ndarray, img0_shape: Tuple[int, int], ratio_pad=None):
363 | # Rescale coords (xyxy) from img1_shape to img0_shape
364 | if ratio_pad is None: # calculate from img0_shape
365 | gain = min(img1_shape[0] / img0_shape[0],
366 | img1_shape[1] / img0_shape[1])
367 | pad = (img1_shape[1] - img0_shape[1] * gain) / \
368 | 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
369 | else:
370 | gain = ratio_pad[0][0]
371 | pad = ratio_pad[1]
372 |
373 | coords[:, [0, 2]] -= pad[0] # x padding
374 | coords[:, [1, 3]] -= pad[1] # y padding
375 | coords[:, :4] /= gain
376 | clip_coords(coords, img0_shape)
377 | return coords
378 |
379 |
380 | def xyxy2xywh(x: torch.Tensor) -> torch.Tensor:
381 | # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
382 | y = torch.zeros_like(x) if isinstance(
383 | x, torch.Tensor) else np.zeros_like(x)
384 | y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
385 | y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
386 | y[:, 2] = x[:, 2] - x[:, 0] # width
387 | y[:, 3] = x[:, 3] - x[:, 1] # height
388 | return y
389 |
390 |
391 | def xywh2xyxy(x: torch.Tensor) -> torch.Tensor:
392 | # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
393 | y = torch.zeros_like(x) if isinstance(
394 | x, torch.Tensor) else np.zeros_like(x)
395 | y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
396 | y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
397 | y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
398 | y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
399 | return y
400 |
--------------------------------------------------------------------------------
/utils/general.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import argparse
3 | import os.path as osp
4 | from typing import Callable
5 |
6 | import cv2
7 | import numpy as np
8 |
9 |
10 | # mscoco class names
11 | CLASS_LABELS = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
12 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
13 | 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
14 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
15 | 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
16 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
17 | 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
18 | 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
19 | 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
20 |
21 |
22 | def parse_arguments(desc):
23 | parser = argparse.ArgumentParser(description=desc)
24 | parser.add_argument('-i', '--input_path', dest='input_path', required=True, type=str,
25 | help='Path to Input: Video File or Image file')
26 | parser.add_argument('-m', '--media_type', dest='media_type', default='image', type=str,
27 | choices=('image', 'video'),
28 | help='Type of Input: image, video. (default: %(default)s)')
29 | parser.add_argument('-t', '--threshold', dest='threshold', default=0.6, type=float,
30 | help='Detection Threshold. (default: %(default)s)')
31 | parser.add_argument('--ox', '--onnx_path', dest='onnx_path', default="models/yolov5s.onnx", type=str,
32 | help='Path to ONNX model. (default: %(default)s)')
33 | parser.add_argument('-o', '--output_dir', dest='output_dir', default='output', type=str,
34 | help='Output directory. (default: %(default)s)')
35 | parser.add_argument('-c', '--num_classes', dest='num_classes', default=80, type=int,
36 | help='Num of classes. (default: %(default)s)')
37 |
38 | return parser.parse_args()
39 |
40 |
41 | class DataStreamer(object):
42 |
43 | """Iterable DataStreamer class for generating numpy arr images
44 | Generates orig image and pre-processed image
45 |
46 | For loading data into detectors
47 | """
48 |
49 | def __init__(self, src_path: str, media_type: str = "image", preprocess_func: Callable = None):
50 | """Init DataStreamer Obj
51 |
52 | src_path : str
53 | path to a single image/video or path to directory containing images
54 | media_type : str
55 | inference media_type "image" or "video"
56 | preprocess_func : Callable function
57 | preprocessesing function applied to PIL images
58 | """
59 | if media_type not in {'video', 'image'}:
60 | raise NotImplementedError(
61 | f"{media_type} not supported in streamer. Use video or image")
62 | self.img_path_list = []
63 | self.vid_path_list = []
64 | self.idx = 0
65 | self.media_type = media_type
66 | self.preprocess_func = preprocess_func
67 |
68 | if media_type == "video":
69 | if osp.isfile(src_path):
70 | self.vid_path_list.append(src_path)
71 | self.vcap = cv2.VideoCapture(src_path)
72 | elif osp.isdir(src_path):
73 | raise NotImplementedError(
74 | f"dir iteration supported for video media_type. {src_path} must be a video file")
75 | elif media_type == "image":
76 | if osp.isfile(src_path):
77 | self.img_path_list.append(src_path)
78 | elif osp.isdir(src_path):
79 | img_exts = ['*.png', '*.PNG', '*.jpg', '*.jpeg']
80 | for ext in img_exts:
81 | self.img_path_list.extend(
82 | glob.glob(osp.join(src_path, ext)))
83 |
84 | def __iter__(self):
85 | return self
86 |
87 | def __next__(self):
88 | """Get next image or frame as numpy array
89 |
90 | """
91 | orig_img = None
92 | if self.media_type == 'image':
93 | if self.idx < len(self.img_path_list):
94 | orig_img = cv2.imread(self.img_path_list[self.idx])
95 | orig_img = orig_img[..., ::-1]
96 | self.idx += 1
97 | elif self.media_type == 'video':
98 | if self.idx < len(self.vid_path_list):
99 | ret, frame = self.vcap.read()
100 | if ret:
101 | orig_img = frame[..., ::-1]
102 | else:
103 | self.idx += 1
104 | if orig_img is not None:
105 | proc_img = None
106 | if self.preprocess_func is not None:
107 | proc_img = self.preprocess_func(orig_img)
108 | proc_img = np.expand_dims(proc_img, axis=0)
109 | return np.array(orig_img), proc_img
110 | raise StopIteration
111 |
--------------------------------------------------------------------------------