├── README.md
├── __pycache__
├── ctdet.cpython-37.pyc
├── dataset.cpython-37.pyc
├── loss.cpython-37.pyc
├── model.cpython-37.pyc
└── utils.cpython-37.pyc
├── ctdet.py
├── dataset.py
├── loss.py
├── model.py
├── model_orig.py
├── outputs
├── epoch_0.jpg
├── epoch_1.jpg
├── epoch_10.jpg
├── epoch_11.jpg
├── epoch_12.jpg
├── epoch_13.jpg
├── epoch_14.jpg
├── epoch_15.jpg
├── epoch_16.jpg
├── epoch_17.jpg
├── epoch_18.jpg
├── epoch_19.jpg
├── epoch_2.jpg
├── epoch_3.jpg
├── epoch_4.jpg
├── epoch_5.jpg
├── epoch_6.jpg
├── epoch_7.jpg
├── epoch_8.jpg
├── epoch_9.jpg
└── mask_loss.png
├── train.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 |
5 |
6 | This repo contains a minimalist implementation of the paper *Objects as Points* by Zhou et al. I found the approach pretty interesting and novel. It doesn't use anchor boxes and requires minimal post-processing. The essential idea of the paper is to treat objects as points denoted by their centers rather than bounding boxes. This idea is elegant and makes intuitive sense. And most importantly, it performs pretty well.The authors were able to achieve high accuracies on multiple tasks namely Object Detection on MSCOCO, 3D detection on Kitti and pose estimation !
7 |
8 | ## Dependencies
9 | - Python > 3.6
10 | - Pytorch > 1.0
11 |
12 | If you have installed pytorch in conda environment, it should work without much hassle.
13 |
14 | ## Running the model
15 | No need to download any dataset. Just run python train.py to start training the model. The model outputs are saved in outputs folder.
16 |
17 |
18 | ## Model Outputs
19 | The visualization of model output looks something like this.
20 | 
21 |
22 | Here are a few details about this implementation.
23 |
24 |
25 | ## The Toy dataset
26 | MSCOCO or Kitti datasets are big and not everyone is interested in understanding them just to do a few experiments. So, I decided to write a function that automatically generates a toy dataset of circles. The dataset generates an image with multiple circles (max 5]. The circles may overlap.
27 | ## The Model
28 | I didn't want to complicate things so I used resnet 18 as encoder and followed unet in the upsampling stages. Let's say the input to model is 1x1x256x256 (Batch Size x Channels x Width x Height), then output of
29 | model will be 1x(N+4)x64x64. Here 1 is the batch size, N+4 is the number of output channels and 64 is the downsampled image width. Here N is the number of object categories. In my implementation we only have one category of object i.e. circle. So, in our case the output is 1x5x64x64. Also, notice that we could have gone for complete upsampling like in UNET. The authors went for 64x64 size. An advantage of traning on downsampled images is that it reduces the number of parameters.
30 |
31 | Now an obvious question is if our model output is 4 times smaller [say 64] than the input image, then how are we going to get a fine prediction about the object location ? It seems like we can be off by 4 pixels both in height and width... The authors propose an innovative solution. They predict an offset map to finetune the object position.
32 | So, if model output is 1x5x64x64, the first channel (1x1x64x64) will be score of an object at those pixels, second and third channels represent the width and height. If coordinates say (30,23) are 1, it means that in the original image at the coordinate (30\*4,23\*4), some object may be present. The last 2 channels are offsets. It can be used to get better estimate of the object location.
33 | Also, note that we use sigmoid activation for offsets. This ensures that they are bounded between [0,1].
34 |
35 | ## The loss functions
36 |
37 | We use two types of loss:
38 | - Mask Loss
39 |
40 | 
41 |
42 | This is modified version of focal loss. Here
and
are parameters of focal loss. We set alpha as 2 and beta as 4. In original implementation, the authors use an object centered gaussian kernel
43 |
44 |
45 |
46 | Here
and
are the low resolution center of the object.
is object size adaptive standard deviation. For this toy dataset,using gaussian kernel didn't seem to make much difference. So, in the interest of keeping it minimal I just use
.
47 |
48 | - Size and offset Loss
49 |
50 | For size and offset predictions,we basically use l1 loss normalized by the number of detected points.
51 |
52 | ## Training details
53 |
54 | I used Adam optimizer with default settings. The model starts giving reasonable output after training for around 10 epochs.
55 | We perform non maxima suppression [using Max pooling] on the model class maps. The top 100 predictions are selected. We use offset maps to finetune object coordinates. During visualization, I set a hard threshold of 0.25 i.e. if draw bounding box only if model confidence > 0.25. If you don't care about false alarms, you should probably set a lower threshold.
56 |
57 | ## Conclusion
58 |
59 | I found the paper to be pretty interesting and elegant. I wanted to make the idea more accessible so i am sharing this code. I hope this minimalist implementation helps you in understanding the paper better.
--------------------------------------------------------------------------------
/__pycache__/ctdet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/__pycache__/ctdet.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/__pycache__/dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/__pycache__/loss.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/ctdet.py:
--------------------------------------------------------------------------------
1 | # adapted from
2 | # https://github.com/xingyizhou/CenterNet/blob/master/src/lib/utils/post_process.py
3 | import torch
4 | import torch.nn as nn
5 | from utils import _gather_feat, _tranpose_and_gather_feat
6 |
7 |
8 | def _nms(heat, kernel=3):
9 | pad = (kernel - 1) // 2
10 |
11 | hmax = nn.functional.max_pool2d(
12 | heat, (kernel, kernel), stride=1, padding=pad)
13 | keep = (hmax == heat).float()
14 | return heat * keep
15 |
16 |
17 | def _topk(scores, K=40):
18 | batch, cat, height, width = scores.size()
19 |
20 | topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)
21 |
22 | topk_inds = topk_inds % (height * width)
23 | topk_ys = (topk_inds / width).int().float()
24 | topk_xs = (topk_inds % width).int().float()
25 |
26 | topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
27 | topk_clses = (topk_ind / K).int()
28 | topk_inds = _gather_feat(
29 | topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
30 | topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
31 | topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)
32 |
33 | return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
34 |
35 |
36 | def ctdet_decode(heat, wh, reg=None, cat_spec_wh=False, K=100, model_scale=4):
37 | batch, cat, height, width = heat.size()
38 |
39 | heat = torch.sigmoid(heat)
40 | # perform nms on heatmaps
41 | heat = _nms(heat)
42 |
43 | scores, inds, clses, ys, xs = _topk(heat, K=K)
44 |
45 | xs_raw = xs.view(batch, K, 1) + 0.5
46 | ys_raw = ys.view(batch, K, 1) + 0.5
47 |
48 | if reg is not None:
49 | reg = _tranpose_and_gather_feat(reg, inds)
50 | reg = reg.view(batch, K, 2)
51 | # check if it's correct and not reversed with ys
52 | xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
53 | ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
54 |
55 | wh = _tranpose_and_gather_feat(wh, inds)
56 | if cat_spec_wh:
57 | wh = wh.view(batch, K, cat, 2)
58 | clses_ind = clses.view(batch, K, 1, 1).expand(batch, K, 1, 2).long()
59 | wh = wh.gather(2, clses_ind).view(batch, K, 2)
60 | else:
61 | wh = wh.view(batch, K, 2)
62 |
63 | xs, ys = xs*model_scale, ys*model_scale
64 |
65 | xs_raw, ys_raw = xs_raw*model_scale, ys_raw*model_scale
66 |
67 | clses = clses.view(batch, K, 1).float()
68 | scores = scores.view(batch, K, 1)
69 |
70 | bboxes = torch.cat([xs - wh[..., 1:2] / 2,
71 | ys - wh[..., 0:1] / 2,
72 | xs + wh[..., 1:2] / 2,
73 | ys + wh[..., 0:1] / 2], dim=2)
74 |
75 | bboxes_raw = torch.cat([xs_raw - wh[..., 1:2] / 2,
76 | ys_raw - wh[..., 0:1] / 2,
77 | xs_raw + wh[..., 1:2] / 2,
78 | ys_raw + wh[..., 0:1] / 2], dim=2)
79 |
80 | return bboxes_raw, bboxes, scores, clses
81 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import cv2
3 | import numpy as np
4 |
5 |
6 | class ToyDataset(Dataset):
7 | """Car dataset."""
8 |
9 | def __init__(self, img_shape=(256,256),max_radius=64,
10 | num_classes=1,max_objects=5):
11 | super().__init__()
12 | self.img_shape = np.array(img_shape)
13 | self.num_classes = num_classes
14 | self.max_width = 64
15 | self.max_height = 64
16 | self.max_radius = min(img_shape)//4
17 | self.max_objects = max_objects
18 |
19 |
20 | w, h = self.img_shape//4
21 | # prepare mesh center points
22 | x_arr = np.arange(w) + 0.5
23 | y_arr = np.arange(h) + 0.5
24 | self.xy_mesh = np.stack(np.meshgrid(x_arr, y_arr)) # [2, h, w]
25 |
26 | def __len__(self):
27 | return 1000
28 |
29 | def __getitem__(self, idx):
30 |
31 | im = np.zeros(self.img_shape,dtype=np.float32)
32 | heatmap = np.zeros((self.num_classes+4,self.img_shape[0]//4,self.img_shape[1]//4),dtype=np.float32)
33 | for _ in range(np.random.randint(0,5)):
34 | x,y = np.random.randint(0,self.img_shape[0]),np.random.randint(0,self.img_shape[1])
35 | radius = np.random.randint(10,self.max_radius)
36 | im = np.maximum(im,cv2.circle(im,(y,x),radius=radius,color=1,thickness=-1))
37 |
38 | center = np.array([x,y])/4
39 | x, y = np.floor(center).astype(np.int)
40 | # print('center,wh',center,wh)
41 |
42 | # sigma = gaussian_radius(wh)
43 | # dist_squared = np.sum((self.xy_mesh - center[:, None, None]) ** 2, axis=0)
44 | # gauss = np.exp(-1 * dist_squared / (2 * sigma ** 2))
45 | # heatmap[0, :, :] = np.maximum(heatmap[0, :, :], gauss)
46 |
47 | heatmap[0,x,y] = 1
48 |
49 | # size
50 | heatmap[-4:-2,x,y] = np.array([2*radius,2*radius])
51 |
52 | # offset
53 | heatmap[-2:, x,y] = center - np.floor(center)
54 |
55 | return im[None,:,:], heatmap
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def regr_loss(regr, gt_regr, mask):
6 | num = mask.float().sum()*2
7 |
8 | regr = regr[mask == 1]
9 | gt_regr = gt_regr[mask == 1]
10 | regr_loss = F.l1_loss(
11 | regr, gt_regr, size_average=False)
12 | regr_loss = regr_loss / (num + 1e-4)
13 | return regr_loss
14 |
15 | def _neg_loss(pred, gt,alpha=2,beta=4):
16 | ''' Modified focal loss. Exactly the same as CornerNet.
17 | Runs faster and costs a little bit more memory
18 | Arguments:
19 | pred (batch x c x h x w)
20 | gt_regr (batch x c x h x w)
21 | '''
22 | pos_inds = gt.eq(1).float()
23 | neg_inds = gt.lt(1).float()
24 |
25 | neg_weights = torch.pow(1 - gt, beta)
26 |
27 | loss = 0
28 |
29 | pos_loss = torch.log(pred) * torch.pow(1 - pred, alpha) * pos_inds
30 | neg_loss = torch.log(1 - pred) * torch.pow(pred, alpha) * neg_weights * neg_inds
31 |
32 | num_pos = pos_inds.float().sum()
33 | pos_loss = pos_loss.sum()
34 | neg_loss = neg_loss.sum()
35 |
36 | if num_pos == 0:
37 | loss = loss - neg_loss
38 | else:
39 | loss = loss - (pos_loss + neg_loss) / num_pos
40 | return loss
41 |
42 | def criterion(prediction, true, size_average=True):
43 |
44 | # Binary mask loss
45 | pred_mask = torch.sigmoid(prediction[:, 0])
46 |
47 | mask_loss = _neg_loss(pred_mask[:,None,:,:],true[:,0:1,:,:])
48 |
49 |
50 | size_loss_x = regr_loss(prediction[:,-4,:,:],
51 | true[:,-4,:,:],true[:,0])
52 | size_loss_y = regr_loss(prediction[:,-3,:,:],
53 | true[:,-3,:,:],true[:,0])
54 |
55 | offset_loss_x = regr_loss(prediction[:,-2,:,:],
56 | true[:,-2,:,:],true[:,0])
57 |
58 | offset_loss_y = regr_loss(prediction[:,-1,:,:],
59 | true[:,-1,:,:],true[:,0])
60 |
61 | return mask_loss, size_loss_x+size_loss_y, offset_loss_x+offset_loss_y
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
5 | """3x3 convolution with padding"""
6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
7 | padding=dilation, groups=groups, bias=False, dilation=dilation)
8 |
9 |
10 | def conv1x1(in_planes, out_planes, stride=1):
11 | """1x1 convolution"""
12 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
13 |
14 |
15 | class BasicBlock(nn.Module):
16 | expansion = 1
17 | __constants__ = ['downsample']
18 |
19 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
20 | base_width=64, dilation=1, norm_layer=None):
21 | super(BasicBlock, self).__init__()
22 | if norm_layer is None:
23 | norm_layer = nn.BatchNorm2d
24 | if groups != 1 or base_width != 64:
25 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
26 | if dilation > 1:
27 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
28 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
29 | self.conv1 = conv3x3(inplanes, planes, stride)
30 | self.bn1 = norm_layer(planes)
31 | self.relu = nn.ReLU(inplace=True)
32 | self.conv2 = conv3x3(planes, planes)
33 | self.bn2 = norm_layer(planes)
34 | self.downsample = downsample
35 | self.stride = stride
36 |
37 | def forward(self, x):
38 | identity = x
39 |
40 | out = self.conv1(x)
41 | out = self.bn1(out)
42 | out = self.relu(out)
43 |
44 | out = self.conv2(out)
45 | out = self.bn2(out)
46 |
47 | if self.downsample is not None:
48 | identity = self.downsample(x)
49 |
50 | out += identity
51 | out = self.relu(out)
52 |
53 | return out
54 |
55 |
56 | class Bottleneck(nn.Module):
57 | expansion = 4
58 | __constants__ = ['downsample']
59 |
60 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
61 | base_width=64, dilation=1, norm_layer=None):
62 | super(Bottleneck, self).__init__()
63 | if norm_layer is None:
64 | norm_layer = nn.BatchNorm2d
65 | width = int(planes * (base_width / 64.)) * groups
66 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
67 | self.conv1 = conv1x1(inplanes, width)
68 | self.bn1 = norm_layer(width)
69 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
70 | self.bn2 = norm_layer(width)
71 | self.conv3 = conv1x1(width, planes * self.expansion)
72 | self.bn3 = norm_layer(planes * self.expansion)
73 | self.relu = nn.ReLU(inplace=True)
74 | self.downsample = downsample
75 | self.stride = stride
76 |
77 | def forward(self, x):
78 | identity = x
79 |
80 | out = self.conv1(x)
81 | out = self.bn1(out)
82 | out = self.relu(out)
83 |
84 | out = self.conv2(out)
85 | out = self.bn2(out)
86 | out = self.relu(out)
87 |
88 | out = self.conv3(out)
89 | out = self.bn3(out)
90 |
91 | if self.downsample is not None:
92 | identity = self.downsample(x)
93 |
94 | out += identity
95 | out = self.relu(out)
96 |
97 | return out
98 |
99 |
100 | class ResNet(nn.Module):
101 |
102 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
103 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
104 | norm_layer=None):
105 | super(ResNet, self).__init__()
106 | if norm_layer is None:
107 | norm_layer = nn.BatchNorm2d
108 | self._norm_layer = norm_layer
109 |
110 | self.inplanes = 64
111 | self.dilation = 1
112 | if replace_stride_with_dilation is None:
113 | # each element in the tuple indicates if we should replace
114 | # the 2x2 stride with a dilated convolution instead
115 | replace_stride_with_dilation = [False, False, False]
116 | if len(replace_stride_with_dilation) != 3:
117 | raise ValueError("replace_stride_with_dilation should be None "
118 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
119 | self.groups = groups
120 | self.base_width = width_per_group
121 | self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3,
122 | bias=False)
123 | self.bn1 = norm_layer(self.inplanes)
124 | self.relu = nn.ReLU(inplace=True)
125 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
126 | self.layer1 = self._make_layer(block, 64, layers[0])
127 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
128 | dilate=replace_stride_with_dilation[0])
129 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
130 | dilate=replace_stride_with_dilation[1])
131 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
132 | dilate=replace_stride_with_dilation[2])
133 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
134 | self.fc = nn.Linear(512 * block.expansion, num_classes)
135 |
136 | for m in self.modules():
137 | if isinstance(m, nn.Conv2d):
138 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
139 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
140 | nn.init.constant_(m.weight, 1)
141 | nn.init.constant_(m.bias, 0)
142 |
143 | # Zero-initialize the last BN in each residual branch,
144 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
145 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
146 | if zero_init_residual:
147 | for m in self.modules():
148 | if isinstance(m, Bottleneck):
149 | nn.init.constant_(m.bn3.weight, 0)
150 | elif isinstance(m, BasicBlock):
151 | nn.init.constant_(m.bn2.weight, 0)
152 |
153 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
154 | norm_layer = self._norm_layer
155 | downsample = None
156 | previous_dilation = self.dilation
157 | if dilate:
158 | self.dilation *= stride
159 | stride = 1
160 | if stride != 1 or self.inplanes != planes * block.expansion:
161 | downsample = nn.Sequential(
162 | conv1x1(self.inplanes, planes * block.expansion, stride),
163 | norm_layer(planes * block.expansion),
164 | )
165 |
166 | layers = []
167 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
168 | self.base_width, previous_dilation, norm_layer))
169 | self.inplanes = planes * block.expansion
170 | for _ in range(1, blocks):
171 | layers.append(block(self.inplanes, planes, groups=self.groups,
172 | base_width=self.base_width, dilation=self.dilation,
173 | norm_layer=norm_layer))
174 |
175 | return nn.Sequential(*layers)
176 |
177 | def forward(self, x):
178 | x = self.conv1(x)
179 | x = self.bn1(x)
180 | x = self.relu(x)
181 | x = self.maxpool(x)
182 |
183 | x1 = self.layer1(x)
184 | x2 = self.layer2(x1)
185 | x3 = self.layer3(x2)
186 | x4 = self.layer4(x3)
187 |
188 | return x1,x2,x3,x4
189 |
190 | class Res18UnetCenterNet(nn.Module):
191 |
192 | stride: int = 4
193 |
194 | def __init__(self,n_fg_class=1):
195 | super().__init__()
196 | out_ch = n_fg_class + 4
197 |
198 | self.encoder = ResNet(BasicBlock, [2, 2, 2, 2])
199 | self.dc1 = nn.ConvTranspose2d(512,256,2,2)
200 | self.dc2 = nn.ConvTranspose2d(512,128,2,2)
201 | self.dc3 = nn.ConvTranspose2d(256,64,2,2)
202 | self.dc4 = nn.Conv2d(128,out_ch,kernel_size=3,stride=1,padding=1)
203 |
204 | def forward(self,x):
205 | x1,x2,x3,x4 = self.encoder(x)
206 | h = self.dc1(x4)
207 | h = self.dc2(torch.cat([x3, h],1))
208 | h = self.dc3(torch.cat([x2, h],1))
209 | h = self.dc4(torch.cat([x1, h],1))
210 |
211 | # force 0-1 value range to scores and offsets
212 | C = h.shape[1]
213 | scores, sizes, offsets = torch.split(h,(C-4,C-3,2),1)
214 |
215 | offsets = torch.sigmoid(offsets)
216 |
217 | return torch.cat([scores, sizes, offsets],1)
218 |
219 | if __name__=='__main__':
220 | # resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
221 | a = torch.zeros((1,1,256,256))
222 | # x1,x2,x3,x4 = resnet18(a)
223 |
224 | res_unet = Res18UnetCenterNet()
225 | pred = res_unet(a)
226 | print(pred.shape)
227 | print()
228 |
--------------------------------------------------------------------------------
/model_orig.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
5 | """3x3 convolution with padding"""
6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
7 | padding=dilation, groups=groups, bias=False, dilation=dilation)
8 |
9 |
10 | def conv1x1(in_planes, out_planes, stride=1):
11 | """1x1 convolution"""
12 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
13 |
14 |
15 | class BasicBlock(nn.Module):
16 | expansion = 1
17 | __constants__ = ['downsample']
18 |
19 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
20 | base_width=64, dilation=1, norm_layer=None):
21 | super(BasicBlock, self).__init__()
22 | if norm_layer is None:
23 | norm_layer = nn.BatchNorm2d
24 | if groups != 1 or base_width != 64:
25 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
26 | if dilation > 1:
27 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
28 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
29 | self.conv1 = conv3x3(inplanes, planes, stride)
30 | self.bn1 = norm_layer(planes)
31 | self.relu = nn.ReLU(inplace=True)
32 | self.conv2 = conv3x3(planes, planes)
33 | self.bn2 = norm_layer(planes)
34 | self.downsample = downsample
35 | self.stride = stride
36 |
37 | def forward(self, x):
38 | identity = x
39 |
40 | out = self.conv1(x)
41 | out = self.bn1(out)
42 | out = self.relu(out)
43 |
44 | out = self.conv2(out)
45 | out = self.bn2(out)
46 |
47 | if self.downsample is not None:
48 | identity = self.downsample(x)
49 |
50 | out += identity
51 | out = self.relu(out)
52 |
53 | return out
54 |
55 |
56 | class Bottleneck(nn.Module):
57 | expansion = 4
58 | __constants__ = ['downsample']
59 |
60 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
61 | base_width=64, dilation=1, norm_layer=None):
62 | super(Bottleneck, self).__init__()
63 | if norm_layer is None:
64 | norm_layer = nn.BatchNorm2d
65 | width = int(planes * (base_width / 64.)) * groups
66 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
67 | self.conv1 = conv1x1(inplanes, width)
68 | self.bn1 = norm_layer(width)
69 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
70 | self.bn2 = norm_layer(width)
71 | self.conv3 = conv1x1(width, planes * self.expansion)
72 | self.bn3 = norm_layer(planes * self.expansion)
73 | self.relu = nn.ReLU(inplace=True)
74 | self.downsample = downsample
75 | self.stride = stride
76 |
77 | def forward(self, x):
78 | identity = x
79 |
80 | out = self.conv1(x)
81 | out = self.bn1(out)
82 | out = self.relu(out)
83 |
84 | out = self.conv2(out)
85 | out = self.bn2(out)
86 | out = self.relu(out)
87 |
88 | out = self.conv3(out)
89 | out = self.bn3(out)
90 |
91 | if self.downsample is not None:
92 | identity = self.downsample(x)
93 |
94 | out += identity
95 | out = self.relu(out)
96 |
97 | return out
98 |
99 |
100 | class ResNet(nn.Module):
101 |
102 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
103 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
104 | norm_layer=None):
105 | super(ResNet, self).__init__()
106 | if norm_layer is None:
107 | norm_layer = nn.BatchNorm2d
108 | self._norm_layer = norm_layer
109 |
110 | self.inplanes = 64
111 | self.dilation = 1
112 | if replace_stride_with_dilation is None:
113 | # each element in the tuple indicates if we should replace
114 | # the 2x2 stride with a dilated convolution instead
115 | replace_stride_with_dilation = [False, False, False]
116 | if len(replace_stride_with_dilation) != 3:
117 | raise ValueError("replace_stride_with_dilation should be None "
118 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
119 | self.groups = groups
120 | self.base_width = width_per_group
121 | self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3,
122 | bias=False)
123 | self.bn1 = norm_layer(self.inplanes)
124 | self.relu = nn.ReLU(inplace=True)
125 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
126 | self.layer1 = self._make_layer(block, 64, layers[0])
127 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
128 | dilate=replace_stride_with_dilation[0])
129 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
130 | dilate=replace_stride_with_dilation[1])
131 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
132 | dilate=replace_stride_with_dilation[2])
133 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
134 | self.fc = nn.Linear(512 * block.expansion, num_classes)
135 |
136 | for m in self.modules():
137 | if isinstance(m, nn.Conv2d):
138 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
139 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
140 | nn.init.constant_(m.weight, 1)
141 | nn.init.constant_(m.bias, 0)
142 |
143 | # Zero-initialize the last BN in each residual branch,
144 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
145 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
146 | if zero_init_residual:
147 | for m in self.modules():
148 | if isinstance(m, Bottleneck):
149 | nn.init.constant_(m.bn3.weight, 0)
150 | elif isinstance(m, BasicBlock):
151 | nn.init.constant_(m.bn2.weight, 0)
152 |
153 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
154 | norm_layer = self._norm_layer
155 | downsample = None
156 | previous_dilation = self.dilation
157 | if dilate:
158 | self.dilation *= stride
159 | stride = 1
160 | if stride != 1 or self.inplanes != planes * block.expansion:
161 | downsample = nn.Sequential(
162 | conv1x1(self.inplanes, planes * block.expansion, stride),
163 | norm_layer(planes * block.expansion),
164 | )
165 |
166 | layers = []
167 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
168 | self.base_width, previous_dilation, norm_layer))
169 | self.inplanes = planes * block.expansion
170 | for _ in range(1, blocks):
171 | layers.append(block(self.inplanes, planes, groups=self.groups,
172 | base_width=self.base_width, dilation=self.dilation,
173 | norm_layer=norm_layer))
174 |
175 | return nn.Sequential(*layers)
176 |
177 | def forward(self, x):
178 | x = self.conv1(x)
179 | x = self.bn1(x)
180 | x = self.relu(x)
181 | x = self.maxpool(x)
182 |
183 | x1 = self.layer1(x)
184 | x2 = self.layer2(x1)
185 | x3 = self.layer3(x2)
186 | x4 = self.layer4(x3)
187 |
188 | return x1,x2,x3,x4
189 |
190 | class Res18UnetCenterNet(nn.Module):
191 |
192 | stride: int = 4
193 |
194 | def __init__(self,n_fg_class=1):
195 | super().__init__()
196 | out_ch = n_fg_class + 4
197 |
198 | self.encoder = ResNet(BasicBlock, [2, 2, 2, 2])
199 | self.dc1 = nn.ConvTranspose2d(512,256,2,2)
200 | self.dc2 = nn.ConvTranspose2d(512,128,2,2)
201 | self.dc3 = nn.ConvTranspose2d(256,64,2,2)
202 | self.dc4 = nn.Conv2d(128,out_ch,kernel_size=3,stride=1,padding=1)
203 |
204 | def forward(self,x):
205 | x1,x2,x3,x4 = self.encoder(x)
206 | h = self.dc1(x4)
207 | h = self.dc2(torch.cat([x3, h],1))
208 | h = self.dc3(torch.cat([x2, h],1))
209 | h = self.dc4(torch.cat([x1, h],1))
210 |
211 | # force 0-1 value range to scores and offsets
212 | C = h.shape[1]
213 | scores, sizes, offsets = torch.split(h,(C-4,C-3,2),1)
214 |
215 | offsets = torch.sigmoid(offsets)
216 |
217 | return torch.cat([scores, sizes, offsets],1)
218 |
219 | if __name__=='__main__':
220 | # resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
221 | a = torch.zeros((1,1,256,256))
222 | # x1,x2,x3,x4 = resnet18(a)
223 |
224 | res_unet = Res18UnetCenterNet()
225 | pred = res_unet(a)
226 | print(pred.shape)
227 | print()
228 |
--------------------------------------------------------------------------------
/outputs/epoch_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_0.jpg
--------------------------------------------------------------------------------
/outputs/epoch_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_1.jpg
--------------------------------------------------------------------------------
/outputs/epoch_10.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_10.jpg
--------------------------------------------------------------------------------
/outputs/epoch_11.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_11.jpg
--------------------------------------------------------------------------------
/outputs/epoch_12.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_12.jpg
--------------------------------------------------------------------------------
/outputs/epoch_13.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_13.jpg
--------------------------------------------------------------------------------
/outputs/epoch_14.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_14.jpg
--------------------------------------------------------------------------------
/outputs/epoch_15.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_15.jpg
--------------------------------------------------------------------------------
/outputs/epoch_16.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_16.jpg
--------------------------------------------------------------------------------
/outputs/epoch_17.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_17.jpg
--------------------------------------------------------------------------------
/outputs/epoch_18.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_18.jpg
--------------------------------------------------------------------------------
/outputs/epoch_19.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_19.jpg
--------------------------------------------------------------------------------
/outputs/epoch_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_2.jpg
--------------------------------------------------------------------------------
/outputs/epoch_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_3.jpg
--------------------------------------------------------------------------------
/outputs/epoch_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_4.jpg
--------------------------------------------------------------------------------
/outputs/epoch_5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_5.jpg
--------------------------------------------------------------------------------
/outputs/epoch_6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_6.jpg
--------------------------------------------------------------------------------
/outputs/epoch_7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_7.jpg
--------------------------------------------------------------------------------
/outputs/epoch_8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_8.jpg
--------------------------------------------------------------------------------
/outputs/epoch_9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/epoch_9.jpg
--------------------------------------------------------------------------------
/outputs/mask_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sidml/Understanding-Centernet/8c971592c53f87cbd9350c7322f407e1319967b1/outputs/mask_loss.png
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from model import Res18UnetCenterNet
3 | from loss import criterion
4 | from dataset import ToyDataset
5 | from ctdet import ctdet_decode
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import cv2
9 |
10 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
11 |
12 |
13 | def train_model(model, optimizer, dataloader, epoch):
14 | model.train()
15 | running_mask_loss, running_size_loss, running_offset_loss = 0, 0, 0
16 | for batch_idx, (img_batch, mask_batch) in enumerate(dataloader):
17 | img_batch = img_batch.to(device)
18 | mask_batch = mask_batch.to(device)
19 | # center_index = center_index.to(device)
20 |
21 | optimizer.zero_grad()
22 | output = model(img_batch)
23 |
24 | mask_loss, size_loss, offset_loss = criterion(output, mask_batch)
25 |
26 | loss = mask_loss + size_loss + offset_loss
27 |
28 | loss.backward()
29 |
30 | optimizer.step()
31 |
32 | running_mask_loss += mask_loss
33 | running_size_loss += size_loss
34 | running_offset_loss += offset_loss
35 |
36 | if batch_idx % 5 == 0:
37 | print(f'\r{running_mask_loss/(batch_idx+1):.3f} {running_size_loss/(batch_idx+1):.3f} {running_offset_loss/(batch_idx+1):.3f}',
38 | end='', flush=True)
39 |
40 | print('\r', end='', flush=True)
41 | print(f"Epoch: {epoch} mask_loss: {running_mask_loss/(batch_idx): .3f} "
42 | f"size_loss: {running_size_loss/(batch_idx): .3f} offset_loss: {running_offset_loss/(batch_idx): .3f}")
43 |
44 |
45 | @torch.no_grad()
46 | def eval_model(model, dataloader, output_folder, epoch=0, thresh=0.25):
47 |
48 | for (img_batch, mask_batch) in dataloader:
49 | img_batch = img_batch.to(device)
50 | mask_batch = mask_batch.to(device)
51 | predictions = model(img_batch)
52 | bboxes_raw, bboxes, scores, clses = ctdet_decode(predictions[:, 0:1], predictions[:, -4:-2, :, :],
53 | predictions[:, -2:, :, :])
54 | bboxes = bboxes.long().cpu().numpy()
55 |
56 | for batch_idx, (im, mask, pred) in enumerate(zip(img_batch, mask_batch, predictions)):
57 | im = im.permute(1, 2, 0).cpu().squeeze().numpy()*255
58 | im = np.repeat(im[:, :, None], 3, 2)
59 |
60 | score_pos = []
61 | for score, bbox, bbox_raw in zip(scores[batch_idx], bboxes[batch_idx], bboxes_raw[batch_idx]):
62 | if score > thresh:
63 | im = np.maximum(im, cv2.rectangle(
64 | im, (bbox[2], bbox[3]), (bbox[0], bbox[1]), (0, 255, 0), 2))
65 |
66 | # uncomment to visualize bbox without offset correction
67 | # im = np.maximum(im, cv2.rectangle(
68 | # im, (bbox_raw[2], bbox_raw[3]), (bbox_raw[0], bbox_raw[1]), (255, 0, 0), 2))
69 | score_pos.append((bbox[2]+5, bbox[3]+5, score))
70 | else:
71 | break
72 |
73 | plt.subplot(2, 3, 1)
74 | plt.title('Image with pred bbox')
75 | plt.imshow(im.astype(np.int))
76 | for pos_x, pos_y, score in score_pos:
77 | plt.text(pos_x, pos_y, f'{score.item():.2}', fontsize=6, c='r')
78 |
79 | plt.subplot(2, 3, 2)
80 | plt.title('Mask')
81 | plt.imshow(mask[0].cpu().squeeze())
82 |
83 | plt.subplot(2, 3, 3)
84 | plt.title('Mask Prediction')
85 | plt.imshow(pred[0].cpu().squeeze())
86 |
87 | plt.subplot(2, 3, 4)
88 | plt.title('Width Prediction')
89 | plt.imshow(pred[1].cpu().squeeze())
90 |
91 | plt.subplot(2, 3, 5)
92 | plt.title('Height Prediction')
93 | plt.imshow(pred[3].cpu().squeeze())
94 |
95 | plt.subplot(2, 3, 6)
96 | plt.title('Width offset Prediction')
97 | plt.imshow(pred[4].cpu().squeeze())
98 |
99 | plt.suptitle(f'Epoch {epoch}')
100 | # plt.show()
101 | # print(f'{output_folder}_{epoch}.jpg')
102 | plt.savefig(f'{output_folder}/epoch_{epoch}.jpg')
103 | plt.close()
104 | break
105 | break
106 |
107 |
108 | if __name__ == '__main__':
109 | import os
110 |
111 | model = Res18UnetCenterNet()
112 | model = model.to(device)
113 | optim = torch.optim.Adam(model.parameters())
114 | dataset = ToyDataset(img_shape=(128, 128))
115 |
116 | dataloader = torch.utils.data.DataLoader(dataset=dataset,
117 | batch_size=8, num_workers=0)
118 |
119 | output_folder = './outputs'
120 | os.makedirs(output_folder, exist_ok=True)
121 | for epoch in range(20):
122 | train_model(model, optim, dataloader, epoch)
123 | eval_model(model, dataloader, output_folder, epoch=epoch)
124 |
125 | print()
126 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | def _sigmoid(x):
9 | y = torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4)
10 | return y
11 |
12 | def _gather_feat(feat, ind, mask=None):
13 | dim = feat.size(2)
14 | ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
15 | feat = feat.gather(1, ind)
16 | if mask is not None:
17 | mask = mask.unsqueeze(2).expand_as(feat)
18 | feat = feat[mask]
19 | feat = feat.view(-1, dim)
20 | return feat
21 |
22 | def _tranpose_and_gather_feat(feat, ind):
23 | feat = feat.permute(0, 2, 3, 1).contiguous()
24 | feat = feat.view(feat.size(0), -1, feat.size(3))
25 | feat = _gather_feat(feat, ind)
26 | return feat
27 |
28 | def flip_tensor(x):
29 | return torch.flip(x, [3])
30 | # tmp = x.detach().cpu().numpy()[..., ::-1].copy()
31 | # return torch.from_numpy(tmp).to(x.device)
32 |
33 | def flip_lr(x, flip_idx):
34 | tmp = x.detach().cpu().numpy()[..., ::-1].copy()
35 | shape = tmp.shape
36 | for e in flip_idx:
37 | tmp[:, e[0], ...], tmp[:, e[1], ...] = \
38 | tmp[:, e[1], ...].copy(), tmp[:, e[0], ...].copy()
39 | return torch.from_numpy(tmp.reshape(shape)).to(x.device)
40 |
41 | def flip_lr_off(x, flip_idx):
42 | tmp = x.detach().cpu().numpy()[..., ::-1].copy()
43 | shape = tmp.shape
44 | tmp = tmp.reshape(tmp.shape[0], 17, 2,
45 | tmp.shape[2], tmp.shape[3])
46 | tmp[:, :, 0, :, :] *= -1
47 | for e in flip_idx:
48 | tmp[:, e[0], ...], tmp[:, e[1], ...] = \
49 | tmp[:, e[1], ...].copy(), tmp[:, e[0], ...].copy()
50 | return torch.from_numpy(tmp.reshape(shape)).to(x.device)
--------------------------------------------------------------------------------