├── README.md ├── __pycache__ ├── ctdet.cpython-37.pyc ├── dataset.cpython-37.pyc ├── loss.cpython-37.pyc ├── model.cpython-37.pyc └── utils.cpython-37.pyc ├── ctdet.py ├── dataset.py ├── loss.py ├── model.py ├── model_orig.py ├── outputs ├── epoch_0.jpg ├── epoch_1.jpg ├── epoch_10.jpg ├── epoch_11.jpg ├── epoch_12.jpg ├── epoch_13.jpg ├── epoch_14.jpg ├── epoch_15.jpg ├── epoch_16.jpg ├── epoch_17.jpg ├── epoch_18.jpg ├── epoch_19.jpg ├── epoch_2.jpg ├── epoch_3.jpg ├── epoch_4.jpg ├── epoch_5.jpg ├── epoch_6.jpg ├── epoch_7.jpg ├── epoch_8.jpg ├── epoch_9.jpg └── mask_loss.png ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | 5 | 6 | This repo contains a minimalist implementation of the paper *Objects as Points* by Zhou et al. I found the approach pretty interesting and novel. It doesn't use anchor boxes and requires minimal post-processing. The essential idea of the paper is to treat objects as points denoted by their centers rather than bounding boxes. This idea is elegant and makes intuitive sense. And most importantly, it performs pretty well.The authors were able to achieve high accuracies on multiple tasks namely Object Detection on MSCOCO, 3D detection on Kitti and pose estimation ! 7 | 8 | ## Dependencies 9 | - Python > 3.6 10 | - Pytorch > 1.0 11 | 12 | If you have installed pytorch in conda environment, it should work without much hassle. 13 | 14 | ## Running the model 15 | No need to download any dataset. Just run python train.py to start training the model. The model outputs are saved in outputs folder. 16 | 17 | 18 | ## Model Outputs 19 | The visualization of model output looks something like this. 20 | ![Model Outputs](./outputs/epoch_12.jpg) 21 | 22 | Here are a few details about this implementation. 23 | 24 | 25 | ## The Toy dataset 26 | MSCOCO or Kitti datasets are big and not everyone is interested in understanding them just to do a few experiments. So, I decided to write a function that automatically generates a toy dataset of circles. The dataset generates an image with multiple circles (max 5]. The circles may overlap. 27 | ## The Model 28 | I didn't want to complicate things so I used resnet 18 as encoder and followed unet in the upsampling stages. Let's say the input to model is 1x1x256x256 (Batch Size x Channels x Width x Height), then output of 29 | model will be 1x(N+4)x64x64. Here 1 is the batch size, N+4 is the number of output channels and 64 is the downsampled image width. Here N is the number of object categories. In my implementation we only have one category of object i.e. circle. So, in our case the output is 1x5x64x64. Also, notice that we could have gone for complete upsampling like in UNET. The authors went for 64x64 size. An advantage of traning on downsampled images is that it reduces the number of parameters. 30 | 31 | Now an obvious question is if our model output is 4 times smaller [say 64] than the input image, then how are we going to get a fine prediction about the object location ? It seems like we can be off by 4 pixels both in height and width... The authors propose an innovative solution. They predict an offset map to finetune the object position. 32 | So, if model output is 1x5x64x64, the first channel (1x1x64x64) will be score of an object at those pixels, second and third channels represent the width and height. If coordinates say (30,23) are 1, it means that in the original image at the coordinate (30\*4,23\*4), some object may be present. The last 2 channels are offsets. It can be used to get better estimate of the object location. 33 | Also, note that we use sigmoid activation for offsets. This ensures that they are bounded between [0,1]. 34 | 35 | ## The loss functions 36 | 37 | We use two types of loss: 38 | - Mask Loss 39 | 40 | ![Model Outputs](./outputs/mask_loss.png) 41 | 42 | This is modified version of focal loss. Here and are parameters of focal loss. We set alpha as 2 and beta as 4. In original implementation, the authors use an object centered gaussian kernel 43 | 44 | 45 | 46 | Here and are the low resolution center of the object. is object size adaptive standard deviation. For this toy dataset,using gaussian kernel didn't seem to make much difference. So, in the interest of keeping it minimal I just use . 47 | 48 | - Size and offset Loss 49 | 50 | For size and offset predictions,we basically use l1 loss normalized by the number of detected points. 51 | 52 | ## Training details 53 | 54 | I used Adam optimizer with default settings. The model starts giving reasonable output after training for around 10 epochs. 55 | We perform non maxima suppression [using Max pooling] on the model class maps. The top 100 predictions are selected. We use offset maps to finetune object coordinates. During visualization, I set a hard threshold of 0.25 i.e. if draw bounding box only if model confidence > 0.25. If you don't care about false alarms, you should probably set a lower threshold. 56 | 57 | ## Conclusion 58 | 59 | I found the paper to be pretty interesting and elegant. I wanted to make the idea more accessible so i am sharing this code. I hope this minimalist implementation helps you in understanding the paper better. -------------------------------------------------------------------------------- /__pycache__/ctdet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/__pycache__/ctdet.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /ctdet.py: -------------------------------------------------------------------------------- 1 | # adapted from 2 | # https://github.com/xingyizhou/CenterNet/blob/master/src/lib/utils/post_process.py 3 | import torch 4 | import torch.nn as nn 5 | from utils import _gather_feat, _tranpose_and_gather_feat 6 | 7 | 8 | def _nms(heat, kernel=3): 9 | pad = (kernel - 1) // 2 10 | 11 | hmax = nn.functional.max_pool2d( 12 | heat, (kernel, kernel), stride=1, padding=pad) 13 | keep = (hmax == heat).float() 14 | return heat * keep 15 | 16 | 17 | def _topk(scores, K=40): 18 | batch, cat, height, width = scores.size() 19 | 20 | topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K) 21 | 22 | topk_inds = topk_inds % (height * width) 23 | topk_ys = (topk_inds / width).int().float() 24 | topk_xs = (topk_inds % width).int().float() 25 | 26 | topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K) 27 | topk_clses = (topk_ind / K).int() 28 | topk_inds = _gather_feat( 29 | topk_inds.view(batch, -1, 1), topk_ind).view(batch, K) 30 | topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K) 31 | topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K) 32 | 33 | return topk_score, topk_inds, topk_clses, topk_ys, topk_xs 34 | 35 | 36 | def ctdet_decode(heat, wh, reg=None, cat_spec_wh=False, K=100, model_scale=4): 37 | batch, cat, height, width = heat.size() 38 | 39 | heat = torch.sigmoid(heat) 40 | # perform nms on heatmaps 41 | heat = _nms(heat) 42 | 43 | scores, inds, clses, ys, xs = _topk(heat, K=K) 44 | 45 | xs_raw = xs.view(batch, K, 1) + 0.5 46 | ys_raw = ys.view(batch, K, 1) + 0.5 47 | 48 | if reg is not None: 49 | reg = _tranpose_and_gather_feat(reg, inds) 50 | reg = reg.view(batch, K, 2) 51 | # check if it's correct and not reversed with ys 52 | xs = xs.view(batch, K, 1) + reg[:, :, 0:1] 53 | ys = ys.view(batch, K, 1) + reg[:, :, 1:2] 54 | 55 | wh = _tranpose_and_gather_feat(wh, inds) 56 | if cat_spec_wh: 57 | wh = wh.view(batch, K, cat, 2) 58 | clses_ind = clses.view(batch, K, 1, 1).expand(batch, K, 1, 2).long() 59 | wh = wh.gather(2, clses_ind).view(batch, K, 2) 60 | else: 61 | wh = wh.view(batch, K, 2) 62 | 63 | xs, ys = xs*model_scale, ys*model_scale 64 | 65 | xs_raw, ys_raw = xs_raw*model_scale, ys_raw*model_scale 66 | 67 | clses = clses.view(batch, K, 1).float() 68 | scores = scores.view(batch, K, 1) 69 | 70 | bboxes = torch.cat([xs - wh[..., 1:2] / 2, 71 | ys - wh[..., 0:1] / 2, 72 | xs + wh[..., 1:2] / 2, 73 | ys + wh[..., 0:1] / 2], dim=2) 74 | 75 | bboxes_raw = torch.cat([xs_raw - wh[..., 1:2] / 2, 76 | ys_raw - wh[..., 0:1] / 2, 77 | xs_raw + wh[..., 1:2] / 2, 78 | ys_raw + wh[..., 0:1] / 2], dim=2) 79 | 80 | return bboxes_raw, bboxes, scores, clses 81 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import cv2 3 | import numpy as np 4 | 5 | 6 | class ToyDataset(Dataset): 7 | """Car dataset.""" 8 | 9 | def __init__(self, img_shape=(256,256),max_radius=64, 10 | num_classes=1,max_objects=5): 11 | super().__init__() 12 | self.img_shape = np.array(img_shape) 13 | self.num_classes = num_classes 14 | self.max_width = 64 15 | self.max_height = 64 16 | self.max_radius = min(img_shape)//4 17 | self.max_objects = max_objects 18 | 19 | 20 | w, h = self.img_shape//4 21 | # prepare mesh center points 22 | x_arr = np.arange(w) + 0.5 23 | y_arr = np.arange(h) + 0.5 24 | self.xy_mesh = np.stack(np.meshgrid(x_arr, y_arr)) # [2, h, w] 25 | 26 | def __len__(self): 27 | return 1000 28 | 29 | def __getitem__(self, idx): 30 | 31 | im = np.zeros(self.img_shape,dtype=np.float32) 32 | heatmap = np.zeros((self.num_classes+4,self.img_shape[0]//4,self.img_shape[1]//4),dtype=np.float32) 33 | for _ in range(np.random.randint(0,5)): 34 | x,y = np.random.randint(0,self.img_shape[0]),np.random.randint(0,self.img_shape[1]) 35 | radius = np.random.randint(10,self.max_radius) 36 | im = np.maximum(im,cv2.circle(im,(y,x),radius=radius,color=1,thickness=-1)) 37 | 38 | center = np.array([x,y])/4 39 | x, y = np.floor(center).astype(np.int) 40 | # print('center,wh',center,wh) 41 | 42 | # sigma = gaussian_radius(wh) 43 | # dist_squared = np.sum((self.xy_mesh - center[:, None, None]) ** 2, axis=0) 44 | # gauss = np.exp(-1 * dist_squared / (2 * sigma ** 2)) 45 | # heatmap[0, :, :] = np.maximum(heatmap[0, :, :], gauss) 46 | 47 | heatmap[0,x,y] = 1 48 | 49 | # size 50 | heatmap[-4:-2,x,y] = np.array([2*radius,2*radius]) 51 | 52 | # offset 53 | heatmap[-2:, x,y] = center - np.floor(center) 54 | 55 | return im[None,:,:], heatmap -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def regr_loss(regr, gt_regr, mask): 6 | num = mask.float().sum()*2 7 | 8 | regr = regr[mask == 1] 9 | gt_regr = gt_regr[mask == 1] 10 | regr_loss = F.l1_loss( 11 | regr, gt_regr, size_average=False) 12 | regr_loss = regr_loss / (num + 1e-4) 13 | return regr_loss 14 | 15 | def _neg_loss(pred, gt,alpha=2,beta=4): 16 | ''' Modified focal loss. Exactly the same as CornerNet. 17 | Runs faster and costs a little bit more memory 18 | Arguments: 19 | pred (batch x c x h x w) 20 | gt_regr (batch x c x h x w) 21 | ''' 22 | pos_inds = gt.eq(1).float() 23 | neg_inds = gt.lt(1).float() 24 | 25 | neg_weights = torch.pow(1 - gt, beta) 26 | 27 | loss = 0 28 | 29 | pos_loss = torch.log(pred) * torch.pow(1 - pred, alpha) * pos_inds 30 | neg_loss = torch.log(1 - pred) * torch.pow(pred, alpha) * neg_weights * neg_inds 31 | 32 | num_pos = pos_inds.float().sum() 33 | pos_loss = pos_loss.sum() 34 | neg_loss = neg_loss.sum() 35 | 36 | if num_pos == 0: 37 | loss = loss - neg_loss 38 | else: 39 | loss = loss - (pos_loss + neg_loss) / num_pos 40 | return loss 41 | 42 | def criterion(prediction, true, size_average=True): 43 | 44 | # Binary mask loss 45 | pred_mask = torch.sigmoid(prediction[:, 0]) 46 | 47 | mask_loss = _neg_loss(pred_mask[:,None,:,:],true[:,0:1,:,:]) 48 | 49 | 50 | size_loss_x = regr_loss(prediction[:,-4,:,:], 51 | true[:,-4,:,:],true[:,0]) 52 | size_loss_y = regr_loss(prediction[:,-3,:,:], 53 | true[:,-3,:,:],true[:,0]) 54 | 55 | offset_loss_x = regr_loss(prediction[:,-2,:,:], 56 | true[:,-2,:,:],true[:,0]) 57 | 58 | offset_loss_y = regr_loss(prediction[:,-1,:,:], 59 | true[:,-1,:,:],true[:,0]) 60 | 61 | return mask_loss, size_loss_x+size_loss_y, offset_loss_x+offset_loss_y -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 5 | """3x3 convolution with padding""" 6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 7 | padding=dilation, groups=groups, bias=False, dilation=dilation) 8 | 9 | 10 | def conv1x1(in_planes, out_planes, stride=1): 11 | """1x1 convolution""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | __constants__ = ['downsample'] 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 20 | base_width=64, dilation=1, norm_layer=None): 21 | super(BasicBlock, self).__init__() 22 | if norm_layer is None: 23 | norm_layer = nn.BatchNorm2d 24 | if groups != 1 or base_width != 64: 25 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 26 | if dilation > 1: 27 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 28 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = norm_layer(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = norm_layer(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | identity = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | identity = self.downsample(x) 49 | 50 | out += identity 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | __constants__ = ['downsample'] 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 61 | base_width=64, dilation=1, norm_layer=None): 62 | super(Bottleneck, self).__init__() 63 | if norm_layer is None: 64 | norm_layer = nn.BatchNorm2d 65 | width = int(planes * (base_width / 64.)) * groups 66 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 67 | self.conv1 = conv1x1(inplanes, width) 68 | self.bn1 = norm_layer(width) 69 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 70 | self.bn2 = norm_layer(width) 71 | self.conv3 = conv1x1(width, planes * self.expansion) 72 | self.bn3 = norm_layer(planes * self.expansion) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.downsample = downsample 75 | self.stride = stride 76 | 77 | def forward(self, x): 78 | identity = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | identity = self.downsample(x) 93 | 94 | out += identity 95 | out = self.relu(out) 96 | 97 | return out 98 | 99 | 100 | class ResNet(nn.Module): 101 | 102 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 103 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 104 | norm_layer=None): 105 | super(ResNet, self).__init__() 106 | if norm_layer is None: 107 | norm_layer = nn.BatchNorm2d 108 | self._norm_layer = norm_layer 109 | 110 | self.inplanes = 64 111 | self.dilation = 1 112 | if replace_stride_with_dilation is None: 113 | # each element in the tuple indicates if we should replace 114 | # the 2x2 stride with a dilated convolution instead 115 | replace_stride_with_dilation = [False, False, False] 116 | if len(replace_stride_with_dilation) != 3: 117 | raise ValueError("replace_stride_with_dilation should be None " 118 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 119 | self.groups = groups 120 | self.base_width = width_per_group 121 | self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, 122 | bias=False) 123 | self.bn1 = norm_layer(self.inplanes) 124 | self.relu = nn.ReLU(inplace=True) 125 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 126 | self.layer1 = self._make_layer(block, 64, layers[0]) 127 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 128 | dilate=replace_stride_with_dilation[0]) 129 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 130 | dilate=replace_stride_with_dilation[1]) 131 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 132 | dilate=replace_stride_with_dilation[2]) 133 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 134 | self.fc = nn.Linear(512 * block.expansion, num_classes) 135 | 136 | for m in self.modules(): 137 | if isinstance(m, nn.Conv2d): 138 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 139 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 140 | nn.init.constant_(m.weight, 1) 141 | nn.init.constant_(m.bias, 0) 142 | 143 | # Zero-initialize the last BN in each residual branch, 144 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 145 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 146 | if zero_init_residual: 147 | for m in self.modules(): 148 | if isinstance(m, Bottleneck): 149 | nn.init.constant_(m.bn3.weight, 0) 150 | elif isinstance(m, BasicBlock): 151 | nn.init.constant_(m.bn2.weight, 0) 152 | 153 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 154 | norm_layer = self._norm_layer 155 | downsample = None 156 | previous_dilation = self.dilation 157 | if dilate: 158 | self.dilation *= stride 159 | stride = 1 160 | if stride != 1 or self.inplanes != planes * block.expansion: 161 | downsample = nn.Sequential( 162 | conv1x1(self.inplanes, planes * block.expansion, stride), 163 | norm_layer(planes * block.expansion), 164 | ) 165 | 166 | layers = [] 167 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 168 | self.base_width, previous_dilation, norm_layer)) 169 | self.inplanes = planes * block.expansion 170 | for _ in range(1, blocks): 171 | layers.append(block(self.inplanes, planes, groups=self.groups, 172 | base_width=self.base_width, dilation=self.dilation, 173 | norm_layer=norm_layer)) 174 | 175 | return nn.Sequential(*layers) 176 | 177 | def forward(self, x): 178 | x = self.conv1(x) 179 | x = self.bn1(x) 180 | x = self.relu(x) 181 | x = self.maxpool(x) 182 | 183 | x1 = self.layer1(x) 184 | x2 = self.layer2(x1) 185 | x3 = self.layer3(x2) 186 | x4 = self.layer4(x3) 187 | 188 | return x1,x2,x3,x4 189 | 190 | class Res18UnetCenterNet(nn.Module): 191 | 192 | stride: int = 4 193 | 194 | def __init__(self,n_fg_class=1): 195 | super().__init__() 196 | out_ch = n_fg_class + 4 197 | 198 | self.encoder = ResNet(BasicBlock, [2, 2, 2, 2]) 199 | self.dc1 = nn.ConvTranspose2d(512,256,2,2) 200 | self.dc2 = nn.ConvTranspose2d(512,128,2,2) 201 | self.dc3 = nn.ConvTranspose2d(256,64,2,2) 202 | self.dc4 = nn.Conv2d(128,out_ch,kernel_size=3,stride=1,padding=1) 203 | 204 | def forward(self,x): 205 | x1,x2,x3,x4 = self.encoder(x) 206 | h = self.dc1(x4) 207 | h = self.dc2(torch.cat([x3, h],1)) 208 | h = self.dc3(torch.cat([x2, h],1)) 209 | h = self.dc4(torch.cat([x1, h],1)) 210 | 211 | # force 0-1 value range to scores and offsets 212 | C = h.shape[1] 213 | scores, sizes, offsets = torch.split(h,(C-4,C-3,2),1) 214 | 215 | offsets = torch.sigmoid(offsets) 216 | 217 | return torch.cat([scores, sizes, offsets],1) 218 | 219 | if __name__=='__main__': 220 | # resnet18 = ResNet(BasicBlock, [2, 2, 2, 2]) 221 | a = torch.zeros((1,1,256,256)) 222 | # x1,x2,x3,x4 = resnet18(a) 223 | 224 | res_unet = Res18UnetCenterNet() 225 | pred = res_unet(a) 226 | print(pred.shape) 227 | print() 228 | -------------------------------------------------------------------------------- /model_orig.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 5 | """3x3 convolution with padding""" 6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 7 | padding=dilation, groups=groups, bias=False, dilation=dilation) 8 | 9 | 10 | def conv1x1(in_planes, out_planes, stride=1): 11 | """1x1 convolution""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | __constants__ = ['downsample'] 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 20 | base_width=64, dilation=1, norm_layer=None): 21 | super(BasicBlock, self).__init__() 22 | if norm_layer is None: 23 | norm_layer = nn.BatchNorm2d 24 | if groups != 1 or base_width != 64: 25 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 26 | if dilation > 1: 27 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 28 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = norm_layer(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = norm_layer(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | identity = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | identity = self.downsample(x) 49 | 50 | out += identity 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | __constants__ = ['downsample'] 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 61 | base_width=64, dilation=1, norm_layer=None): 62 | super(Bottleneck, self).__init__() 63 | if norm_layer is None: 64 | norm_layer = nn.BatchNorm2d 65 | width = int(planes * (base_width / 64.)) * groups 66 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 67 | self.conv1 = conv1x1(inplanes, width) 68 | self.bn1 = norm_layer(width) 69 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 70 | self.bn2 = norm_layer(width) 71 | self.conv3 = conv1x1(width, planes * self.expansion) 72 | self.bn3 = norm_layer(planes * self.expansion) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.downsample = downsample 75 | self.stride = stride 76 | 77 | def forward(self, x): 78 | identity = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | identity = self.downsample(x) 93 | 94 | out += identity 95 | out = self.relu(out) 96 | 97 | return out 98 | 99 | 100 | class ResNet(nn.Module): 101 | 102 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 103 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 104 | norm_layer=None): 105 | super(ResNet, self).__init__() 106 | if norm_layer is None: 107 | norm_layer = nn.BatchNorm2d 108 | self._norm_layer = norm_layer 109 | 110 | self.inplanes = 64 111 | self.dilation = 1 112 | if replace_stride_with_dilation is None: 113 | # each element in the tuple indicates if we should replace 114 | # the 2x2 stride with a dilated convolution instead 115 | replace_stride_with_dilation = [False, False, False] 116 | if len(replace_stride_with_dilation) != 3: 117 | raise ValueError("replace_stride_with_dilation should be None " 118 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 119 | self.groups = groups 120 | self.base_width = width_per_group 121 | self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, 122 | bias=False) 123 | self.bn1 = norm_layer(self.inplanes) 124 | self.relu = nn.ReLU(inplace=True) 125 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 126 | self.layer1 = self._make_layer(block, 64, layers[0]) 127 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 128 | dilate=replace_stride_with_dilation[0]) 129 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 130 | dilate=replace_stride_with_dilation[1]) 131 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 132 | dilate=replace_stride_with_dilation[2]) 133 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 134 | self.fc = nn.Linear(512 * block.expansion, num_classes) 135 | 136 | for m in self.modules(): 137 | if isinstance(m, nn.Conv2d): 138 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 139 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 140 | nn.init.constant_(m.weight, 1) 141 | nn.init.constant_(m.bias, 0) 142 | 143 | # Zero-initialize the last BN in each residual branch, 144 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 145 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 146 | if zero_init_residual: 147 | for m in self.modules(): 148 | if isinstance(m, Bottleneck): 149 | nn.init.constant_(m.bn3.weight, 0) 150 | elif isinstance(m, BasicBlock): 151 | nn.init.constant_(m.bn2.weight, 0) 152 | 153 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 154 | norm_layer = self._norm_layer 155 | downsample = None 156 | previous_dilation = self.dilation 157 | if dilate: 158 | self.dilation *= stride 159 | stride = 1 160 | if stride != 1 or self.inplanes != planes * block.expansion: 161 | downsample = nn.Sequential( 162 | conv1x1(self.inplanes, planes * block.expansion, stride), 163 | norm_layer(planes * block.expansion), 164 | ) 165 | 166 | layers = [] 167 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 168 | self.base_width, previous_dilation, norm_layer)) 169 | self.inplanes = planes * block.expansion 170 | for _ in range(1, blocks): 171 | layers.append(block(self.inplanes, planes, groups=self.groups, 172 | base_width=self.base_width, dilation=self.dilation, 173 | norm_layer=norm_layer)) 174 | 175 | return nn.Sequential(*layers) 176 | 177 | def forward(self, x): 178 | x = self.conv1(x) 179 | x = self.bn1(x) 180 | x = self.relu(x) 181 | x = self.maxpool(x) 182 | 183 | x1 = self.layer1(x) 184 | x2 = self.layer2(x1) 185 | x3 = self.layer3(x2) 186 | x4 = self.layer4(x3) 187 | 188 | return x1,x2,x3,x4 189 | 190 | class Res18UnetCenterNet(nn.Module): 191 | 192 | stride: int = 4 193 | 194 | def __init__(self,n_fg_class=1): 195 | super().__init__() 196 | out_ch = n_fg_class + 4 197 | 198 | self.encoder = ResNet(BasicBlock, [2, 2, 2, 2]) 199 | self.dc1 = nn.ConvTranspose2d(512,256,2,2) 200 | self.dc2 = nn.ConvTranspose2d(512,128,2,2) 201 | self.dc3 = nn.ConvTranspose2d(256,64,2,2) 202 | self.dc4 = nn.Conv2d(128,out_ch,kernel_size=3,stride=1,padding=1) 203 | 204 | def forward(self,x): 205 | x1,x2,x3,x4 = self.encoder(x) 206 | h = self.dc1(x4) 207 | h = self.dc2(torch.cat([x3, h],1)) 208 | h = self.dc3(torch.cat([x2, h],1)) 209 | h = self.dc4(torch.cat([x1, h],1)) 210 | 211 | # force 0-1 value range to scores and offsets 212 | C = h.shape[1] 213 | scores, sizes, offsets = torch.split(h,(C-4,C-3,2),1) 214 | 215 | offsets = torch.sigmoid(offsets) 216 | 217 | return torch.cat([scores, sizes, offsets],1) 218 | 219 | if __name__=='__main__': 220 | # resnet18 = ResNet(BasicBlock, [2, 2, 2, 2]) 221 | a = torch.zeros((1,1,256,256)) 222 | # x1,x2,x3,x4 = resnet18(a) 223 | 224 | res_unet = Res18UnetCenterNet() 225 | pred = res_unet(a) 226 | print(pred.shape) 227 | print() 228 | -------------------------------------------------------------------------------- /outputs/epoch_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_0.jpg -------------------------------------------------------------------------------- /outputs/epoch_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_1.jpg -------------------------------------------------------------------------------- /outputs/epoch_10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_10.jpg -------------------------------------------------------------------------------- /outputs/epoch_11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_11.jpg -------------------------------------------------------------------------------- /outputs/epoch_12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_12.jpg -------------------------------------------------------------------------------- /outputs/epoch_13.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_13.jpg -------------------------------------------------------------------------------- /outputs/epoch_14.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_14.jpg -------------------------------------------------------------------------------- /outputs/epoch_15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_15.jpg -------------------------------------------------------------------------------- /outputs/epoch_16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_16.jpg -------------------------------------------------------------------------------- /outputs/epoch_17.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_17.jpg -------------------------------------------------------------------------------- /outputs/epoch_18.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_18.jpg -------------------------------------------------------------------------------- /outputs/epoch_19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_19.jpg -------------------------------------------------------------------------------- /outputs/epoch_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_2.jpg -------------------------------------------------------------------------------- /outputs/epoch_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_3.jpg -------------------------------------------------------------------------------- /outputs/epoch_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_4.jpg -------------------------------------------------------------------------------- /outputs/epoch_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_5.jpg -------------------------------------------------------------------------------- /outputs/epoch_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_6.jpg -------------------------------------------------------------------------------- /outputs/epoch_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_7.jpg -------------------------------------------------------------------------------- /outputs/epoch_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_8.jpg -------------------------------------------------------------------------------- /outputs/epoch_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_9.jpg -------------------------------------------------------------------------------- /outputs/mask_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/mask_loss.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model import Res18UnetCenterNet 3 | from loss import criterion 4 | from dataset import ToyDataset 5 | from ctdet import ctdet_decode 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import cv2 9 | 10 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 11 | 12 | 13 | def train_model(model, optimizer, dataloader, epoch): 14 | model.train() 15 | running_mask_loss, running_size_loss, running_offset_loss = 0, 0, 0 16 | for batch_idx, (img_batch, mask_batch) in enumerate(dataloader): 17 | img_batch = img_batch.to(device) 18 | mask_batch = mask_batch.to(device) 19 | # center_index = center_index.to(device) 20 | 21 | optimizer.zero_grad() 22 | output = model(img_batch) 23 | 24 | mask_loss, size_loss, offset_loss = criterion(output, mask_batch) 25 | 26 | loss = mask_loss + size_loss + offset_loss 27 | 28 | loss.backward() 29 | 30 | optimizer.step() 31 | 32 | running_mask_loss += mask_loss 33 | running_size_loss += size_loss 34 | running_offset_loss += offset_loss 35 | 36 | if batch_idx % 5 == 0: 37 | print(f'\r{running_mask_loss/(batch_idx+1):.3f} {running_size_loss/(batch_idx+1):.3f} {running_offset_loss/(batch_idx+1):.3f}', 38 | end='', flush=True) 39 | 40 | print('\r', end='', flush=True) 41 | print(f"Epoch: {epoch} mask_loss: {running_mask_loss/(batch_idx): .3f} " 42 | f"size_loss: {running_size_loss/(batch_idx): .3f} offset_loss: {running_offset_loss/(batch_idx): .3f}") 43 | 44 | 45 | @torch.no_grad() 46 | def eval_model(model, dataloader, output_folder, epoch=0, thresh=0.25): 47 | 48 | for (img_batch, mask_batch) in dataloader: 49 | img_batch = img_batch.to(device) 50 | mask_batch = mask_batch.to(device) 51 | predictions = model(img_batch) 52 | bboxes_raw, bboxes, scores, clses = ctdet_decode(predictions[:, 0:1], predictions[:, -4:-2, :, :], 53 | predictions[:, -2:, :, :]) 54 | bboxes = bboxes.long().cpu().numpy() 55 | 56 | for batch_idx, (im, mask, pred) in enumerate(zip(img_batch, mask_batch, predictions)): 57 | im = im.permute(1, 2, 0).cpu().squeeze().numpy()*255 58 | im = np.repeat(im[:, :, None], 3, 2) 59 | 60 | score_pos = [] 61 | for score, bbox, bbox_raw in zip(scores[batch_idx], bboxes[batch_idx], bboxes_raw[batch_idx]): 62 | if score > thresh: 63 | im = np.maximum(im, cv2.rectangle( 64 | im, (bbox[2], bbox[3]), (bbox[0], bbox[1]), (0, 255, 0), 2)) 65 | 66 | # uncomment to visualize bbox without offset correction 67 | # im = np.maximum(im, cv2.rectangle( 68 | # im, (bbox_raw[2], bbox_raw[3]), (bbox_raw[0], bbox_raw[1]), (255, 0, 0), 2)) 69 | score_pos.append((bbox[2]+5, bbox[3]+5, score)) 70 | else: 71 | break 72 | 73 | plt.subplot(2, 3, 1) 74 | plt.title('Image with pred bbox') 75 | plt.imshow(im.astype(np.int)) 76 | for pos_x, pos_y, score in score_pos: 77 | plt.text(pos_x, pos_y, f'{score.item():.2}', fontsize=6, c='r') 78 | 79 | plt.subplot(2, 3, 2) 80 | plt.title('Mask') 81 | plt.imshow(mask[0].cpu().squeeze()) 82 | 83 | plt.subplot(2, 3, 3) 84 | plt.title('Mask Prediction') 85 | plt.imshow(pred[0].cpu().squeeze()) 86 | 87 | plt.subplot(2, 3, 4) 88 | plt.title('Width Prediction') 89 | plt.imshow(pred[1].cpu().squeeze()) 90 | 91 | plt.subplot(2, 3, 5) 92 | plt.title('Height Prediction') 93 | plt.imshow(pred[3].cpu().squeeze()) 94 | 95 | plt.subplot(2, 3, 6) 96 | plt.title('Width offset Prediction') 97 | plt.imshow(pred[4].cpu().squeeze()) 98 | 99 | plt.suptitle(f'Epoch {epoch}') 100 | # plt.show() 101 | # print(f'{output_folder}_{epoch}.jpg') 102 | plt.savefig(f'{output_folder}/epoch_{epoch}.jpg') 103 | plt.close() 104 | break 105 | break 106 | 107 | 108 | if __name__ == '__main__': 109 | import os 110 | 111 | model = Res18UnetCenterNet() 112 | model = model.to(device) 113 | optim = torch.optim.Adam(model.parameters()) 114 | dataset = ToyDataset(img_shape=(128, 128)) 115 | 116 | dataloader = torch.utils.data.DataLoader(dataset=dataset, 117 | batch_size=8, num_workers=0) 118 | 119 | output_folder = './outputs' 120 | os.makedirs(output_folder, exist_ok=True) 121 | for epoch in range(20): 122 | train_model(model, optim, dataloader, epoch) 123 | eval_model(model, dataloader, output_folder, epoch=epoch) 124 | 125 | print() 126 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | def _sigmoid(x): 9 | y = torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4) 10 | return y 11 | 12 | def _gather_feat(feat, ind, mask=None): 13 | dim = feat.size(2) 14 | ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) 15 | feat = feat.gather(1, ind) 16 | if mask is not None: 17 | mask = mask.unsqueeze(2).expand_as(feat) 18 | feat = feat[mask] 19 | feat = feat.view(-1, dim) 20 | return feat 21 | 22 | def _tranpose_and_gather_feat(feat, ind): 23 | feat = feat.permute(0, 2, 3, 1).contiguous() 24 | feat = feat.view(feat.size(0), -1, feat.size(3)) 25 | feat = _gather_feat(feat, ind) 26 | return feat 27 | 28 | def flip_tensor(x): 29 | return torch.flip(x, [3]) 30 | # tmp = x.detach().cpu().numpy()[..., ::-1].copy() 31 | # return torch.from_numpy(tmp).to(x.device) 32 | 33 | def flip_lr(x, flip_idx): 34 | tmp = x.detach().cpu().numpy()[..., ::-1].copy() 35 | shape = tmp.shape 36 | for e in flip_idx: 37 | tmp[:, e[0], ...], tmp[:, e[1], ...] = \ 38 | tmp[:, e[1], ...].copy(), tmp[:, e[0], ...].copy() 39 | return torch.from_numpy(tmp.reshape(shape)).to(x.device) 40 | 41 | def flip_lr_off(x, flip_idx): 42 | tmp = x.detach().cpu().numpy()[..., ::-1].copy() 43 | shape = tmp.shape 44 | tmp = tmp.reshape(tmp.shape[0], 17, 2, 45 | tmp.shape[2], tmp.shape[3]) 46 | tmp[:, :, 0, :, :] *= -1 47 | for e in flip_idx: 48 | tmp[:, e[0], ...], tmp[:, e[1], ...] = \ 49 | tmp[:, e[1], ...].copy(), tmp[:, e[0], ...].copy() 50 | return torch.from_numpy(tmp.reshape(shape)).to(x.device) --------------------------------------------------------------------------------