├── EAST_box_supervision ├── dataset │ ├── ICDAR15.py │ ├── MSRA_TD500.py │ ├── SynthText.py │ └── dataset.py ├── eval.py ├── evaluate │ ├── msra │ │ ├── eval.py │ │ └── file_util.py │ ├── rrc_evaluation_funcs.py │ └── script.py ├── lib │ ├── detect.py │ └── utils.py ├── network │ ├── loss.py │ └── model.py ├── test_msra.py ├── trainSyndata.py ├── train_ICDAR15.py └── train_msra.py ├── PSENet_box_supervision ├── config │ ├── icdar15 │ │ ├── icdar15_baseline.py │ │ └── icdar15_pseudo.py │ ├── msra │ │ ├── msra_baseline.py │ │ └── msra_pseudo.py │ └── totaltext │ │ ├── psenet_baseline.py │ │ └── psenet_pseudo.py ├── dataset │ ├── __init__.py │ ├── augment.py │ ├── augment_img.py │ ├── icdar15_load.py │ ├── msra_td500.py │ ├── synthtext_load.py │ ├── total_aug.py │ └── total_text_load.py ├── evaluation │ ├── __init__.py │ ├── eval_tt.sh │ ├── msra │ │ ├── eval.py │ │ └── file_util.py │ ├── rrc_evaluation_funcs.py │ └── script.py ├── models │ ├── ShuffleNetV2.py │ ├── __init__.py │ ├── loss.py │ ├── mobilenetv3.py │ ├── model.py │ └── resnet.py ├── predict.py ├── pse │ ├── Makefile │ ├── __init__.py │ ├── include │ │ └── pybind11 │ │ │ ├── attr.h │ │ │ ├── buffer_info.h │ │ │ ├── cast.h │ │ │ ├── chrono.h │ │ │ ├── class_support.h │ │ │ ├── common.h │ │ │ ├── complex.h │ │ │ ├── descr.h │ │ │ ├── detail │ │ │ ├── class.h │ │ │ ├── common.h │ │ │ ├── descr.h │ │ │ ├── init.h │ │ │ ├── internals.h │ │ │ └── typeid.h │ │ │ ├── eigen.h │ │ │ ├── embed.h │ │ │ ├── eval.h │ │ │ ├── functional.h │ │ │ ├── iostream.h │ │ │ ├── numpy.h │ │ │ ├── operators.h │ │ │ ├── options.h │ │ │ ├── pybind11.h │ │ │ ├── pytypes.h │ │ │ ├── stl.h │ │ │ ├── stl_bind.h │ │ │ └── typeid.h │ ├── ncnn │ │ └── examples │ │ │ ├── CMakeLists.txt │ │ │ ├── psenet.cpp │ │ │ └── run.sh │ ├── pse.cpp │ └── pse.so ├── pse_pyx │ ├── __init__.py │ ├── pse.c │ ├── pse.cpp │ ├── pse.pyx │ ├── pse.pyx.bak1 │ ├── pse.pyx.bak2 │ ├── pse.pyx.bak2_select │ ├── pse.pyx.bak3 │ └── setup.py ├── test_icdar15.py ├── test_msra.py ├── train_icdar15.py ├── train_msra.py ├── train_synthtext.py ├── train_totaltext.py └── utils │ ├── .ipynb_checkpoints │ ├── config_icdar15-checkpoint.py │ └── config_icdar17-checkpoint.py │ ├── __init__.py │ ├── config_icdar15.py │ ├── config_synthtext.py │ ├── config_totaltext.py │ ├── lr_scheduler.py │ └── utils.py ├── README.md ├── TextBoxSeg ├── .gitignore ├── LICENSE ├── configs │ ├── textseg.yaml │ ├── textseg2.yaml │ └── textseg_total.yaml ├── data │ ├── st800k_crop │ └── st800k_crop2 ├── demo │ ├── TextBoxSegTool.ipynb │ ├── curved_st800k_crop.py │ ├── imgs │ │ ├── text1.png │ │ ├── text10.png │ │ ├── text11.png │ │ ├── text12.png │ │ ├── text13.png │ │ ├── text14.png │ │ ├── text15.png │ │ ├── text16.png │ │ ├── text2.png │ │ ├── text3.png │ │ ├── text4.png │ │ ├── text5.png │ │ ├── text6.png │ │ ├── text7.png │ │ ├── text8.png │ │ └── text9.png │ ├── result │ │ ├── text1.png │ │ ├── text10.png │ │ ├── text11.png │ │ ├── text12.png │ │ ├── text13.png │ │ ├── text14.png │ │ ├── text15.png │ │ ├── text16.png │ │ ├── text2.png │ │ ├── text3.png │ │ ├── text4.png │ │ ├── text5.png │ │ ├── text6.png │ │ ├── text7.png │ │ ├── text8.png │ │ └── text9.png │ ├── st800k_crop.py │ ├── st800k_crop2.py │ └── test.py ├── demo_paper.py ├── readme.md ├── segmentron │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ ├── config.py │ │ └── settings.py │ ├── data │ │ ├── __init__.py │ │ └── dataloader │ │ │ ├── Curved_Synthtext.py │ │ │ ├── Curved_Synthtext_attention.py │ │ │ ├── TextSegmentation_total.py │ │ │ ├── __init__.py │ │ │ ├── seg_data_base.py │ │ │ ├── st800k.py │ │ │ └── utils.py │ ├── models │ │ ├── __init__.py │ │ ├── backbones │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ ├── eespnet.py │ │ │ ├── hrnet.py │ │ │ ├── mobilenet.py │ │ │ ├── resnet.py │ │ │ └── xception.py │ │ ├── model_zoo.py │ │ ├── segbase.py │ │ ├── textseg.py │ │ └── textseg_attention.py │ ├── modules │ │ ├── __init__.py │ │ ├── basic.py │ │ ├── batch_norm.py │ │ ├── cc_attention.py │ │ ├── csrc │ │ │ ├── criss_cross_attention │ │ │ │ ├── ca.h │ │ │ │ └── ca_cuda.cu │ │ │ └── vision.cpp │ │ ├── module.py │ │ └── sync_bn │ │ │ └── syncbn.py │ ├── solver │ │ ├── __init__.py │ │ ├── loss.py │ │ ├── lovasz_losses.py │ │ ├── lr_scheduler.py │ │ └── optimizer.py │ └── utils │ │ ├── __init__.py │ │ ├── default_setup.py │ │ ├── distributed.py │ │ ├── download.py │ │ ├── env.py │ │ ├── filesystem.py │ │ ├── filter_negative.py │ │ ├── line_chart.py │ │ ├── logger.py │ │ ├── options.py │ │ ├── parallel.py │ │ ├── registry.py │ │ ├── score.py │ │ ├── show.py │ │ └── visualize.py ├── setup.py └── tools │ ├── demo_ctw1500.py │ ├── demo_ic15.py │ ├── demo_ic15_pslabel.py │ ├── demo_icdar19ArT.py │ ├── demo_icdar19LSVT.py │ ├── demo_msra.py │ ├── demo_paper.py │ ├── demo_paper_icdar19ArT.py │ ├── demo_paper_msra.py │ ├── demo_paper_sw.py │ ├── demo_paper_tt.py │ ├── demo_tt.py │ ├── demo_tt_DB.py │ ├── demo_tt_sw.py │ ├── gen_ctw1500.py │ ├── gen_ic15_pslabel.py │ ├── gen_icd17_pslabel.py │ ├── gen_icdar19ArT.py │ ├── gen_icdar19LSVT.py │ ├── gen_msra.py │ ├── gen_tt_DB.py │ ├── gen_tt_pslabel.py │ ├── show │ ├── 108_image.jpg │ ├── 108_mask.jpg │ ├── 108_mask_1.jpg │ ├── 108_mask_2.jpg │ ├── 108_mask_3.jpg │ ├── 108_mask_4.jpg │ ├── 108_mask_5.jpg │ ├── 108_mask_6.jpg │ └── 108_origin_image.jpg │ ├── test_demo.py │ └── train.py └── image └── 1606325537.png /EAST_box_supervision/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from network.model import EAST 3 | from network.loss import Loss 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from lib.detect import detect 8 | from evaluate.script import getresult 9 | import argparse 10 | import os 11 | import cv2 12 | from torchvision import transforms 13 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 14 | 15 | parser = argparse.ArgumentParser(description='EAST reimplementation') 16 | 17 | # Model path 18 | parser.add_argument('--resume', default="/home/wwj/workspace/Sence_Text_detection/AAAI_EAST/Baseline/EAST_v1/worksapce/ICDAR17/best_model_640aug.pth", type=str, 19 | help='Checkpoint state_dict file to resume training from') 20 | parser.add_argument('--eval_path', default="/data/data_weijiawu/Sence_Text_detection/Paper-ACCV/DomainAdaptive/ICDAR2015/EAST_v2/ICDAR15/Test/image/", type=str, 21 | help='the test image of target domain ') 22 | 23 | parser.add_argument('--output_path', default="/home/wwj/workspace/Sence_Text_detection/AAAI_EAST/Baseline/EAST_v1/evaluate/15_submit/", type=str, 24 | help='the predicted output of target domain') 25 | 26 | parser.add_argument('--vis_path', default="/home/wwj/workspace/Sence_Text_detection/AAAI_EAST/Baseline/EAST_v1/worksapce/ICDAR17/show/", type=str, 27 | help='the predicted output of target domain') 28 | args = parser.parse_args() 29 | 30 | def cvt2HeatmapImg(img): 31 | img = (np.clip(img, 0, 1) * 255).astype(np.uint8) 32 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET) 33 | return img 34 | 35 | 36 | def test(model, input_path,output_path): 37 | model.eval() 38 | # model_box.eval() 39 | image_list = os.listdir(input_path) 40 | print(" ----------------------------------------------------------------") 41 | print(" Starting Eval...") 42 | print(" ----------------------------------------------------------------") 43 | 44 | 45 | # getresult(output_path) 46 | # for size in range(640,2048,32): 47 | # print("short_line:",size) 48 | for one_image in tqdm(image_list): 49 | image_path = os.path.join(input_path, one_image) 50 | img = Image.open(image_path).convert('RGB') 51 | orign_img = cv2.imread(image_path) 52 | filename, file_ext = os.path.splitext(os.path.basename(one_image)) 53 | filename = filename.split("ts_")[-1] 54 | res_file = output_path + "res_" + filename + '.txt' 55 | vis_file = args.vis_path + filename + '.jpg' 56 | print(res_file) 57 | boxes = detect(img, model, device) 58 | 59 | 60 | with open(res_file, 'w') as f: 61 | if boxes is None: 62 | continue 63 | for i, box in enumerate(boxes): 64 | poly = np.array(box).astype(np.int32) 65 | points = np.reshape(poly, -1) 66 | # print(points[8]) 67 | strResult = ','.join( 68 | [str(points[0]), str(points[1]), str(points[2]), str(points[3]), str(points[4]), str(points[5]), 69 | str(points[6]), str(points[7])]) + '\r\n' 70 | # strResult = ','.join( 71 | # [str(points[1]), str(points[0]), str(points[3]), str(points[2]), str(points[5]), str(points[4]), 72 | # str(points[7]), str(points[6]), str("1.0")]) + '\r\n' 73 | f.write(strResult) 74 | 75 | for bbox in boxes: 76 | # bbox = bbox / scale.repeat(int(len(bbox) / 2)) 77 | bbox = np.array(bbox,np.int) 78 | cv2.drawContours(orign_img, [bbox[:8].reshape(int(bbox.shape[0] / 2), 2)], -1, (0, 0, 255), 2) 79 | # cv2.imwrite(vis_file, orign_img) 80 | # f_score_new = getresult(output_path) 81 | 82 | 83 | if __name__ == '__main__': 84 | 85 | device = torch.device("cuda") 86 | model = EAST() 87 | data_parallel = False 88 | # if torch.cuda.device_count() > 1: 89 | # model = nn.DataParallel(model) 90 | # data_parallel = True 91 | model.to(device) 92 | 93 | print("loading pretrained model from ",args.resume) 94 | model.load_state_dict(torch.load(args.resume)) 95 | # print(getresult(args.output_path)) 96 | test(model,args.eval_path,args.output_path) 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /EAST_box_supervision/evaluate/msra/eval.py: -------------------------------------------------------------------------------- 1 | import evaluate.msra.file_util as file_util 2 | import Polygon as plg 3 | import numpy as np 4 | import math 5 | import cv2 6 | 7 | 8 | pred_root = '/home/wwj/workspace/Sence_Text_detection/AAAI_EAST/Baseline/EAST_v1/worksapce/MSRA/submit' 9 | gt_root = '/data/data_weijiawu/TD500/Test_gt/' 10 | 11 | 12 | def get_pred(path): 13 | lines = file_util.read_file(path).split('\n') 14 | bboxes = [] 15 | for line in lines: 16 | if line == '': 17 | continue 18 | bbox = line.split(',') 19 | if len(bbox) % 2 == 1: 20 | print(path) 21 | bbox = [int(x) for x in bbox] 22 | bboxes.append(bbox) 23 | return bboxes 24 | 25 | 26 | def get_gt(path): 27 | lines = file_util.read_file(path).split('\n') 28 | bboxes = [] 29 | tags = [] 30 | for line in lines: 31 | if line == '': 32 | continue 33 | # line = util.str.remove_all(line, '\xef\xbb\xbf') 34 | # gt = util.str.split(line, ' ') 35 | gt = line.split(' ') 36 | 37 | w_ = np.float(gt[4]) 38 | h_ = np.float(gt[5]) 39 | x1 = np.float(gt[2]) + w_ / 2.0 40 | y1 = np.float(gt[3]) + h_ / 2.0 41 | theta = np.float(gt[6]) / math.pi * 180 42 | 43 | bbox = cv2.boxPoints(((x1, y1), (w_, h_), theta)) 44 | bbox = bbox.reshape(-1) 45 | 46 | bboxes.append(bbox) 47 | tags.append(np.int(gt[1])) 48 | return np.array(bboxes), tags 49 | 50 | 51 | def get_union(pD, pG): 52 | areaA = pD.area() 53 | areaB = pG.area() 54 | return areaA + areaB - get_intersection(pD, pG) 55 | 56 | 57 | def get_intersection(pD, pG): 58 | pInt = pD & pG 59 | if len(pInt) == 0: 60 | return 0 61 | return pInt.area() 62 | 63 | def get_msra_result(pred_path_root,gt_path_root): 64 | th = 0.5 65 | pred_list = file_util.read_dir(pred_path_root) 66 | 67 | count, tp, fp, tn, ta = 0, 0, 0, 0, 0 68 | for pred_path in pred_list: 69 | count = count + 1 70 | preds = get_pred(pred_path) 71 | gt_path = gt_path_root + pred_path.split('/')[-1].split('.')[0].split('res_')[-1] + '.gt' 72 | gts, tags = get_gt(gt_path) 73 | 74 | ta = ta + len(preds) 75 | for gt, tag in zip(gts, tags): 76 | gt = np.array(gt) 77 | gt = gt.reshape(int(gt.shape[0] / 2), 2) 78 | gt_p = plg.Polygon(gt) 79 | difficult = tag 80 | flag = 0 81 | for pred in preds: 82 | pred = np.array(pred) 83 | pred = pred.reshape(int(pred.shape[0] / 2), 2) 84 | pred_p = plg.Polygon(pred) 85 | 86 | union = get_union(pred_p, gt_p) 87 | inter = get_intersection(pred_p, gt_p) 88 | iou = float(inter) / union 89 | if iou >= th: 90 | flag = 1 91 | tp = tp + 1 92 | break 93 | 94 | if flag == 0 and difficult == 0: 95 | fp = fp + 1 96 | 97 | recall = float(tp) / (tp + fp) 98 | precision = float(tp) / ta 99 | hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall) 100 | 101 | print('p: %.4f, r: %.4f, f: %.4f' % (precision, recall, hmean)) 102 | return hmean 103 | 104 | if __name__ == '__main__': 105 | th = 0.5 106 | pred_list = file_util.read_dir(pred_root) 107 | 108 | count, tp, fp, tn, ta = 0, 0, 0, 0, 0 109 | for pred_path in pred_list: 110 | count = count + 1 111 | preds = get_pred(pred_path) 112 | gt_path = gt_root + pred_path.split('/')[-1].split('.')[0].split('res_')[-1] + '.gt' 113 | gts, tags = get_gt(gt_path) 114 | 115 | ta = ta + len(preds) 116 | for gt, tag in zip(gts, tags): 117 | gt = np.array(gt) 118 | gt = gt.reshape(int(gt.shape[0] / 2), 2) 119 | gt_p = plg.Polygon(gt) 120 | difficult = tag 121 | flag = 0 122 | for pred in preds: 123 | pred = np.array(pred) 124 | pred = pred.reshape(int(pred.shape[0] / 2), 2) 125 | pred_p = plg.Polygon(pred) 126 | 127 | union = get_union(pred_p, gt_p) 128 | inter = get_intersection(pred_p, gt_p) 129 | iou = float(inter) / union 130 | if iou >= th: 131 | flag = 1 132 | tp = tp + 1 133 | break 134 | 135 | if flag == 0 and difficult == 0: 136 | fp = fp + 1 137 | 138 | recall = float(tp) / (tp + fp) 139 | precision = float(tp) / ta 140 | hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall) 141 | 142 | print('p: %.4f, r: %.4f, f: %.4f' % (precision, recall, hmean)) 143 | -------------------------------------------------------------------------------- /EAST_box_supervision/evaluate/msra/file_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def read_dir(root): 4 | file_path_list = [] 5 | for file_path, dirs, files in os.walk(root): 6 | for file in files: 7 | file_path_list.append(os.path.join(file_path, file).replace('\\', '/')) 8 | file_path_list.sort() 9 | return file_path_list 10 | 11 | def read_file(file_path): 12 | file_object = open(file_path, 'r') 13 | file_content = file_object.read() 14 | file_object.close() 15 | return file_content 16 | 17 | def write_file(file_path, file_content): 18 | if file_path.find('/') != -1: 19 | father_dir = '/'.join(file_path.split('/')[0:-1]) 20 | if not os.path.exists(father_dir): 21 | os.makedirs(father_dir) 22 | file_object = open(file_path, 'w') 23 | file_object.write(file_content) 24 | file_object.close() 25 | 26 | 27 | def write_file_not_cover(file_path, file_content): 28 | father_dir = '/'.join(file_path.split('/')[0:-1]) 29 | if not os.path.exists(father_dir): 30 | os.makedirs(father_dir) 31 | file_object = open(file_path, 'a') 32 | file_object.write(file_content) 33 | file_object.close() -------------------------------------------------------------------------------- /EAST_box_supervision/lib/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | def adjust_box_sort(box): 3 | start = -1 4 | _box = list(np.array(box).reshape(-1,2)) 5 | min_x = min(box[0::2]) 6 | min_y = min(box[1::2]) 7 | _box.sort(key=lambda x:(x[0]-min_x)**2+(x[1]-min_y)**2) 8 | start_point = list(_box[0]) 9 | for i in range(0,8,2): 10 | x,y = box[i],box[i+1] 11 | if [x,y] == start_point: 12 | start = i//2 13 | break 14 | 15 | new_box = [] 16 | new_box.extend(box[start*2:]) 17 | new_box.extend(box[:start*2]) 18 | return new_box 19 | 20 | 21 | def setup_logger(log_file_path: str = None): 22 | import logging 23 | from colorlog import ColoredFormatter 24 | import os 25 | if os.path.exists(log_file_path): 26 | os.remove(log_file_path) 27 | logging.basicConfig(filename=log_file_path, format='%(asctime)s %(levelname)-8s %(filename)s: %(message)s', 28 | # 定义输出log的格式 29 | datefmt='%Y-%m-%d %H:%M:%S', ) 30 | """Return a logger with a default ColoredFormatter.""" 31 | formatter = ColoredFormatter("%(asctime)s %(log_color)s%(levelname)-8s %(reset)s %(filename)s: %(message)s", 32 | datefmt='%Y-%m-%d %H:%M:%S', 33 | reset=True, 34 | log_colors={ 35 | 'DEBUG': 'blue', 36 | 'INFO': 'green', 37 | 'WARNING': 'yellow', 38 | 'ERROR': 'red', 39 | 'CRITICAL': 'red', 40 | }) 41 | 42 | logger = logging.getLogger('project') 43 | handler = logging.StreamHandler() 44 | handler.setFormatter(formatter) 45 | logger.addHandler(handler) 46 | logger.setLevel(logging.DEBUG) 47 | logger.info('logger init finished') 48 | return logger -------------------------------------------------------------------------------- /EAST_box_supervision/network/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def get_dice_loss(gt_score, pred_score): 6 | inter = torch.sum(gt_score * pred_score) 7 | union = torch.sum(gt_score) + torch.sum(pred_score) + 1e-5 8 | return 1. - (2 * inter / union) 9 | 10 | 11 | def get_geo_loss(gt_geo, pred_geo): 12 | d1_gt, d2_gt, d3_gt, d4_gt, angle_gt = torch.split(gt_geo, 1, 1) 13 | d1_pred, d2_pred, d3_pred, d4_pred, angle_pred = torch.split(pred_geo, 1, 1) 14 | area_gt = (d1_gt + d2_gt) * (d3_gt + d4_gt) 15 | area_pred = (d1_pred + d2_pred) * (d3_pred + d4_pred) 16 | w_union = torch.min(d3_gt, d3_pred) + torch.min(d4_gt, d4_pred) 17 | h_union = torch.min(d1_gt, d1_pred) + torch.min(d2_gt, d2_pred) 18 | area_intersect = w_union * h_union 19 | area_union = area_gt + area_pred - area_intersect 20 | iou_loss_map = -torch.log((area_intersect + 1.0)/(area_union + 1.0)) 21 | angle_loss_map = 1 - torch.cos(angle_pred - angle_gt) 22 | return iou_loss_map, angle_loss_map 23 | 24 | 25 | class Loss(nn.Module): 26 | def __init__(self, weight_angle=10): 27 | super(Loss, self).__init__() 28 | self.weight_angle = weight_angle 29 | 30 | def forward(self, gt_score, pred_score, gt_geo, pred_geo, valid_map): 31 | if torch.sum(gt_score) < 1: 32 | return torch.sum(pred_score + pred_geo) * 0 33 | 34 | # gt_score = gt_score*valid_map 35 | classify_loss = get_dice_loss(gt_score, pred_score*valid_map) 36 | iou_loss_map, angle_loss_map = get_geo_loss(gt_geo, pred_geo) 37 | 38 | angle_loss = torch.sum(angle_loss_map*gt_score) / torch.sum(gt_score) 39 | iou_loss = torch.sum(iou_loss_map*gt_score) / torch.sum(gt_score) 40 | geo_loss = self.weight_angle * angle_loss + iou_loss 41 | # print('classify loss is {:.8f}, angle loss is {:.8f}, iou loss is {:.8f}'.format(classify_loss, angle_loss, iou_loss)) 42 | return geo_loss + classify_loss 43 | -------------------------------------------------------------------------------- /EAST_box_supervision/train_msra.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from torch import nn 4 | from torch.optim import lr_scheduler 5 | from dataset.MSRA_TD500 import MSRA_TD500 6 | from network.model import EAST 7 | from network.loss import Loss 8 | import os 9 | import time 10 | import numpy as np 11 | from PIL import Image, ImageDraw 12 | from tqdm import tqdm 13 | from lib.detect import detect_msra 14 | from evaluate.msra.eval import get_msra_result 15 | import argparse 16 | import os 17 | from lib.utils import setup_logger 18 | 19 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 20 | 21 | parser = argparse.ArgumentParser(description='EAST reimplementation') 22 | 23 | # Model path 24 | parser.add_argument('--exp_name',default= "MSRA", help='Where to store logs and models') 25 | parser.add_argument('--resume', default="/home/wwj/workspace/Sence_Text_detection/AAAI_EAST/Baseline/EAST_v1/worksapce/SynthText/synthtext_1_model.pth", type=str, 26 | help='Checkpoint state_dict file to resume training from') 27 | parser.add_argument('--msra_data', default='/home/xjc/Dataset/MSRA-TD500/', type=str, 28 | help='the path of training data ') 29 | parser.add_argument('--hust_path', default="/home/xjc/Dataset/HUST-TR400/", type=str, 30 | help='the label of training data') 31 | parser.add_argument('--workspace', default="/home/xjc/Desktop/CVPR_SemiText/SemiText/EAST_box_supervision/worksapce/", type=str, 32 | help='save model') 33 | parser.add_argument('--gt_name', default="msra_gt.zip", type=str, help='gt name') 34 | parser.add_argument('--is_box_pseudo', default=True, type=bool, help='gt name') 35 | 36 | # Training strategy 37 | parser.add_argument('--epoch_iter', default=400, type = int, 38 | help='the max epoch iter') 39 | parser.add_argument('--batch_size', default=8, type = int, 40 | help='batch size of training') 41 | parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, 42 | help='initial learning rate') 43 | parser.add_argument('--momentum', default=0.9, type=float, 44 | help='Momentum value for optim') 45 | parser.add_argument('--weight_decay', default=5e-4, type=float, 46 | help='Weight decay for SGD') 47 | parser.add_argument('--gamma', default=0.1, type=float, 48 | help='Gamma update for SGD') 49 | parser.add_argument('--num_workers', default=10, type=int, 50 | help='Number of workers used in dataloading') 51 | 52 | args = parser.parse_args() 53 | 54 | 55 | def train(epoch, model, optimizer,train_loader_source,scheduler,criterion): 56 | model.train() 57 | scheduler.step() 58 | epoch_loss = 0 59 | epoch_time = time.time() 60 | 61 | for i, (img_target, gt_score_target, gt_geo_target, valid_map_target) in enumerate(train_loader_source): 62 | start_time = time.time() 63 | img, gt_score, gt_geo, valid_map = img_target.to(device), gt_score_target.to(device), gt_geo_target.to(device), valid_map_target.to(device) 64 | 65 | pred_score, pred_geo = model(img) 66 | 67 | loss = criterion(gt_score, pred_score, gt_geo, pred_geo, valid_map) 68 | 69 | epoch_loss += loss.item() 70 | optimizer.zero_grad() 71 | loss.backward() 72 | optimizer.step() 73 | 74 | if i%20==0: 75 | logger.info('Epoch is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'.format( \ 76 | epoch + 1, args.epoch_iter, i + 1, int(len(train_loader_source)), time.time() - start_time, loss.item())) 77 | 78 | # if i>4000 and i%1000==0: 79 | # f_score = test(epoch, model, args.t_eval_path, args.t_output_path, f_score, args.save_model) 80 | logger.info('epoch_loss is {:.8f}, epoch_time is {:.8f}'.format(epoch_loss / int(7200 / args.batch_size),time.time() - epoch_time)) 81 | logger.info(time.asctime(time.localtime(time.time()))) 82 | 83 | 84 | 85 | if __name__ == '__main__': 86 | args.workspace = os.path.join(args.workspace, args.exp_name) 87 | os.makedirs(args.workspace, exist_ok=True) 88 | logger = setup_logger(os.path.join(args.workspace, 'train_MSRA_log')) 89 | criterion = Loss() 90 | device = torch.device("cuda") 91 | model = EAST() 92 | # model = nn.DataParallel(model) 93 | data_parallel = False 94 | if torch.cuda.device_count() > 1: 95 | model = nn.DataParallel(model) 96 | data_parallel = True 97 | model.to(device) 98 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 99 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[150,220], gamma=0.1) 100 | 101 | # 先产生第一次的pseudo-label 102 | logger.info("loading pretrained model from "+args.resume) 103 | # model.load_state_dict(torch.load(args.resume)) 104 | 105 | trainset = MSRA_TD500(args.msra_data,args.hust_path,args.is_box_pseudo) 106 | train_loader_target = data.DataLoader(trainset, batch_size=args.batch_size, 107 | shuffle=True, num_workers=args.num_workers, drop_last=True) 108 | 109 | f_score = 0.5 110 | for epoch in range(args.epoch_iter): 111 | train( epoch, model, optimizer,train_loader_target,scheduler,criterion) 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /PSENet_box_supervision/config/icdar15/icdar15_baseline.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/1/3 17:40 3 | # data config 4 | exp_name = "icdar15/psenet_baseline" 5 | trainroot = '/home/xjc/Dataset/icdar15/' 6 | testroot = '/home/xjc/Dataset/icdar15/' 7 | 8 | workspace_dir = '/home/xjc/Desktop/CVPR_SemiText/SemiText/PSENet/workdirs/' 9 | workspace = "" 10 | gt_name = "icd15_gt.zip" 11 | data_shape = 640 12 | 13 | # train config 14 | gpu_id = '1,2,3' 15 | workers = 10 16 | start_epoch = 0 17 | epochs = 600 18 | 19 | train_batch_size = 16 20 | 21 | lr = 1e-4 22 | end_lr = 1e-7 23 | lr_gamma = 0.1 24 | lr_decay_step = [100,200] 25 | weight_decay = 5e-4 26 | warm_up_epoch = 6 27 | warm_up_lr = lr * lr_gamma 28 | 29 | display_input_images = False 30 | display_output_images = False 31 | visualization = True 32 | is_pseudo = False 33 | display_interval = 10 34 | show_images_interval = 50 35 | save_interval=5 36 | 37 | pretrained = True 38 | restart_training = True 39 | checkpoint = '' 40 | 41 | # net config 42 | backbone = 'resnet50' 43 | Lambda = 0.7 44 | kernel_num = 6 45 | m = 0.5 46 | OHEM_ratio = 3 47 | scale = 1 48 | # random seed 49 | seed = 2 50 | 51 | 52 | def print(): 53 | from pprint import pformat 54 | tem_d = {} 55 | for k, v in globals().items(): 56 | if not k.startswith('_') and not callable(v): 57 | tem_d[k] = v 58 | return pformat(tem_d) 59 | -------------------------------------------------------------------------------- /PSENet_box_supervision/config/icdar15/icdar15_pseudo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/1/3 17:40 3 | # data config 4 | exp_name = "icdar15/psenet_pseudo" 5 | trainroot = '/data/glusterfs_cv_04/11121171/data/ICDAR2015/' 6 | testroot = '/data/glusterfs_cv_04/11121171/data/ICDAR2015/' 7 | 8 | workspace_dir = '/data/glusterfs_cv_04/11121171/CVPR_Text/SemiText/PSENet/workdirs/' 9 | workspace = "" 10 | gt_name = "icd15_gt.zip" 11 | data_shape = 640 12 | 13 | # train config 14 | gpu_id = '0,1' 15 | workers = 10 16 | start_epoch = 0 17 | epochs = 600 18 | 19 | train_batch_size = 16 20 | 21 | lr = 1e-4 22 | end_lr = 1e-7 23 | lr_gamma = 0.1 24 | lr_decay_step = [100,200] 25 | weight_decay = 5e-4 26 | warm_up_epoch = 6 27 | warm_up_lr = lr * lr_gamma 28 | 29 | display_input_images = False 30 | display_output_images = False 31 | visualization = True 32 | is_pseudo = True 33 | display_interval = 10 34 | show_images_interval = 50 35 | save_interval=5 36 | 37 | pretrained = True 38 | restart_training = True 39 | checkpoint = '' 40 | 41 | # net config 42 | backbone = 'resnet50' 43 | Lambda = 0.7 44 | kernel_num = 6 45 | m = 0.5 46 | OHEM_ratio = 3 47 | scale = 1 48 | # random seed 49 | seed = 2 50 | 51 | 52 | def print(): 53 | from pprint import pformat 54 | tem_d = {} 55 | for k, v in globals().items(): 56 | if not k.startswith('_') and not callable(v): 57 | tem_d[k] = v 58 | return pformat(tem_d) 59 | -------------------------------------------------------------------------------- /PSENet_box_supervision/config/msra/msra_baseline.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/1/3 17:40 3 | # data config 4 | exp_name = "msra/msra_baseline" 5 | msra_path = '/home/xjc/Dataset/MSRA-TD500/' 6 | hust_path = '/home/xjc/Dataset/HUST-TR400/' 7 | 8 | workspace_dir = '/home/xjc/Desktop/CVPR_SemiText/SemiText/PSENet_box_supervision/workspace/' 9 | workspace = "" 10 | gt_name = "msra_gt.zip" 11 | data_shape = 640 12 | 13 | # train config 14 | gpu_id = '0' 15 | workers = 10 16 | start_epoch = 0 17 | epochs = 600 18 | 19 | train_batch_size = 8 20 | 21 | lr = 1e-4 22 | end_lr = 1e-7 23 | lr_gamma = 0.1 24 | lr_decay_step = [100,200] 25 | weight_decay = 5e-4 26 | warm_up_epoch = 6 27 | warm_up_lr = lr * lr_gamma 28 | 29 | display_input_images = False 30 | display_output_images = False 31 | visualization = False 32 | is_box_pseudo = False 33 | display_interval = 10 34 | show_images_interval = 50 35 | save_interval=5 36 | 37 | pretrained = True 38 | restart_training = True 39 | checkpoint = '' 40 | 41 | # net config 42 | backbone = 'resnet50' 43 | Lambda = 0.7 44 | kernel_num = 6 45 | min_scale = 0.5 46 | OHEM_ratio = 3 47 | scale = 1 48 | # random seed 49 | seed = 2 50 | 51 | 52 | def print(): 53 | from pprint import pformat 54 | tem_d = {} 55 | for k, v in globals().items(): 56 | if not k.startswith('_') and not callable(v): 57 | tem_d[k] = v 58 | return pformat(tem_d) 59 | -------------------------------------------------------------------------------- /PSENet_box_supervision/config/msra/msra_pseudo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/1/3 17:40 3 | # data config 4 | exp_name = "msra/msra_pseudo" 5 | msra_path = '/home/xjc/Dataset/MSRA-TD500/' 6 | hust_path = '/home/xjc/Dataset/HUST-TR400/' 7 | 8 | workspace_dir = '/home/xjc/Desktop/CVPR_SemiText/SemiText/PSENet_box_supervision/workspace/' 9 | workspace = "" 10 | gt_name = "msra_gt.zip" 11 | data_shape = 640 12 | 13 | # train config 14 | gpu_id = '0' 15 | workers = 10 16 | start_epoch = 0 17 | epochs = 600 18 | 19 | train_batch_size = 8 20 | 21 | lr = 1e-4 22 | end_lr = 1e-7 23 | lr_gamma = 0.1 24 | lr_decay_step = [100,200] 25 | weight_decay = 5e-4 26 | warm_up_epoch = 6 27 | warm_up_lr = lr * lr_gamma 28 | 29 | display_input_images = False 30 | display_output_images = False 31 | visualization = False 32 | is_box_pseudo = True 33 | display_interval = 10 34 | show_images_interval = 50 35 | save_interval=5 36 | 37 | pretrained = True 38 | restart_training = True 39 | checkpoint = '' 40 | 41 | # net config 42 | backbone = 'resnet50' 43 | Lambda = 0.7 44 | kernel_num = 6 45 | min_scale = 0.5 46 | OHEM_ratio = 3 47 | scale = 1 48 | # random seed 49 | seed = 2 50 | 51 | 52 | def print(): 53 | from pprint import pformat 54 | tem_d = {} 55 | for k, v in globals().items(): 56 | if not k.startswith('_') and not callable(v): 57 | tem_d[k] = v 58 | return pformat(tem_d) 59 | -------------------------------------------------------------------------------- /PSENet_box_supervision/config/totaltext/psenet_baseline.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # data config 4 | exp_name = "totaltext/psenet_baseline" 5 | 6 | train_data_dir = '/data/glusterfs_cv_04/11121171/data/Total/Images/Train/' 7 | train_gt_dir = '/data/glusterfs_cv_04/11121171/data/Total/gt/Train/' 8 | test_data_dir = '/data/glusterfs_cv_04/11121171/data/Total/Images/Test/' 9 | test_gt_dir = '/data/glusterfs_cv_04/11121171/data/Total/gt/Test/' 10 | 11 | workspace_dir = '/data/glusterfs_cv_04/11121171/CVPR_Text/SemiText/PSENet/workdirs/' 12 | workspace = "/data/glusterfs_cv_04/11121171/CVPR_Text/SemiText/PSENet/workdirs/eval" 13 | data_shape = 640 14 | 15 | # train config 16 | gpu_id = '0,1' 17 | workers = 10 18 | start_epoch = 0 19 | epochs = 100 20 | train_batch_size = 20 21 | 22 | lr = 1e-4 23 | lr_gamma = 0.1 24 | lr_decay_step = [60,80] 25 | 26 | display_input_images = False 27 | display_output_images = False 28 | visualization = True 29 | display_interval = 10 30 | show_images_interval = 50 31 | is_pseudo = False 32 | pretrained = True 33 | restart_training = False 34 | checkpoint = '' 35 | save_interval = 5 36 | 37 | # net config 38 | backbone = 'resnet50' 39 | Lambda = 0.7 40 | kernel_num = 7 41 | min_scale = 0.4 42 | OHEM_ratio = 3 43 | scale = 1 44 | min_kernel_area = 10.0 45 | # random seed 46 | seed = 2 47 | binary_th = 1.0 48 | 49 | 50 | def pprint(): 51 | from pprint import pformat 52 | tem_d = {} 53 | for k, v in globals().items(): 54 | if not k.startswith('_') and not callable(v): 55 | tem_d[k] = v 56 | return pformat(tem_d) 57 | -------------------------------------------------------------------------------- /PSENet_box_supervision/config/totaltext/psenet_pseudo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # data config 4 | exp_name = "totaltext/psenet_pseudo" 5 | 6 | train_data_dir = '/home/xjc/Dataset/total-text/Images/Train/' 7 | train_gt_dir = '/home/xjc/Dataset/total-text/pseudolabel_attention/Train_pseudo/' 8 | test_data_dir = '/home/xjc/Dataset/total-text/Images/Test/' 9 | test_gt_dir = '/home/xjc/Dataset/total-text/gt/Test/' 10 | 11 | workspace_dir = '/home/xjc/Desktop/CVPR_SemiText/SemiText/PSENet/workdirs/' 12 | workspace = "/home/xjc/Desktop/CVPR_SemiText/SemiText/PSENet/workdirs/eval" 13 | data_shape = 640 14 | 15 | # train config 16 | gpu_id = '2,3' 17 | workers = 10 18 | start_epoch = 0 19 | epochs = 150 20 | train_batch_size = 8 21 | 22 | lr = 1e-4 23 | lr_gamma = 0.1 24 | lr_decay_step = [80,120] 25 | 26 | display_input_images = False 27 | display_output_images = False 28 | visualization = False 29 | display_interval = 10 30 | show_images_interval = 50 31 | 32 | is_pseudo = True 33 | pretrained = True 34 | restart_training = False 35 | checkpoint = '' 36 | save_interval = 5 37 | 38 | # net config 39 | backbone = 'resnet50' 40 | Lambda = 0.7 41 | kernel_num = 7 42 | min_scale = 0.4 43 | OHEM_ratio = 3 44 | scale = 1 45 | min_kernel_area = 10.0 46 | # random seed 47 | seed = 2 48 | binary_th = 1.0 49 | 50 | 51 | def pprint(): 52 | from pprint import pformat 53 | tem_d = {} 54 | for k, v in globals().items(): 55 | if not k.startswith('_') and not callable(v): 56 | tem_d[k] = v 57 | return pformat(tem_d) 58 | -------------------------------------------------------------------------------- /PSENet_box_supervision/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/17/19 2:09 AM 3 | # @Author : zhoujun -------------------------------------------------------------------------------- /PSENet_box_supervision/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/PSENet_box_supervision/evaluation/__init__.py -------------------------------------------------------------------------------- /PSENet_box_supervision/evaluation/eval_tt.sh: -------------------------------------------------------------------------------- 1 | cd total_text 2 | python Deteval.py 3 | cd .. -------------------------------------------------------------------------------- /PSENet_box_supervision/evaluation/msra/eval.py: -------------------------------------------------------------------------------- 1 | import evaluation.msra.file_util as file_util 2 | import Polygon as plg 3 | import numpy as np 4 | import math 5 | import cv2 6 | import os 7 | 8 | pred_root = '/home/wwj/workspace/Sence_Text_detection/AAAI_EAST/Baseline/EAST_v1/worksapce/MSRA/submit' 9 | gt_root = '/data/data_weijiawu/TD500/Test_gt/' 10 | 11 | 12 | def get_pred(path): 13 | lines = file_util.read_file(path).split('\n') 14 | bboxes = [] 15 | for line in lines: 16 | if line == '': 17 | continue 18 | bbox = line.split(',') 19 | if len(bbox) % 2 == 1: 20 | print(path) 21 | bbox = [int(x) for x in bbox] 22 | bboxes.append(bbox) 23 | return bboxes 24 | 25 | 26 | def get_gt(path): 27 | lines = file_util.read_file(path).split('\n') 28 | bboxes = [] 29 | tags = [] 30 | for line in lines: 31 | if line == '': 32 | continue 33 | # line = util.str.remove_all(line, '\xef\xbb\xbf') 34 | # gt = util.str.split(line, ' ') 35 | gt = line.split(' ') 36 | 37 | w_ = np.float(gt[4]) 38 | h_ = np.float(gt[5]) 39 | x1 = np.float(gt[2]) + w_ / 2.0 40 | y1 = np.float(gt[3]) + h_ / 2.0 41 | theta = np.float(gt[6]) / math.pi * 180 42 | 43 | bbox = cv2.boxPoints(((x1, y1), (w_, h_), theta)) 44 | bbox = bbox.reshape(-1) 45 | 46 | bboxes.append(bbox) 47 | tags.append(np.int(gt[1])) 48 | return np.array(bboxes), tags 49 | 50 | 51 | def get_union(pD, pG): 52 | areaA = pD.area() 53 | areaB = pG.area() 54 | return areaA + areaB - get_intersection(pD, pG) 55 | 56 | 57 | def get_intersection(pD, pG): 58 | pInt = pD & pG 59 | if len(pInt) == 0: 60 | return 0 61 | return pInt.area() 62 | 63 | def get_msra_result(pred_path_root,gt_path_root): 64 | th = 0.5 65 | pred_list = file_util.read_dir(pred_path_root) 66 | 67 | count, tp, fp, tn, ta = 0, 0, 0, 0, 0 68 | for pred_path in pred_list: 69 | count = count + 1 70 | preds = get_pred(pred_path) 71 | gt_path = os.path.join(gt_path_root ,pred_path.split('/')[-1].split('.')[0].split('res_')[-1] + '.gt') 72 | gts, tags = get_gt(gt_path) 73 | 74 | ta = ta + len(preds) 75 | for gt, tag in zip(gts, tags): 76 | gt = np.array(gt) 77 | gt = gt.reshape(int(gt.shape[0] / 2), 2) 78 | gt_p = plg.Polygon(gt) 79 | difficult = tag 80 | flag = 0 81 | for pred in preds: 82 | pred = np.array(pred) 83 | pred = pred.reshape(int(pred.shape[0] / 2), 2) 84 | pred_p = plg.Polygon(pred) 85 | 86 | union = get_union(pred_p, gt_p) 87 | inter = get_intersection(pred_p, gt_p) 88 | iou = float(inter) / union 89 | if iou >= th: 90 | flag = 1 91 | tp = tp + 1 92 | break 93 | 94 | if flag == 0 and difficult == 0: 95 | fp = fp + 1 96 | 97 | recall = float(tp) / (tp + fp) 98 | try: 99 | precision = float(tp) / ta 100 | except: 101 | precision = 0 102 | hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall) 103 | 104 | print('p: %.4f, r: %.4f, f: %.4f' % (precision, recall, hmean)) 105 | return hmean 106 | 107 | if __name__ == '__main__': 108 | th = 0.5 109 | pred_list = file_util.read_dir(pred_root) 110 | 111 | count, tp, fp, tn, ta = 0, 0, 0, 0, 0 112 | for pred_path in pred_list: 113 | count = count + 1 114 | preds = get_pred(pred_path) 115 | gt_path = gt_root + pred_path.split('/')[-1].split('.')[0].split('res_')[-1] + '.gt' 116 | gts, tags = get_gt(gt_path) 117 | 118 | ta = ta + len(preds) 119 | for gt, tag in zip(gts, tags): 120 | gt = np.array(gt) 121 | gt = gt.reshape(int(gt.shape[0] / 2), 2) 122 | gt_p = plg.Polygon(gt) 123 | difficult = tag 124 | flag = 0 125 | for pred in preds: 126 | pred = np.array(pred) 127 | pred = pred.reshape(int(pred.shape[0] / 2), 2) 128 | pred_p = plg.Polygon(pred) 129 | 130 | union = get_union(pred_p, gt_p) 131 | inter = get_intersection(pred_p, gt_p) 132 | iou = float(inter) / union 133 | if iou >= th: 134 | flag = 1 135 | tp = tp + 1 136 | break 137 | 138 | if flag == 0 and difficult == 0: 139 | fp = fp + 1 140 | 141 | recall = float(tp) / (tp + fp) 142 | precision = float(tp) / ta 143 | hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall) 144 | 145 | print('p: %.4f, r: %.4f, f: %.4f' % (precision, recall, hmean)) 146 | -------------------------------------------------------------------------------- /PSENet_box_supervision/evaluation/msra/file_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def read_dir(root): 4 | file_path_list = [] 5 | for file_path, dirs, files in os.walk(root): 6 | for file in files: 7 | file_path_list.append(os.path.join(file_path, file).replace('\\', '/')) 8 | file_path_list.sort() 9 | return file_path_list 10 | 11 | def read_file(file_path): 12 | file_object = open(file_path, 'r') 13 | file_content = file_object.read() 14 | file_object.close() 15 | return file_content 16 | 17 | def write_file(file_path, file_content): 18 | if file_path.find('/') != -1: 19 | father_dir = '/'.join(file_path.split('/')[0:-1]) 20 | if not os.path.exists(father_dir): 21 | os.makedirs(father_dir) 22 | file_object = open(file_path, 'w') 23 | file_object.write(file_content) 24 | file_object.close() 25 | 26 | 27 | def write_file_not_cover(file_path, file_content): 28 | father_dir = '/'.join(file_path.split('/')[0:-1]) 29 | if not os.path.exists(father_dir): 30 | os.makedirs(father_dir) 31 | file_object = open(file_path, 'a') 32 | file_object.write(file_content) 33 | file_object.close() -------------------------------------------------------------------------------- /PSENet_box_supervision/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/1/2 18:18 3 | # @Author : zhoujun 4 | from .model import PSENet -------------------------------------------------------------------------------- /PSENet_box_supervision/models/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 3/29/19 11:03 AM 3 | # @Author : zhoujun 4 | import torch 5 | from torch import nn 6 | import numpy as np 7 | 8 | 9 | class PSELoss(nn.Module): 10 | def __init__(self, Lambda, ratio=3, reduction='mean'): 11 | """Implement PSE Loss. 12 | """ 13 | super(PSELoss, self).__init__() 14 | assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']" 15 | self.Lambda = Lambda 16 | self.ratio = ratio 17 | self.reduction = reduction 18 | 19 | def forward(self, outputs, labels, training_masks): 20 | texts = outputs[:, -1, :, :] 21 | kernels = outputs[:, :-1, :, :] 22 | gt_texts = labels[:, -1, :, :] 23 | gt_kernels = labels[:, :-1, :, :] 24 | 25 | selected_masks = self.ohem_batch(texts, gt_texts, training_masks) 26 | selected_masks = selected_masks.to(outputs.device) 27 | 28 | loss_text = self.dice_loss(texts, gt_texts, selected_masks) 29 | 30 | loss_kernels = [] 31 | mask0 = torch.sigmoid(texts).data.cpu().numpy() 32 | mask1 = training_masks.data.cpu().numpy() 33 | selected_masks = ((mask0 > 0.5) & (mask1 > 0.5)).astype('float32') 34 | selected_masks = torch.from_numpy(selected_masks).float() 35 | selected_masks = selected_masks.to(outputs.device) 36 | kernels_num = gt_kernels.size()[1] 37 | for i in range(kernels_num): 38 | kernel_i = kernels[:, i, :, :] 39 | gt_kernel_i = gt_kernels[:, i, :, :] 40 | loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i, selected_masks) 41 | loss_kernels.append(loss_kernel_i) 42 | loss_kernels = torch.stack(loss_kernels).mean(0) 43 | if self.reduction == 'mean': 44 | loss_text = loss_text.mean() 45 | loss_kernels = loss_kernels.mean() 46 | elif self.reduction == 'sum': 47 | loss_text = loss_text.sum() 48 | loss_kernels = loss_kernels.sum() 49 | 50 | loss = self.Lambda * loss_text + (1 - self.Lambda) * loss_kernels 51 | return loss_text, loss_kernels, loss 52 | 53 | def dice_loss(self, input, target, mask): 54 | input = torch.sigmoid(input) 55 | 56 | input = input.contiguous().view(input.size()[0], -1) 57 | target = target.contiguous().view(target.size()[0], -1) 58 | mask = mask.contiguous().view(mask.size()[0], -1) 59 | 60 | input = input * mask 61 | target = target * mask 62 | 63 | a = torch.sum(input * target, 1) 64 | b = torch.sum(input * input, 1) + 0.001 65 | c = torch.sum(target * target, 1) + 0.001 66 | d = (2 * a) / (b + c) 67 | return 1 - d 68 | 69 | def ohem_single(self, score, gt_text, training_mask): 70 | pos_num = (int)(np.sum(gt_text > 0.5)) - (int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5))) 71 | 72 | if pos_num == 0: 73 | # selected_mask = gt_text.copy() * 0 # may be not good 74 | selected_mask = training_mask 75 | selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') 76 | return selected_mask 77 | 78 | neg_num = (int)(np.sum(gt_text <= 0.5)) 79 | neg_num = (int)(min(pos_num * 3, neg_num)) 80 | 81 | if neg_num == 0: 82 | selected_mask = training_mask 83 | selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') 84 | return selected_mask 85 | 86 | neg_score = score[gt_text <= 0.5] 87 | # 将负样本得分从高到低排序 88 | neg_score_sorted = np.sort(-neg_score) 89 | threshold = -neg_score_sorted[neg_num - 1] 90 | # 选出 得分高的 负样本 和正样本 的 mask 91 | selected_mask = ((score >= threshold) | (gt_text > 0.5)) & (training_mask > 0.5) 92 | selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') 93 | return selected_mask 94 | 95 | def ohem_batch(self, scores, gt_texts, training_masks): 96 | scores = scores.data.cpu().numpy() 97 | gt_texts = gt_texts.data.cpu().numpy() 98 | training_masks = training_masks.data.cpu().numpy() 99 | 100 | selected_masks = [] 101 | for i in range(scores.shape[0]): 102 | selected_masks.append(self.ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[i, :, :])) 103 | 104 | selected_masks = np.concatenate(selected_masks, 0) 105 | selected_masks = torch.from_numpy(selected_masks).float() 106 | 107 | return selected_masks 108 | -------------------------------------------------------------------------------- /PSENet_box_supervision/predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/4/19 11:14 AM 3 | # @Author : zhoujun 4 | import torch 5 | from torchvision import transforms 6 | import os 7 | import cv2 8 | import time 9 | import numpy as np 10 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' 11 | from pse import decode as pse_decode 12 | 13 | 14 | class Pytorch_model: 15 | def __init__(self, model_path, net, scale, gpu_id=None): 16 | ''' 17 | 初始化pytorch模型 18 | :param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件) 19 | :param net: 网络计算图,如果在model_path中指定的是参数的保存路径,则需要给出网络的计算图 20 | :param img_channel: 图像的通道数: 1,3 21 | :param gpu_id: 在哪一块gpu上运行 22 | ''' 23 | 24 | self.scale = scale 25 | if gpu_id is not None and isinstance(gpu_id, int) and torch.cuda.is_available(): 26 | 27 | self.device = torch.device("cuda:{}".format(gpu_id)) 28 | else: 29 | self.device = torch.device("cpu") 30 | print('device:', self.device) 31 | self.net = torch.load(model_path, map_location=self.device)['state_dict'] 32 | 33 | 34 | if net is not None: 35 | # 如果网络计算图和参数是分开保存的,就执行参数加载 36 | net = net.to(self.device) 37 | net.scale = scale 38 | try: 39 | sk = {} 40 | for k in self.net: 41 | sk[k[7:]] = self.net[k] 42 | net.load_state_dict(sk) 43 | except: 44 | net.load_state_dict(self.net) 45 | self.net = net 46 | print('load models') 47 | self.net.eval() 48 | 49 | def predict(self, img: str, long_size: int = 2000): 50 | ''' 51 | 对传入的图像进行预测,支持图像地址,opecv 读取图片,偏慢 52 | :param img: 图像地址 53 | :param is_numpy: 54 | :return: 55 | ''' 56 | assert os.path.exists(img), 'file is not exists' 57 | img = cv2.imread(img) 58 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 59 | h, w = img.shape[:2] 60 | 61 | scale = long_size / max(h, w) 62 | img = cv2.resize(img, None, fx=scale, fy=scale) 63 | # 将图片由(w,h)变为(1,img_channel,h,w) 64 | tensor = transforms.ToTensor()(img) 65 | tensor = tensor.unsqueeze_(0) 66 | 67 | tensor = tensor.to(self.device) 68 | with torch.no_grad(): 69 | # torch.cuda.synchronize() 70 | start = time.time() 71 | preds = self.net(tensor) 72 | preds, boxes_list = pse_decode(preds[0], self.scale) 73 | scale = (preds.shape[1] / w, preds.shape[0] / h) 74 | # print(scale) 75 | # preds, boxes_list = decode(preds,num_pred=-1) 76 | if len(boxes_list): 77 | boxes_list = boxes_list / scale 78 | # torch.cuda.synchronize() 79 | t = time.time() - start 80 | return preds, boxes_list, t 81 | 82 | 83 | def _get_annotation(label_path): 84 | boxes = [] 85 | with open(label_path, encoding='utf-8', mode='r') as f: 86 | for line in f.readlines(): 87 | params = line.strip().strip('\ufeff').strip('\xef\xbb\xbf').split(',') 88 | try: 89 | label = params[8] 90 | if label == '*' or label == '###': 91 | continue 92 | x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, params[:8])) 93 | boxes.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) 94 | except: 95 | print('load label failed on {}'.format(label_path)) 96 | return np.array(boxes, dtype=np.float32) 97 | 98 | 99 | if __name__ == '__main__': 100 | import config 101 | from models import PSENet 102 | import matplotlib.pyplot as plt 103 | from utils.utils import show_img, draw_bbox 104 | 105 | os.environ['CUDA_VISIBLE_DEVICES'] = str('2') 106 | 107 | model_path = 'output/psenet_icd2015_resnet152_author_crop_adam_warm_up_myloss/best_r0.714011_p0.708214_f10.711100.pth' 108 | 109 | # model_path = 'output/psenet_icd2015_new_loss/final.pth' 110 | img_id = 10 111 | img_path = '/data2/dataset/ICD15/test/img/img_{}.jpg'.format(img_id) 112 | label_path = '/data2/dataset/ICD15/test/gt/gt_img_{}.txt'.format(img_id) 113 | label = _get_annotation(label_path) 114 | 115 | # 初始化网络 116 | net = PSENet(backbone='resnet152', pretrained=False, result_num=config.n) 117 | model = Pytorch_model(model_path, net=net, scale=1, gpu_id=0) 118 | # for i in range(100): 119 | # models.predict(img_path) 120 | preds, boxes_list,t = model.predict(img_path) 121 | print(boxes_list) 122 | show_img(preds) 123 | img = draw_bbox(img_path, boxes_list, color=(0, 0, 255)) 124 | cv2.imwrite('result.jpg', img) 125 | # img = draw_bbox(img, label,color=(0,0,255)) 126 | show_img(img, color=True) 127 | 128 | plt.show() 129 | -------------------------------------------------------------------------------- /PSENet_box_supervision/pse/Makefile: -------------------------------------------------------------------------------- 1 | CXXFLAGS = -I include -std=c++11 -O3 $(shell python3-config --cflags) 2 | LDFLAGS = $(shell python3-config --ldflags) 3 | 4 | DEPS = $(shell find include -xtype f) 5 | CXX_SOURCES = pse.cpp 6 | 7 | LIB_SO = pse.so 8 | 9 | $(LIB_SO): $(CXX_SOURCES) $(DEPS) 10 | $(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC 11 | 12 | clean: 13 | rm -rf $(LIB_SO) 14 | -------------------------------------------------------------------------------- /PSENet_box_supervision/pse/include/pybind11/buffer_info.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/buffer_info.h: Python buffer object interface 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "detail/common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | 16 | /// Information record describing a Python buffer object 17 | struct buffer_info { 18 | void *ptr = nullptr; // Pointer to the underlying storage 19 | ssize_t itemsize = 0; // Size of individual items in bytes 20 | ssize_t size = 0; // Total number of entries 21 | std::string format; // For homogeneous buffers, this should be set to format_descriptor::format() 22 | ssize_t ndim = 0; // Number of dimensions 23 | std::vector shape; // Shape of the tensor (1 entry per dimension) 24 | std::vector strides; // Number of entries between adjacent entries (for each per dimension) 25 | 26 | buffer_info() { } 27 | 28 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, 29 | detail::any_container shape_in, detail::any_container strides_in) 30 | : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim), 31 | shape(std::move(shape_in)), strides(std::move(strides_in)) { 32 | if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size()) 33 | pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length"); 34 | for (size_t i = 0; i < (size_t) ndim; ++i) 35 | size *= shape[i]; 36 | } 37 | 38 | template 39 | buffer_info(T *ptr, detail::any_container shape_in, detail::any_container strides_in) 40 | : buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor::format(), static_cast(shape_in->size()), std::move(shape_in), std::move(strides_in)) { } 41 | 42 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size) 43 | : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { } 44 | 45 | template 46 | buffer_info(T *ptr, ssize_t size) 47 | : buffer_info(ptr, sizeof(T), format_descriptor::format(), size) { } 48 | 49 | explicit buffer_info(Py_buffer *view, bool ownview = true) 50 | : buffer_info(view->buf, view->itemsize, view->format, view->ndim, 51 | {view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) { 52 | this->view = view; 53 | this->ownview = ownview; 54 | } 55 | 56 | buffer_info(const buffer_info &) = delete; 57 | buffer_info& operator=(const buffer_info &) = delete; 58 | 59 | buffer_info(buffer_info &&other) { 60 | (*this) = std::move(other); 61 | } 62 | 63 | buffer_info& operator=(buffer_info &&rhs) { 64 | ptr = rhs.ptr; 65 | itemsize = rhs.itemsize; 66 | size = rhs.size; 67 | format = std::move(rhs.format); 68 | ndim = rhs.ndim; 69 | shape = std::move(rhs.shape); 70 | strides = std::move(rhs.strides); 71 | std::swap(view, rhs.view); 72 | std::swap(ownview, rhs.ownview); 73 | return *this; 74 | } 75 | 76 | ~buffer_info() { 77 | if (view && ownview) { PyBuffer_Release(view); delete view; } 78 | } 79 | 80 | private: 81 | struct private_ctr_tag { }; 82 | 83 | buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, 84 | detail::any_container &&shape_in, detail::any_container &&strides_in) 85 | : buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { } 86 | 87 | Py_buffer *view = nullptr; 88 | bool ownview = false; 89 | }; 90 | 91 | NAMESPACE_BEGIN(detail) 92 | 93 | template struct compare_buffer_info { 94 | static bool compare(const buffer_info& b) { 95 | return b.format == format_descriptor::format() && b.itemsize == (ssize_t) sizeof(T); 96 | } 97 | }; 98 | 99 | template struct compare_buffer_info::value>> { 100 | static bool compare(const buffer_info& b) { 101 | return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor::value || 102 | ((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned::value ? "L" : "l")) || 103 | ((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned::value ? "N" : "n"))); 104 | } 105 | }; 106 | 107 | NAMESPACE_END(detail) 108 | NAMESPACE_END(PYBIND11_NAMESPACE) 109 | -------------------------------------------------------------------------------- /PSENet_box_supervision/pse/include/pybind11/common.h: -------------------------------------------------------------------------------- 1 | #include "detail/common.h" 2 | #warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'." 3 | -------------------------------------------------------------------------------- /PSENet_box_supervision/pse/include/pybind11/complex.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/complex.h: Complex number support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | /// glibc defines I as a macro which breaks things, e.g., boost template names 16 | #ifdef I 17 | # undef I 18 | #endif 19 | 20 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 21 | 22 | template struct format_descriptor, detail::enable_if_t::value>> { 23 | static constexpr const char c = format_descriptor::c; 24 | static constexpr const char value[3] = { 'Z', c, '\0' }; 25 | static std::string format() { return std::string(value); } 26 | }; 27 | 28 | #ifndef PYBIND11_CPP17 29 | 30 | template constexpr const char format_descriptor< 31 | std::complex, detail::enable_if_t::value>>::value[3]; 32 | 33 | #endif 34 | 35 | NAMESPACE_BEGIN(detail) 36 | 37 | template struct is_fmt_numeric, detail::enable_if_t::value>> { 38 | static constexpr bool value = true; 39 | static constexpr int index = is_fmt_numeric::index + 3; 40 | }; 41 | 42 | template class type_caster> { 43 | public: 44 | bool load(handle src, bool convert) { 45 | if (!src) 46 | return false; 47 | if (!convert && !PyComplex_Check(src.ptr())) 48 | return false; 49 | Py_complex result = PyComplex_AsCComplex(src.ptr()); 50 | if (result.real == -1.0 && PyErr_Occurred()) { 51 | PyErr_Clear(); 52 | return false; 53 | } 54 | value = std::complex((T) result.real, (T) result.imag); 55 | return true; 56 | } 57 | 58 | static handle cast(const std::complex &src, return_value_policy /* policy */, handle /* parent */) { 59 | return PyComplex_FromDoubles((double) src.real(), (double) src.imag()); 60 | } 61 | 62 | PYBIND11_TYPE_CASTER(std::complex, _("complex")); 63 | }; 64 | NAMESPACE_END(detail) 65 | NAMESPACE_END(PYBIND11_NAMESPACE) 66 | -------------------------------------------------------------------------------- /PSENet_box_supervision/pse/include/pybind11/detail/descr.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | NAMESPACE_BEGIN(detail) 16 | 17 | #if !defined(_MSC_VER) 18 | # define PYBIND11_DESCR_CONSTEXPR static constexpr 19 | #else 20 | # define PYBIND11_DESCR_CONSTEXPR const 21 | #endif 22 | 23 | /* Concatenate type signatures at compile time */ 24 | template 25 | struct descr { 26 | char text[N + 1]; 27 | 28 | constexpr descr() : text{'\0'} { } 29 | constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence()) { } 30 | 31 | template 32 | constexpr descr(char const (&s)[N+1], index_sequence) : text{s[Is]..., '\0'} { } 33 | 34 | template 35 | constexpr descr(char c, Chars... cs) : text{c, static_cast(cs)..., '\0'} { } 36 | 37 | static constexpr std::array types() { 38 | return {{&typeid(Ts)..., nullptr}}; 39 | } 40 | }; 41 | 42 | template 43 | constexpr descr plus_impl(const descr &a, const descr &b, 44 | index_sequence, index_sequence) { 45 | return {a.text[Is1]..., b.text[Is2]...}; 46 | } 47 | 48 | template 49 | constexpr descr operator+(const descr &a, const descr &b) { 50 | return plus_impl(a, b, make_index_sequence(), make_index_sequence()); 51 | } 52 | 53 | template 54 | constexpr descr _(char const(&text)[N]) { return descr(text); } 55 | constexpr descr<0> _(char const(&)[1]) { return {}; } 56 | 57 | template struct int_to_str : int_to_str { }; 58 | template struct int_to_str<0, Digits...> { 59 | static constexpr auto digits = descr(('0' + Digits)...); 60 | }; 61 | 62 | // Ternary description (like std::conditional) 63 | template 64 | constexpr enable_if_t> _(char const(&text1)[N1], char const(&)[N2]) { 65 | return _(text1); 66 | } 67 | template 68 | constexpr enable_if_t> _(char const(&)[N1], char const(&text2)[N2]) { 69 | return _(text2); 70 | } 71 | 72 | template 73 | constexpr enable_if_t _(const T1 &d, const T2 &) { return d; } 74 | template 75 | constexpr enable_if_t _(const T1 &, const T2 &d) { return d; } 76 | 77 | template auto constexpr _() -> decltype(int_to_str::digits) { 78 | return int_to_str::digits; 79 | } 80 | 81 | template constexpr descr<1, Type> _() { return {'%'}; } 82 | 83 | constexpr descr<0> concat() { return {}; } 84 | 85 | template 86 | constexpr descr concat(const descr &descr) { return descr; } 87 | 88 | template 89 | constexpr auto concat(const descr &d, const Args &...args) 90 | -> decltype(std::declval>() + concat(args...)) { 91 | return d + _(", ") + concat(args...); 92 | } 93 | 94 | template 95 | constexpr descr type_descr(const descr &descr) { 96 | return _("{") + descr + _("}"); 97 | } 98 | 99 | NAMESPACE_END(detail) 100 | NAMESPACE_END(PYBIND11_NAMESPACE) 101 | -------------------------------------------------------------------------------- /PSENet_box_supervision/pse/include/pybind11/detail/typeid.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/typeid.h: Compiler-independent access to type identifiers 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #if defined(__GNUG__) 16 | #include 17 | #endif 18 | 19 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 20 | NAMESPACE_BEGIN(detail) 21 | /// Erase all occurrences of a substring 22 | inline void erase_all(std::string &string, const std::string &search) { 23 | for (size_t pos = 0;;) { 24 | pos = string.find(search, pos); 25 | if (pos == std::string::npos) break; 26 | string.erase(pos, search.length()); 27 | } 28 | } 29 | 30 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { 31 | #if defined(__GNUG__) 32 | int status = 0; 33 | std::unique_ptr res { 34 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; 35 | if (status == 0) 36 | name = res.get(); 37 | #else 38 | detail::erase_all(name, "class "); 39 | detail::erase_all(name, "struct "); 40 | detail::erase_all(name, "enum "); 41 | #endif 42 | detail::erase_all(name, "pybind11::"); 43 | } 44 | NAMESPACE_END(detail) 45 | 46 | /// Return a string representation of a C++ type 47 | template static std::string type_id() { 48 | std::string name(typeid(T).name()); 49 | detail::clean_type_id(name); 50 | return name; 51 | } 52 | 53 | NAMESPACE_END(PYBIND11_NAMESPACE) 54 | -------------------------------------------------------------------------------- /PSENet_box_supervision/pse/include/pybind11/eval.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/exec.h: Support for evaluating Python expressions and statements 3 | from strings and files 4 | 5 | Copyright (c) 2016 Klemens Morgenstern and 6 | Wenzel Jakob 7 | 8 | All rights reserved. Use of this source code is governed by a 9 | BSD-style license that can be found in the LICENSE file. 10 | */ 11 | 12 | #pragma once 13 | 14 | #include "pybind11.h" 15 | 16 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 17 | 18 | enum eval_mode { 19 | /// Evaluate a string containing an isolated expression 20 | eval_expr, 21 | 22 | /// Evaluate a string containing a single statement. Returns \c none 23 | eval_single_statement, 24 | 25 | /// Evaluate a string containing a sequence of statement. Returns \c none 26 | eval_statements 27 | }; 28 | 29 | template 30 | object eval(str expr, object global = globals(), object local = object()) { 31 | if (!local) 32 | local = global; 33 | 34 | /* PyRun_String does not accept a PyObject / encoding specifier, 35 | this seems to be the only alternative */ 36 | std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr; 37 | 38 | int start; 39 | switch (mode) { 40 | case eval_expr: start = Py_eval_input; break; 41 | case eval_single_statement: start = Py_single_input; break; 42 | case eval_statements: start = Py_file_input; break; 43 | default: pybind11_fail("invalid evaluation mode"); 44 | } 45 | 46 | PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr()); 47 | if (!result) 48 | throw error_already_set(); 49 | return reinterpret_steal(result); 50 | } 51 | 52 | template 53 | object eval(const char (&s)[N], object global = globals(), object local = object()) { 54 | /* Support raw string literals by removing common leading whitespace */ 55 | auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s)) 56 | : str(s); 57 | return eval(expr, global, local); 58 | } 59 | 60 | inline void exec(str expr, object global = globals(), object local = object()) { 61 | eval(expr, global, local); 62 | } 63 | 64 | template 65 | void exec(const char (&s)[N], object global = globals(), object local = object()) { 66 | eval(s, global, local); 67 | } 68 | 69 | template 70 | object eval_file(str fname, object global = globals(), object local = object()) { 71 | if (!local) 72 | local = global; 73 | 74 | int start; 75 | switch (mode) { 76 | case eval_expr: start = Py_eval_input; break; 77 | case eval_single_statement: start = Py_single_input; break; 78 | case eval_statements: start = Py_file_input; break; 79 | default: pybind11_fail("invalid evaluation mode"); 80 | } 81 | 82 | int closeFile = 1; 83 | std::string fname_str = (std::string) fname; 84 | #if PY_VERSION_HEX >= 0x03040000 85 | FILE *f = _Py_fopen_obj(fname.ptr(), "r"); 86 | #elif PY_VERSION_HEX >= 0x03000000 87 | FILE *f = _Py_fopen(fname.ptr(), "r"); 88 | #else 89 | /* No unicode support in open() :( */ 90 | auto fobj = reinterpret_steal(PyFile_FromString( 91 | const_cast(fname_str.c_str()), 92 | const_cast("r"))); 93 | FILE *f = nullptr; 94 | if (fobj) 95 | f = PyFile_AsFile(fobj.ptr()); 96 | closeFile = 0; 97 | #endif 98 | if (!f) { 99 | PyErr_Clear(); 100 | pybind11_fail("File \"" + fname_str + "\" could not be opened!"); 101 | } 102 | 103 | #if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION) 104 | PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(), 105 | local.ptr()); 106 | (void) closeFile; 107 | #else 108 | PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(), 109 | local.ptr(), closeFile); 110 | #endif 111 | 112 | if (!result) 113 | throw error_already_set(); 114 | return reinterpret_steal(result); 115 | } 116 | 117 | NAMESPACE_END(PYBIND11_NAMESPACE) 118 | -------------------------------------------------------------------------------- /PSENet_box_supervision/pse/include/pybind11/functional.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/functional.h: std::function<> support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 16 | NAMESPACE_BEGIN(detail) 17 | 18 | template 19 | struct type_caster> { 20 | using type = std::function; 21 | using retval_type = conditional_t::value, void_type, Return>; 22 | using function_type = Return (*) (Args...); 23 | 24 | public: 25 | bool load(handle src, bool convert) { 26 | if (src.is_none()) { 27 | // Defer accepting None to other overloads (if we aren't in convert mode): 28 | if (!convert) return false; 29 | return true; 30 | } 31 | 32 | if (!isinstance(src)) 33 | return false; 34 | 35 | auto func = reinterpret_borrow(src); 36 | 37 | /* 38 | When passing a C++ function as an argument to another C++ 39 | function via Python, every function call would normally involve 40 | a full C++ -> Python -> C++ roundtrip, which can be prohibitive. 41 | Here, we try to at least detect the case where the function is 42 | stateless (i.e. function pointer or lambda function without 43 | captured variables), in which case the roundtrip can be avoided. 44 | */ 45 | if (auto cfunc = func.cpp_function()) { 46 | auto c = reinterpret_borrow(PyCFunction_GET_SELF(cfunc.ptr())); 47 | auto rec = (function_record *) c; 48 | 49 | if (rec && rec->is_stateless && 50 | same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) { 51 | struct capture { function_type f; }; 52 | value = ((capture *) &rec->data)->f; 53 | return true; 54 | } 55 | } 56 | 57 | value = [func](Args... args) -> Return { 58 | gil_scoped_acquire acq; 59 | object retval(func(std::forward(args)...)); 60 | /* Visual studio 2015 parser issue: need parentheses around this expression */ 61 | return (retval.template cast()); 62 | }; 63 | return true; 64 | } 65 | 66 | template 67 | static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) { 68 | if (!f_) 69 | return none().inc_ref(); 70 | 71 | auto result = f_.template target(); 72 | if (result) 73 | return cpp_function(*result, policy).release(); 74 | else 75 | return cpp_function(std::forward(f_), policy).release(); 76 | } 77 | 78 | PYBIND11_TYPE_CASTER(type, _("Callable[[") + concat(make_caster::name...) + _("], ") 79 | + make_caster::name + _("]")); 80 | }; 81 | 82 | NAMESPACE_END(detail) 83 | NAMESPACE_END(PYBIND11_NAMESPACE) 84 | -------------------------------------------------------------------------------- /PSENet_box_supervision/pse/include/pybind11/options.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/options.h: global settings that are configurable at runtime. 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "detail/common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | 16 | class options { 17 | public: 18 | 19 | // Default RAII constructor, which leaves settings as they currently are. 20 | options() : previous_state(global_state()) {} 21 | 22 | // Class is non-copyable. 23 | options(const options&) = delete; 24 | options& operator=(const options&) = delete; 25 | 26 | // Destructor, which restores settings that were in effect before. 27 | ~options() { 28 | global_state() = previous_state; 29 | } 30 | 31 | // Setter methods (affect the global state): 32 | 33 | options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; } 34 | 35 | options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; } 36 | 37 | options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; } 38 | 39 | options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; } 40 | 41 | // Getter methods (return the global state): 42 | 43 | static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; } 44 | 45 | static bool show_function_signatures() { return global_state().show_function_signatures; } 46 | 47 | // This type is not meant to be allocated on the heap. 48 | void* operator new(size_t) = delete; 49 | 50 | private: 51 | 52 | struct state { 53 | bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings. 54 | bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings. 55 | }; 56 | 57 | static state &global_state() { 58 | static state instance; 59 | return instance; 60 | } 61 | 62 | state previous_state; 63 | }; 64 | 65 | NAMESPACE_END(PYBIND11_NAMESPACE) 66 | -------------------------------------------------------------------------------- /PSENet_box_supervision/pse/include/pybind11/typeid.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/typeid.h: Compiler-independent access to type identifiers 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #if defined(__GNUG__) 16 | #include 17 | #endif 18 | 19 | NAMESPACE_BEGIN(pybind11) 20 | NAMESPACE_BEGIN(detail) 21 | /// Erase all occurrences of a substring 22 | inline void erase_all(std::string &string, const std::string &search) { 23 | for (size_t pos = 0;;) { 24 | pos = string.find(search, pos); 25 | if (pos == std::string::npos) break; 26 | string.erase(pos, search.length()); 27 | } 28 | } 29 | 30 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { 31 | #if defined(__GNUG__) 32 | int status = 0; 33 | std::unique_ptr res { 34 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; 35 | if (status == 0) 36 | name = res.get(); 37 | #else 38 | detail::erase_all(name, "class "); 39 | detail::erase_all(name, "struct "); 40 | detail::erase_all(name, "enum "); 41 | #endif 42 | detail::erase_all(name, "pybind11::"); 43 | } 44 | NAMESPACE_END(detail) 45 | 46 | /// Return a string representation of a C++ type 47 | template static std::string type_id() { 48 | std::string name(typeid(T).name()); 49 | detail::clean_type_id(name); 50 | return name; 51 | } 52 | 53 | NAMESPACE_END(pybind11) 54 | -------------------------------------------------------------------------------- /PSENet_box_supervision/pse/ncnn/examples/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package(OpenCV QUIET COMPONENTS core highgui imgproc imgcodecs) 2 | if(NOT OpenCV_FOUND) 3 | find_package(OpenCV REQUIRED COMPONENTS core highgui imgproc) 4 | endif() 5 | 6 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../src) 7 | include_directories(${CMAKE_CURRENT_BINARY_DIR}/../src) 8 | 9 | set(NCNN_EXAMPLE_LINK_LIBRARIES ncnn ${OpenCV_LIBS}) 10 | if(NCNN_VULKAN) 11 | list(APPEND NCNN_EXAMPLE_LINK_LIBRARIES ${Vulkan_LIBRARY}) 12 | endif() 13 | 14 | add_executable(psenet psenet.cpp) 15 | target_link_libraries(psenet ${NCNN_EXAMPLE_LINK_LIBRARIES}) 16 | -------------------------------------------------------------------------------- /PSENet_box_supervision/pse/ncnn/examples/run.sh: -------------------------------------------------------------------------------- 1 | /home/zj/project/ncnn/build/examples/psenet /home/zj/project/ncnn/examples/psenet.bin /home/zj/project/ncnn/examples/psenet.param /home/zj/card/8_6217921001182693.jpg 600 -------------------------------------------------------------------------------- /PSENet_box_supervision/pse/pse.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // pse 3 | // reference https://github.com/whai362/PSENet/issues/15 4 | // Created by liuheng on 11/3/19. 5 | // Copyright © 2019年 liuheng. All rights reserved. 6 | // 7 | #include 8 | #include "include/pybind11/pybind11.h" 9 | #include "include/pybind11/numpy.h" 10 | #include "include/pybind11/stl.h" 11 | #include "include/pybind11/stl_bind.h" 12 | 13 | namespace py = pybind11; 14 | 15 | namespace pse{ 16 | //S5->S0, small->big 17 | std::vector> pse( 18 | py::array_t label_map, 19 | py::array_t Sn, 20 | int c = 6) 21 | { 22 | auto pbuf_label_map = label_map.request(); 23 | auto pbuf_Sn = Sn.request(); 24 | if (pbuf_label_map.ndim != 2 || pbuf_label_map.shape[0]==0 || pbuf_label_map.shape[1]==0) 25 | throw std::runtime_error("label map must have a shape of (h>0, w>0)"); 26 | int h = pbuf_label_map.shape[0]; 27 | int w = pbuf_label_map.shape[1]; 28 | if (pbuf_Sn.ndim != 3 || pbuf_Sn.shape[0] != c || pbuf_Sn.shape[1]!=h || pbuf_Sn.shape[2]!=w) 29 | throw std::runtime_error("Sn must have a shape of (c>0, h>0, w>0)"); 30 | 31 | std::vector> res; 32 | for (size_t i = 0; i(w, 0)); 34 | auto ptr_label_map = static_cast(pbuf_label_map.ptr); 35 | auto ptr_Sn = static_cast(pbuf_Sn.ptr); 36 | 37 | std::queue> q, next_q; 38 | 39 | for (size_t i = 0; i0) 46 | { 47 | q.push(std::make_tuple(i, j, label)); 48 | res[i][j] = label; 49 | } 50 | } 51 | } 52 | 53 | int dx[4] = {-1, 1, 0, 0}; 54 | int dy[4] = {0, 0, -1, 1}; 55 | // merge from small to large kernel progressively 56 | for (int i = 1; i(q_n); 65 | int x = std::get<1>(q_n); 66 | int32_t l = std::get<2>(q_n); 67 | //store the edge pixel after one expansion 68 | bool is_edge = true; 69 | for (int idx=0; idx<4; idx++) 70 | { 71 | int index_y = y + dy[idx]; 72 | int index_x = x + dx[idx]; 73 | if (index_y<0 || index_y>=h || index_x<0 || index_x>=w) 74 | continue; 75 | if (!p_Sn[index_y*w+index_x] || res[index_y][index_x]>0) 76 | continue; 77 | q.push(std::make_tuple(index_y, index_x, l)); 78 | res[index_y][index_x]=l; 79 | is_edge = false; 80 | } 81 | if (is_edge){ 82 | next_q.push(std::make_tuple(y, x, l)); 83 | } 84 | } 85 | std::swap(q, next_q); 86 | } 87 | return res; 88 | } 89 | } 90 | 91 | PYBIND11_MODULE(pse, m){ 92 | m.def("pse_cpp", &pse::pse, " re-implementation pse algorithm(cpp)", py::arg("label_map"), py::arg("Sn"), py::arg("c")=6); 93 | } 94 | 95 | -------------------------------------------------------------------------------- /PSENet_box_supervision/pse/pse.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/PSENet_box_supervision/pse/pse.so -------------------------------------------------------------------------------- /PSENet_box_supervision/pse_pyx/__init__.py: -------------------------------------------------------------------------------- 1 | from .pse import pse -------------------------------------------------------------------------------- /PSENet_box_supervision/pse_pyx/pse.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | cimport numpy as np 5 | cimport cython 6 | cimport libcpp 7 | cimport libcpp.pair 8 | cimport libcpp.queue 9 | from libcpp.pair cimport * 10 | from libcpp.queue cimport * 11 | 12 | @cython.boundscheck(False) 13 | @cython.wraparound(False) 14 | cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels, 15 | np.ndarray[np.float32_t, ndim=3] emb, 16 | np.ndarray[np.int32_t, ndim=2] label, 17 | np.ndarray[np.int32_t, ndim=2] cc, 18 | int kernel_num, 19 | int label_num, 20 | float min_area=0): 21 | cdef np.ndarray[np.int32_t, ndim=2] pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32) 22 | cdef np.ndarray[np.float32_t, ndim=2] mean_emb = np.zeros((label_num, 4), dtype=np.float32) 23 | cdef np.ndarray[np.float32_t, ndim=1] area = np.full((label_num,), -1, dtype=np.float32) 24 | cdef np.ndarray[np.int32_t, ndim=1] flag = np.zeros((label_num,), dtype=np.int32) 25 | cdef np.ndarray[np.uint8_t, ndim=3] inds = np.zeros((label_num, label.shape[0], label.shape[1]), dtype=np.uint8) 26 | cdef np.ndarray[np.int32_t, ndim=2] p = np.zeros((label_num, 2), dtype=np.int32) 27 | 28 | cdef np.float32_t max_rate = 1024 29 | for i in range(1, label_num): 30 | ind = label == i 31 | inds[i] = ind 32 | 33 | area[i] = np.sum(ind) 34 | 35 | if area[i] < min_area: 36 | label[ind] = 0 37 | continue 38 | 39 | px, py = np.where(ind) 40 | p[i] = (px[0], py[0]) 41 | 42 | for j in range(1, i): 43 | if area[j] < min_area: 44 | continue 45 | if cc[p[i, 0], p[i, 1]] != cc[p[j, 0], p[j, 1]]: 46 | continue 47 | rate = area[i] / area[j] 48 | if rate < 1 / max_rate or rate > max_rate: 49 | flag[i] = 1 50 | mean_emb[i] = np.mean(emb[:, ind], axis=1) 51 | 52 | if flag[j] == 0: 53 | flag[j] = 1 54 | mean_emb[j] = np.mean(emb[:, inds[j].astype(np.bool)], axis=1) 55 | 56 | cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \ 57 | queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() 58 | cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \ 59 | queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() 60 | cdef np.int16_t* dx = [-1, 1, 0, 0] 61 | cdef np.int16_t* dy = [0, 0, -1, 1] 62 | cdef np.int16_t tmpx, tmpy 63 | 64 | points = np.array(np.where(label > 0)).transpose((1, 0)) 65 | for point_idx in range(points.shape[0]): 66 | tmpx, tmpy = points[point_idx, 0], points[point_idx, 1] 67 | que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) 68 | pred[tmpx, tmpy] = label[tmpx, tmpy] 69 | 70 | cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur 71 | cdef int cur_label 72 | for kernel_idx in range(kernel_num - 1, -1, -1): 73 | while not que.empty(): 74 | cur = que.front() 75 | que.pop() 76 | cur_label = pred[cur.first, cur.second] 77 | 78 | is_edge = True 79 | for j in range(4): 80 | tmpx = cur.first + dx[j] 81 | tmpy = cur.second + dy[j] 82 | if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]: 83 | continue 84 | if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: 85 | continue 86 | if flag[cur_label] == 1 and np.linalg.norm(emb[:, tmpx, tmpy] - mean_emb[cur_label]) > 3: 87 | continue 88 | 89 | que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) 90 | pred[tmpx, tmpy] = cur_label 91 | is_edge = False 92 | if is_edge: 93 | nxt_que.push(cur) 94 | 95 | que, nxt_que = nxt_que, que 96 | 97 | return pred 98 | 99 | def pse(kernels, emb, min_area): 100 | kernel_num = kernels.shape[0] 101 | _, cc = cv2.connectedComponents(kernels[0], connectivity=4) 102 | label_num, label = cv2.connectedComponents(kernels[1], connectivity=4) 103 | 104 | return _pse(kernels[:-1], emb, label, cc, kernel_num, label_num, min_area) -------------------------------------------------------------------------------- /PSENet_box_supervision/pse_pyx/pse.pyx.bak1: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | cimport numpy as np 5 | cimport cython 6 | cimport libcpp 7 | cimport libcpp.pair 8 | cimport libcpp.queue 9 | from libcpp.pair cimport * 10 | from libcpp.queue cimport * 11 | 12 | @cython.boundscheck(False) 13 | @cython.wraparound(False) 14 | cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels, 15 | np.ndarray[np.float32_t, ndim=3] emb, 16 | np.ndarray[np.int32_t, ndim=2] label, 17 | int kernel_num, 18 | int label_num, 19 | float min_area=0, 20 | cal_dist=True): 21 | cdef np.ndarray[np.int32_t, ndim=2] pred 22 | cdef np.ndarray[np.float32_t, ndim=2] mean_emb 23 | cdef np.ndarray[np.float32_t, ndim=3] dist 24 | pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32) 25 | mean_emb = np.zeros((label_num, 4), dtype=np.float32) 26 | 27 | for label_idx in range(1, label_num): 28 | ind = label == label_idx 29 | if np.sum(ind) < min_area: 30 | label[ind] = 0 31 | elif cal_dist: 32 | mean_emb[label_idx] = np.mean(emb[:, ind], axis=1) 33 | 34 | cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \ 35 | queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() 36 | cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \ 37 | queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() 38 | cdef np.int16_t* dx = [-1, 1, 0, 0] 39 | cdef np.int16_t* dy = [0, 0, -1, 1] 40 | cdef np.int16_t tmpx, tmpy 41 | 42 | points = np.array(np.where(label > 0)).transpose((1, 0)) 43 | for point_idx in range(points.shape[0]): 44 | tmpx, tmpy = points[point_idx, 0], points[point_idx, 1] 45 | que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) 46 | pred[tmpx, tmpy] = label[tmpx, tmpy] 47 | 48 | cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur 49 | cdef int cur_label 50 | for kernel_idx in range(kernel_num - 1, -1, -1): 51 | while not que.empty(): 52 | cur = que.front() 53 | que.pop() 54 | cur_label = pred[cur.first, cur.second] 55 | 56 | is_edge = True 57 | for j in range(4): 58 | tmpx = cur.first + dx[j] 59 | tmpy = cur.second + dy[j] 60 | if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]: 61 | continue 62 | if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: 63 | continue 64 | if cal_dist and np.linalg.norm(emb[:, tmpx, tmpy] - mean_emb[cur_label]) > 3: 65 | continue 66 | 67 | que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) 68 | pred[tmpx, tmpy] = cur_label 69 | is_edge = False 70 | if is_edge: 71 | nxt_que.push(cur) 72 | 73 | que, nxt_que = nxt_que, que 74 | 75 | return pred 76 | 77 | def pse(kernels, emb, min_area, cal_dist=False): 78 | kernel_num = kernels.shape[0] 79 | label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4) 80 | 81 | return _pse(kernels[:-1], emb, label, kernel_num, label_num, min_area, cal_dist) -------------------------------------------------------------------------------- /PSENet_box_supervision/pse_pyx/pse.pyx.bak2: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | cimport numpy as np 5 | cimport cython 6 | cimport libcpp 7 | cimport libcpp.pair 8 | cimport libcpp.queue 9 | from libcpp.pair cimport * 10 | from libcpp.queue cimport * 11 | 12 | @cython.boundscheck(False) 13 | @cython.wraparound(False) 14 | cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels, 15 | np.ndarray[np.float32_t, ndim=3] emb, 16 | np.ndarray[np.int32_t, ndim=2] label, 17 | np.ndarray[np.int32_t, ndim=2] cc, 18 | int kernel_num, 19 | int label_num, 20 | float min_area=0): 21 | cdef np.ndarray[np.int32_t, ndim=2] pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32) 22 | cdef np.ndarray[np.float32_t, ndim=2] mean_emb = np.zeros((label_num, 4), dtype=np.float32) 23 | cdef np.ndarray[np.float32_t, ndim=1] area = np.full((label_num,), -1, dtype=np.float32) 24 | cdef np.ndarray[np.int32_t, ndim=1] flag = np.zeros((label_num,), dtype=np.int32) 25 | cdef np.ndarray[np.uint8_t, ndim=3] inds = np.zeros((label_num, label.shape[0], label.shape[1]), dtype=np.uint8) 26 | cdef np.ndarray[np.int32_t, ndim=2] p = np.zeros((label_num, 2), dtype=np.int32) 27 | 28 | cdef np.float32_t max_rate = 1024 29 | for i in range(1, label_num): 30 | ind = label == i 31 | inds[i] = ind 32 | 33 | area[i] = np.sum(ind) 34 | 35 | if area[i] < min_area: 36 | label[ind] = 0 37 | continue 38 | 39 | px, py = np.where(ind) 40 | p[i] = (px[0], py[0]) 41 | 42 | for j in range(1, i): 43 | if area[j] < min_area: 44 | continue 45 | if cc[p[i, 0], p[i, 1]] != cc[p[j, 0], p[j, 1]]: 46 | continue 47 | rate = area[i] / area[j] 48 | if rate < 1 / max_rate or rate > max_rate: 49 | flag[i] = 1 50 | mean_emb[i] = np.mean(emb[:, ind], axis=1) 51 | 52 | if flag[j] == 0: 53 | flag[j] = 1 54 | mean_emb[j] = np.mean(emb[:, inds[j].astype(np.bool)], axis=1) 55 | 56 | cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \ 57 | queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() 58 | cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \ 59 | queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() 60 | cdef np.int16_t* dx = [-1, 1, 0, 0] 61 | cdef np.int16_t* dy = [0, 0, -1, 1] 62 | cdef np.int16_t tmpx, tmpy 63 | 64 | points = np.array(np.where(label > 0)).transpose((1, 0)) 65 | for point_idx in range(points.shape[0]): 66 | tmpx, tmpy = points[point_idx, 0], points[point_idx, 1] 67 | que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) 68 | pred[tmpx, tmpy] = label[tmpx, tmpy] 69 | 70 | cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur 71 | cdef int cur_label 72 | for kernel_idx in range(kernel_num - 1, -1, -1): 73 | while not que.empty(): 74 | cur = que.front() 75 | que.pop() 76 | cur_label = pred[cur.first, cur.second] 77 | 78 | is_edge = True 79 | for j in range(4): 80 | tmpx = cur.first + dx[j] 81 | tmpy = cur.second + dy[j] 82 | if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]: 83 | continue 84 | if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: 85 | continue 86 | if flag[cur_label] == 1 and np.linalg.norm(emb[:, tmpx, tmpy] - mean_emb[cur_label]) > 3: 87 | continue 88 | 89 | que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) 90 | pred[tmpx, tmpy] = cur_label 91 | is_edge = False 92 | if is_edge: 93 | nxt_que.push(cur) 94 | 95 | que, nxt_que = nxt_que, que 96 | 97 | return pred 98 | 99 | def pse(kernels, emb, min_area): 100 | kernel_num = kernels.shape[0] 101 | _, cc = cv2.connectedComponents(kernels[0], connectivity=4) 102 | label_num, label = cv2.connectedComponents(kernels[1], connectivity=4) 103 | 104 | return _pse(kernels[:-1], emb, label, cc, kernel_num, label_num, min_area) -------------------------------------------------------------------------------- /PSENet_box_supervision/pse_pyx/pse.pyx.bak2_select: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | cimport numpy as np 5 | cimport cython 6 | cimport libcpp 7 | cimport libcpp.pair 8 | cimport libcpp.queue 9 | from libcpp.pair cimport * 10 | from libcpp.queue cimport * 11 | 12 | @cython.boundscheck(False) 13 | @cython.wraparound(False) 14 | cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels, 15 | np.ndarray[np.float32_t, ndim=3] emb, 16 | np.ndarray[np.int32_t, ndim=2] label, 17 | np.ndarray[np.int32_t, ndim=2] cc, 18 | int kernel_num, 19 | int label_num, 20 | float min_area=0): 21 | cdef np.ndarray[np.int32_t, ndim=2] pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32) 22 | cdef np.ndarray[np.float32_t, ndim=2] mean_emb = np.zeros((label_num, 4), dtype=np.float32) 23 | cdef np.ndarray[np.float32_t, ndim=1] area = np.full((label_num,), -1, dtype=np.float32) 24 | cdef np.ndarray[np.int32_t, ndim=1] flag = np.zeros((label_num,), dtype=np.int32) 25 | cdef np.ndarray[np.uint8_t, ndim=3] inds = np.zeros((label_num, label.shape[0], label.shape[1]), dtype=np.uint8) 26 | cdef np.ndarray[np.int32_t, ndim=2] p = np.zeros((label_num, 2), dtype=np.int32) 27 | 28 | cdef np.float32_t max_rate = 1024 29 | for i in range(1, label_num): 30 | ind = label == i 31 | inds[i] = ind 32 | 33 | area[i] = np.sum(ind) 34 | 35 | if area[i] < min_area: 36 | label[ind] = 0 37 | continue 38 | 39 | px, py = np.where(ind) 40 | p[i] = (px[0], py[0]) 41 | 42 | for j in range(1, i): 43 | if area[j] < min_area: 44 | continue 45 | if cc[p[i, 0], p[i, 1]] != cc[p[j, 0], p[j, 1]]: 46 | continue 47 | rate = area[i] / area[j] 48 | if rate < 1 / max_rate or rate > max_rate: 49 | flag[i] = 1 50 | mean_emb[i] = np.mean(emb[:, ind], axis=1) 51 | 52 | if flag[j] == 0: 53 | flag[j] = 1 54 | mean_emb[j] = np.mean(emb[:, inds[j].astype(np.bool)], axis=1) 55 | 56 | cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \ 57 | queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() 58 | cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \ 59 | queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() 60 | cdef np.int16_t* dx = [-1, 1, 0, 0] 61 | cdef np.int16_t* dy = [0, 0, -1, 1] 62 | cdef np.int16_t tmpx, tmpy 63 | 64 | points = np.array(np.where(label > 0)).transpose((1, 0)) 65 | for point_idx in range(points.shape[0]): 66 | tmpx, tmpy = points[point_idx, 0], points[point_idx, 1] 67 | que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) 68 | pred[tmpx, tmpy] = label[tmpx, tmpy] 69 | 70 | cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur 71 | cdef int cur_label 72 | for kernel_idx in range(kernel_num - 1, -1, -1): 73 | while not que.empty(): 74 | cur = que.front() 75 | que.pop() 76 | cur_label = pred[cur.first, cur.second] 77 | 78 | is_edge = True 79 | for j in range(4): 80 | tmpx = cur.first + dx[j] 81 | tmpy = cur.second + dy[j] 82 | if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]: 83 | continue 84 | if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: 85 | continue 86 | if flag[cur_label] == 1 and np.linalg.norm(emb[:, tmpx, tmpy] - mean_emb[cur_label]) > 3: 87 | continue 88 | 89 | que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) 90 | pred[tmpx, tmpy] = cur_label 91 | is_edge = False 92 | if is_edge: 93 | nxt_que.push(cur) 94 | 95 | que, nxt_que = nxt_que, que 96 | 97 | return pred 98 | 99 | def pse(kernels, emb, min_area): 100 | kernel_num = kernels.shape[0] 101 | _, cc = cv2.connectedComponents(kernels[0], connectivity=4) 102 | label_num, label = cv2.connectedComponents(kernels[1], connectivity=4) 103 | 104 | return _pse(kernels[:-1], emb, label, cc, kernel_num, label_num, min_area) -------------------------------------------------------------------------------- /PSENet_box_supervision/pse_pyx/pse.pyx.bak3: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | cimport numpy as np 5 | cimport cython 6 | cimport libcpp 7 | cimport libcpp.pair 8 | cimport libcpp.queue 9 | from libcpp.pair cimport * 10 | from libcpp.queue cimport * 11 | 12 | @cython.boundscheck(False) 13 | @cython.wraparound(False) 14 | cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels, 15 | np.ndarray[np.float32_t, ndim=3] emb, 16 | np.ndarray[np.int32_t, ndim=2] label, 17 | int kernel_num, 18 | int label_num, 19 | float min_area=0): 20 | cdef np.ndarray[np.int32_t, ndim=2] pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32) 21 | cdef np.ndarray[np.float32_t, ndim=2] mean_emb = np.zeros((label_num, 4), dtype=np.float32) 22 | cdef np.ndarray[np.float32_t, ndim=1] area = np.full((label_num,), -1, dtype=np.float32) 23 | cdef np.ndarray[np.uint8_t, ndim=1] flag = np.zeros((label_num,), dtype=np.uint8) 24 | cdef np.ndarray[np.uint8_t, ndim=3] inds = np.zeros((label_num, label.shape[0], label.shape[1]), dtype=np.uint8) 25 | cdef np.ndarray[np.int32_t, ndim=2] rect = np.zeros((label_num, 4), dtype=np.int32) 26 | 27 | cdef np.float32_t max_rate = 64 28 | for i in range(1, label_num): 29 | ind = label == i 30 | inds[i] = ind 31 | 32 | area[i] = np.sum(ind) 33 | 34 | if area[i] < min_area: 35 | label[ind] = 0 36 | continue 37 | 38 | x, y = np.where(ind) 39 | rect[i] = (np.min(x), np.min(y), np.max(x), np.max(y)) 40 | 41 | for j in range(1, i): 42 | if area[j] < min_area: 43 | continue 44 | if 1 / max_rate <= area[i] / area[j] <= max_rate: 45 | continue 46 | if min(rect[i, 2], rect[j, 2]) < max(rect[i, 0], rect[j, 0]): 47 | continue 48 | if min(rect[i, 3], rect[j, 3]) < max(rect[i, 1], rect[j, 1]): 49 | continue 50 | 51 | flag[i] = 1 52 | mean_emb[i] = np.mean(emb[:, ind], axis=1) 53 | 54 | if flag[j] == 0: 55 | flag[j] = 1 56 | mean_emb[j] = np.mean(emb[:, inds[j].astype(np.bool)], axis=1) 57 | 58 | cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \ 59 | queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() 60 | cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \ 61 | queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() 62 | cdef np.int16_t* dx = [-1, 1, 0, 0] 63 | cdef np.int16_t* dy = [0, 0, -1, 1] 64 | cdef np.int16_t tmpx, tmpy 65 | 66 | points = np.array(np.where(label > 0)).transpose((1, 0)) 67 | for point_idx in range(points.shape[0]): 68 | tmpx, tmpy = points[point_idx, 0], points[point_idx, 1] 69 | que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) 70 | pred[tmpx, tmpy] = label[tmpx, tmpy] 71 | 72 | cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur 73 | cdef int cur_label 74 | for kernel_idx in range(kernel_num - 1, -1, -1): 75 | while not que.empty(): 76 | cur = que.front() 77 | que.pop() 78 | cur_label = pred[cur.first, cur.second] 79 | 80 | is_edge = True 81 | for j in range(4): 82 | tmpx = cur.first + dx[j] 83 | tmpy = cur.second + dy[j] 84 | if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]: 85 | continue 86 | if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: 87 | continue 88 | if flag[cur_label] == 1 and np.linalg.norm(emb[:, tmpx, tmpy] - mean_emb[cur_label]) > 3: 89 | continue 90 | 91 | que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) 92 | pred[tmpx, tmpy] = cur_label 93 | is_edge = False 94 | if is_edge: 95 | nxt_que.push(cur) 96 | 97 | que, nxt_que = nxt_que, que 98 | 99 | return pred 100 | 101 | def pse(kernels, emb, min_area): 102 | kernel_num = kernels.shape[0] 103 | label_num, label = cv2.connectedComponents(kernels[1], connectivity=4) 104 | 105 | return _pse(kernels[:-1], emb, label, kernel_num, label_num, min_area) -------------------------------------------------------------------------------- /PSENet_box_supervision/pse_pyx/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup, Extension 2 | from Cython.Build import cythonize 3 | import numpy 4 | setup(ext_modules = cythonize(Extension( 5 | 'pse', 6 | sources=['pse.pyx'], 7 | language='c++', 8 | include_dirs=[numpy.get_include()], 9 | library_dirs=[], 10 | libraries=[], 11 | extra_compile_args=['-O3'], 12 | extra_link_args=[] 13 | ))) -------------------------------------------------------------------------------- /PSENet_box_supervision/test_icdar15.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import shutil 4 | import glob 5 | import time 6 | import numpy as np 7 | import torch 8 | from tqdm import tqdm 9 | from torch import nn 10 | import torch.utils.data as Data 11 | from torchvision import transforms 12 | import torchvision.utils as vutils 13 | from utils.utils import write_result_as_txt,debug,load_checkpoint, save_checkpoint, setup_logger 14 | from models import PSENet 15 | from evaluation.script import getresult 16 | from pse import decode as pse_decode 17 | from pse import decode_icdar17 as pse_decode_17 18 | from mmcv import Config 19 | import argparse 20 | os.environ['CUDA_VISIBLE_DEVICES'] = "0" 21 | 22 | 23 | def scale_image(img, short_size=800): 24 | h, w = img.shape[0:2] 25 | scale = short_size * 1.0 / min(h, w) 26 | max_scale = 3200.0 / max(h, w) 27 | scale = min(scale, max_scale) 28 | 29 | # img = cv2.resize(img, dsize=None, fx=scale, fy=scale) 30 | h = (int)(h * scale + 0.5) 31 | w = (int)(w * scale + 0.5) 32 | if h % 32 != 0: 33 | h = h + (32 - h % 32) 34 | if w % 32 != 0: 35 | w = w + (32 - w % 32) 36 | img = cv2.resize(img, dsize=(w, h)) 37 | return img 38 | 39 | 40 | def eval(model, config, device, thre, long_size=1280): 41 | model.eval() 42 | img_path = os.path.join(config.testroot, 'test_img') 43 | save_path = os.path.join(config.workspace, 'output_eval') 44 | 45 | # if os.path.exists(save_path): 46 | # shutil.rmtree(save_path, ignore_errors=True) 47 | if not os.path.exists(save_path): 48 | os.makedirs(save_path) 49 | long_size = 2240 50 | # 预测所有测试图片 51 | img_paths = [os.path.join(img_path, x) for x in os.listdir(img_path)] 52 | for img_path in tqdm(img_paths, desc='test models'): 53 | img_name = os.path.basename(img_path).split('.')[0] 54 | 55 | save_name = os.path.join(save_path, 'res_' + img_name + '.txt') 56 | 57 | assert os.path.exists(img_path), 'file is not exists' 58 | img = cv2.imread(img_path) 59 | h, w = img.shape[:2] 60 | scale = long_size / max(h, w) 61 | img = cv2.resize(img, None, fx=scale, fy=scale) 62 | # 将图片由(w,h)变为(1,img_channel,h,w) 63 | tensor = transforms.ToTensor()(img) 64 | tensor = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(tensor) 65 | tensor = tensor.unsqueeze_(0) 66 | tensor = tensor.to(device) 67 | with torch.no_grad(): 68 | preds = model(tensor) 69 | preds, boxes_list = pse_decode(preds[0], config.scale) 70 | scale = (preds.shape[1] * 1.0 / w, preds.shape[0] * 1.0 / h) 71 | if len(boxes_list): 72 | boxes_list = boxes_list / scale 73 | np.savetxt(save_name, boxes_list.reshape(-1, 8), delimiter=',', fmt='%d') 74 | # 开始计算 recall precision f1 75 | 76 | methodHmean,methodPrecision,methodRecall = getresult(save_path, config.gt_name) 77 | print("precision: {} , recall: {}, f1: {}".format(methodPrecision,methodRecall,methodHmean)) 78 | 79 | 80 | if __name__ == "__main__": 81 | parser = argparse.ArgumentParser(description='test a model') 82 | parser.add_argument('--config', type=str, default="./config/icdar15/icdar15_ST.py", help='') 83 | parser.add_argument('--resume_from', '-r', type=str, 84 | default="/home/xjc/Desktop/CVPR_SemiText/SemiText/PSENet/workdirs/icdar15/psenet_ST/Best_model_0.863204.pth", 85 | help='') 86 | parser.add_argument('--vis', action='store_true', help='') 87 | args = parser.parse_args() 88 | 89 | config = Config.fromfile(args.config) 90 | config.workspace = os.path.join(config.workspace_dir, config.exp_name) 91 | config.checkpoint = args.resume_from 92 | config.visualization = config.visualization 93 | 94 | if not os.path.exists(config.workspace): 95 | os.makedirs(config.workspace) 96 | logger = setup_logger(os.path.join(config.workspace, 'test_log')) 97 | logger.info(config.print()) 98 | 99 | model = PSENet(backbone=config.backbone, 100 | pretrained=config.pretrained, 101 | result_num=config.kernel_num, 102 | scale=config.scale) 103 | num_gpus = torch.cuda.device_count() 104 | device = torch.device("cuda:0") 105 | model = nn.DataParallel(model) 106 | model = model.to(device) 107 | state = torch.load(config.checkpoint) 108 | model.load_state_dict(state['state_dict']) 109 | logger.info('test epoch {}'.format(state['epoch'])) 110 | 111 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 112 | 113 | start_epoch = load_checkpoint(config.checkpoint, model, logger, device, optimizer) 114 | for i in range(60,100,2): 115 | thre = i*0.01 116 | print(thre) 117 | eval(model, config, device,thre) 118 | -------------------------------------------------------------------------------- /PSENet_box_supervision/test_msra.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import shutil 4 | import glob 5 | import time 6 | import numpy as np 7 | import torch 8 | from tqdm import tqdm 9 | from torch import nn 10 | import torch.utils.data as Data 11 | from torchvision import transforms 12 | from utils.utils import write_result_as_txt,debug,load_checkpoint, save_checkpoint, setup_logger 13 | from models import PSENet 14 | from evaluation.script import getresult 15 | from pse import decode_msra as pse_decode 16 | from mmcv import Config 17 | import argparse 18 | from evaluation.msra.eval import get_msra_result 19 | os.environ['CUDA_VISIBLE_DEVICES'] = "0" 20 | 21 | 22 | def scale_image(img, short_size=704): 23 | h, w = img.shape[0:2] 24 | scale = short_size * 1.0 / min(h, w) 25 | max_scale = 3200.0 / max(h, w) 26 | scale = min(scale, max_scale) 27 | 28 | # img = cv2.resize(img, dsize=None, fx=scale, fy=scale) 29 | h = (int)(h * scale + 0.5) 30 | w = (int)(w * scale + 0.5) 31 | if h % 32 != 0: 32 | h = h + (32 - h % 32) 33 | if w % 32 != 0: 34 | w = w + (32 - w % 32) 35 | img = cv2.resize(img, dsize=(w, h)) 36 | return img 37 | 38 | 39 | def eval(model, save_path, test_path, device,threshold): 40 | model.eval() 41 | 42 | save_path = os.path.join(save_path, "submit") 43 | os.makedirs(save_path, exist_ok=True) 44 | img_path_root = os.path.join(test_path, "MSRA-TD500", "test") 45 | 46 | # 预测所有测试图片 47 | img_paths = [os.path.join(img_path_root, x) for x in os.listdir(img_path_root)] 48 | for img_path in tqdm(img_paths, desc='test models'): 49 | if not img_path.endswith('.JPG') and not img_path.endswith('.jpg'): 50 | continue 51 | img_name = os.path.basename(img_path).split('.')[0] 52 | save_name = os.path.join(save_path, 'res_' + img_name + '.txt') 53 | 54 | assert os.path.exists(img_path), 'file is not exists' 55 | img = cv2.imread(img_path) 56 | org_img = img.copy() 57 | h, w = img.shape[:2] 58 | 59 | img = scale_image(img) 60 | # 将图片由(w,h)变为(1,img_channel,h,w) 61 | tensor = transforms.ToTensor()(img) 62 | tensor = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(tensor) 63 | tensor = tensor.unsqueeze_(0) 64 | tensor = tensor.to(device) 65 | with torch.no_grad(): 66 | preds = model(tensor) 67 | preds, boxes_list = pse_decode(preds[0], config.scale, org_img) 68 | 69 | np.savetxt(save_name, boxes_list.reshape(-1, 8), delimiter=',', fmt='%d') 70 | # 开始计算 recall precision f1 71 | f_score_new = get_msra_result(save_path, img_path_root) 72 | return f_score_new 73 | 74 | if __name__ == "__main__": 75 | parser = argparse.ArgumentParser(description='test a model') 76 | parser.add_argument('--config', type=str, default="./config/msra/msra_baseline.py", help='') 77 | parser.add_argument('--resume_from', '-r', type=str, 78 | default="/home/xjc/Desktop/CVPR_SemiText/SemiText/PSENet_box_supervision/workspace/msra/msra_pseudo/Best_model_0.794250.pth", 79 | help='') 80 | parser.add_argument('--vis', action='store_true', help='') 81 | args = parser.parse_args() 82 | 83 | config = Config.fromfile(args.config) 84 | config.workspace = os.path.join(config.workspace_dir, config.exp_name) 85 | config.checkpoint = args.resume_from 86 | config.visualization = config.visualization 87 | 88 | if not os.path.exists(config.workspace): 89 | os.makedirs(config.workspace) 90 | logger = setup_logger(os.path.join(config.workspace, 'test_log')) 91 | logger.info(config.print()) 92 | 93 | model = PSENet(backbone=config.backbone, 94 | pretrained=config.pretrained, 95 | result_num=config.kernel_num, 96 | scale=config.scale) 97 | num_gpus = torch.cuda.device_count() 98 | device = torch.device("cuda:0") 99 | model = nn.DataParallel(model) 100 | model = model.to(device) 101 | state = torch.load(config.checkpoint) 102 | model.load_state_dict(state['state_dict']) 103 | logger.info('test epoch {}'.format(state['epoch'])) 104 | 105 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 106 | 107 | start_epoch = load_checkpoint(config.checkpoint, model, logger, device, optimizer) 108 | for i in range(70,80): 109 | thre = i*0.01 110 | print(thre) 111 | eval(model, config.workspace, config.msra_path, device, thre) 112 | # eval(model, config, device,thre) 113 | break -------------------------------------------------------------------------------- /PSENet_box_supervision/utils/.ipynb_checkpoints/config_icdar15-checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/1/3 17:40 3 | # @Author : zhoujun 4 | 5 | 6 | # data config 7 | exp_name = "ICDAR15" 8 | trainroot = '/data/glusterfs_cv_04/11121171/data/ICDAR2015/' 9 | testroot = '/data/glusterfs_cv_04/11121171/data/ICDAR2015/' 10 | workspace_dir = '/data/glusterfs_cv_04/11121171/CVPR_Text/PSENet/workspace' 11 | workspace = "" 12 | gt_name = "icd15_gt.zip" 13 | data_shape = 640 14 | 15 | # train config 16 | gpu_id = '0' 17 | workers = 10 18 | start_epoch = 0 19 | epochs = 600 20 | 21 | train_batch_size = 10 22 | 23 | lr = 1e-4 24 | end_lr = 1e-7 25 | lr_gamma = 0.1 26 | lr_decay_step = [200,400] 27 | weight_decay = 5e-4 28 | warm_up_epoch = 6 29 | warm_up_lr = lr * lr_gamma 30 | 31 | display_input_images = False 32 | display_output_images = False 33 | display_interval = 10 34 | show_images_interval = 50 35 | 36 | pretrained = True 37 | restart_training = True 38 | checkpoint = '' 39 | 40 | # net config 41 | backbone = 'resnet50' 42 | Lambda = 0.7 43 | n = 6 44 | m = 0.5 45 | OHEM_ratio = 3 46 | scale = 1 47 | # random seed 48 | seed = 2 49 | 50 | 51 | def print(): 52 | from pprint import pformat 53 | tem_d = {} 54 | for k, v in globals().items(): 55 | if not k.startswith('_') and not callable(v): 56 | tem_d[k] = v 57 | return pformat(tem_d) 58 | -------------------------------------------------------------------------------- /PSENet_box_supervision/utils/.ipynb_checkpoints/config_icdar17-checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/1/3 17:40 3 | # @Author : zhoujun 4 | 5 | 6 | # data config 7 | exp_name = "ICDAR17" 8 | trainroot = '/data/glusterfs_cv_04/11121171/data/ICDAR2017/' 9 | testroot = '/data/glusterfs_cv_04/11121171/data/ICDAR2017/' 10 | workspace_dir = '/data/glusterfs_cv_04/11121171/CVPR_Text/PSENet/workspace' 11 | workspace = "" 12 | gt_name = "icd17_gt.zip" 13 | data_shape = 800 14 | 15 | # train config 16 | gpu_id = '0,1' 17 | workers = 10 18 | start_epoch = 0 19 | epochs = 600 20 | 21 | train_batch_size = 24 22 | 23 | lr = 1e-4 24 | end_lr = 1e-7 25 | lr_gamma = 0.1 26 | lr_decay_step = [100,200] 27 | weight_decay = 5e-4 28 | warm_up_epoch = 6 29 | warm_up_lr = lr * lr_gamma 30 | 31 | display_input_images = False 32 | display_output_images = False 33 | display_interval = 10 34 | show_images_interval = 50 35 | 36 | pretrained = True 37 | restart_training = False 38 | checkpoint = '' 39 | 40 | # net config 41 | backbone = 'resnet50' 42 | Lambda = 0.7 43 | n = 6 44 | m = 0.5 45 | OHEM_ratio = 3 46 | scale = 1 47 | # random seed 48 | seed = 2 49 | 50 | 51 | def print(): 52 | from pprint import pformat 53 | tem_d = {} 54 | for k, v in globals().items(): 55 | if not k.startswith('_') and not callable(v): 56 | tem_d[k] = v 57 | return pformat(tem_d) 58 | -------------------------------------------------------------------------------- /PSENet_box_supervision/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 3/22/19 11:45 AM 3 | # @Author : zhoujun 4 | from .utils import * -------------------------------------------------------------------------------- /PSENet_box_supervision/utils/config_icdar15.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/1/3 17:40 3 | # data config 4 | exp_name = "ICDAR15" 5 | trainroot = '/data/glusterfs_cv_04/11121171/data/ICDAR2015/' 6 | testroot = '/data/glusterfs_cv_04/11121171/data/ICDAR2015/' 7 | workspace_dir = '/data/glusterfs_cv_04/11121171/CVPR_Text/PSENet/workspace' 8 | workspace = "" 9 | gt_name = "icd15_gt.zip" 10 | data_shape = 640 11 | 12 | # train config 13 | gpu_id = '0,1' 14 | workers = 10 15 | start_epoch = 0 16 | epochs = 600 17 | 18 | train_batch_size = 16 19 | 20 | lr = 1e-4 21 | end_lr = 1e-7 22 | lr_gamma = 0.1 23 | lr_decay_step = [100,200] 24 | weight_decay = 5e-4 25 | warm_up_epoch = 6 26 | warm_up_lr = lr * lr_gamma 27 | 28 | display_input_images = False 29 | display_output_images = False 30 | display_interval = 10 31 | show_images_interval = 50 32 | save_interval=5 33 | 34 | pretrained = True 35 | restart_training = True 36 | checkpoint = '' 37 | 38 | # net config 39 | backbone = 'resnet50' 40 | Lambda = 0.7 41 | n = 6 42 | m = 0.5 43 | OHEM_ratio = 3 44 | scale = 1 45 | # random seed 46 | seed = 2 47 | 48 | 49 | def print(): 50 | from pprint import pformat 51 | tem_d = {} 52 | for k, v in globals().items(): 53 | if not k.startswith('_') and not callable(v): 54 | tem_d[k] = v 55 | return pformat(tem_d) 56 | -------------------------------------------------------------------------------- /PSENet_box_supervision/utils/config_synthtext.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/1/3 17:40 3 | 4 | 5 | # data config 6 | exp_name = "Synthtext" 7 | trainroot = '/home/xjc/Dataset/SynthText/' 8 | testroot = '/home/xjc/Dataset/SynthText/' 9 | workspace_dir = '/home/xjc/Desktop/CVPR_SemiText/SemiText/PSENet/workdirs/' 10 | workspace = "" 11 | gt_name = "" 12 | data_shape = 640 13 | 14 | # train config 15 | gpu_id = '2,3' 16 | workers = 10 17 | start_epoch = 0 18 | epochs = 600 19 | 20 | train_batch_size = 14 21 | 22 | lr = 1e-4 23 | end_lr = 1e-7 24 | lr_gamma = 0.1 25 | lr_decay_step = [200,400] 26 | weight_decay = 5e-4 27 | warm_up_epoch = 6 28 | warm_up_lr = lr * lr_gamma 29 | 30 | display_input_images = False 31 | display_output_images = False 32 | display_interval = 10 33 | show_images_interval = 50 34 | 35 | pretrained = True 36 | restart_training = False 37 | checkpoint = '' 38 | 39 | # net config 40 | backbone = 'resnet50' 41 | Lambda = 0.7 42 | n = 7 43 | m = 0.5 44 | OHEM_ratio = 3 45 | scale = 1 46 | # random seed 47 | seed = 2 48 | 49 | 50 | def print(): 51 | from pprint import pformat 52 | tem_d = {} 53 | for k, v in globals().items(): 54 | if not k.startswith('_') and not callable(v): 55 | tem_d[k] = v 56 | return pformat(tem_d) 57 | -------------------------------------------------------------------------------- /PSENet_box_supervision/utils/config_totaltext.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # data config 4 | exp_name = "Total_Text" 5 | trainroot = '/mnt/lustre/share/xieenze/Text/total_text/' 6 | testroot = '/mnt/lustre/share/xieenze/Text/total_text/' 7 | workspace_dir = './workdirs/' 8 | workspace = "./workdirs/eval" 9 | data_shape = 640 10 | 11 | # train config 12 | gpu_id = '0,1,2,3' 13 | workers = 10 14 | start_epoch = 0 15 | epochs = 100 16 | 17 | train_batch_size = 20 18 | 19 | lr = 1e-4 20 | end_lr = 1e-7 21 | lr_gamma = 0.1 22 | lr_decay_step = [60,80] 23 | weight_decay = 5e-4 24 | warm_up_epoch = 6 25 | warm_up_lr = lr * lr_gamma 26 | 27 | display_input_images = False 28 | display_output_images = False 29 | visualization = False 30 | display_interval = 10 31 | show_images_interval = 50 32 | 33 | pretrained = True 34 | restart_training = False 35 | checkpoint = '' 36 | save_interval = 5 37 | 38 | # net config 39 | backbone = 'resnet50' 40 | Lambda = 0.7 41 | kernel_num = 7 42 | min_scale = 0.4 43 | OHEM_ratio = 3 44 | scale = 1 45 | min_kernel_area = 10.0 46 | # random seed 47 | seed = 2 48 | binary_th = 1.0 49 | 50 | 51 | def pprint(): 52 | from pprint import pformat 53 | tem_d = {} 54 | for k, v in globals().items(): 55 | if not k.startswith('_') and not callable(v): 56 | tem_d[k] = v 57 | return pformat(tem_d) 58 | -------------------------------------------------------------------------------- /PSENet_box_supervision/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/19/19 3:37 PM 3 | # @Author : zhoujun 4 | from torch.optim.lr_scheduler import MultiStepLR 5 | 6 | 7 | class WarmupMultiStepLR(MultiStepLR): 8 | def __init__(self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3, 9 | warmup_iters=500, last_epoch=-1): 10 | self.warmup_factor = warmup_factor 11 | self.warmup_iters = warmup_iters 12 | super().__init__(optimizer, milestones, gamma, last_epoch) 13 | 14 | def get_lr(self): 15 | lr = super().get_lr() 16 | if self.last_epoch < self.warmup_iters: 17 | alpha = self.last_epoch / self.warmup_iters 18 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 19 | return [l * warmup_factor for l in lr] 20 | return lr -------------------------------------------------------------------------------- /TextBoxSeg/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # mac os 107 | __MACOSX/ 108 | 109 | #model 110 | trash 111 | workdirs 112 | *.bak 113 | datasets 114 | -------------------------------------------------------------------------------- /TextBoxSeg/configs/textseg.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | ROOT: "data/st800k_crop" 3 | NAME: "st800k" 4 | MEAN: [0.485, 0.456, 0.406] 5 | STD: [0.229, 0.224, 0.225] 6 | TRAIN: 7 | EPOCHS: 16 8 | BATCH_SIZE: 8 9 | BASE_SIZE: (128,32) 10 | MODEL_SAVE_DIR: 'workdirs/debug' 11 | TEST: 12 | BATCH_SIZE: 1 13 | TEST_MODEL_PATH: 'workdirs/debug/16.pth' 14 | 15 | SOLVER: 16 | LR: 0.02 17 | 18 | MODEL: 19 | MODEL_NAME: "TextSeg" 20 | BACKBONE: "resnet50" 21 | 22 | -------------------------------------------------------------------------------- /TextBoxSeg/configs/textseg2.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | ROOT: "/data/glusterfs_cv_04/11121171/data/SynthText/" 3 | NAME: "st800k" 4 | MEAN: [0.485, 0.456, 0.406] 5 | STD: [0.229, 0.224, 0.225] 6 | TRAIN: 7 | EPOCHS: 50 8 | BATCH_SIZE: 24 9 | BASE_SIZE: (128,32) 10 | MODEL_SAVE_DIR: '/data/glusterfs_cv_04/11121171/CVPR_Text/SemiText/TextBoxSeg/workdirs' 11 | TEST: 12 | BATCH_SIZE: 1 13 | TEST_MODEL_PATH: '/data/glusterfs_cv_04/11121171/CVPR_Text/SemiText/TextBoxSeg/workdirs/50_32.pth' 14 | 15 | SOLVER: 16 | LR: 0.02 17 | 18 | MODEL: 19 | MODEL_NAME: "TextSeg" 20 | BACKBONE: "resnet50" 21 | 22 | -------------------------------------------------------------------------------- /TextBoxSeg/configs/textseg_total.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | ROOT: "/data/data_weijiawu/CurvedSynthText/SegData/" 3 | # st800k_attention , st800k_total 4 | NAME: "st800k_attention" 5 | MEAN: [0.485, 0.456, 0.406] 6 | STD: [0.229, 0.224, 0.225] 7 | TRAIN: 8 | EPOCHS: 100 9 | BATCH_SIZE: 24 10 | BASE_SIZE: (128,128) # (160,128) or (128,128) 11 | MODEL_SAVE_DIR: '../workdirs/' 12 | 13 | TEST: 14 | BATCH_SIZE: 1 15 | TEST_MODEL_PATH: '../workdirs/50_attention.pth' 16 | #50_attention.pth 17 | 18 | SOLVER: 19 | LR: 0.02 20 | 21 | MODEL: 22 | # model_name textseg_attention or TextSeg 23 | MODEL_NAME: "textseg_attention" 24 | BACKBONE: "resnet50" 25 | 26 | -------------------------------------------------------------------------------- /TextBoxSeg/data/st800k_crop: -------------------------------------------------------------------------------- 1 | /mnt/lustre/share_data/xieenze/xez_space/Text/st800k_crop -------------------------------------------------------------------------------- /TextBoxSeg/data/st800k_crop2: -------------------------------------------------------------------------------- 1 | /mnt/lustre/share_data/xieenze/xez_space/Text/st800k_crop2 -------------------------------------------------------------------------------- /TextBoxSeg/demo/curved_st800k_crop.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import glob 5 | from tqdm import tqdm 6 | import scipy.io as sio 7 | import matplotlib.pyplot as plt 8 | import time 9 | import random 10 | # from IPython import embed 11 | import re 12 | import itertools 13 | import json 14 | def readimg(path): 15 | return cv2.imread(path) 16 | 17 | 18 | def show(img): 19 | if len(img.shape) == 3: 20 | return plt.imshow(img[:, :, ::-1]) 21 | else: 22 | return plt.imshow(img) 23 | 24 | 25 | def makedirs(path): 26 | """Create directory recursively if not exists. 27 | Similar to `makedir -p`, you can skip checking existence before this function. 28 | Parameters 29 | ---------- 30 | path : str 31 | Path of the desired dir 32 | """ 33 | if not os.path.exists(path): 34 | os.makedirs(path) 35 | else: 36 | pass 37 | 38 | 39 | def get_patchs(img, char_boxs): 40 | img_h, img_w, _ = img.shape 41 | 42 | char_boxs = np.array(char_boxs).transpose(2, 1, 0) 43 | 44 | contours = [] 45 | for c in char_boxs: 46 | contours.append(c[0]) 47 | contours.append(c[1]) 48 | for c in char_boxs[::-1, :, :]: 49 | contours.append(c[2]) 50 | contours.append(c[3]) 51 | contours = np.array(contours) 52 | 53 | mask = np.zeros((img_h, img_w)) 54 | cv2.fillPoly(mask, [contours], 1) 55 | 56 | return img, mask 57 | 58 | 59 | if __name__ == "__main__": 60 | st800k_path = '/data/data_weijiawu/CurvedSynthText/' 61 | save_path = '/data/data_weijiawu/CurvedSynthText/SegData/' 62 | # print(len(os.listdir(os.path.join(save_path,"image")))) 63 | # raise NameError 64 | max_num = 500000 65 | if not os.path.exists(save_path): 66 | os.makedirs(save_path) 67 | json_filename = '{}/Label.json'.format(st800k_path) 68 | 69 | print('load json label..') 70 | with open(json_filename) as f: 71 | pop_data = json.load(f) 72 | print('json label loaded') 73 | 74 | shuffle_ids = [i for i in range(len(pop_data))] 75 | random.shuffle(shuffle_ids) 76 | save_number = 0 77 | total_p_imgs, total_p_masks = [], [] 78 | print('generate patch...') 79 | for img_idx in tqdm(shuffle_ids): 80 | assert len(total_p_imgs) == len(total_p_masks) 81 | 82 | data = pop_data[img_idx] 83 | char_box = data["chars"] 84 | 85 | file_number,_,image_name = data["img"].split("/") 86 | file_number = file_number.zfill(4) 87 | image_path = os.path.join(st800k_path,"img/images",file_number,image_name) 88 | 89 | if os.path.exists(image_path): 90 | image = readimg(image_path) 91 | try: 92 | patch_imgs, patch_masks = get_patchs(image,char_box) 93 | except: 94 | continue 95 | 96 | 97 | img_path = os.path.join(save_path, 'image', '{}.png'.format(save_number)) 98 | mask_path = os.path.join(save_path, 'mask', '{}.png'.format(save_number)) 99 | # print(img_path) 100 | cv2.imwrite(img_path, patch_imgs) 101 | cv2.imwrite(mask_path, patch_masks) 102 | save_number += 1 103 | if save_number>= max_num: 104 | break 105 | 106 | # if min(img.shape[:2]) < 20: 107 | # continue 108 | 109 | # print('save images...') 110 | # debug = False 111 | # for i in tqdm(range(len(total_p_imgs))): 112 | # img = total_p_imgs[i] 113 | # mask = total_p_masks[i] 114 | # if min(img.shape[:2]) < 20: 115 | # continue 116 | # if debug: 117 | # print('debug vis') 118 | # mask *= 255 119 | # img_path = os.path.join(save_path, 'image', '{}.png'.format(i)) 120 | # mask_path = os.path.join(save_path, 'mask', '{}.png'.format(i)) 121 | # 122 | # cv2.imwrite(img_path, img) 123 | # cv2.imwrite(mask_path, mask) -------------------------------------------------------------------------------- /TextBoxSeg/demo/imgs/text1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/imgs/text1.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/imgs/text10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/imgs/text10.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/imgs/text11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/imgs/text11.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/imgs/text12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/imgs/text12.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/imgs/text13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/imgs/text13.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/imgs/text14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/imgs/text14.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/imgs/text15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/imgs/text15.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/imgs/text16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/imgs/text16.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/imgs/text2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/imgs/text2.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/imgs/text3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/imgs/text3.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/imgs/text4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/imgs/text4.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/imgs/text5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/imgs/text5.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/imgs/text6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/imgs/text6.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/imgs/text7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/imgs/text7.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/imgs/text8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/imgs/text8.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/imgs/text9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/imgs/text9.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/result/text1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/result/text1.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/result/text10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/result/text10.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/result/text11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/result/text11.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/result/text12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/result/text12.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/result/text13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/result/text13.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/result/text14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/result/text14.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/result/text15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/result/text15.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/result/text16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/result/text16.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/result/text2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/result/text2.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/result/text3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/result/text3.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/result/text4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/result/text4.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/result/text5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/result/text5.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/result/text6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/result/text6.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/result/text7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/result/text7.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/result/text8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/result/text8.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/result/text9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/demo/result/text9.png -------------------------------------------------------------------------------- /TextBoxSeg/demo/st800k_crop.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import glob 5 | from tqdm import tqdm 6 | import scipy.io as sio 7 | import matplotlib.pyplot as plt 8 | import time 9 | import random 10 | from IPython import embed 11 | 12 | def readimg(path): 13 | return cv2.imread(path) 14 | 15 | def show(img): 16 | if len(img.shape) == 3: 17 | return plt.imshow(img[:,:,::-1]) 18 | else: 19 | return plt.imshow(img) 20 | 21 | def makedirs(path): 22 | """Create directory recursively if not exists. 23 | Similar to `makedir -p`, you can skip checking existence before this function. 24 | Parameters 25 | ---------- 26 | path : str 27 | Path of the desired dir 28 | """ 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | else: 32 | pass 33 | 34 | 35 | def get_patchs(img, r_boxes): 36 | patch_imgs, patch_masks = [], [] 37 | img_h, img_w, _ = img.shape 38 | for i, r_box in enumerate(r_boxes): 39 | rect = cv2.boundingRect(r_box) 40 | x,y,w,h = rect 41 | # cv2.rectangle(img, (x, y), (x+w, y+h), (0,255,0), 2) 42 | mask = np.zeros((img_h,img_w)) 43 | cv2.fillPoly(mask, [r_box], 1) 44 | 45 | patch_img = img[y:y+h+1,x:x+w+1]; patch_imgs.append(patch_img) 46 | patch_mask = mask[y:y+h+1,x:x+w+1]; patch_masks.append(patch_mask) 47 | 48 | return patch_imgs, patch_masks 49 | 50 | 51 | 52 | 53 | if __name__ == "__main__": 54 | st800k_path = '/mnt/lustre/share_data/xieenze/xez_space/Text/SynthText' 55 | save_path = '/mnt/lustre/share_data/xieenze/xez_space/Text/st800k_crop' 56 | max_num = 10000 57 | 58 | gt_mat = '{}/gt.mat'.format(st800k_path) 59 | 60 | print('load mat...') 61 | mat = sio.loadmat(gt_mat) 62 | print('mat loaded') 63 | 64 | charBB, wordBB = mat['charBB'][0], mat['wordBB'][0] 65 | img_names = mat['imnames'][0] 66 | img_paths = [os.path.join(st800k_path, i[0]) for i in img_names] 67 | 68 | shuffle_ids = [i for i in range(len(img_paths))] 69 | random.shuffle(shuffle_ids) 70 | 71 | total_p_imgs, total_p_masks = [], [] 72 | print('generate patch...') 73 | for img_idx in tqdm(shuffle_ids[:max_num]): 74 | assert len(total_p_imgs) == len(total_p_masks) 75 | img = readimg(img_paths[img_idx]) 76 | if len(wordBB[img_idx].shape) == 2: 77 | continue 78 | r_boxes = wordBB[img_idx].transpose(2, 1, 0) 79 | r_boxes = np.array(r_boxes, dtype='int32') 80 | p_imgs, p_masks = get_patchs(img, r_boxes) 81 | total_p_imgs.extend(p_imgs) 82 | total_p_masks.extend(p_masks) 83 | 84 | 85 | 86 | print('save images...') 87 | debug = False 88 | for i in tqdm(range(len(total_p_imgs))): 89 | img = total_p_imgs[i] 90 | mask = total_p_masks[i] 91 | if min(img.shape[:2]) < 20: 92 | continue 93 | if debug: 94 | print('debug vis') 95 | mask *= 255 96 | img_path = os.path.join(save_path, 'image', '{}.png'.format(i)) 97 | mask_path = os.path.join(save_path, 'mask', '{}.png'.format(i)) 98 | cv2.imwrite(img_path, img) 99 | cv2.imwrite(mask_path, mask) -------------------------------------------------------------------------------- /TextBoxSeg/demo/st800k_crop2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import glob 5 | from tqdm import tqdm 6 | import scipy.io as sio 7 | import matplotlib.pyplot as plt 8 | import time 9 | import random 10 | from IPython import embed 11 | import re 12 | import itertools 13 | 14 | def readimg(path): 15 | return cv2.imread(path) 16 | 17 | 18 | def show(img): 19 | if len(img.shape) == 3: 20 | return plt.imshow(img[:, :, ::-1]) 21 | else: 22 | return plt.imshow(img) 23 | 24 | 25 | def makedirs(path): 26 | """Create directory recursively if not exists. 27 | Similar to `makedir -p`, you can skip checking existence before this function. 28 | Parameters 29 | ---------- 30 | path : str 31 | Path of the desired dir 32 | """ 33 | if not os.path.exists(path): 34 | os.makedirs(path) 35 | else: 36 | pass 37 | 38 | 39 | def get_patchs(img, r_box, char_boxs): 40 | patch_imgs, patch_masks = [], [] 41 | img_h, img_w, _ = img.shape 42 | contours = [] 43 | for c in char_boxs: 44 | contours.append(c[0]) 45 | contours.append(c[1]) 46 | for c in char_boxs[::-1, :, :]: 47 | contours.append(c[2]) 48 | contours.append(c[3]) 49 | contours = np.array(contours) 50 | 51 | rect = cv2.boundingRect(r_box) 52 | x, y, w, h = rect 53 | mask = np.zeros((img_h, img_w)) 54 | cv2.fillPoly(mask, [contours], 1) 55 | 56 | patch_img = img[y:y + h + 1, x:x + w + 1] 57 | patch_imgs.append(patch_img) 58 | patch_mask = mask[y:y + h + 1, x:x + w + 1] 59 | patch_masks.append(patch_mask) 60 | 61 | return patch_imgs, patch_masks 62 | 63 | 64 | if __name__ == "__main__": 65 | st800k_path = '/data/glusterfs_cv_04/11121171/data/SynthText/' 66 | save_path = '/data/glusterfs_cv_04/11121171/data/SynthText/' 67 | max_num = 15000 68 | 69 | gt_mat = '{}/gt.mat'.format(st800k_path) 70 | 71 | print('load mat...') 72 | mat = sio.loadmat(gt_mat) 73 | print('mat loaded') 74 | 75 | charBB, wordBB = mat['charBB'][0], mat['wordBB'][0] 76 | wordText = mat['txt'][0] 77 | 78 | img_names = mat['imnames'][0] 79 | img_paths = [os.path.join(st800k_path, i[0]) for i in img_names] 80 | 81 | shuffle_ids = [i for i in range(len(img_paths))] 82 | random.shuffle(shuffle_ids) 83 | save_number = 0 84 | total_p_imgs, total_p_masks = [], [] 85 | print('generate patch...') 86 | for img_idx in tqdm(shuffle_ids[:max_num]): 87 | assert len(total_p_imgs) == len(total_p_masks) 88 | img = readimg(img_paths[img_idx]) 89 | 90 | if len(wordBB[img_idx].shape) == 2: 91 | continue 92 | 93 | r_boxes = wordBB[img_idx].transpose(2, 1, 0) 94 | r_boxes = np.array(r_boxes, dtype='int32') 95 | 96 | char_boxes = np.array(charBB[img_idx].transpose(2, 1, 0),dtype='int32') 97 | text = wordText[img_idx] 98 | words = [re.split(' \n|\n |\n| ', t.strip()) for t in text] 99 | words = list(itertools.chain(*words)) 100 | words = [t for t in words if len(t) > 0] 101 | 102 | total = 0 103 | for i in range(len(words)): 104 | r_box = r_boxes[i] 105 | char_box = char_boxes[total:total+len(words[i])] 106 | assert (len(char_box) == len(words[i])) 107 | total += len(words[i]) 108 | patch_img, patch_mask = get_patchs(img, r_box, char_box) 109 | 110 | 111 | # total_p_imgs.extend(patch_img) 112 | # total_p_masks.extend(patch_mask) 113 | for y in range(len(patch_img)): 114 | if min(img.shape[:2]) < 20: 115 | continue 116 | img_path = os.path.join(save_path, 'image', '{}.png'.format(save_number)) 117 | mask_path = os.path.join(save_path, 'mask', '{}.png'.format(save_number)) 118 | print(img_path) 119 | cv2.imwrite(img_path, patch_img[y]) 120 | cv2.imwrite(mask_path, patch_mask[y]) 121 | save_number+=1 122 | 123 | # print('save images...') 124 | # debug = False 125 | # for i in tqdm(range(len(total_p_imgs))): 126 | # img = total_p_imgs[i] 127 | # mask = total_p_masks[i] 128 | # if min(img.shape[:2]) < 20: 129 | # continue 130 | # if debug: 131 | # print('debug vis') 132 | # mask *= 255 133 | # img_path = os.path.join(save_path, 'image', '{}.png'.format(i)) 134 | # mask_path = os.path.join(save_path, 'mask', '{}.png'.format(i)) 135 | # 136 | # cv2.imwrite(img_path, img) 137 | # cv2.imwrite(mask_path, mask) -------------------------------------------------------------------------------- /TextBoxSeg/demo/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | print(os.listdir("/home/xjc/Dataset/Curved_SynthText/img/Data")[:100]) 3 | print(len(os.listdir("/home/xjc/Dataset/Curved_SynthText/img/Data"))) -------------------------------------------------------------------------------- /TextBoxSeg/readme.md: -------------------------------------------------------------------------------- 1 | (1) 生成word 分割的synth gt 2 | synthtext crop,然后生成word mask gt. 3 | 在demo/st800k_crop2.py(mask gt是按照char连接的). 4 | 在demo/st800k_crop.py(mask gt 是synthtext默认的gt). 5 | 6 | 7 | (2) 训练分割网络 8 | sh tools/dist_train.sh configs/textseg2.yaml 9 | textseg.yaml 和textseg2.yaml 对应(1)中的crop和corp2 10 | 11 | (3) 生成pseudo label 12 | 生成ic15 pseudo label: python3 tools/gen_ic15_pslabel.py --config-file configs/textseg2.yaml 13 | 可视化ic15 pseudo label: python3 tools/demo_ic15.py --config-file configs/textseg2.yaml 14 | 图片生成在demo/trash/IC15 15 | 16 | 生成totaltext pseudo label: python3 tools/gen_tt_pslabel.py --config-file configs/textseg2.yaml 17 | 可视化totaltext pseudo label: python3 tools/demo_tt.py --config-file configs/textseg2.yaml 18 | 图片生成在demo/trash/TT 19 | 20 | # cv2.imread和Image.open效果不知道谁好。。。。。。 21 | 22 | -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/__init__.py: -------------------------------------------------------------------------------- 1 | from . import modules, models, utils, data -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .settings import cfg -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/config/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import unicode_literals 5 | 6 | import codecs 7 | import yaml 8 | import six 9 | import time 10 | 11 | from ast import literal_eval 12 | 13 | class SegmentronConfig(dict): 14 | def __init__(self, *args, **kwargs): 15 | super(SegmentronConfig, self).__init__(*args, **kwargs) 16 | self.immutable = False 17 | 18 | def __setattr__(self, key, value, create_if_not_exist=True): 19 | if key in ["immutable"]: 20 | self.__dict__[key] = value 21 | return 22 | 23 | t = self 24 | keylist = key.split(".") 25 | for k in keylist[:-1]: 26 | t = t.__getattr__(k, create_if_not_exist) 27 | 28 | t.__getattr__(keylist[-1], create_if_not_exist) 29 | t[keylist[-1]] = value 30 | 31 | def __getattr__(self, key, create_if_not_exist=True): 32 | if key in ["immutable"]: 33 | if key not in self.__dict__: 34 | self.__dict__[key] = False 35 | return self.__dict__[key] 36 | 37 | if not key in self: 38 | if not create_if_not_exist: 39 | raise KeyError 40 | self[key] = SegmentronConfig() 41 | return self[key] 42 | 43 | def __setitem__(self, key, value): 44 | # 45 | if self.immutable: 46 | raise AttributeError( 47 | 'Attempted to set "{}" to "{}", but SegConfig is immutable'. 48 | format(key, value)) 49 | # 50 | if isinstance(value, six.string_types): 51 | try: 52 | value = literal_eval(value) 53 | except ValueError: 54 | pass 55 | except SyntaxError: 56 | pass 57 | super(SegmentronConfig, self).__setitem__(key, value) 58 | 59 | def update_from_other_cfg(self, other): 60 | if isinstance(other, dict): 61 | other = SegmentronConfig(other) 62 | assert isinstance(other, SegmentronConfig) 63 | cfg_list = [("", other)] 64 | while len(cfg_list): 65 | prefix, tdic = cfg_list[0] 66 | cfg_list = cfg_list[1:] 67 | for key, value in tdic.items(): 68 | key = "{}.{}".format(prefix, key) if prefix else key 69 | if isinstance(value, dict): 70 | cfg_list.append((key, value)) 71 | continue 72 | try: 73 | self.__setattr__(key, value, create_if_not_exist=False) 74 | except KeyError: 75 | raise KeyError('Non-existent config key: {}'.format(key)) 76 | 77 | def remove_irrelevant_cfg(self): 78 | model_name = self.MODEL.MODEL_NAME 79 | 80 | from ..models.model_zoo import MODEL_REGISTRY 81 | model_list = MODEL_REGISTRY.get_list() 82 | model_list_lower = [x.lower() for x in model_list] 83 | # print('model_list:', model_list) 84 | assert model_name.lower() in model_list_lower, "Expected model name in {}, but received {}"\ 85 | .format(model_list, model_name) 86 | pop_keys = [] 87 | for key in self.MODEL.keys(): 88 | if key.lower() in model_list_lower and key.lower() != model_name.lower(): 89 | pop_keys.append(key) 90 | for key in pop_keys: 91 | self.MODEL.pop(key) 92 | 93 | 94 | 95 | def check_and_freeze(self): 96 | self.TIME_STAMP = time.strftime('%Y-%m-%d-%H-%M', time.localtime()) 97 | # TODO: remove irrelevant config and then freeze 98 | self.remove_irrelevant_cfg() 99 | self.immutable = True 100 | 101 | def update_from_list(self, config_list): 102 | if len(config_list) % 2 != 0: 103 | raise ValueError( 104 | "Command line options config format error! Please check it: {}". 105 | format(config_list)) 106 | for key, value in zip(config_list[0::2], config_list[1::2]): 107 | try: 108 | self.__setattr__(key, value, create_if_not_exist=False) 109 | except KeyError: 110 | raise KeyError('Non-existent config key: {}'.format(key)) 111 | 112 | def update_from_file(self, config_file): 113 | with codecs.open(config_file, 'r', 'utf-8') as file: 114 | loaded_cfg = yaml.load(file, Loader=yaml.FullLoader) 115 | self.update_from_other_cfg(loaded_cfg) 116 | 117 | def set_immutable(self, immutable): 118 | self.immutable = immutable 119 | for value in self.values(): 120 | if isinstance(value, SegmentronConfig): 121 | value.set_immutable(immutable) 122 | 123 | def is_immutable(self): 124 | return self.immutable -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/config/settings.py: -------------------------------------------------------------------------------- 1 | from .config import SegmentronConfig 2 | 3 | cfg = SegmentronConfig() 4 | 5 | ########################## basic set ########################################### 6 | # random seed 7 | cfg.SEED = 1024 8 | # train time stamp, auto generate, do not need to set 9 | cfg.TIME_STAMP = '' 10 | # root path 11 | cfg.ROOT_PATH = '' 12 | # model phase ['train', 'test'] 13 | cfg.PHASE = 'train' 14 | 15 | ########################## dataset config ######################################### 16 | # dataset name 17 | cfg.DATASET.NAME = '' 18 | cfg.DATASET.ROOT = '' 19 | # pixel mean 20 | cfg.DATASET.MEAN = [0.5, 0.5, 0.5] 21 | # pixel std 22 | cfg.DATASET.STD = [0.5, 0.5, 0.5] 23 | # dataset ignore index 24 | cfg.DATASET.IGNORE_INDEX = -1 25 | # workers 26 | cfg.DATASET.WORKERS = 8 27 | # val dataset mode 28 | cfg.DATASET.MODE = 'testval' 29 | ########################### data augment ###################################### 30 | # data augment image mirror 31 | cfg.AUG.MIRROR = True 32 | # blur probability 33 | cfg.AUG.BLUR_PROB = 0.0 34 | # blur radius 35 | cfg.AUG.BLUR_RADIUS = 0.0 36 | # color jitter, float or tuple: (0.1, 0.2, 0.3, 0.4) 37 | cfg.AUG.COLOR_JITTER = None 38 | ########################### train config ########################################## 39 | # epochs 40 | cfg.TRAIN.EPOCHS = 30 41 | # batch size 42 | cfg.TRAIN.BATCH_SIZE = 1 43 | # train crop size 44 | cfg.TRAIN.CROP_SIZE = 769 45 | # train base size 46 | cfg.TRAIN.BASE_SIZE = 512 47 | # model output dir 48 | cfg.TRAIN.MODEL_SAVE_DIR = 'workdirs/' 49 | # log dir 50 | cfg.TRAIN.LOG_SAVE_DIR = cfg.TRAIN.MODEL_SAVE_DIR 51 | # pretrained model for eval or finetune 52 | cfg.TRAIN.PRETRAINED_MODEL_PATH = '' 53 | # use pretrained backbone model over imagenet 54 | cfg.TRAIN.BACKBONE_PRETRAINED = True 55 | # backbone pretrained model path, if not specific, will load from url when backbone pretrained enabled 56 | cfg.TRAIN.BACKBONE_PRETRAINED_PATH = '' 57 | # resume model path 58 | cfg.TRAIN.RESUME_MODEL_PATH = '' 59 | # whether to use synchronize bn 60 | cfg.TRAIN.SYNC_BATCH_NORM = True 61 | # save model every checkpoint-epoch 62 | cfg.TRAIN.SNAPSHOT_EPOCH = 1 63 | 64 | ########################### optimizer config ################################## 65 | # base learning rate 66 | cfg.SOLVER.LR = 1e-4 67 | # optimizer method 68 | cfg.SOLVER.OPTIMIZER = "sgd" 69 | # optimizer epsilon 70 | cfg.SOLVER.EPSILON = 1e-8 71 | # optimizer momentum 72 | cfg.SOLVER.MOMENTUM = 0.9 73 | # weight decay 74 | cfg.SOLVER.WEIGHT_DECAY = 1e-4 #0.00004 75 | # decoder lr x10 76 | cfg.SOLVER.DECODER_LR_FACTOR = 10.0 77 | # lr scheduler mode 78 | cfg.SOLVER.LR_SCHEDULER = "poly" 79 | # poly power 80 | cfg.SOLVER.POLY.POWER = 0.9 81 | # step gamma 82 | cfg.SOLVER.STEP.GAMMA = 0.1 83 | # milestone of step lr scheduler 84 | cfg.SOLVER.STEP.DECAY_EPOCH = [10, 20] 85 | # warm up epochs can be float 86 | cfg.SOLVER.WARMUP.EPOCHS = 0. 87 | # warm up factor 88 | cfg.SOLVER.WARMUP.FACTOR = 1.0 / 3 89 | # warm up method 90 | cfg.SOLVER.WARMUP.METHOD = 'linear' 91 | # whether to use ohem 92 | cfg.SOLVER.OHEM = False 93 | # whether to use aux loss 94 | cfg.SOLVER.AUX = False 95 | # aux loss weight 96 | cfg.SOLVER.AUX_WEIGHT = 0.4 97 | # loss name 98 | cfg.SOLVER.LOSS_NAME = '' 99 | ########################## test config ########################################### 100 | # val/test model path 101 | cfg.TEST.TEST_MODEL_PATH = '' 102 | # test batch size 103 | cfg.TEST.BATCH_SIZE = 1 104 | # eval crop size 105 | cfg.TEST.CROP_SIZE = None 106 | # multiscale eval 107 | cfg.TEST.SCALES = [1.0] 108 | # flip 109 | cfg.TEST.FLIP = False 110 | 111 | ########################## visual config ########################################### 112 | # visual result output dir 113 | cfg.VISUAL.OUTPUT_DIR = '../runs/visual/' 114 | 115 | ########################## model ####################################### 116 | # model name 117 | cfg.MODEL.MODEL_NAME = '' 118 | # model backbone 119 | cfg.MODEL.BACKBONE = '' 120 | # model backbone channel scale 121 | cfg.MODEL.BACKBONE_SCALE = 1.0 122 | # support resnet b, c. b is standard resnet in pytorch official repo 123 | # cfg.MODEL.RESNET_VARIANT = 'b' 124 | # multi branch loss weight 125 | cfg.MODEL.MULTI_LOSS_WEIGHT = [1.0] 126 | # gn groups 127 | cfg.MODEL.DEFAULT_GROUP_NUMBER = 32 128 | # whole model default epsilon 129 | cfg.MODEL.DEFAULT_EPSILON = 1e-5 130 | # batch norm, support ['BN', 'SyncBN', 'FrozenBN', 'GN', 'nnSyncBN'] 131 | cfg.MODEL.BN_TYPE = 'BN' 132 | # batch norm epsilon for encoder, if set None will use api default value. 133 | cfg.MODEL.BN_EPS_FOR_ENCODER = None 134 | # batch norm epsilon for encoder, if set None will use api default value. 135 | cfg.MODEL.BN_EPS_FOR_DECODER = None 136 | # backbone output stride 137 | cfg.MODEL.OUTPUT_STRIDE = 16 138 | # BatchNorm momentum, if set None will use api default value. 139 | cfg.MODEL.BN_MOMENTUM = None 140 | 141 | 142 | ########################## DeepLab config #################################### 143 | # whether to use aspp 144 | cfg.MODEL.DEEPLABV3_PLUS.USE_ASPP = True 145 | # whether to use decoder 146 | cfg.MODEL.DEEPLABV3_PLUS.ENABLE_DECODER = True 147 | # whether aspp use sep conv 148 | cfg.MODEL.DEEPLABV3_PLUS.ASPP_WITH_SEP_CONV = True 149 | # whether decoder use sep conv 150 | cfg.MODEL.DEEPLABV3_PLUS.DECODER_USE_SEP_CONV = True 151 | ########################## Demo #################################### 152 | 153 | cfg.DEMO_DIR = 'demo/imgs' 154 | -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/segmentron/data/__init__.py -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/data/dataloader/TextSegmentation_total.py: -------------------------------------------------------------------------------- 1 | """Prepare Trans10K dataset""" 2 | import os 3 | import torch 4 | import numpy as np 5 | import logging 6 | 7 | from PIL import Image 8 | from .seg_data_base import SegmentationDataset_total 9 | # from IPython import embed 10 | import cv2 11 | 12 | class TextSegmentation_total(SegmentationDataset_total): 13 | """Trans10K Semantic Segmentation Dataset. 14 | 15 | Parameters 16 | ---------- 17 | root : string 18 | Path to Trans10K folder. Default is './datasets/Trans10K' 19 | split: string 20 | 'train', 'validation', 'test' 21 | transform : callable, optional 22 | A function that transforms the image 23 | """ 24 | BASE_DIR = 'st800k' 25 | NUM_CLASS = 2 26 | 27 | def __init__(self, root='data/st800k_crop', split='train', mode=None, transform=None, debug=False, **kwargs): 28 | super(TextSegmentation_total, self).__init__(root, split, mode, transform, **kwargs) 29 | assert os.path.exists(self.root), "Please put dataset in {}".format(root) 30 | self.images, self.mask_paths = _get_st800kcrop_pairs(self.root, self.split) 31 | assert (len(self.images) == len(self.mask_paths)) 32 | if len(self.images) == 0: 33 | raise RuntimeError("Found 0 images in subfolders of:" + root + "\n") 34 | self.valid_classes = [0,1] 35 | self._key = np.array([0,1]) 36 | self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32') + 1 37 | self.debug = debug 38 | self.mode =mode 39 | 40 | def _class_to_index(self, mask): 41 | # assert the value 42 | values = np.unique(mask) 43 | for value in values: 44 | assert (value in self._mapping) 45 | index = np.digitize(mask.ravel(), self._mapping, right=True) 46 | return self._key[index].reshape(mask.shape) 47 | 48 | def __getitem__(self, index): 49 | try: 50 | img = Image.open(self.images[index]).convert('RGB') 51 | except: 52 | print("invalid image:",self.images[index] ) 53 | return self.__getitem__(index+1) 54 | if self.mode == 'test': 55 | if self.transform is not None: 56 | img = self.transform(img) 57 | return img, os.path.basename(self.images[index]) 58 | mask = Image.open(self.mask_paths[index]) 59 | 60 | # synchrosized transform 61 | if self.mode == 'train': 62 | img, mask = self._sync_transform(img, mask) 63 | elif self.mode == 'val': 64 | img, mask = self._val_sync_transform(img, mask) 65 | else: 66 | assert self.mode == 'testval' 67 | img, mask = self._img_transform(img), self._mask_transform(mask) 68 | 69 | if self.debug == True: 70 | print('debug vis') 71 | _img = Image.fromarray(img) 72 | _img.save('trash/img.png') 73 | _mask = Image.fromarray(mask.float().data.cpu().numpy()*255).convert('L') 74 | _mask.save('trash/mask.png') 75 | 76 | # general resize, normalize and toTensor 77 | if self.transform is not None: 78 | img = self.transform(img) 79 | return img, mask, self.images[index] 80 | 81 | def _mask_transform(self, mask): 82 | target = self._class_to_index(np.array(mask).astype('int32')) 83 | return torch.LongTensor(np.array(target).astype('int32')) 84 | 85 | def __len__(self): 86 | return len(self.images) 87 | 88 | @property 89 | def pred_offset(self): 90 | return 0 91 | 92 | @property 93 | def classes(self): 94 | """Category names.""" 95 | return ('background', 'text') 96 | 97 | 98 | def _get_st800kcrop_pairs(folder, split='train'): 99 | 100 | def get_path_pairs(img_folder, mask_folder): 101 | img_paths = [] 102 | mask_paths = [] 103 | imgs = os.listdir(img_folder) 104 | 105 | for imgname in imgs: 106 | imgpath = os.path.join(img_folder, imgname) 107 | maskname = imgname 108 | maskpath = os.path.join(mask_folder, maskname) 109 | if os.path.isfile(imgpath) and os.path.isfile(maskpath): 110 | img_paths.append(imgpath) 111 | mask_paths.append(maskpath) 112 | else: 113 | logging.info('cannot find the mask or image: {} {}'.format(imgpath, maskpath)) 114 | 115 | logging.info('Found {} images in the folder {}'.format(len(img_paths), img_folder)) 116 | return img_paths, mask_paths 117 | 118 | if split == 'train': 119 | img_folder = os.path.join(folder, 'image') 120 | mask_folder = os.path.join(folder, 'mask') 121 | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder) 122 | else: 123 | pass 124 | return img_paths, mask_paths 125 | -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/data/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides data loaders and transformers for popular vision datasets. 3 | """ 4 | from .st800k import TextSegmentation 5 | from .TextSegmentation_total import TextSegmentation_total 6 | from .Curved_Synthtext_attention import TextSegmentation_attention 7 | datasets = { 8 | 'st800k': TextSegmentation, 9 | 'st800k_total': TextSegmentation_total, 10 | 'st800k_attention': TextSegmentation_attention 11 | 12 | } 13 | 14 | 15 | def get_segmentation_dataset(name, **kwargs): 16 | """Segmentation Datasets""" 17 | return datasets[name.lower()](**kwargs) 18 | -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/data/dataloader/st800k.py: -------------------------------------------------------------------------------- 1 | """Prepare Trans10K dataset""" 2 | import os 3 | import torch 4 | import numpy as np 5 | import logging 6 | 7 | from PIL import Image 8 | from .seg_data_base import SegmentationDataset 9 | # from IPython import embed 10 | import cv2 11 | 12 | class TextSegmentation(SegmentationDataset): 13 | """Trans10K Semantic Segmentation Dataset. 14 | 15 | Parameters 16 | ---------- 17 | root : string 18 | Path to Trans10K folder. Default is './datasets/Trans10K' 19 | split: string 20 | 'train', 'validation', 'test' 21 | transform : callable, optional 22 | A function that transforms the image 23 | """ 24 | BASE_DIR = 'st800k' 25 | NUM_CLASS = 2 26 | 27 | def __init__(self, root='data/st800k_crop', split='train', mode=None, transform=None, debug=False, **kwargs): 28 | super(TextSegmentation, self).__init__(root, split, mode, transform, **kwargs) 29 | assert os.path.exists(self.root), "Please put dataset in {}".format(root) 30 | self.images, self.mask_paths = _get_st800kcrop_pairs(self.root, self.split) 31 | assert (len(self.images) == len(self.mask_paths)) 32 | if len(self.images) == 0: 33 | raise RuntimeError("Found 0 images in subfolders of:" + root + "\n") 34 | self.valid_classes = [0,1] 35 | self._key = np.array([0,1]) 36 | self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32') + 1 37 | self.debug = debug 38 | 39 | def _class_to_index(self, mask): 40 | # assert the value 41 | values = np.unique(mask) 42 | for value in values: 43 | assert (value in self._mapping) 44 | index = np.digitize(mask.ravel(), self._mapping, right=True) 45 | return self._key[index].reshape(mask.shape) 46 | 47 | def __getitem__(self, index): 48 | try: 49 | img = Image.open(self.images[index]).convert('RGB') 50 | except: 51 | print("invalid image:",self.images[index] ) 52 | return self.__getitem__(index+1) 53 | if self.mode == 'test': 54 | if self.transform is not None: 55 | img = self.transform(img) 56 | return img, os.path.basename(self.images[index]) 57 | mask = Image.open(self.mask_paths[index]) 58 | 59 | # synchrosized transform 60 | if self.mode == 'train': 61 | img, mask = self._sync_transform(img, mask) 62 | elif self.mode == 'val': 63 | img, mask = self._val_sync_transform(img, mask) 64 | else: 65 | assert self.mode == 'testval' 66 | img, mask = self._img_transform(img), self._mask_transform(mask) 67 | 68 | if self.debug == True: 69 | print('debug vis') 70 | _img = Image.fromarray(img) 71 | _img.save('trash/img.png') 72 | _mask = Image.fromarray(mask.float().data.cpu().numpy()*255).convert('L') 73 | _mask.save('trash/mask.png') 74 | 75 | # general resize, normalize and toTensor 76 | if self.transform is not None: 77 | img = self.transform(img) 78 | return img, mask, self.images[index] 79 | 80 | def _mask_transform(self, mask): 81 | target = self._class_to_index(np.array(mask).astype('int32')) 82 | return torch.LongTensor(np.array(target).astype('int32')) 83 | 84 | def __len__(self): 85 | return len(self.images) 86 | 87 | @property 88 | def pred_offset(self): 89 | return 0 90 | 91 | @property 92 | def classes(self): 93 | """Category names.""" 94 | return ('background', 'text') 95 | 96 | 97 | def _get_st800kcrop_pairs(folder, split='train'): 98 | 99 | def get_path_pairs(img_folder, mask_folder): 100 | img_paths = [] 101 | mask_paths = [] 102 | imgs = os.listdir(img_folder) 103 | 104 | for imgname in imgs: 105 | imgpath = os.path.join(img_folder, imgname) 106 | maskname = imgname 107 | maskpath = os.path.join(mask_folder, maskname) 108 | if os.path.isfile(imgpath) and os.path.isfile(maskpath): 109 | img_paths.append(imgpath) 110 | mask_paths.append(maskpath) 111 | else: 112 | logging.info('cannot find the mask or image: {} {}'.format(imgpath, maskpath)) 113 | 114 | logging.info('Found {} images in the folder {}'.format(len(img_paths), img_folder)) 115 | return img_paths, mask_paths 116 | 117 | if split == 'train': 118 | img_folder = os.path.join(folder, 'image') 119 | mask_folder = os.path.join(folder, 'mask') 120 | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder) 121 | else: 122 | pass 123 | return img_paths, mask_paths 124 | -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/data/dataloader/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | import errno 4 | import tarfile 5 | from six.moves import urllib 6 | from torch.utils.model_zoo import tqdm 7 | 8 | def read_lines(p): 9 | """return the text in a file in lines as a list """ 10 | p = get_absolute_path(p) 11 | f = open(p, 'rU', encoding='utf-8-sig') 12 | return f.readlines() 13 | 14 | 15 | def remove_all(s, sub): 16 | return replace_all(s, sub, '') 17 | 18 | def replace_all(s, old, new, reg=False): 19 | if reg: 20 | import re 21 | targets = re.findall(old, s) 22 | for t in targets: 23 | s = s.replace(t, new) 24 | else: 25 | s = s.replace(old, new) 26 | return s 27 | 28 | 29 | def get_absolute_path(p): 30 | if p.startswith('~'): 31 | p = os.path.expanduser(p) 32 | return os.path.abspath(p) 33 | 34 | 35 | def gen_bar_updater(): 36 | pbar = tqdm(total=None) 37 | 38 | def bar_update(count, block_size, total_size): 39 | if pbar.total is None and total_size: 40 | pbar.total = total_size 41 | progress_bytes = count * block_size 42 | pbar.update(progress_bytes - pbar.n) 43 | 44 | return bar_update 45 | 46 | def check_integrity(fpath, md5=None): 47 | if md5 is None: 48 | return True 49 | if not os.path.isfile(fpath): 50 | return False 51 | md5o = hashlib.md5() 52 | with open(fpath, 'rb') as f: 53 | # read in 1MB chunks 54 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 55 | md5o.update(chunk) 56 | md5c = md5o.hexdigest() 57 | if md5c != md5: 58 | return False 59 | return True 60 | 61 | def makedir_exist_ok(dirpath): 62 | try: 63 | os.makedirs(dirpath) 64 | except OSError as e: 65 | if e.errno == errno.EEXIST: 66 | pass 67 | else: 68 | pass 69 | 70 | def download_url(url, root, filename=None, md5=None): 71 | """Download a file from a url and place it in root.""" 72 | root = os.path.expanduser(root) 73 | if not filename: 74 | filename = os.path.basename(url) 75 | fpath = os.path.join(root, filename) 76 | 77 | makedir_exist_ok(root) 78 | 79 | # downloads file 80 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 81 | print('Using downloaded and verified file: ' + fpath) 82 | else: 83 | try: 84 | print('Downloading ' + url + ' to ' + fpath) 85 | urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater()) 86 | except OSError: 87 | if url[:5] == 'https': 88 | url = url.replace('https:', 'http:') 89 | print('Failed download. Trying https -> http instead.' 90 | ' Downloading ' + url + ' to ' + fpath) 91 | urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater()) 92 | 93 | def download_extract(url, root, filename, md5): 94 | download_url(url, root, filename, md5) 95 | with tarfile.open(os.path.join(root, filename), "r") as tar: 96 | tar.extractall(path=root) -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Model Zoo""" 2 | from .model_zoo import MODEL_REGISTRY 3 | from .textseg import TextSeg 4 | from .textseg_attention import textseg_attention 5 | -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import BACKBONE_REGISTRY, get_segmentation_backbone 2 | from .xception import * 3 | from .mobilenet import * 4 | from .resnet import * 5 | from .hrnet import * 6 | from .eespnet import * 7 | -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/models/backbones/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | from ...utils.registry import Registry 7 | from ...config import cfg 8 | 9 | BACKBONE_REGISTRY = Registry("BACKBONE") 10 | BACKBONE_REGISTRY.__doc__ = """ 11 | Registry for backbone, i.e. resnet. 12 | 13 | The registered object will be called with `obj()` 14 | and expected to return a `nn.Module` object. 15 | """ 16 | 17 | model_urls = { 18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 23 | 'resnet50c': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/resnet50-25c4b509.pth', 24 | 'resnet101c': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/resnet101-2a57e44d.pth', 25 | 'resnet152c': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/resnet152-0d43d698.pth', 26 | 'xception65': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/tf-xception65-270e81cf.pth', 27 | 'hrnet_w18_small_v1': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/hrnet-w18-small-v1-08f8ae64.pth', 28 | 'mobilenet_v2': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/mobilenetV2-15498621.pth', 29 | } 30 | 31 | 32 | def load_backbone_pretrained(model, backbone): 33 | if cfg.PHASE == 'train' and cfg.TRAIN.BACKBONE_PRETRAINED and (not cfg.TRAIN.PRETRAINED_MODEL_PATH): 34 | if os.path.isfile(cfg.TRAIN.BACKBONE_PRETRAINED_PATH): 35 | logging.info('Load backbone pretrained model from {}'.format( 36 | cfg.TRAIN.BACKBONE_PRETRAINED_PATH 37 | )) 38 | msg = model.load_state_dict(torch.load(cfg.TRAIN.BACKBONE_PRETRAINED_PATH), strict=False) 39 | logging.info(msg) 40 | elif backbone not in model_urls: 41 | logging.info('{} has no pretrained model'.format(backbone)) 42 | return 43 | else: 44 | logging.info('load backbone pretrained model from url..') 45 | msg = model.load_state_dict(model_zoo.load_url(model_urls[backbone]), strict=False) 46 | logging.info(msg) 47 | 48 | 49 | def get_segmentation_backbone(backbone, norm_layer=torch.nn.BatchNorm2d): 50 | """ 51 | Built the backbone model, defined by `cfg.MODEL.BACKBONE`. 52 | """ 53 | model = BACKBONE_REGISTRY.get(backbone)(norm_layer) 54 | load_backbone_pretrained(model, backbone) 55 | return model 56 | 57 | -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/models/model_zoo.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | 4 | from collections import OrderedDict 5 | from segmentron.utils.registry import Registry 6 | from ..config import cfg 7 | 8 | MODEL_REGISTRY = Registry("MODEL") 9 | MODEL_REGISTRY.__doc__ = """ 10 | Registry for segment model, i.e. the whole model. 11 | 12 | The registered object will be called with `obj()` 13 | and expected to return a `nn.Module` object. 14 | """ 15 | 16 | 17 | def get_segmentation_model(): 18 | """ 19 | Built the whole model, defined by `cfg.MODEL.META_ARCHITECTURE`. 20 | """ 21 | model_name = cfg.MODEL.MODEL_NAME 22 | model = MODEL_REGISTRY.get(model_name)() 23 | load_model_pretrain(model) 24 | return model 25 | 26 | 27 | def load_model_pretrain(model): 28 | if cfg.PHASE == 'train': 29 | if cfg.TRAIN.PRETRAINED_MODEL_PATH: 30 | logging.info('load pretrained model from {}'.format(cfg.TRAIN.PRETRAINED_MODEL_PATH)) 31 | state_dict_to_load = torch.load(cfg.TRAIN.PRETRAINED_MODEL_PATH) 32 | keys_wrong_shape = [] 33 | state_dict_suitable = OrderedDict() 34 | state_dict = model.state_dict() 35 | for k, v in state_dict_to_load.items(): 36 | if v.shape == state_dict[k].shape: 37 | state_dict_suitable[k] = v 38 | else: 39 | keys_wrong_shape.append(k) 40 | logging.info('Shape unmatched weights: {}'.format(keys_wrong_shape)) 41 | msg = model.load_state_dict(state_dict_suitable, strict=False) 42 | logging.info(msg) 43 | else: 44 | if cfg.TEST.TEST_MODEL_PATH: 45 | logging.info('load test model from {}'.format(cfg.TEST.TEST_MODEL_PATH)) 46 | msg = model.load_state_dict(torch.load(cfg.TEST.TEST_MODEL_PATH), strict=False) 47 | logging.info(msg) -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/models/segbase.py: -------------------------------------------------------------------------------- 1 | """Base Model for Semantic Segmentation""" 2 | import math 3 | import numbers 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from .backbones import get_segmentation_backbone 10 | from ..data.dataloader import datasets 11 | from ..modules import get_norm 12 | from ..config import cfg 13 | __all__ = ['SegBaseModel'] 14 | 15 | 16 | class SegBaseModel(nn.Module): 17 | r"""Base Model for Semantic Segmentation 18 | """ 19 | def __init__(self, need_backbone=True): 20 | super(SegBaseModel, self).__init__() 21 | self.nclass = datasets[cfg.DATASET.NAME].NUM_CLASS 22 | self.aux = cfg.SOLVER.AUX 23 | self.norm_layer = get_norm(cfg.MODEL.BN_TYPE) 24 | self.backbone = None 25 | self.encoder = None 26 | if need_backbone: 27 | self.get_backbone() 28 | 29 | def get_backbone(self): 30 | self.backbone = cfg.MODEL.BACKBONE.lower() 31 | self.encoder = get_segmentation_backbone(self.backbone, self.norm_layer) 32 | 33 | def base_forward(self, x): 34 | """forwarding backbone network""" 35 | c1, c2, c3, c4 = self.encoder(x) 36 | return c1, c2, c3, c4 37 | 38 | def demo(self, x): 39 | pred = self.forward(x) 40 | if self.aux: 41 | pred = pred[0] 42 | return pred 43 | 44 | def evaluate(self, image): 45 | """evaluating network with inputs and targets""" 46 | scales = cfg.TEST.SCALES 47 | batch, _, h, w = image.shape 48 | base_size = max(h, w) 49 | # scores = torch.zeros((batch, self.nclass, h, w)).to(image.device) 50 | scores = None 51 | for scale in scales: 52 | long_size = int(math.ceil(base_size * scale)) 53 | if h > w: 54 | height = long_size 55 | width = int(1.0 * w * long_size / h + 0.5) 56 | else: 57 | width = long_size 58 | height = int(1.0 * h * long_size / w + 0.5) 59 | 60 | # resize image to current size 61 | cur_img = _resize_image(image, height, width) 62 | outputs = self.forward(cur_img)[0][..., :height, :width] 63 | 64 | score = _resize_image(outputs, h, w) 65 | 66 | if scores is None: 67 | scores = score 68 | else: 69 | scores += score 70 | return scores 71 | 72 | 73 | def _resize_image(img, h, w): 74 | return F.interpolate(img, size=[h, w], mode='bilinear', align_corners=True) 75 | 76 | 77 | def _pad_image(img, crop_size): 78 | b, c, h, w = img.shape 79 | assert(c == 3) 80 | padh = crop_size[0] - h if h < crop_size[0] else 0 81 | padw = crop_size[1] - w if w < crop_size[1] else 0 82 | if padh == 0 and padw == 0: 83 | return img 84 | img_pad = F.pad(img, (0, padh, 0, padw)) 85 | 86 | # TODO clean this code 87 | # mean = cfg.DATASET.MEAN 88 | # std = cfg.DATASET.STD 89 | # pad_values = -np.array(mean) / np.array(std) 90 | # img_pad = torch.zeros((b, c, h + padh, w + padw)).to(img.device) 91 | # for i in range(c): 92 | # # print(img[:, i, :, :].unsqueeze(1).shape) 93 | # img_pad[:, i, :, :] = torch.squeeze( 94 | # F.pad(img[:, i, :, :].unsqueeze(1), (0, padh, 0, padw), 95 | # 'constant', value=pad_values[i]), 1) 96 | # assert(img_pad.shape[2] >= crop_size[0] and img_pad.shape[3] >= crop_size[1]) 97 | 98 | return img_pad 99 | 100 | 101 | def _crop_image(img, h0, h1, w0, w1): 102 | return img[:, :, h0:h1, w0:w1] 103 | 104 | 105 | def _flip_image(img): 106 | assert(img.ndim == 4) 107 | return img.flip((3)) 108 | 109 | 110 | def _to_tuple(size): 111 | if isinstance(size, (list, tuple)): 112 | assert len(size), 'Expect eval crop size contains two element, ' \ 113 | 'but received {}'.format(len(size)) 114 | return tuple(size) 115 | elif isinstance(size, numbers.Number): 116 | return tuple((size, size)) 117 | else: 118 | raise ValueError('Unsupport datatype: {}'.format(type(size))) 119 | -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/models/textseg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .segbase import SegBaseModel 6 | from .model_zoo import MODEL_REGISTRY 7 | from ..modules import _ConvBNReLU, SeparableConv2d, _ASPP, _FCNHead 8 | from ..config import cfg 9 | from IPython import embed 10 | import math 11 | 12 | __all__ = ['TextSeg'] 13 | 14 | def _resize_image(img, h, w): 15 | return F.interpolate(img, size=[h, w], mode='bilinear', align_corners=True) 16 | 17 | @MODEL_REGISTRY.register(name='TextSeg') 18 | class TextSeg(SegBaseModel): 19 | def __init__(self): 20 | super(TextSeg, self).__init__() 21 | if self.backbone.startswith('mobilenet'): 22 | c1_channels = 24 23 | c4_channels = 320 24 | else: 25 | c1_channels = 256 26 | c4_channels = 2048 27 | self.head = _DeepLabHead(self.nclass, c1_channels=c1_channels, c4_channels=c4_channels) 28 | if self.aux: 29 | self.auxlayer = _FCNHead(728, self.nclass) 30 | self.__setattr__('decoder', ['head', 'auxlayer'] if self.aux else ['head']) 31 | 32 | def forward(self, x): 33 | size = x.size()[2:] 34 | c1, _, c3, c4 = self.encoder(x) 35 | 36 | outputs = list() 37 | x = self.head(c4, c1) 38 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 39 | 40 | outputs.append(x) 41 | if self.aux: 42 | auxout = self.auxlayer(c3) 43 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 44 | outputs.append(auxout) 45 | return tuple(outputs) 46 | 47 | 48 | class _DeepLabHead(nn.Module): 49 | def __init__(self, nclass, c1_channels=256, c4_channels=2048, norm_layer=nn.BatchNorm2d): 50 | super(_DeepLabHead, self).__init__() 51 | self.use_aspp = cfg.MODEL.DEEPLABV3_PLUS.USE_ASPP 52 | self.use_decoder = cfg.MODEL.DEEPLABV3_PLUS.ENABLE_DECODER 53 | last_channels = c4_channels 54 | if self.use_aspp: 55 | self.aspp = _ASPP(c4_channels, 256) 56 | last_channels = 256 57 | if self.use_decoder: 58 | self.c1_block = _ConvBNReLU(c1_channels, 48, 1, norm_layer=norm_layer) 59 | last_channels += 48 60 | self.block = nn.Sequential( 61 | SeparableConv2d(last_channels, 256, 3, norm_layer=norm_layer, relu_first=False), 62 | SeparableConv2d(256, 256, 3, norm_layer=norm_layer, relu_first=False), 63 | nn.Conv2d(256, nclass, 1)) 64 | 65 | def forward(self, x, c1): 66 | size = c1.size()[2:] 67 | if self.use_aspp: 68 | x = self.aspp(x) 69 | if self.use_decoder: 70 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 71 | c1 = self.c1_block(c1) 72 | return self.block(torch.cat([x, c1], dim=1)) 73 | 74 | return self.block(x) -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """Seg NN Modules""" 2 | 3 | from .basic import * 4 | from .module import * 5 | from .batch_norm import get_norm -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/modules/cc_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd.function import once_differentiable 6 | from segmentron import _C 7 | 8 | __all__ = ['CrissCrossAttention', 'ca_weight', 'ca_map'] 9 | 10 | 11 | class _CAWeight(torch.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, t, f): 14 | weight = _C.ca_forward(t, f) 15 | 16 | ctx.save_for_backward(t, f) 17 | 18 | return weight 19 | 20 | @staticmethod 21 | @once_differentiable 22 | def backward(ctx, dw): 23 | t, f = ctx.saved_tensors 24 | 25 | dt, df = _C.ca_backward(dw, t, f) 26 | return dt, df 27 | 28 | 29 | class _CAMap(torch.autograd.Function): 30 | @staticmethod 31 | def forward(ctx, weight, g): 32 | out = _C.ca_map_forward(weight, g) 33 | 34 | ctx.save_for_backward(weight, g) 35 | 36 | return out 37 | 38 | @staticmethod 39 | @once_differentiable 40 | def backward(ctx, dout): 41 | weight, g = ctx.saved_tensors 42 | 43 | dw, dg = _C.ca_map_backward(dout, weight, g) 44 | 45 | return dw, dg 46 | 47 | 48 | ca_weight = _CAWeight.apply 49 | ca_map = _CAMap.apply 50 | 51 | 52 | class CrissCrossAttention(nn.Module): 53 | """Criss-Cross Attention Module""" 54 | 55 | def __init__(self, in_channels): 56 | super(CrissCrossAttention, self).__init__() 57 | self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1) 58 | self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1) 59 | self.value_conv = nn.Conv2d(in_channels, in_channels, 1) 60 | self.gamma = nn.Parameter(torch.zeros(1)) 61 | 62 | def forward(self, x): 63 | proj_query = self.query_conv(x) 64 | proj_key = self.key_conv(x) 65 | proj_value = self.value_conv(x) 66 | 67 | energy = ca_weight(proj_query, proj_key) 68 | attention = F.softmax(energy, 1) 69 | out = ca_map(attention, proj_value) 70 | out = self.gamma * out + x 71 | 72 | return out 73 | -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/modules/csrc/criss_cross_attention/ca.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace segmentron { 6 | at::Tensor ca_forward_cuda( 7 | const at::Tensor& t, 8 | const at::Tensor& f); 9 | 10 | std::tuple ca_backward_cuda( 11 | const at::Tensor& dw, 12 | const at::Tensor& t, 13 | const at::Tensor& f); 14 | 15 | at::Tensor ca_map_forward_cuda( 16 | const at::Tensor& weight, 17 | const at::Tensor& g); 18 | 19 | std::tuple ca_map_backward_cuda( 20 | const at::Tensor& dout, 21 | const at::Tensor& weight, 22 | const at::Tensor& g); 23 | 24 | 25 | at::Tensor ca_forward(const at::Tensor& t, 26 | const at::Tensor& f) { 27 | if (t.type().is_cuda()) { 28 | #ifdef WITH_CUDA 29 | return ca_forward_cuda(t, f); 30 | #else 31 | AT_ERROR("Not compiled with GPU support"); 32 | #endif 33 | } 34 | AT_ERROR("Not implemented on the CPU"); 35 | } 36 | 37 | std::tuple ca_backward(const at::Tensor& dw, 38 | const at::Tensor& t, 39 | const at::Tensor& f) { 40 | if (dw.type().is_cuda()) { 41 | #ifdef WITH_CUDA 42 | return ca_backward_cuda(dw, t, f); 43 | #else 44 | AT_ERROR("Not compiled with GPU support"); 45 | #endif 46 | } 47 | AT_ERROR("Not implemented on the CPU"); 48 | } 49 | 50 | at::Tensor ca_map_forward(const at::Tensor& weight, 51 | const at::Tensor& g) { 52 | if (weight.type().is_cuda()) { 53 | #ifdef WITH_CUDA 54 | return ca_map_forward_cuda(weight, g); 55 | #else 56 | AT_ERROR("Not compiled with GPU support"); 57 | #endif 58 | } 59 | AT_ERROR("Not implemented on the CPU"); 60 | } 61 | 62 | std::tuple ca_map_backward(const at::Tensor& dout, 63 | const at::Tensor& weight, 64 | const at::Tensor& g) { 65 | if (dout.type().is_cuda()) { 66 | #ifdef WITH_CUDA 67 | return ca_map_backward_cuda(dout, weight, g); 68 | #else 69 | AT_ERROR("Not compiled with GPU support"); 70 | #endif 71 | } 72 | AT_ERROR("Not implemented on the CPU"); 73 | } 74 | 75 | } // namespace segmentron 76 | -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/modules/csrc/vision.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "criss_cross_attention/ca.h" 3 | 4 | namespace segmentron { 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("ca_forward", &ca_forward, "ca_forward"); 8 | m.def("ca_backward", &ca_backward, "ca_backward"); 9 | m.def("ca_map_forward", &ca_map_forward, "ca_map_forward"); 10 | m.def("ca_map_backward", &ca_map_backward, "ca_map_backward"); 11 | } 12 | 13 | } // namespace segmentron 14 | -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/solver/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/segmentron/solver/__init__.py -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/solver/optimizer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | 4 | from torch import optim 5 | from segmentron.config import cfg 6 | 7 | 8 | def _set_batch_norm_attr(named_modules, attr, value): 9 | for m in named_modules: 10 | if isinstance(m[1], (nn.BatchNorm2d, nn.SyncBatchNorm)): 11 | setattr(m[1], attr, value) 12 | 13 | 14 | def _get_paramters(model): 15 | params_list = list() 16 | if hasattr(model, 'encoder') and model.encoder is not None and hasattr(model, 'decoder'): 17 | params_list.append({'params': model.encoder.parameters(), 'lr': cfg.SOLVER.LR}) 18 | if cfg.MODEL.BN_EPS_FOR_ENCODER: 19 | logging.info('Set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER)) 20 | _set_batch_norm_attr(model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER) 21 | 22 | for module in model.decoder: 23 | params_list.append({'params': getattr(model, module).parameters(), 24 | 'lr': cfg.SOLVER.LR * cfg.SOLVER.DECODER_LR_FACTOR}) 25 | 26 | if cfg.MODEL.BN_EPS_FOR_DECODER: 27 | logging.info('Set bn custom eps for bn in decoder: {}'.format(cfg.MODEL.BN_EPS_FOR_DECODER)) 28 | for module in model.decoder: 29 | _set_batch_norm_attr(getattr(model, module).named_modules(), 'eps', 30 | cfg.MODEL.BN_EPS_FOR_DECODER) 31 | else: 32 | logging.info('Model do not have encoder or decoder, params list was from model.parameters(), ' 33 | 'and arguments BN_EPS_FOR_ENCODER, BN_EPS_FOR_DECODER, DECODER_LR_FACTOR not used!') 34 | params_list = model.parameters() 35 | 36 | if cfg.MODEL.BN_MOMENTUM and cfg.MODEL.BN_TYPE in ['BN']: 37 | logging.info('Set bn custom momentum: {}'.format(cfg.MODEL.BN_MOMENTUM)) 38 | _set_batch_norm_attr(model.named_modules(), 'momentum', cfg.MODEL.BN_MOMENTUM) 39 | elif cfg.MODEL.BN_MOMENTUM and cfg.MODEL.BN_TYPE not in ['BN']: 40 | logging.info('Batch norm type is {}, custom bn momentum is not effective!'.format(cfg.MODEL.BN_TYPE)) 41 | 42 | return params_list 43 | 44 | 45 | def get_optimizer(model): 46 | parameters = _get_paramters(model) 47 | opt_lower = cfg.SOLVER.OPTIMIZER.lower() 48 | 49 | if opt_lower == 'sgd': 50 | optimizer = optim.SGD( 51 | parameters, lr=cfg.SOLVER.LR, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 52 | elif opt_lower == 'adam': 53 | optimizer = optim.Adam( 54 | parameters, lr=cfg.SOLVER.LR, eps=cfg.SOLVER.EPSILON, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 55 | elif opt_lower == 'adadelta': 56 | optimizer = optim.Adadelta( 57 | parameters, lr=cfg.SOLVER.LR, eps=cfg.SOLVER.EPSILON, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 58 | elif opt_lower == 'rmsprop': 59 | optimizer = optim.RMSprop( 60 | parameters, lr=cfg.SOLVER.LR, alpha=0.9, eps=cfg.SOLVER.EPSILON, 61 | momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 62 | else: 63 | raise ValueError("Expected optimizer method in [sgd, adam, adadelta, rmsprop], but received " 64 | "{}".format(opt_lower)) 65 | 66 | return optimizer 67 | -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utility functions.""" 2 | from __future__ import absolute_import 3 | 4 | from .download import download, check_sha1 5 | from .filesystem import makedirs 6 | -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/utils/default_setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import json 4 | import torch 5 | 6 | from .distributed import get_rank, synchronize 7 | from .logger import setup_logger 8 | from .env import seed_all_rng 9 | from ..config import cfg 10 | 11 | def default_setup(args): 12 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 13 | args.num_gpus = num_gpus 14 | args.distributed = num_gpus > 1 15 | 16 | if not args.no_cuda and torch.cuda.is_available(): 17 | # cudnn.deterministic = True 18 | torch.backends.cudnn.benchmark = True 19 | args.device = "cuda" 20 | else: 21 | args.distributed = False 22 | args.device = "cpu" 23 | if args.distributed: 24 | torch.cuda.set_device(args.local_rank) 25 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 26 | synchronize() 27 | 28 | # TODO 29 | # if args.save_pred: 30 | # outdir = '../runs/pred_pic/{}_{}_{}'.format(args.model, args.backbone, args.dataset) 31 | # if not os.path.exists(outdir): 32 | # os.makedirs(outdir) 33 | 34 | save_dir = cfg.TRAIN.MODEL_SAVE_DIR if cfg.PHASE == 'train' else None 35 | setup_logger("Segmentron", save_dir, get_rank(), filename='{}_{}_{}_{}_log.txt'.format( 36 | cfg.MODEL.MODEL_NAME, cfg.MODEL.BACKBONE, cfg.DATASET.NAME, cfg.TIME_STAMP)) 37 | 38 | logging.info("Using {} GPUs".format(num_gpus)) 39 | logging.info(args) 40 | logging.info(json.dumps(cfg, indent=8)) 41 | 42 | seed_all_rng(None if cfg.SEED < 0 else cfg.SEED + get_rank()) -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/utils/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | import requests 4 | from tqdm import tqdm 5 | 6 | def check_sha1(filename, sha1_hash): 7 | """Check whether the sha1 hash of the file content matches the expected hash. 8 | Parameters 9 | ---------- 10 | filename : str 11 | Path to the file. 12 | sha1_hash : str 13 | Expected sha1 hash in hexadecimal digits. 14 | Returns 15 | ------- 16 | bool 17 | Whether the file content matches the expected hash. 18 | """ 19 | sha1 = hashlib.sha1() 20 | with open(filename, 'rb') as f: 21 | while True: 22 | data = f.read(1048576) 23 | if not data: 24 | break 25 | sha1.update(data) 26 | 27 | sha1_file = sha1.hexdigest() 28 | l = min(len(sha1_file), len(sha1_hash)) 29 | return sha1.hexdigest()[0:l] == sha1_hash[0:l] 30 | 31 | def download(url, path=None, overwrite=False, sha1_hash=None): 32 | """Download an given URL 33 | Parameters 34 | ---------- 35 | url : str 36 | URL to download 37 | path : str, optional 38 | Destination path to store downloaded file. By default stores to the 39 | current directory with same name as in url. 40 | overwrite : bool, optional 41 | Whether to overwrite destination file if already exists. 42 | sha1_hash : str, optional 43 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified 44 | but doesn't match. 45 | Returns 46 | ------- 47 | str 48 | The file path of the downloaded file. 49 | """ 50 | if path is None: 51 | fname = url.split('/')[-1] 52 | else: 53 | path = os.path.expanduser(path) 54 | if os.path.isdir(path): 55 | fname = os.path.join(path, url.split('/')[-1]) 56 | else: 57 | fname = path 58 | 59 | if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): 60 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) 61 | if not os.path.exists(dirname): 62 | os.makedirs(dirname) 63 | 64 | print('Downloading %s from %s...'%(fname, url)) 65 | r = requests.get(url, stream=True) 66 | if r.status_code != 200: 67 | raise RuntimeError("Failed downloading url %s"%url) 68 | total_length = r.headers.get('content-length') 69 | with open(fname, 'wb') as f: 70 | if total_length is None: # no content length header 71 | for chunk in r.iter_content(chunk_size=1024): 72 | if chunk: # filter out keep-alive new chunks 73 | f.write(chunk) 74 | else: 75 | total_length = int(total_length) 76 | for chunk in tqdm(r.iter_content(chunk_size=1024), 77 | total=int(total_length / 1024. + 0.5), 78 | unit='KB', unit_scale=False, dynamic_ncols=True): 79 | f.write(chunk) 80 | 81 | if sha1_hash and not check_sha1(fname, sha1_hash): 82 | raise UserWarning('File {} is downloaded but the content hash does not match. ' \ 83 | 'The repo may be outdated or download may be incomplete. ' \ 84 | 'If the "repo_url" is overridden, consider switching to ' \ 85 | 'the default repo.'.format(fname)) 86 | 87 | return fname -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/utils/env.py: -------------------------------------------------------------------------------- 1 | # this code heavily based on detectron2 2 | 3 | import logging 4 | import numpy as np 5 | import os 6 | import random 7 | from datetime import datetime 8 | import torch 9 | 10 | __all__ = ["seed_all_rng"] 11 | 12 | 13 | def seed_all_rng(seed=None): 14 | """ 15 | Set the random seed for the RNG in torch, numpy and python. 16 | 17 | Args: 18 | seed (int): if None, will use a strong random seed. 19 | """ 20 | if seed is None: 21 | seed = ( 22 | os.getpid() 23 | + int(datetime.now().strftime("%S%f")) 24 | + int.from_bytes(os.urandom(2), "big") 25 | ) 26 | logger = logging.getLogger(__name__) 27 | logger.info("Using a generated random seed {}".format(seed)) 28 | np.random.seed(seed) 29 | torch.set_rng_state(torch.manual_seed(seed).get_state()) 30 | random.seed(seed) 31 | -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/utils/filesystem.py: -------------------------------------------------------------------------------- 1 | """Filesystem utility functions.""" 2 | from __future__ import absolute_import 3 | import os 4 | import errno 5 | import torch 6 | import logging 7 | 8 | from ..config import cfg 9 | 10 | def save_checkpoint(model, epoch, optimizer=None, lr_scheduler=None, is_best=False): 11 | """Save Checkpoint""" 12 | directory = os.path.expanduser(cfg.TRAIN.MODEL_SAVE_DIR) 13 | # directory = os.path.join(directory, '{}_{}_{}_{}'.format(cfg.MODEL.MODEL_NAME, cfg.MODEL.BACKBONE, 14 | # cfg.DATASET.NAME, cfg.TIME_STAMP)) 15 | if not os.path.exists(directory): 16 | os.makedirs(directory) 17 | filename = '{}.pth'.format(str(epoch)) 18 | filename = os.path.join(directory, filename) 19 | model_state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict() 20 | if is_best: 21 | best_filename = 'best_model.pth' 22 | best_filename = os.path.join(directory, best_filename) 23 | torch.save(model_state_dict, best_filename) 24 | else: 25 | if not os.path.exists(filename): 26 | torch.save(model_state_dict, filename) 27 | logging.info('Epoch {} model saved in: {}'.format(epoch, filename)) 28 | 29 | # remove last epoch 30 | pre_filename = '{}.pth'.format(str(epoch - 1)) 31 | pre_filename = os.path.join(directory, pre_filename) 32 | # try: 33 | # if os.path.exists(pre_filename): 34 | # os.remove(pre_filename) 35 | # except OSError as e: 36 | # logging.info(e) 37 | 38 | def makedirs(path): 39 | """Create directory recursively if not exists. 40 | Similar to `makedir -p`, you can skip checking existence before this function. 41 | Parameters 42 | ---------- 43 | path : str 44 | Path of the desired dir 45 | """ 46 | try: 47 | os.makedirs(path) 48 | except OSError as exc: 49 | if exc.errno != errno.EEXIST: 50 | raise 51 | 52 | -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/utils/filter_negative.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | # import cv2 4 | import random 5 | import numpy as np 6 | 7 | def cvt2HeatmapImg(img): 8 | img = (np.clip(img, 0, 1) * 255).astype(np.uint8) 9 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET) 10 | return img 11 | 12 | def get_negative(img): 13 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 14 | 15 | gray = cv2.GaussianBlur(gray, (3, 3), 0) 16 | canny = cv2.Canny(gray, 50, 150) 17 | kernel = np.ones((5, 5), np.uint8) 18 | canny = cv2.dilate(canny, kernel, iterations=1) 19 | cv2.imwrite("/mnt/lustre/xieenze/wuweijia/CVPR_SemiText/SemiText/TextBoxSeg/workdirs/show/edge.jpg", 20 | canny) 21 | 22 | vis_3 = np.zeros(img.shape[:2]) 23 | blurred = cv2.blur(canny, (9, 9)) 24 | 25 | _, thresh = cv2.threshold(blurred, 5, 1, cv2.THRESH_BINARY) 26 | 27 | thresh = np.array(1 - thresh) 28 | # thresh = cv2.cvtColor(thresh, cv2.COLOR_BGR2GRAY) 29 | 30 | dist = cv2.distanceTransform(thresh.copy().astype(np.uint8), cv2.DIST_L2, 5) 31 | 32 | cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX) 33 | cv2.imwrite("/mnt/lustre/xieenze/wuweijia/CVPR_SemiText/SemiText/TextBoxSeg/workdirs/show/distance.jpg", 34 | cvt2HeatmapImg(dist)) 35 | distance_map = np.array(dist > 0.1) * 1 36 | cv2.imwrite("/mnt/lustre/xieenze/wuweijia/CVPR_SemiText/SemiText/TextBoxSeg/workdirs/show/final_.jpg", 37 | distance_map*255) 38 | # distance_map = cv2.threshold(dist, 0.5, 1, cv2.THRESH_BINARY) 39 | cnts, _ = cv2.findContours(distance_map.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 40 | cnts.sort(key=lambda x: cv2.contourArea(x), reverse=True) 41 | 42 | for idx, cnt in enumerate(cnts): 43 | # cnt = cnts[idx] 44 | if idx >= 4: 45 | break 46 | if cv2.contourArea(cnt) < 2000 and idx > 0: 47 | break 48 | vis_3 = cv2.fillPoly(vis_3, [cnt], 1) 49 | 50 | vis_3 = vis_3*distance_map 51 | return vis_3 52 | 53 | def order_points(pts): 54 | # pts为轮廓坐标 55 | # 列表中存储元素分别为左上角,右上角,右下角和左下角 56 | rect = np.zeros((4, 2), dtype = "float32") 57 | # 左上角的点具有最小的和,而右下角的点具有最大的和 58 | s = pts.sum(axis = 1) 59 | rect[0] = pts[np.argmin(s)] 60 | rect[2] = pts[np.argmax(s)] 61 | # 计算点之间的差值 62 | # 右上角的点具有最小的差值, 63 | # 左下角的点具有最大的差值 64 | diff = np.diff(pts, axis = 1) 65 | rect[1] = pts[np.argmin(diff)] 66 | rect[3] = pts[np.argmax(diff)] 67 | # 返回排序坐标(依次为左上右上右下左下) 68 | return rect 69 | 70 | # 71 | image_path = "/mnt/lustre/xieenze/wuweijia/CVPR_SemiText/SemiText/TextBoxSeg/workdirs/show/IMG_0003.jpg" 72 | 73 | # image_list = os.listdir(image_path) 74 | # idx = random.randint(0,1000) 75 | # image_path_ = os.path.join(image_path,image_list[idx]) 76 | img = cv2.imread(image_path) 77 | background = get_negative(img) 78 | # cv2.imwrite("/home/xjc/Desktop/CVPR_SemiText/SemiText/PSENet_self_training/workdirs/icdar15/image_result_1.jpg",boxing) 79 | # vis_1 = img.copy() 80 | # vis_2 = img.copy() 81 | # vis_3 = np.zeros_like(img) 82 | # 83 | # gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 84 | # 85 | # gray = cv2.GaussianBlur(gray,(3,3),0) 86 | # canny = cv2.Canny(gray, 50, 150) 87 | # 88 | # blurred = cv2.blur(canny, (9, 9)) 89 | # 90 | # _, thresh = cv2.threshold(blurred, 5, 1, cv2.THRESH_BINARY) 91 | # # print(thresh.max()) 92 | # # cnts, _ = cv2.findContours( (1-thresh).copy(), cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE) 93 | # # c = sorted(cnts, key=cv2.contourArea, reverse=True)[0] 94 | # # 先找出轮廓点 95 | # #### distance transform 96 | # thresh = np.array(1-thresh) 97 | # # thresh = cv2.cvtColor(thresh, cv2.COLOR_BGR2GRAY) 98 | # 99 | # dist = cv2.distanceTransform(thresh.copy().astype(np.uint8), cv2.DIST_L2, 5) 100 | # cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX) 101 | # 102 | # distance_map = np.array(dist>0.3)*1 103 | # # distance_map = cv2.threshold(dist, 0.5, 1, cv2.THRESH_BINARY) 104 | # cnts, _ = cv2.findContours( distance_map.astype(np.uint8), cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE) 105 | # cnts.sort(key=lambda x:cv2.contourArea(x), reverse=True) 106 | # 107 | # for idx,cnt in enumerate(cnts): 108 | # # cnt = cnts[idx] 109 | # if idx>=3: 110 | # break 111 | # if cv2.contourArea(cnt)<2500 and idx>0: 112 | # break 113 | # vis_3 = cv2.fillPoly(vis_3,[cnt],255) 114 | # # rect = order_points(c.reshape(c.shape[0], 2)) 115 | # # print(rect) 116 | # print(vis_3.shape) 117 | # # cv2.fillPoly(vis_3,np.array([rect]).astype(np.int32),255) 118 | # # 合并图片 119 | # canny = np.expand_dims(canny, axis=2) 120 | # canny = np.concatenate((canny, canny, canny), axis=-1) 121 | # 122 | # # dist = np.expand_dims(dist, axis=2) 123 | # dist = np.expand_dims(dist, axis=2) 124 | # dist = np.concatenate((dist, dist, dist), axis=-1) 125 | # print(canny.shape) 126 | # boxing_list = [img, canny, dist*255, vis_3] 127 | # boxing = np.concatenate(boxing_list, axis=1) 128 | # 129 | # cv2.imwrite("/home/xjc/Desktop/CVPR_SemiText/SemiText/PSENet_self_training/workdirs/icdar15/image_result_1.jpg",boxing) -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/utils/line_chart.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | #折线图 4 | x = ["0.5","1","5","10","15","20"]#点的横坐标 5 | k1 = [200,300,500,800,1300,2000]#线1的纵坐标 6 | k2 = [40,60,100,160,240,340]#线2的纵坐标 7 | # plt.legend(loc=2,prop={'size':6}) 8 | plt.plot(x,k1,'s-',color = 'r',label="Annotation Cost")#s-:方形 9 | plt.plot(x,k2,'o-',color = 'g',label="Collection Cost")#o-:圆形 10 | # plt.xticks(fontsize=15) 11 | # plt.yticks(fontsize=15) 12 | plt.xlabel("Data Scale(k)",size=15)#横坐标名字 13 | plt.ylabel("Cost($)",size=15)#纵坐标名字 14 | plt.legend(loc = "best",prop={'size':12})#图例 15 | plt.savefig("filename.png") 16 | # plt.show() -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | __all__ = ['setup_logger'] 6 | 7 | 8 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt", mode='w'): 9 | if distributed_rank > 0: 10 | return 11 | 12 | logging.root.name = name 13 | logging.root.setLevel(logging.INFO) 14 | # don't log results for the non-master process 15 | ch = logging.StreamHandler(stream=sys.stdout) 16 | ch.setLevel(logging.DEBUG) 17 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 18 | ch.setFormatter(formatter) 19 | logging.root.addHandler(ch) 20 | 21 | if save_dir: 22 | if not os.path.exists(save_dir): 23 | os.makedirs(save_dir) 24 | fh = logging.FileHandler(os.path.join(save_dir, filename), mode=mode) # 'a+' for add, 'w' for overwrite 25 | fh.setLevel(logging.DEBUG) 26 | fh.setFormatter(formatter) 27 | logging.root.addHandler(fh) 28 | -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/utils/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser(description='Segmentron') 5 | parser.add_argument('--config-file',default="../configs/textseg_total.yaml", metavar="FILE", 6 | help='config file path') 7 | # cuda setting 8 | parser.add_argument('--no-cuda', action='store_true', default=False, 9 | help='disables CUDA training') 10 | parser.add_argument('--local_rank', type=int, default=0) 11 | # checkpoint and log 12 | parser.add_argument('--resume', type=str, default=None, 13 | help='put the path to resuming file if needed') 14 | parser.add_argument('--log-iter', type=int, default=10, 15 | help='print log every log-iter') 16 | # for evaluation 17 | parser.add_argument('--val-epoch', type=int, default=1, 18 | help='run validation every val-epoch') 19 | parser.add_argument('--skip-val', action='store_true', default=False, 20 | help='skip validation during training') 21 | # for visual 22 | parser.add_argument('--input-img', type=str, default='tools/demo_vis.png', 23 | help='path to the input image or a directory of images') 24 | # config options 25 | parser.add_argument('opts', help='See config for all options', 26 | default=None, nargs=argparse.REMAINDER) 27 | args = parser.parse_args() 28 | 29 | return args -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/utils/registry.py: -------------------------------------------------------------------------------- 1 | # this code heavily based on detectron2 2 | 3 | import logging 4 | import torch 5 | 6 | from ..config import cfg 7 | 8 | class Registry(object): 9 | """ 10 | The registry that provides name -> object mapping, to support third-party users' custom modules. 11 | 12 | To create a registry (inside segmentron): 13 | 14 | .. code-block:: python 15 | 16 | BACKBONE_REGISTRY = Registry('BACKBONE') 17 | 18 | To register an object: 19 | 20 | .. code-block:: python 21 | 22 | @BACKBONE_REGISTRY.register() 23 | class MyBackbone(): 24 | ... 25 | 26 | Or: 27 | 28 | .. code-block:: python 29 | 30 | BACKBONE_REGISTRY.register(MyBackbone) 31 | """ 32 | 33 | def __init__(self, name): 34 | """ 35 | Args: 36 | name (str): the name of this registry 37 | """ 38 | self._name = name 39 | 40 | self._obj_map = {} 41 | 42 | def _do_register(self, name, obj): 43 | assert ( 44 | name not in self._obj_map 45 | ), "An object named '{}' was already registered in '{}' registry!".format(name, self._name) 46 | self._obj_map[name] = obj 47 | 48 | def register(self, obj=None, name=None): 49 | """ 50 | Register the given object under the the name `obj.__name__`. 51 | Can be used as either a decorator or not. See docstring of this class for usage. 52 | """ 53 | if obj is None: 54 | # used as a decorator 55 | def deco(func_or_class, name=name): 56 | if name is None: 57 | name = func_or_class.__name__ 58 | self._do_register(name, func_or_class) 59 | return func_or_class 60 | 61 | return deco 62 | 63 | # used as a function call 64 | if name is None: 65 | name = obj.__name__ 66 | self._do_register(name, obj) 67 | 68 | 69 | 70 | def get(self, name): 71 | ret = self._obj_map.get(name) 72 | if ret is None: 73 | raise KeyError("No object named '{}' found in '{}' registry!".format(name, self._name)) 74 | 75 | return ret 76 | 77 | def get_list(self): 78 | return list(self._obj_map.keys()) 79 | -------------------------------------------------------------------------------- /TextBoxSeg/segmentron/utils/show.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import scipy.io 4 | import random 5 | 6 | 7 | def mask_image(image,mask_2d): 8 | h, w = mask_2d.shape 9 | 10 | # mask_3d = np.ones((h, w), dtype="uint8") * 255 11 | mask_3d_color = np.zeros((h, w, 3), dtype="uint8") 12 | # mask_3d[mask_2d[:, :] == 1] = 0 13 | 14 | image.astype("uint8") 15 | # beijing = cv2.bitwise_and(image, image, mask=mask_3d) 16 | # R = random.randint(100,200) 17 | # G = random.randint(100,200) 18 | # B = random.randint(100,200) 19 | 20 | # mask_3d_color[mask_2d[:, :] == 1] = np.random.randint(50, 200, (1, 3), dtype=np.uint8) 21 | # 22 | # add_image = cv2.add(image, mask_3d_color) 23 | 24 | mask = (mask_2d!=0).astype(bool) 25 | 26 | mask_3d_color[mask_2d[:, :] == 1] = np.random.randint(0, 255, (1, 3), dtype=np.uint8) 27 | image[mask] = image[mask] * 0.5 + mask_3d_color[mask] * 0.5 28 | 29 | return image 30 | 31 | def read_mat_lindes(p): 32 | f = scipy.io.loadmat(p) 33 | return f 34 | 35 | def get_bboxes(gt_path): 36 | 37 | bboxes = [] 38 | tags = [] 39 | point_nums = [] 40 | data = read_mat_lindes(gt_path) 41 | data_polygt = data['polygt'] 42 | # for lines in data_polygt: 43 | for i, lines in enumerate(data_polygt): 44 | X = np.array(lines[1]) 45 | Y = np.array(lines[3]) 46 | 47 | point_num = len(X[0]) 48 | point_nums.append(point_num) 49 | word = np.array(lines[4]) 50 | if len(word) == 0: 51 | word = '?' 52 | else: 53 | word = str(word[0].encode("utf-8")) 54 | 55 | if "#" in word: 56 | tags.append(True) 57 | else: 58 | tags.append(False) 59 | arr = np.concatenate([X, Y]).T 60 | 61 | # x, y, w_1, h_1 = cv2.boundingRect(arr.astype('int32')) 62 | # box=[x,y,x+w_1,y,x+w_1,y+h_1,x,y+h_1] 63 | # box = np.asarray(box) / ([w * 1.0, h * 1.0] * 4) 64 | 65 | # box = [] 66 | 67 | # for i in range(point_num): 68 | # box.append(arr[i][0]) 69 | # box.append(arr[i][1]) 70 | 71 | box = np.asarray(arr) 72 | bboxes.append(box) 73 | 74 | return bboxes 75 | # gt_path = "/mnt/lustre/share_data/xieenze/wuweijia/Total_text/pseudolabel_attention/Train_bbox/poly_gt_img89.mat" 76 | # image = cv2.imread("/mnt/lustre/share_data/xieenze/wuweijia/Total_text/Images/Train/img89.jpg") 77 | # gt = get_bboxes(gt_path) 78 | # for i,gt_one in enumerate(gt): 79 | # if i==1: 80 | # x_,y_ = 586, 713 81 | # # continue 82 | # cv2.drawContours(image, [gt_one], 0, (0, 0, 255), 15) 83 | # 84 | # gt_bpath = "/mnt/lustre/share_data/xieenze/wuweijia/Total_text/pseudolabel_attention/Train_pseudo/poly_gt_img89.mat" 85 | # 86 | # 87 | # mask = cv2.imread("/mnt/lustre/xieenze/wuweijia/CVPR_SemiText/SemiText/TextBoxSeg/workdirs/show/resultd.png") 88 | # patch = image[713:1700, 586:1089] 89 | # mask = cv2.resize(mask,(patch.shape[1],patch.shape[0]))[:,:,0] 90 | # try: 91 | # _, contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 92 | # except: 93 | # contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 94 | # gt = get_bboxes(gt_bpath) 95 | # for i,gt_one in enumerate(gt): 96 | # if i==1: 97 | # # continue 98 | # gt_one = contours[0] 99 | # print(gt_one.shape) 100 | # gt_one = gt_one[:, 0, :] 101 | # 102 | # gt_one[:,0] += 586 103 | # gt_one[:, 1] += 713 104 | # 105 | # mask_1 = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) 106 | # cv2.fillPoly(mask_1, [gt_one], 1) 107 | # image = mask_image(image, mask_1) 108 | # cv2.drawContours(image, [gt_one], 0, (0, 255, 0), 10) 109 | # 110 | # 111 | # cv2.imwrite("/mnt/lustre/xieenze/wuweijia/CVPR_SemiText/SemiText/TextBoxSeg/workdirs/show/show.jpg",image) 112 | # mask_2d = cv2.imread("/home/xjc/Desktop/CVPR_SemiText/SemiText/TextBoxSeg/demo/resultd.png",cv2.IMREAD_GRAYSCALE) 113 | 114 | # print(mask_2d.shape) 115 | 116 | # cv2.imwrite("/home/xjc/Desktop/CVPR_SemiText/SemiText/TextBoxSeg/demo/show.jpg",add_image) 117 | -------------------------------------------------------------------------------- /TextBoxSeg/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import glob 4 | import os 5 | from setuptools import find_packages, setup 6 | import torch 7 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension 8 | 9 | torch_ver = [int(x) for x in torch.__version__.split(".")[:2]] 10 | assert torch_ver >= [1, 1], "Requires PyTorch >= 1.1" 11 | 12 | 13 | def get_extensions(): 14 | this_dir = os.path.dirname(os.path.abspath(__file__)) 15 | extensions_dir = os.path.join(this_dir, "segmentron", "modules", "csrc") 16 | 17 | main_source = os.path.join(extensions_dir, "vision.cpp") 18 | sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp")) 19 | source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob( 20 | os.path.join(extensions_dir, "*.cu") 21 | ) 22 | 23 | sources = [main_source] + sources 24 | 25 | extension = CppExtension 26 | 27 | extra_compile_args = {"cxx": []} 28 | define_macros = [] 29 | 30 | if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1": 31 | extension = CUDAExtension 32 | sources += source_cuda 33 | define_macros += [("WITH_CUDA", None)] 34 | extra_compile_args["nvcc"] = [ 35 | "-DCUDA_HAS_FP16=1", 36 | "-D__CUDA_NO_HALF_OPERATORS__", 37 | "-D__CUDA_NO_HALF_CONVERSIONS__", 38 | "-D__CUDA_NO_HALF2_OPERATORS__", 39 | ] 40 | 41 | # It's better if pytorch can do this by default .. 42 | CC = os.environ.get("CC", None) 43 | if CC is not None: 44 | extra_compile_args["nvcc"].append("-ccbin={}".format(CC)) 45 | 46 | sources = [os.path.join(extensions_dir, s) for s in sources] 47 | 48 | include_dirs = [extensions_dir] 49 | 50 | ext_modules = [ 51 | extension( 52 | "segmentron._C", 53 | sources, 54 | include_dirs=include_dirs, 55 | define_macros=define_macros, 56 | extra_compile_args=extra_compile_args, 57 | ) 58 | ] 59 | 60 | return ext_modules 61 | 62 | 63 | setup( 64 | name="segmentron", 65 | version="0.1", 66 | author="LikeLy-Journey", 67 | url="https://github.com/LikeLy-Journey/SegmenTron", 68 | description="platform for semantic segmentation base on pytorch.", 69 | # packages=find_packages(exclude=("configs", "tests")), 70 | # python_requires=">=3.6", 71 | # install_requires=[ 72 | # "termcolor>=1.1", 73 | # "Pillow", 74 | # "yacs>=0.1.6", 75 | # "tabulate", 76 | # "cloudpickle", 77 | # "matplotlib", 78 | # "tqdm>4.29.0", 79 | # "tensorboard", 80 | # ], 81 | # extras_require={"all": ["shapely", "psutil"]}, 82 | ext_modules=get_extensions(), 83 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 84 | ) 85 | -------------------------------------------------------------------------------- /TextBoxSeg/tools/demo_ic15.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | 5 | cur_path = os.path.abspath(os.path.dirname(__file__)) 6 | root_path = os.path.split(cur_path)[0] 7 | sys.path.append(root_path) 8 | 9 | from torchvision import transforms 10 | from PIL import Image 11 | from segmentron.utils.visualize import get_color_pallete 12 | from segmentron.models.model_zoo import get_segmentation_model 13 | from segmentron.utils.options import parse_args 14 | from segmentron.utils.default_setup import default_setup 15 | from segmentron.config import cfg 16 | from IPython import embed 17 | import numpy as np 18 | from tqdm import trange 19 | import cv2 20 | 21 | def demo(): 22 | args = parse_args() 23 | cfg.update_from_file(args.config_file) 24 | cfg.PHASE = 'test' 25 | cfg.ROOT_PATH = root_path 26 | cfg.check_and_freeze() 27 | default_setup(args) 28 | 29 | # output folder 30 | output_dir = 'demo/trash/IC15' 31 | if not os.path.exists(output_dir): 32 | os.makedirs(output_dir) 33 | 34 | # image transform 35 | transform = transforms.Compose([ 36 | transforms.ToTensor(), 37 | transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD), 38 | ]) 39 | 40 | model = get_segmentation_model().to(args.device) 41 | model.eval() 42 | 43 | #get img_patch from IC15 44 | if os.path.exists('/mnt/lustre/share_data/xieenze/xez_space/Text/ICDAR2015/'): 45 | ic15_root_path = '/mnt/lustre/share_data/xieenze/xez_space/Text/ICDAR2015/' 46 | else: 47 | ic15_root_path = '/mnt/lustre/share/xieenze/Text/ICDAR2015/' 48 | ic15_train_data = ic15_root_path + 'ch4_training_images' 49 | ic15_train_gt = ic15_root_path + 'ch4_training_localization_transcription_gt' 50 | assert os.path.exists(ic15_train_data) and os.path.exists(ic15_train_gt) 51 | 52 | 53 | patch_imgs = [] 54 | for i in trange(1, 501): 55 | img_path = 'img_{}.jpg'.format(i) 56 | img_path = os.path.join(ic15_train_data, img_path) 57 | gt_path = 'gt_img_{}.txt'.format(i) 58 | gt_path = os.path.join(ic15_train_gt, gt_path) 59 | 60 | if os.path.exists(gt_path) and os.path.exists(img_path): 61 | img, boxes = parse_img_gt(img_path, gt_path) 62 | img = np.array(img) 63 | if boxes == []: 64 | continue 65 | for box in boxes: 66 | x1, y1, x2, y2 = box 67 | patch = img[y1:y2 + 1, x1:x2 + 1] 68 | patch_imgs.append(Image.fromarray(patch)) 69 | # 先只测500张 70 | if len(patch_imgs) > 500: 71 | break 72 | else: 73 | print(img_path) 74 | print('total patch images:{}'.format(len(patch_imgs))) 75 | 76 | pool_imgs, pool_masks = [], [] 77 | count = 0 78 | for image in patch_imgs: 79 | # image = Image.open(img_path).convert('RGB') 80 | resized_img = image.resize(cfg.TRAIN.BASE_SIZE) 81 | resized_img = transform(resized_img).unsqueeze(0).to(args.device) 82 | with torch.no_grad(): 83 | output = model(resized_img) 84 | 85 | pred = torch.argmax(output[0], 1).squeeze(0).cpu().data.numpy() 86 | 87 | img = np.array(image.resize(cfg.TRAIN.BASE_SIZE)) 88 | mask = np.array(get_color_pallete(pred, cfg.DATASET.NAME))[:,:,None].repeat(3,-1) * 255 89 | if len(pool_imgs)<20: 90 | pool_imgs.append(img) 91 | pool_masks.append(mask) 92 | else: 93 | big_img = np.concatenate(pool_imgs, axis=0) 94 | big_mask = np.concatenate(pool_masks, axis=0) 95 | big_img_mask = Image.fromarray(np.concatenate([big_img, big_mask], axis=1)) 96 | big_img_mask.save('{}/{}.png'.format(output_dir, count)) 97 | print('{}/{}.png'.format(output_dir, count)) 98 | count += 1 99 | pool_imgs, pool_masks = [], [] 100 | 101 | 102 | def parse_img_gt(img_path, gt_path): 103 | img = Image.open(img_path) 104 | with open(gt_path,'r') as f: 105 | data=f.readlines() 106 | boxes = [] 107 | for d in data: 108 | d = d.replace('\n','').split(',') 109 | polygon = d[:8]; text = d[8] 110 | if "#" in text: 111 | continue #过滤掉ignore的 112 | polygon = [int(i.replace('\ufeff','')) for i in polygon] 113 | polygon_np = np.array(polygon).reshape([-1, 2]) 114 | x, y, w, h = cv2.boundingRect(polygon_np) 115 | boxes.append([x,y,x+w,y+h]) 116 | return img,boxes 117 | 118 | if __name__ == '__main__': 119 | demo() -------------------------------------------------------------------------------- /TextBoxSeg/tools/show/108_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/tools/show/108_image.jpg -------------------------------------------------------------------------------- /TextBoxSeg/tools/show/108_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/tools/show/108_mask.jpg -------------------------------------------------------------------------------- /TextBoxSeg/tools/show/108_mask_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/tools/show/108_mask_1.jpg -------------------------------------------------------------------------------- /TextBoxSeg/tools/show/108_mask_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/tools/show/108_mask_2.jpg -------------------------------------------------------------------------------- /TextBoxSeg/tools/show/108_mask_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/tools/show/108_mask_3.jpg -------------------------------------------------------------------------------- /TextBoxSeg/tools/show/108_mask_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/tools/show/108_mask_4.jpg -------------------------------------------------------------------------------- /TextBoxSeg/tools/show/108_mask_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/tools/show/108_mask_5.jpg -------------------------------------------------------------------------------- /TextBoxSeg/tools/show/108_mask_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/tools/show/108_mask_6.jpg -------------------------------------------------------------------------------- /TextBoxSeg/tools/show/108_origin_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/TextBoxSeg/tools/show/108_origin_image.jpg -------------------------------------------------------------------------------- /TextBoxSeg/tools/test_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | 5 | cur_path = os.path.abspath(os.path.dirname(__file__)) 6 | root_path = os.path.split(cur_path)[0] 7 | sys.path.append(root_path) 8 | 9 | from torchvision import transforms 10 | from PIL import Image 11 | from segmentron.utils.visualize import get_color_pallete 12 | from segmentron.models.model_zoo import get_segmentation_model 13 | from segmentron.utils.options import parse_args 14 | from segmentron.utils.default_setup import default_setup 15 | from segmentron.config import cfg 16 | from IPython import embed 17 | 18 | def demo(): 19 | args = parse_args() 20 | cfg.update_from_file(args.config_file) 21 | cfg.PHASE = 'test' 22 | cfg.ROOT_PATH = root_path 23 | cfg.check_and_freeze() 24 | default_setup(args) 25 | 26 | # output folder 27 | output_dir = 'demo/result' 28 | if not os.path.exists(output_dir): 29 | os.makedirs(output_dir) 30 | 31 | # image transform 32 | transform = transforms.Compose([ 33 | transforms.ToTensor(), 34 | transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD), 35 | ]) 36 | 37 | model = get_segmentation_model().to(args.device) 38 | model.eval() 39 | 40 | if os.path.isdir(args.input_img): 41 | img_paths = [os.path.join(args.input_img, x) for x in os.listdir(args.input_img)] 42 | else: 43 | img_paths = [args.input_img] 44 | for img_path in img_paths: 45 | image = Image.open(img_path).convert('RGB') 46 | resized_img = image.resize(cfg.TRAIN.BASE_SIZE) 47 | resized_img = transform(resized_img).unsqueeze(0).to(args.device) 48 | with torch.no_grad(): 49 | output = model(resized_img) 50 | 51 | pred = torch.argmax(output[0], 1).squeeze(0).cpu().data.numpy() 52 | mask = get_color_pallete(pred, cfg.DATASET.NAME).resize(image.size) 53 | outname = os.path.splitext(os.path.split(img_path)[-1])[0] + '.png' 54 | print('generate {}'.format(outname)) 55 | mask.save(os.path.join(output_dir, outname)) 56 | 57 | 58 | if __name__ == '__main__': 59 | demo() -------------------------------------------------------------------------------- /image/1606325537.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weijiawu/Polygon-free-Unconstrained-Scene-Text-Detection-with-Box-Annotations/a377c12b31259c56995f730423209ea79af68009/image/1606325537.png --------------------------------------------------------------------------------