├── __init__.py ├── nets ├── __init__.py ├── resnet_utils.py └── resnet_v1.py ├── .DS_Store ├── tx_infer_data ├── 1.jpg ├── 2.jpg ├── .DS_Store ├── col │ ├── 1.jpg │ └── 2.jpg ├── row │ ├── 1.jpg │ └── 2.jpg ├── ncol │ ├── 1.jpg │ ├── 2.jpg │ └── vanke_2016_1241_nb_3.jpg ├── nrow │ ├── 1.jpg │ ├── 2.jpg │ └── vanke_2016_1241_nb_3.jpg └── vanke_2016_1241_nb_3.jpg ├── model └── checkpoint ├── requirements.txt ├── README.md ├── post_tx.py ├── data_util.py ├── inference.py ├── model.py ├── dataf.py ├── data_f.py └── train.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/tableImageParser_tx/HEAD/.DS_Store -------------------------------------------------------------------------------- /tx_infer_data/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/tableImageParser_tx/HEAD/tx_infer_data/1.jpg -------------------------------------------------------------------------------- /tx_infer_data/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/tableImageParser_tx/HEAD/tx_infer_data/2.jpg -------------------------------------------------------------------------------- /tx_infer_data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/tableImageParser_tx/HEAD/tx_infer_data/.DS_Store -------------------------------------------------------------------------------- /tx_infer_data/col/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/tableImageParser_tx/HEAD/tx_infer_data/col/1.jpg -------------------------------------------------------------------------------- /tx_infer_data/col/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/tableImageParser_tx/HEAD/tx_infer_data/col/2.jpg -------------------------------------------------------------------------------- /tx_infer_data/row/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/tableImageParser_tx/HEAD/tx_infer_data/row/1.jpg -------------------------------------------------------------------------------- /tx_infer_data/row/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/tableImageParser_tx/HEAD/tx_infer_data/row/2.jpg -------------------------------------------------------------------------------- /model/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "./model.ckpt-11111" 2 | all_model_checkpoint_paths: "./model.ckpt-11111" 3 | -------------------------------------------------------------------------------- /tx_infer_data/ncol/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/tableImageParser_tx/HEAD/tx_infer_data/ncol/1.jpg -------------------------------------------------------------------------------- /tx_infer_data/ncol/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/tableImageParser_tx/HEAD/tx_infer_data/ncol/2.jpg -------------------------------------------------------------------------------- /tx_infer_data/nrow/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/tableImageParser_tx/HEAD/tx_infer_data/nrow/1.jpg -------------------------------------------------------------------------------- /tx_infer_data/nrow/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/tableImageParser_tx/HEAD/tx_infer_data/nrow/2.jpg -------------------------------------------------------------------------------- /tx_infer_data/vanke_2016_1241_nb_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/tableImageParser_tx/HEAD/tx_infer_data/vanke_2016_1241_nb_3.jpg -------------------------------------------------------------------------------- /tx_infer_data/ncol/vanke_2016_1241_nb_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/tableImageParser_tx/HEAD/tx_infer_data/ncol/vanke_2016_1241_nb_3.jpg -------------------------------------------------------------------------------- /tx_infer_data/nrow/vanke_2016_1241_nb_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/tableImageParser_tx/HEAD/tx_infer_data/nrow/vanke_2016_1241_nb_3.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Shapely==1.5.13 2 | Flask==0.10.1 3 | matplotlib==1.5.1 4 | scipy==0.19.0 5 | plumbum==1.6.2 6 | numpy==1.12.1 7 | ipython==6.1.0 8 | Pillow==4.2.1 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tableImageParser 2 | 3 | ### 本项目通过腾讯分享文章,复现了腾讯表格结果解析文章地址为https://zhuanlan.zhihu.com/p/69793742 (Table Structure Recognition from Tencent) 4 | 5 | ## 环境(Requirements) 6 | ```pip install -r requirements.txt``` 7 | 8 | ## 例子🌰(Demo) 9 | - 修改inference 中main函数所需路径 10 | 11 | ```python inference.py``` 12 | 13 | ## 训练(train) 14 | - 修改train.py 中checkpoint_path 为模型路径 15 | - 修改dataf.py 中training_data_path 为训练数据路径 16 | 17 | ```python train.py``` 18 | 19 | ## 可视化实例 20 | ### 例子🌰1 21 | ![raw](https://github.com/tommyMessi/tableImageParser_tx/blob/master/tx_infer_data/vanke_2016_1241_nb_3.jpg) 22 | ![nrow](https://github.com/tommyMessi/tableImageParser_tx/blob/master/tx_infer_data/nrow/vanke_2016_1241_nb_3.jpg) 23 | ![ncol](https://github.com/tommyMessi/tableImageParser_tx/blob/master/tx_infer_data/ncol/vanke_2016_1241_nb_3.jpg) 24 | ### 例子🌰2 25 | ![raw](https://github.com/tommyMessi/tableImageParser_tx/blob/master/tx_infer_data/1.jpg) 26 | ![row](https://github.com/tommyMessi/tableImageParser_tx/blob/master/tx_infer_data/row/1.jpg) 27 | ![row](https://github.com/tommyMessi/tableImageParser_tx/tree/master/tx_infer_data/col) 28 | 29 | ## 更新 8.17 30 | - 修改dataf.py 中的数据预处理,具体效果 根据自己数据进行适量修改。(数据增强对效果还很重要) 31 | - 添加post.py 用于后处理的demo。具体思路通过霍夫变换的HoughLinesP函数的线段合集,进行直线断连接修复。 32 | - 预训练模型: 链接: https://pan.baidu.com/s/1JXEKuWYtbyF6vFGQIzyE6g 提取码: 4mbb 33 | 34 | ## 其他 35 | 训练数据与预训练模型 关注微信公众账号 hulugeAI 留言:table parser 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /post_tx.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | import numpy as np 4 | import os 5 | import math 6 | 7 | def line_row_gen(img_path): 8 | 9 | img = cv2.imread( img_path ) 10 | img_temp = np.ones_like(img) *255 11 | 12 | gray = cv2.cvtColor( img,cv2.COLOR_BGR2GRAY ) 13 | ret, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV) 14 | 15 | edges = cv2.Canny( gray,50,150,apertureSize = 3 ) 16 | # cv2.imshow('edge', edges) 17 | # cv2.imshow('binary',binary) 18 | # cv2.waitKey(0) 19 | 20 | minLineLength = 100 21 | maxLineGap = 100 22 | lines = cv2.HoughLinesP( binary,1,np.pi/180,100,minLineLength=minLineLength,maxLineGap=maxLineGap ) 23 | try: 24 | for line in lines: 25 | for x1,y1,x2,y2 in line: 26 | cv2.line( img,( x1,y1 ),( x2,y2 ),( 0,255,0 ),2 ) 27 | except: 28 | return img_temp, 0 29 | # points = [(box[0], box[1]), (box[2],box[1]), (box[2], box[3]), (box[0], box[3])] 30 | # cv2.fillPoly(image,[np.array(points)],(255,0,0)) 31 | # cv2.imwrite( 'E:/image/myhoughlinesp.jpg',img ) 32 | cv2.imshow( '1',img ) 33 | cv2.waitKey(0) 34 | return img_temp, 1 35 | 36 | def angle(v1, v2): 37 | dx1 = v1[2] - v1[0] 38 | dy1 = v1[3] - v1[1] 39 | dx2 = v2[2] - v2[0] 40 | dy2 = v2[3] - v2[1] 41 | angle1 = math.atan2(dy1, dx1) 42 | angle1 = int(angle1 * 180/math.pi) 43 | angle2 = math.atan2(dy2, dx2) 44 | angle2 = int(angle2 * 180/math.pi) 45 | if angle1*angle2 >= 0: 46 | included_angle = abs(angle1-angle2) 47 | else: 48 | included_angle = abs(angle1) + abs(angle2) 49 | if included_angle > 180: 50 | included_angle = 360 - included_angle 51 | return included_angle 52 | 53 | def line_col_gen(img_path): 54 | AB = [0,0,100,0] 55 | 56 | img = cv2.imread( img_path ) 57 | img_temp = np.ones_like(img) *255 58 | 59 | gray = cv2.cvtColor( img,cv2.COLOR_BGR2GRAY ) 60 | ret, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV) 61 | 62 | edges = cv2.Canny( gray,50,150,apertureSize = 3 ) 63 | # cv2.imshow('edge', edges) 64 | # cv2.imshow('binary',binary) 65 | # cv2.waitKey(0) 66 | 67 | minLineLength = 100 68 | maxLineGap = 100 69 | lines = cv2.HoughLinesP( binary,1,np.pi/180,100,minLineLength=20,maxLineGap=50 ) 70 | try: 71 | for line in lines: 72 | for x1,y1,x2,y2 in line: 73 | CD = [x1, y1, x2, y2] 74 | angle_cross = angle(AB, CD) 75 | if angle_cross<90+15 and angle_cross>90-15: 76 | cv2.line( img,( x1,y1 ),( x2,y2 ),( 255,0,0 ),2 ) 77 | except: 78 | return img_temp, 0 79 | # points = [(box[0], box[1]), (box[2],box[1]), (box[2], box[3]), (box[0], box[3])] 80 | # cv2.fillPoly(image,[np.array(points)],(255,0,0)) 81 | # cv2.imwrite( 'E:/image/myhoughlinesp.jpg',img ) 82 | cv2.imshow( '2',img ) 83 | cv2.waitKey(0) 84 | return img_temp, 1 85 | 86 | def tx_post(row_path, nrow_path, col_path, ncol_path): 87 | row_image, is_row_exist = line_row_gen(row_path) 88 | nrow_image, is_nrow_exist = line_row_gen(nrow_path) 89 | col_image, is_col_exist = line_col_gen(col_path) 90 | ncol_image, is_ncol_exist = line_col_gen(ncol_path) 91 | 92 | 93 | if __name__ == '__main__': 94 | img_root = r'\result\test' 95 | col_root = r'\result\col' 96 | row_root = r'\result\row' 97 | ncol_root = r'\result\ncol' 98 | nrow_root = r'\result\nrow' 99 | 100 | 101 | img_names = os.listdir(col_root) 102 | for img_name in img_names: 103 | col_path = os.path.join(col_root, img_name) 104 | ncol_path = os.path.join(ncol_root, img_name) 105 | row_path = os.path.join(row_root, img_name) 106 | nrow_path = os.path.join(nrow_root, img_name) 107 | # save_path = os.path.join(save_root, img_name) 108 | tx_post(row_path, nrow_path, col_path, ncol_path) 109 | -------------------------------------------------------------------------------- /data_util.py: -------------------------------------------------------------------------------- 1 | ''' 2 | this file is modified from keras implemention of data process multi-threading, 3 | see https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py 4 | ''' 5 | import time 6 | import numpy as np 7 | import threading 8 | import multiprocessing 9 | try: 10 | import queue 11 | except ImportError: 12 | import Queue as queue 13 | 14 | 15 | class GeneratorEnqueuer(): 16 | """Builds a queue out of a data generator. 17 | 18 | Used in `fit_generator`, `evaluate_generator`, `predict_generator`. 19 | 20 | # Arguments 21 | generator: a generator function which endlessly yields data 22 | use_multiprocessing: use multiprocessing if True, otherwise threading 23 | wait_time: time to sleep in-between calls to `put()` 24 | random_seed: Initial seed for workers, 25 | will be incremented by one for each workers. 26 | """ 27 | 28 | def __init__(self, generator, 29 | use_multiprocessing=False, 30 | wait_time=0.05, 31 | random_seed=None): 32 | self.wait_time = wait_time 33 | self._generator = generator 34 | self._use_multiprocessing = use_multiprocessing 35 | self._threads = [] 36 | self._stop_event = None 37 | self.queue = None 38 | self.random_seed = random_seed 39 | 40 | def start(self, workers=1, max_queue_size=10): 41 | """Kicks off threads which add data from the generator into the queue. 42 | 43 | # Arguments 44 | workers: number of worker threads 45 | max_queue_size: queue size 46 | (when full, threads could block on `put()`) 47 | """ 48 | 49 | def data_generator_task(): 50 | while not self._stop_event.is_set(): 51 | try: 52 | if self._use_multiprocessing or self.queue.qsize() < max_queue_size: 53 | generator_output = next(self._generator) 54 | self.queue.put(generator_output) 55 | else: 56 | time.sleep(self.wait_time) 57 | except Exception: 58 | self._stop_event.set() 59 | raise 60 | 61 | try: 62 | if self._use_multiprocessing: 63 | self.queue = multiprocessing.Queue(maxsize=max_queue_size) 64 | self._stop_event = multiprocessing.Event() 65 | else: 66 | self.queue = queue.Queue() 67 | self._stop_event = threading.Event() 68 | 69 | for _ in range(workers): 70 | if self._use_multiprocessing: 71 | # Reset random seed else all children processes 72 | # share the same seed 73 | np.random.seed(self.random_seed) 74 | thread = multiprocessing.Process(target=data_generator_task) 75 | thread.daemon = True 76 | if self.random_seed is not None: 77 | self.random_seed += 1 78 | else: 79 | thread = threading.Thread(target=data_generator_task) 80 | self._threads.append(thread) 81 | thread.start() 82 | except: 83 | self.stop() 84 | raise 85 | 86 | def is_running(self): 87 | return self._stop_event is not None and not self._stop_event.is_set() 88 | 89 | def stop(self, timeout=None): 90 | """Stops running threads and wait for them to exit, if necessary. 91 | 92 | Should be called by the same thread which called `start()`. 93 | 94 | # Arguments 95 | timeout: maximum time to wait on `thread.join()`. 96 | """ 97 | if self.is_running(): 98 | self._stop_event.set() 99 | 100 | for thread in self._threads: 101 | if thread.is_alive(): 102 | if self._use_multiprocessing: 103 | thread.terminate() 104 | else: 105 | thread.join(timeout) 106 | 107 | if self._use_multiprocessing: 108 | if self.queue is not None: 109 | self.queue.close() 110 | 111 | self._threads = [] 112 | self._stop_event = None 113 | self.queue = None 114 | 115 | def get(self): 116 | """Creates a generator to extract data from the queue. 117 | 118 | Skip the data if it is `None`. 119 | 120 | # Returns 121 | A generator 122 | """ 123 | while self.is_running(): 124 | if not self.queue.empty(): 125 | inputs = self.queue.get() 126 | if inputs is not None: 127 | yield inputs 128 | else: 129 | time.sleep(self.wait_time) -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | import logging 6 | log = logging.getLogger(__name__) 7 | 8 | import model 9 | import time 10 | 11 | import os 12 | import random 13 | 14 | class Detector(object): 15 | def __init__(self,model_dir): 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 17 | config = tf.ConfigProto(allow_soft_placement=True) 18 | config.gpu_options.per_process_gpu_memory_fraction = 1.0 19 | config.gpu_options.allow_growth = True 20 | self.input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images') 21 | self.session = tf.Session(config=config) 22 | self.global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) 23 | self.score_nrow, self.score_ncol, self.score_row, self.score_col = model_tx.model(self.input_images, is_training=False) 24 | self.variable_averages = tf.train.ExponentialMovingAverage(0.997, self.global_step) 25 | self.saver = tf.train.Saver(self.variable_averages.variables_to_restore()) 26 | self.ckpt_state = tf.train.get_checkpoint_state(model_dir) 27 | print(self.ckpt_state) 28 | self.model_path = os.path.join(model_dir, os.path.basename(self.ckpt_state.model_checkpoint_path)) 29 | print(self.model_path) 30 | self.saver.restore(self.session,self.model_path) 31 | 32 | 33 | def main_detection(self, image): 34 | # img_e_c = image[:,:,::-1] 35 | img_e = np.expand_dims(image, axis=2) 36 | img_e_c = np.concatenate((img_e, img_e, img_e), axis=-1) 37 | im_resized, (ratio_h, ratio_w) = resize_image(img_e_c) 38 | score_nrow, score_ncol, score_row, score_col = self.session.run([self.score_nrow, self.score_ncol, self.score_row, self.score_col], feed_dict={self.input_images: [im_resized]}) 39 | return score_nrow[0], score_ncol[0], score_row[0], score_col[0] ,ratio_h, ratio_w 40 | 41 | def resize_image(im): 42 | h, w, _ = im.shape 43 | size = (int(512), int(512)) 44 | im = cv2.resize(im, size, interpolation=cv2.INTER_AREA) 45 | # la_p = cv2.resize(label_im, size, interpolation=cv2.INTER_AREA) 46 | 47 | ratio_h = 512 / float(h) 48 | ratio_w = 512 / float(w) 49 | 50 | return im, (ratio_h, ratio_w) 51 | 52 | def iou_count(list1, list2): 53 | xx1 = np.maximum(list1[0], list2[0]) 54 | yy1 = np.maximum(list1[1], list2[1]) 55 | xx2 = np.minimum(list1[4], list2[4]) 56 | yy2 = np.minimum(list1[5], list2[5]) 57 | 58 | w = np.maximum(0.0, xx2 - xx1 + 1) 59 | h = np.maximum(0.0, yy2 - yy1 + 1) 60 | w = np.maximum(0.0, xx2 - xx1 + 1) 61 | h = np.maximum(0.0, yy2 - yy1 + 1) 62 | 63 | inter = w * h 64 | area1 = (list1[4] - list1[0] + 1) * (list1[5] - list1[1] + 1) 65 | area2 = (list2[4] - list2[0] + 1) * (list2[5] - list2[1] + 1) 66 | iou = inter / min(area1, area2) 67 | return iou 68 | 69 | if __name__ == '__main__': 70 | 71 | result_path = './result/' 72 | instance = Detector('./model/') 73 | images = os.listdir('./image/') 74 | 75 | row_root = './tx_infer_data/row' 76 | col_root = './tx_infer_data/col' 77 | nrow_root = './tx_infer_data/nrow' 78 | ncol_root = './tx_infer_data/ncol' 79 | 80 | 81 | i_l = [] 82 | for x in range(len(images)): 83 | print(images[x]) 84 | image_path = os.path.join('./image/',images[x]) 85 | image_name = images[x] 86 | txt_name = image_name.replace('.jpg','.txt') 87 | row_path = os.path.join(row_root, image_name) 88 | col_path = os.path.join(col_root, image_name) 89 | nrow_path = os.path.join(nrow_root, image_name) 90 | ncol_path = os.path.join(ncol_root, image_name) 91 | 92 | image = cv2.imread(image_path, 0) 93 | # image = cv2.imread(image_path) 94 | image_color = cv2.imread(image_path) 95 | # instance.table_detection(image, image_color) 96 | score_nrow, score_ncol, score_row, score_col, ratio_h, ratio_w = instance.main_detection(image) 97 | 98 | score_nrow = np.where(score_nrow > 0.9, score_nrow, 0) 99 | score_nrow = np.where(score_nrow < 0.9, score_nrow, 1) 100 | 101 | score_ncol = np.where(score_ncol > 0.9, score_ncol, 0) 102 | score_ncol = np.where(score_ncol < 0.9, score_ncol, 1) 103 | 104 | score_row = np.where(score_row > 0.9, score_row, 0) 105 | score_row = np.where(score_row < 0.9, score_row, 1) 106 | 107 | score_col = np.where(score_col > 0.9, score_col, 0) 108 | score_col = np.where(score_col < 0.9, score_col, 1) 109 | 110 | nmap = cv2.bitwise_and(score_nrow, score_ncol) 111 | lmap = cv2.bitwise_and(score_row, score_col) 112 | pre_map = cv2.bitwise_and(nmap, lmap) 113 | 114 | result = os.path.join(result_path, images[x]) 115 | score_nrow_map = cv2.resize(score_nrow, dsize=None, fx=1/ratio_w, fy=1/ratio_h, interpolation=cv2.INTER_AREA) 116 | score_ncol_map = cv2.resize(score_ncol, dsize=None, fx=1 / ratio_w, fy=1 / ratio_h, interpolation=cv2.INTER_AREA) 117 | score_row_map = cv2.resize(score_row, dsize=None, fx=1 / ratio_w, fy=1 / ratio_h, interpolation=cv2.INTER_AREA) 118 | score_col_map = cv2.resize(score_col, dsize=None, fx=1 / ratio_w, fy=1 / ratio_h, interpolation=cv2.INTER_AREA) 119 | pre_map = cv2.resize(pre_map, dsize=None, fx=1 / ratio_w, fy=1 / ratio_h, interpolation=cv2.INTER_AREA) 120 | # mask_result = os.path.join(result_path, 'mask_'+images[x]) 121 | # print(mask_result) 122 | cv2.imwrite(row_path, score_row_map*255) 123 | cv2.imwrite(col_path, score_col_map*255) 124 | cv2.imwrite(nrow_path, score_nrow_map * 255) 125 | cv2.imwrite(ncol_path, score_ncol_map * 255) 126 | cv2.imwrite(result, pre_map*255) 127 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from tensorflow.contrib import slim 5 | 6 | tf.app.flags.DEFINE_integer('text_scale', 512, '') 7 | 8 | from nets import resnet_v1 9 | 10 | FLAGS = tf.app.flags.FLAGS 11 | 12 | 13 | def unpool(inputs): 14 | return tf.image.resize_bilinear(inputs, size=[tf.shape(inputs)[1]*2, tf.shape(inputs)[2]*2]) 15 | 16 | 17 | def mean_image_subtraction(images, means=[123.68, 116.78, 103.94]): 18 | ''' 19 | image normalization 20 | :param images: 21 | :param means: 22 | :return: 23 | ''' 24 | num_channels = images.get_shape().as_list()[-1] 25 | if len(means) != num_channels: 26 | raise ValueError('len(means) must match the number of channels') 27 | channels = tf.split(axis=3, num_or_size_splits=num_channels, value=images) 28 | for i in range(num_channels): 29 | channels[i] -= means[i] 30 | return tf.concat(axis=3, values=channels) 31 | 32 | 33 | def model(images, weight_decay=1e-5, is_training=True): 34 | ''' 35 | define the model, we use slim's implemention of resnet 36 | ''' 37 | images = mean_image_subtraction(images) 38 | 39 | with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=weight_decay)): 40 | logits, end_points = resnet_v1.resnet_v1_50(images, is_training=is_training, scope='resnet_v1_50') 41 | 42 | with tf.variable_scope('feature_fusion', values=[end_points.values]): 43 | batch_norm_params = { 44 | 'decay': 0.997, 45 | 'epsilon': 1e-5, 46 | 'scale': True, 47 | 'is_training': is_training 48 | } 49 | with slim.arg_scope([slim.conv2d], 50 | activation_fn=tf.nn.relu, 51 | normalizer_fn=slim.batch_norm, 52 | normalizer_params=batch_norm_params, 53 | weights_regularizer=slim.l2_regularizer(weight_decay)): 54 | f = [end_points['pool5'], end_points['pool4'], 55 | end_points['pool3'], end_points['pool2']] 56 | for i in range(4): 57 | print('Shape of f_{} {}'.format(i, f[i].shape)) 58 | g = [None, None, None, None] 59 | h = [None, None, None, None] 60 | num_outputs = [None, 128, 64, 32] 61 | for i in range(4): 62 | if i == 0: 63 | h[i] = f[i] 64 | else: 65 | c1_1 = slim.conv2d(tf.concat([g[i-1], f[i]], axis=-1), num_outputs[i], 1) 66 | h[i] = slim.conv2d(c1_1, num_outputs[i], 3) 67 | if i <= 2: 68 | g[i] = unpool(h[i]) 69 | else: 70 | g[i] = slim.conv2d(h[i], num_outputs[i], 3) 71 | print('Shape of h_{} {}, g_{} {}'.format(i, h[i].shape, i, g[i].shape)) 72 | 73 | # here we use a slightly different way for regression part, 74 | # we first use a sigmoid to limit the regression range, and also 75 | # this is do with the angle map 76 | F_score_nrow = slim.conv2d(g[3], 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) 77 | 78 | F_score_ncol = slim.conv2d(g[3], 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) 79 | 80 | F_score_row = slim.conv2d(g[3], 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) 81 | 82 | F_score_col = slim.conv2d(g[3], 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) 83 | 84 | return F_score_nrow, F_score_ncol, F_score_row, F_score_col 85 | 86 | 87 | def dice_coefficient(y_true_cls, y_pred_cls, 88 | training_mask): 89 | ''' 90 | dice loss 91 | :param y_true_cls: 92 | :param y_pred_cls: 93 | :param training_mask: 94 | :return: 95 | ''' 96 | eps = 1e-5 97 | intersection = tf.reduce_sum(y_true_cls * y_pred_cls * training_mask) 98 | union = tf.reduce_sum(y_true_cls*training_mask) + tf.reduce_sum(y_pred_cls*training_mask) + eps 99 | loss = 1. - (2 * intersection / union) 100 | tf.summary.scalar('classification_dice_loss', loss) 101 | return loss 102 | 103 | def focal_loss(y_true_cls, y_pred_cls): 104 | ''' 105 | :param y_true_cls: 106 | :param y_pred_cls: 107 | :return: 108 | ''' 109 | gamma = 2 110 | alpha = 0.5 111 | 112 | dim = tf.reduce_prod(tf.shape(y_true_cls)[1:]) 113 | flat_y_true_cls = tf.reshape(y_true_cls, [-1, dim]) 114 | flat_y_pred_cls = tf.reshape(y_pred_cls, [-1, dim]) 115 | pt = flat_y_true_cls*flat_y_pred_cls+(1.0-flat_y_true_cls)*(1.0 - flat_y_pred_cls) 116 | weight_map = alpha*tf.pow((1.0-pt),gamma) 117 | weighted_loss = tf.multiply(((flat_y_true_cls * tf.log(flat_y_pred_cls + 1e-9)) + ((1 - flat_y_true_cls) * tf.log(1 - flat_y_pred_cls + 1e-9))), weight_map) 118 | cross_entropy = -tf.reduce_sum(weighted_loss,axis = 1) 119 | cross_entropy_mean = tf.reduce_mean(cross_entropy) 120 | tf.summary.scalar('classification_focal_loss', cross_entropy_mean) 121 | return cross_entropy_mean 122 | 123 | 124 | def loss(y_true_cls_nrow, y_pred_cls_nrow, 125 | y_true_cls_ncol, y_pred_cls_ncol, 126 | y_true_cls_row, y_pred_cls_row, 127 | y_true_cls_col, y_pred_cls_col, 128 | training_mask): 129 | ''' 130 | define the loss used for training, contraning two part, 131 | the first part we use dice loss instead of weighted logloss, 132 | the second part is the iou loss defined in the paper 133 | :param training_mask: mask used in training, to ignore some text annotated by ### 134 | :return: 135 | ''' 136 | classification_loss_nrow = dice_coefficient(y_true_cls_nrow, y_pred_cls_nrow, training_mask) 137 | classification_loss_ncol = dice_coefficient(y_true_cls_ncol, y_pred_cls_ncol, training_mask) 138 | classification_loss_row = dice_coefficient(y_true_cls_row, y_pred_cls_row, training_mask) 139 | classification_loss_col = dice_coefficient(y_true_cls_col, y_pred_cls_col, training_mask) 140 | 141 | 142 | return tf.reduce_mean(classification_loss_row+classification_loss_ncol+classification_loss_nrow+classification_loss_col) 143 | -------------------------------------------------------------------------------- /dataf.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import glob 3 | import csv 4 | import cv2 5 | import time 6 | import os 7 | import numpy as np 8 | from shapely.geometry import Polygon 9 | 10 | import tensorflow as tf 11 | 12 | from data_util import GeneratorEnqueuer 13 | 14 | tf.app.flags.DEFINE_string('training_data_path', './tx_data/image', 15 | 'training dataset to use') 16 | 17 | 18 | FLAGS = tf.app.flags.FLAGS 19 | 20 | 21 | def get_images(): 22 | files = [] 23 | for ext in ['jpg', 'png', 'jpeg', 'JPG']: 24 | files.extend(glob.glob( 25 | os.path.join(FLAGS.training_data_path, '*.{}'.format(ext)))) 26 | return files 27 | 28 | 29 | def load_annoataion(p): 30 | ''' 31 | load annotation from the text file 32 | :param p: 33 | :return: 34 | ''' 35 | text_polys = [] 36 | text_tags = [] 37 | if not os.path.exists(p): 38 | return np.array(text_polys, dtype=np.float32) 39 | with open(p, 'r') as f: 40 | reader = csv.reader(f) 41 | for line in reader: 42 | label = line[-1] 43 | # strip BOM. \ufeff for python3, \xef\xbb\bf for python2 44 | line = [i.strip('\ufeff').strip('\xef\xbb\xbf') for i in line] 45 | 46 | x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, line[:8])) 47 | text_polys.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) 48 | if label == '*' or label == '###': 49 | text_tags.append(True) 50 | else: 51 | text_tags.append(False) 52 | return np.array(text_polys, dtype=np.float32), np.array(text_tags, dtype=np.bool) 53 | 54 | 55 | 56 | def crop_area(im, label_im,crop_background=False, max_tries=150): 57 | size = (int(512), int(512)) 58 | im_p = cv2.resize(im, size, interpolation=cv2.INTER_AREA) 59 | la_p = cv2.resize(label_im, size, interpolation=cv2.INTER_AREA) 60 | return im_p,la_p 61 | 62 | 63 | def point_dist_to_line(p1, p2, p3): 64 | # compute the distance from p3 to p1-p2 65 | return np.linalg.norm(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1) 66 | 67 | 68 | def fit_line(p1, p2): 69 | # fit a line ax+by+c = 0 70 | if p1[0] == p1[1]: 71 | return [1., 0., -p1[0]] 72 | else: 73 | [k, b] = np.polyfit(p1, p2, deg=1) 74 | return [k, -1., b] 75 | 76 | 77 | def line_cross_point(line1, line2): 78 | # line1 0= ax+by+c, compute the cross point of line1 and line2 79 | if line1[0] != 0 and line1[0] == line2[0]: 80 | print('Cross point does not exist') 81 | return None 82 | if line1[0] == 0 and line2[0] == 0: 83 | print('Cross point does not exist') 84 | return None 85 | if line1[1] == 0: 86 | x = -line1[2] 87 | y = line2[0] * x + line2[2] 88 | elif line2[1] == 0: 89 | x = -line2[2] 90 | y = line1[0] * x + line1[2] 91 | else: 92 | k1, _, b1 = line1 93 | k2, _, b2 = line2 94 | x = -(b1-b2)/(k1-k2) 95 | y = k1*x + b1 96 | return np.array([x, y], dtype=np.float32) 97 | 98 | 99 | def line_verticle(line, point): 100 | # get the verticle line from line across point 101 | if line[1] == 0: 102 | verticle = [0, -1, point[1]] 103 | else: 104 | if line[0] == 0: 105 | verticle = [1, 0, -point[0]] 106 | else: 107 | verticle = [-1./line[0], -1, point[1] - (-1/line[0] * point[0])] 108 | return verticle 109 | 110 | 111 | def generator_label(label_im, label_str): 112 | label_name = label_str.split('/')[-1] 113 | h, w = label_im.shape 114 | score_map = np.zeros((h, w), dtype=np.uint8) 115 | for i in range(h): 116 | for j in range(w): 117 | if label_im[i][j] == 0: 118 | score_map[i][j] = 0 119 | else: 120 | score_map[i][j] = 1 121 | 122 | return score_map 123 | 124 | def generator(input_size=512, batch_size=32, 125 | background_ratio=3./8, 126 | random_scale=np.array([0.5, 1, 2.0, 3.0]), 127 | vis=True): 128 | image_list = np.array(get_images()) 129 | print('{} training images in {}'.format( 130 | image_list.shape[0], FLAGS.training_data_path)) 131 | index = np.arange(0, image_list.shape[0]) 132 | while True: 133 | np.random.shuffle(index) 134 | images = [] 135 | image_fns = [] 136 | score_maps_nrow = [] 137 | 138 | score_maps_ncol = [] 139 | 140 | score_maps_row = [] 141 | 142 | score_maps_col = [] 143 | training_masks = [] 144 | for i in index: 145 | try: 146 | im_fn = image_list[i] 147 | im = cv2.imread(im_fn) 148 | if '.png' in im_fn: 149 | im_fn = im_fn.replace('.png','.jpg') 150 | 151 | h, w, _ = im.shape 152 | label_fn_nrow = im_fn.replace('image', 'label_nrow') 153 | label_fn_ncol = im_fn.replace('image', 'label_ncol') 154 | label_fn_row = im_fn.replace('image', 'label_row') 155 | label_fn_col = im_fn.replace('image', 'label_col') 156 | 157 | if not os.path.exists(label_fn_nrow): 158 | print('text file {} does not exists'.format(label_fn_nrow)) 159 | continue 160 | if not os.path.exists(label_fn_ncol): 161 | print('text file {} does not exists'.format(label_fn_ncol)) 162 | continue 163 | if not os.path.exists(label_fn_row): 164 | print('text file {} does not exists'.format(label_fn_row)) 165 | continue 166 | if not os.path.exists(label_fn_col): 167 | print('text file {} does not exists'.format(label_fn_col)) 168 | continue 169 | label_im_nrow = cv2.imread(label_fn_nrow, cv2.IMREAD_GRAYSCALE) 170 | label_im_ncol = cv2.imread(label_fn_ncol, cv2.IMREAD_GRAYSCALE) 171 | label_im_row = cv2.imread(label_fn_row, cv2.IMREAD_GRAYSCALE) 172 | label_im_col = cv2.imread(label_fn_col, cv2.IMREAD_GRAYSCALE) 173 | 174 | score_map_nrow = generator_label(label_im_nrow, label_fn_nrow) 175 | score_map_ncol = generator_label(label_im_ncol, label_fn_ncol) 176 | score_map_row = generator_label(label_im_row, label_fn_row) 177 | score_map_col = generator_label(label_im_col, label_fn_col) 178 | 179 | im, score_map_nrow = crop_area(im, score_map_nrow, crop_background=True) 180 | im, score_map_ncol = crop_area(im, score_map_ncol, crop_background=True) 181 | im, score_map_row = crop_area(im, score_map_row, crop_background=True) 182 | im, score_map_col = crop_area(im, score_map_col, crop_background=True) 183 | im = cv2.resize(im, dsize=(input_size, input_size), interpolation=cv2.INTER_AREA) 184 | 185 | score_map_nrow = cv2.resize(score_map_nrow, dsize=(input_size, input_size), interpolation=cv2.INTER_AREA) 186 | score_map_ncol = cv2.resize(score_map_ncol, dsize=(input_size, input_size), interpolation=cv2.INTER_AREA) 187 | score_map_row = cv2.resize(score_map_row, dsize=(input_size, input_size), interpolation=cv2.INTER_AREA) 188 | score_map_col = cv2.resize(score_map_col, dsize=(input_size, input_size), interpolation=cv2.INTER_AREA) 189 | 190 | training_mask = np.ones((input_size, input_size), dtype=np.uint8) 191 | 192 | images.append(im[:, :, ::-1].astype(np.float32)) 193 | image_fns.append(im_fn) 194 | 195 | score_maps_nrow.append(score_map_nrow[::2, ::2, np.newaxis].astype(np.float32)) 196 | 197 | score_maps_ncol.append(score_map_ncol[::2, ::2, np.newaxis].astype(np.float32)) 198 | 199 | score_maps_row.append(score_map_row[::2, ::2, np.newaxis].astype(np.float32)) 200 | 201 | score_maps_col.append(score_map_col[::2, ::2, np.newaxis].astype(np.float32)) 202 | 203 | training_masks.append(training_mask[::2, ::2, np.newaxis].astype(np.float32)) 204 | 205 | if len(images) == batch_size: 206 | yield images, image_fns, score_maps_nrow, score_maps_ncol, \ 207 | score_maps_row, score_maps_col, training_masks 208 | images = [] 209 | image_fns = [] 210 | score_maps_nrow = [] 211 | 212 | score_maps_ncol = [] 213 | 214 | score_maps_row = [] 215 | 216 | score_maps_col = [] 217 | 218 | training_masks = [] 219 | except Exception as e: 220 | import traceback 221 | print(im_fn) 222 | traceback.print_exc() 223 | continue 224 | 225 | 226 | def get_batch(num_workers, **kwargs): 227 | try: 228 | enqueuer = GeneratorEnqueuer(generator(**kwargs), use_multiprocessing=True) 229 | print('Generator use 10 batches for buffering, this may take a while, you can tune this yourself.') 230 | enqueuer.start(max_queue_size=10, workers=num_workers) 231 | generator_output = None 232 | while True: 233 | while enqueuer.is_running(): 234 | if not enqueuer.queue.empty(): 235 | generator_output = enqueuer.queue.get() 236 | break 237 | else: 238 | time.sleep(0.01) 239 | yield generator_output 240 | generator_output = None 241 | finally: 242 | if enqueuer is not None: 243 | enqueuer.stop() 244 | 245 | 246 | 247 | if __name__ == '__main__': 248 | pass 249 | -------------------------------------------------------------------------------- /data_f.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import glob 3 | import csv 4 | import cv2 5 | import time 6 | import os 7 | import numpy as np 8 | from shapely.geometry import Polygon 9 | import random 10 | 11 | import tensorflow as tf 12 | 13 | from data_util import GeneratorEnqueuer 14 | 15 | tf.app.flags.DEFINE_string('training_data_path', './data_tx/raw_img/', 16 | 'training dataset to use') 17 | 18 | 19 | FLAGS = tf.app.flags.FLAGS 20 | 21 | 22 | def get_images(): 23 | files = [] 24 | for ext in ['jpg', 'png', 'jpeg', 'JPG']: 25 | files.extend(glob.glob( 26 | os.path.join(FLAGS.training_data_path, '*.{}'.format(ext)))) 27 | return files 28 | 29 | 30 | def load_annoataion(p): 31 | ''' 32 | load annotation from the text file 33 | :param p: 34 | :return: 35 | ''' 36 | text_polys = [] 37 | text_tags = [] 38 | if not os.path.exists(p): 39 | return np.array(text_polys, dtype=np.float32) 40 | with open(p, 'r') as f: 41 | reader = csv.reader(f) 42 | for line in reader: 43 | label = line[-1] 44 | # strip BOM. \ufeff for python3, \xef\xbb\bf for python2 45 | line = [i.strip('\ufeff').strip('\xef\xbb\xbf') for i in line] 46 | 47 | x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, line[:8])) 48 | text_polys.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) 49 | if label == '*' or label == '###': 50 | text_tags.append(True) 51 | else: 52 | text_tags.append(False) 53 | return np.array(text_polys, dtype=np.float32), np.array(text_tags, dtype=np.bool) 54 | 55 | def resize_train(im, label_row, label_col, label_nrow, label_ncol): 56 | h, w, _ = im.shape 57 | 58 | if h<450: 59 | h_new = 512*h/w 60 | pad = random.randint(10,512 - int(h_new)) 61 | 62 | im = cv2.copyMakeBorder(im, 0,pad,0,0,cv2.BORDER_CONSTANT,value=[255,255,255]) 63 | label_row = cv2.copyMakeBorder(label_row, 0, pad, 0, 0, cv2.BORDER_CONSTANT,value=1) 64 | label_col = cv2.copyMakeBorder(label_col, 0, pad, 0, 0, cv2.BORDER_CONSTANT,value=1) 65 | label_nrow = cv2.copyMakeBorder(label_nrow, 0, pad, 0, 0, cv2.BORDER_CONSTANT,value=1) 66 | label_ncol = cv2.copyMakeBorder(label_ncol, 0, pad, 0, 0, cv2.BORDER_CONSTANT,value=1) 67 | 68 | 69 | size = (int(512), int(512)) 70 | im_1 = cv2.resize(im, size, interpolation=cv2.INTER_AREA) 71 | label_row = cv2.resize(label_row, size, interpolation=cv2.INTER_AREA) 72 | label_col = cv2.resize(label_col, size, interpolation=cv2.INTER_AREA) 73 | label_nrow = cv2.resize(label_nrow, size, interpolation=cv2.INTER_AREA) 74 | label_ncol = cv2.resize(label_ncol, size, interpolation=cv2.INTER_AREA) 75 | 76 | return im_1,label_row,label_col,label_nrow,label_ncol 77 | 78 | 79 | def crop_area(im, label_row,label_col, label_nrow, label_ncol): 80 | 81 | im_p,la_row,la_col,la_nrow,la_ncol = resize_train(im, label_row, label_col, label_nrow, label_ncol) 82 | 83 | return im_p, la_row, la_col, la_nrow, la_ncol 84 | 85 | 86 | def point_dist_to_line(p1, p2, p3): 87 | # compute the distance from p3 to p1-p2 88 | return np.linalg.norm(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1) 89 | 90 | 91 | def fit_line(p1, p2): 92 | # fit a line ax+by+c = 0 93 | if p1[0] == p1[1]: 94 | return [1., 0., -p1[0]] 95 | else: 96 | [k, b] = np.polyfit(p1, p2, deg=1) 97 | return [k, -1., b] 98 | 99 | 100 | def line_cross_point(line1, line2): 101 | # line1 0= ax+by+c, compute the cross point of line1 and line2 102 | if line1[0] != 0 and line1[0] == line2[0]: 103 | print('Cross point does not exist') 104 | return None 105 | if line1[0] == 0 and line2[0] == 0: 106 | print('Cross point does not exist') 107 | return None 108 | if line1[1] == 0: 109 | x = -line1[2] 110 | y = line2[0] * x + line2[2] 111 | elif line2[1] == 0: 112 | x = -line2[2] 113 | y = line1[0] * x + line1[2] 114 | else: 115 | k1, _, b1 = line1 116 | k2, _, b2 = line2 117 | x = -(b1-b2)/(k1-k2) 118 | y = k1*x + b1 119 | return np.array([x, y], dtype=np.float32) 120 | 121 | 122 | def line_verticle(line, point): 123 | # get the verticle line from line across point 124 | if line[1] == 0: 125 | verticle = [0, -1, point[1]] 126 | else: 127 | if line[0] == 0: 128 | verticle = [1, 0, -point[0]] 129 | else: 130 | verticle = [-1./line[0], -1, point[1] - (-1/line[0] * point[0])] 131 | return verticle 132 | 133 | 134 | def generator_label(label_im, label_str): 135 | label_name = label_str.split('/')[-1] 136 | h, w = label_im.shape 137 | score_map = np.zeros((h, w), dtype=np.uint8) 138 | for i in range(h): 139 | for j in range(w): 140 | if label_im[i][j] == 0: 141 | score_map[i][j] = 0 142 | else: 143 | score_map[i][j] = 1 144 | 145 | return score_map 146 | 147 | def generator(input_size=512, batch_size=32, 148 | background_ratio=3./8, 149 | random_scale=np.array([0.5, 1, 2.0, 3.0]), 150 | vis=True): 151 | image_list = np.array(get_images()) 152 | print('{} training images in {}'.format( 153 | image_list.shape[0], FLAGS.training_data_path)) 154 | index = np.arange(0, image_list.shape[0]) 155 | while True: 156 | np.random.shuffle(index) 157 | images = [] 158 | image_fns = [] 159 | score_maps_nrow = [] 160 | 161 | score_maps_ncol = [] 162 | 163 | score_maps_row = [] 164 | 165 | score_maps_col = [] 166 | training_masks = [] 167 | for i in index: 168 | try: 169 | im_fn = image_list[i] 170 | im = cv2.imread(im_fn) 171 | if '.png' in im_fn: 172 | im_fn = im_fn.replace('.png','.jpg') 173 | 174 | h, w, _ = im.shape 175 | label_fn_nrow = im_fn.replace('raw', 'nrow') 176 | label_fn_ncol = im_fn.replace('raw', 'ncol') 177 | label_fn_row = im_fn.replace('raw', 'row') 178 | label_fn_col = im_fn.replace('raw', 'col') 179 | 180 | if not os.path.exists(label_fn_nrow): 181 | print('text file {} does not exists'.format(label_fn_nrow)) 182 | continue 183 | if not os.path.exists(label_fn_ncol): 184 | print('text file {} does not exists'.format(label_fn_ncol)) 185 | continue 186 | if not os.path.exists(label_fn_row): 187 | print('text file {} does not exists'.format(label_fn_row)) 188 | continue 189 | if not os.path.exists(label_fn_col): 190 | print('text file {} does not exists'.format(label_fn_col)) 191 | continue 192 | label_im_nrow = cv2.imread(label_fn_nrow, cv2.IMREAD_GRAYSCALE) 193 | label_im_ncol = cv2.imread(label_fn_ncol, cv2.IMREAD_GRAYSCALE) 194 | label_im_row = cv2.imread(label_fn_row, cv2.IMREAD_GRAYSCALE) 195 | label_im_col = cv2.imread(label_fn_col, cv2.IMREAD_GRAYSCALE) 196 | 197 | score_map_nrow = generator_label(label_im_nrow, label_fn_nrow) 198 | score_map_ncol = generator_label(label_im_ncol, label_fn_ncol) 199 | score_map_row = generator_label(label_im_row, label_fn_row) 200 | score_map_col = generator_label(label_im_col, label_fn_col) 201 | 202 | im, score_map_row,score_map_col,score_map_nrow,score_map_ncol = crop_area(im, score_map_row, score_map_col,score_map_nrow,score_map_ncol) 203 | 204 | training_mask = np.ones((input_size, input_size), dtype=np.uint8) 205 | 206 | images.append(im[:, :, ::-1].astype(np.float32)) 207 | image_fns.append(im_fn) 208 | 209 | score_maps_nrow.append(score_map_nrow[::2, ::2, np.newaxis].astype(np.float32)) 210 | 211 | score_maps_ncol.append(score_map_ncol[::2, ::2, np.newaxis].astype(np.float32)) 212 | 213 | score_maps_row.append(score_map_row[::2, ::2, np.newaxis].astype(np.float32)) 214 | 215 | score_maps_col.append(score_map_col[::2, ::2, np.newaxis].astype(np.float32)) 216 | 217 | training_masks.append(training_mask[::2, ::2, np.newaxis].astype(np.float32)) 218 | 219 | if len(images) == batch_size: 220 | yield images, image_fns, score_maps_nrow, score_maps_ncol, \ 221 | score_maps_row, score_maps_col, training_masks 222 | images = [] 223 | image_fns = [] 224 | score_maps_nrow = [] 225 | 226 | score_maps_ncol = [] 227 | 228 | score_maps_row = [] 229 | 230 | score_maps_col = [] 231 | 232 | training_masks = [] 233 | except Exception as e: 234 | import traceback 235 | print(im_fn) 236 | traceback.print_exc() 237 | continue 238 | 239 | 240 | def get_batch(num_workers, **kwargs): 241 | try: 242 | enqueuer = GeneratorEnqueuer(generator(**kwargs), use_multiprocessing=True) 243 | print('Generator use 10 batches for buffering, this may take a while, you can tune this yourself.') 244 | enqueuer.start(max_queue_size=10, workers=num_workers) 245 | generator_output = None 246 | while True: 247 | while enqueuer.is_running(): 248 | if not enqueuer.queue.empty(): 249 | generator_output = enqueuer.queue.get() 250 | break 251 | else: 252 | time.sleep(0.01) 253 | yield generator_output 254 | generator_output = None 255 | finally: 256 | if enqueuer is not None: 257 | enqueuer.stop() 258 | 259 | 260 | 261 | if __name__ == '__main__': 262 | pass 263 | -------------------------------------------------------------------------------- /nets/resnet_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import collections 3 | import tensorflow as tf 4 | 5 | slim = tf.contrib.slim 6 | 7 | 8 | class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])): 9 | """A named tuple describing a ResNet block. 10 | 11 | Its parts are: 12 | scope: The scope of the `Block`. 13 | unit_fn: The ResNet unnit function which takes as input a `Tensor` ad 14 | returns another `Tensor` with the output of the ResNet unit. 15 | args: A list of length equal to the number of units in the `Block`. The list 16 | contains one (depth, depth_bottleneck, stride) tuple for each unit in the 17 | block to serve as argument to unit_fn. 18 | """ 19 | 20 | 21 | def subsample(inputs, factor, scope=None): 22 | """Subsamples the input along the spatial dimensions. 23 | 24 | Args: 25 | inputs: A `Tensor` of size [batch, height_in, width_in, channels]. 26 | factor: The subsampling factor. 27 | scope: Optional variable_scope. 28 | 29 | Returns: 30 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the 31 | input, either intact (if factor == 1) or subsampled (if factor > 1). 32 | """ 33 | if factor == 1: 34 | return inputs 35 | else: 36 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope) 37 | 38 | 39 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None): 40 | """Strided 2-D convolution with 'SAME' padding. 41 | 42 | When stride > 1, then we do explicit zero-padding, followed by conv2d with 43 | 'VALID' padding. 44 | 45 | Note that 46 | 47 | net = conv2d_same(inputs, num_outputs, 3, stride=stride) 48 | 49 | is equivalent to 50 | 51 | net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME') 52 | net = subsample(net, factor=stride) 53 | 54 | whereas 55 | 56 | net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME') 57 | 58 | is different when the input's height or width is even, which is why we add the 59 | current function. For more details, see ResnetUtilsTest.testConv2DSameEven(). 60 | 61 | Args: 62 | inputs: A 4-D tensor of size [batch, height_in, width_in, channels]. 63 | num_outputs: An integer, the number of output filters. 64 | kernel_size: An int with the kernel_size of the filters. 65 | stride: An integer, the output stride. 66 | rate: An integer, rate for atrous convolution. 67 | scope: Scope. 68 | 69 | Returns: 70 | output: A 4-D tensor of size [batch, height_out, width_out, channels] with 71 | the convolution output. 72 | """ 73 | if stride == 1: 74 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate, 75 | padding='SAME', scope=scope) 76 | else: 77 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) 78 | pad_total = kernel_size_effective - 1 79 | pad_beg = pad_total // 2 80 | pad_end = pad_total - pad_beg 81 | inputs = tf.pad(inputs, 82 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) 83 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride, 84 | rate=rate, padding='VALID', scope=scope) 85 | 86 | 87 | @slim.add_arg_scope 88 | def stack_blocks_dense(net, blocks, output_stride=None, 89 | outputs_collections=None): 90 | """Stacks ResNet `Blocks` and controls output feature density. 91 | 92 | First, this function creates scopes for the ResNet in the form of 93 | 'block_name/unit_1', 'block_name/unit_2', etc. 94 | 95 | Second, this function allows the user to explicitly control the ResNet 96 | output_stride, which is the ratio of the input to output spatial resolution. 97 | This is useful for dense prediction tasks such as semantic segmentation or 98 | object detection. 99 | 100 | Most ResNets consist of 4 ResNet blocks and subsample the activations by a 101 | factor of 2 when transitioning between consecutive ResNet blocks. This results 102 | to a nominal ResNet output_stride equal to 8. If we set the output_stride to 103 | half the nominal network stride (e.g., output_stride=4), then we compute 104 | responses twice. 105 | 106 | Control of the output feature density is implemented by atrous convolution. 107 | 108 | Args: 109 | net: A `Tensor` of size [batch, height, width, channels]. 110 | blocks: A list of length equal to the number of ResNet `Blocks`. Each 111 | element is a ResNet `Block` object describing the units in the `Block`. 112 | output_stride: If `None`, then the output will be computed at the nominal 113 | network stride. If output_stride is not `None`, it specifies the requested 114 | ratio of input to output spatial resolution, which needs to be equal to 115 | the product of unit strides from the start up to some level of the ResNet. 116 | For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1, 117 | then valid values for the output_stride are 1, 2, 6, 24 or None (which 118 | is equivalent to output_stride=24). 119 | outputs_collections: Collection to add the ResNet block outputs. 120 | 121 | Returns: 122 | net: Output tensor with stride equal to the specified output_stride. 123 | 124 | Raises: 125 | ValueError: If the target output_stride is not valid. 126 | """ 127 | # The current_stride variable keeps track of the effective stride of the 128 | # activations. This allows us to invoke atrous convolution whenever applying 129 | # the next residual unit would result in the activations having stride larger 130 | # than the target output_stride. 131 | current_stride = 1 132 | 133 | # The atrous convolution rate parameter. 134 | rate = 1 135 | 136 | for block in blocks: 137 | with tf.variable_scope(block.scope, 'block', [net]) as sc: 138 | for i, unit in enumerate(block.args): 139 | if output_stride is not None and current_stride > output_stride: 140 | raise ValueError('The target output_stride cannot be reached.') 141 | 142 | with tf.variable_scope('unit_%d' % (i + 1), values=[net]): 143 | unit_depth, unit_depth_bottleneck, unit_stride = unit 144 | # If we have reached the target output_stride, then we need to employ 145 | # atrous convolution with stride=1 and multiply the atrous rate by the 146 | # current unit's stride for use in subsequent layers. 147 | if output_stride is not None and current_stride == output_stride: 148 | net = block.unit_fn(net, 149 | depth=unit_depth, 150 | depth_bottleneck=unit_depth_bottleneck, 151 | stride=1, 152 | rate=rate) 153 | rate *= unit_stride 154 | 155 | else: 156 | net = block.unit_fn(net, 157 | depth=unit_depth, 158 | depth_bottleneck=unit_depth_bottleneck, 159 | stride=unit_stride, 160 | rate=1) 161 | current_stride *= unit_stride 162 | print(sc.name, net.shape) 163 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 164 | 165 | if output_stride is not None and current_stride != output_stride: 166 | raise ValueError('The target output_stride cannot be reached.') 167 | 168 | return net 169 | 170 | 171 | def resnet_arg_scope(weight_decay=0.0001, 172 | batch_norm_decay=0.997, 173 | batch_norm_epsilon=1e-5, 174 | batch_norm_scale=True): 175 | """Defines the default ResNet arg scope. 176 | 177 | TODO(gpapan): The batch-normalization related default values above are 178 | appropriate for use in conjunction with the reference ResNet models 179 | released at https://github.com/KaimingHe/deep-residual-networks. When 180 | training ResNets from scratch, they might need to be tuned. 181 | 182 | Args: 183 | weight_decay: The weight decay to use for regularizing the model. 184 | batch_norm_decay: The moving average decay when estimating layer activation 185 | statistics in batch normalization. 186 | batch_norm_epsilon: Small constant to prevent division by zero when 187 | normalizing activations by their variance in batch normalization. 188 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the 189 | activations in the batch normalization layer. 190 | 191 | Returns: 192 | An `arg_scope` to use for the resnet models. 193 | """ 194 | batch_norm_params = { 195 | 'decay': batch_norm_decay, 196 | 'epsilon': batch_norm_epsilon, 197 | 'scale': batch_norm_scale, 198 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 199 | } 200 | 201 | with slim.arg_scope( 202 | [slim.conv2d], 203 | weights_regularizer=slim.l2_regularizer(weight_decay), 204 | weights_initializer=slim.variance_scaling_initializer(), 205 | activation_fn=tf.nn.relu, 206 | normalizer_fn=slim.batch_norm, 207 | normalizer_params=batch_norm_params): 208 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 209 | # The following implies padding='SAME' for pool1, which makes feature 210 | # alignment easier for dense prediction tasks. This is also used in 211 | # https://github.com/facebook/fb.resnet.torch. However the accompanying 212 | # code of 'Deep Residual Learning for Image Recognition' uses 213 | # padding='VALID' for pool1. You can switch to that choice by setting 214 | # slim.arg_scope([slim.max_pool2d], padding='VALID'). 215 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc: 216 | return arg_sc 217 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import tensorflow as tf 4 | from tensorflow.contrib import slim 5 | import cv2 6 | 7 | tf.app.flags.DEFINE_integer('input_size', 512, '') 8 | tf.app.flags.DEFINE_integer('batch_size_per_gpu', 3, '') 9 | tf.app.flags.DEFINE_integer('num_readers', 16, '') 10 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, '') 11 | tf.app.flags.DEFINE_integer('max_steps', 100000, '') 12 | tf.app.flags.DEFINE_float('moving_average_decay', 0.997, '') 13 | tf.app.flags.DEFINE_string('gpu_list', '1', '') 14 | tf.app.flags.DEFINE_string('checkpoint_path', './model/', '') 15 | tf.app.flags.DEFINE_boolean('restore', False, 'whether to resotre from checkpoint') 16 | tf.app.flags.DEFINE_integer('save_checkpoint_steps', 100, '') 17 | tf.app.flags.DEFINE_integer('save_summary_steps', 100, '') 18 | tf.app.flags.DEFINE_string('pretrained_model_path', None, '') 19 | 20 | import model as model 21 | import dataf 22 | 23 | FLAGS = tf.app.flags.FLAGS 24 | 25 | gpus = list(range(len(FLAGS.gpu_list.split(',')))) 26 | 27 | 28 | def tower_loss(images, score_maps_nrow, score_maps_ncol, score_maps_row, 29 | score_maps_col, training_masks, reuse_variables=None): 30 | # Build inference graph 31 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables): 32 | f_score_nrow, f_score_ncol, \ 33 | f_score_row, f_score_col = model.model(images, is_training=True) 34 | 35 | model_loss = model.loss(score_maps_nrow, f_score_nrow, 36 | score_maps_ncol, f_score_ncol, 37 | score_maps_row, f_score_row, 38 | score_maps_col, f_score_col, 39 | training_masks) 40 | total_loss = tf.add_n([model_loss] + tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) 41 | 42 | # add summary 43 | if reuse_variables is None: 44 | tf.summary.image('input', images) 45 | tf.summary.image('score_map', score_maps_nrow) 46 | tf.summary.image('score_map_pred', f_score_nrow * 255) 47 | 48 | tf.summary.image('score_map', score_maps_ncol) 49 | tf.summary.image('score_map_pred', f_score_ncol * 255) 50 | 51 | tf.summary.image('score_map', score_maps_row) 52 | tf.summary.image('score_map_pred', f_score_row * 255) 53 | 54 | tf.summary.image('score_map', score_maps_col) 55 | tf.summary.image('score_map_pred', f_score_col * 255) 56 | 57 | tf.summary.image('training_masks', training_masks) 58 | tf.summary.scalar('model_loss', model_loss) 59 | tf.summary.scalar('total_loss', total_loss) 60 | 61 | return total_loss, model_loss 62 | 63 | 64 | def average_gradients(tower_grads): 65 | average_grads = [] 66 | for grad_and_vars in zip(*tower_grads): 67 | grads = [] 68 | for g, _ in grad_and_vars: 69 | expanded_g = tf.expand_dims(g, 0) 70 | grads.append(expanded_g) 71 | 72 | grad = tf.concat(grads, 0) 73 | grad = tf.reduce_mean(grad, 0) 74 | 75 | v = grad_and_vars[0][1] 76 | grad_and_var = (grad, v) 77 | average_grads.append(grad_and_var) 78 | 79 | return average_grads 80 | 81 | 82 | def main(argv=None): 83 | import os 84 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list 85 | if not tf.gfile.Exists(FLAGS.checkpoint_path): 86 | tf.gfile.MkDir(FLAGS.checkpoint_path) 87 | else: 88 | if not FLAGS.restore: 89 | tf.gfile.DeleteRecursively(FLAGS.checkpoint_path) 90 | tf.gfile.MkDir(FLAGS.checkpoint_path) 91 | 92 | input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images') 93 | input_score_maps_nrow = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='input_score_maps_nrow') 94 | 95 | input_score_maps_ncol = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='input_score_maps_ncol') 96 | 97 | input_score_maps_row = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='input_score_maps_row') 98 | 99 | input_score_maps_col = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='input_score_maps_col') 100 | 101 | input_training_masks = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='input_training_masks') 102 | 103 | global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) 104 | learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, decay_steps=10000, decay_rate=0.94, staircase=True) 105 | # add summary 106 | tf.summary.scalar('learning_rate', learning_rate) 107 | opt = tf.train.AdamOptimizer(learning_rate) 108 | # opt = tf.train.MomentumOptimizer(learning_rate, 0.9) 109 | 110 | # split 111 | input_images_split = tf.split(input_images, len(gpus)) 112 | input_score_maps_split_nrow = tf.split(input_score_maps_nrow, len(gpus)) 113 | 114 | input_score_maps_split_ncol = tf.split(input_score_maps_ncol, len(gpus)) 115 | 116 | input_score_maps_split_row = tf.split(input_score_maps_row, len(gpus)) 117 | 118 | input_score_maps_split_col = tf.split(input_score_maps_col, len(gpus)) 119 | input_training_masks_split = tf.split(input_training_masks, len(gpus)) 120 | 121 | tower_grads = [] 122 | reuse_variables = None 123 | for i, gpu_id in enumerate(gpus): 124 | with tf.device('/gpu:%d' % gpu_id): 125 | with tf.name_scope('model_%d' % gpu_id) as scope: 126 | iis = input_images_split[i] 127 | isms_nrow = input_score_maps_split_nrow[i] 128 | 129 | isms_ncol = input_score_maps_split_ncol[i] 130 | 131 | isms_row = input_score_maps_split_row[i] 132 | 133 | isms_col = input_score_maps_split_col[i] 134 | itms = input_training_masks_split[i] 135 | # total_loss, model_loss = tower_loss(iis, isms, igms, itms, reuse_variables) 136 | 137 | total_loss, model_loss = tower_loss(iis, isms_nrow, 138 | isms_ncol, isms_row, 139 | isms_col, itms, reuse_variables) 140 | 141 | batch_norm_updates_op = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope)) 142 | reuse_variables = True 143 | 144 | grads = opt.compute_gradients(total_loss) 145 | tower_grads.append(grads) 146 | 147 | grads = average_gradients(tower_grads) 148 | apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) 149 | 150 | summary_op = tf.summary.merge_all() 151 | # save moving average 152 | variable_averages = tf.train.ExponentialMovingAverage( 153 | FLAGS.moving_average_decay, global_step) 154 | variables_averages_op = variable_averages.apply(tf.trainable_variables()) 155 | # batch norm updates 156 | with tf.control_dependencies([variables_averages_op, apply_gradient_op, batch_norm_updates_op]): 157 | train_op = tf.no_op(name='train_op') 158 | 159 | saver = tf.train.Saver(tf.global_variables()) 160 | summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_path, tf.get_default_graph()) 161 | 162 | init = tf.global_variables_initializer() 163 | 164 | if FLAGS.pretrained_model_path is not None: 165 | variable_restore_op = slim.assign_from_checkpoint_fn(FLAGS.pretrained_model_path, slim.get_trainable_variables(), 166 | ignore_missing_vars=True) 167 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 168 | if FLAGS.restore: 169 | print('continue training from previous checkpoint') 170 | ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_path) 171 | saver.restore(sess, ckpt) 172 | else: 173 | sess.run(init) 174 | if FLAGS.pretrained_model_path is not None: 175 | variable_restore_op(sess) 176 | 177 | data_generator = dataf.get_batch(num_workers=FLAGS.num_readers, 178 | input_size=FLAGS.input_size, 179 | batch_size=FLAGS.batch_size_per_gpu * len(gpus)) 180 | 181 | start = time.time() 182 | for step in range(FLAGS.max_steps): 183 | data = next(data_generator) 184 | ml, tl, _ = sess.run([model_loss, total_loss, train_op], feed_dict={input_images: data[0], 185 | input_score_maps_nrow: data[2], 186 | input_score_maps_ncol: data[3], 187 | input_score_maps_row: data[4], 188 | input_score_maps_col: data[5], 189 | input_training_masks: data[6]}) 190 | if np.isnan(tl): 191 | print('Loss diverged, stop training') 192 | break 193 | 194 | if step % 10 == 0: 195 | avg_time_per_step = (time.time() - start)/10 196 | avg_examples_per_second = (10 * FLAGS.batch_size_per_gpu * len(gpus))/(time.time() - start) 197 | start = time.time() 198 | print('Step {:06d}, model loss {:.4f}, total loss {:.4f}, {:.2f} seconds/step, {:.2f} examples/second'.format( 199 | step, ml, tl, avg_time_per_step, avg_examples_per_second)) 200 | 201 | if step % FLAGS.save_checkpoint_steps == 0: 202 | saver.save(sess, FLAGS.checkpoint_path + 'model.ckpt', global_step=global_step) 203 | 204 | if step % FLAGS.save_summary_steps == 0: 205 | _, tl, summary_str = sess.run([train_op, total_loss, summary_op], feed_dict={input_images: data[0], 206 | input_score_maps_nrow: data[2], 207 | input_score_maps_ncol: data[3], 208 | input_score_maps_row: data[4], 209 | input_score_maps_col: data[5], 210 | input_training_masks: data[6]}) 211 | summary_writer.add_summary(summary_str, global_step=step) 212 | 213 | if __name__ == '__main__': 214 | tf.app.run() 215 | -------------------------------------------------------------------------------- /nets/resnet_v1.py: -------------------------------------------------------------------------------- 1 | 2 | # from __future__ import absolute_import 3 | # from __future__ import division 4 | # from __future__ import print_function 5 | 6 | import tensorflow as tf 7 | from tensorflow.contrib import slim 8 | 9 | from . import resnet_utils 10 | 11 | resnet_arg_scope = resnet_utils.resnet_arg_scope 12 | 13 | 14 | @slim.add_arg_scope 15 | def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1, 16 | outputs_collections=None, scope=None): 17 | """Bottleneck residual unit variant with BN after convolutions. 18 | 19 | This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for 20 | its definition. Note that we use here the bottleneck variant which has an 21 | extra bottleneck layer. 22 | 23 | When putting together two consecutive ResNet blocks that use this unit, one 24 | should use stride = 2 in the last unit of the first block. 25 | 26 | Args: 27 | inputs: A tensor of size [batch, height, width, channels]. 28 | depth: The depth of the ResNet unit output. 29 | depth_bottleneck: The depth of the bottleneck layers. 30 | stride: The ResNet unit's stride. Determines the amount of downsampling of 31 | the units output compared to its input. 32 | rate: An integer, rate for atrous convolution. 33 | outputs_collections: Collection to add the ResNet unit output. 34 | scope: Optional variable_scope. 35 | 36 | Returns: 37 | The ResNet unit's output. 38 | """ 39 | with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc: 40 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) 41 | if depth == depth_in: 42 | shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') 43 | else: 44 | shortcut = slim.conv2d(inputs, depth, [1, 1], stride=stride, 45 | activation_fn=None, scope='shortcut') 46 | 47 | residual = slim.conv2d(inputs, depth_bottleneck, [1, 1], stride=1, 48 | scope='conv1') 49 | residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride, 50 | rate=rate, scope='conv2') 51 | residual = slim.conv2d(residual, depth, [1, 1], stride=1, 52 | activation_fn=None, scope='conv3') 53 | 54 | output = tf.nn.relu(shortcut + residual) 55 | 56 | return slim.utils.collect_named_outputs(outputs_collections, 57 | sc.original_name_scope, 58 | output) 59 | 60 | 61 | def resnet_v1(inputs, 62 | blocks, 63 | num_classes=None, 64 | is_training=True, 65 | global_pool=True, 66 | output_stride=None, 67 | include_root_block=True, 68 | spatial_squeeze=True, 69 | reuse=None, 70 | scope=None): 71 | """Generator for v1 ResNet models. 72 | 73 | This function generates a family of ResNet v1 models. See the resnet_v1_*() 74 | methods for specific model instantiations, obtained by selecting different 75 | block instantiations that produce ResNets of various depths. 76 | 77 | Training for image classification on Imagenet is usually done with [224, 224] 78 | inputs, resulting in [7, 7] feature maps at the output of the last ResNet 79 | block for the ResNets defined in [1] that have nominal stride equal to 32. 80 | However, for dense prediction tasks we advise that one uses inputs with 81 | spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In 82 | this case the feature maps at the ResNet output will have spatial shape 83 | [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1] 84 | and corners exactly aligned with the input image corners, which greatly 85 | facilitates alignment of the features to the image. Using as input [225, 225] 86 | images results in [8, 8] feature maps at the output of the last ResNet block. 87 | 88 | For dense prediction tasks, the ResNet needs to run in fully-convolutional 89 | (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all 90 | have nominal stride equal to 32 and a good choice in FCN mode is to use 91 | output_stride=16 in order to increase the density of the computed features at 92 | small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915. 93 | 94 | Args: 95 | inputs: A tensor of size [batch, height_in, width_in, channels]. 96 | blocks: A list of length equal to the number of ResNet blocks. Each element 97 | is a resnet_utils.Block object describing the units in the block. 98 | num_classes: Number of predicted classes for classification tasks. If None 99 | we return the features before the logit layer. 100 | is_training: whether is training or not. 101 | global_pool: If True, we perform global average pooling before computing the 102 | logits. Set to True for image classification, False for dense prediction. 103 | output_stride: If None, then the output will be computed at the nominal 104 | network stride. If output_stride is not None, it specifies the requested 105 | ratio of input to output spatial resolution. 106 | include_root_block: If True, include the initial convolution followed by 107 | max-pooling, if False excludes it. 108 | spatial_squeeze: if True, logits is of shape [B, C], if false logits is 109 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes. 110 | reuse: whether or not the network and its variables should be reused. To be 111 | able to reuse 'scope' must be given. 112 | scope: Optional variable_scope. 113 | 114 | Returns: 115 | net: A rank-4 tensor of size [batch, height_out, width_out, channels_out]. 116 | If global_pool is False, then height_out and width_out are reduced by a 117 | factor of output_stride compared to the respective height_in and width_in, 118 | else both height_out and width_out equal one. If num_classes is None, then 119 | net is the output of the last ResNet block, potentially after global 120 | average pooling. If num_classes is not None, net contains the pre-softmax 121 | activations. 122 | end_points: A dictionary from components of the network to the corresponding 123 | activation. 124 | 125 | Raises: 126 | ValueError: If the target output_stride is not valid. 127 | """ 128 | with tf.variable_scope(scope, 'resnet_v1', [inputs], reuse=reuse) as sc: 129 | end_points_collection = sc.name + '_end_points' 130 | with slim.arg_scope([slim.conv2d, bottleneck, 131 | resnet_utils.stack_blocks_dense], 132 | outputs_collections=end_points_collection): 133 | with slim.arg_scope([slim.batch_norm], is_training=is_training): 134 | net = inputs 135 | if include_root_block: 136 | if output_stride is not None: 137 | if output_stride % 4 != 0: 138 | raise ValueError('The output_stride needs to be a multiple of 4.') 139 | output_stride /= 4 140 | net = resnet_utils.conv2d_same(net, 64, 7, stride=1, scope='conv1') 141 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') 142 | 143 | net = slim.utils.collect_named_outputs(end_points_collection, 'pool2', net) 144 | 145 | net = resnet_utils.stack_blocks_dense(net, blocks, output_stride) 146 | 147 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 148 | 149 | # end_points['pool2'] = end_points['resnet_v1_50/pool1/MaxPool:0'] 150 | try: 151 | end_points['pool3'] = end_points['resnet_v1_50/block1'] 152 | end_points['pool4'] = end_points['resnet_v1_50/block2'] 153 | except: 154 | end_points['pool3'] = end_points['Detection/resnet_v1_50/block1'] 155 | end_points['pool4'] = end_points['Detection/resnet_v1_50/block2'] 156 | end_points['pool5'] = net 157 | # if global_pool: 158 | # # Global average pooling. 159 | # net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True) 160 | # if num_classes is not None: 161 | # net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, 162 | # normalizer_fn=None, scope='logits') 163 | # if spatial_squeeze: 164 | # logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze') 165 | # else: 166 | # logits = net 167 | # # Convert end_points_collection into a dictionary of end_points. 168 | # end_points = slim.utils.convert_collection_to_dict(end_points_collection) 169 | # if num_classes is not None: 170 | # end_points['predictions'] = slim.softmax(logits, scope='predictions') 171 | return net, end_points 172 | 173 | 174 | resnet_v1.default_image_size = 224 175 | 176 | 177 | def resnet_v1_50(inputs, 178 | num_classes=None, 179 | is_training=True, 180 | global_pool=True, 181 | output_stride=None, 182 | spatial_squeeze=True, 183 | reuse=None, 184 | scope='resnet_v1_50'): 185 | """ResNet-50 model of [1]. See resnet_v1() for arg and return description.""" 186 | blocks = [ 187 | resnet_utils.Block( 188 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), 189 | resnet_utils.Block( 190 | 'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]), 191 | resnet_utils.Block( 192 | 'block3', bottleneck, [(1024, 256, 1)] * 5 + [(1024, 256, 2)]), 193 | resnet_utils.Block( 194 | 'block4', bottleneck, [(2048, 512, 1)] * 3) 195 | ] 196 | return resnet_v1(inputs, blocks, num_classes, is_training, 197 | global_pool=global_pool, output_stride=output_stride, 198 | include_root_block=True, spatial_squeeze=spatial_squeeze, 199 | reuse=reuse, scope=scope) 200 | 201 | 202 | resnet_v1_50.default_image_size = resnet_v1.default_image_size 203 | 204 | 205 | def resnet_v1_101(inputs, 206 | num_classes=None, 207 | is_training=True, 208 | global_pool=True, 209 | output_stride=None, 210 | spatial_squeeze=True, 211 | reuse=None, 212 | scope='resnet_v1_101'): 213 | """ResNet-101 model of [1]. See resnet_v1() for arg and return description.""" 214 | blocks = [ 215 | resnet_utils.Block( 216 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), 217 | resnet_utils.Block( 218 | 'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]), 219 | resnet_utils.Block( 220 | 'block3', bottleneck, [(1024, 256, 1)] * 22 + [(1024, 256, 2)]), 221 | resnet_utils.Block( 222 | 'block4', bottleneck, [(2048, 512, 1)] * 3) 223 | ] 224 | return resnet_v1(inputs, blocks, num_classes, is_training, 225 | global_pool=global_pool, output_stride=output_stride, 226 | include_root_block=True, spatial_squeeze=spatial_squeeze, 227 | reuse=reuse, scope=scope) 228 | 229 | 230 | resnet_v1_101.default_image_size = resnet_v1.default_image_size 231 | 232 | 233 | def resnet_v1_152(inputs, 234 | num_classes=None, 235 | is_training=True, 236 | global_pool=True, 237 | output_stride=None, 238 | spatial_squeeze=True, 239 | reuse=None, 240 | scope='resnet_v1_152'): 241 | """ResNet-152 model of [1]. See resnet_v1() for arg and return description.""" 242 | blocks = [ 243 | resnet_utils.Block( 244 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), 245 | resnet_utils.Block( 246 | 'block2', bottleneck, [(512, 128, 1)] * 7 + [(512, 128, 2)]), 247 | resnet_utils.Block( 248 | 'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]), 249 | resnet_utils.Block( 250 | 'block4', bottleneck, [(2048, 512, 1)] * 3)] 251 | return resnet_v1(inputs, blocks, num_classes, is_training, 252 | global_pool=global_pool, output_stride=output_stride, 253 | include_root_block=True, spatial_squeeze=spatial_squeeze, 254 | reuse=reuse, scope=scope) 255 | 256 | 257 | resnet_v1_152.default_image_size = resnet_v1.default_image_size 258 | 259 | 260 | def resnet_v1_200(inputs, 261 | num_classes=None, 262 | is_training=True, 263 | global_pool=True, 264 | output_stride=None, 265 | spatial_squeeze=True, 266 | reuse=None, 267 | scope='resnet_v1_200'): 268 | """ResNet-200 model of [2]. See resnet_v1() for arg and return description.""" 269 | blocks = [ 270 | resnet_utils.Block( 271 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), 272 | resnet_utils.Block( 273 | 'block2', bottleneck, [(512, 128, 1)] * 23 + [(512, 128, 2)]), 274 | resnet_utils.Block( 275 | 'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]), 276 | resnet_utils.Block( 277 | 'block4', bottleneck, [(2048, 512, 1)] * 3)] 278 | return resnet_v1(inputs, blocks, num_classes, is_training, 279 | global_pool=global_pool, output_stride=output_stride, 280 | include_root_block=True, spatial_squeeze=spatial_squeeze, 281 | reuse=reuse, scope=scope) 282 | 283 | 284 | resnet_v1_200.default_image_size = resnet_v1.default_image_size 285 | 286 | 287 | if __name__ == '__main__': 288 | input = tf.placeholder(tf.float32, shape=(None, 224, 224, 3), name='input') 289 | with slim.arg_scope(resnet_arg_scope()) as sc: 290 | logits = resnet_v1_50(input) --------------------------------------------------------------------------------