├── 70-lstm-layout.pdf ├── 80-lstm-text.pdf ├── README.md ├── lines └── 0.png ├── make_training_labels ├── K47LYBN.framed.png ├── K47LYBN.framed.txt ├── K47LYBN.lines.png ├── W001.png ├── data_pre_process.py └── out.png ├── md_lstm.py ├── segmentation.py └── train_test.py /70-lstm-layout.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/watersink/ocrsegment/995cf725d277cd4b892f033aa6f9ec81965bd743/70-lstm-layout.pdf -------------------------------------------------------------------------------- /80-lstm-text.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/watersink/ocrsegment/995cf725d277cd4b892f033aa6f9ec81965bd743/80-lstm-text.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OCR Segmentation 2 | a deep learning model for page layout analysis / segmentation. 3 | 4 | ## dependencies 5 | tensorflow1.8 6 | > 7 | python3 8 | 9 | ## dataset: 10 | [uw3-framed-lines-degraded-000](https://storage.googleapis.com/tmb-ocr/uw3-framed-lines-degraded-000.tgz) 11 | 12 | ## make training labels 13 | python3 data_pre_process.py 14 | 15 | ## train 16 | python3 train_test.py 17 | ## test 18 | python3 segmentation.py 19 | ![image]( https://github.com/watersink/ocrsegment/blob/master/make_training_labels/W001.png) 20 | ![image]( https://github.com/watersink/ocrsegment/blob/master/make_training_labels/out.png) 21 | ![image]( https://github.com/watersink/ocrsegment/blob/master/lines/0.png) 22 | 23 | ## references 24 | [Multi-Dimensional Recurrent Neural Networks](https://arxiv.org/abs/0705.2011)
25 | [Robust_ Simple Page Segmentation Using Hybrid Convolutional MDLSTM Networks](https://github.com/wanghaisheng/awesome-ocr/files/2042377/Robust_.Simple.Page.Segmentation.Using.Hybrid.Convolutional.MDLSTM.Networks.pdf)
26 | [https://github.com/NVlabs/ocroseg](https://github.com/NVlabs/ocroseg)
27 | [https://github.com/philipperemy/tensorflow-multi-dimensional-lstm](https://github.com/philipperemy/tensorflow-multi-dimensional-lstm)
28 | 29 | 30 | -------------------------------------------------------------------------------- /lines/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/watersink/ocrsegment/995cf725d277cd4b892f033aa6f9ec81965bd743/lines/0.png -------------------------------------------------------------------------------- /make_training_labels/K47LYBN.framed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/watersink/ocrsegment/995cf725d277cd4b892f033aa6f9ec81965bd743/make_training_labels/K47LYBN.framed.png -------------------------------------------------------------------------------- /make_training_labels/K47LYBN.framed.txt: -------------------------------------------------------------------------------- 1 | 939,259,1436,305 2 | 1073,3236,1118,3265 3 | 262,448,2104,504 4 | 957,2918,1395,2964 5 | 864,2982,1487,3036 6 | 258,539,373,587 7 | 356,621,2104,671 8 | 263,699,2102,755 9 | 259,790,1204,839 10 | 357,867,2103,922 11 | 262,956,2094,1006 12 | 257,1040,2102,1090 13 | 258,1123,2102,1179 14 | 258,1200,2100,1255 15 | 257,1289,2102,1338 16 | 258,1366,2104,1428 17 | 257,1449,1520,1504 18 | 354,1532,2075,1587 19 | 256,1614,2101,1677 20 | 255,1704,2105,1760 21 | 256,1782,688,1837 22 | -------------------------------------------------------------------------------- /make_training_labels/K47LYBN.lines.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/watersink/ocrsegment/995cf725d277cd4b892f033aa6f9ec81965bd743/make_training_labels/K47LYBN.lines.png -------------------------------------------------------------------------------- /make_training_labels/W001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/watersink/ocrsegment/995cf725d277cd4b892f033aa6f9ec81965bd743/make_training_labels/W001.png -------------------------------------------------------------------------------- /make_training_labels/data_pre_process.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | import scipy.ndimage as ndi 5 | 6 | 7 | def _get_text_line_image(image,line_label): 8 | x1 = line_label[0] 9 | y1 = line_label[1] 10 | x2 = line_label[2] 11 | y2 = line_label[3] 12 | h = y2-y1 13 | image_line = image[y1:y2, x1:x2] 14 | 15 | nlabels, labels, stats, centroids = cv2.connectedComponentsWithStats(image_line) 16 | 17 | image_connect = np.zeros(image_line.shape, dtype=np.uint8) 18 | 19 | 20 | for i in range(1, stats.shape[0]): 21 | x1 = stats[i, 0] 22 | y1 = stats[i, 1] 23 | x2 = stats[i, 2] + x1 24 | y2 = stats[i, 3] + y1 25 | image_connect[y1:y2, x1:x2] = 255 26 | 27 | 28 | image_anigauss = ndi.gaussian_filter(image_connect, sigma=[h/2, h], mode='constant', cval=0) 29 | image_line_center_index = np.argmax(image_anigauss, 0) 30 | 31 | image_center_gauss_index = ndi.gaussian_filter1d(image_line_center_index, sigma=h/3) 32 | for k in range(len(image_center_gauss_index)): 33 | image_anigauss[image_center_gauss_index[k], k] = 255 34 | 35 | image_bin2 = (np.uint8(image_anigauss)>254)*255.0 36 | 37 | kernel = np.ones((3, 3), np.uint8) 38 | image_dilation = cv2.dilate(image_bin2, kernel, iterations=3) 39 | temp = (np.uint8(image_dilation)>254)*255 40 | 41 | return temp 42 | 43 | 44 | 45 | 46 | if __name__ == '__main__': 47 | image_name="K47LYBN.framed.png" 48 | image = cv2.imread(image_name, 0) 49 | _, image = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU) 50 | 51 | image = 255-image 52 | label_image = np.zeros_like(image) 53 | text_file = os.path.splitext(image_name)[0]+'.txt' 54 | with open(text_file, "r", encoding="utf-8") as f: 55 | labels = f.readlines() 56 | labels=[labels[i].rstrip("\n").split(",") for i in range(len(labels))] 57 | label_arrays=np.asarray(labels,np.int32) 58 | for line in label_arrays: 59 | label_line_image = _get_text_line_image(image,line) 60 | x1 = line[0] 61 | y1 = line[1] 62 | x2 = line[2] 63 | y2 = line[3] 64 | label_image[y1:y2, x1:x2] = label_line_image 65 | 66 | cv2.imwrite(image_name.replace("framed","lines"), label_image) 67 | 68 | -------------------------------------------------------------------------------- /make_training_labels/out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/watersink/ocrsegment/995cf725d277cd4b892f033aa6f9ec81965bd743/make_training_labels/out.png -------------------------------------------------------------------------------- /md_lstm.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.rnn import RNNCell, LSTMStateTuple 3 | from tensorflow.contrib.rnn.python.ops.core_rnn_cell import _linear 4 | from tensorflow.python.ops.rnn import dynamic_rnn 5 | 6 | 7 | def ln(tensor, scope=None, epsilon=1e-5): 8 | """ Layer normalizes a 2D tensor along its second axis """ 9 | assert (len(tensor.get_shape()) == 2) 10 | m, v = tf.nn.moments(tensor, [1], keep_dims=True) 11 | if not isinstance(scope, str): 12 | scope = '' 13 | with tf.variable_scope(scope + 'layer_norm'): 14 | scale = tf.get_variable('scale', 15 | shape=[tensor.get_shape()[1]], 16 | initializer=tf.constant_initializer(1)) 17 | shift = tf.get_variable('shift', 18 | shape=[tensor.get_shape()[1]], 19 | initializer=tf.constant_initializer(0)) 20 | ln_initial = (tensor - m) / tf.sqrt(v + epsilon) 21 | 22 | return ln_initial * scale + shift 23 | 24 | 25 | class MultiDimensionalLSTMCell(RNNCell): 26 | """ 27 | Adapted from TF's BasicLSTMCell to use Layer Normalization. 28 | Note that state_is_tuple is always True. 29 | """ 30 | 31 | def __init__(self, num_units, forget_bias=0.0, activation=tf.nn.tanh): 32 | self._num_units = num_units 33 | self._forget_bias = forget_bias 34 | self._activation = activation 35 | 36 | @property 37 | def state_size(self): 38 | return LSTMStateTuple(self._num_units, self._num_units) 39 | 40 | @property 41 | def output_size(self): 42 | return self._num_units 43 | 44 | def __call__(self, inputs, state, scope=None): 45 | """Long short-term memory cell (LSTM). 46 | @param: inputs (batch,n) 47 | @param state: the states and hidden unit of the two cells 48 | """ 49 | with tf.variable_scope(scope or type(self).__name__): 50 | c1, c2, h1, h2 = state 51 | 52 | # change bias argument to False since LN will add bias via shift 53 | concat = _linear([inputs, h1, h2], 5 * self._num_units, False) 54 | 55 | i, j, f1, f2, o = tf.split(value=concat, num_or_size_splits=5, axis=1) 56 | 57 | # add layer normalization to each gate 58 | i = ln(i, scope='i/') 59 | j = ln(j, scope='j/') 60 | f1 = ln(f1, scope='f1/') 61 | f2 = ln(f2, scope='f2/') 62 | o = ln(o, scope='o/') 63 | 64 | new_c = (c1 * tf.nn.sigmoid(f1 + self._forget_bias) + 65 | c2 * tf.nn.sigmoid(f2 + self._forget_bias) + tf.nn.sigmoid(i) * 66 | self._activation(j)) 67 | 68 | # add layer_normalization in calculation of new hidden state 69 | new_h = self._activation(ln(new_c, scope='new_h/')) * tf.nn.sigmoid(o) 70 | new_state = LSTMStateTuple(new_c, new_h) 71 | 72 | return new_h, new_state 73 | 74 | 75 | def multi_dimensional_rnn_while_loop(rnn_size, input_data, sh, dims=None, scope_n="layer1"): 76 | """Implements naive multi dimension recurrent neural networks 77 | 78 | @param rnn_size: the hidden units 79 | @param input_data: the data to process of shape [batch,h,w,channels] 80 | @param sh: [height,width] of the windows 81 | @param dims: dimensions to reverse the input data,eg. 82 | dims=[False,True,True,False] => true means reverse dimension 83 | @param scope_n : the scope 84 | 85 | returns [batch,h/sh[0],w/sh[1],rnn_size] the output of the lstm 86 | """ 87 | 88 | with tf.variable_scope("MultiDimensionalLSTMCell-" + scope_n): 89 | 90 | # Create multidimensional cell with selected size 91 | cell = MultiDimensionalLSTMCell(rnn_size) 92 | 93 | # Get the shape of the input (batch_size, x, y, channels) 94 | shape = input_data.get_shape().as_list() 95 | batch_size = shape[0] 96 | X_dim = shape[1] 97 | Y_dim = shape[2] 98 | channels = shape[3] 99 | # Window size 100 | X_win = sh[0] 101 | Y_win = sh[1] 102 | # Get the runtime batch size 103 | batch_size_runtime = tf.shape(input_data)[0] 104 | 105 | # If the input cannot be exactly sampled by the window, we patch it with zeros 106 | if X_dim % X_win != 0: 107 | # Get offset size 108 | offset = tf.zeros([batch_size_runtime, X_win - (X_dim % X_win), Y_dim, channels]) 109 | # Concatenate X dimension 110 | input_data = tf.concat(axis=1, values=[input_data, offset]) 111 | # Get new shape 112 | shape = input_data.get_shape().as_list() 113 | # Update shape value 114 | X_dim = shape[1] 115 | 116 | # The same but for Y axis 117 | if Y_dim % Y_win != 0: 118 | # Get offset size 119 | offset = tf.zeros([batch_size_runtime, X_dim, Y_win - (Y_dim % Y_win), channels]) 120 | # Concatenate Y dimension 121 | input_data = tf.concat(axis=2, values=[input_data, offset]) 122 | # Get new shape 123 | shape = input_data.get_shape().as_list() 124 | # Update shape value 125 | Y_dim = shape[2] 126 | 127 | # Get the steps to perform in X and Y axis 128 | h, w = int(X_dim / X_win), int(Y_dim / Y_win) 129 | 130 | # Get the number of features (total number of imput values per step) 131 | features = Y_win * X_win * channels 132 | 133 | # Reshape input data to a tensor containing the step indexes and features inputs 134 | # The batch size is inferred from the tensor size 135 | x = tf.reshape(input_data, [batch_size_runtime, h, w, features]) 136 | 137 | # Reverse the selected dimensions 138 | if dims is not None: 139 | assert dims[0] is False and dims[3] is False 140 | x = tf.reverse(x, dims) 141 | 142 | # Reorder inputs to (h, w, batch_size, features) 143 | x = tf.transpose(x, [1, 2, 0, 3]) 144 | # Reshape to a one dimensional tensor of (h*w*batch_size , features) 145 | 146 | x = tf.reshape(x, [-1, features]) 147 | # Split tensor into h*w tensors of size (batch_size , features) 148 | x = tf.split(axis=0, num_or_size_splits=h * w, value=x) 149 | # Create an input tensor array (literally an array of tensors) to use inside the loop 150 | inputs_ta = tf.TensorArray(dtype=tf.float32, size=h * w, name='input_ta') 151 | # Unstack the input X in the tensor array 152 | inputs_ta = inputs_ta.unstack(x) 153 | # Create an input tensor array for the states 154 | states_ta = tf.TensorArray(dtype=tf.float32, size=h * w + 1, name='state_ta', clear_after_read=False) 155 | # And an other for the output 156 | outputs_ta = tf.TensorArray(dtype=tf.float32, size=h * w, name='output_ta') 157 | # initial cell hidden states 158 | # Write to the last position of the array, the LSTMStateTuple filled with zeros 159 | states_ta = states_ta.write(h * w, LSTMStateTuple(tf.zeros([batch_size_runtime, rnn_size], tf.float32), 160 | tf.zeros([batch_size_runtime, rnn_size], tf.float32))) 161 | 162 | # Function to get the sample skipping one row 163 | def get_up(t_, w_): 164 | return t_ - tf.constant(w_) 165 | 166 | # Function to get the previous sample 167 | def get_last(t_, w_): 168 | return t_ - tf.constant(1) 169 | 170 | # Controls the initial index 171 | time = tf.constant(0) 172 | zero = tf.constant(0) 173 | 174 | # Body of the while loop operation that applies the MD LSTM 175 | def body(time_, outputs_ta_, states_ta_): 176 | 177 | # If the current position is less or equal than the width, we are in the first row 178 | # and we need to read the zero state we added in row (h*w). 179 | # If not, get the sample located at a width distance. 180 | state_up = tf.cond(tf.less_equal(time_, tf.constant(w)), 181 | lambda: states_ta_.read(h * w), 182 | lambda: states_ta_.read(get_up(time_, w))) 183 | 184 | # If it is the first step we read the zero state if not we read the inmediate last 185 | state_last = tf.cond(tf.less(zero, tf.mod(time_, tf.constant(w))), 186 | lambda: states_ta_.read(get_last(time_, w)), 187 | lambda: states_ta_.read(h * w)) 188 | 189 | # We build the input state in both dimensions 190 | current_state = state_up[0], state_last[0], state_up[1], state_last[1] 191 | # Now we calculate the output state and the cell output 192 | out, state = cell(inputs_ta.read(time_), current_state) 193 | # We write the output to the output tensor array 194 | outputs_ta_ = outputs_ta_.write(time_, out) 195 | # And save the output state to the state tensor array 196 | states_ta_ = states_ta_.write(time_, state) 197 | 198 | # Return outputs and incremented time step 199 | return time_ + 1, outputs_ta_, states_ta_ 200 | 201 | # Loop output condition. The index, given by the time, should be less than the 202 | # total number of steps defined within the image 203 | def condition(time_, outputs_ta_, states_ta_): 204 | return tf.less(time_, tf.constant(h * w)) 205 | 206 | # Run the looped operation 207 | result, outputs_ta, states_ta = tf.while_loop(condition, body, [time, outputs_ta, states_ta], 208 | parallel_iterations=1) 209 | 210 | # Extract the output tensors from the processesed tensor array 211 | outputs = outputs_ta.stack() 212 | states = states_ta.stack() 213 | 214 | # Reshape outputs to match the shape of the input 215 | y = tf.reshape(outputs, [h, w, batch_size_runtime, rnn_size]) 216 | 217 | # Reorder te dimensions to match the input 218 | y = tf.transpose(y, [2, 0, 1, 3]) 219 | # Reverse if selected 220 | if dims is not None: 221 | y = tf.reverse(y, dims) 222 | 223 | # Return the output and the inner states 224 | return y, states 225 | 226 | 227 | def horizontal_standard_lstm(input_data, rnn_size, scope_n="layer1"): 228 | with tf.variable_scope("MultiDimensionalLSTMCell-" + scope_n): 229 | # input is (b, h, w, c) 230 | b, _, _, c = input_data.get_shape().as_list() 231 | h,w=tf.shape(input_data)[1],tf.shape(input_data)[2] 232 | # transpose = swap h and w. 233 | new_input_data = tf.reshape(input_data, (b * h, w, c)) # horizontal. 234 | rnn_out, _ = tf.nn.bidirectional_dynamic_rnn( 235 | tf.contrib.rnn.LSTMCell(rnn_size//2), 236 | tf.contrib.rnn.LSTMCell(rnn_size//2), 237 | inputs=new_input_data, 238 | dtype=tf.float32, 239 | time_major=False) 240 | rnn_out=tf.concat(rnn_out, 2) 241 | 242 | 243 | rnn_out = tf.reshape(rnn_out, (b, h, w, rnn_size)) 244 | return rnn_out 245 | 246 | 247 | def snake_standard_lstm(input_data, rnn_size, scope_n="layer1"): 248 | with tf.variable_scope("MultiDimensionalLSTMCell-" + scope_n): 249 | # input is (b, h, w, c) 250 | b, _, _, c = input_data.get_shape().as_list() 251 | h,w=tf.shape(input_data)[1],tf.shape(input_data)[2] 252 | # transpose = swap h and w. 253 | new_input_data = tf.reshape(input_data, (b, w * h, c)) # snake. 254 | rnn_out, _ = tf.nn.bidirectional_dynamic_rnn( 255 | tf.contrib.rnn.LSTMCell(rnn_size//2), 256 | tf.contrib.rnn.LSTMCell(rnn_size//2), 257 | inputs=new_input_data, 258 | dtype=tf.float32, 259 | time_major=False) 260 | rnn_out=tf.concat(rnn_out, 2) 261 | 262 | 263 | rnn_out = tf.reshape(rnn_out, (b, h, w, rnn_size)) 264 | return rnn_out 265 | 266 | def horizontal_vertical_lstm_inorder(input_data, rnn_size, scope_n="layer1"): 267 | with tf.variable_scope("MultiDimensionalLSTMCell-horizontal-" + scope_n): 268 | # input is (b, h, w, c) 269 | #horizontal 270 | b_h, _, _, c_h = input_data.get_shape().as_list() 271 | h_h,w_h=tf.shape(input_data)[1],tf.shape(input_data)[2] 272 | # transpose = swap h and w. 273 | new_input_data_h = tf.reshape(input_data, (b_h * h_h, w_h, c_h)) # horizontal. 274 | rnn_out_h, _ = tf.nn.bidirectional_dynamic_rnn( 275 | tf.contrib.rnn.LSTMCell(rnn_size//2), 276 | tf.contrib.rnn.LSTMCell(rnn_size//2), 277 | inputs=new_input_data_h, 278 | dtype=tf.float32, 279 | time_major=False) 280 | rnn_out_h=tf.concat(rnn_out_h, 2) 281 | 282 | rnn_out_h = tf.reshape(rnn_out_h, (b_h, h_h, w_h, rnn_size)) 283 | #vertical 284 | with tf.variable_scope("MultiDimensionalLSTMCell-vertical-" + scope_n): 285 | new_input_data_v=tf.transpose(rnn_out_h,(0,2,1,3)) 286 | b_v, _, _, c_v = new_input_data_v.get_shape().as_list() 287 | h_v,w_v=tf.shape(new_input_data_v)[1],tf.shape(new_input_data_v)[2] 288 | new_input_data_v = tf.reshape(new_input_data_v, (b_v * h_v, w_v, c_v)) 289 | rnn_out_v, _ = tf.nn.bidirectional_dynamic_rnn( 290 | tf.contrib.rnn.LSTMCell(rnn_size//2), 291 | tf.contrib.rnn.LSTMCell(rnn_size//2), 292 | inputs=new_input_data_v, 293 | dtype=tf.float32, 294 | time_major=False) 295 | rnn_out_v=tf.concat(rnn_out_v, 2) 296 | 297 | 298 | rnn_out_v = tf.reshape(rnn_out_v, (b_v, h_v, w_v, rnn_size)) 299 | rnn_out_v=tf.transpose(rnn_out_v,(0,2,1,3)) 300 | return rnn_out_v 301 | 302 | 303 | def horizontal_vertical_lstm_together(input_data, rnn_size, scope_n="layer1"): 304 | with tf.variable_scope("MultiDimensionalLSTMCell-horizontal-" + scope_n): 305 | # input is (b, h, w, c) 306 | #horizontal 307 | b_h, _, _, c_h = input_data.get_shape().as_list() 308 | h_h,w_h=tf.shape(input_data)[1],tf.shape(input_data)[2] 309 | # transpose = swap h and w. 310 | new_input_data_h = tf.reshape(input_data, (b_h * h_h, w_h, c_h)) # horizontal. 311 | # Forward 312 | lstm_fw_cell = tf.contrib.rnn.LSTMCell(rnn_size//4) 313 | #lstm_fw_cell = tf.contrib.rnn.DropoutWrapper(lstm_fw_cell, output_keep_prob=0.5) 314 | # Backward 315 | lstm_bw_cell = tf.contrib.rnn.LSTMCell(rnn_size//4) 316 | #lstm_bw_cell = tf.contrib.rnn.DropoutWrapper(lstm_bw_cell, output_keep_prob=0.5) 317 | 318 | 319 | rnn_out_h, _ = tf.nn.bidirectional_dynamic_rnn( 320 | lstm_fw_cell, 321 | lstm_bw_cell, 322 | inputs=new_input_data_h, 323 | dtype=tf.float32, 324 | time_major=False) 325 | rnn_out_h=tf.concat(rnn_out_h, 2) 326 | rnn_out_h = tf.reshape(rnn_out_h, (b_h, h_h, w_h, rnn_size//2)) 327 | #vertical 328 | with tf.variable_scope("MultiDimensionalLSTMCell-vertical-" + scope_n): 329 | new_input_data_v=tf.transpose(input_data,(0,2,1,3)) 330 | b_v, _, _, c_v = new_input_data_v.get_shape().as_list() 331 | h_v,w_v=tf.shape(new_input_data_v)[1],tf.shape(new_input_data_v)[2] 332 | new_input_data_v = tf.reshape(new_input_data_v, (b_v * h_v, w_v, c_v)) 333 | # Forward 334 | lstm_fw_cell = tf.contrib.rnn.LSTMCell(rnn_size//4) 335 | #lstm_fw_cell = tf.contrib.rnn.DropoutWrapper(lstm_fw_cell, output_keep_prob=0.5) 336 | # Backward 337 | lstm_bw_cell = tf.contrib.rnn.LSTMCell(rnn_size//4) 338 | #lstm_bw_cell = tf.contrib.rnn.DropoutWrapper(lstm_bw_cell, output_keep_prob=0.5) 339 | 340 | 341 | rnn_out_v, _ = tf.nn.bidirectional_dynamic_rnn( 342 | lstm_fw_cell, 343 | lstm_bw_cell, 344 | inputs=new_input_data_v, 345 | dtype=tf.float32, 346 | time_major=False) 347 | rnn_out_v=tf.concat(rnn_out_v, 2) 348 | 349 | rnn_out_v = tf.reshape(rnn_out_v, (b_v, h_v, w_v, rnn_size//2)) 350 | rnn_out_v=tf.transpose(rnn_out_v,(0,2,1,3)) 351 | rnn_out=tf.concat([rnn_out_h,rnn_out_v],axis=3) 352 | #rnn_out=tf.add(rnn_out_h,rnn_out_v) 353 | return rnn_out 354 | 355 | 356 | -------------------------------------------------------------------------------- /segmentation.py: -------------------------------------------------------------------------------- 1 | from numpy import * 2 | import tensorflow as tf 3 | from md_lstm import horizontal_vertical_lstm_inorder 4 | 5 | import os 6 | import cv2 7 | import numpy as np 8 | import scipy.ndimage as ndi 9 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 10 | 11 | 12 | 13 | class record: 14 | def __init__(self, **kw): 15 | self.__dict__.update(kw) 16 | 17 | 18 | def sl_width(s): 19 | return s.stop - s.start 20 | 21 | 22 | def sl_area(s): 23 | return sl_width(s[0]) * sl_width(s[1]) 24 | 25 | 26 | def sl_dim0(s): 27 | return sl_width(s[0]) 28 | 29 | 30 | def sl_dim1(s): 31 | return sl_width(s[1]) 32 | 33 | 34 | def sl_tuple(s): 35 | return s[0].start, s[0].stop, s[1].start, s[1].stop 36 | 37 | 38 | def hysteresis_threshold(image, lo, hi): 39 | binlo = (image > lo) 40 | lablo, n = ndi.label(binlo) 41 | n += 1 42 | good = set((lablo * (image > hi)).flat) 43 | markers = zeros(n, 'i') 44 | for index in good: 45 | if index == 0: 46 | continue 47 | markers[index] = 1 48 | return markers[lablo] 49 | 50 | 51 | def zoom_like(image, shape): 52 | h, w = shape 53 | ih, iw = image.shape 54 | scale = diag([ih * 1.0/h, iw * 1.0/w]) 55 | return ndi.affine_transform(image, scale, output_shape=(h, w), order=1) 56 | 57 | 58 | def remove_big(image, max_h=100, max_w=100): 59 | """Remove large components.""" 60 | assert image.ndim == 2 61 | bin = (image > 0.5 * amax(image)) 62 | labels, n = ndi.label(bin) 63 | objects = ndi.find_objects(labels) 64 | indexes = ones(n+1, 'i') 65 | for i, (yr, xr) in enumerate(objects): 66 | if yr.stop-yr.start < max_h and xr.stop-xr.start < max_w: 67 | continue 68 | indexes[i+1] = 0 69 | indexes[0] = 0 70 | return indexes[labels] 71 | 72 | 73 | def compute_boxmap(binary, lo=10, hi=5000, dtype='i'): 74 | objects = binary_objects(binary) 75 | bysize = sorted(objects, key=sl_area) 76 | boxmap = zeros(binary.shape, dtype) 77 | for o in bysize: 78 | if sl_area(o)**.5 < lo: 79 | continue 80 | if sl_area(o)**.5 > hi: 81 | continue 82 | boxmap[o] = 1 83 | return boxmap 84 | 85 | 86 | def binary_objects(binary): 87 | labels, n = ndi.label(binary) 88 | objects = ndi.find_objects(labels) 89 | return objects 90 | 91 | 92 | def propagate_labels(image, labels, conflict=0): 93 | """Given an image and a set of labels, apply the labels 94 | to all the regions in the image that overlap a label. 95 | Assign the value `conflict` to any labels that have a conflict.""" 96 | rlabels, _ = ndi.label(image) 97 | cors = correspondences(rlabels, labels) 98 | outputs = zeros(amax(rlabels) + 1, 'i') 99 | oops = -(1 << 30) 100 | for o, i in cors.T: 101 | if outputs[o] != 0: 102 | outputs[o] = oops 103 | else: 104 | outputs[o] = i 105 | outputs[outputs == oops] = conflict 106 | outputs[0] = 0 107 | return outputs[rlabels] 108 | 109 | 110 | def correspondences(labels1, labels2): 111 | """Given two labeled images, compute an array giving the correspondences 112 | between labels in the two images.""" 113 | q = 100000 114 | assert amin(labels1) >= 0 and amin(labels2) >= 0 115 | assert amax(labels2) < q 116 | combo = labels1 * q + labels2 117 | result = unique(combo) 118 | result = array([result // q, result % q]) 119 | return result 120 | 121 | 122 | def spread_labels(labels, maxdist=9999999): 123 | """Spread the given labels to the background""" 124 | distances, features = ndi.distance_transform_edt( 125 | labels == 0, return_distances=1, return_indices=1) 126 | indexes = features[0] * labels.shape[1] + features[1] 127 | spread = labels.ravel()[indexes.ravel()].reshape(*labels.shape) 128 | spread *= (distances < maxdist) 129 | return spread 130 | 131 | 132 | def estimate_scale(binary): 133 | objects = binary_objects(binary) 134 | bysize = sorted(objects, key=sl_area) 135 | scalemap = zeros(binary.shape) 136 | for o in bysize: 137 | if amax(scalemap[o]) > 0: 138 | continue 139 | scalemap[o] = sl_area(o)**0.5 140 | scale = median(scalemap[(scalemap > 3) & (scalemap < 100)]) 141 | return scale 142 | 143 | 144 | def compute_boxmap(binary, lo=10, hi=5000, dtype='i'): 145 | objects = binary_objects(binary) 146 | bysize = sorted(objects, key=sl_area) 147 | boxmap = zeros(binary.shape, dtype) 148 | for o in bysize: 149 | if sl_area(o)**.5 < lo: 150 | continue 151 | if sl_area(o)**.5 > hi: 152 | continue 153 | boxmap[o] = 1 154 | return boxmap 155 | 156 | 157 | def compute_lines(segmentation, scale): 158 | """Given a line segmentation map, computes a list 159 | of tuples consisting of 2D slices and masked images.""" 160 | lobjects = ndi.find_objects(segmentation) 161 | lines = [] 162 | for i, o in enumerate(lobjects): 163 | if o is None: 164 | continue 165 | if sl_dim1(o) < 2 * scale or sl_dim0(o) < scale: 166 | continue 167 | mask = (segmentation[o] == i + 1) 168 | if amax(mask) == 0: 169 | continue 170 | result = dict(label=i+1, 171 | bounds=o, 172 | mask=mask) 173 | lines.append(result) 174 | return lines 175 | 176 | 177 | def pad_image(image, d, cval=None): 178 | result = ones(array(image.shape) + 2 * d) 179 | result[:, :] = amax(image) if cval is None else cval 180 | result[d:-d, d:-d] = image 181 | return result 182 | 183 | 184 | def extract(image, y0, x0, y1, x1, mode='nearest', cval=0): 185 | h, w = image.shape 186 | ch, cw = y1 - y0, x1 - x0 187 | y, x = clip(y0, 0, max(h - ch, 0)), clip(x0, 0, max(w - cw, 0)) 188 | sub = image[y:y + ch, x:x + cw] 189 | try: 190 | r = ndi.shift(sub, (y - y0, x - x0), mode=mode, cval=cval, order=0) 191 | if cw > w or ch > h: 192 | pady0, padx0 = max(-y0, 0), max(-x0, 0) 193 | r = ndi.affine_transform(r, eye(2), offset=( 194 | pady0, padx0), cval=1, output_shape=(ch, cw)) 195 | return r 196 | 197 | except RuntimeError: 198 | # workaround for platform differences between 32bit and 64bit 199 | # scipy.ndimage 200 | dtype = sub.dtype 201 | sub = array(sub, dtype='float64') 202 | sub = ndi.shift(sub, (y - y0, x - x0), mode=mode, cval=cval, order=0) 203 | sub = array(sub, dtype=dtype) 204 | return sub 205 | 206 | 207 | def extract_masked(image, linedesc, pad=5, expand=0, background=None): 208 | """Extract a subimage from the image using the line descriptor. 209 | A line descriptor consists of bounds and a mask.""" 210 | assert amin(image) >= 0 and amax(image) <= 1 211 | if background is None or background == "min": 212 | background = amin(image) 213 | elif background == "max": 214 | background = amax(image) 215 | bounds = linedesc["bounds"] 216 | y0, x0, y1, x1 = [int(x) for x in [bounds[0].start, bounds[1].start, 217 | bounds[0].stop, bounds[1].stop]] 218 | if pad > 0: 219 | mask = pad_image(linedesc["mask"], pad, cval=0) 220 | else: 221 | mask = linedesc["mask"] 222 | line = extract(image, y0 - pad, x0 - pad, y1 + pad, x1 + pad) 223 | if expand > 0: 224 | mask = ndi.maximum_filter(mask, (expand, expand)) 225 | line = where(mask, line, background) 226 | return line 227 | 228 | 229 | def reading_order(lines, highlight=None, debug=0): 230 | """Given the list of lines (a list of 2D slices), computes 231 | the partial reading order. The output is a binary 2D array 232 | such that order[i,j] is true if line i comes before line j 233 | in reading order.""" 234 | order = zeros((len(lines), len(lines)), 'B') 235 | 236 | def x_overlaps(u, v): 237 | return u[1].start < v[1].stop and u[1].stop > v[1].start 238 | 239 | def above(u, v): 240 | return u[0].start < v[0].start 241 | 242 | def left_of(u, v): 243 | return u[1].stop < v[1].start 244 | 245 | def separates(w, u, v): 246 | if w[0].stop < min(u[0].start, v[0].start): 247 | return 0 248 | if w[0].start > max(u[0].stop, v[0].stop): 249 | return 0 250 | if w[1].start < u[1].stop and w[1].stop > v[1].start: 251 | return 1 252 | 253 | if highlight is not None: 254 | clf() 255 | title("highlight") 256 | imshow(binary) 257 | ginput(1, debug) 258 | for i, u in enumerate(lines): 259 | for j, v in enumerate(lines): 260 | if x_overlaps(u, v): 261 | if above(u, v): 262 | order[i, j] = 1 263 | else: 264 | if [w for w in lines if separates(w, u, v)] == []: 265 | if left_of(u, v): 266 | order[i, j] = 1 267 | if j == highlight and order[i, j]: 268 | print (i, j), 269 | y0, x0 = sl.center(lines[i]) 270 | y1, x1 = sl.center(lines[j]) 271 | plot([x0, x1 + 200], [y0, y1]) 272 | if highlight is not None: 273 | print() 274 | ginput(1, debug) 275 | return order 276 | 277 | 278 | def topsort(order): 279 | """Given a binary array defining a partial order (o[i,j]==True means i0.5)*1 398 | output=np.asarray(output,np.uint8) 399 | result=cv2.resize(output,(width,height)) 400 | 401 | return zoom_like(result, image.shape) 402 | 403 | def line_seeds(self, image): 404 | poutput = self.line_probs(image) 405 | binoutput = hysteresis_threshold(poutput, self.lo, self.hi) 406 | self.lines = binoutput 407 | seeds, _ = ndi.label(binoutput) 408 | return seeds 409 | 410 | def line_segmentation(self, pimage, max_size=(300, 300)): 411 | self.image = pimage 412 | self.binary = pimage > self.docthreshold 413 | if max_size is not None: 414 | self.binary = remove_big(self.binary, *max_size) 415 | self.boxmap = compute_boxmap(self.binary, dtype="B") 416 | self.seeds = self.line_seeds(pimage) 417 | self.llabels = propagate_labels(self.boxmap, self.seeds, conflict=0) 418 | self.spread = spread_labels(self.seeds, maxdist=self.basic_size) 419 | self.llabels = where(self.llabels > 0, self.llabels, 420 | self.spread * self.binary) 421 | self.segmentation = self.llabels * self.binary 422 | return self.segmentation 423 | 424 | def extract_textlines(self, image, docimage=None, max_size=(300, 300), scale=5.0, pad=5, expand=0, background=None): 425 | if len(image.shape)!=2: 426 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 427 | image=1-image/255 428 | 429 | 430 | if docimage is None: 431 | docimage = image 432 | assert image.shape == docimage.shape 433 | self.lineimage = self.line_segmentation(image, max_size=max_size) 434 | lines = compute_lines(self.lineimage, scale) 435 | for line in lines: 436 | line["image"] = extract_masked( 437 | docimage, line, pad=pad, expand=expand, background=background) 438 | return lines 439 | 440 | 441 | if __name__=="__main__": 442 | seg = Segmenter() 443 | image = cv2.imread("./make_training_labels/W1P0.png") 444 | lines = seg.extract_textlines(image) 445 | for num,line in enumerate(lines): 446 | cv2.imwrite("./lines/%d.png"%num,line['image']*255) 447 | cv2.imwrite("out.png", seg.lines*255) 448 | -------------------------------------------------------------------------------- /train_test.py: -------------------------------------------------------------------------------- 1 | #references: 2 | #https://github.com/NVlabs/ocroseg 3 | #https://github.com/philipperemy/tensorflow-multi-dimensional-lstm 4 | #https://github.com/tmbdev/ocropy 5 | #paper: 6 | #Multi-Dimensional Recurrent Neural Networks 7 | #Robust_ Simple Page Segmentation Using Hybrid Convolutional MDLSTM Networks 8 | #dataset: 9 | #https://storage.googleapis.com/tmb-ocr/uw3-framed-lines-degraded-000.tgz 10 | 11 | import os 12 | import math 13 | import tensorflow as tf 14 | from PIL import Image 15 | from functools import partial 16 | 17 | from md_lstm import multi_dimensional_rnn_while_loop,horizontal_standard_lstm,snake_standard_lstm,horizontal_vertical_lstm_inorder,horizontal_vertical_lstm_together 18 | 19 | import cv2 20 | import numpy as np 21 | 22 | 23 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 24 | batch_size=1 25 | batch_height=None 26 | batch_width=None 27 | batch_channel=1 28 | save_steps=1000 29 | 30 | 31 | def get_tf_dataset(dataset_text_file,batch_size=1, channels=1,shuffle_size=10): 32 | def _parse_function(filename, labelname): 33 | 34 | image_string = tf.read_file(filename) 35 | image_decoded = tf.image.decode_jpeg(image_string, channels=channels) 36 | #image = tf.image.resize_images(image_decoded, size) 37 | #image=255-image 38 | image = tf.cast(image_decoded, tf.float32) * (1. / 255) 39 | 40 | label_string = tf.read_file(labelname) 41 | label_decoded = tf.image.decode_jpeg(label_string, channels=channels) 42 | label = tf.image.resize_images(label_decoded, [tf.round(tf.div(tf.shape(label_decoded)[0],4)),tf.round(tf.div(tf.shape(label_decoded)[1],4))]) 43 | label = tf.cast(label, tf.float32) * (1. / 255) 44 | label =tf.cast(label>0.5, tf.float32) 45 | 46 | 47 | return image, label 48 | 49 | def read_labeled_image_list(dataset_text_file): 50 | base_dir="./uw3/" 51 | filenames=[] 52 | labels=[] 53 | with open(dataset_text_file,"r",encoding="utf-8") as f_l: 54 | filenames_lables=f_l.readlines() 55 | 56 | one_epoch_num=len(filenames_lables) 57 | 58 | for filename_lable in filenames_lables: 59 | filenames.append(base_dir+filename_lable.split(" ")[0]) 60 | labels.append(base_dir+filename_lable.split(" ")[1].strip("\n")) 61 | return filenames,labels,one_epoch_num 62 | 63 | filenames, labels,one_epoch_num = read_labeled_image_list(dataset_text_file) 64 | 65 | filenames = tf.constant(filenames, name='filename_list') 66 | labels = tf.constant(labels, name='label_list') 67 | 68 | #tensorflow1.3:tf.contrib.data.Dataset.from_tensor_slices 69 | #tensorflow1.4+:tf.data.Dataset.from_tensor_slices 70 | dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) 71 | dataset = dataset.shuffle(shuffle_size) 72 | dataset = dataset.map(_parse_function) 73 | dataset = dataset.batch(batch_size=batch_size) 74 | dataset = dataset.repeat() 75 | 76 | return dataset,one_epoch_num 77 | 78 | 79 | def network(is_training=False): 80 | network = {} 81 | network["inputs"] = tf.placeholder(tf.float32, [batch_size, batch_height,batch_width, batch_channel], 82 | name='inputs') 83 | network["conv1"] = tf.layers.conv2d(inputs=network["inputs"], filters=32, kernel_size=(3, 3), padding="same", 84 | activation=None, name="conv1") 85 | #with tf.variable_scope("BN"): 86 | network["batch_norm1"] = tf.contrib.layers.batch_norm( 87 | network["conv1"], 88 | decay=0.9, 89 | center=True, 90 | scale=True, 91 | epsilon=0.001, 92 | updates_collections=None, 93 | is_training=is_training, 94 | zero_debias_moving_mean=True, 95 | scope="BN1") 96 | network["batch_norm1"] = tf.nn.relu(network["batch_norm1"]) 97 | network["pool1"] = tf.layers.max_pooling2d(inputs=network["batch_norm1"], pool_size=[2, 2], strides=2) 98 | network["conv2"] = tf.layers.conv2d(inputs=network["pool1"], filters=64, kernel_size=(3, 3), padding="same", 99 | activation=None, name="conv2") 100 | #with tf.variable_scope("BN"): 101 | network["batch_norm2"] = tf.contrib.layers.batch_norm( 102 | network["conv2"], 103 | decay=0.9, 104 | center=True, 105 | scale=True, 106 | epsilon=0.001, 107 | updates_collections=None, 108 | is_training=is_training, 109 | scope="BN2") 110 | network["batch_norm2"] = tf.nn.relu(network["batch_norm2"]) 111 | network["pool2"] = tf.layers.max_pooling2d(inputs=network["batch_norm2"], pool_size=[2, 2], strides=2) 112 | network["conv3"] = tf.layers.conv2d(inputs=network["pool2"], filters=128, kernel_size=(3, 3), padding="same", 113 | activation=None, name="conv3") 114 | #with tf.variable_scope("BN"): 115 | network["batch_norm3"] = tf.contrib.layers.batch_norm( 116 | network["conv3"], 117 | decay=0.9, 118 | center=True, 119 | scale=True, 120 | epsilon=0.001, 121 | updates_collections=None, 122 | is_training=is_training, 123 | scope="BN3") 124 | network["batch_norm3"] = tf.nn.relu(network["batch_norm3"]) 125 | 126 | network["LSTM2D1"] = horizontal_vertical_lstm_inorder(rnn_size=128, input_data=network["batch_norm3"], scope_n="LSTM2D1") 127 | #network["LSTM2D1"] = horizontal_vertical_lstm_together(rnn_size=128, input_data=network["batch_norm3"], scope_n="LSTM2D1") 128 | #network["LSTM2D1"] = horizontal_standard_lstm(rnn_size=128, input_data=network["batch_norm3"], scope_n="LSTM2D1") 129 | #network["LSTM2D1"] = snake_standard_lstm(rnn_size=128, input_data=network["batch_norm3"], scope_n="LSTM2D1") 130 | #network["LSTM2D1"], _ = multi_dimensional_rnn_while_loop(rnn_size=128, input_data=network["batch_norm3"],sh=[1, 1], dims=None, scope_n="LSTM2D1") 131 | 132 | network["conv4"] = tf.layers.conv2d(inputs=network["LSTM2D1"], filters=64, kernel_size=(3, 3), padding="same", 133 | activation=None, name="conv4") 134 | #with tf.variable_scope("BN"): 135 | network["batch_norm4"] = tf.contrib.layers.batch_norm( 136 | network["conv4"], 137 | decay=0.9, 138 | center=True, 139 | scale=True, 140 | epsilon=0.001, 141 | updates_collections=None, 142 | is_training=is_training, 143 | scope="BN4") 144 | network["batch_norm4"] = tf.nn.relu(network["batch_norm4"]) 145 | network["LSTM2D2"] = horizontal_vertical_lstm_inorder(rnn_size=128, input_data=network["batch_norm4"], scope_n="LSTM2D2") 146 | #network["LSTM2D2"] = horizontal_vertical_lstm_together(rnn_size=128, input_data=network["batch_norm4"], scope_n="LSTM2D2") 147 | #network["LSTM2D2"] = horizontal_standard_lstm(rnn_size=128, input_data=network["batch_norm4"], scope_n="LSTM2D2") 148 | #network["LSTM2D2"] = snake_standard_lstm(rnn_size=128, input_data=network["batch_norm4"], scope_n="LSTM2D2") 149 | #network["LSTM2D2"], _ = multi_dimensional_rnn_while_loop(rnn_size=128, input_data=network["batch_norm4"],sh=[1, 1], dims=None, scope_n="LSTM2D2") 150 | 151 | network["conv5"] = tf.layers.conv2d(inputs=network["LSTM2D2"], filters=1, kernel_size=(3, 3), padding="same", 152 | activation=None, name="conv5") 153 | network["outputs"] = tf.nn.sigmoid(network["conv5"]) 154 | return network 155 | 156 | 157 | 158 | 159 | 160 | 161 | def train(): 162 | #network 163 | global_step = tf.Variable(0, trainable=False) 164 | learning_rate = tf.train.exponential_decay(learning_rate=0.001, 165 | global_step=global_step, 166 | decay_steps=10000, 167 | decay_rate=0.1, 168 | staircase=True) 169 | 170 | model=network(is_training=True) 171 | y_ = tf.placeholder(tf.float32, [batch_size, batch_height, batch_width, batch_channel], name='labels') 172 | loss = tf.reduce_mean(tf.losses.mean_squared_error(labels=y_,predictions=model["outputs"])) 173 | accuracy=tf.reduce_sum(tf.cast((tf.cast(model["outputs"]>0.5,tf.int32)+tf.cast(y_,tf.int32))>1,tf.float32))/tf.reduce_sum(y_) 174 | 175 | 176 | update_ops= tf.get_collection(tf.GraphKeys.UPDATE_OPS) 177 | if update_ops: 178 | with tf.control_dependencies(update_ops): 179 | grad_update = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss,global_step=global_step) 180 | else: 181 | grad_update = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss,global_step=global_step) 182 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1.0) 183 | 184 | 185 | #tensorboard 186 | tf.summary.scalar("loss", loss) 187 | tf.summary.scalar("accuracy", accuracy) 188 | for update_op in update_ops: 189 | tf.summary.histogram(update_op.name, update_op) 190 | for var in tf.trainable_variables(): 191 | tf.summary.histogram(var.name, var) 192 | merge_summary = tf.summary.merge_all() 193 | 194 | dataset,one_epoch_num = get_tf_dataset(dataset_text_file="./uw3/label.txt",batch_size=batch_size,channels=batch_channel) 195 | iterator = dataset.make_one_shot_iterator() 196 | img_batch, label_batch = iterator.get_next() 197 | 198 | init = tf.global_variables_initializer() 199 | 200 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as session: 201 | session.run(init) 202 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=20) 203 | #saver.restore(save_path="./save/ocrseg.ckpt-2000", sess=session) 204 | 205 | #tensorboard 206 | summary_writer = tf.summary.FileWriter("./summary/", session.graph) 207 | 208 | epoch=0 209 | while True: 210 | try: 211 | img_batch_i, label_batch_i = session.run([img_batch,label_batch]) 212 | 213 | if img_batch_i.shape[0]!=batch_size: 214 | print("the last iter of one epoch") 215 | continue 216 | except tf.errors.OutOfRangeError: 217 | print("one epoch over!") 218 | continue 219 | feed = {model["inputs"]: img_batch_i,y_: label_batch_i} 220 | learning_rate_train,loss_train,accuracy_train,step,summary,_=session.run([learning_rate,loss,accuracy,global_step,merge_summary,grad_update], feed_dict=feed) 221 | print("learning rate:%f epoch:%d iter:%d loss:%f accuracy:%f"%(learning_rate_train,epoch,step,loss_train,accuracy_train)) 222 | 223 | #tensorboard 224 | summary_writer.add_summary(summary, step) 225 | 226 | if step > 0 and step % save_steps == 0: 227 | save_path = saver.save(session, "save/ocrseg.ckpt", global_step=step) 228 | print(save_path) 229 | if step > 0: 230 | epoch=step*batch_size//one_epoch_num 231 | 232 | def test(): 233 | model=network(is_training=False) 234 | init = tf.global_variables_initializer() 235 | with tf.Session() as session: 236 | session.run(init) 237 | saver = tf.train.Saver(tf.global_variables()) 238 | saver.restore(save_path="./save/ocrseg.ckpt-1000",sess=session) 239 | 240 | image=cv2.imread("./make_training_labels/W001.png",0) 241 | height,width=image.shape 242 | image=image.reshape((1,image.shape[0],image.shape[1],1)) 243 | image=1-image/255 244 | feed = {model["inputs"]: image} 245 | output = session.run(model["outputs"], feed_dict=feed) 246 | output = output.reshape((output.shape[1],output.shape[2])) 247 | print("max:%f min:%f"%(np.max(np.max(output)),np.min(np.min(output)))) 248 | output=(output>0.5)*255 249 | output=np.asarray(output,np.uint8) 250 | output=cv2.resize(output,(width,height)) 251 | cv2.imwrite("out.png",output) 252 | 253 | if __name__=="__main__": 254 | train() 255 | #test() 256 | --------------------------------------------------------------------------------