├── README.md ├── demo.py ├── pp_liteseg.py ├── ppliteseg_paddlepaddle.png └── results.jpg /README.md: -------------------------------------------------------------------------------- 1 | # PPLiteSeg.pytorch 2 | pytorch of the SOTA real-time segmentation network ppliteseg 3 | 4 | | Model | Backbone | Training Iters | Train Resolution | Test Resolution | mIoU | mIoU (flip) | mIoU (ms+flip) | 5 | | ------------ | -------- | -------------- | ---------------- | --------------- | ------ | ----------- | -------------- | 6 | | PP-LiteSeg-T | STDC1 | 160000 | 1024x512 | 1025x512 | 73.10% | 73.89% | - | 7 | | PP-LiteSeg-T | STDC1 | 160000 | 1024x512 | 1536x768 | 76.03% | 76.74% | - | 8 | | PP-LiteSeg-T | STDC1 | 160000 | 1024x512 | 2048x1024 | 77.04% | 77.73% | 77.46% | 9 | | PP-LiteSeg-B | STDC2 | 160000 | 1024x512 | 1024x512 | 75.25% | 75.65% | - | 10 | | PP-LiteSeg-B | STDC2 | 160000 | 1024x512 | 1536x768 | 78.75% | 79.23% | - | 11 | | PP-LiteSeg-B | STDC2 | 160000 | 1024x512 | 2048x1024 | 79.04% | 79.52% | 79.85% | 12 | 13 | here we convert the model and weights of PP-LiteSeg-B(1024x512) from paddlepaddle to torch. 14 | 15 | ## Model&Weight 16 | 17 | pp_liteseg.py : pytorch model 18 | 19 | [ppliteset_pp2torch_cityscape_pretrained.pth](https://github.com/midasklr/PPLiteSeg.pytorch/releases/download/weights/ppliteset_pp2torch_cityscape_pretrained.pth): the cityscape pretrained weights trained with paddleseg 20 | 21 | demo of paddleseg: 22 | 23 |

24 | 25 | demo of pytorch: 26 | 27 |

