├── .gitignore ├── cityscapes_train.sh ├── images ├── SegmentationExample.png └── ThisSegmentationDoesNotExist.png ├── layers.py ├── models.py ├── notebooks ├── Enet CamVid Training.ipynb └── Enet FaceSegmentation Inference.ipynb ├── readme.md ├── run.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.homer 2 | *.pyc 3 | MNIST-data 4 | .vscode 5 | .ropeproject 6 | .ipynb_checkpoints 7 | __pycache__ 8 | notebooks/.ipynb_checkpoints 9 | notebooks/logs 10 | logs 11 | enet_tensorflow_debug.py 12 | EnetPaper.pdf 13 | dataset/ 14 | EnetTensorflow.code-workspace 15 | datasets/ 16 | notebooks/Enet FaceSegmentation Training.ipynb 17 | *.cache.* 18 | dataset 19 | EnetTensorflow.code-workspace 20 | datasets 21 | train.cache.data-00000-of-00001 22 | train.cache.index 23 | val.cache.data-00000-of-00001 24 | val.cache.index 25 | notebooks/Enet FaceSegmentation Training.ipynb 26 | enet_tensorflow_debug.py 27 | -------------------------------------------------------------------------------- /cityscapes_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | IW=512 3 | IH=1024 4 | BS=16 5 | EP=10 6 | NC=35 7 | IPTR="./datasets/cityscapes/leftImg8bit_trainvaltest/leftImg8bit/train/*/*.png" 8 | LPTR="./datasets/cityscapes/gtFine_trainvaltest/gtFine/train/*/*labelIds*.png" 9 | IPV="./datasets/cityscapes/leftImg8bit_trainvaltest/leftImg8bit/val/*/*.png" 10 | LPV="./datasets/cityscapes/gtFine_trainvaltest/gtFine/val/*/*labelIds*.png" 11 | echo "$IPTR" 12 | python3 run.py -iw $IW -ih $IH -bs $BS -e $EP -nc $NC -iptr "$IPTR" -lptr "$LPTR" -ipv "$IPV" -lpv "$LPV" -------------------------------------------------------------------------------- /images/SegmentationExample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevero/enet_tensorflow/897d0c052fb16b8a0dcc299da5fb7eed673a57e7/images/SegmentationExample.png -------------------------------------------------------------------------------- /images/ThisSegmentationDoesNotExist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevero/enet_tensorflow/897d0c052fb16b8a0dcc299da5fb7eed673a57e7/images/ThisSegmentationDoesNotExist.png -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.keras.utils import conv_utils 3 | 4 | # Tensorflow implementations of max_pooling and unpooling 5 | 6 | 7 | # Keras layers for pooling and unpooling 8 | class MaxPoolWithArgmax2D(tf.keras.layers.Layer): 9 | """2D Pooling layer with pooling indices. 10 | 11 | Arguments 12 | ---------- 13 | 'pool_size' = An integer or tuple/list of 2 integers: 14 | (pool_height, pool_width) specifying the size of the 15 | pooling window. Can be a single integer to specify 16 | the same value for all spatial dimensions. 17 | 'strides' = An integer or tuple/list of 2 integers, 18 | specifying the strides of the pooling operation. 19 | Can be a single integer to specify the same value for 20 | all spatial dimensions. 21 | 'padding' = A string. The padding method, either 'valid' or 'same'. 22 | Case-insensitive. 23 | 'data_format' = A string, one of `channels_last` (default) 24 | or `channels_first`. The ordering of the dimensions in the inputs. 25 | `channels_last` corresponds to inputs with shape 26 | `(batch, height, width, channels)` while `channels_first` corresponds 27 | to inputs with shape `(batch, channels, height, width)`. 28 | 'name' = A string, the name of the layer. 29 | """ 30 | def __init__(self, 31 | pool_size, 32 | strides, 33 | padding='valid', 34 | data_format=None, 35 | name=None, 36 | **kwargs): 37 | super(MaxPoolWithArgmax2D, self).__init__(name=name, **kwargs) 38 | if data_format is None: 39 | data_format = tf.keras.backend.image_data_format() 40 | if strides is None: 41 | strides = pool_size 42 | self.pool_size = conv_utils.normalize_tuple(pool_size, 2, 'pool_size') 43 | self.strides = conv_utils.normalize_tuple(strides, 2, 'strides') 44 | self.padding = conv_utils.normalize_padding(padding) 45 | self.data_format = conv_utils.normalize_data_format(data_format) 46 | self.input_spec = tf.keras.layers.InputSpec(ndim=4) 47 | 48 | def call(self, inputs): 49 | 50 | pool_shape = (1, ) + self.pool_size + (1, ) 51 | strides = (1, ) + self.strides + (1, ) 52 | 53 | if self.data_format == 'channels_last': 54 | outputs, argmax = tf.nn.max_pool_with_argmax( 55 | inputs, 56 | ksize=pool_shape, 57 | strides=strides, 58 | padding=self.padding.upper()) 59 | return (outputs, argmax) 60 | else: 61 | outputs, argmax = tf.nn.max_pool_with_argmax( 62 | tf.transpose(inputs, perm=[0, 2, 3, 1]), 63 | ksize=pool_shape, 64 | strides=strides, 65 | padding=self.padding.upper()) 66 | return (tf.transpose(outputs, perm=[0, 3, 1, 2]), 67 | tf.transpose(argmax, perm=[0, 3, 1, 2])) 68 | 69 | def compute_output_shape(self, input_shape): 70 | input_shape = tf.TensorShape(input_shape).as_list() 71 | if self.data_format == 'channels_first': 72 | rows = input_shape[2] 73 | cols = input_shape[3] 74 | else: 75 | rows = input_shape[1] 76 | cols = input_shape[2] 77 | rows = conv_utils.conv_output_length(rows, self.pool_size[0], 78 | self.padding, self.strides[0]) 79 | cols = conv_utils.conv_output_length(cols, self.pool_size[1], 80 | self.padding, self.strides[1]) 81 | if self.data_format == 'channels_first': 82 | return tf.TensorShape([input_shape[0], input_shape[1], rows, cols]) 83 | else: 84 | return tf.TensorShape([input_shape[0], rows, cols, input_shape[3]]) 85 | 86 | def get_config(self): 87 | config = { 88 | 'pool_size': self.pool_size, 89 | 'padding': self.padding, 90 | 'strides': self.strides, 91 | 'data_format': self.data_format 92 | } 93 | base_config = super(MaxPoolingWithArgmax2D, self).get_config() 94 | return dict(list(base_config.items()) + list(config.items())) 95 | 96 | 97 | class MaxUnpool2D(tf.keras.layers.Layer): 98 | def __init__(self, data_format='channels_last', name=None, **kwargs): 99 | super(MaxUnpool2D, self).__init__(**kwargs) 100 | if data_format is None: 101 | data_format = tf.keras.backend.image_data_format() 102 | self.data_format = conv_utils.normalize_data_format(data_format) 103 | self.input_spec = tf.keras.layers.InputSpec(min_ndim=2, max_ndim=4) 104 | 105 | def call(self, inputs, argmax, spatial_output_shape): 106 | 107 | # standardize spatial_output_shape 108 | spatial_output_shape = conv_utils.normalize_tuple( 109 | spatial_output_shape, 2, 'spatial_output_shape') 110 | 111 | # getting input shape 112 | # input_shape = tf.shape(inputs) 113 | # input_shape = inputs.get_shape().as_list() 114 | input_shape = tf.shape(inputs) 115 | 116 | # checking if spatial shape is ok 117 | if self.data_format == 'channels_last': 118 | output_shape = (input_shape[0],) + \ 119 | spatial_output_shape + (input_shape[3],) 120 | 121 | # assert output_shape[1] * output_shape[2] * output_shape[ 122 | # 3] > tf.math.reduce_max(argmax).numpy(), "HxWxC <= Max(argmax)" 123 | else: 124 | output_shape = (input_shape[0], 125 | input_shape[1]) + spatial_output_shape 126 | # assert output_shape[1] * output_shape[2] * output_shape[ 127 | # 3] > tf.math.reduce_max(argmax).numpy(), "CxHxW <= Max(argmax)" 128 | 129 | # N * H_in * W_in * C 130 | # flat_input_size = tf.reduce_prod(input_shape) 131 | flat_input_size = tf.reduce_prod(input_shape) 132 | 133 | # flat output_shape = [N, H_out * W_out * C] 134 | flat_output_shape = [ 135 | output_shape[0], 136 | output_shape[1] * output_shape[2] * output_shape[3] 137 | ] 138 | 139 | # flatten input tensor for the use in tf.scatter_nd 140 | inputs_ = tf.reshape(inputs, [flat_input_size]) 141 | 142 | # create the tensor [ [[[0]]], [[[1]]], ..., [[[N-1]]] ] 143 | # corresponding to the batch size but transposed in 4D 144 | batch_range = tf.reshape(tf.range(tf.cast(output_shape[0], tf.int64), 145 | dtype=argmax.dtype), 146 | shape=[input_shape[0], 1, 1, 1]) 147 | 148 | # b is a tensor of size (N, H, W, C) or (N, C, H, W) whose 149 | # first element of the batch are 3D-array full of 0, ... 150 | # second element of the batch are 3D-array full of 1, ... 151 | b = tf.ones_like(argmax) * batch_range 152 | b = tf.reshape(b, [flat_input_size, 1]) 153 | 154 | # argmax_ = [ [0, argmax_1], [0, argmax_2], ... [0, argmax_k], ..., 155 | # [N-1, argmax_{N*H*W*C}], [N-1, argmax_{N*H*W*C-1}] ] 156 | argmax_ = tf.reshape(argmax, [flat_input_size, 1]) 157 | argmax_ = tf.concat([b, argmax_], axis=-1) 158 | 159 | # reshaping output tensor 160 | ret = tf.scatter_nd(argmax_, 161 | inputs_, 162 | shape=tf.cast(flat_output_shape, tf.int64)) 163 | ret = tf.reshape(ret, output_shape) 164 | 165 | return ret 166 | 167 | def compute_output_shape(self, input_shape, spatial_output_shape): 168 | 169 | # getting input shape 170 | input_shape = tf.shape(input_shape) 171 | 172 | # standardize spatial_output_shape 173 | spatial_output_shape = conv_utils.normalize_tuple( 174 | spatial_output_shape, 2, 'spatial_output_shape') 175 | 176 | # checking if spatial shape is ok 177 | if self.data_format == 'channels_last': 178 | output_shape = (input_shape[0],) + \ 179 | self.spatial_output_shape + (input_shape[3],) 180 | # assert output_shape[1] * output_shape[2] > tf.math.reduce_max( 181 | # self.argmax).numpy(), "HxW <= Max(argmax)" 182 | else: 183 | output_shape = (input_shape[0], 184 | input_shape[1]) + self.spatial_output_shape 185 | # assert output_shape[2] * output_shape[3] > tf.math.reduce_max( 186 | # self.argmax).numpy(), "HxW <= Max(argmax)" 187 | 188 | return output_shape 189 | 190 | def get_config(self): 191 | config = { 192 | 'spatial_output_shape': self.spatial_output_shape, 193 | 'data_format': self.data_format 194 | } 195 | base_config = super(MaxPoolingWithArgmax2D, self).get_config() 196 | return dict(list(base_config.items()) + list(config.items())) 197 | 198 | 199 | class BottleNeck(tf.keras.Model): 200 | ''' 201 | Enet bottleneck module as in: 202 | (1) Paszke, A.; Chaurasia, A.; Kim, S.; Culurciello, E. ENet: A Deep Neural 203 | Network Architecture for Real-Time Semantic Segmentation. 204 | arXiv:1606.02147 [cs] 2016. 205 | (2) https://github.com/e-lab/ENet-training/blob/master/train/models/encoder.lua 206 | (3) https://culurciello.github.io/tech/2016/06/20/training-enet.html 207 | 208 | This is the general bottleneck modules. It is used both in the encoding and 209 | decoding paths, the only exception being the upsampling decoding, where 210 | we use BottleDeck 211 | 212 | Arguments 213 | ---------- 214 | 'output_filters' = an `Integer`: number of output filters 215 | 'kernel_size' = a `List`: size of the kernel for the central convolution 216 | 'kernel_strides' = a `List`: length of the strides for the central conv 217 | 'padding' = a `String`: padding of the central convolution 218 | 'dilation_rate' = a `List`: dilation rate of the central convolution 219 | 'internal_comp_ratio' = an `Integer`: compression ratio of the bottleneck 220 | 'dropout_prob' = a `float`: dropout at the end of the main connection 221 | 'downsample' = a `String`: downsampling flag 222 | 'name' = a `String`: name of the bottleneck 223 | 224 | Returns 225 | ------- 226 | 'output_layer' = A `Tensor` with the same type as `input_layer` 227 | ''' 228 | def __init__(self, 229 | output_filters=128, 230 | kernel_size=[3, 3], 231 | kernel_strides=[1, 1], 232 | padding='same', 233 | dilation_rate=[1, 1], 234 | internal_comp_ratio=4, 235 | dropout_prob=0.1, 236 | l2=0.0, 237 | downsample=False, 238 | name='BottleEnc', 239 | **kwargs): 240 | super(BottleNeck, self).__init__(name=name, **kwargs) 241 | 242 | # ------- bottleneck parameters ------- 243 | self.output_filters = output_filters 244 | self.kernel_size = kernel_size 245 | self.kernel_strides = kernel_strides 246 | self.padding = padding 247 | self.dilation_rate = dilation_rate 248 | self.internal_comp_ratio = internal_comp_ratio 249 | self.dropout_prob = dropout_prob 250 | self.l2 = l2 251 | self.downsample = downsample 252 | 253 | # Derived parameters 254 | self.internal_filters = self.output_filters // self.internal_comp_ratio 255 | if self.internal_filters == 0: 256 | self.internal_filters = 1 257 | 258 | # downsampling or not 259 | if self.downsample: 260 | self.down_kernel = [2, 2] 261 | self.down_strides = [2, 2] 262 | else: 263 | self.down_kernel = [1, 1] 264 | self.down_strides = [1, 1] 265 | 266 | # ------- main connection layers ------- 267 | 268 | # bottleneck representation compression with valid padding 269 | # 1x1 usually, 2x2 if downsampling 270 | self.ConvIn = tf.keras.layers.Conv2D( 271 | self.internal_filters, 272 | self.down_kernel, 273 | strides=self.down_strides, 274 | use_bias=False, 275 | kernel_regularizer=tf.keras.regularizers.l2(l2), 276 | name=self.name + '.' + 'ConvIn') 277 | self.BNormIn = tf.keras.layers.BatchNormalization(name=self.name + 278 | '.' + 'BNormIn') 279 | self.PreLuIn = tf.keras.layers.PReLU(name=self.name + '.' + 'PreLuIn') 280 | 281 | # central convolution 282 | self.asym_flag = self.kernel_size[0] != self.kernel_size[1] 283 | self.ConvMain = tf.keras.layers.Conv2D( 284 | self.internal_filters, 285 | self.kernel_size, 286 | strides=self.kernel_strides, 287 | padding=self.padding, 288 | dilation_rate=self.dilation_rate, 289 | use_bias=not (self.asym_flag), 290 | kernel_regularizer=tf.keras.regularizers.l2(l2), 291 | name=self.name + '.' + 'ConvMain') 292 | if self.asym_flag: 293 | self.ConvMainAsym = tf.keras.layers.Conv2D( 294 | self.internal_filters, 295 | self.kernel_size[::-1], 296 | strides=self.kernel_strides, 297 | padding=self.padding, 298 | dilation_rate=self.dilation_rate, 299 | kernel_regularizer=tf.keras.regularizers.l2(l2), 300 | name=self.name + '.' + 'ConvMainAsym') 301 | self.BNormMain = tf.keras.layers.BatchNormalization(name=self.name + 302 | '.' + 'BNormMain') 303 | self.PreLuMain = tf.keras.layers.PReLU(name=self.name + '.' + 304 | 'PreLuMain') 305 | 306 | # bottleneck representation expansion with 1x1 valid convolution 307 | self.ConvOut = tf.keras.layers.Conv2D( 308 | self.output_filters, [1, 1], 309 | strides=[1, 1], 310 | use_bias=False, 311 | kernel_regularizer=tf.keras.regularizers.l2(l2), 312 | name=self.name + '.' + 'ConvOut') 313 | self.BNormOut = tf.keras.layers.BatchNormalization(name=self.name + 314 | '.' + 'BNormOut') 315 | self.DropOut = tf.keras.layers.SpatialDropout2D(dropout_prob, 316 | name=self.name + '.' + 317 | 'DropOut') 318 | 319 | # ------- skip connection layers ------- 320 | 321 | # downsampling layer 322 | self.ArgMaxSkip = MaxPoolWithArgmax2D(pool_size=self.down_kernel, 323 | strides=self.down_strides, 324 | name=self.name + '.' + 325 | 'ArgMaxSkip') 326 | 327 | # matching filter dimension with learned 1x1 convolution 328 | # this is done differently than in vanilla enet, where 329 | # you shold just pad with zeros. 330 | self.ConvSkip = tf.keras.layers.Conv2D( 331 | self.output_filters, 332 | kernel_size=[1, 1], 333 | padding='valid', 334 | use_bias=False, 335 | kernel_regularizer=tf.keras.regularizers.l2(l2), 336 | name=name + '.' + 'ConvSkip') 337 | 338 | # ------- output layer ------- 339 | self.AddMainSkip = tf.keras.layers.Add(name=self.name + '.' + 340 | 'AddSkip') 341 | self.PreLuMainSkip = tf.keras.layers.PReLU(name=self.name + '.' + 342 | 'PreLuSkip') 343 | 344 | def call(self, input_layer): 345 | 346 | # input filter from incoming layer 347 | input_filters = input_layer.get_shape().as_list()[-1] 348 | 349 | # ----- main connection ------ 350 | # Bottleneck in 351 | main = self.ConvIn(input_layer) 352 | main = self.BNormIn(main) 353 | main = self.PreLuIn(main) 354 | 355 | # Bottleneck main 356 | main = self.ConvMain(main) 357 | if self.asym_flag: 358 | main = self.ConvMainAsym(main) 359 | main = self.BNormMain(main) 360 | main = self.PreLuMain(main) 361 | 362 | # Bottleneck out 363 | main = self.ConvOut(main) 364 | main = self.BNormOut(main) 365 | main = self.DropOut(main) 366 | 367 | # ----- skip connection ------ 368 | skip = input_layer 369 | 370 | # downsampling if necessary 371 | if self.downsample: 372 | skip, argmax = self.ArgMaxSkip(input_layer) 373 | 374 | # matching filter dimension with learned 1x1 convolution 375 | # this is done differently than in vanilla enet, where 376 | # you should just pad with zeros. 377 | if input_filters != self.output_filters: 378 | skip = self.ConvSkip(skip) 379 | 380 | # ------- output layer ------- 381 | addition_layer = self.AddMainSkip([main, skip]) 382 | output_layer = self.PreLuMainSkip(addition_layer) 383 | 384 | # I need the input layer, I see no other way round 385 | # because i neet to pass it to the decoder 386 | if self.downsample: 387 | return output_layer, argmax, input_layer 388 | else: 389 | return output_layer 390 | 391 | 392 | class BottleDeck(tf.keras.Model): 393 | ''' 394 | Enet bottleneck module as in: 395 | (1) Paszke, A.; Chaurasia, A.; Kim, S.; Culurciello, E. ENet: A Deep Neural 396 | Network Architecture for Real-Time Semantic Segmentation. 397 | arXiv:1606.02147 [cs] 2016. 398 | (2) https://github.com/e-lab/ENet-training/blob/master/train/models/encoder.lua 399 | (3) https://culurciello.github.io/tech/2016/06/20/training-enet.html 400 | 401 | This is the general bottleneck decoding modules. It is used only in the 402 | decoding path when we use the upsampling. In the forward pass we have 403 | three input tensors: 404 | - input: the real input tensor 405 | - enc_tensor: coming from the encoder path, used to get the shape of the 406 | output tensor 407 | - argmax: the tensor for the mapping of the upsampled values 408 | 409 | Arguments 410 | ---------- 411 | 'output_filters' = an `Integer`: number of output filters 412 | 'kernel_size' = a `List`: size of the kernel for the central convolution 413 | 'kernel_strides' = a `List`: length of the strides for the central conv 414 | 'padding' = a `String`: padding of the central convolution 415 | 'dilation_rate' = a `List`: dilation rate of the central convolution 416 | 'internal_comp_ratio' = an `Integer`: compression ratio of the bottleneck 417 | 'dropout_prob' = a `float`: dropout at the end of the main connection 418 | 'name' = a `String`: name of the bottleneck 419 | 420 | Returns 421 | ------- 422 | 'output_layer' = A `Tensor` with the same type as `input_layer` 423 | ''' 424 | def __init__(self, 425 | output_filters=128, 426 | kernel_size=[3, 3], 427 | kernel_strides=[2, 2], 428 | padding='same', 429 | dilation_rate=[1, 1], 430 | internal_comp_ratio=4, 431 | dropout_prob=0.1, 432 | l2=0.0, 433 | name='BottleDeck', 434 | **kwargs): 435 | super(BottleDeck, self).__init__(name=name, **kwargs) 436 | 437 | # ------- bottleneck parameters ------- 438 | self.output_filters = output_filters 439 | self.kernel_size = kernel_size 440 | self.kernel_strides = kernel_strides 441 | self.padding = padding 442 | self.dilation_rate = dilation_rate 443 | self.internal_comp_ratio = internal_comp_ratio 444 | self.dropout_prob = dropout_prob 445 | self.l2 = l2 446 | 447 | # Derived parameters 448 | self.internal_filters = self.output_filters // self.internal_comp_ratio 449 | if self.internal_filters == 0: 450 | self.internal_filters = 1 451 | 452 | # ------- main connection layers ------- 453 | 454 | # bottleneck representation compression with valid padding 455 | # 1x1 usually, 2x2 if downsampling 456 | self.ConvIn = tf.keras.layers.Conv2D( 457 | self.internal_filters, 458 | kernel_size=[1, 1], 459 | strides=[1, 1], 460 | use_bias=False, 461 | kernel_regularizer=tf.keras.regularizers.l2(l2), 462 | name=self.name + '.' + 'ConvIn') 463 | self.BNormIn = tf.keras.layers.BatchNormalization(name=self.name + 464 | '.' + 'BNormIn') 465 | self.PreLuIn = tf.keras.layers.PReLU(name=self.name + '.' + 'PreLuIn') 466 | 467 | # central convolution: am i using "same" padding? 468 | self.ConvMain = tf.keras.layers.Conv2DTranspose( 469 | self.internal_filters, 470 | self.kernel_size, 471 | strides=self.kernel_strides, 472 | padding=self.padding, 473 | dilation_rate=self.dilation_rate, 474 | use_bias=True, 475 | kernel_regularizer=tf.keras.regularizers.l2(l2), 476 | name=self.name + '.' + 'ConvMain') 477 | self.BNormMain = tf.keras.layers.BatchNormalization(name=self.name + 478 | '.' + 'BNormMain') 479 | self.PreLuMain = tf.keras.layers.PReLU(name=self.name + '.' + 480 | 'PreLuMain') 481 | 482 | # bottleneck representation expansion with 1x1 valid convolution 483 | self.ConvOut = tf.keras.layers.Conv2D( 484 | self.output_filters, [1, 1], 485 | strides=[1, 1], 486 | use_bias=False, 487 | kernel_regularizer=tf.keras.regularizers.l2(l2), 488 | name=self.name + '.' + 'ConvOut') 489 | self.BNormOut = tf.keras.layers.BatchNormalization(name=self.name + 490 | '.' + 'BNormOut') 491 | self.DropOut = tf.keras.layers.SpatialDropout2D(dropout_prob, 492 | name=self.name + '.' + 493 | 'DropOut') 494 | 495 | # ------- skip connection layers ------- 496 | 497 | # convolution for the upsampling. It comes before the 498 | # unpooling layer. 499 | self.ConvSkip = tf.keras.layers.Conv2D( 500 | self.output_filters, 501 | kernel_size=[1, 1], 502 | padding='valid', 503 | use_bias=False, 504 | kernel_regularizer=tf.keras.regularizers.l2(l2), 505 | name=name + '.' + 'ConvSkip') 506 | 507 | # downsampling layer 508 | self.MaxUnpoolSkip = MaxUnpool2D(name=self.name + '.' + 509 | 'MaxUnpoolSkip') 510 | 511 | # ------- output layer ------- 512 | self.AddMainSkip = tf.keras.layers.Add(name=self.name + '.' + 513 | 'AddMainSkip') 514 | self.PreluMainSkip = tf.keras.layers.PReLU(name=self.name + '.' + 515 | 'PreluMainSkip') 516 | 517 | def call(self, input_layer, argmax, upsample_layer): 518 | 519 | # input filter from incoming layer, and upsample layer spatial shape 520 | input_filters = input_layer.get_shape().as_list()[-1] 521 | upsample_layer_shape = upsample_layer.get_shape().as_list()[1:3] 522 | 523 | # ----- main connection ------ 524 | # Bottleneck in 525 | main = self.ConvIn(input_layer) 526 | main = self.BNormIn(main) 527 | main = self.PreLuIn(main) 528 | 529 | # Bottleneck main 530 | main = self.ConvMain(main) 531 | main = self.BNormMain(main) 532 | main = self.PreLuMain(main) 533 | 534 | main = self.ConvOut(main) 535 | main = self.BNormOut(main) 536 | main = self.DropOut(main) 537 | 538 | # ----- skip connection ------ 539 | # matching channels before applying MaxUnpool 540 | skip = self.ConvSkip(input_layer) 541 | 542 | # downsampling if necessary 543 | skip = self.MaxUnpoolSkip(skip, argmax, upsample_layer_shape) 544 | 545 | # ------- output layer ------- 546 | addition_layer = self.AddMainSkip([main, skip]) 547 | output_layer = self.PreluMainSkip(addition_layer) 548 | 549 | return output_layer 550 | 551 | 552 | class InitBlock(tf.keras.Model): 553 | ''' 554 | Enet init_block as in: 555 | (1) Paszke, A.; Chaurasia, A.; Kim, S.; Culurciello, E. ENet: A Deep Neural Network 556 | Architecture for Real-Time Semantic Segmentation. arXiv:1606.02147 [cs] 2016. 557 | (2) https://github.com/e-lab/ENet-training/blob/master/train/models/encoderI.lua 558 | (3) https://culurciello.github.io/tech/2016/06/20/training-enet.html 559 | 560 | 561 | Arguments 562 | ---------- 563 | 'conv_filters' = an `Integer`: number filters for the convolution 564 | 'kernel_size' = a `List`: size of the kernel for the convolution 565 | 'kernel_strides' = a `List`: length of the strides for the convolution 566 | 'pool_size' = a `List`: size of the pool for the maxpooling 567 | 'pool_strides' = a `List`: length of the strides for the maxpooling 568 | 'padding' = a `String`: padding for the convolution and the maxpooling 569 | 'name' = a `String`: name of the init_block 570 | ''' 571 | def __init__(self, 572 | conv_filters=13, 573 | kernel_size=[3, 3], 574 | kernel_strides=[2, 2], 575 | pool_size=[2, 2], 576 | pool_strides=[2, 2], 577 | padding='valid', 578 | l2=0.0, 579 | name='init_block', 580 | **kwargs): 581 | super(InitBlock, self).__init__(name=name, **kwargs) 582 | 583 | # ------- init_block parameters ------- 584 | self.conv_filters = conv_filters 585 | self.kernel_size = kernel_size 586 | self.kernel_strides = kernel_strides 587 | self.pool_size = pool_size 588 | self.pool_strides = pool_strides 589 | self.padding = padding 590 | 591 | # ------- init_block layers ------- 592 | 593 | # conv connection: need the padding to match the dimension of pool_init 594 | self.padded_init = tf.keras.layers.ZeroPadding2D() 595 | self.conv_init = tf.keras.layers.Conv2D( 596 | conv_filters, 597 | kernel_size, 598 | strides=kernel_strides, 599 | kernel_regularizer=tf.keras.regularizers.l2(l2), 600 | padding='valid') 601 | 602 | # maxpool, where pool_init is to be concatenated with conv_init 603 | self.pool_init = tf.keras.layers.MaxPool2D(pool_size=pool_size, 604 | strides=pool_strides, 605 | padding='valid') 606 | 607 | # concatenating the two connections 608 | self.concatenate = tf.keras.layers.Concatenate(axis=-1) 609 | self.batch_norm = tf.keras.layers.BatchNormalization() 610 | self.prelu = tf.keras.layers.PReLU(name=self.name + '.' + 'out_init') 611 | 612 | def call(self, input_layer): 613 | 614 | # ----- conv connection ------ 615 | # conv connection: need the padding to match the dimension of pool_init 616 | conv_conn = self.padded_init(input_layer) 617 | conv_conn = self.conv_init(conv_conn) 618 | 619 | # ----- pool connection ------ 620 | pool_conn = self.pool_init(input_layer) 621 | 622 | # ------- concat to output layer ------- 623 | output_layer = self.concatenate([conv_conn, pool_conn]) 624 | output_layer = self.batch_norm(output_layer) 625 | output_layer = self.prelu(output_layer) 626 | 627 | return output_layer 628 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from layers import BottleDeck, BottleNeck, InitBlock 3 | import numpy as np 4 | 5 | 6 | class EnetModel(tf.keras.Model): 7 | ''' 8 | Enet model. 9 | (1) Paszke, A.; Chaurasia, A.; Kim, S.; Culurciello, E. 10 | ENet: A Deep Neural Network Architecture for Real-Time Semantic 11 | Segmentation. arXiv:1606.02147 [cs] 2016. 12 | 13 | Arguments 14 | ---------- 15 | 'input_layer' = input `Tensor` with type `float32` and 16 | shape [batch_size,w,h,1] 17 | 'C' = an `Integer`: number of classes 18 | 'l2' = a `float`: l2 regularization parameter 19 | 20 | Returns 21 | ------- 22 | 'EncOut, DecOut' = A `Tensor` with the same type as `input_layer` 23 | ''' 24 | def __init__(self, C=12, l2=0.0, MultiObjective=False, **kwargs): 25 | super(EnetModel, self).__init__(**kwargs) 26 | 27 | # initialize parameters 28 | self.C = C 29 | self.l2 = l2 30 | self.MultiObjective = MultiObjective 31 | 32 | # # layers 33 | self.InitBlock = InitBlock(conv_filters=13) 34 | 35 | # # first block of bottlenecks 36 | self.BNeck1_0 = BottleNeck(output_filters=64, 37 | downsample=True, 38 | dropout_prob=0.01, 39 | l2=l2, 40 | name='BNeck1_0') 41 | self.BNeck1_1 = BottleNeck(output_filters=64, 42 | dropout_prob=0.01, 43 | l2=l2, 44 | name='BNeck1_1') 45 | self.BNeck1_2 = BottleNeck(output_filters=64, 46 | dropout_prob=0.01, 47 | l2=l2, 48 | name='BNeck1_2') 49 | self.BNeck1_3 = BottleNeck(output_filters=64, 50 | dropout_prob=0.01, 51 | l2=l2, 52 | name='BNeck1_3') 53 | self.BNeck1_4 = BottleNeck(output_filters=64, 54 | dropout_prob=0.01, 55 | l2=l2, 56 | name='BNeck1_4') 57 | 58 | # # second block of bottlenecks 59 | self.BNeck2_0 = BottleNeck(output_filters=128, 60 | downsample=True, 61 | l2=l2, 62 | name='BNeck2_0') 63 | self.BNeck2_1 = BottleNeck(output_filters=128, l2=l2, name='BNeck2_1') 64 | self.BNeck2_2 = BottleNeck(output_filters=128, 65 | dilation_rate=(2, 2), 66 | l2=l2, 67 | name='BNeck2_2') 68 | self.BNeck2_3 = BottleNeck(output_filters=128, 69 | kernel_size=(5, 1), 70 | l2=l2, 71 | name='BNeck2_3') 72 | self.BNeck2_4 = BottleNeck(output_filters=128, 73 | dilation_rate=(4, 4), 74 | l2=l2, 75 | name='BNeck2_4') 76 | self.BNeck2_5 = BottleNeck(output_filters=128, l2=l2, name='BNeck2_5') 77 | self.BNeck2_6 = BottleNeck(output_filters=128, 78 | dilation_rate=(8, 8), 79 | l2=l2, 80 | name='BNeck2_6') 81 | self.BNeck2_7 = BottleNeck(output_filters=128, 82 | kernel_size=(5, 1), 83 | l2=l2, 84 | name='BNeck2_7') 85 | self.BNeck2_8 = BottleNeck(output_filters=128, 86 | dilation_rate=(16, 16), 87 | l2=l2, 88 | name='BNeck2_8') 89 | 90 | # # third block of bottlenecks 91 | self.BNeck3_1 = BottleNeck(output_filters=128, l2=l2, name='BNeck3_1') 92 | self.BNeck3_2 = BottleNeck(output_filters=128, 93 | dilation_rate=(2, 2), 94 | l2=l2, 95 | name='BNeck3_2') 96 | self.BNeck3_3 = BottleNeck(output_filters=128, 97 | kernel_size=(5, 1), 98 | l2=l2, 99 | name='BNeck3_3') 100 | self.BNeck3_4 = BottleNeck(output_filters=128, 101 | dilation_rate=(4, 4), 102 | l2=l2, 103 | name='BNeck3_4') 104 | self.BNeck3_5 = BottleNeck(output_filters=128, l2=l2, name='BNeck3_5') 105 | self.BNeck3_6 = BottleNeck(output_filters=128, 106 | dilation_rate=(8, 8), 107 | l2=l2, 108 | name='BNeck3_6') 109 | self.BNeck3_7 = BottleNeck(output_filters=128, 110 | kernel_size=(5, 1), 111 | l2=l2, 112 | name='BNeck3_7') 113 | self.BNeck3_8 = BottleNeck(output_filters=128, 114 | dilation_rate=(16, 16), 115 | l2=l2, 116 | name='BNeck3_8') 117 | 118 | # project the encoder output to the number of classes 119 | # to get the output of the encoder head 120 | self.ConvEncOut = tf.keras.layers.Conv2D( 121 | self.C, 122 | kernel_size=[1, 1], 123 | padding='valid', 124 | use_bias=False, 125 | activation='softmax', 126 | kernel_regularizer=tf.keras.regularizers.l2(l2), 127 | name='EncOut') 128 | 129 | # fourth block of bottlenecks 130 | self.BNeck4_0 = BottleDeck(output_filters=64, 131 | internal_comp_ratio=2, 132 | l2=l2, 133 | name='BNeck4_0') 134 | self.BNeck4_1 = BottleNeck(output_filters=64, l2=l2, name='BNeck4_1') 135 | self.BNeck4_2 = BottleNeck(output_filters=64, l2=l2, name='BNeck4_2') 136 | 137 | # fourth block of bottlenecks 138 | self.BNeck5_0 = BottleDeck(output_filters=16, 139 | internal_comp_ratio=2, 140 | l2=l2, 141 | name='BNeck5_0') 142 | self.BNeck5_1 = BottleNeck(output_filters=16, l2=l2, name='BNeck5_1') 143 | 144 | # Final ConvTranspose Layer 145 | self.FullConv = tf.keras.layers.Conv2DTranspose( 146 | self.C, 147 | kernel_size=(3, 3), 148 | strides=(2, 2), 149 | padding='same', 150 | activation='softmax', 151 | kernel_regularizer=tf.keras.regularizers.l2(l2), 152 | name='DecOut') 153 | 154 | def call(self, inputs): 155 | 156 | # init block 157 | x = self.InitBlock(inputs) 158 | 159 | # first block of bottlenecks - downsampling 160 | x, x_argmax1_0, x_upsample1_0 = self.BNeck1_0(x) # downsample 161 | x = self.BNeck1_1(x) 162 | x = self.BNeck1_2(x) 163 | x = self.BNeck1_3(x) 164 | x = self.BNeck1_4(x) 165 | 166 | # second block of bottlenecks - downsampling 167 | x, x_argmax2_0, x_upsample2_0 = self.BNeck2_0(x) # downsample 168 | x = self.BNeck2_1(x) 169 | x = self.BNeck2_2(x) 170 | x = self.BNeck2_3(x) 171 | x = self.BNeck2_4(x) 172 | x = self.BNeck2_5(x) 173 | x = self.BNeck2_6(x) 174 | x = self.BNeck2_7(x) 175 | x = self.BNeck2_8(x) 176 | 177 | # third block of bottlenecks 178 | x = self.BNeck3_1(x) 179 | x = self.BNeck3_2(x) 180 | x = self.BNeck3_3(x) 181 | x = self.BNeck3_4(x) 182 | x = self.BNeck3_5(x) 183 | x = self.BNeck3_6(x) 184 | x = self.BNeck3_7(x) 185 | x = self.BNeck3_8(x) 186 | 187 | if self.MultiObjective: 188 | EncOut = self.ConvEncOut(x) 189 | 190 | # fourth block of bottlenecks - upsampling 191 | x = self.BNeck4_0(x, x_argmax2_0, x_upsample2_0) 192 | x = self.BNeck4_1(x) 193 | x = self.BNeck4_2(x) 194 | 195 | # fifth block of bottlenecks - upsampling 196 | x = self.BNeck5_0(x, x_argmax1_0, x_upsample1_0) 197 | x = self.BNeck5_1(x) 198 | 199 | # final full conv to the segmentation maps 200 | DecOut = self.FullConv(x) 201 | 202 | # what i return, depends on the multiobjective flag 203 | if self.MultiObjective: 204 | return EncOut, DecOut 205 | else: 206 | return DecOut 207 | -------------------------------------------------------------------------------- /notebooks/Enet CamVid Training.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.5-final"},"colab":{"name":"Enet CamVid Training.ipynb","provenance":[{"file_id":"https://github.com/gevero/enet_tensorflow/blob/master/notebooks/Enet%20CamVid%20Training.ipynb","timestamp":1573076117127}],"collapsed_sections":["Sz4BIEzsmqnN","2Ao_SIGICpCZ","GTtXHaRECpCj","vYpZV_YoCpCu","3zBhPKs-CpC3","4MtlcLzOJ3cn","b6B-3jiAgbY4","kI9gr7jTdCgc","juTnAVjQeA01","z58tXooI4RAG"],"machine_shape":"hm"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"colab_type":"text","id":"Sz4BIEzsmqnN"},"source":["## **Data and code setup**"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{},"colab_type":"code","id":"pIZUbbQAqJSL"},"outputs":[],"source":["%%capture\n","!pip install gdown\n","!pip3 install gpustat"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{},"colab_type":"code","id":"dKE3hFiLqkwD"},"outputs":[],"source":["%%capture\n","!git clone https://github.com/gevero/enet_tensorflow.git"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{},"colab_type":"code","id":"fF5l82H-mwH8"},"outputs":[],"source":["%%capture\n","!gdown https://drive.google.com/uc?id=1gt0nCGft0winZqHBYaTb1EL6zM8lrKPA\n","!unzip -o camvid.zip"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"2Ao_SIGICpCZ"},"source":["## **Notebook Setup**"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["# update to tf 2.0\n","from __future__ import absolute_import, division, print_function, unicode_literals\n","\n","# Install TensorFlow\n","try:\n"," # %tensorflow_version only exists in Colab.\n"," %tensorflow_version 2.1\n","except Exception:\n"," pass\n","\n","# importing standard libraries\n","import tensorflow as tf\n","print(tf.__version__)\n","import matplotlib.pylab as plt\n","import numpy as np\n","import os, os.path\n","from functools import partial\n","from google.colab import files\n","\n","# Importing utils and models\n","import sys\n","sys.path.append('./enet_tensorflow')\n","from utils import preprocess_img_label, map_singlehead, map_doublehead, map_label, tf_dataset_generator, get_class_weights\n","from models import EnetModel"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"GTtXHaRECpCj"},"source":["## **Create training test and validation dataset, and get class weights**"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["%%capture\n","# creating datasets\n","img_pattern = \"./dataset/train/images/*.png\"\n","label_pattern = \"./dataset/train/labels/*.png\"\n","img_pattern_val = \"./dataset/val/images/*.png\"\n","label_pattern_val = \"./dataset/val/labels/*.png\"\n","img_pattern_test = \"./dataset/test/images/*.png\"\n","label_pattern_test = \"./dataset/test/labels/*.png\"\n","\n","# batch size\n","batch_size = 8\n","\n","# image size\n","img_height = 360\n","img_width = 480\n","h_enc = img_height // 8\n","w_enc = img_width // 8\n","h_dec = img_height\n","w_dec = img_width\n","\n","# create (img,label) string tensor lists\n","filelist_train = preprocess_img_label(img_pattern, label_pattern)\n","filelist_val = preprocess_img_label(img_pattern_val, label_pattern_val)\n","filelist_test = preprocess_img_label(img_pattern_test, label_pattern_test)\n","\n","# training dataset size\n","n_train = tf.data.experimental.cardinality(filelist_train).numpy()\n","n_val = tf.data.experimental.cardinality(filelist_val).numpy()\n","n_test = tf.data.experimental.cardinality(filelist_test).numpy()\n","\n","# define mapping functions for single and double head nets\n","map_single = lambda img_file, label_file: map_singlehead(\n"," img_file, label_file, h_dec, w_dec)\n","map_double = lambda img_file, label_file: map_doublehead(\n"," img_file, label_file, h_enc, w_enc, h_dec, w_dec)\n","\n","# create single head datasets\n","train_single_ds = filelist_train.shuffle(n_train).map(map_single).cache().batch(batch_size).repeat()\n","val_single_ds = filelist_val.map(map_single).cache().batch(batch_size).repeat()\n","test_single_ds = filelist_test.map(map_single).cache().batch(batch_size).repeat()\n","\n","# create double head datasets\n","train_double_ds = filelist_train.shuffle(n_train).map(map_double).cache().batch(batch_size).repeat()\n","val_double_ds = filelist_val.map(map_double).cache().batch(batch_size).repeat()\n","test_double_ds = filelist_test.map(map_double).cache().batch(batch_size).repeat()\n","\n","# get class weights\n","label_filelist = tf.data.Dataset.list_files(label_pattern, shuffle=False)\n","label_ds = label_filelist.map(lambda x: map_label(x, h_dec, w_dec))\n","class_weights = get_class_weights(label_ds).tolist()\n","class_weights = {i: class_weights[i ] for i in range(0, len(class_weights))} "]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"vYpZV_YoCpCu"},"source":["## **Example (Image,Label) pair from the training set**"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["for img,iml in train_single_ds.take(1):\n"," plt.figure(figsize=(15,10))\n"," plt.subplot(1,2,1)\n"," plt.imshow(img.numpy()[0,:,:,:])\n"," plt.subplot(1,2,2)\n"," plt.imshow(iml.numpy()[0,:,:,0])"]},{"cell_type":"markdown","metadata":{},"source":["## Losses and Metrics"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')\n","loss_1 = tf.keras.losses.SparseCategoricalCrossentropy(name='loss_1')\n","loss_2 = tf.keras.losses.SparseCategoricalCrossentropy(name='loss_2')"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"3zBhPKs-CpC3"},"source":["## **1 - Two stage training: first Encoder then Decoder**"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"DUQLopwUxm4f"},"source":["### Training the encoder"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["Enet = EnetModel(C=12,MultiObjective=True,l2=1e-3)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["for layer in Enet.layers[-6:]:\n"," layer.trainable = False"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["# compile model: only the first objective matters\n","n_epochs = 60\n","adam_optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4)\n","Enet.compile(optimizer=adam_optimizer,\n"," loss=[loss_1,loss_2],\n"," metrics=accuracy,\n"," loss_weights=[1.0,0.0])"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["enet_enc_history = Enet.fit(x= train_double_ds,\n"," epochs=n_epochs,\n"," steps_per_epoch=n_train//batch_size,\n"," validation_data= val_double_ds,\n"," validation_steps=n_val//batch_size//5)"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"Ujw0eGL1xuKk"},"source":["### Training the decoder"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{},"colab_type":"code","id":"_RhYf5VQxGfy"},"outputs":[],"source":["for layer in Enet.layers[-6:]:\n"," layer.trainable = True\n","for layer in Enet.layers[:-6]:\n"," layer.trainable = False"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{},"colab_type":"code","id":"QTaImgvixYQg"},"outputs":[],"source":["# compile model: only the first objective matters\n","n_epochs = 60\n","adam_optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4)\n","Enet.compile(optimizer=adam_optimizer,\n"," loss=[loss_1,loss_2],\n"," metrics=accuracy,\n"," loss_weights=[0.0,1.0])"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"colab_type":"code","executionInfo":{"elapsed":1418508,"status":"ok","timestamp":1573586741621,"user":{"displayName":"Giovanni Pellegrini","photoUrl":"https://lh3.googleusercontent.com/a-/AAuE7mCfIl0LjjjdwnrOebKDMDcTn1hUBx5REgwi7mJHEA=s64","userId":"14348157764532225738"},"user_tz":-60},"id":"vHBuIqOixfDv","outputId":"81dd2129-1431-4c6b-99d4-6c445aebf42f"},"outputs":[],"source":["enet_dec_history = Enet.fit(x= train_double_ds,\n"," epochs=n_epochs,\n"," steps_per_epoch=n_train//batch_size,\n"," validation_data= val_double_ds,\n"," validation_steps=n_val//batch_size//5)"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"txPR_S4VbDFy"},"source":["### Check performance"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":71},"colab_type":"code","executionInfo":{"elapsed":1421537,"status":"ok","timestamp":1573586744665,"user":{"displayName":"Giovanni Pellegrini","photoUrl":"https://lh3.googleusercontent.com/a-/AAuE7mCfIl0LjjjdwnrOebKDMDcTn1hUBx5REgwi7mJHEA=s64","userId":"14348157764532225738"},"user_tz":-60},"id":"uDwOuc9ZbGYF","outputId":"83661eab-ae6f-43b7-aca3-900192c4cd0b"},"outputs":[],"source":["Enet.evaluate(x=test_double_ds,steps=n_test//batch_size)"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":513},"colab_type":"code","executionInfo":{"elapsed":1421530,"status":"ok","timestamp":1573586744670,"user":{"displayName":"Giovanni Pellegrini","photoUrl":"https://lh3.googleusercontent.com/a-/AAuE7mCfIl0LjjjdwnrOebKDMDcTn1hUBx5REgwi7mJHEA=s64","userId":"14348157764532225738"},"user_tz":-60},"id":"jYhWcnJcbKM0","outputId":"a35a716b-c009-4fbc-f70d-487caf92ffc5"},"outputs":[],"source":["loss = enet_dec_history.history['loss']\n","val_loss = enet_dec_history.history['val_loss']\n","acc = enet_dec_history.history['output_2_accuracy']\n","val_acc = enet_dec_history.history['val_output_2_accuracy']\n","\n","epochs = range(n_epochs)\n","\n","plt.figure(figsize=(12,8))\n","plt.plot(epochs, loss/np.max(loss), 'r', label='Training loss')\n","plt.plot(epochs, val_loss/np.max(val_loss), 'b', label='Validation loss')\n","plt.plot(epochs, acc, 'r:', label='Training accuracy')\n","plt.plot(epochs, val_acc, 'b:', label='Validation accuracy')\n","\n","plt.title('Training and Validation Loss')\n","plt.xlabel('Epoch')\n","plt.ylabel('Loss Value')\n","plt.ylim([0, 1])\n","plt.legend()\n","plt.show()"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"4MtlcLzOJ3cn"},"source":["## **2 - Training both objectives simultaneously**"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"b6B-3jiAgbY4"},"source":["###Training"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{},"colab_type":"code","id":"0CCZZVlyJ81i"},"outputs":[],"source":["EnetMulti = EnetModel(C=12,MultiObjective=True,l2=1e-3)"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{},"colab_type":"code","id":"jrX-MW1aKDeh"},"outputs":[],"source":["# compile model: only the first objective matters\n","n_epochs = 80\n","adam_optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4)\n","EnetMulti.compile(optimizer=adam_optimizer,\n"," loss=[loss_1,loss_2],\n"," metrics=accuracy,\n"," loss_weights=[0.5,0.5])"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":0},"colab_type":"code","executionInfo":{"elapsed":2096493,"status":"ok","timestamp":1573587419672,"user":{"displayName":"Giovanni Pellegrini","photoUrl":"https://lh3.googleusercontent.com/a-/AAuE7mCfIl0LjjjdwnrOebKDMDcTn1hUBx5REgwi7mJHEA=s64","userId":"14348157764532225738"},"user_tz":-60},"id":"F_m8T8kTKPK8","outputId":"d548e126-b14a-4897-84da-05134afb8273"},"outputs":[],"source":["enet_multi_history = EnetMulti.fit(x= train_double_ds,\n"," epochs=n_epochs,\n"," steps_per_epoch=n_train//batch_size,\n"," validation_data= val_double_ds,\n"," validation_steps=n_val//batch_size//5)"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"rVVrPhqGcRKD"},"source":["### Check performance"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":51},"colab_type":"code","executionInfo":{"elapsed":2098622,"status":"ok","timestamp":1573587421812,"user":{"displayName":"Giovanni Pellegrini","photoUrl":"https://lh3.googleusercontent.com/a-/AAuE7mCfIl0LjjjdwnrOebKDMDcTn1hUBx5REgwi7mJHEA=s64","userId":"14348157764532225738"},"user_tz":-60},"id":"pMYbsKQ5cZuF","outputId":"861613ec-d1f3-4725-c586-cc6edcf1fb2b"},"outputs":[],"source":["EnetMulti.evaluate(x=test_double_ds,steps=n_test//batch_size)"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":513},"colab_type":"code","executionInfo":{"elapsed":2099139,"status":"ok","timestamp":1573587422341,"user":{"displayName":"Giovanni Pellegrini","photoUrl":"https://lh3.googleusercontent.com/a-/AAuE7mCfIl0LjjjdwnrOebKDMDcTn1hUBx5REgwi7mJHEA=s64","userId":"14348157764532225738"},"user_tz":-60},"id":"QrSUQ4wYcecC","outputId":"62bc45f8-13e1-40be-f3dd-665eaed26eab"},"outputs":[],"source":["loss = enet_multi_history.history['loss']\n","val_loss = enet_multi_history.history['val_loss']\n","acc = enet_multi_history.history['output_2_accuracy']\n","val_acc = enet_multi_history.history['val_output_2_accuracy']\n","\n","epochs = range(n_epochs)\n","\n","plt.figure(figsize=(12,8))\n","plt.plot(epochs, loss/np.max(loss), 'r', label='Training loss')\n","plt.plot(epochs, val_loss/np.max(val_loss), 'b', label='Validation loss')\n","plt.plot(epochs, acc, 'r:', label='Training accuracy')\n","plt.plot(epochs, val_acc, 'b:', label='Validation accuracy')\n","\n","plt.title('Training and Validation Loss')\n","plt.xlabel('Epoch')\n","plt.ylabel('Loss Value')\n","plt.ylim([0, 1])\n","plt.legend()\n","plt.show()"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"kI9gr7jTdCgc"},"source":["## **3 - EndtoEnd Training**"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{},"colab_type":"code","id":"edMy9qeSdFdm"},"outputs":[],"source":["EnetEndToEnd = EnetModel(C=12,MultiObjective=False,l2=1e-3)"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{},"colab_type":"code","id":"tcG5RD2PdLhr"},"outputs":[],"source":["# compile model: only the first objective matters\n","n_epochs = 80\n","adam_optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4)\n","EnetEndToEnd.compile(optimizer=adam_optimizer,\n"," loss=loss_1,\n"," metrics=accuracy)"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"colab_type":"code","executionInfo":{"elapsed":2982715,"status":"ok","timestamp":1573588305963,"user":{"displayName":"Giovanni Pellegrini","photoUrl":"https://lh3.googleusercontent.com/a-/AAuE7mCfIl0LjjjdwnrOebKDMDcTn1hUBx5REgwi7mJHEA=s64","userId":"14348157764532225738"},"user_tz":-60},"id":"3UZlMLV7dZY6","outputId":"a370436f-c730-4970-dbe9-58be68c56f4d"},"outputs":[],"source":["enet_endtoend_history = EnetEndToEnd.fit(x= train_single_ds,\n"," epochs=n_epochs,\n"," steps_per_epoch=n_train//batch_size,\n"," validation_data= val_single_ds,\n"," validation_steps=n_val//batch_size//5)"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"Y6W5e9Oydz3t"},"source":["### Check performance"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":51},"colab_type":"code","executionInfo":{"elapsed":2985964,"status":"ok","timestamp":1573588309222,"user":{"displayName":"Giovanni Pellegrini","photoUrl":"https://lh3.googleusercontent.com/a-/AAuE7mCfIl0LjjjdwnrOebKDMDcTn1hUBx5REgwi7mJHEA=s64","userId":"14348157764532225738"},"user_tz":-60},"id":"LLduMggMd3b1","outputId":"de2a95d1-a5c4-4ad9-9be2-c859c57307b3"},"outputs":[],"source":["EnetEndToEnd.evaluate(x=test_single_ds,steps=n_test//batch_size)"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":513},"colab_type":"code","executionInfo":{"elapsed":2985956,"status":"ok","timestamp":1573588309223,"user":{"displayName":"Giovanni Pellegrini","photoUrl":"https://lh3.googleusercontent.com/a-/AAuE7mCfIl0LjjjdwnrOebKDMDcTn1hUBx5REgwi7mJHEA=s64","userId":"14348157764532225738"},"user_tz":-60},"id":"q7qgxdDQd3m-","outputId":"ad24ea1b-8d4a-4dfe-e62d-4c32b7fdeabe"},"outputs":[],"source":["loss = enet_endtoend_history.history['loss']\n","val_loss = enet_endtoend_history.history['val_loss']\n","acc = enet_endtoend_history.history['accuracy']\n","val_acc = enet_endtoend_history.history['val_accuracy']\n","\n","epochs = range(n_epochs)\n","\n","plt.figure(figsize=(12,8))\n","plt.plot(epochs, loss/np.max(loss), 'r', label='Training loss')\n","plt.plot(epochs, val_loss/np.max(val_loss), 'b', label='Validation loss')\n","plt.plot(epochs, acc, 'r:', label='Training accuracy')\n","plt.plot(epochs, val_acc, 'b:', label='Validation accuracy')\n","\n","plt.title('Training and Validation Loss')\n","plt.xlabel('Epoch')\n","plt.ylabel('Loss Value')\n","plt.ylim([0, 1])\n","plt.legend()\n","plt.show()"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"juTnAVjQeA01"},"source":["## **Test Masks**"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":730},"colab_type":"code","executionInfo":{"elapsed":2879,"status":"ok","timestamp":1573596705474,"user":{"displayName":"Giovanni Pellegrini","photoUrl":"https://lh3.googleusercontent.com/a-/AAuE7mCfIl0LjjjdwnrOebKDMDcTn1hUBx5REgwi7mJHEA=s64","userId":"14348157764532225738"},"user_tz":-60},"id":"-vHoAvya0-S3","outputId":"e25a3f24-5575-483f-c058-5fb7f2ef2715"},"outputs":[],"source":["def create_mask(pred_mask):\n"," pred_mask = tf.argmax(pred_mask, axis=-1)\n"," pred_mask = pred_mask[..., tf.newaxis]\n"," return pred_mask[0]\n","for img,iml in train_dec_ds.take(10):\n"," img_test = img\n"," iml_test = iml\n","\n","img_enc_probs, img_dec_probs = Enet(img_test[0:1,:,:,:])\n","img_enc_probs, img_multi_probs = EnetMulti(img_test[0:1,:,:,:])\n","img_endtoend_probs = EnetEndToEnd(img_test[0:1,:,:,:])\n","img_dec_out = create_mask(img_dec_probs)\n","img_multi_out = create_mask(img_multi_probs)\n","img_endtoend_out = create_mask(img_endtoend_probs)\n","\n","plt.figure(figsize=(20,10))\n","plt.subplot(2,3,1)\n","plt.xticks([])\n","plt.yticks([])\n","plt.title('Image',fontdict={'fontsize':20})\n","plt.imshow(img_test.numpy()[0,:,:,:])\n","\n","plt.subplot(2,3,2)\n","plt.xticks([])\n","plt.yticks([])\n","plt.title('Ground Truth',fontdict={'fontsize':20})\n","plt.imshow(iml_test.numpy()[0,:,:,0])\n","\n","plt.subplot(2,3,4)\n","plt.imshow(img_dec_out[:,:,0])\n","plt.xticks([])\n","plt.yticks([])\n","plt.title('Encoder + Decoder',fontdict={'fontsize':20})\n","\n","plt.subplot(2,3,5)\n","plt.xticks([])\n","plt.yticks([])\n","plt.title('Multiple Objectives',fontdict={'fontsize':20})\n","plt.imshow(img_multi_out[:,:,0])\n","\n","plt.subplot(2,3,6)\n","plt.xticks([])\n","plt.yticks([])\n","plt.title('End to End',fontdict={'fontsize':20})\n","plt.imshow(img_endtoend_out[:,:,0])\n","\n","plt.tight_layout()\n","plt.savefig('./segmentation.png')"]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"z58tXooI4RAG"},"source":["# **Save models**\n","You can download them in your google drive. Mount it with che command below and drag and drop the weight files"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":122},"colab_type":"code","executionInfo":{"elapsed":31476,"status":"ok","timestamp":1573589870868,"user":{"displayName":"Giovanni Pellegrini","photoUrl":"https://lh3.googleusercontent.com/a-/AAuE7mCfIl0LjjjdwnrOebKDMDcTn1hUBx5REgwi7mJHEA=s64","userId":"14348157764532225738"},"user_tz":-60},"id":"E0Gnd9YyQ17l","outputId":"dfcffeb9-6249-4fbc-ea09-7f43de06ba06"},"outputs":[],"source":["from google.colab import drive\n","drive.mount('/content/drive')"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{},"colab_type":"code","id":"n8AP62KXhyNf"},"outputs":[],"source":["Enet.save_weights('Enet.tf')\n","EnetMulti.save_weights('EnetMulti.tf')\n","EnetEndToEnd.save_weights('EnetEndToEnd.tf')"]}]} -------------------------------------------------------------------------------- /notebooks/Enet FaceSegmentation Inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "name": "python3", 7 | "display_name": "Python 3" 8 | }, 9 | "language_info": { 10 | "codemirror_mode": { 11 | "name": "ipython", 12 | "version": 3 13 | }, 14 | "file_extension": ".py", 15 | "mimetype": "text/x-python", 16 | "name": "python", 17 | "nbconvert_exporter": "python", 18 | "pygments_lexer": "ipython3", 19 | "version": "3.8.1-final" 20 | }, 21 | "colab": { 22 | "name": "Enet FaceSegmentation Inference.ipynb", 23 | "provenance": [], 24 | "collapsed_sections": [], 25 | "toc_visible": true, 26 | "machine_shape": "hm" 27 | }, 28 | "accelerator": "GPU" 29 | }, 30 | "cells": [ 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "colab_type": "text", 35 | "id": "Sz4BIEzsmqnN" 36 | }, 37 | "source": [ 38 | "## **Data and code setup**" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 0, 44 | "metadata": { 45 | "colab": {}, 46 | "colab_type": "code", 47 | "id": "pIZUbbQAqJSL" 48 | }, 49 | "outputs": [], 50 | "source": [ 51 | "%%capture\n", 52 | "!pip3 install gdown\n", 53 | "!pip3 install gpustat" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 0, 59 | "metadata": { 60 | "colab": {}, 61 | "colab_type": "code", 62 | "id": "dKE3hFiLqkwD" 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "%%capture\n", 67 | "!git clone https://github.com/gevero/enet_tensorflow.git" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 0, 73 | "metadata": { 74 | "colab": {}, 75 | "colab_type": "code", 76 | "id": "fF5l82H-mwH8" 77 | }, 78 | "outputs": [], 79 | "source": [ 80 | "%%capture\n", 81 | "!gdown https://drive.google.com/uc?id=1zQ6PCA7k-1d_s_zrZWftJ0OgS23wKIT_ -O EnetWeights.zip\n", 82 | "!unzip -o EnetWeights.zip" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": { 88 | "colab_type": "text", 89 | "id": "2Ao_SIGICpCZ" 90 | }, 91 | "source": [ 92 | "## **Notebook Setup**" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 0, 98 | "metadata": { 99 | "colab": {}, 100 | "colab_type": "code", 101 | "id": "3p_d27ibCuOW" 102 | }, 103 | "outputs": [], 104 | "source": [ 105 | "# update to tf 2.0\n", 106 | "from __future__ import absolute_import, division, print_function, unicode_literals\n", 107 | "\n", 108 | "# Install TensorFlow\n", 109 | "try:\n", 110 | " # %tensorflow_version only exists in Colab.\n", 111 | " %tensorflow_version 2.1\n", 112 | "except Exception:\n", 113 | " pass\n", 114 | "\n", 115 | "# importing standard libraries\n", 116 | "import tensorflow as tf\n", 117 | "print(tf.__version__)\n", 118 | "import matplotlib.pylab as plt\n", 119 | "import numpy as np\n", 120 | "import os, os.path\n", 121 | "from functools import partial\n", 122 | "from google.colab import files\n", 123 | "\n", 124 | "# Importing utils and models\n", 125 | "import time\n", 126 | "import sys\n", 127 | "sys.path.append('./enet_tensorflow')\n", 128 | "from utils import preprocess_img_label, map_singlehead, map_doublehead, map_label, tf_dataset_generator, get_class_weights\n", 129 | "from models import EnetModel" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": { 135 | "colab_type": "text", 136 | "id": "Q28wciECoFxB" 137 | }, 138 | "source": [ 139 | "## **Selenium setup**" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 0, 145 | "metadata": { 146 | "colab": {}, 147 | "colab_type": "code", 148 | "id": "Hi6LdNCQoK3i" 149 | }, 150 | "outputs": [], 151 | "source": [ 152 | "%%capture\n", 153 | "!apt install chromium-chromedriver\n", 154 | "!cp /usr/lib/chromium-browser/chromedriver /usr/bin\n", 155 | "!pip install selenium" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 0, 161 | "metadata": { 162 | "colab": {}, 163 | "colab_type": "code", 164 | "id": "O9OP1gFwoO2a" 165 | }, 166 | "outputs": [], 167 | "source": [ 168 | "from selenium import webdriver\n", 169 | "\n", 170 | "# download automatically without dialog box\n", 171 | "prefs = {'profile.default_content_setting_values.automatic_downloads': 1}\n", 172 | "\n", 173 | "# set necessary options for headless working\n", 174 | "options = webdriver.ChromeOptions()\n", 175 | "options.add_argument('--headless')\n", 176 | "options.add_argument('--no-sandbox')\n", 177 | "options.add_argument('--disable-dev-shm-usage')\n", 178 | "options.add_experimental_option(\"prefs\", prefs)\n", 179 | "\n", 180 | "# create webdriver\n", 181 | "wd = webdriver.Chrome('chromedriver',options=options)" 182 | ] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "metadata": { 187 | "colab_type": "text", 188 | "id": "3zBhPKs-CpC3" 189 | }, 190 | "source": [ 191 | "## **Load Weights**" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 7, 197 | "metadata": { 198 | "colab": { 199 | "base_uri": "https://localhost:8080/", 200 | "height": 34 201 | }, 202 | "colab_type": "code", 203 | "id": "QZVBH5whCuOl", 204 | "outputId": "93e52eeb-6b37-4305-8e5d-873588e7b575" 205 | }, 206 | "outputs": [ 207 | { 208 | "data": { 209 | "text/plain": [ 210 | "" 211 | ] 212 | }, 213 | "execution_count": 7, 214 | "metadata": { 215 | "tags": [] 216 | }, 217 | "output_type": "execute_result" 218 | } 219 | ], 220 | "source": [ 221 | "Enet = EnetModel(C=3,MultiObjective=True,l2=1e-3)\n", 222 | "Enet.load_weights('./Enet512x512.tf')" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "metadata": { 228 | "colab_type": "text", 229 | "id": "juTnAVjQeA01" 230 | }, 231 | "source": [ 232 | "## **Get/refresh image from \"this person does not exist\"**" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 0, 238 | "metadata": { 239 | "colab": {}, 240 | "colab_type": "code", 241 | "id": "7zuecXGPqMiP" 242 | }, 243 | "outputs": [], 244 | "source": [ 245 | "# get or refresh the image\n", 246 | "wd.get('https://www.thispersondoesnotexist.com')\n", 247 | "time.sleep(2)\n", 248 | "\n", 249 | "# buttons\n", 250 | "save_button = wd.find_element_by_id('saveButton')\n", 251 | "another_button = wd.find_element_by_xpath('//*[@title=\"Save this person\"]')\n", 252 | "time.sleep(2)\n", 253 | "\n", 254 | "# click them to get image\n", 255 | "another_button.click()\n", 256 | "save_button.click()" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 0, 262 | "metadata": { 263 | "colab": {}, 264 | "colab_type": "code", 265 | "id": "QnT62GNaGaKG" 266 | }, 267 | "outputs": [], 268 | "source": [ 269 | " # decoding image\n", 270 | " img = tf.io.read_file('./person.jpg')\n", 271 | " img = tf.image.decode_jpeg(img, channels=3)\n", 272 | " img = tf.image.convert_image_dtype(img, tf.float32)\n", 273 | " img = tf.image.resize(img,(512,512))\n", 274 | " img = tf.reshape(img,[1,512,512,3])" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 0, 280 | "metadata": { 281 | "colab": {}, 282 | "colab_type": "code", 283 | "id": "-vHoAvya0-S3" 284 | }, 285 | "outputs": [], 286 | "source": [ 287 | "def create_mask(pred_mask):\n", 288 | " pred_mask = tf.argmax(pred_mask, axis=-1)\n", 289 | " pred_mask = pred_mask[..., tf.newaxis]\n", 290 | "\n", 291 | " return pred_mask[0]\n", 292 | "\n", 293 | "img_enc_probs, img_dec_probs = Enet(img[0:1,:,:,:])\n", 294 | "img_dec_out = create_mask(img_dec_probs)\n", 295 | "\n", 296 | "# image\n", 297 | "fig = plt.figure(figsize=(20,10))\n", 298 | "plt.subplot(1,3,1)\n", 299 | "plt.xticks([])\n", 300 | "plt.yticks([])\n", 301 | "plt.imshow(img.numpy()[0,:,:,:])\n", 302 | "\n", 303 | "# mask\n", 304 | "plt.subplot(1,3,2)\n", 305 | "plt.xticks([])\n", 306 | "plt.yticks([])\n", 307 | "plt.imshow(img_dec_out[:,:,0],cmap='viridis')\n", 308 | "\n", 309 | "# image + mask\n", 310 | "plt.subplot(1,3,3)\n", 311 | "plt.xticks([])\n", 312 | "plt.yticks([])\n", 313 | "plt.imshow(img.numpy()[0,:,:,:])\n", 314 | "plt.imshow(img_dec_out[:,:,0], alpha=0.5,cmap='viridis')\n", 315 | "\n", 316 | "plt.tight_layout()\n", 317 | "fig.subplots_adjust(wspace=0.0, hspace=0.0)" 318 | ] 319 | } 320 | ] 321 | } -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # ENet 2 | 3 | This repository contains a tensorflow 2.0 implementation of Enet as in: 4 | 5 | Paszke, A.; Chaurasia, A.; Kim, S.; Culurciello, E. ENet: **A Deep Neural Network Architecture for Real-Time Semantic Segmentation.** [arXiv:1606.02147 [cs] 2016.](https://arxiv.org/pdf/1606.02147.pdf) 6 | 7 | Enet is a lightweight neural network geared towards image segmentation for real time applications. This tensorflow 2.0 implementation is greatly indebted with the PyTorch [ENet - Real Time Semantic Segmentation](https://github.com/iArunava/ENet-Real-Time-Semantic-Segmentation) implementation by [iArunava](https://github.com/iArunava). 8 | 9 | # CamVid: Try it out 10 | You can try it out directly in this [Colab notebook](https://colab.research.google.com/github/gevero/enet_tensorflow/blob/master/notebooks/Enet%20CamVid%20Training.ipynb). In the notebook, Enet is trained in three different ways for comparison: 11 | 12 | - **First the Encoder and then the Decoder:** as in the original paper, we first train the encoder, freeze the weights and then train the decoder. This approach provides the more stable training. 13 | - **Encoder and Decoder simultaneously with two objectives:** an approach similar to the original paper: we train Enet in one go but with the stabilty benefits of the original approach. 14 | - **End to End:** we train Enet in one go. It is the least stable method, albeit the simplest. 15 | 16 | ## CamVid pretrained weights 17 | 18 | You can find them [here](https://drive.google.com/open?id=1rQN_855G-iHZkPe7KEI-P5PF8U4uIf40) for the [CamVid](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/) dataset. The weights for different datasets will be released as soon as possible. 19 | 20 | ## A typical example 21 | ![TestImg](https://github.com/gevero/enet_tensorflow/blob/master/images/SegmentationExample.png) 22 | 23 | # Segment faces that do not exist 24 | If instead you prefer something different, you can try a version of Enet trained on a face segmentation dataset built upon [CelebHair](https://github.com/ileniTudor/Face-Hair-Segmentation-Dataset) and [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). As for the CamVid dataset, you can download the pretrained weights [here](https://drive.google.com/open?id=1zQ6PCA7k-1d_s_zrZWftJ0OgS23wKIT_). If you want to immediately try out face segmentation, you can do it with this [Colab notebook](https://colab.research.google.com/github/gevero/enet_tensorflow/blob/master/notebooks/Enet%20FaceSegmentation%20Inference.ipynb). Enet will run on resized images generated by [ThisPersonDoesNotExist](https://www.thispersondoesnotexist.com/). 25 | 26 | ![TestImg](https://github.com/gevero/enet_tensorflow/blob/master/images/ThisSegmentationDoesNotExist.png) 27 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | from train import * 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser() 7 | 8 | parser.add_argument('-m', 9 | type=str, 10 | default='./datasets/CamVid/ckpt-camvid-enet.pth', 11 | help='The path to the pretrained enet model') 12 | 13 | parser.add_argument( 14 | '-i', 15 | '--image-path', 16 | type=str, 17 | help='The path to the image to perform semantic segmentation') 18 | 19 | parser.add_argument('-sm', 20 | '--save_model', 21 | type=str, 22 | default='./Enet.tf', 23 | help='Tensorflow model save file') 24 | 25 | parser.add_argument('-tbl', 26 | '--tensorboard_logs', 27 | type=str, 28 | default='./tb_logs/', 29 | help='Tensorboard logs folder') 30 | 31 | parser.add_argument('-tt', 32 | '--training-type', 33 | type=int, 34 | default=1, 35 | help='0: end->dec, 1: end+dec, 2: dec') 36 | 37 | parser.add_argument('-ih', 38 | '--img-height', 39 | type=int, 40 | default=360, 41 | help='The height for the resized image') 42 | 43 | parser.add_argument('-iw', 44 | '--img-width', 45 | type=int, 46 | default=480, 47 | help='The width for the resized image') 48 | 49 | parser.add_argument('-lr', 50 | '--learning-rate', 51 | type=float, 52 | default=5e-4, 53 | help='The learning rate') 54 | 55 | parser.add_argument('-bs', 56 | '--batch-size', 57 | type=int, 58 | default=8, 59 | help='The batch size') 60 | 61 | parser.add_argument('-wd', 62 | '--weight-decay', 63 | type=float, 64 | default=2e-4, 65 | help='The weight decay') 66 | 67 | parser.add_argument('-e', 68 | '--epochs', 69 | type=int, 70 | default=10, 71 | help='The number of epochs') 72 | 73 | parser.add_argument('-nc', 74 | '--num-classes', 75 | type=int, 76 | default=12, 77 | help='The number of epochs') 78 | 79 | parser.add_argument( 80 | '-se', 81 | '--save-every', 82 | type=int, 83 | default=10, 84 | help='The number of epochs after which to save a model') 85 | 86 | parser.add_argument('-iptr', 87 | '--img-pattern', 88 | type=str, 89 | default='./datasets/CamVid/train/images/*.png', 90 | help='The path to the input dataset') 91 | 92 | parser.add_argument('-lptr', 93 | '--label-pattern', 94 | type=str, 95 | default='./datasets/CamVid/train/labels/*.png', 96 | help='The path to the label dataset') 97 | 98 | parser.add_argument( 99 | '-ipv', 100 | '--img-pattern-val', 101 | type=str, 102 | # default='./datasets/CamVid/val/images/*.png', 103 | default= 104 | './datasets/cityscapes/leftImg8bit_trainvaltest/leftImg8bit/train/*/*.png', 105 | help='The path to the input dataset') 106 | 107 | parser.add_argument( 108 | '-lpv', 109 | '--label-pattern-val', 110 | type=str, 111 | # default='./datasets/CamVid/val/labels/*.png', 112 | default= 113 | './datasets/cityscapes/gtFine_trainvaltest/gtFine/train/*/*labelIds*.png', 114 | help='The path to the label dataset') 115 | 116 | parser.add_argument('-iptt', 117 | '--img-pattern-test', 118 | type=str, 119 | default='./datasets/CamVid/test/images/*.png', 120 | help='The path to the input dataset') 121 | 122 | parser.add_argument('-lptt', 123 | '--label-pattern-test', 124 | type=str, 125 | default='./datasets/CamVid/test/labels/*.png', 126 | help='The path to the label dataset') 127 | 128 | parser.add_argument( 129 | '-ctr', 130 | '--cache-train', 131 | type=str, 132 | default='', 133 | help='Filename to cache the training data: if empty cache in memory') 134 | 135 | parser.add_argument( 136 | '-cv', 137 | '--cache-val', 138 | type=str, 139 | default='', 140 | help='Filename to cache the validation data: if empty cache in memory') 141 | 142 | parser.add_argument( 143 | '-ctt', 144 | '--cache-test', 145 | type=str, 146 | default='', 147 | help='Filename to cache the test data: if empty cache in memory') 148 | 149 | parser.add_argument('--mode', 150 | choices=['train', 'test'], 151 | default='train', 152 | help='Whether to train or test') 153 | 154 | FLAGS, unparsed = parser.parse_known_args() 155 | 156 | if FLAGS.mode.lower() == 'train': 157 | train(FLAGS) 158 | elif FLAGS.mode.lower() == 'test': 159 | test(FLAGS) 160 | else: 161 | raise RuntimeError( 162 | 'Unknown mode passed. \n Mode passed should be either \ 163 | of "train" or "test"') 164 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # importing standard libraries 2 | import tensorflow as tf 3 | import numpy as np 4 | import datetime 5 | 6 | # Importing utils and models 7 | from functools import partial 8 | from utils import preprocess_img_label, map_singlehead, map_doublehead 9 | from utils import tf_dataset_generator, get_class_weights 10 | from models import EnetModel 11 | 12 | 13 | def train(FLAGS): 14 | 15 | # -------------------- Defining the hyperparameters -------------------- 16 | batch_size = FLAGS.batch_size # 17 | epochs = FLAGS.epochs # 18 | training_type = FLAGS.training_type # 19 | learning_rate = FLAGS.learning_rate # 20 | save_every = FLAGS.save_every # 21 | num_classes = FLAGS.num_classes # 22 | weight_decay = FLAGS.weight_decay 23 | img_pattern = FLAGS.img_pattern # 24 | label_pattern = FLAGS.label_pattern # 25 | img_pattern_val = FLAGS.img_pattern_val # 26 | label_pattern_val = FLAGS.label_pattern_val # 27 | tb_logs = FLAGS.tensorboard_logs # 28 | img_width = FLAGS.img_width # 29 | img_height = FLAGS.img_height # 30 | save_model = FLAGS.save_model # 31 | cache_train = FLAGS.cache_train # 32 | cache_val = FLAGS.cache_val # 33 | cache_test = FLAGS.cache_test # 34 | print('[INFO]Defined all the hyperparameters successfully!') 35 | 36 | # setup tensorboard 37 | log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 38 | tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, 39 | histogram_freq=1) 40 | 41 | # encoder and decoder dimensions 42 | h_enc = img_height // 8 43 | w_enc = img_width // 8 44 | h_dec = img_height 45 | w_dec = img_width 46 | 47 | # create (img,label) string tensor lists 48 | filelist_train = preprocess_img_label(img_pattern, label_pattern) 49 | filelist_val = preprocess_img_label(img_pattern_val, label_pattern_val) 50 | 51 | # training dataset size 52 | n_train = tf.data.experimental.cardinality(filelist_train).numpy() 53 | n_val = tf.data.experimental.cardinality(filelist_val).numpy() 54 | 55 | # define mapping functions for single and double head nets 56 | map_single = partial(map_singlehead, h_img=h_dec, w_img=w_dec) 57 | map_double = partial(map_doublehead, 58 | h_enc=h_enc, 59 | w_enc=w_enc, 60 | h_dec=h_dec, 61 | w_dec=w_dec) 62 | 63 | # create dataset 64 | if training_type == 0 or training_type == 1: 65 | map_fn = map_double 66 | else: 67 | map_fn = map_single 68 | train_ds = filelist_train.shuffle(n_train).map(map_fn).cache( 69 | cache_train).batch(batch_size).repeat() 70 | val_ds = filelist_val.map(map_fn).cache(cache_val).batch( 71 | batch_size).repeat() 72 | 73 | # final training and validation datasets 74 | 75 | # -------------------- get the class weights -------------------- 76 | print('[INFO]Starting to define the class weights...') 77 | label_filelist = tf.data.Dataset.list_files(label_pattern, shuffle=False) 78 | label_ds = label_filelist.map(lambda x: process_label(x, h_dec, w_dec)) 79 | class_weights = get_class_weights(label_ds).tolist() 80 | print('[INFO]Fetched all class weights successfully!') 81 | 82 | # -------------------- istantiate model -------------------- 83 | if training_type == 0 or training_type == 1: 84 | Enet = EnetModel(C=num_classes, MultiObjective=True, l2=weight_decay) 85 | else: 86 | Enet = EnetModel(C=num_classes, l2=weight_decay) 87 | print('[INFO]Model Instantiated!') 88 | 89 | # -------------------- start training -------------------- 90 | adam_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) 91 | 92 | # -- two stages training -- 93 | if training_type == 0: 94 | 95 | # freeze decoder layers 96 | for layer in Enet.layers[-6:]: 97 | layer.trainable = False 98 | 99 | # compile encoder: only the first objective matters 100 | Enet.compile(optimizer=adam_optimizer, 101 | loss=[ 102 | 'sparse_categorical_crossentropy', 103 | 'sparse_categorical_crossentropy' 104 | ], 105 | metrics=['accuracy', 'accuracy'], 106 | loss_weights=[1.0, 0.0]) 107 | 108 | # train encoder 109 | Enet.fit(x=train_ds, 110 | epochs=epochs, 111 | steps_per_epoch=n_train // batch_size, 112 | validation_data=val_ds, 113 | validation_steps=n_val // batch_size // 5, 114 | class_weight=class_weights, 115 | callbacks=[tensorboard_callback]) 116 | 117 | # freeze encoder and unfreeze decoder 118 | for layer in Enet.layers[-6:]: 119 | layer.trainable = True 120 | for layer in Enet.layers[:-6]: 121 | layer.trainable = False 122 | 123 | # compile model: only the second objective matters 124 | Enet.compile(optimizer=adam_optimizer, 125 | loss=[ 126 | 'sparse_categorical_crossentropy', 127 | 'sparse_categorical_crossentropy' 128 | ], 129 | metrics=['accuracy', 'accuracy'], 130 | loss_weights=[0.0, 1.0]) 131 | 132 | # train decoder 133 | enet_hist = Enet.fit(x=train_ds, 134 | epochs=epochs, 135 | steps_per_epoch=n_train // batch_size, 136 | validation_data=val_ds, 137 | validation_steps=n_val // batch_size // 5, 138 | class_weight=class_weights, 139 | callbacks=[tensorboard_callback]) 140 | 141 | # -- simultaneous double objective trainings -- 142 | elif training_type == 1: 143 | 144 | # compile model 145 | Enet.compile(optimizer=adam_optimizer, 146 | loss=[ 147 | 'sparse_categorical_crossentropy', 148 | 'sparse_categorical_crossentropy' 149 | ], 150 | metrics=['accuracy', 'accuracy'], 151 | loss_weights=[0.5, 0.5]) 152 | 153 | # fit model 154 | print('train: ', n_train, 'batch: ', batch_size) 155 | enet_hist = Enet.fit(x=train_ds, 156 | epochs=epochs, 157 | steps_per_epoch=n_train // batch_size, 158 | validation_data=val_ds, 159 | validation_steps=n_val // batch_size // 5, 160 | class_weight=class_weights, 161 | callbacks=[tensorboard_callback]) 162 | 163 | # -- end to end training -- 164 | else: 165 | 166 | # compile model 167 | Enet.compile(optimizer=adam_optimizer, 168 | loss=['sparse_categorical_crossentropy'], 169 | metrics=['accuracy']) 170 | 171 | enet_hist = Enet.fit(x=train_ds, 172 | epochs=epochs, 173 | steps_per_epoch=n_train // batch_size, 174 | validation_data=val_ds, 175 | validation_steps=n_val // batch_size // 5, 176 | class_weight=class_weights, 177 | callbacks=[tensorboard_callback]) 178 | 179 | # -------------------- save model -------------------- 180 | Enet.save_weights(save_model) 181 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def process_img(img_file, h_img, w_img): 6 | ''' 7 | Function to process the image file. It takes the image file 8 | name as input and return the appropriate tensor as output. 9 | 10 | Arguments 11 | ---------- 12 | 'img_file' = String: image filename 13 | 'h_img' = Integer: output tensor height 14 | 'w_img' = Integer: output tensor width 15 | 16 | Returns 17 | ------- 18 | 'img' = image tensor 19 | ''' 20 | 21 | # decoding image 22 | img = tf.io.read_file(img_file) 23 | img = tf.image.decode_png(img, channels=3) 24 | img = tf.image.convert_image_dtype(img, tf.float32) 25 | img = tf.image.resize(img, [h_img, w_img]) 26 | 27 | return img 28 | 29 | 30 | def map_label(label_file, h_iml, w_iml): 31 | ''' 32 | Function to process the label file. It takes the label file 33 | name as input and return the appropriate tensor as output. 34 | 35 | Arguments 36 | ---------- 37 | 'label_file' = String: label filename 38 | 'h_iml' = Integer: output tensor height 39 | 'w_iml' = Integer: output tensor width 40 | 41 | Returns 42 | ------- 43 | 'iml' = image tensor 44 | ''' 45 | 46 | # decoding image 47 | iml = tf.io.read_file(label_file) 48 | iml = tf.image.decode_png(iml, channels=1) 49 | iml = tf.image.convert_image_dtype(iml, tf.uint8) 50 | iml = tf.image.resize(iml, [h_iml, w_iml], method='nearest') 51 | 52 | return iml 53 | 54 | 55 | def tf_dataset_generator(file_pattern, map_fn): 56 | ''' 57 | Creates a training tf.dataset from images or labels matching 58 | the 'file_pattern' in the 'dataset_path'. Here we do not batch 59 | or cache the dataset, because this will be done by chaining 60 | methods in a subsequent passage 61 | 62 | Arguments 63 | ---------- 64 | 'file_pattern' = glob pattern to match dataset files 65 | 'map_fn' = function to map the filename to a tf tensor 66 | 67 | Returns 68 | ------- 69 | 'data_set' = training tf.dataset to plug in in model.fit() 70 | ''' 71 | 72 | # create a list of the training images 73 | data_filelist_ds = tf.data.Dataset.list_files(file_pattern, shuffle=False) 74 | 75 | # create the labeled dataset (returns (img,label) pairs) 76 | data_set = data_filelist_ds.map( 77 | map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) 78 | 79 | return data_set 80 | 81 | 82 | def get_class_weights(data_set, num_classes=12, c=1.02): 83 | ''' 84 | Gets segmentation class weights from the dataset. This 85 | is thought for a large dataset where the calculation 86 | of the weights must be done out of memory. If data_set 87 | is already divided in batches, the calculation should be 88 | rather efficient and the memory should not blow up. 89 | 90 | Arguments 91 | ---------- 92 | 'data_set' = already batched tf.dataset containing the 93 | correctly resized label tensors 94 | 95 | Returns 96 | ------- 97 | 'class_weights' = class weights for the segmentation classes 98 | ''' 99 | 100 | # building a giant array to count how many pixels per label 101 | each_class = 0.0 102 | tot_num_pixel = 0.0 103 | for label in data_set.take(-1): 104 | 105 | # flatten the barch of label arrays 106 | label_array = np.array(label.numpy()).flatten() 107 | 108 | # counting the pixels 109 | each_class = each_class + np.bincount(label_array, 110 | minlength=num_classes) 111 | tot_num_pixel = tot_num_pixel + len(label_array) 112 | 113 | # computing the weights as in the original paper 114 | prospensity_score = each_class / tot_num_pixel 115 | class_weights = 1 / (np.log(c + prospensity_score)) 116 | 117 | return class_weights 118 | 119 | 120 | def preprocess_img_label(img_pattern, label_pattern): 121 | ''' 122 | Creates the string tensor pairs (img_file, label_file) 123 | 124 | Arguments 125 | ---------- 126 | 'img_pattern' = glob pattern to match img files 127 | 'label_pattern' = glob pattern to match label files 128 | 129 | Returns 130 | ------- 131 | 'pair_filelist' = returns the list of pair files 132 | ''' 133 | 134 | # create a list of the training images 135 | img_filelist = tf.data.Dataset.list_files(img_pattern, shuffle=False) 136 | 137 | # create a list of the label images 138 | label_filelist = tf.data.Dataset.list_files(label_pattern, shuffle=False) 139 | 140 | # pair filenames 141 | pair_filelist = tf.data.Dataset.zip((img_filelist, label_filelist)) 142 | 143 | return pair_filelist 144 | 145 | 146 | def map_singlehead(img_file, label_file, h_img, w_img): 147 | ''' 148 | Takes the string tensor pair (img_file, label_file) 149 | and outputs the (img,label) tf tensor pair 150 | 151 | Arguments 152 | ---------- 153 | 'img_file' = string tensor with img filename 154 | 'label_file' = string tensor with label filename 155 | 'h_iml' = Integer: output tensor height 156 | 'w_iml' = Integer: output tensor width 157 | 158 | Returns 159 | ------- 160 | '(img,iml)' = image and label tensor 161 | ''' 162 | 163 | # decoding image 164 | img = tf.io.read_file(img_file) 165 | img = tf.image.decode_png(img, channels=3) 166 | img = tf.image.convert_image_dtype(img, tf.float32) 167 | img = tf.image.resize(img, [h_img, w_img]) 168 | 169 | # decoding label 170 | iml = tf.io.read_file(label_file) 171 | iml = tf.image.decode_png(iml, channels=1) 172 | iml = tf.image.convert_image_dtype(iml, tf.uint8) 173 | iml = tf.image.resize(iml, [h_img, w_img], method='nearest') 174 | 175 | return (img, iml) 176 | 177 | 178 | def map_doublehead(img_file, label_file, h_enc, w_enc, h_dec, w_dec): 179 | ''' 180 | Takes the string tensor pair (img_file, label_file) 181 | and outputs the (img,label) tf tensor pair 182 | 183 | Arguments 184 | ---------- 185 | 'img_file' = string tensor with img filename 186 | 'label_file' = string tensor with label filename 187 | 'h_enc' = Integer: output tensor height at encoder head 188 | 'w_enc' = Integer: output tensor width at encoder head 189 | 'h_dec' = Integer: output tensor height at decoder head 190 | 'w_dec' = Integer: output tensor width at decoder head 191 | 192 | Returns 193 | ------- 194 | '(img,(iml_enc,iml_dec))' = image and label tensors 195 | ''' 196 | 197 | # decoding image 198 | img = tf.io.read_file(img_file) 199 | img = tf.image.decode_png(img, channels=3) 200 | img = tf.image.convert_image_dtype(img, tf.float32) 201 | img = tf.image.resize(img, [h_dec, w_dec]) 202 | 203 | # decoding label 204 | iml = tf.io.read_file(label_file) 205 | iml = tf.image.decode_png(iml, channels=1) 206 | iml = tf.image.convert_image_dtype(iml, tf.uint8) 207 | iml_enc = tf.image.resize(iml, [h_enc, w_enc], method='nearest') 208 | iml_dec = tf.image.resize(iml, [h_dec, w_dec], method='nearest') 209 | return (img, (iml_enc, iml_dec)) 210 | --------------------------------------------------------------------------------