├── img └── result.jpg ├── README.md ├── .gitignore ├── predict.py ├── train.py └── squeezeunet.py /img/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhelontra/squeeze-unet/HEAD/img/result.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Squeeze-unet Semantic Segmentation for embedded devices 2 | 3 | The model is inspired by [Squeezenet](https://arxiv.org/abs/1602.07360) and [U-Net](https://arxiv.org/abs/1505.04597). 4 | 5 | [camvid](https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid), is used as a Dataset. 6 | 7 | ![Example of predicted image.](img/result.jpg) 8 | 9 | ## Requirements 10 | 11 | * Keras 2 12 | 13 | ## About network 14 | 15 | As a typical U-Net architecture, it has encoder and decoder parts, which consist of fire modules proposed by squeezenet. 16 | 17 | ## TODO 18 | 19 | - [ ] Report speed vs accuracy. 20 | - [ ] link for download camvid dataset converted by cvision. 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Virtual env 2 | .virtualenv/ 3 | .idea 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | env/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | eggs/ 20 | .eggs/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | 25 | # PyInstaller 26 | # Usually these files are written by a python script from a template 27 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 28 | *.manifest 29 | 30 | # Installer logs 31 | pip-log.txt 32 | pip-delete-this-directory.txt 33 | 34 | # Unit test / coverage reports 35 | htmlcov/ 36 | .tox/ 37 | .coverage 38 | .coverage.* 39 | .cache 40 | nosetests.xml 41 | coverage.xml 42 | *,cover 43 | .hypothesis/ 44 | 45 | # Django stuff: 46 | *.log 47 | 48 | # Sphinx documentation 49 | docs/_build/ 50 | 51 | # PyBuilder 52 | target/ 53 | 54 | #Ipython Notebook 55 | .ipynb_checkpoints 56 | 57 | # Pyenv 58 | .python-version 59 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import cv2 5 | import numpy as np 6 | from keras.layers import Input 7 | from nnutil.image import load_img, im_standardize 8 | 9 | import cvision 10 | from squeezeunet import SqueezeUNet 11 | 12 | def label_img(img): 13 | img = img.copy() 14 | Sky = [128, 128, 128] 15 | Building = [128, 0, 0] 16 | Pole = [192, 192, 128] 17 | #Road_marking = [255,69,0] 18 | Road = [128, 64, 128] 19 | Pavement = [60, 40, 222] 20 | Tree = [128, 128, 0] 21 | SignSymbol = [192, 128, 128] 22 | Fence = [64, 64, 128] 23 | Car = [64, 0, 128] 24 | Pedestrian = [64, 64, 0] 25 | Bicyclist = [0, 128, 192] 26 | Unlabelled = [0, 0, 0] 27 | 28 | label_colours = [Unlabelled, 29 | Sky, 30 | Building, 31 | Pole, 32 | Road, 33 | Pavement, 34 | Tree, 35 | SignSymbol, 36 | Fence, 37 | Car, 38 | Pedestrian, 39 | Bicyclist] 40 | 41 | categorical_to_id = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] 42 | 43 | if img.ndim < 3 or img.shape[-1] == 1: 44 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 45 | 46 | for l in range(len(categorical_to_id)): 47 | img[np.where((img == [l, l, l]).all(axis=-1))] = label_colours[l] 48 | return img 49 | 50 | 51 | img_rows = 224 52 | img_cols = 224 53 | channels = 3 54 | inputs = Input((img_rows, img_cols, channels)) 55 | 56 | model = SqueezeUNet(inputs, num_classes=12, deconv_ksize=3, activation='softmax') 57 | model.load_weights('squeezeunet.h5') 58 | 59 | pipeline_predict_transform = cvision.importFromSketch("sketchs/predict_pipeline.sketch") 60 | 61 | img = load_img("~/datasets/camvid/images_validation/0001TP_006720.png", color_mode="bgr") 62 | img = pipeline_predict_transform.run(img) 63 | img = im_standardize(img, rescale=None, mean=None, std=None) 64 | 65 | masks = model.predict(np.expand_dims(img, axis=0), verbose=0) 66 | mask = np.argmax(masks[0], axis=2).astype(np.uint8) 67 | mask = label_img(mask) 68 | 69 | cv2.imshow("segmentation", mask) 70 | cv2.waitKey(0) 71 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | from keras.layers import Input 6 | from keras.optimizers import Adam 7 | 8 | import cvision 9 | from nnutil.dataset import ImageDataGenerator 10 | from squeezeunet import SqueezeUNet 11 | 12 | img_rows = 224 13 | img_cols = 224 14 | channels = 3 15 | epochs = 10 16 | batch_size = 1 17 | nb_train_samples = 5000 18 | nb_validation_samples = 2000 19 | save_to_dir = None 20 | inputs = Input((img_rows, img_cols, channels)) 21 | 22 | class_indices = { 23 | "unlabelled" : 0, 24 | "sky" : 1, 25 | "building" : 2, 26 | "pole" : 3, 27 | "road" : 4, 28 | "pavement" : 5, 29 | "tree" : 6, 30 | "signsymbol" : 7, 31 | "fence": 8, 32 | "car" : 9, 33 | "pedestrian" : 10, 34 | "bicyclist" : 11 35 | } 36 | 37 | datasets = ["~/datasets/camvid/train-camvid.json"] 38 | validation_datasets = ["~/datasets/camvid/validation-camvid.json"] 39 | 40 | 41 | pipeline_train_transform = cvision.importFromSketch("sketchs/train_pipeline.sketch") 42 | train_imGenerator = ImageDataGenerator(pipeline_train_transform, 43 | dataset_mean=None, 44 | dataset_std_normalization=None) 45 | 46 | 47 | segmentation_generator = train_imGenerator.flow_json_segmentation(datasets, class_indices=class_indices, 48 | class_mode="categorical", 49 | skipImageNonAnnotations=True, 50 | nonAnnotationLabel=None, 51 | batch_size=batch_size, 52 | mask_rescale=None, 53 | mask_transform=None, 54 | save_to_dir=save_to_dir, 55 | save_prefix='test', 56 | save_format='jpg') 57 | 58 | validation_generator = train_imGenerator.flow_json_segmentation(validation_datasets, class_indices=class_indices, 59 | class_mode="categorical", 60 | skipImageNonAnnotations=True, 61 | nonAnnotationLabel=None, 62 | batch_size=batch_size, 63 | mask_rescale=None, 64 | mask_transform=None, 65 | save_to_dir=save_to_dir, 66 | save_prefix='test', 67 | save_format='jpg') 68 | 69 | model = SqueezeUNet(inputs, num_classes=12, deconv_ksize=3, activation='softmax') 70 | model.compile(loss="categorical_crossentropy", optimizer=Adam(lr=1e-05), metrics=["accuracy"]) 71 | if os.path.exists('squeezeunet.h5'): 72 | model.load_weights('squeezeunet.h5') 73 | 74 | model.fit_generator(segmentation_generator, 75 | class_weight="auto", 76 | steps_per_epoch=nb_train_samples, 77 | epochs=epochs, 78 | validation_data=validation_generator, 79 | validation_steps=nb_validation_samples) 80 | 81 | model.save_weights('squeezeunet.h5', overwrite=True) 82 | -------------------------------------------------------------------------------- /squeezeunet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from keras.models import Model 5 | from keras.layers import Conv2D, MaxPooling2D, UpSampling2D, Dropout 6 | from keras.layers import concatenate, Conv2DTranspose, BatchNormalization 7 | from keras import backend as K 8 | 9 | 10 | def fire_module(x, fire_id, squeeze=16, expand=64): 11 | f_name = "fire{0}/{1}" 12 | channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 13 | 14 | x = Conv2D(squeeze, (1, 1), activation='relu', padding='same', name=f_name.format(fire_id, "squeeze1x1"))(x) 15 | x = BatchNormalization(axis=channel_axis)(x) 16 | 17 | left = Conv2D(expand, (1, 1), activation='relu', padding='same', name=f_name.format(fire_id, "expand1x1"))(x) 18 | right = Conv2D(expand, (3, 3), activation='relu', padding='same', name=f_name.format(fire_id, "expand3x3"))(x) 19 | x = concatenate([left, right], axis=channel_axis, name=f_name.format(fire_id, "concat")) 20 | return x 21 | 22 | 23 | def SqueezeUNet(inputs, num_classes=None, deconv_ksize=3, dropout=0.5, activation='sigmoid'): 24 | """SqueezeUNet is a implementation based in SqueezeNetv1.1 and unet for semantic segmentation 25 | :param inputs: input layer. 26 | :param num_classes: number of classes. 27 | :param deconv_ksize: (width and height) or integer of the 2D deconvolution window. 28 | :param dropout: dropout rate 29 | :param activation: type of activation at the top layer. 30 | :returns: SqueezeUNet model 31 | """ 32 | channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 33 | if num_classes is None: 34 | num_classes = K.int_shape(inputs)[channel_axis] 35 | 36 | x01 = Conv2D(64, (3, 3), strides=(2, 2), padding='same', activation='relu', name='conv1')(inputs) 37 | x02 = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool1', padding='same')(x01) 38 | 39 | x03 = fire_module(x02, fire_id=2, squeeze=16, expand=64) 40 | x04 = fire_module(x03, fire_id=3, squeeze=16, expand=64) 41 | x05 = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool3', padding="same")(x04) 42 | 43 | x06 = fire_module(x05, fire_id=4, squeeze=32, expand=128) 44 | x07 = fire_module(x06, fire_id=5, squeeze=32, expand=128) 45 | x08 = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool5', padding="same")(x07) 46 | 47 | x09 = fire_module(x08, fire_id=6, squeeze=48, expand=192) 48 | x10 = fire_module(x09, fire_id=7, squeeze=48, expand=192) 49 | x11 = fire_module(x10, fire_id=8, squeeze=64, expand=256) 50 | x12 = fire_module(x11, fire_id=9, squeeze=64, expand=256) 51 | 52 | if dropout != 0.0: 53 | x12 = Dropout(dropout)(x12) 54 | 55 | up1 = concatenate([ 56 | Conv2DTranspose(192, deconv_ksize, strides=(1, 1), padding='same')(x12), 57 | x10, 58 | ], axis=channel_axis) 59 | up1 = fire_module(up1, fire_id=10, squeeze=48, expand=192) 60 | 61 | up2 = concatenate([ 62 | Conv2DTranspose(128, deconv_ksize, strides=(1, 1), padding='same')(up1), 63 | x08, 64 | ], axis=channel_axis) 65 | up2 = fire_module(up2, fire_id=11, squeeze=32, expand=128) 66 | 67 | up3 = concatenate([ 68 | Conv2DTranspose(64, deconv_ksize, strides=(2, 2), padding='same')(up2), 69 | x05, 70 | ], axis=channel_axis) 71 | up3 = fire_module(up3, fire_id=12, squeeze=16, expand=64) 72 | 73 | up4 = concatenate([ 74 | Conv2DTranspose(32, deconv_ksize, strides=(2, 2), padding='same')(up3), 75 | x02, 76 | ], axis=channel_axis) 77 | up4 = fire_module(up4, fire_id=13, squeeze=16, expand=32) 78 | up4 = UpSampling2D(size=(2, 2))(up4) 79 | 80 | x = concatenate([up4, x01], axis=channel_axis) 81 | x = Conv2D(64, (3, 3), strides=(1, 1), padding='same', activation='relu')(x) 82 | x = UpSampling2D(size=(2, 2))(x) 83 | x = Conv2D(num_classes, (1, 1), activation=activation)(x) 84 | 85 | return Model(inputs=inputs, outputs=x) 86 | --------------------------------------------------------------------------------