├── BiERF-PSPNet ├── model.py └── train.py ├── README.md └── erf-pspnet ├── dataset ├── train │ └── 0001TP_006690.png ├── trainannot │ └── 0001TP_006690.png ├── val │ └── 0001TP_008550.png └── valannot │ └── 0001TP_008550.png ├── model.py ├── train.py └── trainial.py /BiERF-PSPNet/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | def conv(inputs,filters,kernel_size,strides=(1, 1),padding='SAME',dilation_rate=(1, 1),activation=tf.nn.relu,use_bias=True,regularizer=None,name=None,reuse=None): 3 | out=tf.layers.conv2d( 4 | inputs, 5 | filters=filters, 6 | kernel_size=kernel_size, 7 | strides=strides, 8 | padding=padding, 9 | dilation_rate=dilation_rate, 10 | activation=activation, 11 | use_bias=use_bias, 12 | kernel_regularizer=regularizer, 13 | bias_initializer=tf.zeros_initializer(), 14 | kernel_initializer= tf.random_normal_initializer(stddev=0.1), 15 | name=name, 16 | reuse=reuse) 17 | return out 18 | 19 | def batch(inputs,training=True,reuse=None,momentum=0.9,name='n'): 20 | out=tf.layers.batch_normalization(inputs,training=training,reuse=reuse,momentum=momentum,name=name) 21 | return out 22 | 23 | def downsample(x, n_filters, is_training,l2=None, name="down",momentum=0.9,reuse=None): 24 | with tf.variable_scope(name): 25 | reg = None if l2 is None else tf.contrib.layers.l2_regularizer(scale=l2) 26 | n_filters_in = x.shape.as_list()[-1] 27 | n_filters_conv = n_filters - n_filters_in 28 | x=tf.concat([conv(x, n_filters_conv, kernel_size=[3, 3],activation=None,strides=2,name='conv',regularizer=reg,reuse=reuse),tf.layers.max_pooling2d(x,[2,2],padding='SAME',strides=2,name='pool')],-1) 29 | x=tf.nn.relu(batch(x,training=is_training,reuse=reuse,momentum=momentum,name='batch')) 30 | return x 31 | 32 | def factorized_res_module(x, is_training, dropout=0.3, dilation=[1,1], l2=None, name="fres",reuse=None,momentum=0.9): 33 | with tf.variable_scope(name): 34 | reg = None if l2 is None else tf.contrib.layers.l2_regularizer(scale=l2) 35 | n_filters = x.shape.as_list()[-1] 36 | y =conv(x,n_filters,kernel_size=[3,1],dilation_rate=dilation[0],name='conv_a_3x1',regularizer=reg,reuse=reuse) 37 | y =conv(y,n_filters,kernel_size=[1,3],dilation_rate=dilation[0],activation=None,name='conv_a_1x3',regularizer=reg,reuse=reuse) 38 | y=tf.nn.relu(batch(y,training=is_training,reuse=reuse,momentum=momentum,name='batch1')) 39 | y = conv(y,n_filters,kernel_size=[3,1],dilation_rate=[dilation[1],1],name='conv_b_3x1',regularizer=reg,reuse=reuse) 40 | y = conv(y,n_filters,kernel_size=[1,3],dilation_rate=[1,dilation[1]],activation=None,name='conv_b_1x3',regularizer=reg,reuse=reuse) 41 | y=tf.nn.relu(batch(y,training=is_training,reuse=reuse,momentum=momentum,name='batch2')) 42 | y=tf.layers.dropout(y,rate=dropout,training=is_training) 43 | y =tf.add(x,y,name='add') 44 | return y 45 | 46 | def Encoder(x, is_training,l2=None,reuse=None,momentum=0.9): 47 | #x = tf.div(x, 255., name="rescaled_inputs") 48 | net=downsample(x, 16, is_training=is_training,name="d1",momentum=momentum,reuse=reuse) 49 | net=downsample(net, 64, is_training=is_training,name="d2",momentum=momentum,reuse=reuse) 50 | net = factorized_res_module(net, is_training=is_training, dropout=0.03, dilation=[1, 1], l2=l2,name="fres3",reuse=reuse,momentum=momentum) 51 | net = factorized_res_module(net, is_training=is_training, dropout=0.03, dilation=[1, 1], l2=l2,name="fres4",reuse=reuse,momentum=momentum) 52 | net = factorized_res_module(net, is_training=is_training, dropout=0.03, dilation=[1, 1], l2=l2,name="fres5",reuse=reuse,momentum=momentum) 53 | net = factorized_res_module(net, is_training=is_training, dropout=0.03, dilation=[1, 1], l2=l2,name="fres6",reuse=reuse,momentum=momentum) 54 | net = factorized_res_module(net, is_training=is_training, dropout=0.03, dilation=[1, 1], l2=l2,name="fres7",reuse=reuse,momentum=momentum) 55 | net=downsample(net, 128, is_training=is_training,name="d8",momentum=momentum,reuse=reuse) 56 | net = factorized_res_module(net, is_training=is_training, dropout=0.3, dilation=[1, 2], l2=l2,name="fres9",reuse=reuse,momentum=momentum) 57 | net = factorized_res_module(net, is_training=is_training, dropout=0.3, dilation=[1, 4], l2=l2,name="fres10",reuse=reuse,momentum=momentum) 58 | net = factorized_res_module(net, is_training=is_training, dropout=0.3, dilation=[1, 8], l2=l2,name="fres11",reuse=reuse,momentum=momentum) 59 | net = factorized_res_module(net, is_training=is_training, dropout=0.3, dilation=[1, 16], l2=l2,name="fres12",reuse=reuse,momentum=momentum) 60 | net = factorized_res_module(net, is_training=is_training, dropout=0.3, dilation=[1, 2], l2=l2,name="fres13",reuse=reuse,momentum=momentum) 61 | net = factorized_res_module(net, is_training=is_training, dropout=0.3, dilation=[1, 4], l2=l2,name="fres14",reuse=reuse,momentum=momentum) 62 | net = factorized_res_module(net, is_training=is_training, dropout=0.3, dilation=[1, 8], l2=l2,name="fres15",reuse=reuse,momentum=momentum) 63 | net = factorized_res_module(net, is_training=is_training, dropout=0.3, dilation=[1, 16], l2=l2,name="fres16",reuse=reuse,momentum=momentum) 64 | return net 65 | 66 | def spatial(x,name='spatial',is_training=False,l2=None,reuse=None,momentum=0.9): 67 | with tf.variable_scope(name): 68 | reg = None if l2 is None else tf.contrib.layers.l2_regularizer(scale=l2) 69 | x = conv(x,64,kernel_size=3,name='conv1',activation=None,regularizer=reg,reuse=reuse,strides=2) 70 | x=tf.nn.relu(batch(x,training=is_training,reuse=reuse,momentum=momentum,name='batch1')) 71 | x = conv(x,128,kernel_size=3,name='conv2',activation=None,regularizer=reg,reuse=reuse,strides=2) 72 | x=tf.nn.relu(batch(x,training=is_training,reuse=reuse,momentum=momentum,name='batch2')) 73 | x = conv(x,256,kernel_size=3,name='conv3',activation=None,regularizer=reg,reuse=reuse,strides=2) 74 | x=tf.nn.relu(batch(x,training=is_training,reuse=reuse,momentum=momentum,name='batch3')) 75 | return x 76 | 77 | def Decoder(x,shape=[480,640],name='decoder',is_training=False,l2=None,reuse=None,momentum=0.9): 78 | p1=x 79 | p2=tf.layers.average_pooling2d(x,pool_size=[2,2],strides=2,padding='SAME',name='pool2') 80 | p3=tf.layers.average_pooling2d(x,pool_size=[4,4],strides=4,padding='SAME',name='pool3') 81 | p4=tf.layers.average_pooling2d(x,pool_size=[8,8],strides=8,padding='SAME',name='pool4') 82 | with tf.variable_scope(name): 83 | reg = None if l2 is None else tf.contrib.layers.l2_regularizer(scale=l2) 84 | j1=conv(p1, 32, kernel_size=1,activation=None,name='conv1',regularizer=reg,reuse=reuse,use_bias=None) 85 | j1=tf.nn.relu(batch(j1,training=is_training,reuse=reuse,momentum=momentum,name='batch1')) 86 | j2=conv(p2, 32, kernel_size=1,activation=None,name='conv2',regularizer=reg,reuse=reuse,use_bias=None) 87 | j2=tf.nn.relu(batch(j2,training=is_training,reuse=reuse,momentum=momentum,name='batch2')) 88 | j3=conv(p3, 32, kernel_size=1,activation=None,name='conv3',regularizer=reg,reuse=reuse,use_bias=None) 89 | j3=tf.nn.relu(batch(j3,training=is_training,reuse=reuse,momentum=momentum,name='batch3')) 90 | j4=conv(p4, 32, kernel_size=1,activation=None,name='conv4',regularizer=reg,reuse=reuse,use_bias=None) 91 | j4=tf.nn.relu(batch(j4,training=is_training,reuse=reuse,momentum=momentum,name='batch4')) 92 | f2=tf.image.resize_images(j2, [shape[0]//8,shape[1]//8],method=0) 93 | f3=tf.image.resize_images(j3, [shape[0]//8,shape[1]//8],method=0) 94 | f4=tf.image.resize_images(j4, [shape[0]//8,shape[1]//8],method=0) 95 | net=tf.concat([p1,j1,f2,f3,f4],-1) 96 | net=conv(net, 256, kernel_size=3,activation=None,name='conv5',regularizer=reg,reuse=reuse,use_bias=None) 97 | net=tf.nn.relu(batch(net,training=is_training,reuse=reuse,momentum=momentum,name='batch5')) 98 | return net 99 | 100 | def FeatureFusionModule(input_1, input_2, numclasses,name='fusion',is_training=False,l2=None,reuse=None,shape=[360,720],momentum=0.9): 101 | inputs = tf.concat([input_1, input_2], axis=-1) 102 | with tf.variable_scope(name): 103 | reg = None if l2 is None else tf.contrib.layers.l2_regularizer(scale=l2) 104 | inputs = conv(inputs, 256, kernel_size=3,activation=None,name='conv1',regularizer=reg,reuse=reuse) 105 | inputs=tf.nn.relu(batch(inputs,training=is_training,reuse=reuse,momentum=momentum,name='batch1')) 106 | inputs = conv(inputs, numclasses, kernel_size=3,activation=None,name='conv2',regularizer=reg,reuse=reuse) 107 | inputs=tf.nn.relu(batch(inputs,training=is_training,reuse=reuse,momentum=momentum,name='batch2')) 108 | # Global average pooling 109 | net = tf.reduce_mean(inputs, [1, 2], keepdims=True) 110 | net = conv(net, numclasses, kernel_size=1,activation=None,name='conv3',regularizer=reg,reuse=reuse) 111 | net = tf.nn.relu(batch(net,training=is_training,reuse=reuse,momentum=momentum,name='batch3')) 112 | net = conv(net, numclasses, kernel_size=1,activation=None,name='conv4',regularizer=reg,reuse=reuse) 113 | net = tf.sigmoid(net) 114 | net = tf.multiply(inputs, net) 115 | net = tf.add(inputs, net) 116 | net = conv(net, numclasses, kernel_size=1,activation=None,name='conv5',regularizer=reg,reuse=reuse)#最后加的 117 | net = tf.nn.relu(batch(net,training=is_training,reuse=reuse,momentum=momentum,name='batch4'))#最后加的 118 | net = tf.image.resize_images(net, [shape[0],shape[1]],method=0) 119 | #net = conv(net, numclasses, kernel_size=1,activation=None,name='conv5',regularizer=reg,reuse=reuse) 120 | return net 121 | 122 | def erfpspcontext(x1,x2,l2,shape=[480,640],shape2=[1024,2048],numclasses=66,reuse=None,is_training=True,momentum=0.9): 123 | x1 = Encoder(x1,is_training=is_training,l2=l2,reuse=reuse,momentum=momentum) 124 | x1 = Decoder(x1,shape=shape,is_training=is_training,l2=l2,reuse=reuse,momentum=momentum) 125 | x1 = tf.image.resize_images(x1, [shape2[0]//8,shape2[1]//8],method=0) 126 | x2 = spatial(x2,is_training=is_training,l2=l2,reuse=reuse,momentum=momentum) 127 | y =FeatureFusionModule(x1, x2, numclasses,is_training=is_training,l2=l2,reuse=reuse,shape=shape,momentum=momentum) 128 | probabilities = tf.nn.softmax(y, name='logits_to_softmax') 129 | return y,probabilities 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /BiERF-PSPNet/train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.platform import tf_logging as logging 3 | import os 4 | import time 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import model 8 | import random 9 | import csv 10 | import math 11 | 12 | #==============INPUT ARGUMENTS================== 13 | flags = tf.app.flags 14 | 15 | #Directory arguments 16 | flags.DEFINE_string('dataset_dir', './dataset', 'The dataset directory to find the train, validation and test images.') 17 | flags.DEFINE_string('logdir', './log/contextnetchangeforfinalsss', 'The log directory to save your checkpoint and event files.') 18 | #Training arguments 19 | flags.DEFINE_integer('num_classes', 19, 'The number of classes to predict.') 20 | flags.DEFINE_integer('batch_size', 6, 'The batch_size for training.') 21 | flags.DEFINE_integer('eval_batch_size', 12, 'The batch size used for validation.') 22 | flags.DEFINE_integer('image_height',360, "The input height of the images.") 23 | flags.DEFINE_integer('image_width', 720, "The input width of the images.") 24 | flags.DEFINE_integer('num_epochs', 300, "The number of epochs to train your model.") 25 | flags.DEFINE_integer('num_epochs_before_decay', 100, 'The number of epochs before decaying your learning rate.') 26 | flags.DEFINE_float('weight_decay', 2e-4, "The weight decay for ENet convolution layers.") 27 | flags.DEFINE_float('learning_rate_decay_factor', 1e-1, 'The learning rate decay factor.') 28 | flags.DEFINE_float('initial_learning_rate', 1e-3, 'The initial learning rate for your training.') 29 | flags.DEFINE_boolean('Start_train',True, "The input height of the images.") 30 | FLAGS = flags.FLAGS 31 | 32 | Start_train = FLAGS.Start_train 33 | log_name = 'model.ckpt' 34 | 35 | num_classes = FLAGS.num_classes 36 | batch_size = FLAGS.batch_size 37 | eval_batch_size = FLAGS.eval_batch_size 38 | image_height = FLAGS.image_height 39 | image_width = FLAGS.image_width 40 | 41 | #Training parameters 42 | initial_learning_rate = FLAGS.initial_learning_rate 43 | num_epochs_before_decay = FLAGS.num_epochs_before_decay 44 | num_epochs =FLAGS.num_epochs 45 | learning_rate_decay_factor = FLAGS.learning_rate_decay_factor 46 | weight_decay = FLAGS.weight_decay 47 | epsilon = 1e-8 48 | 49 | 50 | #Directories 51 | dataset_dir = FLAGS.dataset_dir 52 | logdir = FLAGS.logdir 53 | 54 | #===============PREPARATION FOR TRAINING================== 55 | #Get the images into a list 56 | image_files = sorted([os.path.join(dataset_dir, 'train', file) for file in os.listdir(dataset_dir + "/train") if file.endswith('.png')]) 57 | annotation_files = sorted([os.path.join(dataset_dir, "trainannot", file) for file in os.listdir(dataset_dir + "/trainannot") if file.endswith('.png')]) 58 | image_val_files = sorted([os.path.join(dataset_dir, 'val', file) for file in os.listdir(dataset_dir + "/val") if file.endswith('.png')]) 59 | annotation_val_files = sorted([os.path.join(dataset_dir, "valannot", file) for file in os.listdir(dataset_dir + "/valannot") if file.endswith('.png')]) 60 | #保存到excel 61 | csvname=logdir[6:]+'.csv' 62 | 63 | with open(csvname,'a', newline='') as out: 64 | csv_write = csv.writer(out,dialect='excel') 65 | a=[str(i) for i in range(num_classes)] 66 | csv_write.writerow(a) 67 | 68 | #Know the number steps to take before decaying the learning rate and batches per epoch 69 | num_batches_per_epoch = math.ceil(len(image_files) / batch_size) 70 | num_steps_per_epoch = num_batches_per_epoch 71 | decay_steps = int(num_epochs_before_decay * num_steps_per_epoch) 72 | 73 | #=================CLASS WEIGHTS=============================== 74 | #Median frequency balancing class_weights 75 | 76 | class_weights=np.array([8.6979065, 8.497886 , 8.741297 , 5.983605 , 8.662319 , 8.681756 , 77 | 8.683093 , 8.763641 , 8.576978 , 2.7114885, 6.237076 , 3.582358 , 78 | 8.439253 , 8.316548 , 8.129169 , 4.312109 , 8.170293 , 6.91469 , 79 | 8.135018 ], dtype=np.float32) 80 | 81 | class_weights1=np.array([0, 0 , 0 ,0 , 0 , 0 , 82 | 0 , 0 , 0 ,2.7114885, 0 , 3.582358 , 83 | 8.439253 , 0 ,0 , 4.312109 , 8.170293 , 6.91469 , 84 | 0 , 0. ], dtype=np.float32) 85 | class_weights2=np.array([0, 0 , 0 , 5.983605 , 0 , 0 , 86 | 0 , 0, 0 , 0, 6.237076 , 0 , 87 | 0, 8.316548 , 8.129169 , 0 , 0 , 0 , 88 | 8.135018 , 0. ], dtype=np.float32) 89 | class_weights3=np.array([8.6979065, 8.497886 , 8.741297 , 0 , 8.662319 , 8.681756 , 90 | 8.683093 , 8.763641 , 8.576978 , 0, 0 , 0 , 91 | 0 , 0 , 0 ,0 , 0 , 0 , 92 | 0, 0. ], dtype=np.float32) 93 | 94 | def weighted_cross_entropy(onehot_labels, logits, class_weights): 95 | a=tf.reduce_sum(-tf.log(tf.clip_by_value(logits, 1e-10, 1.0))*onehot_labels*class_weights) 96 | return a 97 | 98 | 99 | 100 | def decode(a,b): 101 | a = tf.read_file(a) 102 | a=tf.image.decode_png(a, channels=3) 103 | a = tf.image.convert_image_dtype(a, dtype=tf.float32) 104 | b = tf.read_file(b) 105 | b = tf.image.decode_png(b,channels=1) 106 | c=a 107 | a=tf.image.resize_images(a, [image_height,image_width],method=0) 108 | b=tf.image.resize_images(b, [image_height,image_width],method=1) 109 | a.set_shape(shape=(image_height, image_width, 3)) 110 | b.set_shape(shape=(image_height, image_width,1)) 111 | c.set_shape(shape=(1024, 2048, 3)) 112 | return a,b,c 113 | def decodev(a,b): 114 | a = tf.read_file(a) 115 | a=tf.image.decode_png(a, channels=3) 116 | a = tf.image.convert_image_dtype(a, dtype=tf.float32) 117 | b = tf.read_file(b) 118 | b = tf.image.decode_png(b,channels=1) 119 | c = a 120 | a=tf.image.resize_images(a, [image_height,image_width],method=0) 121 | b=tf.image.resize_images(b, [image_height,image_width],method=1) 122 | a.set_shape(shape=(image_height, image_width, 3)) 123 | b.set_shape(shape=(image_height, image_width,1)) 124 | c.set_shape(shape=(1024, 2048, 3)) 125 | return a,b,c 126 | def run(): 127 | with tf.Graph().as_default() as graph: 128 | tf.logging.set_verbosity(tf.logging.INFO) 129 | 130 | #===================TRAINING BRANCH======================= 131 | #Load the files into one input queue 132 | images = tf.convert_to_tensor(image_files) 133 | annotations = tf.convert_to_tensor(annotation_files) 134 | tdataset = tf.data.Dataset.from_tensor_slices((images,annotations)) 135 | tdataset = tdataset.map(decode) 136 | tdataset = tdataset.shuffle(100).batch(batch_size).repeat(num_epochs) 137 | titerator = tdataset.make_initializable_iterator() 138 | images,annotations,bigimages = titerator.get_next() 139 | 140 | 141 | images_val = tf.convert_to_tensor(image_val_files) 142 | annotations_val = tf.convert_to_tensor(annotation_val_files) 143 | vdataset = tf.data.Dataset.from_tensor_slices((images_val,annotations_val)) 144 | vdataset = vdataset.map(decodev) 145 | vdataset = vdataset.batch(eval_batch_size).repeat(num_epochs*3) 146 | viterator = vdataset.make_initializable_iterator() 147 | images_val,annotations_val,bigimages_val = viterator.get_next() 148 | 149 | 150 | 151 | 152 | 153 | #perform one-hot-encoding on the ground truth annotation to get same shape as the logits 154 | _, probabilities= model.erfpspcontext(images,bigimages,numclasses=num_classes, shape=[image_height,image_width], shape2=[1024,2048],l2=weight_decay,reuse=None,is_training=True) 155 | annotations = tf.reshape(annotations, shape=[-1, image_height, image_width]) 156 | raw_gt = tf.reshape(annotations, [-1,]) 157 | indices = tf.squeeze(tf.where(tf.less_equal(raw_gt,num_classes-1)), 1) 158 | gt = tf.cast(tf.gather(raw_gt, indices), tf.int32) 159 | gt = gt - 1 160 | gt_one = tf.one_hot(gt, num_classes, axis=-1) 161 | raw_prediction = tf.reshape(probabilities, [-1, num_classes]) 162 | prediction = tf.gather(raw_prediction, indices) 163 | 164 | 165 | 166 | 167 | annotations_ohe = tf.one_hot(annotations, num_classes, axis=-1) 168 | MASK = tf.reduce_sum(1-annotations_ohe[:,:,:,0]) 169 | 170 | los=weighted_cross_entropy(gt_one, prediction, class_weights)/MASK 171 | loss=tf.losses.add_loss(los) 172 | total_loss = tf.losses.get_total_loss() 173 | global_step = tf.train.get_or_create_global_step() 174 | #Define your exponentially decaying learning rate 175 | lr = tf.train.exponential_decay( 176 | learning_rate = initial_learning_rate, 177 | global_step = global_step, 178 | decay_steps = decay_steps, 179 | decay_rate = learning_rate_decay_factor, 180 | staircase = True) 181 | 182 | #Now we can define the optimizer that takes on the learning rate 183 | optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon=epsilon) 184 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 185 | updates_op = tf.group(*update_ops) 186 | #Create the train_op. 187 | with tf.control_dependencies([updates_op]): 188 | train_op = optimizer.minimize(total_loss,global_step=global_step) 189 | #State the metrics that you want to predict. We get a predictions that is not one_hot_encoded. 190 | #这一块为验证 191 | _, probabilities_val= model.erfpspcontext(images_val,bigimages_val,numclasses=num_classes, shape=[image_height,image_width], shape2=[1024,2048],l2=None,reuse=True,is_training=None) 192 | raw_gt_v = tf.reshape(tf.reshape(annotations_val, shape=[-1, 1024, 2048]),[-1,]) 193 | indices_v = tf.squeeze(tf.where(tf.greater(raw_gt_v,0)), 1) 194 | gt_v = tf.cast(tf.gather(raw_gt_v, indices_v), tf.int32) 195 | gt_v = gt_v-1 196 | gt_one_v = tf.one_hot(gt_v, num_classes, axis=-1) 197 | raw_prediction_v = tf.argmax(tf.reshape(probabilities_val, [-1, num_classes]),-1) 198 | prediction_v = tf.gather(raw_prediction_v, indices_v) 199 | prediction_ohe_v = tf.one_hot(prediction_v, num_classes, axis=-1) 200 | and_val=gt_one_v*prediction_ohe_v 201 | and_sum=tf.reduce_sum(and_val,[0]) 202 | or_val=tf.to_int32((gt_one_v+prediction_ohe_v)>0.5) 203 | or_sum=tf.reduce_sum(or_val,axis=[0]) 204 | T_sum=tf.reduce_sum(gta_v,axis=[0]) 205 | R_sum = tf.reduce_sum(prediction_ohe_v,axis=[0]) 206 | mPrecision=0 207 | mRecall_rate=0 208 | mIoU=0 209 | #Now we need to create a training step function that runs both the train_op, metrics_op and updates the global_step concurrently. 210 | def train_step(sess, train_op, global_step ,loss=total_loss): 211 | ''' 212 | Simply runs a session for the three arguments provided and gives a logging on the time elapsed for each global step 213 | ''' 214 | #Check the time for each sess run 215 | start_time = time.time() 216 | _,total_loss, global_step_count= sess.run([train_op,loss, global_step ]) 217 | time_elapsed = time.time() - start_time 218 | global_step_count=global_step_count+1 219 | #Run the logging to show some results 220 | logging.info('global step %s: loss: %.4f (%.2f sec/step)', global_step_count, total_loss, time_elapsed) 221 | 222 | return total_loss 223 | A = tf.Variable(tf.constant(0.0), dtype=tf.float32) 224 | a=tf.placeholder(shape=[],dtype=tf.float32) 225 | Precision=tf.assign(A, a) 226 | B = tf.Variable(tf.constant(0.0), dtype=tf.float32) 227 | b=tf.placeholder(shape=[],dtype=tf.float32) 228 | Recall=tf.assign(B, b) 229 | C = tf.Variable(tf.constant(0.0), dtype=tf.float32) 230 | c=tf.placeholder(shape=[],dtype=tf.float32) 231 | mIOU=tf.assign(C, c) 232 | predictions = tf.argmax(probabilities, -1) 233 | segmentation_output = tf.cast(tf.reshape((predictions+1)*255/num_classes, shape=[-1, image_height, image_width, 1]),tf.uint8) 234 | segmentation_ground_truth = tf.cast(tf.reshape(tf.cast(annotations, dtype=tf.float32)*255/num_classes, shape=[-1, image_height, image_width, 1]),tf.uint8) 235 | tf.summary.scalar('Monitor/Total_Loss', total_loss) 236 | tf.summary.scalar('Monitor/Precision', Precision) 237 | tf.summary.scalar('Monitor/Recall_rate', Recall) 238 | tf.summary.scalar('Monitor/mIoU', mIOU) 239 | tf.summary.scalar('Monitor/learning_rate', lr) 240 | tf.summary.image('Images/original_image', realimg, max_outputs=1) 241 | tf.summary.image('Images/segmentation_output', segmentation_output, max_outputs=1) 242 | tf.summary.image('Images/segmentation_ground_truth', segmentation_ground_truth, max_outputs=1) 243 | my_summary_op = tf.summary.merge_all() 244 | 245 | def train_sum(sess, train_op, global_step,sums,loss=total_loss,pre=0,recall=0,iou=0): 246 | start_time = time.time() 247 | _,total_loss, global_step_count,ss = sess.run([train_op,loss, global_step,sums ],feed_dict={a:pre,b:recall,c:iou}) 248 | time_elapsed = time.time() - start_time 249 | global_step_count=global_step_count+1 250 | logging.info('global step %s: loss: %.4f (%.2f sec/step)', global_step_count, total_loss, time_elapsed) 251 | 252 | return total_loss,ss 253 | 254 | def eval_step(sess,i ): 255 | and_eval_batch,or_eval_batch,T_eval_batch,R_eval_batch = sess.run([and_sum,or_sum,T_sum,R_sum]) 256 | #Log some information 257 | logging.info('STEP: %d ',i) 258 | return and_eval_batch,or_eval_batch,T_eval_batch,R_eval_batch 259 | def eval(num_class,csvname,session,image_val,eval_batch): 260 | or_=np.zeros((num_class), dtype=np.float32) 261 | and_=np.zeros((num_class), dtype=np.float32) 262 | T_=np.zeros((num_class), dtype=np.float32) 263 | R_=np.zeros((num_class), dtype=np.float32) 264 | for i in range(math.ceil(len(image_val) / eval_batch)): 265 | and_eval_batch,or_eval_batch,T_eval_batch,R_eval_batch = eval_step(session,i+1) 266 | and_=and_+and_eval_batch 267 | or_=or_+or_eval_batch 268 | T_=T_+T_eval_batch 269 | R_=R_+R_eval_batch 270 | Recall_rate=and_/T_ 271 | Precision=and_/R_ 272 | IoU=and_/or_ 273 | mPrecision=np.mean(Precision) 274 | mRecall_rate=np.mean(Recall_rate) 275 | mIoU=np.mean(IoU) 276 | print("Precision:") 277 | print(Precision) 278 | print("Recall rate:") 279 | print(Recall_rate) 280 | print("IoU:") 281 | print(IoU) 282 | print("mPrecision:") 283 | print(mPrecision) 284 | print("mRecall_rate:") 285 | print(mRecall_rate) 286 | print("mIoU") 287 | print(mIoU) 288 | with open(csvname,'a', newline='') as out: 289 | csv_write = csv.writer(out,dialect='excel') 290 | csv_write.writerow(Precision) 291 | csv_write.writerow(Recall_rate) 292 | csv_write.writerow(IoU) 293 | return mPrecision,mPrecision,mIoU 294 | gpu_options = tf.GPUOptions(allow_growth=True) 295 | config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options) 296 | init = tf.global_variables_initializer() 297 | saver=tf.train.Saver(max_to_keep=10) 298 | with tf.Session(config=config) as sess: 299 | sess.run(init) 300 | sess.run([titerator.initializer,viterator.initializer]) 301 | step = 0; 302 | if Start_train is not True: 303 | #input the checkpoint address,and the step number. 304 | checkpoint='./log/bierfpsp/model.ckpt-37127' 305 | saver.restore(sess, checkpoint) 306 | step = 37127 307 | sess.run(tf.assign(global_step,step)) 308 | summary_writer = tf.summary.FileWriter(logdir, sess.graph) 309 | final = num_steps_per_epoch * num_epochs 310 | for i in range(step,final,1): 311 | if i % num_batches_per_epoch == 0: 312 | logging.info('Epoch %s/%s', i/num_batches_per_epoch + 1, num_epochs) 313 | learning_rate_value = sess.run([lr]) 314 | logging.info('Current Learning Rate: %s', learning_rate_value) 315 | if i is not step: 316 | saver.save(sess, os.path.join(logdir,log_name),global_step=i) 317 | mPrecision,mRecall_rate,mIoU=eval(num_class=num_classes,csvname=csvname,session=sess,image_val=image_val_files,eval_batch=eval_batch_size) 318 | if i % min(num_steps_per_epoch, 10) == 0: 319 | loss,summaries = train_sum(sess, train_op,global_step,sums=my_summary_op,loss=total_loss,pre=mPrecision,recall=mPrecision,iou=mIoU) 320 | summary_writer.add_summary(summaries,global_step=i+1) 321 | else: 322 | loss = train_step(sess, train_op, global_step) 323 | summary_writer.close() 324 | eval(num_class=num_classes,csvname=csvname,session=sess,image_val=image_val_files,eval_batch=eval_batch_size) 325 | logging.info('Final Loss: %s', loss) 326 | logging.info('Finished training! Saving model to disk now.') 327 | saver.save(sess, os.path.join(logdir,log_name), global_step = final) 328 | 329 | 330 | if __name__ == '__main__': 331 | run() 332 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ERF-PSPNET 2 | ERF-PSPNET implemented by tensorflow 3 | ## Paper 4 | The code is impledmented according to the following papers. 5 | + [Unifying Terrain Awareness for the Visually Impaired through Real-Time Semantic Segmentation](https://www.mdpi.com/1424-8220/18/5/1506 ) 6 | + [Unifying terrain awareness through real-time semantic segmentation](http://www.wangkaiwei.org/file/publications/iv2018_kailun.pdf ) 7 | 8 | ## Main Dependencies 9 | ``` 10 | tensorflow 1.11 11 | Open CV 12 | Python 3.6.5 13 | ``` 14 | 15 | ## Description 16 | This repository serves as a real-time semantic segmentation networks, which is designed for the assistance for visually-impaired people. The code not only implements the tensorflow-version erf-pspnet, but also implements the code for mIOU calculation which haven't been found in other tf-version's code. Our code combines the training and evaluating stages together, which records every claesses' ioU after one epoch, and fulfill visualization supervision during training. In addition, we adapt ***IAL*** which is a loss function for pushing model to foucs on important classes. And we implement BiERF-PSPNet which is inspired by BiSeNet. 17 | 18 | ## Useage 19 | The useage is very easy, you only need to download the code, and create a file folder named ***dataset***, and create another four file folders under the dataset named ***train, trainannot, val, valannot*** ,and put into the data we want to train or eval like the example in the repository. 20 | 21 | ## Note 22 | In the future, I will update the demo code for video, a improved version network and loss-function like ***IAL*** for specific tasks. 23 | -------------------------------------------------------------------------------- /erf-pspnet/dataset/train/0001TP_006690.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Katexiang/ERF-PSPNET/e8dd8f6f0bf092b31a618a0b518c51f1a8bd4897/erf-pspnet/dataset/train/0001TP_006690.png -------------------------------------------------------------------------------- /erf-pspnet/dataset/trainannot/0001TP_006690.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Katexiang/ERF-PSPNET/e8dd8f6f0bf092b31a618a0b518c51f1a8bd4897/erf-pspnet/dataset/trainannot/0001TP_006690.png -------------------------------------------------------------------------------- /erf-pspnet/dataset/val/0001TP_008550.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Katexiang/ERF-PSPNET/e8dd8f6f0bf092b31a618a0b518c51f1a8bd4897/erf-pspnet/dataset/val/0001TP_008550.png -------------------------------------------------------------------------------- /erf-pspnet/dataset/valannot/0001TP_008550.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Katexiang/ERF-PSPNET/e8dd8f6f0bf092b31a618a0b518c51f1a8bd4897/erf-pspnet/dataset/valannot/0001TP_008550.png -------------------------------------------------------------------------------- /erf-pspnet/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import math 3 | def conv(inputs,filters,kernel_size,strides=(1, 1),padding='SAME',dilation_rate=(1, 1),activation=tf.nn.relu,use_bias=True,regularizer=None,name=None,reuse=None): 4 | out=tf.layers.conv2d( 5 | inputs, 6 | filters=filters, 7 | kernel_size=kernel_size, 8 | strides=strides, 9 | padding=padding, 10 | dilation_rate=dilation_rate, 11 | activation=activation, 12 | use_bias=use_bias, 13 | kernel_regularizer=regularizer, 14 | bias_initializer=tf.zeros_initializer(), 15 | kernel_initializer= tf.random_normal_initializer(stddev=0.1), 16 | name=name, 17 | reuse=reuse) 18 | return out 19 | 20 | def batch(inputs,training=True,reuse=None,momentum=0.9,name='n'): 21 | out=tf.layers.batch_normalization(inputs,training=training,reuse=reuse,momentum=momentum,name=name) 22 | return out 23 | 24 | def downsample(x, n_filters, is_training,l2=None, name="down",momentum=0.9,reuse=None): 25 | with tf.variable_scope(name): 26 | reg = None if l2 is None else tf.contrib.layers.l2_regularizer(scale=l2) 27 | n_filters_in = x.shape.as_list()[-1] 28 | n_filters_conv = n_filters - n_filters_in 29 | x=tf.concat([conv(x, n_filters_conv, kernel_size=[3, 3],activation=None,strides=2,name='conv',regularizer=reg,reuse=reuse),tf.layers.max_pooling2d(x,[2,2],padding='SAME',strides=2,name='pool')],-1) 30 | x=tf.nn.relu(batch(x,training=is_training,reuse=reuse,momentum=momentum,name='batch')) 31 | return x 32 | 33 | def factorized_res_module(x, is_training, dropout=0.3, dilation=[1,1], l2=None, name="fres",reuse=None,momentum=0.9): 34 | with tf.variable_scope(name): 35 | reg = None if l2 is None else tf.contrib.layers.l2_regularizer(scale=l2) 36 | n_filters = x.shape.as_list()[-1] 37 | y =conv(x,n_filters,kernel_size=[3,1],dilation_rate=dilation[0],name='conv_a_3x1',regularizer=reg,reuse=reuse) 38 | y =conv(y,n_filters,kernel_size=[1,3],dilation_rate=dilation[0],activation=None,name='conv_a_1x3',regularizer=reg,reuse=reuse) 39 | y=tf.nn.relu(batch(y,training=is_training,reuse=reuse,momentum=momentum,name='batch1')) 40 | y = conv(y,n_filters,kernel_size=[3,1],dilation_rate=[dilation[1],1],name='conv_b_3x1',regularizer=reg,reuse=reuse) 41 | y = conv(y,n_filters,kernel_size=[1,3],dilation_rate=[1,dilation[1]],activation=None,name='conv_b_1x3',regularizer=reg,reuse=reuse) 42 | y=tf.nn.relu(batch(y,training=is_training,reuse=reuse,momentum=momentum,name='batch2')) 43 | y=tf.layers.dropout(y,rate=dropout,training=is_training) 44 | y =tf.add(x,y,name='add') 45 | return y 46 | 47 | def Encoder(x, is_training,l2=None,reuse=None,momentum=0.9): 48 | #x = tf.div(x, 255., name="rescaled_inputs") 49 | net=downsample(x, 16, is_training=is_training,name="d1",l2=l2,momentum=momentum,reuse=reuse) 50 | net=downsample(net, 64, is_training=is_training,name="d2",l2=l2,momentum=momentum,reuse=reuse) 51 | net = factorized_res_module(net, is_training=is_training, dropout=0.03, dilation=[1, 1], l2=l2,name="fres3",reuse=reuse,momentum=momentum) 52 | net = factorized_res_module(net, is_training=is_training, dropout=0.03, dilation=[1, 1], l2=l2,name="fres4",reuse=reuse,momentum=momentum) 53 | net = factorized_res_module(net, is_training=is_training, dropout=0.03, dilation=[1, 1], l2=l2,name="fres5",reuse=reuse,momentum=momentum) 54 | net = factorized_res_module(net, is_training=is_training, dropout=0.03, dilation=[1, 1], l2=l2,name="fres6",reuse=reuse,momentum=momentum) 55 | net = factorized_res_module(net, is_training=is_training, dropout=0.03, dilation=[1, 1], l2=l2,name="fres7",reuse=reuse,momentum=momentum) 56 | net=downsample(net, 128, is_training=is_training,name="d8",momentum=momentum,reuse=reuse) 57 | net = factorized_res_module(net, is_training=is_training, dropout=0.3, dilation=[1, 2], l2=l2,name="fres9",reuse=reuse,momentum=momentum) 58 | net = factorized_res_module(net, is_training=is_training, dropout=0.3, dilation=[1, 4], l2=l2,name="fres10",reuse=reuse,momentum=momentum) 59 | net = factorized_res_module(net, is_training=is_training, dropout=0.3, dilation=[1, 8], l2=l2,name="fres11",reuse=reuse,momentum=momentum) 60 | net = factorized_res_module(net, is_training=is_training, dropout=0.3, dilation=[1, 16], l2=l2,name="fres12",reuse=reuse,momentum=momentum) 61 | net = factorized_res_module(net, is_training=is_training, dropout=0.3, dilation=[1, 2], l2=l2,name="fres13",reuse=reuse,momentum=momentum) 62 | net = factorized_res_module(net, is_training=is_training, dropout=0.3, dilation=[1, 4], l2=l2,name="fres14",reuse=reuse,momentum=momentum) 63 | net = factorized_res_module(net, is_training=is_training, dropout=0.3, dilation=[1, 8], l2=l2,name="fres15",reuse=reuse,momentum=momentum) 64 | net = factorized_res_module(net, is_training=is_training, dropout=0.3, dilation=[1, 16], l2=l2,name="fres16",reuse=reuse,momentum=momentum) 65 | return net 66 | 67 | def Decoder(x,numclasses,shape=[480,640],name='decoder',is_training=False,l2=None,reuse=None,momentum=0.9): 68 | height=shape[0]//8 69 | weight=shape[1]//8 70 | p1=x 71 | h=math.floor(height/2) 72 | w=math.floor(weight/2) 73 | kh=height-(2-1) * h 74 | kw=weight-(2-1) * w 75 | p2=tf.nn.avg_pool(x,[1,kh,kw,1],[1,h,w,1],padding='VALID') 76 | h=math.floor(height/4) 77 | w=math.floor(weight/4) 78 | kh=height-(4-1) * h 79 | kw=weight-(4-1) * w 80 | p3=tf.nn.avg_pool(x,[1,kh,kw,1],[1,h,w,1],padding='VALID') 81 | h=math.floor(height/8) 82 | w=math.floor(weight/8) 83 | kh=height-(8-1) * h 84 | kw=weight-(8-1) * w 85 | p4=tf.nn.avg_pool(x,[1,kh,kw,1],[1,h,w,1],padding='VALID') 86 | with tf.variable_scope(name): 87 | reg = None if l2 is None else tf.contrib.layers.l2_regularizer(scale=l2) 88 | j1=conv(p1, 32, kernel_size=1,activation=None,name='conv1',regularizer=reg,reuse=reuse,use_bias=None) 89 | j1=tf.nn.relu(batch(j1,training=is_training,reuse=reuse,momentum=momentum,name='batch1')) 90 | j2=conv(p2, 32, kernel_size=1,activation=None,name='conv2',regularizer=reg,reuse=reuse,use_bias=None) 91 | j2=tf.nn.relu(batch(j2,training=is_training,reuse=reuse,momentum=momentum,name='batch2')) 92 | j3=conv(p3, 32, kernel_size=1,activation=None,name='conv3',regularizer=reg,reuse=reuse,use_bias=None) 93 | j3=tf.nn.relu(batch(j3,training=is_training,reuse=reuse,momentum=momentum,name='batch3')) 94 | j4=conv(p4, 32, kernel_size=1,activation=None,name='conv4',regularizer=reg,reuse=reuse,use_bias=None) 95 | j4=tf.nn.relu(batch(j4,training=is_training,reuse=reuse,momentum=momentum,name='batch4')) 96 | f2=tf.image.resize_images(j2, [shape[0]//8,shape[1]//8],method=0) 97 | f3=tf.image.resize_images(j3, [shape[0]//8,shape[1]//8],method=0) 98 | f4=tf.image.resize_images(j4, [shape[0]//8,shape[1]//8],method=0) 99 | net=tf.concat([p1,j1,f2,f3,f4],-1) 100 | net=conv(net, 256, kernel_size=3,activation=None,name='conv5',regularizer=reg,reuse=reuse,use_bias=None) 101 | net=tf.nn.relu(batch(net,training=is_training,reuse=reuse,momentum=momentum,name='batch5')) 102 | net=tf.layers.dropout(net,rate=0.1,training=is_training) 103 | net=conv(net, numclasses, kernel_size=1,activation=None,name='conv6',regularizer=reg,reuse=reuse,use_bias=None) 104 | final=tf.image.resize_images(net, [shape[0],shape[1]],method=0) 105 | probabilities = tf.nn.softmax(final, name='logits_to_softmax') 106 | return final,probabilities 107 | 108 | def erfpsp(x,l2,shape=[480,640],numclasses=66,reuse=None,is_training=True,momentum=0.9): 109 | x=Encoder(x, is_training=is_training,l2=l2,reuse=reuse,momentum=momentum) 110 | x=Decoder(x,numclasses,shape=shape,name='decoder',is_training=is_training,l2=l2,reuse=reuse,momentum=momentum) 111 | return x -------------------------------------------------------------------------------- /erf-pspnet/train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.platform import tf_logging as logging 3 | import os 4 | import time 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import model 8 | import random 9 | import csv 10 | import math 11 | 12 | #==============INPUT ARGUMENTS================== 13 | flags = tf.app.flags 14 | 15 | #Directory arguments 16 | flags.DEFINE_string('dataset_dir', './dataset', 'The dataset directory to find the train, validation and test images.') 17 | flags.DEFINE_string('logdir', './log/camvid', 'The log directory to save your checkpoint and event files.') 18 | #Training arguments 19 | flags.DEFINE_integer('num_classes', 19, 'The number of classes to predict.') 20 | flags.DEFINE_integer('batch_size', 8, 'The batch_size for training.') 21 | flags.DEFINE_integer('eval_batch_size', 24, 'The batch size used for validation.') 22 | flags.DEFINE_integer('image_height',360, "The input height of the images.") 23 | flags.DEFINE_integer('image_width', 720, "The input width of the images.") 24 | flags.DEFINE_integer('num_epochs', 300, "The number of epochs to train your model.") 25 | flags.DEFINE_integer('num_epochs_before_decay', 100, 'The number of epochs before decaying your learning rate.') 26 | flags.DEFINE_float('weight_decay', 2e-4, "The weight decay for ENet convolution layers.") 27 | flags.DEFINE_float('learning_rate_decay_factor', 1e-1, 'The learning rate decay factor.') 28 | flags.DEFINE_float('initial_learning_rate', 1e-3, 'The initial learning rate for your training.') 29 | flags.DEFINE_boolean('Start_train',True, "The input height of the images.") 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | Start_train = FLAGS.Start_train 34 | log_name = 'model.ckpt' 35 | 36 | num_classes = FLAGS.num_classes 37 | batch_size = FLAGS.batch_size 38 | eval_batch_size = FLAGS.eval_batch_size 39 | image_height = FLAGS.image_height 40 | image_width = FLAGS.image_width 41 | 42 | #Training parameters 43 | initial_learning_rate = FLAGS.initial_learning_rate 44 | num_epochs_before_decay = FLAGS.num_epochs_before_decay 45 | num_epochs =FLAGS.num_epochs 46 | learning_rate_decay_factor = FLAGS.learning_rate_decay_factor 47 | weight_decay = FLAGS.weight_decay 48 | epsilon = 1e-8 49 | 50 | 51 | #Directories 52 | dataset_dir = FLAGS.dataset_dir 53 | logdir = FLAGS.logdir 54 | 55 | #===============PREPARATION FOR TRAINING================== 56 | #Get the images into a list 57 | image_files = sorted([os.path.join(dataset_dir, 'train', file) for file in os.listdir(dataset_dir + "/train") if file.endswith('.png')]) 58 | annotation_files = sorted([os.path.join(dataset_dir, "trainannot", file) for file in os.listdir(dataset_dir + "/trainannot") if file.endswith('.png')]) 59 | image_val_files = sorted([os.path.join(dataset_dir, 'val', file) for file in os.listdir(dataset_dir + "/val") if file.endswith('.png')]) 60 | annotation_val_files = sorted([os.path.join(dataset_dir, "valannot", file) for file in os.listdir(dataset_dir + "/valannot") if file.endswith('.png')]) 61 | #保存到excel 62 | csvname=logdir[6:]+'.csv' 63 | with open(csvname,'a', newline='') as out: 64 | csv_write = csv.writer(out,dialect='excel') 65 | a=[str(i) for i in range(num_classes)] 66 | csv_write.writerow(a) 67 | #Know the number steps to take before decaying the learning rate and batches per epoch 68 | num_batches_per_epoch = math.ceil(len(image_files) / batch_size) 69 | num_steps_per_epoch = num_batches_per_epoch 70 | decay_steps = int(num_epochs_before_decay * num_steps_per_epoch) 71 | 72 | #=================CLASS WEIGHTS=============================== 73 | #Median frequency balancing class_weights 74 | 75 | class_weights = np.array( [40.69042899, 47.6765088 , 12.70029695, 45.20543212, 45.78372173, 76 | 45.82527748, 48.40614895, 42.75593537, 3.36208549, 14.03151966, 77 | 4.9866471 , 39.25440643, 36.51259517, 32.81231979, 6.69824427, 78 | 33.55546509, 18.48781934, 32.97432129, 46.28665742],dtype=np.float32) 79 | 80 | def weighted_cross_entropy(onehot_labels, logits, class_weights): 81 | #a=tf.reduce_sum(-tf.log(tf.clip_by_value(logits, 1e-10, 1.0))*(1-logits)*(1-logits)*onehot_labels*class_weights) 82 | a=tf.reduce_sum(-tf.log(tf.clip_by_value(logits, 1e-10, 1.0))*onehot_labels*class_weights) 83 | return a 84 | 85 | 86 | def decode(a,b): 87 | a = tf.read_file(a) 88 | a=tf.image.decode_png(a, channels=3) 89 | a = tf.image.convert_image_dtype(a, dtype=tf.float32) 90 | b = tf.read_file(b) 91 | b = tf.image.decode_png(b,channels=1) 92 | a=tf.image.resize_images(a, [image_height,image_width],method=0) 93 | b=tf.image.resize_images(b, [image_height,image_width],method=1) 94 | c=tf.image.convert_image_dtype(a, dtype=tf.uint8) 95 | a.set_shape(shape=(image_height, image_width, 3)) 96 | b.set_shape(shape=(image_height, image_width,1)) 97 | c.set_shape(shape=(image_height, image_width, 3)) 98 | return a,b,c 99 | def decodev(a,b): 100 | a = tf.read_file(a) 101 | a=tf.image.decode_png(a, channels=3) 102 | a = tf.image.convert_image_dtype(a, dtype=tf.float32) 103 | b = tf.read_file(b) 104 | b = tf.image.decode_png(b,channels=1) 105 | a=tf.image.resize_images(a, [image_height,image_width],method=0) 106 | b=tf.image.resize_images(b, [image_height,image_width],method=1) 107 | a.set_shape(shape=(image_height, image_width, 3)) 108 | b.set_shape(shape=(image_height, image_width,1)) 109 | return a,b 110 | 111 | def run(): 112 | with tf.Graph().as_default() as graph: 113 | tf.logging.set_verbosity(tf.logging.INFO) 114 | 115 | #===================TRAINING BRANCH======================= 116 | #Load the files into one input queue 117 | images = tf.convert_to_tensor(image_files) 118 | annotations = tf.convert_to_tensor(annotation_files) 119 | tdataset = tf.data.Dataset.from_tensor_slices((images,annotations)) 120 | tdataset = tdataset.map(decode) 121 | tdataset = tdataset.shuffle(100).batch(batch_size).repeat(num_epochs) 122 | titerator = tdataset.make_initializable_iterator() 123 | images,annotations,realimg = titerator.get_next() 124 | 125 | 126 | images_val = tf.convert_to_tensor(image_val_files) 127 | annotations_val = tf.convert_to_tensor(annotation_val_files) 128 | vdataset = tf.data.Dataset.from_tensor_slices((images_val,annotations_val)) 129 | vdataset = vdataset.map(decodev) 130 | vdataset = vdataset.batch(eval_batch_size).repeat(num_epochs*3) 131 | viterator = vdataset.make_initializable_iterator() 132 | images_val,annotations_val = viterator.get_next() 133 | 134 | 135 | 136 | 137 | #perform one-hot-encoding on the ground truth annotation to get same shape as the logits 138 | _, probabilities= model.erfpsp(images,numclasses=num_classes, shape=[image_height,image_width], l2=weight_decay,reuse=None,is_training=True) 139 | annotations = tf.reshape(annotations, shape=[-1, image_height, image_width]) 140 | raw_gt = tf.reshape(annotations, [-1,]) 141 | indices = tf.squeeze(tf.where(tf.greater(raw_gt,0)), 1) 142 | gt = tf.cast(tf.gather(raw_gt, indices), tf.int32) 143 | gt = gt - 1 144 | gt_one = tf.one_hot(gt, num_classes, axis=-1) 145 | raw_prediction = tf.reshape(probabilities, [-1, num_classes]) 146 | prediction = tf.gather(raw_prediction, indices) 147 | annotations_ohe = tf.one_hot(annotations, num_classes+1, axis=-1) 148 | MASK = tf.reduce_sum(1-annotations_ohe[:,:,:,0]) 149 | 150 | los=weighted_cross_entropy(gt_one, prediction, class_weights)/MASK 151 | loss=tf.losses.add_loss(los) 152 | total_loss = tf.losses.get_total_loss() 153 | global_step = tf.train.get_or_create_global_step() 154 | #Define your exponentially decaying learning rate 155 | lr = tf.train.exponential_decay( 156 | learning_rate = initial_learning_rate, 157 | global_step = global_step, 158 | decay_steps = decay_steps, 159 | decay_rate = learning_rate_decay_factor, 160 | staircase = True) 161 | 162 | #Now we can define the optimizer that takes on the learning rate 163 | optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon=epsilon) 164 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 165 | updates_op = tf.group(*update_ops) 166 | #Create the train_op. 167 | with tf.control_dependencies([updates_op]): 168 | train_op = optimizer.minimize(total_loss,global_step=global_step) 169 | 170 | 171 | _, probabilities_val= model.erfpsp(images_val,numclasses=num_classes, shape=[image_height,image_width], l2=None,reuse=True,is_training=None) 172 | raw_gt_v = tf.reshape(tf.reshape(annotations_val, shape=[-1, 1024, 2048]),[-1,]) 173 | indices_v = tf.squeeze(tf.where(tf.greater(raw_gt_v,0)), 1) 174 | gt_v = tf.cast(tf.gather(raw_gt_v, indices_v), tf.int32) 175 | gt_v = gt_v-1 176 | gt_one_v = tf.one_hot(gt_v, num_classes, axis=-1) 177 | raw_prediction_v = tf.argmax(tf.reshape(probabilities_val, [-1, num_classes]),-1) 178 | prediction_v = tf.gather(raw_prediction_v, indices_v) 179 | prediction_ohe_v = tf.one_hot(prediction_v, num_classes, axis=-1) 180 | and_val=gt_one_v*prediction_ohe_v 181 | and_sum=tf.reduce_sum(and_val,[0]) 182 | or_val=tf.to_int32((gt_one_v+prediction_ohe_v)>0.5) 183 | or_sum=tf.reduce_sum(or_val,axis=[0]) 184 | T_sum=tf.reduce_sum(gt_one_v,axis=[0]) 185 | R_sum = tf.reduce_sum(prediction_ohe_v,axis=[0]) 186 | mPrecision=0 187 | mRecall_rate=0 188 | mIoU=0 189 | #Now we need to create a training step function that runs both the train_op, metrics_op and updates the global_step concurrently. 190 | def train_step(sess, train_op, global_step ,loss=total_loss): 191 | #Check the time for each sess run 192 | start_time = time.time() 193 | _,total_loss, global_step_count= sess.run([train_op,loss, global_step ]) 194 | time_elapsed = time.time() - start_time 195 | global_step_count=global_step_count+1 196 | #Run the logging to show some results 197 | logging.info('global step %s: loss: %.4f (%.2f sec/step)', global_step_count, total_loss, time_elapsed) 198 | 199 | return total_loss 200 | #Now finally create all the summaries you need to monitor and group them into one summary op. 201 | A = tf.Variable(tf.constant(0.0), dtype=tf.float32) 202 | a=tf.placeholder(shape=[],dtype=tf.float32) 203 | Precision=tf.assign(A, a) 204 | B = tf.Variable(tf.constant(0.0), dtype=tf.float32) 205 | b=tf.placeholder(shape=[],dtype=tf.float32) 206 | Recall=tf.assign(B, b) 207 | C = tf.Variable(tf.constant(0.0), dtype=tf.float32) 208 | c=tf.placeholder(shape=[],dtype=tf.float32) 209 | mIOU=tf.assign(C, c) 210 | predictions = tf.argmax(probabilities, -1) 211 | segmentation_output = tf.cast(tf.reshape((predictions+1)*255/num_classes, shape=[-1, image_height, image_width, 1]),tf.uint8) 212 | segmentation_ground_truth = tf.cast(tf.reshape(tf.cast(annotations, dtype=tf.float32)*255/num_classes, shape=[-1, image_height, image_width, 1]),tf.uint8) 213 | tf.summary.scalar('Monitor/Total_Loss', total_loss) 214 | tf.summary.scalar('Monitor/Precision', Precision) 215 | tf.summary.scalar('Monitor/Recall_rate', Recall) 216 | tf.summary.scalar('Monitor/mIoU', mIOU) 217 | tf.summary.scalar('Monitor/learning_rate', lr) 218 | tf.summary.image('Images/original_image', realimg, max_outputs=1) 219 | tf.summary.image('Images/segmentation_output', segmentation_output, max_outputs=1) 220 | tf.summary.image('Images/segmentation_ground_truth', segmentation_ground_truth, max_outputs=1) 221 | my_summary_op = tf.summary.merge_all() 222 | 223 | def train_sum(sess, train_op, global_step,sums,loss=total_loss,pre=0,recall=0,iou=0): 224 | start_time = time.time() 225 | _,total_loss, global_step_count,ss = sess.run([train_op,loss, global_step,sums ],feed_dict={a:pre,b:recall,c:iou}) 226 | time_elapsed = time.time() - start_time 227 | global_step_count=global_step_count+1 228 | logging.info('global step %s: loss: %.4f (%.2f sec/step)', global_step_count, total_loss, time_elapsed) 229 | 230 | return total_loss,ss 231 | 232 | def eval_step(sess,i ): 233 | and_eval_batch,or_eval_batch,T_eval_batch,R_eval_batch = sess.run([and_sum,or_sum,T_sum,R_sum]) 234 | #Log some information 235 | logging.info('STEP: %d ',i) 236 | return and_eval_batch,or_eval_batch,T_eval_batch,R_eval_batch 237 | def eval(num_class,csvname,session,image_val,eval_batch): 238 | or_=np.zeros((num_class), dtype=np.float32) 239 | and_=np.zeros((num_class), dtype=np.float32) 240 | T_=np.zeros((num_class), dtype=np.float32) 241 | R_=np.zeros((num_class), dtype=np.float32) 242 | for i in range(math.ceil(len(image_val) / eval_batch)): 243 | and_eval_batch,or_eval_batch,T_eval_batch,R_eval_batch = eval_step(session,i+1) 244 | and_=and_+and_eval_batch 245 | or_=or_+or_eval_batch 246 | T_=T_+T_eval_batch 247 | R_=R_+R_eval_batch 248 | Recall_rate=and_/T_ 249 | Precision=and_/R_ 250 | IoU=and_/or_ 251 | mPrecision=np.mean(Precision) 252 | mRecall_rate=np.mean(Recall_rate) 253 | mIoU=np.mean(IoU) 254 | print("Precision:") 255 | print(Precision) 256 | print("Recall rate:") 257 | print(Recall_rate) 258 | print("IoU:") 259 | print(IoU) 260 | print("mPrecision:") 261 | print(mPrecision) 262 | print("mRecall_rate:") 263 | print(mRecall_rate) 264 | print("mIoU") 265 | print(mIoU) 266 | with open(csvname,'a', newline='') as out: 267 | csv_write = csv.writer(out,dialect='excel') 268 | csv_write.writerow(Precision) 269 | csv_write.writerow(Recall_rate) 270 | csv_write.writerow(IoU) 271 | return mPrecision,mPrecision,mIoU 272 | gpu_options = tf.GPUOptions(allow_growth=True) 273 | config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options) 274 | init = tf.global_variables_initializer() 275 | saver=tf.train.Saver(max_to_keep=10) 276 | with tf.Session(config=config) as sess: 277 | sess.run(init) 278 | sess.run([titerator.initializer,viterator.initializer]) 279 | step = 0; 280 | if Start_train is not True: 281 | #input the checkpoint address,and the step number. 282 | checkpoint='./log/erfpsp/model.ckpt-37127' 283 | saver.restore(sess, checkpoint) 284 | step = 37127 285 | sess.run(tf.assign(global_step,step)) 286 | summary_writer = tf.summary.FileWriter(logdir, sess.graph) 287 | final = num_steps_per_epoch * num_epochs 288 | for i in range(step,final,1): 289 | if i % num_batches_per_epoch == 0: 290 | logging.info('Epoch %s/%s', i/num_batches_per_epoch + 1, num_epochs) 291 | learning_rate_value = sess.run([lr]) 292 | logging.info('Current Learning Rate: %s', learning_rate_value) 293 | if i is not step: 294 | saver.save(sess, os.path.join(logdir,log_name),global_step=i) 295 | mPrecision,mRecall_rate,mIoU=eval(num_class=num_classes,csvname=csvname,session=sess,image_val=image_val_files,eval_batch=eval_batch_size) 296 | if i % min(num_steps_per_epoch, 10) == 0: 297 | loss,summaries = train_sum(sess, train_op,global_step,sums=my_summary_op,loss=total_loss,pre=mPrecision,recall=mRecall_rate,iou=mIoU) 298 | summary_writer.add_summary(summaries,global_step=i+1) 299 | else: 300 | loss = train_step(sess, train_op, global_step) 301 | summary_writer.close() 302 | eval(num_class=num_classes,csvname=csvname,session=sess,image_val=image_val_files,eval_batch=eval_batch_size) 303 | logging.info('Final Loss: %s', loss) 304 | logging.info('Finished training! Saving model to disk now.') 305 | saver.save(sess, os.path.join(logdir,log_name), global_step = final) 306 | 307 | 308 | if __name__ == '__main__': 309 | run() 310 | -------------------------------------------------------------------------------- /erf-pspnet/trainial.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.platform import tf_logging as logging 3 | import os 4 | import time 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import model 8 | import random 9 | import csv 10 | import math 11 | 12 | #==============INPUT ARGUMENTS================== 13 | flags = tf.app.flags 14 | 15 | #Directory arguments 16 | flags.DEFINE_string('dataset_dir', './dataset', 'The dataset directory to find the train, validation and test images.') 17 | flags.DEFINE_string('logdir', './log/camvid', 'The log directory to save your checkpoint and event files.') 18 | #Training arguments 19 | flags.DEFINE_integer('num_classes', 11, 'The number of classes to predict.') 20 | flags.DEFINE_integer('batch_size', 8, 'The batch_size for training.') 21 | flags.DEFINE_integer('eval_batch_size', 24, 'The batch size used for validation.') 22 | flags.DEFINE_integer('image_height',360, "The input height of the images.") 23 | flags.DEFINE_integer('image_width', 480, "The input width of the images.") 24 | flags.DEFINE_integer('num_epochs', 300, "The number of epochs to train your model.") 25 | flags.DEFINE_integer('num_epochs_before_decay', 100, 'The number of epochs before decaying your learning rate.') 26 | flags.DEFINE_float('weight_decay', 2e-4, "The weight decay for ENet convolution layers.") 27 | flags.DEFINE_float('learning_rate_decay_factor', 1e-1, 'The learning rate decay factor.') 28 | flags.DEFINE_float('initial_learning_rate', 1e-3, 'The initial learning rate for your training.') 29 | flags.DEFINE_boolean('Start_train',True, "The input height of the images.") 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | Start_train = FLAGS.Start_train 34 | log_name = 'model.ckpt' 35 | 36 | num_classes = FLAGS.num_classes 37 | batch_size = FLAGS.batch_size 38 | eval_batch_size = FLAGS.eval_batch_size 39 | image_height = FLAGS.image_height 40 | image_width = FLAGS.image_width 41 | 42 | #Training parameters 43 | initial_learning_rate = FLAGS.initial_learning_rate 44 | num_epochs_before_decay = FLAGS.num_epochs_before_decay 45 | num_epochs =FLAGS.num_epochs 46 | learning_rate_decay_factor = FLAGS.learning_rate_decay_factor 47 | weight_decay = FLAGS.weight_decay 48 | epsilon = 1e-8 49 | 50 | 51 | #Directories 52 | dataset_dir = FLAGS.dataset_dir 53 | logdir = FLAGS.logdir 54 | 55 | #===============PREPARATION FOR TRAINING================== 56 | #Get the images into a list 57 | image_files = sorted([os.path.join(dataset_dir, 'train', file) for file in os.listdir(dataset_dir + "/train") if file.endswith('.png')]) 58 | annotation_files = sorted([os.path.join(dataset_dir, "trainannot", file) for file in os.listdir(dataset_dir + "/trainannot") if file.endswith('.png')]) 59 | image_val_files = sorted([os.path.join(dataset_dir, 'val', file) for file in os.listdir(dataset_dir + "/val") if file.endswith('.png')]) 60 | annotation_val_files = sorted([os.path.join(dataset_dir, "valannot", file) for file in os.listdir(dataset_dir + "/valannot") if file.endswith('.png')]) 61 | #保存到excel 62 | csvname=logdir[6:]+'.csv' 63 | with open(csvname,'a', newline='') as out: 64 | csv_write = csv.writer(out,dialect='excel') 65 | a=[str(i) for i in range(num_classes)] 66 | csv_write.writerow(a) 67 | #Know the number steps to take before decaying the learning rate and batches per epoch 68 | num_batches_per_epoch = math.ceil(len(image_files) / batch_size) 69 | num_steps_per_epoch = num_batches_per_epoch 70 | decay_steps = int(num_epochs_before_decay * num_steps_per_epoch) 71 | 72 | #=================CLASS WEIGHTS=============================== 73 | #Median frequency balancing class_weights 74 | 75 | class_weights1=np.array([ 4.57716287, 0, 0, 0, 9.60914246, 0, 0, 0 , 0,0,6.10711717], dtype=np.float32) 76 | class_weights2=np.array([ 0, 42.28705255, 3.46893819, 16.45916311, 0,0, 33.06296333, 0 , 0,0,0], dtype=np.float32) 77 | class_weights3=np.array([ 0, 0, 0, 0, 0, 33.93236668, 0, 13.5811212 , 40.96211531,44.98280801,0], dtype=np.float32) 78 | 79 | def weighted_cross_entropy(onehot_labels, logits, class_weights): 80 | #a=tf.reduce_sum(-tf.log(tf.clip_by_value(logits, 1e-10, 1.0))*(1-logits)*(1-logits)*onehot_labels*class_weights) 81 | a=tf.reduce_sum(-tf.log(tf.clip_by_value(logits, 1e-10, 1.0))*onehot_labels*class_weights) 82 | return a 83 | 84 | 85 | def decode(a,b): 86 | a = tf.read_file(a) 87 | a=tf.image.decode_png(a, channels=3) 88 | a = tf.image.convert_image_dtype(a, dtype=tf.float32) 89 | b = tf.read_file(b) 90 | b = tf.image.decode_png(b,channels=1) 91 | a=tf.image.resize_images(a, [image_height,image_width],method=0) 92 | b=tf.image.resize_images(b, [image_height,image_width],method=1) 93 | c=tf.image.convert_image_dtype(a, dtype=tf.uint8) 94 | a.set_shape(shape=(image_height, image_width, 3)) 95 | b.set_shape(shape=(image_height, image_width,1)) 96 | c.set_shape(shape=(image_height, image_width, 3)) 97 | return a,b,c 98 | def decodev(a,b): 99 | a = tf.read_file(a) 100 | a=tf.image.decode_png(a, channels=3) 101 | a = tf.image.convert_image_dtype(a, dtype=tf.float32) 102 | b = tf.read_file(b) 103 | b = tf.image.decode_png(b,channels=1) 104 | a=tf.image.resize_images(a, [image_height,image_width],method=0) 105 | b=tf.image.resize_images(b, [image_height,image_width],method=1) 106 | a.set_shape(shape=(image_height, image_width, 3)) 107 | b.set_shape(shape=(image_height, image_width,1)) 108 | return a,b 109 | def run(): 110 | with tf.Graph().as_default() as graph: 111 | tf.logging.set_verbosity(tf.logging.INFO) 112 | 113 | #===================TRAINING BRANCH======================= 114 | #Load the files into one input queue 115 | images = tf.convert_to_tensor(image_files) 116 | annotations = tf.convert_to_tensor(annotation_files) 117 | tdataset = tf.data.Dataset.from_tensor_slices((images,annotations)) 118 | tdataset = tdataset.map(decode) 119 | tdataset = tdataset.shuffle(100).batch(batch_size).repeat(num_epochs) 120 | titerator = tdataset.make_initializable_iterator() 121 | images,annotations,realimg = titerator.get_next() 122 | 123 | 124 | images_val = tf.convert_to_tensor(image_val_files) 125 | annotations_val = tf.convert_to_tensor(annotation_val_files) 126 | vdataset = tf.data.Dataset.from_tensor_slices((images_val,annotations_val)) 127 | vdataset = vdataset.map(decodev) 128 | vdataset = vdataset.batch(eval_batch_size).repeat(num_epochs*3) 129 | viterator = vdataset.make_initializable_iterator() 130 | images_val,annotations_val = viterator.get_next() 131 | 132 | 133 | 134 | 135 | 136 | #perform one-hot-encoding on the ground truth annotation to get same shape as the logits 137 | _, probabilities= model.erfpsp(images,numclasses=num_classes, shape=[image_height,image_width], l2=weight_decay,reuse=None,is_training=True) 138 | annotations = tf.reshape(annotations, shape=[-1, image_height, image_width]) 139 | raw_gt = tf.reshape(annotations, [-1,]) 140 | indices = tf.squeeze(tf.where(tf.greater(raw_gt,0)), 1) 141 | gt = tf.cast(tf.gather(raw_gt, indices), tf.int32) 142 | gt = gt - 1 143 | gt_one = tf.one_hot(gt, num_classes, axis=-1) 144 | raw_prediction = tf.reshape(probabilities, [-1, num_classes]) 145 | prediction = tf.gather(raw_prediction, indices) 146 | 147 | 148 | 149 | 150 | annotations_ohe = tf.one_hot(annotations, num_classes+1, axis=-1) 151 | MASK = tf.reduce_sum(1-annotations_ohe[:,:,:,0]) 152 | m=tf.split(annotations_ohe,num_or_size_splits=num_classes+1,axis=-1) 153 | M1=tf.reduce_sum(tf.concat([m[2],m[3],m[4],m[6],m[7],m[8],m[9],m[10]],axis=-1),-1,keepdims=True) 154 | M2=tf.reduce_sum(tf.concat([m[6],m[8],m[9],m[10]],axis=-1),-1,keepdims=True) 155 | X_=tf.reduce_sum(probabilities*annotations_ohe[:,:,:,1:],-1,keepdims=True) 156 | mask=tf.reshape(1-annotations_ohe[:,:,:,0],shape=[-1,image_height,image_width,1]) 157 | f1=tf.reduce_sum(tf.pow(tf.sqrt(M1+0.5)*(X_-M1)*mask,2))/(2*MASK) 158 | f2=tf.reduce_sum(tf.pow(tf.sqrt(M2+0.5)*(X_-M2)*M1,2))/(2*tf.reduce_sum(M1)) 159 | 160 | loss1=weighted_cross_entropy(gt_one, prediction, class_weights1) 161 | loss2=weighted_cross_entropy(gt_one, prediction, class_weights2) 162 | loss3=weighted_cross_entropy(gt_one, prediction, class_weights3) 163 | los=(loss1+loss2*(f1+2)+loss3*(f1+2)*(f2+2))/YANMO 164 | loss=tf.losses.add_loss(los) 165 | total_loss = tf.losses.get_total_loss() 166 | global_step = tf.train.get_or_create_global_step() 167 | #Define your exponentially decaying learning rate 168 | lr = tf.train.exponential_decay( 169 | learning_rate = initial_learning_rate, 170 | global_step = global_step, 171 | decay_steps = decay_steps, 172 | decay_rate = learning_rate_decay_factor, 173 | staircase = True) 174 | 175 | #Now we can define the optimizer that takes on the learning rate 176 | optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon=epsilon) 177 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 178 | updates_op = tf.group(*update_ops) 179 | #Create the train_op. 180 | with tf.control_dependencies([updates_op]): 181 | train_op = optimizer.minimize(total_loss,global_step=global_step) 182 | #State the metrics that you want to predict. We get a predictions that is not one_hot_encoded. 183 | #这一块为验证 184 | _, probabilities_val= model.erfpsp(images_val,numclasses=num_classes, shape=[image_height,image_width], l2=None,reuse=True,is_training=None) 185 | raw_gt_v = tf.reshape(tf.reshape(annotations_val, shape=[-1, 1024, 2048]),[-1,]) 186 | indices_v = tf.squeeze(tf.where(tf.greater(raw_gt_v,0)), 1) 187 | gt_v = tf.cast(tf.gather(raw_gt_v, indices_v), tf.int32) 188 | gt_v = gt_v-1 189 | gt_one_v = tf.one_hot(gt_v, num_classes, axis=-1) 190 | raw_prediction_v = tf.argmax(tf.reshape(probabilities_val, [-1, num_classes]),-1) 191 | prediction_v = tf.gather(raw_prediction_v, indices_v) 192 | prediction_ohe_v = tf.one_hot(prediction_v, num_classes, axis=-1) 193 | and_val=gt_one_v*prediction_ohe_v 194 | and_sum=tf.reduce_sum(and_val,[0]) 195 | or_val=tf.to_int32((gt_one_v+prediction_ohe_v)>0.5) 196 | or_sum=tf.reduce_sum(or_val,axis=[0]) 197 | T_sum=tf.reduce_sum(gt_one_v,axis=[0]) 198 | R_sum = tf.reduce_sum(prediction_ohe_v,axis=[0]) 199 | mPrecision=0 200 | mRecall_rate=0 201 | mIoU=0 202 | #Now we need to create a training step function that runs both the train_op, metrics_op and updates the global_step concurrently. 203 | def train_step(sess, train_op, global_step ,loss=total_loss): 204 | #Check the time for each sess run 205 | start_time = time.time() 206 | _,total_loss, global_step_count= sess.run([train_op,loss, global_step ]) 207 | time_elapsed = time.time() - start_time 208 | global_step_count=global_step_count+1 209 | #Run the logging to show some results 210 | logging.info('global step %s: loss: %.4f (%.2f sec/step)', global_step_count, total_loss, time_elapsed) 211 | 212 | return total_loss 213 | #Now finally create all the summaries you need to monitor and group them into one summary op. 214 | A = tf.Variable(tf.constant(0.0), dtype=tf.float32) 215 | a=tf.placeholder(shape=[],dtype=tf.float32) 216 | Precision=tf.assign(A, a) 217 | B = tf.Variable(tf.constant(0.0), dtype=tf.float32) 218 | b=tf.placeholder(shape=[],dtype=tf.float32) 219 | Recall=tf.assign(B, b) 220 | C = tf.Variable(tf.constant(0.0), dtype=tf.float32) 221 | c=tf.placeholder(shape=[],dtype=tf.float32) 222 | mIOU=tf.assign(C, c) 223 | predictions = tf.argmax(probabilities, -1) 224 | segmentation_output = tf.cast(tf.reshape((predictions+1)*255/num_classes, shape=[-1, image_height, image_width, 1]),tf.uint8) 225 | segmentation_ground_truth = tf.cast(tf.reshape(tf.cast(annotations, dtype=tf.float32)*255/num_classes, shape=[-1, image_height, image_width, 1]),tf.uint8) 226 | tf.summary.scalar('Monitor/Total_Loss', total_loss) 227 | tf.summary.scalar('Monitor/Precision', Precision) 228 | tf.summary.scalar('Monitor/Recall_rate', Recall) 229 | tf.summary.scalar('Monitor/mIoU', mIOU) 230 | tf.summary.scalar('Monitor/learning_rate', lr) 231 | tf.summary.image('Images/original_image', realimg, max_outputs=1) 232 | tf.summary.image('Images/segmentation_output', segmentation_output, max_outputs=1) 233 | tf.summary.image('Images/segmentation_ground_truth', segmentation_ground_truth, max_outputs=1) 234 | my_summary_op = tf.summary.merge_all() 235 | def train_sum(sess, train_op, global_step,sums,loss=total_loss,pre=0,recall=0,iou=0): 236 | start_time = time.time() 237 | _,total_loss, global_step_count,ss = sess.run([train_op,loss, global_step,sums ],feed_dict={a:pre,b:recall,c:iou}) 238 | time_elapsed = time.time() - start_time 239 | global_step_count=global_step_count+1 240 | logging.info('global step %s: loss: %.4f (%.2f sec/step)', global_step_count, total_loss, time_elapsed) 241 | 242 | return total_loss,ss 243 | 244 | def eval_step(sess,i ): 245 | and_eval_batch,or_eval_batch,T_eval_batch,R_eval_batch = sess.run([and_sum,or_sum,T_sum,R_sum]) 246 | #Log some information 247 | logging.info('STEP: %d ',i) 248 | return and_eval_batch,or_eval_batch,T_eval_batch,R_eval_batch 249 | def eval(num_class,csvname,session,image_val,eval_batch): 250 | or_=np.zeros((num_class), dtype=np.float32) 251 | and_=np.zeros((num_class), dtype=np.float32) 252 | T_=np.zeros((num_class), dtype=np.float32) 253 | R_=np.zeros((num_class), dtype=np.float32) 254 | for i in range(math.ceil(len(image_val) / eval_batch)): 255 | and_eval_batch,or_eval_batch,T_eval_batch,R_eval_batch = eval_step(session,i+1) 256 | and_=and_+and_eval_batch 257 | or_=or_+or_eval_batch 258 | T_=T_+T_eval_batch 259 | R_=R_+R_eval_batch 260 | Recall_rate=and_/T_ 261 | Precision=and_/R_ 262 | IoU=and_/or_ 263 | mPrecision=np.mean(Precision) 264 | mRecall_rate=np.mean(Recall_rate) 265 | mIoU=np.mean(IoU) 266 | print("Precision:") 267 | print(Precision) 268 | print("Recall rate:") 269 | print(Recall_rate) 270 | print("IoU:") 271 | print(IoU) 272 | print("mPrecision:") 273 | print(mPrecision) 274 | print("mRecall_rate:") 275 | print(mRecall_rate) 276 | print("mIoU") 277 | print(mIoU) 278 | with open(csvname,'a', newline='') as out: 279 | csv_write = csv.writer(out,dialect='excel') 280 | csv_write.writerow(Precision) 281 | csv_write.writerow(Recall_rate) 282 | csv_write.writerow(IoU) 283 | return mPrecision,mPrecision,mIoU 284 | gpu_options = tf.GPUOptions(allow_growth=True) 285 | config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options) 286 | init = tf.global_variables_initializer() 287 | saver=tf.train.Saver(max_to_keep=10) 288 | with tf.Session(config=config) as sess: 289 | sess.run(init) 290 | sess.run([titerator.initializer,viterator.initializer]) 291 | step = 0; 292 | if Start_train is not True: 293 | #input the checkpoint address,and the step number. 294 | checkpoint='./log/erfpspial/model.ckpt-37127' 295 | saver.restore(sess, checkpoint) 296 | step = 37127 297 | sess.run(tf.assign(global_step,step)) 298 | summary_writer = tf.summary.FileWriter(logdir, sess.graph) 299 | final = num_steps_per_epoch * num_epochs 300 | for i in range(step,final,1): 301 | if i % num_batches_per_epoch == 0: 302 | logging.info('Epoch %s/%s', i/num_batches_per_epoch + 1, num_epochs) 303 | learning_rate_value = sess.run([lr]) 304 | logging.info('Current Learning Rate: %s', learning_rate_value) 305 | if i is not step: 306 | saver.save(sess, os.path.join(logdir,log_name),global_step=i) 307 | mPrecision,mRecall_rate,mIoU=eval(num_class=num_classes,csvname=csvname,session=sess,image_val=image_val_files,eval_batch=eval_batch_size) 308 | if i % min(num_steps_per_epoch, 10) == 0: 309 | loss,summaries = train_sum(sess, train_op,global_step,sums=my_summary_op,loss=total_loss,pre=mPrecision,recall=mPrecision,iou=mIoU) 310 | summary_writer.add_summary(summaries,global_step=i+1) 311 | else: 312 | loss = train_step(sess, train_op, global_step) 313 | summary_writer.close() 314 | eval(num_class=num_classes,csvname=csvname,session=sess,image_val=image_val_files,eval_batch=eval_batch_size) 315 | logging.info('Final Loss: %s', loss) 316 | logging.info('Finished training! Saving model to disk now.') 317 | saver.save(sess, os.path.join(logdir,log_name), global_step = final) 318 | 319 | 320 | if __name__ == '__main__': 321 | run() 322 | --------------------------------------------------------------------------------