├── 70-lstm-layout.pdf
├── 80-lstm-text.pdf
├── README.md
├── lines
└── 0.png
├── make_training_labels
├── K47LYBN.framed.png
├── K47LYBN.framed.txt
├── K47LYBN.lines.png
├── W001.png
├── data_pre_process.py
└── out.png
├── md_lstm.py
├── segmentation.py
└── train_test.py
/70-lstm-layout.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watersink/ocrsegment/995cf725d277cd4b892f033aa6f9ec81965bd743/70-lstm-layout.pdf
--------------------------------------------------------------------------------
/80-lstm-text.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watersink/ocrsegment/995cf725d277cd4b892f033aa6f9ec81965bd743/80-lstm-text.pdf
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OCR Segmentation
2 | a deep learning model for page layout analysis / segmentation.
3 |
4 | ## dependencies
5 | tensorflow1.8
6 | >
7 | python3
8 |
9 | ## dataset:
10 | [uw3-framed-lines-degraded-000](https://storage.googleapis.com/tmb-ocr/uw3-framed-lines-degraded-000.tgz)
11 |
12 | ## make training labels
13 | python3 data_pre_process.py
14 |
15 | ## train
16 | python3 train_test.py
17 | ## test
18 | python3 segmentation.py
19 | 
20 | 
21 | 
22 |
23 | ## references
24 | [Multi-Dimensional Recurrent Neural Networks](https://arxiv.org/abs/0705.2011)
25 | [Robust_ Simple Page Segmentation Using Hybrid Convolutional MDLSTM Networks](https://github.com/wanghaisheng/awesome-ocr/files/2042377/Robust_.Simple.Page.Segmentation.Using.Hybrid.Convolutional.MDLSTM.Networks.pdf)
26 | [https://github.com/NVlabs/ocroseg](https://github.com/NVlabs/ocroseg)
27 | [https://github.com/philipperemy/tensorflow-multi-dimensional-lstm](https://github.com/philipperemy/tensorflow-multi-dimensional-lstm)
28 |
29 |
30 |
--------------------------------------------------------------------------------
/lines/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watersink/ocrsegment/995cf725d277cd4b892f033aa6f9ec81965bd743/lines/0.png
--------------------------------------------------------------------------------
/make_training_labels/K47LYBN.framed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watersink/ocrsegment/995cf725d277cd4b892f033aa6f9ec81965bd743/make_training_labels/K47LYBN.framed.png
--------------------------------------------------------------------------------
/make_training_labels/K47LYBN.framed.txt:
--------------------------------------------------------------------------------
1 | 939,259,1436,305
2 | 1073,3236,1118,3265
3 | 262,448,2104,504
4 | 957,2918,1395,2964
5 | 864,2982,1487,3036
6 | 258,539,373,587
7 | 356,621,2104,671
8 | 263,699,2102,755
9 | 259,790,1204,839
10 | 357,867,2103,922
11 | 262,956,2094,1006
12 | 257,1040,2102,1090
13 | 258,1123,2102,1179
14 | 258,1200,2100,1255
15 | 257,1289,2102,1338
16 | 258,1366,2104,1428
17 | 257,1449,1520,1504
18 | 354,1532,2075,1587
19 | 256,1614,2101,1677
20 | 255,1704,2105,1760
21 | 256,1782,688,1837
22 |
--------------------------------------------------------------------------------
/make_training_labels/K47LYBN.lines.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watersink/ocrsegment/995cf725d277cd4b892f033aa6f9ec81965bd743/make_training_labels/K47LYBN.lines.png
--------------------------------------------------------------------------------
/make_training_labels/W001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watersink/ocrsegment/995cf725d277cd4b892f033aa6f9ec81965bd743/make_training_labels/W001.png
--------------------------------------------------------------------------------
/make_training_labels/data_pre_process.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import os
4 | import scipy.ndimage as ndi
5 |
6 |
7 | def _get_text_line_image(image,line_label):
8 | x1 = line_label[0]
9 | y1 = line_label[1]
10 | x2 = line_label[2]
11 | y2 = line_label[3]
12 | h = y2-y1
13 | image_line = image[y1:y2, x1:x2]
14 |
15 | nlabels, labels, stats, centroids = cv2.connectedComponentsWithStats(image_line)
16 |
17 | image_connect = np.zeros(image_line.shape, dtype=np.uint8)
18 |
19 |
20 | for i in range(1, stats.shape[0]):
21 | x1 = stats[i, 0]
22 | y1 = stats[i, 1]
23 | x2 = stats[i, 2] + x1
24 | y2 = stats[i, 3] + y1
25 | image_connect[y1:y2, x1:x2] = 255
26 |
27 |
28 | image_anigauss = ndi.gaussian_filter(image_connect, sigma=[h/2, h], mode='constant', cval=0)
29 | image_line_center_index = np.argmax(image_anigauss, 0)
30 |
31 | image_center_gauss_index = ndi.gaussian_filter1d(image_line_center_index, sigma=h/3)
32 | for k in range(len(image_center_gauss_index)):
33 | image_anigauss[image_center_gauss_index[k], k] = 255
34 |
35 | image_bin2 = (np.uint8(image_anigauss)>254)*255.0
36 |
37 | kernel = np.ones((3, 3), np.uint8)
38 | image_dilation = cv2.dilate(image_bin2, kernel, iterations=3)
39 | temp = (np.uint8(image_dilation)>254)*255
40 |
41 | return temp
42 |
43 |
44 |
45 |
46 | if __name__ == '__main__':
47 | image_name="K47LYBN.framed.png"
48 | image = cv2.imread(image_name, 0)
49 | _, image = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
50 |
51 | image = 255-image
52 | label_image = np.zeros_like(image)
53 | text_file = os.path.splitext(image_name)[0]+'.txt'
54 | with open(text_file, "r", encoding="utf-8") as f:
55 | labels = f.readlines()
56 | labels=[labels[i].rstrip("\n").split(",") for i in range(len(labels))]
57 | label_arrays=np.asarray(labels,np.int32)
58 | for line in label_arrays:
59 | label_line_image = _get_text_line_image(image,line)
60 | x1 = line[0]
61 | y1 = line[1]
62 | x2 = line[2]
63 | y2 = line[3]
64 | label_image[y1:y2, x1:x2] = label_line_image
65 |
66 | cv2.imwrite(image_name.replace("framed","lines"), label_image)
67 |
68 |
--------------------------------------------------------------------------------
/make_training_labels/out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watersink/ocrsegment/995cf725d277cd4b892f033aa6f9ec81965bd743/make_training_labels/out.png
--------------------------------------------------------------------------------
/md_lstm.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib.rnn import RNNCell, LSTMStateTuple
3 | from tensorflow.contrib.rnn.python.ops.core_rnn_cell import _linear
4 | from tensorflow.python.ops.rnn import dynamic_rnn
5 |
6 |
7 | def ln(tensor, scope=None, epsilon=1e-5):
8 | """ Layer normalizes a 2D tensor along its second axis """
9 | assert (len(tensor.get_shape()) == 2)
10 | m, v = tf.nn.moments(tensor, [1], keep_dims=True)
11 | if not isinstance(scope, str):
12 | scope = ''
13 | with tf.variable_scope(scope + 'layer_norm'):
14 | scale = tf.get_variable('scale',
15 | shape=[tensor.get_shape()[1]],
16 | initializer=tf.constant_initializer(1))
17 | shift = tf.get_variable('shift',
18 | shape=[tensor.get_shape()[1]],
19 | initializer=tf.constant_initializer(0))
20 | ln_initial = (tensor - m) / tf.sqrt(v + epsilon)
21 |
22 | return ln_initial * scale + shift
23 |
24 |
25 | class MultiDimensionalLSTMCell(RNNCell):
26 | """
27 | Adapted from TF's BasicLSTMCell to use Layer Normalization.
28 | Note that state_is_tuple is always True.
29 | """
30 |
31 | def __init__(self, num_units, forget_bias=0.0, activation=tf.nn.tanh):
32 | self._num_units = num_units
33 | self._forget_bias = forget_bias
34 | self._activation = activation
35 |
36 | @property
37 | def state_size(self):
38 | return LSTMStateTuple(self._num_units, self._num_units)
39 |
40 | @property
41 | def output_size(self):
42 | return self._num_units
43 |
44 | def __call__(self, inputs, state, scope=None):
45 | """Long short-term memory cell (LSTM).
46 | @param: inputs (batch,n)
47 | @param state: the states and hidden unit of the two cells
48 | """
49 | with tf.variable_scope(scope or type(self).__name__):
50 | c1, c2, h1, h2 = state
51 |
52 | # change bias argument to False since LN will add bias via shift
53 | concat = _linear([inputs, h1, h2], 5 * self._num_units, False)
54 |
55 | i, j, f1, f2, o = tf.split(value=concat, num_or_size_splits=5, axis=1)
56 |
57 | # add layer normalization to each gate
58 | i = ln(i, scope='i/')
59 | j = ln(j, scope='j/')
60 | f1 = ln(f1, scope='f1/')
61 | f2 = ln(f2, scope='f2/')
62 | o = ln(o, scope='o/')
63 |
64 | new_c = (c1 * tf.nn.sigmoid(f1 + self._forget_bias) +
65 | c2 * tf.nn.sigmoid(f2 + self._forget_bias) + tf.nn.sigmoid(i) *
66 | self._activation(j))
67 |
68 | # add layer_normalization in calculation of new hidden state
69 | new_h = self._activation(ln(new_c, scope='new_h/')) * tf.nn.sigmoid(o)
70 | new_state = LSTMStateTuple(new_c, new_h)
71 |
72 | return new_h, new_state
73 |
74 |
75 | def multi_dimensional_rnn_while_loop(rnn_size, input_data, sh, dims=None, scope_n="layer1"):
76 | """Implements naive multi dimension recurrent neural networks
77 |
78 | @param rnn_size: the hidden units
79 | @param input_data: the data to process of shape [batch,h,w,channels]
80 | @param sh: [height,width] of the windows
81 | @param dims: dimensions to reverse the input data,eg.
82 | dims=[False,True,True,False] => true means reverse dimension
83 | @param scope_n : the scope
84 |
85 | returns [batch,h/sh[0],w/sh[1],rnn_size] the output of the lstm
86 | """
87 |
88 | with tf.variable_scope("MultiDimensionalLSTMCell-" + scope_n):
89 |
90 | # Create multidimensional cell with selected size
91 | cell = MultiDimensionalLSTMCell(rnn_size)
92 |
93 | # Get the shape of the input (batch_size, x, y, channels)
94 | shape = input_data.get_shape().as_list()
95 | batch_size = shape[0]
96 | X_dim = shape[1]
97 | Y_dim = shape[2]
98 | channels = shape[3]
99 | # Window size
100 | X_win = sh[0]
101 | Y_win = sh[1]
102 | # Get the runtime batch size
103 | batch_size_runtime = tf.shape(input_data)[0]
104 |
105 | # If the input cannot be exactly sampled by the window, we patch it with zeros
106 | if X_dim % X_win != 0:
107 | # Get offset size
108 | offset = tf.zeros([batch_size_runtime, X_win - (X_dim % X_win), Y_dim, channels])
109 | # Concatenate X dimension
110 | input_data = tf.concat(axis=1, values=[input_data, offset])
111 | # Get new shape
112 | shape = input_data.get_shape().as_list()
113 | # Update shape value
114 | X_dim = shape[1]
115 |
116 | # The same but for Y axis
117 | if Y_dim % Y_win != 0:
118 | # Get offset size
119 | offset = tf.zeros([batch_size_runtime, X_dim, Y_win - (Y_dim % Y_win), channels])
120 | # Concatenate Y dimension
121 | input_data = tf.concat(axis=2, values=[input_data, offset])
122 | # Get new shape
123 | shape = input_data.get_shape().as_list()
124 | # Update shape value
125 | Y_dim = shape[2]
126 |
127 | # Get the steps to perform in X and Y axis
128 | h, w = int(X_dim / X_win), int(Y_dim / Y_win)
129 |
130 | # Get the number of features (total number of imput values per step)
131 | features = Y_win * X_win * channels
132 |
133 | # Reshape input data to a tensor containing the step indexes and features inputs
134 | # The batch size is inferred from the tensor size
135 | x = tf.reshape(input_data, [batch_size_runtime, h, w, features])
136 |
137 | # Reverse the selected dimensions
138 | if dims is not None:
139 | assert dims[0] is False and dims[3] is False
140 | x = tf.reverse(x, dims)
141 |
142 | # Reorder inputs to (h, w, batch_size, features)
143 | x = tf.transpose(x, [1, 2, 0, 3])
144 | # Reshape to a one dimensional tensor of (h*w*batch_size , features)
145 |
146 | x = tf.reshape(x, [-1, features])
147 | # Split tensor into h*w tensors of size (batch_size , features)
148 | x = tf.split(axis=0, num_or_size_splits=h * w, value=x)
149 | # Create an input tensor array (literally an array of tensors) to use inside the loop
150 | inputs_ta = tf.TensorArray(dtype=tf.float32, size=h * w, name='input_ta')
151 | # Unstack the input X in the tensor array
152 | inputs_ta = inputs_ta.unstack(x)
153 | # Create an input tensor array for the states
154 | states_ta = tf.TensorArray(dtype=tf.float32, size=h * w + 1, name='state_ta', clear_after_read=False)
155 | # And an other for the output
156 | outputs_ta = tf.TensorArray(dtype=tf.float32, size=h * w, name='output_ta')
157 | # initial cell hidden states
158 | # Write to the last position of the array, the LSTMStateTuple filled with zeros
159 | states_ta = states_ta.write(h * w, LSTMStateTuple(tf.zeros([batch_size_runtime, rnn_size], tf.float32),
160 | tf.zeros([batch_size_runtime, rnn_size], tf.float32)))
161 |
162 | # Function to get the sample skipping one row
163 | def get_up(t_, w_):
164 | return t_ - tf.constant(w_)
165 |
166 | # Function to get the previous sample
167 | def get_last(t_, w_):
168 | return t_ - tf.constant(1)
169 |
170 | # Controls the initial index
171 | time = tf.constant(0)
172 | zero = tf.constant(0)
173 |
174 | # Body of the while loop operation that applies the MD LSTM
175 | def body(time_, outputs_ta_, states_ta_):
176 |
177 | # If the current position is less or equal than the width, we are in the first row
178 | # and we need to read the zero state we added in row (h*w).
179 | # If not, get the sample located at a width distance.
180 | state_up = tf.cond(tf.less_equal(time_, tf.constant(w)),
181 | lambda: states_ta_.read(h * w),
182 | lambda: states_ta_.read(get_up(time_, w)))
183 |
184 | # If it is the first step we read the zero state if not we read the inmediate last
185 | state_last = tf.cond(tf.less(zero, tf.mod(time_, tf.constant(w))),
186 | lambda: states_ta_.read(get_last(time_, w)),
187 | lambda: states_ta_.read(h * w))
188 |
189 | # We build the input state in both dimensions
190 | current_state = state_up[0], state_last[0], state_up[1], state_last[1]
191 | # Now we calculate the output state and the cell output
192 | out, state = cell(inputs_ta.read(time_), current_state)
193 | # We write the output to the output tensor array
194 | outputs_ta_ = outputs_ta_.write(time_, out)
195 | # And save the output state to the state tensor array
196 | states_ta_ = states_ta_.write(time_, state)
197 |
198 | # Return outputs and incremented time step
199 | return time_ + 1, outputs_ta_, states_ta_
200 |
201 | # Loop output condition. The index, given by the time, should be less than the
202 | # total number of steps defined within the image
203 | def condition(time_, outputs_ta_, states_ta_):
204 | return tf.less(time_, tf.constant(h * w))
205 |
206 | # Run the looped operation
207 | result, outputs_ta, states_ta = tf.while_loop(condition, body, [time, outputs_ta, states_ta],
208 | parallel_iterations=1)
209 |
210 | # Extract the output tensors from the processesed tensor array
211 | outputs = outputs_ta.stack()
212 | states = states_ta.stack()
213 |
214 | # Reshape outputs to match the shape of the input
215 | y = tf.reshape(outputs, [h, w, batch_size_runtime, rnn_size])
216 |
217 | # Reorder te dimensions to match the input
218 | y = tf.transpose(y, [2, 0, 1, 3])
219 | # Reverse if selected
220 | if dims is not None:
221 | y = tf.reverse(y, dims)
222 |
223 | # Return the output and the inner states
224 | return y, states
225 |
226 |
227 | def horizontal_standard_lstm(input_data, rnn_size, scope_n="layer1"):
228 | with tf.variable_scope("MultiDimensionalLSTMCell-" + scope_n):
229 | # input is (b, h, w, c)
230 | b, _, _, c = input_data.get_shape().as_list()
231 | h,w=tf.shape(input_data)[1],tf.shape(input_data)[2]
232 | # transpose = swap h and w.
233 | new_input_data = tf.reshape(input_data, (b * h, w, c)) # horizontal.
234 | rnn_out, _ = tf.nn.bidirectional_dynamic_rnn(
235 | tf.contrib.rnn.LSTMCell(rnn_size//2),
236 | tf.contrib.rnn.LSTMCell(rnn_size//2),
237 | inputs=new_input_data,
238 | dtype=tf.float32,
239 | time_major=False)
240 | rnn_out=tf.concat(rnn_out, 2)
241 |
242 |
243 | rnn_out = tf.reshape(rnn_out, (b, h, w, rnn_size))
244 | return rnn_out
245 |
246 |
247 | def snake_standard_lstm(input_data, rnn_size, scope_n="layer1"):
248 | with tf.variable_scope("MultiDimensionalLSTMCell-" + scope_n):
249 | # input is (b, h, w, c)
250 | b, _, _, c = input_data.get_shape().as_list()
251 | h,w=tf.shape(input_data)[1],tf.shape(input_data)[2]
252 | # transpose = swap h and w.
253 | new_input_data = tf.reshape(input_data, (b, w * h, c)) # snake.
254 | rnn_out, _ = tf.nn.bidirectional_dynamic_rnn(
255 | tf.contrib.rnn.LSTMCell(rnn_size//2),
256 | tf.contrib.rnn.LSTMCell(rnn_size//2),
257 | inputs=new_input_data,
258 | dtype=tf.float32,
259 | time_major=False)
260 | rnn_out=tf.concat(rnn_out, 2)
261 |
262 |
263 | rnn_out = tf.reshape(rnn_out, (b, h, w, rnn_size))
264 | return rnn_out
265 |
266 | def horizontal_vertical_lstm_inorder(input_data, rnn_size, scope_n="layer1"):
267 | with tf.variable_scope("MultiDimensionalLSTMCell-horizontal-" + scope_n):
268 | # input is (b, h, w, c)
269 | #horizontal
270 | b_h, _, _, c_h = input_data.get_shape().as_list()
271 | h_h,w_h=tf.shape(input_data)[1],tf.shape(input_data)[2]
272 | # transpose = swap h and w.
273 | new_input_data_h = tf.reshape(input_data, (b_h * h_h, w_h, c_h)) # horizontal.
274 | rnn_out_h, _ = tf.nn.bidirectional_dynamic_rnn(
275 | tf.contrib.rnn.LSTMCell(rnn_size//2),
276 | tf.contrib.rnn.LSTMCell(rnn_size//2),
277 | inputs=new_input_data_h,
278 | dtype=tf.float32,
279 | time_major=False)
280 | rnn_out_h=tf.concat(rnn_out_h, 2)
281 |
282 | rnn_out_h = tf.reshape(rnn_out_h, (b_h, h_h, w_h, rnn_size))
283 | #vertical
284 | with tf.variable_scope("MultiDimensionalLSTMCell-vertical-" + scope_n):
285 | new_input_data_v=tf.transpose(rnn_out_h,(0,2,1,3))
286 | b_v, _, _, c_v = new_input_data_v.get_shape().as_list()
287 | h_v,w_v=tf.shape(new_input_data_v)[1],tf.shape(new_input_data_v)[2]
288 | new_input_data_v = tf.reshape(new_input_data_v, (b_v * h_v, w_v, c_v))
289 | rnn_out_v, _ = tf.nn.bidirectional_dynamic_rnn(
290 | tf.contrib.rnn.LSTMCell(rnn_size//2),
291 | tf.contrib.rnn.LSTMCell(rnn_size//2),
292 | inputs=new_input_data_v,
293 | dtype=tf.float32,
294 | time_major=False)
295 | rnn_out_v=tf.concat(rnn_out_v, 2)
296 |
297 |
298 | rnn_out_v = tf.reshape(rnn_out_v, (b_v, h_v, w_v, rnn_size))
299 | rnn_out_v=tf.transpose(rnn_out_v,(0,2,1,3))
300 | return rnn_out_v
301 |
302 |
303 | def horizontal_vertical_lstm_together(input_data, rnn_size, scope_n="layer1"):
304 | with tf.variable_scope("MultiDimensionalLSTMCell-horizontal-" + scope_n):
305 | # input is (b, h, w, c)
306 | #horizontal
307 | b_h, _, _, c_h = input_data.get_shape().as_list()
308 | h_h,w_h=tf.shape(input_data)[1],tf.shape(input_data)[2]
309 | # transpose = swap h and w.
310 | new_input_data_h = tf.reshape(input_data, (b_h * h_h, w_h, c_h)) # horizontal.
311 | # Forward
312 | lstm_fw_cell = tf.contrib.rnn.LSTMCell(rnn_size//4)
313 | #lstm_fw_cell = tf.contrib.rnn.DropoutWrapper(lstm_fw_cell, output_keep_prob=0.5)
314 | # Backward
315 | lstm_bw_cell = tf.contrib.rnn.LSTMCell(rnn_size//4)
316 | #lstm_bw_cell = tf.contrib.rnn.DropoutWrapper(lstm_bw_cell, output_keep_prob=0.5)
317 |
318 |
319 | rnn_out_h, _ = tf.nn.bidirectional_dynamic_rnn(
320 | lstm_fw_cell,
321 | lstm_bw_cell,
322 | inputs=new_input_data_h,
323 | dtype=tf.float32,
324 | time_major=False)
325 | rnn_out_h=tf.concat(rnn_out_h, 2)
326 | rnn_out_h = tf.reshape(rnn_out_h, (b_h, h_h, w_h, rnn_size//2))
327 | #vertical
328 | with tf.variable_scope("MultiDimensionalLSTMCell-vertical-" + scope_n):
329 | new_input_data_v=tf.transpose(input_data,(0,2,1,3))
330 | b_v, _, _, c_v = new_input_data_v.get_shape().as_list()
331 | h_v,w_v=tf.shape(new_input_data_v)[1],tf.shape(new_input_data_v)[2]
332 | new_input_data_v = tf.reshape(new_input_data_v, (b_v * h_v, w_v, c_v))
333 | # Forward
334 | lstm_fw_cell = tf.contrib.rnn.LSTMCell(rnn_size//4)
335 | #lstm_fw_cell = tf.contrib.rnn.DropoutWrapper(lstm_fw_cell, output_keep_prob=0.5)
336 | # Backward
337 | lstm_bw_cell = tf.contrib.rnn.LSTMCell(rnn_size//4)
338 | #lstm_bw_cell = tf.contrib.rnn.DropoutWrapper(lstm_bw_cell, output_keep_prob=0.5)
339 |
340 |
341 | rnn_out_v, _ = tf.nn.bidirectional_dynamic_rnn(
342 | lstm_fw_cell,
343 | lstm_bw_cell,
344 | inputs=new_input_data_v,
345 | dtype=tf.float32,
346 | time_major=False)
347 | rnn_out_v=tf.concat(rnn_out_v, 2)
348 |
349 | rnn_out_v = tf.reshape(rnn_out_v, (b_v, h_v, w_v, rnn_size//2))
350 | rnn_out_v=tf.transpose(rnn_out_v,(0,2,1,3))
351 | rnn_out=tf.concat([rnn_out_h,rnn_out_v],axis=3)
352 | #rnn_out=tf.add(rnn_out_h,rnn_out_v)
353 | return rnn_out
354 |
355 |
356 |
--------------------------------------------------------------------------------
/segmentation.py:
--------------------------------------------------------------------------------
1 | from numpy import *
2 | import tensorflow as tf
3 | from md_lstm import horizontal_vertical_lstm_inorder
4 |
5 | import os
6 | import cv2
7 | import numpy as np
8 | import scipy.ndimage as ndi
9 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
10 |
11 |
12 |
13 | class record:
14 | def __init__(self, **kw):
15 | self.__dict__.update(kw)
16 |
17 |
18 | def sl_width(s):
19 | return s.stop - s.start
20 |
21 |
22 | def sl_area(s):
23 | return sl_width(s[0]) * sl_width(s[1])
24 |
25 |
26 | def sl_dim0(s):
27 | return sl_width(s[0])
28 |
29 |
30 | def sl_dim1(s):
31 | return sl_width(s[1])
32 |
33 |
34 | def sl_tuple(s):
35 | return s[0].start, s[0].stop, s[1].start, s[1].stop
36 |
37 |
38 | def hysteresis_threshold(image, lo, hi):
39 | binlo = (image > lo)
40 | lablo, n = ndi.label(binlo)
41 | n += 1
42 | good = set((lablo * (image > hi)).flat)
43 | markers = zeros(n, 'i')
44 | for index in good:
45 | if index == 0:
46 | continue
47 | markers[index] = 1
48 | return markers[lablo]
49 |
50 |
51 | def zoom_like(image, shape):
52 | h, w = shape
53 | ih, iw = image.shape
54 | scale = diag([ih * 1.0/h, iw * 1.0/w])
55 | return ndi.affine_transform(image, scale, output_shape=(h, w), order=1)
56 |
57 |
58 | def remove_big(image, max_h=100, max_w=100):
59 | """Remove large components."""
60 | assert image.ndim == 2
61 | bin = (image > 0.5 * amax(image))
62 | labels, n = ndi.label(bin)
63 | objects = ndi.find_objects(labels)
64 | indexes = ones(n+1, 'i')
65 | for i, (yr, xr) in enumerate(objects):
66 | if yr.stop-yr.start < max_h and xr.stop-xr.start < max_w:
67 | continue
68 | indexes[i+1] = 0
69 | indexes[0] = 0
70 | return indexes[labels]
71 |
72 |
73 | def compute_boxmap(binary, lo=10, hi=5000, dtype='i'):
74 | objects = binary_objects(binary)
75 | bysize = sorted(objects, key=sl_area)
76 | boxmap = zeros(binary.shape, dtype)
77 | for o in bysize:
78 | if sl_area(o)**.5 < lo:
79 | continue
80 | if sl_area(o)**.5 > hi:
81 | continue
82 | boxmap[o] = 1
83 | return boxmap
84 |
85 |
86 | def binary_objects(binary):
87 | labels, n = ndi.label(binary)
88 | objects = ndi.find_objects(labels)
89 | return objects
90 |
91 |
92 | def propagate_labels(image, labels, conflict=0):
93 | """Given an image and a set of labels, apply the labels
94 | to all the regions in the image that overlap a label.
95 | Assign the value `conflict` to any labels that have a conflict."""
96 | rlabels, _ = ndi.label(image)
97 | cors = correspondences(rlabels, labels)
98 | outputs = zeros(amax(rlabels) + 1, 'i')
99 | oops = -(1 << 30)
100 | for o, i in cors.T:
101 | if outputs[o] != 0:
102 | outputs[o] = oops
103 | else:
104 | outputs[o] = i
105 | outputs[outputs == oops] = conflict
106 | outputs[0] = 0
107 | return outputs[rlabels]
108 |
109 |
110 | def correspondences(labels1, labels2):
111 | """Given two labeled images, compute an array giving the correspondences
112 | between labels in the two images."""
113 | q = 100000
114 | assert amin(labels1) >= 0 and amin(labels2) >= 0
115 | assert amax(labels2) < q
116 | combo = labels1 * q + labels2
117 | result = unique(combo)
118 | result = array([result // q, result % q])
119 | return result
120 |
121 |
122 | def spread_labels(labels, maxdist=9999999):
123 | """Spread the given labels to the background"""
124 | distances, features = ndi.distance_transform_edt(
125 | labels == 0, return_distances=1, return_indices=1)
126 | indexes = features[0] * labels.shape[1] + features[1]
127 | spread = labels.ravel()[indexes.ravel()].reshape(*labels.shape)
128 | spread *= (distances < maxdist)
129 | return spread
130 |
131 |
132 | def estimate_scale(binary):
133 | objects = binary_objects(binary)
134 | bysize = sorted(objects, key=sl_area)
135 | scalemap = zeros(binary.shape)
136 | for o in bysize:
137 | if amax(scalemap[o]) > 0:
138 | continue
139 | scalemap[o] = sl_area(o)**0.5
140 | scale = median(scalemap[(scalemap > 3) & (scalemap < 100)])
141 | return scale
142 |
143 |
144 | def compute_boxmap(binary, lo=10, hi=5000, dtype='i'):
145 | objects = binary_objects(binary)
146 | bysize = sorted(objects, key=sl_area)
147 | boxmap = zeros(binary.shape, dtype)
148 | for o in bysize:
149 | if sl_area(o)**.5 < lo:
150 | continue
151 | if sl_area(o)**.5 > hi:
152 | continue
153 | boxmap[o] = 1
154 | return boxmap
155 |
156 |
157 | def compute_lines(segmentation, scale):
158 | """Given a line segmentation map, computes a list
159 | of tuples consisting of 2D slices and masked images."""
160 | lobjects = ndi.find_objects(segmentation)
161 | lines = []
162 | for i, o in enumerate(lobjects):
163 | if o is None:
164 | continue
165 | if sl_dim1(o) < 2 * scale or sl_dim0(o) < scale:
166 | continue
167 | mask = (segmentation[o] == i + 1)
168 | if amax(mask) == 0:
169 | continue
170 | result = dict(label=i+1,
171 | bounds=o,
172 | mask=mask)
173 | lines.append(result)
174 | return lines
175 |
176 |
177 | def pad_image(image, d, cval=None):
178 | result = ones(array(image.shape) + 2 * d)
179 | result[:, :] = amax(image) if cval is None else cval
180 | result[d:-d, d:-d] = image
181 | return result
182 |
183 |
184 | def extract(image, y0, x0, y1, x1, mode='nearest', cval=0):
185 | h, w = image.shape
186 | ch, cw = y1 - y0, x1 - x0
187 | y, x = clip(y0, 0, max(h - ch, 0)), clip(x0, 0, max(w - cw, 0))
188 | sub = image[y:y + ch, x:x + cw]
189 | try:
190 | r = ndi.shift(sub, (y - y0, x - x0), mode=mode, cval=cval, order=0)
191 | if cw > w or ch > h:
192 | pady0, padx0 = max(-y0, 0), max(-x0, 0)
193 | r = ndi.affine_transform(r, eye(2), offset=(
194 | pady0, padx0), cval=1, output_shape=(ch, cw))
195 | return r
196 |
197 | except RuntimeError:
198 | # workaround for platform differences between 32bit and 64bit
199 | # scipy.ndimage
200 | dtype = sub.dtype
201 | sub = array(sub, dtype='float64')
202 | sub = ndi.shift(sub, (y - y0, x - x0), mode=mode, cval=cval, order=0)
203 | sub = array(sub, dtype=dtype)
204 | return sub
205 |
206 |
207 | def extract_masked(image, linedesc, pad=5, expand=0, background=None):
208 | """Extract a subimage from the image using the line descriptor.
209 | A line descriptor consists of bounds and a mask."""
210 | assert amin(image) >= 0 and amax(image) <= 1
211 | if background is None or background == "min":
212 | background = amin(image)
213 | elif background == "max":
214 | background = amax(image)
215 | bounds = linedesc["bounds"]
216 | y0, x0, y1, x1 = [int(x) for x in [bounds[0].start, bounds[1].start,
217 | bounds[0].stop, bounds[1].stop]]
218 | if pad > 0:
219 | mask = pad_image(linedesc["mask"], pad, cval=0)
220 | else:
221 | mask = linedesc["mask"]
222 | line = extract(image, y0 - pad, x0 - pad, y1 + pad, x1 + pad)
223 | if expand > 0:
224 | mask = ndi.maximum_filter(mask, (expand, expand))
225 | line = where(mask, line, background)
226 | return line
227 |
228 |
229 | def reading_order(lines, highlight=None, debug=0):
230 | """Given the list of lines (a list of 2D slices), computes
231 | the partial reading order. The output is a binary 2D array
232 | such that order[i,j] is true if line i comes before line j
233 | in reading order."""
234 | order = zeros((len(lines), len(lines)), 'B')
235 |
236 | def x_overlaps(u, v):
237 | return u[1].start < v[1].stop and u[1].stop > v[1].start
238 |
239 | def above(u, v):
240 | return u[0].start < v[0].start
241 |
242 | def left_of(u, v):
243 | return u[1].stop < v[1].start
244 |
245 | def separates(w, u, v):
246 | if w[0].stop < min(u[0].start, v[0].start):
247 | return 0
248 | if w[0].start > max(u[0].stop, v[0].stop):
249 | return 0
250 | if w[1].start < u[1].stop and w[1].stop > v[1].start:
251 | return 1
252 |
253 | if highlight is not None:
254 | clf()
255 | title("highlight")
256 | imshow(binary)
257 | ginput(1, debug)
258 | for i, u in enumerate(lines):
259 | for j, v in enumerate(lines):
260 | if x_overlaps(u, v):
261 | if above(u, v):
262 | order[i, j] = 1
263 | else:
264 | if [w for w in lines if separates(w, u, v)] == []:
265 | if left_of(u, v):
266 | order[i, j] = 1
267 | if j == highlight and order[i, j]:
268 | print (i, j),
269 | y0, x0 = sl.center(lines[i])
270 | y1, x1 = sl.center(lines[j])
271 | plot([x0, x1 + 200], [y0, y1])
272 | if highlight is not None:
273 | print()
274 | ginput(1, debug)
275 | return order
276 |
277 |
278 | def topsort(order):
279 | """Given a binary array defining a partial order (o[i,j]==True means i0.5)*1
398 | output=np.asarray(output,np.uint8)
399 | result=cv2.resize(output,(width,height))
400 |
401 | return zoom_like(result, image.shape)
402 |
403 | def line_seeds(self, image):
404 | poutput = self.line_probs(image)
405 | binoutput = hysteresis_threshold(poutput, self.lo, self.hi)
406 | self.lines = binoutput
407 | seeds, _ = ndi.label(binoutput)
408 | return seeds
409 |
410 | def line_segmentation(self, pimage, max_size=(300, 300)):
411 | self.image = pimage
412 | self.binary = pimage > self.docthreshold
413 | if max_size is not None:
414 | self.binary = remove_big(self.binary, *max_size)
415 | self.boxmap = compute_boxmap(self.binary, dtype="B")
416 | self.seeds = self.line_seeds(pimage)
417 | self.llabels = propagate_labels(self.boxmap, self.seeds, conflict=0)
418 | self.spread = spread_labels(self.seeds, maxdist=self.basic_size)
419 | self.llabels = where(self.llabels > 0, self.llabels,
420 | self.spread * self.binary)
421 | self.segmentation = self.llabels * self.binary
422 | return self.segmentation
423 |
424 | def extract_textlines(self, image, docimage=None, max_size=(300, 300), scale=5.0, pad=5, expand=0, background=None):
425 | if len(image.shape)!=2:
426 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
427 | image=1-image/255
428 |
429 |
430 | if docimage is None:
431 | docimage = image
432 | assert image.shape == docimage.shape
433 | self.lineimage = self.line_segmentation(image, max_size=max_size)
434 | lines = compute_lines(self.lineimage, scale)
435 | for line in lines:
436 | line["image"] = extract_masked(
437 | docimage, line, pad=pad, expand=expand, background=background)
438 | return lines
439 |
440 |
441 | if __name__=="__main__":
442 | seg = Segmenter()
443 | image = cv2.imread("./make_training_labels/W1P0.png")
444 | lines = seg.extract_textlines(image)
445 | for num,line in enumerate(lines):
446 | cv2.imwrite("./lines/%d.png"%num,line['image']*255)
447 | cv2.imwrite("out.png", seg.lines*255)
448 |
--------------------------------------------------------------------------------
/train_test.py:
--------------------------------------------------------------------------------
1 | #references:
2 | #https://github.com/NVlabs/ocroseg
3 | #https://github.com/philipperemy/tensorflow-multi-dimensional-lstm
4 | #https://github.com/tmbdev/ocropy
5 | #paper:
6 | #Multi-Dimensional Recurrent Neural Networks
7 | #Robust_ Simple Page Segmentation Using Hybrid Convolutional MDLSTM Networks
8 | #dataset:
9 | #https://storage.googleapis.com/tmb-ocr/uw3-framed-lines-degraded-000.tgz
10 |
11 | import os
12 | import math
13 | import tensorflow as tf
14 | from PIL import Image
15 | from functools import partial
16 |
17 | from md_lstm import multi_dimensional_rnn_while_loop,horizontal_standard_lstm,snake_standard_lstm,horizontal_vertical_lstm_inorder,horizontal_vertical_lstm_together
18 |
19 | import cv2
20 | import numpy as np
21 |
22 |
23 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
24 | batch_size=1
25 | batch_height=None
26 | batch_width=None
27 | batch_channel=1
28 | save_steps=1000
29 |
30 |
31 | def get_tf_dataset(dataset_text_file,batch_size=1, channels=1,shuffle_size=10):
32 | def _parse_function(filename, labelname):
33 |
34 | image_string = tf.read_file(filename)
35 | image_decoded = tf.image.decode_jpeg(image_string, channels=channels)
36 | #image = tf.image.resize_images(image_decoded, size)
37 | #image=255-image
38 | image = tf.cast(image_decoded, tf.float32) * (1. / 255)
39 |
40 | label_string = tf.read_file(labelname)
41 | label_decoded = tf.image.decode_jpeg(label_string, channels=channels)
42 | label = tf.image.resize_images(label_decoded, [tf.round(tf.div(tf.shape(label_decoded)[0],4)),tf.round(tf.div(tf.shape(label_decoded)[1],4))])
43 | label = tf.cast(label, tf.float32) * (1. / 255)
44 | label =tf.cast(label>0.5, tf.float32)
45 |
46 |
47 | return image, label
48 |
49 | def read_labeled_image_list(dataset_text_file):
50 | base_dir="./uw3/"
51 | filenames=[]
52 | labels=[]
53 | with open(dataset_text_file,"r",encoding="utf-8") as f_l:
54 | filenames_lables=f_l.readlines()
55 |
56 | one_epoch_num=len(filenames_lables)
57 |
58 | for filename_lable in filenames_lables:
59 | filenames.append(base_dir+filename_lable.split(" ")[0])
60 | labels.append(base_dir+filename_lable.split(" ")[1].strip("\n"))
61 | return filenames,labels,one_epoch_num
62 |
63 | filenames, labels,one_epoch_num = read_labeled_image_list(dataset_text_file)
64 |
65 | filenames = tf.constant(filenames, name='filename_list')
66 | labels = tf.constant(labels, name='label_list')
67 |
68 | #tensorflow1.3:tf.contrib.data.Dataset.from_tensor_slices
69 | #tensorflow1.4+:tf.data.Dataset.from_tensor_slices
70 | dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
71 | dataset = dataset.shuffle(shuffle_size)
72 | dataset = dataset.map(_parse_function)
73 | dataset = dataset.batch(batch_size=batch_size)
74 | dataset = dataset.repeat()
75 |
76 | return dataset,one_epoch_num
77 |
78 |
79 | def network(is_training=False):
80 | network = {}
81 | network["inputs"] = tf.placeholder(tf.float32, [batch_size, batch_height,batch_width, batch_channel],
82 | name='inputs')
83 | network["conv1"] = tf.layers.conv2d(inputs=network["inputs"], filters=32, kernel_size=(3, 3), padding="same",
84 | activation=None, name="conv1")
85 | #with tf.variable_scope("BN"):
86 | network["batch_norm1"] = tf.contrib.layers.batch_norm(
87 | network["conv1"],
88 | decay=0.9,
89 | center=True,
90 | scale=True,
91 | epsilon=0.001,
92 | updates_collections=None,
93 | is_training=is_training,
94 | zero_debias_moving_mean=True,
95 | scope="BN1")
96 | network["batch_norm1"] = tf.nn.relu(network["batch_norm1"])
97 | network["pool1"] = tf.layers.max_pooling2d(inputs=network["batch_norm1"], pool_size=[2, 2], strides=2)
98 | network["conv2"] = tf.layers.conv2d(inputs=network["pool1"], filters=64, kernel_size=(3, 3), padding="same",
99 | activation=None, name="conv2")
100 | #with tf.variable_scope("BN"):
101 | network["batch_norm2"] = tf.contrib.layers.batch_norm(
102 | network["conv2"],
103 | decay=0.9,
104 | center=True,
105 | scale=True,
106 | epsilon=0.001,
107 | updates_collections=None,
108 | is_training=is_training,
109 | scope="BN2")
110 | network["batch_norm2"] = tf.nn.relu(network["batch_norm2"])
111 | network["pool2"] = tf.layers.max_pooling2d(inputs=network["batch_norm2"], pool_size=[2, 2], strides=2)
112 | network["conv3"] = tf.layers.conv2d(inputs=network["pool2"], filters=128, kernel_size=(3, 3), padding="same",
113 | activation=None, name="conv3")
114 | #with tf.variable_scope("BN"):
115 | network["batch_norm3"] = tf.contrib.layers.batch_norm(
116 | network["conv3"],
117 | decay=0.9,
118 | center=True,
119 | scale=True,
120 | epsilon=0.001,
121 | updates_collections=None,
122 | is_training=is_training,
123 | scope="BN3")
124 | network["batch_norm3"] = tf.nn.relu(network["batch_norm3"])
125 |
126 | network["LSTM2D1"] = horizontal_vertical_lstm_inorder(rnn_size=128, input_data=network["batch_norm3"], scope_n="LSTM2D1")
127 | #network["LSTM2D1"] = horizontal_vertical_lstm_together(rnn_size=128, input_data=network["batch_norm3"], scope_n="LSTM2D1")
128 | #network["LSTM2D1"] = horizontal_standard_lstm(rnn_size=128, input_data=network["batch_norm3"], scope_n="LSTM2D1")
129 | #network["LSTM2D1"] = snake_standard_lstm(rnn_size=128, input_data=network["batch_norm3"], scope_n="LSTM2D1")
130 | #network["LSTM2D1"], _ = multi_dimensional_rnn_while_loop(rnn_size=128, input_data=network["batch_norm3"],sh=[1, 1], dims=None, scope_n="LSTM2D1")
131 |
132 | network["conv4"] = tf.layers.conv2d(inputs=network["LSTM2D1"], filters=64, kernel_size=(3, 3), padding="same",
133 | activation=None, name="conv4")
134 | #with tf.variable_scope("BN"):
135 | network["batch_norm4"] = tf.contrib.layers.batch_norm(
136 | network["conv4"],
137 | decay=0.9,
138 | center=True,
139 | scale=True,
140 | epsilon=0.001,
141 | updates_collections=None,
142 | is_training=is_training,
143 | scope="BN4")
144 | network["batch_norm4"] = tf.nn.relu(network["batch_norm4"])
145 | network["LSTM2D2"] = horizontal_vertical_lstm_inorder(rnn_size=128, input_data=network["batch_norm4"], scope_n="LSTM2D2")
146 | #network["LSTM2D2"] = horizontal_vertical_lstm_together(rnn_size=128, input_data=network["batch_norm4"], scope_n="LSTM2D2")
147 | #network["LSTM2D2"] = horizontal_standard_lstm(rnn_size=128, input_data=network["batch_norm4"], scope_n="LSTM2D2")
148 | #network["LSTM2D2"] = snake_standard_lstm(rnn_size=128, input_data=network["batch_norm4"], scope_n="LSTM2D2")
149 | #network["LSTM2D2"], _ = multi_dimensional_rnn_while_loop(rnn_size=128, input_data=network["batch_norm4"],sh=[1, 1], dims=None, scope_n="LSTM2D2")
150 |
151 | network["conv5"] = tf.layers.conv2d(inputs=network["LSTM2D2"], filters=1, kernel_size=(3, 3), padding="same",
152 | activation=None, name="conv5")
153 | network["outputs"] = tf.nn.sigmoid(network["conv5"])
154 | return network
155 |
156 |
157 |
158 |
159 |
160 |
161 | def train():
162 | #network
163 | global_step = tf.Variable(0, trainable=False)
164 | learning_rate = tf.train.exponential_decay(learning_rate=0.001,
165 | global_step=global_step,
166 | decay_steps=10000,
167 | decay_rate=0.1,
168 | staircase=True)
169 |
170 | model=network(is_training=True)
171 | y_ = tf.placeholder(tf.float32, [batch_size, batch_height, batch_width, batch_channel], name='labels')
172 | loss = tf.reduce_mean(tf.losses.mean_squared_error(labels=y_,predictions=model["outputs"]))
173 | accuracy=tf.reduce_sum(tf.cast((tf.cast(model["outputs"]>0.5,tf.int32)+tf.cast(y_,tf.int32))>1,tf.float32))/tf.reduce_sum(y_)
174 |
175 |
176 | update_ops= tf.get_collection(tf.GraphKeys.UPDATE_OPS)
177 | if update_ops:
178 | with tf.control_dependencies(update_ops):
179 | grad_update = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss,global_step=global_step)
180 | else:
181 | grad_update = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss,global_step=global_step)
182 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1.0)
183 |
184 |
185 | #tensorboard
186 | tf.summary.scalar("loss", loss)
187 | tf.summary.scalar("accuracy", accuracy)
188 | for update_op in update_ops:
189 | tf.summary.histogram(update_op.name, update_op)
190 | for var in tf.trainable_variables():
191 | tf.summary.histogram(var.name, var)
192 | merge_summary = tf.summary.merge_all()
193 |
194 | dataset,one_epoch_num = get_tf_dataset(dataset_text_file="./uw3/label.txt",batch_size=batch_size,channels=batch_channel)
195 | iterator = dataset.make_one_shot_iterator()
196 | img_batch, label_batch = iterator.get_next()
197 |
198 | init = tf.global_variables_initializer()
199 |
200 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as session:
201 | session.run(init)
202 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=20)
203 | #saver.restore(save_path="./save/ocrseg.ckpt-2000", sess=session)
204 |
205 | #tensorboard
206 | summary_writer = tf.summary.FileWriter("./summary/", session.graph)
207 |
208 | epoch=0
209 | while True:
210 | try:
211 | img_batch_i, label_batch_i = session.run([img_batch,label_batch])
212 |
213 | if img_batch_i.shape[0]!=batch_size:
214 | print("the last iter of one epoch")
215 | continue
216 | except tf.errors.OutOfRangeError:
217 | print("one epoch over!")
218 | continue
219 | feed = {model["inputs"]: img_batch_i,y_: label_batch_i}
220 | learning_rate_train,loss_train,accuracy_train,step,summary,_=session.run([learning_rate,loss,accuracy,global_step,merge_summary,grad_update], feed_dict=feed)
221 | print("learning rate:%f epoch:%d iter:%d loss:%f accuracy:%f"%(learning_rate_train,epoch,step,loss_train,accuracy_train))
222 |
223 | #tensorboard
224 | summary_writer.add_summary(summary, step)
225 |
226 | if step > 0 and step % save_steps == 0:
227 | save_path = saver.save(session, "save/ocrseg.ckpt", global_step=step)
228 | print(save_path)
229 | if step > 0:
230 | epoch=step*batch_size//one_epoch_num
231 |
232 | def test():
233 | model=network(is_training=False)
234 | init = tf.global_variables_initializer()
235 | with tf.Session() as session:
236 | session.run(init)
237 | saver = tf.train.Saver(tf.global_variables())
238 | saver.restore(save_path="./save/ocrseg.ckpt-1000",sess=session)
239 |
240 | image=cv2.imread("./make_training_labels/W001.png",0)
241 | height,width=image.shape
242 | image=image.reshape((1,image.shape[0],image.shape[1],1))
243 | image=1-image/255
244 | feed = {model["inputs"]: image}
245 | output = session.run(model["outputs"], feed_dict=feed)
246 | output = output.reshape((output.shape[1],output.shape[2]))
247 | print("max:%f min:%f"%(np.max(np.max(output)),np.min(np.min(output))))
248 | output=(output>0.5)*255
249 | output=np.asarray(output,np.uint8)
250 | output=cv2.resize(output,(width,height))
251 | cv2.imwrite("out.png",output)
252 |
253 | if __name__=="__main__":
254 | train()
255 | #test()
256 |
--------------------------------------------------------------------------------