├── .gitignore ├── LICENSE ├── README.md ├── Structure.png └── lib ├── __init__.py ├── configs ├── TADAM_MOT16.yaml ├── TADAM_MOT17.yaml ├── TADAM_MOT20.yaml ├── __init__.py └── config.py ├── dataset ├── __init__.py └── mot.py ├── modules ├── __init__.py ├── attention.py ├── detector.py ├── faster_rcnn.py ├── genenralized_rcnn_transform.py ├── identity.py ├── integration.py ├── memory.py └── roi_heads.py ├── tracking ├── __init__.py ├── detection.py ├── test_tracker.py ├── tracker.py └── tracklet.py ├── training ├── __init__.py ├── group_by_aspect_ratio.py ├── train.py └── train_utils.py └── utils ├── __init__.py ├── image_processing.py ├── kalman_filter.py ├── log.py ├── matching.py ├── model_loader.py ├── official_benchmark.py ├── timer.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Folders 2 | logs/ 3 | __pycache__/ 4 | output/* 5 | .vscode/ 6 | datasets -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Song Guo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Online Multiple Object Tracking with Cross-Task Synergy 2 | This repository is the implementation of the CVPR 2021 paper ["Online Multiple Object Tracking with Cross-Task Synergy"](https://arxiv.org/abs/2104.00380) 3 | ![Structure of TADAM](Structure.png) 4 | 5 | ## Installation 6 | Tested on python=3.8 with torch=1.8.1 and torchvision=0.9.1. 7 | 8 | It should also be compatible with python>=3.6, torch>=1.4.0 and torchvision>=0.4.0. 9 | Not tested on lower versions. 10 | 11 | ### 1. Clone the repository 12 | ``` 13 | git clone https://github.com/songguocode/TADAM.git 14 | ``` 15 | 16 | ### 2. Create conda env and activate 17 | ``` 18 | conda create -n TADAM python=3.8 19 | conda activate TADAM 20 | ``` 21 | 22 | ### 3. Install required packages 23 | ``` 24 | pip install torch torchvision scipy opencv-python yacs 25 | ``` 26 | All models are set to run on GPU, thus make sure graphics card driver is properly installed, as well as CUDA. 27 | 28 | To check if `torch` is running with CUDA, run in python: 29 | ``` 30 | import torch 31 | torch.cuda.is_available() 32 | ``` 33 | It is working if `True` is returned. 34 | 35 | See [PyTorch Official Site](https://pytorch.org/get-started/locally/) if `torch` is not installed or working properly. 36 | 37 | ### 4. Clone MOTChallenge benchmark evaluation code 38 | ``` 39 | git clone https://github.com/JonathonLuiten/TrackEval.git 40 | ``` 41 | 42 | By now there should be two folders, `TADAM` and `TrackEval`. 43 | 44 | Refer to [MOTChallenge-Official](https://github.com/JonathonLuiten/TrackEval/blob/master/docs/MOTChallenge-Official/Readme.md) for instructions. 45 | 46 | Download the provided `data.zip`, unzip as folder `data` and copy inside `TrackEval` as `TrackEva/data`. 47 | 48 | Move into `TADAM` folder 49 | ``` 50 | cd TADAM 51 | ``` 52 | 53 | ### 5. Prepare MOTChallenge data 54 | Download [MOT16](https://motchallenge.net/data/MOT16.zip), [MOT17](https://motchallenge.net/data/MOT17.zip), [MOT17Det](https://motchallenge.net/data/MOT17Det.zip), and [MOT20](https://motchallenge.net/data/MOT20.zip) and place them inside a `datasets` folder. 55 | 56 | Two options to provide `datasets` location for training/testing: 57 | * a. Add a symbolic link inside `TADAM` folder by `ln -s path_of_datasets datasets` 58 | * b. In `TADAM/configs/config.py`, assign `__C.PATHS.DATASET_ROOT` with `path_of_datasets` 59 | 60 | ### 6. Download Models 61 | The training base of TADAM is a detector pretrained on COCO. The base model `coco_checkpoint.pth` is provided in [Google Drive](https://drive.google.com/drive/folders/13vVgYkq6lulxYmhW2FQucptbDDzij92i?usp=sharing) 62 | 63 | Trained models are also provided for reference: 64 | * TADAM_MOT16.pth 65 | * TADAM_MOT17.pth 66 | * TADAM_MOT20.pth 67 | 68 | Create a folder `output/models` and place all models inside. 69 | 70 | ## Train 71 | 1. Training on single GPU, for MOT17 as an example 72 | ``` 73 | python -m lib.training.train TADAM_MOT17 --config TADAM_MOT17 74 | ``` 75 | First `TADAM_MOT17` specifies the output name of the trained model, which can be changed as preferred. 76 | 77 | Second `TADAM_MOT17` refers to the config file `lib/configs/TADAM_MOT17.yaml` that loads training parameters. Switch config for respective dataset training. Config files are located in `lib/configs`. 78 | 79 | 2. Training on multiple GPU with [Distributed Data Parallel](https://pytorch.org/docs/stable/notes/ddp.html) 80 | ``` 81 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=2 --use_env -m lib.training.train TADAM_MOT17 --config TADAM_MOT17 82 | ``` 83 | Argument `--nproc_per_node=2` specifies how many GPUs to be used for training. Here 2 cards are used. 84 | 85 | Trained model will be stored inside `output/models` with the specified output name 86 | 87 | ## Evaluate 88 | ``` 89 | python -m lib.tracking.test_tracker --result-name xxx --config TADAM_MOT17 --evaluation 90 | ``` 91 | Change `xxx` to prefered result name. 92 | `--evaluation` toggles on evaluation right after obtaining tracking results. Remove it if only running for results without evaluation. Evaluation requires all sequences results of the specified dataset. 93 | 94 | Either run evaluation after training, or download and test the provided trained models. 95 | 96 | Note that if output name of the trained model is changed, it must be specified in corresponding `.yaml` config file's line, i.e. replace value in `MODEL: TADAM_MOT17.pth` with expected model file name. 97 | 98 | Add `--which_set test` flag to run on test dataset. Note that `--evaluation` should be removed, as no evaluation result is available for test dataset. 99 | 100 | Code from `TrackEval` is used for evaluation, and it is set to run on multiple cores (8 cores) by default. 101 | 102 | To run an evaluation after obtaining tracking results (with sequences result files), run: 103 | ``` 104 | python -m lib.utils.official_benchmark --result-name xxx --config TADAM_MOT17 105 | ``` 106 | Replace `xxx` with the result name, and choose config accordingly. 107 | 108 | Tracking results can be found in `output/results` under respective dataset name folders. 109 | Detailed result is stored in a `xxx_detailed.csv` file, while the summary is given in a `xxx_summary.txt` file. 110 | 111 | ### Results for reference 112 | The evaluation results on train sets are given here for reference. See paper for reported test sets results. 113 | * MOT16 114 | ``` 115 | MOTA MOTP MODA CLR_Re CLR_Pr MTR PTR MLR CLR_TP CLR_FN 116 | 63.7 91.6 63.9 64.5 99.0 35.6 40.8 23.6 71242 39165 117 | CLR_FP IDSW MT PT ML Frag sMOTA IDF1 IDR IDP 118 | 689 186 184 211 122 316 58.3 68.0 56.2 86.2 119 | IDTP IDFN IDFP Dets GT_Dets IDs GT_IDs 120 | 62013 48394 9918 71931 110407 446 517 121 | ``` 122 | * MOT17 123 | ``` 124 | MOTA MOTP MODA CLR_Re CLR_Pr MTR PTR MLR CLR_TP CLR_FN 125 | 68.0 91.3 68.2 69.0 98.8 43.5 37.5 19.0 232600 104291 126 | CLR_FP IDSW MT PT ML Frag sMOTA IDF1 IDR IDP 127 | 2845 742 712 615 311 1182 62.0 71.6 60.8 87.0 128 | IDTP IDFN IDFP Dets GT_Dets IDs GT_IDs 129 | 204819 132072 30626 235445 336891 1455 1638 130 | ``` 131 | * MOT20 132 | ``` 133 | MOTA MOTP MODA CLR_Re CLR_Pr MTR PTR MLR CLR_TP CLR_FN 134 | 80.2 87.0 80.4 82.2 97.9 64.0 28.8 7.18 932899 201715 135 | CLR_FP IDSW MT PT ML Frag sMOTA IDF1 IDR IDP 136 | 20355 2275 1418 638 159 2737 69.5 72.3 66.5 79.2 137 | IDTP IDFN IDFP Dets GT_Dets IDs GT_IDs 138 | 754621 379993 198633 953254 1134614 2953 2215 139 | ``` 140 | Results could differ slightly, and small variations should be acceptable. 141 | 142 | ## Visualization 143 | A visualization tool is provided to preview datasets' ground-truths, provided detections, and generated tracking results. 144 | ``` 145 | python -m lib.utils.visualization --config TADAM_MOT17 --which-set train --sequence 02 --public-detection FRCNN --result xxx --start-frame 1 --scale 0.8 146 | ``` 147 | Specify config files, train/test split, and sequence with `--config`, `--which-set`, `--sequence` respectively. `--public-detection` should only be specified for MOT17. 148 | 149 | Replace `--result xxx` with the tracking results 150 | `--start-frame 1` means viewing from frame 1, while `--scale 0.8` resizes viewing window with given ratio. 151 | 152 | Commands in visualization window: 153 | * "<": previous frame 154 | * ">": next frame 155 | * "t": toggle between viewing ground_truths, provided detections, and tracking results 156 | * "s": save current frame with all rendered elements 157 | * "h": hide frame information on window's top-left corner 158 | * "i": hide identity index on bounding boxes' top-left corner 159 | * "Esc" or "q": exit program 160 | 161 | ## Pretrain detector on COCO 162 | Basic detector is pretrained on COCO dataset, before training on MOT. A Faster-RCNN FPN with ResNet101 backbone is adopted in this code, which can be replaced by other similar detectors with code modifications. 163 | 164 | Refer to [Object detection reference training scripts](https://github.com/pytorch/vision/tree/master/references/detection) on how to train a PyTorch-based detector. 165 | 166 | See [Tracking without bells and whistles](https://github.com/phil-bergmann/tracking_wo_bnw) for a jupyter notebook hands-on, which is also based on the aforementioned reference codes. 167 | 168 | ## Publication 169 | If you use the code in your research, please cite: 170 | ``` 171 | @InProceedings{TADAM_2021_CVPR, 172 | author = {Guo, Song and Wang, Jingya and Wang, Xinchao and Tao, Dacheng}, 173 | title = {Online Multiple Object Tracking With Cross-Task Synergy}, 174 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 175 | month = {June}, 176 | year = {2021}, 177 | } 178 | ``` 179 | -------------------------------------------------------------------------------- /Structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songguocode/TADAM/abd0b7422c3582e36c928778894cee8a159f896e/Structure.png -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songguocode/TADAM/abd0b7422c3582e36c928778894cee8a159f896e/lib/__init__.py -------------------------------------------------------------------------------- /lib/configs/TADAM_MOT16.yaml: -------------------------------------------------------------------------------- 1 | NAMES: 2 | DATASET: MOT16 3 | MODEL: TADAM_MOT16.pth -------------------------------------------------------------------------------- /lib/configs/TADAM_MOT17.yaml: -------------------------------------------------------------------------------- 1 | NAMES: 2 | DATASET: MOT17 3 | MODEL: TADAM_MOT17.pth -------------------------------------------------------------------------------- /lib/configs/TADAM_MOT20.yaml: -------------------------------------------------------------------------------- 1 | NAMES: 2 | DATASET: MOT20 3 | MODEL: TADAM_MOT20.pth 4 | TRAINING: 5 | BATCH_SIZE: 1 -------------------------------------------------------------------------------- /lib/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songguocode/TADAM/abd0b7422c3582e36c928778894cee8a159f896e/lib/configs/__init__.py -------------------------------------------------------------------------------- /lib/configs/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from yacs.config import CfgNode as CN 3 | from ..utils.log import log_or_print 4 | 5 | __C = CN() 6 | 7 | base_config = __C 8 | 9 | __C.NAMES = CN() 10 | __C.NAMES.BACKBONE = "resnet101" 11 | __C.NAMES.DATASET = "MOT17" 12 | __C.NAMES.CHECKPOINT = "coco_checkpoint.pth" 13 | __C.NAMES.MODEL = "TADAM_MOT17.pth" 14 | 15 | __C.PATHS = CN() 16 | __C.PATHS.DATASET_ROOT = "datasets" 17 | __C.PATHS.MODEL_ROOT = "output/models" 18 | __C.PATHS.RESULT_ROOT = "output/results" 19 | __C.PATHS.EVAL_ROOT = "../TrackEval" 20 | 21 | __C.TRAINING = CN() 22 | __C.TRAINING = CN() 23 | __C.TRAINING.BATCH_SIZE = 2 24 | __C.TRAINING.EPOCHS = 12 25 | __C.TRAINING.ID_LOSS_RATIO = 0.1 26 | __C.TRAINING.LR = 0.002 27 | __C.TRAINING.LR_GAMMA = 0.5 28 | __C.TRAINING.LR_STEP_SIZE = 3 29 | __C.TRAINING.PRINT_FREQ = 200 30 | __C.TRAINING.SAVE_FREQ = 3 31 | __C.TRAINING.RANDOM_SEED = 123456 32 | __C.TRAINING.MOMENTUM = 0.9 33 | __C.TRAINING.WEIGHT_DECAY = 0.0005 34 | __C.TRAINING.VIS_THRESHOLD = 0.1 35 | __C.TRAINING.WARMUP_LR = 0.02 36 | __C.TRAINING.WARMUP_EPOCHS = 3 37 | __C.TRAINING.WORKERS = 2 38 | 39 | __C.TRACKING = CN() 40 | __C.TRACKING.MIN_BOX_SIZE = 3 41 | __C.TRACKING.MIN_SCORE_ACTIVE_TRACKLET = 0.5 42 | __C.TRACKING.MIN_SCORE_DETECTION = 0.05 43 | __C.TRACKING.NMS_ACTIVE_TRACKLET = 0.7 44 | __C.TRACKING.NMS_DETECTION = 0.3 45 | __C.TRACKING.MIN_OVERLAP_AS_DISTRACTOR = 0.2 46 | __C.TRACKING.MIN_RECOVER_GIOU = -0.4 47 | __C.TRACKING.MIN_RECOVER_SCORE = 0.5 48 | __C.TRACKING.MAX_LOST_FRAMES_BEFORE_REMOVE = 100 49 | 50 | 51 | def load_config(config_file=None): 52 | """ 53 | Load configurations 54 | Add or overwrite config from yaml file if specified 55 | """ 56 | config = base_config 57 | if config_file is not None: 58 | config_file_path = os.path.join("lib", "configs", f"{config_file}.yaml") 59 | if os.path.isfile(config_file_path): 60 | config.merge_from_file(config_file_path) 61 | msg = f"Merged config from '{config_file_path}'" 62 | else: 63 | print(f"Cannot open the specified yaml config file '{config_file_path}'", level="critical") 64 | exit(0) 65 | else: 66 | msg = f"No yaml config file is specified. Using default config." 67 | return config, msg 68 | -------------------------------------------------------------------------------- /lib/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songguocode/TADAM/abd0b7422c3582e36c928778894cee8a159f896e/lib/dataset/__init__.py -------------------------------------------------------------------------------- /lib/dataset/mot.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import csv 3 | import os 4 | import itertools 5 | 6 | from PIL import Image 7 | import torch 8 | from torchvision.transforms.functional import to_tensor 9 | 10 | 11 | class MOTDetection(torch.utils.data.Dataset): 12 | """ 13 | Data class for detection 14 | Loads all images in all sequences at once 15 | To be used for training 16 | """ 17 | 18 | def __init__( 19 | self, 20 | root="../datasets/", 21 | dataset="MOT17Det", 22 | transforms=None, 23 | vis_threshold=0.1, 24 | ): 25 | predefined_datasets = ["MOT16", "MOT17Det", "MOT20"] 26 | assert dataset in predefined_datasets, \ 27 | f"Provided dataset name '{dataset}' is not in predefined datasets: {predefined_datasets}" 28 | 29 | self.root = os.path.join(root, dataset, "train") 30 | self.transforms = transforms 31 | self._vis_threshold = vis_threshold 32 | self._classes = ("__background__", "pedestrian") 33 | self._global_id_counter = 0 34 | self._local_to_global_dict = {} 35 | self._global_to_local_dict = {} 36 | self._img_paths = [] 37 | self._aspect_ratios = [] 38 | 39 | for f in sorted(os.listdir(self.root)): 40 | path = os.path.join(self.root, f) 41 | config_file = os.path.join(path, "seqinfo.ini") 42 | 43 | assert os.path.exists(config_file), f"Path does not exist: {config_file}" 44 | 45 | config = configparser.ConfigParser() 46 | config.read(config_file) 47 | seq_len = int(config["Sequence"]["seqLength"]) 48 | im_width = int(config["Sequence"]["imWidth"]) 49 | im_height = int(config["Sequence"]["imHeight"]) 50 | im_ext = config["Sequence"]["imExt"] 51 | im_dir = config["Sequence"]["imDir"] 52 | 53 | _imDir = os.path.join(path, im_dir) 54 | aspect_ratio = im_width / im_height 55 | 56 | # Collect global gt_id 57 | self.process_ids(path) 58 | 59 | for i in range(1, seq_len + 1): 60 | img_path = os.path.join(_imDir, f"{i:06d}{im_ext}") 61 | assert os.path.exists(img_path), \ 62 | "Path does not exist: {img_path}" 63 | self._img_paths.append(img_path) 64 | self._aspect_ratios.append(aspect_ratio) 65 | 66 | @property 67 | def num_classes(self): 68 | return len(self._classes) 69 | 70 | @property 71 | def num_ids(self): 72 | return self._global_id_counter 73 | 74 | def _get_annotation(self, idx): 75 | """ 76 | Obtain annotation from gt file 77 | """ 78 | 79 | img_path = self._img_paths[idx] 80 | file_index = int(os.path.basename(img_path).split(".")[0]) 81 | 82 | gt_file = os.path.join(os.path.dirname( 83 | os.path.dirname(img_path)), "gt", "gt.txt") 84 | seq_name = os.path.basename(os.path.dirname(os.path.dirname(img_path))) 85 | 86 | assert os.path.exists(gt_file), f"GT file does not exist: {gt_file}" 87 | 88 | bounding_boxes = [] 89 | 90 | with open(gt_file, "r") as inf: 91 | reader = csv.reader(inf, delimiter=",") 92 | for row in reader: 93 | visibility = float(row[8]) 94 | local_id = f"{seq_name}-{int(row[1])}" 95 | if int(row[0]) == file_index and int(row[6]) == 1 and int(row[7]) == 1 and \ 96 | visibility > self._vis_threshold: 97 | bb = {} 98 | bb["gt_id"] = self._local_to_global_dict[local_id] 99 | bb["bb_left"] = int(row[2]) 100 | bb["bb_top"] = int(row[3]) 101 | bb["bb_width"] = int(row[4]) 102 | bb["bb_height"] = int(row[5]) 103 | bb["visibility"] = visibility 104 | 105 | bounding_boxes.append(bb) 106 | 107 | num_objs = len(bounding_boxes) 108 | 109 | boxes = torch.zeros((num_objs, 4), dtype=torch.float32) 110 | visibilities = torch.zeros((num_objs), dtype=torch.float32) 111 | gt_ids = torch.zeros((num_objs), dtype=torch.int64) 112 | 113 | for i, bb in enumerate(bounding_boxes): 114 | x1 = bb["bb_left"] # GS 115 | y1 = bb["bb_top"] 116 | x2 = x1 + bb["bb_width"] 117 | y2 = y1 + bb["bb_height"] 118 | boxes[i, 0] = x1 119 | boxes[i, 1] = y1 120 | boxes[i, 2] = x2 121 | boxes[i, 3] = y2 122 | visibilities[i] = bb["visibility"] 123 | gt_ids[i] = bb["gt_id"] 124 | 125 | return {"boxes": boxes, 126 | "labels": torch.ones((num_objs,), dtype=torch.int64), 127 | "image_id": torch.tensor([idx]), 128 | "area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]), 129 | "iscrowd": torch.zeros((num_objs,), dtype=torch.int64), 130 | "visibilities": visibilities, 131 | "frame_id": torch.tensor([file_index]), 132 | "gt_ids": gt_ids} 133 | 134 | def process_ids(self, path): 135 | """ 136 | Global id is 0-based, indexed across all sequences 137 | All ids are considered, regardless of used or not 138 | """ 139 | seq_name = os.path.basename(path) 140 | if seq_name not in self._global_to_local_dict.keys(): 141 | self._global_to_local_dict[seq_name] = {} 142 | gt_file = os.path.join(path, "gt", "gt.txt") 143 | with open(gt_file, "r") as inf: 144 | reader = csv.reader(inf, delimiter=",") 145 | for row in reader: 146 | local_id = f"{seq_name}-{int(row[1])}" 147 | if int(row[6]) == 1 and int(row[7]) == 1: 148 | if local_id not in self._local_to_global_dict.keys(): 149 | self._local_to_global_dict[local_id] = self._global_id_counter 150 | self._global_to_local_dict[seq_name][self._global_id_counter] = int(row[1]) 151 | self._global_id_counter += 1 152 | 153 | def __getitem__(self, idx): 154 | # Load image 155 | img_path = self._img_paths[idx] 156 | img = Image.open(img_path).convert("RGB") 157 | # Get annotation 158 | target = self._get_annotation(idx) 159 | # Apply augmentation transforms 160 | if self.transforms is not None: 161 | img, target = self.transforms(img, target) 162 | 163 | return img, target 164 | 165 | def __len__(self): 166 | return len(self._img_paths) 167 | 168 | 169 | class MOTTracking(torch.utils.data.Dataset): 170 | """ 171 | Data class for tracking 172 | Loads one sequence at a time 173 | To be used for tracking 174 | """ 175 | 176 | def __init__( 177 | self, 178 | root="../datasets/", 179 | dataset="MOT17", 180 | which_set="train", 181 | sequence="02", 182 | public_detection="None", 183 | vis_threshold=0.1, 184 | ): 185 | # Check dataset 186 | predefined_datasets = ["MOT16", "MOT17", "MOT20"] 187 | assert dataset in predefined_datasets, \ 188 | f"Provided dataset name '{dataset}' is not in predefined datasets: {predefined_datasets}" 189 | # Different public detections for MOT17 190 | if dataset == "MOT17": 191 | assert public_detection in ["DPM", "FRCNN", "SDP"], "Incorrect public detection provided" 192 | public_detection = f"-{public_detection}" 193 | # No public detection names for MOT16 and MOT20 194 | else: 195 | assert public_detection == "None", f"No public detection should be provided for {dataset}" 196 | public_detection = "" 197 | # Check train/test 198 | assert which_set in ["train", "test"], "Invalid choice between 'train' and 'test'" 199 | # Check sequence, convert to two-digits string format 200 | assert sequence.isdigit(), "Non-digit sequence provided" 201 | sequence = f"{int(sequence):02d}" 202 | dict_sequences = { 203 | "MOT16": { 204 | "train": ["02", "04", "05", "09", "10", "11", "13"], 205 | "test": ["01", "03", "06", "07", "08", "12", "14"], 206 | }, 207 | "MOT17": { 208 | "train": ["02", "04", "05", "09", "10", "11", "13"], 209 | "test": ["01", "03", "06", "07", "08", "12", "14"], 210 | }, 211 | "MOT20": { 212 | "train": ["01", "02", "03", "05"], 213 | "test": ["04", "06", "07", "08"], 214 | } 215 | } 216 | assert sequence in dict_sequences[dataset][which_set], \ 217 | f"Sequence for {dataset}/{which_set} must be in [{dict_sequences[dataset][which_set]}]" 218 | 219 | self._img_paths = [] 220 | self._vis_threshold = vis_threshold 221 | 222 | # Load images 223 | self.path = os.path.join(root, dataset, which_set, f"{dataset}-{sequence}{public_detection}") 224 | config_file = os.path.join(self.path, "seqinfo.ini") 225 | 226 | assert os.path.exists(config_file), f"Path does not exist: {config_file}" 227 | 228 | config = configparser.ConfigParser() 229 | config.read(config_file) 230 | seq_len = int(config["Sequence"]["seqLength"]) 231 | im_ext = config["Sequence"]["imExt"] 232 | im_dir = config["Sequence"]["imDir"] 233 | 234 | _imDir = os.path.join(self.path, im_dir) 235 | 236 | for i in range(1, seq_len + 1): 237 | img_path = os.path.join(_imDir, f"{i:06d}{im_ext}") 238 | assert os.path.exists(img_path), \ 239 | "Path does not exist: {img_path}" 240 | self._img_paths.append(img_path) 241 | 242 | def _get_annotation(self, idx): 243 | """ 244 | Obtain annotation for detections (train/test) and ground truths (train only) 245 | """ 246 | img_path = self._img_paths[idx] 247 | file_index = int(os.path.basename(img_path).split(".")[0]) 248 | 249 | det_file = os.path.join(os.path.dirname( 250 | os.path.dirname(img_path)), "det", "det.txt") 251 | assert os.path.exists(det_file), \ 252 | f"Det file does not exist: {det_file}" 253 | det_boxes, _, det_scores, _ = read_mot_file(det_file, file_index, self._vis_threshold, is_gt=False) 254 | 255 | # No GT for test set 256 | if "test" in self.path: 257 | return det_boxes, None, None, None, None 258 | 259 | gt_file = os.path.join(os.path.dirname( 260 | os.path.dirname(img_path)), "gt", "gt.txt") 261 | assert os.path.exists(gt_file), \ 262 | f"GT file does not exist: {gt_file}" 263 | gt_boxes, gt_ids, _, gt_visibilities = read_mot_file(gt_file, file_index, self._vis_threshold, is_gt=True) 264 | 265 | return det_boxes, det_scores, gt_boxes, gt_ids, gt_visibilities 266 | 267 | def __getitem__(self, idx): 268 | # Load image 269 | img_path = self._img_paths[idx] 270 | img = Image.open(img_path).convert("RGB") 271 | img = to_tensor(img) 272 | # Get annotation 273 | det_boxes, det_scores, gt_boxes, gt_ids, gt_visibilities = self._get_annotation(idx) 274 | 275 | return img, det_boxes, det_scores, gt_boxes, gt_ids, gt_visibilities 276 | 277 | def __len__(self): 278 | return len(self._img_paths) 279 | 280 | 281 | def read_mot_file(file, file_index, vis_threshold=0.1, is_gt=False): 282 | """ 283 | Read data from mot files, gt or det or tracking result 284 | """ 285 | bounding_boxes = [] 286 | with open(file, "r") as inf: 287 | reader = csv.reader(inf, delimiter=",") 288 | for row in reader: 289 | visibility = float(row[8]) if is_gt else -1.0 290 | if int(row[0]) == file_index and \ 291 | ((is_gt and (int(row[6]) == 1 and int(row[7]) == 1 and visibility > vis_threshold)) or 292 | not is_gt): # Only requires class=pedestrian and confidence=1 for gt 293 | bb = {} 294 | bb["gt_id"] = int(row[1]) 295 | bb["bb_left"] = float(row[2]) 296 | bb["bb_top"] = float(row[3]) 297 | bb["bb_width"] = float(row[4]) 298 | bb["bb_height"] = float(row[5]) 299 | bb["bb_score"] = float(row[6]) if not is_gt else 1 300 | bb["visibility"] = visibility 301 | bounding_boxes.append(bb) 302 | 303 | num_objs = len(bounding_boxes) 304 | boxes = torch.zeros((num_objs, 4), dtype=torch.float32) 305 | scores = torch.zeros((num_objs), dtype=torch.float32) 306 | visibilities = torch.zeros((num_objs), dtype=torch.float32) 307 | ids = torch.zeros((num_objs), dtype=torch.int64) 308 | for i, bb in enumerate(bounding_boxes): 309 | x1 = bb["bb_left"] # GS 310 | y1 = bb["bb_top"] 311 | x2 = x1 + bb["bb_width"] 312 | y2 = y1 + bb["bb_height"] 313 | boxes[i, 0] = x1 314 | boxes[i, 1] = y1 315 | boxes[i, 2] = x2 316 | boxes[i, 3] = y2 317 | scores[i] = bb["bb_score"] 318 | visibilities[i] = bb["visibility"] 319 | ids[i] = bb["gt_id"] 320 | 321 | return boxes, ids, scores, visibilities 322 | 323 | 324 | def collate_fn(batch): 325 | """ 326 | Function for dataloader 327 | """ 328 | return tuple(zip(*batch)) 329 | 330 | 331 | def get_seq_names(dataset, which_set, public_detection, sequence): 332 | """ 333 | Get name of all required sequences 334 | """ 335 | # Process inputs 336 | if public_detection == "all": 337 | if dataset == "MOT17": 338 | public_detection_list = ["DPM", "FRCNN", "SDP"] 339 | else: 340 | public_detection_list = ["None"] 341 | else: 342 | public_detection_list = [public_detection] 343 | 344 | if sequence == "all": 345 | if dataset == "MOT20": 346 | if which_set == "train": 347 | sequence_list = ["01", "02", "03", "05"] 348 | else: 349 | sequence_list = ["04", "06", "07", "08"] 350 | else: 351 | if which_set == "train": 352 | sequence_list = ["02", "04", "05", "09", "10", "11", "13"] 353 | else: 354 | sequence_list = ["01", "03", "06", "07", "08", "12", "14"] 355 | else: 356 | sequence_list = [sequence] 357 | # Iterate through all sequences 358 | full_names = [] 359 | seqs = [] 360 | pds = [] # public detections for each sequence 361 | for pd, seq in list(itertools.product(public_detection_list, sequence_list)): 362 | seqs.append(seq) 363 | pd_suffix = f"-{pd}" if dataset == "MOT17" else "" 364 | pds.append(pd) 365 | curr_seq = f"{dataset}-{seq}{pd_suffix}" 366 | full_names.append(curr_seq) 367 | return full_names, seqs, pds 368 | -------------------------------------------------------------------------------- /lib/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songguocode/TADAM/abd0b7422c3582e36c928778894cee8a159f896e/lib/modules/__init__.py -------------------------------------------------------------------------------- /lib/modules/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class NonLocalAttention(nn.Module): 6 | """ 7 | Attention from all positions in object B to all positions in object A 8 | """ 9 | def __init__(self, in_channels=256, inter_channels=None, bn_layer=True): 10 | super(NonLocalAttention, self).__init__() 11 | 12 | self.in_channels = in_channels 13 | self.inter_channels = inter_channels 14 | 15 | # Set default inter_channels 16 | if self.inter_channels is None: 17 | self.inter_channels = in_channels // 2 18 | if self.inter_channels == 0: 19 | self.inter_channels = 1 20 | 21 | self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 22 | kernel_size=1, stride=1, padding=0) 23 | 24 | if bn_layer: 25 | self.W = nn.Sequential( 26 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 27 | bias=False, kernel_size=1, stride=1, padding=0), 28 | nn.BatchNorm2d(self.in_channels) 29 | ) 30 | else: 31 | self.W = nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 32 | kernel_size=1, stride=1, padding=0) 33 | 34 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 35 | kernel_size=1, stride=1, padding=0) 36 | 37 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 38 | kernel_size=1, stride=1, padding=0) 39 | 40 | # Initialize weights 41 | for m in self.modules(): 42 | if isinstance(m, nn.Conv2d): 43 | nn.init.kaiming_normal_(m.weight.data, mode="fan_in") 44 | if m.bias is not None: 45 | m.bias.data.zero_() 46 | 47 | def forward(self, a, b): 48 | B, C, H, W = a.size() 49 | 50 | # Pairwise relationship 51 | theta_a = self.theta(a).reshape(B, self.inter_channels, -1).permute(0, 2, 1) 52 | phi_b = self.phi(b).reshape(B, self.inter_channels, -1) 53 | # Correlation of size (B, H * W, H * W) 54 | f = torch.matmul(theta_a, phi_b) 55 | f_div_C = f / f.size(-1) 56 | 57 | # Get representation of b 58 | g_b = self.g(b).view(B, self.inter_channels, -1) 59 | g_b = g_b.permute(0, 2, 1) 60 | 61 | # Combine 62 | y = torch.matmul(f_div_C, g_b) 63 | y = y.permute(0, 2, 1).contiguous() 64 | y = y.view(B, self.inter_channels, *a.size()[2:]) 65 | W_y = self.W(y) 66 | 67 | return W_y 68 | -------------------------------------------------------------------------------- /lib/modules/detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.jit.annotations import Tuple, List, Dict, Optional 6 | from torch.autograd.function import Function 7 | from torchvision.models.detection.backbone_utils import resnet_fpn_backbone 8 | from torchvision.models.detection.transform import resize_boxes 9 | from torchvision.ops import MultiScaleRoIAlign, box_iou 10 | # Local files 11 | from .faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor 12 | from .roi_heads import fastrcnn_loss 13 | from .identity import IDModule 14 | from .memory import MemoryNet 15 | from .attention import NonLocalAttention 16 | from .integration import IntegrationModule 17 | from ..utils.model_loader import load_model 18 | from ..tracking.tracklet import Tracklet 19 | 20 | 21 | class Detector(FasterRCNN): 22 | def __init__( 23 | self, 24 | config, 25 | num_classes=2, 26 | num_ids=1000, 27 | tracking=False, 28 | logger=None, 29 | ): 30 | super(Detector, self).__init__( 31 | resnet_fpn_backbone(config.NAMES.BACKBONE, False), 32 | num_classes=num_classes 33 | ) 34 | assert config is not None, "Config not passed" 35 | self.config = config 36 | 37 | # Load components 38 | self.memory_net = MemoryNet( 39 | feature_size=(256, 7, 7), 40 | num_ids=num_ids, 41 | kernel_size=(3, 3), 42 | bias=True 43 | ) 44 | self.roi_heads.id_roi_pool = MultiScaleRoIAlign( 45 | featmap_names=["0", "1", "2", "3"], 46 | output_size=7, 47 | sampling_ratio=2 48 | ) 49 | self.roi_heads.id_module = IDModule( 50 | in_channels=256, 51 | out_channels=256, 52 | num_ids=num_ids 53 | ) 54 | self.roi_heads.target_enhance = NonLocalAttention( 55 | in_channels=256, 56 | inter_channels=128 57 | ) 58 | self.roi_heads.distractor_reduce = NonLocalAttention( 59 | in_channels=256, 60 | inter_channels=128 61 | ) 62 | self.roi_heads.integration = IntegrationModule(min_iou=0.2) 63 | self.roi_heads.hard_box_head = TwoMLPHead( 64 | self.roi_heads.box_head.fc6.in_features, 65 | self.roi_heads.box_head.fc6.out_features 66 | ) 67 | self.roi_heads.hard_box_predictor = FastRCNNPredictor( 68 | self.roi_heads.box_predictor.cls_score.in_features, 69 | num_classes=num_classes 70 | ) 71 | 72 | # Load trained model for tracking 73 | if tracking: 74 | model_path = os.path.join( 75 | self.config.PATHS.MODEL_ROOT, 76 | self.config.NAMES.MODEL 77 | ) 78 | self = load_model(self, model_path, logger) 79 | # Freeze model to speed up inferencing 80 | for param in self.parameters(): 81 | param.requires_grad = False 82 | # Load checkpoint for training 83 | else: 84 | checkpoint_path = os.path.join( 85 | self.config.PATHS.MODEL_ROOT, 86 | self.config.NAMES.CHECKPOINT 87 | ) 88 | self = load_model(self, checkpoint_path, logger) 89 | 90 | def predict_boxes( 91 | self, 92 | frame, 93 | boxes, 94 | prediction_type="detection", 95 | distractor_ious=None, 96 | box_ids=None, 97 | target_bools=None, 98 | target_embeddings=None, 99 | distractor_bools=None, 100 | distractor_embeddings=None 101 | ): 102 | """ 103 | Make predictions from given bounding boxes 104 | Either from public detections (basic), or from tracked targets (with TADA) 105 | """ 106 | device = list(self.parameters())[0].device 107 | images = frame.unsqueeze(0) 108 | images = images.to(device) 109 | boxes.to(device) 110 | 111 | targets = None 112 | original_image_sizes = [images.shape[-2:]] 113 | # Image and box sizes are changed inside the RCNN transform 114 | images, targets = self.transform(images, targets) 115 | 116 | backbone_features = self.backbone(images.tensors) 117 | if isinstance(backbone_features, torch.Tensor): 118 | backbone_features = OrderedDict([(0, backbone_features)]) 119 | 120 | # Resize to transformed size 121 | proposals = [resize_boxes(boxes, original_image_sizes[0], images.image_sizes[0])] 122 | 123 | # Get box features by pooling 124 | box_features = self.roi_heads.box_roi_pool(backbone_features, proposals, images.image_sizes) 125 | 126 | # Basic prediction for given public detections 127 | assert prediction_type in ["detection", "tracklet"], "Invalid prediction type" 128 | if prediction_type == "detection": 129 | box_features = self.roi_heads.box_head(box_features) 130 | class_logits, box_regression = self.roi_heads.box_predictor(box_features) 131 | # Target-aware and distractor-aware for tracked targets 132 | else: 133 | # Use distractor iou threshold to choose between easy/hard classifier 134 | awareness_bool = distractor_ious > self.config.TRACKING.MIN_OVERLAP_AS_DISTRACTOR 135 | # No awareness for easy cases 136 | easy_box_features = self.roi_heads.box_head(box_features[~awareness_bool]) 137 | easy_class_logits, easy_box_regression = self.roi_heads.box_predictor(easy_box_features) 138 | # Apply awareness for hard cases 139 | if len(target_bools[awareness_bool]): # In case no hard cases 140 | hard_awareness, _ = self.process_id_embeddings( 141 | [p[awareness_bool] for p in proposals], 142 | backbone_features=backbone_features, 143 | image_shapes=images.image_sizes, 144 | boxes_type="tracklet", 145 | purpose="awareness", 146 | target_bools=target_bools[awareness_bool], 147 | target_embeddings=target_embeddings[awareness_bool], 148 | distractor_ious=distractor_ious[awareness_bool], 149 | distractor_bools=distractor_bools[awareness_bool], 150 | distractor_embeddings=distractor_embeddings 151 | ) 152 | hard_input_features = box_features[awareness_bool] + hard_awareness 153 | hard_box_features = self.roi_heads.hard_box_head(hard_input_features) 154 | hard_class_logits, hard_box_regression = self.roi_heads.hard_box_predictor(hard_box_features) 155 | # Combine both cases, create empty data first 156 | class_logits = torch.zeros([len(boxes)] + list(easy_class_logits.size()[1:]), 157 | dtype=easy_class_logits.dtype, device=easy_class_logits.device) 158 | box_regression = torch.zeros([len(boxes)] + list(easy_box_regression.size()[1:]), 159 | dtype=easy_box_regression.dtype, device=easy_box_regression.device) 160 | # Assign values according to awareness_bool 161 | class_logits[~awareness_bool] = easy_class_logits 162 | box_regression[~awareness_bool] = easy_box_regression 163 | if len(target_bools[awareness_bool]): # In case no hard cases 164 | class_logits[awareness_bool] = hard_class_logits 165 | box_regression[awareness_bool] = hard_box_regression 166 | 167 | # Process predictions 168 | pred_boxes = self.roi_heads.box_coder.decode(box_regression, proposals) 169 | pred_scores = F.softmax(class_logits, -1) 170 | # Get pedestrian class 171 | pred_boxes = pred_boxes[:, 1] 172 | pred_scores = pred_scores[:, 1] 173 | 174 | # Recover original size 175 | pred_boxes_orig = resize_boxes(pred_boxes, images.image_sizes[0], original_image_sizes[0]) 176 | 177 | # Output 178 | id_embeddings, _ = self.process_id_embeddings( 179 | [pred_boxes], 180 | backbone_features=backbone_features, 181 | image_shapes=images.image_sizes, 182 | boxes_type=prediction_type, 183 | purpose="embedding", 184 | target_bools=target_bools, 185 | target_embeddings=target_embeddings, 186 | distractor_ious=distractor_ious, 187 | distractor_bools=distractor_bools, 188 | distractor_embeddings=distractor_embeddings) 189 | 190 | return pred_boxes_orig, pred_scores, id_embeddings 191 | 192 | def custom_train( 193 | self, 194 | images, 195 | targets, 196 | warmup=False 197 | ): 198 | """ 199 | TADAM training 200 | """ 201 | device = list(self.parameters())[0].device 202 | if targets is None: 203 | raise ValueError("In training, targets should be passed") 204 | 205 | # Remove image and target with less than one ground truth, happens in MOT16/17-05 206 | non_empty_indices = [] 207 | for i, t in enumerate(targets): 208 | if len(t['boxes']) > 1: 209 | non_empty_indices.append(i) 210 | images = [img for i, img in enumerate(images) if i in non_empty_indices] 211 | targets = [t for i, t in enumerate(targets) if i in non_empty_indices] 212 | 213 | # Extract backbone features 214 | images, targets = self.transform(images, targets) 215 | backbone_features = self.backbone(images.tensors) 216 | if isinstance(backbone_features, torch.Tensor): 217 | backbone_features = OrderedDict([('0', backbone_features)]) 218 | 219 | # Losses 220 | all_losses = {} 221 | 222 | # ====== RPN ====== # 223 | # Only works in training to select RoIs, not used in tracking 224 | proposals, proposal_losses = self.rpn(images, backbone_features, targets) 225 | self.update_losses(all_losses, proposal_losses) 226 | 227 | # ====== Box Features ====== # 228 | if targets is not None: 229 | for t in targets: 230 | assert t["boxes"].dtype in (torch.float, torch.double, torch.half), 'target boxes must of float type' 231 | assert t["labels"].dtype == torch.int64, 'target labels must of int64 type' 232 | proposals, _, labels, regression_targets = self.roi_heads.select_training_samples(proposals, targets) 233 | assert labels is not None and regression_targets is not None, "Invalid labels/regression_targets" 234 | 235 | # Extract box features from images 236 | # box_features from different images are concatenated if input has multiple images 237 | box_features = self.roi_heads.box_roi_pool(backbone_features, proposals, images.image_sizes) 238 | 239 | # ====== Easy cases without identity awareness ====== # 240 | if not warmup: 241 | # Easy cases without awareness 242 | easy_box_features = self.roi_heads.box_head(box_features) 243 | easy_class_logits, easy_box_regression = self.roi_heads.box_predictor(easy_box_features) 244 | # Losses 245 | easy_box_losses = {} 246 | loss_easy_classifier, loss_easy_box_reg = fastrcnn_loss( 247 | easy_class_logits, easy_box_regression, labels, regression_targets) 248 | easy_box_losses = { 249 | "loss_classifier": loss_easy_classifier, 250 | "loss_box_reg": loss_easy_box_reg, 251 | } 252 | self.update_losses(all_losses, easy_box_losses) 253 | # Train hard case head & predictor with easy cases for basic abilities 254 | hard_basic_box_features = self.roi_heads.hard_box_head(box_features) 255 | hard_basic_class_logits, hard_basic_box_regression = self.roi_heads.hard_box_predictor(hard_basic_box_features) 256 | # Losses 257 | hard_basic_box_losses = {} 258 | loss_hard_basic_classifier, loss_hard_basic_box_reg = fastrcnn_loss( 259 | hard_basic_class_logits, hard_basic_box_regression, labels, regression_targets) 260 | hard_basic_box_losses = { 261 | "loss_hard_basic_classifier": loss_hard_basic_classifier, 262 | "loss_hard_basic_box_reg": loss_hard_basic_box_reg, 263 | } 264 | self.update_losses(all_losses, hard_basic_box_losses) 265 | 266 | # ====== Identity training on memory and embedding extraction ====== # 267 | # Obtain id embeddings for gt boxes, no loss in this step 268 | # Only gt boxes participates in identity training for accuracy 269 | gt_id_embeddings, _ = self.process_id_embeddings( 270 | [t['boxes'] for t in targets], 271 | backbone_features=backbone_features, 272 | image_shapes=images.image_sizes, 273 | boxes_type='detection', 274 | purpose='embedding' 275 | ) 276 | # Update tracklets 277 | cat_boxes = torch.cat([t['boxes'] for t in targets]) 278 | cat_gt_ids = torch.cat([t['gt_ids'] for t in targets]) 279 | memory_losses = [] 280 | for i, (box, gt_id, e) in enumerate(zip(cat_boxes, cat_gt_ids, gt_id_embeddings)): 281 | # Update an existing tracklet. Each tracklet is reset after certain updates 282 | if gt_id.item() in self.all_tracklets_dict.keys(): 283 | tracklet = self.all_tracklets_dict[gt_id.item()] 284 | tracklet.update_embedding(e.detach().unsqueeze(0), training=True) 285 | # Get loss for memory upon update, using identity loss as supervision 286 | memory_losses.append(self.memory_net.memory_loss(tracklet.embedding, gt_id.unsqueeze(0))) 287 | # Initialize a tracklet 288 | else: 289 | self.all_tracklets_dict[gt_id.item()] = Tracklet(-1, box, self.config, 290 | embedding=e.detach().unsqueeze(0), 291 | memory_net=self.memory_net, training=True) 292 | self.all_tracklets_dict[gt_id.item()].tracklet_id = gt_id.item() 293 | # Loss for memory 294 | if len(memory_losses): 295 | # Average over tracklets in the frame 296 | loss_memory = sum(memory_losses) / len(memory_losses) 297 | else: 298 | loss_memory = torch.tensor(0.0, dtype=gt_id_embeddings.dtype, device=device) 299 | self.update_losses(all_losses, { 300 | 'loss_mem': loss_memory 301 | }) 302 | # Loss for embedding extraction 303 | pos_gt_embeddings, neg_gt_embeddings = self.collect_triplet(cat_gt_ids) 304 | _, gt_id_losses = self.process_id_embeddings( 305 | [t['boxes'] for t in targets], 306 | backbone_features=backbone_features, 307 | image_shapes=images.image_sizes, 308 | boxes_type='detection', 309 | purpose='embedding', 310 | matched_bools=torch.ones_like(cat_gt_ids).bool(), 311 | matched_gt_ids=cat_gt_ids, 312 | pos_id_embeddings=pos_gt_embeddings, 313 | neg_id_embeddings=neg_gt_embeddings) 314 | # Reduce identity losses by a ratio after warmup, to balance different losses 315 | self.update_losses(all_losses, { 316 | 'loss_gt_id_crossentropy': 317 | gt_id_losses['loss_id_crossentropy'] if warmup 318 | else gt_id_losses['loss_id_crossentropy'] * self.config.TRAINING.ID_LOSS_RATIO 319 | }) 320 | self.update_losses(all_losses, { 321 | 'loss_gt_id_triplet': 322 | gt_id_losses['loss_id_triplet'] if warmup 323 | else gt_id_losses['loss_id_triplet'] * self.config.TRAINING.ID_LOSS_RATIO 324 | }) 325 | 326 | # ====== Hard cases with identity awareness ====== # 327 | if not warmup: 328 | target_bools, target_embeddings, distractor_bools, distractor_embeddings, distractor_ious = \ 329 | self.collect_attention_embeddings(proposals, targets, min_overlap=0.5, 330 | min_distractor_overlap=self.config.TRACKING.MIN_OVERLAP_AS_DISTRACTOR) 331 | id_awareness, _ = self.process_id_embeddings( 332 | proposals, 333 | backbone_features=backbone_features, 334 | image_shapes=images.image_sizes, 335 | boxes_type='tracklet', 336 | purpose='awareness', 337 | target_bools=target_bools, 338 | target_embeddings=target_embeddings, 339 | distractor_ious=distractor_ious, 340 | distractor_bools=distractor_bools, 341 | distractor_embeddings=distractor_embeddings) 342 | # Collect box features 343 | input_features = box_features + id_awareness 344 | hard_awareness_box_features = self.roi_heads.hard_box_head(input_features) 345 | hard_awareness_class_logits, hard_awareness_box_regression = \ 346 | self.roi_heads.hard_box_predictor(hard_awareness_box_features) 347 | # Losses 348 | loss_hard_awareness_classifier, loss_hard_awareness_box_reg = fastrcnn_loss( 349 | hard_awareness_class_logits, hard_awareness_box_regression, labels, regression_targets) 350 | all_losses.update({ 351 | 'loss_hard_awareness_classifier': loss_hard_awareness_classifier, 352 | 'loss_hard_awareness_box_reg': loss_hard_awareness_box_reg, 353 | }) 354 | 355 | return all_losses 356 | 357 | def update_losses(self, all_losses, new_loss): 358 | for key, value in new_loss.items(): 359 | if key in all_losses: 360 | all_losses[key] += value 361 | else: 362 | all_losses.update({key: value}) 363 | 364 | def process_id_embeddings( 365 | self, 366 | boxes, 367 | backbone_features=None, 368 | image_shapes=None, 369 | boxes_type="detection", 370 | purpose="embedding", 371 | matched_bools=None, 372 | matched_gt_ids=None, 373 | pos_id_embeddings=None, 374 | neg_id_embeddings=None, 375 | target_bools=None, 376 | target_embeddings=None, 377 | distractor_ious=None, 378 | distractor_bools=None, 379 | distractor_embeddings=None 380 | ): 381 | """ 382 | Outputs identity embeddings and/or awareness 383 | For public detection: extracts identity embeddinng only 384 | For tracked targets: extract identity embedding + awareness 385 | For awareness computation: awareness only 386 | """ 387 | assert purpose in ["embedding", "awareness"], "Invalid purpose" 388 | id_losses = {} 389 | assert backbone_features is not None and image_shapes is not None 390 | id_embeddings = self.roi_heads.id_module(self.roi_heads.id_roi_pool(backbone_features, boxes, image_shapes)) 391 | # ID losses 392 | if self.training and matched_bools is not None and matched_gt_ids is not None: 393 | # Cross entropy loss # GS 394 | loss_id_crossentropy = self.roi_heads.id_module.cross_entropy_loss( 395 | id_embeddings[matched_bools], matched_gt_ids) 396 | id_losses.update({"loss_id_crossentropy": loss_id_crossentropy}) 397 | # Triplet loss # GS 398 | assert pos_id_embeddings is not None 399 | assert neg_id_embeddings is not None 400 | assert len(matched_gt_ids) == len(pos_id_embeddings) 401 | assert len(matched_gt_ids) == len(neg_id_embeddings) 402 | loss_id_triplet = self.roi_heads.id_module.triplet_loss( 403 | id_embeddings[matched_bools], 404 | pos_id_embeddings, 405 | neg_id_embeddings, 406 | margin=0.3) 407 | id_losses.update({"loss_id_triplet": loss_id_triplet}) 408 | # Apply awareness for tracked targets 409 | enhancement = torch.zeros_like(id_embeddings) 410 | reduction = torch.zeros_like(id_embeddings) 411 | if boxes_type == "tracklet": 412 | # Target 413 | assert target_bools is not None and target_embeddings is not None 414 | if len(target_bools): 415 | enhancement[target_bools] = self.roi_heads.target_enhance(id_embeddings[target_bools].detach(), target_embeddings) 416 | # Distractor. It is possible to have no distractors 417 | if distractor_bools is not None and distractor_embeddings is not None and len(distractor_bools) and torch.sum(distractor_bools).item() > 0: 418 | reduction[distractor_bools] = self.roi_heads.distractor_reduce(id_embeddings[distractor_bools].detach(), distractor_embeddings) 419 | # Scale up gradient 420 | reduction[distractor_bools] = _ScaleGradient.apply(reduction[distractor_bools], 2.0) 421 | # Combine 422 | awareness = torch.zeros_like(id_embeddings) 423 | if boxes_type == "tracklet": 424 | awareness = self.roi_heads.integration(enhancement, reduction, overlaps=distractor_ious) 425 | # Output 426 | output = \ 427 | id_embeddings * float(purpose == "embedding") + awareness 428 | return output, id_losses 429 | 430 | def collect_triplet(self, pos_ids): 431 | """ 432 | Retrieve pos and neg for triplet in identity training 433 | """ 434 | pos_embeddings = [] 435 | for p_id in pos_ids: 436 | t = self.all_tracklets_dict[p_id.item()] 437 | pos_embeddings.append(t.embedding.detach()) 438 | pos_embeddings = torch.cat(pos_embeddings) 439 | # Generate random neg indices 440 | neg_embeddings = [] 441 | neg_gt_ids = torch.tensor(list(self.all_tracklets_dict.keys()), device=pos_ids.device).unsqueeze(0) 442 | neg_gt_ids_repeated = neg_gt_ids.repeat(pos_ids.size(0), 1) 443 | neg_gt_ids = neg_gt_ids_repeated[neg_gt_ids_repeated != pos_ids.unsqueeze(1)].reshape(pos_ids.size(0), neg_gt_ids.size(1) - 1) 444 | neg_selections = torch.randint(0, neg_gt_ids.size(1), (neg_gt_ids.size(0), 1), device=neg_gt_ids.device) 445 | neg_gt_ids = neg_gt_ids.gather(dim=1, index=neg_selections).squeeze(-1) 446 | for i, p_id in enumerate(pos_ids): 447 | t = self.all_tracklets_dict[neg_gt_ids[i].item()] 448 | neg_embeddings.append(t.embedding.detach()) 449 | neg_embeddings = torch.cat(neg_embeddings) 450 | return pos_embeddings, neg_embeddings 451 | 452 | def collect_attention_embeddings( 453 | self, 454 | proposals, 455 | targets, 456 | min_overlap=0.5, 457 | min_distractor_overlap=0.2 458 | ): 459 | """ 460 | Retrieve target reference and distractor reference 461 | """ 462 | target_bools, target_ref_ids, distractor_bools, distractor_ref_ids, distractor_ious = \ 463 | self.attention_training_match_ids(proposals, targets, 464 | min_overlap=min_overlap, min_distractor_overlap=min_distractor_overlap) 465 | target_embeddings = [] 466 | for target_id in target_ref_ids: 467 | target_embeddings.append(self.all_tracklets_dict[target_id.item()].embedding.detach()) 468 | target_embeddings = torch.cat(target_embeddings) 469 | distractor_embeddings = [] 470 | for distractor_id in distractor_ref_ids: 471 | distractor_embeddings.append(self.all_tracklets_dict[distractor_id.item()].embedding.detach()) 472 | if len(distractor_embeddings): 473 | distractor_embeddings = torch.cat(distractor_embeddings) 474 | else: 475 | distractor_bools = None 476 | distractor_embeddings = None 477 | return target_bools, target_embeddings, distractor_bools, distractor_embeddings, distractor_ious 478 | 479 | def attention_training_match_ids( 480 | self, 481 | proposals, 482 | targets, 483 | min_overlap=0.5, 484 | min_distractor_overlap=0.2 485 | ): 486 | """ 487 | Get proposals that match gt boxes with minimum overlap 488 | Only train attention on such positive proposals 489 | Then match proposals with other gt boxes to seek respective distractors 490 | min_distractor_overlap must be smaller than min_overlap 491 | """ 492 | cat_matched_bools = [] 493 | cat_matched_ids = [] 494 | cat_matched_distractor_bools = [] 495 | cat_matched_distractor_ids = [] 496 | cat_distractor_ious = [] 497 | for p, t in zip(proposals, targets): 498 | gt_boxes = t["boxes"].to(p[0].dtype) 499 | gt_ids = t["gt_ids"] 500 | 501 | # Match proposals to gt 502 | match_quality_matrix = box_iou(gt_boxes, p) 503 | top_vals, top_matches = match_quality_matrix.topk(k=2, dim=0) # For each proposal 504 | 505 | # No.1 match should be corresponding target, if matched 506 | matched_bools = top_vals[0] > min_overlap 507 | matched_ids = gt_ids[top_matches[0]][matched_bools] 508 | cat_matched_bools.append(matched_bools) 509 | cat_matched_ids.append(matched_ids) 510 | 511 | # No.2 match should be corresponding distractor, if matched 512 | matched_distractor_bools = top_vals[1] > min_distractor_overlap 513 | matched_distractor_ids = gt_ids[top_matches[1]][matched_distractor_bools] 514 | cat_matched_distractor_bools.append(matched_distractor_bools) 515 | cat_matched_distractor_ids.append(matched_distractor_ids) 516 | cat_distractor_ious.append(top_vals[1]) 517 | # Length of bools is original length 518 | cat_matched_bools = torch.cat(cat_matched_bools) 519 | cat_matched_distractor_bools = torch.cat(cat_matched_distractor_bools) 520 | # Length of ids is matched length, shorter than original length 521 | cat_matched_ids = torch.cat(cat_matched_ids) 522 | cat_matched_distractor_ids = torch.cat(cat_matched_distractor_ids) 523 | cat_distractor_ious = torch.cat(cat_distractor_ious) 524 | return cat_matched_bools, cat_matched_ids, cat_matched_distractor_bools, cat_matched_distractor_ids, cat_distractor_ious 525 | 526 | 527 | class _ScaleGradient(Function): 528 | @staticmethod 529 | def forward(ctx, input, scale): 530 | ctx.scale = scale 531 | return input 532 | 533 | @staticmethod 534 | def backward(ctx, grad_output): 535 | return grad_output * ctx.scale, None 536 | -------------------------------------------------------------------------------- /lib/modules/faster_rcnn.py: -------------------------------------------------------------------------------- 1 | # Adapted from torchvision 2 | 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from torchvision.ops import MultiScaleRoIAlign 6 | from torchvision.models.detection.anchor_utils import AnchorGenerator 7 | from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN 8 | from torchvision.models.detection.rpn import RPNHead, RegionProposalNetwork 9 | 10 | # Local files 11 | from .roi_heads import RoIHeads 12 | from .genenralized_rcnn_transform import GeneralizedRCNNTransform 13 | 14 | 15 | class FasterRCNN(GeneralizedRCNN): 16 | """ 17 | Implements Faster R-CNN. 18 | 19 | The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each 20 | image, and should be in 0-1 range. Different images can have different sizes. 21 | 22 | The behavior of the model changes depending if it is in training or evaluation mode. 23 | 24 | During training, the model expects both the input tensors, as well as a targets (list of dictionary), 25 | containing: 26 | - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with 27 | ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. 28 | - labels (Int64Tensor[N]): the class label for each ground-truth box 29 | 30 | The model returns a Dict[Tensor] during training, containing the classification and regression 31 | losses for both the RPN and the R-CNN. 32 | 33 | During inference, the model requires only the input tensors, and returns the post-processed 34 | predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as 35 | follows: 36 | - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with 37 | ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. 38 | - labels (Int64Tensor[N]): the predicted labels for each image 39 | - scores (Tensor[N]): the scores or each prediction 40 | 41 | Args: 42 | backbone (nn.Module): the network used to compute the features for the model. 43 | It should contain a out_channels attribute, which indicates the number of output 44 | channels that each feature map has (and it should be the same for all feature maps). 45 | The backbone should return a single Tensor or and OrderedDict[Tensor]. 46 | num_classes (int): number of output classes of the model (including the background). 47 | If box_predictor is specified, num_classes should be None. 48 | min_size (int): minimum size of the image to be rescaled before feeding it to the backbone 49 | max_size (int): maximum size of the image to be rescaled before feeding it to the backbone 50 | image_mean (Tuple[float, float, float]): mean values used for input normalization. 51 | They are generally the mean values of the dataset on which the backbone has been trained 52 | on 53 | image_std (Tuple[float, float, float]): std values used for input normalization. 54 | They are generally the std values of the dataset on which the backbone has been trained on 55 | rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature 56 | maps. 57 | rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN 58 | rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training 59 | rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing 60 | rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training 61 | rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing 62 | rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals 63 | rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be 64 | considered as positive during training of the RPN. 65 | rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be 66 | considered as negative during training of the RPN. 67 | rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN 68 | for computing the loss 69 | rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training 70 | of the RPN 71 | rpn_score_thresh (float): during inference, only return proposals with a classification score 72 | greater than rpn_score_thresh 73 | box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in 74 | the locations indicated by the bounding boxes 75 | box_head (nn.Module): module that takes the cropped feature maps as input 76 | box_predictor (nn.Module): module that takes the output of box_head and returns the 77 | classification logits and box regression deltas. 78 | box_score_thresh (float): during inference, only return proposals with a classification score 79 | greater than box_score_thresh 80 | box_nms_thresh (float): NMS threshold for the prediction head. Used during inference 81 | box_detections_per_img (int): maximum number of detections per image, for all classes. 82 | box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be 83 | considered as positive during training of the classification head 84 | box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be 85 | considered as negative during training of the classification head 86 | box_batch_size_per_image (int): number of proposals that are sampled during training of the 87 | classification head 88 | box_positive_fraction (float): proportion of positive proposals in a mini-batch during training 89 | of the classification head 90 | bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the 91 | bounding boxes 92 | 93 | Example:: 94 | 95 | >>> import torch 96 | >>> import torchvision 97 | >>> from torchvision.models.detection import FasterRCNN 98 | >>> from torchvision.models.detection.rpn import AnchorGenerator 99 | >>> # load a pre-trained model for classification and return 100 | >>> # only the features 101 | >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features 102 | >>> # FasterRCNN needs to know the number of 103 | >>> # output channels in a backbone. For mobilenet_v2, it's 1280 104 | >>> # so we need to add it here 105 | >>> backbone.out_channels = 1280 106 | >>> 107 | >>> # let's make the RPN generate 5 x 3 anchors per spatial 108 | >>> # location, with 5 different sizes and 3 different aspect 109 | >>> # ratios. We have a Tuple[Tuple[int]] because each feature 110 | >>> # map could potentially have different sizes and 111 | >>> # aspect ratios 112 | >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),), 113 | >>> aspect_ratios=((0.5, 1.0, 2.0),)) 114 | >>> 115 | >>> # let's define what are the feature maps that we will 116 | >>> # use to perform the region of interest cropping, as well as 117 | >>> # the size of the crop after rescaling. 118 | >>> # if your backbone returns a Tensor, featmap_names is expected to 119 | >>> # be ['0']. More generally, the backbone should return an 120 | >>> # OrderedDict[Tensor], and in featmap_names you can choose which 121 | >>> # feature maps to use. 122 | >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], 123 | >>> output_size=7, 124 | >>> sampling_ratio=2) 125 | >>> 126 | >>> # put the pieces together inside a FasterRCNN model 127 | >>> model = FasterRCNN(backbone, 128 | >>> num_classes=2, 129 | >>> rpn_anchor_generator=anchor_generator, 130 | >>> box_roi_pool=roi_pooler) 131 | >>> model.eval() 132 | >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] 133 | >>> predictions = model(x) 134 | """ 135 | 136 | def __init__(self, backbone, num_classes=None, 137 | # transform parameters 138 | min_size=800, max_size=1333, 139 | image_mean=None, image_std=None, 140 | # RPN parameters 141 | rpn_anchor_generator=None, rpn_head=None, 142 | rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, 143 | rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, 144 | rpn_nms_thresh=0.7, 145 | rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, 146 | rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, 147 | rpn_score_thresh=0.0, 148 | # Box parameters 149 | box_roi_pool=None, box_head=None, box_predictor=None, 150 | box_score_thresh=0.05, box_nms_thresh=0.3, box_detections_per_img=100, 151 | box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, 152 | box_batch_size_per_image=256, box_positive_fraction=0.75, 153 | bbox_reg_weights=None): 154 | 155 | if not hasattr(backbone, "out_channels"): 156 | raise ValueError( 157 | "backbone should contain an attribute out_channels " 158 | "specifying the number of output channels (assumed to be the " 159 | "same for all the levels)") 160 | 161 | assert isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))) 162 | assert isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))) 163 | 164 | if num_classes is not None: 165 | if box_predictor is not None: 166 | raise ValueError("num_classes should be None when box_predictor is specified") 167 | else: 168 | if box_predictor is None: 169 | raise ValueError("num_classes should not be None when box_predictor " 170 | "is not specified") 171 | 172 | out_channels = backbone.out_channels 173 | 174 | if rpn_anchor_generator is None: 175 | anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) 176 | aspect_ratios = ((1.0, 2.0, 3.0),) * len(anchor_sizes) 177 | rpn_anchor_generator = AnchorGenerator( 178 | anchor_sizes, aspect_ratios 179 | ) 180 | if rpn_head is None: 181 | rpn_head = RPNHead( 182 | out_channels, rpn_anchor_generator.num_anchors_per_location()[0] 183 | ) 184 | 185 | rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test) 186 | rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test) 187 | 188 | rpn = RegionProposalNetwork( 189 | rpn_anchor_generator, rpn_head, 190 | rpn_fg_iou_thresh, rpn_bg_iou_thresh, 191 | rpn_batch_size_per_image, rpn_positive_fraction, 192 | rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh, 193 | score_thresh=rpn_score_thresh) 194 | 195 | if box_roi_pool is None: 196 | box_roi_pool = MultiScaleRoIAlign( 197 | featmap_names=["0", "1", "2", "3"], 198 | output_size=7, 199 | sampling_ratio=2) 200 | 201 | if box_head is None: 202 | resolution = box_roi_pool.output_size[0] 203 | representation_size = 1024 204 | box_head = TwoMLPHead( 205 | out_channels * resolution ** 2, 206 | representation_size) 207 | 208 | if box_predictor is None: 209 | representation_size = 1024 210 | box_predictor = FastRCNNPredictor( 211 | representation_size, 212 | num_classes) 213 | 214 | roi_heads = RoIHeads( 215 | # Box 216 | box_roi_pool, box_head, box_predictor, 217 | box_fg_iou_thresh, box_bg_iou_thresh, 218 | box_batch_size_per_image, box_positive_fraction, 219 | bbox_reg_weights, 220 | box_score_thresh, box_nms_thresh, box_detections_per_img) 221 | 222 | if image_mean is None: 223 | image_mean = [0.485, 0.456, 0.406] 224 | if image_std is None: 225 | image_std = [0.229, 0.224, 0.225] 226 | transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) 227 | 228 | super(FasterRCNN, self).__init__(backbone, rpn, roi_heads, transform) 229 | 230 | 231 | class TwoMLPHead(nn.Module): 232 | """ 233 | Standard heads for FPN-based models 234 | 235 | Args: 236 | in_channels (int): number of input channels 237 | representation_size (int): size of the intermediate representation 238 | """ 239 | 240 | def __init__(self, in_channels, representation_size): 241 | super(TwoMLPHead, self).__init__() 242 | 243 | self.fc6 = nn.Linear(in_channels, representation_size) 244 | self.fc7 = nn.Linear(representation_size, representation_size) 245 | 246 | def forward(self, x): 247 | x = x.flatten(start_dim=1) 248 | 249 | x = F.relu(self.fc6(x)) 250 | x = F.relu(self.fc7(x)) 251 | 252 | return x 253 | 254 | 255 | class FastRCNNPredictor(nn.Module): 256 | """ 257 | Standard classification + bounding box regression layers 258 | for Fast R-CNN. 259 | 260 | Args: 261 | in_channels (int): number of input channels 262 | num_classes (int): number of output classes (including background) 263 | """ 264 | 265 | def __init__(self, in_channels, num_classes): 266 | super(FastRCNNPredictor, self).__init__() 267 | self.cls_score = nn.Linear(in_channels, num_classes) 268 | self.bbox_pred = nn.Linear(in_channels, num_classes * 4) 269 | 270 | def forward(self, x): 271 | if x.dim() == 4: 272 | assert list(x.shape[2:]) == [1, 1] 273 | x = x.flatten(start_dim=1) 274 | scores = self.cls_score(x) 275 | bbox_deltas = self.bbox_pred(x) 276 | 277 | return scores, bbox_deltas 278 | -------------------------------------------------------------------------------- /lib/modules/genenralized_rcnn_transform.py: -------------------------------------------------------------------------------- 1 | # Adapted from torchvision 2 | 3 | from __future__ import division 4 | 5 | import math 6 | import torch 7 | from torch import nn, Tensor 8 | import torchvision 9 | from torch.jit.annotations import List, Tuple, Dict, Optional 10 | 11 | from torchvision.models.detection.image_list import ImageList 12 | 13 | 14 | class GeneralizedRCNNTransform(nn.Module): 15 | """ 16 | Performs input / target transformation before feeding the data to a GeneralizedRCNN 17 | model. 18 | 19 | The transformations it perform are: 20 | - input normalization (mean subtraction and std division) 21 | - input / target resizing to match min_size / max_size 22 | 23 | It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets 24 | """ 25 | 26 | def __init__(self, min_size, max_size, image_mean, image_std): 27 | super(GeneralizedRCNNTransform, self).__init__() 28 | if not isinstance(min_size, (list, tuple)): 29 | min_size = (min_size,) 30 | self.min_size = min_size 31 | self.max_size = max_size 32 | self.image_mean = image_mean 33 | self.image_std = image_std 34 | 35 | def forward(self, images, targets=None): 36 | # type: (...) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]] 37 | images = [img for img in images] 38 | for i in range(len(images)): 39 | image = images[i] 40 | target_index = targets[i] if targets is not None else None 41 | 42 | if image.dim() != 3: 43 | raise ValueError("images is expected to be a list of 3d tensors " 44 | f"of shape [C, H, W], got {image.shape}") 45 | image = self.normalize(image) 46 | image, target_index = self.resize(image, target_index) 47 | images[i] = image 48 | if targets is not None and target_index is not None: 49 | targets[i] = target_index 50 | 51 | image_sizes = [img.shape[-2:] for img in images] 52 | images = self.batch_images(images) 53 | image_sizes_list = torch.jit.annotate(List[Tuple[int, int]], []) 54 | for image_size in image_sizes: 55 | assert len(image_size) == 2 56 | image_sizes_list.append((image_size[0], image_size[1])) 57 | 58 | image_list = ImageList(images, image_sizes_list) 59 | return image_list, targets 60 | 61 | def normalize(self, image): 62 | dtype, device = image.dtype, image.device 63 | mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device) 64 | std = torch.as_tensor(self.image_std, dtype=dtype, device=device) 65 | return (image - mean[:, None, None]) / std[:, None, None] 66 | 67 | def torch_choice(self, l): 68 | # type: (List[int]) -> int 69 | """ 70 | Implements `random.choice` via torch ops so it can be compiled with 71 | TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803 72 | is fixed. 73 | """ 74 | index = int(torch.empty(1).uniform_(0., float(len(l))).item()) 75 | return l[index] 76 | 77 | def resize(self, image, target): 78 | # type: (Tensor, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]] 79 | h, w = image.shape[-2:] 80 | im_shape = torch.tensor(image.shape[-2:]) 81 | min_size = float(torch.min(im_shape)) 82 | max_size = float(torch.max(im_shape)) 83 | if self.training: 84 | size = float(self.torch_choice(self.min_size)) 85 | else: 86 | # FIXME assume for now that testing uses the largest scale 87 | size = float(self.min_size[-1]) 88 | scale_factor = size / min_size 89 | if max_size * scale_factor > self.max_size: 90 | scale_factor = self.max_size / max_size 91 | image = torch.nn.functional.interpolate( 92 | image[None], scale_factor=scale_factor, mode="bilinear", 93 | recompute_scale_factor=True, align_corners=False)[0] 94 | 95 | if target is None: 96 | return image, target 97 | 98 | bbox = target["boxes"] 99 | bbox = resize_boxes(bbox, (h, w), image.shape[-2:]) 100 | target["boxes"] = bbox 101 | 102 | return image, target 103 | 104 | # _onnx_batch_images() is an implementation of 105 | # batch_images() that is supported by ONNX tracing. 106 | @torch.jit.unused 107 | def _onnx_batch_images(self, images, size_divisible=32): 108 | # type: (List[Tensor], int) -> Tensor 109 | max_size = [] 110 | for i in range(images[0].dim()): 111 | max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64) 112 | max_size.append(max_size_i) 113 | stride = size_divisible 114 | max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64) 115 | max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64) 116 | max_size = tuple(max_size) 117 | 118 | # work around for 119 | # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 120 | # which is not yet supported in onnx 121 | padded_imgs = [] 122 | for img in images: 123 | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] 124 | padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) 125 | padded_imgs.append(padded_img) 126 | 127 | return torch.stack(padded_imgs) 128 | 129 | def max_by_axis(self, the_list): 130 | # type: (List[List[int]]) -> List[int] 131 | maxes = the_list[0] 132 | for sublist in the_list[1:]: 133 | for index, item in enumerate(sublist): 134 | maxes[index] = max(maxes[index], item) 135 | return maxes 136 | 137 | def batch_images(self, images, size_divisible=32): 138 | # type: (List[Tensor], int) -> Tensor 139 | if torchvision._is_tracing(): 140 | # batch_images() does not export well to ONNX 141 | # call _onnx_batch_images() instead 142 | return self._onnx_batch_images(images, size_divisible) 143 | 144 | max_size = self.max_by_axis([list(img.shape) for img in images]) 145 | stride = float(size_divisible) 146 | max_size = list(max_size) 147 | max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride) 148 | max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride) 149 | 150 | batch_shape = [len(images)] + max_size 151 | batched_imgs = images[0].new_full(batch_shape, 0) 152 | for img, pad_img in zip(images, batched_imgs): 153 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 154 | 155 | return batched_imgs 156 | 157 | def postprocess(self, result, image_shapes, original_image_sizes): 158 | # type: (...) -> List[Dict[str, Tensor]] 159 | if self.training: 160 | return result 161 | for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)): 162 | boxes = pred["boxes"] 163 | boxes = resize_boxes(boxes, im_s, o_im_s) 164 | result[i]["boxes"] = boxes 165 | return result 166 | 167 | 168 | def resize_boxes(boxes, original_size, new_size): 169 | # type: (Tensor, List[int], List[int]) -> Tensor 170 | ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)] 171 | ratio_height, ratio_width = ratios 172 | xmin, ymin, xmax, ymax = boxes.unbind(1) 173 | 174 | xmin = xmin * ratio_width 175 | xmax = xmax * ratio_width 176 | ymin = ymin * ratio_height 177 | ymax = ymax * ratio_height 178 | return torch.stack((xmin, ymin, xmax, ymax), dim=1) 179 | -------------------------------------------------------------------------------- /lib/modules/identity.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import functional as F 3 | 4 | 5 | class BasicConv(nn.Module): 6 | def __init__( 7 | self, 8 | in_planes, 9 | out_planes, 10 | kernel_size, 11 | stride=1, 12 | padding=1, 13 | dilation=1, 14 | groups=1, 15 | relu=True, 16 | bn=True, 17 | bias=True 18 | ): 19 | super(BasicConv, self).__init__() 20 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, 21 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 22 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 23 | self.relu = nn.ReLU() if relu else None 24 | 25 | def forward(self, x): 26 | x = self.conv(x) 27 | if self.bn is not None: 28 | x = self.bn(x) 29 | if self.relu is not None: 30 | x = self.relu(x) 31 | return x 32 | 33 | 34 | class IDModule(nn.Module): 35 | def __init__( 36 | self, 37 | in_channels=256, 38 | out_channels=128, 39 | num_ids=1000 40 | ): 41 | super(IDModule, self).__init__() 42 | 43 | self.in_channels = in_channels 44 | self.layers = nn.Sequential( 45 | BasicConv(in_channels, in_channels, 1, stride=1, padding=0, relu=True, bn=True), 46 | BasicConv(in_channels, in_channels, 1, stride=1, padding=0, relu=True, bn=True), 47 | BasicConv(in_channels, out_channels, 1, stride=1, padding=0, relu=False, bn=False) 48 | ) 49 | 50 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 51 | self.classifier = nn.Linear(out_channels, num_ids) 52 | 53 | for m in self.modules(): 54 | if isinstance(m, nn.Conv2d): 55 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 56 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 57 | nn.init.constant_(m.weight, 1) 58 | nn.init.constant_(m.bias, 0) 59 | 60 | def forward(self, x): 61 | id_feature = self.layers(x) 62 | return id_feature 63 | 64 | def cross_entropy_loss(self, id_feature, labels): 65 | pooled_id_feature = self.avg_pool(id_feature).squeeze(-1).squeeze(-1) 66 | id_predictions = self.classifier(pooled_id_feature) 67 | loss = nn.CrossEntropyLoss(reduction="mean")(id_predictions, labels) 68 | return loss 69 | 70 | def triplet_loss(self, anchor_feature, pos_feature, neg_feature, margin=0.1): 71 | if neg_feature.shape[1] != anchor_feature.shape[1]: 72 | loss = 0 73 | for i in range(neg_feature.shape[1]): 74 | loss += nn.TripletMarginLoss(margin=margin, p=2, reduction="mean")( 75 | self.avg_pool(anchor_feature), 76 | self.avg_pool(pos_feature), 77 | self.avg_pool(neg_feature[:, i, :, :, :])) 78 | loss = loss / neg_feature.shape[1] 79 | else: 80 | loss = nn.TripletMarginLoss(margin=margin, p=2, reduction="mean")( 81 | self.avg_pool(anchor_feature), 82 | self.avg_pool(pos_feature), 83 | self.avg_pool(neg_feature) 84 | ) 85 | return loss 86 | -------------------------------------------------------------------------------- /lib/modules/integration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class IntegrationModule(nn.Module): 6 | def __init__( 7 | self, 8 | min_iou=0.2, 9 | enhance_weight_max=1.0, 10 | reduce_weight_max=1.0 11 | ): 12 | super(IntegrationModule, self).__init__() 13 | self.min_iou = min_iou 14 | self.enhance_weight_max = enhance_weight_max 15 | self.reduce_weight_max = reduce_weight_max 16 | 17 | def forward(self, enhance_feature, reduce_feature, overlaps): 18 | enhance_weight = self.compute_weight(overlaps, self.enhance_weight_max, self.min_iou) 19 | reduce_weight = self.compute_weight(overlaps, self.reduce_weight_max, self.min_iou) 20 | return enhance_feature * enhance_weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) - \ 21 | reduce_feature * reduce_weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 22 | 23 | def compute_weight(self, ious, weight_max, iou_min): 24 | weight = weight_max * torch.min(torch.max((ious - iou_min) / (1.0 - iou_min), torch.zeros_like(ious)), torch.ones_like(ious)) 25 | return weight 26 | -------------------------------------------------------------------------------- /lib/modules/memory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Memory(object): 6 | def __init__(self, initial_feature, memory_net): 7 | self.h_state = initial_feature 8 | 9 | def update(self, new_feature, memory_net): 10 | self.h_state = memory_net(new_feature, self.h_state) 11 | 12 | def train_update(self, feature_sequence, memory_net): 13 | for i, f in enumerate(feature_sequence): 14 | if i == 0: 15 | h = f 16 | else: 17 | h = memory_net(f, h) 18 | self.h_state = h 19 | 20 | 21 | class MemoryNet(nn.Module): 22 | def __init__( 23 | self, 24 | feature_size=(256, 7, 7), 25 | num_ids=1000, 26 | kernel_size=(1, 1), 27 | bias=True, 28 | ): 29 | super(MemoryNet, self).__init__() 30 | C, H, W = feature_size 31 | self.init_state = nn.Conv2d( 32 | in_channels=C, 33 | out_channels=C, 34 | kernel_size=kernel_size, 35 | padding=(kernel_size[0] // 2, kernel_size[1] // 2), 36 | bias=bias 37 | ) 38 | self.loss_classifier = nn.Linear(feature_size[0], num_ids) 39 | self.cell = ConvGRUCell(feature_size=feature_size, kernel_size=kernel_size, bias=bias) 40 | 41 | def forward(self, new_feature, state): 42 | return self.cell(new_feature, state) 43 | 44 | # To train memory, compute after update 45 | def memory_loss(self, feature, gt_id, pos_feature=None, neg_feature=None): 46 | pooled_feature = nn.AdaptiveAvgPool2d(1)(feature).squeeze(-1).squeeze(-1) 47 | prediction = self.loss_classifier(pooled_feature) 48 | crossentropy_loss = nn.CrossEntropyLoss(reduction='mean')(prediction, gt_id) 49 | if pos_feature is not None and neg_feature is not None: 50 | pooled_pos = nn.AdaptiveAvgPool2d(1)(pos_feature).squeeze(-1).squeeze(-1) 51 | pooled_neg = nn.AdaptiveAvgPool2d(1)(neg_feature).squeeze(-1).squeeze(-1) 52 | triplet_loss = nn.TripletMarginLoss(reduction='mean')(pooled_feature, pooled_pos, pooled_neg) 53 | return crossentropy_loss + triplet_loss 54 | else: 55 | return crossentropy_loss 56 | 57 | 58 | class ConvGRUCell(nn.Module): 59 | def __init__( 60 | self, 61 | feature_size=(256, 7, 7), 62 | kernel_size=(3, 3), 63 | bias=True, 64 | ): 65 | super(ConvGRUCell, self).__init__() 66 | self.C, self.H, self.W = feature_size 67 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 68 | self.hidden_dim = self.C 69 | self.bias = bias 70 | 71 | self.conv_gates = nn.Conv2d( 72 | in_channels=self.C + self.hidden_dim, 73 | out_channels=self.hidden_dim * 2, # For update_gate, reset_gate 74 | kernel_size=kernel_size, 75 | padding=self.padding, 76 | bias=self.bias 77 | ) 78 | 79 | self.conv_candidate = nn.Conv2d( 80 | in_channels=self.C + self.hidden_dim, 81 | out_channels=self.hidden_dim, # For candidate neural memory 82 | kernel_size=kernel_size, 83 | padding=self.padding, 84 | bias=self.bias 85 | ) 86 | 87 | def forward(self, new_feature, h_prev): 88 | combined = torch.cat([new_feature, h_prev], dim=1) 89 | gates = self.conv_gates(combined) 90 | reset_gate, update_gate = torch.split(gates, self.hidden_dim, dim=1) 91 | reset_gate = torch.sigmoid(reset_gate) 92 | update_gate = torch.sigmoid(update_gate) 93 | combined = torch.cat([new_feature, reset_gate * h_prev], dim=1) 94 | candidate = torch.tanh(self.conv_candidate(combined)) 95 | h_new = (1 - update_gate) * h_prev + update_gate * candidate 96 | return h_new 97 | -------------------------------------------------------------------------------- /lib/modules/roi_heads.py: -------------------------------------------------------------------------------- 1 | # Adapted from torchvision 2 | 3 | from __future__ import division 4 | import torch 5 | from torch import nn, Tensor 6 | import torch.nn.functional as F 7 | import torchvision.models.detection._utils as det_utils 8 | from torchvision.ops import boxes as box_ops 9 | from torch.jit.annotations import Optional, List, Dict, Tuple 10 | 11 | 12 | def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): 13 | # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] 14 | """ 15 | Computes the loss for Faster R-CNN. 16 | 17 | Arguments: 18 | class_logits (Tensor) 19 | box_regression (Tensor) 20 | labels (list[BoxList]) 21 | regression_targets (Tensor) 22 | cls_weights (Tensor) 23 | 24 | Returns: 25 | classification_loss (Tensor) 26 | box_loss (Tensor) 27 | """ 28 | 29 | labels = torch.cat(labels, dim=0) 30 | regression_targets = torch.cat(regression_targets, dim=0) 31 | classification_loss = F.cross_entropy(class_logits, labels) 32 | 33 | # get indices that correspond to the regression targets for 34 | # the corresponding ground truth labels, to be used with 35 | # advanced indexing 36 | sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1) 37 | labels_pos = labels[sampled_pos_inds_subset] 38 | N, num_classes = class_logits.shape 39 | box_regression = box_regression.reshape(N, -1, 4) 40 | 41 | box_loss = F.smooth_l1_loss( 42 | box_regression[sampled_pos_inds_subset, labels_pos], 43 | regression_targets[sampled_pos_inds_subset], 44 | reduction="sum", 45 | ) 46 | box_loss = box_loss / labels.numel() 47 | 48 | return classification_loss, box_loss 49 | 50 | 51 | class RoIHeads(torch.nn.Module): 52 | __annotations__ = { 53 | "box_coder": det_utils.BoxCoder, 54 | "proposal_matcher": det_utils.Matcher, 55 | "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler, 56 | } 57 | 58 | def __init__( 59 | self, 60 | box_roi_pool, 61 | box_head, 62 | box_predictor, 63 | # Faster R-CNN training 64 | fg_iou_thresh, bg_iou_thresh, 65 | batch_size_per_image, positive_fraction, 66 | bbox_reg_weights, 67 | # Faster R-CNN inference 68 | score_thresh, 69 | nms_thresh, 70 | detections_per_img, 71 | ): 72 | super(RoIHeads, self).__init__() 73 | 74 | self.box_similarity = box_ops.box_iou 75 | # assign ground-truth boxes for each proposal 76 | self.proposal_matcher = det_utils.Matcher( 77 | fg_iou_thresh, 78 | bg_iou_thresh, 79 | allow_low_quality_matches=False) 80 | 81 | self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler( 82 | batch_size_per_image, 83 | positive_fraction) 84 | 85 | if bbox_reg_weights is None: 86 | bbox_reg_weights = (10., 10., 5., 5.) 87 | self.box_coder = det_utils.BoxCoder(bbox_reg_weights) 88 | 89 | self.box_roi_pool = box_roi_pool 90 | self.box_head = box_head 91 | self.box_predictor = box_predictor 92 | 93 | self.score_thresh = score_thresh 94 | self.nms_thresh = nms_thresh 95 | self.detections_per_img = detections_per_img 96 | 97 | def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels): 98 | # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] 99 | matched_idxs = [] 100 | labels = [] 101 | for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels): 102 | # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands 103 | match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image) 104 | matched_idxs_in_image = self.proposal_matcher(match_quality_matrix) 105 | 106 | clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0) 107 | 108 | labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image] 109 | labels_in_image = labels_in_image.to(dtype=torch.int64) 110 | 111 | # Label background (below the low threshold) 112 | bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD 113 | labels_in_image[bg_inds] = torch.tensor(0, device=proposals_in_image.device) 114 | 115 | # Label ignore proposals (between low and high thresholds) 116 | ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS 117 | # -1 is ignored by sampler 118 | labels_in_image[ignore_inds] = torch.tensor(-1, device=proposals_in_image.device) 119 | 120 | matched_idxs.append(clamped_matched_idxs_in_image) 121 | labels.append(labels_in_image) 122 | return matched_idxs, labels 123 | 124 | def subsample(self, labels): 125 | # type: (List[Tensor]) -> List[Tensor] 126 | sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) 127 | sampled_inds = [] 128 | for img_idx, (pos_inds_img, neg_inds_img) in enumerate( 129 | zip(sampled_pos_inds, sampled_neg_inds) 130 | ): 131 | img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1) 132 | sampled_inds.append(img_sampled_inds) 133 | return sampled_inds 134 | 135 | def add_gt_proposals(self, proposals, gt_boxes): 136 | # type: (List[Tensor], List[Tensor]) -> List[Tensor] 137 | proposals = [ 138 | torch.cat((proposal, gt_box)) 139 | for proposal, gt_box in zip(proposals, gt_boxes) 140 | ] 141 | 142 | return proposals 143 | 144 | def check_targets(self, targets): 145 | # type: (Optional[List[Dict[str, Tensor]]]) -> None 146 | assert targets is not None 147 | assert all(["boxes" in t for t in targets]) 148 | assert all(["labels" in t for t in targets]) 149 | 150 | def select_training_samples(self, 151 | proposals, # type: List[Tensor] 152 | targets # type: Optional[List[Dict[str, Tensor]]] 153 | ): 154 | # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]] 155 | self.check_targets(targets) 156 | assert targets is not None 157 | dtype = proposals[0].dtype 158 | gt_boxes = [t["boxes"].to(dtype) for t in targets] 159 | gt_labels = [t["labels"] for t in targets] 160 | 161 | # append ground-truth bboxes to propos 162 | proposals = self.add_gt_proposals(proposals, gt_boxes) 163 | 164 | # get matching gt indices for each proposal 165 | matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels) 166 | # sample a fixed proportion of positive-negative proposals 167 | sampled_inds = self.subsample(labels) 168 | matched_gt_boxes = [] 169 | num_images = len(proposals) 170 | for img_id in range(num_images): 171 | img_sampled_inds = sampled_inds[img_id] 172 | proposals[img_id] = proposals[img_id][img_sampled_inds] 173 | labels[img_id] = labels[img_id][img_sampled_inds] 174 | matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds] 175 | matched_gt_boxes.append(gt_boxes[img_id][matched_idxs[img_id]]) 176 | 177 | regression_targets = self.box_coder.encode(matched_gt_boxes, proposals) 178 | return proposals, matched_idxs, labels, regression_targets 179 | 180 | def postprocess_detections(self, 181 | class_logits, # type: Tensor 182 | box_regression, # type: Tensor 183 | proposals, # type: List[Tensor] 184 | image_shapes # type: List[Tuple[int, int]] 185 | ): 186 | # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]] 187 | device = class_logits.device 188 | num_classes = class_logits.shape[-1] 189 | 190 | boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals] 191 | pred_boxes = self.box_coder.decode(box_regression, proposals) 192 | 193 | pred_scores = F.softmax(class_logits, -1) 194 | 195 | # split boxes and scores per image 196 | if len(boxes_per_image) == 1: 197 | # TODO : remove this when ONNX support dynamic split sizes 198 | # and just assign to pred_boxes instead of pred_boxes_list 199 | pred_boxes_list = [pred_boxes] 200 | pred_scores_list = [pred_scores] 201 | else: 202 | pred_boxes_list = pred_boxes.split(boxes_per_image, 0) 203 | pred_scores_list = pred_scores.split(boxes_per_image, 0) 204 | 205 | all_boxes = [] 206 | all_scores = [] 207 | all_labels = [] 208 | for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes): 209 | boxes = box_ops.clip_boxes_to_image(boxes, image_shape) 210 | 211 | # create labels for each prediction 212 | labels = torch.arange(num_classes, device=device) 213 | labels = labels.view(1, -1).expand_as(scores) 214 | 215 | # Remove predictions with the background label 216 | boxes = boxes[:, 1:] 217 | scores = scores[:, 1:] 218 | labels = labels[:, 1:] 219 | 220 | # batch everything, by making every class prediction be a separate instance 221 | boxes = boxes.reshape(-1, 4) 222 | scores = scores.reshape(-1) 223 | labels = labels.reshape(-1) 224 | 225 | # remove low scoring boxes 226 | inds = torch.nonzero(scores > self.score_thresh).squeeze(1) 227 | boxes, scores, labels = boxes[inds], scores[inds], labels[inds] 228 | 229 | # remove empty boxes 230 | keep = box_ops.remove_small_boxes(boxes, min_size=1e-2) 231 | boxes, scores, labels = boxes[keep], scores[keep], labels[keep] 232 | 233 | # non-maximum suppression, independently done per class 234 | keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh) 235 | # keep only topk scoring predictions 236 | keep = keep[:self.detections_per_img] 237 | boxes, scores, labels = boxes[keep], scores[keep], labels[keep] 238 | 239 | all_boxes.append(boxes) 240 | all_scores.append(scores) 241 | all_labels.append(labels) 242 | 243 | return all_boxes, all_scores, all_labels 244 | 245 | def forward(self, 246 | features, # type: Dict[str, Tensor] 247 | proposals, # type: List[Tensor] 248 | image_shapes, # type: List[Tuple[int, int]] 249 | targets=None # type: Optional[List[Dict[str, Tensor]]] 250 | ): 251 | # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]] 252 | """ 253 | Arguments: 254 | features (List[Tensor]) 255 | proposals (List[Tensor[N, 4]]) 256 | image_shapes (List[Tuple[H, W]]) 257 | targets (List[Dict]) 258 | """ 259 | if targets is not None: 260 | for t in targets: 261 | # TODO: https://github.com/pytorch/pytorch/issues/26731 262 | floating_point_types = (torch.float, torch.double, torch.half) 263 | assert t["boxes"].dtype in floating_point_types, "target boxes must of float type" 264 | assert t["labels"].dtype == torch.int64, "target labels must of int64 type" 265 | if self.has_keypoint(): 266 | assert t["keypoints"].dtype == torch.float32, "target keypoints must of float type" 267 | 268 | if self.training: 269 | proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets) 270 | else: 271 | labels = None 272 | regression_targets = None 273 | 274 | # Extract features from images 275 | # box_features from different images are concatenated if input has multiple images 276 | box_features = self.box_roi_pool(features, proposals, image_shapes) 277 | 278 | # Two MLP head 279 | box_features = self.box_head(box_features) 280 | # Prediction 281 | class_logits, box_regression = self.box_predictor(box_features) 282 | 283 | # Get prediction result with processing 284 | result = torch.jit.annotate(List[Dict[str, torch.Tensor]], []) 285 | pred_boxes, pred_scores, pred_labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes) 286 | num_images = len(pred_boxes) 287 | for i in range(num_images): 288 | result.append( 289 | { 290 | "boxes": pred_boxes[i], 291 | "labels": pred_labels[i], 292 | "scores": pred_scores[i], 293 | } 294 | ) 295 | 296 | # Losses 297 | losses = {} 298 | if self.training: 299 | assert labels is not None and regression_targets is not None 300 | loss_classifier, loss_box_reg = fastrcnn_loss( 301 | class_logits, box_regression, labels, regression_targets) 302 | 303 | losses = { 304 | "loss_classifier": loss_classifier, 305 | "loss_box_reg": loss_box_reg, 306 | } 307 | 308 | return result, losses 309 | -------------------------------------------------------------------------------- /lib/tracking/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songguocode/TADAM/abd0b7422c3582e36c928778894cee8a159f896e/lib/tracking/__init__.py -------------------------------------------------------------------------------- /lib/tracking/detection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Detection(object): 5 | def __init__(self, box, score, embedding): 6 | self.box = box.detach() 7 | self.score = score 8 | self.embedding = embedding 9 | self.brand_new = True 10 | 11 | @property 12 | def ltwh(self): 13 | # Retrieve left, top, width, height from box (x1y1x2y2) 14 | ltwh = np.asarray(self.box.clone().detach().cpu().numpy()) 15 | ltwh[2:] -= ltwh[:2] 16 | return ltwh 17 | 18 | @property 19 | def avg_embedding(self): 20 | return self.embedding.mean(-1).mean(-1) 21 | -------------------------------------------------------------------------------- /lib/tracking/test_tracker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import itertools 5 | from ..configs.config import load_config 6 | from .tracker import OnlineTracker 7 | from ..dataset.mot import MOTTracking, collate_fn, get_seq_names 8 | from ..utils.image_processing import tensor_to_cv2 9 | from ..utils.visualization import plot_boxes 10 | from ..utils.log import get_logger, log_or_print 11 | from ..utils.timer import Timer 12 | from ..utils.official_benchmark import benchmark 13 | 14 | 15 | # Follows MOTChallenge data format, ltwh: left, top, width, height 16 | def write_results(filename, results, logger): 17 | with open(filename, "w") as f: 18 | for frame_id, ltwhs, track_ids in results: 19 | for ltwh, track_id in zip(ltwhs, track_ids): 20 | if track_id < 0: 21 | continue 22 | l, t, w, h = ltwh 23 | line = f"{frame_id},{track_id},{l},{t},{w},{h},1,-1,-1,-1\n" 24 | f.write(line) 25 | log_or_print(logger, f"Saved results to '{filename}'") 26 | 27 | 28 | def test_sequence( 29 | dataloader, 30 | config, 31 | logger, 32 | seq_result_file, 33 | plot_frames=False, 34 | ): 35 | tracker = OnlineTracker(config, logger) 36 | 37 | timer = Timer() 38 | results = [] 39 | log_or_print(logger, f"Started processing frames") 40 | for frame_id, batch in enumerate(dataloader): 41 | # Start from 1 42 | frame_id += 1 43 | 44 | # Get detections 45 | frame, det_boxes, _, _, _, _ = batch 46 | # 0 index due to collate function 47 | frame = frame[0].cuda() 48 | det_boxes = det_boxes[0].cuda() 49 | 50 | # Track objects in this frame 51 | timer.tic() 52 | online_targets = tracker.update(frame_id, frame, det_boxes) 53 | online_ltwhs = [] 54 | online_boxes = [] 55 | online_scores = [] 56 | online_ids = [] 57 | for t in online_targets: 58 | online_ltwhs.append(t.ltwh) 59 | online_boxes.append(t.box.detach().clone().cpu().numpy()) 60 | online_scores.append(t.score) 61 | online_ids.append(t.tracklet_id) 62 | timer.toc() 63 | 64 | # Log every 20 frames 65 | if frame_id % 20 == 0: 66 | log_or_print(logger, f"Processed frame {frame_id} ({1. / max(1e-5, timer.average_time):.2f} fps)") 67 | 68 | # Store results 69 | results.append((frame_id, online_ltwhs, online_ids)) 70 | 71 | # Visualize tracking results 72 | if plot_frames: 73 | cv2_frame = tensor_to_cv2(frame) 74 | image = plot_boxes(cv2_frame, online_boxes, obj_ids=online_ids, 75 | scores=online_scores, show_scores=True, 76 | image_scale=900 / cv2_frame.shape[0], show_info=True, 77 | frame_id=frame_id, fps=1. / timer.average_time) 78 | cv2.imshow("Tracking", image) 79 | while True: 80 | # Wait for keys 81 | key = cv2.waitKey(0) 82 | # Quit, press "q" or "Esc" 83 | if key in [ord("q"), 27]: 84 | exit(0) 85 | # Next, press "space" 86 | elif key == 32: 87 | break 88 | 89 | # Write files 90 | write_results(seq_result_file, results, logger) 91 | 92 | # Release GPU memory 93 | del tracker 94 | 95 | 96 | def test( 97 | config, 98 | logger, 99 | dataset="MOT17", 100 | which_set="train", 101 | public_detection="all", 102 | sequence="all", 103 | result_name="TADAM_MOT17", 104 | evaluation=True, 105 | plot_frames=False, 106 | ): 107 | # Set directories 108 | result_folder = os.path.join(config.PATHS.RESULT_ROOT, dataset, result_name) 109 | if not os.path.isdir(result_folder): 110 | os.makedirs(result_folder) 111 | 112 | full_seq_names, seqs, pds = get_seq_names(dataset, which_set, public_detection, sequence) 113 | for seq_name, seq, pd in zip(full_seq_names, seqs, pds): 114 | dataloader = torch.utils.data.DataLoader( 115 | MOTTracking(config.PATHS.DATASET_ROOT, dataset, which_set, seq, pd), 116 | batch_size=1, shuffle=False, num_workers=2, collate_fn=collate_fn) 117 | seq_result_file = os.path.join(result_folder, f"{seq_name}.txt") 118 | log_or_print(logger, f"Sequence: {seq_name}") 119 | test_sequence(dataloader, config, logger, seq_result_file, plot_frames) 120 | 121 | # Use matlab code to evaluate training set 122 | if which_set == "train" and evaluation: 123 | log_or_print(logger, f"Starting Evaluation") 124 | benchmark(dataset, result_name, config.PATHS.EVAL_ROOT, config.PATHS.RESULT_ROOT, full_seq_names, logger) 125 | summary = os.path.join(config.PATHS.RESULT_ROOT, dataset, result_name, f"{result_name}_result_summary.txt") 126 | titles = [] 127 | values = [] 128 | with open(summary, "r") as f: 129 | titles = f.readline().split() 130 | values = f.readline().split() 131 | log_or_print(logger, f"Evaluation Summary") 132 | for i in range(len(titles) // 10 + 1): 133 | log_or_print(logger, "\t".join(titles[i * 10: min((i + 1) * 10, len(titles))])) 134 | log_or_print(logger, "\t".join(values[i * 10: min((i + 1) * 10, len(titles))])) 135 | 136 | 137 | if __name__ == "__main__": 138 | import argparse 139 | parser = argparse.ArgumentParser(description="MOT tracking") 140 | parser.add_argument("--result-name", default="TADAM_MOT17_train", type=str, help="name for saving results") 141 | parser.add_argument("--config", default=None, type=str, help="config file to be loaded") 142 | parser.add_argument("--which_set", default="train", type=str, 143 | choices=["train", "test"], help="which set to run on") 144 | parser.add_argument("--public-detection", default="all", choices=["all", "DPM", "FRCNN", "SDP"], 145 | type=str, help="test on specified public detection, valid for MOT17 only. default is all") 146 | parser.add_argument("--sequence", default="all", type=str, help="test on specified sequence. default is all") 147 | parser.add_argument("--evaluation", action="store_true", help="enable evaluation on results. requires matlab") 148 | parser.add_argument("--plot-frames", action="store_true", help="show frames of tracking") 149 | parser.add_argument("-v", "--verbose", action="store_true", help="Display details in console log") 150 | args = parser.parse_args() 151 | 152 | config, cfg_msg = load_config(args.config) 153 | logger = get_logger(name="global", save_file=True, overwrite_file=True, 154 | log_dir=os.path.join(config.PATHS.RESULT_ROOT, config.NAMES.DATASET, args.result_name), 155 | log_name=f"{args.result_name}", console_verbose=args.verbose) 156 | log_or_print(logger, cfg_msg) 157 | 158 | test(config, logger, dataset=config.NAMES.DATASET, which_set=args.which_set, 159 | public_detection=args.public_detection, sequence=args.sequence, 160 | result_name=args.result_name, evaluation=args.evaluation, plot_frames=args.plot_frames) 161 | -------------------------------------------------------------------------------- /lib/tracking/tracker.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import torch 4 | from torchvision.ops import nms, box_iou 5 | 6 | from ..modules.detector import Detector 7 | from .tracklet import Tracklet, TrackletList 8 | from .detection import Detection 9 | 10 | from ..utils.matching import iou, box_giou, reid_distance, linear_assignment 11 | from ..utils.image_processing import tensor_to_cv2, cmc_align 12 | from ..utils.log import log_or_print 13 | 14 | 15 | class OnlineTracker(object): 16 | def __init__(self, config, logger): 17 | self.config = config 18 | self.logger = logger 19 | 20 | # Model, in cuda by default 21 | self.detector = Detector(num_classes=2, tracking=True, 22 | config=self.config, logger=self.logger).cuda() 23 | 24 | # Tracklets 25 | self.active_tracklets = TrackletList("Active") 26 | self.lost_tracklets = TrackletList("Lost") 27 | 28 | # Record last frame for camera movement compensation 29 | self.last_frame = None 30 | 31 | def update(self, frame_id, frame, det_boxes): 32 | # Frame info 33 | log_or_print(self.logger, f"========== Frame {frame_id:4d} ==========", level="debug") 34 | frame_height, frame_width = frame.shape[1:] 35 | 36 | # ------------------------------ # 37 | # ------Process Detections------ # 38 | # ------------------------------ # 39 | 40 | detections = [] 41 | if len(det_boxes): # In case no detections available 42 | det_boxes, det_scores, det_embeddings = \ 43 | self.detector.predict_boxes( 44 | frame, 45 | det_boxes, 46 | prediction_type="detection" 47 | ) 48 | 49 | # Filter low scores 50 | high_score_keep = torch.ge(det_scores, self.config.TRACKING.MIN_SCORE_DETECTION) 51 | det_scores = det_scores[high_score_keep] 52 | det_boxes = det_boxes[high_score_keep] 53 | det_embeddings = det_embeddings[high_score_keep] 54 | 55 | if len(det_scores): # in case of empty 56 | # Apply nms to suppress close detections, especially for DPM detector 57 | indices = nms(det_boxes, det_scores, self.config.TRACKING.NMS_DETECTION).cpu().numpy() 58 | det_boxes = det_boxes[indices] 59 | det_scores = det_scores[indices] 60 | det_embeddings = det_embeddings[indices] 61 | 62 | if len(det_scores): # in case of empty 63 | # Add as detection objects 64 | for index in range(len(det_scores)): 65 | detections.append(Detection(det_boxes[index], det_scores[index].item(), 66 | embedding=det_embeddings[index].clone().detach().unsqueeze(0))) 67 | if len(detections): 68 | log_or_print(self.logger, f"{len(detections)} detections found after filtering", level="debug") 69 | 70 | # ---------------------------- # 71 | # -------Tracklet Update------ # 72 | # ---------------------------- # 73 | 74 | # Camera movement compensation 75 | warp = None 76 | if self.last_frame is not None: 77 | warp = cmc_align(self.last_frame, tensor_to_cv2(frame)) 78 | for t in itertools.chain(self.active_tracklets, self.lost_tracklets): 79 | t.cmc_update(warp) 80 | 81 | # Box outside or overlap with edge. Do twice 82 | # First time here as CMC could move boxes 83 | # Second time later after prediction which also move boxes 84 | if len(self.active_tracklets): 85 | indices = [] 86 | for i, t in enumerate(self.active_tracklets): 87 | if t.box[0] > frame_width - self.config.TRACKING.MIN_BOX_SIZE or \ 88 | t.box[1] > frame_height - self.config.TRACKING.MIN_BOX_SIZE or \ 89 | t.box[2] < self.config.TRACKING.MIN_BOX_SIZE or \ 90 | t.box[3] < self.config.TRACKING.MIN_BOX_SIZE: 91 | indices.append(i) 92 | self.move_tracklets("Outside_cmc\t", 93 | [t for i, t in enumerate(self.active_tracklets) if i in indices], 94 | self.active_tracklets, self.lost_tracklets) 95 | 96 | # Prediction on active tracklets 97 | if len(self.active_tracklets): 98 | # Collect embeddings for target attention 99 | target_embeddings = torch.cat([t.embedding for i, t in enumerate(self.active_tracklets)]) 100 | target_bools = torch.ones(len(self.active_tracklets), device=target_embeddings.device, dtype=torch.bool) 101 | 102 | # Collect embeddings for distractor attention 103 | distractor_bools, distractor_embeddings, distractor_ious = \ 104 | self.collect_distractors( 105 | self.active_tracklets, 106 | self.active_tracklets, # GS 107 | self.config.TRACKING.MIN_OVERLAP_AS_DISTRACTOR 108 | ) 109 | 110 | # Prediction 111 | t_boxes = torch.stack([t.box for t in self.active_tracklets]) 112 | t_boxes, t_scores, t_embeddings = \ 113 | self.detector.predict_boxes( 114 | frame, 115 | t_boxes, 116 | prediction_type="tracklet", 117 | box_ids=torch.tensor([t.tracklet_id for t in self.active_tracklets]).cuda(), 118 | target_bools=target_bools, 119 | target_embeddings=target_embeddings, 120 | distractor_ious=distractor_ious, 121 | distractor_bools=distractor_bools, 122 | distractor_embeddings=distractor_embeddings 123 | ) 124 | 125 | # Filter by scores 126 | high_score_keep = torch.ge(t_scores, self.config.TRACKING.MIN_SCORE_ACTIVE_TRACKLET) 127 | # Update with new position for high score ones 128 | for i, t in enumerate(self.active_tracklets): 129 | if high_score_keep[i]: 130 | t.update(t_boxes[i, :], t_scores[i].item(), 131 | t_embeddings[i].clone().detach().unsqueeze(0)) 132 | log_or_print(self.logger, f"Updated\t\t{t.tracklet_id}", level="debug") 133 | # Send low score ones to lost list 134 | self.move_tracklets("Low class score", 135 | [t for i, t in enumerate(self.active_tracklets) if not high_score_keep[i]], 136 | self.active_tracklets, self.lost_tracklets) 137 | 138 | # NMS 139 | if len(self.active_tracklets): # In case no active tracklets after filtering 140 | scores = torch.tensor([t.score for t in self.active_tracklets]).cuda() 141 | boxes = torch.stack([t.box for t in self.active_tracklets]) 142 | indices = nms(boxes, scores, self.config.TRACKING.NMS_ACTIVE_TRACKLET).cpu().numpy() 143 | self.move_tracklets("NMS\t", 144 | [t for i, t in enumerate(self.active_tracklets) if i not in indices], 145 | self.active_tracklets, self.lost_tracklets) 146 | 147 | # Box outside or overlap with edge, second time 148 | if len(self.active_tracklets): 149 | indices = [] 150 | for i, t in enumerate(self.active_tracklets): 151 | if t.box[0] > frame_width - self.config.TRACKING.MIN_BOX_SIZE or \ 152 | t.box[1] > frame_height - self.config.TRACKING.MIN_BOX_SIZE or \ 153 | t.box[2] < self.config.TRACKING.MIN_BOX_SIZE or \ 154 | t.box[3] < self.config.TRACKING.MIN_BOX_SIZE: 155 | indices.append(i) 156 | self.move_tracklets("Outside_pred", 157 | [t for i, t in enumerate(self.active_tracklets) if i in indices], 158 | self.active_tracklets, self.lost_tracklets) 159 | 160 | # Remove ones with too small boxes 161 | if len(self.active_tracklets): # In case no active tracklets after filtering 162 | indices = [] 163 | for i, t in enumerate(self.active_tracklets): 164 | if t.ltwh[2] < self.config.TRACKING.MIN_BOX_SIZE or t.ltwh[3] < self.config.TRACKING.MIN_BOX_SIZE: 165 | indices.append(i) 166 | self.move_tracklets("Too small\t", 167 | [t for i, t in enumerate(self.active_tracklets) if i in indices], 168 | self.active_tracklets, self.lost_tracklets) 169 | 170 | # Remove boxes too close to edge with only a narrow visible region 171 | # Necessary for MOT16&17, as gt annotation boxes could go outside 172 | # Trained results follow same pattern, thus leave narrow boxes at edges 173 | if len(self.active_tracklets): # In case no active tracklets after filtering 174 | indices = [] 175 | for i, t in enumerate(self.active_tracklets): 176 | _, _, w, h = t.ltwh 177 | min_ratio = 0.25 178 | if t.box[0] > frame_width - w * min_ratio or \ 179 | t.box[1] > frame_height - h * min_ratio or \ 180 | t.box[2] < w * min_ratio or \ 181 | t.box[3] < h * min_ratio: 182 | indices.append(i) 183 | self.move_tracklets("Edge\t\t", 184 | [t for i, t in enumerate(self.active_tracklets) if i in indices], 185 | self.active_tracklets, self.lost_tracklets) 186 | 187 | # -------------------- # 188 | # ------Matching------ # 189 | # -------------------- # 190 | 191 | if len(detections): # In case no detections 192 | # Check if a detection is covered by an active tracklet, simplified matching 193 | for det in detections: 194 | for t in self.active_tracklets: 195 | if iou(det.ltwh, t.ltwh) > self.config.TRACKING.NMS_DETECTION: 196 | det.brand_new = False 197 | break 198 | detections = [det for det in detections if det.brand_new] 199 | 200 | # Matching between lost tracklets and detections, then recover ones matched 201 | if len(self.lost_tracklets) and len(detections): # In case no detections/lost 202 | # Use ReID as cost 203 | t_boxes = torch.stack([t.box for t in self.lost_tracklets]) 204 | det_boxes = torch.stack([d.box for d in detections]) 205 | # Use GIoU to prevent matching too far away 206 | giou_matrix = box_giou(t_boxes, det_boxes).cpu().numpy() 207 | # Use id embedding similarity as cost 208 | cost_matrix = reid_distance(self.lost_tracklets, detections) 209 | # Filter out those GIoU unqualified (set cost to 1) 210 | cost_matrix[giou_matrix < self.config.TRACKING.MIN_RECOVER_GIOU] = 1. 211 | # Linear assignment for matching 212 | matches, _, unassigned_detection_indices = linear_assignment(cost_matrix, 213 | threshold=1 - self.config.TRACKING.MIN_RECOVER_SCORE) 214 | # Matched with a detection 215 | recover_t_indices = [] 216 | recover_det_indices = [] 217 | for matched_t_index, matched_det_index in matches: 218 | t = self.lost_tracklets[matched_t_index] 219 | det = detections[matched_det_index] 220 | recover_t_indices.append(matched_t_index) 221 | recover_det_indices.append(matched_det_index) 222 | t.recover(frame_id, det.box, det.embedding.clone().detach()) 223 | if len(recover_t_indices): 224 | self.move_tracklets("Recovered\t", [t for i, t in enumerate(self.lost_tracklets) 225 | if i in recover_t_indices], self.lost_tracklets, self.active_tracklets) 226 | # Remove matched detections 227 | detections = [det for i, det in enumerate(detections) if i not in recover_det_indices] 228 | 229 | # Initiate new tracklets from unmatched new detections 230 | for i, d in enumerate(detections): 231 | t = Tracklet(frame_id, d.box, self.config, 232 | embedding=d.embedding.clone().detach(), 233 | memory_net=self.detector.memory_net) 234 | log_or_print(self.logger, f"New\t\t\t{t.tracklet_id}", level="debug") 235 | self.active_tracklets.append(t) 236 | 237 | # Update info 238 | for t in self.active_tracklets: 239 | t.update_active_info(frame_id) 240 | for t in self.lost_tracklets: 241 | t.update_lost_info() 242 | 243 | # Remove long lost tracklets 244 | for t in self.lost_tracklets: 245 | if t.time_since_update >= t.max_lost_frames: 246 | log_or_print(self.logger, f"Removed\t{t.tracklet_id}", level="debug") 247 | self.lost_tracklets.remove(t) 248 | del t 249 | 250 | log_or_print(self.logger, f"Active Tracklets\t{sorted([t.tracklet_id for t in self.active_tracklets])}", level="debug") 251 | log_or_print(self.logger, f"Lost Tracklets\t{sorted([t.tracklet_id for t in self.lost_tracklets])}", level="debug") 252 | 253 | self.last_frame = tensor_to_cv2(frame) 254 | torch.cuda.empty_cache() 255 | 256 | return self.active_tracklets 257 | 258 | def move_tracklets(self, reason, tracklets, source, target): 259 | for t in tracklets: 260 | if t in source: 261 | log_or_print(self.logger, f"{reason}\t{t.tracklet_id}\t{source.name} ==> {target.name}", level="debug") 262 | source.remove(t) 263 | target.append(t) 264 | 265 | # Potential distractors input includes target themselves 266 | def collect_distractors(self, target_tracklets, distractor_tracklets, min_distractor_overlap=0.2): 267 | zero_bools = torch.zeros(len(target_tracklets), 268 | device=target_tracklets[0].box.device, dtype=torch.bool) 269 | zero_embeddings = torch.zeros([len(distractor_tracklets)] + list(target_tracklets[0].embedding.size()[1:]), 270 | device=target_tracklets[0].embedding.device, dtype=target_tracklets[0].embedding.dtype) 271 | zero_ious = torch.zeros(len(distractor_tracklets), 272 | device=target_tracklets[0].box.device, dtype=target_tracklets[0].box.dtype) 273 | # No need to compute if only one tracklet exists 274 | if len(distractor_tracklets) < 2: 275 | return zero_bools, zero_embeddings, zero_ious 276 | # Match by IoU 277 | t_boxes = torch.stack([t.box for t in target_tracklets]) 278 | d_boxes = torch.stack([d.box for d in distractor_tracklets]) 279 | iou_matrix = box_iou(t_boxes, d_boxes) 280 | top_vals, top_matches = iou_matrix.topk(k=2, dim=1) 281 | # Distractors are ones with second largest IoU 282 | # As we pass same lists for target and distractor, the one with largest IoU is always itself 283 | distractor_ious = top_vals[:, 1] 284 | distractor_bools = top_vals[:, 1] >= min_distractor_overlap 285 | distractor_indices = top_matches[:, 1][distractor_bools] 286 | distractor_embeddings = [] 287 | for d_index in distractor_indices: 288 | distractor_embeddings.append(distractor_tracklets[d_index.item()].embedding) 289 | if len(distractor_embeddings): 290 | distractor_embeddings = torch.cat(distractor_embeddings) 291 | return distractor_bools, distractor_embeddings, distractor_ious 292 | # If all distractors do no qualify the min_distractor_overlap 293 | else: 294 | return zero_bools, zero_embeddings, distractor_ious 295 | 296 | # Release GPU memory 297 | def __del__(self): 298 | if "active_tracklets" in self.__dict__: 299 | del self.active_tracklets 300 | if "lost_tracklets" in self.__dict__: 301 | del self.lost_tracklets 302 | torch.cuda.empty_cache() 303 | -------------------------------------------------------------------------------- /lib/tracking/tracklet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ..utils.kalman_filter import KalmanFilter 5 | from ..modules.memory import Memory 6 | 7 | 8 | class TrackletList(list): 9 | def __init__(self, name): 10 | self.name = name 11 | 12 | 13 | class Tracklet(object): 14 | # Shared among all instances, not accessible outside 15 | _count = 0 16 | 17 | def __init__( 18 | self, 19 | frame_id, 20 | box, 21 | config, 22 | memory_net=None, 23 | embedding=None, 24 | training=False, 25 | ): 26 | # Initialization with default 27 | self.tracklet_id = self.next_id() 28 | self.tracklet_len = 0 29 | self.time_since_update = 0 30 | self.score = 1.0 31 | self.frame_id = frame_id 32 | self.start_frame = frame_id 33 | 34 | # Position 35 | self.box = box.detach() 36 | 37 | # Kalman filter 38 | self.kalman_filter = KalmanFilter(dim=4) 39 | self.kalman_filter.initiate(self.box.cpu().numpy()) 40 | 41 | # Config 42 | self.config = config 43 | 44 | # Memory aggregation 45 | assert memory_net is not None, "MemoryNet not passed" 46 | self.memory_net = memory_net 47 | if training: 48 | self.memory_train_input = [embedding] 49 | self.memory = Memory(embedding, memory_net) 50 | 51 | # Lost properties 52 | self.max_lost_frames = config.TRACKING.MAX_LOST_FRAMES_BEFORE_REMOVE 53 | 54 | # For training 55 | self.train_update_count = 0 56 | 57 | @staticmethod 58 | def next_id(): 59 | Tracklet._count += 1 60 | return Tracklet._count 61 | 62 | @property 63 | def ltwh(self): 64 | # Retrieve left, top, width, height from box (x1y1x2y2) 65 | ltwh = np.asarray(self.box.clone().detach().cpu().numpy()) 66 | ltwh[2:] -= ltwh[:2] 67 | return ltwh 68 | 69 | def cmc_update(self, warp): 70 | """ 71 | warp: affine transform matrix, np.array or None (no ECC) 72 | """ 73 | warp_tensor = torch.tensor(warp, dtype=self.box.dtype, device=self.box.device) 74 | p1 = torch.tensor([self.box[0], self.box[1], 1], dtype=self.box.dtype, device=self.box.device).view(3, 1) 75 | p2 = torch.tensor([self.box[2], self.box[3], 1], dtype=self.box.dtype, device=self.box.device).view(3, 1) 76 | p1_n = torch.mm(warp_tensor, p1).view(1, 2) 77 | p2_n = torch.mm(warp_tensor, p2).view(1, 2) 78 | box = torch.cat((p1_n, p2_n), 1).view(1, -1).squeeze(0) 79 | self.update(box) 80 | 81 | def update(self, box, score=None, embedding=None): 82 | # Prevent abnormal aspect ratios 83 | aspect_ratio = (box[3] - box[1]) / (box[2] - box[0]) 84 | # Keep center unchanged, use original width & height 85 | if aspect_ratio < 1.0 or aspect_ratio > 4.0: 86 | cx = box[0] + (box[2] - box[0]) / 2 87 | cy = box[1] + (box[3] - box[1]) / 2 88 | w = self.box[2] - self.box[0] 89 | h = self.box[3] - self.box[1] 90 | box = torch.tensor([cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2], 91 | dtype=self.box.dtype, device=self.box.device) 92 | self.box = box 93 | 94 | # Update kalman filter 95 | self.kalman_filter.predict() 96 | self.kalman_filter.update(box.cpu().numpy()) 97 | 98 | if score is not None: 99 | self.score = score 100 | if embedding is not None: 101 | self.update_embedding(embedding) 102 | 103 | def predict(self): 104 | self.box = torch.tensor(self.kalman_filter.predict(), dtype=self.box.dtype, device=self.box.device) 105 | 106 | def recover(self, frame_id, box, embedding): 107 | self.frame_id = frame_id 108 | self.box = box.detach() 109 | self.time_since_update = 0 110 | self.update_embedding(embedding) 111 | 112 | # Reset prediction 113 | self.kalman_filter.initiate(self.box.cpu().numpy()) 114 | 115 | # Reset score 116 | self.score = 1.0 117 | 118 | @property 119 | def avg_embedding(self): 120 | return self.embedding.mean(-1).mean(-1) 121 | 122 | @property 123 | def embedding(self): 124 | return self.memory.h_state 125 | 126 | def update_embedding(self, new_embedding, training=False): 127 | # Reset for training when updated for 4 times 128 | if training and self.train_update_count >= 4: 129 | self.memory = Memory(new_embedding, self.memory_net) 130 | self.memory_train_input = [new_embedding] 131 | self.train_update_count = 0 132 | else: 133 | # For tracking 134 | if not training: 135 | self.memory.update(new_embedding, self.memory_net) 136 | # For training 137 | else: 138 | self.memory_train_input.append(new_embedding) 139 | self.memory.train_update(self.memory_train_input, self.memory_net) 140 | self.train_update_count += 1 141 | 142 | def update_active_info(self, frame_id): 143 | # Except for new tracklet 144 | if frame_id > self.frame_id: 145 | self.frame_id = frame_id 146 | self.time_since_update = 0 147 | self.tracklet_len += 1 148 | 149 | def update_lost_info(self): 150 | self.time_since_update += 1 151 | 152 | def __del__(self): 153 | if "memory" in self.__dict__: 154 | del self.memory 155 | # Release GPU memory 156 | torch.cuda.empty_cache() 157 | -------------------------------------------------------------------------------- /lib/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songguocode/TADAM/abd0b7422c3582e36c928778894cee8a159f896e/lib/training/__init__.py -------------------------------------------------------------------------------- /lib/training/group_by_aspect_ratio.py: -------------------------------------------------------------------------------- 1 | # Adapted from torchvision 2 | 3 | import bisect 4 | from collections import defaultdict 5 | import copy 6 | import numpy as np 7 | 8 | from torch.utils.data.sampler import BatchSampler, Sampler 9 | 10 | 11 | class GroupedBatchSampler(BatchSampler): 12 | """ 13 | Wraps another sampler to yield a mini-batch of indices. 14 | It enforces that the batch only contain elements from the same group. 15 | It also tries to provide mini-batches which follows an ordering which is 16 | as close as possible to the ordering from the original sampler. 17 | Arguments: 18 | sampler (Sampler): Base sampler. 19 | group_ids (list[int]): If the sampler produces indices in range [0, N), 20 | `group_ids` must be a list of `N` ints which contains the group id of each sample. 21 | The group ids must be a continuous set of integers starting from 22 | 0, i.e. they must be in the range [0, num_groups). 23 | batch_size (int): Size of mini-batch. 24 | """ 25 | def __init__(self, sampler, group_ids, batch_size): 26 | if not isinstance(sampler, Sampler): 27 | raise ValueError( 28 | "sampler should be an instance of " 29 | f"torch.utils.data.Sampler, but got sampler={sampler}" 30 | ) 31 | self.sampler = sampler 32 | self.group_ids = group_ids 33 | self.batch_size = batch_size 34 | 35 | def __iter__(self): 36 | buffer_per_group = defaultdict(list) 37 | samples_per_group = defaultdict(list) 38 | 39 | num_batches = 0 40 | for idx in self.sampler: 41 | group_id = self.group_ids[idx] 42 | buffer_per_group[group_id].append(idx) 43 | samples_per_group[group_id].append(idx) 44 | if len(buffer_per_group[group_id]) == self.batch_size: 45 | yield buffer_per_group[group_id] 46 | num_batches += 1 47 | del buffer_per_group[group_id] 48 | assert len(buffer_per_group[group_id]) < self.batch_size 49 | 50 | # now we have run out of elements that satisfy 51 | # the group criteria, let's return the remaining 52 | # elements so that the size of the sampler is 53 | # deterministic 54 | expected_num_batches = len(self) 55 | num_remaining = expected_num_batches - num_batches 56 | if num_remaining > 0: 57 | # for the remaining batches, take first the buffers with largest number 58 | # of elements 59 | for group_id, _ in sorted(buffer_per_group.items(), 60 | key=lambda x: len(x[1]), reverse=True): 61 | remaining = self.batch_size - len(buffer_per_group[group_id]) 62 | buffer_per_group[group_id].extend( 63 | samples_per_group[group_id][:remaining]) 64 | assert len(buffer_per_group[group_id]) == self.batch_size 65 | yield buffer_per_group[group_id] 66 | num_remaining -= 1 67 | if num_remaining == 0: 68 | break 69 | assert num_remaining == 0 70 | 71 | def __len__(self): 72 | return len(self.sampler) // self.batch_size 73 | 74 | 75 | def _quantize(x, bins): 76 | bins = copy.deepcopy(bins) 77 | bins = sorted(bins) 78 | quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) 79 | return quantized 80 | 81 | 82 | def create_aspect_ratio_groups(dataset, k=0): 83 | bins = (2 ** np.linspace(-1, 1, 2 * k + 1)).tolist() if k > 0 else [1.0] 84 | groups = _quantize(dataset._aspect_ratios, bins) 85 | # count number of elements per group 86 | counts = np.unique(groups, return_counts=True)[1] 87 | fbins = [0] + bins + [np.inf] 88 | print(f"Using {fbins} as bins for aspect ratio quantization") 89 | print(f"Count of instances per bin: {counts}") 90 | return groups 91 | -------------------------------------------------------------------------------- /lib/training/train.py: -------------------------------------------------------------------------------- 1 | r"""PyTorch Detection Training. 2 | 3 | To run in a multi-gpu environment, use the distributed launcher:: 4 | 5 | python -m torch.distributed.launch --nproc_per_node=$NGPU --use_env \ 6 | train.py ... 7 | 8 | Example: 9 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=2 --use_env \ 10 | -m lib.training.train TADAM_MOT17 --config TADAM_MOT17 11 | 12 | Note: 13 | If conducting multiple mgpu training on the same machine, add --master_port=$any_number in first line 14 | Use different numbers for trainings. This is to avoid setup conflict between different trainings 15 | """ 16 | 17 | import os 18 | import sys 19 | import time 20 | import datetime 21 | import random 22 | import math 23 | import numpy as np 24 | import torch 25 | 26 | from ..modules.detector import Detector 27 | from ..dataset.mot import MOTDetection, collate_fn 28 | from .train_utils import init_distributed_mode, get_rank, get_transform, \ 29 | MetricLogger, SmoothedValue, reduce_dict, save_on_master 30 | from ..configs.config import load_config 31 | from ..utils.log import get_logger, log_or_print 32 | 33 | 34 | def train_mot(training_name, save_dir, config, logger, is_distributed=False): 35 | # Deterministic 36 | seed = config.TRAINING.RANDOM_SEED 37 | random.seed(seed) 38 | np.random.seed(seed) 39 | torch.cuda.manual_seed_all(seed) 40 | 41 | # Get current device 42 | # Operations other than training should only happen on device 0 43 | current_device = get_rank() 44 | 45 | # Dataset 46 | dataset_train = MOTDetection( 47 | root=config.PATHS.DATASET_ROOT, 48 | dataset=config.NAMES.DATASET if config.NAMES.DATASET != "MOT17" else "MOT17Det", 49 | transforms=get_transform(train=True), 50 | vis_threshold=config.TRAINING.VIS_THRESHOLD, 51 | ) 52 | # Dataloader 53 | if is_distributed: 54 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train) 55 | else: 56 | train_sampler = torch.utils.data.RandomSampler(dataset_train) 57 | train_batch_sampler = torch.utils.data.BatchSampler( 58 | train_sampler, config.TRAINING.BATCH_SIZE, drop_last=True) 59 | data_loader_train = torch.utils.data.DataLoader( 60 | dataset_train, 61 | batch_sampler=train_batch_sampler, 62 | num_workers=config.TRAINING.WORKERS, 63 | collate_fn=collate_fn 64 | ) 65 | 66 | # Create model and load checkpoint 67 | log_or_print(logger, f"Creating model on device #{current_device}") 68 | model = Detector( 69 | config, 70 | num_classes=2, 71 | num_ids=dataset_train.num_ids, 72 | tracking=False, 73 | logger=logger 74 | ) 75 | device = "cuda" 76 | model.to(device) 77 | # For distributed training 78 | if is_distributed: 79 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[current_device]) 80 | 81 | # Warmup epochs 82 | total_warmup_epochs = config.TRAINING.WARMUP_EPOCHS 83 | remaining_warmup_epochs = total_warmup_epochs 84 | 85 | # Optimizer 86 | if config.TRAINING.WARMUP_EPOCHS > 0: 87 | # Set lr for id warmup components separately 88 | id_warmup_lr_list = ["roi_heads.id_module", "memory_net"] 89 | base_params = [] 90 | warmup_params = [] 91 | for name, p in model.named_parameters(): 92 | if p.requires_grad: 93 | in_warmup_list = False 94 | for w_n in id_warmup_lr_list: 95 | if name.startswith(w_n): 96 | in_warmup_list = True 97 | if in_warmup_list: 98 | warmup_params.append(p) 99 | else: 100 | base_params.append(p) 101 | params = [ 102 | {"params": base_params}, 103 | {"params": warmup_params, "lr": config.TRAINING.WARMUP_LR} 104 | ] 105 | else: 106 | params = [p for p in model.parameters() if p.requires_grad] 107 | optimizer = torch.optim.SGD(params, lr=config.TRAINING.LR, 108 | momentum=config.TRAINING.MOMENTUM, 109 | weight_decay=config.TRAINING.WEIGHT_DECAY) 110 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 111 | step_size=config.TRAINING.LR_STEP_SIZE, gamma=config.TRAINING.LR_GAMMA) 112 | 113 | # Ready for training 114 | if is_distributed: 115 | log_or_print(logger, f"Multiple GPU training on device #{current_device}") 116 | else: 117 | log_or_print(logger, f"Single GPU training on device #{current_device}") 118 | 119 | # Train 120 | start_time = time.time() 121 | # Epoch starts from 1 for easier understanding 122 | for epoch in range(1, 1 + config.TRAINING.EPOCHS + total_warmup_epochs): 123 | # Sync 124 | if is_distributed: 125 | train_sampler.set_epoch(epoch) 126 | torch.cuda.synchronize() 127 | # Reset lr for formal training 128 | if epoch == total_warmup_epochs + 1: 129 | for g in optimizer.param_groups: 130 | g["lr"] = config.TRAINING.LR 131 | 132 | # Initialize tracklets for the epoch, for identity training 133 | if is_distributed: 134 | model.module.all_tracklets_dict = {} 135 | else: 136 | model.all_tracklets_dict = {} 137 | 138 | model.train() 139 | metric_logger = MetricLogger(delimiter=" ", logger=logger) 140 | metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) 141 | header = f"device: [{current_device}] {'warmup ' if remaining_warmup_epochs > 0 else ''}" + \ 142 | f"epoch: [{epoch if remaining_warmup_epochs > 0 else epoch - total_warmup_epochs:2d}/" + \ 143 | f"{total_warmup_epochs if remaining_warmup_epochs > 0 else config.TRAINING.EPOCHS:2d}]" 144 | 145 | for step, (images, targets) in metric_logger.log_every(data_loader_train, config.TRAINING.PRINT_FREQ, header): 146 | step += 1 # Starts with 1 for easier understanding 147 | 148 | # Move to cuda 149 | images = list(image.to(device) for image in images) 150 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 151 | 152 | # Loss 153 | if is_distributed: 154 | loss_dict = model.module.custom_train(images, targets, 155 | warmup=remaining_warmup_epochs > 0) 156 | else: 157 | loss_dict = model.custom_train(images, targets, 158 | warmup=remaining_warmup_epochs > 0) 159 | losses = sum(loss for loss in loss_dict.values()) 160 | 161 | # Reduce losses over all GPUs for logging purposes 162 | loss_dict_reduced = reduce_dict(loss_dict) 163 | 164 | # Overall loss 165 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 166 | loss_value = losses_reduced.item() 167 | 168 | # Detect loss explosion 169 | if not math.isfinite(loss_value): 170 | logger.info(f"Loss is {loss_value}, stopping training") 171 | logger.info(f"Last loss {loss_dict_reduced}") 172 | sys.exit(1) 173 | 174 | optimizer.zero_grad() 175 | losses.backward() 176 | # Clip gradients in warmup only, may happen if lr is large 177 | if remaining_warmup_epochs > 0: 178 | torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0) 179 | optimizer.step() 180 | 181 | # Update log 182 | if remaining_warmup_epochs > 0: 183 | metric_logger.update(**{"lr_warmup": optimizer.param_groups[1]["lr"]}) 184 | metric_logger.update(loss=losses_reduced, **loss_dict_reduced) 185 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 186 | 187 | # No lr_scheduler stepping in warmup 188 | if remaining_warmup_epochs > 0: 189 | # Do not step lr during warmup 190 | remaining_warmup_epochs -= 1 191 | # Save checkpoint at end of warmup 192 | if remaining_warmup_epochs == 0: 193 | # Save checkpoint, only in main process 194 | save_on_master( 195 | {"state_dict": model.module.state_dict() if is_distributed else model.state_dict()}, 196 | os.path.join(save_dir, f"checkpoint_{training_name}_warmup_epoch_{total_warmup_epochs}.pth") 197 | ) 198 | # End of warmup 199 | else: 200 | # Move forward in lr_scheduler 201 | lr_scheduler.step() 202 | # Save checkpoints, at SAVE_FREQ or last epoch 203 | if (epoch - total_warmup_epochs) % config.TRAINING.SAVE_FREQ == 0 or \ 204 | (epoch - total_warmup_epochs) == config.TRAINING.EPOCHS: 205 | # Save checkpoint, only in main process 206 | save_on_master( 207 | {"state_dict": model.module.state_dict() if is_distributed else model.state_dict()}, 208 | os.path.join(save_dir, f"checkpoint_{training_name}_epoch_{epoch - total_warmup_epochs}.pth") 209 | ) 210 | 211 | # Clean up 212 | if is_distributed: 213 | torch.distributed.destroy_process_group() 214 | 215 | # Save final model to model root 216 | model_path = os.path.join(config.PATHS.MODEL_ROOT, f"{training_name}.pth") 217 | save_on_master({"state_dict": model.module.state_dict() if is_distributed else model.state_dict()}, model_path) 218 | log_or_print(logger, f"Saved trained model to '{model_path}'") 219 | 220 | # Log total time for training 221 | total_time = time.time() - start_time 222 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 223 | log_or_print(logger, f"Training time {total_time_str}") 224 | 225 | 226 | if __name__ == "__main__": 227 | import argparse 228 | parser = argparse.ArgumentParser(description="train on mot") 229 | parser.add_argument("name", help="name for training, required") 230 | parser.add_argument("--config", default="TADAM_MOT17", type=str, help="config file to load") 231 | parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training") 232 | args = parser.parse_args() 233 | 234 | # Load config 235 | config, cfg_msg = load_config(args.config) 236 | # Create folder for output 237 | save_dir = os.path.join(config.PATHS.MODEL_ROOT, "checkpoints", args.name) 238 | if not os.path.isdir(save_dir): 239 | os.makedirs(save_dir) 240 | 241 | # Change default url in case multiple training in progress to avoid conflict 242 | if args.dist_url == "env://": 243 | args.dist_url += f"{args.name}_{datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}" 244 | 245 | # Setup distributed training 246 | init_distributed_msg = init_distributed_mode(args) 247 | 248 | # Logger 249 | logger = get_logger(name="global", save_file=True, overwrite_file=True, 250 | log_dir=save_dir, log_name=f"{args.name}") 251 | log_or_print(logger, cfg_msg) 252 | log_or_print(logger, init_distributed_msg) 253 | 254 | train_mot(args.name, save_dir, config, logger, is_distributed=args.distributed) 255 | -------------------------------------------------------------------------------- /lib/training/train_utils.py: -------------------------------------------------------------------------------- 1 | # Adapted from torchvision 2 | 3 | import os 4 | import time 5 | import pickle 6 | import random 7 | import datetime 8 | from collections import defaultdict, deque 9 | 10 | import torch 11 | import torch.distributed as dist 12 | from torchvision.transforms import functional as F 13 | 14 | from ..utils.log import log_or_print 15 | 16 | 17 | def get_transform(train): 18 | """ 19 | Returns transform to be applied upon dataloading 20 | """ 21 | transforms = [] 22 | # Convert PIL images to tensors 23 | transforms.append(ToTensor()) 24 | if train: 25 | # Randomly flip training images together with ground truth during training 26 | transforms.append(RandomHorizontalFlip(0.5)) 27 | return Compose(transforms) 28 | 29 | 30 | class ToTensor(object): 31 | def __call__(self, image, target): 32 | image = F.to_tensor(image) 33 | return image, target 34 | 35 | 36 | class RandomHorizontalFlip(object): 37 | def __init__(self, prob): 38 | self.prob = prob 39 | 40 | def __call__(self, image, target): 41 | if random.random() < self.prob: 42 | height, width = image.shape[-2:] 43 | image = image.flip(-1) 44 | bbox = target["boxes"] 45 | bbox[:, [0, 2]] = width - bbox[:, [2, 0]] 46 | target["boxes"] = bbox 47 | return image, target 48 | 49 | 50 | class Compose(object): 51 | def __init__(self, transforms): 52 | self.transforms = transforms 53 | 54 | def __call__(self, image, target): 55 | for t in self.transforms: 56 | image, target = t(image, target) 57 | return image, target 58 | 59 | 60 | class SmoothedValue(object): 61 | """Track a series of values and provide access to smoothed values over a 62 | window or the global series average. 63 | """ 64 | 65 | def __init__(self, window_size=20, fmt=None): 66 | if fmt is None: 67 | fmt = "{median:.4f} ({global_avg:.4f})" 68 | self.deque = deque(maxlen=window_size) 69 | self.total = 0.0 70 | self.count = 0 71 | self.fmt = fmt 72 | 73 | def update(self, value, n=1): 74 | self.deque.append(value) 75 | self.count += n 76 | self.total += value * n 77 | 78 | def synchronize_between_processes(self): 79 | """ 80 | Warning: does not synchronize the deque! 81 | """ 82 | if not is_dist_avail_and_initialized(): 83 | return 84 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 85 | dist.barrier() 86 | dist.all_reduce(t) 87 | t = t.tolist() 88 | self.count = int(t[0]) 89 | self.total = t[1] 90 | 91 | @property 92 | def median(self): 93 | d = torch.tensor(list(self.deque)) 94 | return d.median().item() 95 | 96 | @property 97 | def avg(self): 98 | d = torch.tensor(list(self.deque), dtype=torch.float32) 99 | return d.mean().item() 100 | 101 | @property 102 | def global_avg(self): 103 | return self.total / self.count 104 | 105 | @property 106 | def max(self): 107 | return max(self.deque) 108 | 109 | @property 110 | def value(self): 111 | return self.deque[-1] 112 | 113 | def __str__(self): 114 | return self.fmt.format( 115 | median=self.median, 116 | avg=self.avg, 117 | global_avg=self.global_avg, 118 | max=self.max, 119 | value=self.value) 120 | 121 | 122 | def all_gather(data): 123 | """ 124 | Run all_gather on arbitrary picklable data (not necessarily tensors) 125 | Args: 126 | data: any picklable object 127 | Returns: 128 | list[data]: list of data gathered from each rank 129 | """ 130 | world_size = get_world_size() 131 | if world_size == 1: 132 | return [data] 133 | 134 | # serialized to a Tensor 135 | buffer = pickle.dumps(data) 136 | storage = torch.ByteStorage.from_buffer(buffer) 137 | tensor = torch.ByteTensor(storage).to("cuda") 138 | 139 | # obtain Tensor size of each rank 140 | local_size = torch.tensor([tensor.numel()], device="cuda") 141 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 142 | dist.all_gather(size_list, local_size) 143 | size_list = [int(size.item()) for size in size_list] 144 | max_size = max(size_list) 145 | 146 | # receiving Tensor from all ranks 147 | # we pad the tensor because torch all_gather does not support 148 | # gathering tensors of different shapes 149 | tensor_list = [] 150 | for _ in size_list: 151 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 152 | if local_size != max_size: 153 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 154 | tensor = torch.cat((tensor, padding), dim=0) 155 | dist.all_gather(tensor_list, tensor) 156 | 157 | data_list = [] 158 | for size, tensor in zip(size_list, tensor_list): 159 | buffer = tensor.cpu().numpy().tobytes()[:size] 160 | data_list.append(pickle.loads(buffer)) 161 | 162 | return data_list 163 | 164 | 165 | def reduce_dict(input_dict, average=True): 166 | """ 167 | Args: 168 | input_dict (dict): all the values will be reduced 169 | average (bool): whether to do average or sum 170 | Reduce the values in the dictionary from all processes so that all processes 171 | have the averaged results. Returns a dict with the same fields as 172 | input_dict, after reduction. 173 | """ 174 | world_size = get_world_size() 175 | if world_size < 2: 176 | return input_dict 177 | with torch.no_grad(): 178 | names = [] 179 | values = [] 180 | # sort the keys so that they are consistent across processes 181 | for k in sorted(input_dict.keys()): 182 | names.append(k) 183 | values.append(input_dict[k]) 184 | values = torch.stack(values, dim=0) 185 | dist.all_reduce(values) 186 | if average: 187 | values /= world_size 188 | reduced_dict = {k: v for k, v in zip(names, values)} 189 | return reduced_dict 190 | 191 | 192 | # For tensorboard writer 193 | def write_scalars(writer, scalars, names, n_iter, tag=None): 194 | for i, scalar, in enumerate(scalars): 195 | if tag is not None: 196 | name = os.path.join(tag, names[i]) 197 | else: 198 | name = names[i] 199 | writer.add_scalar(name, scalar, n_iter) 200 | 201 | 202 | class MetricLogger(object): 203 | def __init__(self, delimiter="\t", logger=None): 204 | self.meters = defaultdict(SmoothedValue) 205 | self.delimiter = delimiter 206 | self.logger = logger 207 | 208 | def update(self, **kwargs): 209 | for k, v in kwargs.items(): 210 | if isinstance(v, torch.Tensor): 211 | v = v.item() 212 | assert isinstance(v, (float, int)) 213 | self.meters[k].update(v) 214 | 215 | def __getattr__(self, attr): 216 | if attr in self.meters: 217 | return self.meters[attr] 218 | if attr in self.__dict__: 219 | return self.__dict__[attr] 220 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'") 221 | 222 | def __str__(self): 223 | loss_str = [] 224 | for i, (name, meter) in enumerate(self.meters.items()): 225 | if i % 2 == 0: 226 | end = "\n" if i == len(self.meters.items()) - 1 else "" 227 | loss_str.append(f"\t\t\t\t{name}: {meter}{end}") 228 | else: 229 | loss_str.append(f"{name}: {meter}\n") 230 | 231 | return self.delimiter.join(loss_str) 232 | 233 | def synchronize_between_processes(self): 234 | for meter in self.meters.values(): 235 | meter.synchronize_between_processes() 236 | 237 | def add_meter(self, name, meter): 238 | self.meters[name] = meter 239 | 240 | def log_every(self, iterable, print_freq, header=None): 241 | i = 0 242 | if not header: 243 | header = "" 244 | start_time = time.time() 245 | end = time.time() 246 | iter_time = SmoothedValue(fmt="{avg:.4f}") 247 | data_time = SmoothedValue(fmt="{avg:.4f}") 248 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 249 | if torch.cuda.is_available(): 250 | log_msg = self.delimiter.join([ 251 | header, 252 | "iter: [{0" + space_fmt + "}/{1}]", 253 | "eta: {eta}\n", 254 | "{meters}", 255 | "\t\t\t\titer_time: {time}", 256 | "data_time: {data}", 257 | "max_mem: {memory:.0f}MB" 258 | ]) 259 | else: 260 | log_msg = self.delimiter.join([ 261 | header, 262 | "[{0" + space_fmt + "}/{1}]", 263 | "eta: {eta}\n", 264 | "{meters}", 265 | "\t\t\t\titer_time: {time}", 266 | "data_time: {data}" 267 | ]) 268 | MB = 1024.0 * 1024.0 269 | for obj in iterable: 270 | data_time.update(time.time() - end) 271 | yield i, obj 272 | iter_time.update(time.time() - end) 273 | if i % print_freq == 0 or i == len(iterable) - 1: 274 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 275 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 276 | if torch.cuda.is_available(): 277 | log_or_print(self.logger, log_msg.format( 278 | i, len(iterable), eta=eta_string, 279 | meters=str(self), 280 | time=str(iter_time), data=str(data_time), 281 | memory=torch.cuda.max_memory_allocated() / MB)) 282 | else: 283 | log_or_print(self.logger, log_msg.format( 284 | i, len(iterable), eta=eta_string, 285 | meters=str(self), 286 | time=str(iter_time), data=str(data_time))) 287 | i += 1 288 | end = time.time() 289 | total_time = time.time() - start_time 290 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 291 | log_or_print(self.logger, f"{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)") 292 | 293 | 294 | def setup_for_distributed(is_master): 295 | """ 296 | This function disables printing when not in master process 297 | """ 298 | import builtins as __builtin__ 299 | builtin_print = __builtin__.print 300 | 301 | def print(*args, **kwargs): 302 | force = kwargs.pop("force", False) 303 | if is_master or force: 304 | builtin_print(*args, **kwargs) 305 | 306 | __builtin__.print = print 307 | 308 | 309 | def is_dist_avail_and_initialized(): 310 | if not dist.is_available(): 311 | return False 312 | if not dist.is_initialized(): 313 | return False 314 | return True 315 | 316 | 317 | def get_world_size(): 318 | if not is_dist_avail_and_initialized(): 319 | return 1 320 | return dist.get_world_size() 321 | 322 | 323 | def get_rank(): 324 | if not is_dist_avail_and_initialized(): 325 | return 0 326 | return dist.get_rank() 327 | 328 | 329 | def is_main_process(): 330 | return get_rank() == 0 331 | 332 | 333 | def save_on_master(*args, **kwargs): 334 | if is_main_process(): 335 | torch.save(*args, **kwargs) 336 | 337 | 338 | def init_distributed_mode(args): 339 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 340 | args.rank = int(os.environ["RANK"]) 341 | args.world_size = int(os.environ["WORLD_SIZE"]) 342 | args.gpu = int(os.environ["LOCAL_RANK"]) 343 | elif "SLURM_PROCID" in os.environ: 344 | args.rank = int(os.environ["SLURM_PROCID"]) 345 | args.gpu = args.rank % torch.cuda.device_count() 346 | else: 347 | args.distributed = False 348 | return "Not using distributed mode" 349 | 350 | args.distributed = True 351 | 352 | torch.cuda.set_device(args.gpu) 353 | args.dist_backend = "nccl" 354 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 355 | world_size=args.world_size, rank=args.rank) 356 | torch.distributed.barrier() 357 | setup_for_distributed(args.rank == 0) 358 | return f"Distributed init #{args.gpu} (rank {args.rank}): {args.dist_url}" 359 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songguocode/TADAM/abd0b7422c3582e36c928778894cee8a159f896e/lib/utils/__init__.py -------------------------------------------------------------------------------- /lib/utils/image_processing.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from torchvision.transforms.functional import to_pil_image 4 | 5 | 6 | def tensor_to_cv2(tensor_image): 7 | # To PIL image first 8 | pil_image = to_pil_image(tensor_image) 9 | opencv_image = np.array(pil_image) 10 | # Convert RGB to BGR 11 | opencv_image = opencv_image[:, :, ::-1].copy() 12 | return opencv_image 13 | 14 | 15 | def cmc_align(last_frame, curr_frame): 16 | """ 17 | Camera movement compensation 18 | Returns warp matrix for position alignment 19 | """ 20 | last_gray = cv2.cvtColor(np.array(last_frame), cv2.COLOR_RGB2GRAY) 21 | curr_gray = cv2.cvtColor(np.array(curr_frame), cv2.COLOR_RGB2GRAY) 22 | warp_mode = cv2.MOTION_EUCLIDEAN 23 | # warp_mode = cv2.MOTION_AFFINE 24 | warp_matrix = np.eye(2, 3, dtype=np.float32) 25 | number_of_iterations = 100 26 | termination_eps = 0.00001 27 | criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps) 28 | (cc, warp_matrix) = cv2.findTransformECC(last_gray, curr_gray, warp_matrix, warp_mode, criteria, None, 5) 29 | return warp_matrix 30 | -------------------------------------------------------------------------------- /lib/utils/kalman_filter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class KalmanFilter(object): 5 | def __init__(self, dim=2, dt=1., uncertainty_x=0.05, uncertainty_v=0.00625): 6 | # For x: [x1, x2, ..., vx1, vx2, ...], where vx is velocity of change in x 7 | # For z: [z1, z2, ...] 8 | self.dim = dim 9 | 10 | self.uncertainty_x = uncertainty_x 11 | self.uncertainty_v = uncertainty_v 12 | 13 | # State transition matrix 14 | self.F = np.eye(2 * dim, 2 * dim) 15 | # Set all x_new = x + vx * 1 16 | for i in range(dim): 17 | self.F[i][i + dim] = dt 18 | 19 | std = np.r_[np.ones(dim) * uncertainty_x, np.ones(dim) * uncertainty_v] 20 | 21 | # Process uncertainty 22 | self.Q = np.diag(np.square(std)) 23 | 24 | # Measurement function, convert x to z 25 | self.H = np.eye(dim, 2 * dim) 26 | 27 | # Covariance matrix 28 | self.P = np.diag(np.square(std)) 29 | 30 | # State uncertainty 31 | std_z = np.ones(dim) * uncertainty_x 32 | self.R = np.diag(np.square(std_z)) 33 | 34 | def initiate(self, measurement): 35 | # Use first value as initial value 36 | v = np.zeros_like(measurement) 37 | self.x = np.r_[measurement, v].astype(float) 38 | 39 | std = np.r_[np.ones(self.dim) * self.uncertainty_x, np.ones(self.dim) * self.uncertainty_v] 40 | # Covariance matrix 41 | self.P = np.diag(np.square(std)) 42 | 43 | def predict(self, warp=None): 44 | """ 45 | warp: numpy array of shape (2, 3) 46 | """ 47 | if warp is not None: 48 | x = np.dot(self.F, self.x) 49 | x1 = np.array([[x[0], x[1], 1]]).T 50 | x2 = np.array([[x[2], x[3], 1]]).T 51 | x1_n = np.dot(warp, x1).reshape((1, 2)) 52 | x2_n = np.dot(warp, x2).reshape((1, 2)) 53 | x[:self.dim] = np.concatenate((x1_n, x2_n), axis=1) 54 | self.x = x 55 | else: 56 | self.x = np.dot(self.F, self.x) 57 | self.P = np.dot(self.F, self.P).dot(self.F.T) + self.Q 58 | return self.x[:self.dim] 59 | 60 | def update(self, z): 61 | self.S = np.dot(self.H, self.P).dot(self.H.T) + self.R 62 | self.K = np.dot(self.P, self.H.T).dot(np.linalg.inv(self.S)) 63 | y = z - np.dot(self.H, self.x) 64 | self.x += np.dot(self.K, y) 65 | self.P = self.P - np.dot(self.K, self.H).dot(self.P) 66 | return self.x[:self.dim], self.x[self.dim:] 67 | -------------------------------------------------------------------------------- /lib/utils/log.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from datetime import date, datetime 4 | 5 | level_dict = {"DEBUG": 10, "INFO": 20, "WARNING": 30, "ERROR": 40, "CRITICAL": 50} 6 | 7 | 8 | def log_or_print(logger, msg, level="info"): 9 | level = level.upper() 10 | assert level in level_dict.keys() 11 | if logger is not None: 12 | logger.log(level_dict[level], msg) 13 | else: 14 | timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 15 | print(f"{timestamp} [{level}]:\t{msg}") 16 | 17 | 18 | def get_logger( 19 | name="global", 20 | save_file=False, 21 | overwrite_file=False, 22 | log_dir=None, 23 | log_name=None, 24 | console_verbose=False, 25 | ): 26 | """ 27 | Setup and return a logger 28 | """ 29 | formatter = logging.Formatter( 30 | fmt="%(asctime)s [%(levelname)s]:\t%(message)s", datefmt="%Y-%m-%d %H:%M:%S") 31 | 32 | handlers = [] 33 | # Console handler 34 | console_handler = logging.StreamHandler() 35 | console_handler.setFormatter(formatter) 36 | console_handler.setLevel(logging.DEBUG if console_verbose else logging.INFO) 37 | handlers.append(console_handler) 38 | # File handler 39 | if save_file: 40 | if not os.path.exists(log_dir): 41 | os.makedirs(log_dir) 42 | file_handler = logging.FileHandler(os.path.join(log_dir, f"{log_name}_log.txt"), 43 | mode="w" if overwrite_file else "a") 44 | file_handler.setFormatter(formatter) 45 | file_handler.setLevel(logging.DEBUG) 46 | handlers.append(file_handler) 47 | # Removes any handler in RootLogger, due to a defect in torch DDP 48 | # See https://discuss.pytorch.org/t/distributed-1-8-0-logging-twice-in-a-single-process-same-code-works-properly-in-1-7-0/114103/6 49 | logging.getLogger().handlers = [] 50 | # Setup logger 51 | logger = logging.getLogger(name) 52 | logger.setLevel(logging.DEBUG) 53 | for h in handlers: 54 | logger.addHandler(h) 55 | 56 | return logger 57 | -------------------------------------------------------------------------------- /lib/utils/matching.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from scipy.optimize import linear_sum_assignment 5 | 6 | 7 | def indices_to_matches(cost_matrix, indices, threshold): 8 | """ 9 | Generates three list of indices 10 | """ 11 | matched_cost = cost_matrix[tuple(zip(*indices))] 12 | matched_mask = (matched_cost <= threshold) 13 | matches = indices[matched_mask] 14 | unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0])) 15 | unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1])) 16 | return matches, unmatched_a, unmatched_b 17 | 18 | 19 | def linear_assignment(cost_matrix, threshold): 20 | """ 21 | Cost above threshold will not be assigned 22 | return: matches, unmatched_a, unmatched_b 23 | """ 24 | if cost_matrix.size == 0: 25 | return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) 26 | cost_matrix[cost_matrix > threshold] = threshold + 1e-4 27 | indices = linear_sum_assignment(cost_matrix) # In two arrays, like [[1,2,3,4,5], [3,4,2,5,1]] 28 | # Convert to pairs like [1, 3] 29 | indices_array = np.transpose(np.asarray(indices)) 30 | return indices_to_matches(cost_matrix, indices_array, threshold) 31 | 32 | 33 | def iou(altwh, bltwh): 34 | # Get coordinates of intersection 35 | xA = max(altwh[0], bltwh[0]) 36 | yA = max(altwh[1], bltwh[1]) 37 | xB = min(altwh[0] + altwh[2], bltwh[0] + bltwh[2]) 38 | yB = min(altwh[1] + altwh[3], bltwh[1] + bltwh[3]) 39 | 40 | intersection = max(0, xB - xA) * max(0, yB - yA) 41 | 42 | areaA = altwh[2] * altwh[3] 43 | areaB = bltwh[2] * bltwh[3] 44 | union = areaA + areaB - intersection 45 | 46 | return intersection / union 47 | 48 | 49 | def reid_distance(tracklets, detections): 50 | cost_matrix = torch.zeros(len(tracklets), len(detections)) 51 | if cost_matrix.shape[0] * cost_matrix.shape[1] == 0: 52 | return cost_matrix.numpy() 53 | 54 | for i, t in enumerate(tracklets): 55 | for j, det in enumerate(detections): 56 | cost_matrix[i, j] = 1 - F.cosine_similarity(t.avg_embedding, det.avg_embedding) 57 | return np.maximum(cost_matrix.detach().cpu().numpy(), 0.0) 58 | 59 | 60 | def box_area(boxes): 61 | return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 62 | 63 | 64 | def box_giou(boxes1, boxes2): 65 | """ 66 | Intersection-over-Union (Jaccard index) of boxes. 67 | Both sets of boxes are expected to be in (x1, y1, x2, y2) format. 68 | Returns a Tensor[N, M] for all possible pairs 69 | """ 70 | area1 = box_area(boxes1) 71 | area2 = box_area(boxes2) 72 | 73 | # Get coordinates of potential common area 74 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 75 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 76 | wh = (rb - lt).clamp(min=0) # [N,M,2] 77 | # Common area 78 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 79 | # Union of the two 80 | union = area1[:, None] + area2 - inter 81 | 82 | # Get coordinates of smallest convex hull covering A and B 83 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 84 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 85 | wh = (rb - lt) # [N,M,2] 86 | # Convex hull 87 | areaC = wh[:, :, 0] * wh[:, :, 1] # [N, M] 88 | 89 | return inter / union - (areaC - union) / areaC 90 | -------------------------------------------------------------------------------- /lib/utils/model_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from ..utils.log import log_or_print 5 | 6 | 7 | def load_model(model, model_path, logger=None): 8 | # Check file 9 | if not os.path.isfile(model_path): 10 | log_or_print(logger, f"Invalid model at '{model_path}'. Please check path of file", level="critical") 11 | exit(0) 12 | 13 | # Load model checkpoint 14 | checkpoint = torch.load(model_path, map_location=torch.device("cpu")) 15 | if "state_dict" in checkpoint: 16 | state_dict = checkpoint["state_dict"] 17 | elif "model" in checkpoint: 18 | state_dict = checkpoint["model"] 19 | else: 20 | log_or_print(logger, "Model file does not contain 'state_dict' or 'model'. Check again.", level="critical") 21 | exit(0) 22 | 23 | model_dict = model.state_dict() 24 | new_state_dict = OrderedDict() 25 | matched_layers, discarded_layers = [], [] 26 | 27 | for k, v in state_dict.items(): 28 | # Discard "module." characters which is caused by parallel training 29 | if k.startswith("module."): 30 | k = k[7:] 31 | if k in model_dict and model_dict[k].size() == v.size(): 32 | new_state_dict[k] = v 33 | matched_layers.append(k) 34 | else: 35 | discarded_layers.append(k) 36 | 37 | model_dict.update(new_state_dict) 38 | model.load_state_dict(model_dict) 39 | 40 | if len(matched_layers) == 0: 41 | log_or_print( 42 | logger, f"The pretrained weights '{model_path}' cannot be loaded " 43 | "as no matched layers are found, please check the key names carefully.", 44 | level="critical" 45 | ) 46 | exit(0) 47 | else: 48 | log_or_print(logger, f"Successfully loaded pretrained weights from '{model_path}'") 49 | if len(discarded_layers) > 0: 50 | log_or_print(logger, "** The following layers are discarded " 51 | f"due to unmatched keys or layer size: '{discarded_layers}'", level="warning") 52 | return model 53 | -------------------------------------------------------------------------------- /lib/utils/official_benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | from ..configs.config import load_config 5 | from ..dataset.mot import get_seq_names 6 | from ..utils.log import log_or_print, get_logger 7 | 8 | 9 | def benchmark(dataset, result_name, eval_root, result_root, seq_names, logger=None): 10 | """ 11 | Copy results to TrackEval, evaluate then retrieve results 12 | Only train set is provided with gt 13 | Evaluates all sequences at once for now 14 | """ 15 | # # Create seqmaps file 16 | # seqmap_path = os.path.join(eval_root, "data/gt/mot_challenge/seqmaps") 17 | # seqmap = f"{result_name}.txt" 18 | # seqmap_file = os.path.join(seqmap_path, seqmap) 19 | # if not os.path.isfile(seqmap_file): 20 | # with open(seqmap_file, "w") as f: 21 | # f.write("name\n") 22 | # for i, name in enumerate(seq_names): 23 | # with open(seqmap_file, "a+") as f: 24 | # if i == len(seq_names) - 1: 25 | # f.write(name) 26 | # else: 27 | # f.write(name + "\n") 28 | # log_or_print(logger, f"Created seqmaps file at {seqmap_file}") 29 | # Copy results to TrackEval 30 | destination_folder = os.path.join(eval_root, "data/trackers/mot_challenge", f"{dataset}-train", result_name) 31 | data_folder = os.path.join(destination_folder, "data") 32 | if not os.path.isdir(data_folder): 33 | os.makedirs(data_folder) 34 | print(os.path.isdir(data_folder)) 35 | for name in seq_names: 36 | result_file = os.path.join(result_root, dataset, result_name, f"{name}.txt") 37 | target_file = os.path.join(data_folder, f"{name}.txt") 38 | assert os.path.isfile(result_file), f"No result file found at '{result_file}'" 39 | shutil.copyfile(result_file, target_file) 40 | log_or_print(logger, f"Copied {name}.txt for evaluation", "debug") 41 | # Evaluate and copy back results 42 | os.system(f"python {eval_root}/scripts/run_mot_challenge.py --USE_PARALLEL True \ 43 | --TRACKERS_TO_EVAL {result_name} \ 44 | --BENCHMARK {dataset} --METRICS CLEAR Identity") 45 | result_detail = os.path.join(destination_folder, "pedestrian_detailed.csv") 46 | result_summary = os.path.join(destination_folder, "pedestrian_summary.txt") 47 | shutil.copyfile(result_detail, os.path.join(result_root, dataset, result_name, f"{result_name}_result_detailed.csv")) 48 | shutil.copyfile(result_summary, os.path.join(result_root, dataset, result_name, f"{result_name}_result_summary.txt")) 49 | log_or_print(logger, f"Retrieved evaluation results and saved in {os.path.join(result_root, result_name)}") 50 | 51 | 52 | if __name__ == "__main__": 53 | import argparse 54 | parser = argparse.ArgumentParser(description="Evaluate tracking result on benchmark") 55 | parser.add_argument("--result-name", default="TADAM_MOT17_train", type=str, help="result folder name") 56 | parser.add_argument("--config", default="TADAM_MOT17", type=str, help="config file") 57 | # parser.add_argument("--public-detection", default="all", choices=["all", "DPM", "FRCNN", "SDP"], 58 | # type=str, help="test on specified public detection, valid for MOT17 only. default is all") 59 | # parser.add_argument("--sequence", default="all", type=str, help="test on specified sequence. default is all") 60 | args = parser.parse_args() 61 | 62 | config, cfg_msg = load_config(args.config) 63 | logger = get_logger(name="global", save_file=False, console_verbose=False) 64 | log_or_print(logger, cfg_msg) 65 | 66 | dataset = config.NAMES.DATASET 67 | full_seq_names, _, _ = get_seq_names(dataset, "train", "all", "all") 68 | 69 | benchmark(dataset, args.result_name, config.PATHS.EVAL_ROOT, 70 | config.PATHS.RESULT_ROOT, full_seq_names, logger) 71 | -------------------------------------------------------------------------------- /lib/utils/timer.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/utils/timer.py 2 | 3 | import time 4 | 5 | 6 | class Timer(object): 7 | """A simple timer.""" 8 | def __init__(self): 9 | self.total_time = 0. 10 | self.calls = 0 11 | self.start_time = 0. 12 | self.diff = 0. 13 | self.average_time = 0. 14 | 15 | self.duration = 0. 16 | 17 | def tic(self): 18 | # using time.time instead of time.clock because time time.clock 19 | # does not normalize for multithreading 20 | self.start_time = time.time() 21 | 22 | def toc(self, average=True): 23 | self.diff = time.time() - self.start_time 24 | self.total_time += self.diff 25 | self.calls += 1 26 | self.average_time = self.total_time / self.calls 27 | if average: 28 | self.duration = self.average_time 29 | else: 30 | self.duration = self.diff 31 | return self.duration 32 | 33 | def clear(self): 34 | self.total_time = 0. 35 | self.calls = 0 36 | self.start_time = 0. 37 | self.diff = 0. 38 | self.average_time = 0. 39 | self.duration = 0. 40 | -------------------------------------------------------------------------------- /lib/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.getcwd()) 4 | import cv2 5 | import numpy as np 6 | import torch 7 | 8 | from ..dataset.mot import MOTTracking, collate_fn, read_mot_file 9 | from ..configs.config import load_config 10 | from ..utils.log import log_or_print, get_logger 11 | 12 | 13 | def voc_color_code(num_colors=100): 14 | def to_binary(val, idx): 15 | return ((val & (1 << idx)) != 0) 16 | 17 | color_code = np.zeros((num_colors, 3), dtype=np.uint8) 18 | for i in range(num_colors): 19 | r = g = b = 0 20 | c = i 21 | for j in range(8): 22 | r |= (to_binary(c, 0) << 7 - j) 23 | g |= (to_binary(c, 1) << 7 - j) 24 | b |= (to_binary(c, 2) << 7 - j) 25 | c >>= 3 26 | color_code[i, :] = [r, g, b] 27 | return color_code 28 | 29 | 30 | def plot_boxes( 31 | frame, 32 | boxes, 33 | obj_ids=None, 34 | show_ids=True, 35 | scores=None, 36 | show_scores=False, 37 | polygons=None, 38 | show_polygons=False, 39 | masks=None, 40 | show_masks=False, 41 | show_info=False, 42 | frame_id=0, 43 | image_scale=1, 44 | text_scale=2, 45 | line_thickness=2, 46 | fps=0. 47 | ): 48 | """ 49 | Draw a frame with bounding boxes 50 | """ 51 | im = np.copy(frame) 52 | im_h, im_w = im.shape[:2] 53 | 54 | # Determin colors first: 55 | colors = [] 56 | for i in range(len(boxes)): 57 | if obj_ids is None: 58 | colors.append((255, 0, 0)) 59 | else: 60 | obj_id = obj_ids[i] 61 | # color = get_color(abs(int(obj_id))) 62 | colors.append(tuple([int(c) for c in voc_color_code(256)[abs(int(obj_id)) % 256]])) 63 | 64 | # Draw masks first 65 | # Input should be K x H x W of float, where K is number of objects 66 | if masks is not None and show_masks: 67 | final_mask = np.zeros_like(im, dtype=np.uint8) 68 | for i, mask in enumerate(masks): 69 | mask = np.expand_dims(masks[i], axis=-1) 70 | final_mask += np.uint8(np.concatenate((mask * colors[i][0], mask * colors[i][1], mask * colors[i][2]), axis=-1)) 71 | im = cv2.addWeighted(im, 0.77, final_mask, 0.5, -1) 72 | 73 | for i, box in enumerate(boxes): 74 | if obj_ids is not None: 75 | obj_id = obj_ids[i] 76 | else: 77 | obj_id = None 78 | 79 | x1, y1, x2, y2 = box 80 | intbox = tuple(map(int, (x1, y1, x2, y2))) 81 | # Draw box 82 | if not show_polygons: 83 | cv2.rectangle(im, intbox[0:2], intbox[2:4], color=colors[i], thickness=line_thickness) 84 | 85 | # Draw Polygons 86 | polygon = None 87 | if polygons is not None and show_polygons: 88 | polygon = polygons[i] 89 | if polygon is not None: 90 | cv2.polylines(im, [polygon.reshape((-1, 1, 2))], True, (0, 255, 0), 3) 91 | 92 | # Draw id at top-left corner of box 93 | if obj_id is not None and show_ids: 94 | cv2.putText(im, f"{obj_id:d}", (int(x1), int(y1) + 20), cv2.FONT_HERSHEY_PLAIN, 95 | text_scale * 0.6, (0, 255, 255), thickness=1) 96 | 97 | # Draw scores at bottom-left corner of box 98 | score = None 99 | if scores is not None and show_scores: 100 | score = scores[i] 101 | cv2.putText(im, f"{score:.4f}", (int(x1), int(y2) - 20), cv2.FONT_HERSHEY_PLAIN, 102 | text_scale * 0.6, (0, 255, 255), thickness=1) 103 | if show_info: 104 | cv2.putText(im, "frame: %d fps: %.2f num: %d" % (frame_id, fps, len(boxes)), (0, int(15 * text_scale)), 105 | cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 0, 255), thickness=2) 106 | 107 | # Resize 108 | im = cv2.resize(im, (int(im_w * image_scale), int(im_h * image_scale))) 109 | 110 | return im 111 | 112 | 113 | def get_color(idx): 114 | idx = idx * 3 115 | color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255) 116 | return color 117 | 118 | 119 | class Content(object): 120 | def __init__(self): 121 | self.list = ["GT", "DET", "RESULT"] 122 | self.current = self.list[0] 123 | 124 | def next(self): 125 | self.current = self.list[(self.list.index(self.current) + 1) % len(self.list)] 126 | 127 | 128 | def show_mot( 129 | dataloader, 130 | dataset_root="../datasets", 131 | dataset="MOT17", 132 | which_set="train", 133 | sequence="MOT17-FRCNN-02", 134 | vis_threshold=0.1, 135 | result_root="output/results", 136 | result=None, 137 | start_frame=1, 138 | scale=1.0, 139 | ): 140 | """ 141 | Visualize MOT detections/ground truths/tracking results 142 | """ 143 | content = Content() 144 | hide_info = False 145 | hide_ids = False 146 | save_image = False 147 | save_dir = "./output/images" 148 | if not os.path.isdir(save_dir): 149 | os.makedirs(save_dir) 150 | 151 | # Load data 152 | include_result = False 153 | if result is not None: 154 | include_result = True 155 | result_path = os.path.join(result_root, dataset, result, f"{sequence}.txt") 156 | assert os.path.isfile(result_path), f"No valid result file found at '{result_path}'" 157 | log_or_print(logger, f"Loaded result file at '{result_path}'") 158 | 159 | # Load all data at once and store 160 | det_by_frame = [] 161 | gt_by_frame = [] 162 | result_by_frame = [] 163 | for frame_id, batch in enumerate(dataloader): 164 | frame_id += 1 # Start from 1 165 | # Get MOT data 166 | _, det_boxes, det_scores, gt_boxes, gt_ids, gt_visibilities = batch 167 | # Make a detached copy to stop dataloader from using file 168 | det_boxes = det_boxes[0].detach().clone().cpu().numpy() 169 | det_scores = det_scores[0].detach().clone().cpu().numpy() 170 | gt_boxes = gt_boxes[0].detach().clone().cpu().numpy() 171 | gt_ids = gt_ids[0].detach().cpu().clone().numpy() 172 | gt_visibilities = gt_visibilities[0].detach().clone().cpu().numpy() 173 | del batch 174 | 175 | # Detections 176 | det_by_frame.append((det_boxes, det_scores)) 177 | 178 | # GT 179 | if gt_boxes is not None: # In case of test sets 180 | gt_boxes = [gt_box for i, gt_box in enumerate(gt_boxes) if gt_visibilities[i] > vis_threshold] 181 | gt_ids = [gt_id for i, gt_id in enumerate(gt_ids) if gt_visibilities[i] > vis_threshold] 182 | gt_visibilities = [gt_visibility for i, gt_visibility in enumerate(gt_visibilities) if gt_visibilities[i] > vis_threshold] 183 | gt_by_frame.append((gt_boxes, gt_ids, gt_visibilities)) 184 | else: 185 | gt_by_frame.append(([], [], [])) 186 | 187 | # Result 188 | if include_result: 189 | result_boxes, result_ids, result_scores, _ = read_mot_file(result_path, frame_id) 190 | # In case no result for the frame, empty lists are returned in that case 191 | if len(result_boxes) and len(result_ids) and len(result_scores): 192 | result_boxes = result_boxes.cpu().numpy() 193 | result_ids = result_ids.cpu().numpy() 194 | result_scores = result_scores.cpu().numpy() 195 | result_by_frame.append((result_boxes, result_scores, result_ids)) 196 | else: 197 | result_by_frame.append(([], [], [])) 198 | 199 | # Load images 200 | image_dir = os.path.join(dataset_root, dataset, which_set, sequence, "img1") 201 | file_list = os.listdir(image_dir) 202 | 203 | def get_index(x_str): 204 | return x_str[:-4] 205 | file_list = sorted(file_list, key=get_index) 206 | 207 | # Show 208 | window_name = f"MOT Visualization - {sequence}" 209 | frame_id = start_frame 210 | filename = file_list[frame_id - 1] 211 | im = draw_frame(image_dir, filename, frame_id, det_by_frame[frame_id - 1], 212 | gt_by_frame[frame_id - 1], result_by_frame[frame_id - 1], content.current, hide_info, hide_ids, scale) 213 | while True: 214 | cv2.imshow(window_name, im) 215 | # Save if toggled 216 | if save_image: 217 | cv2.imwrite(os.path.join(save_dir, f"{sequence}-{content.current}-{filename.split('.')[0]}.jpg"), im) 218 | key = cv2.waitKey(0) 219 | # Prev frame, press Key "<" 220 | if key == 44: 221 | frame_id = (frame_id - 1) % len(file_list) 222 | if frame_id == 0: 223 | frame_id = len(file_list) 224 | filename = file_list[frame_id - 1] 225 | im = draw_frame(image_dir, filename, frame_id, det_by_frame[frame_id - 1], 226 | gt_by_frame[frame_id - 1], result_by_frame[frame_id - 1], content.current, hide_info, hide_ids, scale) 227 | # Next frame, press Key ">" 228 | elif key == 46: 229 | frame_id = (frame_id + 1) % len(file_list) 230 | if frame_id == 0: 231 | frame_id = len(file_list) 232 | filename = file_list[frame_id - 1] 233 | im = draw_frame(image_dir, filename, frame_id, det_by_frame[frame_id - 1], 234 | gt_by_frame[frame_id - 1], result_by_frame[frame_id - 1], content.current, hide_info, hide_ids, scale) 235 | # Exit, press Key "q" or Esc 236 | elif key == 113 or key == 27: 237 | break 238 | # Other options 239 | else: 240 | # Rotate among GT, DET, RESULT, press Key "t" 241 | if key == 116: 242 | content.next() 243 | # Save crops, press Key "c" 244 | elif key == 99: 245 | if content.current == "GT": 246 | boxes_info = gt_by_frame[frame_id - 1] 247 | elif content.current == "DET": 248 | boxes_info = det_by_frame[frame_id - 1] 249 | elif content.current == "RESULT": 250 | boxes_info = result_by_frame[frame_id - 1] 251 | save_crops(image_dir, filename, frame_id, boxes_info, content.current, save_dir, sequence) 252 | # Save image, press key "s" 253 | elif key == 115: 254 | save_image = not save_image 255 | # Hide info in image, press key "h" 256 | elif key == 104: 257 | hide_info = not hide_info 258 | # Hide ids in image, press key "i" 259 | elif key == 105: 260 | hide_ids = not hide_ids 261 | im = draw_frame(image_dir, filename, frame_id, det_by_frame[frame_id - 1], 262 | gt_by_frame[frame_id - 1], result_by_frame[frame_id - 1], content.current, hide_info, hide_ids, scale) 263 | 264 | 265 | def draw_frame( 266 | image_dir, 267 | filename, 268 | frame_id, 269 | detections, 270 | groundtruths, 271 | results, 272 | content_selection, 273 | hide_info, 274 | hide_ids, 275 | scale 276 | ): 277 | """ 278 | Draw a frame with given detections/ground truths/results 279 | """ 280 | im = cv2.imread(os.path.join(image_dir, filename)) 281 | content_info_position = (0, 22) 282 | content_info_color = (0, 255, 255) 283 | content_info_thickness = 2 284 | 285 | if content_selection == "DET": 286 | if not hide_info: 287 | cv2.putText(im, f"Frame: {frame_id:5d} Detections", content_info_position, 288 | cv2.FONT_HERSHEY_PLAIN, content_info_thickness, content_info_color, thickness=2) 289 | im = plot_boxes(im, detections[0], scores=detections[1], show_scores=True, image_scale=scale) 290 | elif content_selection == "GT": 291 | if not hide_info: 292 | cv2.putText(im, f"Frame: {frame_id:5d} Grount Truths", content_info_position, 293 | cv2.FONT_HERSHEY_PLAIN, content_info_thickness, content_info_color, thickness=2) 294 | im = plot_boxes(im, groundtruths[0], obj_ids=groundtruths[1], show_ids=not hide_ids, 295 | scores=groundtruths[2], show_scores=True, image_scale=scale) 296 | elif content_selection == "RESULT": 297 | if not hide_info: 298 | cv2.putText(im, f"Frame: {frame_id:5d} Results", content_info_position, 299 | cv2.FONT_HERSHEY_PLAIN, content_info_thickness, content_info_color, thickness=2) 300 | im = plot_boxes(im, results[0], scores=results[1], obj_ids=results[2], show_ids=not hide_ids, image_scale=scale) 301 | return im 302 | 303 | 304 | def save_crops(image_dir, filename, frame_id, boxes_info, content, save_dir, sequence): 305 | im = cv2.imread(os.path.join(image_dir, filename)) 306 | height, width = im.shape[:2] 307 | try: 308 | for i, box in enumerate(boxes_info[0]): 309 | log_or_print(logger, f"ID {boxes_info[1][i]} box: {boxes_info[0][i]}") 310 | crop = im[max(0, int(box[1])):min(height, int(box[1] + box[3])), max(0, int(box[0])):min(width, int(box[0] + box[2])), :] 311 | cv2.imwrite(os.path.join(save_dir, f"{sequence}-{content}-FRAME-{frame_id}-ID-{boxes_info[1][i]}.jpg"), crop) 312 | except Exception: 313 | log_or_print(logger, "No bounding boxes in current content", level="warning") 314 | 315 | 316 | if __name__ == "__main__": 317 | """ 318 | Run a visualization demo with parameters 319 | """ 320 | # arguments 321 | import argparse 322 | parser = argparse.ArgumentParser(description="MOT Visualization") 323 | parser.add_argument("--config", default="TADAM_MOT17", type=str, help="config file to load") 324 | parser.add_argument("--which-set", default="train", type=str, choices=["train", "test"], help="which sequence") 325 | parser.add_argument("--sequence", default="02", type=str, help="which sequence") 326 | parser.add_argument("--public-detection", default="FRCNN", type=str, 327 | choices=["None", "DPM", "FRCNN", "SDP"], help="public detection") 328 | parser.add_argument("--result", default=None, type=str, help="name for loading results") 329 | parser.add_argument("--start-frame", default=1, type=int, help="start frame") 330 | parser.add_argument("--scale", default=1, type=float, help="visual size of image") 331 | parser.add_argument("--vis_threshold", default=0.1, type=float, help="visibility threshold for gt") 332 | args = parser.parse_args() 333 | 334 | config, cfg_msg = load_config(args.config) 335 | logger = get_logger(name="global", save_file=False, console_verbose=False) 336 | log_or_print(logger, cfg_msg) 337 | 338 | public_detection = args.public_detection if config.NAMES.DATASET == "MOT17" else "None" 339 | dataloader = torch.utils.data.DataLoader(MOTTracking(config.PATHS.DATASET_ROOT, 340 | config.NAMES.DATASET, args.which_set, args.sequence, public_detection, args.vis_threshold), 341 | batch_size=1, shuffle=False, num_workers=4, collate_fn=collate_fn) 342 | public_detection = f"-{public_detection}" if public_detection != "None" else "" 343 | sequence = f"{config.NAMES.DATASET}-{int(args.sequence):02d}{public_detection}" 344 | 345 | # Show info 346 | log_or_print(logger, f"Showing {config.NAMES.DATASET}/{args.which_set}/{sequence}") 347 | 348 | show_mot(dataloader, config.PATHS.DATASET_ROOT, config.NAMES.DATASET, args.which_set, sequence, 349 | args.vis_threshold, config.PATHS.RESULT_ROOT, args.result, args.start_frame, args.scale) 350 | --------------------------------------------------------------------------------