├── 012.png ├── 012.png_act.jpg ├── 012.png_predict.jpg ├── README.md ├── __init__.py ├── a.png ├── a.png_act.jpg ├── a.png_predict.jpg ├── cfg.py ├── convert_to_onnx.py ├── generator.py ├── label.py ├── loss.py ├── model.py ├── nms.py ├── predict.py ├── preprocess.py ├── saved_model └── mb3_512_model_epoch_535.pth └── train.py /012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corleonechensiyu/pytorch_AdvancedEast/5cbcfbc25c9c897526bd637ddda0182cad12ff76/012.png -------------------------------------------------------------------------------- /012.png_act.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corleonechensiyu/pytorch_AdvancedEast/5cbcfbc25c9c897526bd637ddda0182cad12ff76/012.png_act.jpg -------------------------------------------------------------------------------- /012.png_predict.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corleonechensiyu/pytorch_AdvancedEast/5cbcfbc25c9c897526bd637ddda0182cad12ff76/012.png_predict.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch_AdvancedEast 2 | pytorch实现AdvancedEast+mobilenetv3 3 | 4 | # 参考https://github.com/huoyijie/AdvancedEAST 5 | # training 6 | ## tianchi ICPR dataset download 链接: https://pan.baidu.com/s/1NSyc-cHKV3IwDo6qojIrKA 密码: ye9y 7 | ### 1.modify config params in cfg.py, see default values. 8 | ### 2.python preprocess.py, resize image to 256256,384384,512512,640640,736*736, and train respectively could speed up training process. 9 | ### 3.python label.py 10 | ### 4.python train.py 11 | ### 5.python predict.py 12 | 图片: 13 | ![demo](https://github.com/corleonechensiyu/pytorch_AdvancedEast/blob/master/012.png_predict.jpg) 14 | 15 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corleonechensiyu/pytorch_AdvancedEast/5cbcfbc25c9c897526bd637ddda0182cad12ff76/__init__.py -------------------------------------------------------------------------------- /a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corleonechensiyu/pytorch_AdvancedEast/5cbcfbc25c9c897526bd637ddda0182cad12ff76/a.png -------------------------------------------------------------------------------- /a.png_act.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corleonechensiyu/pytorch_AdvancedEast/5cbcfbc25c9c897526bd637ddda0182cad12ff76/a.png_act.jpg -------------------------------------------------------------------------------- /a.png_predict.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corleonechensiyu/pytorch_AdvancedEast/5cbcfbc25c9c897526bd637ddda0182cad12ff76/a.png_predict.jpg -------------------------------------------------------------------------------- /cfg.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | train_task_id = 'MBV3_512' 4 | initial_epoch = 0 5 | epoch_num = 24 6 | lr = 1e-3 7 | decay = 5e-4 8 | # clipvalue = 0.5 # default 0.5, 0 means no clip 9 | patience = 5 10 | load_weights = True 11 | lambda_inside_score_loss = 4.0 12 | lambda_side_vertex_code_loss = 1.0 13 | lambda_side_vertex_coord_loss = 1.0 14 | 15 | total_img = 10000 16 | validation_split_ratio = 0.1 17 | max_train_img_size = int(train_task_id[-3:]) 18 | max_predict_img_size = int(train_task_id[-3:]) # 2400 19 | assert max_train_img_size in [256, 384, 512, 640, 736], \ 20 | 'max_train_img_size must in [256, 384, 512, 640, 736]' 21 | if max_train_img_size == 256: 22 | batch_size = 8 23 | elif max_train_img_size == 384: 24 | batch_size = 4 25 | elif max_train_img_size == 512: 26 | batch_size = 2 27 | else: 28 | batch_size = 1 29 | steps_per_epoch = total_img * (1 - validation_split_ratio) // batch_size 30 | validation_steps = total_img * validation_split_ratio // batch_size 31 | 32 | data_dir = '/home/csy/ICDAR/icpr/train_1000/' 33 | origin_image_dir_name = 'image_1000/' 34 | origin_txt_dir_name = 'txt_1000/' 35 | train_image_dir_name = 'images_%s/' % train_task_id 36 | train_label_dir_name = 'labels_%s/' % train_task_id 37 | show_gt_image_dir_name = 'show_gt_images_%s/' % train_task_id 38 | show_act_image_dir_name = 'show_act_images_%s/' % train_task_id 39 | gen_origin_img = True 40 | draw_gt_quad = True 41 | draw_act_quad = True 42 | val_fname = 'val_%s.txt' % train_task_id 43 | train_fname = 'train_%s.txt' % train_task_id 44 | # in paper it's 0.3, maybe to large to this problem 45 | shrink_ratio = 0.2 46 | # pixels between 0.2 and 0.6 are side pixels 47 | shrink_side_ratio = 0.6 48 | epsilon = 1e-4 49 | 50 | num_channels = 3 51 | feature_layers_range = range(5, 1, -1) 52 | # feature_layers_range = range(3, 0, -1) 53 | feature_layers_num = len(feature_layers_range) 54 | # pixel_size = 4 55 | pixel_size = 2 ** feature_layers_range[-1] 56 | locked_layers = False 57 | 58 | if not os.path.exists('model'): 59 | os.mkdir('model') 60 | if not os.path.exists('saved_model'): 61 | os.mkdir('saved_model') 62 | 63 | model_weights_path = 'model/mbv3_weights_%s.{epoch:03d}-{val_loss:.3f}.h5' \ 64 | % train_task_id 65 | saved_model_file_path = 'saved_model/mbv3_east_model_%s.h5' % train_task_id 66 | # saved_model_weights_file_path = 'saved_model/mbv3_east_model_weights_%s.h5'\ 67 | # % train_task_id 68 | saved_model_weights_file_path ='model/mbv3_weights_3T640.004-0.156.h5' 69 | pixel_threshold = 0.9 70 | side_vertex_pixel_threshold = 0.9 71 | trunc_threshold = 0.1 72 | predict_cut_text_line = False 73 | predict_write2txt = False 74 | -------------------------------------------------------------------------------- /convert_to_onnx.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is used to convert the pytorch model into an onnx format model. 3 | """ 4 | import sys 5 | 6 | import torch.onnx 7 | from model import EAST 8 | 9 | origin_model_path="./saved_model/mb3_512_model_epoch_535.pth" 10 | 11 | model = EAST().to("cuda") 12 | model.load_state_dict(torch.load(origin_model_path)) 13 | model.eval() 14 | 15 | model_path = "model/mbv3_512_east.onnx" 16 | 17 | dummy_input = torch.randn(1, 3, 512, 512).to("cuda") 18 | 19 | torch.onnx.export(model, dummy_input, model_path, verbose=False, input_names=['input'], output_names=['east_detect']) 20 | -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from PIL import Image 4 | import os 5 | import torch 6 | import torchvision.transforms as transforms 7 | from torch.utils import data 8 | import cfg 9 | 10 | 11 | class custom_dataset(data.Dataset): 12 | def __init__(self, img_path): #train_image_dir_name 13 | super(custom_dataset, self).__init__() 14 | self.img_files = [os.path.join(img_path, img_file) for img_file in sorted(os.listdir(img_path))] 15 | 16 | 17 | def __len__(self): 18 | return len(self.img_files) 19 | 20 | def __getitem__(self, index): 21 | img_filename = self.img_files[index].strip().split('/')[-1] 22 | 23 | gt_file = os.path.join(cfg.data_dir, 24 | cfg.train_label_dir_name, 25 | img_filename[:-4] + '_gt.npy') 26 | y=np.load(gt_file) 27 | img = Image.open(self.img_files[index]) 28 | transform = transforms.Compose([transforms.ColorJitter(0.5, 0.5, 0.5, 0.25), \ 29 | transforms.ToTensor(), \ 30 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) 31 | 32 | return transform(img),torch.Tensor(y).permute(2,0,1) 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /label.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from PIL import Image, ImageDraw 4 | from tqdm import tqdm 5 | import cfg 6 | 7 | 8 | def point_inside_of_quad(px, py, quad_xy_list, p_min, p_max): 9 | if (p_min[0] <= px <= p_max[0]) and (p_min[1] <= py <= p_max[1]): 10 | xy_list = np.zeros((4, 2)) 11 | xy_list[:3, :] = quad_xy_list[1:4, :] - quad_xy_list[:3, :] 12 | xy_list[3] = quad_xy_list[0, :] - quad_xy_list[3, :] 13 | yx_list = np.zeros((4, 2)) 14 | yx_list[:, :] = quad_xy_list[:, -1:-3:-1] 15 | a = xy_list * ([py, px] - yx_list) 16 | b = a[:, 0] - a[:, 1] 17 | if np.amin(b) >= 0 or np.amax(b) <= 0: 18 | return True 19 | else: 20 | return False 21 | else: 22 | return False 23 | 24 | 25 | def point_inside_of_nth_quad(px, py, xy_list, shrink_1, long_edge): 26 | nth = -1 27 | vs = [[[0, 0, 3, 3, 0], [1, 1, 2, 2, 1]], 28 | [[0, 0, 1, 1, 0], [2, 2, 3, 3, 2]]] 29 | for ith in range(2): 30 | quad_xy_list = np.concatenate(( 31 | np.reshape(xy_list[vs[long_edge][ith][0]], (1, 2)), 32 | np.reshape(shrink_1[vs[long_edge][ith][1]], (1, 2)), 33 | np.reshape(shrink_1[vs[long_edge][ith][2]], (1, 2)), 34 | np.reshape(xy_list[vs[long_edge][ith][3]], (1, 2))), axis=0) 35 | p_min = np.amin(quad_xy_list, axis=0) 36 | p_max = np.amax(quad_xy_list, axis=0) 37 | if point_inside_of_quad(px, py, quad_xy_list, p_min, p_max): 38 | if nth == -1: 39 | nth = ith 40 | else: 41 | nth = -1 42 | break 43 | return nth 44 | 45 | 46 | def shrink(xy_list, ratio=cfg.shrink_ratio): 47 | if ratio == 0.0: 48 | return xy_list, xy_list 49 | diff_1to3 = xy_list[:3, :] - xy_list[1:4, :] 50 | diff_4 = xy_list[3:4, :] - xy_list[0:1, :] 51 | diff = np.concatenate((diff_1to3, diff_4), axis=0) 52 | dis = np.sqrt(np.sum(np.square(diff), axis=-1)) 53 | # determine which are long or short edges 54 | long_edge = int(np.argmax(np.sum(np.reshape(dis, (2, 2)), axis=0))) 55 | short_edge = 1 - long_edge 56 | # cal r length array 57 | r = [np.minimum(dis[i], dis[(i + 1) % 4]) for i in range(4)] 58 | # cal theta array 59 | diff_abs = np.abs(diff) 60 | diff_abs[:, 0] += cfg.epsilon 61 | theta = np.arctan(diff_abs[:, 1] / diff_abs[:, 0]) 62 | # shrink two long edges 63 | temp_new_xy_list = np.copy(xy_list) 64 | shrink_edge(xy_list, temp_new_xy_list, long_edge, r, theta, ratio) 65 | shrink_edge(xy_list, temp_new_xy_list, long_edge + 2, r, theta, ratio) 66 | # shrink two short edges 67 | new_xy_list = np.copy(temp_new_xy_list) 68 | shrink_edge(temp_new_xy_list, new_xy_list, short_edge, r, theta, ratio) 69 | shrink_edge(temp_new_xy_list, new_xy_list, short_edge + 2, r, theta, ratio) 70 | return temp_new_xy_list, new_xy_list, long_edge 71 | 72 | 73 | def shrink_edge(xy_list, new_xy_list, edge, r, theta, ratio=cfg.shrink_ratio): 74 | if ratio == 0.0: 75 | return 76 | start_point = edge 77 | end_point = (edge + 1) % 4 78 | long_start_sign_x = np.sign( 79 | xy_list[end_point, 0] - xy_list[start_point, 0]) 80 | new_xy_list[start_point, 0] = \ 81 | xy_list[start_point, 0] + \ 82 | long_start_sign_x * ratio * r[start_point] * np.cos(theta[start_point]) 83 | long_start_sign_y = np.sign( 84 | xy_list[end_point, 1] - xy_list[start_point, 1]) 85 | new_xy_list[start_point, 1] = \ 86 | xy_list[start_point, 1] + \ 87 | long_start_sign_y * ratio * r[start_point] * np.sin(theta[start_point]) 88 | # long edge one, end point 89 | long_end_sign_x = -1 * long_start_sign_x 90 | new_xy_list[end_point, 0] = \ 91 | xy_list[end_point, 0] + \ 92 | long_end_sign_x * ratio * r[end_point] * np.cos(theta[start_point]) 93 | long_end_sign_y = -1 * long_start_sign_y 94 | new_xy_list[end_point, 1] = \ 95 | xy_list[end_point, 1] + \ 96 | long_end_sign_y * ratio * r[end_point] * np.sin(theta[start_point]) 97 | 98 | 99 | def process_label(data_dir=cfg.data_dir): 100 | with open(os.path.join(data_dir, cfg.val_fname), 'r') as f_val: 101 | f_list = f_val.readlines() 102 | with open(os.path.join(data_dir, cfg.train_fname), 'r') as f_train: 103 | f_list.extend(f_train.readlines()) 104 | for line, _ in zip(f_list, tqdm(range(len(f_list)))): 105 | line_cols = str(line).strip().split(',') 106 | img_name, width, height = \ 107 | line_cols[0].strip(), int(line_cols[1].strip()), \ 108 | int(line_cols[2].strip()) 109 | gt = np.zeros((height // cfg.pixel_size, width // cfg.pixel_size, 7)) 110 | train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name) 111 | xy_list_array = np.load(os.path.join(train_label_dir, 112 | img_name[:-4] + '.npy')) 113 | train_image_dir = os.path.join(data_dir, cfg.train_image_dir_name) 114 | with Image.open(os.path.join(train_image_dir, img_name)) as im: 115 | draw = ImageDraw.Draw(im) 116 | for xy_list in xy_list_array: 117 | _, shrink_xy_list, _ = shrink(xy_list, cfg.shrink_ratio) 118 | shrink_1, _, long_edge = shrink(xy_list, cfg.shrink_side_ratio) 119 | p_min = np.amin(shrink_xy_list, axis=0) 120 | p_max = np.amax(shrink_xy_list, axis=0) 121 | # floor of the float 122 | ji_min = (p_min / cfg.pixel_size - 0.5).astype(int) - 1 123 | # +1 for ceil of the float and +1 for include the end 124 | ji_max = (p_max / cfg.pixel_size - 0.5).astype(int) + 3 125 | imin = np.maximum(0, ji_min[1]) 126 | imax = np.minimum(height // cfg.pixel_size, ji_max[1]) 127 | jmin = np.maximum(0, ji_min[0]) 128 | jmax = np.minimum(width // cfg.pixel_size, ji_max[0]) 129 | for i in range(imin, imax): 130 | for j in range(jmin, jmax): 131 | px = (j + 0.5) * cfg.pixel_size 132 | py = (i + 0.5) * cfg.pixel_size 133 | if point_inside_of_quad(px, py, 134 | shrink_xy_list, p_min, p_max): 135 | gt[i, j, 0] = 1 136 | line_width, line_color = 1, 'red' 137 | ith = point_inside_of_nth_quad(px, py, 138 | xy_list, 139 | shrink_1, 140 | long_edge) 141 | vs = [[[3, 0], [1, 2]], [[0, 1], [2, 3]]] 142 | if ith in range(2): 143 | gt[i, j, 1] = 1 144 | if ith == 0: 145 | line_width, line_color = 2, 'yellow' 146 | else: 147 | line_width, line_color = 2, 'green' 148 | gt[i, j, 2:3] = ith 149 | gt[i, j, 3:5] = \ 150 | xy_list[vs[long_edge][ith][0]] - [px, py] 151 | gt[i, j, 5:] = \ 152 | xy_list[vs[long_edge][ith][1]] - [px, py] 153 | draw.line([(px - 0.5 * cfg.pixel_size, 154 | py - 0.5 * cfg.pixel_size), 155 | (px + 0.5 * cfg.pixel_size, 156 | py - 0.5 * cfg.pixel_size), 157 | (px + 0.5 * cfg.pixel_size, 158 | py + 0.5 * cfg.pixel_size), 159 | (px - 0.5 * cfg.pixel_size, 160 | py + 0.5 * cfg.pixel_size), 161 | (px - 0.5 * cfg.pixel_size, 162 | py - 0.5 * cfg.pixel_size)], 163 | width=line_width, fill=line_color) 164 | act_image_dir = os.path.join(cfg.data_dir, 165 | cfg.show_act_image_dir_name) 166 | if cfg.draw_act_quad: 167 | im.save(os.path.join(act_image_dir, img_name)) 168 | train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name) 169 | np.save(os.path.join(train_label_dir, 170 | img_name[:-4] + '_gt.npy'), gt) 171 | 172 | 173 | if __name__ == '__main__': 174 | process_label() 175 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import cfg 4 | import torch.nn.functional as F 5 | import numpy as np 6 | # def get_dice_loss(gt_score, pred_score): 7 | # inter = torch.sum(gt_score * pred_score) 8 | # union = torch.sum(gt_score) + torch.sum(pred_score) + 1e-5 9 | # return 1. - (2 * inter / union) 10 | # 11 | # 12 | # def get_geo_loss(gt_geo, pred_geo): 13 | # d1_gt, d2_gt, d3_gt, d4_gt, angle_gt = torch.split(gt_geo, 1, 1) 14 | # d1_pred, d2_pred, d3_pred, d4_pred, angle_pred = torch.split(pred_geo, 1, 1) 15 | # area_gt = (d1_gt + d2_gt) * (d3_gt + d4_gt) 16 | # area_pred = (d1_pred + d2_pred) * (d3_pred + d4_pred) 17 | # w_union = torch.min(d3_gt, d3_pred) + torch.min(d4_gt, d4_pred) 18 | # h_union = torch.min(d1_gt, d1_pred) + torch.min(d2_gt, d2_pred) 19 | # area_intersect = w_union * h_union 20 | # area_union = area_gt + area_pred - area_intersect 21 | # iou_loss_map = -torch.log((area_intersect + 1.0)/(area_union + 1.0)) 22 | # angle_loss_map = 1 - torch.cos(angle_pred - angle_gt) 23 | # return iou_loss_map, angle_loss_map 24 | 25 | 26 | # class Loss(nn.Module): 27 | # def __init__(self, weight_angle=10): 28 | # super(Loss, self).__init__() 29 | # self.weight_angle = weight_angle 30 | # 31 | # def forward(self, gt_score, pred_score, gt_geo, pred_geo, ignored_map): 32 | # if torch.sum(gt_score) < 1: 33 | # return torch.sum(pred_score + pred_geo) * 0 34 | # 35 | # classify_loss = get_dice_loss(gt_score, pred_score*(1-ignored_map)) 36 | # iou_loss_map, angle_loss_map = get_geo_loss(gt_geo, pred_geo) 37 | # 38 | # angle_loss = torch.sum(angle_loss_map*gt_score) / torch.sum(gt_score) 39 | # iou_loss = torch.sum(iou_loss_map*gt_score) / torch.sum(gt_score) 40 | # geo_loss = self.weight_angle * angle_loss + iou_loss 41 | # print('classify loss is {:.8f}, angle loss is {:.8f}, iou loss is {:.8f}'.format(classify_loss, angle_loss, iou_loss)) 42 | # return geo_loss + classify_loss 43 | 44 | 45 | 46 | 47 | def quad_loss(y_true, y_pred): 48 | # loss for inside_score 49 | logits = y_pred[:, :1, :, :] 50 | labels = y_true[:, :1, :, :] 51 | # balance positive and negative samples in an image 52 | beta = 1 - torch.mean(labels) 53 | # first apply sigmoid activation 54 | predicts = torch.sigmoid(logits) 55 | # log +epsilon for stable cal 56 | inside_score_loss = torch.mean( 57 | -1 * (beta * labels * torch.log(predicts + cfg.epsilon) + 58 | (1 - beta) * (1 - labels) * torch.log(1 - predicts + cfg.epsilon))) 59 | inside_score_loss *= cfg.lambda_inside_score_loss 60 | 61 | # loss for side_vertex_code 62 | vertex_logits = y_pred[:, 1:3, :, :] 63 | vertex_labels = y_true[:, 1:3, :, :] 64 | vertex_beta = 1 - (torch.mean(y_true[:, 1:2, :, :]) 65 | / (torch.mean(labels) + cfg.epsilon)) 66 | vertex_predicts = torch.sigmoid(vertex_logits) 67 | pos = -1 * vertex_beta * vertex_labels * torch.log(vertex_predicts + 68 | cfg.epsilon) 69 | neg = -1 * (1 - vertex_beta) * (1 - vertex_labels) * torch.log( 70 | 1 - vertex_predicts + cfg.epsilon) 71 | 72 | # positive_weights = torch.cast(torch.eq(y_true[:, :, :, 0], 1), tf.float32) 73 | positive_weights = torch.eq(y_true[:, 0, :, :], 1).float() 74 | side_vertex_code_loss = \ 75 | torch.sum(torch.sum(pos + neg, 1) * positive_weights) / ( 76 | torch.sum(positive_weights) + cfg.epsilon) 77 | side_vertex_code_loss *= cfg.lambda_side_vertex_code_loss 78 | 79 | # loss for side_vertex_coord delta 80 | g_hat = y_pred[:, 3:, :, :] 81 | g_true = y_true[:, 3:, :, :] 82 | vertex_weights = torch.eq(y_true[:, 1, :, :], 1).float() 83 | pixel_wise_smooth_l1norm = smooth_l1_loss(g_hat, g_true, vertex_weights) 84 | side_vertex_coord_loss = torch.sum(pixel_wise_smooth_l1norm) / ( 85 | torch.sum(vertex_weights) + cfg.epsilon) 86 | side_vertex_coord_loss *= cfg.lambda_side_vertex_coord_loss 87 | return inside_score_loss , side_vertex_code_loss , side_vertex_coord_loss 88 | 89 | 90 | def smooth_l1_loss(prediction_tensor, target_tensor, weights): 91 | n_q = torch.reshape(quad_norm(target_tensor), weights.size()) 92 | diff = prediction_tensor - target_tensor 93 | abs_diff = torch.abs(diff) 94 | abs_diff_lt_1 = torch.lt(abs_diff, 1) 95 | pixel_wise_smooth_l1norm = (torch.sum( 96 | torch.where(abs_diff_lt_1, 0.5 * torch.pow(abs_diff,2), abs_diff - 0.5),1) / n_q) * weights 97 | return pixel_wise_smooth_l1norm 98 | 99 | 100 | def quad_norm(g_true): 101 | t_shape = g_true.permute(0,2,3,1) 102 | shape = t_shape.size() #n h w c 103 | delta_xy_matrix = torch.reshape(t_shape, [-1, 2, 2]) 104 | diff = delta_xy_matrix[:, 0:1, :] - delta_xy_matrix[:, 1:2, :] 105 | square = torch.pow(diff,2) 106 | distance = torch.sqrt(torch.sum(square, 2)) 107 | distance *= 4.0 108 | distance += cfg.epsilon 109 | return torch.reshape(distance, shape[:-1]) 110 | 111 | class Loss(nn.Module): 112 | def __init__(self): 113 | super(Loss, self).__init__() 114 | return 115 | def forward(self,y_true, y_pred): 116 | inside_score_loss , side_vertex_code_loss , side_vertex_coord_loss=quad_loss(y_true,y_pred) 117 | print('inside_score_loss is {:.8f}\n side_vertex_code_loss is {:.8f}\n side_vertex_coord_loss loss is {:.8f}\n'.format(inside_score_loss, side_vertex_code_loss, 118 | side_vertex_coord_loss)) 119 | return inside_score_loss + side_vertex_code_loss + side_vertex_coord_loss 120 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import torch.nn.functional as F 5 | import math 6 | 7 | # from tensorboardX import SummaryWriter 8 | 9 | 10 | class hswish(nn.Module): 11 | def forward(self, x): 12 | out = x * F.relu6(x + 3, inplace=True) / 6 13 | return out 14 | 15 | 16 | class hsigmoid(nn.Module): 17 | def forward(self, x): 18 | out = F.relu6(x + 3, inplace=True) / 6 19 | return out 20 | 21 | 22 | class SeModule(nn.Module): 23 | def __init__(self, in_size, reduction=4): 24 | super(SeModule, self).__init__() 25 | self.se = nn.Sequential( 26 | nn.AdaptiveAvgPool2d(1), 27 | nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False), 28 | nn.BatchNorm2d(in_size // reduction), 29 | nn.ReLU(inplace=True), 30 | nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False), 31 | nn.BatchNorm2d(in_size), 32 | hsigmoid() 33 | ) 34 | 35 | def forward(self, x): 36 | return x * self.se(x) 37 | 38 | 39 | class Block(nn.Module): 40 | '''expand + depthwise + pointwise''' 41 | def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride): 42 | super(Block, self).__init__() 43 | self.stride = stride 44 | self.se = semodule 45 | 46 | self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0) 47 | self.bn1 = nn.BatchNorm2d(expand_size) 48 | self.nolinear1 = nolinear 49 | self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False) 50 | self.bn2 = nn.BatchNorm2d(expand_size) 51 | self.nolinear2 = nolinear 52 | self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False) 53 | self.bn3 = nn.BatchNorm2d(out_size) 54 | 55 | self.shortcut = nn.Sequential() 56 | if stride == 1 and in_size != out_size: 57 | self.shortcut = nn.Sequential( 58 | nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False), 59 | nn.BatchNorm2d(out_size), 60 | ) 61 | 62 | def forward(self, x): 63 | out = self.nolinear1(self.bn1(self.conv1(x))) 64 | out = self.nolinear2(self.bn2(self.conv2(out))) 65 | out = self.bn3(self.conv3(out)) 66 | if self.se != None: 67 | out = self.se(out) 68 | out = out + self.shortcut(x) if self.stride==1 else out 69 | return out 70 | 71 | 72 | class MobileNetV3_Large(nn.Module): 73 | def __init__(self): 74 | super(MobileNetV3_Large, self).__init__() 75 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False) 76 | self.bn1 = nn.BatchNorm2d(16) 77 | self.hs1 = hswish() 78 | 79 | self.bneck1 = nn.Sequential( 80 | Block(3, 16, 16, 16, nn.ReLU(inplace=True), None, 1), 81 | Block(3, 16, 64, 24, nn.ReLU(inplace=True), None, 2), 82 | Block(3, 24, 72, 24, nn.ReLU(inplace=True), None, 1), 83 | nn.Conv2d(24, 72, kernel_size=1, stride=1, padding=0, bias=False), 84 | nn.BatchNorm2d(72), 85 | nn.ReLU(inplace=True), 86 | ) # 1 72 128 128 87 | 88 | self.bneck2 = nn.Sequential( 89 | nn.Conv2d(72, 72, kernel_size=5, stride=2, padding=2, groups=72, bias=False), 90 | nn.BatchNorm2d(72), 91 | nn.ReLU(inplace=True), 92 | nn.Conv2d(72, 40, kernel_size=1, stride=1, padding=0, bias=False), 93 | nn.BatchNorm2d(40), 94 | SeModule(40), 95 | 96 | Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1), 97 | Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1), 98 | nn.Conv2d(40, 240, kernel_size=1, stride=1, padding=0, bias=False), 99 | nn.BatchNorm2d(240), 100 | hswish(), 101 | ) # 1 240 64 64 102 | 103 | self.bneck3 = nn.Sequential( 104 | nn.Conv2d(240, 240, kernel_size=3, stride=2, padding=1, groups=240, bias=False), 105 | nn.BatchNorm2d(240), 106 | hswish(), 107 | nn.Conv2d(240, 80, kernel_size=1, stride=1, padding=0, bias=False), 108 | nn.BatchNorm2d(80), 109 | Block(3, 80, 200, 80, hswish(), None, 1), 110 | Block(3, 80, 184, 80, hswish(), None, 1), 111 | Block(3, 80, 184, 80, hswish(), None, 1), 112 | Block(3, 80, 480, 112, hswish(), SeModule(112), 1), 113 | Block(3, 112, 672, 112, hswish(), SeModule(112), 1), 114 | Block(5, 112, 672, 160, hswish(), SeModule(160), 1), 115 | nn.Conv2d(160, 672, kernel_size=1, stride=1, padding=0, bias=False), 116 | nn.BatchNorm2d(672), 117 | hswish(), 118 | ) # 1 672 32 32 119 | 120 | self.bneck4 = nn.Sequential( 121 | nn.Conv2d(672, 672, kernel_size=5, stride=2, padding=2, groups=672, bias=False), 122 | nn.BatchNorm2d(672), 123 | hswish(), 124 | nn.Conv2d(672, 160, kernel_size=1, stride=1, padding=0, bias=False), 125 | nn.BatchNorm2d(160), 126 | SeModule(160), 127 | Block(5, 160, 960, 160, hswish(), SeModule(160), 1), 128 | ) # 1 160 16 16 129 | 130 | self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False) 131 | self.bn2 = nn.BatchNorm2d(960) 132 | self.hs2 = hswish() 133 | 134 | self.conv3 = nn.Conv2d(960, 640, kernel_size=1, stride=1, padding=0, bias=False) 135 | self.bn3 = nn.BatchNorm2d(640) 136 | self.linear = nn.ReLU(inplace=True) 137 | 138 | self.init_params() 139 | 140 | def init_params(self): 141 | for m in self.modules(): 142 | if isinstance(m, nn.Conv2d): 143 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 144 | if m.bias is not None: 145 | nn.init.constant_(m.bias, 0) 146 | elif isinstance(m, nn.BatchNorm2d): 147 | nn.init.constant_(m.weight, 1) 148 | nn.init.constant_(m.bias, 0) 149 | elif isinstance(m, nn.Linear): 150 | nn.init.normal_(m.weight, std=0.001) 151 | if m.bias is not None: 152 | nn.init.constant_(m.bias, 0) 153 | 154 | def forward(self, x): 155 | out = self.hs1(self.bn1(self.conv1(x))) 156 | # print(out.shape) torch.Size([2, 16, 256, 256]) 157 | out1 = self.bneck1(out) 158 | # print(out1.shape) torch.Size([2, 72, 128, 128]) up 159 | out2 = self.bneck2(out1) 160 | # print(out2.shape) torch.Size([2, 240, 64, 64]) up 161 | out3 = self.bneck3(out2) 162 | # print(out3.shape) torch.Size([2, 672, 32, 32]) up 163 | out = self.bneck4(out3) 164 | # print(out4.shape) torch.Size([2, 160, 16, 16]) 165 | out = self.hs2(self.bn2(self.conv2(out))) 166 | # print(out.shape) torch.Size([2, 960, 16, 16]) up 167 | out = self.linear(self.bn3(self.conv3(out))) 168 | # print(out.shape) torch.Size([2, 640, 16, 16]) 169 | return out ,out1,out2,out3, 170 | 171 | class merge(nn.Module): 172 | def __init__(self): 173 | super(merge, self).__init__() 174 | 175 | self.conv1 = nn.Conv2d(1312, 320, 1) 176 | self.bn1 = nn.BatchNorm2d(320) 177 | self.relu1 = nn.ReLU() 178 | self.conv2 = nn.Conv2d(320, 320, 3, padding=1) 179 | self.bn2 = nn.BatchNorm2d(320) 180 | self.relu2 = nn.ReLU() 181 | 182 | self.conv3 = nn.Conv2d(560, 160, 1) 183 | self.bn3 = nn.BatchNorm2d(160) 184 | self.relu3 = nn.ReLU() 185 | self.conv4 = nn.Conv2d(160, 160, 3, padding=1) 186 | self.bn4 = nn.BatchNorm2d(160) 187 | self.relu4 = nn.ReLU() 188 | 189 | self.conv5 = nn.Conv2d(232, 128, 1) 190 | self.bn5 = nn.BatchNorm2d(128) 191 | self.relu5 = nn.ReLU() 192 | self.conv6 = nn.Conv2d(128, 128, 3, padding=1) 193 | self.bn6 = nn.BatchNorm2d(128) 194 | self.relu6 = nn.ReLU() 195 | 196 | # self.conv7 = nn.Conv2d(32, 32, 3, padding=1) 197 | # self.bn7 = nn.BatchNorm2d(32) 198 | # self.relu7 = nn.ReLU() 199 | 200 | for m in self.modules(): 201 | if isinstance(m, nn.Conv2d): 202 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 203 | if m.bias is not None: 204 | nn.init.constant_(m.bias, 0) 205 | elif isinstance(m, nn.BatchNorm2d): 206 | nn.init.constant_(m.weight, 1) 207 | nn.init.constant_(m.bias, 0) 208 | 209 | def forward(self, x,x1,x2,x3): 210 | # print(x.shape) 1 640 16 16 211 | y = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) 212 | # print(y.shape) 1 640 32 32 213 | y = torch.cat((y, x3), 1) # 1 1312 32 32 214 | y = self.relu1(self.bn1(self.conv1(y))) 215 | y = self.relu2(self.bn2(self.conv2(y))) # 1 320 32 32 216 | 217 | y = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True) #1 320 64 64 218 | y = torch.cat((y, x2), 1) #1 560 64 64 219 | y = self.relu3(self.bn3(self.conv3(y))) 220 | y = self.relu4(self.bn4(self.conv4(y))) # 1 160 64 64 221 | 222 | y = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True) #1 160 128 128 223 | y = torch.cat((y, x1), 1) # 1 232 128 128 224 | y = self.relu5(self.bn5(self.conv5(y))) 225 | y = self.relu6(self.bn6(self.conv6(y))) # 1 128 128 128 226 | 227 | # y = self.relu7(self.bn7(self.conv7(y))) 228 | return y 229 | 230 | class output(nn.Module): 231 | def __init__(self): 232 | super(output, self).__init__() 233 | self.conv1 = nn.Conv2d(128, 1, 1) 234 | 235 | self.conv2 = nn.Conv2d(128, 2, 1) 236 | 237 | self.conv3 = nn.Conv2d(128, 4, 1) 238 | 239 | 240 | for m in self.modules(): 241 | if isinstance(m, nn.Conv2d): 242 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 243 | if m.bias is not None: 244 | nn.init.constant_(m.bias, 0) 245 | 246 | def forward(self, x): 247 | inside_score = self.conv1(x) 248 | side_v_code = self.conv2(x) 249 | side_v_coord = self.conv3(x) 250 | east_detect = torch.cat((inside_score, side_v_code,side_v_coord), 1) 251 | return east_detect 252 | 253 | 254 | class EAST(nn.Module): 255 | def __init__(self): 256 | super(EAST, self).__init__() 257 | self.extractor = MobileNetV3_Large() 258 | self.merge = merge() 259 | self.output = output() 260 | 261 | def forward(self, x): 262 | 263 | x,x1,x2,x3=self.extractor(x) 264 | 265 | return self.output(self.merge(x,x1,x2,x3)) 266 | 267 | 268 | if __name__ == '__main__': 269 | 270 | m = EAST() 271 | # x = torch.randn(1, 3, 512, 512) 272 | # with SummaryWriter(comment='mobilenetv3') as w: 273 | # w.add_graph(m, (x,)) 274 | # east_detect= m(x) 275 | # print(east_detect.shape) 276 | print(m) 277 | 278 | -------------------------------------------------------------------------------- /nms.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import numpy as np 3 | 4 | import cfg 5 | 6 | 7 | def should_merge(region, i, j): 8 | neighbor = {(i, j - 1)} 9 | return not region.isdisjoint(neighbor) ##判断集合元素是否相等,返回true , not true =false 10 | 11 | 12 | def region_neighbor(region_set): 13 | region_pixels = np.array(list(region_set))###由{(a,b)}转换为[[a b]] 14 | j_min = np.amin(region_pixels, axis=0)[1] - 1 ####取行最小值 ,如不指定,则是所有元素的最大值 15 | j_max = np.amax(region_pixels, axis=0)[1] + 1 16 | i_m = np.amin(region_pixels, axis=0)[0] + 1 17 | region_pixels[:, 0] += 1 18 | neighbor = {(region_pixels[n, 0], region_pixels[n, 1]) for n in 19 | range(len(region_pixels))} 20 | neighbor.add((i_m, j_min)) 21 | neighbor.add((i_m, j_max)) 22 | return neighbor 23 | 24 | 25 | def region_group(region_list):###len(region_list) =36 26 | S = [i for i in range(len(region_list))] 27 | D = [] 28 | while len(S) > 0: 29 | m = S.pop(0) 30 | if len(S) == 0: 31 | # S has only one element, put it to D 32 | D.append([m]) 33 | else: 34 | D.append(rec_region_merge(region_list, m, S)) 35 | return D 36 | 37 | 38 | def rec_region_merge(region_list, m, S): 39 | rows = [m] 40 | tmp = [] 41 | for n in S: 42 | if not region_neighbor(region_list[m]).isdisjoint(region_list[n]) or \ 43 | not region_neighbor(region_list[n]).isdisjoint(region_list[m]): 44 | # 第m与n相交 45 | tmp.append(n)####方法用于在列表末尾添加新的对象 46 | for d in tmp: 47 | S.remove(d) ###指定删除list 48 | for e in tmp: 49 | rows.extend(rec_region_merge(region_list, e, S))####于在列表末尾一次性追加另一个序列中的多个值 50 | 51 | 52 | return rows 53 | 54 | 55 | def nms(predict, activation_pixels, threshold=cfg.side_vertex_pixel_threshold): 56 | region_list = [] 57 | for i, j in zip(activation_pixels[0], activation_pixels[1]): 58 | merge = False 59 | for k in range(len(region_list)): 60 | if should_merge(region_list[k], i, j): 61 | # print(region_list) 62 | region_list[k].add((i, j)) 63 | merge = True 64 | # Fixme 重叠文本区域处理,存在和多个区域邻接的pixels,先都merge试试 65 | # break 66 | if not merge: 67 | region_list.append({(i, j)}) 68 | 69 | D = region_group(region_list) 70 | # print(D) 71 | quad_list = np.zeros((len(D), 4, 2)) 72 | score_list = np.zeros((len(D), 4)) 73 | for group, g_th in zip(D, range(len(D))): 74 | total_score = np.zeros((4, 2)) 75 | for row in group: 76 | for ij in region_list[row]: 77 | score = predict[1,ij[0], ij[1]] 78 | if score >= threshold: #####threshold == 0.9 79 | ith_score = predict[2:3,ij[0], ij[1]] 80 | if not (cfg.trunc_threshold <= ith_score < 1 - 81 | cfg.trunc_threshold): 82 | ith = int(np.around(ith_score)) 83 | total_score[ith * 2:(ith + 1) * 2] += score 84 | px = (ij[1] + 0.5) * cfg.pixel_size 85 | py = (ij[0] + 0.5) * cfg.pixel_size 86 | p_v = [px, py] + np.reshape(predict[3:7,ij[0], ij[1]], 87 | (2, 2)) 88 | quad_list[g_th, ith * 2:(ith + 1) * 2] += score * p_v 89 | score_list[g_th] = total_score[:, 0] 90 | quad_list[g_th] /= (total_score + cfg.epsilon) 91 | # print("score_list: ", score_list) 92 | # print("quad_list: ",quad_list) 93 | return score_list, quad_list 94 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torchvision import transforms 5 | from model import EAST 6 | 7 | import numpy as np 8 | from PIL import Image, ImageDraw 9 | 10 | import cfg 11 | 12 | from preprocess import resize_image 13 | from nms import nms 14 | 15 | def sigmoid(x): 16 | """`y = 1 / (1 + exp(-x))`""" 17 | return 1 / (1 + np.exp(-x)) 18 | 19 | def load_pil(img): 20 | '''convert PIL Image to torch.Tensor 21 | ''' 22 | t = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))]) 23 | return t(img).unsqueeze(0) 24 | 25 | def detect(img_path, model, device,pixel_threshold,quiet=True): 26 | img = Image.open(img_path) 27 | d_wight, d_height = resize_image(img, cfg.max_predict_img_size) 28 | img = img.resize((d_wight, d_height), Image.NEAREST).convert('RGB') 29 | with torch.no_grad(): 30 | east_detect=model(load_pil(img).to(device)) 31 | y = np.squeeze(east_detect.cpu().numpy(), axis=0) 32 | y[:3, :, :] = sigmoid(y[:3, :, :]) 33 | cond = np.greater_equal(y[0, :, :], pixel_threshold) 34 | activation_pixels = np.where(cond) 35 | quad_scores, quad_after_nms = nms(y, activation_pixels) 36 | with Image.open(img_path) as im: 37 | d_wight, d_height = resize_image(im, cfg.max_predict_img_size) 38 | scale_ratio_w = d_wight / im.width 39 | scale_ratio_h = d_height / im.height 40 | im = im.resize((d_wight, d_height), Image.NEAREST).convert('RGB') 41 | quad_im = im.copy() 42 | draw = ImageDraw.Draw(im) 43 | for i, j in zip(activation_pixels[0], activation_pixels[1]): 44 | px = (j + 0.5) * cfg.pixel_size 45 | py = (i + 0.5) * cfg.pixel_size 46 | line_width, line_color = 1, 'red' 47 | if y[1,i, j] >= cfg.side_vertex_pixel_threshold: 48 | if y[2,i, j] < cfg.trunc_threshold: 49 | line_width, line_color = 2, 'yellow' 50 | elif y[2,i, j] >= 1 - cfg.trunc_threshold: 51 | line_width, line_color = 2, 'green' 52 | draw.line([(px - 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size), 53 | (px + 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size), 54 | (px + 0.5 * cfg.pixel_size, py + 0.5 * cfg.pixel_size), 55 | (px - 0.5 * cfg.pixel_size, py + 0.5 * cfg.pixel_size), 56 | (px - 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size)], 57 | width=line_width, fill=line_color) 58 | im.save(img_path + '_act.jpg') 59 | quad_draw = ImageDraw.Draw(quad_im) 60 | txt_items = [] 61 | for score, geo, s in zip(quad_scores, quad_after_nms, 62 | range(len(quad_scores))): 63 | 64 | if np.amin(score) > 0: 65 | quad_draw.line([tuple(geo[0]), 66 | tuple(geo[1]), 67 | tuple(geo[2]), 68 | tuple(geo[3]), 69 | tuple(geo[0])], width=2, fill='red') 70 | 71 | rescaled_geo = geo / [scale_ratio_w, scale_ratio_h] 72 | rescaled_geo_list = np.reshape(rescaled_geo, (8,)).tolist() 73 | txt_item = ','.join(map(str, rescaled_geo_list)) 74 | txt_items.append(txt_item + '\n') 75 | elif not quiet: 76 | print('quad invalid with vertex num less then 4.') 77 | quad_im.save(img_path + '_predict.jpg') 78 | if cfg.predict_write2txt and len(txt_items) > 0: 79 | with open(img_path[:-4] + '.txt', 'w') as f_txt: 80 | f_txt.writelines(txt_items) 81 | 82 | 83 | def parse_args(): 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('--path', '-p', 86 | default='./00001.jpg', 87 | help='image path') 88 | parser.add_argument('--threshold', '-t', 89 | default=cfg.pixel_threshold, 90 | help='pixel activation threshold') 91 | return parser.parse_args() 92 | 93 | 94 | if __name__ == '__main__': 95 | args = parse_args() 96 | img_path = args.path 97 | threshold = float(args.threshold) 98 | print(img_path, threshold) 99 | model_path='./saved_model/mb3_512_model_epoch_535.pth' 100 | 101 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 102 | model = EAST().to(device) 103 | model.load_state_dict(torch.load(model_path)) 104 | model.eval() 105 | 106 | 107 | detect(img_path, model, device,threshold) 108 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image, ImageDraw 3 | import os 4 | import random 5 | from tqdm import tqdm 6 | 7 | import cfg 8 | from label import shrink 9 | 10 | 11 | def batch_reorder_vertexes(xy_list_array): 12 | reorder_xy_list_array = np.zeros_like(xy_list_array) 13 | for xy_list, i in zip(xy_list_array, range(len(xy_list_array))): 14 | reorder_xy_list_array[i] = reorder_vertexes(xy_list) 15 | return reorder_xy_list_array 16 | 17 | 18 | def reorder_vertexes(xy_list): 19 | reorder_xy_list = np.zeros_like(xy_list) 20 | # determine the first point with the smallest x, 21 | # if two has same x, choose that with smallest y, 22 | ordered = np.argsort(xy_list, axis=0) 23 | xmin1_index = ordered[0, 0] 24 | xmin2_index = ordered[1, 0] 25 | if xy_list[xmin1_index, 0] == xy_list[xmin2_index, 0]: 26 | if xy_list[xmin1_index, 1] <= xy_list[xmin2_index, 1]: 27 | reorder_xy_list[0] = xy_list[xmin1_index] 28 | first_v = xmin1_index 29 | else: 30 | reorder_xy_list[0] = xy_list[xmin2_index] 31 | first_v = xmin2_index 32 | else: 33 | reorder_xy_list[0] = xy_list[xmin1_index] 34 | first_v = xmin1_index 35 | # connect the first point to others, the third point on the other side of 36 | # the line with the middle slope 37 | others = list(range(4)) 38 | others.remove(first_v) 39 | k = np.zeros((len(others),)) 40 | for index, i in zip(others, range(len(others))): 41 | k[i] = (xy_list[index, 1] - xy_list[first_v, 1]) \ 42 | / (xy_list[index, 0] - xy_list[first_v, 0] + cfg.epsilon) 43 | k_mid = np.argsort(k)[1] 44 | third_v = others[k_mid] 45 | reorder_xy_list[2] = xy_list[third_v] 46 | # determine the second point which on the bigger side of the middle line 47 | others.remove(third_v) 48 | b_mid = xy_list[first_v, 1] - k[k_mid] * xy_list[first_v, 0] 49 | second_v, fourth_v = 0, 0 50 | for index, i in zip(others, range(len(others))): 51 | # delta = y - (k * x + b) 52 | delta_y = xy_list[index, 1] - (k[k_mid] * xy_list[index, 0] + b_mid) 53 | if delta_y > 0: 54 | second_v = index 55 | else: 56 | fourth_v = index 57 | reorder_xy_list[1] = xy_list[second_v] 58 | reorder_xy_list[3] = xy_list[fourth_v] 59 | # compare slope of 13 and 24, determine the final order 60 | k13 = k[k_mid] 61 | k24 = (xy_list[second_v, 1] - xy_list[fourth_v, 1]) / ( 62 | xy_list[second_v, 0] - xy_list[fourth_v, 0] + cfg.epsilon) 63 | if k13 < k24: 64 | tmp_x, tmp_y = reorder_xy_list[3, 0], reorder_xy_list[3, 1] 65 | for i in range(2, -1, -1): 66 | reorder_xy_list[i + 1] = reorder_xy_list[i] 67 | reorder_xy_list[0, 0], reorder_xy_list[0, 1] = tmp_x, tmp_y 68 | return reorder_xy_list 69 | 70 | 71 | def resize_image(im, max_img_size=cfg.max_train_img_size): 72 | im_width = np.minimum(im.width, max_img_size)#####min(475,256)=256 73 | if im_width == max_img_size < im.width: 74 | im_height = int((im_width / im.width) * im.height)####256/475 * 300==161 75 | else: 76 | im_height = im.height#### 300 77 | o_height = np.minimum(im_height, max_img_size) ####min(161,256)===161 78 | if o_height == max_img_size < im_height: 79 | o_width = int((o_height / im_height) * im_width) 80 | else: 81 | o_width = im_width###256 82 | d_wight = o_width - (o_width % 32)####256-0=250 83 | d_height = o_height - (o_height % 32)#####161-1=160 84 | return d_wight, d_height 85 | 86 | 87 | def preprocess(): 88 | data_dir = cfg.data_dir 89 | origin_image_dir = os.path.join(data_dir, cfg.origin_image_dir_name) 90 | origin_txt_dir = os.path.join(data_dir, cfg.origin_txt_dir_name) 91 | train_image_dir = os.path.join(data_dir, cfg.train_image_dir_name) 92 | train_label_dir = os.path.join(data_dir, cfg.train_label_dir_name) 93 | if not os.path.exists(train_image_dir): 94 | os.mkdir(train_image_dir) 95 | if not os.path.exists(train_label_dir): 96 | os.mkdir(train_label_dir) 97 | draw_gt_quad = cfg.draw_gt_quad 98 | show_gt_image_dir = os.path.join(data_dir, cfg.show_gt_image_dir_name) 99 | if not os.path.exists(show_gt_image_dir): 100 | os.mkdir(show_gt_image_dir) 101 | show_act_image_dir = os.path.join(cfg.data_dir, cfg.show_act_image_dir_name) 102 | if not os.path.exists(show_act_image_dir): 103 | os.mkdir(show_act_image_dir) 104 | 105 | o_img_list = os.listdir(origin_image_dir) 106 | print('found %d origin images.' % len(o_img_list)) 107 | train_val_set = [] 108 | for o_img_fname, _ in zip(o_img_list, tqdm(range(len(o_img_list)))): 109 | with Image.open(os.path.join(origin_image_dir, o_img_fname)) as im: 110 | # d_wight, d_height = resize_image(im) 111 | d_wight, d_height = cfg.max_train_img_size, cfg.max_train_img_size 112 | scale_ratio_w = d_wight / im.width 113 | scale_ratio_h = d_height / im.height 114 | im = im.resize((d_wight, d_height), Image.NEAREST).convert('RGB') 115 | show_gt_im = im.copy() 116 | # draw on the img 117 | draw = ImageDraw.Draw(show_gt_im) 118 | with open(os.path.join(origin_txt_dir, 119 | o_img_fname[:-4] + '.txt'), 'r') as f: 120 | anno_list = f.readlines() 121 | xy_list_array = np.zeros((len(anno_list), 4, 2)) 122 | for anno, i in zip(anno_list, range(len(anno_list))): 123 | anno_colums = anno.strip().split(',') 124 | anno_array = np.array(anno_colums) 125 | xy_list = np.reshape(anno_array[:8].astype(float), (4, 2)) 126 | xy_list[:, 0] = xy_list[:, 0] * scale_ratio_w 127 | xy_list[:, 1] = xy_list[:, 1] * scale_ratio_h 128 | xy_list = reorder_vertexes(xy_list) 129 | xy_list_array[i] = xy_list 130 | _, shrink_xy_list, _ = shrink(xy_list, cfg.shrink_ratio) 131 | shrink_1, _, long_edge = shrink(xy_list, cfg.shrink_side_ratio) 132 | if draw_gt_quad: 133 | draw.line([tuple(xy_list[0]), tuple(xy_list[1]), 134 | tuple(xy_list[2]), tuple(xy_list[3]), 135 | tuple(xy_list[0]) 136 | ], 137 | width=2, fill='green') 138 | draw.line([tuple(shrink_xy_list[0]), 139 | tuple(shrink_xy_list[1]), 140 | tuple(shrink_xy_list[2]), 141 | tuple(shrink_xy_list[3]), 142 | tuple(shrink_xy_list[0]) 143 | ], 144 | width=2, fill='blue') 145 | vs = [[[0, 0, 3, 3, 0], [1, 1, 2, 2, 1]], 146 | [[0, 0, 1, 1, 0], [2, 2, 3, 3, 2]]] 147 | for q_th in range(2): 148 | draw.line([tuple(xy_list[vs[long_edge][q_th][0]]), 149 | tuple(shrink_1[vs[long_edge][q_th][1]]), 150 | tuple(shrink_1[vs[long_edge][q_th][2]]), 151 | tuple(xy_list[vs[long_edge][q_th][3]]), 152 | tuple(xy_list[vs[long_edge][q_th][4]])], 153 | width=3, fill='yellow') 154 | if cfg.gen_origin_img: 155 | im.save(os.path.join(train_image_dir, o_img_fname)) 156 | np.save(os.path.join( 157 | train_label_dir, 158 | o_img_fname[:-4] + '.npy'), 159 | xy_list_array) 160 | if draw_gt_quad: 161 | show_gt_im.save(os.path.join(show_gt_image_dir, o_img_fname)) 162 | train_val_set.append('{},{},{}\n'.format(o_img_fname, 163 | d_wight, 164 | d_height)) 165 | 166 | train_img_list = os.listdir(train_image_dir) 167 | print('found %d train images.' % len(train_img_list)) 168 | train_label_list = os.listdir(train_label_dir) 169 | print('found %d train labels.' % len(train_label_list)) 170 | 171 | random.shuffle(train_val_set) 172 | val_count = int(cfg.validation_split_ratio * len(train_val_set)) 173 | with open(os.path.join(data_dir, cfg.val_fname), 'w') as f_val: 174 | f_val.writelines(train_val_set[:val_count]) 175 | with open(os.path.join(data_dir, cfg.train_fname), 'w') as f_train: 176 | f_train.writelines(train_val_set[val_count:]) 177 | 178 | 179 | if __name__ == '__main__': 180 | preprocess() 181 | -------------------------------------------------------------------------------- /saved_model/mb3_512_model_epoch_535.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corleonechensiyu/pytorch_AdvancedEast/5cbcfbc25c9c897526bd637ddda0182cad12ff76/saved_model/mb3_512_model_epoch_535.pth -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from torch import nn 4 | from torch.optim import lr_scheduler 5 | from generator import custom_dataset 6 | from model import EAST 7 | from loss import Loss 8 | import os 9 | import time 10 | import numpy as np 11 | import cfg 12 | 13 | def train(train_img_path, pths_path, batch_size, lr,decay, num_workers, epoch_iter, interval,pretained): 14 | file_num = len(os.listdir(train_img_path)) 15 | trainset = custom_dataset(train_img_path) 16 | train_loader = data.DataLoader(trainset, batch_size=batch_size, \ 17 | shuffle=True, num_workers=num_workers, drop_last=True) 18 | 19 | criterion = Loss() 20 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 21 | model = EAST() 22 | # TODO 可能是bug 23 | if os.path.exists(pretained): 24 | model.load_state_dict(torch.load(pretained)) 25 | 26 | data_parallel = False 27 | if torch.cuda.device_count() > 1: 28 | model = nn.DataParallel(model) 29 | data_parallel = True 30 | model.to(device) 31 | optimizer = torch.optim.Adam(model.parameters(), lr=lr,weight_decay=decay) 32 | # scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94) 33 | 34 | for epoch in range(epoch_iter): 35 | model.train() 36 | optimizer.step() 37 | epoch_loss = 0 38 | epoch_time = time.time() 39 | for i, (img, gt_map) in enumerate(train_loader): 40 | start_time = time.time() 41 | img, gt_map = img.to(device),gt_map.to(device) 42 | east_detect = model(img) 43 | loss = criterion(gt_map, east_detect) 44 | 45 | epoch_loss += loss.item() 46 | optimizer.zero_grad() 47 | loss.backward() 48 | optimizer.step() 49 | 50 | print('Epoch is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'.format(\ 51 | epoch+1, epoch_iter, i+1, int(file_num/batch_size), time.time()-start_time, loss.item())) 52 | 53 | print('epoch_loss is {:.8f}, epoch_time is {:.8f}'.format(epoch_loss/int(file_num/batch_size), time.time()-epoch_time)) 54 | print(time.asctime(time.localtime(time.time()))) 55 | print('='*50) 56 | if (epoch + 1) % interval == 0: 57 | state_dict = model.module.state_dict() if data_parallel else model.state_dict() 58 | torch.save(state_dict, os.path.join(pths_path, cfg.train_task_id+'_model_epoch_{}.pth'.format(epoch+1))) 59 | 60 | 61 | # def test(): 62 | 63 | 64 | if __name__ == '__main__': 65 | train_img_path = os.path.join(cfg.data_dir,cfg.train_image_dir_name) 66 | pths_path = './saved_model' 67 | batch_size = 10 68 | lr = 1e-3 69 | decay =5e-4 70 | num_workers = 4 71 | epoch_iter = 600 72 | save_interval = 5 73 | pretained = './saved_model/mb3_512_model_epoch_535.pth' 74 | train(train_img_path, pths_path, batch_size, lr, decay,num_workers, epoch_iter, save_interval,pretained) 75 | 76 | --------------------------------------------------------------------------------