├── .idea
├── .gitignore
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── layout_analysis.iml
├── misc.xml
└── modules.xml
├── 8mpt
├── F1_curve.png
├── PR_curve.png
├── P_curve.png
├── R_curve.png
├── best.pt.txt
└── results.png
├── 8npt
├── F1_curve.png
├── PR_curve.png
├── P_curve.png
├── R_curve.png
├── best.pt.txt
└── results.png
├── README.md
├── coco_2_yolo.py
├── img-layout.yaml
├── infer_app.py
├── requirements.txt
├── result
├── test1_result.png
├── test2_result.png
└── test3_result.png
└── train_app.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/layout_analysis.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/8mpt/F1_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8mpt/F1_curve.png
--------------------------------------------------------------------------------
/8mpt/PR_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8mpt/PR_curve.png
--------------------------------------------------------------------------------
/8mpt/P_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8mpt/P_curve.png
--------------------------------------------------------------------------------
/8mpt/R_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8mpt/R_curve.png
--------------------------------------------------------------------------------
/8mpt/best.pt.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8mpt/best.pt.txt
--------------------------------------------------------------------------------
/8mpt/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8mpt/results.png
--------------------------------------------------------------------------------
/8npt/F1_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8npt/F1_curve.png
--------------------------------------------------------------------------------
/8npt/PR_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8npt/PR_curve.png
--------------------------------------------------------------------------------
/8npt/P_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8npt/P_curve.png
--------------------------------------------------------------------------------
/8npt/R_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8npt/R_curve.png
--------------------------------------------------------------------------------
/8npt/best.pt.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8npt/best.pt.txt
--------------------------------------------------------------------------------
/8npt/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/8npt/results.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ### 利用yolov8对中文文档图片进行版面检测
2 | yolov8 is used to detect the layout of Chinese document images
3 |
4 | #### 模型下载、训练及推理
5 | 本项目根据开源中文版面数据[CDLA](https://github.com/buptlihang/CDLA) ,利用yolov8训练两个模型8mpt与8npt,
6 |
7 | CDLA是一个中文文档版面分析数据集,面向中文文献类(论文)场景。包含以下10个label:
8 |
9 | |正文|标题|图片|图片标题|表格|表格标题|页眉|页脚|注释|公式|
10 | |---|---|---|---|---|---|---|---|---|---|
11 | |Text|Title|Figure|Figure caption|Table|Table caption|Header|Footer|Reference|Equation|
12 |
13 | 8mpt模型与8npt模型下载:
14 |
15 | 链接:https://pan.baidu.com/s/1YakM5AYrakoG9hYN-w7mJw
16 |
17 | 提取码:j2za
18 |
19 | 训练:
20 | ```
21 | from ultralytics import YOLO
22 |
23 | def train_model():
24 | # 加载模型
25 | print('model load。。。')
26 | model = YOLO("8npt/best.pt") # 加载模型
27 | print('model load completed。。。')
28 | #使用模型
29 | model.train(data="img-layout.yaml", epochs=300, device=1)# , lr0=0.0001) # 训练模型
30 | metrics = model.val() # 在验证集上评估模型性能
31 | ```
32 | 8npt
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 | 8mpt
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 | 推理:
51 | ```
52 | from ultralytics import YOLO
53 | def infer():
54 | model = YOLO('8npt/best.pt')
55 | results = model('img.jpg')
56 | print(results[0].plot())
57 | cv2.imwrite('result.png', results[0].plot())
58 | ```
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 | #### contact
70 |
71 | 1、github:https://github.com/jiangnanboy
72 |
73 | 2、博客:https://www.cnblogs.com/little-horse/
74 |
75 | 3、邮件:2229029156@qq.com
76 |
77 | #### reference
78 | https://github.com/ultralytics/ultralytics
79 |
80 | https://github.com/buptlihang/CDLA
81 |
82 |
--------------------------------------------------------------------------------
/coco_2_yolo.py:
--------------------------------------------------------------------------------
1 | from pycocotools.coco import COCO
2 | import numpy as np
3 | import tqdm
4 | import argparse
5 |
6 |
7 | def arg_parser():
8 | parser = argparse.ArgumentParser('code by rbj')
9 | parser.add_argument('--annotation_path', type=str,
10 | default='/home/sy/data/img_layout/instance_train.json')
11 | #生成的txt文件保存的目录
12 | parser.add_argument('--save_base_path', type=str, default='/home/sy/data/img_layout/train_labels/')
13 | args = parser.parse_args(args=[])
14 | #原网页中是args = parser.parse_args()会报错,改成这个以后解决了
15 | return args
16 |
17 | if __name__ == '__main__':
18 | args = arg_parser()
19 | annotation_path = args.annotation_path
20 | save_base_path = args.save_base_path
21 |
22 | data_source = COCO(annotation_file=annotation_path)
23 | catIds = data_source.getCatIds()
24 | categories = data_source.loadCats(catIds)
25 | categories.sort(key=lambda x: x['id'])
26 | classes = {}
27 | coco_labels = {}
28 | coco_labels_inverse = {}
29 | for c in categories:
30 | coco_labels[len(classes)] = c['id']
31 | coco_labels_inverse[c['id']] = len(classes)
32 | classes[c['name']] = len(classes)
33 |
34 | img_ids = data_source.getImgIds()
35 | for index, img_id in tqdm.tqdm(enumerate(img_ids), desc='change .json file to .txt file'):
36 | img_info = data_source.loadImgs(img_id)[0]
37 | file_name = img_info['file_name'].split('.')[0]
38 | height = img_info['height']
39 | width = img_info['width']
40 |
41 | save_path = save_base_path + file_name + '.txt'
42 | with open(save_path, mode='w') as fp:
43 | annotation_id = data_source.getAnnIds(img_id)
44 | boxes = np.zeros((0, 5))
45 | if len(annotation_id) == 0:
46 | fp.write('')
47 | continue
48 | annotations = data_source.loadAnns(annotation_id)
49 | lines = ''
50 | for annotation in annotations:
51 | box = annotation['bbox']
52 | # some annotations have basically no width / height, skip them
53 | if box[2] < 1 or box[3] < 1:
54 | continue
55 | #top_x,top_y,width,height---->cen_x,cen_y,width,height
56 | box[0] = round((box[0] + box[2] / 2) / width, 6)
57 | box[1] = round((box[1] + box[3] / 2) / height, 6)
58 | box[2] = round(box[2] / width, 6)
59 | box[3] = round(box[3] / height, 6)
60 | label = coco_labels_inverse[annotation['category_id'] - 1]
61 | lines = lines + str(label)
62 | for i in box:
63 | lines += ' ' + str(i)
64 | lines += '\n'
65 | fp.writelines(lines)
66 | print('finish')
67 |
--------------------------------------------------------------------------------
/img-layout.yaml:
--------------------------------------------------------------------------------
1 | train: /home/sy/img_layout/images/train # 训练集
2 | val: /home/sy/img_layout/images/val # 验证集
3 | test: # 验证集
4 |
5 | nc: 10 # 数据集类别数量
6 | names: ['Header', 'Text', 'Reference', 'Figure caption', 'Figure', 'Table caption', 'Table', 'Title', 'Footer', 'Equation']
7 |
8 |
--------------------------------------------------------------------------------
/infer_app.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import cv2
4 |
5 | sys.path.insert(0, os.path.dirname(os.getcwd()))
6 | os.environ['CUDA_VISIBLE_DEVICES'] = '1'
7 |
8 | from ultralytics import YOLO
9 |
10 | def infer():
11 | model = YOLO('8npt/best.pt')
12 | results = model('img.jpg')
13 | print(results[0].plot())
14 | cv2.imwrite('result.png', results[0].plot())
15 |
16 | if __name__ == '__main__':
17 | infer()
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-contrib-python=4.7.0.68
2 | opencv-python=4.6.0.66
3 | opencv-python-headless=4.7.0.68
4 |
5 | ultralytics=8.0.51
6 | Python>=3.7
7 | PyTorch>=1.7.
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/result/test1_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/result/test1_result.png
--------------------------------------------------------------------------------
/result/test2_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/result/test2_result.png
--------------------------------------------------------------------------------
/result/test3_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangnanboy/layout_analysis/b42520519267fc5b931484d30a026f0e7477724e/result/test3_result.png
--------------------------------------------------------------------------------
/train_app.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.insert(0, os.path.dirname(os.getcwd()))
4 | os.environ['CUDA_VISIBLE_DEVICES'] = '1'
5 |
6 | from ultralytics import YOLO
7 |
8 | def train_model():
9 | # 加载模型
10 | # model = YOLO("yolov8n.yaml") # 从头开始构建新模型
11 | print('model load。。。')
12 | model = YOLO("8npt/best.pt") # 加载模型
13 | print('model load completed。。。')
14 |
15 | # 使用模型
16 | # model.train(data="img-layout.yaml", epochs=300, device=1)# , lr0=0.0001) # 训练模型
17 | #
18 | # metrics = model.val() # 在验证集上评估模型性能
19 | #
20 | # print('metric : {}'.format(metrics))
21 |
22 | # results = model("https://ultralytics.com/images/bus.jpg") # 对图像进行预测
23 | success = model.export(format="onnx") # 将模型导出为 ONNX 格式
24 |
25 | if __name__ == '__main__':
26 | train_model()
--------------------------------------------------------------------------------