├── .gitignore
├── README.md
├── client.js
├── coco.names
├── index.html
├── main.py
└── yolov7-tiny_480x640.onnx
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # real-time-object-detection-with-webrtc-and-yolo
2 |
3 | A solution code for the real time object detection with WebRTC and YOLO article - https://softwarescalability.com/editorial/real-time-object-detection-with-webrtc-and-yolo
4 |
--------------------------------------------------------------------------------
/client.js:
--------------------------------------------------------------------------------
1 | var pc = null;
2 |
3 | function negotiate() {
4 | pc.addTransceiver('video', {direction: 'recvonly'});
5 | pc.addTransceiver('audio', {direction: 'recvonly'});
6 | return pc.createOffer().then(function(offer) {
7 | return pc.setLocalDescription(offer);
8 | }).then(function() {
9 | // wait for ICE gathering to complete
10 | return new Promise(function(resolve) {
11 | if (pc.iceGatheringState === 'complete') {
12 | resolve();
13 | } else {
14 | function checkState() {
15 | if (pc.iceGatheringState === 'complete') {
16 | pc.removeEventListener('icegatheringstatechange', checkState);
17 | resolve();
18 | }
19 | }
20 | pc.addEventListener('icegatheringstatechange', checkState);
21 | }
22 | });
23 | }).then(function() {
24 | var offer = pc.localDescription;
25 | return fetch('/offer', {
26 | body: JSON.stringify({
27 | sdp: offer.sdp,
28 | type: offer.type,
29 | }),
30 | headers: {
31 | 'Content-Type': 'application/json'
32 | },
33 | method: 'POST'
34 | });
35 | }).then(function(response) {
36 | return response.json();
37 | }).then(function(answer) {
38 | return pc.setRemoteDescription(answer);
39 | }).catch(function(e) {
40 | alert(e);
41 | });
42 | }
43 |
44 | function start() {
45 | var config = {
46 | sdpSemantics: 'unified-plan'
47 | };
48 |
49 | if (document.getElementById('use-stun').checked) {
50 | config.iceServers = [{urls: ['stun:stun.l.google.com:19302']}];
51 | }
52 |
53 | pc = new RTCPeerConnection(config);
54 |
55 | // connect audio / video
56 | pc.addEventListener('track', function(evt) {
57 | if (evt.track.kind == 'video') {
58 | document.getElementById('video').srcObject = evt.streams[0];
59 | } else {
60 | document.getElementById('audio').srcObject = evt.streams[0];
61 | }
62 | });
63 |
64 | document.getElementById('start').style.display = 'none';
65 | negotiate();
66 | document.getElementById('stop').style.display = 'inline-block';
67 | }
68 |
69 | function stop() {
70 | document.getElementById('stop').style.display = 'none';
71 |
72 | // close peer connection
73 | setTimeout(function() {
74 | pc.close();
75 | }, 500);
76 | }
--------------------------------------------------------------------------------
/coco.names:
--------------------------------------------------------------------------------
1 | person
2 | bicycle
3 | car
4 | motorbike
5 | aeroplane
6 | bus
7 | train
8 | truck
9 | boat
10 | traffic light
11 | fire hydrant
12 | stop sign
13 | parking meter
14 | bench
15 | bird
16 | cat
17 | dog
18 | horse
19 | sheep
20 | cow
21 | elephant
22 | bear
23 | zebra
24 | giraffe
25 | backpack
26 | umbrella
27 | handbag
28 | tie
29 | suitcase
30 | frisbee
31 | skis
32 | snowboard
33 | sports ball
34 | kite
35 | baseball bat
36 | baseball glove
37 | skateboard
38 | surfboard
39 | tennis racket
40 | bottle
41 | wine glass
42 | cup
43 | fork
44 | knife
45 | spoon
46 | bowl
47 | banana
48 | apple
49 | sandwich
50 | orange
51 | broccoli
52 | carrot
53 | hot dog
54 | pizza
55 | donut
56 | cake
57 | chair
58 | sofa
59 | pottedplant
60 | bed
61 | diningtable
62 | toilet
63 | tvmonitor
64 | laptop
65 | mouse
66 | remote
67 | keyboard
68 | cell phone
69 | microwave
70 | oven
71 | toaster
72 | sink
73 | refrigerator
74 | book
75 | clock
76 | vase
77 | scissors
78 | teddy bear
79 | hair drier
80 | toothbrush
--------------------------------------------------------------------------------
/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | WebRTC webcam
6 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import asyncio
3 | import json
4 | import logging
5 | import os
6 | import platform
7 |
8 | from aiohttp import web
9 | import cv2
10 | import numpy as np
11 | from av import VideoFrame
12 | from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
13 | from aiortc.contrib.media import MediaPlayer, MediaRelay
14 |
15 | ROOT = os.path.dirname(__file__)
16 |
17 | relay = None
18 | webcam = None
19 |
20 | model = "yolov7-tiny_480x640.onnx"
21 |
22 | class YOLOVideoStreamTrack(VideoStreamTrack):
23 | """
24 | A video track thats returns camera track with annotated detected objects.
25 | """
26 | def __init__(self, conf_thres=0.7, iou_thres=0.5):
27 | super().__init__() # don't forget this!
28 | self.conf_threshold = conf_thres
29 | self.iou_threshold = iou_thres
30 |
31 | video = cv2.VideoCapture(0)
32 | self.video = video
33 |
34 | # Initialize model
35 | self.net = cv2.dnn.readNet(model)
36 | input_shape = os.path.splitext(os.path.basename(model))[0].split('_')[-1].split('x')
37 | self.input_height = int(input_shape[0])
38 | self.input_width = int(input_shape[1])
39 |
40 | self.class_names = list(map(lambda x: x.strip(), open('coco.names', 'r').readlines()))
41 | self.colors = np.random.default_rng(3).uniform(0, 255, size=(len(self.class_names), 3))
42 |
43 | self.output_names = self.net.getUnconnectedOutLayersNames()
44 |
45 | async def recv(self):
46 | pts, time_base = await self.next_timestamp()
47 | _, frame = self.video.read()
48 | boxes, scores, class_ids = self.detect(frame)
49 | frame = self.draw_detections(frame, boxes, scores, class_ids)
50 | frame = VideoFrame.from_ndarray(frame, format="bgr24")
51 | frame.pts = pts
52 | frame.time_base = time_base
53 | return frame
54 |
55 |
56 | def prepare_input(self, image):
57 | self.img_height, self.img_width = image.shape[:2]
58 |
59 | input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
60 |
61 | # Resize input image
62 | input_img = cv2.resize(input_img, (self.input_width, self.input_height))
63 |
64 | # Scale input pixel values to 0 to 1
65 | return input_img
66 |
67 | def detect(self, frame):
68 | input_img = self.prepare_input(frame)
69 | blob = cv2.dnn.blobFromImage(input_img, 1 / 255.0)
70 | # Perform inference on the image
71 | self.net.setInput(blob)
72 | # Runs the forward pass to get output of the output layers
73 | outputs = self.net.forward(self.output_names)
74 |
75 | boxes, scores, class_ids = self.process_output(outputs)
76 | return boxes, scores, class_ids
77 |
78 | def process_output(self, output):
79 | predictions = np.squeeze(output[0])
80 |
81 | # Filter out object confidence scores below threshold
82 | obj_conf = predictions[:, 4]
83 | predictions = predictions[obj_conf > self.conf_threshold]
84 | obj_conf = obj_conf[obj_conf > self.conf_threshold]
85 |
86 | # Multiply class confidence with bounding box confidence
87 | predictions[:, 5:] *= obj_conf[:, np.newaxis]
88 |
89 | # Get the scores
90 | scores = np.max(predictions[:, 5:], axis=1)
91 |
92 | # Filter out the objects with a low score
93 | valid_scores = scores > self.conf_threshold
94 | predictions = predictions[valid_scores]
95 | scores = scores[valid_scores]
96 |
97 | # Get the class with the highest confidence
98 | class_ids = np.argmax(predictions[:, 5:], axis=1)
99 |
100 | # Get bounding boxes for each object
101 | boxes = self.extract_boxes(predictions)
102 |
103 | # Apply non-maxima suppression to suppress weak, overlapping bounding boxes
104 | indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), self.conf_threshold, self.iou_threshold)
105 | if len(indices) > 0:
106 | indices = indices.flatten()
107 |
108 | return boxes[indices], scores[indices], class_ids[indices]
109 |
110 | def rescale_boxes(self, boxes):
111 | input_shape = np.array([self.input_width, self.input_height, self.input_width, self.input_height])
112 | boxes = np.divide(boxes, input_shape, dtype=np.float32)
113 | boxes *= np.array([self.img_width, self.img_height, self.img_width, self.img_height])
114 | return boxes
115 |
116 | def extract_boxes(self, predictions):
117 | # Extract boxes from predictions
118 | boxes = predictions[:, :4]
119 |
120 | # Scale boxes to original image dimensions
121 | boxes = self.rescale_boxes(boxes)
122 |
123 | # Convert boxes to xywh format
124 | boxes_ = np.copy(boxes)
125 | boxes_[..., 0] = boxes[..., 0] - boxes[..., 2] * 0.5
126 | boxes_[..., 1] = boxes[..., 1] - boxes[..., 3] * 0.5
127 | return boxes_
128 |
129 | def draw_detections(self, frame, boxes, scores, class_ids):
130 | for box, score, class_id in zip(boxes, scores, class_ids):
131 | x, y, w, h = box.astype(int)
132 | color = self.colors[class_id]
133 |
134 | # Draw rectangle
135 | cv2.rectangle(frame, (x, y), (x+w, y+h), color, thickness=2)
136 | label = self.class_names[class_id]
137 | label = f'{label} {int(score * 100)}%'
138 | cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
139 | cv2.putText(frame, label, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, color, thickness=2)
140 | return frame
141 |
142 |
143 | def create_local_tracks():
144 | global relay, webcam
145 | options = {"framerate": "30", "video_size": "640x480"}
146 | if relay is None:
147 | if platform.system() == "Darwin":
148 | webcam = MediaPlayer(
149 | "default:none", format="avfoundation", options=options
150 | )
151 | elif platform.system() == "Windows":
152 | webcam = MediaPlayer(
153 | "video=Integrated Camera", format="dshow", options=options
154 | )
155 | else:
156 | webcam = MediaPlayer("/dev/video0", format="v4l2", options=options)
157 | relay = MediaRelay()
158 | return relay.subscribe(webcam.video)
159 |
160 |
161 | async def index(request):
162 | content = open(os.path.join(ROOT, "index.html"), "r").read()
163 | return web.Response(content_type="text/html", text=content)
164 |
165 |
166 | async def javascript(request):
167 | content = open(os.path.join(ROOT, "client.js"), "r").read()
168 | return web.Response(content_type="application/javascript", text=content)
169 |
170 |
171 | async def offer(request):
172 | params = await request.json()
173 | offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
174 |
175 | pc = RTCPeerConnection()
176 | pcs.add(pc)
177 |
178 | @pc.on("connectionstatechange")
179 | async def on_connectionstatechange():
180 | print("Connection state is %s" % pc.connectionState)
181 | if pc.connectionState == "failed":
182 | await pc.close()
183 | pcs.discard(pc)
184 |
185 | # open media source
186 | video = create_local_tracks()
187 | if video:
188 | pc.addTrack(YOLOVideoStreamTrack())
189 |
190 | await pc.setRemoteDescription(offer)
191 |
192 | answer = await pc.createAnswer()
193 | await pc.setLocalDescription(answer)
194 |
195 | return web.Response(
196 | content_type="application/json",
197 | text=json.dumps(
198 | {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
199 | ),
200 | )
201 |
202 |
203 | pcs = set()
204 |
205 | async def on_shutdown(app):
206 | # close peer connections
207 | coros = [pc.close() for pc in pcs]
208 | await asyncio.gather(*coros)
209 | pcs.clear()
210 |
211 |
212 | if __name__ == "__main__":
213 | parser = argparse.ArgumentParser(description="WebRTC webcam demo")
214 | parser.add_argument(
215 | "--host", default="0.0.0.0", help="Host for HTTP server (default: 0.0.0.0)"
216 | )
217 | parser.add_argument(
218 | "--port", type=int, default=8080, help="Port for HTTP server (default: 8080)"
219 | )
220 |
221 | args = parser.parse_args()
222 | logging.basicConfig(level=logging.INFO)
223 |
224 | app = web.Application()
225 | app.on_shutdown.append(on_shutdown)
226 | app.router.add_get("/", index)
227 | app.router.add_get("/client.js", javascript)
228 | app.router.add_post("/offer", offer)
229 | web.run_app(app, host=args.host, port=args.port)
--------------------------------------------------------------------------------
/yolov7-tiny_480x640.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/genert/real-time-object-detection-with-webrtc-and-yolo/745ea481d0c866ecd884edd9b53b6be15d1307c1/yolov7-tiny_480x640.onnx
--------------------------------------------------------------------------------