├── 2007_000033.jpg ├── 2008_000438.jpg ├── 2008_000438_resized.jpg ├── BusServices-FlyBus_468x280.jpg ├── README.md ├── model_final.py ├── predict.py ├── predict2.py ├── result_straight.png ├── result_straight_enhanced.png ├── result_straight_enhanced_edge.png ├── result_straight_th.png ├── train1.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc └── image_reader.cpython-36.pyc └── image_reader.py /2007_000033.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nitinnilesh/MRF-Segmentation/3181ae862f59d8017ba59059b6f7e7d0bf6b27c7/2007_000033.jpg -------------------------------------------------------------------------------- /2008_000438.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nitinnilesh/MRF-Segmentation/3181ae862f59d8017ba59059b6f7e7d0bf6b27c7/2008_000438.jpg -------------------------------------------------------------------------------- /2008_000438_resized.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nitinnilesh/MRF-Segmentation/3181ae862f59d8017ba59059b6f7e7d0bf6b27c7/2008_000438_resized.jpg -------------------------------------------------------------------------------- /BusServices-FlyBus_468x280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nitinnilesh/MRF-Segmentation/3181ae862f59d8017ba59059b6f7e7d0bf6b27c7/BusServices-FlyBus_468x280.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MRF-Segmentation 2 | Markov Random field based semantic segmentation using Deep Learning implemented as project for Computer Vision (CSE 578) course at [IIIT Hyderabad](https://www.iiit.ac.in/). 3 | 4 | Paper Followed - [Deep Learning Markov Random Field for Semantic Segmentation](https://arxiv.org/pdf/1606.07230.pdf) 5 | 6 | Github Link - [MRF-Segmentation](https://github.com/Pi-Rasp/MRF-Segmentation) 7 | 8 | Authors - [Nitin Nilesh](https://github.com/Pi-Rasp), [Abhinav Dhere](https://github.com/abhinavdhere), [Ashish Kubade](https://github.com/Ashj9) 9 | 10 | -------------------------------------------------------------------------------- /model_final.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import keras 4 | from keras.models import Sequential 5 | from keras.layers import Input, Dense, Activation, Flatten, Reshape, Dropout 6 | from keras.layers import Conv2D, Conv3D, MaxPooling2D, UpSampling2D, Conv2DTranspose 7 | from keras.models import Model 8 | from keras import backend as K 9 | import h5py 10 | from keras.applications.vgg16 import VGG16 11 | import gc 12 | os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 13 | 14 | def DPN(input_shape): 15 | #Take VGG first 4 blocks, where last block without pool layer 16 | vgg16 = VGG16(weights='imagenet',include_top=False, input_shape=(224,224,3)) 17 | #Block 5 18 | x = Conv2D(512, (5, 5), activation='relu', padding='same', name='block5_conv1')(vgg16.layers[-6].output) 19 | x = Conv2D(512, (5, 5), activation='relu', padding='same', name='block5_conv2')(x) 20 | x = Conv2D(512, (5, 5), activation='relu', padding='same', name='block5_conv3')(x) 21 | #Block 6 22 | x = Conv2D(512, (25, 25), activation='relu', padding='same', name='block6_conv1', kernel_initializer='glorot_uniform')(x) 23 | x = Conv2D(4096, (1, 1), activation='relu', padding='same', name = 'block6_conv2',kernel_initializer='glorot_uniform')(x) 24 | x = Conv2D(21, (1, 1), activation='relu', padding='same', name = 'block6_conv3', kernel_initializer='glorot_uniform')(x) 25 | x = UpSampling2D((8,8))(x) 26 | #x = Reshape((224*224,1))(x) 27 | #x = Activation('softmax')(x) 28 | # Model created for Unary terms 29 | 30 | model = Model(inputs=vgg16.input, outputs=x) 31 | model.summary() 32 | change_weights_conv(model,vgg16,14,15) 33 | change_weights_conv(model,vgg16, 15,16) 34 | change_weights_conv(model,vgg16, 16,17) 35 | print('Conv layers weight transfer done') 36 | #change_weights_fc1(model, 17, weights_path, obj=[1,1,1,2,2,2,3,3,3,4,4,4,5,5,5,6,6,6], old_dim = (7,7,512,4096), new_dim = (25,25,512,4096)) 37 | #print('FC1 Done') 38 | #change_weights_fc2(model, 18, weights_path, new_dim = (1,1,4096,4096)) 39 | #print('FC2 Done') 40 | return model 41 | 42 | def change_weights_conv(model,vgg16,m_layer_number,v_layer_number,obj = [1,2], new_dim = (5,5,512,512)): 43 | weights = vgg16.layers[v_layer_number].get_weights()[0] 44 | bias = vgg16.layers[v_layer_number].get_weights()[1] 45 | weights_modified = np.random.randn(5,5,512,512)*10 46 | for i in range(weights.shape[2]): 47 | for j in range(weights.shape[3]): 48 | temp = np.insert(weights[:,:,i,j],obj,0,axis=1) 49 | temp = np.insert(temp.T,obj,0,axis=1).T 50 | weights_modified[:,:,i,j] = temp 51 | new_weights = [weights_modified,bias] 52 | model.layers[m_layer_number].set_weights(new_weights) 53 | del new_weights 54 | gc.collect() 55 | 56 | def change_weights_fc1(model, m_layer_number, weights_path, obj, old_dim, new_dim): 57 | f = h5py.File(weights_path) 58 | weights = np.asarray(f['fc1']['fc1_W_1:0']) 59 | bias = np.asarray(f['fc1']['fc1_b_1:0']) 60 | weights = weights.reshape(7,7,512,4096) 61 | weights_modified = np.zeros(new_dim) 62 | for i in range(weights.shape[2]): 63 | for j in range(weights.shape[3]): 64 | temp = np.insert(weights[:,:,i,j],obj,0,axis=1) 65 | temp = np.insert(temp.T,obj,0,axis=1).T 66 | weights_modified[:,:,i,j] = temp 67 | new_weights = [weights_modified,bias] 68 | model.layers[m_layer_number].set_weights(new_weights) 69 | del new_weights 70 | gc.collect() 71 | 72 | def change_weights_fc2(model, m_layer_number, weights_path, new_dim): 73 | f = h5py.File(weights_path) 74 | weights = np.asarray(f['fc2']['fc2_W_1:0']) 75 | bias = np.asarray(f['fc2']['fc2_b_1:0']) 76 | weights_modified = np.zeros(new_dim) 77 | for i in range(weights.shape[0]): 78 | for j in range(weights.shape[1]): 79 | weights_modified[:,:,i,j] = weights[i,j] 80 | new_weights = [weights_modified,bias] 81 | model.layers[m_layer_number].set_weights(new_weights) 82 | del new_weights 83 | gc.collect() 84 | #K.clear_session() 85 | 86 | def main(): 87 | model = DPN((224,224,3)) 88 | model.summary() 89 | print('done') 90 | 91 | if __name__=='__main__': 92 | main() 93 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from keras.models import load_model 2 | import numpy as np 3 | import os 4 | from PIL import Image 5 | from utils import pascal_palette, interp_map 6 | import scipy.misc as misc 7 | import matplotlib.pyplot as plt 8 | import keras.backend as K 9 | from keras.preprocessing.image import load_img, img_to_array,flip_axis 10 | 11 | input_width, input_height = 224, 224 12 | label_margin = 186 13 | pascal_mean = np.array([102.93, 111.36, 116.52]) 14 | input_image_path = '2007_000033.jpg' 15 | 16 | def get_trained_model(): 17 | model_name = 'DPN_Adam_12ep_1e_8_wodec.h5' 18 | model = load_model(model_name) 19 | return model 20 | 21 | def decode_segmap(label_mask): 22 | label_colours = pascal_palette 23 | r = label_mask.copy() 24 | g = label_mask.copy() 25 | b = label_mask.copy() 26 | for ll in range(0, 21): 27 | r[label_mask == ll] = label_colours[ll, 0] 28 | g[label_mask == ll] = label_colours[ll, 1] 29 | b[label_mask == ll] = label_colours[ll, 2] 30 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 31 | rgb[:, :, 0] = r 32 | rgb[:, :, 1] = g 33 | rgb[:, :, 2] = b 34 | plt.imshow(rgb) 35 | plt.show() 36 | return rgb 37 | 38 | def forward_pass(): 39 | model = get_trained_model() 40 | #input_image = np.array(Image.open(input_image_path)).astype(np.float32) 41 | input_image = misc.imread(input_image_path).astype(np.float32) 42 | input_image = misc.imresize(input_image,(input_height,input_width,3),interp='bicubic') 43 | input_image = input_image.astype(float) / 255.0 44 | image = input_image[:, :, ::-1] - pascal_mean 45 | image_size = image.shape 46 | 47 | net_in = np.zeros((1, input_height, input_width, 3), dtype=np.float32) 48 | '''output_height = input_height - 2 * label_margin 49 | output_width = input_width - 2 * label_margin 50 | image = np.pad(image,((label_margin, label_margin), 51 | (label_margin, label_margin), 52 | (0, 0)), 'reflect') 53 | 54 | margins_h = (0, input_height - image.shape[0]) 55 | margins_w = (0, input_width - image.shape[1]) 56 | image = np.pad(image, 57 | (margins_h, 58 | margins_w, 59 | (0, 0)), 'reflect')''' 60 | 61 | # Run inference 62 | net_in[0] = image 63 | prob = K.eval(K.softmax(model.predict(net_in)[0])) 64 | prob_edge = int(prob.shape[0]) 65 | prob = prob.reshape((input_height,input_width)) 66 | seg_image = decode_segmap(prob).astype(np.uint8) 67 | misc.imsave('result.png',seg_image) 68 | #print(prob) 69 | 70 | if __name__ == '__main__': 71 | forward_pass() 72 | -------------------------------------------------------------------------------- /predict2.py: -------------------------------------------------------------------------------- 1 | from keras.models import load_model 2 | import numpy as np 3 | import os 4 | from PIL import Image 5 | from utils import pascal_palette, interp_map 6 | from utils import image_reader as ir 7 | import scipy.misc as misc 8 | import matplotlib.pyplot as plt 9 | import keras.backend as K 10 | from keras.preprocessing.image import load_img, img_to_array,flip_axis 11 | 12 | input_width, input_height = 224, 224 13 | label_margin = 186 14 | pascal_mean = np.array([102.93, 111.36, 116.52]) 15 | input_image_path = '2009_004436.jpg' 16 | 17 | def add_context_margin(image, margin_size, **pad_kwargs): 18 | """ Adds a margin-size border around the image, used for 19 | providing context. """ 20 | return np.pad(image, 21 | ((margin_size, margin_size), 22 | (margin_size, margin_size), 23 | (0, 0)), **pad_kwargs) 24 | 25 | def pad_to_square(image, min_size, **pad_kwargs): 26 | """ Add padding to make sure that the image is larger than (min_size * min_size). 27 | This time, the image is aligned to the top left corner. """ 28 | 29 | h, w = image.shape[:2] 30 | 31 | if h >= min_size and w >= min_size: 32 | return image 33 | 34 | top = bottom = left = right = 0 35 | 36 | if h < min_size: 37 | top = (min_size - h) // 2 38 | bottom = min_size - h - top 39 | if w < min_size: 40 | left = (min_size - w) // 2 41 | right = min_size - w - left 42 | 43 | return np.pad(image, 44 | ((top, bottom), 45 | (left, right), 46 | (0, 0)), **pad_kwargs) 47 | 48 | def pad_image(image): 49 | image_pad_kwargs = dict(mode='reflect') 50 | image = add_context_margin(image, label_margin, **image_pad_kwargs) 51 | return pad_to_square(image, 224, **image_pad_kwargs) 52 | 53 | def crop_to(image, target_h=224, target_w=224): 54 | h_off = (image.shape[0] - target_h) // 2 55 | w_off = (image.shape[1] - target_w) // 2 56 | return image[h_off:h_off + target_h, 57 | w_off:w_off + target_w, :] 58 | 59 | def rgb_to_bgr(image): 60 | # Swap color channels to use pretrained VGG weights 61 | return image[:, :, ::-1] 62 | 63 | def remove_mean(image): 64 | # Note that there's no 0..1 normalization in VGG 65 | return image - pascal_mean 66 | 67 | def nll(y_true,y_pred): 68 | import keras.backend as K 69 | return K.sum(K.binary_crossentropy(y_true,y_pred)) 70 | 71 | def get_trained_model(): 72 | model_name = 'DPN_Adam_50ep_1e_8_wodec_womgpu.h5' 73 | model = load_model(model_name,custom_objects={'nll': nll}) 74 | return model 75 | 76 | def decode_segmap(label_mask): 77 | label_colours = pascal_palette 78 | r = label_mask.copy() 79 | g = label_mask.copy() 80 | b = label_mask.copy() 81 | for ll in range(0, 21): 82 | r[label_mask == ll] = label_colours[ll, 0] 83 | g[label_mask == ll] = label_colours[ll, 1] 84 | b[label_mask == ll] = label_colours[ll, 2] 85 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 86 | rgb[:, :, 0] = r#/255.0 *100.0 87 | rgb[:, :, 1] = g#/255.0*100.0 88 | rgb[:, :, 2] = b#/255.0*100.0 89 | #plt.imshow(rgb) 90 | #plt.show() 91 | return rgb 92 | 93 | def forward_pass(): 94 | import scipy.io as sio 95 | model = get_trained_model() 96 | image = misc.imread(input_image_path) 97 | image = misc.imresize(image,(224,224,3),interp='nearest') 98 | #image = pad_image(image) 99 | #image = crop_to(image) 100 | #image = rgb_to_bgr(image) 101 | #image = remove_mean(image) 102 | #image = image.astype(np.float64) 103 | image = image.astype(float)/255.0 104 | 105 | net_in = np.zeros((1, input_height, input_width, 3), dtype=np.float32) 106 | '''output_height = input_height - 2 * label_margin 107 | output_width = input_width - 2 * label_margin 108 | image = np.pad(image,((label_margin, label_margin), 109 | (label_margin, label_margin), 110 | (0, 0)), 'reflect') 111 | 112 | margins_h = (0, input_height - image.shape[0]) 113 | margins_w = (0, input_width - image.shape[1]) 114 | image = np.pad(image, 115 | (margins_h, 116 | margins_w, 117 | (0, 0)), 'reflect')''' 118 | 119 | # Run inference 120 | net_in[0] = image 121 | prob = ((model.predict(net_in)[0]))#K.eval, K.softmax 122 | prob_edge = int(prob.shape[0]) 123 | prob = prob.reshape((input_height,input_width)) 124 | seg_image = decode_segmap(prob)#.astype(np.uint8) 125 | #misc.imsave('result_straight.png',seg_image) 126 | #sio.savemat('result_raw.mat',mdict={'seg_image':seg_image}) 127 | plt.imshow(seg_image) 128 | plt.show() 129 | #print(prob) 130 | 131 | if __name__ == '__main__': 132 | forward_pass() 133 | -------------------------------------------------------------------------------- /result_straight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nitinnilesh/MRF-Segmentation/3181ae862f59d8017ba59059b6f7e7d0bf6b27c7/result_straight.png -------------------------------------------------------------------------------- /result_straight_enhanced.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nitinnilesh/MRF-Segmentation/3181ae862f59d8017ba59059b6f7e7d0bf6b27c7/result_straight_enhanced.png -------------------------------------------------------------------------------- /result_straight_enhanced_edge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nitinnilesh/MRF-Segmentation/3181ae862f59d8017ba59059b6f7e7d0bf6b27c7/result_straight_enhanced_edge.png -------------------------------------------------------------------------------- /result_straight_th.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nitinnilesh/MRF-Segmentation/3181ae862f59d8017ba59059b6f7e7d0bf6b27c7/result_straight_th.png -------------------------------------------------------------------------------- /train1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import time 6 | import shutil 7 | 8 | import click 9 | import numpy as np 10 | from keras import callbacks, optimizers 11 | from IPython import embed 12 | 13 | from model_final import * 14 | from utils.image_reader import ( 15 | RandomTransformer, 16 | SegmentationDataGenerator) 17 | from keras.utils.multi_gpu_utils import multi_gpu_model 18 | import keras.backend as K 19 | import tensorflow as tf 20 | import torch 21 | import numpy as np 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | 25 | @click.command() 26 | @click.option('--train-list-fname', type=click.Path(exists=True), 27 | default='./VOCdevkit/VOC2012/ImageSets/Segmentation/train.txt') 28 | @click.option('--val-list-fname', type=click.Path(exists=True), 29 | default='./VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt') 30 | @click.option('--img-root', type=click.Path(exists=True), 31 | default='./VOCdevkit/VOC2012/JPEGImages') 32 | @click.option('--mask-root', type=click.Path(exists=True), 33 | default='./VOCdevkit/VOC2012/SegmentationClass') 34 | @click.option('--batch-size', type=int, default=1) 35 | @click.option('--learning-rate', type=float, default=1e-7) 36 | def train(train_list_fname, 37 | val_list_fname, 38 | img_root, 39 | mask_root, 40 | batch_size, 41 | learning_rate): 42 | 43 | # Create image generators for the training and validation sets. Validation has 44 | # no data augmentation. 45 | transformer_train = RandomTransformer(horizontal_flip=True, vertical_flip=True) 46 | datagen_train = SegmentationDataGenerator(transformer_train) 47 | 48 | transformer_val = RandomTransformer(horizontal_flip=False, vertical_flip=False) 49 | datagen_val = SegmentationDataGenerator(transformer_val) 50 | 51 | '''train_desc = '{}-lr{:.0e}-bs{:03d}'.format( 52 | time.strftime("%Y-%m-%d %H:%M"), 53 | learning_rate, 54 | batch_size) 55 | checkpoints_folder = 'trained/' + train_desc 56 | try: 57 | os.makedirs(checkpoints_folder) 58 | except OSError: 59 | shutil.rmtree(checkpoints_folder, ignore_errors=True) 60 | os.makedirs(checkpoints_folder) 61 | 62 | model_checkpoint = callbacks.ModelCheckpoint( 63 | checkpoints_folder + '/ep{epoch:02d}-vl{val_loss:.4f}.hdf5', 64 | monitor='loss')''' 65 | 66 | 67 | '''tensorboard_cback = callbacks.TensorBoard( 68 | log_dir='{}/tboard'.format(checkpoints_folder), 69 | histogram_freq=0, 70 | write_graph=False, 71 | write_images=False) 72 | csv_log_cback = callbacks.CSVLogger( 73 | '{}/history.log'.format(checkpoints_folder)) 74 | reduce_lr_cback = callbacks.ReduceLROnPlateau( 75 | monitor='val_loss', 76 | factor=0.2, 77 | patience=5, 78 | verbose=1, 79 | min_lr=0.05 * learning_rate)''' 80 | 81 | model = DPN((224,224,3)) 82 | #gpu_model = multi_gpu_model(model, gpus=2) 83 | def nll(y_true,y_pred): 84 | loss = K.sum(K.binary_crossentropy(y_true,y_pred),axis=-1) 85 | return loss 86 | def cross_entropy_2d(target,input): 87 | input = K.eval(input) 88 | input = Variable(torch.from_numpy(input).type(torch.FloatTensor)) 89 | target = K.eval(target) 90 | target = Variable(torch.from_numpy(target).type(torch.LongTensor)) 91 | 92 | n, h, w, c = input.size() 93 | log_p = F.log_softmax(input, dim=1) 94 | log_p = log_p.contiguous().view(-1, c) 95 | log_p = log_p[target.view(n * h * w, 1).repeat(1, c) >= 0] 96 | log_p = log_p.view(-1, c) 97 | 98 | mask = target >= 0 99 | target = target[mask] 100 | loss = F.nll_loss(log_p, target, ignore_index=250,weight=weight, size_average=False) 101 | loss /= mask.data.sum() 102 | return loss.data[0] 103 | model.compile(loss=nll,optimizer=optimizers.Adam(lr=1e-8),metrics=['accuracy']) 104 | # Build absolute image paths 105 | def build_abs_paths(basenames): 106 | img_fnames = [os.path.join(img_root, f) + '.jpg' for f in basenames] 107 | mask_fnames = [os.path.join(mask_root, f) + '.png' for f in basenames] 108 | return img_fnames, mask_fnames 109 | 110 | train_basenames = [l.strip() for l in open(train_list_fname).readlines()] 111 | val_basenames = [l.strip() for l in open(val_list_fname).readlines()][:512] 112 | 113 | train_img_fnames, train_mask_fnames = build_abs_paths(train_basenames) 114 | val_img_fnames, val_mask_fnames = build_abs_paths(val_basenames) 115 | 116 | '''skipped_report_cback = callbacks.LambdaCallback( 117 | on_epoch_end=lambda a, b: open( 118 | '{}/skipped.txt'.format(checkpoints_folder), 'a').write( 119 | '{}\n'.format(datagen_train.skipped_count)))''' 120 | 121 | model.fit_generator( 122 | datagen_train.flow_from_list( 123 | train_img_fnames, 124 | train_mask_fnames, 125 | shuffle=True, 126 | batch_size=8, 127 | img_target_size=(224, 224), 128 | mask_target_size=(224, 224)), 129 | steps_per_epoch=len(train_basenames)//8, 130 | epochs=50, 131 | validation_data=datagen_val.flow_from_list( 132 | val_img_fnames, 133 | val_mask_fnames, 134 | batch_size=8, 135 | img_target_size=(224, 224), 136 | mask_target_size=(224, 224)),validation_steps=len(val_basenames)//8) 137 | model.save('DPN_Adam_50ep_1e_8_wodec_womgpu.h5') 138 | 139 | 140 | if __name__ == '__main__': 141 | train() 142 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | pascal_nclasses = 21 4 | pascal_palette = np.array([(0, 0, 0) 5 | , (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128) 6 | , (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0) 7 | , (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128) 8 | , (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)], dtype=np.uint8) 9 | 10 | 11 | # 0=background 12 | # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle 13 | # 6=bus, 7=car, 8=cat, 9=chair, 10=cow 14 | # 11=diningtable, 12=dog, 13=horse, 14=motorbike, 15=person 15 | # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor 16 | 17 | 18 | def mask_to_label(mask_rgb): 19 | """From color-coded RGB mask to classes [0-21] 20 | mask_labels = np.zeros(mask_rgb.shape[:2]) 21 | 22 | for i in range(mask_rgb.shape[0]): 23 | for j in range(mask_rgb.shape[1]): 24 | mask_labels[i, j] = pascal_palette.index(tuple(mask_rgb[i, j, :].astype(np.uint8))) 25 | 26 | return mask_labels""" 27 | mask = mask_rgb.astype(int) 28 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 29 | for ii, label in enumerate(pascal_palette): 30 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 31 | label_mask = label_mask.astype(int) 32 | return label_mask 33 | 34 | 35 | def interp_map(prob, zoom, width, height): 36 | zoom_prob = np.zeros((height, width, prob.shape[2]), dtype=np.float32) 37 | for c in range(prob.shape[2]): 38 | for h in range(height): 39 | for w in range(width): 40 | r0 = h // zoom 41 | r1 = r0 + 1 42 | c0 = w // zoom 43 | c1 = c0 + 1 44 | rt = float(h) / zoom - r0 45 | ct = float(w) / zoom - c0 46 | v0 = rt * prob[r1, c0, c] + (1 - rt) * prob[r0, c0, c] 47 | v1 = rt * prob[r1, c1, c] + (1 - rt) * prob[r0, c1, c] 48 | zoom_prob[h, w, c] = (1 - ct) * v0 + ct * v1 49 | return zoom_prob 50 | -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nitinnilesh/MRF-Segmentation/3181ae862f59d8017ba59059b6f7e7d0bf6b27c7/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/image_reader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nitinnilesh/MRF-Segmentation/3181ae862f59d8017ba59059b6f7e7d0bf6b27c7/utils/__pycache__/image_reader.cpython-36.pyc -------------------------------------------------------------------------------- /utils/image_reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from collections import namedtuple 4 | from utils import mask_to_label 5 | import click 6 | import numpy as np 7 | from IPython import embed 8 | from keras.preprocessing.image import ( 9 | load_img, img_to_array, 10 | flip_axis) 11 | 12 | # The set of parameters that describes an instance of 13 | # (random) augmentation 14 | TransformParams = namedtuple( 15 | 'TransformParameters', 16 | ('do_hor_flip', 'do_vert_flip')) 17 | 18 | pascal_mean = np.array([102.93, 111.36, 116.52]) 19 | 20 | label_margin = 186 21 | 22 | 23 | def load_img_array(fname, grayscale=False, target_size=None, dim_ordering='default'): 24 | """Loads and image file and returns an array.""" 25 | img = load_img(fname, 26 | grayscale=grayscale, 27 | target_size=target_size) 28 | x = img_to_array(img) 29 | return x 30 | 31 | 32 | class RandomTransformer: 33 | """To consistently add data augmentation to image pairs, we split the process in 34 | two steps. First, we generate a stream of random augmentation parameters, that 35 | can be zipped together with the images. Second, we do the actual transformation, 36 | that has no randomness since the parameters are passed in.""" 37 | 38 | def __init__(self, 39 | horizontal_flip=False, 40 | vertical_flip=False): 41 | self.horizontal_flip = horizontal_flip 42 | self.vertical_flip = vertical_flip 43 | 44 | def random_params_gen(self) -> TransformParams: 45 | """Returns a generator of random transformation parameters.""" 46 | while True: 47 | do_hor_flip = self.horizontal_flip and (np.random.random() < 0.5) 48 | do_vert_flip = self.vertical_flip and (np.random.random() < 0.5) 49 | 50 | yield TransformParams(do_hor_flip=do_hor_flip, 51 | do_vert_flip=do_vert_flip) 52 | 53 | @staticmethod 54 | def transform(x: np.array, params: TransformParams) -> np.array: 55 | """Transforms a single image according to the parameters given.""" 56 | if params.do_hor_flip: 57 | x = flip_axis(x, 1) 58 | 59 | if params.do_vert_flip: 60 | x = flip_axis(x, 0) 61 | 62 | return x 63 | 64 | 65 | class SegmentationDataGenerator: 66 | """A data generator for segmentation tasks, similar to ImageDataGenerator 67 | in Keras, but with distinct pipelines for images and masks. 68 | 69 | The idea is that this object holds no data, and only knows how to run 70 | the pipeline to load, augment, and batch samples. The actual data (csv, 71 | numpy, etc..) must be passed in to the fit/flow functions directly.""" 72 | 73 | skipped_count = 0 74 | 75 | def __init__(self, 76 | random_transformer: RandomTransformer): 77 | self.random_transformer = random_transformer 78 | 79 | def get_processed_pairs(self, 80 | img_fnames, 81 | mask_fnames): 82 | # Generators for image data 83 | img_arrs = (load_img_array(f) for f in img_fnames) 84 | mask_arrs = (load_img_array(f, grayscale=True) for f in mask_fnames) 85 | 86 | def add_context_margin(image, margin_size, **pad_kwargs): 87 | """ Adds a margin-size border around the image, used for 88 | providing context. """ 89 | return np.pad(image, 90 | ((margin_size, margin_size), 91 | (margin_size, margin_size), 92 | (0, 0)), **pad_kwargs) 93 | 94 | def pad_to_square(image, min_size, **pad_kwargs): 95 | """ Add padding to make sure that the image is larger than (min_size * min_size). 96 | This time, the image is aligned to the top left corner. """ 97 | 98 | h, w = image.shape[:2] 99 | 100 | if h >= min_size and w >= min_size: 101 | return image 102 | 103 | top = bottom = left = right = 0 104 | 105 | if h < min_size: 106 | top = (min_size - h) // 2 107 | bottom = min_size - h - top 108 | if w < min_size: 109 | left = (min_size - w) // 2 110 | right = min_size - w - left 111 | 112 | return np.pad(image, 113 | ((top, bottom), 114 | (left, right), 115 | (0, 0)), **pad_kwargs) 116 | 117 | def pad_image(image): 118 | image_pad_kwargs = dict(mode='reflect') 119 | image = add_context_margin(image, label_margin, **image_pad_kwargs) 120 | return pad_to_square(image, 224, **image_pad_kwargs) 121 | 122 | def pad_label(image): 123 | # Same steps as the image, but the borders are constant white 124 | label_pad_kwargs = dict(mode='constant', constant_values=255) 125 | image = add_context_margin(image, label_margin, **label_pad_kwargs) 126 | return pad_to_square(image, 224, **label_pad_kwargs) 127 | 128 | pairs = ((pad_image(image), pad_label(label)) for 129 | image, label in zip(img_arrs, mask_arrs)) 130 | 131 | # random/center crop 132 | def crop_to(image, target_h=224, target_w=224): 133 | # TODO: random cropping 134 | h_off = (image.shape[0] - target_h) // 2 135 | w_off = (image.shape[1] - target_w) // 2 136 | return image[h_off:h_off + target_h, 137 | w_off:w_off + target_w, :] 138 | 139 | pairs = ((crop_to(image), crop_to(label)) for image, label in pairs) 140 | 141 | # random augmentation 142 | augmentation_params = self.random_transformer.random_params_gen() 143 | transf_fn = self.random_transformer.transform 144 | pairs = ((transf_fn(image, params), transf_fn(label, params)) for 145 | ((image, label), params) in zip(pairs, augmentation_params)) 146 | 147 | def rgb_to_bgr(image): 148 | # Swap color channels to use pretrained VGG weights 149 | return image[:, :, ::-1] 150 | 151 | pairs = ((rgb_to_bgr(image), rgb_to_bgr(label)) for 152 | image, label in pairs) 153 | 154 | def remove_mean(image): 155 | # Note that there's no 0..1 normalization in VGG 156 | return image - pascal_mean 157 | 158 | pairs = ((remove_mean(image), label) for image, label in pairs) 159 | 160 | def slice_label(image, offset, label_size, stride): 161 | # Builds label_size * label_size pixels labels, starting from 162 | # offset from the original image, and stride stride 163 | return image[offset:offset + label_size * stride:stride, 164 | offset:offset + label_size * stride:stride] 165 | 166 | #pairs = ((image, slice_label(label, label_margin, 16, 8)) for 167 | # image, label in pairs) 168 | 169 | return pairs 170 | 171 | def flow_from_list(self, 172 | img_fnames, 173 | mask_fnames, 174 | batch_size, 175 | img_target_size, 176 | mask_target_size, 177 | shuffle=False): 178 | assert batch_size > 0 179 | 180 | paired_fnames = list(zip(img_fnames, mask_fnames)) 181 | 182 | while True: 183 | # Starting a new epoch.. 184 | if shuffle: 185 | random.shuffle(paired_fnames) # Shuffles in place 186 | img_fnames, mask_fnames = zip(*paired_fnames) 187 | 188 | pairs = self.get_processed_pairs(img_fnames, mask_fnames) 189 | 190 | i = 0 191 | img_batch = np.zeros((batch_size, img_target_size[0], img_target_size[1], 3)) 192 | mask_batch = np.zeros((batch_size, mask_target_size[0] * mask_target_size[1], 1)) 193 | for img, mask in pairs: 194 | #mask[mask==255] = 0 195 | mask = mask_to_label(mask) 196 | mask = mask.astype(int) 197 | img = img.astype(np.float64) 198 | img = img.astype(float)/255.0 199 | # Fill up the batch one pair at a time 200 | img_batch[i] = img 201 | # Pass the label image as 1D array to avoid the problematic Reshape 202 | # layer after Softmax (see model.py) 203 | mask_batch[i] = np.reshape(mask, (-1, 1)) 204 | #mask_batch[i] = mask 205 | 206 | # TODO: remove this ugly workaround to skip pairs whose mask 207 | # has non-labeled pixels. 208 | #if 255. in mask: 209 | # self.skipped_count += 1 210 | # continue 211 | 212 | i += 1 213 | if i == batch_size: 214 | i = 0 215 | yield img_batch, mask_batch 216 | 217 | 218 | @click.command() 219 | @click.option('--list-fname', type=click.Path(exists=True), 220 | default='./VOCdevkit/VOC2012/ImageSets/Segmentation/train.txt') 221 | @click.option('--img-root', type=click.Path(exists=True), 222 | default='./VOCdevkit/VOC2012/JPEGImages') 223 | @click.option('--mask-root', type=click.Path(exists=True), 224 | default='./VOCdevkit/VOC2012/SegmentationClass') 225 | def test_datagen(list_fname, img_root, mask_root): 226 | datagen = SegmentationDataGenerator() 227 | 228 | basenames = [l.strip() for l in open(list_fname).readlines()] 229 | img_fnames = [os.path.join(img_root, f) + '.jpg' for f in basenames] 230 | mask_fnames = [os.path.join(mask_root, f) + '.png' for f in basenames] 231 | 232 | datagen.flow_from_list(img_fnames, mask_fnames) 233 | 234 | 235 | if __name__ == '__main__': 236 | test_datagen() 237 | --------------------------------------------------------------------------------