├── .gitignore
├── .travis.yml
├── LICENSE
├── MANIFEST.in
├── README.md
├── chainer_sort
├── __init__.py
├── datasets
│ ├── __init__.py
│ └── mot
│ │ ├── __init__.py
│ │ ├── mot_dataset.py
│ │ └── mot_utils.py
├── models
│ ├── __init__.py
│ └── sort_multi_object_tracking.py
├── trackers
│ ├── __init__.py
│ ├── kalman_bbox_tracker.py
│ └── sort_multi_bbox_tracker.py
├── utils.py
└── visualizations
│ ├── __init__.py
│ └── vis_tracking_bbox.py
├── examples
└── mot
│ └── demo.py
├── requirements.txt
├── setup.py
└── static
└── sort_faster_rcnn_example.gif
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 |
3 | cache:
4 | - ccache
5 | - pip
6 |
7 | jobs:
8 | include:
9 | - os: linux
10 | python: 2.7
11 | - os: linux
12 | python: 3.6
13 | - os: linux
14 | python: 3.7
15 | - os: linux
16 | python: 3.8
17 | allow_failures:
18 | - os: linux
19 | python: 2.7
20 |
21 | before_install:
22 | - wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh
23 | - bash miniconda.sh -b -p $HOME/miniconda
24 | - export PATH="$HOME/miniconda/bin:$PATH"
25 | - conda config --set always_yes yes --set changeps1 no
26 | - conda update -q conda
27 | - conda info -a
28 |
29 | install:
30 | - conda create --name=chainer-sort python=$TRAVIS_PYTHON_VERSION -q -y
31 | - source activate chainer-sort
32 | - pip install -r requirements.txt
33 | - pip install -e .
34 |
35 | before_script:
36 | - pip install flake8
37 | - pip install hacking
38 |
39 | script:
40 | - flake8 .
41 |
42 | notifications:
43 | email: false
44 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Shingo Kitagawa
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include requirements.txt README.md
2 | recursive-include static *
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | chainer-sort - SORT
2 | ===================
3 | 
4 |
5 | This is [Simple, Online, Realtime Tracking of Multiple Objects](https://arxiv.org/abs/1602.00763) implementation for chainer and chainercv.
6 |
7 | This repository provides MOT dataset, SORT tracker class and SORT examples with FasterRCNN and SSD.
8 |
9 | [\[arXiv\]](https://arxiv.org/abs/1602.00763), [\[Original repo\]](https://github.com/abewley/sort)
10 |
11 |
12 |
13 | Notification
14 | ------------
15 |
16 | - This repository is the implementation of [SORT](https://arxiv.org/abs/1602.00763), not [DeepSORT](https://arxiv.org/abs/1703.07402)
17 | - SORT is based on Kalman filter and Hangarian algorithm and does not use deep learning techniques.
18 | - In this repo, we use deep learning techniques (FasterRCNN and SSD) for object detection part.
19 |
20 | Requirement
21 | -----------
22 |
23 | - [Chainer](https://github.com/chainer/chainer)
24 | - [ChainerCV](https://github.com/chainer/chainercv)
25 | - [FilterPy](https://github.com/rlabbe/filterpy)
26 |
27 | Installation
28 | ------------
29 |
30 | We recommend to use [Anacoda](https://anaconda.org/).
31 |
32 | ```bash
33 | # Requirement installation
34 | conda create -n chainer-sort python=2.7
35 | source activate chainer-sort
36 |
37 | git clone https://github.com/knorth55/chainer-sort
38 | cd chainer-sort/
39 | pip install -e .
40 | ```
41 |
42 | Demo
43 | ----
44 |
45 | ```bash
46 | cd examples/mot/
47 | python demo.py
48 | ```
49 |
--------------------------------------------------------------------------------
/chainer_sort/__init__.py:
--------------------------------------------------------------------------------
1 | import pkg_resources
2 |
3 | from chainer_sort import datasets # NOQA
4 | from chainer_sort import models # NOQA
5 | from chainer_sort import trackers # NOQA
6 | from chainer_sort import utils # NOQA
7 | from chainer_sort import visualizations # NOQA
8 |
9 |
10 | __version__ = pkg_resources.get_distribution('chainer_sort').version
11 |
--------------------------------------------------------------------------------
/chainer_sort/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from chainer_sort.datasets.mot.mot_dataset import MOTDataset # NOQA
2 | from chainer_sort.datasets.mot.mot_utils import mot_map_names # NOQA
3 | from chainer_sort.datasets.mot.mot_utils import mot_sequence_names # NOQA
4 |
--------------------------------------------------------------------------------
/chainer_sort/datasets/mot/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/knorth55/chainer-sort/a72d4b7f6b88b8e06f9cdfe61e12133b46a75d26/chainer_sort/datasets/mot/__init__.py
--------------------------------------------------------------------------------
/chainer_sort/datasets/mot/mot_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import chainer
4 | from chainercv.utils import read_image
5 |
6 | from chainer_sort.datasets.mot import mot_utils
7 | from chainer_sort.datasets.mot.mot_utils import mot_map_names
8 | from chainer_sort.datasets.mot.mot_utils import mot_sequence_names
9 |
10 |
11 | class MOTDataset(chainer.dataset.DatasetMixin):
12 |
13 | def __init__(
14 | self, data_dir='auto', year='2015', split='train', sequence='c2',
15 | ):
16 | if split not in ['train', 'val', 'trainval']:
17 | raise ValueError(
18 | 'please pick split from \'train\', \'trainval\', \'val\'')
19 |
20 | if data_dir == 'auto':
21 | data_dir = mot_utils.get_mot(year, split)
22 |
23 | id_list_file = os.path.join(
24 | data_dir, 'annotations/{0}.txt'.format(split))
25 | ids = [id_.strip() for id_ in open(id_list_file)]
26 | if sequence in mot_map_names:
27 | if (year == '2015' and sequence not in ['c2', 'c3', 'c4']) \
28 | or (year == '2016' and sequence != 'c5') \
29 | or (year == '2017' and sequence != 'c9'):
30 | raise ValueError
31 | sequences = mot_utils.get_sequences(split, sequence)
32 | self.ids = [id_ for id_ in ids if id_.split('_')[1] in sequences]
33 | elif sequence.startswith(tuple(mot_sequence_names[year])):
34 | self.ids = [
35 | id_ for id_ in ids if id_.split('_')[1].startswith(sequence)
36 | ]
37 | else:
38 | raise ValueError
39 |
40 | self.data_dir = data_dir
41 |
42 | self.id2inst_id, self.id2bbox = None, None
43 | if split != 'val':
44 | self.id2inst_id, self.id2bbox = mot_utils.load_gt(
45 | self.data_dir, self.ids)
46 |
47 | def __len__(self):
48 | return len(self.ids)
49 |
50 | def get_example(self, i):
51 | data_id = self.ids[i]
52 | split_d, seq_d, frame = data_id.split('_')
53 | img_file = os.path.join(
54 | self.data_dir, split_d, seq_d, 'img1/{}.jpg'.format(frame))
55 | img = read_image(img_file, color=True)
56 | if self.id2inst_id is None and self.id2bbox is None:
57 | inst_id, bbox = None, None
58 | else:
59 | inst_id = self.id2inst_id[data_id]
60 | bbox = self.id2bbox[data_id]
61 | return img, bbox, inst_id
62 |
--------------------------------------------------------------------------------
/chainer_sort/datasets/mot/mot_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 |
4 | from chainer.dataset import download
5 | from chainercv import utils
6 |
7 |
8 | root = 'pfnet/chainercv/mot'
9 | dev_urls = 'http://motchallenge.net/data/devkit.zip'
10 | urls = {
11 | '2015': 'http://motchallenge.net/data/2DMOT2015.zip',
12 | '2016': 'http://motchallenge.net/data/MOT16.zip',
13 | '2017': 'http://motchallenge.net/data/MOT17.zip',
14 | }
15 |
16 | mot_map_names = ['c2', 'c3', 'c4', 'c5', 'c9']
17 | mot_sequence_names = {
18 | '2015': [
19 | 'TUD-Stadtmitte',
20 | 'TUD-Campus',
21 | 'PETS09-S2L1',
22 | 'ETH-Bahnhof',
23 | 'ETH-Sunnyday',
24 | 'ETH-Pedcross2',
25 | 'ADL-Rundle-6',
26 | 'ADL-Rundle-8',
27 | 'KITTI-13',
28 | 'KITTI-17',
29 | 'Venice-2',
30 | 'TUD-Crossing',
31 | 'PETS09-S2L2',
32 | 'ETH-Jelmoli',
33 | 'ETH-Linthescher',
34 | 'ETH-Crossing',
35 | 'AVG-TownCentre',
36 | 'ADL-Rundle-1',
37 | 'ADL-Rundle-3',
38 | 'KITTI-16',
39 | 'KITTI-19',
40 | 'Venice-1',
41 | ]
42 | }
43 |
44 | mot_sequence_names['2016'] = ['MOT16-{0:02d}'.format(i) for i in range(1, 15)]
45 | mot_sequence_names['2017'] = ['MOT17-{0:02d}'.format(i) for i in range(1, 15)]
46 |
47 |
48 | def load_gt(data_dir, data_ids):
49 | gt_dict = {}
50 | id2bbox = {}
51 | id2inst_id = {}
52 |
53 | for data_id in data_ids:
54 | split_d, seq_d, img_name = data_id.split('_')
55 | frame_id = int(img_name)
56 | if split_d != 'train':
57 | id2inst_id[data_id] = None
58 | id2bbox[data_id] = None
59 | continue
60 | if seq_d not in gt_dict:
61 | gt_dict[seq_d] = _load_gt(data_dir, split_d, seq_d)
62 | if frame_id in gt_dict[seq_d]:
63 | inst_id, bbox = gt_dict[seq_d][frame_id]
64 | else:
65 | inst_id, bbox = np.empty((0, )), np.empty((0, 4))
66 | id2inst_id[data_id] = inst_id
67 | id2bbox[data_id] = bbox
68 | return id2inst_id, id2bbox
69 |
70 |
71 | def _load_gt(data_dir, split_d, seq_d):
72 | gt_path = os.path.join(data_dir, split_d, seq_d, 'gt/gt.txt')
73 |
74 | gt_dict = {}
75 | gt_data = [map(float, x.split(',')[:6]) for x in open(gt_path).readlines()]
76 | gt_data = np.array(gt_data)
77 | for frame_id in np.unique(gt_data[:, 0].astype(np.int32)):
78 | gt_d = gt_data[gt_data[:, 0] == frame_id]
79 | inst_id = gt_d[:, 1].astype(np.int32)
80 | bbox = gt_d[:, 2:].astype(np.float32)
81 | bbox[:, 2:4] += bbox[:, :2]
82 | bbox = bbox[:, [1, 0, 3, 2]]
83 | gt_dict[frame_id] = (inst_id, bbox)
84 | return gt_dict
85 |
86 |
87 | def get_sequences(split, map_name):
88 | if split == 'train':
89 | splits = ['train']
90 | elif split == 'val':
91 | splits = ['test']
92 | elif split == 'trainval':
93 | if map_name in ['c2', 'c3', 'c4']:
94 | splits = ['train', 'test']
95 | else:
96 | splits = ['all']
97 | else:
98 | raise ValueError
99 |
100 | seq_map = []
101 | data_root = download.get_dataset_directory(root)
102 | seq_path = os.path.join(
103 | data_root, 'motchallenge-devkit/motchallenge/seqmaps')
104 | for sp in splits:
105 | seqmap_path = os.path.join(
106 | seq_path, '{0}-{1}.txt'.format(map_name, sp))
107 | with open(seqmap_path, 'r') as f:
108 | seq_m = f.read().split('\n')
109 | seq_map.extend(seq_m[1:-1])
110 | if map_name == 'c9':
111 | seq_map = ['{}-DPM'.format(x) for x in seq_map]
112 | return seq_map
113 |
114 |
115 | def get_mot(year, split):
116 | if year not in urls:
117 | raise ValueError
118 |
119 | data_root = download.get_dataset_directory(root)
120 | if year == '2015':
121 | mot_dirname = '2DMOT{}'.format(year)
122 | else:
123 | mot_dirname = 'MOT{}'.format(year[2:])
124 | base_path = os.path.join(data_root, mot_dirname)
125 | anno_path = os.path.join(base_path, 'annotations')
126 | anno_txt_path = os.path.join(anno_path, '{}.txt'.format(split))
127 |
128 | if not os.path.exists(base_path):
129 | download_file_path = utils.cached_download(urls[year])
130 | ext = os.path.splitext(urls[year])[1]
131 | utils.extractall(download_file_path, data_root, ext)
132 |
133 | if not os.path.exists(os.path.join(data_root, 'motchallenge-devkit')):
134 | download_devfile_path = utils.cached_download(dev_urls)
135 | dev_ext = os.path.splitext(dev_urls)[1]
136 | utils.extractall(download_devfile_path, data_root, dev_ext)
137 |
138 | if not os.path.exists(anno_path):
139 | os.mkdir(anno_path)
140 | if split == 'train':
141 | split_dirs = ['train']
142 | elif split == 'val':
143 | split_dirs = ['test']
144 | elif split == 'trainval':
145 | split_dirs = ['train', 'test']
146 | else:
147 | raise ValueError
148 |
149 | data_ids = []
150 | for split_d in split_dirs:
151 | seq_dirs = sorted(os.listdir(os.path.join(base_path, split_d)))
152 | for seq_d in seq_dirs:
153 | img_dir = os.path.join(base_path, split_d, seq_d, 'img1')
154 | img_names = sorted(os.listdir(img_dir))
155 | for img_name in img_names:
156 | data_id = '{0}_{1}_{2}'.format(
157 | split_d, seq_d, img_name.split('.')[0])
158 | data_ids.append(data_id)
159 |
160 | with open(anno_txt_path, 'w') as anno_f:
161 | anno_f.write('\n'.join(data_ids))
162 |
163 | return base_path
164 |
--------------------------------------------------------------------------------
/chainer_sort/models/__init__.py:
--------------------------------------------------------------------------------
1 | from chainer_sort.models.sort_multi_object_tracking import SORTMultiObjectTracking # NOQA
2 |
--------------------------------------------------------------------------------
/chainer_sort/models/sort_multi_object_tracking.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import chainer
4 |
5 | from chainer_sort.trackers import SORTMultiBboxTracker
6 |
7 |
8 | class SORTMultiObjectTracking(object):
9 |
10 | def __init__(self, dectector,
11 | detector_label_names, tracking_label_names=None):
12 | self.bbox_tracker = SORTMultiBboxTracker()
13 | self.detector = dectector
14 | self.detector_label_names = detector_label_names
15 | self.tracking_label_names = tracking_label_names
16 |
17 | def predict(self, img):
18 | if len(img) > 1:
19 | raise ValueError
20 | bboxes, labels, scores = self.detector.predict(img)
21 | bbox, label, score = bboxes[0], labels[0], scores[0]
22 | bbox = chainer.cuda.to_cpu(bbox)
23 | label = chainer.cuda.to_cpu(label)
24 | score = chainer.cuda.to_cpu(score)
25 |
26 | det_bbox = []
27 | det_label = []
28 | det_score = []
29 | for bb, lbl, sc in zip(bbox, label, score):
30 | if self.detector_label_names[lbl] in self.tracking_label_names:
31 | det_bbox.append(bb[None])
32 | det_label.append(lbl)
33 | det_score.append(sc)
34 | if len(det_bbox) > 0:
35 | det_bbox = np.concatenate(det_bbox)
36 | else:
37 | det_bbox = np.array(det_bbox).reshape((0, 4))
38 | det_label = np.array(det_label)
39 | det_score = np.array(det_score)
40 |
41 | trk_indices, trk_bbox, trk_inst_id = self.bbox_tracker.update(det_bbox)
42 | trk_label = det_label[trk_indices]
43 | trk_score = det_score[trk_indices]
44 | return trk_bbox[None], trk_label[None], \
45 | trk_score[None], trk_inst_id[None]
46 |
--------------------------------------------------------------------------------
/chainer_sort/trackers/__init__.py:
--------------------------------------------------------------------------------
1 | from chainer_sort.trackers.kalman_bbox_tracker import KalmanBboxTracker # NOQA
2 | from chainer_sort.trackers.sort_multi_bbox_tracker import SORTMultiBboxTracker # NOQA
3 |
--------------------------------------------------------------------------------
/chainer_sort/trackers/kalman_bbox_tracker.py:
--------------------------------------------------------------------------------
1 | # Modified work by Shingo Kitagawa (@knorth55)
2 | #
3 | # Original works of SORT from https://github.com/abewley/sort
4 | # ---------------------------------------------------------------------
5 | # Copyright (C) 2016 Alex Bewley alex@dynamicdetection.com
6 | #
7 | # This program is free software: you can redistribute it and/or modify
8 | # it under the terms of the GNU General Public License as published by
9 | # the Free Software Foundation, either version 3 of the License, or
10 | # (at your option) any later version.
11 | #
12 | # This program is distributed in the hope that it will be useful,
13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | # GNU General Public License for more details.
16 | #
17 | # You should have received a copy of the GNU General Public License
18 | # along with this program. If not, see .
19 | # ---------------------------------------------------------------------
20 |
21 |
22 | import numpy as np
23 |
24 | from filterpy.kalman import KalmanFilter
25 |
26 | from chainer_sort.utils import bbox2z_bbox
27 | from chainer_sort.utils import z_bbox2bbox
28 |
29 |
30 | class KalmanBboxTracker(object):
31 |
32 | def __init__(self, bbox):
33 | if len(bbox) > 1:
34 | raise ValueError
35 |
36 | self.filter = KalmanFilter(dim_x=7, dim_z=4)
37 | self.filter.F = np.array(
38 | [[1, 0, 0, 0, 1, 0, 0],
39 | [0, 1, 0, 0, 0, 1, 0],
40 | [0, 0, 1, 0, 0, 0, 1],
41 | [0, 0, 0, 1, 0, 0, 0],
42 | [0, 0, 0, 0, 1, 0, 0],
43 | [0, 0, 0, 0, 0, 1, 0],
44 | [0, 0, 0, 0, 0, 0, 1]])
45 | self.filter.H = np.array(
46 | [[1, 0, 0, 0, 0, 0, 0],
47 | [0, 1, 0, 0, 0, 0, 0],
48 | [0, 0, 1, 0, 0, 0, 0],
49 | [0, 0, 0, 1, 0, 0, 0]])
50 |
51 | self.filter.R[2:, 2:] *= 10.
52 | self.filter.P[4:, 4:] *= 1000.
53 | self.filter.P *= 10.
54 | self.filter.Q[-1, -1] *= 0.01
55 | self.filter.Q[4:, 4:] *= 0.01
56 | self.filter.x[:4] = bbox2z_bbox(bbox).T
57 |
58 | self.time_since_update = 0
59 | self.hit_streak = 0
60 |
61 | def get_state(self):
62 | bbox = z_bbox2bbox(self.filter.x[:4].T)
63 | return bbox
64 |
65 | def update(self, bbox):
66 | if len(bbox) > 1:
67 | raise ValueError
68 |
69 | self.time_since_update = 0
70 | self.hit_streak += 1
71 | self.filter.update(bbox2z_bbox(bbox).T)
72 |
73 | def predict(self):
74 | if (self.filter.x[2] + self.filter.x[6]) <= 0:
75 | self.filter.x[6] = 0.
76 | self.filter.predict()
77 | if self.time_since_update > 0:
78 | self.hit_streak = 0
79 | self.time_since_update += 1
80 | bbox = self.get_state()
81 | return bbox
82 |
--------------------------------------------------------------------------------
/chainer_sort/trackers/sort_multi_bbox_tracker.py:
--------------------------------------------------------------------------------
1 | # Modified work by Shingo Kitagawa (@knorth55)
2 | #
3 | # Original works of SORT from https://github.com/abewley/sort
4 | # ---------------------------------------------------------------------
5 | # Copyright (C) 2016 Alex Bewley alex@dynamicdetection.com
6 | #
7 | # This program is free software: you can redistribute it and/or modify
8 | # it under the terms of the GNU General Public License as published by
9 | # the Free Software Foundation, either version 3 of the License, or
10 | # (at your option) any later version.
11 | #
12 | # This program is distributed in the hope that it will be useful,
13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | # GNU General Public License for more details.
16 | #
17 | # You should have received a copy of the GNU General Public License
18 | # along with this program. If not, see .
19 | # ---------------------------------------------------------------------
20 |
21 |
22 | import numpy as np
23 |
24 | from chainer_sort.trackers import KalmanBboxTracker
25 | from chainer_sort.utils import iou_linear_assignment
26 |
27 |
28 | class SORTMultiBboxTracker(object):
29 |
30 | def __init__(self, max_age=1, min_hit_streak=3):
31 | self.max_age = max_age
32 | self.min_hit_streak = min_hit_streak
33 | self.trackers = []
34 | self.frame_count = 0
35 | self.tracker_num = 0
36 | self.inst_ids = []
37 |
38 | def update(self, det_bboxes):
39 | self.frame_count += 1
40 | pred_bboxes = []
41 | valid_trackers = []
42 | for i, tracker in enumerate(self.trackers):
43 | pred_bbox = tracker.predict()
44 | if np.all(~np.isnan(pred_bbox)) and np.all(~np.isinf(pred_bbox)):
45 | valid_trackers.append(tracker)
46 | pred_bboxes.append(pred_bbox)
47 | self.trackers = valid_trackers
48 |
49 | if len(pred_bboxes) > 0:
50 | pred_bboxes = np.concatenate(pred_bboxes)
51 | else:
52 | pred_bboxes = np.array(pred_bboxes).reshape((-1, 4))
53 | assert len(self.trackers) == len(pred_bboxes)
54 |
55 | matched_det_indices, matched_pred_indices = iou_linear_assignment(
56 | det_bboxes, pred_bboxes)
57 |
58 | # update matched trackers
59 | # create new trackers for unmatched detections
60 | new_trackers = []
61 | trk_det_indices = []
62 | trk_bboxes = []
63 | trk_inst_ids = []
64 | for det_index, det_bbox in enumerate(det_bboxes):
65 | # matched
66 | if det_index in matched_det_indices:
67 | pred_index = matched_pred_indices[
68 | matched_det_indices == det_index]
69 | tracker = self.trackers[int(pred_index)]
70 | tracker.update(det_bbox[None])
71 | # not matched
72 | else:
73 | tracker = KalmanBboxTracker(det_bbox[None])
74 | tracker.id = self.tracker_num
75 | self.tracker_num += 1
76 |
77 | if tracker.time_since_update < 1 \
78 | and (tracker.hit_streak >= self.min_hit_streak
79 | or self.frame_count <= self.min_hit_streak):
80 | trk_det_indices.append(det_index)
81 | trk_bbox = tracker.get_state()[0]
82 | trk_bboxes.append(trk_bbox)
83 | if tracker.id in self.inst_ids:
84 | trk_inst_id = self.inst_ids.index(tracker.id)
85 | else:
86 | self.inst_ids.append(tracker.id)
87 | trk_inst_id = len(self.inst_ids)
88 | trk_inst_ids.append(trk_inst_id)
89 |
90 | if tracker.time_since_update <= self.max_age:
91 | new_trackers.append(tracker)
92 | self.trackers = new_trackers
93 |
94 | trk_det_indices = np.array(trk_det_indices, dtype=int)
95 | trk_bboxes = np.array(trk_bboxes, dtype=float).reshape((-1, 4))
96 | trk_inst_ids = np.array(trk_inst_ids, dtype=int)
97 | return trk_det_indices, trk_bboxes, trk_inst_ids
98 |
--------------------------------------------------------------------------------
/chainer_sort/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | from chainercv.utils import bbox_iou
4 | import numpy as np
5 | from sklearn.utils.linear_assignment_ import linear_assignment
6 |
7 |
8 | def bbox2z_bbox(bbox):
9 | if bbox.shape[1] != 4:
10 | raise ValueError
11 |
12 | y_center = (bbox[:, 2] + bbox[:, 0]) / 2
13 | x_center = (bbox[:, 3] + bbox[:, 1]) / 2
14 | size = (bbox[:, 2] - bbox[:, 0]) * (bbox[:, 3] - bbox[:, 1])
15 | ratio = (bbox[:, 2] - bbox[:, 0]) / (bbox[:, 3] - bbox[:, 1])
16 | z_bbox = np.concatenate(
17 | (y_center[:, None], x_center[:, None],
18 | size[:, None], ratio[:, None]),
19 | axis=1)
20 | return z_bbox
21 |
22 |
23 | def z_bbox2bbox(z_bbox):
24 | if z_bbox.shape[1] != 4:
25 | raise ValueError
26 |
27 | height = np.sqrt(z_bbox[:, 2] * z_bbox[:, 3])
28 | width = z_bbox[:, 2] / height
29 | y_min = z_bbox[:, 0] - (height / 2)
30 | x_min = z_bbox[:, 1] - (width / 2)
31 | y_max = z_bbox[:, 0] + (height / 2)
32 | x_max = z_bbox[:, 1] + (width / 2)
33 | bbox = np.concatenate(
34 | (y_min[:, None], x_min[:, None],
35 | y_max[:, None], x_max[:, None]),
36 | axis=1)
37 | return bbox
38 |
39 |
40 | def iou_linear_assignment(bbox_a, bbox_b):
41 | iou = bbox_iou(bbox_a, bbox_b)
42 | indices = linear_assignment(-iou)
43 | return indices[:, 0], indices[:, 1]
44 |
--------------------------------------------------------------------------------
/chainer_sort/visualizations/__init__.py:
--------------------------------------------------------------------------------
1 | from chainer_sort.visualizations.vis_tracking_bbox import vis_tracking_bbox # NOQA
2 |
--------------------------------------------------------------------------------
/chainer_sort/visualizations/vis_tracking_bbox.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | import numpy as np
4 |
5 | from chainercv.visualizations.vis_image import vis_image
6 |
7 |
8 | def _default_cmap(label):
9 | """Color map used in PASCAL VOC"""
10 | r, g, b = 0, 0, 0
11 | i = label
12 | for j in range(8):
13 | if i & (1 << 0):
14 | r |= 1 << (7 - j)
15 | if i & (1 << 1):
16 | g |= 1 << (7 - j)
17 | if i & (1 << 2):
18 | b |= 1 << (7 - j)
19 | i >>= 3
20 | return r, g, b
21 |
22 |
23 | def vis_tracking_bbox(
24 | img, bbox, inst_id, label=None, score=None,
25 | label_names=None, alpha=1.0, ax=None):
26 |
27 | from matplotlib import pyplot as plot
28 | ax = vis_image(img, ax=ax)
29 |
30 | assert len(bbox) == len(inst_id)
31 | if len(bbox) == 0:
32 | return ax
33 |
34 | for i, (bb, inst_i) in enumerate(zip(bbox, inst_id)):
35 | bb = np.round(bb).astype(np.int32)
36 | y_min, x_min, y_max, x_max = bb
37 | color = np.array(_default_cmap(inst_i + 1)) / 255.
38 |
39 | ax.add_patch(plot.Rectangle(
40 | (x_min, y_min), x_max - x_min, y_max - y_min,
41 | fill=False, edgecolor=color, linewidth=3))
42 |
43 | caption = []
44 | caption.append('{}'.format(inst_i))
45 |
46 | if label is not None and label_names is not None:
47 | lb = label[i]
48 | if not (0 <= lb < len(label_names)):
49 | raise ValueError('No corresponding name is given')
50 | caption.append(label_names[lb])
51 | if score is not None:
52 | sc = score[i]
53 | caption.append('{:.2f}'.format(sc))
54 |
55 | ax.text((x_max + x_min) / 2, y_min,
56 | ': '.join(caption),
57 | style='italic',
58 | bbox={'facecolor': color, 'alpha': alpha},
59 | fontsize=8, color='white')
60 |
--------------------------------------------------------------------------------
/examples/mot/demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import time
5 |
6 | import chainer
7 | from chainercv.datasets import voc_bbox_label_names
8 | from chainercv.links import FasterRCNNVGG16
9 | from chainercv.links import SSD300
10 | from chainercv.links import SSD512
11 |
12 | from chainer_sort.datasets.mot.mot_utils import get_sequences
13 | from chainer_sort.datasets import MOTDataset
14 | from chainer_sort.models import SORTMultiObjectTracking
15 | from chainer_sort.visualizations import vis_tracking_bbox
16 |
17 |
18 | def main():
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('--no-display', action='store_true')
21 | parser.add_argument(
22 | '--model', choices=('ssd300', 'ssd512', 'faster-rcnn-vgg16'),
23 | default='faster-rcnn-vgg16')
24 | parser.add_argument('--gpu', type=int, default=0)
25 | parser.add_argument(
26 | '--pretrained_model', choices=('voc0712', 'voc07'), default='voc0712')
27 | parser.add_argument(
28 | '--year', '-y', choices=('2015', '2016', '2017'), default='2015')
29 | parser.add_argument('--sequence-map', '-s', default=None)
30 | args = parser.parse_args()
31 |
32 | if args.sequence_map is None:
33 | if args.year == '2015':
34 | seqmap_name = 'c2-test'
35 | elif args.year == '2016':
36 | seqmap_name = 'c5-test'
37 | elif args.year == '2017':
38 | seqmap_name = 'c9-test'
39 |
40 | map_name, split = seqmap_name.split('-')
41 | if split == 'test':
42 | split = 'val'
43 | sequences = get_sequences(split, map_name)
44 |
45 | if args.model == 'ssd300':
46 | detector = SSD300(
47 | n_fg_class=len(voc_bbox_label_names),
48 | pretrained_model=args.pretrained_model)
49 | elif args.model == 'ssd512':
50 | detector = SSD512(
51 | n_fg_class=len(voc_bbox_label_names),
52 | pretrained_model=args.pretrained_model)
53 | elif args.model == 'faster-rcnn-vgg16':
54 | detector = FasterRCNNVGG16(
55 | n_fg_class=len(voc_bbox_label_names),
56 | pretrained_model=args.pretrained_model)
57 | detector.use_preset('evaluate')
58 | detector.score_thresh = 0.5
59 |
60 | if args.gpu >= 0:
61 | chainer.cuda.get_device_from_id(args.gpu).use()
62 | detector.to_gpu()
63 |
64 | sort_label_names = ['person']
65 |
66 | if not args.no_display:
67 | plt.ion()
68 | fig = plt.figure()
69 |
70 | for seq in sequences:
71 | if args.no_display:
72 | ax = fig.add_subplot(111, aspect='equal')
73 |
74 | dataset = MOTDataset(
75 | year=args.year, split=split, sequence=seq)
76 |
77 | model = SORTMultiObjectTracking(
78 | detector, voc_bbox_label_names,
79 | sort_label_names)
80 |
81 | print('Sequence: {}'.format(seq))
82 | cycle_times = []
83 | for i in range(len(dataset)):
84 | img, _, _ = dataset[i]
85 | start_time = time.time()
86 | bboxes, labels, scores, inst_ids = model.predict([img])
87 | cycle_time = time.time() - start_time
88 | cycle_times.append(cycle_time)
89 | if args.no_display:
90 | bbox = bboxes[0]
91 | inst_id = inst_ids[0]
92 | label = labels[0]
93 | score = scores[0]
94 | vis_tracking_bbox(
95 | img, bbox, inst_id, label, score,
96 | label_names=voc_bbox_label_names, ax=ax)
97 | fig.canvas.flush_events()
98 | plt.draw()
99 | ax.cla()
100 |
101 | cycle_times = np.array(cycle_times)
102 | print('total time: {}'.format(np.sum(cycle_times)))
103 | print('average time: {}'.format(np.average(cycle_times)))
104 |
105 |
106 | if __name__ == '__main__':
107 | main()
108 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | chainer
2 | chainercv
3 | filterpy
4 | numpy
5 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages
2 | from setuptools import setup
3 |
4 |
5 | version = '0.1.0'
6 |
7 |
8 | setup(
9 | name='chainer_sort',
10 | version=version,
11 | packages=find_packages(),
12 | install_requires=open('requirements.txt').readlines(),
13 | description='SORT for chainercv',
14 | long_description=open('README.md').read(),
15 | author='Shingo Kitagawa',
16 | author_email='shingogo.5511@gmail.com',
17 | url='https://github.com/knorth55/chainer-sort',
18 | license='MIT',
19 | keywords='machine-learning',
20 | )
21 |
--------------------------------------------------------------------------------
/static/sort_faster_rcnn_example.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/knorth55/chainer-sort/a72d4b7f6b88b8e06f9cdfe61e12133b46a75d26/static/sort_faster_rcnn_example.gif
--------------------------------------------------------------------------------