├── .gitignore ├── LICENSE ├── README.md ├── backbone ├── base.py ├── resnet101.py ├── resnet18.py └── resnet50.py ├── bbox.py ├── config ├── config.py ├── eval_config.py └── train_config.py ├── data └── .gitignore ├── dataset ├── base.py ├── coco2017.py ├── coco2017_animal.py ├── coco2017_car.py ├── coco2017_person.py ├── voc2007.py └── voc2007_cat_dog.py ├── eval.py ├── evaluator.py ├── images ├── feature-pyramid.png ├── inference-result.jpg ├── inference-sample.jpg ├── nms_cuda.png ├── rpn_find_labels_1.png ├── rpn_find_labels_2.png └── test_nms.png ├── infer.py ├── logger.py ├── model.py ├── nms ├── build.py ├── nms.py ├── src │ ├── nms.c │ ├── nms.h │ ├── nms_cuda.cu │ └── nms_cuda.h └── test │ ├── nms-large-input.npy │ ├── nms-large-output.npy │ └── test_nms.py ├── outputs └── .gitignore ├── realtime.py ├── roi ├── align │ ├── build.py │ ├── crop_and_resize.py │ ├── roi_align.py │ └── src │ │ ├── crop_and_resize.c │ │ ├── crop_and_resize.h │ │ ├── crop_and_resize_gpu.c │ │ ├── crop_and_resize_gpu.h │ │ └── cuda │ │ ├── crop_and_resize_kernel.cu │ │ └── crop_and_resize_kernel.h └── wrapper.py ├── rpn └── region_proposal_network.py ├── train.py └── voc_eval.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Potter Hsu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # easy-fpn.pytorch 2 | 3 | An easy implementation of [FPN](https://arxiv.org/pdf/1612.03144.pdf) in PyTorch based on our [easy-faster-rcnn.pytorch](https://github.com/potterhsu/easy-faster-rcnn.pytorch) project. 4 | 5 | 6 | ## Demo 7 | 8 | ![](images/inference-result.jpg?raw=true) 9 | 10 | 11 | ## Features 12 | 13 | * Supports PyTorch 0.4.1 14 | * Supports `PASCAL VOC 2007` and `MS COCO 2017` datasets 15 | * Supports `ResNet-18`, `ResNet-50` and `ResNet-101` backbones (from official PyTorch model) 16 | * Supports `ROI Pooling` and `ROI Align` pooling modes 17 | * Matches the performance reported by the original paper 18 | * It's efficient with maintainable, readable and clean code 19 | 20 | 21 | ## Benchmarking 22 | 23 | * PASCAL VOC 2007 24 | 25 | * Train: 2007 trainval (5011 images) 26 | * Eval: 2007 test (4952 images) 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 |
ImplementationBackboneGPUTraining Speed (FPS)Inference Speed (FPS)mAPimage_min_sideimage_max_sideanchor_ratiosanchor_scalespooling_moderpn_pre_nms_top_n (train)rpn_post_nms_top_n (train)rpn_pre_nms_top_n (eval)rpn_post_nms_top_n (eval)learning_ratemomentumweight_decaystep_lr_sizestep_lr_gammanum_steps_to_finish
54 | 55 | Ours 56 | 57 | ResNet-101GTX 1080 Ti~ 3.3~ 9.50.7627|0.7604 (60k|70k)8001333[(1, 2), (1, 1), (2, 1)][1]align120002000600010000.0010.90.0001500000.170000
80 | 81 | > Scroll to right for more configurations 82 | 83 | * MS COCO 2017 84 | 85 | * Train: 2017 Train drops images without any objects (117266 images) 86 | * Eval: 2017 Val drops images without any objects (4952 images) 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 |
ImplementationBackboneGPUTraining Speed (FPS)Inference Speed (FPS)AP@[.5:.95]image_min_sideimage_max_sideanchor_ratiosanchor_scalespooling_moderpn_pre_nms_top_n (train)rpn_post_nms_top_n (train)rpn_pre_nms_top_n (eval)rpn_post_nms_top_n (eval)learning_ratemomentumweight_decaystep_lr_sizestep_lr_gammanum_steps_to_finish
Original PaperResNet-101---0.362---------------
137 | 138 | Ours 139 | 140 | ResNet-101GTX 1080 Ti~ 3.3~ 9.50.3638001333[(1, 2), (1, 1), (2, 1)][1]align120002000600010000.0010.90.00019000000.11640000
163 | 164 | > Scroll to right for more configurations 165 | 166 | * PASCAL VOC 2007 Cat Dog 167 | 168 | * Train: 2007 trainval drops categories other than cat and dog (750 images) 169 | * Eval: 2007 test drops categories other than cat and dog (728 images) 170 | 171 | * MS COCO 2017 Person 172 | 173 | * Train: 2017 Train drops categories other than person (64115 images) 174 | * Eval: 2017 Val drops categories other than person (2693 images) 175 | 176 | * MS COCO 2017 Car 177 | 178 | * Train: 2017 Train drops categories other than car (12251 images) 179 | * Eval: 2017 Val drops categories other than car (535 images) 180 | 181 | * MS COCO 2017 Animal 182 | 183 | * Train: 2017 Train drops categories other than bird, cat, dog, horse, sheep, cow, elephant, bear, zebra and giraffe (23989 images) 184 | * Eval: 2017 Val drops categories other than bird, cat, dog, horse, sheep, cow, elephant, bear, zebra and giraffe (1016 images) 185 | 186 | 187 | ## Requirements 188 | 189 | * Python 3.6 190 | * torch 0.4.1 191 | * torchvision 0.2.1 192 | * tqdm 193 | 194 | ``` 195 | $ pip install tqdm 196 | ``` 197 | 198 | * tensorboardX 199 | 200 | ``` 201 | $ pip install tensorboardX 202 | ``` 203 | 204 | 205 | ## Setup 206 | 207 | 1. Prepare data 208 | 1. For `PASCAL VOC 2007` 209 | 210 | 1. Download dataset 211 | 212 | - [Training / Validation](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar) (5011 images) 213 | - [Test](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar) (4952 images) 214 | 215 | 1. Extract to data folder, now your folder structure should be like: 216 | 217 | ``` 218 | easy-faster-rcnn.pytorch 219 | - data 220 | - VOCdevkit 221 | - VOC2007 222 | - Annotations 223 | - 000001.xml 224 | - 000002.xml 225 | ... 226 | - ImageSets 227 | - Main 228 | ... 229 | test.txt 230 | ... 231 | trainval.txt 232 | ... 233 | - JPEGImages 234 | - 000001.jpg 235 | - 000002.jpg 236 | ... 237 | - ... 238 | ``` 239 | 240 | 1. For `MS COCO 2017` 241 | 242 | 1. Download dataset 243 | 244 | - [2017 Train images [18GB]](http://images.cocodataset.org/zips/train2017.zip) (118287 images) 245 | > COCO 2017 Train = COCO 2015 Train + COCO 2015 Val - COCO 2015 Val Sample 5k 246 | - [2017 Val images [1GB]](http://images.cocodataset.org/zips/val2017.zip) (5000 images) 247 | > COCO 2017 Val = COCO 2015 Val Sample 5k (formerly known as `minival`) 248 | - [2017 Train/Val annotations [241MB]](http://images.cocodataset.org/annotations/annotations_trainval2017.zip) 249 | 250 | 1. Extract to data folder, now your folder structure should be like: 251 | 252 | ``` 253 | easy-faster-rcnn.pytorch 254 | - data 255 | - COCO 256 | - annotations 257 | - instances_train2017.json 258 | - instances_val2017.json 259 | ... 260 | - train2017 261 | - 000000000009.jpg 262 | - 000000000025.jpg 263 | ... 264 | - val2017 265 | - 000000000139.jpg 266 | - 000000000285.jpg 267 | ... 268 | - ... 269 | ``` 270 | 271 | 1. Build CUDA modules 272 | 273 | 1. Define your CUDA architecture code 274 | 275 | ``` 276 | $ export CUDA_ARCH=sm_61 277 | ``` 278 | 279 | * `sm_61` is for `GTX 1080 Ti`, to see others visit [here](http://arnon.dk/matching-sm-architectures-arch-and-gencode-for-various-nvidia-cards/) 280 | 281 | * To check your GPU architecture, you might need following script to find out GPU information 282 | 283 | ``` 284 | $ nvidia-smi -L 285 | ``` 286 | 287 | 1. Build `Non-Maximum-Suppression` module 288 | 289 | ``` 290 | $ nvcc -arch=$CUDA_ARCH -c --compiler-options -fPIC -o nms/src/nms_cuda.o nms/src/nms_cuda.cu 291 | $ python nms/build.py 292 | $ python -m nms.test.test_nms 293 | ``` 294 | 295 | * Result after unit testing 296 | 297 | ![](images/test_nms.png?raw=true) 298 | 299 | 1. Build `ROI-Align` module (modified from [RoIAlign.pytorch](https://github.com/longcw/RoIAlign.pytorch)) 300 | 301 | ``` 302 | $ nvcc -arch=$CUDA_ARCH -c --compiler-options -fPIC -o roi/align/src/cuda/crop_and_resize_kernel.cu.o roi/align/src/cuda/crop_and_resize_kernel.cu 303 | $ python roi/align/build.py 304 | ``` 305 | 306 | 1. Install `pycocotools` for `MS COCO 2017` dataset 307 | 308 | 1. Clone and build COCO API 309 | 310 | ``` 311 | $ git clone https://github.com/cocodataset/cocoapi 312 | $ cd cocoapi/PythonAPI 313 | $ make 314 | ``` 315 | > It's not necessary to be under project directory 316 | 317 | 1. If an error with message `pycocotools/_mask.c: No such file or directory` has occurred, please install `cython` and try again 318 | 319 | ``` 320 | $ pip install cython 321 | ``` 322 | 323 | 1. Copy `pycocotools` into project 324 | 325 | ``` 326 | $ cp -R pycocotools /path/to/project 327 | ``` 328 | 329 | 330 | ## Usage 331 | 332 | 1. Train 333 | 334 | * To apply default configuration (see also `config/`) 335 | ``` 336 | $ python train.py -s=coco2017 -b=resnet101 337 | ``` 338 | 339 | * To apply custom configuration (see also `train.py`) 340 | ``` 341 | $ python train.py -s=coco2017 -b=resnet101 --pooling_mode=align 342 | ``` 343 | 344 | 1. Evaluate 345 | 346 | * To apply default configuration (see also `config/`) 347 | ``` 348 | $ python eval.py -s=coco2017 -b=resnet101 /path/to/checkpoint.pth 349 | ``` 350 | 351 | * To apply custom configuration (see also `eval.py`) 352 | ``` 353 | $ python eval.py -s=coco2017 -b=resnet101 --pooling_mode=align /path/to/checkpoint.pth 354 | ``` 355 | 356 | 1. Infer 357 | 358 | * To apply default configuration (see also `config/`) 359 | ``` 360 | $ python infer.py -c=/path/to/checkpoint.pth -s=coco2017 -b=resnet101 /path/to/input/image.jpg /path/to/output/image.jpg 361 | ``` 362 | 363 | * To apply custom configuration (see also `infer.py`) 364 | ``` 365 | $ python infer.py -c=/path/to/checkpoint.pth -s=coco2017 -b=resnet101 -p=0.9 /path/to/input/image.jpg /path/to/output/image.jpg 366 | ``` 367 | 368 | 369 | ## Notes 370 | 371 | * Illustration for feature pyramid (see `forward` in `model.py`) 372 | 373 | ```python 374 | # Bottom-up pathway 375 | c1 = self.conv1(image) 376 | c2 = self.conv2(c1) 377 | c3 = self.conv3(c2) 378 | c4 = self.conv4(c3) 379 | c5 = self.conv5(c4) 380 | 381 | # Top-down pathway and lateral connections 382 | p5 = self.lateral_c5(c5) 383 | p4 = self.lateral_c4(c4) + F.interpolate(input=p5, size=(c4.shape[2], c4.shape[3]), mode='nearest') 384 | p3 = self.lateral_c3(c3) + F.interpolate(input=p4, size=(c3.shape[2], c3.shape[3]), mode='nearest') 385 | p2 = self.lateral_c2(c2) + F.interpolate(input=p3, size=(c2.shape[2], c2.shape[3]), mode='nearest') 386 | 387 | # Reduce the aliasing effect 388 | p4 = self.dealiasing_p4(p4) 389 | p3 = self.dealiasing_p3(p3) 390 | p2 = self.dealiasing_p2(p2) 391 | 392 | p6 = F.max_pool2d(input=p5, kernel_size=2) 393 | ``` 394 | 395 | ![](images/feature-pyramid.png) 396 | 397 | * Illustration for "find labels for each `anchor_bboxes`" in `region_proposal_network.py` 398 | 399 | ![](images/rpn_find_labels_1.png) 400 | 401 | ![](images/rpn_find_labels_2.png) 402 | 403 | * Illustration for NMS CUDA 404 | 405 | ![](images/nms_cuda.png) 406 | -------------------------------------------------------------------------------- /backbone/base.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Type, NamedTuple 2 | 3 | from torch import nn 4 | 5 | 6 | class Base(object): 7 | 8 | OPTIONS = ['resnet18', 'resnet50', 'resnet101'] 9 | 10 | class ConvLayers(NamedTuple): 11 | conv1: nn.Module 12 | conv2: nn.Module 13 | conv3: nn.Module 14 | conv4: nn.Module 15 | conv5: nn.Module 16 | 17 | class LateralLayers(NamedTuple): 18 | lateral_c2: nn.Module 19 | lateral_c3: nn.Module 20 | lateral_c4: nn.Module 21 | lateral_c5: nn.Module 22 | 23 | class DealiasingLayers(NamedTuple): 24 | dealiasing_p2: nn.Module 25 | dealiasing_p3: nn.Module 26 | dealiasing_p4: nn.Module 27 | 28 | @staticmethod 29 | def from_name(name: str) -> Type['Base']: 30 | if name == 'resnet18': 31 | from backbone.resnet18 import ResNet18 32 | return ResNet18 33 | elif name == 'resnet50': 34 | from backbone.resnet50 import ResNet50 35 | return ResNet50 36 | elif name == 'resnet101': 37 | from backbone.resnet101 import ResNet101 38 | return ResNet101 39 | else: 40 | raise ValueError 41 | 42 | def __init__(self, pretrained: bool): 43 | super().__init__() 44 | self._pretrained = pretrained 45 | 46 | def features(self) -> Tuple[ConvLayers, LateralLayers, DealiasingLayers, int]: 47 | raise NotImplementedError 48 | -------------------------------------------------------------------------------- /backbone/resnet101.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torchvision 4 | from torch import nn 5 | 6 | from backbone.base import Base 7 | 8 | 9 | class ResNet101(Base): 10 | 11 | def __init__(self, pretrained: bool): 12 | super().__init__(pretrained) 13 | 14 | def features(self) -> Tuple[Base.ConvLayers, Base.LateralLayers, Base.DealiasingLayers, int]: 15 | resnet101 = torchvision.models.resnet101(pretrained=self._pretrained) 16 | 17 | # list(resnet101.children()) consists of following modules 18 | # [0] = Conv2d, [1] = BatchNorm2d, [2] = ReLU, 19 | # [3] = MaxPool2d, [4] = Sequential(Bottleneck...), 20 | # [5] = Sequential(Bottleneck...), 21 | # [6] = Sequential(Bottleneck...), 22 | # [7] = Sequential(Bottleneck...), 23 | # [8] = AvgPool2d, [9] = Linear 24 | children = list(resnet101.children()) 25 | 26 | conv1 = nn.Sequential(*children[:3]) 27 | conv2 = nn.Sequential(*([children[3]] + list(children[4].children()))) 28 | conv3 = children[5] 29 | conv4 = children[6] 30 | conv5 = children[7] 31 | 32 | num_features_out = 256 33 | 34 | lateral_c2 = nn.Conv2d(in_channels=256, out_channels=num_features_out, kernel_size=1) 35 | lateral_c3 = nn.Conv2d(in_channels=512, out_channels=num_features_out, kernel_size=1) 36 | lateral_c4 = nn.Conv2d(in_channels=1024, out_channels=num_features_out, kernel_size=1) 37 | lateral_c5 = nn.Conv2d(in_channels=2048, out_channels=num_features_out, kernel_size=1) 38 | 39 | dealiasing_p2 = nn.Conv2d(in_channels=num_features_out, out_channels=num_features_out, kernel_size=3, padding=1) 40 | dealiasing_p3 = nn.Conv2d(in_channels=num_features_out, out_channels=num_features_out, kernel_size=3, padding=1) 41 | dealiasing_p4 = nn.Conv2d(in_channels=num_features_out, out_channels=num_features_out, kernel_size=3, padding=1) 42 | 43 | for parameters in [module.parameters() for module in [conv1, conv2]]: 44 | for parameter in parameters: 45 | parameter.requires_grad = False 46 | 47 | conv_layers = Base.ConvLayers(conv1, conv2, conv3, conv4, conv5) 48 | lateral_layers = Base.LateralLayers(lateral_c2, lateral_c3, lateral_c4, lateral_c5) 49 | dealiasing_layers = Base.DealiasingLayers(dealiasing_p2, dealiasing_p3, dealiasing_p4) 50 | 51 | return conv_layers, lateral_layers, dealiasing_layers, num_features_out 52 | -------------------------------------------------------------------------------- /backbone/resnet18.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torchvision 4 | from torch import nn 5 | 6 | from backbone.base import Base 7 | 8 | 9 | class ResNet18(Base): 10 | 11 | def __init__(self, pretrained: bool): 12 | super().__init__(pretrained) 13 | 14 | def features(self) -> Tuple[Base.ConvLayers, Base.LateralLayers, Base.DealiasingLayers, int]: 15 | resnet18 = torchvision.models.resnet18(pretrained=self._pretrained) 16 | 17 | # list(resnet18.children()) consists of following modules 18 | # [0] = Conv2d, [1] = BatchNorm2d, [2] = ReLU, [3] = MaxPool2d, 19 | # [4] = Sequential(Bottleneck...), 20 | # [5] = Sequential(Bottleneck...), 21 | # [6] = Sequential(Bottleneck...), 22 | # [7] = Sequential(Bottleneck...), 23 | # [8] = AvgPool2d, [9] = Linear 24 | children = list(resnet18.children()) 25 | 26 | conv1 = nn.Sequential(*children[:4]) 27 | conv2 = children[4] 28 | conv3 = children[5] 29 | conv4 = children[6] 30 | conv5 = children[7] 31 | 32 | num_features_out = 256 33 | 34 | lateral_c2 = nn.Conv2d(in_channels=64, out_channels=num_features_out, kernel_size=1) 35 | lateral_c3 = nn.Conv2d(in_channels=128, out_channels=num_features_out, kernel_size=1) 36 | lateral_c4 = nn.Conv2d(in_channels=256, out_channels=num_features_out, kernel_size=1) 37 | lateral_c5 = nn.Conv2d(in_channels=512, out_channels=num_features_out, kernel_size=1) 38 | 39 | dealiasing_p2 = nn.Conv2d(in_channels=num_features_out, out_channels=num_features_out, kernel_size=3, padding=1) 40 | dealiasing_p3 = nn.Conv2d(in_channels=num_features_out, out_channels=num_features_out, kernel_size=3, padding=1) 41 | dealiasing_p4 = nn.Conv2d(in_channels=num_features_out, out_channels=num_features_out, kernel_size=3, padding=1) 42 | 43 | for parameters in [module.parameters() for module in [conv1, conv2]]: 44 | for parameter in parameters: 45 | parameter.requires_grad = False 46 | 47 | conv_layers = Base.ConvLayers(conv1, conv2, conv3, conv4, conv5) 48 | lateral_layers = Base.LateralLayers(lateral_c2, lateral_c3, lateral_c4, lateral_c5) 49 | dealiasing_layers = Base.DealiasingLayers(dealiasing_p2, dealiasing_p3, dealiasing_p4) 50 | 51 | return conv_layers, lateral_layers, dealiasing_layers, num_features_out 52 | -------------------------------------------------------------------------------- /backbone/resnet50.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Callable 2 | 3 | import torchvision 4 | from torch import nn, Tensor 5 | 6 | from backbone.base import Base 7 | 8 | 9 | class ResNet50(Base): 10 | 11 | def __init__(self, pretrained: bool): 12 | super().__init__(pretrained) 13 | 14 | def features(self) -> Tuple[Base.ConvLayers, Base.LateralLayers, Base.DealiasingLayers, int]: 15 | resnet50 = torchvision.models.resnet50(pretrained=self._pretrained) 16 | 17 | # list(resnet50.children()) consists of following modules 18 | # [0] = Conv2d, [1] = BatchNorm2d, [2] = ReLU, [3] = MaxPool2d, 19 | # [4] = Sequential(Bottleneck...), 20 | # [5] = Sequential(Bottleneck...), 21 | # [6] = Sequential(Bottleneck...), 22 | # [7] = Sequential(Bottleneck...), 23 | # [8] = AvgPool2d, [9] = Linear 24 | children = list(resnet50.children()) 25 | 26 | conv1 = nn.Sequential(*children[:4]) 27 | conv2 = children[4] 28 | conv3 = children[5] 29 | conv4 = children[6] 30 | conv5 = children[7] 31 | 32 | num_features_out = 256 33 | 34 | lateral_c2 = nn.Conv2d(in_channels=256, out_channels=num_features_out, kernel_size=1) 35 | lateral_c3 = nn.Conv2d(in_channels=512, out_channels=num_features_out, kernel_size=1) 36 | lateral_c4 = nn.Conv2d(in_channels=1024, out_channels=num_features_out, kernel_size=1) 37 | lateral_c5 = nn.Conv2d(in_channels=2048, out_channels=num_features_out, kernel_size=1) 38 | 39 | dealiasing_p2 = nn.Conv2d(in_channels=num_features_out, out_channels=num_features_out, kernel_size=3, padding=1) 40 | dealiasing_p3 = nn.Conv2d(in_channels=num_features_out, out_channels=num_features_out, kernel_size=3, padding=1) 41 | dealiasing_p4 = nn.Conv2d(in_channels=num_features_out, out_channels=num_features_out, kernel_size=3, padding=1) 42 | 43 | for parameters in [module.parameters() for module in [conv1, conv2]]: 44 | for parameter in parameters: 45 | parameter.requires_grad = False 46 | 47 | conv_layers = Base.ConvLayers(conv1, conv2, conv3, conv4, conv5) 48 | lateral_layers = Base.LateralLayers(lateral_c2, lateral_c3, lateral_c4, lateral_c5) 49 | dealiasing_layers = Base.DealiasingLayers(dealiasing_p2, dealiasing_p3, dealiasing_p4) 50 | 51 | return conv_layers, lateral_layers, dealiasing_layers, num_features_out 52 | -------------------------------------------------------------------------------- /bbox.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | class BBox(object): 6 | 7 | def __init__(self, left: float, top: float, right: float, bottom: float): 8 | super().__init__() 9 | self.left = left 10 | self.top = top 11 | self.right = right 12 | self.bottom = bottom 13 | 14 | def __repr__(self) -> str: 15 | return 'BBox[l={:.1f}, t={:.1f}, r={:.1f}, b={:.1f}]'.format( 16 | self.left, self.top, self.right, self.bottom) 17 | 18 | def tolist(self): 19 | return [self.left, self.top, self.right, self.bottom] 20 | 21 | @staticmethod 22 | def to_center_base(bboxes: Tensor): 23 | return torch.stack([ 24 | (bboxes[:, 0] + bboxes[:, 2]) / 2, 25 | (bboxes[:, 1] + bboxes[:, 3]) / 2, 26 | bboxes[:, 2] - bboxes[:, 0], 27 | bboxes[:, 3] - bboxes[:, 1] 28 | ], dim=1) 29 | 30 | @staticmethod 31 | def from_center_base(center_based_bboxes: Tensor) -> Tensor: 32 | return torch.stack([ 33 | center_based_bboxes[:, 0] - center_based_bboxes[:, 2] / 2, 34 | center_based_bboxes[:, 1] - center_based_bboxes[:, 3] / 2, 35 | center_based_bboxes[:, 0] + center_based_bboxes[:, 2] / 2, 36 | center_based_bboxes[:, 1] + center_based_bboxes[:, 3] / 2 37 | ], dim=1) 38 | 39 | @staticmethod 40 | def calc_transformer(src_bboxes: Tensor, dst_bboxes: Tensor) -> Tensor: 41 | center_based_src_bboxes = BBox.to_center_base(src_bboxes) 42 | center_based_dst_bboxes = BBox.to_center_base(dst_bboxes) 43 | transformers = torch.stack([ 44 | (center_based_dst_bboxes[:, 0] - center_based_src_bboxes[:, 0]) / center_based_dst_bboxes[:, 2], 45 | (center_based_dst_bboxes[:, 1] - center_based_src_bboxes[:, 1]) / center_based_dst_bboxes[:, 3], 46 | torch.log(center_based_dst_bboxes[:, 2] / center_based_src_bboxes[:, 2]), 47 | torch.log(center_based_dst_bboxes[:, 3] / center_based_src_bboxes[:, 3]) 48 | ], dim=1) 49 | return transformers 50 | 51 | @staticmethod 52 | def apply_transformer(src_bboxes: Tensor, transformers: Tensor) -> Tensor: 53 | center_based_src_bboxes = BBox.to_center_base(src_bboxes) 54 | center_based_dst_bboxes = torch.stack([ 55 | transformers[:, 0] * center_based_src_bboxes[:, 2] + center_based_src_bboxes[:, 0], 56 | transformers[:, 1] * center_based_src_bboxes[:, 3] + center_based_src_bboxes[:, 1], 57 | torch.exp(transformers[:, 2]) * center_based_src_bboxes[:, 2], 58 | torch.exp(transformers[:, 3]) * center_based_src_bboxes[:, 3] 59 | ], dim=1) 60 | dst_bboxes = BBox.from_center_base(center_based_dst_bboxes) 61 | return dst_bboxes 62 | 63 | @staticmethod 64 | def iou(source: Tensor, other: Tensor) -> Tensor: 65 | source = source.repeat(other.shape[0], 1, 1).permute(1, 0, 2) 66 | other = other.repeat(source.shape[0], 1, 1) 67 | 68 | source_area = (source[:, :, 2] - source[:, :, 0]) * (source[:, :, 3] - source[:, :, 1]) 69 | other_area = (other[:, :, 2] - other[:, :, 0]) * (other[:, :, 3] - other[:, :, 1]) 70 | 71 | intersection_left = torch.max(source[:, :, 0], other[:, :, 0]) 72 | intersection_top = torch.max(source[:, :, 1], other[:, :, 1]) 73 | intersection_right = torch.min(source[:, :, 2], other[:, :, 2]) 74 | intersection_bottom = torch.min(source[:, :, 3], other[:, :, 3]) 75 | intersection_width = torch.clamp(intersection_right - intersection_left, min=0) 76 | intersection_height = torch.clamp(intersection_bottom - intersection_top, min=0) 77 | intersection_area = intersection_width * intersection_height 78 | 79 | return intersection_area / (source_area + other_area - intersection_area) 80 | 81 | @staticmethod 82 | def inside(source: Tensor, other: Tensor) -> bool: 83 | source = source.repeat(other.shape[0], 1, 1).permute(1, 0, 2) 84 | other = other.repeat(source.shape[0], 1, 1) 85 | return ((source[:, :, 0] >= other[:, :, 0]) * (source[:, :, 1] >= other[:, :, 1]) * 86 | (source[:, :, 2] <= other[:, :, 2]) * (source[:, :, 3] <= other[:, :, 3])) 87 | 88 | @staticmethod 89 | def clip(bboxes: Tensor, left: float, top: float, right: float, bottom: float) -> Tensor: 90 | return torch.stack([ 91 | torch.clamp(bboxes[:, 0], min=left, max=right), 92 | torch.clamp(bboxes[:, 1], min=top, max=bottom), 93 | torch.clamp(bboxes[:, 2], min=left, max=right), 94 | torch.clamp(bboxes[:, 3], min=top, max=bottom) 95 | ], dim=1) 96 | -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | import ast 4 | 5 | from roi.wrapper import Wrapper as ROIWrapper 6 | 7 | 8 | class Config(object): 9 | 10 | IMAGE_MIN_SIDE: float = 800.0 11 | IMAGE_MAX_SIDE: float = 1333.0 12 | 13 | ANCHOR_RATIOS: List[Tuple[int, int]] = [(1, 2), (1, 1), (2, 1)] 14 | ANCHOR_SCALES: List[int] = [1] 15 | POOLING_MODE: ROIWrapper.Mode = ROIWrapper.Mode.ALIGN 16 | 17 | @classmethod 18 | def describe(cls): 19 | text = '\nConfig:\n' 20 | attrs = [attr for attr in dir(cls) if not callable(getattr(cls, attr)) and not attr.startswith('__')] 21 | text += '\n'.join(['\t{:s} = {:s}'.format(attr, str(getattr(cls, attr))) for attr in attrs]) + '\n' 22 | 23 | return text 24 | 25 | @classmethod 26 | def setup(cls, image_min_side: float = None, image_max_side: float = None, 27 | anchor_ratios: List[Tuple[int, int]] = None, anchor_scales: List[int] = None, pooling_mode: str = None): 28 | if image_min_side is not None: 29 | cls.IMAGE_MIN_SIDE = image_min_side 30 | if image_max_side is not None: 31 | cls.IMAGE_MAX_SIDE = image_max_side 32 | 33 | if anchor_ratios is not None: 34 | cls.ANCHOR_RATIOS = ast.literal_eval(anchor_ratios) 35 | if anchor_scales is not None: 36 | cls.ANCHOR_SCALES = ast.literal_eval(anchor_scales) 37 | if pooling_mode is not None: 38 | cls.POOLING_MODE = ROIWrapper.Mode(pooling_mode) 39 | -------------------------------------------------------------------------------- /config/eval_config.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | from config.config import Config 4 | 5 | 6 | class EvalConfig(Config): 7 | 8 | RPN_PRE_NMS_TOP_N: int = 6000 // 5 # for each level 9 | RPN_POST_NMS_TOP_N: int = 1000 10 | 11 | @classmethod 12 | def setup(cls, image_min_side: float = None, image_max_side: float = None, 13 | anchor_ratios: List[Tuple[int, int]] = None, anchor_scales: List[int] = None, pooling_mode: str = None, 14 | rpn_pre_nms_top_n: int = None, rpn_post_nms_top_n: int = None): 15 | super().setup(image_min_side, image_max_side, anchor_ratios, anchor_scales, pooling_mode) 16 | 17 | if rpn_pre_nms_top_n is not None: 18 | cls.RPN_PRE_NMS_TOP_N = rpn_pre_nms_top_n 19 | if rpn_post_nms_top_n is not None: 20 | cls.RPN_POST_NMS_TOP_N = rpn_post_nms_top_n 21 | -------------------------------------------------------------------------------- /config/train_config.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import ast 4 | 5 | from config.config import Config 6 | 7 | 8 | class TrainConfig(Config): 9 | 10 | RPN_PRE_NMS_TOP_N: int = 12000 // 5 # for each level 11 | RPN_POST_NMS_TOP_N: int = 2000 12 | 13 | LEARNING_RATE: float = 0.001 14 | MOMENTUM: float = 0.9 15 | WEIGHT_DECAY: float = 0.0001 16 | STEP_LR_SIZES: List[int] = [50000, 70000] 17 | STEP_LR_GAMMA: float = 0.1 18 | 19 | NUM_STEPS_TO_DISPLAY: int = 20 20 | NUM_STEPS_TO_SNAPSHOT: int = 10000 21 | NUM_STEPS_TO_FINISH: int = 80000 22 | 23 | @classmethod 24 | def setup(cls, image_min_side: float = None, image_max_side: float = None, 25 | anchor_ratios: List[Tuple[int, int]] = None, anchor_scales: List[int] = None, pooling_mode: str = None, 26 | rpn_pre_nms_top_n: int = None, rpn_post_nms_top_n: int = None, 27 | learning_rate: float = None, momentum: float = None, weight_decay: float = None, 28 | step_lr_sizes: List[int] = None, step_lr_gamma: float = None, 29 | num_steps_to_display: int = None, num_steps_to_snapshot: int = None, num_steps_to_finish: int = None): 30 | super().setup(image_min_side, image_max_side, anchor_ratios, anchor_scales, pooling_mode) 31 | 32 | if rpn_pre_nms_top_n is not None: 33 | cls.RPN_PRE_NMS_TOP_N = rpn_pre_nms_top_n 34 | if rpn_post_nms_top_n is not None: 35 | cls.RPN_POST_NMS_TOP_N = rpn_post_nms_top_n 36 | 37 | if learning_rate is not None: 38 | cls.LEARNING_RATE = learning_rate 39 | if momentum is not None: 40 | cls.MOMENTUM = momentum 41 | if weight_decay is not None: 42 | cls.WEIGHT_DECAY = weight_decay 43 | if step_lr_sizes is not None: 44 | cls.STEP_LR_SIZES = ast.literal_eval(step_lr_sizes) 45 | if step_lr_gamma is not None: 46 | cls.STEP_LR_GAMMA = step_lr_gamma 47 | 48 | if num_steps_to_display is not None: 49 | cls.NUM_STEPS_TO_DISPLAY = num_steps_to_display 50 | if num_steps_to_snapshot is not None: 51 | cls.NUM_STEPS_TO_SNAPSHOT = num_steps_to_snapshot 52 | if num_steps_to_finish is not None: 53 | cls.NUM_STEPS_TO_FINISH = num_steps_to_finish 54 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /dataset/base.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Tuple, List, Type 3 | 4 | import PIL 5 | import torch.utils.data.dataset 6 | from PIL import Image 7 | from torch import Tensor 8 | from torchvision.transforms import transforms 9 | 10 | 11 | class Base(torch.utils.data.dataset.Dataset): 12 | 13 | class Mode(Enum): 14 | TRAIN = 'train' 15 | EVAL = 'eval' 16 | 17 | OPTIONS = ['voc2007', 'coco2017', 'voc2007-cat-dog', 'coco2017-person', 'coco2017-car', 'coco2017-animal'] 18 | 19 | @staticmethod 20 | def from_name(name: str) -> Type['Base']: 21 | if name == 'voc2007': 22 | from dataset.voc2007 import VOC2007 23 | return VOC2007 24 | elif name == 'coco2017': 25 | from dataset.coco2017 import COCO2017 26 | return COCO2017 27 | elif name == 'voc2007-cat-dog': 28 | from dataset.voc2007_cat_dog import VOC2007CatDog 29 | return VOC2007CatDog 30 | elif name == 'coco2017-person': 31 | from dataset.coco2017_person import COCO2017Person 32 | return COCO2017Person 33 | elif name == 'coco2017-car': 34 | from dataset.coco2017_car import COCO2017Car 35 | return COCO2017Car 36 | elif name == 'coco2017-animal': 37 | from dataset.coco2017_animal import COCO2017Animal 38 | return COCO2017Animal 39 | else: 40 | raise ValueError 41 | 42 | def __init__(self, path_to_data_dir: str, mode: Mode, image_min_side: float, image_max_side: float): 43 | self._path_to_data_dir = path_to_data_dir 44 | self._mode = mode 45 | self._image_min_side = image_min_side 46 | self._image_max_side = image_max_side 47 | 48 | def __len__(self) -> int: 49 | raise NotImplementedError 50 | 51 | def __getitem__(self, index: int) -> Tuple[str, Tensor, float, Tensor, Tensor]: 52 | raise NotImplementedError 53 | 54 | def evaluate(self, path_to_results_dir: str, image_ids: List[str], bboxes: List[List[float]], classes: List[int], probs: List[float]) -> Tuple[float, str]: 55 | raise NotImplementedError 56 | 57 | def _write_results(self, path_to_results_dir: str, image_ids: List[str], bboxes: List[List[float]], classes: List[int], probs: List[float]): 58 | raise NotImplementedError 59 | 60 | @staticmethod 61 | def num_classes() -> int: 62 | raise NotImplementedError 63 | 64 | @staticmethod 65 | def preprocess(image: PIL.Image.Image, image_min_side: float, image_max_side: float) -> Tuple[Tensor, float]: 66 | # resize according to the rules: 67 | # 1. scale shorter side to IMAGE_MIN_SIDE 68 | # 2. after scaling, if longer side > IMAGE_MAX_SIDE, scale longer side to IMAGE_MAX_SIDE 69 | scale_for_shorter_side = image_min_side / min(image.width, image.height) 70 | longer_side_after_scaling = max(image.width, image.height) * scale_for_shorter_side 71 | scale_for_longer_side = (image_max_side / longer_side_after_scaling) if longer_side_after_scaling > image_max_side else 1 72 | scale = scale_for_shorter_side * scale_for_longer_side 73 | 74 | transform = transforms.Compose([ 75 | transforms.Resize((round(image.height * scale), round(image.width * scale))), # interpolation `BILINEAR` is applied by default 76 | transforms.ToTensor(), 77 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 78 | ]) 79 | image = transform(image) 80 | 81 | return image, scale 82 | -------------------------------------------------------------------------------- /dataset/coco2017.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | import random 5 | from typing import List, Tuple, Dict 6 | 7 | import torch 8 | import torch.utils.data.dataset 9 | from PIL import Image, ImageOps 10 | from pycocotools.coco import COCO 11 | from pycocotools.cocoeval import COCOeval 12 | from torch import Tensor 13 | from torchvision.datasets import CocoDetection 14 | from tqdm import tqdm 15 | 16 | from bbox import BBox 17 | from dataset.base import Base 18 | from io import StringIO 19 | import sys 20 | 21 | 22 | class COCO2017(Base): 23 | 24 | class Annotation(object): 25 | class Object(object): 26 | def __init__(self, bbox: BBox, label: int) -> None: 27 | super().__init__() 28 | self.bbox = bbox 29 | self.label = label 30 | 31 | def __repr__(self) -> str: 32 | return 'Object[label={:d}, bbox={!s}]'.format( 33 | self.label, self.bbox) 34 | 35 | def __init__(self, filename: str, objects: List[Object]) -> None: 36 | super().__init__() 37 | self.filename = filename 38 | self.objects = objects 39 | 40 | CATEGORY_TO_LABEL_DICT = { 41 | 'background': 0, 'person': 1, 'bicycle': 2, 'car': 3, 'motorcycle': 4, 42 | 'airplane': 5, 'bus': 6, 'train': 7, 'truck': 8, 'boat': 9, 43 | 'traffic light': 10, 'fire hydrant': 11, 'street sign': 12, 'stop sign': 13, 'parking meter': 14, 44 | 'bench': 15, 'bird': 16, 'cat': 17, 'dog': 18, 'horse': 19, 45 | 'sheep': 20, 'cow': 21, 'elephant': 22, 'bear': 23, 'zebra': 24, 46 | 'giraffe': 25, 'hat': 26, 'backpack': 27, 'umbrella': 28, 'shoe': 29, 47 | 'eye glasses': 30, 'handbag': 31, 'tie': 32, 'suitcase': 33, 'frisbee': 34, 48 | 'skis': 35, 'snowboard': 36, 'sports ball': 37, 'kite': 38, 'baseball bat': 39, 49 | 'baseball glove': 40, 'skateboard': 41, 'surfboard': 42, 'tennis racket': 43, 'bottle': 44, 50 | 'plate': 45, 'wine glass': 46, 'cup': 47, 'fork': 48, 'knife': 49, 51 | 'spoon': 50, 'bowl': 51, 'banana': 52, 'apple': 53, 'sandwich': 54, 52 | 'orange': 55, 'broccoli': 56, 'carrot': 57, 'hot dog': 58, 'pizza': 59, 53 | 'donut': 60, 'cake': 61, 'chair': 62, 'couch': 63, 'potted plant': 64, 54 | 'bed': 65, 'mirror': 66, 'dining table': 67, 'window': 68, 'desk': 69, 55 | 'toilet': 70, 'door': 71, 'tv': 72, 'laptop': 73, 'mouse': 74, 56 | 'remote': 75, 'keyboard': 76, 'cell phone': 77, 'microwave': 78, 'oven': 79, 57 | 'toaster': 80, 'sink': 81, 'refrigerator': 82, 'blender': 83, 'book': 84, 58 | 'clock': 85, 'vase': 86, 'scissors': 87, 'teddy bear': 88, 'hair drier': 89, 59 | 'toothbrush': 90, 'hair brush': 91 60 | } 61 | 62 | LABEL_TO_CATEGORY_DICT = {v: k for k, v in CATEGORY_TO_LABEL_DICT.items()} 63 | 64 | def __init__(self, path_to_data_dir: str, mode: Base.Mode, image_min_side: float, image_max_side: float): 65 | super().__init__(path_to_data_dir, mode, image_min_side, image_max_side) 66 | 67 | path_to_coco_dir = os.path.join(self._path_to_data_dir, 'COCO') 68 | path_to_annotations_dir = os.path.join(path_to_coco_dir, 'annotations') 69 | path_to_caches_dir = os.path.join('caches', 'coco2017', f'{self._mode.value}') 70 | path_to_image_ids_pickle = os.path.join(path_to_caches_dir, 'image-ids.pkl') 71 | path_to_image_id_dict_pickle = os.path.join(path_to_caches_dir, 'image-id-dict.pkl') 72 | 73 | if self._mode == COCO2017.Mode.TRAIN: 74 | path_to_jpeg_images_dir = os.path.join(path_to_coco_dir, 'train2017') 75 | path_to_annotation = os.path.join(path_to_annotations_dir, 'instances_train2017.json') 76 | elif self._mode == COCO2017.Mode.EVAL: 77 | path_to_jpeg_images_dir = os.path.join(path_to_coco_dir, 'val2017') 78 | path_to_annotation = os.path.join(path_to_annotations_dir, 'instances_val2017.json') 79 | else: 80 | raise ValueError('invalid mode') 81 | 82 | coco_dataset = CocoDetection(root=path_to_jpeg_images_dir, annFile=path_to_annotation) 83 | 84 | if os.path.exists(path_to_image_ids_pickle) and os.path.exists(path_to_image_id_dict_pickle): 85 | print('loading cache files...') 86 | 87 | with open(path_to_image_ids_pickle, 'rb') as f: 88 | self._image_ids = pickle.load(f) 89 | 90 | with open(path_to_image_id_dict_pickle, 'rb') as f: 91 | self._image_id_to_annotation_dict = pickle.load(f) 92 | else: 93 | print('generating cache files...') 94 | 95 | os.makedirs(path_to_caches_dir, exist_ok=True) 96 | 97 | self._image_ids: List[str] = [] 98 | self._image_id_to_annotation_dict: Dict[str, COCO2017.Annotation] = {} 99 | 100 | for idx, (image, annotation) in enumerate(tqdm(coco_dataset)): 101 | if len(annotation) > 0: 102 | image_id = str(annotation[0]['image_id']) # all image_id in annotation are the same 103 | self._image_ids.append(image_id) 104 | self._image_id_to_annotation_dict[image_id] = COCO2017.Annotation( 105 | filename=os.path.join(path_to_jpeg_images_dir, '{:012d}.jpg'.format(int(image_id))), 106 | objects=[COCO2017.Annotation.Object( 107 | bbox=BBox( # `ann['bbox']` is in the format [left, top, width, height] 108 | left=ann['bbox'][0], 109 | top=ann['bbox'][1], 110 | right=ann['bbox'][0] + ann['bbox'][2], 111 | bottom=ann['bbox'][1] + ann['bbox'][3] 112 | ), 113 | label=ann['category_id']) 114 | for ann in annotation] 115 | ) 116 | 117 | with open(path_to_image_ids_pickle, 'wb') as f: 118 | pickle.dump(self._image_ids, f) 119 | 120 | with open(path_to_image_id_dict_pickle, 'wb') as f: 121 | pickle.dump(self._image_id_to_annotation_dict, f) 122 | 123 | def __len__(self) -> int: 124 | return len(self._image_id_to_annotation_dict) 125 | 126 | def __getitem__(self, index: int) -> Tuple[str, Tensor, float, Tensor, Tensor]: 127 | image_id = self._image_ids[index] 128 | annotation = self._image_id_to_annotation_dict[image_id] 129 | 130 | bboxes = [obj.bbox.tolist() for obj in annotation.objects] 131 | labels = [obj.label for obj in annotation.objects] 132 | 133 | bboxes = torch.tensor(bboxes, dtype=torch.float) 134 | labels = torch.tensor(labels, dtype=torch.long) 135 | 136 | image = Image.open(annotation.filename).convert('RGB') # for some grayscale images 137 | 138 | # random flip on only training mode 139 | if self._mode == COCO2017.Mode.TRAIN and random.random() > 0.5: 140 | image = ImageOps.mirror(image) 141 | bboxes[:, [0, 2]] = image.width - bboxes[:, [2, 0]] # index 0 and 2 represent `left` and `right` respectively 142 | 143 | image, scale = COCO2017.preprocess(image, self._image_min_side, self._image_max_side) 144 | bboxes *= scale 145 | 146 | return image_id, image, scale, bboxes, labels 147 | 148 | def evaluate(self, path_to_results_dir: str, image_ids: List[str], bboxes: List[List[float]], classes: List[int], probs: List[float]) -> Tuple[float, str]: 149 | self._write_results(path_to_results_dir, image_ids, bboxes, classes, probs) 150 | 151 | annType = 'bbox' 152 | path_to_coco_dir = os.path.join(self._path_to_data_dir, 'COCO') 153 | path_to_annotations_dir = os.path.join(path_to_coco_dir, 'annotations') 154 | path_to_annotation = os.path.join(path_to_annotations_dir, 'instances_val2017.json') 155 | 156 | cocoGt = COCO(path_to_annotation) 157 | cocoDt = cocoGt.loadRes(os.path.join(path_to_results_dir, 'results.json')) 158 | 159 | cocoEval = COCOeval(cocoGt, cocoDt, annType) 160 | cocoEval.evaluate() 161 | cocoEval.accumulate() 162 | 163 | original_stdout = sys.stdout 164 | string_stdout = StringIO() 165 | sys.stdout = string_stdout 166 | cocoEval.summarize() 167 | sys.stdout = original_stdout 168 | 169 | mean_ap = cocoEval.stats[0].item() # stats[0] records AP@[0.5:0.95] 170 | detail = string_stdout.getvalue() 171 | 172 | return mean_ap, detail 173 | 174 | def _write_results(self, path_to_results_dir: str, image_ids: List[str], bboxes: List[List[float]], classes: List[int], probs: List[float]): 175 | results = [] 176 | for image_id, bbox, cls, prob in zip(image_ids, bboxes, classes, probs): 177 | results.append( 178 | { 179 | 'image_id': int(image_id), # COCO evaluation requires `image_id` to be type `int` 180 | 'category_id': cls, 181 | 'bbox': [ # format [left, top, width, height] is expected 182 | bbox[0], 183 | bbox[1], 184 | bbox[2] - bbox[0], 185 | bbox[3] - bbox[1] 186 | ], 187 | 'score': prob 188 | } 189 | ) 190 | 191 | with open(os.path.join(path_to_results_dir, 'results.json'), 'w') as f: 192 | json.dump(results, f) 193 | 194 | @staticmethod 195 | def num_classes() -> int: 196 | return 92 197 | -------------------------------------------------------------------------------- /dataset/coco2017_animal.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | import random 5 | import sys 6 | from io import StringIO 7 | from typing import List, Tuple, Dict 8 | 9 | import torch 10 | import torch.utils.data.dataset 11 | from PIL import Image, ImageOps 12 | from pycocotools.coco import COCO 13 | from pycocotools.cocoeval import COCOeval 14 | from torch import Tensor 15 | from torchvision.datasets import CocoDetection 16 | from tqdm import tqdm 17 | 18 | from bbox import BBox 19 | from dataset.base import Base 20 | from dataset.coco2017 import COCO2017 21 | 22 | 23 | class COCO2017Animal(Base): 24 | 25 | class Annotation(object): 26 | class Object(object): 27 | def __init__(self, bbox: BBox, label: int) -> None: 28 | super().__init__() 29 | self.bbox = bbox 30 | self.label = label 31 | 32 | def __repr__(self) -> str: 33 | return 'Object[label={:d}, bbox={!s}]'.format( 34 | self.label, self.bbox) 35 | 36 | def __init__(self, filename: str, objects: List[Object]) -> None: 37 | super().__init__() 38 | self.filename = filename 39 | self.objects = objects 40 | 41 | CATEGORY_TO_LABEL_DICT = { 42 | 'background': 0, 43 | 'bird': 1, 'cat': 2, 'dog': 3, 'horse': 4, 'sheep': 5, 44 | 'cow': 6, 'elephant': 7, 'bear': 8, 'zebra': 9, 'giraffe': 10 45 | } 46 | 47 | LABEL_TO_CATEGORY_DICT = {v: k for k, v in CATEGORY_TO_LABEL_DICT.items()} 48 | 49 | def __init__(self, path_to_data_dir: str, mode: Base.Mode, image_min_side: float, image_max_side: float): 50 | super().__init__(path_to_data_dir, mode, image_min_side, image_max_side) 51 | 52 | path_to_coco_dir = os.path.join(self._path_to_data_dir, 'COCO') 53 | path_to_annotations_dir = os.path.join(path_to_coco_dir, 'annotations') 54 | path_to_caches_dir = os.path.join('caches', 'coco2017-animal', f'{self._mode.value}') 55 | path_to_image_ids_pickle = os.path.join(path_to_caches_dir, 'image-ids.pkl') 56 | path_to_image_id_dict_pickle = os.path.join(path_to_caches_dir, 'image-id-dict.pkl') 57 | 58 | if self._mode == COCO2017Animal.Mode.TRAIN: 59 | path_to_jpeg_images_dir = os.path.join(path_to_coco_dir, 'train2017') 60 | path_to_annotation = os.path.join(path_to_annotations_dir, 'instances_train2017.json') 61 | elif self._mode == COCO2017Animal.Mode.EVAL: 62 | path_to_jpeg_images_dir = os.path.join(path_to_coco_dir, 'val2017') 63 | path_to_annotation = os.path.join(path_to_annotations_dir, 'instances_val2017.json') 64 | else: 65 | raise ValueError('invalid mode') 66 | 67 | coco_dataset = CocoDetection(root=path_to_jpeg_images_dir, annFile=path_to_annotation) 68 | 69 | if os.path.exists(path_to_image_ids_pickle) and os.path.exists(path_to_image_id_dict_pickle): 70 | print('loading cache files...') 71 | 72 | with open(path_to_image_ids_pickle, 'rb') as f: 73 | self._image_ids = pickle.load(f) 74 | 75 | with open(path_to_image_id_dict_pickle, 'rb') as f: 76 | self._image_id_to_annotation_dict = pickle.load(f) 77 | else: 78 | print('generating cache files...') 79 | 80 | os.makedirs(path_to_caches_dir, exist_ok=True) 81 | 82 | self._image_id_to_annotation_dict: Dict[str, COCO2017Animal.Annotation] = {} 83 | for idx, (image, annotation) in enumerate(tqdm(coco_dataset)): 84 | if len(annotation) > 0: 85 | image_id = str(annotation[0]['image_id']) # all image_id in annotation are the same 86 | annotation = COCO2017Animal.Annotation( 87 | filename=os.path.join(path_to_jpeg_images_dir, '{:012d}.jpg'.format(int(image_id))), 88 | objects=[COCO2017Animal.Annotation.Object( 89 | bbox=BBox( # `ann['bbox']` is in the format [left, top, width, height] 90 | left=ann['bbox'][0], 91 | top=ann['bbox'][1], 92 | right=ann['bbox'][0] + ann['bbox'][2], 93 | bottom=ann['bbox'][1] + ann['bbox'][3] 94 | ), 95 | label=ann['category_id']) 96 | for ann in annotation] 97 | ) 98 | annotation.objects = [obj for obj in annotation.objects 99 | if obj.label in [COCO2017.CATEGORY_TO_LABEL_DICT[category] # filtering label should refer to original `COCO2017` dataset 100 | for category in COCO2017Animal.CATEGORY_TO_LABEL_DICT.keys()][1:]] 101 | 102 | if len(annotation.objects) > 0: 103 | self._image_id_to_annotation_dict[image_id] = annotation 104 | 105 | self._image_ids = list(self._image_id_to_annotation_dict.keys()) 106 | 107 | with open(path_to_image_ids_pickle, 'wb') as f: 108 | pickle.dump(self._image_ids, f) 109 | 110 | with open(path_to_image_id_dict_pickle, 'wb') as f: 111 | pickle.dump(self._image_id_to_annotation_dict, f) 112 | 113 | def __len__(self) -> int: 114 | return len(self._image_id_to_annotation_dict) 115 | 116 | def __getitem__(self, index: int) -> Tuple[str, Tensor, float, Tensor, Tensor]: 117 | image_id = self._image_ids[index] 118 | annotation = self._image_id_to_annotation_dict[image_id] 119 | 120 | bboxes = [obj.bbox.tolist() for obj in annotation.objects] 121 | labels = [COCO2017Animal.CATEGORY_TO_LABEL_DICT[COCO2017.LABEL_TO_CATEGORY_DICT[obj.label]] for obj in annotation.objects] # mapping from original `COCO2017` dataset 122 | 123 | bboxes = torch.tensor(bboxes, dtype=torch.float) 124 | labels = torch.tensor(labels, dtype=torch.long) 125 | 126 | image = Image.open(annotation.filename).convert('RGB') # for some grayscale images 127 | 128 | # random flip on only training mode 129 | if self._mode == COCO2017Animal.Mode.TRAIN and random.random() > 0.5: 130 | image = ImageOps.mirror(image) 131 | bboxes[:, [0, 2]] = image.width - bboxes[:, [2, 0]] # index 0 and 2 represent `left` and `right` respectively 132 | 133 | image, scale = COCO2017Animal.preprocess(image, self._image_min_side, self._image_max_side) 134 | bboxes *= scale 135 | 136 | return image_id, image, scale, bboxes, labels 137 | 138 | def evaluate(self, path_to_results_dir: str, image_ids: List[str], bboxes: List[List[float]], classes: List[int], probs: List[float]) -> Tuple[float, str]: 139 | self._write_results(path_to_results_dir, image_ids, bboxes, classes, probs) 140 | 141 | annType = 'bbox' 142 | path_to_coco_dir = os.path.join(self._path_to_data_dir, 'COCO') 143 | path_to_annotations_dir = os.path.join(path_to_coco_dir, 'annotations') 144 | path_to_annotation = os.path.join(path_to_annotations_dir, 'instances_val2017.json') 145 | 146 | cocoGt = COCO(path_to_annotation) 147 | cocoDt = cocoGt.loadRes(os.path.join(path_to_results_dir, 'results.json')) 148 | 149 | cocoEval = COCOeval(cocoGt, cocoDt, annType) 150 | cocoEval.params.catIds = [COCO2017.CATEGORY_TO_LABEL_DICT[category] # filtering label should refer to original `COCO2017` dataset 151 | for category in COCO2017Animal.CATEGORY_TO_LABEL_DICT.keys()] 152 | cocoEval.evaluate() 153 | cocoEval.accumulate() 154 | 155 | original_stdout = sys.stdout 156 | string_stdout = StringIO() 157 | sys.stdout = string_stdout 158 | cocoEval.summarize() 159 | sys.stdout = original_stdout 160 | 161 | mean_ap = cocoEval.stats[0].item() # stats[0] records AP@[0.5:0.95] 162 | detail = string_stdout.getvalue() 163 | 164 | return mean_ap, detail 165 | 166 | def _write_results(self, path_to_results_dir: str, image_ids: List[str], bboxes: List[List[float]], classes: List[int], probs: List[float]): 167 | results = [] 168 | for image_id, bbox, cls, prob in zip(image_ids, bboxes, classes, probs): 169 | results.append( 170 | { 171 | 'image_id': int(image_id), # COCO evaluation requires `image_id` to be type `int` 172 | 'category_id': COCO2017.CATEGORY_TO_LABEL_DICT[COCO2017Animal.LABEL_TO_CATEGORY_DICT[cls]], # mapping to original `COCO2017` dataset 173 | 'bbox': [ # format [left, top, width, height] is expected 174 | bbox[0], 175 | bbox[1], 176 | bbox[2] - bbox[0], 177 | bbox[3] - bbox[1] 178 | ], 179 | 'score': prob 180 | } 181 | ) 182 | 183 | with open(os.path.join(path_to_results_dir, 'results.json'), 'w') as f: 184 | json.dump(results, f) 185 | 186 | @staticmethod 187 | def num_classes() -> int: 188 | return 11 189 | -------------------------------------------------------------------------------- /dataset/coco2017_car.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | import random 5 | import sys 6 | from io import StringIO 7 | from typing import List, Tuple, Dict 8 | 9 | import torch 10 | import torch.utils.data.dataset 11 | from PIL import Image, ImageOps 12 | from pycocotools.coco import COCO 13 | from pycocotools.cocoeval import COCOeval 14 | from torch import Tensor 15 | from torchvision.datasets import CocoDetection 16 | from tqdm import tqdm 17 | 18 | from bbox import BBox 19 | from dataset.base import Base 20 | from dataset.coco2017 import COCO2017 21 | 22 | 23 | class COCO2017Car(Base): 24 | 25 | class Annotation(object): 26 | class Object(object): 27 | def __init__(self, bbox: BBox, label: int) -> None: 28 | super().__init__() 29 | self.bbox = bbox 30 | self.label = label 31 | 32 | def __repr__(self) -> str: 33 | return 'Object[label={:d}, bbox={!s}]'.format( 34 | self.label, self.bbox) 35 | 36 | def __init__(self, filename: str, objects: List[Object]) -> None: 37 | super().__init__() 38 | self.filename = filename 39 | self.objects = objects 40 | 41 | CATEGORY_TO_LABEL_DICT = { 42 | 'background': 0, 'car': 1 43 | } 44 | 45 | LABEL_TO_CATEGORY_DICT = {v: k for k, v in CATEGORY_TO_LABEL_DICT.items()} 46 | 47 | def __init__(self, path_to_data_dir: str, mode: Base.Mode, image_min_side: float, image_max_side: float): 48 | super().__init__(path_to_data_dir, mode, image_min_side, image_max_side) 49 | 50 | path_to_coco_dir = os.path.join(self._path_to_data_dir, 'COCO') 51 | path_to_annotations_dir = os.path.join(path_to_coco_dir, 'annotations') 52 | path_to_caches_dir = os.path.join('caches', 'coco2017-car', f'{self._mode.value}') 53 | path_to_image_ids_pickle = os.path.join(path_to_caches_dir, 'image-ids.pkl') 54 | path_to_image_id_dict_pickle = os.path.join(path_to_caches_dir, 'image-id-dict.pkl') 55 | 56 | if self._mode == COCO2017Car.Mode.TRAIN: 57 | path_to_jpeg_images_dir = os.path.join(path_to_coco_dir, 'train2017') 58 | path_to_annotation = os.path.join(path_to_annotations_dir, 'instances_train2017.json') 59 | elif self._mode == COCO2017Car.Mode.EVAL: 60 | path_to_jpeg_images_dir = os.path.join(path_to_coco_dir, 'val2017') 61 | path_to_annotation = os.path.join(path_to_annotations_dir, 'instances_val2017.json') 62 | else: 63 | raise ValueError('invalid mode') 64 | 65 | coco_dataset = CocoDetection(root=path_to_jpeg_images_dir, annFile=path_to_annotation) 66 | 67 | if os.path.exists(path_to_image_ids_pickle) and os.path.exists(path_to_image_id_dict_pickle): 68 | print('loading cache files...') 69 | 70 | with open(path_to_image_ids_pickle, 'rb') as f: 71 | self._image_ids = pickle.load(f) 72 | 73 | with open(path_to_image_id_dict_pickle, 'rb') as f: 74 | self._image_id_to_annotation_dict = pickle.load(f) 75 | else: 76 | print('generating cache files...') 77 | 78 | os.makedirs(path_to_caches_dir, exist_ok=True) 79 | 80 | self._image_id_to_annotation_dict: Dict[str, COCO2017Car.Annotation] = {} 81 | for idx, (image, annotation) in enumerate(tqdm(coco_dataset)): 82 | if len(annotation) > 0: 83 | image_id = str(annotation[0]['image_id']) # all image_id in annotation are the same 84 | annotation = COCO2017Car.Annotation( 85 | filename=os.path.join(path_to_jpeg_images_dir, '{:012d}.jpg'.format(int(image_id))), 86 | objects=[COCO2017Car.Annotation.Object( 87 | bbox=BBox( # `ann['bbox']` is in the format [left, top, width, height] 88 | left=ann['bbox'][0], 89 | top=ann['bbox'][1], 90 | right=ann['bbox'][0] + ann['bbox'][2], 91 | bottom=ann['bbox'][1] + ann['bbox'][3] 92 | ), 93 | label=ann['category_id']) 94 | for ann in annotation] 95 | ) 96 | annotation.objects = [obj for obj in annotation.objects 97 | if obj.label in [COCO2017.CATEGORY_TO_LABEL_DICT['car']]] # filtering label should refer to original `COCO2017` dataset 98 | 99 | if len(annotation.objects) > 0: 100 | self._image_id_to_annotation_dict[image_id] = annotation 101 | 102 | self._image_ids = list(self._image_id_to_annotation_dict.keys()) 103 | 104 | with open(path_to_image_ids_pickle, 'wb') as f: 105 | pickle.dump(self._image_ids, f) 106 | 107 | with open(path_to_image_id_dict_pickle, 'wb') as f: 108 | pickle.dump(self._image_id_to_annotation_dict, f) 109 | 110 | def __len__(self) -> int: 111 | return len(self._image_id_to_annotation_dict) 112 | 113 | def __getitem__(self, index: int) -> Tuple[str, Tensor, float, Tensor, Tensor]: 114 | image_id = self._image_ids[index] 115 | annotation = self._image_id_to_annotation_dict[image_id] 116 | 117 | bboxes = [obj.bbox.tolist() for obj in annotation.objects] 118 | labels = [COCO2017Car.CATEGORY_TO_LABEL_DICT[COCO2017.LABEL_TO_CATEGORY_DICT[obj.label]] for obj in annotation.objects] # mapping from original `COCO2017` dataset 119 | 120 | bboxes = torch.tensor(bboxes, dtype=torch.float) 121 | labels = torch.tensor(labels, dtype=torch.long) 122 | 123 | image = Image.open(annotation.filename).convert('RGB') # for some grayscale images 124 | 125 | # random flip on only training mode 126 | if self._mode == COCO2017Car.Mode.TRAIN and random.random() > 0.5: 127 | image = ImageOps.mirror(image) 128 | bboxes[:, [0, 2]] = image.width - bboxes[:, [2, 0]] # index 0 and 2 represent `left` and `right` respectively 129 | 130 | image, scale = COCO2017Car.preprocess(image, self._image_min_side, self._image_max_side) 131 | bboxes *= scale 132 | 133 | return image_id, image, scale, bboxes, labels 134 | 135 | def evaluate(self, path_to_results_dir: str, image_ids: List[str], bboxes: List[List[float]], classes: List[int], probs: List[float]) -> Tuple[float, str]: 136 | self._write_results(path_to_results_dir, image_ids, bboxes, classes, probs) 137 | 138 | annType = 'bbox' 139 | path_to_coco_dir = os.path.join(self._path_to_data_dir, 'COCO') 140 | path_to_annotations_dir = os.path.join(path_to_coco_dir, 'annotations') 141 | path_to_annotation = os.path.join(path_to_annotations_dir, 'instances_val2017.json') 142 | 143 | cocoGt = COCO(path_to_annotation) 144 | cocoDt = cocoGt.loadRes(os.path.join(path_to_results_dir, 'results.json')) 145 | 146 | cocoEval = COCOeval(cocoGt, cocoDt, annType) 147 | cocoEval.params.catIds = COCO2017.CATEGORY_TO_LABEL_DICT['car'] # filtering label should refer to original `COCO2017` dataset 148 | cocoEval.evaluate() 149 | cocoEval.accumulate() 150 | 151 | original_stdout = sys.stdout 152 | string_stdout = StringIO() 153 | sys.stdout = string_stdout 154 | cocoEval.summarize() 155 | sys.stdout = original_stdout 156 | 157 | mean_ap = cocoEval.stats[0].item() # stats[0] records AP@[0.5:0.95] 158 | detail = string_stdout.getvalue() 159 | 160 | return mean_ap, detail 161 | 162 | def _write_results(self, path_to_results_dir: str, image_ids: List[str], bboxes: List[List[float]], classes: List[int], probs: List[float]): 163 | results = [] 164 | for image_id, bbox, cls, prob in zip(image_ids, bboxes, classes, probs): 165 | results.append( 166 | { 167 | 'image_id': int(image_id), # COCO evaluation requires `image_id` to be type `int` 168 | 'category_id': COCO2017.CATEGORY_TO_LABEL_DICT[COCO2017Car.LABEL_TO_CATEGORY_DICT[cls]], # mapping to original `COCO2017` dataset 169 | 'bbox': [ # format [left, top, width, height] is expected 170 | bbox[0], 171 | bbox[1], 172 | bbox[2] - bbox[0], 173 | bbox[3] - bbox[1] 174 | ], 175 | 'score': prob 176 | } 177 | ) 178 | 179 | with open(os.path.join(path_to_results_dir, 'results.json'), 'w') as f: 180 | json.dump(results, f) 181 | 182 | @staticmethod 183 | def num_classes() -> int: 184 | return 2 185 | -------------------------------------------------------------------------------- /dataset/coco2017_person.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | import random 5 | import sys 6 | from io import StringIO 7 | from typing import List, Tuple, Dict 8 | 9 | import torch 10 | import torch.utils.data.dataset 11 | from PIL import Image, ImageOps 12 | from pycocotools.coco import COCO 13 | from pycocotools.cocoeval import COCOeval 14 | from torch import Tensor 15 | from torchvision.datasets import CocoDetection 16 | from tqdm import tqdm 17 | 18 | from bbox import BBox 19 | from dataset.base import Base 20 | from dataset.coco2017 import COCO2017 21 | 22 | 23 | class COCO2017Person(Base): 24 | 25 | class Annotation(object): 26 | class Object(object): 27 | def __init__(self, bbox: BBox, label: int) -> None: 28 | super().__init__() 29 | self.bbox = bbox 30 | self.label = label 31 | 32 | def __repr__(self) -> str: 33 | return 'Object[label={:d}, bbox={!s}]'.format( 34 | self.label, self.bbox) 35 | 36 | def __init__(self, filename: str, objects: List[Object]) -> None: 37 | super().__init__() 38 | self.filename = filename 39 | self.objects = objects 40 | 41 | CATEGORY_TO_LABEL_DICT = { 42 | 'background': 0, 'person': 1 43 | } 44 | 45 | LABEL_TO_CATEGORY_DICT = {v: k for k, v in CATEGORY_TO_LABEL_DICT.items()} 46 | 47 | def __init__(self, path_to_data_dir: str, mode: Base.Mode, image_min_side: float, image_max_side: float): 48 | super().__init__(path_to_data_dir, mode, image_min_side, image_max_side) 49 | 50 | path_to_coco_dir = os.path.join(self._path_to_data_dir, 'COCO') 51 | path_to_annotations_dir = os.path.join(path_to_coco_dir, 'annotations') 52 | path_to_caches_dir = os.path.join('caches', 'coco2017-person', f'{self._mode.value}') 53 | path_to_image_ids_pickle = os.path.join(path_to_caches_dir, 'image-ids.pkl') 54 | path_to_image_id_dict_pickle = os.path.join(path_to_caches_dir, 'image-id-dict.pkl') 55 | 56 | if self._mode == COCO2017Person.Mode.TRAIN: 57 | path_to_jpeg_images_dir = os.path.join(path_to_coco_dir, 'train2017') 58 | path_to_annotation = os.path.join(path_to_annotations_dir, 'instances_train2017.json') 59 | elif self._mode == COCO2017Person.Mode.EVAL: 60 | path_to_jpeg_images_dir = os.path.join(path_to_coco_dir, 'val2017') 61 | path_to_annotation = os.path.join(path_to_annotations_dir, 'instances_val2017.json') 62 | else: 63 | raise ValueError('invalid mode') 64 | 65 | coco_dataset = CocoDetection(root=path_to_jpeg_images_dir, annFile=path_to_annotation) 66 | 67 | if os.path.exists(path_to_image_ids_pickle) and os.path.exists(path_to_image_id_dict_pickle): 68 | print('loading cache files...') 69 | 70 | with open(path_to_image_ids_pickle, 'rb') as f: 71 | self._image_ids = pickle.load(f) 72 | 73 | with open(path_to_image_id_dict_pickle, 'rb') as f: 74 | self._image_id_to_annotation_dict = pickle.load(f) 75 | else: 76 | print('generating cache files...') 77 | 78 | os.makedirs(path_to_caches_dir, exist_ok=True) 79 | 80 | self._image_id_to_annotation_dict: Dict[str, COCO2017Person.Annotation] = {} 81 | for idx, (image, annotation) in enumerate(tqdm(coco_dataset)): 82 | if len(annotation) > 0: 83 | image_id = str(annotation[0]['image_id']) # all image_id in annotation are the same 84 | annotation = COCO2017Person.Annotation( 85 | filename=os.path.join(path_to_jpeg_images_dir, '{:012d}.jpg'.format(int(image_id))), 86 | objects=[COCO2017Person.Annotation.Object( 87 | bbox=BBox( # `ann['bbox']` is in the format [left, top, width, height] 88 | left=ann['bbox'][0], 89 | top=ann['bbox'][1], 90 | right=ann['bbox'][0] + ann['bbox'][2], 91 | bottom=ann['bbox'][1] + ann['bbox'][3] 92 | ), 93 | label=ann['category_id']) 94 | for ann in annotation] 95 | ) 96 | annotation.objects = [obj for obj in annotation.objects 97 | if obj.label in [COCO2017.CATEGORY_TO_LABEL_DICT['person']]] # filtering label should refer to original `COCO2017` dataset 98 | 99 | if len(annotation.objects) > 0: 100 | self._image_id_to_annotation_dict[image_id] = annotation 101 | 102 | self._image_ids = list(self._image_id_to_annotation_dict.keys()) 103 | 104 | with open(path_to_image_ids_pickle, 'wb') as f: 105 | pickle.dump(self._image_ids, f) 106 | 107 | with open(path_to_image_id_dict_pickle, 'wb') as f: 108 | pickle.dump(self._image_id_to_annotation_dict, f) 109 | 110 | def __len__(self) -> int: 111 | return len(self._image_id_to_annotation_dict) 112 | 113 | def __getitem__(self, index: int) -> Tuple[str, Tensor, float, Tensor, Tensor]: 114 | image_id = self._image_ids[index] 115 | annotation = self._image_id_to_annotation_dict[image_id] 116 | 117 | bboxes = [obj.bbox.tolist() for obj in annotation.objects] 118 | labels = [COCO2017Person.CATEGORY_TO_LABEL_DICT[COCO2017.LABEL_TO_CATEGORY_DICT[obj.label]] for obj in annotation.objects] # mapping from original `COCO2017` dataset 119 | 120 | bboxes = torch.tensor(bboxes, dtype=torch.float) 121 | labels = torch.tensor(labels, dtype=torch.long) 122 | 123 | image = Image.open(annotation.filename).convert('RGB') # for some grayscale images 124 | 125 | # random flip on only training mode 126 | if self._mode == COCO2017Person.Mode.TRAIN and random.random() > 0.5: 127 | image = ImageOps.mirror(image) 128 | bboxes[:, [0, 2]] = image.width - bboxes[:, [2, 0]] # index 0 and 2 represent `left` and `right` respectively 129 | 130 | image, scale = COCO2017Person.preprocess(image, self._image_min_side, self._image_max_side) 131 | bboxes *= scale 132 | 133 | return image_id, image, scale, bboxes, labels 134 | 135 | def evaluate(self, path_to_results_dir: str, image_ids: List[str], bboxes: List[List[float]], classes: List[int], probs: List[float]) -> Tuple[float, str]: 136 | self._write_results(path_to_results_dir, image_ids, bboxes, classes, probs) 137 | 138 | annType = 'bbox' 139 | path_to_coco_dir = os.path.join(self._path_to_data_dir, 'COCO') 140 | path_to_annotations_dir = os.path.join(path_to_coco_dir, 'annotations') 141 | path_to_annotation = os.path.join(path_to_annotations_dir, 'instances_val2017.json') 142 | 143 | cocoGt = COCO(path_to_annotation) 144 | cocoDt = cocoGt.loadRes(os.path.join(path_to_results_dir, 'results.json')) 145 | 146 | cocoEval = COCOeval(cocoGt, cocoDt, annType) 147 | cocoEval.params.catIds = COCO2017.CATEGORY_TO_LABEL_DICT['person'] # filtering label should refer to original `COCO2017` dataset 148 | cocoEval.evaluate() 149 | cocoEval.accumulate() 150 | 151 | original_stdout = sys.stdout 152 | string_stdout = StringIO() 153 | sys.stdout = string_stdout 154 | cocoEval.summarize() 155 | sys.stdout = original_stdout 156 | 157 | mean_ap = cocoEval.stats[0].item() # stats[0] records AP@[0.5:0.95] 158 | detail = string_stdout.getvalue() 159 | 160 | return mean_ap, detail 161 | 162 | def _write_results(self, path_to_results_dir: str, image_ids: List[str], bboxes: List[List[float]], classes: List[int], probs: List[float]): 163 | results = [] 164 | for image_id, bbox, cls, prob in zip(image_ids, bboxes, classes, probs): 165 | results.append( 166 | { 167 | 'image_id': int(image_id), # COCO evaluation requires `image_id` to be type `int` 168 | 'category_id': COCO2017.CATEGORY_TO_LABEL_DICT[COCO2017Person.LABEL_TO_CATEGORY_DICT[cls]], # mapping to original `COCO2017` dataset 169 | 'bbox': [ # format [left, top, width, height] is expected 170 | bbox[0], 171 | bbox[1], 172 | bbox[2] - bbox[0], 173 | bbox[3] - bbox[1] 174 | ], 175 | 'score': prob 176 | } 177 | ) 178 | 179 | with open(os.path.join(path_to_results_dir, 'results.json'), 'w') as f: 180 | json.dump(results, f) 181 | 182 | @staticmethod 183 | def num_classes() -> int: 184 | return 2 185 | -------------------------------------------------------------------------------- /dataset/voc2007.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import xml.etree.ElementTree as ET 4 | from typing import List, Tuple 5 | 6 | import numpy as np 7 | import torch.utils.data 8 | from PIL import Image, ImageOps 9 | from torch import Tensor 10 | 11 | from bbox import BBox 12 | from dataset.base import Base 13 | from voc_eval import voc_eval 14 | 15 | 16 | class VOC2007(Base): 17 | 18 | class Annotation(object): 19 | class Object(object): 20 | def __init__(self, name: str, difficult: bool, bbox: BBox): 21 | super().__init__() 22 | self.name = name 23 | self.difficult = difficult 24 | self.bbox = bbox 25 | 26 | def __repr__(self) -> str: 27 | return 'Object[name={:s}, difficult={!s}, bbox={!s}]'.format( 28 | self.name, self.difficult, self.bbox) 29 | 30 | def __init__(self, filename: str, objects: List[Object]): 31 | super().__init__() 32 | self.filename = filename 33 | self.objects = objects 34 | 35 | CATEGORY_TO_LABEL_DICT = { 36 | 'background': 0, 37 | 'aeroplane': 1, 'bicycle': 2, 'bird': 3, 'boat': 4, 'bottle': 5, 38 | 'bus': 6, 'car': 7, 'cat': 8, 'chair': 9, 'cow': 10, 39 | 'diningtable': 11, 'dog': 12, 'horse': 13, 'motorbike': 14, 'person': 15, 40 | 'pottedplant': 16, 'sheep': 17, 'sofa': 18, 'train': 19, 'tvmonitor': 20 41 | } 42 | 43 | LABEL_TO_CATEGORY_DICT = {v: k for k, v in CATEGORY_TO_LABEL_DICT.items()} 44 | 45 | def __init__(self, path_to_data_dir: str, mode: Base.Mode, image_min_side: float, image_max_side: float): 46 | super().__init__(path_to_data_dir, mode, image_min_side, image_max_side) 47 | 48 | path_to_voc2007_dir = os.path.join(self._path_to_data_dir, 'VOCdevkit', 'VOC2007') 49 | path_to_imagesets_main_dir = os.path.join(path_to_voc2007_dir, 'ImageSets', 'Main') 50 | path_to_annotations_dir = os.path.join(path_to_voc2007_dir, 'Annotations') 51 | self._path_to_jpeg_images_dir = os.path.join(path_to_voc2007_dir, 'JPEGImages') 52 | 53 | if self._mode == VOC2007.Mode.TRAIN: 54 | path_to_image_ids_txt = os.path.join(path_to_imagesets_main_dir, 'trainval.txt') 55 | elif self._mode == VOC2007.Mode.EVAL: 56 | path_to_image_ids_txt = os.path.join(path_to_imagesets_main_dir, 'test.txt') 57 | else: 58 | raise ValueError('invalid mode') 59 | 60 | with open(path_to_image_ids_txt, 'r') as f: 61 | lines = f.readlines() 62 | self._image_ids = [line.rstrip() for line in lines] 63 | 64 | self._image_id_to_annotation_dict = {} 65 | for image_id in self._image_ids: 66 | path_to_annotation_xml = os.path.join(path_to_annotations_dir, f'{image_id}.xml') 67 | tree = ET.ElementTree(file=path_to_annotation_xml) 68 | root = tree.getroot() 69 | 70 | self._image_id_to_annotation_dict[image_id] = VOC2007.Annotation( 71 | filename=next(root.iterfind('filename')).text, 72 | objects=[VOC2007.Annotation.Object(name=next(tag_object.iterfind('name')).text, 73 | difficult=next(tag_object.iterfind('difficult')).text == '1', 74 | bbox=BBox( 75 | left=float(next(tag_object.iterfind('bndbox/xmin')).text), 76 | top=float(next(tag_object.iterfind('bndbox/ymin')).text), 77 | right=float(next(tag_object.iterfind('bndbox/xmax')).text), 78 | bottom=float(next(tag_object.iterfind('bndbox/ymax')).text)) 79 | ) 80 | for tag_object in root.iterfind('object')] 81 | ) 82 | 83 | def __len__(self) -> int: 84 | return len(self._image_id_to_annotation_dict) 85 | 86 | def __getitem__(self, index: int) -> Tuple[str, Tensor, float, Tensor, Tensor]: 87 | image_id = self._image_ids[index] 88 | annotation = self._image_id_to_annotation_dict[image_id] 89 | 90 | bboxes = [obj.bbox.tolist() for obj in annotation.objects if not obj.difficult] 91 | labels = [VOC2007.CATEGORY_TO_LABEL_DICT[obj.name] for obj in annotation.objects if not obj.difficult] 92 | 93 | bboxes = torch.tensor(bboxes, dtype=torch.float) 94 | labels = torch.tensor(labels, dtype=torch.long) 95 | 96 | image = Image.open(os.path.join(self._path_to_jpeg_images_dir, annotation.filename)) 97 | 98 | # random flip on only training mode 99 | if self._mode == VOC2007.Mode.TRAIN and random.random() > 0.5: 100 | image = ImageOps.mirror(image) 101 | bboxes[:, [0, 2]] = image.width - bboxes[:, [2, 0]] # index 0 and 2 represent `left` and `right` respectively 102 | 103 | image, scale = VOC2007.preprocess(image, self._image_min_side, self._image_max_side) 104 | bboxes *= scale 105 | 106 | return image_id, image, scale, bboxes, labels 107 | 108 | def evaluate(self, path_to_results_dir: str, image_ids: List[str], bboxes: List[List[float]], classes: List[int], probs: List[float]) -> Tuple[float, str]: 109 | self._write_results(path_to_results_dir, image_ids, bboxes, classes, probs) 110 | 111 | path_to_voc2007_dir = os.path.join(self._path_to_data_dir, 'VOCdevkit', 'VOC2007') 112 | path_to_main_dir = os.path.join(path_to_voc2007_dir, 'ImageSets', 'Main') 113 | path_to_annotations_dir = os.path.join(path_to_voc2007_dir, 'Annotations') 114 | 115 | class_to_ap_dict = {} 116 | for c in range(1, VOC2007.num_classes()): 117 | category = VOC2007.LABEL_TO_CATEGORY_DICT[c] 118 | try: 119 | path_to_cache_dir = os.path.join('caches', 'voc2007') 120 | os.makedirs(path_to_cache_dir, exist_ok=True) 121 | _, _, ap = voc_eval(detpath=os.path.join(path_to_results_dir, 'comp3_det_test_{:s}.txt'.format(category)), 122 | annopath=os.path.join(path_to_annotations_dir, '{:s}.xml'), 123 | imagesetfile=os.path.join(path_to_main_dir, 'test.txt'), 124 | classname=category, 125 | cachedir=path_to_cache_dir, 126 | ovthresh=0.5, 127 | use_07_metric=True) 128 | except IndexError: 129 | ap = 0 130 | 131 | class_to_ap_dict[c] = ap 132 | 133 | mean_ap = np.mean([v for k, v in class_to_ap_dict.items()]).item() 134 | 135 | detail = '' 136 | for c in range(1, VOC2007.num_classes()): 137 | detail += '{:d}: {:s} AP = {:.4f}\n'.format(c, VOC2007.LABEL_TO_CATEGORY_DICT[c], class_to_ap_dict[c]) 138 | 139 | return mean_ap, detail 140 | 141 | def _write_results(self, path_to_results_dir: str, image_ids: List[str], bboxes: List[List[float]], classes: List[int], probs: List[float]): 142 | class_to_txt_files_dict = {} 143 | for c in range(1, VOC2007.num_classes()): 144 | class_to_txt_files_dict[c] = open(os.path.join(path_to_results_dir, 'comp3_det_test_{:s}.txt'.format(VOC2007.LABEL_TO_CATEGORY_DICT[c])), 'w') 145 | 146 | for image_id, bbox, cls, prob in zip(image_ids, bboxes, classes, probs): 147 | class_to_txt_files_dict[cls].write('{:s} {:f} {:f} {:f} {:f} {:f}\n'.format(image_id, prob, 148 | bbox[0], bbox[1], bbox[2], bbox[3])) 149 | 150 | for _, f in class_to_txt_files_dict.items(): 151 | f.close() 152 | 153 | @staticmethod 154 | def num_classes() -> int: 155 | return 21 156 | -------------------------------------------------------------------------------- /dataset/voc2007_cat_dog.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import xml.etree.ElementTree as ET 4 | from typing import List, Tuple 5 | 6 | import numpy as np 7 | import torch.utils.data 8 | from PIL import Image, ImageOps 9 | from torch import Tensor 10 | 11 | from bbox import BBox 12 | from dataset.base import Base 13 | from voc_eval import voc_eval 14 | 15 | 16 | class VOC2007CatDog(Base): 17 | 18 | class Annotation(object): 19 | class Object(object): 20 | def __init__(self, name: str, difficult: bool, bbox: BBox): 21 | super().__init__() 22 | self.name = name 23 | self.difficult = difficult 24 | self.bbox = bbox 25 | 26 | def __repr__(self) -> str: 27 | return 'Object[name={:s}, difficult={!s}, bbox={!s}]'.format( 28 | self.name, self.difficult, self.bbox) 29 | 30 | def __init__(self, filename: str, objects: List[Object]): 31 | super().__init__() 32 | self.filename = filename 33 | self.objects = objects 34 | 35 | CATEGORY_TO_LABEL_DICT = { 36 | 'background': 0, 37 | 'cat': 1, 'dog': 2 38 | } 39 | 40 | LABEL_TO_CATEGORY_DICT = {v: k for k, v in CATEGORY_TO_LABEL_DICT.items()} 41 | 42 | def __init__(self, path_to_data_dir: str, mode: Base.Mode, image_min_side: float, image_max_side: float): 43 | super().__init__(path_to_data_dir, mode, image_min_side, image_max_side) 44 | 45 | path_to_voc2007_dir = os.path.join(self._path_to_data_dir, 'VOCdevkit', 'VOC2007') 46 | path_to_imagesets_main_dir = os.path.join(path_to_voc2007_dir, 'ImageSets', 'Main') 47 | path_to_annotations_dir = os.path.join(path_to_voc2007_dir, 'Annotations') 48 | self._path_to_jpeg_images_dir = os.path.join(path_to_voc2007_dir, 'JPEGImages') 49 | 50 | if self._mode == VOC2007CatDog.Mode.TRAIN: 51 | path_to_image_ids_txt = os.path.join(path_to_imagesets_main_dir, 'trainval.txt') 52 | elif self._mode == VOC2007CatDog.Mode.EVAL: 53 | path_to_image_ids_txt = os.path.join(path_to_imagesets_main_dir, 'test.txt') 54 | else: 55 | raise ValueError('invalid mode') 56 | 57 | with open(path_to_image_ids_txt, 'r') as f: 58 | lines = f.readlines() 59 | image_ids = [line.rstrip() for line in lines] 60 | 61 | self._image_id_to_annotation_dict = {} 62 | for image_id in image_ids: 63 | path_to_annotation_xml = os.path.join(path_to_annotations_dir, f'{image_id}.xml') 64 | tree = ET.ElementTree(file=path_to_annotation_xml) 65 | root = tree.getroot() 66 | 67 | annotation = VOC2007CatDog.Annotation( 68 | filename=next(root.iterfind('filename')).text, 69 | objects=[VOC2007CatDog.Annotation.Object(name=next(tag_object.iterfind('name')).text, 70 | difficult=next(tag_object.iterfind('difficult')).text == '1', 71 | bbox=BBox( 72 | left=float(next(tag_object.iterfind('bndbox/xmin')).text), 73 | top=float(next(tag_object.iterfind('bndbox/ymin')).text), 74 | right=float(next(tag_object.iterfind('bndbox/xmax')).text), 75 | bottom=float(next(tag_object.iterfind('bndbox/ymax')).text)) 76 | ) 77 | for tag_object in root.iterfind('object')] 78 | ) 79 | annotation.objects = [obj for obj in annotation.objects if obj.name in ['cat', 'dog'] and not obj.difficult] 80 | 81 | if len(annotation.objects) > 0: 82 | self._image_id_to_annotation_dict[image_id] = annotation 83 | 84 | self._image_ids = list(self._image_id_to_annotation_dict.keys()) 85 | 86 | def __len__(self) -> int: 87 | return len(self._image_id_to_annotation_dict) 88 | 89 | def __getitem__(self, index: int) -> Tuple[str, Tensor, float, Tensor, Tensor]: 90 | image_id = self._image_ids[index] 91 | annotation = self._image_id_to_annotation_dict[image_id] 92 | 93 | bboxes = [obj.bbox.tolist() for obj in annotation.objects] 94 | labels = [VOC2007CatDog.CATEGORY_TO_LABEL_DICT[obj.name] for obj in annotation.objects] 95 | 96 | bboxes = torch.tensor(bboxes, dtype=torch.float) 97 | labels = torch.tensor(labels, dtype=torch.long) 98 | 99 | image = Image.open(os.path.join(self._path_to_jpeg_images_dir, annotation.filename)) 100 | 101 | # random flip on only training mode 102 | if self._mode == VOC2007CatDog.Mode.TRAIN and random.random() > 0.5: 103 | image = ImageOps.mirror(image) 104 | bboxes[:, [0, 2]] = image.width - bboxes[:, [2, 0]] # index 0 and 2 represent `left` and `right` respectively 105 | 106 | image, scale = VOC2007CatDog.preprocess(image, self._image_min_side, self._image_max_side) 107 | bboxes *= scale 108 | 109 | return image_id, image, scale, bboxes, labels 110 | 111 | def evaluate(self, path_to_results_dir: str, image_ids: List[str], bboxes: List[List[float]], classes: List[int], probs: List[float]) -> Tuple[float, str]: 112 | self._write_results(path_to_results_dir, image_ids, bboxes, classes, probs) 113 | 114 | path_to_voc2007_dir = os.path.join(self._path_to_data_dir, 'VOCdevkit', 'VOC2007') 115 | path_to_main_dir = os.path.join(path_to_voc2007_dir, 'ImageSets', 'Main') 116 | path_to_annotations_dir = os.path.join(path_to_voc2007_dir, 'Annotations') 117 | 118 | class_to_ap_dict = {} 119 | for c in range(1, VOC2007CatDog.num_classes()): 120 | category = VOC2007CatDog.LABEL_TO_CATEGORY_DICT[c] 121 | try: 122 | path_to_cache_dir = os.path.join('caches', 'voc2007-cat-dog') 123 | os.makedirs(path_to_cache_dir, exist_ok=True) 124 | _, _, ap = voc_eval(detpath=os.path.join(path_to_results_dir, 'comp3_det_test_{:s}.txt'.format(category)), 125 | annopath=os.path.join(path_to_annotations_dir, '{:s}.xml'), 126 | imagesetfile=os.path.join(path_to_main_dir, 'test.txt'), 127 | classname=category, 128 | cachedir=path_to_cache_dir, 129 | ovthresh=0.5, 130 | use_07_metric=True) 131 | except IndexError: 132 | ap = 0 133 | 134 | class_to_ap_dict[c] = ap 135 | 136 | mean_ap = np.mean([v for k, v in class_to_ap_dict.items()]).item() 137 | 138 | detail = '' 139 | for c in range(1, VOC2007CatDog.num_classes()): 140 | detail += '{:d}: {:s} AP = {:.4f}\n'.format(c, VOC2007CatDog.LABEL_TO_CATEGORY_DICT[c], class_to_ap_dict[c]) 141 | 142 | return mean_ap, detail 143 | 144 | def _write_results(self, path_to_results_dir: str, image_ids: List[str], bboxes: List[List[float]], classes: List[int], probs: List[float]): 145 | class_to_txt_files_dict = {} 146 | for c in range(1, VOC2007CatDog.num_classes()): 147 | class_to_txt_files_dict[c] = open(os.path.join(path_to_results_dir, 'comp3_det_test_{:s}.txt'.format(VOC2007CatDog.LABEL_TO_CATEGORY_DICT[c])), 'w') 148 | 149 | for image_id, bbox, cls, prob in zip(image_ids, bboxes, classes, probs): 150 | class_to_txt_files_dict[cls].write('{:s} {:f} {:f} {:f} {:f} {:f}\n'.format(image_id, prob, 151 | bbox[0], bbox[1], bbox[2], bbox[3])) 152 | 153 | for _, f in class_to_txt_files_dict.items(): 154 | f.close() 155 | 156 | @staticmethod 157 | def num_classes() -> int: 158 | return 3 159 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import uuid 6 | 7 | from backbone.base import Base as BackboneBase 8 | from config.eval_config import EvalConfig as Config 9 | from dataset.base import Base as DatasetBase 10 | from evaluator import Evaluator 11 | from logger import Logger as Log 12 | from model import Model 13 | from roi.wrapper import Wrapper as ROIWrapper 14 | 15 | 16 | def _eval(path_to_checkpoint: str, dataset_name: str, backbone_name: str, path_to_data_dir: str, path_to_results_dir: str): 17 | dataset = DatasetBase.from_name(dataset_name)(path_to_data_dir, DatasetBase.Mode.EVAL, Config.IMAGE_MIN_SIDE, Config.IMAGE_MAX_SIDE) 18 | evaluator = Evaluator(dataset, path_to_data_dir, path_to_results_dir) 19 | 20 | Log.i('Found {:d} samples'.format(len(dataset))) 21 | 22 | backbone = BackboneBase.from_name(backbone_name)(pretrained=False) 23 | model = Model(backbone, dataset.num_classes(), pooling_mode=Config.POOLING_MODE, 24 | anchor_ratios=Config.ANCHOR_RATIOS, anchor_scales=Config.ANCHOR_SCALES, 25 | rpn_pre_nms_top_n=Config.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=Config.RPN_POST_NMS_TOP_N).cuda() 26 | model.load(path_to_checkpoint) 27 | 28 | mean_ap, detail = evaluator.evaluate(model) 29 | 30 | Log.i('mean AP = {:.4f}'.format(mean_ap)) 31 | Log.i('\n' + detail) 32 | 33 | 34 | if __name__ == '__main__': 35 | def main(): 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('checkpoint', type=str, help='path to evaluating checkpoint') 38 | parser.add_argument('-s', '--dataset', type=str, choices=DatasetBase.OPTIONS, required=True, help='name of dataset') 39 | parser.add_argument('-b', '--backbone', type=str, choices=BackboneBase.OPTIONS, required=True, help='name of backbone model') 40 | parser.add_argument('-d', '--data_dir', type=str, default='./data', help='path to data directory') 41 | parser.add_argument('--image_min_side', type=float, help='default: {:g}'.format(Config.IMAGE_MIN_SIDE)) 42 | parser.add_argument('--image_max_side', type=float, help='default: {:g}'.format(Config.IMAGE_MAX_SIDE)) 43 | parser.add_argument('--anchor_ratios', type=str, help='default: "{!s}"'.format(Config.ANCHOR_RATIOS)) 44 | parser.add_argument('--anchor_scales', type=str, help='default: "{!s}"'.format(Config.ANCHOR_SCALES)) 45 | parser.add_argument('--pooling_mode', type=str, choices=ROIWrapper.OPTIONS, help='default: {.value:s}'.format(Config.POOLING_MODE)) 46 | parser.add_argument('--rpn_pre_nms_top_n', type=int, help='default: {:d}'.format(Config.RPN_PRE_NMS_TOP_N)) 47 | parser.add_argument('--rpn_post_nms_top_n', type=int, help='default: {:d}'.format(Config.RPN_POST_NMS_TOP_N)) 48 | args = parser.parse_args() 49 | 50 | path_to_checkpoint = args.checkpoint 51 | dataset_name = args.dataset 52 | backbone_name = args.backbone 53 | path_to_data_dir = args.data_dir 54 | 55 | path_to_results_dir = os.path.join(os.path.dirname(path_to_checkpoint), 'results-{:s}-{:s}-{:s}'.format( 56 | time.strftime('%Y%m%d%H%M%S'), path_to_checkpoint.split(os.path.sep)[-1].split(os.path.curdir)[0], 57 | str(uuid.uuid4()).split('-')[0])) 58 | os.makedirs(path_to_results_dir) 59 | 60 | Config.setup(image_min_side=args.image_min_side, image_max_side=args.image_max_side, 61 | anchor_ratios=args.anchor_ratios, anchor_scales=args.anchor_scales, pooling_mode=args.pooling_mode, 62 | rpn_pre_nms_top_n=args.rpn_pre_nms_top_n, rpn_post_nms_top_n=args.rpn_post_nms_top_n) 63 | 64 | Log.initialize(os.path.join(path_to_results_dir, 'eval.log')) 65 | Log.i('Arguments:') 66 | for k, v in vars(args).items(): 67 | Log.i(f'\t{k} = {v}') 68 | Log.i(Config.describe()) 69 | 70 | _eval(path_to_checkpoint, dataset_name, backbone_name, path_to_data_dir, path_to_results_dir) 71 | 72 | main() 73 | -------------------------------------------------------------------------------- /evaluator.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from tqdm import tqdm 6 | 7 | from dataset.base import Base as DatasetBase 8 | from model import Model 9 | 10 | 11 | class Evaluator(object): 12 | def __init__(self, dataset: DatasetBase, path_to_data_dir: str, path_to_results_dir: str): 13 | super().__init__() 14 | self._dataset = dataset 15 | self._dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) 16 | self._path_to_data_dir = path_to_data_dir 17 | self._path_to_results_dir = path_to_results_dir 18 | 19 | def evaluate(self, model: Model) -> Tuple[float, str]: 20 | all_image_ids, all_detection_bboxes, all_detection_classes, all_detection_probs = [], [], [], [] 21 | 22 | with torch.no_grad(): 23 | for batch_index, (image_id_batch, image_batch, scale_batch, _, _) in enumerate(tqdm(self._dataloader)): 24 | image_id = image_id_batch[0] 25 | image = image_batch[0].cuda() 26 | scale = scale_batch[0].item() 27 | 28 | forward_input = Model.ForwardInput.Eval(image) 29 | forward_output: Model.ForwardOutput.Eval = model.eval().forward(forward_input) 30 | 31 | detection_bboxes, detection_classes, detection_probs = forward_output 32 | detection_bboxes /= scale 33 | 34 | selected_indices = (detection_probs > 0.05).nonzero().view(-1) 35 | detection_bboxes = detection_bboxes[selected_indices] 36 | detection_classes = detection_classes[selected_indices] 37 | detection_probs = detection_probs[selected_indices] 38 | 39 | all_detection_bboxes.extend(detection_bboxes.tolist()) 40 | all_detection_classes.extend(detection_classes.tolist()) 41 | all_detection_probs.extend(detection_probs.tolist()) 42 | all_image_ids.extend([image_id] * len(detection_bboxes)) 43 | 44 | mean_ap, detail = self._dataset.evaluate(self._path_to_results_dir, all_image_ids, all_detection_bboxes, all_detection_classes, all_detection_probs) 45 | return mean_ap, detail 46 | -------------------------------------------------------------------------------- /images/feature-pyramid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/potterhsu/easy-fpn.pytorch/cac901f2570bd8dba7bb456128c7c7985c255ea4/images/feature-pyramid.png -------------------------------------------------------------------------------- /images/inference-result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/potterhsu/easy-fpn.pytorch/cac901f2570bd8dba7bb456128c7c7985c255ea4/images/inference-result.jpg -------------------------------------------------------------------------------- /images/inference-sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/potterhsu/easy-fpn.pytorch/cac901f2570bd8dba7bb456128c7c7985c255ea4/images/inference-sample.jpg -------------------------------------------------------------------------------- /images/nms_cuda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/potterhsu/easy-fpn.pytorch/cac901f2570bd8dba7bb456128c7c7985c255ea4/images/nms_cuda.png -------------------------------------------------------------------------------- /images/rpn_find_labels_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/potterhsu/easy-fpn.pytorch/cac901f2570bd8dba7bb456128c7c7985c255ea4/images/rpn_find_labels_1.png -------------------------------------------------------------------------------- /images/rpn_find_labels_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/potterhsu/easy-fpn.pytorch/cac901f2570bd8dba7bb456128c7c7985c255ea4/images/rpn_find_labels_2.png -------------------------------------------------------------------------------- /images/test_nms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/potterhsu/easy-fpn.pytorch/cac901f2570bd8dba7bb456128c7c7985c255ea4/images/test_nms.png -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | 5 | from PIL import ImageDraw 6 | from torchvision.transforms import transforms 7 | from dataset.base import Base as DatasetBase 8 | from backbone.base import Base as BackboneBase 9 | from bbox import BBox 10 | from model import Model 11 | from roi.wrapper import Wrapper as ROIWrapper 12 | from config.eval_config import EvalConfig as Config 13 | 14 | 15 | def _infer(path_to_input_image: str, path_to_output_image: str, path_to_checkpoint: str, dataset_name: str, backbone_name: str, prob_thresh: float): 16 | image = transforms.Image.open(path_to_input_image) 17 | dataset_class = DatasetBase.from_name(dataset_name) 18 | image_tensor, scale = dataset_class.preprocess(image, Config.IMAGE_MIN_SIDE, Config.IMAGE_MAX_SIDE) 19 | 20 | backbone = BackboneBase.from_name(backbone_name)(pretrained=False) 21 | model = Model(backbone, dataset_class.num_classes(), pooling_mode=Config.POOLING_MODE, 22 | anchor_ratios=Config.ANCHOR_RATIOS, anchor_scales=Config.ANCHOR_SCALES, 23 | rpn_pre_nms_top_n=Config.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=Config.RPN_POST_NMS_TOP_N).cuda() 24 | model.load(path_to_checkpoint) 25 | 26 | forward_input = Model.ForwardInput.Eval(image_tensor.cuda()) 27 | forward_output: Model.ForwardOutput.Eval = model.eval().forward(forward_input) 28 | 29 | detection_bboxes = forward_output.detection_bboxes / scale 30 | detection_classes = forward_output.detection_classes 31 | detection_probs = forward_output.detection_probs 32 | 33 | kept_indices = detection_probs > prob_thresh 34 | detection_bboxes = detection_bboxes[kept_indices] 35 | detection_classes = detection_classes[kept_indices] 36 | detection_probs = detection_probs[kept_indices] 37 | 38 | draw = ImageDraw.Draw(image) 39 | 40 | for bbox, cls, prob in zip(detection_bboxes.tolist(), detection_classes.tolist(), detection_probs.tolist()): 41 | color = random.choice(['red', 'green', 'blue', 'yellow', 'purple', 'white']) 42 | bbox = BBox(left=bbox[0], top=bbox[1], right=bbox[2], bottom=bbox[3]) 43 | category = dataset_class.LABEL_TO_CATEGORY_DICT[cls] 44 | 45 | draw.rectangle(((bbox.left, bbox.top), (bbox.right, bbox.bottom)), outline=color) 46 | draw.text((bbox.left, bbox.top), text=f'{category:s} {prob:.3f}', fill=color) 47 | 48 | image.save(path_to_output_image) 49 | print(f'Output image is saved to {path_to_output_image}') 50 | 51 | 52 | if __name__ == '__main__': 53 | def main(): 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument('input', type=str, help='path to input image') 56 | parser.add_argument('output', type=str, help='path to output result image') 57 | parser.add_argument('-c', '--checkpoint', type=str, required=True, help='path to checkpoint') 58 | parser.add_argument('-s', '--dataset', type=str, choices=DatasetBase.OPTIONS, required=True, help='name of dataset') 59 | parser.add_argument('-b', '--backbone', type=str, choices=BackboneBase.OPTIONS, required=True, help='name of backbone model') 60 | parser.add_argument('-p', '--probability_threshold', type=float, default=0.6, help='threshold of detection probability') 61 | parser.add_argument('--image_min_side', type=float, help='default: {:g}'.format(Config.IMAGE_MIN_SIDE)) 62 | parser.add_argument('--image_max_side', type=float, help='default: {:g}'.format(Config.IMAGE_MAX_SIDE)) 63 | parser.add_argument('--anchor_ratios', type=str, help='default: "{!s}"'.format(Config.ANCHOR_RATIOS)) 64 | parser.add_argument('--anchor_scales', type=str, help='default: "{!s}"'.format(Config.ANCHOR_SCALES)) 65 | parser.add_argument('--pooling_mode', type=str, choices=ROIWrapper.OPTIONS, help='default: {.value:s}'.format(Config.POOLING_MODE)) 66 | parser.add_argument('--rpn_pre_nms_top_n', type=int, help='default: {:d}'.format(Config.RPN_PRE_NMS_TOP_N)) 67 | parser.add_argument('--rpn_post_nms_top_n', type=int, help='default: {:d}'.format(Config.RPN_POST_NMS_TOP_N)) 68 | args = parser.parse_args() 69 | 70 | path_to_input_image = args.input 71 | path_to_output_image = args.output 72 | path_to_checkpoint = args.checkpoint 73 | dataset_name = args.dataset 74 | backbone_name = args.backbone 75 | prob_thresh = args.probability_threshold 76 | 77 | os.makedirs(os.path.join(os.path.curdir, os.path.dirname(path_to_output_image)), exist_ok=True) 78 | 79 | Config.setup(image_min_side=args.image_min_side, image_max_side=args.image_max_side, 80 | anchor_ratios=args.anchor_ratios, anchor_scales=args.anchor_scales, pooling_mode=args.pooling_mode, 81 | rpn_pre_nms_top_n=args.rpn_pre_nms_top_n, rpn_post_nms_top_n=args.rpn_post_nms_top_n) 82 | 83 | print('Arguments:') 84 | for k, v in vars(args).items(): 85 | print(f'\t{k} = {v}') 86 | print(Config.describe()) 87 | 88 | _infer(path_to_input_image, path_to_output_image, path_to_checkpoint, dataset_name, backbone_name, prob_thresh) 89 | 90 | main() 91 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | class Logger(object): 5 | Initialized = False 6 | 7 | @staticmethod 8 | def initialize(path_to_log_file): 9 | logging.basicConfig(level=logging.INFO, 10 | format='%(asctime)s %(levelname)-8s %(message)s', 11 | datefmt='%Y-%m-%d %H:%M:%S', 12 | handlers=[logging.FileHandler(path_to_log_file), 13 | logging.StreamHandler()]) 14 | Logger.Initialized = True 15 | 16 | @staticmethod 17 | def log(level, message): 18 | assert Logger.Initialized, 'Logger has not been initialized' 19 | logging.log(level, message) 20 | 21 | @staticmethod 22 | def d(message): 23 | Logger.log(logging.DEBUG, message) 24 | 25 | @staticmethod 26 | def i(message): 27 | Logger.log(logging.INFO, message) 28 | 29 | @staticmethod 30 | def w(message): 31 | Logger.log(logging.WARNING, message) 32 | 33 | @staticmethod 34 | def e(message): 35 | Logger.log(logging.ERROR, message) 36 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union, Tuple, List, NamedTuple 3 | 4 | import torch 5 | from torch import nn, Tensor 6 | from torch.nn import functional as F 7 | from torch.optim import Optimizer 8 | from torch.optim.lr_scheduler import _LRScheduler 9 | 10 | from backbone.base import Base as BackboneBase 11 | from bbox import BBox 12 | from nms.nms import NMS 13 | from roi.wrapper import Wrapper as ROIWrapper 14 | from rpn.region_proposal_network import RegionProposalNetwork 15 | 16 | 17 | class Model(nn.Module): 18 | 19 | class ForwardInput(object): 20 | class Train(NamedTuple): 21 | image: Tensor 22 | gt_classes: Tensor 23 | gt_bboxes: Tensor 24 | 25 | class Eval(NamedTuple): 26 | image: Tensor 27 | 28 | class ForwardOutput(object): 29 | class Train(NamedTuple): 30 | anchor_objectness_loss: Tensor 31 | anchor_transformer_loss: Tensor 32 | proposal_class_loss: Tensor 33 | proposal_transformer_loss: Tensor 34 | 35 | class Eval(NamedTuple): 36 | detection_bboxes: Tensor 37 | detection_classes: Tensor 38 | detection_probs: Tensor 39 | 40 | def __init__(self, backbone: BackboneBase, num_classes: int, pooling_mode: ROIWrapper.Mode, 41 | anchor_ratios: List[Tuple[int, int]], anchor_scales: List[int], rpn_pre_nms_top_n: int, rpn_post_nms_top_n: int): 42 | super().__init__() 43 | 44 | conv_layers, lateral_layers, dealiasing_layers, num_features_out = backbone.features() 45 | self.conv1, self.conv2, self.conv3, self.conv4, self.conv5 = conv_layers 46 | self.lateral_c2, self.lateral_c3, self.lateral_c4, self.lateral_c5 = lateral_layers 47 | self.dealiasing_p2, self.dealiasing_p3, self.dealiasing_p4 = dealiasing_layers 48 | 49 | self._bn_modules = [it for it in self.conv1.modules() if isinstance(it, nn.BatchNorm2d)] + \ 50 | [it for it in self.conv2.modules() if isinstance(it, nn.BatchNorm2d)] + \ 51 | [it for it in self.conv3.modules() if isinstance(it, nn.BatchNorm2d)] + \ 52 | [it for it in self.conv4.modules() if isinstance(it, nn.BatchNorm2d)] + \ 53 | [it for it in self.conv5.modules() if isinstance(it, nn.BatchNorm2d)] + \ 54 | [it for it in self.lateral_c2.modules() if isinstance(it, nn.BatchNorm2d)] + \ 55 | [it for it in self.lateral_c3.modules() if isinstance(it, nn.BatchNorm2d)] + \ 56 | [it for it in self.lateral_c4.modules() if isinstance(it, nn.BatchNorm2d)] + \ 57 | [it for it in self.lateral_c5.modules() if isinstance(it, nn.BatchNorm2d)] + \ 58 | [it for it in self.dealiasing_p2.modules() if isinstance(it, nn.BatchNorm2d)] + \ 59 | [it for it in self.dealiasing_p3.modules() if isinstance(it, nn.BatchNorm2d)] + \ 60 | [it for it in self.dealiasing_p4.modules() if isinstance(it, nn.BatchNorm2d)] 61 | 62 | self.num_classes = num_classes 63 | 64 | self.rpn = RegionProposalNetwork(num_features_out, anchor_ratios, anchor_scales, rpn_pre_nms_top_n, rpn_post_nms_top_n) 65 | self.detection = Model.Detection(pooling_mode, self.num_classes) 66 | 67 | def forward(self, forward_input: Union[ForwardInput.Train, ForwardInput.Eval]) -> Union[ForwardOutput.Train, ForwardOutput.Eval]: 68 | # freeze batch normalization modules for each forwarding process just in case model was switched to `train` at any time 69 | for bn_module in self._bn_modules: 70 | bn_module.eval() 71 | for parameter in bn_module.parameters(): 72 | parameter.requires_grad = False 73 | 74 | image = forward_input.image.unsqueeze(dim=0) 75 | image_height, image_width = image.shape[2], image.shape[3] 76 | 77 | # Bottom-up pathway 78 | c1 = self.conv1(image) 79 | c2 = self.conv2(c1) 80 | c3 = self.conv3(c2) 81 | c4 = self.conv4(c3) 82 | c5 = self.conv5(c4) 83 | 84 | # Top-down pathway and lateral connections 85 | p5 = self.lateral_c5(c5) 86 | p4 = self.lateral_c4(c4) + F.interpolate(input=p5, size=(c4.shape[2], c4.shape[3]), mode='nearest') 87 | p3 = self.lateral_c3(c3) + F.interpolate(input=p4, size=(c3.shape[2], c3.shape[3]), mode='nearest') 88 | p2 = self.lateral_c2(c2) + F.interpolate(input=p3, size=(c2.shape[2], c2.shape[3]), mode='nearest') 89 | 90 | # Reduce the aliasing effect 91 | p4 = self.dealiasing_p4(p4) 92 | p3 = self.dealiasing_p3(p3) 93 | p2 = self.dealiasing_p2(p2) 94 | 95 | p6 = F.max_pool2d(input=p5, kernel_size=1, stride=2) 96 | 97 | # NOTE: We define the anchors to have areas of {32^2, 64^2, 128^2, 256^2, 512^2} pixels on {P2, P3, P4, P5, P6} respectively 98 | 99 | anchor_objectnesses = [] 100 | anchor_transformers = [] 101 | anchor_bboxes = [] 102 | proposal_bboxes = [] 103 | 104 | for p, anchor_size in zip([p2, p3, p4, p5, p6], [32, 64, 128, 256, 512]): 105 | p_anchor_objectnesses, p_anchor_transformers = self.rpn.forward(features=p, image_width=image_width, image_height=image_height) 106 | p_anchor_bboxes = self.rpn.generate_anchors(image_width, image_height, 107 | num_x_anchors=p.shape[3], num_y_anchors=p.shape[2], 108 | anchor_size=anchor_size).cuda() 109 | p_proposal_bboxes = self.rpn.generate_proposals(p_anchor_bboxes, p_anchor_objectnesses, p_anchor_transformers, 110 | image_width, image_height) 111 | anchor_objectnesses.append(p_anchor_objectnesses) 112 | anchor_transformers.append(p_anchor_transformers) 113 | anchor_bboxes.append(p_anchor_bboxes) 114 | proposal_bboxes.append(p_proposal_bboxes) 115 | 116 | anchor_objectnesses = torch.cat(anchor_objectnesses, dim=0) 117 | anchor_transformers = torch.cat(anchor_transformers, dim=0) 118 | anchor_bboxes = torch.cat(anchor_bboxes, dim=0) 119 | proposal_bboxes = torch.cat(proposal_bboxes, dim=0) 120 | 121 | if self.training: 122 | forward_input: Model.ForwardInput.Train 123 | 124 | anchor_sample_fg_indices, anchor_sample_selected_indices, gt_anchor_objectnesses, gt_anchor_transformers = self.rpn.sample(anchor_bboxes, forward_input.gt_bboxes, image_width, image_height) 125 | anchor_objectnesses = anchor_objectnesses[anchor_sample_selected_indices] 126 | anchor_transformers = anchor_transformers[anchor_sample_fg_indices] 127 | anchor_objectness_loss, anchor_transformer_loss = self.rpn.loss(anchor_objectnesses, anchor_transformers, gt_anchor_objectnesses, gt_anchor_transformers) 128 | 129 | proposal_sample_fg_indices, proposal_sample_selected_indices, gt_proposal_classes, gt_proposal_transformers = self.detection.sample(proposal_bboxes, forward_input.gt_classes, forward_input.gt_bboxes) 130 | proposal_bboxes = proposal_bboxes[proposal_sample_selected_indices] 131 | proposal_classes, proposal_transformers = self.detection.forward(p2, p3, p4, p5, proposal_bboxes, image_width, image_height) 132 | proposal_class_loss, proposal_transformer_loss = self.detection.loss(proposal_classes, proposal_transformers, gt_proposal_classes, gt_proposal_transformers) 133 | 134 | forward_output = Model.ForwardOutput.Train(anchor_objectness_loss, anchor_transformer_loss, proposal_class_loss, proposal_transformer_loss) 135 | else: 136 | proposal_classes, proposal_transformers = self.detection.forward(p2, p3, p4, p5, proposal_bboxes, image_width, image_height) 137 | detection_bboxes, detection_classes, detection_probs = self.detection.generate_detections(proposal_bboxes, proposal_classes, proposal_transformers, image_width, image_height) 138 | forward_output = Model.ForwardOutput.Eval(detection_bboxes, detection_classes, detection_probs) 139 | 140 | return forward_output 141 | 142 | def save(self, path_to_checkpoints_dir: str, step: int, optimizer: Optimizer, scheduler: _LRScheduler) -> str: 143 | path_to_checkpoint = os.path.join(path_to_checkpoints_dir, f'model-{step}.pth') 144 | checkpoint = { 145 | 'state_dict': self.state_dict(), 146 | 'step': step, 147 | 'optimizer_state_dict': optimizer.state_dict(), 148 | 'scheduler_state_dict': scheduler.state_dict() 149 | } 150 | torch.save(checkpoint, path_to_checkpoint) 151 | return path_to_checkpoint 152 | 153 | def load(self, path_to_checkpoint: str, optimizer: Optimizer = None, scheduler: _LRScheduler = None) -> 'Model': 154 | checkpoint = torch.load(path_to_checkpoint) 155 | self.load_state_dict(checkpoint['state_dict']) 156 | step = checkpoint['step'] 157 | if optimizer is not None: 158 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 159 | if scheduler is not None: 160 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 161 | return step 162 | 163 | class Detection(nn.Module): 164 | 165 | def __init__(self, pooling_mode: ROIWrapper.Mode, num_classes: int): 166 | super().__init__() 167 | self._pooling_mode = pooling_mode 168 | self._hidden = nn.Sequential( 169 | nn.Linear(256 * 7 * 7, 1024), 170 | nn.ReLU(), 171 | nn.Linear(1024, 1024), 172 | nn.ReLU() 173 | ) 174 | self.num_classes = num_classes 175 | self._class = nn.Linear(1024, num_classes) 176 | self._transformer = nn.Linear(1024, num_classes * 4) 177 | self._transformer_normalize_mean = torch.tensor([0., 0., 0., 0.], dtype=torch.float).cuda() 178 | self._transformer_normalize_std = torch.tensor([.1, .1, .2, .2], dtype=torch.float).cuda() 179 | 180 | def forward(self, p2: Tensor, p3: Tensor, p4: Tensor, p5: Tensor, proposal_bboxes: Tensor, image_width: int, image_height: int) -> Tuple[Tensor, Tensor]: 181 | w = proposal_bboxes[:, 2] - proposal_bboxes[:, 0] 182 | h = proposal_bboxes[:, 3] - proposal_bboxes[:, 1] 183 | k0 = 4 184 | k = torch.floor(k0 + torch.log2(torch.sqrt(w * h) / 224)).long() 185 | k = torch.clamp(k, min=2, max=5) 186 | 187 | k_to_p_dict = {2: p2, 3: p3, 4: p4, 5: p5} 188 | unique_k = torch.unique(k) 189 | 190 | # NOTE: `picked_indices` is for recording the order of selection from `proposal_bboxes` 191 | # so that `pools` can be then restored to make it have a consistent correspondence 192 | # with `proposal_bboxes`. For example: 193 | # 194 | # proposal_bboxes => B0 B1 B2 195 | # picked_indices => 1 2 0 196 | # pools => BP1 BP2 BP0 197 | # sorted_indices => 2 0 1 198 | # pools => BP0 BP1 BP2 199 | 200 | pools = [] 201 | picked_indices = [] 202 | 203 | for uk in unique_k: 204 | uk = uk.item() 205 | p = k_to_p_dict[uk] 206 | uk_indices = (k == uk).nonzero().view(-1) 207 | uk_proposal_bboxes = proposal_bboxes[uk_indices] 208 | pool = ROIWrapper.apply(p, uk_proposal_bboxes, mode=self._pooling_mode, image_width=image_width, image_height=image_height) 209 | pools.append(pool) 210 | picked_indices.append(uk_indices) 211 | 212 | pools = torch.cat(pools, dim=0) 213 | picked_indices = torch.cat(picked_indices, dim=0) 214 | 215 | _, sorted_indices = torch.sort(picked_indices) 216 | pools = pools[sorted_indices] 217 | 218 | pools = pools.view(pools.shape[0], -1) 219 | hidden = self._hidden(pools) 220 | classes = self._class(hidden) 221 | transformers = self._transformer(hidden) 222 | return classes, transformers 223 | 224 | def sample(self, proposal_bboxes: Tensor, gt_classes: Tensor, gt_bboxes: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: 225 | sample_fg_indices = torch.arange(end=len(proposal_bboxes), dtype=torch.long) 226 | sample_selected_indices = torch.arange(end=len(proposal_bboxes), dtype=torch.long) 227 | 228 | # find labels for each `proposal_bboxes` 229 | labels = torch.ones(len(proposal_bboxes), dtype=torch.long).cuda() * -1 230 | ious = BBox.iou(proposal_bboxes, gt_bboxes) 231 | proposal_max_ious, proposal_assignments = ious.max(dim=1) 232 | labels[proposal_max_ious < 0.5] = 0 233 | labels[proposal_max_ious >= 0.5] = gt_classes[proposal_assignments[proposal_max_ious >= 0.5]] 234 | 235 | # select 128 samples 236 | fg_indices = (labels > 0).nonzero().view(-1) 237 | bg_indices = (labels == 0).nonzero().view(-1) 238 | fg_indices = fg_indices[torch.randperm(len(fg_indices))[:min(len(fg_indices), 32)]] 239 | bg_indices = bg_indices[torch.randperm(len(bg_indices))[:128 - len(fg_indices)]] 240 | selected_indices = torch.cat([fg_indices, bg_indices]) 241 | selected_indices = selected_indices[torch.randperm(len(selected_indices))] 242 | 243 | proposal_bboxes = proposal_bboxes[selected_indices] 244 | gt_proposal_transformers = BBox.calc_transformer(proposal_bboxes, gt_bboxes[proposal_assignments[selected_indices]]) 245 | gt_proposal_classes = labels[selected_indices] 246 | 247 | gt_proposal_transformers = (gt_proposal_transformers - self._transformer_normalize_mean) / self._transformer_normalize_std 248 | 249 | gt_proposal_transformers = gt_proposal_transformers.cuda() 250 | gt_proposal_classes = gt_proposal_classes.cuda() 251 | 252 | sample_fg_indices = sample_fg_indices[fg_indices] 253 | sample_selected_indices = sample_selected_indices[selected_indices] 254 | 255 | return sample_fg_indices, sample_selected_indices, gt_proposal_classes, gt_proposal_transformers 256 | 257 | def loss(self, proposal_classes: Tensor, proposal_transformers: Tensor, gt_proposal_classes: Tensor, gt_proposal_transformers: Tensor) -> Tuple[Tensor, Tensor]: 258 | cross_entropy = F.cross_entropy(input=proposal_classes, target=gt_proposal_classes) 259 | 260 | proposal_transformers = proposal_transformers.view(-1, self.num_classes, 4) 261 | proposal_transformers = proposal_transformers[torch.arange(end=len(proposal_transformers), dtype=torch.long).cuda(), gt_proposal_classes] 262 | 263 | fg_indices = gt_proposal_classes.nonzero().view(-1) 264 | 265 | # NOTE: The default of `reduction` is `elementwise_mean`, which is divided by N x 4 (number of all elements), here we replaced by N for better performance 266 | smooth_l1_loss = F.smooth_l1_loss(input=proposal_transformers[fg_indices], target=gt_proposal_transformers[fg_indices], reduction='sum') 267 | smooth_l1_loss /= len(gt_proposal_transformers) 268 | 269 | return cross_entropy, smooth_l1_loss 270 | 271 | def generate_detections(self, proposal_bboxes: Tensor, proposal_classes: Tensor, proposal_transformers: Tensor, image_width: int, image_height: int) -> Tuple[Tensor, Tensor, Tensor]: 272 | proposal_transformers = proposal_transformers.view(-1, self.num_classes, 4) 273 | mean = self._transformer_normalize_mean.repeat(1, self.num_classes, 1) 274 | std = self._transformer_normalize_std.repeat(1, self.num_classes, 1) 275 | 276 | proposal_transformers = proposal_transformers * std - mean 277 | proposal_bboxes = proposal_bboxes.view(-1, 1, 4).repeat(1, self.num_classes, 1) 278 | detection_bboxes = BBox.apply_transformer(proposal_bboxes.view(-1, 4), proposal_transformers.view(-1, 4)) 279 | 280 | detection_bboxes = detection_bboxes.view(-1, self.num_classes, 4) 281 | 282 | detection_bboxes[:, :, [0, 2]] = detection_bboxes[:, :, [0, 2]].clamp(min=0, max=image_width) 283 | detection_bboxes[:, :, [1, 3]] = detection_bboxes[:, :, [1, 3]].clamp(min=0, max=image_height) 284 | 285 | proposal_probs = F.softmax(proposal_classes, dim=1) 286 | 287 | detection_bboxes = detection_bboxes.cpu() 288 | proposal_probs = proposal_probs.cpu() 289 | 290 | generated_bboxes = [] 291 | generated_classes = [] 292 | generated_probs = [] 293 | 294 | for c in range(1, self.num_classes): 295 | detection_class_bboxes = detection_bboxes[:, c, :] 296 | proposal_class_probs = proposal_probs[:, c] 297 | 298 | _, sorted_indices = proposal_class_probs.sort(descending=True) 299 | detection_class_bboxes = detection_class_bboxes[sorted_indices] 300 | proposal_class_probs = proposal_class_probs[sorted_indices] 301 | 302 | kept_indices = NMS.suppress(detection_class_bboxes.cuda(), threshold=0.3) 303 | detection_class_bboxes = detection_class_bboxes[kept_indices] 304 | proposal_class_probs = proposal_class_probs[kept_indices] 305 | 306 | generated_bboxes.append(detection_class_bboxes) 307 | generated_classes.append(torch.ones(len(kept_indices), dtype=torch.int) * c) 308 | generated_probs.append(proposal_class_probs) 309 | 310 | generated_bboxes = torch.cat(generated_bboxes, dim=0) 311 | generated_classes = torch.cat(generated_classes, dim=0) 312 | generated_probs = torch.cat(generated_probs, dim=0) 313 | return generated_bboxes, generated_classes, generated_probs 314 | -------------------------------------------------------------------------------- /nms/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch.utils.ffi import create_extension 4 | 5 | ffi = create_extension( 6 | name='_ext.nms', 7 | headers=['src/nms.h'], 8 | sources=['src/nms.c'], 9 | extra_objects=[os.path.join(os.path.dirname(os.path.abspath(__file__)), it) for it in ['src/nms_cuda.o']], 10 | relative_to=__file__, 11 | with_cuda=True 12 | ) 13 | 14 | if __name__ == '__main__': 15 | ffi.build() 16 | -------------------------------------------------------------------------------- /nms/nms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from nms._ext import nms 4 | from torch import Tensor 5 | 6 | 7 | class NMS(object): 8 | 9 | @staticmethod 10 | def suppress(sorted_bboxes: Tensor, threshold: float) -> Tensor: 11 | kept_indices = torch.tensor([], dtype=torch.long).cuda() 12 | nms.suppress(sorted_bboxes.contiguous(), threshold, kept_indices) 13 | return kept_indices 14 | -------------------------------------------------------------------------------- /nms/src/nms.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include "nms_cuda.h" 3 | 4 | extern THCState *state; 5 | 6 | int suppress(THCudaTensor *bboxes, float threshold, THCudaLongTensor *keepIndices) { 7 | if (!((THCudaTensor_nDimension(state, bboxes) == 2) && (THCudaTensor_size(state, bboxes, 1) == 4))) 8 | return 0; 9 | 10 | long numBoxes = THCudaTensor_size(state, bboxes, 0); 11 | THLongTensor *keepIndicesTmp = THLongTensor_newWithSize1d(numBoxes); 12 | 13 | long numKeepBoxes; 14 | nms(THCudaTensor_data(state, bboxes), numBoxes, threshold, THLongTensor_data(keepIndicesTmp), &numKeepBoxes); 15 | 16 | THLongTensor_resize1d(keepIndicesTmp, numKeepBoxes); 17 | THCudaLongTensor_resize1d(state, keepIndices, numKeepBoxes); 18 | THCudaLongTensor_copyCPU(state, keepIndices, keepIndicesTmp); 19 | 20 | THLongTensor_free(keepIndicesTmp); 21 | 22 | return 1; 23 | } 24 | -------------------------------------------------------------------------------- /nms/src/nms.h: -------------------------------------------------------------------------------- 1 | int suppress(THCudaTensor *bboxes, float threshold, THCudaLongTensor *keepIndices); -------------------------------------------------------------------------------- /nms/src/nms_cuda.cu: -------------------------------------------------------------------------------- 1 | #include "nms_cuda.h" 2 | 3 | #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) 4 | 5 | typedef unsigned long long MaskType; 6 | 7 | const long numThreadsPerBlock = sizeof(MaskType) * 8; 8 | 9 | __device__ inline float iou(const float *bbox1, const float *bbox2) { 10 | float intersectionLeft = max(bbox1[0], bbox2[0]); 11 | float intersectionTop = max(bbox1[1], bbox2[1]); 12 | float intersectionRight = min(bbox1[2], bbox2[2]); 13 | float intersectionBottom = min(bbox1[3], bbox2[3]); 14 | float intersectionWidth = max(intersectionRight - intersectionLeft, 0.f); 15 | float intersectionHeight = max(intersectionBottom - intersectionTop, 0.f); 16 | float intersectionArea = intersectionWidth * intersectionHeight; 17 | float bbox1Area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]); 18 | float bbox2Area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]); 19 | return intersectionArea / (bbox1Area + bbox2Area - intersectionArea); 20 | } 21 | 22 | __global__ void nms_kernel(const float *bboxes, long numBoxes, float threshold, MaskType *suppressionMask) { 23 | int i; 24 | int bidX = blockIdx.x; 25 | int bidY = blockIdx.y; 26 | int tid = threadIdx.x; 27 | const long blockBoxStartX = bidX * numThreadsPerBlock; 28 | const long blockBoxStartY = bidY * numThreadsPerBlock; 29 | const long blockBoxEndX = min(blockBoxStartX + numThreadsPerBlock, numBoxes); 30 | const long blockBoxEndY = min(blockBoxStartY + numThreadsPerBlock, numBoxes); 31 | const long currentBoxY = blockBoxStartY + tid; 32 | 33 | if (currentBoxY < blockBoxEndY) { 34 | MaskType suppression = 0; 35 | 36 | const float *currentBox = bboxes + currentBoxY * 4; 37 | for (i = 0; i < blockBoxEndX - blockBoxStartX; ++i) { 38 | long targetBoxX = blockBoxStartX + i; 39 | if (targetBoxX > currentBoxY) { 40 | const float *targetBox = bboxes + targetBoxX * 4; 41 | if (iou(currentBox, targetBox) > threshold) { 42 | suppression |= 1ULL << i; 43 | } 44 | } 45 | } 46 | 47 | const long numBlockCols = DIVUP(numBoxes, numThreadsPerBlock); 48 | suppressionMask[currentBoxY * numBlockCols + bidX] = suppression; 49 | } 50 | } 51 | 52 | void nms(const float *bboxesInDevice, long numBoxes, float threshold, long *keepIndices, long *numKeepBoxes) { 53 | int i, j; 54 | const long numBlockCols = DIVUP(numBoxes, numThreadsPerBlock); 55 | 56 | MaskType *suppressionMaskInDevice; 57 | cudaMalloc(&suppressionMaskInDevice, sizeof(MaskType) * numBoxes * numBlockCols); 58 | 59 | dim3 blocks(numBlockCols, numBlockCols); 60 | dim3 threads(numThreadsPerBlock); 61 | nms_kernel<<>>(bboxesInDevice, numBoxes, threshold, suppressionMaskInDevice); 62 | 63 | MaskType *suppressionMask = (MaskType *) malloc(sizeof(MaskType) * numBoxes * numBlockCols); 64 | cudaMemcpy(suppressionMask, suppressionMaskInDevice, sizeof(MaskType) * numBoxes * numBlockCols, cudaMemcpyDeviceToHost); 65 | 66 | MaskType *maskRow = (MaskType *) malloc(sizeof(MaskType) * numBlockCols); 67 | memset(maskRow, 0, sizeof(MaskType) * numBlockCols); 68 | long nKeepBoxes = 0; 69 | for (i = 0; i < numBoxes; ++i) { 70 | long block = i / numThreadsPerBlock; 71 | long offset = i % numThreadsPerBlock; 72 | if (!(maskRow[block] & (1ULL << offset))) { 73 | keepIndices[nKeepBoxes++] = i; 74 | for (j = 0; j < numBlockCols; ++j) { 75 | maskRow[j] |= suppressionMask[i * numBlockCols + j]; 76 | } 77 | } 78 | } 79 | *numKeepBoxes = nKeepBoxes; 80 | 81 | cudaFree(suppressionMaskInDevice); 82 | free(suppressionMask); 83 | free(maskRow); 84 | } 85 | -------------------------------------------------------------------------------- /nms/src/nms_cuda.h: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | void nms(const float *bboxesInDevice, long numBoxes, float threshold, long *keepIndices, long *numKeepBoxes); 6 | 7 | #ifdef __cplusplus 8 | } 9 | #endif -------------------------------------------------------------------------------- /nms/test/nms-large-input.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/potterhsu/easy-fpn.pytorch/cac901f2570bd8dba7bb456128c7c7985c255ea4/nms/test/nms-large-input.npy -------------------------------------------------------------------------------- /nms/test/nms-large-output.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/potterhsu/easy-fpn.pytorch/cac901f2570bd8dba7bb456128c7c7985c255ea4/nms/test/nms-large-output.npy -------------------------------------------------------------------------------- /nms/test/test_nms.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import unittest 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from nms.nms import NMS 9 | 10 | 11 | class TestNMS(unittest.TestCase): 12 | def _run_nms(self, bboxes): 13 | start = time.time() 14 | kept_indices = NMS.suppress(bboxes.contiguous(), threshold=0.7) 15 | print('%s in %.3fs, %d -> %d' % (self.id(), time.time() - start, len(bboxes), len(kept_indices))) 16 | return kept_indices 17 | 18 | def test_nms_empty(self): 19 | bboxes = torch.FloatTensor().cuda() 20 | kept_indices = self._run_nms(bboxes) 21 | self.assertEqual(len(kept_indices), 0) 22 | 23 | def test_nms_single(self): 24 | bboxes = torch.FloatTensor([[5, 5, 10, 10]]).cuda() 25 | kept_indices = self._run_nms(bboxes) 26 | self.assertEqual(len(kept_indices), 1) 27 | self.assertListEqual(kept_indices.tolist(), [0]) 28 | 29 | def test_nms_small(self): 30 | # bboxes = torch.FloatTensor([[5, 5, 10, 10], [5, 5, 10, 10], [5, 5, 30, 30]]).cuda() 31 | bboxes = torch.FloatTensor([[5, 5, 10, 10], [5, 5, 30, 30]]).cuda() 32 | kept_indices = self._run_nms(bboxes) 33 | self.assertEqual(len(kept_indices), 2) 34 | # self.assertListEqual(kept_indices.tolist(), [0, 2]) 35 | self.assertListEqual(kept_indices.tolist(), [0, 1]) 36 | 37 | def test_nms_large(self): 38 | # detections format: [[left, top, right, bottom, score], ...], which (right, bottom) is included in area 39 | detections = np.load(os.path.join('nms', 'test', 'nms-large-input.npy')) 40 | bboxes = torch.FloatTensor(detections).cuda() 41 | sorted_indices = torch.sort(bboxes[:, 4], dim=0, descending=True)[1] 42 | bboxes = bboxes[:, 0:4][sorted_indices] 43 | 44 | # point of (right, bottom) in our bbox definition is not included in area 45 | bboxes[:, 2] += 1 46 | bboxes[:, 3] += 1 47 | 48 | kept_indices = self._run_nms(bboxes) 49 | kept_indices_for_detection = sorted_indices[kept_indices] 50 | self.assertEqual(len(kept_indices_for_detection), 1934) 51 | 52 | expect = np.load(os.path.join('nms', 'test', 'nms-large-output.npy')) 53 | self.assertListEqual(kept_indices_for_detection.tolist(), expect.tolist()) 54 | 55 | 56 | if __name__ == '__main__': 57 | assert torch.cuda.is_available(), 'NMS module requires CUDA support' 58 | torch.FloatTensor().cuda() # dummy for initializing GPU 59 | unittest.main() 60 | -------------------------------------------------------------------------------- /outputs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /realtime.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import random 4 | import time 5 | 6 | import cv2 7 | import numpy as np 8 | from PIL import ImageDraw, Image 9 | 10 | from backbone.base import Base as BackboneBase 11 | from config.eval_config import EvalConfig as Config 12 | from dataset.base import Base as DatasetBase 13 | from bbox import BBox 14 | from model import Model 15 | from roi.wrapper import Wrapper as ROIWrapper 16 | 17 | 18 | def _realtime(path_to_input_stream_endpoint: str, period_of_inference: int, path_to_checkpoint: str, dataset_name: str, backbone_name: str, prob_thresh: float): 19 | video_capture = cv2.VideoCapture(path_to_input_stream_endpoint) 20 | 21 | dataset_class = DatasetBase.from_name(dataset_name) 22 | backbone = BackboneBase.from_name(backbone_name)(pretrained=False) 23 | model = Model(backbone, dataset_class.num_classes(), pooling_mode=Config.POOLING_MODE, 24 | anchor_ratios=Config.ANCHOR_RATIOS, anchor_scales=Config.ANCHOR_SCALES, 25 | rpn_pre_nms_top_n=Config.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=Config.RPN_POST_NMS_TOP_N).cuda() 26 | model.load(path_to_checkpoint) 27 | 28 | for sn in itertools.count(start=1): 29 | _, frame = video_capture.read() 30 | 31 | if sn % period_of_inference != 0: 32 | continue 33 | 34 | timestamp = time.time() 35 | 36 | image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 37 | image = Image.fromarray(image) 38 | image_tensor, scale = dataset_class.preprocess(image, Config.IMAGE_MIN_SIDE, Config.IMAGE_MAX_SIDE) 39 | 40 | forward_input = Model.ForwardInput.Eval(image_tensor.cuda()) 41 | forward_output: Model.ForwardOutput.Eval = model.eval().forward(forward_input) 42 | 43 | detection_bboxes = forward_output.detection_bboxes / scale 44 | detection_classes = forward_output.detection_classes 45 | detection_probs = forward_output.detection_probs 46 | 47 | kept_indices = detection_probs > prob_thresh 48 | detection_bboxes = detection_bboxes[kept_indices] 49 | detection_classes = detection_classes[kept_indices] 50 | detection_probs = detection_probs[kept_indices] 51 | 52 | draw = ImageDraw.Draw(image) 53 | 54 | for bbox, cls, prob in zip(detection_bboxes.tolist(), detection_classes.tolist(), detection_probs.tolist()): 55 | color = random.choice(['red', 'green', 'blue', 'yellow', 'purple', 'white']) 56 | bbox = BBox(left=bbox[0], top=bbox[1], right=bbox[2], bottom=bbox[3]) 57 | category = dataset_class.LABEL_TO_CATEGORY_DICT[cls] 58 | 59 | draw.rectangle(((bbox.left, bbox.top), (bbox.right, bbox.bottom)), outline=color) 60 | draw.text((bbox.left, bbox.top), text=f'{category:s} {prob:.3f}', fill=color) 61 | 62 | image = np.array(image) 63 | frame = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 64 | 65 | elapse = time.time() - timestamp 66 | fps = 1 / elapse 67 | cv2.putText(frame, f'FPS = {fps:.1f}', (100, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 1, cv2.LINE_AA) 68 | 69 | cv2.imshow('easy-faster-rcnn.pytorch', frame) 70 | if cv2.waitKey(10) == 27: 71 | break 72 | 73 | cv2.destroyAllWindows() 74 | 75 | 76 | if __name__ == '__main__': 77 | def main(): 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument('input', type=str, help='path to input stream endpoint') 80 | parser.add_argument('period', type=int, help='period of inference') 81 | parser.add_argument('-c', '--checkpoint', type=str, required=True, help='path to checkpoint') 82 | parser.add_argument('-s', '--dataset', type=str, choices=DatasetBase.OPTIONS, required=True, help='name of dataset') 83 | parser.add_argument('-b', '--backbone', type=str, choices=BackboneBase.OPTIONS, required=True, help='name of backbone model') 84 | parser.add_argument('-p', '--probability_threshold', type=float, default=0.6, help='threshold of detection probability') 85 | parser.add_argument('--image_min_side', type=float, help='default: {:g}'.format(Config.IMAGE_MIN_SIDE)) 86 | parser.add_argument('--image_max_side', type=float, help='default: {:g}'.format(Config.IMAGE_MAX_SIDE)) 87 | parser.add_argument('--anchor_ratios', type=str, help='default: "{!s}"'.format(Config.ANCHOR_RATIOS)) 88 | parser.add_argument('--anchor_scales', type=str, help='default: "{!s}"'.format(Config.ANCHOR_SCALES)) 89 | parser.add_argument('--pooling_mode', type=str, choices=ROIWrapper.OPTIONS, help='default: {.value:s}'.format(Config.POOLING_MODE)) 90 | parser.add_argument('--rpn_pre_nms_top_n', type=int, help='default: {:d}'.format(Config.RPN_PRE_NMS_TOP_N)) 91 | parser.add_argument('--rpn_post_nms_top_n', type=int, help='default: {:d}'.format(Config.RPN_POST_NMS_TOP_N)) 92 | args = parser.parse_args() 93 | 94 | path_to_input_stream_endpoint = args.input 95 | period_of_inference = args.period 96 | path_to_checkpoint = args.checkpoint 97 | dataset_name = args.dataset 98 | backbone_name = args.backbone 99 | prob_thresh = args.probability_threshold 100 | 101 | Config.setup(image_min_side=args.image_min_side, image_max_side=args.image_max_side, 102 | anchor_ratios=args.anchor_ratios, anchor_scales=args.anchor_scales, pooling_mode=args.pooling_mode, 103 | rpn_pre_nms_top_n=args.rpn_pre_nms_top_n, rpn_post_nms_top_n=args.rpn_post_nms_top_n) 104 | 105 | print('Arguments:') 106 | for k, v in vars(args).items(): 107 | print(f'\t{k} = {v}') 108 | print(Config.describe()) 109 | 110 | _realtime(path_to_input_stream_endpoint, period_of_inference, path_to_checkpoint, dataset_name, backbone_name, prob_thresh) 111 | 112 | main() 113 | -------------------------------------------------------------------------------- /roi/align/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.ffi import create_extension 4 | 5 | 6 | sources = ['src/crop_and_resize.c'] 7 | headers = ['src/crop_and_resize.h'] 8 | defines = [] 9 | with_cuda = False 10 | 11 | extra_objects = [] 12 | if torch.cuda.is_available(): 13 | print('Including CUDA code.') 14 | sources += ['src/crop_and_resize_gpu.c'] 15 | headers += ['src/crop_and_resize_gpu.h'] 16 | defines += [('WITH_CUDA', None)] 17 | extra_objects += ['src/cuda/crop_and_resize_kernel.cu.o'] 18 | with_cuda = True 19 | 20 | extra_compile_args = ['-std=c99'] 21 | 22 | this_file = os.path.dirname(os.path.realpath(__file__)) 23 | print(this_file) 24 | sources = [os.path.join(this_file, fname) for fname in sources] 25 | headers = [os.path.join(this_file, fname) for fname in headers] 26 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 27 | 28 | ffi = create_extension( 29 | '_ext.crop_and_resize', 30 | headers=headers, 31 | sources=sources, 32 | define_macros=defines, 33 | relative_to=__file__, 34 | with_cuda=with_cuda, 35 | extra_objects=extra_objects, 36 | extra_compile_args=extra_compile_args 37 | ) 38 | 39 | if __name__ == '__main__': 40 | ffi.build() 41 | -------------------------------------------------------------------------------- /roi/align/crop_and_resize.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Function 6 | 7 | from ._ext import crop_and_resize as _backend 8 | 9 | 10 | class CropAndResizeFunction(Function): 11 | 12 | def __init__(self, crop_height, crop_width, extrapolation_value=0): 13 | self.crop_height = crop_height 14 | self.crop_width = crop_width 15 | self.extrapolation_value = extrapolation_value 16 | 17 | def forward(self, image, boxes, box_ind): 18 | crops = torch.zeros_like(image) 19 | 20 | if image.is_cuda: 21 | _backend.crop_and_resize_gpu_forward( 22 | image, boxes, box_ind, 23 | self.extrapolation_value, self.crop_height, self.crop_width, crops) 24 | else: 25 | _backend.crop_and_resize_forward( 26 | image, boxes, box_ind, 27 | self.extrapolation_value, self.crop_height, self.crop_width, crops) 28 | 29 | # save for backward 30 | self.im_size = image.size() 31 | self.save_for_backward(boxes, box_ind) 32 | 33 | return crops 34 | 35 | def backward(self, grad_outputs): 36 | boxes, box_ind = self.saved_tensors 37 | 38 | grad_outputs = grad_outputs.contiguous() 39 | grad_image = torch.zeros_like(grad_outputs).resize_(*self.im_size) 40 | 41 | if grad_outputs.is_cuda: 42 | _backend.crop_and_resize_gpu_backward( 43 | grad_outputs, boxes, box_ind, grad_image 44 | ) 45 | else: 46 | _backend.crop_and_resize_backward( 47 | grad_outputs, boxes, box_ind, grad_image 48 | ) 49 | 50 | return grad_image, None, None 51 | 52 | 53 | class CropAndResize(nn.Module): 54 | """ 55 | Crop and resize ported from tensorflow 56 | See more details on https://www.tensorflow.org/api_docs/python/tf/image/crop_and_resize 57 | """ 58 | 59 | def __init__(self, crop_height, crop_width, extrapolation_value=0): 60 | super(CropAndResize, self).__init__() 61 | 62 | self.crop_height = crop_height 63 | self.crop_width = crop_width 64 | self.extrapolation_value = extrapolation_value 65 | 66 | def forward(self, image, boxes, box_ind): 67 | return CropAndResizeFunction(self.crop_height, self.crop_width, self.extrapolation_value)(image, boxes, box_ind) 68 | -------------------------------------------------------------------------------- /roi/align/roi_align.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .crop_and_resize import CropAndResizeFunction, CropAndResize 5 | 6 | 7 | class RoIAlign(nn.Module): 8 | 9 | def __init__(self, crop_height, crop_width, extrapolation_value=0, transform_fpcoor=True): 10 | super(RoIAlign, self).__init__() 11 | 12 | self.crop_height = crop_height 13 | self.crop_width = crop_width 14 | self.extrapolation_value = extrapolation_value 15 | self.transform_fpcoor = transform_fpcoor 16 | 17 | def forward(self, featuremap, boxes, box_ind): 18 | """ 19 | RoIAlign based on crop_and_resize. 20 | See more details on https://github.com/ppwwyyxx/tensorpack/blob/6d5ba6a970710eaaa14b89d24aace179eb8ee1af/examples/FasterRCNN/model.py#L301 21 | :param featuremap: NxCxHxW 22 | :param boxes: Mx4 float box with (x1, y1, x2, y2) **without normalization** 23 | :param box_ind: M 24 | :return: MxCxoHxoW 25 | """ 26 | x1, y1, x2, y2 = torch.split(boxes, 1, dim=1) 27 | image_height, image_width = featuremap.size()[2:4] 28 | 29 | if self.transform_fpcoor: 30 | spacing_w = (x2 - x1) / float(self.crop_width) 31 | spacing_h = (y2 - y1) / float(self.crop_height) 32 | 33 | nx0 = (x1 + spacing_w / 2 - 0.5) / float(image_width - 1) 34 | ny0 = (y1 + spacing_h / 2 - 0.5) / float(image_height - 1) 35 | nw = spacing_w * float(self.crop_width - 1) / float(image_width - 1) 36 | nh = spacing_h * float(self.crop_height - 1) / float(image_height - 1) 37 | 38 | boxes = torch.cat((ny0, nx0, ny0 + nh, nx0 + nw), 1) 39 | else: 40 | x1 = x1 / float(image_width - 1) 41 | x2 = x2 / float(image_width - 1) 42 | y1 = y1 / float(image_height - 1) 43 | y2 = y2 / float(image_height - 1) 44 | boxes = torch.cat((y1, x1, y2, x2), 1) 45 | 46 | boxes = boxes.detach().contiguous() 47 | box_ind = box_ind.detach() 48 | return CropAndResizeFunction(self.crop_height, self.crop_width, self.extrapolation_value)(featuremap, boxes, box_ind) 49 | -------------------------------------------------------------------------------- /roi/align/src/crop_and_resize.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | 6 | void CropAndResizePerBox( 7 | const float * image_data, 8 | const int batch_size, 9 | const int depth, 10 | const int image_height, 11 | const int image_width, 12 | 13 | const float * boxes_data, 14 | const int * box_index_data, 15 | const int start_box, 16 | const int limit_box, 17 | 18 | float * corps_data, 19 | const int crop_height, 20 | const int crop_width, 21 | const float extrapolation_value 22 | ) { 23 | const int image_channel_elements = image_height * image_width; 24 | const int image_elements = depth * image_channel_elements; 25 | 26 | const int channel_elements = crop_height * crop_width; 27 | const int crop_elements = depth * channel_elements; 28 | 29 | int b; 30 | #pragma omp parallel for 31 | for (b = start_box; b < limit_box; ++b) { 32 | const float * box = boxes_data + b * 4; 33 | const float y1 = box[0]; 34 | const float x1 = box[1]; 35 | const float y2 = box[2]; 36 | const float x2 = box[3]; 37 | 38 | const int b_in = box_index_data[b]; 39 | if (b_in < 0 || b_in >= batch_size) { 40 | printf("Error: batch_index %d out of range [0, %d)\n", b_in, batch_size); 41 | exit(-1); 42 | } 43 | 44 | const float height_scale = 45 | (crop_height > 1) 46 | ? (y2 - y1) * (image_height - 1) / (crop_height - 1) 47 | : 0; 48 | const float width_scale = 49 | (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) 50 | : 0; 51 | 52 | for (int y = 0; y < crop_height; ++y) 53 | { 54 | const float in_y = (crop_height > 1) 55 | ? y1 * (image_height - 1) + y * height_scale 56 | : 0.5 * (y1 + y2) * (image_height - 1); 57 | 58 | if (in_y < 0 || in_y > image_height - 1) 59 | { 60 | for (int x = 0; x < crop_width; ++x) 61 | { 62 | for (int d = 0; d < depth; ++d) 63 | { 64 | // crops(b, y, x, d) = extrapolation_value; 65 | corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = extrapolation_value; 66 | } 67 | } 68 | continue; 69 | } 70 | 71 | const int top_y_index = floorf(in_y); 72 | const int bottom_y_index = ceilf(in_y); 73 | const float y_lerp = in_y - top_y_index; 74 | 75 | for (int x = 0; x < crop_width; ++x) 76 | { 77 | const float in_x = (crop_width > 1) 78 | ? x1 * (image_width - 1) + x * width_scale 79 | : 0.5 * (x1 + x2) * (image_width - 1); 80 | if (in_x < 0 || in_x > image_width - 1) 81 | { 82 | for (int d = 0; d < depth; ++d) 83 | { 84 | corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = extrapolation_value; 85 | } 86 | continue; 87 | } 88 | 89 | const int left_x_index = floorf(in_x); 90 | const int right_x_index = ceilf(in_x); 91 | const float x_lerp = in_x - left_x_index; 92 | 93 | for (int d = 0; d < depth; ++d) 94 | { 95 | const float *pimage = image_data + b_in * image_elements + d * image_channel_elements; 96 | 97 | const float top_left = pimage[top_y_index * image_width + left_x_index]; 98 | const float top_right = pimage[top_y_index * image_width + right_x_index]; 99 | const float bottom_left = pimage[bottom_y_index * image_width + left_x_index]; 100 | const float bottom_right = pimage[bottom_y_index * image_width + right_x_index]; 101 | 102 | const float top = top_left + (top_right - top_left) * x_lerp; 103 | const float bottom = 104 | bottom_left + (bottom_right - bottom_left) * x_lerp; 105 | 106 | corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = top + (bottom - top) * y_lerp; 107 | } 108 | } // end for x 109 | } // end for y 110 | } // end for b 111 | 112 | } 113 | 114 | 115 | void crop_and_resize_forward( 116 | THFloatTensor * image, 117 | THFloatTensor * boxes, // [y1, x1, y2, x2] 118 | THIntTensor * box_index, // range in [0, batch_size) 119 | const float extrapolation_value, 120 | const int crop_height, 121 | const int crop_width, 122 | THFloatTensor * crops 123 | ) { 124 | const int batch_size = THFloatTensor_size(image, 0); 125 | const int depth = THFloatTensor_size(image, 1); 126 | const int image_height = THFloatTensor_size(image, 2); 127 | const int image_width = THFloatTensor_size(image, 3); 128 | 129 | const int num_boxes = THFloatTensor_size(boxes, 0); 130 | 131 | // init output space 132 | THFloatTensor_resize4d(crops, num_boxes, depth, crop_height, crop_width); 133 | THFloatTensor_zero(crops); 134 | 135 | // crop_and_resize for each box 136 | CropAndResizePerBox( 137 | THFloatTensor_data(image), 138 | batch_size, 139 | depth, 140 | image_height, 141 | image_width, 142 | 143 | THFloatTensor_data(boxes), 144 | THIntTensor_data(box_index), 145 | 0, 146 | num_boxes, 147 | 148 | THFloatTensor_data(crops), 149 | crop_height, 150 | crop_width, 151 | extrapolation_value 152 | ); 153 | 154 | } 155 | 156 | 157 | void crop_and_resize_backward( 158 | THFloatTensor * grads, 159 | THFloatTensor * boxes, // [y1, x1, y2, x2] 160 | THIntTensor * box_index, // range in [0, batch_size) 161 | THFloatTensor * grads_image // resize to [bsize, c, hc, wc] 162 | ) 163 | { 164 | // shape 165 | const int batch_size = THFloatTensor_size(grads_image, 0); 166 | const int depth = THFloatTensor_size(grads_image, 1); 167 | const int image_height = THFloatTensor_size(grads_image, 2); 168 | const int image_width = THFloatTensor_size(grads_image, 3); 169 | 170 | const int num_boxes = THFloatTensor_size(grads, 0); 171 | const int crop_height = THFloatTensor_size(grads, 2); 172 | const int crop_width = THFloatTensor_size(grads, 3); 173 | 174 | // n_elements 175 | const int image_channel_elements = image_height * image_width; 176 | const int image_elements = depth * image_channel_elements; 177 | 178 | const int channel_elements = crop_height * crop_width; 179 | const int crop_elements = depth * channel_elements; 180 | 181 | // init output space 182 | THFloatTensor_zero(grads_image); 183 | 184 | // data pointer 185 | const float * grads_data = THFloatTensor_data(grads); 186 | const float * boxes_data = THFloatTensor_data(boxes); 187 | const int * box_index_data = THIntTensor_data(box_index); 188 | float * grads_image_data = THFloatTensor_data(grads_image); 189 | 190 | for (int b = 0; b < num_boxes; ++b) { 191 | const float * box = boxes_data + b * 4; 192 | const float y1 = box[0]; 193 | const float x1 = box[1]; 194 | const float y2 = box[2]; 195 | const float x2 = box[3]; 196 | 197 | const int b_in = box_index_data[b]; 198 | if (b_in < 0 || b_in >= batch_size) { 199 | printf("Error: batch_index %d out of range [0, %d)\n", b_in, batch_size); 200 | exit(-1); 201 | } 202 | 203 | const float height_scale = 204 | (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) 205 | : 0; 206 | const float width_scale = 207 | (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) 208 | : 0; 209 | 210 | for (int y = 0; y < crop_height; ++y) 211 | { 212 | const float in_y = (crop_height > 1) 213 | ? y1 * (image_height - 1) + y * height_scale 214 | : 0.5 * (y1 + y2) * (image_height - 1); 215 | if (in_y < 0 || in_y > image_height - 1) 216 | { 217 | continue; 218 | } 219 | const int top_y_index = floorf(in_y); 220 | const int bottom_y_index = ceilf(in_y); 221 | const float y_lerp = in_y - top_y_index; 222 | 223 | for (int x = 0; x < crop_width; ++x) 224 | { 225 | const float in_x = (crop_width > 1) 226 | ? x1 * (image_width - 1) + x * width_scale 227 | : 0.5 * (x1 + x2) * (image_width - 1); 228 | if (in_x < 0 || in_x > image_width - 1) 229 | { 230 | continue; 231 | } 232 | const int left_x_index = floorf(in_x); 233 | const int right_x_index = ceilf(in_x); 234 | const float x_lerp = in_x - left_x_index; 235 | 236 | for (int d = 0; d < depth; ++d) 237 | { 238 | float *pimage = grads_image_data + b_in * image_elements + d * image_channel_elements; 239 | const float grad_val = grads_data[crop_elements * b + channel_elements * d + y * crop_width + x]; 240 | 241 | const float dtop = (1 - y_lerp) * grad_val; 242 | pimage[top_y_index * image_width + left_x_index] += (1 - x_lerp) * dtop; 243 | pimage[top_y_index * image_width + right_x_index] += x_lerp * dtop; 244 | 245 | const float dbottom = y_lerp * grad_val; 246 | pimage[bottom_y_index * image_width + left_x_index] += (1 - x_lerp) * dbottom; 247 | pimage[bottom_y_index * image_width + right_x_index] += x_lerp * dbottom; 248 | } // end d 249 | } // end x 250 | } // end y 251 | } // end b 252 | } -------------------------------------------------------------------------------- /roi/align/src/crop_and_resize.h: -------------------------------------------------------------------------------- 1 | void crop_and_resize_forward( 2 | THFloatTensor * image, 3 | THFloatTensor * boxes, // [y1, x1, y2, x2] 4 | THIntTensor * box_index, // range in [0, batch_size) 5 | const float extrapolation_value, 6 | const int crop_height, 7 | const int crop_width, 8 | THFloatTensor * crops 9 | ); 10 | 11 | void crop_and_resize_backward( 12 | THFloatTensor * grads, 13 | THFloatTensor * boxes, // [y1, x1, y2, x2] 14 | THIntTensor * box_index, // range in [0, batch_size) 15 | THFloatTensor * grads_image // resize to [bsize, c, hc, wc] 16 | ); -------------------------------------------------------------------------------- /roi/align/src/crop_and_resize_gpu.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include "cuda/crop_and_resize_kernel.h" 3 | 4 | extern THCState *state; 5 | 6 | 7 | void crop_and_resize_gpu_forward( 8 | THCudaTensor * image, 9 | THCudaTensor * boxes, // [y1, x1, y2, x2] 10 | THCudaIntTensor * box_index, // range in [0, batch_size) 11 | const float extrapolation_value, 12 | const int crop_height, 13 | const int crop_width, 14 | THCudaTensor * crops 15 | ) { 16 | const int batch_size = THCudaTensor_size(state, image, 0); 17 | const int depth = THCudaTensor_size(state, image, 1); 18 | const int image_height = THCudaTensor_size(state, image, 2); 19 | const int image_width = THCudaTensor_size(state, image, 3); 20 | 21 | const int num_boxes = THCudaTensor_size(state, boxes, 0); 22 | 23 | // init output space 24 | THCudaTensor_resize4d(state, crops, num_boxes, depth, crop_height, crop_width); 25 | THCudaTensor_zero(state, crops); 26 | 27 | cudaStream_t stream = THCState_getCurrentStream(state); 28 | CropAndResizeLaucher( 29 | THCudaTensor_data(state, image), 30 | THCudaTensor_data(state, boxes), 31 | THCudaIntTensor_data(state, box_index), 32 | num_boxes, batch_size, image_height, image_width, 33 | crop_height, crop_width, depth, extrapolation_value, 34 | THCudaTensor_data(state, crops), 35 | stream 36 | ); 37 | } 38 | 39 | 40 | void crop_and_resize_gpu_backward( 41 | THCudaTensor * grads, 42 | THCudaTensor * boxes, // [y1, x1, y2, x2] 43 | THCudaIntTensor * box_index, // range in [0, batch_size) 44 | THCudaTensor * grads_image // resize to [bsize, c, hc, wc] 45 | ) { 46 | // shape 47 | const int batch_size = THCudaTensor_size(state, grads_image, 0); 48 | const int depth = THCudaTensor_size(state, grads_image, 1); 49 | const int image_height = THCudaTensor_size(state, grads_image, 2); 50 | const int image_width = THCudaTensor_size(state, grads_image, 3); 51 | 52 | const int num_boxes = THCudaTensor_size(state, grads, 0); 53 | const int crop_height = THCudaTensor_size(state, grads, 2); 54 | const int crop_width = THCudaTensor_size(state, grads, 3); 55 | 56 | // init output space 57 | THCudaTensor_zero(state, grads_image); 58 | 59 | cudaStream_t stream = THCState_getCurrentStream(state); 60 | CropAndResizeBackpropImageLaucher( 61 | THCudaTensor_data(state, grads), 62 | THCudaTensor_data(state, boxes), 63 | THCudaIntTensor_data(state, box_index), 64 | num_boxes, batch_size, image_height, image_width, 65 | crop_height, crop_width, depth, 66 | THCudaTensor_data(state, grads_image), 67 | stream 68 | ); 69 | } -------------------------------------------------------------------------------- /roi/align/src/crop_and_resize_gpu.h: -------------------------------------------------------------------------------- 1 | void crop_and_resize_gpu_forward( 2 | THCudaTensor * image, 3 | THCudaTensor * boxes, // [y1, x1, y2, x2] 4 | THCudaIntTensor * box_index, // range in [0, batch_size) 5 | const float extrapolation_value, 6 | const int crop_height, 7 | const int crop_width, 8 | THCudaTensor * crops 9 | ); 10 | 11 | void crop_and_resize_gpu_backward( 12 | THCudaTensor * grads, 13 | THCudaTensor * boxes, // [y1, x1, y2, x2] 14 | THCudaIntTensor * box_index, // range in [0, batch_size) 15 | THCudaTensor * grads_image // resize to [bsize, c, hc, wc] 16 | ); -------------------------------------------------------------------------------- /roi/align/src/cuda/crop_and_resize_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "crop_and_resize_kernel.h" 4 | 5 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 6 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 7 | i += blockDim.x * gridDim.x) 8 | 9 | 10 | __global__ 11 | void CropAndResizeKernel( 12 | const int nthreads, const float *image_ptr, const float *boxes_ptr, 13 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 14 | int image_width, int crop_height, int crop_width, int depth, 15 | float extrapolation_value, float *crops_ptr) 16 | { 17 | CUDA_1D_KERNEL_LOOP(out_idx, nthreads) 18 | { 19 | // NHWC: out_idx = d + depth * (w + crop_width * (h + crop_height * b)) 20 | // NCHW: out_idx = w + crop_width * (h + crop_height * (d + depth * b)) 21 | int idx = out_idx; 22 | const int x = idx % crop_width; 23 | idx /= crop_width; 24 | const int y = idx % crop_height; 25 | idx /= crop_height; 26 | const int d = idx % depth; 27 | const int b = idx / depth; 28 | 29 | const float y1 = boxes_ptr[b * 4]; 30 | const float x1 = boxes_ptr[b * 4 + 1]; 31 | const float y2 = boxes_ptr[b * 4 + 2]; 32 | const float x2 = boxes_ptr[b * 4 + 3]; 33 | 34 | const int b_in = box_ind_ptr[b]; 35 | if (b_in < 0 || b_in >= batch) 36 | { 37 | continue; 38 | } 39 | 40 | const float height_scale = 41 | (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) 42 | : 0; 43 | const float width_scale = 44 | (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) : 0; 45 | 46 | const float in_y = (crop_height > 1) 47 | ? y1 * (image_height - 1) + y * height_scale 48 | : 0.5 * (y1 + y2) * (image_height - 1); 49 | if (in_y < 0 || in_y > image_height - 1) 50 | { 51 | crops_ptr[out_idx] = extrapolation_value; 52 | continue; 53 | } 54 | 55 | const float in_x = (crop_width > 1) 56 | ? x1 * (image_width - 1) + x * width_scale 57 | : 0.5 * (x1 + x2) * (image_width - 1); 58 | if (in_x < 0 || in_x > image_width - 1) 59 | { 60 | crops_ptr[out_idx] = extrapolation_value; 61 | continue; 62 | } 63 | 64 | const int top_y_index = floorf(in_y); 65 | const int bottom_y_index = ceilf(in_y); 66 | const float y_lerp = in_y - top_y_index; 67 | 68 | const int left_x_index = floorf(in_x); 69 | const int right_x_index = ceilf(in_x); 70 | const float x_lerp = in_x - left_x_index; 71 | 72 | const float *pimage = image_ptr + (b_in * depth + d) * image_height * image_width; 73 | const float top_left = pimage[top_y_index * image_width + left_x_index]; 74 | const float top_right = pimage[top_y_index * image_width + right_x_index]; 75 | const float bottom_left = pimage[bottom_y_index * image_width + left_x_index]; 76 | const float bottom_right = pimage[bottom_y_index * image_width + right_x_index]; 77 | 78 | const float top = top_left + (top_right - top_left) * x_lerp; 79 | const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp; 80 | crops_ptr[out_idx] = top + (bottom - top) * y_lerp; 81 | } 82 | } 83 | 84 | __global__ 85 | void CropAndResizeBackpropImageKernel( 86 | const int nthreads, const float *grads_ptr, const float *boxes_ptr, 87 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 88 | int image_width, int crop_height, int crop_width, int depth, 89 | float *grads_image_ptr) 90 | { 91 | CUDA_1D_KERNEL_LOOP(out_idx, nthreads) 92 | { 93 | // NHWC: out_idx = d + depth * (w + crop_width * (h + crop_height * b)) 94 | // NCHW: out_idx = w + crop_width * (h + crop_height * (d + depth * b)) 95 | int idx = out_idx; 96 | const int x = idx % crop_width; 97 | idx /= crop_width; 98 | const int y = idx % crop_height; 99 | idx /= crop_height; 100 | const int d = idx % depth; 101 | const int b = idx / depth; 102 | 103 | const float y1 = boxes_ptr[b * 4]; 104 | const float x1 = boxes_ptr[b * 4 + 1]; 105 | const float y2 = boxes_ptr[b * 4 + 2]; 106 | const float x2 = boxes_ptr[b * 4 + 3]; 107 | 108 | const int b_in = box_ind_ptr[b]; 109 | if (b_in < 0 || b_in >= batch) 110 | { 111 | continue; 112 | } 113 | 114 | const float height_scale = 115 | (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) 116 | : 0; 117 | const float width_scale = 118 | (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) : 0; 119 | 120 | const float in_y = (crop_height > 1) 121 | ? y1 * (image_height - 1) + y * height_scale 122 | : 0.5 * (y1 + y2) * (image_height - 1); 123 | if (in_y < 0 || in_y > image_height - 1) 124 | { 125 | continue; 126 | } 127 | 128 | const float in_x = (crop_width > 1) 129 | ? x1 * (image_width - 1) + x * width_scale 130 | : 0.5 * (x1 + x2) * (image_width - 1); 131 | if (in_x < 0 || in_x > image_width - 1) 132 | { 133 | continue; 134 | } 135 | 136 | const int top_y_index = floorf(in_y); 137 | const int bottom_y_index = ceilf(in_y); 138 | const float y_lerp = in_y - top_y_index; 139 | 140 | const int left_x_index = floorf(in_x); 141 | const int right_x_index = ceilf(in_x); 142 | const float x_lerp = in_x - left_x_index; 143 | 144 | float *pimage = grads_image_ptr + (b_in * depth + d) * image_height * image_width; 145 | const float dtop = (1 - y_lerp) * grads_ptr[out_idx]; 146 | atomicAdd( 147 | pimage + top_y_index * image_width + left_x_index, 148 | (1 - x_lerp) * dtop 149 | ); 150 | atomicAdd( 151 | pimage + top_y_index * image_width + right_x_index, 152 | x_lerp * dtop 153 | ); 154 | 155 | const float dbottom = y_lerp * grads_ptr[out_idx]; 156 | atomicAdd( 157 | pimage + bottom_y_index * image_width + left_x_index, 158 | (1 - x_lerp) * dbottom 159 | ); 160 | atomicAdd( 161 | pimage + bottom_y_index * image_width + right_x_index, 162 | x_lerp * dbottom 163 | ); 164 | } 165 | } 166 | 167 | 168 | void CropAndResizeLaucher( 169 | const float *image_ptr, const float *boxes_ptr, 170 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 171 | int image_width, int crop_height, int crop_width, int depth, 172 | float extrapolation_value, float *crops_ptr, cudaStream_t stream) 173 | { 174 | const int total_count = num_boxes * crop_height * crop_width * depth; 175 | const int thread_per_block = 1024; 176 | const int block_count = (total_count + thread_per_block - 1) / thread_per_block; 177 | cudaError_t err; 178 | 179 | if (total_count > 0) 180 | { 181 | CropAndResizeKernel<<>>( 182 | total_count, image_ptr, boxes_ptr, 183 | box_ind_ptr, num_boxes, batch, image_height, image_width, 184 | crop_height, crop_width, depth, extrapolation_value, crops_ptr); 185 | 186 | err = cudaGetLastError(); 187 | if (cudaSuccess != err) 188 | { 189 | fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); 190 | exit(-1); 191 | } 192 | } 193 | } 194 | 195 | 196 | void CropAndResizeBackpropImageLaucher( 197 | const float *grads_ptr, const float *boxes_ptr, 198 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 199 | int image_width, int crop_height, int crop_width, int depth, 200 | float *grads_image_ptr, cudaStream_t stream) 201 | { 202 | const int total_count = num_boxes * crop_height * crop_width * depth; 203 | const int thread_per_block = 1024; 204 | const int block_count = (total_count + thread_per_block - 1) / thread_per_block; 205 | cudaError_t err; 206 | 207 | if (total_count > 0) 208 | { 209 | CropAndResizeBackpropImageKernel<<>>( 210 | total_count, grads_ptr, boxes_ptr, 211 | box_ind_ptr, num_boxes, batch, image_height, image_width, 212 | crop_height, crop_width, depth, grads_image_ptr); 213 | 214 | err = cudaGetLastError(); 215 | if (cudaSuccess != err) 216 | { 217 | fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); 218 | exit(-1); 219 | } 220 | } 221 | } -------------------------------------------------------------------------------- /roi/align/src/cuda/crop_and_resize_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _CropAndResize_Kernel 2 | #define _CropAndResize_Kernel 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | void CropAndResizeLaucher( 9 | const float *image_ptr, const float *boxes_ptr, 10 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 11 | int image_width, int crop_height, int crop_width, int depth, 12 | float extrapolation_value, float *crops_ptr, cudaStream_t stream); 13 | 14 | void CropAndResizeBackpropImageLaucher( 15 | const float *grads_ptr, const float *boxes_ptr, 16 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 17 | int image_width, int crop_height, int crop_width, int depth, 18 | float *grads_image_ptr, cudaStream_t stream); 19 | 20 | #ifdef __cplusplus 21 | } 22 | #endif 23 | 24 | #endif -------------------------------------------------------------------------------- /roi/wrapper.py: -------------------------------------------------------------------------------- 1 | import math 2 | from enum import Enum 3 | 4 | import torch 5 | from torch import Tensor 6 | from torch.nn import functional as F 7 | 8 | from roi.align.crop_and_resize import CropAndResizeFunction 9 | 10 | 11 | class Wrapper(object): 12 | 13 | class Mode(Enum): 14 | POOLING = 'pooling' 15 | ALIGN = 'align' 16 | 17 | OPTIONS = ['pooling', 'align'] 18 | 19 | @staticmethod 20 | def apply(features: Tensor, proposal_bboxes: Tensor, mode: Mode, image_width: int, image_height: int) -> Tensor: 21 | _, _, feature_map_height, feature_map_width = features.shape 22 | proposal_bboxes = proposal_bboxes.detach() 23 | 24 | scale_x = image_width / feature_map_width 25 | scale_y = image_height / feature_map_height 26 | 27 | if mode == Wrapper.Mode.POOLING: 28 | pool = [] 29 | for proposal_bbox in proposal_bboxes: 30 | start_x = max(min(round(proposal_bbox[0].item() / scale_x), feature_map_width - 1), 0) # [0, feature_map_width) 31 | start_y = max(min(round(proposal_bbox[1].item() / scale_y), feature_map_height - 1), 0) # (0, feature_map_height] 32 | end_x = max(min(round(proposal_bbox[2].item() / scale_x) + 1, feature_map_width), 1) # [0, feature_map_width) 33 | end_y = max(min(round(proposal_bbox[3].item() / scale_y) + 1, feature_map_height), 1) # (0, feature_map_height] 34 | roi_feature_map = features[..., start_y:end_y, start_x:end_x] 35 | pool.append(F.adaptive_max_pool2d(input=roi_feature_map, output_size=7)) 36 | pool = torch.cat(pool, dim=0) 37 | elif mode == Wrapper.Mode.ALIGN: 38 | x1 = proposal_bboxes[:, 0::4] / scale_x 39 | y1 = proposal_bboxes[:, 1::4] / scale_y 40 | x2 = proposal_bboxes[:, 2::4] / scale_x 41 | y2 = proposal_bboxes[:, 3::4] / scale_y 42 | 43 | crops = CropAndResizeFunction(crop_height=7 * 2, crop_width=7 * 2)( 44 | features, 45 | torch.cat([y1 / (feature_map_height - 1), x1 / (feature_map_width - 1), 46 | y2 / (feature_map_height - 1), x2 / (feature_map_width - 1)], 47 | dim=1), 48 | torch.zeros(proposal_bboxes.shape[0], dtype=torch.int, device=proposal_bboxes.device) 49 | ) 50 | pool = F.max_pool2d(input=crops, kernel_size=2, stride=2) 51 | else: 52 | raise ValueError 53 | 54 | return pool 55 | 56 | -------------------------------------------------------------------------------- /rpn/region_proposal_network.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn, Tensor 6 | from torch.nn import functional as F 7 | 8 | from bbox import BBox 9 | from nms.nms import NMS 10 | 11 | 12 | class RegionProposalNetwork(nn.Module): 13 | 14 | def __init__(self, num_features_out: int, anchor_ratios: List[Tuple[int, int]], anchor_scales: List[int], pre_nms_top_n: int, post_nms_top_n: int): 15 | super().__init__() 16 | 17 | self._features = nn.Sequential( 18 | nn.Conv2d(in_channels=num_features_out, out_channels=512, kernel_size=3, padding=1), 19 | nn.ReLU() 20 | ) 21 | 22 | self._anchor_ratios = anchor_ratios 23 | self._anchor_scales = anchor_scales 24 | 25 | num_anchor_ratios = len(self._anchor_ratios) 26 | num_anchor_scales = len(self._anchor_scales) 27 | num_anchors = num_anchor_ratios * num_anchor_scales 28 | 29 | self._pre_nms_top_n = pre_nms_top_n 30 | self._post_nms_top_n = post_nms_top_n 31 | 32 | self._objectness = nn.Conv2d(in_channels=512, out_channels=num_anchors * 2, kernel_size=1) 33 | self._transformer = nn.Conv2d(in_channels=512, out_channels=num_anchors * 4, kernel_size=1) 34 | 35 | def forward(self, features: Tensor, image_width: int, image_height: int) -> Tuple[Tensor, Tensor]: 36 | features = self._features(features) 37 | objectnesses = self._objectness(features) 38 | transformers = self._transformer(features) 39 | 40 | objectnesses = objectnesses.permute(0, 2, 3, 1).contiguous().view(-1, 2) 41 | transformers = transformers.permute(0, 2, 3, 1).contiguous().view(-1, 4) 42 | 43 | return objectnesses, transformers 44 | 45 | def sample(self, anchor_bboxes: Tensor, gt_bboxes: Tensor, image_width: int, image_height: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: 46 | sample_fg_indices = torch.arange(end=len(anchor_bboxes), dtype=torch.long) 47 | sample_selected_indices = torch.arange(end=len(anchor_bboxes), dtype=torch.long) 48 | 49 | anchor_bboxes = anchor_bboxes.cpu() 50 | gt_bboxes = gt_bboxes.cpu() 51 | 52 | # remove cross-boundary 53 | boundary = torch.tensor(BBox(0, 0, image_width, image_height).tolist(), dtype=torch.float) 54 | inside_indices = BBox.inside(anchor_bboxes, boundary.unsqueeze(dim=0)).squeeze().nonzero().view(-1) 55 | 56 | anchor_bboxes = anchor_bboxes[inside_indices] 57 | sample_fg_indices = sample_fg_indices[inside_indices] 58 | sample_selected_indices = sample_selected_indices[inside_indices] 59 | 60 | # find labels for each `anchor_bboxes` 61 | labels = torch.ones(len(anchor_bboxes), dtype=torch.long) * -1 62 | ious = BBox.iou(anchor_bboxes, gt_bboxes) 63 | anchor_max_ious, anchor_assignments = ious.max(dim=1) 64 | gt_max_ious, gt_assignments = ious.max(dim=0) 65 | anchor_additions = (ious == gt_max_ious).nonzero()[:, 0] 66 | labels[anchor_max_ious < 0.3] = 0 67 | labels[anchor_additions] = 1 68 | labels[anchor_max_ious >= 0.7] = 1 69 | 70 | # select 256 samples 71 | fg_indices = (labels == 1).nonzero().view(-1) 72 | bg_indices = (labels == 0).nonzero().view(-1) 73 | fg_indices = fg_indices[torch.randperm(len(fg_indices))[:min(len(fg_indices), 128)]] 74 | bg_indices = bg_indices[torch.randperm(len(bg_indices))[:256 - len(fg_indices)]] 75 | selected_indices = torch.cat([fg_indices, bg_indices]) 76 | selected_indices = selected_indices[torch.randperm(len(selected_indices))] 77 | 78 | gt_anchor_objectnesses = labels[selected_indices] 79 | gt_bboxes = gt_bboxes[anchor_assignments[fg_indices]] 80 | anchor_bboxes = anchor_bboxes[fg_indices] 81 | gt_anchor_transformers = BBox.calc_transformer(anchor_bboxes, gt_bboxes) 82 | 83 | gt_anchor_objectnesses = gt_anchor_objectnesses.cuda() 84 | gt_anchor_transformers = gt_anchor_transformers.cuda() 85 | 86 | sample_fg_indices = sample_fg_indices[fg_indices] 87 | sample_selected_indices = sample_selected_indices[selected_indices] 88 | 89 | return sample_fg_indices, sample_selected_indices, gt_anchor_objectnesses, gt_anchor_transformers 90 | 91 | def loss(self, anchor_objectnesses: Tensor, anchor_transformers: Tensor, gt_anchor_objectnesses: Tensor, gt_anchor_transformers: Tensor) -> Tuple[Tensor, Tensor]: 92 | cross_entropy = F.cross_entropy(input=anchor_objectnesses, target=gt_anchor_objectnesses) 93 | 94 | # NOTE: The default of `reduction` is `elementwise_mean`, which is divided by N x 4 (number of all elements), here we replaced by N for better performance 95 | smooth_l1_loss = F.smooth_l1_loss(input=anchor_transformers, target=gt_anchor_transformers, reduction='sum') 96 | smooth_l1_loss /= len(gt_anchor_transformers) 97 | 98 | return cross_entropy, smooth_l1_loss 99 | 100 | def generate_anchors(self, image_width: int, image_height: int, num_x_anchors: int, num_y_anchors: int, anchor_size: int) -> Tensor: 101 | center_ys = np.linspace(start=0, stop=image_height, num=num_y_anchors + 2)[1:-1] 102 | center_xs = np.linspace(start=0, stop=image_width, num=num_x_anchors + 2)[1:-1] 103 | ratios = np.array(self._anchor_ratios) 104 | ratios = ratios[:, 0] / ratios[:, 1] 105 | scales = np.array(self._anchor_scales) 106 | 107 | # NOTE: it's important to let `center_ys` be the major index (i.e., move horizontally and then vertically) for consistency with 2D convolution 108 | 109 | # giving the string 'ij' returns a meshgrid with matrix indexing, i.e., with shape (#center_ys, #center_xs, #ratios, #scales) 110 | center_ys, center_xs, ratios, scales = np.meshgrid(center_ys, center_xs, ratios, scales, indexing='ij') 111 | 112 | center_ys = center_ys.reshape(-1) 113 | center_xs = center_xs.reshape(-1) 114 | ratios = ratios.reshape(-1) 115 | scales = scales.reshape(-1) 116 | 117 | widths = anchor_size * scales * np.sqrt(1 / ratios) 118 | heights = anchor_size * scales * np.sqrt(ratios) 119 | 120 | center_based_anchor_bboxes = np.stack((center_xs, center_ys, widths, heights), axis=1) 121 | center_based_anchor_bboxes = torch.from_numpy(center_based_anchor_bboxes).float() 122 | anchor_bboxes = BBox.from_center_base(center_based_anchor_bboxes) 123 | 124 | return anchor_bboxes 125 | 126 | def generate_proposals(self, anchor_bboxes: Tensor, objectnesses: Tensor, transformers: Tensor, image_width: int, image_height: int) -> Tensor: 127 | proposal_score = objectnesses[:, 1] 128 | _, sorted_indices = torch.sort(proposal_score, dim=0, descending=True) 129 | 130 | sorted_transformers = transformers[sorted_indices] 131 | sorted_anchor_bboxes = anchor_bboxes[sorted_indices] 132 | 133 | proposal_bboxes = BBox.apply_transformer(sorted_anchor_bboxes, sorted_transformers.detach()) 134 | proposal_bboxes = BBox.clip(proposal_bboxes, 0, 0, image_width, image_height) 135 | 136 | proposal_bboxes = proposal_bboxes[:self._pre_nms_top_n] 137 | kept_indices = NMS.suppress(proposal_bboxes, threshold=0.7) 138 | proposal_bboxes = proposal_bboxes[kept_indices] 139 | proposal_bboxes = proposal_bboxes[:self._post_nms_top_n] 140 | 141 | return proposal_bboxes 142 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import uuid 5 | from collections import deque 6 | from typing import Optional 7 | 8 | from tensorboardX import SummaryWriter 9 | from torch import optim 10 | from torch.optim.lr_scheduler import MultiStepLR 11 | from torch.utils.data import DataLoader 12 | 13 | from backbone.base import Base as BackboneBase 14 | from config.train_config import TrainConfig as Config 15 | from dataset.base import Base as DatasetBase 16 | from logger import Logger as Log 17 | from model import Model 18 | from roi.wrapper import Wrapper as ROIWrapper 19 | 20 | 21 | def _train(dataset_name: str, backbone_name: str, path_to_data_dir: str, path_to_checkpoints_dir: str, path_to_resuming_checkpoint: Optional[str]): 22 | dataset = DatasetBase.from_name(dataset_name)(path_to_data_dir, DatasetBase.Mode.TRAIN, Config.IMAGE_MIN_SIDE, Config.IMAGE_MAX_SIDE) 23 | dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=8, pin_memory=True) 24 | 25 | Log.i('Found {:d} samples'.format(len(dataset))) 26 | 27 | backbone = BackboneBase.from_name(backbone_name)(pretrained=True) 28 | model = Model(backbone, dataset.num_classes(), pooling_mode=Config.POOLING_MODE, 29 | anchor_ratios=Config.ANCHOR_RATIOS, anchor_scales=Config.ANCHOR_SCALES, 30 | rpn_pre_nms_top_n=Config.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=Config.RPN_POST_NMS_TOP_N).cuda() 31 | optimizer = optim.SGD(model.parameters(), lr=Config.LEARNING_RATE, 32 | momentum=Config.MOMENTUM, weight_decay=Config.WEIGHT_DECAY) 33 | scheduler = MultiStepLR(optimizer, milestones=Config.STEP_LR_SIZES, gamma=Config.STEP_LR_GAMMA) 34 | 35 | step = 0 36 | time_checkpoint = time.time() 37 | losses = deque(maxlen=100) 38 | summary_writer = SummaryWriter(os.path.join(path_to_checkpoints_dir, 'summaries')) 39 | should_stop = False 40 | 41 | num_steps_to_display = Config.NUM_STEPS_TO_DISPLAY 42 | num_steps_to_snapshot = Config.NUM_STEPS_TO_SNAPSHOT 43 | num_steps_to_finish = Config.NUM_STEPS_TO_FINISH 44 | 45 | if path_to_resuming_checkpoint is not None: 46 | step = model.load(path_to_resuming_checkpoint, optimizer, scheduler) 47 | Log.i(f'Model has been restored from file: {path_to_resuming_checkpoint}') 48 | 49 | Log.i('Start training') 50 | 51 | while not should_stop: 52 | for batch_index, (_, image_batch, _, bboxes_batch, labels_batch) in enumerate(dataloader): 53 | assert image_batch.shape[0] == 1, 'only batch size of 1 is supported' 54 | 55 | image = image_batch[0].cuda() 56 | bboxes = bboxes_batch[0].cuda() 57 | labels = labels_batch[0].cuda() 58 | 59 | forward_input = Model.ForwardInput.Train(image, gt_classes=labels, gt_bboxes=bboxes) 60 | forward_output: Model.ForwardOutput.Train = model.train().forward(forward_input) 61 | 62 | anchor_objectness_loss, anchor_transformer_loss, proposal_class_loss, proposal_transformer_loss = forward_output 63 | loss = anchor_objectness_loss + anchor_transformer_loss + proposal_class_loss + proposal_transformer_loss 64 | 65 | optimizer.zero_grad() 66 | loss.backward() 67 | optimizer.step() 68 | scheduler.step() 69 | losses.append(loss.item()) 70 | summary_writer.add_scalar('train/anchor_objectness_loss', anchor_objectness_loss.item(), step) 71 | summary_writer.add_scalar('train/anchor_transformer_loss', anchor_transformer_loss.item(), step) 72 | summary_writer.add_scalar('train/proposal_class_loss', proposal_class_loss.item(), step) 73 | summary_writer.add_scalar('train/proposal_transformer_loss', proposal_transformer_loss.item(), step) 74 | summary_writer.add_scalar('train/loss', loss.item(), step) 75 | step += 1 76 | 77 | if step == num_steps_to_finish: 78 | should_stop = True 79 | 80 | if step % num_steps_to_display == 0: 81 | elapsed_time = time.time() - time_checkpoint 82 | time_checkpoint = time.time() 83 | steps_per_sec = num_steps_to_display / elapsed_time 84 | samples_per_sec = dataloader.batch_size * steps_per_sec 85 | eta = (num_steps_to_finish - step) / steps_per_sec / 3600 86 | avg_loss = sum(losses) / len(losses) 87 | lr = scheduler.get_lr()[0] 88 | Log.i(f'[Step {step}] Avg. Loss = {avg_loss:.6f}, Learning Rate = {lr:.6f} ({samples_per_sec:.2f} samples/sec; ETA {eta:.1f} hrs)') 89 | 90 | if step % num_steps_to_snapshot == 0 or should_stop: 91 | path_to_checkpoint = model.save(path_to_checkpoints_dir, step, optimizer, scheduler) 92 | Log.i(f'Model has been saved to {path_to_checkpoint}') 93 | 94 | if should_stop: 95 | break 96 | 97 | Log.i('Done') 98 | 99 | 100 | if __name__ == '__main__': 101 | def main(): 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument('-s', '--dataset', type=str, choices=DatasetBase.OPTIONS, required=True, help='name of dataset') 104 | parser.add_argument('-b', '--backbone', type=str, choices=BackboneBase.OPTIONS, required=True, help='name of backbone model') 105 | parser.add_argument('-d', '--data_dir', type=str, default='./data', help='path to data directory') 106 | parser.add_argument('-o', '--outputs_dir', type=str, default='./outputs', help='path to outputs directory') 107 | parser.add_argument('-r', '--resume_checkpoint', type=str, help='path to resuming checkpoint') 108 | parser.add_argument('--image_min_side', type=float, help='default: {:g}'.format(Config.IMAGE_MIN_SIDE)) 109 | parser.add_argument('--image_max_side', type=float, help='default: {:g}'.format(Config.IMAGE_MAX_SIDE)) 110 | parser.add_argument('--anchor_ratios', type=str, help='default: "{!s}"'.format(Config.ANCHOR_RATIOS)) 111 | parser.add_argument('--anchor_scales', type=str, help='default: "{!s}"'.format(Config.ANCHOR_SCALES)) 112 | parser.add_argument('--pooling_mode', type=str, choices=ROIWrapper.OPTIONS, help='default: {.value:s}'.format(Config.POOLING_MODE)) 113 | parser.add_argument('--rpn_pre_nms_top_n', type=int, help='default: {:d}'.format(Config.RPN_PRE_NMS_TOP_N)) 114 | parser.add_argument('--rpn_post_nms_top_n', type=int, help='default: {:d}'.format(Config.RPN_POST_NMS_TOP_N)) 115 | parser.add_argument('--learning_rate', type=float, help='default: {:g}'.format(Config.LEARNING_RATE)) 116 | parser.add_argument('--momentum', type=float, help='default: {:g}'.format(Config.MOMENTUM)) 117 | parser.add_argument('--weight_decay', type=float, help='default: {:g}'.format(Config.WEIGHT_DECAY)) 118 | parser.add_argument('--step_lr_sizes', type=str, help='default: {!s}'.format(Config.STEP_LR_SIZES)) 119 | parser.add_argument('--step_lr_gamma', type=float, help='default: {:g}'.format(Config.STEP_LR_GAMMA)) 120 | parser.add_argument('--num_steps_to_display', type=int, help='default: {:d}'.format(Config.NUM_STEPS_TO_DISPLAY)) 121 | parser.add_argument('--num_steps_to_snapshot', type=int, help='default: {:d}'.format(Config.NUM_STEPS_TO_SNAPSHOT)) 122 | parser.add_argument('--num_steps_to_finish', type=int, help='default: {:d}'.format(Config.NUM_STEPS_TO_FINISH)) 123 | args = parser.parse_args() 124 | 125 | dataset_name = args.dataset 126 | backbone_name = args.backbone 127 | path_to_data_dir = args.data_dir 128 | path_to_outputs_dir = args.outputs_dir 129 | path_to_resuming_checkpoint = args.resume_checkpoint 130 | 131 | path_to_checkpoints_dir = os.path.join(path_to_outputs_dir, 'checkpoints-{:s}-{:s}-{:s}-{:s}'.format( 132 | time.strftime('%Y%m%d%H%M%S'), dataset_name, backbone_name, str(uuid.uuid4()).split('-')[0])) 133 | os.makedirs(path_to_checkpoints_dir) 134 | 135 | Config.setup(image_min_side=args.image_min_side, image_max_side=args.image_max_side, 136 | anchor_ratios=args.anchor_ratios, anchor_scales=args.anchor_scales, pooling_mode=args.pooling_mode, 137 | rpn_pre_nms_top_n=args.rpn_pre_nms_top_n, rpn_post_nms_top_n=args.rpn_post_nms_top_n, 138 | learning_rate=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, 139 | step_lr_sizes=args.step_lr_sizes, step_lr_gamma=args.step_lr_gamma, 140 | num_steps_to_display=args.num_steps_to_display, num_steps_to_snapshot=args.num_steps_to_snapshot, num_steps_to_finish=args.num_steps_to_finish) 141 | 142 | Log.initialize(os.path.join(path_to_checkpoints_dir, 'train.log')) 143 | Log.i('Arguments:') 144 | for k, v in vars(args).items(): 145 | Log.i(f'\t{k} = {v}') 146 | Log.i(Config.describe()) 147 | 148 | _train(dataset_name, backbone_name, path_to_data_dir, path_to_checkpoints_dir, path_to_resuming_checkpoint) 149 | 150 | main() 151 | -------------------------------------------------------------------------------- /voc_eval.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast/er R-CNN 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Bharath Hariharan 5 | # -------------------------------------------------------- 6 | 7 | import xml.etree.ElementTree as ET 8 | import os 9 | import _pickle as cPickle 10 | import numpy as np 11 | 12 | def parse_rec(filename): 13 | """ Parse a PASCAL VOC xml file """ 14 | tree = ET.parse(filename) 15 | objects = [] 16 | for obj in tree.findall('object'): 17 | obj_struct = {} 18 | obj_struct['name'] = obj.find('name').text 19 | obj_struct['pose'] = obj.find('pose').text 20 | obj_struct['truncated'] = int(obj.find('truncated').text) 21 | obj_struct['difficult'] = int(obj.find('difficult').text) 22 | bbox = obj.find('bndbox') 23 | obj_struct['bbox'] = [int(bbox.find('xmin').text), 24 | int(bbox.find('ymin').text), 25 | int(bbox.find('xmax').text), 26 | int(bbox.find('ymax').text)] 27 | objects.append(obj_struct) 28 | 29 | return objects 30 | 31 | def voc_ap(rec, prec, use_07_metric=False): 32 | """ ap = voc_ap(rec, prec, [use_07_metric]) 33 | Compute VOC AP given precision and recall. 34 | If use_07_metric is true, uses the 35 | VOC 07 11 point method (default:False). 36 | """ 37 | if use_07_metric: 38 | # 11 point metric 39 | ap = 0. 40 | for t in np.arange(0., 1.1, 0.1): 41 | if np.sum(rec >= t) == 0: 42 | p = 0 43 | else: 44 | p = np.max(prec[rec >= t]) 45 | ap = ap + p / 11. 46 | else: 47 | # correct AP calculation 48 | # first append sentinel values at the end 49 | mrec = np.concatenate(([0.], rec, [1.])) 50 | mpre = np.concatenate(([0.], prec, [0.])) 51 | 52 | # compute the precision envelope 53 | for i in range(mpre.size - 1, 0, -1): 54 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 55 | 56 | # to calculate area under PR curve, look for points 57 | # where X axis (recall) changes value 58 | i = np.where(mrec[1:] != mrec[:-1])[0] 59 | 60 | # and sum (\Delta recall) * prec 61 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 62 | return ap 63 | 64 | def voc_eval(detpath, 65 | annopath, 66 | imagesetfile, 67 | classname, 68 | cachedir, 69 | ovthresh=0.5, 70 | use_07_metric=False): 71 | """rec, prec, ap = voc_eval(detpath, 72 | annopath, 73 | imagesetfile, 74 | classname, 75 | [ovthresh], 76 | [use_07_metric]) 77 | Top level function that does the PASCAL VOC evaluation. 78 | detpath: Path to detections 79 | detpath.format(classname) should produce the detection results file. 80 | annopath: Path to annotations 81 | annopath.format(imagename) should be the xml annotations file. 82 | imagesetfile: Text file containing the list of images, one image per line. 83 | classname: Category name (duh) 84 | cachedir: Directory for caching the annotations 85 | [ovthresh]: Overlap threshold (default = 0.5) 86 | [use_07_metric]: Whether to use VOC07's 11 point AP computation 87 | (default False) 88 | """ 89 | # assumes detections are in detpath.format(classname) 90 | # assumes annotations are in annopath.format(imagename) 91 | # assumes imagesetfile is a text file with each line an image name 92 | # cachedir caches the annotations in a pickle file 93 | 94 | # first load gt 95 | if not os.path.isdir(cachedir): 96 | os.mkdir(cachedir) 97 | cachefile = os.path.join(cachedir, 'annots.pkl') 98 | # read list of images 99 | with open(imagesetfile, 'r') as f: 100 | lines = f.readlines() 101 | imagenames = [x.strip() for x in lines] 102 | 103 | if not os.path.isfile(cachefile): 104 | # load annots 105 | recs = {} 106 | for i, imagename in enumerate(imagenames): 107 | recs[imagename] = parse_rec(annopath.format(imagename)) 108 | if i % 100 == 0: 109 | print('Reading annotation for {:d}/{:d}'.format( 110 | i + 1, len(imagenames))) 111 | # save 112 | print('Saving cached annotations to {:s}'.format(cachefile)) 113 | with open(cachefile, 'wb') as f: 114 | cPickle.dump(recs, f) 115 | else: 116 | # load 117 | with open(cachefile, 'rb') as f: 118 | recs = cPickle.load(f) 119 | 120 | # extract gt objects for this class 121 | class_recs = {} 122 | npos = 0 123 | for imagename in imagenames: 124 | R = [obj for obj in recs[imagename] if obj['name'] == classname] 125 | bbox = np.array([x['bbox'] for x in R]) 126 | difficult = np.array([x['difficult'] for x in R]).astype(np.bool) 127 | det = [False] * len(R) 128 | npos = npos + sum(~difficult) 129 | class_recs[imagename] = {'bbox': bbox, 130 | 'difficult': difficult, 131 | 'det': det} 132 | 133 | # read dets 134 | detfile = detpath.format(classname) 135 | with open(detfile, 'r') as f: 136 | lines = f.readlines() 137 | 138 | splitlines = [x.strip().split(' ') for x in lines] 139 | image_ids = [x[0] for x in splitlines] 140 | confidence = np.array([float(x[1]) for x in splitlines]) 141 | BB = np.array([[float(z) for z in x[2:]] for x in splitlines]) 142 | 143 | # sort by confidence 144 | sorted_ind = np.argsort(-confidence) 145 | sorted_scores = np.sort(-confidence) 146 | BB = BB[sorted_ind, :] 147 | image_ids = [image_ids[x] for x in sorted_ind] 148 | 149 | # go down dets and mark TPs and FPs 150 | nd = len(image_ids) 151 | tp = np.zeros(nd) 152 | fp = np.zeros(nd) 153 | for d in range(nd): 154 | R = class_recs[image_ids[d]] 155 | bb = BB[d, :].astype(float) 156 | ovmax = -np.inf 157 | BBGT = R['bbox'].astype(float) 158 | 159 | if BBGT.size > 0: 160 | # compute overlaps 161 | # intersection 162 | ixmin = np.maximum(BBGT[:, 0], bb[0]) 163 | iymin = np.maximum(BBGT[:, 1], bb[1]) 164 | ixmax = np.minimum(BBGT[:, 2], bb[2]) 165 | iymax = np.minimum(BBGT[:, 3], bb[3]) 166 | iw = np.maximum(ixmax - ixmin + 1., 0.) 167 | ih = np.maximum(iymax - iymin + 1., 0.) 168 | inters = iw * ih 169 | 170 | # union 171 | uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) + 172 | (BBGT[:, 2] - BBGT[:, 0] + 1.) * 173 | (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters) 174 | 175 | overlaps = inters / uni 176 | ovmax = np.max(overlaps) 177 | jmax = np.argmax(overlaps) 178 | 179 | if ovmax > ovthresh: 180 | if not R['difficult'][jmax]: 181 | if not R['det'][jmax]: 182 | tp[d] = 1. 183 | R['det'][jmax] = 1 184 | else: 185 | fp[d] = 1. 186 | else: 187 | fp[d] = 1. 188 | 189 | # compute precision recall 190 | fp = np.cumsum(fp) 191 | tp = np.cumsum(tp) 192 | rec = tp / float(npos) 193 | # avoid divide by zero in case the first detection matches a difficult 194 | # ground truth 195 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 196 | ap = voc_ap(rec, prec, use_07_metric) 197 | 198 | return rec, prec, ap 199 | --------------------------------------------------------------------------------