├── README.md └── MSER_refinement.py /README.md: -------------------------------------------------------------------------------- 1 | # MSER 2 | the MSER for text detection 3 | -------------------------------------------------------------------------------- /MSER_refinement.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Jun 12 09:27:37 2017 5 | The steps are: 6 | 1. Convert the image to gray scale 7 | 2. Preprocess. Increase the image contrast using CLAHE to accurately detect 8 | MSER regions. Then the image is scaled to SVGA size (800x600) 9 | 3. Detect MSER region 10 | 4. Filter MSER regions according to aspect ratio of BBOX and Stroke-width ratio 11 | 5. K-means for aligning MSER regions in horizontal direction and cover 12 | regions in one line 13 | @author: lili 14 | """ 15 | #&amp;amp;lt;pre&amp;amp;gt; 16 | 17 | import cv2 18 | import scipy.misc as smp 19 | 20 | import numpy as np 21 | import pprint 22 | from pytesseract import image_to_string 23 | from PIL import Image 24 | 25 | 26 | color = (0,255,0) 27 | char_height = 20.0 28 | 29 | def bbox(points): 30 | res = np.zeros((2,2)) 31 | res[0,:] = np.min(points, axis=0) 32 | res[1,:] = np.max(points, axis=0) 33 | return res 34 | 35 | def bbox_width(bbox): 36 | return (bbox[1,0] - bbox[0,0] + 1) 37 | 38 | def bbox_height(bbox): 39 | return (bbox[1,1] - bbox[0,1] +1) 40 | 41 | def aspect_ratio(region): 42 | bb = bbox(region) 43 | return (bbox_width(bb)/bbox_height(bb)) 44 | 45 | 46 | def filter_on_ar(regions): 47 | #Filter text regions based on Aspect-ratio 48 | return [x for x in regions if aspect_ratio(x)] 49 | 50 | def dbg_draw_txt_contours(img, mser): 51 | overlapped_img = cv2.drawContours(img, mser, -1, color) 52 | new_img = smp.toimage(overlapped_img) 53 | new_img = np.array(new_img) 54 | #new_img.show() 55 | 56 | def dbg_draw_txt_rect(img, bbox_list): 57 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR, dstCn=3) 58 | processed_imgname="/home/lili/Workspace/MSER_images/MSER2/MSER_refinement.png" 59 | #scratch_image_name = 'nutro.tmp.bmp' 60 | for b in bbox_list: 61 | pt1 = tuple(map(int, b[0])) 62 | pt2 = tuple(map(int, b[1])) 63 | img = cv2.rectangle(img, pt1, pt2, color, 1) 64 | #break 65 | new_img = smp.toimage(img) 66 | new_img = np.array(new_img) 67 | cv2.imwrite(processed_imgname, new_img) 68 | 69 | 70 | def preprocess_img(img): 71 | #Enhance contrast and resize the image 72 | # create a CLAHE object (Arguments are optional). 73 | # It is adaptive localized hist-eq and also avoid noise 74 | # amplification with cliplimit 75 | clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize = (8,8)) 76 | img = clahe.apply(img) 77 | # Resize to match SVGA size 78 | height, width = img.shape 79 | # SVGA size is 800x600 80 | if width & height: 81 | scale = 800. /width 82 | else: 83 | scale = 600. /width 84 | 85 | #Avoid shrinking 86 | #if scale & amp; amp; amp; lt; 1.0: 87 | # scale = 1.0 88 | dst = cv2.resize(img, (0,0), None, scale, scale, cv2.INTER_LINEAR) 89 | return dst 90 | 91 | def swt_window_func(l): 92 | center = l[4] 93 | filtered_l = np.append(l[:4], l[5:]) 94 | res = [n for n in filtered_l if n & center] 95 | return res 96 | 97 | def swt(gimg): 98 | #TODO: fix threshold logically 99 | threshold = 90 100 | maxval = 255 101 | #THRESH_BINARY_INV because we want to find distance from foreground pixel to background pixel 102 | temp, bimg = cv2.threshold(gimg, threshold, maxval, cv2.THRESH_BINARY_INV) 103 | rows, cols = bimg.shape 104 | #Pad 0 pixel on bottom-row to avoid Infinite distance 105 | row_2_pad = np.zeros([1, cols], dtype=np.uint8) 106 | bimg_padded = np.concatenate((bimg, row_2_pad), axis=0) 107 | dist = cv2.distanceTransform(bimg_padded, cv2.DIST_L2, cv2.DIST_MASK_PRECISE) 108 | dist = np.take(dist, range(rows), axis=0) 109 | dist = dist.round() 110 | #print dist 111 | it = np.nditer([bimg, dist], 112 | op_flags=[['readonly'],['readonly']], 113 | flags = ['multi_index', 'multi_index']) 114 | 115 | #Look-up operation 116 | #while not it.finished: 117 | lookup = [] 118 | max_col = 0 119 | max_row = 0 120 | for cur_b, cur_d in it: 121 | if it.multi_index[0] pval: 154 | cur_lup.append((i,j)) 155 | lookup.append(cur_lup) 156 | else: 157 | lookup.append(None) 158 | #it.iternext() 159 | lookup = np.array(lookup) 160 | lookup= lookup.reshape(rows, cols) 161 | d_max = int(dist.max()) 162 | for stroke in np.arange(d_max, 0, -1): 163 | stroke_index = np.where(dist==stroke) 164 | stroke_index = [(a,b) for a,b in zip(stroke_index[0], stroke_index[1])] 165 | for stidx in stroke_index: 166 | neigh_index = lookup[stidx] 167 | for nidx in neigh_index: 168 | dist[nidx] = stroke 169 | 170 | it.reset() 171 | sw = [] 172 | for cur_b, cur_d in it: 173 | if cur_b: 174 | sw.append(cur_d) 175 | return sw 176 | 177 | def get_swt_frm_mser(region, rows, cols, img): 178 | #Given image and total rows and columns, extract SWT values from MSER region 179 | bb = bbox(region) 180 | xmin = int(bb[0][0]) 181 | ymin = int(bb[0][1]) 182 | width = int(bbox_width(bb)) 183 | height = int(bbox_height(bb)) 184 | selected_pix = [] 185 | xmax = xmin + width 186 | ymax = ymin + height 187 | for h in range(ymin, ymax): 188 | row = np.take(img, (h, ), axis=0) 189 | horz_pix = np.take(row, range(xmin,xmax)) 190 | selected_pix.append(horz_pix) 191 | selected_pix = np.array(selected_pix) 192 | sw = swt(selected_pix) 193 | return sw 194 | 195 | def filter_on_sw(region_dict): 196 | filtered_dict = {} 197 | distance_th = 4.0 198 | group_num = 0 199 | for rkey in region_dict.keys(): 200 | med = region_dict[rkey]['sw_med'] 201 | height = bbox_height(region_dict[rkey]['bbox']) 202 | added = False 203 | for fkey in filtered_dict: 204 | for k in filtered_dict[fkey]: 205 | elem_med = filtered_dict[fkey][k]['sw_med'] 206 | elem_height = bbox_height(filtered_dict[fkey][k]['bbox']) 207 | m_ratio = med/elem_med 208 | h_ratio = height/elem_height 209 | if m_ratio < 0.66 and m_ratio < 1.5 and h_ratio < 2.0: 210 | filtered_dict[fkey][rkey] = region_dict[rkey] 211 | added = True 212 | break 213 | if added: 214 | break 215 | if not added: 216 | name = 'group' + str(group_num) 217 | filtered_dict[name] = {} 218 | filtered_dict[name][rkey] = region_dict[rkey] 219 | group_num = group_num +1 220 | return filtered_dict 221 | 222 | def get_y_center(bb): 223 | ll = bb[0] 224 | ur = bb[1] 225 | return ((ll[1]+ur[1])/2.0) 226 | 227 | def kmean(region_dict, rows, num_clusters): 228 | clusters = (float(rows)/num_clusters) * np.arange(num_clusters) 229 | cluster_vld = [True] * num_clusters 230 | #calculate initial cost assuming all regions assigned to cluster-0 231 | cost = 0.0 232 | for rkey in region_dict: 233 | center_y = get_y_center(region_dict[rkey]['bbox']) 234 | cost += center_y * center_y 235 | cost = cost/len(region_dict.keys()) 236 | 237 | iter_no = 0 238 | while True: 239 | iter_no = iter_no + 1 240 | #Assign cluster-id to each region 241 | for rkey in region_dict: 242 | center_y = get_y_center(region_dict[rkey]['bbox']) 243 | dist_y = np.abs(clusters - center_y) 244 | cluster_id = dist_y.argmin() 245 | region_dict[rkey]['clid'] = cluster_id 246 | 247 | #find new cost with assigned clusters 248 | new_cost = 0.0 249 | for i, c in enumerate(clusters): 250 | if cluster_vld[i]: 251 | num_regions = 0 252 | cluster_cost = 0.0 253 | for rkey in region_dict: 254 | if(region_dict[rkey]['clid'] == i): 255 | center_y = get_y_center(region_dict[rkey]['bbox']) 256 | cluster_cost += (center_y - clusters[i]) ** 2 257 | num_regions += 1 258 | if num_regions: 259 | cluster_cost /= num_regions 260 | new_cost += cluster_cost 261 | 262 | #Stop when new cost is within 5% of old cost 263 | if new_cost >= 0.95 * cost: 264 | break 265 | else: 266 | cost = new_cost 267 | 268 | for i, c in enumerate(clusters): 269 | if cluster_vld[i]: 270 | num_regions = 0 271 | clusters[i] = 0.0 272 | for rkey in region_dict: 273 | if(region_dict[rkey]['clid'] == i): 274 | center_y = get_y_center(region_dict[rkey]['bbox']) 275 | clusters[i] += center_y 276 | num_regions += 1 277 | if num_regions: 278 | clusters[i] = clusters[i] / num_regions 279 | else: 280 | cluster_vld[i] = False 281 | 282 | #Merge nearby clusters 283 | for i, cur_cl in enumerate(clusters): 284 | if cluster_vld[i]: 285 | for j, iter_cl in enumerate(clusters): 286 | if abs(cur_cl - iter_cl) <= (char_height/2.0) and i != j: 287 | cluster_vld[j] = False 288 | for rkey in region_dict: 289 | #Update cluster-id to updated one 290 | if region_dict[rkey]['clid'] == j: 291 | region_dict[rkey]['clid'] = i 292 | 293 | return cluster_vld 294 | 295 | def dbg_get_cluster_rect(cluster_vld, region_dict): 296 | bbox_list = [] 297 | vld_count = 0 298 | for cl_no, vld in enumerate(cluster_vld): 299 | if vld==True: 300 | vld_count+=1 301 | print "vld" 302 | print vld 303 | if vld: 304 | cur_lL = [100000, 10000] 305 | cur_uR = [-100000, -100000] 306 | for rkey in region_dict.keys(): 307 | if region_dict[rkey]['clid'] == cl_no: 308 | region_lL = region_dict[rkey]['bbox'][0] 309 | region_uR = region_dict[rkey]['bbox'][1] 310 | #update min/max of x/y 311 | if region_lL[0] <= cur_lL[0]: 312 | cur_lL[0] = region_lL[0] 313 | if region_lL[1] <= cur_lL[1]: 314 | cur_lL[1] = region_lL[1] 315 | if region_uR[0] >= cur_uR[0]: 316 | cur_uR[0] = region_uR[0] 317 | if region_uR[1] >= cur_uR[1]: 318 | cur_uR[1] = region_uR[1] 319 | bbox_list.append([cur_lL, cur_uR]) 320 | print "len(bbox_list) in get_text_from_cluster" 321 | print len(bbox_list) 322 | print "vld_count" 323 | print vld_count 324 | return bbox_list 325 | 326 | def get_bbox_img(gimg, bb): 327 | #print bb, gimg.shape 328 | y_start = int(bb[0][1]) 329 | y_end = int(bb[1][1]) 330 | x_start = int(bb[0][0]) 331 | x_end = int(bb[1][0]) 332 | #print x_start, x_end, y_start, y_end 333 | row_extracted = gimg.take(range(y_start, y_end+1), axis=0) 334 | #print gimg 335 | extracted = row_extracted.take(range(x_start, x_end+1), axis=1) 336 | 337 | return extracted 338 | 339 | def get_text_from_cluster(cluster_vld, region_dict, gimg): 340 | bbox_list = dbg_get_cluster_rect(cluster_vld, region_dict) 341 | str_list = [] 342 | for bb in bbox_list: 343 | extracted = get_bbox_img(gimg, bb) 344 | ext_img = smp.toimage(extracted) 345 | found = image_to_string(ext_img, cleanup=False) 346 | str_list.append(found.strip()) 347 | str_list.insert(0, str_list) 348 | 349 | pprint.pprint(str_list) 350 | 351 | 352 | def run(fimage): 353 | processed_imgname='/home/lili/Workspace/MSER_images/MSER2/MSER_refinement.png' 354 | ar_thresh_max = 6.0 355 | ar_thresh_min = 0.5 356 | sw_ratio_thresh = 1 357 | min_area_ratio = 500.0 358 | width_threshold = 5.0 359 | 360 | org_img = cv2.imread(fimage) 361 | gray_img = cv2.cvtColor(org_img, cv2.COLOR_BGR2GRAY) 362 | mser = cv2.MSER_create() 363 | mser.setDelta(4) 364 | mser_areas, _ = mser.detectRegions(gray_img) 365 | region_dict = {} 366 | rows, cols = gray_img.shape 367 | print("the shape of gray image is ") 368 | print rows, cols 369 | 370 | bbox_list = [] 371 | region_num = 0 372 | for m in mser_areas: 373 | name = 'mser_' + str(region_num) 374 | # print("mser name") 375 | # print name 376 | bb = bbox(m) 377 | # print bb 378 | ar = bbox_width(bb)/bbox_height(bb) 379 | area_ratio=bbox_width(bb)*bbox_height(bb) 380 | print "area_ratio" 381 | print area_ratio 382 | print("ar is") 383 | print ar 384 | #Filter based on AspectRatio 385 | if ar < ar_thresh_max and area_ratio>min_area_ratio and ar > ar_thresh_min and bbox_width(bb)>width_threshold: 386 | sw = get_swt_frm_mser(m, rows, cols, gray_img) 387 | sw_std = np.std(sw) 388 | sw_mean = np.mean(sw) 389 | sw_ratio = sw_std/sw_mean 390 | # Filter based on Stroke-width 391 | # print "sw_ratio" 392 | # print sw_ratio 393 | if sw_ratio < sw_ratio_thresh: 394 | print "sw_ratio" 395 | print sw_ratio 396 | sw_med = np.median(sw) 397 | region_dict[name] = {'bbox':bb, 'sw_med':sw_med}; 398 | region_num = region_num +1 399 | 400 | print "region_num" 401 | print region_num 402 | 403 | print "rows number" 404 | print rows 405 | print "char_height" 406 | print char_height 407 | num_clusters = int(rows/char_height) 408 | cluster_vld = kmean(region_dict, rows, num_clusters) 409 | print "len(cluster_vld)" 410 | print len(cluster_vld) 411 | bbox_list = dbg_get_cluster_rect(cluster_vld, region_dict) 412 | print "len(bbox_list) after clustering" 413 | print len(bbox_list) 414 | for bb in bbox_list: 415 | print bb 416 | 417 | #get_text_from_cluster(cluster_vld, region_dict, gray_img) 418 | print "len(region_dict)" 419 | print len(region_dict) 420 | print "gray_img.shape" 421 | print gray_img.shape 422 | cpy_img = np.copy(gray_img) 423 | dbg_draw_txt_rect(cpy_img, bbox_list) 424 | 425 | 426 | 427 | 428 | if __name__ == '__main__': 429 | img_name = "/home/lili/Workspace/FCN_Text/ProposalGeneration/tmp_practice/ee3bdfa09bedf98f836e337585736977-1.png" 430 | #img_name = "/home/lili/Workspace/MSER_images/MSER2/good_ex.png" 431 | run(img_name) 432 | --------------------------------------------------------------------------------