├── .gitignore ├── README.md ├── convert ├── __init__.py ├── crop_rec.py ├── det │ ├── ArtS2json.py │ ├── LSVT2json.py │ ├── MTWI20182json.py │ ├── RcCTS2json.py │ ├── SROIE2json.py │ ├── SynthText800k2json.py │ ├── __init__.py │ ├── check_json.py │ ├── coco_text.py │ ├── coco_text2json.py │ ├── convert2jpg.py │ ├── icdar20152json.py │ ├── icdar2017rctw2json.py │ ├── iflytek_text_detection.py │ └── mlt20192json.py ├── move_imgs.py ├── rec │ ├── 360w2txt.py │ ├── __init__.py │ ├── baidu2txt.py │ └── mjsyhtn2txt.py ├── simsun.ttc └── utils.py ├── dataset ├── __init__.py ├── convert_det2lmdb.py ├── det.py ├── det_lmdb.py └── rec.py ├── gt_detection.json └── ocr公开数据集信息.xlsx /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.pth 3 | *.pyc 4 | *.pyo 5 | *.log 6 | *.tmp 7 | *.pkl 8 | __pycache__/ 9 | .idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Todo 2 | 3 | - [x] 提供数据集百度云链接 4 | - [x] 数据集转换为统一格式(检测和识别) 5 | - [x] icdar2015 6 | - [x] MLT2019 7 | - [x] COCO-Text_v2 8 | - [x] ReCTS 9 | - [x] SROIE 10 | - [x] ArT 11 | - [x] LSVT 12 | - [x] Synth800k 13 | - [x] icdar2017rctw 14 | - [x] MTWI 2018 15 | - [x] 百度中文场景文字识别 16 | - [x] mjsynth 17 | - [x] Synthetic Chinese String Dataset(360万中文数据集) 18 | - [x] 英文识别数据大礼包 19 | - [x] 提供读取脚本 20 | 21 | # 下载 22 | 下载数据集之后,记得修改标注文件里对应的路径为自己的路径 23 | 24 | 通过百度网盘分享的文件:所有数据集一起压… 25 | 链接:https://pan.baidu.com/s/1TkTWql2XxqPLDnFmVvHsUA?pwd=4358  26 | 提取码:4358 27 | 复制这段内容打开「百度网盘APP 即可获取」 28 | 29 | # 数据集 30 | 31 | | 数据集 | 主页 | 适用情况 | 数据情况 | 标注形式 | 说明 | 32 | | ----------------------------------- | ------------------------------------------------------------ | --------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | 33 | | ICDAR2015 | https://rrc.cvc.uab.es/?ch=4 | 检测&识别 | 语言: 英文 train:1,000 test:500 | x1, y1, x2, y2, x3, y3, x4, y4, transcription | 坐标: x1, y1, x2, y2, x3, y3, x4, y4 transcription : 框内的文字信息 | 34 | | MLT2019 | https://rrc.cvc.uab.es/?ch=15 | 检测&识别 | 语言: 混合 train:10,000 test:10,000 | x1,y1,x2,y2,x3,y3,x4,y4,script,transcription | 坐标: x1, y1, x2, y2, x3, y3, x4, y4 script: 文字所属语言 transcription : 框内的文字信息 | 35 | | COCO-Text_v2 | https://bgshih.github.io/cocotext/ | 检测&识别 | 语言: 混合 train:43,686 validation:10,000 test:10,000 | json | | 36 | | ReCTS | https://rrc.cvc.uab.es/?ch=12&com=introduction | 检测&识别 | 语言: 混合 train:20,000 test:5,000 | { “chars”: [ {“points”: [x1,y1,x2,y2,x3,y3,x4,y4], “transcription” : “trans1”, "ignore":0 }, {“points”: [x1,y1,x2,y2,x3,y3,x4,y4], “transcription” : “trans2”, " ignore ":0 }], “lines”: [ {“points”: [x1,y1,x2,y2,x3,y3,x4,y4] , “transcription” : “trans3”, "ignore ":0 }], } | points: x1,y1,x2,y2,x3,y3,x4,y4 chars: 字符级别的标注 lines: 行级别的标注. transcription : 框内的文字信息 ignore: 0:不忽略,1:忽略 | 37 | | SROIE | https://rrc.cvc.uab.es/?ch=13 | 检测&识别 | 语言: 英文 train:699 test:400 | x1, y1, x2, y2, x3, y3, x4, y4, transcription | 坐标: x1, y1, x2, y2, x3, y3, x4, y4 transcription : 框内的文字信息 | 38 | | ArT(已包含Total-Text和SCUT-CTW1500) | https://rrc.cvc.uab.es/?ch=14 | 检测&识别 | 语言: 混合 train: 5,603 test: 4,563 | { “gt_1”: [ {“points”: [[x1, y1], [x2, y2], …, [xn, yn]], “transcription” : “trans1”, “language” : “Latin”, "illegibility": false }, {“points”: [[x1, y1], [x2, y2], …, [xn, yn]], “transcription” : “trans2”, “language” : “Chinese”, "illegibility": false }], } | points: x1,y1,x2,y2,x3,y3,x4,y4…xn,yn transcription : 框内的文字信息 language: 语言信息 illegibility: 是否模糊 | 39 | | LSVT | https://rrc.cvc.uab.es/?ch=16 | 检测&识别 | 语言: 混合 全标注 train: 30,000 test: 20,000 只标注文本 400,000 | { “gt_1”: [ {“points”: [[x1, y1], [x2, y2], …, [xn, yn]], “transcription” : “trans1”, "illegibility": false }, {“points”: [[x1, y1], [x2, y2], …, [xn, yn]], “transcription” : “trans2”, "illegibility": false }], } | points: x1,y1,x2,y2,x3,y3,x4,y4…xn,yn transcription : 框内的文字信息 illegibility: 是否模糊 | 40 | | Synth800k | http://www.robots.ox.ac.uk/~vgg/data/scenetext/ | 检测&识别 | 语言: 英文 800,000 | imnames: wordBB: charBB: txt: | imnames: 文件名称 wordBB: 2*4*n,每张图像内的文本框 charBB: 2*4*n,每张图像内的字符框 txt: 每张图形内的字符串 | 41 | | icdar2017rctw | https://blog.csdn.net/wl1710582732/article/details/89761818 | 检测&识别 | 语言: 混合 train:8,034 test:4,229 | x1,y1,x2,y2,x3,y3,x4,y4,<识别难易程度>,transcription | 坐标: x1, y1, x2, y2, x3, y3, x4, y4 transcription : 框内的文字信息 | 42 | | MTWI 2018 | [识别: https://tianchi.aliyun.com/competition/entrance/231684/introduction](https://tianchi.aliyun.com/competition/entrance/231684/introduction) [检测: https://tianchi.aliyun.com/competition/entrance/231685/introduction](https://tianchi.aliyun.com/competition/entrance/231684/introduction) | 检测&识别 | 语言: 混合 train:10,000 test:10,000 | x1, y1, x2, y2, x3, y3, x4, y4, transcription | 坐标: x1, y1, x2, y2, x3, y3, x4, y4 transcription : 框内的文字信息 | 43 | | 百度中文场景文字识别 | https://aistudio.baidu.com/aistudio/competition/detail/20 | 识别 | 语言: 混合 train:未统计 test:未统计 | h,w,name,value | h: 图片高度 w: 图片宽度 name: 图片名 value: 图片上文字 | 44 | | mjsynth | http://www.robots.ox.ac.uk/~vgg/data/text/ | 识别 | 语言: 英文 9,000,000 | - | - | 45 | | Synthetic Chinese String Dataset(360万中文数据集) | 链接:https://pan.baidu.com/s/1jefn4Jh4jHjQdiWoanjKpQ 提取码:spyi | 识别 | 语言: 混合 300k | - | - | 46 | | 英文识别数据大礼包(https://github.com/clovaai/deep-text-recognition-benchmark) 训练:MJSynth和SynthText 验证:IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE | 链接:https://pan.baidu.com/s/1KSNLv4EY3zFWHpBYlpFCBQ 提取码:rryk | 识别 | 语言: 英文 | - | - | 47 | 48 | # 数据生成工具 49 | 50 | https://github.com/TianzhongSong/awesome-SynthText 51 | 52 | # 数据集读取脚本 53 | - [检测读取脚本](dataset/det.py) 54 | - [识别读取脚本](dataset/rec.py) 55 | -------------------------------------------------------------------------------- /convert/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/18 14:11 3 | # @Author : zhoujun -------------------------------------------------------------------------------- /convert/crop_rec.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/20 20:55 3 | # @Author : zhoujun 4 | """ 5 | 根据生成的json文件 裁剪出识别训练数据 6 | """ 7 | import os 8 | import cv2 9 | import math 10 | import shutil 11 | import pathlib 12 | import numpy as np 13 | from tqdm import tqdm 14 | from PIL import Image 15 | from matplotlib import pyplot as plt 16 | 17 | # 支持中文 18 | plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 19 | plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 20 | 21 | from convert.utils import load_gt, save 22 | 23 | 24 | def order_points(pts): 25 | # 初始化坐标点 26 | rect = np.zeros((4, 2), dtype="float32") 27 | # 获取左上角和右下角坐标点 28 | s = pts.sum(axis=1) 29 | rect[0] = pts[np.argmin(s)] 30 | rect[2] = pts[np.argmax(s)] 31 | # 分别计算左上角和右下角的离散差值 32 | diff = np.diff(pts, axis=1) 33 | rect[1] = pts[np.argmin(diff)] 34 | rect[3] = pts[np.argmax(diff)] 35 | return rect 36 | 37 | 38 | def four_point_transform(image, pts): 39 | # 获取坐标点,并将它们分离开来 40 | rect = original_coordinate_transformation(pts) 41 | (tl, tr, br, bl) = rect 42 | 43 | # 计算新图片的宽度值,选取水平差值的最大值 44 | widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2)) 45 | widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2)) 46 | maxWidth = max(int(widthA), int(widthB)) 47 | 48 | # 计算新图片的高度值,选取垂直差值的最大值 49 | heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2)) 50 | heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2)) 51 | maxHeight = max(int(heightA), int(heightB)) 52 | 53 | # 构建新图片的4个坐标点 54 | dst = np.array([ 55 | [0, 0], 56 | [maxWidth - 1, 0], 57 | [maxWidth - 1, maxHeight - 1], 58 | [0, maxHeight - 1]], dtype="float32") 59 | 60 | # 获取仿射变换矩阵并应用它 61 | M = cv2.getPerspectiveTransform(rect, dst) 62 | # 进行仿射变换 63 | warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight)) 64 | 65 | # 返回变换后的结果 66 | return warped 67 | 68 | 69 | def original_coordinate_transformation(polygon): 70 | """ 71 | 调整坐标顺序为: 72 | x1,y1 x2,y2 73 | x4,y4 x3,y3 74 | :param polygon: 75 | :return: 76 | """ 77 | x1, y1, x2, y2, x3, y3, x4, y4 = polygon.astype(float).reshape(-1) 78 | # 判断x1和x3大小,x3调整为大的数 79 | if x1 > x3: 80 | x1, y1, x3, y3 = x3, y3, x1, y1 81 | # 判断x2和x4大小,x4调整为大的数 82 | if x2 > x4: 83 | x2, y2, x4, y4 = x4, y4, x2, y2 84 | # 判断y1和y2大小,y1调整为大的数 85 | if y2 > y1: 86 | x2, y2, x1, y1 = x1, y1, x2, y2 87 | # 判断y3和y4大小,y4调整为大的数 88 | if y3 > y4: 89 | x3, y3, x4, y4 = x4, y4, x3, y3 90 | return np.array([[x2, y2], [x3, y3], [x4, y4], [x1, y1]], dtype=np.float32) 91 | 92 | 93 | def crop(save_gt_path, json_path, save_path): 94 | if os.path.exists(save_path): 95 | shutil.rmtree(save_path, ignore_errors=True) 96 | os.makedirs(save_path, exist_ok=True) 97 | data = load_gt(json_path) 98 | file_list = [] 99 | for img_path, gt in tqdm(data.items()): 100 | img = cv2.imread(img_path) 101 | np_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 102 | img = Image.fromarray(np_img) 103 | img_name = pathlib.Path(img_path).stem 104 | for i, (polygon, text, illegibility, language) in enumerate( 105 | zip(gt['polygons'], gt['texts'], gt['illegibility_list'], gt['language_list'])): 106 | if illegibility: 107 | continue 108 | polygon = np.array(polygon) 109 | roi_img_save_path = os.path.join(save_path, '{}_{}.jpg'.format(img_name, i)) 110 | # 对于只有四个点的图片,反射变换后存储 111 | if len(polygon) == 4: 112 | roi_img = four_point_transform(np_img, polygon) 113 | roi_img = Image.fromarray(roi_img).convert('RGB') 114 | else: 115 | x_min = polygon[:, 0].min() 116 | x_max = polygon[:, 0].max() 117 | y_min = polygon[:, 1].min() 118 | y_max = polygon[:, 1].max() 119 | roi_img = img.crop((x_min, y_min, x_max, y_max)) 120 | roi_img.save(roi_img_save_path) 121 | roi_w, roi_h = roi_img.size 122 | file_list.append('{}\t{}\t{}\t{}\t{}'.format(roi_img_save_path, text, roi_w, roi_h, language)) 123 | # plt.title(text) 124 | # plt.imshow(roi_img) 125 | # plt.show() 126 | save(file_list, save_gt_path) 127 | 128 | 129 | if __name__ == '__main__': 130 | json_path = r'D:\dataset\icdar2017rctw\detection\train.json' 131 | save_path = r'D:\dataset\icdar2017rctw\recognition\train' 132 | gt_path = pathlib.Path(save_path).parent / 'train.txt' 133 | crop(gt_path, json_path, save_path) 134 | -------------------------------------------------------------------------------- /convert/det/ArtS2json.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/18 14:12 3 | # @Author : zhoujun 4 | """ 5 | 将icdar2015数据集转换为统一格式 6 | """ 7 | import os 8 | from tqdm import tqdm 9 | from convert.utils import load, save 10 | 11 | 12 | def cvt_det(gt_path, save_path, img_folder): 13 | """ 14 | 将icdar2015格式的gt转换为json格式 15 | :param gt_path: 16 | :param save_path: 17 | :return: 18 | """ 19 | gt_dict = {'data_root': img_folder} 20 | data_list = [] 21 | origin_gt = load(gt_path) 22 | for img_name, gt in tqdm(origin_gt.items()): 23 | cur_gt = {'img_name': img_name + '.jpg', 'annotations': []} 24 | for line in gt: 25 | cur_line_gt = {'polygon': [], 'text': '', 'illegibility': False, 'language': 'Latin'} 26 | chars_gt = [{'polygon': [], 'char': '', 'illegibility': False, 'language': 'Latin'}] 27 | cur_line_gt['chars'] = chars_gt 28 | # 字符串级别的信息 29 | cur_line_gt['polygon'] = line['points'] 30 | cur_line_gt['text'] = line['transcription'] 31 | cur_line_gt['illegibility'] = line['illegibility'] 32 | cur_line_gt['language'] = line['language'] 33 | cur_gt['annotations'].append(cur_line_gt) 34 | data_list.append(cur_gt) 35 | gt_dict['data_list'] = data_list 36 | save(gt_dict, save_path) 37 | 38 | 39 | def cvt_rec(gt_path, save_path, img_folder): 40 | origin_gt = load(gt_path) 41 | file_list = [] 42 | for img_name, gt in tqdm(origin_gt.items()): 43 | assert len(gt) == 1 44 | gt = gt[0] 45 | img_path = os.path.join(img_folder, img_name + '.jpg') 46 | file_list.append(img_path + '\t' + gt['transcription'] + '\t' + gt['language']) 47 | save(file_list, save_path) 48 | 49 | 50 | if __name__ == '__main__': 51 | gt_path = r'D:\dataset\Art\detection\gt\train_labels.json' 52 | img_folder = r'D:\dataset\Art\detection\train_images' 53 | save_path = r'D:\dataset\Art\detection\train.json' 54 | cvt_det(gt_path, save_path, img_folder) 55 | -------------------------------------------------------------------------------- /convert/det/LSVT2json.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/18 14:12 3 | # @Author : zhoujun 4 | """ 5 | 将icdar2015数据集转换为统一格式 6 | """ 7 | import os 8 | from tqdm import tqdm 9 | from convert.utils import load, save 10 | 11 | 12 | def cvt(gt_path, save_path, img_folder): 13 | """ 14 | 将icdar2015格式的gt转换为json格式 15 | :param gt_path: 16 | :param save_path: 17 | :return: 18 | """ 19 | gt_dict = {'data_root': img_folder} 20 | data_list = [] 21 | origin_gt = load(gt_path) 22 | for img_name, gt in tqdm(origin_gt.items()): 23 | cur_gt = {'img_name': img_name + '.jpg', 'annotations': []} 24 | for line in gt: 25 | cur_line_gt = {'polygon': [], 'text': '', 'illegibility': False, 'language': 'Latin'} 26 | chars_gt = [{'polygon': [], 'char': '', 'illegibility': False, 'language': 'Latin'}] 27 | cur_line_gt['chars'] = chars_gt 28 | # 字符串级别的信息 29 | cur_line_gt['polygon'] = line['points'] 30 | cur_line_gt['text'] = line['transcription'] 31 | cur_line_gt['illegibility'] = line['illegibility'] 32 | cur_gt['annotations'].append(cur_line_gt) 33 | data_list.append(cur_gt) 34 | gt_dict['data_list'] = data_list 35 | save(gt_dict, save_path) 36 | 37 | 38 | if __name__ == '__main__': 39 | gt_path = r'D:\dataset\LSVT\detection\train_full_labels.json' 40 | img_folder = r'D:\dataset\LSVT\detection\imgs' 41 | save_path = r'D:\dataset\LSVT\detection\train.json' 42 | cvt(gt_path, save_path, img_folder) 43 | -------------------------------------------------------------------------------- /convert/det/MTWI20182json.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/18 14:12 3 | # @Author : zhoujun 4 | """ 5 | 将icdar2015数据集转换为统一格式 6 | """ 7 | import pathlib 8 | from tqdm import tqdm 9 | from convert.utils import load, save, get_file_list 10 | 11 | 12 | def cvt(gt_path, save_path, img_folder): 13 | """ 14 | 将icdar2015格式的gt转换为json格式 15 | :param gt_path: 16 | :param save_path: 17 | :return: 18 | """ 19 | gt_dict = {'data_root': img_folder} 20 | data_list = [] 21 | for file_path in tqdm(get_file_list(gt_path, p_postfix=['.txt'])): 22 | content = load(file_path) 23 | file_path = pathlib.Path(file_path) 24 | img_name = file_path.name.replace('.txt', '.jpg') 25 | cur_gt = {'img_name': img_name, 'annotations': []} 26 | for line in content: 27 | cur_line_gt = {'polygon': [], 'text': '', 'illegibility': False, 'language': 'Latin'} 28 | chars_gt = [{'polygon': [], 'char': '', 'illegibility': False, 'language': 'Latin'}] 29 | cur_line_gt['chars'] = chars_gt 30 | line = line.split(',') 31 | # 字符串级别的信息 32 | x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, line[:8])) 33 | cur_line_gt['polygon'] = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]] 34 | cur_line_gt['text'] = line[-1] 35 | cur_line_gt['illegibility'] = True if cur_line_gt['text'] == '*' or cur_line_gt['text'] == '###' else False 36 | cur_gt['annotations'].append(cur_line_gt) 37 | data_list.append(cur_gt) 38 | gt_dict['data_list'] = data_list 39 | save(gt_dict, save_path) 40 | 41 | 42 | if __name__ == '__main__': 43 | gt_path = r'D:\dataset\MTWI2018\detection\gt' 44 | img_folder = r'D:\dataset\MTWI2018\detection\imgs' 45 | save_path = r'D:\dataset\MTWI2018\detection\train.json' 46 | cvt(gt_path, save_path, img_folder) 47 | -------------------------------------------------------------------------------- /convert/det/RcCTS2json.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/18 14:12 3 | # @Author : zhoujun 4 | """ 5 | 将icdar2015数据集转换为统一格式 6 | """ 7 | import pathlib 8 | import numpy as np 9 | from tqdm import tqdm 10 | from convert.utils import load, save, get_file_list 11 | 12 | 13 | def decode_chars(char_list): 14 | polygon_list = [] 15 | illegibility_list = [] 16 | text_list = [] 17 | for char_dict in char_list: 18 | polygon_list.append(np.array(char_dict['points']).reshape(-1, 2).tolist()) 19 | illegibility_list.append(True if char_dict['ignore'] == 1 else False) 20 | text_list.append(char_dict['transcription']) 21 | return polygon_list, illegibility_list, text_list 22 | 23 | 24 | def cvt(gt_path, save_path, img_folder): 25 | """ 26 | 将icdar2015格式的gt转换为json格式 27 | :param gt_path: 28 | :param save_path: 29 | :return: 30 | """ 31 | gt_dict = {'data_root': img_folder} 32 | data_list = [] 33 | for file_path in tqdm(get_file_list(gt_path, p_postfix=['.json'])): 34 | content = load(file_path) 35 | file_path = pathlib.Path(file_path) 36 | img_name = file_path.stem + '.jpg' 37 | cur_gt = {'img_name': img_name, 'annotations': []} 38 | char_polygon_list, char_illegibility_list, char_text_list = decode_chars(content['chars']) 39 | for line in content['lines']: 40 | cur_line_gt = {'polygon': [], 'text': '', 'illegibility': False, 'language': 'Latin'} 41 | chars_gt = [{'polygon': [], 'char': '', 'illegibility': False, 'language': 'Latin'}] 42 | cur_line_gt['chars'] = chars_gt 43 | # 字符串级别的信息 44 | cur_line_gt['polygon'] = np.array(line['points']).reshape(-1, 2).tolist() 45 | cur_line_gt['text'] = line['transcription'] 46 | cur_line_gt['illegibility'] = True if line['ignore'] == 1 else False 47 | str_len = len(line['transcription']) 48 | # 字符信息 49 | flag = False 50 | for char_idx in range(len(char_polygon_list)): 51 | for str_idx in range(1, str_len + 1): 52 | if ''.join(char_text_list[char_idx:char_idx + str_idx]) == line['transcription']: 53 | chars_gt = [] 54 | for j in range(char_idx, char_idx + str_idx): 55 | chars_gt.append({'polygon': char_polygon_list[j], 'char': char_text_list[j], 56 | 'illegibility': char_illegibility_list[j], 'language': 'Latin'}) 57 | cur_line_gt['chars'] = chars_gt 58 | char_polygon_list = char_polygon_list[char_idx + str_len:] 59 | char_text_list = char_text_list[char_idx + str_len:] 60 | char_illegibility_list = char_illegibility_list[char_idx + str_len:] 61 | flag = True 62 | break 63 | if flag: 64 | break 65 | cur_gt['annotations'].append(cur_line_gt) 66 | data_list.append(cur_gt) 67 | gt_dict['data_list'] = data_list 68 | save(gt_dict, save_path) 69 | 70 | 71 | if __name__ == '__main__': 72 | gt_path = r'D:\dataset\ReCTS\detection\gt' 73 | img_folder = r'D:\dataset\ReCTS\detection\img' 74 | save_path = r'D:\dataset\ReCTS\detection\train.json' 75 | cvt(gt_path, save_path, img_folder) 76 | -------------------------------------------------------------------------------- /convert/det/SROIE2json.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/18 14:12 3 | # @Author : zhoujun 4 | """ 5 | 将icdar2015数据集转换为统一格式 6 | """ 7 | import pathlib 8 | from tqdm import tqdm 9 | from convert.utils import load, save, get_file_list 10 | 11 | 12 | def cvt(gt_path, save_path, img_folder): 13 | """ 14 | 将icdar2015格式的gt转换为json格式 15 | :param gt_path: 16 | :param save_path: 17 | :return: 18 | """ 19 | gt_dict = {'data_root': img_folder} 20 | data_list = [] 21 | for file_path in tqdm(get_file_list(gt_path, p_postfix=['.txt'])): 22 | content = load(file_path) 23 | file_path = pathlib.Path(file_path) 24 | img_name = file_path.name.replace('.txt', '.jpg') 25 | cur_gt = {'img_name': img_name, 'annotations': []} 26 | for line in content: 27 | cur_line_gt = {'polygon': [], 'text': '', 'illegibility': False, 'language': 'Latin'} 28 | chars_gt = [{'polygon': [], 'char': '', 'illegibility': False, 'language': 'Latin'}] 29 | cur_line_gt['chars'] = chars_gt 30 | line = line.split(',') 31 | if len(line) < 9: 32 | continue 33 | # 字符串级别的信息 34 | x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, line[:8])) 35 | cur_line_gt['polygon'] = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]] 36 | cur_line_gt['text'] = line[-1] 37 | cur_line_gt['illegibility'] = True if cur_line_gt['text'] == '*' or cur_line_gt['text'] == '###' else False 38 | cur_gt['annotations'].append(cur_line_gt) 39 | data_list.append(cur_gt) 40 | gt_dict['data_list'] = data_list 41 | save(gt_dict, save_path) 42 | 43 | 44 | if __name__ == '__main__': 45 | gt_path = r'D:\dataset\SROIE2019\detection\test\gt' 46 | img_folder = r'D:\dataset\SROIE2019\detection\test\imgs' 47 | save_path = r'D:\dataset\SROIE2019\detection\test.json' 48 | cvt(gt_path, save_path, img_folder) 49 | -------------------------------------------------------------------------------- /convert/det/SynthText800k2json.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/23 9:29 3 | # @Author : zhoujun 4 | import os 5 | import pathlib 6 | import numpy as np 7 | from tqdm import tqdm 8 | import scipy.io as sio 9 | from convert.utils import save 10 | 11 | 12 | class SynthTextDataset(): 13 | def __init__(self, img_folder: str, gt_path: str): 14 | self.img_folder = img_folder 15 | if not os.path.exists(self.img_folder): 16 | raise FileNotFoundError('Dataset folder is not exist.') 17 | 18 | self.targetFilePath = gt_path 19 | if not os.path.exists(self.targetFilePath): 20 | raise FileExistsError('Target file is not exist.') 21 | targets = {} 22 | sio.loadmat(self.targetFilePath, targets, squeeze_me=True, struct_as_record=False, 23 | variable_names=['imnames', 'wordBB', 'txt']) 24 | 25 | self.imageNames = targets['imnames'] 26 | self.wordBBoxes = targets['wordBB'] 27 | self.transcripts = targets['txt'] 28 | 29 | def cvt(self): 30 | gt_dict = {'data_root': self.img_folder} 31 | data_list = [] 32 | pbar = tqdm(total=len(self.imageNames)) 33 | for imageName, wordBBoxes, texts in zip(self.imageNames, self.wordBBoxes, self.transcripts): 34 | wordBBoxes = np.expand_dims(wordBBoxes, axis=2) if (wordBBoxes.ndim == 2) else wordBBoxes 35 | _, _, numOfWords = wordBBoxes.shape 36 | text_polys = wordBBoxes.reshape([8, numOfWords], order='F').T # num_words * 8 37 | text_polys = text_polys.reshape(numOfWords, 4, 2) # num_of_words * 4 * 2 38 | transcripts = [word for line in texts for word in line.split()] 39 | if numOfWords != len(transcripts): 40 | continue 41 | cur_gt = {'img_name': imageName, 'annotations': []} 42 | for polygon, text in zip(text_polys, transcripts): 43 | cur_line_gt = {'polygon': [], 'text': '', 'illegibility': False, 'language': 'Latin'} 44 | chars_gt = [{'polygon': [], 'char': '', 'illegibility': False, 'language': 'Latin'}] 45 | cur_line_gt['chars'] = chars_gt 46 | cur_line_gt['text'] = text 47 | cur_line_gt['polygon'] = polygon.tolist() 48 | cur_line_gt['illegibility'] = text in ['###', '*'] 49 | cur_gt['annotations'].append(cur_line_gt) 50 | data_list.append(cur_gt) 51 | pbar.update(1) 52 | pbar.close() 53 | gt_dict['data_list'] = data_list 54 | save(gt_dict, save_path) 55 | 56 | 57 | if __name__ == '__main__': 58 | img_folder = r'D:\dataset\SynthText800k\detection\imgs' 59 | gt_path = r'D:\dataset\SynthText800k\detection\gt.mat' 60 | save_path = r'D:\dataset\SynthText800k\detection\train1.json' 61 | synth_dataset = SynthTextDataset(img_folder, gt_path) 62 | synth_dataset.cvt() 63 | -------------------------------------------------------------------------------- /convert/det/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/24 11:09 3 | # @Author : zhoujun -------------------------------------------------------------------------------- /convert/det/check_json.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/20 20:33 3 | # @Author : zhoujun 4 | """ 5 | 用于检查生成的json文件有没有问题 6 | """ 7 | from PIL import Image 8 | from tqdm import tqdm 9 | from matplotlib import pyplot as plt 10 | 11 | from convert.utils import show_bbox_on_image, load_gt 12 | 13 | if __name__ == '__main__': 14 | json_path = r'D:\dataset\自然场景文字检测挑战赛初赛数据\验证集\validation_new.json' 15 | data = load_gt(json_path) 16 | for img_path, gt in tqdm(data.items()): 17 | # print(gt['illegibility_list']) 18 | # print(gt['texts']) 19 | img = Image.open(img_path) 20 | img = show_bbox_on_image(img, gt['polygons'], gt['texts']) 21 | plt.imshow(img) 22 | plt.show() 23 | -------------------------------------------------------------------------------- /convert/det/coco_text.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | __author__ = 'andreasveit' 6 | __version__ = '2.0' 7 | # Interface for accessing the COCO-Text dataset. 8 | 9 | # COCO-Text is a large dataset designed for text detection and recognition. 10 | # This is a Python API that assists in loading, parsing and visualizing the 11 | # annotations. The format of the COCO-Text annotations is also described on 12 | # the project website http://vision.cornell.edu/se3/coco-text/. In addition to this API, please download both 13 | # the COCO images and annotations. 14 | # This dataset is based on Microsoft COCO. Please visit http://mscoco.org/ 15 | # for more information on COCO, including for the image data, object annotatins 16 | # and caption annotations. 17 | 18 | # An alternative to using the API is to load the annotations directly 19 | # into Python dictionary: 20 | # with open(annotation_filename) as json_file: 21 | # coco_text = json.load(json_file) 22 | # Using the API provides additional utility functions. 23 | 24 | # The following API functions are defined: 25 | # COCO_Text - COCO-Text api class that loads COCO annotations and prepare data structures. 26 | # getAnnIds - Get ann ids that satisfy given filter conditions. 27 | # getImgIds - Get img ids that satisfy given filter conditions. 28 | # loadAnns - Load anns with the specified ids. 29 | # loadImgs - Load imgs with the specified ids. 30 | # loadRes - Load algorithm results and create API for accessing them. 31 | # Throughout the API "ann"=annotation, "cat"=category, and "img"=image. 32 | 33 | # COCO-Text Toolbox. Version 1.1 34 | # Data and paper available at: http://vision.cornell.edu/se3/coco-text/ 35 | # Code based on Microsoft COCO Toolbox Version 1.0 by Piotr Dollar and Tsung-Yi Lin 36 | # extended and adapted by Andreas Veit, 2016. 37 | # Licensed under the Simplified BSD License [see bsd.txt] 38 | 39 | import json 40 | import datetime 41 | import matplotlib.pyplot as plt 42 | from matplotlib.collections import PatchCollection 43 | from matplotlib.patches import Rectangle, PathPatch 44 | from matplotlib.path import Path 45 | import numpy as np 46 | import copy 47 | import os 48 | 49 | class COCO_Text: 50 | def __init__(self, annotation_file=None): 51 | """ 52 | Constructor of COCO-Text helper class for reading and visualizing annotations. 53 | :param annotation_file (str): location of annotation file 54 | :return: 55 | """ 56 | # load dataset 57 | self.dataset = {} 58 | self.anns = {} 59 | self.imgToAnns = {} 60 | self.catToImgs = {} 61 | self.imgs = {} 62 | self.cats = {} 63 | self.val = [] 64 | self.test = [] 65 | self.train = [] 66 | if not annotation_file == None: 67 | assert os.path.isfile(annotation_file), "file does not exist" 68 | print('loading annotations into memory...') 69 | time_t = datetime.datetime.utcnow() 70 | dataset = json.load(open(annotation_file, 'r')) 71 | print(datetime.datetime.utcnow() - time_t) 72 | self.dataset = dataset 73 | self.createIndex() 74 | 75 | def createIndex(self): 76 | # create index 77 | print('creating index...') 78 | self.imgToAnns = {int(cocoid): self.dataset['imgToAnns'][cocoid] for cocoid in self.dataset['imgToAnns']} 79 | self.imgs = {int(cocoid): self.dataset['imgs'][cocoid] for cocoid in self.dataset['imgs']} 80 | self.anns = {int(annid): self.dataset['anns'][annid] for annid in self.dataset['anns']} 81 | self.cats = self.dataset['cats'] 82 | self.val = [int(cocoid) for cocoid in self.dataset['imgs'] if self.dataset['imgs'][cocoid]['set'] == 'val'] 83 | self.test = [int(cocoid) for cocoid in self.dataset['imgs'] if self.dataset['imgs'][cocoid]['set'] == 'test'] 84 | self.train = [int(cocoid) for cocoid in self.dataset['imgs'] if self.dataset['imgs'][cocoid]['set'] == 'train'] 85 | print('index created!') 86 | 87 | def info(self): 88 | """ 89 | Print information about the annotation file. 90 | :return: 91 | """ 92 | for key, value in self.dataset['info'].items(): 93 | print('%s: %s'%(key, value)) 94 | 95 | def filtering(self, filterDict, criteria): 96 | return [key for key in filterDict if all(criterion(filterDict[key]) for criterion in criteria)] 97 | 98 | def getAnnByCat(self, properties): 99 | """ 100 | Get ann ids that satisfy given properties 101 | :param properties (list of tuples of the form [(category type, category)] e.g., [('readability','readable')] 102 | : get anns for given categories - anns have to satisfy all given property tuples 103 | :return: ids (int array) : integer array of ann ids 104 | """ 105 | return self.filtering(self.anns, [lambda d, x=a, y=b:d[x] == y for (a,b) in properties]) 106 | 107 | def getAnnIds(self, imgIds=[], catIds=[], areaRng=[]): 108 | """ 109 | Get ann ids that satisfy given filter conditions. default skips that filter 110 | :param imgIds (int array) : get anns for given imgs 111 | catIds (list of tuples of the form [(category type, category)] e.g., [('readability','readable')] 112 | : get anns for given cats 113 | areaRng (float array) : get anns for given area range (e.g. [0 inf]) 114 | :return: ids (int array) : integer array of ann ids 115 | """ 116 | imgIds = imgIds if type(imgIds) == list else [imgIds] 117 | catIds = catIds if type(catIds) == list else [catIds] 118 | 119 | if len(imgIds) == len(catIds) == len(areaRng) == 0: 120 | anns = list(self.anns.keys()) 121 | else: 122 | if not len(imgIds) == 0: 123 | anns = sum([self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns],[]) 124 | else: 125 | anns = list(self.anns.keys()) 126 | anns = anns if len(catIds) == 0 else list(set(anns).intersection(set(self.getAnnByCat(catIds)))) 127 | anns = anns if len(areaRng) == 0 else [ann for ann in anns if self.anns[ann]['area'] > areaRng[0] and self.anns[ann]['area'] < areaRng[1]] 128 | return anns 129 | 130 | def getImgIds(self, imgIds=[], catIds=[]): 131 | ''' 132 | Get img ids that satisfy given filter conditions. 133 | :param imgIds (int array) : get imgs for given ids 134 | :param catIds (int array) : get imgs with all given cats 135 | :return: ids (int array) : integer array of img ids 136 | ''' 137 | imgIds = imgIds if type(imgIds) == list else [imgIds] 138 | catIds = catIds if type(catIds) == list else [catIds] 139 | 140 | if len(imgIds) == len(catIds) == 0: 141 | ids = list(self.imgs.keys()) 142 | else: 143 | ids = set(imgIds) 144 | if not len(catIds) == 0: 145 | ids = ids.intersection(set([self.anns[annid]['image_id'] for annid in self.getAnnByCat(catIds)])) 146 | return list(ids) 147 | 148 | def loadAnns(self, ids=[]): 149 | """ 150 | Load anns with the specified ids. 151 | :param ids (int array) : integer ids specifying anns 152 | :return: anns (object array) : loaded ann objects 153 | """ 154 | if type(ids) == list: 155 | return [self.anns[id] for id in ids] 156 | elif type(ids) == int: 157 | return [self.anns[ids]] 158 | 159 | def loadImgs(self, ids=[]): 160 | """ 161 | Load anns with the specified ids. 162 | :param ids (int array) : integer ids specifying img 163 | :return: imgs (object array) : loaded img objects 164 | """ 165 | if type(ids) == list: 166 | return [self.imgs[id] for id in ids] 167 | elif type(ids) == int: 168 | return [self.imgs[ids]] 169 | 170 | def showAnns(self, anns, show_mask=False): 171 | """ 172 | Display the specified annotations. 173 | :param anns (array of object): annotations to display 174 | :return: None 175 | """ 176 | if len(anns) == 0: 177 | return 0 178 | ax = plt.gca() 179 | boxes = [] 180 | color = [] 181 | for ann in anns: 182 | c = np.random.random((1, 3)).tolist()[0] 183 | if show_mask: 184 | verts = list(zip(*[iter(ann['mask'])] * 2)) + [(0, 0)] 185 | codes = [Path.MOVETO] + [Path.LINETO] * (len(verts) - 2) + [Path.CLOSEPOLY] 186 | path = Path(verts, codes) 187 | patch = PathPatch(path, facecolor='none') 188 | boxes.append(patch) 189 | text_x, text_y = verts[0] 190 | else: 191 | left, top, width, height = ann['bbox'] 192 | boxes.append(Rectangle([left,top],width,height,alpha=0.4)) 193 | text_x, text_y = left, top 194 | color.append(c) 195 | if 'utf8_string' in list(ann.keys()): 196 | ax.annotate(ann['utf8_string'],(text_x, text_y-4),color=c) 197 | p = PatchCollection(boxes, facecolors=color, edgecolors=(0,0,0,1), linewidths=3, alpha=0.4) 198 | ax.add_collection(p) 199 | 200 | def loadRes(self, resFile): 201 | """ 202 | Load result file and return a result api object. 203 | :param resFile (str) : file name of result file 204 | :return: res (obj) : result api object 205 | """ 206 | res = COCO_Text() 207 | res.dataset['imgs'] = [img for img in self.dataset['imgs']] 208 | 209 | print('Loading and preparing results... ') 210 | time_t = datetime.datetime.utcnow() 211 | if type(resFile) == str: 212 | anns = json.load(open(resFile)) 213 | else: 214 | anns = resFile 215 | assert type(anns) == list, 'results in not an array of objects' 216 | annsImgIds = [int(ann['image_id']) for ann in anns] 217 | 218 | if set(annsImgIds) != (set(annsImgIds) & set(self.getImgIds())): 219 | print('Results do not correspond to current coco set') 220 | print('skipping ', str(len(set(annsImgIds)) - len(set(annsImgIds) & set(self.getImgIds()))), ' images') 221 | annsImgIds = list(set(annsImgIds) & set(self.getImgIds())) 222 | 223 | res.imgToAnns = {cocoid : [] for cocoid in annsImgIds} 224 | res.imgs = {cocoid: self.imgs[cocoid] for cocoid in annsImgIds} 225 | 226 | assert anns[0]['bbox'] != [], 'results have incorrect format' 227 | for id, ann in enumerate(anns): 228 | if ann['image_id'] not in annsImgIds: 229 | continue 230 | bb = ann['bbox'] 231 | ann['area'] = bb[2]*bb[3] 232 | ann['id'] = id 233 | res.anns[id] = ann 234 | res.imgToAnns[ann['image_id']].append(id) 235 | print('DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds())) 236 | 237 | return res -------------------------------------------------------------------------------- /convert/det/coco_text2json.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/21 12:54 3 | # @Author : zhoujun 4 | """ 5 | 将coco_text数据集转换为统一格式 6 | """ 7 | import os 8 | import numpy as np 9 | from tqdm import tqdm 10 | from convert.utils import save 11 | from convert.det.coco_text import COCO_Text 12 | 13 | def cvt(gt_path, save_path, imgs_folder): 14 | gt_dict = {'data_root': imgs_folder} 15 | data_list = [] 16 | ct = COCO_Text(gt_path) 17 | 18 | train_img_ids = ct.getImgIds(imgIds=ct.val) 19 | for img_id in tqdm(train_img_ids): 20 | img = ct.loadImgs(img_id)[0] 21 | # img_path = os.path.join(imgs_folder, img['file_name']) 22 | # if not os.path.exists(img_path): 23 | # continue 24 | cur_gt = {'img_name': img['file_name'], 'annotations': []} 25 | annIds = ct.getAnnIds(imgIds=img['id']) 26 | anns = ct.loadAnns(annIds) 27 | for ann in anns: 28 | if len(ann['utf8_string']) == 0: 29 | continue 30 | cur_line_gt = {'polygon': [], 'text': '', 'illegibility': False, 'language': 'Latin'} 31 | chars_gt = [{'polygon': [], 'char': '', 'illegibility': False, 'language': 'Latin'}] 32 | cur_line_gt['chars'] = chars_gt 33 | 34 | cur_line_gt['language'] = ann['language'] 35 | chars_gt[0]['language'] = ann['language'] 36 | 37 | cur_line_gt['polygon'] = np.array(ann['mask']).reshape(-1,2).tolist() 38 | cur_line_gt['text'] = ann['utf8_string'] 39 | cur_line_gt['illegibility'] = True if ann['legibility'] == "illegible" else False 40 | cur_gt['annotations'].append(cur_line_gt) 41 | if len(cur_gt['annotations']) > 0: 42 | data_list.append(cur_gt) 43 | gt_dict['data_list'] = data_list 44 | save(gt_dict, save_path) 45 | print(len(gt_dict), len(data_list)) 46 | 47 | 48 | def show_coco(gt_path, imgs_folder): 49 | import numpy as np 50 | import skimage.io as io 51 | import matplotlib.pyplot as plt 52 | 53 | data = COCO_Text(gt_path) 54 | # get all images containing at least one instance of legible text 55 | imgIds = data.getImgIds(imgIds=data.train) 56 | # pick one at random 57 | img = data.loadImgs(imgIds[np.random.randint(0, len(imgIds))])[0] 58 | I = io.imread(os.path.join(imgs_folder, img['file_name'])) 59 | plt.figure() 60 | plt.imshow(I) 61 | annIds = data.getAnnIds(imgIds=img['id']) 62 | anns = data.loadAnns(annIds) 63 | data.showAnns(anns) 64 | plt.show() 65 | 66 | 67 | if __name__ == '__main__': 68 | gt_path = r'D:\dataset\COCO_Text\detection\cocotext.v2.json' 69 | imgs_folder = r'D:\dataset\COCO_Text\detection\val' 70 | save_path = r'D:\dataset\COCO_Text\detection\val.json' 71 | cvt(gt_path, save_path, imgs_folder) 72 | # show_coco(gt_path, imgs_folder) 73 | -------------------------------------------------------------------------------- /convert/det/convert2jpg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/21 10:37 3 | # @Author : zhoujun 4 | """ 5 | 用于将图片统一转换为jpg 6 | """ 7 | import os 8 | import pathlib 9 | from tqdm import tqdm 10 | from convert.utils import get_file_list 11 | 12 | if __name__ == '__main__': 13 | img_folder = r'D:\dataset\mlt2019\detection\imgs' 14 | for img_path in tqdm(get_file_list(img_folder, p_postfix=['.*'])): 15 | img_path = pathlib.Path(img_path) 16 | save_path = img_path.parent / (img_path.stem + '.jpg') 17 | if img_path != save_path: 18 | os.rename(img_path, save_path) 19 | -------------------------------------------------------------------------------- /convert/det/icdar20152json.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/18 14:12 3 | # @Author : zhoujun 4 | """ 5 | 将icdar2015数据集转换为统一格式 6 | """ 7 | import pathlib 8 | from tqdm import tqdm 9 | from convert.utils import load, save, get_file_list 10 | 11 | 12 | def cvt(gt_path, save_path, img_folder): 13 | """ 14 | 将icdar2015格式的gt转换为json格式 15 | :param gt_path: 16 | :param save_path: 17 | :return: 18 | """ 19 | gt_dict = {'data_root': img_folder} 20 | data_list = [] 21 | for file_path in tqdm(get_file_list(gt_path, p_postfix=['.txt'])): 22 | content = load(file_path) 23 | file_path = pathlib.Path(file_path) 24 | img_name = file_path.name.replace('gt_', '').replace('.txt', '.jpg') 25 | cur_gt = {'img_name': img_name, 'annotations': []} 26 | for line in content: 27 | cur_line_gt = {'polygon': [], 'text': '', 'illegibility': False, 'language': 'Latin'} 28 | chars_gt = [{'polygon': [], 'char': '', 'illegibility': False, 'language': 'Latin'}] 29 | cur_line_gt['chars'] = chars_gt 30 | line = line.split(',') 31 | # 字符串级别的信息 32 | x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, line[:8])) 33 | cur_line_gt['polygon'] = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]] 34 | cur_line_gt['text'] = line[-1] 35 | cur_line_gt['illegibility'] = True if cur_line_gt['text'] == '*' or cur_line_gt['text'] == '###' else False 36 | cur_gt['annotations'].append(cur_line_gt) 37 | data_list.append(cur_gt) 38 | gt_dict['data_list'] = data_list 39 | save(gt_dict, save_path) 40 | 41 | 42 | if __name__ == '__main__': 43 | gt_path = r'D:\dataset\icdar2015\detection\test\gt' 44 | img_folder = r'D:\dataset\icdar2015\detection\test\imgs' 45 | save_path = r'D:\dataset\icdar2015\detection\test.json' 46 | cvt(gt_path, save_path, img_folder) 47 | -------------------------------------------------------------------------------- /convert/det/icdar2017rctw2json.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/18 14:12 3 | # @Author : zhoujun 4 | """ 5 | 将icdar2015数据集转换为统一格式 6 | """ 7 | import pathlib 8 | from tqdm import tqdm 9 | from convert.utils import load, save, get_file_list 10 | 11 | 12 | def cvt(save_path, img_folder): 13 | """ 14 | 将icdar2015格式的gt转换为json格式 15 | :param gt_path: 16 | :param save_path: 17 | :return: 18 | """ 19 | gt_dict = {'data_root': img_folder} 20 | data_list = [] 21 | for img_path in tqdm(get_file_list(img_folder, p_postfix=['.jpg'])): 22 | img_path = pathlib.Path(img_path) 23 | gt_path = pathlib.Path(img_folder) / img_path.name.replace('.jpg', '.txt') 24 | content = load(gt_path) 25 | cur_gt = {'img_name': img_path.name, 'annotations': []} 26 | for line in content: 27 | cur_line_gt = {'polygon': [], 'text': '', 'illegibility': False, 'language': 'Latin'} 28 | chars_gt = [{'polygon': [], 'char': '', 'illegibility': False, 'language': 'Latin'}] 29 | cur_line_gt['chars'] = chars_gt 30 | line = line.split(',') 31 | # 字符串级别的信息 32 | x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, line[:8])) 33 | cur_line_gt['polygon'] = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]] 34 | cur_line_gt['text'] = line[-1][1:-1] 35 | cur_line_gt['illegibility'] = True if line[8] == '1' else False 36 | cur_gt['annotations'].append(cur_line_gt) 37 | data_list.append(cur_gt) 38 | gt_dict['data_list'] = data_list 39 | save(gt_dict, save_path) 40 | 41 | 42 | if __name__ == '__main__': 43 | img_folder = r'D:\dataset\icdar2017rctw\detection\imgs' 44 | save_path = r'D:\dataset\icdar2017rctw\detection\train.json' 45 | cvt(save_path, img_folder) 46 | -------------------------------------------------------------------------------- /convert/det/iflytek_text_detection.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/7/7 10:08 3 | # @Author : zhoujun 4 | import numpy as np 5 | from tqdm import tqdm 6 | from convert.utils import load, save 7 | 8 | 9 | def cvt(gt_path, save_path, imgs_folder): 10 | gt_dict = {'data_root': imgs_folder} 11 | data_list = [] 12 | ct = load(gt_path) 13 | 14 | for img_id, anns in tqdm(ct.items()): 15 | img_name = img_id.replace('gt', 'img') + '.jpg' 16 | cur_gt = {'img_name': img_name, 'annotations': []} 17 | for ann in anns: 18 | cur_line_gt = {'polygon': [], 'text': '', 'illegibility': False, 'language': 'Latin'} 19 | chars_gt = [{'polygon': [], 'char': '', 'illegibility': False, 'language': 'Latin'}] 20 | cur_line_gt['chars'] = chars_gt 21 | 22 | cur_line_gt['polygon'] = ann['points'] 23 | cur_line_gt['illegibility'] = ann['illegibility'] 24 | cur_gt['annotations'].append(cur_line_gt) 25 | if len(cur_gt['annotations']) > 0: 26 | data_list.append(cur_gt) 27 | gt_dict['data_list'] = data_list 28 | save(gt_dict, save_path) 29 | print(len(gt_dict), len(data_list)) 30 | 31 | 32 | if __name__ == '__main__': 33 | gt_path = r'D:\dataset\自然场景文字检测挑战赛初赛数据\验证集\validation.json' 34 | imgs_folder = r'D:\dataset\自然场景文字检测挑战赛初赛数据\验证集\new_image' 35 | save_path = r'D:\dataset\自然场景文字检测挑战赛初赛数据\验证集\validation_new.json' 36 | cvt(gt_path, save_path, imgs_folder) 37 | # show_coco(gt_path, imgs_folder) 38 | -------------------------------------------------------------------------------- /convert/det/mlt20192json.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/18 14:12 3 | # @Author : zhoujun 4 | """ 5 | 将mlt2019数据集转换为统一格式 6 | """ 7 | import glob 8 | import pathlib 9 | from tqdm import tqdm 10 | from convert.utils import load, save, get_file_list 11 | 12 | 13 | def cvt(gt_path, save_path, img_folder): 14 | """ 15 | 将icdar2015格式的gt转换为json格式 16 | :param gt_path: 17 | :param save_path: 18 | :return: 19 | """ 20 | gt_dict = {'data_root': img_folder} 21 | data_list = [] 22 | for file_path in tqdm(get_file_list(gt_path, p_postfix=['.txt'])): 23 | content = load(file_path) 24 | file_path = pathlib.Path(file_path) 25 | img_name = file_path.name.replace('.txt', '.jpg') 26 | cur_gt = {'img_name': img_name, 'annotations': []} 27 | for line in content: 28 | cur_line_gt = {'polygon': [], 'text': '', 'illegibility': False, 'language': 'Latin'} 29 | chars_gt = [{'polygon': [], 'char': '', 'illegibility': False, 'language': 'Latin'}] 30 | cur_line_gt['chars'] = chars_gt 31 | line = line.split(',') 32 | lang = line[8] 33 | cur_line_gt['language'] = lang 34 | chars_gt[0]['language'] = lang 35 | # 字符串级别的信息 36 | x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, line[:8])) 37 | cur_line_gt['polygon'] = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]] 38 | cur_line_gt['text'] = line[-1] 39 | cur_line_gt['illegibility'] = True if cur_line_gt['text'] == '*' or cur_line_gt['text'] == '###' else False 40 | cur_gt['annotations'].append(cur_line_gt) 41 | data_list.append(cur_gt) 42 | gt_dict['data_list'] = data_list 43 | save(gt_dict, save_path) 44 | 45 | 46 | if __name__ == '__main__': 47 | gt_path = r'D:\dataset\mlt2019\detection\gt' 48 | img_folder = r'D:\dataset\mlt2019\detection\imgs' 49 | save_path = r'D:\dataset\mlt2019\detection\gt.json' 50 | cvt(gt_path, save_path, img_folder) 51 | -------------------------------------------------------------------------------- /convert/move_imgs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/21 16:17 3 | # @Author : zhoujun 4 | """ 5 | 根据json,将图片移动到指定文件夹,以便删除不需要的图片 6 | """ 7 | import os 8 | import shutil 9 | from tqdm import tqdm 10 | from convert.utils import load_gt 11 | 12 | if __name__ == '__main__': 13 | json_path = r'D:\dataset\COCO_Text\detection\val.json' 14 | save_path = r'D:\dataset\COCO_Text\detection\val' 15 | os.makedirs(save_path,exist_ok=True) 16 | data = load_gt(json_path) 17 | for img_path, gt in tqdm(data.items()): 18 | dst_path = os.path.join(save_path,os.path.basename(img_path)) 19 | shutil.move(img_path,dst_path) 20 | -------------------------------------------------------------------------------- /convert/rec/360w2txt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/24 11:26 3 | # @Author : zhoujun 4 | 5 | import os 6 | from PIL import Image 7 | from tqdm import tqdm 8 | import matplotlib.pyplot as plt 9 | 10 | # 支持中文 11 | plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 12 | plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 13 | from convert.utils import load, save 14 | 15 | def cvt(gt_path, save_path, img_folder): 16 | content = load(gt_path) 17 | file_list = [] 18 | for i,line in tqdm(enumerate(content)): 19 | try: 20 | line = line.split('.jpg ') 21 | img_path = os.path.join(img_folder, line[-2]) 22 | file_list.append(img_path + '.jpg' + '\t' + line[-1] + '\t' + 'Chinese') 23 | # img = Image.open(img_path) 24 | # plt.title(line[-1]) 25 | # plt.imshow(img) 26 | # plt.show() 27 | except: 28 | a = 1 29 | save(file_list, save_path) 30 | 31 | 32 | if __name__ == '__main__': 33 | img_folder = r'D:\dataset\360w\train_images' 34 | gt_path = r'D:\BaiduNetdiskDownload\360_train.txt' 35 | save_path = r'D:\BaiduNetdiskDownload\train.txt' 36 | cvt(gt_path, save_path, img_folder) -------------------------------------------------------------------------------- /convert/rec/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/24 11:10 3 | # @Author : zhoujun -------------------------------------------------------------------------------- /convert/rec/baidu2txt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/24 11:10 3 | # @Author : zhoujun 4 | import os 5 | from PIL import Image 6 | from tqdm import tqdm 7 | import matplotlib.pyplot as plt 8 | 9 | # 支持中文 10 | plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 11 | plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 12 | from convert.utils import load, save 13 | 14 | 15 | def cvt(gt_path, save_path, img_folder): 16 | content = load(gt_path) 17 | file_list = [] 18 | for line in tqdm(content): 19 | line = line.split('\t') 20 | img_path = os.path.join(img_folder, line[-2]) 21 | if not os.path.exists(img_path): 22 | print(img_path) 23 | file_list.append(img_path + '\t' + line[-1] + '\t' + 'Chinese') 24 | # img = Image.open(img_path) 25 | # plt.title(line[-1]) 26 | # plt.imshow(img) 27 | # plt.show() 28 | save(file_list, save_path) 29 | 30 | 31 | if __name__ == '__main__': 32 | img_folder = r'D:\dataset\百度中文场景文字识别\train_images' 33 | gt_path = r'D:\dataset\百度中文场景文字识别\train.list' 34 | save_path = r'D:\dataset\百度中文场景文字识别\train.txt' 35 | cvt(gt_path, save_path, img_folder) 36 | -------------------------------------------------------------------------------- /convert/rec/mjsyhtn2txt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/24 11:10 3 | # @Author : zhoujun 4 | import os 5 | import pathlib 6 | from PIL import Image 7 | from tqdm import tqdm 8 | import matplotlib.pyplot as plt 9 | from convert.utils import load, save 10 | 11 | 12 | def cvt(gt_path, save_path, img_folder): 13 | content = load(gt_path) 14 | file_list = [] 15 | for line in tqdm(content): 16 | img_relative_path = line.split(' ')[0] 17 | img_path = os.path.join(img_folder, img_relative_path) 18 | img_path = pathlib.Path(img_path) 19 | label = img_path.stem.split('_')[1] 20 | if not img_path.exists(): 21 | print(img_path) 22 | file_list.append(str(img_path) + '\t' + label + '\t' + 'English') 23 | # img = Image.open(img_path) 24 | # plt.title(label) 25 | # plt.imshow(img) 26 | # plt.show() 27 | save(file_list, save_path) 28 | 29 | 30 | if __name__ == '__main__': 31 | img_folder = r'D:\dataset\mjsynth\imgs' 32 | gt_path = r'D:\dataset\mjsynth\annotation_test.txt' 33 | save_path = r'D:\dataset\mjsynth\test.txt' 34 | cvt(gt_path, save_path, img_folder) 35 | -------------------------------------------------------------------------------- /convert/simsun.ttc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenmuZhou/OCR_DataSet/862956dd894e025e1a636f17660770bb25880129/convert/simsun.ttc -------------------------------------------------------------------------------- /convert/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/20 19:54 3 | # @Author : zhoujun 4 | import json 5 | import os 6 | import glob 7 | import pathlib 8 | from natsort import natsorted 9 | 10 | __all__ = ['load', 'save', 'get_file_list', 'show_bbox_on_image', 'load_gt'] 11 | 12 | 13 | def get_file_list(folder_path: str, p_postfix: list = None) -> list: 14 | """ 15 | 获取所给文件目录里的指定后缀的文件,读取文件列表目前使用的是 os.walk 和 os.listdir ,这两个目前比 pathlib 快很多 16 | :param filder_path: 文件夹名称 17 | :param p_postfix: 文件后缀,如果为 [.*]将返回全部文件 18 | :return: 获取到的指定类型的文件列表 19 | """ 20 | assert os.path.exists(folder_path) and os.path.isdir(folder_path) 21 | if p_postfix is None: 22 | p_postfix = ['.jpg'] 23 | if isinstance(p_postfix, str): 24 | p_postfix = [p_postfix] 25 | file_list = [x for x in glob.glob(folder_path + '/**/*.*', recursive=True) if 26 | os.path.splitext(x)[-1] in p_postfix or '.*' in p_postfix] 27 | return natsorted(file_list) 28 | 29 | 30 | def load(file_path: str): 31 | file_path = pathlib.Path(file_path) 32 | func_dict = {'.txt': load_txt, '.json': load_json, '.list': load_txt} 33 | assert file_path.suffix in func_dict 34 | return func_dict[file_path.suffix](file_path) 35 | 36 | 37 | def load_txt(file_path: str): 38 | with open(file_path, 'r', encoding='utf8') as f: 39 | content = [x.strip().strip('\ufeff').strip('\xef\xbb\xbf') for x in f.readlines()] 40 | return content 41 | 42 | 43 | def load_json(file_path: str): 44 | with open(file_path, 'r', encoding='utf8') as f: 45 | content = json.load(f) 46 | return content 47 | 48 | 49 | def save(data, file_path): 50 | file_path = pathlib.Path(file_path) 51 | func_dict = {'.txt': save_txt, '.json': save_json} 52 | assert file_path.suffix in func_dict 53 | return func_dict[file_path.suffix](data, file_path) 54 | 55 | 56 | def save_txt(data, file_path): 57 | """ 58 | 将一个list的数组写入txt文件里 59 | :param data: 60 | :param file_path: 61 | :return: 62 | """ 63 | if not isinstance(data, list): 64 | data = [data] 65 | with open(file_path, mode='w', encoding='utf8') as f: 66 | f.write('\n'.join(data)) 67 | 68 | 69 | def save_json(data, file_path): 70 | with open(file_path, 'w', encoding='utf-8') as json_file: 71 | json.dump(data, json_file, ensure_ascii=False, indent=4) 72 | 73 | 74 | def show_bbox_on_image(image, polygons=None, txt=None, color=None, font_path='convert/simsun.ttc'): 75 | """ 76 | 在图片上绘制 文本框和文本 77 | :param image: 78 | :param polygons: 文本框 79 | :param txt: 文本 80 | :param color: 绘制的颜色 81 | :param font_path: 字体 82 | :return: 83 | """ 84 | from PIL import ImageDraw, ImageFont 85 | image = image.convert('RGB') 86 | draw = ImageDraw.Draw(image) 87 | if len(txt) == 0: 88 | txt = None 89 | if color is None: 90 | color = (255, 0, 0) 91 | if txt is not None: 92 | font = ImageFont.truetype(font_path, 20) 93 | for i, box in enumerate(polygons): 94 | if txt is not None: 95 | draw.text((int(box[0][0]) + 20, int(box[0][1]) - 20), str(txt[i]), fill='red', font=font) 96 | for j in range(len(box) - 1): 97 | draw.line((box[j][0], box[j][1], box[j + 1][0], box[j + 1][1]), fill=color, width=5) 98 | draw.line((box[-1][0], box[-1][1], box[0][0], box[0][1]), fill=color, width=5) 99 | return image 100 | 101 | 102 | def load_gt(json_path): 103 | """ 104 | 从json文件中读取出 文本行的坐标和gt,字符的坐标和gt 105 | :param json_path: 106 | :return: 107 | """ 108 | content = load(json_path) 109 | d = {} 110 | for gt in content['data_list']: 111 | img_path = os.path.join(content['data_root'], gt['img_name']) 112 | polygons = [] 113 | texts = [] 114 | illegibility_list = [] 115 | language_list = [] 116 | for annotation in gt['annotations']: 117 | if len(annotation['polygon']) == 0: 118 | continue 119 | polygons.append(annotation['polygon']) 120 | texts.append(annotation['text']) 121 | illegibility_list.append(annotation['illegibility']) 122 | language_list.append(annotation['language']) 123 | for char_annotation in annotation['chars']: 124 | if len(char_annotation['polygon']) == 0 or len(char_annotation['char']) == 0: 125 | continue 126 | polygons.append(char_annotation['polygon']) 127 | texts.append(char_annotation['char']) 128 | illegibility_list.append(char_annotation['illegibility']) 129 | language_list.append(char_annotation['language']) 130 | d[img_path] = {'polygons': polygons, 'texts': texts, 'illegibility_list': illegibility_list, 131 | 'language_list': language_list} 132 | return d 133 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/24 11:36 3 | # @Author : zhoujun -------------------------------------------------------------------------------- /dataset/convert_det2lmdb.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/4/2 14:19 3 | # @Author : zhoujun 4 | 5 | import os 6 | import lmdb 7 | import cv2 8 | import numpy as np 9 | import argparse 10 | import shutil 11 | import sys 12 | from convert.utils import load_gt 13 | 14 | def checkImageIsValid(imageBin): 15 | if imageBin is None: 16 | return False 17 | 18 | try: 19 | imageBuf = np.fromstring(imageBin, dtype=np.uint8) 20 | img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) 21 | imgH, imgW = img.shape[0], img.shape[1] 22 | except: 23 | return False 24 | else: 25 | if imgH * imgW == 0: 26 | return False 27 | 28 | return True 29 | 30 | 31 | def writeCache(env, cache): 32 | with env.begin(write=True) as txn: 33 | for k, v in cache.items(): 34 | if type(k) == str: 35 | k = k.encode() 36 | if type(v) == str: 37 | v = v.encode() 38 | txn.put(k, v) 39 | 40 | 41 | def createDataset(outputPath, data_dict, map_size=79951162, checkValid=True): 42 | """ 43 | Create LMDB dataset for CRNN training. 44 | 45 | ARGS: 46 | outputPath : LMDB output path 47 | data_dict : a dict contains img_path,texts,text_polys 48 | checkValid : if true, check the validity of every image 49 | """ 50 | # If lmdb file already exists, remove it. Or the new data will add to it. 51 | if os.path.exists(outputPath): 52 | shutil.rmtree(outputPath) 53 | os.makedirs(outputPath) 54 | else: 55 | os.makedirs(outputPath) 56 | 57 | nSamples = len(data_dict) 58 | env = lmdb.open(outputPath, map_size=map_size) 59 | cache = {} 60 | cnt = 1 61 | for img_path in data_dict: 62 | data = data_dict[img_path] 63 | if not os.path.exists(img_path): 64 | print('%s does not exist' % img_path) 65 | continue 66 | with open(img_path, 'rb') as f: 67 | imageBin = f.read() 68 | if checkValid: 69 | if not checkImageIsValid(imageBin): 70 | print('%s is not a valid image' % img_path) 71 | continue 72 | 73 | imageKey = 'image-%09d' % cnt 74 | polygonsKey = 'polygons-%09d' % cnt 75 | textsKey = 'texts-%09d' % cnt 76 | illegibilityKey = 'illegibility-%09d' % cnt 77 | languageKey = 'language-%09d' % cnt 78 | cache[imageKey] = imageBin 79 | cache[polygonsKey] = np.array(data['polygons']).tostring() 80 | cache[textsKey] = '\t'.join(data['texts']) 81 | cache[illegibilityKey] = '\t'.join([str(x) for x in data['illegibility_list']]) 82 | cache[languageKey] = '\t'.join(data['language_list']) 83 | if cnt % 1000 == 0: 84 | writeCache(env, cache) 85 | cache = {} 86 | print('Written %d / %d' % (cnt, nSamples)) 87 | cnt += 1 88 | nSamples = cnt - 1 89 | cache['num-samples'] = str(nSamples) 90 | writeCache(env, cache) 91 | env.close() 92 | print('Created dataset with %d samples' % nSamples) 93 | 94 | 95 | def show_demo(demo_number, image_path_list, label_list): 96 | print('\nShow some demo to prevent creating wrong lmdb data') 97 | print('The first line is the path to image and the second line is the image label') 98 | for i in range(demo_number): 99 | print('image: %s\nlabel: %s\n' % (image_path_list[i], label_list[i])) 100 | 101 | 102 | if __name__ == '__main__': 103 | parser = argparse.ArgumentParser() 104 | # parser.add_argument('--out', type = str, required = True, help = 'lmdb data output path') 105 | parser.add_argument('--json_path', type=str, default='E:\\zj\\dataset\\icdar2015 (2)\\detection\\test.json',help='path to gt json') 106 | parser.add_argument('--save_floder', type=str,default=r'E:\zj\dataset\icdar2015 (2)', help='path to save lmdb') 107 | args = parser.parse_args() 108 | 109 | data_dict = load_gt(args.json_path) 110 | out_lmdb = os.path.join(args.save_floder,'train') 111 | createDataset(out_lmdb, data_dict, map_size=79951162) 112 | -------------------------------------------------------------------------------- /dataset/det.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/24 11:36 3 | # @Author : zhoujun 4 | import os 5 | import sys 6 | 7 | project = 'OCR_DataSet' # 工作项目根目录 8 | sys.path.append(os.getcwd().split(project)[0] + project) 9 | import numpy as np 10 | from PIL import Image 11 | from torch.utils.data import Dataset, DataLoader 12 | 13 | from convert.utils import load, show_bbox_on_image 14 | 15 | class DetDataSet(Dataset): 16 | def __init__(self, json_path, transform=None, target_transform=None): 17 | self.data_list = self.load_data(json_path) 18 | self.transform = transform 19 | self.target_transform = target_transform 20 | 21 | def load_data(self, json_path): 22 | """ 23 | 从json文件中读取出 文本行的坐标和gt,字符的坐标和gt 24 | :param json_path: 25 | :return: 26 | """ 27 | content = load(json_path) 28 | d = [] 29 | for gt in content['data_list']: 30 | img_path = os.path.join(content['data_root'], gt['img_name']) 31 | polygons = [] 32 | texts = [] 33 | illegibility_list = [] 34 | language_list = [] 35 | for annotation in gt['annotations']: 36 | if len(annotation['polygon']) == 0: 37 | continue 38 | polygons.append(annotation['polygon']) 39 | texts.append(annotation['text']) 40 | illegibility_list.append(annotation['illegibility']) 41 | language_list.append(annotation['language']) 42 | for char_annotation in annotation['chars']: 43 | if len(char_annotation['polygon']) == 0 or len(char_annotation['char']) == 0: 44 | continue 45 | polygons.append(char_annotation['polygon']) 46 | texts.append(char_annotation['char']) 47 | illegibility_list.append(char_annotation['illegibility']) 48 | language_list.append(char_annotation['language']) 49 | d.append({'img_path': img_path, 'polygons': np.array(polygons), 'texts': texts, 50 | 'illegibility': illegibility_list, 51 | 'language': language_list}) 52 | return d 53 | 54 | def __getitem__(self, item): 55 | item_dict = self.data_list[item] 56 | item_dict['img'] = Image.open(item_dict['img_path']).convert('RGB') 57 | item_dict['img'] = self.pre_processing(item_dict) 58 | item_dict['texts'] = self.make_label(item_dict) 59 | # 进行标签制作 60 | if self.transform: 61 | item_dict['img'] = self.transform(item_dict['img']) 62 | if self.target_transform: 63 | item_dict['texts'] = self.target_transform(item_dict['texts']) 64 | return item_dict 65 | 66 | def __len__(self): 67 | return len(self.data_list) 68 | 69 | def make_label(self, item_dict): 70 | return item_dict['texts'] 71 | 72 | def pre_processing(self, item_dict): 73 | return item_dict['img'] 74 | 75 | 76 | if __name__ == '__main__': 77 | import time 78 | from tqdm import tqdm 79 | from torchvision import transforms 80 | from matplotlib import pyplot as plt 81 | 82 | # 支持中文 83 | plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 84 | plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 85 | 86 | json_path = r'D:\dataset\自然场景文字检测挑战赛初赛数据\训练集\\train.json' 87 | 88 | dataset = DetDataSet(json_path, transform=transforms.ToTensor()) 89 | train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=0) 90 | pbar = tqdm(total=len(train_loader)) 91 | tic = time.time() 92 | for i, data in enumerate(train_loader): 93 | pass 94 | img = data['img'][0].numpy().transpose(1, 2, 0) * 255 95 | texts = [x[0] for x in data['texts']] 96 | 97 | img = show_bbox_on_image(Image.fromarray(img.astype(np.uint8)), data['polygons'][0],texts) 98 | plt.imshow(img) 99 | plt.show() 100 | pbar.update(1) 101 | pbar.close() 102 | print(len(train_loader)/(time.time()-tic)) -------------------------------------------------------------------------------- /dataset/det_lmdb.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/4/2 18:41 3 | # @Author : zhoujun 4 | import lmdb 5 | import six 6 | import sys 7 | import numpy as np 8 | from PIL import Image 9 | from torch.utils.data import Dataset, DataLoader,ConcatDataset 10 | 11 | 12 | class lmdbDataset(Dataset): 13 | def __init__(self, lmdb_path=None, transform=None, target_transform=None): 14 | self.env = lmdb.open(lmdb_path, max_readers=12, readonly=True, lock=False, readahead=False, meminit=False) 15 | 16 | if not self.env: 17 | print('cannot creat lmdb from %s' % (lmdb_path)) 18 | sys.exit(0) 19 | 20 | with self.env.begin(write=False) as txn: 21 | nSamples = int(txn.get('num-samples'.encode('utf-8'))) 22 | self.nSamples = nSamples 23 | 24 | self.transform = transform 25 | self.target_transform = target_transform 26 | 27 | def __len__(self): 28 | return self.nSamples 29 | 30 | def __getitem__(self, index): 31 | assert index <= len(self), 'index range error' 32 | index += 1 33 | item = {} 34 | with self.env.begin(write=False) as txn: 35 | img_key = 'image-%09d' % index 36 | imgbuf = txn.get(img_key.encode('utf-8')) 37 | 38 | buf = six.BytesIO() 39 | buf.write(imgbuf) 40 | buf.seek(0) 41 | try: 42 | img = Image.open(buf).convert('RGB') 43 | except IOError: 44 | print('Corrupted image for %d' % index) 45 | return self[index + 1] 46 | 47 | if self.transform is not None: 48 | img = self.transform(img) 49 | item['img'] = img 50 | polygonsKey = 'polygons-%09d' % index 51 | textsKey = 'texts-%09d' % index 52 | illegibilityKey = 'illegibility-%09d' % index 53 | languageKey = 'language-%09d' % index 54 | polygons = txn.get(polygonsKey.encode('utf-8')) 55 | item['polygons'] = np.frombuffer(polygons).reshape(-1, 4, 2) 56 | 57 | item['texts'] = txn.get(textsKey.encode('utf-8')).decode().split('\t') 58 | illegibility = txn.get(illegibilityKey.encode('utf-8')).decode().split('\t') 59 | item['illegibility'] = [x.lower()=='true' for x in illegibility] 60 | item['language'] = txn.get(languageKey.encode('utf-8')).decode().split('\t') 61 | 62 | if self.target_transform is not None: 63 | item['texts'] = self.target_transform(item['texts']) 64 | 65 | return item 66 | 67 | 68 | if __name__ == '__main__': 69 | import time 70 | from tqdm import tqdm 71 | from torchvision import transforms 72 | from matplotlib import pyplot as plt 73 | 74 | # 支持中文 75 | plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 76 | plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 77 | 78 | lmdb_path = r'E:\zj\dataset\icdar2015 (2)\train' 79 | 80 | dataset = lmdbDataset(lmdb_path, transform=transforms.ToTensor()) 81 | train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=0) 82 | pbar = tqdm(total=len(train_loader)) 83 | tic = time.time() 84 | for i, data in enumerate(train_loader): 85 | pass 86 | # img = data['img'][0].numpy().transpose(1, 2, 0) * 255 87 | # label = [x[0] for x in data['texts']] 88 | # 89 | # img = show_bbox_on_image(Image.fromarray(img.astype(np.uint8)), data['polygons'][0], label) 90 | # plt.imshow(img) 91 | # plt.show() 92 | # pbar.update(1) 93 | # pbar.close() 94 | print(len(train_loader) / (time.time() - tic)) -------------------------------------------------------------------------------- /dataset/rec.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/3/24 11:36 3 | # @Author : zhoujun 4 | import os 5 | import numpy as np 6 | from PIL import Image 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | from convert.utils import load 10 | 11 | 12 | class DetDataSet(Dataset): 13 | def __init__(self, txt_path, transform=None, target_transform=None): 14 | self.data_list = load(txt_path) 15 | self.transform = transform 16 | self.target_transform = target_transform 17 | 18 | def __getitem__(self, item): 19 | try: 20 | line = self.data_list[item].split('\t') 21 | img = Image.open(line[0]).convert('RGB') 22 | img = self.pre_processing(img) 23 | label = self.make_label(line[1]) 24 | # 进行标签制作 25 | if self.transform: 26 | img = self.transform(img) 27 | if self.target_transform: 28 | label = self.target_transform(label) 29 | return img, label 30 | except: 31 | return self.__getitem__(np.random.randint(self.__len__())) 32 | 33 | def __len__(self): 34 | return len(self.data_list) 35 | 36 | def make_label(self, label): 37 | return label 38 | 39 | def pre_processing(self, img): 40 | return img 41 | 42 | 43 | if __name__ == '__main__': 44 | from tqdm import tqdm 45 | from torchvision import transforms 46 | from matplotlib import pyplot as plt 47 | 48 | # 支持中文 49 | plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 50 | plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 51 | 52 | json_path = r'D:\dataset\icdar2017rctw\recognition\train.txt' 53 | 54 | dataset = DetDataSet(json_path, transform=transforms.ToTensor()) 55 | train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=0) 56 | pbar = tqdm(total=len(train_loader)) 57 | for i, (img, label) in enumerate(train_loader): 58 | img = img[0].numpy().transpose(1, 2, 0) 59 | plt.title(label[0]) 60 | plt.imshow(img) 61 | plt.show() 62 | pbar.update(1) 63 | pbar.close() 64 | -------------------------------------------------------------------------------- /gt_detection.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_root": "", 3 | "data_list": [ 4 | { 5 | "img_name": "relative/path/xxx.jpg", 6 | "annotations": [ 7 | { 8 | "polygon": [ 9 | [ 10 | x1, 11 | y1 12 | ], 13 | [ 14 | x2, 15 | y2 16 | ], 17 | [ 18 | x3, 19 | y3 20 | ], 21 | [ 22 | x4, 23 | y4 24 | ] 25 | ], 26 | "text": "label", 27 | "illegibility":false, 28 | "language":"Latin", 29 | "chars": [ 30 | { 31 | "polygon": [ 32 | [ 33 | x1, 34 | y1 35 | ], 36 | [ 37 | x2, 38 | y2 39 | ], 40 | [ 41 | x3, 42 | y3 43 | ], 44 | [ 45 | x4, 46 | y4 47 | ] 48 | ], 49 | "char": "c", 50 | "illegibility": false, 51 | "language":"Latin" 52 | } 53 | ] 54 | } 55 | ] 56 | } 57 | ] 58 | } -------------------------------------------------------------------------------- /ocr公开数据集信息.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenmuZhou/OCR_DataSet/862956dd894e025e1a636f17660770bb25880129/ocr公开数据集信息.xlsx --------------------------------------------------------------------------------