├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── eval_mot.py ├── gifs ├── cars_out.gif ├── newyork_out.gif └── test_out.gif ├── requirements.txt ├── track.py └── tracking ├── __init__.py ├── clip ├── __init__.py ├── clip.py └── model.py ├── dino ├── __init__.py ├── dino.py ├── vit.py └── xcit.py ├── sort ├── __init__.py ├── detection.py ├── kalman_filter.py ├── matching.py ├── track.py └── tracker.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Repo-specific GitIgnore ---------------------------------------------------------------------------------------------- 2 | *.jpg 3 | *.jpeg 4 | *.png 5 | *.bmp 6 | *.tif 7 | *.tiff 8 | *.heic 9 | *.JPG 10 | *.JPEG 11 | *.PNG 12 | *.BMP 13 | *.TIF 14 | *.TIFF 15 | *.HEIC 16 | *.mp4 17 | *.mov 18 | *.MOV 19 | *.avi 20 | *.data 21 | *.json 22 | 23 | *.cfg 24 | !cfg/yolov3*.cfg 25 | */tracktor/* 26 | storage.googleapis.com 27 | runs/* 28 | data/* 29 | !data/images/zidane.jpg 30 | !data/images/bus.jpg 31 | !data/coco.names 32 | !data/coco_paper.names 33 | !data/coco.data 34 | !data/coco_*.data 35 | !data/coco_*.txt 36 | !data/trainvalno5k.shapes 37 | !data/*.sh 38 | 39 | test.py 40 | test_imgs/ 41 | 42 | pycocotools/* 43 | results*.txt 44 | gcp_test*.sh 45 | 46 | checkpoints/ 47 | output/ 48 | assests/*/ 49 | 50 | # Datasets ------------------------------------------------------------------------------------------------------------- 51 | coco/ 52 | coco128/ 53 | VOC/ 54 | 55 | # MATLAB GitIgnore ----------------------------------------------------------------------------------------------------- 56 | *.m~ 57 | *.mat 58 | !targets*.mat 59 | 60 | # Neural Network weights ----------------------------------------------------------------------------------------------- 61 | *.weights 62 | *.pt 63 | *.onnx 64 | *.mlmodel 65 | *.torchscript 66 | darknet53.conv.74 67 | yolov3-tiny.conv.15 68 | 69 | # GitHub Python GitIgnore ---------------------------------------------------------------------------------------------- 70 | # Byte-compiled / optimized / DLL files 71 | __pycache__/ 72 | *.py[cod] 73 | *$py.class 74 | 75 | # C extensions 76 | *.so 77 | 78 | # Distribution / packaging 79 | .Python 80 | env/ 81 | build/ 82 | develop-eggs/ 83 | dist/ 84 | downloads/ 85 | eggs/ 86 | .eggs/ 87 | lib/ 88 | lib64/ 89 | parts/ 90 | sdist/ 91 | var/ 92 | wheels/ 93 | *.egg-info/ 94 | wandb/ 95 | .installed.cfg 96 | *.egg 97 | 98 | 99 | # PyInstaller 100 | # Usually these files are written by a python script from a template 101 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 102 | *.manifest 103 | *.spec 104 | 105 | # Installer logs 106 | pip-log.txt 107 | pip-delete-this-directory.txt 108 | 109 | # Unit test / coverage reports 110 | htmlcov/ 111 | .tox/ 112 | .coverage 113 | .coverage.* 114 | .cache 115 | nosetests.xml 116 | coverage.xml 117 | *.cover 118 | .hypothesis/ 119 | 120 | # Translations 121 | *.mo 122 | *.pot 123 | 124 | # Django stuff: 125 | *.log 126 | local_settings.py 127 | 128 | # Flask stuff: 129 | instance/ 130 | .webassets-cache 131 | 132 | # Scrapy stuff: 133 | .scrapy 134 | 135 | # Sphinx documentation 136 | docs/_build/ 137 | 138 | # PyBuilder 139 | target/ 140 | 141 | # Jupyter Notebook 142 | .ipynb_checkpoints 143 | 144 | # pyenv 145 | .python-version 146 | 147 | # celery beat schedule file 148 | celerybeat-schedule 149 | 150 | # SageMath parsed files 151 | *.sage.py 152 | 153 | # dotenv 154 | .env 155 | 156 | # virtualenv 157 | .venv* 158 | venv*/ 159 | ENV*/ 160 | 161 | # Spyder project settings 162 | .spyderproject 163 | .spyproject 164 | 165 | # Rope project settings 166 | .ropeproject 167 | 168 | # mkdocs documentation 169 | /site 170 | 171 | # mypy 172 | .mypy_cache/ 173 | 174 | 175 | # https://github.com/github/gitignore/blob/master/Global/macOS.gitignore ----------------------------------------------- 176 | 177 | # General 178 | .DS_Store 179 | .AppleDouble 180 | .LSOverride 181 | 182 | # Icon must end with two \r 183 | Icon 184 | Icon? 185 | 186 | # Thumbnails 187 | ._* 188 | 189 | # Files that might appear in the root of a volume 190 | .DocumentRevisions-V100 191 | .fseventsd 192 | .Spotlight-V100 193 | .TemporaryItems 194 | .Trashes 195 | .VolumeIcon.icns 196 | .com.apple.timemachine.donotpresent 197 | 198 | # Directories potentially created on remote AFP share 199 | .AppleDB 200 | .AppleDesktop 201 | Network Trash Folder 202 | Temporary Items 203 | .apdisk 204 | 205 | 206 | # https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore 207 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 208 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 209 | 210 | # User-specific stuff: 211 | .idea/* 212 | .idea/**/workspace.xml 213 | .idea/**/tasks.xml 214 | .idea/dictionaries 215 | .html # Bokeh Plots 216 | .pg # TensorFlow Frozen Graphs 217 | .avi # videos 218 | 219 | # Sensitive or high-churn files: 220 | .idea/**/dataSources/ 221 | .idea/**/dataSources.ids 222 | .idea/**/dataSources.local.xml 223 | .idea/**/sqlDataSources.xml 224 | .idea/**/dynamic.xml 225 | .idea/**/uiDesigner.xml 226 | 227 | # Gradle: 228 | .idea/**/gradle.xml 229 | .idea/**/libraries 230 | 231 | # CMake 232 | cmake-build-debug/ 233 | cmake-build-release/ 234 | 235 | # Mongo Explorer plugin: 236 | .idea/**/mongoSettings.xml 237 | 238 | ## File-based project format: 239 | *.iws 240 | 241 | ## Plugin-specific files: 242 | 243 | # IntelliJ 244 | out/ 245 | 246 | # mpeltonen/sbt-idea plugin 247 | .idea_modules/ 248 | 249 | # JIRA plugin 250 | atlassian-ide-plugin.xml 251 | 252 | # Cursive Clojure plugin 253 | .idea/replstate.xml 254 | 255 | # Crashlytics plugin (for Android Studio and IntelliJ) 256 | com_crashlytics_export_strings.xml 257 | crashlytics.properties 258 | crashlytics-build.properties 259 | fabric.properties 260 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "yolov5"] 2 | path = yolov5 3 | url = https://github.com/ultralytics/yolov5 4 | [submodule "TrackEval"] 5 | path = TrackEval 6 | url = https://github.com/JonathonLuiten/TrackEval 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 sithu3 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #
Simple Object Tracking
2 | 3 |
4 |

Multi-Object Tracking with YOLOv5, CLIP, DINO and DeepSORT

5 |

6 | 7 |

