├── .idea ├── .gitignore ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── layout_analysis.iml ├── misc.xml └── modules.xml ├── 8mpt ├── F1_curve.png ├── PR_curve.png ├── P_curve.png ├── R_curve.png ├── best.pt.txt └── results.png ├── 8npt ├── F1_curve.png ├── PR_curve.png ├── P_curve.png ├── R_curve.png ├── best.pt.txt └── results.png ├── README.md ├── coco_2_yolo.py ├── img-layout.yaml ├── infer_app.py ├── requirements.txt ├── result ├── test1_result.png ├── test2_result.png └── test3_result.png └── train_app.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 69 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/layout_analysis.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /8mpt/F1_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8mpt/F1_curve.png -------------------------------------------------------------------------------- /8mpt/PR_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8mpt/PR_curve.png -------------------------------------------------------------------------------- /8mpt/P_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8mpt/P_curve.png -------------------------------------------------------------------------------- /8mpt/R_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8mpt/R_curve.png -------------------------------------------------------------------------------- /8mpt/best.pt.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8mpt/best.pt.txt -------------------------------------------------------------------------------- /8mpt/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8mpt/results.png -------------------------------------------------------------------------------- /8npt/F1_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8npt/F1_curve.png -------------------------------------------------------------------------------- /8npt/PR_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8npt/PR_curve.png -------------------------------------------------------------------------------- /8npt/P_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8npt/P_curve.png -------------------------------------------------------------------------------- /8npt/R_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8npt/R_curve.png -------------------------------------------------------------------------------- /8npt/best.pt.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8npt/best.pt.txt -------------------------------------------------------------------------------- /8npt/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8npt/results.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### 利用yolov8对中文文档图片进行版面检测 2 | yolov8 is used to detect the layout of Chinese document images 3 | 4 | #### 模型下载、训练及推理 5 | 本项目根据开源中文版面数据[CDLA](https://github.com/buptlihang/CDLA) ,利用yolov8训练两个模型8mpt与8npt, 6 | 7 | CDLA是一个中文文档版面分析数据集,面向中文文献类(论文)场景。包含以下10个label: 8 | 9 | |正文|标题|图片|图片标题|表格|表格标题|页眉|页脚|注释|公式| 10 | |---|---|---|---|---|---|---|---|---|---| 11 | |Text|Title|Figure|Figure caption|Table|Table caption|Header|Footer|Reference|Equation| 12 | 13 | 8mpt模型与8npt模型下载: 14 | 15 | 链接:https://pan.baidu.com/s/1YakM5AYrakoG9hYN-w7mJw 16 | 17 | 提取码:j2za 18 | 19 | 训练: 20 | ``` 21 | from ultralytics import YOLO 22 | 23 | def train_model(): 24 | # 加载模型 25 | print('model load。。。') 26 | model = YOLO("8npt/best.pt") # 加载模型 27 | print('model load completed。。。') 28 | #使用模型 29 | model.train(data="img-layout.yaml", epochs=300, device=1)# , lr0=0.0001) # 训练模型 30 | metrics = model.val() # 在验证集上评估模型性能 31 | ``` 32 | 8npt 33 |
34 |

35 | 36 | 37 | 38 |

39 |
40 | 41 | 8mpt 42 |
43 |

44 | 45 | 46 | 47 |

48 |
49 | 50 | 推理: 51 | ``` 52 | from ultralytics import YOLO 53 | def infer(): 54 | model = YOLO('8npt/best.pt') 55 | results = model('img.jpg') 56 | print(results[0].plot()) 57 | cv2.imwrite('result.png', results[0].plot()) 58 | ``` 59 | 60 |
61 |

62 | 63 | 64 | 65 |

