├── .gitignore ├── 000129.jpg ├── 000129.png ├── 000999.jpg ├── 000999.png ├── LICENSE ├── README.md ├── create_labels.py ├── eval.py ├── model_contour.py ├── ops.py ├── requirements .txt ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /000129.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Raj-08/tensorflow-object-contour-detection/f73dede38ce2177675bdf9d75f206696f51621d5/000129.jpg -------------------------------------------------------------------------------- /000129.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Raj-08/tensorflow-object-contour-detection/f73dede38ce2177675bdf9d75f206696f51621d5/000129.png -------------------------------------------------------------------------------- /000999.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Raj-08/tensorflow-object-contour-detection/f73dede38ce2177675bdf9d75f206696f51621d5/000999.jpg -------------------------------------------------------------------------------- /000999.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Raj-08/tensorflow-object-contour-detection/f73dede38ce2177675bdf9d75f206696f51621d5/000999.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Raj 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow-object-contour-detection 2 | 3 | This is a tensorflow implimentation of Object Contour Detection with a Fully Convolutional Encoder-Decoder Network (https://arxiv.org/pdf/1603.04530.pdf) . 4 | 5 | **REQUIREMENTS :** 6 | 7 | ``` 8 | pip install requirements.txt 9 | ``` 10 | **Label Preparation :** 11 | 12 | To prepare the labels for contour detection from PASCAL Dataset , run create_lables.py and edit the file to add the path of the labels and new labels to be generated . Use this path for labels during training. 13 | 14 | **TRAINING :** 15 | 16 | ``` 17 | python train.py \ 18 | --max_to_keep=50 \ 19 | --Epochs=100 \ 20 | --momentum=0.9 \ 21 | --learning_rate=.0000001 \ 22 | --train_crop_size=480 \ 23 | --clip_by_value=1.0 \ 24 | --train_text = ${path to text file} \ 25 | --log_dir = ${path to where logs will be saved} \ 26 | --tf_initial_checkpoint=${PATH_TO_CHECKPOINT} \ 27 | --label_dir = ${path to label directory} \ 28 | --image_dir = ${path to image directory} 29 | ``` 30 | **EVALUATION :** 31 | ``` 32 | python eval.py \ 33 | --checkpoint=${path to checkpoint to be evaluated} \ 34 | --save_preds=${path to folder where predictions will be saved} \ 35 | --image_dir = ${path to image directory} \ 36 | --eval_crop_size=480 \ 37 | --eval_text = ${path to eval text file} 38 | 39 | ``` 40 | **Results :** 41 | 42 | Image_1 43 | 44 | prediction_1 45 | 46 | Image_1 47 | 48 | prediction_1 49 | -------------------------------------------------------------------------------- /create_labels.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | path_to_new_labels = '' 3 | if path_to_new_labels==False: 4 | print('Path to new labels missing') 5 | 6 | alls = os.listdir(path_to_labels) 7 | for al in alls: 8 | GT = cv2.imread(path_to_labels+al,0) 9 | GT[GT!=255]=0 10 | cl = cv2.imwrite(path_to_new_labels+al,GT) 11 | print('Label written') 12 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from model_contour import build_model 4 | slim = tf.contrib.slim 5 | flags = tf.app.flags 6 | from tensorflow.python.ops import variables 7 | from utils import random_crop_and_pad_image 8 | import os 9 | import cv2 10 | import time 11 | 12 | FLAGS = flags.FLAGS 13 | flags.DEFINE_integer('eval_crop_size', 480, 14 | 'Image crop size [height, width] for evaluation.') 15 | 16 | flags.DEFINE_string('checkpoint', None, 17 | 'The initial checkpoint in tensorflow format.') 18 | 19 | 20 | flags.DEFINE_string('image_dir', None, 21 | 'The Image Directory.') 22 | 23 | flags.DEFINE_string('save_preds',None, 24 | 'Path to folder where predictions will be saved.') 25 | 26 | flags.DEFINE_string('eval_text', None, 27 | 'The Path to the text file containing names of Images and Labels')###This text file should not have extensions in their names such as 8192.png or 8192.jpg instead just the name such as 8192 28 | 29 | 30 | 31 | Image_directory = '/home/ubuntu/gfav/deeplab/tensorflow_deeplab_resnet/data/pascal/VOCdevkit/VOC2007/JPEGImages/' 32 | my_log_dir='./logs' 33 | 34 | def load(saver, sess, ckpt_path): 35 | saver.restore(sess, ckpt_path) 36 | print("Restored model parameters from {}".format(ckpt_path)) 37 | 38 | image_ph = tf.placeholder(tf.uint8,[1,None,None,3],name='image_placeholder') 39 | size = FLAGS.train_crop_size 40 | image,label=random_crop_and_pad_image(tf.squeeze(image_ph),size,size) 41 | norm_image = tf.image.per_image_standardization(tf.squeeze(image)) 42 | norm_image = tf.expand_dims(norm_image,dim=0) 43 | pred = build_model(norm_image) 44 | restore_var = tf.trainable_variables() 45 | pred = tf.nn.sigmoid(pred) 46 | loader = tf.train.Saver(var_list=restore_var) 47 | 48 | init = variables.global_variables_initializer() 49 | with tf.Session() as sess: 50 | sess.run(init) 51 | load(loader, sess, FLAGS.checkpoint) 52 | f = open(FLAGS.eval_text,'r') 53 | message = f.read() 54 | lines = message.split('\n') 55 | for l in lines: 56 | try : 57 | input_image = cv2.imread(Image_directory+l+'.jpg') 58 | feed_dict={image_ph:input_image} 59 | P= sess.run(pred, feed_dict=feed_dict) 60 | np.save(save_preds+l,P) 61 | except: 62 | print("ERROR") -------------------------------------------------------------------------------- /model_contour.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from ops import conv,batch_normalization,unpool_with_argmax,sigmoid,relu 3 | is_training = True 4 | def build_model(input_img): 5 | 6 | conv1=conv(input_img,7, 7, 64, 1, 1, biased=False, relu=False, name='conv1') 7 | bn_conv1=batch_normalization(conv1,is_training=is_training, activation_fn=tf.nn.relu, name='bn_conv1') 8 | conv2=conv(bn_conv1,3, 3, 64, 1, 1, biased=False, relu=False, name='conv2') 9 | bn_conv2=batch_normalization(conv2,is_training=is_training, activation_fn=tf.nn.relu, name='bn_conv2') 10 | pool1, pool1_indices = tf.nn.max_pool_with_argmax(bn_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool1') 11 | print(pool1) 12 | print('INd',pool1_indices) 13 | 14 | conv3=conv(pool1,3, 3,128, 1, 1, biased=False, relu=False, name='conv3') 15 | bn_conv3=batch_normalization(conv3,is_training=is_training, activation_fn=tf.nn.relu, name='bn_conv3') 16 | conv4=conv(bn_conv3,3, 3, 128, 1, 1, biased=False, relu=False, name='conv4') 17 | bn_conv4=batch_normalization(conv4,is_training=is_training, activation_fn=tf.nn.relu, name='bn_conv4') 18 | pool2, pool2_indices = tf.nn.max_pool_with_argmax(bn_conv4, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool2') 19 | print(pool2) 20 | print('INd',pool2_indices) 21 | conv5=conv(pool2,3, 3, 256, 1, 1, biased=False, relu=False, name='conv5') 22 | bn_conv5=batch_normalization(conv5,is_training=is_training, activation_fn=tf.nn.relu, name='bn_conv5') 23 | conv6=conv(bn_conv5,3, 3, 256, 1, 1, biased=False, relu=False, name='conv6') 24 | bn_conv6=batch_normalization(conv6,is_training=is_training, activation_fn=tf.nn.relu, name='bn_conv6') 25 | conv7=conv(bn_conv6,3, 3, 256, 1, 1, biased=False, relu=False, name='conv7') 26 | bn_conv7=batch_normalization(conv7,is_training=is_training, activation_fn=tf.nn.relu, name='bn_conv7') 27 | pool3, pool3_indices = tf.nn.max_pool_with_argmax(bn_conv7, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool3') 28 | conv8=conv(pool3,3, 3, 512, 1, 1, biased=False, relu=False, name='conv8') 29 | bn_conv8=batch_normalization(conv8,is_training=is_training, activation_fn=tf.nn.relu, name='bn_conv8') 30 | conv9=conv(bn_conv8,3, 3, 512, 1, 1, biased=False, relu=False, name='conv9') 31 | bn_conv9=batch_normalization(conv9,is_training=is_training, activation_fn=tf.nn.relu, name='bn_conv9') 32 | conv10=conv(bn_conv9,3, 3, 512, 1, 1, biased=False, relu=False, name='conv10') 33 | bn_conv10=batch_normalization(conv10,is_training=is_training, activation_fn=tf.nn.relu, name='bn_conv10') 34 | pool4, pool4_indices = tf.nn.max_pool_with_argmax(bn_conv10, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool4') 35 | conv11=conv(pool4,3, 3, 512, 1, 1, biased=False, relu=False, name='conv11') 36 | bn_conv11=batch_normalization(conv11,is_training=is_training, activation_fn=tf.nn.relu, name='bn_conv11') 37 | conv12=conv(bn_conv11,3, 3, 512, 1, 1, biased=False, relu=False, name='conv12') 38 | bn_conv12=batch_normalization(conv12,is_training=is_training, activation_fn=tf.nn.relu, name='bn_conv12') 39 | conv13=conv(bn_conv12,3, 3, 512, 1, 1, biased=False, relu=False, name='conv13') 40 | bn_conv13=batch_normalization(conv13,is_training=is_training, activation_fn=tf.nn.relu, name='bn_conv13') 41 | pool5, pool5_indices = tf.nn.max_pool_with_argmax(bn_conv13, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool5') 42 | conv14=conv(pool5,3, 3, 4096, 1, 1, biased=False, relu=False, name='conv14') 43 | bn_conv14=batch_normalization(conv14,is_training=is_training, activation_fn=tf.nn.relu, name='bn_conv14') 44 | deconv6=conv(bn_conv14,1, 1, 512, 1, 1, biased=False, relu=False, name='conv15') 45 | bn_deconv6=batch_normalization(deconv6,is_training=is_training, activation_fn=tf.nn.relu, name='bn_deconv6') 46 | #print(bn_deconv6) 47 | unpool_5 = unpool_with_argmax(bn_deconv6, ind=pool5_indices, name="unpool_5") 48 | deconv5=conv(unpool_5,5, 5, 512, 1, 1, biased=False, relu=False, name='deconv5') 49 | bn_deconv5=batch_normalization(deconv5,is_training=is_training, activation_fn=tf.nn.relu, name='bn_deconv5') 50 | 51 | #print(bn_deconv5) 52 | unpool_4 = unpool_with_argmax(bn_deconv5, ind=pool4_indices, name="unpool_4") 53 | deconv4=conv(unpool_4,5, 5, 256, 1, 1, biased=False, relu=False, name='deconv4') 54 | bn_deconv4=batch_normalization(deconv4,is_training=is_training, activation_fn=tf.nn.relu, name='bn_deconv4') 55 | unpool_3 = unpool_with_argmax(bn_deconv4, ind=pool3_indices, name="unpool_3") 56 | deconv3=conv(unpool_3,5, 5,128, 1, 1, biased=False, relu=False, name='deconv3') 57 | bn_deconv3=batch_normalization(deconv3,is_training=is_training, activation_fn=tf.nn.relu, name='bn_decon3') 58 | print(bn_deconv3) 59 | unpool_2 = unpool_with_argmax(bn_deconv3, ind=pool2_indices, name="unpool_2") 60 | deconv2=conv(unpool_2,5, 5,64, 1, 1, biased=False, relu=False, name='deconv2') 61 | bn_deconv2=batch_normalization(deconv2,is_training=is_training, activation_fn=tf.nn.relu, name='bn_deconv2') 62 | print(bn_deconv2) 63 | unpool_1 = unpool_with_argmax(bn_deconv2, ind=pool1_indices, name="unpool_1") 64 | deconv1=conv(unpool_1,5, 5,32, 1, 1, biased=False, relu=False, name='deconv1') 65 | bn_deconv1=batch_normalization(deconv1,is_training=is_training, activation_fn=tf.nn.relu, name='bn_deconv1') 66 | deconv0=conv(bn_deconv1,5, 5, 1, 1, 1, biased=False, relu=False, name='deconv0') 67 | # bn_deconv0=batch_normalization(deconv0,is_training=is_training, activation_fn=tf.nn.relu, name='bn_deconv0') 68 | # pred = sigmoid(bn_deconv0,name='pred') 69 | 70 | return deconv0; 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | slim = tf.contrib.slim 4 | 5 | DEFAULT_PADDING = 'SAME' 6 | 7 | def make_var(name, shape,trainable=True): 8 | '''Creates a new TensorFlow variable.''' 9 | return tf.get_variable(name, shape, trainable=trainable) 10 | 11 | def batch_normalization(input, name, is_training, activation_fn=None, scale=True): 12 | with tf.variable_scope(name) as scope: 13 | output = slim.batch_norm( 14 | input, 15 | activation_fn=activation_fn, 16 | is_training=is_training, 17 | updates_collections=None, 18 | scale=scale, 19 | scope=scope) 20 | return output 21 | 22 | def conv(input,k_h,k_w,c_o,s_h,s_w,name,relu=True,padding=DEFAULT_PADDING,group=1,biased=True): 23 | c_i = input.get_shape()[-1] 24 | convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding) 25 | with tf.variable_scope(name) as scope: 26 | kernel = make_var('weights', shape=[k_h, k_w, c_i / group, c_o],trainable=True) 27 | if group == 1: 28 | output = convolve(input, kernel) 29 | else: 30 | input_groups = tf.split(3, group, input) 31 | kernel_groups = tf.split(3, group, kernel) 32 | output_groups = [convolve(i, k) for i, k in zip(input_groups, kernel_groups)] 33 | output = tf.concat(3, output_groups) 34 | if biased: 35 | biases = make_var('biases', [c_o]) 36 | output = tf.nn.bias_add(output, biases) 37 | if relu: 38 | output = tf.nn.relu(output, name=scope.name) 39 | return output 40 | 41 | def relu(input): 42 | return tf.nn.relu(input) 43 | 44 | def sigmoid(input, name): 45 | return tf.nn.sigmoid(input, name=name) 46 | 47 | def unpool_with_argmax(pool, ind, name = None, ksize=[1, 2, 2, 1]): 48 | with tf.variable_scope(name): 49 | input_shape = pool.get_shape().as_list() 50 | output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3]) 51 | flat_input_size = np.prod(input_shape) 52 | flat_output_shape = [output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]] 53 | pool_ = tf.reshape(pool, [flat_input_size]) 54 | batch_range = tf.reshape(tf.range(output_shape[0], dtype=ind.dtype), shape=[input_shape[0], 1, 1, 1]) 55 | b = tf.ones_like(ind) * batch_range 56 | b = tf.reshape(b, [flat_input_size, 1]) 57 | ind_ = tf.reshape(ind, [flat_input_size, 1]) 58 | ind_ = tf.concat([b, ind_], 1) 59 | ret = tf.scatter_nd(ind_, pool_, shape=flat_output_shape) 60 | ret = tf.reshape(ret, output_shape) 61 | return ret 62 | 63 | def dropout(input, name): 64 | #keep = 1 - use_dropout + (use_dropout * keep_prob) 65 | return tf.nn.dropout(input,0.75, name=name) 66 | -------------------------------------------------------------------------------- /requirements .txt: -------------------------------------------------------------------------------- 1 | cv2 2 | tensorflow 3 | numpy 4 | time 5 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from model_contour import build_model 4 | from utils import random_crop_and_pad_image_and_labels 5 | import os 6 | import cv2 7 | import time 8 | from tensorflow.python.ops import variables 9 | slim = tf.contrib.slim 10 | flags = tf.app.flags 11 | FLAGS = flags.FLAGS 12 | 13 | flags.DEFINE_integer('max_to_keep', 50, 14 | 'Maximium number of checkpoints to be saved.') 15 | 16 | flags.DEFINE_float('learning_power', 0.9, 17 | 'The power value used in the poly learning policy.') 18 | 19 | flags.DEFINE_integer('Epochs', 100, 20 | 'The number of steps used for training') 21 | 22 | flags.DEFINE_float('momentum', 0.9, 'The momentum value to use') 23 | 24 | flags.DEFINE_integer('train_crop_size', 480 , 25 | 'Image crop size [height, width] during training.') 26 | 27 | flags.DEFINE_string('tf_initial_checkpoint', None, 28 | 'The initial checkpoint in tensorflow format.') 29 | 30 | flags.DEFINE_float('learning_rate', .0000001, 31 | 'Learning rate employed during slow start.') 32 | 33 | flags.DEFINE_string('image_dir', None, 34 | 'The Image Directory.') 35 | 36 | flags.DEFINE_string('label_dir', None, 37 | 'The Label Directory.') 38 | 39 | flags.DEFINE_string('log_dir', None, 40 | 'The Logs Directory.') 41 | 42 | flags.DEFINE_float('clip_by_value', 1.0, 'The value to be used for clipping.') 43 | 44 | 45 | flags.DEFINE_string('train_text', None, 46 | 'The Path to the text file containing names of Images and Labels')###This text file should not have extensions in their names such as 8192.png or 8192.jpg instead just the name such as 8192 47 | 48 | Image_directory = FLAGS.image_dir 49 | Label_directory = FLAGS.label_dir 50 | my_log_dir = FLAGS.log_dir 51 | 52 | def save(saver, sess, logdir, step): 53 | 54 | model_name = 'model.ckpt' 55 | checkpoint_path = os.path.join(logdir, model_name) 56 | 57 | if not os.path.exists(logdir): 58 | os.makedirs(logdir) 59 | saver.save(sess, checkpoint_path, global_step=step) 60 | print('The checkpoint has been created.') 61 | 62 | def load(saver, sess, ckpt_path): 63 | saver.restore(sess, ckpt_path) 64 | print("Restored model parameters from {}".format(ckpt_path)) 65 | 66 | def main(unused_argv): 67 | 68 | image_ph = tf.placeholder(tf.uint8,[1,None,None,3],name='image_placeholder') 69 | label_ph = tf.placeholder(tf.uint8,[1,None,None,1],name='label_placeholder') 70 | size = FLAGS.train_crop_size 71 | image,label=random_crop_and_pad_image_and_labels(tf.squeeze(image_ph),tf.squeeze(label_ph,axis=0),size,size) 72 | norm_image = tf.image.per_image_standardization(tf.squeeze(image)) 73 | norm_image = tf.expand_dims(norm_image,dim=0) 74 | print(norm_image) 75 | print(label_ph) 76 | 77 | pred = build_model(norm_image) 78 | one_hot_labels = slim.one_hot_encoding( 79 | tf.cast(label,dtype=tf.uint8), 1, on_value=1.0, off_value=0.0) 80 | 81 | total_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.cast(tf.squeeze(one_hot_labels),dtype=tf.float32),logits =tf.squeeze(pred) ) 82 | total_loss = tf.reduce_sum(total_loss) 83 | all_trainables = tf.trainable_variables() 84 | 85 | total_loss_scalar = tf.summary.scalar("total_cost", total_loss) 86 | saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=FLAGS.max_to_keep) 87 | train_summary_op = tf.summary.merge([total_loss_scalar]) 88 | train_writer = tf.summary.FileWriter(my_log_dir+'/train', 89 | graph=tf.get_default_graph()) 90 | optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum) 91 | grads = tf.gradients(total_loss, all_trainables) 92 | grads_and_vars = zip(grads, all_trainables) 93 | if FLAGS.clip_by_value: 94 | clipped_value=[(tf.clip_by_value(grad, -FLAGS.clip_by_value, +FLAGS.clip_by_value), var) for grad, var in grads_and_vars] 95 | train_op = optimizer.apply_gradients(clipped_value) 96 | else: 97 | train_op = optimizer.apply_gradients(grads_and_vars) 98 | 99 | init = variables.global_variables_initializer() 100 | with tf.Session() as sess: 101 | sess.run(init) 102 | if FLAGS.tf_initial_checkpoint==True: 103 | load(loader, sess, FLAGS.tf_initial_checkpoint) 104 | print('Training Starts........') 105 | step_iter = 0 106 | for epoch in range(FLAGS.Epochs): 107 | save(saver, sess, my_log_dir, step_iter) 108 | i=0; 109 | f = open(FLAGS.train_text,'r') 110 | message = f.read() 111 | lines = message.split('\n') 112 | for l in lines: 113 | step_iter =step_iter+1 114 | i=i+1; 115 | try : 116 | input_image = cv2.imread(Image_directory+l+'.jpg') 117 | labs_person = cv2.imread(Label_directory+l+'.png',0) 118 | labs_person = labs_person/255.0 119 | labs_person = np.expand_dims(labs_person,axis=0) 120 | labs_person = np.expand_dims(labs_person,axis=3) 121 | input_image = np.expand_dims(input_image,axis=0) 122 | start_time = time.time() 123 | feed_dict={image_ph:input_image,label_ph:labs_person} 124 | L,P,_,sum_op = sess.run([total_loss,pred,train_op,train_summary_op], feed_dict=feed_dict) 125 | train_writer.add_summary(sum_op, step_iter) 126 | duration = time.time() - start_time 127 | print('::Step::'+str(epoch)+','+str(i), '::total_loss::'+ str(L),'::time::'+str(duration)) 128 | except: 129 | print("ERROR") 130 | 131 | if __name__ == '__main__': 132 | flags.mark_flag_as_required('image_dir') 133 | flags.mark_flag_as_required('label_dir') 134 | flags.mark_flag_as_required('log_dir') 135 | flags.mark_flag_as_required('train_text') 136 | tf.app.run() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | def random_crop_and_pad_image_and_labels(image, label, crop_h, crop_w, ignore_label=255): 3 | combined = tf.concat(axis=2, values=[image, label]) 4 | image_shape = tf.shape(image) 5 | combined_pad = tf.image.pad_to_bounding_box(combined, 0, 0, tf.maximum(crop_h, image_shape[0]), tf.maximum(crop_w, image_shape[1])) 6 | last_image_dim = tf.shape(image)[-1] 7 | last_label_dim = tf.shape(label)[-1] 8 | combined_crop = tf.random_crop(combined_pad, [crop_h,crop_w,4]) 9 | img_crop = combined_crop[:, :, :last_image_dim] 10 | label_crop = combined_crop[:, :, last_image_dim:] 11 | label_crop = label_crop + ignore_label 12 | label_crop = tf.cast(label_crop, dtype=tf.uint8) 13 | img_crop.set_shape((crop_h, crop_w, 3)) 14 | label_crop.set_shape((crop_h,crop_w, 1)) 15 | return img_crop, label_crop 16 | 17 | def random_crop_and_pad_image(image, crop_h, crop_w, ignore_label=255): 18 | image_shape = tf.shape(image) 19 | pad = tf.image.pad_to_bounding_box(image, 0, 0, tf.maximum(crop_h, image_shape[0]), tf.maximum(crop_w, image_shape[1])) 20 | last_image_dim = tf.shape(image)[-1] 21 | img_crop = tf.random_crop(pad, [crop_h,crop_w,3]) 22 | img_crop.set_shape((crop_h, crop_w, 3)) 23 | return img_crop --------------------------------------------------------------------------------