8 |
9 | 10 | ## Introduction 11 | 12 | This is a simple two-stage mulit-object tracking [YOLOv5](https://github.com/ultralytics/yolov5) and [DeepSORT](https://arxiv.org/abs/1703.07402) with zero-short or self-supervised feature extractors. 13 | 14 | Normally, in DeepSORT, the deep part of the model is trained on a person re-identification dataset like [Market1501](https://www.kaggle.com/pengcw1/market-1501/data). We will replace this model with zero-shot or self-supervised models; which makes it ready to track any classes without needing to re-train. 15 | 16 | SOTA models like [CLIP](https://arxiv.org/abs/2103.00020) (zero-shot) and [DINO](https://arxiv.org/abs/2104.14294v2) (SSL) are currently experimented. If better models come out, I will consider adding it. 17 | 18 | ## Requirements 19 | 20 | * torch >= 1.8.1 21 | * torchvision >= 0.9.1 22 | 23 | Other requirements can be installed with `pip install -r requirements.txt`. 24 | 25 | Clone the repository recursively: 26 | 27 | ```bash 28 | $ git clone --recursive https://github.com/sithu31296/simple-object-tracking.git 29 | ``` 30 | 31 | Then download a YOLO model's weight from [YOLOv5](https://github.com/ultralytics/yolov5) and place it in `checkpoints`. 32 | 33 | ## Tracking 34 | 35 | Track all classes: 36 | 37 | ```bash 38 | ## webcam 39 | $ python track.py --source 0 --yolo-model checkpoints/yolov5s.pt --reid-model CLIP-RN50 40 | 41 | ## video 42 | $ python track.py --source VIDEO_PATH --yolo-model checkpoints/yolov5s.pt --reid-model CLIP-RN50 43 | ``` 44 | 45 | Track only specified classes: 46 | 47 | ```bash 48 | ## track only person class 49 | $ python track.py --source 0 --yolo-model checkpoints/yolov5s.pt --reid-model CLIP-RN50 --filter-class 0 50 | 51 | ## track person and car classes 52 | $ python track.py --source 0 --yolo-model checkpoints/yolov5s.pt --reid-model CLIP-RN50 --filter-class 0 2 53 | ``` 54 | 55 | Available ReID models (Feature Extractors): 56 | * **CLIP**: `CLIP-RN50`, `CLIP-ViT-B/32` 57 | * **DINO**: `DINO-XciT-S12/16`, `DINO-XciT-M24/16`, `DINO-ViT-S/16`, `DINO-ViT-B/16` 58 | 59 | Check [here](tracking/utils.py#L14) to get COCO class index for your class. 60 | 61 | ## Evaluate on MOT16 62 | 63 | * Download MOT16 dataset from [here](https://motchallenge.net/data/MOT16.zip) and unzip it. 64 | * Download mot-challenge ground-truth [data](https://omnomnom.vision.rwth-aachen.de/data/TrackEval/data.zip) for evaluating with TrackEval. Then, unzip it under the project directory. 65 | * Save the tracking results of MOT16 with the following command: 66 | 67 | ```bash 68 | $ python eval_mot.py --root MOT16_ROOT_DIR --yolo-model checkpoints/yolov5m.pt --reid-model CLIP-RN50 69 | ``` 70 | 71 | * Evaluate with TrackEval: 72 | 73 | ```bash 74 | $ python TrackEval/scripts/run_mot_challenge.py \ 75 | --BENCHMARK MOT16 \ 76 | --GT_FOLDER PROJECT_ROOT/data/gt/mot_challenge/ \ 77 | --TRACKERS_FOLDER PROJECT_ROOT/data/trackers/mot_challenge/ \ 78 | --TRACKERS_TO_EVAL mot_det \ 79 | --SPLIT_TO_EVAL train \ 80 | --USE_PARALLEL True \ 81 | --NUM_PARALLEL_CORES 4 \ 82 | --PRINT_ONLY_COMBINED True \ 83 | ``` 84 | 85 | > Notes: `FOLDER` parameters in `run_mot_challenge.py` must be an absolute path. 86 | 87 | For tracking persons, instead of using a COCO-pretrained model, using a model trained on multi-person dataset will get better accuracy. You can download a YOLOv5m model trained on [CrowdHuman](https://www.crowdhuman.org/) dataset from [here](https://drive.google.com/file/d/1gglIwqxaH2iTvy6lZlXuAcMpd_U0GCUb/view?usp=sharing). The weights are from [deepakcrk/yolov5-crowdhuman](https://github.com/deepakcrk/yolov5-crowdhuman). It has 2 classes: 'person' and 'head'. So, you can use this model for both person and head tracking. 88 | 89 | ## Results 90 | 91 | **MOT16 Evaluation Results** 92 | 93 | Detector | Feature Extractor | MOTA↑ | HOTA↑ | IDF1↑ | IDsw↓ | MT↑ | ML↓ | FP↓ | FN↓ | FPS
(GTX1660ti) 94 | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- 95 | YOLOv5m
(COCO) | CLIP
(RN50) | 35.42 | 35.37 | 39.42 | **486** | 115 | 192 | **6880** | 63931 | 7 96 | YOLOv5m
(CrowdHuman) | CLIP
(RN50) | 53.25 | 43.25 | 52.12 | 912 | 196 | **89** | 14076 | 36625 | 6 97 | YOLOv5m
(CrowdHuman) | CLIP
(ViT-B/32) | 53.35 | 43.03 | 51.25 | 896 | **199** | 91 | 14035 | **36575** | 4 98 | || 99 | YOLOv5m
(CrowdHuman) | DINO
(XciT-S12/16) | 54.41 | 47.44 | 59.01 | 511 | 184 | 101 | 12265 | 37555 |8 100 | YOLOv5m
(CrowdHuman) | DINO
(ViT-S/16) | 54.56 | 47.61 | 58.94 | 519 | 189 | 97 | 12346 | 37308 | 8 101 | YOLOv5m
(CrowdHuman) | DINO
(XciT-M24/16) | 54.56 | **47.71** | **59.77** | 504 | 187 | 96 | 12364 | 37306 | 5 102 | YOLOv5m
(CrowdHuman) | DINO
(ViT-B/16) | **54.58** | 47.55 | 58.89 | 507 | 184 | 97 | 12017 | 37621 | 5 103 | 104 | **FPS Results** 105 | 106 | Detector | Feature Extractor | GPU | Precision | Image Size | Detection
/Frame | FPS 107 | --- | --- | --- | --- | --- | --- | --- 108 | YOLOv5s | CLIP-RN50 | GTX-1660ti | FP32 | 480x640 | 1 | 38 109 | YOLOv5s | CLIP-ViT-B/32 | GTX-1660ti | FP32 | 480x640 | 1 | 30 110 | || 111 | YOLOv5s | DINO-XciT-S12/16 | GTX-1660ti | FP32 | 480x640 | 1 | 36 112 | YOLOv5s | DINO-ViT-B/16 | GTX-1660ti | FP32 | 480x640 | 1 | 30 113 | YOLOv5s | DINO-XciT-M24/16 | GTX-1660ti | FP32 | 480x640 | 1 | 25 114 | 115 | 116 | ## References 117 | 118 | * https://github.com/ultralytics/yolov5 119 | * https://github.com/JonathonLuiten/TrackEval 120 | 121 | ## Citations 122 | 123 | ``` 124 | @inproceedings{caron2021emerging, 125 | title={Emerging Properties in Self-Supervised Vision Transformers}, 126 | author={Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand}, 127 | booktitle={Proceedings of the International Conference on Computer Vision (ICCV)}, 128 | year={2021} 129 | } 130 | 131 | @article{el2021xcit, 132 | title={XCiT: Cross-Covariance Image Transformers}, 133 | author={El-Nouby, Alaaeldin and Touvron, Hugo and Caron, Mathilde and Bojanowski, Piotr and Douze, Matthijs and Joulin, Armand and Laptev, Ivan and Neverova, Natalia and Synnaeve, Gabriel and Verbeek, Jakob and others}, 134 | journal={arXiv preprint arXiv:2106.09681}, 135 | year={2021} 136 | } 137 | 138 | @misc{radford2021learning, 139 | title={Learning Transferable Visual Models From Natural Language Supervision}, 140 | author={Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever}, 141 | year={2021}, 142 | eprint={2103.00020}, 143 | archivePrefix={arXiv}, 144 | primaryClass={cs.CV} 145 | } 146 | 147 | @inproceedings{Wojke2017simple, 148 | title={Simple Online and Realtime Tracking with a Deep Association Metric}, 149 | author={Wojke, Nicolai and Bewley, Alex and Paulus, Dietrich}, 150 | booktitle={2017 IEEE International Conference on Image Processing (ICIP)}, 151 | year={2017}, 152 | pages={3645--3649}, 153 | organization={IEEE}, 154 | doi={10.1109/ICIP.2017.8296962} 155 | } 156 | 157 | @inproceedings{Wojke2018deep, 158 | title={Deep Cosine Metric Learning for Person Re-identification}, 159 | author={Wojke, Nicolai and Bewley, Alex}, 160 | booktitle={2018 IEEE Winter Conference on Applications of Computer Vision (WACV)}, 161 | year={2018}, 162 | pages={748--756}, 163 | organization={IEEE}, 164 | doi={10.1109/WACV.2018.00087} 165 | } 166 | ``` -------------------------------------------------------------------------------- /eval_mot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import shutil 4 | from pathlib import Path 5 | from tqdm import tqdm 6 | from tracking.utils import * 7 | 8 | from track import Tracking 9 | 10 | 11 | class EvalTracking(Tracking): 12 | def __init__(self, yolo_model, reid_model, img_size, filter_class, conf_thres, iou_thres, max_cosine_dist, max_iou_dist, nn_budget, max_age, n_init) -> None: 13 | super().__init__(yolo_model, reid_model, img_size=img_size, filter_class=filter_class, conf_thres=conf_thres, iou_thres=iou_thres, max_cosine_dist=max_cosine_dist, max_iou_dist=max_iou_dist, nn_budget=nn_budget, max_age=max_age, n_init=n_init) 14 | 15 | def postprocess(self, pred, img1, img0, txt_path, frame_idx): 16 | pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, classes=self.filter_class) 17 | 18 | for det in pred: 19 | if len(det): 20 | boxes = scale_boxes(det[:, :4], img0.shape[:2], img1.shape[-2:]).cpu() 21 | features = self.extract_features(boxes, img0) 22 | 23 | self.tracker.predict() 24 | self.tracker.update(boxes, det[:, 5], features) 25 | 26 | for track in self.tracker.tracks: 27 | if not track.is_confirmed() or track.time_since_update > 1: continue 28 | 29 | x1, y1, x2, y2 = track.to_tlbr() 30 | w, h = x2 - x1, y2 - y1 31 | 32 | with open(txt_path, 'a') as f: 33 | f.write(f"{frame_idx+1},{track.track_id},{x1:.4f},{y1:.4f},{w:.4f},{h:.4f},-1,-1,-1,-1\n") 34 | else: 35 | self.tracker.increment_ages() 36 | 37 | @torch.no_grad() 38 | def predict(self, image, txt_path, frame_idx): 39 | img = self.preprocess(image) 40 | pred = self.model(img)[0] 41 | self.postprocess(pred, img, image, txt_path, frame_idx) 42 | 43 | 44 | def argument_parser(): 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--root', type=str, default='/home/sithu/datasets/MOT16') 47 | parser.add_argument('--yolo-model', type=str, default='checkpoints/crowdhuman_yolov5m.pt') 48 | parser.add_argument('--reid-model', type=str, default='CLIP-RN50') 49 | parser.add_argument('--img-size', type=int, default=640) 50 | parser.add_argument('--filter-class', nargs='+', type=int, default=0) 51 | parser.add_argument('--conf-thres', type=float, default=0.4) 52 | parser.add_argument('--iou-thres', type=float, default=0.5) 53 | parser.add_argument('--max-cosine-dist', type=float, default=0.2) 54 | parser.add_argument('--max-iou-dist', type=int, default=0.7) 55 | parser.add_argument('--nn-budget', type=int, default=100) 56 | parser.add_argument('--max-age', type=int, default=70) 57 | parser.add_argument('--n-init', type=int, default=3) 58 | return parser.parse_args() 59 | 60 | 61 | if __name__ == '__main__': 62 | args = argument_parser() 63 | tracking = EvalTracking( 64 | args.yolo_model, 65 | args.reid_model, 66 | args.img_size, 67 | args.filter_class, 68 | args.conf_thres, 69 | args.iou_thres, 70 | args.max_cosine_dist, 71 | args.max_iou_dist, 72 | args.nn_budget, 73 | args.max_age, 74 | args.n_init 75 | ) 76 | 77 | save_path = Path('data') / 'trackers' / 'mot_challenge' / 'MOT16-train' / 'mot_det' / 'data' 78 | if save_path.exists(): 79 | shutil.rmtree(save_path) 80 | save_path.mkdir(parents=True) 81 | 82 | root = Path(args.root) / 'train' 83 | folders = root.iterdir() 84 | 85 | total_fps = [] 86 | 87 | for folder in folders: 88 | tracking.tracker.reset() 89 | reader = SequenceStream(folder / 'img1') 90 | txt_path = save_path / f"{folder.stem}.txt" 91 | fps = FPS(len(reader.frames)) 92 | 93 | for i, frame in tqdm(enumerate(reader), total=len(reader)): 94 | fps.start() 95 | tracking.predict(frame, txt_path, i) 96 | fps.stop(False) 97 | 98 | print(f"FPS: {fps.fps}") 99 | total_fps.append(fps.fps) 100 | del reader 101 | 102 | print(f"Average FPS for MOT16: {round(sum(total_fps) / len(total_fps))}") -------------------------------------------------------------------------------- /gifs/cars_out.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/simple-object-tracking/4e86e53e78799c5cb92d2f1cddd2df071530e98e/gifs/cars_out.gif -------------------------------------------------------------------------------- /gifs/newyork_out.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/simple-object-tracking/4e86e53e78799c5cb92d2f1cddd2df071530e98e/gifs/newyork_out.gif -------------------------------------------------------------------------------- /gifs/test_out.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/simple-object-tracking/4e86e53e78799c5cb92d2f1cddd2df071530e98e/gifs/test_out.gif -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | seaborn 2 | ftfy 3 | regex 4 | matplotlib 5 | numpy 6 | opencv-python 7 | scipy 8 | tqdm 9 | -------------------------------------------------------------------------------- /track.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import cv2 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from tracking import load_feature_extractor 8 | from tracking.sort.tracker import DeepSORTTracker 9 | from tracking.utils import * 10 | 11 | import sys 12 | sys.path.insert(0, 'yolov5') 13 | from yolov5.models.experimental import attempt_load 14 | 15 | 16 | 17 | class Tracking: 18 | def __init__(self, 19 | yolo_model, 20 | reid_model, 21 | img_size=640, 22 | filter_class=None, 23 | conf_thres=0.25, 24 | iou_thres=0.45, 25 | max_cosine_dist=0.4, # the higher the value, the easier it is to assume it is the same person 26 | max_iou_dist=0.7, # how much bboxes should overlap to determine the identity of the unassigned track 27 | nn_budget=None, # indicates how many previous frames of features vectors should be retained for distance calc for ecah track 28 | max_age=60, # specifies after how many frames unallocated tracks will be deleted 29 | n_init=3 # specifies after how many frames newly allocated tracks will be activated 30 | ) -> None: 31 | self.img_size = img_size 32 | self.conf_thres = conf_thres 33 | self.iou_thres = iou_thres 34 | self.filter_class = filter_class 35 | 36 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 37 | self.model = attempt_load(yolo_model, map_location=self.device) 38 | self.model = self.model.to(self.device) 39 | self.names = self.model.names 40 | 41 | self.patch_model, self.patch_transform = load_feature_extractor(reid_model, self.device) 42 | self.tracker = DeepSORTTracker('cosine', max_cosine_dist, nn_budget, max_iou_dist, max_age, n_init) 43 | 44 | 45 | def preprocess(self, image): 46 | img = letterbox(image, new_shape=self.img_size) 47 | img = np.ascontiguousarray(img.transpose((2, 0, 1))) 48 | img = torch.from_numpy(img).to(self.device) 49 | img = img.float() / 255.0 50 | img = img[None] 51 | return img 52 | 53 | 54 | def extract_features(self, boxes, img): 55 | image_patches = [] 56 | for xyxy in boxes: 57 | x1, y1, x2, y2 = map(int, xyxy) 58 | img_patch = Image.fromarray(img[y1:y2, x1:x2]) 59 | img_patch = self.patch_transform(img_patch) 60 | image_patches.append(img_patch) 61 | 62 | image_patches = torch.stack(image_patches).to(self.device) 63 | features = self.patch_model.encode_image(image_patches).cpu().numpy() 64 | return features 65 | 66 | 67 | def postprocess(self, pred, img1, img0): 68 | pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, classes=self.filter_class) 69 | 70 | for det in pred: 71 | if len(det): 72 | boxes = scale_boxes(det[:, :4], img0.shape[:2], img1.shape[-2:]).cpu() 73 | features = self.extract_features(boxes, img0) 74 | 75 | self.tracker.predict() 76 | self.tracker.update(boxes, det[:, 5], features) 77 | 78 | for track in self.tracker.tracks: 79 | if not track.is_confirmed() or track.time_since_update > 1: continue 80 | label = f"{self.names[int(track.class_id)]} #{track.track_id}" 81 | plot_one_box(track.to_tlbr(), img0, color=colors(int(track.class_id)), label=label) 82 | else: 83 | self.tracker.increment_ages() 84 | 85 | 86 | @torch.no_grad() 87 | def predict(self, image): 88 | img = self.preprocess(image) 89 | pred = self.model(img)[0] 90 | self.postprocess(pred, img, image) 91 | return image 92 | 93 | 94 | def argument_parser(): 95 | parser = argparse.ArgumentParser() 96 | parser.add_argument('--source', type=str, default='0') 97 | parser.add_argument('--yolo-model', type=str, default='checkpoints/yolov5s.pt') 98 | parser.add_argument('--reid-model', type=str, default='CLIP-RN50') 99 | parser.add_argument('--img-size', type=int, default=640) 100 | parser.add_argument('--filter-class', nargs='+', type=int, default=None) 101 | parser.add_argument('--conf-thres', type=float, default=0.4) 102 | parser.add_argument('--iou-thres', type=float, default=0.5) 103 | parser.add_argument('--max-cosine-dist', type=float, default=0.2) 104 | parser.add_argument('--max-iou-dist', type=int, default=0.7) 105 | parser.add_argument('--nn-budget', type=int, default=100) 106 | parser.add_argument('--max-age', type=int, default=70) 107 | parser.add_argument('--n-init', type=int, default=3) 108 | return parser.parse_args() 109 | 110 | 111 | if __name__ == '__main__': 112 | args = argument_parser() 113 | tracking = Tracking( 114 | args.yolo_model, 115 | args.reid_model, 116 | args.img_size, 117 | args.filter_class, 118 | args.conf_thres, 119 | args.iou_thres, 120 | args.max_cosine_dist, 121 | args.max_iou_dist, 122 | args.nn_budget, 123 | args.max_age, 124 | args.n_init 125 | ) 126 | 127 | if args.source.isnumeric(): 128 | webcam = WebcamStream() 129 | fps = FPS() 130 | 131 | for frame in webcam: 132 | fps.start() 133 | output = tracking.predict(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 134 | fps.stop() 135 | cv2.imshow('frame', cv2.cvtColor(output, cv2.COLOR_RGB2BGR)) 136 | 137 | else: 138 | reader = VideoReader(args.source) 139 | writer = VideoWriter(f"{args.source.rsplit('.', maxsplit=1)[0]}_out.mp4", reader.fps) 140 | fps = FPS(len(reader.frames)) 141 | 142 | for frame in tqdm(reader): 143 | fps.start() 144 | output = tracking.predict(frame.numpy()) 145 | fps.stop(False) 146 | writer.update(output) 147 | 148 | print(f"FPS: {fps.fps}") 149 | writer.write() 150 | -------------------------------------------------------------------------------- /tracking/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import download 3 | from .clip import load as clip_load 4 | from .dino import load as dino_load 5 | 6 | 7 | _MODELS = { 8 | "CLIP-RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 9 | "CLIP-ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 10 | "DINO-XciT-S12/16": "https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth", 11 | "DINO-XciT-M24/16": "https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth", 12 | "DINO-ViT-S/16": "https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", 13 | "DINO-ViT-B/16": "https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth", 14 | } 15 | 16 | 17 | def load_feature_extractor(model_name: str, device): 18 | assert model_name in _MODELS 19 | model_path = download(_MODELS[model_name], os.path.expanduser("~/.cache/tracking")) 20 | 21 | if model_name.startswith('CLIP'): 22 | model, transform = clip_load(model_path, device, jit=False) 23 | elif model_name.startswith('DINO'): 24 | model, transform = dino_load(model_name, model_path, device) 25 | return model, transform 26 | 27 | -------------------------------------------------------------------------------- /tracking/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * -------------------------------------------------------------------------------- /tracking/clip/clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms as T 3 | 4 | from .model import build_model 5 | 6 | 7 | __all__ = ["load"] 8 | 9 | 10 | def _transform(n_px): 11 | return T.Compose([ 12 | T.Resize(n_px, interpolation=T.InterpolationMode.BICUBIC), 13 | T.CenterCrop(n_px), 14 | T.ToTensor(), 15 | T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 16 | ]) 17 | 18 | 19 | def load(model_path: str, device, jit=False): 20 | # loading JIT archive 21 | model = torch.jit.load(model_path, map_location="cpu").eval() 22 | 23 | if not jit: 24 | model = build_model(model.state_dict()).to(device) 25 | if str(device) == "cpu": 26 | model.float() 27 | return model, _transform(model.visual.input_resolution) 28 | 29 | # patch the device names 30 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 31 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 32 | 33 | def patch_device(module): 34 | try: 35 | graphs = [module.graph] if hasattr(module, "graph") else [] 36 | except RuntimeError: 37 | graphs = [] 38 | 39 | if hasattr(module, "forward1"): 40 | graphs.append(module.forward1.graph) 41 | 42 | for graph in graphs: 43 | for node in graph.findAllNodes("prim::Constant"): 44 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 45 | node.copyAttributes(device_node) 46 | 47 | model.apply(patch_device) 48 | patch_device(model.encode_image) 49 | patch_device(model.encode_text) 50 | 51 | # patch dtype to float32 on CPU 52 | if str(device) == "cpu": 53 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 54 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 55 | float_node = float_input.node() 56 | 57 | def patch_float(module): 58 | try: 59 | graphs = [module.graph] if hasattr(module, "graph") else [] 60 | except RuntimeError: 61 | graphs = [] 62 | 63 | if hasattr(module, "forward1"): 64 | graphs.append(module.forward1.graph) 65 | 66 | for graph in graphs: 67 | for node in graph.findAllNodes("aten::to"): 68 | inputs = list(node.inputs()) 69 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 70 | if inputs[i].node()["value"] == 5: 71 | inputs[i].node().copyAttributes(float_node) 72 | 73 | model.apply(patch_float) 74 | patch_float(model.encode_image) 75 | patch_float(model.encode_text) 76 | 77 | model.float() 78 | 79 | return model, _transform(model.input_resolution.item()) -------------------------------------------------------------------------------- /tracking/clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | 20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 24 | 25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 27 | 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = None 30 | self.stride = stride 31 | 32 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 34 | self.downsample = nn.Sequential(OrderedDict([ 35 | ("-1", nn.AvgPool2d(stride)), 36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 37 | ("1", nn.BatchNorm2d(planes * self.expansion)) 38 | ])) 39 | 40 | def forward(self, x: torch.Tensor): 41 | identity = x 42 | 43 | out = self.relu(self.bn1(self.conv1(x))) 44 | out = self.relu(self.bn2(self.conv2(out))) 45 | out = self.avgpool(out) 46 | out = self.bn3(self.conv3(out)) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | 51 | out += identity 52 | out = self.relu(out) 53 | return out 54 | 55 | 56 | class AttentionPool2d(nn.Module): 57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 58 | super().__init__() 59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 60 | self.k_proj = nn.Linear(embed_dim, embed_dim) 61 | self.q_proj = nn.Linear(embed_dim, embed_dim) 62 | self.v_proj = nn.Linear(embed_dim, embed_dim) 63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 64 | self.num_heads = num_heads 65 | 66 | def forward(self, x): 67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 70 | x, _ = F.multi_head_attention_forward( 71 | query=x, key=x, value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False 88 | ) 89 | 90 | return x[0] 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.input_resolution = input_resolution 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 110 | self.bn2 = nn.BatchNorm2d(width // 2) 111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 112 | self.bn3 = nn.BatchNorm2d(width) 113 | self.avgpool = nn.AvgPool2d(2) 114 | self.relu = nn.ReLU(inplace=True) 115 | 116 | # residual layers 117 | self._inplanes = width # this is a *mutable* variable used during construction 118 | self.layer1 = self._make_layer(width, layers[0]) 119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 122 | 123 | embed_dim = width * 32 # the ResNet feature dimension 124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 125 | 126 | def _make_layer(self, planes, blocks, stride=1): 127 | layers = [Bottleneck(self._inplanes, planes, stride)] 128 | 129 | self._inplanes = planes * Bottleneck.expansion 130 | for _ in range(1, blocks): 131 | layers.append(Bottleneck(self._inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | def stem(x): 137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 138 | x = self.relu(bn(conv(x))) 139 | x = self.avgpool(x) 140 | return x 141 | 142 | x = x.type(self.conv1.weight.dtype) 143 | x = stem(x) 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | x = self.attnpool(x) 149 | 150 | return x 151 | 152 | 153 | class LayerNorm(nn.LayerNorm): 154 | """Subclass torch's LayerNorm to handle fp16.""" 155 | 156 | def forward(self, x: torch.Tensor): 157 | orig_type = x.dtype 158 | ret = super().forward(x.type(torch.float32)) 159 | return ret.type(orig_type) 160 | 161 | 162 | class QuickGELU(nn.Module): 163 | def forward(self, x: torch.Tensor): 164 | return x * torch.sigmoid(1.702 * x) 165 | 166 | 167 | class ResidualAttentionBlock(nn.Module): 168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 169 | super().__init__() 170 | 171 | self.attn = nn.MultiheadAttention(d_model, n_head) 172 | self.ln_1 = LayerNorm(d_model) 173 | self.mlp = nn.Sequential(OrderedDict([ 174 | ("c_fc", nn.Linear(d_model, d_model * 4)), 175 | ("gelu", QuickGELU()), 176 | ("c_proj", nn.Linear(d_model * 4, d_model)) 177 | ])) 178 | self.ln_2 = LayerNorm(d_model) 179 | self.attn_mask = attn_mask 180 | 181 | def attention(self, x: torch.Tensor): 182 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 183 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 184 | 185 | def forward(self, x: torch.Tensor): 186 | x = x + self.attention(self.ln_1(x)) 187 | x = x + self.mlp(self.ln_2(x)) 188 | return x 189 | 190 | 191 | class Transformer(nn.Module): 192 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 193 | super().__init__() 194 | self.width = width 195 | self.layers = layers 196 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 197 | 198 | def forward(self, x: torch.Tensor): 199 | return self.resblocks(x) 200 | 201 | 202 | class VisionTransformer(nn.Module): 203 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 204 | super().__init__() 205 | self.input_resolution = input_resolution 206 | self.output_dim = output_dim 207 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 208 | 209 | scale = width ** -0.5 210 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 211 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 212 | self.ln_pre = LayerNorm(width) 213 | 214 | self.transformer = Transformer(width, layers, heads) 215 | 216 | self.ln_post = LayerNorm(width) 217 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 218 | 219 | def forward(self, x: torch.Tensor): 220 | x = self.conv1(x) # shape = [*, width, grid, grid] 221 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 222 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 223 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 224 | x = x + self.positional_embedding.to(x.dtype) 225 | x = self.ln_pre(x) 226 | 227 | x = x.permute(1, 0, 2) # NLD -> LND 228 | x = self.transformer(x) 229 | x = x.permute(1, 0, 2) # LND -> NLD 230 | 231 | x = self.ln_post(x[:, 0, :]) 232 | 233 | if self.proj is not None: 234 | x = x @ self.proj 235 | 236 | return x 237 | 238 | 239 | class CLIP(nn.Module): 240 | def __init__(self, 241 | embed_dim: int, 242 | # vision 243 | image_resolution: int, 244 | vision_layers: Union[Tuple[int, int, int, int], int], 245 | vision_width: int, 246 | vision_patch_size: int, 247 | # text 248 | context_length: int, 249 | vocab_size: int, 250 | transformer_width: int, 251 | transformer_heads: int, 252 | transformer_layers: int 253 | ): 254 | super().__init__() 255 | 256 | self.context_length = context_length 257 | 258 | if isinstance(vision_layers, (tuple, list)): 259 | vision_heads = vision_width * 32 // 64 260 | self.visual = ModifiedResNet( 261 | layers=vision_layers, 262 | output_dim=embed_dim, 263 | heads=vision_heads, 264 | input_resolution=image_resolution, 265 | width=vision_width 266 | ) 267 | else: 268 | vision_heads = vision_width // 64 269 | self.visual = VisionTransformer( 270 | input_resolution=image_resolution, 271 | patch_size=vision_patch_size, 272 | width=vision_width, 273 | layers=vision_layers, 274 | heads=vision_heads, 275 | output_dim=embed_dim 276 | ) 277 | 278 | self.transformer = Transformer( 279 | width=transformer_width, 280 | layers=transformer_layers, 281 | heads=transformer_heads, 282 | attn_mask=self.build_attention_mask() 283 | ) 284 | 285 | self.vocab_size = vocab_size 286 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 287 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 288 | self.ln_final = LayerNorm(transformer_width) 289 | 290 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 291 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 292 | 293 | self.initialize_parameters() 294 | 295 | def initialize_parameters(self): 296 | nn.init.normal_(self.token_embedding.weight, std=0.02) 297 | nn.init.normal_(self.positional_embedding, std=0.01) 298 | 299 | if isinstance(self.visual, ModifiedResNet): 300 | if self.visual.attnpool is not None: 301 | std = self.visual.attnpool.c_proj.in_features ** -0.5 302 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 303 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 304 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 305 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 306 | 307 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 308 | for name, param in resnet_block.named_parameters(): 309 | if name.endswith("bn3.weight"): 310 | nn.init.zeros_(param) 311 | 312 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 313 | attn_std = self.transformer.width ** -0.5 314 | fc_std = (2 * self.transformer.width) ** -0.5 315 | for block in self.transformer.resblocks: 316 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 317 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 318 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 319 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 320 | 321 | if self.text_projection is not None: 322 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 323 | 324 | def build_attention_mask(self): 325 | # lazily create causal attention mask, with full attention between the vision tokens 326 | # pytorch uses additive attention mask; fill with -inf 327 | mask = torch.empty(self.context_length, self.context_length) 328 | mask.fill_(float("-inf")) 329 | mask.triu_(1) # zero out the lower diagonal 330 | return mask 331 | 332 | @property 333 | def dtype(self): 334 | return self.visual.conv1.weight.dtype 335 | 336 | def encode_image(self, image): 337 | return self.visual(image.type(self.dtype)) 338 | 339 | def encode_text(self, text): 340 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 341 | 342 | x = x + self.positional_embedding.type(self.dtype) 343 | x = x.permute(1, 0, 2) # NLD -> LND 344 | x = self.transformer(x) 345 | x = x.permute(1, 0, 2) # LND -> NLD 346 | x = self.ln_final(x).type(self.dtype) 347 | 348 | # x.shape = [batch_size, n_ctx, transformer.width] 349 | # take features from the eot embedding (eot_token is the highest number in each sequence) 350 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 351 | 352 | return x 353 | 354 | def forward(self, image, text): 355 | image_features = self.encode_image(image) 356 | text_features = self.encode_text(text) 357 | 358 | # normalized features 359 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 360 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 361 | 362 | # cosine similarity as logits 363 | logit_scale = self.logit_scale.exp() 364 | logits_per_image = logit_scale * image_features @ text_features.t() 365 | logits_per_text = logits_per_image.t() 366 | 367 | # shape = [global_batch_size, global_batch_size] 368 | return logits_per_image, logits_per_text 369 | 370 | 371 | def convert_weights(model: nn.Module): 372 | """Convert applicable model parameters to fp16""" 373 | 374 | def _convert_weights_to_fp16(l): 375 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 376 | l.weight.data = l.weight.data.half() 377 | if l.bias is not None: 378 | l.bias.data = l.bias.data.half() 379 | 380 | if isinstance(l, nn.MultiheadAttention): 381 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 382 | tensor = getattr(l, attr) 383 | if tensor is not None: 384 | tensor.data = tensor.data.half() 385 | 386 | for name in ["text_projection", "proj"]: 387 | if hasattr(l, name): 388 | attr = getattr(l, name) 389 | if attr is not None: 390 | attr.data = attr.data.half() 391 | 392 | model.apply(_convert_weights_to_fp16) 393 | 394 | 395 | def build_model(state_dict: dict): 396 | vit = "visual.proj" in state_dict 397 | 398 | if vit: 399 | vision_width = state_dict["visual.conv1.weight"].shape[0] 400 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 401 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 402 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 403 | image_resolution = vision_patch_size * grid_size 404 | else: 405 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 406 | vision_layers = tuple(counts) 407 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 408 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 409 | vision_patch_size = None 410 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 411 | image_resolution = output_width * 32 412 | 413 | embed_dim = state_dict["text_projection"].shape[1] 414 | context_length = state_dict["positional_embedding"].shape[0] 415 | vocab_size = state_dict["token_embedding.weight"].shape[0] 416 | transformer_width = state_dict["ln_final.weight"].shape[0] 417 | transformer_heads = transformer_width // 64 418 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 419 | 420 | model = CLIP( 421 | embed_dim, 422 | image_resolution, vision_layers, vision_width, vision_patch_size, 423 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 424 | ) 425 | 426 | for key in ["input_resolution", "context_length", "vocab_size"]: 427 | if key in state_dict: 428 | del state_dict[key] 429 | 430 | convert_weights(model) 431 | model.load_state_dict(state_dict) 432 | return model.eval() 433 | -------------------------------------------------------------------------------- /tracking/dino/__init__.py: -------------------------------------------------------------------------------- 1 | from .dino import * -------------------------------------------------------------------------------- /tracking/dino/dino.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms as T 3 | from .xcit import XciT 4 | from .vit import ViT 5 | 6 | __all__ = ["load"] 7 | 8 | 9 | def load(model_name, model_path, device): 10 | _, base_name, variant = model_name.split('-') 11 | model = eval(base_name)(variant) 12 | model.load_state_dict(torch.load(model_path, map_location='cpu')) 13 | model = model.to(device) 14 | model.eval() 15 | 16 | transform = T.Compose([ 17 | T.Resize((224, 224)), 18 | T.ToTensor(), 19 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 20 | ]) 21 | 22 | return model, transform 23 | 24 | -------------------------------------------------------------------------------- /tracking/dino/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn.functional as F 4 | from torch import nn, Tensor 5 | 6 | 7 | class MLP(nn.Module): 8 | def __init__(self, dim, hidden_dim, out_dim=None) -> None: 9 | super().__init__() 10 | out_dim = out_dim or dim 11 | self.fc1 = nn.Linear(dim, hidden_dim) 12 | self.act = nn.GELU() 13 | self.fc2 = nn.Linear(hidden_dim, out_dim) 14 | 15 | def forward(self, x: Tensor) -> Tensor: 16 | return self.fc2(self.act(self.fc1(x))) 17 | 18 | 19 | class PatchEmbedding(nn.Module): 20 | """Image to Patch Embedding 21 | """ 22 | def __init__(self, img_size=224, patch_size=16, embed_dim=768): 23 | super().__init__() 24 | assert img_size % patch_size == 0, 'Image size must be divisible by patch size' 25 | 26 | img_size = (img_size, img_size) if isinstance(img_size, int) else img_size 27 | 28 | self.grid_size = (img_size[0] // patch_size, img_size[1] // patch_size) 29 | self.num_patches = self.grid_size[0] * self.grid_size[1] 30 | self.proj = nn.Conv2d(3, embed_dim, patch_size, patch_size) 31 | 32 | def forward(self, x: Tensor) -> Tensor: 33 | x = self.proj(x) # b x hidden_dim x 14 x 14 34 | x = x.flatten(2).swapaxes(1, 2) # b x (14*14) x hidden_dim 35 | return x 36 | 37 | 38 | class Attention(nn.Module): 39 | def __init__(self, dim, heads=12): 40 | super().__init__() 41 | self.num_heads = heads 42 | self.scale = (dim // heads) ** -0.5 43 | 44 | self.qkv = nn.Linear(dim, dim * 3, bias=True) 45 | self.proj = nn.Linear(dim, dim) 46 | 47 | def forward(self, x: Tensor) -> Tensor: 48 | B, N, C = x.shape 49 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 50 | q, k, v = qkv[0], qkv[1], qkv[2] 51 | 52 | attn = (q @ k.transpose(-2, -1)) * self.scale 53 | attn = attn.softmax(dim=-1) 54 | 55 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 56 | x = self.proj(x) 57 | return x 58 | 59 | 60 | class TransformerEncoder(nn.Module): 61 | def __init__(self, dim, heads): 62 | super().__init__() 63 | self.norm1 = nn.LayerNorm(dim) 64 | self.attn = Attention(dim, heads) 65 | self.norm2 = nn.LayerNorm(dim) 66 | self.mlp = MLP(dim, int(dim * 4)) 67 | 68 | def forward(self, x: Tensor) -> Tensor: 69 | x += self.attn(self.norm1(x)) 70 | x += self.mlp(self.norm2(x)) 71 | return x 72 | 73 | 74 | vit_settings = { 75 | 'S/8': [8, 12, 384, 6], #[patch_size, number_of_layers, embed_dim, heads] 76 | 'S/16': [16, 12, 384, 6], 77 | 'B/16': [16, 12, 768, 12] 78 | } 79 | 80 | 81 | class ViT(nn.Module): 82 | def __init__(self, model_name: str = 'S/8', image_size: int = 224) -> None: 83 | super().__init__() 84 | assert model_name in vit_settings.keys(), f"DeiT model name should be in {list(vit_settings.keys())}" 85 | patch_size, layers, embed_dim, heads = vit_settings[model_name] 86 | 87 | self.patch_size = patch_size 88 | self.patch_embed = PatchEmbedding(image_size, patch_size, embed_dim) 89 | self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)) 90 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 91 | 92 | self.blocks = nn.ModuleList([ 93 | TransformerEncoder(embed_dim, heads) 94 | for i in range(layers)]) 95 | 96 | self.norm = nn.LayerNorm(embed_dim) 97 | 98 | def interpolate_pos_encoding(self, x: Tensor, W: int, H: int) -> Tensor: 99 | num_patches = x.shape[1] - 1 100 | N = self.pos_embed.shape[1] - 1 101 | 102 | if num_patches == N and H == W: 103 | return self.pos_embed 104 | 105 | class_pos_embed = self.pos_embed[:, 0] 106 | patch_pos_embed = self.pos_embed[:, 1:] 107 | 108 | dim = x.shape[-1] 109 | w0 = W // self.patch_size 110 | h0 = H // self.patch_size 111 | 112 | w0, h0 = w0 + 0.1, h0 + 0.1 113 | 114 | patch_pos_embed = F.interpolate( 115 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 116 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 117 | mode='bicubic' 118 | ) 119 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 120 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 121 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 122 | 123 | def encode_image(self, x): 124 | return self.forward(x) 125 | 126 | def forward(self, x: Tensor) -> Tensor: 127 | B, C, W, H = x.shape 128 | x = self.patch_embed(x) 129 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) 130 | x = torch.cat((cls_token, x), dim=1) 131 | x += self.interpolate_pos_encoding(x, W, H) 132 | 133 | for blk in self.blocks: 134 | x = blk(x) 135 | 136 | x = self.norm(x) 137 | return x[:, 0] 138 | 139 | 140 | if __name__ == '__main__': 141 | model = ViT('S/16') 142 | model.load_state_dict(torch.load('checkpoints/vit/dino_deitsmall16_pretrain.pth', map_location='cpu')) 143 | x = torch.zeros(1, 3, 224, 224) 144 | y = model(x) 145 | print(y.shape) -------------------------------------------------------------------------------- /tracking/dino/xcit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn.functional as F 4 | from torch import nn, Tensor 5 | 6 | 7 | class MLP(nn.Module): 8 | def __init__(self, dim, hidden_dim, out_dim=None) -> None: 9 | super().__init__() 10 | out_dim = out_dim or dim 11 | self.fc1 = nn.Linear(dim, hidden_dim) 12 | self.act = nn.GELU() 13 | self.fc2 = nn.Linear(hidden_dim, out_dim) 14 | 15 | def forward(self, x: Tensor) -> Tensor: 16 | return self.fc2(self.act(self.fc1(x))) 17 | 18 | 19 | class PositionalEncodingFourier(nn.Module): 20 | def __init__(self, dim: int = 768): 21 | super().__init__() 22 | self.dim = dim 23 | self.hidden_dim = 32 24 | self.token_projection = nn.Conv2d(self.hidden_dim * 2, dim, 1) 25 | self.scale = 2 * math.pi 26 | 27 | def forward(self, B: int, H: int, W: int) -> Tensor: 28 | mask = torch.zeros(B, H, W).bool().to(self.token_projection.weight.device) 29 | not_mask = ~mask 30 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 31 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 32 | y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale 33 | x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale 34 | 35 | dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=mask.device) 36 | dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode='floor')) / self.hidden_dim) 37 | 38 | pos_x = x_embed[:, :, :, None] / dim_t 39 | pos_y = y_embed[:, :, :, None] / dim_t 40 | 41 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), 42 | pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 43 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), 44 | pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 46 | pos = self.token_projection(pos) 47 | return pos 48 | 49 | 50 | class Conv3x3(nn.Sequential): 51 | def __init__(self, c1, c2, s=1): 52 | super().__init__( 53 | nn.Conv2d(c1, c2, 3, s, 1, bias=False), 54 | nn.BatchNorm2d(c2) 55 | ) 56 | 57 | 58 | class ConvPatchEmbed(nn.Module): 59 | """Image to Patch Embedding using multiple convolutional layers 60 | """ 61 | def __init__(self, patch_size=8, embed_dim=768): 62 | super().__init__() 63 | if patch_size == 16: 64 | self.proj = nn.Sequential( 65 | Conv3x3(3, embed_dim // 8, 2), 66 | nn.GELU(), 67 | Conv3x3(embed_dim // 8, embed_dim // 4, 2), 68 | nn.GELU(), 69 | Conv3x3(embed_dim // 4, embed_dim // 2, 2), 70 | nn.GELU(), 71 | Conv3x3(embed_dim // 2, embed_dim, 2), 72 | ) 73 | else: 74 | self.proj = nn.Sequential( 75 | Conv3x3(3, embed_dim // 4, 2), 76 | nn.GELU(), 77 | Conv3x3(embed_dim // 4, embed_dim // 2, 2), 78 | nn.GELU(), 79 | Conv3x3(embed_dim // 2, embed_dim, 2), 80 | ) 81 | 82 | def forward(self, x: Tensor): 83 | x = self.proj(x) 84 | _, _, H, W = x.shape 85 | x = x.flatten(2).transpose(1, 2) 86 | return x, (H, W) 87 | 88 | 89 | class LPI(nn.Module): 90 | """ 91 | Local Patch Interaction module that allows explicit communication between tokens in 3x3 windows 92 | to augment the implicit communcation performed by the block diagonal scatter attention. 93 | Implemented using 2 layers of separable 3x3 convolutions with GeLU and BatchNorm2d 94 | """ 95 | def __init__(self, dim: int): 96 | super().__init__() 97 | self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim) 98 | self.act = nn.GELU() 99 | self.bn = nn.BatchNorm2d(dim) 100 | self.conv2 = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim) 101 | 102 | def forward(self, x: Tensor, H: int, W: int) -> Tensor: 103 | B, N, C = x.shape 104 | x = x.permute(0, 2, 1).reshape(B, C, H, W) 105 | x = self.conv2(self.bn(self.act(self.conv1(x)))) 106 | x = x.reshape(B, C, N).permute(0, 2, 1) 107 | return x 108 | 109 | 110 | class ClassAttention(nn.Module): 111 | """ClassAttention as in CaiT 112 | """ 113 | def __init__(self, dim: int, heads: int): 114 | super().__init__() 115 | self.num_heads = heads 116 | self.scale = (dim // heads) ** -0.5 117 | 118 | self.qkv = nn.Linear(dim, dim * 3) 119 | self.proj = nn.Linear(dim, dim) 120 | 121 | def forward(self, x: Tensor) -> Tensor: 122 | B, N, C = x.shape 123 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 124 | q, k, v = qkv[0], qkv[1], qkv[2] 125 | 126 | qc = q[:, :, 0:1] # CLS token 127 | 128 | attn_cls = (qc * k).sum(dim=-1) * self.scale 129 | attn_cls = attn_cls.softmax(dim=-1) 130 | 131 | cls_token = (attn_cls.unsqueeze(2) @ v).transpose(1, 2).reshape(B, 1, C) 132 | cls_token = self.proj(cls_token) 133 | 134 | x = torch.cat([cls_token, x[:, 1:]], dim=1) 135 | return x 136 | 137 | 138 | class XCA(nn.Module): 139 | """ Cross-Covariance Attention (XCA) operation where the channels are updated using a weighted 140 | sum. The weights are obtained from the (softmax normalized) Cross-covariance 141 | matrix (Q^T K \\in d_h \\times d_h) 142 | """ 143 | def __init__(self, dim: int, heads: int): 144 | super().__init__() 145 | self.num_heads = heads 146 | self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) 147 | 148 | self.qkv = nn.Linear(dim, dim * 3) 149 | self.proj = nn.Linear(dim, dim) 150 | 151 | def forward(self, x: Tensor) -> Tensor: 152 | B, N, C = x.shape 153 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 154 | q, k, v = qkv[0].transpose(-2, -1), qkv[1].transpose(-2, -1), qkv[2].transpose(-2, -1) 155 | q = F.normalize(q, dim=-1) 156 | k = F.normalize(k, dim=-1) 157 | attn = (q @ k.transpose(-2, -1)) * self.temperature 158 | attn = attn.softmax(dim=-1) 159 | 160 | x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) 161 | x = self.proj(x) 162 | return x 163 | 164 | 165 | class ClassAttentionBlock(nn.Module): 166 | def __init__(self, dim, heads, eta=1e-5): 167 | super().__init__() 168 | self.norm1 = nn.LayerNorm(dim) 169 | self.attn = ClassAttention(dim, heads) 170 | self.norm2 = nn.LayerNorm(dim) 171 | self.mlp = MLP(dim, int(dim * 4)) 172 | 173 | self.gamma1 = nn.Parameter(eta * torch.ones(dim)) 174 | self.gamma2 = nn.Parameter(eta * torch.ones(dim)) 175 | 176 | def forward(self, x: Tensor) -> Tensor: 177 | x = x + (self.gamma1 * self.attn(self.norm1(x))) 178 | x = self.norm2(x) 179 | 180 | x_res = x 181 | cls_token = self.gamma2 * self.mlp(x[:, :1]) 182 | x = torch.cat([cls_token, x[:, 1:]], dim=1) 183 | x += x_res 184 | return x 185 | 186 | 187 | class XCABlock(nn.Module): 188 | def __init__(self, dim, heads, eta=1e-5): 189 | super().__init__() 190 | self.norm1 = nn.LayerNorm(dim) 191 | self.attn = XCA(dim, heads) 192 | self.norm2 = nn.LayerNorm(dim) 193 | self.mlp = MLP(dim, int(dim * 4)) 194 | self.norm3 = nn.LayerNorm(dim) 195 | self.local_mp = LPI(dim) 196 | 197 | self.gamma1 = nn.Parameter(eta * torch.ones(dim)) 198 | self.gamma2 = nn.Parameter(eta * torch.ones(dim)) 199 | self.gamma3 = nn.Parameter(eta * torch.ones(dim)) 200 | 201 | def forward(self, x: Tensor, H, W) -> Tensor: 202 | x = x + self.gamma1 * self.attn(self.norm1(x)) 203 | x = x + self.gamma3 * self.local_mp(self.norm3(x), H, W) 204 | x = x + self.gamma2 * self.mlp(self.norm2(x)) 205 | return x 206 | 207 | 208 | xcit_settings = { 209 | 'S12/8': [8, 12, 384, 8], #[patch_size, layers, embed dim, heads] 210 | 'S12/16': [16, 12, 384, 8], 211 | 'M24/16': [16, 24, 512, 8], 212 | } 213 | 214 | 215 | class XciT(nn.Module): 216 | def __init__(self, model_name: str = 'S12/8', *args, **kwargs) -> None: 217 | super().__init__() 218 | assert model_name in xcit_settings.keys(), f"XciT model name should be in {list(xcit_settings.keys())}" 219 | patch_size, layers, embed_dim, heads = xcit_settings[model_name] 220 | 221 | self.patch_embed = ConvPatchEmbed(patch_size, embed_dim) 222 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 223 | 224 | self.pos_embeder = PositionalEncodingFourier(dim=embed_dim) 225 | 226 | self.blocks = nn.ModuleList([ 227 | XCABlock(embed_dim, heads) 228 | for _ in range(layers)]) 229 | 230 | self.cls_attn_blocks = nn.ModuleList([ 231 | ClassAttentionBlock(embed_dim, heads) 232 | for _ in range(2)]) 233 | self.norm = nn.LayerNorm(embed_dim) 234 | 235 | def encode_image(self, x): 236 | return self.forward(x) 237 | 238 | def forward(self, x): 239 | B = x.shape[0] 240 | x, (Hp, Wp) = self.patch_embed(x) 241 | x += self.pos_embeder(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1) 242 | 243 | for blk in self.blocks: 244 | x = blk(x, Hp, Wp) 245 | 246 | cls_tokens = self.cls_token.expand(B, -1, -1) 247 | x = torch.cat((cls_tokens, x), dim=1) 248 | 249 | for blk in self.cls_attn_blocks: 250 | x = blk(x) 251 | 252 | x = self.norm(x) 253 | return x[:, 0] 254 | 255 | 256 | if __name__ == '__main__': 257 | model = XciT('S12/16') 258 | model.load_state_dict(torch.load('checkpoints/xcit/dino_xcit_small_12_p16_pretrain.pth', map_location='cpu')) 259 | x = torch.zeros(1, 3, 224, 224) 260 | y = model(x) 261 | print(y.shape) 262 | -------------------------------------------------------------------------------- /tracking/sort/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/simple-object-tracking/4e86e53e78799c5cb92d2f1cddd2df071530e98e/tracking/sort/__init__.py -------------------------------------------------------------------------------- /tracking/sort/detection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Detection: 5 | """Bounding box detection in a single image 6 | Parameters 7 | ---------- 8 | tlwh : (ndarray) bbox in format `(top left x, top left y, width, height)`. 9 | confidence : (float) Detector confidence score. 10 | class_id : (ndarray) Detector class. 11 | feature : (ndarray) A feature vector that describes the object contained in this image. 12 | """ 13 | def __init__(self, tlwh, class_id, feature): 14 | self.tlwh = np.asarray(tlwh, dtype=np.float32) 15 | self.feature = np.asarray(feature, dtype=np.float32) 16 | self.class_id = class_id 17 | 18 | def to_tlbr(self): 19 | """Convert bbox from (top, left, width, height) to (top, left, bottom, right) 20 | """ 21 | ret = self.tlwh.copy() 22 | ret[2:] += ret[:2] 23 | return ret 24 | 25 | def to_xyah(self): 26 | """Convert bbox from (top, left, width, height) to (center x, center y, aspect ratio, height) where the aspect ratio is `width / height` 27 | """ 28 | ret = self.tlwh.copy() 29 | ret[:2] += ret[2:] / 2 30 | ret[2] /= ret[3] 31 | return ret -------------------------------------------------------------------------------- /tracking/sort/kalman_filter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.linalg 3 | 4 | 5 | class KalmanFilter: 6 | """A simple Kalman filter for tracking bounding boxes in image space 7 | The 8-dimensional state space 8 | x, y, a, h, vx, vy, va, vh 9 | contains the bounding box center position (x, y), aspect ratio a, height h, 10 | and their respective velocities. 11 | Object motion follows a constant velocity model. The bounding box location 12 | (x, y, a, h) is taken as direct observation of the state space (linear 13 | observation model). 14 | """ 15 | 16 | def __init__(self): 17 | ndim, dt = 4, 1. 18 | 19 | # Create Kalman filter model matrices. 20 | self._motion_mat = np.eye(2 * ndim, 2 * ndim) 21 | for i in range(ndim): 22 | self._motion_mat[i, ndim + i] = dt 23 | self._update_mat = np.eye(ndim, 2 * ndim) 24 | 25 | # Motion and observation uncertainty are chosen relative to the current 26 | # state estimate. These weights control the amount of uncertainty in 27 | # the model. This is a bit hacky. 28 | self._std_weight_position = 1. / 20 29 | self._std_weight_velocity = 1. / 160 30 | 31 | def initiate(self, measurement): 32 | """Create track from unassociated measurement. 33 | Parameters 34 | ---------- 35 | measurement : (ndarray) Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a, and height h. 36 | Returns 37 | ------- 38 | (ndarray, ndarray) 39 | Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of the new track. Unobserved velocities are initialized to 0 mean. 40 | """ 41 | mean_pos = measurement 42 | mean_vel = np.zeros_like(mean_pos) 43 | mean = np.r_[mean_pos, mean_vel] 44 | 45 | std = [ 46 | 2 * self._std_weight_position * measurement[3], 47 | 2 * self._std_weight_position * measurement[3], 48 | 1e-2, 49 | 2 * self._std_weight_position * measurement[3], 50 | 10 * self._std_weight_velocity * measurement[3], 51 | 10 * self._std_weight_velocity * measurement[3], 52 | 1e-5, 53 | 10 * self._std_weight_velocity * measurement[3] 54 | ] 55 | covariance = np.diag(np.square(std)) 56 | return mean, covariance 57 | 58 | def predict(self, mean, covariance): 59 | """Run Kalman filter prediction step. 60 | Parameters 61 | ---------- 62 | mean : ndarray 63 | The 8 dimensional mean vector of the object state at the previous 64 | time step. 65 | covariance : ndarray 66 | The 8x8 dimensional covariance matrix of the object state at the 67 | previous time step. 68 | Returns 69 | ------- 70 | (ndarray, ndarray) 71 | Returns the mean vector and covariance matrix of the predicted 72 | state. Unobserved velocities are initialized to 0 mean. 73 | """ 74 | std_pos = [ 75 | self._std_weight_position * mean[3], 76 | self._std_weight_position * mean[3], 77 | 1e-2, 78 | self._std_weight_position * mean[3]] 79 | std_vel = [ 80 | self._std_weight_velocity * mean[3], 81 | self._std_weight_velocity * mean[3], 82 | 1e-5, 83 | self._std_weight_velocity * mean[3]] 84 | 85 | motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) 86 | mean = np.dot(self._motion_mat, mean) 87 | covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov 88 | 89 | return mean, covariance 90 | 91 | def project(self, mean, covariance): 92 | """Project state distribution to measurement space. 93 | Parameters 94 | ---------- 95 | mean : (ndarray) The state's mean vector (8 dimensional array). 96 | covariance : (ndarray) The state's covariance matrix (8x8 dimensional). 97 | Returns 98 | ------- 99 | (ndarray, ndarray) 100 | Returns the projected mean and covariance matrix of the given state estimate. 101 | """ 102 | std = [ 103 | self._std_weight_position * mean[3], 104 | self._std_weight_position * mean[3], 105 | 1e-1, 106 | self._std_weight_position * mean[3] 107 | ] 108 | 109 | innovation_cov = np.diag(np.square(std)) 110 | mean = np.dot(self._update_mat, mean) 111 | covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T)) + innovation_cov 112 | 113 | return mean, covariance 114 | 115 | def update(self, mean, covariance, measurement): 116 | """Run Kalman filter correction step. 117 | Parameters 118 | ---------- 119 | mean : (ndarray) The predicted state's mean vector (8 dimensional). 120 | covariance : (ndarray) The state's covariance matrix (8x8 dimensional). 121 | measurement : (ndarray) The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center position, a the aspect ratio, and h the height of the bounding box. 122 | 123 | Returns 124 | ------- 125 | (ndarray, ndarray) 126 | Returns the measurement-corrected state distribution. 127 | """ 128 | projected_mean, projected_cov = self.project(mean, covariance) 129 | 130 | chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False) 131 | kalman_gain = scipy.linalg.cho_solve((chol_factor, lower), np.dot(covariance, self._update_mat.T).T, check_finite=False).T 132 | innovation = measurement - projected_mean 133 | 134 | new_mean = mean + np.dot(innovation, kalman_gain.T) 135 | new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T)) 136 | return new_mean, new_covariance 137 | 138 | def gating_distance(self, mean, covariance, measurements): 139 | """Compute gating distance between state distribution and measurements. 140 | Parameters 141 | ---------- 142 | mean : (ndarray) Mean vector over the state distribution (8 dimensional). 143 | covariance : (ndarray) Covariance of the state distribution (8x8 dimensional). 144 | measurements : (ndarray) 145 | An Nx4 dimensional matrix of N measurements, each in 146 | format (x, y, a, h) where (x, y) is the bounding box center 147 | position, a the aspect ratio, and h the height. 148 | Returns 149 | ------- 150 | ndarray 151 | Returns an array of length N, where the i-th element contains the 152 | squared Mahalanobis distance between (mean, covariance) and 153 | `measurements[i]`. 154 | """ 155 | mean, covariance = self.project(mean, covariance) 156 | cholesky_factor = np.linalg.cholesky(covariance) 157 | d = measurements - mean 158 | z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True) 159 | squared_maha = np.sum(z * z, axis=0) 160 | return squared_maha -------------------------------------------------------------------------------- /tracking/sort/matching.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import linear_sum_assignment 3 | 4 | 5 | def iou(bbox, candidates): 6 | """Compute IoU by one box to N candidates 7 | Parameters 8 | ---------- 9 | bbox : (ndarray) A bounding box in format `(top left x, top left y, width, height)`. 10 | candidates : (ndarray) A matrix of candidate bounding boxes (one per row) in the same format as `bbox`. 11 | 12 | Returns 13 | ------- 14 | ndarray 15 | The intersection over union in [0, 1] between the `bbox` and each candidate. A higher score means a larger fraction of the `bbox` is occluded by the candidate. 16 | """ 17 | bbox_tl, bbox_br = bbox[:2], bbox[:2] + bbox[2:] 18 | candidates_tl, candidates_br = candidates[:, :2], candidates[:, :2] + candidates[:, 2:] 19 | 20 | tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis], np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]] 21 | br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis], np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]] 22 | wh = np.maximum(0., br - tl) 23 | 24 | area_intersection = wh.prod(axis=1) 25 | area_bbox = bbox[2:].prod() 26 | area_candidates = candidates[:, 2:].prod(axis=1) 27 | return area_intersection / (area_bbox + area_candidates - area_intersection) 28 | 29 | 30 | def iou_cost(tracks, detections, track_indices=None, detection_indices=None): 31 | """An intersection over union distance metric. 32 | Parameters 33 | ---------- 34 | tracks : List[deep_sort.track.Track] A list of tracks. 35 | detections : List[deep_sort.detection.Detection] A list of detections. 36 | track_indices : Optional[List[int]] A list of indices to tracks that should be matched. Defaults to all `tracks`. 37 | detection_indices : Optional[List[int]] A list of indices to detections that should be matched. Defaults to all `detections`. 38 | 39 | Returns 40 | ------- 41 | ndarray 42 | Returns a cost matrix of shape len(track_indices), len(detection_indices) where entry (i, j) is `1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`. 43 | """ 44 | if track_indices is None: track_indices = np.arange(len(tracks)) 45 | if detection_indices is None: detection_indices = np.arange(len(detections)) 46 | 47 | cost_matrix = np.zeros((len(track_indices), len(detection_indices))) 48 | for row, track_idx in enumerate(track_indices): 49 | if tracks[track_idx].time_since_update > 1: 50 | cost_matrix[row, :] = 1e+5 51 | continue 52 | 53 | bbox = tracks[track_idx].to_tlwh() 54 | candidates = np.asarray([detections[i].tlwh for i in detection_indices]) 55 | cost_matrix[row, :] = 1. - iou(bbox, candidates) 56 | return cost_matrix 57 | 58 | 59 | def _nn_euclidean_distance(a, b): 60 | """Compute pair-wise squared distance between points in `a` and `b`. 61 | Parameters 62 | ---------- 63 | a : array_like 64 | An NxM matrix of N samples of dimensionality M. 65 | b : array_like 66 | An LxM matrix of L samples of dimensionality M. 67 | Returns 68 | ------- 69 | ndarray 70 | Returns a matrix of size len(a), len(b) such that eleement (i, j) 71 | contains the squared distance between `a[i]` and `b[j]`. 72 | """ 73 | a, b = np.asarray(a), np.asarray(b) 74 | if len(a) == 0 or len(b) == 0: 75 | return np.zeros((len(a), len(b))) 76 | a2, b2 = np.square(a).sum(axis=1), np.square(b).sum(axis=1) 77 | distances = -2. * np.dot(a, b.T) + a2[:, None] + b2[None, :] 78 | distances = np.clip(distances, 0., float(np.inf)) 79 | return np.maximum(0.0, distances.min(axis=0)) 80 | 81 | 82 | def _nn_cosine_distance(a, b): 83 | """Compute pair-wise cosine distance between points in `a` and `b`. 84 | Parameters 85 | ---------- 86 | a : array_like 87 | An NxM matrix of N samples of dimensionality M. 88 | b : array_like 89 | An LxM matrix of L samples of dimensionality M. 90 | 91 | Returns 92 | ------- 93 | ndarray 94 | Returns a matrix of size len(a), len(b) such that eleement (i, j) 95 | contains the squared distance between `a[i]` and `b[j]`. 96 | """ 97 | a = np.asarray(a) / np.linalg.norm(a, axis=1, keepdims=True) 98 | b = np.asarray(b) / np.linalg.norm(b, axis=1, keepdims=True) 99 | distances = 1. - np.dot(a, b.T) 100 | return distances.min(axis=0) 101 | 102 | 103 | class NearestNeighborDistanceMetric: 104 | """ 105 | A nearest neighbor distance metric that, for each target, returns 106 | the closest distance to any sample that has been observed so far. 107 | Parameters 108 | ---------- 109 | metric : str 110 | Either "euclidean" or "cosine". 111 | matching_threshold: float 112 | The matching threshold. Samples with larger distance are considered an 113 | invalid match. 114 | budget : Optional[int] 115 | If not None, fix samples per class to at most this number. Removes 116 | the oldest samples when the budget is reached. 117 | Attributes 118 | ---------- 119 | samples : Dict[int -> List[ndarray]] 120 | A dictionary that maps from target identities to the list of samples 121 | that have been observed so far. 122 | """ 123 | 124 | def __init__(self, metric, matching_threshold, budget=None): 125 | if metric == "euclidean": 126 | self._metric = _nn_euclidean_distance 127 | elif metric == "cosine": 128 | self._metric = _nn_cosine_distance 129 | else: 130 | raise ValueError("Invalid metric; must be either 'euclidean' or 'cosine'") 131 | self.matching_threshold = matching_threshold 132 | self.budget = budget 133 | self.samples = {} 134 | 135 | def partial_fit(self, features, targets, active_targets): 136 | """Update the distance metric with new data. 137 | Parameters 138 | ---------- 139 | features : ndarray 140 | An NxM matrix of N features of dimensionality M. 141 | targets : ndarray 142 | An integer array of associated target identities. 143 | active_targets : List[int] 144 | A list of targets that are currently present in the scene. 145 | """ 146 | for feature, target in zip(features, targets): 147 | self.samples.setdefault(target, []).append(feature) 148 | if self.budget is not None: 149 | self.samples[target] = self.samples[target][-self.budget:] 150 | self.samples = {k: self.samples[k] for k in active_targets} 151 | 152 | def distance(self, features, targets): 153 | """Compute distance between features and targets. 154 | Parameters 155 | ---------- 156 | features : ndarray 157 | An NxM matrix of N features of dimensionality M. 158 | targets : List[int] 159 | A list of targets to match the given `features` against. 160 | Returns 161 | ------- 162 | ndarray 163 | Returns a cost matrix of shape len(targets), len(features), where 164 | element (i, j) contains the closest squared distance between 165 | `targets[i]` and `features[j]`. 166 | """ 167 | cost_matrix = np.zeros((len(targets), len(features))) 168 | for i, target in enumerate(targets): 169 | cost_matrix[i, :] = self._metric(self.samples[target], features) 170 | return cost_matrix 171 | 172 | 173 | def min_cost_matching(distance_metric, max_distance, tracks, detections, track_indices=None, detection_indices=None): 174 | """Solve linear assignment problem. 175 | Parameters 176 | ---------- 177 | distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray 178 | The distance metric is given a list of tracks and detections as well as 179 | a list of N track indices and M detection indices. The metric should 180 | return the NxM dimensional cost matrix, where element (i, j) is the 181 | association cost between the i-th track in the given track indices and 182 | the j-th detection in the given detection_indices. 183 | max_distance : float 184 | Gating threshold. Associations with cost larger than this value are 185 | disregarded. 186 | tracks : List[track.Track] 187 | A list of predicted tracks at the current time step. 188 | detections : List[detection.Detection] 189 | A list of detections at the current time step. 190 | track_indices : List[int] 191 | List of track indices that maps rows in `cost_matrix` to tracks in 192 | `tracks` (see description above). 193 | detection_indices : List[int] 194 | List of detection indices that maps columns in `cost_matrix` to 195 | detections in `detections` (see description above). 196 | Returns 197 | ------- 198 | (List[(int, int)], List[int], List[int]) 199 | Returns a tuple with the following three entries: 200 | * A list of matched track and detection indices. 201 | * A list of unmatched track indices. 202 | * A list of unmatched detection indices. 203 | """ 204 | if track_indices is None: track_indices = np.arange(len(tracks)) 205 | if detection_indices is None: detection_indices = np.arange(len(detections)) 206 | 207 | if len(detection_indices) == 0 or len(track_indices) == 0: 208 | return [], track_indices, detection_indices # Nothing to match. 209 | 210 | cost_matrix = distance_metric(tracks, detections, track_indices, detection_indices) 211 | cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5 212 | 213 | row_indices, col_indices = linear_sum_assignment(cost_matrix) 214 | 215 | matches, unmatched_tracks, unmatched_detections = [], [], [] 216 | for col, detection_idx in enumerate(detection_indices): 217 | if col not in col_indices: 218 | unmatched_detections.append(detection_idx) 219 | for row, track_idx in enumerate(track_indices): 220 | if row not in row_indices: 221 | unmatched_tracks.append(track_idx) 222 | for row, col in zip(row_indices, col_indices): 223 | track_idx = track_indices[row] 224 | detection_idx = detection_indices[col] 225 | if cost_matrix[row, col] > max_distance: 226 | unmatched_tracks.append(track_idx) 227 | unmatched_detections.append(detection_idx) 228 | else: 229 | matches.append((track_idx, detection_idx)) 230 | return matches, unmatched_tracks, unmatched_detections 231 | 232 | 233 | def matching_cascade(distance_metric, max_distance, cascade_depth, tracks, detections, track_indices=None, detection_indices=None): 234 | """Run matching cascade. 235 | Parameters 236 | ---------- 237 | distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray 238 | The distance metric is given a list of tracks and detections as well as 239 | a list of N track indices and M detection indices. The metric should 240 | return the NxM dimensional cost matrix, where element (i, j) is the 241 | association cost between the i-th track in the given track indices and 242 | the j-th detection in the given detection indices. 243 | max_distance : float 244 | Gating threshold. Associations with cost larger than this value are 245 | disregarded. 246 | cascade_depth: int 247 | The cascade depth, should be se to the maximum track age. 248 | tracks : List[track.Track] 249 | A list of predicted tracks at the current time step. 250 | detections : List[detection.Detection] 251 | A list of detections at the current time step. 252 | track_indices : Optional[List[int]] 253 | List of track indices that maps rows in `cost_matrix` to tracks in 254 | `tracks` (see description above). Defaults to all tracks. 255 | detection_indices : Optional[List[int]] 256 | List of detection indices that maps columns in `cost_matrix` to 257 | detections in `detections` (see description above). Defaults to all 258 | detections. 259 | Returns 260 | ------- 261 | (List[(int, int)], List[int], List[int]) 262 | Returns a tuple with the following three entries: 263 | * A list of matched track and detection indices. 264 | * A list of unmatched track indices. 265 | * A list of unmatched detection indices. 266 | """ 267 | if track_indices is None: track_indices = list(range(len(tracks))) 268 | if detection_indices is None: detection_indices = list(range(len(detections))) 269 | 270 | unmatched_detections = detection_indices 271 | matches = [] 272 | for level in range(cascade_depth): 273 | if len(unmatched_detections) == 0: # No detections left 274 | break 275 | 276 | track_indices_l = [k for k in track_indices if tracks[k].time_since_update == 1 + level] 277 | if len(track_indices_l) == 0: # Nothing to match at this level 278 | continue 279 | 280 | matches_l, _, unmatched_detections = min_cost_matching(distance_metric, max_distance, tracks, detections, track_indices_l, unmatched_detections) 281 | matches += matches_l 282 | unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches)) 283 | return matches, unmatched_tracks, unmatched_detections 284 | 285 | 286 | def gate_cost_matrix(kf, cost_matrix, tracks, detections, track_indices, detection_indices): 287 | """Invalidate infeasible entries in cost matrix based on the state 288 | distributions obtained by Kalman filtering. 289 | Parameters 290 | ---------- 291 | kf : The Kalman filter. 292 | cost_matrix : ndarray 293 | The NxM dimensional cost matrix, where N is the number of track indices 294 | and M is the number of detection indices, such that entry (i, j) is the 295 | association cost between `tracks[track_indices[i]]` and 296 | `detections[detection_indices[j]]`. 297 | tracks : List[track.Track] 298 | A list of predicted tracks at the current time step. 299 | detections : List[detection.Detection] 300 | A list of detections at the current time step. 301 | track_indices : List[int] 302 | List of track indices that maps rows in `cost_matrix` to tracks in 303 | `tracks` (see description above). 304 | detection_indices : List[int] 305 | List of detection indices that maps columns in `cost_matrix` to 306 | detections in `detections` (see description above). 307 | Returns 308 | ------- 309 | ndarray 310 | Returns the modified cost matrix. 311 | """ 312 | measurements = np.asarray([detections[i].to_xyah() for i in detection_indices]) 313 | for row, track_idx in enumerate(track_indices): 314 | track = tracks[track_idx] 315 | gating_distance = kf.gating_distance(track.mean, track.covariance, measurements) 316 | cost_matrix[row, gating_distance > 9.4877] = 1e+5 317 | return cost_matrix -------------------------------------------------------------------------------- /tracking/sort/track.py: -------------------------------------------------------------------------------- 1 | class TrackState: 2 | """ 3 | Enumeration type for the single target track state. Newly created tracks are 4 | classified as `tentative` until enough evidence has been collected. Then, 5 | the track state is changed to `confirmed`. Tracks that are no longer alive 6 | are classified as `deleted` to mark them for removal from the set of active 7 | tracks. 8 | """ 9 | 10 | Tentative = 1 11 | Confirmed = 2 12 | Deleted = 3 13 | 14 | 15 | class Track: 16 | """ 17 | A single target track with state space `(x, y, a, h)` and associated 18 | velocities, where `(x, y)` is the center of the bounding box, `a` is the 19 | aspect ratio and `h` is the height. 20 | Parameters 21 | ---------- 22 | mean : ndarray 23 | Mean vector of the initial state distribution. 24 | covariance : ndarray 25 | Covariance matrix of the initial state distribution. 26 | track_id : int 27 | A unique track identifier. 28 | n_init : int 29 | Number of consecutive detections before the track is confirmed. The 30 | track state is set to `Deleted` if a miss occurs within the first 31 | `n_init` frames. 32 | max_age : int 33 | The maximum number of consecutive misses before the track state is 34 | set to `Deleted`. 35 | feature : Optional[ndarray] 36 | Feature vector of the detection this track originates from. If not None, 37 | this feature is added to the `features` cache. 38 | Attributes 39 | ---------- 40 | mean : ndarray 41 | Mean vector of the initial state distribution. 42 | covariance : ndarray 43 | Covariance matrix of the initial state distribution. 44 | track_id : int 45 | A unique track identifier. 46 | hits : int 47 | Total number of measurement updates. 48 | age : int 49 | Total number of frames since first occurance. 50 | time_since_update : int 51 | Total number of frames since last measurement update. 52 | state : TrackState 53 | The current track state. 54 | features : List[ndarray] 55 | A cache of features. On each measurement update, the associated feature 56 | vector is added to this list. 57 | """ 58 | 59 | def __init__(self, mean, covariance, track_id, n_init, max_age, feature=None, class_id=None): 60 | self.mean = mean 61 | self.covariance = covariance 62 | self.track_id = track_id 63 | self.hits = 1 64 | self.age = 1 65 | self.time_since_update = 0 66 | 67 | self.state = TrackState.Tentative 68 | self.features = [] 69 | if feature is not None: 70 | self.features.append(feature) 71 | 72 | self._n_init = n_init 73 | self._max_age = max_age 74 | self.class_id = class_id 75 | 76 | def to_tlwh(self): 77 | """Get current position in bounding box format `(top left x, top left y, 78 | width, height)`. 79 | Returns 80 | ------- 81 | ndarray 82 | The bounding box. 83 | """ 84 | ret = self.mean[:4].copy() 85 | ret[2] *= ret[3] 86 | ret[:2] -= ret[2:] / 2 87 | return ret 88 | 89 | def to_tlbr(self): 90 | """Get current position in bounding box format `(min x, miny, max x, 91 | max y)`. 92 | Returns 93 | ------- 94 | ndarray 95 | The bounding box. 96 | """ 97 | ret = self.to_tlwh() 98 | ret[2:] = ret[:2] + ret[2:] 99 | return ret 100 | 101 | def increment_age(self): 102 | self.age += 1 103 | self.time_since_update += 1 104 | 105 | def predict(self, kf): 106 | """Propagate the state distribution to the current time step using a 107 | Kalman filter prediction step. 108 | Parameters 109 | ---------- 110 | kf : kalman_filter.KalmanFilter 111 | The Kalman filter. 112 | """ 113 | self.mean, self.covariance = kf.predict(self.mean, self.covariance) 114 | self.increment_age() 115 | 116 | def update(self, kf, detection): 117 | """Perform Kalman filter measurement update step and update the feature 118 | cache. 119 | Parameters 120 | ---------- 121 | kf : kalman_filter.KalmanFilter 122 | The Kalman filter. 123 | detection : Detection 124 | The associated detection. 125 | """ 126 | self.mean, self.covariance = kf.update(self.mean, self.covariance, detection.to_xyah()) 127 | self.features.append(detection.feature) 128 | 129 | self.hits += 1 130 | self.time_since_update = 0 131 | if self.state == TrackState.Tentative and self.hits >= self._n_init: 132 | self.state = TrackState.Confirmed 133 | 134 | def mark_missed(self): 135 | """Mark this track as missed (no association at the current time step). 136 | """ 137 | if self.state == TrackState.Tentative: 138 | self.state = TrackState.Deleted 139 | elif self.time_since_update > self._max_age: 140 | self.state = TrackState.Deleted 141 | 142 | def is_tentative(self): 143 | """Returns True if this track is tentative (unconfirmed). 144 | """ 145 | return self.state == TrackState.Tentative 146 | 147 | def is_confirmed(self): 148 | """Returns True if this track is confirmed.""" 149 | return self.state == TrackState.Confirmed 150 | 151 | def is_deleted(self): 152 | """Returns True if this track is dead and should be deleted.""" 153 | return self.state == TrackState.Deleted -------------------------------------------------------------------------------- /tracking/sort/tracker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .detection import Detection 3 | from .kalman_filter import KalmanFilter 4 | from .matching import NearestNeighborDistanceMetric, iou_cost, min_cost_matching, matching_cascade, gate_cost_matrix 5 | from .track import Track 6 | 7 | 8 | class DeepSORTTracker: 9 | """DeepSORT Tracker 10 | Parameters 11 | ---------- 12 | metric : nn_matching.NearestNeighborDistanceMetric 13 | A distance metric for measurement-to-track association. 14 | max_age : int 15 | Maximum number of missed misses before a track is deleted. 16 | n_init : int 17 | Number of consecutive detections before the track is confirmed. The 18 | track state is set to `Deleted` if a miss occurs within the first 19 | `n_init` frames. 20 | Attributes 21 | ---------- 22 | metric : nn_matching.NearestNeighborDistanceMetric 23 | The distance metric used for measurement to track association. 24 | max_age : int 25 | Maximum number of missed misses before a track is deleted. 26 | n_init : int 27 | Number of frames that a track remains in initialization phase. 28 | kf : kalman_filter.KalmanFilter 29 | A Kalman filter to filter target trajectories in image space. 30 | tracks : List[Track] 31 | The list of active tracks at the current time step. 32 | """ 33 | 34 | def __init__(self, metric_type='cosine', max_cosine_distance=0.4, nn_budget=None, max_iou_distance=0.7, max_age=60, n_init=3): 35 | self.metric = NearestNeighborDistanceMetric(metric_type, max_cosine_distance, nn_budget) 36 | self.max_iou_distance = max_iou_distance 37 | self.max_age = max_age 38 | self.n_init = n_init 39 | 40 | self.kf = KalmanFilter() 41 | self.tracks = [] 42 | self._next_id = 1 43 | 44 | def reset(self): 45 | self.tracks = [] 46 | self._next_id = 1 47 | 48 | def predict(self): 49 | """Propagate track state distributions one time step forward. 50 | This function should be called once every time step, before `update`. 51 | """ 52 | for track in self.tracks: 53 | track.predict(self.kf) 54 | 55 | def increment_ages(self): 56 | for track in self.tracks: 57 | track.increment_age() 58 | track.mark_missed() 59 | 60 | def xyxy2xywh(self, boxes): 61 | boxes[:, 2] -= boxes[:, 0] 62 | boxes[:, 3] -= boxes[:, 1] 63 | return boxes 64 | 65 | def update(self, boxes, classes, features): 66 | detections = [ 67 | Detection(bbox, class_id, feature) 68 | for bbox, class_id, feature in zip(self.xyxy2xywh(boxes), classes, features)] 69 | 70 | # Run matching cascade. 71 | matches, unmatched_tracks, unmatched_detections = self._match(detections) 72 | 73 | # Update track set. 74 | for track_idx, detection_idx in matches: 75 | self.tracks[track_idx].update(self.kf, detections[detection_idx]) 76 | for track_idx in unmatched_tracks: 77 | self.tracks[track_idx].mark_missed() 78 | for detection_idx in unmatched_detections: 79 | self._initiate_track(detections[detection_idx]) 80 | 81 | self.tracks = [t for t in self.tracks if not t.is_deleted()] 82 | 83 | # Update distance metric. 84 | features, targets, active_targets = [], [], [] 85 | for track in self.tracks: 86 | if not track.is_confirmed(): 87 | continue 88 | active_targets.append(track.track_id) 89 | features += track.features 90 | targets += [track.track_id for _ in track.features] 91 | track.features = [] 92 | self.metric.partial_fit(np.asarray(features), np.asarray(targets), active_targets) 93 | 94 | def _match(self, detections): 95 | 96 | def gated_metric(tracks, dets, track_indices, detection_indices): 97 | features = np.array([dets[i].feature for i in detection_indices]) 98 | targets = np.array([tracks[i].track_id for i in track_indices]) 99 | cost_matrix = self.metric.distance(features, targets) 100 | cost_matrix = gate_cost_matrix(self.kf, cost_matrix, tracks, dets, track_indices, detection_indices) 101 | return cost_matrix 102 | 103 | # Split track set into confirmed and unconfirmed tracks. 104 | confirmed_tracks = [i for i, t in enumerate(self.tracks) if t.is_confirmed()] 105 | unconfirmed_tracks = [i for i, t in enumerate(self.tracks) if not t.is_confirmed()] 106 | 107 | # Associate confirmed tracks using appearance features. 108 | matches_a, unmatched_tracks_a, unmatched_detections = matching_cascade(gated_metric, self.metric.matching_threshold, self.max_age, self.tracks, detections, confirmed_tracks) 109 | 110 | # Associate remaining tracks together with unconfirmed tracks using IOU. 111 | iou_track_candidates = unconfirmed_tracks + [k for k in unmatched_tracks_a if self.tracks[k].time_since_update == 1] 112 | unmatched_tracks_a = [k for k in unmatched_tracks_a if self.tracks[k].time_since_update != 1] 113 | matches_b, unmatched_tracks_b, unmatched_detections = min_cost_matching(iou_cost, self.max_iou_distance, self.tracks, detections, iou_track_candidates, unmatched_detections) 114 | 115 | matches = matches_a + matches_b 116 | unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b)) 117 | return matches, unmatched_tracks, unmatched_detections 118 | 119 | def _initiate_track(self, detection): 120 | mean, covariance = self.kf.initiate(detection.to_xyah()) 121 | self.tracks.append(Track(mean, covariance, self._next_id, self.n_init, self.max_age, detection.feature, detection.class_id)) 122 | self._next_id += 1 -------------------------------------------------------------------------------- /tracking/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import time 3 | import random 4 | import torch 5 | import os 6 | import urllib.request 7 | import numpy as np 8 | from pathlib import Path 9 | from tqdm import tqdm 10 | from torchvision import ops, io 11 | from threading import Thread 12 | from torch.backends import cudnn 13 | cudnn.benchmark = True 14 | cudnn.deterministic = False 15 | 16 | 17 | def coco_class_index(class_name: str) -> int: 18 | coco_classes = [ 19 | 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 20 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 21 | 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 22 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 23 | 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 24 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 25 | 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 26 | 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 27 | 'hair drier', 'toothbrush' 28 | ] 29 | assert class_name.lower() in coco_classes, f"Invalid Class Name.\nAvailable COCO classes: {coco_classes}" 30 | return coco_classes.index(class_name.lower()) 31 | 32 | 33 | class Colors: 34 | # Ultralytics color palette https://ultralytics.com/ 35 | def __init__(self): 36 | # hex = matplotlib.colors.TABLEAU_COLORS.values() 37 | hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB', 38 | '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7') 39 | self.palette = [self.hex2rgb('#' + c) for c in hex] 40 | self.n = len(self.palette) 41 | 42 | def __call__(self, i, bgr=False): 43 | c = self.palette[int(i) % self.n] 44 | return (c[2], c[1], c[0]) if bgr else c 45 | 46 | @staticmethod 47 | def hex2rgb(h): # rgb order (PIL) 48 | return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)) 49 | 50 | 51 | colors = Colors() 52 | 53 | 54 | class WebcamStream: 55 | def __init__(self, src=0) -> None: 56 | self.cap = cv2.VideoCapture(src) 57 | self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) 58 | assert self.cap.isOpened(), f"Failed to open webcam {src}" 59 | _, self.frame = self.cap.read() 60 | Thread(target=self.update, args=([]), daemon=True).start() 61 | 62 | def update(self): 63 | while self.cap.isOpened(): 64 | _, self.frame = self.cap.read() 65 | 66 | def __iter__(self): 67 | self.count = -1 68 | return self 69 | 70 | def __next__(self): 71 | self.count += 1 72 | 73 | if cv2.waitKey(1) == ord('q'): 74 | self.stop() 75 | 76 | return self.frame.copy() 77 | 78 | def stop(self): 79 | cv2.destroyAllWindows() 80 | raise StopIteration 81 | 82 | def __len__(self): 83 | return 0 84 | 85 | 86 | class SequenceStream: 87 | def __init__(self, folder): 88 | self.frames = self.read_frames(folder) 89 | 90 | print(f"Processing '{folder}'...") 91 | print(f"Total Frames: {len(self.frames)}") 92 | print(f"Video Size : {self.frames[0].shape[:-1]}") 93 | 94 | def read_frames(self, folder): 95 | files = sorted(list(Path(folder).glob('*.jpg'))) 96 | frames = [] 97 | for file in files: 98 | img = cv2.imread(str(file)) 99 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 100 | frames.append(img) 101 | return frames 102 | 103 | def __iter__(self): 104 | self.count = 0 105 | return self 106 | 107 | def __len__(self): 108 | return len(self.frames) 109 | 110 | def __next__(self): 111 | if self.count == len(self.frames): 112 | raise StopIteration 113 | frame = self.frames[self.count] 114 | self.count += 1 115 | return frame 116 | 117 | 118 | class VideoReader: 119 | def __init__(self, video: str): 120 | self.frames, _, info = io.read_video(video, pts_unit='sec') 121 | self.fps = info['video_fps'] 122 | 123 | print(f"Processing '{video}'...") 124 | print(f"Total Frames: {len(self.frames)}") 125 | print(f"Video Size : {list(self.frames.shape[1:-1])}") 126 | print(f"Video FPS : {self.fps}") 127 | 128 | def __iter__(self): 129 | self.count = 0 130 | return self 131 | 132 | def __len__(self): 133 | return len(self.frames) 134 | 135 | def __next__(self): 136 | if self.count == len(self.frames): 137 | raise StopIteration 138 | frame = self.frames[self.count] 139 | self.count += 1 140 | return frame 141 | 142 | 143 | class VideoWriter: 144 | def __init__(self, file_name, fps): 145 | self.fname = file_name 146 | self.fps = fps 147 | self.frames = [] 148 | 149 | def update(self, frame): 150 | if isinstance(frame, np.ndarray): 151 | frame = torch.from_numpy(frame) 152 | self.frames.append(frame) 153 | 154 | def write(self): 155 | print(f"Saving video to '{self.fname}'...") 156 | io.write_video(self.fname, torch.stack(self.frames), self.fps) 157 | 158 | 159 | class FPS: 160 | def __init__(self, avg=10) -> None: 161 | self.accum_time = 0 162 | self.counts = 0 163 | self.avg = avg 164 | 165 | def synchronize(self): 166 | if torch.cuda.is_available(): 167 | torch.cuda.synchronize() 168 | 169 | def start(self): 170 | self.synchronize() 171 | self.prev_time = time.time() 172 | 173 | def stop(self, debug=True): 174 | self.synchronize() 175 | self.accum_time += time.time() - self.prev_time 176 | self.counts += 1 177 | if self.counts == self.avg: 178 | self.fps = round(self.counts / self.accum_time) 179 | if debug: print(f"FPS: {self.fps}") 180 | self.counts = 0 181 | self.accum_time = 0 182 | 183 | 184 | def plot_one_box(box, img, color=None, label=None): 185 | color = color or [random.randint(0, 255) for _ in range(3)] 186 | p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) 187 | cv2.rectangle(img, p1, p2, color, 2, lineType=cv2.LINE_AA) 188 | 189 | if label: 190 | t_size = cv2.getTextSize(label, 0, fontScale=0.5, thickness=1)[0] 191 | p2 = p1[0] + t_size[0], p1[1] - t_size[1] - 3 192 | cv2.rectangle(img, p1, p2, color, -1, cv2.LINE_AA) 193 | cv2.putText(img, label, (p1[0], p1[1]-2), 0, 0.5, [255, 255, 255], thickness=1, lineType=cv2.LINE_AA) 194 | 195 | 196 | def letterbox(img, new_shape=(640, 640)): 197 | H, W = img.shape[:2] 198 | if isinstance(new_shape, int): 199 | new_shape = (new_shape, new_shape) 200 | 201 | r = min(new_shape[0] / H, new_shape[1] / W) 202 | nH, nW = round(H * r), round(W * r) 203 | pH, pW = np.mod(new_shape[0] - nH, 32) / 2, np.mod(new_shape[1] - nW, 32) / 2 204 | 205 | if (H, W) != (nH, nW): 206 | img = cv2.resize(img, (nW, nH), interpolation=cv2.INTER_LINEAR) 207 | 208 | top, bottom = round(pH - 0.1), round(pH + 0.1) 209 | left, right = round(pW - 0.1), round(pW + 0.1) 210 | img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) 211 | return img 212 | 213 | 214 | def scale_boxes(boxes, orig_shape, new_shape): 215 | H, W = orig_shape 216 | nH, nW = new_shape 217 | gain = min(nH / H, nW / W) 218 | pad = (nH - H * gain) / 2, (nW - W * gain) / 2 219 | 220 | boxes[:, ::2] -= pad[1] 221 | boxes[:, 1::2] -= pad[0] 222 | boxes[:, :4] /= gain 223 | 224 | boxes[:, ::2].clamp_(0, orig_shape[1]) 225 | boxes[:, 1::2].clamp_(0, orig_shape[0]) 226 | return boxes.round() 227 | 228 | 229 | def xywh2xyxy(x): 230 | boxes = x.clone() 231 | boxes[:, 0] = x[:, 0] - x[:, 2] / 2 232 | boxes[:, 1] = x[:, 1] - x[:, 3] / 2 233 | boxes[:, 2] = x[:, 0] + x[:, 2] / 2 234 | boxes[:, 3] = x[:, 1] + x[:, 3] / 2 235 | return boxes 236 | 237 | 238 | def non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45, classes=None): 239 | candidates = pred[..., 4] > conf_thres 240 | 241 | max_wh = 4096 242 | max_nms = 30000 243 | max_det = 300 244 | 245 | output = [torch.zeros((0, 6), device=pred.device)] * pred.shape[0] 246 | 247 | for xi, x in enumerate(pred): 248 | x = x[candidates[xi]] 249 | 250 | if not x.shape[0]: continue 251 | 252 | # compute conf 253 | x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf 254 | 255 | # box 256 | box = xywh2xyxy(x[:, :4]) 257 | 258 | # detection matrix nx6 259 | conf, j = x[:, 5:].max(1, keepdim=True) 260 | x = torch.cat([box, conf, j.float()], dim=1)[conf.view(-1) > conf_thres] 261 | 262 | # filter by class 263 | if classes is not None: 264 | x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] 265 | 266 | # check shape 267 | n = x.shape[0] 268 | if not n: 269 | continue 270 | elif n > max_nms: 271 | x = x[x[:, 4].argsort(descending=True)[:max_nms]] 272 | 273 | # batched nms 274 | c = x[:, 5:6] * max_wh 275 | boxes, scores = x[:, :4] + c, x[:, 4] 276 | keep = ops.nms(boxes, scores, iou_thres) 277 | 278 | if keep.shape[0] > max_det: 279 | keep = keep[:max_det] 280 | 281 | output[xi] = x[keep] 282 | 283 | return output 284 | 285 | 286 | def download(url: str, root: str): 287 | os.makedirs(root, exist_ok=True) 288 | filename = os.path.basename(url) 289 | download_target = os.path.join(root, filename) 290 | 291 | if os.path.exists(download_target) and os.path.isfile(download_target): 292 | return download_target 293 | 294 | print(f"Downloading model from {url}") 295 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 296 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 297 | while True: 298 | buffer = source.read(8192) 299 | if not buffer: 300 | break 301 | 302 | output.write(buffer) 303 | loop.update(len(buffer)) 304 | 305 | return download_target 306 | 307 | 308 | --------------------------------------------------------------------------------