├── README ├── README.md ├── Transform.py ├── data └── test.txt ├── jaccard.py ├── link.py └── xmlSet.py /README: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lidc54/SSD_data_augment/cd22221f1c180b54b8102f3fbc49a94d27fbd246/README -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #ssd中用到了比较独特的数据增强方法,尤其是随机裁剪和之前的方式不太一样,现在通过读源码,用python重新实现了一遍,先记下来 2 | 3 | #同为数据增强,ssd的用法有点不太一样,特别是random crop随机裁剪 4 | ssd的方法有放大有缩小 5 | 6 | 随机裁剪的时候也并没有太多的考虑目标主体是否落在了裁剪框里面,比如一个人的图片,如果神展开四肢,这个人的ground truth会比较大,但就目标来讲,有点稀疏矩阵的感觉。 7 | 8 | 所以,代码jaccard就是再次实现这个方法(单纯的预处理)。 9 | -------------------------------------------------------------------------------- /Transform.py: -------------------------------------------------------------------------------- 1 | ''' 2 | mirror:probably is 0.5 3 | distort:color jitter 4 | expand:origin image expand to 4*4 5 | crop:random crop 6 | ''' 7 | from jaccard import * 8 | 9 | 10 | # import random 11 | # from PIL import Image, ImageEnhance, ImageOps, ImageFile 12 | # import matplotlib.pyplot as plt 13 | # import numpy as np 14 | # import copy 15 | 16 | 17 | def transform(origin_image, objects): 18 | """ 19 | data augment 20 | :param origin_image: 21 | :param objects: annoated dict 22 | :return: 23 | """ 24 | # whitening 25 | image = whiter(origin_image) 26 | # image=origin_image 27 | 28 | trans_dict = {} 29 | sz = (300, 300) # width & height -- order 30 | 31 | # 1.expand image$$ 32 | image_expand, lables_expand = expand(origin_image, objects, sz) 33 | whiter_expand=whiter(image_expand) 34 | trans_dict['expand'] = [whiter_expand, lables_expand, image_expand] 35 | 36 | # 2.origin image -- resize$$ 37 | image, objects = resize_imgAnno(sz, image, objects) 38 | origin_image, _ = resize_imgAnno(sz, origin_image, objects) 39 | trans_dict['origin'] = [image.copy(), copy.deepcopy(objects), origin_image.copy()] 40 | 41 | # 3.mirror$$ 42 | prob = random.randint(1, 2) 43 | if prob > 1: 44 | image = image[:, ::-1] 45 | origin_image = origin_image[:, ::-1] 46 | mirrot_anno(image, objects) 47 | trans_dict['mirror'] = [image.copy(), copy.deepcopy(objects), origin_image.copy()] 48 | 49 | # 4.distort$$ 50 | image = jitter(image) 51 | # trans_dict['distort'] = [image.copy(),copy.deepcopy(objects)] 52 | 53 | # 5.jaccard$$ 54 | try: 55 | dict_jaccard = corp_image(image, objects, sz, origin_image) 56 | trans_dict = dict(trans_dict, **dict_jaccard) 57 | # print 'po' 58 | except Exception, e: 59 | print "what's wrong:", e 60 | 61 | return trans_dict 62 | 63 | 64 | def mirrot_anno(image, objects): 65 | '''mirror the annoated info at the same time''' 66 | height, width, channel = image.shape 67 | coord = ['xmin', 'xmax', 'ymin', 'ymax'] 68 | for i in objects.keys(): 69 | if not (type(objects[i]) == dict and objects[i].has_key('bndbox')): 70 | continue 71 | xmin, xmax, ymin, ymax = [int(objects[i]['bndbox'][k]) for k in coord] 72 | new_xmin = width - xmax 73 | new_xmax = width - xmin 74 | objects[i]['bndbox']['xmin'] = new_xmin 75 | objects[i]['bndbox']['xmax'] = new_xmax 76 | 77 | 78 | def jitter(data): 79 | image, min_max = transfer(data) 80 | image = Image.fromarray(image) 81 | random_factor = np.random.randint(0, 31) / 10. 82 | image = ImageEnhance.Color(image).enhance(random_factor) 83 | random_factor = np.random.randint(5, 11) / 10. 84 | image = ImageEnhance.Brightness(image).enhance(random_factor) 85 | random_factor = np.random.randint(10, 21) / 10. 86 | image = ImageEnhance.Contrast(image).enhance(random_factor) 87 | ''' 88 | plt.subplot(121) 89 | plt.imshow(image) 90 | plt.subplot(122) 91 | plt.imshow(contrast_image) 92 | ''' 93 | image = np.array(image) 94 | image = re_transfer(image, min_max) 95 | return image 96 | 97 | 98 | def expand(image, objects, sz, ratio=3): 99 | """ 100 | zoom out,but be careful: 101 | if whiter will be applied, the function(whiter) should be carried out after expand() 102 | ranther than before this function.for matters of Mean_value. 103 | :param image: 104 | :param objects: 105 | :param sz: 106 | :param ratio: 107 | :return: 108 | """ 109 | height, width, channel = image.shape 110 | aug_sz = (sz[0] * ratio, sz[1] * ratio) 111 | mean_value = [104, 117, 123] 112 | # build a canvus 113 | canvus = np.zeros((aug_sz[0], aug_sz[0], channel), dtype="uint8") 114 | # canvus_origin = np.zeros((aug_sz[0], aug_sz[0], channel), dtype="uint8") # origin image for show 115 | for i in range(channel): 116 | canvus[:, :, i] = mean_value[i] 117 | # canvus_origin[:, :, i] = mean_value[i] # origin image for show 118 | 119 | # insert the image 120 | h_off = random.randint(0, aug_sz[0] - height) 121 | w_off = random.randint(0, aug_sz[1] - width) 122 | canvus[h_off:h_off + height, w_off:w_off + width, :] = image 123 | # canvus_origin[h_off:h_off + height, w_off:w_off + width, :] = origin_image # origin image for show 124 | # adjust the labels 125 | new_objects = copy.deepcopy(objects) 126 | coord = ['xmin', 'xmax', 'ymin', 'ymax'] 127 | for i in new_objects.keys(): 128 | if not (type(new_objects[i]) == dict and new_objects[i].has_key('bndbox')): 129 | continue 130 | coor = new_objects[i]['bndbox'] 131 | xmin, xmax, ymin, ymax = [int(coor[k]) for k in coord] # coor's coordination is num now,not a string 132 | newCoor = [xmin + w_off, xmax + w_off, ymin + h_off, ymax + h_off] 133 | for k, key in enumerate(coord): 134 | coor[key] = newCoor[k] 135 | 136 | canvus, new_objects = resize_imgAnno(sz, canvus, new_objects) 137 | # canvus_origin, _ = resize_imgAnno(sz, canvus_origin, new_objects) 138 | return [canvus, new_objects] # , canvus_origin] # image,lables,origin_image 139 | 140 | 141 | def whiter(image): 142 | """ 143 | 144 | :param image: image.shape=[300,300,3];height,width,channel 145 | :return: the data after whitering 146 | """ 147 | im=image.astype('float64') 148 | data = np.zeros(image.shape).astype('float16') 149 | w, h, c = image.shape 150 | for i in range(c): 151 | data[:, :, i] = (im[:, :, i] - np.mean(im[:, :, i])) / np.std(im[:, :, i]) 152 | return data 153 | -------------------------------------------------------------------------------- /jaccard.py: -------------------------------------------------------------------------------- 1 | # encoding : utf-8 2 | import random,math,cv2 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import copy 6 | from PIL import Image, ImageEnhance, ImageOps, ImageFile 7 | 8 | 9 | def transfer(image): 10 | """ 11 | data is transfered to 0-255,for show 12 | transfer() & re_transfer() will be used in PIL.Image 13 | PIL.image was used in resize() & jitter() 14 | :param image: 15 | :return: its range of every channel 16 | """ 17 | data = image.copy() 18 | min_max = [] # range of channls 19 | h, w, c = data.shape 20 | for i in range(c): 21 | min_d, max_d = (np.min(data[:, :, i]), np.max(data[:, :, i])) 22 | min_max.append((min_d, max_d)) 23 | zone = max_d - min_d 24 | if zone < 0.1: 25 | zone = 0.1 26 | data[:, :, i] = 1.0 * (data[:, :, i] - min_d) / zone * 255 27 | data = data.astype('uint8') 28 | return data, min_max 29 | 30 | 31 | def re_transfer(image, min_max): 32 | """ 33 | transfer to float ones 34 | :param image: 35 | :param min_max: 36 | :return: 37 | """ 38 | data = image.copy() 39 | data = data.astype('float16') 40 | h, w, c = data.shape 41 | for i in range(c): 42 | min_d, max_d = min_max[i] 43 | min_now, max_now = (np.min(data[:, :, i]), np.max(data[:, :, i])) 44 | zone = (max_now - min_now) 45 | if zone < 0.1: 46 | zone = 0.1 47 | 48 | data[:, :, i] = 1.0 * (data[:, :, i] - min_now) / zone * (max_d - min_d) + min_d 49 | 50 | return data 51 | 52 | 53 | def show_data(image, label): 54 | if image.dtype != np.uint8: 55 | data, _ = transfer(image) 56 | else: 57 | data = image 58 | colors = plt.cm.hsv(np.linspace(0, 1, 21)).tolist() 59 | print colors[1] 60 | plt.imshow(data) 61 | 62 | AX = plt.gca() 63 | coord = ['xmin', 'xmax', 'ymin', 'ymax'] 64 | for i in label.keys(): 65 | if not (type(label[i]) == dict and label[i].has_key('bndbox')): 66 | continue 67 | coor = label[i]['bndbox'] 68 | xmin, xmax, ymin, ymax = [int(coor[k]) for k in coord] # coor's coordination is num now,not a string 69 | newCoor = (xmin + 1, ymin + 1), xmax - xmin - 2, ymax - ymin - 2 70 | AX.add_patch(plt.Rectangle(*newCoor, fill=False, edgecolor=colors[random.randint(0, 20)], linewidth=2)) 71 | plt.show() 72 | 73 | 74 | def resize_imgAnno(sz, data, oj): 75 | '''resize the image & annotated info''' 76 | objects = copy.deepcopy(oj) 77 | height, width, channel = data.shape 78 | coord = ['xmin', 'xmax', 'ymin', 'ymax'] 79 | hw = [width, width, height, height] 80 | top_width, top_height = sz 81 | top_hw = [top_width, top_width, top_height, top_height] 82 | for i in objects.keys(): 83 | if not (type(objects[i]) == dict and objects[i].has_key('bndbox')): 84 | continue 85 | coor = [int(objects[i]['bndbox'][k]) for k in coord] 86 | coor = [1.0 * x / y for x, y in zip(coor, hw)] 87 | coor = [int(x * y) for x, y in zip(coor, top_hw)] 88 | for k in range(len(coord)): 89 | objects[i]['bndbox'][coord[k]] = coor[k] 90 | # image -- 91 | # *255 92 | image, min_max = transfer(data) 93 | image = Image.fromarray(image) 94 | image = image.resize(sz) 95 | image = np.array(image) 96 | image = re_transfer(image, min_max) 97 | return image, objects 98 | # show_data(image,objects) 99 | 100 | 101 | def lap(coor_random, coor): 102 | lap_xmin = coor_random[0] if (coor[0] < coor_random[0]) else coor[0] 103 | lap_ymin = coor_random[2] if (coor[2] < coor_random[2]) else coor[2] 104 | lap_xmax = coor_random[1] if (coor[1] > coor_random[1]) else coor[1] 105 | lap_ymax = coor_random[3] if (coor[3] > coor_random[3]) else coor[3] 106 | return [lap_xmin, lap_xmax, lap_ymin, lap_ymax] 107 | 108 | 109 | # coordination order is :['xmin','xmax','ymin','ymax'] 110 | # to find out the jaccard overlap ratio between box & ground truth 111 | def overlap(coor_random, coor): 112 | width_random = coor_random[1] - coor_random[0] 113 | height_random = coor_random[3] - coor_random[2] 114 | width = coor[1] - coor[0] 115 | height = coor[3] - coor[2] 116 | lap_xmin, lap_xmax, lap_ymin, lap_ymax = lap(coor_random, coor) 117 | s_random = width_random * height_random 118 | s = width * height 119 | s_lap = (lap_xmax - lap_xmin) * (lap_ymax - lap_ymin) 120 | ratio = 1.0 * s_lap / (s_random + s - s_lap) 121 | return ratio, [lap_xmin, lap_xmax, lap_ymin, lap_ymax] 122 | 123 | 124 | def random_box(): 125 | '''the width & height will be genenered randomly at first according to sacle & aspect ratio 126 | then the w_h_off 127 | ''' 128 | aspcect_constraint = [0.5, 2.0] # width/height 129 | scale_constraint = [0.3, 1.0] # random box vs. origin 130 | aspect_zone = aspcect_constraint[1] - aspcect_constraint[0] - 0.01 131 | aspect = random.random() * aspect_zone + aspcect_constraint[0] 132 | scale = random.uniform(scale_constraint[0], scale_constraint[1]) 133 | 134 | aspect = max(aspect, scale ** 2) 135 | aspect = min(aspect, scale ** 0.5) 136 | 137 | box_width = scale * aspect ** 0.5 138 | box_height = scale / aspect ** 0.5 139 | 140 | w_off = random.uniform(0.0, 1 - box_width) 141 | h_off = random.uniform(0.0, 1 - box_height) 142 | 143 | return w_off, w_off + box_width, h_off, h_off + box_height 144 | 145 | 146 | def satisfy_constraint(objects, coord_random, min_jaccard): # ,keep 147 | '''random box shoulr meet the constraint need;ie.min_jaccard_overlap: 0.10000000149''' 148 | xmin, xmax, ymin, ymax = coord_random 149 | found = False 150 | coord = ['xmin', 'xmax', 'ymin', 'ymax'] 151 | for i in objects.keys(): 152 | if not (type(objects[i]) == dict and objects[i].has_key('bndbox')): 153 | continue 154 | object_box = [int(objects[i]['bndbox'][k]) for k in coord] 155 | # center of object box should located in the random box 156 | ixmin, ixmax, iymin, iymax = object_box 157 | meanx = (ixmax + ixmin) / 2 158 | meany = (iymax + iymin) / 2 159 | if not xmin < meanx < xmax and ymin < meany < ymax: 160 | continue 161 | # satisfy the jaccard overlap threhold constratint 162 | ratio, lap_coord = overlap(coord_random, object_box) # 163 | if not ratio > min_jaccard: 164 | continue 165 | found = True 166 | # keep[i]=lap_coord[:] 167 | if found: 168 | return found 169 | return found 170 | 171 | 172 | def corp(image, objects_origin, coor): 173 | '''crop image & labels according coor which is random box''' 174 | objects = copy.deepcopy(objects_origin) 175 | xmin, xmax, ymin, ymax = coor 176 | if 3 == image.ndim: 177 | data = image[ymin:ymax, xmin:xmax, :] 178 | else: 179 | data = image[ymin:ymax, xmin:xmax] 180 | 181 | # crop annoated info 182 | Anno = ['xmin', 'xmax', 'ymin', 'ymax'] 183 | topLeft = [coor[0], coor[0], coor[2], coor[2]] # coordination of top left 184 | keep = [] 185 | for i in objects.keys(): 186 | if not (type(objects[i]) == dict and objects[i].has_key('bndbox')): 187 | continue 188 | object_box = [int(objects[i]['bndbox'][k]) for k in Anno] 189 | 190 | # center of object box should located in the random box 191 | ixmin, ixmax, iymin, iymax = object_box 192 | meanx = (ixmax + ixmin) / 2 193 | meany = (iymax + iymin) / 2 194 | if not xmin < meanx < xmax and ymin < meany < ymax: 195 | continue 196 | 197 | overlap_coor = lap(coor, object_box) 198 | for k, anno in enumerate(Anno): 199 | objects[i]['bndbox'][anno] = overlap_coor[k] - topLeft[k] 200 | overlap_coor[k] -= topLeft[k] # change the values in patch 201 | keep.append(i) 202 | for key in objects.keys(): 203 | if not key in keep: 204 | objects.pop(key) 205 | 206 | return data, objects 207 | 208 | 209 | ''' 210 | image is a photo/graph 211 | objects store its annotated information 212 | min_jaccard is its minmum overlap between random box & 213 | ''' 214 | 215 | 216 | def generate_batch_samples(image, objects, min_jaccard, max_trials=50): 217 | '''generate random crop sample for input data''' 218 | height, width, channel = image.shape 219 | trials = [] 220 | for i in range(max_trials): 221 | # keep = {}#store jaccard overlap coordinate 222 | # xmin, xmax, ymin, ymax 223 | if min_jaccard < 0.05: 224 | coor_ = [0.0, 1.0, 0.0, 1.0] 225 | else: 226 | coor_ = random_box() 227 | size_img = [width, width, height, height] 228 | coor_random = [int(x * y) for x, y in zip(coor_, size_img)] 229 | if min_jaccard < 0.05: 230 | return coor_random[:] 231 | # rectify the random box is legal 232 | if satisfy_constraint(objects, coor_random, min_jaccard): # ,keep 233 | # satisfied = coor_random[:] 234 | # satisfied.append(copy.deepcopy(keep))#first 4 is coor ,the last one is ovelap objects 235 | trials.append(coor_random[:]) # satisfied[:] 236 | 237 | return trials 238 | 239 | 240 | def corp_image(image, objects, sz, origin_image): 241 | """ 242 | 243 | :param image: the image( which is withering) 244 | :param objects: infomation of dict of objects 245 | :param sz: 300*300 246 | :return: 247 | """ 248 | min_jaccard = [0, 1, 3, 5, 7, 9] 249 | all_crop = {} 250 | for i in range(len(min_jaccard)): 251 | if 0 != min_jaccard[i]: 252 | # if i == 6: 253 | # print i 254 | jaccard = min_jaccard[i] / 10.0 255 | 256 | trials = generate_batch_samples(image, objects, jaccard) 257 | 258 | if trials and type(trials[0]) == list: 259 | randidx = random.randint(0, len(trials) - 1) 260 | # print jaccard, randidx, len(trials) 261 | lt = len(trials) 262 | coorAnno = trials[randidx] # coordination & annotated 263 | elif trials and type(trials[0]) != list: 264 | coorAnno = trials 265 | else: 266 | continue 267 | 268 | try: 269 | crop_imageAnno = corp(image, objects, coorAnno) 270 | crop_imageAnno_origin = corp(origin_image, objects, coorAnno) # origin++ 271 | data_crop, labels_crop = resize_imgAnno(sz, *crop_imageAnno) 272 | origin_crop, _ = resize_imgAnno(sz, *crop_imageAnno_origin) # origin++ 273 | all_crop[i] = [data_crop, labels_crop, origin_crop] 274 | except Exception, e: 275 | print '--here:', e 276 | else: 277 | all_crop[i] = [image.copy(), copy.deepcopy(objects), origin_image.copy()] # origin++ 278 | 279 | return all_crop 280 | 281 | 282 | def show(imgs,point):#plot a photo & its face markpoint 283 | img=imgs[:,:] 284 | for i in range(len(point)/2): 285 | cv2.circle(img,(int(point[i*2]),int(point[i*2+1])),1,(255,250,0),1) 286 | plt.imshow(img)# 287 | plt.show() 288 | 289 | def ImgRotate(srcImg, degree):#rotation a photo 290 | h,w,c=srcImg.shape 291 | diaLength = int(math.sqrt(h**2 + w**2)) 292 | tempImg = np.zeros((diaLength, diaLength,c), dtype=srcImg.dtype) 293 | tx = diaLength / 2 - w / 2 # left 294 | ty = diaLength / 2 - h / 2 # top 295 | tempImg[ty:ty + h, tx:tx + w] = srcImg 296 | matRotation=cv2.getRotationMatrix2D((diaLength/2,diaLength/2),degree,1) 297 | imgRotation=cv2.warpAffine(tempImg,matRotation,(diaLength, diaLength),borderValue=(0,0,0)) 298 | return imgRotation 299 | 300 | def getPointAffinedPos(x ,y, h, w, degree):#rotation point related to its origin photo 301 | diaLength = math.sqrt(h**2 + w**2) 302 | center_x, center_y = diaLength / 2, diaLength / 2 303 | x -= w / 2 304 | y -= h / 2 305 | angle = degree * np.pi / 180.0 306 | 307 | dst_x = round( x * math.cos(angle) + y * math.sin(angle) + center_x) 308 | dst_y = round(-x * math.sin(angle) + y * math.cos(angle) + center_y) 309 | return dst_x, dst_y 310 | -------------------------------------------------------------------------------- /link.py: -------------------------------------------------------------------------------- 1 | # 2 | import os 3 | from xmlSet import mainFunction 4 | from Transform import * 5 | from scipy.misc import imread 6 | 7 | 8 | def linkImgAnn(cp): 9 | data, ann = cp 10 | dataDir = '/'.join(data.split('/')[0:2]) 11 | annDir = '/'.join(ann.split('/')[0:2]) 12 | extend = '_Aug' 13 | 14 | 15 | def readFile(): 16 | testFile = '/home/flag54/Downloads/caffe-ssd/data/VOC0712/trainval.txt' 17 | 18 | for bond in open(testFile): 19 | both = bond.split() 20 | yield both 21 | ''' 22 | if os.path.exists(both[0]) and os.path.exists(both[1]) : 23 | linkImgANN(both) 24 | ''' 25 | 26 | 27 | def readAnnoImage(): 28 | prex = "/home/flag54/Downloads/caffe-ssd/data/VOCdevkit/" 29 | 30 | # read xmlAnno 31 | couples = readFile() 32 | xx = (1,) 33 | for cp in couples: 34 | try: 35 | img, anno = cp 36 | print prex + img, 37 | if img=='VOC2012/JPEGImages/2009_002851.jpg': 38 | print 'oo' 39 | data, labels, origin_data = mainFunction(prex + img, prex + anno) 40 | tmp = [] 41 | for i in labels.keys(): 42 | tmp.append(labels[i]['name']) 43 | xx += tuple(tmp) 44 | # print 'o' 45 | ###deal with data & its labels 46 | # show_data(data, labels) 47 | except Exception, e: 48 | print 'line:', e 49 | print xx 50 | 51 | 52 | if __name__ == "__main__": 53 | readAnnoImage() 54 | -------------------------------------------------------------------------------- /xmlSet.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import xml.sax 3 | import os 4 | 5 | from Transform import * 6 | from scipy.misc import imread 7 | import numpy as np 8 | from PIL import Image 9 | 10 | 11 | class xmlReader(xml.sax.ContentHandler): 12 | def __init__(self): 13 | self.contents = {} # all contents in xml 14 | # self.tmp = {} # a piece 15 | self.backend = 0 # postfix 16 | self.tag = '' # name of tag 17 | self.ParentTag = [] # parent tag 18 | self.backend_str = [] # follow to store the information of backend 19 | self.useful_now = False 20 | self.is_common = '' # the tag has ended,which may bring error in character() 21 | 22 | def startElement(self, name, attrs): 23 | self.tag = name # the name of temporal tag 24 | self.is_common = '' 25 | if name == "object": # object tag in xml 26 | self.useful_now = True # objects is in the contents now. 27 | ''' 28 | self.tmp.clear()#slef.tmp is used for common labels 29 | self.tmp[name] = '' 30 | self.contents[self.nodes].setdefault(name)#useless 31 | ''' 32 | 33 | def endElement(self, name): 34 | if len(self.backend_str): 35 | name += self.backend_str[-1] 36 | if name in self.ParentTag: # just suppose all tag is occured didymous or occured as couple 37 | self.ParentTag.pop() 38 | self.backend_str.pop() 39 | if name == "object": 40 | self.useful_now = False 41 | 42 | self.tag = name # the name must equal tag now 43 | self.is_common = name 44 | 45 | def characters(self, content): 46 | if self.useful_now: 47 | # useless information or the tag has been ended 48 | if len(content) > 0 and content[0] == '\t' or self.is_common == self.tag: 49 | return 50 | 51 | tmp_dict = self.contents 52 | for i in self.ParentTag: # to the last level of the dict of contents 53 | tmp_dict = tmp_dict[i] 54 | 55 | # obtain a name has no conflict 56 | key_name = self.tag 57 | if tmp_dict.get(key_name): 58 | key_name += str(self.backend) 59 | 60 | if content == '\n': # at least double connect labels 61 | tmp_dict[key_name] = {} 62 | self.ParentTag.append(key_name) # this is very important 63 | if key_name == self.tag: 64 | self.backend_str.append('') 65 | else: 66 | self.backend_str.append(str(self.backend)) 67 | self.backend += 1 # all name is different 68 | else: 69 | tmp_dict[key_name] = content 70 | 71 | ''' 72 | def __del__(self): 73 | return self.contents 74 | ''' 75 | 76 | 77 | def gotXMLInfo(info): 78 | # creat XMLReader 79 | parser = xml.sax.make_parser() 80 | # turn off namepsaces 81 | parser.setFeature(xml.sax.handler.feature_namespaces, 0) 82 | 83 | # defined ContextHandler 84 | Handler = xmlReader() 85 | parser.setContentHandler(Handler) 86 | 87 | # read xmlAnno 88 | parser.parse(info) 89 | return Handler.contents 90 | 91 | def remove_unbox(objects): 92 | for i in objects.keys(): 93 | if not (type(objects[i]) == dict and objects[i].has_key('bndbox')): 94 | objects.pop(i) 95 | 96 | def mainFunction(image_path, anno_path): 97 | objects = gotXMLInfo(anno_path) 98 | remove_unbox(objects) 99 | image = imread(image_path) 100 | trans_dict = transform(image, objects) 101 | keys = list(trans_dict.keys()) 102 | print ' ', len(keys) 103 | idx = random.randint(0, len(keys) - 1) 104 | 105 | return trans_dict[keys[idx]] 106 | 107 | 108 | def test(): 109 | info = "/home/flag54/Documents/dataSetAugument/data/anno/009653.xml" 110 | photo = "/home/flag54/Documents/dataSetAugument/data/dataSet/009653.jpg" 111 | print os.curdir 112 | 113 | data, newoj, origin_data = mainFunction(photo, info) 114 | print 'o' 115 | return data, newoj 116 | 117 | 118 | if __name__ == "__main__": 119 | test() 120 | --------------------------------------------------------------------------------