├── .gitignore ├── README.md ├── client.js ├── coco.names ├── index.html ├── main.py └── yolov7-tiny_480x640.onnx /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # real-time-object-detection-with-webrtc-and-yolo 2 | 3 | A solution code for the real time object detection with WebRTC and YOLO article - https://softwarescalability.com/editorial/real-time-object-detection-with-webrtc-and-yolo 4 | -------------------------------------------------------------------------------- /client.js: -------------------------------------------------------------------------------- 1 | var pc = null; 2 | 3 | function negotiate() { 4 | pc.addTransceiver('video', {direction: 'recvonly'}); 5 | pc.addTransceiver('audio', {direction: 'recvonly'}); 6 | return pc.createOffer().then(function(offer) { 7 | return pc.setLocalDescription(offer); 8 | }).then(function() { 9 | // wait for ICE gathering to complete 10 | return new Promise(function(resolve) { 11 | if (pc.iceGatheringState === 'complete') { 12 | resolve(); 13 | } else { 14 | function checkState() { 15 | if (pc.iceGatheringState === 'complete') { 16 | pc.removeEventListener('icegatheringstatechange', checkState); 17 | resolve(); 18 | } 19 | } 20 | pc.addEventListener('icegatheringstatechange', checkState); 21 | } 22 | }); 23 | }).then(function() { 24 | var offer = pc.localDescription; 25 | return fetch('/offer', { 26 | body: JSON.stringify({ 27 | sdp: offer.sdp, 28 | type: offer.type, 29 | }), 30 | headers: { 31 | 'Content-Type': 'application/json' 32 | }, 33 | method: 'POST' 34 | }); 35 | }).then(function(response) { 36 | return response.json(); 37 | }).then(function(answer) { 38 | return pc.setRemoteDescription(answer); 39 | }).catch(function(e) { 40 | alert(e); 41 | }); 42 | } 43 | 44 | function start() { 45 | var config = { 46 | sdpSemantics: 'unified-plan' 47 | }; 48 | 49 | if (document.getElementById('use-stun').checked) { 50 | config.iceServers = [{urls: ['stun:stun.l.google.com:19302']}]; 51 | } 52 | 53 | pc = new RTCPeerConnection(config); 54 | 55 | // connect audio / video 56 | pc.addEventListener('track', function(evt) { 57 | if (evt.track.kind == 'video') { 58 | document.getElementById('video').srcObject = evt.streams[0]; 59 | } else { 60 | document.getElementById('audio').srcObject = evt.streams[0]; 61 | } 62 | }); 63 | 64 | document.getElementById('start').style.display = 'none'; 65 | negotiate(); 66 | document.getElementById('stop').style.display = 'inline-block'; 67 | } 68 | 69 | function stop() { 70 | document.getElementById('stop').style.display = 'none'; 71 | 72 | // close peer connection 73 | setTimeout(function() { 74 | pc.close(); 75 | }, 500); 76 | } -------------------------------------------------------------------------------- /coco.names: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | WebRTC webcam 6 | 23 | 24 | 25 | 26 |
27 | 28 | 29 |
30 | 31 | 32 | 33 |
34 |

Media

