├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── eval_mot.py
├── gifs
├── cars_out.gif
├── newyork_out.gif
└── test_out.gif
├── requirements.txt
├── track.py
└── tracking
├── __init__.py
├── clip
├── __init__.py
├── clip.py
└── model.py
├── dino
├── __init__.py
├── dino.py
├── vit.py
└── xcit.py
├── sort
├── __init__.py
├── detection.py
├── kalman_filter.py
├── matching.py
├── track.py
└── tracker.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Repo-specific GitIgnore ----------------------------------------------------------------------------------------------
2 | *.jpg
3 | *.jpeg
4 | *.png
5 | *.bmp
6 | *.tif
7 | *.tiff
8 | *.heic
9 | *.JPG
10 | *.JPEG
11 | *.PNG
12 | *.BMP
13 | *.TIF
14 | *.TIFF
15 | *.HEIC
16 | *.mp4
17 | *.mov
18 | *.MOV
19 | *.avi
20 | *.data
21 | *.json
22 |
23 | *.cfg
24 | !cfg/yolov3*.cfg
25 | */tracktor/*
26 | storage.googleapis.com
27 | runs/*
28 | data/*
29 | !data/images/zidane.jpg
30 | !data/images/bus.jpg
31 | !data/coco.names
32 | !data/coco_paper.names
33 | !data/coco.data
34 | !data/coco_*.data
35 | !data/coco_*.txt
36 | !data/trainvalno5k.shapes
37 | !data/*.sh
38 |
39 | test.py
40 | test_imgs/
41 |
42 | pycocotools/*
43 | results*.txt
44 | gcp_test*.sh
45 |
46 | checkpoints/
47 | output/
48 | assests/*/
49 |
50 | # Datasets -------------------------------------------------------------------------------------------------------------
51 | coco/
52 | coco128/
53 | VOC/
54 |
55 | # MATLAB GitIgnore -----------------------------------------------------------------------------------------------------
56 | *.m~
57 | *.mat
58 | !targets*.mat
59 |
60 | # Neural Network weights -----------------------------------------------------------------------------------------------
61 | *.weights
62 | *.pt
63 | *.onnx
64 | *.mlmodel
65 | *.torchscript
66 | darknet53.conv.74
67 | yolov3-tiny.conv.15
68 |
69 | # GitHub Python GitIgnore ----------------------------------------------------------------------------------------------
70 | # Byte-compiled / optimized / DLL files
71 | __pycache__/
72 | *.py[cod]
73 | *$py.class
74 |
75 | # C extensions
76 | *.so
77 |
78 | # Distribution / packaging
79 | .Python
80 | env/
81 | build/
82 | develop-eggs/
83 | dist/
84 | downloads/
85 | eggs/
86 | .eggs/
87 | lib/
88 | lib64/
89 | parts/
90 | sdist/
91 | var/
92 | wheels/
93 | *.egg-info/
94 | wandb/
95 | .installed.cfg
96 | *.egg
97 |
98 |
99 | # PyInstaller
100 | # Usually these files are written by a python script from a template
101 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
102 | *.manifest
103 | *.spec
104 |
105 | # Installer logs
106 | pip-log.txt
107 | pip-delete-this-directory.txt
108 |
109 | # Unit test / coverage reports
110 | htmlcov/
111 | .tox/
112 | .coverage
113 | .coverage.*
114 | .cache
115 | nosetests.xml
116 | coverage.xml
117 | *.cover
118 | .hypothesis/
119 |
120 | # Translations
121 | *.mo
122 | *.pot
123 |
124 | # Django stuff:
125 | *.log
126 | local_settings.py
127 |
128 | # Flask stuff:
129 | instance/
130 | .webassets-cache
131 |
132 | # Scrapy stuff:
133 | .scrapy
134 |
135 | # Sphinx documentation
136 | docs/_build/
137 |
138 | # PyBuilder
139 | target/
140 |
141 | # Jupyter Notebook
142 | .ipynb_checkpoints
143 |
144 | # pyenv
145 | .python-version
146 |
147 | # celery beat schedule file
148 | celerybeat-schedule
149 |
150 | # SageMath parsed files
151 | *.sage.py
152 |
153 | # dotenv
154 | .env
155 |
156 | # virtualenv
157 | .venv*
158 | venv*/
159 | ENV*/
160 |
161 | # Spyder project settings
162 | .spyderproject
163 | .spyproject
164 |
165 | # Rope project settings
166 | .ropeproject
167 |
168 | # mkdocs documentation
169 | /site
170 |
171 | # mypy
172 | .mypy_cache/
173 |
174 |
175 | # https://github.com/github/gitignore/blob/master/Global/macOS.gitignore -----------------------------------------------
176 |
177 | # General
178 | .DS_Store
179 | .AppleDouble
180 | .LSOverride
181 |
182 | # Icon must end with two \r
183 | Icon
184 | Icon?
185 |
186 | # Thumbnails
187 | ._*
188 |
189 | # Files that might appear in the root of a volume
190 | .DocumentRevisions-V100
191 | .fseventsd
192 | .Spotlight-V100
193 | .TemporaryItems
194 | .Trashes
195 | .VolumeIcon.icns
196 | .com.apple.timemachine.donotpresent
197 |
198 | # Directories potentially created on remote AFP share
199 | .AppleDB
200 | .AppleDesktop
201 | Network Trash Folder
202 | Temporary Items
203 | .apdisk
204 |
205 |
206 | # https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore
207 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
208 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
209 |
210 | # User-specific stuff:
211 | .idea/*
212 | .idea/**/workspace.xml
213 | .idea/**/tasks.xml
214 | .idea/dictionaries
215 | .html # Bokeh Plots
216 | .pg # TensorFlow Frozen Graphs
217 | .avi # videos
218 |
219 | # Sensitive or high-churn files:
220 | .idea/**/dataSources/
221 | .idea/**/dataSources.ids
222 | .idea/**/dataSources.local.xml
223 | .idea/**/sqlDataSources.xml
224 | .idea/**/dynamic.xml
225 | .idea/**/uiDesigner.xml
226 |
227 | # Gradle:
228 | .idea/**/gradle.xml
229 | .idea/**/libraries
230 |
231 | # CMake
232 | cmake-build-debug/
233 | cmake-build-release/
234 |
235 | # Mongo Explorer plugin:
236 | .idea/**/mongoSettings.xml
237 |
238 | ## File-based project format:
239 | *.iws
240 |
241 | ## Plugin-specific files:
242 |
243 | # IntelliJ
244 | out/
245 |
246 | # mpeltonen/sbt-idea plugin
247 | .idea_modules/
248 |
249 | # JIRA plugin
250 | atlassian-ide-plugin.xml
251 |
252 | # Cursive Clojure plugin
253 | .idea/replstate.xml
254 |
255 | # Crashlytics plugin (for Android Studio and IntelliJ)
256 | com_crashlytics_export_strings.xml
257 | crashlytics.properties
258 | crashlytics-build.properties
259 | fabric.properties
260 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "yolov5"]
2 | path = yolov5
3 | url = https://github.com/ultralytics/yolov5
4 | [submodule "TrackEval"]
5 | path = TrackEval
6 | url = https://github.com/JonathonLuiten/TrackEval
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 sithu3
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
Simple Object Tracking
2 |
3 |
4 |
Multi-Object Tracking with YOLOv5, CLIP, DINO and DeepSORT
5 |
6 |
7 |
8 |
9 |
10 | ## Introduction
11 |
12 | This is a simple two-stage mulit-object tracking [YOLOv5](https://github.com/ultralytics/yolov5) and [DeepSORT](https://arxiv.org/abs/1703.07402) with zero-short or self-supervised feature extractors.
13 |
14 | Normally, in DeepSORT, the deep part of the model is trained on a person re-identification dataset like [Market1501](https://www.kaggle.com/pengcw1/market-1501/data). We will replace this model with zero-shot or self-supervised models; which makes it ready to track any classes without needing to re-train.
15 |
16 | SOTA models like [CLIP](https://arxiv.org/abs/2103.00020) (zero-shot) and [DINO](https://arxiv.org/abs/2104.14294v2) (SSL) are currently experimented. If better models come out, I will consider adding it.
17 |
18 | ## Requirements
19 |
20 | * torch >= 1.8.1
21 | * torchvision >= 0.9.1
22 |
23 | Other requirements can be installed with `pip install -r requirements.txt`.
24 |
25 | Clone the repository recursively:
26 |
27 | ```bash
28 | $ git clone --recursive https://github.com/sithu31296/simple-object-tracking.git
29 | ```
30 |
31 | Then download a YOLO model's weight from [YOLOv5](https://github.com/ultralytics/yolov5) and place it in `checkpoints`.
32 |
33 | ## Tracking
34 |
35 | Track all classes:
36 |
37 | ```bash
38 | ## webcam
39 | $ python track.py --source 0 --yolo-model checkpoints/yolov5s.pt --reid-model CLIP-RN50
40 |
41 | ## video
42 | $ python track.py --source VIDEO_PATH --yolo-model checkpoints/yolov5s.pt --reid-model CLIP-RN50
43 | ```
44 |
45 | Track only specified classes:
46 |
47 | ```bash
48 | ## track only person class
49 | $ python track.py --source 0 --yolo-model checkpoints/yolov5s.pt --reid-model CLIP-RN50 --filter-class 0
50 |
51 | ## track person and car classes
52 | $ python track.py --source 0 --yolo-model checkpoints/yolov5s.pt --reid-model CLIP-RN50 --filter-class 0 2
53 | ```
54 |
55 | Available ReID models (Feature Extractors):
56 | * **CLIP**: `CLIP-RN50`, `CLIP-ViT-B/32`
57 | * **DINO**: `DINO-XciT-S12/16`, `DINO-XciT-M24/16`, `DINO-ViT-S/16`, `DINO-ViT-B/16`
58 |
59 | Check [here](tracking/utils.py#L14) to get COCO class index for your class.
60 |
61 | ## Evaluate on MOT16
62 |
63 | * Download MOT16 dataset from [here](https://motchallenge.net/data/MOT16.zip) and unzip it.
64 | * Download mot-challenge ground-truth [data](https://omnomnom.vision.rwth-aachen.de/data/TrackEval/data.zip) for evaluating with TrackEval. Then, unzip it under the project directory.
65 | * Save the tracking results of MOT16 with the following command:
66 |
67 | ```bash
68 | $ python eval_mot.py --root MOT16_ROOT_DIR --yolo-model checkpoints/yolov5m.pt --reid-model CLIP-RN50
69 | ```
70 |
71 | * Evaluate with TrackEval:
72 |
73 | ```bash
74 | $ python TrackEval/scripts/run_mot_challenge.py \
75 | --BENCHMARK MOT16 \
76 | --GT_FOLDER PROJECT_ROOT/data/gt/mot_challenge/ \
77 | --TRACKERS_FOLDER PROJECT_ROOT/data/trackers/mot_challenge/ \
78 | --TRACKERS_TO_EVAL mot_det \
79 | --SPLIT_TO_EVAL train \
80 | --USE_PARALLEL True \
81 | --NUM_PARALLEL_CORES 4 \
82 | --PRINT_ONLY_COMBINED True \
83 | ```
84 |
85 | > Notes: `FOLDER` parameters in `run_mot_challenge.py` must be an absolute path.
86 |
87 | For tracking persons, instead of using a COCO-pretrained model, using a model trained on multi-person dataset will get better accuracy. You can download a YOLOv5m model trained on [CrowdHuman](https://www.crowdhuman.org/) dataset from [here](https://drive.google.com/file/d/1gglIwqxaH2iTvy6lZlXuAcMpd_U0GCUb/view?usp=sharing). The weights are from [deepakcrk/yolov5-crowdhuman](https://github.com/deepakcrk/yolov5-crowdhuman). It has 2 classes: 'person' and 'head'. So, you can use this model for both person and head tracking.
88 |
89 | ## Results
90 |
91 | **MOT16 Evaluation Results**
92 |
93 | Detector | Feature Extractor | MOTA↑ | HOTA↑ | IDF1↑ | IDsw↓ | MT↑ | ML↓ | FP↓ | FN↓ | FPS
(GTX1660ti)
94 | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | ---
95 | YOLOv5m
(COCO) | CLIP
(RN50) | 35.42 | 35.37 | 39.42 | **486** | 115 | 192 | **6880** | 63931 | 7
96 | YOLOv5m
(CrowdHuman) | CLIP
(RN50) | 53.25 | 43.25 | 52.12 | 912 | 196 | **89** | 14076 | 36625 | 6
97 | YOLOv5m
(CrowdHuman) | CLIP
(ViT-B/32) | 53.35 | 43.03 | 51.25 | 896 | **199** | 91 | 14035 | **36575** | 4
98 | ||
99 | YOLOv5m
(CrowdHuman) | DINO
(XciT-S12/16) | 54.41 | 47.44 | 59.01 | 511 | 184 | 101 | 12265 | 37555 |8
100 | YOLOv5m
(CrowdHuman) | DINO
(ViT-S/16) | 54.56 | 47.61 | 58.94 | 519 | 189 | 97 | 12346 | 37308 | 8
101 | YOLOv5m
(CrowdHuman) | DINO
(XciT-M24/16) | 54.56 | **47.71** | **59.77** | 504 | 187 | 96 | 12364 | 37306 | 5
102 | YOLOv5m
(CrowdHuman) | DINO
(ViT-B/16) | **54.58** | 47.55 | 58.89 | 507 | 184 | 97 | 12017 | 37621 | 5
103 |
104 | **FPS Results**
105 |
106 | Detector | Feature Extractor | GPU | Precision | Image Size | Detection
/Frame | FPS
107 | --- | --- | --- | --- | --- | --- | ---
108 | YOLOv5s | CLIP-RN50 | GTX-1660ti | FP32 | 480x640 | 1 | 38
109 | YOLOv5s | CLIP-ViT-B/32 | GTX-1660ti | FP32 | 480x640 | 1 | 30
110 | ||
111 | YOLOv5s | DINO-XciT-S12/16 | GTX-1660ti | FP32 | 480x640 | 1 | 36
112 | YOLOv5s | DINO-ViT-B/16 | GTX-1660ti | FP32 | 480x640 | 1 | 30
113 | YOLOv5s | DINO-XciT-M24/16 | GTX-1660ti | FP32 | 480x640 | 1 | 25
114 |
115 |
116 | ## References
117 |
118 | * https://github.com/ultralytics/yolov5
119 | * https://github.com/JonathonLuiten/TrackEval
120 |
121 | ## Citations
122 |
123 | ```
124 | @inproceedings{caron2021emerging,
125 | title={Emerging Properties in Self-Supervised Vision Transformers},
126 | author={Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand},
127 | booktitle={Proceedings of the International Conference on Computer Vision (ICCV)},
128 | year={2021}
129 | }
130 |
131 | @article{el2021xcit,
132 | title={XCiT: Cross-Covariance Image Transformers},
133 | author={El-Nouby, Alaaeldin and Touvron, Hugo and Caron, Mathilde and Bojanowski, Piotr and Douze, Matthijs and Joulin, Armand and Laptev, Ivan and Neverova, Natalia and Synnaeve, Gabriel and Verbeek, Jakob and others},
134 | journal={arXiv preprint arXiv:2106.09681},
135 | year={2021}
136 | }
137 |
138 | @misc{radford2021learning,
139 | title={Learning Transferable Visual Models From Natural Language Supervision},
140 | author={Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
141 | year={2021},
142 | eprint={2103.00020},
143 | archivePrefix={arXiv},
144 | primaryClass={cs.CV}
145 | }
146 |
147 | @inproceedings{Wojke2017simple,
148 | title={Simple Online and Realtime Tracking with a Deep Association Metric},
149 | author={Wojke, Nicolai and Bewley, Alex and Paulus, Dietrich},
150 | booktitle={2017 IEEE International Conference on Image Processing (ICIP)},
151 | year={2017},
152 | pages={3645--3649},
153 | organization={IEEE},
154 | doi={10.1109/ICIP.2017.8296962}
155 | }
156 |
157 | @inproceedings{Wojke2018deep,
158 | title={Deep Cosine Metric Learning for Person Re-identification},
159 | author={Wojke, Nicolai and Bewley, Alex},
160 | booktitle={2018 IEEE Winter Conference on Applications of Computer Vision (WACV)},
161 | year={2018},
162 | pages={748--756},
163 | organization={IEEE},
164 | doi={10.1109/WACV.2018.00087}
165 | }
166 | ```
--------------------------------------------------------------------------------
/eval_mot.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import shutil
4 | from pathlib import Path
5 | from tqdm import tqdm
6 | from tracking.utils import *
7 |
8 | from track import Tracking
9 |
10 |
11 | class EvalTracking(Tracking):
12 | def __init__(self, yolo_model, reid_model, img_size, filter_class, conf_thres, iou_thres, max_cosine_dist, max_iou_dist, nn_budget, max_age, n_init) -> None:
13 | super().__init__(yolo_model, reid_model, img_size=img_size, filter_class=filter_class, conf_thres=conf_thres, iou_thres=iou_thres, max_cosine_dist=max_cosine_dist, max_iou_dist=max_iou_dist, nn_budget=nn_budget, max_age=max_age, n_init=n_init)
14 |
15 | def postprocess(self, pred, img1, img0, txt_path, frame_idx):
16 | pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, classes=self.filter_class)
17 |
18 | for det in pred:
19 | if len(det):
20 | boxes = scale_boxes(det[:, :4], img0.shape[:2], img1.shape[-2:]).cpu()
21 | features = self.extract_features(boxes, img0)
22 |
23 | self.tracker.predict()
24 | self.tracker.update(boxes, det[:, 5], features)
25 |
26 | for track in self.tracker.tracks:
27 | if not track.is_confirmed() or track.time_since_update > 1: continue
28 |
29 | x1, y1, x2, y2 = track.to_tlbr()
30 | w, h = x2 - x1, y2 - y1
31 |
32 | with open(txt_path, 'a') as f:
33 | f.write(f"{frame_idx+1},{track.track_id},{x1:.4f},{y1:.4f},{w:.4f},{h:.4f},-1,-1,-1,-1\n")
34 | else:
35 | self.tracker.increment_ages()
36 |
37 | @torch.no_grad()
38 | def predict(self, image, txt_path, frame_idx):
39 | img = self.preprocess(image)
40 | pred = self.model(img)[0]
41 | self.postprocess(pred, img, image, txt_path, frame_idx)
42 |
43 |
44 | def argument_parser():
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument('--root', type=str, default='/home/sithu/datasets/MOT16')
47 | parser.add_argument('--yolo-model', type=str, default='checkpoints/crowdhuman_yolov5m.pt')
48 | parser.add_argument('--reid-model', type=str, default='CLIP-RN50')
49 | parser.add_argument('--img-size', type=int, default=640)
50 | parser.add_argument('--filter-class', nargs='+', type=int, default=0)
51 | parser.add_argument('--conf-thres', type=float, default=0.4)
52 | parser.add_argument('--iou-thres', type=float, default=0.5)
53 | parser.add_argument('--max-cosine-dist', type=float, default=0.2)
54 | parser.add_argument('--max-iou-dist', type=int, default=0.7)
55 | parser.add_argument('--nn-budget', type=int, default=100)
56 | parser.add_argument('--max-age', type=int, default=70)
57 | parser.add_argument('--n-init', type=int, default=3)
58 | return parser.parse_args()
59 |
60 |
61 | if __name__ == '__main__':
62 | args = argument_parser()
63 | tracking = EvalTracking(
64 | args.yolo_model,
65 | args.reid_model,
66 | args.img_size,
67 | args.filter_class,
68 | args.conf_thres,
69 | args.iou_thres,
70 | args.max_cosine_dist,
71 | args.max_iou_dist,
72 | args.nn_budget,
73 | args.max_age,
74 | args.n_init
75 | )
76 |
77 | save_path = Path('data') / 'trackers' / 'mot_challenge' / 'MOT16-train' / 'mot_det' / 'data'
78 | if save_path.exists():
79 | shutil.rmtree(save_path)
80 | save_path.mkdir(parents=True)
81 |
82 | root = Path(args.root) / 'train'
83 | folders = root.iterdir()
84 |
85 | total_fps = []
86 |
87 | for folder in folders:
88 | tracking.tracker.reset()
89 | reader = SequenceStream(folder / 'img1')
90 | txt_path = save_path / f"{folder.stem}.txt"
91 | fps = FPS(len(reader.frames))
92 |
93 | for i, frame in tqdm(enumerate(reader), total=len(reader)):
94 | fps.start()
95 | tracking.predict(frame, txt_path, i)
96 | fps.stop(False)
97 |
98 | print(f"FPS: {fps.fps}")
99 | total_fps.append(fps.fps)
100 | del reader
101 |
102 | print(f"Average FPS for MOT16: {round(sum(total_fps) / len(total_fps))}")
--------------------------------------------------------------------------------
/gifs/cars_out.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sithu31296/simple-object-tracking/4e86e53e78799c5cb92d2f1cddd2df071530e98e/gifs/cars_out.gif
--------------------------------------------------------------------------------
/gifs/newyork_out.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sithu31296/simple-object-tracking/4e86e53e78799c5cb92d2f1cddd2df071530e98e/gifs/newyork_out.gif
--------------------------------------------------------------------------------
/gifs/test_out.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sithu31296/simple-object-tracking/4e86e53e78799c5cb92d2f1cddd2df071530e98e/gifs/test_out.gif
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | seaborn
2 | ftfy
3 | regex
4 | matplotlib
5 | numpy
6 | opencv-python
7 | scipy
8 | tqdm
9 |
--------------------------------------------------------------------------------
/track.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import cv2
4 | import numpy as np
5 | from PIL import Image
6 | from tqdm import tqdm
7 | from tracking import load_feature_extractor
8 | from tracking.sort.tracker import DeepSORTTracker
9 | from tracking.utils import *
10 |
11 | import sys
12 | sys.path.insert(0, 'yolov5')
13 | from yolov5.models.experimental import attempt_load
14 |
15 |
16 |
17 | class Tracking:
18 | def __init__(self,
19 | yolo_model,
20 | reid_model,
21 | img_size=640,
22 | filter_class=None,
23 | conf_thres=0.25,
24 | iou_thres=0.45,
25 | max_cosine_dist=0.4, # the higher the value, the easier it is to assume it is the same person
26 | max_iou_dist=0.7, # how much bboxes should overlap to determine the identity of the unassigned track
27 | nn_budget=None, # indicates how many previous frames of features vectors should be retained for distance calc for ecah track
28 | max_age=60, # specifies after how many frames unallocated tracks will be deleted
29 | n_init=3 # specifies after how many frames newly allocated tracks will be activated
30 | ) -> None:
31 | self.img_size = img_size
32 | self.conf_thres = conf_thres
33 | self.iou_thres = iou_thres
34 | self.filter_class = filter_class
35 |
36 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
37 | self.model = attempt_load(yolo_model, map_location=self.device)
38 | self.model = self.model.to(self.device)
39 | self.names = self.model.names
40 |
41 | self.patch_model, self.patch_transform = load_feature_extractor(reid_model, self.device)
42 | self.tracker = DeepSORTTracker('cosine', max_cosine_dist, nn_budget, max_iou_dist, max_age, n_init)
43 |
44 |
45 | def preprocess(self, image):
46 | img = letterbox(image, new_shape=self.img_size)
47 | img = np.ascontiguousarray(img.transpose((2, 0, 1)))
48 | img = torch.from_numpy(img).to(self.device)
49 | img = img.float() / 255.0
50 | img = img[None]
51 | return img
52 |
53 |
54 | def extract_features(self, boxes, img):
55 | image_patches = []
56 | for xyxy in boxes:
57 | x1, y1, x2, y2 = map(int, xyxy)
58 | img_patch = Image.fromarray(img[y1:y2, x1:x2])
59 | img_patch = self.patch_transform(img_patch)
60 | image_patches.append(img_patch)
61 |
62 | image_patches = torch.stack(image_patches).to(self.device)
63 | features = self.patch_model.encode_image(image_patches).cpu().numpy()
64 | return features
65 |
66 |
67 | def postprocess(self, pred, img1, img0):
68 | pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, classes=self.filter_class)
69 |
70 | for det in pred:
71 | if len(det):
72 | boxes = scale_boxes(det[:, :4], img0.shape[:2], img1.shape[-2:]).cpu()
73 | features = self.extract_features(boxes, img0)
74 |
75 | self.tracker.predict()
76 | self.tracker.update(boxes, det[:, 5], features)
77 |
78 | for track in self.tracker.tracks:
79 | if not track.is_confirmed() or track.time_since_update > 1: continue
80 | label = f"{self.names[int(track.class_id)]} #{track.track_id}"
81 | plot_one_box(track.to_tlbr(), img0, color=colors(int(track.class_id)), label=label)
82 | else:
83 | self.tracker.increment_ages()
84 |
85 |
86 | @torch.no_grad()
87 | def predict(self, image):
88 | img = self.preprocess(image)
89 | pred = self.model(img)[0]
90 | self.postprocess(pred, img, image)
91 | return image
92 |
93 |
94 | def argument_parser():
95 | parser = argparse.ArgumentParser()
96 | parser.add_argument('--source', type=str, default='0')
97 | parser.add_argument('--yolo-model', type=str, default='checkpoints/yolov5s.pt')
98 | parser.add_argument('--reid-model', type=str, default='CLIP-RN50')
99 | parser.add_argument('--img-size', type=int, default=640)
100 | parser.add_argument('--filter-class', nargs='+', type=int, default=None)
101 | parser.add_argument('--conf-thres', type=float, default=0.4)
102 | parser.add_argument('--iou-thres', type=float, default=0.5)
103 | parser.add_argument('--max-cosine-dist', type=float, default=0.2)
104 | parser.add_argument('--max-iou-dist', type=int, default=0.7)
105 | parser.add_argument('--nn-budget', type=int, default=100)
106 | parser.add_argument('--max-age', type=int, default=70)
107 | parser.add_argument('--n-init', type=int, default=3)
108 | return parser.parse_args()
109 |
110 |
111 | if __name__ == '__main__':
112 | args = argument_parser()
113 | tracking = Tracking(
114 | args.yolo_model,
115 | args.reid_model,
116 | args.img_size,
117 | args.filter_class,
118 | args.conf_thres,
119 | args.iou_thres,
120 | args.max_cosine_dist,
121 | args.max_iou_dist,
122 | args.nn_budget,
123 | args.max_age,
124 | args.n_init
125 | )
126 |
127 | if args.source.isnumeric():
128 | webcam = WebcamStream()
129 | fps = FPS()
130 |
131 | for frame in webcam:
132 | fps.start()
133 | output = tracking.predict(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
134 | fps.stop()
135 | cv2.imshow('frame', cv2.cvtColor(output, cv2.COLOR_RGB2BGR))
136 |
137 | else:
138 | reader = VideoReader(args.source)
139 | writer = VideoWriter(f"{args.source.rsplit('.', maxsplit=1)[0]}_out.mp4", reader.fps)
140 | fps = FPS(len(reader.frames))
141 |
142 | for frame in tqdm(reader):
143 | fps.start()
144 | output = tracking.predict(frame.numpy())
145 | fps.stop(False)
146 | writer.update(output)
147 |
148 | print(f"FPS: {fps.fps}")
149 | writer.write()
150 |
--------------------------------------------------------------------------------
/tracking/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .utils import download
3 | from .clip import load as clip_load
4 | from .dino import load as dino_load
5 |
6 |
7 | _MODELS = {
8 | "CLIP-RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
9 | "CLIP-ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
10 | "DINO-XciT-S12/16": "https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth",
11 | "DINO-XciT-M24/16": "https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth",
12 | "DINO-ViT-S/16": "https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth",
13 | "DINO-ViT-B/16": "https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth",
14 | }
15 |
16 |
17 | def load_feature_extractor(model_name: str, device):
18 | assert model_name in _MODELS
19 | model_path = download(_MODELS[model_name], os.path.expanduser("~/.cache/tracking"))
20 |
21 | if model_name.startswith('CLIP'):
22 | model, transform = clip_load(model_path, device, jit=False)
23 | elif model_name.startswith('DINO'):
24 | model, transform = dino_load(model_name, model_path, device)
25 | return model, transform
26 |
27 |
--------------------------------------------------------------------------------
/tracking/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
--------------------------------------------------------------------------------
/tracking/clip/clip.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import transforms as T
3 |
4 | from .model import build_model
5 |
6 |
7 | __all__ = ["load"]
8 |
9 |
10 | def _transform(n_px):
11 | return T.Compose([
12 | T.Resize(n_px, interpolation=T.InterpolationMode.BICUBIC),
13 | T.CenterCrop(n_px),
14 | T.ToTensor(),
15 | T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
16 | ])
17 |
18 |
19 | def load(model_path: str, device, jit=False):
20 | # loading JIT archive
21 | model = torch.jit.load(model_path, map_location="cpu").eval()
22 |
23 | if not jit:
24 | model = build_model(model.state_dict()).to(device)
25 | if str(device) == "cpu":
26 | model.float()
27 | return model, _transform(model.visual.input_resolution)
28 |
29 | # patch the device names
30 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
31 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
32 |
33 | def patch_device(module):
34 | try:
35 | graphs = [module.graph] if hasattr(module, "graph") else []
36 | except RuntimeError:
37 | graphs = []
38 |
39 | if hasattr(module, "forward1"):
40 | graphs.append(module.forward1.graph)
41 |
42 | for graph in graphs:
43 | for node in graph.findAllNodes("prim::Constant"):
44 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
45 | node.copyAttributes(device_node)
46 |
47 | model.apply(patch_device)
48 | patch_device(model.encode_image)
49 | patch_device(model.encode_text)
50 |
51 | # patch dtype to float32 on CPU
52 | if str(device) == "cpu":
53 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
54 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
55 | float_node = float_input.node()
56 |
57 | def patch_float(module):
58 | try:
59 | graphs = [module.graph] if hasattr(module, "graph") else []
60 | except RuntimeError:
61 | graphs = []
62 |
63 | if hasattr(module, "forward1"):
64 | graphs.append(module.forward1.graph)
65 |
66 | for graph in graphs:
67 | for node in graph.findAllNodes("aten::to"):
68 | inputs = list(node.inputs())
69 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
70 | if inputs[i].node()["value"] == 5:
71 | inputs[i].node().copyAttributes(float_node)
72 |
73 | model.apply(patch_float)
74 | patch_float(model.encode_image)
75 | patch_float(model.encode_text)
76 |
77 | model.float()
78 |
79 | return model, _transform(model.input_resolution.item())
--------------------------------------------------------------------------------
/tracking/clip/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, inplanes, planes, stride=1):
14 | super().__init__()
15 |
16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 |
20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
21 | self.bn2 = nn.BatchNorm2d(planes)
22 |
23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
24 |
25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
27 |
28 | self.relu = nn.ReLU(inplace=True)
29 | self.downsample = None
30 | self.stride = stride
31 |
32 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
34 | self.downsample = nn.Sequential(OrderedDict([
35 | ("-1", nn.AvgPool2d(stride)),
36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
37 | ("1", nn.BatchNorm2d(planes * self.expansion))
38 | ]))
39 |
40 | def forward(self, x: torch.Tensor):
41 | identity = x
42 |
43 | out = self.relu(self.bn1(self.conv1(x)))
44 | out = self.relu(self.bn2(self.conv2(out)))
45 | out = self.avgpool(out)
46 | out = self.bn3(self.conv3(out))
47 |
48 | if self.downsample is not None:
49 | identity = self.downsample(x)
50 |
51 | out += identity
52 | out = self.relu(out)
53 | return out
54 |
55 |
56 | class AttentionPool2d(nn.Module):
57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
58 | super().__init__()
59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
60 | self.k_proj = nn.Linear(embed_dim, embed_dim)
61 | self.q_proj = nn.Linear(embed_dim, embed_dim)
62 | self.v_proj = nn.Linear(embed_dim, embed_dim)
63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
64 | self.num_heads = num_heads
65 |
66 | def forward(self, x):
67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
70 | x, _ = F.multi_head_attention_forward(
71 | query=x, key=x, value=x,
72 | embed_dim_to_check=x.shape[-1],
73 | num_heads=self.num_heads,
74 | q_proj_weight=self.q_proj.weight,
75 | k_proj_weight=self.k_proj.weight,
76 | v_proj_weight=self.v_proj.weight,
77 | in_proj_weight=None,
78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
79 | bias_k=None,
80 | bias_v=None,
81 | add_zero_attn=False,
82 | dropout_p=0,
83 | out_proj_weight=self.c_proj.weight,
84 | out_proj_bias=self.c_proj.bias,
85 | use_separate_proj_weight=True,
86 | training=self.training,
87 | need_weights=False
88 | )
89 |
90 | return x[0]
91 |
92 |
93 | class ModifiedResNet(nn.Module):
94 | """
95 | A ResNet class that is similar to torchvision's but contains the following changes:
96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
98 | - The final pooling layer is a QKV attention instead of an average pool
99 | """
100 |
101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
102 | super().__init__()
103 | self.output_dim = output_dim
104 | self.input_resolution = input_resolution
105 |
106 | # the 3-layer stem
107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
108 | self.bn1 = nn.BatchNorm2d(width // 2)
109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
110 | self.bn2 = nn.BatchNorm2d(width // 2)
111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
112 | self.bn3 = nn.BatchNorm2d(width)
113 | self.avgpool = nn.AvgPool2d(2)
114 | self.relu = nn.ReLU(inplace=True)
115 |
116 | # residual layers
117 | self._inplanes = width # this is a *mutable* variable used during construction
118 | self.layer1 = self._make_layer(width, layers[0])
119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
122 |
123 | embed_dim = width * 32 # the ResNet feature dimension
124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
125 |
126 | def _make_layer(self, planes, blocks, stride=1):
127 | layers = [Bottleneck(self._inplanes, planes, stride)]
128 |
129 | self._inplanes = planes * Bottleneck.expansion
130 | for _ in range(1, blocks):
131 | layers.append(Bottleneck(self._inplanes, planes))
132 |
133 | return nn.Sequential(*layers)
134 |
135 | def forward(self, x):
136 | def stem(x):
137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
138 | x = self.relu(bn(conv(x)))
139 | x = self.avgpool(x)
140 | return x
141 |
142 | x = x.type(self.conv1.weight.dtype)
143 | x = stem(x)
144 | x = self.layer1(x)
145 | x = self.layer2(x)
146 | x = self.layer3(x)
147 | x = self.layer4(x)
148 | x = self.attnpool(x)
149 |
150 | return x
151 |
152 |
153 | class LayerNorm(nn.LayerNorm):
154 | """Subclass torch's LayerNorm to handle fp16."""
155 |
156 | def forward(self, x: torch.Tensor):
157 | orig_type = x.dtype
158 | ret = super().forward(x.type(torch.float32))
159 | return ret.type(orig_type)
160 |
161 |
162 | class QuickGELU(nn.Module):
163 | def forward(self, x: torch.Tensor):
164 | return x * torch.sigmoid(1.702 * x)
165 |
166 |
167 | class ResidualAttentionBlock(nn.Module):
168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
169 | super().__init__()
170 |
171 | self.attn = nn.MultiheadAttention(d_model, n_head)
172 | self.ln_1 = LayerNorm(d_model)
173 | self.mlp = nn.Sequential(OrderedDict([
174 | ("c_fc", nn.Linear(d_model, d_model * 4)),
175 | ("gelu", QuickGELU()),
176 | ("c_proj", nn.Linear(d_model * 4, d_model))
177 | ]))
178 | self.ln_2 = LayerNorm(d_model)
179 | self.attn_mask = attn_mask
180 |
181 | def attention(self, x: torch.Tensor):
182 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
183 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
184 |
185 | def forward(self, x: torch.Tensor):
186 | x = x + self.attention(self.ln_1(x))
187 | x = x + self.mlp(self.ln_2(x))
188 | return x
189 |
190 |
191 | class Transformer(nn.Module):
192 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
193 | super().__init__()
194 | self.width = width
195 | self.layers = layers
196 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
197 |
198 | def forward(self, x: torch.Tensor):
199 | return self.resblocks(x)
200 |
201 |
202 | class VisionTransformer(nn.Module):
203 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
204 | super().__init__()
205 | self.input_resolution = input_resolution
206 | self.output_dim = output_dim
207 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
208 |
209 | scale = width ** -0.5
210 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
211 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
212 | self.ln_pre = LayerNorm(width)
213 |
214 | self.transformer = Transformer(width, layers, heads)
215 |
216 | self.ln_post = LayerNorm(width)
217 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
218 |
219 | def forward(self, x: torch.Tensor):
220 | x = self.conv1(x) # shape = [*, width, grid, grid]
221 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
222 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
223 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
224 | x = x + self.positional_embedding.to(x.dtype)
225 | x = self.ln_pre(x)
226 |
227 | x = x.permute(1, 0, 2) # NLD -> LND
228 | x = self.transformer(x)
229 | x = x.permute(1, 0, 2) # LND -> NLD
230 |
231 | x = self.ln_post(x[:, 0, :])
232 |
233 | if self.proj is not None:
234 | x = x @ self.proj
235 |
236 | return x
237 |
238 |
239 | class CLIP(nn.Module):
240 | def __init__(self,
241 | embed_dim: int,
242 | # vision
243 | image_resolution: int,
244 | vision_layers: Union[Tuple[int, int, int, int], int],
245 | vision_width: int,
246 | vision_patch_size: int,
247 | # text
248 | context_length: int,
249 | vocab_size: int,
250 | transformer_width: int,
251 | transformer_heads: int,
252 | transformer_layers: int
253 | ):
254 | super().__init__()
255 |
256 | self.context_length = context_length
257 |
258 | if isinstance(vision_layers, (tuple, list)):
259 | vision_heads = vision_width * 32 // 64
260 | self.visual = ModifiedResNet(
261 | layers=vision_layers,
262 | output_dim=embed_dim,
263 | heads=vision_heads,
264 | input_resolution=image_resolution,
265 | width=vision_width
266 | )
267 | else:
268 | vision_heads = vision_width // 64
269 | self.visual = VisionTransformer(
270 | input_resolution=image_resolution,
271 | patch_size=vision_patch_size,
272 | width=vision_width,
273 | layers=vision_layers,
274 | heads=vision_heads,
275 | output_dim=embed_dim
276 | )
277 |
278 | self.transformer = Transformer(
279 | width=transformer_width,
280 | layers=transformer_layers,
281 | heads=transformer_heads,
282 | attn_mask=self.build_attention_mask()
283 | )
284 |
285 | self.vocab_size = vocab_size
286 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
287 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
288 | self.ln_final = LayerNorm(transformer_width)
289 |
290 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
291 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
292 |
293 | self.initialize_parameters()
294 |
295 | def initialize_parameters(self):
296 | nn.init.normal_(self.token_embedding.weight, std=0.02)
297 | nn.init.normal_(self.positional_embedding, std=0.01)
298 |
299 | if isinstance(self.visual, ModifiedResNet):
300 | if self.visual.attnpool is not None:
301 | std = self.visual.attnpool.c_proj.in_features ** -0.5
302 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
303 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
304 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
305 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
306 |
307 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
308 | for name, param in resnet_block.named_parameters():
309 | if name.endswith("bn3.weight"):
310 | nn.init.zeros_(param)
311 |
312 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
313 | attn_std = self.transformer.width ** -0.5
314 | fc_std = (2 * self.transformer.width) ** -0.5
315 | for block in self.transformer.resblocks:
316 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
317 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
318 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
319 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
320 |
321 | if self.text_projection is not None:
322 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
323 |
324 | def build_attention_mask(self):
325 | # lazily create causal attention mask, with full attention between the vision tokens
326 | # pytorch uses additive attention mask; fill with -inf
327 | mask = torch.empty(self.context_length, self.context_length)
328 | mask.fill_(float("-inf"))
329 | mask.triu_(1) # zero out the lower diagonal
330 | return mask
331 |
332 | @property
333 | def dtype(self):
334 | return self.visual.conv1.weight.dtype
335 |
336 | def encode_image(self, image):
337 | return self.visual(image.type(self.dtype))
338 |
339 | def encode_text(self, text):
340 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
341 |
342 | x = x + self.positional_embedding.type(self.dtype)
343 | x = x.permute(1, 0, 2) # NLD -> LND
344 | x = self.transformer(x)
345 | x = x.permute(1, 0, 2) # LND -> NLD
346 | x = self.ln_final(x).type(self.dtype)
347 |
348 | # x.shape = [batch_size, n_ctx, transformer.width]
349 | # take features from the eot embedding (eot_token is the highest number in each sequence)
350 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
351 |
352 | return x
353 |
354 | def forward(self, image, text):
355 | image_features = self.encode_image(image)
356 | text_features = self.encode_text(text)
357 |
358 | # normalized features
359 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
360 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
361 |
362 | # cosine similarity as logits
363 | logit_scale = self.logit_scale.exp()
364 | logits_per_image = logit_scale * image_features @ text_features.t()
365 | logits_per_text = logits_per_image.t()
366 |
367 | # shape = [global_batch_size, global_batch_size]
368 | return logits_per_image, logits_per_text
369 |
370 |
371 | def convert_weights(model: nn.Module):
372 | """Convert applicable model parameters to fp16"""
373 |
374 | def _convert_weights_to_fp16(l):
375 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
376 | l.weight.data = l.weight.data.half()
377 | if l.bias is not None:
378 | l.bias.data = l.bias.data.half()
379 |
380 | if isinstance(l, nn.MultiheadAttention):
381 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
382 | tensor = getattr(l, attr)
383 | if tensor is not None:
384 | tensor.data = tensor.data.half()
385 |
386 | for name in ["text_projection", "proj"]:
387 | if hasattr(l, name):
388 | attr = getattr(l, name)
389 | if attr is not None:
390 | attr.data = attr.data.half()
391 |
392 | model.apply(_convert_weights_to_fp16)
393 |
394 |
395 | def build_model(state_dict: dict):
396 | vit = "visual.proj" in state_dict
397 |
398 | if vit:
399 | vision_width = state_dict["visual.conv1.weight"].shape[0]
400 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
401 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
402 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
403 | image_resolution = vision_patch_size * grid_size
404 | else:
405 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
406 | vision_layers = tuple(counts)
407 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
408 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
409 | vision_patch_size = None
410 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
411 | image_resolution = output_width * 32
412 |
413 | embed_dim = state_dict["text_projection"].shape[1]
414 | context_length = state_dict["positional_embedding"].shape[0]
415 | vocab_size = state_dict["token_embedding.weight"].shape[0]
416 | transformer_width = state_dict["ln_final.weight"].shape[0]
417 | transformer_heads = transformer_width // 64
418 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
419 |
420 | model = CLIP(
421 | embed_dim,
422 | image_resolution, vision_layers, vision_width, vision_patch_size,
423 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
424 | )
425 |
426 | for key in ["input_resolution", "context_length", "vocab_size"]:
427 | if key in state_dict:
428 | del state_dict[key]
429 |
430 | convert_weights(model)
431 | model.load_state_dict(state_dict)
432 | return model.eval()
433 |
--------------------------------------------------------------------------------
/tracking/dino/__init__.py:
--------------------------------------------------------------------------------
1 | from .dino import *
--------------------------------------------------------------------------------
/tracking/dino/dino.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import transforms as T
3 | from .xcit import XciT
4 | from .vit import ViT
5 |
6 | __all__ = ["load"]
7 |
8 |
9 | def load(model_name, model_path, device):
10 | _, base_name, variant = model_name.split('-')
11 | model = eval(base_name)(variant)
12 | model.load_state_dict(torch.load(model_path, map_location='cpu'))
13 | model = model.to(device)
14 | model.eval()
15 |
16 | transform = T.Compose([
17 | T.Resize((224, 224)),
18 | T.ToTensor(),
19 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
20 | ])
21 |
22 | return model, transform
23 |
24 |
--------------------------------------------------------------------------------
/tracking/dino/vit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn.functional as F
4 | from torch import nn, Tensor
5 |
6 |
7 | class MLP(nn.Module):
8 | def __init__(self, dim, hidden_dim, out_dim=None) -> None:
9 | super().__init__()
10 | out_dim = out_dim or dim
11 | self.fc1 = nn.Linear(dim, hidden_dim)
12 | self.act = nn.GELU()
13 | self.fc2 = nn.Linear(hidden_dim, out_dim)
14 |
15 | def forward(self, x: Tensor) -> Tensor:
16 | return self.fc2(self.act(self.fc1(x)))
17 |
18 |
19 | class PatchEmbedding(nn.Module):
20 | """Image to Patch Embedding
21 | """
22 | def __init__(self, img_size=224, patch_size=16, embed_dim=768):
23 | super().__init__()
24 | assert img_size % patch_size == 0, 'Image size must be divisible by patch size'
25 |
26 | img_size = (img_size, img_size) if isinstance(img_size, int) else img_size
27 |
28 | self.grid_size = (img_size[0] // patch_size, img_size[1] // patch_size)
29 | self.num_patches = self.grid_size[0] * self.grid_size[1]
30 | self.proj = nn.Conv2d(3, embed_dim, patch_size, patch_size)
31 |
32 | def forward(self, x: Tensor) -> Tensor:
33 | x = self.proj(x) # b x hidden_dim x 14 x 14
34 | x = x.flatten(2).swapaxes(1, 2) # b x (14*14) x hidden_dim
35 | return x
36 |
37 |
38 | class Attention(nn.Module):
39 | def __init__(self, dim, heads=12):
40 | super().__init__()
41 | self.num_heads = heads
42 | self.scale = (dim // heads) ** -0.5
43 |
44 | self.qkv = nn.Linear(dim, dim * 3, bias=True)
45 | self.proj = nn.Linear(dim, dim)
46 |
47 | def forward(self, x: Tensor) -> Tensor:
48 | B, N, C = x.shape
49 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
50 | q, k, v = qkv[0], qkv[1], qkv[2]
51 |
52 | attn = (q @ k.transpose(-2, -1)) * self.scale
53 | attn = attn.softmax(dim=-1)
54 |
55 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
56 | x = self.proj(x)
57 | return x
58 |
59 |
60 | class TransformerEncoder(nn.Module):
61 | def __init__(self, dim, heads):
62 | super().__init__()
63 | self.norm1 = nn.LayerNorm(dim)
64 | self.attn = Attention(dim, heads)
65 | self.norm2 = nn.LayerNorm(dim)
66 | self.mlp = MLP(dim, int(dim * 4))
67 |
68 | def forward(self, x: Tensor) -> Tensor:
69 | x += self.attn(self.norm1(x))
70 | x += self.mlp(self.norm2(x))
71 | return x
72 |
73 |
74 | vit_settings = {
75 | 'S/8': [8, 12, 384, 6], #[patch_size, number_of_layers, embed_dim, heads]
76 | 'S/16': [16, 12, 384, 6],
77 | 'B/16': [16, 12, 768, 12]
78 | }
79 |
80 |
81 | class ViT(nn.Module):
82 | def __init__(self, model_name: str = 'S/8', image_size: int = 224) -> None:
83 | super().__init__()
84 | assert model_name in vit_settings.keys(), f"DeiT model name should be in {list(vit_settings.keys())}"
85 | patch_size, layers, embed_dim, heads = vit_settings[model_name]
86 |
87 | self.patch_size = patch_size
88 | self.patch_embed = PatchEmbedding(image_size, patch_size, embed_dim)
89 | self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
90 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
91 |
92 | self.blocks = nn.ModuleList([
93 | TransformerEncoder(embed_dim, heads)
94 | for i in range(layers)])
95 |
96 | self.norm = nn.LayerNorm(embed_dim)
97 |
98 | def interpolate_pos_encoding(self, x: Tensor, W: int, H: int) -> Tensor:
99 | num_patches = x.shape[1] - 1
100 | N = self.pos_embed.shape[1] - 1
101 |
102 | if num_patches == N and H == W:
103 | return self.pos_embed
104 |
105 | class_pos_embed = self.pos_embed[:, 0]
106 | patch_pos_embed = self.pos_embed[:, 1:]
107 |
108 | dim = x.shape[-1]
109 | w0 = W // self.patch_size
110 | h0 = H // self.patch_size
111 |
112 | w0, h0 = w0 + 0.1, h0 + 0.1
113 |
114 | patch_pos_embed = F.interpolate(
115 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
116 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
117 | mode='bicubic'
118 | )
119 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
120 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
121 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
122 |
123 | def encode_image(self, x):
124 | return self.forward(x)
125 |
126 | def forward(self, x: Tensor) -> Tensor:
127 | B, C, W, H = x.shape
128 | x = self.patch_embed(x)
129 | cls_token = self.cls_token.expand(x.shape[0], -1, -1)
130 | x = torch.cat((cls_token, x), dim=1)
131 | x += self.interpolate_pos_encoding(x, W, H)
132 |
133 | for blk in self.blocks:
134 | x = blk(x)
135 |
136 | x = self.norm(x)
137 | return x[:, 0]
138 |
139 |
140 | if __name__ == '__main__':
141 | model = ViT('S/16')
142 | model.load_state_dict(torch.load('checkpoints/vit/dino_deitsmall16_pretrain.pth', map_location='cpu'))
143 | x = torch.zeros(1, 3, 224, 224)
144 | y = model(x)
145 | print(y.shape)
--------------------------------------------------------------------------------
/tracking/dino/xcit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn.functional as F
4 | from torch import nn, Tensor
5 |
6 |
7 | class MLP(nn.Module):
8 | def __init__(self, dim, hidden_dim, out_dim=None) -> None:
9 | super().__init__()
10 | out_dim = out_dim or dim
11 | self.fc1 = nn.Linear(dim, hidden_dim)
12 | self.act = nn.GELU()
13 | self.fc2 = nn.Linear(hidden_dim, out_dim)
14 |
15 | def forward(self, x: Tensor) -> Tensor:
16 | return self.fc2(self.act(self.fc1(x)))
17 |
18 |
19 | class PositionalEncodingFourier(nn.Module):
20 | def __init__(self, dim: int = 768):
21 | super().__init__()
22 | self.dim = dim
23 | self.hidden_dim = 32
24 | self.token_projection = nn.Conv2d(self.hidden_dim * 2, dim, 1)
25 | self.scale = 2 * math.pi
26 |
27 | def forward(self, B: int, H: int, W: int) -> Tensor:
28 | mask = torch.zeros(B, H, W).bool().to(self.token_projection.weight.device)
29 | not_mask = ~mask
30 | y_embed = not_mask.cumsum(1, dtype=torch.float32)
31 | x_embed = not_mask.cumsum(2, dtype=torch.float32)
32 | y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
33 | x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
34 |
35 | dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=mask.device)
36 | dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode='floor')) / self.hidden_dim)
37 |
38 | pos_x = x_embed[:, :, :, None] / dim_t
39 | pos_y = y_embed[:, :, :, None] / dim_t
40 |
41 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(),
42 | pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
43 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(),
44 | pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
45 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
46 | pos = self.token_projection(pos)
47 | return pos
48 |
49 |
50 | class Conv3x3(nn.Sequential):
51 | def __init__(self, c1, c2, s=1):
52 | super().__init__(
53 | nn.Conv2d(c1, c2, 3, s, 1, bias=False),
54 | nn.BatchNorm2d(c2)
55 | )
56 |
57 |
58 | class ConvPatchEmbed(nn.Module):
59 | """Image to Patch Embedding using multiple convolutional layers
60 | """
61 | def __init__(self, patch_size=8, embed_dim=768):
62 | super().__init__()
63 | if patch_size == 16:
64 | self.proj = nn.Sequential(
65 | Conv3x3(3, embed_dim // 8, 2),
66 | nn.GELU(),
67 | Conv3x3(embed_dim // 8, embed_dim // 4, 2),
68 | nn.GELU(),
69 | Conv3x3(embed_dim // 4, embed_dim // 2, 2),
70 | nn.GELU(),
71 | Conv3x3(embed_dim // 2, embed_dim, 2),
72 | )
73 | else:
74 | self.proj = nn.Sequential(
75 | Conv3x3(3, embed_dim // 4, 2),
76 | nn.GELU(),
77 | Conv3x3(embed_dim // 4, embed_dim // 2, 2),
78 | nn.GELU(),
79 | Conv3x3(embed_dim // 2, embed_dim, 2),
80 | )
81 |
82 | def forward(self, x: Tensor):
83 | x = self.proj(x)
84 | _, _, H, W = x.shape
85 | x = x.flatten(2).transpose(1, 2)
86 | return x, (H, W)
87 |
88 |
89 | class LPI(nn.Module):
90 | """
91 | Local Patch Interaction module that allows explicit communication between tokens in 3x3 windows
92 | to augment the implicit communcation performed by the block diagonal scatter attention.
93 | Implemented using 2 layers of separable 3x3 convolutions with GeLU and BatchNorm2d
94 | """
95 | def __init__(self, dim: int):
96 | super().__init__()
97 | self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
98 | self.act = nn.GELU()
99 | self.bn = nn.BatchNorm2d(dim)
100 | self.conv2 = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
101 |
102 | def forward(self, x: Tensor, H: int, W: int) -> Tensor:
103 | B, N, C = x.shape
104 | x = x.permute(0, 2, 1).reshape(B, C, H, W)
105 | x = self.conv2(self.bn(self.act(self.conv1(x))))
106 | x = x.reshape(B, C, N).permute(0, 2, 1)
107 | return x
108 |
109 |
110 | class ClassAttention(nn.Module):
111 | """ClassAttention as in CaiT
112 | """
113 | def __init__(self, dim: int, heads: int):
114 | super().__init__()
115 | self.num_heads = heads
116 | self.scale = (dim // heads) ** -0.5
117 |
118 | self.qkv = nn.Linear(dim, dim * 3)
119 | self.proj = nn.Linear(dim, dim)
120 |
121 | def forward(self, x: Tensor) -> Tensor:
122 | B, N, C = x.shape
123 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
124 | q, k, v = qkv[0], qkv[1], qkv[2]
125 |
126 | qc = q[:, :, 0:1] # CLS token
127 |
128 | attn_cls = (qc * k).sum(dim=-1) * self.scale
129 | attn_cls = attn_cls.softmax(dim=-1)
130 |
131 | cls_token = (attn_cls.unsqueeze(2) @ v).transpose(1, 2).reshape(B, 1, C)
132 | cls_token = self.proj(cls_token)
133 |
134 | x = torch.cat([cls_token, x[:, 1:]], dim=1)
135 | return x
136 |
137 |
138 | class XCA(nn.Module):
139 | """ Cross-Covariance Attention (XCA) operation where the channels are updated using a weighted
140 | sum. The weights are obtained from the (softmax normalized) Cross-covariance
141 | matrix (Q^T K \\in d_h \\times d_h)
142 | """
143 | def __init__(self, dim: int, heads: int):
144 | super().__init__()
145 | self.num_heads = heads
146 | self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
147 |
148 | self.qkv = nn.Linear(dim, dim * 3)
149 | self.proj = nn.Linear(dim, dim)
150 |
151 | def forward(self, x: Tensor) -> Tensor:
152 | B, N, C = x.shape
153 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
154 | q, k, v = qkv[0].transpose(-2, -1), qkv[1].transpose(-2, -1), qkv[2].transpose(-2, -1)
155 | q = F.normalize(q, dim=-1)
156 | k = F.normalize(k, dim=-1)
157 | attn = (q @ k.transpose(-2, -1)) * self.temperature
158 | attn = attn.softmax(dim=-1)
159 |
160 | x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
161 | x = self.proj(x)
162 | return x
163 |
164 |
165 | class ClassAttentionBlock(nn.Module):
166 | def __init__(self, dim, heads, eta=1e-5):
167 | super().__init__()
168 | self.norm1 = nn.LayerNorm(dim)
169 | self.attn = ClassAttention(dim, heads)
170 | self.norm2 = nn.LayerNorm(dim)
171 | self.mlp = MLP(dim, int(dim * 4))
172 |
173 | self.gamma1 = nn.Parameter(eta * torch.ones(dim))
174 | self.gamma2 = nn.Parameter(eta * torch.ones(dim))
175 |
176 | def forward(self, x: Tensor) -> Tensor:
177 | x = x + (self.gamma1 * self.attn(self.norm1(x)))
178 | x = self.norm2(x)
179 |
180 | x_res = x
181 | cls_token = self.gamma2 * self.mlp(x[:, :1])
182 | x = torch.cat([cls_token, x[:, 1:]], dim=1)
183 | x += x_res
184 | return x
185 |
186 |
187 | class XCABlock(nn.Module):
188 | def __init__(self, dim, heads, eta=1e-5):
189 | super().__init__()
190 | self.norm1 = nn.LayerNorm(dim)
191 | self.attn = XCA(dim, heads)
192 | self.norm2 = nn.LayerNorm(dim)
193 | self.mlp = MLP(dim, int(dim * 4))
194 | self.norm3 = nn.LayerNorm(dim)
195 | self.local_mp = LPI(dim)
196 |
197 | self.gamma1 = nn.Parameter(eta * torch.ones(dim))
198 | self.gamma2 = nn.Parameter(eta * torch.ones(dim))
199 | self.gamma3 = nn.Parameter(eta * torch.ones(dim))
200 |
201 | def forward(self, x: Tensor, H, W) -> Tensor:
202 | x = x + self.gamma1 * self.attn(self.norm1(x))
203 | x = x + self.gamma3 * self.local_mp(self.norm3(x), H, W)
204 | x = x + self.gamma2 * self.mlp(self.norm2(x))
205 | return x
206 |
207 |
208 | xcit_settings = {
209 | 'S12/8': [8, 12, 384, 8], #[patch_size, layers, embed dim, heads]
210 | 'S12/16': [16, 12, 384, 8],
211 | 'M24/16': [16, 24, 512, 8],
212 | }
213 |
214 |
215 | class XciT(nn.Module):
216 | def __init__(self, model_name: str = 'S12/8', *args, **kwargs) -> None:
217 | super().__init__()
218 | assert model_name in xcit_settings.keys(), f"XciT model name should be in {list(xcit_settings.keys())}"
219 | patch_size, layers, embed_dim, heads = xcit_settings[model_name]
220 |
221 | self.patch_embed = ConvPatchEmbed(patch_size, embed_dim)
222 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
223 |
224 | self.pos_embeder = PositionalEncodingFourier(dim=embed_dim)
225 |
226 | self.blocks = nn.ModuleList([
227 | XCABlock(embed_dim, heads)
228 | for _ in range(layers)])
229 |
230 | self.cls_attn_blocks = nn.ModuleList([
231 | ClassAttentionBlock(embed_dim, heads)
232 | for _ in range(2)])
233 | self.norm = nn.LayerNorm(embed_dim)
234 |
235 | def encode_image(self, x):
236 | return self.forward(x)
237 |
238 | def forward(self, x):
239 | B = x.shape[0]
240 | x, (Hp, Wp) = self.patch_embed(x)
241 | x += self.pos_embeder(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
242 |
243 | for blk in self.blocks:
244 | x = blk(x, Hp, Wp)
245 |
246 | cls_tokens = self.cls_token.expand(B, -1, -1)
247 | x = torch.cat((cls_tokens, x), dim=1)
248 |
249 | for blk in self.cls_attn_blocks:
250 | x = blk(x)
251 |
252 | x = self.norm(x)
253 | return x[:, 0]
254 |
255 |
256 | if __name__ == '__main__':
257 | model = XciT('S12/16')
258 | model.load_state_dict(torch.load('checkpoints/xcit/dino_xcit_small_12_p16_pretrain.pth', map_location='cpu'))
259 | x = torch.zeros(1, 3, 224, 224)
260 | y = model(x)
261 | print(y.shape)
262 |
--------------------------------------------------------------------------------
/tracking/sort/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sithu31296/simple-object-tracking/4e86e53e78799c5cb92d2f1cddd2df071530e98e/tracking/sort/__init__.py
--------------------------------------------------------------------------------
/tracking/sort/detection.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Detection:
5 | """Bounding box detection in a single image
6 | Parameters
7 | ----------
8 | tlwh : (ndarray) bbox in format `(top left x, top left y, width, height)`.
9 | confidence : (float) Detector confidence score.
10 | class_id : (ndarray) Detector class.
11 | feature : (ndarray) A feature vector that describes the object contained in this image.
12 | """
13 | def __init__(self, tlwh, class_id, feature):
14 | self.tlwh = np.asarray(tlwh, dtype=np.float32)
15 | self.feature = np.asarray(feature, dtype=np.float32)
16 | self.class_id = class_id
17 |
18 | def to_tlbr(self):
19 | """Convert bbox from (top, left, width, height) to (top, left, bottom, right)
20 | """
21 | ret = self.tlwh.copy()
22 | ret[2:] += ret[:2]
23 | return ret
24 |
25 | def to_xyah(self):
26 | """Convert bbox from (top, left, width, height) to (center x, center y, aspect ratio, height) where the aspect ratio is `width / height`
27 | """
28 | ret = self.tlwh.copy()
29 | ret[:2] += ret[2:] / 2
30 | ret[2] /= ret[3]
31 | return ret
--------------------------------------------------------------------------------
/tracking/sort/kalman_filter.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.linalg
3 |
4 |
5 | class KalmanFilter:
6 | """A simple Kalman filter for tracking bounding boxes in image space
7 | The 8-dimensional state space
8 | x, y, a, h, vx, vy, va, vh
9 | contains the bounding box center position (x, y), aspect ratio a, height h,
10 | and their respective velocities.
11 | Object motion follows a constant velocity model. The bounding box location
12 | (x, y, a, h) is taken as direct observation of the state space (linear
13 | observation model).
14 | """
15 |
16 | def __init__(self):
17 | ndim, dt = 4, 1.
18 |
19 | # Create Kalman filter model matrices.
20 | self._motion_mat = np.eye(2 * ndim, 2 * ndim)
21 | for i in range(ndim):
22 | self._motion_mat[i, ndim + i] = dt
23 | self._update_mat = np.eye(ndim, 2 * ndim)
24 |
25 | # Motion and observation uncertainty are chosen relative to the current
26 | # state estimate. These weights control the amount of uncertainty in
27 | # the model. This is a bit hacky.
28 | self._std_weight_position = 1. / 20
29 | self._std_weight_velocity = 1. / 160
30 |
31 | def initiate(self, measurement):
32 | """Create track from unassociated measurement.
33 | Parameters
34 | ----------
35 | measurement : (ndarray) Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a, and height h.
36 | Returns
37 | -------
38 | (ndarray, ndarray)
39 | Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of the new track. Unobserved velocities are initialized to 0 mean.
40 | """
41 | mean_pos = measurement
42 | mean_vel = np.zeros_like(mean_pos)
43 | mean = np.r_[mean_pos, mean_vel]
44 |
45 | std = [
46 | 2 * self._std_weight_position * measurement[3],
47 | 2 * self._std_weight_position * measurement[3],
48 | 1e-2,
49 | 2 * self._std_weight_position * measurement[3],
50 | 10 * self._std_weight_velocity * measurement[3],
51 | 10 * self._std_weight_velocity * measurement[3],
52 | 1e-5,
53 | 10 * self._std_weight_velocity * measurement[3]
54 | ]
55 | covariance = np.diag(np.square(std))
56 | return mean, covariance
57 |
58 | def predict(self, mean, covariance):
59 | """Run Kalman filter prediction step.
60 | Parameters
61 | ----------
62 | mean : ndarray
63 | The 8 dimensional mean vector of the object state at the previous
64 | time step.
65 | covariance : ndarray
66 | The 8x8 dimensional covariance matrix of the object state at the
67 | previous time step.
68 | Returns
69 | -------
70 | (ndarray, ndarray)
71 | Returns the mean vector and covariance matrix of the predicted
72 | state. Unobserved velocities are initialized to 0 mean.
73 | """
74 | std_pos = [
75 | self._std_weight_position * mean[3],
76 | self._std_weight_position * mean[3],
77 | 1e-2,
78 | self._std_weight_position * mean[3]]
79 | std_vel = [
80 | self._std_weight_velocity * mean[3],
81 | self._std_weight_velocity * mean[3],
82 | 1e-5,
83 | self._std_weight_velocity * mean[3]]
84 |
85 | motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
86 | mean = np.dot(self._motion_mat, mean)
87 | covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
88 |
89 | return mean, covariance
90 |
91 | def project(self, mean, covariance):
92 | """Project state distribution to measurement space.
93 | Parameters
94 | ----------
95 | mean : (ndarray) The state's mean vector (8 dimensional array).
96 | covariance : (ndarray) The state's covariance matrix (8x8 dimensional).
97 | Returns
98 | -------
99 | (ndarray, ndarray)
100 | Returns the projected mean and covariance matrix of the given state estimate.
101 | """
102 | std = [
103 | self._std_weight_position * mean[3],
104 | self._std_weight_position * mean[3],
105 | 1e-1,
106 | self._std_weight_position * mean[3]
107 | ]
108 |
109 | innovation_cov = np.diag(np.square(std))
110 | mean = np.dot(self._update_mat, mean)
111 | covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T)) + innovation_cov
112 |
113 | return mean, covariance
114 |
115 | def update(self, mean, covariance, measurement):
116 | """Run Kalman filter correction step.
117 | Parameters
118 | ----------
119 | mean : (ndarray) The predicted state's mean vector (8 dimensional).
120 | covariance : (ndarray) The state's covariance matrix (8x8 dimensional).
121 | measurement : (ndarray) The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center position, a the aspect ratio, and h the height of the bounding box.
122 |
123 | Returns
124 | -------
125 | (ndarray, ndarray)
126 | Returns the measurement-corrected state distribution.
127 | """
128 | projected_mean, projected_cov = self.project(mean, covariance)
129 |
130 | chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)
131 | kalman_gain = scipy.linalg.cho_solve((chol_factor, lower), np.dot(covariance, self._update_mat.T).T, check_finite=False).T
132 | innovation = measurement - projected_mean
133 |
134 | new_mean = mean + np.dot(innovation, kalman_gain.T)
135 | new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))
136 | return new_mean, new_covariance
137 |
138 | def gating_distance(self, mean, covariance, measurements):
139 | """Compute gating distance between state distribution and measurements.
140 | Parameters
141 | ----------
142 | mean : (ndarray) Mean vector over the state distribution (8 dimensional).
143 | covariance : (ndarray) Covariance of the state distribution (8x8 dimensional).
144 | measurements : (ndarray)
145 | An Nx4 dimensional matrix of N measurements, each in
146 | format (x, y, a, h) where (x, y) is the bounding box center
147 | position, a the aspect ratio, and h the height.
148 | Returns
149 | -------
150 | ndarray
151 | Returns an array of length N, where the i-th element contains the
152 | squared Mahalanobis distance between (mean, covariance) and
153 | `measurements[i]`.
154 | """
155 | mean, covariance = self.project(mean, covariance)
156 | cholesky_factor = np.linalg.cholesky(covariance)
157 | d = measurements - mean
158 | z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)
159 | squared_maha = np.sum(z * z, axis=0)
160 | return squared_maha
--------------------------------------------------------------------------------
/tracking/sort/matching.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.optimize import linear_sum_assignment
3 |
4 |
5 | def iou(bbox, candidates):
6 | """Compute IoU by one box to N candidates
7 | Parameters
8 | ----------
9 | bbox : (ndarray) A bounding box in format `(top left x, top left y, width, height)`.
10 | candidates : (ndarray) A matrix of candidate bounding boxes (one per row) in the same format as `bbox`.
11 |
12 | Returns
13 | -------
14 | ndarray
15 | The intersection over union in [0, 1] between the `bbox` and each candidate. A higher score means a larger fraction of the `bbox` is occluded by the candidate.
16 | """
17 | bbox_tl, bbox_br = bbox[:2], bbox[:2] + bbox[2:]
18 | candidates_tl, candidates_br = candidates[:, :2], candidates[:, :2] + candidates[:, 2:]
19 |
20 | tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis], np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]]
21 | br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis], np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]]
22 | wh = np.maximum(0., br - tl)
23 |
24 | area_intersection = wh.prod(axis=1)
25 | area_bbox = bbox[2:].prod()
26 | area_candidates = candidates[:, 2:].prod(axis=1)
27 | return area_intersection / (area_bbox + area_candidates - area_intersection)
28 |
29 |
30 | def iou_cost(tracks, detections, track_indices=None, detection_indices=None):
31 | """An intersection over union distance metric.
32 | Parameters
33 | ----------
34 | tracks : List[deep_sort.track.Track] A list of tracks.
35 | detections : List[deep_sort.detection.Detection] A list of detections.
36 | track_indices : Optional[List[int]] A list of indices to tracks that should be matched. Defaults to all `tracks`.
37 | detection_indices : Optional[List[int]] A list of indices to detections that should be matched. Defaults to all `detections`.
38 |
39 | Returns
40 | -------
41 | ndarray
42 | Returns a cost matrix of shape len(track_indices), len(detection_indices) where entry (i, j) is `1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`.
43 | """
44 | if track_indices is None: track_indices = np.arange(len(tracks))
45 | if detection_indices is None: detection_indices = np.arange(len(detections))
46 |
47 | cost_matrix = np.zeros((len(track_indices), len(detection_indices)))
48 | for row, track_idx in enumerate(track_indices):
49 | if tracks[track_idx].time_since_update > 1:
50 | cost_matrix[row, :] = 1e+5
51 | continue
52 |
53 | bbox = tracks[track_idx].to_tlwh()
54 | candidates = np.asarray([detections[i].tlwh for i in detection_indices])
55 | cost_matrix[row, :] = 1. - iou(bbox, candidates)
56 | return cost_matrix
57 |
58 |
59 | def _nn_euclidean_distance(a, b):
60 | """Compute pair-wise squared distance between points in `a` and `b`.
61 | Parameters
62 | ----------
63 | a : array_like
64 | An NxM matrix of N samples of dimensionality M.
65 | b : array_like
66 | An LxM matrix of L samples of dimensionality M.
67 | Returns
68 | -------
69 | ndarray
70 | Returns a matrix of size len(a), len(b) such that eleement (i, j)
71 | contains the squared distance between `a[i]` and `b[j]`.
72 | """
73 | a, b = np.asarray(a), np.asarray(b)
74 | if len(a) == 0 or len(b) == 0:
75 | return np.zeros((len(a), len(b)))
76 | a2, b2 = np.square(a).sum(axis=1), np.square(b).sum(axis=1)
77 | distances = -2. * np.dot(a, b.T) + a2[:, None] + b2[None, :]
78 | distances = np.clip(distances, 0., float(np.inf))
79 | return np.maximum(0.0, distances.min(axis=0))
80 |
81 |
82 | def _nn_cosine_distance(a, b):
83 | """Compute pair-wise cosine distance between points in `a` and `b`.
84 | Parameters
85 | ----------
86 | a : array_like
87 | An NxM matrix of N samples of dimensionality M.
88 | b : array_like
89 | An LxM matrix of L samples of dimensionality M.
90 |
91 | Returns
92 | -------
93 | ndarray
94 | Returns a matrix of size len(a), len(b) such that eleement (i, j)
95 | contains the squared distance between `a[i]` and `b[j]`.
96 | """
97 | a = np.asarray(a) / np.linalg.norm(a, axis=1, keepdims=True)
98 | b = np.asarray(b) / np.linalg.norm(b, axis=1, keepdims=True)
99 | distances = 1. - np.dot(a, b.T)
100 | return distances.min(axis=0)
101 |
102 |
103 | class NearestNeighborDistanceMetric:
104 | """
105 | A nearest neighbor distance metric that, for each target, returns
106 | the closest distance to any sample that has been observed so far.
107 | Parameters
108 | ----------
109 | metric : str
110 | Either "euclidean" or "cosine".
111 | matching_threshold: float
112 | The matching threshold. Samples with larger distance are considered an
113 | invalid match.
114 | budget : Optional[int]
115 | If not None, fix samples per class to at most this number. Removes
116 | the oldest samples when the budget is reached.
117 | Attributes
118 | ----------
119 | samples : Dict[int -> List[ndarray]]
120 | A dictionary that maps from target identities to the list of samples
121 | that have been observed so far.
122 | """
123 |
124 | def __init__(self, metric, matching_threshold, budget=None):
125 | if metric == "euclidean":
126 | self._metric = _nn_euclidean_distance
127 | elif metric == "cosine":
128 | self._metric = _nn_cosine_distance
129 | else:
130 | raise ValueError("Invalid metric; must be either 'euclidean' or 'cosine'")
131 | self.matching_threshold = matching_threshold
132 | self.budget = budget
133 | self.samples = {}
134 |
135 | def partial_fit(self, features, targets, active_targets):
136 | """Update the distance metric with new data.
137 | Parameters
138 | ----------
139 | features : ndarray
140 | An NxM matrix of N features of dimensionality M.
141 | targets : ndarray
142 | An integer array of associated target identities.
143 | active_targets : List[int]
144 | A list of targets that are currently present in the scene.
145 | """
146 | for feature, target in zip(features, targets):
147 | self.samples.setdefault(target, []).append(feature)
148 | if self.budget is not None:
149 | self.samples[target] = self.samples[target][-self.budget:]
150 | self.samples = {k: self.samples[k] for k in active_targets}
151 |
152 | def distance(self, features, targets):
153 | """Compute distance between features and targets.
154 | Parameters
155 | ----------
156 | features : ndarray
157 | An NxM matrix of N features of dimensionality M.
158 | targets : List[int]
159 | A list of targets to match the given `features` against.
160 | Returns
161 | -------
162 | ndarray
163 | Returns a cost matrix of shape len(targets), len(features), where
164 | element (i, j) contains the closest squared distance between
165 | `targets[i]` and `features[j]`.
166 | """
167 | cost_matrix = np.zeros((len(targets), len(features)))
168 | for i, target in enumerate(targets):
169 | cost_matrix[i, :] = self._metric(self.samples[target], features)
170 | return cost_matrix
171 |
172 |
173 | def min_cost_matching(distance_metric, max_distance, tracks, detections, track_indices=None, detection_indices=None):
174 | """Solve linear assignment problem.
175 | Parameters
176 | ----------
177 | distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
178 | The distance metric is given a list of tracks and detections as well as
179 | a list of N track indices and M detection indices. The metric should
180 | return the NxM dimensional cost matrix, where element (i, j) is the
181 | association cost between the i-th track in the given track indices and
182 | the j-th detection in the given detection_indices.
183 | max_distance : float
184 | Gating threshold. Associations with cost larger than this value are
185 | disregarded.
186 | tracks : List[track.Track]
187 | A list of predicted tracks at the current time step.
188 | detections : List[detection.Detection]
189 | A list of detections at the current time step.
190 | track_indices : List[int]
191 | List of track indices that maps rows in `cost_matrix` to tracks in
192 | `tracks` (see description above).
193 | detection_indices : List[int]
194 | List of detection indices that maps columns in `cost_matrix` to
195 | detections in `detections` (see description above).
196 | Returns
197 | -------
198 | (List[(int, int)], List[int], List[int])
199 | Returns a tuple with the following three entries:
200 | * A list of matched track and detection indices.
201 | * A list of unmatched track indices.
202 | * A list of unmatched detection indices.
203 | """
204 | if track_indices is None: track_indices = np.arange(len(tracks))
205 | if detection_indices is None: detection_indices = np.arange(len(detections))
206 |
207 | if len(detection_indices) == 0 or len(track_indices) == 0:
208 | return [], track_indices, detection_indices # Nothing to match.
209 |
210 | cost_matrix = distance_metric(tracks, detections, track_indices, detection_indices)
211 | cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5
212 |
213 | row_indices, col_indices = linear_sum_assignment(cost_matrix)
214 |
215 | matches, unmatched_tracks, unmatched_detections = [], [], []
216 | for col, detection_idx in enumerate(detection_indices):
217 | if col not in col_indices:
218 | unmatched_detections.append(detection_idx)
219 | for row, track_idx in enumerate(track_indices):
220 | if row not in row_indices:
221 | unmatched_tracks.append(track_idx)
222 | for row, col in zip(row_indices, col_indices):
223 | track_idx = track_indices[row]
224 | detection_idx = detection_indices[col]
225 | if cost_matrix[row, col] > max_distance:
226 | unmatched_tracks.append(track_idx)
227 | unmatched_detections.append(detection_idx)
228 | else:
229 | matches.append((track_idx, detection_idx))
230 | return matches, unmatched_tracks, unmatched_detections
231 |
232 |
233 | def matching_cascade(distance_metric, max_distance, cascade_depth, tracks, detections, track_indices=None, detection_indices=None):
234 | """Run matching cascade.
235 | Parameters
236 | ----------
237 | distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
238 | The distance metric is given a list of tracks and detections as well as
239 | a list of N track indices and M detection indices. The metric should
240 | return the NxM dimensional cost matrix, where element (i, j) is the
241 | association cost between the i-th track in the given track indices and
242 | the j-th detection in the given detection indices.
243 | max_distance : float
244 | Gating threshold. Associations with cost larger than this value are
245 | disregarded.
246 | cascade_depth: int
247 | The cascade depth, should be se to the maximum track age.
248 | tracks : List[track.Track]
249 | A list of predicted tracks at the current time step.
250 | detections : List[detection.Detection]
251 | A list of detections at the current time step.
252 | track_indices : Optional[List[int]]
253 | List of track indices that maps rows in `cost_matrix` to tracks in
254 | `tracks` (see description above). Defaults to all tracks.
255 | detection_indices : Optional[List[int]]
256 | List of detection indices that maps columns in `cost_matrix` to
257 | detections in `detections` (see description above). Defaults to all
258 | detections.
259 | Returns
260 | -------
261 | (List[(int, int)], List[int], List[int])
262 | Returns a tuple with the following three entries:
263 | * A list of matched track and detection indices.
264 | * A list of unmatched track indices.
265 | * A list of unmatched detection indices.
266 | """
267 | if track_indices is None: track_indices = list(range(len(tracks)))
268 | if detection_indices is None: detection_indices = list(range(len(detections)))
269 |
270 | unmatched_detections = detection_indices
271 | matches = []
272 | for level in range(cascade_depth):
273 | if len(unmatched_detections) == 0: # No detections left
274 | break
275 |
276 | track_indices_l = [k for k in track_indices if tracks[k].time_since_update == 1 + level]
277 | if len(track_indices_l) == 0: # Nothing to match at this level
278 | continue
279 |
280 | matches_l, _, unmatched_detections = min_cost_matching(distance_metric, max_distance, tracks, detections, track_indices_l, unmatched_detections)
281 | matches += matches_l
282 | unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches))
283 | return matches, unmatched_tracks, unmatched_detections
284 |
285 |
286 | def gate_cost_matrix(kf, cost_matrix, tracks, detections, track_indices, detection_indices):
287 | """Invalidate infeasible entries in cost matrix based on the state
288 | distributions obtained by Kalman filtering.
289 | Parameters
290 | ----------
291 | kf : The Kalman filter.
292 | cost_matrix : ndarray
293 | The NxM dimensional cost matrix, where N is the number of track indices
294 | and M is the number of detection indices, such that entry (i, j) is the
295 | association cost between `tracks[track_indices[i]]` and
296 | `detections[detection_indices[j]]`.
297 | tracks : List[track.Track]
298 | A list of predicted tracks at the current time step.
299 | detections : List[detection.Detection]
300 | A list of detections at the current time step.
301 | track_indices : List[int]
302 | List of track indices that maps rows in `cost_matrix` to tracks in
303 | `tracks` (see description above).
304 | detection_indices : List[int]
305 | List of detection indices that maps columns in `cost_matrix` to
306 | detections in `detections` (see description above).
307 | Returns
308 | -------
309 | ndarray
310 | Returns the modified cost matrix.
311 | """
312 | measurements = np.asarray([detections[i].to_xyah() for i in detection_indices])
313 | for row, track_idx in enumerate(track_indices):
314 | track = tracks[track_idx]
315 | gating_distance = kf.gating_distance(track.mean, track.covariance, measurements)
316 | cost_matrix[row, gating_distance > 9.4877] = 1e+5
317 | return cost_matrix
--------------------------------------------------------------------------------
/tracking/sort/track.py:
--------------------------------------------------------------------------------
1 | class TrackState:
2 | """
3 | Enumeration type for the single target track state. Newly created tracks are
4 | classified as `tentative` until enough evidence has been collected. Then,
5 | the track state is changed to `confirmed`. Tracks that are no longer alive
6 | are classified as `deleted` to mark them for removal from the set of active
7 | tracks.
8 | """
9 |
10 | Tentative = 1
11 | Confirmed = 2
12 | Deleted = 3
13 |
14 |
15 | class Track:
16 | """
17 | A single target track with state space `(x, y, a, h)` and associated
18 | velocities, where `(x, y)` is the center of the bounding box, `a` is the
19 | aspect ratio and `h` is the height.
20 | Parameters
21 | ----------
22 | mean : ndarray
23 | Mean vector of the initial state distribution.
24 | covariance : ndarray
25 | Covariance matrix of the initial state distribution.
26 | track_id : int
27 | A unique track identifier.
28 | n_init : int
29 | Number of consecutive detections before the track is confirmed. The
30 | track state is set to `Deleted` if a miss occurs within the first
31 | `n_init` frames.
32 | max_age : int
33 | The maximum number of consecutive misses before the track state is
34 | set to `Deleted`.
35 | feature : Optional[ndarray]
36 | Feature vector of the detection this track originates from. If not None,
37 | this feature is added to the `features` cache.
38 | Attributes
39 | ----------
40 | mean : ndarray
41 | Mean vector of the initial state distribution.
42 | covariance : ndarray
43 | Covariance matrix of the initial state distribution.
44 | track_id : int
45 | A unique track identifier.
46 | hits : int
47 | Total number of measurement updates.
48 | age : int
49 | Total number of frames since first occurance.
50 | time_since_update : int
51 | Total number of frames since last measurement update.
52 | state : TrackState
53 | The current track state.
54 | features : List[ndarray]
55 | A cache of features. On each measurement update, the associated feature
56 | vector is added to this list.
57 | """
58 |
59 | def __init__(self, mean, covariance, track_id, n_init, max_age, feature=None, class_id=None):
60 | self.mean = mean
61 | self.covariance = covariance
62 | self.track_id = track_id
63 | self.hits = 1
64 | self.age = 1
65 | self.time_since_update = 0
66 |
67 | self.state = TrackState.Tentative
68 | self.features = []
69 | if feature is not None:
70 | self.features.append(feature)
71 |
72 | self._n_init = n_init
73 | self._max_age = max_age
74 | self.class_id = class_id
75 |
76 | def to_tlwh(self):
77 | """Get current position in bounding box format `(top left x, top left y,
78 | width, height)`.
79 | Returns
80 | -------
81 | ndarray
82 | The bounding box.
83 | """
84 | ret = self.mean[:4].copy()
85 | ret[2] *= ret[3]
86 | ret[:2] -= ret[2:] / 2
87 | return ret
88 |
89 | def to_tlbr(self):
90 | """Get current position in bounding box format `(min x, miny, max x,
91 | max y)`.
92 | Returns
93 | -------
94 | ndarray
95 | The bounding box.
96 | """
97 | ret = self.to_tlwh()
98 | ret[2:] = ret[:2] + ret[2:]
99 | return ret
100 |
101 | def increment_age(self):
102 | self.age += 1
103 | self.time_since_update += 1
104 |
105 | def predict(self, kf):
106 | """Propagate the state distribution to the current time step using a
107 | Kalman filter prediction step.
108 | Parameters
109 | ----------
110 | kf : kalman_filter.KalmanFilter
111 | The Kalman filter.
112 | """
113 | self.mean, self.covariance = kf.predict(self.mean, self.covariance)
114 | self.increment_age()
115 |
116 | def update(self, kf, detection):
117 | """Perform Kalman filter measurement update step and update the feature
118 | cache.
119 | Parameters
120 | ----------
121 | kf : kalman_filter.KalmanFilter
122 | The Kalman filter.
123 | detection : Detection
124 | The associated detection.
125 | """
126 | self.mean, self.covariance = kf.update(self.mean, self.covariance, detection.to_xyah())
127 | self.features.append(detection.feature)
128 |
129 | self.hits += 1
130 | self.time_since_update = 0
131 | if self.state == TrackState.Tentative and self.hits >= self._n_init:
132 | self.state = TrackState.Confirmed
133 |
134 | def mark_missed(self):
135 | """Mark this track as missed (no association at the current time step).
136 | """
137 | if self.state == TrackState.Tentative:
138 | self.state = TrackState.Deleted
139 | elif self.time_since_update > self._max_age:
140 | self.state = TrackState.Deleted
141 |
142 | def is_tentative(self):
143 | """Returns True if this track is tentative (unconfirmed).
144 | """
145 | return self.state == TrackState.Tentative
146 |
147 | def is_confirmed(self):
148 | """Returns True if this track is confirmed."""
149 | return self.state == TrackState.Confirmed
150 |
151 | def is_deleted(self):
152 | """Returns True if this track is dead and should be deleted."""
153 | return self.state == TrackState.Deleted
--------------------------------------------------------------------------------
/tracking/sort/tracker.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .detection import Detection
3 | from .kalman_filter import KalmanFilter
4 | from .matching import NearestNeighborDistanceMetric, iou_cost, min_cost_matching, matching_cascade, gate_cost_matrix
5 | from .track import Track
6 |
7 |
8 | class DeepSORTTracker:
9 | """DeepSORT Tracker
10 | Parameters
11 | ----------
12 | metric : nn_matching.NearestNeighborDistanceMetric
13 | A distance metric for measurement-to-track association.
14 | max_age : int
15 | Maximum number of missed misses before a track is deleted.
16 | n_init : int
17 | Number of consecutive detections before the track is confirmed. The
18 | track state is set to `Deleted` if a miss occurs within the first
19 | `n_init` frames.
20 | Attributes
21 | ----------
22 | metric : nn_matching.NearestNeighborDistanceMetric
23 | The distance metric used for measurement to track association.
24 | max_age : int
25 | Maximum number of missed misses before a track is deleted.
26 | n_init : int
27 | Number of frames that a track remains in initialization phase.
28 | kf : kalman_filter.KalmanFilter
29 | A Kalman filter to filter target trajectories in image space.
30 | tracks : List[Track]
31 | The list of active tracks at the current time step.
32 | """
33 |
34 | def __init__(self, metric_type='cosine', max_cosine_distance=0.4, nn_budget=None, max_iou_distance=0.7, max_age=60, n_init=3):
35 | self.metric = NearestNeighborDistanceMetric(metric_type, max_cosine_distance, nn_budget)
36 | self.max_iou_distance = max_iou_distance
37 | self.max_age = max_age
38 | self.n_init = n_init
39 |
40 | self.kf = KalmanFilter()
41 | self.tracks = []
42 | self._next_id = 1
43 |
44 | def reset(self):
45 | self.tracks = []
46 | self._next_id = 1
47 |
48 | def predict(self):
49 | """Propagate track state distributions one time step forward.
50 | This function should be called once every time step, before `update`.
51 | """
52 | for track in self.tracks:
53 | track.predict(self.kf)
54 |
55 | def increment_ages(self):
56 | for track in self.tracks:
57 | track.increment_age()
58 | track.mark_missed()
59 |
60 | def xyxy2xywh(self, boxes):
61 | boxes[:, 2] -= boxes[:, 0]
62 | boxes[:, 3] -= boxes[:, 1]
63 | return boxes
64 |
65 | def update(self, boxes, classes, features):
66 | detections = [
67 | Detection(bbox, class_id, feature)
68 | for bbox, class_id, feature in zip(self.xyxy2xywh(boxes), classes, features)]
69 |
70 | # Run matching cascade.
71 | matches, unmatched_tracks, unmatched_detections = self._match(detections)
72 |
73 | # Update track set.
74 | for track_idx, detection_idx in matches:
75 | self.tracks[track_idx].update(self.kf, detections[detection_idx])
76 | for track_idx in unmatched_tracks:
77 | self.tracks[track_idx].mark_missed()
78 | for detection_idx in unmatched_detections:
79 | self._initiate_track(detections[detection_idx])
80 |
81 | self.tracks = [t for t in self.tracks if not t.is_deleted()]
82 |
83 | # Update distance metric.
84 | features, targets, active_targets = [], [], []
85 | for track in self.tracks:
86 | if not track.is_confirmed():
87 | continue
88 | active_targets.append(track.track_id)
89 | features += track.features
90 | targets += [track.track_id for _ in track.features]
91 | track.features = []
92 | self.metric.partial_fit(np.asarray(features), np.asarray(targets), active_targets)
93 |
94 | def _match(self, detections):
95 |
96 | def gated_metric(tracks, dets, track_indices, detection_indices):
97 | features = np.array([dets[i].feature for i in detection_indices])
98 | targets = np.array([tracks[i].track_id for i in track_indices])
99 | cost_matrix = self.metric.distance(features, targets)
100 | cost_matrix = gate_cost_matrix(self.kf, cost_matrix, tracks, dets, track_indices, detection_indices)
101 | return cost_matrix
102 |
103 | # Split track set into confirmed and unconfirmed tracks.
104 | confirmed_tracks = [i for i, t in enumerate(self.tracks) if t.is_confirmed()]
105 | unconfirmed_tracks = [i for i, t in enumerate(self.tracks) if not t.is_confirmed()]
106 |
107 | # Associate confirmed tracks using appearance features.
108 | matches_a, unmatched_tracks_a, unmatched_detections = matching_cascade(gated_metric, self.metric.matching_threshold, self.max_age, self.tracks, detections, confirmed_tracks)
109 |
110 | # Associate remaining tracks together with unconfirmed tracks using IOU.
111 | iou_track_candidates = unconfirmed_tracks + [k for k in unmatched_tracks_a if self.tracks[k].time_since_update == 1]
112 | unmatched_tracks_a = [k for k in unmatched_tracks_a if self.tracks[k].time_since_update != 1]
113 | matches_b, unmatched_tracks_b, unmatched_detections = min_cost_matching(iou_cost, self.max_iou_distance, self.tracks, detections, iou_track_candidates, unmatched_detections)
114 |
115 | matches = matches_a + matches_b
116 | unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b))
117 | return matches, unmatched_tracks, unmatched_detections
118 |
119 | def _initiate_track(self, detection):
120 | mean, covariance = self.kf.initiate(detection.to_xyah())
121 | self.tracks.append(Track(mean, covariance, self._next_id, self.n_init, self.max_age, detection.feature, detection.class_id))
122 | self._next_id += 1
--------------------------------------------------------------------------------
/tracking/utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import time
3 | import random
4 | import torch
5 | import os
6 | import urllib.request
7 | import numpy as np
8 | from pathlib import Path
9 | from tqdm import tqdm
10 | from torchvision import ops, io
11 | from threading import Thread
12 | from torch.backends import cudnn
13 | cudnn.benchmark = True
14 | cudnn.deterministic = False
15 |
16 |
17 | def coco_class_index(class_name: str) -> int:
18 | coco_classes = [
19 | 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
20 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
21 | 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
22 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
23 | 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
24 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
25 | 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
26 | 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
27 | 'hair drier', 'toothbrush'
28 | ]
29 | assert class_name.lower() in coco_classes, f"Invalid Class Name.\nAvailable COCO classes: {coco_classes}"
30 | return coco_classes.index(class_name.lower())
31 |
32 |
33 | class Colors:
34 | # Ultralytics color palette https://ultralytics.com/
35 | def __init__(self):
36 | # hex = matplotlib.colors.TABLEAU_COLORS.values()
37 | hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
38 | '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
39 | self.palette = [self.hex2rgb('#' + c) for c in hex]
40 | self.n = len(self.palette)
41 |
42 | def __call__(self, i, bgr=False):
43 | c = self.palette[int(i) % self.n]
44 | return (c[2], c[1], c[0]) if bgr else c
45 |
46 | @staticmethod
47 | def hex2rgb(h): # rgb order (PIL)
48 | return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
49 |
50 |
51 | colors = Colors()
52 |
53 |
54 | class WebcamStream:
55 | def __init__(self, src=0) -> None:
56 | self.cap = cv2.VideoCapture(src)
57 | self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3)
58 | assert self.cap.isOpened(), f"Failed to open webcam {src}"
59 | _, self.frame = self.cap.read()
60 | Thread(target=self.update, args=([]), daemon=True).start()
61 |
62 | def update(self):
63 | while self.cap.isOpened():
64 | _, self.frame = self.cap.read()
65 |
66 | def __iter__(self):
67 | self.count = -1
68 | return self
69 |
70 | def __next__(self):
71 | self.count += 1
72 |
73 | if cv2.waitKey(1) == ord('q'):
74 | self.stop()
75 |
76 | return self.frame.copy()
77 |
78 | def stop(self):
79 | cv2.destroyAllWindows()
80 | raise StopIteration
81 |
82 | def __len__(self):
83 | return 0
84 |
85 |
86 | class SequenceStream:
87 | def __init__(self, folder):
88 | self.frames = self.read_frames(folder)
89 |
90 | print(f"Processing '{folder}'...")
91 | print(f"Total Frames: {len(self.frames)}")
92 | print(f"Video Size : {self.frames[0].shape[:-1]}")
93 |
94 | def read_frames(self, folder):
95 | files = sorted(list(Path(folder).glob('*.jpg')))
96 | frames = []
97 | for file in files:
98 | img = cv2.imread(str(file))
99 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
100 | frames.append(img)
101 | return frames
102 |
103 | def __iter__(self):
104 | self.count = 0
105 | return self
106 |
107 | def __len__(self):
108 | return len(self.frames)
109 |
110 | def __next__(self):
111 | if self.count == len(self.frames):
112 | raise StopIteration
113 | frame = self.frames[self.count]
114 | self.count += 1
115 | return frame
116 |
117 |
118 | class VideoReader:
119 | def __init__(self, video: str):
120 | self.frames, _, info = io.read_video(video, pts_unit='sec')
121 | self.fps = info['video_fps']
122 |
123 | print(f"Processing '{video}'...")
124 | print(f"Total Frames: {len(self.frames)}")
125 | print(f"Video Size : {list(self.frames.shape[1:-1])}")
126 | print(f"Video FPS : {self.fps}")
127 |
128 | def __iter__(self):
129 | self.count = 0
130 | return self
131 |
132 | def __len__(self):
133 | return len(self.frames)
134 |
135 | def __next__(self):
136 | if self.count == len(self.frames):
137 | raise StopIteration
138 | frame = self.frames[self.count]
139 | self.count += 1
140 | return frame
141 |
142 |
143 | class VideoWriter:
144 | def __init__(self, file_name, fps):
145 | self.fname = file_name
146 | self.fps = fps
147 | self.frames = []
148 |
149 | def update(self, frame):
150 | if isinstance(frame, np.ndarray):
151 | frame = torch.from_numpy(frame)
152 | self.frames.append(frame)
153 |
154 | def write(self):
155 | print(f"Saving video to '{self.fname}'...")
156 | io.write_video(self.fname, torch.stack(self.frames), self.fps)
157 |
158 |
159 | class FPS:
160 | def __init__(self, avg=10) -> None:
161 | self.accum_time = 0
162 | self.counts = 0
163 | self.avg = avg
164 |
165 | def synchronize(self):
166 | if torch.cuda.is_available():
167 | torch.cuda.synchronize()
168 |
169 | def start(self):
170 | self.synchronize()
171 | self.prev_time = time.time()
172 |
173 | def stop(self, debug=True):
174 | self.synchronize()
175 | self.accum_time += time.time() - self.prev_time
176 | self.counts += 1
177 | if self.counts == self.avg:
178 | self.fps = round(self.counts / self.accum_time)
179 | if debug: print(f"FPS: {self.fps}")
180 | self.counts = 0
181 | self.accum_time = 0
182 |
183 |
184 | def plot_one_box(box, img, color=None, label=None):
185 | color = color or [random.randint(0, 255) for _ in range(3)]
186 | p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
187 | cv2.rectangle(img, p1, p2, color, 2, lineType=cv2.LINE_AA)
188 |
189 | if label:
190 | t_size = cv2.getTextSize(label, 0, fontScale=0.5, thickness=1)[0]
191 | p2 = p1[0] + t_size[0], p1[1] - t_size[1] - 3
192 | cv2.rectangle(img, p1, p2, color, -1, cv2.LINE_AA)
193 | cv2.putText(img, label, (p1[0], p1[1]-2), 0, 0.5, [255, 255, 255], thickness=1, lineType=cv2.LINE_AA)
194 |
195 |
196 | def letterbox(img, new_shape=(640, 640)):
197 | H, W = img.shape[:2]
198 | if isinstance(new_shape, int):
199 | new_shape = (new_shape, new_shape)
200 |
201 | r = min(new_shape[0] / H, new_shape[1] / W)
202 | nH, nW = round(H * r), round(W * r)
203 | pH, pW = np.mod(new_shape[0] - nH, 32) / 2, np.mod(new_shape[1] - nW, 32) / 2
204 |
205 | if (H, W) != (nH, nW):
206 | img = cv2.resize(img, (nW, nH), interpolation=cv2.INTER_LINEAR)
207 |
208 | top, bottom = round(pH - 0.1), round(pH + 0.1)
209 | left, right = round(pW - 0.1), round(pW + 0.1)
210 | img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
211 | return img
212 |
213 |
214 | def scale_boxes(boxes, orig_shape, new_shape):
215 | H, W = orig_shape
216 | nH, nW = new_shape
217 | gain = min(nH / H, nW / W)
218 | pad = (nH - H * gain) / 2, (nW - W * gain) / 2
219 |
220 | boxes[:, ::2] -= pad[1]
221 | boxes[:, 1::2] -= pad[0]
222 | boxes[:, :4] /= gain
223 |
224 | boxes[:, ::2].clamp_(0, orig_shape[1])
225 | boxes[:, 1::2].clamp_(0, orig_shape[0])
226 | return boxes.round()
227 |
228 |
229 | def xywh2xyxy(x):
230 | boxes = x.clone()
231 | boxes[:, 0] = x[:, 0] - x[:, 2] / 2
232 | boxes[:, 1] = x[:, 1] - x[:, 3] / 2
233 | boxes[:, 2] = x[:, 0] + x[:, 2] / 2
234 | boxes[:, 3] = x[:, 1] + x[:, 3] / 2
235 | return boxes
236 |
237 |
238 | def non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45, classes=None):
239 | candidates = pred[..., 4] > conf_thres
240 |
241 | max_wh = 4096
242 | max_nms = 30000
243 | max_det = 300
244 |
245 | output = [torch.zeros((0, 6), device=pred.device)] * pred.shape[0]
246 |
247 | for xi, x in enumerate(pred):
248 | x = x[candidates[xi]]
249 |
250 | if not x.shape[0]: continue
251 |
252 | # compute conf
253 | x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
254 |
255 | # box
256 | box = xywh2xyxy(x[:, :4])
257 |
258 | # detection matrix nx6
259 | conf, j = x[:, 5:].max(1, keepdim=True)
260 | x = torch.cat([box, conf, j.float()], dim=1)[conf.view(-1) > conf_thres]
261 |
262 | # filter by class
263 | if classes is not None:
264 | x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
265 |
266 | # check shape
267 | n = x.shape[0]
268 | if not n:
269 | continue
270 | elif n > max_nms:
271 | x = x[x[:, 4].argsort(descending=True)[:max_nms]]
272 |
273 | # batched nms
274 | c = x[:, 5:6] * max_wh
275 | boxes, scores = x[:, :4] + c, x[:, 4]
276 | keep = ops.nms(boxes, scores, iou_thres)
277 |
278 | if keep.shape[0] > max_det:
279 | keep = keep[:max_det]
280 |
281 | output[xi] = x[keep]
282 |
283 | return output
284 |
285 |
286 | def download(url: str, root: str):
287 | os.makedirs(root, exist_ok=True)
288 | filename = os.path.basename(url)
289 | download_target = os.path.join(root, filename)
290 |
291 | if os.path.exists(download_target) and os.path.isfile(download_target):
292 | return download_target
293 |
294 | print(f"Downloading model from {url}")
295 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
296 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
297 | while True:
298 | buffer = source.read(8192)
299 | if not buffer:
300 | break
301 |
302 | output.write(buffer)
303 | loop.update(len(buffer))
304 |
305 | return download_target
306 |
307 |
308 |
--------------------------------------------------------------------------------