28 | 29 | ## Difference 30 | 31 | ### upsample 32 | 33 | PaddleSeg use "bilinear" mode, while in pytorch, I use "nereast" mode in order to convert to TensorRT . 34 | 35 | ### BatchNorm 36 | 37 | paddleseg :momentum=0.9 38 | 39 | while the default setting of torch: momentum=0.1. 40 | 41 | ## train 42 | 43 | use [ddrnet](https://github.com/midasklr/DDRNet.Pytorch) to train this model, set the coefficient of three seghead(1/8,1/16and1/32) loss to 1 while training. 44 | 45 | ## Demo 46 | 47 | see demo.py 48 | 49 | ## reference 50 | 51 | 1. https://github.com/PaddlePaddle/PaddleSeg/tree/develop/configs/pp_liteseg 52 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | import argparse 8 | import os 9 | import pprint 10 | import shutil 11 | import sys 12 | import random 13 | import logging 14 | import time 15 | import timeit 16 | from pathlib import Path 17 | import time 18 | import numpy as np 19 | 20 | import torch 21 | import torch.nn as nn 22 | from pp_liteseg import PPLiteSeg 23 | import cv2 24 | import torch.nn.functional as F 25 | import datasets 26 | 27 | 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser(description='Train segmentation network') 31 | 32 | parser.add_argument('--image', 33 | help='test image path', 34 | default="mainz_000001_009328_leftImg8bit.png", 35 | type=str) 36 | parser.add_argument('--weights', 37 | help='cityscape pretrained weights', 38 | default="ppliteset_pp2torch_cityscape_pretrained.pth", 39 | type=str) 40 | parser.add_argument('opts', 41 | help="Modify config options using the command-line", 42 | default=None, 43 | nargs=argparse.REMAINDER) 44 | 45 | args = parser.parse_args() 46 | 47 | return args 48 | 49 | 50 | def colorEncode(labelmap, colors, mode='RGB'): 51 | labelmap = labelmap.astype('int') 52 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), 53 | dtype=np.uint8) 54 | for label in np.unique(labelmap): 55 | if label < 0: 56 | continue 57 | labelmap_rgb = labelmap_rgb + (labelmap == label)[:, :, np.newaxis] * \ 58 | np.tile(colors[label], 59 | (labelmap.shape[0], labelmap.shape[1], 1)) 60 | 61 | if mode == 'BGR': 62 | return labelmap_rgb[:, :, ::-1] 63 | else: 64 | return labelmap_rgb 65 | 66 | 67 | def main(): 68 | base_size = 512 69 | wh = 2 70 | mean = [0.5, 0.5, 0.5], 71 | std = [0.5, 0.5, 0.5] 72 | args = parse_args() 73 | 74 | model = PPLiteSeg() 75 | 76 | model.eval() 77 | 78 | print("ppliteseg:", model) 79 | ckpt = torch.load(args.weights) 80 | model = model.cuda() 81 | if 'state_dict' in ckpt: 82 | model.load_state_dict(ckpt['state_dict']) 83 | else: 84 | model.load_state_dict(ckpt) 85 | 86 | img = cv2.imread(args.image) 87 | imgor = img.copy() 88 | img = cv2.resize(img, (wh * base_size, base_size)) 89 | image = img.astype(np.float32)[:, :, ::-1] 90 | image = image / 255.0 91 | image -= mean 92 | image /= std 93 | 94 | image = image.transpose((2, 0, 1)) 95 | image = torch.from_numpy(image) 96 | 97 | # image = image.permute((2, 0, 1)) 98 | 99 | image = image.unsqueeze(0) 100 | image = image.cuda() 101 | start = time.time() 102 | out = model(image) 103 | end = time.time() 104 | print("infer time:", end - start, " s") 105 | out = out[0].squeeze(dim=0) 106 | outadd = F.softmax(out, dim=0) 107 | outadd = torch.argmax(outadd, dim=0) 108 | predadd = outadd.detach().cpu().numpy() 109 | pred = np.int32(predadd) 110 | colors = np.random.randint(0, 255, 19 * 3) 111 | colors = np.reshape(colors, (19, 3)) 112 | # colorize prediction 113 | pred_color = colorEncode(pred, colors).astype(np.uint8) 114 | pred_color = cv2.resize(pred_color,(imgor.shape[1],imgor.shape[0])) 115 | 116 | im_vis = cv2.addWeighted(imgor, 0.7, pred_color, 0.3, 0) 117 | cv2.imwrite("results.jpg", im_vis) 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /pp_liteseg.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import math 19 | 20 | 21 | class ConvBN(nn.Module): 22 | def __init__(self, 23 | in_channels, 24 | out_channels, 25 | kernel_size, 26 | stride=1, 27 | padding=1, 28 | bias = False, 29 | **kwargs): 30 | super().__init__() 31 | self._conv = nn.Conv2d( 32 | in_channels, out_channels, kernel_size, stride=stride, padding=kernel_size//2 if padding else 0, 33 | bias = bias, **kwargs) 34 | self._batch_norm = nn.BatchNorm2d(out_channels, momentum=0.1) 35 | 36 | def forward(self, x): 37 | x = self._conv(x) 38 | x = self._batch_norm(x) 39 | return x 40 | 41 | 42 | class ConvBNReLU(nn.Module): 43 | def __init__(self, 44 | in_channels, 45 | out_channels, 46 | kernel_size=3, 47 | stride = 1, 48 | padding=1, 49 | bias = False, 50 | **kwargs): 51 | super().__init__() 52 | 53 | self._conv = nn.Conv2d( 54 | in_channels, out_channels, kernel_size, stride=stride, padding=kernel_size//2 if padding else 0, bias = bias,**kwargs) 55 | 56 | self._batch_norm = nn.BatchNorm2d(out_channels, momentum=0.1) 57 | self._relu = nn.ReLU(inplace=True) 58 | 59 | def forward(self, x): 60 | x = self._conv(x) 61 | x = self._batch_norm(x) 62 | x = self._relu(x) 63 | return x 64 | 65 | 66 | class ConvBNRelu(nn.Module): 67 | def __init__(self, 68 | in_channels, 69 | out_channels, 70 | kernel_size=3, 71 | stride = 1, 72 | padding=1, 73 | bias = False, 74 | **kwargs): 75 | super().__init__() 76 | 77 | self.conv = nn.Conv2d( 78 | in_channels, out_channels, kernel_size, stride=stride, padding=kernel_size//2 if padding else 0, bias = bias, **kwargs) 79 | 80 | self.bn = nn.BatchNorm2d(out_channels, momentum=0.1) 81 | self.relu = nn.ReLU(inplace=True) 82 | 83 | def forward(self, x): 84 | x = self.conv(x) 85 | x = self.bn(x) 86 | x = self.relu(x) 87 | return x 88 | 89 | 90 | def avg_max_reduce_channel_helper(x, use_concat=True): 91 | # Reduce hw by avg and max, only support single input 92 | assert not isinstance(x, (list, tuple)) 93 | # print("x before mean and max:", x.shape) 94 | mean_value = torch.mean(x, dim=1, keepdim=True) 95 | max_value = torch.max(x, dim=1, keepdim=True)[0] 96 | # mean_value = mean_value.unsqueeze(0) 97 | # print("mean max:", mean_value.shape, max_value.shape) 98 | 99 | if use_concat: 100 | res = torch.at([mean_value, max_value], dim=1) 101 | else: 102 | res = [mean_value, max_value] 103 | return res 104 | 105 | 106 | def avg_max_reduce_channel(x): 107 | # Reduce hw by avg and max 108 | # Return cat([avg_ch_0, max_ch_0, avg_ch_1, max_ch_1, ...]) 109 | if not isinstance(x, (list, tuple)): 110 | return avg_max_reduce_channel_helper(x) 111 | elif len(x) == 1: 112 | return avg_max_reduce_channel_helper(x[0]) 113 | else: 114 | res = [] 115 | for xi in x: 116 | # print(xi.shape) 117 | res.extend(avg_max_reduce_channel_helper(xi, False)) 118 | # print("res:\n",) 119 | # for it in res: 120 | # print(it.shape) 121 | return torch.cat(res, dim=1) 122 | 123 | 124 | class UAFM(nn.Module): 125 | """ 126 | The base of Unified Attention Fusion Module. 127 | Args: 128 | x_ch (int): The channel of x tensor, which is the low level feature. 129 | y_ch (int): The channel of y tensor, which is the high level feature. 130 | out_ch (int): The channel of output tensor. 131 | ksize (int, optional): The kernel size of the conv for x tensor. Default: 3. 132 | resize_mode (str, optional): The resize model in unsampling y tensor. Default: bilinear. 133 | """ 134 | 135 | def __init__(self, x_ch, y_ch, out_ch, ksize=3, resize_mode='nearest'): 136 | super().__init__() 137 | 138 | self.conv_x = ConvBNReLU( 139 | x_ch, y_ch, kernel_size=ksize, padding=ksize // 2, bias=False) 140 | self.conv_out = ConvBNReLU( 141 | y_ch, out_ch, kernel_size=3, padding=1, bias=False) 142 | self.resize_mode = resize_mode 143 | 144 | def check(self, x, y): 145 | # print("x dim:",x.ndim) 146 | assert x.ndim == 4 and y.ndim == 4 147 | x_h, x_w = x.shape[2:] 148 | y_h, y_w = y.shape[2:] 149 | assert x_h >= y_h and x_w >= y_w 150 | 151 | def prepare(self, x, y): 152 | x = self.prepare_x(x, y) 153 | y = self.prepare_y(x, y) 154 | return x, y 155 | 156 | def prepare_x(self, x, y): 157 | x = self.conv_x(x) 158 | return x 159 | 160 | def prepare_y(self, x, y): 161 | y_up = F.interpolate(y, x.shape[2:], mode=self.resize_mode) 162 | return y_up 163 | 164 | def fuse(self, x, y): 165 | out = x + y 166 | out = self.conv_out(out) 167 | return out 168 | 169 | def forward(self, x, y): 170 | """ 171 | Args: 172 | x (Tensor): The low level feature. 173 | y (Tensor): The high level feature. 174 | """ 175 | # print("x,y shape:",x.shape, y.shape) 176 | self.check(x, y) 177 | x, y = self.prepare(x, y) 178 | out = self.fuse(x, y) 179 | return out 180 | 181 | 182 | class UAFM_SpAtten(UAFM): 183 | """ 184 | The UAFM with spatial attention, which uses mean and max values. 185 | Args: 186 | x_ch (int): The channel of x tensor, which is the low level feature. 187 | y_ch (int): The channel of y tensor, which is the high level feature. 188 | out_ch (int): The channel of output tensor. 189 | ksize (int, optional): The kernel size of the conv for x tensor. Default: 3. 190 | resize_mode (str, optional): The resize model in unsampling y tensor. Default: bilinear. 191 | """ 192 | 193 | def __init__(self, x_ch, y_ch, out_ch, ksize=3, resize_mode='nearest'): 194 | super().__init__(x_ch, y_ch, out_ch, ksize, resize_mode) 195 | 196 | self.conv_xy_atten = nn.Sequential( 197 | ConvBNReLU( 198 | 4, 2, kernel_size=3, padding=1, bias=False), 199 | ConvBN( 200 | 2, 1, kernel_size=3, padding=1, bias=False)) 201 | 202 | def fuse(self, x, y): 203 | """ 204 | Args: 205 | x (Tensor): The low level feature. 206 | y (Tensor): The high level feature. 207 | """ 208 | # print("x, y shape:",x.shape, y.shape) 209 | atten = avg_max_reduce_channel([x, y]) 210 | atten = F.sigmoid(self.conv_xy_atten(atten)) 211 | 212 | out = x * atten + y * (1 - atten) 213 | out = self.conv_out(out) 214 | return out 215 | 216 | 217 | class CatBottleneck(nn.Module): 218 | def __init__(self, in_planes, out_planes, block_num=3, stride=1): 219 | super(CatBottleneck, self).__init__() 220 | assert block_num > 1, "block number should be larger than 1." 221 | self.conv_list = nn.ModuleList() 222 | self.stride = stride 223 | if stride == 2: 224 | self.avd_layer = nn.Sequential( 225 | nn.Conv2d( 226 | out_planes // 2, 227 | out_planes // 2, 228 | kernel_size=3, 229 | stride=2, 230 | padding=1, 231 | groups=out_planes // 2, 232 | bias=False), 233 | nn.BatchNorm2d(out_planes // 2, momentum=0.1), ) 234 | self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) 235 | stride = 1 236 | 237 | for idx in range(block_num): 238 | if idx == 0: 239 | self.conv_list.append( 240 | ConvBNRelu( 241 | in_planes, out_planes // 2, kernel_size=1)) 242 | elif idx == 1 and block_num == 2: 243 | self.conv_list.append( 244 | ConvBNRelu( 245 | out_planes // 2, out_planes // 2, stride=stride)) 246 | elif idx == 1 and block_num > 2: 247 | self.conv_list.append( 248 | ConvBNRelu( 249 | out_planes // 2, out_planes // 4, stride=stride)) 250 | elif idx < block_num - 1: 251 | self.conv_list.append( 252 | ConvBNRelu(out_planes // int(math.pow(2, idx)), out_planes 253 | // int(math.pow(2, idx + 1)))) 254 | else: 255 | self.conv_list.append( 256 | ConvBNRelu(out_planes // int(math.pow(2, idx)), out_planes 257 | // int(math.pow(2, idx)))) 258 | 259 | def forward(self, x): 260 | out_list = [] 261 | out1 = self.conv_list[0](x) 262 | for idx, conv in enumerate(self.conv_list[1:]): 263 | if idx == 0: 264 | if self.stride == 2: 265 | out = conv(self.avd_layer(out1)) 266 | else: 267 | out = conv(out1) 268 | else: 269 | out = conv(out) 270 | out_list.append(out) 271 | 272 | if self.stride == 2: 273 | out1 = self.skip(out1) 274 | out_list.insert(0, out1) 275 | out = torch.cat(out_list, dim=1) 276 | return out 277 | 278 | 279 | class AddBottleneck(nn.Module): 280 | def __init__(self, in_planes, out_planes, block_num=3, stride=1): 281 | super(AddBottleneck, self).__init__() 282 | assert block_num > 1, "block number should be larger than 1." 283 | self.conv_list = nn.ModuleList() 284 | self.stride = stride 285 | if stride == 2: 286 | self.avd_layer = nn.Sequential( 287 | nn.Conv2d( 288 | out_planes // 2, 289 | out_planes // 2, 290 | kernel_size=3, 291 | stride=2, 292 | padding=1, 293 | groups=out_planes // 2, 294 | bias=False), 295 | nn.BatchNorm2D(out_planes // 2, momentum=0.1), ) 296 | self.skip = nn.Sequential( 297 | nn.Conv2d( 298 | in_planes, 299 | in_planes, 300 | kernel_size=3, 301 | stride=2, 302 | padding=1, 303 | groups=in_planes, 304 | bias_attr=False), 305 | nn.BatchNorm2d(in_planes, momentum=0.1), 306 | nn.Conv2d( 307 | in_planes, out_planes, kernel_size=1, bias=False), 308 | nn.BatchNorm2d(out_planes, momentum=0.1), ) 309 | stride = 1 310 | 311 | for idx in range(block_num): 312 | if idx == 0: 313 | self.conv_list.append( 314 | ConvBNRelu( 315 | in_planes, out_planes // 2, kernel=1)) 316 | elif idx == 1 and block_num == 2: 317 | self.conv_list.append( 318 | ConvBNRelu( 319 | out_planes // 2, out_planes // 2, stride=stride)) 320 | elif idx == 1 and block_num > 2: 321 | self.conv_list.append( 322 | ConvBNRelu( 323 | out_planes // 2, out_planes // 4, stride=stride)) 324 | elif idx < block_num - 1: 325 | self.conv_list.append( 326 | ConvBNRelu(out_planes // int(math.pow(2, idx)), out_planes 327 | // int(math.pow(2, idx + 1)))) 328 | else: 329 | self.conv_list.append( 330 | ConvBNRelu(out_planes // int(math.pow(2, idx)), out_planes 331 | // int(math.pow(2, idx)))) 332 | 333 | def forward(self, x): 334 | out_list = [] 335 | out = x 336 | for idx, conv in enumerate(self.conv_list): 337 | if idx == 0 and self.stride == 2: 338 | out = self.avd_layer(conv(out)) 339 | else: 340 | out = conv(out) 341 | out_list.append(out) 342 | if self.stride == 2: 343 | x = self.skip(x) 344 | return torch.cat(out_list, dim=1) + x 345 | 346 | 347 | class STDCNet(nn.Module): 348 | """ 349 | The STDCNet implementation based on Pytorch. 350 | 351 | The original article refers to Meituan 352 | Fan, Mingyuan, et al. "Rethinking BiSeNet For Real-time Semantic Segmentation." 353 | (https://arxiv.org/abs/2104.13188) 354 | 355 | Args: 356 | base(int, optional): base channels. Default: 64. 357 | layers(list, optional): layers numbers list. It determines STDC block numbers of STDCNet's stage3\4\5. Defualt: [4, 5, 3]. 358 | block_num(int,optional): block_num of features block. Default: 4. 359 | type(str,optional): feature fusion method "cat"/"add". Default: "cat". 360 | num_classes(int, optional): class number for image classification. Default: 1000. 361 | dropout(float,optional): dropout ratio. if >0,use dropout ratio. Default: 0.20. 362 | use_conv_last(bool,optional): whether to use the last ConvBNReLU layer . Default: False. 363 | pretrained(str, optional): the path of pretrained model. 364 | """ 365 | 366 | def __init__(self, 367 | base=64, 368 | layers=[4, 5, 3], 369 | block_num=4, 370 | type="cat", 371 | num_classes=1000, 372 | dropout=0.20, 373 | use_conv_last=False, 374 | pretrained=None): 375 | super(STDCNet, self).__init__() 376 | if type == "cat": 377 | block = CatBottleneck 378 | elif type == "add": 379 | block = AddBottleneck 380 | self.use_conv_last = use_conv_last 381 | self.feat_channels = [base // 2, base, base * 4, base * 8, base * 16] 382 | self.features = self._make_layers(base, layers, block_num, block) 383 | self.conv_last = ConvBNRelu(base * 16, max(1024, base * 16), 1, 1) 384 | 385 | if (layers == [4, 5, 3]): # stdc1446 386 | self.x2 = nn.Sequential(self.features[:1]) 387 | self.x4 = nn.Sequential(self.features[1:2]) 388 | self.x8 = nn.Sequential(self.features[2:6]) 389 | self.x16 = nn.Sequential(self.features[6:11]) 390 | self.x32 = nn.Sequential(self.features[11:]) 391 | elif (layers == [2, 2, 2]): # stdc813 392 | self.x2 = nn.Sequential(self.features[:1]) 393 | self.x4 = nn.Sequential(self.features[1:2]) 394 | self.x8 = nn.Sequential(self.features[2:4]) 395 | self.x16 = nn.Sequential(self.features[4:6]) 396 | self.x32 = nn.Sequential(self.features[6:]) 397 | else: 398 | raise NotImplementedError( 399 | "model with layers:{} is not implemented!".format(layers)) 400 | 401 | self.pretrained = pretrained 402 | # self.init_weight() 403 | 404 | def forward(self, x): 405 | """ 406 | forward function for feature extract. 407 | """ 408 | feat2 = self.x2(x) 409 | feat4 = self.x4(feat2) 410 | feat8 = self.x8(feat4) 411 | feat16 = self.x16(feat8) 412 | feat32 = self.x32(feat16) 413 | if self.use_conv_last: 414 | feat32 = self.conv_last(feat32) 415 | return feat2, feat4, feat8, feat16, feat32 416 | 417 | def _make_layers(self, base, layers, block_num, block): 418 | features = [] 419 | features += [ConvBNRelu(3, base // 2, 3, 2)] 420 | features += [ConvBNRelu(base // 2, base, 3, 2)] 421 | 422 | for i, layer in enumerate(layers): 423 | for j in range(layer): 424 | if i == 0 and j == 0: 425 | features.append(block(base, base * 4, block_num, 2)) 426 | elif j == 0: 427 | features.append( 428 | block(base * int(math.pow(2, i + 1)), base * int( 429 | math.pow(2, i + 2)), block_num, 2)) 430 | else: 431 | features.append( 432 | block(base * int(math.pow(2, i + 2)), base * int( 433 | math.pow(2, i + 2)), block_num, 1)) 434 | 435 | return nn.Sequential(*features) 436 | 437 | # def init_weight(self): 438 | # for layer in self.sublayers(): 439 | # if isinstance(layer, nn.Conv2D): 440 | # param_init.normal_init(layer.weight, std=0.001) 441 | # elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)): 442 | # param_init.constant_init(layer.weight, value=1.0) 443 | # param_init.constant_init(layer.bias, value=0.0) 444 | # if self.pretrained is not None: 445 | # utils.load_pretrained_model(self, self.pretrained) 446 | 447 | 448 | def STDC2(**kwargs): 449 | model = STDCNet(base=64, layers=[4, 5, 3], **kwargs) 450 | return model 451 | 452 | 453 | class PPLiteSeg(nn.Module): 454 | """ 455 | The PP_LiteSeg implementation based on Pytorch. 456 | 457 | The original article refers to "Juncai Peng, Yi Liu, Shiyu Tang, Yuying Hao, Lutao Chu, 458 | Guowei Chen, Zewu Wu, Zeyu Chen, Zhiliang Yu, Yuning Du, Qingqing Dang,Baohua Lai, 459 | Qiwen Liu, Xiaoguang Hu, Dianhai Yu, Yanjun Ma. PP-LiteSeg: A Superior Real-Time Semantic 460 | Segmentation Model. https://arxiv.org/abs/2204.02681". 461 | 462 | Args: 463 | num_classes (int): The number of target classes. 464 | backbone(nn.Layer): Backbone network, such as stdc1net and resnet18. The backbone must 465 | has feat_channels, of which the length is 5. 466 | backbone_indices (List(int), optional): The values indicate the indices of output of backbone. 467 | Default: [2, 3, 4]. 468 | arm_type (str, optional): The type of attention refinement module. Default: ARM_Add_SpAttenAdd3. 469 | cm_bin_sizes (List(int), optional): The bin size of context module. Default: [1,2,4]. 470 | cm_out_ch (int, optional): The output channel of the last context module. Default: 128. 471 | arm_out_chs (List(int), optional): The out channels of each arm module. Default: [64, 96, 128]. 472 | seg_head_inter_chs (List(int), optional): The intermediate channels of segmentation head. 473 | Default: [64, 64, 64]. 474 | resize_mode (str, optional): The resize mode for the upsampling operation in decoder. 475 | Default: bilinear. 476 | pretrained (str, optional): The path or url of pretrained model. Default: None. 477 | 478 | """ 479 | 480 | def __init__(self, 481 | num_classes = 19, 482 | backbone = STDC2(), 483 | backbone_indices=[2, 3, 4], 484 | arm_type='UAFM_SpAtten', 485 | cm_bin_sizes=[1, 2, 4], 486 | cm_out_ch=128, 487 | arm_out_chs=[64, 96, 128], 488 | seg_head_inter_chs=[64, 64, 64], 489 | resize_mode='nearest', 490 | pretrained=False): 491 | super().__init__() 492 | 493 | # backbone 494 | assert hasattr(backbone, 'feat_channels'), \ 495 | "The backbone should has feat_channels." 496 | assert len(backbone.feat_channels) >= len(backbone_indices), \ 497 | f"The length of input backbone_indices ({len(backbone_indices)}) should not be" \ 498 | f"greater than the length of feat_channels ({len(backbone.feat_channels)})." 499 | assert len(backbone.feat_channels) > max(backbone_indices), \ 500 | f"The max value ({max(backbone_indices)}) of backbone_indices should be " \ 501 | f"less than the length of feat_channels ({len(backbone.feat_channels)})." 502 | self.backbone = backbone 503 | 504 | assert len(backbone_indices) > 1, "The lenght of backbone_indices " \ 505 | "should be greater than 1" 506 | self.backbone_indices = backbone_indices # [..., x16_id, x32_id] 507 | backbone_out_chs = [backbone.feat_channels[i] for i in backbone_indices] 508 | 509 | # head 510 | if len(arm_out_chs) == 1: 511 | arm_out_chs = arm_out_chs * len(backbone_indices) 512 | assert len(arm_out_chs) == len(backbone_indices), "The length of " \ 513 | "arm_out_chs and backbone_indices should be equal" 514 | 515 | self.ppseg_head = PPLiteSegHead(backbone_out_chs, arm_out_chs, 516 | cm_bin_sizes, cm_out_ch, arm_type, 517 | resize_mode) 518 | 519 | if len(seg_head_inter_chs) == 1: 520 | seg_head_inter_chs = seg_head_inter_chs * len(backbone_indices) 521 | assert len(seg_head_inter_chs) == len(backbone_indices), "The length of " \ 522 | "seg_head_inter_chs and backbone_indices should be equal" 523 | self.seg_heads = nn.ModuleList() # [..., head_16, head32] 524 | print("arm_out_chs:",arm_out_chs, " ; seg_head_inter_chs:",seg_head_inter_chs) 525 | for in_ch, mid_ch in zip(arm_out_chs, seg_head_inter_chs): 526 | self.seg_heads.append(SegHead(in_ch, mid_ch, num_classes)) 527 | 528 | # pretrained 529 | self.pretrained = pretrained 530 | # self.init_weight() 531 | 532 | def forward(self, x): 533 | x_hw = x.shape[2:] 534 | # print("x_hw:",x_hw) 535 | 536 | feats_backbone = self.backbone(x) # [x2, x4, x8, x16, x32] 537 | # print(type(feats_backbone)) 538 | assert len(feats_backbone) >= len(self.backbone_indices), \ 539 | f"The nums of backbone feats ({len(feats_backbone)}) should be greater or " \ 540 | f"equal than the nums of backbone_indices ({len(self.backbone_indices)})" 541 | 542 | feats_selected = [feats_backbone[i] for i in self.backbone_indices] 543 | 544 | feats_head = self.ppseg_head(feats_selected) # [..., x8, x16, x32] 545 | 546 | if self.training: 547 | logit_list = [] 548 | 549 | for x, seg_head in zip(feats_head, self.seg_heads): 550 | x = seg_head(x) 551 | logit_list.append(x) 552 | 553 | logit_list = [ 554 | F.interpolate( 555 | x, x_hw, mode='bilinear', align_corners=None) 556 | for x in logit_list 557 | ] 558 | else: 559 | x = self.seg_heads[0](feats_head[0]) 560 | # print("x:",x.shape) 561 | x = F.interpolate(x, x_hw, mode='bilinear', align_corners=None) 562 | logit_list = [x] 563 | 564 | return logit_list 565 | 566 | # def init_weight(self): 567 | # if self.pretrained is not None: 568 | # utils.load_entire_model(self, self.pretrained) 569 | 570 | 571 | class PPLiteSegHead(nn.Module): 572 | """ 573 | The head of PPLiteSeg. 574 | 575 | Args: 576 | backbone_out_chs (List(Tensor)): The channels of output tensors in the backbone. 577 | arm_out_chs (List(int)): The out channels of each arm module. 578 | cm_bin_sizes (List(int)): The bin size of context module. 579 | cm_out_ch (int): The output channel of the last context module. 580 | arm_type (str): The type of attention refinement module. 581 | resize_mode (str): The resize mode for the upsampling operation in decoder. 582 | """ 583 | 584 | def __init__(self, backbone_out_chs, arm_out_chs, cm_bin_sizes, cm_out_ch, 585 | arm_type, resize_mode): 586 | super().__init__() 587 | 588 | self.cm = PPContextModule(backbone_out_chs[-1], cm_out_ch, cm_out_ch, 589 | cm_bin_sizes) 590 | 591 | # assert hasattr(layers, arm_type), \ 592 | # "Not support arm_type ({})".format(arm_type) 593 | arm_class = eval(arm_type) 594 | 595 | self.arm_list = nn.ModuleList() # [..., arm8, arm16, arm32] 596 | for i in range(len(backbone_out_chs)): 597 | low_chs = backbone_out_chs[i] 598 | high_ch = cm_out_ch if i == len( 599 | backbone_out_chs) - 1 else arm_out_chs[i + 1] 600 | out_ch = arm_out_chs[i] 601 | arm = arm_class( 602 | low_chs, high_ch, out_ch, ksize=3, resize_mode=resize_mode) 603 | self.arm_list.append(arm) 604 | 605 | def forward(self, in_feat_list): 606 | """ 607 | Args: 608 | in_feat_list (List(Tensor)): Such as [x2, x4, x8, x16, x32]. 609 | x2, x4 and x8 are optional. 610 | Returns: 611 | out_feat_list (List(Tensor)): Such as [x2, x4, x8, x16, x32]. 612 | x2, x4 and x8 are optional. 613 | The length of in_feat_list and out_feat_list are the same. 614 | """ 615 | 616 | high_feat = self.cm(in_feat_list[-1]) 617 | out_feat_list = [] 618 | 619 | for i in reversed(range(len(in_feat_list))): 620 | low_feat = in_feat_list[i] 621 | arm = self.arm_list[i] 622 | high_feat = arm(low_feat, high_feat) 623 | out_feat_list.insert(0, high_feat) 624 | 625 | return out_feat_list 626 | 627 | 628 | class PPContextModule(nn.Module): 629 | """ 630 | Simple Context module. 631 | 632 | Args: 633 | in_channels (int): The number of input channels to pyramid pooling module. 634 | inter_channels (int): The number of inter channels to pyramid pooling module. 635 | out_channels (int): The number of output channels after pyramid pooling module. 636 | bin_sizes (tuple, optional): The out size of pooled feature maps. Default: (1, 3). 637 | align_corners (bool): An argument of F.interpolate. It should be set to False 638 | when the output size of feature is even, e.g. 1024x512, otherwise it is True, e.g. 769x769. 639 | """ 640 | 641 | def __init__(self, 642 | in_channels, 643 | inter_channels, 644 | out_channels, 645 | bin_sizes, 646 | align_corners=None): 647 | super().__init__() 648 | 649 | self.stages = nn.ModuleList([ 650 | self._make_stage(in_channels, inter_channels, size) 651 | for size in bin_sizes 652 | ]) 653 | 654 | self.conv_out = ConvBNReLU( 655 | in_channels=inter_channels, 656 | out_channels=out_channels, 657 | kernel_size=3, 658 | padding=1, 659 | bias=True) 660 | 661 | self.align_corners = align_corners 662 | 663 | def _make_stage(self, in_channels, out_channels, size): 664 | prior = nn.AdaptiveAvgPool2d(output_size=size) 665 | conv = ConvBNReLU( 666 | in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=True) 667 | return nn.Sequential(prior, conv) 668 | 669 | def forward(self, input): 670 | out = None 671 | input_shape = input.shape[2:] 672 | 673 | for stage in self.stages: 674 | x = stage(input) 675 | x = F.interpolate( 676 | x, 677 | input_shape, 678 | mode='nearest', 679 | align_corners=self.align_corners) 680 | if out is None: 681 | out = x 682 | else: 683 | out += x 684 | 685 | out = self.conv_out(out) 686 | return out 687 | 688 | 689 | class SegHead(nn.Module): 690 | def __init__(self, in_chan, mid_chan, n_classes): 691 | super().__init__() 692 | self.conv = ConvBNReLU( 693 | in_chan, 694 | mid_chan, 695 | kernel_size=3, 696 | stride=1, 697 | padding=1, 698 | bias=False) 699 | # print("="*100) 700 | # print("out:",mid_chan, "n_classes:",n_classes) 701 | self.conv_out = nn.Conv2d( 702 | mid_chan, n_classes, kernel_size=1, bias=False) 703 | 704 | def forward(self, x): 705 | x = self.conv(x) 706 | x = self.conv_out(x) 707 | return x 708 | 709 | # 710 | # def get_seg_model(**kwargs): 711 | # model = PPLiteSeg(pretrained=False) 712 | # return model 713 | -------------------------------------------------------------------------------- /ppliteseg_paddlepaddle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/PPLiteSeg.pytorch/888892f047a6a02b4cd88ba2e4924df1693464e5/ppliteseg_paddlepaddle.png -------------------------------------------------------------------------------- /results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/PPLiteSeg.pytorch/888892f047a6a02b4cd88ba2e4924df1693464e5/results.jpg --------------------------------------------------------------------------------