├── .gitignore ├── .travis.yml ├── LICENSE ├── MANIFEST.in ├── README.md ├── chainer_sort ├── __init__.py ├── datasets │ ├── __init__.py │ └── mot │ │ ├── __init__.py │ │ ├── mot_dataset.py │ │ └── mot_utils.py ├── models │ ├── __init__.py │ └── sort_multi_object_tracking.py ├── trackers │ ├── __init__.py │ ├── kalman_bbox_tracker.py │ └── sort_multi_bbox_tracker.py ├── utils.py └── visualizations │ ├── __init__.py │ └── vis_tracking_bbox.py ├── examples └── mot │ └── demo.py ├── requirements.txt ├── setup.py └── static └── sort_faster_rcnn_example.gif /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | cache: 4 | - ccache 5 | - pip 6 | 7 | jobs: 8 | include: 9 | - os: linux 10 | python: 2.7 11 | - os: linux 12 | python: 3.6 13 | - os: linux 14 | python: 3.7 15 | - os: linux 16 | python: 3.8 17 | allow_failures: 18 | - os: linux 19 | python: 2.7 20 | 21 | before_install: 22 | - wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh 23 | - bash miniconda.sh -b -p $HOME/miniconda 24 | - export PATH="$HOME/miniconda/bin:$PATH" 25 | - conda config --set always_yes yes --set changeps1 no 26 | - conda update -q conda 27 | - conda info -a 28 | 29 | install: 30 | - conda create --name=chainer-sort python=$TRAVIS_PYTHON_VERSION -q -y 31 | - source activate chainer-sort 32 | - pip install -r requirements.txt 33 | - pip install -e . 34 | 35 | before_script: 36 | - pip install flake8 37 | - pip install hacking 38 | 39 | script: 40 | - flake8 . 41 | 42 | notifications: 43 | email: false 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Shingo Kitagawa 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt README.md 2 | recursive-include static * 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | chainer-sort - SORT 2 | =================== 3 | ![Build Status](https://travis-ci.org/knorth55/chainer-sort.svg?branch=master) 4 | 5 | This is [Simple, Online, Realtime Tracking of Multiple Objects](https://arxiv.org/abs/1602.00763) implementation for chainer and chainercv. 6 | 7 | This repository provides MOT dataset, SORT tracker class and SORT examples with FasterRCNN and SSD. 8 | 9 | [\[arXiv\]](https://arxiv.org/abs/1602.00763), [\[Original repo\]](https://github.com/abewley/sort) 10 | 11 | 12 | 13 | Notification 14 | ------------ 15 | 16 | - This repository is the implementation of [SORT](https://arxiv.org/abs/1602.00763), not [DeepSORT](https://arxiv.org/abs/1703.07402) 17 | - SORT is based on Kalman filter and Hangarian algorithm and does not use deep learning techniques. 18 | - In this repo, we use deep learning techniques (FasterRCNN and SSD) for object detection part. 19 | 20 | Requirement 21 | ----------- 22 | 23 | - [Chainer](https://github.com/chainer/chainer) 24 | - [ChainerCV](https://github.com/chainer/chainercv) 25 | - [FilterPy](https://github.com/rlabbe/filterpy) 26 | 27 | Installation 28 | ------------ 29 | 30 | We recommend to use [Anacoda](https://anaconda.org/). 31 | 32 | ```bash 33 | # Requirement installation 34 | conda create -n chainer-sort python=2.7 35 | source activate chainer-sort 36 | 37 | git clone https://github.com/knorth55/chainer-sort 38 | cd chainer-sort/ 39 | pip install -e . 40 | ``` 41 | 42 | Demo 43 | ---- 44 | 45 | ```bash 46 | cd examples/mot/ 47 | python demo.py 48 | ``` 49 | -------------------------------------------------------------------------------- /chainer_sort/__init__.py: -------------------------------------------------------------------------------- 1 | import pkg_resources 2 | 3 | from chainer_sort import datasets # NOQA 4 | from chainer_sort import models # NOQA 5 | from chainer_sort import trackers # NOQA 6 | from chainer_sort import utils # NOQA 7 | from chainer_sort import visualizations # NOQA 8 | 9 | 10 | __version__ = pkg_resources.get_distribution('chainer_sort').version 11 | -------------------------------------------------------------------------------- /chainer_sort/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from chainer_sort.datasets.mot.mot_dataset import MOTDataset # NOQA 2 | from chainer_sort.datasets.mot.mot_utils import mot_map_names # NOQA 3 | from chainer_sort.datasets.mot.mot_utils import mot_sequence_names # NOQA 4 | -------------------------------------------------------------------------------- /chainer_sort/datasets/mot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/knorth55/chainer-sort/a72d4b7f6b88b8e06f9cdfe61e12133b46a75d26/chainer_sort/datasets/mot/__init__.py -------------------------------------------------------------------------------- /chainer_sort/datasets/mot/mot_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import chainer 4 | from chainercv.utils import read_image 5 | 6 | from chainer_sort.datasets.mot import mot_utils 7 | from chainer_sort.datasets.mot.mot_utils import mot_map_names 8 | from chainer_sort.datasets.mot.mot_utils import mot_sequence_names 9 | 10 | 11 | class MOTDataset(chainer.dataset.DatasetMixin): 12 | 13 | def __init__( 14 | self, data_dir='auto', year='2015', split='train', sequence='c2', 15 | ): 16 | if split not in ['train', 'val', 'trainval']: 17 | raise ValueError( 18 | 'please pick split from \'train\', \'trainval\', \'val\'') 19 | 20 | if data_dir == 'auto': 21 | data_dir = mot_utils.get_mot(year, split) 22 | 23 | id_list_file = os.path.join( 24 | data_dir, 'annotations/{0}.txt'.format(split)) 25 | ids = [id_.strip() for id_ in open(id_list_file)] 26 | if sequence in mot_map_names: 27 | if (year == '2015' and sequence not in ['c2', 'c3', 'c4']) \ 28 | or (year == '2016' and sequence != 'c5') \ 29 | or (year == '2017' and sequence != 'c9'): 30 | raise ValueError 31 | sequences = mot_utils.get_sequences(split, sequence) 32 | self.ids = [id_ for id_ in ids if id_.split('_')[1] in sequences] 33 | elif sequence.startswith(tuple(mot_sequence_names[year])): 34 | self.ids = [ 35 | id_ for id_ in ids if id_.split('_')[1].startswith(sequence) 36 | ] 37 | else: 38 | raise ValueError 39 | 40 | self.data_dir = data_dir 41 | 42 | self.id2inst_id, self.id2bbox = None, None 43 | if split != 'val': 44 | self.id2inst_id, self.id2bbox = mot_utils.load_gt( 45 | self.data_dir, self.ids) 46 | 47 | def __len__(self): 48 | return len(self.ids) 49 | 50 | def get_example(self, i): 51 | data_id = self.ids[i] 52 | split_d, seq_d, frame = data_id.split('_') 53 | img_file = os.path.join( 54 | self.data_dir, split_d, seq_d, 'img1/{}.jpg'.format(frame)) 55 | img = read_image(img_file, color=True) 56 | if self.id2inst_id is None and self.id2bbox is None: 57 | inst_id, bbox = None, None 58 | else: 59 | inst_id = self.id2inst_id[data_id] 60 | bbox = self.id2bbox[data_id] 61 | return img, bbox, inst_id 62 | -------------------------------------------------------------------------------- /chainer_sort/datasets/mot/mot_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | from chainer.dataset import download 5 | from chainercv import utils 6 | 7 | 8 | root = 'pfnet/chainercv/mot' 9 | dev_urls = 'http://motchallenge.net/data/devkit.zip' 10 | urls = { 11 | '2015': 'http://motchallenge.net/data/2DMOT2015.zip', 12 | '2016': 'http://motchallenge.net/data/MOT16.zip', 13 | '2017': 'http://motchallenge.net/data/MOT17.zip', 14 | } 15 | 16 | mot_map_names = ['c2', 'c3', 'c4', 'c5', 'c9'] 17 | mot_sequence_names = { 18 | '2015': [ 19 | 'TUD-Stadtmitte', 20 | 'TUD-Campus', 21 | 'PETS09-S2L1', 22 | 'ETH-Bahnhof', 23 | 'ETH-Sunnyday', 24 | 'ETH-Pedcross2', 25 | 'ADL-Rundle-6', 26 | 'ADL-Rundle-8', 27 | 'KITTI-13', 28 | 'KITTI-17', 29 | 'Venice-2', 30 | 'TUD-Crossing', 31 | 'PETS09-S2L2', 32 | 'ETH-Jelmoli', 33 | 'ETH-Linthescher', 34 | 'ETH-Crossing', 35 | 'AVG-TownCentre', 36 | 'ADL-Rundle-1', 37 | 'ADL-Rundle-3', 38 | 'KITTI-16', 39 | 'KITTI-19', 40 | 'Venice-1', 41 | ] 42 | } 43 | 44 | mot_sequence_names['2016'] = ['MOT16-{0:02d}'.format(i) for i in range(1, 15)] 45 | mot_sequence_names['2017'] = ['MOT17-{0:02d}'.format(i) for i in range(1, 15)] 46 | 47 | 48 | def load_gt(data_dir, data_ids): 49 | gt_dict = {} 50 | id2bbox = {} 51 | id2inst_id = {} 52 | 53 | for data_id in data_ids: 54 | split_d, seq_d, img_name = data_id.split('_') 55 | frame_id = int(img_name) 56 | if split_d != 'train': 57 | id2inst_id[data_id] = None 58 | id2bbox[data_id] = None 59 | continue 60 | if seq_d not in gt_dict: 61 | gt_dict[seq_d] = _load_gt(data_dir, split_d, seq_d) 62 | if frame_id in gt_dict[seq_d]: 63 | inst_id, bbox = gt_dict[seq_d][frame_id] 64 | else: 65 | inst_id, bbox = np.empty((0, )), np.empty((0, 4)) 66 | id2inst_id[data_id] = inst_id 67 | id2bbox[data_id] = bbox 68 | return id2inst_id, id2bbox 69 | 70 | 71 | def _load_gt(data_dir, split_d, seq_d): 72 | gt_path = os.path.join(data_dir, split_d, seq_d, 'gt/gt.txt') 73 | 74 | gt_dict = {} 75 | gt_data = [map(float, x.split(',')[:6]) for x in open(gt_path).readlines()] 76 | gt_data = np.array(gt_data) 77 | for frame_id in np.unique(gt_data[:, 0].astype(np.int32)): 78 | gt_d = gt_data[gt_data[:, 0] == frame_id] 79 | inst_id = gt_d[:, 1].astype(np.int32) 80 | bbox = gt_d[:, 2:].astype(np.float32) 81 | bbox[:, 2:4] += bbox[:, :2] 82 | bbox = bbox[:, [1, 0, 3, 2]] 83 | gt_dict[frame_id] = (inst_id, bbox) 84 | return gt_dict 85 | 86 | 87 | def get_sequences(split, map_name): 88 | if split == 'train': 89 | splits = ['train'] 90 | elif split == 'val': 91 | splits = ['test'] 92 | elif split == 'trainval': 93 | if map_name in ['c2', 'c3', 'c4']: 94 | splits = ['train', 'test'] 95 | else: 96 | splits = ['all'] 97 | else: 98 | raise ValueError 99 | 100 | seq_map = [] 101 | data_root = download.get_dataset_directory(root) 102 | seq_path = os.path.join( 103 | data_root, 'motchallenge-devkit/motchallenge/seqmaps') 104 | for sp in splits: 105 | seqmap_path = os.path.join( 106 | seq_path, '{0}-{1}.txt'.format(map_name, sp)) 107 | with open(seqmap_path, 'r') as f: 108 | seq_m = f.read().split('\n') 109 | seq_map.extend(seq_m[1:-1]) 110 | if map_name == 'c9': 111 | seq_map = ['{}-DPM'.format(x) for x in seq_map] 112 | return seq_map 113 | 114 | 115 | def get_mot(year, split): 116 | if year not in urls: 117 | raise ValueError 118 | 119 | data_root = download.get_dataset_directory(root) 120 | if year == '2015': 121 | mot_dirname = '2DMOT{}'.format(year) 122 | else: 123 | mot_dirname = 'MOT{}'.format(year[2:]) 124 | base_path = os.path.join(data_root, mot_dirname) 125 | anno_path = os.path.join(base_path, 'annotations') 126 | anno_txt_path = os.path.join(anno_path, '{}.txt'.format(split)) 127 | 128 | if not os.path.exists(base_path): 129 | download_file_path = utils.cached_download(urls[year]) 130 | ext = os.path.splitext(urls[year])[1] 131 | utils.extractall(download_file_path, data_root, ext) 132 | 133 | if not os.path.exists(os.path.join(data_root, 'motchallenge-devkit')): 134 | download_devfile_path = utils.cached_download(dev_urls) 135 | dev_ext = os.path.splitext(dev_urls)[1] 136 | utils.extractall(download_devfile_path, data_root, dev_ext) 137 | 138 | if not os.path.exists(anno_path): 139 | os.mkdir(anno_path) 140 | if split == 'train': 141 | split_dirs = ['train'] 142 | elif split == 'val': 143 | split_dirs = ['test'] 144 | elif split == 'trainval': 145 | split_dirs = ['train', 'test'] 146 | else: 147 | raise ValueError 148 | 149 | data_ids = [] 150 | for split_d in split_dirs: 151 | seq_dirs = sorted(os.listdir(os.path.join(base_path, split_d))) 152 | for seq_d in seq_dirs: 153 | img_dir = os.path.join(base_path, split_d, seq_d, 'img1') 154 | img_names = sorted(os.listdir(img_dir)) 155 | for img_name in img_names: 156 | data_id = '{0}_{1}_{2}'.format( 157 | split_d, seq_d, img_name.split('.')[0]) 158 | data_ids.append(data_id) 159 | 160 | with open(anno_txt_path, 'w') as anno_f: 161 | anno_f.write('\n'.join(data_ids)) 162 | 163 | return base_path 164 | -------------------------------------------------------------------------------- /chainer_sort/models/__init__.py: -------------------------------------------------------------------------------- 1 | from chainer_sort.models.sort_multi_object_tracking import SORTMultiObjectTracking # NOQA 2 | -------------------------------------------------------------------------------- /chainer_sort/models/sort_multi_object_tracking.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import chainer 4 | 5 | from chainer_sort.trackers import SORTMultiBboxTracker 6 | 7 | 8 | class SORTMultiObjectTracking(object): 9 | 10 | def __init__(self, dectector, 11 | detector_label_names, tracking_label_names=None): 12 | self.bbox_tracker = SORTMultiBboxTracker() 13 | self.detector = dectector 14 | self.detector_label_names = detector_label_names 15 | self.tracking_label_names = tracking_label_names 16 | 17 | def predict(self, img): 18 | if len(img) > 1: 19 | raise ValueError 20 | bboxes, labels, scores = self.detector.predict(img) 21 | bbox, label, score = bboxes[0], labels[0], scores[0] 22 | bbox = chainer.cuda.to_cpu(bbox) 23 | label = chainer.cuda.to_cpu(label) 24 | score = chainer.cuda.to_cpu(score) 25 | 26 | det_bbox = [] 27 | det_label = [] 28 | det_score = [] 29 | for bb, lbl, sc in zip(bbox, label, score): 30 | if self.detector_label_names[lbl] in self.tracking_label_names: 31 | det_bbox.append(bb[None]) 32 | det_label.append(lbl) 33 | det_score.append(sc) 34 | if len(det_bbox) > 0: 35 | det_bbox = np.concatenate(det_bbox) 36 | else: 37 | det_bbox = np.array(det_bbox).reshape((0, 4)) 38 | det_label = np.array(det_label) 39 | det_score = np.array(det_score) 40 | 41 | trk_indices, trk_bbox, trk_inst_id = self.bbox_tracker.update(det_bbox) 42 | trk_label = det_label[trk_indices] 43 | trk_score = det_score[trk_indices] 44 | return trk_bbox[None], trk_label[None], \ 45 | trk_score[None], trk_inst_id[None] 46 | -------------------------------------------------------------------------------- /chainer_sort/trackers/__init__.py: -------------------------------------------------------------------------------- 1 | from chainer_sort.trackers.kalman_bbox_tracker import KalmanBboxTracker # NOQA 2 | from chainer_sort.trackers.sort_multi_bbox_tracker import SORTMultiBboxTracker # NOQA 3 | -------------------------------------------------------------------------------- /chainer_sort/trackers/kalman_bbox_tracker.py: -------------------------------------------------------------------------------- 1 | # Modified work by Shingo Kitagawa (@knorth55) 2 | # 3 | # Original works of SORT from https://github.com/abewley/sort 4 | # --------------------------------------------------------------------- 5 | # Copyright (C) 2016 Alex Bewley alex@dynamicdetection.com 6 | # 7 | # This program is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # This program is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU General Public License 18 | # along with this program. If not, see . 19 | # --------------------------------------------------------------------- 20 | 21 | 22 | import numpy as np 23 | 24 | from filterpy.kalman import KalmanFilter 25 | 26 | from chainer_sort.utils import bbox2z_bbox 27 | from chainer_sort.utils import z_bbox2bbox 28 | 29 | 30 | class KalmanBboxTracker(object): 31 | 32 | def __init__(self, bbox): 33 | if len(bbox) > 1: 34 | raise ValueError 35 | 36 | self.filter = KalmanFilter(dim_x=7, dim_z=4) 37 | self.filter.F = np.array( 38 | [[1, 0, 0, 0, 1, 0, 0], 39 | [0, 1, 0, 0, 0, 1, 0], 40 | [0, 0, 1, 0, 0, 0, 1], 41 | [0, 0, 0, 1, 0, 0, 0], 42 | [0, 0, 0, 0, 1, 0, 0], 43 | [0, 0, 0, 0, 0, 1, 0], 44 | [0, 0, 0, 0, 0, 0, 1]]) 45 | self.filter.H = np.array( 46 | [[1, 0, 0, 0, 0, 0, 0], 47 | [0, 1, 0, 0, 0, 0, 0], 48 | [0, 0, 1, 0, 0, 0, 0], 49 | [0, 0, 0, 1, 0, 0, 0]]) 50 | 51 | self.filter.R[2:, 2:] *= 10. 52 | self.filter.P[4:, 4:] *= 1000. 53 | self.filter.P *= 10. 54 | self.filter.Q[-1, -1] *= 0.01 55 | self.filter.Q[4:, 4:] *= 0.01 56 | self.filter.x[:4] = bbox2z_bbox(bbox).T 57 | 58 | self.time_since_update = 0 59 | self.hit_streak = 0 60 | 61 | def get_state(self): 62 | bbox = z_bbox2bbox(self.filter.x[:4].T) 63 | return bbox 64 | 65 | def update(self, bbox): 66 | if len(bbox) > 1: 67 | raise ValueError 68 | 69 | self.time_since_update = 0 70 | self.hit_streak += 1 71 | self.filter.update(bbox2z_bbox(bbox).T) 72 | 73 | def predict(self): 74 | if (self.filter.x[2] + self.filter.x[6]) <= 0: 75 | self.filter.x[6] = 0. 76 | self.filter.predict() 77 | if self.time_since_update > 0: 78 | self.hit_streak = 0 79 | self.time_since_update += 1 80 | bbox = self.get_state() 81 | return bbox 82 | -------------------------------------------------------------------------------- /chainer_sort/trackers/sort_multi_bbox_tracker.py: -------------------------------------------------------------------------------- 1 | # Modified work by Shingo Kitagawa (@knorth55) 2 | # 3 | # Original works of SORT from https://github.com/abewley/sort 4 | # --------------------------------------------------------------------- 5 | # Copyright (C) 2016 Alex Bewley alex@dynamicdetection.com 6 | # 7 | # This program is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # This program is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU General Public License 18 | # along with this program. If not, see . 19 | # --------------------------------------------------------------------- 20 | 21 | 22 | import numpy as np 23 | 24 | from chainer_sort.trackers import KalmanBboxTracker 25 | from chainer_sort.utils import iou_linear_assignment 26 | 27 | 28 | class SORTMultiBboxTracker(object): 29 | 30 | def __init__(self, max_age=1, min_hit_streak=3): 31 | self.max_age = max_age 32 | self.min_hit_streak = min_hit_streak 33 | self.trackers = [] 34 | self.frame_count = 0 35 | self.tracker_num = 0 36 | self.inst_ids = [] 37 | 38 | def update(self, det_bboxes): 39 | self.frame_count += 1 40 | pred_bboxes = [] 41 | valid_trackers = [] 42 | for i, tracker in enumerate(self.trackers): 43 | pred_bbox = tracker.predict() 44 | if np.all(~np.isnan(pred_bbox)) and np.all(~np.isinf(pred_bbox)): 45 | valid_trackers.append(tracker) 46 | pred_bboxes.append(pred_bbox) 47 | self.trackers = valid_trackers 48 | 49 | if len(pred_bboxes) > 0: 50 | pred_bboxes = np.concatenate(pred_bboxes) 51 | else: 52 | pred_bboxes = np.array(pred_bboxes).reshape((-1, 4)) 53 | assert len(self.trackers) == len(pred_bboxes) 54 | 55 | matched_det_indices, matched_pred_indices = iou_linear_assignment( 56 | det_bboxes, pred_bboxes) 57 | 58 | # update matched trackers 59 | # create new trackers for unmatched detections 60 | new_trackers = [] 61 | trk_det_indices = [] 62 | trk_bboxes = [] 63 | trk_inst_ids = [] 64 | for det_index, det_bbox in enumerate(det_bboxes): 65 | # matched 66 | if det_index in matched_det_indices: 67 | pred_index = matched_pred_indices[ 68 | matched_det_indices == det_index] 69 | tracker = self.trackers[int(pred_index)] 70 | tracker.update(det_bbox[None]) 71 | # not matched 72 | else: 73 | tracker = KalmanBboxTracker(det_bbox[None]) 74 | tracker.id = self.tracker_num 75 | self.tracker_num += 1 76 | 77 | if tracker.time_since_update < 1 \ 78 | and (tracker.hit_streak >= self.min_hit_streak 79 | or self.frame_count <= self.min_hit_streak): 80 | trk_det_indices.append(det_index) 81 | trk_bbox = tracker.get_state()[0] 82 | trk_bboxes.append(trk_bbox) 83 | if tracker.id in self.inst_ids: 84 | trk_inst_id = self.inst_ids.index(tracker.id) 85 | else: 86 | self.inst_ids.append(tracker.id) 87 | trk_inst_id = len(self.inst_ids) 88 | trk_inst_ids.append(trk_inst_id) 89 | 90 | if tracker.time_since_update <= self.max_age: 91 | new_trackers.append(tracker) 92 | self.trackers = new_trackers 93 | 94 | trk_det_indices = np.array(trk_det_indices, dtype=int) 95 | trk_bboxes = np.array(trk_bboxes, dtype=float).reshape((-1, 4)) 96 | trk_inst_ids = np.array(trk_inst_ids, dtype=int) 97 | return trk_det_indices, trk_bboxes, trk_inst_ids 98 | -------------------------------------------------------------------------------- /chainer_sort/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | from chainercv.utils import bbox_iou 4 | import numpy as np 5 | from sklearn.utils.linear_assignment_ import linear_assignment 6 | 7 | 8 | def bbox2z_bbox(bbox): 9 | if bbox.shape[1] != 4: 10 | raise ValueError 11 | 12 | y_center = (bbox[:, 2] + bbox[:, 0]) / 2 13 | x_center = (bbox[:, 3] + bbox[:, 1]) / 2 14 | size = (bbox[:, 2] - bbox[:, 0]) * (bbox[:, 3] - bbox[:, 1]) 15 | ratio = (bbox[:, 2] - bbox[:, 0]) / (bbox[:, 3] - bbox[:, 1]) 16 | z_bbox = np.concatenate( 17 | (y_center[:, None], x_center[:, None], 18 | size[:, None], ratio[:, None]), 19 | axis=1) 20 | return z_bbox 21 | 22 | 23 | def z_bbox2bbox(z_bbox): 24 | if z_bbox.shape[1] != 4: 25 | raise ValueError 26 | 27 | height = np.sqrt(z_bbox[:, 2] * z_bbox[:, 3]) 28 | width = z_bbox[:, 2] / height 29 | y_min = z_bbox[:, 0] - (height / 2) 30 | x_min = z_bbox[:, 1] - (width / 2) 31 | y_max = z_bbox[:, 0] + (height / 2) 32 | x_max = z_bbox[:, 1] + (width / 2) 33 | bbox = np.concatenate( 34 | (y_min[:, None], x_min[:, None], 35 | y_max[:, None], x_max[:, None]), 36 | axis=1) 37 | return bbox 38 | 39 | 40 | def iou_linear_assignment(bbox_a, bbox_b): 41 | iou = bbox_iou(bbox_a, bbox_b) 42 | indices = linear_assignment(-iou) 43 | return indices[:, 0], indices[:, 1] 44 | -------------------------------------------------------------------------------- /chainer_sort/visualizations/__init__.py: -------------------------------------------------------------------------------- 1 | from chainer_sort.visualizations.vis_tracking_bbox import vis_tracking_bbox # NOQA 2 | -------------------------------------------------------------------------------- /chainer_sort/visualizations/vis_tracking_bbox.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import numpy as np 4 | 5 | from chainercv.visualizations.vis_image import vis_image 6 | 7 | 8 | def _default_cmap(label): 9 | """Color map used in PASCAL VOC""" 10 | r, g, b = 0, 0, 0 11 | i = label 12 | for j in range(8): 13 | if i & (1 << 0): 14 | r |= 1 << (7 - j) 15 | if i & (1 << 1): 16 | g |= 1 << (7 - j) 17 | if i & (1 << 2): 18 | b |= 1 << (7 - j) 19 | i >>= 3 20 | return r, g, b 21 | 22 | 23 | def vis_tracking_bbox( 24 | img, bbox, inst_id, label=None, score=None, 25 | label_names=None, alpha=1.0, ax=None): 26 | 27 | from matplotlib import pyplot as plot 28 | ax = vis_image(img, ax=ax) 29 | 30 | assert len(bbox) == len(inst_id) 31 | if len(bbox) == 0: 32 | return ax 33 | 34 | for i, (bb, inst_i) in enumerate(zip(bbox, inst_id)): 35 | bb = np.round(bb).astype(np.int32) 36 | y_min, x_min, y_max, x_max = bb 37 | color = np.array(_default_cmap(inst_i + 1)) / 255. 38 | 39 | ax.add_patch(plot.Rectangle( 40 | (x_min, y_min), x_max - x_min, y_max - y_min, 41 | fill=False, edgecolor=color, linewidth=3)) 42 | 43 | caption = [] 44 | caption.append('{}'.format(inst_i)) 45 | 46 | if label is not None and label_names is not None: 47 | lb = label[i] 48 | if not (0 <= lb < len(label_names)): 49 | raise ValueError('No corresponding name is given') 50 | caption.append(label_names[lb]) 51 | if score is not None: 52 | sc = score[i] 53 | caption.append('{:.2f}'.format(sc)) 54 | 55 | ax.text((x_max + x_min) / 2, y_min, 56 | ': '.join(caption), 57 | style='italic', 58 | bbox={'facecolor': color, 'alpha': alpha}, 59 | fontsize=8, color='white') 60 | -------------------------------------------------------------------------------- /examples/mot/demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import time 5 | 6 | import chainer 7 | from chainercv.datasets import voc_bbox_label_names 8 | from chainercv.links import FasterRCNNVGG16 9 | from chainercv.links import SSD300 10 | from chainercv.links import SSD512 11 | 12 | from chainer_sort.datasets.mot.mot_utils import get_sequences 13 | from chainer_sort.datasets import MOTDataset 14 | from chainer_sort.models import SORTMultiObjectTracking 15 | from chainer_sort.visualizations import vis_tracking_bbox 16 | 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--no-display', action='store_true') 21 | parser.add_argument( 22 | '--model', choices=('ssd300', 'ssd512', 'faster-rcnn-vgg16'), 23 | default='faster-rcnn-vgg16') 24 | parser.add_argument('--gpu', type=int, default=0) 25 | parser.add_argument( 26 | '--pretrained_model', choices=('voc0712', 'voc07'), default='voc0712') 27 | parser.add_argument( 28 | '--year', '-y', choices=('2015', '2016', '2017'), default='2015') 29 | parser.add_argument('--sequence-map', '-s', default=None) 30 | args = parser.parse_args() 31 | 32 | if args.sequence_map is None: 33 | if args.year == '2015': 34 | seqmap_name = 'c2-test' 35 | elif args.year == '2016': 36 | seqmap_name = 'c5-test' 37 | elif args.year == '2017': 38 | seqmap_name = 'c9-test' 39 | 40 | map_name, split = seqmap_name.split('-') 41 | if split == 'test': 42 | split = 'val' 43 | sequences = get_sequences(split, map_name) 44 | 45 | if args.model == 'ssd300': 46 | detector = SSD300( 47 | n_fg_class=len(voc_bbox_label_names), 48 | pretrained_model=args.pretrained_model) 49 | elif args.model == 'ssd512': 50 | detector = SSD512( 51 | n_fg_class=len(voc_bbox_label_names), 52 | pretrained_model=args.pretrained_model) 53 | elif args.model == 'faster-rcnn-vgg16': 54 | detector = FasterRCNNVGG16( 55 | n_fg_class=len(voc_bbox_label_names), 56 | pretrained_model=args.pretrained_model) 57 | detector.use_preset('evaluate') 58 | detector.score_thresh = 0.5 59 | 60 | if args.gpu >= 0: 61 | chainer.cuda.get_device_from_id(args.gpu).use() 62 | detector.to_gpu() 63 | 64 | sort_label_names = ['person'] 65 | 66 | if not args.no_display: 67 | plt.ion() 68 | fig = plt.figure() 69 | 70 | for seq in sequences: 71 | if args.no_display: 72 | ax = fig.add_subplot(111, aspect='equal') 73 | 74 | dataset = MOTDataset( 75 | year=args.year, split=split, sequence=seq) 76 | 77 | model = SORTMultiObjectTracking( 78 | detector, voc_bbox_label_names, 79 | sort_label_names) 80 | 81 | print('Sequence: {}'.format(seq)) 82 | cycle_times = [] 83 | for i in range(len(dataset)): 84 | img, _, _ = dataset[i] 85 | start_time = time.time() 86 | bboxes, labels, scores, inst_ids = model.predict([img]) 87 | cycle_time = time.time() - start_time 88 | cycle_times.append(cycle_time) 89 | if args.no_display: 90 | bbox = bboxes[0] 91 | inst_id = inst_ids[0] 92 | label = labels[0] 93 | score = scores[0] 94 | vis_tracking_bbox( 95 | img, bbox, inst_id, label, score, 96 | label_names=voc_bbox_label_names, ax=ax) 97 | fig.canvas.flush_events() 98 | plt.draw() 99 | ax.cla() 100 | 101 | cycle_times = np.array(cycle_times) 102 | print('total time: {}'.format(np.sum(cycle_times))) 103 | print('average time: {}'.format(np.average(cycle_times))) 104 | 105 | 106 | if __name__ == '__main__': 107 | main() 108 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | chainer 2 | chainercv 3 | filterpy 4 | numpy 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages 2 | from setuptools import setup 3 | 4 | 5 | version = '0.1.0' 6 | 7 | 8 | setup( 9 | name='chainer_sort', 10 | version=version, 11 | packages=find_packages(), 12 | install_requires=open('requirements.txt').readlines(), 13 | description='SORT for chainercv', 14 | long_description=open('README.md').read(), 15 | author='Shingo Kitagawa', 16 | author_email='shingogo.5511@gmail.com', 17 | url='https://github.com/knorth55/chainer-sort', 18 | license='MIT', 19 | keywords='machine-learning', 20 | ) 21 | -------------------------------------------------------------------------------- /static/sort_faster_rcnn_example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/knorth55/chainer-sort/a72d4b7f6b88b8e06f9cdfe61e12133b46a75d26/static/sort_faster_rcnn_example.gif --------------------------------------------------------------------------------