├── Attention ├── attention.py └── se_attention.py ├── DataAugForObjectDetection ├── DataAugmentForObejctDetection.py ├── README.md ├── crop_boudingbox.py └── xml_helper.py ├── KITTI_2_VOC ├── README.md ├── create_train_test_txt.py ├── modify_annotations_txt.py └── txt_to_xml.py ├── README.md └── VOC_2_COCO ├── VOC2COCO_pic.py ├── VOC2COCO_xml.py └── xml_helper.py /Attention/attention.py: -------------------------------------------------------------------------------- 1 | # -*- coding=utf-8 -*- 2 | def add_attention(base_model): 3 | ''' 4 | 输入: 5 | base_model:基础模型 6 | 输出: 7 | 加了attention后的basemodel 8 | ''' 9 | ################## add attention by mao 2019-6-5 21:00 ################## 10 | #加attention, 点乘后求和 11 | def getSum(input_tensor): 12 | ''' 13 | input_tensor : [None, 49, 2048] 14 | Note: 15 | 函数里面要用的都得在函数内import!!!!!! 16 | ''' 17 | import keras.backend as K 18 | res = K.sum(input_tensor, axis=-2) 19 | return res 20 | 21 | from keras.layers import multiply, Reshape, RepeatVector, Permute, Lambda, Dense, BatchNormalization 22 | from keras.layers.pooling import GlobalAveragePooling2D 23 | from keras.models import Model 24 | 25 | x = base_model.output 26 | print(x.shape) 27 | _, H,W,C = x.shape 28 | H = int(H) 29 | W = int(W) 30 | C = int(C) 31 | x = GlobalAveragePooling2D(name='avg_pool_for_attention')(x) #[None, 7, 7, 2048] -> [None, 1, 1, 2048] 32 | x = Dense(H*W, activation='softmax', name='attention_w')(x) #全连接层,输出系数 33 | x = Reshape((H*W,))(x) #[None, 1, 49, 1] -> [None, 49] 34 | x = RepeatVector(C)(x) #[None, 49] -> [None, 2048, 49] 35 | x = Permute((2,1))(x) #[None, 2048, 49] -> [None, 49, 2048] 36 | x = Reshape((H, W, C))(x) #[None, 49, 2048] -> [None, 7, 7, 2048] 37 | x = multiply([base_model.output, x]) #逐个元素乘积 38 | base_model = Model(inputs=base_model.input, outputs=x) 39 | ############################## end of attention ######################### 40 | 41 | return base_model -------------------------------------------------------------------------------- /Attention/se_attention.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | def squeeze_excitation_layer(x, out_dim, ratio=16): 4 | ''' 5 | SE channel attention 6 | input: 7 | x 8 | out_dim : channel default 9 | ratio : reduction rate, defualt 16 10 | ''' 11 | squeeze = layers.GlobalAveragePooling2D()(x) 12 | excitation = layers.Dense(out_dim//ratio, activation='relu')(squeeze) 13 | excitation = layers.Dense(out_dim, activation='sigmoid')(excitation) 14 | excitation = layers.Reshape((1, 1, out_dim))(excitation) 15 | 16 | scale = layers.multiply([x, excitation]) 17 | 18 | return scale 19 | -------------------------------------------------------------------------------- /DataAugForObjectDetection/DataAugmentForObejctDetection.py: -------------------------------------------------------------------------------- 1 | # -*- coding=utf-8 -*- 2 | ############################################################## 3 | # description: 4 | # data augmentation for obeject detection 5 | # author: 6 | # maozezhong 2018-6-27 7 | ############################################################## 8 | 9 | # 包括: 10 | # 1. 裁剪(需改变bbox) 11 | # 2. 平移(需改变bbox) 12 | # 3. 改变亮度 13 | # 4. 加噪声 14 | # 5. 旋转角度(需要改变bbox) 15 | # 6. 镜像(需要改变bbox) 16 | # 7. cutout 17 | # 注意: 18 | # random.seed(),相同的seed,产生的随机数是一样的!! 19 | 20 | import time 21 | import random 22 | import cv2 23 | import os 24 | import math 25 | import numpy as np 26 | from skimage.util import random_noise 27 | from skimage import exposure 28 | 29 | def show_pic(img, bboxes=None): 30 | ''' 31 | 输入: 32 | img:图像array 33 | bboxes:图像的所有boudning box list, 格式为[[x_min, y_min, x_max, y_max]....] 34 | names:每个box对应的名称 35 | ''' 36 | cv2.imwrite('./1.jpg', img) 37 | img = cv2.imread('./1.jpg') 38 | for i in range(len(bboxes)): 39 | bbox = bboxes[i] 40 | x_min = bbox[0] 41 | y_min = bbox[1] 42 | x_max = bbox[2] 43 | y_max = bbox[3] 44 | cv2.rectangle(img,(int(x_min),int(y_min)),(int(x_max),int(y_max)),(0,255,0),3) 45 | cv2.namedWindow('pic', 0) # 1表示原图 46 | cv2.moveWindow('pic', 0, 0) 47 | cv2.resizeWindow('pic', 1200,800) # 可视化的图片大小 48 | cv2.imshow('pic', img) 49 | cv2.waitKey(0) 50 | cv2.destroyAllWindows() 51 | os.remove('./1.jpg') 52 | 53 | # 图像均为cv2读取 54 | class DataAugmentForObjectDetection(): 55 | def __init__(self, rotation_rate=0.5, max_rotation_angle=5, 56 | crop_rate=0.5, shift_rate=0.5, change_light_rate=0.5, 57 | add_noise_rate=0.5, flip_rate=0.5, 58 | cutout_rate=0.5, cut_out_length=50, cut_out_holes=1, cut_out_threshold=0.5): 59 | self.rotation_rate = rotation_rate 60 | self.max_rotation_angle = max_rotation_angle 61 | self.crop_rate = crop_rate 62 | self.shift_rate = shift_rate 63 | self.change_light_rate = change_light_rate 64 | self.add_noise_rate = add_noise_rate 65 | self.flip_rate = flip_rate 66 | self.cutout_rate = cutout_rate 67 | 68 | self.cut_out_length = cut_out_length 69 | self.cut_out_holes = cut_out_holes 70 | self.cut_out_threshold = cut_out_threshold 71 | 72 | # 加噪声 73 | def _addNoise(self, img): 74 | ''' 75 | 输入: 76 | img:图像array 77 | 输出: 78 | 加噪声后的图像array,由于输出的像素是在[0,1]之间,所以得乘以255 79 | ''' 80 | # random.seed(int(time.time())) 81 | # return random_noise(img, mode='gaussian', seed=int(time.time()), clip=True)*255 82 | return random_noise(img, mode='gaussian', clip=True)*255 83 | 84 | 85 | # 调整亮度 86 | def _changeLight(self, img): 87 | # random.seed(int(time.time())) 88 | flag = random.uniform(0.5, 1.5) #flag>1为调暗,小于1为调亮 89 | return exposure.adjust_gamma(img, flag) 90 | 91 | # cutout 92 | def _cutout(self, img, bboxes, length=100, n_holes=1, threshold=0.5): 93 | ''' 94 | 原版本:https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py 95 | Randomly mask out one or more patches from an image. 96 | Args: 97 | img : a 3D numpy array,(h,w,c) 98 | bboxes : 框的坐标 99 | n_holes (int): Number of patches to cut out of each image. 100 | length (int): The length (in pixels) of each square patch. 101 | ''' 102 | 103 | def cal_iou(boxA, boxB): 104 | ''' 105 | boxA, boxB为两个框,返回iou 106 | boxB为bouding box 107 | ''' 108 | 109 | # determine the (x, y)-coordinates of the intersection rectangle 110 | xA = max(boxA[0], boxB[0]) 111 | yA = max(boxA[1], boxB[1]) 112 | xB = min(boxA[2], boxB[2]) 113 | yB = min(boxA[3], boxB[3]) 114 | 115 | if xB <= xA or yB <= yA: 116 | return 0.0 117 | 118 | # compute the area of intersection rectangle 119 | interArea = (xB - xA + 1) * (yB - yA + 1) 120 | 121 | # compute the area of both the prediction and ground-truth 122 | # rectangles 123 | boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) 124 | boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) 125 | 126 | # compute the intersection over union by taking the intersection 127 | # area and dividing it by the sum of prediction + ground-truth 128 | # areas - the interesection area 129 | # iou = interArea / float(boxAArea + boxBArea - interArea) 130 | iou = interArea / float(boxBArea) 131 | 132 | # return the intersection over union value 133 | return iou 134 | 135 | # 得到h和w 136 | if img.ndim == 3: 137 | h,w,c = img.shape 138 | else: 139 | _,h,w,c = img.shape 140 | 141 | mask = np.ones((h,w,c), np.float32) 142 | 143 | for n in range(n_holes): 144 | 145 | chongdie = True #看切割的区域是否与box重叠太多 146 | 147 | while chongdie: 148 | y = np.random.randint(h) 149 | x = np.random.randint(w) 150 | 151 | y1 = np.clip(y - length // 2, 0, h) #numpy.clip(a, a_min, a_max, out=None), clip这个函数将将数组中的元素限制在a_min, a_max之间,大于a_max的就使得它等于 a_max,小于a_min,的就使得它等于a_min 152 | y2 = np.clip(y + length // 2, 0, h) 153 | x1 = np.clip(x - length // 2, 0, w) 154 | x2 = np.clip(x + length // 2, 0, w) 155 | 156 | chongdie = False 157 | for box in bboxes: 158 | if cal_iou([x1,y1,x2,y2], box) > threshold: 159 | chongdie = True 160 | break 161 | 162 | mask[y1: y2, x1: x2, :] = 0. 163 | 164 | # mask = np.expand_dims(mask, axis=0) 165 | img = img * mask 166 | 167 | return img 168 | 169 | # 旋转 170 | def _rotate_img_bbox(self, img, bboxes, angle=5, scale=1.): 171 | ''' 172 | 参考:https://blog.csdn.net/u014540717/article/details/53301195crop_rate 173 | 输入: 174 | img:图像array,(h,w,c) 175 | bboxes:该图像包含的所有boundingboxs,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值 176 | angle:旋转角度 177 | scale:默认1 178 | 输出: 179 | rot_img:旋转后的图像array 180 | rot_bboxes:旋转后的boundingbox坐标list 181 | ''' 182 | #---------------------- 旋转图像 ---------------------- 183 | w = img.shape[1] 184 | h = img.shape[0] 185 | # 角度变弧度 186 | rangle = np.deg2rad(angle) # angle in radians 187 | # now calculate new image width and height 188 | nw = (abs(np.sin(rangle)*h) + abs(np.cos(rangle)*w))*scale 189 | nh = (abs(np.cos(rangle)*h) + abs(np.sin(rangle)*w))*scale 190 | # ask OpenCV for the rotation matrix 191 | rot_mat = cv2.getRotationMatrix2D((nw*0.5, nh*0.5), angle, scale) 192 | # calculate the move from the old center to the new center combined 193 | # with the rotation 194 | rot_move = np.dot(rot_mat, np.array([(nw-w)*0.5, (nh-h)*0.5,0])) 195 | # the move only affects the translation, so update the translation 196 | # part of the transform 197 | rot_mat[0,2] += rot_move[0] 198 | rot_mat[1,2] += rot_move[1] 199 | # 仿射变换 200 | rot_img = cv2.warpAffine(img, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4) 201 | 202 | #---------------------- 矫正bbox坐标 ---------------------- 203 | # rot_mat是最终的旋转矩阵 204 | # 获取原始bbox的四个中点,然后将这四个点转换到旋转后的坐标系下 205 | rot_bboxes = list() 206 | for bbox in bboxes: 207 | xmin = bbox[0] 208 | ymin = bbox[1] 209 | xmax = bbox[2] 210 | ymax = bbox[3] 211 | point1 = np.dot(rot_mat, np.array([(xmin+xmax)/2, ymin, 1])) 212 | point2 = np.dot(rot_mat, np.array([xmax, (ymin+ymax)/2, 1])) 213 | point3 = np.dot(rot_mat, np.array([(xmin+xmax)/2, ymax, 1])) 214 | point4 = np.dot(rot_mat, np.array([xmin, (ymin+ymax)/2, 1])) 215 | # 合并np.array 216 | concat = np.vstack((point1, point2, point3, point4)) 217 | # 改变array类型 218 | concat = concat.astype(np.int32) 219 | # 得到旋转后的坐标 220 | rx, ry, rw, rh = cv2.boundingRect(concat) 221 | rx_min = rx 222 | ry_min = ry 223 | rx_max = rx+rw 224 | ry_max = ry+rh 225 | # 加入list中 226 | rot_bboxes.append([rx_min, ry_min, rx_max, ry_max]) 227 | 228 | return rot_img, rot_bboxes 229 | 230 | # 裁剪 231 | def _crop_img_bboxes(self, img, bboxes): 232 | ''' 233 | 裁剪后的图片要包含所有的框 234 | 输入: 235 | img:图像array 236 | bboxes:该图像包含的所有boundingboxs,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值 237 | 输出: 238 | crop_img:裁剪后的图像array 239 | crop_bboxes:裁剪后的bounding box的坐标list 240 | ''' 241 | #---------------------- 裁剪图像 ---------------------- 242 | w = img.shape[1] 243 | h = img.shape[0] 244 | x_min = w #裁剪后的包含所有目标框的最小的框 245 | x_max = 0 246 | y_min = h 247 | y_max = 0 248 | for bbox in bboxes: 249 | x_min = min(x_min, bbox[0]) 250 | y_min = min(y_min, bbox[1]) 251 | x_max = max(x_max, bbox[2]) 252 | y_max = max(y_max, bbox[3]) 253 | 254 | d_to_left = x_min #包含所有目标框的最小框到左边的距离 255 | d_to_right = w - x_max #包含所有目标框的最小框到右边的距离 256 | d_to_top = y_min #包含所有目标框的最小框到顶端的距离 257 | d_to_bottom = h - y_max #包含所有目标框的最小框到底部的距离 258 | 259 | #随机扩展这个最小框 260 | crop_x_min = int(x_min - random.uniform(0, d_to_left)) 261 | crop_y_min = int(y_min - random.uniform(0, d_to_top)) 262 | crop_x_max = int(x_max + random.uniform(0, d_to_right)) 263 | crop_y_max = int(y_max + random.uniform(0, d_to_bottom)) 264 | 265 | # 随机扩展这个最小框 , 防止别裁的太小 266 | # crop_x_min = int(x_min - random.uniform(d_to_left//2, d_to_left)) 267 | # crop_y_min = int(y_min - random.uniform(d_to_top//2, d_to_top)) 268 | # crop_x_max = int(x_max + random.uniform(d_to_right//2, d_to_right)) 269 | # crop_y_max = int(y_max + random.uniform(d_to_bottom//2, d_to_bottom)) 270 | 271 | #确保不要越界 272 | crop_x_min = max(0, crop_x_min) 273 | crop_y_min = max(0, crop_y_min) 274 | crop_x_max = min(w, crop_x_max) 275 | crop_y_max = min(h, crop_y_max) 276 | 277 | crop_img = img[crop_y_min:crop_y_max, crop_x_min:crop_x_max] 278 | 279 | #---------------------- 裁剪boundingbox ---------------------- 280 | #裁剪后的boundingbox坐标计算 281 | crop_bboxes = list() 282 | for bbox in bboxes: 283 | crop_bboxes.append([bbox[0]-crop_x_min, bbox[1]-crop_y_min, bbox[2]-crop_x_min, bbox[3]-crop_y_min]) 284 | 285 | return crop_img, crop_bboxes 286 | 287 | # 平移 288 | def _shift_pic_bboxes(self, img, bboxes): 289 | ''' 290 | 参考:https://blog.csdn.net/sty945/article/details/79387054 291 | 平移后的图片要包含所有的框 292 | 输入: 293 | img:图像array 294 | bboxes:该图像包含的所有boundingboxs,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值 295 | 输出: 296 | shift_img:平移后的图像array 297 | shift_bboxes:平移后的bounding box的坐标list 298 | ''' 299 | #---------------------- 平移图像 ---------------------- 300 | w = img.shape[1] 301 | h = img.shape[0] 302 | x_min = w #裁剪后的包含所有目标框的最小的框 303 | x_max = 0 304 | y_min = h 305 | y_max = 0 306 | for bbox in bboxes: 307 | x_min = min(x_min, bbox[0]) 308 | y_min = min(y_min, bbox[1]) 309 | x_max = max(x_max, bbox[2]) 310 | y_max = max(y_max, bbox[3]) 311 | 312 | d_to_left = x_min #包含所有目标框的最大左移动距离 313 | d_to_right = w - x_max #包含所有目标框的最大右移动距离 314 | d_to_top = y_min #包含所有目标框的最大上移动距离 315 | d_to_bottom = h - y_max #包含所有目标框的最大下移动距离 316 | 317 | x = random.uniform(-(d_to_left-1) / 3, (d_to_right-1) / 3) 318 | y = random.uniform(-(d_to_top-1) / 3, (d_to_bottom-1) / 3) 319 | 320 | M = np.float32([[1, 0, x], [0, 1, y]]) #x为向左或右移动的像素值,正为向右负为向左; y为向上或者向下移动的像素值,正为向下负为向上 321 | shift_img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0])) 322 | 323 | #---------------------- 平移boundingbox ---------------------- 324 | shift_bboxes = list() 325 | for bbox in bboxes: 326 | shift_bboxes.append([bbox[0]+x, bbox[1]+y, bbox[2]+x, bbox[3]+y]) 327 | 328 | return shift_img, shift_bboxes 329 | 330 | # 镜像 331 | def _filp_pic_bboxes(self, img, bboxes): 332 | ''' 333 | 参考:https://blog.csdn.net/jningwei/article/details/78753607 334 | 平移后的图片要包含所有的框 335 | 输入: 336 | img:图像array 337 | bboxes:该图像包含的所有boundingboxs,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值 338 | 输出: 339 | flip_img:平移后的图像array 340 | flip_bboxes:平移后的bounding box的坐标list 341 | ''' 342 | # ---------------------- 翻转图像 ---------------------- 343 | import copy 344 | flip_img = copy.deepcopy(img) 345 | if random.random() < 0.5: #0.5的概率水平翻转,0.5的概率垂直翻转 346 | horizon = True 347 | else: 348 | horizon = False 349 | h,w,_ = img.shape 350 | if horizon: #水平翻转 351 | flip_img = cv2.flip(flip_img, 1) #1是水平,-1是水平垂直 352 | else: 353 | flip_img = cv2.flip(flip_img, 0) 354 | 355 | # ---------------------- 调整boundingbox ---------------------- 356 | flip_bboxes = list() 357 | for box in bboxes: 358 | x_min = box[0] 359 | y_min = box[1] 360 | x_max = box[2] 361 | y_max = box[3] 362 | if horizon: 363 | flip_bboxes.append([w-x_max, y_min, w-x_min, y_max]) 364 | else: 365 | flip_bboxes.append([x_min, h-y_max, x_max, h-y_min]) 366 | 367 | return flip_img, flip_bboxes 368 | 369 | def dataAugment(self, img, bboxes): 370 | ''' 371 | 图像增强 372 | 输入: 373 | img:图像array 374 | bboxes:该图像的所有框坐标 375 | 输出: 376 | img:增强后的图像 377 | bboxes:增强后图片对应的box 378 | ''' 379 | change_num = 0 #改变的次数 380 | print('------') 381 | while change_num < 1: #默认至少有一种数据增强生效 382 | if random.random() < self.crop_rate: #裁剪 383 | print('裁剪') 384 | change_num += 1 385 | img, bboxes = self._crop_img_bboxes(img, bboxes) 386 | 387 | if random.random() > self.rotation_rate: #旋转 388 | print('旋转') 389 | change_num += 1 390 | # angle = random.uniform(-self.max_rotation_angle, self.max_rotation_angle) 391 | angle = random.sample([90, 180, 270],1)[0] 392 | scale = random.uniform(0.7, 0.8) 393 | img, bboxes = self._rotate_img_bbox(img, bboxes, angle, scale) 394 | 395 | if random.random() < self.shift_rate: #平移 396 | print('平移') 397 | change_num += 1 398 | img, bboxes = self._shift_pic_bboxes(img, bboxes) 399 | 400 | if random.random() > self.change_light_rate: #改变亮度 401 | print('亮度') 402 | change_num += 1 403 | img = self._changeLight(img) 404 | 405 | if random.random() < self.add_noise_rate: #加噪声 406 | print('加噪声') 407 | change_num += 1 408 | img = self._addNoise(img) 409 | 410 | if random.random() < self.cutout_rate: #cutout 411 | print('cutout') 412 | change_num += 1 413 | img = self._cutout(img, bboxes, length=self.cut_out_length, n_holes=self.cut_out_holes, threshold=self.cut_out_threshold) 414 | 415 | if random.random() < self.flip_rate: #翻转 416 | print('翻转') 417 | change_num += 1 418 | img, bboxes = self._filp_pic_bboxes(img, bboxes) 419 | print('\n') 420 | # print('------') 421 | return img, bboxes 422 | 423 | 424 | if __name__ == '__main__': 425 | 426 | ### test ### 427 | 428 | import shutil 429 | from xml_helper import * 430 | 431 | need_aug_num = 1 432 | 433 | dataAug = DataAugmentForObjectDetection() 434 | 435 | source_pic_root_path = './data_split' 436 | source_xml_root_path = './data_voc/VOC2007/Annotations' 437 | 438 | 439 | for parent, _, files in os.walk(source_pic_root_path): 440 | for file in files: 441 | cnt = 0 442 | while cnt < need_aug_num: 443 | pic_path = os.path.join(parent, file) 444 | xml_path = os.path.join(source_xml_root_path, file[:-4]+'.xml') 445 | coords = parse_xml(xml_path) #解析得到box信息,格式为[[x_min,y_min,x_max,y_max,name]] 446 | coords = [coord[:4] for coord in coords] 447 | 448 | img = cv2.imread(pic_path) 449 | show_pic(img, coords) # 原图 450 | 451 | auged_img, auged_bboxes = dataAug.dataAugment(img, coords) 452 | cnt += 1 453 | 454 | show_pic(auged_img, auged_bboxes) # 强化后的图 455 | 456 | 457 | -------------------------------------------------------------------------------- /DataAugForObjectDetection/README.md: -------------------------------------------------------------------------------- 1 | ## 文件说明 2 | - DataAugmentForObejectDetection.py : 针对目标检测的数据增强脚本 3 | - xml_helper.py : 辅助处理xml文件的脚本 4 | - crop_boundingbox.py : 根据原始图像以及其xml文件,将boundingbox截取下来作为新的图片,存放在data_crop文件夹下 5 | -------------------------------------------------------------------------------- /DataAugForObjectDetection/crop_boudingbox.py: -------------------------------------------------------------------------------- 1 | # - *-coding=utf-8 -*- 2 | ''' 3 | 将瑕疵图片的boundingbox截取下来作为新的一张图片 4 | 同时resize到跟原始图片统一大小 5 | ''' 6 | 7 | import cv2 8 | import os 9 | 10 | def show_pic(img, bboxes=None): 11 | ''' 12 | 输入: 13 | img:图像array 14 | bboxes:图像的所有boudning box list, 格式为[[x_min, y_min, x_max, y_max]....] 15 | names:每个box对应的名称 16 | ''' 17 | cv2.imwrite('./1.jpg', img) 18 | img = cv2.imread('./1.jpg') 19 | for i in range(len(bboxes)): 20 | bbox = bboxes[i] 21 | x_min = bbox[0] 22 | y_min = bbox[1] 23 | x_max = bbox[2] 24 | y_max = bbox[3] 25 | cv2.rectangle(img,(int(x_min),int(y_min)),(int(x_max),int(y_max)),(0,255,0),3) 26 | cv2.namedWindow('pic', 0) # 1表示原图 27 | cv2.moveWindow('pic', 0, 0) 28 | cv2.resizeWindow('pic', 1200,800) # 可视化的图片大小 29 | cv2.imshow('pic', img) 30 | cv2.waitKey(0) 31 | cv2.destroyAllWindows() 32 | os.remove('./1.jpg') 33 | 34 | def crop_bd(img, bboxes): 35 | ''' 36 | 裁剪后的图片为boudingbox围起来的区域 37 | 输入: 38 | img:图像array 39 | bboxes:该图像包含的所有boundingboxs,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值 40 | 输出: 41 | crop_bds:裁剪后的boudingbox图像array list,(可能有多个boudingbox) 42 | ''' 43 | #---------------------- 裁剪图像 ---------------------- 44 | crop_bds = list() 45 | for bbox in bboxes: 46 | x_min = bbox[0] 47 | y_min = bbox[1] 48 | x_max = bbox[2] 49 | y_max = bbox[3] 50 | crop_bd = img[y_min:y_max, x_min:x_max] 51 | crop_bds.append(crop_bd) 52 | return crop_bds 53 | 54 | if __name__ == '__main__': 55 | 56 | from xml_helper import * 57 | import shutil 58 | 59 | source_bad_pic_root_path = './data_split/bad' 60 | source_xml_root_path = './tianchi_detect/data_voc/VOC2007/Annotations' 61 | 62 | target_pic_root_path = './data_croped' #只包括有瑕疵的图片 63 | if os.path.exists(target_pic_root_path): 64 | shutil.rmtree(target_pic_root_path) 65 | os.mkdir(target_pic_root_path) 66 | 67 | cnt = 0 68 | process_flag = 0 69 | for parent, _, files in os.walk(source_bad_pic_root_path): 70 | for file in files: 71 | process_flag += 1 72 | print(str(process_flag)+'/'+str(len(files))) 73 | bad_pic_path = os.path.join(parent, file) 74 | xml_path = os.path.join(source_xml_root_path, file[:-4]+'.xml') 75 | img = cv2.imread(bad_pic_path) 76 | bboxes = parse_xml(xml_path) 77 | 78 | # #原图可视化一下 79 | # show_pic(img,bboxes) 80 | 81 | croped_imgs = crop_bd(img, bboxes) 82 | for croped_img in croped_imgs: 83 | cnt += 1 84 | target_pic_path = os.path.join(target_pic_root_path, file[:-4]+'_croped'+str(cnt)+'.jpg') 85 | # #reize一下 86 | # croped_img = cv2.resize(croped_img, (img.shape[1], img.shape[0])) 87 | # #可视化一下截取的图 88 | # show_pic(croped_img,[[0,0,croped_img.shape[1],img.shape[0]]]) 89 | #写入 90 | cv2.imwrite(target_pic_path, croped_img) 91 | 92 | -------------------------------------------------------------------------------- /DataAugForObjectDetection/xml_helper.py: -------------------------------------------------------------------------------- 1 | # -*- coding=utf-8 -*- 2 | import xml.etree.ElementTree as ET 3 | import xml.dom.minidom as DOC 4 | 5 | # 从xml文件中提取bounding box信息, 格式为[[x_min, y_min, x_max, y_max, name]] 6 | def parse_xml(xml_path): 7 | ''' 8 | 输入: 9 | xml_path: xml的文件路径 10 | 输出: 11 | 从xml文件中提取bounding box信息, 格式为[[x_min, y_min, x_max, y_max, name]] 12 | ''' 13 | tree = ET.parse(xml_path) 14 | root = tree.getroot() 15 | objs = root.findall('object') 16 | coords = list() 17 | for ix, obj in enumerate(objs): 18 | name = obj.find('name').text 19 | box = obj.find('bndbox') 20 | x_min = int(box[0].text) 21 | y_min = int(box[1].text) 22 | x_max = int(box[2].text) 23 | y_max = int(box[3].text) 24 | coords.append([x_min, y_min, x_max, y_max, name]) 25 | return coords 26 | 27 | #将bounding box信息写入xml文件中, bouding box格式为[[x_min, y_min, x_max, y_max, name]] 28 | def generate_xml(img_name,coords,img_size,out_root_path): 29 | ''' 30 | 输入: 31 | img_name:图片名称,如a.jpg 32 | coords:坐标list,格式为[[x_min, y_min, x_max, y_max, name]],name为概况的标注 33 | img_size:图像的大小,格式为[h,w,c] 34 | out_root_path: xml文件输出的根路径 35 | ''' 36 | doc = DOC.Document() # 创建DOM文档对象 37 | 38 | annotation = doc.createElement('annotation') 39 | doc.appendChild(annotation) 40 | 41 | title = doc.createElement('folder') 42 | title_text = doc.createTextNode('Tianchi') 43 | title.appendChild(title_text) 44 | annotation.appendChild(title) 45 | 46 | title = doc.createElement('filename') 47 | title_text = doc.createTextNode(img_name) 48 | title.appendChild(title_text) 49 | annotation.appendChild(title) 50 | 51 | source = doc.createElement('source') 52 | annotation.appendChild(source) 53 | 54 | title = doc.createElement('database') 55 | title_text = doc.createTextNode('The Tianchi Database') 56 | title.appendChild(title_text) 57 | source.appendChild(title) 58 | 59 | title = doc.createElement('annotation') 60 | title_text = doc.createTextNode('Tianchi') 61 | title.appendChild(title_text) 62 | source.appendChild(title) 63 | 64 | size = doc.createElement('size') 65 | annotation.appendChild(size) 66 | 67 | title = doc.createElement('width') 68 | title_text = doc.createTextNode(str(img_size[1])) 69 | title.appendChild(title_text) 70 | size.appendChild(title) 71 | 72 | title = doc.createElement('height') 73 | title_text = doc.createTextNode(str(img_size[0])) 74 | title.appendChild(title_text) 75 | size.appendChild(title) 76 | 77 | title = doc.createElement('depth') 78 | title_text = doc.createTextNode(str(img_size[2])) 79 | title.appendChild(title_text) 80 | size.appendChild(title) 81 | 82 | for coord in coords: 83 | 84 | object = doc.createElement('object') 85 | annotation.appendChild(object) 86 | 87 | title = doc.createElement('name') 88 | title_text = doc.createTextNode(coord[4]) 89 | title.appendChild(title_text) 90 | object.appendChild(title) 91 | 92 | pose = doc.createElement('pose') 93 | pose.appendChild(doc.createTextNode('Unspecified')) 94 | object.appendChild(pose) 95 | truncated = doc.createElement('truncated') 96 | truncated.appendChild(doc.createTextNode('1')) 97 | object.appendChild(truncated) 98 | difficult = doc.createElement('difficult') 99 | difficult.appendChild(doc.createTextNode('0')) 100 | object.appendChild(difficult) 101 | 102 | bndbox = doc.createElement('bndbox') 103 | object.appendChild(bndbox) 104 | title = doc.createElement('xmin') 105 | title_text = doc.createTextNode(str(int(float(coord[0])))) 106 | title.appendChild(title_text) 107 | bndbox.appendChild(title) 108 | title = doc.createElement('ymin') 109 | title_text = doc.createTextNode(str(int(float(coord[1])))) 110 | title.appendChild(title_text) 111 | bndbox.appendChild(title) 112 | title = doc.createElement('xmax') 113 | title_text = doc.createTextNode(str(int(float(coord[2])))) 114 | title.appendChild(title_text) 115 | bndbox.appendChild(title) 116 | title = doc.createElement('ymax') 117 | title_text = doc.createTextNode(str(int(float(coord[3])))) 118 | title.appendChild(title_text) 119 | bndbox.appendChild(title) 120 | 121 | # 将DOM对象doc写入文件 122 | f = open(os.path.join(out_root_path, img_name[:-4]+'.xml'),'w') 123 | f.write(doc.toprettyxml(indent = '')) 124 | f.close() 125 | -------------------------------------------------------------------------------- /KITTI_2_VOC/README.md: -------------------------------------------------------------------------------- 1 | # KITTIdata_to_voc 2 | transform KITTI data to voc file form 3 | ##### reference: 4 | - [SSD: Single Shot MultiBox Detector 训练KITTI数据集(1)](https://blog.csdn.net/jesse_mx/article/details/65634482) 5 | 6 | -------------------------------------------------------------------------------- /KITTI_2_VOC/create_train_test_txt.py: -------------------------------------------------------------------------------- 1 | # create_train_test_txt.py 2 | # encoding:utf-8 3 | import pdb 4 | import glob 5 | import os 6 | import random 7 | import math 8 | 9 | def get_sample_value(txt_name, category_name): 10 | label_path = './Labels/' 11 | txt_path = label_path + txt_name+'.txt' 12 | try: 13 | with open(txt_path) as r_tdf: 14 | if category_name in r_tdf.read(): 15 | return ' 1' 16 | else: 17 | return '-1' 18 | except IOError as ioerr: 19 | print('File error:'+str(ioerr)) 20 | 21 | txt_list_path = glob.glob('./Labels/*.txt') 22 | txt_list = [] 23 | 24 | for item in txt_list_path: 25 | temp1,temp2 = os.path.splitext(os.path.basename(item)) 26 | txt_list.append(temp1) 27 | txt_list.sort() 28 | print(txt_list, end = '\n\n') 29 | 30 | # 有博客建议train:val:test=8:1:1,先尝试用一下 31 | num_trainval = random.sample(txt_list, math.floor(len(txt_list)*9/10.0)) # 可修改百分比 32 | num_trainval.sort() 33 | print(num_trainval, end = '\n\n') 34 | 35 | num_train = random.sample(num_trainval,math.floor(len(num_trainval)*8/9.0)) # 可修改百分比 36 | num_train.sort() 37 | print(num_train, end = '\n\n') 38 | 39 | num_val = list(set(num_trainval).difference(set(num_train))) 40 | num_val.sort() 41 | print(num_val, end = '\n\n') 42 | 43 | num_test = list(set(txt_list).difference(set(num_trainval))) 44 | num_test.sort() 45 | print(num_test, end = '\n\n') 46 | 47 | pdb.set_trace() 48 | 49 | Main_path = './ImageSets/Main/' 50 | train_test_name = ['trainval','train','val','test'] 51 | category_name = ['Car','Pedestrian','Cyclist'] 52 | pic_absolute_path = os.getcwd()+'/JPEGImages/' 53 | 54 | # 循环写trainvl train val test 55 | for item_train_test_name in train_test_name: 56 | list_name = 'num_' 57 | list_name += item_train_test_name 58 | train_test_txt_name = Main_path + item_train_test_name + '.txt' 59 | try: 60 | # 写单个文件 61 | with open(train_test_txt_name, 'w') as w_tdf: 62 | # 一行一行写 63 | for item in eval(list_name): 64 | w_tdf.write(pic_absolute_path+item+'.png\n') 65 | # 循环写Car Pedestrian Cyclist 66 | for item_category_name in category_name: 67 | category_txt_name = Main_path + item_category_name + '_' + item_train_test_name + '.txt' 68 | with open(category_txt_name, 'w') as w_tdf: 69 | # 一行一行写 70 | for item in eval(list_name): 71 | w_tdf.write(pic_absolute_path + item+' '+ get_sample_value(item, item_category_name)+'.png\n') 72 | except IOError as ioerr: 73 | print('File error:'+str(ioerr)) -------------------------------------------------------------------------------- /KITTI_2_VOC/modify_annotations_txt.py: -------------------------------------------------------------------------------- 1 | # modify_annotations_txt.py 2 | # 注意:是在原始数据上直接更改的,最好先备份一下原始数据集 3 | # 不过亲测没什么问题 4 | import glob 5 | import string 6 | import os 7 | 8 | ##备份一下 9 | os.system('cp -rf '+os.getcwd()+'/Labels '+os.getcwd()+'/Labels_ori') 10 | 11 | txt_list = glob.glob('./Labels/*.txt') # 存储Labels文件夹所有txt文件路径 12 | def show_category(txt_list): 13 | category_list= [] 14 | for item in txt_list: 15 | try: 16 | with open(item) as tdf: 17 | for each_line in tdf: 18 | labeldata = each_line.strip().split(' ') # 去掉前后多余的字符并把其分开 19 | category_list.append(labeldata[0]) # 只要第一个字段,即类别 20 | except IOError as ioerr: 21 | print('File error:'+str(ioerr)) 22 | print(set(category_list)) # 输出集合 23 | 24 | def merge(line): 25 | each_line='' 26 | for i in range(len(line)): 27 | if i!= (len(line)-1): 28 | each_line=each_line+line[i]+' ' 29 | else: 30 | each_line=each_line+line[i] # 最后一条字段后面不加空格 31 | each_line=each_line+'\n' 32 | return (each_line) 33 | 34 | print('before modify categories are:\n') 35 | show_category(txt_list) 36 | 37 | for item in txt_list: 38 | new_txt=[] 39 | try: 40 | with open(item, 'r') as r_tdf: 41 | for each_line in r_tdf: 42 | labeldata = each_line.strip().split(' ') 43 | if labeldata[0] in ['Truck','Van','Tram']: # 合并汽车类 44 | labeldata[0] = labeldata[0].replace(labeldata[0],'Car') 45 | if labeldata[0] == 'Person_sitting': # 合并行人类 46 | labeldata[0] = labeldata[0].replace(labeldata[0],'Pedestrian') 47 | if labeldata[0] == 'DontCare': # 忽略Dontcare类 48 | continue 49 | if labeldata[0] == 'Misc': # 忽略Misc类 50 | continue 51 | new_txt.append(merge(labeldata)) # 重新写入新的txt文件 52 | with open(item,'w+') as w_tdf: # w+是打开原文件将内容删除,另写新内容进去 53 | for temp in new_txt: 54 | w_tdf.write(temp) 55 | except IOError as ioerr: 56 | print('File error:'+str(ioerr)) 57 | 58 | print('\nafter modify categories are:\n') 59 | show_category(txt_list) -------------------------------------------------------------------------------- /KITTI_2_VOC/txt_to_xml.py: -------------------------------------------------------------------------------- 1 | # txt_to_xml.py 2 | # encoding:utf-8 3 | # 根据一个给定的XML Schema,使用DOM树的形式从空白文件生成一个XML 4 | from xml.dom.minidom import Document 5 | import cv2 6 | import os 7 | 8 | def generate_xml(name,split_lines,img_size,class_ind): 9 | doc = Document() # 创建DOM文档对象 10 | 11 | annotation = doc.createElement('annotation') 12 | doc.appendChild(annotation) 13 | 14 | title = doc.createElement('folder') 15 | title_text = doc.createTextNode('KITTI') 16 | title.appendChild(title_text) 17 | annotation.appendChild(title) 18 | 19 | img_name=name+'.png' 20 | 21 | title = doc.createElement('filename') 22 | title_text = doc.createTextNode(img_name) 23 | title.appendChild(title_text) 24 | annotation.appendChild(title) 25 | 26 | source = doc.createElement('source') 27 | annotation.appendChild(source) 28 | 29 | title = doc.createElement('database') 30 | title_text = doc.createTextNode('The KITTI Database') 31 | title.appendChild(title_text) 32 | source.appendChild(title) 33 | 34 | title = doc.createElement('annotation') 35 | title_text = doc.createTextNode('KITTI') 36 | title.appendChild(title_text) 37 | source.appendChild(title) 38 | 39 | size = doc.createElement('size') 40 | annotation.appendChild(size) 41 | 42 | title = doc.createElement('width') 43 | title_text = doc.createTextNode(str(img_size[1])) 44 | title.appendChild(title_text) 45 | size.appendChild(title) 46 | 47 | title = doc.createElement('height') 48 | title_text = doc.createTextNode(str(img_size[0])) 49 | title.appendChild(title_text) 50 | size.appendChild(title) 51 | 52 | title = doc.createElement('depth') 53 | title_text = doc.createTextNode(str(img_size[2])) 54 | title.appendChild(title_text) 55 | size.appendChild(title) 56 | 57 | for split_line in split_lines: 58 | line=split_line.strip().split() 59 | if line[0] in class_ind: 60 | object = doc.createElement('object') 61 | annotation.appendChild(object) 62 | 63 | title = doc.createElement('name') 64 | title_text = doc.createTextNode(line[0]) 65 | title.appendChild(title_text) 66 | object.appendChild(title) 67 | 68 | bndbox = doc.createElement('bndbox') 69 | object.appendChild(bndbox) 70 | title = doc.createElement('xmin') 71 | title_text = doc.createTextNode(str(int(float(line[4])))) 72 | title.appendChild(title_text) 73 | bndbox.appendChild(title) 74 | title = doc.createElement('ymin') 75 | title_text = doc.createTextNode(str(int(float(line[5])))) 76 | title.appendChild(title_text) 77 | bndbox.appendChild(title) 78 | title = doc.createElement('xmax') 79 | title_text = doc.createTextNode(str(int(float(line[6])))) 80 | title.appendChild(title_text) 81 | bndbox.appendChild(title) 82 | title = doc.createElement('ymax') 83 | title_text = doc.createTextNode(str(int(float(line[7])))) 84 | title.appendChild(title_text) 85 | bndbox.appendChild(title) 86 | 87 | # 将DOM对象doc写入文件 88 | f = open('Annotations/'+name+'.xml','w') 89 | f.write(doc.toprettyxml(indent = '')) 90 | f.close() 91 | 92 | if __name__ == '__main__': 93 | class_ind=('Pedestrian', 'Car', 'Cyclist') 94 | cur_dir=os.getcwd() 95 | labels_dir=os.path.join(cur_dir,'Labels') 96 | for parent, dirnames, filenames in os.walk(labels_dir): # 分别得到根目录,子目录和根目录下文件 97 | for file_name in filenames: 98 | full_path=os.path.join(parent, file_name) # 获取文件全路径 99 | f=open(full_path) 100 | split_lines = f.readlines() 101 | name= file_name[:-4] # 后四位是扩展名.txt,只取前面的文件名 102 | img_name=name+'.png' 103 | img_path=os.path.join(os.getcwd()+'/training/',img_name) # 路径需要自行修改 104 | img_size=cv2.imread(img_path).shape 105 | generate_xml(name,split_lines,img_size,class_ind) 106 | print('all txts has converted into xmls') -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 文件 2 | - DataAugForObjectDetection : 针对目标检测的增强脚本文件夹,增强方式包括 3 | - 裁剪(会改变bbox) 4 | - 平移(会改变bbox) 5 | - 旋转(会改变bbox) 6 | - 镜像(会改变bbox) 7 | - 改变亮度 8 | - 加噪声 9 | - [cutout](https://arxiv.org/abs/1708.04552) 10 | - KITTI_2_VOC : 将KITTI数据形式转换为VOC形式 11 | - VOC_2_COCO : 讲VOC形式数据转换为COCO格式 12 | - Attention : feature map Attention && channel attention (SE) 13 | 14 | ## To_Do_List 15 | - [ ] GAN with LSR for data augment 16 | -------------------------------------------------------------------------------- /VOC_2_COCO/VOC2COCO_pic.py: -------------------------------------------------------------------------------- 1 | # -*- coding=utf-8 -*- 2 | import os 3 | import shutil 4 | 5 | source_pic_root_path = './data_voc/VOC2007/JPEGImages' 6 | target_pic_root_path = './data_coco/coco_train2014' 7 | if os.path.exists(target_pic_root_path): 8 | shutil.rmtree(target_pic_root_path) 9 | os.makedirs(target_pic_root_path) 10 | for parent, _, files in os.walk(source_pic_root_path): 11 | for file in files: 12 | target_pic_path = os.path.join(target_pic_root_path, file) 13 | source_pic_path = os.path.join(source_pic_root_path, file) 14 | shutil.copyfile(source_pic_path, target_pic_path) 15 | print('done') -------------------------------------------------------------------------------- /VOC_2_COCO/VOC2COCO_xml.py: -------------------------------------------------------------------------------- 1 | # -*- coding=utf-8 -*- 2 | import json 3 | import os 4 | import cv2 5 | import xml.etree.ElementTree as ET 6 | import shutil 7 | 8 | # 从xml文件中提取bounding box信息, 格式为[[x_min, y_min, x_max, y_max, name]] 9 | def parse_xml(xml_path): 10 | tree = ET.parse(xml_path) 11 | root = tree.getroot() 12 | objs = root.findall('object') 13 | coords = list() 14 | for ix, obj in enumerate(objs): 15 | name = obj.find('name').text 16 | box = obj.find('bndbox') 17 | x_min = int(box[0].text) 18 | y_min = int(box[1].text) 19 | x_max = int(box[2].text) 20 | y_max = int(box[3].text) 21 | coords.append([x_min, y_min, x_max, y_max, name]) 22 | return coords 23 | 24 | def convert(root_path, source_xml_root_path, target_xml_root_path, phase='train', split=80000): 25 | ''' 26 | root_path: 27 | 根路径,里面包含JPEGImages(图片文件夹),classes.txt(类别标签),以及annotations文件夹(如果没有则会自动创建,用于保存最后的json) 28 | source_xml_root_path: 29 | VOC xml文件存放的根目录 30 | target_xml_root_path: 31 | coco xml存放的根目录 32 | phase: 33 | 状态:'train'或者'test' 34 | split: 35 | train和test图片的分界点数目 36 | 37 | ''' 38 | 39 | dataset = {'categories':[], 'images':[], 'annotations':[]} 40 | 41 | # 打开类别标签 42 | with open(os.path.join(root_path, 'classes.txt')) as f: 43 | classes = f.read().strip().split() 44 | 45 | # 建立类别标签和数字id的对应关系 46 | for i, cls in enumerate(classes, 1): 47 | dataset['categories'].append({'id': i, 'name': cls, 'supercategory': 'beverage'}) #mark 48 | 49 | # 读取images文件夹的图片名称 50 | pics = [f for f in os.listdir(os.path.join(root_path, 'JPEGImages'))] 51 | 52 | # 判断是建立训练集还是验证集 53 | if phase == 'train': 54 | pics = [line for i, line in enumerate(pics) if i <= split] 55 | elif phase == 'val': 56 | pics = [line for i, line in enumerate(pics) if i > split] 57 | 58 | print('---------------- start convert ---------------') 59 | bnd_id = 1 #初始为1 60 | for i, pic in enumerate(pics): 61 | # print('pic '+str(i+1)+'/'+str(len(pics))) 62 | xml_path = os.path.join(source_xml_root_path, pic[:-4]+'.xml') 63 | pic_path = os.path.join(root_path, 'JPEGImages/' + pic) 64 | # 用opencv读取图片,得到图像的宽和高 65 | im = cv2.imread(pic_path) 66 | height, width, _ = im.shape 67 | # 添加图像的信息到dataset中 68 | dataset['images'].append({'file_name': pic, 69 | 'id': i, 70 | 'width': width, 71 | 'height': height}) 72 | try: 73 | coords = parse_xml(xml_path) 74 | except: 75 | print(pic[:-4]+'.xml not exists~') 76 | continue 77 | for coord in coords: 78 | # x_min 79 | x1 = int(coord[0])-1 80 | x1 = max(x1, 0) 81 | # y_min 82 | y1 = int(coord[1])-1 83 | y1 = max(y1, 0) 84 | # x_max 85 | x2 = int(coord[2]) 86 | # y_max 87 | y2 = int(coord[3]) 88 | assert x1