├── .gitattributes ├── coco_eval.py ├── coco_utils.py ├── engine.py ├── group_by_aspect_ratio.py ├── keypoint_rcnn.py ├── plot.py ├── predict_visualize.py ├── readme.md ├── result ├── 10.jpg ├── 14.jpg ├── 5.jpg ├── 8.jpg └── 9.jpg ├── train.py ├── transforms.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /coco_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import tempfile 3 | 4 | import numpy as np 5 | import copy 6 | import time 7 | import torch 8 | import torch._six 9 | 10 | from pycocotools.cocoeval import COCOeval 11 | from pycocotools.coco import COCO 12 | import pycocotools.mask as mask_util 13 | 14 | from collections import defaultdict 15 | 16 | import utils 17 | 18 | 19 | class CocoEvaluator(object): 20 | def __init__(self, coco_gt, iou_types): 21 | assert isinstance(iou_types, (list, tuple)) 22 | coco_gt = copy.deepcopy(coco_gt) 23 | self.coco_gt = coco_gt 24 | 25 | self.iou_types = iou_types 26 | self.coco_eval = {} 27 | for iou_type in iou_types: 28 | self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) 29 | 30 | self.img_ids = [] 31 | self.eval_imgs = {k: [] for k in iou_types} 32 | 33 | def update(self, predictions): 34 | img_ids = list(np.unique(list(predictions.keys()))) 35 | self.img_ids.extend(img_ids) 36 | 37 | for iou_type in self.iou_types: 38 | results = self.prepare(predictions, iou_type) 39 | coco_dt = loadRes(self.coco_gt, results) if results else COCO() 40 | coco_eval = self.coco_eval[iou_type] 41 | 42 | coco_eval.cocoDt = coco_dt 43 | coco_eval.params.imgIds = list(img_ids) 44 | img_ids, eval_imgs = evaluate(coco_eval) 45 | 46 | self.eval_imgs[iou_type].append(eval_imgs) 47 | 48 | def synchronize_between_processes(self): 49 | for iou_type in self.iou_types: 50 | self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) 51 | create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) 52 | 53 | def accumulate(self): 54 | for coco_eval in self.coco_eval.values(): 55 | coco_eval.accumulate() 56 | 57 | def summarize(self): 58 | for iou_type, coco_eval in self.coco_eval.items(): 59 | print("IoU metric: {}".format(iou_type)) 60 | coco_eval.summarize() 61 | 62 | def prepare(self, predictions, iou_type): 63 | if iou_type == "bbox": 64 | return self.prepare_for_coco_detection(predictions) 65 | elif iou_type == "segm": 66 | return self.prepare_for_coco_segmentation(predictions) 67 | elif iou_type == "keypoints": 68 | return self.prepare_for_coco_keypoint(predictions) 69 | else: 70 | raise ValueError("Unknown iou type {}".format(iou_type)) 71 | 72 | def prepare_for_coco_detection(self, predictions): 73 | coco_results = [] 74 | for original_id, prediction in predictions.items(): 75 | if len(prediction) == 0: 76 | continue 77 | 78 | boxes = prediction["boxes"] 79 | boxes = convert_to_xywh(boxes).tolist() 80 | scores = prediction["scores"].tolist() 81 | labels = prediction["labels"].tolist() 82 | 83 | coco_results.extend( 84 | [ 85 | { 86 | "image_id": original_id, 87 | "category_id": labels[k], 88 | "bbox": box, 89 | "score": scores[k], 90 | } 91 | for k, box in enumerate(boxes) 92 | ] 93 | ) 94 | return coco_results 95 | 96 | def prepare_for_coco_segmentation(self, predictions): 97 | coco_results = [] 98 | for original_id, prediction in predictions.items(): 99 | if len(prediction) == 0: 100 | continue 101 | 102 | scores = prediction["scores"] 103 | labels = prediction["labels"] 104 | masks = prediction["masks"] 105 | 106 | masks = masks > 0.5 107 | 108 | scores = prediction["scores"].tolist() 109 | labels = prediction["labels"].tolist() 110 | 111 | rles = [ 112 | mask_util.encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0] 113 | for mask in masks 114 | ] 115 | for rle in rles: 116 | rle["counts"] = rle["counts"].decode("utf-8") 117 | 118 | coco_results.extend( 119 | [ 120 | { 121 | "image_id": original_id, 122 | "category_id": labels[k], 123 | "segmentation": rle, 124 | "score": scores[k], 125 | } 126 | for k, rle in enumerate(rles) 127 | ] 128 | ) 129 | return coco_results 130 | 131 | def prepare_for_coco_keypoint(self, predictions): 132 | coco_results = [] 133 | for original_id, prediction in predictions.items(): 134 | if len(prediction) == 0: 135 | continue 136 | 137 | boxes = prediction["boxes"] 138 | boxes = convert_to_xywh(boxes).tolist() 139 | scores = prediction["scores"].tolist() 140 | labels = prediction["labels"].tolist() 141 | keypoints = prediction["keypoints"] 142 | keypoints = keypoints.flatten(start_dim=1).tolist() 143 | 144 | coco_results.extend( 145 | [ 146 | { 147 | "image_id": original_id, 148 | "category_id": labels[k], 149 | 'keypoints': keypoint, 150 | "score": scores[k], 151 | } 152 | for k, keypoint in enumerate(keypoints) 153 | ] 154 | ) 155 | return coco_results 156 | 157 | 158 | def convert_to_xywh(boxes): 159 | xmin, ymin, xmax, ymax = boxes.unbind(1) 160 | return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) 161 | 162 | 163 | def merge(img_ids, eval_imgs): 164 | all_img_ids = utils.all_gather(img_ids) 165 | all_eval_imgs = utils.all_gather(eval_imgs) 166 | 167 | merged_img_ids = [] 168 | for p in all_img_ids: 169 | merged_img_ids.extend(p) 170 | 171 | merged_eval_imgs = [] 172 | for p in all_eval_imgs: 173 | merged_eval_imgs.append(p) 174 | 175 | merged_img_ids = np.array(merged_img_ids) 176 | merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) 177 | 178 | # keep only unique (and in sorted order) images 179 | merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) 180 | merged_eval_imgs = merged_eval_imgs[..., idx] 181 | 182 | return merged_img_ids, merged_eval_imgs 183 | 184 | 185 | def create_common_coco_eval(coco_eval, img_ids, eval_imgs): 186 | img_ids, eval_imgs = merge(img_ids, eval_imgs) 187 | img_ids = list(img_ids) 188 | eval_imgs = list(eval_imgs.flatten()) 189 | 190 | coco_eval.evalImgs = eval_imgs 191 | coco_eval.params.imgIds = img_ids 192 | coco_eval._paramsEval = copy.deepcopy(coco_eval.params) 193 | 194 | 195 | ################################################################# 196 | # From pycocotools, just removed the prints and fixed 197 | # a Python3 bug about unicode not defined 198 | ################################################################# 199 | 200 | # Ideally, pycocotools wouldn't have hard-coded prints 201 | # so that we could avoid copy-pasting those two functions 202 | 203 | def createIndex(self): 204 | # create index 205 | # print('creating index...') 206 | anns, cats, imgs = {}, {}, {} 207 | imgToAnns, catToImgs = defaultdict(list), defaultdict(list) 208 | if 'annotations' in self.dataset: 209 | for ann in self.dataset['annotations']: 210 | imgToAnns[ann['image_id']].append(ann) 211 | anns[ann['id']] = ann 212 | 213 | if 'images' in self.dataset: 214 | for img in self.dataset['images']: 215 | imgs[img['id']] = img 216 | 217 | if 'categories' in self.dataset: 218 | for cat in self.dataset['categories']: 219 | cats[cat['id']] = cat 220 | 221 | if 'annotations' in self.dataset and 'categories' in self.dataset: 222 | for ann in self.dataset['annotations']: 223 | catToImgs[ann['category_id']].append(ann['image_id']) 224 | 225 | # print('index created!') 226 | 227 | # create class members 228 | self.anns = anns 229 | self.imgToAnns = imgToAnns 230 | self.catToImgs = catToImgs 231 | self.imgs = imgs 232 | self.cats = cats 233 | 234 | 235 | maskUtils = mask_util 236 | 237 | 238 | def loadRes(self, resFile): 239 | """ 240 | Load result file and return a result api object. 241 | :param resFile (str) : file name of result file 242 | :return: res (obj) : result api object 243 | """ 244 | res = COCO() 245 | res.dataset['images'] = [img for img in self.dataset['images']] 246 | 247 | # print('Loading and preparing results...') 248 | # tic = time.time() 249 | if isinstance(resFile, torch._six.string_classes): 250 | anns = json.load(open(resFile)) 251 | elif type(resFile) == np.ndarray: 252 | anns = self.loadNumpyAnnotations(resFile) 253 | else: 254 | anns = resFile 255 | assert type(anns) == list, 'results in not an array of objects' 256 | annsImgIds = [ann['image_id'] for ann in anns] 257 | assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \ 258 | 'Results do not correspond to current coco set' 259 | if 'caption' in anns[0]: 260 | imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns]) 261 | res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds] 262 | for id, ann in enumerate(anns): 263 | ann['id'] = id + 1 264 | elif 'bbox' in anns[0] and not anns[0]['bbox'] == []: 265 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 266 | for id, ann in enumerate(anns): 267 | bb = ann['bbox'] 268 | x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]] 269 | if 'segmentation' not in ann: 270 | ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]] 271 | ann['area'] = bb[2] * bb[3] 272 | ann['id'] = id + 1 273 | ann['iscrowd'] = 0 274 | elif 'segmentation' in anns[0]: 275 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 276 | for id, ann in enumerate(anns): 277 | # now only support compressed RLE format as segmentation results 278 | ann['area'] = maskUtils.area(ann['segmentation']) 279 | if 'bbox' not in ann: 280 | ann['bbox'] = maskUtils.toBbox(ann['segmentation']) 281 | ann['id'] = id + 1 282 | ann['iscrowd'] = 0 283 | elif 'keypoints' in anns[0]: 284 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 285 | for id, ann in enumerate(anns): 286 | s = ann['keypoints'] 287 | x = s[0::3] 288 | y = s[1::3] 289 | x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y) 290 | ann['area'] = (x1 - x0) * (y1 - y0) 291 | ann['id'] = id + 1 292 | ann['bbox'] = [x0, y0, x1 - x0, y1 - y0] 293 | # print('DONE (t={:0.2f}s)'.format(time.time()- tic)) 294 | 295 | res.dataset['annotations'] = anns 296 | createIndex(res) 297 | return res 298 | 299 | 300 | def evaluate(self): 301 | ''' 302 | Run per image evaluation on given images and store results (a list of dict) in self.evalImgs 303 | :return: None 304 | ''' 305 | # tic = time.time() 306 | # print('Running per image evaluation...') 307 | p = self.params 308 | # add backward compatibility if useSegm is specified in params 309 | if p.useSegm is not None: 310 | p.iouType = 'segm' if p.useSegm == 1 else 'bbox' 311 | print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) 312 | # print('Evaluate annotation type *{}*'.format(p.iouType)) 313 | p.imgIds = list(np.unique(p.imgIds)) 314 | if p.useCats: 315 | p.catIds = list(np.unique(p.catIds)) 316 | p.maxDets = sorted(p.maxDets) 317 | self.params = p 318 | 319 | self._prepare() 320 | # loop through images, area range, max detection number 321 | catIds = p.catIds if p.useCats else [-1] 322 | 323 | if p.iouType == 'segm' or p.iouType == 'bbox': 324 | computeIoU = self.computeIoU 325 | elif p.iouType == 'keypoints': 326 | computeIoU = self.computeOks 327 | self.ious = { 328 | (imgId, catId): computeIoU(imgId, catId) 329 | for imgId in p.imgIds 330 | for catId in catIds} 331 | 332 | evaluateImg = self.evaluateImg 333 | maxDet = p.maxDets[-1] 334 | evalImgs = [ 335 | evaluateImg(imgId, catId, areaRng, maxDet) 336 | for catId in catIds 337 | for areaRng in p.areaRng 338 | for imgId in p.imgIds 339 | ] 340 | # this is NOT in the pycocotools code, but could be done outside 341 | evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) 342 | self._paramsEval = copy.deepcopy(self.params) 343 | # toc = time.time() 344 | # print('DONE (t={:0.2f}s).'.format(toc-tic)) 345 | return p.imgIds, evalImgs 346 | 347 | ################################################################# 348 | # end of straight copy from pycocotools, just removing the prints 349 | ################################################################# 350 | -------------------------------------------------------------------------------- /coco_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | from PIL import Image 4 | 5 | import torch 6 | import torch.utils.data 7 | import torchvision 8 | 9 | from pycocotools import mask as coco_mask 10 | from pycocotools.coco import COCO 11 | 12 | import transforms as T 13 | 14 | 15 | class FilterAndRemapCocoCategories(object): 16 | def __init__(self, categories, remap=True): 17 | self.categories = categories 18 | self.remap = remap 19 | 20 | def __call__(self, image, target): 21 | anno = target["annotations"] 22 | anno = [obj for obj in anno if obj["category_id"] in self.categories] 23 | if not self.remap: 24 | target["annotations"] = anno 25 | return image, target 26 | anno = copy.deepcopy(anno) 27 | for obj in anno: 28 | obj["category_id"] = self.categories.index(obj["category_id"]) 29 | target["annotations"] = anno 30 | return image, target 31 | 32 | 33 | def convert_coco_poly_to_mask(segmentations, height, width): 34 | masks = [] 35 | for polygons in segmentations: 36 | rles = coco_mask.frPyObjects(polygons, height, width) 37 | mask = coco_mask.decode(rles) 38 | if len(mask.shape) < 3: 39 | mask = mask[..., None] 40 | mask = torch.as_tensor(mask, dtype=torch.uint8) 41 | mask = mask.any(dim=2) 42 | masks.append(mask) 43 | if masks: 44 | masks = torch.stack(masks, dim=0) 45 | else: 46 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 47 | return masks 48 | 49 | 50 | class ConvertCocoPolysToMask(object): 51 | def __call__(self, image, target): 52 | w, h = image.size 53 | 54 | image_id = target["image_id"] 55 | image_id = torch.tensor([image_id]) 56 | 57 | anno = target["annotations"] 58 | 59 | anno = [obj for obj in anno if obj['iscrowd'] == 0] 60 | 61 | boxes = [obj["bbox"] for obj in anno] 62 | # guard against no boxes via resizing 63 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 64 | boxes[:, 2:] += boxes[:, :2] 65 | boxes[:, 0::2].clamp_(min=0, max=w) 66 | boxes[:, 1::2].clamp_(min=0, max=h) 67 | 68 | classes = [obj["category_id"] for obj in anno] 69 | classes = torch.tensor(classes, dtype=torch.int64) 70 | 71 | segmentations = [obj["segmentation"] for obj in anno] 72 | masks = convert_coco_poly_to_mask(segmentations, h, w) 73 | 74 | keypoints = None 75 | if anno and "keypoints" in anno[0]: 76 | keypoints = [obj["keypoints"] for obj in anno] 77 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32) 78 | num_keypoints = keypoints.shape[0] 79 | if num_keypoints: 80 | keypoints = keypoints.view(num_keypoints, -1, 3) 81 | 82 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 83 | boxes = boxes[keep] 84 | classes = classes[keep] 85 | masks = masks[keep] 86 | if keypoints is not None: 87 | keypoints = keypoints[keep] 88 | 89 | target = {} 90 | target["boxes"] = boxes 91 | target["labels"] = classes 92 | target["masks"] = masks 93 | target["image_id"] = image_id 94 | if keypoints is not None: 95 | target["keypoints"] = keypoints 96 | 97 | # for conversion to coco api 98 | area = torch.tensor([obj["area"] for obj in anno]) 99 | iscrowd = torch.tensor([obj["iscrowd"] for obj in anno]) 100 | target["area"] = area 101 | target["iscrowd"] = iscrowd 102 | 103 | return image, target 104 | 105 | 106 | def _coco_remove_images_without_annotations(dataset, cat_list=None): 107 | def _has_only_empty_bbox(anno): 108 | return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) 109 | 110 | def _count_visible_keypoints(anno): 111 | return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) 112 | 113 | min_keypoints_per_image = 10 114 | 115 | def _has_valid_annotation(anno): 116 | # if it's empty, there is no annotation 117 | if len(anno) == 0: 118 | return False 119 | # if all boxes have close to zero area, there is no annotation 120 | if _has_only_empty_bbox(anno): 121 | return False 122 | # keypoints task have a slight different critera for considering 123 | # if an annotation is valid 124 | if "keypoints" not in anno[0]: 125 | return True 126 | # for keypoint detection tasks, only consider valid images those 127 | # containing at least min_keypoints_per_image 128 | if _count_visible_keypoints(anno) >= min_keypoints_per_image: 129 | return True 130 | return False 131 | 132 | assert isinstance(dataset, torchvision.datasets.CocoDetection) 133 | ids = [] 134 | for ds_idx, img_id in enumerate(dataset.ids): 135 | ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) 136 | anno = dataset.coco.loadAnns(ann_ids) 137 | if cat_list: 138 | anno = [obj for obj in anno if obj["category_id"] in cat_list] 139 | if _has_valid_annotation(anno): 140 | ids.append(ds_idx) 141 | 142 | dataset = torch.utils.data.Subset(dataset, ids) 143 | return dataset 144 | 145 | 146 | def convert_to_coco_api(ds): 147 | coco_ds = COCO() 148 | ann_id = 0 149 | dataset = {'images': [], 'categories': [], 'annotations': []} 150 | categories = set() 151 | for img_idx in range(len(ds)): 152 | # find better way to get target 153 | # targets = ds.get_annotations(img_idx) 154 | img, targets = ds[img_idx] 155 | image_id = targets["image_id"].item() 156 | img_dict = {} 157 | img_dict['id'] = image_id 158 | img_dict['height'] = img.shape[-2] 159 | img_dict['width'] = img.shape[-1] 160 | dataset['images'].append(img_dict) 161 | bboxes = targets["boxes"] 162 | bboxes[:, 2:] -= bboxes[:, :2] 163 | bboxes = bboxes.tolist() 164 | labels = targets['labels'].tolist() 165 | areas = targets['area'].tolist() 166 | iscrowd = targets['iscrowd'].tolist() 167 | if 'masks' in targets: 168 | masks = targets['masks'] 169 | # make masks Fortran contiguous for coco_mask 170 | masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1) 171 | if 'keypoints' in targets: 172 | keypoints = targets['keypoints'] 173 | keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist() 174 | num_objs = len(bboxes) 175 | for i in range(num_objs): 176 | ann = {} 177 | ann['image_id'] = image_id 178 | ann['bbox'] = bboxes[i] 179 | ann['category_id'] = labels[i] 180 | categories.add(labels[i]) 181 | ann['area'] = areas[i] 182 | ann['iscrowd'] = iscrowd[i] 183 | ann['id'] = ann_id 184 | if 'masks' in targets: 185 | ann["segmentation"] = coco_mask.encode(masks[i].numpy()) 186 | if 'keypoints' in targets: 187 | ann['keypoints'] = keypoints[i] 188 | ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3]) 189 | dataset['annotations'].append(ann) 190 | ann_id += 1 191 | dataset['categories'] = [{'id': i} for i in sorted(categories)] 192 | coco_ds.dataset = dataset 193 | coco_ds.createIndex() 194 | return coco_ds 195 | 196 | 197 | def get_coco_api_from_dataset(dataset): 198 | for i in range(10): 199 | if isinstance(dataset, torchvision.datasets.CocoDetection): 200 | break 201 | if isinstance(dataset, torch.utils.data.Subset): 202 | dataset = dataset.dataset 203 | if isinstance(dataset, torchvision.datasets.CocoDetection): 204 | return dataset.coco 205 | return convert_to_coco_api(dataset) 206 | 207 | 208 | class CocoDetection(torchvision.datasets.CocoDetection): 209 | def __init__(self, img_folder, ann_file, transforms): 210 | super(CocoDetection, self).__init__(img_folder, ann_file) 211 | self._transforms = transforms 212 | 213 | def __getitem__(self, idx): 214 | img, target = super(CocoDetection, self).__getitem__(idx) 215 | image_id = self.ids[idx] 216 | target = dict(image_id=image_id, annotations=target) 217 | if self._transforms is not None: 218 | img, target = self._transforms(img, target) 219 | return img, target 220 | 221 | 222 | def get_coco(root, image_set, transforms, mode='instances'): 223 | anno_file_template = "{}_{}2017.json" 224 | PATHS = { 225 | "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), 226 | "val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))), 227 | # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))) 228 | } 229 | 230 | t = [ConvertCocoPolysToMask()] 231 | 232 | if transforms is not None: 233 | t.append(transforms) 234 | transforms = T.Compose(t) 235 | 236 | img_folder, ann_file = PATHS[image_set] 237 | img_folder = os.path.join(root, img_folder) 238 | ann_file = os.path.join(root, ann_file) 239 | 240 | dataset = CocoDetection(img_folder, ann_file, transforms=transforms) 241 | 242 | if image_set == "train": 243 | dataset = _coco_remove_images_without_annotations(dataset) 244 | 245 | # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)]) 246 | 247 | return dataset 248 | 249 | 250 | def get_coco_kp(root, image_set, transforms): 251 | return get_coco(root, image_set, transforms, mode="person_keypoints") 252 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | import time 4 | import torch 5 | 6 | import torchvision.models.detection.mask_rcnn 7 | 8 | from coco_utils import get_coco_api_from_dataset 9 | from coco_eval import CocoEvaluator 10 | import utils 11 | 12 | 13 | def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): 14 | model.train() 15 | metric_logger = utils.MetricLogger(delimiter=" ") 16 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 17 | header = 'Epoch: [{}]'.format(epoch) 18 | 19 | lr_scheduler = None 20 | if epoch == 0: 21 | warmup_factor = 1. / 1000 22 | warmup_iters = min(1000, len(data_loader) - 1) 23 | 24 | lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor) 25 | 26 | for images, targets in metric_logger.log_every(data_loader, print_freq, header): 27 | images = list(image.to(device) for image in images) 28 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 29 | 30 | loss_dict = model(images, targets) 31 | 32 | losses = sum(loss for loss in loss_dict.values()) 33 | 34 | # reduce losses over all GPUs for logging purposes 35 | loss_dict_reduced = utils.reduce_dict(loss_dict) 36 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 37 | 38 | loss_value = losses_reduced.item() 39 | 40 | if not math.isfinite(loss_value): 41 | print("Loss is {}, stopping training".format(loss_value)) 42 | print(loss_dict_reduced) 43 | sys.exit(1) 44 | 45 | optimizer.zero_grad() 46 | losses.backward() 47 | optimizer.step() 48 | 49 | if lr_scheduler is not None: 50 | lr_scheduler.step() 51 | 52 | metric_logger.update(loss=losses_reduced, **loss_dict_reduced) 53 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 54 | 55 | 56 | def _get_iou_types(model): 57 | model_without_ddp = model 58 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 59 | model_without_ddp = model.module 60 | iou_types = ["bbox"] 61 | if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN): 62 | iou_types.append("segm") 63 | if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN): 64 | iou_types.append("keypoints") 65 | return iou_types 66 | 67 | 68 | @torch.no_grad() 69 | def evaluate(model, data_loader, device): 70 | n_threads = torch.get_num_threads() 71 | # FIXME remove this and make paste_masks_in_image run on the GPU 72 | torch.set_num_threads(1) 73 | cpu_device = torch.device("cpu") 74 | model.eval() 75 | metric_logger = utils.MetricLogger(delimiter=" ") 76 | header = 'Test:' 77 | 78 | coco = get_coco_api_from_dataset(data_loader.dataset) 79 | iou_types = _get_iou_types(model) 80 | coco_evaluator = CocoEvaluator(coco, iou_types) 81 | 82 | for image, targets in metric_logger.log_every(data_loader, 100, header): 83 | image = list(img.to(device) for img in image) 84 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 85 | 86 | torch.cuda.synchronize() 87 | model_time = time.time() 88 | outputs = model(image) 89 | 90 | outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs] 91 | model_time = time.time() - model_time 92 | 93 | res = {target["image_id"].item(): output for target, output in zip(targets, outputs)} 94 | evaluator_time = time.time() 95 | coco_evaluator.update(res) 96 | evaluator_time = time.time() - evaluator_time 97 | metric_logger.update(model_time=model_time, evaluator_time=evaluator_time) 98 | 99 | # gather the stats from all processes 100 | metric_logger.synchronize_between_processes() 101 | print("Averaged stats:", metric_logger) 102 | coco_evaluator.synchronize_between_processes() 103 | 104 | # accumulate predictions from all images 105 | coco_evaluator.accumulate() 106 | coco_evaluator.summarize() 107 | torch.set_num_threads(n_threads) 108 | return coco_evaluator 109 | -------------------------------------------------------------------------------- /group_by_aspect_ratio.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | from collections import defaultdict 3 | import copy 4 | import numpy as np 5 | 6 | import torch 7 | import torch.utils.data 8 | from torch.utils.data.sampler import BatchSampler, Sampler 9 | from torch.utils.model_zoo import tqdm 10 | import torchvision 11 | 12 | from PIL import Image 13 | 14 | 15 | class GroupedBatchSampler(BatchSampler): 16 | """ 17 | Wraps another sampler to yield a mini-batch of indices. 18 | It enforces that the batch only contain elements from the same group. 19 | It also tries to provide mini-batches which follows an ordering which is 20 | as close as possible to the ordering from the original sampler. 21 | Arguments: 22 | sampler (Sampler): Base sampler. 23 | group_ids (list[int]): If the sampler produces indices in range [0, N), 24 | `group_ids` must be a list of `N` ints which contains the group id of each sample. 25 | The group ids must be a continuous set of integers starting from 26 | 0, i.e. they must be in the range [0, num_groups). 27 | batch_size (int): Size of mini-batch. 28 | """ 29 | def __init__(self, sampler, group_ids, batch_size): 30 | if not isinstance(sampler, Sampler): 31 | raise ValueError( 32 | "sampler should be an instance of " 33 | "torch.utils.data.Sampler, but got sampler={}".format(sampler) 34 | ) 35 | self.sampler = sampler 36 | self.group_ids = group_ids 37 | self.batch_size = batch_size 38 | 39 | def __iter__(self): 40 | buffer_per_group = defaultdict(list) 41 | samples_per_group = defaultdict(list) 42 | 43 | num_batches = 0 44 | for idx in self.sampler: 45 | group_id = self.group_ids[idx] 46 | buffer_per_group[group_id].append(idx) 47 | samples_per_group[group_id].append(idx) 48 | if len(buffer_per_group[group_id]) == self.batch_size: 49 | yield buffer_per_group[group_id] 50 | num_batches += 1 51 | del buffer_per_group[group_id] 52 | assert len(buffer_per_group[group_id]) < self.batch_size 53 | 54 | # now we have run out of elements that satisfy 55 | # the group criteria, let's return the remaining 56 | # elements so that the size of the sampler is 57 | # deterministic 58 | expected_num_batches = len(self) 59 | num_remaining = expected_num_batches - num_batches 60 | if num_remaining > 0: 61 | # for the remaining batches, take first the buffers with largest number 62 | # of elements 63 | for group_id, _ in sorted(buffer_per_group.items(), 64 | key=lambda x: len(x[1]), reverse=True): 65 | remaining = self.batch_size - len(buffer_per_group[group_id]) 66 | buffer_per_group[group_id].extend( 67 | samples_per_group[group_id][:remaining]) 68 | assert len(buffer_per_group[group_id]) == self.batch_size 69 | yield buffer_per_group[group_id] 70 | num_remaining -= 1 71 | if num_remaining == 0: 72 | break 73 | assert num_remaining == 0 74 | 75 | def __len__(self): 76 | return len(self.sampler) // self.batch_size 77 | 78 | 79 | def _compute_aspect_ratios_slow(dataset, indices=None): 80 | print("Your dataset doesn't support the fast path for " 81 | "computing the aspect ratios, so will iterate over " 82 | "the full dataset and load every image instead. " 83 | "This might take some time...") 84 | if indices is None: 85 | indices = range(len(dataset)) 86 | 87 | class SubsetSampler(Sampler): 88 | def __init__(self, indices): 89 | self.indices = indices 90 | 91 | def __iter__(self): 92 | return iter(self.indices) 93 | 94 | def __len__(self): 95 | return len(self.indices) 96 | 97 | sampler = SubsetSampler(indices) 98 | data_loader = torch.utils.data.DataLoader( 99 | dataset, batch_size=1, sampler=sampler, 100 | num_workers=14, # you might want to increase it for faster processing 101 | collate_fn=lambda x: x[0]) 102 | aspect_ratios = [] 103 | with tqdm(total=len(dataset)) as pbar: 104 | for i, (img, _) in enumerate(data_loader): 105 | pbar.update(1) 106 | height, width = img.shape[-2:] 107 | aspect_ratio = float(height) / float(width) 108 | aspect_ratios.append(aspect_ratio) 109 | return aspect_ratios 110 | 111 | 112 | def _compute_aspect_ratios_custom_dataset(dataset, indices=None): 113 | if indices is None: 114 | indices = range(len(dataset)) 115 | aspect_ratios = [] 116 | for i in indices: 117 | height, width = dataset.get_height_and_width(i) 118 | aspect_ratio = float(height) / float(width) 119 | aspect_ratios.append(aspect_ratio) 120 | return aspect_ratios 121 | 122 | 123 | def _compute_aspect_ratios_coco_dataset(dataset, indices=None): 124 | if indices is None: 125 | indices = range(len(dataset)) 126 | aspect_ratios = [] 127 | for i in indices: 128 | img_info = dataset.coco.imgs[dataset.ids[i]] 129 | aspect_ratio = float(img_info["height"]) / float(img_info["width"]) 130 | aspect_ratios.append(aspect_ratio) 131 | return aspect_ratios 132 | 133 | 134 | def _compute_aspect_ratios_voc_dataset(dataset, indices=None): 135 | if indices is None: 136 | indices = range(len(dataset)) 137 | aspect_ratios = [] 138 | for i in indices: 139 | # this doesn't load the data into memory, because PIL loads it lazily 140 | width, height = Image.open(dataset.images[i]).size 141 | aspect_ratio = float(height) / float(width) 142 | aspect_ratios.append(aspect_ratio) 143 | return aspect_ratios 144 | 145 | 146 | def _compute_aspect_ratios_subset_dataset(dataset, indices=None): 147 | if indices is None: 148 | indices = range(len(dataset)) 149 | 150 | ds_indices = [dataset.indices[i] for i in indices] 151 | return compute_aspect_ratios(dataset.dataset, ds_indices) 152 | 153 | 154 | def compute_aspect_ratios(dataset, indices=None): 155 | if hasattr(dataset, "get_height_and_width"): 156 | return _compute_aspect_ratios_custom_dataset(dataset, indices) 157 | 158 | if isinstance(dataset, torchvision.datasets.CocoDetection): 159 | return _compute_aspect_ratios_coco_dataset(dataset, indices) 160 | 161 | if isinstance(dataset, torchvision.datasets.VOCDetection): 162 | return _compute_aspect_ratios_voc_dataset(dataset, indices) 163 | 164 | if isinstance(dataset, torch.utils.data.Subset): 165 | return _compute_aspect_ratios_subset_dataset(dataset, indices) 166 | 167 | # slow path 168 | return _compute_aspect_ratios_slow(dataset, indices) 169 | 170 | 171 | def _quantize(x, bins): 172 | bins = copy.deepcopy(bins) 173 | bins = sorted(bins) 174 | quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) 175 | return quantized 176 | 177 | 178 | def create_aspect_ratio_groups(dataset, k=0): 179 | aspect_ratios = compute_aspect_ratios(dataset) 180 | bins = (2 ** np.linspace(-1, 1, 2 * k + 1)).tolist() if k > 0 else [1.0] 181 | groups = _quantize(aspect_ratios, bins) 182 | # count number of elements per group 183 | counts = np.unique(groups, return_counts=True)[1] 184 | fbins = [0] + bins + [np.inf] 185 | print("Using {} as bins for aspect ratio quantization".format(fbins)) 186 | print("Count of instances per bin: {}".format(counts)) 187 | return groups 188 | -------------------------------------------------------------------------------- /keypoint_rcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from torchvision.ops import misc as misc_nn_ops 5 | from torchvision.ops import MultiScaleRoIAlign 6 | 7 | from ..utils import load_state_dict_from_url 8 | 9 | from .faster_rcnn import FasterRCNN 10 | from .backbone_utils import resnet_fpn_backbone 11 | 12 | 13 | __all__ = [ 14 | "KeypointRCNN", "keypointrcnn_resnet50_fpn" 15 | ] 16 | 17 | 18 | class KeypointRCNN(FasterRCNN): 19 | """ 20 | Implements Keypoint R-CNN. 21 | 22 | The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each 23 | image, and should be in 0-1 range. Different images can have different sizes. 24 | 25 | The behavior of the model changes depending if it is in training or evaluation mode. 26 | 27 | During training, the model expects both the input tensors, as well as a targets dictionary, 28 | containing: 29 | - boxes (Tensor[N, 4]): the ground-truth boxes in [x0, y0, x1, y1] format, with values 30 | between 0 and H and 0 and W 31 | - labels (Tensor[N]): the class label for each ground-truth box 32 | - keypoints (Tensor[N, K, 3]): the K keypoints location for each of the N instances, in the 33 | format [x, y, visibility], where visibility=0 means that the keypoint is not visible. 34 | 35 | The model returns a Dict[Tensor] during training, containing the classification and regression 36 | losses for both the RPN and the R-CNN, and the keypoint loss. 37 | 38 | During inference, the model requires only the input tensors, and returns the post-processed 39 | predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as 40 | follows: 41 | - boxes (Tensor[N, 4]): the predicted boxes in [x0, y0, x1, y1] format, with values between 42 | 0 and H and 0 and W 43 | - labels (Tensor[N]): the predicted labels for each image 44 | - scores (Tensor[N]): the scores or each prediction 45 | - keypoints (Tensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format. 46 | 47 | Arguments: 48 | backbone (nn.Module): the network used to compute the features for the model. 49 | It should contain a out_channels attribute, which indicates the number of output 50 | channels that each feature map has (and it should be the same for all feature maps). 51 | The backbone should return a single Tensor or and OrderedDict[Tensor]. 52 | num_classes (int): number of output classes of the model (including the background). 53 | If box_predictor is specified, num_classes should be None. 54 | min_size (int): minimum size of the image to be rescaled before feeding it to the backbone 55 | max_size (int): maximum size of the image to be rescaled before feeding it to the backbone 56 | image_mean (Tuple[float, float, float]): mean values used for input normalization. 57 | They are generally the mean values of the dataset on which the backbone has been trained 58 | on 59 | image_std (Tuple[float, float, float]): std values used for input normalization. 60 | They are generally the std values of the dataset on which the backbone has been trained on 61 | rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature 62 | maps. 63 | rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN 64 | rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training 65 | rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing 66 | rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training 67 | rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing 68 | rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals 69 | rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be 70 | considered as positive during training of the RPN. 71 | rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be 72 | considered as negative during training of the RPN. 73 | rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN 74 | for computing the loss 75 | rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training 76 | of the RPN 77 | box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in 78 | the locations indicated by the bounding boxes 79 | box_head (nn.Module): module that takes the cropped feature maps as input 80 | box_predictor (nn.Module): module that takes the output of box_head and returns the 81 | classification logits and box regression deltas. 82 | box_score_thresh (float): during inference, only return proposals with a classification score 83 | greater than box_score_thresh 84 | box_nms_thresh (float): NMS threshold for the prediction head. Used during inference 85 | box_detections_per_img (int): maximum number of detections per image, for all classes. 86 | box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be 87 | considered as positive during training of the classification head 88 | box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be 89 | considered as negative during training of the classification head 90 | box_batch_size_per_image (int): number of proposals that are sampled during training of the 91 | classification head 92 | box_positive_fraction (float): proportion of positive proposals in a mini-batch during training 93 | of the classification head 94 | bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the 95 | bounding boxes 96 | keypoint_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in 97 | the locations indicated by the bounding boxes, which will be used for the keypoint head. 98 | keypoint_head (nn.Module): module that takes the cropped feature maps as input 99 | keypoint_predictor (nn.Module): module that takes the output of the keypoint_head and returns the 100 | heatmap logits 101 | 102 | Example:: 103 | 104 | >>> import torchvision 105 | >>> from torchvision.models.detection import KeypointRCNN 106 | >>> from torchvision.models.detection.rpn import AnchorGenerator 107 | >>> 108 | >>> # load a pre-trained model for classification and return 109 | >>> # only the features 110 | >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features 111 | >>> # KeypointRCNN needs to know the number of 112 | >>> # output channels in a backbone. For mobilenet_v2, it's 1280 113 | >>> # so we need to add it here 114 | >>> backbone.out_channels = 1280 115 | >>> 116 | >>> # let's make the RPN generate 5 x 3 anchors per spatial 117 | >>> # location, with 5 different sizes and 3 different aspect 118 | >>> # ratios. We have a Tuple[Tuple[int]] because each feature 119 | >>> # map could potentially have different sizes and 120 | >>> # aspect ratios 121 | >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),), 122 | >>> aspect_ratios=((0.5, 1.0, 2.0),)) 123 | >>> 124 | >>> # let's define what are the feature maps that we will 125 | >>> # use to perform the region of interest cropping, as well as 126 | >>> # the size of the crop after rescaling. 127 | >>> # if your backbone returns a Tensor, featmap_names is expected to 128 | >>> # be [0]. More generally, the backbone should return an 129 | >>> # OrderedDict[Tensor], and in featmap_names you can choose which 130 | >>> # feature maps to use. 131 | >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0], 132 | >>> output_size=7, 133 | >>> sampling_ratio=2) 134 | >>> 135 | >>> keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0], 136 | >>> output_size=14, 137 | >>> sampling_ratio=2) 138 | >>> # put the pieces together inside a FasterRCNN model 139 | >>> model = KeypointRCNN(backbone, 140 | >>> num_classes=2, 141 | >>> rpn_anchor_generator=anchor_generator, 142 | >>> box_roi_pool=roi_pooler, 143 | >>> keypoint_roi_pool=keypoint_roi_pooler) 144 | >>> model.eval() 145 | >>> model.eval() 146 | >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] 147 | >>> predictions = model(x) 148 | """ 149 | def __init__(self, backbone, num_classes=None, 150 | # transform parameters 151 | min_size=None, max_size=1333, 152 | image_mean=None, image_std=None, 153 | # RPN parameters 154 | rpn_anchor_generator=None, rpn_head=None, 155 | rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, 156 | rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, 157 | rpn_nms_thresh=0.7, 158 | rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, 159 | rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, 160 | # Box parameters 161 | box_roi_pool=None, box_head=None, box_predictor=None, 162 | box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, 163 | box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, 164 | box_batch_size_per_image=512, box_positive_fraction=0.25, 165 | bbox_reg_weights=None, 166 | # keypoint parameters 167 | keypoint_roi_pool=None, keypoint_head=None, keypoint_predictor=None, 168 | num_keypoints=17): 169 | 170 | assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))) 171 | if min_size is None: 172 | min_size = (640, 672, 704, 736, 768, 800) 173 | 174 | if num_classes is not None: 175 | if keypoint_predictor is not None: 176 | raise ValueError("num_classes should be None when keypoint_predictor is specified") 177 | 178 | out_channels = backbone.out_channels 179 | 180 | if keypoint_roi_pool is None: 181 | keypoint_roi_pool = MultiScaleRoIAlign( 182 | featmap_names=[0, 1, 2, 3], 183 | output_size=14, 184 | sampling_ratio=2) 185 | 186 | if keypoint_head is None: 187 | keypoint_layers = tuple(512 for _ in range(8)) 188 | keypoint_head = KeypointRCNNHeads(out_channels, keypoint_layers) 189 | 190 | if keypoint_predictor is None: 191 | keypoint_dim_reduced = 512 # == keypoint_layers[-1] 192 | keypoint_predictor = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints) 193 | 194 | super(KeypointRCNN, self).__init__( 195 | backbone, num_classes, 196 | # transform parameters 197 | min_size, max_size, 198 | image_mean, image_std, 199 | # RPN-specific parameters 200 | rpn_anchor_generator, rpn_head, 201 | rpn_pre_nms_top_n_train, rpn_pre_nms_top_n_test, 202 | rpn_post_nms_top_n_train, rpn_post_nms_top_n_test, 203 | rpn_nms_thresh, 204 | rpn_fg_iou_thresh, rpn_bg_iou_thresh, 205 | rpn_batch_size_per_image, rpn_positive_fraction, 206 | # Box parameters 207 | box_roi_pool, box_head, box_predictor, 208 | box_score_thresh, box_nms_thresh, box_detections_per_img, 209 | box_fg_iou_thresh, box_bg_iou_thresh, 210 | box_batch_size_per_image, box_positive_fraction, 211 | bbox_reg_weights) 212 | 213 | self.roi_heads.keypoint_roi_pool = keypoint_roi_pool 214 | self.roi_heads.keypoint_head = keypoint_head 215 | self.roi_heads.keypoint_predictor = keypoint_predictor 216 | 217 | 218 | class KeypointRCNNHeads(nn.Sequential): 219 | def __init__(self, in_channels, layers): 220 | d = [] 221 | next_feature = in_channels 222 | for l in layers: 223 | d.append(misc_nn_ops.Conv2d(next_feature, l, 3, stride=1, padding=1)) 224 | d.append(nn.ReLU(inplace=True)) 225 | next_feature = l 226 | super(KeypointRCNNHeads, self).__init__(*d) 227 | for m in self.children(): 228 | if isinstance(m, misc_nn_ops.Conv2d): 229 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 230 | nn.init.constant_(m.bias, 0) 231 | 232 | 233 | class KeypointRCNNPredictor(nn.Module): 234 | def __init__(self, in_channels, num_keypoints): 235 | super(KeypointRCNNPredictor, self).__init__() 236 | input_features = in_channels 237 | deconv_kernel = 4 238 | self.kps_score_lowres = misc_nn_ops.ConvTranspose2d( 239 | input_features, 240 | num_keypoints, 241 | deconv_kernel, 242 | stride=2, 243 | padding=deconv_kernel // 2 - 1, 244 | ) 245 | nn.init.kaiming_normal_( 246 | self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu" 247 | ) 248 | nn.init.constant_(self.kps_score_lowres.bias, 0) 249 | self.up_scale = 2 250 | self.out_channels = num_keypoints 251 | 252 | def forward(self, x): 253 | x = self.kps_score_lowres(x) 254 | x = misc_nn_ops.interpolate( 255 | x, scale_factor=self.up_scale, mode="bilinear", align_corners=False 256 | ) 257 | return x 258 | 259 | 260 | model_urls = { 261 | 'keypointrcnn_resnet50_fpn_coco': 262 | 'https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth', 263 | } 264 | 265 | 266 | def keypointrcnn_resnet50_fpn(pretrained=False, progress=True, 267 | num_classes=2, num_keypoints=17, 268 | pretrained_backbone=True, **kwargs): 269 | """ 270 | Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone. 271 | 272 | The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each 273 | image, and should be in ``0-1`` range. Different images can have different sizes. 274 | 275 | The behavior of the model changes depending if it is in training or evaluation mode. 276 | 277 | During training, the model expects both the input tensors, as well as a targets dictionary, 278 | containing: 279 | - boxes (``Tensor[N, 4]``): the ground-truth boxes in ``[x0, y0, x1, y1]`` format, with values 280 | between ``0`` and ``H`` and ``0`` and ``W`` 281 | - labels (``Tensor[N]``): the class label for each ground-truth box 282 | - keypoints (``Tensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the 283 | format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible. 284 | 285 | The model returns a ``Dict[Tensor]`` during training, containing the classification and regression 286 | losses for both the RPN and the R-CNN, and the keypoint loss. 287 | 288 | During inference, the model requires only the input tensors, and returns the post-processed 289 | predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as 290 | follows: 291 | - boxes (``Tensor[N, 4]``): the predicted boxes in ``[x0, y0, x1, y1]`` format, with values between 292 | ``0`` and ``H`` and ``0`` and ``W`` 293 | - labels (``Tensor[N]``): the predicted labels for each image 294 | - scores (``Tensor[N]``): the scores or each prediction 295 | - keypoints (``Tensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format. 296 | 297 | Example:: 298 | 299 | >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=True) 300 | >>> model.eval() 301 | >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] 302 | >>> predictions = model(x) 303 | 304 | Arguments: 305 | pretrained (bool): If True, returns a model pre-trained on COCO train2017 306 | progress (bool): If True, displays a progress bar of the download to stderr 307 | """ 308 | if pretrained: 309 | # no need to download the backbone if pretrained is set 310 | pretrained_backbone = False 311 | backbone = resnet_fpn_backbone('resnet50', pretrained_backbone) 312 | model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) 313 | if pretrained: 314 | state_dict = load_state_dict_from_url(model_urls['keypointrcnn_resnet50_fpn_coco'], 315 | progress=progress) 316 | model.load_state_dict(state_dict) 317 | return model 318 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib 3 | import cv2 as cv 4 | import numpy as np 5 | import math 6 | plt.switch_backend('agg') 7 | 8 | 9 | def map_coco_to_personlab(keypoints): 10 | permute = [0, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3] 11 | return keypoints[:, permute, :] 12 | 13 | def plot_poses(img, skeletons, save_name='pose.jpg'): 14 | EDGES = [ 15 | (0, 14), 16 | (0, 13), 17 | (0, 4), 18 | (0, 1), 19 | (14, 16), 20 | (13, 15), 21 | (4, 10), 22 | (1, 7), 23 | (10, 11), 24 | (7, 8), 25 | (11, 12), 26 | (8, 9), 27 | (4, 5), 28 | (1, 2), 29 | (5, 6), 30 | (2, 3) 31 | ] 32 | NUM_EDGES = len(EDGES) 33 | 34 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ 35 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ 36 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 37 | cmap = matplotlib.cm.get_cmap('hsv') 38 | plt.figure() 39 | 40 | #img = img.astype('uint8') 41 | canvas = img.copy() 42 | 43 | for i in range(17): 44 | rgba = np.array(cmap(1 - i/17. - 1./34)) 45 | rgba[0:3] *= 255 46 | for j in range(len(skeletons)): 47 | cv.circle(canvas, tuple(skeletons[j][i, 0:2].astype('int32')), 2, colors[i], thickness=-1) 48 | 49 | to_plot = cv.addWeighted(img, 0.3, canvas, 0.7, 0) 50 | fig = matplotlib.pyplot.gcf() 51 | 52 | stickwidth = 2 53 | 54 | skeletons = map_coco_to_personlab(skeletons) 55 | for i in range(NUM_EDGES): 56 | for j in range(len(skeletons)): 57 | edge = EDGES[i] 58 | if skeletons[j][edge[0],2] == 0 or skeletons[j][edge[1],2] == 0: 59 | continue 60 | 61 | cur_canvas = canvas.copy() 62 | X = [skeletons[j][edge[0], 1], skeletons[j][edge[1], 1]] 63 | Y = [skeletons[j][edge[0], 0], skeletons[j][edge[1], 0]] 64 | mX = np.mean(X) 65 | mY = np.mean(Y) 66 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 67 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 68 | polygon = cv.ellipse2Poly((int(mY),int(mX)), (int(length/2), stickwidth), int(angle), 0, 360, 1) 69 | cv.fillConvexPoly(cur_canvas, polygon, colors[i]) 70 | canvas = cv.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0) 71 | 72 | plt.imsave(save_name,canvas[:,:,:]) 73 | plt.close() -------------------------------------------------------------------------------- /predict_visualize.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import time 4 | 5 | import torch 6 | import torch.utils.data 7 | from torch import nn 8 | import torchvision 9 | import torchvision.models.detection 10 | import torchvision.models.detection.mask_rcnn 11 | 12 | from torchvision import transforms 13 | 14 | from coco_utils import get_coco, get_coco_kp 15 | 16 | from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups 17 | from engine import train_one_epoch, evaluate 18 | 19 | import utils 20 | import transforms as T 21 | 22 | from PIL import Image 23 | from plot import plot_poses 24 | import numpy as np 25 | def get_dataset(name, image_set, transform): 26 | paths = { 27 | "coco": ('/home/hzj/data/COCO2017/', get_coco, 91), 28 | "coco_kp": ('/home/hzj/data/COCO2017/', get_coco_kp, 2) 29 | } 30 | p, ds_fn, num_classes = paths[name] 31 | 32 | ds = ds_fn(p, image_set=image_set, transforms=transform) 33 | return ds, num_classes 34 | 35 | 36 | def get_transform(train): 37 | transforms = [] 38 | transforms.append(T.ToTensor()) 39 | if train: 40 | transforms.append(T.RandomHorizontalFlip(0.5)) 41 | return T.Compose(transforms) 42 | 43 | def main(): 44 | 45 | device = torch.device("cuda:0") 46 | 47 | # Data loading code 48 | print("Loading data") 49 | 50 | #dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True)) 51 | dataset_test, num_classes = get_dataset("coco_kp", "val", get_transform(train=False)) 52 | 53 | print("Creating data loaders") 54 | 55 | #train_sampler = torch.utils.data.RandomSampler(dataset) 56 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 57 | 58 | 59 | #train_batch_sampler = torch.utils.data.BatchSampler( 60 | # train_sampler, args.batch_size, drop_last=True) 61 | 62 | #data_loader = torch.utils.data.DataLoader( 63 | # dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, 64 | # collate_fn=utils.collate_fn) 65 | 66 | data_loader_test = torch.utils.data.DataLoader( 67 | dataset_test, batch_size=1, 68 | sampler=test_sampler, num_workers=4, 69 | collate_fn=utils.collate_fn) 70 | 71 | print("Creating model") 72 | model = torchvision.models.detection.__dict__['keypointrcnn_resnet50_fpn'](num_classes=num_classes, 73 | pretrained=True) 74 | model.to(device) 75 | 76 | #checkpoint = torch.load(args.resume, map_location='cpu') 77 | #model_without_ddp.load_state_dict(checkpoint['model']) 78 | #optimizer.load_state_dict(checkpoint['optimizer']) 79 | #lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 80 | 81 | model.eval() 82 | 83 | detect_threshold = 0.7 84 | keypoint_score_threshold = 2 85 | with torch.no_grad(): 86 | for i in range(20): 87 | img,_ = dataset_test[i] 88 | prediction = model([img.to(device)]) 89 | keypoints = prediction[0]['keypoints'].cpu().numpy() 90 | scores = prediction[0]['scores'].cpu().numpy() 91 | keypoints_scores = prediction[0]['keypoints_scores'].cpu().numpy() 92 | idx = np.where(scores>detect_threshold) 93 | keypoints = keypoints[idx] 94 | keypoints_scores = keypoints_scores[idx] 95 | for j in range(keypoints.shape[0]): 96 | for num in range(17): 97 | if keypoints_scores[j][num]= 0: 63 | group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor) 64 | train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size) 65 | else: 66 | train_batch_sampler = torch.utils.data.BatchSampler( 67 | train_sampler, args.batch_size, drop_last=True) 68 | 69 | data_loader = torch.utils.data.DataLoader( 70 | dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, 71 | collate_fn=utils.collate_fn) 72 | 73 | data_loader_test = torch.utils.data.DataLoader( 74 | dataset_test, batch_size=1, 75 | sampler=test_sampler, num_workers=args.workers, 76 | collate_fn=utils.collate_fn) 77 | 78 | print("Creating model") 79 | model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, 80 | pretrained=args.pretrained) 81 | model.to(device) 82 | 83 | model_without_ddp = model 84 | if args.distributed: 85 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 86 | model_without_ddp = model.module 87 | 88 | params = [p for p in model.parameters() if p.requires_grad] 89 | optimizer = torch.optim.SGD( 90 | params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 91 | 92 | # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) 93 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) 94 | 95 | if args.resume: 96 | checkpoint = torch.load(args.resume, map_location='cpu') 97 | model_without_ddp.load_state_dict(checkpoint['model']) 98 | optimizer.load_state_dict(checkpoint['optimizer']) 99 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 100 | 101 | if args.test_only: 102 | evaluate(model, data_loader_test, device=device) 103 | return 104 | 105 | print("Start training") 106 | start_time = time.time() 107 | for epoch in range(args.epochs): 108 | if args.distributed: 109 | train_sampler.set_epoch(epoch) 110 | train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq) 111 | lr_scheduler.step() 112 | if args.output_dir: 113 | utils.save_on_master({ 114 | 'model': model_without_ddp.state_dict(), 115 | 'optimizer': optimizer.state_dict(), 116 | 'lr_scheduler': lr_scheduler.state_dict(), 117 | 'args': args}, 118 | os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) 119 | 120 | # evaluate after every epoch 121 | evaluate(model, data_loader_test, device=device) 122 | 123 | total_time = time.time() - start_time 124 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 125 | print('Training time {}'.format(total_time_str)) 126 | 127 | 128 | if __name__ == "__main__": 129 | import argparse 130 | parser = argparse.ArgumentParser(description='PyTorch Detection Training') 131 | 132 | parser.add_argument('--data-path', default='/home/hzj/data/COCO2017/', help='dataset') 133 | parser.add_argument('--dataset', default='coco_kp', help='dataset') 134 | parser.add_argument('--model', default='keypointrcnn_resnet50_fpn', help='model') 135 | parser.add_argument('--device', default='cuda:0', help='device') 136 | parser.add_argument('-b', '--batch-size', default=2, type=int) 137 | parser.add_argument('--epochs', default=13, type=int, metavar='N', 138 | help='number of total epochs to run') 139 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 140 | help='number of data loading workers (default: 16)') 141 | parser.add_argument('--lr', default=0.02, type=float, help='initial learning rate') 142 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 143 | help='momentum') 144 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 145 | metavar='W', help='weight decay (default: 1e-4)', 146 | dest='weight_decay') 147 | parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs') 148 | parser.add_argument('--lr-steps', default=[8, 11], nargs='+', type=int, help='decrease lr every step-size epochs') 149 | parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') 150 | parser.add_argument('--print-freq', default=20, type=int, help='print frequency') 151 | parser.add_argument('--output-dir', default='.', help='path where to save') 152 | parser.add_argument('--resume', default='', help='resume from checkpoint') 153 | parser.add_argument('--aspect-ratio-group-factor', default=0, type=int) 154 | parser.add_argument( 155 | "--test-only", 156 | dest="test_only", 157 | help="Only test the model", 158 | action="store_true", 159 | ) 160 | 161 | parser.add_argument( 162 | "--pretrained", 163 | dest="pretrained", 164 | help="Use pre-trained models from the modelzoo", 165 | action="store_true", 166 | ) 167 | 168 | # distributed training parameters 169 | parser.add_argument('--world-size', default=1, type=int, 170 | help='number of distributed processes') 171 | parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') 172 | 173 | args = parser.parse_args() 174 | 175 | if args.output_dir: 176 | utils.mkdir(args.output_dir) 177 | 178 | main(args) 179 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | from torchvision.transforms import functional as F 5 | 6 | 7 | def _flip_coco_person_keypoints(kps, width): 8 | flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] 9 | flipped_data = kps[:, flip_inds] 10 | flipped_data[..., 0] = width - flipped_data[..., 0] 11 | # Maintain COCO convention that if visibility == 0, then x, y = 0 12 | inds = flipped_data[..., 2] == 0 13 | flipped_data[inds] = 0 14 | return flipped_data 15 | 16 | 17 | class Compose(object): 18 | def __init__(self, transforms): 19 | self.transforms = transforms 20 | 21 | def __call__(self, image, target): 22 | for t in self.transforms: 23 | image, target = t(image, target) 24 | return image, target 25 | 26 | 27 | class RandomHorizontalFlip(object): 28 | def __init__(self, prob): 29 | self.prob = prob 30 | 31 | def __call__(self, image, target): 32 | if random.random() < self.prob: 33 | height, width = image.shape[-2:] 34 | image = image.flip(-1) 35 | bbox = target["boxes"] 36 | bbox[:, [0, 2]] = width - bbox[:, [2, 0]] 37 | target["boxes"] = bbox 38 | if "masks" in target: 39 | target["masks"] = target["masks"].flip(-1) 40 | if "keypoints" in target: 41 | keypoints = target["keypoints"] 42 | keypoints = _flip_coco_person_keypoints(keypoints, width) 43 | target["keypoints"] = keypoints 44 | return image, target 45 | 46 | 47 | class ToTensor(object): 48 | def __call__(self, image, target): 49 | image = F.to_tensor(image) 50 | return image, target 51 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from collections import defaultdict, deque 4 | import datetime 5 | import pickle 6 | import time 7 | 8 | import torch 9 | import torch.distributed as dist 10 | 11 | import errno 12 | import os 13 | 14 | 15 | class SmoothedValue(object): 16 | """Track a series of values and provide access to smoothed values over a 17 | window or the global series average. 18 | """ 19 | 20 | def __init__(self, window_size=20, fmt=None): 21 | if fmt is None: 22 | fmt = "{median:.4f} ({global_avg:.4f})" 23 | self.deque = deque(maxlen=window_size) 24 | self.total = 0.0 25 | self.count = 0 26 | self.fmt = fmt 27 | 28 | def update(self, value, n=1): 29 | self.deque.append(value) 30 | self.count += n 31 | self.total += value * n 32 | 33 | def synchronize_between_processes(self): 34 | """ 35 | Warning: does not synchronize the deque! 36 | """ 37 | if not is_dist_avail_and_initialized(): 38 | return 39 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 40 | dist.barrier() 41 | dist.all_reduce(t) 42 | t = t.tolist() 43 | self.count = int(t[0]) 44 | self.total = t[1] 45 | 46 | @property 47 | def median(self): 48 | d = torch.tensor(list(self.deque)) 49 | return d.median().item() 50 | 51 | @property 52 | def avg(self): 53 | d = torch.tensor(list(self.deque), dtype=torch.float32) 54 | return d.mean().item() 55 | 56 | @property 57 | def global_avg(self): 58 | return self.total / self.count 59 | 60 | @property 61 | def max(self): 62 | return max(self.deque) 63 | 64 | @property 65 | def value(self): 66 | return self.deque[-1] 67 | 68 | def __str__(self): 69 | return self.fmt.format( 70 | median=self.median, 71 | avg=self.avg, 72 | global_avg=self.global_avg, 73 | max=self.max, 74 | value=self.value) 75 | 76 | 77 | def all_gather(data): 78 | """ 79 | Run all_gather on arbitrary picklable data (not necessarily tensors) 80 | Args: 81 | data: any picklable object 82 | Returns: 83 | list[data]: list of data gathered from each rank 84 | """ 85 | world_size = get_world_size() 86 | if world_size == 1: 87 | return [data] 88 | 89 | # serialized to a Tensor 90 | buffer = pickle.dumps(data) 91 | storage = torch.ByteStorage.from_buffer(buffer) 92 | tensor = torch.ByteTensor(storage).to("cuda") 93 | 94 | # obtain Tensor size of each rank 95 | local_size = torch.tensor([tensor.numel()], device="cuda") 96 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 97 | dist.all_gather(size_list, local_size) 98 | size_list = [int(size.item()) for size in size_list] 99 | max_size = max(size_list) 100 | 101 | # receiving Tensor from all ranks 102 | # we pad the tensor because torch all_gather does not support 103 | # gathering tensors of different shapes 104 | tensor_list = [] 105 | for _ in size_list: 106 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 107 | if local_size != max_size: 108 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 109 | tensor = torch.cat((tensor, padding), dim=0) 110 | dist.all_gather(tensor_list, tensor) 111 | 112 | data_list = [] 113 | for size, tensor in zip(size_list, tensor_list): 114 | buffer = tensor.cpu().numpy().tobytes()[:size] 115 | data_list.append(pickle.loads(buffer)) 116 | 117 | return data_list 118 | 119 | 120 | def reduce_dict(input_dict, average=True): 121 | """ 122 | Args: 123 | input_dict (dict): all the values will be reduced 124 | average (bool): whether to do average or sum 125 | Reduce the values in the dictionary from all processes so that all processes 126 | have the averaged results. Returns a dict with the same fields as 127 | input_dict, after reduction. 128 | """ 129 | world_size = get_world_size() 130 | if world_size < 2: 131 | return input_dict 132 | with torch.no_grad(): 133 | names = [] 134 | values = [] 135 | # sort the keys so that they are consistent across processes 136 | for k in sorted(input_dict.keys()): 137 | names.append(k) 138 | values.append(input_dict[k]) 139 | values = torch.stack(values, dim=0) 140 | dist.all_reduce(values) 141 | if average: 142 | values /= world_size 143 | reduced_dict = {k: v for k, v in zip(names, values)} 144 | return reduced_dict 145 | 146 | 147 | class MetricLogger(object): 148 | def __init__(self, delimiter="\t"): 149 | self.meters = defaultdict(SmoothedValue) 150 | self.delimiter = delimiter 151 | 152 | def update(self, **kwargs): 153 | for k, v in kwargs.items(): 154 | if isinstance(v, torch.Tensor): 155 | v = v.item() 156 | assert isinstance(v, (float, int)) 157 | self.meters[k].update(v) 158 | 159 | def __getattr__(self, attr): 160 | if attr in self.meters: 161 | return self.meters[attr] 162 | if attr in self.__dict__: 163 | return self.__dict__[attr] 164 | raise AttributeError("'{}' object has no attribute '{}'".format( 165 | type(self).__name__, attr)) 166 | 167 | def __str__(self): 168 | loss_str = [] 169 | for name, meter in self.meters.items(): 170 | loss_str.append( 171 | "{}: {}".format(name, str(meter)) 172 | ) 173 | return self.delimiter.join(loss_str) 174 | 175 | def synchronize_between_processes(self): 176 | for meter in self.meters.values(): 177 | meter.synchronize_between_processes() 178 | 179 | def add_meter(self, name, meter): 180 | self.meters[name] = meter 181 | 182 | def log_every(self, iterable, print_freq, header=None): 183 | i = 0 184 | if not header: 185 | header = '' 186 | start_time = time.time() 187 | end = time.time() 188 | iter_time = SmoothedValue(fmt='{avg:.4f}') 189 | data_time = SmoothedValue(fmt='{avg:.4f}') 190 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 191 | log_msg = self.delimiter.join([ 192 | header, 193 | '[{0' + space_fmt + '}/{1}]', 194 | 'eta: {eta}', 195 | '{meters}', 196 | 'time: {time}', 197 | 'data: {data}', 198 | 'max mem: {memory:.0f}' 199 | ]) 200 | MB = 1024.0 * 1024.0 201 | for obj in iterable: 202 | data_time.update(time.time() - end) 203 | yield obj 204 | iter_time.update(time.time() - end) 205 | if i % print_freq == 0 or i == len(iterable) - 1: 206 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 207 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 208 | print(log_msg.format( 209 | i, len(iterable), eta=eta_string, 210 | meters=str(self), 211 | time=str(iter_time), data=str(data_time), 212 | memory=torch.cuda.max_memory_allocated() / MB)) 213 | i += 1 214 | end = time.time() 215 | total_time = time.time() - start_time 216 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 217 | print('{} Total time: {} ({:.4f} s / it)'.format( 218 | header, total_time_str, total_time / len(iterable))) 219 | 220 | 221 | def collate_fn(batch): 222 | return tuple(zip(*batch)) 223 | 224 | 225 | def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor): 226 | 227 | def f(x): 228 | if x >= warmup_iters: 229 | return 1 230 | alpha = float(x) / warmup_iters 231 | return warmup_factor * (1 - alpha) + alpha 232 | 233 | return torch.optim.lr_scheduler.LambdaLR(optimizer, f) 234 | 235 | 236 | def mkdir(path): 237 | try: 238 | os.makedirs(path) 239 | except OSError as e: 240 | if e.errno != errno.EEXIST: 241 | raise 242 | 243 | 244 | def setup_for_distributed(is_master): 245 | """ 246 | This function disables printing when not in master process 247 | """ 248 | import builtins as __builtin__ 249 | builtin_print = __builtin__.print 250 | 251 | def print(*args, **kwargs): 252 | force = kwargs.pop('force', False) 253 | if is_master or force: 254 | builtin_print(*args, **kwargs) 255 | 256 | __builtin__.print = print 257 | 258 | 259 | def is_dist_avail_and_initialized(): 260 | if not dist.is_available(): 261 | return False 262 | if not dist.is_initialized(): 263 | return False 264 | return True 265 | 266 | 267 | def get_world_size(): 268 | if not is_dist_avail_and_initialized(): 269 | return 1 270 | return dist.get_world_size() 271 | 272 | 273 | def get_rank(): 274 | if not is_dist_avail_and_initialized(): 275 | return 0 276 | return dist.get_rank() 277 | 278 | 279 | def is_main_process(): 280 | return get_rank() == 0 281 | 282 | 283 | def save_on_master(*args, **kwargs): 284 | if is_main_process(): 285 | torch.save(*args, **kwargs) 286 | 287 | 288 | def init_distributed_mode(args): 289 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 290 | args.rank = int(os.environ["RANK"]) 291 | args.world_size = int(os.environ['WORLD_SIZE']) 292 | args.gpu = int(os.environ['LOCAL_RANK']) 293 | elif 'SLURM_PROCID' in os.environ: 294 | args.rank = int(os.environ['SLURM_PROCID']) 295 | args.gpu = args.rank % torch.cuda.device_count() 296 | else: 297 | print('Not using distributed mode') 298 | args.distributed = False 299 | return 300 | 301 | args.distributed = True 302 | 303 | torch.cuda.set_device(args.gpu) 304 | args.dist_backend = 'nccl' 305 | print('| distributed init (rank {}): {}'.format( 306 | args.rank, args.dist_url), flush=True) 307 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 308 | world_size=args.world_size, rank=args.rank) 309 | torch.distributed.barrier() 310 | setup_for_distributed(args.rank == 0) 311 | --------------------------------------------------------------------------------