├── README.md
├── demo.py
├── pp_liteseg.py
├── ppliteseg_paddlepaddle.png
└── results.jpg
/README.md:
--------------------------------------------------------------------------------
1 | # PPLiteSeg.pytorch
2 | pytorch of the SOTA real-time segmentation network ppliteseg
3 |
4 | | Model | Backbone | Training Iters | Train Resolution | Test Resolution | mIoU | mIoU (flip) | mIoU (ms+flip) |
5 | | ------------ | -------- | -------------- | ---------------- | --------------- | ------ | ----------- | -------------- |
6 | | PP-LiteSeg-T | STDC1 | 160000 | 1024x512 | 1025x512 | 73.10% | 73.89% | - |
7 | | PP-LiteSeg-T | STDC1 | 160000 | 1024x512 | 1536x768 | 76.03% | 76.74% | - |
8 | | PP-LiteSeg-T | STDC1 | 160000 | 1024x512 | 2048x1024 | 77.04% | 77.73% | 77.46% |
9 | | PP-LiteSeg-B | STDC2 | 160000 | 1024x512 | 1024x512 | 75.25% | 75.65% | - |
10 | | PP-LiteSeg-B | STDC2 | 160000 | 1024x512 | 1536x768 | 78.75% | 79.23% | - |
11 | | PP-LiteSeg-B | STDC2 | 160000 | 1024x512 | 2048x1024 | 79.04% | 79.52% | 79.85% |
12 |
13 | here we convert the model and weights of PP-LiteSeg-B(1024x512) from paddlepaddle to torch.
14 |
15 | ## Model&Weight
16 |
17 | pp_liteseg.py : pytorch model
18 |
19 | [ppliteset_pp2torch_cityscape_pretrained.pth](https://github.com/midasklr/PPLiteSeg.pytorch/releases/download/weights/ppliteset_pp2torch_cityscape_pretrained.pth): the cityscape pretrained weights trained with paddleseg
20 |
21 | demo of paddleseg:
22 |
23 |

24 |
25 | demo of pytorch:
26 |
27 | 
28 |
29 | ## Difference
30 |
31 | ### upsample
32 |
33 | PaddleSeg use "bilinear" mode, while in pytorch, I use "nereast" mode in order to convert to TensorRT .
34 |
35 | ### BatchNorm
36 |
37 | paddleseg :momentum=0.9
38 |
39 | while the default setting of torch: momentum=0.1.
40 |
41 | ## train
42 |
43 | use [ddrnet](https://github.com/midasklr/DDRNet.Pytorch) to train this model, set the coefficient of three seghead(1/8,1/16and1/32) loss to 1 while training.
44 |
45 | ## Demo
46 |
47 | see demo.py
48 |
49 | ## reference
50 |
51 | 1. https://github.com/PaddlePaddle/PaddleSeg/tree/develop/configs/pp_liteseg
52 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import argparse
8 | import os
9 | import pprint
10 | import shutil
11 | import sys
12 | import random
13 | import logging
14 | import time
15 | import timeit
16 | from pathlib import Path
17 | import time
18 | import numpy as np
19 |
20 | import torch
21 | import torch.nn as nn
22 | from pp_liteseg import PPLiteSeg
23 | import cv2
24 | import torch.nn.functional as F
25 | import datasets
26 |
27 |
28 |
29 | def parse_args():
30 | parser = argparse.ArgumentParser(description='Train segmentation network')
31 |
32 | parser.add_argument('--image',
33 | help='test image path',
34 | default="mainz_000001_009328_leftImg8bit.png",
35 | type=str)
36 | parser.add_argument('--weights',
37 | help='cityscape pretrained weights',
38 | default="ppliteset_pp2torch_cityscape_pretrained.pth",
39 | type=str)
40 | parser.add_argument('opts',
41 | help="Modify config options using the command-line",
42 | default=None,
43 | nargs=argparse.REMAINDER)
44 |
45 | args = parser.parse_args()
46 |
47 | return args
48 |
49 |
50 | def colorEncode(labelmap, colors, mode='RGB'):
51 | labelmap = labelmap.astype('int')
52 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),
53 | dtype=np.uint8)
54 | for label in np.unique(labelmap):
55 | if label < 0:
56 | continue
57 | labelmap_rgb = labelmap_rgb + (labelmap == label)[:, :, np.newaxis] * \
58 | np.tile(colors[label],
59 | (labelmap.shape[0], labelmap.shape[1], 1))
60 |
61 | if mode == 'BGR':
62 | return labelmap_rgb[:, :, ::-1]
63 | else:
64 | return labelmap_rgb
65 |
66 |
67 | def main():
68 | base_size = 512
69 | wh = 2
70 | mean = [0.5, 0.5, 0.5],
71 | std = [0.5, 0.5, 0.5]
72 | args = parse_args()
73 |
74 | model = PPLiteSeg()
75 |
76 | model.eval()
77 |
78 | print("ppliteseg:", model)
79 | ckpt = torch.load(args.weights)
80 | model = model.cuda()
81 | if 'state_dict' in ckpt:
82 | model.load_state_dict(ckpt['state_dict'])
83 | else:
84 | model.load_state_dict(ckpt)
85 |
86 | img = cv2.imread(args.image)
87 | imgor = img.copy()
88 | img = cv2.resize(img, (wh * base_size, base_size))
89 | image = img.astype(np.float32)[:, :, ::-1]
90 | image = image / 255.0
91 | image -= mean
92 | image /= std
93 |
94 | image = image.transpose((2, 0, 1))
95 | image = torch.from_numpy(image)
96 |
97 | # image = image.permute((2, 0, 1))
98 |
99 | image = image.unsqueeze(0)
100 | image = image.cuda()
101 | start = time.time()
102 | out = model(image)
103 | end = time.time()
104 | print("infer time:", end - start, " s")
105 | out = out[0].squeeze(dim=0)
106 | outadd = F.softmax(out, dim=0)
107 | outadd = torch.argmax(outadd, dim=0)
108 | predadd = outadd.detach().cpu().numpy()
109 | pred = np.int32(predadd)
110 | colors = np.random.randint(0, 255, 19 * 3)
111 | colors = np.reshape(colors, (19, 3))
112 | # colorize prediction
113 | pred_color = colorEncode(pred, colors).astype(np.uint8)
114 | pred_color = cv2.resize(pred_color,(imgor.shape[1],imgor.shape[0]))
115 |
116 | im_vis = cv2.addWeighted(imgor, 0.7, pred_color, 0.3, 0)
117 | cv2.imwrite("results.jpg", im_vis)
118 |
119 |
120 | if __name__ == '__main__':
121 | main()
122 |
--------------------------------------------------------------------------------
/pp_liteseg.py:
--------------------------------------------------------------------------------
1 | # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 | import math
19 |
20 |
21 | class ConvBN(nn.Module):
22 | def __init__(self,
23 | in_channels,
24 | out_channels,
25 | kernel_size,
26 | stride=1,
27 | padding=1,
28 | bias = False,
29 | **kwargs):
30 | super().__init__()
31 | self._conv = nn.Conv2d(
32 | in_channels, out_channels, kernel_size, stride=stride, padding=kernel_size//2 if padding else 0,
33 | bias = bias, **kwargs)
34 | self._batch_norm = nn.BatchNorm2d(out_channels, momentum=0.1)
35 |
36 | def forward(self, x):
37 | x = self._conv(x)
38 | x = self._batch_norm(x)
39 | return x
40 |
41 |
42 | class ConvBNReLU(nn.Module):
43 | def __init__(self,
44 | in_channels,
45 | out_channels,
46 | kernel_size=3,
47 | stride = 1,
48 | padding=1,
49 | bias = False,
50 | **kwargs):
51 | super().__init__()
52 |
53 | self._conv = nn.Conv2d(
54 | in_channels, out_channels, kernel_size, stride=stride, padding=kernel_size//2 if padding else 0, bias = bias,**kwargs)
55 |
56 | self._batch_norm = nn.BatchNorm2d(out_channels, momentum=0.1)
57 | self._relu = nn.ReLU(inplace=True)
58 |
59 | def forward(self, x):
60 | x = self._conv(x)
61 | x = self._batch_norm(x)
62 | x = self._relu(x)
63 | return x
64 |
65 |
66 | class ConvBNRelu(nn.Module):
67 | def __init__(self,
68 | in_channels,
69 | out_channels,
70 | kernel_size=3,
71 | stride = 1,
72 | padding=1,
73 | bias = False,
74 | **kwargs):
75 | super().__init__()
76 |
77 | self.conv = nn.Conv2d(
78 | in_channels, out_channels, kernel_size, stride=stride, padding=kernel_size//2 if padding else 0, bias = bias, **kwargs)
79 |
80 | self.bn = nn.BatchNorm2d(out_channels, momentum=0.1)
81 | self.relu = nn.ReLU(inplace=True)
82 |
83 | def forward(self, x):
84 | x = self.conv(x)
85 | x = self.bn(x)
86 | x = self.relu(x)
87 | return x
88 |
89 |
90 | def avg_max_reduce_channel_helper(x, use_concat=True):
91 | # Reduce hw by avg and max, only support single input
92 | assert not isinstance(x, (list, tuple))
93 | # print("x before mean and max:", x.shape)
94 | mean_value = torch.mean(x, dim=1, keepdim=True)
95 | max_value = torch.max(x, dim=1, keepdim=True)[0]
96 | # mean_value = mean_value.unsqueeze(0)
97 | # print("mean max:", mean_value.shape, max_value.shape)
98 |
99 | if use_concat:
100 | res = torch.at([mean_value, max_value], dim=1)
101 | else:
102 | res = [mean_value, max_value]
103 | return res
104 |
105 |
106 | def avg_max_reduce_channel(x):
107 | # Reduce hw by avg and max
108 | # Return cat([avg_ch_0, max_ch_0, avg_ch_1, max_ch_1, ...])
109 | if not isinstance(x, (list, tuple)):
110 | return avg_max_reduce_channel_helper(x)
111 | elif len(x) == 1:
112 | return avg_max_reduce_channel_helper(x[0])
113 | else:
114 | res = []
115 | for xi in x:
116 | # print(xi.shape)
117 | res.extend(avg_max_reduce_channel_helper(xi, False))
118 | # print("res:\n",)
119 | # for it in res:
120 | # print(it.shape)
121 | return torch.cat(res, dim=1)
122 |
123 |
124 | class UAFM(nn.Module):
125 | """
126 | The base of Unified Attention Fusion Module.
127 | Args:
128 | x_ch (int): The channel of x tensor, which is the low level feature.
129 | y_ch (int): The channel of y tensor, which is the high level feature.
130 | out_ch (int): The channel of output tensor.
131 | ksize (int, optional): The kernel size of the conv for x tensor. Default: 3.
132 | resize_mode (str, optional): The resize model in unsampling y tensor. Default: bilinear.
133 | """
134 |
135 | def __init__(self, x_ch, y_ch, out_ch, ksize=3, resize_mode='nearest'):
136 | super().__init__()
137 |
138 | self.conv_x = ConvBNReLU(
139 | x_ch, y_ch, kernel_size=ksize, padding=ksize // 2, bias=False)
140 | self.conv_out = ConvBNReLU(
141 | y_ch, out_ch, kernel_size=3, padding=1, bias=False)
142 | self.resize_mode = resize_mode
143 |
144 | def check(self, x, y):
145 | # print("x dim:",x.ndim)
146 | assert x.ndim == 4 and y.ndim == 4
147 | x_h, x_w = x.shape[2:]
148 | y_h, y_w = y.shape[2:]
149 | assert x_h >= y_h and x_w >= y_w
150 |
151 | def prepare(self, x, y):
152 | x = self.prepare_x(x, y)
153 | y = self.prepare_y(x, y)
154 | return x, y
155 |
156 | def prepare_x(self, x, y):
157 | x = self.conv_x(x)
158 | return x
159 |
160 | def prepare_y(self, x, y):
161 | y_up = F.interpolate(y, x.shape[2:], mode=self.resize_mode)
162 | return y_up
163 |
164 | def fuse(self, x, y):
165 | out = x + y
166 | out = self.conv_out(out)
167 | return out
168 |
169 | def forward(self, x, y):
170 | """
171 | Args:
172 | x (Tensor): The low level feature.
173 | y (Tensor): The high level feature.
174 | """
175 | # print("x,y shape:",x.shape, y.shape)
176 | self.check(x, y)
177 | x, y = self.prepare(x, y)
178 | out = self.fuse(x, y)
179 | return out
180 |
181 |
182 | class UAFM_SpAtten(UAFM):
183 | """
184 | The UAFM with spatial attention, which uses mean and max values.
185 | Args:
186 | x_ch (int): The channel of x tensor, which is the low level feature.
187 | y_ch (int): The channel of y tensor, which is the high level feature.
188 | out_ch (int): The channel of output tensor.
189 | ksize (int, optional): The kernel size of the conv for x tensor. Default: 3.
190 | resize_mode (str, optional): The resize model in unsampling y tensor. Default: bilinear.
191 | """
192 |
193 | def __init__(self, x_ch, y_ch, out_ch, ksize=3, resize_mode='nearest'):
194 | super().__init__(x_ch, y_ch, out_ch, ksize, resize_mode)
195 |
196 | self.conv_xy_atten = nn.Sequential(
197 | ConvBNReLU(
198 | 4, 2, kernel_size=3, padding=1, bias=False),
199 | ConvBN(
200 | 2, 1, kernel_size=3, padding=1, bias=False))
201 |
202 | def fuse(self, x, y):
203 | """
204 | Args:
205 | x (Tensor): The low level feature.
206 | y (Tensor): The high level feature.
207 | """
208 | # print("x, y shape:",x.shape, y.shape)
209 | atten = avg_max_reduce_channel([x, y])
210 | atten = F.sigmoid(self.conv_xy_atten(atten))
211 |
212 | out = x * atten + y * (1 - atten)
213 | out = self.conv_out(out)
214 | return out
215 |
216 |
217 | class CatBottleneck(nn.Module):
218 | def __init__(self, in_planes, out_planes, block_num=3, stride=1):
219 | super(CatBottleneck, self).__init__()
220 | assert block_num > 1, "block number should be larger than 1."
221 | self.conv_list = nn.ModuleList()
222 | self.stride = stride
223 | if stride == 2:
224 | self.avd_layer = nn.Sequential(
225 | nn.Conv2d(
226 | out_planes // 2,
227 | out_planes // 2,
228 | kernel_size=3,
229 | stride=2,
230 | padding=1,
231 | groups=out_planes // 2,
232 | bias=False),
233 | nn.BatchNorm2d(out_planes // 2, momentum=0.1), )
234 | self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
235 | stride = 1
236 |
237 | for idx in range(block_num):
238 | if idx == 0:
239 | self.conv_list.append(
240 | ConvBNRelu(
241 | in_planes, out_planes // 2, kernel_size=1))
242 | elif idx == 1 and block_num == 2:
243 | self.conv_list.append(
244 | ConvBNRelu(
245 | out_planes // 2, out_planes // 2, stride=stride))
246 | elif idx == 1 and block_num > 2:
247 | self.conv_list.append(
248 | ConvBNRelu(
249 | out_planes // 2, out_planes // 4, stride=stride))
250 | elif idx < block_num - 1:
251 | self.conv_list.append(
252 | ConvBNRelu(out_planes // int(math.pow(2, idx)), out_planes
253 | // int(math.pow(2, idx + 1))))
254 | else:
255 | self.conv_list.append(
256 | ConvBNRelu(out_planes // int(math.pow(2, idx)), out_planes
257 | // int(math.pow(2, idx))))
258 |
259 | def forward(self, x):
260 | out_list = []
261 | out1 = self.conv_list[0](x)
262 | for idx, conv in enumerate(self.conv_list[1:]):
263 | if idx == 0:
264 | if self.stride == 2:
265 | out = conv(self.avd_layer(out1))
266 | else:
267 | out = conv(out1)
268 | else:
269 | out = conv(out)
270 | out_list.append(out)
271 |
272 | if self.stride == 2:
273 | out1 = self.skip(out1)
274 | out_list.insert(0, out1)
275 | out = torch.cat(out_list, dim=1)
276 | return out
277 |
278 |
279 | class AddBottleneck(nn.Module):
280 | def __init__(self, in_planes, out_planes, block_num=3, stride=1):
281 | super(AddBottleneck, self).__init__()
282 | assert block_num > 1, "block number should be larger than 1."
283 | self.conv_list = nn.ModuleList()
284 | self.stride = stride
285 | if stride == 2:
286 | self.avd_layer = nn.Sequential(
287 | nn.Conv2d(
288 | out_planes // 2,
289 | out_planes // 2,
290 | kernel_size=3,
291 | stride=2,
292 | padding=1,
293 | groups=out_planes // 2,
294 | bias=False),
295 | nn.BatchNorm2D(out_planes // 2, momentum=0.1), )
296 | self.skip = nn.Sequential(
297 | nn.Conv2d(
298 | in_planes,
299 | in_planes,
300 | kernel_size=3,
301 | stride=2,
302 | padding=1,
303 | groups=in_planes,
304 | bias_attr=False),
305 | nn.BatchNorm2d(in_planes, momentum=0.1),
306 | nn.Conv2d(
307 | in_planes, out_planes, kernel_size=1, bias=False),
308 | nn.BatchNorm2d(out_planes, momentum=0.1), )
309 | stride = 1
310 |
311 | for idx in range(block_num):
312 | if idx == 0:
313 | self.conv_list.append(
314 | ConvBNRelu(
315 | in_planes, out_planes // 2, kernel=1))
316 | elif idx == 1 and block_num == 2:
317 | self.conv_list.append(
318 | ConvBNRelu(
319 | out_planes // 2, out_planes // 2, stride=stride))
320 | elif idx == 1 and block_num > 2:
321 | self.conv_list.append(
322 | ConvBNRelu(
323 | out_planes // 2, out_planes // 4, stride=stride))
324 | elif idx < block_num - 1:
325 | self.conv_list.append(
326 | ConvBNRelu(out_planes // int(math.pow(2, idx)), out_planes
327 | // int(math.pow(2, idx + 1))))
328 | else:
329 | self.conv_list.append(
330 | ConvBNRelu(out_planes // int(math.pow(2, idx)), out_planes
331 | // int(math.pow(2, idx))))
332 |
333 | def forward(self, x):
334 | out_list = []
335 | out = x
336 | for idx, conv in enumerate(self.conv_list):
337 | if idx == 0 and self.stride == 2:
338 | out = self.avd_layer(conv(out))
339 | else:
340 | out = conv(out)
341 | out_list.append(out)
342 | if self.stride == 2:
343 | x = self.skip(x)
344 | return torch.cat(out_list, dim=1) + x
345 |
346 |
347 | class STDCNet(nn.Module):
348 | """
349 | The STDCNet implementation based on Pytorch.
350 |
351 | The original article refers to Meituan
352 | Fan, Mingyuan, et al. "Rethinking BiSeNet For Real-time Semantic Segmentation."
353 | (https://arxiv.org/abs/2104.13188)
354 |
355 | Args:
356 | base(int, optional): base channels. Default: 64.
357 | layers(list, optional): layers numbers list. It determines STDC block numbers of STDCNet's stage3\4\5. Defualt: [4, 5, 3].
358 | block_num(int,optional): block_num of features block. Default: 4.
359 | type(str,optional): feature fusion method "cat"/"add". Default: "cat".
360 | num_classes(int, optional): class number for image classification. Default: 1000.
361 | dropout(float,optional): dropout ratio. if >0,use dropout ratio. Default: 0.20.
362 | use_conv_last(bool,optional): whether to use the last ConvBNReLU layer . Default: False.
363 | pretrained(str, optional): the path of pretrained model.
364 | """
365 |
366 | def __init__(self,
367 | base=64,
368 | layers=[4, 5, 3],
369 | block_num=4,
370 | type="cat",
371 | num_classes=1000,
372 | dropout=0.20,
373 | use_conv_last=False,
374 | pretrained=None):
375 | super(STDCNet, self).__init__()
376 | if type == "cat":
377 | block = CatBottleneck
378 | elif type == "add":
379 | block = AddBottleneck
380 | self.use_conv_last = use_conv_last
381 | self.feat_channels = [base // 2, base, base * 4, base * 8, base * 16]
382 | self.features = self._make_layers(base, layers, block_num, block)
383 | self.conv_last = ConvBNRelu(base * 16, max(1024, base * 16), 1, 1)
384 |
385 | if (layers == [4, 5, 3]): # stdc1446
386 | self.x2 = nn.Sequential(self.features[:1])
387 | self.x4 = nn.Sequential(self.features[1:2])
388 | self.x8 = nn.Sequential(self.features[2:6])
389 | self.x16 = nn.Sequential(self.features[6:11])
390 | self.x32 = nn.Sequential(self.features[11:])
391 | elif (layers == [2, 2, 2]): # stdc813
392 | self.x2 = nn.Sequential(self.features[:1])
393 | self.x4 = nn.Sequential(self.features[1:2])
394 | self.x8 = nn.Sequential(self.features[2:4])
395 | self.x16 = nn.Sequential(self.features[4:6])
396 | self.x32 = nn.Sequential(self.features[6:])
397 | else:
398 | raise NotImplementedError(
399 | "model with layers:{} is not implemented!".format(layers))
400 |
401 | self.pretrained = pretrained
402 | # self.init_weight()
403 |
404 | def forward(self, x):
405 | """
406 | forward function for feature extract.
407 | """
408 | feat2 = self.x2(x)
409 | feat4 = self.x4(feat2)
410 | feat8 = self.x8(feat4)
411 | feat16 = self.x16(feat8)
412 | feat32 = self.x32(feat16)
413 | if self.use_conv_last:
414 | feat32 = self.conv_last(feat32)
415 | return feat2, feat4, feat8, feat16, feat32
416 |
417 | def _make_layers(self, base, layers, block_num, block):
418 | features = []
419 | features += [ConvBNRelu(3, base // 2, 3, 2)]
420 | features += [ConvBNRelu(base // 2, base, 3, 2)]
421 |
422 | for i, layer in enumerate(layers):
423 | for j in range(layer):
424 | if i == 0 and j == 0:
425 | features.append(block(base, base * 4, block_num, 2))
426 | elif j == 0:
427 | features.append(
428 | block(base * int(math.pow(2, i + 1)), base * int(
429 | math.pow(2, i + 2)), block_num, 2))
430 | else:
431 | features.append(
432 | block(base * int(math.pow(2, i + 2)), base * int(
433 | math.pow(2, i + 2)), block_num, 1))
434 |
435 | return nn.Sequential(*features)
436 |
437 | # def init_weight(self):
438 | # for layer in self.sublayers():
439 | # if isinstance(layer, nn.Conv2D):
440 | # param_init.normal_init(layer.weight, std=0.001)
441 | # elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)):
442 | # param_init.constant_init(layer.weight, value=1.0)
443 | # param_init.constant_init(layer.bias, value=0.0)
444 | # if self.pretrained is not None:
445 | # utils.load_pretrained_model(self, self.pretrained)
446 |
447 |
448 | def STDC2(**kwargs):
449 | model = STDCNet(base=64, layers=[4, 5, 3], **kwargs)
450 | return model
451 |
452 |
453 | class PPLiteSeg(nn.Module):
454 | """
455 | The PP_LiteSeg implementation based on Pytorch.
456 |
457 | The original article refers to "Juncai Peng, Yi Liu, Shiyu Tang, Yuying Hao, Lutao Chu,
458 | Guowei Chen, Zewu Wu, Zeyu Chen, Zhiliang Yu, Yuning Du, Qingqing Dang,Baohua Lai,
459 | Qiwen Liu, Xiaoguang Hu, Dianhai Yu, Yanjun Ma. PP-LiteSeg: A Superior Real-Time Semantic
460 | Segmentation Model. https://arxiv.org/abs/2204.02681".
461 |
462 | Args:
463 | num_classes (int): The number of target classes.
464 | backbone(nn.Layer): Backbone network, such as stdc1net and resnet18. The backbone must
465 | has feat_channels, of which the length is 5.
466 | backbone_indices (List(int), optional): The values indicate the indices of output of backbone.
467 | Default: [2, 3, 4].
468 | arm_type (str, optional): The type of attention refinement module. Default: ARM_Add_SpAttenAdd3.
469 | cm_bin_sizes (List(int), optional): The bin size of context module. Default: [1,2,4].
470 | cm_out_ch (int, optional): The output channel of the last context module. Default: 128.
471 | arm_out_chs (List(int), optional): The out channels of each arm module. Default: [64, 96, 128].
472 | seg_head_inter_chs (List(int), optional): The intermediate channels of segmentation head.
473 | Default: [64, 64, 64].
474 | resize_mode (str, optional): The resize mode for the upsampling operation in decoder.
475 | Default: bilinear.
476 | pretrained (str, optional): The path or url of pretrained model. Default: None.
477 |
478 | """
479 |
480 | def __init__(self,
481 | num_classes = 19,
482 | backbone = STDC2(),
483 | backbone_indices=[2, 3, 4],
484 | arm_type='UAFM_SpAtten',
485 | cm_bin_sizes=[1, 2, 4],
486 | cm_out_ch=128,
487 | arm_out_chs=[64, 96, 128],
488 | seg_head_inter_chs=[64, 64, 64],
489 | resize_mode='nearest',
490 | pretrained=False):
491 | super().__init__()
492 |
493 | # backbone
494 | assert hasattr(backbone, 'feat_channels'), \
495 | "The backbone should has feat_channels."
496 | assert len(backbone.feat_channels) >= len(backbone_indices), \
497 | f"The length of input backbone_indices ({len(backbone_indices)}) should not be" \
498 | f"greater than the length of feat_channels ({len(backbone.feat_channels)})."
499 | assert len(backbone.feat_channels) > max(backbone_indices), \
500 | f"The max value ({max(backbone_indices)}) of backbone_indices should be " \
501 | f"less than the length of feat_channels ({len(backbone.feat_channels)})."
502 | self.backbone = backbone
503 |
504 | assert len(backbone_indices) > 1, "The lenght of backbone_indices " \
505 | "should be greater than 1"
506 | self.backbone_indices = backbone_indices # [..., x16_id, x32_id]
507 | backbone_out_chs = [backbone.feat_channels[i] for i in backbone_indices]
508 |
509 | # head
510 | if len(arm_out_chs) == 1:
511 | arm_out_chs = arm_out_chs * len(backbone_indices)
512 | assert len(arm_out_chs) == len(backbone_indices), "The length of " \
513 | "arm_out_chs and backbone_indices should be equal"
514 |
515 | self.ppseg_head = PPLiteSegHead(backbone_out_chs, arm_out_chs,
516 | cm_bin_sizes, cm_out_ch, arm_type,
517 | resize_mode)
518 |
519 | if len(seg_head_inter_chs) == 1:
520 | seg_head_inter_chs = seg_head_inter_chs * len(backbone_indices)
521 | assert len(seg_head_inter_chs) == len(backbone_indices), "The length of " \
522 | "seg_head_inter_chs and backbone_indices should be equal"
523 | self.seg_heads = nn.ModuleList() # [..., head_16, head32]
524 | print("arm_out_chs:",arm_out_chs, " ; seg_head_inter_chs:",seg_head_inter_chs)
525 | for in_ch, mid_ch in zip(arm_out_chs, seg_head_inter_chs):
526 | self.seg_heads.append(SegHead(in_ch, mid_ch, num_classes))
527 |
528 | # pretrained
529 | self.pretrained = pretrained
530 | # self.init_weight()
531 |
532 | def forward(self, x):
533 | x_hw = x.shape[2:]
534 | # print("x_hw:",x_hw)
535 |
536 | feats_backbone = self.backbone(x) # [x2, x4, x8, x16, x32]
537 | # print(type(feats_backbone))
538 | assert len(feats_backbone) >= len(self.backbone_indices), \
539 | f"The nums of backbone feats ({len(feats_backbone)}) should be greater or " \
540 | f"equal than the nums of backbone_indices ({len(self.backbone_indices)})"
541 |
542 | feats_selected = [feats_backbone[i] for i in self.backbone_indices]
543 |
544 | feats_head = self.ppseg_head(feats_selected) # [..., x8, x16, x32]
545 |
546 | if self.training:
547 | logit_list = []
548 |
549 | for x, seg_head in zip(feats_head, self.seg_heads):
550 | x = seg_head(x)
551 | logit_list.append(x)
552 |
553 | logit_list = [
554 | F.interpolate(
555 | x, x_hw, mode='bilinear', align_corners=None)
556 | for x in logit_list
557 | ]
558 | else:
559 | x = self.seg_heads[0](feats_head[0])
560 | # print("x:",x.shape)
561 | x = F.interpolate(x, x_hw, mode='bilinear', align_corners=None)
562 | logit_list = [x]
563 |
564 | return logit_list
565 |
566 | # def init_weight(self):
567 | # if self.pretrained is not None:
568 | # utils.load_entire_model(self, self.pretrained)
569 |
570 |
571 | class PPLiteSegHead(nn.Module):
572 | """
573 | The head of PPLiteSeg.
574 |
575 | Args:
576 | backbone_out_chs (List(Tensor)): The channels of output tensors in the backbone.
577 | arm_out_chs (List(int)): The out channels of each arm module.
578 | cm_bin_sizes (List(int)): The bin size of context module.
579 | cm_out_ch (int): The output channel of the last context module.
580 | arm_type (str): The type of attention refinement module.
581 | resize_mode (str): The resize mode for the upsampling operation in decoder.
582 | """
583 |
584 | def __init__(self, backbone_out_chs, arm_out_chs, cm_bin_sizes, cm_out_ch,
585 | arm_type, resize_mode):
586 | super().__init__()
587 |
588 | self.cm = PPContextModule(backbone_out_chs[-1], cm_out_ch, cm_out_ch,
589 | cm_bin_sizes)
590 |
591 | # assert hasattr(layers, arm_type), \
592 | # "Not support arm_type ({})".format(arm_type)
593 | arm_class = eval(arm_type)
594 |
595 | self.arm_list = nn.ModuleList() # [..., arm8, arm16, arm32]
596 | for i in range(len(backbone_out_chs)):
597 | low_chs = backbone_out_chs[i]
598 | high_ch = cm_out_ch if i == len(
599 | backbone_out_chs) - 1 else arm_out_chs[i + 1]
600 | out_ch = arm_out_chs[i]
601 | arm = arm_class(
602 | low_chs, high_ch, out_ch, ksize=3, resize_mode=resize_mode)
603 | self.arm_list.append(arm)
604 |
605 | def forward(self, in_feat_list):
606 | """
607 | Args:
608 | in_feat_list (List(Tensor)): Such as [x2, x4, x8, x16, x32].
609 | x2, x4 and x8 are optional.
610 | Returns:
611 | out_feat_list (List(Tensor)): Such as [x2, x4, x8, x16, x32].
612 | x2, x4 and x8 are optional.
613 | The length of in_feat_list and out_feat_list are the same.
614 | """
615 |
616 | high_feat = self.cm(in_feat_list[-1])
617 | out_feat_list = []
618 |
619 | for i in reversed(range(len(in_feat_list))):
620 | low_feat = in_feat_list[i]
621 | arm = self.arm_list[i]
622 | high_feat = arm(low_feat, high_feat)
623 | out_feat_list.insert(0, high_feat)
624 |
625 | return out_feat_list
626 |
627 |
628 | class PPContextModule(nn.Module):
629 | """
630 | Simple Context module.
631 |
632 | Args:
633 | in_channels (int): The number of input channels to pyramid pooling module.
634 | inter_channels (int): The number of inter channels to pyramid pooling module.
635 | out_channels (int): The number of output channels after pyramid pooling module.
636 | bin_sizes (tuple, optional): The out size of pooled feature maps. Default: (1, 3).
637 | align_corners (bool): An argument of F.interpolate. It should be set to False
638 | when the output size of feature is even, e.g. 1024x512, otherwise it is True, e.g. 769x769.
639 | """
640 |
641 | def __init__(self,
642 | in_channels,
643 | inter_channels,
644 | out_channels,
645 | bin_sizes,
646 | align_corners=None):
647 | super().__init__()
648 |
649 | self.stages = nn.ModuleList([
650 | self._make_stage(in_channels, inter_channels, size)
651 | for size in bin_sizes
652 | ])
653 |
654 | self.conv_out = ConvBNReLU(
655 | in_channels=inter_channels,
656 | out_channels=out_channels,
657 | kernel_size=3,
658 | padding=1,
659 | bias=True)
660 |
661 | self.align_corners = align_corners
662 |
663 | def _make_stage(self, in_channels, out_channels, size):
664 | prior = nn.AdaptiveAvgPool2d(output_size=size)
665 | conv = ConvBNReLU(
666 | in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=True)
667 | return nn.Sequential(prior, conv)
668 |
669 | def forward(self, input):
670 | out = None
671 | input_shape = input.shape[2:]
672 |
673 | for stage in self.stages:
674 | x = stage(input)
675 | x = F.interpolate(
676 | x,
677 | input_shape,
678 | mode='nearest',
679 | align_corners=self.align_corners)
680 | if out is None:
681 | out = x
682 | else:
683 | out += x
684 |
685 | out = self.conv_out(out)
686 | return out
687 |
688 |
689 | class SegHead(nn.Module):
690 | def __init__(self, in_chan, mid_chan, n_classes):
691 | super().__init__()
692 | self.conv = ConvBNReLU(
693 | in_chan,
694 | mid_chan,
695 | kernel_size=3,
696 | stride=1,
697 | padding=1,
698 | bias=False)
699 | # print("="*100)
700 | # print("out:",mid_chan, "n_classes:",n_classes)
701 | self.conv_out = nn.Conv2d(
702 | mid_chan, n_classes, kernel_size=1, bias=False)
703 |
704 | def forward(self, x):
705 | x = self.conv(x)
706 | x = self.conv_out(x)
707 | return x
708 |
709 | #
710 | # def get_seg_model(**kwargs):
711 | # model = PPLiteSeg(pretrained=False)
712 | # return model
713 |
--------------------------------------------------------------------------------
/ppliteseg_paddlepaddle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/PPLiteSeg.pytorch/888892f047a6a02b4cd88ba2e4924df1693464e5/ppliteseg_paddlepaddle.png
--------------------------------------------------------------------------------
/results.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/PPLiteSeg.pytorch/888892f047a6a02b4cd88ba2e4924df1693464e5/results.jpg
--------------------------------------------------------------------------------