├── LICENSE ├── README.md ├── config.py ├── voc0712_meta.py └── voc07test_meta.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 TengfeiZhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NIST-FSD 2 | NIST-FSD is a benchmark devised for few-shot object detection, based on the Pascal VOC dataset. 3 | This repo contains the code to build this dataset. Besides, we set a parameter IDX to switch different class partitions flexibly. 4 | 5 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # config.py 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from easydict import EasyDict as edict 6 | 7 | 8 | # config for meta-learning 9 | __C = edict() 10 | cfg = __C 11 | 12 | # Set some path 13 | __C.voc0712_meta_info = './voc2017_metainfo' 14 | __C.voc07test_meta_info = './voc2007test_metainfo' 15 | __C.VOC0712_ROOT = './path to VOC2007 + VOC2012/VOCdevkit/' 16 | __C.VOC07testVOC_ROOT = './path to VOC2007test/VOCdevkit/' 17 | 18 | # idx for dataset (seen / unseen split) 19 | __C.IDX = 1 20 | 21 | # idx=1 22 | if __C.IDX == 1: 23 | __C.seen_classes = ['pottedplant', 'tvmonitor', 'sofa', 'motorbike', 24 | 'horse', 'boat', 'dog', 'bicycle', 'train', 25 | 'sheep','bottle', 'person', 'aeroplane', 'diningtable', 'bird'] 26 | __C.unseen_classes = ['cow', 'bus', 'cat', 'car', 'chair'] 27 | 28 | # idx=2 29 | elif __C.IDX == 2: 30 | __C.seen_classes = ['pottedplant', 'tvmonitor', 'sofa', 'bus', 31 | 'boat', 'bicycle', 'train', 'cow', 'cat', 32 | 'car', 'chair', 'sheep', 'bottle', 'aeroplane', 'bird'] 33 | __C.unseen_classes = ['diningtable', 'dog', 'horse', 'motorbike', 'person'] 34 | 35 | # idx=3 36 | elif __C.IDX == 3: 37 | __C.seen_classes = ['bus', 'motorbike', 38 | 'horse', 'boat', 'dog', 'bicycle', 'cow', 'cat', 39 | 'car', 'chair', 'bottle', 'person', 'aeroplane', 'diningtable', 'bird'] 40 | __C.unseen_classes = ['pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor',] 41 | 42 | # idx=4 43 | elif __C.IDX == 4: 44 | __C.seen_classes = ['bus', 'motorbike', 'horse', 'dog', 'cow', 'cat', 45 | 'car', 'chair', 'person', 'diningtable', 'pottedplant', 46 | 'sheep', 'sofa', 'train', 'tvmonitor',] 47 | __C.unseen_classes = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle',] 48 | 49 | else: 50 | raise NotImplementedError('No this idx {}'.format(__C.IDX)) 51 | # TODO other class partitions: 52 | -------------------------------------------------------------------------------- /voc0712_meta.py: -------------------------------------------------------------------------------- 1 | """ 2 | NIST-FSD dataset by ztf-ucas 3 | Reference: Francisco Massa 4 | https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py 5 | """ 6 | import os.path as osp 7 | import sys 8 | import torch 9 | import torch.utils.data as data 10 | import cv2 11 | import numpy as np 12 | from data.config import cfg 13 | import os 14 | import pickle 15 | 16 | if sys.version_info[0] == 2: 17 | import xml.etree.cElementTree as ET 18 | else: 19 | import xml.etree.ElementTree as ET 20 | 21 | VOC_CLASSES = [ # always index 0 22 | 'aeroplane', 'bicycle', 'bird', 'boat', 23 | 'bottle', 'bus', 'car', 'cat', 'chair', 24 | 'cow', 'diningtable', 'dog', 'horse', 25 | 'motorbike', 'person', 'pottedplant', 26 | 'sheep', 'sofa', 'train', 'tvmonitor'] 27 | 28 | VOC_ROOT = cfg.VOC0712_ROOT 29 | 30 | 31 | def get_all_crop(target): 32 | """ 33 | Arguments: 34 | target (annotation) : the target annotation to be made usable 35 | will be an ET.Element 36 | Returns: 37 | a list containing lists of bounding boxes [bbox coords, class name] 38 | """ 39 | res = [] 40 | keep_difficult = False 41 | for obj in target.iter('object'): 42 | difficult = int(obj.find('difficult').text) == 1 43 | if not keep_difficult and difficult: 44 | continue 45 | name = obj.find('name').text.lower().strip() 46 | bbox = obj.find('bndbox') 47 | 48 | pts = ['xmin', 'ymin', 'xmax', 'ymax'] 49 | bndbox = [] 50 | for i, pt in enumerate(pts): 51 | cur_pt = int(bbox.find(pt).text) - 1 52 | # scale height or width 53 | bndbox.append(cur_pt) 54 | bndbox.append(name) 55 | res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind] 56 | 57 | return res # [[xmin, ymin, xmax, ymax, label_ind], ... ] 58 | 59 | 60 | class VOCAnnotationTransform_meta(object): 61 | """Transforms a VOC annotation into a Tensor of bbox coords and label index 62 | Initilized with a dictionary lookup of classnames to indexes 63 | 64 | Arguments: 65 | class_to_ind (dict, optional): dictionary lookup of classnames -> indexes 66 | (default: alphabetic indexing of VOC's 20 classes) 67 | keep_difficult (bool, optional): keep difficult instances or not 68 | (default: False) 69 | height (int): height 70 | width (int): width 71 | """ 72 | 73 | def __init__(self, meta_classes=None, keep_difficult=False): 74 | self.class_to_ind = dict( 75 | zip(meta_classes, range(len(meta_classes)))) 76 | self.keep_difficult = keep_difficult 77 | 78 | def __call__(self, target, width, height, classes): 79 | """ 80 | Arguments: 81 | target (annotation) : the target annotation to be made usable 82 | will be an ET.Element 83 | Returns: 84 | a list containing lists of bounding boxes [bbox coords, class name] 85 | """ 86 | res = [] 87 | for obj in target.iter('object'): 88 | difficult = int(obj.find('difficult').text) == 1 89 | if not self.keep_difficult and difficult: 90 | continue 91 | name = obj.find('name').text.lower().strip() 92 | if name not in classes: 93 | continue 94 | bbox = obj.find('bndbox') 95 | 96 | pts = ['xmin', 'ymin', 'xmax', 'ymax'] 97 | bndbox = [] 98 | for i, pt in enumerate(pts): 99 | cur_pt = int(bbox.find(pt).text) - 1 100 | # scale height or width 101 | cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height 102 | bndbox.append(cur_pt) 103 | label_idx = self.class_to_ind[name] 104 | bndbox.append(label_idx) 105 | res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind] 106 | 107 | return res # [[xmin, ymin, xmax, ymax, label_ind], ... ] 108 | 109 | 110 | class NIST_FSD(data.Dataset): 111 | """VOC Detection Dataset Object 112 | input is image, target is annotation 113 | Arguments: 114 | root (string): filepath to VOCdevkit folder. 115 | image_set (string): imageset to use (eg. 'train', 'val', 'test') 116 | transform (callable, optional): transformation to perform on the 117 | input image 118 | target_transform (callable, optional): transformation to perform on the 119 | target `annotation` 120 | (eg: take in caption string, return tensor of word indices) 121 | dataset_name (string, optional): which dataset to load 122 | (default: 'VOC2007') 123 | """ 124 | 125 | def __init__(self, root=VOC_ROOT, 126 | image_sets=[('2007', 'trainval'), ('2012', 'trainval')], 127 | transform=None, target_transform=None, 128 | dataset_name='VOC0712', idx=cfg.IDX, meta_idx=None, meta_classes=None): 129 | # idx: id for class partition 130 | # meta_idx: samples for the current task 131 | # meta_class: classes for the current task 132 | self.root = root 133 | self.image_set = image_sets 134 | self.transform = transform 135 | self.target_transform = target_transform 136 | self.name = dataset_name 137 | self._annopath = osp.join('%s', 'Annotations', '%s.xml') 138 | self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg') 139 | self.ids = list() 140 | self.idx = idx 141 | self.meta_idx = meta_idx 142 | self.meta_classes = meta_classes 143 | # supervised learning setting 144 | if self.meta_classes is None: 145 | self.classes = VOC_CLASSES 146 | # meta-learning setting 147 | else: 148 | self.classes = self.meta_classes 149 | if self.meta_idx is None: 150 | for (year, name) in image_sets: 151 | rootpath = osp.join(self.root, 'VOC' + year) 152 | for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')): 153 | self.ids.append((rootpath, line.strip())) 154 | else: 155 | self.ids = self.meta_idx 156 | 157 | self.voc0712_meta_info = cfg.voc0712_meta_info 158 | if not osp.exists(self.voc0712_meta_info): 159 | os.mkdir(self.voc0712_meta_info) 160 | # get seen unseen classes 161 | self._get_seen_unseen_classes() 162 | 163 | # split images contain seen or unseen objects 164 | self.train_seen_pkl = osp.join(self.voc0712_meta_info, 'train_seen_' + str(self.idx) + '.pkl') 165 | self.train_unseen_pkl = osp.join(self.voc0712_meta_info, 'train_unseen_' + str(self.idx) + '.pkl') 166 | self.train_seen_unseen_pkl = osp.join(self.voc0712_meta_info, 'train_seen_unseen_' + str(self.idx) + '.pkl') 167 | if not osp.exists(self.train_seen_pkl): 168 | print('No seen_unseen split file, execute _filter_samples()') 169 | self._filter_samples() 170 | else: 171 | print('seen_unseen split file: ', self.train_seen_pkl) 172 | self.cls_dict_pkl = osp.join(self.voc0712_meta_info, 'cls_dict'+str(self.idx)+'.pkl') 173 | if not osp.exists(self.cls_dict_pkl): 174 | print('No cls_dict file, excute generate_cls_dict()') 175 | self.generate_cls_dict() 176 | else: 177 | print('cls_dict file: ', self.cls_dict_pkl) 178 | 179 | def __getitem__(self, index): 180 | im, gt, h, w = self.pull_item(index) 181 | 182 | return im, gt 183 | 184 | def __len__(self): 185 | return len(self.ids) 186 | 187 | def pull_item(self, index): 188 | img_id = self.ids[index] 189 | target = ET.parse(self._annopath % img_id).getroot() 190 | img = cv2.imread(self._imgpath % img_id) 191 | height, width, channels = img.shape 192 | 193 | if self.target_transform is not None: 194 | target = self.target_transform(target, width, height, self.classes) 195 | 196 | if self.transform is not None: 197 | target = np.array(target) 198 | # [[xmin, ymin, xmax, ymax, label_ind], ...] 199 | img, boxes, labels = self.transform(img, target[:, :4], target[:, 4]) 200 | # to rgb 201 | img = img[:, :, (2, 1, 0)] 202 | # img = img.transpose(2, 0, 1) 203 | # [[xmin, ymin, xmax, ymax, label_ind], ...] 204 | target = np.hstack((boxes, np.expand_dims(labels, axis=1))) 205 | return torch.from_numpy(img).permute(2, 0, 1), target, height, width 206 | 207 | def pull_image(self, index): 208 | '''Returns the original image object at index in PIL form 209 | 210 | Note: not using self.__getitem__(), as any transformations passed in 211 | could mess up this functionality. 212 | 213 | Argument: 214 | index (int): index of img to show 215 | Return: 216 | PIL img 217 | ''' 218 | img_id = self.ids[index] 219 | return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR) 220 | 221 | def pull_anno(self, index): 222 | '''Returns the original annotation of image at index 223 | 224 | Note: not using self.__getitem__(), as any transformations passed in 225 | could mess up this functionality. 226 | 227 | Argument: 228 | index (int): index of img to get annotation of 229 | Return: 230 | list: [img_id, [(label, bbox coords),...]] 231 | eg: ('001718', [('dog', (96, 13, 438, 332))]) 232 | ''' 233 | img_id = self.ids[index] 234 | anno = ET.parse(self._annopath % img_id).getroot() 235 | gt = self.target_transform(anno, 1, 1) 236 | return img_id[1], gt 237 | 238 | def pull_tensor(self, index): 239 | '''Returns the original image at an index in tensor form 240 | 241 | Note: not using self.__getitem__(), as any transformations passed in 242 | could mess up this functionality. 243 | 244 | Argument: 245 | index (int): index of img to show 246 | Return: 247 | tensorized version of img, squeezed 248 | ''' 249 | return torch.Tensor(self.pull_image(index)).unsqueeze_(0) 250 | 251 | def _get_seen_unseen_classes(self): 252 | self.seen_classes = cfg.seen_classes 253 | self.unseen_classes = cfg.unseen_classes 254 | 255 | def _filter_samples(self): 256 | """ 257 | train_seen: images contain only seen objects 258 | train_unseen: images contain only unseen objects 259 | train_seen_unseen: images contain seen and unseen objects 260 | """ 261 | train_seen_names = [] 262 | train_unseen_names = [] 263 | train_seen_unseen_names = [] 264 | name_list = self.ids 265 | for index, name in enumerate(name_list): 266 | xml_path = self._annopath % name 267 | target = ET.parse(xml_path).getroot() 268 | boxes = get_all_crop(target) 269 | cls_names = [box_info[-1] for box_info in boxes] 270 | # only seen class 271 | if self._check_if_clas_intersect(cls_names) == 's': 272 | train_seen_names.append(name) 273 | # only unseen class 274 | if self._check_if_clas_intersect(cls_names) == 'u': 275 | train_unseen_names.append(name) 276 | # only seen class 277 | if self._check_if_clas_intersect(cls_names) == 'su': 278 | train_seen_unseen_names.append(name) 279 | # seen (train+val) 280 | with open(os.path.join(self.voc0712_meta_info, 'train_seen_'+str(self.idx)+'.pkl'), 'wb') as f: 281 | pickle.dump(train_seen_names, f) 282 | # unseen (train+val) 283 | with open(os.path.join(self.voc0712_meta_info, 'train_unseen_' + str(self.idx) + '.pkl'), 'wb') as f: 284 | pickle.dump(train_unseen_names, f) 285 | # seen_unseen (train+val) 286 | with open(os.path.join(self.voc0712_meta_info, 'train_seen_unseen_' + str(self.idx) + '.pkl'), 'wb') as f: 287 | pickle.dump(train_seen_unseen_names, f) 288 | 289 | def _check_if_clas_intersect(self, cls_names): 290 | """ 291 | s: this image contains only seen objects 292 | u: this image contains only unseen objects 293 | su: this image contains seen and unseen objects 294 | """ 295 | seen_flag = True 296 | unseen_flag = True 297 | for cls_name in cls_names: 298 | if cls_name not in self.seen_classes: 299 | seen_flag = False 300 | if cls_name not in self.unseen_classes: 301 | unseen_flag = False 302 | 303 | if seen_flag: 304 | return 's' 305 | if unseen_flag: 306 | return 'u' 307 | else: 308 | return 'su' 309 | 310 | def generate_cls_dict(self): 311 | """ 312 | { 313 | cls_name1: [idx1, idx2, ...], 314 | cls_name2: [idx1, idx2, ...], 315 | ... 316 | } 317 | """ 318 | # process seen classes 319 | cls_dict = {} 320 | with open(self.train_seen_pkl, 'rb') as f: 321 | seen_idx_list = pickle.load(f) 322 | for cls_name in self.seen_classes: 323 | cls_dict[cls_name] = [] 324 | 325 | for idx in seen_idx_list: 326 | xml_path = self._annopath % idx 327 | target = ET.parse(xml_path).getroot() 328 | boxes = get_all_crop(target) 329 | cls_names = [box_info[-1] for box_info in boxes] 330 | for cls_name in self.seen_classes: 331 | if cls_name in cls_names: 332 | cls_dict[cls_name].append(idx) 333 | # process unseen classes 334 | with open(self.train_unseen_pkl, 'rb') as f: 335 | unseen_idx_list = pickle.load(f) 336 | for cls_name in self.unseen_classes: 337 | cls_dict[cls_name] = [] 338 | for idx in unseen_idx_list: 339 | xml_path = self._annopath % idx 340 | target = ET.parse(xml_path).getroot() 341 | boxes = get_all_crop(target) 342 | cls_names = [box_info[-1] for box_info in boxes] 343 | for cls_name in self.unseen_classes: 344 | if cls_name in cls_names: 345 | cls_dict[cls_name].append(idx) 346 | # save cls_dict into disk 347 | with open(self.cls_dict_pkl, 'wb') as f: 348 | pickle.dump(cls_dict, f) 349 | 350 | 351 | 352 | if __name__ == '__main__': 353 | # test 354 | db = NIST_FSD(root=VOC_ROOT,transform=None, idx=1) 355 | print(db.seen_classes) 356 | print(db.unseen_classes) 357 | -------------------------------------------------------------------------------- /voc07test_meta.py: -------------------------------------------------------------------------------- 1 | """ 2 | NIST-FSD dataset by ztf-ucas 3 | Reference: Francisco Massa 4 | https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py 5 | """ 6 | import os.path as osp 7 | import sys 8 | import torch 9 | import torch.utils.data as data 10 | import cv2 11 | import numpy as np 12 | from data.config import cfg 13 | import os 14 | import pickle 15 | 16 | if sys.version_info[0] == 2: 17 | import xml.etree.cElementTree as ET 18 | else: 19 | import xml.etree.ElementTree as ET 20 | 21 | VOC_CLASSES = [ # always index 0 22 | 'aeroplane', 'bicycle', 'bird', 'boat', 23 | 'bottle', 'bus', 'car', 'cat', 'chair', 24 | 'cow', 'diningtable', 'dog', 'horse', 25 | 'motorbike', 'person', 'pottedplant', 26 | 'sheep', 'sofa', 'train', 'tvmonitor'] 27 | 28 | VOC_ROOT = cfg.VOC07testVOC_ROOT 29 | 30 | 31 | def get_all_crop(target): 32 | """ 33 | Arguments: 34 | target (annotation) : the target annotation to be made usable 35 | will be an ET.Element 36 | Returns: 37 | a list containing lists of bounding boxes [bbox coords, class name] 38 | """ 39 | res = [] 40 | keep_difficult = False 41 | for obj in target.iter('object'): 42 | difficult = int(obj.find('difficult').text) == 1 43 | if not keep_difficult and difficult: 44 | continue 45 | name = obj.find('name').text.lower().strip() 46 | bbox = obj.find('bndbox') 47 | 48 | pts = ['xmin', 'ymin', 'xmax', 'ymax'] 49 | bndbox = [] 50 | for i, pt in enumerate(pts): 51 | cur_pt = int(bbox.find(pt).text) - 1 52 | # scale height or width 53 | bndbox.append(cur_pt) 54 | bndbox.append(name) 55 | res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind] 56 | # img_id = target.find('filename').text[:-4] 57 | 58 | return res # [[xmin, ymin, xmax, ymax, label_ind], ... ] 59 | 60 | 61 | class VOC07test_meta(data.Dataset): 62 | """VOC Detection Dataset Object 63 | 64 | input is image, target is annotation 65 | 66 | Arguments: 67 | root (string): filepath to VOCdevkit folder. 68 | image_set (string): imageset to use (eg. 'train', 'val', 'test') 69 | transform (callable, optional): transformation to perform on the 70 | input image 71 | target_transform (callable, optional): transformation to perform on the 72 | target `annotation` 73 | (eg: take in caption string, return tensor of word indices) 74 | dataset_name (string, optional): which dataset to load 75 | (default: 'VOC2007') 76 | """ 77 | 78 | def __init__(self, root=VOC_ROOT, 79 | image_sets=[('2007', 'test')], 80 | transform=None, target_transform=None, 81 | dataset_name='VOC07test', idx=cfg.IDX, meta_idx=None, meta_classes=None): 82 | # idx: id for class partition 83 | # meta_idx: samples for the current task 84 | # meta_class: classes for the current task 85 | self.root = root 86 | self.image_set = image_sets 87 | self.transform = transform 88 | self.target_transform = target_transform 89 | self.name = dataset_name 90 | self._annopath = osp.join('%s', 'Annotations', '%s.xml') 91 | self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg') 92 | self.ids = list() 93 | self.idx = idx 94 | self.meta_idx = meta_idx 95 | self.meta_classes = meta_classes 96 | if self.meta_classes is None: 97 | self.classes = VOC_CLASSES 98 | else: # meta 99 | self.classes = self.meta_classes 100 | if self.meta_idx is None: # not meta 101 | for (year, name) in image_sets: 102 | rootpath = osp.join(self.root, 'VOC' + year) 103 | for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')): 104 | self.ids.append((rootpath, line.strip())) 105 | else: # meta 106 | self.ids = self.meta_idx 107 | 108 | self.voc07test_meta_info = cfg.voc07test_meta_info 109 | if not osp.exists(self.voc07test_meta_info): 110 | os.mkdir(self.voc07test_meta_info) 111 | # get seen unseen classes 112 | self._get_seen_unseen_classes() 113 | 114 | # split images contain seen or unseen objects 115 | self.train_seen_pkl = osp.join(self.voc07test_meta_info, 'train_seen_' + str(self.idx) + '.pkl') 116 | self.train_unseen_pkl = osp.join(self.voc07test_meta_info, 'train_unseen_' + str(self.idx) + '.pkl') 117 | self.train_seen_unseen_pkl = osp.join(self.voc07test_meta_info, 'train_seen_unseen_' + str(self.idx) + '.pkl') 118 | if not osp.exists(self.train_seen_pkl): 119 | print('No seen_unseen split file, execute _filter_samples()') 120 | self._filter_samples() 121 | else: 122 | print('seen_unseen split file: ', self.train_seen_pkl) 123 | self.cls_dict_pkl = osp.join(self.voc07test_meta_info, 'cls_dict'+str(self.idx)+'.pkl') 124 | if not osp.exists(self.cls_dict_pkl): 125 | print('No cls_dict file, excute generate_cls_dict()') 126 | self.generate_cls_dict() 127 | else: 128 | print('cls_dict file: ', self.cls_dict_pkl) 129 | 130 | def __getitem__(self, index): 131 | im, gt, h, w = self.pull_item(index) 132 | 133 | return im, gt 134 | 135 | def __len__(self): 136 | return len(self.ids) 137 | 138 | def pull_item(self, index): 139 | img_id = self.ids[index] 140 | target = ET.parse(self._annopath % img_id).getroot() 141 | img = cv2.imread(self._imgpath % img_id) 142 | height, width, channels = img.shape 143 | 144 | if self.target_transform is not None: 145 | target = self.target_transform(target, width, height, self.classes) 146 | 147 | if self.transform is not None: 148 | target = np.array(target) 149 | # [[xmin, ymin, xmax, ymax, label_ind], ...] 150 | img, boxes, labels = self.transform(img, target[:, :4], target[:, 4]) 151 | # to rgb 152 | img = img[:, :, (2, 1, 0)] 153 | # img = img.transpose(2, 0, 1) 154 | # [[xmin, ymin, xmax, ymax, label_ind], ...] 155 | target = np.hstack((boxes, np.expand_dims(labels, axis=1))) 156 | return torch.from_numpy(img).permute(2, 0, 1), target, height, width 157 | 158 | 159 | def pull_image(self, index): 160 | '''Returns the original image object at index in PIL form 161 | 162 | Note: not using self.__getitem__(), as any transformations passed in 163 | could mess up this functionality. 164 | 165 | Argument: 166 | index (int): index of img to show 167 | Return: 168 | PIL img 169 | ''' 170 | img_id = self.ids[index] 171 | return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR) 172 | 173 | def pull_anno(self, index): 174 | '''Returns the original annotation of image at index 175 | 176 | Note: not using self.__getitem__(), as any transformations passed in 177 | could mess up this functionality. 178 | 179 | Argument: 180 | index (int): index of img to get annotation of 181 | Return: 182 | list: [img_id, [(label, bbox coords),...]] 183 | eg: ('001718', [('dog', (96, 13, 438, 332))]) 184 | ''' 185 | img_id = self.ids[index] 186 | anno = ET.parse(self._annopath % img_id).getroot() 187 | gt = self.target_transform(anno, 1, 1) 188 | return img_id[1], gt 189 | 190 | def pull_tensor(self, index): 191 | '''Returns the original image at an index in tensor form 192 | 193 | Note: not using self.__getitem__(), as any transformations passed in 194 | could mess up this functionality. 195 | 196 | Argument: 197 | index (int): index of img to show 198 | Return: 199 | tensorized version of img, squeezed 200 | ''' 201 | return torch.Tensor(self.pull_image(index)).unsqueeze_(0) 202 | 203 | def _get_seen_unseen_classes(self): 204 | self.seen_classes = cfg.seen_classes 205 | self.unseen_classes = cfg.unseen_classes 206 | 207 | def _filter_samples(self): 208 | """ 209 | train_seen: images contain only seen objects 210 | train_unseen: images contain only unseen objects 211 | train_seen_unseen: images contain seen and unseen objects 212 | """ 213 | train_seen_names = [] 214 | train_unseen_names = [] 215 | train_seen_unseen_names = [] 216 | name_list = self.ids 217 | for index, name in enumerate(name_list): 218 | xml_path = self._annopath % name 219 | target = ET.parse(xml_path).getroot() 220 | boxes = get_all_crop(target) 221 | cls_names = [box_info[-1] for box_info in boxes] 222 | # only seen class 223 | if self._check_if_clas_intersect(cls_names) == 's': 224 | train_seen_names.append(name) 225 | # only unseen class 226 | if self._check_if_clas_intersect(cls_names) == 'u': 227 | train_unseen_names.append(name) 228 | # only seen class 229 | if self._check_if_clas_intersect(cls_names) == 'su': 230 | train_seen_unseen_names.append(name) 231 | # seen (train+val) 232 | with open(os.path.join(self.voc0712_meta_info, 'train_seen_'+str(self.idx)+'.pkl'), 'wb') as f: 233 | pickle.dump(train_seen_names, f) 234 | # unseen (train+val) 235 | with open(os.path.join(self.voc0712_meta_info, 'train_unseen_' + str(self.idx) + '.pkl'), 'wb') as f: 236 | pickle.dump(train_unseen_names, f) 237 | # seen_unseen (train+val) 238 | with open(os.path.join(self.voc0712_meta_info, 'train_seen_unseen_' + str(self.idx) + '.pkl'), 'wb') as f: 239 | pickle.dump(train_seen_unseen_names, f) 240 | 241 | def _check_if_clas_intersect(self, cls_names): 242 | """ 243 | s: this image contains only seen objects 244 | u: this image contains only unseen objects 245 | su: this image contains seen and unseen objects 246 | """ 247 | seen_flag = True 248 | unseen_flag = True 249 | for cls_name in cls_names: 250 | if cls_name not in self.seen_classes: 251 | seen_flag = False 252 | if cls_name not in self.unseen_classes: 253 | unseen_flag = False 254 | 255 | if seen_flag: 256 | return 's' 257 | if unseen_flag: 258 | return 'u' 259 | else: 260 | return 'su' 261 | 262 | def generate_cls_dict(self): 263 | """ 264 | { 265 | cls_name1: [idx1, idx2, ...], 266 | cls_name2: [idx1, idx2, ...], 267 | ... 268 | } 269 | """ 270 | # process seen classes 271 | cls_dict = {} 272 | with open(self.train_seen_pkl, 'rb') as f: 273 | seen_idx_list = pickle.load(f) 274 | for cls_name in self.seen_classes: 275 | cls_dict[cls_name] = [] 276 | 277 | for idx in seen_idx_list: 278 | xml_path = self._annopath % idx 279 | target = ET.parse(xml_path).getroot() 280 | boxes = get_all_crop(target) 281 | cls_names = [box_info[-1] for box_info in boxes] 282 | for cls_name in self.seen_classes: 283 | if cls_name in cls_names: 284 | cls_dict[cls_name].append(idx) 285 | # process unseen classes 286 | with open(self.train_unseen_pkl, 'rb') as f: 287 | unseen_idx_list = pickle.load(f) 288 | for cls_name in self.unseen_classes: 289 | cls_dict[cls_name] = [] 290 | for idx in unseen_idx_list: 291 | xml_path = self._annopath % idx 292 | target = ET.parse(xml_path).getroot() 293 | boxes = get_all_crop(target) 294 | cls_names = [box_info[-1] for box_info in boxes] 295 | for cls_name in self.unseen_classes: 296 | if cls_name in cls_names: 297 | cls_dict[cls_name].append(idx) 298 | # save cls_dict into disk 299 | with open(self.cls_dict_pkl, 'wb') as f: 300 | pickle.dump(cls_dict, f) 301 | 302 | 303 | 304 | if __name__ == '__main__': 305 | db = VOC07test_meta(root=VOC_ROOT,transform=None) 306 | print(db.seen_classes) 307 | print(db.unseen_classes) 308 | 309 | 310 | 311 | --------------------------------------------------------------------------------