66 |
67 | 68 | 69 | #### contact 70 | 71 | 1、github:https://github.com/jiangnanboy 72 | 73 | 2、博客:https://www.cnblogs.com/little-horse/ 74 | 75 | 3、邮件:2229029156@qq.com 76 | 77 | #### reference 78 | https://github.com/ultralytics/ultralytics 79 | 80 | https://github.com/buptlihang/CDLA 81 | 82 | -------------------------------------------------------------------------------- /coco_2_yolo.py: -------------------------------------------------------------------------------- 1 | from pycocotools.coco import COCO 2 | import numpy as np 3 | import tqdm 4 | import argparse 5 | 6 | 7 | def arg_parser(): 8 | parser = argparse.ArgumentParser('code by rbj') 9 | parser.add_argument('--annotation_path', type=str, 10 | default='/home/sy/data/img_layout/instance_train.json') 11 | #生成的txt文件保存的目录 12 | parser.add_argument('--save_base_path', type=str, default='/home/sy/data/img_layout/train_labels/') 13 | args = parser.parse_args(args=[]) 14 | #原网页中是args = parser.parse_args()会报错,改成这个以后解决了 15 | return args 16 | 17 | if __name__ == '__main__': 18 | args = arg_parser() 19 | annotation_path = args.annotation_path 20 | save_base_path = args.save_base_path 21 | 22 | data_source = COCO(annotation_file=annotation_path) 23 | catIds = data_source.getCatIds() 24 | categories = data_source.loadCats(catIds) 25 | categories.sort(key=lambda x: x['id']) 26 | classes = {} 27 | coco_labels = {} 28 | coco_labels_inverse = {} 29 | for c in categories: 30 | coco_labels[len(classes)] = c['id'] 31 | coco_labels_inverse[c['id']] = len(classes) 32 | classes[c['name']] = len(classes) 33 | 34 | img_ids = data_source.getImgIds() 35 | for index, img_id in tqdm.tqdm(enumerate(img_ids), desc='change .json file to .txt file'): 36 | img_info = data_source.loadImgs(img_id)[0] 37 | file_name = img_info['file_name'].split('.')[0] 38 | height = img_info['height'] 39 | width = img_info['width'] 40 | 41 | save_path = save_base_path + file_name + '.txt' 42 | with open(save_path, mode='w') as fp: 43 | annotation_id = data_source.getAnnIds(img_id) 44 | boxes = np.zeros((0, 5)) 45 | if len(annotation_id) == 0: 46 | fp.write('') 47 | continue 48 | annotations = data_source.loadAnns(annotation_id) 49 | lines = '' 50 | for annotation in annotations: 51 | box = annotation['bbox'] 52 | # some annotations have basically no width / height, skip them 53 | if box[2] < 1 or box[3] < 1: 54 | continue 55 | #top_x,top_y,width,height---->cen_x,cen_y,width,height 56 | box[0] = round((box[0] + box[2] / 2) / width, 6) 57 | box[1] = round((box[1] + box[3] / 2) / height, 6) 58 | box[2] = round(box[2] / width, 6) 59 | box[3] = round(box[3] / height, 6) 60 | label = coco_labels_inverse[annotation['category_id'] - 1] 61 | lines = lines + str(label) 62 | for i in box: 63 | lines += ' ' + str(i) 64 | lines += '\n' 65 | fp.writelines(lines) 66 | print('finish') 67 | -------------------------------------------------------------------------------- /img-layout.yaml: -------------------------------------------------------------------------------- 1 | train: /home/sy/img_layout/images/train # 训练集 2 | val: /home/sy/img_layout/images/val # 验证集 3 | test: # 验证集 4 | 5 | nc: 10 # 数据集类别数量 6 | names: ['Header', 'Text', 'Reference', 'Figure caption', 'Figure', 'Table caption', 'Table', 'Title', 'Footer', 'Equation'] 7 | 8 | -------------------------------------------------------------------------------- /infer_app.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import cv2 4 | 5 | sys.path.insert(0, os.path.dirname(os.getcwd())) 6 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 7 | 8 | from ultralytics import YOLO 9 | 10 | def infer(): 11 | model = YOLO('8npt/best.pt') 12 | results = model('img.jpg') 13 | print(results[0].plot()) 14 | cv2.imwrite('result.png', results[0].plot()) 15 | 16 | if __name__ == '__main__': 17 | infer() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-contrib-python=4.7.0.68 2 | opencv-python=4.6.0.66 3 | opencv-python-headless=4.7.0.68 4 | 5 | ultralytics=8.0.51 6 | Python>=3.7 7 | PyTorch>=1.7. 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /result/test1_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/result/test1_result.png -------------------------------------------------------------------------------- /result/test2_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/result/test2_result.png -------------------------------------------------------------------------------- /result/test3_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/result/test3_result.png -------------------------------------------------------------------------------- /train_app.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.insert(0, os.path.dirname(os.getcwd())) 4 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 5 | 6 | from ultralytics import YOLO 7 | 8 | def train_model(): 9 | # 加载模型 10 | # model = YOLO("yolov8n.yaml") # 从头开始构建新模型 11 | print('model load。。。') 12 | model = YOLO("8npt/best.pt") # 加载模型 13 | print('model load completed。。。') 14 | 15 | # 使用模型 16 | # model.train(data="img-layout.yaml", epochs=300, device=1)# , lr0=0.0001) # 训练模型 17 | # 18 | # metrics = model.val() # 在验证集上评估模型性能 19 | # 20 | # print('metric : {}'.format(metrics)) 21 | 22 | # results = model("https://ultralytics.com/images/bus.jpg") # 对图像进行预测 23 | success = model.export(format="onnx") # 将模型导出为 ONNX 格式 24 | 25 | if __name__ == '__main__': 26 | train_model() --------------------------------------------------------------------------------