├── .gitignore ├── AirNet ├── __init__.py ├── evaluation.py ├── inference.py ├── inputs.py ├── layers.py ├── requirements.txt ├── test.py └── training.py ├── LICENSE.txt ├── README.md ├── dataset ├── areacoverimage.py ├── check_for_duplicates.py ├── combined_results_imgs.py ├── convert_label_to_correct_classes.py ├── dataset_balance.py ├── geojson_to_png.py ├── getChannels.py ├── import_data.py ├── postgres_init.txt └── split_dataset.py ├── docs ├── Basic.png ├── Example-result-2.png ├── Example-result-3.png ├── Example-result-4.png ├── Example-result-5.png ├── Example-result.png ├── Extended-dropout.png ├── Extended.png ├── arch.PNG └── dataset-example.PNG └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | #This software 2 | Output/ 3 | !Output/Readme.txt 4 | 5 | tmp/ 6 | tmp 7 | dataset/config.py 8 | /dataset/config.py 9 | 10 | 11 | #pycharm 12 | .idea/ 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | env/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *,cover 60 | .hypothesis/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # celery beat schedule file 90 | celerybeat-schedule 91 | 92 | # dotenv 93 | .env 94 | 95 | # virtualenv 96 | .venv 97 | venv/ 98 | ENV/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # Windows thumbnail cache files 107 | Thumbs.db 108 | ehthumbs.db 109 | ehthumbs_vista.db 110 | 111 | # Folder config file 112 | Desktop.ini 113 | 114 | # Recycle Bin used on file shares 115 | $RECYCLE.BIN/ 116 | 117 | # Windows Installer files 118 | *.cab 119 | *.msi 120 | *.msm 121 | *.msp 122 | 123 | # Windows shortcuts 124 | *.lnk 125 | 126 | *.DS_Store 127 | .AppleDouble 128 | .LSOverride 129 | 130 | # Icon must end with two \r 131 | Icon 132 | 133 | 134 | # Thumbnails 135 | ._* 136 | 137 | # Files that might appear in the root of a volume 138 | .DocumentRevisions-V100 139 | .fseventsd 140 | .Spotlight-V100 141 | .TemporaryItems 142 | .Trashes 143 | .VolumeIcon.icns 144 | .com.apple.timemachine.donotpresent 145 | 146 | # Directories potentially created on remote AFP share 147 | .AppleDB 148 | .AppleDesktop 149 | Network Trash Folder 150 | Temporary Items 151 | .apdisk 152 | -------------------------------------------------------------------------------- /AirNet/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference import inference_basic, inference_basic_dropout, inference_extended_dropout, inference_extended 2 | from .training import training 3 | from .test import test 4 | from .inputs import placeholder_inputs, dataset_inputs, get_filename_list, get_all_test_data 5 | from .evaluation import evaluation, loss_calc, per_class_acc, get_hist, print_hist_summery 6 | 7 | import tensorflow as tf 8 | FLAGS = tf.app.flags.FLAGS 9 | 10 | """ AFFECTS HOW CODE RUNS""" 11 | 12 | tf.app.flags.DEFINE_string('model', 'basic_dropout', 13 | """ Defining what version of the model to run """) 14 | 15 | #Training 16 | tf.app.flags.DEFINE_string('log_dir',"./tmp/basic_dropout", #Training is default on, unless testing or finetuning is set to "True" 17 | """ dir to store training ckpt """) 18 | tf.app.flags.DEFINE_integer('max_steps', "60000", 19 | """ max_steps for training """) 20 | 21 | #Testing 22 | tf.app.flags.DEFINE_boolean('testing', False, #True or False 23 | """ Whether to run test or not """) 24 | tf.app.flags.DEFINE_string('model_ckpt_dir', "./tmp/basic_dropout/model.ckpt-22500", 25 | """ checkpoint file for model to use for testing """) 26 | tf.app.flags.DEFINE_boolean('save_image', True, 27 | """ Whether to save predicted image """) 28 | tf.app.flags.DEFINE_string('res_output_dir', "/home/mators/autoKart/result_imgs", 29 | """ Directory to save result images when running test """) 30 | #Finetuning 31 | tf.app.flags.DEFINE_boolean('finetune', True, #True or False 32 | """ Whether to finetune or not """) 33 | tf.app.flags.DEFINE_string('finetune_dir', 'tmp/basic_dropout/model.ckpt-22500', 34 | """ Path to the checkpoint file to finetune from """) 35 | 36 | 37 | """ TRAINING PARAMETERS""" 38 | tf.app.flags.DEFINE_integer('batch_size', "6", 39 | """ train batch_size """) 40 | tf.app.flags.DEFINE_integer('test_batch_size', "1", 41 | """ batch_size for training """) 42 | tf.app.flags.DEFINE_integer('eval_batch_size', "6", 43 | """ Eval batch_size """) 44 | 45 | tf.app.flags.DEFINE_float('balance_weight_0', 0.8, 46 | """ Define the dataset balance weight for class 0 - Not building """) 47 | tf.app.flags.DEFINE_float('balance_weight_1', 1.1, 48 | """ Define the dataset balance weight for class 1 - Building """) 49 | 50 | 51 | """ DATASET SPECIFIC PARAMETERS """ 52 | #Directories 53 | tf.app.flags.DEFINE_string('train_dir', "/home/mators/aerial_datasets/RGB_Trondheim_full/RGB_images/combined_dataset_v2/train_images", 54 | """ path to training images """) 55 | tf.app.flags.DEFINE_string('test_dir', "/home/mators/aerial_datasets/RGB_Trondheim_full/RGB_images/combined_dataset_v2/test_images", 56 | """ path to test image """) 57 | tf.app.flags.DEFINE_string('val_dir', "/home/mators/aerial_datasets/RGB_Trondheim_full/RGB_images/combined_dataset_v2/val_images", 58 | """ path to val image """) 59 | 60 | #Dataset size. #Epoch = one pass of the whole dataset. 61 | tf.app.flags.DEFINE_integer('num_examples_epoch_train', "7121", 62 | """ num examples per epoch for train """) 63 | tf.app.flags.DEFINE_integer('num_examples_epoch_test', "889", 64 | """ num examples per epoch for test """) 65 | tf.app.flags.DEFINE_integer('num_examples_epoch_val', "50", 66 | """ num examples per epoch for test """) 67 | tf.app.flags.DEFINE_float('fraction_of_examples_in_queue', "0.1", 68 | """ Fraction of examples from datasat to put in queue. Large datasets need smaller value, otherwise memory gets full. """) 69 | 70 | #Image size and classes 71 | tf.app.flags.DEFINE_integer('image_h', "512", 72 | """ image height """) 73 | tf.app.flags.DEFINE_integer('image_w', "512", 74 | """ image width """) 75 | tf.app.flags.DEFINE_integer('image_c', "3", 76 | """ number of image channels (RGB) (the depth) """) 77 | tf.app.flags.DEFINE_integer('num_class', "2", #classes are "Building" and "Not building" 78 | """ total class number """) 79 | 80 | 81 | #FOR TESTING: 82 | TEST_ITER = FLAGS.num_examples_epoch_test // FLAGS.batch_size 83 | 84 | 85 | tf.app.flags.DEFINE_float('moving_average_decay', "0.99",#"0.9999", #https://www.tensorflow.org/versions/r0.12/api_docs/python/train/moving_averages 86 | """ The decay to use for the moving average""") 87 | 88 | 89 | if(FLAGS.model == "basic" or FLAGS.model == "basic_dropout"): 90 | tf.app.flags.DEFINE_string('conv_init', 'xavier', # xavier / var_scale 91 | """ Initializer for the convolutional layers. One of: "xavier", "var_scale". """) 92 | tf.app.flags.DEFINE_string('optimizer', "SGD", 93 | """ Optimizer for training. One of: "adam", "SGD", "momentum", "adagrad". """) 94 | 95 | elif(FLAGS.model == "extended" or FLAGS.model == "extended_dropout"): 96 | tf.app.flags.DEFINE_string('conv_init', 'var_scale', # xavier / var_scale 97 | """ Initializer for the convolutional layers. One of "msra", "xavier", "var_scale". """) 98 | tf.app.flags.DEFINE_string('optimizer', "adagrad", 99 | """ Optimizer for training. One of: "adam", "SGD", "momentum", "adagrad". """) 100 | else: 101 | raise ValueError("Determine which initalizer you want to use. Non exist for model ", FLAGS.model) -------------------------------------------------------------------------------- /AirNet/evaluation.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | FLAGS = tf.app.flags.FLAGS 4 | 5 | def loss_calc(logits, labels): 6 | """ 7 | logits: tensor, float - [batch_size, width, height, num_classes]. 8 | labels: tensor, int32 - [batch_size, width, height, num_classes]. 9 | """ 10 | # construct one-hot label array 11 | label_flat = tf.reshape(labels, (-1, 1)) 12 | labels = tf.reshape(tf.one_hot(label_flat, depth=FLAGS.num_class), (-1, FLAGS.num_class)) 13 | 14 | #This motif is needed to hook up the batch_norm updates to the training 15 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 16 | with tf.control_dependencies(update_ops): 17 | cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels)) 18 | tf.summary.scalar('loss', cross_entropy) 19 | return cross_entropy 20 | 21 | def weighted_loss_calc(logits, labels): 22 | class_weights = np.array([ 23 | FLAGS.balance_weight_0, #"Not building" 24 | FLAGS.balance_weight_1 #"Building" 25 | ]) 26 | cross_entropy = tf.nn.weighted_cross_entropy_with_logits(logits=logits, labels=labels, pos_weight=class_weights) 27 | loss = tf.reduce_mean(cross_entropy) 28 | tf.summary.scalar('loss', loss) 29 | return loss 30 | 31 | def evaluation(logits, labels): 32 | labels = tf.to_int64(labels) 33 | correct_prediction = tf.equal(tf.argmax(logits, 3), labels) 34 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 35 | tf.summary.scalar('accuracy', accuracy) 36 | return accuracy 37 | 38 | def per_class_acc(predictions, label_tensor): 39 | labels = label_tensor 40 | num_class = FLAGS.num_class 41 | size = predictions.shape[0] 42 | hist = np.zeros((num_class, num_class)) 43 | for i in range(size): 44 | hist += fast_hist(labels[i].flatten(), predictions[i].argmax(2).flatten(), num_class) 45 | acc_total = np.diag(hist).sum() / hist.sum() 46 | print ('accuracy = %f'%np.nanmean(acc_total)) 47 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 48 | print ('mean IU = %f'%np.nanmean(iu)) 49 | for ii in range(num_class): 50 | if float(hist.sum(1)[ii]) == 0: 51 | acc = 0.0 52 | else: 53 | acc = np.diag(hist)[ii] / float(hist.sum(1)[ii]) 54 | print(" class # %d accuracy = %f "%(ii,acc)) 55 | 56 | def fast_hist(a, b, n): 57 | k = (a >= 0) & (a < n) 58 | return np.bincount(n * a[k].astype(int) + b[k], minlength=n**2).reshape(n, n) 59 | 60 | def get_hist(predictions, labels): 61 | num_class = predictions.shape[3] #becomes 2 for aerial - correct 62 | batch_size = predictions.shape[0] 63 | hist = np.zeros((num_class, num_class)) 64 | for i in range(batch_size): 65 | hist += fast_hist(labels[i].flatten(), predictions[i].argmax(2).flatten(), num_class) 66 | return hist 67 | 68 | def print_hist_summery(hist): 69 | acc_total = np.diag(hist).sum() / hist.sum() 70 | print ('accuracy = %f'%np.nanmean(acc_total)) 71 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 72 | print ('mean IU = %f'%np.nanmean(iu)) 73 | for ii in range(hist.shape[0]): 74 | if float(hist.sum(1)[ii]) == 0: 75 | acc = 0.0 76 | else: 77 | acc = np.diag(hist)[ii] / float(hist.sum(1)[ii]) 78 | print(" class # %d accuracy = %f "%(ii, acc)) -------------------------------------------------------------------------------- /AirNet/inference.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | FLAGS = tf.app.flags.FLAGS 3 | 4 | from .layers import unpool_with_argmax, conv_classifier, conv_layer_with_bn 5 | 6 | def inference_basic(images, is_training): 7 | """ 8 | Args: 9 | images: Images Tensors (placeholder with correct shape, img_h, img_w, img_d) 10 | is_training: If the model is training or testing 11 | """ 12 | initializer = get_weight_initializer() 13 | img_d = images.get_shape().as_list()[3] 14 | norm1 = tf.nn.lrn(images, depth_radius=5, bias=1.0, alpha=0.0001, beta=0.75, 15 | name='norm1') 16 | conv1 = conv_layer_with_bn(initializer, norm1, [7, 7, img_d, 64], is_training, name="conv1") 17 | pool1, pool1_indices = tf.nn.max_pool_with_argmax(conv1, ksize=[1, 2, 2, 1], 18 | strides=[1, 2, 2, 1], padding='SAME', name='pool1') 19 | 20 | conv2 = conv_layer_with_bn(initializer, pool1, [7, 7, 64, 64], is_training, name="conv2") 21 | pool2, pool2_indices = tf.nn.max_pool_with_argmax(conv2, ksize=[1, 2, 2, 1], 22 | strides=[1, 2, 2, 1], padding='SAME', name='pool2') 23 | 24 | conv3 = conv_layer_with_bn(initializer, pool2, [7, 7, 64, 64], is_training, name="conv3") 25 | pool3, pool3_indices = tf.nn.max_pool_with_argmax(conv3, ksize=[1, 2, 2, 1], 26 | strides=[1, 2, 2, 1], padding='SAME', name='pool3') 27 | 28 | conv4 = conv_layer_with_bn(initializer, pool3, [7, 7, 64, 64], is_training, name="conv4") 29 | pool4, pool4_indices = tf.nn.max_pool_with_argmax(conv4, ksize=[1, 2, 2, 1], 30 | strides=[1, 2, 2, 1], padding='SAME', name='pool4') 31 | 32 | """ End of encoder - starting decoder """ 33 | 34 | unpool_4 = unpool_with_argmax(pool4, ind=pool4_indices, name='unpool_4') 35 | conv_decode4 = conv_layer_with_bn(initializer, unpool_4, [7, 7, 64, 64], is_training, False, name="conv_decode4") 36 | 37 | unpool_3 = unpool_with_argmax(conv_decode4, ind=pool3_indices, name='unpool_3') 38 | conv_decode3 = conv_layer_with_bn(initializer, unpool_3, [7, 7, 64, 64], is_training, False, name="conv_decode3") 39 | 40 | unpool_2 = unpool_with_argmax(conv_decode3, ind=pool2_indices, name='unpool_2') 41 | conv_decode2 = conv_layer_with_bn(initializer, unpool_2, [7, 7, 64, 64], is_training, False, name="conv_decode2") 42 | 43 | unpool_1 = unpool_with_argmax(conv_decode2, ind=pool1_indices, name='unpool_1') 44 | conv_decode1 = conv_layer_with_bn(initializer, unpool_1, [7, 7, 64, 64], is_training, False, name="conv_decode1") 45 | 46 | return conv_classifier(conv_decode1, initializer) 47 | 48 | def inference_basic_dropout(images, is_training, keep_prob): 49 | """ 50 | Args: 51 | images: Images Tensors (placeholder with correct shape, img_h, img_w, img_d) 52 | is_training: If the model is training or testing 53 | keep_prob = probability that the layer will be dropped (dropout layer active) 54 | """ 55 | initializer = get_weight_initializer() 56 | img_d = images.get_shape().as_list()[3] 57 | norm1 = tf.nn.lrn(images, depth_radius=5, bias=1.0, alpha=0.0001, beta=0.75,name='norm1') 58 | 59 | conv1 = conv_layer_with_bn(initializer, norm1, [7, 7, img_d, 64], is_training, name="conv1") 60 | pool1, pool1_indices = tf.nn.max_pool_with_argmax(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool1') 61 | dropout1 = tf.layers.dropout(pool1, rate=(1-keep_prob), training=is_training, name="dropout1") 62 | 63 | conv2 = conv_layer_with_bn(initializer, dropout1, [7, 7, 64, 64], is_training, name="conv2") 64 | pool2, pool2_indices = tf.nn.max_pool_with_argmax(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool2') 65 | dropout2 = tf.layers.dropout(pool2, rate=(1-keep_prob), training=is_training, name="dropout2") 66 | 67 | conv3 = conv_layer_with_bn(initializer, dropout2, [7, 7, 64, 64], is_training, name="conv3") 68 | pool3, pool3_indices = tf.nn.max_pool_with_argmax(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool3') 69 | dropout3 = tf.layers.dropout(pool3, rate=(1-keep_prob), training=is_training, name="dropout3") 70 | 71 | conv4 = conv_layer_with_bn(initializer, dropout3, [7, 7, 64, 64], is_training, name="conv4") 72 | pool4, pool4_indices = tf.nn.max_pool_with_argmax(conv4, ksize=[1, 2, 2, 1], 73 | strides=[1, 2, 2, 1], padding='SAME', name='pool4') 74 | 75 | """ End of encoder - starting decoder """ 76 | 77 | unpool_4 = unpool_with_argmax(pool4, ind=pool4_indices, name='unpool_4') 78 | conv_decode4 = conv_layer_with_bn(initializer, unpool_4, [7, 7, 64, 64], is_training, False, name="conv_decode4") 79 | 80 | decode_dropout3 = tf.layers.dropout(conv_decode4, rate=(1-keep_prob), training=is_training, name="decoder_dropout3") 81 | unpool_3 = unpool_with_argmax(decode_dropout3, ind=pool3_indices, name='unpool_3') 82 | conv_decode3 = conv_layer_with_bn(initializer, unpool_3, [7, 7, 64, 64], is_training, False, name="conv_decode3") 83 | 84 | decode_dropout2 = tf.layers.dropout(conv_decode3, rate=(1-keep_prob), training=is_training, name="decoder_dropout2") 85 | unpool_2 = unpool_with_argmax(decode_dropout2, ind=pool2_indices, name='unpool_2') 86 | conv_decode2 = conv_layer_with_bn(initializer, unpool_2, [7, 7, 64, 64], is_training, False, name="conv_decode2") 87 | 88 | decode_dropout1 = tf.layers.dropout(conv_decode2, rate=(1-keep_prob), training=is_training, name="decoder_dropout1") 89 | unpool_1 = unpool_with_argmax(decode_dropout1, ind=pool1_indices, name='unpool_1') 90 | conv_decode1 = conv_layer_with_bn(initializer, unpool_1, [7, 7, 64, 64], is_training, False, name="conv_decode1") 91 | 92 | return conv_classifier(conv_decode1, initializer) 93 | 94 | def inference_extended(images, is_training): 95 | initializer = get_weight_initializer() 96 | img_d = images.get_shape().as_list()[3] 97 | conv1_1 = conv_layer_with_bn(initializer, images, [7, 7, img_d, 64], is_training, name="conv1_1") 98 | conv1_2 = conv_layer_with_bn(initializer, conv1_1, [7, 7, 64, 64], is_training, name="conv1_2") 99 | pool1, pool1_indices = tf.nn.max_pool_with_argmax(conv1_2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool1') 100 | 101 | conv2_1 = conv_layer_with_bn(initializer, pool1, [7, 7, 64, 64], is_training, name="conv2_1") 102 | conv2_2 = conv_layer_with_bn(initializer, conv2_1, [7, 7, 64, 64], is_training, name="conv2_2") 103 | pool2, pool2_indices = tf.nn.max_pool_with_argmax(conv2_2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool2') 104 | 105 | conv3_1 = conv_layer_with_bn(initializer, pool2, [7, 7, 64, 64], is_training, name="conv3_1") 106 | conv3_2 = conv_layer_with_bn(initializer, conv3_1, [7, 7, 64, 64], is_training, name="conv3_2") 107 | conv3_3 = conv_layer_with_bn(initializer, conv3_2, [7, 7, 64, 64], is_training, name="conv3_3") 108 | pool3, pool3_indices = tf.nn.max_pool_with_argmax(conv3_3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool3') 109 | 110 | conv4_1 = conv_layer_with_bn(initializer, pool3, [7, 7, 64, 64], is_training, name="conv4_1") 111 | conv4_2 = conv_layer_with_bn(initializer, conv4_1, [7, 7, 64, 64], is_training, name="conv4_2") 112 | conv4_3 = conv_layer_with_bn(initializer, conv4_2, [7, 7, 64, 64], is_training, name="conv4_3") 113 | pool4, pool4_indices = tf.nn.max_pool_with_argmax(conv4_3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool4') 114 | 115 | conv5_1 = conv_layer_with_bn(initializer, pool4, [7, 7, 64, 64], is_training, name="conv5_1") 116 | conv5_2 = conv_layer_with_bn(initializer, conv5_1, [7, 7, 64, 64], is_training, name="conv5_2") 117 | conv5_3 = conv_layer_with_bn(initializer, conv5_2, [7, 7, 64, 64], is_training, name="conv5_3") 118 | pool5, pool5_indices = tf.nn.max_pool_with_argmax(conv5_3, ksize=[1, 2, 2, 1], 119 | strides=[1, 2, 2, 1], padding='SAME', name='pool5') 120 | """ End of encoder """ 121 | 122 | """ Start decoder """ 123 | unpool_5 = unpool_with_argmax(pool5, ind=pool5_indices, name="unpool_5") 124 | conv_decode5_1 = conv_layer_with_bn(initializer, unpool_5, [7, 7, 64, 64], is_training, False, name="conv_decode5_1") 125 | conv_decode5_2 = conv_layer_with_bn(initializer, conv_decode5_1, [7, 7, 64, 64], is_training, False, name="conv_decode5_2") 126 | conv_decode5_3 = conv_layer_with_bn(initializer, conv_decode5_2, [7, 7, 64, 64], is_training, False, name="conv_decode5_3") 127 | 128 | unpool_4 = unpool_with_argmax(pool4, ind=pool4_indices, name="unpool_4") 129 | conv_decode4_1 = conv_layer_with_bn(initializer, unpool_4, [7, 7, 64, 64], is_training, False, name="conv_decode4_1") 130 | conv_decode4_2 = conv_layer_with_bn(initializer, conv_decode4_1, [7, 7, 64, 64], is_training, False, name="conv_decode4_2") 131 | conv_decode4_3 = conv_layer_with_bn(initializer, conv_decode4_2, [7, 7, 64, 64], is_training, False, name="conv_decode4_3") 132 | 133 | unpool_3 = unpool_with_argmax(pool3, ind=pool3_indices, name="unpool_3") 134 | conv_decode3_1 = conv_layer_with_bn(initializer, unpool_3, [7, 7, 64, 64], is_training, False, name="conv_decode3_1") 135 | conv_decode3_2 = conv_layer_with_bn(initializer, conv_decode3_1, [7, 7, 64, 64], is_training, False, name="conv_decode3_2") 136 | conv_decode3_3 = conv_layer_with_bn(initializer, conv_decode3_2, [7, 7, 64, 64], is_training, False, name="conv_decode3_3") 137 | 138 | unpool_2 = unpool_with_argmax(pool2, ind=pool2_indices, name="unpool_2") 139 | conv_decode2_1 = conv_layer_with_bn(initializer, unpool_2, [7, 7, 64, 64], is_training, False, name="conv_decode2_1") 140 | conv_decode2_2 = conv_layer_with_bn(initializer, conv_decode2_1, [7, 7, 64, 64], is_training, False, name="conv_decode2_2") 141 | 142 | unpool_1 = unpool_with_argmax(pool1, ind=pool1_indices, name="unpool_1") 143 | conv_decode1_1 = conv_layer_with_bn(initializer, unpool_1, [7, 7, 64, 64], is_training, False, name="conv_decode1_1") 144 | conv_decode1_2 = conv_layer_with_bn(initializer, conv_decode1_1, [7, 7, 64, 64], is_training, False, name="conv_decode1_2") 145 | """ End of decoder """ 146 | 147 | return conv_classifier(conv_decode1_2, initializer) 148 | 149 | def inference_extended_dropout(images, is_training, keep_prob): 150 | """ 151 | Args: 152 | images: Images Tensors (placeholder with correct shape, img_h, img_w, img_d) 153 | is_training: If the model is training or testing 154 | keep_prob = probability that the layer will be dropped (dropout layer active) 155 | """ 156 | 157 | initializer = get_weight_initializer() 158 | conv1_1 = conv_layer_with_bn(initializer, images, [7, 7, images.get_shape().as_list()[3], 64], is_training, name="conv1_1") 159 | conv1_2 = conv_layer_with_bn(initializer, conv1_1, [7, 7, 64, 64], is_training, name="conv1_2") 160 | pool1, pool1_indices = tf.nn.max_pool_with_argmax(conv1_2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool1') 161 | dropout1 = tf.layers.dropout(pool1, rate=(1-keep_prob), training=is_training, name="dropout1") 162 | 163 | conv2_1 = conv_layer_with_bn(initializer, dropout1, [7, 7, 64, 64], is_training, name="conv2_1") 164 | conv2_2 = conv_layer_with_bn(initializer, conv2_1, [7, 7, 64, 64], is_training, name="conv2_2") 165 | pool2, pool2_indices = tf.nn.max_pool_with_argmax(conv2_2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool2') 166 | dropout2 = tf.layers.dropout(pool2, rate=(1-keep_prob), training=is_training, name="dropout2") 167 | 168 | conv3_1 = conv_layer_with_bn(initializer, dropout2, [7, 7, 64, 64], is_training, name="conv3_1") 169 | conv3_2 = conv_layer_with_bn(initializer, conv3_1, [7, 7, 64, 64], is_training, name="conv3_2") 170 | conv3_3 = conv_layer_with_bn(initializer, conv3_2, [7, 7, 64, 64], is_training, name="conv3_3") 171 | pool3, pool3_indices = tf.nn.max_pool_with_argmax(conv3_3, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME', name='pool3') 172 | dropout3 = tf.layers.dropout(pool3, rate=(1-keep_prob), training=is_training, name="dropout3") 173 | 174 | conv4_1 = conv_layer_with_bn(initializer, dropout3, [7, 7, 64, 64], is_training, name="conv4_1") 175 | conv4_2 = conv_layer_with_bn(initializer, conv4_1, [7, 7, 64, 64], is_training, name="conv4_2") 176 | conv4_3 = conv_layer_with_bn(initializer, conv4_2, [7, 7, 64, 64], is_training, name="conv4_3") 177 | pool4, pool4_indices = tf.nn.max_pool_with_argmax(conv4_3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool4') 178 | dropout4 = tf.layers.dropout(pool4, rate=(1-keep_prob), training=is_training, name="dropout4") 179 | 180 | conv5_1 = conv_layer_with_bn(initializer, dropout4, [7, 7, 64, 64], is_training, name="conv5_1") 181 | conv5_2 = conv_layer_with_bn(initializer, conv5_1, [7, 7, 64, 64], is_training, name="conv5_2") 182 | conv5_3 = conv_layer_with_bn(initializer, conv5_2, [7, 7, 64, 64], is_training, name="conv5_3") 183 | pool5, pool5_indices = tf.nn.max_pool_with_argmax(conv5_3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool5') 184 | dropout5 = tf.layers.dropout(pool5, rate=(1-keep_prob), training=is_training, name="dropout5") 185 | """ End of encoder """ 186 | 187 | """ Start decoder """ 188 | unpool_5 = unpool_with_argmax(dropout5, ind=pool5_indices, name='unpool_5') 189 | conv_decode5_1 = conv_layer_with_bn(initializer, unpool_5, [7, 7, 64, 64], is_training, False, name="conv_decode5_1") 190 | conv_decode5_2 = conv_layer_with_bn(initializer, conv_decode5_1, [7, 7, 64, 64], is_training, False, name="conv_decode5_2") 191 | conv_decode5_3 = conv_layer_with_bn(initializer, conv_decode5_2, [7, 7, 64, 64], is_training, False, name="conv_decode5_3") 192 | 193 | dropout4_decode = tf.layers.dropout(conv_decode5_3, rate=(1-keep_prob), training=is_training, name="dropout4_decode") 194 | unpool_4 = unpool_with_argmax(dropout4_decode, ind=pool4_indices, name='unpool_4') 195 | conv_decode4_1 = conv_layer_with_bn(initializer, unpool_4, [7, 7, 64, 64], is_training, False, name="conv_decode4_1") 196 | conv_decode4_2 = conv_layer_with_bn(initializer, conv_decode4_1, [7, 7, 64, 64], is_training, False, name="conv_decode4_2") 197 | conv_decode4_3 = conv_layer_with_bn(initializer, conv_decode4_2, [7, 7, 64, 64], is_training, False, name="conv_decode4_3") 198 | 199 | dropout3_decode = tf.layers.dropout(conv_decode4_3, rate=(1-keep_prob), training=is_training, name="dropout3_decode") 200 | unpool_3 = unpool_with_argmax(dropout3_decode, ind=pool3_indices, name='unpool_3') 201 | conv_decode3_1 = conv_layer_with_bn(initializer, unpool_3, [7, 7, 64, 64], is_training, False, name="conv_decode3_1") 202 | conv_decode3_2 = conv_layer_with_bn(initializer, conv_decode3_1, [7, 7, 64, 64], is_training, False, name="conv_decode3_2") 203 | conv_decode3_3 = conv_layer_with_bn(initializer, conv_decode3_2, [7, 7, 64, 64], is_training, False, name="conv_decode3_3") 204 | 205 | dropout2_decode = tf.layers.dropout(conv_decode3_3, rate=(1-keep_prob), training=is_training, name="dropout2_decode") 206 | unpool_2 = unpool_with_argmax(dropout2_decode, ind=pool2_indices, name='unpool_2') 207 | conv_decode2_1 = conv_layer_with_bn(initializer, unpool_2, [7, 7, 64, 64], is_training, False, name="conv_decode2_1") 208 | conv_decode2_2 = conv_layer_with_bn(initializer, conv_decode2_1, [7, 7, 64, 64], is_training, False, name="conv_decode2_2") 209 | 210 | dropout1_decode = tf.layers.dropout(conv_decode2_2, rate=(1-keep_prob), training=is_training, name="dropout1_deconv") 211 | unpool_1 = unpool_with_argmax(dropout1_decode, ind=pool1_indices, name='unpool_1') 212 | conv_decode1_1 = conv_layer_with_bn(initializer, unpool_1, [7, 7, 64, 64], is_training, False, name="conv_decode1_1") 213 | conv_decode1_2 = conv_layer_with_bn(initializer, conv_decode1_1, [7, 7, 64, 64], is_training, False, name="conv_decode1_2") 214 | """ End of decoder """ 215 | 216 | return conv_classifier(conv_decode1_2, initializer) 217 | 218 | 219 | def get_weight_initializer(): 220 | if(FLAGS.conv_init == "var_scale"): 221 | initializer = tf.contrib.layers.variance_scaling_initializer() 222 | elif(FLAGS.conv_init == "xavier"): 223 | initializer=tf.contrib.layers.xavier_initializer() 224 | else: 225 | raise ValueError("Chosen weight initializer does not exist") 226 | return initializer 227 | -------------------------------------------------------------------------------- /AirNet/inputs.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.framework import ops 3 | from tensorflow.python.framework import dtypes 4 | import os 5 | import numpy as np 6 | import threading 7 | 8 | import skimage 9 | import skimage.io 10 | 11 | FLAGS = tf.app.flags.FLAGS 12 | 13 | def get_filename_list(path): 14 | 15 | image_filenames = sorted(os.listdir(path+'/images')) #sort by names to get img and label after each other 16 | label_filenames = sorted(os.listdir(path+'/labels')) #sort by names to get img and label after each other 17 | 18 | #Adding correct path to the each filename in the lists 19 | step=0 20 | for name in image_filenames: 21 | image_filenames[step] = path+"/images/"+name 22 | step=step+1 23 | step=0 24 | for name in label_filenames: 25 | label_filenames[step] = path+"/labels/"+name 26 | step=step+1 27 | 28 | return image_filenames, label_filenames 29 | 30 | 31 | def dataset_reader(filename_queue): #prev name: CamVid_reader 32 | 33 | image_filename = filename_queue[0] #tensor of type string 34 | label_filename = filename_queue[1] #tensor of type string 35 | 36 | #get png encoded image 37 | imageValue = tf.read_file(image_filename) 38 | labelValue = tf.read_file(label_filename) 39 | 40 | #decodes a png image into a uint8 or uint16 tensor 41 | #returns a tensor of type dtype with shape [height, width, depth] 42 | image_bytes = tf.image.decode_png(imageValue) 43 | label_bytes = tf.image.decode_png(labelValue) #Labels are png, not jpeg 44 | 45 | image = tf.reshape(image_bytes, (FLAGS.image_h, FLAGS.image_w, FLAGS.image_c)) 46 | label = tf.reshape(label_bytes, (FLAGS.image_h, FLAGS.image_w, 1)) 47 | 48 | return image, label 49 | 50 | def dataset_inputs(image_filenames, label_filenames, batch_size, running_train_set=True): 51 | images = ops.convert_to_tensor(image_filenames, dtype=dtypes.string) 52 | labels = ops.convert_to_tensor(label_filenames, dtype=dtypes.string) 53 | 54 | 55 | filename_queue = tf.train.slice_input_producer([images, labels], shuffle=True) 56 | 57 | image, label = dataset_reader(filename_queue) 58 | reshaped_image = tf.cast(image, tf.float32) 59 | min_fraction_of_examples_in_queue = FLAGS.fraction_of_examples_in_queue 60 | min_queue_examples = int(FLAGS.num_examples_epoch_train * 61 | min_fraction_of_examples_in_queue) 62 | 63 | print ('Filling queue with %d input images before starting to train. ' 64 | 'This may take some time.' % min_queue_examples) 65 | 66 | # Generate a batch of images and labels by building up a queue of examples. 67 | return _generate_image_and_label_batch(reshaped_image, label, 68 | min_queue_examples, batch_size, 69 | shuffle=True) 70 | 71 | def _generate_image_and_label_batch(image, label, min_queue_examples, 72 | batch_size, shuffle): 73 | """Construct a queued batch of images and labels. 74 | Args: 75 | image: 3-D Tensor of [height, width, 3] of type.float32. 76 | label: 3-D Tensor of [height, width, 1] type.int32 77 | min_queue_examples: int32, minimum number of samples to retain 78 | in the queue that provides of batches of examples. 79 | batch_size: Number of images per batch. 80 | shuffle: boolean indicating whether to use a shuffling queue. 81 | Returns: 82 | images: Images. 4D tensor of [batch_size, height, width, 3] size. 83 | labels: Labels. 3D tensor of [batch_size, height, width ,1] size. 84 | """ 85 | # Create a queue that shuffles the examples, and then 86 | # read 'batch_size' images + labels from the example queue. 87 | 88 | #TODO: test if setting threads to higher number! 89 | num_preprocess_threads = 1 90 | if shuffle: 91 | images, label_batch = tf.train.shuffle_batch( 92 | [image, label], 93 | batch_size=batch_size, 94 | num_threads=num_preprocess_threads, 95 | capacity=min_queue_examples + 3 * batch_size, 96 | min_after_dequeue=min_queue_examples) 97 | else: 98 | images, label_batch = tf.train.batch( 99 | [image, label], 100 | batch_size=batch_size, 101 | num_threads=num_preprocess_threads, 102 | capacity=min_queue_examples + 3 * batch_size) 103 | 104 | # Display the training images in the visualizer. 105 | tf.summary.image('training_images', images) 106 | print('generating image and label batch:') 107 | return images, label_batch 108 | 109 | def get_all_test_data(im_list, la_list): 110 | images = [] 111 | labels = [] 112 | index = 0 113 | for im_filename, la_filename in zip(im_list, la_list): 114 | im = np.array(skimage.io.imread(im_filename), np.float32) 115 | im = im[np.newaxis] 116 | la = skimage.io.imread(la_filename) 117 | la = la[np.newaxis] 118 | la = la[...,np.newaxis] 119 | images.append(im) 120 | labels.append(la) 121 | return images, labels 122 | 123 | 124 | def placeholder_inputs(batch_size): 125 | images = tf.placeholder(tf.float32, shape=[batch_size, FLAGS.image_h, FLAGS.image_w, 3]) 126 | labels = tf.placeholder(tf.int64, [batch_size, FLAGS.image_h, FLAGS.image_w, 1]) 127 | is_training = tf.placeholder(tf.bool, name='is_training') 128 | keep_prob = tf.placeholder(tf.float32, name="keep_probabilty") 129 | 130 | return images, labels, is_training, keep_prob -------------------------------------------------------------------------------- /AirNet/layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | FLAGS = tf.app.flags.FLAGS 4 | 5 | 6 | def unpool_with_argmax(pool, ind, name = None, ksize=[1, 2, 2, 1]): 7 | 8 | """ 9 | Unpooling layer after max_pool_with_argmax. 10 | Args: 11 | pool: max pooled output tensor 12 | ind: argmax indices 13 | ksize: ksize is the same as for the pool 14 | Return: 15 | unpool: unpooling tensor 16 | """ 17 | with tf.variable_scope(name): 18 | input_shape = pool.get_shape().as_list() 19 | output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3]) 20 | 21 | flat_input_size = np.prod(input_shape) 22 | flat_output_shape = [output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]] 23 | 24 | pool_ = tf.reshape(pool, [flat_input_size]) 25 | batch_range = tf.reshape(tf.range(output_shape[0], dtype=ind.dtype), shape=[input_shape[0], 1, 1, 1]) 26 | b = tf.ones_like(ind) * batch_range 27 | b = tf.reshape(b, [flat_input_size, 1]) 28 | ind_ = tf.reshape(ind, [flat_input_size, 1]) 29 | ind_ = tf.concat([b, ind_], 1) 30 | 31 | ret = tf.scatter_nd(ind_, pool_, shape=flat_output_shape) 32 | ret = tf.reshape(ret, output_shape) 33 | return ret 34 | 35 | 36 | def conv_classifier(input_layer, initializer): 37 | # output predicted class number (2) 38 | with tf.variable_scope('conv_classifier') as scope: #all variables prefixed with "conv_classifier/" 39 | shape=[1, 1, 64, FLAGS.num_class] 40 | kernel = _variable_with_weight_decay('weights', shape=shape, initializer=initializer, wd=None) 41 | #kernel = tf.get_variable('weights', shape, initializer=initializer) 42 | conv = tf.nn.conv2d(input_layer, filter=kernel, strides=[1, 1, 1, 1], padding='SAME') 43 | biases = _variable_on_cpu('biases', [FLAGS.num_class], tf.constant_initializer(0.0)) 44 | conv_classifier = tf.nn.bias_add(conv, biases, name=scope.name) 45 | return conv_classifier 46 | 47 | 48 | def conv_layer_with_bn(initializer, inputT, shape, is_training, activation=True, name=None): 49 | in_channel = shape[2] 50 | out_channel = shape[3] 51 | k_size = shape[0] 52 | 53 | with tf.variable_scope(name) as scope: 54 | kernel = _variable_with_weight_decay('weights', shape=shape, initializer=initializer, wd=None) 55 | #kernel = tf.get_variable(scope.name, shape, initializer=initializer) 56 | conv = tf.nn.conv2d(inputT, kernel, [1, 1, 1, 1], padding='SAME') 57 | biases = tf.Variable(tf.constant(0.0, shape=[out_channel], dtype=tf.float32), 58 | trainable=True, name='biases') 59 | bias = tf.nn.bias_add(conv, biases) 60 | 61 | if activation is True: #only use relu during encoder 62 | conv_out = tf.nn.relu(batch_norm_layer(bias, is_training, scope.name)) 63 | else: 64 | conv_out = batch_norm_layer(bias, is_training, scope.name) 65 | return conv_out 66 | 67 | def batch_norm_layer(inputT, is_training, scope): 68 | return tf.cond(is_training, 69 | lambda: tf.contrib.layers.batch_norm(inputT, is_training=True, 70 | center=False, decay=FLAGS.moving_average_decay, scope=scope), 71 | lambda: tf.contrib.layers.batch_norm(inputT, is_training=False, 72 | center=False, reuse = True, decay=FLAGS.moving_average_decay, scope=scope)) 73 | 74 | 75 | 76 | def _variable_with_weight_decay(name, shape, initializer, wd): 77 | """ Helper to create an initialized Variable with weight decay. 78 | Note that the Variable is initialized with a truncated normal distribution. 79 | A weight decay is added only if one is specified. 80 | Args: 81 | name: name of the variable 82 | shape: list of ints 83 | stddev: standard deviation of a truncated Gaussian 84 | wd: add L2Loss weight decay multiplied by this float. If None, weight 85 | decay is not added for this Variable. 86 | Returns: 87 | Variable Tensor 88 | """ 89 | var = _variable_on_cpu(name, shape, initializer) 90 | 91 | if wd is not None: 92 | weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss') 93 | tf.add_to_collection('losses', weight_decay) 94 | return var 95 | 96 | 97 | def _variable_on_cpu(name, shape, initializer): 98 | """Helper to create a Variable stored on CPU memory. 99 | Args: 100 | name: name of the variable 101 | shape: list of ints 102 | initializer: initializer for Variable 103 | Returns: 104 | Variable Tensor 105 | """ 106 | with tf.device('/cpu:0'): 107 | #dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 #added this after, cause it was in cifar model 108 | var = tf.get_variable(name, shape, initializer=initializer)#, dtype=dtype) 109 | return var -------------------------------------------------------------------------------- /AirNet/requirements.txt: -------------------------------------------------------------------------------- 1 | GDAL==2.1.3 2 | numpy==1.13.2 3 | Pillow==4.2.1 4 | virtualenv==15.0.1 5 | image==1.5.15 6 | matplotlib==2.0.2 7 | psycopg2==2.7.3.1 8 | scikit-image==0.13.1 9 | tensorflow-gpu==1.3.0 10 | tensorflow-tensorboard==0.1.7 11 | 12 | -------------------------------------------------------------------------------- /AirNet/test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import AirNet 3 | import os 4 | import numpy as np 5 | from PIL import Image 6 | FLAGS = tf.app.flags.FLAGS 7 | 8 | 9 | def test(): 10 | print("----------- In test method ----------") 11 | image_filenames, label_filenames = AirNet.get_filename_list(FLAGS.test_dir) 12 | 13 | test_data_node, test_labels_node, is_training, keep_prob = AirNet.placeholder_inputs(batch_size=1) 14 | 15 | 16 | if FLAGS.model == "basic": 17 | logits = AirNet.inference_basic(test_data_node, is_training) 18 | elif FLAGS.model == "extended": 19 | logits = AirNet.inference_extended(test_data_node, is_training) 20 | elif FLAGS.model == "basic_dropout": 21 | logits = AirNet.inference_basic_dropout(test_data_node, is_training, keep_prob) 22 | elif FLAGS.model == "extended_dropout": 23 | logits = AirNet.inference_extended_dropout(test_data_node, is_training, keep_prob) 24 | else: 25 | raise ValueError("The selected model does not exist") 26 | 27 | pred = tf.argmax(logits, axis=3) 28 | saver = tf.train.Saver() 29 | 30 | with tf.Session() as sess: 31 | saver.restore(sess, FLAGS.model_ckpt_dir) 32 | images, labels = AirNet.get_all_test_data(image_filenames, label_filenames) 33 | 34 | # Start the queue runners. 35 | coord = tf.train.Coordinator() 36 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 37 | hist = np.zeros((FLAGS.num_class, FLAGS.num_class)) 38 | 39 | step=0 40 | for image_batch, label_batch in zip(images, labels): 41 | feed_dict = { 42 | test_data_node: image_batch, 43 | test_labels_node: label_batch, 44 | is_training: False, 45 | keep_prob: 1.0 #During testing droput should be turned off -> 100% chance of keeping variable 46 | } 47 | 48 | dense_prediction, im = sess.run(fetches=[logits, pred], feed_dict=feed_dict) 49 | AirNet.per_class_acc(dense_prediction, label_batch) 50 | # output_image to verify 51 | if (FLAGS.save_image): 52 | if(step < 10): 53 | numb_img = "000"+str(step) 54 | elif(step < 100): 55 | numb_img = "00"+str(step) 56 | elif(step < 1000): 57 | numb_img = "0"+str(step) 58 | write_image(im[0], os.path.join(FLAGS.res_output_dir +'/testing_image'+numb_img+'.png')) #Printing all test images 59 | step=step+1 60 | hist += AirNet.get_hist(dense_prediction, label_batch) 61 | acc_total = np.diag(hist).sum() / hist.sum() 62 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 63 | print("acc: ", acc_total) 64 | print("IU: ", iu) 65 | print("mean IU: ", np.nanmean(iu)) 66 | 67 | coord.request_stop() 68 | coord.join(threads) 69 | 70 | 71 | def write_image(image, filename): 72 | """ store label data to colored image """ 73 | Sky = [0,0,0] # 74 | Building = [128,128,0] #green-ish 75 | 76 | r = image.copy() 77 | g = image.copy() 78 | b = image.copy() 79 | 80 | label_colours = np.array([Sky, Building]) 81 | for label in range(0,FLAGS.num_class): #for all labels - shouldn't this be set according to num_class? 82 | #Replacing all instances in matrix with label value with the label colour 83 | r[image==label] = label_colours[label,0] #red is channel/debth 0 84 | g[image==label] = label_colours[label,1] #green is channel/debth 1 85 | b[image==label] = label_colours[label,2] #blue is channel/debth 2 86 | rgb = np.zeros((image.shape[0], image.shape[1], 3)) 87 | rgb[:,:,0] = r/1.0 88 | rgb[:,:,1] = g/1.0 89 | rgb[:,:,2] = b/1.0 90 | im = Image.fromarray(np.uint8(rgb)) 91 | im.save(filename) -------------------------------------------------------------------------------- /AirNet/training.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.slim as slim 3 | 4 | FLAGS = tf.app.flags.FLAGS 5 | 6 | def training(loss): 7 | 8 | global_step = tf.Variable(0, name='global_step', trainable=False) 9 | 10 | #This motif is needed to hook up the batch_norm updates to the training 11 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 12 | with tf.control_dependencies(update_ops): 13 | if(FLAGS.optimizer == "SGD"): 14 | print("Running with SGD optimizer") 15 | optimizer = tf.train.GradientDescentOptimizer(0.1) 16 | elif(FLAGS.optimizer == "adam"): 17 | print("Running with adam optimizer") 18 | optimizer = tf.train.AdamOptimizer(0.001) 19 | elif(FLAGS.optimizer == "adagrad"): 20 | print("Running with adagrad optimizer") 21 | optimizer = tf.train.AdagradOptimizer(0.01) 22 | else: 23 | raise ValueError("optimizer was not recognized.") 24 | 25 | train_op = optimizer.minimize(loss=loss, global_step=global_step) 26 | #optimizer, like 'SGD', 'Adam', 'Adagrad' 27 | #train_op = tf.contrib.layers.optimize_loss(loss, optimizer="SGD", global_step=global_step, learning_rate = 0.1) 28 | return train_op, global_step -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Matthew Shun-Shin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SegNet --> AirNet 2 | AirNet is a segmentation network based on [SegNet](https://mi.eng.cam.ac.uk/projects/segnet/), but with some modifications. The goal is to use the model to segment multispectral images, so that geographical information (e.g. building information) can be extracted. The model is implemented in [Tensorflow](https://www.tensorflow.org/). 3 | 4 | ## Recognition 5 | - SegNet implementation by tkuanlan35: https://github.com/tkuanlun350/Tensorflow-SegNet. 6 | - Unraveling method from mshunshin: https://github.com/mshunshin/SegNetCMR/blob/master/SegNetCMR/layers.py 7 | 8 | 9 | ## Architecture 10 | *NB! If you are unfamiliar with how convolutional neural networks work, I have written a [blogpost](https://geoit.geoforum.no/2017/12/20/maskinlaering-flyfoto/) that explains the basic concepts.* 11 | 12 | I've implemented four different version of the models: 13 | - [AirNet-basic](#AirNet-basic) 14 | - [AirNet-basic-dropout](#AirNet-Basic-dropout) 15 | - [AirNet-extended](#AirNet-Extended) 16 | - [AirNet-extended-dropout](#AirNet-Extended-dropout) 17 | 18 | All the AirNet models have: 19 | - an encoder-decoder structure with different number of layers for the different versions 20 | - batch normalization with a moving avarage decay of 0.99 21 | - usampling layers that use pooling indices from the maxpool in the encoder. The positions for the original pixels that are kept during maxpooling in the encoder is saved, and used in the decoder to place the pixels back to their original positions. 22 | 23 | ![Airnet extended architecture](docs/arch.PNG) 24 | 25 | 26 | ### Optimizers and initializers 27 | Different optimizers and initializers have been tested on each of the models. The ones that were tested are listed below, and the ones that was chosen based on performance is given in the model description. 28 | 29 | #### Optimizers: 30 | - Stochastic gradient descent 31 | - Adam 32 | - Adagrad 33 | - Momentum 34 | #### Weight initializers: 35 | - Variance scale 36 | - Xavier 37 | 38 | ### AirNet Basic 39 | - Four encoders and four decoders 40 | - Stochastic gradient descent optimizer 41 | - Xavier initializer 42 | 43 | 44 | 45 | ### AirNet Basic dropout 46 | Same architecture as Basic, except for dropoutlayers that are added after the pooling layers, with a dropout rate of 0.5. 47 | 48 | 49 | ### AirNet Extended 50 | The extended model is much larger and has 5 encoders and 5 decoders. It takes longer time to train, it is slower during inference, but achieves higher performance when trained sufficiently. 51 | 52 | - Five encoder and decoders 53 | - Adagrad optimizer 54 | - Variance scale weight initializer 55 | 56 | 57 | 58 | ### AirNet Extended dropout 59 | Same architecture as Extended, except for dropoutlayers that are added after the pooling layers, with a dropout rate of 0.5. 60 | 61 | 62 | 63 | 64 | ## Usage 65 | ### Requirements 66 | 67 | - Tensorflow GPU 1.3.0 68 | - Python 3.5 69 | 70 | `pip install -r AirNet/requirements.txt` 71 | 72 | 73 | ### Run TensorBoard: 74 | tensorboard --logdir=path/to/log-directory 75 | 76 | 77 | ## Dataset 78 | To verify the model I used the CamVid dataset. This can be downloaded from: https://github.com/alexgkendall/SegNet-Tutorial, and used in the model by setting the correct paths and dataset size in [AirNet/\__init__.py](https://github.com/Norkart/autoKart/blob/master/AirNet/__init__.py) 79 | 80 | The datasets of aerial images used to train and test the model is constructed through an automatic mapping of vector data and aerial images of Norway. Both a dataset with IR images, and with RGB images was constructed, both containing around 4500 images. The data is unfortunately not open source. 81 | 82 | ![Example dataset](docs/dataset-example.PNG) 83 | 84 | 85 | 86 | ## Results 87 | The IR images gave a small increase in performance, and examples of the segmentation can be seen here: 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /dataset/areacoverimage.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import os 5 | import sys 6 | import requests 7 | import shutil 8 | import numpy as np 9 | from PIL import Image 10 | import shutil 11 | 12 | from osgeo import gdal 13 | from osgeo import osr 14 | from osgeo import ogr 15 | 16 | 17 | #Define global variables: 18 | database_name = "" 19 | base_dir = "" 20 | image_dir = "" 21 | label_dir = "" 22 | 23 | wms_url_std = '' 24 | wms_url_hist = '' 25 | 26 | #Postgres database specification: 27 | databaseServer = "localhost" 28 | databasePort = "5432" 29 | databaseUser = "postgres" 30 | databasePW = "1234" 31 | 32 | 33 | """ ------------------- SET PARAMETERS ------------------- """ 34 | 35 | col_pixel_size = 0.2 36 | col_image_dim = 512 37 | cat_image_dim = 512 #prev 27 38 | 39 | #Information about where areas can be found 40 | wms_areas = [ 41 | { 42 | 'name':"Trondheim-2016", 43 | 'rgb_wms_url': wms_url_std, 44 | 'ir_wms_url': wms_url_std 45 | } 46 | # , 47 | # { 48 | # 'name':"Oslo-Ostlandet-2016", 49 | # 'rgb_wms_url': wms_url_hist, 50 | # 'ir_wms_url': None 51 | # } 52 | ] 53 | 54 | datasetName="RGB_Trondheim_full/" 55 | start_img_nr = 0 #Change if creating datasets seperately that should later be merged 56 | 57 | """ Chose what to parts of the dataset to create """ 58 | #TIF 59 | create_IR_images = False 60 | create_RGB_images = False 61 | 62 | convert_IR_to_png = False 63 | convert_RGB_to_png = False 64 | 65 | split_IR_dataset = False 66 | split_RGB_dataset = True 67 | 68 | 69 | """ ------------------------------------------------------- """ 70 | 71 | # Setup working spatial reference 72 | srs_epsg_nr = 25832 73 | srs = osr.SpatialReference() 74 | srs.ImportFromEPSG(srs_epsg_nr) 75 | srs_epsg_str = 'epsg:{0}'.format(srs_epsg_nr) 76 | 77 | 78 | feature_table = [ 79 | (0.0, "artype >= 90", "Unknown/novalue"), 80 | (0.10, "artype = 30 and artreslag = 31", "Barskog"), 81 | (0.10, "artype = 30 and artreslag = 32", "Loevskog"), 82 | (0.10, "artype = 30 and artreslag >= 33", "Skog, blandet eller ukjent"), 83 | (0.10, "artype = 50 and argrunnf >= 43 and argrunnf <= 45", "Jorddekt aapen mark"), 84 | (0.10, "artype >= 20 and artype < 30", "Dyrket"), 85 | (0.10, "artype >= 80 and artype < 89", "Water"), 86 | (0.05, "artype = 50 and argrunnf = 41", "Blokkmark"), 87 | (0.05, "artype = 50 and argrunnf = 42", "Fjell i dagen"), 88 | (0.05, "artype = 60", "Myr"), 89 | (0.01, "artype = 70", "Sne/is/bre"), 90 | (0.05, "artype = 50 and argrunnf > 45", "Menneskepaavirket eller ukjent aapen mark"), 91 | (0.20, "artype = 12", "Vei/jernbane/transport"), 92 | (0.20, "artype >= 10 and artype < 12", "Bebygd"), 93 | (0.90, "byggtyp is not null", "Bygning"), 94 | (0.90, "objtype = 'annenbygning' ", "Bygning") 95 | ] 96 | 97 | 98 | def createImageDS(filename, x_min, y_min, x_max, y_max, pixel_size, srs=None): 99 | # Create the destination data source 100 | x_res = int((x_max - x_min) / pixel_size) # resolution 101 | y_res = int((y_max - y_min) / pixel_size) # resolution 102 | ds = gdal.GetDriverByName('GTiff').Create(filename, x_res, 103 | y_res, 1, gdal.GDT_Byte) 104 | ds.SetGeoTransform(( 105 | x_min, pixel_size, 0, 106 | y_max, 0, -pixel_size, 107 | )) 108 | if srs: 109 | # Make the target raster have the same projection as the source 110 | ds.SetProjection(srs.ExportToWkt()) 111 | else: 112 | # Source has no projection (needs GDAL >= 1.7.0 to work) 113 | ds.SetProjection('LOCAL_CS["arbitrary"]') 114 | 115 | # Set nodata 116 | band = ds.GetRasterBand(1) 117 | band.SetNoDataValue(0) 118 | 119 | return ds 120 | 121 | def create_png_versions_with_correct_pixel_values(path, img_type): 122 | print("Converting to png for") 123 | print(img_type) 124 | save_to_path = base_dir+wms_area_name+"/png/"+ img_type 125 | create_dir(save_to_path) 126 | 127 | #test_img = Image.open(path+"/map_categories_0.tif") 128 | #print(test_img) 129 | 130 | images = sorted(os.listdir(path)) 131 | print(images) 132 | print("Images is opened and sorted") 133 | 134 | for image_file in images: 135 | print("For image") 136 | outfile = image_file.replace(' ', '')[:-3]+"png" #removing .tif ending and replacing with .png 137 | try: 138 | im = Image.open(os.path.join(path, image_file)) 139 | print("Generating png for %s" % image_file) 140 | except Exception: 141 | print("Could not open image in path:") 142 | print(os.path.join(path, image_file)) 143 | continue 144 | 145 | if img_type == "labels": 146 | print("type is labels:") 147 | pixels = im.load() 148 | 149 | for x in range(0, im.size[0]): 150 | for y in range(0, im.size[1]): 151 | if pixels[x,y] == 14: #converting to correct classname 152 | pixels[x,y] = 1 153 | else: 154 | im=im.convert('RGB') 155 | im.thumbnail(im.size) 156 | im.save(os.path.join(save_to_path, outfile), "PNG", quality=100) 157 | 158 | 159 | def loadWMS(img_file, url, x_min, y_min, x_max, y_max, x_sz, y_sz, srs, layers, format='image/tiff', styles=None): 160 | #Styles is set when working with height data 161 | # Set WMS parameters 162 | # eksempel url: 'http://www.webatlas.no/wms-orto-std/' 163 | hdr = None 164 | 165 | #http://www.webatlas.no/wms-orto-std/?request=GetMap&layers=Trondheim-2016&width=512&height=512&srs=epsg:25832&format=image/tiff&bbox=560000, 7020000, 570000, 7030000 166 | 167 | params = { 168 | 'request': 'GetMap', 169 | #'layers': layers, 170 | 'layers': "Trondheim-2016", 171 | 'width': x_sz, 172 | 'height': y_sz, 173 | 'srs': srs, 174 | 'format': format, 175 | 'bbox': '{0}, {1}, {2}, {3}'.format(x_min, y_min, x_max, y_max) 176 | } 177 | 178 | 179 | if styles: 180 | params['styles'] = styles 181 | 182 | # Do request 183 | for i in range(10): 184 | try: 185 | req = requests.get(url, stream=True, params=params, headers=hdr, timeout=None) 186 | break 187 | except requests.exceptions.ConnectionError as err: 188 | time.sleep(10) 189 | else: 190 | print("Unable to fetch image") 191 | 192 | # Handle response 193 | if req.status_code == 200: 194 | print("request status is 200") 195 | if req.headers['content-type'] == format: 196 | # If response is OK and an image, save image file 197 | with open(img_file, 'wb') as out_file: 198 | shutil.copyfileobj(req.raw, out_file) 199 | return True 200 | 201 | else: 202 | # If no image, print error to stdout 203 | print("Content-type: ", req.headers['content-type'], " url: ", req.url, " Content: ", req.text, file=sys.stderr) 204 | 205 | # Use existing 206 | elif req.status_code == 304: 207 | return True 208 | 209 | # Handle error 210 | else: 211 | print("Status: ", req.status_code, " url: ", req.url, file=sys.stderr) 212 | 213 | return False 214 | 215 | 216 | def create_dir(path): 217 | if not os.path.exists(path): 218 | os.makedirs(path) 219 | 220 | 221 | def createDataset(): 222 | #Define the global variables as global, to get correct values 223 | global image_nr 224 | print("In create dataset") 225 | 226 | cat_pixel_size = col_pixel_size * col_image_dim / cat_image_dim 227 | 228 | connString = "PG: host=%s port=%s dbname=%s user=%s password=%s" % (databaseServer, databasePort, database_name, databaseUser, databasePW) 229 | 230 | conn = ogr.Open(connString) 231 | layer = conn.GetLayer("ar_bygg") #ar_bygg is the database table 232 | 233 | layer_x_min, layer_x_max, layer_y_min, layer_y_max = layer.GetExtent() 234 | 235 | target_fill = [0] * len(feature_table) 236 | 237 | bbox_size_x = col_pixel_size * col_image_dim 238 | bbox_size_y = col_pixel_size * cat_image_dim 239 | 240 | #Init boundingbox (bbox) 241 | x_min = layer_x_min 242 | y_min = layer_y_min 243 | x_max = x_min + bbox_size_x 244 | y_max = y_min + bbox_size_y 245 | 246 | it = 0 247 | while y_max < (layer_y_max - bbox_size_y): #As long as it hasn't reached end of area 248 | it = it+1 249 | if(x_max > layer_x_max + bbox_size_x): #If end of x, skip to next y col 250 | print("\n Reached x end, moving y length") 251 | x_min = layer_x_min # reset x_min --> start at layer_x_min again 252 | #skip one column (y-length): 253 | y_min = y_min + bbox_size_y 254 | y_max = y_min + bbox_size_y 255 | 256 | # Create new boundingbox by moving across x-axis 257 | x_min = x_min + bbox_size_x #Startpunk + lengden av forrige 258 | x_max = x_min + bbox_size_x 259 | 260 | # Create ring 261 | ring = ogr.Geometry(ogr.wkbLinearRing) 262 | ring.AddPoint(x_min, y_min) 263 | ring.AddPoint(x_max, y_min) 264 | ring.AddPoint(x_max, y_max) 265 | ring.AddPoint(x_min, y_max) 266 | ring.AddPoint(x_min, y_min) 267 | 268 | # Create polygon 269 | poly = ogr.Geometry(ogr.wkbPolygon) 270 | poly.AddGeometry(ring) 271 | 272 | # set spatial filter 273 | layer.SetSpatialFilter(poly) 274 | 275 | # Test for the existence of data 276 | layer.SetAttributeFilter(None) 277 | # if no data, go to next 278 | if layer.GetFeatureCount() < 1: 279 | #print("Feature count less than 1") 280 | continue 281 | 282 | good_data = True 283 | for feature in reversed(feature_table): 284 | if feature[0] > np.random.random_sample(): 285 | layer.SetAttributeFilter(feature[1]) 286 | #if layer.GetFeatureCount() < 1: 287 | #print("Gikk ut paa", feature[2]) 288 | #good_data = False 289 | #break 290 | 291 | if not good_data: 292 | #print("Not good data") 293 | continue #skipping to next iteration 294 | 295 | # Create image 296 | target_ds = gdal.GetDriverByName('GTiff').Create(os.path.join(label_dir, "map_categories_{0}.tif".format(image_nr)), 297 | cat_image_dim, cat_image_dim, 1, gdal.GDT_Byte) 298 | target_ds.SetGeoTransform(( 299 | x_min, cat_pixel_size, 0, 300 | y_max, 0, -cat_pixel_size, 301 | )) 302 | target_ds.SetProjection(srs.ExportToWkt()) 303 | 304 | # Fill raster 305 | no_data = True 306 | for i, attr_filter in enumerate([feature[1] for feature in feature_table]): 307 | if attr_filter: 308 | # Rasterize 309 | layer.SetAttributeFilter(attr_filter) 310 | if layer.GetFeatureCount() > 0: 311 | no_data = False 312 | target_fill[i] += 1 313 | if gdal.RasterizeLayer(target_ds, [1], layer, burn_values=[i], options=['ALL_TOUCHED=TRUE']) != 0: 314 | raise Exception("error rasterizing layer: %s" % err) 315 | 316 | if no_data: 317 | print("no data") 318 | continue #skipping to next iteration 319 | 320 | # Load color image 321 | colorImgFile = os.path.join(image_dir, "map_color_{0}.tif".format(image_nr)) 322 | 323 | #Extracting images 324 | loadWMS(colorImgFile, wms_url, x_min, y_min, x_max, y_max, 325 | col_image_dim, col_image_dim, srs_epsg_str, wms_area_name) 326 | 327 | image_nr += 1 #used when setting names for result images 328 | 329 | #print("feature and fill: ") 330 | #for fill, feature in zip(target_fill, feature_table): 331 | #print(feature[2], fill) 332 | 333 | 334 | def split_dataset(): 335 | """ Splitting dataset into three parts: Training, testing, validation""" 336 | training_size = 0.8 337 | val_size = 0.1 338 | 339 | images = os.listdir(image_dir) 340 | labels = os.listdir(label_dir) 341 | 342 | tot_number = len(images) 343 | processed_number = 0 344 | 345 | outpath_base = base_dir+"combined_dataset_v2/" 346 | create_dir(outpath_base) 347 | 348 | for image_name in images: 349 | 350 | imagenr = image_name.split("map_color_")[1].split(".png")[0] 351 | image = Image.open(os.path.join(image_dir, image_name)) 352 | 353 | #Obs! Some images dont have labels version - skip them! 354 | for label_name in labels: 355 | labelnr = label_name.split("map_categories_")[1].split(".png")[0] 356 | if(imagenr == labelnr): 357 | outpath_images, outpath_labels = set_split_dataset_outpath(processed_number, tot_number, training_size, val_size, outpath_base) 358 | # os.rename(os.path.join(image_dir, image_name), os.path.join(outpath_images, newName_image) ) 359 | # os.rename(os.path.join(label_dir, label_name), os.path.join(outpath_labels, newName_label) ) 360 | shutil.copy2(os.path.join(image_dir, image_name), os.path.join(outpath_images, image_name) ) 361 | shutil.copy2(os.path.join(label_dir, label_name), os.path.join(outpath_labels, label_name) ) 362 | 363 | processed_number = processed_number + 1 364 | # continue 365 | break 366 | 367 | 368 | def set_split_dataset_outpath(processed_number, tot_number, training_size, val_size,outpath_base ): 369 | if(processed_number > (tot_number * (training_size+val_size)) ): 370 | outpath_images = outpath_base + "test_images/images" 371 | outpath_labels = outpath_base + "test_images/labels" 372 | elif(processed_number > (tot_number * training_size)): 373 | print("\n Val images, processed_number is:") 374 | print(processed_number) 375 | outpath_images = outpath_base + "val_images/images" 376 | outpath_labels = outpath_base + "val_images/labels" 377 | else: 378 | outpath_images = outpath_base + "train_images/images" 379 | outpath_labels = outpath_base + "train_images/labels" 380 | 381 | create_dir(outpath_images) 382 | create_dir(outpath_labels) 383 | return outpath_images, outpath_labels 384 | 385 | def setAreaVariables(img_type, img_dir_type): 386 | #For each area: 387 | #Defining that these variables should be modified globally: 388 | global database_name, base_dir, image_dir, label_dir 389 | 390 | database_name = wms_area_name.split("-")[0] 391 | print("DATABASE name is: ") 392 | print(database_name) 393 | base_dir = '/home/mators/aerial_datasets/'+datasetName+img_type+'/' 394 | create_dir('/home/mators/aerial_datasets') 395 | create_dir('/home/mators/aerial_datasets/'+datasetName+img_type+'/') 396 | image_dir = base_dir+wms_area_name+"/"+img_dir_type+'/images/' 397 | label_dir = base_dir+wms_area_name+"/"+img_dir_type+'/labels/' 398 | create_dir(image_dir) 399 | create_dir(label_dir) 400 | 401 | 402 | if __name__ == "__main__": 403 | global image_nr , wms_url, wms_area_name 404 | 405 | image_nr=start_img_nr 406 | 407 | for i in range (0, len(wms_areas)): 408 | #print("Creating images for area:") 409 | #print(wms_areas[i]["name"]) 410 | if create_RGB_images: 411 | wms_area_name = wms_areas[i]["name"] 412 | print("wms are name:") 413 | print(wms_area_name) 414 | setAreaVariables("RGB_images", "tif") 415 | wms_url = wms_areas[i]["rgb_wms_url"] 416 | createDataset() 417 | 418 | if create_IR_images: 419 | wms_area_name = wms_areas[i]["name"].split("-")[0]+"-IR-"+wms_areas[i]["name"].split("-")[1] 420 | setAreaVariables("IR_images", "tif") 421 | wms_url = wms_areas[i]["ir_wms_url"] 422 | createDataset() 423 | 424 | if convert_RGB_to_png: 425 | wms_area_name = wms_areas[i]["name"] 426 | setAreaVariables("RGB_images", "tif") 427 | create_png_versions_with_correct_pixel_values(image_dir, "images") 428 | create_png_versions_with_correct_pixel_values(label_dir, "labels") 429 | 430 | if convert_IR_to_png: 431 | wms_area_name = wms_areas[i]["name"].split("-")[0]+"-IR-"+wms_areas[i]["name"].split("-")[1] 432 | #converting to png for images and labels 433 | print("converting IR images to PNG") 434 | setAreaVariables("IR_images", "tif") 435 | create_png_versions_with_correct_pixel_values(image_dir, "images") 436 | create_png_versions_with_correct_pixel_values(label_dir, "labels") 437 | 438 | if split_RGB_dataset: 439 | wms_area_name = wms_areas[i]["name"] 440 | setAreaVariables("RGB_images", "png") 441 | split_dataset() 442 | 443 | if split_IR_dataset: 444 | wms_area_name = wms_areas[i]["name"].split("-")[0]+"-IR-"+wms_areas[i]["name"].split("-")[1] 445 | setAreaVariables("IR_images", "png") 446 | split_dataset() 447 | -------------------------------------------------------------------------------- /dataset/check_for_duplicates.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from PIL import Image 4 | import numpy 5 | 6 | from PIL import ImageChops 7 | 8 | """ TESTED: 9 | 10 | No duplicates in: 11 | - within validation images first part (stopped because of training - took to much time) 12 | 13 | """ 14 | 15 | image_path="../../IR_images/combined_dataset/val_images/images" 16 | # image_path="../../IR_images/combined_dataset/val_images/images" 17 | 18 | images = sorted(os.listdir(image_path)) 19 | 20 | 21 | for image_file_1 in images: 22 | for image_file_2 in images: 23 | image1 = Image.open(os.path.join(image_path,image_file_1)) 24 | image2 = Image.open(os.path.join(image_path,image_file_2)) 25 | #pixels = image.load() 26 | 27 | if ImageChops.difference(image1, image2).getbbox() is None: 28 | # if(image1==image2):# and image_file_1 != image_file_2): 29 | print("Same image!!!") 30 | print(image_file_1) 31 | print(image_file_2) 32 | # else: 33 | # print("not same") 34 | # print(image_file_1) 35 | # print(image_file_2) 36 | -------------------------------------------------------------------------------- /dataset/combined_results_imgs.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy 4 | 5 | """ 6 | Script that creates result images that puts the label results over input images, with some opacity. 7 | 8 | """ 9 | 10 | # image_path = "../../IR_images/combined_dataset/" 11 | # label_path = "../result_imgs/for_combined_images/" 12 | # image_path = "../../aerial_datasets/IR_RGB_0.1res/IR_images/combined_dataset/test_images/images" #Input images 13 | image_path = "../../aerial_datasets/IR_RGB_0.1res/IR_images/combined_dataset/test_images/images/" #Ground truth 14 | label_path = "../result_imgs_IR/" 15 | outpath = "../result_imgs_combined/" 16 | 17 | images = sorted(os.listdir(image_path)) 18 | labels = sorted(os.listdir(label_path)) 19 | 20 | #IR test = 309 21 | #IR labels = 309 22 | 23 | #First rename all image-labels based on what number in sorted list they are? Smallest number = 0, next smallest number = 1 and so on? 24 | 25 | processed = 0 26 | 27 | for image_name in images: 28 | for label_name in labels: 29 | labelnr = label_name.split("testing_image")[1].split(".png")[0] 30 | 31 | if(int(labelnr) == int(processed)): 32 | label = Image.open(os.path.join(label_path, label_name)) 33 | print("same number!") 34 | image = Image.open(os.path.join(image_path, image_name)) 35 | # image = Image.open(os.path.join(image_path, image_name)).convert('RGB') 36 | print(label) 37 | print(image) 38 | blend_img = Image.blend(label, image, .7) 39 | blend_img.save(os.path.join(outpath, image_name)) 40 | # Image.blend(label, image, .7).save(os.path.join(outpath, label_name+"_label")) 41 | processed +=1 42 | 43 | 44 | # for label_name in labels: 45 | # labelnr = label_name.split("testing_image")[1].split(".png")[0] 46 | # label = Image.open(os.path.join(label_path, label_name)) 47 | # if(len(labelnr) != -1): 48 | # # labelnr = "400"+labelnr 49 | # # print(int(float(labelnr))) 50 | # labelnr = int(float(labelnr)) + 4271 51 | # print("new label number is: ") 52 | # print(labelnr) 53 | # # elif(len(labelnr) == 2): 54 | # # labelnr = "40"+labelnr 55 | # # elif(len(labelnr) == 4): 56 | # # labelnr = "4"+labelnr 57 | # # print(labelnr) 58 | # for image_name in images: 59 | # imagenr = image_name.split("map_color_")[1].split(".png")[0] 60 | # print('imagenr is:') 61 | # print(imagenr) 62 | # if(int(labelnr) == int(imagenr)): 63 | # print("same number!") 64 | # image = Image.open(os.path.join(image_path, image_name)) 65 | # Image.blend(label, image, .7).save(os.path.join(outpath, label_name)) 66 | -------------------------------------------------------------------------------- /dataset/convert_label_to_correct_classes.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy 4 | 5 | """ 6 | Script that converts all pixels in image to correct classes. 7 | Originally building class was class number 14, so all pixels of type building has value 14. 8 | It therefor needs to be changed into 1, since the really belong to class 1 and not class 14. 9 | """ 10 | 11 | image_path="../../aerial_img_4600/val_images/png/labels" 12 | outpath = "../masterproject/aerial_img_4600/val_images/png/labels_correct" 13 | 14 | images = sorted(os.listdir(image_path)) 15 | 16 | existing_pixelvalues = [] 17 | 18 | for image_file in images: 19 | 20 | image = Image.open(os.path.join(image_path,image_file)) 21 | pixels = image.load() 22 | 23 | for x in range(0, image.size[0]): 24 | for y in range(0, image.size[1]): 25 | if pixels[x,y] not in existing_pixelvalues: 26 | existing_pixelvalues.append(pixels[x,y]) 27 | if pixels[x,y] == 14: 28 | pixels[x,y] = 1 29 | 30 | print(existing_pixelvalues) 31 | break#testing with just one image 32 | #image.save(os.path.join(outpath, image_file), "PNG", quality=100) 33 | -------------------------------------------------------------------------------- /dataset/dataset_balance.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy 4 | 5 | """ Calculate median frequency balancing for dataset, to avoid issues with unbalanced dataset. 6 | 7 | we weight each pixel by 8 | ac = median_freq / freq(c) 9 | where freq(c) is the number of pixels of class c divided by the total number of pixels in images where c is present, 10 | and median_freq is the median of these frequencies. 11 | 12 | """ 13 | 14 | # image_path="../masterproject/aerial_img_1400/test_images/jpeg/labels" 15 | image_path="../../aerial_datasets/IR_RGB_0.1res/IR_images/combined_dataset/train_images/labels/" 16 | images = sorted(os.listdir(image_path)) 17 | 18 | #Saving total number of pixels in dataset belonging to each class 19 | tot_num_class0=0.0 20 | tot_num_class1=0.0 21 | tot_pixels_dataset=0.0 22 | 23 | class1_freqs = [] 24 | class0_freqs = [] 25 | class_freqs = [] 26 | 27 | for image_file in images: 28 | #Saving number of pixels in one image belonging to each class 29 | print("inspecting image %s"%image_file) 30 | num_class0 = 0.0 #not building 31 | num_class1 = 0.0 #building 32 | 33 | image = Image.open(os.path.join(image_path,image_file)) 34 | pixels = image.load() 35 | 36 | img_array = numpy.asarray(image, dtype="float") 37 | 38 | num_class0 = float(numpy.count_nonzero(img_array == 0.0)) 39 | num_class1 = float(numpy.count_nonzero(img_array == 1.0)) 40 | tot_pixels = num_class0 + num_class1 41 | 42 | freq_class0 = num_class0 / tot_pixels 43 | freq_class1 = num_class1 / tot_pixels 44 | 45 | class_freqs.append(freq_class0) 46 | class_freqs.append(freq_class1) 47 | 48 | tot_num_class0 = tot_num_class0 + num_class0 49 | tot_num_class1 = tot_num_class1 + num_class1 50 | tot_pixels_dataset = tot_pixels_dataset + tot_pixels 51 | 52 | median_freq = numpy.median(class_freqs) 53 | 54 | tot_freq_class0= tot_num_class0 / tot_pixels_dataset 55 | tot_freq_class1= tot_num_class1 / tot_pixels_dataset 56 | 57 | class0_score = median_freq / tot_freq_class0 58 | class1_score = median_freq / tot_freq_class1 59 | 60 | print('Score for class not building: ') 61 | print(class0_score) 62 | print('Score for class building: ') 63 | print(class1_score) 64 | 65 | """ RESULT: 66 | Score for class not building: 67 | 0.545399958039 68 | Score for class building: 69 | 6.006613019 70 | """ 71 | -------------------------------------------------------------------------------- /dataset/geojson_to_png.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image, ImageDraw 3 | # import geojson 4 | import csv 5 | 6 | path="../famous_datasets/SpaceNet/processedBuildingLabels/vectordata/geojson" 7 | outpath="../famous_datasets/SpaceNet/processedBuildingLabels/vectordata/png" 8 | 9 | 10 | files = sorted(os.listdir(path)) 11 | 12 | 13 | out = Image.new("L",(49,87)) 14 | dout = ImageDraw.Draw(out) 15 | import csv 16 | 17 | for name in files: 18 | with open(os.path.join(path, name), 'r') as f: 19 | reader = csv.reader(f) 20 | for row in reader: 21 | dout.point((int(row[0]) / 10,int(row[1]) / 10),fill=int(int(row[2]) * 2.55)) 22 | #print(row[0] + " " + row[1] + " " + row[2]) 23 | out.show() 24 | break 25 | 26 | # for name in files: 27 | # # outfile = name.replace(' ', '')[:-3]+"jpeg" 28 | # outfile = name.replace(' ', '')[:-3]+"png" 29 | # geojson = geojson.loads(name) 30 | # im = Image.open(os.path.join(path, name)) 31 | # im=im.convert('RGB') 32 | # print("Generating png for %s" % name) 33 | # im.thumbnail(im.size) 34 | # #im.save(outfile, "PNG", quality=100) 35 | # # im.save(os.path.join(outpath, outfile), "JPEG", quality=100) 36 | # im.save(os.path.join(outpath, outfile), "PNG", quality=100) 37 | -------------------------------------------------------------------------------- /dataset/getChannels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | from PIL import Image 6 | 7 | 8 | image_path = "../../IR_images/combined_dataset/test_images/images/map_color_4271.png" 9 | 10 | img = Image.open(image_path) 11 | 12 | print(img.mode) 13 | print(img.size) 14 | -------------------------------------------------------------------------------- /dataset/import_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | from osgeo import ogr 5 | import psycopg2 6 | 7 | 8 | 9 | def importAr(to_cur, filename): 10 | print("Importing:", filename) 11 | dataSource = ogr.Open(filename) 12 | dataLayer = dataSource.GetLayer(0) 13 | 14 | for feature in dataLayer: 15 | geom = feature.GetGeometryRef().ExportToWkt() 16 | id = feature.GetField("poly_id") 17 | objtype = feature.GetField("objtype") 18 | artype = feature.GetField("artype") 19 | arskogbon = feature.GetField("arskogbon") 20 | artreslag = feature.GetField("artreslag") 21 | argrunnf = feature.GetField("argrunnf") 22 | 23 | to_tuple = ( id, objtype, artype, arskogbon, artreslag, argrunnf, geom) 24 | to_cur.execute("""INSERT into ar_bygg (id, objtype, artype, arskogbon, artreslag, argrunnf, geom) 25 | SELECT %s, %s, %s, %s, %s, %s, ST_GeometryFromText(%s);""", 26 | to_tuple) 27 | 28 | to_conn.commit() 29 | dataSource.Destroy() 30 | 31 | def importBygg(to_cur, filename): 32 | print("Importing:", filename) 33 | dataSource = ogr.Open(filename) 34 | dataLayer = dataSource.GetLayer(0) 35 | print("dataLayer") 36 | print(dataLayer) 37 | 38 | #Genererer id'er fortloopende 39 | for id, feature in enumerate(dataLayer): 40 | geom = feature.GetGeometryRef() 41 | if not geom: 42 | continue 43 | geom.FlattenTo2D() 44 | print("Feature") 45 | print(feature) 46 | 47 | for i in range(1, feature.GetFieldCount()): 48 | field = feature.GetDefnRef().GetFieldDefn(i).GetName() 49 | if( i == 4): 50 | continue 51 | #print(field.encode('utf-8')) 52 | 53 | byggtyp = feature.GetField("BYGGTYP_NB") 54 | #poly_id = feature.GetField("LOKALID ") 55 | objtype = feature.GetField("OBJTYPE") 56 | 57 | to_tuple = (id, objtype, byggtyp, 'SRID=25832;' + geom.ExportToWkt()) 58 | 59 | to_cur.execute("""INSERT into ar_bygg (id, objtype, byggtyp, geom) 60 | SELECT %s, %s, %s, ST_GeometryFromText(%s);""", 61 | to_tuple) 62 | 63 | to_conn.commit() 64 | dataSource.Destroy() 65 | 66 | # to_conn = psycopg2.connect("host=localhost port=5433 dbname=ar-bygg-ostfold user=postgres password=24Pils") 67 | to_conn = psycopg2.connect("host=localhost port=5432 dbname=Asker user=postgres password=1234") 68 | to_cur = to_conn.cursor() 69 | 70 | # by_filebase = "./ar5_bygg_01/32_FKB_{0}_Bygning/32_{0}bygning_flate.shp" 71 | #by_filebase = "../../fkb-data/32_FKB_{0}_Bygning/32_{0}bygning_flate.shp" 72 | by_filebase = "../../fkb-data/Asker-shape/Bygning_polygon.shp" 73 | 74 | #Kommune nummer liste - spesifisert i mappenavnet 75 | k_nr_list = [ #For FKB_area_for_IR_ad_RGB 76 | #"0228" #ralingen 77 | #"0230" #lorenskog 78 | #"0231" #skedsmo 79 | #"0233" #nittedal 80 | #"1003" #farsund 81 | #"0226" #sorum 82 | #"0227" #fet 83 | #"1601" #Trondheim 84 | "0220" #Asker 85 | ] 86 | #k_nr_list = [ #For IR_area_fkb mappen 87 | # "0211" 88 | # "0213", 89 | # "0214", 90 | # "0215", 91 | # "0216", 92 | # "0217", 93 | # "0219", 94 | # "0220", 95 | #"0426", 96 | #"0427" 97 | #] 98 | 99 | # k_nr_list = [ #For ar5_bygg_01 mappen 100 | # "0101", 101 | # "0104", 102 | # "0105", 103 | # "0106", 104 | # "0111", 105 | # "0118", 106 | # "0119", 107 | # "0121", 108 | # "0122", 109 | # "0123", 110 | # "0124", 111 | # "0125", 112 | # "0127", 113 | # "0128", 114 | # "0135", 115 | # "0136", 116 | # "0137", 117 | # "0138", 118 | # ] 119 | 120 | print('by_file') 121 | print(by_filebase) 122 | 123 | # importAr(to_cur, ar_file) 124 | importBygg(to_cur, by_filebase) 125 | 126 | ''' for k_nr in k_nr_list: 127 | # ar_file = ar_filebase.format(k_nr) 128 | #by_file = by_filebase.format(k_nr) 129 | print('by_file') 130 | print(by_filebase) 131 | 132 | # importAr(to_cur, ar_file) 133 | importBygg(to_cur, by_filebase) ''' 134 | -------------------------------------------------------------------------------- /dataset/postgres_init.txt: -------------------------------------------------------------------------------- 1 | CREATE EXTENSION postgis; 2 | 3 | CREATE SEQUENCE ar_bygg_gid_seq 4 | INCREMENT 1 5 | MINVALUE 1 6 | MAXVALUE 9223372036854775807 7 | START 483286 8 | CACHE 1; 9 | ALTER TABLE ar_bygg_gid_seq 10 | OWNER TO postgres; 11 | 12 | CREATE TABLE ar_bygg 13 | ( 14 | gid integer NOT NULL DEFAULT nextval('ar_bygg_gid_seq'::regclass), 15 | objtype text, 16 | id bigint, 17 | byggtyp smallint, 18 | artype smallint, 19 | arskogbon smallint, 20 | artreslag smallint, 21 | argrunnf smallint, 22 | geom geometry(Geometry,25832), 23 | CONSTRAINT ar_bygg_gid_pk PRIMARY KEY (gid) 24 | ) 25 | WITH ( 26 | OIDS=FALSE 27 | ); 28 | ALTER TABLE ar_bygg 29 | OWNER TO postgres; 30 | 31 | -- Index: ar_bygg_geom_idx 32 | 33 | -- DROP INDEX ar_bygg_geom_idx; 34 | 35 | CREATE INDEX ar_bygg_geom_idx 36 | ON ar_bygg 37 | USING gist 38 | (geom); 39 | -------------------------------------------------------------------------------- /dataset/split_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import requests 4 | import shutil 5 | import numpy as np 6 | from PIL import Image 7 | 8 | """ Splitting dataset into three parts: Training, testing, validation""" 9 | 10 | # SET PARAMETERS: 11 | # data_base_path = "../../IR_images/asker_berum/png/" 12 | # data_base_path = "../../IR_images/follo/png/" 13 | data_base_path = "../../IR_images/elverum_vaaler/png/" #use new_number code!! 14 | outpath_base = "../../IR_images/combined_dataset/" 15 | 16 | #------------------------ 17 | 18 | image_path = data_base_path+"images/" 19 | label_path = data_base_path+"labels/" 20 | 21 | #outpath_images = ""#outpath_base + "train_images/images" 22 | #outpath_labels = "" 23 | 24 | training_size = 0.7 25 | val_size = 0.1 26 | 27 | images = os.listdir(image_path) 28 | labels = os.listdir(label_path) 29 | 30 | tot_number = len(images) 31 | print('tot_number') 32 | print(tot_number) 33 | processed_number = 0 34 | 35 | def create_dir(path): 36 | if not os.path.exists(path): 37 | os.makedirs(path) 38 | 39 | def set_outpath(processed_number): 40 | if(processed_number > (tot_number * (training_size+val_size)) ): 41 | print("\n TRAINING!") 42 | outpath_images = outpath_base + "test_images/images" 43 | outpath_labels = outpath_base + "test_images/labels" 44 | elif(processed_number > (tot_number * training_size)): 45 | print("\n Val images, processed_number is:") 46 | print(processed_number) 47 | outpath_images = outpath_base + "val_images/images" 48 | outpath_labels = outpath_base + "val_images/labels" 49 | else: 50 | outpath_images = outpath_base + "train_images/images" 51 | outpath_labels = outpath_base + "train_images/labels" 52 | 53 | create_dir(outpath_images) 54 | create_dir(outpath_labels) 55 | return outpath_images, outpath_labels 56 | 57 | 58 | 59 | for image_name in images: 60 | 61 | imagenr = image_name.split("map_color_")[1].split(".png")[0] 62 | image = Image.open(os.path.join(image_path, image_name)) 63 | 64 | #Obs! Some images dont have labels version - skip them! 65 | for label_name in labels: 66 | labelnr = label_name.split("map_categories_")[1].split(".png")[0] 67 | print(labelnr) 68 | if(imagenr == labelnr): 69 | outpath_images, outpath_labels = set_outpath(processed_number) 70 | 71 | #new_number = "10"+image_name.split("map_color_")[1].split(".png")[0] #forgot to set higher number so now multiple images has same number in different parts - have to change that 72 | 73 | newName_image = "map_color_"+new_number+".png" 74 | newName_label = "map_categories_"+new_number+".png" 75 | 76 | os.rename(os.path.join(image_path, image_name), os.path.join(outpath_images, newName_image) ) 77 | os.rename(os.path.join(label_path, label_name), os.path.join(outpath_labels, newName_label) ) 78 | processed_number = processed_number + 1 79 | continue 80 | -------------------------------------------------------------------------------- /docs/Basic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathildor/TF-SegNet/dff209c8174b5e8fa77b4c2644298f6903a09445/docs/Basic.png -------------------------------------------------------------------------------- /docs/Example-result-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathildor/TF-SegNet/dff209c8174b5e8fa77b4c2644298f6903a09445/docs/Example-result-2.png -------------------------------------------------------------------------------- /docs/Example-result-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathildor/TF-SegNet/dff209c8174b5e8fa77b4c2644298f6903a09445/docs/Example-result-3.png -------------------------------------------------------------------------------- /docs/Example-result-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathildor/TF-SegNet/dff209c8174b5e8fa77b4c2644298f6903a09445/docs/Example-result-4.png -------------------------------------------------------------------------------- /docs/Example-result-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathildor/TF-SegNet/dff209c8174b5e8fa77b4c2644298f6903a09445/docs/Example-result-5.png -------------------------------------------------------------------------------- /docs/Example-result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathildor/TF-SegNet/dff209c8174b5e8fa77b4c2644298f6903a09445/docs/Example-result.png -------------------------------------------------------------------------------- /docs/Extended-dropout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathildor/TF-SegNet/dff209c8174b5e8fa77b4c2644298f6903a09445/docs/Extended-dropout.png -------------------------------------------------------------------------------- /docs/Extended.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathildor/TF-SegNet/dff209c8174b5e8fa77b4c2644298f6903a09445/docs/Extended.png -------------------------------------------------------------------------------- /docs/arch.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathildor/TF-SegNet/dff209c8174b5e8fa77b4c2644298f6903a09445/docs/arch.PNG -------------------------------------------------------------------------------- /docs/dataset-example.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathildor/TF-SegNet/dff209c8174b5e8fa77b4c2644298f6903a09445/docs/dataset-example.PNG -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import time 4 | from datetime import datetime 5 | import numpy as np 6 | import AirNet 7 | 8 | FLAGS = tf.app.flags.FLAGS 9 | 10 | def train(is_finetune = False): 11 | 12 | startstep = 0 if not is_finetune else int(FLAGS.finetune_dir.split('-')[-1]) 13 | image_filenames, label_filenames = AirNet.get_filename_list(FLAGS.train_dir) 14 | val_image_filenames, val_label_filenames = AirNet.get_filename_list(FLAGS.val_dir) 15 | 16 | with tf.Graph().as_default(): 17 | 18 | images, labels, is_training, keep_prob = AirNet.placeholder_inputs(batch_size=FLAGS.batch_size) 19 | 20 | images, labels = AirNet.dataset_inputs(image_filenames, label_filenames, FLAGS.batch_size) 21 | val_images, val_labels = AirNet.dataset_inputs(val_image_filenames, val_label_filenames, FLAGS.eval_batch_size, False) 22 | 23 | if FLAGS.model == "basic": 24 | logits = AirNet.inference_basic(images, is_training) 25 | elif FLAGS.model == "extended": 26 | logits = AirNet.inference_extended(images, is_training) 27 | elif FLAGS.model == "basic_dropout": 28 | logits = AirNet.inference_basic_dropout(images, is_training, keep_prob) 29 | elif FLAGS.model == "extended_dropout": 30 | logits = AirNet.inference_extended_dropout(images, is_training, keep_prob) 31 | else: 32 | raise ValueError("The selected model does not exist") 33 | 34 | loss = AirNet.loss_calc(logits=logits, labels=labels) 35 | train_op, global_step = AirNet.training(loss=loss) 36 | accuracy = tf.argmax(logits, axis=3) 37 | 38 | summary = tf.summary.merge_all() 39 | saver = tf.train.Saver(max_to_keep=100000) 40 | 41 | with tf.Session() as sess: 42 | 43 | if(is_finetune): 44 | print("\n =====================================================") 45 | print(" Finetuning with model: ", FLAGS.model) 46 | print("\n Batch size is: ", FLAGS.batch_size) 47 | print(" ckpt files are saved to: ", FLAGS.log_dir) 48 | print(" Max iterations to train is: ", FLAGS.max_steps) 49 | print(" =====================================================") 50 | saver.restore(sess, FLAGS.finetune_dir) 51 | else: 52 | print("\n =====================================================") 53 | print(" Training from scratch with model: ", FLAGS.model) 54 | print("\n Batch size is: ", FLAGS.batch_size) 55 | print(" ckpt files are saved to: ", FLAGS.log_dir) 56 | print(" Max iterations to train is: ", FLAGS.max_steps) 57 | print(" =====================================================") 58 | sess.run(tf.variables_initializer(tf.global_variables())) 59 | sess.run(tf.local_variables_initializer()) 60 | 61 | # Start the queue runners. 62 | coord = tf.train.Coordinator() 63 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 64 | 65 | train_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph) 66 | #val_writer = tf.summary.FileWriter(#TEST_WRITER_DIR) 67 | 68 | """ Starting iterations to train the network """ 69 | for step in range(startstep+1, startstep + FLAGS.max_steps+1): 70 | images_batch, labels_batch = sess.run(fetches=[images, labels]) 71 | 72 | train_feed_dict = {images: images_batch, 73 | labels: labels_batch, 74 | is_training: True, 75 | keep_prob: 0.5} 76 | 77 | start_time = time.time() 78 | 79 | _, train_loss_value, train_accuracy_value, train_summary_str = sess.run([train_op, loss, accuracy, summary], feed_dict=train_feed_dict) 80 | 81 | #Finding duration for training batch 82 | duration = time.time() - start_time 83 | 84 | if step % 10 == 0: #Print info about training 85 | examples_per_sec = FLAGS.batch_size / duration 86 | sec_per_batch = float(duration) 87 | 88 | print('\n--- Normal training ---') 89 | format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 90 | 'sec/batch)') 91 | print (format_str % (datetime.now(), step, train_loss_value, 92 | examples_per_sec, sec_per_batch)) 93 | 94 | # eval current training batch pre-class accuracy 95 | pred = sess.run(logits, feed_dict=train_feed_dict) 96 | AirNet.per_class_acc(pred, labels_batch) #printing class accuracy 97 | 98 | train_writer.add_summary(train_summary_str, step) 99 | train_writer.flush() 100 | 101 | if step % 100 == 0 or (step + 1) == FLAGS.max_steps: 102 | test_iter = FLAGS.num_examples_epoch_test // FLAGS.batch_size 103 | """ Validate training by running validation dataset """ 104 | print("\n===========================================================") 105 | print("--- Running test on VALIDATION dataset ---") 106 | total_val_loss=0.0 107 | hist = np.zeros((FLAGS.num_class, FLAGS.num_class)) 108 | for val_step in range (test_iter): 109 | val_images_batch, val_labels_batch = sess.run(fetches=[val_images, val_labels]) 110 | 111 | val_feed_dict = { images: val_images_batch, 112 | labels: val_labels_batch, 113 | is_training: True, 114 | keep_prob: 1.0} 115 | 116 | _val_loss, _val_pred = sess.run(fetches=[loss, logits], feed_dict=val_feed_dict) 117 | total_val_loss += _val_loss 118 | hist += AirNet.get_hist(_val_pred, val_labels_batch) 119 | print("Validation Loss: ", total_val_loss / test_iter, ". If this value increases the model is likely overfitting.") 120 | AirNet.print_hist_summery(hist) 121 | print("===========================================================") 122 | 123 | # Save the model checkpoint periodically. 124 | if step % 1000 == 0 or step % 500 == 0 or (step + 1) == FLAGS.max_steps: 125 | print("\n--- SAVING SESSION ---") 126 | checkpoint_path = os.path.join(FLAGS.log_dir, 'model.ckpt') 127 | saver.save(sess, checkpoint_path, global_step=step) 128 | print("=========================") 129 | 130 | coord.request_stop() 131 | coord.join(threads) 132 | 133 | def main(args): 134 | if FLAGS.testing: 135 | print("Testing the model!") 136 | AirNet.test() 137 | elif FLAGS.finetune: 138 | train(is_finetune=True) 139 | else: 140 | train(is_finetune=False) 141 | 142 | if __name__ == "__main__": 143 | tf.app.run() # wrapper that handles flags parsing. --------------------------------------------------------------------------------