├── .gitignore ├── README.md ├── generator.py ├── model.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | pretrained 3 | pretrained_* 4 | *.ipynb 5 | .ipynb_checkpoints 6 | images 7 | *.json 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # keras-PSPNet 2 | 3 | Pyramid Scene persing Network is a model of semantic segmentation based on Fully Comvolutional Network. 4 | This repository contains the implementation of learning and testing in keras and tensorflow. 5 | 6 | 7 | ## Architecutre 8 | 9 | - atrous convalution 10 | - residual module 11 | - pyramid pooling module 12 | 13 | ## Prerequirements 14 | 15 | - python3.6 16 | - opencv for python 17 | - keras,tensorflow 18 | 19 | ## Usage 20 | 21 | ### train 22 | - Segmentation involveing multiple categories 23 | 24 | ` python train.py --options ` 25 | 26 | - Segmentation of mask image 27 | 28 | ` python train_mask.py --options ` 29 | 30 | - options 31 | - image dir 32 | - mask image dir 33 | - batchsize, nb_epochs, epoch_per_steps, input_configs 34 | - class weights 35 | - device num 36 | 37 | ### test 38 | - Input test image 39 | - responce json format involving category name and color(pixel based prediction) 40 | 41 | ` python predict.py --input_path [path/to//input_imahge] ` 42 | 43 | -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from keras.preprocessing.image import img_to_array 7 | 8 | 9 | def category_label(labels, dims, n_labels): 10 | x = np.zeros([dims[0], dims[1], n_labels]) 11 | for i in range(dims[0]): 12 | for j in range(dims[1]): 13 | x[i, j, labels[i][j]] = 1 14 | x = x.reshape(dims[0] * dims[1], n_labels) 15 | 16 | return x 17 | 18 | 19 | # generator that we will use to read the data from the directory 20 | def data_gen_small(img_dir, mask_dir, lists, batch_size, dims, n_labels): 21 | while True: 22 | ix = np.random.choice(np.arange(len(lists)), batch_size) 23 | imgs = [] 24 | labels = [] 25 | for i in ix: 26 | # images 27 | img_path = os.path.join(img_dir, lists.iloc[i, 0], ".jpg") 28 | original_img = cv2.imread(img_path)[:, :, ::-1] 29 | resized_img = cv2.resize(original_img, (dims[0], dims[1])) 30 | array_img = img_to_array(resized_img) / 255 31 | imgs.append(array_img) 32 | # masks 33 | mask_path = os.path.join(img_dir, lists.iloc[i, 0], ".png") 34 | original_mask = cv2.imread(mask_path) 35 | resized_mask = cv2.resize(original_mask, (dims[0], dims[1])) 36 | array_mask = category_label(resized_mask[:, :, 0], dims, n_labels) 37 | labels.append(array_mask) 38 | imgs = np.array(imgs) 39 | labels = np.array(labels) 40 | 41 | yield imgs, labels 42 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | from keras.engine import InputSpec 3 | from keras.engine.topology import Layer 4 | from keras.layers import (Activation, AveragePooling2D, BatchNormalization, 5 | Conv2D, Conv2DTranspose, Dense, 6 | GlobalAveragePooling2D, Input, Lambda, MaxPooling2D, 7 | Permute, Reshape, ZeroPadding2D, add, concatenate, 8 | merge, multiply) 9 | from keras.models import Model 10 | from keras.utils import conv_utils 11 | 12 | 13 | class CroppingLike2D(Layer): def __init__(self, target_shape, offset=None, data_format=None, **kwargs): 14 | super(CroppingLike2D, self).__init__(**kwargs) 15 | self.data_format = conv_utils.normalize_data_format(data_format) 16 | self.target_shape = target_shape 17 | if offset is None or offset == "centered": 18 | self.offset = "centered" 19 | elif isinstance(offset, int): 20 | self.offset = (offset, offset) 21 | elif hasattr(offset, "__len__"): 22 | if len(offset) != 2: 23 | raise ValueError( 24 | "`offset` should have two elements. " "Found: " + 25 | str(offset) 26 | ) 27 | self.offset = offset 28 | self.input_spec = InputSpec(ndim=4) 29 | 30 | def compute_output_shape(self, input_shape): 31 | if self.data_format == "channels_first": 32 | return ( 33 | input_shape[0], 34 | input_shape[1], 35 | self.target_shape[2], 36 | self.target_shape[3], 37 | ) 38 | else: 39 | return ( 40 | input_shape[0], 41 | self.target_shape[1], 42 | self.target_shape[2], 43 | input_shape[3], 44 | ) 45 | 46 | def call(self, inputs): 47 | input_shape = K.int_shape(inputs) 48 | if self.data_format == "channels_first": 49 | input_height = input_shape[2] 50 | input_width = input_shape[3] 51 | target_height = self.target_shape[2] 52 | target_width = self.target_shape[3] 53 | if target_height > input_height or target_width > input_width: 54 | raise ValueError( 55 | "The Tensor to be cropped need to be smaller" 56 | "or equal to the target Tensor." 57 | ) 58 | 59 | if self.offset == "centered": 60 | self.offset = [ 61 | int((input_height - target_height) / 2), 62 | int((input_width - target_width) / 2), 63 | ] 64 | 65 | if self.offset[0] + target_height > input_height: 66 | raise ValueError( 67 | "Height index out of range: " + str(self.offset[0] + target_height) 68 | ) 69 | if self.offset[1] + target_width > input_width: 70 | raise ValueError( 71 | "Width index out of range:" + str(self.offset[1] + target_width) 72 | ) 73 | 74 | return inputs[ 75 | :, 76 | :, 77 | self.offset[0] : self.offset[0] + target_height, 78 | self.offset[1] : self.offset[1] + target_width, 79 | ] 80 | elif self.data_format == "channels_last": 81 | input_height = input_shape[1] 82 | input_width = input_shape[2] 83 | target_height = self.target_shape[1] 84 | target_width = self.target_shape[2] 85 | if target_height > input_height or target_width > input_width: 86 | raise ValueError( 87 | "The Tensor to be cropped need to be smaller" 88 | "or equal to the target Tensor." 89 | ) 90 | 91 | if self.offset == "centered": 92 | self.offset = [ 93 | int((input_height - target_height) / 2), 94 | int((input_width - target_width) / 2), 95 | ] 96 | 97 | if self.offset[0] + target_height > input_height: 98 | raise ValueError( 99 | "Height index out of range: " + str(self.offset[0] + target_height) 100 | ) 101 | if self.offset[1] + target_width > input_width: 102 | raise ValueError( 103 | "Width index out of range:" + str(self.offset[1] + target_width) 104 | ) 105 | output = inputs[ 106 | :, 107 | self.offset[0] : self.offset[0] + target_height, 108 | self.offset[1] : self.offset[1] + target_width, 109 | :, 110 | ] 111 | return output 112 | 113 | 114 | class BilinearUpSampling2D(Layer): 115 | def __init__(self, target_shape=None, factor=None, data_format=None, **kwargs): 116 | # conmpute dataformat 117 | if data_format is None: 118 | data_format = K.image_data_format() 119 | assert data_format in {"channels_last", "channels_first"} 120 | 121 | self.data_format = data_format 122 | self.input_spec = [InputSpec(ndim=4)] 123 | self.target_shape = target_shape 124 | self.factor = factor 125 | if self.data_format == "channels_first": 126 | self.target_size = (target_shape[2], target_shape[3]) 127 | elif self.data_format == "channels_last": 128 | self.target_size = (target_shape[1], target_shape[2]) 129 | super(BilinearUpSampling2D, self).__init__(**kwargs) 130 | 131 | def compute_output_shape(self, input_shape): 132 | if self.data_format == "channels_last": 133 | return ( 134 | input_shape[0], 135 | self.target_size[0], 136 | self.target_size[1], 137 | input_shape[3], 138 | ) 139 | else: 140 | return ( 141 | input_shape[0], 142 | input_shape[1], 143 | self.target_size[0], 144 | self.target_size[1], 145 | ) 146 | 147 | def call(self, inputs): 148 | return K.resize_images(inputs, self.factor, self.factor, self.data_format) 149 | 150 | def get_config(self): 151 | config = {"target_shape": self.target_shape, "data_format": self.data_format} 152 | base_config = super(BilinearUpSampling2D, self).get_config() 153 | return dict(list(base_config.items()) + list(config.items())) 154 | 155 | 156 | def identity_block( 157 | input_tensor, 158 | kernel_size, 159 | filters, 160 | stage, 161 | block, 162 | dilation_rate=1, 163 | multigrid=[1, 2, 1], 164 | use_se=True, 165 | ): 166 | # conv filters 167 | filters1, filters2, filters3 = filters 168 | 169 | # compute dataformat 170 | if K.image_data_format() == "channels_last": 171 | bn_axis = 3 172 | else: 173 | bn_axis = 1 174 | 175 | # layer names 176 | conv_name_base = "res" + str(stage) + block + "_branch" 177 | bn_name_base = "bn" + str(stage) + block + "_branch" 178 | 179 | # dilated rate 180 | if dilation_rate < 2: 181 | multigrid = [1, 1, 1] 182 | 183 | # forward 184 | x = Conv2D( 185 | filters1, 186 | (1, 1), 187 | name=conv_name_base + "2a", 188 | dilation_rate=dilation_rate * multigrid[0], 189 | )(input_tensor) 190 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + "2a")(x) 191 | x = Activation("relu")(x) 192 | 193 | x = Conv2D( 194 | filters2, 195 | kernel_size, 196 | padding="same", 197 | name=conv_name_base + "2b", 198 | dilation_rate=dilation_rate * multigrid[1], 199 | )(x) 200 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + "2b")(x) 201 | x = Activation("relu")(x) 202 | 203 | x = Conv2D( 204 | filters3, 205 | (1, 1), 206 | name=conv_name_base + "2c", 207 | dilation_rate=dilation_rate * multigrid[2], 208 | )(x) 209 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + "2c")(x) 210 | 211 | # stage 5 after squeeze and excinttation layer 212 | if use_se and stage < 5: 213 | se = _squeeze_excite_block(x, filters3, k=1, name=conv_name_base + "_se") 214 | x = multiply([x, se]) 215 | x = add([x, input_tensor]) 216 | x = Activation("relu")(x) 217 | 218 | return x 219 | 220 | 221 | def _conv(**conv_params): 222 | # conv params 223 | filters = conv_params["filters"] 224 | kernel_size = conv_params["kernel_size"] 225 | strides = conv_params.setdefault("strides", (1, 1)) 226 | dilation_rate = conv_params.setdefault("dilation_rate", (1, 1)) 227 | kernel_initializer = conv_params.setdefault("kernel_initializer", "he_normal") 228 | padding = conv_params.setdefault("padding", "same") 229 | 230 | def f(input): 231 | conv = Conv2D( 232 | filters=filters, 233 | kernel_size=kernel_size, 234 | strides=strides, 235 | padding=padding, 236 | dilation_rate=dilation_rate, 237 | kernel_initializer=kernel_initializer, 238 | activation="linear", 239 | )(input) 240 | return conv 241 | 242 | return f 243 | 244 | 245 | # Atrous Spatial Pyramid Pooling block 246 | def aspp_block( 247 | x, num_filters=256, rate_scale=1, output_stride=16, input_shape=(512, 512, 3) 248 | ): 249 | # compute dataformat 250 | if K.image_data_format() == "channels_last": 251 | bn_axis = 3 252 | else: 253 | bn_axis = 1 254 | 255 | # forward 256 | conv3_3_1 = ZeroPadding2D(padding=(6 * rate_scale, 6 * rate_scale))(x) 257 | conv3_3_1 = _conv( 258 | filters=num_filters, 259 | kernel_size=(3, 3), 260 | dilation_rate=(6 * rate_scale, 6 * rate_scale), 261 | padding="valid", 262 | block="assp_3_3_1_%s" % output_stride, 263 | )(conv3_3_1) 264 | conv3_3_1 = BatchNormalization(axis=bn_axis, name="bn_3_3_1_%s" % output_stride)( 265 | conv3_3_1 266 | ) 267 | 268 | conv3_3_2 = ZeroPadding2D(padding=(12 * rate_scale, 12 * rate_scale))(x) 269 | conv3_3_2 = _conv( 270 | filters=num_filters, 271 | kernel_size=(3, 3), 272 | dilation_rate=(12 * rate_scale, 12 * rate_scale), 273 | padding="valid", 274 | block="assp_3_3_2_%s" % output_stride, 275 | )(conv3_3_2) 276 | conv3_3_2 = BatchNormalization(axis=bn_axis, name="bn_3_3_2_%s" % output_stride)( 277 | conv3_3_2 278 | ) 279 | 280 | conv3_3_3 = ZeroPadding2D(padding=(18 * rate_scale, 18 * rate_scale))(x) 281 | conv3_3_3 = _conv( 282 | filters=num_filters, 283 | kernel_size=(3, 3), 284 | dilation_rate=(18 * rate_scale, 18 * rate_scale), 285 | padding="valid", 286 | block="assp_3_3_3_%s" % output_stride, 287 | )(conv3_3_3) 288 | conv3_3_3 = BatchNormalization(axis=bn_axis, name="bn_3_3_3_%s" % output_stride)( 289 | conv3_3_3 290 | ) 291 | 292 | conv1_1 = _conv( 293 | filters=num_filters, 294 | kernel_size=(1, 1), 295 | padding="same", 296 | block="assp_1_1_%s" % output_stride, 297 | )(x) 298 | conv1_1 = BatchNormalization(axis=bn_axis, name="bn_1_1_%s" % output_stride)( 299 | conv1_1 300 | ) 301 | 302 | # channel merge 303 | y = merge( 304 | [conv3_3_1, conv3_3_2, conv3_3_3, conv1_1], 305 | # global_feat, 306 | mode="concat", 307 | concat_axis=3, 308 | ) 309 | 310 | # y = _conv_bn_relu(filters=1, kernel_size=(1, 1),padding='same')(y) 311 | y = _conv( 312 | filters=256, 313 | kernel_size=(1, 1), 314 | padding="same", 315 | block="assp_out_%s" % output_stride, 316 | )(y) 317 | y = BatchNormalization(axis=bn_axis, name="bn_out_%s" % output_stride)(y) 318 | 319 | return y 320 | 321 | 322 | # residual module 323 | def conv_block( 324 | input_tensor, 325 | kernel_size, 326 | filters, 327 | stage, 328 | block, 329 | strides=(2, 2), 330 | dilation_rate=1, 331 | multigrid=[1, 2, 1], 332 | use_se=True, 333 | ): 334 | # conv filters 335 | filters1, filters2, filters3 = filters 336 | 337 | # compute dataformat 338 | if K.image_data_format() == "channels_last": 339 | bn_axis = 3 340 | else: 341 | bn_axis = 1 342 | conv_name_base = "res" + str(stage) + block + "_branch" 343 | bn_name_base = "bn" + str(stage) + block + "_branch" 344 | 345 | # dailated rate 346 | if dilation_rate > 1: 347 | strides = (1, 1) 348 | else: 349 | multigrid = [1, 1, 1] 350 | 351 | # forward 352 | x = Conv2D( 353 | filters1, 354 | (1, 1), 355 | strides=strides, 356 | name=conv_name_base + "2a", 357 | dilation_rate=dilation_rate * multigrid[0], 358 | )(input_tensor) 359 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + "2a")(x) 360 | x = Activation("relu")(x) 361 | 362 | x = Conv2D( 363 | filters2, 364 | kernel_size, 365 | padding="same", 366 | name=conv_name_base + "2b", 367 | dilation_rate=dilation_rate * multigrid[1], 368 | )(x) 369 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + "2b")(x) 370 | x = Activation("relu")(x) 371 | 372 | x = Conv2D( 373 | filters3, 374 | (1, 1), 375 | name=conv_name_base + "2c", 376 | dilation_rate=dilation_rate * multigrid[2], 377 | )(x) 378 | x = BatchNormalization(axis=bn_axis, name=bn_name_base + "2c")(x) 379 | 380 | shortcut = Conv2D(filters3, (1, 1), strides=strides, name=conv_name_base + "1")( 381 | input_tensor 382 | ) 383 | shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + "1")(shortcut) 384 | 385 | # stage after 5 squeeze and excittation 386 | if use_se and stage < 5: 387 | se = _squeeze_excite_block(x, filters3, k=1, name=conv_name_base + "_se") 388 | x = multiply([x, se]) 389 | x = add([x, shortcut]) 390 | x = Activation("relu")(x) 391 | 392 | return x 393 | 394 | 395 | def duc(x, factor=8, output_shape=(512, 512, 1)): 396 | if K.image_data_format() == "channels_last": 397 | bn_axis = 3 398 | else: 399 | bn_axis = 1 400 | H, W, c, r = output_shape[0], output_shape[1], output_shape[2], factor 401 | h = H / r 402 | w = W / r 403 | x = Conv2D(c * r * r, (3, 3), padding="same", name="conv_duc_%s" % factor)(x) 404 | x = BatchNormalization(axis=bn_axis, name="bn_duc_%s" % factor)(x) 405 | x = Activation("relu")(x) 406 | x = Permute((3, 1, 2))(x) 407 | x = Reshape((c, r, r, h, w))(x) 408 | x = Permute((1, 4, 2, 5, 3))(x) 409 | x = Reshape((c, H, W))(x) 410 | x = Permute((2, 3, 1))(x) 411 | 412 | return x 413 | 414 | 415 | def Interp(x, shape): 416 | """ interpolation """ 417 | from keras.backend import tf as ktf 418 | 419 | new_height, new_width = shape 420 | resized = ktf.image.resize_images( 421 | x, [int(new_height), int(new_width)], align_corners=True 422 | ) 423 | return resized 424 | 425 | 426 | def interp_block( 427 | x, num_filters=512, level=1, input_shape=(512, 512, 3), output_stride=16 428 | ): 429 | """ interpolation block """ 430 | feature_map_shape = (input_shape[0] / output_stride, input_shape[1] / output_stride) 431 | 432 | # compute dataformat 433 | if K.image_data_format() == "channels_last": 434 | bn_axis = 3 435 | else: 436 | bn_axis = 1 437 | 438 | if output_stride == 16: 439 | scale = 5 440 | elif output_stride == 8: 441 | scale = 10 442 | 443 | kernel = (level * scale, level * scale) 444 | strides = (level * scale, level * scale) 445 | global_feat = AveragePooling2D( 446 | kernel, strides=strides, name="pool_level_%s_%s" % (level, output_stride) 447 | )(x) 448 | global_feat = _conv( 449 | filters=num_filters, 450 | kernel_size=(1, 1), 451 | padding="same", 452 | name="conv_level_%s_%s" % (level, output_stride), 453 | )(global_feat) 454 | global_feat = BatchNormalization( 455 | axis=bn_axis, name="bn_level_%s_%s" % (level, output_stride) 456 | )(global_feat) 457 | global_feat = Lambda(Interp, arguments={"shape": feature_map_shape})(global_feat) 458 | 459 | return global_feat 460 | 461 | 462 | # squeeze and excitation function 463 | def _squeeze_excite_block(input, filters, k=1, name=None): 464 | init = input 465 | if K.image_data_format() == "channels_last": 466 | se_shape = (1, 1, filters * k) 467 | else: 468 | se_shape = (filters * k, 1, 1) 469 | 470 | se = GlobalAveragePooling2D()(init) 471 | se = Reshape(se_shape)(se) 472 | se = Dense( 473 | (filters * k) // 16, 474 | activation="relu", 475 | kernel_initializer="he_normal", 476 | use_bias=False, 477 | name=name + "_fc1", 478 | )(se) 479 | se = Dense( 480 | filters * k, 481 | activation="sigmoid", 482 | kernel_initializer="he_normal", 483 | use_bias=False, 484 | name=name + "_fc2", 485 | )(se) 486 | return se 487 | 488 | 489 | def pyramid_pooling_module( 490 | x, num_filters=512, input_shape=(512, 512, 3), output_stride=16, levels=[6, 3, 2, 1] 491 | ): 492 | """ pyramid pooling function """ 493 | 494 | # compute data format 495 | if K.image_data_format() == "channels_last": 496 | bn_axis = 3 497 | else: 498 | bn_axis = 1 499 | 500 | pyramid_pooling_blocks = [x] 501 | for level in levels: 502 | pyramid_pooling_blocks.append( 503 | interp_block( 504 | x, 505 | num_filters=num_filters, 506 | level=level, 507 | input_shape=input_shape, 508 | output_stride=output_stride, 509 | ) 510 | ) 511 | 512 | y = concatenate(pyramid_pooling_blocks) 513 | # y = merge(pyramid_pooling_blocks, mode='concat', concat_axis=3) 514 | y = _conv( 515 | filters=num_filters, 516 | kernel_size=(3, 3), 517 | padding="same", 518 | block="pyramid_out_%s" % output_stride, 519 | )(y) 520 | y = BatchNormalization(axis=bn_axis, name="bn_pyramid_out_%s" % output_stride)(y) 521 | y = Activation("relu")(y) 522 | 523 | return y 524 | 525 | 526 | def crop_deconv( 527 | classes, 528 | scale=1, 529 | kernel_size=(4, 4), 530 | strides=(2, 2), 531 | crop_offset="centered", 532 | weight_decay=0.0, 533 | block_name="featx", 534 | ): 535 | def f(x, y): 536 | def scaling(xx, ss=1): 537 | return xx * ss 538 | 539 | scaled = Lambda( 540 | scaling, arguments={"ss": scale}, name="scale_{}".format(block_name) 541 | )(x) 542 | score = Conv2D( 543 | filters=classes, 544 | kernel_size=(1, 1), 545 | activation="linear", 546 | kernel_initializer="he_normal", 547 | name="score_{}".format(block_name), 548 | )(scaled) 549 | 550 | if y is None: 551 | upscore = Conv2DTranspose( 552 | filters=classes, 553 | kernel_size=kernel_size, 554 | strides=strides, 555 | padding="valid", 556 | kernel_initializer="he_normal", 557 | use_bias=False, 558 | name="upscore_{}".format(block_name), 559 | )(score) 560 | else: 561 | crop = CroppingLike2D( 562 | target_shape=K.int_shape(y), 563 | offset=crop_offset, 564 | name="crop_{}".format(block_name), 565 | )(score) 566 | merge = add([y, crop]) 567 | upscore = Conv2DTranspose( 568 | filters=classes, 569 | kernel_size=kernel_size, 570 | strides=strides, 571 | padding="valid", 572 | kernel_initializer="he_normal", 573 | use_bias=False, 574 | name="upscore_{}".format(block_name), 575 | )(merge) 576 | return upscore 577 | 578 | return f 579 | 580 | 581 | def PSPNet50( 582 | input_shape=(512, 512, 3), 583 | n_labels=20, 584 | output_stride=16, 585 | num_blocks=4, 586 | multigrid=[1, 1, 1], 587 | levels=[6, 3, 2, 1], 588 | use_se=True, 589 | output_mode="softmax", 590 | upsample_type="deconv", 591 | ): 592 | 593 | # Input shape 594 | img_input = Input(shape=input_shape) 595 | 596 | # compute input shape 597 | if K.image_data_format() == "channels_last": 598 | bn_axis = 3 599 | else: 600 | bn_axis = 1 601 | 602 | x = Conv2D(64, (7, 7), strides=(2, 2), padding="same", name="conv1")(img_input) 603 | x = BatchNormalization(axis=bn_axis, name="bn_conv1")(x) 604 | x = Activation("relu")(x) 605 | x = MaxPooling2D((3, 3), strides=(2, 2))(x) 606 | 607 | x = conv_block( 608 | x, 3, [64, 64, 256], stage=2, block="a", strides=(1, 1), use_se=use_se 609 | ) 610 | x = identity_block(x, 3, [64, 64, 256], stage=2, block="b", use_se=use_se) 611 | x = identity_block(x, 3, [64, 64, 256], stage=2, block="c", use_se=use_se) 612 | 613 | x = conv_block(x, 3, [128, 128, 512], stage=3, block="a", use_se=use_se) 614 | x = identity_block(x, 3, [128, 128, 512], stage=3, block="b", use_se=use_se) 615 | x = identity_block(x, 3, [128, 128, 512], stage=3, block="c", use_se=use_se) 616 | x = identity_block(x, 3, [128, 128, 512], stage=3, block="d", use_se=use_se) 617 | 618 | if output_stride == 8: 619 | rate_scale = 2 620 | elif output_stride == 16: 621 | rate_scale = 1 622 | 623 | x = conv_block( 624 | x, 625 | 3, 626 | [256, 256, 1024], 627 | stage=4, 628 | block="a", 629 | dilation_rate=1 * rate_scale, 630 | multigrid=multigrid, 631 | use_se=use_se, 632 | ) 633 | x = identity_block( 634 | x, 635 | 3, 636 | [256, 256, 1024], 637 | stage=4, 638 | block="b", 639 | dilation_rate=1 * rate_scale, 640 | multigrid=multigrid, 641 | use_se=use_se, 642 | ) 643 | x = identity_block( 644 | x, 645 | 3, 646 | [256, 256, 1024], 647 | stage=4, 648 | block="c", 649 | dilation_rate=1 * rate_scale, 650 | multigrid=multigrid, 651 | use_se=use_se, 652 | ) 653 | x = identity_block( 654 | x, 655 | 3, 656 | [256, 256, 1024], 657 | stage=4, 658 | block="d", 659 | dilation_rate=1 * rate_scale, 660 | multigrid=multigrid, 661 | use_se=use_se, 662 | ) 663 | x = identity_block( 664 | x, 665 | 3, 666 | [256, 256, 1024], 667 | stage=4, 668 | block="e", 669 | dilation_rate=1 * rate_scale, 670 | multigrid=multigrid, 671 | use_se=use_se, 672 | ) 673 | x = identity_block( 674 | x, 675 | 3, 676 | [256, 256, 1024], 677 | stage=4, 678 | block="f", 679 | dilation_rate=1 * rate_scale, 680 | multigrid=multigrid, 681 | use_se=use_se, 682 | ) 683 | 684 | init_rate = 2 685 | for block in range(4, num_blocks + 1): 686 | if block == 4: 687 | block = "" 688 | x = conv_block( 689 | x, 690 | 3, 691 | [512, 512, 2048], 692 | stage=5, 693 | block="a%s" % block, 694 | dilation_rate=init_rate * rate_scale, 695 | multigrid=multigrid, 696 | use_se=use_se, 697 | ) 698 | x = identity_block( 699 | x, 700 | 3, 701 | [512, 512, 2048], 702 | stage=5, 703 | block="b%s" % block, 704 | dilation_rate=init_rate * rate_scale, 705 | multigrid=multigrid, 706 | use_se=use_se, 707 | ) 708 | x = identity_block( 709 | x, 710 | 3, 711 | [512, 512, 2048], 712 | stage=5, 713 | block="c%s" % block, 714 | dilation_rate=init_rate * rate_scale, 715 | multigrid=multigrid, 716 | use_se=use_se, 717 | ) 718 | init_rate *= 2 719 | 720 | x = pyramid_pooling_module( 721 | x, 722 | num_filters=512, 723 | input_shape=input_shape, 724 | output_stride=output_stride, 725 | levels=levels, 726 | ) 727 | 728 | # x = merge([ 729 | # x1, 730 | # x2, 731 | # ], mode='concat', concat_axis=3) 732 | 733 | # upsample_type 734 | if upsample_type == "duc": 735 | x = duc( 736 | x, 737 | factor=output_stride, 738 | output_shape=(input_shape[0], input_shape[1], n_labels), 739 | ) 740 | out = _conv( 741 | filters=n_labels, 742 | kernel_size=(1, 1), 743 | padding="same", 744 | block="out_duc_%s" % output_stride, 745 | )(x) 746 | 747 | elif upsample_type == "bilinear": 748 | x = _conv( 749 | filters=n_labels, 750 | kernel_size=(1, 1), 751 | padding="same", 752 | block="out_bilinear_%s" % output_stride, 753 | )(x) 754 | out = BilinearUpSampling2D( 755 | (n_labels, input_shape[0], input_shape[1]), factor=output_stride 756 | )(x) 757 | 758 | elif upsample_type == "deconv": 759 | out = Conv2DTranspose( 760 | filters=n_labels, 761 | kernel_size=(output_stride * 2, output_stride * 2), 762 | strides=(output_stride, output_stride), 763 | padding="same", 764 | kernel_initializer="he_normal", 765 | kernel_regularizer=None, 766 | use_bias=False, 767 | name="upscore_{}".format("out"), 768 | )(x) 769 | 770 | out = Reshape( 771 | (input_shape[0] * input_shape[1], n_labels), 772 | input_shape=(input_shape[0], input_shape[1], n_labels), 773 | )(out) 774 | # default "softmax" 775 | out = Activation(output_mode)(out) 776 | 777 | model = Model(inputs=img_input, outputs=out) 778 | 779 | return model 780 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import keras.backend.tensorflow_backend as KTF 5 | import pandas as pd 6 | import tensorflow as tf 7 | from generator import data_gen_small 8 | from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard 9 | from model import PSPNet50 10 | 11 | 12 | def argparer(): 13 | # command line argments 14 | parser = argparse.ArgumentParser(description="PSPNet LIP dataset") 15 | parser.add_argument("--train_list", help="train list path") 16 | parser.add_argument("--trainimg_dir", help="train image dir path") 17 | parser.add_argument("--trainmsk_dir", help="train mask dir path") 18 | parser.add_argument("--val_list", help="val list path") 19 | parser.add_argument("--valimg_dir", help="val image dir path") 20 | parser.add_argument("--valmsk_dir", help="val mask dir path") 21 | parser.add_argument("--batch_size", default=5, type=int, help="batch size") 22 | parser.add_argument("--n_epochs", default=10, type=int, help="number of epoch") 23 | parser.add_argument( 24 | "--epoch_steps", default=6000, type=int, help="number of epoch step" 25 | ) 26 | parser.add_argument( 27 | "--val_steps", default=1000, type=int, help="number of valdation step" 28 | ) 29 | parser.add_argument("--n_labels", default=20, type=int, help="Number of label") 30 | parser.add_argument( 31 | "--input_shape", default=(512, 512, 3), help="Input images shape" 32 | ) 33 | parser.add_argument("--output_stride", default=16, type=int, help="output stirde") 34 | parser.add_argument( 35 | "--output_mode", default="softmax", type=str, help="output activation" 36 | ) 37 | parser.add_argument( 38 | "--upsample_type", default="deconv", type=str, help="upsampling type" 39 | ) 40 | parser.add_argument( 41 | "--loss", default="categorical_crossentropy", type=str, help="loss function" 42 | ) 43 | parser.add_argument("--optimizer", default="adadelta", type=str, help="oprimizer") 44 | parser.add_argument("--gpu", default="0", type=str, help="number of gpu") 45 | args = parser.parse_args() 46 | 47 | return args 48 | 49 | 50 | def main(args): 51 | # device number 52 | if args.gpu_num: 53 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 54 | 55 | # set the necessary list 56 | train_list = pd.read_csv(args.train_list, header=None) 57 | val_list = pd.read_csv(args.val_list, header=None) 58 | 59 | # set the necessary directories 60 | trainimg_dir = args.trainimg_dir 61 | trainmsk_dir = args.trainmsk_dir 62 | valimg_dir = args.valimg_dir 63 | valmsk_dir = args.valmsk_dir 64 | 65 | # get old session old_session = KTF.get_session() 66 | 67 | with tf.Graph().as_default(): 68 | session = tf.Session("") 69 | KTF.set_session(session) 70 | KTF.set_learning_phase(1) 71 | 72 | # set callbacks 73 | cp_cb = ModelCheckpoint( 74 | filepath=args.log_dir, 75 | monitor="val_loss", 76 | verbose=1, 77 | save_best_only=True, 78 | mode="auto", 79 | period=2, 80 | ) 81 | es_cb = EarlyStopping(monitor="val_loss", patience=2, verbose=1, mode="auto") 82 | tb_cb = TensorBoard(log_dir=args.log_dir, write_images=True) 83 | 84 | # set generater 85 | train_gen = data_gen_small( 86 | trainimg_dir, 87 | trainmsk_dir, 88 | train_list, 89 | args.batch_size, 90 | [args.input_shape[0], args.input_shape[1]], 91 | args.n_labels, 92 | ) 93 | val_gen = data_gen_small( 94 | valimg_dir, 95 | valmsk_dir, 96 | val_list, 97 | args.batch_size, 98 | [args.input_shape[0], args.input_shape[1]], 99 | args.n_labels, 100 | ) 101 | 102 | # set model 103 | pspnet = PSPNet50( 104 | input_shape=args.input_shape, 105 | n_labels=args.n_labels, 106 | output_mode=args.output_mode, 107 | upsample_type=args.upsample_type, 108 | ) 109 | print(pspnet.summary()) 110 | 111 | # compile model 112 | pspnet.compile(loss=args.loss, optimizer=args.optimizer, metrics=["accuracy"]) 113 | 114 | # fit with genarater 115 | pspnet.fit_generator( 116 | generator=train_gen, 117 | steps_per_epoch=args.epoch_steps, 118 | epochs=args.n_epochs, 119 | validation_data=val_gen, 120 | validation_steps=args.val_steps, 121 | callbacks=[cp_cb, es_cb, tb_cb], 122 | ) 123 | 124 | 125 | if __name__ == "__main__": 126 | 127 | args = argparer() 128 | main(args) 129 | --------------------------------------------------------------------------------