├── centernet ├── __init__.py ├── transforms │ ├── __init__.py │ └── transform_centeraffine.py ├── config.py ├── centernet_head.py ├── defaults.py ├── centernet_decode.py ├── centernet_gt.py ├── centernet_deconv.py ├── dataset_mapper.py └── centernet.py ├── .flake8 ├── configs ├── Base-CenterNet.yaml ├── centernet_r_18_C4_1x.yaml └── centernet_r_50_C4_1x.yaml ├── dev └── linter.sh ├── LICENSE ├── README.md ├── train_net.py └── .gitignore /centernet/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from .centernet import CenterNet 3 | -------------------------------------------------------------------------------- /centernet/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from detectron2.data.transforms import * 2 | 3 | from .transform_centeraffine import * 4 | 5 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 6 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | # This is an example .flake8 config, used when developing *Black* itself. 2 | # Keep in sync with setup.cfg which is used for source packages. 3 | 4 | [flake8] 5 | ignore = W503, E203, E221, C901 6 | max-line-length = 100 7 | max-complexity = 18 8 | select = B,C,E,F,W,T4,B9 9 | exclude = build,__init__.py 10 | -------------------------------------------------------------------------------- /configs/Base-CenterNet.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "CenterNet" 3 | MASK_ON: False 4 | BACKBONE: 5 | NAME: "build_torch_backbone" 6 | RESNETS: 7 | OUT_FEATURES: ["res3", "res4", "res5"] 8 | FPN: 9 | IN_FEATURES: ["res3", "res4", "res5"] 10 | CENTERNET: 11 | DECONV_CHANNEL: [512, 256, 128, 64] 12 | DECONV_KERNEL: [4, 4, 4] 13 | NUM_CLASSES: 80 14 | MODULATE_DEFORM: True 15 | BIAS_VALUE: -2.19 16 | DOWN_SCALE: 4 17 | MIN_OVERLAP: 0.7 18 | TENSOR_DIM: 128 19 | DATASETS: 20 | TRAIN: ("coco_2017_train",) 21 | TEST: ("coco_2017_val",) 22 | SOLVER: 23 | IMS_PER_BATCH: 128 24 | BASE_LR: 0.02 # Note that RetinaNet uses a different default learning rate 25 | STEPS: (30000, 40000) 26 | MAX_ITER: 45000 27 | INPUT: 28 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 29 | VERSION: 2 30 | -------------------------------------------------------------------------------- /configs/centernet_r_18_C4_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-CenterNet.yaml" 2 | MODEL: 3 | # WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" 4 | # WEIGHTS: 'https://download.pytorch.org/models/resnet50-19c8e357.pth' 5 | 6 | META_ARCHITECTURE: "CenterNet" 7 | BACKBONE: 8 | NAME: "build_torch_backbone" 9 | RESNETS: 10 | DEPTH: 18 11 | OUT_FEATURES: ["res5"] 12 | WEIGHTS: "" 13 | PIXEL_MEAN: [0.485, 0.456, 0.406] 14 | PIXEL_STD: [0.229, 0.224, 0.225] 15 | 16 | CENTERNET: 17 | DECONV_CHANNEL: [512, 256, 128, 64] 18 | DECONV_KERNEL: [4, 4, 4] 19 | NUM_CLASSES: 80 20 | MODULATE_DEFORM: True 21 | BIAS_VALUE: -2.19 22 | DOWN_SCALE: 4 23 | MIN_OVERLAP: 0.7 24 | TENSOR_DIM: 128 25 | 26 | SOLVER: 27 | STEPS: (81000, 108000) 28 | MAX_ITER: 126000 29 | 30 | INPUT: 31 | FORMAT: 'RGB' 32 | MIN_SIZE_TEST: 0 33 | 34 | -------------------------------------------------------------------------------- /configs/centernet_r_50_C4_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-CenterNet.yaml" 2 | MODEL: 3 | # WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" 4 | # WEIGHTS: 'https://download.pytorch.org/models/resnet50-19c8e357.pth' 5 | 6 | META_ARCHITECTURE: "CenterNet" 7 | BACKBONE: 8 | NAME: "build_torch_backbone" 9 | RESNETS: 10 | DEPTH: 50 11 | OUT_FEATURES: ["res5"] 12 | # WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" 13 | PIXEL_MEAN: [0.485, 0.456, 0.406] 14 | PIXEL_STD: [0.229, 0.224, 0.225] 15 | 16 | CENTERNET: 17 | DECONV_CHANNEL: [2048, 256, 128, 64] 18 | DECONV_KERNEL: [4, 4, 4] 19 | NUM_CLASSES: 80 20 | MODULATE_DEFORM: True 21 | BIAS_VALUE: -2.19 22 | DOWN_SCALE: 4 23 | MIN_OVERLAP: 0.7 24 | TENSOR_DIM: 128 25 | 26 | SOLVER: 27 | STEPS: (81000, 108000) 28 | MAX_ITER: 126000 29 | 30 | INPUT: 31 | FORMAT: 'RGB' 32 | -------------------------------------------------------------------------------- /dev/linter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | # Run this script at project root by "./dev/linter.sh" before you commit 5 | 6 | vergte() { 7 | [ "$2" = "$(echo -e "$1\n$2" | sort -V | head -n1)" ] 8 | } 9 | 10 | { 11 | black --version | grep "19.3b0" > /dev/null 12 | } || { 13 | echo "Linter requires black==19.3b0 !" 14 | exit 1 15 | } 16 | 17 | ISORT_TARGET_VERSION="4.3.21" 18 | ISORT_VERSION=$(isort -v | grep VERSION | awk '{print $2}') 19 | vergte "$ISORT_VERSION" "$ISORT_TARGET_VERSION" || { 20 | echo "Linter requires isort>=${ISORT_TARGET_VERSION} !" 21 | exit 1 22 | } 23 | 24 | set -v 25 | 26 | echo "Running isort ..." 27 | isort -y -sp . --atomic 28 | 29 | echo "Running black ..." 30 | black -l 100 . 31 | 32 | echo "Running flake8 ..." 33 | if [ -x "$(command -v flake8-3)" ]; then 34 | flake8-3 . 35 | else 36 | python3 -m flake8 . 37 | fi 38 | 39 | # echo "Running mypy ..." 40 | # Pytorch does not have enough type annotations 41 | # mypy detectron2/solver detectron2/structures detectron2/config 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Li Bin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /centernet/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from detectron2.config import CfgNode as CN 5 | 6 | 7 | def add_centernet_config(cfg): 8 | """ 9 | Add config for tridentnet. 10 | """ 11 | _C = cfg 12 | 13 | _C.MODEL.CENTERNET = CN() 14 | _C.MODEL.CENTERNET.DECONV_CHANNEL = [512, 256, 128, 64] 15 | _C.MODEL.CENTERNET.DECONV_KERNEL = [4, 4, 4] 16 | _C.MODEL.CENTERNET.NUM_CLASSES = 80 17 | _C.MODEL.CENTERNET.MODULATE_DEFORM = True 18 | _C.MODEL.CENTERNET.BIAS_VALUE = -2.19 19 | _C.MODEL.CENTERNET.DOWN_SCALE = 4 20 | _C.MODEL.CENTERNET.MIN_OVERLAP = 0.7 21 | _C.MODEL.CENTERNET.TENSOR_DIM = 128 22 | _C.MODEL.CENTERNET.IN_FEATURES = ["res5"] 23 | _C.MODEL.CENTERNET.OUTPUT_SIZE = [128, 128] 24 | _C.MODEL.CENTERNET.TRAIN_PIPELINES = [ 25 | ("CenterAffine", dict(boarder=128, output_size=(512, 512), random_aug=True)), 26 | ("RandomFlip", dict()), 27 | ("RandomBrightness", dict(intensity_min=0.6, intensity_max=1.4)), 28 | ("RandomContrast", dict(intensity_min=0.6, intensity_max=1.4)), 29 | ("RandomSaturation", dict(intensity_min=0.6, intensity_max=1.4)), 30 | ("RandomLighting", dict(scale=0.1)), 31 | ] 32 | _C.MODEL.CENTERNET.TEST_PIPELINES = [] 33 | _C.MODEL.CENTERNET.LOSS = CN() 34 | _C.MODEL.CENTERNET.LOSS.CLS_WEIGHT = 1 35 | _C.MODEL.CENTERNET.LOSS.WH_WEIGHT = 0.1 36 | _C.MODEL.CENTERNET.LOSS.REG_WEIGHT = 1 37 | _C.INPUT.FORMAT = "RGB" 38 | -------------------------------------------------------------------------------- /centernet/centernet_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SingleHead(nn.Module): 6 | def __init__(self, in_channel, out_channel, bias_fill=False, bias_value=0): 7 | super(SingleHead, self).__init__() 8 | self.feat_conv = nn.Conv2d(in_channel, in_channel, kernel_size=3, padding=1) 9 | self.relu = nn.ReLU() 10 | self.out_conv = nn.Conv2d(in_channel, out_channel, kernel_size=1) 11 | if bias_fill: 12 | self.out_conv.bias.data.fill_(bias_value) 13 | 14 | def forward(self, x): 15 | x = self.feat_conv(x) 16 | x = self.relu(x) 17 | x = self.out_conv(x) 18 | return x 19 | 20 | 21 | class CenternetHead(nn.Module): 22 | """ 23 | The head used in CenterNet for object classification and box regression. 24 | It has three subnet, with a common structure but separate parameters. 25 | """ 26 | 27 | def __init__(self, cfg): 28 | super(CenternetHead, self).__init__() 29 | self.cls_head = SingleHead( 30 | 64, 31 | cfg.MODEL.CENTERNET.NUM_CLASSES, 32 | bias_fill=True, 33 | bias_value=cfg.MODEL.CENTERNET.BIAS_VALUE, 34 | ) 35 | self.wh_head = SingleHead(64, 2) 36 | self.reg_head = SingleHead(64, 2) 37 | 38 | def forward(self, x): 39 | cls = self.cls_head(x) 40 | cls = torch.sigmoid(cls) 41 | wh = self.wh_head(x) 42 | reg = self.reg_head(x) 43 | pred = {"cls": cls, "wh": wh, "reg": reg} 44 | return pred 45 | -------------------------------------------------------------------------------- /centernet/defaults.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | """ 5 | This file contains components with some default boilerplate logic user may need 6 | in training / testing. They will not work for everyone, but many users may find them useful. 7 | The behavior of functions/classes in this file is subject to change, 8 | since they are meant to represent the "common default behavior" people need in their projects. 9 | """ 10 | 11 | import logging 12 | 13 | from detectron2.data import build_detection_test_loader, build_detection_train_loader 14 | from detectron2.engine.defaults import DefaultTrainer 15 | 16 | # from detectron2.modeling import build_model 17 | from centernet.centernet import build_model 18 | from centernet.dataset_mapper import DatasetMapper 19 | 20 | __all__ = ["DefaultTrainer2"] 21 | 22 | 23 | class DefaultTrainer2(DefaultTrainer): 24 | def __init__(self, cfg): 25 | 26 | super().__init__(cfg) 27 | 28 | @classmethod 29 | def build_model(cls, cfg): 30 | """ 31 | Returns: 32 | torch.nn.Module: 33 | It now calls :func:`detectron2.modeling.build_model`. 34 | Overwrite it if you'd like a different model. 35 | """ 36 | model = build_model(cfg) 37 | logger = logging.getLogger(__name__) 38 | logger.info("Model:\n{}".format(model)) 39 | return model 40 | 41 | @classmethod 42 | def build_test_loader(cls, cfg, dataset_name): 43 | return build_detection_test_loader(cfg, dataset_name, mapper=DatasetMapper(cfg, False)) 44 | 45 | @classmethod 46 | def build_train_loader(cls, cfg): 47 | return build_detection_train_loader(cfg, mapper=DatasetMapper(cfg, True)) 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CenterNet-better-plus 2 | 3 | This repo is implemented based on [detectron2](https://github.com/facebookresearch/detectron2) and [CenterNet-better](https://github.com/FateScript/CenterNet-better/edit/master/README.md) 4 | 5 | ## Requirements 6 | 7 | - Python >= 3.6 8 | - PyTorch >= 1.4 9 | - torchvision that matches the PyTorch installation. 10 | - OpenCV 11 | - pycocotools 12 | 13 | ```shell 14 | pip install cython; pip install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI' 15 | ``` 16 | 17 | - GCC >= 4.9 18 | 19 | ```shell 20 | gcc --version 21 | ``` 22 | 23 | - detectron2 24 | 25 | ```shell 26 | pip install -U 'git+https://github.com/facebookresearch/detectron2.git' 27 | ``` 28 | 29 | ### Training 30 | 31 | ```shell 32 | python train_net.py --num-gpus 8 --config-file configs/centernet_r_18_C4_1x.yaml 33 | ``` 34 | 35 | ### Testing and Evaluation 36 | 37 | ```shell 38 | python train_net.py --num-gpus 8 --config-file configs/centernet_r_18_C4_1x.yaml --eval-only MODEL.WEIGHTS model_0007999.pth 39 | ``` 40 | 41 | ## Performance 42 | 43 | This repo use less training time to get a better performance, it nearly spend half training time and get 1~2 pts higher mAP compared with the old repo. Here is the table of performance. 44 | 45 | Backbone ResNet-50 46 | 47 | | Code | mAP | 48 | | ---------------- | ---- | 49 | | ours | | 50 | | centernet-better | 35.1 | 51 | 52 | Backbone ResNet-18 53 | | Code | mAP | 54 | | ---------------- | ---- | 55 | | ours | 29.7 | 56 | | centernet-better | 29.8 | 57 | 58 | 59 | ## What\'s comming 60 | 61 | - [ ] Support DLA backbone 62 | - [ ] Support Hourglass backbone 63 | - [ ] Support KeyPoints dataset 64 | 65 | ## Acknowledgement 66 | 67 | - [detectron2](https://github.com/facebookresearch/detectron2) 68 | - [CenterNet](https://github.com/xingyizhou/CenterNet) 69 | - [CenterNet-better](https://github.com/FateScript/CenterNet-better) 70 | -------------------------------------------------------------------------------- /train_net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | 4 | from detectron2.checkpoint import DetectionCheckpointer 5 | from detectron2.config import get_cfg 6 | from detectron2.engine import default_argument_parser, default_setup, launch 7 | from detectron2.evaluation import COCOEvaluator 8 | from detectron2.evaluation.testing import verify_results 9 | from detectron2.utils import comm 10 | 11 | from centernet.config import add_centernet_config 12 | from centernet.defaults import DefaultTrainer2 13 | 14 | 15 | class Trainer(DefaultTrainer2): 16 | @classmethod 17 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 18 | if output_folder is None: 19 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 20 | return COCOEvaluator(dataset_name, cfg, True, output_folder) 21 | 22 | 23 | def setup(args): 24 | """ 25 | Create configs and perform basic setups. 26 | """ 27 | cfg = get_cfg() 28 | add_centernet_config(cfg) 29 | cfg.merge_from_file(args.config_file) 30 | cfg.merge_from_list(args.opts) 31 | cfg.freeze() 32 | default_setup(cfg, args) 33 | return cfg 34 | 35 | 36 | def main(args): 37 | cfg = setup(args) 38 | 39 | if args.eval_only: 40 | model = Trainer.build_model(cfg) 41 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 42 | cfg.MODEL.WEIGHTS, resume=args.resume 43 | ) 44 | res = Trainer.test(cfg, model) 45 | if comm.is_main_process(): 46 | verify_results(cfg, res) 47 | return res 48 | """ 49 | If you'd like to do anything fancier than the standard training logic, 50 | consider writing your own training loop or subclassing the trainer. 51 | """ 52 | trainer = Trainer(cfg) 53 | trainer.resume_or_load(resume=args.resume) 54 | 55 | return trainer.train() 56 | 57 | 58 | if __name__ == "__main__": 59 | args = default_argument_parser().parse_args() 60 | print("Command Line Args:", args) 61 | launch( 62 | main, 63 | args.num_gpus, 64 | num_machines=args.num_machines, 65 | machine_rank=args.machine_rank, 66 | dist_url=args.dist_url, 67 | args=(args,), 68 | ) 69 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | .DS_Store 133 | .vscode/ 134 | datasets/ 135 | output/ 136 | *.pth -------------------------------------------------------------------------------- /centernet/centernet_decode.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from .transforms import CenterAffine 7 | 8 | 9 | def gather_feature(fmap, index, mask=None, use_transform=False): 10 | if use_transform: 11 | # change a (N, C, H, W) tenor to (N, HxW, C) shape 12 | batch, channel = fmap.shape[:2] 13 | fmap = fmap.view(batch, channel, -1).permute((0, 2, 1)).contiguous() 14 | 15 | dim = fmap.size(-1) 16 | index = index.unsqueeze(len(index.shape)).expand(*index.shape, dim) 17 | fmap = fmap.gather(dim=1, index=index) 18 | if mask is not None: 19 | # this part is not called in Res18 dcn COCO 20 | mask = mask.unsqueeze(2).expand_as(fmap) 21 | fmap = fmap[mask] 22 | fmap = fmap.reshape(-1, dim) 23 | return fmap 24 | 25 | 26 | class CenterNetDecoder(object): 27 | @staticmethod 28 | def decode(fmap, wh, reg=None, cat_spec_wh=False, K=100): 29 | r""" 30 | decode output feature map to detection results 31 | 32 | Args: 33 | fmap(Tensor): output feature map 34 | wh(Tensor): tensor that represents predicted width-height 35 | reg(Tensor): tensor that represens regression of center points 36 | cat_spec_wh(bool): whether apply gather on tensor `wh` or not 37 | K(int): topk value 38 | """ 39 | batch, channel, height, width = fmap.shape 40 | 41 | fmap = CenterNetDecoder.pseudo_nms(fmap) 42 | 43 | scores, index, clses, ys, xs = CenterNetDecoder.topk_score(fmap, K=K) 44 | if reg is not None: 45 | reg = gather_feature(reg, index, use_transform=True) 46 | reg = reg.reshape(batch, K, 2) 47 | xs = xs.view(batch, K, 1) + reg[:, :, 0:1] 48 | ys = ys.view(batch, K, 1) + reg[:, :, 1:2] 49 | else: 50 | xs = xs.view(batch, K, 1) + 0.5 51 | ys = ys.view(batch, K, 1) + 0.5 52 | wh = gather_feature(wh, index, use_transform=True) 53 | 54 | if cat_spec_wh: 55 | wh = wh.view(batch, K, channel, 2) 56 | clses_ind = clses.view(batch, K, 1, 1).expand(batch, K, 1, 2).long() 57 | wh = wh.gather(2, clses_ind).reshape(batch, K, 2) 58 | else: 59 | wh = wh.reshape(batch, K, 2) 60 | 61 | clses = clses.reshape(batch, K, 1).float() 62 | scores = scores.reshape(batch, K, 1) 63 | 64 | half_w, half_h = wh[..., 0:1] / 2, wh[..., 1:2] / 2 65 | bboxes = torch.cat([xs - half_w, ys - half_h, xs + half_w, ys + half_h], dim=2) 66 | 67 | detections = (bboxes, scores, clses) 68 | 69 | return detections 70 | 71 | @staticmethod 72 | def transform_boxes(boxes, img_info, scale=1): 73 | r""" 74 | transform predicted boxes to target boxes 75 | 76 | Args: 77 | boxes(Tensor): torch Tensor with (Batch, N, 4) shape 78 | img_info(dict): dict contains all information of original image 79 | scale(float): used for multiscale testing 80 | """ 81 | boxes = boxes.cpu().numpy().reshape(-1, 4) 82 | 83 | center = img_info["center"] 84 | size = img_info["size"] 85 | output_size = (img_info["width"], img_info["height"]) 86 | src, dst = CenterAffine.generate_src_and_dst(center, size, output_size) 87 | trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) 88 | 89 | coords = boxes.reshape(-1, 2) 90 | aug_coords = np.column_stack((coords, np.ones(coords.shape[0]))) 91 | target_boxes = np.dot(aug_coords, trans.T).reshape(-1, 4) 92 | return target_boxes 93 | 94 | @staticmethod 95 | def pseudo_nms(fmap, pool_size=3): 96 | r""" 97 | apply max pooling to get the same effect of nms 98 | 99 | Args: 100 | fmap(Tensor): output tensor of previous step 101 | pool_size(int): size of max-pooling 102 | """ 103 | pad = (pool_size - 1) // 2 104 | fmap_max = F.max_pool2d(fmap, pool_size, stride=1, padding=pad) 105 | keep = (fmap_max == fmap).float() 106 | return fmap * keep 107 | 108 | @staticmethod 109 | def topk_score(scores, K=40): 110 | """ 111 | get top K point in score map 112 | """ 113 | batch, channel, height, width = scores.shape 114 | 115 | # get topk score and its index in every H x W(channel dim) feature map 116 | topk_scores, topk_inds = torch.topk(scores.reshape(batch, channel, -1), K) 117 | 118 | topk_inds = topk_inds % (height * width) 119 | topk_ys = (topk_inds / width).int().float() 120 | topk_xs = (topk_inds % width).int().float() 121 | 122 | # get all topk in in a batch 123 | topk_score, index = torch.topk(topk_scores.reshape(batch, -1), K) 124 | # div by K because index is grouped by K(C x K shape) 125 | topk_clses = (index / K).int() 126 | topk_inds = gather_feature(topk_inds.view(batch, -1, 1), index).reshape(batch, K) 127 | topk_ys = gather_feature(topk_ys.reshape(batch, -1, 1), index).reshape(batch, K) 128 | topk_xs = gather_feature(topk_xs.reshape(batch, -1, 1), index).reshape(batch, K) 129 | 130 | return topk_score, topk_inds, topk_clses, topk_ys, topk_xs 131 | -------------------------------------------------------------------------------- /centernet/transforms/transform_centeraffine.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | # Modified by Feng Wang. 4 | # File: transformer.py 5 | 6 | import cv2 7 | import numpy as np 8 | from detectron2.data.transforms import Transform, TransformGen 9 | 10 | __all__ = ["CenterAffine", "AffineTransform"] 11 | 12 | 13 | class AffineTransform(Transform): 14 | """ 15 | Augmentation from CenterNet 16 | """ 17 | 18 | def __init__(self, src, dst, output_size): 19 | """ 20 | output_size:(w, h) 21 | """ 22 | super().__init__() 23 | affine = cv2.getAffineTransform(np.float32(src), np.float32(dst)) 24 | self._set_attributes(locals()) 25 | 26 | def apply_image(self, img: np.ndarray) -> np.ndarray: 27 | """ 28 | Apply AffineTransform for the image(s). 29 | Args: 30 | img (ndarray): of shape HxW, HxWxC, or NxHxWxC. The array can be 31 | of type uint8 in range [0, 255], or floating point in range 32 | [0, 1] or [0, 255]. 33 | Returns: 34 | ndarray: the image(s) after applying affine transform. 35 | """ 36 | return cv2.warpAffine(img, self.affine, self.output_size, flags=cv2.INTER_LINEAR) 37 | 38 | def apply_coords(self, coords: np.ndarray) -> np.ndarray: 39 | """ 40 | Affine the coordinates. 41 | Args: 42 | coords (ndarray): floating point array of shape Nx2. Each row is 43 | (x, y). 44 | Returns: 45 | ndarray: the flipped coordinates. 46 | Note: 47 | The inputs are floating point coordinates, not pixel indices. 48 | Therefore they are flipped by `(W - x, H - y)`, not 49 | `(W - 1 - x, H 1 - y)`. 50 | """ 51 | # aug_coord (N, 3) shape, self.affine (2, 3) shape 52 | w, h = self.output_size 53 | aug_coords = np.column_stack((coords, np.ones(coords.shape[0]))) 54 | coords = np.dot(aug_coords, self.affine.T) 55 | coords[..., 0] = np.clip(coords[..., 0], 0, w - 1) 56 | coords[..., 1] = np.clip(coords[..., 1], 0, h - 1) 57 | return coords 58 | 59 | 60 | class CenterAffine(TransformGen): 61 | """ 62 | Affine Transform for CenterNet 63 | """ 64 | 65 | def __init__(self, boarder, output_size, random_aug=True): 66 | """ 67 | Args: 68 | boarder(int): boarder size of image 69 | output_size(tuple): a tuple represents (width, height) of image 70 | random_aug(bool): whether apply random augmentation on annos or not 71 | """ 72 | super().__init__() 73 | self._init(locals()) 74 | 75 | def get_transform(self, img): 76 | """ 77 | generate one `AffineTransform` for input image 78 | """ 79 | img_shape = img.shape[:2] 80 | center, scale = self.generate_center_and_scale(img_shape) 81 | src, dst = self.generate_src_and_dst(center, scale, self.output_size) 82 | return AffineTransform(src, dst, self.output_size) 83 | 84 | @staticmethod 85 | def _get_boarder(boarder, size): 86 | """ 87 | decide the boarder size of image 88 | """ 89 | # NOTE This func may be reimplemented in the future 90 | i = 1 91 | size //= 2 92 | while size <= boarder // i: 93 | i *= 2 94 | return boarder // i 95 | 96 | def generate_center_and_scale(self, img_shape): 97 | r""" 98 | generate center and scale for image randomly 99 | Args: 100 | shape(tuple): a tuple represents (height, width) of image 101 | """ 102 | height, width = img_shape 103 | center = np.array([width / 2, height / 2], dtype=np.float32) 104 | scale = float(max(img_shape)) 105 | if self.random_aug: 106 | scale = scale * np.random.choice(np.arange(0.6, 1.4, 0.1)) 107 | h_boarder = self._get_boarder(self.boarder, height) 108 | w_boarder = self._get_boarder(self.boarder, width) 109 | center[0] = np.random.randint(low=w_boarder, high=width - w_boarder) 110 | center[1] = np.random.randint(low=h_boarder, high=height - h_boarder) 111 | else: 112 | raise NotImplementedError("Non-random augmentation not implemented") 113 | 114 | return center, scale 115 | 116 | @staticmethod 117 | def generate_src_and_dst(center, size, output_size): 118 | r""" 119 | generate source and destination for affine transform 120 | """ 121 | if not isinstance(size, np.ndarray) and not isinstance(size, list): 122 | size = np.array([size, size], dtype=np.float32) 123 | src = np.zeros((3, 2), dtype=np.float32) 124 | src_w = size[0] 125 | src_dir = [0, src_w * -0.5] 126 | src[0, :] = center 127 | src[1, :] = src[0, :] + src_dir 128 | src[2, :] = src[1, :] + (src_dir[1], -src_dir[0]) 129 | 130 | dst = np.zeros((3, 2), dtype=np.float32) 131 | dst_w, dst_h = output_size 132 | dst_dir = [0, dst_w * -0.5] 133 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5] 134 | dst[1, :] = dst[0, :] + dst_dir 135 | dst[2, :] = dst[1, :] + (dst_dir[1], -dst_dir[0]) 136 | 137 | return src, dst 138 | -------------------------------------------------------------------------------- /centernet/centernet_gt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class CenterNetGT(object): 6 | @staticmethod 7 | def generate(config, batched_input): 8 | box_scale = 1 / config.MODEL.CENTERNET.DOWN_SCALE 9 | num_classes = config.MODEL.CENTERNET.NUM_CLASSES 10 | output_size = config.MODEL.CENTERNET.OUTPUT_SIZE 11 | min_overlap = config.MODEL.CENTERNET.MIN_OVERLAP 12 | tensor_dim = config.MODEL.CENTERNET.TENSOR_DIM 13 | 14 | scoremap_list, wh_list, reg_list, reg_mask_list, index_list = [[] for i in range(5)] 15 | for data in batched_input: 16 | # img_size = (data['height'], data['width']) 17 | 18 | bbox_dict = data["instances"].get_fields() 19 | 20 | # init gt tensors 21 | gt_scoremap = torch.zeros(num_classes, *output_size) 22 | gt_wh = torch.zeros(tensor_dim, 2) 23 | gt_reg = torch.zeros_like(gt_wh) 24 | reg_mask = torch.zeros(tensor_dim) 25 | gt_index = torch.zeros(tensor_dim) 26 | # pass 27 | 28 | boxes, classes = bbox_dict["gt_boxes"], bbox_dict["gt_classes"] 29 | num_boxes = boxes.tensor.shape[0] 30 | boxes.scale(box_scale, box_scale) 31 | 32 | centers = boxes.get_centers() 33 | centers_int = centers.to(torch.int32) 34 | gt_index[:num_boxes] = centers_int[..., 1] * output_size[0] + centers_int[..., 0] 35 | gt_reg[:num_boxes] = centers - centers_int 36 | reg_mask[:num_boxes] = 1 37 | 38 | wh = torch.zeros_like(centers) 39 | box_tensor = boxes.tensor 40 | wh[..., 0] = box_tensor[..., 2] - box_tensor[..., 0] 41 | wh[..., 1] = box_tensor[..., 3] - box_tensor[..., 1] 42 | CenterNetGT.generate_score_map(gt_scoremap, classes, wh, centers_int, min_overlap) 43 | gt_wh[:num_boxes] = wh 44 | 45 | scoremap_list.append(gt_scoremap) 46 | wh_list.append(gt_wh) 47 | reg_list.append(gt_reg) 48 | reg_mask_list.append(reg_mask) 49 | index_list.append(gt_index) 50 | 51 | gt_dict = { 52 | "score_map": torch.stack(scoremap_list, dim=0), 53 | "wh": torch.stack(wh_list, dim=0), 54 | "reg": torch.stack(reg_list, dim=0), 55 | "reg_mask": torch.stack(reg_mask_list, dim=0), 56 | "index": torch.stack(index_list, dim=0), 57 | } 58 | return gt_dict 59 | 60 | @staticmethod 61 | def generate_score_map(fmap, gt_class, gt_wh, centers_int, min_overlap): 62 | radius = CenterNetGT.get_gaussian_radius(gt_wh, min_overlap) 63 | radius = torch.clamp_min(radius, 0) 64 | radius = radius.type(torch.int).cpu().numpy() 65 | for i in range(gt_class.shape[0]): 66 | channel_index = gt_class[i] 67 | CenterNetGT.draw_gaussian(fmap[channel_index], centers_int[i], radius[i]) 68 | 69 | @staticmethod 70 | def get_gaussian_radius(box_size, min_overlap): 71 | """ 72 | copyed from CornerNet 73 | box_size (w, h), it could be a torch.Tensor, numpy.ndarray, list or tuple 74 | notice: we are using a bug-version, please refer to fix bug version in CornerNet 75 | """ 76 | box_tensor = torch.Tensor(box_size) 77 | width, height = box_tensor[..., 0], box_tensor[..., 1] 78 | 79 | a1 = 1 80 | b1 = height + width 81 | c1 = width * height * (1 - min_overlap) / (1 + min_overlap) 82 | sq1 = torch.sqrt(b1 ** 2 - 4 * a1 * c1) 83 | r1 = (b1 + sq1) / 2 84 | 85 | a2 = 4 86 | b2 = 2 * (height + width) 87 | c2 = (1 - min_overlap) * width * height 88 | sq2 = torch.sqrt(b2 ** 2 - 4 * a2 * c2) 89 | r2 = (b2 + sq2) / 2 90 | 91 | a3 = 4 * min_overlap 92 | b3 = -2 * min_overlap * (height + width) 93 | c3 = (min_overlap - 1) * width * height 94 | sq3 = torch.sqrt(b3 ** 2 - 4 * a3 * c3) 95 | r3 = (b3 + sq3) / 2 96 | 97 | return torch.min(r1, torch.min(r2, r3)) 98 | 99 | @staticmethod 100 | def gaussian2D(radius, sigma=1): 101 | # m, n = [(s - 1.) / 2. for s in shape] 102 | m, n = radius 103 | y, x = np.ogrid[-m : m + 1, -n : n + 1] 104 | 105 | gauss = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) 106 | gauss[gauss < np.finfo(gauss.dtype).eps * gauss.max()] = 0 107 | return gauss 108 | 109 | @staticmethod 110 | def draw_gaussian(fmap, center, radius, k=1): 111 | diameter = 2 * radius + 1 112 | gaussian = CenterNetGT.gaussian2D((radius, radius), sigma=diameter / 6) 113 | gaussian = torch.Tensor(gaussian) 114 | x, y = int(center[0]), int(center[1]) 115 | height, width = fmap.shape[:2] 116 | 117 | left, right = min(x, radius), min(width - x, radius + 1) 118 | top, bottom = min(y, radius), min(height - y, radius + 1) 119 | 120 | masked_fmap = fmap[y - top : y + bottom, x - left : x + right] 121 | masked_gaussian = gaussian[radius - top : radius + bottom, radius - left : radius + right] 122 | if min(masked_gaussian.shape) > 0 and min(masked_fmap.shape) > 0: 123 | masked_fmap = torch.max(masked_fmap, masked_gaussian * k) 124 | fmap[y - top : y + bottom, x - left : x + right] = masked_fmap 125 | # return fmap 126 | -------------------------------------------------------------------------------- /centernet/centernet_deconv.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from detectron2.layers.deform_conv import DeformConv, ModulatedDeformConv 6 | 7 | 8 | class DeformConvWithOff(nn.Module): 9 | def __init__( 10 | self, 11 | in_channels, 12 | out_channels, 13 | kernel_size=3, 14 | stride=1, 15 | padding=1, 16 | dilation=1, 17 | deformable_groups=1, 18 | ): 19 | super(DeformConvWithOff, self).__init__() 20 | self.offset_conv = nn.Conv2d( 21 | in_channels, 22 | deformable_groups * 2 * kernel_size * kernel_size, 23 | kernel_size=kernel_size, 24 | stride=stride, 25 | padding=padding, 26 | ) 27 | self.dcn = DeformConv( 28 | in_channels, 29 | out_channels, 30 | kernel_size=kernel_size, 31 | stride=stride, 32 | padding=padding, 33 | dilation=dilation, 34 | deformable_groups=deformable_groups, 35 | ) 36 | 37 | def forward(self, input): 38 | offset = self.offset_conv(input) 39 | output = self.dcn(input, offset) 40 | return output 41 | 42 | 43 | class ModulatedDeformConvWithOff(nn.Module): 44 | def __init__( 45 | self, 46 | in_channels, 47 | out_channels, 48 | kernel_size=3, 49 | stride=1, 50 | padding=1, 51 | dilation=1, 52 | deformable_groups=1, 53 | ): 54 | super(ModulatedDeformConvWithOff, self).__init__() 55 | self.offset_mask_conv = nn.Conv2d( 56 | in_channels, 57 | deformable_groups * 3 * kernel_size * kernel_size, 58 | kernel_size=kernel_size, 59 | stride=stride, 60 | padding=padding, 61 | ) 62 | self.dcnv2 = ModulatedDeformConv( 63 | in_channels, 64 | out_channels, 65 | kernel_size=kernel_size, 66 | stride=stride, 67 | padding=padding, 68 | dilation=dilation, 69 | deformable_groups=deformable_groups, 70 | ) 71 | 72 | def forward(self, input): 73 | x = self.offset_mask_conv(input) 74 | o1, o2, mask = torch.chunk(x, 3, dim=1) 75 | offset = torch.cat((o1, o2), dim=1) 76 | mask = torch.sigmoid(mask) 77 | output = self.dcnv2(input, offset, mask) 78 | return output 79 | 80 | 81 | class DeconvLayer(nn.Module): 82 | def __init__( 83 | self, 84 | in_planes, 85 | out_planes, 86 | deconv_kernel, 87 | deconv_stride=2, 88 | deconv_pad=1, 89 | deconv_out_pad=0, 90 | modulate_deform=True, 91 | ): 92 | super(DeconvLayer, self).__init__() 93 | if modulate_deform: 94 | self.dcn = ModulatedDeformConvWithOff( 95 | in_planes, out_planes, kernel_size=3, deformable_groups=1 96 | ) 97 | else: 98 | self.dcn = DeformConvWithOff(in_planes, out_planes, kernel_size=3, deformable_groups=1) 99 | 100 | self.dcn_bn = nn.BatchNorm2d(out_planes) 101 | self.up_sample = nn.ConvTranspose2d( 102 | in_channels=out_planes, 103 | out_channels=out_planes, 104 | kernel_size=deconv_kernel, 105 | stride=deconv_stride, 106 | padding=deconv_pad, 107 | output_padding=deconv_out_pad, 108 | bias=False, 109 | ) 110 | self._deconv_init() 111 | self.up_bn = nn.BatchNorm2d(out_planes) 112 | self.relu = nn.ReLU() 113 | 114 | def forward(self, x): 115 | x = self.dcn(x) 116 | x = self.dcn_bn(x) 117 | x = self.relu(x) 118 | x = self.up_sample(x) 119 | x = self.up_bn(x) 120 | x = self.relu(x) 121 | return x 122 | 123 | def _deconv_init(self): 124 | w = self.up_sample.weight.data 125 | f = math.ceil(w.size(2) / 2) 126 | c = (2 * f - 1 - f % 2) / (2.0 * f) 127 | for i in range(w.size(2)): 128 | for j in range(w.size(3)): 129 | w[0, 0, i, j] = (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) 130 | for c in range(1, w.size(0)): 131 | w[c, 0, :, :] = w[0, 0, :, :] 132 | 133 | 134 | class CenternetDeconv(nn.Module): 135 | """ 136 | The head used in CenterNet for object classification and box regression. 137 | It has three subnet, with a common structure but separate parameters. 138 | """ 139 | 140 | def __init__(self, cfg): 141 | super(CenternetDeconv, self).__init__() 142 | # modify into config 143 | channels = cfg.MODEL.CENTERNET.DECONV_CHANNEL 144 | deconv_kernel = cfg.MODEL.CENTERNET.DECONV_KERNEL 145 | modulate_deform = cfg.MODEL.CENTERNET.MODULATE_DEFORM 146 | self.deconv1 = DeconvLayer( 147 | channels[0], 148 | channels[1], 149 | deconv_kernel=deconv_kernel[0], 150 | modulate_deform=modulate_deform, 151 | ) 152 | self.deconv2 = DeconvLayer( 153 | channels[1], 154 | channels[2], 155 | deconv_kernel=deconv_kernel[1], 156 | modulate_deform=modulate_deform, 157 | ) 158 | self.deconv3 = DeconvLayer( 159 | channels[2], 160 | channels[3], 161 | deconv_kernel=deconv_kernel[2], 162 | modulate_deform=modulate_deform, 163 | ) 164 | 165 | def forward(self, x): 166 | x = self.deconv1(x) 167 | x = self.deconv2(x) 168 | x = self.deconv3(x) 169 | return x 170 | -------------------------------------------------------------------------------- /centernet/dataset_mapper.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | 4 | import numpy as np 5 | import torch 6 | from detectron2.data import detection_utils as utils 7 | from fvcore.common.file_io import PathManager 8 | from PIL import Image 9 | 10 | from . import transforms as T 11 | 12 | """ 13 | This file contains the default mapping that's applied to "dataset dicts". 14 | """ 15 | 16 | __all__ = ["DatasetMapper"] 17 | 18 | 19 | def build_transform_gen(cfg, is_train): 20 | """ 21 | Create a list of :class:`TransformGen` from config. 22 | Now it includes resizing and flipping. 23 | 24 | Returns: 25 | list[TransformGen] 26 | """ 27 | logger = logging.getLogger(__name__) 28 | 29 | tfm_gens = [] 30 | 31 | if is_train: 32 | for (aug, args) in cfg.MODEL.CENTERNET.TRAIN_PIPELINES: 33 | if aug == "ResizeShortestEdge": 34 | check_sample_valid(args) 35 | tfm_gens.append(getattr(T, aug)(**args)) 36 | else: 37 | for (aug, args) in cfg.MODEL.CENTERNET.TEST_PIPELINES: 38 | if aug == "ResizeShortestEdge": 39 | check_sample_valid(args) 40 | tfm_gens.append(getattr(T, aug)(**args)) 41 | 42 | logger.info("TransformGens used: " + str(tfm_gens)) 43 | 44 | return tfm_gens 45 | 46 | 47 | def check_sample_valid(args): 48 | if args["sample_style"] == "range": 49 | assert ( 50 | len(args["min_size"]) == 2 51 | ), f"more than 2 ({len(args['min_size'])}) min_size(s) are provided for ranges" 52 | 53 | 54 | class DatasetMapper: 55 | """ 56 | A callable which takes a dataset dict in centernet Dataset format, 57 | and map it into a format used by the model. 58 | 59 | This is the default callable to be used to map your dataset dict into training data. 60 | You may need to follow it to implement your own one for customized logic. 61 | 62 | The callable currently does the following: 63 | 64 | 1. Read the image from "file_name" 65 | 2. Applies cropping/geometric transforms to the image and annotations 66 | 3. Prepare data and annotations to Tensor and :class:`Instances` 67 | """ 68 | 69 | def __init__(self, cfg, is_train=True): 70 | 71 | if cfg.INPUT.CROP.ENABLED and is_train: 72 | self.crop_gen = T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE) 73 | logging.getLogger(__name__).info("CropGen used in training: " + str(self.crop_gen)) 74 | else: 75 | self.crop_gen = None 76 | 77 | self.eval_with_gt = cfg.TEST.get("WITH_GT", False) 78 | 79 | self.tfm_gens = build_transform_gen(cfg, is_train) 80 | 81 | # fmt: off 82 | self.img_format = cfg.INPUT.FORMAT 83 | self.mask_on = cfg.MODEL.MASK_ON 84 | self.mask_format = cfg.INPUT.MASK_FORMAT 85 | self.keypoint_on = cfg.MODEL.KEYPOINT_ON 86 | self.load_proposals = cfg.MODEL.LOAD_PROPOSALS 87 | # fmt: on 88 | if self.keypoint_on and is_train: 89 | # Flip only makes sense in training 90 | self.keypoint_hflip_indices = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN) 91 | else: 92 | self.keypoint_hflip_indices = None 93 | 94 | if self.load_proposals: 95 | self.min_box_side_len = cfg.MODEL.PROPOSAL_GENERATOR.MIN_SIZE 96 | self.proposal_topk = ( 97 | cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN 98 | if is_train 99 | else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST 100 | ) 101 | self.is_train = is_train 102 | 103 | def __call__(self, dataset_dict): 104 | """ 105 | Args: 106 | dataset_dict (dict): Metadata of one image, in centernet Dataset format. 107 | 108 | Returns: 109 | dict: a format that builtin models in centernet accept 110 | """ 111 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 112 | # USER: Write your own image loading if it's not from a file 113 | image = utils.read_image(dataset_dict["file_name"], format=self.img_format) 114 | utils.check_image_size(dataset_dict, image) 115 | 116 | if "annotations" not in dataset_dict: 117 | image, transforms = T.apply_transform_gens( 118 | ([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image 119 | ) 120 | else: 121 | # Crop around an instance if there are instances in the image. 122 | # USER: Remove if you don't use cropping 123 | if self.crop_gen: 124 | crop_tfm = utils.gen_crop_transform_with_instance( 125 | self.crop_gen.get_crop_size(image.shape[:2]), 126 | image.shape[:2], 127 | np.random.choice(dataset_dict["annotations"]), 128 | ) 129 | image = crop_tfm.apply_image(image) 130 | image, transforms = T.apply_transform_gens(self.tfm_gens, image) 131 | if self.crop_gen: 132 | transforms = crop_tfm + transforms 133 | 134 | image_shape = image.shape[:2] # h, w 135 | 136 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, 137 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue. 138 | # Therefore it's important to use torch.Tensor. 139 | dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32")) 140 | # Can use uint8 if it turns out to be slow some day 141 | 142 | # USER: Remove if you don't use pre-computed proposals. 143 | if self.load_proposals: 144 | utils.transform_proposals( 145 | dataset_dict, image_shape, transforms, self.min_box_side_len, self.proposal_topk 146 | ) 147 | 148 | if not self.is_train and not self.eval_with_gt: 149 | dataset_dict.pop("annotations", None) 150 | dataset_dict.pop("sem_seg_file_name", None) 151 | return dataset_dict 152 | 153 | if "annotations" in dataset_dict: 154 | # USER: Modify this if you want to keep them for some reason. 155 | for anno in dataset_dict["annotations"]: 156 | if not self.mask_on: 157 | anno.pop("segmentation", None) 158 | if not self.keypoint_on: 159 | anno.pop("keypoints", None) 160 | 161 | # USER: Implement additional transformations if you have other types of data 162 | annos = [ 163 | utils.transform_instance_annotations( 164 | obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices 165 | ) 166 | for obj in dataset_dict.pop("annotations") 167 | if obj.get("iscrowd", 0) == 0 168 | ] 169 | instances = utils.annotations_to_instances( 170 | annos, image_shape, mask_format=self.mask_format 171 | ) 172 | # Create a tight bounding box from masks, useful when image is cropped 173 | if self.crop_gen and instances.has("gt_masks"): 174 | instances.gt_boxes = instances.gt_masks.get_bounding_boxes() 175 | dataset_dict["instances"] = utils.filter_empty_instances(instances) 176 | 177 | # USER: Remove if you don't do semantic/panoptic segmentation. 178 | if "sem_seg_file_name" in dataset_dict: 179 | with PathManager.open(dataset_dict.pop("sem_seg_file_name"), "rb") as f: 180 | sem_seg_gt = Image.open(f) 181 | sem_seg_gt = np.asarray(sem_seg_gt, dtype="uint8") 182 | sem_seg_gt = transforms.apply_segmentation(sem_seg_gt) 183 | sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long")) 184 | dataset_dict["sem_seg"] = sem_seg_gt 185 | return dataset_dict 186 | -------------------------------------------------------------------------------- /centernet/centernet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torchvision.models.resnet as resnet 8 | from detectron2.layers import ShapeSpec 9 | 10 | # from centernet.network.backbone import Backbone 11 | from detectron2.modeling import Backbone 12 | from detectron2.modeling.backbone import build_backbone 13 | from detectron2.modeling.backbone.build import BACKBONE_REGISTRY 14 | from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY 15 | from detectron2.structures import Boxes, ImageList, Instances 16 | 17 | from .centernet_decode import CenterNetDecoder, gather_feature 18 | from .centernet_deconv import CenternetDeconv 19 | from .centernet_gt import CenterNetGT 20 | from .centernet_head import CenternetHead 21 | 22 | __all__ = ["CenterNet"] 23 | 24 | _resnet_mapper = {18: resnet.resnet18, 50: resnet.resnet50, 101: resnet.resnet101} 25 | 26 | 27 | class ResnetBackbone(Backbone): 28 | def __init__(self, cfg, input_shape=None, pretrained=True): 29 | super().__init__() 30 | depth = cfg.MODEL.RESNETS.DEPTH 31 | backbone = _resnet_mapper[depth](pretrained=pretrained) 32 | self.stage0 = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu, backbone.maxpool) 33 | self.stage1 = backbone.layer1 34 | self.stage2 = backbone.layer2 35 | self.stage3 = backbone.layer3 36 | self.stage4 = backbone.layer4 37 | 38 | def forward(self, x): 39 | x = self.stage0(x) 40 | x = self.stage1(x) 41 | x = self.stage2(x) 42 | x = self.stage3(x) 43 | x = self.stage4(x) 44 | return x 45 | 46 | 47 | def reg_l1_loss(output, mask, index, target): 48 | pred = gather_feature(output, index, use_transform=True) 49 | mask = mask.unsqueeze(dim=2).expand_as(pred).float() 50 | # loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean') 51 | loss = F.l1_loss(pred * mask, target * mask, reduction="sum") 52 | loss = loss / (mask.sum() + 1e-4) 53 | return loss 54 | 55 | 56 | def modified_focal_loss(pred, gt): 57 | """ 58 | focal loss copied from CenterNet, modified version focal loss 59 | change log: numeric stable version implementation 60 | """ 61 | pos_inds = gt.eq(1).float() 62 | neg_inds = gt.lt(1).float() 63 | 64 | neg_weights = torch.pow(1 - gt, 4) 65 | # clamp min value is set to 1e-12 to maintain the numerical stability 66 | pred = torch.clamp(pred, 1e-12) 67 | 68 | pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds 69 | neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds 70 | 71 | num_pos = pos_inds.float().sum() 72 | pos_loss = pos_loss.sum() 73 | neg_loss = neg_loss.sum() 74 | 75 | if num_pos == 0: 76 | loss = -neg_loss 77 | else: 78 | loss = -(pos_loss + neg_loss) / num_pos 79 | return loss 80 | 81 | 82 | @BACKBONE_REGISTRY.register() 83 | def build_torch_backbone(cfg, input_shape=None): 84 | """ 85 | Build a backbone. 86 | 87 | Returns: 88 | an instance of :class:`Backbone` 89 | """ 90 | if input_shape is None: 91 | input_shape = ShapeSpec(channels=len(cfg.MODEL.PIXEL_MEAN)) 92 | 93 | backbone = ResnetBackbone(cfg, input_shape) 94 | assert isinstance(backbone, Backbone) 95 | return backbone 96 | 97 | 98 | @META_ARCH_REGISTRY.register() 99 | class CenterNet(nn.Module): 100 | """ 101 | Implement CenterNet (https://arxiv.org/abs/1904.07850). 102 | """ 103 | 104 | def __init__(self, cfg): 105 | super().__init__() 106 | 107 | self.device = torch.device(cfg.MODEL.DEVICE) 108 | self.cfg = cfg 109 | 110 | # fmt: off 111 | self.num_classes = cfg.MODEL.CENTERNET.NUM_CLASSES 112 | # Loss parameters: 113 | # Inference parameters: 114 | self.max_detections_per_image = cfg.TEST.DETECTIONS_PER_IMAGE 115 | # fmt: on 116 | self.backbone = build_backbone( 117 | cfg, input_shape=ShapeSpec(channels=len(cfg.MODEL.PIXEL_MEAN)) 118 | ) 119 | self.upsample = CenternetDeconv(cfg) 120 | self.head = CenternetHead(cfg) 121 | 122 | self.mean, self.std = cfg.MODEL.PIXEL_MEAN, cfg.MODEL.PIXEL_STD 123 | pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1) 124 | pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1) 125 | self.normalizer = lambda x: (x - pixel_mean) / pixel_std 126 | 127 | self.to(self.device) 128 | 129 | def forward(self, batched_inputs): 130 | """ 131 | Args: 132 | batched_inputs(list): batched outputs of :class:`DatasetMapper` . 133 | Each item in the list contains the inputs for one image. 134 | For now, each item in the list is a dict that contains: 135 | 136 | * image: Tensor, image in (C, H, W) format. 137 | * instances: Instances 138 | 139 | Other information that's included in the original dicts, such as: 140 | 141 | * "height", "width" (int): the output resolution of the model, used in inference. 142 | See :meth:`postprocess` for details. 143 | Returns: 144 | dict[str: Tensor]: 145 | """ 146 | images = self.preprocess_image(batched_inputs) 147 | 148 | if not self.training: 149 | return self.inference(images) 150 | 151 | features = self.backbone(images.tensor) 152 | # features = features[self.cfg.MODEL.RESNETS.OUT_FEATURES[0]] 153 | up_fmap = self.upsample(features) 154 | pred_dict = self.head(up_fmap) 155 | 156 | gt_dict = self.get_ground_truth(batched_inputs) 157 | 158 | return self.losses(pred_dict, gt_dict) 159 | 160 | def losses(self, pred_dict, gt_dict): 161 | r""" 162 | calculate losses of pred and gt 163 | 164 | Args: 165 | gt_dict(dict): a dict contains all information of gt 166 | gt_dict = { 167 | "score_map": gt scoremap, 168 | "wh": gt width and height of boxes, 169 | "reg": gt regression of box center point, 170 | "reg_mask": mask of regression, 171 | "index": gt index, 172 | } 173 | pred(dict): a dict contains all information of prediction 174 | pred = { 175 | "cls": predicted score map 176 | "reg": predcited regression 177 | "wh": predicted width and height of box 178 | } 179 | """ 180 | # scoremap loss 181 | pred_score = pred_dict["cls"] 182 | cur_device = pred_score.device 183 | for k in gt_dict: 184 | gt_dict[k] = gt_dict[k].to(cur_device) 185 | 186 | loss_cls = modified_focal_loss(pred_score, gt_dict["score_map"]) 187 | 188 | mask = gt_dict["reg_mask"] 189 | index = gt_dict["index"] 190 | index = index.to(torch.long) 191 | # width and height loss, better version 192 | loss_wh = reg_l1_loss(pred_dict["wh"], mask, index, gt_dict["wh"]) 193 | 194 | # regression loss 195 | loss_reg = reg_l1_loss(pred_dict["reg"], mask, index, gt_dict["reg"]) 196 | 197 | loss_cls *= self.cfg.MODEL.CENTERNET.LOSS.CLS_WEIGHT 198 | loss_wh *= self.cfg.MODEL.CENTERNET.LOSS.WH_WEIGHT 199 | loss_reg *= self.cfg.MODEL.CENTERNET.LOSS.REG_WEIGHT 200 | 201 | loss = {"loss_cls": loss_cls, "loss_box_wh": loss_wh, "loss_center_reg": loss_reg} 202 | # print(loss) 203 | return loss 204 | 205 | @torch.no_grad() 206 | def get_ground_truth(self, batched_inputs): 207 | return CenterNetGT.generate(self.cfg, batched_inputs) 208 | 209 | @torch.no_grad() 210 | def inference(self, images): 211 | """ 212 | image(tensor): ImageList in detectron2.structures 213 | """ 214 | n, c, h, w = images.tensor.shape 215 | new_h, new_w = (h | 31) + 1, (w | 31) + 1 216 | center_wh = np.array([w // 2, h // 2], dtype=np.float32) 217 | size_wh = np.array([new_w, new_h], dtype=np.float32) 218 | down_scale = self.cfg.MODEL.CENTERNET.DOWN_SCALE 219 | img_info = dict( 220 | center=center_wh, size=size_wh, height=new_h // down_scale, width=new_w // down_scale 221 | ) 222 | 223 | pad_value = [-x / y for x, y in zip(self.mean, self.std)] 224 | aligned_img = torch.Tensor(pad_value).reshape((1, -1, 1, 1)).expand(n, c, new_h, new_w) 225 | aligned_img = aligned_img.to(images.tensor.device) 226 | 227 | pad_w, pad_h = math.ceil((new_w - w) / 2), math.ceil((new_h - h) / 2) 228 | aligned_img[..., pad_h : h + pad_h, pad_w : w + pad_w] = images.tensor 229 | 230 | features = self.backbone(aligned_img) 231 | # features = features[self.cfg.MODEL.RESNETS.OUT_FEATURES[0]] 232 | up_fmap = self.upsample(features) 233 | pred_dict = self.head(up_fmap) 234 | results = self.decode_prediction(pred_dict, img_info) 235 | 236 | ori_w, ori_h = img_info["center"] * 2 237 | det_instance = Instances((int(ori_h), int(ori_w)), **results) 238 | 239 | return [{"instances": det_instance}] 240 | 241 | def decode_prediction(self, pred_dict, img_info): 242 | """ 243 | Args: 244 | pred_dict(dict): a dict contains all information of prediction 245 | img_info(dict): a dict contains needed information of origin image 246 | """ 247 | fmap = pred_dict["cls"] 248 | reg = pred_dict["reg"] 249 | wh = pred_dict["wh"] 250 | 251 | boxes, scores, classes = CenterNetDecoder.decode(fmap, wh, reg) 252 | # boxes = Boxes(boxes.reshape(boxes.shape[-2:])) 253 | scores = scores.reshape(-1) 254 | classes = classes.reshape(-1).to(torch.int64) 255 | 256 | # dets = CenterNetDecoder.decode(fmap, wh, reg) 257 | boxes = CenterNetDecoder.transform_boxes(boxes, img_info) 258 | boxes = Boxes(boxes) 259 | return dict(pred_boxes=boxes, scores=scores, pred_classes=classes) 260 | 261 | def preprocess_image(self, batched_inputs): 262 | """ 263 | Normalize, pad and batch the input images. 264 | """ 265 | images = [x["image"].to(self.device) for x in batched_inputs] 266 | images = [self.normalizer(img/255.) for img in images] 267 | images = ImageList.from_tensors(images, self.backbone.size_divisibility) 268 | return images 269 | 270 | 271 | def build_model(cfg): 272 | 273 | model = CenterNet(cfg) 274 | return model 275 | --------------------------------------------------------------------------------