├── README.md ├── create_data_lists.py ├── data_crop.py ├── image └── dota.png └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Remote-Sensing-Object-Detection-with-Oriented-Bouding-Box 2 | Some object detection codes for DOTA dataset 3 | 4 | 这里用到的数据集是DOTA数据集包含15个类别:`'small-vehicle', 'plane', 'large-vehicle', 'ship', 'harbor', 'tennis-court', 'round-track-field', 'soccer-ball-field', 'baseball-diamond', 'swimming-pool', 'roundabout', 'basketball-court', 'storage-tank', 'bridge', 'helicopter'` 5 | 6 | ![Dota数据集实例](https://github.com/weihancug/Remote-Sensing-Object-Detection-with-Oriented-Bouding-Box/blob/master/image/dota.png) 7 | 8 | 1 首先使用`data_crop.py` 讲dota数据集进行切分,可以训练的大小,例如1000x1000 9 | 10 | 2 接下来使用`create_data_list.py`,创建一个训练集和测试集所有文件的json文件,用于模型读取 11 | 12 | 3 模型训练: 两个参数,第一个是`interpreter options: -m torch.distributed.launch --nproc_per_node = 2` 13 | 第二个是:`--skip-test --config-file config_path DATALOADER.2 OUTPUT_DIR output_path` 14 | 15 | ``` 16 | -m torch.distributed.launch --nproc_per_node = 2 python train_net.py --skip-test --config-file ../configs/fcos/orientedfcos_R50_1x.yaml DATALOADER.2 OUTPUT_DIR ../training_dir/orientedfcos_R_50_FPN_1x 17 | ``` 18 | 4 模型测试: 19 | -------------------------------------------------------------------------------- /create_data_lists.py: -------------------------------------------------------------------------------- 1 | from utils import create_data_lists,create_data_list_DOTA 2 | 3 | if __name__ == '__main__': 4 | ''' 5 | #用于处理voc数据 6 | create_data_lists(voc07_path='E:/object detection competetion/VOC2007/VOCtrainval_06-Nov-2007/VOCdevkit/VOC2007', 7 | voc12_path='E:/object detection competetion/VOC2007/VOCtrainval_11-May-2012/VOCdevkit/VOC2012/', 8 | output_folder='E:/object detection competetion/VOC2007/VOC-output/') 9 | ''' 10 | create_data_list_DOTA(dota_train='/home/han/Desktop/DOTA/dataset_crop/train/', 11 | dota_test='/home/han/Desktop/DOTA/dataset_crop/val/', 12 | output_folder='/home/han/Desktop/DOTA/output/') 13 | -------------------------------------------------------------------------------- /data_crop.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import os 3 | import scipy.misc as misc 4 | from xml.dom.minidom import Document 5 | import numpy as np 6 | import copy, cv2 7 | 8 | 9 | raw_data = '/home/han/Desktop/DOTA/dataset/train/' 10 | save_dir = '/home/han/Desktop/DOTA/dataset_crop/train/' 11 | # raw_data = '/home/han/Desktop/DOTA/dataset/val/' 12 | # save_dir = '/home/han/Desktop/DOTA/dataset_crop/val/' 13 | # DOTA 14 | class_list = ['small-vehicle', 'plane', 'large-vehicle', 'ship', 'harbor', 'tennis-court', 15 | 'ground-track-field', 'soccer-ball-field', 'baseball-diamond', 'swimming-pool', 'roundabout', 16 | 'basketball-court', 'storage-tank', 'bridge', 'helicopter'] 17 | 18 | if not os.path.exists(save_dir): 19 | os.mkdir(save_dir) 20 | if not os.path.exists(os.path.join(save_dir,'images')): 21 | os.mkdir(os.path.join(save_dir,'images')) 22 | if not os.path.exists(os.path.join(save_dir,'labelTxt')): 23 | os.mkdir(os.path.join(save_dir,'labelTxt')) 24 | 25 | def save_to_xml(save_path, im_height, im_width, objects_axis, label_name): 26 | im_depth = 0 27 | object_num = len(objects_axis) 28 | doc = Document() 29 | 30 | annotation = doc.createElement('annotation') 31 | doc.appendChild(annotation) 32 | 33 | folder = doc.createElement('folder') 34 | folder_name = doc.createTextNode('VOC2007') 35 | folder.appendChild(folder_name) 36 | annotation.appendChild(folder) 37 | 38 | filename = doc.createElement('filename') 39 | filename_name = doc.createTextNode('000024.jpg') 40 | filename.appendChild(filename_name) 41 | annotation.appendChild(filename) 42 | 43 | source = doc.createElement('source') 44 | annotation.appendChild(source) 45 | 46 | database = doc.createElement('database') 47 | database.appendChild(doc.createTextNode('The VOC2007 Database')) 48 | source.appendChild(database) 49 | 50 | annotation_s = doc.createElement('annotation') 51 | annotation_s.appendChild(doc.createTextNode('PASCAL VOC2007')) 52 | source.appendChild(annotation_s) 53 | 54 | image = doc.createElement('image') 55 | image.appendChild(doc.createTextNode('flickr')) 56 | source.appendChild(image) 57 | 58 | flickrid = doc.createElement('flickrid') 59 | flickrid.appendChild(doc.createTextNode('322409915')) 60 | source.appendChild(flickrid) 61 | 62 | size = doc.createElement('size') 63 | annotation.appendChild(size) 64 | width = doc.createElement('width') 65 | width.appendChild(doc.createTextNode(str(im_width))) 66 | height = doc.createElement('height') 67 | height.appendChild(doc.createTextNode(str(im_height))) 68 | depth = doc.createElement('depth') 69 | depth.appendChild(doc.createTextNode(str(im_depth))) 70 | size.appendChild(width) 71 | size.appendChild(height) 72 | size.appendChild(depth) 73 | segmented = doc.createElement('segmented') 74 | segmented.appendChild(doc.createTextNode('0')) 75 | annotation.appendChild(segmented) 76 | for i in range(object_num): 77 | objects = doc.createElement('object') 78 | annotation.appendChild(objects) 79 | object_name = doc.createElement('name') 80 | object_name.appendChild(doc.createTextNode(label_name[int(objects_axis[i][8])])) 81 | objects.appendChild(object_name) 82 | pose = doc.createElement('pose') 83 | pose.appendChild(doc.createTextNode('Unspecified')) 84 | objects.appendChild(pose) 85 | truncated = doc.createElement('truncated') 86 | truncated.appendChild(doc.createTextNode('1')) 87 | objects.appendChild(truncated) 88 | difficult = doc.createElement('difficult') 89 | difficult.appendChild(doc.createTextNode(str((objects_axis[i][9])))) 90 | objects.appendChild(difficult) 91 | bndbox = doc.createElement('bndbox') 92 | objects.appendChild(bndbox) 93 | 94 | x0 = doc.createElement('x0') 95 | x0.appendChild(doc.createTextNode(str((objects_axis[i][0])))) 96 | bndbox.appendChild(x0) 97 | y0 = doc.createElement('y0') 98 | y0.appendChild(doc.createTextNode(str((objects_axis[i][1])))) 99 | bndbox.appendChild(y0) 100 | 101 | x1 = doc.createElement('x1') 102 | x1.appendChild(doc.createTextNode(str((objects_axis[i][2])))) 103 | bndbox.appendChild(x1) 104 | y1 = doc.createElement('y1') 105 | y1.appendChild(doc.createTextNode(str((objects_axis[i][3])))) 106 | bndbox.appendChild(y1) 107 | 108 | x2 = doc.createElement('x2') 109 | x2.appendChild(doc.createTextNode(str((objects_axis[i][4])))) 110 | bndbox.appendChild(x2) 111 | y2 = doc.createElement('y2') 112 | y2.appendChild(doc.createTextNode(str((objects_axis[i][5])))) 113 | bndbox.appendChild(y2) 114 | 115 | x3 = doc.createElement('x3') 116 | x3.appendChild(doc.createTextNode(str((objects_axis[i][6])))) 117 | bndbox.appendChild(x3) 118 | y3 = doc.createElement('y3') 119 | y3.appendChild(doc.createTextNode(str((objects_axis[i][7])))) 120 | bndbox.appendChild(y3) 121 | 122 | f = open(save_path,'w') 123 | f.write(doc.toprettyxml(indent = '')) 124 | f.close() 125 | 126 | 127 | def format_label(txt_list): 128 | format_data = [] 129 | for i in txt_list[0:]: 130 | if len(i.split(' ')) != 10: 131 | continue 132 | format_data.append( 133 | [int(float(xy)) for xy in i.split(' ')[:8]] + [class_list.index(i.split(' ')[8])] + [int(i.split(' ')[9])] 134 | # {'x0': int(i.split(' ')[0]), 135 | # 'x1': int(i.split(' ')[2]), 136 | # 'x2': int(i.split(' ')[4]), 137 | # 'x3': int(i.split(' ')[6]), 138 | # 'y1': int(i.split(' ')[1]), 139 | # 'y2': int(i.split(' ')[3]), 140 | # 'y3': int(i.split(' ')[5]), 141 | # 'y4': int(i.split(' ')[7]), 142 | # 'class': class_list.index(i.split(' ')[8]) if i.split(' ')[8] in class_list else 0, 143 | # 'difficulty': int(i.split(' ')[9])} 144 | ) 145 | if i.split(' ')[8] not in class_list : 146 | print ('warning found a new label :', i.split(' ')[8]) 147 | exit() 148 | return np.array(format_data) 149 | 150 | def clip_image(file_idx, image, boxes_all, width, height): 151 | # print ('image shape', image.shape) 152 | if len(boxes_all) > 0: 153 | shape = image.shape 154 | for start_h in range(0, shape[0], 256): 155 | for start_w in range(0, shape[1], 256): 156 | boxes = copy.deepcopy(boxes_all) 157 | box = np.zeros_like(boxes_all) 158 | start_h_new = start_h 159 | start_w_new = start_w 160 | if start_h + height > shape[0]: 161 | start_h_new = shape[0] - height 162 | if start_w + width > shape[1]: 163 | start_w_new = shape[1] - width 164 | top_left_row = max(start_h_new, 0) 165 | top_left_col = max(start_w_new, 0) 166 | bottom_right_row = min(start_h + height, shape[0]) 167 | bottom_right_col = min(start_w + width, shape[1]) 168 | 169 | 170 | subImage = image[top_left_row:bottom_right_row, top_left_col: bottom_right_col] 171 | 172 | box[:, 0] = boxes[:, 0] - top_left_col 173 | box[:, 2] = boxes[:, 2] - top_left_col 174 | box[:, 4] = boxes[:, 4] - top_left_col 175 | box[:, 6] = boxes[:, 6] - top_left_col 176 | 177 | box[:, 1] = boxes[:, 1] - top_left_row 178 | box[:, 3] = boxes[:, 3] - top_left_row 179 | box[:, 5] = boxes[:, 5] - top_left_row 180 | box[:, 7] = boxes[:, 7] - top_left_row 181 | box[:, 8] = boxes[:, 8] 182 | box[:, 9] = boxes[:, 9] 183 | center_y = 0.25*(box[:, 1] + box[:, 3] + box[:, 5] + box[:, 7]) 184 | center_x = 0.25*(box[:, 0] + box[:, 2] + box[:, 4] + box[:, 6]) 185 | # print('center_y', center_y) 186 | # print('center_x', center_x) 187 | # print ('boxes', boxes) 188 | # print ('boxes_all', boxes_all) 189 | # print ('top_left_col', top_left_col, 'top_left_row', top_left_row) 190 | 191 | cond1 = np.intersect1d(np.where(center_y[:] >=0 )[0], np.where(center_x[:] >=0 )[0]) 192 | cond2 = np.intersect1d(np.where(center_y[:] <= (bottom_right_row - top_left_row))[0], 193 | np.where(center_x[:] <= (bottom_right_col - top_left_col))[0]) 194 | idx = np.intersect1d(cond1, cond2) 195 | # idx = np.where(center_y[:]>=0 and center_x[:]>=0 and center_y[:] <= (bottom_right_row - top_left_row) and center_x[:] <= (bottom_right_col - top_left_col))[0] 196 | # save_path, im_width, im_height, objects_axis, label_name 197 | if len(idx) > 0: 198 | xml = os.path.join(save_dir, 'labelTxt', "%s_%04d_%04d.xml" % (file_idx, top_left_row, top_left_col)) 199 | save_to_xml(xml, subImage.shape[0], subImage.shape[1], box[idx, :], class_list) 200 | # print ('save xml : ', xml) 201 | if subImage.shape[0] > 5 and subImage.shape[1] >5: 202 | img = os.path.join(save_dir, 'images', "%s_%04d_%04d.png" % (file_idx, top_left_row, top_left_col)) 203 | cv2.imwrite(img, subImage) 204 | 205 | raw_images_dir = os.path.join(raw_data, 'images') 206 | raw_label_dir = os.path.join(raw_data, 'labelTxt') 207 | 208 | 209 | images = [i for i in os.listdir(raw_images_dir) if 'png' in i] 210 | labels = [i for i in os.listdir(raw_label_dir) if 'txt' in i] 211 | 212 | print ('find image', len(images)) 213 | print ('find label', len(labels)) 214 | 215 | min_length = 1e10 216 | max_length = 1 217 | 218 | for idx, img in enumerate(images): 219 | # img = 'P1524.png' 220 | print (idx, 'read image', img) 221 | #img_data = misc.imread(os.path.join(raw_images_dir, img)) 222 | img_data = cv2.imread(os.path.join(raw_images_dir, img)) 223 | 224 | txt_data = open(os.path.join(raw_label_dir, 225 | img.replace('png', 'txt')), 226 | 'r').readlines() 227 | 228 | box = format_label(txt_data) 229 | clip_image(img.strip('.png'), img_data, box, 1000, 1000) 230 | 231 | 232 | 233 | # rm train/images/* && rm train/labeltxt/* 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | -------------------------------------------------------------------------------- /image/dota.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihancug/Remote-Sensing-Object-Detection-with-Oriented-Bouding-Box/a31f40c7e2acf751a0827348fafaff0b870bd61c/image/dota.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import torch 4 | import random 5 | import xml.etree.ElementTree as ET 6 | import torchvision.transforms.functional as FT 7 | import torch 8 | from torch.utils.data import Dataset 9 | import json 10 | import os 11 | from PIL import Image 12 | import random 13 | import torchvision.transforms.functional as FT 14 | import math 15 | import numpy as np 16 | from shapely.geometry import Polygon, MultiPoint 17 | import shapely 18 | import pandas as pd 19 | 20 | #device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | 22 | device = torch.device("cuda:0,1" if torch.cuda.is_available() else "cpu") 23 | # Label map 24 | #plane, ship, storage tank, baseball diamond, tennis court, basketball court, ground track field, harbor, bridge, 25 | # large vehicle, small vehicle, helicopter, roundabout, soccer ball field and swimming pool 26 | #voc_labels = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 27 | # 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor') 28 | # voc_labels = ('small-vehicle', 'plane', 'large-vehicle', 'ship', 'harbor', 'tennis-court', 29 | # 'ground-track-field', 'soccer-ball-field', 'baseball-diamond', 'swimming-pool', 'roundabout', 30 | # 'basketball-court', 'storage-tank', 'bridge', 'helicopter') 31 | voc_labels = ('0', '1', '2', '3', '4', '5', 32 | '6', '7', '8', '9', '10', 33 | '11', '12', '13', '14') 34 | label_map = {k: v + 1 for v, k in enumerate(voc_labels)} 35 | label_map['background'] = 0 36 | rev_label_map = {v: k for k, v in label_map.items()} # Inverse mapping 37 | 38 | # Color map for bounding boxes of detected objects from https://sashat.me/2017/01/11/list-of-20-simple-distinct-colors/ 39 | distinct_colors = ['#e6194b', '#3cb44b', '#ffe119', '#0082c8', '#f58231', '#911eb4', '#46f0f0', '#f032e6', 40 | '#d2f53c', '#fabebe', '#008080', '#000080', '#aa6e28', 41 | '#fffac8', '#800000', '#aaffc3', '#808000','#808900', 42 | '#ffd8b1', '#e6beff', '#808080', '#FFFFFF'] 43 | label_color_map = {k: distinct_colors[i] for i, k in enumerate(label_map.keys())} 44 | class_names_print =[] 45 | img_extension = '.png' 46 | def parse_annotation(annotation_path): 47 | tree = ET.parse(annotation_path) 48 | root = tree.getroot() 49 | boxes = list() 50 | labels = list() 51 | difficulties = list() 52 | for object in root.iter('object'): 53 | difficult = int(object.find('difficult').text == '1') 54 | label = object.find('name').text.lower().strip() 55 | ''' 56 | #collect class name for DOTA 57 | if class_names_print == None or label not in class_names_print: 58 | class_names_print.append(label) 59 | ''' 60 | if label not in label_map: 61 | continue 62 | bbox = object.find('bndbox') 63 | ##### 64 | #VOC# 65 | #### 66 | # xmin = int(bbox.find('xmin').text) - 1 67 | # ymin = int(bbox.find('ymin').text) - 1 68 | # xmax = int(bbox.find('xmax').text) - 1 69 | # ymax = int(bbox.find('ymax').text) - 1 70 | # boxes.append([xmin, ymin, xmax, ymax]) 71 | #DOTA 72 | x0 = int(bbox.find('x0').text) - 1 73 | x1 = int(bbox.find('x1').text) - 1 74 | x2 = int(bbox.find('x2').text) - 1 75 | x3 = int(bbox.find('x3').text) - 1 76 | y0 = int(bbox.find('y0').text) - 1 77 | y1 = int(bbox.find('y1').text) - 1 78 | y2 = int(bbox.find('y2').text) - 1 79 | y3 = int(bbox.find('y3').text) - 1 80 | xmin = min(x1, x2, x3, x0) 81 | xmax = max(x1, x2, x3, x0) 82 | ymin = min(y1, y2, y3, y0) 83 | ymax = max(y1, y2, y3, y0) 84 | #boxes.append([xmin, ymin, xmax, ymax]) 85 | boxes.append([x0, y0, x1, y1, x2,y2,x3,y3]) 86 | labels.append(label_map[label]) 87 | difficulties.append(difficult) 88 | return {'boxes': boxes, 'labels': labels, 'difficulties': difficulties} 89 | 90 | 91 | def create_data_list_DOTA(dota_train,output_folder,dota_test=None): 92 | ''' 93 | 用于生成DOTA数据集的训练集和测试机json文件 94 | :param dota_train: 训练集目录 格式为: images, labeltext 95 | :param dota_test: 测试集目录 格式为: images,labeltext 96 | :param output_folder: 用于存放训练集和测试机生成的各种json 数据 97 | :return: 98 | ''' 99 | 100 | print ('the processing dataset is: DOTA \n') 101 | print ("training dir is : {} \n".format(dota_train)) 102 | print("test dir is : {} \n".format(dota_test)) 103 | print("json file output is : {} \n".format(output_folder)) 104 | train_images = list() 105 | train_objects =list() 106 | n_objects_train = 0 107 | ids_train = list() 108 | #training_data preparation 109 | for root, dirs, files in os.walk(os.path.join(dota_train,"images")): 110 | for name in files: 111 | name = name.replace('.png','') 112 | ids_train.append(name) 113 | for id in ids_train: 114 | # Parse annotation's XML file 115 | objects = parse_annotation(os.path.join(dota_train, 'labelTxt', id + '.xml')) 116 | if len(objects) == 0: 117 | continue 118 | n_objects_train += len(objects) 119 | train_objects.append(objects) 120 | train_images.append(os.path.join(dota_train, 'images', id + '.tif')) 121 | 122 | assert len(train_objects) == len(train_images) 123 | #print (class_names_print) 124 | with open(os.path.join(output_folder, 'TRAIN_images.json'), 'w') as j: 125 | json.dump(train_images, j) 126 | with open(os.path.join(output_folder,'list'),'w') as f: 127 | for item in train_images: 128 | f.write("%s\n" % item) 129 | with open(os.path.join(output_folder, 'TRAIN_objects.json'), 'w') as j: 130 | json.dump(train_objects, j) 131 | with open(os.path.join(output_folder, 'label_map.json'), 'w') as j: 132 | json.dump(label_map, j) # save label map too 133 | 134 | print('\n There are %d training images containing a total of %d objects. Files have been saved to %s.' % ( 135 | len(train_images), n_objects_train, output_folder)) 136 | 137 | #val_dir 138 | if dota_test is not None: 139 | ids_val = list() 140 | test_images = list() 141 | test_objects = list() 142 | n_objects_val = 0 143 | for root, dirs, files in os.walk(os.path.join(dota_test,"images")): 144 | for name in files: 145 | name = name.replace(img_extension,'') 146 | ids_val.append(name) 147 | for id in ids_val: 148 | # Parse annotation's XML file 149 | objects = parse_annotation(os.path.join(dota_test, 'labelTxt', id + '.xml')) 150 | if len(objects) == 0: 151 | continue 152 | n_objects_val += len(objects) 153 | test_objects.append(objects) 154 | test_images.append(os.path.join(dota_test, 'images', id + '.tif')) 155 | assert len(test_objects) == len(test_images) 156 | 157 | # Save to file 158 | with open(os.path.join(output_folder, 'TEST_images.json'), 'w') as j: 159 | json.dump(test_images, j) 160 | with open(os.path.join(output_folder, 'TEST_objects.json'), 'w') as j: 161 | json.dump(test_objects, j) 162 | 163 | print('\nThere are %d validation images containing a total of %d objects. Files have been saved to %s.' % ( 164 | len(test_images), n_objects_val, os.path.abspath(output_folder))) 165 | 166 | def create_data_lists(voc07_path, voc12_path, output_folder): 167 | """ 168 | Create lists of images, the bounding boxes and labels of the objects in these images, and save these to file. 169 | 170 | :param voc07_path: path to the 'VOC2007' folder 171 | :param voc12_path: path to the 'VOC2012' folder 172 | :param output_folder: folder where the JSONs must be saved 173 | """ 174 | #voc07_path = os.path.abspath(voc07_path) 175 | #voc12_path = os.path.abspath(voc12_path) 176 | train_images = list() 177 | train_objects = list() 178 | n_objects = 0 179 | 180 | # Training data 181 | for path in [voc07_path, voc12_path]: 182 | # Find IDs of images in training data 183 | with open(os.path.join(path, 'ImageSets/Main/trainval.txt')) as f: 184 | ids = f.read().splitlines() 185 | 186 | for id in ids: 187 | # Parse annotation's XML file 188 | objects = parse_annotation(os.path.join(path, 'Annotations', id + '.xml')) 189 | print ('objects is : {}'.format(len(objects))) 190 | if len(objects) == 0: 191 | continue 192 | n_objects += len(objects) 193 | train_objects.append(objects) 194 | train_images.append(os.path.join(path, 'JPEGImages', id + '.jpg')) 195 | 196 | assert len(train_objects) == len(train_images) 197 | 198 | # Save to file 199 | with open(os.path.join(output_folder, 'TRAIN_images.json'), 'w') as j: 200 | json.dump(train_images, j) 201 | with open(os.path.join(output_folder, 'TRAIN_objects.json'), 'w') as j: 202 | json.dump(train_objects, j) 203 | with open(os.path.join(output_folder, 'label_map.json'), 'w') as j: 204 | json.dump(label_map, j) # save label map too 205 | 206 | print('\nThere are %d training images containing a total of %d objects. Files have been saved to %s.' % ( 207 | len(train_images), n_objects, output_folder)) 208 | 209 | # Validation data 210 | test_images = list() 211 | test_objects = list() 212 | n_objects = 0 213 | 214 | # Find IDs of images in validation data 215 | with open(os.path.join(voc07_path, 'ImageSets/Main/test.txt')) as f: 216 | ids = f.read().splitlines() 217 | 218 | for id in ids: 219 | # Parse annotation's XML file 220 | objects = parse_annotation(os.path.join(voc07_path, 'Annotations', id + '.xml')) 221 | if len(objects) == 0: 222 | continue 223 | test_objects.append(objects) 224 | n_objects += len(objects) 225 | test_images.append(os.path.join(voc07_path, 'JPEGImages', id + '.jpg')) 226 | 227 | assert len(test_objects) == len(test_images) 228 | 229 | # Save to file 230 | with open(os.path.join(output_folder, 'TEST_images.json'), 'w') as j: 231 | json.dump(test_images, j) 232 | with open(os.path.join(output_folder, 'TEST_objects.json'), 'w') as j: 233 | json.dump(test_objects, j) 234 | 235 | print('\nThere are %d validation images containing a total of %d objects. Files have been saved to %s.' % ( 236 | len(test_images), n_objects, os.path.abspath(output_folder))) 237 | 238 | 239 | def decimate(tensor, m): 240 | """ 241 | Decimate a tensor by a factor 'm', i.e. downsample by keeping every 'm'th value. 242 | 243 | This is used when we convert FC layers to equivalent Convolutional layers, BUT of a smaller size. 244 | 245 | :param tensor: tensor to be decimated 246 | :param m: list of decimation factors for each dimension of the tensor; None if not to be decimated along a dimension 247 | :return: decimated tensor 248 | """ 249 | assert tensor.dim() == len(m) 250 | for d in range(tensor.dim()): 251 | if m[d] is not None: 252 | tensor = tensor.index_select(dim=d, 253 | index=torch.arange(start=0, end=tensor.size(d), step=m[d]).long()) 254 | 255 | return tensor 256 | 257 | 258 | def calculate_mAP(det_boxes, det_labels, det_scores, true_boxes, true_labels, true_difficulties): 259 | """ 260 | Calculate the Mean Average Precision (mAP) of detected objects. 261 | 262 | See https://medium.com/@jonathan_hui/map-mean-average-precision-for-object-detection-45c121a31173 for an explanation 263 | 264 | :param det_boxes: list of tensors, one tensor for each image containing detected objects' bounding boxes 265 | :param det_labels: list of tensors, one tensor for each image containing detected objects' labels 266 | :param det_scores: list of tensors, one tensor for each image containing detected objects' labels' scores 267 | :param true_boxes: list of tensors, one tensor for each image containing actual objects' bounding boxes 268 | :param true_labels: list of tensors, one tensor for each image containing actual objects' labels 269 | :param true_difficulties: list of tensors, one tensor for each image containing actual objects' difficulty (0 or 1) 270 | :return: list of average precisions for all classes, mean average precision (mAP) 271 | """ 272 | assert len(det_boxes) == len(det_labels) == len(det_scores) == len(true_boxes) == len( 273 | true_labels) == len( 274 | true_difficulties) # these are all lists of tensors of the same length, i.e. number of images 275 | n_classes = len(label_map) 276 | 277 | # Store all (true) objects in a single continuous tensor while keeping track of the image it is from 278 | true_images = list() 279 | for i in range(len(true_labels)): 280 | true_images.extend([i] * true_labels[i].size(0)) 281 | true_images = torch.LongTensor(true_images).to( 282 | device) # (n_objects), n_objects is the total no. of objects across all images 283 | true_boxes = torch.cat(true_boxes, dim=0) # (n_objects, 4) 284 | true_labels = torch.cat(true_labels, dim=0) # (n_objects) 285 | true_difficulties = torch.cat(true_difficulties, dim=0) # (n_objects) 286 | 287 | assert true_images.size(0) == true_boxes.size(0) == true_labels.size(0) 288 | 289 | # Store all detections in a single continuous tensor while keeping track of the image it is from 290 | det_images = list() 291 | for i in range(len(det_labels)): 292 | det_images.extend([i] * det_labels[i].size(0)) 293 | det_images = torch.LongTensor(det_images).to(device) # (n_detections) 294 | det_boxes = torch.cat(det_boxes, dim=0) # (n_detections, 4) 295 | det_labels = torch.cat(det_labels, dim=0) # (n_detections) 296 | det_scores = torch.cat(det_scores, dim=0) # (n_detections) 297 | 298 | assert det_images.size(0) == det_boxes.size(0) == det_labels.size(0) == det_scores.size(0) 299 | 300 | # Calculate APs for each class (except background) 301 | average_precisions = torch.zeros((n_classes - 1), dtype=torch.float) # (n_classes - 1) 302 | for c in range(1, n_classes): 303 | # Extract only objects with this class 304 | true_class_images = true_images[true_labels == c] # (n_class_objects) 305 | true_class_boxes = true_boxes[true_labels == c] # (n_class_objects, 4) 306 | true_class_difficulties = true_difficulties[true_labels == c] # (n_class_objects) 307 | n_easy_class_objects = (1 - true_class_difficulties).sum().item() # ignore difficult objects 308 | 309 | # Keep track of which true objects with this class have already been 'detected' 310 | # So far, none 311 | true_class_boxes_detected = torch.zeros((true_class_difficulties.size(0)), dtype=torch.uint8).to( 312 | device) # (n_class_objects) 313 | 314 | # Extract only detections with this class 315 | det_class_images = det_images[det_labels == c] # (n_class_detections) 316 | det_class_boxes = det_boxes[det_labels == c] # (n_class_detections, 4) 317 | det_class_scores = det_scores[det_labels == c] # (n_class_detections) 318 | n_class_detections = det_class_boxes.size(0) 319 | if n_class_detections == 0: 320 | continue 321 | 322 | # Sort detections in decreasing order of confidence/scores 323 | det_class_scores, sort_ind = torch.sort(det_class_scores, dim=0, descending=True) # (n_class_detections) 324 | det_class_images = det_class_images[sort_ind] # (n_class_detections) 325 | det_class_boxes = det_class_boxes[sort_ind] # (n_class_detections, 4) 326 | 327 | # In the order of decreasing scores, check if true or false positive 328 | true_positives = torch.zeros((n_class_detections), dtype=torch.float).to(device) # (n_class_detections) 329 | false_positives = torch.zeros((n_class_detections), dtype=torch.float).to(device) # (n_class_detections) 330 | for d in range(n_class_detections): 331 | this_detection_box = det_class_boxes[d].unsqueeze(0) # (1, 4) 332 | this_image = det_class_images[d] # (), scalar 333 | 334 | # Find objects in the same image with this class, their difficulties, and whether they have been detected before 335 | object_boxes = true_class_boxes[true_class_images == this_image] # (n_class_objects_in_img) 336 | object_difficulties = true_class_difficulties[true_class_images == this_image] # (n_class_objects_in_img) 337 | # If no such object in this image, then the detection is a false positive 338 | if object_boxes.size(0) == 0: 339 | false_positives[d] = 1 340 | continue 341 | 342 | # Find maximum overlap of this detection with objects in this image of this class 343 | overlaps = find_jaccard_overlap(this_detection_box, object_boxes) # (1, n_class_objects_in_img) 344 | max_overlap, ind = torch.max(overlaps.squeeze(0), dim=0) # (), () - scalars 345 | 346 | # 'ind' is the index of the object in these image-level tensors 'object_boxes', 'object_difficulties' 347 | # In the original class-level tensors 'true_class_boxes', etc., 'ind' corresponds to object with index... 348 | original_ind = torch.LongTensor(range(true_class_boxes.size(0)))[true_class_images == this_image][ind] 349 | # We need 'original_ind' to update 'true_class_boxes_detected' 350 | 351 | # If the maximum overlap is greater than the threshold of 0.5, it's a match 352 | if max_overlap.item() > 0.5: 353 | # If the object it matched with is 'difficult', ignore it 354 | if object_difficulties[ind] == 0: 355 | # If this object has already not been detected, it's a true positive 356 | if true_class_boxes_detected[original_ind] == 0: 357 | true_positives[d] = 1 358 | true_class_boxes_detected[original_ind] = 1 # this object has now been detected/accounted for 359 | # Otherwise, it's a false positive (since this object is already accounted for) 360 | else: 361 | false_positives[d] = 1 362 | # Otherwise, the detection occurs in a different location than the actual object, and is a false positive 363 | else: 364 | false_positives[d] = 1 365 | 366 | # Compute cumulative precision and recall at each detection in the order of decreasing scores 367 | cumul_true_positives = torch.cumsum(true_positives, dim=0) # (n_class_detections) 368 | cumul_false_positives = torch.cumsum(false_positives, dim=0) # (n_class_detections) 369 | cumul_precision = cumul_true_positives / ( 370 | cumul_true_positives + cumul_false_positives + 1e-10) # (n_class_detections) 371 | cumul_recall = cumul_true_positives / n_easy_class_objects # (n_class_detections) 372 | 373 | # Find the mean of the maximum of the precisions corresponding to recalls above the threshold 't' 374 | recall_thresholds = torch.arange(start=0, end=1.1, step=.1).tolist() # (11) 375 | precisions = torch.zeros((len(recall_thresholds)), dtype=torch.float).to(device) # (11) 376 | for i, t in enumerate(recall_thresholds): 377 | recalls_above_t = cumul_recall >= t 378 | if recalls_above_t.any(): 379 | precisions[i] = cumul_precision[recalls_above_t].max() 380 | else: 381 | precisions[i] = 0. 382 | average_precisions[c - 1] = precisions.mean() # c is in [1, n_classes - 1] 383 | 384 | # Calculate Mean Average Precision (mAP) 385 | mean_average_precision = average_precisions.mean().item() 386 | 387 | # Keep class-wise average precisions in a dictionary 388 | average_precisions = {rev_label_map[c + 1]: v for c, v in enumerate(average_precisions.tolist())} 389 | 390 | return average_precisions, mean_average_precision 391 | 392 | 393 | def xy_to_cxcy_old(xy): 394 | """ 395 | Convert bounding boxes from boundary coordinates (x_min, y_min, x_max, y_max) to center-size coordinates (c_x, c_y, w, h). 396 | 397 | :param xy: bounding boxes in boundary coordinates, a tensor of size (n_boxes, 4) 398 | :return: bounding boxes in center-size coordinates, a tensor of size (n_boxes, 4) 399 | """ 400 | return torch.cat([(xy[:, 2:] + xy[:, :2]) / 2, # c_x, c_y 401 | xy[:, 2:] - xy[:, :2]], 1) # w, h 402 | 403 | def xy_to_cxcy(xy): #xy_to_cxcy_oriented 404 | ''' 405 | 带有方向的转换,将cxcy坐标变成8个值的坐标 406 | :param cxcy: 407 | :return: 408 | ''' 409 | x_ = torch.cat([xy[:,0].view(-1,1),xy[:,2].view(-1,1),xy[:,4].view(-1,1),xy[:,6].view(-1,1)],1) 410 | y_ = torch.cat([xy[:,1].view(-1,1),xy[:,3].view(-1,1),xy[:,5].view(-1,1),xy[:,7].view(-1,1)],1) 411 | 412 | xmin = torch.min(x_,1).values 413 | 414 | ymin = torch.min(y_, 1).values 415 | xmax = torch.max(x_, 1).values 416 | 417 | ymax = torch.max(y_, 1).values 418 | cx, cy = torch.div((xmin+xmax).cpu(),2.0).view(-1,1) ,torch.div((ymin+ymax).cpu(),2.0).view(-1,1) 419 | w = torch.sqrt((xy[:,0]-xy[:,2])**2+ (xy[:,1]-xy[:,3])**2).view(-1,1) 420 | h = torch.sqrt((xy[:,0]-xy[:,6])**2+ (xy[:,1]-xy[:,7])**2).view(-1,1) 421 | angle = torch.atan2((xy[:,1]-xy[:,3]),(xy[:,0]-xy[:,2])).view(-1,1) 422 | index = angle.gt(math.pi).nonzero() 423 | angle[index]-= math.pi 424 | index = angle.lt(0).nonzero() 425 | angle[index]+=math.pi 426 | return torch.cat([cx.cuda(),cy.cuda(),w,h,angle.view(-1,1)],1) 427 | 428 | 429 | def cxcy_minmaxxy(set1): 430 | xy=cxcy_to_xy(set1) 431 | x_ = torch.cat([xy[:,0].view(-1,1),xy[:,2].view(-1,1),xy[:,4].view(-1,1),xy[:,6].view(-1,1)],1) 432 | y_ = torch.cat([xy[:,1].view(-1,1),xy[:,3].view(-1,1),xy[:,5].view(-1,1),xy[:,7].view(-1,1)],1) 433 | xmin = torch.min(x_,1).values 434 | ymin = torch.min(y_, 1).values 435 | xmax = torch.max(x_, 1).values 436 | ymax = torch.max(y_, 1).values 437 | return torch.cat([xmin.view(-1, 1), ymin.view(-1, 1), xmax.view(-1, 1), ymax.view(-1, 1)], 1) 438 | 439 | 440 | 441 | def cxcy_to_xy_old(cxcy): 442 | """ 443 | Convert bounding boxes from center-size coordinates (c_x, c_y, w, h) to boundary coordinates (x_min, y_min, x_max, y_max). 444 | 445 | :param cxcy: bounding boxes in center-size coordinates, a tensor of size (n_boxes, 4) 446 | :return: bounding boxes in boundary coordinates, a tensor of size (n_boxes, 4) 447 | """ 448 | return torch.cat([cxcy[:, :2] - (cxcy[:, 2:] / 2), # x_min, y_min 449 | cxcy[:, :2] + (cxcy[:, 2:] / 2)], 1) # x_max, y_max 450 | 451 | def cxcy_to_xy(cxcy): #cxcy_to_xy_oriented 452 | ''' 453 | 带有方向的转换,将cxcy坐标变成8个值的坐标 454 | :param cxcy: 455 | :return: 456 | ''' 457 | bow_x = cxcy[:,0] + cxcy[:, 2]/2*torch.cos((cxcy[:, 4])) 458 | bow_y = cxcy[:,1] - cxcy[:, 2]/2*torch.sin((cxcy[:, 4])) 459 | tail_x = cxcy[:,0] - cxcy[:, 2]/2*torch.cos((cxcy[:, 4])) 460 | tail_y = cxcy[:,1] + cxcy[:, 2]/2*torch.sin((cxcy[:, 4])) 461 | 462 | x1 = (bow_x+cxcy[:, 3]/2*torch.sin((cxcy[:, 4]))).view(-1,1) 463 | y1 = (bow_y+cxcy[:, 3]/2*torch.cos((cxcy[:, 4]))).view(-1,1) 464 | 465 | x2 = (tail_x + cxcy[:, 3] / 2 * torch.sin((cxcy[:, 4]))).view(-1,1) 466 | y2 = (tail_y + cxcy[:, 3] / 2 * torch.cos((cxcy[:, 4]))).view(-1,1) 467 | x3 = (tail_x - cxcy[:, 3] / 2 * torch.sin((cxcy[:, 4]))).view(-1,1) 468 | y3 = (tail_y - cxcy[:, 3] / 2 * torch.cos((cxcy[:, 4]))).view(-1,1) 469 | x4 = (bow_x - cxcy[:, 3] / 2 * torch.sin((cxcy[:, 4]))).view(-1,1) 470 | y4 = (bow_y - cxcy[:, 3] / 2 * torch.cos((cxcy[:, 4]))).view(-1,1) 471 | return torch.cat([x1, y1, x2, y2, x3, y3, x4, y4],1) 472 | 473 | def cxcy_to_gcxgcy_old(cxcy, priors_cxcy): 474 | """ 475 | Encode bounding boxes (that are in center-size form) w.r.t. the corresponding prior boxes (that are in center-size form). 476 | 477 | For the center coordinates, find the offset with respect to the prior box, and scale by the size of the prior box. 478 | For the size coordinates, scale by the size of the prior box, and convert to the log-space. 479 | 480 | In the model, we are predicting bounding box coordinates in this encoded form. 481 | 482 | :param cxcy: bounding boxes in center-size coordinates, a tensor of size (n_priors, 4) 483 | :param priors_cxcy: prior boxes with respect to which the encoding must be performed, a tensor of size (n_priors, 4) 484 | :return: encoded bounding boxes, a tensor of size (n_priors, 4) 485 | """ 486 | eps = 1e-5 487 | # The 10 and 5 below are referred to as 'variances' in the original Caffe repo, completely empirical 488 | # They are for some sort of numerical conditioning, for 'scaling the localization gradient' 489 | # See https://github.com/weiliu89/caffe/issues/155 490 | return torch.cat([(cxcy[:, :2] - priors_cxcy[:, :2]) / (priors_cxcy[:, 2:] / 10), # g_c_x, g_c_y 491 | torch.log(cxcy[:, 2:] / priors_cxcy[:, 2:]+eps) * 5], 1) # g_w, g_h 492 | 493 | 494 | def gcxgcy_to_cxcy_old(gcxgcy, priors_cxcy): 495 | """ 496 | Decode bounding box coordinates predicted by the model, since they are encoded in the form mentioned above. 497 | 498 | They are decoded into center-size coordinates. 499 | 500 | This is the inverse of the function above. 501 | 502 | :param gcxgcy: encoded bounding boxes, i.e. output of the model, a tensor of size (n_priors, 4) 503 | :param priors_cxcy: prior boxes with respect to which the encoding is defined, a tensor of size (n_priors, 4) 504 | :return: decoded bounding boxes in center-size form, a tensor of size (n_priors, 4) 505 | """ 506 | 507 | return torch.cat([gcxgcy[:, :2] * priors_cxcy[:, 2:] / 10 + priors_cxcy[:, :2], # c_x, c_y 508 | torch.exp(gcxgcy[:, 2:] / 5) * priors_cxcy[:, 2:]], 1) # w, h 509 | #带方向iou计算 510 | def gcxgcy_to_cxcy(gcxgcy,priors_cxcy): #gcxgcy_to_cxcy_oriented 511 | xy = gcxgcy[:,:2]*priors_cxcy[:,2:4]/10 +priors_cxcy[:,:2] 512 | wh = torch.exp(gcxgcy[:,2:4]/5)*priors_cxcy[:,2:4] 513 | angle = (torch.atan(gcxgcy[:,4])+priors_cxcy[:,4]).view(-1,1) 514 | return torch.cat([xy,wh,angle],1) 515 | 516 | #带方向iou计算 517 | def cxcy_to_gcxgcy(cxcy, priors_cxcy): #开始写这个function cxcy_to_gcxgcy_oriented 518 | ''' 519 | :param cxcy: 520 | :param priors_cxcy: 521 | :return: 522 | ''' 523 | eps = 1e-5 524 | return torch.cat([(cxcy[:,:2]-priors_cxcy[:,:2])/(priors_cxcy[:,2:4]/10), 525 | torch.log(cxcy[:,2:4]/priors_cxcy[:,2:4]+eps)*5, 526 | torch.tan(cxcy[:,4]-priors_cxcy[:,4]).view(-1,1)],1) 527 | 528 | def angleiou(cxcy1,cxcy2): 529 | set1 = cxcy_to_xy(cxcy1) 530 | set2 =cxcy_to_xy(cxcy2) 531 | xmin = torch.min(set1[:,0],set1[:,2],set1[:,4],set1[:,6],set2[:,0],set2[:,2],set2[:,4],set2[:,6]) 532 | ymin = torch.min(set1[:,1],set1[:,3],set1[:,5],set1[:,7],set2[:,1],set2[:,3],set2[:,5],set2[:,7]) 533 | xmax = torch.min(set1[:, 0], set1[:, 2], set1[:, 4], set1[:, 6], set2[:, 0], set2[:, 2], set2[:, 4], set2[:, 6]) 534 | ymax = torch.min(set1[:, 1], set1[:, 3], set1[:, 5], set1[:, 7], set2[:, 1], set2[:, 3], set2[:, 5], set2[:, 7]) 535 | cxcy = torch.cat[[torch.div((xmin + xmax).cpu(), 2.0).view(-1, 1), torch.div((ymin + ymax).cpu(), 2.0).view(-1, 1)],1].unsqueeze(1) 536 | set1 = torch.cat[[set1[:,0],set1[:,1]],[set1[:,2],set1[:,3]],[set1[:,4],set1[:,5]],[set1[:,6],set1[:,7]],1] 537 | set2 = torch.cat[ 538 | [set2[:, 0], set2[:, 1]], [set2[:, 2], set2[:, 3]], [set2[:, 4], set2[:, 5]], [set2[:, 6], set2[:, 7]], 1] 539 | set1_origin,set2_origin = set1-cxcy, set2-cxcy 540 | 541 | angle = cxcy2[:4].expand[cxcy1.shape(0),4,1] 542 | set1_oritate = torch.cat[[[set1_origin[:,:,0]*torch.cos(-angle[:,:0])-set1_origin[:,:,1]*torch.sin(-angle[:,:0])],\ 543 | [set1_origin[:,:,1]*torch.cos(-angle[:,:0])+set1_origin[:,:,0]*torch.sin(-angle[:,:0])]],2] 544 | set2_oritate = torch.cat[ 545 | [[set2_origin[:, :, 0] * torch.cos(-angle[:, :0]) - set2_origin[:, :, 1] * torch.sin(-angle[:, :0])], \ 546 | [set2_origin[:, :, 1] * torch.cos(-angle[:, :0]) + set2_origin[:, :, 0] * torch.sin(-angle[:, :0])]], 2] 547 | set1_oritate = [torch.min(set1_oritate[:,:,0].view(-1,4),1),torch.min(set1_oritate[:,:,1].view(-1,4),1), 548 | torch.max(set1_oritate[:,:,0].view(-1,4),1),torch.max(set1_oritate[:,:,1].view(-1,4),1)] 549 | set2_oritate = [torch.min(set2_oritate[:, :, 0].view(-1, 4), 1), torch.min(set2_oritate[:, :, 1].view(-1, 4), 1), 550 | torch.max(set2_oritate[:, :, 0].view(-1, 4), 1), torch.max(set2_oritate[:, :, 1].view(-1, 4), 1)] 551 | iou = find_jaccard_overlap_old(set1_oritate,set2_oritate) 552 | iou = iou*torch.abs(torch.cos(cxcy1[:,:,4].view(-1,1)-cxcy2[:,:,4].view(1,-1))) 553 | return iou 554 | 555 | 556 | 557 | 558 | 559 | def find_intersection(set_1, set_2): 560 | """ 561 | Find the intersection of every box combination between two sets of boxes that are in boundary coordinates. 562 | 563 | :param set_1: set 1, a tensor of dimensions (n1, 4) 564 | :param set_2: set 2, a tensor of dimensions (n2, 4) 565 | :return: intersection of each of the boxes in set 1 with respect to each of the boxes in set 2, a tensor of dimensions (n1, n2) 566 | """ 567 | 568 | # PyTorch auto-broadcasts singleton dimensions 569 | lower_bounds = torch.max(set_1[:, :2].unsqueeze(1), set_2[:, :2].unsqueeze(0)) # (n1, n2, 2) 570 | upper_bounds = torch.min(set_1[:, 2:].unsqueeze(1), set_2[:, 2:].unsqueeze(0)) # (n1, n2, 2) 571 | intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0) # (n1, n2, 2) 572 | return intersection_dims[:, :, 0] * intersection_dims[:, :, 1] # (n1, n2) 573 | 574 | def find_jaccard_verlap_v2(set_1,set_2): 575 | set_1 = cxcy_minmaxxy(set_1) 576 | set_2 = cxcy_minmaxxy(set_2) 577 | return find_jaccard_overlap_old(set_1,set_2) 578 | 579 | 580 | 581 | 582 | def find_jaccard_overlap_old(set_1, set_2): # 无方向iou 583 | """ 584 | Find the Jaccard Overlap (IoU) of every box combination between two sets of boxes that are in boundary coordinates. 585 | 586 | :param set_1: set 1, a tensor of dimensions (n1, 4) 587 | :param set_2: set 2, a tensor of dimensions (n2, 4) 588 | :return: Jaccard Overlap of each of the boxes in set 1 with respect to each of the boxes in set 2, a tensor of dimensions (n1, n2) 589 | """ 590 | 591 | # Find intersections 592 | intersection = find_intersection(set_1, set_2) # (n1, n2) 593 | 594 | # Find areas of each box in both sets 595 | areas_set_1 = (set_1[:, 2] - set_1[:, 0]) * (set_1[:, 3] - set_1[:, 1]) # (n1) 596 | areas_set_2 = (set_2[:, 2] - set_2[:, 0]) * (set_2[:, 3] - set_2[:, 1]) # (n2) 597 | 598 | # Find the union 599 | # PyTorch auto-broadcasts singleton dimensions 600 | union = areas_set_1.unsqueeze(1) + areas_set_2.unsqueeze(0) - intersection # (n1, n2) 601 | 602 | return intersection / union # (n1, n2) 603 | 604 | #计算带方向的iou 605 | def find_jaccard_overlap(set1,set2): #find_jaccard_overlap_oriented 606 | ''' 607 | :param set1: 608 | :param set2: 609 | :return: 610 | ''' 611 | # Find intersections 612 | set1 =set1.cpu().detach().numpy() 613 | set2 = set2.cpu().detach().numpy() 614 | # set1 =set1.cpu().numpy() 615 | # set2 = set2.cpu().numpy() 616 | iou_all = [] 617 | for i in range(set1.shape[0]): 618 | iou_row = [] 619 | for j in range(set2.shape[0]): 620 | iou_row.append(rbox_iou(set1[i],set2[j])) 621 | iou_all.append(iou_row) 622 | iou_all= np.array(iou_all) 623 | return torch.from_numpy(iou_all).to(device) 624 | # inter_area = plo11.interpolate(ploy2).area 625 | # union_area = MultiPoint(union_poly).convex_hull.area 626 | # iou = float(inter_area) / union_area 627 | # return torch.from_numpy(iou).cuda() 628 | 629 | def rbox_iou(a,b): 630 | a = [[a[0], a[1]], [a[2], a[3]], [a[4], a[5]], [a[7], a[7]]] 631 | b = [[b[0], b[1]], [b[2], b[3]], [b[4], b[5]], [b[7], b[7]]] 632 | poly1 = Polygon(a).convex_hull # python四边形对象,会自动计算四个点,最后四个点顺序为:左上 左下 右下 右上 左上 633 | poly2 = Polygon(b).convex_hull 634 | union_poly = np.concatenate((a, b)) # 合并两个box坐标,变为8*2 635 | 636 | if not poly1.intersects(poly2): # 如果两四边形不相交 637 | iou = 0 638 | else: 639 | try: 640 | inter_area = poly1.intersection(poly2).area # 相交面积 641 | # print(inter_area) 642 | # union_area = poly1.area + poly2.area - inter_area 643 | union_area = MultiPoint(union_poly).convex_hull.area 644 | # print(union_area) 645 | if union_area == 0: 646 | iou = 0 647 | # iou = float(inter_area) / (union_area-inter_area) #错了 648 | iou = float(inter_area) / union_area 649 | # iou=float(inter_area) /(poly1.area+poly2.area-inter_area) 650 | # 源码中给出了两种IOU计算方式,第一种计算的是: 交集部分/包含两个四边形最小多边形的面积 651 | # 第二种: 交集 / 并集(常见矩形框IOU计算方式) 652 | except shapely.geos.TopologicalError: 653 | print('shapely.geos.TopologicalError occured, iou set to 0') 654 | iou = 0 655 | return iou 656 | 657 | # Some augmentation functions below have been adapted from 658 | # From https://github.com/amdegroot/ssd.pytorch/blob/master/utils/augmentations.py 659 | 660 | def find_intersection_rbox(set_1,set_2): 661 | """ 662 | Find the intersection of every box combination between two sets of boxes that are in boundary coordinates. 663 | 664 | :param set_1: set 1, a tensor of dimensions (n1, 4) 665 | :param set_2: set 2, a tensor of dimensions (n2, 4) 666 | :return: intersection of each of the boxes in set 1 with respect to each of the boxes in set 2, a tensor of dimensions (n1, n2) 667 | """ 668 | # PyTorch auto-broadcasts singleton dimensions 669 | lower_bounds = torch.max(set_1[:,:, :2], set_2[:,:, :2]) # (n1, n2, 2) 670 | upper_bounds = torch.min(set_1[:,:, 2:], set_2[:,:, 2:]) # (n1, n2, 2) 671 | intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0) # (n1, n2, 2) 672 | return intersection_dims[:, :, 0] * intersection_dims[:, :, 1] # (n1, n2) 673 | 674 | 675 | def angleiou(cxcy1,cxcy2): 676 | set1 = cxcy_to_xy(cxcy1) 677 | set2 = cxcy_to_xy(cxcy2) 678 | # x_collect = torch.cat([set1[:,0].view(-1,1),set1[:,2].view(-1,1),set1[:,4].view(-1,1),set1[:,6].view(-1,1),set2[:,0].view(-1,1),set2[:,2].view(-1,1),set2[:,4].view(-1,1),set2[:,6].view(-1,1)],1) 679 | # y_collocet = torch.cat([set1[:,1].view(-1,1),set1[:,3].view(-1,1),set1[:,5].view(-1,1),set1[:,7].view(-1,1),set2[:,1].view(-1,1),set2[:,3].view(-1,1),set2[:,5].view(-1,1),set2[:,7].view(-1,1)],1) 680 | x_collect_1 = torch.cat( 681 | [set1[:, 0].view(-1, 1), set1[:, 2].view(-1, 1), set1[:, 4].view(-1, 1), set1[:, 6].view(-1, 1)], 1).unsqueeze(1).expand(cxcy1.shape[0],cxcy2.shape[0],4) 682 | y_collect_1 = torch.cat( 683 | [set1[:, 1].view(-1, 1), set1[:, 3].view(-1, 1), set1[:, 5].view(-1, 1), set1[:, 7].view(-1, 1)], 1).unsqueeze(1).expand(cxcy1.shape[0],cxcy2.shape[0],4) 684 | x_collect_2 = torch.cat( 685 | [set2[:, 0].view(-1, 1), set2[:, 2].view(-1, 1), set2[:, 4].view(-1, 1), set2[:, 6].view(-1, 1)], 1).unsqueeze(0).expand(cxcy1.shape[0],cxcy2.shape[0],4) 686 | y_collect_2 = torch.cat( 687 | [set2[:, 1].view(-1, 1), set2[:, 3].view(-1, 1), set2[:, 5].view(-1, 1), set2[:, 7].view(-1, 1)], 1).unsqueeze(0).expand(cxcy1.shape[0],cxcy2.shape[0],4) 688 | x_collect = torch.cat([x_collect_1,x_collect_2],2) 689 | y_collect = torch.cat([y_collect_1,y_collect_2],2) 690 | xmin = torch.min(x_collect,2).values 691 | ymin = torch.min(y_collect,2).values 692 | xmax = torch.max(x_collect,2).values 693 | ymax = torch.max(y_collect,2).values 694 | cxcy = torch.cat([torch.div((xmin + xmax), 2.0).unsqueeze(2), torch.div((ymin + ymax), 2.0).unsqueeze(2)],2) 695 | set1_expand = set1.view(-1,4,2).unsqueeze(1).expand(cxcy1.shape[0],cxcy2.shape[0],4,2) 696 | set2_expand = set2.view(-1,4,2).unsqueeze(0).expand(cxcy1.shape[0],cxcy2.shape[0],4,2) 697 | cxcy = cxcy.unsqueeze(2).expand(cxcy.shape[0],cxcy.shape[1],4,cxcy.shape[2]) 698 | # # set1 = torch.cat([[set1[:,0],set1[:,1]],[set1[:,2],set1[:,3]],[set1[:,4],set1[:,5]],[set1[:,6],set1[:,7]]],1) 699 | # # set2 = torch.cat( 700 | # # [set2[:, 0], set2[:, 1]], [set2[:, 2], set2[:, 3]], [set2[:, 4], set2[:, 5]], [set2[:, 6], set2[:, 7]], 1) 701 | set1_origin,set2_origin = set1_expand-cxcy, set2_expand-cxcy 702 | #e 703 | angle = cxcy2[:,4].view(-1,1,1).expand(cxcy1.shape[0], cxcy2.shape[0],4,1) 704 | set1_oritate = torch.cat([(set1_origin[:,:,:,0]*torch.cos(-angle[:,:,:,0])-set1_origin[:,:,:,1]*torch.sin(-angle[:,:,:,0])).unsqueeze(3),\ 705 | (set1_origin[:,:,:, 1]*torch.cos(-angle[:,:,:,0])+set1_origin[:,:,:,0]*torch.sin(-angle[:,:,:,0])).unsqueeze(3)],3) 706 | set2_oritate = torch.cat( 707 | [(set2_origin[:, :,:, 0] * torch.cos(-angle[:, :,:,0]) - set2_origin[:, :,:, 1] * torch.sin(-angle[:, :,:,0])).unsqueeze(3), \ 708 | (set2_origin[:, :,:, 1] * torch.cos(-angle[:, :,:,0]) + set2_origin[:, :,:, 0] * torch.sin(-angle[:, :,:,0])).unsqueeze(3)], 3) 709 | set1_oritate = torch.cat([torch.min(set1_oritate[:,:,:,0],2).values.unsqueeze(2),torch.min(set1_oritate[:,:,:,1],2).values.unsqueeze(2), 710 | torch.max(set1_oritate[:,:,:,0],2).values.unsqueeze(2),torch.max(set1_oritate[:,:,:,1],2).values.unsqueeze(2)],2) 711 | set2_oritate = torch.cat([torch.min(set2_oritate[:, :,:, 0],2).values.unsqueeze(2), torch.min(set2_oritate[:, :,:, 1],2).values.unsqueeze(2), 712 | torch.max(set2_oritate[:, :,:, 0],2).values.unsqueeze(2), torch.max(set2_oritate[:, :,:, 1],2).values.unsqueeze(2)],2) 713 | intersection = find_intersection_rbox(set1_oritate, set2_oritate) 714 | area_set1 = (set1_oritate[:,:,2]-set1_oritate[:,:,0])*(set1_oritate[:,:,3]-set1_oritate[:,:,1]) 715 | area_set2 = (set2_oritate[:, :, 2] - set2_oritate[:, :, 0]) * (set2_oritate[:, :, 3] - set2_oritate[:, :, 1]) 716 | iou = intersection /(area_set1 + area_set2 - intersection) # (n1, n2) 717 | angle_cos = torch.abs(torch.cos(cxcy1[:,4].unsqueeze(1).expand(cxcy1.shape[0],cxcy2.shape[0]) - cxcy2[:,4].unsqueeze(0).expand(cxcy1.shape[0],cxcy2.shape[0]))) 718 | iou = iou*angle_cos 719 | return iou 720 | 721 | 722 | 723 | def expand(image, boxes, filler): 724 | """ 725 | Perform a zooming out operation by placing the image in a larger canvas of filler material. 726 | 727 | Helps to learn to detect smaller objects. 728 | 729 | :param image: image, a tensor of dimensions (3, original_h, original_w) 730 | :param boxes: bounding boxes in boundary coordinates, a tensor of dimensions (n_objects, 4) 731 | :param filler: RBG values of the filler material, a list like [R, G, B] 732 | :return: expanded image, updated bounding box coordinates 733 | """ 734 | # Calculate dimensions of proposed expanded (zoomed-out) image 735 | original_h = image.size(1) 736 | original_w = image.size(2) 737 | max_scale = 4 738 | scale = random.uniform(1, max_scale) 739 | new_h = int(scale * original_h) 740 | new_w = int(scale * original_w) 741 | 742 | # Create such an image with the filler 743 | filler = torch.FloatTensor(filler) # (3) 744 | new_image = torch.ones((3, new_h, new_w), dtype=torch.float) * filler.unsqueeze(1).unsqueeze(1) # (3, new_h, new_w) 745 | # Note - do not use expand() like new_image = filler.unsqueeze(1).unsqueeze(1).expand(3, new_h, new_w) 746 | # because all expanded values will share the same memory, so changing one pixel will change all 747 | 748 | # Place the original image at random coordinates in this new image (origin at top-left of image) 749 | left = random.randint(0, new_w - original_w) 750 | right = left + original_w 751 | top = random.randint(0, new_h - original_h) 752 | bottom = top + original_h 753 | new_image[:, top:bottom, left:right] = image 754 | 755 | # Adjust bounding boxes' coordinates accordingly 756 | 757 | new_boxes = boxes + torch.FloatTensor([left, top, left, top]).unsqueeze( 758 | 0) # (n_objects, 4), n_objects is the no. of objects in this image 759 | return new_image, new_boxes 760 | 761 | 762 | def random_crop(image, boxes, labels, difficulties): 763 | """ 764 | Performs a random crop in the manner stated in the paper. Helps to learn to detect larger and partial objects. 765 | 766 | Note that some objects may be cut out entirely. 767 | 768 | Adapted from https://github.com/amdegroot/ssd.pytorch/blob/master/utils/augmentations.py 769 | 770 | :param image: image, a tensor of dimensions (3, original_h, original_w) 771 | :param boxes: bounding boxes in boundary coordinates, a tensor of dimensions (n_objects, 4) 772 | :param labels: labels of objects, a tensor of dimensions (n_objects) 773 | :param difficulties: difficulties of detection of these objects, a tensor of dimensions (n_objects) 774 | :return: cropped image, updated bounding box coordinates, updated labels, updated difficulties 775 | """ 776 | original_h = image.size(1) 777 | original_w = image.size(2) 778 | # Keep choosing a minimum overlap until a successful crop is made 779 | while True: 780 | # Randomly draw the value for minimum overlap 781 | min_overlap = random.choice([0., .1, .3, .5, .7, .9, None]) # 'None' refers to no cropping 782 | 783 | # If not cropping 784 | if min_overlap is None: 785 | return image, boxes, labels, difficulties 786 | 787 | # Try up to 50 times for this choice of minimum overlap 788 | # This isn't mentioned in the paper, of course, but 50 is chosen in paper authors' original Caffe repo 789 | max_trials = 50 790 | for _ in range(max_trials): 791 | # Crop dimensions must be in [0.3, 1] of original dimensions 792 | # Note - it's [0.1, 1] in the paper, but actually [0.3, 1] in the authors' repo 793 | min_scale = 0.3 794 | scale_h = random.uniform(min_scale, 1) 795 | scale_w = random.uniform(min_scale, 1) 796 | new_h = int(scale_h * original_h) 797 | new_w = int(scale_w * original_w) 798 | 799 | # Aspect ratio has to be in [0.5, 2] 800 | aspect_ratio = new_h / new_w 801 | if not 0.5 < aspect_ratio < 2: 802 | continue 803 | 804 | # Crop coordinates (origin at top-left of image) 805 | left = random.randint(0, original_w - new_w) 806 | right = left + new_w 807 | top = random.randint(0, original_h - new_h) 808 | bottom = top + new_h 809 | crop = torch.FloatTensor([left, top, right, bottom]) # (4) 810 | 811 | # Calculate Jaccard overlap between the crop and the bounding boxes 812 | overlap = find_jaccard_overlap(crop.unsqueeze(0), 813 | boxes) # (1, n_objects), n_objects is the no. of objects in this image 814 | overlap = overlap.squeeze(0) # (n_objects) 815 | 816 | # If not a single bounding box has a Jaccard overlap of greater than the minimum, try again 817 | if overlap.max().item() < min_overlap: 818 | continue 819 | 820 | # Crop image 821 | new_image = image[:, top:bottom, left:right] # (3, new_h, new_w) 822 | 823 | # Find centers of original bounding boxes 824 | bb_centers = (boxes[:, :2] + boxes[:, 2:]) / 2. # (n_objects, 2) 825 | 826 | # Find bounding boxes whose centers are in the crop 827 | centers_in_crop = (bb_centers[:, 0] > left) * (bb_centers[:, 0] < right) * (bb_centers[:, 1] > top) * ( 828 | bb_centers[:, 1] < bottom) # (n_objects), a Torch uInt8/Byte tensor, can be used as a boolean index 829 | 830 | # If not a single bounding box has its center in the crop, try again 831 | if not centers_in_crop.any(): 832 | continue 833 | 834 | # Discard bounding boxes that don't meet this criterion 835 | new_boxes = boxes[centers_in_crop, :] 836 | new_labels = labels[centers_in_crop] 837 | new_difficulties = difficulties[centers_in_crop] 838 | 839 | # Calculate bounding boxes' new coordinates in the crop 840 | new_boxes[:, :2] = torch.max(new_boxes[:, :2], crop[:2]) # crop[:2] is [left, top] 841 | new_boxes[:, :2] -= crop[:2] 842 | new_boxes[:, 2:] = torch.min(new_boxes[:, 2:], crop[2:]) # crop[2:] is [right, bottom] 843 | new_boxes[:, 2:] -= crop[:2] 844 | 845 | return new_image, new_boxes, new_labels, new_difficulties 846 | 847 | 848 | def flip(image, boxes): 849 | """ 850 | Flip image horizontally. 851 | 852 | :param image: image, a PIL Image 853 | :param boxes: bounding boxes in boundary coordinates, a tensor of dimensions (n_objects, 4) 854 | :return: flipped image, updated bounding box coordinates 855 | """ 856 | # Flip image 857 | new_image = FT.hflip(image) 858 | 859 | # Flip boxes 860 | new_boxes = boxes 861 | new_boxes[:, 0] = image.width - boxes[:, 0] - 1 862 | new_boxes[:, 2] = image.width - boxes[:, 2] - 1 863 | new_boxes = new_boxes[:, [2, 1, 0, 3]] 864 | 865 | return new_image, new_boxes 866 | 867 | 868 | def resize_old(image, boxes, dims=(300, 300), return_percent_coords=True): 869 | """ 870 | Resize image. For the SSD300, resize to (300, 300). 871 | 872 | Since percent/fractional coordinates are calculated for the bounding boxes (w.r.t image dimensions) in this process, 873 | you may choose to retain them. 874 | 875 | :param image: image, a PIL Image 876 | :param boxes: bounding boxes in boundary coordinates, a tensor of dimensions (n_objects, 4) 877 | :return: resized image, updated bounding box coordinates (or fractional coordinates, in which case they remain the same) 878 | """ 879 | # Resize image 880 | new_image = FT.resize(image, dims) 881 | 882 | # Resize bounding boxes 883 | old_dims = torch.FloatTensor([image.width, image.height, image.width, image.height]).unsqueeze(0) 884 | new_boxes = boxes / old_dims # percent coordinates 885 | 886 | if not return_percent_coords: 887 | new_dims = torch.FloatTensor([dims[1], dims[0], dims[1], dims[0]]).unsqueeze(0) 888 | new_boxes = new_boxes * new_dims 889 | 890 | return new_image, new_boxes 891 | 892 | def resize(image,boxes,dims=(300,300),return_percent_coords=True): 893 | ''' 894 | 把图像缩放到固定尺寸,并返回一个分数的坐标,作为box坐标 895 | :param image: 896 | :param boxes: 897 | :param dims: 898 | :param return_percent_coords: 899 | :return: 900 | ''' 901 | new_image = FT.resize(image,dims) 902 | #resize box的坐标 903 | old_dims = torch.FloatTensor([image.width,image.height,image.width,image.height,image.width,image.height,image.width,image.height]).unsqueeze(0) 904 | new_boxes = boxes /old_dims 905 | if not return_percent_coords: 906 | new_dims = torch.FloatTensor([dims[1], dims[0], dims[1], dims[0],dims[1], dims[0], dims[1], dims[0]]).unsqueeze(0) 907 | new_boxes = new_boxes * new_dims 908 | return new_image, new_boxes 909 | 910 | 911 | 912 | def photometric_distort(image): 913 | """ 914 | Distort brightness, contrast, saturation, and hue, each with a 50% chance, in random order. 915 | 916 | :param image: image, a PIL Image 917 | :return: distorted image 918 | """ 919 | new_image = image 920 | 921 | distortions = [FT.adjust_brightness, 922 | FT.adjust_contrast, 923 | FT.adjust_saturation, 924 | FT.adjust_hue] 925 | 926 | random.shuffle(distortions) 927 | 928 | for d in distortions: 929 | if random.random() < 0.5: 930 | if d.__name__ is 'adjust_hue': 931 | # Caffe repo uses a 'hue_delta' of 18 - we divide by 255 because PyTorch needs a normalized value 932 | adjust_factor = random.uniform(-18 / 255., 18 / 255.) 933 | else: 934 | # Caffe repo uses 'lower' and 'upper' values of 0.5 and 1.5 for brightness, contrast, and saturation 935 | adjust_factor = random.uniform(0.5, 1.5) 936 | 937 | # Apply this distortion 938 | new_image = d(new_image, adjust_factor) 939 | 940 | return new_image 941 | 942 | 943 | def transform_old(image, boxes, labels, difficulties, split): 944 | """ 945 | Apply the transformations above. 946 | 947 | :param image: image, a PIL Image 948 | :param boxes: bounding boxes in boundary coordinates, a tensor of dimensions (n_objects, 4) 949 | :param labels: labels of objects, a tensor of dimensions (n_objects) 950 | :param difficulties: difficulties of detection of these objects, a tensor of dimensions (n_objects) 951 | :param split: one of 'TRAIN' or 'TEST', since different sets of transformations are applied 952 | :return: transformed image, transformed bounding box coordinates, transformed labels, transformed difficulties 953 | """ 954 | assert split in {'TRAIN', 'TEST'} 955 | 956 | # Mean and standard deviation of ImageNet data that our base VGG from torchvision was trained on 957 | # see: https://pytorch.org/docs/stable/torchvision/models.html 958 | mean = [0.485, 0.456, 0.406] 959 | std = [0.229, 0.224, 0.225] 960 | 961 | new_image = image 962 | new_boxes = boxes 963 | new_labels = labels 964 | new_difficulties = difficulties 965 | # Skip the following operations if validation/evaluation 966 | if split == 'TRAIN': 967 | # A series of photometric distortions in random order, each with 50% chance of occurrence, as in Caffe repo 968 | new_image = photometric_distort(new_image) 969 | 970 | # Convert PIL image to Torch tensor 971 | new_image = FT.to_tensor(new_image) 972 | 973 | # Expand image (zoom out) with a 50% chance - helpful for training detection of small objects 974 | # Fill surrounding space with the mean of ImageNet data that our base VGG was trained on 975 | if random.random() < 0.5: 976 | new_image, new_boxes = expand(new_image, boxes, filler=mean) 977 | 978 | # Randomly crop image (zoom in) 979 | new_image, new_boxes, new_labels, new_difficulties = random_crop(new_image, new_boxes, new_labels, 980 | new_difficulties) 981 | 982 | # Convert Torch tensor to PIL image 983 | new_image = FT.to_pil_image(new_image) 984 | 985 | # Flip image with a 50% chance 986 | if random.random() < 0.5: 987 | new_image, new_boxes = flip(new_image, new_boxes) 988 | 989 | # Resize image to (300, 300) - this also converts absolute boundary coordinates to their fractional form 990 | new_image, new_boxes = resize(new_image, new_boxes, dims=(300, 300)) 991 | 992 | # Convert PIL image to Torch tensor 993 | new_image = FT.to_tensor(new_image) 994 | 995 | # Normalize by mean and standard deviation of ImageNet data that our base VGG was trained on 996 | new_image = FT.normalize(new_image, mean=mean, std=std) 997 | 998 | return new_image, new_boxes, new_labels, new_difficulties 999 | 1000 | def transform (image, boxes,labels,difficulties,split): 1001 | ''' 1002 | :param image: 1003 | :param boxes: 1004 | :param labels: 1005 | :param difficulties: 1006 | :param split: 1007 | :return: 1008 | ''' 1009 | assert split in {'TRAIN','TEST'} 1010 | 1011 | mean = [0.485, 0.456, 0.406] 1012 | std = [0.229, 0.224, 0.225] 1013 | 1014 | new_image = image 1015 | new_boxes = boxes 1016 | new_labels = labels 1017 | new_difficulties = difficulties 1018 | if split == 'TRAIN': 1019 | # 一系列畸变 1020 | new_image = photometric_distort(new_image) 1021 | # 变成tensor 以下省略了一系列数据增广的操作 1022 | #new_image = FT.to_tensor(new_image) 1023 | # if random.random()<0.5: 1024 | # new_image,new_boxes = expand(new_image,new_boxes,filler=mean) 1025 | #randomly crop 1026 | new_image,new_boxes = resize(new_image,new_boxes,dims=(300,300)) 1027 | # Convert PIL image to Torch tensor 1028 | new_image = FT.to_tensor(new_image) 1029 | 1030 | # Normalize by mean and standard deviation of ImageNet data that our base VGG was trained on 1031 | new_image = FT.normalize(new_image, mean=mean, std=std) 1032 | 1033 | return new_image, new_boxes, new_labels, new_difficulties 1034 | 1035 | def adjust_learning_rate(optimizer, scale): 1036 | """ 1037 | Scale learning rate by a specified factor. 1038 | 1039 | :param optimizer: optimizer whose learning rate must be shrunk. 1040 | :param scale: factor to multiply learning rate with. 1041 | """ 1042 | for param_group in optimizer.param_groups: 1043 | param_group['lr'] = param_group['lr'] * scale 1044 | print("DECAYING learning rate.\n The new LR is %f\n" % (optimizer.param_groups[1]['lr'],)) 1045 | 1046 | 1047 | def accuracy(scores, targets, k): 1048 | """ 1049 | Computes top-k accuracy, from predicted and true labels. 1050 | 1051 | :param scores: scores from the model 1052 | :param targets: true labels 1053 | :param k: k in top-k accuracy 1054 | :return: top-k accuracy 1055 | """ 1056 | batch_size = targets.size(0) 1057 | _, ind = scores.topk(k, 1, True, True) 1058 | correct = ind.eq(targets.view(-1, 1).expand_as(ind)) 1059 | correct_total = correct.view(-1).float().sum() # 0D tensor 1060 | return correct_total.item() * (100.0 / batch_size) 1061 | 1062 | 1063 | def save_checkpoint(epoch, epochs_since_improvement, model, optimizer, loss, best_loss, is_best): 1064 | """ 1065 | Save model checkpoint. 1066 | 1067 | :param epoch: epoch number 1068 | :param epochs_since_improvement: number of epochs since last improvement 1069 | :param model: model 1070 | :param optimizer: optimizer 1071 | :param loss: validation loss in this epoch 1072 | :param best_loss: best validation loss achieved so far (not necessarily in this checkpoint) 1073 | :param is_best: is this checkpoint the best so far? 1074 | """ 1075 | state = {'epoch': epoch, 1076 | 'epochs_since_improvement': epochs_since_improvement, 1077 | 'loss': loss, 1078 | 'best_loss': best_loss, 1079 | 'model': model, 1080 | 'optimizer': optimizer} 1081 | filename = 'checkpoint_ssd300.pth.tar' 1082 | torch.save(state, filename) 1083 | # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint 1084 | if is_best: 1085 | torch.save(state, 'BEST_' + filename) 1086 | 1087 | 1088 | class AverageMeter(object): 1089 | """ 1090 | Keeps track of most recent, average, sum, and count of a metric. 1091 | """ 1092 | 1093 | def __init__(self): 1094 | self.reset() 1095 | 1096 | def reset(self): 1097 | self.val = 0 1098 | self.avg = 0 1099 | self.sum = 0 1100 | self.count = 0 1101 | 1102 | def update(self, val, n=1): 1103 | self.val = val 1104 | self.sum += val * n 1105 | self.count += n 1106 | self.avg = self.sum / self.count 1107 | 1108 | 1109 | def clip_gradient(optimizer, grad_clip): 1110 | """ 1111 | Clips gradients computed during backpropagation to avoid explosion of gradients. 1112 | 1113 | :param optimizer: optimizer with the gradients to be clipped 1114 | :param grad_clip: clip value 1115 | """ 1116 | for group in optimizer.param_groups: 1117 | for param in group['params']: 1118 | if param.grad is not None: 1119 | param.grad.data.clamp_(-grad_clip, grad_clip) 1120 | --------------------------------------------------------------------------------