├── README.md ├── adaplayer.py ├── mimic_loss.py ├── convert_coco2custom.py ├── main.py └── vis_groundtruth_keypoints.py /README.md: -------------------------------------------------------------------------------- 1 | # Distillation-of-Faster-rcnn 2 | Distillation for faster rcnn in classification,regression,feature level,feature level +mask 3 | 4 | ## Detail in my csdn blog: 5 | https://blog.csdn.net/qq_33547191/article/details/95014337 6 | ### 7 | https://blog.csdn.net/qq_33547191/article/details/95049838 8 | 9 | ## The code is heavily borrowed from : 10 | 11 | ### 1.Distillation for faster rcnn in classification,regression,feature level 12 | http://papers.nips.cc/paper/6676-learning-efficient-object-detection-models-with-knowledge-distillation.pdf 13 | 14 | 15 | ### 2.Distillation for faster rcnn in feature level +mask 16 | 17 | http://openaccess.thecvf.com/content_CVPR_2019/papers/Wang_Distilling_Object_Detectors_With_Fine-Grained_Feature_Imitation_CVPR_2019_paper.pdf 18 | 19 | #### code: 20 | https://github.com/twangnh/Distilling-Object-Detectors 21 | 22 | 23 | main.py 里面介绍在哪里加入蒸馏 24 | 25 | adaplayer.py 是针对teacher和student的feature map大小不相同时进行特征图大小进行变换 26 | 27 | mimic_loss.py是如何计算分类和回归的蒸馏loss,其实就是计算他们的相似度 28 | 29 | 蒸馏在teacher和student的精度差距非常大时效果特别明显。 30 | 31 | -------------------------------------------------------------------------------- /adaplayer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class Stu_Feature_Adap(nn.Module): 6 | 7 | def __init__(self,input_channel=256, output_channel=1024,kernel_size=2,padding=0): 8 | super(Stu_Feature_Adap, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(input_channel, output_channel, kernel_size=kernel_size, padding=padding) 11 | self.relu = nn.ReLU() 12 | # self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 13 | # self.conv3 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 14 | # self.conv4 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 15 | # self.classifier = nn.Conv2d(ndf*8, 1, kernel_size=4, stride=2, padding=1) 16 | # 17 | # self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 18 | #self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear') 19 | #self.sigmoid = nn.Sigmoid() 20 | 21 | 22 | def forward(self, x): 23 | x = self.conv1(x) 24 | x = self.relu(x) 25 | # x = self.leaky_relu(x) 26 | # x = self.conv2(x) 27 | # x = self.leaky_relu(x) 28 | # x = self.conv3(x) 29 | # x = self.leaky_relu(x) 30 | # x = self.conv4(x) 31 | # x = self.leaky_relu(x) 32 | # x = self.classifier(x) 33 | # #x = self.up_sample(x) 34 | # #x = self.sigmoid(x) 35 | 36 | return x 37 | -------------------------------------------------------------------------------- /mimic_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | 4 | 5 | def compute_loss_classification(Z_t, Z_s, mu, L_hard, T=1, weighted = True): 6 | #vettori di pesi 7 | # if weighted: 8 | # if torch.cuda.is_available(): 9 | # wc = torch.where((y == 0), 1.5 * torch.ones(Z_t.shape[0]).cuda(), torch.ones(Z_s.shape[0]).cuda()).double() 10 | # else: 11 | # wc = torch.where((y == 0), 1.5 * torch.ones(Z_t.shape[0]), torch.ones(Z_s.shape[0])).double() 12 | # else: 13 | wc = torch.ones(Z_s.shape[0]).cuda().float() 14 | Z_s = Z_s.float() 15 | Z_t = Z_t.float() 16 | 17 | P_t = F.softmax(Z_t /T, dim=1) 18 | P_s = F.softmax(Z_s /T, dim=1) 19 | 20 | try: 21 | #print('P_t & P_s:') 22 | #print(P_t.shape) 23 | #print(P_s.shape) 24 | P = torch.sum(P_t * torch.log(P_s), dim=1)# era P_s è e-10 25 | except: 26 | #print('Err:') 27 | #print(P_t.shape) 28 | #print(P_s.shape) 29 | mu=1 30 | L_soft = torch.zeros([1]).cuda().float() 31 | L_cls = L_hard.float() 32 | return L_cls, L_soft 33 | # raise RuntimeError('Stop!') 34 | L_soft = -torch.mean(P*wc) 35 | 36 | L_cls = mu * L_hard.float() + (1 - mu) * L_soft 37 | 38 | 39 | return L_cls, L_soft 40 | 41 | # def compute_loss_classification(Z_t, Z_s, mu, L_hard, y, T=1, weighted = True): 42 | # #vettori di pesi 43 | # if weighted: 44 | # if torch.cuda.is_available(): 45 | # wc = torch.where((y == 0), 1.5 * torch.ones(Z_t.shape[0]).cuda(), torch.ones(Z_s.shape[0]).cuda()).double() 46 | # else: 47 | # wc = torch.where((y == 0), 1.5 * torch.ones(Z_t.shape[0]), torch.ones(Z_s.shape[0])).double() 48 | # else: 49 | # wc = torch.ones(Z_s.shape[0]).cuda().double() 50 | # Z_s = Z_s.double() 51 | # Z_t = Z_t.double() 52 | # 53 | # P_t = F.softmax(Z_t /T, dim=1) 54 | # P_s = F.softmax(Z_s /T, dim=1) 55 | # 56 | # 57 | # P = torch.sum(P_t * torch.log(P_s), dim=1)# era P_s è e-10 58 | # 59 | # L_soft = -torch.mean(P*wc) 60 | # 61 | # L_cls = mu * L_hard.double() + (1 - mu) * L_soft 62 | # 63 | # return L_cls, L_soft 64 | 65 | def compute_loss_regression(smooth_l1_loss, Rs, Rt, y_reg_s, y_reg_t , m,ni): 66 | 67 | s_box_diff = Rs - y_reg_s 68 | t_box_diff = Rt - y_reg_t 69 | # in_s_box_diff = bbox_inside_weights_s * s_box_diff 70 | # in_t_box_diff = bbox_inside_weights_t * t_box_diff 71 | # in_s_box_diff = in_s_box_diff * bbox_outside_weights_s 72 | # in_t_box_diff = in_t_box_diff * bbox_outside_weights_t 73 | in_s_box_diff = s_box_diff 74 | in_t_box_diff = t_box_diff 75 | 76 | in_s_bd_quad = in_s_box_diff.pow(2) 77 | in_t_bd_quad = in_t_box_diff.pow(2) 78 | norm_s = in_s_bd_quad 79 | norm_t = in_t_bd_quad 80 | dim= range(1,len(in_s_box_diff.shape)) 81 | for i in sorted(dim, reverse=True): 82 | norm_s = norm_s.sum(i) 83 | norm_t = norm_t.sum(i) 84 | if torch.cuda.is_available(): 85 | zeros = torch.zeros(norm_s.shape).cuda() 86 | else: 87 | zeros = torch.zeros(norm_s.shape) 88 | try: 89 | l_b = torch.where((norm_s + m <= norm_t), zeros, norm_s) 90 | l_reg = smooth_l1_loss + ni * l_b.mean() 91 | except: 92 | l_reg = smooth_l1_loss 93 | l_b=torch.zeros([1]).cuda().float() 94 | 95 | 96 | return l_reg, l_b.mean(), norm_s.mean(), norm_t.mean() 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /convert_coco2custom.py: -------------------------------------------------------------------------------- 1 | 2 | # -*- coding:utf-8 -*- 3 | 4 | from __future__ import print_function 5 | from pycocotools.coco import COCO 6 | import os, sys, zipfile 7 | import urllib.request 8 | import shutil 9 | import numpy as np 10 | # import skimage.io as io 11 | import matplotlib.pyplot as plt 12 | import pylab 13 | import json 14 | global true,false 15 | true=True 16 | false=False 17 | ''' 18 | keypoints是一个长度为3*k的数组,其中k是category中keypoints的总数量。 19 | 每一个keypoint是一个长度为3的数组,第一和第二个元素分别是x和y坐标值,第三个元素是个标志位v, 20 | v为0时表示这个关键点没有标注(这种情况下x=y=v=0),v为1时表示这个关键点标注了但是不可见(被遮挡了),v为2时表示这个关键点标注了同时也可见。 21 | num_keypoints表示这个目标上被标注的关键点的数量(v>0),比较小的目标上可能就无法标注关键点。 22 | "keypoints": [623, 171, 2, 23 | 0, 0, 0, 24 | 602, 152, 2, 25 | 0, 0, 0, 26 | 571, 126, 1, 27 | 0, 0, 0, 28 | 499, 222, 2, 29 | 0, 0, 0, 30 | 438, 358, 2, 31 | 592, 407, 2, 32 | 399, 465, 2], 33 | ''' 34 | def convert_coco2custom(): 35 | fn=open('./custom_coco_train_17.json', 'w') 36 | json_file = 'coco/person_keypoints_train2017.json' # # Object Instance 类型的标注 37 | # person_keypoints_val2017.json # Object Keypoint 类型的标注格式 38 | # captions_val2017.json # Image Caption的标注格式 39 | count=0 40 | data = json.load(open(json_file, 'r')) 41 | annotation_dict = {} 42 | for ann in data['annotations']: 43 | # print(ann) 44 | annotation_dict.update({ann['id']:ann}) 45 | 46 | for image in data['images']: 47 | # print(image) 48 | img_name=image['file_name'] 49 | img_height=image['height'] 50 | img_width=image['width'] 51 | id=image['id'] 52 | instances=[] 53 | annotations = data['annotations'] 54 | GT_NUM=len(annotations) 55 | 56 | # print(GT_NUM) 57 | for i in range(GT_NUM): 58 | ann=annotations[i] 59 | if ann['image_id']==id: 60 | ann_bbox=ann['bbox'] 61 | bbox=[ann_bbox[0],ann_bbox[1],ann_bbox[0]+ann_bbox[2],ann_bbox[1]+ann_bbox[3]] 62 | keypoints=ann['keypoints'] 63 | need_keypoints=[] 64 | for i in range(17): 65 | need_keypoints.append(keypoints[0+i*3:3+i*3]) 66 | instances.append({ 67 | 'is_ignored': False, 68 | 'bbox': bbox, 69 | 'keypoints': need_keypoints, 70 | 'label': 1}) 71 | if len(instances)<1: 72 | continue 73 | data_write = { 74 | 'filename': 'data/train2017/'+img_name, 75 | 'image_height': img_height, 76 | 'image_width': img_width, 77 | 'instances': instances 78 | } 79 | fn.write(json.dumps(data_write, ensure_ascii=False) + '\n') 80 | count+=1 81 | print(count) 82 | 83 | 84 | # data_need={} 85 | # data_need['images'] = [data['images'][27]] # 只提取第一张图片 86 | # print(data['images'][27]) 87 | # annotation = [] 88 | # # 通过imgID 找到其所有对象 89 | # # for 90 | # imgID = data_need['images'][0]['id'] 91 | # print(len(data['images'])) 92 | # print(imgID) 93 | # for ann in data['annotations']: 94 | # print(ann) 95 | # annotation.append(ann) 96 | # # if ann['image_id'] == imgID: 97 | # # annotation.append(ann) 98 | # fn.write(json.dumps(ann, ensure_ascii=False) + '\n') 99 | # data_need['annotations'] = annotation 100 | 101 | # 保存到新的JSON文件,便于查看数据特点 102 | 103 | 104 | # json.dump(data_need, open('./new_instances_val2017.json', 'w'), indent=1) # indent=4 更加美观显示 105 | convert_coco2custom() 106 | 107 | 108 | def read_one_image(): 109 | json_file = 'coco/person_keypoints_val2017.json' # # Object Instance 类型的标注 110 | # person_keypoints_val2017.json # Object Keypoint 类型的标注格式 111 | # captions_val2017.json # Image Caption的标注格式 112 | 113 | data = json.load(open(json_file, 'r')) 114 | data_2 = {} 115 | data_2['info'] = data['info'] 116 | data_2['licenses'] = data['licenses'] 117 | data_2['images'] = [data['images'][27]] # 只提取第一张图片 118 | data_2['categories'] = data['categories'] 119 | annotation = [] 120 | 121 | # 通过imgID 找到其所有对象 122 | imgID = data_2['images'][0]['id'] 123 | for ann in data['annotations']: 124 | if ann['image_id'] == imgID: 125 | annotation.append(ann) 126 | 127 | data_2['annotations'] = annotation 128 | 129 | # 保存到新的JSON文件,便于查看数据特点 130 | json.dump(data_2, open('./one.json', 'w'), indent=4) # indent=4 更加美观显示 131 | # read_one_image() 132 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from mimic_loss import compute_loss_classification,compute_loss_classification 2 | 3 | output = model(to_device(input)) #student输出 4 | output_teacher = model_teacher(to_device(input)) #teacher输出 5 | 6 | 7 | # rois_label_t=output['cls_target'] 8 | 9 | ''' 10 | Classification and regression distillation in RPN: 11 | ''' 12 | if cfg_distillation.get('cls_distillation_rpn',None): 13 | cfg_cls_distillation = cfg_distillation.get('cls_distillation_rpn') 14 | rpn_cls_score_t = output_teacher['RoINet.cls_pred'] 15 | rpn_cls_score_s = output['RoINet.cls_pred'] 16 | RPN_loss_cls_s = output['RoINet.cls_loss'] 17 | start_mu=cfg_cls_distillation.get('start_mu') 18 | end_mu=cfg_cls_distillation.get('end_mu') 19 | mu=start_mu+(end_mu-start_mu)*(float(epoch)/max_epoch) 20 | loss_rpn_cls, loss_rpn_cls_soft = compute_loss_classification(rpn_cls_score_t, rpn_cls_score_s, mu, 21 | RPN_loss_cls_s, T=1, weighted=True) 22 | # loss_rcn_cls, loss_rcn_cls_soft = compute_loss_classification(rcn_cls_score_t, rcn_cls_score_s, mu, 23 | # 24 | output['RoINet.cls_loss']=loss_rpn_cls 25 | 26 | if cfg_distillation.get('loc_distillation_rpn', None): 27 | cfg_loc_distillation = cfg_distillation.get('loc_distillation_rpn') 28 | RCNN_loss_bbox_s = output['RoINet.loc_loss'] 29 | bbox_pred_s = output['RoINet.loc_pred'] 30 | bbox_pred_t = output_teacher['RoINet.loc_pred'] 31 | rpn_rois_target_s = output['RoINet.loc_target'] 32 | rpn_rois_target_t = output_teacher['RoINet.loc_target'] 33 | 34 | start_ni = cfg_loc_distillation.get('start_ni') 35 | end_ni = cfg_loc_distillation.get('end_ni') 36 | ni = start_ni + (end_ni - start_ni) * (float(epoch) / max_epoch) 37 | loss_rpn_reg, loss_rpn_reg_soft, _, _ = \ 38 | compute_loss_regression(RCNN_loss_bbox_s, bbox_pred_s, bbox_pred_t, rpn_rois_target_s, rpn_rois_target_t, 39 | m=0.01, ni=ni) 40 | # loss_rcn_cls, loss_rcn_cls_soft = compute_loss_classification(rcn_cls_score_t, rcn_cls_score_s, mu, 41 | # RCNN_loss_cls_s, T=1, weighted=True) 42 | output['RoINet.loc_loss'] = loss_rpn_reg 43 | 44 | 45 | ''' 46 | Classification and regression distillation in rcnn: 47 | ''' 48 | if cfg_distillation.get('cls_distillation',None): 49 | cfg_cls_distillation = cfg_distillation.get('cls_distillation') 50 | rcn_cls_score_t = output_teacher['cls_pred'] 51 | rcn_cls_score_s = output['cls_pred'] 52 | RCNN_loss_cls_s = output['BboxNet.cls_loss'] 53 | start_mu=cfg_cls_distillation.get('start_mu') 54 | end_mu=cfg_cls_distillation.get('end_mu') 55 | mu=start_mu+(end_mu-start_mu)*(float(epoch)/max_epoch) 56 | loss_rcn_cls, loss_rcn_cls_soft = compute_loss_classification(rcn_cls_score_t, rcn_cls_score_s, mu, 57 | RCNN_loss_cls_s, T=1, weighted=True) 58 | # loss_rcn_cls, loss_rcn_cls_soft = compute_loss_classification(rcn_cls_score_t, rcn_cls_score_s, mu, 59 | # 60 | output['BboxNet.cls_loss']=loss_rcn_cls 61 | 62 | if cfg_distillation.get('loc_distillation',None): 63 | cfg_loc_distillation=cfg_distillation.get('loc_distillation') 64 | RCNN_loss_bbox_s=output['BboxNet.loc_loss'] 65 | bbox_pred_s=output['loc_pred'] 66 | bbox_pred_t=output_teacher['loc_pred'] 67 | rois_target_s=output['loc_target'] 68 | rois_target_t=output_teacher['loc_target'] 69 | 70 | start_ni=cfg_loc_distillation.get('start_ni') 71 | end_ni=cfg_loc_distillation.get('end_ni') 72 | ni=start_ni+(end_ni-start_ni)*(float(epoch)/max_epoch) 73 | loss_rcn_reg, loss_rcn_reg_soft,_,_ = \ 74 | compute_loss_regression(RCNN_loss_bbox_s, bbox_pred_s, bbox_pred_t,rois_target_s, rois_target_t, m=0.01, ni=ni) 75 | # loss_rcn_cls, loss_rcn_cls_soft = compute_loss_classification(rcn_cls_score_t, rcn_cls_score_s, mu, 76 | # RCNN_loss_cls_s, T=1, weighted=True) 77 | output['BboxNet.loc_loss'] = loss_rcn_reg 78 | 79 | ''' 80 | Feature level distillation: 81 | ''' 82 | # sup_loss = (torch.pow(sup_feature - stu_feature_adap, 2) * mask_batch).sum() / norms 83 | # sup_loss = sup_loss * args.imitation_loss_weigth 84 | if cfg_distillation.get('feature_distillation', None): 85 | cfg_feature_distillation=cfg_distillation.get('feature_distillation') 86 | sup_feature=output_teacher['features'][0] 87 | stu_feature=output['features'][0] 88 | stu_feature_adap=model_adap(stu_feature) 89 | 90 | 91 | start_weigth=cfg_feature_distillation.get('start_weigth') 92 | end_weigth=cfg_feature_distillation.get('end_weigth') 93 | imitation_loss_weigth = start_weigth + (end_weigth - start_weigth) * (float(epoch) / max_epoch) 94 | if cfg_feature_distillation.get('need_mask', None): 95 | mask_batch = output_teacher['RoINet.mask_batch'] 96 | mask_list = [] 97 | for mask in mask_batch: 98 | mask = (mask > 0).float().unsqueeze(0) 99 | mask_list.append(mask) 100 | mask_batch = torch.stack(mask_list, dim=0) 101 | # print('sum:%d' % mask_batch.sum(), flush=True) 102 | norms = mask_batch.sum() ** 2 103 | # print('norms:%d'%norms,flush=True) 104 | sup_loss = (torch.pow(sup_feature - stu_feature_adap, 2) * mask_batch).sum() / norms 105 | 106 | # print('sup_loss:%f' % sup_loss,flush=True) 107 | if sup_loss>100 or sup_loss<-100: 108 | print('sup_loss:%f' % sup_loss, flush=True) 109 | sup_loss=sup_loss*0.000001 110 | 111 | # raise RuntimeError('Stop') 112 | else: 113 | sup_loss = (torch.pow(sup_feature - stu_feature_adap, 2)).sum() 114 | 115 | # imitation_loss_weigth=0.0001 116 | 117 | sup_loss = sup_loss * imitation_loss_weigth 118 | output['sup.loss']=sup_loss 119 | -------------------------------------------------------------------------------- /vis_groundtruth_keypoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | module_path = os.path.abspath(os.path.join('..')) 4 | if module_path not in sys.path: 5 | sys.path.append(module_path) 6 | import cv2 7 | 8 | from PIL import Image 9 | from matplotlib import pyplot as plt 10 | import argparse 11 | import numpy as np 12 | global false, null, true 13 | false =False 14 | true=True 15 | 16 | 17 | 18 | # def align_boxes(img_path,instances,target_dir,box,iou,w,h): 19 | ''' 20 | :param img_path: 21 | :param instances: 22 | :return: 23 | ''' 24 | from matplotlib import pyplot, patches 25 | import numpy as np 26 | import math 27 | 28 | 29 | # PERSON_KEYPOINTS = [ 30 | # "nose", "left_eye", "right_eye", "left_ear", "right_ear", "left_shoulder", 31 | # "right_shoulder", "left_elbow", "right_elbow", "left_wrist", "right_wrist", 32 | # "left_hip", "right_hip", "left_knee", "right_knee", "left_ankle", 33 | # "right_ankle" 34 | # ] 35 | 36 | # KEYPOINT_PAIRS = [(i, i + 1) for i in range(1, 16, 2)] 37 | 38 | # for visualization 39 | # KEYP_LINES = [ 40 | # [PERSON_KEYPOINTS.index('left_eye'), PERSON_KEYPOINTS.index('right_eye')], 41 | # [PERSON_KEYPOINTS.index('left_eye'), PERSON_KEYPOINTS.index('nose')], 42 | # [PERSON_KEYPOINTS.index('right_eye'), PERSON_KEYPOINTS.index('nose')], 43 | # [PERSON_KEYPOINTS.index('right_eye'), PERSON_KEYPOINTS.index('right_ear')], 44 | # [PERSON_KEYPOINTS.index('left_eye'), PERSON_KEYPOINTS.index('left_ear')], 45 | # [PERSON_KEYPOINTS.index('right_shoulder'), PERSON_KEYPOINTS.index('right_elbow')], 46 | # [PERSON_KEYPOINTS.index('right_elbow'), PERSON_KEYPOINTS.index('right_wrist')], 47 | # [PERSON_KEYPOINTS.index('left_shoulder'), PERSON_KEYPOINTS.index('left_elbow')], 48 | # [PERSON_KEYPOINTS.index('left_elbow'), PERSON_KEYPOINTS.index('left_wrist')], 49 | # [PERSON_KEYPOINTS.index('right_hip'), PERSON_KEYPOINTS.index('right_knee')], 50 | # [PERSON_KEYPOINTS.index('right_knee'), PERSON_KEYPOINTS.index('right_ankle')], 51 | # [PERSON_KEYPOINTS.index('left_hip'), PERSON_KEYPOINTS.index('left_knee')], 52 | # [PERSON_KEYPOINTS.index('left_knee'), PERSON_KEYPOINTS.index('left_ankle')], 53 | # [PERSON_KEYPOINTS.index('right_shoulder'), PERSON_KEYPOINTS.index('left_shoulder')], 54 | # [PERSON_KEYPOINTS.index('right_hip'), PERSON_KEYPOINTS.index('left_hip')], 55 | # [PERSON_KEYPOINTS.index('left_ear'), PERSON_KEYPOINTS.index('left_shoulder')], 56 | # [PERSON_KEYPOINTS.index('right_ear'), PERSON_KEYPOINTS.index('right_shoulder')], 57 | # [PERSON_KEYPOINTS.index('left_shoulder'), PERSON_KEYPOINTS.index('left_hip')], 58 | # [PERSON_KEYPOINTS.index('right_shoulder'), PERSON_KEYPOINTS.index('right_hip')], 59 | # ] 60 | 61 | 62 | ''' 63 | 14: 64 | 65 | ''' 66 | KEYPOINT_COLORS = ['magenta', 'cyan', 'yellow', 'green', 67 | 'lime', 'blue', 'purple', 'orange', 68 | 'white', 'lightcoral', 'lime', 'olive', 69 | 'steelblue', 'red', 'gold', 'navy', 70 | 'dodgerblue', 'mediumaquamarine', 'black', 'gray'] 71 | # KEYPOINT_PAIRS = \ 72 | # [ 73 | # (0,1),(0,2),(0,5), 74 | # (1,2),(1,5),(1,8),(1,11), 75 | # (2,3),(3,4),(5,6),(6,7), 76 | # (8,9),(9,10),(11,12),(12,13) 77 | # ] 78 | 79 | ''' 80 | 11: 81 | PERSON_KEYPOINTS = [ 82 | "nose", "left_eye", "right_eye", "left_ear", "right_ear", "left_shoulder", 83 | "right_shoulder", "left_elbow", "right_elbow", "left_wrist", "right_wrist"] 84 | 85 | ''' 86 | KEYPOINT_PAIRS = \ 87 | [ 88 | (0,1),(0,2), 89 | (1,3),(3,5),(5,7),(7,9), 90 | (2,4),(4,6),(6,8),(8,10) 91 | ] 92 | 93 | 94 | def vis_one_img(img_root,grountruth_file,vis_num): 95 | f=open(grountruth_file,'r') 96 | lines=f.readlines() 97 | line=lines[vis_num-1] 98 | line_dic=eval(line) 99 | 100 | img_name=line_dic['filename'] 101 | print(img_name,flush=True) 102 | img_position = os.path.join(img_root,img_name) 103 | 104 | image = cv2.imread(img_position, cv2.IMREAD_COLOR) 105 | img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 106 | plt.ion() 107 | fig = plt.figure(figsize=(5, 5)) 108 | # 109 | plt.imshow(img) 110 | currentAxis = plt.gca() 111 | instances=line_dic['instances'] 112 | for i in range(len(instances)): 113 | roi = instances[i]['bbox'] 114 | if instances[i]['label'] == 1: 115 | if instances[i]['is_ignored']: 116 | currentAxis.add_patch(plt.Rectangle((roi[0], roi[1]), roi[2] - roi[0], 117 | roi[3] - roi[1], fill=False, 118 | edgecolor='y', linewidth=1.5)) 119 | # if i==38: 120 | currentAxis.text(float(roi[0]), roi[1], i, 121 | color='b',fontsize=10,bbox={'facecolor': 'g','alpha': 0.001}) 122 | else: 123 | currentAxis.add_patch(plt.Rectangle((roi[0], roi[1]), roi[2] - roi[0], 124 | roi[3] - roi[1], fill=False, 125 | edgecolor='r', linewidth=2.5)) 126 | # if i == 38: 127 | currentAxis.text(float(roi[0]), roi[1], i, 128 | color='r',fontsize=10,bbox={'facecolor': 'g','alpha': 0.001}) 129 | else: 130 | if instances[i]['is_ignored']: 131 | currentAxis.add_patch(plt.Rectangle((roi[0], roi[1]), roi[2] - roi[0], 132 | roi[3] - roi[1], fill=False, 133 | edgecolor='g', linewidth=1.5)) 134 | else: 135 | currentAxis.add_patch(plt.Rectangle((roi[0], roi[1]), roi[2] - roi[0], 136 | roi[3] - roi[1], fill=False, 137 | edgecolor='w', linewidth=1.5)) 138 | keypoints=instances[i]['keypoints'] 139 | if keypoints is not None and len(keypoints) > 0: 140 | # im = vis_keypoints(im, keypoints[i]) 141 | 142 | 143 | colormap_index = np.linspace(0, 1, len(KEYPOINT_PAIRS)) 144 | # for i in range(18): 145 | keypoints_2d = keypoints 146 | for corner_xy, color in zip(keypoints_2d, KEYPOINT_COLORS): 147 | print(corner_xy) 148 | # corner_xy=int(corner_xy) 149 | if corner_xy[2]>0: 150 | currentAxis.add_patch(patches.Circle(corner_xy, radius=4, fill=True, edgecolor=color)) 151 | 152 | pts = np.array(keypoints) 153 | 154 | joint_visible = pts[:, 2] > 0 155 | 156 | for cm_ind, jp in zip(colormap_index, KEYPOINT_PAIRS): 157 | # try: 158 | if joint_visible[jp[0]] and joint_visible[jp[1]]: 159 | currentAxis.plot(pts[jp,0], pts[jp,1], 160 | linewidth=2.0, alpha=0.7, color=plt.cm.cool(cm_ind)) 161 | currentAxis.scatter(pts[jp, 0], pts[jp, 1], s=5) 162 | # except: 163 | # continue 164 | 165 | target_dir ='./labeledImgs/' 166 | if not os.path.exists(target_dir): 167 | os.makedirs(target_dir) 168 | plt.axis('off') 169 | plt.subplots_adjust(top=0.9, bottom=0.01, right=0.9, left=0.01, hspace=0.01, wspace=0.01) 170 | plt.margins(0, 0) 171 | # fig.savefig(out_png_path, format='png', transparent=True, dpi=300, pad_inches=0) 172 | # fig.set_size_inches(1920 / 100, 1080 / 100) 173 | plt.savefig(target_dir + img_name.replace('/','_'), format='jpg', transparent=True, dpi=300, pad_inches=0) 174 | 175 | plt.show() 176 | plt.pause(400) 177 | plt.close() 178 | 179 | if __name__=='__main__': 180 | img_root='/data/COCO/walter1218-datasets-mscoco-1' 181 | grountruth_file='./custom_coco_train.json' 182 | vis_one_img(img_root, grountruth_file,vis_num=57000) 183 | # for i in range(60): 184 | # vis_num=i*90 185 | # vis_one_img(img_root,grountruth_file,vis_num) 186 | --------------------------------------------------------------------------------