├── .gitignore ├── LICENSE ├── README.md ├── mcftracker.py ├── test.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 watanika 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-object tracking using min-cost flow 2 | 3 | This is a simple Python implementation of tracking algorithm based on global data association via network flows [1]. 4 | 5 | Targets are tracked by minimizing network costs built on initial detection results. 6 | 7 | ## Dependencies 8 | 9 | - numpy 10 | - OpenCV (for image reading, processing) 11 | - ortools (for optimizing min-cost flow) 12 | 13 | ## Usage 14 | 15 | Please modify test.py and mcftracker.py to adapt your tracking targets. 16 | You can test this implementation as: 17 | 18 | ```sh 19 | % python test.py 20 | ``` 21 | 22 | To include it in your project, you just need to: 23 | 24 | ```py 25 | 26 | tracker = MinCostFlowTracker(some_parameters) 27 | tracker.build_network(images) 28 | optimal_flow, optimal_cost = tracker.run() 29 | 30 | ``` 31 | 32 | You can use fibonacci search to reduce computation costs. 33 | 34 | ## License 35 | 36 | MIT 37 | 38 | ## References 39 | 40 | [1] L. Zhang et al., "Global data association for multi-object tracking using network flows", CVPR 2008 -------------------------------------------------------------------------------- /mcftracker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2018 watanika, all rights reserved. 3 | Licensed under the MIT license . This file may 5 | not be copied, modified,or distributed except 6 | according to those terms. 7 | """ 8 | 9 | import math 10 | from ortools.graph import pywrapgraph 11 | import sys 12 | 13 | import tools 14 | 15 | 16 | class MinCostFlowTracker: 17 | """ 18 | Object tracking based on data association via minimum cost flow algorithm 19 | L. Zhang et al., 20 | "Global data association for multi-object tracking using network flows", 21 | CVPR 2008 22 | """ 23 | 24 | def __init__(self, detections, tags, min_thresh, P_enter, P_exit, beta): 25 | self._detections = detections 26 | self._min_thresh = min_thresh 27 | 28 | self.P_enter = P_enter 29 | self.P_exit = self.P_enter 30 | self.beta = beta 31 | 32 | self._id2name = tools.map_id2name(tags) 33 | self._name2id = tools.map_name2id(tags) 34 | self._id2node = tools.map_id2node(detections) 35 | self._node2id = tools.map_node2id(detections) 36 | self._fib_cache = {0: 0, 1: 1} 37 | 38 | def _fib(self, n): 39 | if n in self._fib_cache: 40 | return self._fib_cache[n] 41 | elif n > 1: 42 | return self._fib_cache.setdefault(n, self._fib(n - 1) + self._fib(n - 2)) 43 | return n 44 | 45 | def _find_nearest_fib(self, num): 46 | for n in range(num): 47 | if num < self._fib(n): 48 | return (n - 1, self._fib(n - 1)) 49 | return (num, self._fib(num)) 50 | 51 | def _calc_cost_enter(self): 52 | return -math.log(self.P_enter) 53 | 54 | def _calc_cost_exit(self): 55 | return -math.log(self.P_exit) 56 | 57 | def _calc_cost_detection(self, beta): 58 | return math.log(beta / (1.0 - beta)) 59 | 60 | def _calc_cost_link(self, rect1, rect2, image1=None, image2=None, eps=1e-7): 61 | prob_iou = tools.calc_overlap(rect1, rect2) 62 | hist1 = tools.calc_HS_histogram(image1, rect1) 63 | hist2 = tools.calc_HS_histogram(image2, rect2) 64 | prob_color = 1.0 - tools.calc_bhattacharyya_distance(hist1, hist2) 65 | 66 | prob_sim = prob_iou * prob_color 67 | return -math.log(prob_sim + eps) 68 | 69 | def build_network(self, images={}, f2i_factor=10000): 70 | self.mcf = pywrapgraph.SimpleMinCostFlow() 71 | 72 | for image_name, rects in sorted(self._detections.items()): 73 | for i, rect in enumerate(rects): 74 | self.mcf.AddArcWithCapacityAndUnitCost(self._node2id["source"], self._node2id[(image_name, i, "u")], 1, int(self._calc_cost_enter() * f2i_factor)) 75 | self.mcf.AddArcWithCapacityAndUnitCost(self._node2id[(image_name, i, "u")], self._node2id[(image_name, i, "v")], 1, int(self._calc_cost_detection(rect[4]) * f2i_factor)) 76 | self.mcf.AddArcWithCapacityAndUnitCost(self._node2id[(image_name, i, "v")], self._node2id["sink"], 1, int(self._calc_cost_exit() * f2i_factor)) 77 | 78 | frame_id = self._name2id[image_name] 79 | if frame_id == 0: 80 | continue 81 | prev_image_name = self._id2name[frame_id - 1] 82 | if prev_image_name not in self._detections: 83 | continue 84 | 85 | for i, i_rect in enumerate(self._detections[prev_image_name]): 86 | for j, j_rect in enumerate(rects): 87 | self.mcf.AddArcWithCapacityAndUnitCost(self._node2id[(prev_image_name, i, "v")], self._node2id[(image_name, j, "u")], 1, int(self._calc_cost_link(images[prev_image_name], i_rect, images[image_name], j_rect) * 1000)) 88 | 89 | def _make_flow_dict(self): 90 | self.flow_dict = {} 91 | for i in range(self.mcf.NumArcs()): 92 | if self.mcf.Flow(i) > 0: 93 | tail = self.mcf.Tail(i) 94 | head = self.mcf.Head(i) 95 | if self._id2node[tail] in self.flow_dict: 96 | self.flow_dict[self._id2node[tail]][self._id2node[head]] = 1 97 | else: 98 | self.flow_dict[self._id2node[tail]] = {self._id2node[head]: 1} 99 | 100 | def _fibonacci_search(self, search_range=200): 101 | s = 0 102 | k_max, t = self._find_nearest_fib(self.mcf.NumNodes() // search_range) 103 | cost = {} 104 | 105 | for k in range(k_max, 1, -1): 106 | # s < u < v < t 107 | u = s + self._fib(k - 2) 108 | v = s + self._fib(k - 1) 109 | 110 | if u not in cost: 111 | self.mcf.SetNodeSupply(self._node2id["source"], u) 112 | self.mcf.SetNodeSupply(self._node2id["sink"], -u) 113 | 114 | if self.mcf.Solve() == self.mcf.OPTIMAL: 115 | cost[u] = self.mcf.OptimalCost() 116 | else: 117 | print("There was an issue with the min cost flow input.") 118 | sys.exit() 119 | 120 | if v not in cost: 121 | self.mcf.SetNodeSupply(self._node2id["source"], v) 122 | self.mcf.SetNodeSupply(self._node2id["sink"], -v) 123 | 124 | if self.mcf.Solve() == self.mcf.OPTIMAL: 125 | cost[v] = self.mcf.OptimalCost() 126 | else: 127 | print("There was an issue with the min cost flow input.") 128 | sys.exit() 129 | 130 | if cost[u] < cost[v]: 131 | t = v 132 | elif cost[u] == cost[v]: 133 | s = u 134 | t = v 135 | else: 136 | s = u 137 | 138 | self.mcf.SetNodeSupply(self._node2id["source"], s) 139 | self.mcf.SetNodeSupply(self._node2id["sink"], -s) 140 | 141 | if self.mcf.Solve() == self.mcf.OPTIMAL: 142 | optimal_cost = self.mcf.OptimalCost() 143 | else: 144 | print("There was an issue with the min cost flow input.") 145 | sys.exit() 146 | self._make_flow_dict() 147 | return (s, optimal_cost) 148 | 149 | def _brute_force(self, search_range=100): 150 | max_flow = self.mcf.NumNodes() // search_range 151 | print("Search: 0 < num_flow <", max_flow) 152 | 153 | optimal_flow = 0 154 | optimal_cost = float("inf") 155 | for flow in range(max_flow): 156 | self.mcf.SetNodeSupply(self._node2id["source"], flow) 157 | self.mcf.SetNodeSupply(self._node2id["sink"], -flow) 158 | 159 | if self.mcf.Solve() == self.mcf.OPTIMAL: 160 | cost = self.mcf.OptimalCost() 161 | else: 162 | print("There was an issue with the min cost flow input.") 163 | sys.exit() 164 | 165 | if cost < optimal_cost: 166 | optimal_flow = flow 167 | optimal_cost = cost 168 | self._make_flow_dict() 169 | return (optimal_flow, optimal_cost) 170 | 171 | def run(self, fib=False): 172 | if fib: 173 | return self._fibonacci_search() 174 | else: 175 | return self._brute_force() 176 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2018 watanika, all rights reserved. 3 | Licensed under the MIT license . This file may 5 | not be copied, modified,or distributed except 6 | according to those terms. 7 | """ 8 | 9 | import time 10 | from mcftracker import MinCostFlowTracker 11 | 12 | 13 | # Example usage of mcftracker 14 | def main(): 15 | # Prepare initial detecton results, ground truth, and images 16 | # You need to change below 17 | detections = {"image_name": [x1, y1, x2, y2, score]} 18 | tags = {"image_name": [x1, y1, x2, y2]} 19 | images = {"image_name": numpy_image} 20 | 21 | # Parameters 22 | min_thresh = 0 23 | P_enter = 0.1 24 | P_exit = 0.1 25 | beta = 0.5 26 | fib_search = True 27 | 28 | # Let's track them! 29 | start = time.time() 30 | tracker = MinCostFlowTracker(detections, tags, min_thresh, P_enter, P_exit, beta) 31 | tracker.build_network(images) 32 | optimal_flow, optimal_cost = tracker.run(fib=fib_search) 33 | end = time.time() 34 | 35 | print("Finished: {} sec".format(end - start)) 36 | print("Optimal number of flow: {}".format(optimal_flow)) 37 | print("Optimal cost: {}".format(optimal_cost)) 38 | 39 | print("Optimal flow:") 40 | print(tracker.flow_dict) 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2018 watanika, all rights reserved. 3 | Licensed under the MIT license . This file may 5 | not be copied, modified,or distributed except 6 | according to those terms. 7 | """ 8 | 9 | import cv2 10 | 11 | 12 | def calc_HS_histogram(image, roi): 13 | cropped = image[roi[1]:roi[3], roi[0]:roi[2], :] 14 | hsv = cv2.cvtColor(cropped, cv2.COLOR_BGR2HSV) 15 | 16 | hist = cv2.calcHist([hsv], [0, 1], None, [180, 256], [0, 180, 0, 256]) 17 | cv2.normalize(hist, hist, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX).flatten() 18 | return hist 19 | 20 | 21 | def calc_bhattacharyya_distance(hist1, hist2): 22 | return cv2.compareHist(hist1, hist2, cv2.HISTCMP_BHATTACHARYYA) 23 | 24 | 25 | def map_node2id(detections): 26 | node2id = {} 27 | node2id["source"] = 0 28 | node2id["sink"] = 1 29 | 30 | nextid = 2 31 | for image_name, rects in sorted(detections.items()): 32 | for i, rect in enumerate(rects): 33 | node2id[(image_name, i, "u")] = nextid 34 | node2id[(image_name, i, "v")] = nextid + 1 35 | nextid += 2 36 | return node2id 37 | 38 | 39 | def map_id2node(detections): 40 | id2node = {} 41 | id2node[0] = "source" 42 | id2node[1] = "sink" 43 | 44 | nextid = 2 45 | for image_name, rects in sorted(detections.items()): 46 | for i, rect in enumerate(rects): 47 | id2node[nextid] = (image_name, i, "u") 48 | id2node[nextid + 1] = (image_name, i, "v") 49 | nextid += 2 50 | return id2node 51 | 52 | 53 | def map_name2id(tags): 54 | name2id = {} 55 | for frame_id, (image_name, rects) in enumerate(sorted(tags.items())): 56 | name2id[image_name] = frame_id 57 | return name2id 58 | 59 | 60 | def map_id2name(tags): 61 | id2name = {} 62 | for frame_id, (image_name, rects) in enumerate(sorted(tags.items())): 63 | id2name[frame_id] = image_name 64 | return id2name 65 | 66 | 67 | def calc_overlap(bb1, bb2): 68 | bi = (max(bb1[0], bb2[0]), max(bb1[1], bb2[1]), min(bb1[2], bb2[2]), min(bb1[3], bb2[3])) 69 | iw = bi[2] - bi[0] + 1 70 | ih = bi[3] - bi[1] + 1 71 | if iw > 0 and ih > 0: 72 | ua = (bb1[2] - bb1[0] + 1) * (bb1[3] - bb1[1] + 1) + (bb2[2] - bb2[0] + 1) * (bb2[3] - bb2[1] + 1) - iw * ih 73 | return iw * ih / ua 74 | else: 75 | return 0.0 76 | --------------------------------------------------------------------------------