├── README.md └── ShuffleNetV2 ├── model ├── checkpoint ├── model.ckpt.data-00000-of-00001 ├── model.ckpt.index ├── model.ckpt.meta ├── model_1.ckpt.data-00000-of-00001 ├── model_1.ckpt.index ├── model_1.ckpt.meta └── tensorboard │ ├── events.out.tfevents.1536153204.DESKTOP-BNNBN92 │ └── events.out.tfevents.1536336490.DESKTOP-BNNBN92 ├── rst └── rst.txt └── src ├── main.py └── model.py /README.md: -------------------------------------------------------------------------------- 1 | # ShuffleNetV2_tensorflow 2 | a tensorflow based implementation of ShuffleNetV2 on the Tiny ImageNet dataset 3 | 4 | ## File Organization 5 | ShuffleNetV2 6 | |---data: tiny imagenet dataset 7 | |---|---test 8 | |---|---train 9 | |---|---val 10 | |---model: save checkpoint file 11 | |---|---tensorboard: save tensorboard file 12 | |---src: source codes 13 | |---|---main.py: data load, model training and test functions 14 | |---|---model.py: ShuffleNetV2 model 15 | |---rst: result on the test set 16 | ## Environment 17 | Win10, anaconda 1.8.7, python 3.6.5, tensorflow 1.8. 18 | ## Description 19 | The model use the 0.5x weights configure. 20 | I trained the model from the scratch on the [Tiny ImageNet dataset](http://tiny-imagenet.herokuapp.com/). 21 | I trained two rounds. 22 | On the first round, the learning rate was set as 0.5 and tricks such as warm up and exponential decay were used. 23 | It trained 150 epochs. The accuracy on the validation set achieved above 80%. 24 | On the second round, the learning rate was set as 0.0005 and exponential decay was used. 25 | It trained 50 epochs. The accuracy on the validation set achieved about 90%. 26 | I tested on the test dataset. But I failed to get the accuracy because the website crashed after I uploaded my result. 27 | -------------------------------------------------------------------------------- /ShuffleNetV2/model/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model_1.ckpt" 2 | all_model_checkpoint_paths: "model_1.ckpt" 3 | -------------------------------------------------------------------------------- /ShuffleNetV2/model/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monkeyCv/ShuffleNetV2_tensorflow/e8922669d1a0a538c9065885d295cf05a0a41136/ShuffleNetV2/model/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /ShuffleNetV2/model/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monkeyCv/ShuffleNetV2_tensorflow/e8922669d1a0a538c9065885d295cf05a0a41136/ShuffleNetV2/model/model.ckpt.index -------------------------------------------------------------------------------- /ShuffleNetV2/model/model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monkeyCv/ShuffleNetV2_tensorflow/e8922669d1a0a538c9065885d295cf05a0a41136/ShuffleNetV2/model/model.ckpt.meta -------------------------------------------------------------------------------- /ShuffleNetV2/model/model_1.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monkeyCv/ShuffleNetV2_tensorflow/e8922669d1a0a538c9065885d295cf05a0a41136/ShuffleNetV2/model/model_1.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /ShuffleNetV2/model/model_1.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monkeyCv/ShuffleNetV2_tensorflow/e8922669d1a0a538c9065885d295cf05a0a41136/ShuffleNetV2/model/model_1.ckpt.index -------------------------------------------------------------------------------- /ShuffleNetV2/model/model_1.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monkeyCv/ShuffleNetV2_tensorflow/e8922669d1a0a538c9065885d295cf05a0a41136/ShuffleNetV2/model/model_1.ckpt.meta -------------------------------------------------------------------------------- /ShuffleNetV2/model/tensorboard/events.out.tfevents.1536153204.DESKTOP-BNNBN92: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monkeyCv/ShuffleNetV2_tensorflow/e8922669d1a0a538c9065885d295cf05a0a41136/ShuffleNetV2/model/tensorboard/events.out.tfevents.1536153204.DESKTOP-BNNBN92 -------------------------------------------------------------------------------- /ShuffleNetV2/model/tensorboard/events.out.tfevents.1536336490.DESKTOP-BNNBN92: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monkeyCv/ShuffleNetV2_tensorflow/e8922669d1a0a538c9065885d295cf05a0a41136/ShuffleNetV2/model/tensorboard/events.out.tfevents.1536336490.DESKTOP-BNNBN92 -------------------------------------------------------------------------------- /ShuffleNetV2/src/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Aug 21 20:32:19 2018 4 | 5 | @author: Yuxi1989 6 | """ 7 | import tensorflow as tf 8 | from model import model 9 | import os 10 | import random 11 | 12 | def load_data(): 13 | #dir param 14 | train_dir=os.path.join('..','data','train') 15 | test_dir=os.path.join('..','data','test') 16 | val_dir=os.path.join('..','data','val') 17 | 18 | train_imgs=[] 19 | train_labels=[] 20 | val_imgs=[] 21 | val_labels=[] 22 | test_imgs=[] 23 | label_correspond={} 24 | 25 | for i,label in enumerate(os.listdir(train_dir)): 26 | label_correspond[label]=i 27 | img_dir=os.path.join(train_dir,label,'images') 28 | for img in os.listdir(img_dir): 29 | train_imgs.append(os.path.join(img_dir,img)) 30 | train_labels.append(i) 31 | 32 | with open(os.path.join(val_dir,'val_annotations.txt')) as file: 33 | file_labels=[] 34 | for line in file.readlines(): 35 | strings=line.split('\t') 36 | file_labels.append(label_correspond[strings[1]]) 37 | for img in sorted(os.listdir(os.path.join(val_dir,'images'))): 38 | val_imgs.append(os.path.join(val_dir,'images',img)) 39 | img_idx=int(img.split('_')[1].split('.')[0]) 40 | val_labels.append(file_labels[img_idx]) 41 | 42 | for img in sorted(os.listdir(os.path.join(test_dir,'images'))): 43 | test_imgs.append(os.path.join(test_dir,'images',img)) 44 | 45 | train=list(zip(train_imgs,train_labels)) 46 | random.shuffle(train) 47 | train_imgs[:],train_labels[:]=zip(*train ) 48 | return train_imgs,train_labels,val_imgs,val_labels, test_imgs,label_correspond 49 | 50 | def train(train_imgs,train_labels,val_imgs,val_labels): 51 | has_train=True 52 | TB_LOG_DIR=os.path.join('..','model') 53 | ckpt = tf.train.get_checkpoint_state(TB_LOG_DIR) 54 | if not ckpt and not ckpt.model_checkpoint_path: 55 | has_train=False 56 | if has_train==False: 57 | #dataset param 58 | EPOCHS=150 59 | SHUFFLE_SZ=1000 60 | BATCH_SZ=200 61 | #model param 62 | OUTPUT_CNS=[24,48,96,192,1024] 63 | CLASS_NUM=200 64 | WEIGHT_DECAY=4e-5 65 | #training param 66 | WARM_UP_LR=0.002 67 | LEARNING_RATE=0.5 68 | LEARNING_RATE_DECAY=0.95 69 | TOTAL_STEPS=EPOCHS*100000//BATCH_SZ 70 | LEARNING_RATE_STEPS=TOTAL_STEPS//100 71 | MOMENTUM=0.9 72 | #display 73 | DISPLAY_STEP=TOTAL_STEPS//100 74 | TB_LOG_DIR=os.path.join('..','model') 75 | #validation 76 | VAL_SZ=10000 77 | else: 78 | #dataset param 79 | EPOCHS=50 80 | SHUFFLE_SZ=1000 81 | BATCH_SZ=200 82 | #model param 83 | OUTPUT_CNS=[24,48,96,192,1024] 84 | CLASS_NUM=200 85 | WEIGHT_DECAY=4e-5 86 | #training param 87 | WARM_UP_LR=0.0005 88 | LEARNING_RATE=0.0005 89 | LEARNING_RATE_DECAY=0.9 90 | TOTAL_STEPS=EPOCHS*100000//BATCH_SZ 91 | LEARNING_RATE_STEPS=TOTAL_STEPS//100 92 | MOMENTUM=0.9 93 | #display 94 | DISPLAY_STEP=TOTAL_STEPS//100 95 | TB_LOG_DIR=os.path.join('..','model') 96 | #validation 97 | VAL_SZ=10000 98 | 99 | 100 | imgpaths=tf.convert_to_tensor(train_imgs) 101 | labels=tf.convert_to_tensor(train_labels) 102 | valimgpaths=tf.convert_to_tensor(val_imgs) 103 | vallabels=tf.convert_to_tensor(val_labels) 104 | 105 | #sess=tf.Session() 106 | def _parse_function(imgpath,label): 107 | img=tf.read_file(imgpath) 108 | img_decoded=tf.image.decode_jpeg(img,3) 109 | img_decoded.set_shape([64,64,3]) 110 | img_decoded=tf.cast(img_decoded,dtype=tf.float32) 111 | return img_decoded,label 112 | dataset=tf.data.Dataset.from_tensor_slices((imgpaths,labels)).map(_parse_function) 113 | dataset=dataset.shuffle(buffer_size=SHUFFLE_SZ) 114 | dataset=dataset.repeat(EPOCHS) 115 | dataset=dataset.batch(BATCH_SZ) 116 | iterator=dataset.make_initializable_iterator() 117 | batch_imgs,batch_labels=iterator.get_next() 118 | 119 | valset=tf.data.Dataset.from_tensor_slices((valimgpaths,vallabels)).map(_parse_function) 120 | valset=valset.batch(VAL_SZ) 121 | valiterator=dataset.make_initializable_iterator() 122 | valbatch_imgs,valbatch_labels=valiterator.get_next() 123 | #dimgs,dlabels=sess.run([batch_imgs,batch_labels]) 124 | 125 | initial=tf.variance_scaling_initializer() 126 | regular=tf.contrib.layers.l2_regularizer(1.0) 127 | logits=model(batch_imgs,OUTPUT_CNS,CLASS_NUM,True,regular,initial) 128 | with tf.name_scope('loss'): 129 | loss=tf.reduce_mean( 130 | tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,labels=batch_labels)) 131 | reg=tf.losses.get_regularization_loss() 132 | loss+=WEIGHT_DECAY*reg 133 | with tf.name_scope('train'): 134 | global_step=tf.get_variable('step',shape=[],trainable=False, 135 | initializer=tf.zeros_initializer(dtype=tf.int64)) 136 | def get_lr(global_step,total_step,base_lr,warm_up_lr): 137 | warm_up_total_step=total_step//20 138 | transition_total_step=warm_up_total_step 139 | remain_total_step=total_step-warm_up_total_step-transition_total_step 140 | transition_dlrt=tf.convert_to_tensor((1.0*base_lr-warm_up_lr)/transition_total_step,dtype=tf.float32) 141 | base_lrt=tf.convert_to_tensor(base_lr,dtype=tf.float32) 142 | warm_up_lrt=tf.convert_to_tensor(warm_up_lr,dtype=tf.float32) 143 | warm_up_total_step=tf.convert_to_tensor(warm_up_total_step,dtype=tf.float32) 144 | transition_total_step=tf.convert_to_tensor(transition_total_step,dtype=tf.float32) 145 | remain_total_step=tf.convert_to_tensor(remain_total_step,dtype=tf.float32) 146 | transition_lr=(tf.cast(global_step,tf.float32)-warm_up_total_step)*transition_dlrt+warm_up_lrt 147 | remain_lr=tf.train.exponential_decay(base_lrt,tf.cast(global_step,tf.float32)-warm_up_total_step-transition_total_step, 148 | remain_total_step//120 ,LEARNING_RATE_DECAY) 149 | lr=tf.case({tf.less(global_step,warm_up_total_step): lambda:warm_up_lrt, 150 | tf.greater(global_step,transition_total_step+warm_up_total_step): lambda:remain_lr}, 151 | default=lambda:transition_lr,exclusive=True) 152 | return lr 153 | if has_train==False: 154 | learning_rate=get_lr(global_step,TOTAL_STEPS,LEARNING_RATE,WARM_UP_LR) 155 | else: 156 | learning_rate=tf.train.exponential_decay(LEARNING_RATE,global_step,LEARNING_RATE_STEPS,LEARNING_RATE_DECAY) 157 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 158 | with tf.control_dependencies(update_ops): 159 | train_op=tf.train.MomentumOptimizer(learning_rate=learning_rate, 160 | momentum=MOMENTUM).minimize(loss) 161 | with tf.control_dependencies([train_op]): 162 | global_step_update=tf.assign_add(global_step,1) 163 | 164 | if has_train==False: 165 | init=tf.global_variables_initializer() 166 | 167 | with tf.name_scope('batch_train_accuracy'): 168 | logits_train=model(batch_imgs,OUTPUT_CNS,CLASS_NUM,False,regular,initial) 169 | correct_pred_train=tf.equal(tf.cast(tf.argmax(logits_train,1),dtype=tf.int32),batch_labels) 170 | accuracy_train=tf.reduce_mean(tf.cast(correct_pred_train,tf.float32)) 171 | 172 | with tf.name_scope('val_accuracy'): 173 | logits_val=model(valbatch_imgs,OUTPUT_CNS,CLASS_NUM,False,regular,initial) 174 | correct_pred_val=tf.equal(tf.cast(tf.argmax(logits_val,1),dtype=tf.int32),valbatch_labels) 175 | accuracy_val=tf.reduce_mean(tf.cast(correct_pred_val,tf.float32)) 176 | 177 | sess=tf.Session() 178 | if has_train==False: 179 | sess.run(init) 180 | else: 181 | saver=tf.train.Saver() 182 | saver.restore(sess,ckpt.model_checkpoint_path) 183 | sess.run(iterator.initializer) 184 | 185 | tf.summary.scalar('loss',loss) 186 | tf.summary.scalar('batch_train_accuracy',accuracy_train) 187 | tf.summary.scalar('val_accuracy',accuracy_val) 188 | tf.summary.scalar('learning_rate',learning_rate) 189 | tb_merge_summary_op=tf.summary.merge_all() 190 | summary_writer=tf.summary.FileWriter(os.path.join(TB_LOG_DIR,'tensorboard'),graph=sess.graph) 191 | 192 | saver=tf.train.Saver() 193 | 194 | sess.run(tf.assign(global_step,0.0)) 195 | for step in range(1,TOTAL_STEPS+1): 196 | try: 197 | #_,print_step=sess.run(train_op) 198 | sess.run(global_step_update) 199 | except tf.errors.OutOfRangeError: 200 | break 201 | if step%DISPLAY_STEP==0 or step==1: 202 | sess.run(valiterator.initializer) 203 | l,acct,accv,lr,summary_str=sess.run([loss,accuracy_train,accuracy_val,learning_rate,tb_merge_summary_op]) 204 | summary_writer.add_summary(summary_str,step) 205 | print("epoch {:d} steps {:d}: loss={:.4f}, accuracy_batch_train={:.4f}, accuracy_val={:.4f}, learning_rate={:.5f}".format( 206 | step//(TOTAL_STEPS//EPOCHS),step,l,acct,accv,lr)) 207 | 208 | summary_writer.close() 209 | saver.save(sess,os.path.join(TB_LOG_DIR,'model_1.ckpt')) 210 | 211 | def test(test_imgs): 212 | TB_LOG_DIR=os.path.join('..','model') 213 | ckpt = tf.train.get_checkpoint_state(TB_LOG_DIR) 214 | if not ckpt and not ckpt.model_checkpoint_path: 215 | print("No model! Please train the model first!") 216 | return 217 | 218 | imgpaths=tf.convert_to_tensor(test_imgs) 219 | OUTPUT_CNS=[24,48,96,192,1024] 220 | CLASS_NUM=200 221 | BATCH_SZ=10000 222 | 223 | def _parse_function(imgpath): 224 | img=tf.read_file(imgpath) 225 | img_decoded=tf.image.decode_jpeg(img,3) 226 | img_decoded.set_shape([64,64,3]) 227 | img_decoded=tf.cast(img_decoded,dtype=tf.float32) 228 | return img_decoded 229 | dataset=tf.data.Dataset.from_tensor_slices(imgpaths).map(_parse_function) 230 | dataset=dataset.batch(BATCH_SZ) 231 | iterator=dataset.make_one_shot_iterator() 232 | batch_imgs=iterator.get_next() 233 | initial=tf.variance_scaling_initializer() 234 | regular=tf.contrib.layers.l2_regularizer(1.0) 235 | model(batch_imgs,OUTPUT_CNS,CLASS_NUM,True,regular,initial) 236 | logits_test=model(batch_imgs,OUTPUT_CNS,CLASS_NUM,False,regular,initial) 237 | pred=tf.cast(tf.argmax(logits_test,1),dtype=tf.int32) 238 | 239 | sess=tf.Session() 240 | saver=tf.train.Saver() 241 | saver.restore(sess,ckpt.model_checkpoint_path) 242 | prediction=sess.run(pred) 243 | return prediction 244 | 245 | def parse_pred(pred,label_correspond,test_imgs): 246 | imgs=[x.split('\\')[-1] for x in test_imgs] 247 | corrs={v:k for k,v in label_correspond.items()} 248 | labels=[corrs[x] for x in pred] 249 | num=len(labels) 250 | with open('../rst/rst.txt','w') as file: 251 | for idx in range(num): 252 | file.write('{!s} {!s}\n'.format(imgs[idx],labels[idx])) 253 | 254 | if __name__=='__main__': 255 | tf.reset_default_graph() 256 | train_imgs,train_labels,val_imgs,val_labels, test_imgs,label_correspond = load_data() 257 | #train(train_imgs,train_labels,val_imgs,val_labels) 258 | pred=test(test_imgs) 259 | parse_pred(pred,label_correspond,test_imgs) 260 | 261 | 262 | 263 | 264 | 265 | -------------------------------------------------------------------------------- /ShuffleNetV2/src/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Aug 20 20:07:23 2018 4 | 5 | @author: Yuxi1989 6 | """ 7 | 8 | import tensorflow as tf 9 | 10 | def unit_c(l,cns,is_training,regular,initial,idx):#NHWC 11 | with tf.variable_scope('unit_c{}'.format(idx)): 12 | left,right=tf.split(l,2,axis=3) 13 | s=left.get_shape().as_list() 14 | H=s[1] 15 | W=s[2] 16 | left_c=s[3] 17 | conv1=tf.layers.conv2d(right,cns//2,1,padding='SAME', 18 | kernel_initializer=initial, 19 | kernel_regularizer=regular) 20 | bn1=tf.layers.batch_normalization(conv1,training=is_training) 21 | relu1=tf.nn.relu(bn1) 22 | depth_conv_kernel=tf.get_variable('dconv',shape=[3,3,cns//2,1],initializer=initial,regularizer=regular) 23 | conv2=tf.nn.depthwise_conv2d(relu1,filter=depth_conv_kernel, 24 | strides=[1,1,1,1],padding='SAME') 25 | bn2=tf.layers.batch_normalization(conv2,training=is_training) 26 | conv3=tf.layers.conv2d(bn2,cns-left_c,1,padding='SAME', 27 | kernel_initializer=initial, 28 | kernel_regularizer=regular) 29 | bn3=tf.layers.batch_normalization(conv3,training=is_training) 30 | relu3=tf.nn.relu(bn3) 31 | out=tf.concat([left,relu3],axis=3) 32 | out=tf.reshape(out,shape=[-1,H,W,cns//2,2])#shuffle 33 | out=tf.transpose(out,perm=[0,1,2,4,3]) 34 | out=tf.reshape(out,shape=[-1,H,W,cns]) 35 | return out 36 | 37 | def unit_d(l,cns,is_training,regular,initial): 38 | with tf.variable_scope('unit_d'): 39 | left,right=l,l 40 | s=l.get_shape().as_list() 41 | H=s[1]//2 42 | W=s[2]//2 43 | depth_conv_kernel1=tf.get_variable('dconv1',shape=[3,3,cns//2,1],initializer=initial,regularizer=regular) 44 | lconv1=tf.nn.depthwise_conv2d(left,filter=depth_conv_kernel1, 45 | strides=[1,2,2,1],padding='SAME') 46 | lbn1=tf.layers.batch_normalization(lconv1,training=is_training) 47 | lconv2=tf.layers.conv2d(lbn1,cns//2,1,padding='SAME', 48 | kernel_initializer=initial, 49 | kernel_regularizer=regular) 50 | lbn2=tf.layers.batch_normalization(lconv2,training=is_training) 51 | lrelu2=tf.nn.relu(lbn2) 52 | rconv1=tf.layers.conv2d(right,cns//2,1,padding='SAME', 53 | kernel_initializer=initial, 54 | kernel_regularizer=regular) 55 | rbn1=tf.layers.batch_normalization(rconv1,training=is_training) 56 | rrelu1=tf.nn.relu(rbn1) 57 | depth_conv_kernel2=tf.get_variable('dconv2',shape=[3,3,cns//2,1],initializer=initial,regularizer=regular) 58 | rconv2=tf.nn.depthwise_conv2d(rrelu1,filter=depth_conv_kernel2, 59 | strides=[1,2,2,1],padding='SAME') 60 | rbn2=tf.layers.batch_normalization(rconv2,training=is_training) 61 | rconv3=tf.layers.conv2d(rbn2,cns//2,1,padding='SAME', 62 | kernel_initializer=initial, 63 | kernel_regularizer=regular) 64 | rbn3=tf.layers.batch_normalization(rconv3,training=is_training) 65 | rrelu3=tf.nn.relu(rbn3) 66 | out=tf.concat([lrelu2,rrelu3],axis=3) 67 | out=tf.reshape(out,shape=[-1,H,W,cns//2,2]) 68 | out=tf.transpose(out,perm=[0,1,2,4,3]) 69 | out=tf.reshape(out,shape=[-1,H,W,cns]) 70 | return out 71 | 72 | def stage_1(img,cns,regular,initial): 73 | with tf.variable_scope('stage_1'): 74 | conv=tf.layers.conv2d(img,cns,3,strides=2,padding='SAME', 75 | kernel_initializer=initial, 76 | kernel_regularizer=regular) 77 | pool=tf.layers.max_pooling2d(conv,3,strides=2,padding='SAME') 78 | return pool 79 | 80 | def stage_2(l,cns,is_training,regular,initial): 81 | with tf.variable_scope('stage_2'): 82 | l=unit_d(l,cns,is_training,regular,initial) 83 | for i in range(3): 84 | l=unit_c(l,cns,is_training,regular,initial,i) 85 | return l 86 | 87 | def stage_3(l,cns,is_training,regular,initial): 88 | with tf.variable_scope('stage_3'): 89 | l=unit_d(l,cns,is_training,regular,initial) 90 | for i in range(7): 91 | l=unit_c(l,cns,is_training,regular,initial,i) 92 | return l 93 | 94 | def stage_4(l,cns,is_training,regular,initial): 95 | with tf.variable_scope('stage_4'): 96 | l=unit_d(l,cns,is_training,regular,initial) 97 | for i in range(3): 98 | l=unit_c(l,cns,is_training,regular,initial,i) 99 | return l 100 | 101 | def stage_5(l,cns,cls,regular,initial): 102 | with tf.variable_scope('stage_5'): 103 | conv=tf.layers.conv2d(l,cns,1,padding='SAME', 104 | kernel_initializer=initial, 105 | kernel_regularizer=regular) 106 | pool=tf.reduce_mean(conv,axis=[1,2]) 107 | fc=tf.layers.dense(pool,cls) 108 | return fc 109 | 110 | def model(img,cns,cls,is_training,regular,initial): 111 | with tf.variable_scope('model',reuse=not is_training): 112 | o1=stage_1(img,cns[0],regular,initial) 113 | o2=stage_2(o1,cns[1],is_training,regular,initial) 114 | o3=stage_3(o2,cns[2],is_training,regular,initial) 115 | o4=stage_4(o3,cns[3],is_training,regular,initial) 116 | logits=stage_5(o4,cns[4],cls,regular,initial) 117 | return logits --------------------------------------------------------------------------------