├── .github ├── dependabot.yml └── workflows │ └── ci-testing.yml ├── .gitignore ├── CenterNet ├── __init__.py ├── centernet.py ├── centernet_detection.py ├── centernet_multi_pose.py ├── centernet_test.py ├── decode │ ├── __init__.py │ ├── ctdet.py │ └── multi_pose.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── large_hourglass.py │ │ ├── msra_resnet.py │ │ ├── pose_dla_dcn.py │ │ └── resnet_dcn.py │ └── heads.py ├── sample │ ├── __init__.py │ ├── ctdet.py │ └── multi_pose.py ├── transforms │ ├── __init__.py │ ├── image.py │ └── sample.py └── utils │ ├── __init__.py │ ├── decode.py │ ├── gaussian.py │ ├── losses.py │ └── nms.py ├── LICENSE ├── README.md ├── requirements.txt ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── data └── coco_annotation.json ├── requirements.txt ├── test_models.py ├── test_sample_encode_decode.py ├── test_train_detection.py ├── test_train_multi_pose.py ├── test_transforms.py └── utilities.py /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | 2 | version: 2 3 | 4 | updates: 5 | - package-ecosystem: "pip" # See documentation for possible values 6 | directory: "/" # Location of package manifests 7 | schedule: 8 | interval: "daily" 9 | 10 | - package-ecosystem: "pip" 11 | directory: "/tests" 12 | schedule: 13 | interval: "daily" 14 | 15 | - package-ecosystem: "github-actions" 16 | directory: "/" 17 | schedule: 18 | # Check for updates to GitHub Actions every weekday 19 | interval: "daily" 20 | -------------------------------------------------------------------------------- /.github/workflows/ci-testing.yml: -------------------------------------------------------------------------------- 1 | name: CI testing 2 | 3 | # see: https://help.github.com/en/actions/reference/events-that-trigger-workflows 4 | on: 5 | # Trigger the workflow on push or pull request, but only for the master branch 6 | push: 7 | branches: 8 | - main 9 | pull_request: 10 | branches: 11 | - main 12 | 13 | jobs: 14 | pytest: 15 | 16 | runs-on: ${{ matrix.os }} 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | os: [ubuntu-18.04, ubuntu-20.04] 21 | python-version: [3.7] 22 | 23 | # Timeout: https://stackoverflow.com/a/59076067/4521646 24 | timeout-minutes: 35 25 | 26 | steps: 27 | - name: Checkout repository and submodules 28 | uses: actions/checkout@v3.0.2 29 | with: 30 | submodules: recursive 31 | 32 | - name: Set up Python ${{ matrix.python-version }} 33 | uses: actions/setup-python@v3 34 | with: 35 | python-version: ${{ matrix.python-version }} 36 | 37 | # Github Actions: Run step on specific OS: https://stackoverflow.com/a/57948488/4521646 38 | - name: Setup macOS 39 | if: runner.os == 'macOS' 40 | run: | 41 | brew install libomp # https://github.com/pytorch/pytorch/issues/20030 42 | 43 | # Note: This uses an internal pip API and may not always work 44 | # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow 45 | - name: Get pip cache 46 | id: pip-cache 47 | run: | 48 | python -c "from pip._internal.locations import USER_CACHE_DIR; print('::set-output name=dir::' + USER_CACHE_DIR)" 49 | 50 | - name: Cache pip 51 | uses: actions/cache@v3.0.2 52 | with: 53 | path: ${{ steps.pip-cache.outputs.dir }} 54 | key: ${{ runner.os }}-py${{ matrix.python-version }}-${{ hashFiles('requirements.txt') }} 55 | restore-keys: | 56 | ${{ runner.os }}-py${{ matrix.python-version }}- 57 | 58 | - name: Install dependencies 59 | run: | 60 | sudo apt-get install ninja-build 61 | pip install --requirement requirements.txt --upgrade --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --use-feature=2020-resolver 62 | pip install --requirement tests/requirements.txt --quiet 63 | python --version 64 | pip --version 65 | pip list 66 | shell: bash 67 | 68 | - name: Cache weights 69 | uses: actions/cache@v3.0.2 70 | with: 71 | key: weights-cache 72 | path: ~/.cache/torch 73 | 74 | - name: Tests 75 | run: | 76 | coverage run --source CenterNet -m py.test CenterNet tests -v --ignore CenterNet/models/backbones/DCNv2/tests --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}.xml 77 | 78 | - name: Statistics 79 | if: success() 80 | run: | 81 | coverage report 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .github 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # Lightning /research 30 | test_tube_exp/ 31 | tests/tests_tt_dir/ 32 | tests/save_dir 33 | default/ 34 | test_tube_logs/ 35 | test_tube_data/ 36 | model_weights/ 37 | tests/save_dir 38 | tests/tests_tt_dir/ 39 | processed/ 40 | raw/ 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env 100 | .venv 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder CenterNet settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope CenterNet settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | 120 | # IDEs 121 | .idea 122 | .vscode 123 | 124 | # seed CenterNet 125 | lightning_logs/ 126 | .DS_Store 127 | -------------------------------------------------------------------------------- /CenterNet/__init__.py: -------------------------------------------------------------------------------- 1 | from CenterNet.centernet import CenterNet 2 | from CenterNet.centernet_detection import CenterNetDetection 3 | from CenterNet.centernet_multi_pose import CenterNetMultiPose 4 | -------------------------------------------------------------------------------- /CenterNet/centernet.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | import pytorch_lightning as pl 5 | 6 | from CenterNet.models import create_model 7 | 8 | 9 | class CenterNet(pl.LightningModule): 10 | def __init__(self, arch): 11 | super().__init__() 12 | self.arch = arch 13 | 14 | # Backbone specific args 15 | self.head_conv = 256 if "dla" in arch or "hourglass" in arch else 64 16 | self.num_stacks = 2 if "hourglass" in arch else 1 17 | self.padding = 127 if "hourglass" in arch else 31 18 | 19 | self.backbone = create_model(arch) 20 | 21 | self.down_ratio = 4 22 | 23 | def load_pretrained_weights(self, model_weight_path, strict=True): 24 | mapping = { 25 | "hm": "heatmap", 26 | "wh": "width_height", 27 | "reg": "regression", 28 | "hm_hp": "heatmap_keypoints", 29 | "hp_offset": "heatmap_keypoints_offset", 30 | "hps": "keypoints", 31 | } 32 | 33 | print(f"Loading weights from: {model_weight_path}") 34 | checkpoint = torch.load(model_weight_path) 35 | backbone = { 36 | k.replace("module.", ""): v 37 | for k, v in checkpoint["state_dict"].items() 38 | if k.split(".")[1] not in mapping 39 | } 40 | self.backbone.load_state_dict(backbone, strict=strict) 41 | 42 | # These next lines are some special magic. 43 | # Try not to touch them and enjoy their beauty. 44 | # (The new decoupled heads require these amazing mapping functions 45 | # to load the old pretrained weights) 46 | heads = { 47 | ("0." if self.num_stacks == 1 else "") 48 | + ".".join( 49 | [mapping[k.replace("module.", "").split(".")[0]], "fc"] 50 | + k.split(".")[2:] 51 | ).replace("conv.", ""): v 52 | for k, v in checkpoint["state_dict"].items() 53 | if k.split(".")[1] in mapping 54 | } 55 | if self.arch == "hourglass": 56 | heads = { 57 | ".".join( 58 | k.split(".")[2:3] + k.split(".")[:2] + k.split(".")[3:] 59 | ).replace("fc.1", "fc.2"): v 60 | for k, v in heads.items() 61 | } 62 | self.heads.load_state_dict(heads, strict=strict) 63 | 64 | def forward(self, x): 65 | return self.backbone.forward(x) 66 | 67 | def loss(self, outputs, target): 68 | return 0, {} 69 | 70 | def training_step(self, batch, batch_idx): 71 | img, target = batch 72 | outputs = self(img) 73 | loss, loss_stats = self.loss(outputs, target) 74 | 75 | self.log(f"train_loss", loss, on_epoch=True) 76 | 77 | for key, value in loss_stats.items(): 78 | self.log(f"train/{key}", value) 79 | 80 | return loss 81 | 82 | def validation_step(self, batch, batch_idx): 83 | img, target = batch 84 | outputs = self(img) 85 | loss, loss_stats = self.loss(outputs, target) 86 | 87 | self.log(f"val_loss", loss, on_epoch=True, sync_dist=True) 88 | 89 | for name, value in loss_stats.items(): 90 | self.log(f"val/{name}", value, on_epoch=True, sync_dist=True) 91 | 92 | return {"loss": loss, "loss_stats": loss_stats} 93 | 94 | def configure_optimizers(self): 95 | optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) 96 | lr_scheduler = { 97 | "scheduler": torch.optim.lr_scheduler.MultiStepLR( 98 | optimizer, milestones=self.learning_rate_milestones 99 | ), 100 | "name": "learning_rate", 101 | "interval": "epoch", 102 | "frequency": 1 103 | } 104 | 105 | return [optimizer], [lr_scheduler] 106 | 107 | @staticmethod 108 | def add_model_specific_args(parent_parser): 109 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 110 | parser.add_argument( 111 | "--arch", 112 | default="dla_34", 113 | help="backbone architecture. Currently tested " 114 | "res_18 | res_101 | resdcn_18 | resdcn_101 | dla_34 | hourglass", 115 | ) 116 | 117 | parser.add_argument("--learning_rate", type=float, default=25e-5) 118 | parser.add_argument("--learning_rate_milestones", default="90, 120") 119 | return parser 120 | -------------------------------------------------------------------------------- /CenterNet/centernet_detection.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | import numpy as np 4 | import imgaug as ia 5 | import imgaug.augmenters as iaa 6 | import torch 7 | import torchvision 8 | import torch.nn.functional as F 9 | import torchvision.transforms.functional as VF 10 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 11 | from pytorch_lightning.loggers import TensorBoardLogger 12 | from torch.utils.data import DataLoader 13 | import pytorch_lightning as pl 14 | from pycocotools.cocoeval import COCOeval 15 | from torchvision.datasets import CocoDetection 16 | 17 | from CenterNet import CenterNet 18 | from CenterNet.models.heads import CenterHead 19 | from CenterNet.sample.ctdet import CenterDetectionSample 20 | from CenterNet.transforms import CategoryIdToClass, ImageAugmentation 21 | from CenterNet.transforms.sample import ComposeSample 22 | from CenterNet.decode.ctdet import ctdet_decode 23 | from CenterNet.utils.losses import RegL1Loss, FocalLoss 24 | from CenterNet.utils.decode import sigmoid_clamped 25 | from CenterNet.utils.nms import soft_nms 26 | 27 | 28 | class CenterNetDetection(CenterNet): 29 | mean = [0.408, 0.447, 0.470] 30 | std = [0.289, 0.274, 0.278] 31 | max_objs = 128 32 | valid_ids = [ 33 | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 34 | 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 35 | 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 36 | 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 37 | 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 38 | 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 39 | 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 40 | 82, 84, 85, 86, 87, 88, 89, 90 41 | ] 42 | 43 | def __init__( 44 | self, 45 | arch, 46 | learning_rate=1e-4, 47 | learning_rate_milestones=None, 48 | hm_weight=1, 49 | wh_weight=0.1, 50 | off_weight=1, 51 | num_classes=80, 52 | test_coco=None, 53 | test_coco_ids=None, 54 | test_scales=None, 55 | test_flip=False, 56 | ): 57 | super().__init__(arch) 58 | 59 | self.num_classes = num_classes 60 | heads = {"heatmap": self.num_classes, "width_height": 2, "regression": 2} 61 | self.heads = torch.nn.ModuleList( 62 | [ 63 | CenterHead(heads, self.backbone.out_channels, self.head_conv) 64 | for _ in range(self.num_stacks) 65 | ] 66 | ) 67 | 68 | self.learning_rate_milestones = ( 69 | learning_rate_milestones 70 | if learning_rate_milestones is not None 71 | else [] 72 | ) 73 | 74 | # Test 75 | self.test_coco = test_coco 76 | self.test_coco_ids = test_coco_ids 77 | self.test_max_per_image = 100 78 | self.test_scales = [1] if test_scales is None else test_scales 79 | self.test_flip = test_flip 80 | 81 | # Loss 82 | self.criterion = FocalLoss() 83 | self.criterion_regression = RegL1Loss() 84 | self.criterion_width_height = RegL1Loss() 85 | 86 | self.save_hyperparameters() 87 | 88 | def forward(self, x): 89 | outputs = self.backbone(x) 90 | 91 | rets = [] 92 | for head, output in zip(self.heads, outputs): 93 | rets.append(head(output)) 94 | 95 | return rets 96 | 97 | def loss(self, outputs, target): 98 | hm_loss, wh_loss, off_loss = 0, 0, 0 99 | num_stacks = len(outputs) 100 | 101 | for s in range(num_stacks): 102 | output = outputs[s] 103 | output["heatmap"] = sigmoid_clamped(output["heatmap"]) 104 | 105 | hm_loss += self.criterion(output["heatmap"], target["heatmap"]) 106 | wh_loss += self.criterion_width_height( 107 | output["width_height"], 108 | target["regression_mask"], 109 | target["indices"], 110 | target["width_height"], 111 | ) 112 | off_loss += self.criterion_regression( 113 | output["regression"], 114 | target["regression_mask"], 115 | target["indices"], 116 | target["regression"], 117 | ) 118 | 119 | loss = ( 120 | self.hparams.hm_weight * hm_loss 121 | + self.hparams.wh_weight * wh_loss 122 | + self.hparams.off_weight * off_loss 123 | ) / num_stacks 124 | loss_stats = { 125 | "loss": loss, 126 | "hm_loss": hm_loss, 127 | "wh_loss": wh_loss, 128 | "off_loss": off_loss, 129 | } 130 | return loss, loss_stats 131 | 132 | def test_step(self, batch, batch_idx): 133 | img, target = batch 134 | image_id = self.test_coco_ids[batch_idx] if self.test_coco_ids else batch_idx 135 | 136 | # Test augmentation 137 | images = [] 138 | meta = [] 139 | for scale in self.test_scales: 140 | _, _, height, width = img.shape 141 | new_height = int(height * scale) 142 | new_width = int(width * scale) 143 | pad_top_bottom = ((new_height | self.padding) + 1 - new_height) // 2 144 | pad_left_right = ((new_width | self.padding) + 1 - new_width) // 2 145 | 146 | img_scaled = VF.resize(img, [new_height, new_width]) 147 | img_scaled = F.pad( 148 | img_scaled, 149 | (pad_left_right, pad_left_right, pad_top_bottom, pad_top_bottom), 150 | ) 151 | img_scaled = VF.normalize(img_scaled, self.mean, self.std) 152 | 153 | if self.test_flip: 154 | img_scaled = torch.cat([img_scaled, VF.hflip(img_scaled)]) 155 | 156 | images.append(img_scaled) 157 | meta.append({ 158 | "scale": [new_width / width, new_height / height], 159 | "padding": [pad_left_right, pad_top_bottom], 160 | }) 161 | 162 | # Forward 163 | outputs = [] 164 | for image in images: 165 | outputs.append(self(image)[-1]) 166 | 167 | if self.test_flip: 168 | for output in outputs: 169 | output["heatmap"] = (output["heatmap"][0:1] + VF.hflip(output["heatmap"][1:2])) / 2 170 | output["width_height"] = (output["width_height"][0:1] + VF.hflip(output["width_height"][1:2])) / 2 171 | output["regression"] = output["regression"][0:1] 172 | 173 | return image_id, outputs, meta 174 | 175 | def test_step_end(self, outputs): 176 | image_id, outputs, metas = outputs 177 | 178 | detections = [] 179 | for i in range(len(outputs)): 180 | output = outputs[i] 181 | meta = metas[i] 182 | 183 | detection = ctdet_decode( 184 | output["heatmap"].sigmoid_(), 185 | output["width_height"], 186 | reg=output["regression"], 187 | ) 188 | detection = detection.cpu().detach().squeeze() 189 | 190 | # Transform detection to original image 191 | padding = torch.FloatTensor(meta["padding"] + meta["padding"]) 192 | scale = torch.FloatTensor(meta["scale"] + meta["scale"]) 193 | detection[:, :4] *= self.down_ratio # Scale to input 194 | detection[:, :4] -= padding # Remove pad 195 | detection[:, :4] /= scale # Compensate scale 196 | 197 | # Group detections by class 198 | class_predictions = {} 199 | classes = detection[:, -1] 200 | for j in range(self.num_classes): 201 | indices = classes == j 202 | class_predictions[j + 1] = detection[indices, :5].numpy().reshape(-1, 5) 203 | 204 | detections.append(class_predictions) 205 | 206 | # Merge detections 207 | results = {} 208 | for j in range(1, self.num_classes + 1): 209 | results[j] = np.concatenate( 210 | [detection[j] for detection in detections], axis=0 211 | ) 212 | if len(self.test_scales) > 1: 213 | keep_indices = soft_nms(results[j], Nt=0.5, method=2) 214 | results[j] = results[j][keep_indices] 215 | 216 | # Keep only best detections 217 | scores = np.hstack([results[j][:, 4] for j in range(1, self.num_classes + 1)]) 218 | if len(scores) > self.test_max_per_image: 219 | kth = len(scores) - self.test_max_per_image 220 | thresh = np.partition(scores, kth)[kth] 221 | for j in range(1, self.num_classes + 1): 222 | keep_indices = results[j][:, 4] >= thresh 223 | results[j] = results[j][keep_indices] 224 | 225 | return image_id, results 226 | 227 | def test_epoch_end(self, detections): 228 | if not self.test_coco: 229 | return detections 230 | 231 | # Convert to COCO eval format 232 | # Format: imageID, x1, y1, w, h, score, class 233 | data = [] 234 | for image_id, detection in detections: 235 | for class_index, box in detection.items(): 236 | if box.shape[0] == 0: 237 | continue 238 | 239 | category_id = self.valid_ids[class_index - 1] 240 | category_ids = np.repeat(category_id, box.shape[0]).reshape((-1, 1)) 241 | image_ids = np.repeat(image_id, box.shape[0]).reshape((-1, 1)) 242 | 243 | box[:, 2] -= box[:, 0] 244 | box[:, 3] -= box[:, 1] 245 | 246 | data.append(np.hstack((image_ids, box, category_ids))) 247 | 248 | data = np.concatenate(data, axis=0) 249 | 250 | coco_detections = self.test_coco.loadRes(data) 251 | 252 | coco_eval = COCOeval(self.test_coco, coco_detections, "bbox") 253 | coco_eval.evaluate() 254 | coco_eval.accumulate() 255 | coco_eval.summarize() 256 | 257 | prefix = "" 258 | if len(self.test_scales) > 1: 259 | prefix += "multi-scale_" 260 | if self.test_flip: 261 | prefix += "flip_" 262 | 263 | stats = ["ap", "ap_50", "ap_75", "ap_S", "ap_M", "ap_L"] 264 | for num, name in enumerate(stats): 265 | self.log(f"test/{prefix}{name}", coco_eval.stats[num], sync_dist=True) 266 | 267 | 268 | def cli_main(): 269 | pl.seed_everything(5318008) 270 | ia.seed(107734) 271 | 272 | # ------------ 273 | # args 274 | # ------------ 275 | parser = ArgumentParser() 276 | parser.add_argument("image_root") 277 | parser.add_argument("annotation_root") 278 | 279 | parser.add_argument("--pretrained_weights_path") 280 | parser.add_argument("--batch_size", default=32, type=int) 281 | parser.add_argument("--num_workers", default=8, type=int) 282 | parser = pl.Trainer.add_argparse_args(parser) 283 | parser = CenterNetDetection.add_model_specific_args(parser) 284 | args = parser.parse_args() 285 | 286 | # ------------ 287 | # data 288 | # ------------ 289 | train_transform = ComposeSample( 290 | [ 291 | ImageAugmentation( 292 | iaa.Sequential([ 293 | iaa.Resize({"shorter-side": "keep-aspect-ratio", "longer-side": 500}), 294 | iaa.Sequential([ 295 | iaa.Fliplr(0.5), 296 | iaa.Sometimes(0.5, iaa.GaussianBlur(sigma=(0, 0.5))), 297 | iaa.LinearContrast((0.75, 1.5)), 298 | iaa.AdditiveGaussianNoise( 299 | loc=0, scale=(0.0, 0.05 * 255), per_channel=0.5 300 | ), 301 | iaa.Multiply((0.8, 1.2), per_channel=0.1), 302 | iaa.Affine( 303 | scale={"x": (0.6, 1.4), "y": (0.6, 1.4)}, 304 | translate_percent={ 305 | "x": (-0.2, 0.2), 306 | "y": (-0.2, 0.2), 307 | }, 308 | rotate=(-5, 5), 309 | shear=(-3, 3), 310 | ), 311 | ], random_order=True), 312 | iaa.PadToFixedSize(width=500, height=500), 313 | iaa.CropToFixedSize(width=500, height=500), 314 | iaa.PadToFixedSize(width=512, height=512, position="center"), 315 | ]), 316 | torchvision.transforms.Compose([ 317 | torchvision.transforms.ToTensor(), 318 | torchvision.transforms.Normalize(CenterNetDetection.mean, CenterNetDetection.std, inplace=True), 319 | ]), 320 | ), 321 | CategoryIdToClass(CenterNetDetection.valid_ids), 322 | CenterDetectionSample(), 323 | ] 324 | ) 325 | 326 | valid_transform = ComposeSample( 327 | [ 328 | ImageAugmentation( 329 | iaa.Sequential([ 330 | iaa.Resize({"shorter-side": "keep-aspect-ratio", "longer-side": 500}), 331 | iaa.PadToFixedSize(width=512, height=512, position="center"), 332 | ]), 333 | torchvision.transforms.Compose([ 334 | torchvision.transforms.ToTensor(), 335 | torchvision.transforms.Normalize(CenterNetDetection.mean, CenterNetDetection.std, inplace=True), 336 | ]), 337 | ), 338 | CategoryIdToClass(CenterNetDetection.valid_ids), 339 | CenterDetectionSample(), 340 | ] 341 | ) 342 | 343 | test_transform = ImageAugmentation(img_transforms=torchvision.transforms.ToTensor()) 344 | 345 | coco_train = CocoDetection( 346 | os.path.join(args.image_root, "train2017"), 347 | os.path.join(args.annotation_root, "instances_train2017.json"), 348 | transforms=train_transform, 349 | ) 350 | 351 | coco_val = CocoDetection( 352 | os.path.join(args.image_root, "val2017"), 353 | os.path.join(args.annotation_root, "instances_val2017.json"), 354 | transforms=valid_transform, 355 | ) 356 | 357 | coco_test = CocoDetection( 358 | os.path.join(args.image_root, "val2017"), 359 | os.path.join(args.annotation_root, "instances_val2017.json"), 360 | transforms=test_transform, 361 | ) 362 | 363 | train_loader = DataLoader( 364 | coco_train, 365 | batch_size=args.batch_size, 366 | num_workers=args.num_workers, 367 | pin_memory=True, 368 | ) 369 | val_loader = DataLoader( 370 | coco_val, 371 | batch_size=args.batch_size, 372 | num_workers=args.num_workers, 373 | pin_memory=True, 374 | ) 375 | test_loader = DataLoader(coco_test, batch_size=1, num_workers=0, pin_memory=True) 376 | 377 | # ------------ 378 | # model 379 | # ------------ 380 | args.learning_rate_milestones = list(map(int, args.learning_rate_milestones.split(","))) 381 | model = CenterNetDetection( 382 | args.arch, args.learning_rate, 383 | args.learning_rate_milestones, 384 | test_coco=coco_test.coco, 385 | test_coco_ids=list(sorted(coco_test.coco.imgs.keys())) 386 | ) 387 | if args.pretrained_weights_path: 388 | model.load_pretrained_weights(args.pretrained_weights_path) 389 | 390 | # ------------ 391 | # training 392 | # ------------ 393 | logger = TensorBoardLogger("../tb_logs", name=f"multi_pose_{args.arch}") 394 | callbacks = [ 395 | ModelCheckpoint( 396 | monitor="val_loss", 397 | mode="min", 398 | save_top_k=5, 399 | save_last=True, 400 | every_n_epochs=10 401 | ), 402 | LearningRateMonitor(logging_interval="epoch"), 403 | ] 404 | 405 | trainer = pl.Trainer.from_argparse_args( 406 | args, 407 | callbacks=callbacks, 408 | logger=logger 409 | ) 410 | trainer.fit(model, train_loader, val_loader) 411 | 412 | # ------------ 413 | # testing 414 | # ------------ 415 | trainer.test(dataloaders=test_loader) 416 | 417 | 418 | if __name__ == "__main__": 419 | cli_main() 420 | -------------------------------------------------------------------------------- /CenterNet/centernet_multi_pose.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | import numpy as np 5 | import imgaug.augmenters as iaa 6 | import torch 7 | import torchvision 8 | import torch.nn.functional as F 9 | import torchvision.transforms.functional as VF 10 | import pytorch_lightning as pl 11 | from pycocotools.cocoeval import COCOeval 12 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 13 | from pytorch_lightning.loggers import TensorBoardLogger 14 | from torch.utils.data import DataLoader 15 | from torchvision.datasets import CocoDetection 16 | 17 | from CenterNet import CenterNet 18 | from CenterNet.models.heads import CenterHead 19 | from CenterNet.sample.ctdet import CenterDetectionSample 20 | from CenterNet.sample.multi_pose import MultiPoseSample 21 | from CenterNet.transforms import ImageAugmentation 22 | from CenterNet.transforms.sample import MultiSampleTransform, PoseFlip, ComposeSample 23 | from CenterNet.decode.multi_pose import multi_pose_decode 24 | from CenterNet.utils.decode import sigmoid_clamped 25 | from CenterNet.utils.losses import RegL1Loss, FocalLoss, RegWeightedL1Loss 26 | from CenterNet.utils.nms import soft_nms_39 27 | 28 | 29 | class CenterNetMultiPose(CenterNet): 30 | mean = [0.408, 0.447, 0.470] 31 | std = [0.289, 0.274, 0.278] 32 | flip_idx = [ 33 | 0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 34 | ] 35 | 36 | def __init__( 37 | self, 38 | arch, 39 | learning_rate=1e-4, 40 | learning_rate_milestones=None, 41 | hm_weight=1, 42 | wh_weight=0.1, 43 | off_weight=1, 44 | hp_weight=1, 45 | hm_hp_weight=1, 46 | test_coco=None, 47 | test_coco_ids=None, 48 | test_scales=None, 49 | test_flip=True, 50 | ): 51 | super().__init__(arch) 52 | 53 | heads = { 54 | "heatmap": 1, 55 | "width_height": 2, 56 | "regression": 2, 57 | "heatmap_keypoints": 17, 58 | "keypoints": 34, 59 | "heatmap_keypoints_offset": 2, 60 | } 61 | self.heads = torch.nn.ModuleList( 62 | [ 63 | CenterHead(heads, self.backbone.out_channels, self.head_conv) 64 | for _ in range(self.num_stacks) 65 | ] 66 | ) 67 | 68 | self.learning_rate_milestones = ( 69 | learning_rate_milestones if learning_rate_milestones is not None else [] 70 | ) 71 | 72 | # Test 73 | self.test_coco = test_coco 74 | self.test_coco_ids = test_coco_ids 75 | self.test_max_per_image = 20 76 | self.test_scales = [1] if test_scales is None else test_scales 77 | self.test_flip = test_flip 78 | 79 | # Loss 80 | self.criterion = FocalLoss() 81 | self.criterion_heatmap_keypoints = FocalLoss() 82 | self.criterion_keypoints = RegWeightedL1Loss() 83 | self.criterion_regression = RegL1Loss() 84 | self.criterion_width_height = RegL1Loss() 85 | 86 | self.save_hyperparameters() 87 | 88 | def forward(self, x): 89 | outputs = self.backbone(x) 90 | 91 | rets = [] 92 | for head, output in zip(self.heads, outputs): 93 | rets.append(head(output)) 94 | 95 | return rets 96 | 97 | def loss(self, outputs, target): 98 | hm_loss, wh_loss, off_loss = 0, 0, 0 99 | kp_loss, off_loss, hm_kp_loss, hm_offset_loss = 0, 0, 0, 0 100 | num_stacks = len(outputs) 101 | 102 | for s in range(num_stacks): 103 | output = outputs[s] 104 | output["heatmap"] = sigmoid_clamped(output["heatmap"]) 105 | output["heatmap_keypoints"] = sigmoid_clamped(output["heatmap_keypoints"]) 106 | 107 | hm_loss += self.criterion(output["heatmap"], target["heatmap"]) 108 | wh_loss += self.criterion_width_height( 109 | output["width_height"], 110 | target["regression_mask"], 111 | target["indices"], 112 | target["width_height"], 113 | ) 114 | off_loss += self.criterion_regression( 115 | output["regression"], 116 | target["regression_mask"], 117 | target["indices"], 118 | target["regression"], 119 | ) 120 | 121 | kp_loss += self.criterion_keypoints( 122 | output["keypoints"], 123 | target["keypoints_mask"], 124 | target["indices"], 125 | target["keypoints"], 126 | ) 127 | hm_kp_loss += self.criterion_heatmap_keypoints( 128 | output["heatmap_keypoints"], target["heatmap_keypoints"] 129 | ) 130 | hm_offset_loss += self.criterion_regression( 131 | output["heatmap_keypoints_offset"], 132 | target["heatmap_keypoints_mask"], 133 | target["heatmap_keypoints_indices"], 134 | target["heatmap_keypoints_offset"], 135 | ) 136 | 137 | loss = ( 138 | self.hparams.hm_weight * hm_loss 139 | + self.hparams.wh_weight * wh_loss 140 | + self.hparams.off_weight * off_loss 141 | + self.hparams.hp_weight * kp_loss 142 | + self.hparams.hm_hp_weight * hm_kp_loss 143 | + self.hparams.off_weight * hm_offset_loss 144 | ) / num_stacks 145 | 146 | loss_stats = { 147 | "loss": loss, 148 | "hm_loss": hm_loss, 149 | "kp_loss": kp_loss, 150 | "hm_kp_loss": hm_kp_loss, 151 | "hm_offset_loss": hm_offset_loss, 152 | "wh_loss": wh_loss, 153 | "off_loss": off_loss, 154 | } 155 | return loss, loss_stats 156 | 157 | def test_step(self, batch, batch_idx): 158 | img, target = batch 159 | image_id = self.test_coco_ids[batch_idx] if self.test_coco_ids else batch_idx 160 | 161 | # Test augmentation 162 | images = [] 163 | meta = [] 164 | for scale in self.test_scales: 165 | _, _, height, width = img.shape 166 | new_height = int(height * scale) 167 | new_width = int(width * scale) 168 | pad_top_bottom = ((new_height | self.padding) + 1 - new_height) // 2 169 | pad_left_right = ((new_width | self.padding) + 1 - new_width) // 2 170 | 171 | img_scaled = VF.resize(img, [new_height, new_width]) 172 | img_scaled = F.pad( 173 | img_scaled, 174 | (pad_left_right, pad_left_right, pad_top_bottom, pad_top_bottom), 175 | ) 176 | img_scaled = VF.normalize(img_scaled, self.mean, self.std) 177 | 178 | if self.test_flip: 179 | img_scaled = torch.cat([img_scaled, VF.hflip(img_scaled)]) 180 | 181 | images.append(img_scaled) 182 | meta.append({ 183 | "scale": [new_width / width, new_height / height], 184 | "padding": [pad_left_right, pad_top_bottom], 185 | }) 186 | 187 | # Forward 188 | outputs = [] 189 | for image in images: 190 | outputs.append(self(image)[-1]) 191 | 192 | if self.test_flip: 193 | for output in outputs: 194 | output["heatmap"] = ( 195 | output["heatmap"][0:1] + VF.hflip(output["heatmap"][1:2]) 196 | ) / 2 197 | output["width_height"] = ( 198 | output["width_height"][0:1] + VF.hflip(output["width_height"][1:2]) 199 | ) / 2 200 | output["regression"] = output["regression"][0:1] 201 | 202 | # Flip pose aware 203 | num, points, height, width = output["keypoints"][1:2].shape 204 | flipped_keypoints = VF.hflip(output["keypoints"][1:2]).view(1, points // 2, 2, height, width) 205 | flipped_keypoints[:, :, 0, :, :] *= -1 206 | flipped_keypoints = flipped_keypoints[0:1, self.flip_idx].view(1, points, height, width) 207 | output["keypoints"] = (output["keypoints"][0:1] + flipped_keypoints) / 2 208 | 209 | flipped_heatmap = VF.hflip(output["heatmap_keypoints"][1:2])[0:1, self.flip_idx] 210 | output["heatmap_keypoints"] = (output["heatmap_keypoints"][0:1] + flipped_heatmap) / 2 211 | output["heatmap_keypoints_offset"] = output["heatmap_keypoints_offset"][0:1] 212 | 213 | return image_id, outputs, meta 214 | 215 | def test_step_end(self, outputs): 216 | image_id, outputs, metas = outputs 217 | 218 | detections = [] 219 | for i in range(len(outputs)): 220 | output = outputs[i] 221 | meta = metas[i] 222 | 223 | detection = multi_pose_decode( 224 | output["heatmap"].sigmoid_(), 225 | output["width_height"], 226 | output["keypoints"], 227 | reg=output["regression"], 228 | hm_hp=output["heatmap_keypoints"].sigmoid_(), 229 | hp_offset=output["heatmap_keypoints_offset"], 230 | ) 231 | detection = detection.cpu().detach().squeeze() 232 | 233 | # Transform detection to original image 234 | padding = torch.FloatTensor(meta["padding"]) 235 | scale = torch.FloatTensor(meta["scale"]) 236 | 237 | # Bounding Box 238 | detection[:, :4] *= self.down_ratio # Scale to input 239 | detection[:, :4] -= torch.cat([padding, padding]) # Remove pad 240 | detection[:, :4] /= torch.cat([scale, scale]) # Compensate scale 241 | 242 | # Keypoints 243 | points = detection[:, 5:39].view(-1, 17, 2) 244 | points *= self.down_ratio 245 | points -= padding 246 | points /= scale 247 | detection[:, 5:39] = points.view(-1, 34) 248 | 249 | detections.append(detection.numpy()) 250 | 251 | results = np.concatenate(detections, axis=0) 252 | if len(self.test_scales) > 1: 253 | keep_indices = soft_nms_39(results, Nt=0.5, method=2) 254 | results = results[keep_indices] 255 | 256 | # Keep only best detections 257 | scores = results[:, 4] 258 | if len(scores) > self.test_max_per_image: 259 | kth = len(scores) - self.test_max_per_image 260 | thresh = np.partition(scores, kth)[kth] 261 | keep_indices = results[:, 4] >= thresh 262 | results = results[keep_indices] 263 | 264 | return image_id, results.tolist() 265 | 266 | def test_epoch_end(self, results): 267 | if not self.test_coco: 268 | return 269 | 270 | category_id = 1 271 | 272 | # Convert to COCO annotation format 273 | data = [] 274 | for image_id, detections in results: 275 | for detection in detections: 276 | bbox = detection[:4] 277 | bbox[2] -= bbox[0] 278 | bbox[3] -= bbox[1] 279 | score = detection[4] 280 | 281 | keypoints = ( 282 | np.concatenate([ 283 | np.array(detection[5:39], dtype=np.float32).reshape(-1, 2), 284 | np.ones((17, 1), dtype=np.float32), 285 | ], axis=1) 286 | .reshape(51) 287 | .tolist() 288 | ) 289 | 290 | data.append({ 291 | "image_id": int(image_id), 292 | "category_id": int(category_id), 293 | "bbox": bbox, 294 | "score": score, 295 | "keypoints": keypoints, 296 | }) 297 | 298 | coco_detections = self.test_coco.loadRes(data) 299 | 300 | coco_eval_kp = COCOeval(self.test_coco, coco_detections, "keypoints") 301 | coco_eval_kp.evaluate() 302 | coco_eval_kp.accumulate() 303 | coco_eval_kp.summarize() 304 | 305 | coco_eval = COCOeval(self.test_coco, coco_detections, "bbox") 306 | coco_eval.evaluate() 307 | coco_eval.accumulate() 308 | coco_eval.summarize() 309 | 310 | prefix = "" 311 | if len(self.test_scales) > 1: 312 | prefix += "multi-scale_" 313 | if self.test_flip: 314 | prefix += "flip_" 315 | 316 | stats = ["ap", "ap_50", "ap_75", "ap_S", "ap_M", "ap_L"] 317 | for num, name in enumerate(stats): 318 | self.log(f"test/kp_{prefix}{name}", coco_eval_kp.stats[num], sync_dist=True) 319 | 320 | for num, name in enumerate(stats): 321 | self.log(f"test/bbox_{prefix}{name}", coco_eval.stats[num], sync_dist=True) 322 | 323 | 324 | def cli_main(): 325 | pl.seed_everything(5318008) 326 | 327 | # ------------ 328 | # args 329 | # ------------ 330 | parser = ArgumentParser() 331 | parser.add_argument("image_root") 332 | parser.add_argument("annotation_root") 333 | 334 | parser.add_argument("--pretrained_weights_path") 335 | parser.add_argument("--batch_size", default=32, type=int) 336 | parser.add_argument("--num_workers", default=8, type=int) 337 | parser = pl.Trainer.add_argparse_args(parser) 338 | parser = CenterNetMultiPose.add_model_specific_args(parser) 339 | args = parser.parse_args() 340 | 341 | # ------------ 342 | # data 343 | # ------------ 344 | train_transform = ComposeSample([ 345 | ImageAugmentation( 346 | iaa.Sequential([ 347 | iaa.Resize({"shorter-side": "keep-aspect-ratio", "longer-side": 500}), 348 | iaa.Sequential([ 349 | iaa.Sometimes(0.25, iaa.GaussianBlur(sigma=(0, 0.5))), 350 | iaa.LinearContrast((0.75, 1.5)), 351 | iaa.AdditiveGaussianNoise( 352 | loc=0, scale=(0.0, 0.05 * 255), per_channel=0.5 353 | ), 354 | iaa.Multiply((0.8, 1.2), per_channel=0.1), 355 | iaa.Affine( 356 | scale={"x": (0.75, 1.25), "y": (0.75, 1.15)}, 357 | translate_percent={ 358 | "x": (-0.2, 0.2), 359 | "y": (-0.2, 0.2), 360 | }, 361 | rotate=(-5, 5), 362 | shear=(-3, 3), 363 | ), 364 | ], random_order=True), 365 | iaa.PadToFixedSize(width=500, height=500), 366 | iaa.CropToFixedSize(width=500, height=500), 367 | iaa.PadToFixedSize(width=512, height=512, position="center"), 368 | ]), 369 | torchvision.transforms.Compose([ 370 | torchvision.transforms.ToTensor(), 371 | torchvision.transforms.Normalize(CenterNetMultiPose.mean, CenterNetMultiPose.std, inplace=True), 372 | ]), 373 | ), 374 | PoseFlip(0.5), 375 | MultiSampleTransform([CenterDetectionSample(num_classes=1), MultiPoseSample()]), 376 | ]) 377 | 378 | valid_transform = ComposeSample([ 379 | ImageAugmentation( 380 | iaa.Sequential([ 381 | iaa.Resize({"shorter-side": "keep-aspect-ratio", "longer-side": 500}), 382 | iaa.PadToFixedSize(width=512, height=512, position="center"), 383 | ]), 384 | torchvision.transforms.Compose([ 385 | torchvision.transforms.ToTensor(), 386 | torchvision.transforms.Normalize(CenterNetMultiPose.mean, CenterNetMultiPose.std, inplace=True), 387 | ]), 388 | ), 389 | MultiSampleTransform([CenterDetectionSample(num_classes=1), MultiPoseSample()]), 390 | ]) 391 | 392 | test_transform = ImageAugmentation(img_transforms=torchvision.transforms.ToTensor()) 393 | 394 | coco_train = CocoDetection( 395 | os.path.join(args.image_root, "train2017"), 396 | os.path.join(args.annotation_root, "person_keypoints_train2017.json"), 397 | transforms=train_transform, 398 | ) 399 | 400 | coco_val = CocoDetection( 401 | os.path.join(args.image_root, "val2017"), 402 | os.path.join(args.annotation_root, "person_keypoints_val2017.json"), 403 | transforms=valid_transform, 404 | ) 405 | 406 | coco_test = CocoDetection( 407 | os.path.join(args.image_root, "val2017"), 408 | os.path.join(args.annotation_root, "person_keypoints_val2017.json"), 409 | transforms=test_transform, 410 | ) 411 | 412 | train_loader = DataLoader( 413 | coco_train, 414 | batch_size=args.batch_size, 415 | num_workers=args.num_workers, 416 | pin_memory=True, 417 | ) 418 | val_loader = DataLoader( 419 | coco_val, 420 | batch_size=args.batch_size, 421 | num_workers=args.num_workers, 422 | pin_memory=True, 423 | ) 424 | test_loader = DataLoader(coco_test, batch_size=1, num_workers=0, pin_memory=True) 425 | 426 | # ------------ 427 | # model 428 | # ------------ 429 | learning_rate_milestones = list( 430 | map(int, args.learning_rate_milestones.split(",")) 431 | ) 432 | model = CenterNetMultiPose( 433 | args.arch, 434 | args.learning_rate, 435 | learning_rate_milestones, 436 | test_coco=coco_test.coco, 437 | test_coco_ids=list(sorted(coco_test.coco.imgs.keys())), 438 | ) 439 | if args.pretrained_weights_path: 440 | model.load_pretrained_weights(args.pretrained_weights_path) 441 | 442 | # ------------ 443 | # training 444 | # ------------ 445 | logger = TensorBoardLogger("../tb_logs", name=f"multi_pose_{args.arch}") 446 | callbacks = [ 447 | ModelCheckpoint( 448 | monitor="val_loss", 449 | mode="min", 450 | save_top_k=5, 451 | save_last=True, 452 | every_n_epochs=10 453 | ), 454 | LearningRateMonitor(logging_interval="epoch"), 455 | ] 456 | 457 | trainer = pl.Trainer.from_argparse_args( 458 | args, 459 | callbacks=callbacks, 460 | logger=logger 461 | ) 462 | trainer.fit(model, train_loader, val_loader) 463 | 464 | # ------------ 465 | # testing 466 | # ------------ 467 | trainer.test(dataloaders=test_loader) 468 | 469 | 470 | if __name__ == "__main__": 471 | cli_main() 472 | -------------------------------------------------------------------------------- /CenterNet/centernet_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | import pytorch_lightning as pl 4 | from torch.utils.data import DataLoader 5 | import torchvision 6 | from torchvision.datasets import CocoDetection 7 | 8 | from CenterNet.centernet import CenterNet 9 | from CenterNet.centernet_detection import CenterNetDetection 10 | from CenterNet.centernet_multi_pose import CenterNetMultiPose 11 | from CenterNet.transforms import ImageAugmentation 12 | from CenterNet.transforms.sample import ComposeSample 13 | 14 | task = { 15 | "detection": (CenterNetDetection, "instances_val2017.json"), 16 | "multi_pose": (CenterNetMultiPose, "person_keypoints_val2017.json") 17 | } 18 | 19 | 20 | def cli_test(): 21 | pl.seed_everything(5318008) 22 | 23 | # ------------ 24 | # args 25 | # ------------ 26 | parser = ArgumentParser() 27 | parser.add_argument("image_root") 28 | parser.add_argument("annotation_root") 29 | 30 | parser.add_argument("--task", choices=["detection", "multi_pose"], default="detection") 31 | 32 | parser.add_argument("--pretrained_weights_path") 33 | parser.add_argument("--ckpt_path") 34 | 35 | parser.add_argument("--flip", action='store_true') 36 | parser.add_argument("--multi_scale", action='store_true') 37 | 38 | parser = pl.Trainer.add_argparse_args(parser) 39 | parser = CenterNet.add_model_specific_args(parser) 40 | args = parser.parse_args() 41 | 42 | pl_module, file = task[args.task] 43 | 44 | # ------------ 45 | # data 46 | # ------------ 47 | test_transform = ComposeSample( 48 | [ImageAugmentation(img_transforms=torchvision.transforms.ToTensor())] 49 | ) 50 | 51 | coco_test = CocoDetection( 52 | os.path.join(args.image_root, "val2017"), 53 | os.path.join(args.annotation_root, file), 54 | transforms=test_transform, 55 | ) 56 | test_loader = DataLoader(coco_test, batch_size=1, num_workers=0, pin_memory=True) 57 | 58 | # ------------ 59 | # model 60 | # ------------ 61 | model = pl_module( 62 | args.arch, 63 | args.learning_rate, 64 | test_coco=coco_test.coco, 65 | test_coco_ids=list(sorted(coco_test.coco.imgs.keys())), 66 | test_flip=args.flip, 67 | test_scales=[.5, .75, 1, 1.25, 1.5] if args.multi_scale else None 68 | ) 69 | if args.pretrained_weights_path: 70 | model.load_pretrained_weights(args.pretrained_weights_path) 71 | 72 | if args.ckpt_path: 73 | ckpt = pl.utilities.cloud_io.load(args.ckpt_path) 74 | model.load_state_dict(ckpt['state_dict']) 75 | 76 | # ------------ 77 | # testing 78 | # ------------ 79 | trainer = pl.Trainer.from_argparse_args(args) 80 | trainer.test(model, test_loader) 81 | 82 | 83 | if __name__ == "__main__": 84 | cli_test() 85 | -------------------------------------------------------------------------------- /CenterNet/decode/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tteepe/CenterNet-pytorch-lightning/2febb56103046064a42b502a4145c37728917042/CenterNet/decode/__init__.py -------------------------------------------------------------------------------- /CenterNet/decode/ctdet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils.decode import _nms, _topk, _transpose_and_gather_feat 4 | 5 | 6 | def ctdet_decode(heat, wh, reg=None, K=100): 7 | batch, cat, height, width = heat.size() 8 | 9 | # heat = torch.sigmoid(heat) 10 | # perform nms on heatmaps 11 | heat = _nms(heat) 12 | 13 | scores, inds, clses, ys, xs = _topk(heat, K=K) 14 | if reg is not None: 15 | reg = _transpose_and_gather_feat(reg, inds) 16 | reg = reg.view(batch, K, 2) 17 | xs = xs.view(batch, K, 1) + reg[:, :, 0:1] 18 | ys = ys.view(batch, K, 1) + reg[:, :, 1:2] 19 | else: 20 | xs = xs.view(batch, K, 1) + 0.5 21 | ys = ys.view(batch, K, 1) + 0.5 22 | wh = _transpose_and_gather_feat(wh, inds) 23 | 24 | wh = wh.view(batch, K, 2) 25 | clses = clses.view(batch, K, 1).float() 26 | scores = scores.view(batch, K, 1) 27 | bboxes = torch.cat( 28 | [ 29 | xs - wh[..., 0:1] / 2, 30 | ys - wh[..., 1:2] / 2, 31 | xs + wh[..., 0:1] / 2, 32 | ys + wh[..., 1:2] / 2, 33 | ], 34 | dim=2, 35 | ) 36 | detections = torch.cat([bboxes, scores, clses], dim=2) 37 | 38 | return detections 39 | -------------------------------------------------------------------------------- /CenterNet/decode/multi_pose.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from ..utils.decode import _nms, _topk, _topk_channel, _transpose_and_gather_feat 5 | 6 | 7 | def multi_pose_decode(heat, wh, kps, reg=None, hm_hp=None, hp_offset=None, K=100): 8 | batch, cat, height, width = heat.size() 9 | num_joints = kps.shape[1] // 2 10 | # heat = torch.sigmoid(heat) 11 | # perform nms on heatmaps 12 | heat = _nms(heat) 13 | scores, inds, clses, ys, xs = _topk(heat, K=K) 14 | 15 | kps = _transpose_and_gather_feat(kps, inds) 16 | kps = kps.view(batch, K, num_joints * 2) 17 | kps[..., ::2] += xs.view(batch, K, 1).expand(batch, K, num_joints) 18 | kps[..., 1::2] += ys.view(batch, K, 1).expand(batch, K, num_joints) 19 | if reg is not None: 20 | reg = _transpose_and_gather_feat(reg, inds) 21 | reg = reg.view(batch, K, 2) 22 | xs = xs.view(batch, K, 1) + reg[:, :, 0:1] 23 | ys = ys.view(batch, K, 1) + reg[:, :, 1:2] 24 | else: 25 | xs = xs.view(batch, K, 1) + 0.5 26 | ys = ys.view(batch, K, 1) + 0.5 27 | wh = _transpose_and_gather_feat(wh, inds) 28 | wh = wh.view(batch, K, 2) 29 | clses = clses.view(batch, K, 1).float() 30 | scores = scores.view(batch, K, 1) 31 | 32 | bboxes = torch.cat( 33 | [ 34 | xs - wh[..., 0:1] / 2, 35 | ys - wh[..., 1:2] / 2, 36 | xs + wh[..., 0:1] / 2, 37 | ys + wh[..., 1:2] / 2, 38 | ], 39 | dim=2, 40 | ) 41 | if hm_hp is not None: 42 | hm_hp = _nms(hm_hp) 43 | thresh = 0.1 44 | kps = ( 45 | kps.view(batch, K, num_joints, 2).permute(0, 2, 1, 3).contiguous() 46 | ) # b x J x K x 2 47 | reg_kps = kps.unsqueeze(3).expand(batch, num_joints, K, K, 2) 48 | hm_score, hm_inds, hm_ys, hm_xs = _topk_channel(hm_hp, K=K) # b x J x K 49 | if hp_offset is not None: 50 | hp_offset = _transpose_and_gather_feat(hp_offset, hm_inds.view(batch, -1)) 51 | hp_offset = hp_offset.view(batch, num_joints, K, 2) 52 | hm_xs = hm_xs + hp_offset[:, :, :, 0] 53 | hm_ys = hm_ys + hp_offset[:, :, :, 1] 54 | else: 55 | hm_xs = hm_xs + 0.5 56 | hm_ys = hm_ys + 0.5 57 | 58 | mask = (hm_score > thresh).float() 59 | hm_score = (1 - mask) * -1 + mask * hm_score 60 | hm_ys = (1 - mask) * (-10000) + mask * hm_ys 61 | hm_xs = (1 - mask) * (-10000) + mask * hm_xs 62 | hm_kps = ( 63 | torch.stack([hm_xs, hm_ys], dim=-1) 64 | .unsqueeze(2) 65 | .expand(batch, num_joints, K, K, 2) 66 | ) 67 | dist = ((reg_kps - hm_kps) ** 2).sum(dim=4) ** 0.5 68 | min_dist, min_ind = dist.min(dim=3) # b x J x K 69 | hm_score = hm_score.gather(2, min_ind).unsqueeze(-1) # b x J x K x 1 70 | min_dist = min_dist.unsqueeze(-1) 71 | min_ind = min_ind.view(batch, num_joints, K, 1, 1).expand( 72 | batch, num_joints, K, 1, 2 73 | ) 74 | hm_kps = hm_kps.gather(3, min_ind) 75 | hm_kps = hm_kps.view(batch, num_joints, K, 2) 76 | l = bboxes[:, :, 0].view(batch, 1, K, 1).expand(batch, num_joints, K, 1) 77 | t = bboxes[:, :, 1].view(batch, 1, K, 1).expand(batch, num_joints, K, 1) 78 | r = bboxes[:, :, 2].view(batch, 1, K, 1).expand(batch, num_joints, K, 1) 79 | b = bboxes[:, :, 3].view(batch, 1, K, 1).expand(batch, num_joints, K, 1) 80 | mask = ( 81 | (hm_kps[..., 0:1] < l) 82 | + (hm_kps[..., 0:1] > r) 83 | + (hm_kps[..., 1:2] < t) 84 | + (hm_kps[..., 1:2] > b) 85 | + (hm_score < thresh) 86 | + (min_dist > (torch.max(b - t, r - l) * 0.3)) 87 | ) 88 | mask = (mask > 0).float() 89 | hm_score = hm_score * (1 - mask) 90 | hm_score = hm_score.view(batch, K, num_joints) 91 | mask = (mask > 0).float().expand(batch, num_joints, K, 2) 92 | kps = (1 - mask) * hm_kps + mask * kps 93 | kps = kps.permute(0, 2, 1, 3).contiguous().view(batch, K, num_joints * 2) 94 | detections = torch.cat([bboxes, scores, kps, clses, hm_score], dim=2) 95 | 96 | return detections 97 | 98 | 99 | def multi_pose_post_process(dets, c, s, h, w): 100 | # dets: batch x max_dets x 40 101 | # return list of 39 in image coord 102 | ret = [] 103 | for i in range(dets.shape[0]): 104 | # bbox = transform_preds(dets[i, :, :4].reshape(-1, 2), c[i], s[i], (w, h)) 105 | bbox = dets[i, :, :4].reshape(-1, 2), c[i], s[i], (w, h) 106 | # pts = transform_preds(dets[i, :, 5:39].reshape(-1, 2), c[i], s[i], (w, h)) 107 | pts = dets[i, :, 5:39].reshape(-1, 2), c[i], s[i], (w, h) 108 | top_preds = ( 109 | np.concatenate( 110 | [bbox.reshape(-1, 4), dets[i, :, 4:5], pts.reshape(-1, 34)], axis=1 111 | ) 112 | .astype(np.float32) 113 | .tolist() 114 | ) 115 | ret.append({np.ones(1, dtype=np.int32)[0]: top_preds}) 116 | return ret 117 | -------------------------------------------------------------------------------- /CenterNet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones.msra_resnet import get_pose_net 2 | from .backbones.pose_dla_dcn import get_pose_net as get_dla_dcn 3 | from .backbones.resnet_dcn import get_pose_net as get_pose_net_dcn 4 | from .backbones.large_hourglass import get_large_hourglass_net 5 | 6 | _model_factory = { 7 | "res": get_pose_net, # default Resnet with deconv 8 | "dla": get_dla_dcn, 9 | "resdcn": get_pose_net_dcn, 10 | "hourglass": get_large_hourglass_net 11 | } 12 | 13 | 14 | def create_model(arch): 15 | num_layers = int(arch[arch.find("_") + 1:]) if "_" in arch else 0 16 | arch = arch[: arch.find("_")] if "_" in arch else arch 17 | get_model = _model_factory[arch] 18 | model = get_model(num_layers=num_layers) 19 | return model 20 | -------------------------------------------------------------------------------- /CenterNet/models/backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tteepe/CenterNet-pytorch-lightning/2febb56103046064a42b502a4145c37728917042/CenterNet/models/backbones/__init__.py -------------------------------------------------------------------------------- /CenterNet/models/backbones/large_hourglass.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # This code is base on 3 | # CornerNet (https://github.com/princeton-vl/CornerNet) 4 | # Copyright (c) 2018, University of Michigan 5 | # Licensed under the BSD 3-Clause License 6 | # ------------------------------------------------------------------------------ 7 | 8 | import torch.nn as nn 9 | 10 | 11 | class convolution(nn.Module): 12 | def __init__(self, k, inp_dim, out_dim, stride=1, with_bn=True): 13 | super(convolution, self).__init__() 14 | 15 | pad = (k - 1) // 2 16 | self.conv = nn.Conv2d( 17 | inp_dim, 18 | out_dim, 19 | (k, k), 20 | padding=(pad, pad), 21 | stride=(stride, stride), 22 | bias=not with_bn, 23 | ) 24 | self.bn = nn.BatchNorm2d(out_dim) if with_bn else nn.Sequential() 25 | self.relu = nn.ReLU(inplace=True) 26 | 27 | def forward(self, x): 28 | conv = self.conv(x) 29 | bn = self.bn(conv) 30 | relu = self.relu(bn) 31 | return relu 32 | 33 | 34 | class fully_connected(nn.Module): 35 | def __init__(self, inp_dim, out_dim, with_bn=True): 36 | super(fully_connected, self).__init__() 37 | self.with_bn = with_bn 38 | 39 | self.linear = nn.Linear(inp_dim, out_dim) 40 | if self.with_bn: 41 | self.bn = nn.BatchNorm1d(out_dim) 42 | self.relu = nn.ReLU(inplace=True) 43 | 44 | def forward(self, x): 45 | linear = self.linear(x) 46 | bn = self.bn(linear) if self.with_bn else linear 47 | relu = self.relu(bn) 48 | return relu 49 | 50 | 51 | class residual(nn.Module): 52 | def __init__(self, k, inp_dim, out_dim, stride=1, with_bn=True): 53 | super(residual, self).__init__() 54 | 55 | self.conv1 = nn.Conv2d( 56 | inp_dim, 57 | out_dim, 58 | (3, 3), 59 | padding=(1, 1), 60 | stride=(stride, stride), 61 | bias=False, 62 | ) 63 | self.bn1 = nn.BatchNorm2d(out_dim) 64 | self.relu1 = nn.ReLU(inplace=True) 65 | 66 | self.conv2 = nn.Conv2d(out_dim, out_dim, (3, 3), padding=(1, 1), bias=False) 67 | self.bn2 = nn.BatchNorm2d(out_dim) 68 | 69 | self.skip = ( 70 | nn.Sequential( 71 | nn.Conv2d( 72 | inp_dim, out_dim, (1, 1), stride=(stride, stride), bias=False 73 | ), 74 | nn.BatchNorm2d(out_dim), 75 | ) 76 | if stride != 1 or inp_dim != out_dim 77 | else nn.Sequential() 78 | ) 79 | self.relu = nn.ReLU(inplace=True) 80 | 81 | def forward(self, x): 82 | conv1 = self.conv1(x) 83 | bn1 = self.bn1(conv1) 84 | relu1 = self.relu1(bn1) 85 | 86 | conv2 = self.conv2(relu1) 87 | bn2 = self.bn2(conv2) 88 | 89 | skip = self.skip(x) 90 | return self.relu(bn2 + skip) 91 | 92 | 93 | def make_layer(k, inp_dim, out_dim, modules, layer=convolution, **kwargs): 94 | layers = [layer(k, inp_dim, out_dim, **kwargs)] 95 | for _ in range(1, modules): 96 | layers.append(layer(k, out_dim, out_dim, **kwargs)) 97 | return nn.Sequential(*layers) 98 | 99 | 100 | def make_layer_revr(k, inp_dim, out_dim, modules, layer=convolution, **kwargs): 101 | layers = [] 102 | for _ in range(modules - 1): 103 | layers.append(layer(k, inp_dim, inp_dim, **kwargs)) 104 | layers.append(layer(k, inp_dim, out_dim, **kwargs)) 105 | return nn.Sequential(*layers) 106 | 107 | 108 | class MergeUp(nn.Module): 109 | def forward(self, up1, up2): 110 | return up1 + up2 111 | 112 | 113 | def make_merge_layer(dim): 114 | return MergeUp() 115 | 116 | 117 | # def make_pool_layer(dim): 118 | # return nn.MaxPool2d(kernel_size=2, stride=2) 119 | 120 | 121 | def make_pool_layer(dim): 122 | return nn.Sequential() 123 | 124 | 125 | def make_unpool_layer(dim): 126 | return nn.Upsample(scale_factor=2) 127 | 128 | 129 | def make_kp_layer(cnv_dim, curr_dim, out_dim): 130 | return nn.Sequential( 131 | convolution(3, cnv_dim, curr_dim, with_bn=False), 132 | nn.Conv2d(curr_dim, out_dim, (1, 1)), 133 | ) 134 | 135 | 136 | def make_inter_layer(dim): 137 | return residual(3, dim, dim) 138 | 139 | 140 | def make_cnv_layer(inp_dim, out_dim): 141 | return convolution(3, inp_dim, out_dim) 142 | 143 | 144 | class kp_module(nn.Module): 145 | def __init__( 146 | self, 147 | n, 148 | dims, 149 | modules, 150 | layer=residual, 151 | make_up_layer=make_layer, 152 | make_low_layer=make_layer, 153 | make_hg_layer=make_layer, 154 | make_hg_layer_revr=make_layer_revr, 155 | make_pool_layer=make_pool_layer, 156 | make_unpool_layer=make_unpool_layer, 157 | make_merge_layer=make_merge_layer, 158 | **kwargs 159 | ): 160 | super(kp_module, self).__init__() 161 | 162 | self.n = n 163 | 164 | curr_mod = modules[0] 165 | next_mod = modules[1] 166 | 167 | curr_dim = dims[0] 168 | next_dim = dims[1] 169 | 170 | self.up1 = make_up_layer(3, curr_dim, curr_dim, curr_mod, layer=layer, **kwargs) 171 | self.max1 = make_pool_layer(curr_dim) 172 | self.low1 = make_hg_layer( 173 | 3, curr_dim, next_dim, curr_mod, layer=layer, **kwargs 174 | ) 175 | self.low2 = ( 176 | kp_module( 177 | n - 1, 178 | dims[1:], 179 | modules[1:], 180 | layer=layer, 181 | make_up_layer=make_up_layer, 182 | make_low_layer=make_low_layer, 183 | make_hg_layer=make_hg_layer, 184 | make_hg_layer_revr=make_hg_layer_revr, 185 | make_pool_layer=make_pool_layer, 186 | make_unpool_layer=make_unpool_layer, 187 | make_merge_layer=make_merge_layer, 188 | **kwargs 189 | ) 190 | if self.n > 1 191 | else make_low_layer(3, next_dim, next_dim, next_mod, layer=layer, **kwargs) 192 | ) 193 | self.low3 = make_hg_layer_revr( 194 | 3, next_dim, curr_dim, curr_mod, layer=layer, **kwargs 195 | ) 196 | self.up2 = make_unpool_layer(curr_dim) 197 | 198 | self.merge = make_merge_layer(curr_dim) 199 | 200 | def forward(self, x): 201 | up1 = self.up1(x) 202 | max1 = self.max1(x) 203 | low1 = self.low1(max1) 204 | low2 = self.low2(low1) 205 | low3 = self.low3(low2) 206 | up2 = self.up2(low3) 207 | return self.merge(up1, up2) 208 | 209 | 210 | class exkp(nn.Module): 211 | def __init__( 212 | self, 213 | n, 214 | nstack, 215 | dims, 216 | modules, 217 | pre=None, 218 | cnv_dim=256, 219 | make_tl_layer=None, 220 | make_br_layer=None, 221 | make_cnv_layer=make_cnv_layer, 222 | make_heat_layer=make_kp_layer, 223 | make_tag_layer=make_kp_layer, 224 | make_regr_layer=make_kp_layer, 225 | make_up_layer=make_layer, 226 | make_low_layer=make_layer, 227 | make_hg_layer=make_layer, 228 | make_hg_layer_revr=make_layer_revr, 229 | make_pool_layer=make_pool_layer, 230 | make_unpool_layer=make_unpool_layer, 231 | make_merge_layer=make_merge_layer, 232 | make_inter_layer=make_inter_layer, 233 | kp_layer=residual, 234 | ): 235 | super(exkp, self).__init__() 236 | 237 | self.nstack = nstack 238 | self.out_channels = 256 239 | 240 | curr_dim = dims[0] 241 | 242 | self.pre = ( 243 | nn.Sequential( 244 | convolution(7, 3, 128, stride=2), residual(3, 128, 256, stride=2) 245 | ) 246 | if pre is None 247 | else pre 248 | ) 249 | 250 | self.kps = nn.ModuleList( 251 | [ 252 | kp_module( 253 | n, 254 | dims, 255 | modules, 256 | layer=kp_layer, 257 | make_up_layer=make_up_layer, 258 | make_low_layer=make_low_layer, 259 | make_hg_layer=make_hg_layer, 260 | make_hg_layer_revr=make_hg_layer_revr, 261 | make_pool_layer=make_pool_layer, 262 | make_unpool_layer=make_unpool_layer, 263 | make_merge_layer=make_merge_layer, 264 | ) 265 | for _ in range(nstack) 266 | ] 267 | ) 268 | self.cnvs = nn.ModuleList( 269 | [make_cnv_layer(curr_dim, cnv_dim) for _ in range(nstack)] 270 | ) 271 | 272 | self.inters = nn.ModuleList( 273 | [make_inter_layer(curr_dim) for _ in range(nstack - 1)] 274 | ) 275 | 276 | self.inters_ = nn.ModuleList( 277 | [ 278 | nn.Sequential( 279 | nn.Conv2d(curr_dim, curr_dim, (1, 1), bias=False), 280 | nn.BatchNorm2d(curr_dim), 281 | ) 282 | for _ in range(nstack - 1) 283 | ] 284 | ) 285 | self.cnvs_ = nn.ModuleList( 286 | [ 287 | nn.Sequential( 288 | nn.Conv2d(cnv_dim, curr_dim, (1, 1), bias=False), 289 | nn.BatchNorm2d(curr_dim), 290 | ) 291 | for _ in range(nstack - 1) 292 | ] 293 | ) 294 | 295 | self.relu = nn.ReLU(inplace=True) 296 | 297 | def forward(self, image): 298 | # print('image shape', image.shape) 299 | inter = self.pre(image) 300 | outs = [] 301 | 302 | for ind in range(self.nstack): 303 | kp_, cnv_ = self.kps[ind], self.cnvs[ind] 304 | kp = kp_(inter) 305 | cnv = cnv_(kp) 306 | 307 | outs.append(cnv) 308 | 309 | if ind < self.nstack - 1: 310 | inter = self.inters_[ind](inter) + self.cnvs_[ind](cnv) 311 | inter = self.relu(inter) 312 | inter = self.inters[ind](inter) 313 | return outs 314 | 315 | 316 | def make_hg_layer(kernel, dim0, dim1, mod, layer=convolution, **kwargs): 317 | layers = [layer(kernel, dim0, dim1, stride=2)] 318 | layers += [layer(kernel, dim1, dim1) for _ in range(mod - 1)] 319 | return nn.Sequential(*layers) 320 | 321 | 322 | class HourglassNet(exkp): 323 | def __init__(self, num_stacks=2): 324 | n = 5 325 | dims = [256, 256, 384, 384, 384, 512] 326 | modules = [2, 2, 2, 2, 2, 4] 327 | 328 | super(HourglassNet, self).__init__( 329 | n, 330 | num_stacks, 331 | dims, 332 | modules, 333 | make_tl_layer=None, 334 | make_br_layer=None, 335 | make_pool_layer=make_pool_layer, 336 | make_hg_layer=make_hg_layer, 337 | kp_layer=residual, 338 | cnv_dim=256, 339 | ) 340 | 341 | 342 | def get_large_hourglass_net(num_layers): 343 | return HourglassNet() 344 | -------------------------------------------------------------------------------- /CenterNet/models/backbones/msra_resnet.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # Modified by Xingyi Zhou 6 | # ------------------------------------------------------------------------------ 7 | 8 | import torch.nn as nn 9 | import torch.utils.model_zoo as model_zoo 10 | 11 | BN_MOMENTUM = 0.1 12 | 13 | model_urls = { 14 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", 15 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", 16 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", 17 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", 18 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", 19 | } 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1): 23 | """3x3 convolution with padding""" 24 | return nn.Conv2d( 25 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 26 | ) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None): 65 | super(Bottleneck, self).__init__() 66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 67 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 68 | self.conv2 = nn.Conv2d( 69 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 70 | ) 71 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 72 | self.conv3 = nn.Conv2d( 73 | planes, planes * self.expansion, kernel_size=1, bias=False 74 | ) 75 | self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM) 76 | self.relu = nn.ReLU(inplace=True) 77 | self.downsample = downsample 78 | self.stride = stride 79 | 80 | def forward(self, x): 81 | residual = x 82 | 83 | out = self.conv1(x) 84 | out = self.bn1(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv2(out) 88 | out = self.bn2(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv3(out) 92 | out = self.bn3(out) 93 | 94 | if self.downsample is not None: 95 | residual = self.downsample(x) 96 | 97 | out += residual 98 | out = self.relu(out) 99 | 100 | return out 101 | 102 | 103 | class PoseResNet(nn.Module): 104 | def __init__(self, block, layers, **kwargs): 105 | self.inplanes = 64 106 | self.out_channels = 256 107 | self.deconv_with_bias = False 108 | 109 | super(PoseResNet, self).__init__() 110 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 111 | self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 112 | self.relu = nn.ReLU(inplace=True) 113 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 114 | self.layer1 = self._make_layer(block, 64, layers[0]) 115 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 116 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 117 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 118 | 119 | # used for deconv layers 120 | self.deconv_layers = self._make_deconv_layer( 121 | 3, 122 | [256, 256, 256], 123 | [4, 4, 4], 124 | ) 125 | # self.final_layer = [] 126 | 127 | # self.final_layer = nn.ModuleList(self.final_layer) 128 | 129 | def _make_layer(self, block, planes, blocks, stride=1): 130 | downsample = None 131 | if stride != 1 or self.inplanes != planes * block.expansion: 132 | downsample = nn.Sequential( 133 | nn.Conv2d( 134 | self.inplanes, 135 | planes * block.expansion, 136 | kernel_size=1, 137 | stride=stride, 138 | bias=False, 139 | ), 140 | nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), 141 | ) 142 | 143 | layers = [] 144 | layers.append(block(self.inplanes, planes, stride, downsample)) 145 | self.inplanes = planes * block.expansion 146 | for i in range(1, blocks): 147 | layers.append(block(self.inplanes, planes)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def _get_deconv_cfg(self, deconv_kernel, index): 152 | if deconv_kernel == 4: 153 | padding = 1 154 | output_padding = 0 155 | elif deconv_kernel == 3: 156 | padding = 1 157 | output_padding = 1 158 | elif deconv_kernel == 2: 159 | padding = 0 160 | output_padding = 0 161 | 162 | return deconv_kernel, padding, output_padding 163 | 164 | def _make_deconv_layer(self, num_layers, num_filters, num_kernels): 165 | assert num_layers == len( 166 | num_filters 167 | ), "ERROR: num_deconv_layers is different len(num_deconv_filters)" 168 | assert num_layers == len( 169 | num_kernels 170 | ), "ERROR: num_deconv_layers is different len(num_deconv_filters)" 171 | 172 | layers = [] 173 | for i in range(num_layers): 174 | kernel, padding, output_padding = self._get_deconv_cfg(num_kernels[i], i) 175 | 176 | planes = num_filters[i] 177 | layers.append( 178 | nn.ConvTranspose2d( 179 | in_channels=self.inplanes, 180 | out_channels=planes, 181 | kernel_size=kernel, 182 | stride=2, 183 | padding=padding, 184 | output_padding=output_padding, 185 | bias=self.deconv_with_bias, 186 | ) 187 | ) 188 | layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) 189 | layers.append(nn.ReLU(inplace=True)) 190 | self.inplanes = planes 191 | 192 | return nn.Sequential(*layers) 193 | 194 | def forward(self, x): 195 | x = self.conv1(x) 196 | x = self.bn1(x) 197 | x = self.relu(x) 198 | x = self.maxpool(x) 199 | 200 | x = self.layer1(x) 201 | x = self.layer2(x) 202 | x = self.layer3(x) 203 | x = self.layer4(x) 204 | 205 | x = self.deconv_layers(x) 206 | 207 | return [x] 208 | 209 | def init_weights(self, num_layers, pretrained=True): 210 | if pretrained: 211 | # print('=> init resnet deconv model_weights from normal distribution') 212 | for _, m in self.deconv_layers.named_modules(): 213 | if isinstance(m, nn.ConvTranspose2d): 214 | # print('=> init {}.weight as normal(0, 0.001)'.format(name)) 215 | # print('=> init {}.bias as 0'.format(name)) 216 | nn.init.normal_(m.weight, std=0.001) 217 | if self.deconv_with_bias: 218 | nn.init.constant_(m.bias, 0) 219 | elif isinstance(m, nn.BatchNorm2d): 220 | # print('=> init {}.weight as 1'.format(name)) 221 | # print('=> init {}.bias as 0'.format(name)) 222 | nn.init.constant_(m.weight, 1) 223 | nn.init.constant_(m.bias, 0) 224 | # print('=> init final conv model_weights from normal distribution') 225 | # for head in self.heads: 226 | # final_layer = self.__getattr__(head) 227 | # for i, m in enumerate(final_layer.modules()): 228 | # if isinstance(m, nn.Conv2d): 229 | # # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 230 | # # print('=> init {}.weight as normal(0, 0.001)'.format(name)) 231 | # # print('=> init {}.bias as 0'.format(name)) 232 | # if m.weight.shape[0] == self.heads[head]: 233 | # if "heatmap" in head: 234 | # nn.init.constant_(m.bias, -2.19) 235 | # else: 236 | # nn.init.normal_(m.weight, std=0.001) 237 | # nn.init.constant_(m.bias, 0) 238 | # pretrained_state_dict = torch.load(pretrained) 239 | url = model_urls["resnet{}".format(num_layers)] 240 | pretrained_state_dict = model_zoo.load_url(url) 241 | print("=> loading pretrained backbone {}".format(url)) 242 | self.load_state_dict(pretrained_state_dict, strict=False) 243 | else: 244 | print("=> imagenet pretrained backbone dose not exist") 245 | print("=> please download it first") 246 | raise ValueError("imagenet pretrained backbone does not exist") 247 | 248 | 249 | resnet_spec = { 250 | 18: (BasicBlock, [2, 2, 2, 2]), 251 | 34: (BasicBlock, [3, 4, 6, 3]), 252 | 50: (Bottleneck, [3, 4, 6, 3]), 253 | 101: (Bottleneck, [3, 4, 23, 3]), 254 | 152: (Bottleneck, [3, 8, 36, 3]), 255 | } 256 | 257 | 258 | def get_pose_net(num_layers): 259 | block_class, layers = resnet_spec[num_layers] 260 | 261 | model = PoseResNet(block_class, layers) 262 | model.init_weights(num_layers, pretrained=True) 263 | return model 264 | -------------------------------------------------------------------------------- /CenterNet/models/backbones/pose_dla_dcn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | import numpy as np 4 | from os.path import join 5 | 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | import torch.utils.model_zoo as model_zoo 10 | 11 | from DCN.dcn_v2 import DCN 12 | 13 | BN_MOMENTUM = 0.1 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def get_model_url(data="imagenet", name="dla34", hash="ba72cf86"): 18 | return join("http://dl.yf.io/dla/models", data, "{}-{}.pth".format(name, hash)) 19 | 20 | 21 | def conv3x3(in_planes, out_planes, stride=1): 22 | "3x3 convolution with padding" 23 | return nn.Conv2d( 24 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 25 | ) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | def __init__(self, inplanes, planes, stride=1, dilation=1): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = nn.Conv2d( 32 | inplanes, 33 | planes, 34 | kernel_size=3, 35 | stride=stride, 36 | padding=dilation, 37 | bias=False, 38 | dilation=dilation, 39 | ) 40 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 41 | self.relu = nn.ReLU(inplace=True) 42 | self.conv2 = nn.Conv2d( 43 | planes, 44 | planes, 45 | kernel_size=3, 46 | stride=1, 47 | padding=dilation, 48 | bias=False, 49 | dilation=dilation, 50 | ) 51 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 52 | self.stride = stride 53 | 54 | def forward(self, x, residual=None): 55 | if residual is None: 56 | residual = x 57 | 58 | out = self.conv1(x) 59 | out = self.bn1(out) 60 | out = self.relu(out) 61 | 62 | out = self.conv2(out) 63 | out = self.bn2(out) 64 | 65 | out += residual 66 | out = self.relu(out) 67 | 68 | return out 69 | 70 | 71 | class Bottleneck(nn.Module): 72 | expansion = 2 73 | 74 | def __init__(self, inplanes, planes, stride=1, dilation=1): 75 | super(Bottleneck, self).__init__() 76 | expansion = Bottleneck.expansion 77 | bottle_planes = planes // expansion 78 | self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False) 79 | self.bn1 = nn.BatchNorm2d(bottle_planes, momentum=BN_MOMENTUM) 80 | self.conv2 = nn.Conv2d( 81 | bottle_planes, 82 | bottle_planes, 83 | kernel_size=3, 84 | stride=stride, 85 | padding=dilation, 86 | bias=False, 87 | dilation=dilation, 88 | ) 89 | self.bn2 = nn.BatchNorm2d(bottle_planes, momentum=BN_MOMENTUM) 90 | self.conv3 = nn.Conv2d(bottle_planes, planes, kernel_size=1, bias=False) 91 | self.bn3 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 92 | self.relu = nn.ReLU(inplace=True) 93 | self.stride = stride 94 | 95 | def forward(self, x, residual=None): 96 | if residual is None: 97 | residual = x 98 | 99 | out = self.conv1(x) 100 | out = self.bn1(out) 101 | out = self.relu(out) 102 | 103 | out = self.conv2(out) 104 | out = self.bn2(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv3(out) 108 | out = self.bn3(out) 109 | 110 | out += residual 111 | out = self.relu(out) 112 | 113 | return out 114 | 115 | 116 | class BottleneckX(nn.Module): 117 | expansion = 2 118 | cardinality = 32 119 | 120 | def __init__(self, inplanes, planes, stride=1, dilation=1): 121 | super(BottleneckX, self).__init__() 122 | cardinality = BottleneckX.cardinality 123 | # dim = int(math.floor(planes * (BottleneckV5.expansion / 64.0))) 124 | # bottle_planes = dim * cardinality 125 | bottle_planes = planes * cardinality // 32 126 | self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False) 127 | self.bn1 = nn.BatchNorm2d(bottle_planes, momentum=BN_MOMENTUM) 128 | self.conv2 = nn.Conv2d( 129 | bottle_planes, 130 | bottle_planes, 131 | kernel_size=3, 132 | stride=stride, 133 | padding=dilation, 134 | bias=False, 135 | dilation=dilation, 136 | groups=cardinality, 137 | ) 138 | self.bn2 = nn.BatchNorm2d(bottle_planes, momentum=BN_MOMENTUM) 139 | self.conv3 = nn.Conv2d(bottle_planes, planes, kernel_size=1, bias=False) 140 | self.bn3 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 141 | self.relu = nn.ReLU(inplace=True) 142 | self.stride = stride 143 | 144 | def forward(self, x, residual=None): 145 | if residual is None: 146 | residual = x 147 | 148 | out = self.conv1(x) 149 | out = self.bn1(out) 150 | out = self.relu(out) 151 | 152 | out = self.conv2(out) 153 | out = self.bn2(out) 154 | out = self.relu(out) 155 | 156 | out = self.conv3(out) 157 | out = self.bn3(out) 158 | 159 | out += residual 160 | out = self.relu(out) 161 | 162 | return out 163 | 164 | 165 | class Root(nn.Module): 166 | def __init__(self, in_channels, out_channels, kernel_size, residual): 167 | super(Root, self).__init__() 168 | self.conv = nn.Conv2d( 169 | in_channels, 170 | out_channels, 171 | 1, 172 | stride=1, 173 | bias=False, 174 | padding=(kernel_size - 1) // 2, 175 | ) 176 | self.bn = nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM) 177 | self.relu = nn.ReLU(inplace=True) 178 | self.residual = residual 179 | 180 | def forward(self, *x): 181 | children = x 182 | x = self.conv(torch.cat(x, 1)) 183 | x = self.bn(x) 184 | if self.residual: 185 | x += children[0] 186 | x = self.relu(x) 187 | 188 | return x 189 | 190 | 191 | class Tree(nn.Module): 192 | def __init__( 193 | self, 194 | levels, 195 | block, 196 | in_channels, 197 | out_channels, 198 | stride=1, 199 | level_root=False, 200 | root_dim=0, 201 | root_kernel_size=1, 202 | dilation=1, 203 | root_residual=False, 204 | ): 205 | super(Tree, self).__init__() 206 | if root_dim == 0: 207 | root_dim = 2 * out_channels 208 | if level_root: 209 | root_dim += in_channels 210 | if levels == 1: 211 | self.tree1 = block(in_channels, out_channels, stride, dilation=dilation) 212 | self.tree2 = block(out_channels, out_channels, 1, dilation=dilation) 213 | else: 214 | self.tree1 = Tree( 215 | levels - 1, 216 | block, 217 | in_channels, 218 | out_channels, 219 | stride, 220 | root_dim=0, 221 | root_kernel_size=root_kernel_size, 222 | dilation=dilation, 223 | root_residual=root_residual, 224 | ) 225 | self.tree2 = Tree( 226 | levels - 1, 227 | block, 228 | out_channels, 229 | out_channels, 230 | root_dim=root_dim + out_channels, 231 | root_kernel_size=root_kernel_size, 232 | dilation=dilation, 233 | root_residual=root_residual, 234 | ) 235 | if levels == 1: 236 | self.root = Root(root_dim, out_channels, root_kernel_size, root_residual) 237 | self.level_root = level_root 238 | self.root_dim = root_dim 239 | self.downsample = None 240 | self.project = None 241 | self.levels = levels 242 | if stride > 1: 243 | self.downsample = nn.MaxPool2d(stride, stride=stride) 244 | if in_channels != out_channels: 245 | self.project = nn.Sequential( 246 | nn.Conv2d( 247 | in_channels, out_channels, kernel_size=1, stride=1, bias=False 248 | ), 249 | nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM), 250 | ) 251 | 252 | def forward(self, x, residual=None, children=None): 253 | children = [] if children is None else children 254 | bottom = self.downsample(x) if self.downsample else x 255 | residual = self.project(bottom) if self.project else bottom 256 | if self.level_root: 257 | children.append(bottom) 258 | x1 = self.tree1(x, residual) 259 | if self.levels == 1: 260 | x2 = self.tree2(x1) 261 | x = self.root(x2, x1, *children) 262 | else: 263 | children.append(x1) 264 | x = self.tree2(x1, children=children) 265 | return x 266 | 267 | 268 | class DLA(nn.Module): 269 | def __init__( 270 | self, 271 | levels, 272 | channels, 273 | num_classes=1000, 274 | block=BasicBlock, 275 | residual_root=False, 276 | linear_root=False, 277 | ): 278 | super(DLA, self).__init__() 279 | self.channels = channels 280 | self.num_classes = num_classes 281 | self.base_layer = nn.Sequential( 282 | nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, bias=False), 283 | nn.BatchNorm2d(channels[0], momentum=BN_MOMENTUM), 284 | nn.ReLU(inplace=True), 285 | ) 286 | self.level0 = self._make_conv_level(channels[0], channels[0], levels[0]) 287 | self.level1 = self._make_conv_level( 288 | channels[0], channels[1], levels[1], stride=2 289 | ) 290 | self.level2 = Tree( 291 | levels[2], 292 | block, 293 | channels[1], 294 | channels[2], 295 | 2, 296 | level_root=False, 297 | root_residual=residual_root, 298 | ) 299 | self.level3 = Tree( 300 | levels[3], 301 | block, 302 | channels[2], 303 | channels[3], 304 | 2, 305 | level_root=True, 306 | root_residual=residual_root, 307 | ) 308 | self.level4 = Tree( 309 | levels[4], 310 | block, 311 | channels[3], 312 | channels[4], 313 | 2, 314 | level_root=True, 315 | root_residual=residual_root, 316 | ) 317 | self.level5 = Tree( 318 | levels[5], 319 | block, 320 | channels[4], 321 | channels[5], 322 | 2, 323 | level_root=True, 324 | root_residual=residual_root, 325 | ) 326 | 327 | # for m in self.modules(): 328 | # if isinstance(m, nn.Conv2d): 329 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 330 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 331 | # elif isinstance(m, nn.BatchNorm2d): 332 | # m.weight.data.fill_(1) 333 | # m.bias.data.zero_() 334 | 335 | def _make_level(self, block, inplanes, planes, blocks, stride=1): 336 | downsample = None 337 | if stride != 1 or inplanes != planes: 338 | downsample = nn.Sequential( 339 | nn.MaxPool2d(stride, stride=stride), 340 | nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False), 341 | nn.BatchNorm2d(planes, momentum=BN_MOMENTUM), 342 | ) 343 | 344 | layers = [] 345 | layers.append(block(inplanes, planes, stride, downsample=downsample)) 346 | for i in range(1, blocks): 347 | layers.append(block(inplanes, planes)) 348 | 349 | return nn.Sequential(*layers) 350 | 351 | def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): 352 | modules = [] 353 | for i in range(convs): 354 | modules.extend( 355 | [ 356 | nn.Conv2d( 357 | inplanes, 358 | planes, 359 | kernel_size=3, 360 | stride=stride if i == 0 else 1, 361 | padding=dilation, 362 | bias=False, 363 | dilation=dilation, 364 | ), 365 | nn.BatchNorm2d(planes, momentum=BN_MOMENTUM), 366 | nn.ReLU(inplace=True), 367 | ] 368 | ) 369 | inplanes = planes 370 | return nn.Sequential(*modules) 371 | 372 | def forward(self, x): 373 | y = [] 374 | x = self.base_layer(x) 375 | for i in range(6): 376 | x = getattr(self, "level{}".format(i))(x) 377 | y.append(x) 378 | return y 379 | 380 | def load_pretrained_model(self, data="imagenet", name="dla34", hash="ba72cf86"): 381 | # fc = self.fc 382 | if name.endswith(".pth"): 383 | model_weights = torch.load(data + name) 384 | else: 385 | model_url = get_model_url(data, name, hash) 386 | model_weights = model_zoo.load_url(model_url) 387 | num_classes = len(model_weights[list(model_weights.keys())[-1]]) 388 | self.fc = nn.Conv2d( 389 | self.channels[-1], 390 | num_classes, 391 | kernel_size=1, 392 | stride=1, 393 | padding=0, 394 | bias=True, 395 | ) 396 | self.load_state_dict(model_weights) 397 | # self.fc = fc 398 | 399 | 400 | def dla34(pretrained=True, **kwargs): # DLA-34 401 | model = DLA( 402 | [1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=BasicBlock, **kwargs 403 | ) 404 | if pretrained: 405 | model.load_pretrained_model(data="imagenet", name="dla34", hash="ba72cf86") 406 | return model 407 | 408 | 409 | class Identity(nn.Module): 410 | def __init__(self): 411 | super(Identity, self).__init__() 412 | 413 | def forward(self, x): 414 | return x 415 | 416 | 417 | def fill_fc_weights(layers): 418 | for m in layers.modules(): 419 | if isinstance(m, nn.Conv2d): 420 | if m.bias is not None: 421 | nn.init.constant_(m.bias, 0) 422 | 423 | 424 | def fill_up_weights(up): 425 | w = up.weight.data 426 | f = math.ceil(w.size(2) / 2) 427 | c = (2 * f - 1 - f % 2) / (2.0 * f) 428 | for i in range(w.size(2)): 429 | for j in range(w.size(3)): 430 | w[0, 0, i, j] = (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) 431 | for c in range(1, w.size(0)): 432 | w[c, 0, :, :] = w[0, 0, :, :] 433 | 434 | 435 | class DeformConv(nn.Module): 436 | def __init__(self, chi, cho): 437 | super(DeformConv, self).__init__() 438 | self.actf = nn.Sequential( 439 | nn.BatchNorm2d(cho, momentum=BN_MOMENTUM), nn.ReLU(inplace=True) 440 | ) 441 | self.conv = DCN( 442 | chi, 443 | cho, 444 | kernel_size=(3, 3), 445 | stride=1, 446 | padding=1, 447 | dilation=1, 448 | deformable_groups=1, 449 | ) 450 | 451 | def forward(self, x): 452 | x = self.conv(x) 453 | x = self.actf(x) 454 | return x 455 | 456 | 457 | class IDAUp(nn.Module): 458 | def __init__(self, o, channels, up_f): 459 | super(IDAUp, self).__init__() 460 | for i in range(1, len(channels)): 461 | c = channels[i] 462 | f = int(up_f[i]) 463 | proj = DeformConv(c, o) 464 | node = DeformConv(o, o) 465 | 466 | up = nn.ConvTranspose2d( 467 | o, 468 | o, 469 | f * 2, 470 | stride=f, 471 | padding=f // 2, 472 | output_padding=0, 473 | groups=o, 474 | bias=False, 475 | ) 476 | fill_up_weights(up) 477 | 478 | setattr(self, "proj_" + str(i), proj) 479 | setattr(self, "up_" + str(i), up) 480 | setattr(self, "node_" + str(i), node) 481 | 482 | def forward(self, layers, startp, endp): 483 | for i in range(startp + 1, endp): 484 | upsample = getattr(self, "up_" + str(i - startp)) 485 | project = getattr(self, "proj_" + str(i - startp)) 486 | layers[i] = upsample(project(layers[i])) 487 | node = getattr(self, "node_" + str(i - startp)) 488 | layers[i] = node(layers[i] + layers[i - 1]) 489 | 490 | 491 | class DLAUp(nn.Module): 492 | def __init__(self, startp, channels, scales, in_channels=None): 493 | super(DLAUp, self).__init__() 494 | self.startp = startp 495 | if in_channels is None: 496 | in_channels = channels 497 | self.channels = channels 498 | channels = list(channels) 499 | scales = np.array(scales, dtype=int) 500 | for i in range(len(channels) - 1): 501 | j = -i - 2 502 | setattr( 503 | self, 504 | "ida_{}".format(i), 505 | IDAUp(channels[j], in_channels[j:], scales[j:] // scales[j]), 506 | ) 507 | scales[j + 1 :] = scales[j] 508 | in_channels[j + 1 :] = [channels[j] for _ in channels[j + 1 :]] 509 | 510 | def forward(self, layers): 511 | out = [layers[-1]] # start with 32 512 | for i in range(len(layers) - self.startp - 1): 513 | ida = getattr(self, "ida_{}".format(i)) 514 | ida(layers, len(layers) - i - 2, len(layers)) 515 | out.insert(0, layers[-1]) 516 | return out 517 | 518 | 519 | class Interpolate(nn.Module): 520 | def __init__(self, scale, mode): 521 | super(Interpolate, self).__init__() 522 | self.scale = scale 523 | self.mode = mode 524 | 525 | def forward(self, x): 526 | x = F.interpolate( 527 | x, scale_factor=self.scale, mode=self.mode, align_corners=False 528 | ) 529 | return x 530 | 531 | 532 | class DLASeg(nn.Module): 533 | def __init__( 534 | self, 535 | base_name, 536 | pretrained, 537 | down_ratio, 538 | final_kernel, 539 | last_level, 540 | out_channel=0, 541 | ): 542 | super(DLASeg, self).__init__() 543 | assert down_ratio in [2, 4, 8, 16] 544 | self.first_level = int(np.log2(down_ratio)) 545 | self.last_level = last_level 546 | self.base = globals()[base_name](pretrained=pretrained) 547 | channels = self.base.channels 548 | scales = [2 ** i for i in range(len(channels[self.first_level :]))] 549 | self.dla_up = DLAUp(self.first_level, channels[self.first_level :], scales) 550 | 551 | if out_channel == 0: 552 | out_channel = channels[self.first_level] 553 | self.out_channels = out_channel 554 | 555 | self.ida_up = IDAUp( 556 | out_channel, 557 | channels[self.first_level: self.last_level], 558 | [2 ** i for i in range(self.last_level - self.first_level)], 559 | ) 560 | 561 | def forward(self, x): 562 | x = self.base(x) 563 | x = self.dla_up(x) 564 | 565 | y = [] 566 | for i in range(self.last_level - self.first_level): 567 | y.append(x[i].clone()) 568 | self.ida_up(y, 0, len(y)) 569 | 570 | return [y[-1]] 571 | 572 | 573 | def get_pose_net(num_layers, down_ratio=4): 574 | model = DLASeg( 575 | "dla{}".format(num_layers), 576 | pretrained=True, 577 | down_ratio=down_ratio, 578 | final_kernel=1, 579 | last_level=5, 580 | ) 581 | return model 582 | -------------------------------------------------------------------------------- /CenterNet/models/backbones/resnet_dcn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # Modified by Dequan Wang and Xingyi Zhou 6 | # ------------------------------------------------------------------------------ 7 | 8 | import math 9 | import logging 10 | 11 | import torch.nn as nn 12 | import torch.utils.model_zoo as model_zoo 13 | 14 | from DCN.dcn_v2 import DCN 15 | 16 | 17 | BN_MOMENTUM = 0.1 18 | logger = logging.getLogger(__name__) 19 | 20 | model_urls = { 21 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", 22 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", 23 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", 24 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", 25 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", 26 | } 27 | 28 | 29 | def conv3x3(in_planes, out_planes, stride=1): 30 | """3x3 convolution with padding""" 31 | return nn.Conv2d( 32 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 33 | ) 34 | 35 | 36 | class BasicBlock(nn.Module): 37 | expansion = 1 38 | 39 | def __init__(self, inplanes, planes, stride=1, downsample=None): 40 | super(BasicBlock, self).__init__() 41 | self.conv1 = conv3x3(inplanes, planes, stride) 42 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.conv2 = conv3x3(planes, planes) 45 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 46 | self.downsample = downsample 47 | self.stride = stride 48 | 49 | def forward(self, x): 50 | residual = x 51 | 52 | out = self.conv1(x) 53 | out = self.bn1(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv2(out) 57 | out = self.bn2(out) 58 | 59 | if self.downsample is not None: 60 | residual = self.downsample(x) 61 | 62 | out += residual 63 | out = self.relu(out) 64 | 65 | return out 66 | 67 | 68 | class Bottleneck(nn.Module): 69 | expansion = 4 70 | 71 | def __init__(self, inplanes, planes, stride=1, downsample=None): 72 | super(Bottleneck, self).__init__() 73 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 74 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 75 | self.conv2 = nn.Conv2d( 76 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 77 | ) 78 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 79 | self.conv3 = nn.Conv2d( 80 | planes, planes * self.expansion, kernel_size=1, bias=False 81 | ) 82 | self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM) 83 | self.relu = nn.ReLU(inplace=True) 84 | self.downsample = downsample 85 | self.stride = stride 86 | 87 | def forward(self, x): 88 | residual = x 89 | 90 | out = self.conv1(x) 91 | out = self.bn1(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv2(out) 95 | out = self.bn2(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv3(out) 99 | out = self.bn3(out) 100 | 101 | if self.downsample is not None: 102 | residual = self.downsample(x) 103 | 104 | out += residual 105 | out = self.relu(out) 106 | 107 | return out 108 | 109 | 110 | def fill_up_weights(up): 111 | w = up.weight.data 112 | f = math.ceil(w.size(2) / 2) 113 | c = (2 * f - 1 - f % 2) / (2.0 * f) 114 | for i in range(w.size(2)): 115 | for j in range(w.size(3)): 116 | w[0, 0, i, j] = (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) 117 | for c in range(1, w.size(0)): 118 | w[c, 0, :, :] = w[0, 0, :, :] 119 | 120 | 121 | def fill_fc_weights(layers): 122 | for m in layers.modules(): 123 | if isinstance(m, nn.Conv2d): 124 | nn.init.normal_(m.weight, std=0.001) 125 | # torch.nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu') 126 | # torch.nn.init.xavier_normal_(m.weight.data) 127 | if m.bias is not None: 128 | nn.init.constant_(m.bias, 0) 129 | 130 | 131 | class PoseResNet(nn.Module): 132 | def __init__(self, block, layers): 133 | self.inplanes = 64 134 | self.out_channels = 64 135 | self.deconv_with_bias = False 136 | 137 | super(PoseResNet, self).__init__() 138 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 139 | self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 140 | self.relu = nn.ReLU(inplace=True) 141 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 142 | self.layer1 = self._make_layer(block, 64, layers[0]) 143 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 144 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 145 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 146 | 147 | # used for deconv layers 148 | self.deconv_layers = self._make_deconv_layer( 149 | 3, 150 | [256, 128, 64], 151 | [4, 4, 4], 152 | ) 153 | 154 | def _make_layer(self, block, planes, blocks, stride=1): 155 | downsample = None 156 | if stride != 1 or self.inplanes != planes * block.expansion: 157 | downsample = nn.Sequential( 158 | nn.Conv2d( 159 | self.inplanes, 160 | planes * block.expansion, 161 | kernel_size=1, 162 | stride=stride, 163 | bias=False, 164 | ), 165 | nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), 166 | ) 167 | 168 | layers = [] 169 | layers.append(block(self.inplanes, planes, stride, downsample)) 170 | self.inplanes = planes * block.expansion 171 | for i in range(1, blocks): 172 | layers.append(block(self.inplanes, planes)) 173 | 174 | return nn.Sequential(*layers) 175 | 176 | def _get_deconv_cfg(self, deconv_kernel, index): 177 | if deconv_kernel == 4: 178 | padding = 1 179 | output_padding = 0 180 | elif deconv_kernel == 3: 181 | padding = 1 182 | output_padding = 1 183 | elif deconv_kernel == 2: 184 | padding = 0 185 | output_padding = 0 186 | 187 | return deconv_kernel, padding, output_padding 188 | 189 | def _make_deconv_layer(self, num_layers, num_filters, num_kernels): 190 | assert num_layers == len( 191 | num_filters 192 | ), "ERROR: num_deconv_layers is different len(num_deconv_filters)" 193 | assert num_layers == len( 194 | num_kernels 195 | ), "ERROR: num_deconv_layers is different len(num_deconv_filters)" 196 | 197 | layers = [] 198 | for i in range(num_layers): 199 | kernel, padding, output_padding = self._get_deconv_cfg(num_kernels[i], i) 200 | 201 | planes = num_filters[i] 202 | fc = DCN( 203 | self.inplanes, 204 | planes, 205 | kernel_size=(3, 3), 206 | stride=1, 207 | padding=1, 208 | dilation=1, 209 | deformable_groups=1, 210 | ) 211 | # fc = nn.Conv2d(self.inplanes, planes, 212 | # kernel_size=3, stride=1, 213 | # padding=1, dilation=1, bias=False) 214 | # fill_fc_weights(fc) 215 | up = nn.ConvTranspose2d( 216 | in_channels=planes, 217 | out_channels=planes, 218 | kernel_size=kernel, 219 | stride=2, 220 | padding=padding, 221 | output_padding=output_padding, 222 | bias=self.deconv_with_bias, 223 | ) 224 | fill_up_weights(up) 225 | 226 | layers.append(fc) 227 | layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) 228 | layers.append(nn.ReLU(inplace=True)) 229 | layers.append(up) 230 | layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) 231 | layers.append(nn.ReLU(inplace=True)) 232 | self.inplanes = planes 233 | 234 | return nn.Sequential(*layers) 235 | 236 | def forward(self, x): 237 | x = self.conv1(x) 238 | x = self.bn1(x) 239 | x = self.relu(x) 240 | x = self.maxpool(x) 241 | 242 | x = self.layer1(x) 243 | x = self.layer2(x) 244 | x = self.layer3(x) 245 | x = self.layer4(x) 246 | 247 | x = self.deconv_layers(x) 248 | 249 | return [x] 250 | 251 | def init_weights(self, num_layers): 252 | if 1: 253 | url = model_urls["resnet{}".format(num_layers)] 254 | pretrained_state_dict = model_zoo.load_url(url) 255 | print("=> loading pretrained backbone {}".format(url)) 256 | self.load_state_dict(pretrained_state_dict, strict=False) 257 | print("=> init deconv model_weights from normal distribution") 258 | for name, m in self.deconv_layers.named_modules(): 259 | if isinstance(m, nn.BatchNorm2d): 260 | nn.init.constant_(m.weight, 1) 261 | nn.init.constant_(m.bias, 0) 262 | 263 | 264 | resnet_spec = { 265 | 18: (BasicBlock, [2, 2, 2, 2]), 266 | 34: (BasicBlock, [3, 4, 6, 3]), 267 | 50: (Bottleneck, [3, 4, 6, 3]), 268 | 101: (Bottleneck, [3, 4, 23, 3]), 269 | 152: (Bottleneck, [3, 8, 36, 3]), 270 | } 271 | 272 | 273 | def get_pose_net(num_layers): 274 | block_class, layers = resnet_spec[num_layers] 275 | 276 | model = PoseResNet(block_class, layers) 277 | model.init_weights(num_layers) 278 | return model 279 | -------------------------------------------------------------------------------- /CenterNet/models/heads.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class HeadConv(nn.Module): 5 | def __init__(self, out_channels: int, intermediate_channel: int, head_conv: int): 6 | super().__init__() 7 | self.out_channels = out_channels 8 | 9 | self.fc = nn.Sequential( 10 | nn.Conv2d( 11 | intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True 12 | ), 13 | nn.ReLU(inplace=True), 14 | nn.Conv2d(head_conv, out_channels, kernel_size=1, stride=1, padding=0), 15 | ) 16 | 17 | def forward(self, x): 18 | return self.fc(x) 19 | 20 | def fill_fc_weights(self): 21 | for m in self.modules(): 22 | if isinstance(m, nn.Conv2d): 23 | nn.init.normal_(m.weight, std=0.001) 24 | if m.bias is not None: 25 | nn.init.constant_(m.bias, 0) 26 | 27 | 28 | class CenterHead(nn.Module): 29 | def __init__(self, heads, intermediate_channel, head_conv): 30 | super().__init__() 31 | 32 | self.heads = heads 33 | for name, out_channel in heads.items(): 34 | self.__setattr__(name, HeadConv(out_channel, intermediate_channel, head_conv)) 35 | 36 | self.init_weights() 37 | 38 | def forward(self, x): 39 | ret = {} 40 | for name in self.heads.keys(): 41 | ret[name] = self.__getattr__(name)(x) 42 | 43 | return ret 44 | 45 | def init_weights(self): 46 | for name in self.heads.keys(): 47 | if name.startswith("heatmap"): 48 | self.__getattr__(name).fc[-1].bias.data.fill_(-2.19) 49 | else: 50 | self.__getattr__(name).fill_fc_weights() 51 | -------------------------------------------------------------------------------- /CenterNet/sample/__init__.py: -------------------------------------------------------------------------------- 1 | from .ctdet import CenterDetectionSample 2 | from .multi_pose import MultiPoseSample 3 | -------------------------------------------------------------------------------- /CenterNet/sample/ctdet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from CenterNet.utils.gaussian import draw_umich_gaussian, draw_msra_gaussian, gaussian_radius 7 | 8 | 9 | class CenterDetectionSample: 10 | def __init__( 11 | self, 12 | down_ratio=4, 13 | num_classes=80, 14 | max_objects=128, 15 | gaussian_type="umich", 16 | ): 17 | 18 | self.down_ratio = down_ratio 19 | 20 | self.num_classes = num_classes 21 | self.max_objects = max_objects 22 | self.gaussian_type = gaussian_type 23 | 24 | @staticmethod 25 | def _coco_box_to_bbox(box): 26 | return np.array( 27 | [box[0], box[1], box[0] + box[2], box[1] + box[3]], dtype=np.float32 28 | ) 29 | 30 | def scale_point(self, point, output_size): 31 | x, y = point / self.down_ratio 32 | output_h, output_w = output_size 33 | 34 | x = np.clip(x, 0, output_w - 1) 35 | y = np.clip(y, 0, output_h - 1) 36 | 37 | return [x, y] 38 | 39 | def __call__(self, img, target): 40 | _, input_w, input_h = img.shape 41 | 42 | output_h = input_h // self.down_ratio 43 | output_w = input_w // self.down_ratio 44 | 45 | heatmap = torch.zeros( 46 | (self.num_classes, output_h, output_w), dtype=torch.float32 47 | ) 48 | width_height = torch.zeros((self.max_objects, 2), dtype=torch.float32) 49 | regression = torch.zeros((self.max_objects, 2), dtype=torch.float32) 50 | regression_mask = torch.zeros(self.max_objects, dtype=torch.bool) 51 | indices = torch.zeros(self.max_objects, dtype=torch.int64) 52 | 53 | draw_gaussian = ( 54 | draw_msra_gaussian if self.gaussian_type == "msra" else draw_umich_gaussian 55 | ) 56 | 57 | num_objects = min(len(target), self.max_objects) 58 | for k in range(num_objects): 59 | ann = target[k] 60 | bbox = self._coco_box_to_bbox(ann["bbox"]) 61 | cls_id = ann["class_id"] if "class_id" in ann else int(ann["category_id"]) - 1 62 | 63 | # Scale to output size 64 | bbox[:2] = self.scale_point(bbox[:2], (output_h, output_w)) 65 | bbox[2:] = self.scale_point(bbox[2:], (output_h, output_w)) 66 | 67 | h, w = bbox[3] - bbox[1], bbox[2] - bbox[0] 68 | if h > 0 and w > 0: 69 | radius = gaussian_radius((math.ceil(h), math.ceil(w))) 70 | radius = max(0, int(radius)) 71 | ct = torch.FloatTensor( 72 | [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] 73 | ) 74 | ct_int = ct.to(torch.int32) 75 | 76 | draw_gaussian(heatmap[cls_id], ct_int, radius) 77 | width_height[k] = torch.tensor([1.0 * w, 1.0 * h]) 78 | indices[k] = ct_int[1] * output_w + ct_int[0] 79 | regression[k] = ct - ct_int 80 | regression_mask[k] = 1 81 | 82 | ret = { 83 | "heatmap": heatmap, 84 | "regression_mask": regression_mask, 85 | "indices": indices, 86 | "width_height": width_height, 87 | "regression": regression, 88 | } 89 | 90 | return img, ret 91 | -------------------------------------------------------------------------------- /CenterNet/sample/multi_pose.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from CenterNet.utils.gaussian import draw_umich_gaussian, draw_msra_gaussian, gaussian_radius 7 | 8 | 9 | class MultiPoseSample: 10 | def __init__( 11 | self, down_ratio=4, max_objects=128, gaussian_type="msra", num_joints=17 12 | ): 13 | 14 | self.down_ratio = down_ratio 15 | 16 | self.max_objects = max_objects 17 | self.gaussian_type = gaussian_type 18 | self.num_joints = num_joints 19 | 20 | @staticmethod 21 | def _coco_box_to_bbox(box): 22 | return np.array( 23 | [box[0], box[1], box[0] + box[2], box[1] + box[3]], dtype=np.float32 24 | ) 25 | 26 | def scale_point(self, point, output_size): 27 | x, y = point / self.down_ratio 28 | output_h, output_w = output_size 29 | 30 | x = np.clip(x, 0, output_w - 1) 31 | y = np.clip(y, 0, output_h - 1) 32 | 33 | return [x, y] 34 | 35 | def __call__(self, img, target): 36 | _, input_w, input_h = img.shape 37 | 38 | output_h = input_h // self.down_ratio 39 | output_w = input_w // self.down_ratio 40 | 41 | heatmap_keypoints = torch.zeros( 42 | (self.num_joints, output_h, output_w), dtype=torch.float32 43 | ) 44 | keypoints = torch.zeros( 45 | (self.max_objects, self.num_joints * 2), dtype=torch.float32 46 | ) 47 | keypoints_mask = torch.zeros( 48 | (self.max_objects, self.num_joints * 2), dtype=torch.bool 49 | ) 50 | heatmap_keypoints_offset = torch.zeros( 51 | (self.max_objects * self.num_joints, 2), dtype=torch.float32 52 | ) 53 | 54 | heatmap_keypoints_indices = torch.zeros( 55 | (self.max_objects * self.num_joints), dtype=torch.int64 56 | ) 57 | heatmap_keypoints_mask = torch.zeros( 58 | (self.max_objects * self.num_joints), dtype=torch.bool 59 | ) 60 | 61 | draw_gaussian = ( 62 | draw_msra_gaussian if self.gaussian_type == "msra" else draw_umich_gaussian 63 | ) 64 | 65 | num_objects = min(len(target), self.max_objects) 66 | for k in range(num_objects): 67 | ann = target[k] 68 | bbox = self._coco_box_to_bbox(ann["bbox"]) 69 | 70 | # Scale to output size 71 | bbox[:2] = self.scale_point(bbox[:2], (output_h, output_w)) 72 | bbox[2:] = self.scale_point(bbox[2:], (output_h, output_w)) 73 | 74 | ct_int = torch.IntTensor([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]) 75 | 76 | h, w = bbox[3] - bbox[1], bbox[2] - bbox[0] 77 | if h > 0 and w > 0: 78 | hp_radius = gaussian_radius((math.ceil(h), math.ceil(w))) 79 | pts = torch.from_numpy( 80 | np.array(ann["keypoints"], np.float32).reshape(self.num_joints, 3) 81 | ) 82 | 83 | for j in range(self.num_joints): 84 | if pts[j, 2] == 0: 85 | continue 86 | 87 | pts[j, :2] = torch.FloatTensor( 88 | self.scale_point(pts[j, :2], (output_h, output_w)) 89 | ) 90 | 91 | keypoints[k, j * 2: j * 2 + 2] = pts[j, :2] - ct_int 92 | keypoints_mask[k, j * 2: j * 2 + 2] = 1 93 | 94 | pt_int = pts[j, :2].to(torch.int32) 95 | heatmap_keypoints_offset[k * self.num_joints + j] = pts[j, :2] - pt_int 96 | heatmap_keypoints_indices[k * self.num_joints + j] = ( 97 | pt_int[1] * output_w + pt_int[0] 98 | ) 99 | heatmap_keypoints_mask[k * self.num_joints + j] = 1 100 | 101 | draw_gaussian(heatmap_keypoints[j], pt_int, hp_radius) 102 | 103 | ret = { 104 | "heatmap_keypoints": heatmap_keypoints, 105 | "keypoints": keypoints, 106 | "keypoints_mask": keypoints_mask, 107 | "heatmap_keypoints_offset": heatmap_keypoints_offset, 108 | "heatmap_keypoints_indices": heatmap_keypoints_indices, 109 | "heatmap_keypoints_mask": heatmap_keypoints_mask, 110 | } 111 | 112 | return img, ret 113 | -------------------------------------------------------------------------------- /CenterNet/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .image import * 2 | from .sample import * 3 | -------------------------------------------------------------------------------- /CenterNet/transforms/image.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from imgaug.augmentables import Keypoint, KeypointsOnImage, BoundingBox, BoundingBoxesOnImage 7 | from imgaug.augmenters import Augmenter, Identity 8 | 9 | 10 | class ImageAugmentation: 11 | def __init__(self, imgaug_augmenter: Augmenter = Identity(), img_transforms=None, num_joints=17): 12 | self.ia_sequence = imgaug_augmenter 13 | self.img_transforms = img_transforms 14 | self.num_joints = num_joints 15 | 16 | def __call__(self, img, target): 17 | # PIL to array BGR 18 | img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) 19 | target = copy.deepcopy(target) 20 | 21 | # Prepare augmentables for imgaug 22 | bounding_boxes = [] 23 | keypoints = [] 24 | for idx in range(len(target)): 25 | ann = target[idx] 26 | 27 | # Bounding Box 28 | box = ann['bbox'] 29 | bounding_boxes.append(BoundingBox( 30 | x1=box[0], 31 | y1=box[1], 32 | x2=box[0] + box[2], 33 | y2=box[1] + box[3], 34 | label=idx 35 | )) 36 | 37 | # Keypoints 38 | if 'num_keypoints' not in ann or ann['num_keypoints'] == 0: 39 | continue 40 | 41 | points = np.array(ann['keypoints'], np.float32).reshape(self.num_joints, 3) 42 | for i in range(self.num_joints): 43 | keypoints.append(Keypoint(x=points[i][0], y=points[i][1])) 44 | 45 | # Augmentation 46 | image_aug, bbs_aug, kps_aug = self.ia_sequence( 47 | image=img, 48 | bounding_boxes=BoundingBoxesOnImage(bounding_boxes, shape=img.shape), 49 | keypoints=KeypointsOnImage(keypoints, shape=img.shape) 50 | ) 51 | 52 | # Write augmentation back to annotations 53 | for bb in bbs_aug: 54 | target[bb.label]['bbox'] = [ 55 | bb.x1, bb.y1, bb.width, bb.height 56 | ] 57 | 58 | for ann in target: 59 | if 'num_keypoints' not in ann or ann['num_keypoints'] == 0: 60 | continue 61 | 62 | aug_keypoints = [] 63 | points = np.array(ann['keypoints'], np.float32).reshape(self.num_joints, 3) 64 | for i in range(self.num_joints): 65 | aug_kp = kps_aug.items.pop(0) 66 | kp_type = int(points[i][2]) 67 | if kp_type == 0: 68 | aug_keypoints.extend([0, 0, 0]) 69 | else: 70 | aug_keypoints.extend([aug_kp.x, aug_kp.y, kp_type]) 71 | 72 | ann['keypoints'] = aug_keypoints 73 | 74 | # torchvision transforms 75 | if self.img_transforms: 76 | image_aug = self.img_transforms(image_aug) 77 | 78 | return image_aug, target 79 | -------------------------------------------------------------------------------- /CenterNet/transforms/sample.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import numpy as np 5 | from collections import Callable 6 | 7 | import torchvision.transforms.functional as VF 8 | 9 | 10 | class ComposeSample: 11 | """Composes several transforms together on sample of image and target 12 | 13 | Args: 14 | transforms (list of ``Transform`` objects): list of transforms to compose. 15 | """ 16 | 17 | def __init__(self, transforms): 18 | self.transforms = transforms 19 | 20 | def __call__(self, img, target): 21 | for t in self.transforms: 22 | img, target = t(img, target) 23 | return img, target 24 | 25 | def __repr__(self): 26 | format_string = self.__class__.__name__ + '(' 27 | for t in self.transforms: 28 | format_string += '\n' 29 | format_string += ' {0}'.format(t) 30 | format_string += '\n)' 31 | return format_string 32 | 33 | 34 | class MultiSampleTransform: 35 | def __init__(self, transforms: [Callable]): 36 | self.transforms = transforms 37 | 38 | def __call__(self, img, target): 39 | ret_all = {} 40 | 41 | for transform in self.transforms: 42 | img, ret = transform(img, target) 43 | 44 | ret_all.update(ret) 45 | 46 | return img, ret_all 47 | 48 | 49 | class PoseFlip: 50 | flip_idx_array = [ 51 | 0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 52 | ] 53 | 54 | def __init__(self, flip_probability=0.5, num_joints=17): 55 | self.flip_probability = flip_probability 56 | 57 | self.num_joints = num_joints 58 | 59 | def __call__(self, img, target): 60 | if torch.rand(1) < self.flip_probability: 61 | img = VF.hflip(img) 62 | target = copy.deepcopy(target) 63 | 64 | for i in range(len(target)): 65 | # change x1 66 | bbox = target[i]["bbox"] 67 | width = img.shape[2] 68 | bbox[0] = width - bbox[0] - 1 69 | 70 | if 'num_keypoints' not in target[i] or target[i]['num_keypoints'] == 0: 71 | continue 72 | 73 | points = np.array(target[i]['keypoints'], np.float32).reshape(self.num_joints, 3) 74 | points[:, 0] = width - points[:, 0] - 1 75 | points[points[:, 2] == 0] = 0 76 | points_flipped = points[self.flip_idx_array, :] 77 | 78 | target[i]['keypoints'] = points_flipped.reshape(-1).tolist() 79 | target[i]["bbox"] = bbox 80 | 81 | return img, target 82 | 83 | 84 | class CategoryIdToClass: 85 | def __init__(self, valid_ids): 86 | self.valid_ids = valid_ids 87 | self.category_ids = {v: i for i, v in enumerate(self.valid_ids)} 88 | 89 | def __call__(self, img, target): 90 | for ann in target: 91 | ann["class_id"] = int(self.category_ids[int(ann["category_id"])]) 92 | 93 | return img, target 94 | 95 | -------------------------------------------------------------------------------- /CenterNet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tteepe/CenterNet-pytorch-lightning/2febb56103046064a42b502a4145c37728917042/CenterNet/utils/__init__.py -------------------------------------------------------------------------------- /CenterNet/utils/decode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def _nms(heat, kernel=3): 6 | pad = (kernel - 1) // 2 7 | 8 | hmax = nn.functional.max_pool2d(heat, (kernel, kernel), stride=1, padding=pad) 9 | keep = (hmax == heat).float() 10 | return heat * keep 11 | 12 | 13 | def _topk(scores, K=40): 14 | batch, cat, height, width = scores.size() 15 | 16 | topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K) 17 | 18 | topk_inds = topk_inds % (height * width) 19 | topk_ys = (topk_inds / width).int().float() 20 | topk_xs = (topk_inds % width).int().float() 21 | 22 | topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K) 23 | topk_clses = (topk_ind / K).int() 24 | topk_inds = _gather_feat(topk_inds.view(batch, -1, 1), topk_ind).view(batch, K) 25 | topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K) 26 | topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K) 27 | 28 | return topk_score, topk_inds, topk_clses, topk_ys, topk_xs 29 | 30 | 31 | def _topk_channel(scores, K=40): 32 | batch, cat, height, width = scores.size() 33 | 34 | topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K) 35 | 36 | topk_inds = topk_inds % (height * width) 37 | topk_ys = (topk_inds / width).int().float() 38 | topk_xs = (topk_inds % width).int().float() 39 | 40 | return topk_scores, topk_inds, topk_ys, topk_xs 41 | 42 | 43 | def sigmoid_clamped(x, clamp=1e-4): 44 | y = torch.clamp(x.sigmoid_(), min=clamp, max=1 - clamp) 45 | return y 46 | 47 | 48 | def _gather_feat(feat, ind, mask=None): 49 | dim = feat.size(2) 50 | ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) 51 | feat = feat.gather(1, ind) 52 | if mask is not None: 53 | mask = mask.unsqueeze(2).expand_as(feat) 54 | feat = feat[mask] 55 | feat = feat.view(-1, dim) 56 | return feat 57 | 58 | 59 | def _transpose_and_gather_feat(feat, ind): 60 | feat = feat.permute(0, 2, 3, 1).contiguous() 61 | feat = feat.view(feat.size(0), -1, feat.size(3)) 62 | feat = _gather_feat(feat, ind) 63 | return feat 64 | -------------------------------------------------------------------------------- /CenterNet/utils/gaussian.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def gaussian_radius(det_size, min_overlap=0.7): 7 | height, width = det_size 8 | 9 | a1 = 1 10 | b1 = height + width 11 | c1 = width * height * (1 - min_overlap) / (1 + min_overlap) 12 | sq1 = math.sqrt(b1 ** 2 - 4 * a1 * c1) 13 | r1 = (b1 + sq1) / 2 14 | 15 | a2 = 4 16 | b2 = 2 * (height + width) 17 | c2 = (1 - min_overlap) * width * height 18 | sq2 = math.sqrt(b2 ** 2 - 4 * a2 * c2) 19 | r2 = (b2 + sq2) / 2 20 | 21 | a3 = 4 * min_overlap 22 | b3 = -2 * min_overlap * (height + width) # may cause a negative radius 23 | c3 = (min_overlap - 1) * width * height 24 | sq3 = math.sqrt(b3 ** 2 - 4 * a3 * c3) 25 | r3 = (b3 + sq3) / 2 26 | return min(r1, r2, r3) 27 | 28 | 29 | def gaussian2D(shape, sigma=1): 30 | m, n = [(ss - 1.0) / 2.0 for ss in shape] 31 | 32 | # y, x = np.ogrid[-m:m + 1, -n:n + 1] 33 | y = torch.arange(-m, m + 1).unsqueeze(-1) 34 | x = torch.arange(-n, n + 1).unsqueeze(0) 35 | 36 | h = torch.exp(-(x * x + y * y) / (2 * sigma * sigma)) 37 | h[h < torch.finfo(h.dtype).eps * h.max()] = 0 38 | return h 39 | 40 | 41 | def draw_umich_gaussian(heatmap, center, radius, k=1): 42 | diameter = 2 * radius + 1 43 | gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6) 44 | 45 | x, y = int(center[0]), int(center[1]) 46 | 47 | height, width = heatmap.shape[0:2] 48 | 49 | left, right = int(min(x, radius)), int(min(width - x, radius + 1)) 50 | top, bottom = int(min(y, radius)), int(min(height - y, radius + 1)) 51 | 52 | masked_heatmap = heatmap[y - top : y + bottom, x - left : x + right] 53 | masked_gaussian = gaussian[ 54 | radius - top : radius + bottom, radius - left : radius + right 55 | ] 56 | if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debug 57 | torch.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) 58 | return heatmap 59 | 60 | 61 | def draw_msra_gaussian(heatmap, center, sigma): 62 | tmp_size = sigma * 3 63 | mu_x = int(center[0] + 0.5) 64 | mu_y = int(center[1] + 0.5) 65 | w, h = heatmap.shape[0], heatmap.shape[1] 66 | ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] 67 | br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] 68 | if br[0] >= h or br[1] >= w or ul[0] < 0 or ul[1] < 0: 69 | return heatmap 70 | size = 2 * tmp_size + 1 71 | x = np.arange(0, size, 1, np.float32) 72 | y = x[:, np.newaxis] 73 | x0 = y0 = size // 2 74 | g = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) # if the sigma is 0 here, will cause a Nan value 75 | g_x = max(0, -ul[0]), min(br[0], h) - ul[0] 76 | g_y = max(0, -ul[1]), min(br[1], w) - ul[1] 77 | img_x = max(0, ul[0]), min(br[0], h) 78 | img_y = max(0, ul[1]), min(br[1], w) 79 | heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]] = np.maximum( 80 | heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]], 81 | g[g_y[0] : g_y[1], g_x[0] : g_x[1]], 82 | ) 83 | return heatmap 84 | -------------------------------------------------------------------------------- /CenterNet/utils/losses.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Portions of this code are from 3 | # CornerNet (https://github.com/princeton-vl/CornerNet) 4 | # Copyright (c) 2018, University of Michigan 5 | # Licensed under the BSD 3-Clause License 6 | # ------------------------------------------------------------------------------ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from .decode import _transpose_and_gather_feat 12 | 13 | 14 | def _neg_loss(pred, gt): 15 | """Modified focal loss. Exactly the same as CornerNet. 16 | Runs faster and costs a little bit more memory 17 | Arguments: 18 | pred (batch x c x h x w) 19 | gt (batch x c x h x w) 20 | """ 21 | pos_inds = gt.eq(1).float() 22 | neg_inds = gt.lt(1).float() 23 | 24 | neg_weights = torch.pow(1 - gt, 4) 25 | 26 | loss = 0 27 | 28 | pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds 29 | neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds 30 | 31 | num_pos = pos_inds.float().sum() 32 | pos_loss = pos_loss.sum() 33 | neg_loss = neg_loss.sum() 34 | 35 | if num_pos == 0: 36 | loss = loss - neg_loss 37 | else: 38 | loss = loss - (pos_loss + neg_loss) / num_pos 39 | return loss 40 | 41 | 42 | class FocalLoss(nn.Module): 43 | """nn.Module warpper for focal loss""" 44 | 45 | def __init__(self): 46 | super(FocalLoss, self).__init__() 47 | self.neg_loss = _neg_loss 48 | 49 | def forward(self, out, target): 50 | return self.neg_loss(out, target) 51 | 52 | 53 | class RegL1Loss(nn.Module): 54 | def __init__(self): 55 | super(RegL1Loss, self).__init__() 56 | 57 | def forward(self, output, mask, ind, target): 58 | pred = _transpose_and_gather_feat(output, ind) 59 | mask = mask.unsqueeze(2).expand_as(pred).float() 60 | # loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean') 61 | loss = F.l1_loss(pred * mask, target * mask, reduction="sum") 62 | loss = loss / (mask.sum() + 1e-4) 63 | return loss 64 | 65 | 66 | class NormRegL1Loss(nn.Module): 67 | def __init__(self): 68 | super(NormRegL1Loss, self).__init__() 69 | 70 | def forward(self, output, mask, ind, target): 71 | pred = _transpose_and_gather_feat(output, ind) 72 | mask = mask.unsqueeze(2).expand_as(pred).float() 73 | # loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean') 74 | pred = pred / (target + 1e-4) 75 | target = target * 0 + 1 76 | loss = F.l1_loss(pred * mask, target * mask, size_average=False) 77 | loss = loss / (mask.sum() + 1e-4) 78 | return loss 79 | 80 | 81 | class RegWeightedL1Loss(nn.Module): 82 | def __init__(self): 83 | super(RegWeightedL1Loss, self).__init__() 84 | 85 | def forward(self, output, mask, ind, target): 86 | pred = _transpose_and_gather_feat(output, ind) 87 | mask = mask.float() 88 | # loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean') 89 | loss = F.l1_loss(pred * mask, target * mask, reduction="sum") 90 | loss = loss / (mask.sum() + 1e-4) 91 | return loss 92 | -------------------------------------------------------------------------------- /CenterNet/utils/nms.py: -------------------------------------------------------------------------------- 1 | from numba import jit 2 | import numpy as np 3 | 4 | 5 | @jit(nopython=True) 6 | def soft_nms(boxes, sigma=0.5, Nt=0.3, threshold=0.001, method=0): 7 | """ 8 | Soft-NMS: Improving Object Detection With One Line of Code 9 | Copyright (c) University of Maryland, College Park 10 | Licensed under The MIT License [see LICENSE for details] 11 | Written by Navaneeth Bodla and Bharat Singh 12 | 13 | :param boxes: bounding_boxes [x1, y1, x2, y2, sc] 14 | :param sigma: Sigma 15 | :param Nt: 16 | :param threshold: 17 | :param method: 0 | 1 | 2 - Oringal NMS | linear | gaussian 18 | :return: 19 | """ 20 | N = boxes.shape[0] 21 | 22 | for i in range(N): 23 | maxscore = boxes[i, 4] 24 | maxpos = i 25 | 26 | tx1 = boxes[i, 0] 27 | ty1 = boxes[i, 1] 28 | tx2 = boxes[i, 2] 29 | ty2 = boxes[i, 3] 30 | ts = boxes[i, 4] 31 | 32 | pos = i + 1 33 | # get max box 34 | while pos < N: 35 | if maxscore < boxes[pos, 4]: 36 | maxscore = boxes[pos, 4] 37 | maxpos = pos 38 | pos = pos + 1 39 | 40 | # add max box as a detection 41 | boxes[i, 0] = boxes[maxpos, 0] 42 | boxes[i, 1] = boxes[maxpos, 1] 43 | boxes[i, 2] = boxes[maxpos, 2] 44 | boxes[i, 3] = boxes[maxpos, 3] 45 | boxes[i, 4] = boxes[maxpos, 4] 46 | 47 | # swap ith box with position of max box 48 | boxes[maxpos, 0] = tx1 49 | boxes[maxpos, 1] = ty1 50 | boxes[maxpos, 2] = tx2 51 | boxes[maxpos, 3] = ty2 52 | boxes[maxpos, 4] = ts 53 | 54 | tx1 = boxes[i, 0] 55 | ty1 = boxes[i, 1] 56 | tx2 = boxes[i, 2] 57 | ty2 = boxes[i, 3] 58 | ts = boxes[i, 4] 59 | 60 | pos = i + 1 61 | # NMS iterations, note that N changes if detection boxes fall below threshold 62 | while pos < N: 63 | x1 = boxes[pos, 0] 64 | y1 = boxes[pos, 1] 65 | x2 = boxes[pos, 2] 66 | y2 = boxes[pos, 3] 67 | s = boxes[pos, 4] 68 | 69 | area = (x2 - x1 + 1) * (y2 - y1 + 1) 70 | iw = (min(tx2, x2) - max(tx1, x1) + 1) 71 | if iw > 0: 72 | ih = (min(ty2, y2) - max(ty1, y1) + 1) 73 | if ih > 0: 74 | ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih) 75 | ov = iw * ih / ua # iou between max box and detection box 76 | 77 | if method == 1: # linear 78 | if ov > Nt: 79 | weight = 1 - ov 80 | else: 81 | weight = 1 82 | elif method == 2: # gaussian 83 | weight = np.exp(-(ov * ov) / sigma) 84 | else: # original NMS 85 | if ov > Nt: 86 | weight = 0 87 | else: 88 | weight = 1 89 | 90 | boxes[pos, 4] = weight * boxes[pos, 4] 91 | 92 | # if box score falls below threshold, discard the box by swapping with last box 93 | # update N 94 | if boxes[pos, 4] < threshold: 95 | boxes[pos, 0] = boxes[N - 1, 0] 96 | boxes[pos, 1] = boxes[N - 1, 1] 97 | boxes[pos, 2] = boxes[N - 1, 2] 98 | boxes[pos, 3] = boxes[N - 1, 3] 99 | boxes[pos, 4] = boxes[N - 1, 4] 100 | N = N - 1 101 | pos = pos - 1 102 | 103 | pos = pos + 1 104 | 105 | keep = [i for i in range(N)] 106 | return keep 107 | 108 | 109 | @jit(nopython=True) 110 | def soft_nms_39(boxes, sigma=0.5, Nt=0.3, threshold=0.001, method=0): 111 | N = boxes.shape[0] 112 | 113 | for i in range(N): 114 | maxscore = boxes[i, 4] 115 | maxpos = i 116 | 117 | tx1 = boxes[i, 0] 118 | ty1 = boxes[i, 1] 119 | tx2 = boxes[i, 2] 120 | ty2 = boxes[i, 3] 121 | ts = boxes[i, 4] 122 | 123 | pos = i + 1 124 | # get max box 125 | while pos < N: 126 | if maxscore < boxes[pos, 4]: 127 | maxscore = boxes[pos, 4] 128 | maxpos = pos 129 | pos = pos + 1 130 | 131 | # add max box as a detection 132 | boxes[i, 0] = boxes[maxpos, 0] 133 | boxes[i, 1] = boxes[maxpos, 1] 134 | boxes[i, 2] = boxes[maxpos, 2] 135 | boxes[i, 3] = boxes[maxpos, 3] 136 | boxes[i, 4] = boxes[maxpos, 4] 137 | 138 | # swap ith box with position of max box 139 | boxes[maxpos, 0] = tx1 140 | boxes[maxpos, 1] = ty1 141 | boxes[maxpos, 2] = tx2 142 | boxes[maxpos, 3] = ty2 143 | boxes[maxpos, 4] = ts 144 | 145 | for j in range(5, 39): 146 | tmp = boxes[i, j] 147 | boxes[i, j] = boxes[maxpos, j] 148 | boxes[maxpos, j] = tmp 149 | 150 | tx1 = boxes[i, 0] 151 | ty1 = boxes[i, 1] 152 | tx2 = boxes[i, 2] 153 | ty2 = boxes[i, 3] 154 | ts = boxes[i, 4] 155 | 156 | pos = i + 1 157 | # NMS iterations, note that N changes if detection boxes fall below threshold 158 | while pos < N: 159 | x1 = boxes[pos, 0] 160 | y1 = boxes[pos, 1] 161 | x2 = boxes[pos, 2] 162 | y2 = boxes[pos, 3] 163 | s = boxes[pos, 4] 164 | 165 | area = (x2 - x1 + 1) * (y2 - y1 + 1) 166 | iw = (min(tx2, x2) - max(tx1, x1) + 1) 167 | if iw > 0: 168 | ih = (min(ty2, y2) - max(ty1, y1) + 1) 169 | if ih > 0: 170 | ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih) 171 | ov = iw * ih / ua # iou between max box and detection box 172 | 173 | if method == 1: # linear 174 | if ov > Nt: 175 | weight = 1 - ov 176 | else: 177 | weight = 1 178 | elif method == 2: # gaussian 179 | weight = np.exp(-(ov * ov) / sigma) 180 | else: # original NMS 181 | if ov > Nt: 182 | weight = 0 183 | else: 184 | weight = 1 185 | 186 | boxes[pos, 4] = weight * boxes[pos, 4] 187 | 188 | # if box score falls below threshold, discard the box by swapping with last box 189 | # update N 190 | if boxes[pos, 4] < threshold: 191 | boxes[pos, 0] = boxes[N - 1, 0] 192 | boxes[pos, 1] = boxes[N - 1, 1] 193 | boxes[pos, 2] = boxes[N - 1, 2] 194 | boxes[pos, 3] = boxes[N - 1, 3] 195 | boxes[pos, 4] = boxes[N - 1, 4] 196 | for j in range(5, 39): 197 | tmp = boxes[pos, j] 198 | boxes[pos, j] = boxes[N - 1, j] 199 | boxes[N - 1, j] = tmp 200 | N = N - 1 201 | pos = pos - 1 202 | 203 | pos = pos + 1 204 | 205 | keep = [i for i in range(N)] 206 | return keep 207 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | Copyright 2021 Torben Teepe 179 | 180 | Licensed under the Apache License, Version 2.0 (the "License"); 181 | you may not use this file except in compliance with the License. 182 | You may obtain a copy of the License at 183 | 184 | http://www.apache.org/licenses/LICENSE-2.0 185 | 186 | Unless required by applicable law or agreed to in writing, software 187 | distributed under the License is distributed on an "AS IS" BASIS, 188 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 189 | See the License for the specific language governing permissions and 190 | limitations under the License. 191 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CenterNet w/ PyTorchLightning 2 | 3 | ![CI testing](https://github.com/tteepe/CenterNet-pytorch-lightning/workflows/CI%20testing/badge.svg?branch=main&event=push) 4 | [![DOI](https://zenodo.org/badge/334429075.svg)](https://zenodo.org/badge/latestdoi/334429075) 5 | 6 | 7 | ## Description 8 | My attempt at a cleaner implementation of the glorious [CenterNet](https://github.com/xingyizhou/CenterNet). 9 | 10 | ### Features 11 | - Decoupled backbones and heads for easier backbone integration 12 | - Split sample creation into image augmentation (with [imgaug](https://github.com/aleju/imgaug)) and actual sample creation 13 | - Comes shipped with Lightning modules but can also be used with good ol' plain PyTorch 14 | - Stripped all code not used to reproduce the results in the paper 15 | - Smaller code base with more meaningful variable names 16 | - Requires significantly less memory 17 | - Same or slightly better results than the original implementation 18 | 19 | 20 | ### ToDos 21 | Some features of the original repository are not implemented yet but pull requests are welcome! 22 | - [ ] 3D bounding box detection 23 | - [ ] ExtremeNet detection 24 | - [ ] Pascal VOC dataset 25 | 26 | ## How to run 27 | First, install dependencies 28 | ```shell 29 | # Install ninja for DCNv2 JIT compilation 30 | sudo apt-get install ninja-build 31 | 32 | # clone CenterNet 33 | git clone https://github.com/tteepe/CenterNet-pytorch-lightning 34 | 35 | # install CenterNet 36 | cd CenterNet-pytorch-lightning 37 | pip install -e . 38 | pip install -r requirements.txt 39 | ``` 40 | Next, navigate to any file and run it. 41 | ```shell 42 | # module folder 43 | cd CenterNet 44 | 45 | # run module 46 | python centernet_detection.py 47 | python centernet_multi_pose.py 48 | ``` 49 | 50 | ## Imports 51 | This project is setup as a package which means you can now easily import any file into any other file like so: 52 | 53 | ```python 54 | from pytorch_lightning import Trainer 55 | from torchvision.datasets import CocoDetection 56 | from CenterNet import CenterNetDetection 57 | 58 | # model 59 | model = CenterNetDetection("dla_34") 60 | 61 | # data 62 | train = CocoDetection("train2017", "instances_train2017.json") 63 | val = CocoDetection("val2017", "instances_val2017.json") 64 | 65 | # train 66 | trainer = Trainer() 67 | trainer.fit(model, train, val) 68 | 69 | # test using the best backbone! 70 | test = CocoDetection("test2017", "image_info_test2017.json") 71 | trainer.test(test_dataloaders=test) 72 | ``` 73 | 74 | ## BibTeX 75 | If you want to cite the implementation feel free to use this or [zenodo](https://zenodo.org/record/4569502): 76 | 77 | ```bibtex 78 | @article{teepe2021centernet, 79 | title={CenterNet PyTorch Lightning}, 80 | author={Teepe, Torben and Gilg, Johannes}, 81 | journal={GitHub. Note: https://github.com/tteepe/CenterNet-pytorch-lightning}, 82 | volume={1}, 83 | year={2021} 84 | } 85 | ``` 86 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | DCNv2 @ git+https://github.com/tteepe/DCNv2 2 | imgaug~=0.4.0 3 | json_tricks 4 | matplotlib 5 | numba 6 | opencv-python 7 | pycocotools 8 | pytorch-lightning~=1.5.10 9 | setuptools 10 | torchvision~=0.11.3 11 | torch~=1.10.2 12 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | norecursedirs = 3 | .git 4 | dist 5 | build 6 | addopts = 7 | --strict 8 | --doctest-modules 9 | --durations=0 10 | 11 | [coverage:report] 12 | exclude_lines = 13 | pragma: no-cover 14 | pass 15 | 16 | [flake8] 17 | max-line-length = 120 18 | exclude = .tox,*.egg,build,temp 19 | select = E,W,F 20 | doctests = True 21 | verbose = 2 22 | # https://pep8.readthedocs.io/en/latest/intro.html#error-codes 23 | format = pylint 24 | # see: https://www.flake8rules.com/ 25 | ignore = 26 | E731 # Do not assign a lambda expression, use a def 27 | W504 # Line break occurred after a binary operator 28 | F401 # Module imported but unused 29 | F841 # Local variable name is assigned to but never used 30 | W605 # Invalid escape sequence 'x' 31 | 32 | # setup.cfg or tox.ini 33 | [check-manifest] 34 | ignore = 35 | *.yml 36 | .github 37 | .github/* 38 | 39 | [metadata] 40 | license_file = LICENSE 41 | description-file = README.md 42 | # long_description = file:README.md 43 | # long_description_content_type = text/markdown 44 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup( 6 | name='CenterNet', 7 | version='0.1.0', 8 | description='Refactored version of CenterNet (Objects as Points). With PyTorch Lightning and imgaug.', 9 | author='Torben Teepe', 10 | author_email='torben@tee.pe', 11 | url='https://github.com/tteepe/CenterNet-pytorch-lightning', 12 | install_requires=['pytorch-lightning'], 13 | packages=find_packages(), 14 | ) 15 | 16 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tteepe/CenterNet-pytorch-lightning/2febb56103046064a42b502a4145c37728917042/tests/__init__.py -------------------------------------------------------------------------------- /tests/data/coco_annotation.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "area": 2913.1104, 4 | "bbox": [ 5 | 412.8, 6 | 157.61, 7 | 53.05, 8 | 138.01 9 | ], 10 | "category_id": 1, 11 | "id": 230831, 12 | "image_id": 139, 13 | "iscrowd": 0, 14 | "keypoints": [ 15 | 427.0, 170.0, 1.0, 16 | 0.0, 0.0, 0.0, 17 | 429.0, 169.0, 2.0, 18 | 0.0, 0.0, 0.0, 19 | 434.0, 168.0, 2.0, 20 | 446.0, 177.0, 2.0, 21 | 441.0, 177.0, 2.0, 22 | 430.0, 206.0, 2.0, 23 | 437.0, 200.0, 2.0, 24 | 420.0, 215.0, 2.0, 25 | 430.0, 220.0, 2.0, 26 | 452.0, 223.0, 2.0, 27 | 445.0, 226.0, 2.0, 28 | 454.0, 257.0, 2.0, 29 | 447.0, 260.0, 2.0, 30 | 459.0, 286.0, 2.0, 31 | 455.0, 290.0, 2.0 32 | ], 33 | "num_keypoints": 15, 34 | "segmentation": [ 35 | [428.19 , 219.47 , 430.94 , 209.57 , 430.39 , 210.12 , 421.32 , 216.17 , 412.8 , 217.27 , 413.9 , 214.24 , 422.42 , 211.22 , 429.29 , 201.6 , 430.67 , 181.8 , 430.12 , 175.2 , 427.09 , 168.06 , 426.27 , 164.21 , 430.94 , 159.26 , 440.29 , 157.61 , 446.06 , 163.93 , 448.53 , 168.06 , 448.53 , 173.01 , 449.08 , 174.93 , 454.03 , 185.1 , 455.41 , 188.4 , 458.43 , 195 , 460.08 , 210.94 , 462.28 , 226.61 , 460.91 , 233.76 , 454.31 , 234.04 , 460.08 , 256.85 , 462.56 , 268.13 , 465.58 , 290.67 , 465.85 , 293.14 , 463.38 , 295.62 , 452.66 , 295.34 , 448.26 , 294.52 , 443.59 , 282.7 , 446.06 , 235.14 , 446.34 , 230.19 , 438.09 , 232.39 , 438.09 , 221.67 , 434.24 , 221.12 , 427.09 , 219.74] 36 | ] 37 | }, 38 | { 39 | "area": 435.14495, 40 | "bbox": [ 41 | 384.43, 42 | 172.21, 43 | 15.12, 44 | 35.74 45 | ], 46 | "category_id": 1, 47 | "id": 233201, 48 | "image_id": 139, 49 | "iscrowd": 0, 50 | "keypoints": [ 51 | 0.0, 0.0, 0.0, 52 | 0.0, 0.0, 0.0, 53 | 0.0, 0.0, 0.0, 54 | 0.0, 0.0, 0.0, 55 | 0.0, 0.0, 0.0, 56 | 0.0, 0.0, 0.0, 57 | 0.0, 0.0, 0.0, 58 | 0.0, 0.0, 0.0, 59 | 0.0, 0.0, 0.0, 60 | 0.0, 0.0, 0.0, 61 | 0.0, 0.0, 0.0, 62 | 0.0, 0.0, 0.0, 63 | 0.0, 0.0, 0.0, 64 | 0.0, 0.0, 0.0, 65 | 0.0, 0.0, 0.0, 66 | 0.0, 0.0, 0.0, 67 | 0.0, 0.0, 0.0 68 | ], 69 | "num_keypoints": 0, 70 | "segmentation": [ 71 | [384.98 , 206.58 , 384.43 , 199.98 , 385.25 , 193.66 , 385.25 , 190.08 , 387.18 , 185.13 , 387.18 , 182.93 , 386.08 , 181.01 , 385.25 , 178.81 , 385.25 , 175.79 , 388 , 172.76 , 394.88 , 172.21 , 398.72 , 173.31 , 399.27 , 176.06 , 399.55 , 183.48 , 397.9 , 185.68 , 395.15 , 188.98 , 396.8 , 193.38 , 398.45 , 194.48 , 399 , 205.75 , 395.43 , 207.95 , 388.83 , 206.03] 72 | ] 73 | } 74 | ] 75 | -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | coverage 2 | codecov>=2.1 3 | pytest>=3.0.5 4 | pytest-cov 5 | pytest-flake8 6 | flake8 7 | check-manifest 8 | twine>=3.4.2 9 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from CenterNet.models import create_model 3 | from CenterNet.models.heads import CenterHead 4 | 5 | 6 | device = "cuda" if torch.cuda.is_available() else "cpu" 7 | supported_backbones = [ 8 | "res_18", "res_101", "resdcn_18", "resdcn_101", "dla_34", "hourglass" 9 | ] 10 | 11 | 12 | def test_models(): 13 | sample_input = torch.rand((1, 3, 512, 512), device=device) 14 | 15 | for arch in supported_backbones: 16 | print(f"Testing: {arch}") 17 | model = create_model(arch).to(device) 18 | heads = { 19 | "heatmap": 1, 20 | "width_height": 2, 21 | "regression": 2, 22 | "heatmap_keypoints": 17, 23 | "heatpoint_offset": 2, 24 | "keypoints": 34, 25 | } 26 | 27 | head = CenterHead(heads, model.out_channels, 64).to(device) 28 | 29 | out_backbone = model(sample_input) 30 | output = head(out_backbone[-1]) 31 | 32 | assert output 33 | 34 | for name, data in output.items(): 35 | shape = getattr(head, name).out_channels 36 | head_shape = torch.Size( 37 | [1, shape, sample_input.shape[2] // 4, sample_input.shape[3] // 4] 38 | ) 39 | assert data.shape == head_shape 40 | 41 | 42 | if __name__ == "__main__": 43 | test_models() 44 | -------------------------------------------------------------------------------- /tests/test_sample_encode_decode.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import torch 4 | import torchvision 5 | import imgaug.augmenters as iaa 6 | import pytest 7 | 8 | from CenterNet.decode.ctdet import ctdet_decode 9 | from CenterNet.sample.ctdet import CenterDetectionSample 10 | from CenterNet.transforms import CategoryIdToClass, ImageAugmentation 11 | from CenterNet.transforms.sample import ComposeSample 12 | 13 | 14 | def test_cdet_encoding_decoding(): 15 | sample_encoding = ComposeSample([ 16 | ImageAugmentation( 17 | iaa.Identity(), 18 | torchvision.transforms.ToTensor() 19 | ), 20 | CategoryIdToClass(range(0, 100)), 21 | CenterDetectionSample() 22 | ]) 23 | 24 | img = (255 * np.random.rand(512, 512, 3)).astype(np.uint8) 25 | with open('tests/data/coco_annotation.json') as json_file: 26 | coco_annotation = json.load(json_file) 27 | 28 | ann_center = np.zeros((len(coco_annotation), 2)) 29 | for i in range(len(coco_annotation)): 30 | x, y, w, h = coco_annotation[i]["bbox"] 31 | ann_center[i, 0] = x + w/2 32 | ann_center[i, 1] = y + h/2 33 | 34 | img, output = sample_encoding(img, coco_annotation) 35 | 36 | heatmap = output['heatmap'].unsqueeze(0) 37 | batch, cat, height, width = heatmap.size() 38 | wh = torch.zeros((batch, width, height, 2)) 39 | reg = torch.zeros((batch, width, height, 2)) 40 | 41 | # Create fake output from sample 42 | indices = output['indices'].unsqueeze(0) 43 | indices_x = indices % width 44 | indices_y = indices // width 45 | wh[:, indices_y, indices_x] = output['width_height'].unsqueeze(0) 46 | wh = wh.permute(0, 3, 1, 2) 47 | reg[:, indices_y, indices_x] = output['regression'].unsqueeze(0) 48 | reg = reg.permute(0, 3, 1, 2) 49 | 50 | # Decode fake output 51 | detections = ctdet_decode(heatmap, wh, reg).squeeze().numpy() 52 | detections = 4 * detections[detections[:, 4] > 0.5] 53 | 54 | center = (detections[:, :2] + detections[:, 2:4]) / 2. 55 | 56 | assert abs(np.sum(center) - np.sum(ann_center)) == pytest.approx(0., abs=1e-3) 57 | 58 | 59 | if __name__ == "__main__": 60 | test_cdet_encoding_decoding() 61 | -------------------------------------------------------------------------------- /tests/test_train_detection.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import Trainer, seed_everything 2 | import torch 3 | from torch.utils.data import DataLoader 4 | 5 | from CenterNet.sample.ctdet import CenterDetectionSample 6 | from CenterNet.centernet_detection import CenterNetDetection 7 | from tests.utilities import CocoFakeDataset 8 | 9 | 10 | def test_detection(): 11 | """ 12 | Simple smoke test for CenterNetDetection 13 | """ 14 | seed_everything(1234) 15 | dataset = CocoFakeDataset( 16 | transforms=CenterDetectionSample() 17 | ) 18 | 19 | test_val_loader = DataLoader( 20 | dataset, 21 | batch_size=2, 22 | num_workers=1, 23 | pin_memory=True, 24 | ) 25 | test_loader = DataLoader( 26 | dataset, 27 | batch_size=1, 28 | num_workers=0, 29 | pin_memory=True, 30 | ) 31 | 32 | model = CenterNetDetection( 33 | "dla_34", 34 | test_flip=True, 35 | test_scales=[.5, 1, 1.5] 36 | ) 37 | 38 | trainer = Trainer( 39 | limit_train_batches=2, 40 | limit_val_batches=1, 41 | limit_test_batches=1, 42 | max_epochs=1, 43 | gpus=1 if torch.cuda.is_available() else 0 44 | ) 45 | trainer.fit(model, test_val_loader, test_val_loader) 46 | 47 | trainer.test(model, test_dataloaders=test_loader) 48 | 49 | 50 | if __name__ == "__main__": 51 | test_detection() 52 | -------------------------------------------------------------------------------- /tests/test_train_multi_pose.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import Trainer, seed_everything 2 | import torch 3 | from torch.utils.data import DataLoader 4 | 5 | 6 | from .utilities import CocoFakeDataset 7 | 8 | from CenterNet.centernet_multi_pose import CenterNetMultiPose 9 | from CenterNet.transforms import MultiSampleTransform 10 | from CenterNet.sample.ctdet import CenterDetectionSample 11 | from CenterNet.sample.multi_pose import MultiPoseSample 12 | 13 | 14 | def test_multi_pose(): 15 | """ 16 | Simple smoke test for CenterNetMultiPose 17 | """ 18 | seed_everything(1234) 19 | dataset = CocoFakeDataset( 20 | transforms=MultiSampleTransform([CenterDetectionSample(), MultiPoseSample()]), 21 | ) 22 | 23 | test_val_loader = DataLoader( 24 | dataset, 25 | batch_size=2, 26 | num_workers=1, 27 | pin_memory=True, 28 | ) 29 | test_loader = DataLoader( 30 | dataset, 31 | batch_size=1, 32 | num_workers=0, 33 | pin_memory=True, 34 | ) 35 | 36 | model = CenterNetMultiPose( 37 | "dla_34", 38 | test_flip=True, 39 | test_scales=[.5, 1, 1.5] 40 | ) 41 | 42 | trainer = Trainer( 43 | limit_train_batches=2, 44 | limit_val_batches=1, 45 | limit_test_batches=1, 46 | max_epochs=1, 47 | gpus=1 if torch.cuda.is_available() else 0 48 | ) 49 | trainer.fit(model, test_val_loader, test_val_loader) 50 | 51 | trainer.test(model, test_dataloaders=test_loader) 52 | 53 | 54 | if __name__ == "__main__": 55 | test_multi_pose() 56 | -------------------------------------------------------------------------------- /tests/test_transforms.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import torch 4 | import imgaug.augmenters as iaa 5 | 6 | from CenterNet.transforms import ImageAugmentation, PoseFlip 7 | 8 | 9 | def test_image_augmentation(): 10 | img_aug_none = ImageAugmentation( 11 | iaa.Identity() 12 | ) 13 | img_aug_change = ImageAugmentation( 14 | iaa.Fliplr(1) 15 | ) 16 | 17 | sample_img = (255 * np.random.rand(512, 512, 3)).astype(np.uint8) 18 | with open('tests/data/coco_annotation.json') as json_file: 19 | coco_annotation = json.load(json_file) 20 | 21 | # Expect no change 22 | img_aug, ann_aug = img_aug_none(sample_img, coco_annotation) 23 | 24 | assert np.sum(img_aug[:, :, ::-1] - sample_img) == 0 25 | 26 | for i in range(len(coco_annotation)): 27 | assert ann_aug[i]['keypoints'] == coco_annotation[i]['keypoints'] 28 | np.testing.assert_array_almost_equal( 29 | np.array(ann_aug[i]['bbox']), 30 | np.array(coco_annotation[i]['bbox']) 31 | ) 32 | 33 | # Expect change 34 | img_aug, ann_aug = img_aug_change(sample_img, coco_annotation) 35 | 36 | assert np.sum(img_aug[:, :, ::-1] - sample_img) != 0 37 | 38 | for i in range(len(coco_annotation)): 39 | if ann_aug[i]["num_keypoints"] != 0: 40 | assert ann_aug[i]['keypoints'] != coco_annotation[i]['keypoints'] 41 | assert ann_aug[i]['bbox'] != coco_annotation[i]['bbox'] 42 | 43 | 44 | def test_pose_flip(): 45 | sample_img = torch.rand((1, 3, 512, 512)) 46 | with open('tests/data/coco_annotation.json') as json_file: 47 | coco_annotation = json.load(json_file) 48 | 49 | flip = PoseFlip(1) 50 | 51 | # Flip 52 | img, ann = flip(sample_img, coco_annotation) 53 | 54 | assert torch.sum(img - sample_img) != 0 55 | for i in range(len(coco_annotation)): 56 | assert ann[i]['bbox'] != coco_annotation[i]['bbox'] 57 | if ann[i]["num_keypoints"] == 0: 58 | assert ann[i]['keypoints'] == coco_annotation[i]['keypoints'] 59 | else: 60 | assert ann[i]['keypoints'] != coco_annotation[i]['keypoints'] 61 | 62 | # Flip back to original 63 | img, ann = flip(img, ann) 64 | 65 | assert torch.sum(img - sample_img) == 0 66 | for i in range(len(coco_annotation)): 67 | assert ann[i]['keypoints'] == coco_annotation[i]['keypoints'] 68 | np.testing.assert_array_almost_equal( 69 | np.array(ann[i]['bbox']), 70 | np.array(coco_annotation[i]['bbox']) 71 | ) 72 | 73 | 74 | if __name__ == "__main__": 75 | test_image_augmentation() 76 | test_pose_flip() 77 | -------------------------------------------------------------------------------- /tests/utilities.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class CocoFakeDataset(Dataset): 8 | def __init__(self, 9 | transforms=None, 10 | annotation_path="tests/data/coco_annotation.json", 11 | length=1000 12 | ): 13 | self.transforms = transforms 14 | with open(annotation_path) as json_file: 15 | self.coco_annotation = json.load(json_file) 16 | self.length = length 17 | 18 | def __getitem__(self, index): 19 | img = torch.rand((3, 512, 512)) 20 | annotation = self.coco_annotation.copy() 21 | 22 | if self.transforms: 23 | img, annotation = self.transforms(img, annotation) 24 | 25 | return img, annotation 26 | 27 | def __len__(self): 28 | return self.length 29 | --------------------------------------------------------------------------------