├── .gitignore ├── LICENSE ├── README.md ├── benchmark.py ├── configs ├── r50_deformable_detr.sh ├── r50_deformable_detr_plus_iterative_bbox_refinement.sh ├── r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh ├── r50_deformable_detr_single_scale.sh ├── r50_deformable_detr_single_scale_dc5.sh ├── r50_motr_demo.sh ├── r50_motr_eval.sh ├── r50_motr_submit.sh ├── r50_motr_submit_dance.sh ├── r50_motr_train.sh └── r50_motr_train_dance.sh ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── coco.cpython-36.pyc │ ├── coco_eval.cpython-36.pyc │ ├── data_prefetcher.cpython-36.pyc │ ├── detmot.cpython-36.pyc │ ├── joint.cpython-36.pyc │ ├── panoptic_eval.cpython-36.pyc │ ├── samplers.cpython-36.pyc │ ├── static_detmot.cpython-36.pyc │ └── transforms.cpython-36.pyc ├── coco.py ├── coco_eval.py ├── coco_panoptic.py ├── dance.py ├── data_path │ ├── bdd100k.train │ ├── bdd100k.val │ ├── crowdhuman.train │ ├── crowdhuman.val │ ├── detmot16.train │ ├── detmot17.train │ ├── gen_bdd100k_mot.py │ ├── gen_labels_15.py │ ├── gen_labels_16.py │ ├── joint.train │ ├── mot16.train │ ├── mot17.train │ └── prepare.py ├── data_prefetcher.py ├── detmot.py ├── joint.py ├── panoptic_eval.py ├── samplers.py ├── static_detmot.py ├── torchvision_datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── coco.cpython-36.pyc │ └── coco.py └── transforms.py ├── demo.py ├── engine.py ├── eval.py ├── figs ├── demo.avi └── motr.png ├── main.py ├── models ├── .DS_Store ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── backbone.cpython-36.pyc │ ├── deformable_detr.cpython-36.pyc │ ├── deformable_transformer.cpython-36.pyc │ ├── deformable_transformer_plus.cpython-36.pyc │ ├── matcher.cpython-36.pyc │ ├── memory_bank.cpython-36.pyc │ ├── motr.cpython-36.pyc │ ├── position_encoding.cpython-36.pyc │ ├── qim.cpython-36.pyc │ └── segmentation.cpython-36.pyc ├── backbone.py ├── deformable_detr.py ├── deformable_transformer.py ├── deformable_transformer_plus.py ├── matcher.py ├── memory_bank.py ├── motr.py ├── ops │ ├── functions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── ms_deform_attn_func.cpython-36.pyc │ │ └── ms_deform_attn_func.py │ ├── make.sh │ ├── modules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── ms_deform_attn.cpython-36.pyc │ │ └── ms_deform_attn.py │ ├── setup.py │ ├── src │ │ ├── cpu │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ └── ms_deform_attn_cpu.h │ │ ├── cuda │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ ├── ms_deform_attn_cuda.h │ │ │ └── ms_deform_im2col_cuda.cuh │ │ ├── ms_deform_attn.h │ │ └── vision.cpp │ └── test.py ├── position_encoding.py ├── qim.py ├── relu_dropout.py ├── segmentation.py └── structures │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── boxes.cpython-36.pyc │ └── instances.cpython-36.pyc │ ├── boxes.py │ └── instances.py ├── requirements.txt ├── submit.py ├── submit_dance.py ├── tools ├── launch.py ├── run_dist_launch.sh └── run_dist_slurm.sh └── util ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── box_ops.cpython-36.pyc ├── evaluation.cpython-36.pyc ├── misc.cpython-36.pyc ├── motdet_eval.cpython-36.pyc ├── plot_utils.cpython-36.pyc └── tool.cpython-36.pyc ├── box_ops.py ├── checkpoint.py ├── evaluation.py ├── misc.py ├── motdet_eval.py ├── plot_utils.py └── tool.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pth 3 | *.train 4 | exps/ 5 | build/ 6 | *.egg 7 | *.egg-info 8 | *.mp4 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MOTR: End-to-End Multiple-Object Tracking with TRansformer 2 | 3 | 4 | 5 | 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/motr-end-to-end-multiple-object-tracking-with/multi-object-tracking-on-mot17)](https://paperswithcode.com/sota/multi-object-tracking-on-mot17?p=motr-end-to-end-multiple-object-tracking-with) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/motr-end-to-end-multiple-object-tracking-with/multi-object-tracking-on-mot16)](https://paperswithcode.com/sota/multi-object-tracking-on-mot16?p=motr-end-to-end-multiple-object-tracking-with) 8 | 9 | 10 | 11 | This repository is an official implementation of the paper [MOTR: End-to-End Multiple-Object Tracking with TRansformer](https://arxiv.org/pdf/2105.03247.pdf). 12 | 13 | ## Introduction 14 | 15 | **TL; DR.** MOTR is a fully end-to-end multiple-object tracking framework based on Transformer. It directly outputs the tracks within the video sequences without any association procedures. 16 | 17 |
18 | 19 |
20 | 21 | **Abstract.** The key challenge in multiple-object tracking task is temporal modeling of the object under track. Existing tracking-by-detection methods adopt simple heuristics, such as spatial or appearance similarity. Such methods, in spite of their commonality, are overly simple and lack the ability to learn temporal variations from data in an end-to-end manner.In this paper, we present MOTR, a fully end-to-end multiple-object tracking framework. It learns to model the long-range temporal variation of the objects. It performs temporal association implicitly and avoids previous explicit heuristics. Built upon DETR, MOTR introduces the concept of "track query". Each track query models the entire track of an object. It is transferred and updated frame-by-frame to perform iterative predictions in a seamless manner. Tracklet-aware label assignment is proposed for one-to-one assignment between track queries and object tracks. Temporal aggregation network together with collective average loss is further proposed to enhance the long-range temporal relation. Experimental results show that MOTR achieves competitive performance and can serve as a strong Transformer-based baseline for future research. 22 | 23 | ## Updates 24 | - (2021/09/23) Report BDD100K results and release corresponding codes [motr_bdd100k](https://github.com/megvii-model/MOTR/tree/motr_bdd100k). 25 | - (2022/02/09) Higher performance achieved by not clipping the bounding boxes inside the image. 26 | - (2022/02/11) Add checkpoint support for training on RTX 2080ti. 27 | - (2022/02/11) Report [DanceTrack](https://github.com/DanceTrack/DanceTrack) results and [scripts](configs/r50_motr_train_dance.sh). 28 | - (2022/05/12) Higher performance achieved by removing the public detection filtering (filter_pub_det) trick. 29 | - (2022/07/04) MOTR is accepted by ECCV 2022. 30 | 31 | ## Main Results 32 | 33 | ### MOT17 34 | 35 | | **Method** | **Dataset** | **Train Data** | **HOTA** | **DetA** | **AssA** | **MOTA** | **IDF1** | **IDS** | **URL** | 36 | | :--------: | :---------: | :------------------: | :------: | :------: | :------: | :------: | :------: | :-----: | :-----------------------------------------------------------------------------------------: | 37 | | MOTR | MOT17 | MOT17+CrowdHuman Val | 57.8 | 60.3 | 55.7 | 73.4 | 68.6 | 2439 | [model](https://drive.google.com/file/d/1K9AbtzTCBNsOD8LYA1k16kf4X0uJi8PC/view?usp=sharing) | 38 | 39 | ### DanceTrack 40 | 41 | | **Method** | **Dataset** | **Train Data** | **HOTA** | **DetA** | **AssA** | **MOTA** | **IDF1** | **URL** | 42 | | :--------: | :---------: | :------------: | :------: | :------: | :------: | :------: | :------: | :-----------------------------------------------------------------------------------------: | 43 | | MOTR | DanceTrack | DanceTrack | 54.2 | 73.5 | 40.2 | 79.7 | 51.5 | [model](https://drive.google.com/file/d/1zs5o1oK8diafVfewRl3heSVQ7-XAty3J/view?usp=sharing) | 44 | 45 | ### BDD100K 46 | 47 | | **Method** | **Dataset** | **Train Data** | **MOTA** | **IDF1** | **IDS** | **URL** | 48 | | :--------: | :---------: | :------------: | :------: | :------: | :-----: | :-----------------------------------------------------------------------------------------: | 49 | | MOTR | BDD100K | BDD100K | 32.0 | 43.5 | 3493 | [model](https://drive.google.com/file/d/13fsTj9e6Hk7qVcybWi1X5KbZEsFCHa6e/view?usp=sharing) | 50 | 51 | *Note:* 52 | 53 | 1. MOTR on MOT17 and DanceTrack is trained on 8 NVIDIA RTX 2080ti GPUs. 54 | 2. The training time for MOT17 is about 2.5 days on V100 or 4 days on RTX 2080ti; 55 | 3. The inference speed is about 7.5 FPS for resolution 1536x800; 56 | 4. All models of MOTR are trained with ResNet50 with pre-trained weights on COCO dataset. 57 | 58 | 59 | ## Installation 60 | 61 | The codebase is built on top of [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR). 62 | 63 | ### Requirements 64 | 65 | * Linux, CUDA>=9.2, GCC>=5.4 66 | 67 | * Python>=3.7 68 | 69 | We recommend you to use Anaconda to create a conda environment: 70 | ```bash 71 | conda create -n deformable_detr python=3.7 pip 72 | ``` 73 | Then, activate the environment: 74 | ```bash 75 | conda activate deformable_detr 76 | ``` 77 | 78 | * PyTorch>=1.5.1, torchvision>=0.6.1 (following instructions [here](https://pytorch.org/)) 79 | 80 | For example, if your CUDA version is 9.2, you could install pytorch and torchvision as following: 81 | ```bash 82 | conda install pytorch=1.5.1 torchvision=0.6.1 cudatoolkit=9.2 -c pytorch 83 | ``` 84 | 85 | * Other requirements 86 | ```bash 87 | pip install -r requirements.txt 88 | ``` 89 | 90 | * Build MultiScaleDeformableAttention 91 | ```bash 92 | cd ./models/ops 93 | sh ./make.sh 94 | ``` 95 | 96 | ## Usage 97 | 98 | ### Dataset preparation 99 | 100 | 1. Please download [MOT17 dataset](https://motchallenge.net/) and [CrowdHuman dataset](https://www.crowdhuman.org/) and organize them like [FairMOT](https://github.com/ifzhang/FairMOT) as following: 101 | 102 | ``` 103 | . 104 | ├── crowdhuman 105 | │   ├── images 106 | │   └── labels_with_ids 107 | ├── MOT15 108 | │   ├── images 109 | │   ├── labels_with_ids 110 | │   ├── test 111 | │   └── train 112 | ├── MOT17 113 | │   ├── images 114 | │   ├── labels_with_ids 115 | ├── DanceTrack 116 | │   ├── train 117 | │   ├── test 118 | ├── bdd100k 119 | │   ├── images 120 | │ ├── track 121 | │ ├── train 122 | │ ├── val 123 | │   ├── labels 124 | │ ├── track 125 | │ ├── train 126 | │ ├── val 127 | 128 | ``` 129 | 130 | 2. For BDD100K dataset, you can use the following script to generate txt file: 131 | 132 | 133 | ```bash 134 | cd datasets/data_path 135 | python3 generate_bdd100k_mot.py 136 | cd ../../ 137 | ``` 138 | 139 | ### Training and Evaluation 140 | 141 | #### Training on single node 142 | 143 | You can download COCO pretrained weights from [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR). Then training MOTR on 8 GPUs as following: 144 | 145 | ```bash 146 | sh configs/r50_motr_train.sh 147 | 148 | ``` 149 | 150 | #### Evaluation on MOT15 151 | 152 | You can download the pretrained model of MOTR (the link is in "Main Results" session), then run following command to evaluate it on MOT15 train dataset: 153 | 154 | ```bash 155 | sh configs/r50_motr_eval.sh 156 | 157 | ``` 158 | 159 | For visual in demo video, you can enable 'vis=True' in eval.py like: 160 | ```bash 161 | det.detect(vis=True) 162 | 163 | ``` 164 | 165 | #### Evaluation on MOT17 166 | 167 | You can download the pretrained model of MOTR (the link is in "Main Results" session), then run following command to evaluate it on MOT17 test dataset (submit to server): 168 | 169 | ```bash 170 | sh configs/r50_motr_submit.sh 171 | 172 | ``` 173 | #### Evaluation on BDD100K 174 | 175 | For BDD100K dataset, please refer [motr_bdd100k](https://github.com/megvii-model/MOTR/tree/motr_bdd100k). 176 | 177 | 178 | #### Test on Video Demo 179 | 180 | We also provide a demo interface which allows for a quick processing of a given video. 181 | 182 | ```bash 183 | EXP_DIR=exps/e2e_motr_r50_joint 184 | python3 demo.py \ 185 | --meta_arch motr \ 186 | --dataset_file e2e_joint \ 187 | --epoch 200 \ 188 | --with_box_refine \ 189 | --lr_drop 100 \ 190 | --lr 2e-4 \ 191 | --lr_backbone 2e-5 \ 192 | --pretrained ${EXP_DIR}/motr_final.pth \ 193 | --output_dir ${EXP_DIR} \ 194 | --batch_size 1 \ 195 | --sample_mode 'random_interval' \ 196 | --sample_interval 10 \ 197 | --sampler_steps 50 90 120 \ 198 | --sampler_lengths 2 3 4 5 \ 199 | --update_query_pos \ 200 | --merger_dropout 0 \ 201 | --dropout 0 \ 202 | --random_drop 0.1 \ 203 | --fp_ratio 0.3 \ 204 | --query_interaction_layer 'QIM' \ 205 | --extra_track_attn \ 206 | --resume ${EXP_DIR}/motr_final.pth \ 207 | --input_video figs/demo.avi 208 | ``` 209 | 210 | ## Citing MOTR 211 | If you find MOTR useful in your research, please consider citing: 212 | ```bibtex 213 | @inproceedings{zeng2021motr, 214 | title={MOTR: End-to-End Multiple-Object Tracking with TRansformer}, 215 | author={Zeng, Fangao and Dong, Bin and Zhang, Yuang and Wang, Tiancai and Zhang, Xiangyu and Wei, Yichen}, 216 | booktitle={European Conference on Computer Vision (ECCV)}, 217 | year={2022} 218 | } 219 | ``` 220 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | 12 | """ 13 | Benchmark inference speed of Deformable DETR. 14 | """ 15 | import os 16 | import time 17 | import argparse 18 | 19 | import torch 20 | 21 | from main import get_args_parser as get_main_args_parser 22 | from models import build_model 23 | from datasets import build_dataset 24 | from util.misc import nested_tensor_from_tensor_list 25 | 26 | 27 | def get_benckmark_arg_parser(): 28 | parser = argparse.ArgumentParser('Benchmark inference speed of Deformable DETR.') 29 | parser.add_argument('--num_iters', type=int, default=300, help='total iters to benchmark speed') 30 | parser.add_argument('--warm_iters', type=int, default=5, help='ignore first several iters that are very slow') 31 | parser.add_argument('--batch_size', type=int, default=1, help='batch size in inference') 32 | parser.add_argument('--resume', type=str, help='load the pre-trained checkpoint') 33 | return parser 34 | 35 | 36 | @torch.no_grad() 37 | def measure_average_inference_time(model, inputs, num_iters=100, warm_iters=5): 38 | ts = [] 39 | for iter_ in range(num_iters): 40 | torch.cuda.synchronize() 41 | t_ = time.perf_counter() 42 | model(inputs) 43 | torch.cuda.synchronize() 44 | t = time.perf_counter() - t_ 45 | if iter_ >= warm_iters: 46 | ts.append(t) 47 | print(ts) 48 | return sum(ts) / len(ts) 49 | 50 | 51 | def benchmark(): 52 | args, _ = get_benckmark_arg_parser().parse_known_args() 53 | main_args = get_main_args_parser().parse_args(_) 54 | assert args.warm_iters < args.num_iters and args.num_iters > 0 and args.warm_iters >= 0 55 | assert args.batch_size > 0 56 | assert args.resume is None or os.path.exists(args.resume) 57 | dataset = build_dataset('val', main_args) 58 | model, _, _ = build_model(main_args) 59 | model.cuda() 60 | model.eval() 61 | if args.resume is not None: 62 | ckpt = torch.load(args.resume, map_location=lambda storage, loc: storage) 63 | model.load_state_dict(ckpt['model']) 64 | inputs = nested_tensor_from_tensor_list([dataset.__getitem__(0)[0].cuda() for _ in range(args.batch_size)]) 65 | t = measure_average_inference_time(model, inputs, args.num_iters, args.warm_iters) 66 | return 1.0 / t * args.batch_size 67 | 68 | 69 | if __name__ == '__main__': 70 | fps = benchmark() 71 | print(f'Inference Speed: {fps:.1f} FPS') 72 | 73 | -------------------------------------------------------------------------------- /configs/r50_deformable_detr.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------ 3 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 6 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | 9 | set -x 10 | 11 | EXP_DIR=exps/r50_deformable_detr 12 | PY_ARGS=${@:1} 13 | 14 | python -u main.py \ 15 | --output_dir ${EXP_DIR} \ 16 | ${PY_ARGS} 17 | -------------------------------------------------------------------------------- /configs/r50_deformable_detr_plus_iterative_bbox_refinement.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------ 3 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 6 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | 9 | set -x 10 | 11 | EXP_DIR=exps/r50_deformable_detr_plus_iterative_bbox_refinement 12 | PY_ARGS=${@:1} 13 | 14 | python -u main.py \ 15 | --output_dir ${EXP_DIR} \ 16 | --with_box_refine \ 17 | ${PY_ARGS} 18 | -------------------------------------------------------------------------------- /configs/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------ 3 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 6 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | 9 | set -x 10 | 11 | EXP_DIR=exps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage 12 | PY_ARGS=${@:1} 13 | 14 | python -u main.py \ 15 | --output_dir ${EXP_DIR} \ 16 | --with_box_refine \ 17 | --two_stage \ 18 | ${PY_ARGS} 19 | -------------------------------------------------------------------------------- /configs/r50_deformable_detr_single_scale.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------ 3 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 6 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | 9 | set -x 10 | 11 | EXP_DIR=exps/r50_deformable_detr_single_scale 12 | PY_ARGS=${@:1} 13 | 14 | python -u main.py \ 15 | --num_feature_levels 1 \ 16 | --output_dir ${EXP_DIR} \ 17 | ${PY_ARGS} 18 | -------------------------------------------------------------------------------- /configs/r50_deformable_detr_single_scale_dc5.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------ 3 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 6 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | 9 | set -x 10 | 11 | EXP_DIR=exps/r50_deformable_detr_single_scale_dc5 12 | PY_ARGS=${@:1} 13 | 14 | python -u main.py \ 15 | --num_feature_levels 1 \ 16 | --dilation \ 17 | --output_dir ${EXP_DIR} \ 18 | ${PY_ARGS} 19 | -------------------------------------------------------------------------------- /configs/r50_motr_demo.sh: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | 8 | EXP_DIR=exps/e2e_motr_r50_joint 9 | python3 demo.py \ 10 | --meta_arch motr \ 11 | --dataset_file e2e_joint \ 12 | --epoch 200 \ 13 | --with_box_refine \ 14 | --lr_drop 100 \ 15 | --lr 2e-4 \ 16 | --lr_backbone 2e-5 \ 17 | --pretrained ${EXP_DIR}/motr_final.pth \ 18 | --output_dir ${EXP_DIR} \ 19 | --batch_size 1 \ 20 | --sample_mode 'random_interval' \ 21 | --sample_interval 10 \ 22 | --sampler_steps 50 90 120 \ 23 | --sampler_lengths 2 3 4 5 \ 24 | --update_query_pos \ 25 | --merger_dropout 0 \ 26 | --dropout 0 \ 27 | --random_drop 0.1 \ 28 | --fp_ratio 0.3 \ 29 | --query_interaction_layer 'QIM' \ 30 | --extra_track_attn \ 31 | --resume ${EXP_DIR}/motr_final.pth \ 32 | --input_video figs/demo.avi -------------------------------------------------------------------------------- /configs/r50_motr_eval.sh: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | 8 | # for MOT17 9 | 10 | # EXP_DIR=exps/e2e_motr_r50_joint 11 | # python3 eval.py \ 12 | # --meta_arch motr \ 13 | # --dataset_file e2e_joint \ 14 | # --epoch 200 \ 15 | # --with_box_refine \ 16 | # --lr_drop 100 \ 17 | # --lr 2e-4 \ 18 | # --lr_backbone 2e-5 \ 19 | # --pretrained ${EXP_DIR}/motr_final.pth \ 20 | # --output_dir ${EXP_DIR} \ 21 | # --batch_size 1 \ 22 | # --sample_mode 'random_interval' \ 23 | # --sample_interval 10 \ 24 | # --sampler_steps 50 90 120 \ 25 | # --sampler_lengths 2 3 4 5 \ 26 | # --update_query_pos \ 27 | # --merger_dropout 0 \ 28 | # --dropout 0 \ 29 | # --random_drop 0.1 \ 30 | # --fp_ratio 0.3 \ 31 | # --query_interaction_layer 'QIM' \ 32 | # --extra_track_attn \ 33 | # --data_txt_path_train ./datasets/data_path/joint.train \ 34 | # --data_txt_path_val ./datasets/data_path/mot17.train \ 35 | # --resume ${EXP_DIR}/motr_final.pth \ -------------------------------------------------------------------------------- /configs/r50_motr_submit.sh: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | 8 | EXP_DIR=exps/e2e_motr_r50_joint 9 | python3 submit.py \ 10 | --meta_arch motr \ 11 | --dataset_file e2e_joint \ 12 | --epoch 200 \ 13 | --with_box_refine \ 14 | --lr_drop 100 \ 15 | --lr 2e-4 \ 16 | --lr_backbone 2e-5 \ 17 | --pretrained ${EXP_DIR}/motr_final.pth \ 18 | --output_dir ${EXP_DIR} \ 19 | --batch_size 1 \ 20 | --sample_mode 'random_interval' \ 21 | --sample_interval 10 \ 22 | --sampler_steps 50 90 150 \ 23 | --sampler_lengths 2 3 4 5 \ 24 | --update_query_pos \ 25 | --merger_dropout 0 \ 26 | --dropout 0 \ 27 | --random_drop 0.1 \ 28 | --fp_ratio 0.3 \ 29 | --query_interaction_layer 'QIM' \ 30 | --extra_track_attn \ 31 | --data_txt_path_train ./datasets/data_path/joint.train \ 32 | --data_txt_path_val ./datasets/data_path/mot17.train \ 33 | --resume ${EXP_DIR}/motr_final.pth \ 34 | --exp_name pub_submit_17 -------------------------------------------------------------------------------- /configs/r50_motr_submit_dance.sh: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | 8 | EXP_DIR=exps/e2e_motr_r50_dance 9 | python3 submit_dance.py \ 10 | --meta_arch motr \ 11 | --dataset_file e2e_joint \ 12 | --mot_path /data/datasets \ 13 | --epoch 200 \ 14 | --with_box_refine \ 15 | --lr_drop 100 \ 16 | --lr 2e-4 \ 17 | --lr_backbone 2e-5 \ 18 | --output_dir ${EXP_DIR} \ 19 | --batch_size 1 \ 20 | --sample_mode 'random_interval' \ 21 | --sample_interval 10 \ 22 | --sampler_steps 50 90 150 \ 23 | --sampler_lengths 2 3 4 5 \ 24 | --update_query_pos \ 25 | --merger_dropout 0 \ 26 | --dropout 0 \ 27 | --random_drop 0.1 \ 28 | --fp_ratio 0.3 \ 29 | --query_interaction_layer 'QIM' \ 30 | --extra_track_attn \ 31 | --data_txt_path_train ./datasets/data_path/joint.train \ 32 | --data_txt_path_val ./datasets/data_path/mot17.train \ 33 | --resume ${EXP_DIR}/checkpoint.pth \ 34 | --exp_name tracker 35 | -------------------------------------------------------------------------------- /configs/r50_motr_train.sh: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | 8 | 9 | # for MOT17 10 | 11 | # PRETRAIN=coco_model_final.pth 12 | # EXP_DIR=exps/e2e_motr_r50_joint 13 | # python3 -m torch.distributed.launch --nproc_per_node=8 \ 14 | # --use_env main.py \ 15 | # --meta_arch motr \ 16 | # --use_checkpoint \ 17 | # --dataset_file e2e_joint \ 18 | # --epoch 200 \ 19 | # --with_box_refine \ 20 | # --lr_drop 100 \ 21 | # --lr 2e-4 \ 22 | # --lr_backbone 2e-5 \ 23 | # --pretrained ${PRETRAIN} \ 24 | # --output_dir ${EXP_DIR} \ 25 | # --batch_size 1 \ 26 | # --sample_mode 'random_interval' \ 27 | # --sample_interval 10 \ 28 | # --sampler_steps 50 90 150 \ 29 | # --sampler_lengths 2 3 4 5 \ 30 | # --update_query_pos \ 31 | # --merger_dropout 0 \ 32 | # --dropout 0 \ 33 | # --random_drop 0.1 \ 34 | # --fp_ratio 0.3 \ 35 | # --query_interaction_layer 'QIM' \ 36 | # --extra_track_attn \ 37 | # --data_txt_path_train ./datasets/data_path/joint.train \ 38 | # --data_txt_path_val ./datasets/data_path/mot17.train \ -------------------------------------------------------------------------------- /configs/r50_motr_train_dance.sh: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | 8 | 9 | # for MOT17 10 | 11 | PRETRAIN=r50_deformable_detr_plus_iterative_bbox_refinement-checkpoint.pth 12 | EXP_DIR=exps/e2e_motr_r50_dance 13 | python3 -m torch.distributed.launch --nproc_per_node=8 \ 14 | --use_env main.py \ 15 | --meta_arch motr \ 16 | --use_checkpoint \ 17 | --dataset_file e2e_dance \ 18 | --epoch 20 \ 19 | --with_box_refine \ 20 | --lr_drop 10 \ 21 | --lr 2e-4 \ 22 | --lr_backbone 2e-5 \ 23 | --pretrained ${PRETRAIN} \ 24 | --output_dir ${EXP_DIR} \ 25 | --batch_size 1 \ 26 | --sample_mode 'random_interval' \ 27 | --sample_interval 10 \ 28 | --sampler_steps 5 9 15 \ 29 | --sampler_lengths 2 3 4 5 \ 30 | --update_query_pos \ 31 | --merger_dropout 0 \ 32 | --dropout 0 \ 33 | --random_drop 0.1 \ 34 | --fp_ratio 0.3 \ 35 | --query_interaction_layer 'QIM' \ 36 | --extra_track_attn \ 37 | --data_txt_path_train ./datasets/data_path/joint.train \ 38 | --data_txt_path_val ./datasets/data_path/mot17.train \ 39 | |& tee ${EXP_DIR}/output.log 40 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | import torch.utils.data 12 | import torchvision 13 | 14 | from .coco import build as build_coco 15 | from .detmot import build as build_e2e_mot 16 | from .dance import build as build_e2e_dance 17 | from .static_detmot import build as build_e2e_static_mot 18 | from .joint import build as build_e2e_joint 19 | from .torchvision_datasets import CocoDetection 20 | 21 | def get_coco_api_from_dataset(dataset): 22 | for _ in range(10): 23 | # if isinstance(dataset, torchvision.datasets.CocoDetection): 24 | # break 25 | if isinstance(dataset, torch.utils.data.Subset): 26 | dataset = dataset.dataset 27 | if isinstance(dataset, CocoDetection): 28 | return dataset.coco 29 | 30 | 31 | def build_dataset(image_set, args): 32 | if args.dataset_file == 'coco': 33 | return build_coco(image_set, args) 34 | if args.dataset_file == 'coco_panoptic': 35 | # to avoid making panopticapi required for coco 36 | from .coco_panoptic import build as build_coco_panoptic 37 | return build_coco_panoptic(image_set, args) 38 | if args.dataset_file == 'e2e_joint': 39 | return build_e2e_joint(image_set, args) 40 | if args.dataset_file == 'e2e_static_mot': 41 | return build_e2e_static_mot(image_set, args) 42 | if args.dataset_file == 'e2e_mot': 43 | return build_e2e_mot(image_set, args) 44 | if args.dataset_file == 'e2e_dance': 45 | return build_e2e_dance(image_set, args) 46 | raise ValueError(f'dataset {args.dataset_file} not supported') 47 | -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/coco.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/datasets/__pycache__/coco.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/coco_eval.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/datasets/__pycache__/coco_eval.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/data_prefetcher.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/datasets/__pycache__/data_prefetcher.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/detmot.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/datasets/__pycache__/detmot.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/joint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/datasets/__pycache__/joint.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/panoptic_eval.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/datasets/__pycache__/panoptic_eval.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/samplers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/datasets/__pycache__/samplers.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/static_detmot.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/datasets/__pycache__/static_detmot.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/datasets/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/coco.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | 12 | """ 13 | COCO dataset which returns image_id for evaluation. 14 | 15 | Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py 16 | """ 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.utils.data 21 | from pycocotools import mask as coco_mask 22 | 23 | from .torchvision_datasets import CocoDetection as TvCocoDetection 24 | from util.misc import get_local_rank, get_local_size 25 | import datasets.transforms as T 26 | 27 | 28 | class CocoDetection(TvCocoDetection): 29 | def __init__(self, img_folder, ann_file, transforms, return_masks, cache_mode=False, local_rank=0, local_size=1): 30 | super(CocoDetection, self).__init__(img_folder, ann_file, 31 | cache_mode=cache_mode, local_rank=local_rank, local_size=local_size) 32 | self._transforms = transforms 33 | self.prepare = ConvertCocoPolysToMask(return_masks) 34 | 35 | def __getitem__(self, idx): 36 | img, target = super(CocoDetection, self).__getitem__(idx) 37 | image_id = self.ids[idx] 38 | target = {'image_id': image_id, 'annotations': target} 39 | img, target = self.prepare(img, target) 40 | if self._transforms is not None: 41 | img, target = self._transforms(img, target) 42 | return img, target 43 | 44 | 45 | def convert_coco_poly_to_mask(segmentations, height, width): 46 | masks = [] 47 | for polygons in segmentations: 48 | rles = coco_mask.frPyObjects(polygons, height, width) 49 | mask = coco_mask.decode(rles) 50 | if len(mask.shape) < 3: 51 | mask = mask[..., None] 52 | mask = torch.as_tensor(mask, dtype=torch.uint8) 53 | mask = mask.any(dim=2) 54 | masks.append(mask) 55 | if masks: 56 | masks = torch.stack(masks, dim=0) 57 | else: 58 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 59 | return masks 60 | 61 | 62 | class ConvertCocoPolysToMask(object): 63 | def __init__(self, return_masks=False): 64 | self.return_masks = return_masks 65 | 66 | def __call__(self, image, target): 67 | w, h = image.size 68 | 69 | image_id = target["image_id"] 70 | image_id = torch.tensor([image_id]) 71 | 72 | anno = target["annotations"] 73 | 74 | anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] 75 | 76 | boxes = [obj["bbox"] for obj in anno] 77 | # guard against no boxes via resizing 78 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 79 | boxes[:, 2:] += boxes[:, :2] 80 | boxes[:, 0::2].clamp_(min=0, max=w) 81 | boxes[:, 1::2].clamp_(min=0, max=h) 82 | 83 | classes = [obj["category_id"] for obj in anno] 84 | classes = torch.tensor(classes, dtype=torch.int64) 85 | 86 | if self.return_masks: 87 | segmentations = [obj["segmentation"] for obj in anno] 88 | masks = convert_coco_poly_to_mask(segmentations, h, w) 89 | 90 | keypoints = None 91 | if anno and "keypoints" in anno[0]: 92 | keypoints = [obj["keypoints"] for obj in anno] 93 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32) 94 | num_keypoints = keypoints.shape[0] 95 | if num_keypoints: 96 | keypoints = keypoints.view(num_keypoints, -1, 3) 97 | 98 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 99 | boxes = boxes[keep] 100 | classes = classes[keep] 101 | if self.return_masks: 102 | masks = masks[keep] 103 | if keypoints is not None: 104 | keypoints = keypoints[keep] 105 | 106 | target = {} 107 | target["boxes"] = boxes 108 | target["labels"] = classes 109 | if self.return_masks: 110 | target["masks"] = masks 111 | target["image_id"] = image_id 112 | if keypoints is not None: 113 | target["keypoints"] = keypoints 114 | 115 | # for conversion to coco api 116 | area = torch.tensor([obj["area"] for obj in anno]) 117 | iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) 118 | target["area"] = area[keep] 119 | target["iscrowd"] = iscrowd[keep] 120 | 121 | target["orig_size"] = torch.as_tensor([int(h), int(w)]) 122 | target["size"] = torch.as_tensor([int(h), int(w)]) 123 | 124 | return image, target 125 | 126 | 127 | def make_coco_transforms(image_set): 128 | 129 | normalize = T.Compose([ 130 | T.ToTensor(), 131 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 132 | ]) 133 | 134 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 135 | 136 | if image_set == 'train': 137 | return T.Compose([ 138 | T.RandomHorizontalFlip(), 139 | T.RandomSelect( 140 | T.RandomResize(scales, max_size=1333), 141 | T.Compose([ 142 | T.RandomResize([400, 500, 600]), 143 | T.RandomSizeCrop(384, 600), 144 | T.RandomResize(scales, max_size=1333), 145 | ]) 146 | ), 147 | normalize, 148 | ]) 149 | 150 | if image_set == 'val': 151 | return T.Compose([ 152 | T.RandomResize([800], max_size=1333), 153 | normalize, 154 | ]) 155 | 156 | raise ValueError(f'unknown {image_set}') 157 | 158 | 159 | def build(image_set, args): 160 | root = Path(args.coco_path) 161 | assert root.exists(), f'provided COCO path {root} does not exist' 162 | mode = 'instances' 163 | PATHS = { 164 | "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'), 165 | "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'), 166 | } 167 | 168 | img_folder, ann_file = PATHS[image_set] 169 | dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks, 170 | cache_mode=args.cache_mode, local_rank=get_local_rank(), local_size=get_local_size()) 171 | return dataset 172 | -------------------------------------------------------------------------------- /datasets/coco_eval.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | 12 | """ 13 | COCO evaluator that works in distributed mode. 14 | 15 | Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py 16 | The difference is that there is less copy-pasting from pycocotools 17 | in the end of the file, as python3 can suppress prints with contextlib 18 | """ 19 | import os 20 | import contextlib 21 | import copy 22 | import numpy as np 23 | import torch 24 | 25 | from pycocotools.cocoeval import COCOeval 26 | from pycocotools.coco import COCO 27 | import pycocotools.mask as mask_util 28 | 29 | from util.misc import all_gather 30 | 31 | 32 | class CocoEvaluator(object): 33 | def __init__(self, coco_gt, iou_types): 34 | assert isinstance(iou_types, (list, tuple)) 35 | coco_gt = copy.deepcopy(coco_gt) 36 | self.coco_gt = coco_gt 37 | 38 | self.iou_types = iou_types 39 | self.coco_eval = {} 40 | for iou_type in iou_types: 41 | self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) 42 | 43 | self.img_ids = [] 44 | self.eval_imgs = {k: [] for k in iou_types} 45 | 46 | def update(self, predictions): 47 | img_ids = list(np.unique(list(predictions.keys()))) 48 | self.img_ids.extend(img_ids) 49 | 50 | for iou_type in self.iou_types: 51 | results = self.prepare(predictions, iou_type) 52 | 53 | # suppress pycocotools prints 54 | with open(os.devnull, 'w') as devnull: 55 | print("self.coco_gt={}".format(self.coco_gt)) 56 | with contextlib.redirect_stdout(devnull): 57 | coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() 58 | coco_eval = self.coco_eval[iou_type] 59 | 60 | coco_eval.cocoDt = coco_dt 61 | coco_eval.params.imgIds = list(img_ids) 62 | img_ids, eval_imgs = evaluate(coco_eval) 63 | 64 | self.eval_imgs[iou_type].append(eval_imgs) 65 | 66 | def synchronize_between_processes(self): 67 | for iou_type in self.iou_types: 68 | self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) 69 | create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) 70 | 71 | def accumulate(self): 72 | for coco_eval in self.coco_eval.values(): 73 | coco_eval.accumulate() 74 | 75 | def summarize(self): 76 | for iou_type, coco_eval in self.coco_eval.items(): 77 | print("IoU metric: {}".format(iou_type)) 78 | coco_eval.summarize() 79 | 80 | def prepare(self, predictions, iou_type): 81 | if iou_type == "bbox": 82 | return self.prepare_for_coco_detection(predictions) 83 | elif iou_type == "segm": 84 | return self.prepare_for_coco_segmentation(predictions) 85 | elif iou_type == "keypoints": 86 | return self.prepare_for_coco_keypoint(predictions) 87 | else: 88 | raise ValueError("Unknown iou type {}".format(iou_type)) 89 | 90 | def prepare_for_coco_detection(self, predictions): 91 | coco_results = [] 92 | for original_id, prediction in predictions.items(): 93 | if len(prediction) == 0: 94 | continue 95 | 96 | boxes = prediction["boxes"] 97 | boxes = convert_to_xywh(boxes).tolist() 98 | scores = prediction["scores"].tolist() 99 | labels = prediction["labels"].tolist() 100 | 101 | coco_results.extend( 102 | [ 103 | { 104 | "image_id": original_id, 105 | "category_id": labels[k], 106 | "bbox": box, 107 | "score": scores[k], 108 | } 109 | for k, box in enumerate(boxes) 110 | ] 111 | ) 112 | return coco_results 113 | 114 | def prepare_for_coco_segmentation(self, predictions): 115 | coco_results = [] 116 | for original_id, prediction in predictions.items(): 117 | if len(prediction) == 0: 118 | continue 119 | 120 | scores = prediction["scores"] 121 | labels = prediction["labels"] 122 | masks = prediction["masks"] 123 | 124 | masks = masks > 0.5 125 | 126 | scores = prediction["scores"].tolist() 127 | labels = prediction["labels"].tolist() 128 | 129 | rles = [ 130 | mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] 131 | for mask in masks 132 | ] 133 | for rle in rles: 134 | rle["counts"] = rle["counts"].decode("utf-8") 135 | 136 | coco_results.extend( 137 | [ 138 | { 139 | "image_id": original_id, 140 | "category_id": labels[k], 141 | "segmentation": rle, 142 | "score": scores[k], 143 | } 144 | for k, rle in enumerate(rles) 145 | ] 146 | ) 147 | return coco_results 148 | 149 | def prepare_for_coco_keypoint(self, predictions): 150 | coco_results = [] 151 | for original_id, prediction in predictions.items(): 152 | if len(prediction) == 0: 153 | continue 154 | 155 | boxes = prediction["boxes"] 156 | boxes = convert_to_xywh(boxes).tolist() 157 | scores = prediction["scores"].tolist() 158 | labels = prediction["labels"].tolist() 159 | keypoints = prediction["keypoints"] 160 | keypoints = keypoints.flatten(start_dim=1).tolist() 161 | 162 | coco_results.extend( 163 | [ 164 | { 165 | "image_id": original_id, 166 | "category_id": labels[k], 167 | 'keypoints': keypoint, 168 | "score": scores[k], 169 | } 170 | for k, keypoint in enumerate(keypoints) 171 | ] 172 | ) 173 | return coco_results 174 | 175 | 176 | def convert_to_xywh(boxes): 177 | xmin, ymin, xmax, ymax = boxes.unbind(1) 178 | return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) 179 | 180 | 181 | def merge(img_ids, eval_imgs): 182 | all_img_ids = all_gather(img_ids) 183 | all_eval_imgs = all_gather(eval_imgs) 184 | 185 | merged_img_ids = [] 186 | for p in all_img_ids: 187 | merged_img_ids.extend(p) 188 | 189 | merged_eval_imgs = [] 190 | for p in all_eval_imgs: 191 | merged_eval_imgs.append(p) 192 | 193 | merged_img_ids = np.array(merged_img_ids) 194 | merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) 195 | 196 | # keep only unique (and in sorted order) images 197 | merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) 198 | merged_eval_imgs = merged_eval_imgs[..., idx] 199 | 200 | return merged_img_ids, merged_eval_imgs 201 | 202 | 203 | def create_common_coco_eval(coco_eval, img_ids, eval_imgs): 204 | img_ids, eval_imgs = merge(img_ids, eval_imgs) 205 | img_ids = list(img_ids) 206 | eval_imgs = list(eval_imgs.flatten()) 207 | 208 | coco_eval.evalImgs = eval_imgs 209 | coco_eval.params.imgIds = img_ids 210 | coco_eval._paramsEval = copy.deepcopy(coco_eval.params) 211 | 212 | 213 | ################################################################# 214 | # From pycocotools, just removed the prints and fixed 215 | # a Python3 bug about unicode not defined 216 | ################################################################# 217 | 218 | 219 | def evaluate(self): 220 | ''' 221 | Run per image evaluation on given images and store results (a list of dict) in self.evalImgs 222 | :return: None 223 | ''' 224 | # tic = time.time() 225 | # print('Running per image evaluation...') 226 | p = self.params 227 | # add backward compatibility if useSegm is specified in params 228 | if p.useSegm is not None: 229 | p.iouType = 'segm' if p.useSegm == 1 else 'bbox' 230 | print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) 231 | # print('Evaluate annotation type *{}*'.format(p.iouType)) 232 | p.imgIds = list(np.unique(p.imgIds)) 233 | if p.useCats: 234 | p.catIds = list(np.unique(p.catIds)) 235 | p.maxDets = sorted(p.maxDets) 236 | self.params = p 237 | 238 | self._prepare() 239 | # loop through images, area range, max detection number 240 | catIds = p.catIds if p.useCats else [-1] 241 | 242 | if p.iouType == 'segm' or p.iouType == 'bbox': 243 | computeIoU = self.computeIoU 244 | elif p.iouType == 'keypoints': 245 | computeIoU = self.computeOks 246 | self.ious = { 247 | (imgId, catId): computeIoU(imgId, catId) 248 | for imgId in p.imgIds 249 | for catId in catIds} 250 | 251 | evaluateImg = self.evaluateImg 252 | maxDet = p.maxDets[-1] 253 | evalImgs = [ 254 | evaluateImg(imgId, catId, areaRng, maxDet) 255 | for catId in catIds 256 | for areaRng in p.areaRng 257 | for imgId in p.imgIds 258 | ] 259 | # this is NOT in the pycocotools code, but could be done outside 260 | evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) 261 | self._paramsEval = copy.deepcopy(self.params) 262 | # toc = time.time() 263 | # print('DONE (t={:0.2f}s).'.format(toc-tic)) 264 | return p.imgIds, evalImgs 265 | 266 | ################################################################# 267 | # end of straight copy from pycocotools, just removing the prints 268 | ################################################################# 269 | -------------------------------------------------------------------------------- /datasets/coco_panoptic.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | 12 | import json 13 | from pathlib import Path 14 | 15 | import numpy as np 16 | import torch 17 | from PIL import Image 18 | 19 | from panopticapi.utils import rgb2id 20 | from util.box_ops import masks_to_boxes 21 | 22 | from .coco import make_coco_transforms 23 | 24 | 25 | class CocoPanoptic: 26 | def __init__(self, img_folder, ann_folder, ann_file, transforms=None, return_masks=True): 27 | with open(ann_file, 'r') as f: 28 | self.coco = json.load(f) 29 | 30 | # sort 'images' field so that they are aligned with 'annotations' 31 | # i.e., in alphabetical order 32 | self.coco['images'] = sorted(self.coco['images'], key=lambda x: x['id']) 33 | # sanity check 34 | if "annotations" in self.coco: 35 | for img, ann in zip(self.coco['images'], self.coco['annotations']): 36 | assert img['file_name'][:-4] == ann['file_name'][:-4] 37 | 38 | self.img_folder = img_folder 39 | self.ann_folder = ann_folder 40 | self.ann_file = ann_file 41 | self.transforms = transforms 42 | self.return_masks = return_masks 43 | 44 | def __getitem__(self, idx): 45 | ann_info = self.coco['annotations'][idx] if "annotations" in self.coco else self.coco['images'][idx] 46 | img_path = Path(self.img_folder) / ann_info['file_name'].replace('.png', '.jpg') 47 | ann_path = Path(self.ann_folder) / ann_info['file_name'] 48 | 49 | img = Image.open(img_path).convert('RGB') 50 | w, h = img.size 51 | if "segments_info" in ann_info: 52 | masks = np.asarray(Image.open(ann_path), dtype=np.uint32) 53 | masks = rgb2id(masks) 54 | 55 | ids = np.array([ann['id'] for ann in ann_info['segments_info']]) 56 | masks = masks == ids[:, None, None] 57 | 58 | masks = torch.as_tensor(masks, dtype=torch.uint8) 59 | labels = torch.tensor([ann['category_id'] for ann in ann_info['segments_info']], dtype=torch.int64) 60 | 61 | target = {} 62 | target['image_id'] = torch.tensor([ann_info['image_id'] if "image_id" in ann_info else ann_info["id"]]) 63 | if self.return_masks: 64 | target['masks'] = masks 65 | target['labels'] = labels 66 | 67 | target["boxes"] = masks_to_boxes(masks) 68 | 69 | target['size'] = torch.as_tensor([int(h), int(w)]) 70 | target['orig_size'] = torch.as_tensor([int(h), int(w)]) 71 | if "segments_info" in ann_info: 72 | for name in ['iscrowd', 'area']: 73 | target[name] = torch.tensor([ann[name] for ann in ann_info['segments_info']]) 74 | 75 | if self.transforms is not None: 76 | img, target = self.transforms(img, target) 77 | 78 | return img, target 79 | 80 | def __len__(self): 81 | return len(self.coco['images']) 82 | 83 | def get_height_and_width(self, idx): 84 | img_info = self.coco['images'][idx] 85 | height = img_info['height'] 86 | width = img_info['width'] 87 | return height, width 88 | 89 | 90 | def build(image_set, args): 91 | img_folder_root = Path(args.coco_path) 92 | ann_folder_root = Path(args.coco_panoptic_path) 93 | assert img_folder_root.exists(), f'provided COCO path {img_folder_root} does not exist' 94 | assert ann_folder_root.exists(), f'provided COCO path {ann_folder_root} does not exist' 95 | mode = 'panoptic' 96 | PATHS = { 97 | "train": ("train2017", Path("annotations") / f'{mode}_train2017.json'), 98 | "val": ("val2017", Path("annotations") / f'{mode}_val2017.json'), 99 | } 100 | 101 | img_folder, ann_file = PATHS[image_set] 102 | img_folder_path = img_folder_root / img_folder 103 | ann_folder = ann_folder_root / f'{mode}_{img_folder}' 104 | ann_file = ann_folder_root / ann_file 105 | 106 | dataset = CocoPanoptic(img_folder_path, ann_folder, ann_file, 107 | transforms=make_coco_transforms(image_set), return_masks=args.masks) 108 | 109 | return dataset 110 | -------------------------------------------------------------------------------- /datasets/data_path/gen_bdd100k_mot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import cv2 5 | from tqdm import tqdm 6 | from collections import defaultdict 7 | 8 | 9 | def convert(img_dir, split, label_dir, save_label_dir, filter_crowd=False, filter_ignore=False): 10 | cat2id = {'train':6, 'car':3, 'bus':5, 'other person': 1, 'rider':2, 'pedestrian':1, 'other vehicle':3, 'motorcycle':7, 'bicycle':8, 'trailer':4, 'truck':4} 11 | 12 | coco = defaultdict(list) 13 | coco["categories"] = [ 14 | {"supercategory": "human", "id": 1, "name": "pedestrian"}, 15 | {"supercategory": "human", "id": 2, "name": "rider"}, 16 | {"supercategory": "vehicle", "id": 3, "name": "car"}, 17 | {"supercategory": "vehicle", "id": 4, "name": "truck"}, 18 | {"supercategory": "vehicle", "id": 5, "name": "bus"}, 19 | {"supercategory": "vehicle", "id": 6, "name": "train"}, 20 | {"supercategory": "bike", "id": 7, "name": "motorcycle"}, 21 | {"supercategory": "bike", "id": 8, "name": "bicycle"}, 22 | ] 23 | attr_id_dict = { 24 | frame["name"]: frame["id"] for frame in coco["categories"] 25 | } 26 | 27 | all_categories = set() 28 | img_dir = os.path.join(img_dir, split) 29 | label_dir = os.path.join(label_dir, split) 30 | vids = os.listdir(img_dir) 31 | for vid in tqdm(vids): 32 | txt_label_dir = os.path.join(save_label_dir, split, vid) 33 | os.makedirs(txt_label_dir, exist_ok=True) 34 | annos = json.load(open(os.path.join(label_dir, vid+'.json'), 'r')) 35 | for anno in annos: 36 | name = anno['name'] 37 | labels = anno['labels'] 38 | videoName = anno['videoName'] 39 | frameIndex = anno['frameIndex'] 40 | img = cv2.imread(os.path.join(img_dir, vid, name)) 41 | seq_height, seq_width, _ = img.shape 42 | if len(labels) < 1: 43 | continue 44 | # for label in labels: 45 | # category = label['category'] 46 | # all_categories.add(category) 47 | with open(os.path.join(txt_label_dir, name.replace('jpg', 'txt')), 'w') as f: 48 | for label in labels: 49 | obj_id = label['id'] 50 | category = label['category'] 51 | attributes = label['attributes'] 52 | is_crowd = attributes['crowd'] 53 | 54 | if filter_crowd and is_crowd: 55 | continue 56 | if filter_ignore and (category not in attr_id_dict.keys()): 57 | continue 58 | 59 | box2d = label['box2d'] 60 | x1 = box2d['x1'] 61 | x2 = box2d['x2'] 62 | y1 = box2d['y1'] 63 | y2 = box2d['y2'] 64 | w = x2-x1 65 | h = y2-y1 66 | cx = (x1+x2) / 2 67 | cy = (y1+y2) / 2 68 | label_str = '{:d} {:d} {:.6f} {:.6f} {:.6f} {:.6f}\n'.format( 69 | cat2id[category], int(obj_id), cx / seq_width, cy / seq_height, w / seq_width, h / seq_height) 70 | f.write(label_str) 71 | # print(f'all categories are {all_categories}.') 72 | 73 | def generate_txt(img_dir,label_dir,txt_path='bdd100k.train',split='train'): 74 | img_dir = os.path.join(img_dir, split) 75 | label_dir = os.path.join(label_dir, split) 76 | all_vids = os.listdir(img_dir) 77 | all_frames = [] 78 | for vid in tqdm(all_vids): 79 | fids = os.listdir(os.path.join(img_dir, vid)) 80 | fids.sort() 81 | for fid in fids: 82 | if os.path.exists(os.path.join(label_dir, vid, fid.replace('jpg', 'txt'))): 83 | all_frames.append(f'images/track/{split}/{vid}/{fid}') 84 | with open(txt_path, 'w') as f: 85 | for frame in all_frames: 86 | f.write(frame+'\n') 87 | 88 | '''no filter''' 89 | # img_dir = '/data/Dataset/bdd100k/bdd100k/images/track' 90 | # label_dir = '/data/Dataset/bdd100k/bdd100k/labels/box_track_20' 91 | # save_label_dir = '/data/Dataset/bdd100k/bdd100k/labels/track' 92 | # split = 'train' 93 | # convert(img_dir, split, label_dir, save_label_dir) 94 | 95 | # img_dir = '/data/Dataset/bdd100k/bdd100k/images/track' 96 | # label_dir = '/data/Dataset/bdd100k/bdd100k/labels/box_track_20' 97 | # save_label_dir = '/data/Dataset/bdd100k/bdd100k/labels/track' 98 | # split = 'val' 99 | # convert(img_dir, split, label_dir, save_label_dir) 100 | 101 | # img_dir = '/data/Dataset/bdd100k/bdd100k/images/track' 102 | # label_dir = '/data/Dataset/bdd100k/bdd100k/labels/box_track_20' 103 | # save_label_dir = '/data/Dataset/bdd100k/bdd100k/labels/track' 104 | # split = 'train' 105 | # generate_txt(img_dir,save_label_dir,txt_path='bdd100k.train',split='train') 106 | 107 | # img_dir = '/data/Dataset/bdd100k/bdd100k/images/track' 108 | # label_dir = '/data/Dataset/bdd100k/bdd100k/labels/box_track_20' 109 | # save_label_dir = '/data/Dataset/bdd100k/bdd100k/labels/track' 110 | # split = 'val' 111 | # generate_txt(img_dir,save_label_dir,txt_path='bdd100k.val',split='val') 112 | 113 | 114 | '''for filter''' 115 | # img_dir = '/data/Dataset/bdd100k/bdd100k/images/track' 116 | # label_dir = '/data/Dataset/bdd100k/bdd100k/labels/box_track_20' 117 | # save_label_dir = '/data/Dataset/bdd100k/bdd100k/filter_labels/track' 118 | # split = 'train' 119 | # convert(img_dir, split, label_dir, save_label_dir, filter_crowd=True, filter_ignore=True) 120 | 121 | # img_dir = '/data/Dataset/bdd100k/bdd100k/images/track' 122 | # label_dir = '/data/Dataset/bdd100k/bdd100k/labels/box_track_20' 123 | # save_label_dir = '/data/Dataset/bdd100k/bdd100k/filter_labels/track' 124 | # split = 'val' 125 | # convert(img_dir, split, label_dir, save_label_dir, filter_crowd=True, filter_ignore=True) 126 | 127 | # img_dir = '/data/Dataset/bdd100k/bdd100k/images/track' 128 | # label_dir = '/data/Dataset/bdd100k/bdd100k/labels/box_track_20' 129 | # save_label_dir = '/data/Dataset/bdd100k/bdd100k/filter_labels/track' 130 | # split = 'train' 131 | # generate_txt(img_dir,save_label_dir,txt_path='filter.bdd100k.train',split='train') 132 | 133 | # img_dir = '/data/Dataset/bdd100k/bdd100k/images/track' 134 | # label_dir = '/data/Dataset/bdd100k/bdd100k/labels/box_track_20' 135 | # save_label_dir = '/data/Dataset/bdd100k/bdd100k/filter_labels/track' 136 | # split = 'val' 137 | # generate_txt(img_dir,save_label_dir,txt_path='filter.bdd100k.val',split='val') 138 | 139 | -------------------------------------------------------------------------------- /datasets/data_path/gen_labels_15.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os 3 | import numpy as np 4 | import cv2 5 | from tqdm import tqdm 6 | 7 | def mkdirs(d): 8 | if not osp.exists(d): 9 | os.makedirs(d) 10 | 11 | seq_root = '/data/workspace/datasets/mot/MOT15/images/train' 12 | label_root = '/data/workspace/datasets/mot/MOT15/labels_with_ids/train' 13 | mkdirs(label_root) 14 | seqs = ['ADL-Rundle-6', 'ETH-Bahnhof', 'KITTI-13', 'PETS09-S2L1', 'TUD-Stadtmitte', 'ADL-Rundle-8', 'KITTI-17', 15 | 'ETH-Pedcross2', 'ETH-Sunnyday', 'TUD-Campus', 'Venice-2'] 16 | 17 | tid_curr = 0 18 | tid_last = -1 19 | for seq in tqdm(seqs): 20 | 21 | # seq_info = open(osp.join(seq_root, seq, 'seqinfo.ini')).read() 22 | # seq_width = int(seq_info[seq_info.find('imWidth=') + 8:seq_info.find('\nimHeight')]) 23 | # seq_height = int(seq_info[seq_info.find('imHeight=') + 9:seq_info.find('\nimExt')]) 24 | 25 | all_imgs = os.listdir(osp.join(seq_root, seq, 'img1')) 26 | fm = cv2.imread(osp.join(seq_root, seq, 'img1', all_imgs[0])) 27 | seq_height, seq_width, c = fm.shape 28 | 29 | gt_txt = osp.join(seq_root, seq, 'gt', 'gt.txt') 30 | gt = np.loadtxt(gt_txt, dtype=np.float64, delimiter=',') 31 | idx = np.lexsort(gt.T[:2, :]) 32 | gt = gt[idx, :] 33 | 34 | seq_label_root = osp.join(label_root, seq, 'img1') 35 | mkdirs(seq_label_root) 36 | 37 | for fid, tid, x, y, w, h, mark, _, _, _ in gt: 38 | if mark == 0: 39 | continue 40 | fid = int(fid) 41 | tid = int(tid) 42 | if not tid == tid_last: 43 | tid_curr += 1 44 | tid_last = tid 45 | x += w / 2 46 | y += h / 2 47 | label_fpath = osp.join(seq_label_root, '{:06d}.txt'.format(fid)) 48 | label_str = '0 {:d} {:.6f} {:.6f} {:.6f} {:.6f}\n'.format( 49 | tid_curr, x / seq_width, y / seq_height, w / seq_width, h / seq_height) 50 | with open(label_fpath, 'a') as f: 51 | f.write(label_str) -------------------------------------------------------------------------------- /datasets/data_path/gen_labels_16.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os 3 | import numpy as np 4 | def mkdirs(d): 5 | if not osp.exists(d): 6 | os.makedirs(d) 7 | 8 | seq_root = '/data/workspace/datasets/mot/MOT16/images/train' 9 | label_root = '/data/workspace/datasets/mot/MOT16/labels_with_ids/train' 10 | mkdirs(label_root) 11 | seqs = [s for s in os.listdir(seq_root)] 12 | 13 | tid_curr = 0 14 | tid_last = -1 15 | for seq in seqs: 16 | seq_info = open(osp.join(seq_root, seq, 'seqinfo.ini')).read() 17 | seq_width = int(seq_info[seq_info.find('imWidth=') + 8:seq_info.find('\nimHeight')]) 18 | seq_height = int(seq_info[seq_info.find('imHeight=') + 9:seq_info.find('\nimExt')]) 19 | 20 | gt_txt = osp.join(seq_root, seq, 'gt', 'gt.txt') 21 | gt = np.loadtxt(gt_txt, dtype=np.float64, delimiter=',') 22 | idx = np.lexsort(gt.T[:2, :]) 23 | gt = gt[idx, :] 24 | 25 | seq_label_root = osp.join(label_root, seq, 'img1') 26 | mkdirs(seq_label_root) 27 | 28 | for fid, tid, x, y, w, h, mark, _, _ in gt: 29 | if mark == 0: 30 | continue 31 | fid = int(fid) 32 | tid = int(tid) 33 | if not tid == tid_last: 34 | tid_curr += 1 35 | tid_last = tid 36 | x += w / 2 37 | y += h / 2 38 | label_fpath = osp.join(seq_label_root, '{:06d}.txt'.format(fid)) 39 | label_str = '0 {:d} {:.6f} {:.6f} {:.6f} {:.6f}\n'.format( 40 | tid_curr, x / seq_width, y / seq_height, w / seq_width, h / seq_height) 41 | with open(label_fpath, 'a') as f: 42 | f.write(label_str) -------------------------------------------------------------------------------- /datasets/data_path/prepare.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | from typing import List 4 | 5 | 6 | def solve_MOT_train(root, year): 7 | assert year in [15, 16, 17] 8 | dataset_path = 'MOT{}/images/train'.format(year) 9 | data_root = os.path.join(root, dataset_path) 10 | if year == 17: 11 | video_paths = [] 12 | for video_name in os.listdir(data_root): 13 | if 'SDP' in video_name: 14 | video_paths.append(video_name) 15 | else: 16 | video_paths = os.listdir(data_root) 17 | 18 | frames = [] 19 | for video_name in video_paths: 20 | files = os.listdir(os.path.join(data_root, video_name, 'img1')) 21 | files.sort() 22 | for i in range(1, len(files) + 1): 23 | frames.append(os.path.join(dataset_path, video_name, 'img1', '%06d.jpg' % i)) 24 | return frames 25 | 26 | 27 | def solve_CUHK(root): 28 | dataset_path = 'ethz/CUHK-SYSU' 29 | data_root = os.path.join(root, dataset_path) 30 | file_names = os.listdir(os.path.join(data_root, 'images')) 31 | file_names.sort() 32 | 33 | frames = [] 34 | for i in range(len(file_names)): 35 | if os.path.exists(os.path.join(root, 'ethz/CUHK-SYSU/labels_with_ids', f's{i + 1}.txt')): 36 | if os.path.exists(os.path.join(root, 'ethz/CUHK-SYSU/images', f's{i + 1}.jpg')): 37 | frames.append(os.path.join('ethz/CUHK-SYSU/images', f's{i + 1}.jpg')) 38 | return frames 39 | 40 | def solve_ETHZ(root): 41 | dataset_path = 'ethz/ETHZ' 42 | data_root = os.path.join(root, dataset_path) 43 | video_paths = [] 44 | for name in os.listdir(data_root): 45 | if name not in ['eth01', 'eth03']: 46 | video_paths.append(name) 47 | 48 | frames = [] 49 | for video_path in video_paths: 50 | files = os.listdir(os.path.join(data_root, video_path, 'images')) 51 | files.sort() 52 | for img_name in files: 53 | if os.path.exists(os.path.join(data_root, video_path, 'labels_with_ids', img_name.replace('.png', '.txt'))): 54 | if os.path.exists(os.path.join(data_root, video_path, 'images', img_name)): 55 | frames.append(os.path.join('ethz/ETHZ', video_path, 'images', img_name)) 56 | return frames 57 | 58 | 59 | def solve_PRW(root): 60 | dataset_path = 'ethz/PRW' 61 | data_root = os.path.join(root, dataset_path) 62 | frame_paths = os.listdir(os.path.join(data_root, 'images')) 63 | frame_paths.sort() 64 | frames = [] 65 | for i in range(len(frame_paths)): 66 | if os.path.exists(os.path.join(data_root, 'labels_with_ids', frame_paths[i].split('.')[0] + '.txt')): 67 | if os.path.exists(os.path.join(data_root, 'images', frame_paths[i])): 68 | frames.append(os.path.join(dataset_path, 'images', frame_paths[i])) 69 | return frames 70 | 71 | 72 | dataset_catalog = { 73 | 'MOT15': partial(solve_MOT_train, year=15), 74 | 'MOT16': partial(solve_MOT_train, year=16), 75 | 'MOT17': partial(solve_MOT_train, year=17), 76 | 'CUHK-SYSU': solve_CUHK, 77 | 'ETHZ': solve_ETHZ, 78 | 'PRW': solve_PRW, 79 | } 80 | 81 | 82 | def solve(dataset_list: List[str], root, save_path): 83 | all_frames = [] 84 | for dataset_name in dataset_list: 85 | dataset_frames = dataset_catalog[dataset_name](root) 86 | print("solve {} frames from dataset:{} ".format(len(dataset_frames), dataset_name)) 87 | all_frames.extend(dataset_frames) 88 | print("totally {} frames are solved.".format(len(all_frames))) 89 | with open(save_path, 'w') as f: 90 | for u in all_frames: 91 | line = '{}'.format(u) + '\n' 92 | f.writelines(line) 93 | 94 | root = '/data/workspace/datasets/mot' 95 | save_path = '/data/workspace/detr-mot/datasets/data_path/mot17.train' # for fangao 96 | dataset_list = ['MOT17', ] 97 | 98 | solve(dataset_list, root, save_path) 99 | -------------------------------------------------------------------------------- /datasets/data_prefetcher.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | 12 | import torch 13 | from functools import partial 14 | from models.structures import Instances 15 | 16 | def to_cuda(samples, targets, device): 17 | samples = samples.to(device, non_blocking=True) 18 | targets = [{k: v.to(device, non_blocking=True) for k, v in t.items()} for t in targets] 19 | return samples, targets 20 | 21 | 22 | def tensor_to_cuda(tensor: torch.Tensor, device): 23 | return tensor.to(device) 24 | 25 | 26 | def is_tensor_or_instances(data): 27 | return isinstance(data, torch.Tensor) or isinstance(data, Instances) 28 | 29 | 30 | def data_apply(data, check_func, apply_func): 31 | if isinstance(data, dict): 32 | for k in data.keys(): 33 | if check_func(data[k]): 34 | data[k] = apply_func(data[k]) 35 | elif isinstance(data[k], dict) or isinstance(data[k], list): 36 | data_apply(data[k], check_func, apply_func) 37 | else: 38 | raise ValueError() 39 | elif isinstance(data, list): 40 | for i in range(len(data)): 41 | if check_func(data[i]): 42 | data[i] = apply_func(data[i]) 43 | elif isinstance(data[i], dict) or isinstance(data[i], list): 44 | data_apply(data[i], check_func, apply_func) 45 | else: 46 | raise ValueError("invalid type {}".format(type(data[i]))) 47 | else: 48 | raise ValueError("invalid type {}".format(type(data))) 49 | return data 50 | 51 | 52 | def data_dict_to_cuda(data_dict, device): 53 | return data_apply(data_dict, is_tensor_or_instances, partial(tensor_to_cuda, device=device)) 54 | 55 | 56 | class data_prefetcher(): 57 | def __init__(self, loader, device, prefetch=True): 58 | self.loader = iter(loader) 59 | self.prefetch = prefetch 60 | self.device = device 61 | if prefetch: 62 | self.stream = torch.cuda.Stream() 63 | self.preload() 64 | 65 | def preload(self): 66 | try: 67 | self.next_samples, self.next_targets = next(self.loader) 68 | except StopIteration: 69 | self.next_samples = None 70 | self.next_targets = None 71 | return 72 | # if record_stream() doesn't work, another option is to make sure device inputs are created 73 | # on the main stream. 74 | # self.next_input_gpu = torch.empty_like(self.next_input, device='cuda') 75 | # self.next_target_gpu = torch.empty_like(self.next_target, device='cuda') 76 | # Need to make sure the memory allocated for next_* is not still in use by the main stream 77 | # at the time we start copying to next_*: 78 | # self.stream.wait_stream(torch.cuda.current_stream()) 79 | with torch.cuda.stream(self.stream): 80 | self.next_samples, self.next_targets = to_cuda(self.next_samples, self.next_targets, self.device) 81 | # more code for the alternative if record_stream() doesn't work: 82 | # copy_ will record the use of the pinned source tensor in this side stream. 83 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True) 84 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True) 85 | # self.next_input = self.next_input_gpu 86 | # self.next_target = self.next_target_gpu 87 | 88 | # With Amp, it isn't necessary to manually convert data to half. 89 | # if args.fp16: 90 | # self.next_input = self.next_input.half() 91 | # else: 92 | 93 | def next(self): 94 | if self.prefetch: 95 | torch.cuda.current_stream().wait_stream(self.stream) 96 | samples = self.next_samples 97 | targets = self.next_targets 98 | if samples is not None: 99 | samples.record_stream(torch.cuda.current_stream()) 100 | if targets is not None: 101 | for t in targets: 102 | for k, v in t.items(): 103 | v.record_stream(torch.cuda.current_stream()) 104 | self.preload() 105 | else: 106 | try: 107 | samples, targets = next(self.loader) 108 | samples, targets = to_cuda(samples, targets, self.device) 109 | except StopIteration: 110 | print("catch_stop_iter") 111 | samples = None 112 | targets = None 113 | 114 | return samples, targets 115 | -------------------------------------------------------------------------------- /datasets/detmot.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | """ 12 | MOT dataset which returns image_id for evaluation. 13 | """ 14 | from pathlib import Path 15 | import cv2 16 | import numpy as np 17 | import torch 18 | import torch.utils.data 19 | import os.path as osp 20 | from PIL import Image, ImageDraw 21 | import copy 22 | import datasets.transforms as T 23 | from models.structures import Instances 24 | 25 | 26 | class DetMOTDetection: 27 | def __init__(self, args, data_txt_path: str, seqs_folder, transforms): 28 | self.args = args 29 | self._transforms = transforms 30 | self.num_frames_per_batch = max(args.sampler_lengths) 31 | self.sample_mode = args.sample_mode 32 | self.sample_interval = args.sample_interval 33 | self.vis = args.vis 34 | self.video_dict = {} 35 | 36 | with open(data_txt_path, 'r') as file: 37 | self.img_files = file.readlines() 38 | self.img_files = [osp.join(seqs_folder, x.split(',')[0].strip()) for x in self.img_files] 39 | self.img_files = list(filter(lambda x: len(x) > 0, self.img_files)) 40 | self.label_files = [(x.replace('images', 'labels_with_ids').replace('.png', '.txt').replace('.jpg', '.txt')) 41 | for x in self.img_files] 42 | # The number of images per sample: 1 + (num_frames - 1) * interval. 43 | # The number of valid samples: num_images - num_image_per_sample + 1. 44 | self.item_num = len(self.img_files) - (self.num_frames_per_batch - 1) * self.sample_interval 45 | 46 | self._register_videos() 47 | 48 | # video sampler. 49 | self.sampler_steps: list = args.sampler_steps 50 | self.lengths: list = args.sampler_lengths 51 | print("sampler_steps={} lenghts={}".format(self.sampler_steps, self.lengths)) 52 | if self.sampler_steps is not None and len(self.sampler_steps) > 0: 53 | # Enable sampling length adjustment. 54 | assert len(self.lengths) > 0 55 | assert len(self.lengths) == len(self.sampler_steps) + 1 56 | for i in range(len(self.sampler_steps) - 1): 57 | assert self.sampler_steps[i] < self.sampler_steps[i + 1] 58 | self.item_num = len(self.img_files) - (self.lengths[-1] - 1) * self.sample_interval 59 | self.period_idx = 0 60 | self.num_frames_per_batch = self.lengths[0] 61 | self.current_epoch = 0 62 | 63 | def _register_videos(self): 64 | for label_name in self.label_files: 65 | video_name = '/'.join(label_name.split('/')[:-1]) 66 | if video_name not in self.video_dict: 67 | print("register {}-th video: {} ".format(len(self.video_dict) + 1, video_name)) 68 | self.video_dict[video_name] = len(self.video_dict) 69 | assert len(self.video_dict) <= 300 70 | 71 | def set_epoch(self, epoch): 72 | self.current_epoch = epoch 73 | if self.sampler_steps is None or len(self.sampler_steps) == 0: 74 | # fixed sampling length. 75 | return 76 | 77 | for i in range(len(self.sampler_steps)): 78 | if epoch >= self.sampler_steps[i]: 79 | self.period_idx = i + 1 80 | print("set epoch: epoch {} period_idx={}".format(epoch, self.period_idx)) 81 | self.num_frames_per_batch = self.lengths[self.period_idx] 82 | 83 | def step_epoch(self): 84 | # one epoch finishes. 85 | print("Dataset: epoch {} finishes".format(self.current_epoch)) 86 | self.set_epoch(self.current_epoch + 1) 87 | 88 | @staticmethod 89 | def _targets_to_instances(targets: dict, img_shape) -> Instances: 90 | gt_instances = Instances(tuple(img_shape)) 91 | gt_instances.boxes = targets['boxes'] 92 | gt_instances.labels = targets['labels'] 93 | gt_instances.obj_ids = targets['obj_ids'] 94 | gt_instances.area = targets['area'] 95 | return gt_instances 96 | 97 | def _pre_single_frame(self, idx: int): 98 | img_path = self.img_files[idx] 99 | label_path = self.label_files[idx] 100 | img = Image.open(img_path) 101 | targets = {} 102 | w, h = img._size 103 | assert w > 0 and h > 0, "invalid image {} with shape {} {}".format(img_path, w, h) 104 | if osp.isfile(label_path): 105 | labels0 = np.loadtxt(label_path, dtype=np.float32).reshape(-1, 6) 106 | 107 | # normalized cewh to pixel xyxy format 108 | labels = labels0.copy() 109 | labels[:, 2] = w * (labels0[:, 2] - labels0[:, 4] / 2) 110 | labels[:, 3] = h * (labels0[:, 3] - labels0[:, 5] / 2) 111 | labels[:, 4] = w * (labels0[:, 2] + labels0[:, 4] / 2) 112 | labels[:, 5] = h * (labels0[:, 3] + labels0[:, 5] / 2) 113 | else: 114 | raise ValueError('invalid label path: {}'.format(label_path)) 115 | video_name = '/'.join(label_path.split('/')[:-1]) 116 | obj_idx_offset = self.video_dict[video_name] * 100000 # 100000 unique ids is enough for a video. 117 | targets['boxes'] = [] 118 | targets['area'] = [] 119 | targets['iscrowd'] = [] 120 | targets['labels'] = [] 121 | targets['obj_ids'] = [] 122 | targets['image_id'] = torch.as_tensor(idx) 123 | targets['size'] = torch.as_tensor([h, w]) 124 | targets['orig_size'] = torch.as_tensor([h, w]) 125 | for label in labels: 126 | targets['boxes'].append(label[2:6].tolist()) 127 | targets['area'].append(label[4] * label[5]) 128 | targets['iscrowd'].append(0) 129 | targets['labels'].append(0) 130 | obj_id = label[1] + obj_idx_offset if label[1] >= 0 else label[1] 131 | targets['obj_ids'].append(obj_id) # relative id 132 | 133 | targets['area'] = torch.as_tensor(targets['area']) 134 | targets['iscrowd'] = torch.as_tensor(targets['iscrowd']) 135 | targets['labels'] = torch.as_tensor(targets['labels']) 136 | targets['obj_ids'] = torch.as_tensor(targets['obj_ids']) 137 | targets['boxes'] = torch.as_tensor(targets['boxes'], dtype=torch.float32).reshape(-1, 4) 138 | targets['boxes'][:, 0::2].clamp_(min=0, max=w) 139 | targets['boxes'][:, 1::2].clamp_(min=0, max=h) 140 | return img, targets 141 | 142 | def _get_sample_range(self, start_idx): 143 | 144 | # take default sampling method for normal dataset. 145 | assert self.sample_mode in ['fixed_interval', 'random_interval'], 'invalid sample mode: {}'.format(self.sample_mode) 146 | if self.sample_mode == 'fixed_interval': 147 | sample_interval = self.sample_interval 148 | elif self.sample_mode == 'random_interval': 149 | sample_interval = np.random.randint(1, self.sample_interval + 1) 150 | default_range = start_idx, start_idx + (self.num_frames_per_batch - 1) * sample_interval + 1, sample_interval 151 | return default_range 152 | 153 | def pre_continuous_frames(self, start, end, interval=1): 154 | targets = [] 155 | images = [] 156 | for i in range(start, end, interval): 157 | img_i, targets_i = self._pre_single_frame(i) 158 | images.append(img_i) 159 | targets.append(targets_i) 160 | return images, targets 161 | 162 | def __getitem__(self, idx): 163 | sample_start, sample_end, sample_interval = self._get_sample_range(idx) 164 | images, targets = self.pre_continuous_frames(sample_start, sample_end, sample_interval) 165 | data = {} 166 | if self._transforms is not None: 167 | images, targets = self._transforms(images, targets) 168 | gt_instances = [] 169 | for img_i, targets_i in zip(images, targets): 170 | gt_instances_i = self._targets_to_instances(targets_i, img_i.shape[1:3]) 171 | gt_instances.append(gt_instances_i) 172 | data.update({ 173 | 'imgs': images, 174 | 'gt_instances': gt_instances, 175 | }) 176 | if self.args.vis: 177 | data['ori_img'] = [target_i['ori_img'] for target_i in targets] 178 | return data 179 | 180 | def __len__(self): 181 | return self.item_num 182 | 183 | 184 | class DetMOTDetectionValidation(DetMOTDetection): 185 | def __init__(self, args, seqs_folder, transforms): 186 | args.data_txt_path = args.val_data_txt_path 187 | super().__init__(args, seqs_folder, transforms) 188 | 189 | 190 | def make_detmot_transforms(image_set, args=None): 191 | normalize = T.MotCompose([ 192 | T.MotToTensor(), 193 | T.MotNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 194 | ]) 195 | 196 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 197 | 198 | if image_set == 'train': 199 | color_transforms = [] 200 | scale_transforms = [ 201 | T.MotRandomHorizontalFlip(), 202 | T.MotRandomResize(scales, max_size=1333), 203 | normalize, 204 | ] 205 | 206 | return T.MotCompose(color_transforms + scale_transforms) 207 | 208 | if image_set == 'val': 209 | return T.MotCompose([ 210 | T.MotRandomResize([800], max_size=1333), 211 | normalize, 212 | ]) 213 | 214 | raise ValueError(f'unknown {image_set}') 215 | 216 | 217 | def build(image_set, args): 218 | root = Path(args.mot_path) 219 | assert root.exists(), f'provided MOT path {root} does not exist' 220 | transforms = make_detmot_transforms(image_set, args) 221 | if image_set == 'train': 222 | data_txt_path = args.data_txt_path_train 223 | dataset = DetMOTDetection(args, data_txt_path=data_txt_path, seqs_folder=root, transforms=transforms) 224 | if image_set == 'val': 225 | data_txt_path = args.data_txt_path_val 226 | dataset = DetMOTDetection(args, data_txt_path=data_txt_path, seqs_folder=root, transforms=transforms) 227 | return dataset 228 | 229 | -------------------------------------------------------------------------------- /datasets/panoptic_eval.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | 12 | import json 13 | import os 14 | 15 | import util.misc as utils 16 | 17 | try: 18 | from panopticapi.evaluation import pq_compute 19 | except ImportError: 20 | pass 21 | 22 | 23 | class PanopticEvaluator(object): 24 | def __init__(self, ann_file, ann_folder, output_dir="panoptic_eval"): 25 | self.gt_json = ann_file 26 | self.gt_folder = ann_folder 27 | if utils.is_main_process(): 28 | if not os.path.exists(output_dir): 29 | os.mkdir(output_dir) 30 | self.output_dir = output_dir 31 | self.predictions = [] 32 | 33 | def update(self, predictions): 34 | for p in predictions: 35 | with open(os.path.join(self.output_dir, p["file_name"]), "wb") as f: 36 | f.write(p.pop("png_string")) 37 | 38 | self.predictions += predictions 39 | 40 | def synchronize_between_processes(self): 41 | all_predictions = utils.all_gather(self.predictions) 42 | merged_predictions = [] 43 | for p in all_predictions: 44 | merged_predictions += p 45 | self.predictions = merged_predictions 46 | 47 | def summarize(self): 48 | if utils.is_main_process(): 49 | json_data = {"annotations": self.predictions} 50 | predictions_json = os.path.join(self.output_dir, "predictions.json") 51 | with open(predictions_json, "w") as f: 52 | f.write(json.dumps(json_data)) 53 | return pq_compute(self.gt_json, predictions_json, gt_folder=self.gt_folder, pred_folder=self.output_dir) 54 | return None 55 | -------------------------------------------------------------------------------- /datasets/samplers.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | 12 | import os 13 | import math 14 | import torch 15 | import torch.distributed as dist 16 | from torch.utils.data.sampler import Sampler 17 | 18 | 19 | class DistributedSampler(Sampler): 20 | """Sampler that restricts data loading to a subset of the dataset. 21 | It is especially useful in conjunction with 22 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 23 | process can pass a DistributedSampler instance as a DataLoader sampler, 24 | and load a subset of the original dataset that is exclusive to it. 25 | .. note:: 26 | Dataset is assumed to be of constant size. 27 | Arguments: 28 | dataset: Dataset used for sampling. 29 | num_replicas (optional): Number of processes participating in 30 | distributed training. 31 | rank (optional): Rank of the current process within num_replicas. 32 | """ 33 | 34 | def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True): 35 | if num_replicas is None: 36 | if not dist.is_available(): 37 | raise RuntimeError("Requires distributed package to be available") 38 | num_replicas = dist.get_world_size() 39 | if rank is None: 40 | if not dist.is_available(): 41 | raise RuntimeError("Requires distributed package to be available") 42 | rank = dist.get_rank() 43 | self.dataset = dataset 44 | self.num_replicas = num_replicas 45 | self.rank = rank 46 | self.epoch = 0 47 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 48 | self.total_size = self.num_samples * self.num_replicas 49 | self.shuffle = shuffle 50 | 51 | def __iter__(self): 52 | if self.shuffle: 53 | # deterministically shuffle based on epoch 54 | g = torch.Generator() 55 | g.manual_seed(self.epoch) 56 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 57 | else: 58 | indices = torch.arange(len(self.dataset)).tolist() 59 | 60 | # add extra samples to make it evenly divisible 61 | indices += indices[: (self.total_size - len(indices))] 62 | assert len(indices) == self.total_size 63 | 64 | # subsample 65 | offset = self.num_samples * self.rank 66 | indices = indices[offset : offset + self.num_samples] 67 | assert len(indices) == self.num_samples 68 | 69 | return iter(indices) 70 | 71 | def __len__(self): 72 | return self.num_samples 73 | 74 | def set_epoch(self, epoch): 75 | self.epoch = epoch 76 | 77 | 78 | class NodeDistributedSampler(Sampler): 79 | """Sampler that restricts data loading to a subset of the dataset. 80 | It is especially useful in conjunction with 81 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 82 | process can pass a DistributedSampler instance as a DataLoader sampler, 83 | and load a subset of the original dataset that is exclusive to it. 84 | .. note:: 85 | Dataset is assumed to be of constant size. 86 | Arguments: 87 | dataset: Dataset used for sampling. 88 | num_replicas (optional): Number of processes participating in 89 | distributed training. 90 | rank (optional): Rank of the current process within num_replicas. 91 | """ 92 | 93 | def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True): 94 | if num_replicas is None: 95 | if not dist.is_available(): 96 | raise RuntimeError("Requires distributed package to be available") 97 | num_replicas = dist.get_world_size() 98 | if rank is None: 99 | if not dist.is_available(): 100 | raise RuntimeError("Requires distributed package to be available") 101 | rank = dist.get_rank() 102 | if local_rank is None: 103 | local_rank = int(os.environ.get('LOCAL_RANK', 0)) 104 | if local_size is None: 105 | local_size = int(os.environ.get('LOCAL_SIZE', 1)) 106 | self.dataset = dataset 107 | self.shuffle = shuffle 108 | self.num_replicas = num_replicas 109 | self.num_parts = local_size 110 | self.rank = rank 111 | self.local_rank = local_rank 112 | self.epoch = 0 113 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 114 | self.total_size = self.num_samples * self.num_replicas 115 | 116 | self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts 117 | 118 | def __iter__(self): 119 | if self.shuffle: 120 | # deterministically shuffle based on epoch 121 | g = torch.Generator() 122 | g.manual_seed(self.epoch) 123 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 124 | else: 125 | indices = torch.arange(len(self.dataset)).tolist() 126 | indices = [i for i in indices if i % self.num_parts == self.local_rank] 127 | 128 | # add extra samples to make it evenly divisible 129 | indices += indices[:(self.total_size_parts - len(indices))] 130 | assert len(indices) == self.total_size_parts 131 | 132 | # subsample 133 | indices = indices[self.rank // self.num_parts:self.total_size_parts:self.num_replicas // self.num_parts] 134 | assert len(indices) == self.num_samples 135 | 136 | return iter(indices) 137 | 138 | def __len__(self): 139 | return self.num_samples 140 | 141 | def set_epoch(self, epoch): 142 | self.epoch = epoch 143 | -------------------------------------------------------------------------------- /datasets/torchvision_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | 12 | from .coco import CocoDetection 13 | -------------------------------------------------------------------------------- /datasets/torchvision_datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/datasets/torchvision_datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/torchvision_datasets/__pycache__/coco.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/datasets/torchvision_datasets/__pycache__/coco.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/torchvision_datasets/coco.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | 12 | """ 13 | Copy-Paste from torchvision, but add utility of caching images on memory 14 | """ 15 | from torchvision.datasets.vision import VisionDataset 16 | from PIL import Image 17 | import os 18 | import os.path 19 | import tqdm 20 | from io import BytesIO 21 | 22 | 23 | class CocoDetection(VisionDataset): 24 | """`MS Coco Detection `_ Dataset. 25 | Args: 26 | root (string): Root directory where images are downloaded to. 27 | annFile (string): Path to json annotation file. 28 | transform (callable, optional): A function/transform that takes in an PIL image 29 | and returns a transformed version. E.g, ``transforms.ToTensor`` 30 | target_transform (callable, optional): A function/transform that takes in the 31 | target and transforms it. 32 | transforms (callable, optional): A function/transform that takes input sample and its target as entry 33 | and returns a transformed version. 34 | """ 35 | 36 | def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None, 37 | cache_mode=False, local_rank=0, local_size=1): 38 | super(CocoDetection, self).__init__(root, transforms, transform, target_transform) 39 | from pycocotools.coco import COCO 40 | self.coco = COCO(annFile) 41 | self.ids = list(sorted(self.coco.imgs.keys())) 42 | self.cache_mode = cache_mode 43 | self.local_rank = local_rank 44 | self.local_size = local_size 45 | if cache_mode: 46 | self.cache = {} 47 | self.cache_images() 48 | 49 | def cache_images(self): 50 | self.cache = {} 51 | for index, img_id in zip(tqdm.trange(len(self.ids)), self.ids): 52 | if index % self.local_size != self.local_rank: 53 | continue 54 | path = self.coco.loadImgs(img_id)[0]['file_name'] 55 | with open(os.path.join(self.root, path), 'rb') as f: 56 | self.cache[path] = f.read() 57 | 58 | def get_image(self, path): 59 | if self.cache_mode: 60 | if path not in self.cache.keys(): 61 | with open(os.path.join(self.root, path), 'rb') as f: 62 | self.cache[path] = f.read() 63 | return Image.open(BytesIO(self.cache[path])).convert('RGB') 64 | return Image.open(os.path.join(self.root, path)).convert('RGB') 65 | 66 | def __getitem__(self, index): 67 | """ 68 | Args: 69 | index (int): Index 70 | Returns: 71 | tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. 72 | """ 73 | coco = self.coco 74 | img_id = self.ids[index] 75 | ann_ids = coco.getAnnIds(imgIds=img_id) 76 | target = coco.loadAnns(ann_ids) 77 | 78 | path = coco.loadImgs(img_id)[0]['file_name'] 79 | 80 | img = self.get_image(path) 81 | if self.transforms is not None: 82 | img, target = self.transforms(img, target) 83 | 84 | return img, target 85 | 86 | def __len__(self): 87 | return len(self.ids) 88 | -------------------------------------------------------------------------------- /figs/demo.avi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/figs/demo.avi -------------------------------------------------------------------------------- /figs/motr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/figs/motr.png -------------------------------------------------------------------------------- /models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/models/.DS_Store -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Modified from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | # ------------------------------------------------------------------------ 9 | 10 | from .deformable_detr import build as build_deformable_detr 11 | from .motr import build as build_motr 12 | 13 | 14 | def build_model(args): 15 | arch_catalog = { 16 | 'deformable_detr': build_deformable_detr, 17 | 'motr': build_motr, 18 | } 19 | assert args.meta_arch in arch_catalog, 'invalid arch: {}'.format(args.meta_arch) 20 | build_func = arch_catalog[args.meta_arch] 21 | return build_func(args) 22 | 23 | -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/backbone.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/models/__pycache__/backbone.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/deformable_detr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/models/__pycache__/deformable_detr.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/deformable_transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/models/__pycache__/deformable_transformer.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/deformable_transformer_plus.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/models/__pycache__/deformable_transformer_plus.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/matcher.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/models/__pycache__/matcher.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/memory_bank.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/models/__pycache__/memory_bank.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/motr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/models/__pycache__/motr.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/position_encoding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/models/__pycache__/position_encoding.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/qim.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/models/__pycache__/qim.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/segmentation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/models/__pycache__/segmentation.cpython-36.pyc -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | 12 | """ 13 | Backbone modules. 14 | """ 15 | from collections import OrderedDict 16 | import torch.nn as nn 17 | import torch 18 | import torch.nn.functional as F 19 | import torchvision 20 | from torch import nn 21 | from torchvision.models._utils import IntermediateLayerGetter 22 | from typing import Dict, List 23 | 24 | from util.misc import NestedTensor, is_main_process 25 | from .position_encoding import build_position_encoding 26 | 27 | class FrozenBatchNorm2d(torch.nn.Module): 28 | """ 29 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 30 | 31 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 32 | without which any other models than torchvision.models.resnet[18,34,50,101] 33 | produce nans. 34 | """ 35 | 36 | def __init__(self, n, eps=1e-5): 37 | super(FrozenBatchNorm2d, self).__init__() 38 | self.register_buffer("weight", torch.ones(n)) 39 | self.register_buffer("bias", torch.zeros(n)) 40 | self.register_buffer("running_mean", torch.zeros(n)) 41 | self.register_buffer("running_var", torch.ones(n)) 42 | self.eps = eps 43 | 44 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 45 | missing_keys, unexpected_keys, error_msgs): 46 | num_batches_tracked_key = prefix + 'num_batches_tracked' 47 | if num_batches_tracked_key in state_dict: 48 | del state_dict[num_batches_tracked_key] 49 | 50 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 51 | state_dict, prefix, local_metadata, strict, 52 | missing_keys, unexpected_keys, error_msgs) 53 | 54 | def forward(self, x): 55 | # move reshapes to the beginning 56 | # to make it fuser-friendly 57 | w = self.weight.reshape(1, -1, 1, 1) 58 | b = self.bias.reshape(1, -1, 1, 1) 59 | rv = self.running_var.reshape(1, -1, 1, 1) 60 | rm = self.running_mean.reshape(1, -1, 1, 1) 61 | eps = self.eps 62 | scale = w * (rv + eps).rsqrt() 63 | bias = b - rm * scale 64 | return x * scale + bias 65 | 66 | 67 | class BackboneBase(nn.Module): 68 | 69 | def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool): 70 | super().__init__() 71 | for name, parameter in backbone.named_parameters(): 72 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 73 | parameter.requires_grad_(False) 74 | 75 | if return_interm_layers: 76 | return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} 77 | self.strides = [8, 16, 32] 78 | self.num_channels = [512, 1024, 2048] 79 | else: 80 | return_layers = {'layer4': "0"} 81 | self.strides = [32] 82 | self.num_channels = [2048] 83 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 84 | 85 | def forward(self, tensor_list: NestedTensor): 86 | xs = self.body(tensor_list.tensors) 87 | out: Dict[str, NestedTensor] = {} 88 | for name, x in xs.items(): 89 | m = tensor_list.mask 90 | assert m is not None 91 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 92 | out[name] = NestedTensor(x, mask) 93 | return out 94 | 95 | 96 | class Backbone(BackboneBase): 97 | """ResNet backbone with frozen BatchNorm.""" 98 | def __init__(self, name: str, 99 | train_backbone: bool, 100 | return_interm_layers: bool, 101 | dilation: bool,): 102 | norm_layer = FrozenBatchNorm2d 103 | backbone = getattr(torchvision.models, name)( 104 | replace_stride_with_dilation=[False, False, dilation], 105 | pretrained=is_main_process(), norm_layer=norm_layer) 106 | assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded" 107 | super().__init__(backbone, train_backbone, return_interm_layers) 108 | if dilation: 109 | self.strides[-1] = self.strides[-1] // 2 110 | 111 | 112 | class Joiner(nn.Sequential): 113 | def __init__(self, backbone, position_embedding): 114 | super().__init__(backbone, position_embedding) 115 | self.strides = backbone.strides 116 | self.num_channels = backbone.num_channels 117 | 118 | def forward(self, tensor_list: NestedTensor): 119 | xs = self[0](tensor_list) 120 | out: List[NestedTensor] = [] 121 | pos = [] 122 | for name, x in sorted(xs.items()): 123 | out.append(x) 124 | 125 | # position encoding 126 | for x in out: 127 | pos.append(self[1](x).to(x.tensors.dtype)) 128 | 129 | return out, pos 130 | 131 | 132 | def build_backbone(args): 133 | position_embedding = build_position_encoding(args) 134 | train_backbone = args.lr_backbone > 0 135 | return_interm_layers = args.masks or (args.num_feature_levels > 1) 136 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) 137 | model = Joiner(backbone, position_embedding) 138 | return model 139 | -------------------------------------------------------------------------------- /models/matcher.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | 12 | """ 13 | Modules to compute the matching cost and solve the corresponding LSAP. 14 | """ 15 | import torch 16 | from scipy.optimize import linear_sum_assignment 17 | from torch import nn 18 | 19 | from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 20 | from models.structures import Instances 21 | 22 | 23 | class HungarianMatcher(nn.Module): 24 | """This class computes an assignment between the targets and the predictions of the network 25 | 26 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 27 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 28 | while the others are un-matched (and thus treated as non-objects). 29 | """ 30 | 31 | def __init__(self, 32 | cost_class: float = 1, 33 | cost_bbox: float = 1, 34 | cost_giou: float = 1): 35 | """Creates the matcher 36 | 37 | Params: 38 | cost_class: This is the relative weight of the classification error in the matching cost 39 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 40 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 41 | """ 42 | super().__init__() 43 | self.cost_class = cost_class 44 | self.cost_bbox = cost_bbox 45 | self.cost_giou = cost_giou 46 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" 47 | 48 | def forward(self, outputs, targets, use_focal=True): 49 | """ Performs the matching 50 | 51 | Params: 52 | outputs: This is a dict that contains at least these entries: 53 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 54 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates 55 | 56 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 57 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 58 | objects in the target) containing the class labels 59 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates 60 | 61 | Returns: 62 | A list of size batch_size, containing tuples of (index_i, index_j) where: 63 | - index_i is the indices of the selected predictions (in order) 64 | - index_j is the indices of the corresponding selected targets (in order) 65 | For each batch element, it holds: 66 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 67 | """ 68 | with torch.no_grad(): 69 | bs, num_queries = outputs["pred_logits"].shape[:2] 70 | 71 | # We flatten to compute the cost matrices in a batch 72 | if use_focal: 73 | out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() 74 | else: 75 | out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] 76 | out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] 77 | 78 | # Also concat the target labels and boxes 79 | if isinstance(targets[0], Instances): 80 | tgt_ids = torch.cat([gt_per_img.labels for gt_per_img in targets]) 81 | tgt_bbox = torch.cat([gt_per_img.boxes for gt_per_img in targets]) 82 | else: 83 | tgt_ids = torch.cat([v["labels"] for v in targets]) 84 | tgt_bbox = torch.cat([v["boxes"] for v in targets]) 85 | 86 | # Compute the classification cost. 87 | if use_focal: 88 | alpha = 0.25 89 | gamma = 2.0 90 | neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) 91 | pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) 92 | cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] 93 | else: 94 | # Compute the classification cost. Contrary to the loss, we don't use the NLL, 95 | # but approximate it in 1 - proba[target class]. 96 | # The 1 is a constant that doesn't change the matching, it can be ommitted. 97 | cost_class = -out_prob[:, tgt_ids] 98 | 99 | # Compute the L1 cost between boxes 100 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) 101 | 102 | # Compute the giou cost betwen boxes 103 | cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), 104 | box_cxcywh_to_xyxy(tgt_bbox)) 105 | 106 | # Final cost matrix 107 | C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou 108 | C = C.view(bs, num_queries, -1).cpu() 109 | 110 | if isinstance(targets[0], Instances): 111 | sizes = [len(gt_per_img.boxes) for gt_per_img in targets] 112 | else: 113 | sizes = [len(v["boxes"]) for v in targets] 114 | 115 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] 116 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 117 | 118 | 119 | def build_matcher(args): 120 | return HungarianMatcher(cost_class=args.set_cost_class, 121 | cost_bbox=args.set_cost_bbox, 122 | cost_giou=args.set_cost_giou) 123 | -------------------------------------------------------------------------------- /models/memory_bank.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn, Tensor 8 | 9 | from typing import List 10 | 11 | from models.structures import Instances 12 | 13 | 14 | class MemoryBank(nn.Module): 15 | def __init__(self, args, dim_in, hidden_dim, dim_out): 16 | super().__init__() 17 | self._build_layers(args, dim_in, hidden_dim, dim_out) 18 | for p in self.parameters(): 19 | if p.dim() > 1: 20 | nn.init.xavier_uniform_(p) 21 | 22 | def _build_layers(self, args, dim_in, hidden_dim, dim_out): 23 | self.save_thresh = args.memory_bank_score_thresh 24 | self.save_period = 3 25 | self.max_his_length = args.memory_bank_len 26 | 27 | self.save_proj = nn.Linear(dim_in, dim_in) 28 | 29 | self.temporal_attn = nn.MultiheadAttention(dim_in, 8, dropout=0) 30 | self.temporal_fc1 = nn.Linear(dim_in, hidden_dim) 31 | self.temporal_fc2 = nn.Linear(hidden_dim, dim_in) 32 | self.temporal_norm1 = nn.LayerNorm(dim_in) 33 | self.temporal_norm2 = nn.LayerNorm(dim_in) 34 | 35 | self.track_cls = nn.Linear(dim_in, 1) 36 | 37 | self.self_attn = None 38 | if args.memory_bank_with_self_attn: 39 | self.spatial_attn = nn.MultiheadAttention(dim_in, 8, dropout=0) 40 | self.spatial_fc1 = nn.Linear(dim_in, hidden_dim) 41 | self.spatial_fc2 = nn.Linear(hidden_dim, dim_in) 42 | self.spatial_norm1 = nn.LayerNorm(dim_in) 43 | self.spatial_norm2 = nn.LayerNorm(dim_in) 44 | else: 45 | self.spatial_attn = None 46 | 47 | def update(self, track_instances): 48 | embed = track_instances.output_embedding[:, None] #( N, 1, 256) 49 | scores = track_instances.scores 50 | mem_padding_mask = track_instances.mem_padding_mask 51 | device = embed.device 52 | 53 | save_period = track_instances.save_period 54 | if self.training: 55 | saved_idxes = scores > 0 56 | else: 57 | saved_idxes = (save_period == 0) & (scores > self.save_thresh) 58 | # saved_idxes = (save_period == 0) 59 | save_period[save_period > 0] -= 1 60 | save_period[saved_idxes] = self.save_period 61 | 62 | saved_embed = embed[saved_idxes] 63 | if len(saved_embed) > 0: 64 | prev_embed = track_instances.mem_bank[saved_idxes] 65 | save_embed = self.save_proj(saved_embed) 66 | mem_padding_mask[saved_idxes] = torch.cat([mem_padding_mask[saved_idxes, 1:], torch.zeros((len(saved_embed), 1), dtype=torch.bool, device=device)], dim=1) 67 | track_instances.mem_bank = track_instances.mem_bank.clone() 68 | track_instances.mem_bank[saved_idxes] = torch.cat([prev_embed[:, 1:], save_embed], dim=1) 69 | 70 | def _forward_spatial_attn(self, track_instances): 71 | if len(track_instances) == 0: 72 | return track_instances 73 | 74 | embed = track_instances.output_embedding 75 | dim = embed.shape[-1] 76 | query_pos = track_instances.query_pos[:, :dim] 77 | k = q = (embed + query_pos) 78 | v = embed 79 | embed2 = self.spatial_attn( 80 | q[:, None], 81 | k[:, None], 82 | v[:, None] 83 | )[0][:, 0] 84 | embed = self.spatial_norm1(embed + embed2) 85 | embed2 = self.spatial_fc2(F.relu(self.spatial_fc1(embed))) 86 | embed = self.spatial_norm2(embed + embed2) 87 | track_instances.output_embedding = embed 88 | return track_instances 89 | 90 | def _forward_track_cls(self, track_instances): 91 | track_instances.track_scores = self.track_cls(track_instances.output_embedding)[..., 0] 92 | return track_instances 93 | 94 | def _forward_temporal_attn(self, track_instances): 95 | if len(track_instances) == 0: 96 | return track_instances 97 | 98 | dim = track_instances.query_pos.shape[1] 99 | key_padding_mask = track_instances.mem_padding_mask 100 | 101 | valid_idxes = key_padding_mask[:, -1] == 0 102 | embed = track_instances.output_embedding[valid_idxes] # (n, 256) 103 | 104 | if len(embed) > 0: 105 | prev_embed = track_instances.mem_bank[valid_idxes] 106 | key_padding_mask = key_padding_mask[valid_idxes] 107 | embed2 = self.temporal_attn( 108 | embed[None], # (num_track, dim) to (1, num_track, dim) 109 | prev_embed.transpose(0, 1), # (num_track, mem_len, dim) to (mem_len, num_track, dim) 110 | prev_embed.transpose(0, 1), 111 | key_padding_mask=key_padding_mask, 112 | )[0][0] 113 | 114 | embed = self.temporal_norm1(embed + embed2) 115 | embed2 = self.temporal_fc2(F.relu(self.temporal_fc1(embed))) 116 | embed = self.temporal_norm2(embed + embed2) 117 | track_instances.output_embedding = track_instances.output_embedding.clone() 118 | track_instances.output_embedding[valid_idxes] = embed 119 | 120 | return track_instances 121 | 122 | def forward_temporal_attn(self, track_instances): 123 | return self._forward_temporal_attn(track_instances) 124 | 125 | def forward(self, track_instances: Instances, update_bank=True) -> Instances: 126 | track_instances = self._forward_temporal_attn(track_instances) 127 | if update_bank: 128 | self.update(track_instances) 129 | if self.spatial_attn is not None: 130 | track_instances = self._forward_spatial_attn(track_instances) 131 | if self.track_cls is not None: 132 | track_instances = self._forward_track_cls(track_instances) 133 | return track_instances 134 | 135 | 136 | def build_memory_bank(args, dim_in, hidden_dim, dim_out): 137 | name = args.memory_bank_type 138 | memory_banks = { 139 | 'MemoryBank': MemoryBank, 140 | } 141 | assert name in memory_banks 142 | return memory_banks[name](args, dim_in, hidden_dim, dim_out) 143 | -------------------------------------------------------------------------------- /models/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | 12 | from .ms_deform_attn_func import MSDeformAttnFunction 13 | 14 | -------------------------------------------------------------------------------- /models/ops/functions/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/models/ops/functions/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/ops/functions/__pycache__/ms_deform_attn_func.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/models/ops/functions/__pycache__/ms_deform_attn_func.cpython-36.pyc -------------------------------------------------------------------------------- /models/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from torch.autograd import Function 19 | from torch.autograd.function import once_differentiable 20 | 21 | import MultiScaleDeformableAttention as MSDA 22 | 23 | 24 | class MSDeformAttnFunction(Function): 25 | @staticmethod 26 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 27 | ctx.im2col_step = im2col_step 28 | output = MSDA.ms_deform_attn_forward( 29 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 30 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 31 | return output 32 | 33 | @staticmethod 34 | @once_differentiable 35 | def backward(ctx, grad_output): 36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 37 | grad_value, grad_sampling_loc, grad_attn_weight = \ 38 | MSDA.ms_deform_attn_backward( 39 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 40 | 41 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 42 | 43 | 44 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 45 | # for debug and test only, 46 | # need to use cuda version instead 47 | N_, S_, M_, D_ = value.shape 48 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 49 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 50 | sampling_grids = 2 * sampling_locations - 1 51 | sampling_value_list = [] 52 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 53 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 54 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 55 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 56 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 57 | # N_*M_, D_, Lq_, P_ 58 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 59 | mode='bilinear', padding_mode='zeros', align_corners=False) 60 | sampling_value_list.append(sampling_value_l_) 61 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 62 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 63 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 64 | return output.transpose(1, 2).contiguous() 65 | -------------------------------------------------------------------------------- /models/ops/make.sh: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | python setup.py build install 10 | -------------------------------------------------------------------------------- /models/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | 12 | from .ms_deform_attn import MSDeformAttn 13 | -------------------------------------------------------------------------------- /models/ops/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/models/ops/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/ops/modules/__pycache__/ms_deform_attn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/models/ops/modules/__pycache__/ms_deform_attn.cpython-36.pyc -------------------------------------------------------------------------------- /models/ops/modules/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import warnings 17 | import math 18 | 19 | import torch 20 | from torch import nn 21 | import torch.nn.functional as F 22 | from torch.nn.init import xavier_uniform_, constant_ 23 | 24 | from ..functions import MSDeformAttnFunction 25 | 26 | 27 | def _is_power_of_2(n): 28 | if (not isinstance(n, int)) or (n < 0): 29 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 30 | return (n & (n-1) == 0) and n != 0 31 | 32 | 33 | class MSDeformAttn(nn.Module): 34 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, sigmoid_attn=False): 35 | """ 36 | Multi-Scale Deformable Attention Module 37 | :param d_model hidden dimension 38 | :param n_levels number of feature levels 39 | :param n_heads number of attention heads 40 | :param n_points number of sampling points per attention head per feature level 41 | """ 42 | super().__init__() 43 | if d_model % n_heads != 0: 44 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) 45 | _d_per_head = d_model // n_heads 46 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 47 | if not _is_power_of_2(_d_per_head): 48 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 49 | "which is more efficient in our CUDA implementation.") 50 | 51 | self.im2col_step = 64 52 | self.sigmoid_attn = sigmoid_attn 53 | 54 | self.d_model = d_model 55 | self.n_levels = n_levels 56 | self.n_heads = n_heads 57 | self.n_points = n_points 58 | 59 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) 60 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 61 | self.value_proj = nn.Linear(d_model, d_model) 62 | self.output_proj = nn.Linear(d_model, d_model) 63 | 64 | self._reset_parameters() 65 | 66 | def _reset_parameters(self): 67 | constant_(self.sampling_offsets.weight.data, 0.) 68 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 69 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 70 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) 71 | for i in range(self.n_points): 72 | grid_init[:, :, i, :] *= i + 1 73 | with torch.no_grad(): 74 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 75 | constant_(self.attention_weights.weight.data, 0.) 76 | constant_(self.attention_weights.bias.data, 0.) 77 | xavier_uniform_(self.value_proj.weight.data) 78 | constant_(self.value_proj.bias.data, 0.) 79 | xavier_uniform_(self.output_proj.weight.data) 80 | constant_(self.output_proj.bias.data, 0.) 81 | 82 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): 83 | """ 84 | :param query (N, Length_{query}, C) 85 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 86 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 87 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 88 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 89 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 90 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 91 | 92 | :return output (N, Length_{query}, C) 93 | """ 94 | N, Len_q, _ = query.shape 95 | N, Len_in, _ = input_flatten.shape 96 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 97 | 98 | value = self.value_proj(input_flatten) 99 | if input_padding_mask is not None: 100 | value.masked_fill_(input_padding_mask[..., None], float(0)) 101 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 102 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) 103 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 104 | if self.sigmoid_attn: 105 | attention_weights = attention_weights.sigmoid().view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 106 | else: 107 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 108 | # N, Len_q, n_heads, n_levels, n_points, 2 109 | if reference_points.shape[-1] == 2: 110 | sampling_locations = reference_points[:, :, None, :, None, :] \ 111 | + sampling_offsets / input_spatial_shapes[None, None, None, :, None, (1, 0)] 112 | elif reference_points.shape[-1] == 4: 113 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 114 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 115 | else: 116 | raise ValueError( 117 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 118 | output = MSDeformAttnFunction.apply( 119 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) 120 | output = self.output_proj(output) 121 | return output 122 | -------------------------------------------------------------------------------- /models/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | import os 10 | import glob 11 | 12 | import torch 13 | 14 | from torch.utils.cpp_extension import CUDA_HOME 15 | from torch.utils.cpp_extension import CppExtension 16 | from torch.utils.cpp_extension import CUDAExtension 17 | 18 | from setuptools import find_packages 19 | from setuptools import setup 20 | 21 | requirements = ["torch", "torchvision"] 22 | 23 | def get_extensions(): 24 | this_dir = os.path.dirname(os.path.abspath(__file__)) 25 | extensions_dir = os.path.join(this_dir, "src") 26 | 27 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 28 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 29 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 30 | 31 | sources = main_file + source_cpu 32 | extension = CppExtension 33 | extra_compile_args = {"cxx": []} 34 | define_macros = [] 35 | 36 | if torch.cuda.is_available() and CUDA_HOME is not None: 37 | extension = CUDAExtension 38 | sources += source_cuda 39 | define_macros += [("WITH_CUDA", None)] 40 | extra_compile_args["nvcc"] = [ 41 | "-DCUDA_HAS_FP16=1", 42 | "-D__CUDA_NO_HALF_OPERATORS__", 43 | "-D__CUDA_NO_HALF_CONVERSIONS__", 44 | "-D__CUDA_NO_HALF2_OPERATORS__", 45 | ] 46 | else: 47 | raise NotImplementedError('Cuda is not availabel') 48 | 49 | sources = [os.path.join(extensions_dir, s) for s in sources] 50 | include_dirs = [extensions_dir] 51 | ext_modules = [ 52 | extension( 53 | "MultiScaleDeformableAttention", 54 | sources, 55 | include_dirs=include_dirs, 56 | define_macros=define_macros, 57 | extra_compile_args=extra_compile_args, 58 | ) 59 | ] 60 | return ext_modules 61 | 62 | setup( 63 | name="MultiScaleDeformableAttention", 64 | version="1.0", 65 | author="Weijie Su", 66 | url="https://github.com/fundamentalvision/Deformable-DETR", 67 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 68 | packages=find_packages(exclude=("configs", "tests",)), 69 | ext_modules=get_extensions(), 70 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 71 | ) 72 | -------------------------------------------------------------------------------- /models/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | 17 | at::Tensor 18 | ms_deform_attn_cpu_forward( 19 | const at::Tensor &value, 20 | const at::Tensor &spatial_shapes, 21 | const at::Tensor &level_start_index, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const int im2col_step) 25 | { 26 | AT_ERROR("Not implement on cpu"); 27 | } 28 | 29 | std::vector 30 | ms_deform_attn_cpu_backward( 31 | const at::Tensor &value, 32 | const at::Tensor &spatial_shapes, 33 | const at::Tensor &level_start_index, 34 | const at::Tensor &sampling_loc, 35 | const at::Tensor &attn_weight, 36 | const at::Tensor &grad_output, 37 | const int im2col_step) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /models/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor 15 | ms_deform_attn_cpu_forward( 16 | const at::Tensor &value, 17 | const at::Tensor &spatial_shapes, 18 | const at::Tensor &level_start_index, 19 | const at::Tensor &sampling_loc, 20 | const at::Tensor &attn_weight, 21 | const int im2col_step); 22 | 23 | std::vector 24 | ms_deform_attn_cpu_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | 34 | -------------------------------------------------------------------------------- /models/ops/src/cuda/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | #include "cuda/ms_deform_im2col_cuda.cuh" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | 20 | at::Tensor ms_deform_attn_cuda_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step) 27 | { 28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 33 | 34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 39 | 40 | const int batch = value.size(0); 41 | const int spatial_size = value.size(1); 42 | const int num_heads = value.size(2); 43 | const int channels = value.size(3); 44 | 45 | const int num_levels = spatial_shapes.size(0); 46 | 47 | const int num_query = sampling_loc.size(1); 48 | const int num_point = sampling_loc.size(4); 49 | 50 | const int im2col_step_ = std::min(batch, im2col_step); 51 | 52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 53 | 54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); 55 | 56 | const int batch_n = im2col_step_; 57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 58 | auto per_value_size = spatial_size * num_heads * channels; 59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 61 | for (int n = 0; n < batch/im2col_step_; ++n) 62 | { 63 | auto columns = output_n.select(0, n); 64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 66 | value.data() + n * im2col_step_ * per_value_size, 67 | spatial_shapes.data(), 68 | level_start_index.data(), 69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 72 | columns.data()); 73 | 74 | })); 75 | } 76 | 77 | output = output.view({batch, num_query, num_heads*channels}); 78 | 79 | return output; 80 | } 81 | 82 | 83 | std::vector ms_deform_attn_cuda_backward( 84 | const at::Tensor &value, 85 | const at::Tensor &spatial_shapes, 86 | const at::Tensor &level_start_index, 87 | const at::Tensor &sampling_loc, 88 | const at::Tensor &attn_weight, 89 | const at::Tensor &grad_output, 90 | const int im2col_step) 91 | { 92 | 93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 99 | 100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 106 | 107 | const int batch = value.size(0); 108 | const int spatial_size = value.size(1); 109 | const int num_heads = value.size(2); 110 | const int channels = value.size(3); 111 | 112 | const int num_levels = spatial_shapes.size(0); 113 | 114 | const int num_query = sampling_loc.size(1); 115 | const int num_point = sampling_loc.size(4); 116 | 117 | const int im2col_step_ = std::min(batch, im2col_step); 118 | 119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 120 | 121 | auto grad_value = at::zeros_like(value); 122 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 123 | auto grad_attn_weight = at::zeros_like(attn_weight); 124 | 125 | const int batch_n = im2col_step_; 126 | auto per_value_size = spatial_size * num_heads * channels; 127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 130 | 131 | for (int n = 0; n < batch/im2col_step_; ++n) 132 | { 133 | auto grad_output_g = grad_output_n.select(0, n); 134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { 135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 136 | grad_output_g.data(), 137 | value.data() + n * im2col_step_ * per_value_size, 138 | spatial_shapes.data(), 139 | level_start_index.data(), 140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 143 | grad_value.data() + n * im2col_step_ * per_value_size, 144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 146 | 147 | })); 148 | } 149 | 150 | return { 151 | grad_value, grad_sampling_loc, grad_attn_weight 152 | }; 153 | } -------------------------------------------------------------------------------- /models/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor ms_deform_attn_cuda_forward( 15 | const at::Tensor &value, 16 | const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, 18 | const at::Tensor &sampling_loc, 19 | const at::Tensor &attn_weight, 20 | const int im2col_step); 21 | 22 | std::vector ms_deform_attn_cuda_backward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const at::Tensor &grad_output, 29 | const int im2col_step); 30 | 31 | -------------------------------------------------------------------------------- /models/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cpu/ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "cuda/ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | 20 | at::Tensor 21 | ms_deform_attn_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | if (value.type().is_cuda()) 30 | { 31 | #ifdef WITH_CUDA 32 | return ms_deform_attn_cuda_forward( 33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | ms_deform_attn_backward( 43 | const at::Tensor &value, 44 | const at::Tensor &spatial_shapes, 45 | const at::Tensor &level_start_index, 46 | const at::Tensor &sampling_loc, 47 | const at::Tensor &attn_weight, 48 | const at::Tensor &grad_output, 49 | const int im2col_step) 50 | { 51 | if (value.type().is_cuda()) 52 | { 53 | #ifdef WITH_CUDA 54 | return ms_deform_attn_cuda_backward( 55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 56 | #else 57 | AT_ERROR("Not compiled with GPU support"); 58 | #endif 59 | } 60 | AT_ERROR("Not implemented on the CPU"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /models/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include "ms_deform_attn.h" 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 16 | } 17 | -------------------------------------------------------------------------------- /models/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import time 14 | import torch 15 | import torch.nn as nn 16 | from torch.autograd import gradcheck 17 | 18 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 19 | 20 | 21 | N, M, D = 1, 2, 2 22 | Lq, L, P = 2, 2, 2 23 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 24 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 25 | S = sum([(H*W).item() for H, W in shapes]) 26 | 27 | 28 | torch.manual_seed(3) 29 | 30 | 31 | @torch.no_grad() 32 | def check_forward_equal_with_pytorch_double(): 33 | value = torch.rand(N, S, M, D).cuda() * 0.01 34 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 35 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 36 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 37 | im2col_step = 2 38 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 39 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 40 | fwdok = torch.allclose(output_cuda, output_pytorch) 41 | max_abs_err = (output_cuda - output_pytorch).abs().max() 42 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 43 | 44 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 45 | 46 | 47 | @torch.no_grad() 48 | def check_forward_equal_with_pytorch_float(): 49 | value = torch.rand(N, S, M, D).cuda() * 0.01 50 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 51 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 52 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 53 | im2col_step = 2 54 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 55 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 56 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 57 | max_abs_err = (output_cuda - output_pytorch).abs().max() 58 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 59 | 60 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 61 | 62 | 63 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 64 | 65 | value = torch.rand(N, S, M, channels).cuda() * 0.01 66 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 67 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 68 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 69 | im2col_step = 2 70 | func = MSDeformAttnFunction.apply 71 | 72 | value.requires_grad = grad_value 73 | sampling_locations.requires_grad = grad_sampling_loc 74 | attention_weights.requires_grad = grad_attn_weight 75 | 76 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 77 | 78 | print(f'* {gradok} check_gradient_numerical(D={channels})') 79 | 80 | 81 | if __name__ == '__main__': 82 | check_forward_equal_with_pytorch_double() 83 | check_forward_equal_with_pytorch_float() 84 | 85 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 86 | check_gradient_numerical(channels, True, True, True) 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /models/position_encoding.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | 12 | """ 13 | Various positional encodings for the transformer. 14 | """ 15 | import math 16 | import torch 17 | from torch import nn 18 | 19 | from util.misc import NestedTensor 20 | 21 | 22 | class PositionEmbeddingSine(nn.Module): 23 | """ 24 | This is a more standard version of the position embedding, very similar to the one 25 | used by the Attention is all you need paper, generalized to work on images. 26 | """ 27 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 28 | super().__init__() 29 | self.num_pos_feats = num_pos_feats 30 | self.temperature = temperature 31 | self.normalize = normalize 32 | if scale is not None and normalize is False: 33 | raise ValueError("normalize should be True if scale is passed") 34 | if scale is None: 35 | scale = 2 * math.pi 36 | self.scale = scale 37 | 38 | def forward(self, tensor_list: NestedTensor): 39 | x = tensor_list.tensors 40 | mask = tensor_list.mask 41 | assert mask is not None 42 | not_mask = ~mask 43 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 44 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 45 | if self.normalize: 46 | eps = 1e-6 47 | y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale 48 | x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale 49 | 50 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 51 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 52 | 53 | pos_x = x_embed[:, :, :, None] / dim_t 54 | pos_y = y_embed[:, :, :, None] / dim_t 55 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 56 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 57 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 58 | return pos 59 | 60 | 61 | class PositionEmbeddingLearned(nn.Module): 62 | """ 63 | Absolute pos embedding, learned. 64 | """ 65 | def __init__(self, num_pos_feats=256): 66 | super().__init__() 67 | self.row_embed = nn.Embedding(50, num_pos_feats) 68 | self.col_embed = nn.Embedding(50, num_pos_feats) 69 | self.reset_parameters() 70 | 71 | def reset_parameters(self): 72 | nn.init.uniform_(self.row_embed.weight) 73 | nn.init.uniform_(self.col_embed.weight) 74 | 75 | def forward(self, tensor_list: NestedTensor): 76 | x = tensor_list.tensors 77 | h, w = x.shape[-2:] 78 | i = torch.arange(w, device=x.device) 79 | j = torch.arange(h, device=x.device) 80 | x_emb = self.col_embed(i) 81 | y_emb = self.row_embed(j) 82 | pos = torch.cat([ 83 | x_emb.unsqueeze(0).repeat(h, 1, 1), 84 | y_emb.unsqueeze(1).repeat(1, w, 1), 85 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 86 | return pos 87 | 88 | 89 | def build_position_encoding(args): 90 | N_steps = args.hidden_dim // 2 91 | if args.position_embedding in ('v2', 'sine'): 92 | # TODO find a better way of exposing other arguments 93 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 94 | elif args.position_embedding in ('v3', 'learned'): 95 | position_embedding = PositionEmbeddingLearned(N_steps) 96 | else: 97 | raise ValueError(f"not supported {args.position_embedding}") 98 | 99 | return position_embedding 100 | -------------------------------------------------------------------------------- /models/qim.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | 5 | import random 6 | import torch 7 | from torch import nn, Tensor 8 | from torch.nn import functional as F 9 | from typing import Optional, List 10 | 11 | from util import box_ops 12 | from util.misc import inverse_sigmoid 13 | from models.structures import Boxes, Instances, pairwise_iou 14 | 15 | 16 | def random_drop_tracks(track_instances: Instances, drop_probability: float) -> Instances: 17 | if drop_probability > 0 and len(track_instances) > 0: 18 | keep_idxes = torch.rand_like(track_instances.scores) > drop_probability 19 | track_instances = track_instances[keep_idxes] 20 | return track_instances 21 | 22 | 23 | class QueryInteractionBase(nn.Module): 24 | def __init__(self, args, dim_in, hidden_dim, dim_out): 25 | super().__init__() 26 | self.args = args 27 | self._build_layers(args, dim_in, hidden_dim, dim_out) 28 | self._reset_parameters() 29 | 30 | def _build_layers(self, args, dim_in, hidden_dim, dim_out): 31 | raise NotImplementedError() 32 | 33 | def _reset_parameters(self): 34 | for p in self.parameters(): 35 | if p.dim() > 1: 36 | nn.init.xavier_uniform_(p) 37 | 38 | def _select_active_tracks(self, data: dict) -> Instances: 39 | raise NotImplementedError() 40 | 41 | def _update_track_embedding(self, track_instances): 42 | raise NotImplementedError() 43 | 44 | 45 | class FFN(nn.Module): 46 | def __init__(self, d_model, d_ffn, dropout=0): 47 | super().__init__() 48 | self.linear1 = nn.Linear(d_model, d_ffn) 49 | self.activation = nn.ReLU(True) 50 | self.dropout1 = nn.Dropout(dropout) 51 | self.linear2 = nn.Linear(d_ffn, d_model) 52 | self.dropout2 = nn.Dropout(dropout) 53 | self.norm = nn.LayerNorm(d_model) 54 | 55 | def forward(self, tgt): 56 | tgt2 = self.linear2(self.dropout1(self.activation(self.linear1(tgt)))) 57 | tgt = tgt + self.dropout2(tgt2) 58 | tgt = self.norm(tgt) 59 | return tgt 60 | 61 | 62 | class QueryInteractionModule(QueryInteractionBase): 63 | def __init__(self, args, dim_in, hidden_dim, dim_out): 64 | super().__init__(args, dim_in, hidden_dim, dim_out) 65 | self.random_drop = args.random_drop 66 | self.fp_ratio = args.fp_ratio 67 | self.update_query_pos = args.update_query_pos 68 | 69 | def _build_layers(self, args, dim_in, hidden_dim, dim_out): 70 | dropout = args.merger_dropout 71 | 72 | self.self_attn = nn.MultiheadAttention(dim_in, 8, dropout) 73 | self.linear1 = nn.Linear(dim_in, hidden_dim) 74 | self.dropout = nn.Dropout(dropout) 75 | self.linear2 = nn.Linear(hidden_dim, dim_in) 76 | 77 | if args.update_query_pos: 78 | self.linear_pos1 = nn.Linear(dim_in, hidden_dim) 79 | self.linear_pos2 = nn.Linear(hidden_dim, dim_in) 80 | self.dropout_pos1 = nn.Dropout(dropout) 81 | self.dropout_pos2 = nn.Dropout(dropout) 82 | self.norm_pos = nn.LayerNorm(dim_in) 83 | 84 | self.linear_feat1 = nn.Linear(dim_in, hidden_dim) 85 | self.linear_feat2 = nn.Linear(hidden_dim, dim_in) 86 | self.dropout_feat1 = nn.Dropout(dropout) 87 | self.dropout_feat2 = nn.Dropout(dropout) 88 | self.norm_feat = nn.LayerNorm(dim_in) 89 | 90 | self.norm1 = nn.LayerNorm(dim_in) 91 | self.norm2 = nn.LayerNorm(dim_in) 92 | if args.update_query_pos: 93 | self.norm3 = nn.LayerNorm(dim_in) 94 | 95 | self.dropout1 = nn.Dropout(dropout) 96 | self.dropout2 = nn.Dropout(dropout) 97 | if args.update_query_pos: 98 | self.dropout3 = nn.Dropout(dropout) 99 | self.dropout4 = nn.Dropout(dropout) 100 | 101 | self.activation = nn.ReLU(True) 102 | 103 | def _random_drop_tracks(self, track_instances: Instances) -> Instances: 104 | return random_drop_tracks(track_instances, self.random_drop) 105 | 106 | def _add_fp_tracks(self, track_instances: Instances, active_track_instances: Instances) -> Instances: 107 | inactive_instances = track_instances[track_instances.obj_idxes < 0] 108 | 109 | # add fp for each active track in a specific probability. 110 | fp_prob = torch.ones_like(active_track_instances.scores) * self.fp_ratio 111 | selected_active_track_instances = active_track_instances[torch.bernoulli(fp_prob).bool()] 112 | 113 | if len(inactive_instances) > 0 and len(selected_active_track_instances) > 0: 114 | num_fp = len(selected_active_track_instances) 115 | if num_fp >= len(inactive_instances): 116 | fp_track_instances = inactive_instances 117 | else: 118 | inactive_boxes = Boxes(box_ops.box_cxcywh_to_xyxy(inactive_instances.pred_boxes)) 119 | selected_active_boxes = Boxes(box_ops.box_cxcywh_to_xyxy(selected_active_track_instances.pred_boxes)) 120 | ious = pairwise_iou(inactive_boxes, selected_active_boxes) 121 | # select the fp with the largest IoU for each active track. 122 | fp_indexes = ious.max(dim=0).indices 123 | 124 | # remove duplicate fp. 125 | fp_indexes = torch.unique(fp_indexes) 126 | fp_track_instances = inactive_instances[fp_indexes] 127 | 128 | merged_track_instances = Instances.cat([active_track_instances, fp_track_instances]) 129 | return merged_track_instances 130 | 131 | return active_track_instances 132 | 133 | def _select_active_tracks(self, data: dict) -> Instances: 134 | track_instances: Instances = data['track_instances'] 135 | if self.training: 136 | active_idxes = (track_instances.obj_idxes >= 0) & (track_instances.iou > 0.5) 137 | active_track_instances = track_instances[active_idxes] 138 | # set -2 instead of -1 to ensure that these tracks will not be selected in matching. 139 | active_track_instances = self._random_drop_tracks(active_track_instances) 140 | if self.fp_ratio > 0: 141 | active_track_instances = self._add_fp_tracks(track_instances, active_track_instances) 142 | else: 143 | active_track_instances = track_instances[track_instances.obj_idxes >= 0] 144 | 145 | return active_track_instances 146 | 147 | def _update_track_embedding(self, track_instances: Instances) -> Instances: 148 | if len(track_instances) == 0: 149 | return track_instances 150 | dim = track_instances.query_pos.shape[1] 151 | out_embed = track_instances.output_embedding 152 | query_pos = track_instances.query_pos[:, :dim // 2] 153 | query_feat = track_instances.query_pos[:, dim//2:] 154 | q = k = query_pos + out_embed 155 | 156 | tgt = out_embed 157 | tgt2 = self.self_attn(q[:, None], k[:, None], value=tgt[:, None])[0][:, 0] 158 | tgt = tgt + self.dropout1(tgt2) 159 | tgt = self.norm1(tgt) 160 | 161 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 162 | tgt = tgt + self.dropout2(tgt2) 163 | tgt = self.norm2(tgt) 164 | 165 | if self.update_query_pos: 166 | query_pos2 = self.linear_pos2(self.dropout_pos1(self.activation(self.linear_pos1(tgt)))) 167 | query_pos = query_pos + self.dropout_pos2(query_pos2) 168 | query_pos = self.norm_pos(query_pos) 169 | track_instances.query_pos[:, :dim // 2] = query_pos 170 | 171 | query_feat2 = self.linear_feat2(self.dropout_feat1(self.activation(self.linear_feat1(tgt)))) 172 | query_feat = query_feat + self.dropout_feat2(query_feat2) 173 | query_feat = self.norm_feat(query_feat) 174 | track_instances.query_pos[:, dim//2:] = query_feat 175 | 176 | track_instances.ref_pts = inverse_sigmoid(track_instances.pred_boxes[:, :2].detach().clone()) 177 | return track_instances 178 | 179 | def forward(self, data) -> Instances: 180 | active_track_instances = self._select_active_tracks(data) 181 | active_track_instances = self._update_track_embedding(active_track_instances) 182 | init_track_instances: Instances = data['init_track_instances'] 183 | merged_track_instances = Instances.cat([init_track_instances, active_track_instances]) 184 | return merged_track_instances 185 | 186 | 187 | def build(args, layer_name, dim_in, hidden_dim, dim_out): 188 | interaction_layers = { 189 | 'QIM': QueryInteractionModule, 190 | } 191 | assert layer_name in interaction_layers, 'invalid query interaction layer: {}'.format(layer_name) 192 | return interaction_layers[layer_name](args, dim_in, hidden_dim, dim_out) 193 | -------------------------------------------------------------------------------- /models/relu_dropout.py: -------------------------------------------------------------------------------- 1 | # https://gist.github.com/vadimkantorov/360ece06de4fd2641fa9ed1085f76d48 2 | import torch 3 | 4 | class ReLUDropout(torch.nn.Dropout): 5 | def forward(self, input): 6 | return relu_dropout(input, p=self.p, training=self.training, inplace=self.inplace) 7 | 8 | def relu_dropout(x, p=0, inplace=False, training=False): 9 | if not training or p == 0: 10 | return x.clamp_(min=0) if inplace else x.clamp(min=0) 11 | 12 | mask = (x < 0) | (torch.rand_like(x) > 1 - p) 13 | return x.masked_fill_(mask, 0).div_(1 - p) if inplace else x.masked_fill(mask, 0).div(1 - p) 14 | -------------------------------------------------------------------------------- /models/structures/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from Detectron2 (https://github.com/facebookresearch/detectron2) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | from .boxes import Boxes, BoxMode, pairwise_iou, pairwise_ioa, matched_boxlist_iou 6 | from .instances import Instances 7 | 8 | __all__ = [k for k in globals().keys() if not k.startswith("_")] -------------------------------------------------------------------------------- /models/structures/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/models/structures/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/structures/__pycache__/boxes.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/models/structures/__pycache__/boxes.cpython-36.pyc -------------------------------------------------------------------------------- /models/structures/__pycache__/instances.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/models/structures/__pycache__/instances.cpython-36.pyc -------------------------------------------------------------------------------- /models/structures/instances.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from Detectron2 (https://github.com/facebookresearch/detectron2) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | 6 | import itertools 7 | from typing import Any, Dict, List, Tuple, Union 8 | import torch 9 | import numpy as np 10 | 11 | 12 | class Instances: 13 | """ 14 | This class represents a list of instances in an image. 15 | It stores the attributes of instances (e.g., boxes, masks, labels, scores) as "fields". 16 | All fields must have the same ``__len__`` which is the number of instances. 17 | 18 | All other (non-field) attributes of this class are considered private: 19 | they must start with '_' and are not modifiable by a user. 20 | 21 | Some basic usage: 22 | 23 | 1. Set/get/check a field: 24 | 25 | .. code-block:: python 26 | 27 | instances.gt_boxes = Boxes(...) 28 | print(instances.pred_masks) # a tensor of shape (N, H, W) 29 | print('gt_masks' in instances) 30 | 31 | 2. ``len(instances)`` returns the number of instances 32 | 3. Indexing: ``instances[indices]`` will apply the indexing on all the fields 33 | and returns a new :class:`Instances`. 34 | Typically, ``indices`` is a integer vector of indices, 35 | or a binary mask of length ``num_instances`` 36 | 37 | .. code-block:: python 38 | 39 | category_3_detections = instances[instances.pred_classes == 3] 40 | confident_detections = instances[instances.scores > 0.9] 41 | """ 42 | 43 | def __init__(self, image_size: Tuple[int, int], **kwargs: Any): 44 | """ 45 | Args: 46 | image_size (height, width): the spatial size of the image. 47 | kwargs: fields to add to this `Instances`. 48 | """ 49 | self._image_size = image_size 50 | self._fields: Dict[str, Any] = {} 51 | for k, v in kwargs.items(): 52 | self.set(k, v) 53 | 54 | @property 55 | def image_size(self) -> Tuple[int, int]: 56 | """ 57 | Returns: 58 | tuple: height, width 59 | """ 60 | return self._image_size 61 | 62 | def __setattr__(self, name: str, val: Any) -> None: 63 | if name.startswith("_"): 64 | super().__setattr__(name, val) 65 | else: 66 | self.set(name, val) 67 | 68 | def __getattr__(self, name: str) -> Any: 69 | if name == "_fields" or name not in self._fields: 70 | raise AttributeError("Cannot find field '{}' in the given Instances!".format(name)) 71 | return self._fields[name] 72 | 73 | def set(self, name: str, value: Any) -> None: 74 | """ 75 | Set the field named `name` to `value`. 76 | The length of `value` must be the number of instances, 77 | and must agree with other existing fields in this object. 78 | """ 79 | data_len = len(value) 80 | if len(self._fields): 81 | assert ( 82 | len(self) == data_len 83 | ), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self)) 84 | self._fields[name] = value 85 | 86 | def has(self, name: str) -> bool: 87 | """ 88 | Returns: 89 | bool: whether the field called `name` exists. 90 | """ 91 | return name in self._fields 92 | 93 | def remove(self, name: str) -> None: 94 | """ 95 | Remove the field called `name`. 96 | """ 97 | del self._fields[name] 98 | 99 | def get(self, name: str) -> Any: 100 | """ 101 | Returns the field called `name`. 102 | """ 103 | return self._fields[name] 104 | 105 | def get_fields(self) -> Dict[str, Any]: 106 | """ 107 | Returns: 108 | dict: a dict which maps names (str) to data of the fields 109 | 110 | Modifying the returned dict will modify this instance. 111 | """ 112 | return self._fields 113 | 114 | # Tensor-like methods 115 | def to(self, *args: Any, **kwargs: Any) -> "Instances": 116 | """ 117 | Returns: 118 | Instances: all fields are called with a `to(device)`, if the field has this method. 119 | """ 120 | ret = Instances(self._image_size) 121 | for k, v in self._fields.items(): 122 | if hasattr(v, "to"): 123 | v = v.to(*args, **kwargs) 124 | ret.set(k, v) 125 | return ret 126 | 127 | def numpy(self): 128 | ret = Instances(self._image_size) 129 | for k, v in self._fields.items(): 130 | if hasattr(v, "numpy"): 131 | v = v.numpy() 132 | ret.set(k, v) 133 | return ret 134 | 135 | def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Instances": 136 | """ 137 | Args: 138 | item: an index-like object and will be used to index all the fields. 139 | 140 | Returns: 141 | If `item` is a string, return the data in the corresponding field. 142 | Otherwise, returns an `Instances` where all fields are indexed by `item`. 143 | """ 144 | if type(item) == int: 145 | if item >= len(self) or item < -len(self): 146 | raise IndexError("Instances index out of range!") 147 | else: 148 | item = slice(item, None, len(self)) 149 | 150 | ret = Instances(self._image_size) 151 | for k, v in self._fields.items(): 152 | ret.set(k, v[item]) 153 | return ret 154 | 155 | def __len__(self) -> int: 156 | for v in self._fields.values(): 157 | # use __len__ because len() has to be int and is not friendly to tracing 158 | return v.__len__() 159 | raise NotImplementedError("Empty Instances does not support __len__!") 160 | 161 | def __iter__(self): 162 | raise NotImplementedError("`Instances` object is not iterable!") 163 | 164 | @staticmethod 165 | def cat(instance_lists: List["Instances"]) -> "Instances": 166 | """ 167 | Args: 168 | instance_lists (list[Instances]) 169 | 170 | Returns: 171 | Instances 172 | """ 173 | assert all(isinstance(i, Instances) for i in instance_lists) 174 | assert len(instance_lists) > 0 175 | if len(instance_lists) == 1: 176 | return instance_lists[0] 177 | 178 | image_size = instance_lists[0].image_size 179 | for i in instance_lists[1:]: 180 | assert i.image_size == image_size 181 | ret = Instances(image_size) 182 | for k in instance_lists[0]._fields.keys(): 183 | values = [i.get(k) for i in instance_lists] 184 | v0 = values[0] 185 | if isinstance(v0, torch.Tensor): 186 | values = torch.cat(values, dim=0) 187 | elif isinstance(v0, list): 188 | values = list(itertools.chain(*values)) 189 | elif hasattr(type(v0), "cat"): 190 | values = type(v0).cat(values) 191 | else: 192 | raise ValueError("Unsupported type {} for concatenation".format(type(v0))) 193 | ret.set(k, values) 194 | return ret 195 | 196 | def __str__(self) -> str: 197 | s = self.__class__.__name__ + "(" 198 | s += "num_instances={}, ".format(len(self)) 199 | s += "image_height={}, ".format(self._image_size[0]) 200 | s += "image_width={}, ".format(self._image_size[1]) 201 | s += "fields=[{}])".format(", ".join((f"{k}: {v}" for k, v in self._fields.items()))) 202 | return s 203 | 204 | __repr__ = __str__ 205 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pycocotools 2 | tqdm 3 | cython 4 | scipy 5 | motmetrics 6 | opencv-python 7 | seaborn 8 | lap 9 | -------------------------------------------------------------------------------- /tools/launch.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | 12 | r""" 13 | `torch.distributed.launch` is a module that spawns up multiple distributed 14 | training processes on each of the training nodes. 15 | The utility can be used for single-node distributed training, in which one or 16 | more processes per node will be spawned. The utility can be used for either 17 | CPU training or GPU training. If the utility is used for GPU training, 18 | each distributed process will be operating on a single GPU. This can achieve 19 | well-improved single-node training performance. It can also be used in 20 | multi-node distributed training, by spawning up multiple processes on each node 21 | for well-improved multi-node distributed training performance as well. 22 | This will especially be benefitial for systems with multiple Infiniband 23 | interfaces that have direct-GPU support, since all of them can be utilized for 24 | aggregated communication bandwidth. 25 | In both cases of single-node distributed training or multi-node distributed 26 | training, this utility will launch the given number of processes per node 27 | (``--nproc_per_node``). If used for GPU training, this number needs to be less 28 | or euqal to the number of GPUs on the current system (``nproc_per_node``), 29 | and each process will be operating on a single GPU from *GPU 0 to 30 | GPU (nproc_per_node - 1)*. 31 | **How to use this module:** 32 | 1. Single-Node multi-process distributed training 33 | :: 34 | >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE 35 | YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other 36 | arguments of your training script) 37 | 2. Multi-Node multi-process distributed training: (e.g. two nodes) 38 | Node 1: *(IP: 192.168.1.1, and has a free port: 1234)* 39 | :: 40 | >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE 41 | --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" 42 | --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 43 | and all other arguments of your training script) 44 | Node 2: 45 | :: 46 | >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE 47 | --nnodes=2 --node_rank=1 --master_addr="192.168.1.1" 48 | --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 49 | and all other arguments of your training script) 50 | 3. To look up what optional arguments this module offers: 51 | :: 52 | >>> python -m torch.distributed.launch --help 53 | **Important Notices:** 54 | 1. This utilty and multi-process distributed (single-node or 55 | multi-node) GPU training currently only achieves the best performance using 56 | the NCCL distributed backend. Thus NCCL backend is the recommended backend to 57 | use for GPU training. 58 | 2. In your training program, you must parse the command-line argument: 59 | ``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by this module. 60 | If your training program uses GPUs, you should ensure that your code only 61 | runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by: 62 | Parsing the local_rank argument 63 | :: 64 | >>> import argparse 65 | >>> parser = argparse.ArgumentParser() 66 | >>> parser.add_argument("--local_rank", type=int) 67 | >>> args = parser.parse_args() 68 | Set your device to local rank using either 69 | :: 70 | >>> torch.cuda.set_device(arg.local_rank) # before your code runs 71 | or 72 | :: 73 | >>> with torch.cuda.device(arg.local_rank): 74 | >>> # your code to run 75 | 3. In your training program, you are supposed to call the following function 76 | at the beginning to start the distributed backend. You need to make sure that 77 | the init_method uses ``env://``, which is the only supported ``init_method`` 78 | by this module. 79 | :: 80 | torch.distributed.init_process_group(backend='YOUR BACKEND', 81 | init_method='env://') 82 | 4. In your training program, you can either use regular distributed functions 83 | or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your 84 | training program uses GPUs for training and you would like to use 85 | :func:`torch.nn.parallel.DistributedDataParallel` module, 86 | here is how to configure it. 87 | :: 88 | model = torch.nn.parallel.DistributedDataParallel(model, 89 | device_ids=[arg.local_rank], 90 | output_device=arg.local_rank) 91 | Please ensure that ``device_ids`` argument is set to be the only GPU device id 92 | that your code will be operating on. This is generally the local rank of the 93 | process. In other words, the ``device_ids`` needs to be ``[args.local_rank]``, 94 | and ``output_device`` needs to be ``args.local_rank`` in order to use this 95 | utility 96 | 5. Another way to pass ``local_rank`` to the subprocesses via environment variable 97 | ``LOCAL_RANK``. This behavior is enabled when you launch the script with 98 | ``--use_env=True``. You must adjust the subprocess example above to replace 99 | ``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher 100 | will not pass ``--local_rank`` when you specify this flag. 101 | .. warning:: 102 | ``local_rank`` is NOT globally unique: it is only unique per process 103 | on a machine. Thus, don't use it to decide if you should, e.g., 104 | write to a networked filesystem. See 105 | https://github.com/pytorch/pytorch/issues/12042 for an example of 106 | how things can go wrong if you don't do this correctly. 107 | """ 108 | 109 | 110 | import sys 111 | import subprocess 112 | import os 113 | import socket 114 | from argparse import ArgumentParser, REMAINDER 115 | 116 | import torch 117 | 118 | 119 | def parse_args(): 120 | """ 121 | Helper function parsing the command line options 122 | @retval ArgumentParser 123 | """ 124 | parser = ArgumentParser(description="PyTorch distributed training launch " 125 | "helper utilty that will spawn up " 126 | "multiple distributed processes") 127 | 128 | # Optional arguments for the launch helper 129 | parser.add_argument("--nnodes", type=int, default=1, 130 | help="The number of nodes to use for distributed " 131 | "training") 132 | parser.add_argument("--node_rank", type=int, default=0, 133 | help="The rank of the node for multi-node distributed " 134 | "training") 135 | parser.add_argument("--nproc_per_node", type=int, default=1, 136 | help="The number of processes to launch on each node, " 137 | "for GPU training, this is recommended to be set " 138 | "to the number of GPUs in your system so that " 139 | "each process can be bound to a single GPU.") 140 | parser.add_argument("--master_addr", default="127.0.0.1", type=str, 141 | help="Master node (rank 0)'s address, should be either " 142 | "the IP address or the hostname of node 0, for " 143 | "single node multi-proc training, the " 144 | "--master_addr can simply be 127.0.0.1") 145 | parser.add_argument("--master_port", default=29500, type=int, 146 | help="Master node (rank 0)'s free port that needs to " 147 | "be used for communciation during distributed " 148 | "training") 149 | 150 | # positional 151 | parser.add_argument("training_script", type=str, 152 | help="The full path to the single GPU training " 153 | "program/script to be launched in parallel, " 154 | "followed by all the arguments for the " 155 | "training script") 156 | 157 | # rest from the training program 158 | parser.add_argument('training_script_args', nargs=REMAINDER) 159 | return parser.parse_args() 160 | 161 | 162 | def main(): 163 | args = parse_args() 164 | 165 | # world size in terms of number of processes 166 | dist_world_size = args.nproc_per_node * args.nnodes 167 | 168 | # set PyTorch distributed related environmental variables 169 | current_env = os.environ.copy() 170 | current_env["MASTER_ADDR"] = args.master_addr 171 | current_env["MASTER_PORT"] = str(args.master_port) 172 | current_env["WORLD_SIZE"] = str(dist_world_size) 173 | 174 | processes = [] 175 | 176 | for local_rank in range(0, args.nproc_per_node): 177 | # each process's rank 178 | dist_rank = args.nproc_per_node * args.node_rank + local_rank 179 | current_env["RANK"] = str(dist_rank) 180 | current_env["LOCAL_RANK"] = str(local_rank) 181 | 182 | cmd = [args.training_script] + args.training_script_args 183 | 184 | process = subprocess.Popen(cmd, env=current_env) 185 | processes.append(process) 186 | 187 | for process in processes: 188 | process.wait() 189 | if process.returncode != 0: 190 | raise subprocess.CalledProcessError(returncode=process.returncode, 191 | cmd=process.args) 192 | 193 | 194 | if __name__ == "__main__": 195 | main() -------------------------------------------------------------------------------- /tools/run_dist_launch.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------ 3 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 6 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | # Modified from DETR (https://github.com/facebookresearch/detr) 9 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 10 | # ------------------------------------------------------------------------ 11 | 12 | 13 | set -x 14 | 15 | GPUS=$1 16 | RUN_COMMAND=${@:2} 17 | if [ $GPUS -lt 8 ]; then 18 | GPUS_PER_NODE=${GPUS_PER_NODE:-$GPUS} 19 | else 20 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 21 | fi 22 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 23 | MASTER_PORT=${MASTER_PORT:-"29500"} 24 | NODE_RANK=${NODE_RANK:-0} 25 | 26 | let "NNODES=GPUS/GPUS_PER_NODE" 27 | 28 | python3 ./tools/launch.py \ 29 | --nnodes ${NNODES} \ 30 | --node_rank ${NODE_RANK} \ 31 | --master_addr ${MASTER_ADDR} \ 32 | --master_port ${MASTER_PORT} \ 33 | --nproc_per_node ${GPUS_PER_NODE} \ 34 | ${RUN_COMMAND} -------------------------------------------------------------------------------- /tools/run_dist_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------ 3 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 6 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | # Modified from DETR (https://github.com/facebookresearch/detr) 9 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 10 | # ------------------------------------------------------------------------ 11 | 12 | 13 | set -x 14 | 15 | PARTITION=$1 16 | JOB_NAME=$2 17 | GPUS=$3 18 | RUN_COMMAND=${@:4} 19 | if [ $GPUS -lt 8 ]; then 20 | GPUS_PER_NODE=${GPUS_PER_NODE:-$GPUS} 21 | else 22 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 23 | fi 24 | CPUS_PER_TASK=${CPUS_PER_TASK:-4} 25 | SRUN_ARGS=${SRUN_ARGS:-""} 26 | 27 | srun -p ${PARTITION} \ 28 | --job-name=${JOB_NAME} \ 29 | --gres=gpu:${GPUS_PER_NODE} \ 30 | --ntasks=${GPUS} \ 31 | --ntasks-per-node=${GPUS_PER_NODE} \ 32 | --cpus-per-task=${CPUS_PER_TASK} \ 33 | --kill-on-bad-exit=1 \ 34 | ${SRUN_ARGS} \ 35 | ${RUN_COMMAND} 36 | 37 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/box_ops.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/util/__pycache__/box_ops.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/evaluation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/util/__pycache__/evaluation.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/util/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/motdet_eval.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/util/__pycache__/motdet_eval.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/plot_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/util/__pycache__/plot_utils.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/tool.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/MOTR/8690da3392159635ca37c31975126acf40220724/util/__pycache__/tool.cpython-36.pyc -------------------------------------------------------------------------------- /util/box_ops.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | 12 | """ 13 | Utilities for bounding box manipulation and GIoU. 14 | """ 15 | import torch 16 | from torchvision.ops.boxes import box_area 17 | 18 | 19 | def box_cxcywh_to_xyxy(x): 20 | x_c, y_c, w, h = x.unbind(-1) 21 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 22 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 23 | return torch.stack(b, dim=-1) 24 | 25 | 26 | def box_xyxy_to_cxcywh(x): 27 | x0, y0, x1, y1 = x.unbind(-1) 28 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 29 | (x1 - x0), (y1 - y0)] 30 | return torch.stack(b, dim=-1) 31 | 32 | 33 | # modified from torchvision to also return the union 34 | def box_iou(boxes1, boxes2): 35 | area1 = box_area(boxes1) 36 | area2 = box_area(boxes2) 37 | 38 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 39 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 40 | 41 | wh = (rb - lt).clamp(min=0) # [N,M,2] 42 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 43 | 44 | union = area1[:, None] + area2 - inter 45 | 46 | iou = inter / union 47 | return iou, union 48 | 49 | 50 | def generalized_box_iou(boxes1, boxes2): 51 | """ 52 | Generalized IoU from https://giou.stanford.edu/ 53 | 54 | The boxes should be in [x0, y0, x1, y1] format 55 | 56 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 57 | and M = len(boxes2) 58 | """ 59 | # degenerate boxes gives inf / nan results 60 | # so do an early check 61 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 62 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 63 | iou, union = box_iou(boxes1, boxes2) 64 | 65 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 66 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 67 | 68 | wh = (rb - lt).clamp(min=0) # [N,M,2] 69 | area = wh[:, :, 0] * wh[:, :, 1] 70 | 71 | return iou - (area - union) / area 72 | 73 | 74 | def masks_to_boxes(masks): 75 | """Compute the bounding boxes around the provided masks 76 | 77 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 78 | 79 | Returns a [N, 4] tensors, with the boxes in xyxy format 80 | """ 81 | if masks.numel() == 0: 82 | return torch.zeros((0, 4), device=masks.device) 83 | 84 | h, w = masks.shape[-2:] 85 | 86 | y = torch.arange(0, h, dtype=torch.float) 87 | x = torch.arange(0, w, dtype=torch.float) 88 | y, x = torch.meshgrid(y, x) 89 | 90 | x_mask = (masks * x.unsqueeze(0)) 91 | x_max = x_mask.flatten(1).max(-1)[0] 92 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 93 | 94 | y_mask = (masks * y.unsqueeze(0)) 95 | y_max = y_mask.flatten(1).max(-1)[0] 96 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 97 | 98 | return torch.stack([x_min, y_min, x_max, y_max], 1) 99 | -------------------------------------------------------------------------------- /util/checkpoint.py: -------------------------------------------------------------------------------- 1 | # from: https://github.com/csrhddlam/pytorch-checkpoint 2 | 3 | import torch 4 | import warnings 5 | 6 | 7 | def detach_variable(inputs): 8 | if isinstance(inputs, tuple): 9 | out = [] 10 | for inp in inputs: 11 | x = inp.detach() 12 | x.requires_grad = inp.requires_grad 13 | out.append(x) 14 | return tuple(out) 15 | else: 16 | raise RuntimeError( 17 | "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__) 18 | 19 | 20 | def check_backward_validity(inputs): 21 | if not any(inp.requires_grad for inp in inputs): 22 | warnings.warn("None of the inputs have requires_grad=True. Gradients will be None") 23 | 24 | 25 | class CheckpointFunction(torch.autograd.Function): 26 | @staticmethod 27 | def forward(ctx, run_function, length, *args): 28 | ctx.run_function = run_function 29 | ctx.input_tensors = list(args[:length]) 30 | ctx.input_params = list(args[length:]) 31 | with torch.no_grad(): 32 | output_tensors = ctx.run_function(*ctx.input_tensors) 33 | return output_tensors 34 | 35 | @staticmethod 36 | def backward(ctx, *output_grads): 37 | for i in range(len(ctx.input_tensors)): 38 | temp = ctx.input_tensors[i] 39 | ctx.input_tensors[i] = temp.detach() 40 | ctx.input_tensors[i].requires_grad = temp.requires_grad 41 | with torch.enable_grad(): 42 | output_tensors = ctx.run_function(*ctx.input_tensors) 43 | to_autograd = [] 44 | for i in range(len(ctx.input_tensors)): 45 | if ctx.input_tensors[i].requires_grad: 46 | to_autograd.append(ctx.input_tensors[i]) 47 | output_tensors, output_grads = zip(*filter(lambda t: t[0].requires_grad, zip(output_tensors, output_grads))) 48 | input_grads = torch.autograd.grad(output_tensors, to_autograd + ctx.input_params, output_grads, allow_unused=True) 49 | input_grads = list(input_grads) 50 | for i in range(len(ctx.input_tensors)): 51 | if not ctx.input_tensors[i].requires_grad: 52 | input_grads.insert(i, None) 53 | return (None, None) + tuple(input_grads) 54 | -------------------------------------------------------------------------------- /util/evaluation.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | 12 | import os 13 | import numpy as np 14 | import copy 15 | import motmetrics as mm 16 | mm.lap.default_solver = 'lap' 17 | import os 18 | from typing import Dict 19 | import numpy as np 20 | import logging 21 | 22 | def read_results(filename, data_type: str, is_gt=False, is_ignore=False): 23 | if data_type in ('mot', 'lab'): 24 | read_fun = read_mot_results 25 | else: 26 | raise ValueError('Unknown data type: {}'.format(data_type)) 27 | 28 | return read_fun(filename, is_gt, is_ignore) 29 | 30 | # def read_mot_results(filename, is_gt, is_ignore): 31 | # results_dict = dict() 32 | # if os.path.isfile(filename): 33 | # with open(filename, 'r') as f: 34 | # for line in f.readlines(): 35 | # linelist = line.split(',') 36 | # if len(linelist) < 7: 37 | # continue 38 | # fid = int(linelist[0]) 39 | # if fid < 1: 40 | # continue 41 | # results_dict.setdefault(fid, list()) 42 | 43 | # if is_gt: 44 | # mark = int(float(linelist[6])) 45 | # if mark == 0 : 46 | # continue 47 | # score = 1 48 | # elif is_ignore: 49 | # score = 1 50 | # else: 51 | # score = float(linelist[6]) 52 | 53 | # tlwh = tuple(map(float, linelist[2:6])) 54 | # target_id = int(float(linelist[1])) 55 | # results_dict[fid].append((tlwh, target_id, score)) 56 | 57 | # return results_dict 58 | 59 | def read_mot_results(filename, is_gt, is_ignore): 60 | valid_labels = {1} 61 | ignore_labels = {0, 2, 7, 8, 12} 62 | results_dict = dict() 63 | if os.path.isfile(filename): 64 | with open(filename, 'r') as f: 65 | for line in f.readlines(): 66 | linelist = line.split(',') 67 | if len(linelist) < 7: 68 | continue 69 | fid = int(linelist[0]) 70 | if fid < 1: 71 | continue 72 | results_dict.setdefault(fid, list()) 73 | 74 | if is_gt: 75 | if 'MOT16-' in filename or 'MOT17-' in filename: 76 | label = int(float(linelist[7])) 77 | mark = int(float(linelist[6])) 78 | if mark == 0 or label not in valid_labels: 79 | continue 80 | score = 1 81 | elif is_ignore: 82 | if 'MOT16-' in filename or 'MOT17-' in filename: 83 | label = int(float(linelist[7])) 84 | vis_ratio = float(linelist[8]) 85 | if label not in ignore_labels and vis_ratio >= 0: 86 | continue 87 | elif 'MOT15' in filename: 88 | label = int(float(linelist[6])) 89 | if label not in ignore_labels: 90 | continue 91 | else: 92 | continue 93 | score = 1 94 | else: 95 | score = float(linelist[6]) 96 | 97 | tlwh = tuple(map(float, linelist[2:6])) 98 | target_id = int(linelist[1]) 99 | 100 | results_dict[fid].append((tlwh, target_id, score)) 101 | 102 | return results_dict 103 | 104 | def unzip_objs(objs): 105 | if len(objs) > 0: 106 | tlwhs, ids, scores = zip(*objs) 107 | else: 108 | tlwhs, ids, scores = [], [], [] 109 | tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4) 110 | return tlwhs, ids, scores 111 | 112 | 113 | class Evaluator(object): 114 | def __init__(self, data_root, seq_name, data_type='mot'): 115 | 116 | self.data_root = data_root 117 | self.seq_name = seq_name 118 | self.data_type = data_type 119 | 120 | self.load_annotations() 121 | self.reset_accumulator() 122 | 123 | def load_annotations(self): 124 | assert self.data_type == 'mot' 125 | 126 | gt_filename = os.path.join(self.data_root, self.seq_name, 'gt', 'gt.txt') 127 | self.gt_frame_dict = read_results(gt_filename, self.data_type, is_gt=True) 128 | self.gt_ignore_frame_dict = read_results(gt_filename, self.data_type, is_ignore=True) 129 | 130 | def reset_accumulator(self): 131 | self.acc = mm.MOTAccumulator(auto_id=True) 132 | 133 | def eval_frame(self, frame_id, trk_tlwhs, trk_ids, rtn_events=False): 134 | # results 135 | trk_tlwhs = np.copy(trk_tlwhs) 136 | trk_ids = np.copy(trk_ids) 137 | 138 | # gts 139 | gt_objs = self.gt_frame_dict.get(frame_id, []) 140 | gt_tlwhs, gt_ids = unzip_objs(gt_objs)[:2] 141 | 142 | # ignore boxes 143 | ignore_objs = self.gt_ignore_frame_dict.get(frame_id, []) 144 | ignore_tlwhs = unzip_objs(ignore_objs)[0] 145 | # remove ignored results 146 | keep = np.ones(len(trk_tlwhs), dtype=bool) 147 | iou_distance = mm.distances.iou_matrix(ignore_tlwhs, trk_tlwhs, max_iou=0.5) 148 | if len(iou_distance) > 0: 149 | match_is, match_js = mm.lap.linear_sum_assignment(iou_distance) 150 | match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js]) 151 | match_ious = iou_distance[match_is, match_js] 152 | 153 | match_js = np.asarray(match_js, dtype=int) 154 | match_js = match_js[np.logical_not(np.isnan(match_ious))] 155 | keep[match_js] = False 156 | trk_tlwhs = trk_tlwhs[keep] 157 | trk_ids = trk_ids[keep] 158 | 159 | # get distance matrix 160 | iou_distance = mm.distances.iou_matrix(gt_tlwhs, trk_tlwhs, max_iou=0.5) 161 | 162 | # acc 163 | self.acc.update(gt_ids, trk_ids, iou_distance) 164 | 165 | if rtn_events and iou_distance.size > 0 and hasattr(self.acc, 'last_mot_events'): 166 | events = self.acc.last_mot_events # only supported by https://github.com/longcw/py-motmetrics 167 | else: 168 | events = None 169 | return events 170 | 171 | def eval_file(self, filename): 172 | self.reset_accumulator() 173 | 174 | result_frame_dict = read_results(filename, self.data_type, is_gt=False) 175 | frames = sorted(list(set(self.gt_frame_dict.keys()) | set(result_frame_dict.keys()))) 176 | for frame_id in frames: 177 | trk_objs = result_frame_dict.get(frame_id, []) 178 | trk_tlwhs, trk_ids = unzip_objs(trk_objs)[:2] 179 | self.eval_frame(frame_id, trk_tlwhs, trk_ids, rtn_events=False) 180 | 181 | return self.acc 182 | 183 | @staticmethod 184 | def get_summary(accs, names, metrics=('mota', 'num_switches', 'idp', 'idr', 'idf1', 'precision', 'recall')): 185 | names = copy.deepcopy(names) 186 | if metrics is None: 187 | metrics = mm.metrics.motchallenge_metrics 188 | metrics = copy.deepcopy(metrics) 189 | 190 | mh = mm.metrics.create() 191 | summary = mh.compute_many( 192 | accs, 193 | metrics=metrics, 194 | names=names, 195 | generate_overall=True 196 | ) 197 | 198 | return summary 199 | 200 | @staticmethod 201 | def save_summary(summary, filename): 202 | import pandas as pd 203 | writer = pd.ExcelWriter(filename) 204 | summary.to_excel(writer) 205 | writer.save() -------------------------------------------------------------------------------- /util/plot_utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | 12 | """ 13 | Plotting utilities to visualize training logs. 14 | """ 15 | import cv2 16 | import torch 17 | import pandas as pd 18 | import numpy as np 19 | import seaborn as sns 20 | import matplotlib.pyplot as plt 21 | 22 | from torch import Tensor 23 | 24 | from pathlib import Path, PurePath 25 | 26 | 27 | def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): 28 | ''' 29 | Function to plot specific fields from training log(s). Plots both training and test results. 30 | 31 | :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file 32 | - fields = which results to plot from each log file - plots both training and test for each field. 33 | - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots 34 | - log_name = optional, name of log file if different than default 'log.txt'. 35 | 36 | :: Outputs - matplotlib plots of results in fields, color coded for each log file. 37 | - solid lines are training results, dashed lines are test results. 38 | 39 | ''' 40 | func_name = "plot_utils.py::plot_logs" 41 | 42 | # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, 43 | # convert single Path to list to avoid 'not iterable' error 44 | 45 | if not isinstance(logs, list): 46 | if isinstance(logs, PurePath): 47 | logs = [logs] 48 | print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") 49 | else: 50 | raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ 51 | Expect list[Path] or single Path obj, received {type(logs)}") 52 | 53 | # verify valid dir(s) and that every item in list is Path object 54 | for i, dir in enumerate(logs): 55 | if not isinstance(dir, PurePath): 56 | raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") 57 | if dir.exists(): 58 | continue 59 | raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") 60 | 61 | # load log file(s) and plot 62 | dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] 63 | 64 | fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) 65 | 66 | for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): 67 | for j, field in enumerate(fields): 68 | if field == 'mAP': 69 | coco_eval = pd.DataFrame(pd.np.stack(df.test_coco_eval.dropna().values)[:, 1]).ewm(com=ewm_col).mean() 70 | axs[j].plot(coco_eval, c=color) 71 | else: 72 | df.interpolate().ewm(com=ewm_col).mean().plot( 73 | y=[f'train_{field}', f'test_{field}'], 74 | ax=axs[j], 75 | color=[color] * 2, 76 | style=['-', '--'] 77 | ) 78 | for ax, field in zip(axs, fields): 79 | ax.legend([Path(p).name for p in logs]) 80 | ax.set_title(field) 81 | 82 | 83 | def plot_precision_recall(files, naming_scheme='iter'): 84 | if naming_scheme == 'exp_id': 85 | # name becomes exp_id 86 | names = [f.parts[-3] for f in files] 87 | elif naming_scheme == 'iter': 88 | names = [f.stem for f in files] 89 | else: 90 | raise ValueError(f'not supported {naming_scheme}') 91 | fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) 92 | for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): 93 | data = torch.load(f) 94 | # precision is n_iou, n_points, n_cat, n_area, max_det 95 | precision = data['precision'] 96 | recall = data['params'].recThrs 97 | scores = data['scores'] 98 | # take precision for all classes, all areas and 100 detections 99 | precision = precision[0, :, :, 0, -1].mean(1) 100 | scores = scores[0, :, :, 0, -1].mean(1) 101 | prec = precision.mean() 102 | rec = data['recall'][0, :, 0, -1].mean() 103 | print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + 104 | f'score={scores.mean():0.3f}, ' + 105 | f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' 106 | ) 107 | axs[0].plot(recall, precision, c=color) 108 | axs[1].plot(recall, scores, c=color) 109 | 110 | axs[0].set_title('Precision / Recall') 111 | axs[0].legend(names) 112 | axs[1].set_title('Scores / Recall') 113 | axs[1].legend(names) 114 | return fig, axs 115 | 116 | 117 | def draw_boxes(image: Tensor, boxes: Tensor, color=(0, 255, 0), texts=None) -> np.ndarray: 118 | if isinstance(image, Tensor): 119 | cv_image = image.detach().cpu().numpy() 120 | else: 121 | cv_image = image 122 | if isinstance(boxes, Tensor): 123 | cv_boxes = boxes.detach().cpu().numpy() 124 | else: 125 | cv_boxes = boxes 126 | 127 | tl = round(0.002 * max(image.shape[0:2])) + 1 # line thickness 128 | tf = max(tl - 1, 1) 129 | for i in range(len(boxes)): 130 | box = cv_boxes[i] 131 | x1, y1 = box[0:2] 132 | x2, y2 = box[2:4] 133 | cv2.rectangle(cv_image, (int(x1), int(y1)), (int(x2), int(y2)), color=color) 134 | if texts is not None: 135 | cv2.putText(cv_image, texts[i], (int(x1), int(y1+10)), 0, tl/3, [225, 255, 255], 136 | thickness=tf, 137 | lineType=cv2.LINE_AA) 138 | return cv_image 139 | 140 | 141 | def draw_ref_pts(image: Tensor, ref_pts: Tensor) -> np.ndarray: 142 | if isinstance(image, Tensor): 143 | cv_image = image.detach().cpu().numpy() 144 | else: 145 | cv_image = image 146 | if isinstance(ref_pts, Tensor): 147 | cv_pts = ref_pts.detach().cpu().numpy() 148 | else: 149 | cv_pts = ref_pts 150 | for i in range(len(cv_pts)): 151 | x, y, is_pos = cv_pts[i] 152 | color = (0, 1, 0) if is_pos else (1, 1, 1) 153 | cv2.circle(cv_image, (int(x), int(y)), 2, color) 154 | return cv_image 155 | 156 | 157 | def image_hwc2chw(image: np.ndarray): 158 | image = np.ascontiguousarray(image.transpose(2, 0, 1)) 159 | return image 160 | -------------------------------------------------------------------------------- /util/tool.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 5 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 6 | # ------------------------------------------------------------------------ 7 | # Modified from DETR (https://github.com/facebookresearch/detr) 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | # ------------------------------------------------------------------------ 10 | 11 | import torch 12 | import numpy as np 13 | 14 | 15 | def load_model(model, model_path, optimizer=None, resume=False, 16 | lr=None, lr_step=None): 17 | start_epoch = 0 18 | checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) 19 | print(f'loaded {model_path}') 20 | state_dict = checkpoint['model'] 21 | model_state_dict = model.state_dict() 22 | 23 | # check loaded parameters and created model parameters 24 | msg = 'If you see this, your model does not fully load the ' + \ 25 | 'pre-trained weight. Please make sure ' + \ 26 | 'you set the correct --num_classes for your own dataset.' 27 | for k in state_dict: 28 | if k in model_state_dict: 29 | if state_dict[k].shape != model_state_dict[k].shape: 30 | print('Skip loading parameter {}, required shape{}, ' \ 31 | 'loaded shape{}. {}'.format( 32 | k, model_state_dict[k].shape, state_dict[k].shape, msg)) 33 | if 'class_embed' in k: 34 | print("load class_embed: {} shape={}".format(k, state_dict[k].shape)) 35 | if model_state_dict[k].shape[0] == 1: 36 | state_dict[k] = state_dict[k][1:2] 37 | elif model_state_dict[k].shape[0] == 2: 38 | state_dict[k] = state_dict[k][1:3] 39 | elif model_state_dict[k].shape[0] == 3: 40 | state_dict[k] = state_dict[k][1:4] 41 | else: 42 | raise NotImplementedError('invalid shape: {}'.format(model_state_dict[k].shape)) 43 | continue 44 | state_dict[k] = model_state_dict[k] 45 | else: 46 | print('Drop parameter {}.'.format(k) + msg) 47 | for k in model_state_dict: 48 | if not (k in state_dict): 49 | print('No param {}.'.format(k) + msg) 50 | state_dict[k] = model_state_dict[k] 51 | model.load_state_dict(state_dict, strict=False) 52 | 53 | # resume optimizer parameters 54 | if optimizer is not None and resume: 55 | if 'optimizer' in checkpoint: 56 | optimizer.load_state_dict(checkpoint['optimizer']) 57 | start_epoch = checkpoint['epoch'] 58 | start_lr = lr 59 | for step in lr_step: 60 | if start_epoch >= step: 61 | start_lr *= 0.1 62 | for param_group in optimizer.param_groups: 63 | param_group['lr'] = start_lr 64 | print('Resumed optimizer with start lr', start_lr) 65 | else: 66 | print('No optimizer parameters in checkpoint.') 67 | if optimizer is not None: 68 | return model, optimizer, start_epoch 69 | else: 70 | return model 71 | 72 | 73 | 74 | --------------------------------------------------------------------------------