├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── lib ├── backbone │ ├── fpn.py │ └── resnet50.py ├── data │ └── CrowdHuman.py ├── det_oprs │ ├── anchors_generator.py │ ├── bbox_opr.py │ ├── cascade_roi_target.py │ ├── find_top_rpn_proposals.py │ ├── fpn_anchor_target.py │ ├── fpn_roi_target.py │ ├── loss_opr.py │ ├── retina_anchor_target.py │ └── utils.py ├── evaluate │ ├── APMRToolkits │ │ ├── __init__.py │ │ ├── database.py │ │ └── image.py │ ├── JIToolkits │ │ ├── JI_tools.py │ │ └── matching.py │ ├── __init__.py │ ├── compute_APMR.py │ └── compute_JI.py ├── layers │ ├── batch_norm.py │ └── pooler.py ├── module │ └── rpn.py └── utils │ ├── SGD_bias.py │ ├── misc_utils.py │ ├── nms_utils.py │ └── visual_utils.py ├── model ├── rcnn_emd_refine │ ├── config.py │ └── network.py ├── rcnn_emd_simple │ ├── config.py │ └── network.py ├── rcnn_fpn_baseline │ ├── config.py │ └── network.py ├── retina_emd_simple │ ├── config.py │ └── network.py └── retina_fpn_baseline │ ├── config.py │ └── network.py ├── requirements.txt └── tools ├── eval_json.py ├── inference.py ├── test.py ├── train.py └── visulize_json.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/.DS_Store 2 | **/__pycache__ 3 | **/outputs 4 | **/model_dump 5 | 6 | **/*.out 7 | **/*.json 8 | **/*.pkl 9 | **/*.pth 10 | **/*.log 11 | 12 | **/*.jpg 13 | **/*.jpeg 14 | **/*.png 15 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG PYTORCH="1.5" 2 | ARG CUDA="10.1" 3 | ARG CUDNN="7" 4 | 5 | FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel 6 | 7 | ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX" 8 | ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all" 9 | ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" 10 | 11 | RUN apt-get update && apt-get install -y vim git ninja-build libglib2.0-0 libsm6 libgl1-mesa-dev libxrender-dev libxext6 \ 12 | && apt-get clean \ 13 | && rm -rf /var/lib/apt/lists/* 14 | 15 | # Install mmdetection 16 | #RUN git clone https://github.com/open-mmlab/mmdetection.git /mmdetection 17 | WORKDIR /crowddet 18 | ADD . . 19 | RUN pip install --default-timeout=1800 -r requirements.txt --ignore-installed 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Megvii Technology 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Detection in Crowded Scenes: One Proposal, Multiple Predictions 2 | 3 | This is the pytorch implementation of our paper "[Detection in Crowded Scenes: One Proposal, Multiple Predictions](https://openaccess.thecvf.com/content_CVPR_2020/html/Chu_Detection_in_Crowded_Scenes_One_Proposal_Multiple_Predictions_CVPR_2020_paper.html)", https://arxiv.org/abs/2003.09163, published in CVPR 2020. 4 | 5 | Our method aiming at detecting highly-overlapped instances in crowded scenes. 6 | 7 | The key of our approach is to let each proposal predict a set of instances that might be highly overlapped rather than a single one in previous proposal-based frameworks. With this scheme, the predictions of nearby proposals are expected to infer the **same set** of instances, rather than **distinguishing individuals**, which is much easy to be learned. Equipped with new techniques such as EMD Loss and Set NMS, our detector can effectively handle the difficulty of detecting highly overlapped objects. 8 | 9 | The network structure and results are shown here: 10 | 11 | 12 | 13 | 14 | # Citation 15 | 16 | If you use the code in your research, please cite: 17 | ``` 18 | @InProceedings{Chu_2020_CVPR, 19 | author = {Chu, Xuangeng and Zheng, Anlin and Zhang, Xiangyu and Sun, Jian}, 20 | title = {Detection in Crowded Scenes: One Proposal, Multiple Predictions}, 21 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 22 | month = {June}, 23 | year = {2020} 24 | } 25 | ``` 26 | 27 | # Run 28 | 1) Setup environment by docker 29 | - Requirements: Install [nvidia-docker](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker) 30 | - Create docker image: 31 | ```shell 32 | sudo docker build . -t crowddet 33 | ``` 34 | - Run docker image: 35 | ```shell 36 | sudo docker run --gpus all --shm-size=8g -it --rm crowddet 37 | ``` 38 | 39 | 2. CrowdHuman data: 40 | * CrowdHuman is a benchmark dataset to better evaluate detectors in crowd scenarios. The dataset can be downloaded from http://www.crowdhuman.org/. The path of the dataset is set in `config.py`. 41 | 42 | 3. Steps to run: 43 | * Step1: training. More training and testing settings can be set in `config.py`. 44 | ``` 45 | cd tools 46 | python3 train.py -md rcnn_fpn_baseline 47 | ``` 48 | 49 | * Step2: testing. If you have four GPUs, you can use ` -d 0-3 ` to use all of your GPUs. 50 | The result json file will be evaluated automatically. 51 | ``` 52 | cd tools 53 | python3 test.py -md rcnn_fpn_baseline -r 40 54 | ``` 55 | 56 | * Step3: evaluating json, inference one picture and visulization json file. 57 | ` -r ` means resume epoch, ` -n ` means number of visulization pictures. 58 | ``` 59 | cd tools 60 | python3 eval_json.py -f your_json_path.json 61 | python3 inference.py -md rcnn_fpn_baseline -r 40 -i your_image_path.png 62 | python3 visulize_json.py -f your_json_path.json -n 3 63 | ``` 64 | 65 | # Models 66 | 67 | We use MegEngine in the research (https://github.com/megvii-model/CrowdDetection), this proiect is a re-implementation based on Pytorch. 68 | 69 | We use pre-trained model from [MegEngine Model Hub](https://megengine.org.cn/model-hub) and convert this model to pytorch. You can get this model from [here](https://drive.google.com/file/d/1lfYQHC63oM2Dynbfj6uD7XnpDIaA5kNr/view?usp=sharing). 70 | These models can also be downloaded from [Baidu Netdisk](https://pan.baidu.com/s/1U3I-qNIrXuYQzUEDDdISTw)(code:yx46). 71 | | Model | Top1 acc | Top5 acc | 72 | | --- | --- | --- | 73 | | ResNet50 | 76.254 | 93.056 | 74 | 75 | All models are based on ResNet-50 FPN. 76 | | | AP | MR | JI | Model 77 | | --- | --- | --- | --- | --- | 78 | | RCNN FPN Baseline (convert from MegEngine) | 0.8718 | 0.4239 | 0.7949 | [rcnn_fpn_baseline_mge.pth](https://drive.google.com/file/d/19LBc_6vizKr06Wky0s7TAnvlqP8PjSA_/view?usp=sharing) | 79 | | RCNN EMD Simple (convert from MegEngine) | 0.9052 | 0.4196 | 0.8209 | [rcnn_emd_simple_mge.pth](https://drive.google.com/file/d/1f_vjFrjTxXYR5nPnYZRrU-yffYTGUnyL/view?usp=sharing) | 80 | | RCNN EMD with RM (convert from MegEngine) | 0.9097 | 0.4102 | 0.8271 | [rcnn_emd_refine_mge.pth](https://drive.google.com/file/d/1qYJ0b7QsYZsP5_8yIjya_kj_tu90ALDJ/view?usp=sharing) | 81 | | RCNN FPN Baseline (trained with PyTorch) | 0.8665 | 0.4243 | 0.7949 | [rcnn_fpn_baseline.pth](https://drive.google.com/file/d/10poBJ1qwlV0iS6i_lnbw9cbdt4tpTvh1/view?usp=sharing) | 82 | | RCNN EMD Simple (trained with PyTorch) | 0.8997 | 0.4167 | 0.8225 | [rcnn_emd_simple.pth](https://drive.google.com/file/d/1Rryeqz5sMWTTm3epEfqDpK1H8EsPvlLe/view?usp=sharing) | 83 | | RCNN EMD with RM (trained with PyTorch) | 0.9030 | 0.4128 | 0.8263 | [rcnn_emd_refine.pth](https://drive.google.com/file/d/1jk_b7Ws528uCfEgOesLS_iBsqXcHl2Ju/view?usp=sharing) | 84 | | RetinaNet FPN Baseline | 0.8188 | 0.5644 | 0.7316 | [retina_fpn_baseline.pth](https://drive.google.com/file/d/1w1CmE4MfYB4NT5Uyx85dPkR87gEFXhBJ/view?usp=sharing) | 85 | | RetinaNet EMD Simple | 0.8292 | 0.5481 | 0.7393 | [retina_emd_simple.pth](https://drive.google.com/file/d/1LwUlTf4YAH3wp-HXAAuXyTD11SeDgwhE/view?usp=sharing) | 86 | 89 | 90 | # Contact 91 | 92 | If you have any questions, please do not hesitate to contact Xuangeng Chu (xg_chu@pku.edu.cn). 93 | -------------------------------------------------------------------------------- /lib/backbone/fpn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | class FPN(nn.Module): 8 | """ 9 | This module implements Feature Pyramid Network. 10 | It creates pyramid features built on top of some input feature maps. 11 | """ 12 | def __init__(self, bottom_up, layers_begin, layers_end): 13 | super(FPN, self).__init__() 14 | assert layers_begin > 1 and layers_begin < 6 15 | assert layers_end > 4 and layers_begin < 8 16 | in_channels = [256, 512, 1024, 2048] 17 | fpn_dim = 256 18 | in_channels = in_channels[layers_begin-2:] 19 | 20 | lateral_convs = nn.ModuleList() 21 | output_convs = nn.ModuleList() 22 | for idx, in_channels in enumerate(in_channels): 23 | lateral_conv = nn.Conv2d(in_channels, fpn_dim, kernel_size=1) 24 | output_conv = nn.Conv2d(fpn_dim, fpn_dim, kernel_size=3, stride=1, padding=1) 25 | nn.init.kaiming_normal_(lateral_conv.weight, mode='fan_out') 26 | nn.init.constant_(lateral_conv.bias, 0) 27 | nn.init.kaiming_normal_(output_conv.weight, mode='fan_out') 28 | nn.init.constant_(output_conv.bias, 0) 29 | lateral_convs.append(lateral_conv) 30 | output_convs.append(output_conv) 31 | 32 | self.lateral_convs = lateral_convs[::-1] 33 | self.output_convs = output_convs[::-1] 34 | self.bottom_up = bottom_up 35 | self.output_b = layers_begin 36 | self.output_e = layers_end 37 | if self.output_e == 7: 38 | self.p6 = nn.Conv2d(fpn_dim, fpn_dim, kernel_size=3, stride=2, padding=1) 39 | self.p7 = nn.Conv2d(fpn_dim, fpn_dim, kernel_size=3, stride=2, padding=1) 40 | for l in [self.p6, self.p7]: 41 | nn.init.kaiming_uniform_(l.weight, a=1) # pyre-ignore 42 | nn.init.constant_(l.bias, 0) 43 | 44 | def forward(self, x): 45 | bottom_up_features = self.bottom_up(x) 46 | bottom_up_features = bottom_up_features[self.output_b - 2:] 47 | bottom_up_features = bottom_up_features[::-1] 48 | results = [] 49 | prev_features = self.lateral_convs[0](bottom_up_features[0]) 50 | results.append(self.output_convs[0](prev_features)) 51 | for l_id, (features, lateral_conv, output_conv) in enumerate(zip( 52 | bottom_up_features[1:], self.lateral_convs[1:], self.output_convs[1:])): 53 | top_down_features = F.interpolate(prev_features, scale_factor=2, mode="bilinear", align_corners=False) 54 | lateral_features = lateral_conv(features) 55 | prev_features = lateral_features + top_down_features 56 | results.append(output_conv(prev_features)) 57 | if(self.output_e == 6): 58 | p6 = F.max_pool2d(results[0], kernel_size=1, stride=2, padding=0) 59 | results.insert(0, p6) 60 | elif(self.output_e == 7): 61 | p6 = self.p6(results[0]) 62 | results.insert(0, p6) 63 | p7 = self.p7(F.relu(results[0])) 64 | results.insert(0, p7) 65 | return results 66 | 67 | -------------------------------------------------------------------------------- /lib/backbone/resnet50.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from layers.batch_norm import FrozenBatchNorm2d 6 | 7 | class Bottleneck(nn.Module): 8 | def __init__(self, in_cha, neck_cha, out_cha, stride, has_bias=False): 9 | super(Bottleneck, self).__init__() 10 | 11 | self.downsample = None 12 | if in_cha!= out_cha or stride != 1: 13 | self.downsample = nn.Sequential( 14 | nn.Conv2d(in_cha, out_cha, kernel_size=1, stride=stride, bias=has_bias), 15 | FrozenBatchNorm2d(out_cha), 16 | ) 17 | 18 | # The original MSRA ResNet models have stride in the first 1x1 conv 19 | # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations 20 | # have stride in the 3x3 conv 21 | self.conv1 = nn.Conv2d(in_cha, neck_cha, kernel_size=1, stride=1, bias=has_bias) 22 | self.bn1 = FrozenBatchNorm2d(neck_cha) 23 | 24 | self.conv2 = nn.Conv2d(neck_cha, neck_cha, kernel_size=3, stride=stride, padding=1, bias=has_bias) 25 | self.bn2 = FrozenBatchNorm2d(neck_cha) 26 | 27 | self.conv3 = nn.Conv2d(neck_cha, out_cha, kernel_size=1, bias=has_bias) 28 | self.bn3 = FrozenBatchNorm2d(out_cha) 29 | 30 | def forward(self, x): 31 | identity = x 32 | 33 | x = self.conv1(x) 34 | x = self.bn1(x) 35 | x = F.relu_(x) 36 | x = self.conv2(x) 37 | x = self.bn2(x) 38 | x = F.relu_(x) 39 | x = self.conv3(x) 40 | x = self.bn3(x) 41 | if self.downsample is not None: 42 | identity = self.downsample(identity) 43 | 44 | x += identity 45 | x = F.relu_(x) 46 | return x 47 | 48 | 49 | class ResNet50(nn.Module): 50 | def __init__(self, freeze_at, has_bias=False): 51 | super(ResNet50, self).__init__() 52 | self.has_bias = has_bias 53 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=has_bias) 54 | self.bn1 = FrozenBatchNorm2d(64) 55 | 56 | block_counts = [3, 4, 6, 3] 57 | bottleneck_channels_list = [64, 128, 256, 512] 58 | out_channels_list = [256, 512, 1024, 2048] 59 | stride_list = [1, 2, 2, 2] 60 | in_channels = 64 61 | 62 | self.layer1 = self._make_layer(block_counts[0], 64, 63 | bottleneck_channels_list[0], out_channels_list[0], stride_list[0]) 64 | self.layer2 = self._make_layer(block_counts[1], out_channels_list[0], 65 | bottleneck_channels_list[1], out_channels_list[1], stride_list[1]) 66 | self.layer3 = self._make_layer(block_counts[2], out_channels_list[1], 67 | bottleneck_channels_list[2], out_channels_list[2], stride_list[2]) 68 | self.layer4 = self._make_layer(block_counts[3], out_channels_list[2], 69 | bottleneck_channels_list[3], out_channels_list[3], stride_list[3]) 70 | 71 | for l in self.modules(): 72 | if isinstance(l, nn.Conv2d): 73 | nn.init.kaiming_normal_(l.weight, mode='fan_out') 74 | if self.has_bias: 75 | nn.init.constant_(l.bias, 0) 76 | 77 | self._freeze_backbone(freeze_at) 78 | 79 | def _make_layer(self, num_blocks, in_channels, bottleneck_channels, out_channels, stride): 80 | layers = [] 81 | for _ in range(num_blocks): 82 | layers.append(Bottleneck(in_channels, bottleneck_channels, out_channels, stride, self.has_bias)) 83 | stride = 1 84 | in_channels = out_channels 85 | return nn.Sequential(*layers) 86 | 87 | def _freeze_backbone(self, freeze_at): 88 | if freeze_at < 0: 89 | return 90 | if freeze_at >= 1: 91 | for p in self.conv1.parameters(): 92 | p.requires_grad = False 93 | if freeze_at >= 2: 94 | for p in self.layer1.parameters(): 95 | p.requires_grad = False 96 | if freeze_at >= 3: 97 | print("Freeze too much layers! Only freeze the first 2 layers.") 98 | 99 | def forward(self, x): 100 | outputs = [] 101 | # stem 102 | x = self.conv1(x) 103 | x = self.bn1(x) 104 | x = F.relu_(x) 105 | x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) 106 | # blocks 107 | x = self.layer1(x) 108 | outputs.append(x) 109 | x = self.layer2(x) 110 | outputs.append(x) 111 | x = self.layer3(x) 112 | outputs.append(x) 113 | x = self.layer4(x) 114 | outputs.append(x) 115 | return outputs 116 | 117 | -------------------------------------------------------------------------------- /lib/data/CrowdHuman.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | 6 | from utils import misc_utils 7 | 8 | class CrowdHuman(torch.utils.data.Dataset): 9 | def __init__(self, config, if_train): 10 | if if_train: 11 | self.training = True 12 | source = config.train_source 13 | self.short_size = config.train_image_short_size 14 | self.max_size = config.train_image_max_size 15 | else: 16 | self.training = False 17 | source = config.eval_source 18 | self.short_size = config.eval_image_short_size 19 | self.max_size = config.eval_image_max_size 20 | self.records = misc_utils.load_json_lines(source) 21 | self.config = config 22 | 23 | def __getitem__(self, index): 24 | return self.load_record(self.records[index]) 25 | 26 | def __len__(self): 27 | return len(self.records) 28 | 29 | def load_record(self, record): 30 | if self.training: 31 | if_flap = np.random.randint(2) == 1 32 | else: 33 | if_flap = False 34 | # image 35 | image_path = os.path.join(self.config.image_folder, record['ID']+'.png') 36 | image = misc_utils.load_img(image_path) 37 | image_h = image.shape[0] 38 | image_w = image.shape[1] 39 | if if_flap: 40 | image = cv2.flip(image, 1) 41 | if self.training: 42 | # ground_truth 43 | gtboxes = misc_utils.load_gt(record, 'gtboxes', 'fbox', self.config.class_names) 44 | keep = (gtboxes[:, 2]>=0) * (gtboxes[:, 3]>=0) 45 | gtboxes=gtboxes[keep, :] 46 | gtboxes[:, 2:4] += gtboxes[:, :2] 47 | if if_flap: 48 | gtboxes = flip_boxes(gtboxes, image_w) 49 | # im_info 50 | nr_gtboxes = gtboxes.shape[0] 51 | im_info = np.array([0, 0, 1, image_h, image_w, nr_gtboxes]) 52 | return image, gtboxes, im_info 53 | else: 54 | # image 55 | t_height, t_width, scale = target_size( 56 | image_h, image_w, self.short_size, self.max_size) 57 | # INTER_CUBIC, INTER_LINEAR, INTER_NEAREST, INTER_AREA, INTER_LANCZOS4 58 | resized_image = cv2.resize(image, (t_width, t_height), interpolation=cv2.INTER_LINEAR) 59 | resized_image = resized_image.transpose(2, 0, 1) 60 | image = torch.tensor(resized_image).float() 61 | gtboxes = misc_utils.load_gt(record, 'gtboxes', 'fbox', self.config.class_names) 62 | gtboxes[:, 2:4] += gtboxes[:, :2] 63 | gtboxes = torch.tensor(gtboxes) 64 | # im_info 65 | nr_gtboxes = gtboxes.shape[0] 66 | im_info = torch.tensor([t_height, t_width, scale, image_h, image_w, nr_gtboxes]) 67 | return image, gtboxes, im_info, record['ID'] 68 | 69 | def merge_batch(self, data): 70 | # image 71 | images = [it[0] for it in data] 72 | gt_boxes = [it[1] for it in data] 73 | im_info = np.array([it[2] for it in data]) 74 | batch_height = np.max(im_info[:, 3]) 75 | batch_width = np.max(im_info[:, 4]) 76 | padded_images = [pad_image( 77 | im, batch_height, batch_width, self.config.image_mean) for im in images] 78 | t_height, t_width, scale = target_size( 79 | batch_height, batch_width, self.short_size, self.max_size) 80 | # INTER_CUBIC, INTER_LINEAR, INTER_NEAREST, INTER_AREA, INTER_LANCZOS4 81 | resized_images = np.array([cv2.resize( 82 | im, (t_width, t_height), interpolation=cv2.INTER_LINEAR) for im in padded_images]) 83 | resized_images = resized_images.transpose(0, 3, 1, 2) 84 | images = torch.tensor(resized_images).float() 85 | # ground_truth 86 | ground_truth = [] 87 | for it in gt_boxes: 88 | gt_padded = np.zeros((self.config.max_boxes_of_image, self.config.nr_box_dim)) 89 | it[:, 0:4] *= scale 90 | max_box = min(self.config.max_boxes_of_image, len(it)) 91 | gt_padded[:max_box] = it[:max_box] 92 | ground_truth.append(gt_padded) 93 | ground_truth = torch.tensor(ground_truth).float() 94 | # im_info 95 | im_info[:, 0] = t_height 96 | im_info[:, 1] = t_width 97 | im_info[:, 2] = scale 98 | im_info = torch.tensor(im_info) 99 | if max(im_info[:, -1] < 2): 100 | return None, None, None 101 | else: 102 | return images, ground_truth, im_info 103 | 104 | def target_size(height, width, short_size, max_size): 105 | im_size_min = np.min([height, width]) 106 | im_size_max = np.max([height, width]) 107 | scale = (short_size + 0.0) / im_size_min 108 | if scale * im_size_max > max_size: 109 | scale = (max_size + 0.0) / im_size_max 110 | t_height, t_width = int(round(height * scale)), int( 111 | round(width * scale)) 112 | return t_height, t_width, scale 113 | 114 | def flip_boxes(boxes, im_w): 115 | flip_boxes = boxes.copy() 116 | for i in range(flip_boxes.shape[0]): 117 | flip_boxes[i, 0] = im_w - boxes[i, 2] - 1 118 | flip_boxes[i, 2] = im_w - boxes[i, 0] - 1 119 | return flip_boxes 120 | 121 | def pad_image(img, height, width, mean_value): 122 | o_h, o_w, _ = img.shape 123 | margins = np.zeros(2, np.int32) 124 | assert o_h <= height 125 | margins[0] = height - o_h 126 | img = cv2.copyMakeBorder( 127 | img, 0, margins[0], 0, 0, cv2.BORDER_CONSTANT, value=0) 128 | img[o_h:, :, :] = mean_value 129 | assert o_w <= width 130 | margins[1] = width - o_w 131 | img = cv2.copyMakeBorder( 132 | img, 0, 0, 0, margins[1], cv2.BORDER_CONSTANT, value=0) 133 | img[:, o_w:, :] = mean_value 134 | return img 135 | -------------------------------------------------------------------------------- /lib/det_oprs/anchors_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | class AnchorGenerator(): 6 | """default anchor generator for fpn. 7 | This class generate anchors by feature map in level. 8 | """ 9 | def __init__(self, base_size=16, ratios=[0.5, 1, 2], 10 | base_scale=2): 11 | self.base_size = base_size 12 | self.base_scale = np.array(base_scale) 13 | self.anchor_ratios = np.array(ratios) 14 | 15 | def _whctrs(self, anchor): 16 | """convert anchor box into (w, h, ctr_x, ctr_y) 17 | """ 18 | w = anchor[:, 2] - anchor[:, 0] + 1 19 | h = anchor[:, 3] - anchor[:, 1] + 1 20 | x_ctr = anchor[:, 0] + 0.5 * (w - 1) 21 | y_ctr = anchor[:, 1] + 0.5 * (h - 1) 22 | return w, h, x_ctr, y_ctr 23 | 24 | def get_plane_anchors(self, anchor_scales: np.ndarray): 25 | """get anchors per location on feature map. 26 | The anchor number is anchor_scales x anchor_ratios 27 | """ 28 | base_anchor = np.array([[0, 0, self.base_size - 1, self.base_size - 1]]) 29 | off = self.base_size // 2 - 8 30 | w, h, x_ctr, y_ctr = self._whctrs(base_anchor) 31 | # ratio enumerate 32 | size = w * h 33 | size_ratios = size / self.anchor_ratios 34 | ws = np.round(np.sqrt(size_ratios)) 35 | hs = np.round(ws * self.anchor_ratios) 36 | # scale enumerate 37 | anchor_scales = anchor_scales[None, ...] 38 | ws = (ws[:, None] * anchor_scales).reshape(-1, 1) 39 | hs = (hs[:, None] * anchor_scales).reshape(-1, 1) 40 | # make anchors 41 | anchors = np.hstack((x_ctr - 0.5 * (ws - 1), 42 | y_ctr - 0.5 * (hs - 1), 43 | x_ctr + 0.5 * (ws - 1), 44 | y_ctr + 0.5 * (hs - 1))) - off 45 | return anchors.astype(np.float32) 46 | 47 | def get_center_offsets(self, fm_map, stride): 48 | fm_height, fm_width = fm_map.shape[-2], fm_map.shape[-1] 49 | f_device = fm_map.device 50 | shift_x = torch.arange(0, fm_width, device=f_device) * stride 51 | shift_y = torch.arange(0, fm_height, device=f_device) * stride 52 | broad_shift_x = shift_x.reshape(-1, shift_x.shape[0]).repeat(fm_height,1) 53 | broad_shift_y = shift_y.reshape(shift_y.shape[0], -1).repeat(1,fm_width) 54 | flatten_shift_x = broad_shift_x.flatten().reshape(-1,1) 55 | flatten_shift_y = broad_shift_y.flatten().reshape(-1,1) 56 | shifts = torch.cat( 57 | [flatten_shift_x, flatten_shift_y, flatten_shift_x, flatten_shift_y], 58 | axis=1) 59 | return shifts 60 | 61 | def get_anchors_by_feature(self, fm_map, base_stride, off_stride): 62 | # shifts shape: [A, 4] 63 | shifts = self.get_center_offsets(fm_map, base_stride * off_stride) 64 | # plane_anchors shape: [B, 4], e.g. B=3 65 | plane_anchors = self.get_plane_anchors(self.base_scale * off_stride) 66 | plane_anchors = torch.tensor(plane_anchors, device=fm_map.device) 67 | # all_anchors shape: [A, B, 4] 68 | all_anchors = plane_anchors[None, :] + shifts[:, None] 69 | all_anchors = all_anchors.reshape(-1, 4) 70 | return all_anchors 71 | 72 | @torch.no_grad() 73 | def __call__(self, featmap, base_stride, off_stride): 74 | return self.get_anchors_by_feature(featmap, base_stride, off_stride) 75 | 76 | -------------------------------------------------------------------------------- /lib/det_oprs/bbox_opr.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | def filter_boxes_opr(boxes, min_size): 5 | """Remove all boxes with any side smaller than min_size.""" 6 | ws = boxes[:, 2] - boxes[:, 0] + 1 7 | hs = boxes[:, 3] - boxes[:, 1] + 1 8 | keep = (ws >= min_size) * (hs >= min_size) 9 | return keep 10 | 11 | def clip_boxes_opr(boxes, im_info): 12 | """ Clip the boxes into the image region.""" 13 | w = im_info[1] - 1 14 | h = im_info[0] - 1 15 | boxes[:, 0::4] = boxes[:, 0::4].clamp(min=0, max=w) 16 | boxes[:, 1::4] = boxes[:, 1::4].clamp(min=0, max=h) 17 | boxes[:, 2::4] = boxes[:, 2::4].clamp(min=0, max=w) 18 | boxes[:, 3::4] = boxes[:, 3::4].clamp(min=0, max=h) 19 | return boxes 20 | 21 | def batch_clip_proposals(proposals, im_info): 22 | """ Clip the boxes into the image region.""" 23 | w = im_info[1] - 1 24 | h = im_info[0] - 1 25 | boxes[:, 0::4] = boxes[:, 0::4].clamp(min=0, max=w) 26 | boxes[:, 1::4] = boxes[:, 1::4].clamp(min=0, max=h) 27 | boxes[:, 2::4] = boxes[:, 2::4].clamp(min=0, max=w) 28 | boxes[:, 3::4] = boxes[:, 3::4].clamp(min=0, max=h) 29 | return boxes 30 | 31 | def bbox_transform_inv_opr(bbox, deltas): 32 | max_delta = math.log(1000.0 / 16) 33 | """ Transforms the learned deltas to the final bbox coordinates, the axis is 1""" 34 | bbox_width = bbox[:, 2] - bbox[:, 0] + 1 35 | bbox_height = bbox[:, 3] - bbox[:, 1] + 1 36 | bbox_ctr_x = bbox[:, 0] + 0.5 * bbox_width 37 | bbox_ctr_y = bbox[:, 1] + 0.5 * bbox_height 38 | pred_ctr_x = bbox_ctr_x + deltas[:, 0] * bbox_width 39 | pred_ctr_y = bbox_ctr_y + deltas[:, 1] * bbox_height 40 | 41 | dw = deltas[:, 2] 42 | dh = deltas[:, 3] 43 | dw = torch.clamp(dw, max=max_delta) 44 | dh = torch.clamp(dh, max=max_delta) 45 | pred_width = bbox_width * torch.exp(dw) 46 | pred_height = bbox_height * torch.exp(dh) 47 | 48 | pred_x1 = pred_ctr_x - 0.5 * pred_width 49 | pred_y1 = pred_ctr_y - 0.5 * pred_height 50 | pred_x2 = pred_ctr_x + 0.5 * pred_width 51 | pred_y2 = pred_ctr_y + 0.5 * pred_height 52 | pred_boxes = torch.cat((pred_x1.reshape(-1, 1), pred_y1.reshape(-1, 1), 53 | pred_x2.reshape(-1, 1), pred_y2.reshape(-1, 1)), dim=1) 54 | return pred_boxes 55 | 56 | def bbox_transform_opr(bbox, gt): 57 | """ Transform the bounding box and ground truth to the loss targets. 58 | The 4 box coordinates are in axis 1""" 59 | bbox_width = bbox[:, 2] - bbox[:, 0] + 1 60 | bbox_height = bbox[:, 3] - bbox[:, 1] + 1 61 | bbox_ctr_x = bbox[:, 0] + 0.5 * bbox_width 62 | bbox_ctr_y = bbox[:, 1] + 0.5 * bbox_height 63 | 64 | gt_width = gt[:, 2] - gt[:, 0] + 1 65 | gt_height = gt[:, 3] - gt[:, 1] + 1 66 | gt_ctr_x = gt[:, 0] + 0.5 * gt_width 67 | gt_ctr_y = gt[:, 1] + 0.5 * gt_height 68 | 69 | target_dx = (gt_ctr_x - bbox_ctr_x) / bbox_width 70 | target_dy = (gt_ctr_y - bbox_ctr_y) / bbox_height 71 | target_dw = torch.log(gt_width / bbox_width) 72 | target_dh = torch.log(gt_height / bbox_height) 73 | target = torch.cat((target_dx.reshape(-1, 1), target_dy.reshape(-1, 1), 74 | target_dw.reshape(-1, 1), target_dh.reshape(-1, 1)), dim=1) 75 | return target 76 | 77 | def box_overlap_opr(box, gt): 78 | assert box.ndim == 2 79 | assert gt.ndim == 2 80 | area_box = (box[:, 2] - box[:, 0] + 1) * (box[:, 3] - box[:, 1] + 1) 81 | area_gt = (gt[:, 2] - gt[:, 0] + 1) * (gt[:, 3] - gt[:, 1] + 1) 82 | width_height = torch.min(box[:, None, 2:], gt[:, 2:]) - torch.max( 83 | box[:, None, :2], gt[:, :2]) + 1 # [N,M,2] 84 | width_height.clamp_(min=0) # [N,M,2] 85 | inter = width_height.prod(dim=2) # [N,M] 86 | del width_height 87 | # handle empty boxes 88 | iou = torch.where( 89 | inter > 0, 90 | inter / (area_box[:, None] + area_gt - inter), 91 | torch.zeros(1, dtype=inter.dtype, device=inter.device), 92 | ) 93 | return iou 94 | 95 | def box_overlap_ignore_opr(box, gt, ignore_label=-1): 96 | assert box.ndim == 2 97 | assert gt.ndim == 2 98 | assert gt.shape[-1] > 4 99 | area_box = (box[:, 2] - box[:, 0] + 1) * (box[:, 3] - box[:, 1] + 1) 100 | area_gt = (gt[:, 2] - gt[:, 0] + 1) * (gt[:, 3] - gt[:, 1] + 1) 101 | width_height = torch.min(box[:, None, 2:], gt[:, 2:4]) - torch.max( 102 | box[:, None, :2], gt[:, :2]) # [N,M,2] 103 | width_height.clamp_(min=0) # [N,M,2] 104 | inter = width_height.prod(dim=2) # [N,M] 105 | del width_height 106 | # handle empty boxes 107 | iou = torch.where( 108 | inter > 0, 109 | inter / (area_box[:, None] + area_gt - inter), 110 | torch.zeros(1, dtype=inter.dtype, device=inter.device)) 111 | ioa = torch.where( 112 | inter > 0, 113 | inter / (area_box[:, None]), 114 | torch.zeros(1, dtype=inter.dtype, device=inter.device)) 115 | gt_ignore_mask = gt[:, 4].eq(ignore_label).repeat(box.shape[0], 1) 116 | iou *= ~gt_ignore_mask 117 | ioa *= gt_ignore_mask 118 | return iou, ioa 119 | 120 | -------------------------------------------------------------------------------- /lib/det_oprs/cascade_roi_target.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | 4 | import numpy as np 5 | from config import config 6 | from det_oprs.bbox_opr import box_overlap_opr, bbox_transform_opr, box_overlap_ignore_opr 7 | 8 | @torch.no_grad() 9 | def cascade_roi_target(rpn_rois, im_info, gt_boxes, pos_threshold=0.5, top_k=1): 10 | return_rois = [] 11 | return_labels = [] 12 | return_bbox_targets = [] 13 | # get per image proposals and gt_boxes 14 | for bid in range(config.train_batch_per_gpu): 15 | gt_boxes_perimg = gt_boxes[bid, :int(im_info[bid, 5]), :] 16 | batch_inds = torch.ones((gt_boxes_perimg.shape[0], 1)).type_as(gt_boxes_perimg) * bid 17 | gt_rois = torch.cat([batch_inds, gt_boxes_perimg[:, :4]], axis=1) 18 | batch_roi_inds = torch.nonzero(rpn_rois[:, 0] == bid, as_tuple=False).flatten() 19 | all_rois = torch.cat([rpn_rois[batch_roi_inds], gt_rois], axis=0) 20 | overlaps_normal, overlaps_ignore = box_overlap_ignore_opr( 21 | all_rois[:, 1:5], gt_boxes_perimg) 22 | overlaps_normal, overlaps_normal_indices = overlaps_normal.sort(descending=True, dim=1) 23 | overlaps_ignore, overlaps_ignore_indices = overlaps_ignore.sort(descending=True, dim=1) 24 | # gt max and indices, ignore max and indices 25 | max_overlaps_normal = overlaps_normal[:, :top_k].flatten() 26 | gt_assignment_normal = overlaps_normal_indices[:, :top_k].flatten() 27 | max_overlaps_ignore = overlaps_ignore[:, :top_k].flatten() 28 | gt_assignment_ignore = overlaps_ignore_indices[:, :top_k].flatten() 29 | # cons masks 30 | ignore_assign_mask = (max_overlaps_normal < pos_threshold) * ( 31 | max_overlaps_ignore > max_overlaps_normal) 32 | max_overlaps = max_overlaps_normal * ~ignore_assign_mask + \ 33 | max_overlaps_ignore * ignore_assign_mask 34 | gt_assignment = gt_assignment_normal * ~ignore_assign_mask + \ 35 | gt_assignment_ignore * ignore_assign_mask 36 | labels = gt_boxes_perimg[gt_assignment, 4] 37 | fg_mask = (max_overlaps >= pos_threshold) * (labels != config.ignore_label) 38 | bg_mask = (max_overlaps < config.bg_threshold_high) 39 | fg_mask = fg_mask.reshape(-1, top_k) 40 | bg_mask = bg_mask.reshape(-1, top_k) 41 | #pos_max = config.num_rois * config.fg_ratio 42 | #fg_inds_mask = subsample_masks(fg_mask[:, 0], pos_max, True) 43 | #neg_max = config.num_rois - fg_inds_mask.sum() 44 | #bg_inds_mask = subsample_masks(bg_mask[:, 0], neg_max, True) 45 | labels = labels * fg_mask.flatten() 46 | #keep_mask = fg_inds_mask + bg_inds_mask 47 | # labels 48 | labels = labels.reshape(-1, top_k)#[keep_mask] 49 | gt_assignment = gt_assignment.reshape(-1, top_k).flatten() 50 | target_boxes = gt_boxes_perimg[gt_assignment, :4] 51 | #rois = all_rois[keep_mask] 52 | target_rois = all_rois.repeat(1, top_k).reshape(-1, all_rois.shape[-1]) 53 | bbox_targets = bbox_transform_opr(target_rois[:, 1:5], target_boxes) 54 | if config.rcnn_bbox_normalize_targets: 55 | std_opr = torch.tensor(config.bbox_normalize_stds[None, :]).type_as(bbox_targets) 56 | mean_opr = torch.tensor(config.bbox_normalize_means[None, :]).type_as(bbox_targets) 57 | minus_opr = mean_opr / std_opr 58 | bbox_targets = bbox_targets / std_opr - minus_opr 59 | bbox_targets = bbox_targets.reshape(-1, top_k * 4) 60 | return_rois.append(all_rois) 61 | return_labels.append(labels) 62 | return_bbox_targets.append(bbox_targets) 63 | if config.train_batch_per_gpu == 1: 64 | return rois, labels, bbox_targets 65 | else: 66 | return_rois = torch.cat(return_rois, axis=0) 67 | return_labels = torch.cat(return_labels, axis=0) 68 | return_bbox_targets = torch.cat(return_bbox_targets, axis=0) 69 | return return_rois, return_labels, return_bbox_targets 70 | 71 | -------------------------------------------------------------------------------- /lib/det_oprs/find_top_rpn_proposals.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from config import config 5 | from det_oprs.bbox_opr import bbox_transform_inv_opr, clip_boxes_opr, \ 6 | filter_boxes_opr 7 | from torchvision.ops import nms 8 | 9 | @torch.no_grad() 10 | def find_top_rpn_proposals(is_train, rpn_bbox_offsets_list, rpn_cls_prob_list, 11 | all_anchors_list, im_info): 12 | prev_nms_top_n = config.train_prev_nms_top_n \ 13 | if is_train else config.test_prev_nms_top_n 14 | post_nms_top_n = config.train_post_nms_top_n \ 15 | if is_train else config.test_post_nms_top_n 16 | batch_per_gpu = config.train_batch_per_gpu if is_train else 1 17 | nms_threshold = config.rpn_nms_threshold 18 | box_min_size = config.rpn_min_box_size 19 | bbox_normalize_targets = config.rpn_bbox_normalize_targets 20 | bbox_normalize_means = config.bbox_normalize_means 21 | bbox_normalize_stds = config.bbox_normalize_stds 22 | list_size = len(rpn_bbox_offsets_list) 23 | 24 | return_rois = [] 25 | return_inds = [] 26 | for bid in range(batch_per_gpu): 27 | batch_proposals_list = [] 28 | batch_probs_list = [] 29 | for l in range(list_size): 30 | # get proposals and probs 31 | offsets = rpn_bbox_offsets_list[l][bid] \ 32 | .permute(1, 2, 0).reshape(-1, 4) 33 | if bbox_normalize_targets: 34 | std_opr = torch.tensor(config.bbox_normalize_stds[None, :]).type_as(bbox_targets) 35 | mean_opr = torch.tensor(config.bbox_normalize_means[None, :]).type_as(bbox_targets) 36 | pred_offsets = pred_offsets * std_opr 37 | pred_offsets = pred_offsets + mean_opr 38 | all_anchors = all_anchors_list[l] 39 | proposals = bbox_transform_inv_opr(all_anchors, offsets) 40 | if config.anchor_within_border: 41 | proposals = clip_boxes_opr(proposals, im_info[bid, :]) 42 | probs = rpn_cls_prob_list[l][bid] \ 43 | .permute(1,2,0).reshape(-1, 2) 44 | probs = torch.softmax(probs, dim=-1)[:, 1] 45 | # gather the proposals and probs 46 | batch_proposals_list.append(proposals) 47 | batch_probs_list.append(probs) 48 | batch_proposals = torch.cat(batch_proposals_list, dim=0) 49 | batch_probs = torch.cat(batch_probs_list, dim=0) 50 | # filter the zero boxes. 51 | batch_keep_mask = filter_boxes_opr( 52 | batch_proposals, box_min_size * im_info[bid, 2]) 53 | batch_proposals = batch_proposals[batch_keep_mask] 54 | batch_probs = batch_probs[batch_keep_mask] 55 | # prev_nms_top_n 56 | num_proposals = min(prev_nms_top_n, batch_probs.shape[0]) 57 | batch_probs, idx = batch_probs.sort(descending=True) 58 | batch_probs = batch_probs[:num_proposals] 59 | topk_idx = idx[:num_proposals].flatten() 60 | batch_proposals = batch_proposals[topk_idx] 61 | # For each image, run a total-level NMS, and choose topk results. 62 | keep = nms(batch_proposals, batch_probs, nms_threshold) 63 | keep = keep[:post_nms_top_n] 64 | batch_proposals = batch_proposals[keep] 65 | #batch_probs = batch_probs[keep] 66 | # cons the rois 67 | batch_inds = torch.ones(batch_proposals.shape[0], 1).type_as(batch_proposals) * bid 68 | batch_rois = torch.cat([batch_inds, batch_proposals], axis=1) 69 | return_rois.append(batch_rois) 70 | 71 | if batch_per_gpu == 1: 72 | return batch_rois 73 | else: 74 | concated_rois = torch.cat(return_rois, axis=0) 75 | return concated_rois 76 | -------------------------------------------------------------------------------- /lib/det_oprs/fpn_anchor_target.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from det_oprs.bbox_opr import box_overlap_opr, bbox_transform_opr 5 | from config import config 6 | 7 | def fpn_rpn_reshape(pred_cls_score_list, pred_bbox_offsets_list): 8 | final_pred_bbox_offsets_list = [] 9 | final_pred_cls_score_list = [] 10 | for bid in range(config.train_batch_per_gpu): 11 | batch_pred_bbox_offsets_list = [] 12 | batch_pred_cls_score_list = [] 13 | for i in range(len(pred_cls_score_list)): 14 | pred_cls_score_perlvl = pred_cls_score_list[i][bid] \ 15 | .permute(1, 2, 0).reshape(-1, 2) 16 | pred_bbox_offsets_perlvl = pred_bbox_offsets_list[i][bid] \ 17 | .permute(1, 2, 0).reshape(-1, 4) 18 | batch_pred_cls_score_list.append(pred_cls_score_perlvl) 19 | batch_pred_bbox_offsets_list.append(pred_bbox_offsets_perlvl) 20 | batch_pred_cls_score = torch.cat(batch_pred_cls_score_list, dim=0) 21 | batch_pred_bbox_offsets = torch.cat(batch_pred_bbox_offsets_list, dim=0) 22 | final_pred_cls_score_list.append(batch_pred_cls_score) 23 | final_pred_bbox_offsets_list.append(batch_pred_bbox_offsets) 24 | final_pred_cls_score = torch.cat(final_pred_cls_score_list, dim=0) 25 | final_pred_bbox_offsets = torch.cat(final_pred_bbox_offsets_list, dim=0) 26 | return final_pred_cls_score, final_pred_bbox_offsets 27 | 28 | def fpn_anchor_target_opr_core_impl( 29 | gt_boxes, im_info, anchors, allow_low_quality_matches=True): 30 | ignore_label = config.ignore_label 31 | # get the gt boxes 32 | valid_gt_boxes = gt_boxes[:int(im_info[5]), :] 33 | valid_gt_boxes = valid_gt_boxes[valid_gt_boxes[:, -1].gt(0)] 34 | # compute the iou matrix 35 | anchors = anchors.type_as(valid_gt_boxes) 36 | overlaps = box_overlap_opr(anchors, valid_gt_boxes[:, :4]) 37 | # match the dtboxes 38 | max_overlaps, argmax_overlaps = torch.max(overlaps, axis=1) 39 | #_, gt_argmax_overlaps = torch.max(overlaps, axis=0) 40 | gt_argmax_overlaps = my_gt_argmax(overlaps) 41 | del overlaps 42 | # all ignore 43 | labels = torch.ones(anchors.shape[0], device=gt_boxes.device, dtype=torch.long) * ignore_label 44 | # set negative ones 45 | labels = labels * (max_overlaps >= config.rpn_negative_overlap) 46 | # set positive ones 47 | fg_mask = (max_overlaps >= config.rpn_positive_overlap) 48 | if allow_low_quality_matches: 49 | gt_id = torch.arange(valid_gt_boxes.shape[0]).type_as(argmax_overlaps) 50 | argmax_overlaps[gt_argmax_overlaps] = gt_id 51 | max_overlaps[gt_argmax_overlaps] = 1 52 | fg_mask = (max_overlaps >= config.rpn_positive_overlap) 53 | # set positive ones 54 | fg_mask_ind = torch.nonzero(fg_mask, as_tuple=False).flatten() 55 | labels[fg_mask_ind] = 1 56 | # bbox targets 57 | bbox_targets = bbox_transform_opr( 58 | anchors, valid_gt_boxes[argmax_overlaps, :4]) 59 | if config.rpn_bbox_normalize_targets: 60 | std_opr = torch.tensor(config.bbox_normalize_stds[None, :]).type_as(bbox_targets) 61 | mean_opr = torch.tensor(config.bbox_normalize_means[None, :]).type_as(bbox_targets) 62 | minus_opr = mean_opr / std_opr 63 | bbox_targets = bbox_targets / std_opr - minus_opr 64 | return labels, bbox_targets 65 | 66 | @torch.no_grad() 67 | def fpn_anchor_target(boxes, im_info, all_anchors_list): 68 | final_labels_list = [] 69 | final_bbox_targets_list = [] 70 | for bid in range(config.train_batch_per_gpu): 71 | batch_labels_list = [] 72 | batch_bbox_targets_list = [] 73 | for i in range(len(all_anchors_list)): 74 | anchors_perlvl = all_anchors_list[i] 75 | rpn_labels_perlvl, rpn_bbox_targets_perlvl = fpn_anchor_target_opr_core_impl( 76 | boxes[bid], im_info[bid], anchors_perlvl) 77 | batch_labels_list.append(rpn_labels_perlvl) 78 | batch_bbox_targets_list.append(rpn_bbox_targets_perlvl) 79 | # here we samples the rpn_labels 80 | concated_batch_labels = torch.cat(batch_labels_list, dim=0) 81 | concated_batch_bbox_targets = torch.cat(batch_bbox_targets_list, dim=0) 82 | # sample labels 83 | pos_idx, neg_idx = subsample_labels(concated_batch_labels, 84 | config.num_sample_anchors, config.positive_anchor_ratio) 85 | concated_batch_labels.fill_(-1) 86 | concated_batch_labels[pos_idx] = 1 87 | concated_batch_labels[neg_idx] = 0 88 | 89 | final_labels_list.append(concated_batch_labels) 90 | final_bbox_targets_list.append(concated_batch_bbox_targets) 91 | final_labels = torch.cat(final_labels_list, dim=0) 92 | final_bbox_targets = torch.cat(final_bbox_targets_list, dim=0) 93 | return final_labels, final_bbox_targets 94 | 95 | def my_gt_argmax(overlaps): 96 | gt_max_overlaps, _ = torch.max(overlaps, axis=0) 97 | gt_max_mask = overlaps == gt_max_overlaps 98 | gt_argmax_overlaps = [] 99 | for i in range(overlaps.shape[-1]): 100 | gt_max_inds = torch.nonzero(gt_max_mask[:, i], as_tuple=False).flatten() 101 | gt_max_ind = gt_max_inds[torch.randperm(gt_max_inds.numel(), device=gt_max_inds.device)[0,None]] 102 | gt_argmax_overlaps.append(gt_max_ind) 103 | gt_argmax_overlaps = torch.cat(gt_argmax_overlaps) 104 | return gt_argmax_overlaps 105 | 106 | def subsample_labels(labels, num_samples, positive_fraction): 107 | positive = torch.nonzero((labels != config.ignore_label) & (labels != 0), as_tuple=False).squeeze(1) 108 | negative = torch.nonzero(labels == 0, as_tuple=False).squeeze(1) 109 | 110 | num_pos = int(num_samples * positive_fraction) 111 | num_pos = min(positive.numel(), num_pos) 112 | num_neg = num_samples - num_pos 113 | num_neg = min(negative.numel(), num_neg) 114 | 115 | # randomly select positive and negative examples 116 | perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos] 117 | perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg] 118 | 119 | pos_idx = positive[perm1] 120 | neg_idx = negative[perm2] 121 | return pos_idx, neg_idx 122 | -------------------------------------------------------------------------------- /lib/det_oprs/fpn_roi_target.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | 4 | import numpy as np 5 | from config import config 6 | from det_oprs.bbox_opr import box_overlap_opr, bbox_transform_opr, box_overlap_ignore_opr 7 | 8 | @torch.no_grad() 9 | def fpn_roi_target(rpn_rois, im_info, gt_boxes, top_k=1): 10 | return_rois = [] 11 | return_labels = [] 12 | return_bbox_targets = [] 13 | # get per image proposals and gt_boxes 14 | for bid in range(config.train_batch_per_gpu): 15 | gt_boxes_perimg = gt_boxes[bid, :int(im_info[bid, 5]), :] 16 | batch_inds = torch.ones((gt_boxes_perimg.shape[0], 1)).type_as(gt_boxes_perimg) * bid 17 | gt_rois = torch.cat([batch_inds, gt_boxes_perimg[:, :4]], axis=1) 18 | batch_roi_inds = torch.nonzero(rpn_rois[:, 0] == bid, as_tuple=False).flatten() 19 | all_rois = torch.cat([rpn_rois[batch_roi_inds], gt_rois], axis=0) 20 | overlaps_normal, overlaps_ignore = box_overlap_ignore_opr( 21 | all_rois[:, 1:5], gt_boxes_perimg) 22 | overlaps_normal, overlaps_normal_indices = overlaps_normal.sort(descending=True, dim=1) 23 | overlaps_ignore, overlaps_ignore_indices = overlaps_ignore.sort(descending=True, dim=1) 24 | # gt max and indices, ignore max and indices 25 | max_overlaps_normal = overlaps_normal[:, :top_k].flatten() 26 | gt_assignment_normal = overlaps_normal_indices[:, :top_k].flatten() 27 | max_overlaps_ignore = overlaps_ignore[:, :top_k].flatten() 28 | gt_assignment_ignore = overlaps_ignore_indices[:, :top_k].flatten() 29 | # cons masks 30 | ignore_assign_mask = (max_overlaps_normal < config.fg_threshold) * ( 31 | max_overlaps_ignore > max_overlaps_normal) 32 | max_overlaps = max_overlaps_normal * ~ignore_assign_mask + \ 33 | max_overlaps_ignore * ignore_assign_mask 34 | gt_assignment = gt_assignment_normal * ~ignore_assign_mask + \ 35 | gt_assignment_ignore * ignore_assign_mask 36 | labels = gt_boxes_perimg[gt_assignment, 4] 37 | fg_mask = (max_overlaps >= config.fg_threshold) * (labels != config.ignore_label) 38 | bg_mask = (max_overlaps < config.bg_threshold_high) * ( 39 | max_overlaps >= config.bg_threshold_low) 40 | fg_mask = fg_mask.reshape(-1, top_k) 41 | bg_mask = bg_mask.reshape(-1, top_k) 42 | pos_max = config.num_rois * config.fg_ratio 43 | fg_inds_mask = subsample_masks(fg_mask[:, 0], pos_max, True) 44 | neg_max = config.num_rois - fg_inds_mask.sum() 45 | bg_inds_mask = subsample_masks(bg_mask[:, 0], neg_max, True) 46 | labels = labels * fg_mask.flatten() 47 | keep_mask = fg_inds_mask + bg_inds_mask 48 | # labels 49 | labels = labels.reshape(-1, top_k)[keep_mask] 50 | gt_assignment = gt_assignment.reshape(-1, top_k)[keep_mask].flatten() 51 | target_boxes = gt_boxes_perimg[gt_assignment, :4] 52 | rois = all_rois[keep_mask] 53 | target_rois = rois.repeat(1, top_k).reshape(-1, all_rois.shape[-1]) 54 | bbox_targets = bbox_transform_opr(target_rois[:, 1:5], target_boxes) 55 | if config.rcnn_bbox_normalize_targets: 56 | std_opr = torch.tensor(config.bbox_normalize_stds[None, :]).type_as(bbox_targets) 57 | mean_opr = torch.tensor(config.bbox_normalize_means[None, :]).type_as(bbox_targets) 58 | minus_opr = mean_opr / std_opr 59 | bbox_targets = bbox_targets / std_opr - minus_opr 60 | bbox_targets = bbox_targets.reshape(-1, top_k * 4) 61 | return_rois.append(rois) 62 | return_labels.append(labels) 63 | return_bbox_targets.append(bbox_targets) 64 | if config.train_batch_per_gpu == 1: 65 | return rois, labels, bbox_targets 66 | else: 67 | return_rois = torch.cat(return_rois, axis=0) 68 | return_labels = torch.cat(return_labels, axis=0) 69 | return_bbox_targets = torch.cat(return_bbox_targets, axis=0) 70 | return return_rois, return_labels, return_bbox_targets 71 | 72 | def subsample_masks(masks, num_samples, sample_value): 73 | positive = torch.nonzero(masks.eq(sample_value), as_tuple=False).squeeze(1) 74 | num_mask = len(positive) 75 | num_samples = int(num_samples) 76 | num_final_samples = min(num_mask, num_samples) 77 | num_final_negative = num_mask - num_final_samples 78 | perm = torch.randperm(num_mask, device=masks.device)[:num_final_negative] 79 | negative = positive[perm] 80 | masks[negative] = not sample_value 81 | return masks 82 | 83 | -------------------------------------------------------------------------------- /lib/det_oprs/loss_opr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from config import config 4 | 5 | def softmax_loss(score, label, ignore_label=-1): 6 | with torch.no_grad(): 7 | max_score = score.max(axis=1, keepdims=True)[0] 8 | score -= max_score 9 | log_prob = score - torch.log(torch.exp(score).sum(axis=1, keepdims=True)) 10 | mask = label != ignore_label 11 | vlabel = label * mask 12 | onehot = torch.zeros(vlabel.shape[0], config.num_classes, device=score.device) 13 | onehot.scatter_(1, vlabel.reshape(-1, 1), 1) 14 | loss = -(log_prob * onehot).sum(axis=1) 15 | loss = loss * mask 16 | return loss 17 | 18 | def smooth_l1_loss(pred, target, beta: float): 19 | if beta < 1e-5: 20 | loss = torch.abs(input - target) 21 | else: 22 | abs_x = torch.abs(pred- target) 23 | in_mask = abs_x < beta 24 | loss = torch.where(in_mask, 0.5 * abs_x ** 2 / beta, abs_x - 0.5 * beta) 25 | return loss.sum(axis=1) 26 | 27 | def focal_loss(inputs, targets, alpha=-1, gamma=2): 28 | class_range = torch.arange(1, inputs.shape[1] + 1, device=inputs.device) 29 | pos_pred = (1 - inputs) ** gamma * torch.log(inputs) 30 | neg_pred = inputs ** gamma * torch.log(1 - inputs) 31 | 32 | pos_loss = (targets == class_range) * pos_pred * alpha 33 | neg_loss = (targets != class_range) * neg_pred * (1 - alpha) 34 | loss = -(pos_loss + neg_loss) 35 | return loss.sum(axis=1) 36 | 37 | def emd_loss_softmax(p_b0, p_s0, p_b1, p_s1, targets, labels): 38 | # reshape 39 | pred_delta = torch.cat([p_b0, p_b1], axis=1).reshape(-1, p_b0.shape[-1]) 40 | pred_score = torch.cat([p_s0, p_s1], axis=1).reshape(-1, p_s0.shape[-1]) 41 | targets = targets.reshape(-1, 4) 42 | labels = labels.long().flatten() 43 | # cons masks 44 | valid_masks = labels >= 0 45 | fg_masks = labels > 0 46 | # multiple class 47 | pred_delta = pred_delta.reshape(-1, config.num_classes, 4) 48 | fg_gt_classes = labels[fg_masks] 49 | pred_delta = pred_delta[fg_masks, fg_gt_classes, :] 50 | # loss for regression 51 | localization_loss = smooth_l1_loss( 52 | pred_delta, 53 | targets[fg_masks], 54 | config.rcnn_smooth_l1_beta) 55 | # loss for classification 56 | objectness_loss = softmax_loss(pred_score, labels) 57 | loss = objectness_loss * valid_masks 58 | loss[fg_masks] = loss[fg_masks] + localization_loss 59 | loss = loss.reshape(-1, 2).sum(axis=1) 60 | return loss.reshape(-1, 1) 61 | 62 | def emd_loss_focal(p_b0, p_s0, p_b1, p_s1, targets, labels): 63 | pred_delta = torch.cat([p_b0, p_b1], axis=1).reshape(-1, p_b0.shape[-1]) 64 | pred_score = torch.cat([p_s0, p_s1], axis=1).reshape(-1, p_s0.shape[-1]) 65 | targets = targets.reshape(-1, 4) 66 | labels = labels.long().reshape(-1, 1) 67 | valid_mask = (labels >= 0).flatten() 68 | objectness_loss = focal_loss(pred_score, labels, 69 | config.focal_loss_alpha, config.focal_loss_gamma) 70 | fg_masks = (labels > 0).flatten() 71 | localization_loss = smooth_l1_loss( 72 | pred_delta[fg_masks], 73 | targets[fg_masks], 74 | config.smooth_l1_beta) 75 | loss = objectness_loss * valid_mask 76 | loss[fg_masks] = loss[fg_masks] + localization_loss 77 | loss = loss.reshape(-1, 2).sum(axis=1) 78 | return loss.reshape(-1, 1) 79 | -------------------------------------------------------------------------------- /lib/det_oprs/retina_anchor_target.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import numpy as np 4 | from config import config 5 | from det_oprs.bbox_opr import box_overlap_opr, bbox_transform_opr 6 | 7 | @torch.no_grad() 8 | def retina_anchor_target(anchors, gt_boxes, im_info, top_k=1): 9 | total_anchor = anchors.shape[0] 10 | return_labels = [] 11 | return_bbox_targets = [] 12 | # get per image proposals and gt_boxes 13 | for bid in range(config.train_batch_per_gpu): 14 | gt_boxes_perimg = gt_boxes[bid, :int(im_info[bid, 5]), :] 15 | anchors = anchors.type_as(gt_boxes_perimg) 16 | overlaps = box_overlap_opr(anchors, gt_boxes_perimg[:, :-1]) 17 | # gt max and indices 18 | max_overlaps, gt_assignment = overlaps.topk(top_k, dim=1, sorted=True) 19 | max_overlaps= max_overlaps.flatten() 20 | gt_assignment= gt_assignment.flatten() 21 | _, gt_assignment_for_gt = torch.max(overlaps, axis=0) 22 | del overlaps 23 | # cons labels 24 | labels = gt_boxes_perimg[gt_assignment, 4] 25 | labels = labels * (max_overlaps >= config.negative_thresh) 26 | ignore_mask = (max_overlaps < config.positive_thresh) * ( 27 | max_overlaps >= config.negative_thresh) 28 | labels[ignore_mask] = -1 29 | # cons bbox targets 30 | target_boxes = gt_boxes_perimg[gt_assignment, :4] 31 | target_anchors = anchors.repeat(1, top_k).reshape(-1, anchors.shape[-1]) 32 | bbox_targets = bbox_transform_opr(target_anchors, target_boxes) 33 | if config.allow_low_quality: 34 | labels[gt_assignment_for_gt] = gt_boxes_perimg[:, 4] 35 | low_quality_bbox_targets = bbox_transform_opr( 36 | anchors[gt_assignment_for_gt], gt_boxes_perimg[:, :4]) 37 | bbox_targets[gt_assignment_for_gt] = low_quality_bbox_targets 38 | labels = labels.reshape(-1, 1 * top_k) 39 | bbox_targets = bbox_targets.reshape(-1, 4 * top_k) 40 | return_labels.append(labels) 41 | return_bbox_targets.append(bbox_targets) 42 | 43 | if config.train_batch_per_gpu == 1: 44 | return labels, bbox_targets 45 | else: 46 | return_labels = torch.cat(return_labels, axis=0) 47 | return_bbox_targets = torch.cat(return_bbox_targets, axis=0) 48 | return return_labels, return_bbox_targets 49 | 50 | -------------------------------------------------------------------------------- /lib/det_oprs/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import pickle 3 | 4 | import torch 5 | 6 | def get_padded_tensor(tensor, multiple_number, pad_value=0): 7 | t_height, t_width = tensor.shape[-2], tensor.shape[-1] 8 | padded_height = (t_height + multiple_number - 1) // \ 9 | multiple_number * multiple_number 10 | padded_width = (t_width + multiple_number - 1) // \ 11 | multiple_number * multiple_number 12 | ndim = tensor.ndim 13 | if ndim == 4: 14 | padded_tensor = torch.ones([tensor.shape[0], tensor.shape[1], padded_height, padded_width]) * pad_value 15 | padded_tensor = padded_tensor.type_as(tensor) 16 | padded_tensor[:, :, :t_height, :t_width] = tensor 17 | elif ndim == 3: 18 | padded_tensor = torch.ones([tensor.shape[0], padded_height, padded_width]) * pad_value 19 | padded_tensor = padded_tensor.type_as(tensor) 20 | padded_tensor[:, :t_height, :t_width] = tensor 21 | else: 22 | raise Exception('Not supported tensor dim: {}'.format(ndim)) 23 | return padded_tensor 24 | 25 | def _init_backbone(backbone, model_path, strict): 26 | state_dict = _load_c2_pickled_weights(model_path) 27 | state_dict = _rename_weights_for_resnet50(state_dict) 28 | backbone.load_state_dict(state_dict, strict=strict) 29 | del state_dict 30 | 31 | def _rename_basic_resnet_weights(layer_keys): 32 | layer_keys = [k.replace("_", ".") for k in layer_keys] 33 | layer_keys = [k.replace(".w", ".weight") for k in layer_keys] 34 | layer_keys = [k.replace(".bn", "_bn") for k in layer_keys] 35 | layer_keys = [k.replace(".b", ".bias") for k in layer_keys] 36 | layer_keys = [k.replace(".biasranch", ".branch") for k in layer_keys] 37 | layer_keys = [k.replace(".biaseta", ".beta") for k in layer_keys] 38 | # Affine-Channel -> BatchNorm enaming 39 | layer_keys = [k.replace("running.mean", "running_mean") for k in layer_keys] 40 | layer_keys = [k.replace("running.var", "running_var") for k in layer_keys] 41 | layer_keys = [k.replace(".beta", ".bias") for k in layer_keys] 42 | layer_keys = [k.replace(".gamma", ".weight") for k in layer_keys] 43 | layer_keys = [k.replace("res.conv1_bn", "bn1") for k in layer_keys] 44 | ## Make torchvision-compatible 45 | layer_keys = [k.replace("res2.", "layer1.") for k in layer_keys] 46 | layer_keys = [k.replace("res3.", "layer2.") for k in layer_keys] 47 | layer_keys = [k.replace("res4.", "layer3.") for k in layer_keys] 48 | layer_keys = [k.replace("res5.", "layer4.") for k in layer_keys] 49 | 50 | layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys] 51 | layer_keys = [k.replace(".branch2a_bn.", ".bn1.") for k in layer_keys] 52 | layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys] 53 | layer_keys = [k.replace(".branch2b_bn.", ".bn2.") for k in layer_keys] 54 | layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys] 55 | layer_keys = [k.replace(".branch2c_bn.", ".bn3.") for k in layer_keys] 56 | layer_keys = [k.replace(".branch1.", ".downsample.0.") for k in layer_keys] 57 | layer_keys = [k.replace(".branch1_bn.", ".downsample.1.") for k in layer_keys] 58 | return layer_keys 59 | 60 | def _rename_weights_for_resnet50(weights): 61 | original_keys = sorted(weights.keys()) 62 | layer_keys = sorted(weights.keys()) 63 | layer_keys = _rename_basic_resnet_weights(layer_keys) 64 | key_map = {k: v for k, v in zip(original_keys, layer_keys)} 65 | new_weights = OrderedDict() 66 | for k in original_keys: 67 | v = weights[k] 68 | if "_momentum" in k: 69 | continue 70 | if "fc1000" in k: 71 | continue 72 | w = torch.from_numpy(v) 73 | new_weights[key_map[k]] = w 74 | return new_weights 75 | 76 | def _load_c2_pickled_weights(file_path): 77 | with open(file_path, "rb") as f: 78 | if torch._six.PY3: 79 | data = pickle.load(f, encoding="latin1") 80 | else: 81 | data = pickle.load(f) 82 | if "blobs" in data: 83 | weights = data["blobs"] 84 | else: 85 | weights = data 86 | return weights 87 | -------------------------------------------------------------------------------- /lib/evaluate/APMRToolkits/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf8 -*- 2 | __author__ = 'jyn' 3 | __email__ = 'jyn@megvii.com' 4 | 5 | from .image import * 6 | from .database import * 7 | -------------------------------------------------------------------------------- /lib/evaluate/APMRToolkits/database.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from .image import * 5 | 6 | PERSON_CLASSES = ['background', 'person'] 7 | # DBBase 8 | class Database(object): 9 | def __init__(self, gtpath=None, dtpath=None, body_key=None, head_key=None, mode=0): 10 | """ 11 | mode=0: only body; mode=1: only head 12 | """ 13 | self.images = dict() 14 | self.eval_mode = mode 15 | self.loadData(gtpath, body_key, head_key, True) 16 | self.loadData(dtpath, body_key, head_key, False) 17 | 18 | self._ignNum = sum([self.images[i]._ignNum for i in self.images]) 19 | self._gtNum = sum([self.images[i]._gtNum for i in self.images]) 20 | self._imageNum = len(self.images) 21 | self.scorelist = None 22 | 23 | def loadData(self, fpath, body_key=None, head_key=None, if_gt=True): 24 | assert os.path.isfile(fpath), fpath + " does not exist!" 25 | with open(fpath, "r") as f: 26 | lines = f.readlines() 27 | records = [json.loads(line.strip('\n')) for line in lines] 28 | if if_gt: 29 | for record in records: 30 | self.images[record["ID"]] = Image(self.eval_mode) 31 | self.images[record["ID"]].load(record, body_key, head_key, PERSON_CLASSES, True) 32 | else: 33 | for record in records: 34 | self.images[record["ID"]].load(record, body_key, head_key, PERSON_CLASSES, False) 35 | self.images[record["ID"]].clip_all_boader() 36 | 37 | def compare(self, thres=0.5, matching=None): 38 | """ 39 | match the detection results with the groundtruth in the whole database 40 | """ 41 | assert matching is None or matching == "VOC", matching 42 | scorelist = list() 43 | for ID in self.images: 44 | if matching == "VOC": 45 | result = self.images[ID].compare_voc(thres) 46 | else: 47 | result = self.images[ID].compare_caltech(thres) 48 | scorelist.extend(result) 49 | # In the descending sort of dtbox score. 50 | scorelist.sort(key=lambda x: x[0][-1], reverse=True) 51 | self.scorelist = scorelist 52 | 53 | def eval_MR(self, ref="CALTECH_-2"): 54 | """ 55 | evaluate by Caltech-style log-average miss rate 56 | ref: str - "CALTECH_-2"/"CALTECH_-4" 57 | """ 58 | # find greater_than 59 | def _find_gt(lst, target): 60 | for idx, item in enumerate(lst): 61 | if item >= target: 62 | return idx 63 | return len(lst)-1 64 | 65 | assert ref == "CALTECH_-2" or ref == "CALTECH_-4", ref 66 | if ref == "CALTECH_-2": 67 | # CALTECH_MRREF_2: anchor points (from 10^-2 to 1) as in P.Dollar's paper 68 | ref = [0.0100, 0.0178, 0.03160, 0.0562, 0.1000, 0.1778, 0.3162, 0.5623, 1.000] 69 | else: 70 | # CALTECH_MRREF_4: anchor points (from 10^-4 to 1) as in S.Zhang's paper 71 | ref = [0.0001, 0.0003, 0.00100, 0.0032, 0.0100, 0.0316, 0.1000, 0.3162, 1.000] 72 | 73 | if self.scorelist is None: 74 | self.compare() 75 | 76 | tp, fp = 0.0, 0.0 77 | fppiX, fppiY = list(), list() 78 | for i, item in enumerate(self.scorelist): 79 | if item[1] == 1: 80 | tp += 1.0 81 | elif item[1] == 0: 82 | fp += 1.0 83 | 84 | fn = (self._gtNum - self._ignNum) - tp 85 | recall = tp / (tp + fn) 86 | precision = tp / (tp + fp) 87 | missrate = 1.0 - recall 88 | fppi = fp / self._imageNum 89 | fppiX.append(fppi) 90 | fppiY.append(missrate) 91 | 92 | score = list() 93 | for pos in ref: 94 | argmin = _find_gt(fppiX, pos) 95 | if argmin >= 0: 96 | score.append(fppiY[argmin]) 97 | score = np.array(score) 98 | MR = np.exp(np.log(score).mean()) 99 | return MR, (fppiX, fppiY) 100 | 101 | def eval_AP(self): 102 | """ 103 | :meth: evaluate by average precision 104 | """ 105 | # calculate general ap score 106 | def _calculate_map(recall, precision): 107 | assert len(recall) == len(precision) 108 | area = 0 109 | for i in range(1, len(recall)): 110 | delta_h = (precision[i-1] + precision[i]) / 2 111 | delta_w = recall[i] - recall[i-1] 112 | area += delta_w * delta_h 113 | return area 114 | 115 | tp, fp = 0.0, 0.0 116 | rpX, rpY = list(), list() 117 | total_det = len(self.scorelist) 118 | total_gt = self._gtNum - self._ignNum 119 | total_images = self._imageNum 120 | 121 | fpn = [] 122 | recalln = [] 123 | thr = [] 124 | fppi = [] 125 | for i, item in enumerate(self.scorelist): 126 | if item[1] == 1: 127 | tp += 1.0 128 | elif item[1] == 0: 129 | fp += 1.0 130 | fn = total_gt - tp 131 | recall = tp / (tp + fn) 132 | precision = tp / (tp + fp) 133 | rpX.append(recall) 134 | rpY.append(precision) 135 | fpn.append(fp) 136 | recalln.append(tp) 137 | thr.append(item[0][-1]) 138 | fppi.append(fp/total_images) 139 | 140 | AP = _calculate_map(rpX, rpY) 141 | return AP, (rpX, rpY, thr, fpn, recalln, fppi) 142 | 143 | -------------------------------------------------------------------------------- /lib/evaluate/APMRToolkits/image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Image(object): 4 | def __init__(self, mode): 5 | self.ID = None 6 | self._width = None 7 | self._height = None 8 | self.dtboxes = None 9 | self.gtboxes = None 10 | self.eval_mode = mode 11 | 12 | self._ignNum = None 13 | self._gtNum = None 14 | self._dtNum = None 15 | 16 | def load(self, record, body_key, head_key, class_names, gtflag): 17 | """ 18 | :meth: read the object from a dict 19 | """ 20 | if "ID" in record and self.ID is None: 21 | self.ID = record['ID'] 22 | if "width" in record and self._width is None: 23 | self._width = record["width"] 24 | if "height" in record and self._height is None: 25 | self._height = record["height"] 26 | if gtflag: 27 | self._gtNum = len(record["gtboxes"]) 28 | body_bbox, head_bbox = self.load_gt_boxes(record, 'gtboxes', class_names) 29 | if self.eval_mode == 0: 30 | self.gtboxes = body_bbox 31 | self._ignNum = (body_bbox[:, -1] == -1).sum() 32 | elif self.eval_mode == 1: 33 | self.gtboxes = head_bbox 34 | self._ignNum = (head_bbox[:, -1] == -1).sum() 35 | elif self.eval_mode == 2: 36 | gt_tag = np.array([body_bbox[i,-1]!=-1 and head_bbox[i,-1]!=-1 for i in range(len(body_bbox))]) 37 | self._ignNum = (gt_tag == 0).sum() 38 | self.gtboxes = np.hstack((body_bbox[:, :-1], head_bbox[:, :-1], gt_tag.reshape(-1, 1))) 39 | else: 40 | raise Exception('Unknown evaluation mode!') 41 | if not gtflag: 42 | self._dtNum = len(record["dtboxes"]) 43 | if self.eval_mode == 0: 44 | self.dtboxes = self.load_det_boxes(record, 'dtboxes', body_key, 'score') 45 | elif self.eval_mode == 1: 46 | self.dtboxes = self.load_det_boxes(record, 'dtboxes', head_key, 'score') 47 | elif self.eval_mode == 2: 48 | body_dtboxes = self.load_det_boxes(record, 'dtboxes', body_key) 49 | head_dtboxes = self.load_det_boxes(record, 'dtboxes', head_key, 'score') 50 | self.dtboxes = np.hstack((body_dtboxes, head_dtboxes)) 51 | else: 52 | raise Exception('Unknown evaluation mode!') 53 | 54 | def compare_caltech(self, thres): 55 | """ 56 | :meth: match the detection results with the groundtruth by Caltech matching strategy 57 | :param thres: iou threshold 58 | :type thres: float 59 | :return: a list of tuples (dtbox, imageID), in the descending sort of dtbox.score 60 | """ 61 | dtboxes = self.dtboxes if self.dtboxes is not None else list() 62 | gtboxes = self.gtboxes if self.gtboxes is not None else list() 63 | dt_matched = np.zeros(dtboxes.shape[0]) 64 | gt_matched = np.zeros(gtboxes.shape[0]) 65 | 66 | dtboxes = np.array(sorted(dtboxes, key=lambda x: x[-1], reverse=True)) 67 | gtboxes = np.array(sorted(gtboxes, key=lambda x: x[-1], reverse=True)) 68 | if len(dtboxes): 69 | overlap_iou = self.box_overlap_opr(dtboxes, gtboxes, True) 70 | overlap_ioa = self.box_overlap_opr(dtboxes, gtboxes, False) 71 | else: 72 | return list() 73 | 74 | scorelist = list() 75 | for i, dt in enumerate(dtboxes): 76 | maxpos = -1 77 | maxiou = thres 78 | for j, gt in enumerate(gtboxes): 79 | if gt_matched[j] == 1: 80 | continue 81 | if gt[-1] > 0: 82 | overlap = overlap_iou[i][j] 83 | if overlap > maxiou: 84 | maxiou = overlap 85 | maxpos = j 86 | else: 87 | if maxpos >= 0: 88 | break 89 | else: 90 | overlap = overlap_ioa[i][j] 91 | if overlap > thres: 92 | maxiou = overlap 93 | maxpos = j 94 | if maxpos >= 0: 95 | if gtboxes[maxpos, -1] > 0: 96 | gt_matched[maxpos] = 1 97 | dt_matched[i] = 1 98 | scorelist.append((dt, 1, self.ID)) 99 | else: 100 | dt_matched[i] = -1 101 | else: 102 | dt_matched[i] = 0 103 | scorelist.append((dt, 0, self.ID)) 104 | return scorelist 105 | 106 | def compare_caltech_union(self, thres): 107 | """ 108 | :meth: match the detection results with the groundtruth by Caltech matching strategy 109 | :param thres: iou threshold 110 | :type thres: float 111 | :return: a list of tuples (dtbox, imageID), in the descending sort of dtbox.score 112 | """ 113 | dtboxes = self.dtboxes if self.dtboxes is not None else list() 114 | gtboxes = self.gtboxes if self.gtboxes is not None else list() 115 | if len(dtboxes) == 0: 116 | return list() 117 | dt_matched = np.zeros(dtboxes.shape[0]) 118 | gt_matched = np.zeros(gtboxes.shape[0]) 119 | 120 | dtboxes = np.array(sorted(dtboxes, key=lambda x: x[-1], reverse=True)) 121 | gtboxes = np.array(sorted(gtboxes, key=lambda x: x[-1], reverse=True)) 122 | dt_body_boxes = np.hstack((dtboxes[:, :4], dtboxes[:, -1][:,None])) 123 | dt_head_boxes = dtboxes[:, 4:8] 124 | gt_body_boxes = np.hstack((gtboxes[:, :4], gtboxes[:, -1][:,None])) 125 | gt_head_boxes = gtboxes[:, 4:8] 126 | overlap_iou = self.box_overlap_opr(dt_body_boxes, gt_body_boxes, True) 127 | overlap_head = self.box_overlap_opr(dt_head_boxes, gt_head_boxes, True) 128 | overlap_ioa = self.box_overlap_opr(dt_body_boxes, gt_body_boxes, False) 129 | 130 | scorelist = list() 131 | for i, dt in enumerate(dtboxes): 132 | maxpos = -1 133 | maxiou = thres 134 | for j, gt in enumerate(gtboxes): 135 | if gt_matched[j] == 1: 136 | continue 137 | if gt[-1] > 0: 138 | o_body = overlap_iou[i][j] 139 | o_head = overlap_head[i][j] 140 | if o_body > maxiou and o_head > maxiou: 141 | maxiou = o_body 142 | maxpos = j 143 | else: 144 | if maxpos >= 0: 145 | break 146 | else: 147 | o_body = overlap_ioa[i][j] 148 | if o_body > thres: 149 | maxiou = o_body 150 | maxpos = j 151 | if maxpos >= 0: 152 | if gtboxes[maxpos, -1] > 0: 153 | gt_matched[maxpos] = 1 154 | dt_matched[i] = 1 155 | scorelist.append((dt, 1, self.ID)) 156 | else: 157 | dt_matched[i] = -1 158 | else: 159 | dt_matched[i] = 0 160 | scorelist.append((dt, 0, self.ID)) 161 | return scorelist 162 | 163 | def box_overlap_opr(self, dboxes:np.ndarray, gboxes:np.ndarray, if_iou): 164 | eps = 1e-6 165 | assert dboxes.shape[-1] >= 4 and gboxes.shape[-1] >= 4 166 | N, K = dboxes.shape[0], gboxes.shape[0] 167 | dtboxes = np.tile(np.expand_dims(dboxes, axis = 1), (1, K, 1)) 168 | gtboxes = np.tile(np.expand_dims(gboxes, axis = 0), (N, 1, 1)) 169 | 170 | iw = np.minimum(dtboxes[:,:,2], gtboxes[:,:,2]) - np.maximum(dtboxes[:,:,0], gtboxes[:,:,0]) 171 | ih = np.minimum(dtboxes[:,:,3], gtboxes[:,:,3]) - np.maximum(dtboxes[:,:,1], gtboxes[:,:,1]) 172 | inter = np.maximum(0, iw) * np.maximum(0, ih) 173 | 174 | dtarea = (dtboxes[:,:,2] - dtboxes[:,:,0]) * (dtboxes[:,:,3] - dtboxes[:,:,1]) 175 | if if_iou: 176 | gtarea = (gtboxes[:,:,2] - gtboxes[:,:,0]) * (gtboxes[:,:,3] - gtboxes[:,:,1]) 177 | ious = inter / (dtarea + gtarea - inter + eps) 178 | else: 179 | ious = inter / (dtarea + eps) 180 | return ious 181 | 182 | def clip_all_boader(self): 183 | 184 | def _clip_boundary(boxes,height,width): 185 | assert boxes.shape[-1]>=4 186 | boxes[:,0] = np.minimum(np.maximum(boxes[:,0],0), width - 1) 187 | boxes[:,1] = np.minimum(np.maximum(boxes[:,1],0), height - 1) 188 | boxes[:,2] = np.maximum(np.minimum(boxes[:,2],width), 0) 189 | boxes[:,3] = np.maximum(np.minimum(boxes[:,3],height), 0) 190 | return boxes 191 | 192 | assert self.dtboxes.shape[-1]>=4 193 | assert self.gtboxes.shape[-1]>=4 194 | assert self._width is not None and self._height is not None 195 | if self.eval_mode == 2: 196 | self.dtboxes[:, :4] = _clip_boundary(self.dtboxes[:, :4], self._height, self._width) 197 | self.gtboxes[:, :4] = _clip_boundary(self.gtboxes[:, :4], self._height, self._width) 198 | self.dtboxes[:, 4:8] = _clip_boundary(self.dtboxes[:, 4:8], self._height, self._width) 199 | self.gtboxes[:, 4:8] = _clip_boundary(self.gtboxes[:, 4:8], self._height, self._width) 200 | else: 201 | self.dtboxes = _clip_boundary(self.dtboxes, self._height, self._width) 202 | self.gtboxes = _clip_boundary(self.gtboxes, self._height, self._width) 203 | 204 | def load_gt_boxes(self, dict_input, key_name, class_names): 205 | assert key_name in dict_input 206 | if len(dict_input[key_name]) < 1: 207 | return np.empty([0, 5]) 208 | head_bbox = [] 209 | body_bbox = [] 210 | for rb in dict_input[key_name]: 211 | if rb['tag'] in class_names: 212 | body_tag = class_names.index(rb['tag']) 213 | head_tag = 1 214 | else: 215 | body_tag = -1 216 | head_tag = -1 217 | if 'extra' in rb: 218 | if 'ignore' in rb['extra']: 219 | if rb['extra']['ignore'] != 0: 220 | body_tag = -1 221 | head_tag = -1 222 | if 'head_attr' in rb: 223 | if 'ignore' in rb['head_attr']: 224 | if rb['head_attr']['ignore'] != 0: 225 | head_tag = -1 226 | head_bbox.append(np.hstack((rb['hbox'], head_tag))) 227 | body_bbox.append(np.hstack((rb['fbox'], body_tag))) 228 | head_bbox = np.array(head_bbox) 229 | head_bbox[:, 2:4] += head_bbox[:, :2] 230 | body_bbox = np.array(body_bbox) 231 | body_bbox[:, 2:4] += body_bbox[:, :2] 232 | return body_bbox, head_bbox 233 | 234 | def load_det_boxes(self, dict_input, key_name, key_box, key_score=None, key_tag=None): 235 | assert key_name in dict_input 236 | if len(dict_input[key_name]) < 1: 237 | return np.empty([0, 5]) 238 | else: 239 | assert key_box in dict_input[key_name][0] 240 | if key_score: 241 | assert key_score in dict_input[key_name][0] 242 | if key_tag: 243 | assert key_tag in dict_input[key_name][0] 244 | if key_score: 245 | if key_tag: 246 | bboxes = np.vstack([np.hstack((rb[key_box], rb[key_score], rb[key_tag])) for rb in dict_input[key_name]]) 247 | else: 248 | bboxes = np.vstack([np.hstack((rb[key_box], rb[key_score])) for rb in dict_input[key_name]]) 249 | else: 250 | if key_tag: 251 | bboxes = np.vstack([np.hstack((rb[key_box], rb[key_tag])) for rb in dict_input[key_name]]) 252 | else: 253 | bboxes = np.vstack([rb[key_box] for rb in dict_input[key_name]]) 254 | bboxes[:, 2:4] += bboxes[:, :2] 255 | return bboxes 256 | 257 | def compare_voc(self, thres): 258 | """ 259 | :meth: match the detection results with the groundtruth by VOC matching strategy 260 | :param thres: iou threshold 261 | :type thres: float 262 | :return: a list of tuples (dtbox, imageID), in the descending sort of dtbox.score 263 | """ 264 | if self.dtboxes is None: 265 | return list() 266 | dtboxes = self.dtboxes 267 | gtboxes = self.gtboxes if self.gtboxes is not None else list() 268 | dtboxes.sort(key=lambda x: x.score, reverse=True) 269 | gtboxes.sort(key=lambda x: x.ign) 270 | 271 | scorelist = list() 272 | for i, dt in enumerate(dtboxes): 273 | maxpos = -1 274 | maxiou = thres 275 | 276 | for j, gt in enumerate(gtboxes): 277 | overlap = dt.iou(gt) 278 | if overlap > maxiou: 279 | maxiou = overlap 280 | maxpos = j 281 | 282 | if maxpos >= 0: 283 | if gtboxes[maxpos].ign == 0: 284 | gtboxes[maxpos].matched = 1 285 | dtboxes[i].matched = 1 286 | scorelist.append((dt, self.ID)) 287 | else: 288 | dtboxes[i].matched = -1 289 | else: 290 | dtboxes[i].matched = 0 291 | scorelist.append((dt, self.ID)) 292 | return scorelist 293 | -------------------------------------------------------------------------------- /lib/evaluate/JIToolkits/JI_tools.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import numpy as np 3 | from .matching import maxWeightMatching 4 | 5 | def compute_matching(dt_boxes, gt_boxes, bm_thr): 6 | assert dt_boxes.shape[-1] > 3 and gt_boxes.shape[-1] > 3 7 | if dt_boxes.shape[0] < 1 or gt_boxes.shape[0] < 1: 8 | return list() 9 | N, K = dt_boxes.shape[0], gt_boxes.shape[0] 10 | ious = compute_iou_matrix(dt_boxes, gt_boxes) 11 | rows, cols = np.where(ious > bm_thr) 12 | bipartites = [(i + 1, j + N + 1, ious[i, j]) for (i, j) in zip(rows, cols)] 13 | mates = maxWeightMatching(bipartites) 14 | if len(mates) < 1: 15 | return list() 16 | rows = np.where(np.array(mates) > -1)[0] 17 | indices = np.where(rows < N + 1)[0] 18 | rows = rows[indices] 19 | cols = np.array([mates[i] for i in rows]) 20 | matches = [(i-1, j - N - 1) for (i, j) in zip(rows, cols)] 21 | return matches 22 | 23 | def compute_head_body_matching(dt_body, dt_head, gt_body, gt_head, bm_thr): 24 | assert dt_body.shape[-1] > 3 and gt_body.shape[-1] > 3 25 | assert dt_head.shape[-1] > 3 and gt_head.shape[-1] > 3 26 | assert dt_body.shape[0] == dt_head.shape[0] 27 | assert gt_body.shape[0] == gt_head.shape[0] 28 | N, K = dt_body.shape[0], gt_body.shape[0] 29 | ious_body = compute_iou_matrix(dt_body, gt_body) 30 | ious_head = compute_iou_matrix(dt_head, gt_head) 31 | mask_body = ious_body > bm_thr 32 | mask_head = ious_head > bm_thr 33 | # only keep the both matches detections 34 | mask = np.array(mask_body) & np.array(mask_head) 35 | ious = np.zeros((N, K)) 36 | ious[mask] = (ious_body[mask] + ious_head[mask]) / 2 37 | rows, cols = np.where(ious > bm_thr) 38 | bipartites = [(i + 1, j + N + 1, ious[i, j]) for (i, j) in zip(rows, cols)] 39 | mates = maxWeightMatching(bipartites) 40 | if len(mates) < 1: 41 | return list() 42 | rows = np.where(np.array(mates) > -1)[0] 43 | indices = np.where(rows < N + 1)[0] 44 | rows = rows[indices] 45 | cols = np.array([mates[i] for i in rows]) 46 | matches = [(i-1, j - N - 1) for (i, j) in zip(rows, cols)] 47 | return matches 48 | 49 | def compute_multi_head_body_matching(dt_body, dt_head_0, dt_head_1, gt_body, gt_head, bm_thr): 50 | assert dt_body.shape[-1] > 3 and gt_body.shape[-1] > 3 51 | assert dt_head_0.shape[-1] > 3 and gt_head.shape[-1] > 3 52 | assert dt_head_1.shape[-1] > 3 and gt_head.shape[-1] > 3 53 | assert dt_body.shape[0] == dt_head_0.shape[0] 54 | assert gt_body.shape[0] == gt_head.shape[0] 55 | N, K = dt_body.shape[0], gt_body.shape[0] 56 | ious_body = compute_iou_matrix(dt_body, gt_body) 57 | ious_head_0 = compute_iou_matrix(dt_head_0, gt_head) 58 | ious_head_1 = compute_iou_matrix(dt_head_1, gt_head) 59 | mask_body = ious_body > bm_thr 60 | mask_head_0 = ious_head_0 > bm_thr 61 | mask_head_1 = ious_head_1 > bm_thr 62 | mask_head = mask_head_0 | mask_head_1 63 | # only keep the both matches detections 64 | mask = np.array(mask_body) & np.array(mask_head) 65 | ious = np.zeros((N, K)) 66 | #ious[mask] = (ious_body[mask] + ious_head[mask]) / 2 67 | ious[mask] = ious_body[mask] 68 | rows, cols = np.where(ious > bm_thr) 69 | bipartites = [(i + 1, j + N + 1, ious[i, j]) for (i, j) in zip(rows, cols)] 70 | mates = maxWeightMatching(bipartites) 71 | if len(mates) < 1: 72 | return list() 73 | rows = np.where(np.array(mates) > -1)[0] 74 | indices = np.where(rows < N + 1)[0] 75 | rows = rows[indices] 76 | cols = np.array([mates[i] for i in rows]) 77 | matches = [(i-1, j - N - 1) for (i, j) in zip(rows, cols)] 78 | return matches 79 | 80 | def get_head_body_ignores(dt_body, dt_head, gt_body, gt_head, bm_thr): 81 | if gt_body.size: 82 | body_ioas = compute_ioa_matrix(dt_body, gt_body) 83 | head_ioas = compute_ioa_matrix(dt_head, gt_head) 84 | body_ioas = np.max(body_ioas, axis=1) 85 | head_ioas = np.max(head_ioas, axis=1) 86 | head_rows = np.where(head_ioas > bm_thr)[0] 87 | body_rows = np.where(body_ioas > bm_thr)[0] 88 | rows = set.union(set(head_rows), set(body_rows)) 89 | return len(rows) 90 | else: 91 | return 0 92 | 93 | def get_ignores(dt_boxes, gt_boxes, bm_thr): 94 | if gt_boxes.size: 95 | ioas = compute_ioa_matrix(dt_boxes, gt_boxes) 96 | ioas = np.max(ioas, axis = 1) 97 | rows = np.where(ioas > bm_thr)[0] 98 | return len(rows) 99 | else: 100 | return 0 101 | 102 | def compute_ioa_matrix(dboxes: np.ndarray, gboxes: np.ndarray): 103 | eps = 1e-6 104 | assert dboxes.shape[-1] >= 4 and gboxes.shape[-1] >= 4 105 | N, K = dboxes.shape[0], gboxes.shape[0] 106 | dtboxes = np.tile(np.expand_dims(dboxes, axis = 1), (1, K, 1)) 107 | gtboxes = np.tile(np.expand_dims(gboxes, axis = 0), (N, 1, 1)) 108 | 109 | iw = np.minimum(dtboxes[:,:,2], gtboxes[:,:,2]) - np.maximum(dtboxes[:,:,0], gtboxes[:,:,0]) 110 | ih = np.minimum(dtboxes[:,:,3], gtboxes[:,:,3]) - np.maximum(dtboxes[:,:,1], gtboxes[:,:,1]) 111 | inter = np.maximum(0, iw) * np.maximum(0, ih) 112 | 113 | dtarea = np.maximum(dtboxes[:,:,2] - dtboxes[:,:,0], 0) * np.maximum(dtboxes[:,:,3] - dtboxes[:,:,1], 0) 114 | ioas = inter / (dtarea + eps) 115 | return ioas 116 | 117 | def compute_iou_matrix(dboxes:np.ndarray, gboxes:np.ndarray): 118 | eps = 1e-6 119 | assert dboxes.shape[-1] >= 4 and gboxes.shape[-1] >= 4 120 | N, K = dboxes.shape[0], gboxes.shape[0] 121 | dtboxes = np.tile(np.expand_dims(dboxes, axis = 1), (1, K, 1)) 122 | gtboxes = np.tile(np.expand_dims(gboxes, axis = 0), (N, 1, 1)) 123 | 124 | iw = np.minimum(dtboxes[:,:,2], gtboxes[:,:,2]) - np.maximum(dtboxes[:,:,0], gtboxes[:,:,0]) 125 | ih = np.minimum(dtboxes[:,:,3], gtboxes[:,:,3]) - np.maximum(dtboxes[:,:,1], gtboxes[:,:,1]) 126 | inter = np.maximum(0, iw) * np.maximum(0, ih) 127 | 128 | dtarea = (dtboxes[:,:,2] - dtboxes[:,:,0]) * (dtboxes[:,:,3] - dtboxes[:,:,1]) 129 | gtarea = (gtboxes[:,:,2] - gtboxes[:,:,0]) * (gtboxes[:,:,3] - gtboxes[:,:,1]) 130 | ious = inter / (dtarea + gtarea - inter + eps) 131 | return ious 132 | 133 | -------------------------------------------------------------------------------- /lib/evaluate/JIToolkits/matching.py: -------------------------------------------------------------------------------- 1 | """Weighted maximum matching in general graphs. 2 | 3 | The algorithm is taken from "Efficient Algorithms for Finding Maximum 4 | Matching in Graphs" by Zvi Galil, ACM Computing Surveys, 1986. 5 | It is based on the "blossom" method for finding augmenting paths and 6 | the "primal-dual" method for finding a matching of maximum weight, both 7 | due to Jack Edmonds. 8 | Some ideas came from "Implementation of algorithms for maximum matching 9 | on non-bipartite graphs" by H.J. Gabow, Standford Ph.D. thesis, 1973. 10 | 11 | A C program for maximum weight matching by Ed Rothberg was used extensively 12 | to validate this new code. 13 | """ 14 | 15 | # 16 | # Changes: 17 | # 18 | # 2013-04-07 19 | # * Added Python 3 compatibility with contributions from Daniel Saunders. 20 | # 21 | # 2008-06-08 22 | # * First release. 23 | # 24 | 25 | from __future__ import print_function 26 | 27 | # If assigned, DEBUG(str) is called with lots of debug messages. 28 | DEBUG = None 29 | """def DEBUG(s): 30 | from sys import stderr 31 | print('DEBUG:', s, file=stderr) 32 | """ 33 | 34 | # Check delta2/delta3 computation after every substage; 35 | # only works on integer weights, slows down the algorithm to O(n^4). 36 | CHECK_DELTA = False 37 | 38 | # Check optimality of solution before returning; only works on integer weights. 39 | CHECK_OPTIMUM = False 40 | 41 | def maxWeightMatching(edges, maxcardinality=False): 42 | """Compute a maximum-weighted matching in the general undirected 43 | weighted graph given by "edges". If "maxcardinality" is true, 44 | only maximum-cardinality matchings are considered as solutions. 45 | 46 | Edges is a sequence of tuples (i, j, wt) describing an undirected 47 | edge between vertex i and vertex j with weight wt. There is at most 48 | one edge between any two vertices; no vertex has an edge to itself. 49 | Vertices are identified by consecutive, non-negative integers. 50 | 51 | Return a list "mate", such that mate[i] == j if vertex i is 52 | matched to vertex j, and mate[i] == -1 if vertex i is not matched. 53 | 54 | This function takes time O(n ** 3).""" 55 | 56 | # 57 | # Vertices are numbered 0 .. (nvertex-1). 58 | # Non-trivial blossoms are numbered nvertex .. (2*nvertex-1) 59 | # 60 | # Edges are numbered 0 .. (nedge-1). 61 | # Edge endpoints are numbered 0 .. (2*nedge-1), such that endpoints 62 | # (2*k) and (2*k+1) both belong to edge k. 63 | # 64 | # Many terms used in the comments (sub-blossom, T-vertex) come from 65 | # the paper by Galil; read the paper before reading this code. 66 | # 67 | 68 | # Python 2/3 compatibility. 69 | from sys import version as sys_version 70 | if sys_version < '3': 71 | integer_types = (int, long) 72 | else: 73 | integer_types = (int,) 74 | 75 | # Deal swiftly with empty graphs. 76 | if not edges: 77 | return [ ] 78 | 79 | # Count vertices. 80 | nedge = len(edges) 81 | nvertex = 0 82 | for (i, j, w) in edges: 83 | assert i >= 0 and j >= 0 and i != j 84 | if i >= nvertex: 85 | nvertex = i + 1 86 | if j >= nvertex: 87 | nvertex = j + 1 88 | 89 | # Find the maximum edge weight. 90 | maxweight = max(0, max([ wt for (i, j, wt) in edges ])) 91 | 92 | # If p is an edge endpoint, 93 | # endpoint[p] is the vertex to which endpoint p is attached. 94 | # Not modified by the algorithm. 95 | endpoint = [ edges[p//2][p%2] for p in range(2*nedge) ] 96 | 97 | # If v is a vertex, 98 | # neighbend[v] is the list of remote endpoints of the edges attached to v. 99 | # Not modified by the algorithm. 100 | neighbend = [ [ ] for i in range(nvertex) ] 101 | for k in range(len(edges)): 102 | (i, j, w) = edges[k] 103 | neighbend[i].append(2*k+1) 104 | neighbend[j].append(2*k) 105 | 106 | # If v is a vertex, 107 | # mate[v] is the remote endpoint of its matched edge, or -1 if it is single 108 | # (i.e. endpoint[mate[v]] is v's partner vertex). 109 | # Initially all vertices are single; updated during augmentation. 110 | mate = nvertex * [ -1 ] 111 | 112 | # If b is a top-level blossom, 113 | # label[b] is 0 if b is unlabeled (free); 114 | # 1 if b is an S-vertex/blossom; 115 | # 2 if b is a T-vertex/blossom. 116 | # The label of a vertex is found by looking at the label of its 117 | # top-level containing blossom. 118 | # If v is a vertex inside a T-blossom, 119 | # label[v] is 2 iff v is reachable from an S-vertex outside the blossom. 120 | # Labels are assigned during a stage and reset after each augmentation. 121 | label = (2 * nvertex) * [ 0 ] 122 | 123 | # If b is a labeled top-level blossom, 124 | # labelend[b] is the remote endpoint of the edge through which b obtained 125 | # its label, or -1 if b's base vertex is single. 126 | # If v is a vertex inside a T-blossom and label[v] == 2, 127 | # labelend[v] is the remote endpoint of the edge through which v is 128 | # reachable from outside the blossom. 129 | labelend = (2 * nvertex) * [ -1 ] 130 | 131 | # If v is a vertex, 132 | # inblossom[v] is the top-level blossom to which v belongs. 133 | # If v is a top-level vertex, v is itself a blossom (a trivial blossom) 134 | # and inblossom[v] == v. 135 | # Initially all vertices are top-level trivial blossoms. 136 | inblossom = list(range(nvertex)) 137 | 138 | # If b is a sub-blossom, 139 | # blossomparent[b] is its immediate parent (sub-)blossom. 140 | # If b is a top-level blossom, blossomparent[b] is -1. 141 | blossomparent = (2 * nvertex) * [ -1 ] 142 | 143 | # If b is a non-trivial (sub-)blossom, 144 | # blossomchilds[b] is an ordered list of its sub-blossoms, starting with 145 | # the base and going round the blossom. 146 | blossomchilds = (2 * nvertex) * [ None ] 147 | 148 | # If b is a (sub-)blossom, 149 | # blossombase[b] is its base VERTEX (i.e. recursive sub-blossom). 150 | blossombase = list(range(nvertex)) + nvertex * [ -1 ] 151 | 152 | # If b is a non-trivial (sub-)blossom, 153 | # blossomendps[b] is a list of endpoints on its connecting edges, 154 | # such that blossomendps[b][i] is the local endpoint of blossomchilds[b][i] 155 | # on the edge that connects it to blossomchilds[b][wrap(i+1)]. 156 | blossomendps = (2 * nvertex) * [ None ] 157 | 158 | # If v is a free vertex (or an unreached vertex inside a T-blossom), 159 | # bestedge[v] is the edge to an S-vertex with least slack, 160 | # or -1 if there is no such edge. 161 | # If b is a (possibly trivial) top-level S-blossom, 162 | # bestedge[b] is the least-slack edge to a different S-blossom, 163 | # or -1 if there is no such edge. 164 | # This is used for efficient computation of delta2 and delta3. 165 | bestedge = (2 * nvertex) * [ -1 ] 166 | 167 | # If b is a non-trivial top-level S-blossom, 168 | # blossombestedges[b] is a list of least-slack edges to neighbouring 169 | # S-blossoms, or None if no such list has been computed yet. 170 | # This is used for efficient computation of delta3. 171 | blossombestedges = (2 * nvertex) * [ None ] 172 | 173 | # List of currently unused blossom numbers. 174 | unusedblossoms = list(range(nvertex, 2*nvertex)) 175 | 176 | # If v is a vertex, 177 | # dualvar[v] = 2 * u(v) where u(v) is the v's variable in the dual 178 | # optimization problem (multiplication by two ensures integer values 179 | # throughout the algorithm if all edge weights are integers). 180 | # If b is a non-trivial blossom, 181 | # dualvar[b] = z(b) where z(b) is b's variable in the dual optimization 182 | # problem. 183 | dualvar = nvertex * [ maxweight ] + nvertex * [ 0 ] 184 | 185 | # If allowedge[k] is true, edge k has zero slack in the optimization 186 | # problem; if allowedge[k] is false, the edge's slack may or may not 187 | # be zero. 188 | allowedge = nedge * [ False ] 189 | 190 | # Queue of newly discovered S-vertices. 191 | queue = [ ] 192 | 193 | # Return 2 * slack of edge k (does not work inside blossoms). 194 | def slack(k): 195 | (i, j, wt) = edges[k] 196 | return dualvar[i] + dualvar[j] - 2 * wt 197 | 198 | # Generate the leaf vertices of a blossom. 199 | def blossomLeaves(b): 200 | if b < nvertex: 201 | yield b 202 | else: 203 | for t in blossomchilds[b]: 204 | if t < nvertex: 205 | yield t 206 | else: 207 | for v in blossomLeaves(t): 208 | yield v 209 | 210 | # Assign label t to the top-level blossom containing vertex w 211 | # and record the fact that w was reached through the edge with 212 | # remote endpoint p. 213 | def assignLabel(w, t, p): 214 | if DEBUG: DEBUG('assignLabel(%d,%d,%d)' % (w, t, p)) 215 | b = inblossom[w] 216 | assert label[w] == 0 and label[b] == 0 217 | label[w] = label[b] = t 218 | labelend[w] = labelend[b] = p 219 | bestedge[w] = bestedge[b] = -1 220 | if t == 1: 221 | # b became an S-vertex/blossom; add it(s vertices) to the queue. 222 | queue.extend(blossomLeaves(b)) 223 | if DEBUG: DEBUG('PUSH ' + str(list(blossomLeaves(b)))) 224 | elif t == 2: 225 | # b became a T-vertex/blossom; assign label S to its mate. 226 | # (If b is a non-trivial blossom, its base is the only vertex 227 | # with an external mate.) 228 | base = blossombase[b] 229 | assert mate[base] >= 0 230 | assignLabel(endpoint[mate[base]], 1, mate[base] ^ 1) 231 | 232 | # Trace back from vertices v and w to discover either a new blossom 233 | # or an augmenting path. Return the base vertex of the new blossom or -1. 234 | def scanBlossom(v, w): 235 | if DEBUG: DEBUG('scanBlossom(%d,%d)' % (v, w)) 236 | # Trace back from v and w, placing breadcrumbs as we go. 237 | path = [ ] 238 | base = -1 239 | while v != -1 or w != -1: 240 | # Look for a breadcrumb in v's blossom or put a new breadcrumb. 241 | b = inblossom[v] 242 | if label[b] & 4: 243 | base = blossombase[b] 244 | break 245 | assert label[b] == 1 246 | path.append(b) 247 | label[b] = 5 248 | # Trace one step back. 249 | assert labelend[b] == mate[blossombase[b]] 250 | if labelend[b] == -1: 251 | # The base of blossom b is single; stop tracing this path. 252 | v = -1 253 | else: 254 | v = endpoint[labelend[b]] 255 | b = inblossom[v] 256 | assert label[b] == 2 257 | # b is a T-blossom; trace one more step back. 258 | assert labelend[b] >= 0 259 | v = endpoint[labelend[b]] 260 | # Swap v and w so that we alternate between both paths. 261 | if w != -1: 262 | v, w = w, v 263 | # Remove breadcrumbs. 264 | for b in path: 265 | label[b] = 1 266 | # Return base vertex, if we found one. 267 | return base 268 | 269 | # Construct a new blossom with given base, containing edge k which 270 | # connects a pair of S vertices. Label the new blossom as S; set its dual 271 | # variable to zero; relabel its T-vertices to S and add them to the queue. 272 | def addBlossom(base, k): 273 | (v, w, wt) = edges[k] 274 | bb = inblossom[base] 275 | bv = inblossom[v] 276 | bw = inblossom[w] 277 | # Create blossom. 278 | b = unusedblossoms.pop() 279 | if DEBUG: DEBUG('addBlossom(%d,%d) (v=%d w=%d) -> %d' % (base, k, v, w, b)) 280 | blossombase[b] = base 281 | blossomparent[b] = -1 282 | blossomparent[bb] = b 283 | # Make list of sub-blossoms and their interconnecting edge endpoints. 284 | blossomchilds[b] = path = [ ] 285 | blossomendps[b] = endps = [ ] 286 | # Trace back from v to base. 287 | while bv != bb: 288 | # Add bv to the new blossom. 289 | blossomparent[bv] = b 290 | path.append(bv) 291 | endps.append(labelend[bv]) 292 | assert (label[bv] == 2 or 293 | (label[bv] == 1 and labelend[bv] == mate[blossombase[bv]])) 294 | # Trace one step back. 295 | assert labelend[bv] >= 0 296 | v = endpoint[labelend[bv]] 297 | bv = inblossom[v] 298 | # Reverse lists, add endpoint that connects the pair of S vertices. 299 | path.append(bb) 300 | path.reverse() 301 | endps.reverse() 302 | endps.append(2*k) 303 | # Trace back from w to base. 304 | while bw != bb: 305 | # Add bw to the new blossom. 306 | blossomparent[bw] = b 307 | path.append(bw) 308 | endps.append(labelend[bw] ^ 1) 309 | assert (label[bw] == 2 or 310 | (label[bw] == 1 and labelend[bw] == mate[blossombase[bw]])) 311 | # Trace one step back. 312 | assert labelend[bw] >= 0 313 | w = endpoint[labelend[bw]] 314 | bw = inblossom[w] 315 | # Set label to S. 316 | assert label[bb] == 1 317 | label[b] = 1 318 | labelend[b] = labelend[bb] 319 | # Set dual variable to zero. 320 | dualvar[b] = 0 321 | # Relabel vertices. 322 | for v in blossomLeaves(b): 323 | if label[inblossom[v]] == 2: 324 | # This T-vertex now turns into an S-vertex because it becomes 325 | # part of an S-blossom; add it to the queue. 326 | queue.append(v) 327 | inblossom[v] = b 328 | # Compute blossombestedges[b]. 329 | bestedgeto = (2 * nvertex) * [ -1 ] 330 | for bv in path: 331 | if blossombestedges[bv] is None: 332 | # This subblossom does not have a list of least-slack edges; 333 | # get the information from the vertices. 334 | nblists = [ [ p // 2 for p in neighbend[v] ] 335 | for v in blossomLeaves(bv) ] 336 | else: 337 | # Walk this subblossom's least-slack edges. 338 | nblists = [ blossombestedges[bv] ] 339 | for nblist in nblists: 340 | for k in nblist: 341 | (i, j, wt) = edges[k] 342 | if inblossom[j] == b: 343 | i, j = j, i 344 | bj = inblossom[j] 345 | if (bj != b and label[bj] == 1 and 346 | (bestedgeto[bj] == -1 or 347 | slack(k) < slack(bestedgeto[bj]))): 348 | bestedgeto[bj] = k 349 | # Forget about least-slack edges of the subblossom. 350 | blossombestedges[bv] = None 351 | bestedge[bv] = -1 352 | blossombestedges[b] = [ k for k in bestedgeto if k != -1 ] 353 | # Select bestedge[b]. 354 | bestedge[b] = -1 355 | for k in blossombestedges[b]: 356 | if bestedge[b] == -1 or slack(k) < slack(bestedge[b]): 357 | bestedge[b] = k 358 | if DEBUG: DEBUG('blossomchilds[%d]=' % b + repr(blossomchilds[b])) 359 | 360 | # Expand the given top-level blossom. 361 | def expandBlossom(b, endstage): 362 | if DEBUG: DEBUG('expandBlossom(%d,%d) %s' % (b, endstage, repr(blossomchilds[b]))) 363 | # Convert sub-blossoms into top-level blossoms. 364 | for s in blossomchilds[b]: 365 | blossomparent[s] = -1 366 | if s < nvertex: 367 | inblossom[s] = s 368 | elif endstage and dualvar[s] == 0: 369 | # Recursively expand this sub-blossom. 370 | expandBlossom(s, endstage) 371 | else: 372 | for v in blossomLeaves(s): 373 | inblossom[v] = s 374 | # If we expand a T-blossom during a stage, its sub-blossoms must be 375 | # relabeled. 376 | if (not endstage) and label[b] == 2: 377 | # Start at the sub-blossom through which the expanding 378 | # blossom obtained its label, and relabel sub-blossoms untili 379 | # we reach the base. 380 | # Figure out through which sub-blossom the expanding blossom 381 | # obtained its label initially. 382 | assert labelend[b] >= 0 383 | entrychild = inblossom[endpoint[labelend[b] ^ 1]] 384 | # Decide in which direction we will go round the blossom. 385 | j = blossomchilds[b].index(entrychild) 386 | if j & 1: 387 | # Start index is odd; go forward and wrap. 388 | j -= len(blossomchilds[b]) 389 | jstep = 1 390 | endptrick = 0 391 | else: 392 | # Start index is even; go backward. 393 | jstep = -1 394 | endptrick = 1 395 | # Move along the blossom until we get to the base. 396 | p = labelend[b] 397 | while j != 0: 398 | # Relabel the T-sub-blossom. 399 | label[endpoint[p ^ 1]] = 0 400 | label[endpoint[blossomendps[b][j-endptrick]^endptrick^1]] = 0 401 | assignLabel(endpoint[p ^ 1], 2, p) 402 | # Step to the next S-sub-blossom and note its forward endpoint. 403 | allowedge[blossomendps[b][j-endptrick]//2] = True 404 | j += jstep 405 | p = blossomendps[b][j-endptrick] ^ endptrick 406 | # Step to the next T-sub-blossom. 407 | allowedge[p//2] = True 408 | j += jstep 409 | # Relabel the base T-sub-blossom WITHOUT stepping through to 410 | # its mate (so don't call assignLabel). 411 | bv = blossomchilds[b][j] 412 | label[endpoint[p ^ 1]] = label[bv] = 2 413 | labelend[endpoint[p ^ 1]] = labelend[bv] = p 414 | bestedge[bv] = -1 415 | # Continue along the blossom until we get back to entrychild. 416 | j += jstep 417 | while blossomchilds[b][j] != entrychild: 418 | # Examine the vertices of the sub-blossom to see whether 419 | # it is reachable from a neighbouring S-vertex outside the 420 | # expanding blossom. 421 | bv = blossomchilds[b][j] 422 | if label[bv] == 1: 423 | # This sub-blossom just got label S through one of its 424 | # neighbours; leave it. 425 | j += jstep 426 | continue 427 | for v in blossomLeaves(bv): 428 | if label[v] != 0: 429 | break 430 | # If the sub-blossom contains a reachable vertex, assign 431 | # label T to the sub-blossom. 432 | if label[v] != 0: 433 | assert label[v] == 2 434 | assert inblossom[v] == bv 435 | label[v] = 0 436 | label[endpoint[mate[blossombase[bv]]]] = 0 437 | assignLabel(v, 2, labelend[v]) 438 | j += jstep 439 | # Recycle the blossom number. 440 | label[b] = labelend[b] = -1 441 | blossomchilds[b] = blossomendps[b] = None 442 | blossombase[b] = -1 443 | blossombestedges[b] = None 444 | bestedge[b] = -1 445 | unusedblossoms.append(b) 446 | 447 | # Swap matched/unmatched edges over an alternating path through blossom b 448 | # between vertex v and the base vertex. Keep blossom bookkeeping consistent. 449 | def augmentBlossom(b, v): 450 | if DEBUG: DEBUG('augmentBlossom(%d,%d)' % (b, v)) 451 | # Bubble up through the blossom tree from vertex v to an immediate 452 | # sub-blossom of b. 453 | t = v 454 | while blossomparent[t] != b: 455 | t = blossomparent[t] 456 | # Recursively deal with the first sub-blossom. 457 | if t >= nvertex: 458 | augmentBlossom(t, v) 459 | # Decide in which direction we will go round the blossom. 460 | i = j = blossomchilds[b].index(t) 461 | if i & 1: 462 | # Start index is odd; go forward and wrap. 463 | j -= len(blossomchilds[b]) 464 | jstep = 1 465 | endptrick = 0 466 | else: 467 | # Start index is even; go backward. 468 | jstep = -1 469 | endptrick = 1 470 | # Move along the blossom until we get to the base. 471 | while j != 0: 472 | # Step to the next sub-blossom and augment it recursively. 473 | j += jstep 474 | t = blossomchilds[b][j] 475 | p = blossomendps[b][j-endptrick] ^ endptrick 476 | if t >= nvertex: 477 | augmentBlossom(t, endpoint[p]) 478 | # Step to the next sub-blossom and augment it recursively. 479 | j += jstep 480 | t = blossomchilds[b][j] 481 | if t >= nvertex: 482 | augmentBlossom(t, endpoint[p ^ 1]) 483 | # Match the edge connecting those sub-blossoms. 484 | mate[endpoint[p]] = p ^ 1 485 | mate[endpoint[p ^ 1]] = p 486 | if DEBUG: DEBUG('PAIR %d %d (k=%d)' % (endpoint[p], endpoint[p^1], p//2)) 487 | # Rotate the list of sub-blossoms to put the new base at the front. 488 | blossomchilds[b] = blossomchilds[b][i:] + blossomchilds[b][:i] 489 | blossomendps[b] = blossomendps[b][i:] + blossomendps[b][:i] 490 | blossombase[b] = blossombase[blossomchilds[b][0]] 491 | assert blossombase[b] == v 492 | 493 | # Swap matched/unmatched edges over an alternating path between two 494 | # single vertices. The augmenting path runs through edge k, which 495 | # connects a pair of S vertices. 496 | def augmentMatching(k): 497 | (v, w, wt) = edges[k] 498 | if DEBUG: DEBUG('augmentMatching(%d) (v=%d w=%d)' % (k, v, w)) 499 | if DEBUG: DEBUG('PAIR %d %d (k=%d)' % (v, w, k)) 500 | for (s, p) in ((v, 2*k+1), (w, 2*k)): 501 | # Match vertex s to remote endpoint p. Then trace back from s 502 | # until we find a single vertex, swapping matched and unmatched 503 | # edges as we go. 504 | while 1: 505 | bs = inblossom[s] 506 | assert label[bs] == 1 507 | assert labelend[bs] == mate[blossombase[bs]] 508 | # Augment through the S-blossom from s to base. 509 | if bs >= nvertex: 510 | augmentBlossom(bs, s) 511 | # Update mate[s] 512 | mate[s] = p 513 | # Trace one step back. 514 | if labelend[bs] == -1: 515 | # Reached single vertex; stop. 516 | break 517 | t = endpoint[labelend[bs]] 518 | bt = inblossom[t] 519 | assert label[bt] == 2 520 | # Trace one step back. 521 | assert labelend[bt] >= 0 522 | s = endpoint[labelend[bt]] 523 | j = endpoint[labelend[bt] ^ 1] 524 | # Augment through the T-blossom from j to base. 525 | assert blossombase[bt] == t 526 | if bt >= nvertex: 527 | augmentBlossom(bt, j) 528 | # Update mate[j] 529 | mate[j] = labelend[bt] 530 | # Keep the opposite endpoint; 531 | # it will be assigned to mate[s] in the next step. 532 | p = labelend[bt] ^ 1 533 | if DEBUG: DEBUG('PAIR %d %d (k=%d)' % (s, t, p//2)) 534 | 535 | # Verify that the optimum solution has been reached. 536 | def verifyOptimum(): 537 | if maxcardinality: 538 | # Vertices may have negative dual; 539 | # find a constant non-negative number to add to all vertex duals. 540 | vdualoffset = max(0, -min(dualvar[:nvertex])) 541 | else: 542 | vdualoffset = 0 543 | # 0. all dual variables are non-negative 544 | assert min(dualvar[:nvertex]) + vdualoffset >= 0 545 | assert min(dualvar[nvertex:]) >= 0 546 | # 0. all edges have non-negative slack and 547 | # 1. all matched edges have zero slack; 548 | for k in range(nedge): 549 | (i, j, wt) = edges[k] 550 | s = dualvar[i] + dualvar[j] - 2 * wt 551 | iblossoms = [ i ] 552 | jblossoms = [ j ] 553 | while blossomparent[iblossoms[-1]] != -1: 554 | iblossoms.append(blossomparent[iblossoms[-1]]) 555 | while blossomparent[jblossoms[-1]] != -1: 556 | jblossoms.append(blossomparent[jblossoms[-1]]) 557 | iblossoms.reverse() 558 | jblossoms.reverse() 559 | for (bi, bj) in zip(iblossoms, jblossoms): 560 | if bi != bj: 561 | break 562 | s += 2 * dualvar[bi] 563 | assert s >= 0 564 | if mate[i] // 2 == k or mate[j] // 2 == k: 565 | assert mate[i] // 2 == k and mate[j] // 2 == k 566 | assert s == 0 567 | # 2. all single vertices have zero dual value; 568 | for v in range(nvertex): 569 | assert mate[v] >= 0 or dualvar[v] + vdualoffset == 0 570 | # 3. all blossoms with positive dual value are full. 571 | for b in range(nvertex, 2*nvertex): 572 | if blossombase[b] >= 0 and dualvar[b] > 0: 573 | assert len(blossomendps[b]) % 2 == 1 574 | for p in blossomendps[b][1::2]: 575 | assert mate[endpoint[p]] == p ^ 1 576 | assert mate[endpoint[p ^ 1]] == p 577 | # Ok. 578 | 579 | # Check optimized delta2 against a trivial computation. 580 | def checkDelta2(): 581 | for v in range(nvertex): 582 | if label[inblossom[v]] == 0: 583 | bd = None 584 | bk = -1 585 | for p in neighbend[v]: 586 | k = p // 2 587 | w = endpoint[p] 588 | if label[inblossom[w]] == 1: 589 | d = slack(k) 590 | if bk == -1 or d < bd: 591 | bk = k 592 | bd = d 593 | if DEBUG and (bestedge[v] != -1 or bk != -1) and (bestedge[v] == -1 or bd != slack(bestedge[v])): 594 | DEBUG('v=' + str(v) + ' bk=' + str(bk) + ' bd=' + str(bd) + ' bestedge=' + str(bestedge[v]) + ' slack=' + str(slack(bestedge[v]))) 595 | assert (bk == -1 and bestedge[v] == -1) or (bestedge[v] != -1 and bd == slack(bestedge[v])) 596 | 597 | # Check optimized delta3 against a trivial computation. 598 | def checkDelta3(): 599 | bk = -1 600 | bd = None 601 | tbk = -1 602 | tbd = None 603 | for b in range(2 * nvertex): 604 | if blossomparent[b] == -1 and label[b] == 1: 605 | for v in blossomLeaves(b): 606 | for p in neighbend[v]: 607 | k = p // 2 608 | w = endpoint[p] 609 | if inblossom[w] != b and label[inblossom[w]] == 1: 610 | d = slack(k) 611 | if bk == -1 or d < bd: 612 | bk = k 613 | bd = d 614 | if bestedge[b] != -1: 615 | (i, j, wt) = edges[bestedge[b]] 616 | assert inblossom[i] == b or inblossom[j] == b 617 | assert inblossom[i] != b or inblossom[j] != b 618 | assert label[inblossom[i]] == 1 and label[inblossom[j]] == 1 619 | if tbk == -1 or slack(bestedge[b]) < tbd: 620 | tbk = bestedge[b] 621 | tbd = slack(bestedge[b]) 622 | if DEBUG and bd != tbd: 623 | DEBUG('bk=%d tbk=%d bd=%s tbd=%s' % (bk, tbk, repr(bd), repr(tbd))) 624 | assert bd == tbd 625 | 626 | # Main loop: continue until no further improvement is possible. 627 | for t in range(nvertex): 628 | 629 | # Each iteration of this loop is a "stage". 630 | # A stage finds an augmenting path and uses that to improve 631 | # the matching. 632 | if DEBUG: DEBUG('STAGE %d' % t) 633 | 634 | # Remove labels from top-level blossoms/vertices. 635 | label[:] = (2 * nvertex) * [ 0 ] 636 | 637 | # Forget all about least-slack edges. 638 | bestedge[:] = (2 * nvertex) * [ -1 ] 639 | blossombestedges[nvertex:] = nvertex * [ None ] 640 | 641 | # Loss of labeling means that we can not be sure that currently 642 | # allowable edges remain allowable througout this stage. 643 | allowedge[:] = nedge * [ False ] 644 | 645 | # Make queue empty. 646 | queue[:] = [ ] 647 | 648 | # Label single blossoms/vertices with S and put them in the queue. 649 | for v in range(nvertex): 650 | if mate[v] == -1 and label[inblossom[v]] == 0: 651 | assignLabel(v, 1, -1) 652 | 653 | # Loop until we succeed in augmenting the matching. 654 | augmented = 0 655 | while 1: 656 | 657 | # Each iteration of this loop is a "substage". 658 | # A substage tries to find an augmenting path; 659 | # if found, the path is used to improve the matching and 660 | # the stage ends. If there is no augmenting path, the 661 | # primal-dual method is used to pump some slack out of 662 | # the dual variables. 663 | if DEBUG: DEBUG('SUBSTAGE') 664 | 665 | # Continue labeling until all vertices which are reachable 666 | # through an alternating path have got a label. 667 | while queue and not augmented: 668 | 669 | # Take an S vertex from the queue. 670 | v = queue.pop() 671 | if DEBUG: DEBUG('POP v=%d' % v) 672 | assert label[inblossom[v]] == 1 673 | 674 | # Scan its neighbours: 675 | for p in neighbend[v]: 676 | k = p // 2 677 | w = endpoint[p] 678 | # w is a neighbour to v 679 | if inblossom[v] == inblossom[w]: 680 | # this edge is internal to a blossom; ignore it 681 | continue 682 | if not allowedge[k]: 683 | kslack = slack(k) 684 | if kslack <= 0: 685 | # edge k has zero slack => it is allowable 686 | allowedge[k] = True 687 | if allowedge[k]: 688 | if label[inblossom[w]] == 0: 689 | # (C1) w is a free vertex; 690 | # label w with T and label its mate with S (R12). 691 | assignLabel(w, 2, p ^ 1) 692 | elif label[inblossom[w]] == 1: 693 | # (C2) w is an S-vertex (not in the same blossom); 694 | # follow back-links to discover either an 695 | # augmenting path or a new blossom. 696 | base = scanBlossom(v, w) 697 | if base >= 0: 698 | # Found a new blossom; add it to the blossom 699 | # bookkeeping and turn it into an S-blossom. 700 | addBlossom(base, k) 701 | else: 702 | # Found an augmenting path; augment the 703 | # matching and end this stage. 704 | augmentMatching(k) 705 | augmented = 1 706 | break 707 | elif label[w] == 0: 708 | # w is inside a T-blossom, but w itself has not 709 | # yet been reached from outside the blossom; 710 | # mark it as reached (we need this to relabel 711 | # during T-blossom expansion). 712 | assert label[inblossom[w]] == 2 713 | label[w] = 2 714 | labelend[w] = p ^ 1 715 | elif label[inblossom[w]] == 1: 716 | # keep track of the least-slack non-allowable edge to 717 | # a different S-blossom. 718 | b = inblossom[v] 719 | if bestedge[b] == -1 or kslack < slack(bestedge[b]): 720 | bestedge[b] = k 721 | elif label[w] == 0: 722 | # w is a free vertex (or an unreached vertex inside 723 | # a T-blossom) but we can not reach it yet; 724 | # keep track of the least-slack edge that reaches w. 725 | if bestedge[w] == -1 or kslack < slack(bestedge[w]): 726 | bestedge[w] = k 727 | 728 | if augmented: 729 | break 730 | 731 | # There is no augmenting path under these constraints; 732 | # compute delta and reduce slack in the optimization problem. 733 | # (Note that our vertex dual variables, edge slacks and delta's 734 | # are pre-multiplied by two.) 735 | deltatype = -1 736 | delta = deltaedge = deltablossom = None 737 | 738 | # Verify data structures for delta2/delta3 computation. 739 | if CHECK_DELTA: 740 | checkDelta2() 741 | checkDelta3() 742 | 743 | # Compute delta1: the minumum value of any vertex dual. 744 | if not maxcardinality: 745 | deltatype = 1 746 | delta = min(dualvar[:nvertex]) 747 | 748 | # Compute delta2: the minimum slack on any edge between 749 | # an S-vertex and a free vertex. 750 | for v in range(nvertex): 751 | if label[inblossom[v]] == 0 and bestedge[v] != -1: 752 | d = slack(bestedge[v]) 753 | if deltatype == -1 or d < delta: 754 | delta = d 755 | deltatype = 2 756 | deltaedge = bestedge[v] 757 | 758 | # Compute delta3: half the minimum slack on any edge between 759 | # a pair of S-blossoms. 760 | for b in range(2 * nvertex): 761 | if ( blossomparent[b] == -1 and label[b] == 1 and 762 | bestedge[b] != -1 ): 763 | kslack = slack(bestedge[b]) 764 | if isinstance(kslack, integer_types): 765 | assert (kslack % 2) == 0 766 | d = kslack // 2 767 | else: 768 | d = kslack / 2 769 | if deltatype == -1 or d < delta: 770 | delta = d 771 | deltatype = 3 772 | deltaedge = bestedge[b] 773 | 774 | # Compute delta4: minimum z variable of any T-blossom. 775 | for b in range(nvertex, 2*nvertex): 776 | if ( blossombase[b] >= 0 and blossomparent[b] == -1 and 777 | label[b] == 2 and 778 | (deltatype == -1 or dualvar[b] < delta) ): 779 | delta = dualvar[b] 780 | deltatype = 4 781 | deltablossom = b 782 | 783 | if deltatype == -1: 784 | # No further improvement possible; max-cardinality optimum 785 | # reached. Do a final delta update to make the optimum 786 | # verifyable. 787 | assert maxcardinality 788 | deltatype = 1 789 | delta = max(0, min(dualvar[:nvertex])) 790 | 791 | # Update dual variables according to delta. 792 | for v in range(nvertex): 793 | if label[inblossom[v]] == 1: 794 | # S-vertex: 2*u = 2*u - 2*delta 795 | dualvar[v] -= delta 796 | elif label[inblossom[v]] == 2: 797 | # T-vertex: 2*u = 2*u + 2*delta 798 | dualvar[v] += delta 799 | for b in range(nvertex, 2*nvertex): 800 | if blossombase[b] >= 0 and blossomparent[b] == -1: 801 | if label[b] == 1: 802 | # top-level S-blossom: z = z + 2*delta 803 | dualvar[b] += delta 804 | elif label[b] == 2: 805 | # top-level T-blossom: z = z - 2*delta 806 | dualvar[b] -= delta 807 | 808 | # Take action at the point where minimum delta occurred. 809 | if DEBUG: DEBUG('delta%d=%f' % (deltatype, delta)) 810 | if deltatype == 1: 811 | # No further improvement possible; optimum reached. 812 | break 813 | elif deltatype == 2: 814 | # Use the least-slack edge to continue the search. 815 | allowedge[deltaedge] = True 816 | (i, j, wt) = edges[deltaedge] 817 | if label[inblossom[i]] == 0: 818 | i, j = j, i 819 | assert label[inblossom[i]] == 1 820 | queue.append(i) 821 | elif deltatype == 3: 822 | # Use the least-slack edge to continue the search. 823 | allowedge[deltaedge] = True 824 | (i, j, wt) = edges[deltaedge] 825 | assert label[inblossom[i]] == 1 826 | queue.append(i) 827 | elif deltatype == 4: 828 | # Expand the least-z blossom. 829 | expandBlossom(deltablossom, False) 830 | 831 | # End of a this substage. 832 | 833 | # Stop when no more augmenting path can be found. 834 | if not augmented: 835 | break 836 | 837 | # End of a stage; expand all S-blossoms which have dualvar = 0. 838 | for b in range(nvertex, 2*nvertex): 839 | if ( blossomparent[b] == -1 and blossombase[b] >= 0 and 840 | label[b] == 1 and dualvar[b] == 0 ): 841 | expandBlossom(b, True) 842 | 843 | # Verify that we reached the optimum solution. 844 | if CHECK_OPTIMUM: 845 | verifyOptimum() 846 | 847 | # Transform mate[] such that mate[v] is the vertex to which v is paired. 848 | for v in range(nvertex): 849 | if mate[v] >= 0: 850 | mate[v] = endpoint[mate[v]] 851 | for v in range(nvertex): 852 | assert mate[v] == -1 or mate[mate[v]] == v 853 | 854 | return mate 855 | 856 | 857 | # Unit tests 858 | if __name__ == '__main__': 859 | x = maxWeightMatching([(1,4,10),(1,5,20),(2,5,40),(2,6,60),(3,4,30)]) 860 | print(x) 861 | 862 | # end -------------------------------------------------------------------------------- /lib/evaluate/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xg-chu/CrowdDet/4a0c674c40fdcb3e2706d39544f088390f9f63fe/lib/evaluate/__init__.py -------------------------------------------------------------------------------- /lib/evaluate/compute_APMR.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from .APMRToolkits import * 3 | 4 | dbName = 'human' 5 | def compute_APMR(dt_path, gt_path, target_key=None, mode=0): 6 | database = Database(gt_path, dt_path, target_key, None, mode) 7 | database.compare() 8 | mAP,_ = database.eval_AP() 9 | mMR,_ = database.eval_MR() 10 | line = 'AP:{:.4f}, MR:{:.4f}.'.format(mAP, mMR) 11 | return mAP, mMR 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser(description='Analyze a json result file with iou match') 15 | parser.add_argument('--detfile', required=True, help='path of json result file to load') 16 | parser.add_argument('--target_key', default=None, required=True) 17 | args = parser.parse_args() 18 | compute_APMR(args.detfile, args.target_key, 0) 19 | -------------------------------------------------------------------------------- /lib/evaluate/compute_JI.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import argparse 5 | from multiprocessing import Queue, Process 6 | 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | from .JIToolkits.JI_tools import compute_matching, get_ignores 11 | sys.path.insert(0, '../') 12 | import utils.misc_utils as misc_utils 13 | 14 | gtfile = '/data/annotation_val.odgt' 15 | nr_procs = 10 16 | 17 | def evaluation_all(path, target_key): 18 | records = misc_utils.load_json_lines(path) 19 | res_line = [] 20 | res_JI = [] 21 | for i in range(10): 22 | score_thr = 1e-1 * i 23 | total = len(records) 24 | stride = math.ceil(total / nr_procs) 25 | result_queue = Queue(10000) 26 | results, procs = [], [] 27 | for i in range(nr_procs): 28 | start = i*stride 29 | end = np.min([start+stride,total]) 30 | sample_data = records[start:end] 31 | p = Process(target= compute_JI_with_ignore, args=(result_queue, sample_data, score_thr, target_key)) 32 | p.start() 33 | procs.append(p) 34 | tqdm.monitor_interval = 0 35 | pbar = tqdm(total=total, leave = False, ascii = True) 36 | for i in range(total): 37 | t = result_queue.get() 38 | results.append(t) 39 | pbar.update(1) 40 | for p in procs: 41 | p.join() 42 | pbar.close() 43 | line, mean_ratio = gather(results) 44 | line = 'score_thr:{:.1f}, {}'.format(score_thr, line) 45 | print(line) 46 | res_line.append(line) 47 | res_JI.append(mean_ratio) 48 | return res_line, max(res_JI) 49 | 50 | def compute_JI_with_ignore(result_queue, records, score_thr, target_key, bm_thresh=0.5): 51 | for record in records: 52 | gt_boxes = misc_utils.load_bboxes(record, 'gtboxes', target_key, 'tag') 53 | gt_boxes[:,2:4] += gt_boxes[:,:2] 54 | gt_boxes = misc_utils.clip_boundary(gt_boxes, record['height'], record['width']) 55 | dt_boxes = misc_utils.load_bboxes(record, 'dtboxes', target_key, 'score') 56 | dt_boxes[:,2:4] += dt_boxes[:,:2] 57 | dt_boxes = misc_utils.clip_boundary(dt_boxes, record['height'], record['width']) 58 | keep = dt_boxes[:, -1] > score_thr 59 | dt_boxes = dt_boxes[keep][:, :-1] 60 | 61 | gt_tag = np.array(gt_boxes[:,-1]!=-1) 62 | matches = compute_matching(dt_boxes, gt_boxes[gt_tag, :4], bm_thresh) 63 | # get the unmatched_indices 64 | matched_indices = np.array([j for (j,_) in matches]) 65 | unmatched_indices = list(set(np.arange(dt_boxes.shape[0])) - set(matched_indices)) 66 | num_ignore_dt = get_ignores(dt_boxes[unmatched_indices], gt_boxes[~gt_tag, :4], bm_thresh) 67 | matched_indices = np.array([j for (_,j) in matches]) 68 | unmatched_indices = list(set(np.arange(gt_boxes[gt_tag].shape[0])) - set(matched_indices)) 69 | num_ignore_gt = get_ignores(gt_boxes[gt_tag][unmatched_indices], gt_boxes[~gt_tag, :4], bm_thresh) 70 | # compurte results 71 | eps = 1e-6 72 | k = len(matches) 73 | m = gt_tag.sum() - num_ignore_gt 74 | n = dt_boxes.shape[0] - num_ignore_dt 75 | ratio = k / (m + n -k + eps) 76 | recall = k / (m + eps) 77 | cover = k / (n + eps) 78 | noise = 1 - cover 79 | result_dict = dict(ratio = ratio, recall = recall, cover = cover, 80 | noise = noise, k = k, m = m, n = n) 81 | result_queue.put_nowait(result_dict) 82 | 83 | def gather(results): 84 | assert len(results) 85 | img_num = 0 86 | for result in results: 87 | if result['n'] != 0 or result['m'] != 0: 88 | img_num += 1 89 | mean_ratio = np.sum([rb['ratio'] for rb in results]) / img_num 90 | mean_cover = np.sum([rb['cover'] for rb in results]) / img_num 91 | mean_recall = np.sum([rb['recall'] for rb in results]) / img_num 92 | mean_noise = 1 - mean_cover 93 | valids = np.sum([rb['k'] for rb in results]) 94 | total = np.sum([rb['n'] for rb in results]) 95 | gtn = np.sum([rb['m'] for rb in results]) 96 | 97 | #line = 'mean_ratio:{:.4f}, mean_cover:{:.4f}, mean_recall:{:.4f}, mean_noise:{:.4f}, valids:{}, total:{}, gtn:{}'.format( 98 | # mean_ratio, mean_cover, mean_recall, mean_noise, valids, total, gtn) 99 | line = 'mean_ratio:{:.4f}, valids:{}, total:{}, gtn:{}'.format( 100 | mean_ratio, valids, total, gtn) 101 | return line, mean_ratio 102 | 103 | def common_process(func, cls_list, nr_procs): 104 | total = len(cls_list) 105 | stride = math.ceil(total / nr_procs) 106 | result_queue = Queue(10000) 107 | results, procs = [], [] 108 | for i in range(nr_procs): 109 | start = i*stride 110 | end = np.min([start+stride,total]) 111 | sample_data = cls_list[start:end] 112 | p = Process(target= func,args=(result_queue, sample_data)) 113 | p.start() 114 | procs.append(p) 115 | for i in range(total): 116 | t = result_queue.get() 117 | if t is None: 118 | continue 119 | results.append(t) 120 | for p in procs: 121 | p.join() 122 | return results 123 | 124 | if __name__ == "__main__": 125 | parser = argparse.ArgumentParser(description='Analyze a json result file with iou match') 126 | parser.add_argument('--detfile', required=True, help='path of json result file to load') 127 | parser.add_argument('--target_key', required=True) 128 | args = parser.parse_args() 129 | evaluation_all(args.detfile, args.target_key) 130 | -------------------------------------------------------------------------------- /lib/layers/batch_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class FrozenBatchNorm2d(nn.Module): 5 | """ 6 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 7 | """ 8 | 9 | def __init__(self, num_features, eps=1e-5): 10 | super().__init__() 11 | self.eps = eps 12 | self.register_buffer("weight", torch.ones(num_features)) 13 | self.register_buffer("bias", torch.zeros(num_features)) 14 | self.register_buffer("running_mean", torch.zeros(num_features)) 15 | self.register_buffer("running_var", torch.ones(num_features) - eps) 16 | 17 | def forward(self, x): 18 | scale = self.weight * (self.running_var + self.eps).rsqrt() 19 | bias = self.bias - self.running_mean * scale 20 | scale = scale.reshape(1, -1, 1, 1) 21 | bias = bias.reshape(1, -1, 1, 1) 22 | return x * scale + bias 23 | 24 | def __repr__(self): 25 | return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) 26 | 27 | -------------------------------------------------------------------------------- /lib/layers/pooler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torchvision.ops import roi_align 5 | 6 | def assign_boxes_to_levels(rois, min_level, max_level, canonical_box_size=224, canonical_level=4): 7 | """ 8 | rois (Tensor): A tensor of shape (N, 5). 9 | min_level (int), max_level (int), canonical_box_size (int), canonical_level (int). 10 | Return a tensor of length N. 11 | """ 12 | eps = 1e-6 13 | box_sizes = torch.sqrt((rois[:, 3] - rois[:, 1]) * (rois[:, 4] - rois[:, 2])) 14 | # Eqn.(1) in FPN paper 15 | level_assignments = torch.floor( 16 | canonical_level + torch.log2(box_sizes / canonical_box_size + eps) 17 | ) 18 | # clamp level to (min, max), in case the box size is too large or too small 19 | # for the available feature maps 20 | level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level) 21 | return level_assignments.to(torch.int64) - min_level 22 | 23 | def roi_pooler(fpn_fms, rois, stride, pool_shape, pooler_type): 24 | if pooler_type == "ROIAlign": 25 | pooler_aligned = False 26 | elif pooler_type == "ROIAlignV2": 27 | pooler_aligned = True 28 | else: 29 | raise ValueError("Unknown pooler type: {}".format(pooler_type)) 30 | assert len(fpn_fms) == len(stride) 31 | max_level = int(math.log2(stride[-1])) 32 | min_level = int(math.log2(stride[0])) 33 | assert (len(stride) == max_level - min_level + 1) 34 | level_assignments = assign_boxes_to_levels(rois, min_level, max_level, 224, 4) 35 | dtype, device = fpn_fms[0].dtype, fpn_fms[0].device 36 | output = torch.zeros((len(rois), fpn_fms[0].shape[1], pool_shape[0], pool_shape[1]), 37 | dtype=dtype, device=device) 38 | for level, (fm_level, scale_level) in enumerate(zip(fpn_fms, stride)): 39 | inds = torch.nonzero(level_assignments == level, as_tuple=False).squeeze(1) 40 | rois_level = rois[inds] 41 | output[inds] = roi_align(fm_level, rois_level, pool_shape, spatial_scale=1.0/scale_level, 42 | sampling_ratio=-1, aligned=pooler_aligned) 43 | return output 44 | 45 | -------------------------------------------------------------------------------- /lib/module/rpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from config import config 7 | from det_oprs.anchors_generator import AnchorGenerator 8 | from det_oprs.find_top_rpn_proposals import find_top_rpn_proposals 9 | from det_oprs.fpn_anchor_target import fpn_anchor_target, fpn_rpn_reshape 10 | from det_oprs.loss_opr import softmax_loss, smooth_l1_loss 11 | 12 | class RPN(nn.Module): 13 | def __init__(self, rpn_channel = 256): 14 | super().__init__() 15 | self.anchors_generator = AnchorGenerator( 16 | config.anchor_base_size, 17 | config.anchor_aspect_ratios, 18 | config.anchor_base_scale) 19 | self.rpn_conv = nn.Conv2d(256, rpn_channel, kernel_size=3, stride=1, padding=1) 20 | self.rpn_cls_score = nn.Conv2d(rpn_channel, config.num_cell_anchors * 2, kernel_size=1, stride=1) 21 | self.rpn_bbox_offsets = nn.Conv2d(rpn_channel, config.num_cell_anchors * 4, kernel_size=1, stride=1) 22 | 23 | for l in [self.rpn_conv, self.rpn_cls_score, self.rpn_bbox_offsets]: 24 | nn.init.normal_(l.weight, std=0.01) 25 | nn.init.constant_(l.bias, 0) 26 | 27 | def forward(self, features, im_info, boxes=None): 28 | # prediction 29 | pred_cls_score_list = [] 30 | pred_bbox_offsets_list = [] 31 | for x in features: 32 | t = F.relu(self.rpn_conv(x)) 33 | pred_cls_score_list.append(self.rpn_cls_score(t)) 34 | pred_bbox_offsets_list.append(self.rpn_bbox_offsets(t)) 35 | # get anchors 36 | all_anchors_list = [] 37 | # stride: 64,32,16,8,4 p6->p2 38 | base_stride = 4 39 | off_stride = 2**(len(features)-1) # 16 40 | for fm in features: 41 | layer_anchors = self.anchors_generator(fm, base_stride, off_stride) 42 | off_stride = off_stride // 2 43 | all_anchors_list.append(layer_anchors) 44 | # sample from the predictions 45 | rpn_rois = find_top_rpn_proposals( 46 | self.training, pred_bbox_offsets_list, pred_cls_score_list, 47 | all_anchors_list, im_info) 48 | rpn_rois = rpn_rois.type_as(features[0]) 49 | if self.training: 50 | rpn_labels, rpn_bbox_targets = fpn_anchor_target( 51 | boxes, im_info, all_anchors_list) 52 | #rpn_labels = rpn_labels.astype(np.int32) 53 | pred_cls_score, pred_bbox_offsets = fpn_rpn_reshape( 54 | pred_cls_score_list, pred_bbox_offsets_list) 55 | # rpn loss 56 | valid_masks = rpn_labels >= 0 57 | objectness_loss = softmax_loss( 58 | pred_cls_score[valid_masks], 59 | rpn_labels[valid_masks]) 60 | 61 | pos_masks = rpn_labels > 0 62 | localization_loss = smooth_l1_loss( 63 | pred_bbox_offsets[pos_masks], 64 | rpn_bbox_targets[pos_masks], 65 | config.rpn_smooth_l1_beta) 66 | normalizer = 1 / valid_masks.sum().item() 67 | loss_rpn_cls = objectness_loss.sum() * normalizer 68 | loss_rpn_loc = localization_loss.sum() * normalizer 69 | loss_dict = {} 70 | loss_dict['loss_rpn_cls'] = loss_rpn_cls 71 | loss_dict['loss_rpn_loc'] = loss_rpn_loc 72 | return rpn_rois, loss_dict 73 | else: 74 | return rpn_rois 75 | 76 | -------------------------------------------------------------------------------- /lib/utils/SGD_bias.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer, required 3 | 4 | class SGD(Optimizer): 5 | """Implements stochastic gradient descent (optionally with momentum). 6 | Args: 7 | params (iterable): iterable of parameters to optimize or dicts defining 8 | parameter groups 9 | lr (float): learning rate 10 | momentum (float, optional): momentum factor (default: 0) 11 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 12 | dampening (float, optional): dampening for momentum (default: 0) 13 | nesterov (bool, optional): enables Nesterov momentum (default: False) 14 | """ 15 | 16 | def __init__(self, params, lr=required, momentum=0, dampening=0, 17 | weight_decay=0, nesterov=False): 18 | if lr is not required and lr < 0.0: 19 | raise ValueError("Invalid learning rate: {}".format(lr)) 20 | if momentum < 0.0: 21 | raise ValueError("Invalid momentum value: {}".format(momentum)) 22 | if weight_decay < 0.0: 23 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 24 | 25 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 26 | weight_decay=weight_decay, nesterov=nesterov) 27 | if nesterov and (momentum <= 0 or dampening != 0): 28 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 29 | super(SGD, self).__init__(params, defaults) 30 | 31 | def __setstate__(self, state): 32 | super(SGD, self).__setstate__(state) 33 | for group in self.param_groups: 34 | group.setdefault('nesterov', False) 35 | 36 | @torch.no_grad() 37 | def step(self): 38 | """Performs a single optimization step. 39 | """ 40 | loss = None 41 | for group in self.param_groups: 42 | weight_decay = group['weight_decay'] 43 | momentum = group['momentum'] 44 | dampening = group['dampening'] 45 | nesterov = group['nesterov'] 46 | 47 | for p in group['params']: 48 | if p.grad is None: 49 | continue 50 | d_p = p.grad 51 | if weight_decay != 0 and p.dim() > 1: 52 | d_p = d_p.add(p, alpha=weight_decay) 53 | if momentum != 0: 54 | param_state = self.state[p] 55 | if 'momentum_buffer' not in param_state: 56 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 57 | else: 58 | buf = param_state['momentum_buffer'] 59 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 60 | if nesterov: 61 | d_p = d_p.add(buf, alpha=momentum) 62 | else: 63 | d_p = buf 64 | 65 | p.add_(d_p, alpha=-group['lr']) 66 | 67 | return loss 68 | 69 | -------------------------------------------------------------------------------- /lib/utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | 5 | def load_img(image_path): 6 | import cv2 7 | img = cv2.imread(image_path, cv2.IMREAD_COLOR) 8 | return img 9 | 10 | def load_json_lines(fpath): 11 | assert os.path.exists(fpath) 12 | with open(fpath,'r') as fid: 13 | lines = fid.readlines() 14 | records = [json.loads(line.strip('\n')) for line in lines] 15 | return records 16 | 17 | def save_json_lines(content,fpath): 18 | with open(fpath,'w') as fid: 19 | for db in content: 20 | line = json.dumps(db)+'\n' 21 | fid.write(line) 22 | 23 | def device_parser(str_device): 24 | if '-' in str_device: 25 | device_id = str_device.split('-') 26 | device_id = [i for i in range(int(device_id[0]), int(device_id[1])+1)] 27 | else: 28 | device_id = [int(str_device)] 29 | return device_id 30 | 31 | def ensure_dir(dirpath): 32 | if not os.path.exists(dirpath): 33 | os.makedirs(dirpath) 34 | 35 | def xyxy_to_xywh(boxes): 36 | assert boxes.shape[1]>=4 37 | boxes[:, 2:4] -= boxes[:,:2] 38 | return boxes 39 | 40 | def xywh_to_xyxy(boxes): 41 | assert boxes.shape[1]>=4 42 | boxes[:, 2:4] += boxes[:,:2] 43 | return boxes 44 | 45 | def load_bboxes(dict_input, key_name, key_box, key_score=None, key_tag=None): 46 | assert key_name in dict_input 47 | if len(dict_input[key_name]) < 1: 48 | return np.empty([0, 5]) 49 | else: 50 | assert key_box in dict_input[key_name][0] 51 | if key_score: 52 | assert key_score in dict_input[key_name][0] 53 | if key_tag: 54 | assert key_tag in dict_input[key_name][0] 55 | if key_score: 56 | if key_tag: 57 | bboxes = np.vstack([np.hstack((rb[key_box], rb[key_score], rb[key_tag])) for rb in dict_input[key_name]]) 58 | else: 59 | bboxes = np.vstack([np.hstack((rb[key_box], rb[key_score])) for rb in dict_input[key_name]]) 60 | else: 61 | if key_tag: 62 | bboxes = np.vstack([np.hstack((rb[key_box], rb[key_tag])) for rb in dict_input[key_name]]) 63 | else: 64 | bboxes = np.vstack([rb[key_box] for rb in dict_input[key_name]]) 65 | return bboxes 66 | 67 | def load_masks(dict_input, key_name, key_box): 68 | assert key_name in dict_input 69 | if len(dict_input[key_name]) < 1: 70 | return np.empty([0, 28, 28]) 71 | else: 72 | assert key_box in dict_input[key_name][0] 73 | masks = np.array([rb[key_box] for rb in dict_input[key_name]]) 74 | return masks 75 | 76 | def load_gt(dict_input, key_name, key_box, class_names): 77 | assert key_name in dict_input 78 | if len(dict_input[key_name]) < 1: 79 | return np.empty([0, 5]) 80 | else: 81 | assert key_box in dict_input[key_name][0] 82 | bbox = [] 83 | for rb in dict_input[key_name]: 84 | if rb['tag'] in class_names: 85 | tag = class_names.index(rb['tag']) 86 | else: 87 | tag = -1 88 | if 'extra' in rb: 89 | if 'ignore' in rb['extra']: 90 | if rb['extra']['ignore'] != 0: 91 | tag = -1 92 | bbox.append(np.hstack((rb[key_box], tag))) 93 | bboxes = np.vstack(bbox).astype(np.float64) 94 | return bboxes 95 | 96 | def boxes_dump(boxes, is_gt): 97 | result = [] 98 | boxes = boxes.tolist() 99 | for box in boxes: 100 | if is_gt: 101 | box_dict = {} 102 | box_dict['box'] = [box[0], box[1], box[2]-box[0], box[3]-box[1]] 103 | box_dict['tag'] = box[-1] 104 | result.append(box_dict) 105 | else: 106 | box_dict = {} 107 | box_dict['box'] = [box[0], box[1], box[2]-box[0], box[3]-box[1]] 108 | box_dict['tag'] = 1 109 | box_dict['score'] = box[-1] 110 | result.append(box_dict) 111 | return result 112 | 113 | def clip_boundary(boxes,height,width): 114 | assert boxes.shape[-1]>=4 115 | boxes[:,0] = np.minimum(np.maximum(boxes[:,0],0), width - 1) 116 | boxes[:,1] = np.minimum(np.maximum(boxes[:,1],0), height - 1) 117 | boxes[:,2] = np.maximum(np.minimum(boxes[:,2],width), 0) 118 | boxes[:,3] = np.maximum(np.minimum(boxes[:,3],height), 0) 119 | return boxes 120 | 121 | -------------------------------------------------------------------------------- /lib/utils/nms_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pdb 3 | 4 | def set_cpu_nms(dets, thresh): 5 | """Pure Python NMS baseline.""" 6 | def _overlap(det_boxes, basement, others): 7 | eps = 1e-8 8 | x1_basement, y1_basement, x2_basement, y2_basement \ 9 | = det_boxes[basement, 0], det_boxes[basement, 1], \ 10 | det_boxes[basement, 2], det_boxes[basement, 3] 11 | x1_others, y1_others, x2_others, y2_others \ 12 | = det_boxes[others, 0], det_boxes[others, 1], \ 13 | det_boxes[others, 2], det_boxes[others, 3] 14 | areas_basement = (x2_basement - x1_basement) * (y2_basement - y1_basement) 15 | areas_others = (x2_others - x1_others) * (y2_others - y1_others) 16 | xx1 = np.maximum(x1_basement, x1_others) 17 | yy1 = np.maximum(y1_basement, y1_others) 18 | xx2 = np.minimum(x2_basement, x2_others) 19 | yy2 = np.minimum(y2_basement, y2_others) 20 | w = np.maximum(0.0, xx2 - xx1) 21 | h = np.maximum(0.0, yy2 - yy1) 22 | inter = w * h 23 | ovr = inter / (areas_basement + areas_others - inter + eps) 24 | return ovr 25 | scores = dets[:, 4] 26 | order = np.argsort(-scores) 27 | dets = dets[order] 28 | 29 | numbers = dets[:, -1] 30 | keep = np.ones(len(dets)) == 1 31 | ruler = np.arange(len(dets)) 32 | while ruler.size>0: 33 | basement = ruler[0] 34 | ruler=ruler[1:] 35 | num = numbers[basement] 36 | # calculate the body overlap 37 | overlap = _overlap(dets[:, :4], basement, ruler) 38 | indices = np.where(overlap > thresh)[0] 39 | loc = np.where(numbers[ruler][indices] == num)[0] 40 | # the mask won't change in the step 41 | mask = keep[ruler[indices][loc]]#.copy() 42 | keep[ruler[indices]] = False 43 | keep[ruler[indices][loc][mask]] = True 44 | ruler[~keep[ruler]] = -1 45 | ruler = ruler[ruler>0] 46 | keep = keep[np.argsort(order)] 47 | return keep 48 | 49 | def cpu_nms(dets, base_thr): 50 | """Pure Python NMS baseline.""" 51 | x1 = dets[:, 0] 52 | y1 = dets[:, 1] 53 | x2 = dets[:, 2] 54 | y2 = dets[:, 3] 55 | scores = dets[:, 4] 56 | 57 | areas = (x2 - x1) * (y2 - y1) 58 | order = np.argsort(-scores) 59 | 60 | keep = [] 61 | eps = 1e-8 62 | while len(order) > 0: 63 | i = order[0] 64 | keep.append(i) 65 | xx1 = np.maximum(x1[i], x1[order[1:]]) 66 | yy1 = np.maximum(y1[i], y1[order[1:]]) 67 | xx2 = np.minimum(x2[i], x2[order[1:]]) 68 | yy2 = np.minimum(y2[i], y2[order[1:]]) 69 | 70 | w = np.maximum(0.0, xx2 - xx1) 71 | h = np.maximum(0.0, yy2 - yy1) 72 | inter = w * h 73 | ovr = inter / (areas[i] + areas[order[1:]] - inter + eps) 74 | 75 | inds = np.where(ovr <= base_thr)[0] 76 | indices = np.where(ovr > base_thr)[0] 77 | order = order[inds + 1] 78 | return np.array(keep) 79 | 80 | def _test(): 81 | box1 = np.array([33,45,145,230,0.7])[None,:] 82 | box2 = np.array([44,54,123,348,0.8])[None,:] 83 | box3 = np.array([88,12,340,342,0.65])[None,:] 84 | boxes = np.concatenate([box1,box2,box3],axis = 0) 85 | nms_thresh = 0.5 86 | keep = py_cpu_nms(boxes,nms_thresh) 87 | alive_boxes = boxes[keep] 88 | 89 | if __name__=='__main__': 90 | _test() 91 | -------------------------------------------------------------------------------- /lib/utils/visual_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import cv2 5 | 6 | color = {'green':(0,255,0), 7 | 'blue':(255,165,0), 8 | 'dark red':(0,0,139), 9 | 'red':(0, 0, 255), 10 | 'dark slate blue':(139,61,72), 11 | 'aqua':(255,255,0), 12 | 'brown':(42,42,165), 13 | 'deep pink':(147,20,255), 14 | 'fuchisia':(255,0,255), 15 | 'yello':(0,238,238), 16 | 'orange':(0,165,255), 17 | 'saddle brown':(19,69,139), 18 | 'black':(0,0,0), 19 | 'white':(255,255,255)} 20 | 21 | def draw_boxes(img, boxes, scores=None, tags=None, line_thick=1, line_color='white'): 22 | width = img.shape[1] 23 | height = img.shape[0] 24 | for i in range(len(boxes)): 25 | one_box = boxes[i] 26 | one_box = np.array([max(one_box[0], 0), max(one_box[1], 0), 27 | min(one_box[2], width - 1), min(one_box[3], height - 1)]) 28 | x1,y1,x2,y2 = np.array(one_box[:4]).astype(int) 29 | cv2.rectangle(img, (x1,y1), (x2,y2), color[line_color], line_thick) 30 | if scores is not None: 31 | text = "{} {:.3f}".format(tags[i], scores[i]) 32 | cv2.putText(img, text, (x1, y1 - 7), cv2.FONT_ITALIC, 0.5, color[line_color], line_thick) 33 | return img 34 | 35 | -------------------------------------------------------------------------------- /model/rcnn_emd_refine/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | 6 | def add_path(path): 7 | if path not in sys.path: 8 | sys.path.insert(0, path) 9 | 10 | root_dir = '../../' 11 | add_path(os.path.join(root_dir)) 12 | add_path(os.path.join(root_dir, 'lib')) 13 | 14 | class Crowd_human: 15 | class_names = ['background', 'person'] 16 | num_classes = len(class_names) 17 | root_folder = '/data/CrowdHuman' 18 | image_folder = '/data/CrowdHuman/images' 19 | train_source = os.path.join('/data/CrowdHuman/annotation_train.odgt') 20 | eval_source = os.path.join('/data/CrowdHuman/annotation_val.odgt') 21 | 22 | class Config: 23 | output_dir = 'outputs' 24 | model_dir = os.path.join(output_dir, 'model_dump') 25 | eval_dir = os.path.join(output_dir, 'eval_dump') 26 | init_weights = '/data/model/resnet50_fbaug.pth' 27 | 28 | # ----------data config---------- # 29 | image_mean = np.array([103.530, 116.280, 123.675]) 30 | image_std = np.array([57.375, 57.120, 58.395]) 31 | train_image_short_size = 800 32 | train_image_max_size = 1400 33 | eval_resize = True 34 | eval_image_short_size = 800 35 | eval_image_max_size = 1400 36 | seed_dataprovider = 3 37 | train_source = Crowd_human.train_source 38 | eval_source = Crowd_human.eval_source 39 | image_folder = Crowd_human.image_folder 40 | class_names = Crowd_human.class_names 41 | num_classes = Crowd_human.num_classes 42 | class_names2id = dict(list(zip(class_names, list(range(num_classes))))) 43 | gt_boxes_name = 'fbox' 44 | 45 | # ----------train config---------- # 46 | backbone_freeze_at = 2 47 | rpn_channel = 256 48 | 49 | train_batch_per_gpu = 2 50 | momentum = 0.9 51 | weight_decay = 1e-4 52 | base_lr = 1e-3 * 1.25 53 | 54 | warm_iter = 800 55 | max_epoch = 30 56 | lr_decay = [24, 27] 57 | nr_images_epoch = 15000 58 | log_dump_interval = 20 59 | 60 | # ----------test config---------- # 61 | test_nms = 0.5 62 | test_nms_method = 'set_nms' 63 | visulize_threshold = 0.3 64 | pred_cls_threshold = 0.01 65 | 66 | # ----------model config---------- # 67 | batch_filter_box_size = 0 68 | nr_box_dim = 5 69 | ignore_label = -1 70 | max_boxes_of_image = 500 71 | 72 | # ----------rois generator config---------- # 73 | anchor_base_size = 32 74 | anchor_base_scale = [1] 75 | anchor_aspect_ratios = [1, 2, 3] 76 | num_cell_anchors = len(anchor_aspect_ratios) 77 | anchor_within_border = False 78 | 79 | rpn_min_box_size = 2 80 | rpn_nms_threshold = 0.7 81 | train_prev_nms_top_n = 12000 82 | train_post_nms_top_n = 2000 83 | test_prev_nms_top_n = 6000 84 | test_post_nms_top_n = 1000 85 | 86 | # ----------binding&training config---------- # 87 | rpn_smooth_l1_beta = 1 88 | rcnn_smooth_l1_beta = 1 89 | 90 | num_sample_anchors = 256 91 | positive_anchor_ratio = 0.5 92 | rpn_positive_overlap = 0.7 93 | rpn_negative_overlap = 0.3 94 | rpn_bbox_normalize_targets = False 95 | 96 | num_rois = 512 97 | fg_ratio = 0.5 98 | fg_threshold = 0.5 99 | bg_threshold_high = 0.5 100 | bg_threshold_low = 0.0 101 | rcnn_bbox_normalize_targets = True 102 | bbox_normalize_means = np.array([0, 0, 0, 0]) 103 | bbox_normalize_stds = np.array([0.1, 0.1, 0.2, 0.2]) 104 | 105 | config = Config() 106 | 107 | -------------------------------------------------------------------------------- /model/rcnn_emd_refine/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from config import config 7 | from backbone.resnet50 import ResNet50 8 | from backbone.fpn import FPN 9 | from module.rpn import RPN 10 | from layers.pooler import roi_pooler 11 | from det_oprs.bbox_opr import bbox_transform_inv_opr 12 | from det_oprs.fpn_roi_target import fpn_roi_target 13 | from det_oprs.loss_opr import emd_loss_softmax 14 | from det_oprs.utils import get_padded_tensor 15 | 16 | class Network(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | self.resnet50 = ResNet50(config.backbone_freeze_at, False) 20 | self.FPN = FPN(self.resnet50, 2, 6) 21 | self.RPN = RPN(config.rpn_channel) 22 | self.RCNN = RCNN() 23 | assert config.num_classes == 2, 'Only support two class(1fg/1bg).' 24 | 25 | def forward(self, image, im_info, gt_boxes=None): 26 | image = (image - torch.tensor(config.image_mean[None, :, None, None]).type_as(image)) / ( 27 | torch.tensor(config.image_std[None, :, None, None]).type_as(image)) 28 | image = get_padded_tensor(image, 64) 29 | if self.training: 30 | return self._forward_train(image, im_info, gt_boxes) 31 | else: 32 | return self._forward_test(image, im_info) 33 | 34 | def _forward_train(self, image, im_info, gt_boxes): 35 | loss_dict = {} 36 | fpn_fms = self.FPN(image) 37 | # fpn_fms stride: 64,32,16,8,4, p6->p2 38 | rpn_rois, loss_dict_rpn = self.RPN(fpn_fms, im_info, gt_boxes) 39 | rcnn_rois, rcnn_labels, rcnn_bbox_targets = fpn_roi_target( 40 | rpn_rois, im_info, gt_boxes, top_k=2) 41 | loss_dict_rcnn = self.RCNN(fpn_fms, rcnn_rois, 42 | rcnn_labels, rcnn_bbox_targets) 43 | loss_dict.update(loss_dict_rpn) 44 | loss_dict.update(loss_dict_rcnn) 45 | return loss_dict 46 | 47 | def _forward_test(self, image, im_info): 48 | fpn_fms = self.FPN(image) 49 | rpn_rois = self.RPN(fpn_fms, im_info) 50 | pred_bbox = self.RCNN(fpn_fms, rpn_rois) 51 | return pred_bbox.cpu().detach() 52 | 53 | class RCNN(nn.Module): 54 | def __init__(self): 55 | super().__init__() 56 | # roi head 57 | self.fc1 = nn.Linear(256*7*7, 1024) 58 | self.fc2 = nn.Linear(1024, 1024) 59 | self.fc3 = nn.Linear(1044, 1024) 60 | 61 | for l in [self.fc1, self.fc2, self.fc3]: 62 | nn.init.kaiming_uniform_(l.weight, a=1) 63 | nn.init.constant_(l.bias, 0) 64 | # box predictor 65 | self.emd_pred_cls_0 = nn.Linear(1024, config.num_classes) 66 | self.emd_pred_delta_0 = nn.Linear(1024, config.num_classes * 4) 67 | self.emd_pred_cls_1 = nn.Linear(1024, config.num_classes) 68 | self.emd_pred_delta_1 = nn.Linear(1024, config.num_classes * 4) 69 | self.ref_pred_cls_0 = nn.Linear(1024, config.num_classes) 70 | self.ref_pred_delta_0 = nn.Linear(1024, config.num_classes * 4) 71 | self.ref_pred_cls_1 = nn.Linear(1024, config.num_classes) 72 | self.ref_pred_delta_1 = nn.Linear(1024, config.num_classes * 4) 73 | for l in [self.emd_pred_cls_0, self.emd_pred_cls_1, 74 | self.ref_pred_cls_0, self.ref_pred_cls_1]: 75 | nn.init.normal_(l.weight, std=0.001) 76 | nn.init.constant_(l.bias, 0) 77 | for l in [self.emd_pred_delta_0, self.emd_pred_delta_1, 78 | self.ref_pred_delta_0, self.ref_pred_delta_1]: 79 | nn.init.normal_(l.weight, std=0.001) 80 | nn.init.constant_(l.bias, 0) 81 | 82 | def forward(self, fpn_fms, rcnn_rois, labels=None, bbox_targets=None): 83 | # stride: 64,32,16,8,4 -> 4, 8, 16, 32 84 | fpn_fms = fpn_fms[1:][::-1] 85 | stride = [4, 8, 16, 32] 86 | pool_features = roi_pooler(fpn_fms, rcnn_rois, stride, (7, 7), "ROIAlignV2") 87 | flatten_feature = torch.flatten(pool_features, start_dim=1) 88 | flatten_feature = F.relu_(self.fc1(flatten_feature)) 89 | flatten_feature = F.relu_(self.fc2(flatten_feature)) 90 | pred_emd_cls_0 = self.emd_pred_cls_0(flatten_feature) 91 | pred_emd_delta_0 = self.emd_pred_delta_0(flatten_feature) 92 | pred_emd_cls_1 = self.emd_pred_cls_1(flatten_feature) 93 | pred_emd_delta_1 = self.emd_pred_delta_1(flatten_feature) 94 | pred_emd_scores_0 = F.softmax(pred_emd_cls_0, dim=-1) 95 | pred_emd_scores_1 = F.softmax(pred_emd_cls_1, dim=-1) 96 | # cons refine feature 97 | boxes_feature_0 = torch.cat((pred_emd_delta_0[:, 4:], 98 | pred_emd_scores_0[:, 1][:, None]), dim=1).repeat(1, 4) 99 | boxes_feature_1 = torch.cat((pred_emd_delta_1[:, 4:], 100 | pred_emd_scores_1[:, 1][:, None]), dim=1).repeat(1, 4) 101 | boxes_feature_0 = torch.cat((flatten_feature, boxes_feature_0), dim=1) 102 | boxes_feature_1 = torch.cat((flatten_feature, boxes_feature_1), dim=1) 103 | refine_feature_0 = F.relu_(self.fc3(boxes_feature_0)) 104 | refine_feature_1 = F.relu_(self.fc3(boxes_feature_1)) 105 | # refine 106 | pred_ref_cls_0 = self.ref_pred_cls_0(refine_feature_0) 107 | pred_ref_delta_0 = self.ref_pred_delta_0(refine_feature_0) 108 | pred_ref_cls_1 = self.ref_pred_cls_1(refine_feature_1) 109 | pred_ref_delta_1 = self.ref_pred_delta_1(refine_feature_1) 110 | if self.training: 111 | loss0 = emd_loss_softmax( 112 | pred_emd_delta_0, pred_emd_cls_0, 113 | pred_emd_delta_1, pred_emd_cls_1, 114 | bbox_targets, labels) 115 | loss1 = emd_loss_softmax( 116 | pred_emd_delta_1, pred_emd_cls_1, 117 | pred_emd_delta_0, pred_emd_cls_0, 118 | bbox_targets, labels) 119 | loss2 = emd_loss_softmax( 120 | pred_ref_delta_0, pred_ref_cls_0, 121 | pred_ref_delta_1, pred_ref_cls_1, 122 | bbox_targets, labels) 123 | loss3 = emd_loss_softmax( 124 | pred_ref_delta_1, pred_ref_cls_1, 125 | pred_ref_delta_0, pred_ref_cls_0, 126 | bbox_targets, labels) 127 | loss_rcnn = torch.cat([loss0, loss1], axis=1) 128 | loss_ref = torch.cat([loss2, loss3], axis=1) 129 | # requires_grad = False 130 | _, min_indices_rcnn = loss_rcnn.min(axis=1) 131 | _, min_indices_ref = loss_ref.min(axis=1) 132 | loss_rcnn = loss_rcnn[torch.arange(loss_rcnn.shape[0]), min_indices_rcnn] 133 | loss_rcnn = loss_rcnn.mean() 134 | loss_ref = loss_ref[torch.arange(loss_ref.shape[0]), min_indices_ref] 135 | loss_ref = loss_ref.mean() 136 | loss_dict = {} 137 | loss_dict['loss_rcnn_emd'] = loss_rcnn 138 | loss_dict['loss_ref_emd'] = loss_ref 139 | return loss_dict 140 | else: 141 | class_num = pred_ref_cls_0.shape[-1] - 1 142 | tag = torch.arange(class_num).type_as(pred_ref_cls_0)+1 143 | tag = tag.repeat(pred_ref_cls_0.shape[0], 1).reshape(-1,1) 144 | pred_scores_0 = F.softmax(pred_ref_cls_0, dim=-1)[:, 1:].reshape(-1, 1) 145 | pred_scores_1 = F.softmax(pred_ref_cls_1, dim=-1)[:, 1:].reshape(-1, 1) 146 | pred_delta_0 = pred_ref_delta_0[:, 4:].reshape(-1, 4) 147 | pred_delta_1 = pred_ref_delta_1[:, 4:].reshape(-1, 4) 148 | base_rois = rcnn_rois[:, 1:5].repeat(1, class_num).reshape(-1, 4) 149 | pred_bbox_0 = restore_bbox(base_rois, pred_delta_0, True) 150 | pred_bbox_1 = restore_bbox(base_rois, pred_delta_1, True) 151 | pred_bbox_0 = torch.cat([pred_bbox_0, pred_scores_0, tag], axis=1) 152 | pred_bbox_1 = torch.cat([pred_bbox_1, pred_scores_1, tag], axis=1) 153 | pred_bbox = torch.cat((pred_bbox_0, pred_bbox_1), axis=1) 154 | return pred_bbox 155 | 156 | def restore_bbox(rois, deltas, unnormalize=True): 157 | if unnormalize: 158 | std_opr = torch.tensor(config.bbox_normalize_stds[None, :]).type_as(deltas) 159 | mean_opr = torch.tensor(config.bbox_normalize_means[None, :]).type_as(deltas) 160 | deltas = deltas * std_opr 161 | deltas = deltas + mean_opr 162 | pred_bbox = bbox_transform_inv_opr(rois, deltas) 163 | return pred_bbox 164 | -------------------------------------------------------------------------------- /model/rcnn_emd_simple/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | 6 | def add_path(path): 7 | if path not in sys.path: 8 | sys.path.insert(0, path) 9 | 10 | root_dir = '../../' 11 | add_path(os.path.join(root_dir)) 12 | add_path(os.path.join(root_dir, 'lib')) 13 | 14 | class Crowd_human: 15 | class_names = ['background', 'person'] 16 | num_classes = len(class_names) 17 | root_folder = '/data/CrowdHuman' 18 | image_folder = '/data/CrowdHuman/images' 19 | train_source = os.path.join('/data/CrowdHuman/annotation_train.odgt') 20 | eval_source = os.path.join('/data/CrowdHuman/annotation_val.odgt') 21 | 22 | class Config: 23 | output_dir = 'outputs' 24 | model_dir = os.path.join(output_dir, 'model_dump') 25 | eval_dir = os.path.join(output_dir, 'eval_dump') 26 | init_weights = '/data/model/resnet50_fbaug.pth' 27 | 28 | # ----------data config---------- # 29 | image_mean = np.array([103.530, 116.280, 123.675]) 30 | image_std = np.array([57.375, 57.120, 58.395]) 31 | train_image_short_size = 800 32 | train_image_max_size = 1400 33 | eval_resize = True 34 | eval_image_short_size = 800 35 | eval_image_max_size = 1400 36 | seed_dataprovider = 3 37 | train_source = Crowd_human.train_source 38 | eval_source = Crowd_human.eval_source 39 | image_folder = Crowd_human.image_folder 40 | class_names = Crowd_human.class_names 41 | num_classes = Crowd_human.num_classes 42 | class_names2id = dict(list(zip(class_names, list(range(num_classes))))) 43 | gt_boxes_name = 'fbox' 44 | 45 | # ----------train config---------- # 46 | backbone_freeze_at = 2 47 | rpn_channel = 256 48 | 49 | train_batch_per_gpu = 2 50 | momentum = 0.9 51 | weight_decay = 1e-4 52 | base_lr = 1e-3 * 1.25 53 | 54 | warm_iter = 800 55 | max_epoch = 30 56 | lr_decay = [24, 27] 57 | nr_images_epoch = 15000 58 | log_dump_interval = 20 59 | 60 | # ----------test config---------- # 61 | test_nms = 0.5 62 | test_nms_method = 'set_nms' 63 | visulize_threshold = 0.3 64 | pred_cls_threshold = 0.01 65 | 66 | # ----------model config---------- # 67 | batch_filter_box_size = 0 68 | nr_box_dim = 5 69 | ignore_label = -1 70 | max_boxes_of_image = 500 71 | 72 | # ----------rois generator config---------- # 73 | anchor_base_size = 32 74 | anchor_base_scale = [1] 75 | anchor_aspect_ratios = [1, 2, 3] 76 | num_cell_anchors = len(anchor_aspect_ratios) 77 | anchor_within_border = False 78 | 79 | rpn_min_box_size = 2 80 | rpn_nms_threshold = 0.7 81 | train_prev_nms_top_n = 12000 82 | train_post_nms_top_n = 2000 83 | test_prev_nms_top_n = 6000 84 | test_post_nms_top_n = 1000 85 | 86 | # ----------binding&training config---------- # 87 | rpn_smooth_l1_beta = 1 88 | rcnn_smooth_l1_beta = 1 89 | 90 | num_sample_anchors = 256 91 | positive_anchor_ratio = 0.5 92 | rpn_positive_overlap = 0.7 93 | rpn_negative_overlap = 0.3 94 | rpn_bbox_normalize_targets = False 95 | 96 | num_rois = 512 97 | fg_ratio = 0.5 98 | fg_threshold = 0.5 99 | bg_threshold_high = 0.5 100 | bg_threshold_low = 0.0 101 | rcnn_bbox_normalize_targets = True 102 | bbox_normalize_means = np.array([0, 0, 0, 0]) 103 | bbox_normalize_stds = np.array([0.1, 0.1, 0.2, 0.2]) 104 | 105 | config = Config() 106 | 107 | -------------------------------------------------------------------------------- /model/rcnn_emd_simple/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from config import config 7 | from backbone.resnet50 import ResNet50 8 | from backbone.fpn import FPN 9 | from module.rpn import RPN 10 | from layers.pooler import roi_pooler 11 | from det_oprs.bbox_opr import bbox_transform_inv_opr 12 | from det_oprs.fpn_roi_target import fpn_roi_target 13 | from det_oprs.loss_opr import emd_loss_softmax 14 | from det_oprs.utils import get_padded_tensor 15 | 16 | class Network(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | self.resnet50 = ResNet50(config.backbone_freeze_at, False) 20 | self.FPN = FPN(self.resnet50, 2, 6) 21 | self.RPN = RPN(config.rpn_channel) 22 | self.RCNN = RCNN() 23 | 24 | def forward(self, image, im_info, gt_boxes=None): 25 | image = (image - torch.tensor(config.image_mean[None, :, None, None]).type_as(image)) / ( 26 | torch.tensor(config.image_std[None, :, None, None]).type_as(image)) 27 | image = get_padded_tensor(image, 64) 28 | if self.training: 29 | return self._forward_train(image, im_info, gt_boxes) 30 | else: 31 | return self._forward_test(image, im_info) 32 | 33 | def _forward_train(self, image, im_info, gt_boxes): 34 | loss_dict = {} 35 | fpn_fms = self.FPN(image) 36 | # fpn_fms stride: 64,32,16,8,4, p6->p2 37 | rpn_rois, loss_dict_rpn = self.RPN(fpn_fms, im_info, gt_boxes) 38 | rcnn_rois, rcnn_labels, rcnn_bbox_targets = fpn_roi_target( 39 | rpn_rois, im_info, gt_boxes, top_k=2) 40 | loss_dict_rcnn = self.RCNN(fpn_fms, rcnn_rois, 41 | rcnn_labels, rcnn_bbox_targets) 42 | loss_dict.update(loss_dict_rpn) 43 | loss_dict.update(loss_dict_rcnn) 44 | return loss_dict 45 | 46 | def _forward_test(self, image, im_info): 47 | fpn_fms = self.FPN(image) 48 | rpn_rois = self.RPN(fpn_fms, im_info) 49 | pred_bbox = self.RCNN(fpn_fms, rpn_rois) 50 | return pred_bbox.cpu().detach() 51 | 52 | class RCNN(nn.Module): 53 | def __init__(self): 54 | super().__init__() 55 | # roi head 56 | self.fc1 = nn.Linear(256*7*7, 1024) 57 | self.fc2 = nn.Linear(1024, 1024) 58 | 59 | for l in [self.fc1, self.fc2]: 60 | nn.init.kaiming_uniform_(l.weight, a=1) 61 | nn.init.constant_(l.bias, 0) 62 | # box predictor 63 | self.emd_pred_cls_0 = nn.Linear(1024, config.num_classes) 64 | self.emd_pred_delta_0 = nn.Linear(1024, config.num_classes * 4) 65 | self.emd_pred_cls_1 = nn.Linear(1024, config.num_classes) 66 | self.emd_pred_delta_1 = nn.Linear(1024, config.num_classes * 4) 67 | for l in [self.emd_pred_cls_0, self.emd_pred_cls_1]: 68 | nn.init.normal_(l.weight, std=0.01) 69 | nn.init.constant_(l.bias, 0) 70 | for l in [self.emd_pred_delta_0, self.emd_pred_delta_1]: 71 | nn.init.normal_(l.weight, std=0.001) 72 | nn.init.constant_(l.bias, 0) 73 | 74 | def forward(self, fpn_fms, rcnn_rois, labels=None, bbox_targets=None): 75 | # stride: 64,32,16,8,4 -> 4, 8, 16, 32 76 | fpn_fms = fpn_fms[1:][::-1] 77 | stride = [4, 8, 16, 32] 78 | pool_features = roi_pooler(fpn_fms, rcnn_rois, stride, (7, 7), "ROIAlignV2") 79 | flatten_feature = torch.flatten(pool_features, start_dim=1) 80 | flatten_feature = F.relu_(self.fc1(flatten_feature)) 81 | flatten_feature = F.relu_(self.fc2(flatten_feature)) 82 | pred_emd_cls_0 = self.emd_pred_cls_0(flatten_feature) 83 | pred_emd_delta_0 = self.emd_pred_delta_0(flatten_feature) 84 | pred_emd_cls_1 = self.emd_pred_cls_1(flatten_feature) 85 | pred_emd_delta_1 = self.emd_pred_delta_1(flatten_feature) 86 | if self.training: 87 | loss0 = emd_loss_softmax( 88 | pred_emd_delta_0, pred_emd_cls_0, 89 | pred_emd_delta_1, pred_emd_cls_1, 90 | bbox_targets, labels) 91 | loss1 = emd_loss_softmax( 92 | pred_emd_delta_1, pred_emd_cls_1, 93 | pred_emd_delta_0, pred_emd_cls_0, 94 | bbox_targets, labels) 95 | loss = torch.cat([loss0, loss1], axis=1) 96 | # requires_grad = False 97 | _, min_indices = loss.min(axis=1) 98 | loss_emd = loss[torch.arange(loss.shape[0]), min_indices] 99 | loss_emd = loss_emd.mean() 100 | loss_dict = {} 101 | loss_dict['loss_rcnn_emd'] = loss_emd 102 | return loss_dict 103 | else: 104 | class_num = pred_emd_cls_0.shape[-1] - 1 105 | tag = torch.arange(class_num).type_as(pred_emd_cls_0)+1 106 | tag = tag.repeat(pred_emd_cls_0.shape[0], 1).reshape(-1,1) 107 | pred_scores_0 = F.softmax(pred_emd_cls_0, dim=-1)[:, 1:].reshape(-1, 1) 108 | pred_scores_1 = F.softmax(pred_emd_cls_1, dim=-1)[:, 1:].reshape(-1, 1) 109 | pred_delta_0 = pred_emd_delta_0[:, 4:].reshape(-1, 4) 110 | pred_delta_1 = pred_emd_delta_1[:, 4:].reshape(-1, 4) 111 | base_rois = rcnn_rois[:, 1:5].repeat(1, class_num).reshape(-1, 4) 112 | pred_bbox_0 = restore_bbox(base_rois, pred_delta_0, True) 113 | pred_bbox_1 = restore_bbox(base_rois, pred_delta_1, True) 114 | pred_bbox_0 = torch.cat([pred_bbox_0, pred_scores_0, tag], axis=1) 115 | pred_bbox_1 = torch.cat([pred_bbox_1, pred_scores_1, tag], axis=1) 116 | pred_bbox = torch.cat((pred_bbox_0, pred_bbox_1), axis=1) 117 | return pred_bbox 118 | 119 | def restore_bbox(rois, deltas, unnormalize=True): 120 | if unnormalize: 121 | std_opr = torch.tensor(config.bbox_normalize_stds[None, :]).type_as(deltas) 122 | mean_opr = torch.tensor(config.bbox_normalize_means[None, :]).type_as(deltas) 123 | deltas = deltas * std_opr 124 | deltas = deltas + mean_opr 125 | pred_bbox = bbox_transform_inv_opr(rois, deltas) 126 | return pred_bbox 127 | -------------------------------------------------------------------------------- /model/rcnn_fpn_baseline/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | 6 | def add_path(path): 7 | if path not in sys.path: 8 | sys.path.insert(0, path) 9 | 10 | root_dir = '../../' 11 | add_path(os.path.join(root_dir)) 12 | add_path(os.path.join(root_dir, 'lib')) 13 | 14 | class Crowd_human: 15 | class_names = ['background', 'person'] 16 | num_classes = len(class_names) 17 | root_folder = '/data/CrowdHuman' 18 | image_folder = '/data/CrowdHuman/images' 19 | train_source = os.path.join('/data/CrowdHuman/annotation_train.odgt') 20 | eval_source = os.path.join('/data/CrowdHuman/annotation_val.odgt') 21 | 22 | class Config: 23 | output_dir = 'outputs' 24 | model_dir = os.path.join(output_dir, 'model_dump') 25 | eval_dir = os.path.join(output_dir, 'eval_dump') 26 | init_weights = '/data/model/resnet50_fbaug.pth' 27 | 28 | # ----------data config---------- # 29 | image_mean = np.array([103.530, 116.280, 123.675]) 30 | image_std = np.array([57.375, 57.120, 58.395]) 31 | train_image_short_size = 800 32 | train_image_max_size = 1400 33 | eval_resize = True 34 | eval_image_short_size = 800 35 | eval_image_max_size = 1400 36 | seed_dataprovider = 3 37 | train_source = Crowd_human.train_source 38 | eval_source = Crowd_human.eval_source 39 | image_folder = Crowd_human.image_folder 40 | class_names = Crowd_human.class_names 41 | num_classes = Crowd_human.num_classes 42 | class_names2id = dict(list(zip(class_names, list(range(num_classes))))) 43 | gt_boxes_name = 'fbox' 44 | 45 | # ----------train config---------- # 46 | backbone_freeze_at = 2 47 | rpn_channel = 256 48 | 49 | train_batch_per_gpu = 2 50 | momentum = 0.9 51 | weight_decay = 1e-4 52 | base_lr = 1e-3 * 1.25 53 | 54 | warm_iter = 800 55 | max_epoch = 30 56 | lr_decay = [24, 27] 57 | nr_images_epoch = 15000 58 | log_dump_interval = 20 59 | 60 | # ----------test config---------- # 61 | test_nms = 0.5 62 | test_nms_method = 'normal_nms' 63 | visulize_threshold = 0.3 64 | pred_cls_threshold = 0.01 65 | 66 | # ----------model config---------- # 67 | batch_filter_box_size = 0 68 | nr_box_dim = 5 69 | ignore_label = -1 70 | max_boxes_of_image = 500 71 | 72 | # ----------rois generator config---------- # 73 | anchor_base_size = 32 74 | anchor_base_scale = [1] 75 | anchor_aspect_ratios = [1, 2, 3] 76 | num_cell_anchors = len(anchor_aspect_ratios) 77 | anchor_within_border = False 78 | 79 | rpn_min_box_size = 2 80 | rpn_nms_threshold = 0.7 81 | train_prev_nms_top_n = 12000 82 | train_post_nms_top_n = 2000 83 | test_prev_nms_top_n = 6000 84 | test_post_nms_top_n = 1000 85 | 86 | # ----------binding&training config---------- # 87 | rpn_smooth_l1_beta = 1 88 | rcnn_smooth_l1_beta = 1 89 | 90 | num_sample_anchors = 256 91 | positive_anchor_ratio = 0.5 92 | rpn_positive_overlap = 0.7 93 | rpn_negative_overlap = 0.3 94 | rpn_bbox_normalize_targets = False 95 | 96 | num_rois = 512 97 | fg_ratio = 0.5 98 | fg_threshold = 0.5 99 | bg_threshold_high = 0.5 100 | bg_threshold_low = 0.0 101 | rcnn_bbox_normalize_targets = True 102 | bbox_normalize_means = np.array([0, 0, 0, 0]) 103 | bbox_normalize_stds = np.array([0.1, 0.1, 0.2, 0.2]) 104 | 105 | config = Config() 106 | 107 | -------------------------------------------------------------------------------- /model/rcnn_fpn_baseline/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from config import config 7 | from backbone.resnet50 import ResNet50 8 | from backbone.fpn import FPN 9 | from module.rpn import RPN 10 | from layers.pooler import roi_pooler 11 | from det_oprs.bbox_opr import bbox_transform_inv_opr 12 | from det_oprs.fpn_roi_target import fpn_roi_target 13 | from det_oprs.loss_opr import softmax_loss, smooth_l1_loss 14 | from det_oprs.utils import get_padded_tensor 15 | 16 | class Network(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | self.resnet50 = ResNet50(config.backbone_freeze_at, False) 20 | self.FPN = FPN(self.resnet50, 2, 6) 21 | self.RPN = RPN(config.rpn_channel) 22 | self.RCNN = RCNN() 23 | 24 | def forward(self, image, im_info, gt_boxes=None): 25 | image = (image - torch.tensor(config.image_mean[None, :, None, None]).type_as(image)) / ( 26 | torch.tensor(config.image_std[None, :, None, None]).type_as(image)) 27 | image = get_padded_tensor(image, 64) 28 | if self.training: 29 | return self._forward_train(image, im_info, gt_boxes) 30 | else: 31 | return self._forward_test(image, im_info) 32 | 33 | def _forward_train(self, image, im_info, gt_boxes): 34 | loss_dict = {} 35 | fpn_fms = self.FPN(image) 36 | # fpn_fms stride: 64,32,16,8,4, p6->p2 37 | rpn_rois, loss_dict_rpn = self.RPN(fpn_fms, im_info, gt_boxes) 38 | rcnn_rois, rcnn_labels, rcnn_bbox_targets = fpn_roi_target( 39 | rpn_rois, im_info, gt_boxes, top_k=1) 40 | loss_dict_rcnn = self.RCNN(fpn_fms, rcnn_rois, 41 | rcnn_labels, rcnn_bbox_targets) 42 | loss_dict.update(loss_dict_rpn) 43 | loss_dict.update(loss_dict_rcnn) 44 | return loss_dict 45 | 46 | def _forward_test(self, image, im_info): 47 | fpn_fms = self.FPN(image) 48 | rpn_rois = self.RPN(fpn_fms, im_info) 49 | pred_bbox = self.RCNN(fpn_fms, rpn_rois) 50 | return pred_bbox.cpu().detach() 51 | 52 | class RCNN(nn.Module): 53 | def __init__(self): 54 | super().__init__() 55 | # roi head 56 | self.fc1 = nn.Linear(256*7*7, 1024) 57 | self.fc2 = nn.Linear(1024, 1024) 58 | 59 | for l in [self.fc1, self.fc2]: 60 | nn.init.kaiming_uniform_(l.weight, a=1) 61 | nn.init.constant_(l.bias, 0) 62 | # box predictor 63 | self.pred_cls = nn.Linear(1024, config.num_classes) 64 | self.pred_delta = nn.Linear(1024, config.num_classes * 4) 65 | for l in [self.pred_cls]: 66 | nn.init.normal_(l.weight, std=0.01) 67 | nn.init.constant_(l.bias, 0) 68 | for l in [self.pred_delta]: 69 | nn.init.normal_(l.weight, std=0.001) 70 | nn.init.constant_(l.bias, 0) 71 | 72 | def forward(self, fpn_fms, rcnn_rois, labels=None, bbox_targets=None): 73 | # input p2-p5 74 | fpn_fms = fpn_fms[1:][::-1] 75 | stride = [4, 8, 16, 32] 76 | pool_features = roi_pooler(fpn_fms, rcnn_rois, stride, (7, 7), "ROIAlignV2") 77 | flatten_feature = torch.flatten(pool_features, start_dim=1) 78 | flatten_feature = F.relu_(self.fc1(flatten_feature)) 79 | flatten_feature = F.relu_(self.fc2(flatten_feature)) 80 | pred_cls = self.pred_cls(flatten_feature) 81 | pred_delta = self.pred_delta(flatten_feature) 82 | if self.training: 83 | # loss for regression 84 | labels = labels.long().flatten() 85 | fg_masks = labels > 0 86 | valid_masks = labels >= 0 87 | # multi class 88 | pred_delta = pred_delta.reshape(-1, config.num_classes, 4) 89 | fg_gt_classes = labels[fg_masks] 90 | pred_delta = pred_delta[fg_masks, fg_gt_classes, :] 91 | localization_loss = smooth_l1_loss( 92 | pred_delta, 93 | bbox_targets[fg_masks], 94 | config.rcnn_smooth_l1_beta) 95 | # loss for classification 96 | objectness_loss = softmax_loss(pred_cls, labels) 97 | objectness_loss = objectness_loss * valid_masks 98 | normalizer = 1.0 / valid_masks.sum().item() 99 | loss_rcnn_loc = localization_loss.sum() * normalizer 100 | loss_rcnn_cls = objectness_loss.sum() * normalizer 101 | loss_dict = {} 102 | loss_dict['loss_rcnn_loc'] = loss_rcnn_loc 103 | loss_dict['loss_rcnn_cls'] = loss_rcnn_cls 104 | return loss_dict 105 | else: 106 | class_num = pred_cls.shape[-1] - 1 107 | tag = torch.arange(class_num).type_as(pred_cls)+1 108 | tag = tag.repeat(pred_cls.shape[0], 1).reshape(-1,1) 109 | pred_scores = F.softmax(pred_cls, dim=-1)[:, 1:].reshape(-1, 1) 110 | pred_delta = pred_delta[:, 4:].reshape(-1, 4) 111 | base_rois = rcnn_rois[:, 1:5].repeat(1, class_num).reshape(-1, 4) 112 | pred_bbox = restore_bbox(base_rois, pred_delta, True) 113 | pred_bbox = torch.cat([pred_bbox, pred_scores, tag], axis=1) 114 | return pred_bbox 115 | 116 | def restore_bbox(rois, deltas, unnormalize=True): 117 | if unnormalize: 118 | std_opr = torch.tensor(config.bbox_normalize_stds[None, :]).type_as(deltas) 119 | mean_opr = torch.tensor(config.bbox_normalize_means[None, :]).type_as(deltas) 120 | deltas = deltas * std_opr 121 | deltas = deltas + mean_opr 122 | pred_bbox = bbox_transform_inv_opr(rois, deltas) 123 | return pred_bbox 124 | -------------------------------------------------------------------------------- /model/retina_emd_simple/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | 6 | def add_path(path): 7 | if path not in sys.path: 8 | sys.path.insert(0, path) 9 | 10 | root_dir = '../../' 11 | add_path(os.path.join(root_dir)) 12 | add_path(os.path.join(root_dir, 'lib')) 13 | 14 | class Crowd_human: 15 | class_names = ['background', 'person'] 16 | num_classes = len(class_names) 17 | root_folder = '/data/CrowdHuman' 18 | image_folder = '/data/CrowdHuman/images' 19 | train_source = os.path.join('/data/CrowdHuman/annotation_train.odgt') 20 | eval_source = os.path.join('/data/CrowdHuman/annotation_val.odgt') 21 | 22 | class Config: 23 | output_dir = 'outputs' 24 | model_dir = os.path.join(output_dir, 'model_dump') 25 | eval_dir = os.path.join(output_dir, 'eval_dump') 26 | init_weights = '/data/model/resnet50_fbaug.pth' 27 | 28 | # ----------data config---------- # 29 | image_mean = np.array([103.530, 116.280, 123.675]) 30 | image_std = np.array([57.375, 57.120, 58.395]) 31 | train_image_short_size = 800 32 | train_image_max_size = 1400 33 | eval_resize = True 34 | eval_image_short_size = 800 35 | eval_image_max_size = 1400 36 | seed_dataprovider = 3 37 | train_source = Crowd_human.train_source 38 | eval_source = Crowd_human.eval_source 39 | image_folder = Crowd_human.image_folder 40 | class_names = Crowd_human.class_names 41 | num_classes = Crowd_human.num_classes 42 | class_names2id = dict(list(zip(class_names, list(range(num_classes))))) 43 | gt_boxes_name = 'fbox' 44 | 45 | # ----------train config---------- # 46 | backbone_freeze_at = 2 47 | train_batch_per_gpu = 2 48 | momentum = 0.9 49 | weight_decay = 1e-4 50 | base_lr = 3.125e-4 51 | focal_loss_alpha = 0.25 52 | focal_loss_gamma = 2 53 | 54 | warm_iter = 800 55 | max_epoch = 50 56 | lr_decay = [33, 43] 57 | nr_images_epoch = 15000 58 | log_dump_interval = 20 59 | 60 | # ----------test config---------- # 61 | test_layer_topk = 1000 62 | test_nms = 0.5 63 | test_nms_method = 'set_nms' 64 | visulize_threshold = 0.3 65 | pred_cls_threshold = 0.01 66 | 67 | # ----------dataset config---------- # 68 | nr_box_dim = 5 69 | max_boxes_of_image = 500 70 | 71 | # --------anchor generator config-------- # 72 | anchor_base_size = 32 # the minimize anchor size in the bigest feature map. 73 | anchor_base_scale = [2**0, 2**(1/3), 2**(2/3)] 74 | anchor_aspect_ratios = [1, 2, 3] 75 | num_cell_anchors = len(anchor_aspect_ratios) * len(anchor_base_scale) 76 | 77 | # ----------binding&training config---------- # 78 | smooth_l1_beta = 0.1 79 | negative_thresh = 0.4 80 | positive_thresh = 0.5 81 | allow_low_quality = True 82 | 83 | config = Config() 84 | -------------------------------------------------------------------------------- /model/retina_emd_simple/network.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from config import config 8 | from backbone.resnet50 import ResNet50 9 | from backbone.fpn import FPN 10 | from det_oprs.anchors_generator import AnchorGenerator 11 | from det_oprs.retina_anchor_target import retina_anchor_target 12 | from det_oprs.bbox_opr import bbox_transform_inv_opr 13 | from det_oprs.loss_opr import emd_loss_focal 14 | from det_oprs.utils import get_padded_tensor 15 | 16 | class Network(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | self.resnet50 = ResNet50(config.backbone_freeze_at, False) 20 | self.FPN = FPN(self.resnet50, 3, 7) 21 | self.R_Head = RetinaNet_Head() 22 | self.R_Anchor = RetinaNet_Anchor() 23 | self.R_Criteria = RetinaNet_Criteria() 24 | 25 | def forward(self, image, im_info, gt_boxes=None): 26 | # pre-processing the data 27 | image = (image - torch.tensor(config.image_mean[None, :, None, None]).type_as(image)) / ( 28 | torch.tensor(config.image_std[None, :, None, None]).type_as(image)) 29 | image = get_padded_tensor(image, 64) 30 | # do inference 31 | # stride: 128,64,32,16,8, p7->p3 32 | fpn_fms = self.FPN(image) 33 | anchors_list = self.R_Anchor(fpn_fms) 34 | pred_cls_list, pred_reg_list = self.R_Head(fpn_fms) 35 | # release the useless data 36 | if self.training: 37 | loss_dict = self.R_Criteria( 38 | pred_cls_list, pred_reg_list, anchors_list, gt_boxes, im_info) 39 | return loss_dict 40 | else: 41 | #pred_bbox = union_inference( 42 | # anchors_list, pred_cls_list, pred_reg_list, im_info) 43 | pred_bbox = per_layer_inference( 44 | anchors_list, pred_cls_list, pred_reg_list, im_info) 45 | return pred_bbox.cpu().detach() 46 | 47 | class RetinaNet_Anchor(): 48 | def __init__(self): 49 | self.anchors_generator = AnchorGenerator( 50 | config.anchor_base_size, 51 | config.anchor_aspect_ratios, 52 | config.anchor_base_scale) 53 | 54 | def __call__(self, fpn_fms): 55 | # get anchors 56 | all_anchors_list = [] 57 | base_stride = 8 58 | off_stride = 2**(len(fpn_fms)-1) # 16 59 | for fm in fpn_fms: 60 | layer_anchors = self.anchors_generator(fm, base_stride, off_stride) 61 | off_stride = off_stride // 2 62 | all_anchors_list.append(layer_anchors) 63 | return all_anchors_list 64 | 65 | class RetinaNet_Criteria(nn.Module): 66 | def __init__(self): 67 | super().__init__() 68 | self.loss_normalizer = 100 # initialize with any reasonable #fg that's not too small 69 | self.loss_normalizer_momentum = 0.9 70 | 71 | def __call__(self, pred_cls_list, pred_reg_list, anchors_list, gt_boxes, im_info): 72 | all_anchors = torch.cat(anchors_list, axis=0) 73 | all_pred_cls = torch.cat(pred_cls_list, axis=1).reshape(-1, (config.num_classes-1)*2) 74 | all_pred_cls = torch.sigmoid(all_pred_cls) 75 | all_pred_reg = torch.cat(pred_reg_list, axis=1).reshape(-1, 4*2) 76 | # get ground truth 77 | labels, bbox_targets = retina_anchor_target(all_anchors, gt_boxes, im_info, top_k=2) 78 | all_pred_cls = all_pred_cls.reshape(-1, 2, config.num_classes-1) 79 | all_pred_reg = all_pred_reg.reshape(-1, 2, 4) 80 | loss0 = emd_loss_focal( 81 | all_pred_reg[:, 0], all_pred_cls[:, 0], 82 | all_pred_reg[:, 1], all_pred_cls[:, 1], 83 | bbox_targets, labels) 84 | loss1 = emd_loss_focal( 85 | all_pred_reg[:, 1], all_pred_cls[:, 1], 86 | all_pred_reg[:, 0], all_pred_cls[:, 0], 87 | bbox_targets, labels) 88 | del all_anchors 89 | del all_pred_cls 90 | del all_pred_reg 91 | loss = torch.cat([loss0, loss1], axis=1) 92 | # requires_grad = False 93 | _, min_indices = loss.min(axis=1) 94 | loss_emd = loss[torch.arange(loss.shape[0]), min_indices] 95 | # only main labels 96 | num_pos = (labels[:, 0] > 0).sum().item() 97 | self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + ( 98 | 1 - self.loss_normalizer_momentum) * max(num_pos, 1) 99 | loss_emd = loss_emd.sum() / self.loss_normalizer 100 | loss_dict = {} 101 | loss_dict['retina_emd'] = loss_emd 102 | return loss_dict 103 | 104 | class RetinaNet_Head(nn.Module): 105 | def __init__(self): 106 | super().__init__() 107 | num_convs = 4 108 | in_channels = 256 109 | cls_subnet = [] 110 | bbox_subnet = [] 111 | for _ in range(num_convs): 112 | cls_subnet.append( 113 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) 114 | ) 115 | cls_subnet.append(nn.ReLU(inplace=True)) 116 | bbox_subnet.append( 117 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) 118 | ) 119 | bbox_subnet.append(nn.ReLU(inplace=True)) 120 | self.cls_subnet = nn.Sequential(*cls_subnet) 121 | self.bbox_subnet = nn.Sequential(*bbox_subnet) 122 | # predictor 123 | self.cls_score = nn.Conv2d( 124 | in_channels, config.num_cell_anchors * (config.num_classes-1) * 2, 125 | kernel_size=3, stride=1, padding=1) 126 | self.bbox_pred = nn.Conv2d( 127 | in_channels, config.num_cell_anchors * 4 * 2, 128 | kernel_size=3, stride=1, padding=1) 129 | 130 | # Initialization 131 | for modules in [self.cls_subnet, self.bbox_subnet, 132 | self.cls_score, self.bbox_pred]: 133 | for layer in modules.modules(): 134 | if isinstance(layer, nn.Conv2d): 135 | torch.nn.init.normal_(layer.weight, mean=0, std=0.01) 136 | torch.nn.init.constant_(layer.bias, 0) 137 | prior_prob = 0.01 138 | # Use prior in model initialization to improve stability 139 | bias_value = -(math.log((1 - prior_prob) / prior_prob)) 140 | torch.nn.init.constant_(self.cls_score.bias, bias_value) 141 | 142 | def forward(self, features): 143 | pred_cls = [] 144 | pred_reg = [] 145 | for feature in features: 146 | pred_cls.append(self.cls_score(self.cls_subnet(feature))) 147 | pred_reg.append(self.bbox_pred(self.bbox_subnet(feature))) 148 | # reshape the predictions 149 | assert pred_cls[0].dim() == 4 150 | pred_cls_list = [ 151 | _.permute(0, 2, 3, 1).reshape(pred_cls[0].shape[0], -1, (config.num_classes-1)*2) 152 | for _ in pred_cls] 153 | pred_reg_list = [ 154 | _.permute(0, 2, 3, 1).reshape(pred_reg[0].shape[0], -1, 4*2) 155 | for _ in pred_reg] 156 | return pred_cls_list, pred_reg_list 157 | 158 | def per_layer_inference(anchors_list, pred_cls_list, pred_reg_list, im_info): 159 | keep_anchors = [] 160 | keep_cls = [] 161 | keep_reg = [] 162 | class_num = pred_cls_list[0].shape[-1] // 2 163 | for l_id in range(len(anchors_list)): 164 | anchors = anchors_list[l_id].reshape(-1, 4) 165 | pred_cls = pred_cls_list[l_id][0].reshape(-1, class_num*2) 166 | pred_reg = pred_reg_list[l_id][0].reshape(-1, 4*2) 167 | if len(anchors) > config.test_layer_topk: 168 | ruler = pred_cls.max(axis=1)[0] 169 | _, inds = ruler.topk(config.test_layer_topk, dim=0) 170 | inds = inds.flatten() 171 | keep_anchors.append(anchors[inds]) 172 | keep_cls.append(torch.sigmoid(pred_cls[inds])) 173 | keep_reg.append(pred_reg[inds]) 174 | else: 175 | keep_anchors.append(anchors) 176 | keep_cls.append(torch.sigmoid(pred_cls)) 177 | keep_reg.append(pred_reg) 178 | keep_anchors = torch.cat(keep_anchors, axis = 0) 179 | keep_cls = torch.cat(keep_cls, axis = 0) 180 | keep_reg = torch.cat(keep_reg, axis = 0) 181 | # multiclass 182 | tag = torch.arange(class_num).type_as(keep_cls)+1 183 | tag = tag.repeat(keep_cls.shape[0], 1).reshape(-1,1) 184 | pred_scores_0 = keep_cls[:, :class_num].reshape(-1, 1) 185 | pred_scores_1 = keep_cls[:, class_num:].reshape(-1, 1) 186 | pred_delta_0 = keep_reg[:, :4] 187 | pred_delta_1 = keep_reg[:, 4:] 188 | pred_bbox_0 = restore_bbox(keep_anchors, pred_delta_0, False) 189 | pred_bbox_1 = restore_bbox(keep_anchors, pred_delta_1, False) 190 | pred_bbox_0 = pred_bbox_0.repeat(1, class_num).reshape(-1, 4) 191 | pred_bbox_1 = pred_bbox_1.repeat(1, class_num).reshape(-1, 4) 192 | pred_bbox_0 = torch.cat([pred_bbox_0, pred_scores_0, tag], axis=1) 193 | pred_bbox_1 = torch.cat([pred_bbox_1, pred_scores_1, tag], axis=1) 194 | pred_bbox = torch.cat((pred_bbox_0, pred_bbox_1), axis=1) 195 | return pred_bbox 196 | 197 | def union_inference(anchors_list, pred_cls_list, pred_reg_list, im_info): 198 | anchors = torch.cat(anchors_list, axis = 0) 199 | pred_cls = torch.cat(pred_cls_list, axis = 1)[0] 200 | pred_cls = torch.sigmoid(pred_cls) 201 | pred_reg = torch.cat(pred_reg_list, axis = 1)[0] 202 | class_num = pred_cls.shape[-1] // 2 203 | # multiclass 204 | tag = torch.arange(class_num).type_as(pred_cls)+1 205 | tag = tag.repeat(pred_cls.shape[0], 1).reshape(-1,1) 206 | pred_scores_0 = pred_cls[:, :class_num].reshape(-1, 1) 207 | pred_scores_1 = pred_cls[:, class_num:].reshape(-1, 1) 208 | pred_delta_0 = pred_reg[:, :4] 209 | pred_delta_1 = pred_reg[:, 4:] 210 | pred_bbox_0 = restore_bbox(anchors, pred_delta_0, False) 211 | pred_bbox_1 = restore_bbox(anchors, pred_delta_1, False) 212 | pred_bbox_0 = pred_bbox_0.repeat(1, class_num).reshape(-1, 4) 213 | pred_bbox_1 = pred_bbox_1.repeat(1, class_num).reshape(-1, 4) 214 | pred_bbox_0 = torch.cat([pred_bbox_0, pred_scores_0, tag], axis=1) 215 | pred_bbox_1 = torch.cat([pred_bbox_1, pred_scores_1, tag], axis=1) 216 | pred_bbox = torch.cat((pred_bbox_0, pred_bbox_1), axis=1) 217 | return pred_bbox 218 | 219 | def restore_bbox(rois, deltas, unnormalize=True): 220 | if unnormalize: 221 | std_opr = torch.tensor(config.bbox_normalize_stds[None, :]).type_as(deltas) 222 | mean_opr = torch.tensor(config.bbox_normalize_means[None, :]).type_as(deltas) 223 | deltas = deltas * std_opr 224 | deltas = deltas + mean_opr 225 | pred_bbox = bbox_transform_inv_opr(rois, deltas) 226 | return pred_bbox 227 | 228 | -------------------------------------------------------------------------------- /model/retina_fpn_baseline/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | 6 | def add_path(path): 7 | if path not in sys.path: 8 | sys.path.insert(0, path) 9 | 10 | root_dir = '../../' 11 | add_path(os.path.join(root_dir)) 12 | add_path(os.path.join(root_dir, 'lib')) 13 | 14 | class Crowd_human: 15 | class_names = ['background', 'person'] 16 | num_classes = len(class_names) 17 | root_folder = '/data/CrowdHuman' 18 | image_folder = '/data/CrowdHuman/images' 19 | train_source = os.path.join('/data/CrowdHuman/annotation_train.odgt') 20 | eval_source = os.path.join('/data/CrowdHuman/annotation_val.odgt') 21 | 22 | class Config: 23 | output_dir = 'outputs' 24 | model_dir = os.path.join(output_dir, 'model_dump') 25 | eval_dir = os.path.join(output_dir, 'eval_dump') 26 | init_weights = '/data/model/resnet50_fbaug.pth' 27 | 28 | # ----------data config---------- # 29 | image_mean = np.array([103.530, 116.280, 123.675]) 30 | image_std = np.array([57.375, 57.120, 58.395]) 31 | train_image_short_size = 800 32 | train_image_max_size = 1400 33 | eval_resize = True 34 | eval_image_short_size = 800 35 | eval_image_max_size = 1400 36 | seed_dataprovider = 3 37 | train_source = Crowd_human.train_source 38 | eval_source = Crowd_human.eval_source 39 | image_folder = Crowd_human.image_folder 40 | class_names = Crowd_human.class_names 41 | num_classes = Crowd_human.num_classes 42 | class_names2id = dict(list(zip(class_names, list(range(num_classes))))) 43 | gt_boxes_name = 'fbox' 44 | 45 | # ----------train config---------- # 46 | backbone_freeze_at = 2 47 | train_batch_per_gpu = 2 48 | momentum = 0.9 49 | weight_decay = 1e-4 50 | base_lr = 3.125e-4 51 | focal_loss_alpha = 0.25 52 | focal_loss_gamma = 2 53 | 54 | warm_iter = 800 55 | max_epoch = 50 56 | lr_decay = [33, 43] 57 | nr_images_epoch = 15000 58 | log_dump_interval = 20 59 | 60 | # ----------test config---------- # 61 | test_layer_topk = 1000 62 | test_nms = 0.5 63 | test_nms_method = 'normal_nms' 64 | visulize_threshold = 0.3 65 | pred_cls_threshold = 0.01 66 | 67 | # ----------dataset config---------- # 68 | nr_box_dim = 5 69 | max_boxes_of_image = 500 70 | 71 | # --------anchor generator config-------- # 72 | anchor_base_size = 32 # the minimize anchor size in the bigest feature map. 73 | anchor_base_scale = [2**0, 2**(1/3), 2**(2/3)] 74 | anchor_aspect_ratios = [1, 2, 3] 75 | num_cell_anchors = len(anchor_aspect_ratios) * len(anchor_base_scale) 76 | 77 | # ----------binding&training config---------- # 78 | smooth_l1_beta = 0.1 79 | negative_thresh = 0.4 80 | positive_thresh = 0.5 81 | allow_low_quality = True 82 | 83 | config = Config() 84 | -------------------------------------------------------------------------------- /model/retina_fpn_baseline/network.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from config import config 8 | from backbone.resnet50 import ResNet50 9 | from backbone.fpn import FPN 10 | from det_oprs.anchors_generator import AnchorGenerator 11 | from det_oprs.retina_anchor_target import retina_anchor_target 12 | from det_oprs.bbox_opr import bbox_transform_inv_opr 13 | from det_oprs.loss_opr import focal_loss, smooth_l1_loss 14 | from det_oprs.utils import get_padded_tensor 15 | 16 | class Network(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | self.resnet50 = ResNet50(config.backbone_freeze_at, False) 20 | self.FPN = FPN(self.resnet50, 3, 7) 21 | self.R_Head = RetinaNet_Head() 22 | self.R_Anchor = RetinaNet_Anchor() 23 | self.R_Criteria = RetinaNet_Criteria() 24 | 25 | def forward(self, image, im_info, gt_boxes=None): 26 | # pre-processing the data 27 | image = (image - torch.tensor(config.image_mean[None, :, None, None]).type_as(image)) / ( 28 | torch.tensor(config.image_std[None, :, None, None]).type_as(image)) 29 | image = get_padded_tensor(image, 64) 30 | # do inference 31 | # stride: 128,64,32,16,8, p7->p3 32 | fpn_fms = self.FPN(image) 33 | anchors_list = self.R_Anchor(fpn_fms) 34 | pred_cls_list, pred_reg_list = self.R_Head(fpn_fms) 35 | # release the useless data 36 | if self.training: 37 | loss_dict = self.R_Criteria( 38 | pred_cls_list, pred_reg_list, anchors_list, gt_boxes, im_info) 39 | return loss_dict 40 | else: 41 | #pred_bbox = union_inference( 42 | # anchors_list, pred_cls_list, pred_reg_list, im_info) 43 | pred_bbox = per_layer_inference( 44 | anchors_list, pred_cls_list, pred_reg_list, im_info) 45 | return pred_bbox.cpu().detach() 46 | 47 | class RetinaNet_Anchor(): 48 | def __init__(self): 49 | self.anchors_generator = AnchorGenerator( 50 | config.anchor_base_size, 51 | config.anchor_aspect_ratios, 52 | config.anchor_base_scale) 53 | 54 | def __call__(self, fpn_fms): 55 | # get anchors 56 | all_anchors_list = [] 57 | base_stride = 8 58 | off_stride = 2**(len(fpn_fms)-1) # 16 59 | for fm in fpn_fms: 60 | layer_anchors = self.anchors_generator(fm, base_stride, off_stride) 61 | off_stride = off_stride // 2 62 | all_anchors_list.append(layer_anchors) 63 | return all_anchors_list 64 | 65 | class RetinaNet_Criteria(nn.Module): 66 | def __init__(self): 67 | super().__init__() 68 | self.loss_normalizer = 100 # initialize with any reasonable #fg that's not too small 69 | self.loss_normalizer_momentum = 0.9 70 | 71 | def __call__(self, pred_cls_list, pred_reg_list, anchors_list, gt_boxes, im_info): 72 | all_anchors = torch.cat(anchors_list, axis=0) 73 | all_pred_cls = torch.cat(pred_cls_list, axis=1).reshape(-1, config.num_classes-1) 74 | all_pred_cls = torch.sigmoid(all_pred_cls) 75 | all_pred_reg = torch.cat(pred_reg_list, axis=1).reshape(-1, 4) 76 | # get ground truth 77 | labels, bbox_target = retina_anchor_target(all_anchors, gt_boxes, im_info, top_k=1) 78 | # regression loss 79 | fg_mask = (labels > 0).flatten() 80 | valid_mask = (labels >= 0).flatten() 81 | loss_reg = smooth_l1_loss( 82 | all_pred_reg[fg_mask], 83 | bbox_target[fg_mask], 84 | config.smooth_l1_beta) 85 | loss_cls = focal_loss( 86 | all_pred_cls[valid_mask], 87 | labels[valid_mask], 88 | config.focal_loss_alpha, 89 | config.focal_loss_gamma) 90 | num_pos_anchors = fg_mask.sum().item() 91 | self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + ( 92 | 1 - self.loss_normalizer_momentum 93 | ) * max(num_pos_anchors, 1) 94 | loss_reg = loss_reg.sum() / self.loss_normalizer 95 | loss_cls = loss_cls.sum() / self.loss_normalizer 96 | loss_dict = {} 97 | loss_dict['retina_focal_loss'] = loss_cls 98 | loss_dict['retina_smooth_l1'] = loss_reg 99 | return loss_dict 100 | 101 | class RetinaNet_Head(nn.Module): 102 | def __init__(self): 103 | super().__init__() 104 | num_convs = 4 105 | in_channels = 256 106 | cls_subnet = [] 107 | bbox_subnet = [] 108 | for _ in range(num_convs): 109 | cls_subnet.append( 110 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) 111 | ) 112 | cls_subnet.append(nn.ReLU(inplace=True)) 113 | bbox_subnet.append( 114 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) 115 | ) 116 | bbox_subnet.append(nn.ReLU(inplace=True)) 117 | self.cls_subnet = nn.Sequential(*cls_subnet) 118 | self.bbox_subnet = nn.Sequential(*bbox_subnet) 119 | # predictor 120 | self.cls_score = nn.Conv2d( 121 | in_channels, config.num_cell_anchors * (config.num_classes-1), 122 | kernel_size=3, stride=1, padding=1) 123 | self.bbox_pred = nn.Conv2d( 124 | in_channels, config.num_cell_anchors * 4, 125 | kernel_size=3, stride=1, padding=1) 126 | 127 | # Initialization 128 | for modules in [self.cls_subnet, self.bbox_subnet, self.cls_score, self.bbox_pred]: 129 | for layer in modules.modules(): 130 | if isinstance(layer, nn.Conv2d): 131 | torch.nn.init.normal_(layer.weight, mean=0, std=0.01) 132 | torch.nn.init.constant_(layer.bias, 0) 133 | prior_prob = 0.01 134 | # Use prior in model initialization to improve stability 135 | bias_value = -(math.log((1 - prior_prob) / prior_prob)) 136 | torch.nn.init.constant_(self.cls_score.bias, bias_value) 137 | 138 | def forward(self, features): 139 | pred_cls = [] 140 | pred_reg = [] 141 | for feature in features: 142 | pred_cls.append(self.cls_score(self.cls_subnet(feature))) 143 | pred_reg.append(self.bbox_pred(self.bbox_subnet(feature))) 144 | # reshape the predictions 145 | assert pred_cls[0].dim() == 4 146 | pred_cls_list = [ 147 | _.permute(0, 2, 3, 1).reshape(pred_cls[0].shape[0], -1, config.num_classes-1) 148 | for _ in pred_cls] 149 | pred_reg_list = [ 150 | _.permute(0, 2, 3, 1).reshape(pred_reg[0].shape[0], -1, 4) 151 | for _ in pred_reg] 152 | return pred_cls_list, pred_reg_list 153 | 154 | def per_layer_inference(anchors_list, pred_cls_list, pred_reg_list, im_info): 155 | keep_anchors = [] 156 | keep_cls = [] 157 | keep_reg = [] 158 | class_num = pred_cls_list[0].shape[-1] 159 | for l_id in range(len(anchors_list)): 160 | anchors = anchors_list[l_id].reshape(-1, 4) 161 | pred_cls = pred_cls_list[l_id][0].reshape(-1, class_num) 162 | pred_reg = pred_reg_list[l_id][0].reshape(-1, 4) 163 | if len(anchors) > config.test_layer_topk: 164 | ruler = pred_cls.max(axis=1)[0] 165 | _, inds = ruler.topk(config.test_layer_topk, dim=0) 166 | inds = inds.flatten() 167 | keep_anchors.append(anchors[inds]) 168 | keep_cls.append(torch.sigmoid(pred_cls[inds])) 169 | keep_reg.append(pred_reg[inds]) 170 | else: 171 | keep_anchors.append(anchors) 172 | keep_cls.append(torch.sigmoid(pred_cls)) 173 | keep_reg.append(pred_reg) 174 | keep_anchors = torch.cat(keep_anchors, axis = 0) 175 | keep_cls = torch.cat(keep_cls, axis = 0) 176 | keep_reg = torch.cat(keep_reg, axis = 0) 177 | # multiclass 178 | tag = torch.arange(class_num).type_as(keep_cls)+1 179 | tag = tag.repeat(keep_cls.shape[0], 1).reshape(-1,1) 180 | pred_scores = keep_cls.reshape(-1, 1) 181 | pred_bbox = restore_bbox(keep_anchors, keep_reg, False) 182 | pred_bbox = pred_bbox.repeat(1, class_num).reshape(-1, 4) 183 | pred_bbox = torch.cat([pred_bbox, pred_scores, tag], axis=1) 184 | return pred_bbox 185 | 186 | def union_inference(anchors_list, pred_cls_list, pred_reg_list, im_info): 187 | anchors = torch.cat(anchors_list, axis = 0) 188 | pred_cls = torch.cat(pred_cls_list, axis = 1)[0] 189 | pred_cls = torch.sigmoid(pred_cls) 190 | pred_reg = torch.cat(pred_reg_list, axis = 1)[0] 191 | class_num = pred_cls_list[0].shape[-1] 192 | # multiclass 193 | tag = torch.arange(class_num).type_as(keep_cls)+1 194 | tag = tag.repeat(keep_cls.shape[0], 1).reshape(-1,1) 195 | pred_scores = keep_cls.reshape(-1, 1) 196 | pred_bbox = restore_bbox(keep_anchors, keep_reg, False) 197 | pred_bbox = pred_bbox.repeat(1, class_num).reshape(-1, 4) 198 | pred_bbox = torch.cat([pred_bbox, pred_scores, tag], axis=1) 199 | return pred_bbox 200 | 201 | def restore_bbox(rois, deltas, unnormalize=True): 202 | if unnormalize: 203 | std_opr = torch.tensor(config.bbox_normalize_stds[None, :]).type_as(deltas) 204 | mean_opr = torch.tensor(config.bbox_normalize_means[None, :]).type_as(deltas) 205 | deltas = deltas * std_opr 206 | deltas = deltas + mean_opr 207 | pred_bbox = bbox_transform_inv_opr(rois, deltas) 208 | return pred_bbox 209 | 210 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv-python 3 | -------------------------------------------------------------------------------- /tools/eval_json.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | 5 | sys.path.insert(0, '../lib') 6 | from utils import misc_utils 7 | from evaluate import compute_JI, compute_APMR 8 | 9 | def eval_all(args): 10 | # ground truth file 11 | gt_path = '/data/CrowdHuman/annotation_val.odgt' 12 | assert os.path.exists(gt_path), "Wrong ground truth path!" 13 | misc_utils.ensure_dir('outputs') 14 | # output file 15 | eval_path = os.path.join('outputs', 'result_eval.md') 16 | eval_fid = open(eval_path,'a') 17 | eval_fid.write((args.json_file+'\n')) 18 | # eval JI 19 | res_line, JI = compute_JI.evaluation_all(args.json_file, 'box') 20 | for line in res_line: 21 | eval_fid.write(line+'\n') 22 | # eval AP, MR 23 | AP, MR = compute_APMR.compute_APMR(args.json_file, gt_path, 'box') 24 | line = 'AP:{:.4f}, MR:{:.4f}, JI:{:.4f}.'.format(AP, MR, JI) 25 | print(line) 26 | eval_fid.write(line+'\n\n') 27 | eval_fid.close() 28 | 29 | def run_eval(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--json_file', '-f', default=None, required=True, type=str) 32 | args = parser.parse_args() 33 | eval_all(args) 34 | 35 | if __name__ == '__main__': 36 | run_eval() 37 | -------------------------------------------------------------------------------- /tools/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | 5 | import cv2 6 | import torch 7 | import numpy as np 8 | 9 | sys.path.insert(0, '../lib') 10 | from utils import misc_utils, visual_utils, nms_utils 11 | 12 | def inference(args, config, network): 13 | # model_path 14 | misc_utils.ensure_dir('outputs') 15 | saveDir = os.path.join('../model', args.model_dir, config.model_dir) 16 | model_file = os.path.join(saveDir, 17 | 'dump-{}.pth'.format(args.resume_weights)) 18 | assert os.path.exists(model_file) 19 | # build network 20 | net = network() 21 | net.eval() 22 | check_point = torch.load(model_file, map_location=torch.device('cpu')) 23 | net.load_state_dict(check_point['state_dict']) 24 | # get data 25 | image, resized_img, im_info = get_data( 26 | args.img_path, config.eval_image_short_size, config.eval_image_max_size) 27 | pred_boxes = net(resized_img, im_info).numpy() 28 | pred_boxes = post_process(pred_boxes, config, im_info[0, 2]) 29 | pred_tags = pred_boxes[:, 5].astype(np.int32).flatten() 30 | pred_tags_name = np.array(config.class_names)[pred_tags] 31 | # inplace draw 32 | image = visual_utils.draw_boxes( 33 | image, 34 | pred_boxes[:, :4], 35 | scores=pred_boxes[:, 4], 36 | tags=pred_tags_name, 37 | line_thick=1, line_color='white') 38 | name = args.img_path.split('/')[-1].split('.')[-2] 39 | fpath = 'outputs/{}.png'.format(name) 40 | cv2.imwrite(fpath, image) 41 | 42 | def post_process(pred_boxes, config, scale): 43 | if config.test_nms_method == 'set_nms': 44 | assert pred_boxes.shape[-1] > 6, "Not EMD Network! Using normal_nms instead." 45 | assert pred_boxes.shape[-1] % 6 == 0, "Prediction dim Error!" 46 | top_k = pred_boxes.shape[-1] // 6 47 | n = pred_boxes.shape[0] 48 | pred_boxes = pred_boxes.reshape(-1, 6) 49 | idents = np.tile(np.arange(n)[:,None], (1, top_k)).reshape(-1, 1) 50 | pred_boxes = np.hstack((pred_boxes, idents)) 51 | keep = pred_boxes[:, 4] > config.pred_cls_threshold 52 | pred_boxes = pred_boxes[keep] 53 | keep = nms_utils.set_cpu_nms(pred_boxes, 0.5) 54 | pred_boxes = pred_boxes[keep] 55 | elif config.test_nms_method == 'normal_nms': 56 | assert pred_boxes.shape[-1] % 6 == 0, "Prediction dim Error!" 57 | pred_boxes = pred_boxes.reshape(-1, 6) 58 | keep = pred_boxes[:, 4] > config.pred_cls_threshold 59 | pred_boxes = pred_boxes[keep] 60 | keep = nms_utils.cpu_nms(pred_boxes, config.test_nms) 61 | pred_boxes = pred_boxes[keep] 62 | elif config.test_nms_method == 'none': 63 | assert pred_boxes.shape[-1] % 6 == 0, "Prediction dim Error!" 64 | pred_boxes = pred_boxes.reshape(-1, 6) 65 | keep = pred_boxes[:, 4] > config.pred_cls_threshold 66 | pred_boxes = pred_boxes[keep] 67 | #if pred_boxes.shape[0] > config.detection_per_image and \ 68 | # config.test_nms_method != 'none': 69 | # order = np.argsort(-pred_boxes[:, 4]) 70 | # order = order[:config.detection_per_image] 71 | # pred_boxes = pred_boxes[order] 72 | # recovery the scale 73 | pred_boxes[:, :4] /= scale 74 | keep = pred_boxes[:, 4] > config.visulize_threshold 75 | pred_boxes = pred_boxes[keep] 76 | return pred_boxes 77 | 78 | def get_data(img_path, short_size, max_size): 79 | image = cv2.imread(img_path, cv2.IMREAD_COLOR) 80 | resized_img, scale = resize_img( 81 | image, short_size, max_size) 82 | 83 | original_height, original_width = image.shape[0:2] 84 | height, width = resized_img.shape[0:2] 85 | resized_img = resized_img.transpose(2, 0, 1) 86 | im_info = np.array([height, width, scale, original_height, original_width, 0]) 87 | return image, torch.tensor([resized_img]).float(), torch.tensor([im_info]) 88 | 89 | def resize_img(image, short_size, max_size): 90 | height = image.shape[0] 91 | width = image.shape[1] 92 | im_size_min = np.min([height, width]) 93 | im_size_max = np.max([height, width]) 94 | scale = (short_size + 0.0) / im_size_min 95 | if scale * im_size_max > max_size: 96 | scale = (max_size + 0.0) / im_size_max 97 | t_height, t_width = int(round(height * scale)), int( 98 | round(width * scale)) 99 | resized_image = cv2.resize( 100 | image, (t_width, t_height), interpolation=cv2.INTER_LINEAR) 101 | return resized_image, scale 102 | 103 | def run_inference(): 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument('--model_dir', '-md', default=None, required=True, type=str) 106 | parser.add_argument('--resume_weights', '-r', default=None, required=True, type=str) 107 | parser.add_argument('--img_path', '-i', default=None, required=True, type=str) 108 | args = parser.parse_args() 109 | # import libs 110 | model_root_dir = os.path.join('../model/', args.model_dir) 111 | sys.path.insert(0, model_root_dir) 112 | from config import config 113 | from network import Network 114 | inference(args, config, Network) 115 | 116 | if __name__ == '__main__': 117 | run_inference() 118 | -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import argparse 5 | 6 | import numpy as np 7 | from tqdm import tqdm 8 | import torch 9 | from torch.multiprocessing import Queue, Process 10 | 11 | sys.path.insert(0, '../lib') 12 | sys.path.insert(0, '../model') 13 | from data.CrowdHuman import CrowdHuman 14 | from utils import misc_utils, nms_utils 15 | from evaluate import compute_JI, compute_APMR 16 | 17 | def eval_all(args, config, network): 18 | # model_path 19 | saveDir = os.path.join('../model', args.model_dir, config.model_dir) 20 | evalDir = os.path.join('../model', args.model_dir, config.eval_dir) 21 | misc_utils.ensure_dir(evalDir) 22 | model_file = os.path.join(saveDir, 23 | 'dump-{}.pth'.format(args.resume_weights)) 24 | assert os.path.exists(model_file) 25 | # get devices 26 | str_devices = args.devices 27 | devices = misc_utils.device_parser(str_devices) 28 | # load data 29 | crowdhuman = CrowdHuman(config, if_train=False) 30 | #crowdhuman.records = crowdhuman.records[:10] 31 | # multiprocessing 32 | num_devs = len(devices) 33 | len_dataset = len(crowdhuman) 34 | num_image = math.ceil(len_dataset / num_devs) 35 | result_queue = Queue(500) 36 | procs = [] 37 | all_results = [] 38 | for i in range(num_devs): 39 | start = i * num_image 40 | end = min(start + num_image, len_dataset) 41 | proc = Process(target=inference, args=( 42 | config, network, model_file, devices[i], crowdhuman, start, end, result_queue)) 43 | proc.start() 44 | procs.append(proc) 45 | pbar = tqdm(total=len_dataset, ncols=50) 46 | for i in range(len_dataset): 47 | t = result_queue.get() 48 | all_results.append(t) 49 | pbar.update(1) 50 | pbar.close() 51 | for p in procs: 52 | p.join() 53 | fpath = os.path.join(evalDir, 'dump-{}.json'.format(args.resume_weights)) 54 | misc_utils.save_json_lines(all_results, fpath) 55 | # evaluation 56 | eval_path = os.path.join(evalDir, 'eval-{}.json'.format(args.resume_weights)) 57 | eval_fid = open(eval_path,'w') 58 | res_line, JI = compute_JI.evaluation_all(fpath, 'box') 59 | for line in res_line: 60 | eval_fid.write(line+'\n') 61 | AP, MR = compute_APMR.compute_APMR(fpath, config.eval_source, 'box') 62 | line = 'AP:{:.4f}, MR:{:.4f}, JI:{:.4f}.'.format(AP, MR, JI) 63 | print(line) 64 | eval_fid.write(line+'\n') 65 | eval_fid.close() 66 | 67 | def inference(config, network, model_file, device, dataset, start, end, result_queue): 68 | torch.set_default_tensor_type('torch.FloatTensor') 69 | torch.multiprocessing.set_sharing_strategy('file_system') 70 | # init model 71 | net = network() 72 | net.cuda(device) 73 | net = net.eval() 74 | check_point = torch.load(model_file) 75 | net.load_state_dict(check_point['state_dict']) 76 | # init data 77 | dataset.records = dataset.records[start:end]; 78 | data_iter = torch.utils.data.DataLoader(dataset=dataset, shuffle=False) 79 | # inference 80 | for (image, gt_boxes, im_info, ID) in data_iter: 81 | pred_boxes = net(image.cuda(device), im_info.cuda(device)) 82 | scale = im_info[0, 2] 83 | if config.test_nms_method == 'set_nms': 84 | assert pred_boxes.shape[-1] > 6, "Not EMD Network! Using normal_nms instead." 85 | assert pred_boxes.shape[-1] % 6 == 0, "Prediction dim Error!" 86 | top_k = pred_boxes.shape[-1] // 6 87 | n = pred_boxes.shape[0] 88 | pred_boxes = pred_boxes.reshape(-1, 6) 89 | idents = np.tile(np.arange(n)[:,None], (1, top_k)).reshape(-1, 1) 90 | pred_boxes = np.hstack((pred_boxes, idents)) 91 | keep = pred_boxes[:, 4] > config.pred_cls_threshold 92 | pred_boxes = pred_boxes[keep] 93 | keep = nms_utils.set_cpu_nms(pred_boxes, 0.5) 94 | pred_boxes = pred_boxes[keep] 95 | elif config.test_nms_method == 'normal_nms': 96 | assert pred_boxes.shape[-1] % 6 == 0, "Prediction dim Error!" 97 | pred_boxes = pred_boxes.reshape(-1, 6) 98 | keep = pred_boxes[:, 4] > config.pred_cls_threshold 99 | pred_boxes = pred_boxes[keep] 100 | keep = nms_utils.cpu_nms(pred_boxes, config.test_nms) 101 | pred_boxes = pred_boxes[keep] 102 | elif config.test_nms_method == 'none': 103 | assert pred_boxes.shape[-1] % 6 == 0, "Prediction dim Error!" 104 | pred_boxes = pred_boxes.reshape(-1, 6) 105 | keep = pred_boxes[:, 4] > config.pred_cls_threshold 106 | pred_boxes = pred_boxes[keep] 107 | else: 108 | raise ValueError('Unknown NMS method.') 109 | #if pred_boxes.shape[0] > config.detection_per_image and \ 110 | # config.test_nms_method != 'none': 111 | # order = np.argsort(-pred_boxes[:, 4]) 112 | # order = order[:config.detection_per_image] 113 | # pred_boxes = pred_boxes[order] 114 | # recovery the scale 115 | pred_boxes[:, :4] /= scale 116 | pred_boxes[:, 2:4] -= pred_boxes[:, :2] 117 | gt_boxes = gt_boxes[0].numpy() 118 | gt_boxes[:, 2:4] -= gt_boxes[:, :2] 119 | result_dict = dict(ID=ID[0], height=int(im_info[0, -3]), width=int(im_info[0, -2]), 120 | dtboxes=boxes_dump(pred_boxes), gtboxes=boxes_dump(gt_boxes)) 121 | result_queue.put_nowait(result_dict) 122 | 123 | def boxes_dump(boxes): 124 | if boxes.shape[-1] == 7: 125 | result = [{'box':[round(i, 1) for i in box[:4]], 126 | 'score':round(float(box[4]), 5), 127 | 'tag':int(box[5]), 128 | 'proposal_num':int(box[6])} for box in boxes] 129 | elif boxes.shape[-1] == 6: 130 | result = [{'box':[round(i, 1) for i in box[:4].tolist()], 131 | 'score':round(float(box[4]), 5), 132 | 'tag':int(box[5])} for box in boxes] 133 | elif boxes.shape[-1] == 5: 134 | result = [{'box':[round(i, 1) for i in box[:4]], 135 | 'tag':int(box[4])} for box in boxes] 136 | else: 137 | raise ValueError('Unknown box dim.') 138 | return result 139 | 140 | def run_test(): 141 | parser = argparse.ArgumentParser() 142 | parser.add_argument('--model_dir', '-md', default=None, required=True, type=str) 143 | parser.add_argument('--resume_weights', '-r', default=None, required=True, type=str) 144 | parser.add_argument('--devices', '-d', default='0', type=str) 145 | os.environ['NCCL_IB_DISABLE'] = '1' 146 | args = parser.parse_args() 147 | # import libs 148 | model_root_dir = os.path.join('../model/', args.model_dir) 149 | sys.path.insert(0, model_root_dir) 150 | from config import config 151 | from network import Network 152 | eval_all(args, config, Network) 153 | 154 | if __name__ == '__main__': 155 | run_test() 156 | 157 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import torch 5 | 6 | sys.path.insert(0, '../lib') 7 | sys.path.insert(0, '../model') 8 | from data.CrowdHuman import CrowdHuman 9 | from utils import misc_utils, SGD_bias 10 | 11 | class Train_config: 12 | # size 13 | world_size = 0 14 | mini_batch_size = 0 15 | iter_per_epoch = 0 16 | total_epoch = 0 17 | # learning 18 | warm_iter = 0 19 | learning_rate = 0 20 | momentum = 0 21 | weight_decay = 0 22 | lr_decay = [] 23 | # model 24 | log_dump_interval = 0 25 | resume_weights = None 26 | init_weights = None 27 | model_dir = '' 28 | log_path = '' 29 | 30 | def do_train_epoch(net, data_iter, optimizer, rank, epoch, train_config): 31 | if rank == 0: 32 | fid_log = open(train_config.log_path,'a') 33 | if epoch >= train_config.lr_decay[0]: 34 | for param_group in optimizer.param_groups: 35 | param_group['lr'] = train_config.learning_rate / 10 36 | if epoch >= train_config.lr_decay[1]: 37 | for param_group in optimizer.param_groups: 38 | param_group['lr'] = train_config.learning_rate / 100 39 | for (images, gt_boxes, im_info), step in zip(data_iter, range(0, train_config.iter_per_epoch)): 40 | if images is None: 41 | continue 42 | # warm up 43 | if epoch == 1 and step < train_config.warm_iter: 44 | alpha = step / train_config.warm_iter 45 | lr_new = 0.33 * train_config.learning_rate + \ 46 | 0.67 * alpha * train_config.learning_rate 47 | for group in optimizer.param_groups: 48 | group['lr'] = lr_new 49 | elif epoch == 1 and step == train_config.warm_iter: 50 | for group in optimizer.param_groups: 51 | group['lr'] = train_config.learning_rate 52 | # get training data 53 | optimizer.zero_grad() 54 | # forwad 55 | outputs = net(images.cuda(rank), im_info.cuda(rank), gt_boxes.cuda(rank)) 56 | # collect the loss 57 | total_loss = sum([outputs[key].mean() for key in outputs.keys()]) 58 | assert torch.isfinite(total_loss).all(), outputs 59 | total_loss.backward() 60 | optimizer.step() 61 | # stastic 62 | if rank == 0: 63 | if step % train_config.log_dump_interval == 0: 64 | stastic_total_loss = total_loss.item() 65 | line = 'Epoch:{}, iter:{}, lr:{:.5f}, loss is {:.4f}.'.format( 66 | epoch, step, optimizer.param_groups[0]['lr'], stastic_total_loss) 67 | print(line) 68 | print(outputs) 69 | fid_log.write(line+'\n') 70 | fid_log.write(str(outputs)+'\n') 71 | fid_log.flush() 72 | if rank == 0: 73 | fid_log.close() 74 | 75 | def train_worker(rank, train_config, network, config): 76 | # set the parallel 77 | torch.distributed.init_process_group(backend='nccl', 78 | init_method='env://', world_size=train_config.world_size, rank=rank) 79 | # initialize model 80 | net = network() 81 | # load pretrain model 82 | backbone_dict = torch.load(train_config.init_weights) 83 | del backbone_dict['state_dict']['fc.weight'] 84 | del backbone_dict['state_dict']['fc.bias'] 85 | net.resnet50.load_state_dict(backbone_dict['state_dict']) 86 | net.cuda(rank) 87 | begin_epoch = 1 88 | # build optimizer 89 | #optimizer = SGD_bias.SGD(net.parameters(), 90 | optimizer = torch.optim.SGD(net.parameters(), 91 | lr=train_config.learning_rate, momentum=train_config.momentum, 92 | weight_decay=train_config.weight_decay) 93 | if train_config.resume_weights: 94 | model_file = os.path.join(train_config.model_dir, 95 | 'dump-{}.pth'.format(train_config.resume_weights)) 96 | check_point = torch.load(model_file, map_location=torch.device('cpu')) 97 | net.load_state_dict(check_point['state_dict']) 98 | begin_epoch = train_config.resume_weights + 1 99 | # using distributed data parallel 100 | net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[rank], broadcast_buffers=False) 101 | # build data provider 102 | crowdhuman = CrowdHuman(config, if_train=True) 103 | data_iter = torch.utils.data.DataLoader(dataset=crowdhuman, 104 | batch_size=train_config.mini_batch_size, 105 | num_workers=2, 106 | collate_fn=crowdhuman.merge_batch, 107 | shuffle=True) 108 | for epoch_id in range(begin_epoch, train_config.total_epoch+1): 109 | do_train_epoch(net, data_iter, optimizer, rank, epoch_id, train_config) 110 | if rank == 0: 111 | #save the model 112 | fpath = os.path.join(train_config.model_dir, 113 | 'dump-{}.pth'.format(epoch_id)) 114 | model = dict(epoch = epoch_id, 115 | state_dict = net.module.state_dict(), 116 | optimizer = optimizer.state_dict()) 117 | torch.save(model,fpath) 118 | 119 | def multi_train(params, config, network): 120 | # check gpus 121 | if not torch.cuda.is_available(): 122 | print('No GPU exists!') 123 | return 124 | else: 125 | num_gpus = torch.cuda.device_count() 126 | torch.set_default_tensor_type('torch.FloatTensor') 127 | # setting training config 128 | train_config = Train_config() 129 | train_config.world_size = num_gpus 130 | train_config.total_epoch = config.max_epoch 131 | train_config.iter_per_epoch = \ 132 | config.nr_images_epoch // (num_gpus * config.train_batch_per_gpu) 133 | train_config.mini_batch_size = config.train_batch_per_gpu 134 | train_config.warm_iter = config.warm_iter 135 | train_config.learning_rate = \ 136 | config.base_lr * config.train_batch_per_gpu * num_gpus 137 | train_config.momentum = config.momentum 138 | train_config.weight_decay = config.weight_decay 139 | train_config.lr_decay = config.lr_decay 140 | train_config.model_dir = os.path.join('../model/', params.model_dir, config.model_dir) 141 | line = 'network.lr.{}.train.{}'.format( 142 | train_config.learning_rate, train_config.total_epoch) 143 | train_config.log_path = os.path.join('../model/', params.model_dir, config.output_dir, line+'.log') 144 | train_config.resume_weights = params.resume_weights 145 | train_config.init_weights = config.init_weights 146 | train_config.log_dump_interval = config.log_dump_interval 147 | misc_utils.ensure_dir(train_config.model_dir) 148 | # print the training config 149 | line = 'Num of GPUs:{}, learning rate:{:.5f}, mini batch size:{}, \ 150 | \ntrain_epoch:{}, iter_per_epoch:{}, decay_epoch:{}'.format( 151 | num_gpus, train_config.learning_rate, train_config.mini_batch_size, 152 | train_config.total_epoch, train_config.iter_per_epoch, train_config.lr_decay) 153 | print(line) 154 | print("Init multi-processing training...") 155 | torch.multiprocessing.spawn(train_worker, nprocs=num_gpus, args=(train_config, network, config)) 156 | 157 | def run_train(): 158 | parser = argparse.ArgumentParser() 159 | parser.add_argument('--model_dir', '-md', default=None,required=True,type=str) 160 | parser.add_argument('--resume_weights', '-r', default=None,type=int) 161 | os.environ['MASTER_ADDR'] = '127.0.0.1' 162 | os.environ['MASTER_PORT'] = '8888' 163 | os.environ['NCCL_IB_DISABLE'] = '1' 164 | #os.environ['NCCL_DEBUG'] = 'INFO' 165 | args = parser.parse_args() 166 | # import libs 167 | model_root_dir = os.path.join('../model/', args.model_dir) 168 | sys.path.insert(0, model_root_dir) 169 | from config import config 170 | from network import Network 171 | multi_train(args, config, Network) 172 | 173 | if __name__ == '__main__': 174 | run_train() 175 | 176 | -------------------------------------------------------------------------------- /tools/visulize_json.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | 6 | sys.path.insert(0, '../lib') 7 | from utils import misc_utils, visual_utils 8 | 9 | img_root = '/data/CrowdHuman/images/' 10 | def eval_all(args): 11 | # json file 12 | assert os.path.exists(args.json_file), "Wrong json path!" 13 | misc_utils.ensure_dir('outputs') 14 | records = misc_utils.load_json_lines(args.json_file)[:args.number] 15 | for record in records: 16 | dtboxes = misc_utils.load_bboxes( 17 | record, key_name='dtboxes', key_box='box', key_score='score', key_tag='tag') 18 | gtboxes = misc_utils.load_bboxes(record, 'gtboxes', 'box') 19 | dtboxes = misc_utils.xywh_to_xyxy(dtboxes) 20 | gtboxes = misc_utils.xywh_to_xyxy(gtboxes) 21 | keep = dtboxes[:, -2] > args.visual_thresh 22 | dtboxes = dtboxes[keep] 23 | len_dt = len(dtboxes) 24 | len_gt = len(gtboxes) 25 | line = "{}: dt:{}, gt:{}.".format(record['ID'], len_dt, len_gt) 26 | print(line) 27 | img_path = img_root + record['ID'] + '.png' 28 | img = misc_utils.load_img(img_path) 29 | visual_utils.draw_boxes(img, dtboxes, line_thick=1, line_color='blue') 30 | visual_utils.draw_boxes(img, gtboxes, line_thick=1, line_color='white') 31 | fpath = 'outputs/{}.png'.format(record['ID']) 32 | cv2.imwrite(fpath, img) 33 | 34 | 35 | def run_eval(): 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--json_file', '-f', default=None, required=True, type=str) 38 | parser.add_argument('--number', '-n', default=3, type=int) 39 | parser.add_argument('--visual_thresh', '-v', default=0.3, type=int) 40 | args = parser.parse_args() 41 | eval_all(args) 42 | 43 | if __name__ == '__main__': 44 | run_eval() 45 | --------------------------------------------------------------------------------