├── FPS.py
├── LICENSE.txt
├── MovenetMPOpenvino.py
├── README.md
├── Tracker.py
├── img
├── dance.gif
├── street_192x256.jpg
├── street_256x320.jpg
├── street_480x640.jpg
├── street_736x1280.jpg
├── tracking_iou.gif
└── tracking_oks.gif
└── models
├── movenet_multipose_lightning_192x192_FP32.bin
├── movenet_multipose_lightning_192x192_FP32.xml
├── movenet_multipose_lightning_192x256_FP32.bin
├── movenet_multipose_lightning_192x256_FP32.xml
├── movenet_multipose_lightning_256x256_FP32.bin
├── movenet_multipose_lightning_256x256_FP32.xml
├── movenet_multipose_lightning_256x320_FP32.bin
├── movenet_multipose_lightning_256x320_FP32.xml
├── movenet_multipose_lightning_320x320_FP32.bin
├── movenet_multipose_lightning_320x320_FP32.xml
├── movenet_multipose_lightning_480x640_FP32.bin
├── movenet_multipose_lightning_480x640_FP32.xml
├── movenet_multipose_lightning_736x1280_FP32.bin
└── movenet_multipose_lightning_736x1280_FP32.xml
/FPS.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: geaxx
3 | """
4 | import time
5 | import cv2
6 | from collections import deque
7 |
8 | def now():
9 | return time.perf_counter()
10 |
11 | class FPS: # To measure the number of frame per second
12 | def __init__(self, average_of=30):
13 | self.timestamps = deque(maxlen=average_of)
14 | self.nbf = -1
15 |
16 | def update(self):
17 | self.timestamps.append(time.monotonic())
18 | if len(self.timestamps) == 1:
19 | self.start = self.timestamps[0]
20 | self.fps = 0
21 | else:
22 | self.fps = (len(self.timestamps)-1)/(self.timestamps[-1]-self.timestamps[0])
23 | self.nbf+=1
24 |
25 | def get(self):
26 | return self.fps
27 |
28 | def get_global(self):
29 | return self.nbf/(self.timestamps[-1] - self.start), self.nbf+1
30 |
31 | def draw(self, win, orig=(10,30), font=cv2.FONT_HERSHEY_SIMPLEX, size=2, color=(0,255,0), thickness=2):
32 | cv2.putText(win,f"FPS={self.get():.2f}",orig,font,size,color,thickness)
33 |
34 | if __name__ == "__main__":
35 | fps = FPS()
36 | for i in range(50):
37 | fps.update()
38 | print(f"fps = {fps.get()}")
39 | time.sleep(0.1)
40 | global_fps, nb_frames = fps.get_global()
41 | print(f"Global fps : {global_fps} ({nb_frames} frames)")
42 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) [2021] [geax]
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MovenetMPOpenvino.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from collections import namedtuple
3 | import cv2
4 | from pathlib import Path
5 | from FPS import FPS, now
6 | import argparse
7 | import os
8 | from openvino.inference_engine import IENetwork, IECore
9 | from Tracker import TrackerIoU, TrackerOKS, TRACK_COLORS
10 |
11 |
12 | SCRIPT_DIR = Path(__file__).resolve().parent
13 | DEFAULT_MODEL = SCRIPT_DIR / "models/movenet_multipose_lightning_256x256_FP32.xml"
14 |
15 |
16 | # Dictionary that maps from joint names to keypoint indices.
17 | KEYPOINT_DICT = {
18 | 'nose': 0,
19 | 'left_eye': 1,
20 | 'right_eye': 2,
21 | 'left_ear': 3,
22 | 'right_ear': 4,
23 | 'left_shoulder': 5,
24 | 'right_shoulder': 6,
25 | 'left_elbow': 7,
26 | 'right_elbow': 8,
27 | 'left_wrist': 9,
28 | 'right_wrist': 10,
29 | 'left_hip': 11,
30 | 'right_hip': 12,
31 | 'left_knee': 13,
32 | 'right_knee': 14,
33 | 'left_ankle': 15,
34 | 'right_ankle': 16
35 | }
36 |
37 | # LINES_BODY are used when drawing the skeleton onto the source image.
38 | # Each variable is a list of continuous lines.
39 | # Each line is a list of keypoints as defined at https://github.com/tensorflow/tfjs-models/tree/master/pose-detection#keypoint-diagram
40 |
41 | LINES_BODY = [[4,2],[2,0],[0,1],[1,3],
42 | [10,8],[8,6],[6,5],[5,7],[7,9],
43 | [6,12],[12,11],[11,5],
44 | [12,14],[14,16],[11,13],[13,15]]
45 |
46 | class Body:
47 | def __init__(self, score, xmin, ymin, xmax, ymax, keypoints_score, keypoints, keypoints_norm):
48 | self.score = score # global score
49 | # xmin, ymin, xmax, ymax : bounding box
50 | self.xmin = xmin
51 | self.ymin = ymin
52 | self.xmax = xmax
53 | self.ymax = ymax
54 | self.keypoints_score = keypoints_score# scores of the keypoints
55 | self.keypoints_norm = keypoints_norm # keypoints normalized ([0,1]) coordinates (x,y) in the input image
56 | self.keypoints = keypoints # keypoints coordinates (x,y) in pixels in the input image
57 |
58 | def print(self):
59 | attrs = vars(self)
60 | print('\n'.join("%s: %s" % item for item in attrs.items()))
61 |
62 | def str_bbox(self):
63 | return f"xmin={self.xmin} xmax={self.xmax} ymin={self.ymin} ymax={self.ymax}"
64 |
65 | # Padding (all values are in pixel) :
66 | # w (resp. h): horizontal (resp. vertical) padding on the source image to make its ratio same as Movenet model input.
67 | # The padding is done on one side (bottom or right) of the image.
68 | # padded_w (resp. padded_h): width (resp. height) of the image after padding
69 | Padding = namedtuple('Padding', ['w', 'h', 'padded_w', 'padded_h'])
70 |
71 | class MovenetMPOpenvino:
72 | def __init__(self, input_src=None,
73 | xml=DEFAULT_MODEL,
74 | device="CPU",
75 | tracking=False,
76 | score_thresh=0.2,
77 | output=None):
78 |
79 | self.score_thresh = score_thresh
80 | self.tracking = tracking
81 | if tracking is None:
82 | self.tracking = False
83 | elif tracking == "iou":
84 | self.tracking = True
85 | self.tracker = TrackerIoU()
86 | elif tracking == "oks":
87 | self.tracking = True
88 | self.tracker = TrackerOKS()
89 |
90 | if input_src.endswith('.jpg') or input_src.endswith('.png') :
91 | self.input_type= "image"
92 | self.img = cv2.imread(input_src)
93 | self.video_fps = 25
94 | self.img_h, self.img_w = self.img.shape[:2]
95 | else:
96 | self.input_type = "video"
97 | if input_src.isdigit():
98 | input_type = "webcam"
99 | input_src = int(input_src)
100 | self.cap = cv2.VideoCapture(input_src)
101 | self.video_fps = int(self.cap.get(cv2.CAP_PROP_FPS))
102 | self.img_w = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
103 | self.img_h = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
104 | print("Video FPS:", self.video_fps)
105 |
106 | # Load Openvino models
107 | self.load_model(xml, device)
108 |
109 | # Rendering flags
110 | self.show_fps = True
111 | self.show_bounding_box = False
112 |
113 | if output is None:
114 | self.output = None
115 | else:
116 | if self.input_type == "image":
117 | # For an source image, we will output one image (and not a video) and exit
118 | self.output = output
119 | else:
120 | fourcc = cv2.VideoWriter_fourcc(*"MJPG")
121 | self.output = cv2.VideoWriter(output,fourcc,self.video_fps,(self.img_w, self.img_h))
122 |
123 | # Define the padding
124 | # Note we don't center the source image. The padding is applied
125 | # on the bottom or right side. That simplifies a bit the calculation
126 | # when depadding
127 | if self.img_w / self.img_h > self.pd_w / self.pd_h:
128 | pad_h = int(self.img_w * self.pd_h / self.pd_w - self.img_h)
129 | self.padding = Padding(0, pad_h, self.img_w, self.img_h + pad_h)
130 | else:
131 | pad_w = int(self.img_h * self.pd_w / self.pd_h - self.img_w)
132 | self.padding = Padding(pad_w, 0, self.img_w + pad_w, self.img_h)
133 | print("Padding:", self.padding)
134 |
135 | def load_model(self, xml_path, device):
136 |
137 | print("Loading Inference Engine")
138 | self.ie = IECore()
139 | print("Device info:")
140 | versions = self.ie.get_versions(device)
141 | print("{}{}".format(" "*8, device))
142 | print("{}MKLDNNPlugin version ......... {}.{}".format(" "*8, versions[device].major, versions[device].minor))
143 | print("{}Build ........... {}".format(" "*8, versions[device].build_number))
144 |
145 | name = os.path.splitext(xml_path)[0]
146 | bin_path = name + '.bin'
147 | print("Pose Detection model - Reading network files:\n\t{}\n\t{}".format(xml_path, bin_path))
148 | self.pd_net = self.ie.read_network(model=xml_path, weights=bin_path)
149 | # Input blob: input:0 - shape: [1, 3, 256, 256] (lightning)
150 | # Output blob: Identity - shape: [1, 6, 56]
151 | self.pd_input_blob = next(iter(self.pd_net.input_info))
152 | print(f"Input blob: {self.pd_input_blob} - shape: {self.pd_net.input_info[self.pd_input_blob].input_data.shape}")
153 | _,_,self.pd_h,self.pd_w = self.pd_net.input_info[self.pd_input_blob].input_data.shape
154 | for o in self.pd_net.outputs.keys():
155 | print(f"Output blob: {o} - shape: {self.pd_net.outputs[o].shape}")
156 | self.pd_kps = "Identity"
157 | print("Loading pose detection model into the plugin")
158 | self.pd_exec_net = self.ie.load_network(network=self.pd_net, num_requests=1, device_name=device)
159 |
160 | self.infer_nb = 0
161 | self.infer_time_cumul = 0
162 |
163 | def pad_and_resize(self, frame):
164 | """ Pad and resize the image to prepare for the model input."""
165 |
166 | padded = cv2.copyMakeBorder(frame,
167 | 0,
168 | self.padding.h,
169 | 0,
170 | self.padding.w,
171 | cv2.BORDER_CONSTANT)
172 |
173 | padded = cv2.resize(padded, (self.pd_w, self.pd_h), interpolation=cv2.INTER_AREA)
174 |
175 | return padded
176 |
177 | def pd_postprocess(self, inference):
178 | result = np.squeeze(inference[self.pd_kps]) # 6x56
179 | bodies = []
180 | for i in range(6):
181 | kps = result[i][:51].reshape(17,-1)
182 | bbox = result[i][51:55].reshape(2,2)
183 | score = result[i][55]
184 | if score > self.score_thresh:
185 | ymin, xmin, ymax, xmax = (bbox * [self.padding.padded_h, self.padding.padded_w]).flatten().astype(np.int)
186 | kp_xy =kps[:,[1,0]]
187 | keypoints = kp_xy * np.array([self.padding.padded_w, self.padding.padded_h])
188 |
189 | body = Body(score=score, xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax,
190 | keypoints_score = kps[:,2],
191 | keypoints = keypoints.astype(np.int),
192 | keypoints_norm = keypoints / np.array([self.img_w, self.img_h]))
193 | bodies.append(body)
194 | return bodies
195 |
196 |
197 | def pd_render(self, frame, bodies):
198 | thickness = 3
199 | color_skeleton = (255, 180, 90)
200 | color_box = (0,255,255)
201 | for body in bodies:
202 | if self.tracking:
203 | color_skeleton = color_box = TRACK_COLORS[body.track_id % len(TRACK_COLORS)]
204 |
205 | lines = [np.array([body.keypoints[point] for point in line]) for line in LINES_BODY if body.keypoints_score[line[0]] > self.score_thresh and body.keypoints_score[line[1]] > self.score_thresh]
206 | cv2.polylines(frame, lines, False, color_skeleton, 2, cv2.LINE_AA)
207 |
208 | for i,x_y in enumerate(body.keypoints):
209 | if body.keypoints_score[i] > self.score_thresh:
210 | if i % 2 == 1:
211 | color = (0,255,0)
212 | elif i == 0:
213 | color = (0,255,255)
214 | else:
215 | color = (0,0,255)
216 | cv2.circle(frame, (x_y[0], x_y[1]), 4, color, -11)
217 |
218 | if self.show_bounding_box:
219 | cv2.rectangle(frame, (body.xmin, body.ymin), (body.xmax, body.ymax), color_box, thickness)
220 |
221 | if self.tracking:
222 | # Display track_id at the center of the bounding box
223 | x = (body.xmin + body.xmax) // 2
224 | y = (body.ymin + body.ymax) // 2
225 | cv2.putText(frame, str(body.track_id), (x,y), cv2.FONT_HERSHEY_PLAIN, 4, color_box, 3)
226 |
227 | def run(self):
228 |
229 | self.fps = FPS()
230 | nb_pd_inferences = 0
231 | glob_pd_rtrip_time = 0
232 |
233 | while True:
234 |
235 | if self.input_type == "image":
236 | frame = self.img.copy()
237 | else:
238 | ok, frame = self.cap.read()
239 | if not ok:
240 | break
241 |
242 | padded = self.pad_and_resize(frame)
243 | # cv2.imshow("Padded", padded)
244 |
245 | frame_nn = cv2.cvtColor(padded, cv2.COLOR_BGR2RGB).transpose(2,0,1).astype(np.float32)[None,]
246 | pd_rtrip_time = now()
247 | inference = self.pd_exec_net.infer(inputs={self.pd_input_blob: frame_nn})
248 | glob_pd_rtrip_time += now() - pd_rtrip_time
249 | bodies = self.pd_postprocess(inference)
250 | if self.tracking:
251 | bodies = self.tracker.apply(bodies, now())
252 | self.pd_render(frame, bodies)
253 | nb_pd_inferences += 1
254 |
255 | self.fps.update()
256 |
257 | if self.show_fps:
258 | self.fps.draw(frame, orig=(50,50), size=1, color=(240,180,100))
259 | cv2.imshow("Movenet", frame)
260 |
261 | if self.output:
262 | if self.input_type == "image":
263 | cv2.imwrite(self.output, frame)
264 | break
265 | else:
266 | self.output.write(frame)
267 |
268 | key = cv2.waitKey(1)
269 | if key == ord('q') or key == 27:
270 | break
271 | elif key == 32:
272 | # Pause on space bar
273 | cv2.waitKey(0)
274 | elif key == ord('f'):
275 | self.show_fps = not self.show_fps
276 | elif key == ord('b'):
277 | self.show_bounding_box = not self.show_bounding_box
278 |
279 | # Print some stats
280 | if nb_pd_inferences > 1:
281 | global_fps, nb_frames = self.fps.get_global()
282 |
283 | print(f"FPS : {global_fps:.1f} f/s (# frames = {nb_frames})")
284 | print(f"# pose detection inferences : {nb_pd_inferences}")
285 | print(f"Pose detection round trip : {glob_pd_rtrip_time/nb_pd_inferences*1000:.1f} ms")
286 |
287 | if self.output and self.input_type != "image":
288 | self.output.release()
289 |
290 |
291 | if __name__ == "__main__":
292 | parser = argparse.ArgumentParser()
293 | parser.add_argument('-i', '--input', type=str, default='0',
294 | help="Path to video or image file to use as input (default=%(default)s)")
295 | # parser.add_argument("-p", "--precision", type=int, choices=[16, 32], default=32,
296 | # help="Precision (default=%(default)i")
297 | parser.add_argument("--xml", type=str,
298 | help="Path to an .xml file for model")
299 | parser.add_argument("-r", "--res", default="256x256", choices=["192x192", "192x256", "256x256", "256x320", "320x320", "480x640", "736x1280"])
300 | # parser.add_argument("-d", "--device", default='CPU', type=str,
301 | # help="Target device to run the model (default=%(default)s)")
302 | parser.add_argument("-t", "--tracking", choices=["iou", "oks"],
303 | help="Enable tracking and specify method")
304 | parser.add_argument("-s", "--score_threshold", default=0.2, type=float,
305 | help="Confidence score (default=%(default)f)")
306 | parser.add_argument("-o","--output",
307 | help="Path to output video file")
308 |
309 | args = parser.parse_args()
310 |
311 |
312 | # if args.device == "MYRIAD" or args.device == "GPU":
313 | # args.precision = 16
314 | if not args.xml:
315 | args.xml = SCRIPT_DIR / f"models/movenet_multipose_lightning_{args.res}_FP32.xml"
316 |
317 | pd = MovenetMPOpenvino(input_src=args.input,
318 | xml=args.xml,
319 | # device=args.device,
320 | tracking=args.tracking,
321 | score_thresh=args.score_threshold,
322 | output=args.output)
323 | pd.run()
324 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MoveNet Multipose Tracking on OpenVINO
2 |
3 | Running Google MoveNet Multipose models on OpenVINO.
4 |
5 | A convolutional neural network model that runs on RGB images and predicts [human joint
6 | locations](https://github.com/tensorflow/tfjs-models/tree/master/pose-detection#coco-keypoints-used-in-movenet-and-posenet) of several persons (6 max).
7 |
8 | **WIP: currently only working on CPU (not on GPU nor MYRIAD)**
9 |
10 |
11 | 
12 |
13 | *Full video demo [here.](https://youtu.be/ndy18jNcOC0)*
14 |
15 | For MoveNet Single Pose, please visit : [openvino_movenet](https://github.com/geaxgx/openvino_movenet)
16 |
17 |
18 |
19 | ## Install
20 |
21 | You need OpenVINO (tested on 2021.4) and OpenCV installed on your computer and to clone/download this repository.
22 |
23 | ## Run
24 |
25 | **Usage:**
26 |
27 | ```
28 | > python3 MovenetMPOpenvino.py -h
29 | usage: MovenetMPOpenvino.py [-h] [-i INPUT] [--xml XML]
30 | [-r {192x192,192x256,256x256,256x320,320x320,480x640,736x1280}]
31 | [-t {iou,oks}] [-s SCORE_THRESHOLD] [-o OUTPUT]
32 |
33 | optional arguments:
34 | -h, --help show this help message and exit
35 | -i INPUT, --input INPUT
36 | Path to video or image file to use as input
37 | (default=0)
38 | --xml XML Path to an .xml file for model
39 | -r {192x192,192x256,256x256,256x320,320x320,480x640,736x1280}, --res {192x192,192x256,256x256,256x320,320x320,480x640,736x1280}
40 | -t {iou,oks}, --tracking {iou,oks}
41 | Enable tracking and specify method
42 | -s SCORE_THRESHOLD, --score_threshold SCORE_THRESHOLD
43 | Confidence score (default=0.200000)
44 | -o OUTPUT, --output OUTPUT
45 | Path to output video file
46 | ```
47 |
48 | **Examples :**
49 |
50 | - To use default webcam camera as input :
51 |
52 | ```python3 MovenetMPOpenvino.py```
53 |
54 | - To specify the model input resolution :
55 |
56 | ```python3 MovenetMPOpenvino.py -r 256x320```
57 |
58 | - To enable tracking, based on Object Keypoint Similarity :
59 |
60 | ```python3 MovenetMPOpenvino.py -t keypoint```
61 |
62 | - To use a file (video or image) as input :
63 |
64 | ```python3 MovenetMPOpenvino.py -i filename```
65 |
66 |
67 | |Keypress|Function|
68 | |-|-|
69 | |*Esc*|Exit|
70 | |*space*|Pause
71 | |b|Show/hide bounding boxes|
72 | |f|Show/hide FPS|
73 |
74 |
75 | ## Input resolution
76 |
77 | The model input resolution (set with the '-r' or '--res' option) has an impact on the inference speed (the higher the resolution, the slower the inference) and on the size of the people that can be detected (the higher the resoltion, the smaller the size).
78 | The test below has been run on a CPU i7700k.
79 |
80 |
81 | |Resolution|FPS |Result|
82 | |-|-|-|
83 | |192x256|58.0|[
](img/street_192x256.jpg)|
84 | |256x320|44.1|[
](img/street_256x320.jpg)|
85 | |480x640|14.8|[
](img/street_480x640.jpg)|
86 | |736x1280|4.5|[
](img/street_736x1280.jpg)|
87 |
88 | ## Tracking
89 |
90 | The Javascript MoveNet demo code from Google proposes as an option [two methods of tracking](https://github.com/tensorflow/tfjs-models/blob/master/pose-detection/src/calculators/tracker.md). For this repository, I have adapted this tracking code in python. You can enable the tracking with the `--tracking` (or `-t`) argument of the demo followed by `iou` or `oks` which specifies how to calculate the similarity between detections from consecutive frames :
91 | * IoU (Intersection over Union) of pose bounding boxes (option `iou`);
92 | * [Object Keypoint Similarity](https://cocodataset.org/#keypoints-eval) (option `oks`).
93 |
94 | |Tracking|Result|
95 | |-|-|
96 | |IoU Tracking|
|
97 | |OKS Tracking|
|
98 |
99 | In the example above, we can notice several track switching in the IoU output and a track replacement (2 by 6). OKS method is doing a better job, yet it is not perfect: there is a track switching when body 3 is passing in front of body 1.
100 |
101 |
102 |
103 | ## The models
104 | The MoveNet Multipose v1 source model comes from the Tensorfow Hub: https://tfhub.dev/google/movenet/multipose/lightning/1
105 |
106 | The model was converted by PINTO in OpenVINO IR format. Unfortunately, the OpenVINO IR MoveNet model input resolution cannot be changed dynamically, so an arbitrary list of models have been generated, each one with its dedicated input resolution. These models and others (other resolutions or precisions) are also available there: https://github.com/PINTO0309/PINTO_model_zoo/tree/main/137_MoveNet_MultiPose
107 |
108 |
109 | ## Credits
110 | * [Google Tensorflow Hub](https://tfhub.dev/google/movenet/multipose/lightning/1)
111 | * Katsuya Hyodo a.k.a [Pinto](https://github.com/PINTO0309), the Wizard of Model Conversion !
112 | * The original video : [The Evolution of Dance - 1950 to 2019 - By Ricardo Walker's Crew](https://www.youtube.com/watch?v=p-rSdt0aFuw&ab_channel=RicardoWalker)
113 |
--------------------------------------------------------------------------------
/Tracker.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | TRACK_COLORS = [(230, 25, 75),
4 | (60, 180, 75),
5 | (255, 225, 25),
6 | (0, 130, 200),
7 | (245, 130, 48),
8 | (145, 30, 180),
9 | (70, 240, 240),
10 | (240, 50, 230),
11 | (210, 245, 60),
12 | (250, 190, 212),
13 | (0, 128, 128),
14 | (220, 190, 255),
15 | (170, 110, 40),
16 | (255, 250, 200),
17 | (128, 0, 0),
18 | (170, 255, 195),
19 | (128, 128, 0),
20 | (255, 215, 180),
21 | (0, 0, 128),
22 | (128, 128, 128)]
23 |
24 | class Track:
25 | def __init__(self, pose, timestamp):
26 | self.pose = pose
27 | self.timestamp = timestamp
28 |
29 | """
30 | Tracker: A stateful tracker for associating detections between frames..
31 | https://github.com/tensorflow/tfjs-models/blob/master/pose-detection/src/calculators/tracker.ts
32 | Default parameters values come from: https://github.com/tensorflow/tfjs-models/blob/master/pose-detection/src/movenet/constants.ts
33 | """
34 | class Tracker:
35 | def __init__(self, max_tracks, max_age, min_similarity):
36 | """
37 | max_tracks: int,
38 | The maximum number of tracks that an internal tracker
39 | will maintain. Note that this number should be set
40 | larger than maxPoses. How to set this
41 | number requires experimentation with a given detector,
42 | but a good starting place is about 3 * maxPoses.
43 | max_age: int,
44 | The maximum duration of time (in milliseconds) that a
45 | track can exist without being linked with a new detection
46 | before it is removed. Set this value large if you would
47 | like to recover people that are not detected for long
48 | stretches of time (at the cost of potential false
49 | re-identifications).
50 | min_similarity: float
51 | New poses will only be linked with tracks if the
52 | similarity score exceeds this threshold.
53 |
54 | """
55 | self.max_tracks = max_tracks
56 | self.max_age = max_age
57 | self.min_similarity = min_similarity
58 | self.tracks = {} # Dict of tracks, key = track_id, value = instance of class Track
59 | self.next_id = 1
60 |
61 | def apply(self, poses, timestamp):
62 | # Filters tracks based on their age.
63 | self.tracks = {id:track for (id, track) in self.tracks.items() if timestamp - track.timestamp < self.max_age}
64 | # Sort poses by their scores from most confident to least confident
65 | poses = sorted(poses, key=lambda body: body.score, reverse=True)
66 | # Performs a greedy optimization to link detections with tracks. If incoming
67 | # detections are not linked with existing tracks, new tracks will be created.
68 | unmatched_track_indices = list(self.tracks.keys())
69 | unmatched_detection_indices = []
70 | for i, pose in enumerate(poses):
71 | if len(unmatched_track_indices) == 0:
72 | unmatched_detection_indices.append(i)
73 | continue
74 | # Assign the detection to the track which produces the highest pairwise
75 | # similarity score, assuming the score exceeds the minimum similarity
76 | # threshold.
77 | max_track_id = -1
78 | max_sim = -1
79 | for track_id in unmatched_track_indices:
80 | sim = self.similarity(pose, self.tracks[track_id])
81 | if sim >= self.min_similarity and sim > max_sim:
82 | max_track_id = track_id
83 | max_sim = sim
84 | if max_track_id >= 0:
85 | pose.track_id = max_track_id
86 | self.update_track(max_track_id, pose, timestamp)
87 | unmatched_track_indices.remove(max_track_id)
88 | else:
89 | unmatched_detection_indices.append(i)
90 |
91 | # Spawn new tracks for all unmatched detections.
92 | for i in unmatched_detection_indices:
93 | track_id = self.create_track(poses[i], timestamp)
94 | poses[i].track_id = track_id
95 |
96 | # If there are too many tracks, we keep only the self.max_tracks freshest tracks
97 | if len(self.tracks) > self.max_tracks:
98 | sorted_dict = sorted(self.tracks.items(), key=lambda key_value: key_value[1].timestamp, reverse=True)[:self.max_tracks]
99 | self.tracks = {k:v for k,v in sorted_dict}
100 |
101 | return poses
102 |
103 | def create_track(self, pose, timestamp):
104 | track_id = self.next_id
105 | self.tracks[track_id] = Track(pose, timestamp)
106 | self.next_id += 1
107 | return track_id
108 |
109 | def update_track(self, track_id, pose, timestamp):
110 | self.tracks[track_id].pose = pose
111 | self.tracks[track_id].timestamp = timestamp
112 |
113 |
114 |
115 |
116 | """
117 | TrackerIoU, which tracks objects based on bounding box similarity,
118 | currently defined as intersection-over-union (IoU)
119 | https://github.com/tensorflow/tfjs-models/blob/master/pose-detection/src/calculators/bounding_box_tracker.ts
120 | """
121 | class TrackerIoU(Tracker):
122 | def __init__(self,
123 | max_tracks = 18,
124 | max_age = 1,
125 | min_similarity = 0.15
126 | ):
127 | """
128 | max_tracks, max_age, min_similarity: see Tracker docstring
129 | """
130 | super().__init__(max_tracks, max_age, min_similarity)
131 |
132 | def similarity(self, pose, track):
133 | """
134 | Computes the intersection-over-union (IoU) between a body bounding box and a track.
135 | Returns The IoU between the bounding box and the track. This number is
136 | between 0 and 1, and larger values indicate more box similarity.
137 | """
138 | xmin = max(pose.xmin, track.pose.xmin)
139 | ymin = max(pose.ymin, track.pose.ymin)
140 | xmax = min(pose.xmax, track.pose.xmax)
141 | ymax = min(pose.ymax, track.pose.ymax)
142 | if xmin >= xmax or ymin >= ymax:
143 | return 0.
144 | intersection = (xmax - xmin) * (ymax - ymin)
145 | area_pose = (pose.xmax - pose.xmin) * (pose.ymax - pose.ymin)
146 | area_track = (track.pose.xmax - track.pose.xmin) * (track.pose.ymax - track.pose.ymin)
147 | return intersection / (area_pose + area_track - intersection)
148 |
149 | """
150 | TrackerOKS, which tracks poses based on Object Keypoint Similarity.
151 | This tracker assumes that keypoints are provided in normalized image coordinates.
152 | https://github.com/tensorflow/tfjs-models/blob/master/pose-detection/src/calculators/keypoint_tracker.ts
153 | """
154 | class TrackerOKS(Tracker):
155 | def __init__(self,
156 | max_tracks = 18,
157 | max_age = 1,
158 | min_similarity = 0.2,
159 | keypoint_thresh = 0.3,
160 | keypoint_falloff = np.array([
161 | 0.026, 0.025, 0.025, 0.035, 0.035,
162 | 0.079, 0.079, 0.072, 0.072, 0.062,
163 | 0.062, 0.107, 0.107, 0.087, 0.087,
164 | 0.089, 0.089
165 | ]),
166 | min_keypoints = 4
167 | ):
168 | """
169 | max_tracks, max_age, min_similarity: see Tracker docstring
170 | keypoint_thresh: float,
171 | The minimum keypoint confidence threshold. A keypoint is only
172 | compared in the similarity calculation if both the new detected
173 | keypoint and the corresponding track keypoint have confidences
174 | above this threshold.
175 | keypoint_falloff: list of floats,
176 | Per-keypoint falloff in similarity calculation.
177 | min_keypoints: int,
178 | The minimum number of keypoints that are
179 | necessary for computing similarity. If the number
180 | of confident keypoints (between a pose and
181 | track) are under this value, an similarity of 0.0
182 | will be given.
183 | """
184 | super().__init__(max_tracks, max_age, min_similarity)
185 | self.keypoint_thresh = keypoint_thresh
186 | self.keypoint_falloff = keypoint_falloff
187 | self.min_keypoints = min_keypoints
188 |
189 | def similarity(self, pose, track):
190 | """
191 | Computes the Object Keypoint Similarity (OKS) between a pose and track.
192 | This is similar in spirit to the calculation used by COCO keypoint eval:
193 | https://cocodataset.org/#keypoints-eval
194 | In this case, OKS is calculated as:
195 | (1/sum_i d(c_i, c_ti)) * sum_i exp(-d_i^2/(2*a_ti*x_i^2))*d(c_i, c_ti)
196 | where
197 | d(x, y) is an indicator function which only produces 1 if x and y
198 | exceed a given threshold (i.e. keypointThreshold), otherwise 0.
199 | c_i is the confidence of keypoint i from the new pose
200 | c_ti is the confidence of keypoint i from the track
201 | d_i is the Euclidean distance between the pose and track keypoint
202 | a_ti is the area of the track object (the box covering the keypoints)
203 | x_i is a constant that controls falloff in a Gaussian distribution,
204 | computed as 2*keypointFalloff[i].
205 | Returns The OKS score between the pose and the track. This number is
206 | between 0 and 1, and larger values indicate more keypoint similarity.
207 | """
208 | box_area = self.area(track.pose)
209 | if box_area == 0: return 0
210 |
211 | num_valid_kps = 0
212 | valid_kps_filter = np.logical_and(pose.keypoints_score > self.keypoint_thresh, track.pose.keypoints_score > self.keypoint_thresh)
213 | pose_kps = pose.keypoints_norm[valid_kps_filter]
214 | num_valid_kps = len(pose_kps)
215 | if num_valid_kps < self.min_keypoints:
216 | return 0
217 | else:
218 | track_kps = track.pose.keypoints_norm[valid_kps_filter]
219 | x = 2 * self.keypoint_falloff[valid_kps_filter][:, None]
220 | d_squared = np.power(pose_kps-track_kps, 2)
221 | oks_total = np.sum(np.exp(-d_squared / (2 * box_area * x * x)))
222 | return oks_total / num_valid_kps
223 |
224 | def area(self, pose):
225 | """
226 | Computes the area of a bounding box that tightly covers keypoints.
227 | """
228 | kps = pose.keypoints_norm[pose.keypoints_score > self.keypoint_thresh]
229 | if len(kps) == 0: return 0
230 | xmin, ymin = np.min(kps, axis=0)
231 | xmax, ymax = np.max(kps, axis=0)
232 | return (xmax - xmin) * (ymax - ymin)
233 |
234 |
--------------------------------------------------------------------------------
/img/dance.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geaxgx/openvino_movenet_multipose/5e9d2ad88141767d4e2ba2382b12026f2b0aa400/img/dance.gif
--------------------------------------------------------------------------------
/img/street_192x256.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geaxgx/openvino_movenet_multipose/5e9d2ad88141767d4e2ba2382b12026f2b0aa400/img/street_192x256.jpg
--------------------------------------------------------------------------------
/img/street_256x320.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geaxgx/openvino_movenet_multipose/5e9d2ad88141767d4e2ba2382b12026f2b0aa400/img/street_256x320.jpg
--------------------------------------------------------------------------------
/img/street_480x640.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geaxgx/openvino_movenet_multipose/5e9d2ad88141767d4e2ba2382b12026f2b0aa400/img/street_480x640.jpg
--------------------------------------------------------------------------------
/img/street_736x1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geaxgx/openvino_movenet_multipose/5e9d2ad88141767d4e2ba2382b12026f2b0aa400/img/street_736x1280.jpg
--------------------------------------------------------------------------------
/img/tracking_iou.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geaxgx/openvino_movenet_multipose/5e9d2ad88141767d4e2ba2382b12026f2b0aa400/img/tracking_iou.gif
--------------------------------------------------------------------------------
/img/tracking_oks.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geaxgx/openvino_movenet_multipose/5e9d2ad88141767d4e2ba2382b12026f2b0aa400/img/tracking_oks.gif
--------------------------------------------------------------------------------
/models/movenet_multipose_lightning_192x192_FP32.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geaxgx/openvino_movenet_multipose/5e9d2ad88141767d4e2ba2382b12026f2b0aa400/models/movenet_multipose_lightning_192x192_FP32.bin
--------------------------------------------------------------------------------
/models/movenet_multipose_lightning_192x256_FP32.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geaxgx/openvino_movenet_multipose/5e9d2ad88141767d4e2ba2382b12026f2b0aa400/models/movenet_multipose_lightning_192x256_FP32.bin
--------------------------------------------------------------------------------
/models/movenet_multipose_lightning_256x256_FP32.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geaxgx/openvino_movenet_multipose/5e9d2ad88141767d4e2ba2382b12026f2b0aa400/models/movenet_multipose_lightning_256x256_FP32.bin
--------------------------------------------------------------------------------
/models/movenet_multipose_lightning_256x320_FP32.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geaxgx/openvino_movenet_multipose/5e9d2ad88141767d4e2ba2382b12026f2b0aa400/models/movenet_multipose_lightning_256x320_FP32.bin
--------------------------------------------------------------------------------
/models/movenet_multipose_lightning_320x320_FP32.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geaxgx/openvino_movenet_multipose/5e9d2ad88141767d4e2ba2382b12026f2b0aa400/models/movenet_multipose_lightning_320x320_FP32.bin
--------------------------------------------------------------------------------
/models/movenet_multipose_lightning_480x640_FP32.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geaxgx/openvino_movenet_multipose/5e9d2ad88141767d4e2ba2382b12026f2b0aa400/models/movenet_multipose_lightning_480x640_FP32.bin
--------------------------------------------------------------------------------
/models/movenet_multipose_lightning_736x1280_FP32.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geaxgx/openvino_movenet_multipose/5e9d2ad88141767d4e2ba2382b12026f2b0aa400/models/movenet_multipose_lightning_736x1280_FP32.bin
--------------------------------------------------------------------------------