├── .gitignore ├── .isort.cfg ├── .style.yapf ├── README.md ├── cost.py ├── data ├── test.txt ├── tracking │ ├── evaluate_tracking.seqmap │ ├── evaluate_tracking.seqmap.test │ ├── evaluate_tracking.seqmap.training │ └── label_02 │ │ ├── 0000.txt │ │ ├── 0001.txt │ │ ├── 0002.txt │ │ ├── 0003.txt │ │ ├── 0004.txt │ │ ├── 0005.txt │ │ ├── 0006.txt │ │ ├── 0007.txt │ │ ├── 0008.txt │ │ ├── 0009.txt │ │ ├── 0010.txt │ │ ├── 0011.txt │ │ ├── 0012.txt │ │ ├── 0013.txt │ │ ├── 0014.txt │ │ ├── 0015.txt │ │ ├── 0016.txt │ │ ├── 0017.txt │ │ ├── 0018.txt │ │ ├── 0019.txt │ │ └── 0020.txt ├── train.txt └── val.txt ├── dataset ├── __init__.py ├── common.py ├── patchwise_dataset.py └── test_seq_dataset.py ├── dev.sh ├── eval_seq.py ├── experiments ├── pp_pv_40e_dualadd_subabs_C │ ├── config.yaml │ ├── eval.sh │ ├── test.sh │ └── train.sh ├── pp_pv_40e_mul_A │ ├── config.yaml │ ├── eval.sh │ └── train.sh ├── pp_pv_40e_mul_B │ ├── config.yaml │ ├── eval.sh │ └── train.sh ├── pp_pv_40e_mul_C │ ├── config.yaml │ ├── eval.sh │ └── train.sh └── rrc_pfv_40e_subabs_dualadd_C │ ├── config.yaml │ ├── eval.sh │ ├── test.sh │ └── train.sh ├── kitti_devkit ├── README.md ├── evaluate_tracking.py ├── mailpy.py └── munkres.py ├── main.py ├── modules ├── __init__.py ├── appear_net.py ├── dropblock.py ├── fusion_net.py ├── gcn.py ├── ghm_loss.py ├── new_end.py ├── point_net.py ├── score_net.py ├── tracking_net.py └── vgg.py ├── point_cloud ├── box_np_ops.py ├── geometry.py ├── point_cloud_ops.py └── preprocess.py ├── requirements.txt ├── solvers.py ├── test.py ├── tracking_model.py └── utils ├── build_util.py ├── data_util.py ├── kitti_util.py ├── learning_schedules_fastai.py ├── optim_util.py └── train_util.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | __pycache__ 4 | *.tar 5 | checkpoint* 6 | events* 7 | log.txt 8 | *.ipynb* 9 | *.ipynb 10 | experiments/* 11 | data/* 12 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length = 79 3 | multi_line_output = 0 4 | known_first_party = mmdet 5 | known_third_party = Cython,albumentations,cv2,imagecorruptions,matplotlib,mmcv,numpy,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,setuptools,six,terminaltables,torch 6 | no_lines_before = STDLIB,LOCALFOLDER 7 | default_section = THIRDPARTY -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | BASED_ON_STYLE = pep8 3 | BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true 4 | SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Robust Multi-Modality Multi-Object Tracking 2 | 3 | This is the project page for our ICCV2019 paper: **Robust Multi-Modality Multi-Object Tracking**. 4 | 5 | **Authors**: [Wenwei Zhang](http://zhangwenwei.cn), [Hui Zhou](https://scholar.google.com/citations?user=i35tdbMAAAAJ&hl=zh-CN), [Shuyang Sun](https://kevin-ssy.github.io/), [Zhe Wang](https://wang-zhe.me/), [Jianping Shi](http://shijianping.me/), [Chen Change Loy](http://personal.ie.cuhk.edu.hk/~ccloy/) 6 | 7 | [[ArXiv]](https://arxiv.org/abs/1909.03850)  [[Project Page]](#)  [[Poster]](http://zhangwenwei.cn/files/mmMOT_poster_final.pdf) 8 | 9 | ## Introduction 10 | 11 | In this work, we design a generic sensor-agnostic multi-modality MOT framework (mmMOT), where each modality (i.e., sensors) is capable of performing its role independently to preserve reliability, and further improving its accuracy through a novel multi-modality fusion module. Our mmMOT can be trained in an end-to-end manner, enables joint optimization for the base feature extractor of each modality and an adjacency estimator for cross modality. Our mmMOT also makes the first attempt to encode deep representation of point cloud in data association process in MOT. 12 | 13 | For more details, please refer our [paper](https://arxiv.org/abs/1909.03850). 14 | 15 | ## Install 16 | 17 | This project is based on pytorch>=1.0, you can install it following the [official guide](https://pytorch.org/get-started/locally/). 18 | 19 | We recommand you to build a new conda environment to run the projects as follows: 20 | ```bash 21 | conda create -n mmmot python=3.7 cython 22 | conda activate mmmot 23 | conda install pytorch torchvision -c pytorch 24 | conda install numba 25 | ``` 26 | 27 | Then install packages from pip: 28 | ```bash 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | You can also follow the guide to install [SECOND](https://github.com/traveller59/second.pytorch), we use the same environment as that for SECOND. 33 | 34 | 35 | ## Usage 36 | 37 | We provide several configs and scripts in the `experiments` directory. 38 | 39 | To evaluate the pretrained models or the reimplemented models you can run command 40 | ```bash 41 | python -u eval_seq.py --config ${work_path}/config.yaml \ 42 | --load-path=${work_path}/${model} \ 43 | --result-path=${work_path}/results \ 44 | --result_sha=eval 45 | ``` 46 | The `--result_sha` option is used to distinguish different evaluation attempts. 47 | You can also simply run command like 48 | ``` 49 | sh ./experiments/pp_pv_40e_mul_A/eval.sh ${partition} 50 | ``` 51 | 52 | To train the model on your own, you can run command 53 | ``` 54 | python -u main.py --config ${work_path}/config.yaml \ 55 | --result-path=${work_path}/results 56 | ``` 57 | You can also simply run command like 58 | ``` 59 | sh ./experiments/pp_pv_40e_mul_A/train.sh ${partition} 60 | ``` 61 | 62 | **Note:** Both the train and eval scripts use srun as default, you can just comment them if you do not use srun. 63 | 64 | 65 | ## Pretrain Model 66 | 67 | We provide four models in the [google drive](https://drive.google.com/open?id=1IJ6rWSJw-BExQP-N25RNmQzUeTYSmwj6). 68 | The corresponding configs can be found in the `experiments` directory. 69 | 70 | Following the usage you can directly inference the model and get results as follows: 71 | 72 | 73 | | Name | Method | MOTA | 74 | | :-----------: | :-----: | :--: | 75 | |pp_pv_40e_mul_A|Fusion Module A| 77.57| 76 | |pp_pv_40e_mul_B|Fusion Module B| 77.62| 77 | |pp_pv_40e_mul_C|Fusion Module C| 78.18| 78 | |pp_pv_40e_dualadd_subabs_C|Fusion Module C++| 80.08| 79 | 80 | The results of Fusion Module A,B and C are the same as those in the Table 1 of the paper. 81 | The Fusion Module C++ indicates that it uses `absolute subtraction` and `softmax with addition` to improve the results, and has the same MOTA as that in the last row of Table 3 of the [paper](https://arxiv.org/abs/1909.03850). 82 | 83 | 84 | ## Data 85 | 86 | Currently it supports [PointPillar](https://github.com/nutonomy/second.pytorch)/[SECOND](https://github.com/traveller59/second.pytorch) detector, and also support [RRC-Net](https://github.com/xiaohaoChen/rrc_detection) detector. 87 | 88 | In the [paper](https://arxiv.org/abs/1909.03850), we train a PointPillars model to obtain the train/val detection results for ablation study, using the [official codebase](https://github.com/nutonomy/second.pytorch). The detection data are provided in the [google drive](https://drive.google.com/open?id=1IJ6rWSJw-BExQP-N25RNmQzUeTYSmwj6). Once you download the two pkl files, put them in the `data` directory. 89 | 90 | We also provide the data split used in our paper in the `data` directory. You need to download and unzip the data from the [KITTI Tracking Benchmark](http://www.cvlibs.net/datasets/kitti/eval_tracking.php) and put them in the `kitti_t_o` directory or any path you like. 91 | Do remember to change the path in the configs. 92 | 93 | The RRC detection are obtained from the [link](https://drive.google.com/file/d/1ZR1qEf2qjQYA9zALLl-ZXuWhqG9lxzsM/view) provided by [MOTBeyondPixels](https://github.com/JunaidCS032/MOTBeyondPixels). We use RRC detection for the [KITTI Tracking Benchmark](http://www.cvlibs.net/datasets/kitti/eval_tracking.php). 94 | 95 | 96 | ## Citation 97 | 98 | If you use this codebase or model in your research, please cite: 99 | ``` 100 | @InProceedings{mmMOT_2019_ICCV, 101 | author = {Zhang, Wenwei and Zhou, Hui and Sun, Shuyang, and Wang, Zhe and Shi, Jianping and Loy, Chen Change}, 102 | title = {Robust Multi-Modality Multi-Object Tracking}, 103 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, 104 | month = {October}, 105 | year = {2019} 106 | } 107 | ``` 108 | 109 | ## Acknowledgement 110 | 111 | This code benefits a lot from [SECOND](https://github.com/traveller59/second.pytorch) and use the detection results provided by [MOTBeyondPixels](https://github.com/JunaidCS032/MOTBeyondPixels). The GHM loss implementation is from [GHM_Detection](https://github.com/libuyu/GHM_Detection). 112 | 113 | -------------------------------------------------------------------------------- /cost.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CostLoss(nn.Module): 7 | 8 | def __init__(self, p=1): 9 | super(CostLoss, self).__init__() 10 | self.distance = nn.L1Loss(reduction='mean') 11 | 12 | def forward(self, y, gt_y, cost): 13 | distance = self.distance(y, gt_y) 14 | loss = cost.mul(y - gt_y).mean(-1) + distance 15 | return loss 16 | 17 | 18 | class NoDistanceLoss(nn.Module): 19 | 20 | def __init__(self, p=1): 21 | super(NoDistanceLoss, self).__init__() 22 | 23 | def forward(self, assign_det, assign_link, assign_new, assign_end, gt_det, 24 | gt_link, gt_new, gt_end, det_score, link_score, new_score, 25 | end_score): 26 | 27 | loss = [] 28 | loss.append(det_score.mul(assign_det - gt_det).view(-1)) 29 | loss.append(new_score.mul(assign_new - gt_new).view(-1)) 30 | loss.append(end_score.mul(assign_end - gt_end).view(-1)) 31 | for i in range(len(link_score)): 32 | loss.append(link_score[i].mul(assign_link[i] - 33 | gt_link[i]).view(-1)) 34 | loss = F.relu(torch.cat(loss).sum()) 35 | 36 | return loss 37 | 38 | 39 | class DistanceLoss(nn.Module): 40 | 41 | def __init__(self, p=1): 42 | super(DistanceLoss, self).__init__() 43 | self.distance = nn.L1Loss(reduction='none') 44 | 45 | def forward(self, assign_det, assign_link, assign_new, assign_end, gt_det, 46 | gt_link, gt_new, gt_end, det_score, link_score, new_score, 47 | end_score): 48 | 49 | loss = [] 50 | loss.append(det_score.mul(assign_det - gt_det).view(-1)) 51 | loss.append(new_score.mul(assign_new - gt_new).view(-1)) 52 | loss.append(end_score.mul(assign_end - gt_end).view(-1)) 53 | 54 | distance = [] 55 | distance.append(self.distance(assign_det, gt_det).view(-1)) 56 | distance.append(self.distance(assign_new, gt_new).view(-1)) 57 | distance.append(self.distance(assign_end, gt_end).view(-1)) 58 | for i in range(len(link_score)): 59 | loss.append(link_score[i].mul(assign_link[i] - 60 | gt_link[i]).view(-1)) 61 | distance.append(self.distance(assign_link[i], gt_link[i]).view(-1)) 62 | loss = F.relu(torch.cat(loss + distance).sum()) 63 | 64 | return loss 65 | 66 | 67 | class LinkLoss(nn.Module): 68 | 69 | def __init__(self, smooth_ratio=0, loss_type='l2'): 70 | super(LinkLoss, self).__init__() 71 | self.smooth_ratio = smooth_ratio 72 | self.loss_type = loss_type 73 | assert loss_type in ['l1', 'l2'] 74 | if 'l2' in loss_type: 75 | self.l2_loss = nn.MSELoss() 76 | if 'l1' in loss_type: 77 | print("Use smooth l1 loss for link") 78 | self.l1_loss = nn.SmoothL1Loss() 79 | 80 | def forward(self, det_split, gt_det, link_score, gt_link): 81 | loss = 0 82 | idx_base = 0 83 | for i in range(len(link_score)): 84 | curr_num = det_split[i].item() 85 | next_num = det_split[i + 1].item() 86 | mask = link_score[i].new_ones(size=link_score[i].size()) 87 | curr_det_mask = (gt_det[idx_base:idx_base + curr_num] == 1).float() 88 | next_det_mask = (gt_det[idx_base + curr_num:idx_base + curr_num + 89 | next_num] == 1).float() 90 | mask.mul_(curr_det_mask.unsqueeze(-1).repeat(1, mask.size(-1))) 91 | mask.mul_(next_det_mask.unsqueeze(0).repeat(mask.size(-2), 1)) 92 | if 'l2' in self.loss_type: 93 | loss += self.l2_loss(link_score[i].mul(mask), 94 | gt_link[i].repeat(mask.size(0), 1, 1)) 95 | if 'l1' in self.loss_type: 96 | loss += self.l1_loss(link_score[i].mul(mask), 97 | gt_link[i].repeat(mask.size(0), 1, 1)) 98 | return loss 99 | 100 | 101 | class DetLoss(nn.Module): 102 | 103 | def __init__(self, loss_type='bce', ignore_index=-1): 104 | super(DetLoss, self).__init__() 105 | self.loss_type = loss_type 106 | self.ignore_index = ignore_index 107 | if loss_type == 'ghm': 108 | print("Use Gradient Harmonized Loss") 109 | from modules.ghm_loss import GHMC_Loss 110 | self.GHMC_Loss = GHMC_Loss(bins=30, momentum=0.75) 111 | 112 | def forward(self, det_score, gt_score): 113 | """ 114 | 115 | :param det_score: 3xL 116 | :param gt_score: L 117 | :return: loss 118 | """ 119 | gt_score = gt_score.unsqueeze(0).repeat(det_score.size(0), 1) 120 | if 'bce' in self.loss_type: 121 | loss = F.binary_cross_entropy_with_logits(det_score, gt_score) 122 | if 'l2' in self.loss_type: 123 | mask = 1 - gt_score.eq(self.ignore_index) 124 | loss = F.mse_loss(det_score.mul(mask.float()), gt_score) 125 | if 'l1' in self.loss_type: 126 | mask = 1 - gt_score.eq(self.ignore_index) 127 | loss = F.smooth_l1_loss(det_score.mul(mask.float()), gt_score) 128 | if 'ghm' in self.loss_type: 129 | mask = 1 - gt_score.eq(self.ignore_index) 130 | loss = self.GHMC_Loss(det_score, gt_score, mask) 131 | return loss 132 | 133 | 134 | class TrackingLoss(nn.Module): 135 | 136 | def __init__(self, 137 | smooth_ratio=0, 138 | detloss_type='bce', 139 | endloss_type='l2', 140 | det_ratio=0.4, 141 | trans_ratio=0.4, 142 | trans_last=False, 143 | linkloss_type='l2_softmax'): 144 | super(TrackingLoss, self).__init__() 145 | self.link_loss = LinkLoss(smooth_ratio, linkloss_type) 146 | self.det_ratio = det_ratio 147 | self.trans_ratio = trans_ratio 148 | self.trans_last = trans_last 149 | self.detloss_type = detloss_type 150 | print("Det ratio " + str(det_ratio)) 151 | if self.trans_last: 152 | print( 153 | f"Only calculate the last transform with weight {trans_ratio}") 154 | self.det_loss = DetLoss(detloss_type) 155 | self.end_loss = DetLoss(endloss_type) 156 | 157 | def forward(self, 158 | det_split, 159 | gt_det, 160 | gt_link, 161 | gt_new, 162 | gt_end, 163 | det_score, 164 | link_score, 165 | new_score, 166 | end_score, 167 | trans=None): 168 | 169 | loss = self.det_loss(det_score, gt_det) * self.det_ratio 170 | loss += self.end_loss(new_score, gt_new[det_split[0]:]) * 0.4 171 | loss += self.end_loss(end_score, gt_end[:-det_split[-1]]) * 0.4 172 | loss += self.link_loss(det_split, gt_det, link_score, gt_link) 173 | 174 | if trans is not None: 175 | if self.trans_last: 176 | for i in range(len(trans)): 177 | identity = trans[0].new_tensor( 178 | torch.eye(trans[i].size(-1))) 179 | loss += F.mse_loss(trans[i] * trans[i].transpose(-1, -2), 180 | identity) * self.trans_ratio 181 | else: 182 | identity = trans[-1].new_tensor(torch.eye(trans[-1].size(-1))) 183 | loss += F.mse_loss(trans[-1] * trans[-1].transpose(-1, -2), 184 | identity) * self.trans_ratio 185 | return loss 186 | -------------------------------------------------------------------------------- /data/tracking/evaluate_tracking.seqmap: -------------------------------------------------------------------------------- 1 | 0000 empty 000000 000154 2 | 0001 empty 000000 000447 3 | 0002 empty 000000 000233 4 | 0003 empty 000000 000144 5 | 0004 empty 000000 000314 6 | 0005 empty 000000 000297 7 | 0006 empty 000000 000270 8 | 0007 empty 000000 000800 9 | 0008 empty 000000 000390 10 | 0009 empty 000000 000803 11 | 0010 empty 000000 000294 12 | 0011 empty 000000 000373 13 | 0012 empty 000000 000078 14 | 0013 empty 000000 000340 15 | 0014 empty 000000 000106 16 | 0015 empty 000000 000376 17 | 0016 empty 000000 000209 18 | 0017 empty 000000 000145 19 | 0018 empty 000000 000339 20 | 0019 empty 000000 001059 21 | 0020 empty 000000 000837 22 | -------------------------------------------------------------------------------- /data/tracking/evaluate_tracking.seqmap.test: -------------------------------------------------------------------------------- 1 | 0000 empty 000000 000465 2 | 0001 empty 000000 000147 3 | 0002 empty 000000 000243 4 | 0003 empty 000000 000257 5 | 0004 empty 000000 000421 6 | 0005 empty 000000 000809 7 | 0006 empty 000000 000114 8 | 0007 empty 000000 000215 9 | 0008 empty 000000 000165 10 | 0009 empty 000000 000349 11 | 0010 empty 000000 001176 12 | 0011 empty 000000 000774 13 | 0012 empty 000000 000694 14 | 0013 empty 000000 000152 15 | 0014 empty 000000 000850 16 | 0015 empty 000000 000701 17 | 0016 empty 000000 000510 18 | 0017 empty 000000 000305 19 | 0018 empty 000000 000180 20 | 0019 empty 000000 000404 21 | 0020 empty 000000 000173 22 | 0021 empty 000000 000203 23 | 0022 empty 000000 000436 24 | 0023 empty 000000 000430 25 | 0024 empty 000000 000316 26 | 0025 empty 000000 000176 27 | 0026 empty 000000 000170 28 | 0027 empty 000000 000085 29 | 0028 empty 000000 000175 30 | -------------------------------------------------------------------------------- /data/tracking/evaluate_tracking.seqmap.training: -------------------------------------------------------------------------------- 1 | 0000 empty 000000 000154 2 | 0001 empty 000000 000447 3 | 0002 empty 000000 000233 4 | 0003 empty 000000 000144 5 | 0004 empty 000000 000314 6 | 0005 empty 000000 000297 7 | 0006 empty 000000 000270 8 | 0007 empty 000000 000800 9 | 0008 empty 000000 000390 10 | 0009 empty 000000 000803 11 | 0010 empty 000000 000294 12 | 0011 empty 000000 000373 13 | 0012 empty 000000 000078 14 | 0013 empty 000000 000340 15 | 0014 empty 000000 000106 16 | 0015 empty 000000 000376 17 | 0016 empty 000000 000209 18 | 0017 empty 000000 000145 19 | 0018 empty 000000 000339 20 | 0019 empty 000000 001059 21 | 0020 empty 000000 000837 22 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .patchwise_dataset import * 2 | from .test_seq_dataset import * -------------------------------------------------------------------------------- /dataset/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import io 3 | from PIL import Image 4 | 5 | 6 | import torch 7 | import torchvision 8 | 9 | from utils.data_util import generate_seq_dets, generate_seq_gts, generate_seq_dets_rrc, LABEL, LABEL_VERSE, \ 10 | get_rotate_mat, align_pos, align_points, get_frame_det_info, get_transform_mat 11 | 12 | 13 | TRAIN_SEQ_ID = ['0003', '0001', '0013', '0009', '0004', \ 14 | '0020', '0006', '0015', '0008', '0012'] 15 | VALID_SEQ_ID = ['0005', '0007', '0017', '0011', '0002', \ 16 | '0014', '0000', '0010', '0016', '0019', '0018'] 17 | TEST_SEQ_ID = [f'{i:04d}' for i in range(29)] 18 | # Valid sequence 0017 has no cars in detection, 19 | # so it should not be included if val with GT detection 20 | # VALID_SEQ_ID = ['0005', '0007', '0011', '0002', '0014', \ 21 | # '0000', '0010', '0016', '0019', '0018'] 22 | TRAINVAL_SEQ_ID = [f'{i:04d}' for i in range(21)] 23 | 24 | 25 | def pil_loader(img_str): 26 | buff = io.BytesIO(img_str) 27 | 28 | with Image.open(buff) as img: 29 | img = img.convert('RGB') 30 | return img 31 | 32 | 33 | def opencv_loader(value_str): 34 | img_array = np.frombuffer(value_str, np.uint8) 35 | img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) 36 | img = img[:, :, ::-1] 37 | img = torchvision.transforms.ToPILImage()(img) 38 | return img 39 | 40 | 41 | def iou_calculator(x1, y1, x2, y2, gt_x1, gt_y1, gt_x2, gt_y2): 42 | if gt_x1 >= x2 or gt_x2 <= x1 or gt_y1 >= y2 or gt_y2 <= y1: 43 | iou = 0 44 | else: 45 | w1 = x2 - gt_x1 46 | w2 = gt_x2 - x1 47 | h1 = y2 - gt_y1 48 | h2 = gt_y2 - y1 49 | w = w1 if w1 < w2 else w2 50 | h = h1 if h1 < h2 else h2 51 | iu = w * h 52 | iou = iu / ((x2 - x1) * (y2 - y1) + (gt_x2 - gt_x1) * (gt_y2 - gt_y1) - iu) 53 | return iou 54 | 55 | 56 | def generate_det_id(bbox, gt_det, modality): 57 | (x1, y1, x2, y2) = bbox 58 | gt_det_num = gt_det['detection']['id'].shape[0] 59 | gt_id = -1 60 | gt_cls = -1 61 | max_iou = 0 62 | for i in range(gt_det_num): 63 | gt_x1 = np.floor(gt_det['detection']['bbox'][i][0]) 64 | gt_y1 = np.floor(gt_det['detection']['bbox'][i][1]) 65 | gt_x2 = np.ceil(gt_det['detection']['bbox'][i][2]) 66 | gt_y2 = np.ceil(gt_det['detection']['bbox'][i][3]) 67 | iou = iou_calculator(x1, y1, x2, y2, gt_x1, gt_y1, gt_x2, gt_y2) 68 | if max_iou <= iou: 69 | gt_id = gt_det['detection']['id'][i] 70 | gt_cls = gt_det['detection']['name'][i] 71 | max_iou = iou 72 | 73 | if max_iou < 0.3 or gt_cls != LABEL[modality]: 74 | gt_cls = 0 75 | gt_id = -1 76 | elif gt_cls == LABEL[modality]: 77 | gt_cls = 1 78 | assert gt_id != -1 79 | return gt_id, gt_cls 80 | 81 | 82 | def calculate_distance(dets, gt_dets): 83 | import motmetrics as mm 84 | distance = [] 85 | # dets format: X1, Y1, X2, Y2 86 | # distance input format: X1, Y1, W, H 87 | # for i in range(len(dets)): 88 | det = dets.copy() 89 | det[:, 2:] = det[:, 2:] - det[:, :2] 90 | gt_det = gt_dets.copy() 91 | gt_det[:, 2:] = gt_det[:, 2:] - gt_det[:, :2] 92 | return mm.distances.iou_matrix(gt_det, det, max_iou=0.5) 93 | 94 | 95 | def generate_det_id_matrix(dets_bbox, gt_dets): 96 | distance = calculate_distance(dets_bbox, gt_dets['bbox']) 97 | mat = distance.copy() # the smaller the value, the close between det and gt 98 | mat[np.isnan(mat)] = 10 # just set it to a big number 99 | v, idx = torch.min(torch.Tensor(mat), dim=-1) 100 | gt_id = -1 * torch.ones((dets_bbox.shape[0], 1)) 101 | gt_cls = torch.zeros((dets_bbox.shape[0], 1)) 102 | for i in range(len(idx)): 103 | gt_id[idx[i]] = int(gt_dets['id'][i]) 104 | # gt_cls[idx[i]] = 1 # This is modified because gt also has person and dontcare now 105 | if gt_dets['name'][i] == LABEL['Car']: 106 | gt_cls[idx[i]] = 1 107 | elif gt_dets['name'][i] == LABEL['DontCare']: 108 | gt_cls[idx[i]] = -1 109 | else: 110 | gt_cls[idx[i]] = 0 111 | return gt_id.long(), gt_cls.long() 112 | 113 | 114 | def bbox_jitter(bbox, jitter): 115 | shift_bbox = bbox.copy() 116 | shift = np.random.randint(jitter[0], jitter[1], size=bbox.shape) 117 | shift_bbox += shift 118 | return shift_bbox -------------------------------------------------------------------------------- /dataset/patchwise_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import io 3 | from PIL import Image 4 | import pickle 5 | import csv 6 | import random 7 | 8 | import torch 9 | import torchvision 10 | import torchvision.transforms as transforms 11 | from torch.utils.data import DataLoader, Dataset 12 | 13 | from functools import partial 14 | 15 | # For Point Cloud 16 | from point_cloud.preprocess import read_and_prep_points 17 | 18 | # For data structure 19 | from utils.data_util import generate_seq_dets, generate_seq_gts, generate_seq_dets_rrc, LABEL, LABEL_VERSE, \ 20 | get_rotate_mat, align_pos, align_points, get_frame_det_info, get_transform_mat 21 | 22 | 23 | from .common import * 24 | 25 | 26 | class PatchwiseDataset(Dataset): 27 | 28 | def __init__(self, root_dir, meta_file, link_file, det_file, det_type='2D', fix_iou=0.2, fix_count=2, 29 | tracker_type='3D', use_frustum=False, without_reflectivity=True, bbox_jitter=False, transform=None, 30 | num_point_features=4, gt_ratio=0, sample_max_len=2, modality='Car', train=True): 31 | self.root_dir = root_dir 32 | self.gt_ratio = gt_ratio 33 | self.train = train 34 | self.bbox_jitter = bbox_jitter 35 | self.sample_max_len = sample_max_len 36 | self.modality = modality 37 | self.num_point_features = num_point_features 38 | self.tracker_type = tracker_type 39 | self.det_type = det_type 40 | self.use_frustum = use_frustum 41 | self.without_reflectivity = without_reflectivity 42 | # self.rank = link.get_rank() 43 | # if self.rank == 0: 44 | if "trainval" in link_file: 45 | self.seq_ids = TRAINVAL_SEQ_ID 46 | elif "train" in link_file: 47 | self.seq_ids = TRAIN_SEQ_ID 48 | else: 49 | self.seq_ids = VALID_SEQ_ID 50 | 51 | self.sequence_det = generate_seq_dets(root_dir, link_file, det_file, self.seq_ids, 52 | iou_threshold=fix_iou, fix_threshold=fix_count, 53 | allow_empty=(self.gt_ratio == 1)) 54 | self.sequence_gt = generate_seq_gts(root_dir, self.seq_ids, self.sequence_det, modality) 55 | 56 | if transform == None: 57 | self.transform = transforms.Compose([transforms.ToTensor()]) 58 | else: 59 | self.transform = transform 60 | 61 | self.get_pointcloud = partial(read_and_prep_points, root_path=root_dir, 62 | without_reflectivity = without_reflectivity, 63 | num_point_features=num_point_features, 64 | det_type=self.det_type, use_frustum=use_frustum) 65 | 66 | self.metas = self._generate_meta() 67 | 68 | def __len__(self): 69 | return len(self.metas) 70 | 71 | def __getitem__(self, idx): 72 | use_gt = 0 73 | if self.gt_ratio > 0: 74 | if random.random() < self.gt_ratio: 75 | use_gt = 1 76 | if self.tracker_type == '3D': 77 | return self._generate_img_lidar(idx, use_gt) 78 | elif self.tracker_type == '2D': 79 | return self._generate_img(idx, use_gt) 80 | 81 | def _generate_img(self, idx, use_gt): 82 | frames = self.metas[idx][use_gt] 83 | gt_frames = self.metas[idx][1] 84 | det_imgs = [] 85 | det_split = [] 86 | det_ids = [] 87 | det_cls = [] 88 | 89 | for (frame, gt_frame) in zip(frames, gt_frames): 90 | path = f"{self.root_dir}/image_02/{frame['image_path']}" 91 | img = Image.open(path) 92 | det_num = frame['detection']['bbox'].shape[0] 93 | if self.bbox_jitter is not None: 94 | shift_bbox = bbox_jitter(frame['detection']['bbox'], self.bbox_jitter) 95 | else: 96 | shift_bbox = frame['detection']['bbox'] 97 | frame_ids, frame_cls = generate_det_id_matrix(shift_bbox, gt_frame['detection']) 98 | for i in range(frame['detection']['bbox'].shape[0]): 99 | x1 = np.floor(shift_bbox[i, 0]) 100 | y1 = np.floor(shift_bbox[i, 1]) 101 | x2 = np.ceil(shift_bbox[i, 2]) 102 | y2 = np.ceil(shift_bbox[i, 3]) 103 | det_imgs.append( 104 | self.transform(img.crop((x1, y1, x2, y2)).resize((224, 224), Image.BILINEAR)).unsqueeze(0)) 105 | 106 | assert len(frame_ids) > 0 107 | det_split.append(det_num) 108 | det_ids.append(frame_ids) 109 | det_cls.append(frame_cls) 110 | 111 | det_imgs = torch.cat(det_imgs, dim=0) 112 | det_info = [] 113 | return det_imgs, det_info, det_ids, det_cls, det_split 114 | 115 | def _generate_img_lidar(self, idx, use_gt): 116 | frames = self.metas[idx][use_gt] 117 | gt_frames = self.metas[idx][1] 118 | det_imgs = [] 119 | det_split = [] 120 | det_ids = [] 121 | det_cls = [] 122 | det_info = get_frame_det_info() 123 | R = [] 124 | T = [] 125 | pos = [] 126 | rad = [] 127 | delta_rad = [] 128 | first_flag = 0 129 | for (frame, gt_frame) in zip(frames, gt_frames): 130 | path = f"{self.root_dir}/image_02/{frame['image_path']}" 131 | img = Image.open(path) 132 | frame['frame_info']['img_shape'] = img.size 133 | det_num = frame['detection']['bbox'].shape[0] 134 | if self.bbox_jitter is not None: 135 | shift_bbox = bbox_jitter(frame['detection']['bbox'], self.bbox_jitter) 136 | else: 137 | shift_bbox = frame['detection']['bbox'] 138 | frame['frame_info']['img_shape'] = np.array([img.size[1], img.size[0]]) # w, h -> h, w 139 | point_cloud = self.get_pointcloud(info=frame['frame_info'], point_path=frame['point_path'], 140 | dets=frame['detection'], shift_bbox=shift_bbox) 141 | pos.append(frame['frame_info']['pos']) 142 | rad.append(frame['frame_info']['rad']) 143 | 144 | # Align the bbox to the same coordinate 145 | loc = [] 146 | rot = [] 147 | dim = [] 148 | if len(rad) >= 2: 149 | delta_rad.append(rad[-1] - rad[-2]) 150 | R.append(get_rotate_mat(delta_rad[-1], rotate_order=[1, 2, 3])) 151 | T.append(get_transform_mat(pos[-1] - pos[-2], rad[-2][-1])) 152 | location, rotation_y = align_pos(R, T, frame['frame_info']['calib/Tr_velo_to_cam'], 153 | frame['frame_info']['calib/Tr_imu_to_velo'], 154 | frame['frame_info']['calib/R0_rect'], delta_rad, 155 | frame['detection']['location'], 156 | frame['detection']['rotation_y']) 157 | point_cloud['points'][:,:3] = align_points(R, T, frame['frame_info']['calib/Tr_imu_to_velo'], point_cloud['points'][:,:3]) 158 | 159 | frame_ids, frame_cls = generate_det_id_matrix(shift_bbox, gt_frame['detection']) 160 | 161 | for i in range(frame['detection']['bbox'].shape[0]): 162 | x1 = np.floor(shift_bbox[i, 0]) 163 | y1 = np.floor(shift_bbox[i, 1]) 164 | x2 = np.ceil(shift_bbox[i, 2]) 165 | y2 = np.ceil(shift_bbox[i, 3]) 166 | dim.append(frame['detection']['dimensions'][i:i+1]) 167 | loc.append(location[i:i+1]) 168 | rot.append(rotation_y[i:i+1].reshape(1,1)) 169 | 170 | det_imgs.append( 171 | self.transform(img.crop((x1, y1, x2, y2)).resize((224, 224), Image.BILINEAR)).unsqueeze(0)) 172 | 173 | assert len(frame_ids) > 0 174 | det_split.append(det_num) 175 | det_ids.append(frame_ids) 176 | det_cls.append(frame_cls) 177 | det_info['loc'].append(torch.Tensor(np.concatenate(loc, axis=0))) 178 | det_info['rot'].append(torch.Tensor(np.concatenate(rot, axis=0))) 179 | det_info['dim'].append(torch.Tensor(np.concatenate(dim, axis=0))) 180 | det_info['points'].append(torch.Tensor(point_cloud['points'])) 181 | det_info['points_split'].append(torch.Tensor(point_cloud['points_split'])[first_flag:]) 182 | det_info['info_id'].append(frame['frame_info']['info_id']) 183 | if first_flag == 0: 184 | first_flag += 1 185 | 186 | det_imgs = torch.cat(det_imgs, dim=0) 187 | det_info['loc'] = torch.cat(det_info['loc'], dim=0) 188 | det_info['rot'] = torch.cat(det_info['rot'], dim=0) 189 | det_info['dim'] = torch.cat(det_info['dim'], dim=0) 190 | det_info['points'] = torch.cat(det_info['points'], dim=0) 191 | det_info['bbox'] = frame['detection']['bbox'] # temporally for debug 192 | 193 | # Shift the point split idx 194 | start = 0 195 | for i in range(len(det_info['points_split'])): 196 | det_info['points_split'][i] += start 197 | start = det_info['points_split'][i][-1] 198 | det_info['points_split'] = torch.cat(det_info['points_split'], dim=0) 199 | 200 | return det_imgs, det_info, det_ids, det_cls, det_split 201 | 202 | def _generate_meta(self): 203 | metas = [] 204 | for seq_id in self.seq_ids: 205 | seq_length = len(self.sequence_gt[seq_id]) 206 | for i in range(seq_length - self.sample_max_len + 1): 207 | gt_frames = [] 208 | det_frames = [] 209 | for j in range(self.sample_max_len): 210 | frame_id = self.sequence_gt[seq_id][i + j]['frame_id'] 211 | if self.sequence_det[seq_id].__contains__(frame_id): 212 | gt_frames.append(self.sequence_gt[seq_id][i + j]) 213 | det_frames.append(self.sequence_det[seq_id][frame_id]) 214 | else: 215 | continue 216 | if len(gt_frames) == self.sample_max_len: 217 | metas.append((det_frames, gt_frames)) 218 | 219 | return metas 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | -------------------------------------------------------------------------------- /dataset/test_seq_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import io 3 | from PIL import Image 4 | import pickle 5 | import csv 6 | import random 7 | 8 | import torch 9 | import torchvision 10 | import torchvision.transforms as transforms 11 | from torch.utils.data import DataLoader, Dataset 12 | 13 | from functools import partial 14 | 15 | # For Point Cloud 16 | from point_cloud.preprocess import read_and_prep_points 17 | 18 | 19 | from .common import * 20 | 21 | from utils.data_util import generate_seq_dets, generate_seq_gts, generate_seq_dets_rrc, LABEL, LABEL_VERSE, \ 22 | get_rotate_mat, align_pos, align_points, get_frame_det_info, get_transform_mat 23 | 24 | 25 | class TestSequenceDataset(object): 26 | 27 | def __init__(self, root_dir, meta_file, link_file, det_file, det_type='2D', tracker_type='3D', 28 | use_frustum=False, without_reflectivity=True, fix_iou=0.2, fix_count=2, 29 | transform=None, num_point_features=4, gt_ratio=0, sample_max_len=2, modality='Car'): 30 | self.root_dir = root_dir 31 | self.sample_max_len = sample_max_len 32 | self.modality = modality 33 | self.det_type = det_type 34 | self.num_point_features = num_point_features 35 | self.test = False 36 | self.tracker_type = tracker_type 37 | self.use_frustum = use_frustum 38 | self.without_reflectivity = without_reflectivity 39 | 40 | if "trainval" in link_file: 41 | self.seq_ids = TRAINVAL_SEQ_ID 42 | elif "train" in link_file: 43 | self.seq_ids = TRAIN_SEQ_ID 44 | elif "val" in link_file: 45 | self.seq_ids = VALID_SEQ_ID 46 | elif 'test' in link_file: 47 | self.test = True 48 | self.seq_ids = TEST_SEQ_ID 49 | 50 | self.sequence_det = generate_seq_dets(root_dir, link_file, det_file, self.seq_ids, 51 | iou_threshold=fix_iou, fix_threshold=fix_count, 52 | allow_empty=True, test=self.test) 53 | 54 | if transform == None: 55 | self.transform = transforms.Compose([transforms.ToTensor()]) 56 | else: 57 | self.transform = transform 58 | 59 | self.get_pointcloud = partial(read_and_prep_points, root_path=root_dir, 60 | use_frustum=use_frustum, without_reflectivity=without_reflectivity, 61 | num_point_features=num_point_features, det_type=self.det_type) 62 | 63 | self.metas = self._generate_meta_seq() 64 | 65 | def __len__(self): 66 | return len(self.metas) 67 | 68 | def __getitem__(self, idx): 69 | return self.metas[idx] 70 | 71 | def _generate_meta_seq(self): 72 | metas = [] 73 | for seq_id in self.seq_ids: 74 | if seq_id == '0007': 75 | els = list(self.sequence_det[seq_id].items()) 76 | seq_length = int(els[-1][0]) 77 | else: 78 | seq_length = len(self.sequence_det[seq_id]) 79 | det_seq = [] 80 | gt_seq = [] 81 | # TODO: support interval > 2 82 | for i in range(0, seq_length - self.sample_max_len + 1, self.sample_max_len - 1): 83 | det_frames = [] 84 | frame_id = f'{i:06d}' 85 | # Get first frame, skip the empty frame 86 | if frame_id in self.sequence_det[seq_id] and \ 87 | len(self.sequence_det[seq_id][frame_id]['detection']['name']) > 0: 88 | det_frames.append(self.sequence_det[seq_id][frame_id]) 89 | else: 90 | continue 91 | # Get next frame untill the end, 92 | # 10 could handle most case where objs are still linked 93 | for j in range(1, seq_length-i): 94 | frame_id = f'{i + j:06d}' 95 | if frame_id in self.sequence_det[seq_id] and \ 96 | len(self.sequence_det[seq_id][frame_id]['detection']['name']) > 0: 97 | det_frames.append(self.sequence_det[seq_id][frame_id]) 98 | if len(det_frames) == self.sample_max_len: 99 | det_seq.append(det_frames) 100 | # if j > 1: 101 | # print(f"In ID-{seq_id}, {i:06d}->{i + j:06d} are linked!") 102 | break 103 | 104 | metas.append(TestSequence(name=seq_id, modality=self.modality, det_type=self.det_type, 105 | tracker_type=self.tracker_type, root_dir=self.root_dir, 106 | det_frames=det_seq, use_frustum=self.use_frustum, 107 | without_reflectivity=self.without_reflectivity, 108 | interval=self.sample_max_len, transform=self.transform, 109 | get_pointcloud=self.get_pointcloud)) 110 | return metas 111 | 112 | 113 | class TestSequence(Dataset): 114 | 115 | def __init__(self, name, modality, det_type, tracker_type, root_dir, det_frames, 116 | use_frustum, without_reflectivity,interval, transform, get_pointcloud): 117 | self.det_frames = det_frames 118 | self.interval = interval 119 | self.root_dir = root_dir 120 | self.metas = det_frames 121 | self.idx = 0 122 | self.seq_len = len(det_frames) + interval -1 123 | self.name = name 124 | self.modality = modality 125 | self.get_pointcloud = get_pointcloud 126 | self.det_type = det_type 127 | self.tracker_type = tracker_type 128 | self.use_frustum = use_frustum 129 | self.without_reflectivity = without_reflectivity 130 | 131 | if transform == None: 132 | self.transform = transforms.Compose([transforms.ToTensor()]) 133 | else: 134 | self.transform = transform 135 | assert len(det_frames) == len(det_frames) 136 | 137 | def __getitem__(self, idx): 138 | if self.tracker_type == '3D': 139 | return self._generate_img_lidar(idx) 140 | elif self.tracker_type == '2D': 141 | return self._generate_img(idx) 142 | 143 | def __len__(self): 144 | return len(self.metas) 145 | 146 | def _generate_img(self, idx): 147 | frames = self.metas[idx] 148 | det_imgs = [] 149 | det_split = [] 150 | dets = [] 151 | 152 | for frame in frames: 153 | path = f"{self.root_dir}/image_02/{frame['image_path']}" 154 | img = Image.open(path) 155 | det_num = frame['detection']['bbox'].shape[0] 156 | for i in range(det_num): 157 | x1 = np.floor(frame['detection']['bbox'][i][0]) 158 | y1 = np.floor(frame['detection']['bbox'][i][1]) 159 | x2 = np.ceil(frame['detection']['bbox'][i][2]) 160 | y2 = np.ceil(frame['detection']['bbox'][i][3]) 161 | det_imgs.append( 162 | self.transform(img.crop((x1, y1, x2, y2)).resize((224, 224), Image.BILINEAR)).unsqueeze(0)) 163 | 164 | 165 | if 'image_idx' in frame['detection'].keys(): 166 | frame['detection'].pop('image_idx') 167 | 168 | dets.append(frame['detection']) 169 | det_split.append(det_num) 170 | 171 | det_imgs = torch.cat(det_imgs, dim=0) 172 | 173 | det_info = [] 174 | return det_imgs, det_info, dets, det_split 175 | 176 | def _generate_img_lidar(self, idx): 177 | frames = self.metas[idx] 178 | 179 | det_imgs = [] 180 | det_split = [] 181 | dets = [] 182 | det_info = get_frame_det_info() 183 | R = [] 184 | T = [] 185 | pos = [] 186 | rad = [] 187 | delta_rad = [] 188 | first_flag = 0 189 | for frame in frames: 190 | path =f"{self.root_dir}/image_02/{frame['image_path']}" 191 | img = Image.open(path) 192 | det_num = frame['detection']['bbox'].shape[0] 193 | frame['frame_info']['img_shape'] = np.array([img.size[1], img.size[0]]) # w, h -> h, w 194 | point_cloud = self.get_pointcloud(info=frame['frame_info'], point_path=frame['point_path'], 195 | dets=frame['detection'], shift_bbox = frame['detection']['bbox']) 196 | pos.append(frame['frame_info']['pos']) 197 | rad.append(frame['frame_info']['rad']) 198 | 199 | # Align the bbox to the same coordinate 200 | if len(rad) >= 2: 201 | delta_rad.append(rad[-1] - rad[-2]) 202 | R.append(get_rotate_mat(delta_rad[-1], rotate_order=[1, 2, 3])) 203 | T.append(get_transform_mat(pos[-1] - pos[-2], rad[-2][-1])) 204 | location, rotation_y = align_pos(R, T, frame['frame_info']['calib/Tr_velo_to_cam'], 205 | frame['frame_info']['calib/Tr_imu_to_velo'], 206 | frame['frame_info']['calib/R0_rect'], delta_rad, 207 | frame['detection']['location'], 208 | frame['detection']['rotation_y']) 209 | point_cloud['points'][:,:3] = align_points(R, T, frame['frame_info']['calib/Tr_imu_to_velo'], 210 | point_cloud['points'][:,:3] ) 211 | 212 | for i in range(det_num): 213 | x1 = np.floor(frame['detection']['bbox'][i][0]) 214 | y1 = np.floor(frame['detection']['bbox'][i][1]) 215 | x2 = np.ceil(frame['detection']['bbox'][i][2]) 216 | y2 = np.ceil(frame['detection']['bbox'][i][3]) 217 | det_imgs.append( 218 | self.transform(img.crop((x1, y1, x2, y2)).resize((224, 224), Image.BILINEAR)).unsqueeze(0)) 219 | 220 | if 'image_idx' in frame['detection'].keys(): 221 | frame['detection'].pop('image_idx') 222 | dets.append(frame['detection']) 223 | det_split.append(det_num) 224 | det_info['loc'].append(torch.Tensor(location)) 225 | det_info['rot'].append(torch.Tensor(rotation_y)) 226 | det_info['dim'].append(torch.Tensor(frame['detection']['dimensions'])) 227 | det_info['points'].append(torch.Tensor(point_cloud['points'])) 228 | det_info['points_split'].append(torch.Tensor(point_cloud['points_split'])[first_flag:]) 229 | det_info['info_id'].append(frame['frame_info']['info_id']) 230 | if first_flag == 0: 231 | first_flag += 1 232 | 233 | det_imgs = torch.cat(det_imgs, dim=0) 234 | det_info['loc'] = torch.cat(det_info['loc'], dim=0) 235 | det_info['rot'] = torch.cat(det_info['rot'], dim=0) 236 | det_info['dim'] = torch.cat(det_info['dim'], dim=0) 237 | det_info['points'] = torch.cat(det_info['points'], dim=0) 238 | 239 | # Shift the point split idx 240 | start = 0 241 | for i in range(len(det_info['points_split'])): 242 | det_info['points_split'][i] += start 243 | start = det_info['points_split'][i][-1] 244 | det_info['points_split'] = torch.cat(det_info['points_split'], dim=0) 245 | 246 | return det_imgs, det_info, dets, det_split 247 | -------------------------------------------------------------------------------- /dev.sh: -------------------------------------------------------------------------------- 1 | flake8 . 2 | isort -rc --check-only --diff . 3 | yapf -r -d --style .style.yapf . -------------------------------------------------------------------------------- /eval_seq.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import pprint 5 | import time 6 | 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import yaml 11 | from easydict import EasyDict 12 | from kitti_devkit.evaluate_tracking import evaluate 13 | from torch.utils.data import DataLoader 14 | from tracking_model import TrackingModule 15 | from utils.build_util import build_augmentation, build_dataset, build_model 16 | from utils.data_util import write_kitti_result 17 | from utils.train_util import AverageMeter, create_logger, load_state 18 | 19 | parser = argparse.ArgumentParser(description='PyTorch mmMOT Evaluation') 20 | parser.add_argument('--config', default='cfgs/config_res50.yaml') 21 | parser.add_argument('--load-path', default='', type=str) 22 | parser.add_argument('--result-path', default='', type=str) 23 | parser.add_argument('--recover', action='store_true') 24 | parser.add_argument('-e', '--evaluate', action='store_true') 25 | parser.add_argument('--result_sha', default='last') 26 | parser.add_argument('--memory', action='store_true') 27 | 28 | 29 | def main(): 30 | global args, config, best_mota 31 | args = parser.parse_args() 32 | 33 | with open(args.config) as f: 34 | config = yaml.load(f, Loader=yaml.FullLoader) 35 | 36 | config = EasyDict(config['common']) 37 | config.save_path = os.path.dirname(args.config) 38 | 39 | # create model 40 | model = build_model(config) 41 | model.cuda() 42 | 43 | # optionally resume from a checkpoint 44 | load_state(args.load_path, model) 45 | 46 | cudnn.benchmark = True 47 | 48 | # Data loading code 49 | train_transform, valid_transform = build_augmentation(config.augmentation) 50 | 51 | # build dataset 52 | train_dataset = build_dataset( 53 | config, 54 | set_source='train', 55 | evaluate=True, 56 | valid_transform=valid_transform) 57 | val_dataset = build_dataset( 58 | config, 59 | set_source='val', 60 | evaluate=True, 61 | valid_transform=valid_transform) 62 | 63 | logger = create_logger('global_logger', config.save_path + '/eval_log.txt') 64 | logger.info('args: {}'.format(pprint.pformat(args))) 65 | logger.info('config: {}'.format(pprint.pformat(config))) 66 | 67 | tracking_module = TrackingModule(model, None, None, config.det_type) 68 | 69 | logger.info('Evaluation on traing set:') 70 | validate(train_dataset, tracking_module, args.result_sha, part='train') 71 | logger.info('Evaluation on validation set:') 72 | validate(val_dataset, tracking_module, args.result_sha, part='val') 73 | 74 | 75 | def validate(val_loader, 76 | tracking_module, 77 | step, 78 | part='train', 79 | fusion_list=None, 80 | fuse_prob=False): 81 | prec = AverageMeter(0) 82 | rec = AverageMeter(0) 83 | mota = AverageMeter(0) 84 | motp = AverageMeter(0) 85 | 86 | logger = logging.getLogger('global_logger') 87 | for i, (sequence) in enumerate(val_loader): 88 | logger.info('Test: [{}/{}]\tSequence ID: KITTI-{}'.format( 89 | i, len(val_loader), sequence.name)) 90 | seq_loader = DataLoader( 91 | sequence, 92 | batch_size=config.batch_size, 93 | shuffle=False, 94 | num_workers=config.workers, 95 | pin_memory=True) 96 | if len(seq_loader) == 0: 97 | tracking_module.eval() 98 | logger.info('Empty Sequence ID: KITTI-{}, skip'.format( 99 | sequence.name)) 100 | else: 101 | if args.memory: 102 | seq_prec, seq_rec, seq_mota, seq_motp = validate_mem_seq( 103 | seq_loader, tracking_module) 104 | else: 105 | seq_prec, seq_rec, seq_mota, seq_motp = validate_seq( 106 | seq_loader, tracking_module) 107 | prec.update(seq_prec, 1) 108 | rec.update(seq_rec, 1) 109 | mota.update(seq_mota, 1) 110 | motp.update(seq_motp, 1) 111 | 112 | write_kitti_result( 113 | args.result_path, 114 | sequence.name, 115 | step, 116 | tracking_module.frames_id, 117 | tracking_module.frames_det, 118 | part=part) 119 | 120 | total_num = torch.Tensor([prec.count]) 121 | logger.info( 122 | '* Prec: {:.3f}\tRec: {:.3f}\tMOTA: {:.3f}\tMOTP: {:.3f}\ttotal_num={}' 123 | .format(prec.avg, rec.avg, mota.avg, motp.avg, total_num.item())) 124 | MOTA, MOTP, recall, prec, F1, fp, fn, id_switches = evaluate( 125 | step, args.result_path, part=part) 126 | 127 | tracking_module.train() 128 | return MOTA, MOTP, recall, prec, F1, fp, fn, id_switches 129 | 130 | 131 | def validate_seq(val_loader, 132 | tracking_module, 133 | fusion_list=None, 134 | fuse_prob=False): 135 | batch_time = AverageMeter(0) 136 | 137 | # switch to evaluate mode 138 | tracking_module.eval() 139 | 140 | logger = logging.getLogger('global_logger') 141 | end = time.time() 142 | 143 | with torch.no_grad(): 144 | for i, (input, det_info, dets, det_split) in enumerate(val_loader): 145 | input = input.cuda() 146 | if len(det_info) > 0: 147 | for k, v in det_info.items(): 148 | det_info[k] = det_info[k].cuda() if not isinstance( 149 | det_info[k], list) else det_info[k] 150 | 151 | # compute output 152 | aligned_ids, aligned_dets, frame_start = tracking_module.predict( 153 | input[0], det_info, dets, det_split) 154 | 155 | batch_time.update(time.time() - end) 156 | end = time.time() 157 | if i % config.print_freq == 0: 158 | logger.info('Test Frame: [{0}/{1}]\tTime {batch_time.val:.3f}' 159 | '({batch_time.avg:.3f})'.format( 160 | i, len(val_loader), batch_time=batch_time)) 161 | 162 | return 0, 0, 0, 0 163 | 164 | 165 | def validate_mem_seq(val_loader, 166 | tracking_module, 167 | fusion_list=None, 168 | fuse_prob=False): 169 | batch_time = AverageMeter(0) 170 | 171 | # switch to evaluate mode 172 | tracking_module.eval() 173 | 174 | logger = logging.getLogger('global_logger') 175 | end = time.time() 176 | 177 | with torch.no_grad(): 178 | for i, (input, det_info, dets, det_split, gt_dets, gt_ids, 179 | gt_cls) in enumerate(val_loader): 180 | input = input.cuda() 181 | if len(det_info) > 0: 182 | for k, v in det_info.items(): 183 | det_info[k] = det_info[k].cuda() if not isinstance( 184 | det_info[k], list) else det_info[k] 185 | 186 | # compute output 187 | results = tracking_module.mem_predict(input[0], det_info, dets, 188 | det_split) 189 | aligned_ids, aligned_dets, frame_start = results 190 | 191 | # measure elapsed time 192 | batch_time.update(time.time() - end) 193 | end = time.time() 194 | if i % config.print_freq == 0: 195 | logger.info( 196 | 'Test Frame: [{0}/{1}]\tTime ' 197 | '{batch_time.val:.3f} ({batch_time.avg:.3f})'.format( 198 | i, len(val_loader), batch_time=batch_time)) 199 | 200 | return 0, 0, 0, 0 201 | 202 | 203 | if __name__ == '__main__': 204 | main() 205 | -------------------------------------------------------------------------------- /experiments/pp_pv_40e_dualadd_subabs_C/config.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | model: 3 | point_arch: v1 4 | point_len: 512 5 | appear_arch: vgg 6 | appear_len: 512 7 | appear_skippool: True 8 | appear_fpn: False 9 | 10 | end_arch: v2 11 | end_mode: avg 12 | 13 | affinity_op: minus_abs # multiply or addminus 14 | softmax_mode: dual_add 15 | 16 | score_arch: branch_cls 17 | neg_threshold: 0.2 18 | 19 | score_fusion_arch: C # A, B, C 20 | test_mode: 2 #0:image;1:LiDAR;2:fusion 21 | 22 | gt_det_ratio : 0 23 | sample_max_len : 2 24 | det_type: 3D 25 | tracker_type: 3D 26 | use_frustum: False 27 | without_reflectivity: True 28 | train_fix_iou: 1 29 | train_fix_count: 0 30 | val_fix_iou: 1 31 | val_fix_count: 0 32 | use_dropout: False 33 | dropblock: 0 34 | 35 | augmentation: 36 | input_size: 224 37 | test_resize: 224 38 | 39 | loss: 40 | det_loss: bce 41 | link_loss: l2 42 | smooth_ratio: 0 43 | det_ratio: 1.5 44 | trans_ratio: 0.001 45 | trans_last: True 46 | 47 | workers: 1 48 | batch_size: 1 49 | lr_scheduler: 50 | #type: COSINE 51 | optim: Adam 52 | type: one_cycle 53 | base_lr: 0.0003 54 | lr_max: 0.0006 55 | moms: [0.95, 0.85] 56 | div_factor: 10.0 57 | pct_start: 0.4 58 | max_iter: 134200 # 40 * 3355 59 | 60 | fixed_wd: true 61 | use_moving_average: false 62 | momentum: 0.9 63 | #weight_decay: 0.0001 64 | weight_decay: 0.01 # super converge. decrease this when you increase steps. 65 | 66 | val_freq: 3355 # exact num of samples in 1 epoch with pp 67 | print_freq: 100 68 | 69 | train_root: ./kitti_t_o/training 70 | train_source: ./kitti_t_o/training/ 71 | train_link : ./data/train.txt 72 | train_det: ./data/pp_train_dets.pkl 73 | 74 | val_root: ./kitti_t_o/training 75 | val_source: ./kitti_t_o/training/ 76 | val_link : ./data/val.txt 77 | val_det : ./data/pp_val_dets.pkl 78 | 79 | -------------------------------------------------------------------------------- /experiments/pp_pv_40e_dualadd_subabs_C/eval.sh: -------------------------------------------------------------------------------- 1 | now=$(date +"%Y%m%d_%H%M%S") 2 | work_path=$(dirname $0) 3 | srun --mpi=pmi2 -p $1 -n1 --gres=gpu:1 --ntasks-per-node=1 \ 4 | python -u eval_seq.py --config $work_path/config.yaml \ 5 | --load-path=./pretrain_models/pp_pv_40e_dualadd_subabs_C.pth \ 6 | --result-path=$work_path/results \ 7 | --result_sha=Tracking \ 8 | 2>&1|tee $work_path/Eval-${now}.log 9 | -------------------------------------------------------------------------------- /experiments/pp_pv_40e_dualadd_subabs_C/test.sh: -------------------------------------------------------------------------------- 1 | now=$(date +"%Y%m%d_%H%M%S") 2 | work_path=$(dirname $0) 3 | # srun --mpi=pmi2 -p $1 -n1 --gres=gpu:1 --ntasks-per-node=1 \ 4 | python -u test.py --config $work_path/config.yaml \ 5 | --result-path=$work_path/test_results \ 6 | --result_sha=test \ 7 | --load-path=$work_path/ckpt_best.pth.tar \ 8 | 2>&1|tee $work_path/Test-img_${now}.log & 9 | 10 | -------------------------------------------------------------------------------- /experiments/pp_pv_40e_dualadd_subabs_C/train.sh: -------------------------------------------------------------------------------- 1 | now=$(date +"%Y%m%d_%H%M%S") 2 | work_path=$(dirname $0) 3 | # srun --mpi=pmi2 -p $1 -n1 --gres=gpu:1 --ntasks-per-node=1 \ 4 | python -u main.py --config $work_path/config.yaml \ 5 | --result-path=$work_path/results \ 6 | 2>&1|tee $work_path/T-${now}.log & 7 | #--load-path=$work_path/ckpt.pth.tar \ 8 | #--recover 9 | -------------------------------------------------------------------------------- /experiments/pp_pv_40e_mul_A/config.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | model: 3 | point_arch: v1 4 | point_len: 512 5 | appear_arch: vgg 6 | appear_len: 512 7 | appear_skippool: True 8 | appear_fpn: False 9 | 10 | end_arch: v2 11 | end_mode: avg 12 | 13 | affinity_op: multiply # multiply or addminus 14 | softmax_mode: none 15 | 16 | score_arch: branch_cls # end is v1 17 | neg_threshold: 0.2 18 | 19 | score_fusion_arch: A 20 | test_mode: 2 #0:image;1:LiDAR;2:fusion 21 | 22 | gt_det_ratio : 0 23 | sample_max_len : 2 24 | det_type: 3D 25 | tracker_type: 3D 26 | use_frustum: False 27 | without_reflectivity: True 28 | train_fix_iou: 1 29 | train_fix_count: 0 30 | val_fix_iou: 1 31 | val_fix_count: 0 32 | use_dropout: False 33 | dropblock: 0 34 | 35 | augmentation: 36 | input_size: 224 37 | test_resize: 224 38 | 39 | loss: 40 | det_loss: bce 41 | link_loss: l2 42 | smooth_ratio: 0 43 | det_ratio: 1.5 44 | trans_ratio: 0.001 45 | trans_last: True 46 | 47 | workers: 1 48 | batch_size: 1 49 | lr_scheduler: 50 | #type: COSINE 51 | optim: Adam 52 | type: one_cycle 53 | base_lr: 0.0003 54 | lr_max: 0.0006 55 | moms: [0.95, 0.85] 56 | div_factor: 10.0 57 | pct_start: 0.4 58 | max_iter: 134200 # 40 * 3355 59 | 60 | fixed_wd: true 61 | use_moving_average: false 62 | momentum: 0.9 63 | #weight_decay: 0.0001 64 | weight_decay: 0.01 # super converge. decrease this when you increase steps. 65 | 66 | val_freq: 3355 # exact num of samples in 1 epoch with pp 67 | print_freq: 100 68 | 69 | train_root: ./kitti_t_o/training 70 | train_source: ./kitti_t_o/training/ 71 | train_link : ./data/train.txt 72 | train_det: ./data/pp_train_dets.pkl 73 | 74 | val_root: ./kitti_t_o/training 75 | val_source: ./kitti_t_o/training/ 76 | val_link : ./data/val.txt 77 | val_det : ./data/pp_val_dets.pkl 78 | 79 | -------------------------------------------------------------------------------- /experiments/pp_pv_40e_mul_A/eval.sh: -------------------------------------------------------------------------------- 1 | now=$(date +"%Y%m%d_%H%M%S") 2 | work_path=$(dirname $0) 3 | srun --mpi=pmi2 -p $1 -n1 --gres=gpu:1 --ntasks-per-node=1 \ 4 | python -u eval_seq.py --config $work_path/config.yaml \ 5 | --load-path=./pretrain_models/pp_pv_40e_mul_A-gpu.pth \ 6 | --result-path=$work_path/results \ 7 | --result_sha=all \ 8 | 2>&1|tee $work_path/Eval-pts-${now}.log 9 | -------------------------------------------------------------------------------- /experiments/pp_pv_40e_mul_A/train.sh: -------------------------------------------------------------------------------- 1 | now=$(date +"%Y%m%d_%H%M%S") 2 | work_path=$(dirname $0) 3 | srun --mpi=pmi2 -p $1 -n1 --gres=gpu:1 --ntasks-per-node=1 \ 4 | python -u main.py --config $work_path/config.yaml \ 5 | --result-path=$work_path/results \ 6 | 2>&1|tee $work_path/T-${now}.log 7 | #--load-path=$work_path/ckpt.pth.tar \ 8 | #--recover 9 | -------------------------------------------------------------------------------- /experiments/pp_pv_40e_mul_B/config.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | model: 3 | point_arch: v1 4 | point_len: 512 5 | appear_arch: vgg 6 | appear_len: 512 7 | appear_skippool: True 8 | appear_fpn: False 9 | 10 | end_arch: v2 11 | end_mode: avg 12 | 13 | affinity_op: multiply # multiply or addminus 14 | softmax_mode: none 15 | 16 | score_arch: branch_cls # end is v1 17 | neg_threshold: 0.2 18 | 19 | score_fusion_arch: B 20 | test_mode: 2 #0:image;1:LiDAR;2:fusion 21 | 22 | gt_det_ratio : 0 23 | sample_max_len : 2 24 | det_type: 3D 25 | tracker_type: 3D 26 | use_frustum: False 27 | without_reflectivity: True 28 | train_fix_iou: 1 29 | train_fix_count: 0 30 | val_fix_iou: 1 31 | val_fix_count: 0 32 | use_dropout: False 33 | dropblock: 0 34 | 35 | augmentation: 36 | input_size: 224 37 | test_resize: 224 38 | 39 | loss: 40 | det_loss: bce 41 | link_loss: l2 42 | smooth_ratio: 0 43 | det_ratio: 1.5 44 | trans_ratio: 0.001 45 | trans_last: True 46 | 47 | workers: 1 48 | batch_size: 1 49 | lr_scheduler: 50 | #type: COSINE 51 | optim: Adam 52 | type: one_cycle 53 | base_lr: 0.0003 54 | lr_max: 0.0006 55 | moms: [0.95, 0.85] 56 | div_factor: 10.0 57 | pct_start: 0.4 58 | max_iter: 134200 # 40 * 3355 59 | 60 | fixed_wd: true 61 | use_moving_average: false 62 | momentum: 0.9 63 | #weight_decay: 0.0001 64 | weight_decay: 0.01 # super converge. decrease this when you increase steps. 65 | 66 | val_freq: 3355 # exact num of samples in 1 epoch with pp 67 | print_freq: 100 68 | 69 | train_root: ./kitti_t_o/training 70 | train_source: ./kitti_t_o/training/ 71 | train_link : ./data/train.txt 72 | train_det: ./data/pp_train_dets.pkl 73 | 74 | val_root: ./kitti_t_o/training 75 | val_source: ./kitti_t_o/training/ 76 | val_link : ./data/val.txt 77 | val_det : ./data/pp_val_dets.pkl 78 | 79 | -------------------------------------------------------------------------------- /experiments/pp_pv_40e_mul_B/eval.sh: -------------------------------------------------------------------------------- 1 | now=$(date +"%Y%m%d_%H%M%S") 2 | work_path=$(dirname $0) 3 | srun --mpi=pmi2 -p $1 -n1 --gres=gpu:1 --ntasks-per-node=1 \ 4 | python -u eval_seq.py --config $work_path/config.yaml \ 5 | --load-path=./pretrain_models/pp_pv_40e_mul_B-gpu.pth \ 6 | --result-path=$work_path/results \ 7 | --result_sha=all \ 8 | 2>&1|tee $work_path/Eval-pts-${now}.log 9 | -------------------------------------------------------------------------------- /experiments/pp_pv_40e_mul_B/train.sh: -------------------------------------------------------------------------------- 1 | now=$(date +"%Y%m%d_%H%M%S") 2 | work_path=$(dirname $0) 3 | srun --mpi=pmi2 -p $1 -n1 --gres=gpu:1 --ntasks-per-node=1 \ 4 | python -u main.py --config $work_path/config.yaml \ 5 | --result-path=$work_path/results \ 6 | 2>&1|tee $work_path/T-${now}.log 7 | #--load-path=$work_path/ckpt.pth.tar \ 8 | #--recover 9 | -------------------------------------------------------------------------------- /experiments/pp_pv_40e_mul_C/config.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | model: 3 | point_arch: v1 4 | point_len: 512 5 | appear_arch: vgg 6 | appear_len: 512 7 | appear_skippool: True 8 | appear_fpn: False 9 | 10 | end_arch: v2 11 | end_mode: avg 12 | 13 | affinity_op: multiply # multiply or addminus 14 | softmax_mode: none 15 | 16 | score_arch: branch_cls # end is v1 17 | neg_threshold: 0.2 18 | 19 | score_fusion_arch: C 20 | test_mode: 2 #0:image;1:LiDAR;2:fusion 21 | 22 | gt_det_ratio : 0 23 | sample_max_len : 2 24 | det_type: 3D 25 | tracker_type: 3D 26 | use_frustum: False 27 | without_reflectivity: True 28 | train_fix_iou: 1 29 | train_fix_count: 0 30 | val_fix_iou: 1 31 | val_fix_count: 0 32 | use_dropout: False 33 | dropblock: 0 34 | 35 | 36 | augmentation: 37 | input_size: 224 38 | test_resize: 224 39 | 40 | loss: 41 | det_loss: bce 42 | link_loss: l2 43 | smooth_ratio: 0 44 | det_ratio: 1.5 45 | trans_ratio: 0.001 46 | trans_last: True 47 | 48 | workers: 1 49 | batch_size: 1 50 | lr_scheduler: 51 | #type: COSINE 52 | optim: Adam 53 | type: one_cycle 54 | base_lr: 0.0003 55 | lr_max: 0.0006 56 | moms: [0.95, 0.85] 57 | div_factor: 10.0 58 | pct_start: 0.4 59 | max_iter: 134200 # 40 * 3355 60 | 61 | fixed_wd: true 62 | use_moving_average: false 63 | momentum: 0.9 64 | #weight_decay: 0.0001 65 | weight_decay: 0.01 # super converge. decrease this when you increase steps. 66 | 67 | val_freq: 3355 # exact num of samples in 1 epoch with pp 68 | print_freq: 100 69 | 70 | train_root: ./kitti_t_o/training 71 | train_source: ./kitti_t_o/training/ 72 | train_link : ./data/train.txt 73 | train_det: ./data/pp_train_dets.pkl 74 | 75 | val_root: ./kitti_t_o/training 76 | val_source: ./kitti_t_o/training/ 77 | val_link : ./data/val.txt 78 | val_det : ./data/pp_val_dets.pkl 79 | 80 | -------------------------------------------------------------------------------- /experiments/pp_pv_40e_mul_C/eval.sh: -------------------------------------------------------------------------------- 1 | now=$(date +"%Y%m%d_%H%M%S") 2 | work_path=$(dirname $0) 3 | srun --mpi=pmi2 -p $1 -n1 --gres=gpu:1 --ntasks-per-node=1 \ 4 | python -u eval_seq.py --config $work_path/config.yaml \ 5 | --load-path=./pretrain_models/pp_pv_40e_mul_C-gpu.pth \ 6 | --result-path=$work_path/results \ 7 | --result_sha=all \ 8 | 2>&1|tee $work_path/Eval-pts-${now}.log 9 | -------------------------------------------------------------------------------- /experiments/pp_pv_40e_mul_C/train.sh: -------------------------------------------------------------------------------- 1 | now=$(date +"%Y%m%d_%H%M%S") 2 | work_path=$(dirname $0) 3 | srun --mpi=pmi2 -p $1 -n1 --gres=gpu:1 --ntasks-per-node=1 \ 4 | python -u main.py --config $work_path/config.yaml \ 5 | --result-path=$work_path/results \ 6 | 2>&1|tee $work_path/T-${now}.log 7 | #--load-path=$work_path/ckpt.pth.tar \ 8 | #--recover 9 | -------------------------------------------------------------------------------- /experiments/rrc_pfv_40e_subabs_dualadd_C/config.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | model: 3 | point_arch: v1 4 | point_len: 512 5 | appear_arch: vgg 6 | appear_len: 512 7 | appear_skippool: True 8 | appear_fpn: False 9 | 10 | end_arch: v2 11 | end_mode: avg 12 | 13 | affinity_op: minus_abs # multiply or addminus 14 | softmax_mode: dual_add 15 | 16 | score_arch: branch_cls 17 | neg_threshold: 0 18 | 19 | score_fusion_arch: C 20 | test_mode: 2 #0:image;1:LiDAR;2:fusion 21 | 22 | gt_det_ratio : 0 23 | sample_max_len : 2 24 | det_type: 3D 25 | tracker_type: 3D 26 | use_frustum: True 27 | without_reflectivity: True 28 | train_fix_iou: 1 29 | train_fix_count: 0 30 | val_fix_iou: 1 31 | val_fix_count: 0 32 | use_dropout: True 33 | dropblock: 5 34 | 35 | augmentation: 36 | input_size: 224 37 | test_resize: 224 38 | 39 | loss: 40 | det_loss: bce 41 | link_loss: l2 42 | smooth_ratio: 0 43 | det_ratio: 1.5 44 | trans_ratio: 0.001 45 | trans_last: True 46 | 47 | workers: 1 48 | batch_size: 1 49 | lr_scheduler: 50 | optim: Adam 51 | type: one_cycle 52 | base_lr: 0.0003 53 | lr_max: 0.0006 54 | moms: [0.95, 0.85] 55 | div_factor: 10.0 56 | pct_start: 0.4 57 | max_iter: 133320 # 40 * 3333 58 | 59 | fixed_wd: true 60 | use_moving_average: false 61 | momentum: 0.9 62 | #weight_decay: 0.0001 63 | weight_decay: 0.01 # super converge. decrease this when you increase steps. 64 | 65 | val_freq: 3333 # exact num of samples in 1 epoch with pp 66 | print_freq: 100 67 | 68 | train_root: ./kitti_t_o/training 69 | train_source: ./kitti_t_o/training/ 70 | train_link : ./data/train.txt 71 | train_det: ./data/RRC_Detections_mat/train 72 | 73 | val_root: ./kitti_t_o/training 74 | val_source: ./kitti_t_o/training/ 75 | val_link : ./data/val.txt 76 | val_det : ./data/RRC_Detections_mat/train 77 | 78 | test_root: ./kitti_t_o/testing 79 | test_source: ./kitti_t_o/testing/ 80 | test_link: ./data/test.txt 81 | test_det : ./data/RRC_Detections_mat/test 82 | -------------------------------------------------------------------------------- /experiments/rrc_pfv_40e_subabs_dualadd_C/eval.sh: -------------------------------------------------------------------------------- 1 | now=$(date +"%Y%m%d_%H%M%S") 2 | work_path=$(dirname $0) 3 | srun --mpi=pmi2 -p $1 -n1 --gres=gpu:1 --ntasks-per-node=1 \ 4 | python -u eval_seq.py --config $work_path/config.yaml \ 5 | --load-path=$work_path/ckpt_best.pth.tar \ 6 | --result-path=$work_path/results \ 7 | --result_sha=all \ 8 | 2>&1|tee $work_path/Eval-split-${now}.log 9 | -------------------------------------------------------------------------------- /experiments/rrc_pfv_40e_subabs_dualadd_C/test.sh: -------------------------------------------------------------------------------- 1 | now=$(date +"%Y%m%d_%H%M%S") 2 | work_path=$(dirname $0) 3 | srun --mpi=pmi2 -p $1 -n1 --gres=gpu:1 --ntasks-per-node=1 \ 4 | python -u test.py --config $work_path/config.yaml \ 5 | --result-path=$work_path/test_results \ 6 | --result_sha=test_mm \ 7 | --load-path=$work_path/ckpt_best.pth.tar \ 8 | 2>&1|tee $work_path/Test-img_${now}.log 9 | 10 | -------------------------------------------------------------------------------- /experiments/rrc_pfv_40e_subabs_dualadd_C/train.sh: -------------------------------------------------------------------------------- 1 | now=$(date +"%Y%m%d_%H%M%S") 2 | work_path=$(dirname $0) 3 | srun --mpi=pmi2 -p $1 -n1 --gres=gpu:1 --ntasks-per-node=1 \ 4 | python -u main.py --config $work_path/config.yaml \ 5 | --result-path=$work_path/results \ 6 | 2>&1|tee $work_path/T-${now}.log 7 | #--load-path=$work_path/ckpt.pth.tar \ 8 | #--recover 9 | -------------------------------------------------------------------------------- /kitti_devkit/README.md: -------------------------------------------------------------------------------- 1 | # KITTI Devkit 2 | 3 | This devkit includes the code to evaluate the MOT results. 4 | We borrow the original evaluation code provided by the [KITTI MOT Benchmark](http://www.cvlibs.net/datasets/kitti/eval_tracking.php) with small modification of the function API. 5 | This guarantees our evaluation to be consistent with the official benchmark. 6 | 7 | These `segmaps` from the original devkit are put in the `data/tracking` directory. 8 | -------------------------------------------------------------------------------- /kitti_devkit/mailpy.py: -------------------------------------------------------------------------------- 1 | class Mail: 2 | """ Dummy class to print messages without sending e-mails""" 3 | def __init__(self,mailaddress): 4 | pass 5 | def msg(self,msg): 6 | print(msg) 7 | def finalize(self,success,benchmark,sha_key,mailaddress=None): 8 | if success: 9 | print("Results for %s (benchmark: %s) sucessfully created" % (benchmark,sha_key)) 10 | else: 11 | print("Creating results for %s (benchmark: %s) failed" % (benchmark,sha_key)) 12 | 13 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import pprint 5 | import time 6 | 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import yaml 11 | from easydict import EasyDict 12 | from kitti_devkit.evaluate_tracking import evaluate 13 | from tensorboardX import SummaryWriter 14 | from torch.utils.data import DataLoader 15 | # from models import model_entry 16 | from tracking_model import TrackingModule 17 | from utils.build_util import (build_augmentation, build_criterion, 18 | build_dataset, build_lr_scheduler, build_model, 19 | build_optim) 20 | from utils.data_util import write_kitti_result 21 | from utils.train_util import (AverageMeter, DistributedGivenIterationSampler, 22 | create_logger, load_state, save_checkpoint) 23 | 24 | parser = argparse.ArgumentParser(description='PyTorch mmMOT Training') 25 | parser.add_argument('--config', default='cfgs/config_res50.yaml') 26 | parser.add_argument('--load-path', default='', type=str) 27 | parser.add_argument('--result-path', default='', type=str) 28 | parser.add_argument('--recover', action='store_true') 29 | parser.add_argument('-e', '--evaluate', action='store_true') 30 | parser.add_argument('--part', default='val', type=str) 31 | 32 | 33 | def main(): 34 | global args, config, best_mota 35 | args = parser.parse_args() 36 | 37 | with open(args.config) as f: 38 | config = yaml.load(f, Loader=yaml.FullLoader) 39 | 40 | config = EasyDict(config['common']) 41 | config.save_path = os.path.dirname(args.config) 42 | 43 | # create model 44 | model = build_model(config) 45 | model.cuda() 46 | 47 | optimizer = build_optim(model, config) 48 | 49 | criterion = build_criterion(config.loss) 50 | 51 | # optionally resume from a checkpoint 52 | last_iter = -1 53 | best_mota = 0 54 | if args.load_path: 55 | if args.recover: 56 | best_mota, last_iter = load_state( 57 | args.load_path, model, optimizer=optimizer) 58 | else: 59 | load_state(args.load_path, model) 60 | 61 | cudnn.benchmark = True 62 | 63 | # Data loading code 64 | train_transform, valid_transform = build_augmentation(config.augmentation) 65 | 66 | # train 67 | train_dataset = build_dataset( 68 | config, 69 | set_source='train', 70 | evaluate=False, 71 | train_transform=train_transform) 72 | trainval_dataset = build_dataset( 73 | config, 74 | set_source='train', 75 | evaluate=True, 76 | valid_transform=valid_transform) 77 | val_dataset = build_dataset( 78 | config, 79 | set_source='val', 80 | evaluate=True, 81 | valid_transform=valid_transform) 82 | 83 | train_sampler = DistributedGivenIterationSampler( 84 | train_dataset, 85 | config.lr_scheduler.max_iter, 86 | config.batch_size, 87 | world_size=1, 88 | rank=0, 89 | last_iter=last_iter) 90 | 91 | train_loader = DataLoader( 92 | train_dataset, 93 | batch_size=config.batch_size, 94 | shuffle=False, 95 | num_workers=config.workers, 96 | pin_memory=True, 97 | sampler=train_sampler) 98 | 99 | lr_scheduler = build_lr_scheduler(config.lr_scheduler, optimizer) 100 | 101 | tb_logger = SummaryWriter(config.save_path + '/events') 102 | logger = create_logger('global_logger', config.save_path + '/log.txt') 103 | logger.info('args: {}'.format(pprint.pformat(args))) 104 | logger.info('config: {}'.format(pprint.pformat(config))) 105 | 106 | tracking_module = TrackingModule(model, optimizer, criterion, 107 | config.det_type) 108 | if args.evaluate: 109 | logger.info('Evaluation on traing set:') 110 | validate(trainval_dataset, tracking_module, "last", part='train') 111 | logger.info('Evaluation on validation set:') 112 | validate(val_dataset, tracking_module, "last", part='val') 113 | return 114 | train(train_loader, val_dataset, trainval_dataset, tracking_module, 115 | lr_scheduler, last_iter + 1, tb_logger) 116 | 117 | 118 | def train(train_loader, val_loader, trainval_loader, tracking_module, 119 | lr_scheduler, start_iter, tb_logger): 120 | 121 | global best_mota 122 | 123 | batch_time = AverageMeter(config.print_freq) 124 | data_time = AverageMeter(config.print_freq) 125 | losses = AverageMeter(config.print_freq) 126 | 127 | # switch to train mode 128 | tracking_module.model.train() 129 | 130 | logger = logging.getLogger('global_logger') 131 | 132 | end = time.time() 133 | 134 | for i, (input, det_info, det_id, det_cls, 135 | det_split) in enumerate(train_loader): 136 | curr_step = start_iter + i 137 | # measure data loading time 138 | if lr_scheduler is not None: 139 | lr_scheduler.step(curr_step) 140 | current_lr = lr_scheduler.get_lr() 141 | data_time.update(time.time() - end) 142 | # transfer input to gpu 143 | input = input.cuda() 144 | if len(det_info) > 0: 145 | for k, v in det_info.items(): 146 | det_info[k] = det_info[k].cuda() if not isinstance( 147 | det_info[k], list) else det_info[k] 148 | # forward 149 | loss = tracking_module.step( 150 | input.squeeze(0), det_info, det_id, det_cls, det_split) 151 | 152 | # measure elapsed time 153 | batch_time.update(time.time() - end) 154 | losses.update(loss.item()) 155 | if (curr_step + 1) % config.print_freq == 0: 156 | tb_logger.add_scalar('loss_train', losses.avg, curr_step) 157 | logger.info('Iter: [{0}/{1}]\t' 158 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 159 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 160 | 'Loss {loss.val:.4f} ({loss.avg:.4f})'.format( 161 | curr_step + 1, 162 | len(train_loader), 163 | batch_time=batch_time, 164 | data_time=data_time, 165 | loss=losses)) 166 | 167 | if curr_step > 0 and (curr_step + 1) % config.val_freq == 0: 168 | logger.info('Evaluation on validation set:') 169 | MOTA, MOTP, recall, prec, F1, fp, fn, id_switches = validate( 170 | val_loader, 171 | tracking_module, 172 | str(curr_step + 1), 173 | part=args.part) 174 | if tb_logger is not None: 175 | tb_logger.add_scalar('prec', prec, curr_step) 176 | tb_logger.add_scalar('recall', recall, curr_step) 177 | tb_logger.add_scalar('mota', MOTA, curr_step) 178 | tb_logger.add_scalar('motp', MOTP, curr_step) 179 | tb_logger.add_scalar('fp', fp, curr_step) 180 | tb_logger.add_scalar('fn', fn, curr_step) 181 | tb_logger.add_scalar('f1', F1, curr_step) 182 | tb_logger.add_scalar('id_switches', id_switches, curr_step) 183 | if lr_scheduler is not None: 184 | tb_logger.add_scalar('lr', current_lr, curr_step) 185 | 186 | # remember best mota and save checkpoint 187 | is_best = MOTA > best_mota 188 | best_mota = max(MOTA, best_mota) 189 | 190 | save_checkpoint( 191 | { 192 | 'step': curr_step, 193 | 'score_arch': config.model.score_arch, 194 | 'appear_arch': config.model.appear_arch, 195 | 'best_mota': best_mota, 196 | 'state_dict': tracking_module.model.state_dict(), 197 | 'optimizer': tracking_module.optimizer.state_dict(), 198 | }, is_best, config.save_path + '/ckpt') 199 | 200 | end = time.time() 201 | 202 | 203 | def validate(val_loader, 204 | tracking_module, 205 | step, 206 | part='train', 207 | fusion_list=None, 208 | fuse_prob=False): 209 | 210 | logger = logging.getLogger('global_logger') 211 | for i, (sequence) in enumerate(val_loader): 212 | logger.info('Test: [{}/{}]\tSequence ID: KITTI-{}'.format( 213 | i, len(val_loader), sequence.name)) 214 | seq_loader = DataLoader( 215 | sequence, 216 | batch_size=config.batch_size, 217 | shuffle=False, 218 | num_workers=config.workers, 219 | pin_memory=True) 220 | if len(seq_loader) == 0: 221 | tracking_module.eval() 222 | logger.info('Empty Sequence ID: KITTI-{}, skip'.format( 223 | sequence.name)) 224 | else: 225 | validate_seq(seq_loader, tracking_module) 226 | 227 | write_kitti_result( 228 | args.result_path, 229 | sequence.name, 230 | step, 231 | tracking_module.frames_id, 232 | tracking_module.frames_det, 233 | part=part) 234 | MOTA, MOTP, recall, prec, F1, fp, fn, id_switches = evaluate( 235 | step, args.result_path, part=part) 236 | 237 | tracking_module.train() 238 | return MOTA, MOTP, recall, prec, F1, fp, fn, id_switches 239 | 240 | 241 | def validate_seq(val_loader, 242 | tracking_module, 243 | fusion_list=None, 244 | fuse_prob=False): 245 | batch_time = AverageMeter(0) 246 | 247 | # switch to evaluate mode 248 | tracking_module.eval() 249 | 250 | logger = logging.getLogger('global_logger') 251 | end = time.time() 252 | 253 | with torch.no_grad(): 254 | for i, (input, det_info, dets, det_split) in enumerate(val_loader): 255 | input = input.cuda() 256 | if len(det_info) > 0: 257 | for k, v in det_info.items(): 258 | det_info[k] = det_info[k].cuda() if not isinstance( 259 | det_info[k], list) else det_info[k] 260 | 261 | # compute output 262 | aligned_ids, aligned_dets, frame_start = tracking_module.predict( 263 | input[0], det_info, dets, det_split) 264 | 265 | batch_time.update(time.time() - end) 266 | end = time.time() 267 | if i % config.print_freq == 0: 268 | logger.info( 269 | 'Test Frame: [{0}/{1}]\tTime ' 270 | '{batch_time.val:.3f} ({batch_time.avg:.3f})'.format( 271 | i, len(val_loader), batch_time=batch_time)) 272 | 273 | 274 | if __name__ == '__main__': 275 | main() 276 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .appear_net import * # noqa 2 | from .fusion_net import * # noqa 3 | from .gcn import * # noqa 4 | from .new_end import * # noqa 5 | from .point_net import * # noqa 6 | from .score_net import * # noqa 7 | from .tracking_net import TrackingNet # noqa 8 | -------------------------------------------------------------------------------- /modules/appear_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | from .dropblock import DropBlock2D 6 | from .vgg import vgg16_bn_128, vgg16_bn_256, vgg16_bn_512 # noqa 7 | 8 | 9 | class SkipPool(nn.Module): 10 | 11 | def __init__(self, channels, reduction, out_channels, dropblock_size=5): 12 | super(SkipPool, self).__init__() 13 | self.channels = channels 14 | self.reduction = reduction 15 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 16 | self.dropblock = None 17 | if dropblock_size: 18 | self.dropblock = DropBlock2D(block_size=dropblock_size) 19 | self.fc = nn.Sequential( 20 | nn.GroupNorm(1, channels), 21 | nn.Conv2d(channels, max(channels // reduction, 64), 1, 1), 22 | nn.GroupNorm(1, max(channels // reduction, 64)), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(max(channels // reduction, 64), out_channels, 1, 1), 25 | nn.GroupNorm(1, out_channels), nn.ReLU(inplace=True)) 26 | 27 | def forward(self, x): 28 | if self.dropblock is not None: 29 | x = self.dropblock(x) 30 | out = self.avg_pool(x) 31 | out = self.fc(out).view((x.size(0), -1)) 32 | return out 33 | 34 | 35 | class AppearanceNet(nn.Module): 36 | 37 | def __init__(self, 38 | arch='vgg', 39 | out_channels=512, 40 | skippool=True, 41 | fpn=False, 42 | dropblock=5): 43 | super(AppearanceNet, self).__init__() 44 | self.arch = arch 45 | self.skippool = skippool 46 | self.fpn = fpn 47 | self.out_channels = out_channels 48 | self.dropblock = dropblock 49 | reduction = 512 // out_channels 50 | assert not (skippool and fpn) 51 | 52 | if arch == 'vgg': 53 | base_channel = 64 // reduction 54 | vgg_net = eval("vgg16_bn_%s" % str(out_channels)) 55 | loaded_model = vgg_net() 56 | if skippool: 57 | print("use Skip Pooling in appearance model") 58 | self.layers, self.global_pool = self._parse_vgg_layers( 59 | loaded_model) 60 | elif arch == 'resnet50': 61 | loaded_model = torchvision.models.resnet50(pretrained=True) 62 | base_channel = 256 63 | self.layers = Resnet(loaded_model) 64 | if skippool: 65 | print("use Skip Pooling in appearance model") 66 | self.global_pool = self._parse_res_layers(4) 67 | elif arch == 'resnet101': 68 | print("use resnet101") 69 | loaded_model = torchvision.models.resnet101(pretrained=True) 70 | base_channel = 256 71 | self.layers = Resnet(loaded_model) 72 | if skippool: 73 | print("use Skip Pooling in appearance model") 74 | self.global_pool = self._parse_res_layers(4) 75 | elif arch == 'resnet152': 76 | print("use resnet152") 77 | loaded_model = torchvision.models.resnet152(pretrained=True) 78 | base_channel = 256 79 | self.layers = Resnet(loaded_model) 80 | if skippool: 81 | print("use Skip Pooling in appearance model") 82 | self.global_pool = self._parse_res_layers(4) 83 | if fpn: 84 | print("use FPN in appearance model") 85 | # FPN Module 86 | self.fpn_in = [] 87 | fpn_inplanes = (base_channel, base_channel * 2, base_channel * 4, 88 | base_channel * 8) 89 | for fpn_inplane in fpn_inplanes: # skip the top layer 90 | self.fpn_in.append( 91 | nn.Sequential( 92 | nn.Conv2d( 93 | fpn_inplane, 94 | out_channels, 95 | kernel_size=1, 96 | bias=False), nn.BatchNorm2d(out_channels), 97 | nn.ReLU(inplace=True))) 98 | self.fpn_in = nn.ModuleList(self.fpn_in) 99 | self.conv_last = nn.Sequential( 100 | nn.AdaptiveAvgPool2d(1), 101 | nn.Conv2d(out_channels, out_channels, 1, 1), 102 | nn.BatchNorm2d(out_channels), 103 | nn.ReLU(inplace=True), 104 | ) 105 | 106 | if not skippool and not fpn: 107 | self.conv_last = nn.Sequential( 108 | nn.AdaptiveAvgPool2d(1), 109 | nn.Conv2d(base_channel * 8, out_channels, 1, 1), 110 | nn.BatchNorm2d(out_channels), 111 | nn.ReLU(inplace=True), 112 | ) 113 | 114 | def _parse_res_layers(self, layout): 115 | pool_layers = [] 116 | base_channels = 256 117 | block_size = 0 118 | for i in range(layout): # layout range from 0-3 119 | if i > 1 and self.dropblock: 120 | block_size = self.dropblock 121 | pool_layers.append( 122 | self._make_scalar_layer( 123 | base_channels * pow(2, i), 124 | 4, 125 | self.out_channels // 4, 126 | dropblock_size=block_size)) 127 | 128 | return nn.ModuleList(pool_layers) 129 | 130 | def _parse_vgg_layers(self, loaded_model): 131 | layers = [] 132 | blocks = [] 133 | pool_layers = [] 134 | channels = 0 135 | first = 0 136 | block_size = 0 137 | for m in loaded_model.features.children(): 138 | blocks.append(m) 139 | if isinstance(m, nn.MaxPool2d): 140 | first += 1 141 | if first == 1: 142 | continue 143 | if first > 3 and self.dropblock: 144 | block_size = self.dropblock 145 | layers.append(nn.Sequential(*blocks)) 146 | blocks = [] 147 | pool_layers.append( 148 | self._make_scalar_layer( 149 | channels, 150 | 4, 151 | self.out_channels // 4, 152 | dropblock_size=block_size)) 153 | 154 | elif isinstance(m, nn.Conv2d): 155 | channels = m.out_channels 156 | 157 | return nn.ModuleList(layers), nn.ModuleList(pool_layers) 158 | 159 | def _make_scalar_layer(self, 160 | channels, 161 | reduction, 162 | out_channels, 163 | dropblock_size=5): 164 | return SkipPool(channels, reduction, out_channels, dropblock_size) 165 | 166 | def vgg_forward(self, x): 167 | pool_out = [] 168 | for layer in self.layers: 169 | x = layer(x) 170 | pool_out.append(x) 171 | 172 | return pool_out 173 | 174 | def res_forward(self, x): 175 | feats = self.layers(x, return_feature_maps=True) 176 | return feats 177 | 178 | def forward(self, x): 179 | if self.arch == 'vgg': 180 | feats = self.vgg_forward(x) 181 | else: 182 | feats = self.res_forward(x) 183 | 184 | if self.skippool: 185 | pool_out = [] 186 | for layer, feat in zip(self.global_pool, feats): 187 | pool_out.append(layer(feat)) 188 | 189 | out = torch.cat(pool_out, dim=-1) 190 | return out 191 | 192 | if self.fpn: 193 | out = feats[-1] 194 | out = self.fpn_in[-1](out) # last output 195 | 196 | for i in reversed(range(len(feats) - 1)): 197 | conv_x = feats[i] 198 | conv_x = self.fpn_in[i](conv_x) # lateral branch 199 | 200 | out = nn.functional.interpolate( 201 | out, 202 | size=conv_x.size()[2:], 203 | mode='bilinear', 204 | align_corners=False) # top-down branch 205 | out = conv_x + out 206 | else: 207 | out = feats[-1] 208 | 209 | out = self.conv_last(out).squeeze(-1).squeeze(-1) # NxCx1x1 -> N*C 210 | return out 211 | 212 | 213 | class Resnet(nn.Module): 214 | 215 | def __init__(self, orig_resnet, use_dropblock=False): 216 | super(Resnet, self).__init__() 217 | 218 | # take pretrained resnet, except AvgPool and FC 219 | self.conv1 = orig_resnet.conv1 220 | self.bn1 = orig_resnet.bn1 221 | self.relu1 = orig_resnet.relu 222 | self.maxpool = orig_resnet.maxpool 223 | self.layer1 = orig_resnet.layer1 224 | self.layer2 = orig_resnet.layer2 225 | self.layer3 = orig_resnet.layer3 226 | self.layer4 = orig_resnet.layer4 227 | 228 | self.dropblock = None 229 | if use_dropblock: 230 | print("Apply dropblock after group 3 & 4") 231 | self.dropblock = DropBlock2D(block_size=5) 232 | 233 | def forward(self, x, return_feature_maps=False): 234 | conv_out = [] 235 | 236 | x = self.relu1(self.bn1(self.conv1(x))) 237 | x = self.maxpool(x) 238 | 239 | x = self.layer1(x) 240 | conv_out.append(x) 241 | x = self.layer2(x) 242 | conv_out.append(x) 243 | 244 | x = self.layer3(x) 245 | if self.dropblock is not None: 246 | x = self.dropblock(x) 247 | conv_out.append(x) 248 | 249 | x = self.layer4(x) 250 | if self.dropblock is not None: 251 | x = self.dropblock(x) 252 | conv_out.append(x) 253 | 254 | if return_feature_maps: 255 | return conv_out 256 | return [x] 257 | -------------------------------------------------------------------------------- /modules/dropblock.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class DropBlock2D(nn.Module): 7 | r"""Randomly zeroes 2D spatial blocks of the input tensor. 8 | As described in the paper 9 | `DropBlock: A regularization method for convolutional networks`_ , 10 | dropping whole blocks of feature map allows to remove semantic 11 | information as compared to regular dropout. 12 | Args: 13 | drop_prob (float): probability of an element to be dropped. 14 | block_size (int): size of the block to drop 15 | Shape: 16 | - Input: `(N, C, H, W)` 17 | - Output: `(N, C, H, W)` 18 | .. _DropBlock: A regularization method for convolutional networks: 19 | https://arxiv.org/abs/1810.12890 20 | """ 21 | 22 | def __init__(self, drop_prob=0.1, block_size=7): 23 | super(DropBlock2D, self).__init__() 24 | 25 | self.drop_prob = drop_prob 26 | self.block_size = block_size 27 | 28 | def forward(self, x): 29 | # shape: (bsize, channels, height, width) 30 | 31 | assert x.dim() == 4, \ 32 | "Expected input with 4 dimensions (bsize, channels, height, width)" 33 | 34 | if not self.training or self.drop_prob == 0.: 35 | return x 36 | else: 37 | # get gamma value 38 | gamma = self._compute_gamma(x) 39 | 40 | # sample mask 41 | mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float() 42 | 43 | # place mask on input device 44 | mask = mask.to(x.device) 45 | 46 | # compute block mask 47 | block_mask = self._compute_block_mask(mask) 48 | 49 | # apply block mask 50 | out = x * block_mask[:, None, :, :] 51 | 52 | # scale output 53 | out = out * block_mask.numel() / block_mask.sum() 54 | 55 | return out 56 | 57 | def _compute_block_mask(self, mask): 58 | block_mask = F.max_pool2d( 59 | input=mask[:, None, :, :], 60 | kernel_size=(self.block_size, self.block_size), 61 | stride=(1, 1), 62 | padding=self.block_size // 2) 63 | 64 | if self.block_size % 2 == 0: 65 | block_mask = block_mask[:, :, :-1, :-1] 66 | 67 | block_mask = 1 - block_mask.squeeze(1) 68 | 69 | return block_mask 70 | 71 | def _compute_gamma(self, x): 72 | return self.drop_prob / (self.block_size**2) 73 | 74 | def extra_repr(self): 75 | return 'drop_prob={drop_prob}, block_size={block_size}'.format( 76 | **self.__dict__) 77 | 78 | 79 | class DropBlock3D(DropBlock2D): 80 | r"""Randomly zeroes 3D spatial blocks of the input tensor. 81 | An extension to the concept described in the paper 82 | `DropBlock: A regularization method for convolutional networks`_ , 83 | dropping whole blocks of feature map allows to remove semantic 84 | information as compared to regular dropout. 85 | Args: 86 | drop_prob (float): probability of an element to be dropped. 87 | block_size (int): size of the block to drop 88 | Shape: 89 | - Input: `(N, C, D, H, W)` 90 | - Output: `(N, C, D, H, W)` 91 | .. _DropBlock: A regularization method for convolutional networks: 92 | https://arxiv.org/abs/1810.12890 93 | """ 94 | 95 | def __init__(self, drop_prob, block_size): 96 | super(DropBlock3D, self).__init__(drop_prob, block_size) 97 | 98 | def forward(self, x): 99 | # shape: (bsize, channels, depth, height, width) 100 | 101 | assert x.dim() == 5, \ 102 | "Expected input with 5 dimensions (bsize, channels, depth, height, width)" # noqa 103 | 104 | if not self.training or self.drop_prob == 0.: 105 | return x 106 | else: 107 | # get gamma value 108 | gamma = self._compute_gamma(x) 109 | 110 | # sample mask 111 | mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float() 112 | 113 | # place mask on input device 114 | mask = mask.to(x.device) 115 | 116 | # compute block mask 117 | block_mask = self._compute_block_mask(mask) 118 | 119 | # apply block mask 120 | out = x * block_mask[:, None, :, :, :] 121 | 122 | # scale output 123 | out = out * block_mask.numel() / block_mask.sum() 124 | 125 | return out 126 | 127 | def _compute_block_mask(self, mask): 128 | block_mask = F.max_pool3d( 129 | input=mask[:, None, :, :, :], 130 | kernel_size=(self.block_size, self.block_size, self.block_size), 131 | stride=(1, 1, 1), 132 | padding=self.block_size // 2) 133 | 134 | if self.block_size % 2 == 0: 135 | block_mask = block_mask[:, :, :-1, :-1, :-1] 136 | 137 | block_mask = 1 - block_mask.squeeze(1) 138 | 139 | return block_mask 140 | 141 | def _compute_gamma(self, x): 142 | return self.drop_prob / (self.block_size**3) 143 | -------------------------------------------------------------------------------- /modules/fusion_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Common fusion module 6 | class fusion_module_C(nn.Module): 7 | 8 | def __init__(self, appear_len, point_len, out_channels): 9 | super(fusion_module_C, self).__init__() 10 | print( 11 | "Fusion Module C: split sigmoid weight gated point, image fusion") 12 | self.appear_len = appear_len 13 | self.point_len = point_len 14 | self.gate_p = nn.Sequential( 15 | nn.Conv1d(point_len, point_len, 1, 1), 16 | nn.Sigmoid(), 17 | ) 18 | self.gate_i = nn.Sequential( 19 | nn.Conv1d(appear_len, appear_len, 1, 1), 20 | nn.Sigmoid(), 21 | ) 22 | self.input_p = nn.Sequential( 23 | nn.Conv1d(point_len, out_channels, 1, 1), 24 | nn.GroupNorm(out_channels, out_channels), 25 | ) 26 | self.input_i = nn.Sequential( 27 | nn.Conv1d(appear_len, out_channels, 1, 1), 28 | nn.GroupNorm(out_channels, out_channels), 29 | ) 30 | 31 | def forward(self, objs): 32 | """ 33 | objs : 1xDxN 34 | """ 35 | feats = objs.view(2, -1, objs.size(-1)) # 1x2DxL -> 2xDxL 36 | gate_p = self.gate_p(feats[:1]) # 2xDxL 37 | gate_i = self.gate_i(feats[1:]) # 2xDxL 38 | obj_fused = gate_p.mul(self.input_p(feats[:1])) + gate_i.mul( 39 | self.input_i(feats[1:])) 40 | 41 | obj_feats = torch.cat([feats, obj_fused.div(gate_p + gate_i)], dim=0) 42 | return obj_feats 43 | 44 | 45 | class fusion_module_B(nn.Module): 46 | 47 | def __init__(self, appear_len, point_len, out_channels): 48 | super(fusion_module_B, self).__init__() 49 | print("Fusion Module B: point, weighted image" 50 | "& linear fusion, with split input w") 51 | self.appear_len = appear_len 52 | self.point_len = point_len 53 | self.input_p = nn.Sequential( 54 | nn.Conv1d(out_channels, out_channels, 1, 1), 55 | nn.GroupNorm(out_channels, out_channels), 56 | ) 57 | self.input_i = nn.Sequential( 58 | nn.Conv1d(out_channels, out_channels, 1, 1), 59 | nn.GroupNorm(out_channels, out_channels), 60 | ) 61 | 62 | def forward(self, objs): 63 | """ 64 | objs : 1xDxN 65 | """ 66 | 67 | feats = objs.view(2, -1, objs.size(-1)) # 1x2DxL -> 2xDxL 68 | obj_fused = self.input_p(feats[:1]) + self.input_i(feats[1:]) 69 | obj_feats = torch.cat([feats, obj_fused], dim=0) 70 | return obj_feats 71 | 72 | 73 | class fusion_module_A(nn.Module): 74 | 75 | def __init__(self, appear_len, point_len, out_channels): 76 | super(fusion_module_A, self).__init__() 77 | print("Fusion Module A: concatenate point, image & linear fusion") 78 | self.appear_len = appear_len 79 | self.point_len = point_len 80 | self.input_w = nn.Sequential( 81 | nn.Conv1d(out_channels * 2, out_channels, 1, 1), 82 | nn.GroupNorm(out_channels, out_channels), 83 | ) 84 | 85 | def forward(self, objs): 86 | """ 87 | objs : 1xDxN 88 | """ 89 | feats = objs.view(2, -1, objs.size(-1)) # 1x2DxL -> 2xDxL 90 | obj_fused = self.input_w(objs) # 1x2DxL -> 1xDxL 91 | obj_feats = torch.cat([feats, obj_fused], dim=0) 92 | return obj_feats 93 | -------------------------------------------------------------------------------- /modules/gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Similarity function 6 | def batch_multiply(objs, dets): 7 | """ 8 | 9 | :param objs: BxDxN 10 | :param dets: BxDxM 11 | :return:BxDxNxM 12 | """ 13 | x = torch.einsum('bci,bcj->bcij', objs, dets) 14 | return x 15 | 16 | 17 | def batch_minus_abs(objs, dets): 18 | """ 19 | 20 | :param objs: BxDxN 21 | :param dets: BxDxM 22 | :return: Bx2dxNxM 23 | """ 24 | obj_mat = objs.unsqueeze(-1).repeat(1, 1, 1, dets.size(-1)) # BxDxNxM 25 | det_mat = dets.unsqueeze(-2).repeat(1, 1, objs.size(-1), 1) # BxDxNxM 26 | related_pos = (obj_mat - det_mat) / 2 # BxDxNxM 27 | x = related_pos.abs() # Bx2DxNxM 28 | return x 29 | 30 | 31 | def batch_minus(objs, dets): 32 | """ 33 | 34 | :param objs: BxDxN 35 | :param dets: BxDxM 36 | :return: Bx2dxNxM 37 | """ 38 | obj_mat = objs.unsqueeze(-1).repeat(1, 1, 1, dets.size(-1)) # BxDxNxM 39 | det_mat = dets.unsqueeze(-2).repeat(1, 1, objs.size(-1), 1) # BxDxNxM 40 | related_pos = (obj_mat - det_mat) / 2 # BxDxNxM 41 | return related_pos 42 | 43 | 44 | # GCN 45 | class affinity_module(nn.Module): 46 | 47 | def __init__(self, in_channels, new_end, affinity_op='multiply'): 48 | super(affinity_module, self).__init__() 49 | print(f"Use {affinity_op} similarity with fusion module") 50 | self.in_channels = in_channels 51 | expansion = 1 52 | 53 | if affinity_op in ['multiply', 'minus', 'minus_abs']: 54 | self.affinity = eval(f"batch_{affinity_op}") 55 | else: 56 | print("Not Implement!!") 57 | 58 | self.w_new_end = new_end(in_channels * expansion) 59 | self.conv1 = nn.Sequential( 60 | nn.Conv2d(in_channels * expansion, in_channels, 1, 1), 61 | nn.GroupNorm(in_channels, in_channels), nn.ReLU(inplace=True), 62 | nn.Conv2d(in_channels, in_channels, 1, 1), 63 | nn.GroupNorm(in_channels, in_channels), nn.ReLU(inplace=True), 64 | nn.Conv2d(in_channels, in_channels // 4, 1, 1), 65 | nn.GroupNorm(in_channels // 4, in_channels // 4), 66 | nn.ReLU(inplace=True), nn.Conv2d(in_channels // 4, 1, 1, 1)) 67 | 68 | def forward(self, objs, dets): 69 | """ 70 | objs : 1xDxN 71 | dets : 1xDxM 72 | obj_feats: 3xDxN 73 | det_feats: 3xDxN 74 | """ 75 | # if self.fusion_net is not None: 76 | # objs = self.fusion_net(objs) 77 | # dets = self.fusion_net(dets) 78 | x = self.affinity(objs, dets) 79 | new_score, end_score = self.w_new_end(x) 80 | out = self.conv1(x) 81 | 82 | return out, new_score, end_score 83 | -------------------------------------------------------------------------------- /modules/ghm_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | The Mopdified implementation of GHM-C and GHM-R losses. 3 | Details can be found in the paper `Gradient Harmonized Single-stage Detector`: 4 | https://arxiv.org/abs/1811.05181 5 | Copyright (c) 2018 Multimedia Laboratory, CUHK. 6 | Licensed under the MIT License (see LICENSE for details) 7 | Written by Buyu Li 8 | """ 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | class GHMC_Loss(nn.Module): 16 | 17 | def __init__(self, bins=10, momentum=0): 18 | super(GHMC_Loss, self).__init__() 19 | # mmt=0.75 in experiments 20 | # bins=30, not sensitive to class 21 | self.bins = bins 22 | self.momentum = momentum 23 | self.edges = [float(x) / bins for x in range(bins + 1)] 24 | self.edges[-1] += 1e-6 25 | if momentum > 0: 26 | self.acc_sum = [0.0 for _ in range(bins)] 27 | 28 | def forward(self, input, target, mask): 29 | """ Args: 30 | input [batch_num, class_num]: 31 | The direct prediction of classification fc layer. 32 | target [batch_num, class_num]: 33 | Binary target (0 or 1) for each sample each class. The value is -1 34 | when the sample is ignored. 35 | """ 36 | edges = self.edges 37 | mmt = self.momentum 38 | weights = torch.zeros_like(input) 39 | # gradient length 40 | g = torch.abs(input.sigmoid().detach() - target) 41 | 42 | valid = mask > 0 43 | tot = max(valid.float().sum().item(), 1.0) 44 | n = 0 # n valid bins 45 | for i in range(self.bins): 46 | inds = (g >= edges[i]) & (g < edges[i + 1]) & valid 47 | num_in_bin = inds.sum().item() 48 | if num_in_bin > 0: 49 | if mmt > 0: 50 | self.acc_sum[i] = mmt * self.acc_sum[i] \ 51 | + (1 - mmt) * num_in_bin 52 | weights[inds] = tot / self.acc_sum[i] 53 | else: 54 | weights[inds] = tot / num_in_bin 55 | n += 1 56 | if n > 0: 57 | weights = weights / n 58 | 59 | loss = F.binary_cross_entropy_with_logits( 60 | input, target, weights, reduction='sum') / tot 61 | return loss 62 | 63 | 64 | class GHMR_Loss(nn.Module): 65 | 66 | def __init__(self, mu=0.02, bins=10, momentum=0): 67 | super(GHMR_Loss, self).__init__() 68 | # momentum=0.5 for general 69 | # bin=30 70 | self.mu = mu 71 | self.bins = bins 72 | self.edges = [float(x) / bins for x in range(bins + 1)] 73 | self.edges[-1] = 1e3 74 | self.momentum = momentum 75 | if momentum > 0: 76 | self.acc_sum = [0.0 for _ in range(bins)] 77 | 78 | def forward(self, input, target, mask): 79 | """ Args: 80 | input [batch_num, 4 (* class_num)]: 81 | The prediction of box regression layer. Channel number can be 4 or 82 | (4 * class_num) depending on whether it is class-agnostic. 83 | target [batch_num, 4 (* class_num)]: 84 | The target regression values with the same size of input. 85 | """ 86 | mu = self.mu 87 | edges = self.edges 88 | mmt = self.momentum 89 | 90 | # ASL1 loss 91 | diff = input - target 92 | loss = torch.sqrt(diff * diff + mu * mu) - mu 93 | 94 | # gradient length 95 | g = torch.abs(diff / torch.sqrt(mu * mu + diff * diff)).detach() 96 | weights = torch.zeros_like(g) 97 | 98 | valid = mask > 0 99 | tot = max(mask.float().sum().item(), 1.0) 100 | n = 0 # n: valid bins 101 | for i in range(self.bins): 102 | inds = (g >= edges[i]) & (g < edges[i + 1]) & valid 103 | num_in_bin = inds.sum().item() 104 | if num_in_bin > 0: 105 | n += 1 106 | if mmt > 0: 107 | self.acc_sum[i] = mmt * self.acc_sum[i] \ 108 | + (1 - mmt) * num_in_bin 109 | weights[inds] = tot / self.acc_sum[i] 110 | else: 111 | weights[inds] = tot / num_in_bin 112 | if n > 0: 113 | weights /= n 114 | 115 | loss = loss * weights 116 | loss = loss.sum() / tot 117 | return loss 118 | -------------------------------------------------------------------------------- /modules/new_end.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class NewEndIndicator_v1(nn.Module): 6 | 7 | def __init__(self, in_channels, kernel_size, reduction, mode='avg'): 8 | super(NewEndIndicator_v1, self).__init__() 9 | self.mode = mode 10 | self.w_end_conv = nn.Sequential( 11 | nn.GroupNorm(1, in_channels), 12 | nn.Conv2d(in_channels, in_channels // reduction, 1, 1), 13 | nn.GroupNorm(1, in_channels // reduction), 14 | nn.ReLU(inplace=True), 15 | nn.Conv2d(in_channels // reduction, 1, 1, 1), 16 | ) 17 | self.w_new_conv = nn.Sequential( 18 | nn.GroupNorm(1, in_channels), 19 | nn.Conv2d(in_channels, in_channels // reduction, 1, 1), 20 | nn.GroupNorm(1, in_channels // reduction), 21 | nn.ReLU(inplace=True), 22 | nn.Conv2d(in_channels // reduction, 1, 1, 1), 23 | ) 24 | 25 | def forward(self, x): 26 | """ 27 | x: 1xCxNxM 28 | w_new: Mx1 29 | w_end: Nx1 30 | """ 31 | if self.mode == 'avg': 32 | new_vec = F.adaptive_avg_pool2d(x, (1, x.size(-1))) 33 | end_vec = F.adaptive_avg_pool2d(x, (x.size(-2), 1)) 34 | else: 35 | new_vec = F.adaptive_max_pool2d(x, (1, x.size(-1))) 36 | end_vec = F.adaptive_max_pool2d(x, (x.size(-2), 1)) 37 | w_new = 1 - self.w_new_conv(new_vec).view((new_vec.size(-1), -1)) 38 | w_end = 1 - self.w_end_conv(end_vec).view((end_vec.size(-2), -1)) 39 | 40 | return w_new, w_end 41 | 42 | 43 | class NewEndIndicator_v2(nn.Module): 44 | 45 | def __init__(self, in_channels, kernel_size, reduction, mode='avg'): 46 | super(NewEndIndicator_v2, self).__init__() 47 | self.mode = mode 48 | self.conv0 = nn.Sequential( 49 | nn.Conv2d(in_channels, in_channels, 1, 1), 50 | nn.GroupNorm(1, in_channels), 51 | nn.ReLU(inplace=True), 52 | ) 53 | self.conv1 = nn.Sequential( 54 | nn.Conv1d(in_channels, min(in_channels, 512), 1, 1), 55 | nn.GroupNorm(1, min(in_channels, 512)), nn.ReLU(inplace=True), 56 | nn.Conv1d(min(in_channels, 512), in_channels // reduction, 1, 1), 57 | nn.GroupNorm(1, in_channels // reduction), nn.ReLU(inplace=True), 58 | nn.Conv1d(in_channels // reduction, 1, 1, 1), nn.Sigmoid()) 59 | print(f"End version V2 by {mode}") 60 | print(self) 61 | 62 | def forward(self, x): 63 | """ 64 | x: BxCxNxM 65 | w_new: BxM 66 | w_end: BxN 67 | """ 68 | x = self.conv0(x) 69 | if self.mode == 'avg': 70 | new_vec = x.mean(dim=-2, keepdim=False) # 1xCxM 71 | end_vec = x.mean(dim=-1, keepdim=False) # 1xCxN 72 | else: 73 | new_vec = x.max(dim=-2, keepdim=False)[0] # 1xCxM 74 | end_vec = x.max(dim=-1, keepdim=False)[0] # 1xCxN 75 | w_new = self.conv1(new_vec).squeeze(1) # BxCxM->Bx1xM->BxM 76 | w_end = self.conv1(end_vec).squeeze(1) # BxCxN->Bx1xN->BxN 77 | 78 | # if not self.training: 79 | # new_mask = w_new.gt(0.9).float()+0.05 80 | # end_mask = w_end.gt(0.9).float()+0.05 81 | # return new_mask, end_mask 82 | return w_new, w_end 83 | -------------------------------------------------------------------------------- /modules/point_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class PointNet_v1(nn.Module): 6 | 7 | def __init__(self, in_channels, out_channels=512, use_dropout=False): 8 | super(PointNet_v1, self).__init__() 9 | self.feat = PointNetfeatGN(in_channels, out_channels) 10 | reduction = 512 // out_channels 11 | self.reduction = reduction 12 | self.conv1 = torch.nn.Conv1d(1088 // reduction, 512 // reduction, 1) 13 | self.conv2 = torch.nn.Conv1d(512 // reduction, out_channels, 1) 14 | self.bn1 = nn.GroupNorm(512 // reduction, 512 // reduction) 15 | self.bn2 = nn.GroupNorm(16 // reduction, out_channels) 16 | self.out_channels = out_channels 17 | self.relu = nn.ReLU(inplace=True) 18 | self.avg_pool = nn.AdaptiveAvgPool1d(1) 19 | self.avg_bn = nn.GroupNorm(512 // reduction, 512 // reduction) 20 | self.dropout = None 21 | if use_dropout: 22 | print("Use dropout in pointnet") 23 | self.dropout = nn.Dropout(p=0.5) 24 | 25 | def forward(self, x, point_split): 26 | x, trans = self.feat(x, point_split) 27 | x = torch.cat(x, dim=1) 28 | x = self.relu(self.bn1(self.conv1(x))) 29 | if self.dropout is not None: 30 | x = self.dropout(x) 31 | 32 | max_feats = [] 33 | for i in range(len(point_split) - 1): 34 | start = point_split[i].item() 35 | end = point_split[i + 1].item() 36 | max_feat = self.avg_pool(x[:, :, start:end]) 37 | max_feats.append(max_feat.view(-1, 512 // self.reduction, 1)) 38 | 39 | max_feats = torch.cat(max_feats, dim=-1) 40 | out = self.relu(self.bn2(self.conv2(max_feats))).transpose( 41 | -1, -2).squeeze(0) 42 | assert out.size(0) == len(point_split) - 1 43 | 44 | return out, trans 45 | 46 | 47 | class STN3d(nn.Module): 48 | 49 | def __init__(self, in_channels, out_size=3, feature_channels=512): 50 | super(STN3d, self).__init__() 51 | reduction = 512 // feature_channels 52 | self.reduction = reduction 53 | self.out_size = out_size 54 | self.relu = nn.ReLU(inplace=True) 55 | self.conv1 = nn.Conv1d(in_channels, 64 // reduction, 1) 56 | self.bn1 = nn.GroupNorm(64 // reduction, 64 // reduction) 57 | self.conv2 = nn.Conv1d(64 // reduction, 128 // reduction, 1) 58 | self.bn2 = nn.GroupNorm(128 // reduction, 128 // reduction) 59 | self.conv3 = nn.Conv1d(128 // reduction, 1024 // reduction, 1) 60 | self.bn3 = nn.GroupNorm(1024 // reduction, 1024 // reduction) 61 | self.idt = nn.Parameter(torch.eye(self.out_size), requires_grad=False) 62 | 63 | self.fc1 = nn.Linear(1024 // reduction, 512 // reduction) 64 | self.fc_bn1 = nn.GroupNorm(512 // reduction, 512 // reduction) 65 | self.fc2 = nn.Linear(512 // reduction, 256 // reduction) 66 | self.fc_bn2 = nn.GroupNorm(256 // reduction, 256 // reduction) 67 | 68 | self.output = nn.Linear(256 // reduction, out_size * out_size) 69 | nn.init.constant_(self.output.weight.data, 0) 70 | nn.init.constant_(self.output.bias.data, 0) 71 | 72 | def forward(self, x): 73 | x = self.relu(self.bn1(self.conv1(x))) 74 | x = self.relu(self.bn2(self.conv2(x))) 75 | x = self.relu(self.bn3(self.conv3(x))) 76 | x = torch.max(x, -1, keepdim=True)[0] 77 | x = x.view(-1, 1024 // self.reduction) 78 | 79 | x = self.relu(self.fc_bn1(self.fc1(x))) 80 | x = self.relu(self.fc_bn2(self.fc2(x))) 81 | x = self.output(x).view(-1, self.out_size, self.out_size) 82 | 83 | x = x + self.idt 84 | # idt = x.new_tensor(torch.eye(self.out_size)) 85 | # x = x + idt 86 | return x 87 | 88 | 89 | class PointNetfeatGN(nn.Module): 90 | 91 | def __init__(self, in_channels=3, out_channels=512, global_feat=True): 92 | super(PointNetfeatGN, self).__init__() 93 | self.relu = nn.ReLU(inplace=True) 94 | self.stn1 = STN3d(in_channels, in_channels, out_channels) 95 | reduction = 512 // out_channels 96 | self.reduction = reduction 97 | self.conv1 = nn.Conv1d(in_channels, 64 // reduction, 1) 98 | self.bn1 = nn.GroupNorm(64 // reduction, 64 // reduction) 99 | 100 | self.conv2 = nn.Conv1d(64 // reduction, 64 // reduction, 1) 101 | self.bn2 = nn.GroupNorm(64 // reduction, 64 // reduction) 102 | self.stn2 = STN3d(64 // reduction, 64 // reduction, out_channels) 103 | 104 | self.conv3 = nn.Conv1d(64 // reduction, 64 // reduction, 1) 105 | self.bn3 = nn.GroupNorm(64 // reduction, 64 // reduction) 106 | 107 | self.conv4 = nn.Conv1d(64 // reduction, 128 // reduction, 1) 108 | self.bn4 = nn.GroupNorm(128 // reduction, 128 // reduction) 109 | self.conv5 = nn.Conv1d(128 // reduction, 1024 // reduction, 1) 110 | self.bn5 = nn.GroupNorm(1024 // reduction, 1024 // reduction) 111 | self.global_feat = global_feat 112 | print("use avg in pointnet feat") 113 | self.avg_pool = nn.AdaptiveAvgPool1d(1) 114 | 115 | def forward(self, x, point_split): 116 | conv_out = [] 117 | trans = [] 118 | 119 | trans1 = self.stn1(x) 120 | trans.append(trans1) 121 | x = x.transpose(2, 1) 122 | x = torch.bmm(x, trans1) 123 | x = x.transpose(2, 1) 124 | 125 | x = self.relu(self.bn1(self.conv1(x))) 126 | 127 | trans2 = self.stn2(x) 128 | trans.append(trans2) 129 | x = x.transpose(2, 1) 130 | x = torch.bmm(x, trans2) 131 | x = x.transpose(2, 1) 132 | conv_out.append(x) 133 | 134 | x = self.relu(self.bn2(self.conv2(x))) 135 | x = self.relu(self.bn3(self.conv3(x))) 136 | 137 | x = self.relu(self.bn4(self.conv4(x))) 138 | x = self.relu(self.bn5(self.conv5(x))) 139 | max_feats = [] 140 | for i in range(len(point_split) - 1): 141 | start = point_split[i].item() 142 | end = point_split[i + 1].item() 143 | max_feat = self.avg_pool(x[:, :, start:end]) 144 | max_feats.append( 145 | max_feat.view(-1, 1024 // self.reduction, 146 | 1).repeat(1, 1, end - start)) 147 | 148 | max_feats = torch.cat(max_feats, dim=-1) 149 | 150 | assert max_feats.size(-1) == x.size(-1) 151 | conv_out.append(max_feats) 152 | 153 | return conv_out, trans 154 | -------------------------------------------------------------------------------- /modules/score_net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision 3 | 4 | 5 | class ScoringNet(nn.Module): 6 | 7 | def __init__(self, arch='resnet18'): 8 | super(ScoringNet, self).__init__() 9 | self.out_channels = 1 10 | if 'vgg' in arch: 11 | self.arch = 'vgg' 12 | self.features = torchvision.models.vgg16_bn( 13 | pretrained=True).features 14 | self.score_fc = nn.Sequential( 15 | nn.Linear(512 * 7 * 7, 4096), 16 | nn.ReLU(True), 17 | nn.Dropout(), 18 | nn.Linear(4096, 4096), 19 | nn.ReLU(True), 20 | nn.Dropout(), 21 | nn.Linear(4096, 1), 22 | ) 23 | elif "res" in arch: 24 | if 'resnet18' in arch: 25 | extension = 1 26 | resnet = torchvision.models.resnet18(pretrained=True) 27 | elif 'resnet50' in arch: 28 | extension = 4 29 | resnet = torchvision.models.resnet50(pretrained=True) 30 | self.arch = arch 31 | self.conv1 = resnet.conv1 32 | self.bn1 = resnet.bn1 33 | self.relu = resnet.relu 34 | self.maxpool = resnet.maxpool 35 | self.layer1 = resnet.layer1 36 | self.layer2 = resnet.layer2 37 | self.layer3 = resnet.layer3 38 | self.layer4 = resnet.layer4 39 | self.avgpool = nn.AvgPool2d(7, stride=1) 40 | self.fc = nn.Linear(512 * extension, 1) 41 | 42 | def forward(self, x): 43 | if 'vgg' in self.arch: 44 | x = self.features(x) 45 | x = x.view(x.size(0), -1) 46 | x = self.score_fc(x) 47 | return x 48 | else: 49 | x = self.conv1(x) 50 | x = self.bn1(x) 51 | x = self.relu(x) 52 | x = self.maxpool(x) 53 | 54 | x = self.layer1(x) 55 | x = self.layer2(x) 56 | x = self.layer3(x) 57 | x = self.layer4(x) 58 | 59 | x = self.avgpool(x) 60 | x = x.view(x.size(0), -1) 61 | x = self.fc(x) 62 | return x.transpose(-1, -2) # 1xL 63 | -------------------------------------------------------------------------------- /modules/tracking_net.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .appear_net import AppearanceNet 8 | from .fusion_net import * # noqa 9 | from .gcn import affinity_module 10 | from .new_end import * # noqa 11 | from .point_net import * # noqa 12 | from .score_net import * # noqa 13 | 14 | 15 | class TrackingNet(nn.Module): 16 | 17 | def __init__(self, 18 | seq_len, 19 | appear_len=512, 20 | appear_skippool=False, 21 | appear_fpn=False, 22 | score_arch='vgg', 23 | score_fusion_arch='C', 24 | appear_arch='vgg', 25 | point_arch='v1', 26 | point_len=512, 27 | softmax_mode='single', 28 | test_mode=0, 29 | affinity_op='multiply', 30 | dropblock=5, 31 | end_arch='v2', 32 | end_mode='avg', 33 | without_reflectivity=True, 34 | neg_threshold=0, 35 | use_dropout=False): 36 | super(TrackingNet, self).__init__() 37 | self.seq_len = seq_len 38 | self.score_arch = score_arch 39 | self.neg_threshold = neg_threshold 40 | self.test_mode = test_mode # 0:image;1:image;2:fusion 41 | point_in_channels = 4 - int(without_reflectivity) 42 | 43 | if point_len == 0: 44 | in_channels = appear_len 45 | else: 46 | in_channels = point_len 47 | 48 | self.fusion_module = None 49 | fusion = eval(f"fusion_module_{score_fusion_arch}") 50 | self.fusion_module = fusion( 51 | appear_len, point_len, out_channels=point_len) 52 | 53 | if appear_len == 0: 54 | print('No image appearance used') 55 | self.appearance = None 56 | else: 57 | self.appearance = AppearanceNet( 58 | appear_arch, 59 | appear_len, 60 | skippool=appear_skippool, 61 | fpn=appear_fpn, 62 | dropblock=dropblock) 63 | 64 | # build new end indicator 65 | if end_arch in ['v1', 'v2']: 66 | new_end = partial( 67 | eval("NewEndIndicator_%s" % end_arch), 68 | kernel_size=5, 69 | reduction=4, 70 | mode=end_mode) 71 | 72 | # build point net 73 | if point_len == 0: 74 | print("No point cloud used") 75 | self.point_net = None 76 | elif point_arch in ['v1']: 77 | point_net = eval("PointNet_%s" % point_arch) 78 | self.point_net = point_net( 79 | point_in_channels, 80 | out_channels=point_len, 81 | use_dropout=use_dropout) 82 | else: 83 | print("Not implemented!!") 84 | 85 | # build affinity matrix module 86 | assert in_channels != 0 87 | self.w_link = affinity_module( 88 | in_channels, new_end=new_end, affinity_op=affinity_op) 89 | 90 | # build negative rejection module 91 | if score_arch in ['branch_cls', 'branch_reg']: 92 | self.w_det = nn.Sequential( 93 | nn.Conv1d(in_channels, in_channels, 1, 1), 94 | nn.BatchNorm1d(in_channels), 95 | nn.ReLU(inplace=True), 96 | nn.Conv1d(in_channels, in_channels // 2, 1, 1), 97 | nn.BatchNorm1d(in_channels // 2), 98 | nn.ReLU(inplace=True), 99 | nn.Conv1d(in_channels // 2, 1, 1, 1), 100 | ) 101 | else: 102 | print("Not implement yet") 103 | 104 | self.softmax_mode = softmax_mode 105 | 106 | def associate(self, objs, dets): 107 | link_mat, new_score, end_score = self.w_link(objs, dets) 108 | 109 | if self.softmax_mode == 'single': 110 | link_score = F.softmax(link_mat, dim=-1) 111 | elif self.softmax_mode == 'dual': 112 | link_score_prev = F.softmax(link_mat, dim=-1) 113 | link_score_next = F.softmax(link_mat, dim=-2) 114 | link_score = link_score_prev.mul(link_score_next) 115 | elif self.softmax_mode == 'dual_add': 116 | link_score_prev = F.softmax(link_mat, dim=-1) 117 | link_score_next = F.softmax(link_mat, dim=-2) 118 | link_score = (link_score_prev + link_score_next) / 2 119 | elif self.softmax_mode == 'dual_max': 120 | link_score_prev = F.softmax(link_mat, dim=-1) 121 | link_score_next = F.softmax(link_mat, dim=-2) 122 | link_score = torch.max(link_score_prev, link_score_next) 123 | else: 124 | link_score = link_mat 125 | 126 | return link_score, new_score, end_score 127 | 128 | def feature(self, dets, det_info): 129 | feats = [] 130 | 131 | if self.appearance is not None: 132 | appear = self.appearance(dets) 133 | feats.append(appear) 134 | 135 | trans = None 136 | if self.point_net is not None: 137 | points, trans = self.point_net( 138 | det_info['points'].transpose(-1, -2), 139 | det_info['points_split'].long().squeeze(0)) 140 | feats.append(points) 141 | 142 | feats = torch.cat(feats, dim=-1).t().unsqueeze(0) # LxD->1xDxL 143 | if self.fusion_module is not None: 144 | feats = self.fusion_module(feats) 145 | return feats, trans 146 | 147 | return feats, trans 148 | 149 | def determine_det(self, dets, feats): 150 | det_scores = self.w_det(feats).squeeze(1) # Bx1xL -> BxL 151 | 152 | if not self.training: 153 | # add mask 154 | if 'cls' in self.score_arch: 155 | det_scores = det_scores.sigmoid() 156 | 157 | 158 | # print(det_scores[:, -1].size()) 159 | # mask = det_scores[:, -1].lt(self.neg_threshold) 160 | # det_scores[:, -1] -= mask.float() 161 | mask = det_scores.lt(self.neg_threshold) 162 | det_scores -= mask.float() 163 | return det_scores 164 | 165 | def forward(self, dets, det_info, dets_split): 166 | feats, trans = self.feature(dets, det_info) 167 | det_scores = self.determine_det(dets, feats) 168 | 169 | start = 0 170 | link_scores = [] 171 | new_scores = [] 172 | end_scores = [] 173 | for i in range(len(dets_split) - 1): 174 | prev_end = start + dets_split[i].item() 175 | end = prev_end + dets_split[i + 1].item() 176 | link_score, new_score, end_score = self.associate( 177 | feats[:, :, start:prev_end], feats[:, :, prev_end:end]) 178 | link_scores.append(link_score.squeeze(1)) 179 | new_scores.append(new_score) 180 | end_scores.append(end_score) 181 | start = prev_end 182 | 183 | if not self.training: 184 | fake_new = det_scores.new_zeros( 185 | (det_scores.size(0), link_scores[0].size(-2))) 186 | fake_end = det_scores.new_zeros( 187 | (det_scores.size(0), link_scores[-1].size(-1))) 188 | new_scores = torch.cat([fake_new] + new_scores, dim=1) 189 | end_scores = torch.cat(end_scores + [fake_end], dim=1) 190 | else: 191 | new_scores = torch.cat(new_scores, dim=1) 192 | end_scores = torch.cat(end_scores, dim=1) 193 | return det_scores, link_scores, new_scores, end_scores, trans 194 | -------------------------------------------------------------------------------- /modules/vgg.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | __all__ = [ 7 | 'VGG', 8 | 'vgg11', 9 | 'vgg11_bn', 10 | 'vgg13', 11 | 'vgg13_bn', 12 | 'vgg16', 13 | 'vgg16_bn', 14 | 'vgg19_bn', 15 | 'vgg19', 16 | ] 17 | 18 | model_urls = { 19 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 20 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 21 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 22 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 23 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 24 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 25 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 26 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 27 | } 28 | 29 | 30 | class VGG(nn.Module): 31 | 32 | def __init__(self, features, num_classes=1000): 33 | super(VGG, self).__init__() 34 | self.features = features 35 | self.classifier = nn.Sequential( 36 | nn.Linear(512 * 7 * 7, 4096), 37 | nn.ReLU(True), 38 | nn.Dropout(), 39 | nn.Linear(4096, 4096), 40 | nn.ReLU(True), 41 | nn.Dropout(), 42 | nn.Linear(4096, num_classes), 43 | ) 44 | self._initialize_weights() 45 | 46 | def forward(self, x): 47 | x = self.features(x) 48 | x = x.view(x.size(0), -1) 49 | x = self.classifier(x) 50 | return x 51 | 52 | def _initialize_weights(self): 53 | for m in self.modules(): 54 | if isinstance(m, nn.Conv2d): 55 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 56 | m.weight.data.normal_(0, math.sqrt(2. / n)) 57 | if m.bias is not None: 58 | m.bias.data.zero_() 59 | elif isinstance(m, nn.BatchNorm2d): 60 | m.weight.data.fill_(1) 61 | m.bias.data.zero_() 62 | elif isinstance(m, nn.Linear): 63 | m.weight.data.normal_(0, 0.01) 64 | m.bias.data.zero_() 65 | 66 | 67 | def make_layers(cfg, batch_norm=False): 68 | layers = [] 69 | in_channels = 3 70 | for v in cfg: 71 | if v == 'M': 72 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 73 | else: 74 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 75 | if batch_norm: 76 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 77 | else: 78 | layers += [conv2d, nn.ReLU(inplace=True)] 79 | in_channels = v 80 | return nn.Sequential(*layers) 81 | 82 | 83 | cfg = { 84 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 85 | 'B': 86 | [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 87 | 'D': [ 88 | 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 89 | 512, 512, 512, 'M' 90 | ], 91 | 'D_128': [ 92 | 16, 16, 'M', 32, 32, 'M', 64, 64, 64, 'M', 128, 128, 128, 'M', 128, 93 | 128, 128, 'M' 94 | ], 95 | 'D_256': [ 96 | 32, 32, 'M', 64, 64, 'M', 128, 128, 128, 'M', 256, 256, 256, 'M', 256, 97 | 256, 256, 'M' 98 | ], 99 | 'E': [ 100 | 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 101 | 512, 'M', 512, 512, 512, 512, 'M' 102 | ], 103 | } 104 | 105 | 106 | def vgg11(pretrained=False, **kwargs): 107 | """VGG 11-layer model (configuration "A") 108 | 109 | Args: 110 | pretrained (bool): If True, returns a model pre-trained on ImageNet 111 | """ 112 | model = VGG(make_layers(cfg['A']), **kwargs) 113 | if pretrained: 114 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) 115 | return model 116 | 117 | 118 | def vgg11_bn(pretrained=False, **kwargs): 119 | """VGG 11-layer model (configuration "A") with batch normalization 120 | 121 | Args: 122 | pretrained (bool): If True, returns a model pre-trained on ImageNet 123 | """ 124 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 125 | if pretrained: 126 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) 127 | return model 128 | 129 | 130 | def vgg13(pretrained=False, **kwargs): 131 | """VGG 13-layer model (configuration "B") 132 | 133 | Args: 134 | pretrained (bool): If True, returns a model pre-trained on ImageNet 135 | """ 136 | model = VGG(make_layers(cfg['B']), **kwargs) 137 | if pretrained: 138 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) 139 | return model 140 | 141 | 142 | def vgg13_bn(pretrained=False, **kwargs): 143 | """VGG 13-layer model (configuration "B") with batch normalization 144 | 145 | Args: 146 | pretrained (bool): If True, returns a model pre-trained on ImageNet 147 | """ 148 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 149 | if pretrained: 150 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) 151 | return model 152 | 153 | 154 | def vgg16(pretrained=False, **kwargs): 155 | """VGG 16-layer model (configuration "D") 156 | 157 | Args: 158 | pretrained (bool): If True, returns a model pre-trained on ImageNet 159 | """ 160 | model = VGG(make_layers(cfg['D']), **kwargs) 161 | if pretrained: 162 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 163 | return model 164 | 165 | 166 | def vgg16_bn(pretrained=False, **kwargs): 167 | """VGG 16-layer model (configuration "D") with batch normalization 168 | 169 | Args: 170 | pretrained (bool): If True, returns a model pre-trained on ImageNet 171 | """ 172 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 173 | if pretrained: 174 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) 175 | return model 176 | 177 | 178 | def vgg16_bn_512(pretrained=False, **kwargs): 179 | return vgg16_bn(pretrained, **kwargs) 180 | 181 | 182 | def vgg19(pretrained=False, **kwargs): 183 | """VGG 19-layer model (configuration "E") 184 | 185 | Args: 186 | pretrained (bool): If True, returns a model pre-trained on ImageNet 187 | """ 188 | model = VGG(make_layers(cfg['E']), **kwargs) 189 | if pretrained: 190 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) 191 | return model 192 | 193 | 194 | def vgg19_bn(pretrained=False, **kwargs): 195 | """VGG 19-layer model (configuration 'E') with batch normalization 196 | 197 | Args: 198 | pretrained (bool): If True, returns a model pre-trained on ImageNet 199 | """ 200 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 201 | if pretrained: 202 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) 203 | return model 204 | 205 | 206 | def vgg16_bn_256(pretrained=False, **kwargs): 207 | """VGG 16-layer model (configuration "D") with batch normalization 208 | 209 | Args: 210 | pretrained (bool): If True, returns a model pre-trained on ImageNet 211 | """ 212 | model = VGG(make_layers(cfg['D_256'], batch_norm=True), **kwargs) 213 | if pretrained: 214 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) 215 | return model 216 | 217 | 218 | def vgg16_bn_128(pretrained=False, **kwargs): 219 | """VGG 16-layer model (configuration "D") with batch normalization 220 | 221 | Args: 222 | pretrained (bool): If True, returns a model pre-trained on ImageNet 223 | """ 224 | model = VGG(make_layers(cfg['D_128'], batch_norm=True), **kwargs) 225 | if pretrained: 226 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) 227 | return model 228 | -------------------------------------------------------------------------------- /point_cloud/geometry.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numba 3 | 4 | 5 | @numba.njit 6 | def is_line_segment_intersection_jit(lines1, lines2): 7 | """check if line segments1 and line segments2 have cross point 8 | 9 | Args: 10 | lines1 (float, [N, 2, 2]): [description] 11 | lines2 (float, [M, 2, 2]): [description] 12 | 13 | Returns: 14 | [type]: [description] 15 | """ 16 | 17 | # Return true if line segments AB and CD intersect 18 | N = lines1.shape[0] 19 | M = lines2.shape[0] 20 | ret = np.zeros((N, M), dtype=np.bool_) 21 | for i in range(N): 22 | for j in range(M): 23 | A = lines1[i, 0] 24 | B = lines1[i, 1] 25 | C = lines2[j, 0] 26 | D = lines2[j, 1] 27 | acd = (D[1] - A[1]) * (C[0] - A[0]) > (C[1] - A[1]) * (D[0] - A[0]) 28 | bcd = (D[1] - B[1]) * (C[0] - B[0]) > (C[1] - B[1]) * (D[0] - B[0]) 29 | if acd != bcd: 30 | abc = (C[1] - A[1]) * (B[0] - A[0]) > (B[1] - A[1]) * (C[0] - A[0]) 31 | abd = (D[1] - A[1]) * (B[0] - A[0]) > (B[1] - A[1]) * (D[0] - A[0]) 32 | if abc != abd: 33 | ret[i, j] = True 34 | return ret 35 | 36 | 37 | @numba.njit 38 | def line_segment_intersection(line1, line2, intersection): 39 | A = line1[0] 40 | B = line1[1] 41 | C = line2[0] 42 | D = line2[1] 43 | BA0 = B[0] - A[0] 44 | BA1 = B[1] - A[1] 45 | DA0 = D[0] - A[0] 46 | CA0 = C[0] - A[0] 47 | DA1 = D[1] - A[1] 48 | CA1 = C[1] - A[1] 49 | acd = DA1 * CA0 > CA1 * DA0 50 | bcd = (D[1] - B[1]) * (C[0] - B[0]) > (C[1] - B[1]) * (D[0] - B[0]) 51 | if acd != bcd: 52 | abc = CA1 * BA0 > BA1 * CA0 53 | abd = DA1 * BA0 > BA1 * DA0 54 | if abc != abd: 55 | DC0 = D[0] - C[0] 56 | DC1 = D[1] - C[1] 57 | ABBA = A[0] * B[1] - B[0] * A[1] 58 | CDDC = C[0] * D[1] - D[0] * C[1] 59 | DH = BA1 * DC0 - BA0 * DC1 60 | intersection[0] = (ABBA * DC0 - BA0 * CDDC) / DH 61 | intersection[1] = (ABBA * DC1 - BA1 * CDDC) / DH 62 | return True 63 | return False 64 | 65 | 66 | def _ccw(A, B, C): 67 | return (C[..., 1] - A[..., 1]) * (B[..., 0] - A[..., 0]) > ( 68 | B[..., 1] - A[..., 1]) * (C[..., 0] - A[..., 0]) 69 | 70 | 71 | def is_line_segment_cross(lines1, lines2): 72 | # 10x slower than jit version with 1000-1000 random lines input. 73 | # lines1, [N, 2, 2] 74 | # lines2, [M, 2, 2] 75 | A = lines1[:, 0, :][:, np.newaxis, :] 76 | B = lines1[:, 1, :][:, np.newaxis, :] 77 | C = lines2[:, 0, :][np.newaxis, :, :] 78 | D = lines2[:, 1, :][np.newaxis, :, :] 79 | return np.logical_and( 80 | _ccw(A, C, D) != _ccw(B, C, D), 81 | _ccw(A, B, C) != _ccw(A, B, D)) 82 | 83 | 84 | def surface_equ_3d(polygon_surfaces): 85 | # return [a, b, c], d in ax+by+cz+d=0 86 | # polygon_surfaces: [num_polygon, num_surfaces, num_points_of_polygon, 3] 87 | surface_vec = polygon_surfaces[:, :, :2, :] - polygon_surfaces[:, :, 1:3, :] 88 | # normal_vec: [..., 3] 89 | normal_vec = np.cross(surface_vec[:, :, 0, :], surface_vec[:, :, 1, :]) 90 | # print(normal_vec.shape, points[..., 0, :].shape) 91 | # d = -np.inner(normal_vec, points[..., 0, :]) 92 | d = np.einsum('aij, aij->ai', normal_vec, polygon_surfaces[:, :, 0, :]) 93 | return normal_vec, -d 94 | 95 | 96 | @numba.njit 97 | def _points_in_convex_polygon_3d_jit(points, polygon_surfaces, normal_vec, d, num_surfaces): 98 | max_num_surfaces, max_num_points_of_surface = polygon_surfaces.shape[1:3] 99 | num_points = points.shape[0] 100 | num_polygons = polygon_surfaces.shape[0] 101 | ret = np.ones((num_points, num_polygons), dtype=np.bool_) 102 | sign = 0.0 103 | for i in range(num_points): 104 | for j in range(num_polygons): 105 | for k in range(max_num_surfaces): 106 | if k > num_surfaces[j]: 107 | break 108 | sign = points[i, 0] * normal_vec[j, k, 0] \ 109 | + points[i, 1] * normal_vec[j, k, 1] \ 110 | + points[i, 2] * normal_vec[j, k, 2] + d[j, k] 111 | if sign >= 0: 112 | ret[i, j] = False 113 | break 114 | return ret 115 | 116 | 117 | def points_in_convex_polygon_3d_jit(points, 118 | polygon_surfaces, 119 | num_surfaces=None): 120 | """check points is in 3d convex polygons. 121 | Args: 122 | points: [num_points, 3] array. 123 | polygon_surfaces: [num_polygon, max_num_surfaces, 124 | max_num_points_of_surface, 3] 125 | array. all surfaces' normal vector must direct to internal. 126 | max_num_points_of_surface must at least 3. 127 | num_surfaces: [num_polygon] array. indicate how many surfaces 128 | a polygon contain 129 | Returns: 130 | [num_points, num_polygon] bool array. 131 | """ 132 | max_num_surfaces, max_num_points_of_surface = polygon_surfaces.shape[1:3] 133 | num_points = points.shape[0] 134 | num_polygons = polygon_surfaces.shape[0] 135 | if num_surfaces is None: 136 | num_surfaces = np.full((num_polygons,), 9999999, dtype=np.int64) 137 | normal_vec, d = surface_equ_3d(polygon_surfaces[:, :, :3, :]) 138 | # normal_vec: [num_polygon, max_num_surfaces, 3] 139 | # d: [num_polygon, max_num_surfaces] 140 | return _points_in_convex_polygon_3d_jit(points, polygon_surfaces, normal_vec, d, num_surfaces) 141 | 142 | 143 | @numba.jit 144 | def points_in_convex_polygon_jit(points, polygon, clockwise=True): 145 | """check points is in 2d convex polygons. True when point in polygon 146 | Args: 147 | points: [num_points, 2] array. 148 | polygon: [num_polygon, num_points_of_polygon, 2] array. 149 | clockwise: bool. indicate polygon is clockwise. 150 | Returns: 151 | [num_points, num_polygon] bool array. 152 | """ 153 | # first convert polygon to directed lines 154 | num_points_of_polygon = polygon.shape[1] 155 | num_points = points.shape[0] 156 | num_polygons = polygon.shape[0] 157 | # if clockwise: 158 | # vec1 = polygon - polygon[:, [num_points_of_polygon - 1] + 159 | # list(range(num_points_of_polygon - 1)), :] 160 | # else: 161 | # vec1 = polygon[:, [num_points_of_polygon - 1] + 162 | # list(range(num_points_of_polygon - 1)), :] - polygon 163 | # vec1: [num_polygon, num_points_of_polygon, 2] 164 | vec1 = np.zeros((2), dtype=polygon.dtype) 165 | ret = np.zeros((num_points, num_polygons), dtype=np.bool_) 166 | success = True 167 | cross = 0.0 168 | for i in range(num_points): 169 | for j in range(num_polygons): 170 | success = True 171 | for k in range(num_points_of_polygon): 172 | if clockwise: 173 | vec1 = polygon[j, k] - polygon[j, k-1] 174 | else: 175 | vec1 = polygon[j, k-1] - polygon[j, k] 176 | cross = vec1[1] * (polygon[j, k, 0] - points[i, 0]) 177 | cross -= vec1[0] * (polygon[j, k, 1] - points[i, 1]) 178 | if cross >= 0: 179 | success = False 180 | break 181 | ret[i, j] = success 182 | return ret 183 | 184 | 185 | def points_in_convex_polygon(points, polygon, clockwise=True): 186 | """check points is in convex polygons. may run 2x faster when write in 187 | cython(don't need to calculate all cross-product between edge and point) 188 | Args: 189 | points: [num_points, 2] array. 190 | polygon: [num_polygon, num_points_of_polygon, 2] array. 191 | clockwise: bool. indicate polygon is clockwise. 192 | Returns: 193 | [num_points, num_polygon] bool array. 194 | """ 195 | # first convert polygon to directed lines 196 | num_lines = polygon.shape[1] 197 | polygon_next = polygon[:, [num_lines - 1] + list(range(num_lines - 1)), :] 198 | if clockwise: 199 | vec1 = (polygon - polygon_next)[np.newaxis, ...] 200 | else: 201 | vec1 = (polygon_next - polygon)[np.newaxis, ...] 202 | vec2 = polygon[np.newaxis, ...] - points[:, np.newaxis, np.newaxis, :] 203 | # [num_points, num_polygon, num_points_of_polygon, 2] 204 | cross = np.cross(vec1, vec2) 205 | return np.all(cross > 0, axis=2) 206 | 207 | -------------------------------------------------------------------------------- /point_cloud/point_cloud_ops.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numba 4 | import numpy as np 5 | 6 | 7 | @numba.jit(nopython=True) 8 | def _points_to_voxel_reverse_kernel(points, 9 | voxel_size, 10 | coors_range, 11 | num_points_per_voxel, 12 | coor_to_voxelidx, 13 | voxels, 14 | coors, 15 | max_points=35, 16 | max_voxels=20000): 17 | # put all computations to one loop. 18 | # we shouldn't create large array in main jit code, otherwise 19 | # reduce performance 20 | N = points.shape[0] 21 | # ndim = points.shape[1] - 1 22 | ndim = 3 23 | ndim_minus_1 = ndim - 1 24 | grid_size = (coors_range[3:] - coors_range[:3]) / voxel_size 25 | # np.round(grid_size) 26 | # grid_size = np.round(grid_size).astype(np.int64)(np.int32) 27 | grid_size = np.round(grid_size, 0, grid_size).astype(np.int32) 28 | coor = np.zeros(shape=(3, ), dtype=np.int32) 29 | voxel_num = 0 30 | failed = False 31 | for i in range(N): 32 | failed = False 33 | for j in range(ndim): 34 | c = np.floor((points[i, j] - coors_range[j]) / voxel_size[j]) 35 | if c < 0 or c >= grid_size[j]: 36 | failed = True 37 | break 38 | coor[ndim_minus_1 - j] = c 39 | if failed: 40 | continue 41 | voxelidx = coor_to_voxelidx[coor[0], coor[1], coor[2]] 42 | if voxelidx == -1: 43 | voxelidx = voxel_num 44 | if voxel_num >= max_voxels: 45 | break 46 | voxel_num += 1 47 | coor_to_voxelidx[coor[0], coor[1], coor[2]] = voxelidx 48 | coors[voxelidx] = coor 49 | num = num_points_per_voxel[voxelidx] 50 | if num < max_points: 51 | voxels[voxelidx, num] = points[i] 52 | num_points_per_voxel[voxelidx] += 1 53 | return voxel_num 54 | 55 | @numba.jit(nopython=True) 56 | def _points_to_voxel_kernel(points, 57 | voxel_size, 58 | coors_range, 59 | num_points_per_voxel, 60 | coor_to_voxelidx, 61 | voxels, 62 | coors, 63 | max_points=35, 64 | max_voxels=20000): 65 | # need mutex if write in cuda, but numba.cuda don't support mutex. 66 | # in addition, pytorch don't support cuda in dataloader(tensorflow support this). 67 | # put all computations to one loop. 68 | # we shouldn't create large array in main jit code, otherwise 69 | # decrease performance 70 | N = points.shape[0] 71 | # ndim = points.shape[1] - 1 72 | ndim = 3 73 | grid_size = (coors_range[3:] - coors_range[:3]) / voxel_size 74 | # grid_size = np.round(grid_size).astype(np.int64)(np.int32) 75 | grid_size = np.round(grid_size, 0, grid_size).astype(np.int32) 76 | 77 | lower_bound = coors_range[:3] 78 | upper_bound = coors_range[3:] 79 | coor = np.zeros(shape=(3, ), dtype=np.int32) 80 | voxel_num = 0 81 | failed = False 82 | for i in range(N): 83 | failed = False 84 | for j in range(ndim): 85 | c = np.floor((points[i, j] - coors_range[j]) / voxel_size[j]) 86 | if c < 0 or c >= grid_size[j]: 87 | failed = True 88 | break 89 | coor[j] = c 90 | if failed: 91 | continue 92 | voxelidx = coor_to_voxelidx[coor[0], coor[1], coor[2]] 93 | if voxelidx == -1: 94 | voxelidx = voxel_num 95 | if voxel_num >= max_voxels: 96 | break 97 | voxel_num += 1 98 | coor_to_voxelidx[coor[0], coor[1], coor[2]] = voxelidx 99 | coors[voxelidx] = coor 100 | num = num_points_per_voxel[voxelidx] 101 | if num < max_points: 102 | voxels[voxelidx, num] = points[i] 103 | num_points_per_voxel[voxelidx] += 1 104 | return voxel_num 105 | 106 | 107 | def points_to_voxel(points, 108 | voxel_size, 109 | coors_range, 110 | max_points=35, 111 | reverse_index=True, 112 | max_voxels=20000): 113 | """convert kitti points(N, >=3) to voxels. This version calculate 114 | everything in one loop. now it takes only 4.2ms(complete point cloud) 115 | with jit and 3.2ghz cpu.(don't calculate other features) 116 | Note: this function in ubuntu seems faster than windows 10. 117 | 118 | Args: 119 | points: [N, ndim] float tensor. points[:, :3] contain xyz points and 120 | points[:, 3:] contain other information such as reflectivity. 121 | voxel_size: [3] list/tuple or array, float. xyz, indicate voxel size 122 | coors_range: [6] list/tuple or array, float. indicate voxel range. 123 | format: xyzxyz, minmax 124 | max_points: int. indicate maximum points contained in a voxel. 125 | reverse_index: boolean. indicate whether return reversed coordinates. 126 | if points has xyz format and reverse_index is True, output 127 | coordinates will be zyx format, but points in features always 128 | xyz format. 129 | max_voxels: int. indicate maximum voxels this function create. 130 | for second, 20000 is a good choice. you should shuffle points 131 | before call this function because max_voxels may drop some points. 132 | 133 | Returns: 134 | voxels: [M, max_points, ndim] float tensor. only contain points. 135 | coordinates: [M, 3] int32 tensor. 136 | num_points_per_voxel: [M] int32 tensor. 137 | """ 138 | if not isinstance(voxel_size, np.ndarray): 139 | voxel_size = np.array(voxel_size, dtype=points.dtype) 140 | if not isinstance(coors_range, np.ndarray): 141 | coors_range = np.array(coors_range, dtype=points.dtype) 142 | voxelmap_shape = (coors_range[3:] - coors_range[:3]) / voxel_size 143 | voxelmap_shape = tuple(np.round(voxelmap_shape).astype(np.int32).tolist()) 144 | if reverse_index: 145 | voxelmap_shape = voxelmap_shape[::-1] 146 | # don't create large array in jit(nopython=True) code. 147 | num_points_per_voxel = np.zeros(shape=(max_voxels, ), dtype=np.int32) 148 | coor_to_voxelidx = -np.ones(shape=voxelmap_shape, dtype=np.int32) 149 | voxels = np.zeros( 150 | shape=(max_voxels, max_points, points.shape[-1]), dtype=points.dtype) 151 | coors = np.zeros(shape=(max_voxels, 3), dtype=np.int32) 152 | if reverse_index: 153 | voxel_num = _points_to_voxel_reverse_kernel( 154 | points, voxel_size, coors_range, num_points_per_voxel, 155 | coor_to_voxelidx, voxels, coors, max_points, max_voxels) 156 | 157 | else: 158 | voxel_num = _points_to_voxel_kernel( 159 | points, voxel_size, coors_range, num_points_per_voxel, 160 | coor_to_voxelidx, voxels, coors, max_points, max_voxels) 161 | 162 | coors = coors[:voxel_num] 163 | voxels = voxels[:voxel_num] 164 | num_points_per_voxel = num_points_per_voxel[:voxel_num] 165 | # voxels[:, :, -3:] = voxels[:, :, :3] - \ 166 | # voxels[:, :, :3].sum(axis=1, keepdims=True)/num_points_per_voxel.reshape(-1, 1, 1) 167 | return voxels, coors, num_points_per_voxel 168 | 169 | 170 | @numba.jit(nopython=True) 171 | def bound_points_jit(points, upper_bound, lower_bound): 172 | # to use nopython=True, np.bool is not supported. so you need 173 | # convert result to np.bool after this function. 174 | N = points.shape[0] 175 | ndim = points.shape[1] 176 | keep_indices = np.zeros((N, ), dtype=np.int32) 177 | success = 0 178 | for i in range(N): 179 | success = 1 180 | for j in range(ndim): 181 | if points[i, j] < lower_bound[j] or points[i, j] >= upper_bound[j]: 182 | success = 0 183 | break 184 | keep_indices[i] = success 185 | return keep_indices 186 | -------------------------------------------------------------------------------- /point_cloud/preprocess.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | 5 | from .box_np_ops import (points_in_rbbox, box_camera_to_lidar, 6 | get_frustum_points, remove_outside_points) 7 | 8 | 9 | 10 | def merge_second_batch(batch_list, _unused=False): 11 | example_merged = defaultdict(list) 12 | for example in batch_list: 13 | for k, v in example.items(): 14 | example_merged[k].append(v) 15 | ret = {} 16 | example_merged.pop("num_voxels") 17 | for key, elems in example_merged.items(): 18 | if key in [ 19 | 'voxels', 'num_points', 'num_gt', 'gt_boxes', 'voxel_labels', 20 | 'match_indices' 21 | ]: 22 | ret[key] = np.concatenate(elems, axis=0) 23 | elif key == 'match_indices_num': 24 | ret[key] = np.concatenate(elems, axis=0) 25 | elif key == 'coordinates': 26 | coors = [] 27 | for i, coor in enumerate(elems): 28 | coor_pad = np.pad( 29 | coor, ((0, 0), (1, 0)), 30 | mode='constant', 31 | constant_values=i) 32 | coors.append(coor_pad) 33 | ret[key] = np.concatenate(coors, axis=0) 34 | else: 35 | ret[key] = np.stack(elems, axis=0) 36 | return ret 37 | 38 | 39 | def remove_points_outside_boxes(points, boxes): 40 | masks = points_in_rbbox(points, boxes) 41 | points = points[masks.any(-1)] 42 | return points 43 | 44 | 45 | def read_and_prep_points(info, root_path, point_path, dets, use_frustum=False, 46 | num_point_features=4, without_reflectivity=False, det_type='3D', shift_bbox=None): 47 | """read data from KITTI-format infos, then call prep function. 48 | """ 49 | # read point cloud 50 | point_path_split = point_path.split('-') 51 | v_path = f'{root_path}/velodyne/{point_path_split[0]}/{point_path_split[1]}' 52 | # v_path = f'{root_path}/velodyne_reduced/{point_path}' 53 | 54 | points = np.fromfile( 55 | str(v_path), dtype=np.float32, 56 | count=-1).reshape([-1, num_point_features]) 57 | 58 | # Load Calibration 59 | rect = info['calib/R0_rect'].astype(np.float32) 60 | Trv2c = info['calib/Tr_velo_to_cam'].astype(np.float32) 61 | P2 = info['calib/P2'].astype(np.float32) 62 | 63 | # remove point cloud out side image, this might affect the performance 64 | points = remove_outside_points(points, rect, Trv2c, P2, info["img_shape"]) 65 | 66 | # remove the points that is outside the bboxes or frustum 67 | bbox_points = [] 68 | points_split = [0] 69 | if det_type == '3D' and not use_frustum: 70 | loc = dets["location"].copy() # This is in the camera coordinates 71 | dims = dets["dimensions"].copy() # This should be standard lhw(camera) format 72 | rots = dets["rotation_y"].copy() 73 | 74 | # print(gt_names, len(loc)) 75 | boxes = np.concatenate( 76 | [loc, dims, rots[..., np.newaxis]], axis=1).astype(np.float32) 77 | boxes = box_camera_to_lidar(boxes, rect, Trv2c) # change the boxes to velo coordinates 78 | for i in range(boxes.shape[0]): 79 | bbox_point = remove_points_outside_boxes(points, boxes[i:i+1]) 80 | if bbox_point.shape[0] == 0: 81 | bbox_point = np.zeros(shape=(1,4)) 82 | points_split.append(points_split[-1]+bbox_point.shape[0]) 83 | bbox_points.append(bbox_point) 84 | bbox_points = np.concatenate(bbox_points, axis=0) 85 | else: 86 | boxes = shift_bbox.copy() if shift_bbox is not None else dets['bbox'].copy() 87 | for i in range(boxes.shape[0]): 88 | bbox_point = get_frustum_points(points, boxes[i:i+1], rect, Trv2c, P2) 89 | if bbox_point.shape[0] == 0: 90 | bbox_point = np.zeros(shape=(1,4)) 91 | points_split.append(points_split[-1]+bbox_point.shape[0]) 92 | bbox_points.append(bbox_point) 93 | bbox_points = np.concatenate(bbox_points, axis=0) 94 | 95 | 96 | if without_reflectivity: 97 | used_point_axes = list(range(num_point_features)) 98 | used_point_axes.pop(3) 99 | bbox_points = bbox_points[:, used_point_axes] 100 | 101 | example = { 102 | 'points': bbox_points, 103 | 'points_split': points_split, 104 | } 105 | 106 | return example 107 | 108 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | motmetrics 2 | PyYAML 3 | easydict 4 | munkres 5 | ortools 6 | pyproj 7 | opencv-python 8 | -------------------------------------------------------------------------------- /solvers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | import scipy.optimize as optimize 5 | import torch 6 | from ortools.linear_solver import pywraplp 7 | 8 | 9 | def ortools_solve(det_score, 10 | link_score, 11 | new_score, 12 | end_score, 13 | det_split, 14 | gt=None): 15 | solver = pywraplp.Solver('SolveAssignmentProblemMIP', 16 | pywraplp.Solver.CBC_MIXED_INTEGER_PROGRAMMING) 17 | y_det = {} 18 | y_new = {} 19 | y_end = {} 20 | for i in range(det_score.size(0)): 21 | y_det[i] = solver.BoolVar('y_det[%i]' % (i)) 22 | y_new[i] = solver.BoolVar('y_new[%i]' % (i)) 23 | y_end[i] = solver.BoolVar('y_end[%i]' % (i)) 24 | y_link = {} 25 | for i in range(len(link_score)): 26 | y_link[i] = {} 27 | for j in range(link_score[i][0].size(0)): 28 | y_link[i][j] = {} 29 | for k in range(link_score[i][0].size(1)): 30 | y_link[i][j][k] = solver.BoolVar(f'y_link[{i}, {j}, {k}]') 31 | w_link_y = [] 32 | for i in range(len(link_score)): 33 | for j in range(link_score[i][0].size(0)): 34 | for k in range(link_score[i][0].size(1)): 35 | w_link_y.append(y_link[i][j][k] * 36 | link_score[i][0][j][k].item()) 37 | w_det_y = [ 38 | y_det[i] * det_score[i].item() for i in range(det_score.size(0)) 39 | ] 40 | w_new_y = [ 41 | y_new[i] * new_score[i].item() for i in range(det_score.size(0)) 42 | ] 43 | w_end_y = [ 44 | y_end[i] * end_score[i].item() for i in range(det_score.size(0)) 45 | ] 46 | 47 | # Objective 48 | if gt is None: 49 | solver.Maximize(solver.Sum(w_det_y + w_new_y + w_end_y + w_link_y)) 50 | else: 51 | (gt_det, gt_new, gt_end, gt_link) = gt 52 | gt_eff_det = gt_det + gt_det.eq(0).float().mul(-1) 53 | gt_eff_new = gt_new + gt_new.eq(0).float().mul(-1) 54 | gt_eff_end = gt_end + gt_end.eq(0).float().mul(-1) 55 | gt_eff_link = [] 56 | for i in range(len(link_score)): 57 | gt_eff_link.append(gt_link[i] + gt_link[i].eq(0).float().mul(-1)) 58 | 59 | delta_det = [ 60 | gt_det[i].item() - y_det[i] * gt_eff_det[i].item() 61 | for i in range(det_score.size(0)) 62 | ] 63 | delta_new = [ 64 | gt_new[i].item() - y_new[i] * gt_eff_new[i].item() 65 | for i in range(det_score.size(0)) 66 | ] 67 | delta_end = [ 68 | gt_end[i].item() - y_end[i] * gt_eff_end[i].item() 69 | for i in range(det_score.size(0)) 70 | ] 71 | delta_link = [] 72 | 73 | for i in range(len(link_score)): 74 | for j in range(link_score[i][0].size(0)): 75 | for k in range(link_score[i][0].size(1)): 76 | delta_link.append(gt_link[i][0][j][k].item() - 77 | y_link[i][j][k] * 78 | gt_eff_link[i][0][j][k].item()) 79 | solver.Maximize( 80 | solver.Sum(w_det_y + w_new_y + w_end_y + w_link_y + delta_det + 81 | delta_new + delta_end + delta_link)) 82 | 83 | # Constraints 84 | # Set constraint for fomular 1 85 | det_start_idx = 0 86 | for i in range(len(det_split) - 1): 87 | det_curr_num = det_split[i].item() 88 | for j in range(det_curr_num): 89 | det_idx = det_start_idx + j 90 | successor_link = [ 91 | y_link[i][j][k] for k in range(len(y_link[i][j])) 92 | ] 93 | # end + successor = det 94 | solver.Add( 95 | solver.Sum([y_end[det_idx], (-1) * y_det[det_idx]] + 96 | successor_link) == 0) 97 | if i == 0: 98 | solver.Add( 99 | solver.Sum([y_new[det_idx], (-1) * y_det[det_idx]]) == 0) 100 | det_start_idx += det_curr_num 101 | det_next_num = det_split[i + 1].item() 102 | for j in range(det_next_num): 103 | det_idx = det_start_idx + j 104 | # new + prec = det 105 | precedding_link = [y_link[i][k][j] for k in range(len(y_link[i]))] 106 | solver.Add( 107 | solver.Sum([y_new[det_idx], (-1) * y_det[det_idx]] + 108 | precedding_link) == 0) 109 | if i == len(det_split) - 2: 110 | solver.Add( 111 | solver.Sum([y_end[det_idx], (-1) * y_det[det_idx]]) == 0) 112 | 113 | sol = solver.Solve() # noqa 114 | 115 | det_start_idx = 0 116 | assign_det = det_score.new_zeros(det_score.size()) 117 | assign_new = det_score.new_zeros(det_score.size()) 118 | assign_end = det_score.new_zeros(det_score.size()) 119 | assign_link = [] 120 | for i in range(len(det_split)): 121 | det_curr_num = det_split[i].item() 122 | if i != len(det_split) - 1: 123 | link_matrix = det_score.new_zeros(link_score[i].size()) 124 | for j in range(det_curr_num): 125 | det_idx = det_start_idx + j 126 | assign_new[det_idx] = y_new[det_idx].solution_value() 127 | assign_end[det_idx] = y_end[det_idx].solution_value() 128 | assign_det[det_idx] = y_det[det_idx].solution_value() 129 | if i != len(det_split) - 1: 130 | for k in range(len(y_link[i][j])): 131 | link_matrix[0][j][k] = y_link[i][j][k].solution_value() 132 | 133 | # end + successor = det 134 | det_start_idx += det_curr_num 135 | if i != len(det_split) - 1: 136 | assign_link.append(link_matrix) 137 | 138 | return assign_det, assign_link, assign_new, assign_end 139 | 140 | 141 | class scipy_solver(object): 142 | 143 | def calculate_det_len(self, det_split): 144 | w_det_len = det_split[-1].item() 145 | w_link_len = 0 146 | for i in range(len(det_split) - 1): 147 | w_det_len += det_split[i].item() 148 | w_link_len += det_split[i].item() * det_split[i + 1].item() 149 | 150 | total_len = w_det_len * 3 + w_link_len 151 | return total_len, w_det_len 152 | 153 | def buildLP(self, det_score, link_score, new_score, end_score, det_split): 154 | # LP constriants initialize 155 | 156 | total_len, w_det_len = self.calculate_det_len(det_split) 157 | A_eq = torch.zeros(w_det_len * 2, total_len) 158 | b_eq = torch.zeros(w_det_len * 2) 159 | bounds = [(0, 1)] * total_len 160 | # cost initialize 161 | cost = det_score.new_empty(total_len) 162 | cost[:w_det_len] = det_score.squeeze(-1).clone() 163 | cost[w_det_len:w_det_len * 2] = new_score.clone() 164 | cost[w_det_len * 2:w_det_len * 3] = end_score.clone() 165 | 166 | # inequality to bounds new and end results, not from paper 167 | # y_new + y_end <= 1 168 | b_ub = torch.ones(w_det_len) 169 | A_ub = torch.zeros(w_det_len, total_len) 170 | 171 | # LP constriants calculate 172 | link_start_idx = w_det_len * 3 173 | det_start_idx = 0 174 | # A_eq: [w_det, w_new, w_end, link_1, link_2, link_3...] 175 | for i in range(len(det_split)): 176 | det_curr_num = det_split[i].item( 177 | ) # current frame i has det_i detects 178 | for k in range(det_curr_num): 179 | curr_det_idx = det_start_idx + k 180 | A_eq[curr_det_idx, curr_det_idx] = -1 # indicate current w_det 181 | A_eq[curr_det_idx, 182 | w_det_len + curr_det_idx] = 1 # indicate current w_new 183 | A_eq[w_det_len + curr_det_idx, 184 | curr_det_idx] = -1 # indicate current w_det 185 | A_eq[w_det_len + curr_det_idx, w_det_len * 2 + 186 | curr_det_idx] = 1 # indicate current w_end 187 | A_ub[curr_det_idx, 188 | w_det_len + curr_det_idx] = 1 # indicate current w_new 189 | A_ub[curr_det_idx, w_det_len * 2 + 190 | curr_det_idx] = 1 # indicate current w_end 191 | # calculate link to next frame 192 | if i < len(det_split) - 1: 193 | det_next_num = det_split[ 194 | i + 1] # next frame j has det_j detects 195 | curr_row_idx = link_start_idx + k * det_next_num 196 | A_eq[w_det_len + curr_det_idx, curr_row_idx:curr_row_idx + 197 | det_next_num] = 1 # sum(y_i) 198 | 199 | # calculate cost 200 | cost[curr_row_idx:curr_row_idx + 201 | det_next_num] = link_score[i][0, k].clone() 202 | 203 | # calculate link to prev frame 204 | if i > 0: 205 | det_prev_num = det_split[i - 1] 206 | prev_row_idx = link_start_idx - det_curr_num * det_prev_num 207 | A_eq[curr_det_idx, 208 | prev_row_idx + k:link_start_idx:det_curr_num] = 1 209 | 210 | link_start_idx += det_curr_num * det_next_num 211 | det_start_idx += det_curr_num 212 | 213 | return cost, A_ub, b_ub, A_eq, b_eq, bounds 214 | 215 | def solve(self, det_score, link_score, new_score, end_score, det_split): 216 | cost, A_ub, b_ub, A_eq, b_eq, bounds = self.buildLP( 217 | det_score, link_score, new_score, end_score, det_split) 218 | results = optimize.linprog( 219 | c=-cost.detach().cpu().numpy(), 220 | A_eq=A_eq.cpu().numpy(), 221 | b_eq=b_eq.cpu().numpy(), 222 | bounds=bounds, 223 | method='interior-point', 224 | options={ 225 | 'lstsq': False, 226 | 'presolve': True, 227 | '_sparse_presolve': True, 228 | 'sparse': True 229 | }) 230 | 231 | y = det_score.new_tensor(np.around(results.x)) 232 | return y, cost 233 | 234 | def generate_gt(self, cost, det_id, det_cls, det_split): 235 | total_len, w_det_len = self.calculate_det_len(det_split) 236 | gt_y = cost.new_zeros(total_len) 237 | link_start_idx = w_det_len * 3 238 | det_start_idx = 0 239 | for i in range(len(det_split)): 240 | det_curr_num = det_split[i] # current frame i has det_i detects 241 | # Assign the score, according to eq1 242 | for j in range(det_curr_num): 243 | curr_det_idx = det_start_idx + j 244 | # g_det 245 | if det_cls[i][j] == 0: 246 | # gt_y[curr_det_idx] = 0 # if negtive 247 | continue 248 | elif det_cls[i][j] == 1: 249 | gt_y[curr_det_idx] = 1 # positive 250 | 251 | # g_link 252 | if i == len(det_split) - 1: 253 | # end det at last frame 254 | gt_y[w_det_len * 2 + curr_det_idx] = 1 255 | else: 256 | matched = False 257 | det_next_num = det_split[i + 1] 258 | curr_row_idx = link_start_idx + j * det_next_num 259 | for k in range(det_next_num): 260 | if det_id[i][j] == det_id[i + 1][k]: 261 | gt_y[curr_row_idx + k] = 1 262 | matched = True 263 | break 264 | if not matched: 265 | gt_y[w_det_len * 2 + curr_det_idx] = 1 266 | 267 | if i == 0: 268 | # new det at first frame 269 | gt_y[w_det_len + curr_det_idx] = 1 270 | else: 271 | # look prev 272 | matched = False 273 | det_prev_num = det_split[i - 1] 274 | for k in range(det_prev_num): 275 | if det_id[i][j] == det_id[i - 1][k]: 276 | matched = True 277 | break 278 | if not matched: 279 | gt_y[w_det_len + curr_det_idx] = 1 280 | 281 | link_start_idx += det_curr_num * det_next_num 282 | det_start_idx += det_curr_num 283 | 284 | return gt_y 285 | 286 | def assign_det_id(self, y, det_split, dets): 287 | total_len, w_det_len = self.calculate_det_len(det_split) 288 | link_start_idx = w_det_len * 3 289 | det_start_idx = 0 290 | det_ids = [] 291 | already_used_id = [] 292 | fake_ids = [] 293 | dets_out = [] 294 | for i in range(len(det_split)): 295 | frame_id = [] 296 | det_curr_num = det_split[i] 297 | fake_id = [] 298 | det_out = [] 299 | for j in range(det_curr_num): 300 | curr_det_idx = det_start_idx + j 301 | # check w_det 302 | if y[curr_det_idx] != 1: 303 | fake_id.append(-1) 304 | continue 305 | else: 306 | det_out.append(dets[i][:, j]) 307 | 308 | # w_det=1, check whether a new det 309 | if i == 0: 310 | det_prev_num = 0 311 | if len(already_used_id) == 0: 312 | frame_id.append(0) 313 | fake_id.append(0) 314 | already_used_id.append(0) 315 | else: 316 | new_id = already_used_id[-1] + 1 317 | frame_id.append(new_id) 318 | fake_id.append(new_id) 319 | already_used_id.append(new_id) 320 | continue 321 | elif y[w_det_len + curr_det_idx] == 1: 322 | new_id = already_used_id[-1] + 1 323 | frame_id.append(new_id) 324 | fake_id.append(new_id) 325 | already_used_id.append(new_id) 326 | det_prev_num = det_split[i - 1] 327 | else: 328 | # look prev 329 | det_prev_num = det_split[i - 1] 330 | for k in range(det_prev_num): 331 | if y[link_start_idx + k * det_curr_num + j] == 1: 332 | prev_id = fake_ids[-1][k] 333 | frame_id.append(prev_id) 334 | fake_id.append(prev_id) 335 | break 336 | 337 | assert len(fake_id) == det_curr_num 338 | assert len(det_out) != 0 339 | 340 | fake_ids.append(fake_id) 341 | det_ids.append(frame_id) 342 | dets_out.append(torch.cat(det_out, dim=0)) 343 | link_start_idx += det_curr_num * det_prev_num 344 | det_start_idx += det_curr_num 345 | 346 | return det_ids, dets_out 347 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import pprint 5 | import time 6 | 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import yaml 11 | from easydict import EasyDict 12 | from torch.utils.data import DataLoader 13 | from tracking_model import TrackingModule 14 | from utils.build_util import build_augmentation, build_dataset, build_model 15 | from utils.data_util import write_kitti_result 16 | from utils.train_util import AverageMeter, create_logger, load_state 17 | 18 | parser = argparse.ArgumentParser(description='PyTorch mmMOT Testing') 19 | parser.add_argument('--config', default='cfgs/config_res50.yaml') 20 | parser.add_argument('--load-path', default='', type=str) 21 | parser.add_argument('--result-path', default='', type=str) 22 | parser.add_argument('--recover', action='store_true') 23 | parser.add_argument('-e', '--evaluate', action='store_true') 24 | parser.add_argument('--result_sha', default='last') 25 | 26 | 27 | def main(): 28 | global args, config, best_mota 29 | args = parser.parse_args() 30 | 31 | with open(args.config) as f: 32 | config = yaml.load(f, Loader=yaml.FullLoader) 33 | 34 | config = EasyDict(config['common']) 35 | config.save_path = os.path.dirname(args.config) 36 | 37 | # create model 38 | model = build_model(config) 39 | model.cuda() 40 | 41 | # optionally resume from a checkpoint 42 | last_iter = -1 43 | best_mota = 0 44 | if args.load_path: 45 | if args.recover: 46 | best_mota, last_iter = load_state( 47 | args.load_path, model, optimizer=None) 48 | else: 49 | load_state(args.load_path, model) 50 | 51 | cudnn.benchmark = True 52 | 53 | # Data loading code 54 | train_transform, valid_transform = build_augmentation(config.augmentation) 55 | # train_val 56 | test_dataset = build_dataset( 57 | config, 58 | set_source='test', 59 | evaluate=True, 60 | valid_transform=valid_transform) 61 | 62 | logger = create_logger('global_logger', config.save_path + '/eval_log.txt') 63 | logger.info('args: {}'.format(pprint.pformat(args))) 64 | logger.info('config: {}'.format(pprint.pformat(config))) 65 | 66 | tracking_module = TrackingModule(model, None, None, config.det_type) 67 | 68 | logger.info('Evaluation on traing and validation set:') 69 | validate(test_dataset, tracking_module, args.result_sha, part='all') 70 | 71 | 72 | def validate(val_loader, 73 | tracking_module, 74 | step, 75 | part='train', 76 | fusion_list=None, 77 | fuse_prob=False): 78 | 79 | logger = logging.getLogger('global_logger') 80 | for i, (sequence) in enumerate(val_loader): 81 | logger.info('Test: [{}/{}]\tSequence ID: KITTI-{}'.format( 82 | i, len(val_loader), sequence.name)) 83 | seq_loader = DataLoader( 84 | sequence, 85 | batch_size=config.batch_size, 86 | shuffle=False, 87 | num_workers=config.workers, 88 | pin_memory=True) 89 | if len(seq_loader) == 0: 90 | tracking_module.eval() 91 | logger.info('Empty Sequence ID: KITTI-{}, skip'.format( 92 | sequence.name)) 93 | else: 94 | validate_seq(seq_loader, tracking_module) 95 | 96 | write_kitti_result( 97 | args.result_path, 98 | sequence.name, 99 | step, 100 | tracking_module.frames_id, 101 | tracking_module.frames_det, 102 | part=part) 103 | 104 | tracking_module.train() 105 | return 106 | 107 | 108 | def validate_seq(val_loader, 109 | tracking_module, 110 | fusion_list=None, 111 | fuse_prob=False): 112 | batch_time = AverageMeter(0) 113 | 114 | # switch to evaluate mode 115 | tracking_module.eval() 116 | 117 | logger = logging.getLogger('global_logger') 118 | end = time.time() 119 | # Create an accumulator that will be updated during each frame 120 | 121 | with torch.no_grad(): 122 | for i, (input, det_info, dets, det_split) in enumerate(val_loader): 123 | input = input.cuda() 124 | if len(det_info) > 0: 125 | for k, v in det_info.items(): 126 | det_info[k] = det_info[k].cuda() if not isinstance( 127 | det_info[k], list) else det_info[k] 128 | 129 | # compute output 130 | aligned_ids, aligned_dets, frame_start = tracking_module.predict( 131 | input[0], det_info, dets, det_split) 132 | 133 | # measure elapsed time 134 | batch_time.update(time.time() - end) 135 | end = time.time() 136 | if i % config.print_freq == 0: 137 | logger.info( 138 | 'Test Frame: [{0}/{1}]\tTime' 139 | ' {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( 140 | i, len(val_loader), batch_time=batch_time)) 141 | 142 | return 143 | 144 | 145 | if __name__ == '__main__': 146 | main() 147 | -------------------------------------------------------------------------------- /tracking_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from solvers import ortools_solve 4 | from utils.data_util import get_start_gt_anno 5 | 6 | 7 | class TrackingModule(object): 8 | 9 | def __init__(self, model, optimizer, criterion, det_type='3D'): 10 | self.model = model 11 | self.optimizer = optimizer 12 | self.criterion = criterion 13 | self.det_type = det_type 14 | self.used_id = [] 15 | self.last_id = 0 16 | self.frames_id = [] 17 | self.frames_det = [] 18 | self.track_feats = None 19 | if isinstance(model, list): 20 | self.test_mode = model[0].test_mode 21 | else: 22 | self.test_mode = model.test_mode 23 | 24 | def clear_mem(self): 25 | self.used_id = [] 26 | self.last_id = 0 27 | self.frames_id = [] 28 | self.frames_det = [] 29 | self.track_feats = None 30 | return 31 | 32 | def eval(self): 33 | if isinstance(self.model, list): 34 | for i in range(len(self.model)): 35 | self.model[i].eval() 36 | else: 37 | self.model.eval() 38 | self.clear_mem() 39 | return 40 | 41 | def train(self): 42 | if isinstance(self.model, list): 43 | for i in range(len(self.model)): 44 | self.model[i].train() 45 | else: 46 | self.model.train() 47 | self.clear_mem() 48 | return 49 | 50 | def step(self, det_img, det_info, det_id, det_cls, det_split): 51 | det_score, link_score, new_score, end_score, trans = self.model( 52 | det_img, det_info, det_split) 53 | # generate gt_y 54 | gt_det, gt_link, gt_new, gt_end = self.generate_gt( 55 | det_score[0], det_cls, det_id, det_split) 56 | 57 | # calculate loss 58 | loss = self.criterion(det_split, gt_det, gt_link, gt_new, gt_end, 59 | det_score, link_score, new_score, end_score, 60 | trans) 61 | 62 | self.optimizer.zero_grad() 63 | loss.backward() 64 | self.optimizer.step() 65 | 66 | return loss 67 | 68 | def predict(self, det_imgs, det_info, dets, det_split): 69 | det_score, link_score, new_score, end_score, _ = self.model( 70 | det_imgs, det_info, det_split) 71 | 72 | assign_det, assign_link, assign_new, assign_end = ortools_solve( 73 | det_score[self.test_mode], 74 | [link_score[0][self.test_mode:self.test_mode + 1]], 75 | new_score[self.test_mode], end_score[self.test_mode], det_split) 76 | 77 | assign_id, assign_bbox = self.assign_det_id(assign_det, assign_link, 78 | assign_new, assign_end, 79 | det_split, dets) 80 | aligned_ids, aligned_dets, frame_start = self.align_id( 81 | assign_id, assign_bbox) 82 | 83 | return aligned_ids, aligned_dets, frame_start 84 | 85 | def mem_assign_det_id(self, feats, assign_det, assign_link, assign_new, 86 | assign_end, det_split, dets): 87 | det_ids = [] 88 | v, idx = torch.max(assign_link[0][0], dim=0) 89 | for i in range(idx.size(0)): 90 | if v[i] == 1: 91 | track_id = idx[i].item() 92 | det_ids.append(track_id) 93 | self.track_feats[track_id] = feats[i:i + 1] 94 | else: 95 | new_id = self.last_id + 1 96 | det_ids.append(new_id) 97 | self.last_id += 1 98 | self.track_feats.append(feats[i:i + 1]) 99 | 100 | for k, v in dets[0].items(): 101 | dets[0][k] = v.squeeze(0) if k != 'frame_idx' else v 102 | dets[0]['id'] = torch.Tensor(det_ids).long() 103 | self.frames_id.append(det_ids) 104 | self.frames_det += dets 105 | assert len(self.track_feats) == (self.last_id + 1) 106 | 107 | return det_ids, dets 108 | 109 | def align_id(self, dets_ids, dets_out): 110 | frame_start = 0 111 | if len(self.used_id) == 0: 112 | # Start of a sequence 113 | self.used_id += dets_ids 114 | self.frames_id += dets_ids 115 | self.frames_det += dets_out 116 | max_id = 0 117 | for i in range(len(dets_ids)): 118 | if dets_out[i]['id'].size(0) == 0: 119 | continue 120 | max_id = np.maximum(np.max(dets_ids[i]), max_id) 121 | self.last_id = np.maximum(self.last_id, max_id) 122 | return dets_ids, dets_out, frame_start 123 | elif self.frames_det[-1]['frame_idx'] != dets_out[0]['frame_idx']: 124 | # in case the sequence is not continuous 125 | aligned_ids = [] 126 | aligned_dets = [] 127 | max_id = 0 128 | id_offset = self.last_id + 1 129 | for i in range(len(dets_ids)): 130 | if dets_out[i]['id'].size(0) == 0: 131 | aligned_ids.append([]) 132 | continue 133 | new_id = dets_ids[i] + id_offset 134 | max_id = np.maximum(np.max(new_id), max_id) 135 | aligned_ids.append(new_id) 136 | dets_out[i]['id'] += id_offset 137 | aligned_dets += dets_out 138 | self.last_id = np.maximum(self.last_id, max_id) 139 | self.frames_id += aligned_ids 140 | self.frames_det += aligned_dets 141 | return aligned_ids, aligned_dets, frame_start 142 | else: 143 | # the first frame of current dets 144 | # and the last frame of last dets is the same 145 | frame_start = 1 146 | aligned_ids = [] 147 | aligned_dets = [] 148 | max_id = 0 149 | id_pairs = {} 150 | """ 151 | assert len(dets_ids[0])== len(self.frames_id[-1]) 152 | """ 153 | # Calculate Id pairs 154 | for i in range(len(dets_ids[0])): 155 | # Use minimum because because sometimes 156 | # they are not totally the same 157 | has_match = False 158 | for j in range(len(self.frames_id[-1])): 159 | if ((self.det_type == '3D' 160 | and torch.sum(dets_out[0]['location'][i] != 161 | self.frames_det[-1]['location'][j]) == 0 162 | and torch.sum(dets_out[0]['bbox'][i] != 163 | self.frames_det[-1]['bbox'][j]) == 0) 164 | or (self.det_type == '2D' and torch.sum( 165 | dets_out[0]['bbox'][i] != self.frames_det[-1] 166 | ['bbox'][j]) == 0)): # noqa 167 | 168 | id_pairs[dets_ids[0][i]] = self.frames_id[-1][j] 169 | has_match = True 170 | break 171 | if not has_match: 172 | id_pairs[dets_ids[0][i]] = self.last_id + 1 173 | self.last_id += 1 174 | if len([v for k, v in id_pairs.items()]) != len( 175 | set([v for k, v in id_pairs.items()])): 176 | print("ID pairs has duplicates!!!") 177 | print(id_pairs) 178 | print(dets_ids) 179 | print(dets_out[0]) 180 | print(self.frames_id[-1]) 181 | print(self.frames_det[-1]) 182 | 183 | for i in range(1, len(dets_ids)): 184 | if dets_out[i]['id'].size(0) == 0: 185 | aligned_ids.append([]) 186 | continue 187 | new_id = dets_ids[i].copy() 188 | for j in range(len(dets_ids[i])): 189 | if dets_ids[i][j] in id_pairs.keys(): 190 | new_id[j] = id_pairs[dets_ids[i][j]] 191 | else: 192 | new_id[j] = self.last_id + 1 193 | id_pairs[dets_ids[i][j]] = new_id[j] 194 | self.last_id += 1 195 | if len(new_id) != len( 196 | set(new_id)): # check whether there is duplicate 197 | print('have duplicates!!!') 198 | print(id_pairs) 199 | print(new_id) 200 | print(dets_ids) 201 | print(dets_out) 202 | print(self.frames_id[-1]) 203 | print(self.frames_det[-1]) 204 | import pdb 205 | pdb.set_trace() 206 | 207 | max_id = np.maximum(np.max(new_id), max_id) 208 | self.last_id = np.maximum(self.last_id, max_id) 209 | aligned_ids.append(new_id) 210 | dets_out[i]['id'] = torch.Tensor(new_id).long() 211 | # TODO: This only support check for 2 frame case 212 | if dets_out[1]['id'].size(0) != 0: 213 | aligned_dets += dets_out[1:] 214 | self.frames_id += aligned_ids 215 | self.frames_det += aligned_dets 216 | return aligned_ids, aligned_dets, frame_start 217 | 218 | def assign_det_id(self, assign_det, assign_link, assign_new, assign_end, 219 | det_split, dets): 220 | det_start_idx = 0 221 | det_ids = [] 222 | already_used_id = [] 223 | fake_ids = [] 224 | dets_out = [] 225 | 226 | for i in range(len(det_split)): 227 | frame_id = [] 228 | det_curr_num = det_split[i].item() 229 | fake_id = [] 230 | det_out = get_start_gt_anno() 231 | for j in range(det_curr_num): 232 | curr_det_idx = det_start_idx + j 233 | # check w_det 234 | if assign_det[curr_det_idx] != 1: 235 | fake_id.append(-1) 236 | continue 237 | else: 238 | # det_out.append(dets[i][j]) 239 | det_out['name'].append(dets[i]['name'][:, j]) 240 | det_out['truncated'].append(dets[i]['truncated'][:, j]) 241 | det_out['occluded'].append(dets[i]['occluded'][:, j]) 242 | det_out['alpha'].append(dets[i]['alpha'][:, j]) 243 | det_out['bbox'].append(dets[i]['bbox'][:, j]) 244 | det_out['dimensions'].append(dets[i]['dimensions'][:, j]) 245 | det_out['location'].append(dets[i]['location'][:, j]) 246 | det_out['rotation_y'].append(dets[i]['rotation_y'][:, j]) 247 | 248 | # w_det=1, check whether a new det 249 | if i == 0: 250 | if len(already_used_id) == 0: 251 | frame_id.append(0) 252 | fake_id.append(0) 253 | already_used_id.append(0) 254 | det_out['id'].append(torch.Tensor([0]).long()) 255 | else: 256 | new_id = already_used_id[-1] + 1 257 | frame_id.append(new_id) 258 | fake_id.append(new_id) 259 | already_used_id.append(new_id) 260 | det_out['id'].append(torch.Tensor([new_id]).long()) 261 | continue 262 | elif assign_new[curr_det_idx] == 1: 263 | new_id = already_used_id[-1] + 1 if len( 264 | already_used_id) > 0 else 0 265 | frame_id.append(new_id) 266 | fake_id.append(new_id) 267 | already_used_id.append(new_id) 268 | det_out['id'].append(torch.Tensor([new_id]).long()) 269 | else: 270 | # look prev 271 | det_prev_num = det_split[i - 1] 272 | for k in range(det_prev_num): 273 | if assign_link[i - 1][0][k][j] == 1: 274 | prev_id = fake_ids[-1][k] 275 | frame_id.append(prev_id) 276 | fake_id.append(prev_id) 277 | det_out['id'].append( 278 | torch.Tensor([prev_id]).long()) 279 | break 280 | 281 | assert len(fake_id) == det_curr_num 282 | fake_ids.append(fake_id) 283 | det_ids.append(np.array(frame_id)) 284 | for k, v in det_out.items(): 285 | if len(det_out[k]) == 0: 286 | det_out[k] = torch.Tensor([]) 287 | else: 288 | det_out[k] = torch.cat(v, dim=0) 289 | det_out['frame_idx'] = dets[i]['frame_idx'] 290 | dets_out.append(det_out) 291 | det_start_idx += det_curr_num 292 | return det_ids, dets_out 293 | 294 | def generate_gt(self, det_score, det_cls, det_id, det_split): 295 | gt_det = det_score.new_zeros(det_score.size()) 296 | gt_new = det_score.new_zeros(det_score.size()) 297 | gt_end = det_score.new_zeros(det_score.size()) 298 | gt_link = [] 299 | det_start_idx = 0 300 | 301 | for i in range(len(det_split)): 302 | det_curr_num = det_split[i] # current frame i has det_i detects 303 | if i != len(det_split) - 1: 304 | link_matrix = det_score.new_zeros( 305 | (1, det_curr_num, det_split[i + 1])) 306 | # Assign the score, according to eq1 307 | for j in range(det_curr_num): 308 | curr_det_idx = det_start_idx + j 309 | # g_det 310 | if det_cls[i][0][j] == 1: 311 | gt_det[curr_det_idx] = 1 # positive 312 | else: 313 | continue 314 | 315 | # g_link for successor frame 316 | if i == len(det_split) - 1: 317 | # end det at last frame 318 | gt_end[curr_det_idx] = 1 319 | else: 320 | matched = False 321 | det_next_num = det_split[i + 1] 322 | for k in range(det_next_num): 323 | if det_id[i][0][j] == det_id[i + 1][0][k]: 324 | link_matrix[0][j][k] = 1 325 | matched = True 326 | break 327 | if not matched: 328 | # no successor means an end det 329 | gt_end[curr_det_idx] = 1 330 | 331 | if i == 0: 332 | # new det at first frame 333 | gt_new[curr_det_idx] = 1 334 | else: 335 | # look prev 336 | matched = False 337 | det_prev_num = det_split[i - 1] 338 | for k in range(det_prev_num): 339 | if det_id[i][0][j] == det_id[i - 1][0][k]: 340 | # have been matched during search in 341 | # previous frame, no need to assign 342 | matched = True 343 | break 344 | if not matched: 345 | gt_new[curr_det_idx] = 1 346 | 347 | det_start_idx += det_curr_num 348 | if i != len(det_split) - 1: 349 | gt_link.append(link_matrix) 350 | 351 | return gt_det, gt_link, gt_new, gt_end 352 | -------------------------------------------------------------------------------- /utils/build_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.transforms as transforms 5 | from cost import TrackingLoss 6 | from dataset import PatchwiseDataset, TestSequenceDataset 7 | from modules import TrackingNet 8 | 9 | 10 | def children(m: nn.Module): 11 | "Get children of `m`." 12 | return list(m.children()) 13 | 14 | 15 | def num_children(m: nn.Module) -> int: 16 | "Get number of children modules in `m`." 17 | return len(children(m)) 18 | 19 | 20 | flatten_model = lambda m: sum( 21 | map(flatten_model, m.children()), 22 | [] # noqa 23 | ) if num_children(m) else [m] # noqa 24 | 25 | get_layer_groups = lambda m: [nn.Sequential(*flatten_model(m))] # noqa 26 | 27 | 28 | def build_lr_scheduler(config, optimizer): 29 | from .learning_schedules_fastai import OneCycle 30 | if config.type == 'one_cycle': 31 | print("Use one cycle LR scheduler") 32 | lr_scheduler = OneCycle(optimizer, config.max_iter, config.lr_max, 33 | list(config.moms), config.div_factor, 34 | config.pct_start) 35 | elif config.type == 'constant': 36 | print("Use no LR scheduler") 37 | lr_scheduler = None 38 | return lr_scheduler 39 | 40 | 41 | def build_optim(net, config): 42 | from .optim_util import OptimWrapper 43 | from functools import partial 44 | 45 | if config.lr_scheduler.optim == 'Adam': 46 | optimizer_func = partial(torch.optim.Adam, betas=(0.9, 0.99)) 47 | elif config.lr_scheduler.optim == 'AdaBound': 48 | print("Use AdaBound optim") 49 | from .adabound import AdaBound 50 | optimizer_func = partial(AdaBound, betas=(0.9, 0.99)) 51 | 52 | optimizer = OptimWrapper.create( 53 | optimizer_func, 54 | config.lr_scheduler.base_lr, 55 | get_layer_groups(net), 56 | wd=config.weight_decay, 57 | true_wd=config.fixed_wd, 58 | bn_wd=True) 59 | return optimizer 60 | 61 | 62 | def build_model(config): 63 | model = TrackingNet( 64 | seq_len=config.sample_max_len, 65 | score_arch=config.model.score_arch, 66 | appear_arch=config.model.appear_arch, 67 | appear_len=config.model.appear_len, 68 | appear_skippool=config.model.appear_skippool, 69 | appear_fpn=config.model.appear_fpn, 70 | point_arch=config.model.point_arch, 71 | point_len=config.model.point_len, 72 | without_reflectivity=config.without_reflectivity, 73 | softmax_mode=config.model.softmax_mode, 74 | affinity_op=config.model.affinity_op, 75 | end_arch=config.model.end_arch, 76 | end_mode=config.model.end_mode, 77 | test_mode=config.model.test_mode, 78 | score_fusion_arch=config.model.score_fusion_arch, 79 | neg_threshold=config.model.neg_threshold, 80 | dropblock=config.dropblock, 81 | use_dropout=config.use_dropout, 82 | ) 83 | return model 84 | 85 | 86 | # build dataset 87 | class Cutout(object): 88 | 89 | def __init__(self, length): 90 | self.length = length 91 | 92 | def __call__(self, img): 93 | h, w = img.size(1), img.size(2) 94 | mask = np.ones((h, w), np.float32) 95 | y = np.random.randint(h) 96 | x = np.random.randint(w) 97 | 98 | y1 = np.clip(y - self.length // 2, 0, h) 99 | y2 = np.clip(y + self.length // 2, 0, h) 100 | x1 = np.clip(x - self.length // 2, 0, w) 101 | x2 = np.clip(x + self.length // 2, 0, w) 102 | 103 | mask[y1:y2, x1:x2] = 0. 104 | mask = torch.from_numpy(mask) 105 | mask = mask.expand_as(img) 106 | img *= mask 107 | return img 108 | 109 | 110 | def build_augmentation(config): 111 | normalize = transforms.Normalize( 112 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 113 | aug = [ 114 | transforms.RandomResizedCrop(config.input_size), 115 | transforms.RandomHorizontalFlip() 116 | ] 117 | 118 | rotation = config.get('rotation', 0) 119 | colorjitter = config.get('colorjitter', None) 120 | cutout = config.get('cutout', None) 121 | 122 | if rotation > 0: 123 | print("rotation applied") 124 | aug.append(transforms.RandomRotation(rotation)) 125 | 126 | if colorjitter is not None: 127 | print("colorjitter applied") 128 | aug.append(transforms.ColorJitter(*colorjitter)) 129 | 130 | aug.append(transforms.ToTensor()) 131 | aug.append(normalize) 132 | 133 | if cutout is not None: 134 | print("cutout applied") 135 | aug.append(Cutout(config.cutout_length)) 136 | 137 | valid_transform = transforms.Compose([ 138 | transforms.Resize(config.test_resize), 139 | transforms.CenterCrop(config.input_size), 140 | transforms.ToTensor(), 141 | normalize, 142 | ]) 143 | train_transform = transforms.Compose(aug) 144 | return train_transform, valid_transform 145 | 146 | 147 | def build_criterion(config): 148 | criterion = TrackingLoss( 149 | smooth_ratio=config.smooth_ratio, 150 | detloss_type=config.det_loss, 151 | det_ratio=config.det_ratio, 152 | trans_ratio=config.trans_ratio, 153 | trans_last=config.trans_last, 154 | linkloss_type=config.link_loss) 155 | return criterion 156 | 157 | 158 | def build_dataset(config, 159 | set_source='train', 160 | evaluate=False, 161 | train_transform=None, 162 | valid_transform=None): 163 | if set_source == 'train' and not evaluate: 164 | train_dataset = PatchwiseDataset( 165 | root_dir=config.train_root, 166 | meta_file=config.train_source, 167 | link_file=config.train_link, 168 | det_file=config.train_det, 169 | det_type=config.det_type, 170 | tracker_type=config.tracker_type, 171 | use_frustum=config.use_frustum, 172 | without_reflectivity=config.without_reflectivity, 173 | bbox_jitter=config.augmentation.get('bboxjitter', None), 174 | transform=train_transform, 175 | fix_iou=config.train_fix_iou, 176 | fix_count=config.train_fix_count, 177 | gt_ratio=config.gt_det_ratio, 178 | sample_max_len=config.sample_max_len, 179 | train=True) 180 | return train_dataset 181 | elif set_source == 'train' and evaluate: 182 | # train_val 183 | trainval_dataset = TestSequenceDataset( 184 | root_dir=config.train_root, 185 | meta_file=config.train_source, 186 | link_file=config.train_link, 187 | det_file=config.train_det, 188 | det_type=config.det_type, 189 | tracker_type=config.tracker_type, 190 | use_frustum=config.use_frustum, 191 | without_reflectivity=config.without_reflectivity, 192 | transform=valid_transform, 193 | fix_iou=config.val_fix_iou, 194 | fix_count=config.val_fix_count, 195 | gt_ratio=config.gt_det_ratio, 196 | sample_max_len=config.sample_max_len) 197 | return trainval_dataset 198 | elif set_source == 'val' and evaluate: 199 | # val 200 | val_dataset = TestSequenceDataset( 201 | root_dir=config.val_root, 202 | meta_file=config.val_source, 203 | link_file=config.val_link, 204 | det_file=config.val_det, 205 | det_type=config.det_type, 206 | tracker_type=config.tracker_type, 207 | use_frustum=config.use_frustum, 208 | without_reflectivity=config.without_reflectivity, 209 | transform=valid_transform, 210 | fix_iou=config.val_fix_iou, 211 | fix_count=config.val_fix_count, 212 | gt_ratio=config.gt_det_ratio, 213 | sample_max_len=config.sample_max_len) 214 | return val_dataset 215 | elif set_source == 'test' and evaluate: 216 | test_dataset = TestSequenceDataset( 217 | root_dir=config.test_root, 218 | meta_file=config.test_source, 219 | link_file=config.test_link, 220 | det_file=config.test_det, 221 | det_type=config.det_type, 222 | tracker_type=config.tracker_type, 223 | use_frustum=config.use_frustum, 224 | without_reflectivity=config.without_reflectivity, 225 | transform=valid_transform, 226 | fix_iou=config.val_fix_iou, 227 | fix_count=config.val_fix_count, 228 | gt_ratio=config.gt_det_ratio, 229 | sample_max_len=config.sample_max_len) 230 | return test_dataset 231 | else: 232 | print("Error: Not implement!!!!") 233 | -------------------------------------------------------------------------------- /utils/learning_schedules_fastai.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import numpy as np 4 | 5 | from .optim_util import OptimWrapper 6 | 7 | 8 | class LRSchedulerStep(object): 9 | 10 | def __init__(self, fai_optimizer: OptimWrapper, total_step, lr_phases, 11 | mom_phases): 12 | # if not isinstance(fai_optimizer, OptimWrapper): 13 | # raise TypeError('{} is not a fastai OptimWrapper'.format( 14 | # type(fai_optimizer).__name__)) 15 | self.optimizer = fai_optimizer 16 | self.total_step = total_step 17 | self.lr_phases = [] 18 | 19 | for i, (start, lambda_func) in enumerate(lr_phases): 20 | if len(self.lr_phases) != 0: 21 | assert self.lr_phases[-1][0] < start 22 | if isinstance(lambda_func, str): 23 | lambda_func = eval(lambda_func) 24 | if i < len(lr_phases) - 1: 25 | self.lr_phases.append( 26 | (int(start * total_step), 27 | int(lr_phases[i + 1][0] * total_step), lambda_func)) 28 | else: 29 | self.lr_phases.append( 30 | (int(start * total_step), total_step, lambda_func)) 31 | assert self.lr_phases[0][0] == 0 32 | self.mom_phases = [] 33 | for i, (start, lambda_func) in enumerate(mom_phases): 34 | if len(self.mom_phases) != 0: 35 | assert self.mom_phases[-1][0] < start 36 | if isinstance(lambda_func, str): 37 | lambda_func = eval(lambda_func) 38 | if i < len(mom_phases) - 1: 39 | self.mom_phases.append( 40 | (int(start * total_step), 41 | int(mom_phases[i + 1][0] * total_step), lambda_func)) 42 | else: 43 | self.mom_phases.append( 44 | (int(start * total_step), total_step, lambda_func)) 45 | assert self.mom_phases[0][0] == 0 46 | self.current_lr = 0 47 | 48 | def step(self, step): 49 | for start, end, func in self.lr_phases: 50 | if step >= start: 51 | self.current_lr = func((step - start) / (end - start)) 52 | self.optimizer.lr = self.current_lr 53 | for start, end, func in self.mom_phases: 54 | if step >= start: 55 | self.optimizer.mom = func((step - start) / (end - start)) 56 | 57 | def get_lr(self): 58 | return self.current_lr 59 | 60 | 61 | def annealing_cos(start, end, pct): 62 | # print(pct, start, end) 63 | "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." 64 | cos_out = np.cos(np.pi * pct) + 1 65 | return end + (start - end) / 2 * cos_out 66 | 67 | 68 | class OneCycle(LRSchedulerStep): 69 | 70 | def __init__(self, fai_optimizer, total_step, lr_max, moms, div_factor, 71 | pct_start): 72 | self.lr_max = lr_max 73 | self.moms = moms 74 | self.div_factor = div_factor 75 | self.pct_start = pct_start 76 | a1 = int(total_step * self.pct_start) 77 | a2 = total_step - a1 78 | low_lr = self.lr_max / self.div_factor 79 | lr_phases = ((0, partial(annealing_cos, low_lr, self.lr_max)), 80 | (self.pct_start, 81 | partial(annealing_cos, self.lr_max, low_lr / 1e4))) 82 | mom_phases = ((0, partial(annealing_cos, *self.moms)), 83 | (self.pct_start, partial(annealing_cos, 84 | *self.moms[::-1]))) 85 | fai_optimizer.lr, fai_optimizer.mom = low_lr, self.moms[0] 86 | super().__init__(fai_optimizer, total_step, lr_phases, mom_phases) 87 | 88 | 89 | class FakeOptim: 90 | 91 | def __init__(self): 92 | self.lr = 0 93 | self.mom = 0 94 | 95 | 96 | if __name__ == "__main__": 97 | import matplotlib.pyplot as plt 98 | opt = FakeOptim() # 3e-3, wd=0.4, div_factor=10 99 | schd = OneCycle(opt, 100, 3e-3, (0.95, 0.85), 10.0, 0.1) 100 | 101 | lrs = [] 102 | moms = [] 103 | for i in range(100): 104 | schd.step(i) 105 | lrs.append(opt.lr) 106 | moms.append(opt.mom) 107 | plt.plot(lrs) 108 | # plt.plot(moms) 109 | plt.show() 110 | plt.plot(moms) 111 | plt.show() 112 | -------------------------------------------------------------------------------- /utils/optim_util.py: -------------------------------------------------------------------------------- 1 | from collections import Iterable 2 | from functools import partial 3 | 4 | import torch 5 | from torch import nn 6 | from torch._utils import _unflatten_dense_tensors 7 | from torch.nn.utils import parameters_to_vector 8 | 9 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm) 10 | 11 | 12 | def split_bn_bias(layer_groups): 13 | "Split the layers in `layer_groups` into batchnorm (`\`) and non-batchnorm groups." 14 | split_groups = [] 15 | for l in layer_groups: 16 | l1, l2 = [], [] 17 | for c in l.children(): 18 | if isinstance(c, bn_types): 19 | l2.append(c) 20 | else: 21 | l1.append(c) 22 | split_groups += [nn.Sequential(*l1), nn.Sequential(*l2)] 23 | return split_groups 24 | 25 | 26 | def model_g2master_g(model_params, 27 | master_params, 28 | flat_master: bool = False) -> None: 29 | "Copy the `model_params` gradients to `master_params` for the optimizer step." 30 | if flat_master: 31 | for model_group, master_group in zip(model_params, master_params): 32 | if len(master_group) != 0: 33 | master_group[0].grad.data.copy_( 34 | parameters_to_vector( 35 | [p.grad.data.float() for p in model_group])) 36 | else: 37 | for model_group, master_group in zip(model_params, master_params): 38 | for model, master in zip(model_group, master_group): 39 | if model.grad is not None: 40 | if master.grad is None: 41 | master.grad = master.data.new(*master.data.size()) 42 | master.grad.data.copy_(model.grad.data) 43 | else: 44 | master.grad = None 45 | 46 | 47 | def master2model(model_params, 48 | master_params, 49 | flat_master: bool = False) -> None: 50 | "Copy `master_params` to `model_params`." 51 | if flat_master: 52 | for model_group, master_group in zip(model_params, master_params): 53 | if len(model_group) != 0: 54 | for model, master in zip( 55 | model_group, 56 | _unflatten_dense_tensors(master_group[0].data, 57 | model_group)): 58 | model.data.copy_(master) 59 | else: 60 | for model_group, master_group in zip(model_params, master_params): 61 | for model, master in zip(model_group, master_group): 62 | model.data.copy_(master.data) 63 | 64 | 65 | def listify(p=None, q=None): 66 | "Make `p` listy and the same length as `q`." 67 | if p is None: 68 | p = [] 69 | elif isinstance(p, str): 70 | p = [p] 71 | elif not isinstance(p, Iterable): 72 | p = [p] 73 | n = q if type(q) == int else len(p) if q is None else len(q) 74 | if len(p) == 1: p = p * n 75 | assert len(p) == n, f'List len mismatch ({len(p)} vs {n})' 76 | return list(p) 77 | 78 | 79 | def trainable_params(m: nn.Module): 80 | "Return list of trainable params in `m`." 81 | res = filter(lambda p: p.requires_grad, m.parameters()) 82 | return res 83 | 84 | 85 | def is_tuple(x) -> bool: 86 | return isinstance(x, tuple) 87 | 88 | 89 | # copy from fastai. 90 | class OptimWrapper(): 91 | "Basic wrapper around `opt` to simplify hyper-parameters changes." 92 | 93 | def __init__(self, opt, wd, true_wd: bool = False, bn_wd: bool = True): 94 | self.opt, self.true_wd, self.bn_wd = opt, true_wd, bn_wd 95 | self.opt_keys = list(self.opt.param_groups[0].keys()) 96 | self.opt_keys.remove('params') 97 | self.read_defaults() 98 | self.wd = wd 99 | 100 | @classmethod 101 | def create(cls, opt_func, lr, layer_groups, **kwargs): 102 | "Create an `optim.Optimizer` from `opt_func` with `lr`. Set lr on `layer_groups`." 103 | split_groups = split_bn_bias(layer_groups) 104 | opt = opt_func([{ 105 | 'params': trainable_params(l), 106 | 'lr': lr 107 | } for l in split_groups]) # modified for adabound 108 | #opt = opt_func([{'params': trainable_params(l), 'lr': lr} for l in split_groups]) 109 | opt = cls(opt, **kwargs) 110 | opt.lr, opt.opt_func = listify(lr, layer_groups), opt_func 111 | return opt 112 | 113 | def new(self, layer_groups): 114 | "Create a new `OptimWrapper` from `self` with another `layer_groups` but the same hyper-parameters." 115 | opt_func = getattr(self, 'opt_func', self.opt.__class__) 116 | split_groups = split_bn_bias(layer_groups) 117 | opt = opt_func([{ 118 | 'params': trainable_params(l), 119 | 'lr': 0 120 | } for l in split_groups]) 121 | return self.create( 122 | opt_func, 123 | self.lr, 124 | layer_groups, 125 | wd=self.wd, 126 | true_wd=self.true_wd, 127 | bn_wd=self.bn_wd) 128 | 129 | def __repr__(self) -> str: 130 | return f'OptimWrapper over {repr(self.opt)}.\nTrue weight decay: {self.true_wd}' 131 | 132 | # Pytorch optimizer methods 133 | def step(self) -> None: 134 | "Set weight decay and step optimizer." 135 | # weight decay outside of optimizer step (AdamW) 136 | if self.true_wd: 137 | for lr, wd, pg1, pg2 in zip(self._lr, self._wd, 138 | self.opt.param_groups[::2], 139 | self.opt.param_groups[1::2]): 140 | for p in pg1['params']: 141 | p.data.mul_(1 - wd * lr) 142 | if self.bn_wd: 143 | for p in pg2['params']: 144 | p.data.mul_(1 - wd * lr) 145 | self.set_val('weight_decay', listify(0, self._wd)) 146 | self.opt.step() 147 | 148 | def zero_grad(self) -> None: 149 | "Clear optimizer gradients." 150 | self.opt.zero_grad() 151 | 152 | # Passthrough to the inner opt. 153 | def __getattr__(self, k: str): 154 | return getattr(self.opt, k, None) 155 | 156 | def clear(self): 157 | "Reset the state of the inner optimizer." 158 | sd = self.state_dict() 159 | sd['state'] = {} 160 | self.load_state_dict(sd) 161 | 162 | # Hyperparameters as properties 163 | @property 164 | def lr(self) -> float: 165 | return self._lr[-1] 166 | 167 | @lr.setter 168 | def lr(self, val: float) -> None: 169 | self._lr = self.set_val('lr', listify(val, self._lr)) 170 | 171 | @property 172 | def mom(self) -> float: 173 | return self._mom[-1] 174 | 175 | @mom.setter 176 | def mom(self, val: float) -> None: 177 | if 'momentum' in self.opt_keys: 178 | self.set_val('momentum', listify(val, self._mom)) 179 | elif 'betas' in self.opt_keys: 180 | self.set_val('betas', (listify(val, self._mom), self._beta)) 181 | self._mom = listify(val, self._mom) 182 | 183 | @property 184 | def beta(self) -> float: 185 | return None if self._beta is None else self._beta[-1] 186 | 187 | @beta.setter 188 | def beta(self, val: float) -> None: 189 | "Set beta (or alpha as makes sense for given optimizer)." 190 | if val is None: return 191 | if 'betas' in self.opt_keys: 192 | self.set_val('betas', (self._mom, listify(val, self._beta))) 193 | elif 'alpha' in self.opt_keys: 194 | self.set_val('alpha', listify(val, self._beta)) 195 | self._beta = listify(val, self._beta) 196 | 197 | @property 198 | def wd(self) -> float: 199 | return self._wd[-1] 200 | 201 | @wd.setter 202 | def wd(self, val: float) -> None: 203 | "Set weight decay." 204 | if not self.true_wd: 205 | self.set_val( 206 | 'weight_decay', listify(val, self._wd), bn_groups=self.bn_wd) 207 | self._wd = listify(val, self._wd) 208 | 209 | # Helper functions 210 | def read_defaults(self) -> None: 211 | "Read the values inside the optimizer for the hyper-parameters." 212 | self._beta = None 213 | if 'lr' in self.opt_keys: self._lr = self.read_val('lr') 214 | if 'momentum' in self.opt_keys: self._mom = self.read_val('momentum') 215 | if 'alpha' in self.opt_keys: self._beta = self.read_val('alpha') 216 | if 'betas' in self.opt_keys: 217 | self._mom, self._beta = self.read_val('betas') 218 | if 'weight_decay' in self.opt_keys: 219 | self._wd = self.read_val('weight_decay') 220 | 221 | def set_val(self, key: str, val, bn_groups: bool = True): 222 | "Set `val` inside the optimizer dictionary at `key`." 223 | if is_tuple(val): val = [(v1, v2) for v1, v2 in zip(*val)] 224 | for v, pg1, pg2 in zip(val, self.opt.param_groups[::2], 225 | self.opt.param_groups[1::2]): 226 | pg1[key] = v 227 | if bn_groups: pg2[key] = v 228 | return val 229 | 230 | def read_val(self, key: str): 231 | "Read a hyperparameter `key` in the optimizer dictionary." 232 | val = [pg[key] for pg in self.opt.param_groups[::2]] 233 | if is_tuple(val[0]): val = [o[0] for o in val], [o[1] for o in val] 234 | return val 235 | -------------------------------------------------------------------------------- /utils/train_util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | 5 | import numpy as np 6 | import torch 7 | from scipy.stats import truncnorm 8 | from torch.utils.data.sampler import Sampler 9 | 10 | 11 | def create_logger(name, log_file, level=logging.INFO): 12 | l = logging.getLogger(name) 13 | formatter = logging.Formatter( 14 | '[%(asctime)s][%(filename)10s][line:%(lineno)4d][%(levelname)8s] %(message)s' # noqa 15 | ) 16 | fh = logging.FileHandler(log_file) 17 | fh.setFormatter(formatter) 18 | sh = logging.StreamHandler() 19 | sh.setFormatter(formatter) 20 | l.setLevel(level) 21 | l.addHandler(fh) 22 | l.addHandler(sh) 23 | return l 24 | 25 | 26 | class AverageMeter(object): 27 | """Computes and stores the average and current value""" 28 | 29 | def __init__(self, length=0): 30 | self.length = length 31 | self.reset() 32 | 33 | def reset(self): 34 | if self.length > 0: 35 | self.history = [] 36 | else: 37 | self.count = 0 38 | self.sum = 0.0 39 | self.val = 0.0 40 | self.avg = 0.0 41 | 42 | def update(self, val, num=1): 43 | if self.length > 0: 44 | # currently assert num==1 to avoid bad usage, 45 | # refine when there are some explict requirements 46 | assert num == 1 47 | self.history.append(val) 48 | if len(self.history) > self.length: 49 | del self.history[0] 50 | 51 | self.val = self.history[-1] 52 | self.avg = np.mean(self.history) 53 | else: 54 | self.val = val 55 | self.sum += val * num 56 | self.count += num 57 | self.avg = self.sum / self.count 58 | 59 | 60 | class ColorAugmentation(object): 61 | 62 | def __init__(self, eig_vec=None, eig_val=None): 63 | if eig_vec is None: 64 | eig_vec = torch.Tensor([ 65 | [0.4009, 0.7192, -0.5675], 66 | [-0.8140, -0.0045, -0.5808], 67 | [0.4203, -0.6948, -0.5836], 68 | ]) 69 | if eig_val is None: 70 | eig_val = torch.Tensor([[0.2175, 0.0188, 0.0045]]) 71 | self.eig_val = eig_val # 1*3 72 | self.eig_vec = eig_vec # 3*3 73 | 74 | def __call__(self, tensor): 75 | assert tensor.size(0) == 3 76 | alpha = torch.normal(mean=torch.zeros_like(self.eig_val)) * 0.1 77 | quatity = torch.mm(self.eig_val * alpha, self.eig_vec) 78 | tensor = tensor + quatity.view(3, 1, 1) 79 | return tensor 80 | 81 | 82 | def accuracy(output, target, topk=(1, )): 83 | """Computes the precision@k for the specified values of k""" 84 | maxk = max(topk) 85 | batch_size = target.size(0) 86 | 87 | _, pred = output.topk(maxk, 1, True, True) 88 | pred = pred.t() 89 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 90 | 91 | res = [] 92 | for k in topk: 93 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 94 | res.append(correct_k.mul_(100.0 / batch_size)) 95 | return res 96 | 97 | 98 | class DistributedGivenIterationSampler(Sampler): 99 | 100 | def __init__(self, 101 | dataset, 102 | total_iter, 103 | batch_size, 104 | world_size=None, 105 | rank=None, 106 | last_iter=-1): 107 | self.dataset = dataset 108 | self.total_iter = total_iter 109 | self.batch_size = batch_size 110 | self.world_size = world_size 111 | self.rank = rank 112 | self.last_iter = last_iter 113 | 114 | self.total_size = self.total_iter * self.batch_size 115 | 116 | self.indices = self.gen_new_list() 117 | self.call = 0 118 | 119 | def __iter__(self): 120 | if self.call == 0: 121 | self.call = 1 122 | return iter(self.indices[(self.last_iter + 1) * self.batch_size:]) 123 | else: 124 | raise RuntimeError( 125 | "this sampler is not designed to be called more than once!!") 126 | 127 | def gen_new_list(self): 128 | 129 | np.random.seed(0) 130 | all_size = self.total_size * self.world_size 131 | origin_indices = np.arange(len(self.dataset)) 132 | origin_indices = origin_indices[:all_size] 133 | num_repeat = (all_size - 1) // origin_indices.shape[0] + 1 134 | 135 | total_indices = [] 136 | for i in range(num_repeat): 137 | total_indices.append(np.random.permutation(origin_indices)) 138 | indices = np.concatenate(total_indices, axis=0)[:all_size] 139 | 140 | beg = self.total_size * self.rank 141 | indices = indices[beg:beg + self.total_size] 142 | 143 | assert len(indices) == self.total_size 144 | 145 | return indices 146 | 147 | def __len__(self): 148 | return self.total_size 149 | 150 | 151 | def save_checkpoint(state, is_best, filename): 152 | torch.save(state, filename + '.pth.tar') 153 | if is_best: 154 | shutil.copyfile(filename + '.pth.tar', filename + '_best.pth.tar') 155 | 156 | 157 | def load_state(path, model, optimizer=None, rank=0): 158 | 159 | def map_func(storage, location): 160 | return storage.cuda() 161 | 162 | if os.path.isfile(path): 163 | if rank == 0: 164 | print("=> loading checkpoint '{}'".format(path)) 165 | 166 | checkpoint = torch.load(path, map_location=map_func) 167 | model.load_state_dict(checkpoint['state_dict'], strict=False) 168 | 169 | if rank == 0: 170 | ckpt_keys = set(checkpoint['state_dict'].keys()) 171 | own_keys = set(model.state_dict().keys()) 172 | missing_keys = own_keys - ckpt_keys 173 | for k in missing_keys: 174 | print('caution: missing keys from checkpoint {}: {}'.format( 175 | path, k)) 176 | 177 | if optimizer is not None: 178 | best_prec1 = checkpoint['best_mota'] 179 | last_iter = checkpoint['step'] 180 | optimizer.load_state_dict(checkpoint['optimizer']) 181 | if rank == 0: 182 | print( 183 | "=> also loaded optimizer from checkpoint '{}' (iter {})". 184 | format(path, last_iter)) 185 | return best_prec1, last_iter 186 | else: 187 | if rank == 0: 188 | print("=> no checkpoint found at '{}'".format(path)) 189 | 190 | 191 | def calculate_distance(dets, gt_dets): 192 | import motmetrics as mm 193 | 194 | det = dets.copy() 195 | det[:, 2:] = det[:, 2:] - det[:, :2] 196 | gt_det = gt_dets.copy() 197 | gt_det[:, 2:] = gt_det[:, 2:] - gt_det[:, :2] 198 | 199 | return mm.distances.iou_matrix(gt_det, det, max_iou=0.3) 200 | 201 | 202 | def truncated_normal_(tensor, mean=0, std=0.001, clip_a=-2, clip_b=2): 203 | size = [*tensor.view(-1).size()] 204 | values = truncnorm.rvs(clip_a, clip_b, size=size[0]) 205 | with torch.no_grad(): 206 | tensor.copy_( 207 | torch.from_numpy(values).view(tensor.size()).mul(std).add(mean)) 208 | --------------------------------------------------------------------------------