├── models ├── __init__.py ├── losses.py └── models.py ├── .idea ├── encodings.xml ├── vcs.xml ├── other.xml ├── dictionaries │ ├── jin.xml │ └── zhijun.xml ├── modules.xml ├── misc.xml ├── Look_At_Boundary_PyTorch.iml ├── workspace.xml └── inspectionProfiles │ └── Project_Default.xml ├── utils ├── __init__.py ├── pdb.py ├── args.py ├── dataset_info.py ├── train_eval_utils.py ├── visual.py └── dataload.py ├── README.md ├── dataset.py ├── evaluate.py └── train.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import WingLoss 2 | from .models import Estimator, Regressor, Discrim 3 | -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .args import args 2 | from .dataload import * 3 | from .dataset_info import * 4 | from .train_eval_utils import * 5 | from .visual import * 6 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /.idea/dictionaries/jin.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | caffe 5 | cfss 6 | cofw 7 | sapm 8 | tcdcn 9 | 10 | 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Say nothing for now, maybe later... 2 | 3 | My best result (mean error rate % normalized by inter_pupil) on 300W is as follow: 4 | 5 | | Challenge Subset | Common Subset | Fullset | 6 | | :--------------: | :--------------: | :--------------: | 7 | | 9.11 | 4.97 | 5.78 | 8 | 9 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from utils import args, get_annotations_list, get_item_from 3 | 4 | 5 | class GeneralDataset(data.Dataset): 6 | 7 | def __init__(self, dataset='WFLW', split='train'): 8 | self.dataset = dataset 9 | self.split = split 10 | self.list = get_annotations_list(dataset, split, ispdb=args.PDB) 11 | 12 | def __len__(self): 13 | return len(self.list) 14 | 15 | def __getitem__(self, item): 16 | return get_item_from(self.dataset, self.split, self.list[item]) 17 | -------------------------------------------------------------------------------- /.idea/Look_At_Boundary_PyTorch.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class HeatmapLoss(nn.Module): 7 | def __init__(self): 8 | super(HeatmapLoss, self).__init__() 9 | 10 | def forward(self, pred, gt): 11 | assert pred.size() == gt.size() 12 | loss = ((pred - gt)**2) 13 | loss = loss.sum(dim=3).sum(dim=2).sum(dim=1).mean() / 2. 14 | return loss 15 | 16 | 17 | class WingLoss(nn.Module): 18 | 19 | def __init__(self, w=10, epsilon=2, weight=None): 20 | super(WingLoss, self).__init__() 21 | self.w = w 22 | self.epsilon = epsilon 23 | self.C = self.w - self.w * np.log(1 + self.w / self.epsilon) 24 | self.weight = weight 25 | 26 | def forward(self, predictions, targets): 27 | x = predictions - targets 28 | if self.weight is not None: 29 | x = x * self.weight 30 | t = torch.abs(x) 31 | 32 | return torch.mean(torch.where(t < self.w, self.w * torch.log(1 + t / self.epsilon), t - self.C)) 33 | -------------------------------------------------------------------------------- /.idea/dictionaries/zhijun.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | accuracys 5 | aflw 6 | allshapes 7 | anno 8 | annos 9 | batchnorm 10 | batchsize 11 | bcsize 12 | bilinear 13 | ckpts 14 | conv 15 | cuda 16 | cudnn 17 | dataload 18 | dataloader 19 | dataloading 20 | datasets 21 | dfake 22 | discrim 23 | downsample 24 | fmfhourglass 25 | frobenius 26 | frontalset 27 | fullset 28 | gthm 29 | heatmap 30 | heatmaps 31 | idxs 32 | imgs 33 | imread 34 | imshow 35 | inchannels 36 | inplanes 37 | ispdb 38 | keypoint 39 | keypoints 40 | largepose 41 | lsll 42 | lsul 43 | maxpool 44 | newidx 45 | numbins 46 | numel 47 | regressor 48 | relu 49 | standarised 50 | strd 51 | testset 52 | tform 53 | trainset 54 | usll 55 | usul 56 | wflw 57 | wingloss 58 | zhijun 59 | 60 | 61 | -------------------------------------------------------------------------------- /utils/pdb.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.decomposition import PCA 3 | 4 | 5 | def procrustes(X, Y, scaling=True, reflection='best'): 6 | n, m = X.shape 7 | ny, my = Y.shape 8 | muX = X.mean(0) 9 | muY = Y.mean(0) 10 | X0 = X - muX 11 | Y0 = Y - muY 12 | ssX = (X0**2.).sum() 13 | ssY = (Y0**2.).sum() 14 | # centred Frobenius norm 15 | normX = np.sqrt(ssX) 16 | normY = np.sqrt(ssY) 17 | # scale to equal (unit) norm 18 | X0 = X0/normX if normX > 1e-6 else X0 19 | Y0 = Y0/normY if normY > 1e-6 else Y0 20 | if my < m: 21 | Y0 = np.concatenate((Y0, np.zeros(n, m-my)), 0) 22 | # optimum rotation matrix of Y 23 | A = np.dot(X0.T, Y0) 24 | U, s, Vt = np.linalg.svd(A, full_matrices=False) 25 | V = Vt.T 26 | T = np.dot(V, U.T) 27 | if reflection is not 'best': 28 | # does the current solution use a reflection? 29 | have_reflection = np.linalg.det(T) < 0 30 | # if that's not what was specified, force another reflection 31 | if reflection != have_reflection: 32 | V[:, -1] *= -1 33 | s[-1] *= -1 34 | T = np.dot(V, U.T) 35 | traceTA = s.sum() 36 | if scaling: 37 | # optimum scaling of Y 38 | b = traceTA * normX / normY if normY > 1e-6 else traceTA * normX 39 | # standarised distance between X and b*Y*T + c 40 | d = 1 - traceTA**2 41 | # transformed coords 42 | Z = normX*traceTA*np.dot(Y0, T) + muX 43 | else: 44 | b = 1 45 | d = 1 + ssY/ssX - 2 * traceTA * normY / normX 46 | Z = normY*np.dot(Y0, T) + muX 47 | # transformation matrix 48 | if my < m: 49 | T = T[:my, :] 50 | c = muX - b*np.dot(muY, T) 51 | # transformation values 52 | tform = {'rotation': T, 'scale': b, 'translation': c} 53 | 54 | return d, Z, tform 55 | 56 | 57 | # input as array 58 | def pdb(dataset, allShapes, numBins): 59 | alignedShape = allShapes 60 | meanShape = np.mean(alignedShape, 1) 61 | for i in range(len(alignedShape[0])): 62 | _, tmpS, _ = procrustes(meanShape.reshape((-1, 2), order='F'), 63 | alignedShape[:, i].reshape((-1, 2), order='F')) 64 | alignedShape[:, i] = tmpS.reshape((1, -1), order='F') 65 | 66 | meanShape = np.mean(alignedShape, 1) 67 | meanShape = meanShape.repeat(len(alignedShape[0])).reshape(-1, len(alignedShape[0])) 68 | alignedShape = alignedShape - meanShape 69 | pca = PCA(n_components=2) if dataset in ['AFLW', 'COFW'] else PCA(n_components=1) 70 | posePara = pca.fit_transform(np.transpose(alignedShape)) 71 | 72 | absPosePara = np.abs(posePara[:, 1]) if dataset in ['AFLW', 'COFW'] else np.abs(posePara) 73 | maxPosePara = np.max(absPosePara) 74 | maxSampleInBins = np.max(np.histogram(absPosePara, numBins)[0]) 75 | 76 | newIdx = np.array([]) 77 | for i in range(numBins): 78 | tmp1 = set([index for index in range(len(absPosePara)) 79 | if absPosePara[index] >= i*maxPosePara/numBins]) 80 | tmp2 = set([index for index in range(len(absPosePara)) 81 | if absPosePara[index] <= (i+1)*maxPosePara/numBins]) 82 | tmpTrainIdx = np.array(list(tmp1 & tmp2)) 83 | ratio = round(maxSampleInBins / len(tmpTrainIdx)) if len(tmpTrainIdx) > 0 else 0 84 | newIdx = np.insert(newIdx, 0, values=tmpTrainIdx.repeat(ratio), axis=0) 85 | return newIdx 86 | -------------------------------------------------------------------------------- /utils/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(description='LAB') 4 | 5 | # dataset 6 | parser.add_argument('--dataset_route', default='/home/jin/new_datasets/', type=str) 7 | parser.add_argument('--dataset', default='WFLW', type=str) 8 | parser.add_argument('--split', default='pose', type=str) 9 | 10 | # dataloader 11 | parser.add_argument('--crop_size', default=256, type=int) 12 | parser.add_argument('--batch_size', default=4, type=int) 13 | parser.add_argument('--workers', default=8, type=int) 14 | parser.add_argument('--shuffle', default=True, type=bool) 15 | parser.add_argument('--PDB', default=False, type=bool) 16 | parser.add_argument('--RGB', default=False, type=bool) 17 | parser.add_argument('--trans_ratio', default=0.1, type=float) 18 | parser.add_argument('--rotate_limit', default=20., type=float) 19 | parser.add_argument('--scale_ratio', default=0.1, type=float) 20 | 21 | # devices 22 | parser.add_argument('--cuda', default=True, type=bool) 23 | parser.add_argument('--gpu_id', default='0', type=str) 24 | 25 | # learning parameters 26 | parser.add_argument('--momentum', default=0.9, type=float) 27 | parser.add_argument('--weight_decay', default=5e-4, type=float) 28 | parser.add_argument('--lr', default=2e-5, type=float) 29 | parser.add_argument('--gamma', default=0.2, type=float) 30 | parser.add_argument('--step_values', default=[1000, 1500], type=list) 31 | parser.add_argument('--max_epoch', default=2000, type=int) 32 | 33 | # losses setting 34 | parser.add_argument('--loss_type', default='smoothL1', type=str, 35 | choices=['L1', 'L2', 'smoothL1', 'wingloss']) 36 | parser.add_argument('--wingloss_w', default=10, type=int) 37 | parser.add_argument('--wingloss_e', default=2, type=int) 38 | 39 | # resume training parameters 40 | parser.add_argument('--resume_epoch', default=0, type=int) 41 | parser.add_argument('--resume_folder', default='./weights/ckpts/', type=str) 42 | 43 | # model saving parameters 44 | parser.add_argument('--save_folder', default='./weights/', type=str) 45 | parser.add_argument('--save_interval', default=100, type=int) 46 | 47 | # model setting 48 | parser.add_argument('--hour_stack', default=4, type=int) 49 | parser.add_argument('--msg_pass', default=True, type=bool) 50 | parser.add_argument('--GAN', default=True, type=bool) 51 | parser.add_argument('--fuse_stage', default=4, type=int) 52 | parser.add_argument('--sigma', default=1.0, type=float) 53 | parser.add_argument('--theta', default=1.5, type=float) 54 | parser.add_argument('--delta', default=0.8, type=float) 55 | 56 | # evaluate parameters 57 | parser.add_argument('--eval_epoch', default=900, type=int) 58 | parser.add_argument('--max_threshold', default=0.1, type=float) 59 | parser.add_argument('--norm_way', default='inter_ocular', type=str, 60 | choices=['inter_pupil', 'inter_ocular', 'face_size']) 61 | parser.add_argument('--eval_visual', default=True, type=bool) 62 | parser.add_argument('--save_img', default=True, type=bool) 63 | 64 | args = parser.parse_args() 65 | 66 | assert args.resume_epoch < args.step_values[0] 67 | assert args.resume_epoch < args.max_epoch 68 | assert args.step_values[-1] < args.max_epoch 69 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import time 3 | import numpy as np 4 | from dataset import GeneralDataset 5 | from models import * 6 | from utils import * 7 | 8 | 9 | def evaluate(arg): 10 | devices = torch.device('cuda:'+arg.gpu_id) 11 | error_rate = [] 12 | failure_count = 0 13 | max_threshold = arg.max_threshold 14 | 15 | testset = GeneralDataset(dataset=arg.dataset, split=arg.split) 16 | dataloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, pin_memory=True) 17 | 18 | print('***** Normal Evaluating *****') 19 | print('Evaluating parameters:\n' + 20 | '# Dataset: ' + arg.dataset + '\n' + 21 | '# Dataset split: ' + arg.split + '\n' + 22 | '# Epoch of the model: ' + str(arg.eval_epoch) + '\n' + 23 | '# Normalize way: ' + arg.norm_way + '\n' + 24 | '# Max threshold: ' + str(arg.max_threshold) + '\n') 25 | 26 | print('Loading network ...') 27 | estimator = Estimator(stacks=arg.hour_stack, msg_pass=arg.msg_pass) 28 | regressor = Regressor(fuse_stages=arg.fuse_stage, output=2*kp_num[arg.dataset]) 29 | estimator = load_weights(estimator, arg.save_folder+'estimator_'+str(arg.eval_epoch)+'.pth', devices) 30 | regressor = load_weights(regressor, arg.save_folder+arg.dataset+'_regressor_'+str(arg.eval_epoch)+'.pth', devices) 31 | if arg.cuda: 32 | estimator = estimator.cuda(device=devices) 33 | regressor = regressor.cuda(device=devices) 34 | estimator.eval() 35 | regressor.eval() 36 | print('Loading network done!\nStart testing ...') 37 | 38 | time_records = [] 39 | with torch.no_grad(): 40 | for data in tqdm.tqdm(dataloader): 41 | start = time.time() 42 | 43 | input_images, gt_coords_xy, gt_heatmap, coords_xy, bbox, img_name = data 44 | gt_coords_xy = gt_coords_xy.squeeze().numpy() 45 | bbox = bbox.squeeze().numpy() 46 | error_normalize_factor = calc_normalize_factor(arg.dataset, coords_xy.numpy(), arg.norm_way) \ 47 | if arg.norm_way in ['inter_pupil', 'inter_ocular'] else (bbox[2] - bbox[0]) 48 | input_images = input_images.unsqueeze(1) 49 | input_images = input_images.cuda(device=devices) 50 | 51 | pred_heatmaps = estimator(input_images) 52 | pred_coords = regressor(input_images, pred_heatmaps[-1].detach()).detach().cpu().squeeze().numpy() 53 | pred_coords_map_back = inverse_affine(arg, pred_coords, bbox) 54 | 55 | time_records.append(time.time() - start) 56 | 57 | error_rate_i = calc_error_rate_i( 58 | arg.dataset, 59 | pred_coords_map_back, 60 | coords_xy[0].numpy(), 61 | error_normalize_factor 62 | ) 63 | 64 | if arg.eval_visual: 65 | eval_heatmap(arg, pred_heatmaps[-1], img_name, bbox, save_img=arg.save_img) 66 | eval_pred_points(arg, pred_coords, img_name, bbox, save_img=arg.save_img) 67 | 68 | failure_count = failure_count + 1 if error_rate_i > max_threshold else failure_count 69 | error_rate.append(error_rate_i) 70 | 71 | area_under_curve, auc_record = calc_auc(arg.dataset, arg.split, error_rate, max_threshold) 72 | error_rate = sum(error_rate) / dataset_size[arg.dataset][arg.split] * 100 73 | failure_rate = failure_count / dataset_size[arg.dataset][arg.split] * 100 74 | 75 | print('\nEvaluating results:\n# AUC: {:.4f}\n# Error Rate: {:.2f}%\n# Failure Rate: {:.2f}%\n'.format( 76 | area_under_curve, error_rate, failure_rate)) 77 | print('Average speed: {:.2f}FPS'.format(1./np.mean(np.array(time_records)))) 78 | 79 | 80 | def evaluate_with_gt_heatmap(arg): 81 | devices = torch.device('cuda:' + arg.gpu_id) 82 | error_rate = [] 83 | failure_count = 0 84 | max_threshold = arg.max_threshold 85 | 86 | testset = GeneralDataset(dataset=arg.dataset, split=arg.split) 87 | dataloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, pin_memory=True) 88 | 89 | print('***** Evaluating with ground truth heatmap *****') 90 | print('Evaluating parameters:\n' + 91 | '# Dataset: ' + arg.dataset + '\n' + 92 | '# Dataset split: ' + arg.split + '\n' + 93 | '# Epoch of the model: ' + str(arg.eval_epoch) + '\n' + 94 | '# Normalize way: ' + arg.norm_way + '\n' + 95 | '# Max threshold: ' + str(arg.max_threshold) + '\n') 96 | 97 | print('Loading network...') 98 | regressor = Regressor(fuse_stages=arg.fuse_stage, output=2 * kp_num[arg.dataset]) 99 | regressor = load_weights(regressor, arg.save_folder + arg.dataset + '_regressor_' + str(arg.eval_epoch) + '.pth', 100 | devices) 101 | if arg.cuda: 102 | regressor = regressor.cuda(device=devices) 103 | regressor.eval() 104 | print('Loading network done!\nStart testing...') 105 | 106 | time_records = [] 107 | with torch.no_grad(): 108 | for data in tqdm.tqdm(dataloader): 109 | start = time.time() 110 | 111 | input_images, gt_coords_xy, gt_heatmap, coords_xy, bbox, img_name = data 112 | bbox = bbox.squeeze().numpy() 113 | error_normalize_factor = calc_normalize_factor(arg.dataset, coords_xy.numpy(), arg.norm_way) \ 114 | if arg.norm_way in ['inter_pupil', 'inter_ocular'] else (bbox[2] - bbox[0]) 115 | input_images = input_images.unsqueeze(1) 116 | input_images = input_images.cuda(device=devices) 117 | gt_heatmap = gt_heatmap.cuda(device=devices) 118 | 119 | pred_coords = regressor(input_images, gt_heatmap).detach().cpu().squeeze().numpy() 120 | pred_coords_map_back = inverse_affine(arg, pred_coords, bbox) 121 | 122 | time_records.append(time.time() - start) 123 | 124 | error_rate_i = calc_error_rate_i( 125 | arg.dataset, 126 | pred_coords_map_back, 127 | coords_xy[0].numpy(), 128 | error_normalize_factor 129 | ) 130 | 131 | if arg.eval_visual: 132 | eval_gt_pred_points(arg, gt_coords_xy, pred_coords, img_name, bbox, save_img=arg.save_img) 133 | 134 | failure_count = failure_count + 1 if error_rate_i > max_threshold else failure_count 135 | error_rate.append(error_rate_i) 136 | 137 | area_under_curve, auc_record = calc_auc(arg.dataset, arg.split, error_rate, max_threshold) 138 | error_rate = sum(error_rate) / dataset_size[arg.dataset][arg.split] * 100 139 | failure_rate = failure_count / dataset_size[arg.dataset][arg.split] * 100 140 | 141 | print('\nEvaluating results:\n# AUC: {:.4f}\n# Error Rate: {:.2f}%\n# Failure Rate: {:.2f}%\n'.format( 142 | area_under_curve, error_rate, failure_rate)) 143 | print('Average speed: {:.2f}FPS'.format(1. / np.mean(np.array(time_records)))) 144 | 145 | 146 | if __name__ == '__main__': 147 | evaluate(args) 148 | -------------------------------------------------------------------------------- /utils/dataset_info.py: -------------------------------------------------------------------------------- 1 | from .args import args 2 | 3 | heatmap_size = 64 4 | boundary_num = 13 5 | 6 | boundary_keys = ['chin', 'leb', 'reb', 'bon', 'breath', 'lue', 'lle', 'rue', 'rle', 'usul', 'lsul', 'usll', 'lsll'] 7 | 8 | interp_points_num = { 9 | 'chin': 120, 10 | 'leb': 32, 11 | 'reb': 32, 12 | 'bon': 32, 13 | 'breath': 25, 14 | 'lue': 25, 15 | 'lle': 25, 16 | 'rue': 25, 17 | 'rle': 25, 18 | 'usul': 32, 19 | 'lsul': 32, 20 | 'usll': 32, 21 | 'lsll': 32 22 | } 23 | 24 | dataset_pdb_numbins = { 25 | '300W': 9, 26 | 'AFLW': 17, 27 | 'COFW': 7, 28 | 'WFLW': 13 29 | } 30 | 31 | dataset_route = { 32 | '300W': args.dataset_route+'/300W/', 33 | 'AFLW': args.dataset_route+'/AFLW/', 34 | 'COFW': args.dataset_route+'/COFW/', 35 | 'WFLW': args.dataset_route+'/WFLW/' 36 | } 37 | 38 | dataset_size = { 39 | '300W': { 40 | 'train': 3148, 41 | 'common_subset': 554, 42 | 'challenge_subset': 135, 43 | 'fullset': 689, 44 | '300W_testset': 600, 45 | 'COFW68': 507 # 该数据集用于300W数据集上训练模型的测试 46 | }, 47 | 'AFLW': { 48 | 'train': 20000, 49 | 'test': 24386, 50 | 'frontal': 1314 51 | }, 52 | 'COFW': { 53 | 'train': 1345, 54 | 'test': 507 55 | }, 56 | 'WFLW': { 57 | 'train': 7500, 58 | 'test': 2500, 59 | 'pose': 326, 60 | 'expression': 314, 61 | 'illumination': 698, 62 | 'makeup': 206, 63 | 'occlusion': 736, 64 | 'blur': 773 65 | } 66 | } 67 | 68 | kp_num = { 69 | '300W': 68, 70 | 'AFLW': 19, 71 | 'COFW': 29, 72 | 'WFLW': 98 73 | } 74 | 75 | point_num_per_boundary = { 76 | '300W': [17., 5., 5., 4., 5., 4., 4., 4., 4., 7., 5., 5., 7.], 77 | 'AFLW': [1., 3., 3., 1., 2., 3., 3., 3., 3., 3., 3., 3., 3.], 78 | 'COFW': [1., 3., 3., 1., 3., 3., 3., 3., 3., 3., 1., 1., 3.], 79 | 'WFLW': [33., 9., 9., 4., 5., 5., 5., 5., 5., 7., 5., 5., 7.] 80 | } 81 | 82 | boundary_special = { # 有些边界线条使用的关键点和其他边界形成不连续交集,特殊处理 83 | 'lle': ['300W', 'COFW', 'WFLW'], 84 | 'rle': ['300W', 'COFW', 'WFLW'], 85 | 'usll': ['300W', 'WFLW'], 86 | 'lsll': ['300W', 'COFW', 'WFLW'] 87 | } 88 | 89 | duplicate_point = { # 需要重复使用的关键点的序号,从0开始计数 90 | '300W': { 91 | 'lle': 36, 92 | 'rle': 42, 93 | 'usll': 60, 94 | 'lsll': 48 95 | }, 96 | 'COFW': { 97 | 'lle': 13, 98 | 'rle': 17, 99 | 'lsll': 21 100 | }, 101 | 'WFLW': { 102 | 'lle': 60, 103 | 'rle': 68, 104 | 'usll': 88, 105 | 'lsll': 76 106 | } 107 | } 108 | 109 | point_range = { # notice: this is 'range', the later number pluses 1; the order is boundary order; index starts from 0 110 | '300W': [ 111 | [0, 17], [17, 22], [22, 27], [27, 31], [31, 36], 112 | [36, 40], [39, 42], [42, 46], [45, 48], [48, 55], 113 | [60, 65], [64, 68], [54, 60] 114 | ], 115 | 'AFLW': [ 116 | [0, 1], [1, 4], [4, 7], [7, 8], [8, 10], 117 | [10, 13], [10, 13], [13, 16], [13, 16], [16, 19], 118 | [16, 19], [16, 19], [16, 19] 119 | ], 120 | 'COFW': [ 121 | [0, 1], [1, 4], [5, 8], [9, 10], [10, 13], 122 | [13, 16], [15, 17], [17, 20], [19, 21], [21, 24], 123 | [25, 26], [26, 27], [23, 25] 124 | ], 125 | 'WFLW': [ 126 | [0, 33], [33, 38], [42, 47], [51, 55], [55, 60], 127 | [60, 65], [64, 68], [68, 73], [72, 76], [76, 83], 128 | [88, 93], [92, 96], [82, 88] 129 | ] 130 | } 131 | 132 | flip_relation = { 133 | '300W': [ 134 | [0, 16], [1, 15], [2, 14], [3, 13], [4, 12], [5, 11], 135 | [6, 10], [7, 9], [8, 8], [9, 7], [10, 6], [11, 5], 136 | [12, 4], [13, 3], [14, 2], [15, 1], [16, 0], [17, 26], 137 | [18, 25], [19, 24], [20, 23], [21, 22], [22, 21], [23, 20], 138 | [24, 19], [25, 18], [26, 17], [27, 27], [28, 28], [29, 29], 139 | [30, 30], [31, 35], [32, 34], [33, 33], [34, 32], [35, 31], 140 | [36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46], 141 | [42, 39], [43, 38], [44, 37], [45, 36], [46, 41], [47, 40], 142 | [48, 54], [49, 53], [50, 52], [51, 51], [52, 50], [53, 49], 143 | [54, 48], [55, 59], [56, 58], [57, 57], [58, 56], [59, 55], 144 | [60, 64], [61, 63], [62, 62], [63, 61], [64, 60], [65, 67], 145 | [66, 66], [67, 65] 146 | ], 147 | 'AFLW': [ 148 | [0, 0], [1, 6], [2, 5], [3, 4], [4, 3], [5, 2], 149 | [6, 1], [7, 7], [8, 9], [9, 8], [10, 15], [11, 14], 150 | [12, 13], [13, 12], [14, 11], [15, 10], [16, 18], [17, 17], 151 | [18, 16] 152 | ], 153 | 'COFW': [ 154 | [0, 0], [1, 7], [2, 6], [3, 5], [4, 8], [5, 3], 155 | [6, 2], [7, 1], [8, 4], [9, 9], [10, 12], [11, 11], 156 | [12, 10], [13, 19], [14, 18], [15, 17], [16, 20], [17, 15], 157 | [18, 14], [19, 13], [20, 16], [21, 23], [22, 22], [23, 21], 158 | [24, 24], [25, 25], [26, 26], [27, 28], [28, 27] 159 | ], 160 | 'WFLW': [ 161 | [0, 32], [1, 31], [2, 30], [3, 29], [4, 28], [5, 27], 162 | [6, 26], [7, 25], [8, 24], [9, 23], [10, 22], [11, 21], 163 | [12, 20], [13, 19], [14, 18], [15, 17], [16, 16], [17, 15], 164 | [18, 14], [19, 13], [20, 12], [21, 11], [22, 10], [23, 9], 165 | [24, 8], [25, 7], [26, 6], [27, 5], [28, 4], [29, 3], 166 | [30, 2], [31, 1], [32, 0], [33, 46], [34, 45], [35, 44], 167 | [36, 43], [37, 42], [38, 50], [39, 49], [40, 48], [41, 47], 168 | [42, 37], [43, 36], [44, 35], [45, 34], [46, 33], [47, 41], 169 | [48, 40], [49, 39], [50, 38], [51, 51], [52, 52], [53, 53], 170 | [54, 54], [55, 59], [56, 58], [57, 57], [58, 56], [59, 55], 171 | [60, 72], [61, 71], [62, 70], [63, 69], [64, 68], [65, 75], 172 | [66, 74], [67, 73], [68, 64], [69, 63], [70, 62], [71, 61], 173 | [72, 60], [73, 67], [74, 66], [75, 65], [76, 82], [77, 81], 174 | [78, 80], [79, 79], [80, 78], [81, 77], [82, 76], [83, 87], 175 | [84, 86], [85, 85], [86, 84], [87, 83], [88, 92], [89, 91], 176 | [90, 90], [91, 89], [92, 88], [93, 95], [94, 94], [95, 93], 177 | [96, 97], [97, 96] 178 | ] 179 | } 180 | 181 | lo_eye_corner_index_x = {'300W': 72, 'AFLW': 20, 'COFW': 26, 'WFLW': 120} 182 | lo_eye_corner_index_y = {'300W': 73, 'AFLW': 21, 'COFW': 27, 'WFLW': 121} 183 | ro_eye_corner_index_x = {'300W': 90, 'AFLW': 30, 'COFW': 38, 'WFLW': 144} 184 | ro_eye_corner_index_y = {'300W': 91, 'AFLW': 31, 'COFW': 39, 'WFLW': 145} 185 | l_eye_center_index_x = {'300W': [72, 74, 76, 78, 80, 82], 'AFLW': 22, 'COFW': 54, 'WFLW': 192} 186 | l_eye_center_index_y = {'300W': [73, 75, 77, 79, 81, 83], 'AFLW': 23, 'COFW': 55, 'WFLW': 193} 187 | r_eye_center_index_x = {'300W': [84, 86, 88, 90, 92, 94], 'AFLW': 28, 'COFW': 56, 'WFLW': 194} 188 | r_eye_center_index_y = {'300W': [85, 87, 89, 91, 93, 95], 'AFLW': 29, 'COFW': 57, 'WFLW': 195} 189 | -------------------------------------------------------------------------------- /utils/train_eval_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.backends.cudnn as cudnn 3 | from collections import OrderedDict 4 | import numpy as np 5 | from sklearn.metrics import auc 6 | from utils import * 7 | 8 | 9 | def get_devices_list(arg): 10 | devices_list = [torch.device('cpu')] 11 | if arg.cuda and torch.cuda.is_available(): 12 | devices_list = [] 13 | for dev in arg.gpu_id.split(','): 14 | devices_list.append(torch.device('cuda:'+dev)) 15 | cudnn.benchmark = True 16 | cudnn.enabled = True 17 | return devices_list 18 | 19 | 20 | def load_weights(net, pth_file, device): 21 | state_dict = torch.load(pth_file, map_location=device) 22 | # create new OrderedDict that does not contain `module.` 23 | new_state_dict = OrderedDict() 24 | for k, v in state_dict.items(): 25 | head = k[:7] 26 | if head == 'module.': 27 | name = k[7:] # remove `module.` 28 | else: 29 | name = k 30 | new_state_dict[name] = v 31 | net.load_state_dict(new_state_dict) 32 | return net 33 | 34 | 35 | def create_model(arg, devices_list): 36 | from models import Estimator, Discrim, Regressor 37 | 38 | estimator = Estimator(stacks=arg.hour_stack, msg_pass=arg.msg_pass) 39 | regressor = Regressor(fuse_stages=arg.fuse_stage, output=2*kp_num[arg.dataset]) 40 | discrim = Discrim() if arg.GAN else None 41 | 42 | if arg.resume_epoch > 0: 43 | estimator = load_weights(estimator, arg.resume_folder + 'estimator_' + str(arg.resume_epoch) + '.pth', 44 | devices_list[0]) 45 | regressor = load_weights(regressor, arg.resume_folder + arg.dataset+'_regressor_' + 46 | str(arg.resume_epoch) + '.pth', devices_list[0]) 47 | discrim = load_weights(discrim, arg.resume_folder + 'discrim_' + str(arg.resume_epoch) + '.pth', 48 | devices_list[0]) if arg.GAN else None 49 | 50 | if arg.cuda: 51 | estimator = estimator.cuda(device=devices_list[0]) 52 | regressor = regressor.cuda(device=devices_list[0]) 53 | discrim = discrim.cuda(device=devices_list[0]) if arg.GAN else None 54 | 55 | return estimator, regressor, discrim 56 | 57 | 58 | def calc_d_fake(dataset, pred_coords, gt_coords, bcsize, bcsize_set): 59 | error_regressor = (pred_coords - gt_coords) ** 2 60 | dist_regressor = torch.zeros(bcsize, kp_num[dataset]) 61 | dfake = torch.zeros(bcsize_set, boundary_num) 62 | for batch in range(bcsize): 63 | dist_regressor[batch, :] = \ 64 | (error_regressor[batch][:2*kp_num[dataset]:2] + error_regressor[batch][1:2*kp_num[dataset]:2]) \ 65 | < args.theta*args.theta 66 | for batch_index in range(bcsize): 67 | for boundary_index in range(boundary_num): 68 | for kp_index in range( 69 | point_range[dataset][boundary_index][0], 70 | point_range[dataset][boundary_index][1] 71 | ): 72 | if dist_regressor[batch_index][kp_index] == 1: 73 | dfake[batch_index][boundary_index] += 1 74 | if boundary_keys[boundary_index] in boundary_special.keys() and \ 75 | dataset in boundary_special[boundary_keys[boundary_index]] and \ 76 | dist_regressor[batch_index][duplicate_point[dataset][boundary_keys[boundary_index]]] == 1: 77 | dfake[batch_index][boundary_index] += 1 78 | for boundary_index in range(boundary_num): 79 | if dfake[batch_index][boundary_index] / point_num_per_boundary[dataset][boundary_index] < args.delta: 80 | dfake[batch_index][boundary_index] = 0. 81 | else: 82 | dfake[batch_index][boundary_index] = 1. 83 | if bcsize < bcsize_set: 84 | for batch_index in range(bcsize, bcsize_set): 85 | dfake[batch_index] = dfake[batch_index - bcsize] 86 | return dfake 87 | 88 | 89 | def calc_normalize_factor(dataset, gt_coords_xy, normalize_way='inter_pupil'): 90 | if normalize_way == 'inter_ocular': 91 | error_normalize_factor = np.sqrt( 92 | (gt_coords_xy[0][lo_eye_corner_index_x[dataset]] - gt_coords_xy[0][ro_eye_corner_index_x[dataset]]) * 93 | (gt_coords_xy[0][lo_eye_corner_index_x[dataset]] - gt_coords_xy[0][ro_eye_corner_index_x[dataset]]) + 94 | (gt_coords_xy[0][lo_eye_corner_index_y[dataset]] - gt_coords_xy[0][ro_eye_corner_index_y[dataset]]) * 95 | (gt_coords_xy[0][lo_eye_corner_index_y[dataset]] - gt_coords_xy[0][ro_eye_corner_index_y[dataset]])) 96 | return error_normalize_factor 97 | elif normalize_way == 'inter_pupil': 98 | if l_eye_center_index_x[dataset].__class__ != list: 99 | error_normalize_factor = np.sqrt( 100 | (gt_coords_xy[0][l_eye_center_index_x[dataset]] - gt_coords_xy[0][r_eye_center_index_x[dataset]]) * 101 | (gt_coords_xy[0][l_eye_center_index_x[dataset]] - gt_coords_xy[0][r_eye_center_index_x[dataset]]) + 102 | (gt_coords_xy[0][l_eye_center_index_y[dataset]] - gt_coords_xy[0][r_eye_center_index_y[dataset]]) * 103 | (gt_coords_xy[0][l_eye_center_index_y[dataset]] - gt_coords_xy[0][r_eye_center_index_y[dataset]])) 104 | return error_normalize_factor 105 | else: 106 | length = len(l_eye_center_index_x[dataset]) 107 | l_eye_x_avg, l_eye_y_avg, r_eye_x_avg, r_eye_y_avg = 0., 0., 0., 0. 108 | for i in range(length): 109 | l_eye_x_avg += gt_coords_xy[0][l_eye_center_index_x[dataset][i]] 110 | l_eye_y_avg += gt_coords_xy[0][l_eye_center_index_y[dataset][i]] 111 | r_eye_x_avg += gt_coords_xy[0][r_eye_center_index_x[dataset][i]] 112 | r_eye_y_avg += gt_coords_xy[0][r_eye_center_index_y[dataset][i]] 113 | l_eye_x_avg /= length 114 | l_eye_y_avg /= length 115 | r_eye_x_avg /= length 116 | r_eye_y_avg /= length 117 | error_normalize_factor = np.sqrt((l_eye_x_avg - r_eye_x_avg) * (l_eye_x_avg - r_eye_x_avg) + 118 | (l_eye_y_avg - r_eye_y_avg) * (l_eye_y_avg - r_eye_y_avg)) 119 | return error_normalize_factor 120 | 121 | 122 | def inverse_affine(arg, pred_coords, bbox): 123 | import copy 124 | pred_coords = copy.deepcopy(pred_coords) 125 | for i in range(kp_num[arg.dataset]): 126 | pred_coords[2 * i] = bbox[0] + pred_coords[2 * i]/(arg.crop_size-1)*(bbox[2] - bbox[0]) 127 | pred_coords[2 * i + 1] = bbox[1] + pred_coords[2 * i + 1]/(arg.crop_size-1)*(bbox[3] - bbox[1]) 128 | return pred_coords 129 | 130 | 131 | def calc_error_rate_i(dataset, pred_coords, gt_coords_xy, error_normalize_factor): 132 | temp, error = (pred_coords - gt_coords_xy)**2, 0. 133 | for i in range(kp_num[dataset]): 134 | error += np.sqrt(temp[2*i] + temp[2*i+1]) 135 | return error/kp_num[dataset]/error_normalize_factor 136 | 137 | 138 | def calc_auc(dataset, split, error_rate, max_threshold): 139 | error_rate = np.array(error_rate) 140 | threshold = np.linspace(0, max_threshold, num=2000) 141 | accuracys = np.zeros(threshold.shape) 142 | for i in range(threshold.size): 143 | accuracys[i] = np.sum(error_rate < threshold[i]) * 1.0 / dataset_size[dataset][split] 144 | return auc(threshold, accuracys) / max_threshold, accuracys 145 | -------------------------------------------------------------------------------- /utils/visual.py: -------------------------------------------------------------------------------- 1 | from .dataset_info import * 2 | 3 | import cv2 4 | import numpy as np 5 | import torch.nn.functional as F 6 | import matplotlib.pyplot as plt 7 | from scipy.interpolate import spline 8 | 9 | 10 | def show_img(pic, name='pic', x=0, y=0, wait=0): 11 | cv2.imshow(name, pic) 12 | cv2.moveWindow(name, x, y) 13 | cv2.waitKey(wait) 14 | cv2.destroyAllWindows() 15 | 16 | 17 | def watch_gray_heatmap(gt_heatmap): 18 | heatmap_sum = gt_heatmap[0] 19 | for index in range(boundary_num - 1): 20 | heatmap_sum += gt_heatmap[index + 1] 21 | show_img(heatmap_sum, 'heatmap_sum') 22 | 23 | 24 | def watch_pic_kp(dataset, pic, kp): 25 | for kp_index in range(kp_num[dataset]): 26 | cv2.circle( 27 | pic, 28 | (int(kp[2*kp_index]), int(kp[2*kp_index+1])), 29 | 1, 30 | (0, 0, 255) 31 | ) 32 | show_img(pic) 33 | 34 | 35 | def watch_pic_kp_xy(dataset, pic, coord_x, coord_y): 36 | for kp_index in range(kp_num[dataset]): 37 | cv2.circle( 38 | pic, 39 | (int(coord_x[kp_index]), int(coord_y[kp_index])), 40 | 1, 41 | (0, 0, 255) 42 | ) 43 | show_img(pic) 44 | 45 | 46 | def eval_heatmap(arg, heatmaps, img_name, bbox, save_img=False): 47 | heatmaps = F.interpolate(heatmaps, scale_factor=4, mode='bilinear', align_corners=True) 48 | heatmaps = heatmaps.squeeze(0).detach().cpu().numpy() 49 | heatmaps_sum = heatmaps[0] 50 | for heatmaps_index in range(boundary_num-1): 51 | heatmaps_sum += heatmaps[heatmaps_index+1] 52 | plt.axis('off') 53 | plt.imshow(heatmaps_sum, interpolation='nearest', vmax=1., vmin=0.) 54 | if save_img: 55 | import os 56 | if not os.path.exists('./imgs'): 57 | os.mkdir('./imgs') 58 | fig = plt.gcf() 59 | fig.set_size_inches(2.56 / 3, 2.56 / 3) 60 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 61 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 62 | plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) 63 | plt.margins(0, 0) 64 | name = (img_name[0]).split('/')[-1] 65 | fig.savefig('./imgs/'+name.split('.')[0]+'_hm.png', format='png', transparent=True, dpi=300, pad_inches=0) 66 | 67 | pic = cv2.imread(dataset_route[arg.dataset] + img_name[0]) 68 | position_before = np.float32([ 69 | [int(bbox[0]), int(bbox[1])], 70 | [int(bbox[0]), int(bbox[3])], 71 | [int(bbox[2]), int(bbox[3])] 72 | ]) 73 | position_after = np.float32([ 74 | [0, 0], 75 | [0, arg.crop_size - 1], 76 | [arg.crop_size - 1, arg.crop_size - 1] 77 | ]) 78 | crop_matrix = cv2.getAffineTransform(position_before, position_after) 79 | pic = cv2.warpAffine(pic, crop_matrix, (arg.crop_size, arg.crop_size)) 80 | cv2.imwrite('./imgs/' + name.split('.')[0] + '_init.png', pic) 81 | hm = cv2.imread('./imgs/'+name.split('.')[0]+'_hm.png') 82 | syn = cv2.addWeighted(pic, 0.4, hm, 0.6, 0) 83 | cv2.imwrite('./imgs/'+name.split('.')[0]+'_syn.png', syn) 84 | plt.show() 85 | 86 | 87 | def eval_pred_points(arg, pred_coords, img_name, bbox, save_img=False): 88 | pic = cv2.imread(dataset_route[arg.dataset] + img_name[0]) 89 | position_before = np.float32([ 90 | [int(bbox[0]), int(bbox[1])], 91 | [int(bbox[0]), int(bbox[3])], 92 | [int(bbox[2]), int(bbox[3])] 93 | ]) 94 | position_after = np.float32([ 95 | [0, 0], 96 | [0, arg.crop_size - 1], 97 | [arg.crop_size - 1, arg.crop_size - 1] 98 | ]) 99 | crop_matrix = cv2.getAffineTransform(position_before, position_after) 100 | pic = cv2.warpAffine(pic, crop_matrix, (arg.crop_size, arg.crop_size)) 101 | 102 | for coord_index in range(kp_num[arg.dataset]): 103 | cv2.circle(pic, (int(pred_coords[2 * coord_index]), int(pred_coords[2 * coord_index + 1])), 2, (0, 0, 255)) 104 | if save_img: 105 | import os 106 | if not os.path.exists('./imgs'): 107 | os.mkdir('./imgs') 108 | name = (img_name[0]).split('/')[-1] 109 | cv2.imwrite('./imgs/'+name.split('.')[0]+'_lmk.png', pic) 110 | show_img(pic) 111 | 112 | 113 | def eval_gt_pred_points(arg, gt_coords, pred_coords, img_name, bbox, save_img=False): 114 | pic = cv2.imread(dataset_route[arg.dataset] + img_name[0]) 115 | position_before = np.float32([ 116 | [int(bbox[0]), int(bbox[1])], 117 | [int(bbox[0]), int(bbox[3])], 118 | [int(bbox[2]), int(bbox[3])] 119 | ]) 120 | position_after = np.float32([ 121 | [0, 0], 122 | [0, arg.crop_size - 1], 123 | [arg.crop_size - 1, arg.crop_size - 1] 124 | ]) 125 | crop_matrix = cv2.getAffineTransform(position_before, position_after) 126 | pic = cv2.warpAffine(pic, crop_matrix, (arg.crop_size, arg.crop_size)) 127 | 128 | for coord_index in range(kp_num[arg.dataset]): 129 | cv2.circle(pic, (int(pred_coords[2 * coord_index]), int(pred_coords[2 * coord_index + 1])), 2, (0, 0, 255)) 130 | cv2.circle(pic, (int(gt_coords[2 * coord_index]), int(gt_coords[2 * coord_index + 1])), 2, (0, 255, 0)) 131 | if save_img: 132 | import os 133 | if not os.path.exists('./imgs'): 134 | os.mkdir('./imgs') 135 | name = (img_name[0]).split('/')[-1] 136 | cv2.imwrite('./imgs/'+name.split('.')[0]+'_lmk.png', pic) 137 | show_img(pic) 138 | 139 | 140 | def eval_CED(auc_record): 141 | error = np.linspace(0., 0.1, 21) 142 | error_new = np.linspace(error.min(), error.max(), 300) 143 | auc_value = np.array([auc_record[0], auc_record[99], auc_record[199], auc_record[299], 144 | auc_record[399], auc_record[499], auc_record[599], auc_record[699], 145 | auc_record[799], auc_record[899], auc_record[999], auc_record[1099], 146 | auc_record[1199], auc_record[1299], auc_record[1399], auc_record[1499], 147 | auc_record[1599], auc_record[1699], auc_record[1799], auc_record[1899], 148 | auc_record[1999]]) 149 | CFSS_auc_value = np.array([0., 0., 0., 0., 0., 150 | 0., 0.02, 0.09, 0.18, 0.30, 151 | 0.45, 0.60, 0.70, 0.75, 0.79, 152 | 0.82, 0.85, 0.87, 0.88, 0.89, 0.90]) 153 | SAPM_auc_value = np.array([0., 0., 0., 0., 0., 154 | 0., 0., 0., 0.02, 0.08, 155 | 0.17, 0.28, 0.43, 0.58, 0.71, 156 | 0.78, 0.83, 0.86, 0.89, 0.91, 0.92]) 157 | TCDCN_auc_value = np.array([0., 0., 0., 0., 0., 158 | 0., 0., 0.02, 0.05, 0.10, 159 | 0.19, 0.29, 0.38, 0.47, 0.56, 160 | 0.64, 0.70, 0.75, 0.79, 0.82, 0.826]) 161 | auc_smooth = spline(error, auc_value, error_new) 162 | CFSS_auc_smooth = spline(error, CFSS_auc_value, error_new) 163 | SAPM_auc_smooth = spline(error, SAPM_auc_value, error_new) 164 | TCDCN_auc_smooth = spline(error, TCDCN_auc_value, error_new) 165 | plt.plot(error_new, auc_smooth, 'r-') 166 | plt.plot(error_new, CFSS_auc_smooth, 'g-') 167 | plt.plot(error_new, SAPM_auc_smooth, 'y-') 168 | plt.plot(error_new, TCDCN_auc_smooth, 'm-') 169 | plt.legend(['LAB, Error: 5.35%, Failure: 4.73%', 170 | 'CFSS, Error: 6.28%, Failure: 9.07%', 171 | 'SAPM, Error: 6.64%, Failure: 5.72%', 172 | 'TCDCN, Error: 7.66%, Failure: 16.17%'], loc=4) 173 | plt.plot(error, auc_value, 'rs') 174 | plt.plot(error, CFSS_auc_value, 'go') 175 | plt.plot(error, SAPM_auc_value, 'y^') 176 | plt.plot(error, TCDCN_auc_value, 'mx') 177 | plt.axis([0., 0.1, 0., 1.]) 178 | plt.show() 179 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from models import WingLoss, Estimator, Regressor, Discrim 5 | from dataset import GeneralDataset 6 | from utils import * 7 | import tqdm 8 | 9 | if not os.path.exists(args.save_folder): 10 | os.mkdir(args.save_folder) 11 | if not os.path.exists(args.resume_folder): 12 | os.mkdir(args.resume_folder) 13 | 14 | 15 | def train(arg): 16 | epoch = None 17 | devices = get_devices_list(arg) 18 | 19 | print('***** Normal Training *****') 20 | print('Training parameters:\n' + 21 | '# Dataset: ' + arg.dataset + '\n' + 22 | '# Dataset split: ' + arg.split + '\n' + 23 | '# Batchsize: ' + str(arg.batch_size) + '\n' + 24 | '# Num workers: ' + str(arg.workers) + '\n' + 25 | '# PDB: ' + str(arg.PDB) + '\n' + 26 | '# Use GPU: ' + str(arg.cuda) + '\n' + 27 | '# Start lr: ' + str(arg.lr) + '\n' + 28 | '# Max epoch: ' + str(arg.max_epoch) + '\n' + 29 | '# Loss type: ' + arg.loss_type + '\n' + 30 | '# Resumed model: ' + str(arg.resume_epoch > 0)) 31 | if arg.resume_epoch > 0: 32 | print('# Resumed epoch: ' + str(arg.resume_epoch)) 33 | 34 | print('Creating networks ...') 35 | estimator, regressor, discrim = create_model(arg, devices) 36 | estimator.train() 37 | regressor.train() 38 | if discrim is not None: 39 | discrim.train() 40 | print('Creating networks done!') 41 | 42 | optimizer_estimator = torch.optim.SGD(estimator.parameters(), lr=arg.lr, momentum=arg.momentum, 43 | weight_decay=arg.weight_decay) 44 | optimizer_regressor = torch.optim.SGD(regressor.parameters(), lr=arg.lr, momentum=arg.momentum, 45 | weight_decay=arg.weight_decay) 46 | optimizer_discrim = torch.optim.SGD(discrim.parameters(), lr=arg.lr, momentum=arg.momentum, 47 | weight_decay=arg.weight_decay) if discrim is not None else None 48 | 49 | if arg.loss_type == 'L2': 50 | criterion = nn.MSELoss() 51 | elif arg.loss_type == 'L1': 52 | criterion = nn.L1Loss() 53 | elif arg.loss_type == 'smoothL1': 54 | criterion = nn.SmoothL1Loss() 55 | else: 56 | criterion = WingLoss(w=arg.wingloss_w, epsilon=arg.wingloss_e) 57 | 58 | print('Loading dataset ...') 59 | trainset = GeneralDataset(dataset=arg.dataset) 60 | print('Loading dataset done!') 61 | 62 | d_fake = (torch.zeros(arg.batch_size, 13)).cuda(device=devices[0]) if arg.GAN \ 63 | else torch.zeros(arg.batch_size, 13) 64 | 65 | # evolving training 66 | print('Start training ...') 67 | for epoch in range(arg.resume_epoch, arg.max_epoch): 68 | forward_times_per_epoch, sum_loss_estimator, sum_loss_regressor = 0, 0., 0. 69 | dataloader = torch.utils.data.DataLoader(trainset, batch_size=arg.batch_size, shuffle=arg.shuffle, 70 | num_workers=arg.workers, pin_memory=True) 71 | 72 | if epoch in arg.step_values: 73 | optimizer_estimator.param_groups[0]['lr'] *= arg.gamma 74 | optimizer_regressor.param_groups[0]['lr'] *= arg.gamma 75 | optimizer_discrim.param_groups[0]['lr'] *= arg.gamma 76 | 77 | for data in tqdm.tqdm(dataloader): 78 | forward_times_per_epoch += 1 79 | input_images, gt_coords_xy, gt_heatmap, _, _, _ = data 80 | true_batchsize = input_images.size()[0] 81 | input_images = input_images.unsqueeze(1) 82 | input_images = input_images.cuda(device=devices[0]) 83 | gt_coords_xy = gt_coords_xy.cuda(device=devices[0]) 84 | gt_heatmap = gt_heatmap.cuda(device=devices[0]) 85 | 86 | optimizer_estimator.zero_grad() 87 | heatmaps = estimator(input_images) 88 | loss_G = estimator.calc_loss(heatmaps, gt_heatmap) 89 | loss_A = torch.mean(torch.log2(1. - discrim(heatmaps[-1]))) 90 | loss_estimator = loss_G + loss_A 91 | loss_estimator.backward() 92 | optimizer_estimator.step() 93 | 94 | sum_loss_estimator += loss_estimator 95 | 96 | optimizer_discrim.zero_grad() 97 | loss_D_real = -torch.mean(torch.log2(discrim(gt_heatmap))) 98 | loss_D_fake = -torch.mean(torch.log2(1.-torch.abs(discrim(heatmaps[-1].detach()) - 99 | d_fake[:true_batchsize]))) 100 | loss_D = loss_D_real + loss_D_fake 101 | loss_D.backward() 102 | optimizer_discrim.step() 103 | 104 | optimizer_regressor.zero_grad() 105 | out = regressor(input_images, heatmaps[-1].detach()) 106 | loss_regressor = criterion(out, gt_coords_xy) 107 | loss_regressor.backward() 108 | optimizer_regressor.step() 109 | 110 | d_fake = (calc_d_fake(arg.dataset, out.detach(), gt_coords_xy, true_batchsize, 111 | arg.batch_size)).cuda(device=devices[0]) 112 | 113 | sum_loss_regressor += loss_regressor 114 | 115 | if (epoch+1) % arg.save_interval == 0: 116 | torch.save(estimator.state_dict(), arg.save_folder + 'estimator_'+str(epoch+1)+'.pth') 117 | torch.save(discrim.state_dict(), arg.save_folder + 'discrim_'+str(epoch+1)+'.pth') 118 | torch.save(regressor.state_dict(), arg.save_folder + arg.dataset+'_regressor_'+str(epoch+1)+'.pth') 119 | 120 | print('\nepoch: {:0>4d} | loss_estimator: {:.2f} | loss_regressor: {:.2f}'.format( 121 | epoch, 122 | sum_loss_estimator.item()/forward_times_per_epoch, 123 | sum_loss_regressor.item()/forward_times_per_epoch 124 | )) 125 | 126 | torch.save(estimator.state_dict(), arg.save_folder + 'estimator_'+str(epoch+1)+'.pth') 127 | torch.save(discrim.state_dict(), arg.save_folder + 'discrim_'+str(epoch+1)+'.pth') 128 | torch.save(regressor.state_dict(), arg.save_folder + arg.dataset+'_regressor_'+str(epoch+1)+'.pth') 129 | print('Training done!') 130 | 131 | 132 | def train_with_gt_heatmap(arg): 133 | epoch = None 134 | devices = get_devices_list(arg) 135 | 136 | print('***** Training with ground truth heatmap *****') 137 | print('Training parameters:\n' + 138 | '# Dataset: ' + arg.dataset + '\n' + 139 | '# Dataset split: ' + arg.split + '\n' + 140 | '# Batchsize: ' + str(arg.batch_size) + '\n' + 141 | '# Num workers: ' + str(arg.workers) + '\n' + 142 | '# PDB: ' + str(arg.PDB) + '\n' + 143 | '# Use GPU: ' + str(arg.cuda) + '\n' + 144 | '# Start lr: ' + str(arg.lr) + '\n' + 145 | '# Lr step values: ' + str(arg.step_values) + '\n' + 146 | '# Lr step gamma: ' + str(arg.gamma) + '\n' + 147 | '# Max epoch: ' + str(arg.max_epoch) + '\n' + 148 | '# Loss type: ' + arg.loss_type + '\n' + 149 | '# Resumed model: ' + str(arg.resume_epoch > 0)) 150 | if arg.resume_epoch > 0: 151 | print('# Resumed epoch: ' + str(arg.resume_epoch)) 152 | 153 | print('Creating networks ...') 154 | regressor = Regressor(fuse_stages=arg.fuse_stage, output=2 * kp_num[arg.dataset]) 155 | regressor = load_weights(regressor, arg.resume_folder + arg.dataset + '_regressor_' + 156 | str(arg.resume_epoch) + '.pth', devices_list[0]) if arg.resume_epoch > 0 else regressor 157 | regressor = regressor.cuda(device=devices[0]) 158 | regressor.train() 159 | print('Creating networks done!') 160 | 161 | optimizer_regressor = torch.optim.SGD(regressor.parameters(), lr=arg.lr, momentum=arg.momentum, 162 | weight_decay=arg.weight_decay) 163 | 164 | if arg.loss_type == 'L2': 165 | criterion = nn.MSELoss() 166 | elif arg.loss_type == 'L1': 167 | criterion = nn.L1Loss() 168 | elif arg.loss_type == 'smoothL1': 169 | criterion = nn.SmoothL1Loss() 170 | else: 171 | criterion = WingLoss(w=arg.wingloss_w, epsilon=arg.wingloss_e) 172 | 173 | print('Loading dataset ...') 174 | trainset = GeneralDataset(dataset=arg.dataset) 175 | print('Loading dataset done!') 176 | 177 | print('Start training ...') 178 | for epoch in range(arg.resume_epoch, arg.max_epoch): 179 | forward_times_per_epoch, sum_loss_regressor = 0, 0. 180 | dataloader = torch.utils.data.DataLoader(trainset, batch_size=arg.batch_size, shuffle=arg.shuffle, 181 | num_workers=arg.workers, pin_memory=True) 182 | 183 | if epoch in arg.step_values: 184 | optimizer_regressor.param_groups[0]['lr'] *= arg.gamma 185 | 186 | for data in tqdm.tqdm(dataloader): 187 | forward_times_per_epoch += 1 188 | input_images, gt_coords_xy, gt_heatmap, _, _, _ = data 189 | input_images = input_images.unsqueeze(1) 190 | input_images = input_images.cuda(device=devices[0]) 191 | gt_coords_xy = gt_coords_xy.cuda(device=devices[0]) 192 | gt_heatmap = gt_heatmap.cuda(device=devices[0]) 193 | 194 | optimizer_regressor.zero_grad() 195 | out = regressor(input_images, gt_heatmap) 196 | loss_regressor = criterion(out, gt_coords_xy) 197 | loss_regressor.backward() 198 | optimizer_regressor.step() 199 | 200 | sum_loss_regressor += loss_regressor 201 | 202 | if (epoch + 1) % arg.save_interval == 0: 203 | torch.save(regressor.state_dict(), arg.save_folder + arg.dataset + '_regressor_' + str(epoch + 1) + '.pth') 204 | 205 | print('\nepoch: {:0>4d} | loss_regressor: {:.2f}'.format( 206 | epoch, 207 | sum_loss_regressor.item() / forward_times_per_epoch 208 | )) 209 | 210 | torch.save(regressor.state_dict(), arg.save_folder + arg.dataset + '_regressor_' + str(epoch + 1) + '.pth') 211 | print('Training done!') 212 | 213 | 214 | if __name__ == '__main__': 215 | train(args) 216 | -------------------------------------------------------------------------------- /utils/dataload.py: -------------------------------------------------------------------------------- 1 | from .dataset_info import * 2 | from .args import args 3 | from .pdb import pdb 4 | from .visual import * 5 | 6 | import cv2 7 | import time 8 | import random 9 | import numpy as np 10 | from scipy.interpolate import splprep, splev 11 | 12 | 13 | def get_annotations_list(dataset, split, ispdb=False): 14 | annotations = [] 15 | annotation_file = open(dataset_route[dataset] + dataset + '_' + split + '_annos.txt') 16 | 17 | for line in range(dataset_size[dataset][split]): 18 | annotations.append(annotation_file.readline().rstrip().split()) 19 | annotation_file.close() 20 | 21 | if ispdb: 22 | annos = [] 23 | allshapes = np.zeros((2 * kp_num[dataset], len(annotations))) 24 | for line_index, line in enumerate(annotations): 25 | coord_x = np.array(list(map(float, line[:2*kp_num[dataset]:2]))) 26 | coord_y = np.array(list(map(float, line[1:2*kp_num[dataset]:2]))) 27 | position_before = np.float32([[int(line[-7]), int(line[-6])], 28 | [int(line[-7]), int(line[-4])], 29 | [int(line[-5]), int(line[-4])]]) 30 | position_after = np.float32([[0, 0], 31 | [0, args.crop_size - 1], 32 | [args.crop_size - 1, args.crop_size - 1]]) 33 | crop_matrix = cv2.getAffineTransform(position_before, position_after) 34 | coord_x_after_crop = crop_matrix[0][0] * coord_x + crop_matrix[0][1] * coord_y + crop_matrix[0][2] 35 | coord_y_after_crop = crop_matrix[1][0] * coord_x + crop_matrix[1][1] * coord_y + crop_matrix[1][2] 36 | allshapes[0:kp_num[dataset], line_index] = list(coord_x_after_crop) 37 | allshapes[kp_num[dataset]:2*kp_num[dataset], line_index] = list(coord_y_after_crop) 38 | newidx = pdb(dataset, allshapes, dataset_pdb_numbins[dataset]) 39 | for id_index in newidx: 40 | annos.append(annotations[int(id_index)]) 41 | return annos 42 | 43 | return annotations 44 | 45 | 46 | def convert_img_to_gray(img): 47 | if img.shape[2] == 1: 48 | return img 49 | elif img.shape[2] == 4: 50 | gray = cv2.cvtColor(img, cv2.COLOR_BGRA2GRAY) 51 | return gray 52 | elif img.shape[2] == 3: 53 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 54 | return gray 55 | else: 56 | raise Exception("img shape wrong!\n") 57 | 58 | 59 | def get_random_transform_param(split, bbox): 60 | translation, trans_dir, rotation, scaling, flip, gaussian_blur = 0, 0, 0, 1., 0, 0 61 | if split in ['train']: 62 | random.seed(time.time()) 63 | translate_param = int(args.trans_ratio * abs(bbox[2] - bbox[0])) 64 | translation = random.randint(-translate_param, translate_param) 65 | trans_dir = random.randint(0, 3) # LU:0 RU:1 LL:2 RL:3 66 | rotation = random.uniform(-args.rotate_limit, args.rotate_limit) 67 | scaling = random.uniform(1-args.scale_ratio, 1+args.scale_ratio) 68 | flip = random.randint(0, 1) 69 | gaussian_blur = random.randint(0, 1) 70 | return translation, trans_dir, rotation, scaling, flip, gaussian_blur 71 | 72 | 73 | def further_transform(pic, bbox, flip, gaussian_blur): 74 | if flip == 1: 75 | pic = cv2.flip(pic, 1) 76 | if abs(bbox[2] - bbox[0]) < 120 or gaussian_blur == 0: 77 | return pic 78 | else: 79 | return cv2.GaussianBlur(pic, (5, 5), 1) 80 | 81 | 82 | def get_affine_matrix(crop_size, rotation, scaling): 83 | center = (crop_size / 2.0, crop_size / 2.0) 84 | return cv2.getRotationMatrix2D(center, rotation, scaling) 85 | 86 | 87 | def pic_normalize(pic): # for accelerate, now support gray pic only 88 | pic = np.float32(pic) 89 | mean, std = cv2.meanStdDev(pic) 90 | pic_channel = 1 if len(pic.shape) == 2 else 3 91 | for channel in range(0, pic_channel): 92 | if std[channel][0] < 1e-6: 93 | std[channel][0] = 1 94 | pic = (pic - mean) / std 95 | return np.float32(pic) 96 | 97 | 98 | def get_cropped_coords(dataset, crop_matrix, coord_x, coord_y, flip=0): 99 | coord_x, coord_y = np.array(coord_x), np.array(coord_y) 100 | temp_x = crop_matrix[0][0] * coord_x + crop_matrix[0][1] * coord_y + crop_matrix[0][2] if flip == 0 else \ 101 | float(args.crop_size) - (crop_matrix[0][0] * coord_x + crop_matrix[0][1] * coord_y + crop_matrix[0][2]) - 1 102 | temp_y = crop_matrix[1][0] * coord_x + crop_matrix[1][1] * coord_y + crop_matrix[1][2] 103 | if flip: 104 | temp_x = temp_x[np.array(flip_relation[dataset])[:, 1]] 105 | temp_y = temp_y[np.array(flip_relation[dataset])[:, 1]] 106 | return temp_x, temp_y 107 | 108 | 109 | def get_gt_coords(dataset, affine_matrix, coord_x, coord_y): 110 | out = np.zeros(2*kp_num[dataset]) 111 | out[:2*kp_num[dataset]:2] = affine_matrix[0][0] * coord_x + affine_matrix[0][1] * coord_y + affine_matrix[0][2] 112 | out[1:2*kp_num[dataset]:2] = affine_matrix[1][0] * coord_x + affine_matrix[1][1] * coord_y + affine_matrix[1][2] 113 | return np.array(np.float32(out)) 114 | 115 | 116 | def get_gt_heatmap(dataset, gt_coords): 117 | coord_x, coord_y, gt_heatmap = [], [], [] 118 | for index in range(boundary_num): 119 | gt_heatmap.append(np.ones((64, 64))) 120 | gt_heatmap[index].tolist() 121 | boundary_x = {'chin': [], 'leb': [], 'reb': [], 'bon': [], 'breath': [], 'lue': [], 'lle': [], 122 | 'rue': [], 'rle': [], 'usul': [], 'lsul': [], 'usll': [], 'lsll': []} 123 | boundary_y = {'chin': [], 'leb': [], 'reb': [], 'bon': [], 'breath': [], 'lue': [], 'lle': [], 124 | 'rue': [], 'rle': [], 'usul': [], 'lsul': [], 'usll': [], 'lsll': []} 125 | points = {'chin': [], 'leb': [], 'reb': [], 'bon': [], 'breath': [], 'lue': [], 'lle': [], 126 | 'rue': [], 'rle': [], 'usul': [], 'lsul': [], 'usll': [], 'lsll': []} 127 | resize_matrix = cv2.getAffineTransform(np.float32([[0, 0], [0, args.crop_size-1], 128 | [args.crop_size-1, args.crop_size-1]]), 129 | np.float32([[0, 0], [0, heatmap_size-1], 130 | [heatmap_size-1, heatmap_size-1]])) 131 | for kp_index in range(kp_num[dataset]): 132 | coord_x.append( 133 | resize_matrix[0][0] * gt_coords[2 * kp_index] + 134 | resize_matrix[0][1] * gt_coords[2 * kp_index + 1] + 135 | resize_matrix[0][2] + random.uniform(-0.2, 0.2) 136 | ) 137 | coord_y.append( 138 | resize_matrix[1][0] * gt_coords[2 * kp_index] + 139 | resize_matrix[1][1] * gt_coords[2 * kp_index + 1] + 140 | resize_matrix[1][2] + random.uniform(-0.2, 0.2) 141 | ) 142 | for boundary_index in range(boundary_num): 143 | for kp_index in range( 144 | point_range[dataset][boundary_index][0], 145 | point_range[dataset][boundary_index][1] 146 | ): 147 | boundary_x[boundary_keys[boundary_index]].append(coord_x[kp_index]) 148 | boundary_y[boundary_keys[boundary_index]].append(coord_y[kp_index]) 149 | if boundary_keys[boundary_index] in boundary_special.keys() and\ 150 | dataset in boundary_special[boundary_keys[boundary_index]]: 151 | boundary_x[boundary_keys[boundary_index]].append( 152 | coord_x[duplicate_point[dataset][boundary_keys[boundary_index]]]) 153 | boundary_y[boundary_keys[boundary_index]].append( 154 | coord_y[duplicate_point[dataset][boundary_keys[boundary_index]]]) 155 | for k_index, k in enumerate(boundary_keys): 156 | if point_num_per_boundary[dataset][k_index] >= 2.: 157 | if len(boundary_x[k]) == len(set(boundary_x[k])) or len(boundary_y[k]) == len(set(boundary_y[k])): 158 | points[k].append(boundary_x[k]) 159 | points[k].append(boundary_y[k]) 160 | res = splprep(points[k], s=0.0, k=1) 161 | u_new = np.linspace(res[1].min(), res[1].max(), interp_points_num[k]) 162 | boundary_x[k], boundary_y[k] = splev(u_new, res[0], der=0) 163 | for index, k in enumerate(boundary_keys): 164 | if point_num_per_boundary[dataset][index] >= 2.: 165 | for i in range(len(boundary_x[k]) - 1): 166 | cv2.line(gt_heatmap[index], (int(boundary_x[k][i]), int(boundary_y[k][i])), 167 | (int(boundary_x[k][i+1]), int(boundary_y[k][i+1])), 0) 168 | else: 169 | cv2.circle(gt_heatmap[index], (int(boundary_x[k][0]), int(boundary_y[k][0])), 2, 0, -1) 170 | gt_heatmap[index] = np.uint8(gt_heatmap[index]) 171 | gt_heatmap[index] = cv2.distanceTransform(gt_heatmap[index], cv2.DIST_L2, 5) 172 | gt_heatmap[index] = np.float32(np.array(gt_heatmap[index])) 173 | gt_heatmap[index] = gt_heatmap[index].reshape(64*64) 174 | (gt_heatmap[index])[(gt_heatmap[index]) < 3. * args.sigma] = \ 175 | np.exp(-(gt_heatmap[index])[(gt_heatmap[index]) < 3 * args.sigma] * 176 | (gt_heatmap[index])[(gt_heatmap[index]) < 3 * args.sigma] / 2. * args.sigma * args.sigma) 177 | (gt_heatmap[index])[(gt_heatmap[index]) >= 3. * args.sigma] = 0. 178 | gt_heatmap[index] = gt_heatmap[index].reshape([64, 64]) 179 | return np.array(gt_heatmap) 180 | 181 | 182 | def get_item_from(dataset, split, annotation): 183 | pic = cv2.imread(dataset_route[dataset]+annotation[-1]) 184 | pic = convert_img_to_gray(pic) if not args.RGB else pic 185 | coord_x = list(map(float, annotation[:2*kp_num[dataset]:2])) 186 | coord_y = list(map(float, annotation[1:2*kp_num[dataset]:2])) 187 | coord_xy = np.array(np.float32(list(map(float, annotation[:2*kp_num[dataset]])))) 188 | bbox = np.array(list(map(int, annotation[-7:-3]))) 189 | 190 | translation, trans_dir, rotation, scaling, flip, gaussian_blur = get_random_transform_param(split, bbox) 191 | 192 | position_before = np.float32([[int(bbox[0]) + pow(-1, trans_dir+1)*translation, 193 | int(bbox[1]) + pow(-1, trans_dir//2+1)*translation], 194 | [int(bbox[0]) + pow(-1, trans_dir+1)*translation, 195 | int(bbox[3]) + pow(-1, trans_dir//2+1)*translation], 196 | [int(bbox[2]) + pow(-1, trans_dir+1)*translation, 197 | int(bbox[3]) + pow(-1, trans_dir//2+1)*translation]]) 198 | position_after = np.float32([[0, 0], 199 | [0, args.crop_size - 1], 200 | [args.crop_size - 1, args.crop_size - 1]]) 201 | crop_matrix = cv2.getAffineTransform(position_before, position_after) 202 | pic_crop = cv2.warpAffine(pic, crop_matrix, (args.crop_size, args.crop_size)) 203 | pic_crop = further_transform(pic_crop, bbox, flip, gaussian_blur) if args.split in ['train'] else pic_crop 204 | affine_matrix = get_affine_matrix(args.crop_size, rotation, scaling) 205 | pic_affine = cv2.warpAffine(pic_crop, affine_matrix, (args.crop_size, args.crop_size)) 206 | pic_affine = pic_normalize(pic_affine) if not args.RGB else pic_affine 207 | 208 | coord_x_cropped, coord_y_cropped = get_cropped_coords(dataset, crop_matrix, coord_x, coord_y, flip=flip) 209 | gt_coords_xy = get_gt_coords(dataset, affine_matrix, coord_x_cropped, coord_y_cropped) 210 | 211 | gt_heatmap = get_gt_heatmap(dataset, gt_coords_xy) 212 | 213 | return pic_affine, gt_coords_xy, gt_heatmap, coord_xy, bbox, annotation[-1] 214 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .losses import HeatmapLoss 5 | 6 | 7 | class Bottleneck(nn.Module): 8 | expansion = 4 9 | 10 | def __init__(self, inplanes, planes, stride=1, downsample=None): 11 | super(Bottleneck, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(inplanes) 13 | self.relu1 = nn.ReLU(inplace=False) 14 | self.conv1 = nn.Conv2d(inplanes, planes, padding=0, 15 | kernel_size=1, stride=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(planes) 17 | self.relu2 = nn.ReLU(inplace=False) 18 | self.conv2 = nn.Conv2d(planes, planes, padding=1, 19 | kernel_size=3, stride=stride, bias=False) 20 | self.bn3 = nn.BatchNorm2d(planes) 21 | self.relu3 = nn.ReLU(inplace=False) 22 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, padding=0, 23 | kernel_size=1, stride=1, bias=False) 24 | if stride != 1 or inplanes != planes * self.expansion: 25 | downsample = nn.Conv2d(inplanes, planes * self.expansion, padding=0, 26 | kernel_size=1, stride=stride, bias=False) 27 | self.downsample = downsample 28 | 29 | for m in self.modules(): 30 | if m.__class__.__name__ in ['Conv2d']: 31 | nn.init.kaiming_uniform_(m.weight.data) 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | if self.downsample is not None: 37 | out = self.conv1(x) 38 | out = self.bn2(out) 39 | out = self.relu2(out) 40 | out = self.conv2(out) 41 | out = self.bn3(out) 42 | out = self.relu3(out) 43 | out = self.conv3(out) 44 | else: 45 | out = self.bn1(x) 46 | out = self.relu1(out) 47 | out = self.conv1(out) 48 | out = self.bn2(out) 49 | out = self.relu2(out) 50 | out = self.conv2(out) 51 | out = self.bn3(out) 52 | out = self.relu3(out) 53 | out = self.conv3(out) 54 | 55 | if self.downsample is not None: 56 | residual = self.downsample(x) 57 | 58 | out = out + residual 59 | 60 | return out 61 | 62 | 63 | class Hourglass(nn.Module): 64 | 65 | def __init__(self, block=Bottleneck, num_blocks=1, planes=64, depth=4): 66 | super(Hourglass, self).__init__() 67 | self.depth = depth 68 | self.maxpool = nn.MaxPool2d(2, stride=2) 69 | self.hg = self._make_hourglass(block, num_blocks, planes, depth) 70 | 71 | for m in self.modules(): 72 | if m.__class__.__name__ in ['Conv2d']: 73 | nn.init.kaiming_uniform_(m.weight.data) 74 | 75 | @staticmethod 76 | def _make_residual(block, num_blocks, planes): 77 | layers = [] 78 | for index in range(0, num_blocks): 79 | layers.append(block(planes * block.expansion, planes)) 80 | return nn.Sequential(*layers) 81 | 82 | def _make_hourglass(self, block, num_blocks, planes, depth): 83 | hourglass = [] 84 | for index in range(depth): 85 | res = [] 86 | for j in range(3): 87 | res.append(self._make_residual(block, num_blocks, planes)) 88 | if index == 0: 89 | res.append(self._make_residual(block, num_blocks, planes)) 90 | hourglass.append(nn.ModuleList(res)) 91 | return nn.ModuleList(hourglass) 92 | 93 | def _hourglass_forward(self, n, x): 94 | up1 = self.hg[n - 1][0](x) 95 | low1 = self.maxpool(x) 96 | low1 = self.hg[n - 1][1](low1) 97 | 98 | if n > 1: 99 | low2 = self._hourglass_forward(n - 1, low1) 100 | else: 101 | low2 = self.hg[n - 1][3](low1) 102 | low3 = self.hg[n - 1][2](low2) 103 | up2 = F.interpolate(low3, scale_factor=2, mode='bilinear', align_corners=True) 104 | out = up1 + up2 105 | return out 106 | 107 | def forward(self, x): 108 | return self._hourglass_forward(self.depth, x) 109 | 110 | 111 | class FMFHourglass(nn.Module): 112 | 113 | def __init__(self, planes, depth): 114 | super(FMFHourglass, self).__init__() 115 | self.depth = depth 116 | self.maxpool = nn.MaxPool2d(2, stride=2) 117 | hourglass = [] 118 | for index in range(depth): 119 | res = [] 120 | for j in range(3): 121 | res.append(Bottleneck(planes * Bottleneck.expansion, planes)) 122 | if index == depth - 1: 123 | del(res[-1]) 124 | hourglass.append(nn.ModuleList(res)) 125 | self.hg = nn.ModuleList(hourglass) 126 | 127 | for m in self.modules(): 128 | if m.__class__.__name__ in ['Conv2d']: 129 | nn.init.kaiming_uniform_(m.weight.data) 130 | 131 | def _hourglass_forward(self, n, x): 132 | up1 = self.hg[n - 1][2](x) 133 | low1 = self.maxpool(x) 134 | low1 = self.hg[n - 1][0](low1) 135 | 136 | if n > 1: 137 | low2 = self._hourglass_forward(n - 1, low1) 138 | low2 = self.hg[n - 1][1](low2) 139 | else: 140 | low2 = self.hg[n - 1][1](low1) 141 | up2 = F.interpolate(low2, scale_factor=2, mode='bilinear', align_corners=True) 142 | out = up1 + up2 143 | return out 144 | 145 | def forward(self, x): 146 | out = self.maxpool(x) 147 | out = self.hg[self.depth-1][0](out) 148 | if self.depth > 1: 149 | out = self._hourglass_forward(self.depth-1, out) 150 | out = self.hg[self.depth-1][1](out) 151 | out = F.interpolate(out, scale_factor=2, mode='bilinear', align_corners=True) 152 | return out 153 | 154 | 155 | class MessagePassing(nn.Module): 156 | pass_order = {'A': ['1', '13', '12', '11', '10', '5', '4', '7', '9', '6', '8', '2', '3'], 157 | 'B': ['2', '3', '6', '8', '7', '9', '4', '5', '10', '11', '12', '13', '1']} 158 | boundary_relation = {'A': {'1': ['2', '3', '7', '9', '13'], 159 | '2': [], 160 | '3': [], 161 | '4': ['7', '9'], 162 | '5': ['4'], 163 | '6': ['2'], 164 | '7': ['6'], 165 | '8': ['3'], 166 | '9': ['8'], 167 | '10': ['5'], 168 | '11': ['10'], 169 | '12': ['11'], 170 | '13': ['12']}, 171 | 'B': {'1': [], 172 | '2': ['1', '6'], 173 | '3': ['1', '8'], 174 | '4': ['5'], 175 | '5': ['10'], 176 | '6': ['7'], 177 | '7': ['1', '4'], 178 | '8': ['9'], 179 | '9': ['1', '4'], 180 | '10': ['11'], 181 | '11': ['12'], 182 | '12': ['13'], 183 | '13': ['1']}} 184 | 185 | def __init__(self, classes=13, step=2, inchannels=256, channels=16, first=0, last=0): 186 | super(MessagePassing, self).__init__() 187 | self.first = first # 标识当前次message passing是否是第一次message passing, 1表示是第一次 188 | self.last = last # 标识当前次message passing是否是最后一次message passing, 1表示是最后一次 189 | self.classes = classes # boundary number: 13 190 | self.step = step # message passing steps: 2 191 | prepare_conv, prepare_bn, prepare_relu = [], [], [] 192 | after_bn, after_relu, after_conv = [], [], [] 193 | inner_level_pass, inter_level_pass = [], [] 194 | for index in range(2 * classes): 195 | prepare_conv.append(nn.Conv2d(inchannels, channels, padding=0, 196 | kernel_size=1, stride=1, bias=False)) 197 | prepare_bn.append(nn.BatchNorm2d(channels)) 198 | prepare_relu.append(nn.ReLU()) 199 | for index in range(classes): 200 | after_bn.append(nn.BatchNorm2d(2*channels)) 201 | after_relu.append(nn.ReLU()) 202 | after_conv.append(nn.Conv2d(2*channels, 1, padding=0, 203 | kernel_size=1, stride=1, bias=False)) 204 | for item in self.pass_order['A']: 205 | for index in range(len(self.boundary_relation['A'][item])): 206 | inner_level_pass.append(self._make_passing()) 207 | for item in self.pass_order['B']: 208 | for index in range(len(self.boundary_relation['B'][item])): 209 | inner_level_pass.append(self._make_passing()) 210 | if self.last == 0: 211 | for index in range(2*self.classes): 212 | inter_level_pass.append(self._make_passing()) 213 | self.pre_conv = nn.ModuleList(prepare_conv) 214 | self.pre_bn = nn.ModuleList(prepare_bn) 215 | self.pre_relu = nn.ModuleList(prepare_relu) 216 | self.aft_bn = nn.ModuleList(after_bn) 217 | self.aft_relu = nn.ModuleList(after_relu) 218 | self.aft_conv = nn.ModuleList(after_conv) 219 | self.inner_pass = nn.ModuleList(inner_level_pass) 220 | self.inter_pass = nn.ModuleList(inter_level_pass) 221 | 222 | for m in self.modules(): 223 | if m.__class__.__name__ in ['Conv2d']: 224 | nn.init.kaiming_uniform_(m.weight.data) 225 | 226 | def _make_passing(self, inplanes=16, planes=8, pad=3, ker_size=7, stride=1, bias=False): 227 | passing = [] 228 | for pass_step in range(self.step): 229 | if pass_step == 0: 230 | passing.append(nn.Conv2d(inplanes, planes, padding=pad, 231 | kernel_size=ker_size, stride=stride, bias=bias)) 232 | passing.append(nn.BatchNorm2d(planes)) 233 | passing.append(nn.ReLU()) 234 | elif pass_step == self.step - 1: 235 | passing.append(nn.Conv2d(planes, inplanes, padding=pad, 236 | kernel_size=ker_size, stride=stride, bias=bias)) 237 | else: 238 | passing.append(nn.Conv2d(planes, planes, padding=pad, 239 | kernel_size=ker_size, stride=stride, bias=bias)) 240 | passing.append(nn.BatchNorm2d(planes)) 241 | passing.append(nn.ReLU()) 242 | return nn.Sequential(*passing) 243 | 244 | def forward(self, x, ahead_msg): 245 | inner_msg_count = 0 246 | feature_map = [] 247 | result = {'1': [], '2': [], '3': [], '4': [], '5': [], '6': [], '7': [], '8': [], '9': [], 248 | '10': [], '11': [], '12': [], '13': []} 249 | result_a = {'1': [], '2': [], '3': [], '4': [], '5': [], '6': [], '7': [], '8': [], '9': [], 250 | '10': [], '11': [], '12': [], '13': []} 251 | result_b = {'1': [], '2': [], '3': [], '4': [], '5': [], '6': [], '7': [], '8': [], '9': [], 252 | '10': [], '11': [], '12': [], '13': []} 253 | msg_box_a = {'1': [], '2': [], '3': [], '4': [], '5': [], '6': [], '7': [], '8': [], '9': [], 254 | '10': [], '11': [], '12': [], '13': []} 255 | msg_box_b = {'1': [], '2': [], '3': [], '4': [], '5': [], '6': [], '7': [], '8': [], '9': [], 256 | '10': [], '11': [], '12': [], '13': []} 257 | inter_level_msg = {'A': {'1': [], '2': [], '3': [], '4': [], '5': [], '6': [], '7': [], '8': [], 258 | '9': [], '10': [], '11': [], '12': [], '13': []}, 259 | 'B': {'1': [], '2': [], '3': [], '4': [], '5': [], '6': [], '7': [], '8': [], 260 | '9': [], '10': [], '11': [], '12': [], '13': []}} 261 | 262 | for index in range(self.classes): # direction 'A' 263 | out = self.pre_conv[index](x) 264 | for get_msg_index in range(len(msg_box_a[self.pass_order['A'][index]])): # get inner level msg 265 | out = out + msg_box_a[self.pass_order['A'][index]][get_msg_index] 266 | if self.first == 0: # 即不是第一次message passing, get inter level msg 267 | out = out + ahead_msg['A'][self.pass_order['A'][index]][0] 268 | out = self.pre_bn[index](out) 269 | out = self.pre_relu[index](out) 270 | result_a[self.pass_order['A'][index]].append(out) # save to be concatenated 271 | for send_msg_index in range(len(self.boundary_relation['A'][self.pass_order['A'][index]])): # message pass 272 | temp = self.inner_pass[inner_msg_count](out) 273 | inner_msg_count = inner_msg_count + 1 274 | msg_box_a[self.boundary_relation['A'][self.pass_order['A'][index]][send_msg_index]].append(temp) 275 | if self.last == 0: # 即不是最后一次message passing,则向下一个stack传递消息 276 | temp = self.inter_pass[index](out) 277 | inter_level_msg['A'][self.pass_order['A'][index]].append(temp) 278 | 279 | for index in range(self.classes): # direction 'B' 280 | out = self.pre_conv[index + self.classes](x) 281 | for get_msg_index in range(len(msg_box_b[self.pass_order['B'][index]])): # get inner level msg 282 | out = out + msg_box_b[self.pass_order['B'][index]][get_msg_index] 283 | if self.first == 0: # 即不是第一次message passing, get inter level msg 284 | out = out + ahead_msg['B'][self.pass_order['B'][index]][0] 285 | out = self.pre_bn[index + self.classes](out) 286 | out = self.pre_relu[index + self.classes](out) 287 | result_b[self.pass_order['B'][index]].append(out) # save to be concatenated 288 | for send_msg_index in range(len(self.boundary_relation['B'][self.pass_order['B'][index]])): # message pass 289 | temp = self.inner_pass[inner_msg_count](out) 290 | inner_msg_count = inner_msg_count + 1 291 | msg_box_b[self.boundary_relation['B'][self.pass_order['B'][index]][send_msg_index]].append(temp) 292 | if self.last == 0: # 即不是最后一次message passing,则向下一个stack传递消息 293 | temp = self.inter_pass[index + self.classes](out) 294 | inter_level_msg['B'][self.pass_order['B'][index]].append(temp) 295 | 296 | for index in range(self.classes): # concatenation and conv to get feature_map 297 | result[str(index + 1)] = torch.cat((result_a[str(index + 1)][0], 298 | result_b[str(index + 1)][0]), 1) # after concat: 1 32 64 64 299 | result[str(index + 1)] = self.aft_bn[index](result[str(index + 1)]) 300 | result[str(index + 1)] = self.aft_relu[index](result[str(index + 1)]) 301 | result[str(index + 1)] = self.aft_conv[index](result[str(index + 1)]) 302 | 303 | feature_map.append(result['1']) 304 | for index in range(self.classes - 1): # concat all 'classes' feature maps 305 | feature_map[0] = torch.cat((feature_map[0], result[str(index + 2)]), 1) 306 | 307 | if self.last == 0: # 如果不是最后一个stack的message passing,则除了输出feature map外还输出层间消息 308 | return feature_map[0], inter_level_msg 309 | else: 310 | return feature_map[0] 311 | 312 | 313 | class Estimator(nn.Module): 314 | 315 | def __init__(self, stacks=4, msg_pass=1): 316 | super(Estimator, self).__init__() 317 | self.stacks = stacks 318 | self.msg_pass = msg_pass 319 | self.hm_loss = HeatmapLoss() 320 | self.conv1 = nn.Conv2d(1, 64, padding=3, kernel_size=7, 321 | stride=2, bias=False) 322 | self.conv1_bn = nn.BatchNorm2d(64) 323 | self.conv1_relu = nn.ReLU(inplace=False) 324 | self.pre_res_1 = Bottleneck(64, 32) 325 | self.pool1 = nn.MaxPool2d(3, stride=2, padding=1) # problem, need to see the source code of caffe 326 | self.pre_res_2 = Bottleneck(128, 32) 327 | self.pre_res_2_bn = nn.BatchNorm2d(128) 328 | self.pre_res_2_relu = nn.ReLU(inplace=False) 329 | self.hourglass_0 = Bottleneck(128, 64) 330 | hg, mp = [], [] 331 | linear_1_res, linear_1_bn, linear_1_relu, linear_1_conv = [], [], [], [] 332 | linear_2_bn, linear_2_relu, linear_2_conv = [], [], [] 333 | linear_3 = [] 334 | linear_mp_bn, linear_mp_relu, linear_mp_conv = [], [], [] 335 | for index in range(self.stacks): 336 | hg.append(Hourglass()) 337 | linear_1_res.append(Bottleneck(256, 64)) 338 | linear_1_bn.append(nn.BatchNorm2d(256)) 339 | linear_1_relu.append(nn.ReLU()) 340 | linear_1_conv.append(nn.Conv2d(256, 256, padding=0, kernel_size=1, 341 | stride=1, bias=False)) 342 | if msg_pass: 343 | if index == 0: 344 | mp.append(MessagePassing(first=1)) 345 | elif index == self.stacks - 1: 346 | mp.append(MessagePassing(last=1)) 347 | else: 348 | mp.append(MessagePassing()) 349 | else: 350 | linear_mp_bn.append(nn.BatchNorm2d(256)) 351 | linear_mp_relu.append(nn.ReLU()) 352 | linear_mp_conv.append(nn.Conv2d(256, 13, padding=0, kernel_size=1, 353 | stride=1, bias=False)) 354 | if index != self.stacks - 1: 355 | linear_2_bn.append(nn.BatchNorm2d(256)) 356 | linear_2_relu.append(nn.ReLU()) 357 | linear_2_conv.append(nn.Conv2d(256, 256, padding=0, kernel_size=1, 358 | stride=1, bias=False)) 359 | linear_3.append(nn.Conv2d(13, 256, padding=0, kernel_size=1, 360 | stride=1, bias=False)) 361 | self.hg = nn.ModuleList(hg) 362 | self.linear_1_res = nn.ModuleList(linear_1_res) 363 | self.linear_1_bn = nn.ModuleList(linear_1_bn) 364 | self.linear_1_relu = nn.ModuleList(linear_1_relu) 365 | self.linear_1_conv = nn.ModuleList(linear_1_conv) 366 | self.mp = nn.ModuleList(mp) 367 | self.linear_2_bn = nn.ModuleList(linear_2_bn) 368 | self.linear_2_relu = nn.ModuleList(linear_2_relu) 369 | self.linear_2_conv = nn.ModuleList(linear_2_conv) 370 | self.linear_3 = nn.ModuleList(linear_3) 371 | self.linear_mp_bn = nn.ModuleList(linear_mp_bn) 372 | self.linear_mp_relu = nn.ModuleList(linear_mp_relu) 373 | self.linear_mp_conv = nn.ModuleList(linear_mp_conv) 374 | 375 | for m in self.modules(): 376 | if m.__class__.__name__ in ['Conv2d']: 377 | nn.init.kaiming_uniform_(m.weight.data) 378 | 379 | def forward(self, x): 380 | heatmaps = [] # save all the stacks output feature maps 381 | inter_level_msg = [] 382 | out = self.conv1(x) 383 | out = self.conv1_bn(out) 384 | out = self.conv1_relu(out) 385 | out = self.pre_res_1(out) 386 | out = self.pool1(out) 387 | out = self.pre_res_2(out) 388 | out = self.pre_res_2_bn(out) 389 | out = self.pre_res_2_relu(out) 390 | out = self.hourglass_0(out) 391 | for index in range(self.stacks): 392 | temp = self.hg[index](out) 393 | temp = self.linear_1_res[index](temp) 394 | temp = self.linear_1_bn[index](temp) 395 | temp = self.linear_1_relu[index](temp) 396 | temp = self.linear_1_conv[index](temp) 397 | if self.msg_pass: 398 | if index != self.stacks - 1: 399 | heatmap, inter_level_msg = self.mp[index](temp, inter_level_msg) 400 | else: 401 | heatmap = self.mp[index](temp, inter_level_msg) 402 | else: 403 | heatmap = self.linear_mp_bn[index](temp) 404 | heatmap = self.linear_mp_relu[index](heatmap) 405 | heatmap = self.linear_mp_conv[index](heatmap) 406 | heatmaps.append(heatmap) 407 | if index != self.stacks - 1: 408 | temp = self.linear_2_bn[index](temp) 409 | temp = self.linear_2_relu[index](temp) 410 | linear2_out = self.linear_2_conv[index](temp) 411 | linear3_out = self.linear_3[index](heatmap) 412 | out = out + linear2_out + linear3_out 413 | return heatmaps # 每一个stack的输出heatmap经过append 414 | 415 | def calc_loss(self, pred_heatmaps, gt_heatmap): 416 | heatmap_loss = [] 417 | for stack in range(self.stacks): 418 | heatmap_loss.append(self.hm_loss(pred_heatmaps[stack], gt_heatmap)) 419 | heatmap_loss = torch.stack(heatmap_loss, dim=0) 420 | heatmap_loss = torch.sum(heatmap_loss) 421 | return heatmap_loss 422 | 423 | 424 | class Regressor(nn.Module): 425 | 426 | def __init__(self, classes=13, fuse_stages=4, planes=16, output=196): 427 | super(Regressor, self).__init__() 428 | self.classes = classes 429 | self.FMF_stages = 3 430 | self.fuse_stages = fuse_stages 431 | self.planes = planes 432 | self.conv1 = nn.Conv2d(14, self.planes, padding=3, kernel_size=7, stride=2, bias=False) \ 433 | if fuse_stages > 0 else nn.Conv2d(1, self.planes, padding=3, kernel_size=7, stride=2, bias=False) 434 | self.bn1 = nn.BatchNorm2d(self.planes) 435 | self.bn2 = nn.BatchNorm2d(256) # regressor最后一个Batchnorm 436 | self.relu1 = nn.ReLU(inplace=False) 437 | self.relu2 = nn.ReLU(inplace=False) # regressor ip之前最后一个relu 438 | self.pool1 = nn.MaxPool2d(3, stride=2, padding=1) # problem, need to see the source code of caffe (solved) 439 | baseline_bn, baseline_relu, baseline_res_1, baseline_res_2 = [], [], [], [] 440 | pre_fmf_bn, pre_fmf_relu, pre_fmf_conv = [], [], [] 441 | aft_fmf_bn, aft_fmf_relu, aft_fmf_conv = [], [], [] 442 | tanh = [] 443 | fmfhourglass = [] 444 | for index in range(self.FMF_stages + 1): 445 | if index == 0: 446 | baseline_bn.append(nn.BatchNorm2d(self.planes)) 447 | baseline_relu.append(nn.ReLU()) 448 | baseline_res_1.append(Bottleneck(self.planes, self.planes//2)) 449 | baseline_res_2.append(Bottleneck(self.planes * 2, self.planes//2)) 450 | else: 451 | baseline_bn.append(nn.BatchNorm2d(self.planes * pow(2, index))) 452 | baseline_relu.append(nn.ReLU()) 453 | baseline_res_1.append(Bottleneck(self.planes * pow(2, index), self.planes * pow(2, index-1), stride=2)) 454 | baseline_res_2.append(Bottleneck(self.planes * pow(2, index+1), self.planes * pow(2, index-1))) 455 | for index in range(self.FMF_stages): 456 | pre_fmf_bn.append(nn.BatchNorm2d(self.planes * pow(2, index+1) + self.classes)) 457 | pre_fmf_relu.append(nn.ReLU()) 458 | pre_fmf_conv.append(nn.Conv2d(self.planes*pow(2, index+1) + self.classes, self.planes*pow(2, index+1), 459 | padding=0, kernel_size=1, stride=1, bias=False)) 460 | for index in range(self.FMF_stages): 461 | fmfhourglass.append(FMFHourglass(planes=8*pow(2, index), depth=3-index)) 462 | for index in range(self.FMF_stages): 463 | aft_fmf_bn.append(nn.BatchNorm2d(self.planes * pow(2, index + 1))) 464 | aft_fmf_bn.append(nn.BatchNorm2d(self.planes * pow(2, index + 1))) 465 | aft_fmf_relu.append(nn.ReLU()) 466 | aft_fmf_relu.append(nn.ReLU()) 467 | aft_fmf_conv.append(nn.Conv2d(self.planes * pow(2, index + 1), self.planes * pow(2, index + 1), 468 | padding=0, kernel_size=1, stride=1, bias=False)) 469 | aft_fmf_conv.append(nn.Conv2d(self.planes * pow(2, index + 1), self.planes * pow(2, index + 1), 470 | padding=0, kernel_size=1, stride=1, bias=False)) 471 | tanh.append(nn.Tanh()) 472 | self.bl_bn = nn.ModuleList(baseline_bn) 473 | self.bl_relu = nn.ModuleList(baseline_relu) 474 | self.bl_res_1 = nn.ModuleList(baseline_res_1) 475 | self.bl_res_2 = nn.ModuleList(baseline_res_2) 476 | self.pre_fmf_bn = nn.ModuleList(pre_fmf_bn) 477 | self.pre_fmf_relu = nn.ModuleList(pre_fmf_relu) 478 | self.pre_fmf_conv = nn.ModuleList(pre_fmf_conv) 479 | self.FMF_Hourglass = nn.ModuleList(fmfhourglass) 480 | self.aft_fmf_bn = nn.ModuleList(aft_fmf_bn) 481 | self.aft_fmf_relu = nn.ModuleList(aft_fmf_relu) 482 | self.aft_fmf_conv = nn.ModuleList(aft_fmf_conv) 483 | self.tanh = nn.ModuleList(tanh) 484 | self.fc1 = nn.Linear(256 * 8 * 8, 256) # 目前的代码暂时不考虑通用性,很多数字暂时都强硬地固定下来了 485 | self.fc2 = nn.Linear(256, 256) 486 | self.fc3 = nn.Linear(256, output) 487 | self.fc_relu1 = nn.ReLU(inplace=False) 488 | self.fc_relu2 = nn.ReLU(inplace=False) 489 | 490 | for m in self.modules(): 491 | if m.__class__.__name__ in ['Conv2d']: 492 | nn.init.kaiming_uniform_(m.weight.data) 493 | 494 | @staticmethod 495 | def num_flat_features(x): 496 | size = x.size()[1:] # all dimensions except the batch dimension 497 | num_features = 1 498 | for s in size: 499 | num_features *= s 500 | return num_features 501 | 502 | def forward(self, input_img, heatmap): 503 | data_concat = [] 504 | if self.fuse_stages > 0: 505 | out = F.interpolate(heatmap, scale_factor=4, mode='bilinear', align_corners=True) 506 | data_concat.append(input_img) 507 | for index in range(self.classes - 1): 508 | data_concat[0] = torch.cat((data_concat[0], input_img), 1) 509 | out = data_concat[0]*out 510 | out = torch.cat((out, input_img), 1) 511 | else: 512 | out = input_img 513 | out = self.conv1(out) 514 | out = self.bn1(out) 515 | out = self.relu1(out) 516 | out = self.pool1(out) 517 | out = self.bl_bn[0](out) 518 | out = self.bl_relu[0](out) 519 | out = self.bl_res_1[0](out) 520 | out = self.bl_res_2[0](out) 521 | for index in range(self.FMF_stages): 522 | if index < self.fuse_stages - 1: 523 | temp = F.interpolate(heatmap, scale_factor=pow(2, -1*index), mode='bilinear', align_corners=True) 524 | temp_out = torch.cat((temp, out), 1) 525 | temp_out = self.pre_fmf_bn[index](temp_out) 526 | temp_out = self.pre_fmf_relu[index](temp_out) 527 | temp_out = self.pre_fmf_conv[index](temp_out) 528 | temp_out = self.FMF_Hourglass[index](temp_out) 529 | temp_out = self.aft_fmf_bn[2 * index](temp_out) 530 | temp_out = self.aft_fmf_relu[2 * index](temp_out) 531 | temp_out = self.aft_fmf_conv[2 * index](temp_out) 532 | temp_out = self.aft_fmf_bn[2 * index + 1](temp_out) 533 | temp_out = self.aft_fmf_relu[2 * index + 1](temp_out) 534 | temp_out = self.aft_fmf_conv[2 * index + 1](temp_out) 535 | temp_out = self.tanh[index](temp_out) 536 | temp_out = temp_out * out 537 | out = temp_out + out 538 | out = self.bl_bn[index+1](out) 539 | out = self.bl_relu[index + 1](out) 540 | out = self.bl_res_1[index + 1](out) 541 | out = self.bl_res_2[index + 1](out) 542 | out = self.bn2(out) 543 | out = self.relu2(out) 544 | out = out.view(-1, self.num_flat_features(out)) 545 | out = self.fc1(out) 546 | out = self.fc_relu1(out) 547 | out = self.fc2(out) 548 | out = self.fc_relu2(out) 549 | out = self.fc3(out) 550 | 551 | return out 552 | 553 | 554 | class Discrim(nn.Module): 555 | channels, linear_n = [13, 64, 192, 384, 256, 256], [4096, 1024, 256, 13] 556 | ker_size, strd, pad = [2, 5, 3, 3, 3], [2, 1, 1, 1, 1], [0, 2, 1, 1, 1] 557 | maxpool_mask = [1, 1, 0, 0, 1] 558 | 559 | def __init__(self, conv_layers=5, linear_layers=3): 560 | super(Discrim, self).__init__() 561 | conv_features = [] 562 | linear_classify = [] 563 | for index in range(conv_layers): 564 | conv_features.append(nn.Conv2d(Discrim.channels[index], Discrim.channels[index + 1], 565 | kernel_size=Discrim.ker_size[index], 566 | stride=Discrim.strd[index], 567 | padding=Discrim.pad[index], 568 | bias=False)) 569 | conv_features.append(nn.BatchNorm2d(Discrim.channels[index + 1])) 570 | conv_features.append(nn.ReLU(inplace=False)) 571 | if Discrim.maxpool_mask[index] == 1: 572 | conv_features.append(nn.MaxPool2d(3, stride=2, padding=1)) 573 | else: 574 | conv_features.append(nn.ReLU(inplace=False)) 575 | for index in range(linear_layers): 576 | linear_classify.append(nn.Linear(Discrim.linear_n[index], Discrim.linear_n[index+1])) 577 | if index != linear_layers - 1: 578 | linear_classify.append(nn.ReLU(inplace=False)) 579 | else: 580 | linear_classify.append(nn.Sigmoid()) 581 | self.features = nn.Sequential(*conv_features) 582 | self.classifier = nn.Sequential(*linear_classify) 583 | 584 | @staticmethod 585 | def num_flat_features(x): 586 | size = x.size()[1:] # all dimensions except the batch dimension 587 | num_features = 1 588 | for s in size: 589 | num_features *= s 590 | return num_features 591 | 592 | def forward(self, x): 593 | out = self.features(x) 594 | out = out.view(-1, self.num_flat_features(out)) 595 | out = self.classifier(out) 596 | return out 597 | 598 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 141 | 142 | 143 | 144 | ratote_limit 145 | ratate_limit 146 | ratotion 147 | cat 148 | assert 149 | dataset 150 | test68 151 | 152 | 153 | 154 | ratate_limit 155 | rotate_limit 156 | rotation 157 | COFW68 158 | 159 | 160 | 161 | 163 | 164 | 183 | 184 | 185 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 |