├── README.md ├── Swift_Factorized_Network(SFN) ├── sfn.py └── train.py └── swiftnet ├── swiftnet.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Swiftnet 2 | Swiftnet implemented by tensorflow 3 | ## Paper 4 | The code is impledmented according to the following papers. 5 | + [In Defense of Pre-trained ImageNet Architectures for Real-time Semantic Segmentation of Road-driving Images](https://arxiv.org/pdf/1903.08469.pdf ) 6 | ## Main Dependencies 7 | ``` 8 | tensorflow 1.12 9 | Open CV 10 | Python 3.6.5 11 | ``` 12 | ## Note 13 | In the future, I will update the demo code for change version of swiftnet which possess a higher IoU. The pre-trained parameters of ResNet-18 is on https://pan.baidu.com/s/16OjEC-GcWvGIj0eWVe9RiQ. 14 | -------------------------------------------------------------------------------- /Swift_Factorized_Network(SFN)/sfn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import math 3 | import numpy as np 4 | 5 | 6 | 7 | 8 | 9 | def conv(inputs,filters,kernel_size,strides=(1, 1),padding='SAME',dilation_rate=(1, 1),activation=tf.nn.relu,use_bias=None,regularizer=None,name=None,reuse=None): 10 | out=tf.layers.conv2d( 11 | inputs, 12 | filters=filters, 13 | kernel_size=kernel_size, 14 | strides=strides, 15 | padding=padding, 16 | dilation_rate=dilation_rate, 17 | activation=activation, 18 | use_bias=use_bias, 19 | kernel_regularizer=regularizer, 20 | bias_initializer=tf.zeros_initializer(), 21 | kernel_initializer= tf.random_normal_initializer(stddev=0.1), 22 | name=name, 23 | reuse=reuse) 24 | return out 25 | 26 | 27 | # In[3]: 28 | 29 | 30 | def batch(inputs,training=True,reuse=None,momentum=0.9,name='n'): 31 | out=tf.layers.batch_normalization(inputs,training=training,reuse=reuse,momentum=momentum,name=name) 32 | return out 33 | 34 | 35 | # In[4]: 36 | 37 | 38 | def branch1(x,numOut,l2,stride=1,is_training=True,momentum=0.9,reuse=None): 39 | reg = None if l2 is None else tf.contrib.layers.l2_regularizer(scale=l2) 40 | with tf.variable_scope("conv1"): 41 | y = conv(x, numOut, kernel_size=[3, 3],activation=None,strides=(stride,stride),name='conv',regularizer=reg,reuse=reuse) 42 | y = tf.nn.relu(batch(y,training=is_training,reuse=reuse,momentum=momentum,name='bn')) 43 | with tf.variable_scope("conv2"): 44 | y = conv(y, numOut, kernel_size=[3, 3],activation=None,regularizer=reg,name='conv',reuse=reuse) 45 | y = batch(y,training=is_training,reuse=reuse,momentum=momentum,name='bn') 46 | return y 47 | 48 | 49 | # In[5]: 50 | 51 | 52 | def branch2(x,numOut,l2,stride=1,is_training=True,momentum=0.9,reuse=None): 53 | reg = None if l2 is None else tf.contrib.layers.l2_regularizer(scale=l2) 54 | with tf.variable_scope("convshortcut"): 55 | y = conv(x, numOut, kernel_size=[1, 1],activation=None,strides=(stride,stride),name='conv',regularizer=reg,reuse=reuse) 56 | y = batch(y,training=is_training,reuse=reuse,momentum=momentum,name='bn') 57 | return y 58 | 59 | 60 | # In[6]: 61 | 62 | 63 | def residual(x,numOut,l2,stride=1,is_training=True,reuse=None,momentum=0.9,branch=False,name='res'): 64 | with tf.variable_scope(name): 65 | block = branch1(x,numOut,l2,stride=stride,is_training=is_training,momentum=momentum,reuse=reuse) 66 | if x.get_shape().as_list()[3] != numOut or branch: 67 | skip = branch2(x, numOut,l2,stride=stride,is_training=is_training,momentum=momentum,reuse=reuse) 68 | return tf.nn.relu(block+skip),block+skip 69 | else: 70 | return tf.nn.relu(x+block),x+block 71 | 72 | 73 | # In[7]: 74 | 75 | 76 | def resnet18(x, is_training,l2=None,dropout=0.05,reuse=None,momentum=0.9,name='Resnet18'): 77 | feature=[] 78 | with tf.variable_scope(name): 79 | reg = None if l2 is None else tf.contrib.layers.l2_regularizer(scale=l2/4) 80 | y=conv(x, 64, kernel_size=[7, 7],activation=None,strides=2,name='conv0',regularizer=reg,reuse=reuse) 81 | y=tf.nn.relu(batch(y,training=is_training,reuse=reuse,momentum=momentum,name='conv0/bn')) 82 | y=tf.nn.max_pool(y,ksize=[1,3,3,1],strides=[1,2,2,1],padding='SAME',name='pool1') 83 | with tf.variable_scope('group0'): 84 | res2a,t=residual(y,64,l2,branch=True,reuse=reuse,is_training=is_training,name='block0') 85 | res2b,t=residual(res2a,64,l2,reuse=reuse,is_training=is_training,name='block1') 86 | feature.append(t) 87 | with tf.variable_scope('group1'): 88 | res3a,t=residual(res2b,128,l2,stride=2,reuse=reuse,is_training=is_training,name='block0') 89 | res3b,t=residual(res3a,128,l2,reuse=reuse,is_training=is_training,name='block1') 90 | feature.append(t) 91 | with tf.variable_scope('group2'): 92 | res4a,t=residual(res3b,256,l2,stride=2,reuse=reuse,is_training=is_training,name='block0') 93 | res4b,t=residual(res4a,256,l2,reuse=reuse,is_training=is_training,name='block1') 94 | feature.append(t) 95 | with tf.variable_scope('group3'): 96 | res5a,t=residual(res4b,512,l2,stride=2,reuse=reuse,is_training=is_training,name='block0') 97 | res5b,t=residual(res5a,512,l2,reuse=reuse,is_training=is_training,name='block1') 98 | feature.append(t) 99 | #pool5=tf.reduce_mean(res5b, [1, 2],keepdims=True) 100 | #dropout = tf.layers.dropout(pool5,rate=dropout,training=is_training) 101 | #y=conv(dropout, 1000, kernel_size=[1, 1],activation=None,name='class',use_bias=True,regularizer=reg,reuse=reuse) 102 | #y=conv(y, 512, kernel_size=[1, 1],activation=None,name='attention',use_bias=None,regularizer=reg,reuse=reuse) 103 | #y=tf.nn.sigmoid(batch(y,training=is_training,reuse=reuse,momentum=momentum,name='attentionbn')) 104 | #y=res5b*y+res5b 105 | #feature.append(y) 106 | return y,feature 107 | 108 | def erfupsample(x,skip,is_training,shape=[512,512],kernal=3,stage=0,l2=None,reuse=None,momentum=0.9,name='up0'): 109 | height=int(shape[0]//math.pow(2,5-stage)) 110 | weight=int(shape[1]//math.pow(2,5-stage)) 111 | with tf.variable_scope(name): 112 | reg = None if l2 is None else tf.contrib.layers.l2_regularizer(scale=l2) 113 | skip=tf.nn.relu(batch(skip,training=is_training,reuse=reuse,momentum=momentum,name='skipbn')) 114 | skip = conv(skip,128,kernel_size=1,activation=None,name='changedemesion',regularizer=reg,reuse=reuse) 115 | x=tf.image.resize_images(x, [height,weight],method=0,align_corners=True) 116 | skip=x+skip 117 | skip=tf.nn.relu(batch(skip,training=is_training,reuse=reuse,momentum=momentum,name='blendbn0')) 118 | skip1=conv(skip,128,kernel_size=[kernal,1],activation=None,name='skipconv1a',regularizer=reg,reuse=reuse) 119 | skip1=conv(skip1,128,kernel_size=[1,kernal],activation=None,name='skipconv1b',regularizer=reg,reuse=reuse) 120 | skip2=conv(skip,128,kernel_size=[1,kernal],activation=None,name='skipconv2a',regularizer=reg,reuse=reuse) 121 | skip2=conv(skip2,128,kernel_size=[kernal,1],activation=None,name='skipconv2b',regularizer=reg,reuse=reuse) 122 | x=skip+skip1+skip2 123 | x=tf.nn.relu(batch(x,training=is_training,reuse=reuse,momentum=momentum,name='blendbn')) 124 | x=conv(x, 128, kernel_size=3,activation=None,name='blendconv',regularizer=reg,reuse=reuse) 125 | #pool5=tf.reduce_mean(x, [1, 2],keepdims=True) 126 | #y=conv(pool5, 128, kernel_size=[1, 1],activation=None,name='attention',use_bias=None,regularizer=reg,reuse=reuse) 127 | #y=tf.nn.sigmoid(batch(y,training=is_training,reuse=reuse,momentum=momentum,name='attentionbn')) 128 | #x=x*y+x 129 | return x 130 | 131 | 132 | def SpatialPyramidPooling(x, is_training,shape=[512,512],grids=(8, 4, 2,1),l2=None,reuse=None,momentum=0.9,name='spp'): 133 | levels=[] 134 | height=shape[0]//32 135 | weight=shape[1]//32 136 | with tf.variable_scope(name): 137 | reg = None if l2 is None else tf.contrib.layers.l2_regularizer(scale=l2) 138 | x=tf.nn.relu(batch(x,training=is_training,reuse=reuse,momentum=momentum,name='bn0')) 139 | x=conv(x, 128, kernel_size=1,activation=None,name='conv0',regularizer=reg,reuse=reuse) 140 | levels.append(x) 141 | for i in range(len(grids)): 142 | h=math.floor(height/grids[i]) 143 | w=math.floor(weight/grids[i]) 144 | kh=height-(grids[i]-1) * h 145 | kw=weight-(grids[i]-1) * w 146 | y=tf.nn.avg_pool(x,[1,kh,kw,1],[1,h,w,1],padding='VALID') 147 | y=tf.nn.relu(batch(y,training=is_training,reuse=reuse,momentum=momentum,name='bn'+str(i+1))) 148 | y=conv(y, 32, kernel_size=1,activation=None,name='conv'+str(i+1),regularizer=reg,reuse=reuse) 149 | y=tf.image.resize_images(y, [height,weight],method=0,align_corners=True) 150 | levels.append(y) 151 | final=tf.concat(levels,-1) 152 | final=tf.nn.relu(batch(final,training=is_training,reuse=reuse,momentum=momentum,name='blendbn')) 153 | final=conv(final, 128, kernel_size=1,activation=None,name='blendconv',regularizer=reg,reuse=reuse) 154 | final=tf.nn.relu(batch(final,training=is_training,reuse=reuse,momentum=momentum,name='finalbn')) 155 | return final 156 | 157 | 158 | 159 | 160 | 161 | 162 | def swiftnet(x, numclass,is_training,shape,l2=None,dropout=0.05,reuse=None,momentum=0.9): 163 | xclass,feature=resnet18(x, is_training,l2,dropout=dropout,reuse=reuse,momentum=momentum,name='Resnet18') 164 | x=SpatialPyramidPooling(feature[-1], is_training,shape=shape,grids=(8, 4, 2, 1),l2=l2,reuse=reuse,momentum=momentum,name='spp') 165 | x=erfupsample(x,feature[-2],is_training,shape=shape,kernal=3,stage=1,l2=l2,reuse=reuse,momentum=momentum,name='up1') 166 | x=erfupsample(x,feature[-3],is_training,shape=shape,kernal=5,stage=2,l2=l2,reuse=reuse,momentum=momentum,name='up2') 167 | x=erfupsample(x,feature[-4],is_training,shape=shape,kernal=7,stage=3,l2=l2,reuse=reuse,momentum=momentum,name='up3') 168 | with tf.variable_scope('class'): 169 | reg = None if l2 is None else tf.contrib.layers.l2_regularizer(scale=l2) 170 | x=tf.nn.relu(batch(x,training=is_training,reuse=reuse,momentum=momentum,name='classbn')) 171 | x=conv(x, numclass, kernel_size=3,activation=None,name='classconv',regularizer=reg,reuse=reuse) 172 | x=tf.image.resize_images(x, [shape[0],shape[1]],method=0,align_corners=True) 173 | final=tf.nn.softmax(x, name='logits_to_softmax') 174 | return x,final 175 | 176 | 177 | 178 | 179 | def load_weight(sess,resnet50_path,varss): 180 | param = dict(np.load(resnet50_path)) 181 | for v in varss: 182 | nameEnd = v.name.split('/')[-1] 183 | if nameEnd == "moving_mean:0": 184 | name = v.name[9:-13]+"mean/EMA" 185 | elif nameEnd == "moving_variance:0": 186 | name = v.name[9:-17]+"variance/EMA" 187 | elif nameEnd =='kernel:0': 188 | if v.name.split('/')[1]=='conv0': 189 | name='conv0/W' 190 | b=np.expand_dims(param[name][:,:,0,:],2) 191 | g=np.expand_dims(param[name][:,:,1,:],2) 192 | r=np.expand_dims(param[name][:,:,2,:],2) 193 | param[name]=np.concatenate([r,g,b],2) 194 | elif v.name.split('/')[1]=='class': 195 | name='linear/W' 196 | else: 197 | name=v.name[9:-13]+'W' 198 | elif nameEnd=='gamma:0': 199 | name=v.name[9:-2] 200 | elif nameEnd=='beta:0': 201 | name=v.name[9:-2] 202 | else: 203 | name='linear/b' 204 | sess.run(v.assign(param[name])) 205 | print("Copy weights: " + name + "---->"+ v.name) 206 | 207 | 208 | -------------------------------------------------------------------------------- /Swift_Factorized_Network(SFN)/train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.platform import tf_logging as logging 3 | import os 4 | import time 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import sfn 8 | import random 9 | import csv 10 | import math 11 | from scipy import misc 12 | 13 | #==============INPUT ARGUMENTS================== 14 | flags = tf.app.flags 15 | 16 | #Directory arguments 17 | flags.DEFINE_string('dataset_dir', './dataset', 'The dataset directory to find the train, validation and test images.') 18 | flags.DEFINE_string('logdir', './log/swiftnet', 'The log directory to save your checkpoint and event files.') 19 | #Training arguments 20 | flags.DEFINE_integer('num_classes', 19, 'The number of classes to predict.') 21 | flags.DEFINE_integer('batch_size', 11, 'The batch_size for training.') 22 | flags.DEFINE_integer('eval_batch_size', 8, 'The batch size used for validation.') 23 | flags.DEFINE_integer('image_height',768, "The input height of the images.") 24 | flags.DEFINE_integer('image_width', 768, "The input width of the images.") 25 | flags.DEFINE_integer('num_epochs', 200, "The number of epochs to train your model.") 26 | flags.DEFINE_integer('num_epochs_before_decay', 200, 'The number of epochs before decaying your learning rate.') 27 | flags.DEFINE_float('weight_decay', 1e-4, "The weight decay for ENet convolution layers.") 28 | flags.DEFINE_float('learning_rate_decay_factor', 0.6667, 'The learning rate decay factor.') 29 | flags.DEFINE_float('initial_learning_rate', 4e-4, 'The initial learning rate for your training.') 30 | flags.DEFINE_boolean('Start_train',True, "The input height of the images.") 31 | 32 | # 33 | 34 | FLAGS = flags.FLAGS 35 | 36 | Start_train = FLAGS.Start_train 37 | log_name = 'model.ckpt' 38 | 39 | num_classes = FLAGS.num_classes 40 | batch_size = FLAGS.batch_size 41 | eval_batch_size = FLAGS.eval_batch_size 42 | image_height = FLAGS.image_height 43 | image_width = FLAGS.image_width 44 | 45 | #Training parameters 46 | initial_learning_rate = FLAGS.initial_learning_rate 47 | num_epochs_before_decay = FLAGS.num_epochs_before_decay 48 | num_epochs =FLAGS.num_epochs 49 | learning_rate_decay_factor = FLAGS.learning_rate_decay_factor 50 | weight_decay = FLAGS.weight_decay 51 | epsilon = 1e-8 52 | 53 | 54 | #Directories 55 | dataset_dir = FLAGS.dataset_dir 56 | logdir = FLAGS.logdir 57 | 58 | #===============PREPARATION FOR TRAINING================== 59 | #Get the images into a list 60 | image_files = sorted([os.path.join(dataset_dir, 'train', file) for file in os.listdir(dataset_dir + "/train") if file.endswith('.png')]) 61 | annotation_files = sorted([os.path.join(dataset_dir, "trainannot", file) for file in os.listdir(dataset_dir + "/trainannot") if file.endswith('.png')]) 62 | image_val_files = sorted([os.path.join(dataset_dir, 'val', file) for file in os.listdir(dataset_dir + "/val") if file.endswith('.png')]) 63 | annotation_val_files = sorted([os.path.join(dataset_dir, "valannot", file) for file in os.listdir(dataset_dir + "/valannot") if file.endswith('.png')]) 64 | #保存到excel 65 | csvname=logdir[6:]+'.csv' 66 | with open(csvname,'a', newline='') as out: 67 | csv_write = csv.writer(out,dialect='excel') 68 | a=[str(i) for i in range(num_classes)] 69 | csv_write.writerow(a) 70 | #Know the number steps to take before decaying the learning rate and batches per epoch 71 | num_batches_per_epoch = math.ceil(len(image_files) / batch_size) 72 | num_steps_per_epoch = num_batches_per_epoch 73 | decay_steps = int(num_epochs_before_decay * num_steps_per_epoch) 74 | 75 | #=================CLASS WEIGHTS=============================== 76 | class_weights=np.array( 77 | [40.69042899, 47.6765088 , 12.70029695, 45.20543212, 45.78372173, 78 | 45.82527748, 48.40614895, 42.75593537, 3.36208549, 14.03151966, 79 | 4.9866471 , 39.25440643, 36.51259517, 32.81231979, 6.69824427, 80 | 33.55546509, 18.48781934, 32.97432129, 46.28665742],dtype=np.float32) 81 | 82 | def weighted_cross_entropy(onehot_labels, logits, class_weights,annotations_ohe): 83 | a=tf.reduce_sum(-tf.log(tf.clip_by_value(logits, 1e-10, 1.0))*onehot_labels*class_weights) 84 | MASK = tf.reduce_sum(1-annotations_ohe[:,:,:,0])#calculation the pixel number of the meaningful classes. 85 | return a/MASK 86 | 87 | 88 | #第一次增强采用最大1,然后亮度0.1 89 | def decode(a,b): 90 | a = tf.read_file(a) 91 | a=tf.image.decode_png(a, channels=3) 92 | a=tf.cast(a,dtype=tf.float32) 93 | b = tf.read_file(b) 94 | b = tf.image.decode_png(b,channels=1) 95 | #random scale 96 | scale = tf.random_uniform([1],minval=0.75,maxval=1.25,dtype=tf.float32) 97 | hi=tf.floor(scale*1024) 98 | wi=tf.floor(scale*2048) 99 | s=tf.concat([hi,wi],0) 100 | s=tf.cast(s,dtype=tf.int32) 101 | a=tf.image.resize_images(a, s,method=0,align_corners=True) 102 | b=tf.image.resize_images(b, s,method=1,align_corners=True) 103 | b = tf.image.convert_image_dtype(b, dtype=tf.float32) 104 | #random crop and flip 105 | m=tf.concat([a,b],axis=-1) 106 | m=tf.image.random_crop(m,[image_height,image_width,4]) 107 | m=tf.image.random_flip_left_right(m) 108 | 109 | m=tf.split(m,num_or_size_splits=4,axis=-1) 110 | a=tf.concat([m[0],m[1],m[2]],axis=-1) 111 | img=tf.image.convert_image_dtype(a/255,dtype=tf.uint8) 112 | a=a-[123.68,116.779,103.939] 113 | b=m[3] 114 | b = tf.image.convert_image_dtype(b, dtype=tf.uint8) 115 | a.set_shape(shape=(image_height, image_width, 3)) 116 | b.set_shape(shape=(image_height, image_width,1)) 117 | img.set_shape(shape=(image_height, image_width, 3)) 118 | return a,b,img 119 | def decodev(a,b): 120 | a = tf.read_file(a) 121 | a=tf.image.decode_png(a, channels=3) 122 | a=tf.cast(a,dtype=tf.float32) 123 | b = tf.read_file(b) 124 | a = a-[123.68,116.779,103.939] 125 | b = tf.image.decode_png(b,channels=1) 126 | a.set_shape(shape=(1024, 2048, 3)) 127 | b.set_shape(shape=(1024, 2048,1)) 128 | return a,b 129 | def run(): 130 | with tf.Graph().as_default() as graph: 131 | tf.logging.set_verbosity(tf.logging.INFO) 132 | #===================TRAINING BRANCH======================= 133 | #Load the files into one input queue 134 | images = tf.convert_to_tensor(image_files) 135 | annotations = tf.convert_to_tensor(annotation_files) 136 | tdataset = tf.data.Dataset.from_tensor_slices((images,annotations)).map(decode).shuffle(100).batch(batch_size).repeat(num_epochs) 137 | titerator = tdataset.make_initializable_iterator() 138 | images,annotations,realimg = titerator.get_next() 139 | images_val = tf.convert_to_tensor(image_val_files) 140 | annotations_val = tf.convert_to_tensor(annotation_val_files) 141 | vdataset = tf.data.Dataset.from_tensor_slices((images_val,annotations_val)).map(decodev).batch(eval_batch_size).repeat(num_epochs*3) 142 | viterator = vdataset.make_initializable_iterator() 143 | images_val,annotations_val = viterator.get_next() 144 | #perform one-hot-encoding on the ground truth annotation to get same shape as the logits 145 | _, probabilities= sfn.swiftnet(images, numclass=num_classes,is_training=True,shape=[image_height,image_width],l2=weight_decay,dropout=0.05,reuse=None) 146 | annotations = tf.reshape(annotations, shape=[-1, image_height, image_width]) 147 | #loss function 148 | raw_gt = tf.reshape(annotations, [-1,]) 149 | indices = tf.squeeze(tf.where(tf.greater(raw_gt,0)), 1) 150 | gt = tf.cast(tf.gather(raw_gt, indices), tf.int32) 151 | gt = gt-1 152 | gt_one = tf.one_hot(gt, num_classes, axis=-1) 153 | raw_prediction = tf.reshape(probabilities, [-1, num_classes]) 154 | prediction = tf.gather(raw_prediction, indices) 155 | var=tf.global_variables() 156 | var1=[v for v in var if v.name.split('/')[0]=='Resnet18' and v.name.split('/')[-2]!='attention' and v.name.split('/')[-2]!='attentionbn'] #base_net parameters 157 | var2=[v for v in var if v not in var1] #added parameters 158 | annotations_ohe = tf.one_hot(annotations, num_classes+1, axis=-1) 159 | los=weighted_cross_entropy(gt_one, prediction, class_weights,annotations_ohe) 160 | loss=tf.losses.add_loss(los) 161 | total_loss = tf.losses.get_total_loss() 162 | global_step = tf.train.get_or_create_global_step() 163 | #Define your learning rate and optimizer 164 | lr=tf.train.cosine_decay( 165 | learning_rate=initial_learning_rate, 166 | global_step=global_step, 167 | decay_steps=decay_steps, 168 | alpha=2.5e-3) 169 | optimizer1 = tf.train.AdamOptimizer(learning_rate=lr/4, epsilon=epsilon) 170 | optimizer2 = tf.train.AdamOptimizer(learning_rate=lr, epsilon=epsilon) 171 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 172 | updates_op = tf.group(*update_ops) 173 | with tf.control_dependencies([updates_op]): 174 | grads = tf.gradients(total_loss, var1 + var2) 175 | grads1 = grads[:len(var1)] 176 | grads2 = grads[len(var1):] 177 | train_op1 = optimizer1.apply_gradients(zip(grads1, var1)) 178 | train_op2 = optimizer2.apply_gradients(zip(grads2, var2),global_step=global_step) 179 | train_op = tf.group(train_op1, train_op2) 180 | _, probabilities_val= sfn.swiftnet(images_val, numclass=num_classes,is_training=None,shape=[1024,2048],l2=None,dropout=0,reuse=True) 181 | raw_gt_v = tf.reshape(tf.reshape(annotations_val, shape=[-1, 1024, 2048]),[-1,]) 182 | indices_v = tf.squeeze(tf.where(tf.greater(raw_gt_v,0)), 1) 183 | gt_v = tf.cast(tf.gather(raw_gt_v, indices_v), tf.int32) 184 | gt_v = gt_v-1 185 | gt_one_v = tf.one_hot(gt_v, num_classes, axis=-1) 186 | raw_prediction_v = tf.argmax(tf.reshape(probabilities_val, [-1, num_classes]),-1) 187 | prediction_v = tf.gather(raw_prediction_v, indices_v) 188 | prediction_ohe_v = tf.one_hot(prediction_v, num_classes, axis=-1) 189 | and_val=gt_one_v*prediction_ohe_v 190 | and_sum=tf.reduce_sum(and_val,[0]) 191 | or_val=tf.to_int32((gt_one_v+prediction_ohe_v)>0.5) 192 | or_sum=tf.reduce_sum(or_val,axis=[0]) 193 | T_sum=tf.reduce_sum(gt_one_v,axis=[0]) 194 | R_sum = tf.reduce_sum(prediction_ohe_v,axis=[0]) 195 | mPrecision=0 196 | mRecall_rate=0 197 | mIoU=0 198 | #Now we need to create a training step function that runs both the train_op, metrics_op and updates the global_step concurrently. 199 | def train_step(sess, train_op, global_step ,loss=total_loss): 200 | #Check the time for each sess run 201 | start_time = time.time() 202 | _,total_loss, global_step_count= sess.run([train_op,loss, global_step ]) 203 | time_elapsed = time.time() - start_time 204 | global_step_count=global_step_count+1 205 | #Run the logging to show some results 206 | logging.info('global step %s: loss: %.4f (%.2f sec/step)', global_step_count, total_loss, time_elapsed) 207 | 208 | return total_loss 209 | #Now finally create all the summaries you need to monitor and group them into one summary op. 210 | A = tf.Variable(tf.constant(0.0), dtype=tf.float32) 211 | a=tf.placeholder(shape=[],dtype=tf.float32) 212 | Precision=tf.assign(A, a) 213 | B = tf.Variable(tf.constant(0.0), dtype=tf.float32) 214 | b=tf.placeholder(shape=[],dtype=tf.float32) 215 | Recall=tf.assign(B, b) 216 | C = tf.Variable(tf.constant(0.0), dtype=tf.float32) 217 | c=tf.placeholder(shape=[],dtype=tf.float32) 218 | mIOU=tf.assign(C, c) 219 | predictions = tf.argmax(probabilities, -1) 220 | segmentation_output = tf.cast(tf.reshape((predictions+1)*255/num_classes, shape=[-1, image_height, image_width, 1]),tf.uint8) 221 | segmentation_ground_truth = tf.cast(tf.reshape(tf.cast(annotations, dtype=tf.float32)*255/num_classes, shape=[-1, image_height, image_width, 1]),tf.uint8) 222 | tf.summary.scalar('Monitor/Total_Loss', total_loss) 223 | tf.summary.scalar('Monitor/Precision', Precision) 224 | tf.summary.scalar('Monitor/Recall_rate', Recall) 225 | tf.summary.scalar('Monitor/mIoU', mIOU) 226 | tf.summary.scalar('Monitor/learning_rate', lr) 227 | tf.summary.image('Images/original_image', realimg, max_outputs=1) 228 | tf.summary.image('Images/segmentation_output', segmentation_output, max_outputs=1) 229 | tf.summary.image('Images/segmentation_ground_truth', segmentation_ground_truth, max_outputs=1) 230 | my_summary_op = tf.summary.merge_all() 231 | 232 | def train_sum(sess, train_op, global_step,sums,loss=total_loss,pre=0,recall=0,iou=0): 233 | start_time = time.time() 234 | _,total_loss, global_step_count,ss = sess.run([train_op,loss, global_step,sums ],feed_dict={a:pre,b:recall,c:iou}) 235 | time_elapsed = time.time() - start_time 236 | global_step_count=global_step_count+1 237 | logging.info('global step %s: loss: %.4f (%.2f sec/step)', global_step_count, total_loss, time_elapsed) 238 | 239 | return total_loss,ss 240 | 241 | def eval_step(sess,i ): 242 | and_eval_batch,or_eval_batch,T_eval_batch,R_eval_batch = sess.run([and_sum,or_sum,T_sum,R_sum]) 243 | #Log some information 244 | logging.info('STEP: %d ',i) 245 | return and_eval_batch,or_eval_batch,T_eval_batch,R_eval_batch 246 | def eval(num_class,csvname,session,image_val,eval_batch): 247 | or_=np.zeros((num_class), dtype=np.float32) 248 | and_=np.zeros((num_class), dtype=np.float32) 249 | T_=np.zeros((num_class), dtype=np.float32) 250 | R_=np.zeros((num_class), dtype=np.float32) 251 | for i in range(math.ceil(len(image_val) / eval_batch)): 252 | and_eval_batch,or_eval_batch,T_eval_batch,R_eval_batch = eval_step(session,i+1) 253 | and_=and_+and_eval_batch 254 | or_=or_+or_eval_batch 255 | T_=T_+T_eval_batch 256 | R_=R_+R_eval_batch 257 | Recall_rate=and_/T_ 258 | Precision=and_/R_ 259 | IoU=and_/or_ 260 | mPrecision=np.mean(Precision) 261 | mRecall_rate=np.mean(Recall_rate) 262 | mIoU=np.mean(IoU) 263 | print("Precision:") 264 | print(Precision) 265 | print("Recall rate:") 266 | print(Recall_rate) 267 | print("IoU:") 268 | print(IoU) 269 | print("mPrecision:") 270 | print(mPrecision) 271 | print("mRecall_rate:") 272 | print(mRecall_rate) 273 | print("mIoU") 274 | print(mIoU) 275 | with open(csvname,'a', newline='') as out: 276 | csv_write = csv.writer(out,dialect='excel') 277 | csv_write.writerow(Precision) 278 | csv_write.writerow(Recall_rate) 279 | csv_write.writerow(IoU) 280 | return mPrecision,mPrecision,mIoU 281 | 282 | gpu_options = tf.GPUOptions(allow_growth=True) 283 | config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options) 284 | init = tf.global_variables_initializer() 285 | saver=tf.train.Saver(var_list=tf.global_variables(),max_to_keep=10) 286 | with tf.Session(config=config) as sess: 287 | sess.run(init) 288 | sess.run([titerator.initializer,viterator.initializer]) 289 | sfn.load_weight(sess,'imgnet_resnet18.npz',var1)#load base_net's parameter. 290 | step = 0; 291 | if Start_train is not True: 292 | #input the checkpoint address,and the step number. 293 | checkpoint='./log/swiftnet/model.ckpt-37127' 294 | saver.restore(sess, checkpoint) 295 | step = 37127 296 | sess.run(tf.assign(global_step,step)) 297 | summary_writer = tf.summary.FileWriter(logdir, sess.graph) 298 | final = num_steps_per_epoch * num_epochs 299 | for i in range(step,final,1): 300 | if i % num_batches_per_epoch == 0: 301 | logging.info('Epoch %s/%s', i/num_batches_per_epoch + 1, num_epochs) 302 | learning_rate_value = sess.run([lr]) 303 | logging.info('Current Learning Rate: %s', learning_rate_value) 304 | if i is not step: 305 | saver.save(sess, os.path.join(logdir,log_name),global_step=i) 306 | mPrecision,mRecall_rate,mIoU=eval(num_class=num_classes,csvname=csvname,session=sess,image_val=image_val_files,eval_batch=eval_batch_size) 307 | if i % min(num_steps_per_epoch, 10) == 0: 308 | loss,summaries = train_sum(sess, train_op,global_step,sums=my_summary_op,loss=total_loss,pre=mPrecision,recall=mPrecision,iou=mIoU) 309 | summary_writer.add_summary(summaries,global_step=i+1) 310 | else: 311 | loss = train_step(sess, train_op, global_step) 312 | summary_writer.close() 313 | eval(num_class=num_classes,csvname=csvname,session=sess,image_val=image_val_files,eval_batch=eval_batch_size) 314 | logging.info('Final Loss: %s', loss) 315 | logging.info('Finished training! Saving model to disk now.') 316 | saver.save(sess, os.path.join(logdir,log_name), global_step = final) 317 | 318 | 319 | if __name__ == '__main__': 320 | run() 321 | -------------------------------------------------------------------------------- /swiftnet/swiftnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import math 3 | import numpy as np 4 | 5 | 6 | 7 | 8 | 9 | def conv(inputs,filters,kernel_size,strides=(1, 1),padding='SAME',dilation_rate=(1, 1),activation=tf.nn.relu,use_bias=None,regularizer=None,name=None,reuse=None): 10 | out=tf.layers.conv2d( 11 | inputs, 12 | filters=filters, 13 | kernel_size=kernel_size, 14 | strides=strides, 15 | padding=padding, 16 | dilation_rate=dilation_rate, 17 | activation=activation, 18 | use_bias=use_bias, 19 | kernel_regularizer=regularizer, 20 | bias_initializer=tf.zeros_initializer(), 21 | kernel_initializer= tf.random_normal_initializer(stddev=0.1), 22 | name=name, 23 | reuse=reuse) 24 | return out 25 | 26 | 27 | # In[3]: 28 | 29 | 30 | def batch(inputs,training=True,reuse=None,momentum=0.9,name='n'): 31 | out=tf.layers.batch_normalization(inputs,training=training,reuse=reuse,momentum=momentum,name=name) 32 | return out 33 | 34 | 35 | # In[4]: 36 | 37 | 38 | def branch1(x,numOut,l2,stride=1,is_training=True,momentum=0.9,reuse=None): 39 | reg = None if l2 is None else tf.contrib.layers.l2_regularizer(scale=l2) 40 | with tf.variable_scope("conv1"): 41 | y = conv(x, numOut, kernel_size=[3, 3],activation=None,strides=(stride,stride),name='conv',regularizer=reg,reuse=reuse) 42 | y = tf.nn.relu(batch(y,training=is_training,reuse=reuse,momentum=momentum,name='bn')) 43 | with tf.variable_scope("conv2"): 44 | y = conv(y, numOut, kernel_size=[3, 3],activation=None,regularizer=reg,name='conv',reuse=reuse) 45 | y = batch(y,training=is_training,reuse=reuse,momentum=momentum,name='bn') 46 | return y 47 | 48 | 49 | # In[5]: 50 | 51 | 52 | def branch2(x,numOut,l2,stride=1,is_training=True,momentum=0.9,reuse=None): 53 | reg = None if l2 is None else tf.contrib.layers.l2_regularizer(scale=l2) 54 | with tf.variable_scope("convshortcut"): 55 | y = conv(x, numOut, kernel_size=[1, 1],activation=None,strides=(stride,stride),name='conv',regularizer=reg,reuse=reuse) 56 | y = batch(y,training=is_training,reuse=reuse,momentum=momentum,name='bn') 57 | return y 58 | 59 | 60 | # In[6]: 61 | 62 | 63 | def residual(x,numOut,l2,stride=1,is_training=True,reuse=None,momentum=0.9,branch=False,name='res'): 64 | with tf.variable_scope(name): 65 | block = branch1(x,numOut,l2,stride=stride,is_training=is_training,momentum=momentum,reuse=reuse) 66 | if x.get_shape().as_list()[3] != numOut or branch: 67 | skip = branch2(x, numOut,l2,stride=stride,is_training=is_training,momentum=momentum,reuse=reuse) 68 | return tf.nn.relu(block+skip),block+skip 69 | else: 70 | return tf.nn.relu(x+block),x+block 71 | 72 | 73 | # In[7]: 74 | 75 | 76 | def resnet18(x, is_training,l2=None,dropout=0.05,reuse=None,momentum=0.9,name='Resnet18'): 77 | feature=[] 78 | with tf.variable_scope(name): 79 | reg = None if l2 is None else tf.contrib.layers.l2_regularizer(scale=l2/4) 80 | y=conv(x, 64, kernel_size=[7, 7],activation=None,strides=2,name='conv0',regularizer=reg,reuse=reuse) 81 | y=tf.nn.relu(batch(y,training=is_training,reuse=reuse,momentum=momentum,name='conv0/bn')) 82 | y=tf.nn.max_pool(y,ksize=[1,3,3,1],strides=[1,2,2,1],padding='SAME',name='pool1') 83 | with tf.variable_scope('group0'): 84 | res2a,t=residual(y,64,l2,branch=True,reuse=reuse,is_training=is_training,name='block0') 85 | res2b,t=residual(res2a,64,l2,reuse=reuse,is_training=is_training,name='block1') 86 | feature.append(t) 87 | with tf.variable_scope('group1'): 88 | res3a,t=residual(res2b,128,l2,stride=2,reuse=reuse,is_training=is_training,name='block0') 89 | res3b,t=residual(res3a,128,l2,reuse=reuse,is_training=is_training,name='block1') 90 | feature.append(t) 91 | with tf.variable_scope('group2'): 92 | res4a,t=residual(res3b,256,l2,stride=2,reuse=reuse,is_training=is_training,name='block0') 93 | res4b,t=residual(res4a,256,l2,reuse=reuse,is_training=is_training,name='block1') 94 | feature.append(t) 95 | with tf.variable_scope('group3'): 96 | res5a,t=residual(res4b,512,l2,stride=2,reuse=reuse,is_training=is_training,name='block0') 97 | res5b,t=residual(res5a,512,l2,reuse=reuse,is_training=is_training,name='block1') 98 | feature.append(t) 99 | #pool5=tf.reduce_mean(res5b, [1, 2],keepdims=True) 100 | #dropout = tf.layers.dropout(pool5,rate=dropout,training=is_training) 101 | #y=conv(dropout, 1000, kernel_size=[1, 1],activation=None,name='class',use_bias=True,regularizer=reg,reuse=reuse) 102 | #y=conv(y, 512, kernel_size=[1, 1],activation=None,name='attention',use_bias=None,regularizer=reg,reuse=reuse) 103 | #y=tf.nn.relu(batch(y,training=is_training,reuse=reuse,momentum=momentum,name='bn')) 104 | #y=res5b*y 105 | #feature.append(y) 106 | return y,feature 107 | 108 | 109 | 110 | 111 | def SpatialPyramidPooling(x, is_training,shape=[512,512],grids=(8, 4, 2),l2=None,reuse=None,momentum=0.9,name='spp'): 112 | levels=[] 113 | height=math.ceil(shape[0]/32) 114 | weight=math.ceil(shape[1]/32) 115 | with tf.variable_scope(name): 116 | reg = None if l2 is None else tf.contrib.layers.l2_regularizer(scale=l2) 117 | x=tf.nn.relu(batch(x,training=is_training,reuse=reuse,momentum=momentum,name='bn0')) 118 | x=conv(x, 128, kernel_size=1,activation=None,name='conv0',regularizer=reg,reuse=reuse) 119 | levels.append(x) 120 | for i in range(len(grids)): 121 | h=math.floor(height/grids[i]) 122 | w=math.floor(weight/grids[i]) 123 | kh=height-(grids[i]-1) * h 124 | kw=weight-(grids[i]-1) * w 125 | y=tf.nn.avg_pool(x,[1,kh,kw,1],[1,h,w,1],padding='VALID') 126 | y=tf.nn.relu(batch(y,training=is_training,reuse=reuse,momentum=momentum,name='bn'+str(i+1))) 127 | y=conv(y, 42, kernel_size=1,activation=None,name='conv'+str(i+1),regularizer=reg,reuse=reuse) 128 | y=tf.image.resize_images(y, [height,weight],method=0,align_corners=True) 129 | levels.append(y) 130 | final=tf.concat(levels,-1) 131 | final=tf.nn.relu(batch(final,training=is_training,reuse=reuse,momentum=momentum,name='blendbn')) 132 | final=conv(final, 128, kernel_size=1,activation=None,name='blendconv',regularizer=reg,reuse=reuse) 133 | return final 134 | 135 | 136 | 137 | def upsample(x,skip,is_training,shape=[512,512],stage=0,l2=None,reuse=None,momentum=0.9,name='up0'): 138 | height=math.ceil(shape[0]/math.pow(2,5-stage)) 139 | weight=math.ceil(shape[1]/math.pow(2,5-stage)) 140 | with tf.variable_scope(name): 141 | reg = None if l2 is None else tf.contrib.layers.l2_regularizer(scale=l2) 142 | skip=tf.nn.relu(batch(skip,training=is_training,reuse=reuse,momentum=momentum,name='skipbn')) 143 | skip=conv(skip, 128, kernel_size=1,activation=None,name='skipconv',regularizer=reg,reuse=reuse) 144 | x=tf.image.resize_images(x, [height,weight],method=0,align_corners=True) 145 | x=x+skip 146 | x=tf.nn.relu(batch(x,training=is_training,reuse=reuse,momentum=momentum,name='blendbn')) 147 | x=conv(x, 128, kernel_size=3,activation=None,name='blendconv',regularizer=reg,reuse=reuse) 148 | return x 149 | 150 | 151 | 152 | def swiftnet(x, numclass,is_training,shape,l2=None,dropout=0.05,reuse=None,momentum=0.9): 153 | xclass,feature=resnet18(x, is_training,l2,dropout=dropout,reuse=reuse,momentum=momentum,name='Resnet18') 154 | x=SpatialPyramidPooling(feature[-1], is_training,shape=shape,grids=(8, 4, 2),l2=l2,reuse=reuse,momentum=momentum,name='spp') 155 | x=upsample(x,feature[-2],is_training,shape=shape,stage=1,l2=l2,reuse=reuse,momentum=momentum,name='up1') 156 | x=upsample(x,feature[-3],is_training,shape=shape,stage=2,l2=l2,reuse=reuse,momentum=momentum,name='up2') 157 | x=upsample(x,feature[-4],is_training,shape=shape,stage=3,l2=l2,reuse=reuse,momentum=momentum,name='up3') 158 | with tf.variable_scope('class'): 159 | reg = None if l2 is None else tf.contrib.layers.l2_regularizer(scale=l2) 160 | x=tf.nn.relu(batch(x,training=is_training,reuse=reuse,momentum=momentum,name='classbn')) 161 | x=conv(x, numclass, kernel_size=3,activation=None,name='classconv',regularizer=reg,reuse=reuse) 162 | x=tf.image.resize_images(x, [shape[0],shape[1]],method=0,align_corners=True) 163 | final=tf.nn.softmax(x, name='logits_to_softmax') 164 | return x,final 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | def load_weight(sess,resnet50_path,varss): 173 | param = dict(np.load(resnet50_path)) 174 | for v in varss: 175 | nameEnd = v.name.split('/')[-1] 176 | if nameEnd == "moving_mean:0": 177 | name = v.name[9:-13]+"mean/EMA" 178 | elif nameEnd == "moving_variance:0": 179 | name = v.name[9:-17]+"variance/EMA" 180 | elif nameEnd =='kernel:0': 181 | if v.name.split('/')[1]=='conv0': 182 | name='conv0/W' 183 | b=np.expand_dims(param[name][:,:,0,:],2) 184 | g=np.expand_dims(param[name][:,:,1,:],2) 185 | r=np.expand_dims(param[name][:,:,2,:],2) 186 | param[name]=np.concatenate([r,g,b],2) 187 | elif v.name.split('/')[1]=='class': 188 | name='linear/W' 189 | else: 190 | name=v.name[9:-13]+'W' 191 | elif nameEnd=='gamma:0': 192 | name=v.name[9:-2] 193 | elif nameEnd=='beta:0': 194 | name=v.name[9:-2] 195 | else: 196 | name='linear/b' 197 | sess.run(v.assign(param[name])) 198 | print("Copy weights: " + name + "---->"+ v.name) 199 | 200 | 201 | -------------------------------------------------------------------------------- /swiftnet/train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.platform import tf_logging as logging 3 | import os 4 | import time 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import swiftnet 8 | import random 9 | import csv 10 | import math 11 | from scipy import misc 12 | 13 | #==============INPUT ARGUMENTS================== 14 | flags = tf.app.flags 15 | 16 | #Directory arguments 17 | flags.DEFINE_string('dataset_dir', './dataset', 'The dataset directory to find the train, validation and test images.') 18 | flags.DEFINE_string('logdir', './log/swiftnet', 'The log directory to save your checkpoint and event files.') 19 | #Training arguments 20 | flags.DEFINE_integer('num_classes', 19, 'The number of classes to predict.') 21 | flags.DEFINE_integer('batch_size', 11, 'The batch_size for training.') 22 | flags.DEFINE_integer('eval_batch_size', 8, 'The batch size used for validation.') 23 | flags.DEFINE_integer('image_height',768, "The input height of the images.") 24 | flags.DEFINE_integer('image_width', 768, "The input width of the images.") 25 | flags.DEFINE_integer('num_epochs', 200, "The number of epochs to train your model.") 26 | flags.DEFINE_integer('num_epochs_before_decay', 200, 'The number of epochs before decaying your learning rate.') 27 | flags.DEFINE_float('weight_decay', 1e-4, "The weight decay for ENet convolution layers.") 28 | flags.DEFINE_float('learning_rate_decay_factor', 0.6667, 'The learning rate decay factor.') 29 | flags.DEFINE_float('initial_learning_rate', 4e-4, 'The initial learning rate for your training.') 30 | flags.DEFINE_boolean('Start_train',True, "The input height of the images.") 31 | 32 | # 33 | 34 | FLAGS = flags.FLAGS 35 | 36 | Start_train = FLAGS.Start_train 37 | log_name = 'model.ckpt' 38 | 39 | num_classes = FLAGS.num_classes 40 | batch_size = FLAGS.batch_size 41 | eval_batch_size = FLAGS.eval_batch_size 42 | image_height = FLAGS.image_height 43 | image_width = FLAGS.image_width 44 | 45 | #Training parameters 46 | initial_learning_rate = FLAGS.initial_learning_rate 47 | num_epochs_before_decay = FLAGS.num_epochs_before_decay 48 | num_epochs =FLAGS.num_epochs 49 | learning_rate_decay_factor = FLAGS.learning_rate_decay_factor 50 | weight_decay = FLAGS.weight_decay 51 | epsilon = 1e-8 52 | 53 | 54 | #Directories 55 | dataset_dir = FLAGS.dataset_dir 56 | logdir = FLAGS.logdir 57 | 58 | #===============PREPARATION FOR TRAINING================== 59 | #Get the images into a list 60 | image_files = sorted([os.path.join(dataset_dir, 'train', file) for file in os.listdir(dataset_dir + "/train") if file.endswith('.png')]) 61 | annotation_files = sorted([os.path.join(dataset_dir, "trainannot", file) for file in os.listdir(dataset_dir + "/trainannot") if file.endswith('.png')]) 62 | image_val_files = sorted([os.path.join(dataset_dir, 'val', file) for file in os.listdir(dataset_dir + "/val") if file.endswith('.png')]) 63 | annotation_val_files = sorted([os.path.join(dataset_dir, "valannot", file) for file in os.listdir(dataset_dir + "/valannot") if file.endswith('.png')]) 64 | #保存到excel 65 | csvname=logdir[6:]+'.csv' 66 | with open(csvname,'a', newline='') as out: 67 | csv_write = csv.writer(out,dialect='excel') 68 | a=[str(i) for i in range(num_classes)] 69 | csv_write.writerow(a) 70 | #Know the number steps to take before decaying the learning rate and batches per epoch 71 | num_batches_per_epoch = math.ceil(len(image_files) / batch_size) 72 | num_steps_per_epoch = num_batches_per_epoch 73 | decay_steps = int(num_epochs_before_decay * num_steps_per_epoch) 74 | 75 | #=================CLASS WEIGHTS=============================== 76 | class_weights=np.array( 77 | [40.69042899, 47.6765088 , 12.70029695, 45.20543212, 45.78372173, 78 | 45.82527748, 48.40614895, 42.75593537, 3.36208549, 14.03151966, 79 | 4.9866471 , 39.25440643, 36.51259517, 32.81231979, 6.69824427, 80 | 33.55546509, 18.48781934, 32.97432129, 46.28665742],dtype=np.float32) 81 | 82 | def weighted_cross_entropy(onehot_labels, logits, class_weights,annotations_ohe): 83 | a=tf.reduce_sum(-tf.log(tf.clip_by_value(logits, 1e-10, 1.0))*onehot_labels*class_weights) 84 | MASK = tf.reduce_sum(1-annotations_ohe[:,:,:,0])#calculation the pixel number of the meaningful classes. 85 | return a/MASK 86 | 87 | 88 | #第一次增强采用最大1,然后亮度0.1 89 | def decode(a,b): 90 | a = tf.read_file(a) 91 | a=tf.image.decode_png(a, channels=3) 92 | a=tf.cast(a,dtype=tf.float32) 93 | b = tf.read_file(b) 94 | b = tf.image.decode_png(b,channels=1) 95 | #random scale 96 | scale = tf.random_uniform([1],minval=0.75,maxval=1.25,dtype=tf.float32) 97 | hi=tf.floor(scale*1024) 98 | wi=tf.floor(scale*2048) 99 | s=tf.concat([hi,wi],0) 100 | s=tf.cast(s,dtype=tf.int32) 101 | a=tf.image.resize_images(a, s,method=0,align_corners=True) 102 | b=tf.image.resize_images(b, s,method=1,align_corners=True) 103 | b = tf.image.convert_image_dtype(b, dtype=tf.float32) 104 | #random crop and flip 105 | m=tf.concat([a,b],axis=-1) 106 | m=tf.image.random_crop(m,[image_height,image_width,4]) 107 | m=tf.image.random_flip_left_right(m) 108 | 109 | m=tf.split(m,num_or_size_splits=4,axis=-1) 110 | a=tf.concat([m[0],m[1],m[2]],axis=-1) 111 | img=tf.image.convert_image_dtype(a/255,dtype=tf.uint8) 112 | a=a-[123.68,116.779,103.939] 113 | b=m[3] 114 | b = tf.image.convert_image_dtype(b, dtype=tf.uint8) 115 | a.set_shape(shape=(image_height, image_width, 3)) 116 | b.set_shape(shape=(image_height, image_width,1)) 117 | img.set_shape(shape=(image_height, image_width, 3)) 118 | return a,b,img 119 | def decodev(a,b): 120 | a = tf.read_file(a) 121 | a=tf.image.decode_png(a, channels=3) 122 | a=tf.cast(a,dtype=tf.float32) 123 | b = tf.read_file(b) 124 | a = a-[123.68,116.779,103.939] 125 | b = tf.image.decode_png(b,channels=1) 126 | a.set_shape(shape=(1024, 2048, 3)) 127 | b.set_shape(shape=(1024, 2048,1)) 128 | return a,b 129 | def run(): 130 | with tf.Graph().as_default() as graph: 131 | tf.logging.set_verbosity(tf.logging.INFO) 132 | #===================TRAINING BRANCH======================= 133 | #Load the files into one input queue 134 | images = tf.convert_to_tensor(image_files) 135 | annotations = tf.convert_to_tensor(annotation_files) 136 | tdataset = tf.data.Dataset.from_tensor_slices((images,annotations)).map(decode).shuffle(100).batch(batch_size).repeat(num_epochs) 137 | titerator = tdataset.make_initializable_iterator() 138 | images,annotations,realimg = titerator.get_next() 139 | images_val = tf.convert_to_tensor(image_val_files) 140 | annotations_val = tf.convert_to_tensor(annotation_val_files) 141 | vdataset = tf.data.Dataset.from_tensor_slices((images_val,annotations_val)).map(decodev).batch(eval_batch_size).repeat(num_epochs*3) 142 | viterator = vdataset.make_initializable_iterator() 143 | images_val,annotations_val = viterator.get_next() 144 | #perform one-hot-encoding on the ground truth annotation to get same shape as the logits 145 | _, probabilities= swiftnet.swiftnet(images, numclass=num_classes,is_training=True,shape=[image_height,image_width],l2=weight_decay,dropout=0.05,reuse=None) 146 | annotations = tf.reshape(annotations, shape=[-1, image_height, image_width]) 147 | #loss function 148 | raw_gt = tf.reshape(annotations, [-1,]) 149 | indices = tf.squeeze(tf.where(tf.greater(raw_gt,0)), 1) 150 | gt = tf.cast(tf.gather(raw_gt, indices), tf.int32) 151 | gt = gt-1 152 | gt_one = tf.one_hot(gt, num_classes, axis=-1) 153 | raw_prediction = tf.reshape(probabilities, [-1, num_classes]) 154 | prediction = tf.gather(raw_prediction, indices) 155 | var=tf.global_variables() 156 | var1=[v for v in var if v.name.split('/')[0]=='Resnet18' and v.name.split('/')[-2]!='attention' and v.name.split('/')[-2]!='attentionbn'] #base_net parameters 157 | var2=[v for v in var if v not in var1] #added parameters 158 | annotations_ohe = tf.one_hot(annotations, num_classes+1, axis=-1) 159 | los=weighted_cross_entropy(gt_one, prediction, class_weights,annotations_ohe) 160 | loss=tf.losses.add_loss(los) 161 | total_loss = tf.losses.get_total_loss() 162 | global_step = tf.train.get_or_create_global_step() 163 | #Define your learning rate and optimizer 164 | lr=tf.train.cosine_decay( 165 | learning_rate=initial_learning_rate, 166 | global_step=global_step, 167 | decay_steps=decay_steps, 168 | alpha=2.5e-3) 169 | optimizer1 = tf.train.AdamOptimizer(learning_rate=lr/4, epsilon=epsilon) 170 | optimizer2 = tf.train.AdamOptimizer(learning_rate=lr, epsilon=epsilon) 171 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 172 | updates_op = tf.group(*update_ops) 173 | with tf.control_dependencies([updates_op]): 174 | grads = tf.gradients(total_loss, var1 + var2) 175 | grads1 = grads[:len(var1)] 176 | grads2 = grads[len(var1):] 177 | train_op1 = optimizer1.apply_gradients(zip(grads1, var1)) 178 | train_op2 = optimizer2.apply_gradients(zip(grads2, var2),global_step=global_step) 179 | train_op = tf.group(train_op1, train_op2) 180 | _, probabilities_val= swiftnet.swiftnet(images_val, numclass=num_classes,is_training=None,shape=[1024,2048],l2=None,dropout=0,reuse=True) 181 | raw_gt_v = tf.reshape(tf.reshape(annotations_val, shape=[-1, 1024, 2048]),[-1,]) 182 | indices_v = tf.squeeze(tf.where(tf.greater(raw_gt_v,0)), 1) 183 | gt_v = tf.cast(tf.gather(raw_gt_v, indices_v), tf.int32) 184 | gt_v = gt_v-1 185 | gt_one_v = tf.one_hot(gt_v, num_classes, axis=-1) 186 | raw_prediction_v = tf.argmax(tf.reshape(probabilities_val, [-1, num_classes]),-1) 187 | prediction_v = tf.gather(raw_prediction_v, indices_v) 188 | prediction_ohe_v = tf.one_hot(prediction_v, num_classes, axis=-1) 189 | and_val=gt_one_v*prediction_ohe_v 190 | and_sum=tf.reduce_sum(and_val,[0]) 191 | or_val=tf.to_int32((gt_one_v+prediction_ohe_v)>0.5) 192 | or_sum=tf.reduce_sum(or_val,axis=[0]) 193 | T_sum=tf.reduce_sum(gt_one_v,axis=[0]) 194 | R_sum = tf.reduce_sum(prediction_ohe_v,axis=[0]) 195 | mPrecision=0 196 | mRecall_rate=0 197 | mIoU=0 198 | #Now we need to create a training step function that runs both the train_op, metrics_op and updates the global_step concurrently. 199 | def train_step(sess, train_op, global_step ,loss=total_loss): 200 | #Check the time for each sess run 201 | start_time = time.time() 202 | _,total_loss, global_step_count= sess.run([train_op,loss, global_step ]) 203 | time_elapsed = time.time() - start_time 204 | global_step_count=global_step_count+1 205 | #Run the logging to show some results 206 | logging.info('global step %s: loss: %.4f (%.2f sec/step)', global_step_count, total_loss, time_elapsed) 207 | 208 | return total_loss 209 | #Now finally create all the summaries you need to monitor and group them into one summary op. 210 | A = tf.Variable(tf.constant(0.0), dtype=tf.float32) 211 | a=tf.placeholder(shape=[],dtype=tf.float32) 212 | Precision=tf.assign(A, a) 213 | B = tf.Variable(tf.constant(0.0), dtype=tf.float32) 214 | b=tf.placeholder(shape=[],dtype=tf.float32) 215 | Recall=tf.assign(B, b) 216 | C = tf.Variable(tf.constant(0.0), dtype=tf.float32) 217 | c=tf.placeholder(shape=[],dtype=tf.float32) 218 | mIOU=tf.assign(C, c) 219 | predictions = tf.argmax(probabilities, -1) 220 | segmentation_output = tf.cast(tf.reshape((predictions+1)*255/num_classes, shape=[-1, image_height, image_width, 1]),tf.uint8) 221 | segmentation_ground_truth = tf.cast(tf.reshape(tf.cast(annotations, dtype=tf.float32)*255/num_classes, shape=[-1, image_height, image_width, 1]),tf.uint8) 222 | tf.summary.scalar('Monitor/Total_Loss', total_loss) 223 | tf.summary.scalar('Monitor/Precision', Precision) 224 | tf.summary.scalar('Monitor/Recall_rate', Recall) 225 | tf.summary.scalar('Monitor/mIoU', mIOU) 226 | tf.summary.scalar('Monitor/learning_rate', lr) 227 | tf.summary.image('Images/original_image', realimg, max_outputs=1) 228 | tf.summary.image('Images/segmentation_output', segmentation_output, max_outputs=1) 229 | tf.summary.image('Images/segmentation_ground_truth', segmentation_ground_truth, max_outputs=1) 230 | my_summary_op = tf.summary.merge_all() 231 | 232 | def train_sum(sess, train_op, global_step,sums,loss=total_loss,pre=0,recall=0,iou=0): 233 | start_time = time.time() 234 | _,total_loss, global_step_count,ss = sess.run([train_op,loss, global_step,sums ],feed_dict={a:pre,b:recall,c:iou}) 235 | time_elapsed = time.time() - start_time 236 | global_step_count=global_step_count+1 237 | logging.info('global step %s: loss: %.4f (%.2f sec/step)', global_step_count, total_loss, time_elapsed) 238 | 239 | return total_loss,ss 240 | 241 | def eval_step(sess,i ): 242 | and_eval_batch,or_eval_batch,T_eval_batch,R_eval_batch = sess.run([and_sum,or_sum,T_sum,R_sum]) 243 | #Log some information 244 | logging.info('STEP: %d ',i) 245 | return and_eval_batch,or_eval_batch,T_eval_batch,R_eval_batch 246 | def eval(num_class,csvname,session,image_val,eval_batch): 247 | or_=np.zeros((num_class), dtype=np.float32) 248 | and_=np.zeros((num_class), dtype=np.float32) 249 | T_=np.zeros((num_class), dtype=np.float32) 250 | R_=np.zeros((num_class), dtype=np.float32) 251 | for i in range(math.ceil(len(image_val) / eval_batch)): 252 | and_eval_batch,or_eval_batch,T_eval_batch,R_eval_batch = eval_step(session,i+1) 253 | and_=and_+and_eval_batch 254 | or_=or_+or_eval_batch 255 | T_=T_+T_eval_batch 256 | R_=R_+R_eval_batch 257 | Recall_rate=and_/T_ 258 | Precision=and_/R_ 259 | IoU=and_/or_ 260 | mPrecision=np.mean(Precision) 261 | mRecall_rate=np.mean(Recall_rate) 262 | mIoU=np.mean(IoU) 263 | print("Precision:") 264 | print(Precision) 265 | print("Recall rate:") 266 | print(Recall_rate) 267 | print("IoU:") 268 | print(IoU) 269 | print("mPrecision:") 270 | print(mPrecision) 271 | print("mRecall_rate:") 272 | print(mRecall_rate) 273 | print("mIoU") 274 | print(mIoU) 275 | with open(csvname,'a', newline='') as out: 276 | csv_write = csv.writer(out,dialect='excel') 277 | csv_write.writerow(Precision) 278 | csv_write.writerow(Recall_rate) 279 | csv_write.writerow(IoU) 280 | return mPrecision,mPrecision,mIoU 281 | 282 | gpu_options = tf.GPUOptions(allow_growth=True) 283 | config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options) 284 | init = tf.global_variables_initializer() 285 | saver=tf.train.Saver(var_list=tf.global_variables(),max_to_keep=10) 286 | with tf.Session(config=config) as sess: 287 | sess.run(init) 288 | sess.run([titerator.initializer,viterator.initializer]) 289 | swiftnet.load_weight(sess,'imgnet_resnet18.npz',var1)#load base_net's parameter. 290 | step = 0; 291 | if Start_train is not True: 292 | #input the checkpoint address,and the step number. 293 | checkpoint='./log/swiftnet/model.ckpt-37127' 294 | saver.restore(sess, checkpoint) 295 | step = 37127 296 | sess.run(tf.assign(global_step,step)) 297 | summary_writer = tf.summary.FileWriter(logdir, sess.graph) 298 | final = num_steps_per_epoch * num_epochs 299 | for i in range(step,final,1): 300 | if i % num_batches_per_epoch == 0: 301 | logging.info('Epoch %s/%s', i/num_batches_per_epoch + 1, num_epochs) 302 | learning_rate_value = sess.run([lr]) 303 | logging.info('Current Learning Rate: %s', learning_rate_value) 304 | if i is not step: 305 | saver.save(sess, os.path.join(logdir,log_name),global_step=i) 306 | mPrecision,mRecall_rate,mIoU=eval(num_class=num_classes,csvname=csvname,session=sess,image_val=image_val_files,eval_batch=eval_batch_size) 307 | if i % min(num_steps_per_epoch, 10) == 0: 308 | loss,summaries = train_sum(sess, train_op,global_step,sums=my_summary_op,loss=total_loss,pre=mPrecision,recall=mPrecision,iou=mIoU) 309 | summary_writer.add_summary(summaries,global_step=i+1) 310 | else: 311 | loss = train_step(sess, train_op, global_step) 312 | summary_writer.close() 313 | eval(num_class=num_classes,csvname=csvname,session=sess,image_val=image_val_files,eval_batch=eval_batch_size) 314 | logging.info('Final Loss: %s', loss) 315 | logging.info('Finished training! Saving model to disk now.') 316 | saver.save(sess, os.path.join(logdir,log_name), global_step = final) 317 | 318 | 319 | if __name__ == '__main__': 320 | run() 321 | --------------------------------------------------------------------------------