├── .circleci └── config.yml ├── .github ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── DETR.png └── ISSUE_TEMPLATE │ ├── bugs.md │ ├── questions-help-support.md │ └── unexpected-problems-bugs.md ├── .gitignore ├── LICENSE ├── README.md ├── cfgs ├── submit.yaml ├── track.yaml ├── track_reid.yaml ├── train.yaml ├── train_coco_person_masks.yaml ├── train_crowdhuman.yaml ├── train_deformable.yaml ├── train_full_res.yaml ├── train_mot17.yaml ├── train_mot17_crowdhuman.yaml ├── train_mot20_crowdhuman.yaml ├── train_mot_coco_person.yaml ├── train_mots20.yaml ├── train_multi_frame.yaml └── train_tracking.yaml ├── data ├── .gitignore └── snakeboard │ └── snakeboard.mp4 ├── docs ├── INSTALL.md ├── MOT17-03-SDP.gif ├── MOTS20-07.gif ├── TRAIN.md ├── method.png ├── snakeboard.gif └── visdom.gif ├── logs ├── .gitignore └── visdom │ └── .gitignore ├── models └── .gitignore ├── requirements.txt ├── setup.py └── src ├── combine_frames.py ├── compute_best_mean_epoch_from_splits.py ├── generate_coco_from_crowdhuman.py ├── generate_coco_from_mot.py ├── parse_mot_results_to_tex.py ├── run_with_submitit.py ├── track.py ├── track_param_search.py ├── trackformer ├── __init__.py ├── datasets │ ├── __init__.py │ ├── coco.py │ ├── coco_eval.py │ ├── coco_panoptic.py │ ├── crowdhuman.py │ ├── mot.py │ ├── panoptic_eval.py │ ├── tracking │ │ ├── __init__.py │ │ ├── demo_sequence.py │ │ ├── factory.py │ │ ├── mot17_sequence.py │ │ ├── mot20_sequence.py │ │ ├── mot_wrapper.py │ │ └── mots20_sequence.py │ └── transforms.py ├── engine.py ├── models │ ├── __init__.py │ ├── backbone.py │ ├── deformable_detr.py │ ├── deformable_transformer.py │ ├── detr.py │ ├── detr_segmentation.py │ ├── detr_tracking.py │ ├── matcher.py │ ├── ops │ │ ├── .gitignore │ │ ├── functions │ │ │ ├── __init__.py │ │ │ └── ms_deform_attn_func.py │ │ ├── make.sh │ │ ├── modules │ │ │ ├── __init__.py │ │ │ └── ms_deform_attn.py │ │ ├── setup.py │ │ ├── src │ │ │ ├── cpu │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ ├── cuda │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ ├── ms_deform_attn.h │ │ │ └── vision.cpp │ │ ├── test.py │ │ └── test_double_precision.py │ ├── position_encoding.py │ ├── tracker.py │ └── transformer.py ├── util │ ├── __init__.py │ ├── box_ops.py │ ├── misc.py │ ├── plot_utils.py │ └── track_utils.py └── vis.py └── train.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | jobs: 4 | python_lint: 5 | docker: 6 | - image: circleci/python:3.7 7 | steps: 8 | - checkout 9 | - run: 10 | command: | 11 | pip install --user --progress-bar off flake8 typing 12 | flake8 . 13 | 14 | test: 15 | docker: 16 | - image: circleci/python:3.7 17 | steps: 18 | - checkout 19 | - run: 20 | command: | 21 | pip install --user --progress-bar off scipy pytest 22 | pip install --user --progress-bar off --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html 23 | pytest . 24 | 25 | workflows: 26 | build: 27 | jobs: 28 | - python_lint 29 | - test 30 | -------------------------------------------------------------------------------- /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to DETR 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | Minor changes and improvements will be released on an ongoing basis. Larger changes (e.g., changesets implementing a new paper) will be released on a more periodic basis. 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | 11 | 1. Fork the repo and create your branch from `master`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 5. Make sure your code lints. 16 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 17 | 18 | ## Contributor License Agreement ("CLA") 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Facebook's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | We use GitHub issues to track public bugs. Please ensure your description is 26 | clear and has sufficient instructions to be able to reproduce the issue. 27 | 28 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 29 | disclosure of security bugs. In those cases, please go through the process 30 | outlined on that page and do not file a public issue. 31 | 32 | ## Coding Style 33 | * 4 spaces for indentation rather than tabs 34 | * 80 character line length 35 | * PEP8 formatting following [Black](https://black.readthedocs.io/en/stable/) 36 | 37 | ## License 38 | By contributing to DETR, you agree that your contributions will be licensed 39 | under the LICENSE file in the root directory of this source tree. 40 | -------------------------------------------------------------------------------- /.github/DETR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timmeinhardt/trackformer/e468bf156b029869f6de1be358bc11cd1f517f3c/.github/DETR.png -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bugs.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "🐛 Bugs" 3 | about: Report bugs in DETR 4 | title: Please read & provide the following 5 | 6 | --- 7 | 8 | ## Instructions To Reproduce the 🐛 Bug: 9 | 10 | 1. what changes you made (`git diff`) or what code you wrote 11 | ``` 12 | 13 | ``` 14 | 2. what exact command you run: 15 | 3. what you observed (including __full logs__): 16 | ``` 17 | 18 | ``` 19 | 4. please simplify the steps as much as possible so they do not require additional resources to 20 | run, such as a private dataset. 21 | 22 | ## Expected behavior: 23 | 24 | If there are no obvious error in "what you observed" provided above, 25 | please tell us the expected behavior. 26 | 27 | ## Environment: 28 | 29 | Provide your environment information using the following command: 30 | ``` 31 | python -m torch.utils.collect_env 32 | ``` 33 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/questions-help-support.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "How to do something❓" 3 | about: How to do something using DETR? 4 | 5 | --- 6 | 7 | ## ❓ How to do something using DETR 8 | 9 | Describe what you want to do, including: 10 | 1. what inputs you will provide, if any: 11 | 2. what outputs you are expecting: 12 | 13 | 14 | NOTE: 15 | 16 | 1. Only general answers are provided. 17 | If you want to ask about "why X did not work", please use the 18 | [Unexpected behaviors](https://github.com/facebookresearch/detr/issues/new/choose) issue template. 19 | 20 | 2. About how to implement new models / new dataloader / new training logic, etc., check documentation first. 21 | 22 | 3. We do not answer general machine learning / computer vision questions that are not specific to DETR, such as how a model works, how to improve your training/make it converge, or what algorithm/methods can be used to achieve X. 23 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/unexpected-problems-bugs.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "Unexpected behaviors" 3 | about: Run into unexpected behaviors when using DETR 4 | title: Please read & provide the following 5 | 6 | --- 7 | 8 | If you do not know the root cause of the problem, and wish someone to help you, please 9 | post according to this template: 10 | 11 | ## Instructions To Reproduce the Issue: 12 | 13 | 1. what changes you made (`git diff`) or what code you wrote 14 | ``` 15 | 16 | ``` 17 | 2. what exact command you run: 18 | 3. what you observed (including __full logs__): 19 | ``` 20 | 21 | ``` 22 | 4. please simplify the steps as much as possible so they do not require additional resources to 23 | run, such as a private dataset. 24 | 25 | ## Expected behavior: 26 | 27 | If there are no obvious error in "what you observed" provided above, 28 | please tell us the expected behavior. 29 | 30 | If you expect the model to converge / work better, note that we do not give suggestions 31 | on how to train a new model. 32 | Only in one of the two conditions we will help with it: 33 | (1) You're unable to reproduce the results in DETR model zoo. 34 | (2) It indicates a DETR bug. 35 | 36 | ## Environment: 37 | 38 | Provide your environment information using the following command: 39 | ``` 40 | python -m torch.utils.collect_env 41 | ``` 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .nfs* 2 | *.ipynb 3 | *.pyc 4 | .dumbo.json 5 | .DS_Store 6 | .*.swp 7 | *.pth 8 | **/__pycache__/** 9 | .ipynb_checkpoints/ 10 | datasets/data/ 11 | experiment-* 12 | *.tmp 13 | *.pkl 14 | **/.mypy_cache/* 15 | .mypy_cache/* 16 | not_tracked_dir/ 17 | .vscode 18 | .python-version 19 | *.sbatch 20 | *.egg-info 21 | src/trackformer/models/ops/build* 22 | src/trackformer/models/ops/dist* 23 | src/trackformer/models/ops/lib* 24 | src/trackformer/models/ops/temp* 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 - present, Facebook, Inc 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TrackFormer: Multi-Object Tracking with Transformers 2 | 3 | This repository provides the official implementation of the [TrackFormer: Multi-Object Tracking with Transformers](https://arxiv.org/abs/2101.02702) paper by [Tim Meinhardt](https://dvl.in.tum.de/team/meinhardt/), [Alexander Kirillov](https://alexander-kirillov.github.io/), [Laura Leal-Taixe](https://dvl.in.tum.de/team/lealtaixe/) and [Christoph Feichtenhofer](https://feichtenhofer.github.io/). The codebase builds upon [DETR](https://github.com/facebookresearch/detr), [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR) and [Tracktor](https://github.com/phil-bergmann/tracking_wo_bnw). 4 | 5 | 6 | 7 |
8 | MOT17-03-SDP 9 | MOTS20-07 10 |
11 | 12 | ## Abstract 13 | 14 | The challenging task of multi-object tracking (MOT) requires simultaneous reasoning about track initialization, identity, and spatiotemporal trajectories. 15 | We formulate this task as a frame-to-frame set prediction problem and introduce TrackFormer, an end-to-end MOT approach based on an encoder-decoder Transformer architecture. 16 | Our model achieves data association between frames via attention by evolving a set of track predictions through a video sequence. 17 | The Transformer decoder initializes new tracks from static object queries and autoregressively follows existing tracks in space and time with the new concept of identity preserving track queries. 18 | Both decoder query types benefit from self- and encoder-decoder attention on global frame-level features, thereby omitting any additional graph optimization and matching or modeling of motion and appearance. 19 | TrackFormer represents a new tracking-by-attention paradigm and yields state-of-the-art performance on the task of multi-object tracking (MOT17) and segmentation (MOTS20). 20 | 21 |
22 | TrackFormer casts multi-object tracking as a set prediction problem performing joint detection and tracking-by-attention. The architecture consists of a CNN for image feature extraction, a Transformer encoder for image feature encoding and a Transformer decoder which applies self- and encoder-decoder attention to produce output embeddings with bounding box and class information. 23 |
24 | 25 | ## Installation 26 | 27 | We refer to our [docs/INSTALL.md](docs/INSTALL.md) for detailed installation instructions. 28 | 29 | ## Train TrackFormer 30 | 31 | We refer to our [docs/TRAIN.md](docs/TRAIN.md) for detailed training instructions. 32 | 33 | ## Evaluate TrackFormer 34 | 35 | In order to evaluate TrackFormer on a multi-object tracking dataset, we provide the `src/track.py` script which supports several datasets and splits interchangle via the `dataset_name` argument (See `src/datasets/tracking/factory.py` for an overview of all datasets.) The default tracking configuration is specified in `cfgs/track.yaml`. To facilitate the reproducibility of our results, we provide evaluation metrics for both the train and test set. 36 | 37 | ### MOT17 38 | 39 | #### Private detections 40 | 41 | ``` 42 | python src/track.py with reid 43 | ``` 44 | 45 |
46 | 47 | | MOT17 | MOTA | IDF1 | MT | ML | FP | FN | ID SW. | 48 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 49 | | **Train** | 74.2 | 71.7 | 849 | 177 | 7431 | 78057 | 1449 | 50 | | **Test** | 74.1 | 68.0 | 1113 | 246 | 34602 | 108777 | 2829 | 51 | 52 |
53 | 54 | #### Public detections (DPM, FRCNN, SDP) 55 | 56 | ``` 57 | python src/track.py with \ 58 | reid \ 59 | tracker_cfg.public_detections=min_iou_0_5 \ 60 | obj_detect_checkpoint_file=models/mot17_deformable_multi_frame/checkpoint_epoch_50.pth 61 | ``` 62 | 63 |
64 | 65 | | MOT17 | MOTA | IDF1 | MT | ML | FP | FN | ID SW. | 66 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 67 | | **Train** | 64.6 | 63.7 | 621 | 675 | 4827 | 111958 | 2556 | 68 | | **Test** | 62.3 | 57.6 | 688 | 638 | 16591 | 192123 | 4018 | 69 | 70 |
71 | 72 | ### MOT20 73 | 74 | #### Private detections 75 | 76 | ``` 77 | python src/track.py with \ 78 | reid \ 79 | dataset_name=MOT20-ALL \ 80 | obj_detect_checkpoint_file=models/mot20_crowdhuman_deformable_multi_frame/checkpoint_epoch_50.pth 81 | ``` 82 | 83 |
84 | 85 | | MOT20 | MOTA | IDF1 | MT | ML | FP | FN | ID SW. | 86 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 87 | | **Train** | 81.0 | 73.3 | 1540 | 124 | 20807 | 192665 | 1961 | 88 | | **Test** | 68.6 | 65.7 | 666 | 181 | 20348 | 140373 | 1532 | 89 | 90 |
91 | 92 | ### MOTS20 93 | 94 | ``` 95 | python src/track.py with \ 96 | dataset_name=MOTS20-ALL \ 97 | obj_detect_checkpoint_file=models/mots20_train_masks/checkpoint.pth 98 | ``` 99 | 100 | Our tracking script only applies MOT17 metrics evaluation but outputs MOTS20 mask prediction files. To evaluate these download the official [MOTChallengeEvalKit](https://github.com/dendorferpatrick/MOTChallengeEvalKit). 101 | 102 |
103 | 104 | | MOTS20 | sMOTSA | IDF1 | FP | FN | IDs | 105 | | :---: | :---: | :---: | :---: | :---: | :---: | 106 | | **Train** | -- | -- | -- | -- | -- | 107 | | **Test** | 54.9 | 63.6 | 2233 | 7195 | 278 | 108 | 109 |
110 | 111 | ### Demo 112 | 113 | To facilitate the application of TrackFormer, we provide a demo interface which allows for a quick processing of a given video sequence. 114 | 115 | ``` 116 | ffmpeg -i data/snakeboard/snakeboard.mp4 -vf fps=30 data/snakeboard/%06d.png 117 | 118 | python src/track.py with \ 119 | dataset_name=DEMO \ 120 | data_root_dir=data/snakeboard \ 121 | output_dir=data/snakeboard \ 122 | write_images=pretty 123 | ``` 124 | 125 |
126 | Snakeboard demo 127 |
128 | 129 | ## Publication 130 | If you use this software in your research, please cite our publication: 131 | 132 | ``` 133 | @InProceedings{meinhardt2021trackformer, 134 | title={TrackFormer: Multi-Object Tracking with Transformers}, 135 | author={Tim Meinhardt and Alexander Kirillov and Laura Leal-Taixe and Christoph Feichtenhofer}, 136 | year={2022}, 137 | month = {June}, 138 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 139 | } 140 | ``` -------------------------------------------------------------------------------- /cfgs/submit.yaml: -------------------------------------------------------------------------------- 1 | # Number of gpus to request on each node 2 | num_gpus: 1 3 | vram: 12GB 4 | # memory allocated per GPU in GB 5 | mem_per_gpu: 20 6 | # Number of nodes to request 7 | nodes: 1 8 | # Duration of the job 9 | timeout: 4320 10 | # Job dir. Leave empty for automatic. 11 | job_dir: '' 12 | # Use to run jobs locally. ('debug', 'local', 'slurm') 13 | cluster: debug 14 | # Partition. Leave empty for automatic. 15 | slurm_partition: '' 16 | # Constraint. Leave empty for automatic. 17 | slurm_constraint: '' 18 | slurm_comment: '' 19 | slurm_gres: '' 20 | slurm_exclude: '' 21 | cpus_per_task: 2 -------------------------------------------------------------------------------- /cfgs/track.yaml: -------------------------------------------------------------------------------- 1 | output_dir: null 2 | verbose: false 3 | seed: 666 4 | 5 | obj_detect_checkpoint_file: models/mot17_crowdhuman_deformable_multi_frame/checkpoint_epoch_40.pth 6 | 7 | interpolate: False 8 | # if available load tracking results and only evaluate 9 | load_results_dir: null 10 | 11 | # dataset (look into src/datasets/tracking/factory.py) 12 | dataset_name: MOT17-ALL-ALL 13 | data_root_dir: data 14 | 15 | # [False, 'debug', 'pretty'] 16 | # compile video with: `ffmpeg -f image2 -framerate 15 -i %06d.jpg -vcodec libx264 -y movie.mp4 -vf scale=320:-1` 17 | write_images: False 18 | # Maps are only visualized if write_images is True 19 | generate_attention_maps: False 20 | 21 | # track, evaluate and write images only for a range of frames (in float fraction) 22 | frame_range: 23 | start: 0.0 24 | end: 1.0 25 | 26 | tracker_cfg: 27 | # [False, 'center_distance', 'min_iou_0_5'] 28 | public_detections: False 29 | # score threshold for detections 30 | detection_obj_score_thresh: 0.4 31 | # score threshold for keeping the track alive 32 | track_obj_score_thresh: 0.4 33 | # NMS threshold for detection 34 | detection_nms_thresh: 0.9 35 | # NMS theshold while tracking 36 | track_nms_thresh: 0.9 37 | # number of consective steps a score has to be below track_obj_score_thresh for a track to be terminated 38 | steps_termination: 1 39 | # distance of previous frame for multi-frame attention 40 | prev_frame_dist: 1 41 | # How many timesteps inactive tracks are kept and cosidered for reid 42 | inactive_patience: -1 43 | # How similar do image and old track need to be to be considered the same person 44 | reid_sim_threshold: 0.0 45 | reid_sim_only: false 46 | reid_score_thresh: 0.4 47 | reid_greedy_matching: false 48 | -------------------------------------------------------------------------------- /cfgs/track_reid.yaml: -------------------------------------------------------------------------------- 1 | tracker_cfg: 2 | inactive_patience: 5 3 | -------------------------------------------------------------------------------- /cfgs/train.yaml: -------------------------------------------------------------------------------- 1 | lr: 0.0002 2 | lr_backbone_names: ['backbone.0'] 3 | lr_backbone: 0.00002 4 | lr_linear_proj_names: ['reference_points', 'sampling_offsets'] 5 | lr_linear_proj_mult: 0.1 6 | lr_track: 0.0001 7 | overwrite_lrs: false 8 | overwrite_lr_scheduler: false 9 | batch_size: 2 10 | weight_decay: 0.0001 11 | epochs: 50 12 | lr_drop: 40 13 | # gradient clipping max norm 14 | clip_max_norm: 0.1 15 | # Deformable DETR 16 | deformable: false 17 | with_box_refine: false 18 | two_stage: false 19 | # Model parameters 20 | freeze_detr: false 21 | load_mask_head_from_model: null 22 | # Backbone 23 | # Name of the convolutional backbone to use. ('resnet50', 'resnet101') 24 | backbone: resnet50 25 | # If true, we replace stride with dilation in the last convolutional block (DC5) 26 | dilation: false 27 | # Type of positional embedding to use on top of the image features. ('sine', 'learned') 28 | position_embedding: sine 29 | # Number of feature levels the encoder processes from the backbone 30 | num_feature_levels: 1 31 | # Transformer 32 | # Number of encoding layers in the transformer 33 | enc_layers: 6 34 | # Number of decoding layers in the transformer 35 | dec_layers: 6 36 | # Intermediate size of the feedforward layers in the transformer blocks 37 | dim_feedforward: 2048 38 | # Size of the embeddings (dimension of the transformer) 39 | hidden_dim: 256 40 | # Dropout applied in the transformer 41 | dropout: 0.1 42 | # Number of attention heads inside the transformer's attentions 43 | nheads: 8 44 | # Number of object queries 45 | num_queries: 100 46 | pre_norm: false 47 | dec_n_points: 4 48 | enc_n_points: 4 49 | # Tracking 50 | tracking: false 51 | # In addition to detection also run tracking evaluation with default configuration from `cfgs/track.yaml` 52 | tracking_eval: true 53 | # Range of possible random previous frames 54 | track_prev_frame_range: 0 55 | track_prev_frame_rnd_augs: 0.01 56 | track_prev_prev_frame: False 57 | track_backprop_prev_frame: False 58 | track_query_false_positive_prob: 0.1 59 | track_query_false_negative_prob: 0.4 60 | # only for vanilla DETR 61 | track_query_false_positive_eos_weight: true 62 | track_attention: false 63 | multi_frame_attention: false 64 | multi_frame_encoding: true 65 | multi_frame_attention_separate_encoder: true 66 | merge_frame_features: false 67 | overflow_boxes: false 68 | # Segmentation 69 | masks: false 70 | # Matcher 71 | # Class coefficient in the matching cost 72 | set_cost_class: 1.0 73 | # L1 box coefficient in the matching cost 74 | set_cost_bbox: 5.0 75 | # giou box coefficient in the matching cost 76 | set_cost_giou: 2.0 77 | # Loss 78 | # Disables auxiliary decoding losses (loss at each layer) 79 | aux_loss: true 80 | mask_loss_coef: 1.0 81 | dice_loss_coef: 1.0 82 | cls_loss_coef: 1.0 83 | bbox_loss_coef: 5.0 84 | giou_loss_coef: 2 85 | # Relative classification weight of the no-object class 86 | eos_coef: 0.1 87 | focal_loss: false 88 | focal_alpha: 0.25 89 | focal_gamma: 2 90 | # Dataset 91 | dataset: coco 92 | train_split: train 93 | val_split: val 94 | coco_path: data/coco_2017 95 | coco_panoptic_path: null 96 | mot_path_train: data/MOT17 97 | mot_path_val: data/MOT17 98 | crowdhuman_path: data/CrowdHuman 99 | # allows for joint training of mot and crowdhuman/coco_person with the `mot_crowdhuman`/`mot_coco_person` dataset 100 | crowdhuman_train_split: null 101 | coco_person_train_split: null 102 | coco_and_crowdhuman_prev_frame_rnd_augs: 0.2 103 | coco_min_num_objects: 0 104 | img_transform: 105 | max_size: 1333 106 | val_width: 800 107 | # Miscellaneous 108 | # path where to save, empty for no saving 109 | output_dir: '' 110 | # device to use for training / testing 111 | device: cuda 112 | seed: 42 113 | # resume from checkpoint 114 | resume: '' 115 | resume_shift_neuron: False 116 | # resume optimization from checkpoint 117 | resume_optim: false 118 | # resume Visdom visualization 119 | resume_vis: false 120 | start_epoch: 1 121 | eval_only: false 122 | eval_train: false 123 | num_workers: 2 124 | val_interval: 5 125 | debug: false 126 | # epoch interval for model saving. if 0 only save last and best models 127 | save_model_interval: 5 128 | # distributed training parameters 129 | # number of distributed processes 130 | world_size: 1 131 | # url used to set up distributed training 132 | dist_url: env:// 133 | # Visdom params 134 | # vis_server: http://localhost 135 | vis_server: '' 136 | vis_port: 8090 137 | vis_and_log_interval: 50 138 | no_vis: false 139 | -------------------------------------------------------------------------------- /cfgs/train_coco_person_masks.yaml: -------------------------------------------------------------------------------- 1 | dataset: coco_person 2 | 3 | load_mask_head_from_model: models/detr-r50-panoptic-00ce5173.pth 4 | freeze_detr: true 5 | masks: true 6 | 7 | lr: 0.0001 8 | lr_drop: 50 9 | epochs: 50 -------------------------------------------------------------------------------- /cfgs/train_crowdhuman.yaml: -------------------------------------------------------------------------------- 1 | dataset: mot_crowdhuman 2 | crowdhuman_train_split: train_val 3 | train_split: null 4 | val_split: mot17_train_cross_val_frame_0_5_to_1_0_coco 5 | epochs: 80 6 | lr_drop: 50 -------------------------------------------------------------------------------- /cfgs/train_deformable.yaml: -------------------------------------------------------------------------------- 1 | deformable: true 2 | num_feature_levels: 4 3 | num_queries: 300 4 | dim_feedforward: 1024 5 | focal_loss: true 6 | focal_alpha: 0.25 7 | focal_gamma: 2 8 | cls_loss_coef: 2.0 9 | set_cost_class: 2.0 10 | overflow_boxes: true 11 | with_box_refine: true -------------------------------------------------------------------------------- /cfgs/train_full_res.yaml: -------------------------------------------------------------------------------- 1 | img_transform: 2 | max_size: 1920 3 | val_width: 1080 -------------------------------------------------------------------------------- /cfgs/train_mot17.yaml: -------------------------------------------------------------------------------- 1 | dataset: mot 2 | 3 | train_split: mot17_train_coco 4 | val_split: mot17_train_cross_val_frame_0_5_to_1_0_coco 5 | 6 | mot_path_train: data/MOT17 7 | mot_path_val: data/MOT17 8 | 9 | resume: models/r50_deformable_detr_plus_iterative_bbox_refinement-checkpoint_hidden_dim_288.pth 10 | 11 | epochs: 50 12 | lr_drop: 10 -------------------------------------------------------------------------------- /cfgs/train_mot17_crowdhuman.yaml: -------------------------------------------------------------------------------- 1 | dataset: mot_crowdhuman 2 | 3 | crowdhuman_train_split: train_val 4 | train_split: mot17_train_coco 5 | val_split: mot17_train_cross_val_frame_0_5_to_1_0_coco 6 | 7 | mot_path_train: data/MOT17 8 | mot_path_val: data/MOT17 9 | 10 | resume: models/crowdhuman_deformable_trackformer/checkpoint_epoch_80.pth 11 | 12 | epochs: 40 13 | lr_drop: 10 -------------------------------------------------------------------------------- /cfgs/train_mot20_crowdhuman.yaml: -------------------------------------------------------------------------------- 1 | dataset: mot_crowdhuman 2 | 3 | crowdhuman_train_split: train_val 4 | train_split: mot20_train_coco 5 | val_split: mot20_train_cross_val_frame_0_5_to_1_0_coco 6 | 7 | mot_path_train: data/MOT20 8 | mot_path_val: data/MOT20 9 | 10 | resume: models/crowdhuman_deformable_trackformer/checkpoint_epoch_80.pth 11 | 12 | epochs: 50 13 | lr_drop: 10 -------------------------------------------------------------------------------- /cfgs/train_mot_coco_person.yaml: -------------------------------------------------------------------------------- 1 | dataset: mot_coco_person 2 | coco_person_train_split: train 3 | train_split: null 4 | val_split: mot17_train_cross_val_frame_0_5_to_1_0_coco -------------------------------------------------------------------------------- /cfgs/train_mots20.yaml: -------------------------------------------------------------------------------- 1 | dataset: mot 2 | mot_path: data/MOTS20 3 | train_split: mots20_train_coco 4 | val_split: mots20_train_coco 5 | 6 | resume: models/mot17_train_pretrain_CH_deformable_with_coco_person_masks/checkpoint.pth 7 | masks: true 8 | lr: 0.00001 9 | lr_backbone: 0.000001 10 | 11 | epochs: 40 12 | lr_drop: 40 -------------------------------------------------------------------------------- /cfgs/train_multi_frame.yaml: -------------------------------------------------------------------------------- 1 | num_queries: 500 2 | hidden_dim: 288 3 | multi_frame_attention: true 4 | multi_frame_encoding: true 5 | multi_frame_attention_separate_encoder: true -------------------------------------------------------------------------------- /cfgs/train_tracking.yaml: -------------------------------------------------------------------------------- 1 | tracking: true 2 | tracking_eval: true 3 | track_prev_frame_range: 5 4 | track_query_false_positive_eos_weight: true -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !snakeboard 4 | -------------------------------------------------------------------------------- /data/snakeboard/snakeboard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timmeinhardt/trackformer/e468bf156b029869f6de1be358bc11cd1f517f3c/data/snakeboard/snakeboard.mp4 -------------------------------------------------------------------------------- /docs/INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | 1. Clone and enter this repository: 4 | ``` 5 | git clone git@github.com:timmeinhardt/trackformer.git 6 | cd trackformer 7 | ``` 8 | 9 | 2. Install packages for Python 3.7: 10 | 11 | 1. `pip3 install -r requirements.txt` 12 | 2. Install PyTorch 1.5 and torchvision 0.6 from [here](https://pytorch.org/get-started/previous-versions/#v150). 13 | 3. Install pycocotools (with fixed ignore flag): `pip3 install -U 'git+https://github.com/timmeinhardt/cocoapi.git#subdirectory=PythonAPI'` 14 | 5. Install MultiScaleDeformableAttention package: `python src/trackformer/models/ops/setup.py build --build-base=src/trackformer/models/ops/ install` 15 | 16 | 3. Download and unpack datasets in the `data` directory: 17 | 18 | 1. [MOT17](https://motchallenge.net/data/MOT17/): 19 | 20 | ``` 21 | wget https://motchallenge.net/data/MOT17.zip 22 | unzip MOT17.zip 23 | python src/generate_coco_from_mot.py 24 | ``` 25 | 26 | 2. (Optional) [MOT20](https://motchallenge.net/data/MOT20/): 27 | 28 | ``` 29 | wget https://motchallenge.net/data/MOT20.zip 30 | unzip MOT20.zip 31 | python src/generate_coco_from_mot.py --mot20 32 | ``` 33 | 34 | 3. (Optional) [MOTS20](https://motchallenge.net/data/MOTS/): 35 | 36 | ``` 37 | wget https://motchallenge.net/data/MOTS.zip 38 | unzip MOTS.zip 39 | python src/generate_coco_from_mot.py --mots 40 | ``` 41 | 42 | 4. (Optional) [CrowdHuman](https://www.crowdhuman.org/download.html): 43 | 44 | 1. Create a `CrowdHuman` and `CrowdHuman/annotations` directory. 45 | 2. Download and extract the `train` and `val` datasets including their corresponding `*.odgt` annotation file into the `CrowdHuman` directory. 46 | 3. Create a `CrowdHuman/train_val` directory and merge or symlink the `train` and `val` image folders. 47 | 4. Run `python src/generate_coco_from_crowdhuman.py` 48 | 5. The final folder structure should resemble this: 49 | ~~~ 50 | |-- data 51 | |-- CrowdHuman 52 | | |-- train 53 | | | |-- *.jpg 54 | | |-- val 55 | | | |-- *.jpg 56 | | |-- train_val 57 | | | |-- *.jpg 58 | | |-- annotations 59 | | | |-- annotation_train.odgt 60 | | | |-- annotation_val.odgt 61 | | | |-- train_val.json 62 | ~~~ 63 | 64 | 3. Download and unpack pretrained TrackFormer model files in the `models` directory: 65 | 66 | ``` 67 | wget https://vision.in.tum.de/webshare/u/meinhard/trackformer_models_v1.zip 68 | unzip trackformer_models_v1.zip 69 | ``` 70 | 71 | 4. (optional) The evaluation of MOTS20 metrics requires two steps: 72 | 1. Run Trackformer with `src/track.py` and output prediction files 73 | 2. Download the official MOTChallenge [devkit](https://github.com/dendorferpatrick/MOTChallengeEvalKit) and run the MOTS evaluation on the prediction files 74 | 75 | In order to configure, log and reproduce our computational experiments, we structure our code with the [Sacred](http://sacred.readthedocs.io/en/latest/index.html) framework. For a detailed explanation of the Sacred interface please read its documentation. 76 | -------------------------------------------------------------------------------- /docs/MOT17-03-SDP.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timmeinhardt/trackformer/e468bf156b029869f6de1be358bc11cd1f517f3c/docs/MOT17-03-SDP.gif -------------------------------------------------------------------------------- /docs/MOTS20-07.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timmeinhardt/trackformer/e468bf156b029869f6de1be358bc11cd1f517f3c/docs/MOTS20-07.gif -------------------------------------------------------------------------------- /docs/TRAIN.md: -------------------------------------------------------------------------------- 1 | # Train TrackFormer 2 | 3 | We provide the code as well as intermediate models of our entire training pipeline for multiple datasets. Monitoring of the training/evaluation progress is possible via command line as well as [Visdom](https://github.com/fossasia/visdom.git). For the latter, a Visdom server must be running at `vis_port` and `vis_server` (see `cfgs/train.yaml`). We set `vis_server=''` by default to deactivate Visdom logging. To deactivate Visdom logging with set parameters, you can run a training with the `no_vis=True` flag. 4 | 5 |
6 | Snakeboard demo 7 |
8 | 9 | The settings for each dataset are specified in the respective configuration files, e.g., `cfgs/train_crowdhuman.yaml`. The following train commands produced the pretrained model files mentioned in [docs/INSTALL.md](INSTALL.md). 10 | 11 | ## CrowdHuman pre-training 12 | 13 | ``` 14 | python src/train.py with \ 15 | crowdhuman \ 16 | deformable \ 17 | multi_frame \ 18 | tracking \ 19 | output_dir=models/crowdhuman_deformable_multi_frame \ 20 | ``` 21 | 22 | ## MOT17 23 | 24 | #### Private detections 25 | 26 | ``` 27 | python src/train.py with \ 28 | mot17_crowdhuman \ 29 | deformable \ 30 | multi_frame \ 31 | tracking \ 32 | output_dir=models/mot17_crowdhuman_deformable_multi_frame \ 33 | ``` 34 | 35 | #### Public detections 36 | 37 | ``` 38 | python src/train.py with \ 39 | mot17 \ 40 | deformable \ 41 | multi_frame \ 42 | tracking \ 43 | output_dir=models/mot17_deformable_multi_frame \ 44 | ``` 45 | 46 | ## MOT20 47 | 48 | #### Private detections 49 | 50 | ``` 51 | python src/train.py with \ 52 | mot20_crowdhuman \ 53 | deformable \ 54 | multi_frame \ 55 | tracking \ 56 | output_dir=models/mot20_crowdhuman_deformable_multi_frame \ 57 | ``` 58 | 59 | ## MOTS20 60 | 61 | For our MOTS20 test set submission, we finetune a MOT17 private detection model without deformable attention, i.e., vanilla DETR, which was pre-trained on the CrowdHuman dataset. The finetuning itself conists of two training steps: (i) the original DETR panoptic segmentation head on the COCO person segmentation data and (ii) the entire TrackFormer model (including segmentation head) on the MOTS20 training set. At this point, we only provide the final model files in [docs/INSTALL.md](INSTALL.md). 62 | 63 | 76 | 77 | 80 | 81 | ## Custom Dataset 82 | 83 | TrackFormer can be trained on additional/new object detection or multi-object tracking datasets without changing our codebase. The `crowdhuman` or `mot` datasets merely require a [COCO style](https://www.immersivelimit.com/tutorials/create-coco-annotations-from-scratch) annotation file and the following folder structure: 84 | 85 | ~~~ 86 | |-- data 87 | |-- custom_dataset 88 | | |-- train 89 | | | |-- *.jpg 90 | | |-- val 91 | | | |-- *.jpg 92 | | |-- annotations 93 | | | |-- train.json 94 | | | |-- val.json 95 | ~~~ 96 | 97 | In the case of a multi-object tracking dataset, the original COCO annotations style must be extended with `seq_length`, `first_frame_image_id` and `track_id` fields. See the `src/generate_coco_from_mot.py` script for details. For example, the following command finetunes our `MOT17` private model for additional 20 epochs on a custom dataset: 98 | 99 | ``` 100 | python src/train.py with \ 101 | mot17 \ 102 | deformable \ 103 | multi_frame \ 104 | tracking \ 105 | resume=models/mot17_crowdhuman_deformable_trackformer/checkpoint_epoch_40.pth \ 106 | output_dir=models/custom_dataset_deformable \ 107 | mot_path_train=data/custom_dataset \ 108 | mot_path_val=data/custom_dataset \ 109 | train_split=train \ 110 | val_split=val \ 111 | epochs=20 \ 112 | ``` 113 | 114 | ## Run with multipe GPUs 115 | 116 | All reported results are obtained by training with a batch size of 2 and 7 GPUs, i.e., an effective batch size of 14. If you have less GPUs at your disposal, adjust the learning rates accordingly. To start the CrowdHuman pre-training with 7 GPUs execute: 117 | 118 | ``` 119 | python -m torch.distributed.launch --nproc_per_node=7 --use_env src/train.py with \ 120 | crowdhuman \ 121 | deformable \ 122 | multi_frame \ 123 | tracking \ 124 | output_dir=models/crowdhuman_deformable_multi_frame \ 125 | ``` 126 | 127 | ## Run SLURM jobs with Submitit 128 | 129 | Furthermore, we provide a script for starting Slurm jobs with [submitit](https://github.com/facebookincubator/submitit). This includes a convenient command line interface for Slurm options as well as preemption and resuming capabilities. The aforementioned CrowdHuman pre-training can be executed on 7 x 32 GB GPUs with the following command: 130 | 131 | ``` 132 | python src/run_with_submitit.py with \ 133 | num_gpus=7 \ 134 | vram=32GB \ 135 | cluster=slurm \ 136 | train.crowdhuman \ 137 | train.deformable \ 138 | train.trackformer \ 139 | train.tracking \ 140 | train.output_dir=models/crowdhuman_train_val_deformable \ 141 | ``` -------------------------------------------------------------------------------- /docs/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timmeinhardt/trackformer/e468bf156b029869f6de1be358bc11cd1f517f3c/docs/method.png -------------------------------------------------------------------------------- /docs/snakeboard.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timmeinhardt/trackformer/e468bf156b029869f6de1be358bc11cd1f517f3c/docs/snakeboard.gif -------------------------------------------------------------------------------- /docs/visdom.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timmeinhardt/trackformer/e468bf156b029869f6de1be358bc11cd1f517f3c/docs/visdom.gif -------------------------------------------------------------------------------- /logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !visdom 3 | !.gitignore 4 | -------------------------------------------------------------------------------- /logs/visdom/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argon2-cffi==20.1.0 2 | astroid==2.4.2 3 | async-generator==1.10 4 | attrs==19.3.0 5 | backcall==0.2.0 6 | bleach==3.2.3 7 | certifi==2020.4.5.2 8 | cffi==1.14.4 9 | chardet==3.0.4 10 | cloudpickle==1.6.0 11 | colorama==0.4.3 12 | cycler==0.10.0 13 | Cython==0.29.20 14 | decorator==4.4.2 15 | defusedxml==0.6.0 16 | docopt==0.6.2 17 | entrypoints==0.3 18 | filelock==3.0.12 19 | flake8==3.8.3 20 | flake8-import-order==0.18.1 21 | future==0.18.2 22 | gdown==3.12.2 23 | gitdb==4.0.5 24 | GitPython==3.1.3 25 | idna==2.9 26 | imageio==2.8.0 27 | importlib-metadata==1.6.1 28 | ipykernel==5.4.3 29 | ipython==7.19.0 30 | ipython-genutils==0.2.0 31 | ipywidgets==7.6.3 32 | isort==5.6.4 33 | jedi==0.18.0 34 | Jinja2==2.11.2 35 | jsonpatch==1.25 36 | jsonpickle==1.4.1 37 | jsonpointer==2.0 38 | jsonschema==3.2.0 39 | jupyter==1.0.0 40 | jupyter-client==6.1.11 41 | jupyter-console==6.2.0 42 | jupyter-core==4.7.0 43 | jupyterlab-pygments==0.1.2 44 | jupyterlab-widgets==1.0.0 45 | kiwisolver==1.2.0 46 | lap==0.4.0 47 | lapsolver==1.1.0 48 | lazy-object-proxy==1.4.3 49 | MarkupSafe==1.1.1 50 | matplotlib==3.2.1 51 | mccabe==0.6.1 52 | mistune==0.8.4 53 | more-itertools==8.4.0 54 | motmetrics==1.2.0 55 | munch==2.5.0 56 | nbclient==0.5.1 57 | nbconvert==6.0.7 58 | nbformat==5.1.2 59 | nest-asyncio==1.5.1 60 | networkx==2.4 61 | ninja==1.10.0.post2 62 | notebook==6.2.0 63 | numpy==1.18.5 64 | opencv-python==4.2.0.34 65 | packaging==20.4 66 | pandas==1.0.5 67 | pandocfilters==1.4.3 68 | parso==0.8.1 69 | pexpect==4.8.0 70 | pickleshare==0.7.5 71 | Pillow==7.1.2 72 | pluggy==0.13.1 73 | prometheus-client==0.9.0 74 | prompt-toolkit==3.0.14 75 | ptyprocess==0.7.0 76 | py==1.8.2 77 | py-cpuinfo==6.0.0 78 | pyaml==20.4.0 79 | pycodestyle==2.6.0 80 | pycparser==2.20 81 | pyflakes==2.2.0 82 | Pygments==2.7.4 83 | pylint==2.6.0 84 | pyparsing==2.4.7 85 | pyrsistent==0.17.3 86 | PySocks==1.7.1 87 | pytest==5.4.3 88 | pytest-benchmark==3.2.3 89 | python-dateutil==2.8.1 90 | pytz==2020.1 91 | PyWavelets==1.1.1 92 | PyYAML==5.3.1 93 | pyzmq==19.0.1 94 | qtconsole==5.0.2 95 | QtPy==1.9.0 96 | requests==2.23.0 97 | sacred==0.8.1 98 | scikit-image==0.17.2 99 | scipy==1.4.1 100 | seaborn==0.10.1 101 | Send2Trash==1.5.0 102 | six==1.15.0 103 | smmap==3.0.4 104 | submitit==1.1.5 105 | terminado==0.9.2 106 | testpath==0.4.4 107 | tifffile==2020.6.3 108 | toml==0.10.2 109 | torchfile==0.1.0 110 | tornado==6.1 111 | tqdm==4.46.1 112 | traitlets==5.0.5 113 | typed-ast==1.4.1 114 | typing-extensions==3.7.4.3 115 | urllib3==1.25.9 116 | visdom==0.1.8.9 117 | wcwidth==0.2.5 118 | webencodings==0.5.1 119 | websocket-client==0.57.0 120 | widgetsnbextension==3.5.1 121 | wrapt==1.12.1 122 | xmltodict==0.12.0 123 | zipp==3.1.0 124 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='trackformer', 4 | packages=['trackformer'], 5 | package_dir={'':'src'}, 6 | version='0.0.1', 7 | install_requires=[],) 8 | -------------------------------------------------------------------------------- /src/combine_frames.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Combine two sets of frames to one. 4 | """ 5 | import os 6 | import os.path as osp 7 | 8 | from PIL import Image 9 | 10 | OUTPUT_DIR = 'models/mot17_masks_track_rcnn_and_v3_combined' 11 | 12 | FRAME_DIR_1 = 'models/mot17_masks_track_rcnn/MOTS20-TEST' 13 | FRAME_DIR_2 = 'models/mot17_masks_v3/MOTS20-ALL' 14 | 15 | 16 | if __name__ == '__main__': 17 | seqs_1 = os.listdir(FRAME_DIR_1) 18 | seqs_2 = os.listdir(FRAME_DIR_2) 19 | 20 | if not osp.exists(OUTPUT_DIR): 21 | os.makedirs(OUTPUT_DIR) 22 | 23 | for seq in seqs_1: 24 | if seq in seqs_2: 25 | print(seq) 26 | seg_output_dir = osp.join(OUTPUT_DIR, seq) 27 | if not osp.exists(seg_output_dir): 28 | os.makedirs(seg_output_dir) 29 | 30 | frames = os.listdir(osp.join(FRAME_DIR_1, seq)) 31 | 32 | for frame in frames: 33 | img_1 = Image.open(osp.join(FRAME_DIR_1, seq, frame)) 34 | img_2 = Image.open(osp.join(FRAME_DIR_2, seq, frame)) 35 | 36 | width = img_1.size[0] 37 | height = img_2.size[1] 38 | 39 | combined_frame = Image.new('RGB', (width, height * 2)) 40 | combined_frame.paste(img_1, (0, 0)) 41 | combined_frame.paste(img_2, (0, height)) 42 | 43 | combined_frame.save(osp.join(seg_output_dir, f'{frame}')) 44 | -------------------------------------------------------------------------------- /src/compute_best_mean_epoch_from_splits.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import json 4 | import numpy as np 5 | 6 | 7 | LOG_DIR = 'logs/visdom' 8 | 9 | METRICS = ['MOTA', 'IDF1', 'BBOX AP IoU=0.50:0.95', 'MASK AP IoU=0.50:0.95'] 10 | 11 | RUNS = [ 12 | 'mot17_train_1_deformable_full_res', 13 | 'mot17_train_2_deformable_full_res', 14 | 'mot17_train_3_deformable_full_res', 15 | 'mot17_train_4_deformable_full_res', 16 | 'mot17_train_5_deformable_full_res', 17 | 'mot17_train_6_deformable_full_res', 18 | 'mot17_train_7_deformable_full_res', 19 | ] 20 | 21 | RUNS = [ 22 | 'mot17_train_1_no_pretrain_deformable_tracking', 23 | 'mot17_train_2_no_pretrain_deformable_tracking', 24 | 'mot17_train_3_no_pretrain_deformable_tracking', 25 | 'mot17_train_4_no_pretrain_deformable_tracking', 26 | 'mot17_train_5_no_pretrain_deformable_tracking', 27 | 'mot17_train_6_no_pretrain_deformable_tracking', 28 | 'mot17_train_7_no_pretrain_deformable_tracking', 29 | ] 30 | 31 | RUNS = [ 32 | 'mot17_train_1_coco_pretrain_deformable_tracking_lr=0.00001', 33 | 'mot17_train_2_coco_pretrain_deformable_tracking_lr=0.00001', 34 | 'mot17_train_3_coco_pretrain_deformable_tracking_lr=0.00001', 35 | 'mot17_train_4_coco_pretrain_deformable_tracking_lr=0.00001', 36 | 'mot17_train_5_coco_pretrain_deformable_tracking_lr=0.00001', 37 | 'mot17_train_6_coco_pretrain_deformable_tracking_lr=0.00001', 38 | 'mot17_train_7_coco_pretrain_deformable_tracking_lr=0.00001', 39 | ] 40 | 41 | RUNS = [ 42 | 'mot17_train_1_crowdhuman_coco_pretrain_deformable_tracking_lr=0.00001', 43 | 'mot17_train_2_crowdhuman_coco_pretrain_deformable_tracking_lr=0.00001', 44 | 'mot17_train_3_crowdhuman_coco_pretrain_deformable_tracking_lr=0.00001', 45 | 'mot17_train_4_crowdhuman_coco_pretrain_deformable_tracking_lr=0.00001', 46 | 'mot17_train_5_crowdhuman_coco_pretrain_deformable_tracking_lr=0.00001', 47 | 'mot17_train_6_crowdhuman_coco_pretrain_deformable_tracking_lr=0.00001', 48 | 'mot17_train_7_crowdhuman_coco_pretrain_deformable_tracking_lr=0.00001', 49 | ] 50 | 51 | # RUNS = [ 52 | # 'mot17_train_1_no_pretrain_deformable_tracking_eos_coef=0.2', 53 | # 'mot17_train_2_no_pretrain_deformable_tracking_eos_coef=0.2', 54 | # 'mot17_train_3_no_pretrain_deformable_tracking_eos_coef=0.2', 55 | # 'mot17_train_4_no_pretrain_deformable_tracking_eos_coef=0.2', 56 | # 'mot17_train_5_no_pretrain_deformable_tracking_eos_coef=0.2', 57 | # 'mot17_train_6_no_pretrain_deformable_tracking_eos_coef=0.2', 58 | # 'mot17_train_7_no_pretrain_deformable_tracking_eos_coef=0.2', 59 | # ] 60 | 61 | # RUNS = [ 62 | # 'mot17_train_1_no_pretrain_deformable_tracking_lr_drop=50', 63 | # 'mot17_train_2_no_pretrain_deformable_tracking_lr_drop=50', 64 | # 'mot17_train_3_no_pretrain_deformable_tracking_lr_drop=50', 65 | # 'mot17_train_4_no_pretrain_deformable_tracking_lr_drop=50', 66 | # 'mot17_train_5_no_pretrain_deformable_tracking_lr_drop=50', 67 | # 'mot17_train_6_no_pretrain_deformable_tracking_lr_drop=50', 68 | # 'mot17_train_7_no_pretrain_deformable_tracking_lr_drop=50', 69 | # ] 70 | 71 | # RUNS = [ 72 | # 'mot17_train_1_no_pretrain_deformable_tracking_save_model_interval=1', 73 | # 'mot17_train_2_no_pretrain_deformable_tracking_save_model_interval=1', 74 | # 'mot17_train_3_no_pretrain_deformable_tracking_save_model_interval=1', 75 | # 'mot17_train_4_no_pretrain_deformable_tracking_save_model_interval=1', 76 | # 'mot17_train_5_no_pretrain_deformable_tracking_save_model_interval=1', 77 | # 'mot17_train_6_no_pretrain_deformable_tracking_save_model_interval=1', 78 | # 'mot17_train_7_no_pretrain_deformable_tracking_save_model_interval=1', 79 | # ] 80 | 81 | # RUNS = [ 82 | # 'mot17_train_1_no_pretrain_deformable_tracking_save_model_interval=1', 83 | # 'mot17_train_2_no_pretrain_deformable_tracking_save_model_interval=1', 84 | # 'mot17_train_3_no_pretrain_deformable_tracking_save_model_interval=1', 85 | # 'mot17_train_4_no_pretrain_deformable_tracking_save_model_interval=1', 86 | # 'mot17_train_5_no_pretrain_deformable_tracking_save_model_interval=1', 87 | # 'mot17_train_6_no_pretrain_deformable_tracking_save_model_interval=1', 88 | # 'mot17_train_7_no_pretrain_deformable_tracking_save_model_interval=1', 89 | # ] 90 | 91 | # RUNS = [ 92 | # 'mot17_train_1_no_pretrain_deformable_full_res', 93 | # 'mot17_train_2_no_pretrain_deformable_full_res', 94 | # 'mot17_train_3_no_pretrain_deformable_full_res', 95 | # 'mot17_train_4_no_pretrain_deformable_full_res', 96 | # 'mot17_train_5_no_pretrain_deformable_full_res', 97 | # 'mot17_train_6_no_pretrain_deformable_full_res', 98 | # 'mot17_train_7_no_pretrain_deformable_full_res', 99 | # ] 100 | 101 | # RUNS = [ 102 | # 'mot17_train_1_no_pretrain_deformable_tracking_track_query_false_positive_eos_weight=False', 103 | # 'mot17_train_2_no_pretrain_deformable_tracking_track_query_false_positive_eos_weight=False', 104 | # 'mot17_train_3_no_pretrain_deformable_tracking_track_query_false_positive_eos_weight=False', 105 | # 'mot17_train_4_no_pretrain_deformable_tracking_track_query_false_positive_eos_weight=False', 106 | # 'mot17_train_5_no_pretrain_deformable_tracking_track_query_false_positive_eos_weight=False', 107 | # 'mot17_train_6_no_pretrain_deformable_tracking_track_query_false_positive_eos_weight=False', 108 | # 'mot17_train_7_no_pretrain_deformable_tracking_track_query_false_positive_eos_weight=False', 109 | # ] 110 | 111 | # RUNS = [ 112 | # 'mot17_train_1_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0', 113 | # 'mot17_train_2_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0', 114 | # 'mot17_train_3_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0', 115 | # 'mot17_train_4_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0', 116 | # 'mot17_train_5_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0', 117 | # 'mot17_train_6_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0', 118 | # 'mot17_train_7_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0', 119 | # ] 120 | 121 | # RUNS = [ 122 | # 'mot17_train_1_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0', 123 | # 'mot17_train_2_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0', 124 | # 'mot17_train_3_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0', 125 | # 'mot17_train_4_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0', 126 | # 'mot17_train_5_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0', 127 | # 'mot17_train_6_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0', 128 | # 'mot17_train_7_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0', 129 | # ] 130 | 131 | # RUNS = [ 132 | # 'mot17_train_1_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0_track_query_false_negative_prob=0_0', 133 | # 'mot17_train_2_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0_track_query_false_negative_prob=0_0', 134 | # 'mot17_train_3_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0_track_query_false_negative_prob=0_0', 135 | # 'mot17_train_4_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0_track_query_false_negative_prob=0_0', 136 | # 'mot17_train_5_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0_track_query_false_negative_prob=0_0', 137 | # 'mot17_train_6_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0_track_query_false_negative_prob=0_0', 138 | # 'mot17_train_7_no_pretrain_deformable_tracking_track_query_false_positive_prob=0_0_track_prev_frame_range=0_track_query_false_negative_prob=0_0', 139 | # ] 140 | 141 | # RUNS = [ 142 | # 'mot17_train_1_no_pretrain_deformable', 143 | # 'mot17_train_2_no_pretrain_deformable', 144 | # 'mot17_train_3_no_pretrain_deformable', 145 | # 'mot17_train_4_no_pretrain_deformable', 146 | # 'mot17_train_5_no_pretrain_deformable', 147 | # 'mot17_train_6_no_pretrain_deformable', 148 | # 'mot17_train_7_no_pretrain_deformable', 149 | # ] 150 | 151 | # 152 | # MOTS 4-fold split 153 | # 154 | 155 | # RUNS = [ 156 | # 'mots20_train_1_coco_tracking', 157 | # 'mots20_train_2_coco_tracking', 158 | # 'mots20_train_3_coco_tracking', 159 | # 'mots20_train_4_coco_tracking', 160 | # ] 161 | 162 | # RUNS = [ 163 | # 'mots20_train_1_coco_tracking_full_res_masks=False', 164 | # 'mots20_train_2_coco_tracking_full_res_masks=False', 165 | # 'mots20_train_3_coco_tracking_full_res_masks=False', 166 | # 'mots20_train_4_coco_tracking_full_res_masks=False', 167 | # ] 168 | 169 | # RUNS = [ 170 | # 'mots20_train_1_coco_full_res_pretrain_masks=False_lr_0_0001', 171 | # 'mots20_train_2_coco_full_res_pretrain_masks=False_lr_0_0001', 172 | # 'mots20_train_3_coco_full_res_pretrain_masks=False_lr_0_0001', 173 | # 'mots20_train_4_coco_full_res_pretrain_masks=False_lr_0_0001', 174 | # ] 175 | 176 | # RUNS = [ 177 | # 'mots20_train_1_coco_tracking_full_res_masks=False_pretrain', 178 | # 'mots20_train_2_coco_tracking_full_res_masks=False_pretrain', 179 | # 'mots20_train_3_coco_tracking_full_res_masks=False_pretrain', 180 | # 'mots20_train_4_coco_tracking_full_res_masks=False_pretrain', 181 | # ] 182 | 183 | # RUNS = [ 184 | # 'mot17det_train_1_mots_track_bbox_proposals_pretrain_train_1_mots_vis_save_model_interval_1', 185 | # 'mot17det_train_2_mots_track_bbox_proposals_pretrain_train_3_mots_vis_save_model_interval_1', 186 | # 'mot17det_train_3_mots_track_bbox_proposals_pretrain_train_4_mots_vis_save_model_interval_1', 187 | # 'mot17det_train_4_mots_track_bbox_proposals_pretrain_train_6_mots_vis_save_model_interval_1', 188 | # ] 189 | 190 | if __name__ == '__main__': 191 | results = {} 192 | 193 | for r in RUNS: 194 | print(r) 195 | log_file = os.path.join(LOG_DIR, f"{r}.json") 196 | 197 | with open(log_file) as json_file: 198 | data = json.load(json_file) 199 | 200 | window = [ 201 | window for window in data['jsons'].values() 202 | if window['title'] == 'VAL EVAL EPOCHS'][0] 203 | 204 | for m in METRICS: 205 | if m not in window['legend']: 206 | continue 207 | elif m not in results: 208 | results[m] = [] 209 | 210 | idxs = window['legend'].index(m) 211 | 212 | values = window['content']['data'][idxs]['y'] 213 | results[m].append(values) 214 | 215 | print(f'NUM EPOCHS: {len(values)}') 216 | 217 | min_length = min([len(l) for l in next(iter(results.values()))]) 218 | 219 | for metric in results.keys(): 220 | results[metric] = [l[:min_length] for l in results[metric]] 221 | 222 | mean_results = { 223 | metric: np.array(results[metric]).mean(axis=0) 224 | for metric in results.keys()} 225 | 226 | print("* METRIC INTERVAL = BEST EPOCHS") 227 | for metric in results.keys(): 228 | best_interval = mean_results[metric].argmax() 229 | print(mean_results[metric]) 230 | print( 231 | f'{metric}: {mean_results[metric].max():.2%} at {best_interval + 1}/{len(mean_results[metric])} ' 232 | f'{[(mmetric, f"{mean_results[mmetric][best_interval]:.2%}") for mmetric in results.keys() if not mmetric == metric]}') 233 | -------------------------------------------------------------------------------- /src/generate_coco_from_crowdhuman.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Generates COCO data and annotation structure from CrowdHuman data. 4 | """ 5 | import json 6 | import os 7 | import cv2 8 | 9 | from generate_coco_from_mot import check_coco_from_mot 10 | 11 | DATA_ROOT = 'data/CrowdHuman' 12 | VIS_THRESHOLD = 0.0 13 | 14 | 15 | def generate_coco_from_crowdhuman(split_name='train_val', split='train_val'): 16 | """ 17 | Generate COCO data from CrowdHuman. 18 | """ 19 | annotations = {} 20 | annotations['type'] = 'instances' 21 | annotations['images'] = [] 22 | annotations['categories'] = [{"supercategory": "person", 23 | "name": "person", 24 | "id": 1}] 25 | annotations['annotations'] = [] 26 | annotation_file = os.path.join(DATA_ROOT, f'annotations/{split_name}.json') 27 | 28 | # IMAGES 29 | imgs_list_dir = os.listdir(os.path.join(DATA_ROOT, split)) 30 | for i, img in enumerate(sorted(imgs_list_dir)): 31 | im = cv2.imread(os.path.join(DATA_ROOT, split, img)) 32 | h, w, _ = im.shape 33 | 34 | annotations['images'].append({ 35 | "file_name": img, 36 | "height": h, 37 | "width": w, 38 | "id": i, }) 39 | 40 | # GT 41 | annotation_id = 0 42 | img_file_name_to_id = { 43 | os.path.splitext(img_dict['file_name'])[0]: img_dict['id'] 44 | for img_dict in annotations['images']} 45 | 46 | for split in ['train', 'val']: 47 | if split not in split_name: 48 | continue 49 | odgt_annos_file = os.path.join(DATA_ROOT, f'annotations/annotation_{split}.odgt') 50 | with open(odgt_annos_file, 'r+') as anno_file: 51 | datalist = anno_file.readlines() 52 | 53 | ignores = 0 54 | for data in datalist: 55 | json_data = json.loads(data) 56 | gtboxes = json_data['gtboxes'] 57 | for gtbox in gtboxes: 58 | if gtbox['tag'] == 'person': 59 | bbox = gtbox['fbox'] 60 | area = bbox[2] * bbox[3] 61 | 62 | ignore = False 63 | visibility = 1.0 64 | # if 'occ' in gtbox['extra']: 65 | # visibility = 1.0 - gtbox['extra']['occ'] 66 | # if visibility <= VIS_THRESHOLD: 67 | # ignore = True 68 | 69 | if 'ignore' in gtbox['extra']: 70 | ignore = ignore or bool(gtbox['extra']['ignore']) 71 | 72 | ignores += int(ignore) 73 | 74 | annotation = { 75 | "id": annotation_id, 76 | "bbox": bbox, 77 | "image_id": img_file_name_to_id[json_data['ID']], 78 | "segmentation": [], 79 | "ignore": int(ignore), 80 | "visibility": visibility, 81 | "area": area, 82 | "iscrowd": 0, 83 | "category_id": annotations['categories'][0]['id'],} 84 | 85 | annotation_id += 1 86 | annotations['annotations'].append(annotation) 87 | 88 | # max objs per image 89 | num_objs_per_image = {} 90 | for anno in annotations['annotations']: 91 | image_id = anno["image_id"] 92 | if image_id in num_objs_per_image: 93 | num_objs_per_image[image_id] += 1 94 | else: 95 | num_objs_per_image[image_id] = 1 96 | 97 | print(f'max objs per image: {max([n for n in num_objs_per_image.values()])}') 98 | print(f'ignore augs: {ignores}/{len(annotations["annotations"])}') 99 | print(len(annotations['images'])) 100 | 101 | # for img_id, num_objs in num_objs_per_image.items(): 102 | # if num_objs > 50 or num_objs < 2: 103 | # annotations['images'] = [ 104 | # img for img in annotations['images'] 105 | # if img_id != img['id']] 106 | 107 | # annotations['annotations'] = [ 108 | # anno for anno in annotations['annotations'] 109 | # if img_id != anno['image_id']] 110 | 111 | # print(len(annotations['images'])) 112 | 113 | with open(annotation_file, 'w') as anno_file: 114 | json.dump(annotations, anno_file, indent=4) 115 | 116 | 117 | if __name__ == '__main__': 118 | generate_coco_from_crowdhuman(split_name='train_val', split='train_val') 119 | # generate_coco_from_crowdhuman(split_name='train', split='train') 120 | 121 | # coco_dir = os.path.join('data/CrowdHuman', 'train_val') 122 | # annotation_file = os.path.join('data/CrowdHuman/annotations', 'train_val.json') 123 | # check_coco_from_mot(coco_dir, annotation_file, img_id=9012) 124 | -------------------------------------------------------------------------------- /src/parse_mot_results_to_tex.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Parse MOT results and generate a LaTeX table. 4 | """ 5 | 6 | MOTS = False 7 | MOT20 = False 8 | # F_CONTENT = """ 9 | # MOTA IDF1 MOTP MT ML FP FN Recall Precision FAF IDSW Frag 10 | # MOT17-01-DPM 41.6 44.2 77.1 5 8 496 3252 49.6 86.6 1.1 22 58 11 | # MOT17-01-FRCNN 41.0 42.1 77.1 6 9 571 3207 50.3 85.0 1.3 25 61 12 | # MOT17-01-SDP 41.8 44.3 76.8 7 8 612 3112 51.8 84.5 1.4 27 65 13 | # MOT17-03-DPM 79.3 71.6 79.1 94 8 1142 20297 80.6 98.7 0.8 191 525 14 | # MOT17-03-FRCNN 79.6 72.7 79.1 93 7 1234 19945 80.9 98.6 0.8 180 508 15 | # MOT17-03-SDP 80.0 72.0 79.0 93 8 1223 19530 81.3 98.6 0.8 181 526 16 | # MOT17-06-DPM 54.8 42.0 79.5 54 63 314 4839 58.9 95.7 0.3 175 244 17 | # MOT17-06-FRCNN 55.6 42.9 79.3 57 59 363 4676 60.3 95.1 0.3 190 264 18 | # MOT17-06-SDP 55.5 43.8 79.3 56 61 354 4712 60.0 95.2 0.3 181 262 19 | # MOT17-07-DPM 44.8 42.0 76.6 11 16 1322 7851 53.5 87.2 2.6 147 275 20 | # MOT17-07-FRCNN 45.5 41.5 76.6 13 15 1263 7785 53.9 87.8 2.5 156 289 21 | # MOT17-07-SDP 45.2 42.4 76.6 13 15 1332 7775 54.0 87.3 2.7 147 279 22 | # MOT17-08-DPM 26.5 32.2 83.0 11 37 378 15066 28.7 94.1 0.6 88 146 23 | # MOT17-08-FRCNN 26.5 31.9 83.1 11 36 332 15113 28.5 94.8 0.5 89 141 24 | # MOT17-08-SDP 26.6 32.3 83.1 11 36 350 15067 28.7 94.5 0.6 91 147 25 | # MOT17-12-DPM 46.1 53.1 82.7 16 45 207 4434 48.8 95.3 0.2 30 50 26 | # MOT17-12-FRCNN 46.1 52.6 82.6 15 45 197 4443 48.7 95.5 0.2 30 48 27 | # MOT17-12-SDP 46.0 53.0 82.6 16 45 221 4426 48.9 95.0 0.2 30 52 28 | # MOT17-14-DPM 31.6 36.6 74.8 13 78 636 11812 36.1 91.3 0.8 196 331 29 | # MOT17-14-FRCNN 31.6 37.6 74.6 13 77 780 11653 37.0 89.8 1.0 202 350 30 | # MOT17-14-SDP 31.7 37.1 74.7 13 76 749 11677 36.8 90.1 1.0 205 344 31 | # OVERALL 61.5 59.6 78.9 621 752 14076 200672 64.4 96.3 0.8 2583 4965 32 | # """ 33 | 34 | F_CONTENT = """ 35 | MOTA MOTP IDF1 IDP IDR TP FP FN Rcll Prcn MTR PTR MLR MT PT ML IDSW FAR FM 36 | MOT17-01-DPM 49.92 79.58 42.97 58.18 34.06 3518 258 2932 54.54 93.17 20.83 45.83 33.33 5 11 8 40 0.57 50 37 | MOT17-01-FRCNN 50.87 79.26 42.33 55.77 34.11 3637 308 2813 56.39 92.19 33.33 41.67 25.00 8 10 6 48 0.68 57 38 | MOT17-01-SDP 53.66 78.16 45.33 54.31 38.90 4064 556 2386 63.01 87.97 41.67 37.50 20.83 10 9 5 47 1.24 72 39 | MOT17-03-DPM 74.05 79.41 66.45 76.34 58.83 79279 1389 25396 75.74 98.28 57.43 30.41 12.16 85 45 18 374 0.93 420 40 | MOT17-03-FRCNN 75.34 79.45 66.98 76.21 59.75 80635 1434 24040 77.03 98.25 56.76 32.43 10.81 84 48 16 335 0.96 409 41 | MOT17-03-SDP 79.64 79.04 65.84 72.00 60.65 86043 2134 18632 82.20 97.58 64.19 27.03 8.78 95 40 13 545 1.42 522 42 | MOT17-06-DPM 53.62 82.55 51.83 64.47 43.33 7209 711 4575 61.18 91.02 28.38 37.84 33.78 63 84 75 180 0.60 170 43 | MOT17-06-FRCNN 57.21 81.73 54.75 63.67 48.02 7928 960 3856 67.28 89.20 32.88 45.50 21.62 73 101 48 226 0.80 223 44 | MOT17-06-SDP 56.43 81.93 54.00 62.70 47.42 7895 1017 3889 67.00 88.59 36.94 37.39 25.68 82 83 57 228 0.85 222 45 | MOT17-07-DPM 52.59 80.54 48.08 66.84 37.54 9230 258 7663 54.64 97.28 20.00 53.33 26.67 12 32 16 88 0.52 148 46 | MOT17-07-FRCNN 52.39 80.11 47.88 64.56 38.05 9456 499 7437 55.98 94.99 20.00 61.67 18.33 12 37 11 106 1.00 174 47 | MOT17-07-SDP 54.56 79.84 47.81 62.29 38.79 9928 590 6965 58.77 94.39 26.67 55.00 18.33 16 33 11 121 1.18 199 48 | MOT17-08-DPM 32.52 83.93 31.85 60.34 21.63 7286 288 13838 34.49 96.20 13.16 44.74 42.11 10 34 32 128 0.46 154 49 | MOT17-08-FRCNN 31.11 84.47 31.68 62.05 21.27 6958 285 14166 32.94 96.07 13.16 39.47 47.37 10 30 36 102 0.46 120 50 | MOT17-08-SDP 34.96 83.31 33.05 58.02 23.11 7972 443 13152 37.74 94.74 15.79 48.68 35.53 12 37 27 144 0.71 175 51 | MOT17-12-DPM 51.26 83.01 57.74 72.70 47.88 5102 606 3565 58.87 89.38 23.08 42.86 34.07 21 39 31 53 0.67 86 52 | MOT17-12-FRCNN 47.71 83.16 56.73 72.39 46.64 4882 702 3785 56.33 87.43 20.88 43.96 35.16 19 40 32 45 0.78 72 53 | MOT17-12-SDP 48.88 82.87 57.46 70.30 48.59 5140 850 3527 59.31 85.81 24.18 45.05 30.77 22 41 28 54 0.94 89 54 | MOT17-14-DPM 38.07 77.47 42.03 66.15 30.80 7978 627 10505 43.16 92.71 9.15 52.44 38.41 15 86 63 314 0.84 296 55 | MOT17-14-FRCNN 37.78 76.70 41.78 59.55 32.18 8688 1300 9795 47.01 86.98 10.37 55.49 34.15 17 91 56 406 1.73 382 56 | MOT17-14-SDP 40.40 76.40 42.38 57.96 33.40 9277 1376 9206 50.19 87.08 10.37 59.76 29.88 17 98 49 434 1.83 437 57 | OVERALL\t62.30\t79.77\t57.58\t70.58\t48.62\t372105\t16591 192123 65.95 95.73 29.21 43.69 27.09 688 1029 638 4018 0.93 4477 58 | """ 59 | 60 | 61 | # MOTS = True 62 | # F_CONTENT = """ 63 | # sMOTSA MOTSA MOTSP IDF1 MT ML MTR PTR MLR GT TP FP FN Rcll Prcn FM FMR IDSW IDSWR 64 | # MOTS20-01 59.79 79.56 77.60 68.00 10 0 83.33 16.67 0.00 12 2742 255 364 88.28 91.49 37 41.91 16 18.1 65 | # MOTS20-06 63.91 78.72 82.85 65.14 115 22 60.53 27.89 11.58 190 8479 595 1335 86.40 93.44 218 252.32 158 182.9 66 | # MOTS20-07 43.17 58.52 76.59 53.60 15 17 25.86 44.83 29.31 58 8445 834 4433 65.58 91.01 177 269.91 75 114.4 67 | # MOTS20-12 62.04 74.64 84.93 76.83 41 9 60.29 26.47 13.24 68 5408 549 1063 83.57 90.78 76 90.94 29 34.7 68 | # OVERALL 54.86 69.92 80.62 63.58 181 48 55.18 30.18 14.63 328 25074 2233 7195 77.70 91.82 508 653.77 278 357.8 69 | # """ 70 | 71 | 72 | MOT20 = True 73 | F_CONTENT = """ 74 | MOTA MOTP IDF1 IDP IDR HOTA DetA AssA DetRe DetPr AssRe AssPr LocA TP FP FN Rcll Prcn IDSW\tMT\tML 75 | MOT20-04 82.72 82.57 75.59 79.81 71.79 63.21 68.29 58.64 73.11 81.27 63.43 80.18 84.53 236919 9639 37165 86.44 96.09 566\t490\t28 76 | MOT20-06 55.88 79.00 53.51 68.11 44.07 43.85 45.80 42.23 49.13 75.94 45.95 74.07 81.72 80317 5582 52440 60.50 93.50 545\t96\t72 77 | MOT20-07 56.21 85.22 59.05 78.90 47.18 49.19 48.45 50.21 50.63 84.68 53.31 83.48 86.86 19245 547 13856 58.14 97.24 92\t41\t20 78 | MOT20-08 46.03 77.71 48.34 65.65 38.26 38.89 38.46 39.70 41.87 71.85 43.36 71.76 81.08 40572 4580 36912 52.36 89.86 329\t39\t61 79 | OVERALL\t68.64 81.42 65.70 75.63 58.08 54.67 56.68 52.97 60.84 79.22 57.39 78.50 83.69 377053 20348 140373 72.87 94.88 1532\t666\t181 80 | """ 81 | 82 | 83 | if __name__ == '__main__': 84 | # remove empty lines at start and beginning of F_CONTENT 85 | F_CONTENT = F_CONTENT.strip() 86 | F_CONTENT = F_CONTENT.splitlines() 87 | 88 | start_ixs = range(1, len(F_CONTENT) - 1, 3) 89 | if MOTS or MOT20: 90 | start_ixs = range(1, len(F_CONTENT) - 1) 91 | 92 | metrics_res = {} 93 | 94 | for i in range(len(['DPM', 'FRCNN', 'SDP'])): 95 | for start in start_ixs: 96 | f_list = F_CONTENT[start + i].strip().split('\t') 97 | metrics_res[f_list[0]] = f_list[1:] 98 | 99 | if MOTS or MOT20: 100 | break 101 | 102 | metrics_names = F_CONTENT[0].replace('\n', '').split() 103 | 104 | print(metrics_names) 105 | 106 | metrics_res['ALL'] = F_CONTENT[-1].strip().split('\t')[1:] 107 | 108 | for full_seq_name, data in metrics_res.items(): 109 | seq_name = '-'.join(full_seq_name.split('-')[:2]) 110 | detection_name = full_seq_name.split('-')[-1] 111 | 112 | if MOTS: 113 | print(f"{seq_name} & " 114 | f"{float(data[metrics_names.index('sMOTSA')]):.1f} & " 115 | f"{float(data[metrics_names.index('IDF1')]):.1f} & " 116 | f"{float(data[metrics_names.index('MOTSA')]):.1f} & " 117 | f"{data[metrics_names.index('FP')]} & " 118 | f"{data[metrics_names.index('FN')]} & " 119 | f"{data[metrics_names.index('IDSW')]} \\\\") 120 | else: 121 | print(f"{seq_name} & {detection_name} & " 122 | f"{float(data[metrics_names.index('MOTA')]):.1f} & " 123 | f"{float(data[metrics_names.index('IDF1')]):.1f} & " 124 | f"{data[metrics_names.index('MT')]} & " 125 | f"{data[metrics_names.index('ML')]} & " 126 | f"{data[metrics_names.index('FP')]} & " 127 | f"{data[metrics_names.index('FN')]} & " 128 | f"{data[metrics_names.index('IDSW')]} \\\\") 129 | -------------------------------------------------------------------------------- /src/run_with_submitit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | A script to run multinode training with submitit. 4 | """ 5 | import os 6 | import sys 7 | import uuid 8 | from pathlib import Path 9 | from argparse import Namespace 10 | 11 | import sacred 12 | import submitit 13 | 14 | import train 15 | from trackformer.util.misc import nested_dict_to_namespace 16 | 17 | WORK_DIR = str(Path(__file__).parent.absolute()) 18 | 19 | 20 | ex = sacred.Experiment('submit', ingredients=[train.ex]) 21 | ex.add_config('cfgs/submit.yaml') 22 | 23 | 24 | def get_shared_folder() -> Path: 25 | user = os.getenv("USER") 26 | if Path("/storage/slurm").is_dir(): 27 | path = Path(f"/storage/slurm/{user}/runs") 28 | path.mkdir(exist_ok=True) 29 | return path 30 | raise RuntimeError("No shared folder available") 31 | 32 | 33 | def get_init_file() -> Path: 34 | # Init file must not exist, but it's parent dir must exist. 35 | os.makedirs(str(get_shared_folder()), exist_ok=True) 36 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 37 | if init_file.exists(): 38 | os.remove(str(init_file)) 39 | return init_file 40 | 41 | 42 | class Trainer: 43 | def __init__(self, args: Namespace) -> None: 44 | self.args = args 45 | 46 | def __call__(self) -> None: 47 | sys.path.append(WORK_DIR) 48 | 49 | import train 50 | self._setup_gpu_args() 51 | train.train(self.args) 52 | 53 | def checkpoint(self) -> submitit.helpers.DelayedSubmission: 54 | import os 55 | 56 | import submitit 57 | 58 | self.args.dist_url = get_init_file().as_uri() 59 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 60 | if os.path.exists(checkpoint_file): 61 | self.args.resume = checkpoint_file 62 | self.args.resume_optim = True 63 | self.args.resume_vis = True 64 | self.args.load_mask_head_from_model = None 65 | print("Requeuing ", self.args) 66 | empty_trainer = type(self)(self.args) 67 | return submitit.helpers.DelayedSubmission(empty_trainer) 68 | 69 | def _setup_gpu_args(self) -> None: 70 | from pathlib import Path 71 | 72 | import submitit 73 | 74 | job_env = submitit.JobEnvironment() 75 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 76 | print(self.args.output_dir) 77 | self.args.gpu = job_env.local_rank 78 | self.args.rank = job_env.global_rank 79 | self.args.world_size = job_env.num_tasks 80 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 81 | 82 | 83 | def main(args: Namespace): 84 | # Note that the folder will depend on the job_id, to easily track experiments 85 | if args.job_dir == "": 86 | args.job_dir = get_shared_folder() / "%j" 87 | 88 | executor = submitit.AutoExecutor( 89 | folder=args.job_dir, cluster=args.cluster, slurm_max_num_timeout=30) 90 | 91 | # cluster setup is defined by environment variables 92 | num_gpus_per_node = args.num_gpus 93 | nodes = args.nodes 94 | timeout_min = args.timeout 95 | 96 | if args.slurm_gres: 97 | slurm_gres = args.slurm_gres 98 | else: 99 | slurm_gres = f'gpu:{num_gpus_per_node},VRAM:{args.vram}' 100 | # slurm_gres = f'gpu:rtx_8000:{num_gpus_per_node}' 101 | 102 | executor.update_parameters( 103 | mem_gb=args.mem_per_gpu * num_gpus_per_node, 104 | gpus_per_node=num_gpus_per_node, 105 | tasks_per_node=num_gpus_per_node, # one task per GPU 106 | cpus_per_task=args.cpus_per_task, 107 | nodes=nodes, 108 | timeout_min=timeout_min, # max is 60 * 72, 109 | slurm_partition=args.slurm_partition, 110 | slurm_constraint=args.slurm_constraint, 111 | slurm_comment=args.slurm_comment, 112 | slurm_exclude=args.slurm_exclude, 113 | slurm_gres=slurm_gres 114 | ) 115 | 116 | executor.update_parameters(name="fair_track") 117 | 118 | args.train.dist_url = get_init_file().as_uri() 119 | # args.output_dir = args.job_dir 120 | 121 | trainer = Trainer(args.train) 122 | job = executor.submit(trainer) 123 | 124 | print("Submitted job_id:", job.job_id) 125 | 126 | if args.cluster == 'debug': 127 | job.wait() 128 | 129 | 130 | @ex.main 131 | def load_config(_config, _run): 132 | """ We use sacred only for config loading from YAML files. """ 133 | sacred.commands.print_config(_run) 134 | 135 | 136 | if __name__ == '__main__': 137 | # TODO: hierachical Namespacing for nested dict 138 | config = ex.run_commandline().config 139 | args = nested_dict_to_namespace(config) 140 | # args.train = Namespace(**config['train']) 141 | main(args) 142 | -------------------------------------------------------------------------------- /src/track.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import sys 4 | import time 5 | from os import path as osp 6 | 7 | import motmetrics as mm 8 | import numpy as np 9 | import sacred 10 | import torch 11 | import tqdm 12 | import yaml 13 | from torch.utils.data import DataLoader 14 | 15 | from trackformer.datasets.tracking import TrackDatasetFactory 16 | from trackformer.models import build_model 17 | from trackformer.models.tracker import Tracker 18 | from trackformer.util.misc import nested_dict_to_namespace 19 | from trackformer.util.track_utils import (evaluate_mot_accums, get_mot_accum, 20 | interpolate_tracks, plot_sequence) 21 | 22 | mm.lap.default_solver = 'lap' 23 | 24 | ex = sacred.Experiment('track') 25 | ex.add_config('cfgs/track.yaml') 26 | ex.add_named_config('reid', 'cfgs/track_reid.yaml') 27 | 28 | 29 | @ex.automain 30 | def main(seed, dataset_name, obj_detect_checkpoint_file, tracker_cfg, 31 | write_images, output_dir, interpolate, verbose, load_results_dir, 32 | data_root_dir, generate_attention_maps, frame_range, 33 | _config, _log, _run, obj_detector_model=None): 34 | if write_images: 35 | assert output_dir is not None 36 | 37 | # obj_detector_model is only provided when run as evaluation during 38 | # training. in that case we omit verbose outputs. 39 | if obj_detector_model is None: 40 | sacred.commands.print_config(_run) 41 | 42 | # set all seeds 43 | if seed is not None: 44 | torch.manual_seed(seed) 45 | torch.cuda.manual_seed(seed) 46 | np.random.seed(seed) 47 | torch.backends.cudnn.deterministic = True 48 | 49 | if output_dir is not None: 50 | if not osp.exists(output_dir): 51 | os.makedirs(output_dir) 52 | 53 | yaml.dump( 54 | _config, 55 | open(osp.join(output_dir, 'track.yaml'), 'w'), 56 | default_flow_style=False) 57 | 58 | ########################## 59 | # Initialize the modules # 60 | ########################## 61 | 62 | # object detection 63 | if obj_detector_model is None: 64 | obj_detect_config_path = os.path.join( 65 | os.path.dirname(obj_detect_checkpoint_file), 66 | 'config.yaml') 67 | obj_detect_args = nested_dict_to_namespace(yaml.unsafe_load(open(obj_detect_config_path))) 68 | img_transform = obj_detect_args.img_transform 69 | obj_detector, _, obj_detector_post = build_model(obj_detect_args) 70 | 71 | obj_detect_checkpoint = torch.load( 72 | obj_detect_checkpoint_file, map_location=lambda storage, loc: storage) 73 | 74 | obj_detect_state_dict = obj_detect_checkpoint['model'] 75 | # obj_detect_state_dict = { 76 | # k: obj_detect_state_dict[k] if k in obj_detect_state_dict 77 | # else v 78 | # for k, v in obj_detector.state_dict().items()} 79 | 80 | obj_detect_state_dict = { 81 | k.replace('detr.', ''): v 82 | for k, v in obj_detect_state_dict.items() 83 | if 'track_encoding' not in k} 84 | 85 | obj_detector.load_state_dict(obj_detect_state_dict) 86 | if 'epoch' in obj_detect_checkpoint: 87 | _log.info(f"INIT object detector [EPOCH: {obj_detect_checkpoint['epoch']}]") 88 | 89 | obj_detector.cuda() 90 | else: 91 | obj_detector = obj_detector_model['model'] 92 | obj_detector_post = obj_detector_model['post'] 93 | img_transform = obj_detector_model['img_transform'] 94 | 95 | if hasattr(obj_detector, 'tracking'): 96 | obj_detector.tracking() 97 | 98 | track_logger = None 99 | if verbose: 100 | track_logger = _log.info 101 | tracker = Tracker( 102 | obj_detector, obj_detector_post, tracker_cfg, 103 | generate_attention_maps, track_logger, verbose) 104 | 105 | time_total = 0 106 | num_frames = 0 107 | mot_accums = [] 108 | dataset = TrackDatasetFactory( 109 | dataset_name, root_dir=data_root_dir, img_transform=img_transform) 110 | 111 | for seq in dataset: 112 | tracker.reset() 113 | 114 | _log.info(f"------------------") 115 | _log.info(f"TRACK SEQ: {seq}") 116 | 117 | start_frame = int(frame_range['start'] * len(seq)) 118 | end_frame = int(frame_range['end'] * len(seq)) 119 | 120 | seq_loader = DataLoader( 121 | torch.utils.data.Subset(seq, range(start_frame, end_frame))) 122 | 123 | num_frames += len(seq_loader) 124 | 125 | results = seq.load_results(load_results_dir) 126 | 127 | if not results: 128 | start = time.time() 129 | 130 | for frame_id, frame_data in enumerate(tqdm.tqdm(seq_loader, file=sys.stdout)): 131 | with torch.no_grad(): 132 | tracker.step(frame_data) 133 | 134 | results = tracker.get_results() 135 | 136 | time_total += time.time() - start 137 | 138 | _log.info(f"NUM TRACKS: {len(results)} ReIDs: {tracker.num_reids}") 139 | _log.info(f"RUNTIME: {time.time() - start :.2f} s") 140 | 141 | if interpolate: 142 | results = interpolate_tracks(results) 143 | 144 | if output_dir is not None: 145 | _log.info(f"WRITE RESULTS") 146 | seq.write_results(results, output_dir) 147 | else: 148 | _log.info("LOAD RESULTS") 149 | 150 | if seq.no_gt: 151 | _log.info("NO GT AVAILBLE") 152 | else: 153 | mot_accum = get_mot_accum(results, seq_loader) 154 | mot_accums.append(mot_accum) 155 | 156 | if verbose: 157 | mot_events = mot_accum.mot_events 158 | reid_events = mot_events[mot_events['Type'] == 'SWITCH'] 159 | match_events = mot_events[mot_events['Type'] == 'MATCH'] 160 | 161 | switch_gaps = [] 162 | for index, event in reid_events.iterrows(): 163 | frame_id, _ = index 164 | match_events_oid = match_events[match_events['OId'] == event['OId']] 165 | match_events_oid_earlier = match_events_oid[ 166 | match_events_oid.index.get_level_values('FrameId') < frame_id] 167 | 168 | if not match_events_oid_earlier.empty: 169 | match_events_oid_earlier_frame_ids = \ 170 | match_events_oid_earlier.index.get_level_values('FrameId') 171 | last_occurrence = match_events_oid_earlier_frame_ids.max() 172 | switch_gap = frame_id - last_occurrence 173 | switch_gaps.append(switch_gap) 174 | 175 | switch_gaps_hist = None 176 | if switch_gaps: 177 | switch_gaps_hist, _ = np.histogram( 178 | switch_gaps, bins=list(range(0, max(switch_gaps) + 10, 10))) 179 | switch_gaps_hist = switch_gaps_hist.tolist() 180 | 181 | _log.info(f'SWITCH_GAPS_HIST (bin_width=10): {switch_gaps_hist}') 182 | 183 | if output_dir is not None and write_images: 184 | _log.info("PLOT SEQ") 185 | plot_sequence( 186 | results, seq_loader, osp.join(output_dir, dataset_name, str(seq)), 187 | write_images, generate_attention_maps) 188 | 189 | if time_total: 190 | _log.info(f"RUNTIME ALL SEQS (w/o EVAL or IMG WRITE): " 191 | f"{time_total:.2f} s for {num_frames} frames " 192 | f"({num_frames / time_total:.2f} Hz)") 193 | 194 | if obj_detector_model is None: 195 | _log.info(f"EVAL:") 196 | 197 | summary, str_summary = evaluate_mot_accums( 198 | mot_accums, 199 | [str(s) for s in dataset if not s.no_gt]) 200 | 201 | _log.info(f'\n{str_summary}') 202 | 203 | return summary 204 | 205 | return mot_accums 206 | -------------------------------------------------------------------------------- /src/track_param_search.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from itertools import product 3 | 4 | import numpy as np 5 | 6 | from track import ex 7 | 8 | 9 | if __name__ == "__main__": 10 | 11 | 12 | # configs = [ 13 | # {'dataset_name': ["MOT17-02-FRCNN", "MOT17-10-FRCNN", "MOT17-13-FRCNN"], 14 | # 'obj_detect_checkpoint_file': 'models/mot17det_train_cross_val_1_mots_vis_track_bbox_proposals_track_encoding_bbox_proposals_prev_frame_5/checkpoint_best_MOTA.pth'}, 15 | # {'dataset_name': ["MOT17-04-FRCNN", "MOT17-11-FRCNN"], 16 | # 'obj_detect_checkpoint_file': 'models/mot17det_train_cross_val_2_mots_vis_track_bbox_proposals_track_encoding_bbox_proposals_prev_frame_5/checkpoint_best_MOTA.pth'}, 17 | # {'dataset_name': ["MOT17-05-FRCNN", "MOT17-09-FRCNN"], 18 | # 'obj_detect_checkpoint_file': 'models/mot17det_train_cross_val_3_mots_vis_track_bbox_proposals_track_encoding_bbox_proposals_prev_frame_5/checkpoint_best_MOTA.pth'}, 19 | # ] 20 | 21 | # configs = [ 22 | # {'dataset_name': ["MOT17-02-FRCNN"], 23 | # 'obj_detect_checkpoint_file': '/storage/user/meinhard/fair_track/models/mot17_train_1_no_pretrain_deformable/checkpoint_best_BBOX_AP_IoU_0_50-0_95.pth'}, 24 | # {'dataset_name': ["MOT17-04-FRCNN"], 25 | # 'obj_detect_checkpoint_file': '/storage/user/meinhard/fair_track/models/mot17_train_2_no_pretrain_deformable/checkpoint_best_BBOX_AP_IoU_0_50-0_95.pth'}, 26 | # {'dataset_name': ["MOT17-05-FRCNN"], 27 | # 'obj_detect_checkpoint_file': '/storage/user/meinhard/fair_track/models/mot17_train_3_no_pretrain_deformable/checkpoint_best_BBOX_AP_IoU_0_50-0_95.pth'}, 28 | # {'dataset_name': ["MOT17-09-FRCNN"], 29 | # 'obj_detect_checkpoint_file': '/storage/user/meinhard/fair_track/models/mot17_train_4_no_pretrain_deformable/checkpoint_best_BBOX_AP_IoU_0_50-0_95.pth'}, 30 | # {'dataset_name': ["MOT17-10-FRCNN"], 31 | # 'obj_detect_checkpoint_file': '/storage/user/meinhard/fair_track/models/mot17_train_5_no_pretrain_deformable/checkpoint_best_BBOX_AP_IoU_0_50-0_95.pth'}, 32 | # {'dataset_name': ["MOT17-11-FRCNN"], 33 | # 'obj_detect_checkpoint_file': '/storage/user/meinhard/fair_track/models/mot17_train_6_no_pretrain_deformable/checkpoint_best_BBOX_AP_IoU_0_50-0_95.pth'}, 34 | # {'dataset_name': ["MOT17-13-FRCNN"], 35 | # 'obj_detect_checkpoint_file': '/storage/user/meinhard/fair_track/models/mot17_train_7_no_pretrain_deformable/checkpoint_best_BBOX_AP_IoU_0_50-0_95.pth'}, 36 | # ] 37 | 38 | # dataset_name = ["MOT17-02-FRCNN", "MOT17-04-FRCNN", "MOT17-05-FRCNN", "MOT17-09-FRCNN", "MOT17-10-FRCNN", "MOT17-11-FRCNN", "MOT17-13-FRCNN"] 39 | 40 | # general_tracker_cfg = {'public_detections': False, 'reid_sim_only': True, 'reid_greedy_matching': False} 41 | general_tracker_cfg = {'public_detections': 'min_iou_0_5'} 42 | # general_tracker_cfg = {'public_detections': False} 43 | 44 | # dataset_name = 'MOT17-TRAIN-FRCNN' 45 | dataset_name = 'MOT17-TRAIN-ALL' 46 | # dataset_name = 'MOT20-TRAIN' 47 | 48 | configs = [ 49 | {'dataset_name': dataset_name, 50 | 51 | 'frame_range': {'start': 0.5}, 52 | 'obj_detect_checkpoint_file': '/storage/user/meinhard/fair_track/models/mot_mot17_train_cross_val_frame_0_0_to_0_5_coco_pretrained_num_queries_500_batch_size=2_num_gpus_7_num_classes_20_AP_det_overflow_boxes_True_prev_frame_rnd_augs_0_2_uniform_false_negative_prob_multi_frame_hidden_dim_288_sep_encoders_batch_queries/checkpoint_epoch_50.pth'}, 53 | ] 54 | 55 | tracker_param_grids = { 56 | # 'detection_obj_score_thresh': [0.3, 0.4, 0.5, 0.6], 57 | # 'track_obj_score_thresh': [0.3, 0.4, 0.5, 0.6], 58 | 'detection_obj_score_thresh': [0.4], 59 | 'track_obj_score_thresh': [0.4], 60 | # 'detection_nms_thresh': [0.95, 0.9, 0.0], 61 | # 'track_nms_thresh': [0.95, 0.9, 0.0], 62 | # 'detection_nms_thresh': [0.9], 63 | # 'track_nms_thresh': [0.9], 64 | # 'reid_sim_threshold': [0.0, 0.5, 1.0, 10, 50, 100, 200], 65 | 'reid_score_thresh': [0.4], 66 | # 'inactive_patience': [-1, 5, 10, 20, 30, 40, 50] 67 | # 'reid_score_thresh': [0.8], 68 | # 'inactive_patience': [-1], 69 | # 'inactive_patience': [-1, 5, 10] 70 | } 71 | 72 | # compute all config combinations 73 | tracker_param_cfgs = [dict(zip(tracker_param_grids, v)) 74 | for v in product(*tracker_param_grids.values())] 75 | 76 | # add empty metric arrays 77 | metrics = ['mota', 'idf1'] 78 | tracker_param_cfgs = [ 79 | {'config': {**general_tracker_cfg, **tracker_cfg}} 80 | for tracker_cfg in tracker_param_cfgs] 81 | 82 | for m in metrics: 83 | for tracker_cfg in tracker_param_cfgs: 84 | tracker_cfg[m] = [] 85 | 86 | total_num_experiments = len(tracker_param_cfgs) * len(configs) 87 | print(f'NUM experiments: {total_num_experiments}') 88 | 89 | # run all tracker config combinations for all experiment configurations 90 | exp_counter = 1 91 | for config in configs: 92 | for tracker_cfg in tracker_param_cfgs: 93 | print(f"EXPERIMENT: {exp_counter}/{total_num_experiments}") 94 | 95 | config['tracker_cfg'] = tracker_cfg['config'] 96 | run = ex.run(config_updates=config) 97 | eval_summary = run.result 98 | 99 | for m in metrics: 100 | tracker_cfg[m].append(eval_summary[m]['OVERALL']) 101 | 102 | exp_counter += 1 103 | 104 | # compute mean for all metrices 105 | for m in metrics: 106 | for tracker_cfg in tracker_param_cfgs: 107 | tracker_cfg[m] = np.array(tracker_cfg[m]).mean() 108 | 109 | for cfg in tracker_param_cfgs: 110 | print([cfg[m] for m in metrics], cfg['config']) 111 | 112 | # compute and plot best metric config 113 | for m in metrics: 114 | best_metric_cfg_idx = np.array( 115 | [cfg[m] for cfg in tracker_param_cfgs]).argmax() 116 | 117 | print(f"BEST {m.upper()} CFG: {tracker_param_cfgs[best_metric_cfg_idx]['config']}") 118 | 119 | # TODO 120 | best_mota_plus_idf1_cfg_idx = np.array( 121 | [cfg['mota'] + cfg['idf1'] for cfg in tracker_param_cfgs]).argmax() 122 | print(f"BEST MOTA PLUS IDF1 CFG: {tracker_param_cfgs[best_mota_plus_idf1_cfg_idx]['config']}") 123 | -------------------------------------------------------------------------------- /src/trackformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timmeinhardt/trackformer/e468bf156b029869f6de1be358bc11cd1f517f3c/src/trackformer/__init__.py -------------------------------------------------------------------------------- /src/trackformer/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Submodule interface. 4 | """ 5 | from argparse import Namespace 6 | from pycocotools.coco import COCO 7 | from torch.utils.data import Dataset, Subset 8 | from torchvision.datasets import CocoDetection 9 | 10 | from .coco import build as build_coco 11 | from .crowdhuman import build_crowdhuman 12 | from .mot import build_mot, build_mot_crowdhuman, build_mot_coco_person 13 | 14 | 15 | def get_coco_api_from_dataset(dataset: Subset) -> COCO: 16 | """Return COCO class from PyTorch dataset for evaluation with COCO eval.""" 17 | for _ in range(10): 18 | # if isinstance(dataset, CocoDetection): 19 | # break 20 | if isinstance(dataset, Subset): 21 | dataset = dataset.dataset 22 | 23 | if not isinstance(dataset, CocoDetection): 24 | raise NotImplementedError 25 | 26 | return dataset.coco 27 | 28 | 29 | def build_dataset(split: str, args: Namespace) -> Dataset: 30 | """Helper function to build dataset for different splits ('train' or 'val').""" 31 | if args.dataset == 'coco': 32 | dataset = build_coco(split, args) 33 | elif args.dataset == 'coco_person': 34 | dataset = build_coco(split, args, 'person_keypoints') 35 | elif args.dataset == 'mot': 36 | dataset = build_mot(split, args) 37 | elif args.dataset == 'crowdhuman': 38 | dataset = build_crowdhuman(split, args) 39 | elif args.dataset == 'mot_crowdhuman': 40 | dataset = build_mot_crowdhuman(split, args) 41 | elif args.dataset == 'mot_coco_person': 42 | dataset = build_mot_coco_person(split, args) 43 | elif args.dataset == 'coco_panoptic': 44 | # to avoid making panopticapi required for coco 45 | from .coco_panoptic import build as build_coco_panoptic 46 | dataset = build_coco_panoptic(split, args) 47 | else: 48 | raise ValueError(f'dataset {args.dataset} not supported') 49 | 50 | return dataset 51 | -------------------------------------------------------------------------------- /src/trackformer/datasets/coco_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | COCO evaluator that works in distributed mode. 4 | 5 | Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py 6 | The difference is that there is less copy-pasting from pycocotools 7 | in the end of the file, as python3 can suppress prints with contextlib 8 | """ 9 | import os 10 | import contextlib 11 | import copy 12 | import numpy as np 13 | import torch 14 | 15 | from pycocotools.cocoeval import COCOeval 16 | from pycocotools.coco import COCO 17 | import pycocotools.mask as mask_util 18 | 19 | from ..util.misc import all_gather 20 | 21 | 22 | class CocoEvaluator(object): 23 | def __init__(self, coco_gt, iou_types): 24 | assert isinstance(iou_types, (list, tuple)) 25 | coco_gt = copy.deepcopy(coco_gt) 26 | self.coco_gt = coco_gt 27 | 28 | self.iou_types = iou_types 29 | self.coco_eval = {} 30 | for iou_type in iou_types: 31 | self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) 32 | 33 | self.img_ids = [] 34 | self.eval_imgs = {k: [] for k in iou_types} 35 | 36 | def update(self, predictions): 37 | img_ids = list(np.unique(list(predictions.keys()))) 38 | self.img_ids.extend(img_ids) 39 | 40 | for prediction in predictions.values(): 41 | prediction["labels"] += 1 42 | 43 | for iou_type in self.iou_types: 44 | results = self.prepare(predictions, iou_type) 45 | 46 | # suppress pycocotools prints 47 | with open(os.devnull, 'w') as devnull: 48 | with contextlib.redirect_stdout(devnull): 49 | coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() 50 | coco_eval = self.coco_eval[iou_type] 51 | 52 | coco_eval.cocoDt = coco_dt 53 | coco_eval.params.imgIds = list(img_ids) 54 | img_ids, eval_imgs = evaluate(coco_eval) 55 | 56 | self.eval_imgs[iou_type].append(eval_imgs) 57 | 58 | def synchronize_between_processes(self): 59 | for iou_type in self.iou_types: 60 | self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) 61 | create_common_coco_eval( 62 | self.coco_eval[iou_type], 63 | self.img_ids, 64 | self.eval_imgs[iou_type]) 65 | 66 | def accumulate(self): 67 | for coco_eval in self.coco_eval.values(): 68 | coco_eval.accumulate() 69 | 70 | def summarize(self): 71 | for iou_type, coco_eval in self.coco_eval.items(): 72 | print(f"IoU metric: {iou_type}") 73 | coco_eval.summarize() 74 | 75 | def prepare(self, predictions, iou_type): 76 | if iou_type == "bbox": 77 | return self.prepare_for_coco_detection(predictions) 78 | elif iou_type == "segm": 79 | return self.prepare_for_coco_segmentation(predictions) 80 | elif iou_type == "keypoints": 81 | return self.prepare_for_coco_keypoint(predictions) 82 | else: 83 | raise ValueError("Unknown iou type {}".format(iou_type)) 84 | 85 | def prepare_for_coco_detection(self, predictions): 86 | coco_results = [] 87 | for original_id, prediction in predictions.items(): 88 | if len(prediction) == 0: 89 | continue 90 | 91 | boxes = prediction["boxes"] 92 | boxes = convert_to_xywh(boxes).tolist() 93 | scores = prediction["scores"].tolist() 94 | labels = prediction["labels"].tolist() 95 | 96 | coco_results.extend( 97 | [ 98 | { 99 | "image_id": original_id, 100 | "category_id": labels[k], 101 | "bbox": box, 102 | "score": scores[k], 103 | } 104 | for k, box in enumerate(boxes) 105 | ] 106 | ) 107 | return coco_results 108 | 109 | def prepare_for_coco_segmentation(self, predictions): 110 | coco_results = [] 111 | for original_id, prediction in predictions.items(): 112 | if len(prediction) == 0: 113 | continue 114 | 115 | scores = prediction["scores"] 116 | labels = prediction["labels"] 117 | masks = prediction["masks"] 118 | 119 | masks = masks > 0.5 120 | 121 | scores = prediction["scores"].tolist() 122 | labels = prediction["labels"].tolist() 123 | 124 | rles = [ 125 | mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] 126 | for mask in masks 127 | ] 128 | for rle in rles: 129 | rle["counts"] = rle["counts"].decode("utf-8") 130 | 131 | coco_results.extend( 132 | [ 133 | { 134 | "image_id": original_id, 135 | "category_id": labels[k], 136 | "segmentation": rle, 137 | "score": scores[k], 138 | } 139 | for k, rle in enumerate(rles) 140 | ] 141 | ) 142 | return coco_results 143 | 144 | def prepare_for_coco_keypoint(self, predictions): 145 | coco_results = [] 146 | for original_id, prediction in predictions.items(): 147 | if len(prediction) == 0: 148 | continue 149 | 150 | boxes = prediction["boxes"] 151 | boxes = convert_to_xywh(boxes).tolist() 152 | scores = prediction["scores"].tolist() 153 | labels = prediction["labels"].tolist() 154 | keypoints = prediction["keypoints"] 155 | keypoints = keypoints.flatten(start_dim=1).tolist() 156 | 157 | coco_results.extend( 158 | [ 159 | { 160 | "image_id": original_id, 161 | "category_id": labels[k], 162 | 'keypoints': keypoint, 163 | "score": scores[k], 164 | } 165 | for k, keypoint in enumerate(keypoints) 166 | ] 167 | ) 168 | return coco_results 169 | 170 | 171 | def convert_to_xywh(boxes): 172 | xmin, ymin, xmax, ymax = boxes.unbind(1) 173 | return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) 174 | 175 | 176 | def merge(img_ids, eval_imgs): 177 | all_img_ids = all_gather(img_ids) 178 | all_eval_imgs = all_gather(eval_imgs) 179 | 180 | merged_img_ids = [] 181 | for p in all_img_ids: 182 | merged_img_ids.extend(p) 183 | 184 | merged_eval_imgs = [] 185 | for p in all_eval_imgs: 186 | merged_eval_imgs.append(p) 187 | 188 | merged_img_ids = np.array(merged_img_ids) 189 | merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) 190 | 191 | # keep only unique (and in sorted order) images 192 | merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) 193 | merged_eval_imgs = merged_eval_imgs[..., idx] 194 | 195 | return merged_img_ids, merged_eval_imgs 196 | 197 | 198 | def create_common_coco_eval(coco_eval, img_ids, eval_imgs): 199 | img_ids, eval_imgs = merge(img_ids, eval_imgs) 200 | img_ids = list(img_ids) 201 | eval_imgs = list(eval_imgs.flatten()) 202 | 203 | coco_eval.evalImgs = eval_imgs 204 | coco_eval.params.imgIds = img_ids 205 | coco_eval._paramsEval = copy.deepcopy(coco_eval.params) 206 | 207 | 208 | ################################################################# 209 | # From pycocotools, just removed the prints and fixed 210 | # a Python3 bug about unicode not defined 211 | ################################################################# 212 | 213 | 214 | def evaluate(self): 215 | ''' 216 | Run per image evaluation on given images and store results (a list of dict) in self.evalImgs 217 | :return: None 218 | ''' 219 | # tic = time.time() 220 | # print('Running per image evaluation...') 221 | p = self.params 222 | # add backward compatibility if useSegm is specified in params 223 | if p.useSegm is not None: 224 | p.iouType = 'segm' if p.useSegm == 1 else 'bbox' 225 | print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) 226 | # print('Evaluate annotation type *{}*'.format(p.iouType)) 227 | p.imgIds = list(np.unique(p.imgIds)) 228 | if p.useCats: 229 | p.catIds = list(np.unique(p.catIds)) 230 | p.maxDets = sorted(p.maxDets) 231 | self.params = p 232 | 233 | self._prepare() 234 | # loop through images, area range, max detection number 235 | catIds = p.catIds if p.useCats else [-1] 236 | 237 | if p.iouType == 'segm' or p.iouType == 'bbox': 238 | computeIoU = self.computeIoU 239 | elif p.iouType == 'keypoints': 240 | computeIoU = self.computeOks 241 | self.ious = { 242 | (imgId, catId): computeIoU(imgId, catId) 243 | for imgId in p.imgIds 244 | for catId in catIds} 245 | 246 | evaluateImg = self.evaluateImg 247 | maxDet = p.maxDets[-1] 248 | evalImgs = [ 249 | evaluateImg(imgId, catId, areaRng, maxDet) 250 | for catId in catIds 251 | for areaRng in p.areaRng 252 | for imgId in p.imgIds 253 | ] 254 | # this is NOT in the pycocotools code, but could be done outside 255 | evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) 256 | self._paramsEval = copy.deepcopy(self.params) 257 | # toc = time.time() 258 | # print('DONE (t={:0.2f}s).'.format(toc-tic)) 259 | return p.imgIds, evalImgs 260 | 261 | ################################################################# 262 | # end of straight copy from pycocotools, just removing the prints 263 | ################################################################# 264 | -------------------------------------------------------------------------------- /src/trackformer/datasets/coco_panoptic.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import json 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | 9 | from panopticapi.utils import rgb2id 10 | from util.box_ops import masks_to_boxes 11 | 12 | from .coco import make_coco_transforms 13 | 14 | 15 | class CocoPanoptic: 16 | def __init__(self, img_folder, ann_folder, ann_file, transforms=None, norm_transforms=None, return_masks=True): 17 | with open(ann_file, 'r') as f: 18 | self.coco = json.load(f) 19 | 20 | # sort 'images' field so that they are aligned with 'annotations' 21 | # i.e., in alphabetical order 22 | self.coco['images'] = sorted(self.coco['images'], key=lambda x: x['id']) 23 | # sanity check 24 | if "annotations" in self.coco: 25 | for img, ann in zip(self.coco['images'], self.coco['annotations']): 26 | assert img['file_name'][:-4] == ann['file_name'][:-4] 27 | 28 | self.img_folder = img_folder 29 | self.ann_folder = ann_folder 30 | self.ann_file = ann_file 31 | self.transforms = transforms 32 | self.norm_transforms = norm_transforms 33 | self.return_masks = return_masks 34 | 35 | def __getitem__(self, idx): 36 | ann_info = self.coco['annotations'][idx] if "annotations" in self.coco else self.coco['images'][idx] 37 | img_path = Path(self.img_folder) / ann_info['file_name'].replace('.png', '.jpg') 38 | ann_path = Path(self.ann_folder) / ann_info['file_name'] 39 | 40 | img = Image.open(img_path).convert('RGB') 41 | w, h = img.size 42 | if "segments_info" in ann_info: 43 | masks = np.asarray(Image.open(ann_path), dtype=np.uint32) 44 | masks = rgb2id(masks) 45 | 46 | ids = np.array([ann['id'] for ann in ann_info['segments_info']]) 47 | masks = masks == ids[:, None, None] 48 | 49 | masks = torch.as_tensor(masks, dtype=torch.uint8) 50 | labels = torch.tensor([ann['category_id'] for ann in ann_info['segments_info']], dtype=torch.int64) 51 | 52 | target = {} 53 | target['image_id'] = torch.tensor([ann_info['image_id'] if "image_id" in ann_info else ann_info["id"]]) 54 | if self.return_masks: 55 | target['masks'] = masks 56 | target['labels'] = labels 57 | 58 | target["boxes"] = masks_to_boxes(masks) 59 | 60 | target['size'] = torch.as_tensor([int(h), int(w)]) 61 | target['orig_size'] = torch.as_tensor([int(h), int(w)]) 62 | if "segments_info" in ann_info: 63 | for name in ['iscrowd', 'area']: 64 | target[name] = torch.tensor([ann[name] for ann in ann_info['segments_info']]) 65 | 66 | if self.transforms is not None: 67 | img, target = self.transforms(img, target) 68 | if self.norm_transforms is not None: 69 | img, target = self.norm_transforms(img, target) 70 | 71 | return img, target 72 | 73 | def __len__(self): 74 | return len(self.coco['images']) 75 | 76 | def get_height_and_width(self, idx): 77 | img_info = self.coco['images'][idx] 78 | height = img_info['height'] 79 | width = img_info['width'] 80 | return height, width 81 | 82 | 83 | def build(image_set, args): 84 | img_folder_root = Path(args.coco_path) 85 | ann_folder_root = Path(args.coco_panoptic_path) 86 | assert img_folder_root.exists(), f'provided COCO path {img_folder_root} does not exist' 87 | assert ann_folder_root.exists(), f'provided COCO path {ann_folder_root} does not exist' 88 | mode = 'panoptic' 89 | PATHS = { 90 | "train": ("train2017", Path("annotations") / f'{mode}_train2017.json'), 91 | "val": ("val2017", Path("annotations") / f'{mode}_val2017.json'), 92 | } 93 | 94 | img_folder, ann_file = PATHS[image_set] 95 | img_folder_path = img_folder_root / img_folder 96 | ann_folder = ann_folder_root / f'{mode}_{img_folder}' 97 | ann_file = ann_folder_root / ann_file 98 | 99 | transforms, norm_transforms = make_coco_transforms(image_set, args.img_transform, args.overflow_boxes) 100 | dataset = CocoPanoptic(img_folder_path, ann_folder, ann_file, 101 | transforms=transforms, norm_transforms=norm_transforms, return_masks=args.masks) 102 | 103 | return dataset 104 | -------------------------------------------------------------------------------- /src/trackformer/datasets/crowdhuman.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | CrowdHuman dataset with tracking training augmentations. 4 | """ 5 | from pathlib import Path 6 | 7 | from .coco import CocoDetection, make_coco_transforms 8 | 9 | 10 | def build_crowdhuman(image_set, args): 11 | root = Path(args.crowdhuman_path) 12 | assert root.exists(), f'provided COCO path {root} does not exist' 13 | 14 | split = getattr(args, f"{image_set}_split") 15 | 16 | img_folder = root / split 17 | ann_file = root / f'annotations/{split}.json' 18 | 19 | if image_set == 'train': 20 | prev_frame_rnd_augs = args.coco_and_crowdhuman_prev_frame_rnd_augs 21 | elif image_set == 'val': 22 | prev_frame_rnd_augs = 0.0 23 | 24 | transforms, norm_transforms = make_coco_transforms( 25 | image_set, args.img_transform, args.overflow_boxes) 26 | dataset = CocoDetection( 27 | img_folder, ann_file, transforms, norm_transforms, 28 | return_masks=args.masks, 29 | prev_frame=args.tracking, 30 | prev_frame_rnd_augs=prev_frame_rnd_augs) 31 | 32 | return dataset 33 | -------------------------------------------------------------------------------- /src/trackformer/datasets/mot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | MOT dataset with tracking training augmentations. 4 | """ 5 | import bisect 6 | import copy 7 | import csv 8 | import os 9 | import random 10 | from pathlib import Path 11 | 12 | import torch 13 | 14 | from . import transforms as T 15 | from .coco import CocoDetection, make_coco_transforms 16 | from .coco import build as build_coco 17 | from .crowdhuman import build_crowdhuman 18 | 19 | 20 | class MOT(CocoDetection): 21 | 22 | def __init__(self, *args, prev_frame_range=1, **kwargs): 23 | super(MOT, self).__init__(*args, **kwargs) 24 | 25 | self._prev_frame_range = prev_frame_range 26 | 27 | @property 28 | def sequences(self): 29 | return self.coco.dataset['sequences'] 30 | 31 | @property 32 | def frame_range(self): 33 | if 'frame_range' in self.coco.dataset: 34 | return self.coco.dataset['frame_range'] 35 | else: 36 | return {'start': 0, 'end': 1.0} 37 | 38 | def seq_length(self, idx): 39 | return self.coco.imgs[idx]['seq_length'] 40 | 41 | def sample_weight(self, idx): 42 | return 1.0 / self.seq_length(idx) 43 | 44 | def __getitem__(self, idx): 45 | random_state = { 46 | 'random': random.getstate(), 47 | 'torch': torch.random.get_rng_state()} 48 | 49 | img, target = self._getitem_from_id(idx, random_state, random_jitter=False) 50 | 51 | if self._prev_frame: 52 | frame_id = self.coco.imgs[idx]['frame_id'] 53 | 54 | # PREV 55 | # first frame has no previous frame 56 | prev_frame_id = random.randint( 57 | max(0, frame_id - self._prev_frame_range), 58 | min(frame_id + self._prev_frame_range, self.seq_length(idx) - 1)) 59 | prev_image_id = self.coco.imgs[idx]['first_frame_image_id'] + prev_frame_id 60 | 61 | prev_img, prev_target = self._getitem_from_id(prev_image_id, random_state) 62 | target[f'prev_image'] = prev_img 63 | target[f'prev_target'] = prev_target 64 | 65 | if self._prev_prev_frame: 66 | # PREV PREV frame equidistant as prev_frame 67 | prev_prev_frame_id = min(max(0, prev_frame_id + prev_frame_id - frame_id), self.seq_length(idx) - 1) 68 | prev_prev_image_id = self.coco.imgs[idx]['first_frame_image_id'] + prev_prev_frame_id 69 | 70 | prev_prev_img, prev_prev_target = self._getitem_from_id(prev_prev_image_id, random_state) 71 | target[f'prev_prev_image'] = prev_prev_img 72 | target[f'prev_prev_target'] = prev_prev_target 73 | 74 | return img, target 75 | 76 | def write_result_files(self, results, output_dir): 77 | """Write the detections in the format for the MOT17Det sumbission 78 | 79 | Each file contains these lines: 80 | , , , , , , , , , 81 | 82 | """ 83 | 84 | files = {} 85 | for image_id, res in results.items(): 86 | img = self.coco.loadImgs(image_id)[0] 87 | file_name_without_ext = os.path.splitext(img['file_name'])[0] 88 | seq_name, frame = file_name_without_ext.split('_') 89 | frame = int(frame) 90 | 91 | outfile = os.path.join(output_dir, f"{seq_name}.txt") 92 | 93 | # check if out in keys and create empty list if not 94 | if outfile not in files.keys(): 95 | files[outfile] = [] 96 | 97 | for box, score in zip(res['boxes'], res['scores']): 98 | if score <= 0.7: 99 | continue 100 | x1 = box[0].item() 101 | y1 = box[1].item() 102 | x2 = box[2].item() 103 | y2 = box[3].item() 104 | files[outfile].append( 105 | [frame, -1, x1, y1, x2 - x1, y2 - y1, score.item(), -1, -1, -1]) 106 | 107 | for k, v in files.items(): 108 | with open(k, "w") as of: 109 | writer = csv.writer(of, delimiter=',') 110 | for d in v: 111 | writer.writerow(d) 112 | 113 | 114 | class WeightedConcatDataset(torch.utils.data.ConcatDataset): 115 | 116 | def sample_weight(self, idx): 117 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 118 | if dataset_idx == 0: 119 | sample_idx = idx 120 | else: 121 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 122 | 123 | if hasattr(self.datasets[dataset_idx], 'sample_weight'): 124 | return self.datasets[dataset_idx].sample_weight(sample_idx) 125 | else: 126 | return 1 / len(self.datasets[dataset_idx]) 127 | 128 | 129 | def build_mot(image_set, args): 130 | if image_set == 'train': 131 | root = Path(args.mot_path_train) 132 | prev_frame_rnd_augs = args.track_prev_frame_rnd_augs 133 | prev_frame_range=args.track_prev_frame_range 134 | elif image_set == 'val': 135 | root = Path(args.mot_path_val) 136 | prev_frame_rnd_augs = 0.0 137 | prev_frame_range = 1 138 | else: 139 | ValueError(f'unknown {image_set}') 140 | 141 | assert root.exists(), f'provided MOT17Det path {root} does not exist' 142 | 143 | split = getattr(args, f"{image_set}_split") 144 | 145 | img_folder = root / split 146 | ann_file = root / f"annotations/{split}.json" 147 | 148 | transforms, norm_transforms = make_coco_transforms( 149 | image_set, args.img_transform, args.overflow_boxes) 150 | 151 | dataset = MOT( 152 | img_folder, ann_file, transforms, norm_transforms, 153 | prev_frame_range=prev_frame_range, 154 | return_masks=args.masks, 155 | overflow_boxes=args.overflow_boxes, 156 | remove_no_obj_imgs=False, 157 | prev_frame=args.tracking, 158 | prev_frame_rnd_augs=prev_frame_rnd_augs, 159 | prev_prev_frame=args.track_prev_prev_frame, 160 | ) 161 | 162 | return dataset 163 | 164 | 165 | def build_mot_crowdhuman(image_set, args): 166 | if image_set == 'train': 167 | args_crowdhuman = copy.deepcopy(args) 168 | args_crowdhuman.train_split = args.crowdhuman_train_split 169 | 170 | crowdhuman_dataset = build_crowdhuman('train', args_crowdhuman) 171 | 172 | if getattr(args, f"{image_set}_split") is None: 173 | return crowdhuman_dataset 174 | 175 | dataset = build_mot(image_set, args) 176 | 177 | if image_set == 'train': 178 | dataset = torch.utils.data.ConcatDataset( 179 | [dataset, crowdhuman_dataset]) 180 | 181 | return dataset 182 | 183 | 184 | def build_mot_coco_person(image_set, args): 185 | if image_set == 'train': 186 | args_coco_person = copy.deepcopy(args) 187 | args_coco_person.train_split = args.coco_person_train_split 188 | 189 | coco_person_dataset = build_coco('train', args_coco_person, 'person_keypoints') 190 | 191 | if getattr(args, f"{image_set}_split") is None: 192 | return coco_person_dataset 193 | 194 | dataset = build_mot(image_set, args) 195 | 196 | if image_set == 'train': 197 | dataset = torch.utils.data.ConcatDataset( 198 | [dataset, coco_person_dataset]) 199 | 200 | return dataset 201 | -------------------------------------------------------------------------------- /src/trackformer/datasets/panoptic_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import json 3 | import os 4 | 5 | from ..util import misc as utils 6 | 7 | try: 8 | from panopticapi.evaluation import pq_compute 9 | except ImportError: 10 | pass 11 | 12 | 13 | class PanopticEvaluator(object): 14 | def __init__(self, ann_file, ann_folder, output_dir="panoptic_eval"): 15 | self.gt_json = ann_file 16 | self.gt_folder = ann_folder 17 | if utils.is_main_process(): 18 | if not os.path.exists(output_dir): 19 | os.mkdir(output_dir) 20 | self.output_dir = output_dir 21 | self.predictions = [] 22 | 23 | def update(self, predictions): 24 | for p in predictions: 25 | with open(os.path.join(self.output_dir, p["file_name"]), "wb") as f: 26 | f.write(p.pop("png_string")) 27 | 28 | self.predictions += predictions 29 | 30 | def synchronize_between_processes(self): 31 | all_predictions = utils.all_gather(self.predictions) 32 | merged_predictions = [] 33 | for p in all_predictions: 34 | merged_predictions += p 35 | self.predictions = merged_predictions 36 | 37 | def summarize(self): 38 | if utils.is_main_process(): 39 | json_data = {"annotations": self.predictions} 40 | predictions_json = os.path.join(self.output_dir, "predictions.json") 41 | with open(predictions_json, "w") as f: 42 | f.write(json.dumps(json_data)) 43 | return pq_compute( 44 | self.gt_json, predictions_json, 45 | gt_folder=self.gt_folder, pred_folder=self.output_dir) 46 | return None 47 | -------------------------------------------------------------------------------- /src/trackformer/datasets/tracking/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Submodule interface. 4 | """ 5 | from .factory import TrackDatasetFactory 6 | -------------------------------------------------------------------------------- /src/trackformer/datasets/tracking/demo_sequence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | MOT17 sequence dataset. 4 | """ 5 | import configparser 6 | import csv 7 | import os 8 | from pathlib import Path 9 | import os.path as osp 10 | from argparse import Namespace 11 | from typing import Optional, Tuple, List 12 | 13 | import numpy as np 14 | import torch 15 | from PIL import Image 16 | from torch.utils.data import Dataset 17 | 18 | from ..coco import make_coco_transforms 19 | from ..transforms import Compose 20 | 21 | 22 | class DemoSequence(Dataset): 23 | """DemoSequence (MOT17) Dataset. 24 | """ 25 | 26 | def __init__(self, root_dir: str = 'data', img_transform: Namespace = None) -> None: 27 | """ 28 | Args: 29 | seq_name (string): Sequence to take 30 | vis_threshold (float): Threshold of visibility of persons 31 | above which they are selected 32 | """ 33 | super().__init__() 34 | 35 | self._data_dir = Path(root_dir) 36 | assert self._data_dir.is_dir(), f'data_root_dir:{root_dir} does not exist.' 37 | 38 | self.transforms = Compose(make_coco_transforms('val', img_transform, overflow_boxes=True)) 39 | 40 | self.data = self._sequence() 41 | self.no_gt = True 42 | 43 | def __len__(self) -> int: 44 | return len(self.data) 45 | 46 | def __str__(self) -> str: 47 | return self._data_dir.name 48 | 49 | def __getitem__(self, idx: int) -> dict: 50 | """Return the ith image converted to blob""" 51 | data = self.data[idx] 52 | img = Image.open(data['im_path']).convert("RGB") 53 | width_orig, height_orig = img.size 54 | 55 | img, _ = self.transforms(img) 56 | width, height = img.size(2), img.size(1) 57 | 58 | sample = {} 59 | sample['img'] = img 60 | sample['img_path'] = data['im_path'] 61 | sample['dets'] = torch.tensor([]) 62 | sample['orig_size'] = torch.as_tensor([int(height_orig), int(width_orig)]) 63 | sample['size'] = torch.as_tensor([int(height), int(width)]) 64 | 65 | return sample 66 | 67 | def _sequence(self) -> List[dict]: 68 | total = [] 69 | for filename in sorted(os.listdir(self._data_dir)): 70 | extension = os.path.splitext(filename)[1] 71 | if extension in ['.png', '.jpg']: 72 | total.append({'im_path': osp.join(self._data_dir, filename)}) 73 | 74 | return total 75 | 76 | def load_results(self, results_dir: str) -> dict: 77 | return {} 78 | 79 | def write_results(self, results: dict, output_dir: str) -> None: 80 | """Write the tracks in the format for MOT16/MOT17 sumbission 81 | 82 | results: dictionary with 1 dictionary for every track with 83 | {..., i:np.array([x1,y1,x2,y2]), ...} at key track_num 84 | 85 | Each file contains these lines: 86 | , , , , , , , , , 87 | """ 88 | 89 | # format_str = "{}, -1, {}, {}, {}, {}, {}, -1, -1, -1" 90 | if not os.path.exists(output_dir): 91 | os.makedirs(output_dir) 92 | 93 | result_file_path = osp.join(output_dir, self._data_dir.name) 94 | 95 | with open(result_file_path, "w") as r_file: 96 | writer = csv.writer(r_file, delimiter=',') 97 | 98 | for i, track in results.items(): 99 | for frame, data in track.items(): 100 | x1 = data['bbox'][0] 101 | y1 = data['bbox'][1] 102 | x2 = data['bbox'][2] 103 | y2 = data['bbox'][3] 104 | 105 | writer.writerow([ 106 | frame + 1, 107 | i + 1, 108 | x1 + 1, 109 | y1 + 1, 110 | x2 - x1 + 1, 111 | y2 - y1 + 1, 112 | -1, -1, -1, -1]) 113 | -------------------------------------------------------------------------------- /src/trackformer/datasets/tracking/factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Factory of tracking datasets. 4 | """ 5 | from typing import Union 6 | 7 | from torch.utils.data import ConcatDataset 8 | 9 | from .demo_sequence import DemoSequence 10 | from .mot_wrapper import MOT17Wrapper, MOT20Wrapper, MOTS20Wrapper 11 | 12 | DATASETS = {} 13 | 14 | # Fill all available datasets, change here to modify / add new datasets. 15 | for split in ['TRAIN', 'TEST', 'ALL', '01', '02', '03', '04', '05', 16 | '06', '07', '08', '09', '10', '11', '12', '13', '14']: 17 | for dets in ['DPM', 'FRCNN', 'SDP', 'ALL']: 18 | name = f'MOT17-{split}' 19 | if dets: 20 | name = f"{name}-{dets}" 21 | DATASETS[name] = ( 22 | lambda kwargs, split=split, dets=dets: MOT17Wrapper(split, dets, **kwargs)) 23 | 24 | 25 | for split in ['TRAIN', 'TEST', 'ALL', '01', '02', '03', '04', '05', 26 | '06', '07', '08']: 27 | name = f'MOT20-{split}' 28 | DATASETS[name] = ( 29 | lambda kwargs, split=split: MOT20Wrapper(split, **kwargs)) 30 | 31 | 32 | for split in ['TRAIN', 'TEST', 'ALL', '01', '02', '05', '06', '07', '09', '11', '12']: 33 | name = f'MOTS20-{split}' 34 | DATASETS[name] = ( 35 | lambda kwargs, split=split: MOTS20Wrapper(split, **kwargs)) 36 | 37 | DATASETS['DEMO'] = (lambda kwargs: [DemoSequence(**kwargs), ]) 38 | 39 | 40 | class TrackDatasetFactory: 41 | """A central class to manage the individual dataset loaders. 42 | 43 | This class contains the datasets. Once initialized the individual parts (e.g. sequences) 44 | can be accessed. 45 | """ 46 | 47 | def __init__(self, datasets: Union[str, list], **kwargs) -> None: 48 | """Initialize the corresponding dataloader. 49 | 50 | Keyword arguments: 51 | datasets -- the name of the dataset or list of dataset names 52 | kwargs -- arguments used to call the datasets 53 | """ 54 | if isinstance(datasets, str): 55 | datasets = [datasets] 56 | 57 | self._data = None 58 | for dataset in datasets: 59 | assert dataset in DATASETS, f"[!] Dataset not found: {dataset}" 60 | 61 | if self._data is None: 62 | self._data = DATASETS[dataset](kwargs) 63 | else: 64 | self._data = ConcatDataset([self._data, DATASETS[dataset](kwargs)]) 65 | 66 | def __len__(self) -> int: 67 | return len(self._data) 68 | 69 | def __getitem__(self, idx: int): 70 | return self._data[idx] 71 | -------------------------------------------------------------------------------- /src/trackformer/datasets/tracking/mot17_sequence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | MOT17 sequence dataset. 4 | """ 5 | import configparser 6 | import csv 7 | import os 8 | import os.path as osp 9 | from argparse import Namespace 10 | from typing import Optional, Tuple, List 11 | 12 | import numpy as np 13 | import torch 14 | from PIL import Image 15 | from torch.utils.data import Dataset 16 | 17 | from ..coco import make_coco_transforms 18 | from ..transforms import Compose 19 | 20 | 21 | class MOT17Sequence(Dataset): 22 | """Multiple Object Tracking (MOT17) Dataset. 23 | 24 | This dataloader is designed so that it can handle only one sequence, 25 | if more have to be handled one should inherit from this class. 26 | """ 27 | data_folder = 'MOT17' 28 | 29 | def __init__(self, root_dir: str = 'data', seq_name: Optional[str] = None, 30 | dets: str = '', vis_threshold: float = 0.0, img_transform: Namespace = None) -> None: 31 | """ 32 | Args: 33 | seq_name (string): Sequence to take 34 | vis_threshold (float): Threshold of visibility of persons 35 | above which they are selected 36 | """ 37 | super().__init__() 38 | 39 | self._seq_name = seq_name 40 | self._dets = dets 41 | self._vis_threshold = vis_threshold 42 | 43 | self._data_dir = osp.join(root_dir, self.data_folder) 44 | 45 | self._train_folders = os.listdir(os.path.join(self._data_dir, 'train')) 46 | self._test_folders = os.listdir(os.path.join(self._data_dir, 'test')) 47 | 48 | self.transforms = Compose(make_coco_transforms('val', img_transform, overflow_boxes=True)) 49 | 50 | self.data = [] 51 | self.no_gt = True 52 | if seq_name is not None: 53 | full_seq_name = seq_name 54 | if self._dets is not None: 55 | full_seq_name = f"{seq_name}-{dets}" 56 | assert full_seq_name in self._train_folders or full_seq_name in self._test_folders, \ 57 | 'Image set does not exist: {}'.format(full_seq_name) 58 | 59 | self.data = self._sequence() 60 | self.no_gt = not osp.exists(self.get_gt_file_path()) 61 | 62 | def __len__(self) -> int: 63 | return len(self.data) 64 | 65 | def __getitem__(self, idx: int) -> dict: 66 | """Return the ith image converted to blob""" 67 | data = self.data[idx] 68 | img = Image.open(data['im_path']).convert("RGB") 69 | width_orig, height_orig = img.size 70 | 71 | img, _ = self.transforms(img) 72 | width, height = img.size(2), img.size(1) 73 | 74 | sample = {} 75 | sample['img'] = img 76 | sample['dets'] = torch.tensor([det[:4] for det in data['dets']]) 77 | sample['img_path'] = data['im_path'] 78 | sample['gt'] = data['gt'] 79 | sample['vis'] = data['vis'] 80 | sample['orig_size'] = torch.as_tensor([int(height_orig), int(width_orig)]) 81 | sample['size'] = torch.as_tensor([int(height), int(width)]) 82 | 83 | return sample 84 | 85 | def _sequence(self) -> List[dict]: 86 | # public detections 87 | dets = {i: [] for i in range(1, self.seq_length + 1)} 88 | det_file = self.get_det_file_path() 89 | 90 | if osp.exists(det_file): 91 | with open(det_file, "r") as inf: 92 | reader = csv.reader(inf, delimiter=',') 93 | for row in reader: 94 | x1 = float(row[2]) - 1 95 | y1 = float(row[3]) - 1 96 | # This -1 accounts for the width (width of 1 x1=x2) 97 | x2 = x1 + float(row[4]) - 1 98 | y2 = y1 + float(row[5]) - 1 99 | score = float(row[6]) 100 | bbox = np.array([x1, y1, x2, y2, score], dtype=np.float32) 101 | dets[int(float(row[0]))].append(bbox) 102 | 103 | # accumulate total 104 | img_dir = osp.join( 105 | self.get_seq_path(), 106 | self.config['Sequence']['imDir']) 107 | 108 | boxes, visibility = self.get_track_boxes_and_visbility() 109 | 110 | total = [ 111 | {'gt': boxes[i], 112 | 'im_path': osp.join(img_dir, f"{i:06d}.jpg"), 113 | 'vis': visibility[i], 114 | 'dets': dets[i]} 115 | for i in range(1, self.seq_length + 1)] 116 | 117 | return total 118 | 119 | def get_track_boxes_and_visbility(self) -> Tuple[dict, dict]: 120 | """ Load ground truth boxes and their visibility.""" 121 | boxes = {} 122 | visibility = {} 123 | 124 | for i in range(1, self.seq_length + 1): 125 | boxes[i] = {} 126 | visibility[i] = {} 127 | 128 | gt_file = self.get_gt_file_path() 129 | if not osp.exists(gt_file): 130 | return boxes, visibility 131 | 132 | with open(gt_file, "r") as inf: 133 | reader = csv.reader(inf, delimiter=',') 134 | for row in reader: 135 | # class person, certainity 1 136 | if int(row[6]) == 1 and int(row[7]) == 1 and float(row[8]) >= self._vis_threshold: 137 | # Make pixel indexes 0-based, should already be 0-based (or not) 138 | x1 = int(row[2]) - 1 139 | y1 = int(row[3]) - 1 140 | # This -1 accounts for the width (width of 1 x1=x2) 141 | x2 = x1 + int(row[4]) - 1 142 | y2 = y1 + int(row[5]) - 1 143 | bbox = np.array([x1, y1, x2, y2], dtype=np.float32) 144 | 145 | frame_id = int(row[0]) 146 | track_id = int(row[1]) 147 | 148 | boxes[frame_id][track_id] = bbox 149 | visibility[frame_id][track_id] = float(row[8]) 150 | 151 | return boxes, visibility 152 | 153 | def get_seq_path(self) -> str: 154 | """ Return directory path of sequence. """ 155 | full_seq_name = self._seq_name 156 | if self._dets is not None: 157 | full_seq_name = f"{self._seq_name}-{self._dets}" 158 | 159 | if full_seq_name in self._train_folders: 160 | return osp.join(self._data_dir, 'train', full_seq_name) 161 | else: 162 | return osp.join(self._data_dir, 'test', full_seq_name) 163 | 164 | def get_config_file_path(self) -> str: 165 | """ Return config file of sequence. """ 166 | return osp.join(self.get_seq_path(), 'seqinfo.ini') 167 | 168 | def get_gt_file_path(self) -> str: 169 | """ Return ground truth file of sequence. """ 170 | return osp.join(self.get_seq_path(), 'gt', 'gt.txt') 171 | 172 | def get_det_file_path(self) -> str: 173 | """ Return public detections file of sequence. """ 174 | if self._dets is None: 175 | return "" 176 | 177 | return osp.join(self.get_seq_path(), 'det', 'det.txt') 178 | 179 | @property 180 | def config(self) -> dict: 181 | """ Return config of sequence. """ 182 | config_file = self.get_config_file_path() 183 | 184 | assert osp.exists(config_file), \ 185 | f'Config file does not exist: {config_file}' 186 | 187 | config = configparser.ConfigParser() 188 | config.read(config_file) 189 | return config 190 | 191 | @property 192 | def seq_length(self) -> int: 193 | """ Return sequence length, i.e, number of frames. """ 194 | return int(self.config['Sequence']['seqLength']) 195 | 196 | def __str__(self) -> str: 197 | return f"{self._seq_name}-{self._dets}" 198 | 199 | @property 200 | def results_file_name(self) -> str: 201 | """ Generate file name of results file. """ 202 | assert self._seq_name is not None, "[!] No seq_name, probably using combined database" 203 | 204 | if self._dets is None: 205 | return f"{self._seq_name}.txt" 206 | 207 | return f"{self}.txt" 208 | 209 | def write_results(self, results: dict, output_dir: str) -> None: 210 | """Write the tracks in the format for MOT16/MOT17 sumbission 211 | 212 | results: dictionary with 1 dictionary for every track with 213 | {..., i:np.array([x1,y1,x2,y2]), ...} at key track_num 214 | 215 | Each file contains these lines: 216 | , , , , , , , , , 217 | """ 218 | 219 | # format_str = "{}, -1, {}, {}, {}, {}, {}, -1, -1, -1" 220 | if not os.path.exists(output_dir): 221 | os.makedirs(output_dir) 222 | 223 | result_file_path = osp.join(output_dir, self.results_file_name) 224 | 225 | with open(result_file_path, "w") as r_file: 226 | writer = csv.writer(r_file, delimiter=',') 227 | 228 | for i, track in results.items(): 229 | for frame, data in track.items(): 230 | x1 = data['bbox'][0] 231 | y1 = data['bbox'][1] 232 | x2 = data['bbox'][2] 233 | y2 = data['bbox'][3] 234 | 235 | writer.writerow([ 236 | frame + 1, 237 | i + 1, 238 | x1 + 1, 239 | y1 + 1, 240 | x2 - x1 + 1, 241 | y2 - y1 + 1, 242 | -1, -1, -1, -1]) 243 | 244 | def load_results(self, results_dir: str) -> dict: 245 | results = {} 246 | if results_dir is None: 247 | return results 248 | 249 | file_path = osp.join(results_dir, self.results_file_name) 250 | 251 | if not os.path.isfile(file_path): 252 | return results 253 | 254 | with open(file_path, "r") as file: 255 | csv_reader = csv.reader(file, delimiter=',') 256 | 257 | for row in csv_reader: 258 | frame_id, track_id = int(row[0]) - 1, int(row[1]) - 1 259 | 260 | if track_id not in results: 261 | results[track_id] = {} 262 | 263 | x1 = float(row[2]) - 1 264 | y1 = float(row[3]) - 1 265 | x2 = float(row[4]) - 1 + x1 266 | y2 = float(row[5]) - 1 + y1 267 | 268 | results[track_id][frame_id] = {} 269 | results[track_id][frame_id]['bbox'] = [x1, y1, x2, y2] 270 | results[track_id][frame_id]['score'] = 1.0 271 | 272 | return results 273 | 274 | -------------------------------------------------------------------------------- /src/trackformer/datasets/tracking/mot20_sequence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | MOT20 sequence dataset. 4 | """ 5 | 6 | from .mot17_sequence import MOT17Sequence 7 | 8 | 9 | class MOT20Sequence(MOT17Sequence): 10 | """Multiple Object Tracking (MOT20) Dataset. 11 | 12 | This dataloader is designed so that it can handle only one sequence, 13 | if more have to be handled one should inherit from this class. 14 | """ 15 | data_folder = 'MOT20' 16 | -------------------------------------------------------------------------------- /src/trackformer/datasets/tracking/mot_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | MOT wrapper which combines sequences to a dataset. 4 | """ 5 | from torch.utils.data import Dataset 6 | 7 | from .mot17_sequence import MOT17Sequence 8 | from .mot20_sequence import MOT20Sequence 9 | from .mots20_sequence import MOTS20Sequence 10 | 11 | 12 | class MOT17Wrapper(Dataset): 13 | """A Wrapper for the MOT_Sequence class to return multiple sequences.""" 14 | 15 | def __init__(self, split: str, dets: str, **kwargs) -> None: 16 | """Initliazes all subset of the dataset. 17 | 18 | Keyword arguments: 19 | split -- the split of the dataset to use 20 | kwargs -- kwargs for the MOT17Sequence dataset 21 | """ 22 | train_sequences = [ 23 | 'MOT17-02', 'MOT17-04', 'MOT17-05', 'MOT17-09', 24 | 'MOT17-10', 'MOT17-11', 'MOT17-13'] 25 | test_sequences = [ 26 | 'MOT17-01', 'MOT17-03', 'MOT17-06', 'MOT17-07', 27 | 'MOT17-08', 'MOT17-12', 'MOT17-14'] 28 | 29 | if split == "TRAIN": 30 | sequences = train_sequences 31 | elif split == "TEST": 32 | sequences = test_sequences 33 | elif split == "ALL": 34 | sequences = train_sequences + test_sequences 35 | sequences = sorted(sequences) 36 | elif f"MOT17-{split}" in train_sequences + test_sequences: 37 | sequences = [f"MOT17-{split}"] 38 | else: 39 | raise NotImplementedError("MOT17 split not available.") 40 | 41 | self._data = [] 42 | for seq in sequences: 43 | if dets == 'ALL': 44 | self._data.append(MOT17Sequence(seq_name=seq, dets='DPM', **kwargs)) 45 | self._data.append(MOT17Sequence(seq_name=seq, dets='FRCNN', **kwargs)) 46 | self._data.append(MOT17Sequence(seq_name=seq, dets='SDP', **kwargs)) 47 | else: 48 | self._data.append(MOT17Sequence(seq_name=seq, dets=dets, **kwargs)) 49 | 50 | def __len__(self) -> int: 51 | return len(self._data) 52 | 53 | def __getitem__(self, idx: int): 54 | return self._data[idx] 55 | 56 | 57 | class MOT20Wrapper(Dataset): 58 | """A Wrapper for the MOT_Sequence class to return multiple sequences.""" 59 | 60 | def __init__(self, split: str, **kwargs) -> None: 61 | """Initliazes all subset of the dataset. 62 | 63 | Keyword arguments: 64 | split -- the split of the dataset to use 65 | kwargs -- kwargs for the MOT20Sequence dataset 66 | """ 67 | train_sequences = ['MOT20-01', 'MOT20-02', 'MOT20-03', 'MOT20-05',] 68 | test_sequences = ['MOT20-04', 'MOT20-06', 'MOT20-07', 'MOT20-08',] 69 | 70 | if split == "TRAIN": 71 | sequences = train_sequences 72 | elif split == "TEST": 73 | sequences = test_sequences 74 | elif split == "ALL": 75 | sequences = train_sequences + test_sequences 76 | sequences = sorted(sequences) 77 | elif f"MOT20-{split}" in train_sequences + test_sequences: 78 | sequences = [f"MOT20-{split}"] 79 | else: 80 | raise NotImplementedError("MOT20 split not available.") 81 | 82 | self._data = [] 83 | for seq in sequences: 84 | self._data.append(MOT20Sequence(seq_name=seq, dets=None, **kwargs)) 85 | 86 | def __len__(self) -> int: 87 | return len(self._data) 88 | 89 | def __getitem__(self, idx: int): 90 | return self._data[idx] 91 | 92 | 93 | class MOTS20Wrapper(MOT17Wrapper): 94 | """A Wrapper for the MOT_Sequence class to return multiple sequences.""" 95 | 96 | def __init__(self, split: str, **kwargs) -> None: 97 | """Initliazes all subset of the dataset. 98 | 99 | Keyword arguments: 100 | split -- the split of the dataset to use 101 | kwargs -- kwargs for the MOTS20Sequence dataset 102 | """ 103 | train_sequences = ['MOTS20-02', 'MOTS20-05', 'MOTS20-09', 'MOTS20-11'] 104 | test_sequences = ['MOTS20-01', 'MOTS20-06', 'MOTS20-07', 'MOTS20-12'] 105 | 106 | if split == "TRAIN": 107 | sequences = train_sequences 108 | elif split == "TEST": 109 | sequences = test_sequences 110 | elif split == "ALL": 111 | sequences = train_sequences + test_sequences 112 | sequences = sorted(sequences) 113 | elif f"MOTS20-{split}" in train_sequences + test_sequences: 114 | sequences = [f"MOTS20-{split}"] 115 | else: 116 | raise NotImplementedError("MOTS20 split not available.") 117 | 118 | self._data = [] 119 | for seq in sequences: 120 | self._data.append(MOTS20Sequence(seq_name=seq, **kwargs)) 121 | -------------------------------------------------------------------------------- /src/trackformer/datasets/tracking/mots20_sequence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | MOTS20 sequence dataset. 4 | """ 5 | import csv 6 | import os 7 | import os.path as osp 8 | from argparse import Namespace 9 | from typing import Optional, Tuple 10 | 11 | import numpy as np 12 | import pycocotools.mask as rletools 13 | 14 | from .mot17_sequence import MOT17Sequence 15 | 16 | 17 | class MOTS20Sequence(MOT17Sequence): 18 | """Multiple Object and Segmentation Tracking (MOTS20) Dataset. 19 | 20 | This dataloader is designed so that it can handle only one sequence, 21 | if more have to be handled one should inherit from this class. 22 | """ 23 | data_folder = 'MOTS20' 24 | 25 | def __init__(self, root_dir: str = 'data', seq_name: Optional[str] = None, 26 | vis_threshold: float = 0.0, img_transform: Namespace = None) -> None: 27 | """ 28 | Args: 29 | seq_name (string): Sequence to take 30 | vis_threshold (float): Threshold of visibility of persons 31 | above which they are selected 32 | """ 33 | super().__init__(root_dir, seq_name, None, vis_threshold, img_transform) 34 | 35 | def get_track_boxes_and_visbility(self) -> Tuple[dict, dict]: 36 | boxes = {} 37 | visibility = {} 38 | 39 | for i in range(1, self.seq_length + 1): 40 | boxes[i] = {} 41 | visibility[i] = {} 42 | 43 | gt_file = self.get_gt_file_path() 44 | if not osp.exists(gt_file): 45 | return boxes, visibility 46 | 47 | mask_objects_per_frame = load_mots_gt(gt_file) 48 | for frame_id, mask_objects in mask_objects_per_frame.items(): 49 | for mask_object in mask_objects: 50 | # class_id = 1 is car 51 | # class_id = 2 is pedestrian 52 | # class_id = 10 IGNORE 53 | if mask_object.class_id in [1, 10]: 54 | continue 55 | 56 | bbox = rletools.toBbox(mask_object.mask) 57 | x1, y1, w, h = [int(c) for c in bbox] 58 | bbox = np.array([x1, y1, x1 + w, y1 + h], dtype=np.float32) 59 | 60 | # area = bbox[2] * bbox[3] 61 | # image_id = img_file_name_to_id[f"{seq}_{frame_id:06d}.jpg"] 62 | 63 | # segmentation = { 64 | # 'size': mask_object.mask['size'], 65 | # 'counts': mask_object.mask['counts'].decode(encoding='UTF-8')} 66 | 67 | boxes[frame_id][mask_object.track_id] = bbox 68 | visibility[frame_id][mask_object.track_id] = 1.0 69 | 70 | return boxes, visibility 71 | 72 | def write_results(self, results: dict, output_dir: str) -> None: 73 | if not os.path.exists(output_dir): 74 | os.makedirs(output_dir) 75 | 76 | result_file_path = osp.join(output_dir, f"{self._seq_name}.txt") 77 | 78 | with open(result_file_path, "w") as res_file: 79 | writer = csv.writer(res_file, delimiter=' ') 80 | for i, track in results.items(): 81 | for frame, data in track.items(): 82 | mask = np.asfortranarray(data['mask']) 83 | rle_mask = rletools.encode(mask) 84 | 85 | writer.writerow([ 86 | frame + 1, 87 | i + 1, 88 | 2, # class pedestrian 89 | mask.shape[0], 90 | mask.shape[1], 91 | rle_mask['counts'].decode(encoding='UTF-8')]) 92 | 93 | def load_results(self, results_dir: str) -> dict: 94 | results = {} 95 | 96 | if results_dir is None: 97 | return results 98 | 99 | file_path = osp.join(results_dir, self.results_file_name) 100 | 101 | if not os.path.isfile(file_path): 102 | return results 103 | 104 | mask_objects_per_frame = load_mots_gt(file_path) 105 | 106 | for frame_id, mask_objects in mask_objects_per_frame.items(): 107 | for mask_object in mask_objects: 108 | # class_id = 1 is car 109 | # class_id = 2 is pedestrian 110 | # class_id = 10 IGNORE 111 | if mask_object.class_id in [1, 10]: 112 | continue 113 | 114 | bbox = rletools.toBbox(mask_object.mask) 115 | x1, y1, w, h = [int(c) for c in bbox] 116 | bbox = np.array([x1, y1, x1 + w, y1 + h], dtype=np.float32) 117 | 118 | # area = bbox[2] * bbox[3] 119 | # image_id = img_file_name_to_id[f"{seq}_{frame_id:06d}.jpg"] 120 | 121 | # segmentation = { 122 | # 'size': mask_object.mask['size'], 123 | # 'counts': mask_object.mask['counts'].decode(encoding='UTF-8')} 124 | 125 | track_id = mask_object.track_id - 1 126 | if track_id not in results: 127 | results[track_id] = {} 128 | 129 | results[track_id][frame_id - 1] = {} 130 | results[track_id][frame_id - 1]['mask'] = rletools.decode(mask_object.mask) 131 | results[track_id][frame_id - 1]['bbox'] = bbox.tolist() 132 | results[track_id][frame_id - 1]['score'] = 1.0 133 | 134 | return results 135 | 136 | def __str__(self) -> str: 137 | return self._seq_name 138 | 139 | 140 | class SegmentedObject: 141 | """ 142 | Helper class for segmentation objects. 143 | """ 144 | def __init__(self, mask: dict, class_id: int, track_id: int) -> None: 145 | self.mask = mask 146 | self.class_id = class_id 147 | self.track_id = track_id 148 | 149 | 150 | def load_mots_gt(path: str) -> dict: 151 | """Load MOTS ground truth from path.""" 152 | objects_per_frame = {} 153 | track_ids_per_frame = {} # Check that no frame contains two objects with same id 154 | combined_mask_per_frame = {} # Check that no frame contains overlapping masks 155 | 156 | with open(path, "r") as gt_file: 157 | for line in gt_file: 158 | line = line.strip() 159 | fields = line.split(" ") 160 | 161 | frame = int(fields[0]) 162 | if frame not in objects_per_frame: 163 | objects_per_frame[frame] = [] 164 | if frame not in track_ids_per_frame: 165 | track_ids_per_frame[frame] = set() 166 | if int(fields[1]) in track_ids_per_frame[frame]: 167 | assert False, f"Multiple objects with track id {fields[1]} in frame {fields[0]}" 168 | else: 169 | track_ids_per_frame[frame].add(int(fields[1])) 170 | 171 | class_id = int(fields[2]) 172 | if not(class_id == 1 or class_id == 2 or class_id == 10): 173 | assert False, "Unknown object class " + fields[2] 174 | 175 | mask = { 176 | 'size': [int(fields[3]), int(fields[4])], 177 | 'counts': fields[5].encode(encoding='UTF-8')} 178 | if frame not in combined_mask_per_frame: 179 | combined_mask_per_frame[frame] = mask 180 | elif rletools.area(rletools.merge([ 181 | combined_mask_per_frame[frame], mask], 182 | intersect=True)): 183 | assert False, "Objects with overlapping masks in frame " + fields[0] 184 | else: 185 | combined_mask_per_frame[frame] = rletools.merge( 186 | [combined_mask_per_frame[frame], mask], 187 | intersect=False) 188 | objects_per_frame[frame].append(SegmentedObject( 189 | mask, 190 | class_id, 191 | int(fields[1]) 192 | )) 193 | 194 | return objects_per_frame 195 | -------------------------------------------------------------------------------- /src/trackformer/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | 4 | from .backbone import build_backbone 5 | from .deformable_detr import DeformableDETR, DeformablePostProcess 6 | from .deformable_transformer import build_deforamble_transformer 7 | from .detr import DETR, PostProcess, SetCriterion 8 | from .detr_segmentation import (DeformableDETRSegm, DeformableDETRSegmTracking, 9 | DETRSegm, DETRSegmTracking, 10 | PostProcessPanoptic, PostProcessSegm) 11 | from .detr_tracking import DeformableDETRTracking, DETRTracking 12 | from .matcher import build_matcher 13 | from .transformer import build_transformer 14 | 15 | 16 | def build_model(args): 17 | if args.dataset == 'coco': 18 | num_classes = 91 19 | elif args.dataset == 'coco_panoptic': 20 | num_classes = 250 21 | elif args.dataset in ['coco_person', 'mot', 'mot_crowdhuman', 'crowdhuman', 'mot_coco_person']: 22 | # num_classes = 91 23 | num_classes = 20 24 | # num_classes = 1 25 | else: 26 | raise NotImplementedError 27 | 28 | device = torch.device(args.device) 29 | backbone = build_backbone(args) 30 | matcher = build_matcher(args) 31 | 32 | detr_kwargs = { 33 | 'backbone': backbone, 34 | 'num_classes': num_classes - 1 if args.focal_loss else num_classes, 35 | 'num_queries': args.num_queries, 36 | 'aux_loss': args.aux_loss, 37 | 'overflow_boxes': args.overflow_boxes} 38 | 39 | tracking_kwargs = { 40 | 'track_query_false_positive_prob': args.track_query_false_positive_prob, 41 | 'track_query_false_negative_prob': args.track_query_false_negative_prob, 42 | 'matcher': matcher, 43 | 'backprop_prev_frame': args.track_backprop_prev_frame,} 44 | 45 | mask_kwargs = { 46 | 'freeze_detr': args.freeze_detr} 47 | 48 | if args.deformable: 49 | transformer = build_deforamble_transformer(args) 50 | 51 | detr_kwargs['transformer'] = transformer 52 | detr_kwargs['num_feature_levels'] = args.num_feature_levels 53 | detr_kwargs['with_box_refine'] = args.with_box_refine 54 | detr_kwargs['two_stage'] = args.two_stage 55 | detr_kwargs['multi_frame_attention'] = args.multi_frame_attention 56 | detr_kwargs['multi_frame_encoding'] = args.multi_frame_encoding 57 | detr_kwargs['merge_frame_features'] = args.merge_frame_features 58 | 59 | if args.tracking: 60 | if args.masks: 61 | model = DeformableDETRSegmTracking(mask_kwargs, tracking_kwargs, detr_kwargs) 62 | else: 63 | model = DeformableDETRTracking(tracking_kwargs, detr_kwargs) 64 | else: 65 | if args.masks: 66 | model = DeformableDETRSegm(mask_kwargs, detr_kwargs) 67 | else: 68 | model = DeformableDETR(**detr_kwargs) 69 | else: 70 | transformer = build_transformer(args) 71 | 72 | detr_kwargs['transformer'] = transformer 73 | 74 | if args.tracking: 75 | if args.masks: 76 | model = DETRSegmTracking(mask_kwargs, tracking_kwargs, detr_kwargs) 77 | else: 78 | model = DETRTracking(tracking_kwargs, detr_kwargs) 79 | else: 80 | if args.masks: 81 | model = DETRSegm(mask_kwargs, detr_kwargs) 82 | else: 83 | model = DETR(**detr_kwargs) 84 | 85 | weight_dict = {'loss_ce': args.cls_loss_coef, 86 | 'loss_bbox': args.bbox_loss_coef, 87 | 'loss_giou': args.giou_loss_coef,} 88 | 89 | if args.masks: 90 | weight_dict["loss_mask"] = args.mask_loss_coef 91 | weight_dict["loss_dice"] = args.dice_loss_coef 92 | 93 | # TODO this is a hack 94 | if args.aux_loss: 95 | aux_weight_dict = {} 96 | for i in range(args.dec_layers - 1): 97 | aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) 98 | 99 | if args.two_stage: 100 | aux_weight_dict.update({k + f'_enc': v for k, v in weight_dict.items()}) 101 | weight_dict.update(aux_weight_dict) 102 | 103 | losses = ['labels', 'boxes', 'cardinality'] 104 | if args.masks: 105 | losses.append('masks') 106 | 107 | criterion = SetCriterion( 108 | num_classes, 109 | matcher=matcher, 110 | weight_dict=weight_dict, 111 | eos_coef=args.eos_coef, 112 | losses=losses, 113 | focal_loss=args.focal_loss, 114 | focal_alpha=args.focal_alpha, 115 | focal_gamma=args.focal_gamma, 116 | tracking=args.tracking, 117 | track_query_false_positive_eos_weight=args.track_query_false_positive_eos_weight,) 118 | criterion.to(device) 119 | 120 | if args.focal_loss: 121 | postprocessors = {'bbox': DeformablePostProcess()} 122 | else: 123 | postprocessors = {'bbox': PostProcess()} 124 | if args.masks: 125 | postprocessors['segm'] = PostProcessSegm() 126 | if args.dataset == "coco_panoptic": 127 | is_thing_map = {i: i <= 90 for i in range(201)} 128 | postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85) 129 | 130 | return model, criterion, postprocessors 131 | -------------------------------------------------------------------------------- /src/trackformer/models/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Backbone modules. 4 | """ 5 | from typing import Dict, List 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torchvision 10 | from torch import nn 11 | from torchvision.models._utils import IntermediateLayerGetter 12 | from torchvision.ops.feature_pyramid_network import (FeaturePyramidNetwork, 13 | LastLevelMaxPool) 14 | 15 | from ..util.misc import NestedTensor, is_main_process 16 | from .position_encoding import build_position_encoding 17 | 18 | 19 | class FrozenBatchNorm2d(torch.nn.Module): 20 | """ 21 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 22 | 23 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 24 | without which any other models than torchvision.models.resnet[18,34,50,101] 25 | produce nans. 26 | """ 27 | 28 | def __init__(self, n): 29 | super(FrozenBatchNorm2d, self).__init__() 30 | self.register_buffer("weight", torch.ones(n)) 31 | self.register_buffer("bias", torch.zeros(n)) 32 | self.register_buffer("running_mean", torch.zeros(n)) 33 | self.register_buffer("running_var", torch.ones(n)) 34 | 35 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 36 | missing_keys, unexpected_keys, error_msgs): 37 | num_batches_tracked_key = prefix + 'num_batches_tracked' 38 | if num_batches_tracked_key in state_dict: 39 | del state_dict[num_batches_tracked_key] 40 | 41 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 42 | state_dict, prefix, local_metadata, strict, 43 | missing_keys, unexpected_keys, error_msgs) 44 | 45 | def forward(self, x): 46 | # move reshapes to the beginning 47 | # to make it fuser-friendly 48 | w = self.weight.reshape(1, -1, 1, 1) 49 | b = self.bias.reshape(1, -1, 1, 1) 50 | rv = self.running_var.reshape(1, -1, 1, 1) 51 | rm = self.running_mean.reshape(1, -1, 1, 1) 52 | eps = 1e-5 53 | scale = w * (rv + eps).rsqrt() 54 | bias = b - rm * scale 55 | return x * scale + bias 56 | 57 | 58 | class BackboneBase(nn.Module): 59 | 60 | def __init__(self, backbone: nn.Module, train_backbone: bool, 61 | return_interm_layers: bool): 62 | super().__init__() 63 | for name, parameter in backbone.named_parameters(): 64 | if (not train_backbone 65 | or 'layer2' not in name 66 | and 'layer3' not in name 67 | and 'layer4' not in name): 68 | parameter.requires_grad_(False) 69 | if return_interm_layers: 70 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 71 | # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} 72 | self.strides = [4, 8, 16, 32] 73 | self.num_channels = [256, 512, 1024, 2048] 74 | else: 75 | return_layers = {'layer4': "0"} 76 | self.strides = [32] 77 | self.num_channels = [2048] 78 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 79 | 80 | def forward(self, tensor_list: NestedTensor): 81 | xs = self.body(tensor_list.tensors) 82 | out: Dict[str, NestedTensor] = {} 83 | for name, x in xs.items(): 84 | m = tensor_list.mask 85 | assert m is not None 86 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 87 | out[name] = NestedTensor(x, mask) 88 | return out 89 | 90 | 91 | class Backbone(BackboneBase): 92 | """ResNet backbone with frozen BatchNorm.""" 93 | def __init__(self, name: str, 94 | train_backbone: bool, 95 | return_interm_layers: bool, 96 | dilation: bool): 97 | norm_layer = FrozenBatchNorm2d 98 | backbone = getattr(torchvision.models, name)( 99 | replace_stride_with_dilation=[False, False, dilation], 100 | pretrained=is_main_process(), norm_layer=norm_layer) 101 | super().__init__(backbone, train_backbone, 102 | return_interm_layers) 103 | if dilation: 104 | self.strides[-1] = self.strides[-1] // 2 105 | 106 | 107 | class Joiner(nn.Sequential): 108 | def __init__(self, backbone, position_embedding): 109 | super().__init__(backbone, position_embedding) 110 | self.strides = backbone.strides 111 | self.num_channels = backbone.num_channels 112 | 113 | def forward(self, tensor_list: NestedTensor): 114 | xs = self[0](tensor_list) 115 | out: List[NestedTensor] = [] 116 | pos = [] 117 | for x in xs.values(): 118 | out.append(x) 119 | # position encoding 120 | pos.append(self[1](x).to(x.tensors.dtype)) 121 | 122 | return out, pos 123 | 124 | 125 | def build_backbone(args): 126 | position_embedding = build_position_encoding(args) 127 | train_backbone = args.lr_backbone > 0 128 | return_interm_layers = args.masks or (args.num_feature_levels > 1) 129 | backbone = Backbone(args.backbone, 130 | train_backbone, 131 | return_interm_layers, 132 | args.dilation) 133 | model = Joiner(backbone, position_embedding) 134 | return model 135 | -------------------------------------------------------------------------------- /src/trackformer/models/matcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Modules to compute the matching cost and solve the corresponding LSAP. 4 | """ 5 | import numpy as np 6 | import torch 7 | from scipy.optimize import linear_sum_assignment 8 | from torch import nn 9 | 10 | from ..util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 11 | 12 | 13 | class HungarianMatcher(nn.Module): 14 | """This class computes an assignment between the targets and the predictions of the network 15 | 16 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 17 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best 18 | predictions, while the others are un-matched (and thus treated as non-objects). 19 | """ 20 | 21 | def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1, 22 | focal_loss: bool = False, focal_alpha: float = 0.25, focal_gamma: float = 2.0): 23 | """Creates the matcher 24 | 25 | Params: 26 | cost_class: This is the relative weight of the classification error in the matching cost 27 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates 28 | in the matching cost 29 | cost_giou: This is the relative weight of the giou loss of the bounding box in the 30 | matching cost 31 | """ 32 | super().__init__() 33 | self.cost_class = cost_class 34 | self.cost_bbox = cost_bbox 35 | self.cost_giou = cost_giou 36 | self.focal_loss = focal_loss 37 | self.focal_alpha = focal_alpha 38 | self.focal_gamma = focal_gamma 39 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" 40 | 41 | @torch.no_grad() 42 | def forward(self, outputs, targets): 43 | """ Performs the matching 44 | 45 | Params: 46 | outputs: This is a dict that contains at least these entries: 47 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the 48 | classification logits 49 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted 50 | box coordinates 51 | 52 | targets: This is a list of targets (len(targets) = batch_size), where each target 53 | is a dict containing: 54 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number 55 | of ground-truth objects in the target) containing the class labels 56 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates 57 | 58 | Returns: 59 | A list of size batch_size, containing tuples of (index_i, index_j) where: 60 | - index_i is the indices of the selected predictions (in order) 61 | - index_j is the indices of the corresponding selected targets (in order) 62 | For each batch element, it holds: 63 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 64 | """ 65 | batch_size, num_queries = outputs["pred_logits"].shape[:2] 66 | 67 | # We flatten to compute the cost matrices in a batch 68 | # 69 | # [batch_size * num_queries, num_classes] 70 | if self.focal_loss: 71 | out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() 72 | else: 73 | out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) 74 | 75 | # [batch_size * num_queries, 4] 76 | out_bbox = outputs["pred_boxes"].flatten(0, 1) 77 | 78 | # Also concat the target labels and boxes 79 | tgt_ids = torch.cat([v["labels"] for v in targets]) 80 | tgt_bbox = torch.cat([v["boxes"] for v in targets]) 81 | 82 | # Compute the classification cost. 83 | if self.focal_loss: 84 | neg_cost_class = (1 - self.focal_alpha) * (out_prob ** self.focal_gamma) * (-(1 - out_prob + 1e-8).log()) 85 | pos_cost_class = self.focal_alpha * ((1 - out_prob) ** self.focal_gamma) * (-(out_prob + 1e-8).log()) 86 | cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] 87 | else: 88 | # Contrary to the loss, we don't use the NLL, but approximate it in 1 - proba[target class]. 89 | # The 1 is a constant that doesn't change the matching, it can be ommitted. 90 | cost_class = -out_prob[:, tgt_ids] 91 | 92 | # Compute the L1 cost between boxes 93 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) 94 | 95 | # Compute the giou cost betwen boxes 96 | cost_giou = -generalized_box_iou( 97 | box_cxcywh_to_xyxy(out_bbox), 98 | box_cxcywh_to_xyxy(tgt_bbox)) 99 | 100 | # Final cost matrix 101 | cost_matrix = self.cost_bbox * cost_bbox \ 102 | + self.cost_class * cost_class \ 103 | + self.cost_giou * cost_giou 104 | cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu() 105 | 106 | sizes = [len(v["boxes"]) for v in targets] 107 | 108 | for i, target in enumerate(targets): 109 | if 'track_query_match_ids' not in target: 110 | continue 111 | 112 | prop_i = 0 113 | for j in range(cost_matrix.shape[1]): 114 | # if target['track_queries_fal_pos_mask'][j] or target['track_queries_placeholder_mask'][j]: 115 | if target['track_queries_fal_pos_mask'][j]: 116 | # false positive and palceholder track queries should not 117 | # be matched to any target 118 | cost_matrix[i, j] = np.inf 119 | elif target['track_queries_mask'][j]: 120 | track_query_id = target['track_query_match_ids'][prop_i] 121 | prop_i += 1 122 | 123 | cost_matrix[i, j] = np.inf 124 | cost_matrix[i, :, track_query_id + sum(sizes[:i])] = np.inf 125 | cost_matrix[i, j, track_query_id + sum(sizes[:i])] = -1 126 | 127 | indices = [linear_sum_assignment(c[i]) 128 | for i, c in enumerate(cost_matrix.split(sizes, -1))] 129 | 130 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) 131 | for i, j in indices] 132 | 133 | 134 | def build_matcher(args): 135 | return HungarianMatcher( 136 | cost_class=args.set_cost_class, 137 | cost_bbox=args.set_cost_bbox, 138 | cost_giou=args.set_cost_giou, 139 | focal_loss=args.focal_loss, 140 | focal_alpha=args.focal_alpha, 141 | focal_gamma=args.focal_gamma,) 142 | -------------------------------------------------------------------------------- /src/trackformer/models/ops/.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | dist 3 | *egg-info 4 | *.linux* 5 | -------------------------------------------------------------------------------- /src/trackformer/models/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch, ms_deform_attn_core_pytorch_mot 2 | 3 | -------------------------------------------------------------------------------- /src/trackformer/models/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.autograd import Function 9 | from torch.autograd.function import once_differentiable 10 | 11 | import MultiScaleDeformableAttention as MSDA 12 | 13 | 14 | class MSDeformAttnFunction(Function): 15 | @staticmethod 16 | def forward(ctx, value, value_spatial_shapes, sampling_locations, attention_weights, im2col_step): 17 | ctx.im2col_step = im2col_step 18 | output = MSDA.ms_deform_attn_forward( 19 | value, value_spatial_shapes, sampling_locations, attention_weights, ctx.im2col_step) 20 | ctx.save_for_backward(value, value_spatial_shapes, sampling_locations, attention_weights) 21 | return output 22 | 23 | @staticmethod 24 | @once_differentiable 25 | def backward(ctx, grad_output): 26 | value, value_spatial_shapes, sampling_locations, attention_weights = ctx.saved_tensors 27 | grad_value, grad_sampling_loc, grad_attn_weight = \ 28 | MSDA.ms_deform_attn_backward( 29 | value, value_spatial_shapes, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 30 | 31 | return grad_value, None, grad_sampling_loc, grad_attn_weight, None 32 | 33 | 34 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 35 | # for debug and test only, 36 | # need to use cuda version instead 37 | N_, S_, M_, D_ = value.shape 38 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 39 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 40 | sampling_grids = 2 * sampling_locations - 1 41 | sampling_value_list = [] 42 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 43 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 44 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 45 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 46 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 47 | # N_*M_, D_, Lq_, P_ 48 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 49 | mode='bilinear', padding_mode='zeros', align_corners=False) 50 | sampling_value_list.append(sampling_value_l_) 51 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 52 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 53 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 54 | return output.transpose(1, 2).contiguous() 55 | 56 | def ms_deform_attn_core_pytorch_mot(query, value, value_spatial_shapes, sampling_locations, key_proj, attention_weights=None): 57 | # for debug and test only, 58 | # need to use cuda version instead 59 | N_, S_, M_, D_ = value.shape 60 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 61 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 62 | sampling_grids = 2 * sampling_locations - 1 63 | sampling_value_list = [] 64 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 65 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 66 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 67 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 68 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 69 | # N_*M_, D_, Lq_, P_ 70 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 71 | mode='bilinear', padding_mode='zeros', align_corners=False) 72 | sampling_value_list.append(sampling_value_l_) 73 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 74 | q = query.transpose(1, 2).reshape(N_*M_, D_, Lq_, 1) 75 | v = torch.stack(sampling_value_list, dim=-2).flatten(-2) # (N_*M_, D_, Lq_, L_*P_) 76 | k = key_proj(v.reshape(N_, M_*D_, Lq_, L_*P_).permute(0, 2, 3, 1)).permute(0, 3, 1, 2).reshape(N_*M_, D_, Lq_, L_*P_) 77 | 78 | sim = (q * k).sum(1).reshape(N_*M_, 1, Lq_, L_*P_) 79 | attention_weights = F.softmax(sim, -1) 80 | output = (v * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 81 | 82 | return output.transpose(1, 2).contiguous() -------------------------------------------------------------------------------- /src/trackformer/models/ops/make.sh: -------------------------------------------------------------------------------- 1 | python setup.py build install 2 | -------------------------------------------------------------------------------- /src/trackformer/models/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .ms_deform_attn import MSDeformAttn 2 | -------------------------------------------------------------------------------- /src/trackformer/models/ops/modules/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | from torch.nn.init import xavier_uniform_, constant_ 10 | 11 | from ..functions import MSDeformAttnFunction, ms_deform_attn_core_pytorch 12 | from ..functions import ms_deform_attn_core_pytorch_mot 13 | 14 | 15 | class MSDeformAttn(nn.Module): 16 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, im2col_step=64): 17 | super().__init__() 18 | assert d_model % n_heads == 0 19 | 20 | self.im2col_step = im2col_step 21 | 22 | self.d_model = d_model 23 | self.n_levels = n_levels 24 | self.n_heads = n_heads 25 | self.n_points = n_points 26 | 27 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) 28 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 29 | self.value_proj = nn.Linear(d_model, d_model) 30 | self.output_proj = nn.Linear(d_model, d_model) 31 | 32 | self._reset_parameters() 33 | 34 | def _reset_parameters(self): 35 | constant_(self.sampling_offsets.weight.data, 0.) 36 | grid_init = torch.tensor([-1, -1, -1, 0, -1, 1, 0, -1, 0, 1, 1, -1, 1, 0, 1, 1], dtype=torch.float32) \ 37 | .view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) 38 | for i in range(self.n_points): 39 | grid_init[:, :, i, :] *= i + 1 40 | with torch.no_grad(): 41 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 42 | constant_(self.attention_weights.weight.data, 0.) 43 | constant_(self.attention_weights.bias.data, 0.) 44 | xavier_uniform_(self.value_proj.weight.data) 45 | constant_(self.value_proj.bias.data, 0.) 46 | xavier_uniform_(self.output_proj.weight.data) 47 | constant_(self.output_proj.bias.data, 0.) 48 | 49 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_padding_mask=None, query_attn_mask=None): 50 | """ 51 | :param query (N, Length_{query}, C) 52 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 53 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 54 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 55 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 56 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 57 | 58 | :return output (N, Length_{query}, C) 59 | """ 60 | N, Len_q, _ = query.shape 61 | N, Len_in, _ = input_flatten.shape 62 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 63 | 64 | value = self.value_proj(input_flatten) 65 | if input_padding_mask is not None: 66 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 67 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 68 | 69 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) 70 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 71 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 72 | 73 | if query_attn_mask is not None: 74 | attention_weights = attention_weights.masked_fill(query_attn_mask[..., None, None, None], float(0)) 75 | 76 | # N, Len_q, n_heads, n_levels, n_points, 2 77 | if reference_points.shape[-1] == 2: 78 | sampling_locations = reference_points[:, :, None, :, None, :] \ 79 | + sampling_offsets / input_spatial_shapes[None, None, None, :, None, :] 80 | elif reference_points.shape[-1] == 4: 81 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 82 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 83 | else: 84 | raise ValueError( 85 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 86 | output = MSDeformAttnFunction.apply( 87 | value, input_spatial_shapes, sampling_locations, attention_weights, self.im2col_step) 88 | output = self.output_proj(output) 89 | return output 90 | -------------------------------------------------------------------------------- /src/trackformer/models/ops/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import glob 5 | 6 | import torch 7 | 8 | from torch.utils.cpp_extension import CUDA_HOME 9 | from torch.utils.cpp_extension import CppExtension 10 | from torch.utils.cpp_extension import CUDAExtension 11 | 12 | from setuptools import find_packages 13 | from setuptools import setup 14 | 15 | requirements = ["torch", "torchvision"] 16 | 17 | def get_extensions(): 18 | this_dir = os.path.dirname(os.path.abspath(__file__)) 19 | extensions_dir = os.path.join(this_dir, "src") 20 | 21 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 22 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 23 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 24 | 25 | sources = main_file + source_cpu 26 | extension = CppExtension 27 | extra_compile_args = {"cxx": []} 28 | define_macros = [] 29 | 30 | if torch.cuda.is_available() and CUDA_HOME is not None: 31 | extension = CUDAExtension 32 | sources += source_cuda 33 | define_macros += [("WITH_CUDA", None)] 34 | extra_compile_args["nvcc"] = [ 35 | "-DCUDA_HAS_FP16=1", 36 | "-D__CUDA_NO_HALF_OPERATORS__", 37 | "-D__CUDA_NO_HALF_CONVERSIONS__", 38 | "-D__CUDA_NO_HALF2_OPERATORS__", 39 | ] 40 | else: 41 | raise NotImplementedError('Cuda is not available') 42 | 43 | sources = [os.path.join(extensions_dir, s) for s in sources] 44 | include_dirs = [extensions_dir] 45 | ext_modules = [ 46 | extension( 47 | "MultiScaleDeformableAttention", 48 | sources, 49 | include_dirs=include_dirs, 50 | define_macros=define_macros, 51 | extra_compile_args=extra_compile_args, 52 | ) 53 | ] 54 | return ext_modules 55 | 56 | setup( 57 | name="MultiScaleDeformableAttention", 58 | version="1.0", 59 | author="Weijie Su", 60 | url="xxx", 61 | description="Multi-Scale Deformable Attention Module in Deformable DETR", 62 | packages=find_packages(exclude=("configs", "tests",)), 63 | # install_requires=requirements, 64 | ext_modules=get_extensions(), 65 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 66 | ) 67 | -------------------------------------------------------------------------------- /src/trackformer/models/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | 7 | at::Tensor 8 | ms_deform_attn_cpu_forward( 9 | const at::Tensor &value, 10 | const at::Tensor &spatial_shapes, 11 | const at::Tensor &sampling_loc, 12 | const at::Tensor &attn_weight, 13 | const int im2col_step) 14 | { 15 | AT_ERROR("Not implement on cpu"); 16 | } 17 | 18 | std::vector 19 | ms_deform_attn_cpu_backward( 20 | const at::Tensor &value, 21 | const at::Tensor &spatial_shapes, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const at::Tensor &grad_output, 25 | const int im2col_step) 26 | { 27 | AT_ERROR("Not implement on cpu"); 28 | } 29 | 30 | -------------------------------------------------------------------------------- /src/trackformer/models/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor 5 | ms_deform_attn_cpu_forward( 6 | const at::Tensor &value, 7 | const at::Tensor &spatial_shapes, 8 | const at::Tensor &sampling_loc, 9 | const at::Tensor &attn_weight, 10 | const int im2col_step); 11 | 12 | std::vector 13 | ms_deform_attn_cpu_backward( 14 | const at::Tensor &value, 15 | const at::Tensor &spatial_shapes, 16 | const at::Tensor &sampling_loc, 17 | const at::Tensor &attn_weight, 18 | const at::Tensor &grad_output, 19 | const int im2col_step); 20 | 21 | 22 | -------------------------------------------------------------------------------- /src/trackformer/models/ops/src/cuda/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "cuda/ms_deform_im2col_cuda.cuh" 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | // #include 10 | // #include 11 | // #include 12 | 13 | // extern THCState *state; 14 | 15 | // author: Charles Shang 16 | // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu 17 | 18 | 19 | at::Tensor ms_deform_attn_cuda_forward( 20 | const at::Tensor &value, 21 | const at::Tensor &spatial_shapes, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const int im2col_step) 25 | // value: N_, S_, M_, D_ 26 | // spatial_shapes: L_, 2 27 | // sampling_loc: N_, Lq_, M_, L_, P_, 2 28 | { 29 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 30 | 31 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 32 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 33 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 34 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 35 | 36 | const int batch = value.size(0); 37 | const int spatial_size = value.size(1); 38 | const int num_heads = value.size(2); 39 | const int channels = value.size(3); 40 | 41 | const int num_levels = spatial_shapes.size(0); 42 | 43 | const int num_query = sampling_loc.size(1); 44 | const int num_point = sampling_loc.size(4); 45 | 46 | const int im2col_step_ = std::min(batch, im2col_step); 47 | 48 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 49 | 50 | auto output = at::empty({batch, num_query, num_heads, channels}, value.options()); 51 | 52 | auto level_start_index = at::zeros({num_levels}, spatial_shapes.options()); 53 | for (int lvl = 1; lvl < num_levels; ++lvl) 54 | { 55 | auto shape_prev = spatial_shapes.select(0, lvl-1); 56 | auto size_prev = at::mul(shape_prev.select(0, 0), shape_prev.select(0, 1)); 57 | level_start_index.select(0, lvl) = at::add(level_start_index.select(0, lvl-1), size_prev); 58 | } 59 | 60 | // define alias for easy use 61 | const int batch_n = im2col_step_; 62 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 63 | auto per_value_size = spatial_size * num_heads * channels; 64 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 65 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 66 | for (int n = 0; n < batch/im2col_step_; ++n) 67 | { 68 | auto columns = at::empty({num_levels*num_point, batch_n, num_query, num_heads, channels}, value.options()); 69 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 70 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 71 | value.data() + n * im2col_step_ * per_value_size, 72 | spatial_shapes.data(), 73 | level_start_index.data(), 74 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 75 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 76 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 77 | columns.data()); 78 | 79 | })); 80 | output_n.select(0, n) = at::sum(columns, 0); 81 | } 82 | 83 | output = output.view({batch, num_query, num_heads*channels}); 84 | 85 | return output; 86 | } 87 | 88 | 89 | std::vector ms_deform_attn_cuda_backward( 90 | const at::Tensor &value, 91 | const at::Tensor &spatial_shapes, 92 | const at::Tensor &sampling_loc, 93 | const at::Tensor &attn_weight, 94 | const at::Tensor &grad_output, 95 | const int im2col_step) 96 | { 97 | 98 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 99 | 100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 102 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 103 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 104 | 105 | const int batch = value.size(0); 106 | const int spatial_size = value.size(1); 107 | const int num_heads = value.size(2); 108 | const int channels = value.size(3); 109 | 110 | const int num_levels = spatial_shapes.size(0); 111 | 112 | const int num_query = sampling_loc.size(1); 113 | const int num_point = sampling_loc.size(4); 114 | 115 | const int im2col_step_ = std::min(batch, im2col_step); 116 | 117 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 118 | 119 | auto grad_value = at::zeros_like(value); 120 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 121 | auto grad_attn_weight = at::zeros_like(attn_weight); 122 | 123 | auto level_start_index = at::zeros({num_levels}, spatial_shapes.options()); 124 | for (int lvl = 1; lvl < num_levels; ++lvl) 125 | { 126 | auto shape_prev = spatial_shapes.select(0, lvl-1); 127 | auto size_prev = at::mul(shape_prev.select(0, 0), shape_prev.select(0, 1)); 128 | level_start_index.select(0, lvl) = at::add(level_start_index.select(0, lvl-1), size_prev); 129 | } 130 | 131 | const int batch_n = im2col_step_; 132 | auto per_value_size = spatial_size * num_heads * channels; 133 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 134 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 135 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 136 | for (int n = 0; n < batch/im2col_step_; ++n) 137 | { 138 | auto grad_output_g = grad_output_n.select(0, n); 139 | AT_DISPATCH_FLOATING_TYPES(value.type(), "deform_conv_backward_cuda", ([&] { 140 | 141 | // gradient w.r.t. sampling location & attention weight 142 | ms_deformable_col2im_coord_cuda(at::cuda::getCurrentCUDAStream(), 143 | grad_output_g.data(), 144 | value.data() + n * im2col_step_ * per_value_size, 145 | spatial_shapes.data(), 146 | level_start_index.data(), 147 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 148 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 149 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 150 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 151 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 152 | // gradient w.r.t. value 153 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 154 | grad_output_g.data(), 155 | spatial_shapes.data(), 156 | level_start_index.data(), 157 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 158 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 159 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 160 | grad_value.data() + n * im2col_step_ * per_value_size); 161 | 162 | })); 163 | } 164 | 165 | return { 166 | grad_value, grad_sampling_loc, grad_attn_weight 167 | }; 168 | } -------------------------------------------------------------------------------- /src/trackformer/models/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor ms_deform_attn_cuda_forward( 5 | const at::Tensor &value, 6 | const at::Tensor &spatial_shapes, 7 | const at::Tensor &sampling_loc, 8 | const at::Tensor &attn_weight, 9 | const int im2col_step); 10 | 11 | std::vector ms_deform_attn_cuda_backward( 12 | const at::Tensor &value, 13 | const at::Tensor &spatial_shapes, 14 | const at::Tensor &sampling_loc, 15 | const at::Tensor &attn_weight, 16 | const at::Tensor &grad_output, 17 | const int im2col_step); 18 | 19 | -------------------------------------------------------------------------------- /src/trackformer/models/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cpu/ms_deform_attn_cpu.h" 4 | 5 | #ifdef WITH_CUDA 6 | #include "cuda/ms_deform_attn_cuda.h" 7 | #endif 8 | 9 | 10 | at::Tensor 11 | ms_deform_attn_forward( 12 | const at::Tensor &value, 13 | const at::Tensor &spatial_shapes, 14 | const at::Tensor &sampling_loc, 15 | const at::Tensor &attn_weight, 16 | const int im2col_step) 17 | { 18 | if (value.type().is_cuda()) 19 | { 20 | #ifdef WITH_CUDA 21 | return ms_deform_attn_cuda_forward( 22 | value, spatial_shapes, sampling_loc, attn_weight, im2col_step); 23 | #else 24 | AT_ERROR("Not compiled with GPU support"); 25 | #endif 26 | } 27 | AT_ERROR("Not implemented on the CPU"); 28 | } 29 | 30 | std::vector 31 | ms_deform_attn_backward( 32 | const at::Tensor &value, 33 | const at::Tensor &spatial_shapes, 34 | const at::Tensor &sampling_loc, 35 | const at::Tensor &attn_weight, 36 | const at::Tensor &grad_output, 37 | const int im2col_step) 38 | { 39 | if (value.type().is_cuda()) 40 | { 41 | #ifdef WITH_CUDA 42 | return ms_deform_attn_cuda_backward( 43 | value, spatial_shapes, sampling_loc, attn_weight, grad_output, im2col_step); 44 | #else 45 | AT_ERROR("Not compiled with GPU support"); 46 | #endif 47 | } 48 | AT_ERROR("Not implemented on the CPU"); 49 | } 50 | 51 | -------------------------------------------------------------------------------- /src/trackformer/models/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "ms_deform_attn.h" 3 | 4 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 5 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 6 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 7 | } 8 | -------------------------------------------------------------------------------- /src/trackformer/models/ops/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import time 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import gradcheck 10 | 11 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 12 | 13 | 14 | N, M, D = 2, 2, 4 15 | Lq, L, P = 3, 3, 2 16 | shapes = torch.as_tensor([(8, 8), (4, 4), (2, 2)], dtype=torch.long).cuda() 17 | S = sum([(H*W).item() for H, W in shapes]) 18 | 19 | 20 | torch.manual_seed(3) 21 | 22 | 23 | def check_forward_equal_with_pytorch(): 24 | value = torch.rand(N, S, M, D).cuda() * 0.01 25 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 26 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 27 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 28 | im2col_step = 2 29 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights) 30 | output_cuda = MSDeformAttnFunction.apply(value, shapes, sampling_locations, attention_weights, im2col_step) 31 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 32 | max_abs_err = (output_cuda - output_pytorch).abs().max() 33 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 34 | 35 | print(f'* {fwdok} check_forward_equal_with_pytorch: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 36 | 37 | 38 | def check_backward_equal_with_pytorch(): 39 | value = torch.rand(N, S, M, D).cuda() * 0.01 40 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 41 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 42 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 43 | im2col_step = 2 44 | value.requires_grad = True 45 | sampling_locations.requires_grad = True 46 | attention_weights.requires_grad = True 47 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights) 48 | output_cuda = MSDeformAttnFunction.apply(value, shapes, sampling_locations, attention_weights, im2col_step) 49 | loss_pytorch = output_pytorch.abs().sum() 50 | loss_cuda = output_cuda.abs().sum() 51 | 52 | grad_value_pytorch = torch.autograd.grad(loss_pytorch, value, retain_graph=True)[0] 53 | grad_value_cuda = torch.autograd.grad(loss_cuda, value, retain_graph=True)[0] 54 | bwdok = torch.allclose(grad_value_cuda, grad_value_pytorch, rtol=1e-2, atol=1e-3) 55 | max_abs_err = (grad_value_cuda - grad_value_pytorch).abs().max() 56 | zero_grad_mask = grad_value_pytorch == 0 57 | max_rel_err = ((grad_value_cuda - grad_value_pytorch).abs() / grad_value_pytorch.abs())[~zero_grad_mask].max() 58 | if zero_grad_mask.sum() == 0: 59 | max_abs_err_0 = 0 60 | else: 61 | max_abs_err_0 = (grad_value_cuda - grad_value_pytorch).abs()[zero_grad_mask].max() 62 | print(f'* {bwdok} check_backward_equal_with_pytorch - input1: ' 63 | f'max_abs_err {max_abs_err:.2e} ' 64 | f'max_rel_err {max_rel_err:.2e} ' 65 | f'max_abs_err_0 {max_abs_err_0:.2e}') 66 | 67 | grad_sampling_loc_pytorch = torch.autograd.grad(loss_pytorch, sampling_locations, retain_graph=True)[0] 68 | grad_sampling_loc_cuda = torch.autograd.grad(loss_cuda, sampling_locations, retain_graph=True)[0] 69 | bwdok = torch.allclose(grad_sampling_loc_cuda, grad_sampling_loc_pytorch, rtol=1e-2, atol=1e-3) 70 | max_abs_err = (grad_sampling_loc_cuda - grad_sampling_loc_pytorch).abs().max() 71 | zero_grad_mask = grad_sampling_loc_pytorch == 0 72 | max_rel_err = ((grad_sampling_loc_cuda - grad_sampling_loc_pytorch).abs() / grad_sampling_loc_pytorch.abs())[~zero_grad_mask].max() 73 | if zero_grad_mask.sum() == 0: 74 | max_abs_err_0 = 0 75 | else: 76 | max_abs_err_0 = (grad_sampling_loc_cuda - grad_sampling_loc_pytorch).abs()[zero_grad_mask].max() 77 | print(f'* {bwdok} check_backward_equal_with_pytorch - input2: ' 78 | f'max_abs_err {max_abs_err:.2e} ' 79 | f'max_rel_err {max_rel_err:.2e} ' 80 | f'max_abs_err_0 {max_abs_err_0:.2e}') 81 | 82 | grad_attn_weight_pytorch = torch.autograd.grad(loss_pytorch, attention_weights, retain_graph=True)[0] 83 | grad_attn_weight_cuda = torch.autograd.grad(loss_cuda, attention_weights, retain_graph=True)[0] 84 | bwdok = torch.allclose(grad_attn_weight_cuda, grad_attn_weight_pytorch, rtol=1e-2, atol=1e-3) 85 | max_abs_err = (grad_attn_weight_cuda - grad_attn_weight_pytorch).abs().max() 86 | zero_grad_mask = grad_attn_weight_pytorch == 0 87 | max_rel_err = ((grad_attn_weight_cuda - grad_attn_weight_pytorch).abs() / grad_attn_weight_pytorch.abs())[~zero_grad_mask].max() 88 | if zero_grad_mask.sum() == 0: 89 | max_abs_err_0 = 0 90 | else: 91 | max_abs_err_0 = (grad_attn_weight_cuda - grad_attn_weight_pytorch).abs()[zero_grad_mask].max() 92 | print(f'* {bwdok} check_backward_equal_with_pytorch - input3: ' 93 | f'max_abs_err {max_abs_err:.2e} ' 94 | f'max_rel_err {max_rel_err:.2e} ' 95 | f'max_abs_err_0 {max_abs_err_0:.2e}') 96 | 97 | 98 | def check_gradient_ms_deform_attn( 99 | use_pytorch=False, 100 | grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 101 | 102 | value = torch.rand(N, S, M, D).cuda() * 0.01 103 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 104 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 105 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 106 | im2col_step = 2 107 | if use_pytorch: 108 | func = ms_deform_attn_core_pytorch 109 | else: 110 | func = MSDeformAttnFunction.apply 111 | 112 | value.requires_grad = grad_value 113 | sampling_locations.requires_grad = grad_sampling_loc 114 | attention_weights.requires_grad = grad_attn_weight 115 | 116 | eps = 1e-3 if not grad_sampling_loc else 2e-4 117 | if use_pytorch: 118 | gradok = gradcheck(func, (value, shapes, sampling_locations, attention_weights), 119 | eps=eps, atol=1e-3, rtol=1e-2, raise_exception=True) 120 | else: 121 | gradok = gradcheck(func, (value, shapes, sampling_locations, attention_weights, im2col_step), 122 | eps=eps, atol=1e-3, rtol=1e-2, raise_exception=True) 123 | 124 | print(f'* {gradok} ' 125 | f'check_gradient_ms_deform_attn(' 126 | f'{use_pytorch}, {grad_value}, {grad_sampling_loc}, {grad_attn_weight})') 127 | 128 | 129 | if __name__ == '__main__': 130 | print('checking forward') 131 | check_forward_equal_with_pytorch() 132 | 133 | print('checking backward') 134 | check_backward_equal_with_pytorch() 135 | 136 | print('checking gradient of pytorch version') 137 | check_gradient_ms_deform_attn(True, True, False, False) 138 | check_gradient_ms_deform_attn(True, False, True, False) 139 | check_gradient_ms_deform_attn(True, False, False, True) 140 | check_gradient_ms_deform_attn(True, True, True, True) 141 | 142 | print('checking gradient of cuda version') 143 | check_gradient_ms_deform_attn(False, True, False, False) 144 | check_gradient_ms_deform_attn(False, False, True, False) 145 | check_gradient_ms_deform_attn(False, False, False, True) 146 | check_gradient_ms_deform_attn(False, True, True, True) 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /src/trackformer/models/ops/test_double_precision.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import time 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import gradcheck 10 | 11 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 12 | 13 | 14 | N, M, D = 2, 2, 4 15 | Lq, L, P = 3, 3, 2 16 | shapes = torch.as_tensor([(12, 8), (6, 4), (3, 2)], dtype=torch.long).cuda() 17 | S = sum([(H*W).item() for H, W in shapes]) 18 | 19 | torch.manual_seed(3) 20 | 21 | @torch.no_grad() 22 | def check_forward_equal_with_pytorch(): 23 | value = torch.rand(N, S, M, D).cuda() * 0.01 24 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 25 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 26 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 27 | im2col_step = 2 28 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 29 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 30 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 31 | max_abs_err = (output_cuda - output_pytorch).abs().max() 32 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 33 | 34 | print(f'* {fwdok} check_forward_equal_with_pytorch: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 35 | 36 | 37 | def check_backward_equal_with_pytorch(): 38 | value = torch.rand(N, S, M, D).cuda() * 0.01 39 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 40 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 41 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 42 | im2col_step = 2 43 | value.requires_grad = True 44 | sampling_locations.requires_grad = True 45 | attention_weights.requires_grad = True 46 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()) 47 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, sampling_locations.double(), attention_weights.double(), im2col_step) 48 | loss_pytorch = output_pytorch.abs().sum() 49 | loss_cuda = output_cuda.abs().sum() 50 | 51 | grad_value_pytorch = torch.autograd.grad(loss_pytorch, value, retain_graph=True)[0].detach().cpu() 52 | grad_value_cuda = torch.autograd.grad(loss_cuda, value, retain_graph=True)[0].detach().cpu() 53 | bwdok = torch.allclose(grad_value_cuda, grad_value_pytorch, rtol=1e-2, atol=1e-3) 54 | max_abs_err = (grad_value_cuda - grad_value_pytorch).abs().max() 55 | zero_grad_mask = grad_value_pytorch == 0 56 | max_rel_err = ((grad_value_cuda - grad_value_pytorch).abs() / grad_value_pytorch.abs())[~zero_grad_mask].max() 57 | if zero_grad_mask.sum() == 0: 58 | max_abs_err_0 = 0 59 | else: 60 | max_abs_err_0 = (grad_value_cuda - grad_value_pytorch).abs()[zero_grad_mask].max() 61 | print(f'* {bwdok} check_backward_equal_with_pytorch - input1: ' 62 | f'max_abs_err {max_abs_err:.2e} ' 63 | f'max_rel_err {max_rel_err:.2e} ' 64 | f'max_abs_err_0 {max_abs_err_0:.2e}') 65 | 66 | grad_sampling_loc_pytorch = torch.autograd.grad(loss_pytorch, sampling_locations, retain_graph=True)[0].detach().cpu() 67 | grad_sampling_loc_cuda = torch.autograd.grad(loss_cuda, sampling_locations, retain_graph=True)[0].detach().cpu() 68 | bwdok = torch.allclose(grad_sampling_loc_cuda, grad_sampling_loc_pytorch, rtol=1e-2, atol=1e-3) 69 | max_abs_err = (grad_sampling_loc_cuda - grad_sampling_loc_pytorch).abs().max() 70 | zero_grad_mask = grad_sampling_loc_pytorch == 0 71 | max_rel_err = ((grad_sampling_loc_cuda - grad_sampling_loc_pytorch).abs() / grad_sampling_loc_pytorch.abs())[~zero_grad_mask].max() 72 | if zero_grad_mask.sum() == 0: 73 | max_abs_err_0 = 0 74 | else: 75 | max_abs_err_0 = (grad_sampling_loc_cuda - grad_sampling_loc_pytorch).abs()[zero_grad_mask].max() 76 | print(f'* {bwdok} check_backward_equal_with_pytorch - input2: ' 77 | f'max_abs_err {max_abs_err:.2e} ' 78 | f'max_rel_err {max_rel_err:.2e} ' 79 | f'max_abs_err_0 {max_abs_err_0:.2e}') 80 | 81 | grad_attn_weight_pytorch = torch.autograd.grad(loss_pytorch, attention_weights, retain_graph=True)[0].detach().cpu() 82 | grad_attn_weight_cuda = torch.autograd.grad(loss_cuda, attention_weights, retain_graph=True)[0].detach().cpu() 83 | bwdok = torch.allclose(grad_attn_weight_cuda, grad_attn_weight_pytorch, rtol=1e-2, atol=1e-3) 84 | max_abs_err = (grad_attn_weight_cuda - grad_attn_weight_pytorch).abs().max() 85 | zero_grad_mask = grad_attn_weight_pytorch == 0 86 | max_rel_err = ((grad_attn_weight_cuda - grad_attn_weight_pytorch).abs() / grad_attn_weight_pytorch.abs())[~zero_grad_mask].max() 87 | if zero_grad_mask.sum() == 0: 88 | max_abs_err_0 = 0 89 | else: 90 | max_abs_err_0 = (grad_attn_weight_cuda - grad_attn_weight_pytorch).abs()[zero_grad_mask].max() 91 | print(f'* {bwdok} check_backward_equal_with_pytorch - input3: ' 92 | f'max_abs_err {max_abs_err:.2e} ' 93 | f'max_rel_err {max_rel_err:.2e} ' 94 | f'max_abs_err_0 {max_abs_err_0:.2e}') 95 | 96 | 97 | def check_gradient_ms_deform_attn( 98 | use_pytorch=False, 99 | grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 100 | 101 | value = torch.rand(N, S, M, D).cuda() * 0.01 102 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 103 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 104 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 105 | im2col_step = 2 106 | if use_pytorch: 107 | func = ms_deform_attn_core_pytorch 108 | else: 109 | func = MSDeformAttnFunction.apply 110 | 111 | value.requires_grad = grad_value 112 | sampling_locations.requires_grad = grad_sampling_loc 113 | attention_weights.requires_grad = grad_attn_weight 114 | 115 | eps = 1e-3 if not grad_sampling_loc else 2e-4 116 | if use_pytorch: 117 | gradok = gradcheck(func, (value.double(), shapes, sampling_locations.double(), attention_weights.double())) 118 | else: 119 | gradok = gradcheck(func, (value.double(), shapes, sampling_locations.double(), attention_weights.double(), im2col_step)) 120 | 121 | print(f'* {gradok} ' 122 | f'check_gradient_ms_deform_attn(' 123 | f'{use_pytorch}, {grad_value}, {grad_sampling_loc}, {grad_attn_weight})') 124 | 125 | 126 | if __name__ == '__main__': 127 | print('checking forward') 128 | check_forward_equal_with_pytorch() 129 | 130 | print('checking backward') 131 | check_backward_equal_with_pytorch() 132 | 133 | print('checking gradient of pytorch version') 134 | check_gradient_ms_deform_attn(True, True, False, False) 135 | check_gradient_ms_deform_attn(True, False, True, False) 136 | check_gradient_ms_deform_attn(True, False, False, True) 137 | check_gradient_ms_deform_attn(True, True, True, True) 138 | 139 | print('checking gradient of cuda version') 140 | check_gradient_ms_deform_attn(False, True, False, False) 141 | check_gradient_ms_deform_attn(False, False, True, False) 142 | check_gradient_ms_deform_attn(False, False, False, True) 143 | check_gradient_ms_deform_attn(False, True, True, True) 144 | 145 | 146 | 147 | -------------------------------------------------------------------------------- /src/trackformer/models/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | from ..util.misc import NestedTensor 10 | 11 | 12 | class PositionEmbeddingSine3D(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | # def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 18 | def __init__(self, num_pos_feats=64, num_frames=2, temperature=10000, normalize=False, scale=None): 19 | super().__init__() 20 | self.num_pos_feats = num_pos_feats 21 | self.temperature = temperature 22 | self.normalize = normalize 23 | self.frames = num_frames 24 | 25 | if scale is not None and normalize is False: 26 | raise ValueError("normalize should be True if scale is passed") 27 | if scale is None: 28 | scale = 2 * math.pi 29 | self.scale = scale 30 | 31 | def forward(self, tensor_list: NestedTensor): 32 | x = tensor_list.tensors 33 | mask = tensor_list.mask 34 | n, h, w = mask.shape 35 | # assert n == 1 36 | # mask = mask.reshape(1, 1, h, w) 37 | mask = mask.view(n, 1, h, w) 38 | mask = mask.expand(n, self.frames, h, w) 39 | 40 | assert mask is not None 41 | not_mask = ~mask 42 | # y_embed = not_mask.cumsum(1, dtype=torch.float32) 43 | # x_embed = not_mask.cumsum(2, dtype=torch.float32) 44 | 45 | z_embed = not_mask.cumsum(1, dtype=torch.float32) 46 | y_embed = not_mask.cumsum(2, dtype=torch.float32) 47 | x_embed = not_mask.cumsum(3, dtype=torch.float32) 48 | 49 | if self.normalize: 50 | eps = 1e-6 51 | # y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 52 | # x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 53 | 54 | z_embed = z_embed / (z_embed[:, -1:, :, :] + eps) * self.scale 55 | y_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scale 56 | x_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scale 57 | 58 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 59 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 60 | 61 | # pos_x = x_embed[:, :, :, None] / dim_t 62 | # pos_y = y_embed[:, :, :, None] / dim_t 63 | # pos_x = torch.stack(( 64 | # pos_x[:, :, :, 0::2].sin(), 65 | # pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 66 | # pos_y = torch.stack(( 67 | # pos_y[:, :, :, 0::2].sin(), 68 | # pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 69 | # pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 70 | 71 | pos_x = x_embed[:, :, :, :, None] / dim_t 72 | pos_y = y_embed[:, :, :, :, None] / dim_t 73 | pos_z = z_embed[:, :, :, :, None] / dim_t 74 | pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 75 | pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 76 | pos_z = torch.stack((pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 77 | # pos_w = torch.zeros_like(pos_z) 78 | # pos = torch.cat((pos_w, pos_z, pos_y, pos_x), dim=4).permute(0, 1, 4, 2, 3) 79 | pos = torch.cat((pos_z, pos_y, pos_x), dim=4).permute(0, 1, 4, 2, 3) 80 | 81 | return pos 82 | 83 | 84 | class PositionEmbeddingSine(nn.Module): 85 | """ 86 | This is a more standard version of the position embedding, very similar to the one 87 | used by the Attention is all you need paper, generalized to work on images. 88 | """ 89 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 90 | super().__init__() 91 | self.num_pos_feats = num_pos_feats 92 | self.temperature = temperature 93 | self.normalize = normalize 94 | if scale is not None and normalize is False: 95 | raise ValueError("normalize should be True if scale is passed") 96 | if scale is None: 97 | scale = 2 * math.pi 98 | self.scale = scale 99 | 100 | def forward(self, tensor_list: NestedTensor): 101 | x = tensor_list.tensors 102 | mask = tensor_list.mask 103 | assert mask is not None 104 | not_mask = ~mask 105 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 106 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 107 | if self.normalize: 108 | eps = 1e-6 109 | y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale 110 | x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale 111 | 112 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 113 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 114 | 115 | pos_x = x_embed[:, :, :, None] / dim_t 116 | pos_y = y_embed[:, :, :, None] / dim_t 117 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 118 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 119 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 120 | return pos 121 | 122 | 123 | class PositionEmbeddingLearned(nn.Module): 124 | """ 125 | Absolute pos embedding, learned. 126 | """ 127 | def __init__(self, num_pos_feats=256): 128 | super().__init__() 129 | self.row_embed = nn.Embedding(50, num_pos_feats) 130 | self.col_embed = nn.Embedding(50, num_pos_feats) 131 | self.reset_parameters() 132 | 133 | def reset_parameters(self): 134 | nn.init.uniform_(self.row_embed.weight) 135 | nn.init.uniform_(self.col_embed.weight) 136 | 137 | def forward(self, tensor_list: NestedTensor): 138 | x = tensor_list.tensors 139 | h, w = x.shape[-2:] 140 | i = torch.arange(w, device=x.device) 141 | j = torch.arange(h, device=x.device) 142 | x_emb = self.col_embed(i) 143 | y_emb = self.row_embed(j) 144 | pos = torch.cat([ 145 | x_emb.unsqueeze(0).repeat(h, 1, 1), 146 | y_emb.unsqueeze(1).repeat(1, w, 1), 147 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 148 | return pos 149 | 150 | 151 | def build_position_encoding(args): 152 | # n_steps = args.hidden_dim // 2 153 | # n_steps = args.hidden_dim // 4 154 | if args.multi_frame_attention and args.multi_frame_encoding: 155 | n_steps = args.hidden_dim // 3 156 | sine_emedding_func = PositionEmbeddingSine3D 157 | else: 158 | n_steps = args.hidden_dim // 2 159 | sine_emedding_func = PositionEmbeddingSine 160 | 161 | if args.position_embedding in ('v2', 'sine'): 162 | # TODO find a better way of exposing other arguments 163 | position_embedding = sine_emedding_func(n_steps, normalize=True) 164 | elif args.position_embedding in ('v3', 'learned'): 165 | position_embedding = PositionEmbeddingLearned(n_steps) 166 | else: 167 | raise ValueError(f"not supported {args.position_embedding}") 168 | 169 | return position_embedding 170 | -------------------------------------------------------------------------------- /src/trackformer/util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /src/trackformer/util/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 12 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 13 | return torch.stack(b, dim=-1) 14 | 15 | 16 | def box_xyxy_to_cxcywh(x): 17 | x0, y0, x1, y1 = x.unbind(-1) 18 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 19 | (x1 - x0), (y1 - y0)] 20 | return torch.stack(b, dim=-1) 21 | 22 | 23 | # modified from torchvision to also return the union 24 | def box_iou(boxes1, boxes2): 25 | area1 = box_area(boxes1) 26 | area2 = box_area(boxes2) 27 | 28 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 29 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 30 | 31 | wh = (rb - lt).clamp(min=0) # [N,M,2] 32 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 33 | 34 | union = area1[:, None] + area2 - inter 35 | 36 | iou = inter / union 37 | return iou, union 38 | 39 | 40 | def generalized_box_iou(boxes1, boxes2): 41 | """ 42 | Generalized IoU from https://giou.stanford.edu/ 43 | 44 | The boxes should be in [x0, y0, x1, y1] format 45 | 46 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 47 | and M = len(boxes2) 48 | """ 49 | # degenerate boxes gives inf / nan results 50 | # so do an early check 51 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 52 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 53 | iou, union = box_iou(boxes1, boxes2) 54 | 55 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 56 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 57 | 58 | wh = (rb - lt).clamp(min=0) # [N,M,2] 59 | area = wh[:, :, 0] * wh[:, :, 1] 60 | 61 | return iou - (area - union) / area 62 | 63 | 64 | def masks_to_boxes(masks): 65 | """Compute the bounding boxes around the provided masks 66 | 67 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 68 | 69 | Returns a [N, 4] tensors, with the boxes in xyxy format 70 | """ 71 | if masks.numel() == 0: 72 | return torch.zeros((0, 4), device=masks.device) 73 | 74 | h, w = masks.shape[-2:] 75 | 76 | y = torch.arange(0, h, dtype=torch.float) 77 | x = torch.arange(0, w, dtype=torch.float) 78 | y, x = torch.meshgrid(y, x) 79 | 80 | x_mask = (masks * x.unsqueeze(0)) 81 | x_max = x_mask.flatten(1).max(-1)[0] 82 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 83 | 84 | y_mask = (masks * y.unsqueeze(0)) 85 | y_max = y_mask.flatten(1).max(-1)[0] 86 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 87 | 88 | return torch.stack([x_min, y_min, x_max, y_max], 1) 89 | -------------------------------------------------------------------------------- /src/trackformer/util/plot_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plotting utilities to visualize training logs. 3 | """ 4 | from pathlib import Path, PurePath 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pandas as pd 9 | import seaborn as sns 10 | import torch 11 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 12 | 13 | 14 | def fig_to_numpy(fig): 15 | w, h = fig.get_size_inches() * fig.dpi 16 | w = int(w.item()) 17 | h = int(h.item()) 18 | canvas = FigureCanvas(fig) 19 | canvas.draw() 20 | numpy_image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(h, w, 3) 21 | return np.copy(numpy_image) 22 | 23 | 24 | def get_vis_win_names(vis_dict): 25 | vis_win_names = { 26 | outer_k: { 27 | inner_k: inner_v.win 28 | for inner_k, inner_v in outer_v.items() 29 | } 30 | for outer_k, outer_v in vis_dict.items() 31 | } 32 | return vis_win_names 33 | 34 | 35 | def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): 36 | ''' 37 | Function to plot specific fields from training log(s). Plots both training and test results. 38 | 39 | :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file 40 | - fields = which results to plot from each log file - plots both training and test for each field. 41 | - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots 42 | - log_name = optional, name of log file if different than default 'log.txt'. 43 | 44 | :: Outputs - matplotlib plots of results in fields, color coded for each log file. 45 | - solid lines are training results, dashed lines are test results. 46 | 47 | ''' 48 | func_name = "plot_utils.py::plot_logs" 49 | 50 | # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, 51 | # convert single Path to list to avoid 'not iterable' error 52 | 53 | if not isinstance(logs, list): 54 | if isinstance(logs, PurePath): 55 | logs = [logs] 56 | print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") 57 | else: 58 | raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ 59 | Expect list[Path] or single Path obj, received {type(logs)}") 60 | 61 | # verify valid dir(s) and that every item in list is Path object 62 | for i, dir in enumerate(logs): 63 | if not isinstance(dir, PurePath): 64 | raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") 65 | if dir.exists(): 66 | continue 67 | raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") 68 | 69 | # load log file(s) and plot 70 | dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] 71 | 72 | fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) 73 | 74 | for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): 75 | for j, field in enumerate(fields): 76 | if field == 'mAP': 77 | coco_eval = pd.DataFrame(pd.np.stack(df.test_coco_eval.dropna().values)[:, 1]).ewm(com=ewm_col).mean() 78 | axs[j].plot(coco_eval, c=color) 79 | else: 80 | df.interpolate().ewm(com=ewm_col).mean().plot( 81 | y=[f'train_{field}', f'test_{field}'], 82 | ax=axs[j], 83 | color=[color] * 2, 84 | style=['-', '--'] 85 | ) 86 | for ax, field in zip(axs, fields): 87 | ax.legend([Path(p).name for p in logs]) 88 | ax.set_title(field) 89 | 90 | 91 | def plot_precision_recall(files, naming_scheme='iter'): 92 | if naming_scheme == 'exp_id': 93 | # name becomes exp_id 94 | names = [f.parts[-3] for f in files] 95 | elif naming_scheme == 'iter': 96 | names = [f.stem for f in files] 97 | else: 98 | raise ValueError(f'not supported {naming_scheme}') 99 | fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) 100 | for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): 101 | data = torch.load(f) 102 | # precision is n_iou, n_points, n_cat, n_area, max_det 103 | precision = data['precision'] 104 | recall = data['params'].recThrs 105 | scores = data['scores'] 106 | # take precision for all classes, all areas and 100 detections 107 | precision = precision[0, :, :, 0, -1].mean(1) 108 | scores = scores[0, :, :, 0, -1].mean(1) 109 | prec = precision.mean() 110 | rec = data['recall'][0, :, 0, -1].mean() 111 | print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + 112 | f'score={scores.mean():0.3f}, ' + 113 | f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' 114 | ) 115 | axs[0].plot(recall, precision, c=color) 116 | axs[1].plot(recall, scores, c=color) 117 | 118 | axs[0].set_title('Precision / Recall') 119 | axs[0].legend(names) 120 | axs[1].set_title('Scores / Recall') 121 | axs[1].legend(names) 122 | return fig, axs 123 | --------------------------------------------------------------------------------