├── .gitignore ├── ImgHelper.py ├── LICENSE ├── LogHelper.py ├── NpyHelper.py ├── README.en.md ├── README.md ├── __init__.py ├── box.py ├── box_cluster.py ├── coco2xml.py ├── convert.py ├── crop_img_base_bbox.py ├── cv2_utils.py ├── data_aug.py ├── data_augment.py ├── data_loader.py ├── file2img.py ├── generate_test_json.py ├── line.py ├── line_cluster.py ├── pdf.py ├── point.py ├── rm_watermark.py ├── similary.py ├── spilt_train_val.py ├── table.py ├── text.py ├── text_angle.py ├── utils.py ├── video.py ├── video_speech_word ├── audioSeg.py └── video_speech_word.py └── xml2coco.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # IPython 78 | profile_default/ 79 | ipython_config.py 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | .dmypy.json 112 | dmypy.json 113 | 114 | # Pyre type checker 115 | .pyre/ 116 | -------------------------------------------------------------------------------- /ImgHelper.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | from PIL import Image 5 | import numpy 6 | 7 | 8 | def PIL_image_to_Opencv_image(path): 9 | ''' 10 | PIL.Image格式转换为OpenCV BGR格式 11 | :param path: 12 | :return: 13 | ''' 14 | image = Image.open(path) 15 | image.show() 16 | img = cv2.cvtColor(numpy.asarray(image), cv2.COLOR_RGB2BGR) 17 | # cv2.imshow("OpenCV", img) 18 | # cv2.waitKey() 19 | return img 20 | 21 | def Opencv_image_to_PIL_image(path): 22 | ''' 23 | OpenCV BGR格式转换为PIL.Image格式 24 | :param path: 25 | :return: 26 | ''' 27 | img = cv2.imread(path) 28 | cv2.imshow("OpenCV",img) 29 | image = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB)) 30 | image.show() 31 | cv2.waitKey() 32 | return image 33 | 34 | 35 | def png2jpg(path): 36 | print(path) 37 | for filename in os.listdir(path): 38 | if os.path.splitext(filename)[1] == '.png': 39 | # print(filename) 40 | img = cv2.imread(path + filename) 41 | print(filename.replace(".png", ".jpg")) 42 | newfilename = filename.replace(".png", ".jpg") 43 | cv2.imshow("Image",img) 44 | cv2.waitKey(0) 45 | cv2.imwrite(path + newfilename, img) 46 | 47 | if __name__ == '__main__': 48 | path = '/media/hxzh02/SB@home/hxzh/Dataset/Plane_detect_datasets/VOCdevkit_lineextract_detect/VOC2007/JPEGImages/' 49 | png2jpg(path) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 杂杳 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LogHelper.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | class Logger(object): 5 | def __init__(self, filename="Default.log"): 6 | self.terminal = sys.stdout 7 | self.log = open(filename, "a") 8 | 9 | def write(self, message): 10 | self.terminal.write(message) 11 | self.log.write(message) 12 | 13 | def flush(self): 14 | pass 15 | 16 | 17 | if __name__ == '__main__': 18 | sys.stdout = Logger('a.txt') 19 | 20 | print("测试") -------------------------------------------------------------------------------- /NpyHelper.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | """ 4 | @ Author: LinXu 5 | @ Contact: 17746071609@163.com 6 | @ Date: 2022/01/15 15:35 PM 7 | @ Software: PyCharm 8 | @ File: NpyHelper.py 9 | @ Desc: .npy格式相关支持函数 10 | """ 11 | 12 | import os 13 | import numpy as np 14 | 15 | def readFiledir_saveNpy(file_dir, npy_name): 16 | a = os.listdir(file_dir) # 读取文件夹中的目录文件 17 | # print(a) 18 | save_path = npy_name + '.npy' 19 | np.save(save_path, a) 20 | 21 | 22 | def read_npyfile(npy_path): 23 | # 读取.npy文件 24 | arr = np.load(npy_path) 25 | # print(arr) 26 | return arr 27 | 28 | if __name__ == '__main__': 29 | file_dir = '/media/hxzh02/SB@home/hxzh/Dataset/Plane_detect_datasets/VOCdevkit_lineextract_detect/VOC2007/Annotations/' # 文件夹的路径 30 | npy_name = 'xml' 31 | # readFiledir_saveNpy(file_dir,npy_name) 32 | 33 | npy_path = 'xml.npy' 34 | read_npyfile(npy_path) -------------------------------------------------------------------------------- /README.en.md: -------------------------------------------------------------------------------- 1 | # AI-toolsBox 2 | 3 | #### Description 4 | AI算法工具箱 5 | 6 | #### Software Architecture 7 | Software architecture description 8 | 9 | #### Installation 10 | 11 | 1. xxxx 12 | 2. xxxx 13 | 3. xxxx 14 | 15 | #### Instructions 16 | 17 | 1. xxxx 18 | 2. xxxx 19 | 3. xxxx 20 | 21 | #### Contribution 22 | 23 | 1. Fork the repository 24 | 2. Create Feat_xxx branch 25 | 3. Commit your code 26 | 4. Create Pull Request 27 | 28 | 29 | #### Gitee Feature 30 | 31 | 1. You can use Readme\_XXX.md to support different languages, such as Readme\_en.md, Readme\_zh.md 32 | 2. Gitee blog [blog.gitee.com](https://blog.gitee.com) 33 | 3. Explore open source project [https://gitee.com/explore](https://gitee.com/explore) 34 | 4. The most valuable open source project [GVP](https://gitee.com/gvp) 35 | 5. The manual of Gitee [https://gitee.com/help](https://gitee.com/help) 36 | 6. The most popular members [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AI-toolsBox 2 | 3 | #### 介绍 4 | AI算法工具箱 5 | 6 | #### 软件架构 7 | 软件架构说明 8 | 9 | 10 | #### 安装教程 11 | 12 | 1. xxxx 13 | 2. xxxx 14 | 3. xxxx 15 | 16 | #### 使用说明 17 | 18 | 1. xxxx 19 | 2. xxxx 20 | 3. xxxx 21 | 22 | #### 参与贡献 23 | 24 | 1. Fork 本仓库 25 | 2. 新建 Feat_xxx 分支 26 | 3. 提交代码 27 | 4. 新建 Pull Request 28 | 29 | 30 | #### 特技 31 | 32 | 1. 使用 Readme\_XXX.md 来支持不同的语言,例如 Readme\_en.md, Readme\_zh.md 33 | 2. Gitee 官方博客 [blog.gitee.com](https://blog.gitee.com) 34 | 3. 你可以 [https://gitee.com/explore](https://gitee.com/explore) 这个地址来了解 Gitee 上的优秀开源项目 35 | 4. [GVP](https://gitee.com/gvp) 全称是 Gitee 最有价值开源项目,是综合评定出的优秀开源项目 36 | 5. Gitee 官方提供的使用手册 [https://gitee.com/help](https://gitee.com/help) 37 | 6. Gitee 封面人物是一档用来展示 Gitee 会员风采的栏目 [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) 38 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # 4 | # Author: alex 5 | # Created Time: 2020年03月18日 星期三 17时08分50秒 6 | 7 | -------------------------------------------------------------------------------- /box.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # box相关函数 4 | # 5 | # Author: alex 6 | # Created Time: 2020年03月10日 星期二 7 | import numpy as np 8 | 9 | 10 | def intersection_area(box1, box2): 11 | """计算两个矩形的重叠面积""" 12 | x1, y1, xb1, yb1 = box1 13 | x2, y2, xb2, yb2 = box2 14 | 15 | # 相交矩形 16 | ax, ay, bx, by = max(x1, x2), max(y1, y2), min(xb1, xb2), min(yb1, yb2) 17 | if ax >= bx or ay >= by: 18 | return 0 19 | 20 | # 重叠面积 21 | in_area = (bx-ax) * (by-ay) 22 | # print((ax, ay, bx, by)) 23 | # print('相交面积:%d' % in_area) 24 | return in_area 25 | 26 | 27 | def iou(box1, box2): 28 | """计算两个矩形的交并比""" 29 | in_area = intersection_area(box1, box2) 30 | if in_area == 0: 31 | return 0. 32 | 33 | x1, y1, xb1, yb1 = box1 34 | x2, y2, xb2, yb2 = box2 35 | area1 = abs((xb1-x1) * (yb1-y1)) 36 | area2 = abs((xb2-x2) * (yb2-y2)) 37 | return in_area / (area1 + area2 - in_area) 38 | 39 | 40 | def in_box_rate(box, container_box): 41 | """判断一个box在一个容器box里的占比 42 | :return float 例如返回值为0.6,则表示box在容器box中的面积占box的60% 43 | """ 44 | in_area = intersection_area(box, container_box) 45 | if in_area == 0: 46 | return 0. 47 | x1, y1, xb1, yb1 = box 48 | area = abs((xb1-x1) * (yb1-y1)) 49 | return in_area / area 50 | 51 | 52 | def boxes_in_row(box1, box2): 53 | """判断两个box是否在同一行""" 54 | if iou(box1, box2) > 0: 55 | return False # 不能有交集 56 | if box1[0] > box2[0]: 57 | box1, box2 = box2, box1 58 | 59 | _, y1, xb1, yb1 = box1 60 | x2, y2, _, yb2 = box2 61 | if xb1 > x2: 62 | return False # box2必须在box1的右边 63 | 64 | # 垂直方向上交集 65 | min_yb, max_y = min(yb1, yb2), max(y1, y2) 66 | if min_yb <= max_y: 67 | return False # 如果没有交集 68 | max_h = max(yb1, yb2) - min(y1, y2) 69 | min_h = min_yb - max_y 70 | 71 | # 高度差 72 | h1, h2 = yb1-y1, yb2-y2 73 | h1, h2 = min(h1, h2), max(h1, h2) 74 | print((h2-h1)/h1) 75 | 76 | # 水平方向需要相邻 & 重叠部分超过80% & 高度差不能超过20% 77 | # TODO 这里的参数可能不是最优的,可以经过测试调整 78 | return min_h/max_h > 0.8 and (x2-xb1) < 2*min_h and (h2-h1)/h1 < 0.2 79 | 80 | 81 | def solve(box): 82 | """ 83 | 绕 cx,cy点 w,h 旋转 angle 的坐标 84 | x = cx-w/2 85 | y = cy-h/2 86 | x1-cx = -w/2*cos(angle) +h/2*sin(angle) 87 | y1 -cy= -w/2*sin(angle) -h/2*cos(angle) 88 | 89 | h(x1-cx) = -wh/2*cos(angle) +hh/2*sin(angle) 90 | w(y1 -cy)= -ww/2*sin(angle) -hw/2*cos(angle) 91 | (hh+ww)/2sin(angle) = h(x1-cx)-w(y1 -cy) 92 | :param box 四个顶点坐标[x1, y1, x2, y2, x3, y3, x4, y4] 93 | """ 94 | x1, y1, x2, y2, x3, y3, x4, y4 = box[:8] 95 | cx = (x1+x3+x2+x4)/4.0 96 | cy = (y1+y3+y4+y2)/4.0 97 | w = (np.sqrt((x2-x1)**2+(y2-y1)**2)+np.sqrt((x3-x4)**2+(y3-y4)**2))/2 98 | h = (np.sqrt((x2-x3)**2+(y2-y3)**2)+np.sqrt((x1-x4)**2+(y1-y4)**2))/2 99 | # x = cx-w/2 100 | # y = cy-h/2 101 | # sinA = ((y2+y3)/2 - (y1+y4)/2) / w 102 | sinA = (h * (x1 - cx) - w * (y1 - cy)) * 1.0 / (h * h + w * w) * 2 103 | angle = np.arcsin(sinA) 104 | return angle, w, h, cx, cy 105 | 106 | 107 | def rotate_cut_img(img, box, degree, wh, center, 108 | rotate=False, leftAdjust=1.0, rightAdjust=1.0): 109 | """四边形旋转并裁剪图像,通常和solve搭配使用 110 | :param img PIL图像 111 | :param box 四个顶点坐标[x1, y1, x2, y2, x3, y3, x4, y4] 112 | :param degree 选择角度, 对应solve函数中的angle 113 | :param wh 对应solve函数中的w和h 114 | :param center 中心点坐标, 对应solve函数中的cx和cy 115 | """ 116 | # 原图坐标 117 | # degree, w, h, x_center, y_center = solve(box) 118 | w, h = wh 119 | x_center, y_center = center 120 | xmin_ = min(box[0::2]) 121 | xmax_ = max(box[0::2]) 122 | ymin_ = min(box[1::2]) 123 | ymax_ = max(box[1::2]) 124 | 125 | # 第一次裁剪 126 | img = img.crop([xmin_, ymin_, xmax_, ymax_]) 127 | 128 | # 裁剪后的中心点 129 | x_center = x_center - xmin_ 130 | y_center = y_center - ymin_ 131 | 132 | # 旋转时长度不变: 左上右下点坐标 133 | xmin = max(0, x_center-w/2-leftAdjust*h) 134 | ymin = y_center-h/2 135 | xmax = min(x_center+w/2+rightAdjust*h, img.size[0]-1) 136 | ymax = y_center+h/2 137 | 138 | # 按照裁剪后的中心点旋转并裁剪 139 | degree_ = degree*180.0/np.pi 140 | if rotate is False: 141 | crop_img = img.crop([xmin, ymin, xmax, ymax]) 142 | else: 143 | if abs(degree_) <= 0.0001: 144 | # 不需要进行旋转 145 | degree_ = 0 146 | crop_img = img.crop([xmin, ymin, xmax, ymax]) 147 | else: 148 | crop_img = img.rotate(degree_, center=(x_center, y_center))\ 149 | .crop([xmin, ymin, xmax, ymax]) 150 | 151 | return crop_img 152 | 153 | 154 | if __name__ == '__main__': 155 | box1 = [325.8022766113281, 393.0766296386719, 156 | 592.3567504882812, 435.80364990234375] 157 | box2 = [620.7103881835938, 397.5979309082031, 158 | 660.7562255859375, 433.3531188964844] 159 | print(boxes_in_row(box1, box2)) 160 | box1 = [339.4958190917969, 222.9739532470703, 161 | 581.1708374023438, 261.8145751953125] 162 | box2 = [634.3968505859375, 222.9739532470703, 163 | 670.3604125976562, 264.6916809082031] 164 | print(boxes_in_row(box1, box2)) 165 | -------------------------------------------------------------------------------- /box_cluster.py: -------------------------------------------------------------------------------- 1 | ''' 2 | box聚类 3 | 4 | Author: alex 5 | Created Time: 2020年07月02日 星期四 17时09分54秒 6 | ''' 7 | from sklearn.cluster import DBSCAN 8 | from image_utils.line import iou_line 9 | 10 | 11 | def boxes_cluster(boxes, row=True, col=True, iou_score=0.5): 12 | """将box框进行聚类 13 | :param boxes List[List[float]] 待聚类的box列表,每个box的格式:[x1,y1,x2,y2] 14 | :param row bool 是否进行行聚类 15 | :param col bool 是否进行列聚类 16 | :return row_labels List[int]|None 每个box的行id 17 | :return col_labels List[int]|None 每个box的列id 18 | """ 19 | row_labels = None 20 | if row: 21 | row_labels = row_cluster(boxes, iou_score) 22 | 23 | col_labels = None 24 | if col: 25 | col_labels = col_cluster(boxes, iou_score) 26 | 27 | return row_labels, col_labels 28 | 29 | 30 | def row_cluster(boxes, iou_score): 31 | """按行聚类""" 32 | boxes = [[box[1], box[3]] for box in boxes] 33 | return do_cluster(boxes, iou_score) 34 | 35 | 36 | def col_cluster(boxes, iou_score): 37 | """按列聚类""" 38 | boxes = [[box[0], box[2]] for box in boxes] 39 | return do_cluster(boxes, iou_score) 40 | 41 | 42 | def do_cluster(boxes, iou_score): 43 | """实际聚类""" 44 | # boxes = sorted(boxes, key=lambda x: x[0]) 45 | db = DBSCAN(eps=iou_score, min_samples=1, metric=distance).fit(boxes) 46 | return db.labels_ 47 | 48 | 49 | def distance(box1, box2): 50 | """距离函数""" 51 | return 1 - iou_line(box1, box2) 52 | 53 | 54 | if __name__ == '__main__': 55 | boxes = [ 56 | [1, 1, 10, 10], [12, 2, 21, 12], 57 | [2, 9, 11, 20], [10, 11, 20, 22], 58 | [1.5, 22, 11.5, 30], 59 | [2, 39, 11, 50], [10, 41, 20, 52], [22, 40, 30, 51], 60 | ] 61 | row_labels, col_labels = boxes_cluster(boxes) 62 | print(row_labels) 63 | print(col_labels) 64 | assert row_labels.tolist() == [0, 0, 1, 1, 2, 3, 3, 3] 65 | assert col_labels.tolist() == [0, 1, 0, 1, 0, 0, 1, 2] 66 | -------------------------------------------------------------------------------- /coco2xml.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | #!/usr/bin/env python 4 | # -*- coding:utf-8 -*- 5 | import json 6 | import xml.etree.ElementTree as ET 7 | from collections import defaultdict 8 | from PIL import Image 9 | import numpy as np 10 | 11 | 12 | num_to_class = { 13 | "1": "1", 14 | "2": "2", 15 | "3": "3"} 16 | 17 | 18 | def write_xml(sh, sw, imgname, filepath, labeldicts): 19 | root = ET.Element('Annotation') 20 | ET.SubElement(root, 'filename').text = str(imgname) 21 | sizes = ET.SubElement(root, 'size') 22 | ET.SubElement(sizes, 'width').text = str(sw) 23 | ET.SubElement(sizes, 'height').text = str(sh) 24 | ET.SubElement(sizes, 'depth').text = '3' 25 | for labeldict in labeldicts: 26 | objects = ET.SubElement(root, 'object') 27 | ET.SubElement(objects, 'name').text = labeldict['name'] 28 | ET.SubElement(objects, 'pose').text = 'Unspecified' 29 | ET.SubElement(objects, 'truncated').text = '0' 30 | ET.SubElement(objects, 'difficult').text = '0' 31 | bndbox = ET.SubElement(objects, 'bndbox') 32 | if labeldict['xmin'] < 0: 33 | labeldict['xmin'] = 0 34 | if labeldict['ymin'] < 0: 35 | labeldict['ymin'] = 0 36 | if labeldict['xmax'] > sw: 37 | labeldict['xmax'] = sw 38 | if labeldict['ymax'] > sh: 39 | labeldict['ymax'] = sh 40 | ET.SubElement(bndbox, 'xmin').text = str(int(labeldict['xmin'])) 41 | ET.SubElement(bndbox, 'ymin').text = str(int(labeldict['ymin'])) 42 | ET.SubElement(bndbox, 'xmax').text = str(int(labeldict['xmax'])) 43 | ET.SubElement(bndbox, 'ymax').text = str(int(labeldict['ymax'])) 44 | tree = ET.ElementTree(root) 45 | tree.write(filepath, encoding='utf-8') 46 | 47 | 48 | def my_xml(file_name, annotations_path): 49 | # file_name = "train.json" 50 | # annotations_path = "xml/" 51 | with open(file_name, 'r', encoding='utf-8') as fr: 52 | load_dict = json.load(fr) 53 | imgToAnns = defaultdict(list) 54 | imgs = {} 55 | idToName = defaultdict(list) 56 | for ann in load_dict['annotations']: 57 | imgToAnns[str(ann['image_id'])].append(ann) 58 | for img in load_dict['images']: 59 | imgs[str(img['id'])] = img 60 | idToName[str(img['id'])] = img['file_name'] 61 | for key, values in imgToAnns.items(): 62 | label_dicts = [] 63 | for value in values: 64 | category_id = value["category_id"] 65 | new_dict = {'name': num_to_class[str(category_id)], 66 | # 'name': load_dict["categories"][category_id]["name"], 67 | 'difficult': '0', 68 | 'xmin': value["bbox"][0], 69 | 'ymin': value["bbox"][1], 70 | 'xmax': value["bbox"][0] + value["bbox"][2], 71 | 'ymax': value["bbox"][1] + value["bbox"][3] 72 | } 73 | label_dicts.append(new_dict) 74 | write_xml(imgs[key]["height"], imgs[key]["width"], idToName[key], 75 | annotations_path + idToName[key][0:-4] + '.xml', label_dicts) 76 | 77 | 78 | if __name__ == '__main__': 79 | my_xml('train.json', './traffic_voc/Annotations/') 80 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # 图像格式转换相关函数 4 | # Author: alex 5 | # Created Time: 2019年09月04日 星期三 09时31分18秒 6 | """ 7 | from PIL import Image 8 | img = Image.open(path) 9 | 10 | img.mode模式: 11 | 1 1位像素,黑和白,存成8位的像素 12 | L 8位像素,黑白 13 | P 8位像素,使用调色板映射到任何其他模式 14 | RGB 3×8位像素,真彩 15 | RGBA 4×8位像素,真彩+透明通道 16 | CMYK 4×8位像素,颜色隔离 17 | YCbCr 3×8位像素,彩色视频格式 18 | I 32位整型像素 19 | F 32位浮点型像素 20 | 21 | img.format: 这和文件后缀对应 22 | 23 | 注意P模式的图片,对应的format格式可能是GIF 24 | """ 25 | import re 26 | import cv2 27 | import base64 28 | import numpy as np 29 | from PIL import Image 30 | from io import BytesIO 31 | 32 | 33 | def base64_cv2(b64, is_color=True): 34 | """将base64格式的图片转换为cv2格式 35 | :param b64 str base64字符串 36 | :param is_color bool 是否为彩色图像,若为True则返回彩色图像,否则返回灰度图像 37 | :return numpy.ndarray cv2图像 38 | """ 39 | b64 = base64.b64decode(b64) 40 | nparr = np.fromstring(b64, np.uint8) 41 | color = cv2.IMREAD_COLOR if is_color else cv2.IMREAD_GRAYSCALE 42 | return cv2.imdecode(nparr, color) 43 | 44 | 45 | def cv2_base64(img, format='JPEG'): 46 | """将cv2格式的图像转换为base64格式 47 | :param img numpy.ndarray cv2图像 48 | :param format str 转化后的图像格式 49 | :return str base64字符串 50 | """ 51 | out_img = Image.fromarray(img) 52 | output_buffer = BytesIO() 53 | out_img.save(output_buffer, format=format) 54 | binary_data = output_buffer.getvalue() 55 | return str(base64.b64encode(binary_data), encoding='utf8') 56 | 57 | 58 | def base64_pil(b64): 59 | """将图片从base64格式转换为PIL格式""" 60 | base64_data = re.sub('^data:image/.+;base64,', '', b64) 61 | byte_data = base64.b64decode(base64_data) 62 | image_data = BytesIO(byte_data) 63 | return Image.open(image_data) 64 | 65 | 66 | def pil_base64(img, format='JPEG'): 67 | """将PIL图片转换为base64格式""" 68 | buf = BytesIO() 69 | if img.mode != 'RGB': 70 | img = img.convert('RGB') 71 | 72 | img.save(buf, format=format) 73 | binary_data = buf.getvalue() 74 | return str(base64.b64encode(binary_data), encoding='utf8') 75 | 76 | 77 | def cv2_pil(img): 78 | """将图片从cv2转换为PIL格式 79 | :param img numpy.ndarray cv2图像 80 | """ 81 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 82 | return Image.fromarray(img) 83 | 84 | 85 | def pil_cv2(img): 86 | """将图片从PIL转换为cv2格式""" 87 | return np.asarray(img) 88 | # return cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) 89 | 90 | 91 | def gif_jpg(img): 92 | """GIF格式图片转化为jpg""" 93 | palette = img.getpalette() 94 | img.putpalette(palette) 95 | new_img = Image.new("RGB", img.size) 96 | new_img.paste(img) 97 | return new_img 98 | 99 | 100 | def rotate(image, angle, center=None, scale=1.0, borderValue=(255, 255, 255)): 101 | """cv2旋转图像 102 | 效果比Image.rotate效果要好 103 | :param image cv2图像对象 104 | :param angle 旋转角度(注意这里是角度,而不是弧度) 105 | :param center 中心点 106 | :param scale 缩放比例 107 | :param borderValue 填充颜色,默认为白色 108 | """ 109 | # 获取图像尺寸 110 | (h, w) = image.shape[:2] 111 | 112 | # 若未指定旋转中心,则将图像中心设为旋转中心 113 | if center is None: 114 | center = (w / 2, h / 2) 115 | 116 | # 执行旋转 117 | M = cv2.getRotationMatrix2D(center, angle, scale) 118 | rotated = cv2.warpAffine(image, M, (w, h), borderValue=borderValue) 119 | 120 | # 返回旋转后的图像 121 | return rotated 122 | 123 | 124 | def rotate_pil(image, angle, center=None, scale=1.0): 125 | """PIL旋转图像 126 | 效果比Image.rotate效果要好,调用rotate进行实现 127 | """ 128 | image = np.asarray(image) 129 | rotated = rotate(image, angle) 130 | return Image.fromarray(rotated) 131 | 132 | 133 | def auto_rotate(image, angle, scale=1.0, borderValue=(255, 255, 255)): 134 | """cv2旋转图像(自动扩充图像) 135 | 效果比Image.rotate效果要好 136 | :param image cv2图像对象 137 | :param angle 旋转角度(注意这里是角度,而不是弧度) 138 | :param scale 缩放比例 139 | :param borderValue 填充颜色,默认为白色 140 | """ 141 | # 获取图像尺寸 142 | (h, w) = image.shape[:2] 143 | # 将图像中心设为旋转中心 144 | center = (w / 2, h / 2) 145 | 146 | # 执行旋转 147 | M = cv2.getRotationMatrix2D(center, angle, scale) 148 | cos = np.abs(M[0, 0]) 149 | sin = np.abs(M[0, 1]) 150 | 151 | # compute the new bounding dimensions of the image 152 | nw = int((h * sin) + (w * cos)) 153 | nh = int((h * cos) + (w * sin)) 154 | 155 | # adjust the rotation matrix to take into account translation 156 | M[0, 2] += (nw / 2) - center[0] 157 | M[1, 2] += (nh / 2) - center[1] 158 | 159 | # 返回旋转后的图像 160 | rotated = cv2.warpAffine(image, M, (nw, nh), borderValue=borderValue) 161 | return rotated 162 | 163 | 164 | def auto_rotate_pil(image, angle, center=None, scale=1.0): 165 | """PIL旋转图像(对应auto_rotate函数) 166 | 效果比Image.rotate效果要好,调用rotate进行实现 167 | """ 168 | image = np.asarray(image) 169 | rotated = auto_rotate(image, angle) 170 | return Image.fromarray(rotated) 171 | -------------------------------------------------------------------------------- /crop_img_base_bbox.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from os.path import join 4 | import cv2 5 | import glob 6 | 7 | root_dir = "./fruit" # 原始图片保存的位置 8 | save_dir = "./bbox" # 生成截取图片的位置 9 | 10 | jpg_list = glob.glob(root_dir + "/*.jpg") 11 | 12 | fo = open("dpj_small.txt", "w") # 截取出来的图片位置 13 | 14 | max_s = -1 15 | min_s = 1000 16 | 17 | for jpg_path in jpg_list: # 遍历所有图片 18 | # jpg_path = jpg_list[3] 19 | txt_path = jpg_path.replace("jpg", "txt") # 得到文件中相应注释的文件 20 | jpg_name = os.path.basename(jpg_path) # 21 | 22 | f = open(txt_path, "r") # 打开注释 23 | 24 | img = cv2.imread(jpg_path) # 打开图片 25 | 26 | height, width, channel = img.shape # 得到图片的尺寸 27 | 28 | file_contents = f.readlines() # 读取注释 29 | 30 | for num, file_content in enumerate(file_contents): # 31 | print(num) # 打印种类 32 | clss, xc, yc, w, h = file_content.split() # 得到种类和具体的坐标 33 | xc, yc, w, h = float(xc), float(yc), float(w), float(h) # 对坐标浮点化 34 | # 将归一化的坐标转换为实际的坐标 35 | xc *= width 36 | yc *= height 37 | w *= width 38 | h *= height 39 | # 防止坐标超出实际范围 40 | max_s = max(w * h, max_s) 41 | min_s = min(w * h, min_s) 42 | # 得到图像坐标系下的位置 43 | half_w, half_h = w // 2, h // 2 44 | 45 | x1, y1 = int(xc - half_w), int(yc - half_h) 46 | x2, y2 = int(xc + half_w), int(yc + half_h) 47 | # 进行截取 48 | crop_img = img[y1:y2, x1:x2] 49 | 50 | new_jpg_name = jpg_name.split('.')[0] + "_crop_" + str(num) + ".jpg" # 存储图片的名称 51 | cv2.imwrite(os.path.join(save_dir, new_jpg_name), crop_img) # 截取的图片 52 | # cv2.imshow("croped",crop_img) 53 | # cv2.waitKey(0) 54 | fo.write(os.path.join(save_dir, new_jpg_name) + "\n") # 截取后的注释 55 | 56 | f.close() 57 | 58 | fo.close() 59 | 60 | print(max_s, min_s) 61 | -------------------------------------------------------------------------------- /cv2_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | cv2相关函数 3 | 4 | Author: alex 5 | Created Time: 2020年11月03日 星期二 09时44分42秒 6 | ''' 7 | import cv2 8 | 9 | 10 | def lapulase(gray): 11 | """计算拉普拉斯算子:图像模糊度 12 | 注意: 13 | 1. 在比较模糊度的时候,图像应该resize到相同的大小 14 | 2. 返回的得分阈值很重要 15 | @param gray cv2灰度图像 16 | @return 清晰度得分 该值越大通常越清晰 17 | """ 18 | return cv2.Laplacian(gray, cv2.CV_64F).var() 19 | -------------------------------------------------------------------------------- /data_aug.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # author: hxy 3 | # 2021-4-25,python3.8 4 | """ 5 | 针对图片内小目标数据增强的脚本; 6 | 增加目标的数量; 7 | 更新标签文件.xml; 8 | """ 9 | import os 10 | import cv2 11 | import time 12 | 13 | try: 14 | import xml.etree.cElementTree as ET 15 | except ImportError: 16 | import xml.etree.ElementTree as ET 17 | 18 | from xml.dom.minidom import parseString 19 | 20 | from lxml.etree import Element, SubElement, tostring 21 | 22 | 23 | # 获取原始xml标签文件中的标注信息:obj_name、bbox 24 | def GetAnnotBoxLoc(AnotPath): 25 | tree = ET.ElementTree(file=AnotPath) 26 | root = tree.getroot() 27 | ObjectSet = root.findall('object') 28 | ObjBndBoxSet = list() 29 | for Object in ObjectSet: 30 | BndBoxLoc = dict() 31 | ObjName = Object.find('name').text 32 | BndBox = Object.find('bndbox') 33 | x1 = int(BndBox.find('xmin').text.split('.')[0]) 34 | y1 = int(BndBox.find('ymin').text.split('.')[0]) 35 | x2 = int(BndBox.find('xmax').text.split('.')[0]) 36 | y2 = int(BndBox.find('ymax').text.split('.')[0]) 37 | BndBoxLoc[ObjName] = [x1, y1, x2, y2] 38 | ObjBndBoxSet.append(BndBoxLoc) 39 | return ObjBndBoxSet 40 | 41 | 42 | # 裁剪标注的目标,并将其粘贴至原始目标上方和下方 43 | def paste_object(img_file, bndboxes): 44 | img = cv2.imread(img_file) 45 | new_all_bbox = list() 46 | for i in range(len(bndboxes)): 47 | new_bbox_up = dict() 48 | new_bbox_down = dict() 49 | for obj_name, bboxes in zip(bndboxes[i].keys(), bndboxes[i].values()): 50 | xmin = bboxes[0] 51 | ymin = bboxes[1] 52 | xmax = bboxes[2] 53 | ymax = bboxes[3] 54 | 55 | # 裁剪原始标注的目标物体 56 | obj_pic = img[ymin - 5:ymax + 5, xmin :xmax] 57 | h, w = obj_pic.shape[:2] 58 | 59 | # 将目标下移 60 | img[ymin - 5 + h:ymax + 5 + h, xmin:xmin + w] = obj_pic 61 | new_bbox_down[obj_name] = [xmin, ymin + h, xmin + w, ymax + h] 62 | new_all_bbox.append(new_bbox_down) 63 | # 将目标上移:这里为了防止越界,进行一个逻辑判断; 64 | if ymin - h - 5 > 0: 65 | img[ymin - 5 - h:ymax + 5 - h, xmin:xmin + w] = obj_pic 66 | new_bbox_up[obj_name] = [xmin, ymin - h, xmin + w, ymax - h] 67 | else: 68 | img[ymin - 5 + 2 * h:ymax + 5 + 2 * h, xmin:xmin + w] = obj_pic 69 | new_bbox_up[obj_name] = [xmin, ymin + 2 * h, xmin + w, ymax + 2 * h] 70 | new_all_bbox.append(new_bbox_up) 71 | output_all_bbox = bndboxes + new_all_bbox 72 | 73 | # valid_data_aug(img, output_all_bbox, "val_check") # 绘制box,验证数据增强的正确性 74 | return output_all_bbox, img.shape, img 75 | 76 | 77 | # 验证数据增强的结果 78 | def valid_data_aug(img, bndboxes, val_check_dir): 79 | if not os.path.exists(val_check_dir): 80 | os.mkdir(val_check_dir) 81 | # img = cv2.imread(img_file) 82 | for i in range(len(bndboxes)): 83 | for obj_name, bboxes in zip(bndboxes[i].keys(), bndboxes[i].values()): 84 | cv2.rectangle(img, (bboxes[0], bboxes[1]), (bboxes[2], bboxes[3]), (255, 0, 0), 1) 85 | cv2.putText(img, str(obj_name), (bboxes[0], bboxes[1]), cv2.FONT_HERSHEY_SIMPLEX, 86 | 0.5, (255, 255, 255), 1) 87 | 88 | cv2.imwrite(os.path.join(val_check_dir, 'result' + str(time.time()) + '.jpg'), img) 89 | 90 | return 91 | 92 | 93 | # 生成新的标签文件:.xml 94 | def write_xml(all_bndboxes, img_name, img_path, shape): 95 | node_root = Element('annotation') 96 | node_folder = SubElement(node_root, 'folder') 97 | node_folder.text = 'JPEGImage' 98 | 99 | node_img_name = SubElement(node_root, 'filename') 100 | node_img_name.text = img_name + '.jpg' 101 | node_img_path = SubElement(node_root, 'path') 102 | node_img_path.text = img_path 103 | 104 | node_source = SubElement(node_root, 'source') 105 | node_database = SubElement(node_source, 'database') 106 | node_database.text = 'Unknown' 107 | 108 | node_img_size = SubElement(node_root, 'size') 109 | node_img_width = SubElement(node_img_size, 'width') 110 | node_img_width.text = str(shape[1]) # 照片的w 111 | node_img_height = SubElement(node_img_size, 'height') 112 | node_img_height.text = str(shape[0]) # 照片的h 113 | node_img_depth = SubElement(node_img_size, 'depth') 114 | node_img_depth.text = str(shape[2]) # 照片的depth 115 | 116 | node_img_seg = SubElement(node_root, 'segmented') 117 | node_img_seg.text = '0' 118 | 119 | for i in range(len(all_bndboxes)): 120 | for obj_name, bboxes in zip(all_bndboxes[i].keys(), all_bndboxes[i].values()): 121 | node_obj = SubElement(node_root, 'object') 122 | node_obj_name = SubElement(node_obj, 'name') 123 | node_obj_name.text = obj_name # obj的名字 124 | 125 | node_bbox = SubElement(node_obj, 'bndbox') 126 | node_bbox_xmin = SubElement(node_bbox, 'xmin') 127 | node_bbox_xmin.text = str(bboxes[0]) 128 | node_bbox_ymin = SubElement(node_bbox, 'ymin') 129 | node_bbox_ymin.text = str(bboxes[1]) 130 | node_bbox_xmax = SubElement(node_bbox, 'xmax') 131 | node_bbox_xmax.text = str(bboxes[2]) 132 | node_bbox_ymax = SubElement(node_bbox, 'ymax') 133 | node_bbox_ymax.text = str(bboxes[3]) 134 | 135 | node_difficult = SubElement(node_obj, 'difficult') 136 | node_difficult.text = '0' # 全部设定为0 137 | 138 | xml = tostring(node_root) 139 | dom = parseString(xml) 140 | return dom 141 | 142 | 143 | if __name__ == '__main__': 144 | xml_dir = './traffic_voc/Annotations' # 原始xml文件存储文件夹 145 | img_dir = 'D:/Downloads/train' # 原始照片存储文件夹 146 | 147 | new_xml_dir = './traffic_voc/aug_xmls' # 生成的新xml存储文件夹 148 | if not os.path.exists(new_xml_dir): 149 | os.mkdir(new_xml_dir) 150 | 151 | new_pic_dir = './traffic_voc/aug_imgs' # 生成的新照片存储文件夹 152 | if not os.path.exists(new_pic_dir): 153 | os.mkdir(new_pic_dir) 154 | 155 | for file in os.listdir(xml_dir): 156 | file_name = file.split('.')[0] 157 | xml = file_name + '.xml' 158 | pic = file_name + '.jpg' 159 | #new_file_name = 'aug_' + file_name 160 | new_file_name = file_name 161 | # 生成新的照片文件 162 | bndboxes = GetAnnotBoxLoc(AnotPath=os.path.join(xml_dir, xml)) 163 | all_bboxes, img_shape, img = paste_object(img_file=os.path.join(img_dir, pic), 164 | bndboxes=bndboxes) 165 | cv2.imwrite(os.path.join(new_pic_dir, new_file_name + '.jpg'), img) # 存储数据增强后的照片 166 | 167 | # 生成新的标签文件xml 168 | dom = write_xml(all_bndboxes=all_bboxes, img_name=file_name, 169 | img_path=os.path.join(img_dir, pic), shape=img_shape) 170 | # 存储数据增强后的标签文件 171 | with open(os.path.join(new_xml_dir, new_file_name + '.xml'), 'wb') as x: 172 | x.write(dom.toprettyxml(indent='\t', encoding='utf-8')) 173 | x.close() 174 | -------------------------------------------------------------------------------- /data_augment.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | ################################################ 3 | # 数据增广,包括 4 | # 2018.09.02 add 5 | ################################################ 6 | import numpy as np 7 | import cv2 8 | import copy 9 | 10 | 11 | class DataAugment: 12 | def __init__(self, debug=False): 13 | self.debug = debug 14 | print("Data augment...") 15 | 16 | def basic_matrix(self, translation): 17 | """基础变换矩阵""" 18 | return np.array([[1, 0, translation[0]], [0, 1, translation[1]], 19 | [0, 0, 1]]) 20 | 21 | def adjust_transform_for_image(self, img, trans_matrix): 22 | """根据图像调整当前变换矩阵""" 23 | transform_matrix = copy.deepcopy(trans_matrix) 24 | height, width, channels = img.shape 25 | transform_matrix[0:2, 2] *= [width, height] 26 | center = np.array((0.5 * width, 0.5 * height)) 27 | transform_matrix = np.linalg.multi_dot( 28 | [self.basic_matrix(center), transform_matrix, 29 | self.basic_matrix(-center)]) 30 | return transform_matrix 31 | 32 | def apply_transform(self, img, transform): 33 | """仿射变换""" 34 | # cv2.BORDER_REPLICATE,cv2.BORDER_TRANSPARENT 35 | output = cv2.warpAffine(img, transform[:2, :], 36 | dsize=(img.shape[1], img.shape[0]), 37 | flags=cv2.INTER_LINEAR, 38 | borderMode=cv2.BORDER_REFLECT, 39 | borderValue=0,) 40 | return output 41 | 42 | def apply(self, img, trans_matrix): 43 | """应用变换""" 44 | tmp_matrix = self.adjust_transform_for_image(img, trans_matrix) 45 | out_img = self.apply_transform(img, tmp_matrix) 46 | if self.debug: 47 | self.show(out_img) 48 | return out_img 49 | 50 | def random_vector(self, min, max): 51 | """生成范围矩阵""" 52 | min = np.array(min) 53 | max = np.array(max) 54 | print(min.shape, max.shape) 55 | assert min.shape == max.shape 56 | assert len(min.shape) == 1 57 | return np.random.uniform(min, max) 58 | 59 | def show(self, img): 60 | """可视化""" 61 | cv2.imshow("outimg", img) 62 | cv2.waitKey() 63 | 64 | def random_transform(self, img, min_translation, max_translation): 65 | """平移变换""" 66 | factor = self.random_vector(min_translation, max_translation) 67 | trans_matrix = np.array([[1, 0, factor[0]], [0, 1, factor[1]], 68 | [0, 0, 1]]) 69 | out_img = self.apply(img, trans_matrix) 70 | return trans_matrix, out_img 71 | 72 | def random_flip(self, img, factor): 73 | """水平或垂直翻转""" 74 | flip_matrix = np.array([[factor[0], 0, 0], [0, factor[1], 0], 75 | [0, 0, 1]]) 76 | out_img = self.apply(img, flip_matrix) 77 | return flip_matrix, out_img 78 | 79 | def random_rotate(self, img, factor): 80 | """随机旋转""" 81 | angle = np.random.uniform(factor[0], factor[1]) 82 | print("angle:{}".format(angle)) 83 | rotate_matrix = np.array([[np.cos(angle), -np.sin(angle), 0], 84 | [np.sin(angle), np.cos(angle), 0], 85 | [0, 0, 1]]) 86 | out_img = self.apply(img, rotate_matrix) 87 | return rotate_matrix, out_img 88 | 89 | def random_scale(self, img, min_translation, max_translation): 90 | """随机缩放""" 91 | factor = self.random_vector(min_translation, max_translation) 92 | scale_matrix = np.array([[factor[0], 0, 0], [0, factor[1], 0], 93 | [0, 0, 1]]) 94 | out_img = self.apply(img, scale_matrix) 95 | return scale_matrix, out_img 96 | 97 | def random_shear(self, img, factor): 98 | """随机剪切,包括横向和众向剪切""" 99 | angle = np.random.uniform(factor[0], factor[1]) 100 | print("fc:{}".format(angle)) 101 | crop_matrix = np.array([[1, factor[0], 0], [factor[1], 1, 0], 102 | [0, 0, 1]]) 103 | out_img = self.apply(img, crop_matrix) 104 | return crop_matrix, out_img 105 | 106 | 107 | if __name__ == "__main__": 108 | demo = DataAugment(debug=True) 109 | img = cv2.imread("/pathto/dataArgu/wr.jpg") 110 | 111 | # 平移测试 112 | # (-0.3,-0.3),(0.3,0.3) 113 | _, outimg = demo.random_transform(img, (0.1, 0.1), (0.2, 0.2)) 114 | 115 | # 垂直变换测试 116 | _, outimg = demo.random_flip(img, (1.0, -1.0)) 117 | 118 | # 水平变换测试 119 | _, outimg = demo.random_flip(img, (-1.0, 1.0)) 120 | 121 | # 旋转变换测试 122 | _, outimg = demo.random_rotate(img, (0.5, 0.8)) 123 | 124 | # # 缩放变换测试 125 | _, outimg = demo.random_scale(img, (1.2, 1.2), (1.3, 1.3)) 126 | 127 | # 随机裁剪测试 128 | _, outimg = demo.random_shear(img, (0.2, 0.3)) 129 | 130 | # 组合变换 131 | t1, _ = demo.random_transform(img, (-0.3, -0.3), (0.3, 0.3)) 132 | t2, _ = demo.random_rotate(img, (0.5, 0.8)) 133 | t3, _ = demo.random_scale(img, (1.5, 1.5), (1.7, 1.7)) 134 | tmp = np.linalg.multi_dot([t1, t2, t3]) 135 | print("tmp:{}".format(tmp)) 136 | out = demo.apply(img, tmp) 137 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | 5 | class Dataset(): 6 | def __init__(self, images, labels): 7 | # convert from [0, 255] -> [0.0, 1.0] 8 | images = images.astype(np.float32) 9 | images = np.multiply(images, 1.0 / 255.0) 10 | self._images = images 11 | self._labels = labels 12 | 13 | @property # getter 14 | def images(self): 15 | return self._images 16 | 17 | @property 18 | def labels(self): 19 | return self._labels 20 | 21 | 22 | def extract_images(image_dir, name): 23 | files = open(os.path.join(image_dir, name), 'rb') 24 | files.read(16) 25 | 26 | buf = files.read(28 * 28 * 60000) 27 | 28 | images = np.frombuffer(buf, dtype=np.uint8) 29 | # images = images.reshape(-1, 784) 30 | images = images.reshape(-1, 1, 28, 28) 31 | return images 32 | 33 | 34 | def extract_labels(image_dir, name): 35 | files = open(os.path.join(image_dir, name), 'rb') 36 | files.read(8) 37 | 38 | buf = files.read(28 * 28 * 10000) 39 | 40 | labels = np.frombuffer(buf, dtype=np.uint8) 41 | return labels 42 | 43 | 44 | def read_data_sets(image_dir): 45 | class DataSets(): 46 | pass 47 | 48 | data_sets = DataSets() 49 | 50 | TRAIN_IMAGES = 'train-images-idx3-ubyte' 51 | TRAIN_LABELS = 'train-labels-idx1-ubyte' 52 | TEST_IMAGES = 't10k-images-idx3-ubyte' 53 | TEST_LABELS = 't10k-labels-idx1-ubyte' 54 | VALIDATION_SIZE = 5000 55 | 56 | train_images = extract_images(image_dir, TRAIN_IMAGES) 57 | train_labels = extract_labels(image_dir, TRAIN_LABELS) 58 | 59 | train_images = train_images[VALIDATION_SIZE:] 60 | train_labels = train_labels[VALIDATION_SIZE:] 61 | 62 | validation_images = train_images[:VALIDATION_SIZE] 63 | validation_labels = train_labels[:VALIDATION_SIZE] 64 | 65 | test_images = extract_images(image_dir, TEST_IMAGES) 66 | test_labels = extract_labels(image_dir, TEST_LABELS) 67 | 68 | data_sets.train = Dataset(train_images, train_labels) 69 | data_sets.validation = Dataset(validation_images, validation_labels) 70 | data_sets.test = Dataset(test_images, test_labels) 71 | 72 | return data_sets -------------------------------------------------------------------------------- /file2img.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | """ 4 | @ Author: LinXu 5 | @ Contact: 17746071609@163.com 6 | @ Date: 2021/12/9 下午3:40 7 | @ Software: PyCharm 8 | @ File: file2img.py 9 | @ Desc: 从file或file_url中获取图片 10 | """ 11 | 12 | import io 13 | import os 14 | import torch 15 | import urllib 16 | import datetime 17 | from PIL import Image 18 | from pathlib import Path 19 | 20 | now = datetime.datetime.now() 21 | 22 | 23 | def valid4file(f_url): 24 | if f_url is not "": 25 | flag = "file_url" 26 | else: 27 | flag = "file" 28 | 29 | return flag 30 | 31 | 32 | def get4img(flag4file, u_f, temp_path): 33 | 34 | if flag4file == "file": 35 | image_bytes = u_f.read() 36 | img = Image.open(io.BytesIO(image_bytes)) 37 | elif flag4file == "file_url": 38 | if not os.path.exists(temp_path): 39 | os.makedirs(temp_path) 40 | 41 | url = str(Path(u_f)).replace(':/', '://') # Pathlib turns :// -> :/ 42 | u = Path(urllib.parse.unquote(u_f).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth 43 | t = now.isoformat() 44 | s = os.path.join(temp_path, t + u.split('.')[-1]) 45 | if Path(s).is_file(): 46 | print(f'Found {url} locally at {s}') # file already exists 47 | else: 48 | print(f'Downloading {url} to {s}...') 49 | torch.hub.download_url_to_file(url, s) 50 | assert Path(s).exists() and Path(s).stat().st_size > 0, f'File download failed: {url}' # check 51 | img = Image.open(s) 52 | os.system("rm " + s) 53 | 54 | return img 55 | 56 | 57 | if __name__ == '__main__': 58 | file_url = "http:/oss.straituav.com/DJI_0458.jpg?versionId=" \ 59 | "CAEQGhiBgMCWl7Xt6hciIDNiYTYwMmY3MzFkZDQ3MGRhYzMyM2ZlMjRhZDUzMDE4" 60 | image = get4img(flag4file="file_url", u_f=file_url, temp_path="./") 61 | image.show() 62 | -------------------------------------------------------------------------------- /generate_test_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from glob import glob 4 | from tqdm import tqdm 5 | from PIL import Image 6 | 7 | 8 | #生成测试图片对应的COCO格式的Json文件,通常测试不需要此文件,可用于制作伪标签 9 | 10 | cls_classes = ['dog','cat','human'] #检测目标类别(不含background) 11 | label_ids = {name: i + 1 for i, name in enumerate(cls_classes)} 12 | def save(images, annotations): 13 | ann = {} 14 | ann['type'] = 'instances' 15 | ann['images'] = images 16 | ann['annotations'] = annotations 17 | 18 | categories = [] 19 | for k, v in label_ids.items(): 20 | categories.append({"name": k, "id": v}) 21 | ann['categories'] = categories 22 | json.dump(ann, open('/home1/huangqiangHD/dataset/UnderWater/train/annotations/testA.json', 'w')) 23 | 24 | 25 | def test_dataset(im_dir): 26 | im_list = glob(os.path.join(im_dir, '*.jpg')) 27 | idx = 1 28 | image_id = 1 29 | images = [] 30 | annotations = [] 31 | for im_path in tqdm(im_list): 32 | image_id += 1 33 | im = Image.open(im_path) 34 | w, h = im.size 35 | image = {'file_name': os.path.basename(im_path), 'width': w, 'height': h, 'id': image_id} 36 | images.append(image) 37 | labels = [[10, 10, 20, 20]] 38 | for label in labels: 39 | bbox = [label[0], label[1], label[2] - label[0], label[3] - label[1]] 40 | seg = [] 41 | ann = {'segmentation': [seg], 'area': bbox[2] * bbox[3], 'iscrowd': 0, 'image_id': image_id, 42 | 'bbox': bbox, 'category_id': 1, 'id': idx, 'ignore': 0} 43 | idx += 1 44 | annotations.append(ann) 45 | save(images, annotations) 46 | 47 | 48 | if __name__ == '__main__': 49 | test_dir = '/home1/huangqiangHD/dataset/UnderWater/test/testA/' 50 | print("generate test json label file.") 51 | test_dataset(test_dir) -------------------------------------------------------------------------------- /line.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # 直线检测与识别 4 | # Author: alex 5 | # Created Time: 2020年03月20日 星期五 16时41分48秒 6 | import cv2 7 | import numpy as np 8 | from math import atan 9 | from sklearn.cluster import DBSCAN 10 | 11 | 12 | def intersection_line(line_a, line_b): 13 | """计算两个线段的重叠部分""" 14 | a1, b1 = line_a 15 | a2, b2 = line_b 16 | if b1 <= a2 or b2 <= a1: 17 | return 0 18 | return min(b1, b2) - max(a1, a2) 19 | 20 | 21 | def in_line_rate(line, container_line): 22 | """一个线段和另一个线段的重合部分,占该线段总长的占比""" 23 | inter = intersection_line(line, container_line) 24 | return inter / (line[1] - line[0]) 25 | 26 | 27 | def iou_line(line_a, line_b): 28 | """两个线段的重叠占比""" 29 | inter = intersection_line(line_a, line_b) 30 | if inter == 0: 31 | return 0 32 | a1, b1 = line_a 33 | a2, b2 = line_b 34 | return inter / (max(b1, b2) - min(a1, a2)) 35 | 36 | 37 | def intersection_points(line_img1, line_img2): 38 | """计算两个直线的交点图像""" 39 | points = cv2.bitwise_and(line_img1, line_img2) 40 | return points 41 | 42 | 43 | def detect_lines_angle(gray, scale, line_type='row'): 44 | """检测直线的倾斜角度 45 | :param gray 灰度图 46 | :param scale 检测参数 47 | :param line_type 直线类型,值为row(横线)或者col(竖线) 48 | :return 倾斜弧度,如果需要转换为角度: math.degrees 49 | """ 50 | binary = cv2.adaptiveThreshold(~gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 51 | cv2.THRESH_BINARY, 15, -10) 52 | if line_type == 'row': 53 | line_img = detect_row_line(binary, scale) 54 | elif line_type == 'col': 55 | line_img = detect_col_line(binary, scale) 56 | else: 57 | raise Exception('error line_type value') 58 | 59 | n, lines, _ = cluster_fit_lines(line_img) 60 | if n < 3: 61 | return None 62 | angles = [line[0] for line in lines] 63 | return atan(sum(angles)/len(angles)) 64 | 65 | 66 | def detect_line(gray, scale): 67 | """检测横线和竖线""" 68 | binary = cv2.adaptiveThreshold(~gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 69 | cv2.THRESH_BINARY, 15, -10) 70 | row_img = detect_row_line(binary, scale) 71 | col_img = detect_col_line(binary, scale) 72 | return row_img, col_img 73 | 74 | 75 | def detect_col_line(binary, scale): 76 | """检测竖线""" 77 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, scale)) 78 | eroded = cv2.erode(binary, kernel, iterations=1) 79 | return cv2.dilate(eroded, kernel, iterations=1) 80 | 81 | 82 | def detect_row_line(binary, scale): 83 | """检测横线""" 84 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (scale, 1)) 85 | eroded = cv2.erode(binary, kernel, iterations=1) 86 | return cv2.dilate(eroded, kernel, iterations=1) 87 | 88 | 89 | def cluster_lines(data, eps=3, min_samples=2, metric='manhattan'): 90 | """线条聚类""" 91 | db = DBSCAN(eps=eps, min_samples=min_samples, metric=metric).fit(data) 92 | labels = db.labels_ 93 | n_clusters_ = max(labels) + 1 94 | return n_clusters_, labels 95 | 96 | 97 | def fit_line(points, exchange_xy=False): 98 | """拟合直线 99 | 注意:如果是针对竖线,应该对xy轴进行交换,避免出现x=b这样的直线 100 | :params points [[y, x]] 图像坐标:[y, x] 101 | :param exchange_xy bool 拟合直线时,决定是否需要交换x和y轴 102 | :return [a, b] 直线参数:y=ax+b 103 | """ 104 | X = points[:, 1] 105 | Y = points[:, 0] 106 | if exchange_xy: 107 | X, Y = Y, X 108 | 109 | line = np.polyfit(X, Y, 1) 110 | return line 111 | 112 | 113 | def cluster_fit_lines(line_img, exchange_xy=False, sorted_b=False): 114 | """直线聚合并拟合直线 115 | 注意:如果是竖线,类似y=b这种,应该讲x轴和y轴进行交换 116 | :param line_img 直线的黑白图像 117 | :param exchange_xy bool 拟合直线时,决定是否需要交换x和y轴 118 | :param sorted_b bool 是否按截距进行排序 119 | :return n int 直线数量 120 | :return lines [[a, b]] 直线方程的参数 121 | :return endpoints [[y, x]] 线段的端点 122 | """ 123 | points_idx = np.argwhere(line_img == 255) 124 | if len(points_idx) < 3: 125 | return 0, [], [] 126 | 127 | n, labels = cluster_lines(points_idx) 128 | lines = [] 129 | endpoints = [] 130 | for i in range(n): 131 | line_points_idx = points_idx[labels == i] 132 | line = fit_line(line_points_idx, exchange_xy=exchange_xy) 133 | lines.append(line) 134 | endpoint = get_endpoint(line, line_points_idx, exchange_xy=exchange_xy) 135 | endpoints.append(endpoint) 136 | 137 | if sorted_b and len(lines) > 1: 138 | data = [(l, p) for l, p in zip(lines, endpoints)] 139 | data = sorted(data, key=lambda x: x[0][-1]) 140 | lines = [l for l, _ in data] 141 | endpoints = [p for _, p in data] 142 | 143 | return n, lines, endpoints 144 | 145 | 146 | def get_endpoint(line, points_idx, exchange_xy=False): 147 | """获取线段的端点 148 | :param line list 直线方程的参数,值为[a, b] 149 | :param points_idx list 图像直线上所有点的坐标,坐标格式是[y, x] 150 | :param exchange_xy bool 是否对x轴和y轴进行交换,默认为False,即y=ax+b,若为True, 则是x=ay+b 151 | :return endpoint1, endpoint2: 每个顶点格式[y, x] 152 | """ 153 | a, b = line 154 | # 数据中点的坐标格式是[y, x] 155 | if exchange_xy: 156 | # 这个是竖线,需要找y的最大最小值 157 | X = [y for y, _ in points_idx] 158 | v_min, v_max = min(X), max(X) 159 | return (v_min, v_min*a+b), (v_max, v_max*a+b) 160 | 161 | # 对于横线,需要找到x的最大最小值 162 | X = [x for _, x in points_idx] 163 | v_min, v_max = min(X), max(X) 164 | return (v_min*a+b, v_min), (v_max*a+b, v_max) 165 | 166 | 167 | def check_segment_collinear(seg1, seg2, exchange_xy=False, b_err=4., 168 | a_err=0.1): 169 | """判断两个线段是否共线 170 | :param seg1 list 线段1, 格式:[a, b, (y1, x1), (y2, x2)]。线段所在直线y=ax+b, (y1, x1)与(y2, x2)是线段两个端点 171 | :param seg2 list 线段2, 格式:[a, b, (y1, x1), (y2, x2)] 172 | :param exchange_xy bool 默认为False即可,如果是竖线,则设置该值为True,这时对应的直线方程是x=ay+b 173 | :param b_err float 直线方程中b参数允许的误差 174 | :param a_err float 直线方程中a参数允许的误差 175 | :return bool 两个线段是否共线 176 | """ 177 | a1, b1 = seg1[:2] 178 | a2, b2 = seg2[:2] 179 | if abs(a1 - a2) > a_err or abs(b1 - b2) > b_err: 180 | return False 181 | 182 | # 避免两个线段有交叉 183 | if exchange_xy: 184 | # 竖线 185 | (v11, _), (v12, _) = seg1[2:] 186 | (v21, _), (v22, _) = seg2[2:] 187 | else: 188 | (_, v11), (_, v12) = seg1[2:] 189 | (_, v21), (_, v22) = seg2[2:] 190 | 191 | min1, max1 = min(v11, v12), max(v11, v12) 192 | min2, max2 = min(v21, v22), max(v21, v22) 193 | return max1 < min2 or min1 > max2 194 | 195 | 196 | if __name__ == '__main__': 197 | assert intersection_line((1, 10), (20, 30)) == 0 198 | assert intersection_line((1, 10), (5, 30)) == 5 199 | assert in_line_rate((1, 10), (5, 30)) == 5/9 200 | assert in_line_rate((5, 30), (1, 10)) == 5/25 201 | assert iou_line((5, 30), (1, 10)) == 5/29 202 | -------------------------------------------------------------------------------- /line_cluster.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # 线段聚类 4 | # Author: alex 5 | # Created Time: 2020年03月26日 星期四 15时18分41秒 6 | import numpy as np 7 | from sklearn.cluster import DBSCAN 8 | # from ibbd_algo.optics import Optics 9 | 10 | 11 | def line_cluster(lines, line_types, enpoints, eps=2, min_samples=2): 12 | """将有交点的线段聚合在一起 13 | 线段方程类型为: 14 | True: y=ax+b 15 | False: x=ay+b 16 | :param lines list 线段方程参数[(a, b)] 17 | :param line_types list 线段方程的类型,跟lines参数对应,取值True or False 18 | :param enpoints list 线段的端点,注意每个线段有两个端点: [[(y1, x1), (y2, x2)]] 19 | :param eps, min_samples: DBSCAN聚类参数 20 | :return labels 21 | """ 22 | data = [(l_type, a, b, x1, y1, x2, y2) 23 | for ((a, b), l_type, ((x1, y1), (x2, y2))) in 24 | zip(lines, line_types, enpoints)] 25 | db = DBSCAN(eps=eps, min_samples=min_samples, metric=distance).fit(data) 26 | return db.labels_ 27 | # optics = Optics(max_radius, min_samples, distance=distance) 28 | # optics.fit(data) 29 | # return optics.cluster(cluster_thr) 30 | 31 | 32 | def distance(line1, line2): 33 | """计算两个线段的距离 34 | :param line1,line2: [a, b, x1, y1, x2, y2] 35 | """ 36 | l_type1, a1, b1, x11, y11, x12, y12 = line1 37 | l_type2, a2, b2, x21, y21, x22, y22 = line2 38 | 39 | def format_fraction(val): 40 | c_val = 0.000000001 41 | if -c_val < val < 0: 42 | val = -c_val 43 | elif c_val > val >= 0: 44 | val = c_val 45 | 46 | return val 47 | 48 | # 计算直线交点 49 | if l_type1 and l_type2: 50 | x0 = (b2-b1) / format_fraction(a1-a2) 51 | y0 = a1*x0 + b1 52 | elif not l_type1 and not l_type2: 53 | y0 = (b2-b1) / format_fraction(a1-a2) 54 | x0 = a1*y0 + b1 55 | elif l_type1 and not l_type2: 56 | # y=a1*x+b1 and x=a2*y+b2 57 | y0 = (a1*b2+b1) / format_fraction(1-a1*a2) 58 | x0 = a2*y0 + b2 59 | elif not l_type1 and l_type2: 60 | # x=a1*y+b1 and y=a2*x+b2 61 | x0 = (a1*b2+b1) / format_fraction(1-a1*a2) 62 | y0 = a2*x0 + b2 63 | 64 | def point_line_dist(x1, y1, x2, y2): 65 | """计算点到线的距离""" 66 | if x1 <= x0 <= x2 and y1 <= y0 <= y2: 67 | dist = 0 68 | else: 69 | # 到两端点的最小距离 70 | dist = min(np.linalg.norm([x0-x1, y0-y1]), 71 | np.linalg.norm([x0-x2, y0-y2])) 72 | 73 | return dist 74 | 75 | # 计算到线1的距离 76 | dist1 = point_line_dist(x11, y11, x12, y12) 77 | # 计算到线2的距离 78 | dist2 = point_line_dist(x21, y21, x22, y22) 79 | return dist1 + dist2 80 | 81 | 82 | if __name__ == '__main__': 83 | pass 84 | -------------------------------------------------------------------------------- /pdf.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import fitz # fitz: pip install PyMuPDF 3 | 4 | 5 | def pdf2images(doc, zoom=2, color='RGB'): 6 | """pdf to images 7 | example: 8 | doc = fitz.open(/path/to/pdf) 9 | images = pdf2images(doc) 10 | example: 11 | stream = open(/path/to/pdf, 'rb') 12 | doc = fitz.open(stream) 13 | images = pdf2images(doc) 14 | example: 15 | doc = fitz.open(stream=bytes, filetype='bytes') 16 | images = pdf2images(doc) 17 | """ 18 | mat = fitz.Matrix(zoom, zoom) 19 | images = [] 20 | # mat = fitz.Matrix(zoom_x, zoom_y).preRotate(rotate) 21 | # for pg in range(doc.pageCount): 22 | # page = doc[pg] 23 | for page in doc: 24 | pix = page.getPixmap(matrix=mat, alpha=False) 25 | images.append(Image.frombytes(color, [pix.width, pix.height], pix.samples)) 26 | 27 | return images 28 | 29 | if __name__ == "__main__": 30 | import sys 31 | doc = fitz.open(sys.argv[1]) 32 | images = pdf2images(doc) 33 | print(len(images), images[0].size) -------------------------------------------------------------------------------- /point.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # 点检测与识别 4 | # Author: alex 5 | # Created Time: 2020年03月21日 星期六 11时50分41秒 6 | import numpy as np 7 | from sklearn.cluster import DBSCAN 8 | 9 | 10 | def cluster_points(point_img, eps=3, min_samples=2, metric='manhattan'): 11 | """交点聚类 12 | :param point_img 交点图像 13 | :param eps, min_sample, metric: dbscan聚类参数 14 | :return n_clusters: int: 聚类数量 15 | :return points: list: 每个类别的中心点 16 | :return labels: list: 图像上说有交点的所属类别 17 | """ 18 | idx = np.argwhere(point_img == 255) 19 | if len(idx) < 3: 20 | return 0, [], [] 21 | 22 | db = DBSCAN(eps=eps, min_samples=min_samples, metric=metric).fit(idx) 23 | labels = db.labels_ 24 | n_clusters = max(labels) + 1 25 | 26 | # 计算交点的核心点 27 | points = [np.average(idx[labels == i], axis=0) 28 | for i in range(n_clusters)] 29 | return n_clusters, points, labels 30 | 31 | 32 | def point_on_line(point, a, b, e=0.01): 33 | """判断点是否在直线y=ax+b上,允许一定的误差 34 | :param point [y, x] 35 | :param a,b float 直线参数 36 | :param e float 允许的误差 37 | :return bool 该点是否在直线上 38 | """ 39 | y, x = point 40 | return abs(a*x+b-y) < e 41 | -------------------------------------------------------------------------------- /rm_watermark.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | 4 | Author: alex 5 | Created Time: 2020年08月20日 星期四 16时09分37秒 6 | ''' 7 | import cv2 8 | import numpy as np 9 | 10 | 11 | def remove_watermark(image, thr=200, convol=3): 12 | """ 13 | 简单粗暴去水印,可将将pdf或者扫描件中水印去除 14 | 使用卷积来优化计算 15 | :param image: 输入图片,cv格式灰度图像 16 | :param thr: 去除图片中像素阈值 17 | :param convol: 卷积窗口的大小 18 | :return: 返回np.array格式图片 19 | """ 20 | distance = int((convol - 1) / 2) # 为了执行卷积,对图像连缘进行像素扩充 21 | # 使用白色来进行边缘像素扩充 22 | image = cv2.copyMakeBorder(image, distance, distance, distance, distance, 23 | cv2.BORDER_CONSTANT, value=255) 24 | mask = (image < 200).astype(int) 25 | # 单位矩阵卷积操作 26 | mask = cv2.boxFilter(mask, -1, (convol, convol), normalize=False) 27 | mask = (mask >= 1).astype(int) # 掩膜构建完成,>=1表示窗口内有黑点 28 | image[np.where(mask == 0)] = 255 # 掩膜中为0的位置赋值为255,白色,达到去水印效果 29 | h, w = image.shape[:2] 30 | image = image[distance:h - distance, distance:w - distance] 31 | return image 32 | 33 | 34 | def bak_remove_watermark(image, thr=200, distance=1): 35 | """ 36 | 简单粗暴去水印,可将将pdf或者扫描件中水印去除 37 | :param image: 输入图片,Image格式 38 | :param thr: 去除图片中像素阈值 39 | :param distance: 去除图片中像素距离 40 | :return: 返回np.arrayg格式图片 41 | """ 42 | w, h = image.size 43 | rgb_im = image.convert('RGB') 44 | for x in range(0, w - 1): 45 | for y in range(0, h - 1): 46 | if not hasBlackAround(x, y, distance, rgb_im, thr=thr): 47 | rgb_im.putpixel((x, y), (255, 255, 255)) 48 | 49 | return rgb_im 50 | 51 | 52 | def hasBlackAround(x, y, distance, img, thr=200): 53 | w, h = img.size 54 | startX = max(0, x-distance) 55 | startY = max(0, y-distance) 56 | endX = min(w-1, x+distance) 57 | endY = min(h-1, y+distance) 58 | for j in range(startX, endX): 59 | for k in range(startY, endY): 60 | r, g, b = img.getpixel((j, k)) 61 | if r < thr and g < thr and b < thr: 62 | # 满足条件的点黑点 63 | return True 64 | 65 | return False 66 | 67 | 68 | if __name__ == '__main__': 69 | from PIL import Image 70 | debug = False 71 | image_path = "gf-png/gf1.png" 72 | img = Image.open(image_path) 73 | res_img = remove_watermark(img, thr=100, distance=1) 74 | -------------------------------------------------------------------------------- /similary.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # 图像相似性 4 | # Author: alex 5 | # Created Time: 2020年01月06日 星期一 16时19分23秒 6 | import cv2 7 | 8 | 9 | def orb_similary(img1, img2, distance_thr=0.75, **kwargs): 10 | """ 11 | 计算图像相似性 12 | :param img1 cv2.imread(img1_path, cv2.IMREAD_GRAYSCALE) 13 | :param img2 cv2.imread(img2_path, cv2.IMREAD_GRAYSCALE) 14 | :param distance_thr 匹配的距离阈值 15 | :param **kwargs cv2.OBR_create函数的参数 16 | :return similary float 17 | 说明: 18 | ORB_create([, nfeatures[, scaleFactor[, nlevels[, edgeThreshold[, firstLevel[, WTA_K[, scoreType[, patchSize[, fastThreshold]]]]]]]]]) 19 | """ 20 | # 读取图片 21 | # 初始化ORB检测器 22 | orb = cv2.ORB_create(**kwargs) 23 | _, des1 = orb.detectAndCompute(img1, None) 24 | _, des2 = orb.detectAndCompute(img2, None) 25 | 26 | # 提取并计算特征点 27 | bf = cv2.BFMatcher(cv2.NORM_HAMMING) 28 | 29 | # knn筛选结果 30 | matches = bf.knnMatch(des1, trainDescriptors=des2, k=2) 31 | # print(matches) 32 | 33 | # 查看最大匹配点数目 34 | good = [m for (m, n) in matches if m.distance < distance_thr * n.distance] 35 | # print(len(good)) 36 | # print(len(matches)) 37 | similary = float(len(good))/len(matches) 38 | # print("(ORB算法)两张图片相似度为:%s" % similary) 39 | return similary 40 | -------------------------------------------------------------------------------- /spilt_train_val.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as pt 3 | import json 4 | import funcy 5 | from pycocotools.coco import COCO 6 | from pycocotools.cocoeval import COCOeval 7 | from sklearn.model_selection import train_test_split 8 | 9 | 10 | #*********************# 11 | #将数据集划分为train和val# 12 | #*********************# 13 | 14 | def split_train_val(annotations, trainpath, valpath, splitrate=0.99):#划分比例为99:1 15 | with open(annotations, 'rt', encoding='UTF-8') as annotations: 16 | coco = json.load(annotations) 17 | images = coco['images'] 18 | annotations = coco['annotations'] 19 | categories = coco['categories'] 20 | 21 | # number_of_images = len(images) 22 | train, val = train_test_split(images, train_size=splitrate) 23 | 24 | save_coco(trainpath, train, filter_annotations(annotations, train), 25 | categories) 26 | save_coco(valpath, val, filter_annotations(annotations, val), 27 | categories) 28 | 29 | print("Saved {} entries in {} and {} in {}".format( 30 | len(train), trainpath, len(val), valpath)) 31 | 32 | def save_coco(file, images, annotations, categories): 33 | with open(file, 'wt', encoding='UTF-8') as coco: 34 | json.dump( 35 | { 36 | 'images': images, 37 | 'annotations': annotations, 38 | 'categories': categories 39 | }, 40 | coco, 41 | indent=4, 42 | sort_keys=True) 43 | 44 | def filter_annotations(annotations, images): 45 | image_ids = funcy.lmap(lambda i: int(i['id']), images) 46 | return funcy.lfilter(lambda a: int(a['image_id']) in image_ids, 47 | annotations) 48 | def main(): 49 | COCO_FORMAT_JSON_PATH='/home1/huangqiangHD/dataset/X_ray/train' 50 | 51 | annotations = pt.join(COCO_FORMAT_JSON_PATH, 'train.json')#总的数据集 52 | trainpath = pt.join(COCO_FORMAT_JSON_PATH, 'train_.json')#train_set 53 | valpath = pt.join(COCO_FORMAT_JSON_PATH,'val.json')#val_set 54 | split_train_val(annotations, trainpath, valpath) 55 | 56 | if __name__ == '__main__': 57 | main() -------------------------------------------------------------------------------- /table.py: -------------------------------------------------------------------------------- 1 | """ 2 | 表格聚类: 将有关联的线段聚合在一起 3 | Author: alex 4 | Created Time: 2020年05月25日 5 | """ 6 | import numpy as np 7 | from sklearn.cluster import DBSCAN 8 | 9 | # 线段a,b参数的最大值 10 | A_MAX = 1e8 11 | B_MAX = 1e8 12 | # 斜率最小值 13 | A_MIN = 1e-8 14 | # 距离最大值 15 | D_MAX = 1e8 16 | 17 | 18 | def table_lines_cluster(lines, eps=3, min_samples=2): 19 | """表格线段聚类 20 | 将有关联的线段聚合在一起 21 | 线段:y = a*x+b,(x1, y1)和(x2, y2)是其两个端点,两种表示形式 22 | 1. 用线段的两个端点来表示一个线段:(x1, y1, x2, y2) 23 | 2. 如果已经计算出参数a和b,则线段:(x1, y1, x2, y2, a, b) 24 | :params lines list 线段列表 25 | :params eps, min_samples: DBSCAN聚类所使用的参数 26 | :return labels np.array 聚类结果 27 | """ 28 | new_lines = [] 29 | if len(lines[0]) == 6: 30 | for x1, y1, x2, y2, a, b in lines: 31 | a = max(min(a, A_MAX), -A_MAX) 32 | b = max(min(b, B_MAX), -B_MAX) 33 | if abs(a) < A_MIN: 34 | a = A_MIN 35 | 36 | new_lines.append([x1, y1, x2, y2, a, b]) 37 | elif len(lines[0]) == 4: 38 | for x1, y1, x2, y2 in lines: 39 | a, b = cal_line_params(x1, y1, x2, y2) 40 | new_lines.append([x1, y1, x2, y2, a, b]) 41 | else: 42 | raise Exception('lines: param error!') 43 | 44 | cluster = DBSCAN(eps=eps, min_samples=min_samples, 45 | metric=distance).fit(new_lines) 46 | return cluster.labels_ 47 | 48 | 49 | def cal_line_params(x1, y1, x2, y2): 50 | """计算直线的参数""" 51 | a = (y2-y1)/(x2-x1) 52 | b = y1 - a*x1 53 | a = max(min(a, A_MAX), -A_MAX) 54 | b = max(min(b, B_MAX), -B_MAX) 55 | if abs(a) < A_MIN: 56 | a = A_MIN 57 | 58 | return a, b 59 | 60 | 61 | def distance(line1, line2): 62 | """计算两个线段的距离 63 | a, b: 线段直线参数: y=ax+b 64 | x1, y1, x2, y2: 线段的两个端点 65 | :param line1, line2: [x1, y1, x2, y2, a, b] 66 | :return float 两个线段的距离 67 | """ 68 | a1, b1 = line1[4:] 69 | a2, b2 = line2[4:] 70 | 71 | # 计算交点 72 | if abs(a1-a2) < 0.01: 73 | return D_MAX 74 | x0 = (b2-b1)/(a1-a2) 75 | y0 = a1 * x0 + b1 76 | 77 | def point_line_dist(x1, y1, x2, y2): 78 | """计算点到线的距离""" 79 | v1 = [x1-x0, y1-y0] 80 | v2 = [x2-x0, y2-y0] 81 | return min(np.linalg.norm(v1), np.linalg.norm(v2)) 82 | 83 | # 计算到线1的距离 84 | dist1 = point_line_dist(*line1[:4]) 85 | # 计算到线2的距离 86 | dist2 = point_line_dist(*line2[:4]) 87 | # print(dist1+dist2) 88 | return dist1 + dist2 89 | 90 | 91 | if __name__ == '__main__': 92 | def create_line(p1, p2): 93 | x1, y1 = p1 94 | x2, y2 = p2 95 | a = (y2-y1)/(x2-x1) 96 | b = y1 - a*x1 97 | return (x1, y1, x2, y2, a, b) 98 | 99 | points1 = [(1, 9), (2, 1), (7, 1.5), (10, 10)] 100 | points2 = [(4, 2), (5, 2), (6, 4), (3, 5)] 101 | data = [create_line(points1[i], points1[i+1]) 102 | for i in range(len(points1)-1)] 103 | for i in range(len(points2)-1): 104 | data.append(create_line(points2[i], points2[i+1])) 105 | 106 | data.append(create_line(points2[0], points2[len(points2)-1])) 107 | labels = table_lines_cluster(data, eps=2.5) 108 | print(labels) 109 | 110 | data = [points1[i]+points1[i+1] for i in range(len(points1)-1)] 111 | for i in range(len(points2)-1): 112 | data.append(points2[i]+points2[i+1]) 113 | 114 | data.append(points2[0]+points2[len(points2)-1]) 115 | labels = table_lines_cluster(data, eps=2.5) 116 | print(labels) 117 | -------------------------------------------------------------------------------- /text.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # 文字相关 4 | # Author: alex 5 | # Created Time: 2020年05月23日 星期六 17时05分27秒 6 | import os 7 | from PIL import ImageFont, ImageDraw 8 | 9 | font = None 10 | package_path = os.path.dirname(os.path.realpath(__file__)) 11 | default_font_path = os.path.join(package_path, 'SimHei.ttf') 12 | 13 | 14 | def get_default_font_path(): 15 | """获取默认的字体文件路径 16 | :return default_font_path str 17 | """ 18 | return default_font_path 19 | 20 | 21 | def set_font(font_size=12, font_path=None, encoding='utf-8'): 22 | """设置字体 23 | :param font_path str|None 字体路径,默认则使用的字体是:SimHei 24 | :param font_size int 字体大小,默认为12 25 | :param encoding str 字体编码,默认为utf-8 26 | :return font ImageFont.truetype 27 | """ 28 | global font 29 | if font_path is None: 30 | font_path = default_font_path 31 | 32 | font = ImageFont.truetype(font_path, font_size, encoding=encoding) 33 | return font 34 | 35 | 36 | def add_chinese(draw, pos, text, fill=(255, 0, 0)): 37 | """往图像上添加中文 38 | 注意:在执行该函数之前,需要先初始化字体,对应函数:set_font 39 | :param draw ImageDraw.Draw(img) 40 | :param pos list|tuple 显示文字的位置,格式: (x, y) 41 | :param text str 需要显示的文字 42 | """ 43 | global font 44 | if font is None: 45 | font = set_font() 46 | 47 | draw.text(pos, text, font=font, fill=fill) 48 | return draw 49 | 50 | 51 | def add_chinese_img(img, pos, text, fill=(255, 0, 0)): 52 | """往图像上添加中文 53 | 注意:在执行该函数之前,需要先初始化字体,对应函数:set_font 54 | :param img PIL图像格式 55 | :param pos list|tuple 显示文字的位置,格式: (x, y) 56 | :param text str 需要显示的文字 57 | """ 58 | draw = ImageDraw.Draw(img) 59 | return add_chinese(draw, pos, text, fill=fill) 60 | -------------------------------------------------------------------------------- /text_angle.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # 文字角度相关函数 4 | # Author: alex 5 | # Created Time: 2020年01月03日 星期五 18时36分08秒 6 | import cv2 7 | import numpy as np 8 | from scipy.ndimage import filters, interpolation 9 | from image_utils.utils import conc_map 10 | 11 | 12 | def estimate_skew_angle(gray, fine_tune_num=4, step_start=0.75, 13 | max_workers=None, scale=600., max_scale=900.): 14 | """ 15 | 估计图像文字角度 16 | :param gray 待纠正的灰度图像 17 | :param fine_tune_num 微调的次数, 界定了微调的精度 18 | 当该值为n时,表示微调角度精确到step_start乘以10的-(n-1)次方 19 | :param step_start 步长的初始值 20 | 当该值为a时,其纠正的角度范围是[-10*a, 10*a]。该值不应该大于4.5 21 | :param max_workers int|None 并发的进程数量限制 22 | :param scale, max_scale float 计算时缩放的最小最大宽高 23 | :return angle 需要纠正的角度 24 | """ 25 | def resize_im(im, scale, max_scale): 26 | f = scale / min(im.shape[:2]) 27 | max_rate = max_scale / max(im.shape[:2]) 28 | f = min(f, max_rate) 29 | return cv2.resize(im, (0, 0), fx=f, fy=f) 30 | 31 | gray = resize_im(gray, scale, max_scale) 32 | g_min, g_max = np.amin(gray), np.amax(gray) 33 | if g_max - g_min < 30: 34 | return 0. 35 | # 归一化 36 | image = (gray-g_min) / (g_max-g_min) 37 | m = interpolation.zoom(image, 0.5) 38 | m = filters.percentile_filter(m, 80, size=(20, 2)) 39 | m = filters.percentile_filter(m, 80, size=(2, 20)) 40 | m = interpolation.zoom(m, 1.0/0.5) 41 | 42 | w, h = min(image.shape[1], m.shape[1]), min(image.shape[0], m.shape[0]) 43 | flat = np.clip(image[:h, :w]-m[:h, :w]+1, 0, 1) 44 | d0, d1 = flat.shape 45 | o0, o1 = int(0.1*d0), int(0.1*d1) 46 | flat = np.amax(flat)-flat 47 | flat -= np.amin(flat) 48 | est = flat[o0:d0-o0, o1:d1-o1] 49 | 50 | angle, step = 0, step_start # 纠正角度的初始值和步长 51 | for _ in range(fine_tune_num): 52 | angle = fine_tune_angle(est, step, start=angle, 53 | max_workers=max_workers) 54 | step /= 10 55 | 56 | return angle 57 | 58 | 59 | def fine_tune_angle(image, step, start=0, max_workers=None): 60 | """微调纠正 61 | 在某个角度start的周围进行微调 62 | """ 63 | def var(i): 64 | # 从-10到10 65 | angle = start + (i-5)*step 66 | roest = interpolation.rotate(image, angle, order=0, mode='constant') 67 | v = np.mean(roest, axis=1) 68 | v = np.var(v) 69 | return (v, angle) 70 | 71 | estimates = conc_map(var, range(11), max_workers=max_workers) 72 | _, angle = max(estimates) 73 | return angle 74 | 75 | 76 | if __name__ == '__main__': 77 | import sys 78 | from convert import rotate 79 | img = cv2.imread(sys.argv[1], cv2.COLOR_BGR2GRAY) 80 | angle = estimate_skew_angle(img) 81 | print(angle) 82 | new_img = rotate(img, angle) 83 | cv2.imwrite(sys.argv[2], new_img) 84 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # 其他图像工具函数 4 | # Author: alex 5 | # Created Time: 2020年03月18日 星期三 17时38分00秒 6 | # import cv2 7 | import numpy as np 8 | from copy import deepcopy 9 | from concurrent import futures 10 | 11 | 12 | def conc_map(func, map_data, max_workers=None): 13 | """并发执行""" 14 | with futures.ThreadPoolExecutor(max_workers=max_workers) as executor: 15 | results = executor.map(func, map_data) 16 | return list(results) 17 | 18 | 19 | def find_empty_size_all(gray, size, std_thr=15, mean_thr=200): 20 | """查找图像中所有特定size的空白区域 21 | :param gray cv2格式的灰度图 22 | :param size 指定size,值如:(100, 100) 23 | :param std_thr int 区域内标准差 24 | :param mean_thr int 区域内均值 25 | :return [box1, box2, ...] 返回所有满足条件的box,每个box的格式(x1, y1, x2, y2) 26 | """ 27 | boxes = [] # 返回值 28 | cw, ch = size 29 | h, w = gray.shape[:2] 30 | gray = deepcopy(gray) 31 | for row in range(0, h, ch): 32 | for col in range(0, w, cw): 33 | roi = gray[row:row+ch, col:col+cw] # 获取分块 34 | dev = np.std(roi) 35 | avg = np.mean(roi) 36 | if dev < std_thr and avg > mean_thr: 37 | # 满足条件,接近空白区域,让他变黑 38 | boxes.append((col, col+cw, row, row+ch)) 39 | gray[row:row+ch, col:col+cw] = 0 # 全部都赋值为0 40 | 41 | return boxes 42 | 43 | 44 | def find_empty_size(gray, size, std_thr=15, mean_thr=200): 45 | """查找图像中特定size的空白区域 46 | :param gray cv2格式的灰度图 47 | :param size 指定size,值如:(100, 100) 48 | :param std_thr int 区域内标准差 49 | :param mean_thr int 区域内均值 50 | :return (x1, y1, x2, y2) 返回满足条件的第一个box 51 | """ 52 | cw, ch = size 53 | h, w = gray.shape[:2] 54 | for row in range(0, h, ch): 55 | for col in range(0, w, cw): 56 | roi = gray[row:row+ch, col:col+cw] # 获取分块 57 | dev = np.std(roi) 58 | avg = np.mean(roi) 59 | if dev < std_thr and avg > mean_thr: 60 | # 满足条件,接近空白区域,让他变黑 61 | return (col, col+cw, row, row+ch) 62 | 63 | return None 64 | 65 | 66 | def count_black_points(gray, thr=200): 67 | """计算黑点的个数,通常用于文档图像 68 | :param gray cv2格式的灰度图像 69 | """ 70 | return len(np.argwhere(gray < thr)) 71 | -------------------------------------------------------------------------------- /video.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 视频相关 3 | 4 | Author: alex 5 | Created Time: 2020年11月03日 星期二 16时59分44秒 6 | ''' 7 | import skvideo.io 8 | 9 | 10 | def get_video_rotate(video_path): 11 | """获取视频旋转角度 12 | 注意:手机拍摄的视频,其角度可能需要进行旋转 13 | """ 14 | metadata = skvideo.io.ffprobe(video_path) 15 | d = metadata['video'].get('tag')[0] 16 | if d.setdefault('@key') == 'rotate': # 获取视频自选择角度 17 | return 360-int(d.setdefault('@value')) 18 | return 0 19 | -------------------------------------------------------------------------------- /video_speech_word/audioSeg.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | from pydub import AudioSegment 4 | 5 | def get_second_part_wav(main_wav_path, start_time, end_time, part_wav_path): 6 | """ 7 | 音频切片,获取部分音频,单位秒 8 | :param main_wav_path: 原音频文件路径 9 | :param start_time: 截取的开始时间 10 | :param end_time: 截取的结束时间 11 | :param part_wav_path: 截取后的音频路径 12 | :return: 13 | """ 14 | # 因为是毫秒所以需要乘以1000 15 | start_time = start_time * 1000 16 | end_time = end_time * 1000 17 | 18 | sound = AudioSegment.from_mp3(main_wav_path) 19 | word = sound[start_time:end_time] 20 | 21 | word.export(part_wav_path, format="wav") 22 | 23 | if __name__ == '__main__': 24 | wav_path = "test.wav" 25 | part_path = "2.wav" 26 | s = 0 27 | e = 10 28 | get_second_part_wav(wav_path, s, e, part_path) 29 | 30 | os.system('2.wav') 31 | -------------------------------------------------------------------------------- /video_speech_word/video_speech_word.py: -------------------------------------------------------------------------------- 1 | import moviepy.editor as mp 2 | 3 | import moviepy.editor as mp 4 | 5 | import auditok 6 | import os 7 | 8 | 9 | # 采样率16k 保证和paddlespeech一致 10 | def extract_audio(videos_file_path): 11 | my_clip = mp.VideoFileClip(videos_file_path, audio_fps=16000) 12 | if (videos_file_path.split(".")[-1] == 'MP4' or videos_file_path.split(".")[-1] == 'mp4'): 13 | p = videos_file_path.split('.MP4')[0] 14 | my_clip.audio.write_audiofile(p + '_video.wav') 15 | new_path = p + '_video.wav' 16 | return new_path 17 | 18 | 19 | def qiefen(path, ty='video', mmin_dur=1, mmax_dur=100000, mmax_silence=1, menergy_threshold=55): 20 | global mk, file_pre 21 | file = path 22 | 23 | audio_regions = auditok.split( 24 | file, 25 | min_dur=mmin_dur, # minimum duration of a valid audio event in seconds 26 | max_dur=mmax_dur, # maximum duration of an event 27 | # maximum duration of tolerated continuous silence within an event 28 | max_silence=mmax_silence, 29 | energy_threshold=menergy_threshold # threshold of detection 30 | ) 31 | 32 | for i, r in enumerate(audio_regions): 33 | # Regions returned by `split` have 'start' and 'end' metadata fields 34 | print( 35 | "Region {i}: {r.meta.start:.3f}s -- {r.meta.end:.3f}s".format(i=i, r=r)) 36 | 37 | epath = '' 38 | file_pre = str(epath.join(file.split('.')[0].split('/')[-1])) 39 | 40 | # mk = '/change' 41 | mk = '/media/linxu/mobilePan/AICA6首席架构师计划/QA/change' 42 | if (os.path.exists(mk) == False): 43 | os.mkdir(mk) 44 | if (os.path.exists(mk + '/' + ty) == False): 45 | os.mkdir(mk + '/' + ty) 46 | if (os.path.exists(mk + '/' + ty + '/' + file_pre) == False): 47 | os.mkdir(mk + '/' + ty + '/' + file_pre) 48 | 49 | num = i 50 | # 为了取前三位数字排序 51 | s = '000000' + str(num) 52 | 53 | file_save = mk + '/' + ty + '/' + file_pre + '/' + \ 54 | s[-3:] + '-' + '{meta.start:.3f}-{meta.end:.3f}' + '.wav' 55 | 56 | filename = r.save(file_save) 57 | print("region saved as: {}".format(filename)) 58 | o_path = mk + '/' + ty + '/' + file_pre 59 | return o_path 60 | 61 | 62 | import paddle 63 | from paddlespeech.cli.asr import ASRExecutor 64 | from paddlespeech.cli.text import TextExecutor 65 | 66 | import warnings 67 | 68 | warnings.filterwarnings('ignore') 69 | 70 | asr_executor = ASRExecutor() 71 | text_executor = TextExecutor() 72 | 73 | 74 | def audio2txt(path): 75 | # 返回path下所有文件构成的一个list列表 76 | filelist = os.listdir(path) 77 | # 保证读取按照文件的顺序 78 | filelist.sort(key=lambda x: int(x[:3])) 79 | # 遍历输出每一个文件的名字和类型 80 | words = [] 81 | for file in filelist: 82 | print(path + '/' + file) 83 | text = asr_executor( 84 | audio_file=path + '/' + file, 85 | device=paddle.get_device()) 86 | if text: 87 | result = text_executor( 88 | text=text, 89 | task='punc', 90 | model='ernie_linear_p3_wudao', 91 | device=paddle.get_device()) 92 | else: 93 | result = text 94 | words.append(result) 95 | return words 96 | 97 | 98 | import csv 99 | 100 | 101 | def txt2csv(txt): 102 | with open(path + '.csv', 'w', encoding='utf-8') as f: 103 | f_csv = csv.writer(f) 104 | for row in txt: 105 | f_csv.writerow([row]) 106 | 107 | 108 | if __name__ == '__main__': 109 | # 拿到新生成的音频的路径 110 | path = extract_audio('/media/linxu/mobilePan/AICA6首席架构师计划/QA/《跨上AI的战车》QA视频.mp4') 111 | # 划分音频 112 | path = qiefen(path=path, ty='video30', mmin_dur=0.5, mmax_dur=30, mmax_silence=0.5, menergy_threshold=55) 113 | # 音频转文本 需要GPU 114 | txt_all = audio2txt(path) 115 | 116 | # 存储文本 117 | # with open('Text Result.txt', 'w') as f: # 设置文件对象 118 | # print('Text Result: \n{}'.format(txt_all)) 119 | 120 | # 存入csv 121 | txt2csv(txt_all) 122 | -------------------------------------------------------------------------------- /xml2coco.py: -------------------------------------------------------------------------------- 1 | 2 | import os.path as osp 3 | import xml.etree.ElementTree as ET 4 | 5 | #import mmcv #需在mmdet环境下! 6 | import json 7 | 8 | 9 | from glob import glob 10 | from tqdm import tqdm 11 | from PIL import Image 12 | 13 | cls_classes = ['dog','cat','human'] #检测目标类别(不含background) 14 | label_ids = {name: i + 1 for i, name in enumerate(cls_classes)} 15 | 16 | def get_segmentation(points): 17 | return [points[0], points[1], points[2] + points[0], points[1], 18 | points[2] + points[0], points[3] + points[1], points[0], points[3] + points[1]] 19 | 20 | def parse_xml(xml_path, img_id, anno_id): 21 | tree = ET.parse(xml_path) 22 | root = tree.getroot() 23 | annotation = [] 24 | for obj in root.findall('object'): 25 | try: 26 | name = obj.find('name').text 27 | category_id = label_ids[name] 28 | bnd_box = obj.find('bndbox') 29 | xmin = int(bnd_box.find('xmin').text) 30 | ymin = int(bnd_box.find('ymin').text) 31 | xmax = int(bnd_box.find('xmax').text) 32 | ymax = int(bnd_box.find('ymax').text) 33 | if xmin>=xmax or ymin>=ymax: 34 | continue 35 | w = xmax - xmin + 1 36 | h = ymax - ymin + 1 37 | area = w*h 38 | segmentation = get_segmentation([xmin, ymin, w, h]) 39 | annotation.append({ 40 | "segmentation": segmentation, 41 | "area": area, 42 | "iscrowd": 0, 43 | "image_id": img_id, 44 | "bbox": [xmin, ymin, w, h], 45 | "category_id": category_id, 46 | "id": anno_id, 47 | "ignore": 0}) 48 | anno_id += 1 49 | except: 50 | continue 51 | return annotation, anno_id 52 | 53 | def cvt_annotations(img_path, xml_path, out_file): 54 | images = [] 55 | annotations = [] 56 | img_id = 1 57 | anno_id = 1 58 | for img_path in tqdm(glob(img_path + '/*.jpg')): 59 | w, h = Image.open(img_path).size 60 | img_name = osp.basename(img_path) 61 | img = {"file_name": img_name, "height": int(h), "width": int(w), "id": img_id} 62 | images.append(img) 63 | 64 | xml_file_name = img_name.split('.')[0] + '.xml' 65 | xml_file_path = osp.join(xml_path, xml_file_name) 66 | annos, anno_id = parse_xml(xml_file_path, img_id, anno_id) 67 | annotations.extend(annos) 68 | img_id += 1 69 | 70 | categories = [] 71 | for k,v in label_ids.items(): 72 | categories.append({"name": k, "id": v}) 73 | final_result = {"images": images, "annotations": annotations, "categories": categories} #COCO数据集格式 74 | #mmcv.dump(final_result, out_file) 需在mmdet环境下! 75 | with open(out_file, 'w') as f: 76 | json.dump(final_result, f, indent=4) 77 | return annotations 78 | 79 | 80 | def main(): 81 | xml_path = "/home1/huangqiangHD/dataset/train/xml/" #XML文件位置 82 | img_path = "/home1/huangqiangHD/dataset/train/images/"#Image文件位置 83 | out_path = "/home1/huangqiangHD/dataset/train/train.json"#COCO文件位置 84 | print('processing {} ...'.format("xml format annotations")) 85 | cvt_annotations(img_path, xml_path, out_path) 86 | print('Done!') 87 | 88 | 89 | if __name__ == '__main__': 90 | main() 91 | --------------------------------------------------------------------------------