35 | 36 | 37 | 38 |
39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import json 4 | import logging 5 | import os 6 | import platform 7 | 8 | from aiohttp import web 9 | import cv2 10 | import numpy as np 11 | from av import VideoFrame 12 | from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack 13 | from aiortc.contrib.media import MediaPlayer, MediaRelay 14 | 15 | ROOT = os.path.dirname(__file__) 16 | 17 | relay = None 18 | webcam = None 19 | 20 | model = "yolov7-tiny_480x640.onnx" 21 | 22 | class YOLOVideoStreamTrack(VideoStreamTrack): 23 | """ 24 | A video track thats returns camera track with annotated detected objects. 25 | """ 26 | def __init__(self, conf_thres=0.7, iou_thres=0.5): 27 | super().__init__() # don't forget this! 28 | self.conf_threshold = conf_thres 29 | self.iou_threshold = iou_thres 30 | 31 | video = cv2.VideoCapture(0) 32 | self.video = video 33 | 34 | # Initialize model 35 | self.net = cv2.dnn.readNet(model) 36 | input_shape = os.path.splitext(os.path.basename(model))[0].split('_')[-1].split('x') 37 | self.input_height = int(input_shape[0]) 38 | self.input_width = int(input_shape[1]) 39 | 40 | self.class_names = list(map(lambda x: x.strip(), open('coco.names', 'r').readlines())) 41 | self.colors = np.random.default_rng(3).uniform(0, 255, size=(len(self.class_names), 3)) 42 | 43 | self.output_names = self.net.getUnconnectedOutLayersNames() 44 | 45 | async def recv(self): 46 | pts, time_base = await self.next_timestamp() 47 | _, frame = self.video.read() 48 | boxes, scores, class_ids = self.detect(frame) 49 | frame = self.draw_detections(frame, boxes, scores, class_ids) 50 | frame = VideoFrame.from_ndarray(frame, format="bgr24") 51 | frame.pts = pts 52 | frame.time_base = time_base 53 | return frame 54 | 55 | 56 | def prepare_input(self, image): 57 | self.img_height, self.img_width = image.shape[:2] 58 | 59 | input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 60 | 61 | # Resize input image 62 | input_img = cv2.resize(input_img, (self.input_width, self.input_height)) 63 | 64 | # Scale input pixel values to 0 to 1 65 | return input_img 66 | 67 | def detect(self, frame): 68 | input_img = self.prepare_input(frame) 69 | blob = cv2.dnn.blobFromImage(input_img, 1 / 255.0) 70 | # Perform inference on the image 71 | self.net.setInput(blob) 72 | # Runs the forward pass to get output of the output layers 73 | outputs = self.net.forward(self.output_names) 74 | 75 | boxes, scores, class_ids = self.process_output(outputs) 76 | return boxes, scores, class_ids 77 | 78 | def process_output(self, output): 79 | predictions = np.squeeze(output[0]) 80 | 81 | # Filter out object confidence scores below threshold 82 | obj_conf = predictions[:, 4] 83 | predictions = predictions[obj_conf > self.conf_threshold] 84 | obj_conf = obj_conf[obj_conf > self.conf_threshold] 85 | 86 | # Multiply class confidence with bounding box confidence 87 | predictions[:, 5:] *= obj_conf[:, np.newaxis] 88 | 89 | # Get the scores 90 | scores = np.max(predictions[:, 5:], axis=1) 91 | 92 | # Filter out the objects with a low score 93 | valid_scores = scores > self.conf_threshold 94 | predictions = predictions[valid_scores] 95 | scores = scores[valid_scores] 96 | 97 | # Get the class with the highest confidence 98 | class_ids = np.argmax(predictions[:, 5:], axis=1) 99 | 100 | # Get bounding boxes for each object 101 | boxes = self.extract_boxes(predictions) 102 | 103 | # Apply non-maxima suppression to suppress weak, overlapping bounding boxes 104 | indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), self.conf_threshold, self.iou_threshold) 105 | if len(indices) > 0: 106 | indices = indices.flatten() 107 | 108 | return boxes[indices], scores[indices], class_ids[indices] 109 | 110 | def rescale_boxes(self, boxes): 111 | input_shape = np.array([self.input_width, self.input_height, self.input_width, self.input_height]) 112 | boxes = np.divide(boxes, input_shape, dtype=np.float32) 113 | boxes *= np.array([self.img_width, self.img_height, self.img_width, self.img_height]) 114 | return boxes 115 | 116 | def extract_boxes(self, predictions): 117 | # Extract boxes from predictions 118 | boxes = predictions[:, :4] 119 | 120 | # Scale boxes to original image dimensions 121 | boxes = self.rescale_boxes(boxes) 122 | 123 | # Convert boxes to xywh format 124 | boxes_ = np.copy(boxes) 125 | boxes_[..., 0] = boxes[..., 0] - boxes[..., 2] * 0.5 126 | boxes_[..., 1] = boxes[..., 1] - boxes[..., 3] * 0.5 127 | return boxes_ 128 | 129 | def draw_detections(self, frame, boxes, scores, class_ids): 130 | for box, score, class_id in zip(boxes, scores, class_ids): 131 | x, y, w, h = box.astype(int) 132 | color = self.colors[class_id] 133 | 134 | # Draw rectangle 135 | cv2.rectangle(frame, (x, y), (x+w, y+h), color, thickness=2) 136 | label = self.class_names[class_id] 137 | label = f'{label} {int(score * 100)}%' 138 | cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) 139 | cv2.putText(frame, label, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, color, thickness=2) 140 | return frame 141 | 142 | 143 | def create_local_tracks(): 144 | global relay, webcam 145 | options = {"framerate": "30", "video_size": "640x480"} 146 | if relay is None: 147 | if platform.system() == "Darwin": 148 | webcam = MediaPlayer( 149 | "default:none", format="avfoundation", options=options 150 | ) 151 | elif platform.system() == "Windows": 152 | webcam = MediaPlayer( 153 | "video=Integrated Camera", format="dshow", options=options 154 | ) 155 | else: 156 | webcam = MediaPlayer("/dev/video0", format="v4l2", options=options) 157 | relay = MediaRelay() 158 | return relay.subscribe(webcam.video) 159 | 160 | 161 | async def index(request): 162 | content = open(os.path.join(ROOT, "index.html"), "r").read() 163 | return web.Response(content_type="text/html", text=content) 164 | 165 | 166 | async def javascript(request): 167 | content = open(os.path.join(ROOT, "client.js"), "r").read() 168 | return web.Response(content_type="application/javascript", text=content) 169 | 170 | 171 | async def offer(request): 172 | params = await request.json() 173 | offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) 174 | 175 | pc = RTCPeerConnection() 176 | pcs.add(pc) 177 | 178 | @pc.on("connectionstatechange") 179 | async def on_connectionstatechange(): 180 | print("Connection state is %s" % pc.connectionState) 181 | if pc.connectionState == "failed": 182 | await pc.close() 183 | pcs.discard(pc) 184 | 185 | # open media source 186 | video = create_local_tracks() 187 | if video: 188 | pc.addTrack(YOLOVideoStreamTrack()) 189 | 190 | await pc.setRemoteDescription(offer) 191 | 192 | answer = await pc.createAnswer() 193 | await pc.setLocalDescription(answer) 194 | 195 | return web.Response( 196 | content_type="application/json", 197 | text=json.dumps( 198 | {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} 199 | ), 200 | ) 201 | 202 | 203 | pcs = set() 204 | 205 | async def on_shutdown(app): 206 | # close peer connections 207 | coros = [pc.close() for pc in pcs] 208 | await asyncio.gather(*coros) 209 | pcs.clear() 210 | 211 | 212 | if __name__ == "__main__": 213 | parser = argparse.ArgumentParser(description="WebRTC webcam demo") 214 | parser.add_argument( 215 | "--host", default="0.0.0.0", help="Host for HTTP server (default: 0.0.0.0)" 216 | ) 217 | parser.add_argument( 218 | "--port", type=int, default=8080, help="Port for HTTP server (default: 8080)" 219 | ) 220 | 221 | args = parser.parse_args() 222 | logging.basicConfig(level=logging.INFO) 223 | 224 | app = web.Application() 225 | app.on_shutdown.append(on_shutdown) 226 | app.router.add_get("/", index) 227 | app.router.add_get("/client.js", javascript) 228 | app.router.add_post("/offer", offer) 229 | web.run_app(app, host=args.host, port=args.port) -------------------------------------------------------------------------------- /yolov7-tiny_480x640.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genert/real-time-object-detection-with-webrtc-and-yolo/745ea481d0c866ecd884edd9b53b6be15d1307c1/yolov7-tiny_480x640.onnx --------------------------------------------------------------------------------