├── __init__.py ├── ckpt └── .gitkeep ├── core ├── __init__.py ├── evaler.py ├── coord_conv.py ├── models.py └── dataloader.py ├── utils ├── __init__.py └── utils.py ├── images ├── wflw.png └── wflw_table.png ├── .gitignore ├── requirements.txt ├── scripts └── eval_wflw.sh ├── README.md ├── eval.py ├── dataset └── convert_WFLW.py └── LICENSE /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ckpt/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/wflw.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protossw512/AdaptiveWingLoss/HEAD/images/wflw.png -------------------------------------------------------------------------------- /images/wflw_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protossw512/AdaptiveWingLoss/HEAD/images/wflw_table.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python generated files 2 | *.pyc 3 | 4 | # Project related files 5 | ckpt/*.pth 6 | dataset/* 7 | !dataset/!.py 8 | experiments/* -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | scipy>=0.17.0 3 | scikit-image 4 | numpy 5 | matplotlib 6 | Pillow>=4.3.0 7 | imgaug 8 | tensorflow 9 | git+https://github.com/lanpa/tensorboardX 10 | joblib 11 | torch==1.3.0 12 | torchvision==0.4.1 13 | -------------------------------------------------------------------------------- /scripts/eval_wflw.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python ../eval.py \ 2 | --val_img_dir='../dataset/WFLW_test/images/' \ 3 | --val_landmarks_dir='../dataset/WFLW_test/landmarks/' \ 4 | --ckpt_save_path='../experiments/eval_iccv_0620' \ 5 | --hg_blocks=4 \ 6 | --pretrained_weights='../ckpt/WFLW_4HG.pth' \ 7 | --num_landmarks=98 \ 8 | --end_relu='False' \ 9 | --batch_size=20 \ 10 | 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaptiveWingLoss 2 | ## [arXiv](https://arxiv.org/abs/1904.07399) 3 | Pytorch Implementation of Adaptive Wing Loss for Robust Face Alignment via Heatmap Regression. 4 | 5 | 6 | 7 | ## Update Logs: 8 | ### October 28, 2019 9 | * Pretrained Model and evaluation code on WFLW dataset is released. 10 | 11 | ## Installation 12 | #### Note: Code was originally developed under Python2.X and Pytorch 0.4. This released version was revisioned from original code and was tested on Python3.5.7 and Pytorch 1.3.0. 13 | 14 | Install system requirements: 15 | ``` 16 | sudo apt-get install python3-dev python3-pip python3-tk libglib2.0-0 17 | ``` 18 | 19 | Install python dependencies: 20 | ``` 21 | pip3 install -r requirements.txt 22 | ``` 23 | 24 | ## Run Evaluation on WFLW dataset 25 | 1. Download and process WFLW dataset 26 | * Download WFLW dataset and annotation from [Here](https://wywu.github.io/projects/LAB/WFLW.html). 27 | * Unzip WFLW dataset and annotations and move files into ```./dataset``` directory. Your directory should look like this: 28 | ``` 29 | AdaptiveWingLoss 30 | └───dataset 31 | │ 32 | └───WFLW_annotations 33 | │ └───list_98pt_rect_attr_train_test 34 | │ │ 35 | │ └───list_98pt_test 36 | │ 37 | └───WFLW_images 38 | └───0--Parade 39 | │ 40 | └───... 41 | ``` 42 | * Inside ```./dataset``` directory, run: 43 | ``` 44 | python convert_WFLW.py 45 | ``` 46 | A new directory ```./dataset/WFLW_test``` should be generated with 2500 processed testing images and corresponding landmarks. 47 | 48 | 2. Download pretrained model from [Google Drive](https://drive.google.com/file/d/1HZaSjLoorQ4QCEx7PRTxOmg0bBPYSqhH/view?usp=sharing) and put it in ```./ckpt``` directory. 49 | 50 | 3. Within ```./Scripts``` directory, run following command: 51 | ``` 52 | sh eval_wflw.sh 53 | ``` 54 | 55 | 56 | *GTBbox indicates the ground truth landmarks are used as bounding box to crop faces. 57 | 58 | ## Future Plans 59 | - [x] Release evaluation code and pretrained model on WFLW dataset. 60 | 61 | - [ ] Release training code on WFLW dataset. 62 | 63 | - [ ] Release pretrained model and code on 300W, AFLW and COFW dataset. 64 | 65 | - [ ] Replease facial landmark detection API 66 | 67 | 68 | ## Citation 69 | If you find this useful for your research, please cite the following paper. 70 | 71 | ``` 72 | @InProceedings{Wang_2019_ICCV, 73 | author = {Wang, Xinyao and Bo, Liefeng and Fuxin, Li}, 74 | title = {Adaptive Wing Loss for Robust Face Alignment via Heatmap Regression}, 75 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, 76 | month = {October}, 77 | year = {2019} 78 | } 79 | ``` 80 | 81 | ## Acknowledgments 82 | This repository borrows or partially modifies hourglass model and data processing code from [face alignment](https://github.com/1adrianb/face-alignment) and [pose-hg-train](https://github.com/princeton-vl/pose-hg-train). 83 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import torch 3 | import argparse 4 | import numpy as np 5 | import torch.nn as nn 6 | import time 7 | import os 8 | from core.evaler import eval_model 9 | from core.dataloader import get_dataset 10 | from core import models 11 | from tensorboardX import SummaryWriter 12 | 13 | # Parse arguments 14 | parser = argparse.ArgumentParser() 15 | # Dataset paths 16 | parser.add_argument('--val_img_dir', type=str, 17 | help='Validation image directory') 18 | parser.add_argument('--val_landmarks_dir', type=str, 19 | help='Validation landmarks directory') 20 | parser.add_argument('--num_landmarks', type=int, default=68, 21 | help='Number of landmarks') 22 | 23 | # Checkpoint and pretrained weights 24 | parser.add_argument('--ckpt_save_path', type=str, 25 | help='a directory to save checkpoint file') 26 | parser.add_argument('--pretrained_weights', type=str, 27 | help='a directory to save pretrained_weights') 28 | 29 | # Eval options 30 | parser.add_argument('--batch_size', type=int, default=25, 31 | help='learning rate decay after each epoch') 32 | 33 | # Network parameters 34 | parser.add_argument('--hg_blocks', type=int, default=4, 35 | help='Number of HG blocks to stack') 36 | parser.add_argument('--gray_scale', type=str, default="False", 37 | help='Whether to convert RGB image into gray scale during training') 38 | parser.add_argument('--end_relu', type=str, default="False", 39 | help='Whether to add relu at the end of each HG module') 40 | 41 | args = parser.parse_args() 42 | 43 | VAL_IMG_DIR = args.val_img_dir 44 | VAL_LANDMARKS_DIR = args.val_landmarks_dir 45 | CKPT_SAVE_PATH = args.ckpt_save_path 46 | BATCH_SIZE = args.batch_size 47 | PRETRAINED_WEIGHTS = args.pretrained_weights 48 | GRAY_SCALE = False if args.gray_scale == 'False' else True 49 | HG_BLOCKS = args.hg_blocks 50 | END_RELU = False if args.end_relu == 'False' else True 51 | NUM_LANDMARKS = args.num_landmarks 52 | 53 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 54 | 55 | writer = SummaryWriter(CKPT_SAVE_PATH) 56 | 57 | dataloaders, dataset_sizes = get_dataset(VAL_IMG_DIR, VAL_LANDMARKS_DIR, 58 | BATCH_SIZE, NUM_LANDMARKS) 59 | use_gpu = torch.cuda.is_available() 60 | model_ft = models.FAN(HG_BLOCKS, END_RELU, GRAY_SCALE, NUM_LANDMARKS) 61 | 62 | if PRETRAINED_WEIGHTS != "None": 63 | checkpoint = torch.load(PRETRAINED_WEIGHTS) 64 | if 'state_dict' not in checkpoint: 65 | model_ft.load_state_dict(checkpoint) 66 | else: 67 | pretrained_weights = checkpoint['state_dict'] 68 | model_weights = model_ft.state_dict() 69 | pretrained_weights = {k: v for k, v in pretrained_weights.items() \ 70 | if k in model_weights} 71 | model_weights.update(pretrained_weights) 72 | model_ft.load_state_dict(model_weights) 73 | 74 | model_ft = model_ft.to(device) 75 | 76 | model_ft = eval_model(model_ft, dataloaders, dataset_sizes, writer, use_gpu, 1, 'val', CKPT_SAVE_PATH, NUM_LANDMARKS) 77 | 78 | -------------------------------------------------------------------------------- /core/evaler.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import math 4 | import torch 5 | import copy 6 | import time 7 | from torch.autograd import Variable 8 | import shutil 9 | from skimage import io 10 | import numpy as np 11 | from utils.utils import fan_NME, show_landmarks, get_preds_fromhm 12 | from PIL import Image, ImageDraw 13 | import os 14 | import sys 15 | import cv2 16 | import matplotlib.pyplot as plt 17 | 18 | 19 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 20 | 21 | def eval_model(model, dataloaders, dataset_sizes, 22 | writer, use_gpu=True, epoches=5, dataset='val', 23 | save_path='./', num_landmarks=68): 24 | global_nme = 0 25 | model.eval() 26 | for epoch in range(epoches): 27 | running_loss = 0 28 | step = 0 29 | total_nme = 0 30 | total_count = 0 31 | fail_count = 0 32 | nmes = [] 33 | # running_corrects = 0 34 | 35 | # Iterate over data. 36 | with torch.no_grad(): 37 | for data in dataloaders[dataset]: 38 | total_runtime = 0 39 | run_count = 0 40 | step_start = time.time() 41 | step += 1 42 | # get the inputs 43 | inputs = data['image'].type(torch.FloatTensor) 44 | labels_heatmap = data['heatmap'].type(torch.FloatTensor) 45 | labels_boundary = data['boundary'].type(torch.FloatTensor) 46 | landmarks = data['landmarks'].type(torch.FloatTensor) 47 | loss_weight_map = data['weight_map'].type(torch.FloatTensor) 48 | # wrap them in Variable 49 | if use_gpu: 50 | inputs = inputs.to(device) 51 | labels_heatmap = labels_heatmap.to(device) 52 | labels_boundary = labels_boundary.to(device) 53 | loss_weight_map = loss_weight_map.to(device) 54 | else: 55 | inputs, labels_heatmap = Variable(inputs), Variable(labels_heatmap) 56 | labels_boundary = Variable(labels_boundary) 57 | labels = torch.cat((labels_heatmap, labels_boundary), 1) 58 | single_start = time.time() 59 | outputs, boundary_channels = model(inputs) 60 | single_end = time.time() 61 | total_runtime += time.time() - single_start 62 | run_count += 1 63 | step_end = time.time() 64 | for i in range(inputs.shape[0]): 65 | img = inputs[i] 66 | img = img.cpu().numpy() 67 | img = img.transpose((1, 2, 0))*255.0 68 | img = img.astype(np.uint8) 69 | img = Image.fromarray(img) 70 | # pred_heatmap = outputs[-1][i].detach().cpu()[:-1, :, :] 71 | pred_heatmap = outputs[-1][:, :-1, :, :][i].detach().cpu() 72 | pred_landmarks, _ = get_preds_fromhm(pred_heatmap.unsqueeze(0)) 73 | pred_landmarks = pred_landmarks.squeeze().numpy() 74 | 75 | gt_landmarks = data['landmarks'][i].numpy() 76 | if num_landmarks == 68: 77 | left_eye = np.average(gt_landmarks[36:42], axis=0) 78 | right_eye = np.average(gt_landmarks[42:48], axis=0) 79 | norm_factor = np.linalg.norm(left_eye - right_eye) 80 | # norm_factor = np.linalg.norm(gt_landmarks[36]- gt_landmarks[45]) 81 | 82 | elif num_landmarks == 98: 83 | norm_factor = np.linalg.norm(gt_landmarks[60]- gt_landmarks[72]) 84 | elif num_landmarks == 19: 85 | left, top = gt_landmarks[-2, :] 86 | right, bottom = gt_landmarks[-1, :] 87 | norm_factor = math.sqrt(abs(right - left)*abs(top-bottom)) 88 | gt_landmarks = gt_landmarks[:-2, :] 89 | elif num_landmarks == 29: 90 | # norm_factor = np.linalg.norm(gt_landmarks[8]- gt_landmarks[9]) 91 | norm_factor = np.linalg.norm(gt_landmarks[16]- gt_landmarks[17]) 92 | single_nme = (np.sum(np.linalg.norm(pred_landmarks*4 - gt_landmarks, axis=1)) / pred_landmarks.shape[0]) / norm_factor 93 | 94 | nmes.append(single_nme) 95 | total_count += 1 96 | if single_nme > 0.1: 97 | fail_count += 1 98 | if step % 10 == 0: 99 | print('Step {} Time: {:.6f} Input Mean: {:.6f} Output Mean: {:.6f}'.format( 100 | step, step_end - step_start, 101 | torch.mean(labels), 102 | torch.mean(outputs[0]))) 103 | # gt_landmarks = landmarks.numpy() 104 | # pred_heatmap = outputs[-1].to('cpu').numpy() 105 | gt_landmarks = landmarks 106 | batch_nme = fan_NME(outputs[-1][:, :-1, :, :].detach().cpu(), gt_landmarks, num_landmarks) 107 | # batch_nme = 0 108 | total_nme += batch_nme 109 | epoch_nme = total_nme / dataset_sizes['val'] 110 | global_nme += epoch_nme 111 | nme_save_path = os.path.join(save_path, 'nme_log.npy') 112 | np.save(nme_save_path, np.array(nmes)) 113 | print('NME: {:.6f} Failure Rate: {:.6f} Total Count: {:.6f} Fail Count: {:.6f}'.format(epoch_nme, fail_count/total_count, total_count, fail_count)) 114 | print('Evaluation done! Average NME: {:.6f}'.format(global_nme/epoches)) 115 | print('Everage runtime for a single batch: {:.6f}'.format(total_runtime/run_count)) 116 | return model 117 | -------------------------------------------------------------------------------- /core/coord_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AddCoordsTh(nn.Module): 6 | def __init__(self, x_dim=64, y_dim=64, with_r=False, with_boundary=False): 7 | super(AddCoordsTh, self).__init__() 8 | self.x_dim = x_dim 9 | self.y_dim = y_dim 10 | self.with_r = with_r 11 | self.with_boundary = with_boundary 12 | 13 | def forward(self, input_tensor, heatmap=None): 14 | """ 15 | input_tensor: (batch, c, x_dim, y_dim) 16 | """ 17 | batch_size_tensor = input_tensor.shape[0] 18 | 19 | xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32).cuda() 20 | xx_ones = xx_ones.unsqueeze(-1) 21 | 22 | xx_range = torch.arange(self.x_dim, dtype=torch.int32).unsqueeze(0).cuda() 23 | xx_range = xx_range.unsqueeze(1) 24 | 25 | xx_channel = torch.matmul(xx_ones.float(), xx_range.float()) 26 | xx_channel = xx_channel.unsqueeze(-1) 27 | 28 | 29 | yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32).cuda() 30 | yy_ones = yy_ones.unsqueeze(1) 31 | 32 | yy_range = torch.arange(self.y_dim, dtype=torch.int32).unsqueeze(0).cuda() 33 | yy_range = yy_range.unsqueeze(-1) 34 | 35 | yy_channel = torch.matmul(yy_range.float(), yy_ones.float()) 36 | yy_channel = yy_channel.unsqueeze(-1) 37 | 38 | xx_channel = xx_channel.permute(0, 3, 2, 1) 39 | yy_channel = yy_channel.permute(0, 3, 2, 1) 40 | 41 | xx_channel = xx_channel / (self.x_dim - 1) 42 | yy_channel = yy_channel / (self.y_dim - 1) 43 | 44 | xx_channel = xx_channel * 2 - 1 45 | yy_channel = yy_channel * 2 - 1 46 | 47 | xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1) 48 | yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1) 49 | 50 | if self.with_boundary and type(heatmap) != type(None): 51 | boundary_channel = torch.clamp(heatmap[:, -1:, :, :], 52 | 0.0, 1.0) 53 | 54 | zero_tensor = torch.zeros_like(xx_channel) 55 | xx_boundary_channel = torch.where(boundary_channel>0.05, 56 | xx_channel, zero_tensor) 57 | yy_boundary_channel = torch.where(boundary_channel>0.05, 58 | yy_channel, zero_tensor) 59 | if self.with_boundary and type(heatmap) != type(None): 60 | xx_boundary_channel = xx_boundary_channel.cuda() 61 | yy_boundary_channel = yy_boundary_channel.cuda() 62 | ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1) 63 | 64 | 65 | if self.with_r: 66 | rr = torch.sqrt(torch.pow(xx_channel, 2) + torch.pow(yy_channel, 2)) 67 | rr = rr / torch.max(rr) 68 | ret = torch.cat([ret, rr], dim=1) 69 | 70 | if self.with_boundary and type(heatmap) != type(None): 71 | ret = torch.cat([ret, xx_boundary_channel, 72 | yy_boundary_channel], dim=1) 73 | return ret 74 | 75 | 76 | class CoordConvTh(nn.Module): 77 | """CoordConv layer as in the paper.""" 78 | def __init__(self, x_dim, y_dim, with_r, with_boundary, 79 | in_channels, first_one=False, *args, **kwargs): 80 | super(CoordConvTh, self).__init__() 81 | self.addcoords = AddCoordsTh(x_dim=x_dim, y_dim=y_dim, with_r=with_r, 82 | with_boundary=with_boundary) 83 | in_channels += 2 84 | if with_r: 85 | in_channels += 1 86 | if with_boundary and not first_one: 87 | in_channels += 2 88 | self.conv = nn.Conv2d(in_channels=in_channels, *args, **kwargs) 89 | 90 | def forward(self, input_tensor, heatmap=None): 91 | ret = self.addcoords(input_tensor, heatmap) 92 | last_channel = ret[:, -2:, :, :] 93 | ret = self.conv(ret) 94 | return ret, last_channel 95 | 96 | 97 | ''' 98 | An alternative implementation for PyTorch with auto-infering the x-y dimensions. 99 | ''' 100 | class AddCoords(nn.Module): 101 | 102 | def __init__(self, with_r=False): 103 | super().__init__() 104 | self.with_r = with_r 105 | 106 | def forward(self, input_tensor): 107 | """ 108 | Args: 109 | input_tensor: shape(batch, channel, x_dim, y_dim) 110 | """ 111 | batch_size, _, x_dim, y_dim = input_tensor.size() 112 | 113 | xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1) 114 | yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2) 115 | 116 | xx_channel = xx_channel / (x_dim - 1) 117 | yy_channel = yy_channel / (y_dim - 1) 118 | 119 | xx_channel = xx_channel * 2 - 1 120 | yy_channel = yy_channel * 2 - 1 121 | 122 | xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 123 | yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) 124 | 125 | if input_tensor.is_cuda: 126 | xx_channel = xx_channel.cuda() 127 | yy_channel = yy_channel.cuda() 128 | 129 | ret = torch.cat([ 130 | input_tensor, 131 | xx_channel.type_as(input_tensor), 132 | yy_channel.type_as(input_tensor)], dim=1) 133 | 134 | if self.with_r: 135 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)) 136 | if input_tensor.is_cuda: 137 | rr = rr.cuda() 138 | ret = torch.cat([ret, rr], dim=1) 139 | 140 | return ret 141 | 142 | 143 | class CoordConv(nn.Module): 144 | 145 | def __init__(self, in_channels, out_channels, with_r=False, **kwargs): 146 | super().__init__() 147 | self.addcoords = AddCoords(with_r=with_r) 148 | self.conv = nn.Conv2d(in_channels + 2, out_channels, **kwargs) 149 | 150 | def forward(self, x): 151 | ret = self.addcoords(x) 152 | ret = self.conv(ret) 153 | return ret 154 | -------------------------------------------------------------------------------- /dataset/convert_WFLW.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, "../utils/") 3 | import numpy as np 4 | import os 5 | import glob 6 | import scipy.io as sio 7 | import cv2 8 | from skimage import io 9 | from utils import cv_crop 10 | import torch 11 | from joblib import Parallel, delayed 12 | 13 | def transform(point, center, scale, resolution, rotation=0, invert=False): 14 | _pt = np.ones(3) 15 | _pt[0] = point[0] 16 | _pt[1] = point[1] 17 | 18 | h = 200.0 * scale 19 | t = np.eye(3) 20 | t[0, 0] = resolution / h 21 | t[1, 1] = resolution / h 22 | t[0, 2] = resolution * (-center[0] / h + 0.5) 23 | t[1, 2] = resolution * (-center[1] / h + 0.5) 24 | 25 | if rotation != 0: 26 | rotation = -rotation 27 | r = np.eye(3) 28 | ang = rotation * math.pi / 180.0 29 | s = math.sin(ang) 30 | c = math.cos(ang) 31 | r[0][0] = c 32 | r[0][1] = -s 33 | r[1][0] = s 34 | r[1][1] = c 35 | 36 | t_ = np.eye(3) 37 | t_[0][2] = -resolution / 2.0 38 | t_[1][2] = -resolution / 2.0 39 | t_inv = torch.eye(3) 40 | t_inv[0][2] = resolution / 2.0 41 | t_inv[1][2] = resolution / 2.0 42 | t = reduce(np.matmul, [t_inv, r, t_, t]) 43 | 44 | if invert: 45 | t = np.linalg.inv(t) 46 | new_point = (np.matmul(t, _pt))[0:2] 47 | 48 | return new_point.astype(float) 49 | 50 | def parse_pts(pts_file): 51 | pts = [] 52 | with open(pts_file) as f: 53 | for line in f.readlines(): 54 | line = line.strip() 55 | if line[0].isdigit() == False: 56 | continue 57 | else: 58 | idx = line.find(' ') 59 | x, y = float(line[:idx]), float(line[idx+1:]) 60 | pts.append([x, y]) 61 | if len(pts) != 68: 62 | print('Not enough points') 63 | else: 64 | return np.array(pts) 65 | 66 | class WFLWInstance(): 67 | def __init__(self, line, idx): 68 | self.idx = idx 69 | line = line.strip().split(' ') 70 | # convert landmarks 71 | landmarks_list = list(map(float, line[:196])) 72 | self.landmarks = [] 73 | for i in range(0, 196, 2): 74 | self.landmarks.append([landmarks_list[i], landmarks_list[i+1]]) 75 | self.landmarks = np.array(self.landmarks) 76 | 77 | # convert bboxes 78 | if len(line) == 207: 79 | self.bbox = list(map(float, line[196:200])) 80 | else: 81 | self.bbox = None 82 | 83 | # convert image name 84 | self.image_base_name = line[-1] 85 | self.image_first_point = line[0] 86 | 87 | def load_meta_subset_data(meta_path): 88 | with open(meta_path) as f: 89 | lines = f.readlines() 90 | 91 | meta_data = [] 92 | idx = 0 93 | for line in lines: 94 | line = line.strip().split(' ') 95 | meta_data.append(line[-1]+line[0]) 96 | return meta_data 97 | 98 | def load_meta_data(meta_path, meta_subset_data=None): 99 | with open(meta_path) as f: 100 | lines = f.readlines() 101 | 102 | meta_data = [] 103 | idx = 0 104 | for line in lines: 105 | wflw_instance = WFLWInstance(line, idx) 106 | if meta_subset_data is not None and (wflw_instance.image_base_name+wflw_instance.image_first_point) in meta_subset_data: 107 | meta_data.append(wflw_instance) 108 | idx += 1 109 | return meta_data 110 | 111 | def process_single(single, image_path, image_save_path, landmarks_save_path): 112 | # print('Processing: {}'.format(single.image_base_name)) 113 | image_full_path = os.path.join(image_path, single.image_base_name) 114 | image = io.imread(image_full_path) 115 | if len(image.shape) == 2: 116 | image = np.stack((image, image, image), -1) 117 | 118 | pts = single.landmarks 119 | left, top, right, bottom = [int(x) for x in single.bbox] 120 | lr_pad = int(0.05 * (right - left) / 2) 121 | tb_pad = int(0.05 * (bottom - top) / 2) 122 | left = max(0, left - lr_pad) 123 | right = right + lr_pad 124 | top = max(0, top - tb_pad) 125 | bottom = bottom + tb_pad 126 | 127 | center = torch.FloatTensor( 128 | [right - (right - left) / 2.0, bottom - 129 | (bottom - top) / 2.0]) 130 | scale_factor = 250.0 131 | scale = (right - left + bottom - top) / scale_factor 132 | new_image, new_landmarks = cv_crop(image, pts, center, scale, 450, 0) 133 | while np.min(new_landmarks) < 10 or np.max(new_landmarks) > 440: 134 | scale_factor -= 10 135 | scale = (right - left + bottom - top) / scale_factor 136 | new_image, new_landmarks = cv_crop(image, pts, center, scale, 450, 0) 137 | assert (scale_factor > 0), "Landmarks out of boundary!" 138 | if new_image != []: 139 | io.imsave(os.path.join(image_save_path, os.path.basename(image_full_path[:-4]+'_' + str(single.idx) + image_full_path[-4:])), new_image) 140 | np.save(os.path.join(landmarks_save_path, os.path.basename(image_full_path[:-4]+ '_' + str(single.idx) + '.pts')), new_landmarks) 141 | 142 | if __name__ == '__main__': 143 | image_path = './WFLW_images/' 144 | meta_subset_path = './WFLW_annotations/list_98pt_rect_attr_train_test/list_98pt_rect_attr_test.txt' 145 | meta_path = './WFLW_annotations/list_98pt_rect_attr_train_test/list_98pt_rect_attr_test.txt' 146 | image_save_path = './WFLW_test/images/' 147 | landmarks_save_path = './WFLW_test/landmarks/' 148 | if not os.path.exists(image_save_path): 149 | os.makedirs(image_save_path) 150 | if not os.path.exists(landmarks_save_path): 151 | os.makedirs(landmarks_save_path) 152 | exts = ['*.png', '*.jpg'] 153 | meta_subset_data = load_meta_subset_data(meta_subset_path) 154 | meta_data = load_meta_data(meta_path, meta_subset_data) 155 | assert (len(meta_data) == len(meta_subset_data)), "Some images are missing!" 156 | print("Total images: {0:d}".format(len(meta_data))) 157 | Parallel(n_jobs=10, 158 | backend='threading', 159 | verbose=10)(delayed(process_single)(single, image_path, 160 | image_save_path, 161 | landmarks_save_path) for single in meta_data) 162 | -------------------------------------------------------------------------------- /core/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from core.coord_conv import CoordConvTh 6 | 7 | 8 | def conv3x3(in_planes, out_planes, strd=1, padding=1, 9 | bias=False,dilation=1): 10 | "3x3 convolution with padding" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, 12 | stride=strd, padding=padding, bias=bias, 13 | dilation=dilation) 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None): 19 | super(BasicBlock, self).__init__() 20 | self.conv1 = conv3x3(inplanes, planes, stride) 21 | # self.bn1 = nn.BatchNorm2d(planes) 22 | self.relu = nn.ReLU(inplace=True) 23 | self.conv2 = conv3x3(planes, planes) 24 | # self.bn2 = nn.BatchNorm2d(planes) 25 | self.downsample = downsample 26 | self.stride = stride 27 | 28 | def forward(self, x): 29 | residual = x 30 | 31 | out = self.conv1(x) 32 | # out = self.bn1(out) 33 | out = self.relu(out) 34 | 35 | out = self.conv2(out) 36 | # out = self.bn2(out) 37 | 38 | if self.downsample is not None: 39 | residual = self.downsample(x) 40 | 41 | out += residual 42 | out = self.relu(out) 43 | 44 | return out 45 | 46 | class ConvBlock(nn.Module): 47 | def __init__(self, in_planes, out_planes): 48 | super(ConvBlock, self).__init__() 49 | self.bn1 = nn.BatchNorm2d(in_planes) 50 | self.conv1 = conv3x3(in_planes, int(out_planes / 2)) 51 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) 52 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4), 53 | padding=1, dilation=1) 54 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) 55 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4), 56 | padding=1, dilation=1) 57 | 58 | if in_planes != out_planes: 59 | self.downsample = nn.Sequential( 60 | nn.BatchNorm2d(in_planes), 61 | nn.ReLU(True), 62 | nn.Conv2d(in_planes, out_planes, 63 | kernel_size=1, stride=1, bias=False), 64 | ) 65 | else: 66 | self.downsample = None 67 | 68 | def forward(self, x): 69 | residual = x 70 | 71 | out1 = self.bn1(x) 72 | out1 = F.relu(out1, True) 73 | out1 = self.conv1(out1) 74 | 75 | out2 = self.bn2(out1) 76 | out2 = F.relu(out2, True) 77 | out2 = self.conv2(out2) 78 | 79 | out3 = self.bn3(out2) 80 | out3 = F.relu(out3, True) 81 | out3 = self.conv3(out3) 82 | 83 | out3 = torch.cat((out1, out2, out3), 1) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(residual) 87 | 88 | out3 += residual 89 | 90 | return out3 91 | 92 | class HourGlass(nn.Module): 93 | def __init__(self, num_modules, depth, num_features, first_one=False): 94 | super(HourGlass, self).__init__() 95 | self.num_modules = num_modules 96 | self.depth = depth 97 | self.features = num_features 98 | self.coordconv = CoordConvTh(x_dim=64, y_dim=64, 99 | with_r=True, with_boundary=True, 100 | in_channels=256, first_one=first_one, 101 | out_channels=256, 102 | kernel_size=1, 103 | stride=1, padding=0) 104 | self._generate_network(self.depth) 105 | 106 | def _generate_network(self, level): 107 | self.add_module('b1_' + str(level), ConvBlock(256, 256)) 108 | 109 | self.add_module('b2_' + str(level), ConvBlock(256, 256)) 110 | 111 | if level > 1: 112 | self._generate_network(level - 1) 113 | else: 114 | self.add_module('b2_plus_' + str(level), ConvBlock(256, 256)) 115 | 116 | self.add_module('b3_' + str(level), ConvBlock(256, 256)) 117 | 118 | def _forward(self, level, inp): 119 | # Upper branch 120 | up1 = inp 121 | up1 = self._modules['b1_' + str(level)](up1) 122 | 123 | # Lower branch 124 | low1 = F.avg_pool2d(inp, 2, stride=2) 125 | low1 = self._modules['b2_' + str(level)](low1) 126 | 127 | if level > 1: 128 | low2 = self._forward(level - 1, low1) 129 | else: 130 | low2 = low1 131 | low2 = self._modules['b2_plus_' + str(level)](low2) 132 | 133 | low3 = low2 134 | low3 = self._modules['b3_' + str(level)](low3) 135 | 136 | up2 = F.upsample(low3, scale_factor=2, mode='nearest') 137 | 138 | return up1 + up2 139 | 140 | def forward(self, x, heatmap): 141 | x, last_channel = self.coordconv(x, heatmap) 142 | return self._forward(self.depth, x), last_channel 143 | 144 | class FAN(nn.Module): 145 | 146 | def __init__(self, num_modules=1, end_relu=False, gray_scale=False, 147 | num_landmarks=68): 148 | super(FAN, self).__init__() 149 | self.num_modules = num_modules 150 | self.gray_scale = gray_scale 151 | self.end_relu = end_relu 152 | self.num_landmarks = num_landmarks 153 | 154 | # Base part 155 | if self.gray_scale: 156 | self.conv1 = CoordConvTh(x_dim=256, y_dim=256, 157 | with_r=True, with_boundary=False, 158 | in_channels=3, out_channels=64, 159 | kernel_size=7, 160 | stride=2, padding=3) 161 | else: 162 | self.conv1 = CoordConvTh(x_dim=256, y_dim=256, 163 | with_r=True, with_boundary=False, 164 | in_channels=3, out_channels=64, 165 | kernel_size=7, 166 | stride=2, padding=3) 167 | self.bn1 = nn.BatchNorm2d(64) 168 | self.conv2 = ConvBlock(64, 128) 169 | self.conv3 = ConvBlock(128, 128) 170 | self.conv4 = ConvBlock(128, 256) 171 | 172 | # Stacking part 173 | for hg_module in range(self.num_modules): 174 | if hg_module == 0: 175 | first_one = True 176 | else: 177 | first_one = False 178 | self.add_module('m' + str(hg_module), HourGlass(1, 4, 256, 179 | first_one)) 180 | self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256)) 181 | self.add_module('conv_last' + str(hg_module), 182 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 183 | self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) 184 | self.add_module('l' + str(hg_module), nn.Conv2d(256, 185 | num_landmarks+1, kernel_size=1, stride=1, padding=0)) 186 | 187 | if hg_module < self.num_modules - 1: 188 | self.add_module( 189 | 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 190 | self.add_module('al' + str(hg_module), nn.Conv2d(num_landmarks+1, 191 | 256, kernel_size=1, stride=1, padding=0)) 192 | 193 | def forward(self, x): 194 | x, _ = self.conv1(x) 195 | x = F.relu(self.bn1(x), True) 196 | # x = F.relu(self.bn1(self.conv1(x)), True) 197 | x = F.avg_pool2d(self.conv2(x), 2, stride=2) 198 | x = self.conv3(x) 199 | x = self.conv4(x) 200 | 201 | previous = x 202 | 203 | outputs = [] 204 | boundary_channels = [] 205 | tmp_out = None 206 | for i in range(self.num_modules): 207 | hg, boundary_channel = self._modules['m' + str(i)](previous, 208 | tmp_out) 209 | 210 | ll = hg 211 | ll = self._modules['top_m_' + str(i)](ll) 212 | 213 | ll = F.relu(self._modules['bn_end' + str(i)] 214 | (self._modules['conv_last' + str(i)](ll)), True) 215 | 216 | # Predict heatmaps 217 | tmp_out = self._modules['l' + str(i)](ll) 218 | if self.end_relu: 219 | tmp_out = F.relu(tmp_out) # HACK: Added relu 220 | outputs.append(tmp_out) 221 | boundary_channels.append(boundary_channel) 222 | 223 | if i < self.num_modules - 1: 224 | ll = self._modules['bl' + str(i)](ll) 225 | tmp_out_ = self._modules['al' + str(i)](tmp_out) 226 | previous = previous + ll + tmp_out_ 227 | 228 | return outputs, boundary_channels 229 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import sys 4 | import math 5 | import torch 6 | import cv2 7 | from PIL import Image 8 | from skimage import io 9 | from skimage import transform as ski_transform 10 | from scipy import ndimage 11 | import numpy as np 12 | import matplotlib 13 | import matplotlib.pyplot as plt 14 | from torch.utils.data import Dataset, DataLoader 15 | from torchvision import transforms, utils 16 | 17 | def _gaussian( 18 | size=3, sigma=0.25, amplitude=1, normalize=False, width=None, 19 | height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5, 20 | mean_vert=0.5): 21 | # handle some defaults 22 | if width is None: 23 | width = size 24 | if height is None: 25 | height = size 26 | if sigma_horz is None: 27 | sigma_horz = sigma 28 | if sigma_vert is None: 29 | sigma_vert = sigma 30 | center_x = mean_horz * width + 0.5 31 | center_y = mean_vert * height + 0.5 32 | gauss = np.empty((height, width), dtype=np.float32) 33 | # generate kernel 34 | for i in range(height): 35 | for j in range(width): 36 | gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / ( 37 | sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0)) 38 | if normalize: 39 | gauss = gauss / np.sum(gauss) 40 | return gauss 41 | 42 | def draw_gaussian(image, point, sigma): 43 | # Check if the gaussian is inside 44 | ul = [np.floor(np.floor(point[0]) - 3 * sigma), 45 | np.floor(np.floor(point[1]) - 3 * sigma)] 46 | br = [np.floor(np.floor(point[0]) + 3 * sigma), 47 | np.floor(np.floor(point[1]) + 3 * sigma)] 48 | if (ul[0] > image.shape[1] or ul[1] > 49 | image.shape[0] or br[0] < 1 or br[1] < 1): 50 | return image 51 | size = 6 * sigma + 1 52 | g = _gaussian(size) 53 | g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - 54 | int(max(1, ul[0])) + int(max(1, -ul[0]))] 55 | g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - 56 | int(max(1, ul[1])) + int(max(1, -ul[1]))] 57 | img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))] 58 | img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))] 59 | assert (g_x[0] > 0 and g_y[1] > 0) 60 | correct = False 61 | while not correct: 62 | try: 63 | image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1] 64 | ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]] 65 | correct = True 66 | except: 67 | print('img_x: {}, img_y: {}, g_x:{}, g_y:{}, point:{}, g_shape:{}, ul:{}, br:{}'.format(img_x, img_y, g_x, g_y, point, g.shape, ul, br)) 68 | ul = [np.floor(np.floor(point[0]) - 3 * sigma), 69 | np.floor(np.floor(point[1]) - 3 * sigma)] 70 | br = [np.floor(np.floor(point[0]) + 3 * sigma), 71 | np.floor(np.floor(point[1]) + 3 * sigma)] 72 | g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - 73 | int(max(1, ul[0])) + int(max(1, -ul[0]))] 74 | g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - 75 | int(max(1, ul[1])) + int(max(1, -ul[1]))] 76 | img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))] 77 | img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))] 78 | pass 79 | image[image > 1] = 1 80 | return image 81 | 82 | def transform(point, center, scale, resolution, rotation=0, invert=False): 83 | _pt = np.ones(3) 84 | _pt[0] = point[0] 85 | _pt[1] = point[1] 86 | 87 | h = 200.0 * scale 88 | t = np.eye(3) 89 | t[0, 0] = resolution / h 90 | t[1, 1] = resolution / h 91 | t[0, 2] = resolution * (-center[0] / h + 0.5) 92 | t[1, 2] = resolution * (-center[1] / h + 0.5) 93 | 94 | if rotation != 0: 95 | rotation = -rotation 96 | r = np.eye(3) 97 | ang = rotation * math.pi / 180.0 98 | s = math.sin(ang) 99 | c = math.cos(ang) 100 | r[0][0] = c 101 | r[0][1] = -s 102 | r[1][0] = s 103 | r[1][1] = c 104 | 105 | t_ = np.eye(3) 106 | t_[0][2] = -resolution / 2.0 107 | t_[1][2] = -resolution / 2.0 108 | t_inv = torch.eye(3) 109 | t_inv[0][2] = resolution / 2.0 110 | t_inv[1][2] = resolution / 2.0 111 | t = reduce(np.matmul, [t_inv, r, t_, t]) 112 | 113 | if invert: 114 | t = np.linalg.inv(t) 115 | new_point = (np.matmul(t, _pt))[0:2] 116 | 117 | return new_point.astype(int) 118 | 119 | def cv_crop(image, landmarks, center, scale, resolution=256, center_shift=0): 120 | new_image = cv2.copyMakeBorder(image, center_shift, 121 | center_shift, 122 | center_shift, 123 | center_shift, 124 | cv2.BORDER_CONSTANT, value=[0,0,0]) 125 | new_landmarks = landmarks.copy() 126 | if center_shift != 0: 127 | center[0] += center_shift 128 | center[1] += center_shift 129 | new_landmarks = new_landmarks + center_shift 130 | length = 200 * scale 131 | top = int(center[1] - length // 2) 132 | bottom = int(center[1] + length // 2) 133 | left = int(center[0] - length // 2) 134 | right = int(center[0] + length // 2) 135 | y_pad = abs(min(top, new_image.shape[0] - bottom, 0)) 136 | x_pad = abs(min(left, new_image.shape[1] - right, 0)) 137 | top, bottom, left, right = top + y_pad, bottom + y_pad, left + x_pad, right + x_pad 138 | new_image = cv2.copyMakeBorder(new_image, y_pad, 139 | y_pad, 140 | x_pad, 141 | x_pad, 142 | cv2.BORDER_CONSTANT, value=[0,0,0]) 143 | new_image = new_image[top:bottom, left:right] 144 | new_image = cv2.resize(new_image, dsize=(int(resolution), int(resolution)), 145 | interpolation=cv2.INTER_LINEAR) 146 | new_landmarks[:, 0] = (new_landmarks[:, 0] + x_pad - left) * resolution / length 147 | new_landmarks[:, 1] = (new_landmarks[:, 1] + y_pad - top) * resolution / length 148 | return new_image, new_landmarks 149 | 150 | def cv_rotate(image, landmarks, heatmap, rot, scale, resolution=256): 151 | img_mat = cv2.getRotationMatrix2D((resolution//2, resolution//2), rot, scale) 152 | ones = np.ones(shape=(landmarks.shape[0], 1)) 153 | stacked_landmarks = np.hstack([landmarks, ones]) 154 | new_landmarks = img_mat.dot(stacked_landmarks.T).T 155 | if np.max(new_landmarks) > 255 or np.min(new_landmarks) < 0: 156 | return image, landmarks, heatmap 157 | else: 158 | new_image = cv2.warpAffine(image, img_mat, (resolution, resolution)) 159 | if heatmap is not None: 160 | new_heatmap = np.zeros((heatmap.shape[0], 64, 64)) 161 | for i in range(heatmap.shape[0]): 162 | if new_landmarks[i][0] > 0: 163 | new_heatmap[i] = draw_gaussian(new_heatmap[i], 164 | new_landmarks[i]/4.0+1, 1) 165 | return new_image, new_landmarks, new_heatmap 166 | 167 | def show_landmarks(image, heatmap, gt_landmarks, gt_heatmap): 168 | """Show image with pred_landmarks""" 169 | pred_landmarks = [] 170 | pred_landmarks, _ = get_preds_fromhm(torch.from_numpy(heatmap).unsqueeze(0)) 171 | pred_landmarks = pred_landmarks.squeeze()*4 172 | 173 | # pred_landmarks2 = get_preds_fromhm2(heatmap) 174 | heatmap = np.max(gt_heatmap, axis=0) 175 | heatmap = heatmap / np.max(heatmap) 176 | # image = ski_transform.resize(image, (64, 64))*255 177 | image = image.astype(np.uint8) 178 | heatmap = np.max(gt_heatmap, axis=0) 179 | heatmap = ski_transform.resize(heatmap, (image.shape[0], image.shape[1])) 180 | heatmap *= 255 181 | heatmap = heatmap.astype(np.uint8) 182 | heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) 183 | plt.imshow(image) 184 | plt.scatter(gt_landmarks[:, 0], gt_landmarks[:, 1], s=0.5, marker='.', c='g') 185 | plt.scatter(pred_landmarks[:, 0], pred_landmarks[:, 1], s=0.5, marker='.', c='r') 186 | plt.pause(0.001) # pause a bit so that plots are updated 187 | 188 | def fan_NME(pred_heatmaps, gt_landmarks, num_landmarks=68): 189 | ''' 190 | Calculate total NME for a batch of data 191 | 192 | Args: 193 | pred_heatmaps: torch tensor of size [batch, points, height, width] 194 | gt_landmarks: torch tesnsor of size [batch, points, x, y] 195 | 196 | Returns: 197 | nme: sum of nme for this batch 198 | ''' 199 | nme = 0 200 | pred_landmarks, _ = get_preds_fromhm(pred_heatmaps) 201 | pred_landmarks = pred_landmarks.numpy() 202 | gt_landmarks = gt_landmarks.numpy() 203 | for i in range(pred_landmarks.shape[0]): 204 | pred_landmark = pred_landmarks[i] * 4.0 205 | gt_landmark = gt_landmarks[i] 206 | 207 | if num_landmarks == 68: 208 | left_eye = np.average(gt_landmark[36:42], axis=0) 209 | right_eye = np.average(gt_landmark[42:48], axis=0) 210 | norm_factor = np.linalg.norm(left_eye - right_eye) 211 | # norm_factor = np.linalg.norm(gt_landmark[36]- gt_landmark[45]) 212 | elif num_landmarks == 98: 213 | norm_factor = np.linalg.norm(gt_landmark[60]- gt_landmark[72]) 214 | elif num_landmarks == 19: 215 | left, top = gt_landmark[-2, :] 216 | right, bottom = gt_landmark[-1, :] 217 | norm_factor = math.sqrt(abs(right - left)*abs(top-bottom)) 218 | gt_landmark = gt_landmark[:-2, :] 219 | elif num_landmarks == 29: 220 | # norm_factor = np.linalg.norm(gt_landmark[8]- gt_landmark[9]) 221 | norm_factor = np.linalg.norm(gt_landmark[16]- gt_landmark[17]) 222 | nme += (np.sum(np.linalg.norm(pred_landmark - gt_landmark, axis=1)) / pred_landmark.shape[0]) / norm_factor 223 | return nme 224 | 225 | def fan_NME_hm(pred_heatmaps, gt_heatmaps, num_landmarks=68): 226 | ''' 227 | Calculate total NME for a batch of data 228 | 229 | Args: 230 | pred_heatmaps: torch tensor of size [batch, points, height, width] 231 | gt_landmarks: torch tesnsor of size [batch, points, x, y] 232 | 233 | Returns: 234 | nme: sum of nme for this batch 235 | ''' 236 | nme = 0 237 | pred_landmarks, _ = get_index_fromhm(pred_heatmaps) 238 | pred_landmarks = pred_landmarks.numpy() 239 | gt_landmarks = gt_landmarks.numpy() 240 | for i in range(pred_landmarks.shape[0]): 241 | pred_landmark = pred_landmarks[i] * 4.0 242 | gt_landmark = gt_landmarks[i] 243 | if num_landmarks == 68: 244 | left_eye = np.average(gt_landmark[36:42], axis=0) 245 | right_eye = np.average(gt_landmark[42:48], axis=0) 246 | norm_factor = np.linalg.norm(left_eye - right_eye) 247 | else: 248 | norm_factor = np.linalg.norm(gt_landmark[60]- gt_landmark[72]) 249 | nme += (np.sum(np.linalg.norm(pred_landmark - gt_landmark, axis=1)) / pred_landmark.shape[0]) / norm_factor 250 | return nme 251 | 252 | def power_transform(img, power): 253 | img = np.array(img) 254 | img_new = np.power((img/255.0), power) * 255.0 255 | img_new = img_new.astype(np.uint8) 256 | img_new = Image.fromarray(img_new) 257 | return img_new 258 | 259 | def get_preds_fromhm(hm, center=None, scale=None, rot=None): 260 | max, idx = torch.max( 261 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) 262 | idx += 1 263 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() 264 | preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) 265 | preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) 266 | 267 | for i in range(preds.size(0)): 268 | for j in range(preds.size(1)): 269 | hm_ = hm[i, j, :] 270 | pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 271 | if pX > 0 and pX < 63 and pY > 0 and pY < 63: 272 | diff = torch.FloatTensor( 273 | [hm_[pY, pX + 1] - hm_[pY, pX - 1], 274 | hm_[pY + 1, pX] - hm_[pY - 1, pX]]) 275 | preds[i, j].add_(diff.sign_().mul_(.25)) 276 | 277 | preds.add_(-0.5) 278 | 279 | preds_orig = torch.zeros(preds.size()) 280 | if center is not None and scale is not None: 281 | for i in range(hm.size(0)): 282 | for j in range(hm.size(1)): 283 | preds_orig[i, j] = transform( 284 | preds[i, j], center, scale, hm.size(2), rot, True) 285 | 286 | return preds, preds_orig 287 | 288 | def get_index_fromhm(hm): 289 | max, idx = torch.max( 290 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) 291 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() 292 | preds[..., 0].remainder_(hm.size(3)) 293 | preds[..., 1].div_(hm.size(2)).floor_() 294 | 295 | for i in range(preds.size(0)): 296 | for j in range(preds.size(1)): 297 | hm_ = hm[i, j, :] 298 | pX, pY = int(preds[i, j, 0]), int(preds[i, j, 1]) 299 | if pX > 0 and pX < 63 and pY > 0 and pY < 63: 300 | diff = torch.FloatTensor( 301 | [hm_[pY, pX + 1] - hm_[pY, pX - 1], 302 | hm_[pY + 1, pX] - hm_[pY - 1, pX]]) 303 | preds[i, j].add_(diff.sign_().mul_(.25)) 304 | 305 | return preds 306 | 307 | def shuffle_lr(parts, num_landmarks=68, pairs=None): 308 | if num_landmarks == 68: 309 | if pairs is None: 310 | pairs = [[0, 16], [1, 15], [2, 14], [3, 13], [4, 12], [5, 11], [6, 10], 311 | [7, 9], [17, 26], [18, 25], [19, 24], [20, 23], [21, 22], [36, 45], 312 | [37, 44], [38, 43], [39, 42], [41, 46], [40, 47], [31, 35], [32, 34], 313 | [50, 52], [49, 53], [48, 54], [61, 63], [60, 64], [67, 65], [59, 55], [58, 56]] 314 | elif num_landmarks == 98: 315 | if pairs is None: 316 | pairs = [[0, 32], [1,31], [2, 30], [3, 29], [4, 28], [5, 27], [6, 26], [7, 25], [8, 24], [9, 23], [10, 22], [11, 21], [12, 20], [13, 19], [14, 18], [15, 17], [33, 46], [34, 45], [35, 44], [36, 43], [37, 42], [38, 50], [39, 49], [40, 48], [41, 47], [60, 72], [61, 71], [62, 70], [63, 69], [64, 68], [65, 75], [66, 74], [67, 73], [96, 97], [55, 59], [56, 58], [76, 82], [77, 81], [78, 80], [88, 92], [89, 91], [95, 93], [87, 83], [86, 84]] 317 | elif num_landmarks == 19: 318 | if pairs is None: 319 | pairs = [[0, 5], [1, 4], [2, 3], [6, 11], [7, 10], [8, 9], [12, 14], [15, 17]] 320 | elif num_landmarks == 29: 321 | if pairs is None: 322 | pairs = [[0, 1], [4, 6], [5, 7], [2, 3], [8, 9], [12, 14], [16, 17], [13, 15], [10, 11], [18, 19], [22, 23]] 323 | for matched_p in pairs: 324 | idx1, idx2 = matched_p[0], matched_p[1] 325 | tmp = np.copy(parts[idx1]) 326 | np.copyto(parts[idx1], parts[idx2]) 327 | np.copyto(parts[idx2], tmp) 328 | return parts 329 | 330 | 331 | def generate_weight_map(weight_map,heatmap): 332 | 333 | k_size = 3 334 | dilate = ndimage.grey_dilation(heatmap ,size=(k_size,k_size)) 335 | weight_map[np.where(dilate>0.2)] = 1 336 | return weight_map 337 | 338 | def fig2data(fig): 339 | """ 340 | @brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it 341 | @param fig a matplotlib figure 342 | @return a numpy 3D array of RGBA values 343 | """ 344 | # draw the renderer 345 | fig.canvas.draw ( ) 346 | 347 | # Get the RGB buffer from the figure 348 | w,h = fig.canvas.get_width_height() 349 | buf = np.fromstring (fig.canvas.tostring_rgb(), dtype=np.uint8) 350 | buf.shape = (w, h, 3) 351 | 352 | # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode 353 | buf = np.roll (buf, 3, axis=2) 354 | return buf 355 | -------------------------------------------------------------------------------- /core/dataloader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import random 4 | import glob 5 | import torch 6 | from skimage import io 7 | from skimage import transform as ski_transform 8 | from skimage.color import rgb2gray 9 | import scipy.io as sio 10 | from scipy import interpolate 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from torch.utils.data import Dataset, DataLoader 14 | from torchvision import transforms, utils 15 | from torchvision.transforms import Lambda, Compose 16 | from torchvision.transforms.functional import adjust_brightness, adjust_contrast, adjust_saturation, adjust_hue 17 | from utils.utils import cv_crop, cv_rotate, draw_gaussian, transform, power_transform, shuffle_lr, fig2data, generate_weight_map 18 | from PIL import Image 19 | import cv2 20 | import copy 21 | import math 22 | from imgaug import augmenters as iaa 23 | 24 | 25 | class AddBoundary(object): 26 | def __init__(self, num_landmarks=68): 27 | self.num_landmarks = num_landmarks 28 | 29 | def __call__(self, sample): 30 | landmarks_64 = np.floor(sample['landmarks'] / 4.0) 31 | if self.num_landmarks == 68: 32 | boundaries = {} 33 | boundaries['cheek'] = landmarks_64[0:17] 34 | boundaries['left_eyebrow'] = landmarks_64[17:22] 35 | boundaries['right_eyebrow'] = landmarks_64[22:27] 36 | boundaries['uper_left_eyelid'] = landmarks_64[36:40] 37 | boundaries['lower_left_eyelid'] = np.array([landmarks_64[i] for i in [36, 41, 40, 39]]) 38 | boundaries['upper_right_eyelid'] = landmarks_64[42:46] 39 | boundaries['lower_right_eyelid'] = np.array([landmarks_64[i] for i in [42, 47, 46, 45]]) 40 | boundaries['noise'] = landmarks_64[27:31] 41 | boundaries['noise_bot'] = landmarks_64[31:36] 42 | boundaries['upper_outer_lip'] = landmarks_64[48:55] 43 | boundaries['upper_inner_lip'] = np.array([landmarks_64[i] for i in [60, 61, 62, 63, 64]]) 44 | boundaries['lower_outer_lip'] = np.array([landmarks_64[i] for i in [48, 59, 58, 57, 56, 55, 54]]) 45 | boundaries['lower_inner_lip'] = np.array([landmarks_64[i] for i in [60, 67, 66, 65, 64]]) 46 | elif self.num_landmarks == 98: 47 | boundaries = {} 48 | boundaries['cheek'] = landmarks_64[0:33] 49 | boundaries['left_eyebrow'] = landmarks_64[33:38] 50 | boundaries['right_eyebrow'] = landmarks_64[42:47] 51 | boundaries['uper_left_eyelid'] = landmarks_64[60:65] 52 | boundaries['lower_left_eyelid'] = np.array([landmarks_64[i] for i in [60, 67, 66, 65, 64]]) 53 | boundaries['upper_right_eyelid'] = landmarks_64[68:73] 54 | boundaries['lower_right_eyelid'] = np.array([landmarks_64[i] for i in [68, 75, 74, 73, 72]]) 55 | boundaries['noise'] = landmarks_64[51:55] 56 | boundaries['noise_bot'] = landmarks_64[55:60] 57 | boundaries['upper_outer_lip'] = landmarks_64[76:83] 58 | boundaries['upper_inner_lip'] = np.array([landmarks_64[i] for i in [88, 89, 90, 91, 92]]) 59 | boundaries['lower_outer_lip'] = np.array([landmarks_64[i] for i in [76, 87, 86, 85, 84, 83, 82]]) 60 | boundaries['lower_inner_lip'] = np.array([landmarks_64[i] for i in [88, 95, 94, 93, 92]]) 61 | elif self.num_landmarks == 19: 62 | boundaries = {} 63 | boundaries['left_eyebrow'] = landmarks_64[0:3] 64 | boundaries['right_eyebrow'] = landmarks_64[3:5] 65 | boundaries['left_eye'] = landmarks_64[6:9] 66 | boundaries['right_eye'] = landmarks_64[9:12] 67 | boundaries['noise'] = landmarks_64[12:15] 68 | 69 | elif self.num_landmarks == 29: 70 | boundaries = {} 71 | boundaries['upper_left_eyebrow'] = np.stack([ 72 | landmarks_64[0], 73 | landmarks_64[4], 74 | landmarks_64[2] 75 | ], axis=0) 76 | boundaries['lower_left_eyebrow'] = np.stack([ 77 | landmarks_64[0], 78 | landmarks_64[5], 79 | landmarks_64[2] 80 | ], axis=0) 81 | boundaries['upper_right_eyebrow'] = np.stack([ 82 | landmarks_64[1], 83 | landmarks_64[6], 84 | landmarks_64[3] 85 | ], axis=0) 86 | boundaries['lower_right_eyebrow'] = np.stack([ 87 | landmarks_64[1], 88 | landmarks_64[7], 89 | landmarks_64[3] 90 | ], axis=0) 91 | boundaries['upper_left_eye'] = np.stack([ 92 | landmarks_64[8], 93 | landmarks_64[12], 94 | landmarks_64[10] 95 | ], axis=0) 96 | boundaries['lower_left_eye'] = np.stack([ 97 | landmarks_64[8], 98 | landmarks_64[13], 99 | landmarks_64[10] 100 | ], axis=0) 101 | boundaries['upper_right_eye'] = np.stack([ 102 | landmarks_64[9], 103 | landmarks_64[14], 104 | landmarks_64[11] 105 | ], axis=0) 106 | boundaries['lower_right_eye'] = np.stack([ 107 | landmarks_64[9], 108 | landmarks_64[15], 109 | landmarks_64[11] 110 | ], axis=0) 111 | boundaries['noise'] = np.stack([ 112 | landmarks_64[18], 113 | landmarks_64[21], 114 | landmarks_64[19] 115 | ], axis=0) 116 | boundaries['outer_upper_lip'] = np.stack([ 117 | landmarks_64[22], 118 | landmarks_64[24], 119 | landmarks_64[23] 120 | ], axis=0) 121 | boundaries['inner_upper_lip'] = np.stack([ 122 | landmarks_64[22], 123 | landmarks_64[25], 124 | landmarks_64[23] 125 | ], axis=0) 126 | boundaries['outer_lower_lip'] = np.stack([ 127 | landmarks_64[22], 128 | landmarks_64[26], 129 | landmarks_64[23] 130 | ], axis=0) 131 | boundaries['inner_lower_lip'] = np.stack([ 132 | landmarks_64[22], 133 | landmarks_64[27], 134 | landmarks_64[23] 135 | ], axis=0) 136 | functions = {} 137 | 138 | for key, points in boundaries.items(): 139 | temp = points[0] 140 | new_points = points[0:1, :] 141 | for point in points[1:]: 142 | if point[0] == temp[0] and point[1] == temp[1]: 143 | continue 144 | else: 145 | new_points = np.concatenate((new_points, np.expand_dims(point, 0)), axis=0) 146 | temp = point 147 | points = new_points 148 | if points.shape[0] == 1: 149 | points = np.concatenate((points, points+0.001), axis=0) 150 | k = min(4, points.shape[0]) 151 | functions[key] = interpolate.splprep([points[:, 0], points[:, 1]], k=k-1,s=0) 152 | 153 | boundary_map = np.zeros((64, 64)) 154 | 155 | fig = plt.figure(figsize=[64/96.0, 64/96.0], dpi=96) 156 | 157 | ax = fig.add_axes([0, 0, 1, 1]) 158 | 159 | ax.axis('off') 160 | 161 | ax.imshow(boundary_map, interpolation='nearest', cmap='gray') 162 | #ax.scatter(landmarks[:, 0], landmarks[:, 1], s=1, marker=',', c='w') 163 | 164 | for key in functions.keys(): 165 | xnew = np.arange(0, 1, 0.01) 166 | out = interpolate.splev(xnew, functions[key][0], der=0) 167 | plt.plot(out[0], out[1], ',', linewidth=1, color='w') 168 | 169 | img = fig2data(fig) 170 | 171 | plt.close() 172 | 173 | sigma = 1 174 | temp = 255-img[:,:,1] 175 | temp = cv2.distanceTransform(temp, cv2.DIST_L2, cv2.DIST_MASK_PRECISE) 176 | temp = temp.astype(np.float32) 177 | temp = np.where(temp < 3*sigma, np.exp(-(temp*temp)/(2*sigma*sigma)), 0 ) 178 | 179 | fig = plt.figure(figsize=[64/96.0, 64/96.0], dpi=96) 180 | 181 | ax = fig.add_axes([0, 0, 1, 1]) 182 | 183 | ax.axis('off') 184 | ax.imshow(temp, cmap='gray') 185 | plt.close() 186 | 187 | boundary_map = fig2data(fig) 188 | 189 | sample['boundary'] = boundary_map[:, :, 0] 190 | 191 | return sample 192 | 193 | class AddWeightMap(object): 194 | def __call__(self, sample): 195 | heatmap= sample['heatmap'] 196 | boundary = sample['boundary'] 197 | heatmap = np.concatenate((heatmap, np.expand_dims(boundary, axis=0)), 0) 198 | weight_map = np.zeros_like(heatmap) 199 | for i in range(heatmap.shape[0]): 200 | weight_map[i] = generate_weight_map(weight_map[i], 201 | heatmap[i]) 202 | sample['weight_map'] = weight_map 203 | return sample 204 | 205 | class ToTensor(object): 206 | """Convert ndarrays in sample to Tensors.""" 207 | 208 | def __call__(self, sample): 209 | image, heatmap, landmarks, boundary, weight_map= sample['image'], sample['heatmap'], sample['landmarks'], sample['boundary'], sample['weight_map'] 210 | 211 | # swap color axis because 212 | # numpy image: H x W x C 213 | # torch image: C X H X W 214 | if len(image.shape) == 2: 215 | image = np.expand_dims(image, axis=2) 216 | image_small = np.expand_dims(image_small, axis=2) 217 | image = image.transpose((2, 0, 1)) 218 | boundary = np.expand_dims(boundary, axis=2) 219 | boundary = boundary.transpose((2, 0, 1)) 220 | return {'image': torch.from_numpy(image).float().div(255.0), 221 | 'heatmap': torch.from_numpy(heatmap).float(), 222 | 'landmarks': torch.from_numpy(landmarks).float(), 223 | 'boundary': torch.from_numpy(boundary).float().div(255.0), 224 | 'weight_map': torch.from_numpy(weight_map).float()} 225 | 226 | class FaceLandmarksDataset(Dataset): 227 | """Face Landmarks dataset.""" 228 | 229 | def __init__(self, img_dir, landmarks_dir, num_landmarks=68, gray_scale=False, 230 | detect_face=False, enhance=False, center_shift=0, 231 | transform=None,): 232 | """ 233 | Args: 234 | landmark_dir (string): Path to the mat file with landmarks saved. 235 | img_dir (string): Directory with all the images. 236 | transform (callable, optional): Optional transform to be applied 237 | on a sample. 238 | """ 239 | self.img_dir = img_dir 240 | self.landmarks_dir = landmarks_dir 241 | self.num_lanmdkars = num_landmarks 242 | self.transform = transform 243 | self.img_names = glob.glob(self.img_dir+'*.jpg') + \ 244 | glob.glob(self.img_dir+'*.png') 245 | self.gray_scale = gray_scale 246 | self.detect_face = detect_face 247 | self.enhance = enhance 248 | self.center_shift = center_shift 249 | if self.detect_face: 250 | self.face_detector = MTCNN(thresh=[0.5, 0.6, 0.7]) 251 | def __len__(self): 252 | return len(self.img_names) 253 | 254 | def __getitem__(self, idx): 255 | img_name = self.img_names[idx] 256 | pil_image = Image.open(img_name) 257 | if pil_image.mode != "RGB": 258 | # if input is grayscale image, convert it to 3 channel image 259 | if self.enhance: 260 | pil_image = power_transform(pil_image, 0.5) 261 | temp_image = Image.new('RGB', pil_image.size) 262 | temp_image.paste(pil_image) 263 | pil_image = temp_image 264 | image = np.array(pil_image) 265 | if self.gray_scale: 266 | image = rgb2gray(image) 267 | image = np.expand_dims(image, axis=2) 268 | image = np.concatenate((image, image, image), axis=2) 269 | image = image * 255.0 270 | image = image.astype(np.uint8) 271 | if not self.detect_face: 272 | center = [450//2, 450//2+0] 273 | if self.center_shift != 0: 274 | center[0] += int(np.random.uniform(-self.center_shift, 275 | self.center_shift)) 276 | center[1] += int(np.random.uniform(-self.center_shift, 277 | self.center_shift)) 278 | scale = 1.8 279 | else: 280 | detected_faces = self.face_detector.detect_image(image) 281 | if len(detected_faces) > 0: 282 | box = detected_faces[0] 283 | left, top, right, bottom, _ = box 284 | center = [right - (right - left) / 2.0, 285 | bottom - (bottom - top) / 2.0] 286 | center[1] = center[1] - (bottom - top) * 0.12 287 | scale = (right - left + bottom - top) / 195.0 288 | else: 289 | center = [450//2, 450//2+0] 290 | scale = 1.8 291 | if self.center_shift != 0: 292 | shift = self.center * self.center_shift / 450 293 | center[0] += int(np.random.uniform(-shift, shift)) 294 | center[1] += int(np.random.uniform(-shift, shift)) 295 | base_name = os.path.basename(img_name) 296 | landmarks_base_name = base_name[:-4] + '_pts.mat' 297 | landmarks_name = os.path.join(self.landmarks_dir, landmarks_base_name) 298 | if os.path.isfile(landmarks_name): 299 | mat_data = sio.loadmat(landmarks_name) 300 | landmarks = mat_data['pts_2d'] 301 | elif os.path.isfile(landmarks_name[:-8] + '.pts.npy'): 302 | landmarks = np.load(landmarks_name[:-8] + '.pts.npy') 303 | else: 304 | landmarks = [] 305 | heatmap = [] 306 | 307 | if landmarks != []: 308 | new_image, new_landmarks = cv_crop(image, landmarks, center, 309 | scale, 256, self.center_shift) 310 | tries = 0 311 | while self.center_shift != 0 and tries < 5 and (np.max(new_landmarks) > 240 or np.min(new_landmarks) < 15): 312 | center = [450//2, 450//2+0] 313 | scale += 0.05 314 | center[0] += int(np.random.uniform(-self.center_shift, 315 | self.center_shift)) 316 | center[1] += int(np.random.uniform(-self.center_shift, 317 | self.center_shift)) 318 | 319 | new_image, new_landmarks = cv_crop(image, landmarks, 320 | center, scale, 256, 321 | self.center_shift) 322 | tries += 1 323 | if np.max(new_landmarks) > 250 or np.min(new_landmarks) < 5: 324 | center = [450//2, 450//2+0] 325 | scale = 2.25 326 | new_image, new_landmarks = cv_crop(image, landmarks, 327 | center, scale, 256, 328 | 100) 329 | assert (np.min(new_landmarks) > 0 and np.max(new_landmarks) < 256), \ 330 | "Landmarks out of boundary!" 331 | image = new_image 332 | landmarks = new_landmarks 333 | heatmap = np.zeros((self.num_lanmdkars, 64, 64)) 334 | for i in range(self.num_lanmdkars): 335 | if landmarks[i][0] > 0: 336 | heatmap[i] = draw_gaussian(heatmap[i], landmarks[i]/4.0+1, 1) 337 | sample = {'image': image, 'heatmap': heatmap, 'landmarks': landmarks} 338 | if self.transform: 339 | sample = self.transform(sample) 340 | 341 | return sample 342 | 343 | def get_dataset(val_img_dir, val_landmarks_dir, batch_size, 344 | num_landmarks=68, rotation=0, scale=0, 345 | center_shift=0, random_flip=False, 346 | brightness=0, contrast=0, saturation=0, 347 | blur=False, noise=False, jpeg_effect=False, 348 | random_occlusion=False, gray_scale=False, 349 | detect_face=False, enhance=False): 350 | val_transforms = transforms.Compose([AddBoundary(num_landmarks), 351 | AddWeightMap(), 352 | ToTensor()]) 353 | 354 | val_dataset = FaceLandmarksDataset(val_img_dir, val_landmarks_dir, 355 | num_landmarks=num_landmarks, 356 | gray_scale=gray_scale, 357 | detect_face=detect_face, 358 | enhance=enhance, 359 | transform=val_transforms) 360 | 361 | val_dataloader = torch.utils.data.DataLoader(val_dataset, 362 | batch_size=batch_size, 363 | shuffle=False, 364 | num_workers=6) 365 | data_loaders = {'val': val_dataloader} 366 | dataset_sizes = {} 367 | dataset_sizes['val'] = len(val_dataset) 368 | return data_loaders, dataset_sizes 369 | --------------------------------------------------------------------------------