├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── lib
├── backbone
│ ├── fpn.py
│ └── resnet50.py
├── data
│ └── CrowdHuman.py
├── det_oprs
│ ├── anchors_generator.py
│ ├── bbox_opr.py
│ ├── cascade_roi_target.py
│ ├── find_top_rpn_proposals.py
│ ├── fpn_anchor_target.py
│ ├── fpn_roi_target.py
│ ├── loss_opr.py
│ ├── retina_anchor_target.py
│ └── utils.py
├── evaluate
│ ├── APMRToolkits
│ │ ├── __init__.py
│ │ ├── database.py
│ │ └── image.py
│ ├── JIToolkits
│ │ ├── JI_tools.py
│ │ └── matching.py
│ ├── __init__.py
│ ├── compute_APMR.py
│ └── compute_JI.py
├── layers
│ ├── batch_norm.py
│ └── pooler.py
├── module
│ └── rpn.py
└── utils
│ ├── SGD_bias.py
│ ├── misc_utils.py
│ ├── nms_utils.py
│ └── visual_utils.py
├── model
├── rcnn_emd_refine
│ ├── config.py
│ └── network.py
├── rcnn_emd_simple
│ ├── config.py
│ └── network.py
├── rcnn_fpn_baseline
│ ├── config.py
│ └── network.py
├── retina_emd_simple
│ ├── config.py
│ └── network.py
└── retina_fpn_baseline
│ ├── config.py
│ └── network.py
├── requirements.txt
└── tools
├── eval_json.py
├── inference.py
├── test.py
├── train.py
└── visulize_json.py
/.gitignore:
--------------------------------------------------------------------------------
1 | **/.DS_Store
2 | **/__pycache__
3 | **/outputs
4 | **/model_dump
5 |
6 | **/*.out
7 | **/*.json
8 | **/*.pkl
9 | **/*.pth
10 | **/*.log
11 |
12 | **/*.jpg
13 | **/*.jpeg
14 | **/*.png
15 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | ARG PYTORCH="1.5"
2 | ARG CUDA="10.1"
3 | ARG CUDNN="7"
4 |
5 | FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
6 |
7 | ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX"
8 | ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
9 | ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../"
10 |
11 | RUN apt-get update && apt-get install -y vim git ninja-build libglib2.0-0 libsm6 libgl1-mesa-dev libxrender-dev libxext6 \
12 | && apt-get clean \
13 | && rm -rf /var/lib/apt/lists/*
14 |
15 | # Install mmdetection
16 | #RUN git clone https://github.com/open-mmlab/mmdetection.git /mmdetection
17 | WORKDIR /crowddet
18 | ADD . .
19 | RUN pip install --default-timeout=1800 -r requirements.txt --ignore-installed
20 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Megvii Technology
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Detection in Crowded Scenes: One Proposal, Multiple Predictions
2 |
3 | This is the pytorch implementation of our paper "[Detection in Crowded Scenes: One Proposal, Multiple Predictions](https://openaccess.thecvf.com/content_CVPR_2020/html/Chu_Detection_in_Crowded_Scenes_One_Proposal_Multiple_Predictions_CVPR_2020_paper.html)", https://arxiv.org/abs/2003.09163, published in CVPR 2020.
4 |
5 | Our method aiming at detecting highly-overlapped instances in crowded scenes.
6 |
7 | The key of our approach is to let each proposal predict a set of instances that might be highly overlapped rather than a single one in previous proposal-based frameworks. With this scheme, the predictions of nearby proposals are expected to infer the **same set** of instances, rather than **distinguishing individuals**, which is much easy to be learned. Equipped with new techniques such as EMD Loss and Set NMS, our detector can effectively handle the difficulty of detecting highly overlapped objects.
8 |
9 | The network structure and results are shown here:
10 |
11 |
12 |
13 |
14 | # Citation
15 |
16 | If you use the code in your research, please cite:
17 | ```
18 | @InProceedings{Chu_2020_CVPR,
19 | author = {Chu, Xuangeng and Zheng, Anlin and Zhang, Xiangyu and Sun, Jian},
20 | title = {Detection in Crowded Scenes: One Proposal, Multiple Predictions},
21 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
22 | month = {June},
23 | year = {2020}
24 | }
25 | ```
26 |
27 | # Run
28 | 1) Setup environment by docker
29 | - Requirements: Install [nvidia-docker](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker)
30 | - Create docker image:
31 | ```shell
32 | sudo docker build . -t crowddet
33 | ```
34 | - Run docker image:
35 | ```shell
36 | sudo docker run --gpus all --shm-size=8g -it --rm crowddet
37 | ```
38 |
39 | 2. CrowdHuman data:
40 | * CrowdHuman is a benchmark dataset to better evaluate detectors in crowd scenarios. The dataset can be downloaded from http://www.crowdhuman.org/. The path of the dataset is set in `config.py`.
41 |
42 | 3. Steps to run:
43 | * Step1: training. More training and testing settings can be set in `config.py`.
44 | ```
45 | cd tools
46 | python3 train.py -md rcnn_fpn_baseline
47 | ```
48 |
49 | * Step2: testing. If you have four GPUs, you can use ` -d 0-3 ` to use all of your GPUs.
50 | The result json file will be evaluated automatically.
51 | ```
52 | cd tools
53 | python3 test.py -md rcnn_fpn_baseline -r 40
54 | ```
55 |
56 | * Step3: evaluating json, inference one picture and visulization json file.
57 | ` -r ` means resume epoch, ` -n ` means number of visulization pictures.
58 | ```
59 | cd tools
60 | python3 eval_json.py -f your_json_path.json
61 | python3 inference.py -md rcnn_fpn_baseline -r 40 -i your_image_path.png
62 | python3 visulize_json.py -f your_json_path.json -n 3
63 | ```
64 |
65 | # Models
66 |
67 | We use MegEngine in the research (https://github.com/megvii-model/CrowdDetection), this proiect is a re-implementation based on Pytorch.
68 |
69 | We use pre-trained model from [MegEngine Model Hub](https://megengine.org.cn/model-hub) and convert this model to pytorch. You can get this model from [here](https://drive.google.com/file/d/1lfYQHC63oM2Dynbfj6uD7XnpDIaA5kNr/view?usp=sharing).
70 | These models can also be downloaded from [Baidu Netdisk](https://pan.baidu.com/s/1U3I-qNIrXuYQzUEDDdISTw)(code:yx46).
71 | | Model | Top1 acc | Top5 acc |
72 | | --- | --- | --- |
73 | | ResNet50 | 76.254 | 93.056 |
74 |
75 | All models are based on ResNet-50 FPN.
76 | | | AP | MR | JI | Model
77 | | --- | --- | --- | --- | --- |
78 | | RCNN FPN Baseline (convert from MegEngine) | 0.8718 | 0.4239 | 0.7949 | [rcnn_fpn_baseline_mge.pth](https://drive.google.com/file/d/19LBc_6vizKr06Wky0s7TAnvlqP8PjSA_/view?usp=sharing) |
79 | | RCNN EMD Simple (convert from MegEngine) | 0.9052 | 0.4196 | 0.8209 | [rcnn_emd_simple_mge.pth](https://drive.google.com/file/d/1f_vjFrjTxXYR5nPnYZRrU-yffYTGUnyL/view?usp=sharing) |
80 | | RCNN EMD with RM (convert from MegEngine) | 0.9097 | 0.4102 | 0.8271 | [rcnn_emd_refine_mge.pth](https://drive.google.com/file/d/1qYJ0b7QsYZsP5_8yIjya_kj_tu90ALDJ/view?usp=sharing) |
81 | | RCNN FPN Baseline (trained with PyTorch) | 0.8665 | 0.4243 | 0.7949 | [rcnn_fpn_baseline.pth](https://drive.google.com/file/d/10poBJ1qwlV0iS6i_lnbw9cbdt4tpTvh1/view?usp=sharing) |
82 | | RCNN EMD Simple (trained with PyTorch) | 0.8997 | 0.4167 | 0.8225 | [rcnn_emd_simple.pth](https://drive.google.com/file/d/1Rryeqz5sMWTTm3epEfqDpK1H8EsPvlLe/view?usp=sharing) |
83 | | RCNN EMD with RM (trained with PyTorch) | 0.9030 | 0.4128 | 0.8263 | [rcnn_emd_refine.pth](https://drive.google.com/file/d/1jk_b7Ws528uCfEgOesLS_iBsqXcHl2Ju/view?usp=sharing) |
84 | | RetinaNet FPN Baseline | 0.8188 | 0.5644 | 0.7316 | [retina_fpn_baseline.pth](https://drive.google.com/file/d/1w1CmE4MfYB4NT5Uyx85dPkR87gEFXhBJ/view?usp=sharing) |
85 | | RetinaNet EMD Simple | 0.8292 | 0.5481 | 0.7393 | [retina_emd_simple.pth](https://drive.google.com/file/d/1LwUlTf4YAH3wp-HXAAuXyTD11SeDgwhE/view?usp=sharing) |
86 |
89 |
90 | # Contact
91 |
92 | If you have any questions, please do not hesitate to contact Xuangeng Chu (xg_chu@pku.edu.cn).
93 |
--------------------------------------------------------------------------------
/lib/backbone/fpn.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 |
7 | class FPN(nn.Module):
8 | """
9 | This module implements Feature Pyramid Network.
10 | It creates pyramid features built on top of some input feature maps.
11 | """
12 | def __init__(self, bottom_up, layers_begin, layers_end):
13 | super(FPN, self).__init__()
14 | assert layers_begin > 1 and layers_begin < 6
15 | assert layers_end > 4 and layers_begin < 8
16 | in_channels = [256, 512, 1024, 2048]
17 | fpn_dim = 256
18 | in_channels = in_channels[layers_begin-2:]
19 |
20 | lateral_convs = nn.ModuleList()
21 | output_convs = nn.ModuleList()
22 | for idx, in_channels in enumerate(in_channels):
23 | lateral_conv = nn.Conv2d(in_channels, fpn_dim, kernel_size=1)
24 | output_conv = nn.Conv2d(fpn_dim, fpn_dim, kernel_size=3, stride=1, padding=1)
25 | nn.init.kaiming_normal_(lateral_conv.weight, mode='fan_out')
26 | nn.init.constant_(lateral_conv.bias, 0)
27 | nn.init.kaiming_normal_(output_conv.weight, mode='fan_out')
28 | nn.init.constant_(output_conv.bias, 0)
29 | lateral_convs.append(lateral_conv)
30 | output_convs.append(output_conv)
31 |
32 | self.lateral_convs = lateral_convs[::-1]
33 | self.output_convs = output_convs[::-1]
34 | self.bottom_up = bottom_up
35 | self.output_b = layers_begin
36 | self.output_e = layers_end
37 | if self.output_e == 7:
38 | self.p6 = nn.Conv2d(fpn_dim, fpn_dim, kernel_size=3, stride=2, padding=1)
39 | self.p7 = nn.Conv2d(fpn_dim, fpn_dim, kernel_size=3, stride=2, padding=1)
40 | for l in [self.p6, self.p7]:
41 | nn.init.kaiming_uniform_(l.weight, a=1) # pyre-ignore
42 | nn.init.constant_(l.bias, 0)
43 |
44 | def forward(self, x):
45 | bottom_up_features = self.bottom_up(x)
46 | bottom_up_features = bottom_up_features[self.output_b - 2:]
47 | bottom_up_features = bottom_up_features[::-1]
48 | results = []
49 | prev_features = self.lateral_convs[0](bottom_up_features[0])
50 | results.append(self.output_convs[0](prev_features))
51 | for l_id, (features, lateral_conv, output_conv) in enumerate(zip(
52 | bottom_up_features[1:], self.lateral_convs[1:], self.output_convs[1:])):
53 | top_down_features = F.interpolate(prev_features, scale_factor=2, mode="bilinear", align_corners=False)
54 | lateral_features = lateral_conv(features)
55 | prev_features = lateral_features + top_down_features
56 | results.append(output_conv(prev_features))
57 | if(self.output_e == 6):
58 | p6 = F.max_pool2d(results[0], kernel_size=1, stride=2, padding=0)
59 | results.insert(0, p6)
60 | elif(self.output_e == 7):
61 | p6 = self.p6(results[0])
62 | results.insert(0, p6)
63 | p7 = self.p7(F.relu(results[0]))
64 | results.insert(0, p7)
65 | return results
66 |
67 |
--------------------------------------------------------------------------------
/lib/backbone/resnet50.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | from layers.batch_norm import FrozenBatchNorm2d
6 |
7 | class Bottleneck(nn.Module):
8 | def __init__(self, in_cha, neck_cha, out_cha, stride, has_bias=False):
9 | super(Bottleneck, self).__init__()
10 |
11 | self.downsample = None
12 | if in_cha!= out_cha or stride != 1:
13 | self.downsample = nn.Sequential(
14 | nn.Conv2d(in_cha, out_cha, kernel_size=1, stride=stride, bias=has_bias),
15 | FrozenBatchNorm2d(out_cha),
16 | )
17 |
18 | # The original MSRA ResNet models have stride in the first 1x1 conv
19 | # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations
20 | # have stride in the 3x3 conv
21 | self.conv1 = nn.Conv2d(in_cha, neck_cha, kernel_size=1, stride=1, bias=has_bias)
22 | self.bn1 = FrozenBatchNorm2d(neck_cha)
23 |
24 | self.conv2 = nn.Conv2d(neck_cha, neck_cha, kernel_size=3, stride=stride, padding=1, bias=has_bias)
25 | self.bn2 = FrozenBatchNorm2d(neck_cha)
26 |
27 | self.conv3 = nn.Conv2d(neck_cha, out_cha, kernel_size=1, bias=has_bias)
28 | self.bn3 = FrozenBatchNorm2d(out_cha)
29 |
30 | def forward(self, x):
31 | identity = x
32 |
33 | x = self.conv1(x)
34 | x = self.bn1(x)
35 | x = F.relu_(x)
36 | x = self.conv2(x)
37 | x = self.bn2(x)
38 | x = F.relu_(x)
39 | x = self.conv3(x)
40 | x = self.bn3(x)
41 | if self.downsample is not None:
42 | identity = self.downsample(identity)
43 |
44 | x += identity
45 | x = F.relu_(x)
46 | return x
47 |
48 |
49 | class ResNet50(nn.Module):
50 | def __init__(self, freeze_at, has_bias=False):
51 | super(ResNet50, self).__init__()
52 | self.has_bias = has_bias
53 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=has_bias)
54 | self.bn1 = FrozenBatchNorm2d(64)
55 |
56 | block_counts = [3, 4, 6, 3]
57 | bottleneck_channels_list = [64, 128, 256, 512]
58 | out_channels_list = [256, 512, 1024, 2048]
59 | stride_list = [1, 2, 2, 2]
60 | in_channels = 64
61 |
62 | self.layer1 = self._make_layer(block_counts[0], 64,
63 | bottleneck_channels_list[0], out_channels_list[0], stride_list[0])
64 | self.layer2 = self._make_layer(block_counts[1], out_channels_list[0],
65 | bottleneck_channels_list[1], out_channels_list[1], stride_list[1])
66 | self.layer3 = self._make_layer(block_counts[2], out_channels_list[1],
67 | bottleneck_channels_list[2], out_channels_list[2], stride_list[2])
68 | self.layer4 = self._make_layer(block_counts[3], out_channels_list[2],
69 | bottleneck_channels_list[3], out_channels_list[3], stride_list[3])
70 |
71 | for l in self.modules():
72 | if isinstance(l, nn.Conv2d):
73 | nn.init.kaiming_normal_(l.weight, mode='fan_out')
74 | if self.has_bias:
75 | nn.init.constant_(l.bias, 0)
76 |
77 | self._freeze_backbone(freeze_at)
78 |
79 | def _make_layer(self, num_blocks, in_channels, bottleneck_channels, out_channels, stride):
80 | layers = []
81 | for _ in range(num_blocks):
82 | layers.append(Bottleneck(in_channels, bottleneck_channels, out_channels, stride, self.has_bias))
83 | stride = 1
84 | in_channels = out_channels
85 | return nn.Sequential(*layers)
86 |
87 | def _freeze_backbone(self, freeze_at):
88 | if freeze_at < 0:
89 | return
90 | if freeze_at >= 1:
91 | for p in self.conv1.parameters():
92 | p.requires_grad = False
93 | if freeze_at >= 2:
94 | for p in self.layer1.parameters():
95 | p.requires_grad = False
96 | if freeze_at >= 3:
97 | print("Freeze too much layers! Only freeze the first 2 layers.")
98 |
99 | def forward(self, x):
100 | outputs = []
101 | # stem
102 | x = self.conv1(x)
103 | x = self.bn1(x)
104 | x = F.relu_(x)
105 | x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
106 | # blocks
107 | x = self.layer1(x)
108 | outputs.append(x)
109 | x = self.layer2(x)
110 | outputs.append(x)
111 | x = self.layer3(x)
112 | outputs.append(x)
113 | x = self.layer4(x)
114 | outputs.append(x)
115 | return outputs
116 |
117 |
--------------------------------------------------------------------------------
/lib/data/CrowdHuman.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import numpy as np
5 |
6 | from utils import misc_utils
7 |
8 | class CrowdHuman(torch.utils.data.Dataset):
9 | def __init__(self, config, if_train):
10 | if if_train:
11 | self.training = True
12 | source = config.train_source
13 | self.short_size = config.train_image_short_size
14 | self.max_size = config.train_image_max_size
15 | else:
16 | self.training = False
17 | source = config.eval_source
18 | self.short_size = config.eval_image_short_size
19 | self.max_size = config.eval_image_max_size
20 | self.records = misc_utils.load_json_lines(source)
21 | self.config = config
22 |
23 | def __getitem__(self, index):
24 | return self.load_record(self.records[index])
25 |
26 | def __len__(self):
27 | return len(self.records)
28 |
29 | def load_record(self, record):
30 | if self.training:
31 | if_flap = np.random.randint(2) == 1
32 | else:
33 | if_flap = False
34 | # image
35 | image_path = os.path.join(self.config.image_folder, record['ID']+'.png')
36 | image = misc_utils.load_img(image_path)
37 | image_h = image.shape[0]
38 | image_w = image.shape[1]
39 | if if_flap:
40 | image = cv2.flip(image, 1)
41 | if self.training:
42 | # ground_truth
43 | gtboxes = misc_utils.load_gt(record, 'gtboxes', 'fbox', self.config.class_names)
44 | keep = (gtboxes[:, 2]>=0) * (gtboxes[:, 3]>=0)
45 | gtboxes=gtboxes[keep, :]
46 | gtboxes[:, 2:4] += gtboxes[:, :2]
47 | if if_flap:
48 | gtboxes = flip_boxes(gtboxes, image_w)
49 | # im_info
50 | nr_gtboxes = gtboxes.shape[0]
51 | im_info = np.array([0, 0, 1, image_h, image_w, nr_gtboxes])
52 | return image, gtboxes, im_info
53 | else:
54 | # image
55 | t_height, t_width, scale = target_size(
56 | image_h, image_w, self.short_size, self.max_size)
57 | # INTER_CUBIC, INTER_LINEAR, INTER_NEAREST, INTER_AREA, INTER_LANCZOS4
58 | resized_image = cv2.resize(image, (t_width, t_height), interpolation=cv2.INTER_LINEAR)
59 | resized_image = resized_image.transpose(2, 0, 1)
60 | image = torch.tensor(resized_image).float()
61 | gtboxes = misc_utils.load_gt(record, 'gtboxes', 'fbox', self.config.class_names)
62 | gtboxes[:, 2:4] += gtboxes[:, :2]
63 | gtboxes = torch.tensor(gtboxes)
64 | # im_info
65 | nr_gtboxes = gtboxes.shape[0]
66 | im_info = torch.tensor([t_height, t_width, scale, image_h, image_w, nr_gtboxes])
67 | return image, gtboxes, im_info, record['ID']
68 |
69 | def merge_batch(self, data):
70 | # image
71 | images = [it[0] for it in data]
72 | gt_boxes = [it[1] for it in data]
73 | im_info = np.array([it[2] for it in data])
74 | batch_height = np.max(im_info[:, 3])
75 | batch_width = np.max(im_info[:, 4])
76 | padded_images = [pad_image(
77 | im, batch_height, batch_width, self.config.image_mean) for im in images]
78 | t_height, t_width, scale = target_size(
79 | batch_height, batch_width, self.short_size, self.max_size)
80 | # INTER_CUBIC, INTER_LINEAR, INTER_NEAREST, INTER_AREA, INTER_LANCZOS4
81 | resized_images = np.array([cv2.resize(
82 | im, (t_width, t_height), interpolation=cv2.INTER_LINEAR) for im in padded_images])
83 | resized_images = resized_images.transpose(0, 3, 1, 2)
84 | images = torch.tensor(resized_images).float()
85 | # ground_truth
86 | ground_truth = []
87 | for it in gt_boxes:
88 | gt_padded = np.zeros((self.config.max_boxes_of_image, self.config.nr_box_dim))
89 | it[:, 0:4] *= scale
90 | max_box = min(self.config.max_boxes_of_image, len(it))
91 | gt_padded[:max_box] = it[:max_box]
92 | ground_truth.append(gt_padded)
93 | ground_truth = torch.tensor(ground_truth).float()
94 | # im_info
95 | im_info[:, 0] = t_height
96 | im_info[:, 1] = t_width
97 | im_info[:, 2] = scale
98 | im_info = torch.tensor(im_info)
99 | if max(im_info[:, -1] < 2):
100 | return None, None, None
101 | else:
102 | return images, ground_truth, im_info
103 |
104 | def target_size(height, width, short_size, max_size):
105 | im_size_min = np.min([height, width])
106 | im_size_max = np.max([height, width])
107 | scale = (short_size + 0.0) / im_size_min
108 | if scale * im_size_max > max_size:
109 | scale = (max_size + 0.0) / im_size_max
110 | t_height, t_width = int(round(height * scale)), int(
111 | round(width * scale))
112 | return t_height, t_width, scale
113 |
114 | def flip_boxes(boxes, im_w):
115 | flip_boxes = boxes.copy()
116 | for i in range(flip_boxes.shape[0]):
117 | flip_boxes[i, 0] = im_w - boxes[i, 2] - 1
118 | flip_boxes[i, 2] = im_w - boxes[i, 0] - 1
119 | return flip_boxes
120 |
121 | def pad_image(img, height, width, mean_value):
122 | o_h, o_w, _ = img.shape
123 | margins = np.zeros(2, np.int32)
124 | assert o_h <= height
125 | margins[0] = height - o_h
126 | img = cv2.copyMakeBorder(
127 | img, 0, margins[0], 0, 0, cv2.BORDER_CONSTANT, value=0)
128 | img[o_h:, :, :] = mean_value
129 | assert o_w <= width
130 | margins[1] = width - o_w
131 | img = cv2.copyMakeBorder(
132 | img, 0, 0, 0, margins[1], cv2.BORDER_CONSTANT, value=0)
133 | img[:, o_w:, :] = mean_value
134 | return img
135 |
--------------------------------------------------------------------------------
/lib/det_oprs/anchors_generator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | class AnchorGenerator():
6 | """default anchor generator for fpn.
7 | This class generate anchors by feature map in level.
8 | """
9 | def __init__(self, base_size=16, ratios=[0.5, 1, 2],
10 | base_scale=2):
11 | self.base_size = base_size
12 | self.base_scale = np.array(base_scale)
13 | self.anchor_ratios = np.array(ratios)
14 |
15 | def _whctrs(self, anchor):
16 | """convert anchor box into (w, h, ctr_x, ctr_y)
17 | """
18 | w = anchor[:, 2] - anchor[:, 0] + 1
19 | h = anchor[:, 3] - anchor[:, 1] + 1
20 | x_ctr = anchor[:, 0] + 0.5 * (w - 1)
21 | y_ctr = anchor[:, 1] + 0.5 * (h - 1)
22 | return w, h, x_ctr, y_ctr
23 |
24 | def get_plane_anchors(self, anchor_scales: np.ndarray):
25 | """get anchors per location on feature map.
26 | The anchor number is anchor_scales x anchor_ratios
27 | """
28 | base_anchor = np.array([[0, 0, self.base_size - 1, self.base_size - 1]])
29 | off = self.base_size // 2 - 8
30 | w, h, x_ctr, y_ctr = self._whctrs(base_anchor)
31 | # ratio enumerate
32 | size = w * h
33 | size_ratios = size / self.anchor_ratios
34 | ws = np.round(np.sqrt(size_ratios))
35 | hs = np.round(ws * self.anchor_ratios)
36 | # scale enumerate
37 | anchor_scales = anchor_scales[None, ...]
38 | ws = (ws[:, None] * anchor_scales).reshape(-1, 1)
39 | hs = (hs[:, None] * anchor_scales).reshape(-1, 1)
40 | # make anchors
41 | anchors = np.hstack((x_ctr - 0.5 * (ws - 1),
42 | y_ctr - 0.5 * (hs - 1),
43 | x_ctr + 0.5 * (ws - 1),
44 | y_ctr + 0.5 * (hs - 1))) - off
45 | return anchors.astype(np.float32)
46 |
47 | def get_center_offsets(self, fm_map, stride):
48 | fm_height, fm_width = fm_map.shape[-2], fm_map.shape[-1]
49 | f_device = fm_map.device
50 | shift_x = torch.arange(0, fm_width, device=f_device) * stride
51 | shift_y = torch.arange(0, fm_height, device=f_device) * stride
52 | broad_shift_x = shift_x.reshape(-1, shift_x.shape[0]).repeat(fm_height,1)
53 | broad_shift_y = shift_y.reshape(shift_y.shape[0], -1).repeat(1,fm_width)
54 | flatten_shift_x = broad_shift_x.flatten().reshape(-1,1)
55 | flatten_shift_y = broad_shift_y.flatten().reshape(-1,1)
56 | shifts = torch.cat(
57 | [flatten_shift_x, flatten_shift_y, flatten_shift_x, flatten_shift_y],
58 | axis=1)
59 | return shifts
60 |
61 | def get_anchors_by_feature(self, fm_map, base_stride, off_stride):
62 | # shifts shape: [A, 4]
63 | shifts = self.get_center_offsets(fm_map, base_stride * off_stride)
64 | # plane_anchors shape: [B, 4], e.g. B=3
65 | plane_anchors = self.get_plane_anchors(self.base_scale * off_stride)
66 | plane_anchors = torch.tensor(plane_anchors, device=fm_map.device)
67 | # all_anchors shape: [A, B, 4]
68 | all_anchors = plane_anchors[None, :] + shifts[:, None]
69 | all_anchors = all_anchors.reshape(-1, 4)
70 | return all_anchors
71 |
72 | @torch.no_grad()
73 | def __call__(self, featmap, base_stride, off_stride):
74 | return self.get_anchors_by_feature(featmap, base_stride, off_stride)
75 |
76 |
--------------------------------------------------------------------------------
/lib/det_oprs/bbox_opr.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | def filter_boxes_opr(boxes, min_size):
5 | """Remove all boxes with any side smaller than min_size."""
6 | ws = boxes[:, 2] - boxes[:, 0] + 1
7 | hs = boxes[:, 3] - boxes[:, 1] + 1
8 | keep = (ws >= min_size) * (hs >= min_size)
9 | return keep
10 |
11 | def clip_boxes_opr(boxes, im_info):
12 | """ Clip the boxes into the image region."""
13 | w = im_info[1] - 1
14 | h = im_info[0] - 1
15 | boxes[:, 0::4] = boxes[:, 0::4].clamp(min=0, max=w)
16 | boxes[:, 1::4] = boxes[:, 1::4].clamp(min=0, max=h)
17 | boxes[:, 2::4] = boxes[:, 2::4].clamp(min=0, max=w)
18 | boxes[:, 3::4] = boxes[:, 3::4].clamp(min=0, max=h)
19 | return boxes
20 |
21 | def batch_clip_proposals(proposals, im_info):
22 | """ Clip the boxes into the image region."""
23 | w = im_info[1] - 1
24 | h = im_info[0] - 1
25 | boxes[:, 0::4] = boxes[:, 0::4].clamp(min=0, max=w)
26 | boxes[:, 1::4] = boxes[:, 1::4].clamp(min=0, max=h)
27 | boxes[:, 2::4] = boxes[:, 2::4].clamp(min=0, max=w)
28 | boxes[:, 3::4] = boxes[:, 3::4].clamp(min=0, max=h)
29 | return boxes
30 |
31 | def bbox_transform_inv_opr(bbox, deltas):
32 | max_delta = math.log(1000.0 / 16)
33 | """ Transforms the learned deltas to the final bbox coordinates, the axis is 1"""
34 | bbox_width = bbox[:, 2] - bbox[:, 0] + 1
35 | bbox_height = bbox[:, 3] - bbox[:, 1] + 1
36 | bbox_ctr_x = bbox[:, 0] + 0.5 * bbox_width
37 | bbox_ctr_y = bbox[:, 1] + 0.5 * bbox_height
38 | pred_ctr_x = bbox_ctr_x + deltas[:, 0] * bbox_width
39 | pred_ctr_y = bbox_ctr_y + deltas[:, 1] * bbox_height
40 |
41 | dw = deltas[:, 2]
42 | dh = deltas[:, 3]
43 | dw = torch.clamp(dw, max=max_delta)
44 | dh = torch.clamp(dh, max=max_delta)
45 | pred_width = bbox_width * torch.exp(dw)
46 | pred_height = bbox_height * torch.exp(dh)
47 |
48 | pred_x1 = pred_ctr_x - 0.5 * pred_width
49 | pred_y1 = pred_ctr_y - 0.5 * pred_height
50 | pred_x2 = pred_ctr_x + 0.5 * pred_width
51 | pred_y2 = pred_ctr_y + 0.5 * pred_height
52 | pred_boxes = torch.cat((pred_x1.reshape(-1, 1), pred_y1.reshape(-1, 1),
53 | pred_x2.reshape(-1, 1), pred_y2.reshape(-1, 1)), dim=1)
54 | return pred_boxes
55 |
56 | def bbox_transform_opr(bbox, gt):
57 | """ Transform the bounding box and ground truth to the loss targets.
58 | The 4 box coordinates are in axis 1"""
59 | bbox_width = bbox[:, 2] - bbox[:, 0] + 1
60 | bbox_height = bbox[:, 3] - bbox[:, 1] + 1
61 | bbox_ctr_x = bbox[:, 0] + 0.5 * bbox_width
62 | bbox_ctr_y = bbox[:, 1] + 0.5 * bbox_height
63 |
64 | gt_width = gt[:, 2] - gt[:, 0] + 1
65 | gt_height = gt[:, 3] - gt[:, 1] + 1
66 | gt_ctr_x = gt[:, 0] + 0.5 * gt_width
67 | gt_ctr_y = gt[:, 1] + 0.5 * gt_height
68 |
69 | target_dx = (gt_ctr_x - bbox_ctr_x) / bbox_width
70 | target_dy = (gt_ctr_y - bbox_ctr_y) / bbox_height
71 | target_dw = torch.log(gt_width / bbox_width)
72 | target_dh = torch.log(gt_height / bbox_height)
73 | target = torch.cat((target_dx.reshape(-1, 1), target_dy.reshape(-1, 1),
74 | target_dw.reshape(-1, 1), target_dh.reshape(-1, 1)), dim=1)
75 | return target
76 |
77 | def box_overlap_opr(box, gt):
78 | assert box.ndim == 2
79 | assert gt.ndim == 2
80 | area_box = (box[:, 2] - box[:, 0] + 1) * (box[:, 3] - box[:, 1] + 1)
81 | area_gt = (gt[:, 2] - gt[:, 0] + 1) * (gt[:, 3] - gt[:, 1] + 1)
82 | width_height = torch.min(box[:, None, 2:], gt[:, 2:]) - torch.max(
83 | box[:, None, :2], gt[:, :2]) + 1 # [N,M,2]
84 | width_height.clamp_(min=0) # [N,M,2]
85 | inter = width_height.prod(dim=2) # [N,M]
86 | del width_height
87 | # handle empty boxes
88 | iou = torch.where(
89 | inter > 0,
90 | inter / (area_box[:, None] + area_gt - inter),
91 | torch.zeros(1, dtype=inter.dtype, device=inter.device),
92 | )
93 | return iou
94 |
95 | def box_overlap_ignore_opr(box, gt, ignore_label=-1):
96 | assert box.ndim == 2
97 | assert gt.ndim == 2
98 | assert gt.shape[-1] > 4
99 | area_box = (box[:, 2] - box[:, 0] + 1) * (box[:, 3] - box[:, 1] + 1)
100 | area_gt = (gt[:, 2] - gt[:, 0] + 1) * (gt[:, 3] - gt[:, 1] + 1)
101 | width_height = torch.min(box[:, None, 2:], gt[:, 2:4]) - torch.max(
102 | box[:, None, :2], gt[:, :2]) # [N,M,2]
103 | width_height.clamp_(min=0) # [N,M,2]
104 | inter = width_height.prod(dim=2) # [N,M]
105 | del width_height
106 | # handle empty boxes
107 | iou = torch.where(
108 | inter > 0,
109 | inter / (area_box[:, None] + area_gt - inter),
110 | torch.zeros(1, dtype=inter.dtype, device=inter.device))
111 | ioa = torch.where(
112 | inter > 0,
113 | inter / (area_box[:, None]),
114 | torch.zeros(1, dtype=inter.dtype, device=inter.device))
115 | gt_ignore_mask = gt[:, 4].eq(ignore_label).repeat(box.shape[0], 1)
116 | iou *= ~gt_ignore_mask
117 | ioa *= gt_ignore_mask
118 | return iou, ioa
119 |
120 |
--------------------------------------------------------------------------------
/lib/det_oprs/cascade_roi_target.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 |
4 | import numpy as np
5 | from config import config
6 | from det_oprs.bbox_opr import box_overlap_opr, bbox_transform_opr, box_overlap_ignore_opr
7 |
8 | @torch.no_grad()
9 | def cascade_roi_target(rpn_rois, im_info, gt_boxes, pos_threshold=0.5, top_k=1):
10 | return_rois = []
11 | return_labels = []
12 | return_bbox_targets = []
13 | # get per image proposals and gt_boxes
14 | for bid in range(config.train_batch_per_gpu):
15 | gt_boxes_perimg = gt_boxes[bid, :int(im_info[bid, 5]), :]
16 | batch_inds = torch.ones((gt_boxes_perimg.shape[0], 1)).type_as(gt_boxes_perimg) * bid
17 | gt_rois = torch.cat([batch_inds, gt_boxes_perimg[:, :4]], axis=1)
18 | batch_roi_inds = torch.nonzero(rpn_rois[:, 0] == bid, as_tuple=False).flatten()
19 | all_rois = torch.cat([rpn_rois[batch_roi_inds], gt_rois], axis=0)
20 | overlaps_normal, overlaps_ignore = box_overlap_ignore_opr(
21 | all_rois[:, 1:5], gt_boxes_perimg)
22 | overlaps_normal, overlaps_normal_indices = overlaps_normal.sort(descending=True, dim=1)
23 | overlaps_ignore, overlaps_ignore_indices = overlaps_ignore.sort(descending=True, dim=1)
24 | # gt max and indices, ignore max and indices
25 | max_overlaps_normal = overlaps_normal[:, :top_k].flatten()
26 | gt_assignment_normal = overlaps_normal_indices[:, :top_k].flatten()
27 | max_overlaps_ignore = overlaps_ignore[:, :top_k].flatten()
28 | gt_assignment_ignore = overlaps_ignore_indices[:, :top_k].flatten()
29 | # cons masks
30 | ignore_assign_mask = (max_overlaps_normal < pos_threshold) * (
31 | max_overlaps_ignore > max_overlaps_normal)
32 | max_overlaps = max_overlaps_normal * ~ignore_assign_mask + \
33 | max_overlaps_ignore * ignore_assign_mask
34 | gt_assignment = gt_assignment_normal * ~ignore_assign_mask + \
35 | gt_assignment_ignore * ignore_assign_mask
36 | labels = gt_boxes_perimg[gt_assignment, 4]
37 | fg_mask = (max_overlaps >= pos_threshold) * (labels != config.ignore_label)
38 | bg_mask = (max_overlaps < config.bg_threshold_high)
39 | fg_mask = fg_mask.reshape(-1, top_k)
40 | bg_mask = bg_mask.reshape(-1, top_k)
41 | #pos_max = config.num_rois * config.fg_ratio
42 | #fg_inds_mask = subsample_masks(fg_mask[:, 0], pos_max, True)
43 | #neg_max = config.num_rois - fg_inds_mask.sum()
44 | #bg_inds_mask = subsample_masks(bg_mask[:, 0], neg_max, True)
45 | labels = labels * fg_mask.flatten()
46 | #keep_mask = fg_inds_mask + bg_inds_mask
47 | # labels
48 | labels = labels.reshape(-1, top_k)#[keep_mask]
49 | gt_assignment = gt_assignment.reshape(-1, top_k).flatten()
50 | target_boxes = gt_boxes_perimg[gt_assignment, :4]
51 | #rois = all_rois[keep_mask]
52 | target_rois = all_rois.repeat(1, top_k).reshape(-1, all_rois.shape[-1])
53 | bbox_targets = bbox_transform_opr(target_rois[:, 1:5], target_boxes)
54 | if config.rcnn_bbox_normalize_targets:
55 | std_opr = torch.tensor(config.bbox_normalize_stds[None, :]).type_as(bbox_targets)
56 | mean_opr = torch.tensor(config.bbox_normalize_means[None, :]).type_as(bbox_targets)
57 | minus_opr = mean_opr / std_opr
58 | bbox_targets = bbox_targets / std_opr - minus_opr
59 | bbox_targets = bbox_targets.reshape(-1, top_k * 4)
60 | return_rois.append(all_rois)
61 | return_labels.append(labels)
62 | return_bbox_targets.append(bbox_targets)
63 | if config.train_batch_per_gpu == 1:
64 | return rois, labels, bbox_targets
65 | else:
66 | return_rois = torch.cat(return_rois, axis=0)
67 | return_labels = torch.cat(return_labels, axis=0)
68 | return_bbox_targets = torch.cat(return_bbox_targets, axis=0)
69 | return return_rois, return_labels, return_bbox_targets
70 |
71 |
--------------------------------------------------------------------------------
/lib/det_oprs/find_top_rpn_proposals.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from config import config
5 | from det_oprs.bbox_opr import bbox_transform_inv_opr, clip_boxes_opr, \
6 | filter_boxes_opr
7 | from torchvision.ops import nms
8 |
9 | @torch.no_grad()
10 | def find_top_rpn_proposals(is_train, rpn_bbox_offsets_list, rpn_cls_prob_list,
11 | all_anchors_list, im_info):
12 | prev_nms_top_n = config.train_prev_nms_top_n \
13 | if is_train else config.test_prev_nms_top_n
14 | post_nms_top_n = config.train_post_nms_top_n \
15 | if is_train else config.test_post_nms_top_n
16 | batch_per_gpu = config.train_batch_per_gpu if is_train else 1
17 | nms_threshold = config.rpn_nms_threshold
18 | box_min_size = config.rpn_min_box_size
19 | bbox_normalize_targets = config.rpn_bbox_normalize_targets
20 | bbox_normalize_means = config.bbox_normalize_means
21 | bbox_normalize_stds = config.bbox_normalize_stds
22 | list_size = len(rpn_bbox_offsets_list)
23 |
24 | return_rois = []
25 | return_inds = []
26 | for bid in range(batch_per_gpu):
27 | batch_proposals_list = []
28 | batch_probs_list = []
29 | for l in range(list_size):
30 | # get proposals and probs
31 | offsets = rpn_bbox_offsets_list[l][bid] \
32 | .permute(1, 2, 0).reshape(-1, 4)
33 | if bbox_normalize_targets:
34 | std_opr = torch.tensor(config.bbox_normalize_stds[None, :]).type_as(bbox_targets)
35 | mean_opr = torch.tensor(config.bbox_normalize_means[None, :]).type_as(bbox_targets)
36 | pred_offsets = pred_offsets * std_opr
37 | pred_offsets = pred_offsets + mean_opr
38 | all_anchors = all_anchors_list[l]
39 | proposals = bbox_transform_inv_opr(all_anchors, offsets)
40 | if config.anchor_within_border:
41 | proposals = clip_boxes_opr(proposals, im_info[bid, :])
42 | probs = rpn_cls_prob_list[l][bid] \
43 | .permute(1,2,0).reshape(-1, 2)
44 | probs = torch.softmax(probs, dim=-1)[:, 1]
45 | # gather the proposals and probs
46 | batch_proposals_list.append(proposals)
47 | batch_probs_list.append(probs)
48 | batch_proposals = torch.cat(batch_proposals_list, dim=0)
49 | batch_probs = torch.cat(batch_probs_list, dim=0)
50 | # filter the zero boxes.
51 | batch_keep_mask = filter_boxes_opr(
52 | batch_proposals, box_min_size * im_info[bid, 2])
53 | batch_proposals = batch_proposals[batch_keep_mask]
54 | batch_probs = batch_probs[batch_keep_mask]
55 | # prev_nms_top_n
56 | num_proposals = min(prev_nms_top_n, batch_probs.shape[0])
57 | batch_probs, idx = batch_probs.sort(descending=True)
58 | batch_probs = batch_probs[:num_proposals]
59 | topk_idx = idx[:num_proposals].flatten()
60 | batch_proposals = batch_proposals[topk_idx]
61 | # For each image, run a total-level NMS, and choose topk results.
62 | keep = nms(batch_proposals, batch_probs, nms_threshold)
63 | keep = keep[:post_nms_top_n]
64 | batch_proposals = batch_proposals[keep]
65 | #batch_probs = batch_probs[keep]
66 | # cons the rois
67 | batch_inds = torch.ones(batch_proposals.shape[0], 1).type_as(batch_proposals) * bid
68 | batch_rois = torch.cat([batch_inds, batch_proposals], axis=1)
69 | return_rois.append(batch_rois)
70 |
71 | if batch_per_gpu == 1:
72 | return batch_rois
73 | else:
74 | concated_rois = torch.cat(return_rois, axis=0)
75 | return concated_rois
76 |
--------------------------------------------------------------------------------
/lib/det_oprs/fpn_anchor_target.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from det_oprs.bbox_opr import box_overlap_opr, bbox_transform_opr
5 | from config import config
6 |
7 | def fpn_rpn_reshape(pred_cls_score_list, pred_bbox_offsets_list):
8 | final_pred_bbox_offsets_list = []
9 | final_pred_cls_score_list = []
10 | for bid in range(config.train_batch_per_gpu):
11 | batch_pred_bbox_offsets_list = []
12 | batch_pred_cls_score_list = []
13 | for i in range(len(pred_cls_score_list)):
14 | pred_cls_score_perlvl = pred_cls_score_list[i][bid] \
15 | .permute(1, 2, 0).reshape(-1, 2)
16 | pred_bbox_offsets_perlvl = pred_bbox_offsets_list[i][bid] \
17 | .permute(1, 2, 0).reshape(-1, 4)
18 | batch_pred_cls_score_list.append(pred_cls_score_perlvl)
19 | batch_pred_bbox_offsets_list.append(pred_bbox_offsets_perlvl)
20 | batch_pred_cls_score = torch.cat(batch_pred_cls_score_list, dim=0)
21 | batch_pred_bbox_offsets = torch.cat(batch_pred_bbox_offsets_list, dim=0)
22 | final_pred_cls_score_list.append(batch_pred_cls_score)
23 | final_pred_bbox_offsets_list.append(batch_pred_bbox_offsets)
24 | final_pred_cls_score = torch.cat(final_pred_cls_score_list, dim=0)
25 | final_pred_bbox_offsets = torch.cat(final_pred_bbox_offsets_list, dim=0)
26 | return final_pred_cls_score, final_pred_bbox_offsets
27 |
28 | def fpn_anchor_target_opr_core_impl(
29 | gt_boxes, im_info, anchors, allow_low_quality_matches=True):
30 | ignore_label = config.ignore_label
31 | # get the gt boxes
32 | valid_gt_boxes = gt_boxes[:int(im_info[5]), :]
33 | valid_gt_boxes = valid_gt_boxes[valid_gt_boxes[:, -1].gt(0)]
34 | # compute the iou matrix
35 | anchors = anchors.type_as(valid_gt_boxes)
36 | overlaps = box_overlap_opr(anchors, valid_gt_boxes[:, :4])
37 | # match the dtboxes
38 | max_overlaps, argmax_overlaps = torch.max(overlaps, axis=1)
39 | #_, gt_argmax_overlaps = torch.max(overlaps, axis=0)
40 | gt_argmax_overlaps = my_gt_argmax(overlaps)
41 | del overlaps
42 | # all ignore
43 | labels = torch.ones(anchors.shape[0], device=gt_boxes.device, dtype=torch.long) * ignore_label
44 | # set negative ones
45 | labels = labels * (max_overlaps >= config.rpn_negative_overlap)
46 | # set positive ones
47 | fg_mask = (max_overlaps >= config.rpn_positive_overlap)
48 | if allow_low_quality_matches:
49 | gt_id = torch.arange(valid_gt_boxes.shape[0]).type_as(argmax_overlaps)
50 | argmax_overlaps[gt_argmax_overlaps] = gt_id
51 | max_overlaps[gt_argmax_overlaps] = 1
52 | fg_mask = (max_overlaps >= config.rpn_positive_overlap)
53 | # set positive ones
54 | fg_mask_ind = torch.nonzero(fg_mask, as_tuple=False).flatten()
55 | labels[fg_mask_ind] = 1
56 | # bbox targets
57 | bbox_targets = bbox_transform_opr(
58 | anchors, valid_gt_boxes[argmax_overlaps, :4])
59 | if config.rpn_bbox_normalize_targets:
60 | std_opr = torch.tensor(config.bbox_normalize_stds[None, :]).type_as(bbox_targets)
61 | mean_opr = torch.tensor(config.bbox_normalize_means[None, :]).type_as(bbox_targets)
62 | minus_opr = mean_opr / std_opr
63 | bbox_targets = bbox_targets / std_opr - minus_opr
64 | return labels, bbox_targets
65 |
66 | @torch.no_grad()
67 | def fpn_anchor_target(boxes, im_info, all_anchors_list):
68 | final_labels_list = []
69 | final_bbox_targets_list = []
70 | for bid in range(config.train_batch_per_gpu):
71 | batch_labels_list = []
72 | batch_bbox_targets_list = []
73 | for i in range(len(all_anchors_list)):
74 | anchors_perlvl = all_anchors_list[i]
75 | rpn_labels_perlvl, rpn_bbox_targets_perlvl = fpn_anchor_target_opr_core_impl(
76 | boxes[bid], im_info[bid], anchors_perlvl)
77 | batch_labels_list.append(rpn_labels_perlvl)
78 | batch_bbox_targets_list.append(rpn_bbox_targets_perlvl)
79 | # here we samples the rpn_labels
80 | concated_batch_labels = torch.cat(batch_labels_list, dim=0)
81 | concated_batch_bbox_targets = torch.cat(batch_bbox_targets_list, dim=0)
82 | # sample labels
83 | pos_idx, neg_idx = subsample_labels(concated_batch_labels,
84 | config.num_sample_anchors, config.positive_anchor_ratio)
85 | concated_batch_labels.fill_(-1)
86 | concated_batch_labels[pos_idx] = 1
87 | concated_batch_labels[neg_idx] = 0
88 |
89 | final_labels_list.append(concated_batch_labels)
90 | final_bbox_targets_list.append(concated_batch_bbox_targets)
91 | final_labels = torch.cat(final_labels_list, dim=0)
92 | final_bbox_targets = torch.cat(final_bbox_targets_list, dim=0)
93 | return final_labels, final_bbox_targets
94 |
95 | def my_gt_argmax(overlaps):
96 | gt_max_overlaps, _ = torch.max(overlaps, axis=0)
97 | gt_max_mask = overlaps == gt_max_overlaps
98 | gt_argmax_overlaps = []
99 | for i in range(overlaps.shape[-1]):
100 | gt_max_inds = torch.nonzero(gt_max_mask[:, i], as_tuple=False).flatten()
101 | gt_max_ind = gt_max_inds[torch.randperm(gt_max_inds.numel(), device=gt_max_inds.device)[0,None]]
102 | gt_argmax_overlaps.append(gt_max_ind)
103 | gt_argmax_overlaps = torch.cat(gt_argmax_overlaps)
104 | return gt_argmax_overlaps
105 |
106 | def subsample_labels(labels, num_samples, positive_fraction):
107 | positive = torch.nonzero((labels != config.ignore_label) & (labels != 0), as_tuple=False).squeeze(1)
108 | negative = torch.nonzero(labels == 0, as_tuple=False).squeeze(1)
109 |
110 | num_pos = int(num_samples * positive_fraction)
111 | num_pos = min(positive.numel(), num_pos)
112 | num_neg = num_samples - num_pos
113 | num_neg = min(negative.numel(), num_neg)
114 |
115 | # randomly select positive and negative examples
116 | perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
117 | perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
118 |
119 | pos_idx = positive[perm1]
120 | neg_idx = negative[perm2]
121 | return pos_idx, neg_idx
122 |
--------------------------------------------------------------------------------
/lib/det_oprs/fpn_roi_target.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 |
4 | import numpy as np
5 | from config import config
6 | from det_oprs.bbox_opr import box_overlap_opr, bbox_transform_opr, box_overlap_ignore_opr
7 |
8 | @torch.no_grad()
9 | def fpn_roi_target(rpn_rois, im_info, gt_boxes, top_k=1):
10 | return_rois = []
11 | return_labels = []
12 | return_bbox_targets = []
13 | # get per image proposals and gt_boxes
14 | for bid in range(config.train_batch_per_gpu):
15 | gt_boxes_perimg = gt_boxes[bid, :int(im_info[bid, 5]), :]
16 | batch_inds = torch.ones((gt_boxes_perimg.shape[0], 1)).type_as(gt_boxes_perimg) * bid
17 | gt_rois = torch.cat([batch_inds, gt_boxes_perimg[:, :4]], axis=1)
18 | batch_roi_inds = torch.nonzero(rpn_rois[:, 0] == bid, as_tuple=False).flatten()
19 | all_rois = torch.cat([rpn_rois[batch_roi_inds], gt_rois], axis=0)
20 | overlaps_normal, overlaps_ignore = box_overlap_ignore_opr(
21 | all_rois[:, 1:5], gt_boxes_perimg)
22 | overlaps_normal, overlaps_normal_indices = overlaps_normal.sort(descending=True, dim=1)
23 | overlaps_ignore, overlaps_ignore_indices = overlaps_ignore.sort(descending=True, dim=1)
24 | # gt max and indices, ignore max and indices
25 | max_overlaps_normal = overlaps_normal[:, :top_k].flatten()
26 | gt_assignment_normal = overlaps_normal_indices[:, :top_k].flatten()
27 | max_overlaps_ignore = overlaps_ignore[:, :top_k].flatten()
28 | gt_assignment_ignore = overlaps_ignore_indices[:, :top_k].flatten()
29 | # cons masks
30 | ignore_assign_mask = (max_overlaps_normal < config.fg_threshold) * (
31 | max_overlaps_ignore > max_overlaps_normal)
32 | max_overlaps = max_overlaps_normal * ~ignore_assign_mask + \
33 | max_overlaps_ignore * ignore_assign_mask
34 | gt_assignment = gt_assignment_normal * ~ignore_assign_mask + \
35 | gt_assignment_ignore * ignore_assign_mask
36 | labels = gt_boxes_perimg[gt_assignment, 4]
37 | fg_mask = (max_overlaps >= config.fg_threshold) * (labels != config.ignore_label)
38 | bg_mask = (max_overlaps < config.bg_threshold_high) * (
39 | max_overlaps >= config.bg_threshold_low)
40 | fg_mask = fg_mask.reshape(-1, top_k)
41 | bg_mask = bg_mask.reshape(-1, top_k)
42 | pos_max = config.num_rois * config.fg_ratio
43 | fg_inds_mask = subsample_masks(fg_mask[:, 0], pos_max, True)
44 | neg_max = config.num_rois - fg_inds_mask.sum()
45 | bg_inds_mask = subsample_masks(bg_mask[:, 0], neg_max, True)
46 | labels = labels * fg_mask.flatten()
47 | keep_mask = fg_inds_mask + bg_inds_mask
48 | # labels
49 | labels = labels.reshape(-1, top_k)[keep_mask]
50 | gt_assignment = gt_assignment.reshape(-1, top_k)[keep_mask].flatten()
51 | target_boxes = gt_boxes_perimg[gt_assignment, :4]
52 | rois = all_rois[keep_mask]
53 | target_rois = rois.repeat(1, top_k).reshape(-1, all_rois.shape[-1])
54 | bbox_targets = bbox_transform_opr(target_rois[:, 1:5], target_boxes)
55 | if config.rcnn_bbox_normalize_targets:
56 | std_opr = torch.tensor(config.bbox_normalize_stds[None, :]).type_as(bbox_targets)
57 | mean_opr = torch.tensor(config.bbox_normalize_means[None, :]).type_as(bbox_targets)
58 | minus_opr = mean_opr / std_opr
59 | bbox_targets = bbox_targets / std_opr - minus_opr
60 | bbox_targets = bbox_targets.reshape(-1, top_k * 4)
61 | return_rois.append(rois)
62 | return_labels.append(labels)
63 | return_bbox_targets.append(bbox_targets)
64 | if config.train_batch_per_gpu == 1:
65 | return rois, labels, bbox_targets
66 | else:
67 | return_rois = torch.cat(return_rois, axis=0)
68 | return_labels = torch.cat(return_labels, axis=0)
69 | return_bbox_targets = torch.cat(return_bbox_targets, axis=0)
70 | return return_rois, return_labels, return_bbox_targets
71 |
72 | def subsample_masks(masks, num_samples, sample_value):
73 | positive = torch.nonzero(masks.eq(sample_value), as_tuple=False).squeeze(1)
74 | num_mask = len(positive)
75 | num_samples = int(num_samples)
76 | num_final_samples = min(num_mask, num_samples)
77 | num_final_negative = num_mask - num_final_samples
78 | perm = torch.randperm(num_mask, device=masks.device)[:num_final_negative]
79 | negative = positive[perm]
80 | masks[negative] = not sample_value
81 | return masks
82 |
83 |
--------------------------------------------------------------------------------
/lib/det_oprs/loss_opr.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from config import config
4 |
5 | def softmax_loss(score, label, ignore_label=-1):
6 | with torch.no_grad():
7 | max_score = score.max(axis=1, keepdims=True)[0]
8 | score -= max_score
9 | log_prob = score - torch.log(torch.exp(score).sum(axis=1, keepdims=True))
10 | mask = label != ignore_label
11 | vlabel = label * mask
12 | onehot = torch.zeros(vlabel.shape[0], config.num_classes, device=score.device)
13 | onehot.scatter_(1, vlabel.reshape(-1, 1), 1)
14 | loss = -(log_prob * onehot).sum(axis=1)
15 | loss = loss * mask
16 | return loss
17 |
18 | def smooth_l1_loss(pred, target, beta: float):
19 | if beta < 1e-5:
20 | loss = torch.abs(input - target)
21 | else:
22 | abs_x = torch.abs(pred- target)
23 | in_mask = abs_x < beta
24 | loss = torch.where(in_mask, 0.5 * abs_x ** 2 / beta, abs_x - 0.5 * beta)
25 | return loss.sum(axis=1)
26 |
27 | def focal_loss(inputs, targets, alpha=-1, gamma=2):
28 | class_range = torch.arange(1, inputs.shape[1] + 1, device=inputs.device)
29 | pos_pred = (1 - inputs) ** gamma * torch.log(inputs)
30 | neg_pred = inputs ** gamma * torch.log(1 - inputs)
31 |
32 | pos_loss = (targets == class_range) * pos_pred * alpha
33 | neg_loss = (targets != class_range) * neg_pred * (1 - alpha)
34 | loss = -(pos_loss + neg_loss)
35 | return loss.sum(axis=1)
36 |
37 | def emd_loss_softmax(p_b0, p_s0, p_b1, p_s1, targets, labels):
38 | # reshape
39 | pred_delta = torch.cat([p_b0, p_b1], axis=1).reshape(-1, p_b0.shape[-1])
40 | pred_score = torch.cat([p_s0, p_s1], axis=1).reshape(-1, p_s0.shape[-1])
41 | targets = targets.reshape(-1, 4)
42 | labels = labels.long().flatten()
43 | # cons masks
44 | valid_masks = labels >= 0
45 | fg_masks = labels > 0
46 | # multiple class
47 | pred_delta = pred_delta.reshape(-1, config.num_classes, 4)
48 | fg_gt_classes = labels[fg_masks]
49 | pred_delta = pred_delta[fg_masks, fg_gt_classes, :]
50 | # loss for regression
51 | localization_loss = smooth_l1_loss(
52 | pred_delta,
53 | targets[fg_masks],
54 | config.rcnn_smooth_l1_beta)
55 | # loss for classification
56 | objectness_loss = softmax_loss(pred_score, labels)
57 | loss = objectness_loss * valid_masks
58 | loss[fg_masks] = loss[fg_masks] + localization_loss
59 | loss = loss.reshape(-1, 2).sum(axis=1)
60 | return loss.reshape(-1, 1)
61 |
62 | def emd_loss_focal(p_b0, p_s0, p_b1, p_s1, targets, labels):
63 | pred_delta = torch.cat([p_b0, p_b1], axis=1).reshape(-1, p_b0.shape[-1])
64 | pred_score = torch.cat([p_s0, p_s1], axis=1).reshape(-1, p_s0.shape[-1])
65 | targets = targets.reshape(-1, 4)
66 | labels = labels.long().reshape(-1, 1)
67 | valid_mask = (labels >= 0).flatten()
68 | objectness_loss = focal_loss(pred_score, labels,
69 | config.focal_loss_alpha, config.focal_loss_gamma)
70 | fg_masks = (labels > 0).flatten()
71 | localization_loss = smooth_l1_loss(
72 | pred_delta[fg_masks],
73 | targets[fg_masks],
74 | config.smooth_l1_beta)
75 | loss = objectness_loss * valid_mask
76 | loss[fg_masks] = loss[fg_masks] + localization_loss
77 | loss = loss.reshape(-1, 2).sum(axis=1)
78 | return loss.reshape(-1, 1)
79 |
--------------------------------------------------------------------------------
/lib/det_oprs/retina_anchor_target.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import numpy as np
4 | from config import config
5 | from det_oprs.bbox_opr import box_overlap_opr, bbox_transform_opr
6 |
7 | @torch.no_grad()
8 | def retina_anchor_target(anchors, gt_boxes, im_info, top_k=1):
9 | total_anchor = anchors.shape[0]
10 | return_labels = []
11 | return_bbox_targets = []
12 | # get per image proposals and gt_boxes
13 | for bid in range(config.train_batch_per_gpu):
14 | gt_boxes_perimg = gt_boxes[bid, :int(im_info[bid, 5]), :]
15 | anchors = anchors.type_as(gt_boxes_perimg)
16 | overlaps = box_overlap_opr(anchors, gt_boxes_perimg[:, :-1])
17 | # gt max and indices
18 | max_overlaps, gt_assignment = overlaps.topk(top_k, dim=1, sorted=True)
19 | max_overlaps= max_overlaps.flatten()
20 | gt_assignment= gt_assignment.flatten()
21 | _, gt_assignment_for_gt = torch.max(overlaps, axis=0)
22 | del overlaps
23 | # cons labels
24 | labels = gt_boxes_perimg[gt_assignment, 4]
25 | labels = labels * (max_overlaps >= config.negative_thresh)
26 | ignore_mask = (max_overlaps < config.positive_thresh) * (
27 | max_overlaps >= config.negative_thresh)
28 | labels[ignore_mask] = -1
29 | # cons bbox targets
30 | target_boxes = gt_boxes_perimg[gt_assignment, :4]
31 | target_anchors = anchors.repeat(1, top_k).reshape(-1, anchors.shape[-1])
32 | bbox_targets = bbox_transform_opr(target_anchors, target_boxes)
33 | if config.allow_low_quality:
34 | labels[gt_assignment_for_gt] = gt_boxes_perimg[:, 4]
35 | low_quality_bbox_targets = bbox_transform_opr(
36 | anchors[gt_assignment_for_gt], gt_boxes_perimg[:, :4])
37 | bbox_targets[gt_assignment_for_gt] = low_quality_bbox_targets
38 | labels = labels.reshape(-1, 1 * top_k)
39 | bbox_targets = bbox_targets.reshape(-1, 4 * top_k)
40 | return_labels.append(labels)
41 | return_bbox_targets.append(bbox_targets)
42 |
43 | if config.train_batch_per_gpu == 1:
44 | return labels, bbox_targets
45 | else:
46 | return_labels = torch.cat(return_labels, axis=0)
47 | return_bbox_targets = torch.cat(return_bbox_targets, axis=0)
48 | return return_labels, return_bbox_targets
49 |
50 |
--------------------------------------------------------------------------------
/lib/det_oprs/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import pickle
3 |
4 | import torch
5 |
6 | def get_padded_tensor(tensor, multiple_number, pad_value=0):
7 | t_height, t_width = tensor.shape[-2], tensor.shape[-1]
8 | padded_height = (t_height + multiple_number - 1) // \
9 | multiple_number * multiple_number
10 | padded_width = (t_width + multiple_number - 1) // \
11 | multiple_number * multiple_number
12 | ndim = tensor.ndim
13 | if ndim == 4:
14 | padded_tensor = torch.ones([tensor.shape[0], tensor.shape[1], padded_height, padded_width]) * pad_value
15 | padded_tensor = padded_tensor.type_as(tensor)
16 | padded_tensor[:, :, :t_height, :t_width] = tensor
17 | elif ndim == 3:
18 | padded_tensor = torch.ones([tensor.shape[0], padded_height, padded_width]) * pad_value
19 | padded_tensor = padded_tensor.type_as(tensor)
20 | padded_tensor[:, :t_height, :t_width] = tensor
21 | else:
22 | raise Exception('Not supported tensor dim: {}'.format(ndim))
23 | return padded_tensor
24 |
25 | def _init_backbone(backbone, model_path, strict):
26 | state_dict = _load_c2_pickled_weights(model_path)
27 | state_dict = _rename_weights_for_resnet50(state_dict)
28 | backbone.load_state_dict(state_dict, strict=strict)
29 | del state_dict
30 |
31 | def _rename_basic_resnet_weights(layer_keys):
32 | layer_keys = [k.replace("_", ".") for k in layer_keys]
33 | layer_keys = [k.replace(".w", ".weight") for k in layer_keys]
34 | layer_keys = [k.replace(".bn", "_bn") for k in layer_keys]
35 | layer_keys = [k.replace(".b", ".bias") for k in layer_keys]
36 | layer_keys = [k.replace(".biasranch", ".branch") for k in layer_keys]
37 | layer_keys = [k.replace(".biaseta", ".beta") for k in layer_keys]
38 | # Affine-Channel -> BatchNorm enaming
39 | layer_keys = [k.replace("running.mean", "running_mean") for k in layer_keys]
40 | layer_keys = [k.replace("running.var", "running_var") for k in layer_keys]
41 | layer_keys = [k.replace(".beta", ".bias") for k in layer_keys]
42 | layer_keys = [k.replace(".gamma", ".weight") for k in layer_keys]
43 | layer_keys = [k.replace("res.conv1_bn", "bn1") for k in layer_keys]
44 | ## Make torchvision-compatible
45 | layer_keys = [k.replace("res2.", "layer1.") for k in layer_keys]
46 | layer_keys = [k.replace("res3.", "layer2.") for k in layer_keys]
47 | layer_keys = [k.replace("res4.", "layer3.") for k in layer_keys]
48 | layer_keys = [k.replace("res5.", "layer4.") for k in layer_keys]
49 |
50 | layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys]
51 | layer_keys = [k.replace(".branch2a_bn.", ".bn1.") for k in layer_keys]
52 | layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys]
53 | layer_keys = [k.replace(".branch2b_bn.", ".bn2.") for k in layer_keys]
54 | layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys]
55 | layer_keys = [k.replace(".branch2c_bn.", ".bn3.") for k in layer_keys]
56 | layer_keys = [k.replace(".branch1.", ".downsample.0.") for k in layer_keys]
57 | layer_keys = [k.replace(".branch1_bn.", ".downsample.1.") for k in layer_keys]
58 | return layer_keys
59 |
60 | def _rename_weights_for_resnet50(weights):
61 | original_keys = sorted(weights.keys())
62 | layer_keys = sorted(weights.keys())
63 | layer_keys = _rename_basic_resnet_weights(layer_keys)
64 | key_map = {k: v for k, v in zip(original_keys, layer_keys)}
65 | new_weights = OrderedDict()
66 | for k in original_keys:
67 | v = weights[k]
68 | if "_momentum" in k:
69 | continue
70 | if "fc1000" in k:
71 | continue
72 | w = torch.from_numpy(v)
73 | new_weights[key_map[k]] = w
74 | return new_weights
75 |
76 | def _load_c2_pickled_weights(file_path):
77 | with open(file_path, "rb") as f:
78 | if torch._six.PY3:
79 | data = pickle.load(f, encoding="latin1")
80 | else:
81 | data = pickle.load(f)
82 | if "blobs" in data:
83 | weights = data["blobs"]
84 | else:
85 | weights = data
86 | return weights
87 |
--------------------------------------------------------------------------------
/lib/evaluate/APMRToolkits/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf8 -*-
2 | __author__ = 'jyn'
3 | __email__ = 'jyn@megvii.com'
4 |
5 | from .image import *
6 | from .database import *
7 |
--------------------------------------------------------------------------------
/lib/evaluate/APMRToolkits/database.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | from .image import *
5 |
6 | PERSON_CLASSES = ['background', 'person']
7 | # DBBase
8 | class Database(object):
9 | def __init__(self, gtpath=None, dtpath=None, body_key=None, head_key=None, mode=0):
10 | """
11 | mode=0: only body; mode=1: only head
12 | """
13 | self.images = dict()
14 | self.eval_mode = mode
15 | self.loadData(gtpath, body_key, head_key, True)
16 | self.loadData(dtpath, body_key, head_key, False)
17 |
18 | self._ignNum = sum([self.images[i]._ignNum for i in self.images])
19 | self._gtNum = sum([self.images[i]._gtNum for i in self.images])
20 | self._imageNum = len(self.images)
21 | self.scorelist = None
22 |
23 | def loadData(self, fpath, body_key=None, head_key=None, if_gt=True):
24 | assert os.path.isfile(fpath), fpath + " does not exist!"
25 | with open(fpath, "r") as f:
26 | lines = f.readlines()
27 | records = [json.loads(line.strip('\n')) for line in lines]
28 | if if_gt:
29 | for record in records:
30 | self.images[record["ID"]] = Image(self.eval_mode)
31 | self.images[record["ID"]].load(record, body_key, head_key, PERSON_CLASSES, True)
32 | else:
33 | for record in records:
34 | self.images[record["ID"]].load(record, body_key, head_key, PERSON_CLASSES, False)
35 | self.images[record["ID"]].clip_all_boader()
36 |
37 | def compare(self, thres=0.5, matching=None):
38 | """
39 | match the detection results with the groundtruth in the whole database
40 | """
41 | assert matching is None or matching == "VOC", matching
42 | scorelist = list()
43 | for ID in self.images:
44 | if matching == "VOC":
45 | result = self.images[ID].compare_voc(thres)
46 | else:
47 | result = self.images[ID].compare_caltech(thres)
48 | scorelist.extend(result)
49 | # In the descending sort of dtbox score.
50 | scorelist.sort(key=lambda x: x[0][-1], reverse=True)
51 | self.scorelist = scorelist
52 |
53 | def eval_MR(self, ref="CALTECH_-2"):
54 | """
55 | evaluate by Caltech-style log-average miss rate
56 | ref: str - "CALTECH_-2"/"CALTECH_-4"
57 | """
58 | # find greater_than
59 | def _find_gt(lst, target):
60 | for idx, item in enumerate(lst):
61 | if item >= target:
62 | return idx
63 | return len(lst)-1
64 |
65 | assert ref == "CALTECH_-2" or ref == "CALTECH_-4", ref
66 | if ref == "CALTECH_-2":
67 | # CALTECH_MRREF_2: anchor points (from 10^-2 to 1) as in P.Dollar's paper
68 | ref = [0.0100, 0.0178, 0.03160, 0.0562, 0.1000, 0.1778, 0.3162, 0.5623, 1.000]
69 | else:
70 | # CALTECH_MRREF_4: anchor points (from 10^-4 to 1) as in S.Zhang's paper
71 | ref = [0.0001, 0.0003, 0.00100, 0.0032, 0.0100, 0.0316, 0.1000, 0.3162, 1.000]
72 |
73 | if self.scorelist is None:
74 | self.compare()
75 |
76 | tp, fp = 0.0, 0.0
77 | fppiX, fppiY = list(), list()
78 | for i, item in enumerate(self.scorelist):
79 | if item[1] == 1:
80 | tp += 1.0
81 | elif item[1] == 0:
82 | fp += 1.0
83 |
84 | fn = (self._gtNum - self._ignNum) - tp
85 | recall = tp / (tp + fn)
86 | precision = tp / (tp + fp)
87 | missrate = 1.0 - recall
88 | fppi = fp / self._imageNum
89 | fppiX.append(fppi)
90 | fppiY.append(missrate)
91 |
92 | score = list()
93 | for pos in ref:
94 | argmin = _find_gt(fppiX, pos)
95 | if argmin >= 0:
96 | score.append(fppiY[argmin])
97 | score = np.array(score)
98 | MR = np.exp(np.log(score).mean())
99 | return MR, (fppiX, fppiY)
100 |
101 | def eval_AP(self):
102 | """
103 | :meth: evaluate by average precision
104 | """
105 | # calculate general ap score
106 | def _calculate_map(recall, precision):
107 | assert len(recall) == len(precision)
108 | area = 0
109 | for i in range(1, len(recall)):
110 | delta_h = (precision[i-1] + precision[i]) / 2
111 | delta_w = recall[i] - recall[i-1]
112 | area += delta_w * delta_h
113 | return area
114 |
115 | tp, fp = 0.0, 0.0
116 | rpX, rpY = list(), list()
117 | total_det = len(self.scorelist)
118 | total_gt = self._gtNum - self._ignNum
119 | total_images = self._imageNum
120 |
121 | fpn = []
122 | recalln = []
123 | thr = []
124 | fppi = []
125 | for i, item in enumerate(self.scorelist):
126 | if item[1] == 1:
127 | tp += 1.0
128 | elif item[1] == 0:
129 | fp += 1.0
130 | fn = total_gt - tp
131 | recall = tp / (tp + fn)
132 | precision = tp / (tp + fp)
133 | rpX.append(recall)
134 | rpY.append(precision)
135 | fpn.append(fp)
136 | recalln.append(tp)
137 | thr.append(item[0][-1])
138 | fppi.append(fp/total_images)
139 |
140 | AP = _calculate_map(rpX, rpY)
141 | return AP, (rpX, rpY, thr, fpn, recalln, fppi)
142 |
143 |
--------------------------------------------------------------------------------
/lib/evaluate/APMRToolkits/image.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class Image(object):
4 | def __init__(self, mode):
5 | self.ID = None
6 | self._width = None
7 | self._height = None
8 | self.dtboxes = None
9 | self.gtboxes = None
10 | self.eval_mode = mode
11 |
12 | self._ignNum = None
13 | self._gtNum = None
14 | self._dtNum = None
15 |
16 | def load(self, record, body_key, head_key, class_names, gtflag):
17 | """
18 | :meth: read the object from a dict
19 | """
20 | if "ID" in record and self.ID is None:
21 | self.ID = record['ID']
22 | if "width" in record and self._width is None:
23 | self._width = record["width"]
24 | if "height" in record and self._height is None:
25 | self._height = record["height"]
26 | if gtflag:
27 | self._gtNum = len(record["gtboxes"])
28 | body_bbox, head_bbox = self.load_gt_boxes(record, 'gtboxes', class_names)
29 | if self.eval_mode == 0:
30 | self.gtboxes = body_bbox
31 | self._ignNum = (body_bbox[:, -1] == -1).sum()
32 | elif self.eval_mode == 1:
33 | self.gtboxes = head_bbox
34 | self._ignNum = (head_bbox[:, -1] == -1).sum()
35 | elif self.eval_mode == 2:
36 | gt_tag = np.array([body_bbox[i,-1]!=-1 and head_bbox[i,-1]!=-1 for i in range(len(body_bbox))])
37 | self._ignNum = (gt_tag == 0).sum()
38 | self.gtboxes = np.hstack((body_bbox[:, :-1], head_bbox[:, :-1], gt_tag.reshape(-1, 1)))
39 | else:
40 | raise Exception('Unknown evaluation mode!')
41 | if not gtflag:
42 | self._dtNum = len(record["dtboxes"])
43 | if self.eval_mode == 0:
44 | self.dtboxes = self.load_det_boxes(record, 'dtboxes', body_key, 'score')
45 | elif self.eval_mode == 1:
46 | self.dtboxes = self.load_det_boxes(record, 'dtboxes', head_key, 'score')
47 | elif self.eval_mode == 2:
48 | body_dtboxes = self.load_det_boxes(record, 'dtboxes', body_key)
49 | head_dtboxes = self.load_det_boxes(record, 'dtboxes', head_key, 'score')
50 | self.dtboxes = np.hstack((body_dtboxes, head_dtboxes))
51 | else:
52 | raise Exception('Unknown evaluation mode!')
53 |
54 | def compare_caltech(self, thres):
55 | """
56 | :meth: match the detection results with the groundtruth by Caltech matching strategy
57 | :param thres: iou threshold
58 | :type thres: float
59 | :return: a list of tuples (dtbox, imageID), in the descending sort of dtbox.score
60 | """
61 | dtboxes = self.dtboxes if self.dtboxes is not None else list()
62 | gtboxes = self.gtboxes if self.gtboxes is not None else list()
63 | dt_matched = np.zeros(dtboxes.shape[0])
64 | gt_matched = np.zeros(gtboxes.shape[0])
65 |
66 | dtboxes = np.array(sorted(dtboxes, key=lambda x: x[-1], reverse=True))
67 | gtboxes = np.array(sorted(gtboxes, key=lambda x: x[-1], reverse=True))
68 | if len(dtboxes):
69 | overlap_iou = self.box_overlap_opr(dtboxes, gtboxes, True)
70 | overlap_ioa = self.box_overlap_opr(dtboxes, gtboxes, False)
71 | else:
72 | return list()
73 |
74 | scorelist = list()
75 | for i, dt in enumerate(dtboxes):
76 | maxpos = -1
77 | maxiou = thres
78 | for j, gt in enumerate(gtboxes):
79 | if gt_matched[j] == 1:
80 | continue
81 | if gt[-1] > 0:
82 | overlap = overlap_iou[i][j]
83 | if overlap > maxiou:
84 | maxiou = overlap
85 | maxpos = j
86 | else:
87 | if maxpos >= 0:
88 | break
89 | else:
90 | overlap = overlap_ioa[i][j]
91 | if overlap > thres:
92 | maxiou = overlap
93 | maxpos = j
94 | if maxpos >= 0:
95 | if gtboxes[maxpos, -1] > 0:
96 | gt_matched[maxpos] = 1
97 | dt_matched[i] = 1
98 | scorelist.append((dt, 1, self.ID))
99 | else:
100 | dt_matched[i] = -1
101 | else:
102 | dt_matched[i] = 0
103 | scorelist.append((dt, 0, self.ID))
104 | return scorelist
105 |
106 | def compare_caltech_union(self, thres):
107 | """
108 | :meth: match the detection results with the groundtruth by Caltech matching strategy
109 | :param thres: iou threshold
110 | :type thres: float
111 | :return: a list of tuples (dtbox, imageID), in the descending sort of dtbox.score
112 | """
113 | dtboxes = self.dtboxes if self.dtboxes is not None else list()
114 | gtboxes = self.gtboxes if self.gtboxes is not None else list()
115 | if len(dtboxes) == 0:
116 | return list()
117 | dt_matched = np.zeros(dtboxes.shape[0])
118 | gt_matched = np.zeros(gtboxes.shape[0])
119 |
120 | dtboxes = np.array(sorted(dtboxes, key=lambda x: x[-1], reverse=True))
121 | gtboxes = np.array(sorted(gtboxes, key=lambda x: x[-1], reverse=True))
122 | dt_body_boxes = np.hstack((dtboxes[:, :4], dtboxes[:, -1][:,None]))
123 | dt_head_boxes = dtboxes[:, 4:8]
124 | gt_body_boxes = np.hstack((gtboxes[:, :4], gtboxes[:, -1][:,None]))
125 | gt_head_boxes = gtboxes[:, 4:8]
126 | overlap_iou = self.box_overlap_opr(dt_body_boxes, gt_body_boxes, True)
127 | overlap_head = self.box_overlap_opr(dt_head_boxes, gt_head_boxes, True)
128 | overlap_ioa = self.box_overlap_opr(dt_body_boxes, gt_body_boxes, False)
129 |
130 | scorelist = list()
131 | for i, dt in enumerate(dtboxes):
132 | maxpos = -1
133 | maxiou = thres
134 | for j, gt in enumerate(gtboxes):
135 | if gt_matched[j] == 1:
136 | continue
137 | if gt[-1] > 0:
138 | o_body = overlap_iou[i][j]
139 | o_head = overlap_head[i][j]
140 | if o_body > maxiou and o_head > maxiou:
141 | maxiou = o_body
142 | maxpos = j
143 | else:
144 | if maxpos >= 0:
145 | break
146 | else:
147 | o_body = overlap_ioa[i][j]
148 | if o_body > thres:
149 | maxiou = o_body
150 | maxpos = j
151 | if maxpos >= 0:
152 | if gtboxes[maxpos, -1] > 0:
153 | gt_matched[maxpos] = 1
154 | dt_matched[i] = 1
155 | scorelist.append((dt, 1, self.ID))
156 | else:
157 | dt_matched[i] = -1
158 | else:
159 | dt_matched[i] = 0
160 | scorelist.append((dt, 0, self.ID))
161 | return scorelist
162 |
163 | def box_overlap_opr(self, dboxes:np.ndarray, gboxes:np.ndarray, if_iou):
164 | eps = 1e-6
165 | assert dboxes.shape[-1] >= 4 and gboxes.shape[-1] >= 4
166 | N, K = dboxes.shape[0], gboxes.shape[0]
167 | dtboxes = np.tile(np.expand_dims(dboxes, axis = 1), (1, K, 1))
168 | gtboxes = np.tile(np.expand_dims(gboxes, axis = 0), (N, 1, 1))
169 |
170 | iw = np.minimum(dtboxes[:,:,2], gtboxes[:,:,2]) - np.maximum(dtboxes[:,:,0], gtboxes[:,:,0])
171 | ih = np.minimum(dtboxes[:,:,3], gtboxes[:,:,3]) - np.maximum(dtboxes[:,:,1], gtboxes[:,:,1])
172 | inter = np.maximum(0, iw) * np.maximum(0, ih)
173 |
174 | dtarea = (dtboxes[:,:,2] - dtboxes[:,:,0]) * (dtboxes[:,:,3] - dtboxes[:,:,1])
175 | if if_iou:
176 | gtarea = (gtboxes[:,:,2] - gtboxes[:,:,0]) * (gtboxes[:,:,3] - gtboxes[:,:,1])
177 | ious = inter / (dtarea + gtarea - inter + eps)
178 | else:
179 | ious = inter / (dtarea + eps)
180 | return ious
181 |
182 | def clip_all_boader(self):
183 |
184 | def _clip_boundary(boxes,height,width):
185 | assert boxes.shape[-1]>=4
186 | boxes[:,0] = np.minimum(np.maximum(boxes[:,0],0), width - 1)
187 | boxes[:,1] = np.minimum(np.maximum(boxes[:,1],0), height - 1)
188 | boxes[:,2] = np.maximum(np.minimum(boxes[:,2],width), 0)
189 | boxes[:,3] = np.maximum(np.minimum(boxes[:,3],height), 0)
190 | return boxes
191 |
192 | assert self.dtboxes.shape[-1]>=4
193 | assert self.gtboxes.shape[-1]>=4
194 | assert self._width is not None and self._height is not None
195 | if self.eval_mode == 2:
196 | self.dtboxes[:, :4] = _clip_boundary(self.dtboxes[:, :4], self._height, self._width)
197 | self.gtboxes[:, :4] = _clip_boundary(self.gtboxes[:, :4], self._height, self._width)
198 | self.dtboxes[:, 4:8] = _clip_boundary(self.dtboxes[:, 4:8], self._height, self._width)
199 | self.gtboxes[:, 4:8] = _clip_boundary(self.gtboxes[:, 4:8], self._height, self._width)
200 | else:
201 | self.dtboxes = _clip_boundary(self.dtboxes, self._height, self._width)
202 | self.gtboxes = _clip_boundary(self.gtboxes, self._height, self._width)
203 |
204 | def load_gt_boxes(self, dict_input, key_name, class_names):
205 | assert key_name in dict_input
206 | if len(dict_input[key_name]) < 1:
207 | return np.empty([0, 5])
208 | head_bbox = []
209 | body_bbox = []
210 | for rb in dict_input[key_name]:
211 | if rb['tag'] in class_names:
212 | body_tag = class_names.index(rb['tag'])
213 | head_tag = 1
214 | else:
215 | body_tag = -1
216 | head_tag = -1
217 | if 'extra' in rb:
218 | if 'ignore' in rb['extra']:
219 | if rb['extra']['ignore'] != 0:
220 | body_tag = -1
221 | head_tag = -1
222 | if 'head_attr' in rb:
223 | if 'ignore' in rb['head_attr']:
224 | if rb['head_attr']['ignore'] != 0:
225 | head_tag = -1
226 | head_bbox.append(np.hstack((rb['hbox'], head_tag)))
227 | body_bbox.append(np.hstack((rb['fbox'], body_tag)))
228 | head_bbox = np.array(head_bbox)
229 | head_bbox[:, 2:4] += head_bbox[:, :2]
230 | body_bbox = np.array(body_bbox)
231 | body_bbox[:, 2:4] += body_bbox[:, :2]
232 | return body_bbox, head_bbox
233 |
234 | def load_det_boxes(self, dict_input, key_name, key_box, key_score=None, key_tag=None):
235 | assert key_name in dict_input
236 | if len(dict_input[key_name]) < 1:
237 | return np.empty([0, 5])
238 | else:
239 | assert key_box in dict_input[key_name][0]
240 | if key_score:
241 | assert key_score in dict_input[key_name][0]
242 | if key_tag:
243 | assert key_tag in dict_input[key_name][0]
244 | if key_score:
245 | if key_tag:
246 | bboxes = np.vstack([np.hstack((rb[key_box], rb[key_score], rb[key_tag])) for rb in dict_input[key_name]])
247 | else:
248 | bboxes = np.vstack([np.hstack((rb[key_box], rb[key_score])) for rb in dict_input[key_name]])
249 | else:
250 | if key_tag:
251 | bboxes = np.vstack([np.hstack((rb[key_box], rb[key_tag])) for rb in dict_input[key_name]])
252 | else:
253 | bboxes = np.vstack([rb[key_box] for rb in dict_input[key_name]])
254 | bboxes[:, 2:4] += bboxes[:, :2]
255 | return bboxes
256 |
257 | def compare_voc(self, thres):
258 | """
259 | :meth: match the detection results with the groundtruth by VOC matching strategy
260 | :param thres: iou threshold
261 | :type thres: float
262 | :return: a list of tuples (dtbox, imageID), in the descending sort of dtbox.score
263 | """
264 | if self.dtboxes is None:
265 | return list()
266 | dtboxes = self.dtboxes
267 | gtboxes = self.gtboxes if self.gtboxes is not None else list()
268 | dtboxes.sort(key=lambda x: x.score, reverse=True)
269 | gtboxes.sort(key=lambda x: x.ign)
270 |
271 | scorelist = list()
272 | for i, dt in enumerate(dtboxes):
273 | maxpos = -1
274 | maxiou = thres
275 |
276 | for j, gt in enumerate(gtboxes):
277 | overlap = dt.iou(gt)
278 | if overlap > maxiou:
279 | maxiou = overlap
280 | maxpos = j
281 |
282 | if maxpos >= 0:
283 | if gtboxes[maxpos].ign == 0:
284 | gtboxes[maxpos].matched = 1
285 | dtboxes[i].matched = 1
286 | scorelist.append((dt, self.ID))
287 | else:
288 | dtboxes[i].matched = -1
289 | else:
290 | dtboxes[i].matched = 0
291 | scorelist.append((dt, self.ID))
292 | return scorelist
293 |
--------------------------------------------------------------------------------
/lib/evaluate/JIToolkits/JI_tools.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 | import numpy as np
3 | from .matching import maxWeightMatching
4 |
5 | def compute_matching(dt_boxes, gt_boxes, bm_thr):
6 | assert dt_boxes.shape[-1] > 3 and gt_boxes.shape[-1] > 3
7 | if dt_boxes.shape[0] < 1 or gt_boxes.shape[0] < 1:
8 | return list()
9 | N, K = dt_boxes.shape[0], gt_boxes.shape[0]
10 | ious = compute_iou_matrix(dt_boxes, gt_boxes)
11 | rows, cols = np.where(ious > bm_thr)
12 | bipartites = [(i + 1, j + N + 1, ious[i, j]) for (i, j) in zip(rows, cols)]
13 | mates = maxWeightMatching(bipartites)
14 | if len(mates) < 1:
15 | return list()
16 | rows = np.where(np.array(mates) > -1)[0]
17 | indices = np.where(rows < N + 1)[0]
18 | rows = rows[indices]
19 | cols = np.array([mates[i] for i in rows])
20 | matches = [(i-1, j - N - 1) for (i, j) in zip(rows, cols)]
21 | return matches
22 |
23 | def compute_head_body_matching(dt_body, dt_head, gt_body, gt_head, bm_thr):
24 | assert dt_body.shape[-1] > 3 and gt_body.shape[-1] > 3
25 | assert dt_head.shape[-1] > 3 and gt_head.shape[-1] > 3
26 | assert dt_body.shape[0] == dt_head.shape[0]
27 | assert gt_body.shape[0] == gt_head.shape[0]
28 | N, K = dt_body.shape[0], gt_body.shape[0]
29 | ious_body = compute_iou_matrix(dt_body, gt_body)
30 | ious_head = compute_iou_matrix(dt_head, gt_head)
31 | mask_body = ious_body > bm_thr
32 | mask_head = ious_head > bm_thr
33 | # only keep the both matches detections
34 | mask = np.array(mask_body) & np.array(mask_head)
35 | ious = np.zeros((N, K))
36 | ious[mask] = (ious_body[mask] + ious_head[mask]) / 2
37 | rows, cols = np.where(ious > bm_thr)
38 | bipartites = [(i + 1, j + N + 1, ious[i, j]) for (i, j) in zip(rows, cols)]
39 | mates = maxWeightMatching(bipartites)
40 | if len(mates) < 1:
41 | return list()
42 | rows = np.where(np.array(mates) > -1)[0]
43 | indices = np.where(rows < N + 1)[0]
44 | rows = rows[indices]
45 | cols = np.array([mates[i] for i in rows])
46 | matches = [(i-1, j - N - 1) for (i, j) in zip(rows, cols)]
47 | return matches
48 |
49 | def compute_multi_head_body_matching(dt_body, dt_head_0, dt_head_1, gt_body, gt_head, bm_thr):
50 | assert dt_body.shape[-1] > 3 and gt_body.shape[-1] > 3
51 | assert dt_head_0.shape[-1] > 3 and gt_head.shape[-1] > 3
52 | assert dt_head_1.shape[-1] > 3 and gt_head.shape[-1] > 3
53 | assert dt_body.shape[0] == dt_head_0.shape[0]
54 | assert gt_body.shape[0] == gt_head.shape[0]
55 | N, K = dt_body.shape[0], gt_body.shape[0]
56 | ious_body = compute_iou_matrix(dt_body, gt_body)
57 | ious_head_0 = compute_iou_matrix(dt_head_0, gt_head)
58 | ious_head_1 = compute_iou_matrix(dt_head_1, gt_head)
59 | mask_body = ious_body > bm_thr
60 | mask_head_0 = ious_head_0 > bm_thr
61 | mask_head_1 = ious_head_1 > bm_thr
62 | mask_head = mask_head_0 | mask_head_1
63 | # only keep the both matches detections
64 | mask = np.array(mask_body) & np.array(mask_head)
65 | ious = np.zeros((N, K))
66 | #ious[mask] = (ious_body[mask] + ious_head[mask]) / 2
67 | ious[mask] = ious_body[mask]
68 | rows, cols = np.where(ious > bm_thr)
69 | bipartites = [(i + 1, j + N + 1, ious[i, j]) for (i, j) in zip(rows, cols)]
70 | mates = maxWeightMatching(bipartites)
71 | if len(mates) < 1:
72 | return list()
73 | rows = np.where(np.array(mates) > -1)[0]
74 | indices = np.where(rows < N + 1)[0]
75 | rows = rows[indices]
76 | cols = np.array([mates[i] for i in rows])
77 | matches = [(i-1, j - N - 1) for (i, j) in zip(rows, cols)]
78 | return matches
79 |
80 | def get_head_body_ignores(dt_body, dt_head, gt_body, gt_head, bm_thr):
81 | if gt_body.size:
82 | body_ioas = compute_ioa_matrix(dt_body, gt_body)
83 | head_ioas = compute_ioa_matrix(dt_head, gt_head)
84 | body_ioas = np.max(body_ioas, axis=1)
85 | head_ioas = np.max(head_ioas, axis=1)
86 | head_rows = np.where(head_ioas > bm_thr)[0]
87 | body_rows = np.where(body_ioas > bm_thr)[0]
88 | rows = set.union(set(head_rows), set(body_rows))
89 | return len(rows)
90 | else:
91 | return 0
92 |
93 | def get_ignores(dt_boxes, gt_boxes, bm_thr):
94 | if gt_boxes.size:
95 | ioas = compute_ioa_matrix(dt_boxes, gt_boxes)
96 | ioas = np.max(ioas, axis = 1)
97 | rows = np.where(ioas > bm_thr)[0]
98 | return len(rows)
99 | else:
100 | return 0
101 |
102 | def compute_ioa_matrix(dboxes: np.ndarray, gboxes: np.ndarray):
103 | eps = 1e-6
104 | assert dboxes.shape[-1] >= 4 and gboxes.shape[-1] >= 4
105 | N, K = dboxes.shape[0], gboxes.shape[0]
106 | dtboxes = np.tile(np.expand_dims(dboxes, axis = 1), (1, K, 1))
107 | gtboxes = np.tile(np.expand_dims(gboxes, axis = 0), (N, 1, 1))
108 |
109 | iw = np.minimum(dtboxes[:,:,2], gtboxes[:,:,2]) - np.maximum(dtboxes[:,:,0], gtboxes[:,:,0])
110 | ih = np.minimum(dtboxes[:,:,3], gtboxes[:,:,3]) - np.maximum(dtboxes[:,:,1], gtboxes[:,:,1])
111 | inter = np.maximum(0, iw) * np.maximum(0, ih)
112 |
113 | dtarea = np.maximum(dtboxes[:,:,2] - dtboxes[:,:,0], 0) * np.maximum(dtboxes[:,:,3] - dtboxes[:,:,1], 0)
114 | ioas = inter / (dtarea + eps)
115 | return ioas
116 |
117 | def compute_iou_matrix(dboxes:np.ndarray, gboxes:np.ndarray):
118 | eps = 1e-6
119 | assert dboxes.shape[-1] >= 4 and gboxes.shape[-1] >= 4
120 | N, K = dboxes.shape[0], gboxes.shape[0]
121 | dtboxes = np.tile(np.expand_dims(dboxes, axis = 1), (1, K, 1))
122 | gtboxes = np.tile(np.expand_dims(gboxes, axis = 0), (N, 1, 1))
123 |
124 | iw = np.minimum(dtboxes[:,:,2], gtboxes[:,:,2]) - np.maximum(dtboxes[:,:,0], gtboxes[:,:,0])
125 | ih = np.minimum(dtboxes[:,:,3], gtboxes[:,:,3]) - np.maximum(dtboxes[:,:,1], gtboxes[:,:,1])
126 | inter = np.maximum(0, iw) * np.maximum(0, ih)
127 |
128 | dtarea = (dtboxes[:,:,2] - dtboxes[:,:,0]) * (dtboxes[:,:,3] - dtboxes[:,:,1])
129 | gtarea = (gtboxes[:,:,2] - gtboxes[:,:,0]) * (gtboxes[:,:,3] - gtboxes[:,:,1])
130 | ious = inter / (dtarea + gtarea - inter + eps)
131 | return ious
132 |
133 |
--------------------------------------------------------------------------------
/lib/evaluate/JIToolkits/matching.py:
--------------------------------------------------------------------------------
1 | """Weighted maximum matching in general graphs.
2 |
3 | The algorithm is taken from "Efficient Algorithms for Finding Maximum
4 | Matching in Graphs" by Zvi Galil, ACM Computing Surveys, 1986.
5 | It is based on the "blossom" method for finding augmenting paths and
6 | the "primal-dual" method for finding a matching of maximum weight, both
7 | due to Jack Edmonds.
8 | Some ideas came from "Implementation of algorithms for maximum matching
9 | on non-bipartite graphs" by H.J. Gabow, Standford Ph.D. thesis, 1973.
10 |
11 | A C program for maximum weight matching by Ed Rothberg was used extensively
12 | to validate this new code.
13 | """
14 |
15 | #
16 | # Changes:
17 | #
18 | # 2013-04-07
19 | # * Added Python 3 compatibility with contributions from Daniel Saunders.
20 | #
21 | # 2008-06-08
22 | # * First release.
23 | #
24 |
25 | from __future__ import print_function
26 |
27 | # If assigned, DEBUG(str) is called with lots of debug messages.
28 | DEBUG = None
29 | """def DEBUG(s):
30 | from sys import stderr
31 | print('DEBUG:', s, file=stderr)
32 | """
33 |
34 | # Check delta2/delta3 computation after every substage;
35 | # only works on integer weights, slows down the algorithm to O(n^4).
36 | CHECK_DELTA = False
37 |
38 | # Check optimality of solution before returning; only works on integer weights.
39 | CHECK_OPTIMUM = False
40 |
41 | def maxWeightMatching(edges, maxcardinality=False):
42 | """Compute a maximum-weighted matching in the general undirected
43 | weighted graph given by "edges". If "maxcardinality" is true,
44 | only maximum-cardinality matchings are considered as solutions.
45 |
46 | Edges is a sequence of tuples (i, j, wt) describing an undirected
47 | edge between vertex i and vertex j with weight wt. There is at most
48 | one edge between any two vertices; no vertex has an edge to itself.
49 | Vertices are identified by consecutive, non-negative integers.
50 |
51 | Return a list "mate", such that mate[i] == j if vertex i is
52 | matched to vertex j, and mate[i] == -1 if vertex i is not matched.
53 |
54 | This function takes time O(n ** 3)."""
55 |
56 | #
57 | # Vertices are numbered 0 .. (nvertex-1).
58 | # Non-trivial blossoms are numbered nvertex .. (2*nvertex-1)
59 | #
60 | # Edges are numbered 0 .. (nedge-1).
61 | # Edge endpoints are numbered 0 .. (2*nedge-1), such that endpoints
62 | # (2*k) and (2*k+1) both belong to edge k.
63 | #
64 | # Many terms used in the comments (sub-blossom, T-vertex) come from
65 | # the paper by Galil; read the paper before reading this code.
66 | #
67 |
68 | # Python 2/3 compatibility.
69 | from sys import version as sys_version
70 | if sys_version < '3':
71 | integer_types = (int, long)
72 | else:
73 | integer_types = (int,)
74 |
75 | # Deal swiftly with empty graphs.
76 | if not edges:
77 | return [ ]
78 |
79 | # Count vertices.
80 | nedge = len(edges)
81 | nvertex = 0
82 | for (i, j, w) in edges:
83 | assert i >= 0 and j >= 0 and i != j
84 | if i >= nvertex:
85 | nvertex = i + 1
86 | if j >= nvertex:
87 | nvertex = j + 1
88 |
89 | # Find the maximum edge weight.
90 | maxweight = max(0, max([ wt for (i, j, wt) in edges ]))
91 |
92 | # If p is an edge endpoint,
93 | # endpoint[p] is the vertex to which endpoint p is attached.
94 | # Not modified by the algorithm.
95 | endpoint = [ edges[p//2][p%2] for p in range(2*nedge) ]
96 |
97 | # If v is a vertex,
98 | # neighbend[v] is the list of remote endpoints of the edges attached to v.
99 | # Not modified by the algorithm.
100 | neighbend = [ [ ] for i in range(nvertex) ]
101 | for k in range(len(edges)):
102 | (i, j, w) = edges[k]
103 | neighbend[i].append(2*k+1)
104 | neighbend[j].append(2*k)
105 |
106 | # If v is a vertex,
107 | # mate[v] is the remote endpoint of its matched edge, or -1 if it is single
108 | # (i.e. endpoint[mate[v]] is v's partner vertex).
109 | # Initially all vertices are single; updated during augmentation.
110 | mate = nvertex * [ -1 ]
111 |
112 | # If b is a top-level blossom,
113 | # label[b] is 0 if b is unlabeled (free);
114 | # 1 if b is an S-vertex/blossom;
115 | # 2 if b is a T-vertex/blossom.
116 | # The label of a vertex is found by looking at the label of its
117 | # top-level containing blossom.
118 | # If v is a vertex inside a T-blossom,
119 | # label[v] is 2 iff v is reachable from an S-vertex outside the blossom.
120 | # Labels are assigned during a stage and reset after each augmentation.
121 | label = (2 * nvertex) * [ 0 ]
122 |
123 | # If b is a labeled top-level blossom,
124 | # labelend[b] is the remote endpoint of the edge through which b obtained
125 | # its label, or -1 if b's base vertex is single.
126 | # If v is a vertex inside a T-blossom and label[v] == 2,
127 | # labelend[v] is the remote endpoint of the edge through which v is
128 | # reachable from outside the blossom.
129 | labelend = (2 * nvertex) * [ -1 ]
130 |
131 | # If v is a vertex,
132 | # inblossom[v] is the top-level blossom to which v belongs.
133 | # If v is a top-level vertex, v is itself a blossom (a trivial blossom)
134 | # and inblossom[v] == v.
135 | # Initially all vertices are top-level trivial blossoms.
136 | inblossom = list(range(nvertex))
137 |
138 | # If b is a sub-blossom,
139 | # blossomparent[b] is its immediate parent (sub-)blossom.
140 | # If b is a top-level blossom, blossomparent[b] is -1.
141 | blossomparent = (2 * nvertex) * [ -1 ]
142 |
143 | # If b is a non-trivial (sub-)blossom,
144 | # blossomchilds[b] is an ordered list of its sub-blossoms, starting with
145 | # the base and going round the blossom.
146 | blossomchilds = (2 * nvertex) * [ None ]
147 |
148 | # If b is a (sub-)blossom,
149 | # blossombase[b] is its base VERTEX (i.e. recursive sub-blossom).
150 | blossombase = list(range(nvertex)) + nvertex * [ -1 ]
151 |
152 | # If b is a non-trivial (sub-)blossom,
153 | # blossomendps[b] is a list of endpoints on its connecting edges,
154 | # such that blossomendps[b][i] is the local endpoint of blossomchilds[b][i]
155 | # on the edge that connects it to blossomchilds[b][wrap(i+1)].
156 | blossomendps = (2 * nvertex) * [ None ]
157 |
158 | # If v is a free vertex (or an unreached vertex inside a T-blossom),
159 | # bestedge[v] is the edge to an S-vertex with least slack,
160 | # or -1 if there is no such edge.
161 | # If b is a (possibly trivial) top-level S-blossom,
162 | # bestedge[b] is the least-slack edge to a different S-blossom,
163 | # or -1 if there is no such edge.
164 | # This is used for efficient computation of delta2 and delta3.
165 | bestedge = (2 * nvertex) * [ -1 ]
166 |
167 | # If b is a non-trivial top-level S-blossom,
168 | # blossombestedges[b] is a list of least-slack edges to neighbouring
169 | # S-blossoms, or None if no such list has been computed yet.
170 | # This is used for efficient computation of delta3.
171 | blossombestedges = (2 * nvertex) * [ None ]
172 |
173 | # List of currently unused blossom numbers.
174 | unusedblossoms = list(range(nvertex, 2*nvertex))
175 |
176 | # If v is a vertex,
177 | # dualvar[v] = 2 * u(v) where u(v) is the v's variable in the dual
178 | # optimization problem (multiplication by two ensures integer values
179 | # throughout the algorithm if all edge weights are integers).
180 | # If b is a non-trivial blossom,
181 | # dualvar[b] = z(b) where z(b) is b's variable in the dual optimization
182 | # problem.
183 | dualvar = nvertex * [ maxweight ] + nvertex * [ 0 ]
184 |
185 | # If allowedge[k] is true, edge k has zero slack in the optimization
186 | # problem; if allowedge[k] is false, the edge's slack may or may not
187 | # be zero.
188 | allowedge = nedge * [ False ]
189 |
190 | # Queue of newly discovered S-vertices.
191 | queue = [ ]
192 |
193 | # Return 2 * slack of edge k (does not work inside blossoms).
194 | def slack(k):
195 | (i, j, wt) = edges[k]
196 | return dualvar[i] + dualvar[j] - 2 * wt
197 |
198 | # Generate the leaf vertices of a blossom.
199 | def blossomLeaves(b):
200 | if b < nvertex:
201 | yield b
202 | else:
203 | for t in blossomchilds[b]:
204 | if t < nvertex:
205 | yield t
206 | else:
207 | for v in blossomLeaves(t):
208 | yield v
209 |
210 | # Assign label t to the top-level blossom containing vertex w
211 | # and record the fact that w was reached through the edge with
212 | # remote endpoint p.
213 | def assignLabel(w, t, p):
214 | if DEBUG: DEBUG('assignLabel(%d,%d,%d)' % (w, t, p))
215 | b = inblossom[w]
216 | assert label[w] == 0 and label[b] == 0
217 | label[w] = label[b] = t
218 | labelend[w] = labelend[b] = p
219 | bestedge[w] = bestedge[b] = -1
220 | if t == 1:
221 | # b became an S-vertex/blossom; add it(s vertices) to the queue.
222 | queue.extend(blossomLeaves(b))
223 | if DEBUG: DEBUG('PUSH ' + str(list(blossomLeaves(b))))
224 | elif t == 2:
225 | # b became a T-vertex/blossom; assign label S to its mate.
226 | # (If b is a non-trivial blossom, its base is the only vertex
227 | # with an external mate.)
228 | base = blossombase[b]
229 | assert mate[base] >= 0
230 | assignLabel(endpoint[mate[base]], 1, mate[base] ^ 1)
231 |
232 | # Trace back from vertices v and w to discover either a new blossom
233 | # or an augmenting path. Return the base vertex of the new blossom or -1.
234 | def scanBlossom(v, w):
235 | if DEBUG: DEBUG('scanBlossom(%d,%d)' % (v, w))
236 | # Trace back from v and w, placing breadcrumbs as we go.
237 | path = [ ]
238 | base = -1
239 | while v != -1 or w != -1:
240 | # Look for a breadcrumb in v's blossom or put a new breadcrumb.
241 | b = inblossom[v]
242 | if label[b] & 4:
243 | base = blossombase[b]
244 | break
245 | assert label[b] == 1
246 | path.append(b)
247 | label[b] = 5
248 | # Trace one step back.
249 | assert labelend[b] == mate[blossombase[b]]
250 | if labelend[b] == -1:
251 | # The base of blossom b is single; stop tracing this path.
252 | v = -1
253 | else:
254 | v = endpoint[labelend[b]]
255 | b = inblossom[v]
256 | assert label[b] == 2
257 | # b is a T-blossom; trace one more step back.
258 | assert labelend[b] >= 0
259 | v = endpoint[labelend[b]]
260 | # Swap v and w so that we alternate between both paths.
261 | if w != -1:
262 | v, w = w, v
263 | # Remove breadcrumbs.
264 | for b in path:
265 | label[b] = 1
266 | # Return base vertex, if we found one.
267 | return base
268 |
269 | # Construct a new blossom with given base, containing edge k which
270 | # connects a pair of S vertices. Label the new blossom as S; set its dual
271 | # variable to zero; relabel its T-vertices to S and add them to the queue.
272 | def addBlossom(base, k):
273 | (v, w, wt) = edges[k]
274 | bb = inblossom[base]
275 | bv = inblossom[v]
276 | bw = inblossom[w]
277 | # Create blossom.
278 | b = unusedblossoms.pop()
279 | if DEBUG: DEBUG('addBlossom(%d,%d) (v=%d w=%d) -> %d' % (base, k, v, w, b))
280 | blossombase[b] = base
281 | blossomparent[b] = -1
282 | blossomparent[bb] = b
283 | # Make list of sub-blossoms and their interconnecting edge endpoints.
284 | blossomchilds[b] = path = [ ]
285 | blossomendps[b] = endps = [ ]
286 | # Trace back from v to base.
287 | while bv != bb:
288 | # Add bv to the new blossom.
289 | blossomparent[bv] = b
290 | path.append(bv)
291 | endps.append(labelend[bv])
292 | assert (label[bv] == 2 or
293 | (label[bv] == 1 and labelend[bv] == mate[blossombase[bv]]))
294 | # Trace one step back.
295 | assert labelend[bv] >= 0
296 | v = endpoint[labelend[bv]]
297 | bv = inblossom[v]
298 | # Reverse lists, add endpoint that connects the pair of S vertices.
299 | path.append(bb)
300 | path.reverse()
301 | endps.reverse()
302 | endps.append(2*k)
303 | # Trace back from w to base.
304 | while bw != bb:
305 | # Add bw to the new blossom.
306 | blossomparent[bw] = b
307 | path.append(bw)
308 | endps.append(labelend[bw] ^ 1)
309 | assert (label[bw] == 2 or
310 | (label[bw] == 1 and labelend[bw] == mate[blossombase[bw]]))
311 | # Trace one step back.
312 | assert labelend[bw] >= 0
313 | w = endpoint[labelend[bw]]
314 | bw = inblossom[w]
315 | # Set label to S.
316 | assert label[bb] == 1
317 | label[b] = 1
318 | labelend[b] = labelend[bb]
319 | # Set dual variable to zero.
320 | dualvar[b] = 0
321 | # Relabel vertices.
322 | for v in blossomLeaves(b):
323 | if label[inblossom[v]] == 2:
324 | # This T-vertex now turns into an S-vertex because it becomes
325 | # part of an S-blossom; add it to the queue.
326 | queue.append(v)
327 | inblossom[v] = b
328 | # Compute blossombestedges[b].
329 | bestedgeto = (2 * nvertex) * [ -1 ]
330 | for bv in path:
331 | if blossombestedges[bv] is None:
332 | # This subblossom does not have a list of least-slack edges;
333 | # get the information from the vertices.
334 | nblists = [ [ p // 2 for p in neighbend[v] ]
335 | for v in blossomLeaves(bv) ]
336 | else:
337 | # Walk this subblossom's least-slack edges.
338 | nblists = [ blossombestedges[bv] ]
339 | for nblist in nblists:
340 | for k in nblist:
341 | (i, j, wt) = edges[k]
342 | if inblossom[j] == b:
343 | i, j = j, i
344 | bj = inblossom[j]
345 | if (bj != b and label[bj] == 1 and
346 | (bestedgeto[bj] == -1 or
347 | slack(k) < slack(bestedgeto[bj]))):
348 | bestedgeto[bj] = k
349 | # Forget about least-slack edges of the subblossom.
350 | blossombestedges[bv] = None
351 | bestedge[bv] = -1
352 | blossombestedges[b] = [ k for k in bestedgeto if k != -1 ]
353 | # Select bestedge[b].
354 | bestedge[b] = -1
355 | for k in blossombestedges[b]:
356 | if bestedge[b] == -1 or slack(k) < slack(bestedge[b]):
357 | bestedge[b] = k
358 | if DEBUG: DEBUG('blossomchilds[%d]=' % b + repr(blossomchilds[b]))
359 |
360 | # Expand the given top-level blossom.
361 | def expandBlossom(b, endstage):
362 | if DEBUG: DEBUG('expandBlossom(%d,%d) %s' % (b, endstage, repr(blossomchilds[b])))
363 | # Convert sub-blossoms into top-level blossoms.
364 | for s in blossomchilds[b]:
365 | blossomparent[s] = -1
366 | if s < nvertex:
367 | inblossom[s] = s
368 | elif endstage and dualvar[s] == 0:
369 | # Recursively expand this sub-blossom.
370 | expandBlossom(s, endstage)
371 | else:
372 | for v in blossomLeaves(s):
373 | inblossom[v] = s
374 | # If we expand a T-blossom during a stage, its sub-blossoms must be
375 | # relabeled.
376 | if (not endstage) and label[b] == 2:
377 | # Start at the sub-blossom through which the expanding
378 | # blossom obtained its label, and relabel sub-blossoms untili
379 | # we reach the base.
380 | # Figure out through which sub-blossom the expanding blossom
381 | # obtained its label initially.
382 | assert labelend[b] >= 0
383 | entrychild = inblossom[endpoint[labelend[b] ^ 1]]
384 | # Decide in which direction we will go round the blossom.
385 | j = blossomchilds[b].index(entrychild)
386 | if j & 1:
387 | # Start index is odd; go forward and wrap.
388 | j -= len(blossomchilds[b])
389 | jstep = 1
390 | endptrick = 0
391 | else:
392 | # Start index is even; go backward.
393 | jstep = -1
394 | endptrick = 1
395 | # Move along the blossom until we get to the base.
396 | p = labelend[b]
397 | while j != 0:
398 | # Relabel the T-sub-blossom.
399 | label[endpoint[p ^ 1]] = 0
400 | label[endpoint[blossomendps[b][j-endptrick]^endptrick^1]] = 0
401 | assignLabel(endpoint[p ^ 1], 2, p)
402 | # Step to the next S-sub-blossom and note its forward endpoint.
403 | allowedge[blossomendps[b][j-endptrick]//2] = True
404 | j += jstep
405 | p = blossomendps[b][j-endptrick] ^ endptrick
406 | # Step to the next T-sub-blossom.
407 | allowedge[p//2] = True
408 | j += jstep
409 | # Relabel the base T-sub-blossom WITHOUT stepping through to
410 | # its mate (so don't call assignLabel).
411 | bv = blossomchilds[b][j]
412 | label[endpoint[p ^ 1]] = label[bv] = 2
413 | labelend[endpoint[p ^ 1]] = labelend[bv] = p
414 | bestedge[bv] = -1
415 | # Continue along the blossom until we get back to entrychild.
416 | j += jstep
417 | while blossomchilds[b][j] != entrychild:
418 | # Examine the vertices of the sub-blossom to see whether
419 | # it is reachable from a neighbouring S-vertex outside the
420 | # expanding blossom.
421 | bv = blossomchilds[b][j]
422 | if label[bv] == 1:
423 | # This sub-blossom just got label S through one of its
424 | # neighbours; leave it.
425 | j += jstep
426 | continue
427 | for v in blossomLeaves(bv):
428 | if label[v] != 0:
429 | break
430 | # If the sub-blossom contains a reachable vertex, assign
431 | # label T to the sub-blossom.
432 | if label[v] != 0:
433 | assert label[v] == 2
434 | assert inblossom[v] == bv
435 | label[v] = 0
436 | label[endpoint[mate[blossombase[bv]]]] = 0
437 | assignLabel(v, 2, labelend[v])
438 | j += jstep
439 | # Recycle the blossom number.
440 | label[b] = labelend[b] = -1
441 | blossomchilds[b] = blossomendps[b] = None
442 | blossombase[b] = -1
443 | blossombestedges[b] = None
444 | bestedge[b] = -1
445 | unusedblossoms.append(b)
446 |
447 | # Swap matched/unmatched edges over an alternating path through blossom b
448 | # between vertex v and the base vertex. Keep blossom bookkeeping consistent.
449 | def augmentBlossom(b, v):
450 | if DEBUG: DEBUG('augmentBlossom(%d,%d)' % (b, v))
451 | # Bubble up through the blossom tree from vertex v to an immediate
452 | # sub-blossom of b.
453 | t = v
454 | while blossomparent[t] != b:
455 | t = blossomparent[t]
456 | # Recursively deal with the first sub-blossom.
457 | if t >= nvertex:
458 | augmentBlossom(t, v)
459 | # Decide in which direction we will go round the blossom.
460 | i = j = blossomchilds[b].index(t)
461 | if i & 1:
462 | # Start index is odd; go forward and wrap.
463 | j -= len(blossomchilds[b])
464 | jstep = 1
465 | endptrick = 0
466 | else:
467 | # Start index is even; go backward.
468 | jstep = -1
469 | endptrick = 1
470 | # Move along the blossom until we get to the base.
471 | while j != 0:
472 | # Step to the next sub-blossom and augment it recursively.
473 | j += jstep
474 | t = blossomchilds[b][j]
475 | p = blossomendps[b][j-endptrick] ^ endptrick
476 | if t >= nvertex:
477 | augmentBlossom(t, endpoint[p])
478 | # Step to the next sub-blossom and augment it recursively.
479 | j += jstep
480 | t = blossomchilds[b][j]
481 | if t >= nvertex:
482 | augmentBlossom(t, endpoint[p ^ 1])
483 | # Match the edge connecting those sub-blossoms.
484 | mate[endpoint[p]] = p ^ 1
485 | mate[endpoint[p ^ 1]] = p
486 | if DEBUG: DEBUG('PAIR %d %d (k=%d)' % (endpoint[p], endpoint[p^1], p//2))
487 | # Rotate the list of sub-blossoms to put the new base at the front.
488 | blossomchilds[b] = blossomchilds[b][i:] + blossomchilds[b][:i]
489 | blossomendps[b] = blossomendps[b][i:] + blossomendps[b][:i]
490 | blossombase[b] = blossombase[blossomchilds[b][0]]
491 | assert blossombase[b] == v
492 |
493 | # Swap matched/unmatched edges over an alternating path between two
494 | # single vertices. The augmenting path runs through edge k, which
495 | # connects a pair of S vertices.
496 | def augmentMatching(k):
497 | (v, w, wt) = edges[k]
498 | if DEBUG: DEBUG('augmentMatching(%d) (v=%d w=%d)' % (k, v, w))
499 | if DEBUG: DEBUG('PAIR %d %d (k=%d)' % (v, w, k))
500 | for (s, p) in ((v, 2*k+1), (w, 2*k)):
501 | # Match vertex s to remote endpoint p. Then trace back from s
502 | # until we find a single vertex, swapping matched and unmatched
503 | # edges as we go.
504 | while 1:
505 | bs = inblossom[s]
506 | assert label[bs] == 1
507 | assert labelend[bs] == mate[blossombase[bs]]
508 | # Augment through the S-blossom from s to base.
509 | if bs >= nvertex:
510 | augmentBlossom(bs, s)
511 | # Update mate[s]
512 | mate[s] = p
513 | # Trace one step back.
514 | if labelend[bs] == -1:
515 | # Reached single vertex; stop.
516 | break
517 | t = endpoint[labelend[bs]]
518 | bt = inblossom[t]
519 | assert label[bt] == 2
520 | # Trace one step back.
521 | assert labelend[bt] >= 0
522 | s = endpoint[labelend[bt]]
523 | j = endpoint[labelend[bt] ^ 1]
524 | # Augment through the T-blossom from j to base.
525 | assert blossombase[bt] == t
526 | if bt >= nvertex:
527 | augmentBlossom(bt, j)
528 | # Update mate[j]
529 | mate[j] = labelend[bt]
530 | # Keep the opposite endpoint;
531 | # it will be assigned to mate[s] in the next step.
532 | p = labelend[bt] ^ 1
533 | if DEBUG: DEBUG('PAIR %d %d (k=%d)' % (s, t, p//2))
534 |
535 | # Verify that the optimum solution has been reached.
536 | def verifyOptimum():
537 | if maxcardinality:
538 | # Vertices may have negative dual;
539 | # find a constant non-negative number to add to all vertex duals.
540 | vdualoffset = max(0, -min(dualvar[:nvertex]))
541 | else:
542 | vdualoffset = 0
543 | # 0. all dual variables are non-negative
544 | assert min(dualvar[:nvertex]) + vdualoffset >= 0
545 | assert min(dualvar[nvertex:]) >= 0
546 | # 0. all edges have non-negative slack and
547 | # 1. all matched edges have zero slack;
548 | for k in range(nedge):
549 | (i, j, wt) = edges[k]
550 | s = dualvar[i] + dualvar[j] - 2 * wt
551 | iblossoms = [ i ]
552 | jblossoms = [ j ]
553 | while blossomparent[iblossoms[-1]] != -1:
554 | iblossoms.append(blossomparent[iblossoms[-1]])
555 | while blossomparent[jblossoms[-1]] != -1:
556 | jblossoms.append(blossomparent[jblossoms[-1]])
557 | iblossoms.reverse()
558 | jblossoms.reverse()
559 | for (bi, bj) in zip(iblossoms, jblossoms):
560 | if bi != bj:
561 | break
562 | s += 2 * dualvar[bi]
563 | assert s >= 0
564 | if mate[i] // 2 == k or mate[j] // 2 == k:
565 | assert mate[i] // 2 == k and mate[j] // 2 == k
566 | assert s == 0
567 | # 2. all single vertices have zero dual value;
568 | for v in range(nvertex):
569 | assert mate[v] >= 0 or dualvar[v] + vdualoffset == 0
570 | # 3. all blossoms with positive dual value are full.
571 | for b in range(nvertex, 2*nvertex):
572 | if blossombase[b] >= 0 and dualvar[b] > 0:
573 | assert len(blossomendps[b]) % 2 == 1
574 | for p in blossomendps[b][1::2]:
575 | assert mate[endpoint[p]] == p ^ 1
576 | assert mate[endpoint[p ^ 1]] == p
577 | # Ok.
578 |
579 | # Check optimized delta2 against a trivial computation.
580 | def checkDelta2():
581 | for v in range(nvertex):
582 | if label[inblossom[v]] == 0:
583 | bd = None
584 | bk = -1
585 | for p in neighbend[v]:
586 | k = p // 2
587 | w = endpoint[p]
588 | if label[inblossom[w]] == 1:
589 | d = slack(k)
590 | if bk == -1 or d < bd:
591 | bk = k
592 | bd = d
593 | if DEBUG and (bestedge[v] != -1 or bk != -1) and (bestedge[v] == -1 or bd != slack(bestedge[v])):
594 | DEBUG('v=' + str(v) + ' bk=' + str(bk) + ' bd=' + str(bd) + ' bestedge=' + str(bestedge[v]) + ' slack=' + str(slack(bestedge[v])))
595 | assert (bk == -1 and bestedge[v] == -1) or (bestedge[v] != -1 and bd == slack(bestedge[v]))
596 |
597 | # Check optimized delta3 against a trivial computation.
598 | def checkDelta3():
599 | bk = -1
600 | bd = None
601 | tbk = -1
602 | tbd = None
603 | for b in range(2 * nvertex):
604 | if blossomparent[b] == -1 and label[b] == 1:
605 | for v in blossomLeaves(b):
606 | for p in neighbend[v]:
607 | k = p // 2
608 | w = endpoint[p]
609 | if inblossom[w] != b and label[inblossom[w]] == 1:
610 | d = slack(k)
611 | if bk == -1 or d < bd:
612 | bk = k
613 | bd = d
614 | if bestedge[b] != -1:
615 | (i, j, wt) = edges[bestedge[b]]
616 | assert inblossom[i] == b or inblossom[j] == b
617 | assert inblossom[i] != b or inblossom[j] != b
618 | assert label[inblossom[i]] == 1 and label[inblossom[j]] == 1
619 | if tbk == -1 or slack(bestedge[b]) < tbd:
620 | tbk = bestedge[b]
621 | tbd = slack(bestedge[b])
622 | if DEBUG and bd != tbd:
623 | DEBUG('bk=%d tbk=%d bd=%s tbd=%s' % (bk, tbk, repr(bd), repr(tbd)))
624 | assert bd == tbd
625 |
626 | # Main loop: continue until no further improvement is possible.
627 | for t in range(nvertex):
628 |
629 | # Each iteration of this loop is a "stage".
630 | # A stage finds an augmenting path and uses that to improve
631 | # the matching.
632 | if DEBUG: DEBUG('STAGE %d' % t)
633 |
634 | # Remove labels from top-level blossoms/vertices.
635 | label[:] = (2 * nvertex) * [ 0 ]
636 |
637 | # Forget all about least-slack edges.
638 | bestedge[:] = (2 * nvertex) * [ -1 ]
639 | blossombestedges[nvertex:] = nvertex * [ None ]
640 |
641 | # Loss of labeling means that we can not be sure that currently
642 | # allowable edges remain allowable througout this stage.
643 | allowedge[:] = nedge * [ False ]
644 |
645 | # Make queue empty.
646 | queue[:] = [ ]
647 |
648 | # Label single blossoms/vertices with S and put them in the queue.
649 | for v in range(nvertex):
650 | if mate[v] == -1 and label[inblossom[v]] == 0:
651 | assignLabel(v, 1, -1)
652 |
653 | # Loop until we succeed in augmenting the matching.
654 | augmented = 0
655 | while 1:
656 |
657 | # Each iteration of this loop is a "substage".
658 | # A substage tries to find an augmenting path;
659 | # if found, the path is used to improve the matching and
660 | # the stage ends. If there is no augmenting path, the
661 | # primal-dual method is used to pump some slack out of
662 | # the dual variables.
663 | if DEBUG: DEBUG('SUBSTAGE')
664 |
665 | # Continue labeling until all vertices which are reachable
666 | # through an alternating path have got a label.
667 | while queue and not augmented:
668 |
669 | # Take an S vertex from the queue.
670 | v = queue.pop()
671 | if DEBUG: DEBUG('POP v=%d' % v)
672 | assert label[inblossom[v]] == 1
673 |
674 | # Scan its neighbours:
675 | for p in neighbend[v]:
676 | k = p // 2
677 | w = endpoint[p]
678 | # w is a neighbour to v
679 | if inblossom[v] == inblossom[w]:
680 | # this edge is internal to a blossom; ignore it
681 | continue
682 | if not allowedge[k]:
683 | kslack = slack(k)
684 | if kslack <= 0:
685 | # edge k has zero slack => it is allowable
686 | allowedge[k] = True
687 | if allowedge[k]:
688 | if label[inblossom[w]] == 0:
689 | # (C1) w is a free vertex;
690 | # label w with T and label its mate with S (R12).
691 | assignLabel(w, 2, p ^ 1)
692 | elif label[inblossom[w]] == 1:
693 | # (C2) w is an S-vertex (not in the same blossom);
694 | # follow back-links to discover either an
695 | # augmenting path or a new blossom.
696 | base = scanBlossom(v, w)
697 | if base >= 0:
698 | # Found a new blossom; add it to the blossom
699 | # bookkeeping and turn it into an S-blossom.
700 | addBlossom(base, k)
701 | else:
702 | # Found an augmenting path; augment the
703 | # matching and end this stage.
704 | augmentMatching(k)
705 | augmented = 1
706 | break
707 | elif label[w] == 0:
708 | # w is inside a T-blossom, but w itself has not
709 | # yet been reached from outside the blossom;
710 | # mark it as reached (we need this to relabel
711 | # during T-blossom expansion).
712 | assert label[inblossom[w]] == 2
713 | label[w] = 2
714 | labelend[w] = p ^ 1
715 | elif label[inblossom[w]] == 1:
716 | # keep track of the least-slack non-allowable edge to
717 | # a different S-blossom.
718 | b = inblossom[v]
719 | if bestedge[b] == -1 or kslack < slack(bestedge[b]):
720 | bestedge[b] = k
721 | elif label[w] == 0:
722 | # w is a free vertex (or an unreached vertex inside
723 | # a T-blossom) but we can not reach it yet;
724 | # keep track of the least-slack edge that reaches w.
725 | if bestedge[w] == -1 or kslack < slack(bestedge[w]):
726 | bestedge[w] = k
727 |
728 | if augmented:
729 | break
730 |
731 | # There is no augmenting path under these constraints;
732 | # compute delta and reduce slack in the optimization problem.
733 | # (Note that our vertex dual variables, edge slacks and delta's
734 | # are pre-multiplied by two.)
735 | deltatype = -1
736 | delta = deltaedge = deltablossom = None
737 |
738 | # Verify data structures for delta2/delta3 computation.
739 | if CHECK_DELTA:
740 | checkDelta2()
741 | checkDelta3()
742 |
743 | # Compute delta1: the minumum value of any vertex dual.
744 | if not maxcardinality:
745 | deltatype = 1
746 | delta = min(dualvar[:nvertex])
747 |
748 | # Compute delta2: the minimum slack on any edge between
749 | # an S-vertex and a free vertex.
750 | for v in range(nvertex):
751 | if label[inblossom[v]] == 0 and bestedge[v] != -1:
752 | d = slack(bestedge[v])
753 | if deltatype == -1 or d < delta:
754 | delta = d
755 | deltatype = 2
756 | deltaedge = bestedge[v]
757 |
758 | # Compute delta3: half the minimum slack on any edge between
759 | # a pair of S-blossoms.
760 | for b in range(2 * nvertex):
761 | if ( blossomparent[b] == -1 and label[b] == 1 and
762 | bestedge[b] != -1 ):
763 | kslack = slack(bestedge[b])
764 | if isinstance(kslack, integer_types):
765 | assert (kslack % 2) == 0
766 | d = kslack // 2
767 | else:
768 | d = kslack / 2
769 | if deltatype == -1 or d < delta:
770 | delta = d
771 | deltatype = 3
772 | deltaedge = bestedge[b]
773 |
774 | # Compute delta4: minimum z variable of any T-blossom.
775 | for b in range(nvertex, 2*nvertex):
776 | if ( blossombase[b] >= 0 and blossomparent[b] == -1 and
777 | label[b] == 2 and
778 | (deltatype == -1 or dualvar[b] < delta) ):
779 | delta = dualvar[b]
780 | deltatype = 4
781 | deltablossom = b
782 |
783 | if deltatype == -1:
784 | # No further improvement possible; max-cardinality optimum
785 | # reached. Do a final delta update to make the optimum
786 | # verifyable.
787 | assert maxcardinality
788 | deltatype = 1
789 | delta = max(0, min(dualvar[:nvertex]))
790 |
791 | # Update dual variables according to delta.
792 | for v in range(nvertex):
793 | if label[inblossom[v]] == 1:
794 | # S-vertex: 2*u = 2*u - 2*delta
795 | dualvar[v] -= delta
796 | elif label[inblossom[v]] == 2:
797 | # T-vertex: 2*u = 2*u + 2*delta
798 | dualvar[v] += delta
799 | for b in range(nvertex, 2*nvertex):
800 | if blossombase[b] >= 0 and blossomparent[b] == -1:
801 | if label[b] == 1:
802 | # top-level S-blossom: z = z + 2*delta
803 | dualvar[b] += delta
804 | elif label[b] == 2:
805 | # top-level T-blossom: z = z - 2*delta
806 | dualvar[b] -= delta
807 |
808 | # Take action at the point where minimum delta occurred.
809 | if DEBUG: DEBUG('delta%d=%f' % (deltatype, delta))
810 | if deltatype == 1:
811 | # No further improvement possible; optimum reached.
812 | break
813 | elif deltatype == 2:
814 | # Use the least-slack edge to continue the search.
815 | allowedge[deltaedge] = True
816 | (i, j, wt) = edges[deltaedge]
817 | if label[inblossom[i]] == 0:
818 | i, j = j, i
819 | assert label[inblossom[i]] == 1
820 | queue.append(i)
821 | elif deltatype == 3:
822 | # Use the least-slack edge to continue the search.
823 | allowedge[deltaedge] = True
824 | (i, j, wt) = edges[deltaedge]
825 | assert label[inblossom[i]] == 1
826 | queue.append(i)
827 | elif deltatype == 4:
828 | # Expand the least-z blossom.
829 | expandBlossom(deltablossom, False)
830 |
831 | # End of a this substage.
832 |
833 | # Stop when no more augmenting path can be found.
834 | if not augmented:
835 | break
836 |
837 | # End of a stage; expand all S-blossoms which have dualvar = 0.
838 | for b in range(nvertex, 2*nvertex):
839 | if ( blossomparent[b] == -1 and blossombase[b] >= 0 and
840 | label[b] == 1 and dualvar[b] == 0 ):
841 | expandBlossom(b, True)
842 |
843 | # Verify that we reached the optimum solution.
844 | if CHECK_OPTIMUM:
845 | verifyOptimum()
846 |
847 | # Transform mate[] such that mate[v] is the vertex to which v is paired.
848 | for v in range(nvertex):
849 | if mate[v] >= 0:
850 | mate[v] = endpoint[mate[v]]
851 | for v in range(nvertex):
852 | assert mate[v] == -1 or mate[mate[v]] == v
853 |
854 | return mate
855 |
856 |
857 | # Unit tests
858 | if __name__ == '__main__':
859 | x = maxWeightMatching([(1,4,10),(1,5,20),(2,5,40),(2,6,60),(3,4,30)])
860 | print(x)
861 |
862 | # end
--------------------------------------------------------------------------------
/lib/evaluate/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xg-chu/CrowdDet/4a0c674c40fdcb3e2706d39544f088390f9f63fe/lib/evaluate/__init__.py
--------------------------------------------------------------------------------
/lib/evaluate/compute_APMR.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from .APMRToolkits import *
3 |
4 | dbName = 'human'
5 | def compute_APMR(dt_path, gt_path, target_key=None, mode=0):
6 | database = Database(gt_path, dt_path, target_key, None, mode)
7 | database.compare()
8 | mAP,_ = database.eval_AP()
9 | mMR,_ = database.eval_MR()
10 | line = 'AP:{:.4f}, MR:{:.4f}.'.format(mAP, mMR)
11 | return mAP, mMR
12 |
13 | if __name__ == "__main__":
14 | parser = argparse.ArgumentParser(description='Analyze a json result file with iou match')
15 | parser.add_argument('--detfile', required=True, help='path of json result file to load')
16 | parser.add_argument('--target_key', default=None, required=True)
17 | args = parser.parse_args()
18 | compute_APMR(args.detfile, args.target_key, 0)
19 |
--------------------------------------------------------------------------------
/lib/evaluate/compute_JI.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import argparse
5 | from multiprocessing import Queue, Process
6 |
7 | from tqdm import tqdm
8 | import numpy as np
9 |
10 | from .JIToolkits.JI_tools import compute_matching, get_ignores
11 | sys.path.insert(0, '../')
12 | import utils.misc_utils as misc_utils
13 |
14 | gtfile = '/data/annotation_val.odgt'
15 | nr_procs = 10
16 |
17 | def evaluation_all(path, target_key):
18 | records = misc_utils.load_json_lines(path)
19 | res_line = []
20 | res_JI = []
21 | for i in range(10):
22 | score_thr = 1e-1 * i
23 | total = len(records)
24 | stride = math.ceil(total / nr_procs)
25 | result_queue = Queue(10000)
26 | results, procs = [], []
27 | for i in range(nr_procs):
28 | start = i*stride
29 | end = np.min([start+stride,total])
30 | sample_data = records[start:end]
31 | p = Process(target= compute_JI_with_ignore, args=(result_queue, sample_data, score_thr, target_key))
32 | p.start()
33 | procs.append(p)
34 | tqdm.monitor_interval = 0
35 | pbar = tqdm(total=total, leave = False, ascii = True)
36 | for i in range(total):
37 | t = result_queue.get()
38 | results.append(t)
39 | pbar.update(1)
40 | for p in procs:
41 | p.join()
42 | pbar.close()
43 | line, mean_ratio = gather(results)
44 | line = 'score_thr:{:.1f}, {}'.format(score_thr, line)
45 | print(line)
46 | res_line.append(line)
47 | res_JI.append(mean_ratio)
48 | return res_line, max(res_JI)
49 |
50 | def compute_JI_with_ignore(result_queue, records, score_thr, target_key, bm_thresh=0.5):
51 | for record in records:
52 | gt_boxes = misc_utils.load_bboxes(record, 'gtboxes', target_key, 'tag')
53 | gt_boxes[:,2:4] += gt_boxes[:,:2]
54 | gt_boxes = misc_utils.clip_boundary(gt_boxes, record['height'], record['width'])
55 | dt_boxes = misc_utils.load_bboxes(record, 'dtboxes', target_key, 'score')
56 | dt_boxes[:,2:4] += dt_boxes[:,:2]
57 | dt_boxes = misc_utils.clip_boundary(dt_boxes, record['height'], record['width'])
58 | keep = dt_boxes[:, -1] > score_thr
59 | dt_boxes = dt_boxes[keep][:, :-1]
60 |
61 | gt_tag = np.array(gt_boxes[:,-1]!=-1)
62 | matches = compute_matching(dt_boxes, gt_boxes[gt_tag, :4], bm_thresh)
63 | # get the unmatched_indices
64 | matched_indices = np.array([j for (j,_) in matches])
65 | unmatched_indices = list(set(np.arange(dt_boxes.shape[0])) - set(matched_indices))
66 | num_ignore_dt = get_ignores(dt_boxes[unmatched_indices], gt_boxes[~gt_tag, :4], bm_thresh)
67 | matched_indices = np.array([j for (_,j) in matches])
68 | unmatched_indices = list(set(np.arange(gt_boxes[gt_tag].shape[0])) - set(matched_indices))
69 | num_ignore_gt = get_ignores(gt_boxes[gt_tag][unmatched_indices], gt_boxes[~gt_tag, :4], bm_thresh)
70 | # compurte results
71 | eps = 1e-6
72 | k = len(matches)
73 | m = gt_tag.sum() - num_ignore_gt
74 | n = dt_boxes.shape[0] - num_ignore_dt
75 | ratio = k / (m + n -k + eps)
76 | recall = k / (m + eps)
77 | cover = k / (n + eps)
78 | noise = 1 - cover
79 | result_dict = dict(ratio = ratio, recall = recall, cover = cover,
80 | noise = noise, k = k, m = m, n = n)
81 | result_queue.put_nowait(result_dict)
82 |
83 | def gather(results):
84 | assert len(results)
85 | img_num = 0
86 | for result in results:
87 | if result['n'] != 0 or result['m'] != 0:
88 | img_num += 1
89 | mean_ratio = np.sum([rb['ratio'] for rb in results]) / img_num
90 | mean_cover = np.sum([rb['cover'] for rb in results]) / img_num
91 | mean_recall = np.sum([rb['recall'] for rb in results]) / img_num
92 | mean_noise = 1 - mean_cover
93 | valids = np.sum([rb['k'] for rb in results])
94 | total = np.sum([rb['n'] for rb in results])
95 | gtn = np.sum([rb['m'] for rb in results])
96 |
97 | #line = 'mean_ratio:{:.4f}, mean_cover:{:.4f}, mean_recall:{:.4f}, mean_noise:{:.4f}, valids:{}, total:{}, gtn:{}'.format(
98 | # mean_ratio, mean_cover, mean_recall, mean_noise, valids, total, gtn)
99 | line = 'mean_ratio:{:.4f}, valids:{}, total:{}, gtn:{}'.format(
100 | mean_ratio, valids, total, gtn)
101 | return line, mean_ratio
102 |
103 | def common_process(func, cls_list, nr_procs):
104 | total = len(cls_list)
105 | stride = math.ceil(total / nr_procs)
106 | result_queue = Queue(10000)
107 | results, procs = [], []
108 | for i in range(nr_procs):
109 | start = i*stride
110 | end = np.min([start+stride,total])
111 | sample_data = cls_list[start:end]
112 | p = Process(target= func,args=(result_queue, sample_data))
113 | p.start()
114 | procs.append(p)
115 | for i in range(total):
116 | t = result_queue.get()
117 | if t is None:
118 | continue
119 | results.append(t)
120 | for p in procs:
121 | p.join()
122 | return results
123 |
124 | if __name__ == "__main__":
125 | parser = argparse.ArgumentParser(description='Analyze a json result file with iou match')
126 | parser.add_argument('--detfile', required=True, help='path of json result file to load')
127 | parser.add_argument('--target_key', required=True)
128 | args = parser.parse_args()
129 | evaluation_all(args.detfile, args.target_key)
130 |
--------------------------------------------------------------------------------
/lib/layers/batch_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class FrozenBatchNorm2d(nn.Module):
5 | """
6 | BatchNorm2d where the batch statistics and the affine parameters are fixed.
7 | """
8 |
9 | def __init__(self, num_features, eps=1e-5):
10 | super().__init__()
11 | self.eps = eps
12 | self.register_buffer("weight", torch.ones(num_features))
13 | self.register_buffer("bias", torch.zeros(num_features))
14 | self.register_buffer("running_mean", torch.zeros(num_features))
15 | self.register_buffer("running_var", torch.ones(num_features) - eps)
16 |
17 | def forward(self, x):
18 | scale = self.weight * (self.running_var + self.eps).rsqrt()
19 | bias = self.bias - self.running_mean * scale
20 | scale = scale.reshape(1, -1, 1, 1)
21 | bias = bias.reshape(1, -1, 1, 1)
22 | return x * scale + bias
23 |
24 | def __repr__(self):
25 | return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
26 |
27 |
--------------------------------------------------------------------------------
/lib/layers/pooler.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torchvision.ops import roi_align
5 |
6 | def assign_boxes_to_levels(rois, min_level, max_level, canonical_box_size=224, canonical_level=4):
7 | """
8 | rois (Tensor): A tensor of shape (N, 5).
9 | min_level (int), max_level (int), canonical_box_size (int), canonical_level (int).
10 | Return a tensor of length N.
11 | """
12 | eps = 1e-6
13 | box_sizes = torch.sqrt((rois[:, 3] - rois[:, 1]) * (rois[:, 4] - rois[:, 2]))
14 | # Eqn.(1) in FPN paper
15 | level_assignments = torch.floor(
16 | canonical_level + torch.log2(box_sizes / canonical_box_size + eps)
17 | )
18 | # clamp level to (min, max), in case the box size is too large or too small
19 | # for the available feature maps
20 | level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level)
21 | return level_assignments.to(torch.int64) - min_level
22 |
23 | def roi_pooler(fpn_fms, rois, stride, pool_shape, pooler_type):
24 | if pooler_type == "ROIAlign":
25 | pooler_aligned = False
26 | elif pooler_type == "ROIAlignV2":
27 | pooler_aligned = True
28 | else:
29 | raise ValueError("Unknown pooler type: {}".format(pooler_type))
30 | assert len(fpn_fms) == len(stride)
31 | max_level = int(math.log2(stride[-1]))
32 | min_level = int(math.log2(stride[0]))
33 | assert (len(stride) == max_level - min_level + 1)
34 | level_assignments = assign_boxes_to_levels(rois, min_level, max_level, 224, 4)
35 | dtype, device = fpn_fms[0].dtype, fpn_fms[0].device
36 | output = torch.zeros((len(rois), fpn_fms[0].shape[1], pool_shape[0], pool_shape[1]),
37 | dtype=dtype, device=device)
38 | for level, (fm_level, scale_level) in enumerate(zip(fpn_fms, stride)):
39 | inds = torch.nonzero(level_assignments == level, as_tuple=False).squeeze(1)
40 | rois_level = rois[inds]
41 | output[inds] = roi_align(fm_level, rois_level, pool_shape, spatial_scale=1.0/scale_level,
42 | sampling_ratio=-1, aligned=pooler_aligned)
43 | return output
44 |
45 |
--------------------------------------------------------------------------------
/lib/module/rpn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | from config import config
7 | from det_oprs.anchors_generator import AnchorGenerator
8 | from det_oprs.find_top_rpn_proposals import find_top_rpn_proposals
9 | from det_oprs.fpn_anchor_target import fpn_anchor_target, fpn_rpn_reshape
10 | from det_oprs.loss_opr import softmax_loss, smooth_l1_loss
11 |
12 | class RPN(nn.Module):
13 | def __init__(self, rpn_channel = 256):
14 | super().__init__()
15 | self.anchors_generator = AnchorGenerator(
16 | config.anchor_base_size,
17 | config.anchor_aspect_ratios,
18 | config.anchor_base_scale)
19 | self.rpn_conv = nn.Conv2d(256, rpn_channel, kernel_size=3, stride=1, padding=1)
20 | self.rpn_cls_score = nn.Conv2d(rpn_channel, config.num_cell_anchors * 2, kernel_size=1, stride=1)
21 | self.rpn_bbox_offsets = nn.Conv2d(rpn_channel, config.num_cell_anchors * 4, kernel_size=1, stride=1)
22 |
23 | for l in [self.rpn_conv, self.rpn_cls_score, self.rpn_bbox_offsets]:
24 | nn.init.normal_(l.weight, std=0.01)
25 | nn.init.constant_(l.bias, 0)
26 |
27 | def forward(self, features, im_info, boxes=None):
28 | # prediction
29 | pred_cls_score_list = []
30 | pred_bbox_offsets_list = []
31 | for x in features:
32 | t = F.relu(self.rpn_conv(x))
33 | pred_cls_score_list.append(self.rpn_cls_score(t))
34 | pred_bbox_offsets_list.append(self.rpn_bbox_offsets(t))
35 | # get anchors
36 | all_anchors_list = []
37 | # stride: 64,32,16,8,4 p6->p2
38 | base_stride = 4
39 | off_stride = 2**(len(features)-1) # 16
40 | for fm in features:
41 | layer_anchors = self.anchors_generator(fm, base_stride, off_stride)
42 | off_stride = off_stride // 2
43 | all_anchors_list.append(layer_anchors)
44 | # sample from the predictions
45 | rpn_rois = find_top_rpn_proposals(
46 | self.training, pred_bbox_offsets_list, pred_cls_score_list,
47 | all_anchors_list, im_info)
48 | rpn_rois = rpn_rois.type_as(features[0])
49 | if self.training:
50 | rpn_labels, rpn_bbox_targets = fpn_anchor_target(
51 | boxes, im_info, all_anchors_list)
52 | #rpn_labels = rpn_labels.astype(np.int32)
53 | pred_cls_score, pred_bbox_offsets = fpn_rpn_reshape(
54 | pred_cls_score_list, pred_bbox_offsets_list)
55 | # rpn loss
56 | valid_masks = rpn_labels >= 0
57 | objectness_loss = softmax_loss(
58 | pred_cls_score[valid_masks],
59 | rpn_labels[valid_masks])
60 |
61 | pos_masks = rpn_labels > 0
62 | localization_loss = smooth_l1_loss(
63 | pred_bbox_offsets[pos_masks],
64 | rpn_bbox_targets[pos_masks],
65 | config.rpn_smooth_l1_beta)
66 | normalizer = 1 / valid_masks.sum().item()
67 | loss_rpn_cls = objectness_loss.sum() * normalizer
68 | loss_rpn_loc = localization_loss.sum() * normalizer
69 | loss_dict = {}
70 | loss_dict['loss_rpn_cls'] = loss_rpn_cls
71 | loss_dict['loss_rpn_loc'] = loss_rpn_loc
72 | return rpn_rois, loss_dict
73 | else:
74 | return rpn_rois
75 |
76 |
--------------------------------------------------------------------------------
/lib/utils/SGD_bias.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.optimizer import Optimizer, required
3 |
4 | class SGD(Optimizer):
5 | """Implements stochastic gradient descent (optionally with momentum).
6 | Args:
7 | params (iterable): iterable of parameters to optimize or dicts defining
8 | parameter groups
9 | lr (float): learning rate
10 | momentum (float, optional): momentum factor (default: 0)
11 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
12 | dampening (float, optional): dampening for momentum (default: 0)
13 | nesterov (bool, optional): enables Nesterov momentum (default: False)
14 | """
15 |
16 | def __init__(self, params, lr=required, momentum=0, dampening=0,
17 | weight_decay=0, nesterov=False):
18 | if lr is not required and lr < 0.0:
19 | raise ValueError("Invalid learning rate: {}".format(lr))
20 | if momentum < 0.0:
21 | raise ValueError("Invalid momentum value: {}".format(momentum))
22 | if weight_decay < 0.0:
23 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
24 |
25 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
26 | weight_decay=weight_decay, nesterov=nesterov)
27 | if nesterov and (momentum <= 0 or dampening != 0):
28 | raise ValueError("Nesterov momentum requires a momentum and zero dampening")
29 | super(SGD, self).__init__(params, defaults)
30 |
31 | def __setstate__(self, state):
32 | super(SGD, self).__setstate__(state)
33 | for group in self.param_groups:
34 | group.setdefault('nesterov', False)
35 |
36 | @torch.no_grad()
37 | def step(self):
38 | """Performs a single optimization step.
39 | """
40 | loss = None
41 | for group in self.param_groups:
42 | weight_decay = group['weight_decay']
43 | momentum = group['momentum']
44 | dampening = group['dampening']
45 | nesterov = group['nesterov']
46 |
47 | for p in group['params']:
48 | if p.grad is None:
49 | continue
50 | d_p = p.grad
51 | if weight_decay != 0 and p.dim() > 1:
52 | d_p = d_p.add(p, alpha=weight_decay)
53 | if momentum != 0:
54 | param_state = self.state[p]
55 | if 'momentum_buffer' not in param_state:
56 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
57 | else:
58 | buf = param_state['momentum_buffer']
59 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
60 | if nesterov:
61 | d_p = d_p.add(buf, alpha=momentum)
62 | else:
63 | d_p = buf
64 |
65 | p.add_(d_p, alpha=-group['lr'])
66 |
67 | return loss
68 |
69 |
--------------------------------------------------------------------------------
/lib/utils/misc_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 |
5 | def load_img(image_path):
6 | import cv2
7 | img = cv2.imread(image_path, cv2.IMREAD_COLOR)
8 | return img
9 |
10 | def load_json_lines(fpath):
11 | assert os.path.exists(fpath)
12 | with open(fpath,'r') as fid:
13 | lines = fid.readlines()
14 | records = [json.loads(line.strip('\n')) for line in lines]
15 | return records
16 |
17 | def save_json_lines(content,fpath):
18 | with open(fpath,'w') as fid:
19 | for db in content:
20 | line = json.dumps(db)+'\n'
21 | fid.write(line)
22 |
23 | def device_parser(str_device):
24 | if '-' in str_device:
25 | device_id = str_device.split('-')
26 | device_id = [i for i in range(int(device_id[0]), int(device_id[1])+1)]
27 | else:
28 | device_id = [int(str_device)]
29 | return device_id
30 |
31 | def ensure_dir(dirpath):
32 | if not os.path.exists(dirpath):
33 | os.makedirs(dirpath)
34 |
35 | def xyxy_to_xywh(boxes):
36 | assert boxes.shape[1]>=4
37 | boxes[:, 2:4] -= boxes[:,:2]
38 | return boxes
39 |
40 | def xywh_to_xyxy(boxes):
41 | assert boxes.shape[1]>=4
42 | boxes[:, 2:4] += boxes[:,:2]
43 | return boxes
44 |
45 | def load_bboxes(dict_input, key_name, key_box, key_score=None, key_tag=None):
46 | assert key_name in dict_input
47 | if len(dict_input[key_name]) < 1:
48 | return np.empty([0, 5])
49 | else:
50 | assert key_box in dict_input[key_name][0]
51 | if key_score:
52 | assert key_score in dict_input[key_name][0]
53 | if key_tag:
54 | assert key_tag in dict_input[key_name][0]
55 | if key_score:
56 | if key_tag:
57 | bboxes = np.vstack([np.hstack((rb[key_box], rb[key_score], rb[key_tag])) for rb in dict_input[key_name]])
58 | else:
59 | bboxes = np.vstack([np.hstack((rb[key_box], rb[key_score])) for rb in dict_input[key_name]])
60 | else:
61 | if key_tag:
62 | bboxes = np.vstack([np.hstack((rb[key_box], rb[key_tag])) for rb in dict_input[key_name]])
63 | else:
64 | bboxes = np.vstack([rb[key_box] for rb in dict_input[key_name]])
65 | return bboxes
66 |
67 | def load_masks(dict_input, key_name, key_box):
68 | assert key_name in dict_input
69 | if len(dict_input[key_name]) < 1:
70 | return np.empty([0, 28, 28])
71 | else:
72 | assert key_box in dict_input[key_name][0]
73 | masks = np.array([rb[key_box] for rb in dict_input[key_name]])
74 | return masks
75 |
76 | def load_gt(dict_input, key_name, key_box, class_names):
77 | assert key_name in dict_input
78 | if len(dict_input[key_name]) < 1:
79 | return np.empty([0, 5])
80 | else:
81 | assert key_box in dict_input[key_name][0]
82 | bbox = []
83 | for rb in dict_input[key_name]:
84 | if rb['tag'] in class_names:
85 | tag = class_names.index(rb['tag'])
86 | else:
87 | tag = -1
88 | if 'extra' in rb:
89 | if 'ignore' in rb['extra']:
90 | if rb['extra']['ignore'] != 0:
91 | tag = -1
92 | bbox.append(np.hstack((rb[key_box], tag)))
93 | bboxes = np.vstack(bbox).astype(np.float64)
94 | return bboxes
95 |
96 | def boxes_dump(boxes, is_gt):
97 | result = []
98 | boxes = boxes.tolist()
99 | for box in boxes:
100 | if is_gt:
101 | box_dict = {}
102 | box_dict['box'] = [box[0], box[1], box[2]-box[0], box[3]-box[1]]
103 | box_dict['tag'] = box[-1]
104 | result.append(box_dict)
105 | else:
106 | box_dict = {}
107 | box_dict['box'] = [box[0], box[1], box[2]-box[0], box[3]-box[1]]
108 | box_dict['tag'] = 1
109 | box_dict['score'] = box[-1]
110 | result.append(box_dict)
111 | return result
112 |
113 | def clip_boundary(boxes,height,width):
114 | assert boxes.shape[-1]>=4
115 | boxes[:,0] = np.minimum(np.maximum(boxes[:,0],0), width - 1)
116 | boxes[:,1] = np.minimum(np.maximum(boxes[:,1],0), height - 1)
117 | boxes[:,2] = np.maximum(np.minimum(boxes[:,2],width), 0)
118 | boxes[:,3] = np.maximum(np.minimum(boxes[:,3],height), 0)
119 | return boxes
120 |
121 |
--------------------------------------------------------------------------------
/lib/utils/nms_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pdb
3 |
4 | def set_cpu_nms(dets, thresh):
5 | """Pure Python NMS baseline."""
6 | def _overlap(det_boxes, basement, others):
7 | eps = 1e-8
8 | x1_basement, y1_basement, x2_basement, y2_basement \
9 | = det_boxes[basement, 0], det_boxes[basement, 1], \
10 | det_boxes[basement, 2], det_boxes[basement, 3]
11 | x1_others, y1_others, x2_others, y2_others \
12 | = det_boxes[others, 0], det_boxes[others, 1], \
13 | det_boxes[others, 2], det_boxes[others, 3]
14 | areas_basement = (x2_basement - x1_basement) * (y2_basement - y1_basement)
15 | areas_others = (x2_others - x1_others) * (y2_others - y1_others)
16 | xx1 = np.maximum(x1_basement, x1_others)
17 | yy1 = np.maximum(y1_basement, y1_others)
18 | xx2 = np.minimum(x2_basement, x2_others)
19 | yy2 = np.minimum(y2_basement, y2_others)
20 | w = np.maximum(0.0, xx2 - xx1)
21 | h = np.maximum(0.0, yy2 - yy1)
22 | inter = w * h
23 | ovr = inter / (areas_basement + areas_others - inter + eps)
24 | return ovr
25 | scores = dets[:, 4]
26 | order = np.argsort(-scores)
27 | dets = dets[order]
28 |
29 | numbers = dets[:, -1]
30 | keep = np.ones(len(dets)) == 1
31 | ruler = np.arange(len(dets))
32 | while ruler.size>0:
33 | basement = ruler[0]
34 | ruler=ruler[1:]
35 | num = numbers[basement]
36 | # calculate the body overlap
37 | overlap = _overlap(dets[:, :4], basement, ruler)
38 | indices = np.where(overlap > thresh)[0]
39 | loc = np.where(numbers[ruler][indices] == num)[0]
40 | # the mask won't change in the step
41 | mask = keep[ruler[indices][loc]]#.copy()
42 | keep[ruler[indices]] = False
43 | keep[ruler[indices][loc][mask]] = True
44 | ruler[~keep[ruler]] = -1
45 | ruler = ruler[ruler>0]
46 | keep = keep[np.argsort(order)]
47 | return keep
48 |
49 | def cpu_nms(dets, base_thr):
50 | """Pure Python NMS baseline."""
51 | x1 = dets[:, 0]
52 | y1 = dets[:, 1]
53 | x2 = dets[:, 2]
54 | y2 = dets[:, 3]
55 | scores = dets[:, 4]
56 |
57 | areas = (x2 - x1) * (y2 - y1)
58 | order = np.argsort(-scores)
59 |
60 | keep = []
61 | eps = 1e-8
62 | while len(order) > 0:
63 | i = order[0]
64 | keep.append(i)
65 | xx1 = np.maximum(x1[i], x1[order[1:]])
66 | yy1 = np.maximum(y1[i], y1[order[1:]])
67 | xx2 = np.minimum(x2[i], x2[order[1:]])
68 | yy2 = np.minimum(y2[i], y2[order[1:]])
69 |
70 | w = np.maximum(0.0, xx2 - xx1)
71 | h = np.maximum(0.0, yy2 - yy1)
72 | inter = w * h
73 | ovr = inter / (areas[i] + areas[order[1:]] - inter + eps)
74 |
75 | inds = np.where(ovr <= base_thr)[0]
76 | indices = np.where(ovr > base_thr)[0]
77 | order = order[inds + 1]
78 | return np.array(keep)
79 |
80 | def _test():
81 | box1 = np.array([33,45,145,230,0.7])[None,:]
82 | box2 = np.array([44,54,123,348,0.8])[None,:]
83 | box3 = np.array([88,12,340,342,0.65])[None,:]
84 | boxes = np.concatenate([box1,box2,box3],axis = 0)
85 | nms_thresh = 0.5
86 | keep = py_cpu_nms(boxes,nms_thresh)
87 | alive_boxes = boxes[keep]
88 |
89 | if __name__=='__main__':
90 | _test()
91 |
--------------------------------------------------------------------------------
/lib/utils/visual_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | import cv2
5 |
6 | color = {'green':(0,255,0),
7 | 'blue':(255,165,0),
8 | 'dark red':(0,0,139),
9 | 'red':(0, 0, 255),
10 | 'dark slate blue':(139,61,72),
11 | 'aqua':(255,255,0),
12 | 'brown':(42,42,165),
13 | 'deep pink':(147,20,255),
14 | 'fuchisia':(255,0,255),
15 | 'yello':(0,238,238),
16 | 'orange':(0,165,255),
17 | 'saddle brown':(19,69,139),
18 | 'black':(0,0,0),
19 | 'white':(255,255,255)}
20 |
21 | def draw_boxes(img, boxes, scores=None, tags=None, line_thick=1, line_color='white'):
22 | width = img.shape[1]
23 | height = img.shape[0]
24 | for i in range(len(boxes)):
25 | one_box = boxes[i]
26 | one_box = np.array([max(one_box[0], 0), max(one_box[1], 0),
27 | min(one_box[2], width - 1), min(one_box[3], height - 1)])
28 | x1,y1,x2,y2 = np.array(one_box[:4]).astype(int)
29 | cv2.rectangle(img, (x1,y1), (x2,y2), color[line_color], line_thick)
30 | if scores is not None:
31 | text = "{} {:.3f}".format(tags[i], scores[i])
32 | cv2.putText(img, text, (x1, y1 - 7), cv2.FONT_ITALIC, 0.5, color[line_color], line_thick)
33 | return img
34 |
35 |
--------------------------------------------------------------------------------
/model/rcnn_emd_refine/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import numpy as np
5 |
6 | def add_path(path):
7 | if path not in sys.path:
8 | sys.path.insert(0, path)
9 |
10 | root_dir = '../../'
11 | add_path(os.path.join(root_dir))
12 | add_path(os.path.join(root_dir, 'lib'))
13 |
14 | class Crowd_human:
15 | class_names = ['background', 'person']
16 | num_classes = len(class_names)
17 | root_folder = '/data/CrowdHuman'
18 | image_folder = '/data/CrowdHuman/images'
19 | train_source = os.path.join('/data/CrowdHuman/annotation_train.odgt')
20 | eval_source = os.path.join('/data/CrowdHuman/annotation_val.odgt')
21 |
22 | class Config:
23 | output_dir = 'outputs'
24 | model_dir = os.path.join(output_dir, 'model_dump')
25 | eval_dir = os.path.join(output_dir, 'eval_dump')
26 | init_weights = '/data/model/resnet50_fbaug.pth'
27 |
28 | # ----------data config---------- #
29 | image_mean = np.array([103.530, 116.280, 123.675])
30 | image_std = np.array([57.375, 57.120, 58.395])
31 | train_image_short_size = 800
32 | train_image_max_size = 1400
33 | eval_resize = True
34 | eval_image_short_size = 800
35 | eval_image_max_size = 1400
36 | seed_dataprovider = 3
37 | train_source = Crowd_human.train_source
38 | eval_source = Crowd_human.eval_source
39 | image_folder = Crowd_human.image_folder
40 | class_names = Crowd_human.class_names
41 | num_classes = Crowd_human.num_classes
42 | class_names2id = dict(list(zip(class_names, list(range(num_classes)))))
43 | gt_boxes_name = 'fbox'
44 |
45 | # ----------train config---------- #
46 | backbone_freeze_at = 2
47 | rpn_channel = 256
48 |
49 | train_batch_per_gpu = 2
50 | momentum = 0.9
51 | weight_decay = 1e-4
52 | base_lr = 1e-3 * 1.25
53 |
54 | warm_iter = 800
55 | max_epoch = 30
56 | lr_decay = [24, 27]
57 | nr_images_epoch = 15000
58 | log_dump_interval = 20
59 |
60 | # ----------test config---------- #
61 | test_nms = 0.5
62 | test_nms_method = 'set_nms'
63 | visulize_threshold = 0.3
64 | pred_cls_threshold = 0.01
65 |
66 | # ----------model config---------- #
67 | batch_filter_box_size = 0
68 | nr_box_dim = 5
69 | ignore_label = -1
70 | max_boxes_of_image = 500
71 |
72 | # ----------rois generator config---------- #
73 | anchor_base_size = 32
74 | anchor_base_scale = [1]
75 | anchor_aspect_ratios = [1, 2, 3]
76 | num_cell_anchors = len(anchor_aspect_ratios)
77 | anchor_within_border = False
78 |
79 | rpn_min_box_size = 2
80 | rpn_nms_threshold = 0.7
81 | train_prev_nms_top_n = 12000
82 | train_post_nms_top_n = 2000
83 | test_prev_nms_top_n = 6000
84 | test_post_nms_top_n = 1000
85 |
86 | # ----------binding&training config---------- #
87 | rpn_smooth_l1_beta = 1
88 | rcnn_smooth_l1_beta = 1
89 |
90 | num_sample_anchors = 256
91 | positive_anchor_ratio = 0.5
92 | rpn_positive_overlap = 0.7
93 | rpn_negative_overlap = 0.3
94 | rpn_bbox_normalize_targets = False
95 |
96 | num_rois = 512
97 | fg_ratio = 0.5
98 | fg_threshold = 0.5
99 | bg_threshold_high = 0.5
100 | bg_threshold_low = 0.0
101 | rcnn_bbox_normalize_targets = True
102 | bbox_normalize_means = np.array([0, 0, 0, 0])
103 | bbox_normalize_stds = np.array([0.1, 0.1, 0.2, 0.2])
104 |
105 | config = Config()
106 |
107 |
--------------------------------------------------------------------------------
/model/rcnn_emd_refine/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | from config import config
7 | from backbone.resnet50 import ResNet50
8 | from backbone.fpn import FPN
9 | from module.rpn import RPN
10 | from layers.pooler import roi_pooler
11 | from det_oprs.bbox_opr import bbox_transform_inv_opr
12 | from det_oprs.fpn_roi_target import fpn_roi_target
13 | from det_oprs.loss_opr import emd_loss_softmax
14 | from det_oprs.utils import get_padded_tensor
15 |
16 | class Network(nn.Module):
17 | def __init__(self):
18 | super().__init__()
19 | self.resnet50 = ResNet50(config.backbone_freeze_at, False)
20 | self.FPN = FPN(self.resnet50, 2, 6)
21 | self.RPN = RPN(config.rpn_channel)
22 | self.RCNN = RCNN()
23 | assert config.num_classes == 2, 'Only support two class(1fg/1bg).'
24 |
25 | def forward(self, image, im_info, gt_boxes=None):
26 | image = (image - torch.tensor(config.image_mean[None, :, None, None]).type_as(image)) / (
27 | torch.tensor(config.image_std[None, :, None, None]).type_as(image))
28 | image = get_padded_tensor(image, 64)
29 | if self.training:
30 | return self._forward_train(image, im_info, gt_boxes)
31 | else:
32 | return self._forward_test(image, im_info)
33 |
34 | def _forward_train(self, image, im_info, gt_boxes):
35 | loss_dict = {}
36 | fpn_fms = self.FPN(image)
37 | # fpn_fms stride: 64,32,16,8,4, p6->p2
38 | rpn_rois, loss_dict_rpn = self.RPN(fpn_fms, im_info, gt_boxes)
39 | rcnn_rois, rcnn_labels, rcnn_bbox_targets = fpn_roi_target(
40 | rpn_rois, im_info, gt_boxes, top_k=2)
41 | loss_dict_rcnn = self.RCNN(fpn_fms, rcnn_rois,
42 | rcnn_labels, rcnn_bbox_targets)
43 | loss_dict.update(loss_dict_rpn)
44 | loss_dict.update(loss_dict_rcnn)
45 | return loss_dict
46 |
47 | def _forward_test(self, image, im_info):
48 | fpn_fms = self.FPN(image)
49 | rpn_rois = self.RPN(fpn_fms, im_info)
50 | pred_bbox = self.RCNN(fpn_fms, rpn_rois)
51 | return pred_bbox.cpu().detach()
52 |
53 | class RCNN(nn.Module):
54 | def __init__(self):
55 | super().__init__()
56 | # roi head
57 | self.fc1 = nn.Linear(256*7*7, 1024)
58 | self.fc2 = nn.Linear(1024, 1024)
59 | self.fc3 = nn.Linear(1044, 1024)
60 |
61 | for l in [self.fc1, self.fc2, self.fc3]:
62 | nn.init.kaiming_uniform_(l.weight, a=1)
63 | nn.init.constant_(l.bias, 0)
64 | # box predictor
65 | self.emd_pred_cls_0 = nn.Linear(1024, config.num_classes)
66 | self.emd_pred_delta_0 = nn.Linear(1024, config.num_classes * 4)
67 | self.emd_pred_cls_1 = nn.Linear(1024, config.num_classes)
68 | self.emd_pred_delta_1 = nn.Linear(1024, config.num_classes * 4)
69 | self.ref_pred_cls_0 = nn.Linear(1024, config.num_classes)
70 | self.ref_pred_delta_0 = nn.Linear(1024, config.num_classes * 4)
71 | self.ref_pred_cls_1 = nn.Linear(1024, config.num_classes)
72 | self.ref_pred_delta_1 = nn.Linear(1024, config.num_classes * 4)
73 | for l in [self.emd_pred_cls_0, self.emd_pred_cls_1,
74 | self.ref_pred_cls_0, self.ref_pred_cls_1]:
75 | nn.init.normal_(l.weight, std=0.001)
76 | nn.init.constant_(l.bias, 0)
77 | for l in [self.emd_pred_delta_0, self.emd_pred_delta_1,
78 | self.ref_pred_delta_0, self.ref_pred_delta_1]:
79 | nn.init.normal_(l.weight, std=0.001)
80 | nn.init.constant_(l.bias, 0)
81 |
82 | def forward(self, fpn_fms, rcnn_rois, labels=None, bbox_targets=None):
83 | # stride: 64,32,16,8,4 -> 4, 8, 16, 32
84 | fpn_fms = fpn_fms[1:][::-1]
85 | stride = [4, 8, 16, 32]
86 | pool_features = roi_pooler(fpn_fms, rcnn_rois, stride, (7, 7), "ROIAlignV2")
87 | flatten_feature = torch.flatten(pool_features, start_dim=1)
88 | flatten_feature = F.relu_(self.fc1(flatten_feature))
89 | flatten_feature = F.relu_(self.fc2(flatten_feature))
90 | pred_emd_cls_0 = self.emd_pred_cls_0(flatten_feature)
91 | pred_emd_delta_0 = self.emd_pred_delta_0(flatten_feature)
92 | pred_emd_cls_1 = self.emd_pred_cls_1(flatten_feature)
93 | pred_emd_delta_1 = self.emd_pred_delta_1(flatten_feature)
94 | pred_emd_scores_0 = F.softmax(pred_emd_cls_0, dim=-1)
95 | pred_emd_scores_1 = F.softmax(pred_emd_cls_1, dim=-1)
96 | # cons refine feature
97 | boxes_feature_0 = torch.cat((pred_emd_delta_0[:, 4:],
98 | pred_emd_scores_0[:, 1][:, None]), dim=1).repeat(1, 4)
99 | boxes_feature_1 = torch.cat((pred_emd_delta_1[:, 4:],
100 | pred_emd_scores_1[:, 1][:, None]), dim=1).repeat(1, 4)
101 | boxes_feature_0 = torch.cat((flatten_feature, boxes_feature_0), dim=1)
102 | boxes_feature_1 = torch.cat((flatten_feature, boxes_feature_1), dim=1)
103 | refine_feature_0 = F.relu_(self.fc3(boxes_feature_0))
104 | refine_feature_1 = F.relu_(self.fc3(boxes_feature_1))
105 | # refine
106 | pred_ref_cls_0 = self.ref_pred_cls_0(refine_feature_0)
107 | pred_ref_delta_0 = self.ref_pred_delta_0(refine_feature_0)
108 | pred_ref_cls_1 = self.ref_pred_cls_1(refine_feature_1)
109 | pred_ref_delta_1 = self.ref_pred_delta_1(refine_feature_1)
110 | if self.training:
111 | loss0 = emd_loss_softmax(
112 | pred_emd_delta_0, pred_emd_cls_0,
113 | pred_emd_delta_1, pred_emd_cls_1,
114 | bbox_targets, labels)
115 | loss1 = emd_loss_softmax(
116 | pred_emd_delta_1, pred_emd_cls_1,
117 | pred_emd_delta_0, pred_emd_cls_0,
118 | bbox_targets, labels)
119 | loss2 = emd_loss_softmax(
120 | pred_ref_delta_0, pred_ref_cls_0,
121 | pred_ref_delta_1, pred_ref_cls_1,
122 | bbox_targets, labels)
123 | loss3 = emd_loss_softmax(
124 | pred_ref_delta_1, pred_ref_cls_1,
125 | pred_ref_delta_0, pred_ref_cls_0,
126 | bbox_targets, labels)
127 | loss_rcnn = torch.cat([loss0, loss1], axis=1)
128 | loss_ref = torch.cat([loss2, loss3], axis=1)
129 | # requires_grad = False
130 | _, min_indices_rcnn = loss_rcnn.min(axis=1)
131 | _, min_indices_ref = loss_ref.min(axis=1)
132 | loss_rcnn = loss_rcnn[torch.arange(loss_rcnn.shape[0]), min_indices_rcnn]
133 | loss_rcnn = loss_rcnn.mean()
134 | loss_ref = loss_ref[torch.arange(loss_ref.shape[0]), min_indices_ref]
135 | loss_ref = loss_ref.mean()
136 | loss_dict = {}
137 | loss_dict['loss_rcnn_emd'] = loss_rcnn
138 | loss_dict['loss_ref_emd'] = loss_ref
139 | return loss_dict
140 | else:
141 | class_num = pred_ref_cls_0.shape[-1] - 1
142 | tag = torch.arange(class_num).type_as(pred_ref_cls_0)+1
143 | tag = tag.repeat(pred_ref_cls_0.shape[0], 1).reshape(-1,1)
144 | pred_scores_0 = F.softmax(pred_ref_cls_0, dim=-1)[:, 1:].reshape(-1, 1)
145 | pred_scores_1 = F.softmax(pred_ref_cls_1, dim=-1)[:, 1:].reshape(-1, 1)
146 | pred_delta_0 = pred_ref_delta_0[:, 4:].reshape(-1, 4)
147 | pred_delta_1 = pred_ref_delta_1[:, 4:].reshape(-1, 4)
148 | base_rois = rcnn_rois[:, 1:5].repeat(1, class_num).reshape(-1, 4)
149 | pred_bbox_0 = restore_bbox(base_rois, pred_delta_0, True)
150 | pred_bbox_1 = restore_bbox(base_rois, pred_delta_1, True)
151 | pred_bbox_0 = torch.cat([pred_bbox_0, pred_scores_0, tag], axis=1)
152 | pred_bbox_1 = torch.cat([pred_bbox_1, pred_scores_1, tag], axis=1)
153 | pred_bbox = torch.cat((pred_bbox_0, pred_bbox_1), axis=1)
154 | return pred_bbox
155 |
156 | def restore_bbox(rois, deltas, unnormalize=True):
157 | if unnormalize:
158 | std_opr = torch.tensor(config.bbox_normalize_stds[None, :]).type_as(deltas)
159 | mean_opr = torch.tensor(config.bbox_normalize_means[None, :]).type_as(deltas)
160 | deltas = deltas * std_opr
161 | deltas = deltas + mean_opr
162 | pred_bbox = bbox_transform_inv_opr(rois, deltas)
163 | return pred_bbox
164 |
--------------------------------------------------------------------------------
/model/rcnn_emd_simple/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import numpy as np
5 |
6 | def add_path(path):
7 | if path not in sys.path:
8 | sys.path.insert(0, path)
9 |
10 | root_dir = '../../'
11 | add_path(os.path.join(root_dir))
12 | add_path(os.path.join(root_dir, 'lib'))
13 |
14 | class Crowd_human:
15 | class_names = ['background', 'person']
16 | num_classes = len(class_names)
17 | root_folder = '/data/CrowdHuman'
18 | image_folder = '/data/CrowdHuman/images'
19 | train_source = os.path.join('/data/CrowdHuman/annotation_train.odgt')
20 | eval_source = os.path.join('/data/CrowdHuman/annotation_val.odgt')
21 |
22 | class Config:
23 | output_dir = 'outputs'
24 | model_dir = os.path.join(output_dir, 'model_dump')
25 | eval_dir = os.path.join(output_dir, 'eval_dump')
26 | init_weights = '/data/model/resnet50_fbaug.pth'
27 |
28 | # ----------data config---------- #
29 | image_mean = np.array([103.530, 116.280, 123.675])
30 | image_std = np.array([57.375, 57.120, 58.395])
31 | train_image_short_size = 800
32 | train_image_max_size = 1400
33 | eval_resize = True
34 | eval_image_short_size = 800
35 | eval_image_max_size = 1400
36 | seed_dataprovider = 3
37 | train_source = Crowd_human.train_source
38 | eval_source = Crowd_human.eval_source
39 | image_folder = Crowd_human.image_folder
40 | class_names = Crowd_human.class_names
41 | num_classes = Crowd_human.num_classes
42 | class_names2id = dict(list(zip(class_names, list(range(num_classes)))))
43 | gt_boxes_name = 'fbox'
44 |
45 | # ----------train config---------- #
46 | backbone_freeze_at = 2
47 | rpn_channel = 256
48 |
49 | train_batch_per_gpu = 2
50 | momentum = 0.9
51 | weight_decay = 1e-4
52 | base_lr = 1e-3 * 1.25
53 |
54 | warm_iter = 800
55 | max_epoch = 30
56 | lr_decay = [24, 27]
57 | nr_images_epoch = 15000
58 | log_dump_interval = 20
59 |
60 | # ----------test config---------- #
61 | test_nms = 0.5
62 | test_nms_method = 'set_nms'
63 | visulize_threshold = 0.3
64 | pred_cls_threshold = 0.01
65 |
66 | # ----------model config---------- #
67 | batch_filter_box_size = 0
68 | nr_box_dim = 5
69 | ignore_label = -1
70 | max_boxes_of_image = 500
71 |
72 | # ----------rois generator config---------- #
73 | anchor_base_size = 32
74 | anchor_base_scale = [1]
75 | anchor_aspect_ratios = [1, 2, 3]
76 | num_cell_anchors = len(anchor_aspect_ratios)
77 | anchor_within_border = False
78 |
79 | rpn_min_box_size = 2
80 | rpn_nms_threshold = 0.7
81 | train_prev_nms_top_n = 12000
82 | train_post_nms_top_n = 2000
83 | test_prev_nms_top_n = 6000
84 | test_post_nms_top_n = 1000
85 |
86 | # ----------binding&training config---------- #
87 | rpn_smooth_l1_beta = 1
88 | rcnn_smooth_l1_beta = 1
89 |
90 | num_sample_anchors = 256
91 | positive_anchor_ratio = 0.5
92 | rpn_positive_overlap = 0.7
93 | rpn_negative_overlap = 0.3
94 | rpn_bbox_normalize_targets = False
95 |
96 | num_rois = 512
97 | fg_ratio = 0.5
98 | fg_threshold = 0.5
99 | bg_threshold_high = 0.5
100 | bg_threshold_low = 0.0
101 | rcnn_bbox_normalize_targets = True
102 | bbox_normalize_means = np.array([0, 0, 0, 0])
103 | bbox_normalize_stds = np.array([0.1, 0.1, 0.2, 0.2])
104 |
105 | config = Config()
106 |
107 |
--------------------------------------------------------------------------------
/model/rcnn_emd_simple/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | from config import config
7 | from backbone.resnet50 import ResNet50
8 | from backbone.fpn import FPN
9 | from module.rpn import RPN
10 | from layers.pooler import roi_pooler
11 | from det_oprs.bbox_opr import bbox_transform_inv_opr
12 | from det_oprs.fpn_roi_target import fpn_roi_target
13 | from det_oprs.loss_opr import emd_loss_softmax
14 | from det_oprs.utils import get_padded_tensor
15 |
16 | class Network(nn.Module):
17 | def __init__(self):
18 | super().__init__()
19 | self.resnet50 = ResNet50(config.backbone_freeze_at, False)
20 | self.FPN = FPN(self.resnet50, 2, 6)
21 | self.RPN = RPN(config.rpn_channel)
22 | self.RCNN = RCNN()
23 |
24 | def forward(self, image, im_info, gt_boxes=None):
25 | image = (image - torch.tensor(config.image_mean[None, :, None, None]).type_as(image)) / (
26 | torch.tensor(config.image_std[None, :, None, None]).type_as(image))
27 | image = get_padded_tensor(image, 64)
28 | if self.training:
29 | return self._forward_train(image, im_info, gt_boxes)
30 | else:
31 | return self._forward_test(image, im_info)
32 |
33 | def _forward_train(self, image, im_info, gt_boxes):
34 | loss_dict = {}
35 | fpn_fms = self.FPN(image)
36 | # fpn_fms stride: 64,32,16,8,4, p6->p2
37 | rpn_rois, loss_dict_rpn = self.RPN(fpn_fms, im_info, gt_boxes)
38 | rcnn_rois, rcnn_labels, rcnn_bbox_targets = fpn_roi_target(
39 | rpn_rois, im_info, gt_boxes, top_k=2)
40 | loss_dict_rcnn = self.RCNN(fpn_fms, rcnn_rois,
41 | rcnn_labels, rcnn_bbox_targets)
42 | loss_dict.update(loss_dict_rpn)
43 | loss_dict.update(loss_dict_rcnn)
44 | return loss_dict
45 |
46 | def _forward_test(self, image, im_info):
47 | fpn_fms = self.FPN(image)
48 | rpn_rois = self.RPN(fpn_fms, im_info)
49 | pred_bbox = self.RCNN(fpn_fms, rpn_rois)
50 | return pred_bbox.cpu().detach()
51 |
52 | class RCNN(nn.Module):
53 | def __init__(self):
54 | super().__init__()
55 | # roi head
56 | self.fc1 = nn.Linear(256*7*7, 1024)
57 | self.fc2 = nn.Linear(1024, 1024)
58 |
59 | for l in [self.fc1, self.fc2]:
60 | nn.init.kaiming_uniform_(l.weight, a=1)
61 | nn.init.constant_(l.bias, 0)
62 | # box predictor
63 | self.emd_pred_cls_0 = nn.Linear(1024, config.num_classes)
64 | self.emd_pred_delta_0 = nn.Linear(1024, config.num_classes * 4)
65 | self.emd_pred_cls_1 = nn.Linear(1024, config.num_classes)
66 | self.emd_pred_delta_1 = nn.Linear(1024, config.num_classes * 4)
67 | for l in [self.emd_pred_cls_0, self.emd_pred_cls_1]:
68 | nn.init.normal_(l.weight, std=0.01)
69 | nn.init.constant_(l.bias, 0)
70 | for l in [self.emd_pred_delta_0, self.emd_pred_delta_1]:
71 | nn.init.normal_(l.weight, std=0.001)
72 | nn.init.constant_(l.bias, 0)
73 |
74 | def forward(self, fpn_fms, rcnn_rois, labels=None, bbox_targets=None):
75 | # stride: 64,32,16,8,4 -> 4, 8, 16, 32
76 | fpn_fms = fpn_fms[1:][::-1]
77 | stride = [4, 8, 16, 32]
78 | pool_features = roi_pooler(fpn_fms, rcnn_rois, stride, (7, 7), "ROIAlignV2")
79 | flatten_feature = torch.flatten(pool_features, start_dim=1)
80 | flatten_feature = F.relu_(self.fc1(flatten_feature))
81 | flatten_feature = F.relu_(self.fc2(flatten_feature))
82 | pred_emd_cls_0 = self.emd_pred_cls_0(flatten_feature)
83 | pred_emd_delta_0 = self.emd_pred_delta_0(flatten_feature)
84 | pred_emd_cls_1 = self.emd_pred_cls_1(flatten_feature)
85 | pred_emd_delta_1 = self.emd_pred_delta_1(flatten_feature)
86 | if self.training:
87 | loss0 = emd_loss_softmax(
88 | pred_emd_delta_0, pred_emd_cls_0,
89 | pred_emd_delta_1, pred_emd_cls_1,
90 | bbox_targets, labels)
91 | loss1 = emd_loss_softmax(
92 | pred_emd_delta_1, pred_emd_cls_1,
93 | pred_emd_delta_0, pred_emd_cls_0,
94 | bbox_targets, labels)
95 | loss = torch.cat([loss0, loss1], axis=1)
96 | # requires_grad = False
97 | _, min_indices = loss.min(axis=1)
98 | loss_emd = loss[torch.arange(loss.shape[0]), min_indices]
99 | loss_emd = loss_emd.mean()
100 | loss_dict = {}
101 | loss_dict['loss_rcnn_emd'] = loss_emd
102 | return loss_dict
103 | else:
104 | class_num = pred_emd_cls_0.shape[-1] - 1
105 | tag = torch.arange(class_num).type_as(pred_emd_cls_0)+1
106 | tag = tag.repeat(pred_emd_cls_0.shape[0], 1).reshape(-1,1)
107 | pred_scores_0 = F.softmax(pred_emd_cls_0, dim=-1)[:, 1:].reshape(-1, 1)
108 | pred_scores_1 = F.softmax(pred_emd_cls_1, dim=-1)[:, 1:].reshape(-1, 1)
109 | pred_delta_0 = pred_emd_delta_0[:, 4:].reshape(-1, 4)
110 | pred_delta_1 = pred_emd_delta_1[:, 4:].reshape(-1, 4)
111 | base_rois = rcnn_rois[:, 1:5].repeat(1, class_num).reshape(-1, 4)
112 | pred_bbox_0 = restore_bbox(base_rois, pred_delta_0, True)
113 | pred_bbox_1 = restore_bbox(base_rois, pred_delta_1, True)
114 | pred_bbox_0 = torch.cat([pred_bbox_0, pred_scores_0, tag], axis=1)
115 | pred_bbox_1 = torch.cat([pred_bbox_1, pred_scores_1, tag], axis=1)
116 | pred_bbox = torch.cat((pred_bbox_0, pred_bbox_1), axis=1)
117 | return pred_bbox
118 |
119 | def restore_bbox(rois, deltas, unnormalize=True):
120 | if unnormalize:
121 | std_opr = torch.tensor(config.bbox_normalize_stds[None, :]).type_as(deltas)
122 | mean_opr = torch.tensor(config.bbox_normalize_means[None, :]).type_as(deltas)
123 | deltas = deltas * std_opr
124 | deltas = deltas + mean_opr
125 | pred_bbox = bbox_transform_inv_opr(rois, deltas)
126 | return pred_bbox
127 |
--------------------------------------------------------------------------------
/model/rcnn_fpn_baseline/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import numpy as np
5 |
6 | def add_path(path):
7 | if path not in sys.path:
8 | sys.path.insert(0, path)
9 |
10 | root_dir = '../../'
11 | add_path(os.path.join(root_dir))
12 | add_path(os.path.join(root_dir, 'lib'))
13 |
14 | class Crowd_human:
15 | class_names = ['background', 'person']
16 | num_classes = len(class_names)
17 | root_folder = '/data/CrowdHuman'
18 | image_folder = '/data/CrowdHuman/images'
19 | train_source = os.path.join('/data/CrowdHuman/annotation_train.odgt')
20 | eval_source = os.path.join('/data/CrowdHuman/annotation_val.odgt')
21 |
22 | class Config:
23 | output_dir = 'outputs'
24 | model_dir = os.path.join(output_dir, 'model_dump')
25 | eval_dir = os.path.join(output_dir, 'eval_dump')
26 | init_weights = '/data/model/resnet50_fbaug.pth'
27 |
28 | # ----------data config---------- #
29 | image_mean = np.array([103.530, 116.280, 123.675])
30 | image_std = np.array([57.375, 57.120, 58.395])
31 | train_image_short_size = 800
32 | train_image_max_size = 1400
33 | eval_resize = True
34 | eval_image_short_size = 800
35 | eval_image_max_size = 1400
36 | seed_dataprovider = 3
37 | train_source = Crowd_human.train_source
38 | eval_source = Crowd_human.eval_source
39 | image_folder = Crowd_human.image_folder
40 | class_names = Crowd_human.class_names
41 | num_classes = Crowd_human.num_classes
42 | class_names2id = dict(list(zip(class_names, list(range(num_classes)))))
43 | gt_boxes_name = 'fbox'
44 |
45 | # ----------train config---------- #
46 | backbone_freeze_at = 2
47 | rpn_channel = 256
48 |
49 | train_batch_per_gpu = 2
50 | momentum = 0.9
51 | weight_decay = 1e-4
52 | base_lr = 1e-3 * 1.25
53 |
54 | warm_iter = 800
55 | max_epoch = 30
56 | lr_decay = [24, 27]
57 | nr_images_epoch = 15000
58 | log_dump_interval = 20
59 |
60 | # ----------test config---------- #
61 | test_nms = 0.5
62 | test_nms_method = 'normal_nms'
63 | visulize_threshold = 0.3
64 | pred_cls_threshold = 0.01
65 |
66 | # ----------model config---------- #
67 | batch_filter_box_size = 0
68 | nr_box_dim = 5
69 | ignore_label = -1
70 | max_boxes_of_image = 500
71 |
72 | # ----------rois generator config---------- #
73 | anchor_base_size = 32
74 | anchor_base_scale = [1]
75 | anchor_aspect_ratios = [1, 2, 3]
76 | num_cell_anchors = len(anchor_aspect_ratios)
77 | anchor_within_border = False
78 |
79 | rpn_min_box_size = 2
80 | rpn_nms_threshold = 0.7
81 | train_prev_nms_top_n = 12000
82 | train_post_nms_top_n = 2000
83 | test_prev_nms_top_n = 6000
84 | test_post_nms_top_n = 1000
85 |
86 | # ----------binding&training config---------- #
87 | rpn_smooth_l1_beta = 1
88 | rcnn_smooth_l1_beta = 1
89 |
90 | num_sample_anchors = 256
91 | positive_anchor_ratio = 0.5
92 | rpn_positive_overlap = 0.7
93 | rpn_negative_overlap = 0.3
94 | rpn_bbox_normalize_targets = False
95 |
96 | num_rois = 512
97 | fg_ratio = 0.5
98 | fg_threshold = 0.5
99 | bg_threshold_high = 0.5
100 | bg_threshold_low = 0.0
101 | rcnn_bbox_normalize_targets = True
102 | bbox_normalize_means = np.array([0, 0, 0, 0])
103 | bbox_normalize_stds = np.array([0.1, 0.1, 0.2, 0.2])
104 |
105 | config = Config()
106 |
107 |
--------------------------------------------------------------------------------
/model/rcnn_fpn_baseline/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | from config import config
7 | from backbone.resnet50 import ResNet50
8 | from backbone.fpn import FPN
9 | from module.rpn import RPN
10 | from layers.pooler import roi_pooler
11 | from det_oprs.bbox_opr import bbox_transform_inv_opr
12 | from det_oprs.fpn_roi_target import fpn_roi_target
13 | from det_oprs.loss_opr import softmax_loss, smooth_l1_loss
14 | from det_oprs.utils import get_padded_tensor
15 |
16 | class Network(nn.Module):
17 | def __init__(self):
18 | super().__init__()
19 | self.resnet50 = ResNet50(config.backbone_freeze_at, False)
20 | self.FPN = FPN(self.resnet50, 2, 6)
21 | self.RPN = RPN(config.rpn_channel)
22 | self.RCNN = RCNN()
23 |
24 | def forward(self, image, im_info, gt_boxes=None):
25 | image = (image - torch.tensor(config.image_mean[None, :, None, None]).type_as(image)) / (
26 | torch.tensor(config.image_std[None, :, None, None]).type_as(image))
27 | image = get_padded_tensor(image, 64)
28 | if self.training:
29 | return self._forward_train(image, im_info, gt_boxes)
30 | else:
31 | return self._forward_test(image, im_info)
32 |
33 | def _forward_train(self, image, im_info, gt_boxes):
34 | loss_dict = {}
35 | fpn_fms = self.FPN(image)
36 | # fpn_fms stride: 64,32,16,8,4, p6->p2
37 | rpn_rois, loss_dict_rpn = self.RPN(fpn_fms, im_info, gt_boxes)
38 | rcnn_rois, rcnn_labels, rcnn_bbox_targets = fpn_roi_target(
39 | rpn_rois, im_info, gt_boxes, top_k=1)
40 | loss_dict_rcnn = self.RCNN(fpn_fms, rcnn_rois,
41 | rcnn_labels, rcnn_bbox_targets)
42 | loss_dict.update(loss_dict_rpn)
43 | loss_dict.update(loss_dict_rcnn)
44 | return loss_dict
45 |
46 | def _forward_test(self, image, im_info):
47 | fpn_fms = self.FPN(image)
48 | rpn_rois = self.RPN(fpn_fms, im_info)
49 | pred_bbox = self.RCNN(fpn_fms, rpn_rois)
50 | return pred_bbox.cpu().detach()
51 |
52 | class RCNN(nn.Module):
53 | def __init__(self):
54 | super().__init__()
55 | # roi head
56 | self.fc1 = nn.Linear(256*7*7, 1024)
57 | self.fc2 = nn.Linear(1024, 1024)
58 |
59 | for l in [self.fc1, self.fc2]:
60 | nn.init.kaiming_uniform_(l.weight, a=1)
61 | nn.init.constant_(l.bias, 0)
62 | # box predictor
63 | self.pred_cls = nn.Linear(1024, config.num_classes)
64 | self.pred_delta = nn.Linear(1024, config.num_classes * 4)
65 | for l in [self.pred_cls]:
66 | nn.init.normal_(l.weight, std=0.01)
67 | nn.init.constant_(l.bias, 0)
68 | for l in [self.pred_delta]:
69 | nn.init.normal_(l.weight, std=0.001)
70 | nn.init.constant_(l.bias, 0)
71 |
72 | def forward(self, fpn_fms, rcnn_rois, labels=None, bbox_targets=None):
73 | # input p2-p5
74 | fpn_fms = fpn_fms[1:][::-1]
75 | stride = [4, 8, 16, 32]
76 | pool_features = roi_pooler(fpn_fms, rcnn_rois, stride, (7, 7), "ROIAlignV2")
77 | flatten_feature = torch.flatten(pool_features, start_dim=1)
78 | flatten_feature = F.relu_(self.fc1(flatten_feature))
79 | flatten_feature = F.relu_(self.fc2(flatten_feature))
80 | pred_cls = self.pred_cls(flatten_feature)
81 | pred_delta = self.pred_delta(flatten_feature)
82 | if self.training:
83 | # loss for regression
84 | labels = labels.long().flatten()
85 | fg_masks = labels > 0
86 | valid_masks = labels >= 0
87 | # multi class
88 | pred_delta = pred_delta.reshape(-1, config.num_classes, 4)
89 | fg_gt_classes = labels[fg_masks]
90 | pred_delta = pred_delta[fg_masks, fg_gt_classes, :]
91 | localization_loss = smooth_l1_loss(
92 | pred_delta,
93 | bbox_targets[fg_masks],
94 | config.rcnn_smooth_l1_beta)
95 | # loss for classification
96 | objectness_loss = softmax_loss(pred_cls, labels)
97 | objectness_loss = objectness_loss * valid_masks
98 | normalizer = 1.0 / valid_masks.sum().item()
99 | loss_rcnn_loc = localization_loss.sum() * normalizer
100 | loss_rcnn_cls = objectness_loss.sum() * normalizer
101 | loss_dict = {}
102 | loss_dict['loss_rcnn_loc'] = loss_rcnn_loc
103 | loss_dict['loss_rcnn_cls'] = loss_rcnn_cls
104 | return loss_dict
105 | else:
106 | class_num = pred_cls.shape[-1] - 1
107 | tag = torch.arange(class_num).type_as(pred_cls)+1
108 | tag = tag.repeat(pred_cls.shape[0], 1).reshape(-1,1)
109 | pred_scores = F.softmax(pred_cls, dim=-1)[:, 1:].reshape(-1, 1)
110 | pred_delta = pred_delta[:, 4:].reshape(-1, 4)
111 | base_rois = rcnn_rois[:, 1:5].repeat(1, class_num).reshape(-1, 4)
112 | pred_bbox = restore_bbox(base_rois, pred_delta, True)
113 | pred_bbox = torch.cat([pred_bbox, pred_scores, tag], axis=1)
114 | return pred_bbox
115 |
116 | def restore_bbox(rois, deltas, unnormalize=True):
117 | if unnormalize:
118 | std_opr = torch.tensor(config.bbox_normalize_stds[None, :]).type_as(deltas)
119 | mean_opr = torch.tensor(config.bbox_normalize_means[None, :]).type_as(deltas)
120 | deltas = deltas * std_opr
121 | deltas = deltas + mean_opr
122 | pred_bbox = bbox_transform_inv_opr(rois, deltas)
123 | return pred_bbox
124 |
--------------------------------------------------------------------------------
/model/retina_emd_simple/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import numpy as np
5 |
6 | def add_path(path):
7 | if path not in sys.path:
8 | sys.path.insert(0, path)
9 |
10 | root_dir = '../../'
11 | add_path(os.path.join(root_dir))
12 | add_path(os.path.join(root_dir, 'lib'))
13 |
14 | class Crowd_human:
15 | class_names = ['background', 'person']
16 | num_classes = len(class_names)
17 | root_folder = '/data/CrowdHuman'
18 | image_folder = '/data/CrowdHuman/images'
19 | train_source = os.path.join('/data/CrowdHuman/annotation_train.odgt')
20 | eval_source = os.path.join('/data/CrowdHuman/annotation_val.odgt')
21 |
22 | class Config:
23 | output_dir = 'outputs'
24 | model_dir = os.path.join(output_dir, 'model_dump')
25 | eval_dir = os.path.join(output_dir, 'eval_dump')
26 | init_weights = '/data/model/resnet50_fbaug.pth'
27 |
28 | # ----------data config---------- #
29 | image_mean = np.array([103.530, 116.280, 123.675])
30 | image_std = np.array([57.375, 57.120, 58.395])
31 | train_image_short_size = 800
32 | train_image_max_size = 1400
33 | eval_resize = True
34 | eval_image_short_size = 800
35 | eval_image_max_size = 1400
36 | seed_dataprovider = 3
37 | train_source = Crowd_human.train_source
38 | eval_source = Crowd_human.eval_source
39 | image_folder = Crowd_human.image_folder
40 | class_names = Crowd_human.class_names
41 | num_classes = Crowd_human.num_classes
42 | class_names2id = dict(list(zip(class_names, list(range(num_classes)))))
43 | gt_boxes_name = 'fbox'
44 |
45 | # ----------train config---------- #
46 | backbone_freeze_at = 2
47 | train_batch_per_gpu = 2
48 | momentum = 0.9
49 | weight_decay = 1e-4
50 | base_lr = 3.125e-4
51 | focal_loss_alpha = 0.25
52 | focal_loss_gamma = 2
53 |
54 | warm_iter = 800
55 | max_epoch = 50
56 | lr_decay = [33, 43]
57 | nr_images_epoch = 15000
58 | log_dump_interval = 20
59 |
60 | # ----------test config---------- #
61 | test_layer_topk = 1000
62 | test_nms = 0.5
63 | test_nms_method = 'set_nms'
64 | visulize_threshold = 0.3
65 | pred_cls_threshold = 0.01
66 |
67 | # ----------dataset config---------- #
68 | nr_box_dim = 5
69 | max_boxes_of_image = 500
70 |
71 | # --------anchor generator config-------- #
72 | anchor_base_size = 32 # the minimize anchor size in the bigest feature map.
73 | anchor_base_scale = [2**0, 2**(1/3), 2**(2/3)]
74 | anchor_aspect_ratios = [1, 2, 3]
75 | num_cell_anchors = len(anchor_aspect_ratios) * len(anchor_base_scale)
76 |
77 | # ----------binding&training config---------- #
78 | smooth_l1_beta = 0.1
79 | negative_thresh = 0.4
80 | positive_thresh = 0.5
81 | allow_low_quality = True
82 |
83 | config = Config()
84 |
--------------------------------------------------------------------------------
/model/retina_emd_simple/network.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 |
7 | from config import config
8 | from backbone.resnet50 import ResNet50
9 | from backbone.fpn import FPN
10 | from det_oprs.anchors_generator import AnchorGenerator
11 | from det_oprs.retina_anchor_target import retina_anchor_target
12 | from det_oprs.bbox_opr import bbox_transform_inv_opr
13 | from det_oprs.loss_opr import emd_loss_focal
14 | from det_oprs.utils import get_padded_tensor
15 |
16 | class Network(nn.Module):
17 | def __init__(self):
18 | super().__init__()
19 | self.resnet50 = ResNet50(config.backbone_freeze_at, False)
20 | self.FPN = FPN(self.resnet50, 3, 7)
21 | self.R_Head = RetinaNet_Head()
22 | self.R_Anchor = RetinaNet_Anchor()
23 | self.R_Criteria = RetinaNet_Criteria()
24 |
25 | def forward(self, image, im_info, gt_boxes=None):
26 | # pre-processing the data
27 | image = (image - torch.tensor(config.image_mean[None, :, None, None]).type_as(image)) / (
28 | torch.tensor(config.image_std[None, :, None, None]).type_as(image))
29 | image = get_padded_tensor(image, 64)
30 | # do inference
31 | # stride: 128,64,32,16,8, p7->p3
32 | fpn_fms = self.FPN(image)
33 | anchors_list = self.R_Anchor(fpn_fms)
34 | pred_cls_list, pred_reg_list = self.R_Head(fpn_fms)
35 | # release the useless data
36 | if self.training:
37 | loss_dict = self.R_Criteria(
38 | pred_cls_list, pred_reg_list, anchors_list, gt_boxes, im_info)
39 | return loss_dict
40 | else:
41 | #pred_bbox = union_inference(
42 | # anchors_list, pred_cls_list, pred_reg_list, im_info)
43 | pred_bbox = per_layer_inference(
44 | anchors_list, pred_cls_list, pred_reg_list, im_info)
45 | return pred_bbox.cpu().detach()
46 |
47 | class RetinaNet_Anchor():
48 | def __init__(self):
49 | self.anchors_generator = AnchorGenerator(
50 | config.anchor_base_size,
51 | config.anchor_aspect_ratios,
52 | config.anchor_base_scale)
53 |
54 | def __call__(self, fpn_fms):
55 | # get anchors
56 | all_anchors_list = []
57 | base_stride = 8
58 | off_stride = 2**(len(fpn_fms)-1) # 16
59 | for fm in fpn_fms:
60 | layer_anchors = self.anchors_generator(fm, base_stride, off_stride)
61 | off_stride = off_stride // 2
62 | all_anchors_list.append(layer_anchors)
63 | return all_anchors_list
64 |
65 | class RetinaNet_Criteria(nn.Module):
66 | def __init__(self):
67 | super().__init__()
68 | self.loss_normalizer = 100 # initialize with any reasonable #fg that's not too small
69 | self.loss_normalizer_momentum = 0.9
70 |
71 | def __call__(self, pred_cls_list, pred_reg_list, anchors_list, gt_boxes, im_info):
72 | all_anchors = torch.cat(anchors_list, axis=0)
73 | all_pred_cls = torch.cat(pred_cls_list, axis=1).reshape(-1, (config.num_classes-1)*2)
74 | all_pred_cls = torch.sigmoid(all_pred_cls)
75 | all_pred_reg = torch.cat(pred_reg_list, axis=1).reshape(-1, 4*2)
76 | # get ground truth
77 | labels, bbox_targets = retina_anchor_target(all_anchors, gt_boxes, im_info, top_k=2)
78 | all_pred_cls = all_pred_cls.reshape(-1, 2, config.num_classes-1)
79 | all_pred_reg = all_pred_reg.reshape(-1, 2, 4)
80 | loss0 = emd_loss_focal(
81 | all_pred_reg[:, 0], all_pred_cls[:, 0],
82 | all_pred_reg[:, 1], all_pred_cls[:, 1],
83 | bbox_targets, labels)
84 | loss1 = emd_loss_focal(
85 | all_pred_reg[:, 1], all_pred_cls[:, 1],
86 | all_pred_reg[:, 0], all_pred_cls[:, 0],
87 | bbox_targets, labels)
88 | del all_anchors
89 | del all_pred_cls
90 | del all_pred_reg
91 | loss = torch.cat([loss0, loss1], axis=1)
92 | # requires_grad = False
93 | _, min_indices = loss.min(axis=1)
94 | loss_emd = loss[torch.arange(loss.shape[0]), min_indices]
95 | # only main labels
96 | num_pos = (labels[:, 0] > 0).sum().item()
97 | self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + (
98 | 1 - self.loss_normalizer_momentum) * max(num_pos, 1)
99 | loss_emd = loss_emd.sum() / self.loss_normalizer
100 | loss_dict = {}
101 | loss_dict['retina_emd'] = loss_emd
102 | return loss_dict
103 |
104 | class RetinaNet_Head(nn.Module):
105 | def __init__(self):
106 | super().__init__()
107 | num_convs = 4
108 | in_channels = 256
109 | cls_subnet = []
110 | bbox_subnet = []
111 | for _ in range(num_convs):
112 | cls_subnet.append(
113 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
114 | )
115 | cls_subnet.append(nn.ReLU(inplace=True))
116 | bbox_subnet.append(
117 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
118 | )
119 | bbox_subnet.append(nn.ReLU(inplace=True))
120 | self.cls_subnet = nn.Sequential(*cls_subnet)
121 | self.bbox_subnet = nn.Sequential(*bbox_subnet)
122 | # predictor
123 | self.cls_score = nn.Conv2d(
124 | in_channels, config.num_cell_anchors * (config.num_classes-1) * 2,
125 | kernel_size=3, stride=1, padding=1)
126 | self.bbox_pred = nn.Conv2d(
127 | in_channels, config.num_cell_anchors * 4 * 2,
128 | kernel_size=3, stride=1, padding=1)
129 |
130 | # Initialization
131 | for modules in [self.cls_subnet, self.bbox_subnet,
132 | self.cls_score, self.bbox_pred]:
133 | for layer in modules.modules():
134 | if isinstance(layer, nn.Conv2d):
135 | torch.nn.init.normal_(layer.weight, mean=0, std=0.01)
136 | torch.nn.init.constant_(layer.bias, 0)
137 | prior_prob = 0.01
138 | # Use prior in model initialization to improve stability
139 | bias_value = -(math.log((1 - prior_prob) / prior_prob))
140 | torch.nn.init.constant_(self.cls_score.bias, bias_value)
141 |
142 | def forward(self, features):
143 | pred_cls = []
144 | pred_reg = []
145 | for feature in features:
146 | pred_cls.append(self.cls_score(self.cls_subnet(feature)))
147 | pred_reg.append(self.bbox_pred(self.bbox_subnet(feature)))
148 | # reshape the predictions
149 | assert pred_cls[0].dim() == 4
150 | pred_cls_list = [
151 | _.permute(0, 2, 3, 1).reshape(pred_cls[0].shape[0], -1, (config.num_classes-1)*2)
152 | for _ in pred_cls]
153 | pred_reg_list = [
154 | _.permute(0, 2, 3, 1).reshape(pred_reg[0].shape[0], -1, 4*2)
155 | for _ in pred_reg]
156 | return pred_cls_list, pred_reg_list
157 |
158 | def per_layer_inference(anchors_list, pred_cls_list, pred_reg_list, im_info):
159 | keep_anchors = []
160 | keep_cls = []
161 | keep_reg = []
162 | class_num = pred_cls_list[0].shape[-1] // 2
163 | for l_id in range(len(anchors_list)):
164 | anchors = anchors_list[l_id].reshape(-1, 4)
165 | pred_cls = pred_cls_list[l_id][0].reshape(-1, class_num*2)
166 | pred_reg = pred_reg_list[l_id][0].reshape(-1, 4*2)
167 | if len(anchors) > config.test_layer_topk:
168 | ruler = pred_cls.max(axis=1)[0]
169 | _, inds = ruler.topk(config.test_layer_topk, dim=0)
170 | inds = inds.flatten()
171 | keep_anchors.append(anchors[inds])
172 | keep_cls.append(torch.sigmoid(pred_cls[inds]))
173 | keep_reg.append(pred_reg[inds])
174 | else:
175 | keep_anchors.append(anchors)
176 | keep_cls.append(torch.sigmoid(pred_cls))
177 | keep_reg.append(pred_reg)
178 | keep_anchors = torch.cat(keep_anchors, axis = 0)
179 | keep_cls = torch.cat(keep_cls, axis = 0)
180 | keep_reg = torch.cat(keep_reg, axis = 0)
181 | # multiclass
182 | tag = torch.arange(class_num).type_as(keep_cls)+1
183 | tag = tag.repeat(keep_cls.shape[0], 1).reshape(-1,1)
184 | pred_scores_0 = keep_cls[:, :class_num].reshape(-1, 1)
185 | pred_scores_1 = keep_cls[:, class_num:].reshape(-1, 1)
186 | pred_delta_0 = keep_reg[:, :4]
187 | pred_delta_1 = keep_reg[:, 4:]
188 | pred_bbox_0 = restore_bbox(keep_anchors, pred_delta_0, False)
189 | pred_bbox_1 = restore_bbox(keep_anchors, pred_delta_1, False)
190 | pred_bbox_0 = pred_bbox_0.repeat(1, class_num).reshape(-1, 4)
191 | pred_bbox_1 = pred_bbox_1.repeat(1, class_num).reshape(-1, 4)
192 | pred_bbox_0 = torch.cat([pred_bbox_0, pred_scores_0, tag], axis=1)
193 | pred_bbox_1 = torch.cat([pred_bbox_1, pred_scores_1, tag], axis=1)
194 | pred_bbox = torch.cat((pred_bbox_0, pred_bbox_1), axis=1)
195 | return pred_bbox
196 |
197 | def union_inference(anchors_list, pred_cls_list, pred_reg_list, im_info):
198 | anchors = torch.cat(anchors_list, axis = 0)
199 | pred_cls = torch.cat(pred_cls_list, axis = 1)[0]
200 | pred_cls = torch.sigmoid(pred_cls)
201 | pred_reg = torch.cat(pred_reg_list, axis = 1)[0]
202 | class_num = pred_cls.shape[-1] // 2
203 | # multiclass
204 | tag = torch.arange(class_num).type_as(pred_cls)+1
205 | tag = tag.repeat(pred_cls.shape[0], 1).reshape(-1,1)
206 | pred_scores_0 = pred_cls[:, :class_num].reshape(-1, 1)
207 | pred_scores_1 = pred_cls[:, class_num:].reshape(-1, 1)
208 | pred_delta_0 = pred_reg[:, :4]
209 | pred_delta_1 = pred_reg[:, 4:]
210 | pred_bbox_0 = restore_bbox(anchors, pred_delta_0, False)
211 | pred_bbox_1 = restore_bbox(anchors, pred_delta_1, False)
212 | pred_bbox_0 = pred_bbox_0.repeat(1, class_num).reshape(-1, 4)
213 | pred_bbox_1 = pred_bbox_1.repeat(1, class_num).reshape(-1, 4)
214 | pred_bbox_0 = torch.cat([pred_bbox_0, pred_scores_0, tag], axis=1)
215 | pred_bbox_1 = torch.cat([pred_bbox_1, pred_scores_1, tag], axis=1)
216 | pred_bbox = torch.cat((pred_bbox_0, pred_bbox_1), axis=1)
217 | return pred_bbox
218 |
219 | def restore_bbox(rois, deltas, unnormalize=True):
220 | if unnormalize:
221 | std_opr = torch.tensor(config.bbox_normalize_stds[None, :]).type_as(deltas)
222 | mean_opr = torch.tensor(config.bbox_normalize_means[None, :]).type_as(deltas)
223 | deltas = deltas * std_opr
224 | deltas = deltas + mean_opr
225 | pred_bbox = bbox_transform_inv_opr(rois, deltas)
226 | return pred_bbox
227 |
228 |
--------------------------------------------------------------------------------
/model/retina_fpn_baseline/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import numpy as np
5 |
6 | def add_path(path):
7 | if path not in sys.path:
8 | sys.path.insert(0, path)
9 |
10 | root_dir = '../../'
11 | add_path(os.path.join(root_dir))
12 | add_path(os.path.join(root_dir, 'lib'))
13 |
14 | class Crowd_human:
15 | class_names = ['background', 'person']
16 | num_classes = len(class_names)
17 | root_folder = '/data/CrowdHuman'
18 | image_folder = '/data/CrowdHuman/images'
19 | train_source = os.path.join('/data/CrowdHuman/annotation_train.odgt')
20 | eval_source = os.path.join('/data/CrowdHuman/annotation_val.odgt')
21 |
22 | class Config:
23 | output_dir = 'outputs'
24 | model_dir = os.path.join(output_dir, 'model_dump')
25 | eval_dir = os.path.join(output_dir, 'eval_dump')
26 | init_weights = '/data/model/resnet50_fbaug.pth'
27 |
28 | # ----------data config---------- #
29 | image_mean = np.array([103.530, 116.280, 123.675])
30 | image_std = np.array([57.375, 57.120, 58.395])
31 | train_image_short_size = 800
32 | train_image_max_size = 1400
33 | eval_resize = True
34 | eval_image_short_size = 800
35 | eval_image_max_size = 1400
36 | seed_dataprovider = 3
37 | train_source = Crowd_human.train_source
38 | eval_source = Crowd_human.eval_source
39 | image_folder = Crowd_human.image_folder
40 | class_names = Crowd_human.class_names
41 | num_classes = Crowd_human.num_classes
42 | class_names2id = dict(list(zip(class_names, list(range(num_classes)))))
43 | gt_boxes_name = 'fbox'
44 |
45 | # ----------train config---------- #
46 | backbone_freeze_at = 2
47 | train_batch_per_gpu = 2
48 | momentum = 0.9
49 | weight_decay = 1e-4
50 | base_lr = 3.125e-4
51 | focal_loss_alpha = 0.25
52 | focal_loss_gamma = 2
53 |
54 | warm_iter = 800
55 | max_epoch = 50
56 | lr_decay = [33, 43]
57 | nr_images_epoch = 15000
58 | log_dump_interval = 20
59 |
60 | # ----------test config---------- #
61 | test_layer_topk = 1000
62 | test_nms = 0.5
63 | test_nms_method = 'normal_nms'
64 | visulize_threshold = 0.3
65 | pred_cls_threshold = 0.01
66 |
67 | # ----------dataset config---------- #
68 | nr_box_dim = 5
69 | max_boxes_of_image = 500
70 |
71 | # --------anchor generator config-------- #
72 | anchor_base_size = 32 # the minimize anchor size in the bigest feature map.
73 | anchor_base_scale = [2**0, 2**(1/3), 2**(2/3)]
74 | anchor_aspect_ratios = [1, 2, 3]
75 | num_cell_anchors = len(anchor_aspect_ratios) * len(anchor_base_scale)
76 |
77 | # ----------binding&training config---------- #
78 | smooth_l1_beta = 0.1
79 | negative_thresh = 0.4
80 | positive_thresh = 0.5
81 | allow_low_quality = True
82 |
83 | config = Config()
84 |
--------------------------------------------------------------------------------
/model/retina_fpn_baseline/network.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 |
7 | from config import config
8 | from backbone.resnet50 import ResNet50
9 | from backbone.fpn import FPN
10 | from det_oprs.anchors_generator import AnchorGenerator
11 | from det_oprs.retina_anchor_target import retina_anchor_target
12 | from det_oprs.bbox_opr import bbox_transform_inv_opr
13 | from det_oprs.loss_opr import focal_loss, smooth_l1_loss
14 | from det_oprs.utils import get_padded_tensor
15 |
16 | class Network(nn.Module):
17 | def __init__(self):
18 | super().__init__()
19 | self.resnet50 = ResNet50(config.backbone_freeze_at, False)
20 | self.FPN = FPN(self.resnet50, 3, 7)
21 | self.R_Head = RetinaNet_Head()
22 | self.R_Anchor = RetinaNet_Anchor()
23 | self.R_Criteria = RetinaNet_Criteria()
24 |
25 | def forward(self, image, im_info, gt_boxes=None):
26 | # pre-processing the data
27 | image = (image - torch.tensor(config.image_mean[None, :, None, None]).type_as(image)) / (
28 | torch.tensor(config.image_std[None, :, None, None]).type_as(image))
29 | image = get_padded_tensor(image, 64)
30 | # do inference
31 | # stride: 128,64,32,16,8, p7->p3
32 | fpn_fms = self.FPN(image)
33 | anchors_list = self.R_Anchor(fpn_fms)
34 | pred_cls_list, pred_reg_list = self.R_Head(fpn_fms)
35 | # release the useless data
36 | if self.training:
37 | loss_dict = self.R_Criteria(
38 | pred_cls_list, pred_reg_list, anchors_list, gt_boxes, im_info)
39 | return loss_dict
40 | else:
41 | #pred_bbox = union_inference(
42 | # anchors_list, pred_cls_list, pred_reg_list, im_info)
43 | pred_bbox = per_layer_inference(
44 | anchors_list, pred_cls_list, pred_reg_list, im_info)
45 | return pred_bbox.cpu().detach()
46 |
47 | class RetinaNet_Anchor():
48 | def __init__(self):
49 | self.anchors_generator = AnchorGenerator(
50 | config.anchor_base_size,
51 | config.anchor_aspect_ratios,
52 | config.anchor_base_scale)
53 |
54 | def __call__(self, fpn_fms):
55 | # get anchors
56 | all_anchors_list = []
57 | base_stride = 8
58 | off_stride = 2**(len(fpn_fms)-1) # 16
59 | for fm in fpn_fms:
60 | layer_anchors = self.anchors_generator(fm, base_stride, off_stride)
61 | off_stride = off_stride // 2
62 | all_anchors_list.append(layer_anchors)
63 | return all_anchors_list
64 |
65 | class RetinaNet_Criteria(nn.Module):
66 | def __init__(self):
67 | super().__init__()
68 | self.loss_normalizer = 100 # initialize with any reasonable #fg that's not too small
69 | self.loss_normalizer_momentum = 0.9
70 |
71 | def __call__(self, pred_cls_list, pred_reg_list, anchors_list, gt_boxes, im_info):
72 | all_anchors = torch.cat(anchors_list, axis=0)
73 | all_pred_cls = torch.cat(pred_cls_list, axis=1).reshape(-1, config.num_classes-1)
74 | all_pred_cls = torch.sigmoid(all_pred_cls)
75 | all_pred_reg = torch.cat(pred_reg_list, axis=1).reshape(-1, 4)
76 | # get ground truth
77 | labels, bbox_target = retina_anchor_target(all_anchors, gt_boxes, im_info, top_k=1)
78 | # regression loss
79 | fg_mask = (labels > 0).flatten()
80 | valid_mask = (labels >= 0).flatten()
81 | loss_reg = smooth_l1_loss(
82 | all_pred_reg[fg_mask],
83 | bbox_target[fg_mask],
84 | config.smooth_l1_beta)
85 | loss_cls = focal_loss(
86 | all_pred_cls[valid_mask],
87 | labels[valid_mask],
88 | config.focal_loss_alpha,
89 | config.focal_loss_gamma)
90 | num_pos_anchors = fg_mask.sum().item()
91 | self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + (
92 | 1 - self.loss_normalizer_momentum
93 | ) * max(num_pos_anchors, 1)
94 | loss_reg = loss_reg.sum() / self.loss_normalizer
95 | loss_cls = loss_cls.sum() / self.loss_normalizer
96 | loss_dict = {}
97 | loss_dict['retina_focal_loss'] = loss_cls
98 | loss_dict['retina_smooth_l1'] = loss_reg
99 | return loss_dict
100 |
101 | class RetinaNet_Head(nn.Module):
102 | def __init__(self):
103 | super().__init__()
104 | num_convs = 4
105 | in_channels = 256
106 | cls_subnet = []
107 | bbox_subnet = []
108 | for _ in range(num_convs):
109 | cls_subnet.append(
110 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
111 | )
112 | cls_subnet.append(nn.ReLU(inplace=True))
113 | bbox_subnet.append(
114 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
115 | )
116 | bbox_subnet.append(nn.ReLU(inplace=True))
117 | self.cls_subnet = nn.Sequential(*cls_subnet)
118 | self.bbox_subnet = nn.Sequential(*bbox_subnet)
119 | # predictor
120 | self.cls_score = nn.Conv2d(
121 | in_channels, config.num_cell_anchors * (config.num_classes-1),
122 | kernel_size=3, stride=1, padding=1)
123 | self.bbox_pred = nn.Conv2d(
124 | in_channels, config.num_cell_anchors * 4,
125 | kernel_size=3, stride=1, padding=1)
126 |
127 | # Initialization
128 | for modules in [self.cls_subnet, self.bbox_subnet, self.cls_score, self.bbox_pred]:
129 | for layer in modules.modules():
130 | if isinstance(layer, nn.Conv2d):
131 | torch.nn.init.normal_(layer.weight, mean=0, std=0.01)
132 | torch.nn.init.constant_(layer.bias, 0)
133 | prior_prob = 0.01
134 | # Use prior in model initialization to improve stability
135 | bias_value = -(math.log((1 - prior_prob) / prior_prob))
136 | torch.nn.init.constant_(self.cls_score.bias, bias_value)
137 |
138 | def forward(self, features):
139 | pred_cls = []
140 | pred_reg = []
141 | for feature in features:
142 | pred_cls.append(self.cls_score(self.cls_subnet(feature)))
143 | pred_reg.append(self.bbox_pred(self.bbox_subnet(feature)))
144 | # reshape the predictions
145 | assert pred_cls[0].dim() == 4
146 | pred_cls_list = [
147 | _.permute(0, 2, 3, 1).reshape(pred_cls[0].shape[0], -1, config.num_classes-1)
148 | for _ in pred_cls]
149 | pred_reg_list = [
150 | _.permute(0, 2, 3, 1).reshape(pred_reg[0].shape[0], -1, 4)
151 | for _ in pred_reg]
152 | return pred_cls_list, pred_reg_list
153 |
154 | def per_layer_inference(anchors_list, pred_cls_list, pred_reg_list, im_info):
155 | keep_anchors = []
156 | keep_cls = []
157 | keep_reg = []
158 | class_num = pred_cls_list[0].shape[-1]
159 | for l_id in range(len(anchors_list)):
160 | anchors = anchors_list[l_id].reshape(-1, 4)
161 | pred_cls = pred_cls_list[l_id][0].reshape(-1, class_num)
162 | pred_reg = pred_reg_list[l_id][0].reshape(-1, 4)
163 | if len(anchors) > config.test_layer_topk:
164 | ruler = pred_cls.max(axis=1)[0]
165 | _, inds = ruler.topk(config.test_layer_topk, dim=0)
166 | inds = inds.flatten()
167 | keep_anchors.append(anchors[inds])
168 | keep_cls.append(torch.sigmoid(pred_cls[inds]))
169 | keep_reg.append(pred_reg[inds])
170 | else:
171 | keep_anchors.append(anchors)
172 | keep_cls.append(torch.sigmoid(pred_cls))
173 | keep_reg.append(pred_reg)
174 | keep_anchors = torch.cat(keep_anchors, axis = 0)
175 | keep_cls = torch.cat(keep_cls, axis = 0)
176 | keep_reg = torch.cat(keep_reg, axis = 0)
177 | # multiclass
178 | tag = torch.arange(class_num).type_as(keep_cls)+1
179 | tag = tag.repeat(keep_cls.shape[0], 1).reshape(-1,1)
180 | pred_scores = keep_cls.reshape(-1, 1)
181 | pred_bbox = restore_bbox(keep_anchors, keep_reg, False)
182 | pred_bbox = pred_bbox.repeat(1, class_num).reshape(-1, 4)
183 | pred_bbox = torch.cat([pred_bbox, pred_scores, tag], axis=1)
184 | return pred_bbox
185 |
186 | def union_inference(anchors_list, pred_cls_list, pred_reg_list, im_info):
187 | anchors = torch.cat(anchors_list, axis = 0)
188 | pred_cls = torch.cat(pred_cls_list, axis = 1)[0]
189 | pred_cls = torch.sigmoid(pred_cls)
190 | pred_reg = torch.cat(pred_reg_list, axis = 1)[0]
191 | class_num = pred_cls_list[0].shape[-1]
192 | # multiclass
193 | tag = torch.arange(class_num).type_as(keep_cls)+1
194 | tag = tag.repeat(keep_cls.shape[0], 1).reshape(-1,1)
195 | pred_scores = keep_cls.reshape(-1, 1)
196 | pred_bbox = restore_bbox(keep_anchors, keep_reg, False)
197 | pred_bbox = pred_bbox.repeat(1, class_num).reshape(-1, 4)
198 | pred_bbox = torch.cat([pred_bbox, pred_scores, tag], axis=1)
199 | return pred_bbox
200 |
201 | def restore_bbox(rois, deltas, unnormalize=True):
202 | if unnormalize:
203 | std_opr = torch.tensor(config.bbox_normalize_stds[None, :]).type_as(deltas)
204 | mean_opr = torch.tensor(config.bbox_normalize_means[None, :]).type_as(deltas)
205 | deltas = deltas * std_opr
206 | deltas = deltas + mean_opr
207 | pred_bbox = bbox_transform_inv_opr(rois, deltas)
208 | return pred_bbox
209 |
210 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | opencv-python
3 |
--------------------------------------------------------------------------------
/tools/eval_json.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 |
5 | sys.path.insert(0, '../lib')
6 | from utils import misc_utils
7 | from evaluate import compute_JI, compute_APMR
8 |
9 | def eval_all(args):
10 | # ground truth file
11 | gt_path = '/data/CrowdHuman/annotation_val.odgt'
12 | assert os.path.exists(gt_path), "Wrong ground truth path!"
13 | misc_utils.ensure_dir('outputs')
14 | # output file
15 | eval_path = os.path.join('outputs', 'result_eval.md')
16 | eval_fid = open(eval_path,'a')
17 | eval_fid.write((args.json_file+'\n'))
18 | # eval JI
19 | res_line, JI = compute_JI.evaluation_all(args.json_file, 'box')
20 | for line in res_line:
21 | eval_fid.write(line+'\n')
22 | # eval AP, MR
23 | AP, MR = compute_APMR.compute_APMR(args.json_file, gt_path, 'box')
24 | line = 'AP:{:.4f}, MR:{:.4f}, JI:{:.4f}.'.format(AP, MR, JI)
25 | print(line)
26 | eval_fid.write(line+'\n\n')
27 | eval_fid.close()
28 |
29 | def run_eval():
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument('--json_file', '-f', default=None, required=True, type=str)
32 | args = parser.parse_args()
33 | eval_all(args)
34 |
35 | if __name__ == '__main__':
36 | run_eval()
37 |
--------------------------------------------------------------------------------
/tools/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 |
5 | import cv2
6 | import torch
7 | import numpy as np
8 |
9 | sys.path.insert(0, '../lib')
10 | from utils import misc_utils, visual_utils, nms_utils
11 |
12 | def inference(args, config, network):
13 | # model_path
14 | misc_utils.ensure_dir('outputs')
15 | saveDir = os.path.join('../model', args.model_dir, config.model_dir)
16 | model_file = os.path.join(saveDir,
17 | 'dump-{}.pth'.format(args.resume_weights))
18 | assert os.path.exists(model_file)
19 | # build network
20 | net = network()
21 | net.eval()
22 | check_point = torch.load(model_file, map_location=torch.device('cpu'))
23 | net.load_state_dict(check_point['state_dict'])
24 | # get data
25 | image, resized_img, im_info = get_data(
26 | args.img_path, config.eval_image_short_size, config.eval_image_max_size)
27 | pred_boxes = net(resized_img, im_info).numpy()
28 | pred_boxes = post_process(pred_boxes, config, im_info[0, 2])
29 | pred_tags = pred_boxes[:, 5].astype(np.int32).flatten()
30 | pred_tags_name = np.array(config.class_names)[pred_tags]
31 | # inplace draw
32 | image = visual_utils.draw_boxes(
33 | image,
34 | pred_boxes[:, :4],
35 | scores=pred_boxes[:, 4],
36 | tags=pred_tags_name,
37 | line_thick=1, line_color='white')
38 | name = args.img_path.split('/')[-1].split('.')[-2]
39 | fpath = 'outputs/{}.png'.format(name)
40 | cv2.imwrite(fpath, image)
41 |
42 | def post_process(pred_boxes, config, scale):
43 | if config.test_nms_method == 'set_nms':
44 | assert pred_boxes.shape[-1] > 6, "Not EMD Network! Using normal_nms instead."
45 | assert pred_boxes.shape[-1] % 6 == 0, "Prediction dim Error!"
46 | top_k = pred_boxes.shape[-1] // 6
47 | n = pred_boxes.shape[0]
48 | pred_boxes = pred_boxes.reshape(-1, 6)
49 | idents = np.tile(np.arange(n)[:,None], (1, top_k)).reshape(-1, 1)
50 | pred_boxes = np.hstack((pred_boxes, idents))
51 | keep = pred_boxes[:, 4] > config.pred_cls_threshold
52 | pred_boxes = pred_boxes[keep]
53 | keep = nms_utils.set_cpu_nms(pred_boxes, 0.5)
54 | pred_boxes = pred_boxes[keep]
55 | elif config.test_nms_method == 'normal_nms':
56 | assert pred_boxes.shape[-1] % 6 == 0, "Prediction dim Error!"
57 | pred_boxes = pred_boxes.reshape(-1, 6)
58 | keep = pred_boxes[:, 4] > config.pred_cls_threshold
59 | pred_boxes = pred_boxes[keep]
60 | keep = nms_utils.cpu_nms(pred_boxes, config.test_nms)
61 | pred_boxes = pred_boxes[keep]
62 | elif config.test_nms_method == 'none':
63 | assert pred_boxes.shape[-1] % 6 == 0, "Prediction dim Error!"
64 | pred_boxes = pred_boxes.reshape(-1, 6)
65 | keep = pred_boxes[:, 4] > config.pred_cls_threshold
66 | pred_boxes = pred_boxes[keep]
67 | #if pred_boxes.shape[0] > config.detection_per_image and \
68 | # config.test_nms_method != 'none':
69 | # order = np.argsort(-pred_boxes[:, 4])
70 | # order = order[:config.detection_per_image]
71 | # pred_boxes = pred_boxes[order]
72 | # recovery the scale
73 | pred_boxes[:, :4] /= scale
74 | keep = pred_boxes[:, 4] > config.visulize_threshold
75 | pred_boxes = pred_boxes[keep]
76 | return pred_boxes
77 |
78 | def get_data(img_path, short_size, max_size):
79 | image = cv2.imread(img_path, cv2.IMREAD_COLOR)
80 | resized_img, scale = resize_img(
81 | image, short_size, max_size)
82 |
83 | original_height, original_width = image.shape[0:2]
84 | height, width = resized_img.shape[0:2]
85 | resized_img = resized_img.transpose(2, 0, 1)
86 | im_info = np.array([height, width, scale, original_height, original_width, 0])
87 | return image, torch.tensor([resized_img]).float(), torch.tensor([im_info])
88 |
89 | def resize_img(image, short_size, max_size):
90 | height = image.shape[0]
91 | width = image.shape[1]
92 | im_size_min = np.min([height, width])
93 | im_size_max = np.max([height, width])
94 | scale = (short_size + 0.0) / im_size_min
95 | if scale * im_size_max > max_size:
96 | scale = (max_size + 0.0) / im_size_max
97 | t_height, t_width = int(round(height * scale)), int(
98 | round(width * scale))
99 | resized_image = cv2.resize(
100 | image, (t_width, t_height), interpolation=cv2.INTER_LINEAR)
101 | return resized_image, scale
102 |
103 | def run_inference():
104 | parser = argparse.ArgumentParser()
105 | parser.add_argument('--model_dir', '-md', default=None, required=True, type=str)
106 | parser.add_argument('--resume_weights', '-r', default=None, required=True, type=str)
107 | parser.add_argument('--img_path', '-i', default=None, required=True, type=str)
108 | args = parser.parse_args()
109 | # import libs
110 | model_root_dir = os.path.join('../model/', args.model_dir)
111 | sys.path.insert(0, model_root_dir)
112 | from config import config
113 | from network import Network
114 | inference(args, config, Network)
115 |
116 | if __name__ == '__main__':
117 | run_inference()
118 |
--------------------------------------------------------------------------------
/tools/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import argparse
5 |
6 | import numpy as np
7 | from tqdm import tqdm
8 | import torch
9 | from torch.multiprocessing import Queue, Process
10 |
11 | sys.path.insert(0, '../lib')
12 | sys.path.insert(0, '../model')
13 | from data.CrowdHuman import CrowdHuman
14 | from utils import misc_utils, nms_utils
15 | from evaluate import compute_JI, compute_APMR
16 |
17 | def eval_all(args, config, network):
18 | # model_path
19 | saveDir = os.path.join('../model', args.model_dir, config.model_dir)
20 | evalDir = os.path.join('../model', args.model_dir, config.eval_dir)
21 | misc_utils.ensure_dir(evalDir)
22 | model_file = os.path.join(saveDir,
23 | 'dump-{}.pth'.format(args.resume_weights))
24 | assert os.path.exists(model_file)
25 | # get devices
26 | str_devices = args.devices
27 | devices = misc_utils.device_parser(str_devices)
28 | # load data
29 | crowdhuman = CrowdHuman(config, if_train=False)
30 | #crowdhuman.records = crowdhuman.records[:10]
31 | # multiprocessing
32 | num_devs = len(devices)
33 | len_dataset = len(crowdhuman)
34 | num_image = math.ceil(len_dataset / num_devs)
35 | result_queue = Queue(500)
36 | procs = []
37 | all_results = []
38 | for i in range(num_devs):
39 | start = i * num_image
40 | end = min(start + num_image, len_dataset)
41 | proc = Process(target=inference, args=(
42 | config, network, model_file, devices[i], crowdhuman, start, end, result_queue))
43 | proc.start()
44 | procs.append(proc)
45 | pbar = tqdm(total=len_dataset, ncols=50)
46 | for i in range(len_dataset):
47 | t = result_queue.get()
48 | all_results.append(t)
49 | pbar.update(1)
50 | pbar.close()
51 | for p in procs:
52 | p.join()
53 | fpath = os.path.join(evalDir, 'dump-{}.json'.format(args.resume_weights))
54 | misc_utils.save_json_lines(all_results, fpath)
55 | # evaluation
56 | eval_path = os.path.join(evalDir, 'eval-{}.json'.format(args.resume_weights))
57 | eval_fid = open(eval_path,'w')
58 | res_line, JI = compute_JI.evaluation_all(fpath, 'box')
59 | for line in res_line:
60 | eval_fid.write(line+'\n')
61 | AP, MR = compute_APMR.compute_APMR(fpath, config.eval_source, 'box')
62 | line = 'AP:{:.4f}, MR:{:.4f}, JI:{:.4f}.'.format(AP, MR, JI)
63 | print(line)
64 | eval_fid.write(line+'\n')
65 | eval_fid.close()
66 |
67 | def inference(config, network, model_file, device, dataset, start, end, result_queue):
68 | torch.set_default_tensor_type('torch.FloatTensor')
69 | torch.multiprocessing.set_sharing_strategy('file_system')
70 | # init model
71 | net = network()
72 | net.cuda(device)
73 | net = net.eval()
74 | check_point = torch.load(model_file)
75 | net.load_state_dict(check_point['state_dict'])
76 | # init data
77 | dataset.records = dataset.records[start:end];
78 | data_iter = torch.utils.data.DataLoader(dataset=dataset, shuffle=False)
79 | # inference
80 | for (image, gt_boxes, im_info, ID) in data_iter:
81 | pred_boxes = net(image.cuda(device), im_info.cuda(device))
82 | scale = im_info[0, 2]
83 | if config.test_nms_method == 'set_nms':
84 | assert pred_boxes.shape[-1] > 6, "Not EMD Network! Using normal_nms instead."
85 | assert pred_boxes.shape[-1] % 6 == 0, "Prediction dim Error!"
86 | top_k = pred_boxes.shape[-1] // 6
87 | n = pred_boxes.shape[0]
88 | pred_boxes = pred_boxes.reshape(-1, 6)
89 | idents = np.tile(np.arange(n)[:,None], (1, top_k)).reshape(-1, 1)
90 | pred_boxes = np.hstack((pred_boxes, idents))
91 | keep = pred_boxes[:, 4] > config.pred_cls_threshold
92 | pred_boxes = pred_boxes[keep]
93 | keep = nms_utils.set_cpu_nms(pred_boxes, 0.5)
94 | pred_boxes = pred_boxes[keep]
95 | elif config.test_nms_method == 'normal_nms':
96 | assert pred_boxes.shape[-1] % 6 == 0, "Prediction dim Error!"
97 | pred_boxes = pred_boxes.reshape(-1, 6)
98 | keep = pred_boxes[:, 4] > config.pred_cls_threshold
99 | pred_boxes = pred_boxes[keep]
100 | keep = nms_utils.cpu_nms(pred_boxes, config.test_nms)
101 | pred_boxes = pred_boxes[keep]
102 | elif config.test_nms_method == 'none':
103 | assert pred_boxes.shape[-1] % 6 == 0, "Prediction dim Error!"
104 | pred_boxes = pred_boxes.reshape(-1, 6)
105 | keep = pred_boxes[:, 4] > config.pred_cls_threshold
106 | pred_boxes = pred_boxes[keep]
107 | else:
108 | raise ValueError('Unknown NMS method.')
109 | #if pred_boxes.shape[0] > config.detection_per_image and \
110 | # config.test_nms_method != 'none':
111 | # order = np.argsort(-pred_boxes[:, 4])
112 | # order = order[:config.detection_per_image]
113 | # pred_boxes = pred_boxes[order]
114 | # recovery the scale
115 | pred_boxes[:, :4] /= scale
116 | pred_boxes[:, 2:4] -= pred_boxes[:, :2]
117 | gt_boxes = gt_boxes[0].numpy()
118 | gt_boxes[:, 2:4] -= gt_boxes[:, :2]
119 | result_dict = dict(ID=ID[0], height=int(im_info[0, -3]), width=int(im_info[0, -2]),
120 | dtboxes=boxes_dump(pred_boxes), gtboxes=boxes_dump(gt_boxes))
121 | result_queue.put_nowait(result_dict)
122 |
123 | def boxes_dump(boxes):
124 | if boxes.shape[-1] == 7:
125 | result = [{'box':[round(i, 1) for i in box[:4]],
126 | 'score':round(float(box[4]), 5),
127 | 'tag':int(box[5]),
128 | 'proposal_num':int(box[6])} for box in boxes]
129 | elif boxes.shape[-1] == 6:
130 | result = [{'box':[round(i, 1) for i in box[:4].tolist()],
131 | 'score':round(float(box[4]), 5),
132 | 'tag':int(box[5])} for box in boxes]
133 | elif boxes.shape[-1] == 5:
134 | result = [{'box':[round(i, 1) for i in box[:4]],
135 | 'tag':int(box[4])} for box in boxes]
136 | else:
137 | raise ValueError('Unknown box dim.')
138 | return result
139 |
140 | def run_test():
141 | parser = argparse.ArgumentParser()
142 | parser.add_argument('--model_dir', '-md', default=None, required=True, type=str)
143 | parser.add_argument('--resume_weights', '-r', default=None, required=True, type=str)
144 | parser.add_argument('--devices', '-d', default='0', type=str)
145 | os.environ['NCCL_IB_DISABLE'] = '1'
146 | args = parser.parse_args()
147 | # import libs
148 | model_root_dir = os.path.join('../model/', args.model_dir)
149 | sys.path.insert(0, model_root_dir)
150 | from config import config
151 | from network import Network
152 | eval_all(args, config, Network)
153 |
154 | if __name__ == '__main__':
155 | run_test()
156 |
157 |
--------------------------------------------------------------------------------
/tools/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import torch
5 |
6 | sys.path.insert(0, '../lib')
7 | sys.path.insert(0, '../model')
8 | from data.CrowdHuman import CrowdHuman
9 | from utils import misc_utils, SGD_bias
10 |
11 | class Train_config:
12 | # size
13 | world_size = 0
14 | mini_batch_size = 0
15 | iter_per_epoch = 0
16 | total_epoch = 0
17 | # learning
18 | warm_iter = 0
19 | learning_rate = 0
20 | momentum = 0
21 | weight_decay = 0
22 | lr_decay = []
23 | # model
24 | log_dump_interval = 0
25 | resume_weights = None
26 | init_weights = None
27 | model_dir = ''
28 | log_path = ''
29 |
30 | def do_train_epoch(net, data_iter, optimizer, rank, epoch, train_config):
31 | if rank == 0:
32 | fid_log = open(train_config.log_path,'a')
33 | if epoch >= train_config.lr_decay[0]:
34 | for param_group in optimizer.param_groups:
35 | param_group['lr'] = train_config.learning_rate / 10
36 | if epoch >= train_config.lr_decay[1]:
37 | for param_group in optimizer.param_groups:
38 | param_group['lr'] = train_config.learning_rate / 100
39 | for (images, gt_boxes, im_info), step in zip(data_iter, range(0, train_config.iter_per_epoch)):
40 | if images is None:
41 | continue
42 | # warm up
43 | if epoch == 1 and step < train_config.warm_iter:
44 | alpha = step / train_config.warm_iter
45 | lr_new = 0.33 * train_config.learning_rate + \
46 | 0.67 * alpha * train_config.learning_rate
47 | for group in optimizer.param_groups:
48 | group['lr'] = lr_new
49 | elif epoch == 1 and step == train_config.warm_iter:
50 | for group in optimizer.param_groups:
51 | group['lr'] = train_config.learning_rate
52 | # get training data
53 | optimizer.zero_grad()
54 | # forwad
55 | outputs = net(images.cuda(rank), im_info.cuda(rank), gt_boxes.cuda(rank))
56 | # collect the loss
57 | total_loss = sum([outputs[key].mean() for key in outputs.keys()])
58 | assert torch.isfinite(total_loss).all(), outputs
59 | total_loss.backward()
60 | optimizer.step()
61 | # stastic
62 | if rank == 0:
63 | if step % train_config.log_dump_interval == 0:
64 | stastic_total_loss = total_loss.item()
65 | line = 'Epoch:{}, iter:{}, lr:{:.5f}, loss is {:.4f}.'.format(
66 | epoch, step, optimizer.param_groups[0]['lr'], stastic_total_loss)
67 | print(line)
68 | print(outputs)
69 | fid_log.write(line+'\n')
70 | fid_log.write(str(outputs)+'\n')
71 | fid_log.flush()
72 | if rank == 0:
73 | fid_log.close()
74 |
75 | def train_worker(rank, train_config, network, config):
76 | # set the parallel
77 | torch.distributed.init_process_group(backend='nccl',
78 | init_method='env://', world_size=train_config.world_size, rank=rank)
79 | # initialize model
80 | net = network()
81 | # load pretrain model
82 | backbone_dict = torch.load(train_config.init_weights)
83 | del backbone_dict['state_dict']['fc.weight']
84 | del backbone_dict['state_dict']['fc.bias']
85 | net.resnet50.load_state_dict(backbone_dict['state_dict'])
86 | net.cuda(rank)
87 | begin_epoch = 1
88 | # build optimizer
89 | #optimizer = SGD_bias.SGD(net.parameters(),
90 | optimizer = torch.optim.SGD(net.parameters(),
91 | lr=train_config.learning_rate, momentum=train_config.momentum,
92 | weight_decay=train_config.weight_decay)
93 | if train_config.resume_weights:
94 | model_file = os.path.join(train_config.model_dir,
95 | 'dump-{}.pth'.format(train_config.resume_weights))
96 | check_point = torch.load(model_file, map_location=torch.device('cpu'))
97 | net.load_state_dict(check_point['state_dict'])
98 | begin_epoch = train_config.resume_weights + 1
99 | # using distributed data parallel
100 | net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[rank], broadcast_buffers=False)
101 | # build data provider
102 | crowdhuman = CrowdHuman(config, if_train=True)
103 | data_iter = torch.utils.data.DataLoader(dataset=crowdhuman,
104 | batch_size=train_config.mini_batch_size,
105 | num_workers=2,
106 | collate_fn=crowdhuman.merge_batch,
107 | shuffle=True)
108 | for epoch_id in range(begin_epoch, train_config.total_epoch+1):
109 | do_train_epoch(net, data_iter, optimizer, rank, epoch_id, train_config)
110 | if rank == 0:
111 | #save the model
112 | fpath = os.path.join(train_config.model_dir,
113 | 'dump-{}.pth'.format(epoch_id))
114 | model = dict(epoch = epoch_id,
115 | state_dict = net.module.state_dict(),
116 | optimizer = optimizer.state_dict())
117 | torch.save(model,fpath)
118 |
119 | def multi_train(params, config, network):
120 | # check gpus
121 | if not torch.cuda.is_available():
122 | print('No GPU exists!')
123 | return
124 | else:
125 | num_gpus = torch.cuda.device_count()
126 | torch.set_default_tensor_type('torch.FloatTensor')
127 | # setting training config
128 | train_config = Train_config()
129 | train_config.world_size = num_gpus
130 | train_config.total_epoch = config.max_epoch
131 | train_config.iter_per_epoch = \
132 | config.nr_images_epoch // (num_gpus * config.train_batch_per_gpu)
133 | train_config.mini_batch_size = config.train_batch_per_gpu
134 | train_config.warm_iter = config.warm_iter
135 | train_config.learning_rate = \
136 | config.base_lr * config.train_batch_per_gpu * num_gpus
137 | train_config.momentum = config.momentum
138 | train_config.weight_decay = config.weight_decay
139 | train_config.lr_decay = config.lr_decay
140 | train_config.model_dir = os.path.join('../model/', params.model_dir, config.model_dir)
141 | line = 'network.lr.{}.train.{}'.format(
142 | train_config.learning_rate, train_config.total_epoch)
143 | train_config.log_path = os.path.join('../model/', params.model_dir, config.output_dir, line+'.log')
144 | train_config.resume_weights = params.resume_weights
145 | train_config.init_weights = config.init_weights
146 | train_config.log_dump_interval = config.log_dump_interval
147 | misc_utils.ensure_dir(train_config.model_dir)
148 | # print the training config
149 | line = 'Num of GPUs:{}, learning rate:{:.5f}, mini batch size:{}, \
150 | \ntrain_epoch:{}, iter_per_epoch:{}, decay_epoch:{}'.format(
151 | num_gpus, train_config.learning_rate, train_config.mini_batch_size,
152 | train_config.total_epoch, train_config.iter_per_epoch, train_config.lr_decay)
153 | print(line)
154 | print("Init multi-processing training...")
155 | torch.multiprocessing.spawn(train_worker, nprocs=num_gpus, args=(train_config, network, config))
156 |
157 | def run_train():
158 | parser = argparse.ArgumentParser()
159 | parser.add_argument('--model_dir', '-md', default=None,required=True,type=str)
160 | parser.add_argument('--resume_weights', '-r', default=None,type=int)
161 | os.environ['MASTER_ADDR'] = '127.0.0.1'
162 | os.environ['MASTER_PORT'] = '8888'
163 | os.environ['NCCL_IB_DISABLE'] = '1'
164 | #os.environ['NCCL_DEBUG'] = 'INFO'
165 | args = parser.parse_args()
166 | # import libs
167 | model_root_dir = os.path.join('../model/', args.model_dir)
168 | sys.path.insert(0, model_root_dir)
169 | from config import config
170 | from network import Network
171 | multi_train(args, config, Network)
172 |
173 | if __name__ == '__main__':
174 | run_train()
175 |
176 |
--------------------------------------------------------------------------------
/tools/visulize_json.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import argparse
5 |
6 | sys.path.insert(0, '../lib')
7 | from utils import misc_utils, visual_utils
8 |
9 | img_root = '/data/CrowdHuman/images/'
10 | def eval_all(args):
11 | # json file
12 | assert os.path.exists(args.json_file), "Wrong json path!"
13 | misc_utils.ensure_dir('outputs')
14 | records = misc_utils.load_json_lines(args.json_file)[:args.number]
15 | for record in records:
16 | dtboxes = misc_utils.load_bboxes(
17 | record, key_name='dtboxes', key_box='box', key_score='score', key_tag='tag')
18 | gtboxes = misc_utils.load_bboxes(record, 'gtboxes', 'box')
19 | dtboxes = misc_utils.xywh_to_xyxy(dtboxes)
20 | gtboxes = misc_utils.xywh_to_xyxy(gtboxes)
21 | keep = dtboxes[:, -2] > args.visual_thresh
22 | dtboxes = dtboxes[keep]
23 | len_dt = len(dtboxes)
24 | len_gt = len(gtboxes)
25 | line = "{}: dt:{}, gt:{}.".format(record['ID'], len_dt, len_gt)
26 | print(line)
27 | img_path = img_root + record['ID'] + '.png'
28 | img = misc_utils.load_img(img_path)
29 | visual_utils.draw_boxes(img, dtboxes, line_thick=1, line_color='blue')
30 | visual_utils.draw_boxes(img, gtboxes, line_thick=1, line_color='white')
31 | fpath = 'outputs/{}.png'.format(record['ID'])
32 | cv2.imwrite(fpath, img)
33 |
34 |
35 | def run_eval():
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument('--json_file', '-f', default=None, required=True, type=str)
38 | parser.add_argument('--number', '-n', default=3, type=int)
39 | parser.add_argument('--visual_thresh', '-v', default=0.3, type=int)
40 | args = parser.parse_args()
41 | eval_all(args)
42 |
43 | if __name__ == '__main__':
44 | run_eval()
45 |
--------------------------------------------------------------------------------