├── __init__.py
├── ckpt
└── .gitkeep
├── core
├── __init__.py
├── evaler.py
├── coord_conv.py
├── models.py
└── dataloader.py
├── utils
├── __init__.py
└── utils.py
├── images
├── wflw.png
└── wflw_table.png
├── .gitignore
├── requirements.txt
├── scripts
└── eval_wflw.sh
├── README.md
├── eval.py
├── dataset
└── convert_WFLW.py
└── LICENSE
/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ckpt/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/images/wflw.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/protossw512/AdaptiveWingLoss/HEAD/images/wflw.png
--------------------------------------------------------------------------------
/images/wflw_table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/protossw512/AdaptiveWingLoss/HEAD/images/wflw_table.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python generated files
2 | *.pyc
3 |
4 | # Project related files
5 | ckpt/*.pth
6 | dataset/*
7 | !dataset/!.py
8 | experiments/*
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python
2 | scipy>=0.17.0
3 | scikit-image
4 | numpy
5 | matplotlib
6 | Pillow>=4.3.0
7 | imgaug
8 | tensorflow
9 | git+https://github.com/lanpa/tensorboardX
10 | joblib
11 | torch==1.3.0
12 | torchvision==0.4.1
13 |
--------------------------------------------------------------------------------
/scripts/eval_wflw.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=1 python ../eval.py \
2 | --val_img_dir='../dataset/WFLW_test/images/' \
3 | --val_landmarks_dir='../dataset/WFLW_test/landmarks/' \
4 | --ckpt_save_path='../experiments/eval_iccv_0620' \
5 | --hg_blocks=4 \
6 | --pretrained_weights='../ckpt/WFLW_4HG.pth' \
7 | --num_landmarks=98 \
8 | --end_relu='False' \
9 | --batch_size=20 \
10 |
11 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AdaptiveWingLoss
2 | ## [arXiv](https://arxiv.org/abs/1904.07399)
3 | Pytorch Implementation of Adaptive Wing Loss for Robust Face Alignment via Heatmap Regression.
4 |
5 |
6 |
7 | ## Update Logs:
8 | ### October 28, 2019
9 | * Pretrained Model and evaluation code on WFLW dataset is released.
10 |
11 | ## Installation
12 | #### Note: Code was originally developed under Python2.X and Pytorch 0.4. This released version was revisioned from original code and was tested on Python3.5.7 and Pytorch 1.3.0.
13 |
14 | Install system requirements:
15 | ```
16 | sudo apt-get install python3-dev python3-pip python3-tk libglib2.0-0
17 | ```
18 |
19 | Install python dependencies:
20 | ```
21 | pip3 install -r requirements.txt
22 | ```
23 |
24 | ## Run Evaluation on WFLW dataset
25 | 1. Download and process WFLW dataset
26 | * Download WFLW dataset and annotation from [Here](https://wywu.github.io/projects/LAB/WFLW.html).
27 | * Unzip WFLW dataset and annotations and move files into ```./dataset``` directory. Your directory should look like this:
28 | ```
29 | AdaptiveWingLoss
30 | └───dataset
31 | │
32 | └───WFLW_annotations
33 | │ └───list_98pt_rect_attr_train_test
34 | │ │
35 | │ └───list_98pt_test
36 | │
37 | └───WFLW_images
38 | └───0--Parade
39 | │
40 | └───...
41 | ```
42 | * Inside ```./dataset``` directory, run:
43 | ```
44 | python convert_WFLW.py
45 | ```
46 | A new directory ```./dataset/WFLW_test``` should be generated with 2500 processed testing images and corresponding landmarks.
47 |
48 | 2. Download pretrained model from [Google Drive](https://drive.google.com/file/d/1HZaSjLoorQ4QCEx7PRTxOmg0bBPYSqhH/view?usp=sharing) and put it in ```./ckpt``` directory.
49 |
50 | 3. Within ```./Scripts``` directory, run following command:
51 | ```
52 | sh eval_wflw.sh
53 | ```
54 |
55 |
56 | *GTBbox indicates the ground truth landmarks are used as bounding box to crop faces.
57 |
58 | ## Future Plans
59 | - [x] Release evaluation code and pretrained model on WFLW dataset.
60 |
61 | - [ ] Release training code on WFLW dataset.
62 |
63 | - [ ] Release pretrained model and code on 300W, AFLW and COFW dataset.
64 |
65 | - [ ] Replease facial landmark detection API
66 |
67 |
68 | ## Citation
69 | If you find this useful for your research, please cite the following paper.
70 |
71 | ```
72 | @InProceedings{Wang_2019_ICCV,
73 | author = {Wang, Xinyao and Bo, Liefeng and Fuxin, Li},
74 | title = {Adaptive Wing Loss for Robust Face Alignment via Heatmap Regression},
75 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
76 | month = {October},
77 | year = {2019}
78 | }
79 | ```
80 |
81 | ## Acknowledgments
82 | This repository borrows or partially modifies hourglass model and data processing code from [face alignment](https://github.com/1adrianb/face-alignment) and [pose-hg-train](https://github.com/princeton-vl/pose-hg-train).
83 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import torch
3 | import argparse
4 | import numpy as np
5 | import torch.nn as nn
6 | import time
7 | import os
8 | from core.evaler import eval_model
9 | from core.dataloader import get_dataset
10 | from core import models
11 | from tensorboardX import SummaryWriter
12 |
13 | # Parse arguments
14 | parser = argparse.ArgumentParser()
15 | # Dataset paths
16 | parser.add_argument('--val_img_dir', type=str,
17 | help='Validation image directory')
18 | parser.add_argument('--val_landmarks_dir', type=str,
19 | help='Validation landmarks directory')
20 | parser.add_argument('--num_landmarks', type=int, default=68,
21 | help='Number of landmarks')
22 |
23 | # Checkpoint and pretrained weights
24 | parser.add_argument('--ckpt_save_path', type=str,
25 | help='a directory to save checkpoint file')
26 | parser.add_argument('--pretrained_weights', type=str,
27 | help='a directory to save pretrained_weights')
28 |
29 | # Eval options
30 | parser.add_argument('--batch_size', type=int, default=25,
31 | help='learning rate decay after each epoch')
32 |
33 | # Network parameters
34 | parser.add_argument('--hg_blocks', type=int, default=4,
35 | help='Number of HG blocks to stack')
36 | parser.add_argument('--gray_scale', type=str, default="False",
37 | help='Whether to convert RGB image into gray scale during training')
38 | parser.add_argument('--end_relu', type=str, default="False",
39 | help='Whether to add relu at the end of each HG module')
40 |
41 | args = parser.parse_args()
42 |
43 | VAL_IMG_DIR = args.val_img_dir
44 | VAL_LANDMARKS_DIR = args.val_landmarks_dir
45 | CKPT_SAVE_PATH = args.ckpt_save_path
46 | BATCH_SIZE = args.batch_size
47 | PRETRAINED_WEIGHTS = args.pretrained_weights
48 | GRAY_SCALE = False if args.gray_scale == 'False' else True
49 | HG_BLOCKS = args.hg_blocks
50 | END_RELU = False if args.end_relu == 'False' else True
51 | NUM_LANDMARKS = args.num_landmarks
52 |
53 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
54 |
55 | writer = SummaryWriter(CKPT_SAVE_PATH)
56 |
57 | dataloaders, dataset_sizes = get_dataset(VAL_IMG_DIR, VAL_LANDMARKS_DIR,
58 | BATCH_SIZE, NUM_LANDMARKS)
59 | use_gpu = torch.cuda.is_available()
60 | model_ft = models.FAN(HG_BLOCKS, END_RELU, GRAY_SCALE, NUM_LANDMARKS)
61 |
62 | if PRETRAINED_WEIGHTS != "None":
63 | checkpoint = torch.load(PRETRAINED_WEIGHTS)
64 | if 'state_dict' not in checkpoint:
65 | model_ft.load_state_dict(checkpoint)
66 | else:
67 | pretrained_weights = checkpoint['state_dict']
68 | model_weights = model_ft.state_dict()
69 | pretrained_weights = {k: v for k, v in pretrained_weights.items() \
70 | if k in model_weights}
71 | model_weights.update(pretrained_weights)
72 | model_ft.load_state_dict(model_weights)
73 |
74 | model_ft = model_ft.to(device)
75 |
76 | model_ft = eval_model(model_ft, dataloaders, dataset_sizes, writer, use_gpu, 1, 'val', CKPT_SAVE_PATH, NUM_LANDMARKS)
77 |
78 |
--------------------------------------------------------------------------------
/core/evaler.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('Agg')
3 | import math
4 | import torch
5 | import copy
6 | import time
7 | from torch.autograd import Variable
8 | import shutil
9 | from skimage import io
10 | import numpy as np
11 | from utils.utils import fan_NME, show_landmarks, get_preds_fromhm
12 | from PIL import Image, ImageDraw
13 | import os
14 | import sys
15 | import cv2
16 | import matplotlib.pyplot as plt
17 |
18 |
19 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20 |
21 | def eval_model(model, dataloaders, dataset_sizes,
22 | writer, use_gpu=True, epoches=5, dataset='val',
23 | save_path='./', num_landmarks=68):
24 | global_nme = 0
25 | model.eval()
26 | for epoch in range(epoches):
27 | running_loss = 0
28 | step = 0
29 | total_nme = 0
30 | total_count = 0
31 | fail_count = 0
32 | nmes = []
33 | # running_corrects = 0
34 |
35 | # Iterate over data.
36 | with torch.no_grad():
37 | for data in dataloaders[dataset]:
38 | total_runtime = 0
39 | run_count = 0
40 | step_start = time.time()
41 | step += 1
42 | # get the inputs
43 | inputs = data['image'].type(torch.FloatTensor)
44 | labels_heatmap = data['heatmap'].type(torch.FloatTensor)
45 | labels_boundary = data['boundary'].type(torch.FloatTensor)
46 | landmarks = data['landmarks'].type(torch.FloatTensor)
47 | loss_weight_map = data['weight_map'].type(torch.FloatTensor)
48 | # wrap them in Variable
49 | if use_gpu:
50 | inputs = inputs.to(device)
51 | labels_heatmap = labels_heatmap.to(device)
52 | labels_boundary = labels_boundary.to(device)
53 | loss_weight_map = loss_weight_map.to(device)
54 | else:
55 | inputs, labels_heatmap = Variable(inputs), Variable(labels_heatmap)
56 | labels_boundary = Variable(labels_boundary)
57 | labels = torch.cat((labels_heatmap, labels_boundary), 1)
58 | single_start = time.time()
59 | outputs, boundary_channels = model(inputs)
60 | single_end = time.time()
61 | total_runtime += time.time() - single_start
62 | run_count += 1
63 | step_end = time.time()
64 | for i in range(inputs.shape[0]):
65 | img = inputs[i]
66 | img = img.cpu().numpy()
67 | img = img.transpose((1, 2, 0))*255.0
68 | img = img.astype(np.uint8)
69 | img = Image.fromarray(img)
70 | # pred_heatmap = outputs[-1][i].detach().cpu()[:-1, :, :]
71 | pred_heatmap = outputs[-1][:, :-1, :, :][i].detach().cpu()
72 | pred_landmarks, _ = get_preds_fromhm(pred_heatmap.unsqueeze(0))
73 | pred_landmarks = pred_landmarks.squeeze().numpy()
74 |
75 | gt_landmarks = data['landmarks'][i].numpy()
76 | if num_landmarks == 68:
77 | left_eye = np.average(gt_landmarks[36:42], axis=0)
78 | right_eye = np.average(gt_landmarks[42:48], axis=0)
79 | norm_factor = np.linalg.norm(left_eye - right_eye)
80 | # norm_factor = np.linalg.norm(gt_landmarks[36]- gt_landmarks[45])
81 |
82 | elif num_landmarks == 98:
83 | norm_factor = np.linalg.norm(gt_landmarks[60]- gt_landmarks[72])
84 | elif num_landmarks == 19:
85 | left, top = gt_landmarks[-2, :]
86 | right, bottom = gt_landmarks[-1, :]
87 | norm_factor = math.sqrt(abs(right - left)*abs(top-bottom))
88 | gt_landmarks = gt_landmarks[:-2, :]
89 | elif num_landmarks == 29:
90 | # norm_factor = np.linalg.norm(gt_landmarks[8]- gt_landmarks[9])
91 | norm_factor = np.linalg.norm(gt_landmarks[16]- gt_landmarks[17])
92 | single_nme = (np.sum(np.linalg.norm(pred_landmarks*4 - gt_landmarks, axis=1)) / pred_landmarks.shape[0]) / norm_factor
93 |
94 | nmes.append(single_nme)
95 | total_count += 1
96 | if single_nme > 0.1:
97 | fail_count += 1
98 | if step % 10 == 0:
99 | print('Step {} Time: {:.6f} Input Mean: {:.6f} Output Mean: {:.6f}'.format(
100 | step, step_end - step_start,
101 | torch.mean(labels),
102 | torch.mean(outputs[0])))
103 | # gt_landmarks = landmarks.numpy()
104 | # pred_heatmap = outputs[-1].to('cpu').numpy()
105 | gt_landmarks = landmarks
106 | batch_nme = fan_NME(outputs[-1][:, :-1, :, :].detach().cpu(), gt_landmarks, num_landmarks)
107 | # batch_nme = 0
108 | total_nme += batch_nme
109 | epoch_nme = total_nme / dataset_sizes['val']
110 | global_nme += epoch_nme
111 | nme_save_path = os.path.join(save_path, 'nme_log.npy')
112 | np.save(nme_save_path, np.array(nmes))
113 | print('NME: {:.6f} Failure Rate: {:.6f} Total Count: {:.6f} Fail Count: {:.6f}'.format(epoch_nme, fail_count/total_count, total_count, fail_count))
114 | print('Evaluation done! Average NME: {:.6f}'.format(global_nme/epoches))
115 | print('Everage runtime for a single batch: {:.6f}'.format(total_runtime/run_count))
116 | return model
117 |
--------------------------------------------------------------------------------
/core/coord_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class AddCoordsTh(nn.Module):
6 | def __init__(self, x_dim=64, y_dim=64, with_r=False, with_boundary=False):
7 | super(AddCoordsTh, self).__init__()
8 | self.x_dim = x_dim
9 | self.y_dim = y_dim
10 | self.with_r = with_r
11 | self.with_boundary = with_boundary
12 |
13 | def forward(self, input_tensor, heatmap=None):
14 | """
15 | input_tensor: (batch, c, x_dim, y_dim)
16 | """
17 | batch_size_tensor = input_tensor.shape[0]
18 |
19 | xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32).cuda()
20 | xx_ones = xx_ones.unsqueeze(-1)
21 |
22 | xx_range = torch.arange(self.x_dim, dtype=torch.int32).unsqueeze(0).cuda()
23 | xx_range = xx_range.unsqueeze(1)
24 |
25 | xx_channel = torch.matmul(xx_ones.float(), xx_range.float())
26 | xx_channel = xx_channel.unsqueeze(-1)
27 |
28 |
29 | yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32).cuda()
30 | yy_ones = yy_ones.unsqueeze(1)
31 |
32 | yy_range = torch.arange(self.y_dim, dtype=torch.int32).unsqueeze(0).cuda()
33 | yy_range = yy_range.unsqueeze(-1)
34 |
35 | yy_channel = torch.matmul(yy_range.float(), yy_ones.float())
36 | yy_channel = yy_channel.unsqueeze(-1)
37 |
38 | xx_channel = xx_channel.permute(0, 3, 2, 1)
39 | yy_channel = yy_channel.permute(0, 3, 2, 1)
40 |
41 | xx_channel = xx_channel / (self.x_dim - 1)
42 | yy_channel = yy_channel / (self.y_dim - 1)
43 |
44 | xx_channel = xx_channel * 2 - 1
45 | yy_channel = yy_channel * 2 - 1
46 |
47 | xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1)
48 | yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1)
49 |
50 | if self.with_boundary and type(heatmap) != type(None):
51 | boundary_channel = torch.clamp(heatmap[:, -1:, :, :],
52 | 0.0, 1.0)
53 |
54 | zero_tensor = torch.zeros_like(xx_channel)
55 | xx_boundary_channel = torch.where(boundary_channel>0.05,
56 | xx_channel, zero_tensor)
57 | yy_boundary_channel = torch.where(boundary_channel>0.05,
58 | yy_channel, zero_tensor)
59 | if self.with_boundary and type(heatmap) != type(None):
60 | xx_boundary_channel = xx_boundary_channel.cuda()
61 | yy_boundary_channel = yy_boundary_channel.cuda()
62 | ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)
63 |
64 |
65 | if self.with_r:
66 | rr = torch.sqrt(torch.pow(xx_channel, 2) + torch.pow(yy_channel, 2))
67 | rr = rr / torch.max(rr)
68 | ret = torch.cat([ret, rr], dim=1)
69 |
70 | if self.with_boundary and type(heatmap) != type(None):
71 | ret = torch.cat([ret, xx_boundary_channel,
72 | yy_boundary_channel], dim=1)
73 | return ret
74 |
75 |
76 | class CoordConvTh(nn.Module):
77 | """CoordConv layer as in the paper."""
78 | def __init__(self, x_dim, y_dim, with_r, with_boundary,
79 | in_channels, first_one=False, *args, **kwargs):
80 | super(CoordConvTh, self).__init__()
81 | self.addcoords = AddCoordsTh(x_dim=x_dim, y_dim=y_dim, with_r=with_r,
82 | with_boundary=with_boundary)
83 | in_channels += 2
84 | if with_r:
85 | in_channels += 1
86 | if with_boundary and not first_one:
87 | in_channels += 2
88 | self.conv = nn.Conv2d(in_channels=in_channels, *args, **kwargs)
89 |
90 | def forward(self, input_tensor, heatmap=None):
91 | ret = self.addcoords(input_tensor, heatmap)
92 | last_channel = ret[:, -2:, :, :]
93 | ret = self.conv(ret)
94 | return ret, last_channel
95 |
96 |
97 | '''
98 | An alternative implementation for PyTorch with auto-infering the x-y dimensions.
99 | '''
100 | class AddCoords(nn.Module):
101 |
102 | def __init__(self, with_r=False):
103 | super().__init__()
104 | self.with_r = with_r
105 |
106 | def forward(self, input_tensor):
107 | """
108 | Args:
109 | input_tensor: shape(batch, channel, x_dim, y_dim)
110 | """
111 | batch_size, _, x_dim, y_dim = input_tensor.size()
112 |
113 | xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
114 | yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)
115 |
116 | xx_channel = xx_channel / (x_dim - 1)
117 | yy_channel = yy_channel / (y_dim - 1)
118 |
119 | xx_channel = xx_channel * 2 - 1
120 | yy_channel = yy_channel * 2 - 1
121 |
122 | xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
123 | yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
124 |
125 | if input_tensor.is_cuda:
126 | xx_channel = xx_channel.cuda()
127 | yy_channel = yy_channel.cuda()
128 |
129 | ret = torch.cat([
130 | input_tensor,
131 | xx_channel.type_as(input_tensor),
132 | yy_channel.type_as(input_tensor)], dim=1)
133 |
134 | if self.with_r:
135 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2))
136 | if input_tensor.is_cuda:
137 | rr = rr.cuda()
138 | ret = torch.cat([ret, rr], dim=1)
139 |
140 | return ret
141 |
142 |
143 | class CoordConv(nn.Module):
144 |
145 | def __init__(self, in_channels, out_channels, with_r=False, **kwargs):
146 | super().__init__()
147 | self.addcoords = AddCoords(with_r=with_r)
148 | self.conv = nn.Conv2d(in_channels + 2, out_channels, **kwargs)
149 |
150 | def forward(self, x):
151 | ret = self.addcoords(x)
152 | ret = self.conv(ret)
153 | return ret
154 |
--------------------------------------------------------------------------------
/dataset/convert_WFLW.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.insert(0, "../utils/")
3 | import numpy as np
4 | import os
5 | import glob
6 | import scipy.io as sio
7 | import cv2
8 | from skimage import io
9 | from utils import cv_crop
10 | import torch
11 | from joblib import Parallel, delayed
12 |
13 | def transform(point, center, scale, resolution, rotation=0, invert=False):
14 | _pt = np.ones(3)
15 | _pt[0] = point[0]
16 | _pt[1] = point[1]
17 |
18 | h = 200.0 * scale
19 | t = np.eye(3)
20 | t[0, 0] = resolution / h
21 | t[1, 1] = resolution / h
22 | t[0, 2] = resolution * (-center[0] / h + 0.5)
23 | t[1, 2] = resolution * (-center[1] / h + 0.5)
24 |
25 | if rotation != 0:
26 | rotation = -rotation
27 | r = np.eye(3)
28 | ang = rotation * math.pi / 180.0
29 | s = math.sin(ang)
30 | c = math.cos(ang)
31 | r[0][0] = c
32 | r[0][1] = -s
33 | r[1][0] = s
34 | r[1][1] = c
35 |
36 | t_ = np.eye(3)
37 | t_[0][2] = -resolution / 2.0
38 | t_[1][2] = -resolution / 2.0
39 | t_inv = torch.eye(3)
40 | t_inv[0][2] = resolution / 2.0
41 | t_inv[1][2] = resolution / 2.0
42 | t = reduce(np.matmul, [t_inv, r, t_, t])
43 |
44 | if invert:
45 | t = np.linalg.inv(t)
46 | new_point = (np.matmul(t, _pt))[0:2]
47 |
48 | return new_point.astype(float)
49 |
50 | def parse_pts(pts_file):
51 | pts = []
52 | with open(pts_file) as f:
53 | for line in f.readlines():
54 | line = line.strip()
55 | if line[0].isdigit() == False:
56 | continue
57 | else:
58 | idx = line.find(' ')
59 | x, y = float(line[:idx]), float(line[idx+1:])
60 | pts.append([x, y])
61 | if len(pts) != 68:
62 | print('Not enough points')
63 | else:
64 | return np.array(pts)
65 |
66 | class WFLWInstance():
67 | def __init__(self, line, idx):
68 | self.idx = idx
69 | line = line.strip().split(' ')
70 | # convert landmarks
71 | landmarks_list = list(map(float, line[:196]))
72 | self.landmarks = []
73 | for i in range(0, 196, 2):
74 | self.landmarks.append([landmarks_list[i], landmarks_list[i+1]])
75 | self.landmarks = np.array(self.landmarks)
76 |
77 | # convert bboxes
78 | if len(line) == 207:
79 | self.bbox = list(map(float, line[196:200]))
80 | else:
81 | self.bbox = None
82 |
83 | # convert image name
84 | self.image_base_name = line[-1]
85 | self.image_first_point = line[0]
86 |
87 | def load_meta_subset_data(meta_path):
88 | with open(meta_path) as f:
89 | lines = f.readlines()
90 |
91 | meta_data = []
92 | idx = 0
93 | for line in lines:
94 | line = line.strip().split(' ')
95 | meta_data.append(line[-1]+line[0])
96 | return meta_data
97 |
98 | def load_meta_data(meta_path, meta_subset_data=None):
99 | with open(meta_path) as f:
100 | lines = f.readlines()
101 |
102 | meta_data = []
103 | idx = 0
104 | for line in lines:
105 | wflw_instance = WFLWInstance(line, idx)
106 | if meta_subset_data is not None and (wflw_instance.image_base_name+wflw_instance.image_first_point) in meta_subset_data:
107 | meta_data.append(wflw_instance)
108 | idx += 1
109 | return meta_data
110 |
111 | def process_single(single, image_path, image_save_path, landmarks_save_path):
112 | # print('Processing: {}'.format(single.image_base_name))
113 | image_full_path = os.path.join(image_path, single.image_base_name)
114 | image = io.imread(image_full_path)
115 | if len(image.shape) == 2:
116 | image = np.stack((image, image, image), -1)
117 |
118 | pts = single.landmarks
119 | left, top, right, bottom = [int(x) for x in single.bbox]
120 | lr_pad = int(0.05 * (right - left) / 2)
121 | tb_pad = int(0.05 * (bottom - top) / 2)
122 | left = max(0, left - lr_pad)
123 | right = right + lr_pad
124 | top = max(0, top - tb_pad)
125 | bottom = bottom + tb_pad
126 |
127 | center = torch.FloatTensor(
128 | [right - (right - left) / 2.0, bottom -
129 | (bottom - top) / 2.0])
130 | scale_factor = 250.0
131 | scale = (right - left + bottom - top) / scale_factor
132 | new_image, new_landmarks = cv_crop(image, pts, center, scale, 450, 0)
133 | while np.min(new_landmarks) < 10 or np.max(new_landmarks) > 440:
134 | scale_factor -= 10
135 | scale = (right - left + bottom - top) / scale_factor
136 | new_image, new_landmarks = cv_crop(image, pts, center, scale, 450, 0)
137 | assert (scale_factor > 0), "Landmarks out of boundary!"
138 | if new_image != []:
139 | io.imsave(os.path.join(image_save_path, os.path.basename(image_full_path[:-4]+'_' + str(single.idx) + image_full_path[-4:])), new_image)
140 | np.save(os.path.join(landmarks_save_path, os.path.basename(image_full_path[:-4]+ '_' + str(single.idx) + '.pts')), new_landmarks)
141 |
142 | if __name__ == '__main__':
143 | image_path = './WFLW_images/'
144 | meta_subset_path = './WFLW_annotations/list_98pt_rect_attr_train_test/list_98pt_rect_attr_test.txt'
145 | meta_path = './WFLW_annotations/list_98pt_rect_attr_train_test/list_98pt_rect_attr_test.txt'
146 | image_save_path = './WFLW_test/images/'
147 | landmarks_save_path = './WFLW_test/landmarks/'
148 | if not os.path.exists(image_save_path):
149 | os.makedirs(image_save_path)
150 | if not os.path.exists(landmarks_save_path):
151 | os.makedirs(landmarks_save_path)
152 | exts = ['*.png', '*.jpg']
153 | meta_subset_data = load_meta_subset_data(meta_subset_path)
154 | meta_data = load_meta_data(meta_path, meta_subset_data)
155 | assert (len(meta_data) == len(meta_subset_data)), "Some images are missing!"
156 | print("Total images: {0:d}".format(len(meta_data)))
157 | Parallel(n_jobs=10,
158 | backend='threading',
159 | verbose=10)(delayed(process_single)(single, image_path,
160 | image_save_path,
161 | landmarks_save_path) for single in meta_data)
162 |
--------------------------------------------------------------------------------
/core/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | from core.coord_conv import CoordConvTh
6 |
7 |
8 | def conv3x3(in_planes, out_planes, strd=1, padding=1,
9 | bias=False,dilation=1):
10 | "3x3 convolution with padding"
11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3,
12 | stride=strd, padding=padding, bias=bias,
13 | dilation=dilation)
14 |
15 | class BasicBlock(nn.Module):
16 | expansion = 1
17 |
18 | def __init__(self, inplanes, planes, stride=1, downsample=None):
19 | super(BasicBlock, self).__init__()
20 | self.conv1 = conv3x3(inplanes, planes, stride)
21 | # self.bn1 = nn.BatchNorm2d(planes)
22 | self.relu = nn.ReLU(inplace=True)
23 | self.conv2 = conv3x3(planes, planes)
24 | # self.bn2 = nn.BatchNorm2d(planes)
25 | self.downsample = downsample
26 | self.stride = stride
27 |
28 | def forward(self, x):
29 | residual = x
30 |
31 | out = self.conv1(x)
32 | # out = self.bn1(out)
33 | out = self.relu(out)
34 |
35 | out = self.conv2(out)
36 | # out = self.bn2(out)
37 |
38 | if self.downsample is not None:
39 | residual = self.downsample(x)
40 |
41 | out += residual
42 | out = self.relu(out)
43 |
44 | return out
45 |
46 | class ConvBlock(nn.Module):
47 | def __init__(self, in_planes, out_planes):
48 | super(ConvBlock, self).__init__()
49 | self.bn1 = nn.BatchNorm2d(in_planes)
50 | self.conv1 = conv3x3(in_planes, int(out_planes / 2))
51 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
52 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4),
53 | padding=1, dilation=1)
54 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
55 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4),
56 | padding=1, dilation=1)
57 |
58 | if in_planes != out_planes:
59 | self.downsample = nn.Sequential(
60 | nn.BatchNorm2d(in_planes),
61 | nn.ReLU(True),
62 | nn.Conv2d(in_planes, out_planes,
63 | kernel_size=1, stride=1, bias=False),
64 | )
65 | else:
66 | self.downsample = None
67 |
68 | def forward(self, x):
69 | residual = x
70 |
71 | out1 = self.bn1(x)
72 | out1 = F.relu(out1, True)
73 | out1 = self.conv1(out1)
74 |
75 | out2 = self.bn2(out1)
76 | out2 = F.relu(out2, True)
77 | out2 = self.conv2(out2)
78 |
79 | out3 = self.bn3(out2)
80 | out3 = F.relu(out3, True)
81 | out3 = self.conv3(out3)
82 |
83 | out3 = torch.cat((out1, out2, out3), 1)
84 |
85 | if self.downsample is not None:
86 | residual = self.downsample(residual)
87 |
88 | out3 += residual
89 |
90 | return out3
91 |
92 | class HourGlass(nn.Module):
93 | def __init__(self, num_modules, depth, num_features, first_one=False):
94 | super(HourGlass, self).__init__()
95 | self.num_modules = num_modules
96 | self.depth = depth
97 | self.features = num_features
98 | self.coordconv = CoordConvTh(x_dim=64, y_dim=64,
99 | with_r=True, with_boundary=True,
100 | in_channels=256, first_one=first_one,
101 | out_channels=256,
102 | kernel_size=1,
103 | stride=1, padding=0)
104 | self._generate_network(self.depth)
105 |
106 | def _generate_network(self, level):
107 | self.add_module('b1_' + str(level), ConvBlock(256, 256))
108 |
109 | self.add_module('b2_' + str(level), ConvBlock(256, 256))
110 |
111 | if level > 1:
112 | self._generate_network(level - 1)
113 | else:
114 | self.add_module('b2_plus_' + str(level), ConvBlock(256, 256))
115 |
116 | self.add_module('b3_' + str(level), ConvBlock(256, 256))
117 |
118 | def _forward(self, level, inp):
119 | # Upper branch
120 | up1 = inp
121 | up1 = self._modules['b1_' + str(level)](up1)
122 |
123 | # Lower branch
124 | low1 = F.avg_pool2d(inp, 2, stride=2)
125 | low1 = self._modules['b2_' + str(level)](low1)
126 |
127 | if level > 1:
128 | low2 = self._forward(level - 1, low1)
129 | else:
130 | low2 = low1
131 | low2 = self._modules['b2_plus_' + str(level)](low2)
132 |
133 | low3 = low2
134 | low3 = self._modules['b3_' + str(level)](low3)
135 |
136 | up2 = F.upsample(low3, scale_factor=2, mode='nearest')
137 |
138 | return up1 + up2
139 |
140 | def forward(self, x, heatmap):
141 | x, last_channel = self.coordconv(x, heatmap)
142 | return self._forward(self.depth, x), last_channel
143 |
144 | class FAN(nn.Module):
145 |
146 | def __init__(self, num_modules=1, end_relu=False, gray_scale=False,
147 | num_landmarks=68):
148 | super(FAN, self).__init__()
149 | self.num_modules = num_modules
150 | self.gray_scale = gray_scale
151 | self.end_relu = end_relu
152 | self.num_landmarks = num_landmarks
153 |
154 | # Base part
155 | if self.gray_scale:
156 | self.conv1 = CoordConvTh(x_dim=256, y_dim=256,
157 | with_r=True, with_boundary=False,
158 | in_channels=3, out_channels=64,
159 | kernel_size=7,
160 | stride=2, padding=3)
161 | else:
162 | self.conv1 = CoordConvTh(x_dim=256, y_dim=256,
163 | with_r=True, with_boundary=False,
164 | in_channels=3, out_channels=64,
165 | kernel_size=7,
166 | stride=2, padding=3)
167 | self.bn1 = nn.BatchNorm2d(64)
168 | self.conv2 = ConvBlock(64, 128)
169 | self.conv3 = ConvBlock(128, 128)
170 | self.conv4 = ConvBlock(128, 256)
171 |
172 | # Stacking part
173 | for hg_module in range(self.num_modules):
174 | if hg_module == 0:
175 | first_one = True
176 | else:
177 | first_one = False
178 | self.add_module('m' + str(hg_module), HourGlass(1, 4, 256,
179 | first_one))
180 | self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
181 | self.add_module('conv_last' + str(hg_module),
182 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
183 | self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
184 | self.add_module('l' + str(hg_module), nn.Conv2d(256,
185 | num_landmarks+1, kernel_size=1, stride=1, padding=0))
186 |
187 | if hg_module < self.num_modules - 1:
188 | self.add_module(
189 | 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
190 | self.add_module('al' + str(hg_module), nn.Conv2d(num_landmarks+1,
191 | 256, kernel_size=1, stride=1, padding=0))
192 |
193 | def forward(self, x):
194 | x, _ = self.conv1(x)
195 | x = F.relu(self.bn1(x), True)
196 | # x = F.relu(self.bn1(self.conv1(x)), True)
197 | x = F.avg_pool2d(self.conv2(x), 2, stride=2)
198 | x = self.conv3(x)
199 | x = self.conv4(x)
200 |
201 | previous = x
202 |
203 | outputs = []
204 | boundary_channels = []
205 | tmp_out = None
206 | for i in range(self.num_modules):
207 | hg, boundary_channel = self._modules['m' + str(i)](previous,
208 | tmp_out)
209 |
210 | ll = hg
211 | ll = self._modules['top_m_' + str(i)](ll)
212 |
213 | ll = F.relu(self._modules['bn_end' + str(i)]
214 | (self._modules['conv_last' + str(i)](ll)), True)
215 |
216 | # Predict heatmaps
217 | tmp_out = self._modules['l' + str(i)](ll)
218 | if self.end_relu:
219 | tmp_out = F.relu(tmp_out) # HACK: Added relu
220 | outputs.append(tmp_out)
221 | boundary_channels.append(boundary_channel)
222 |
223 | if i < self.num_modules - 1:
224 | ll = self._modules['bl' + str(i)](ll)
225 | tmp_out_ = self._modules['al' + str(i)](tmp_out)
226 | previous = previous + ll + tmp_out_
227 |
228 | return outputs, boundary_channels
229 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | import sys
4 | import math
5 | import torch
6 | import cv2
7 | from PIL import Image
8 | from skimage import io
9 | from skimage import transform as ski_transform
10 | from scipy import ndimage
11 | import numpy as np
12 | import matplotlib
13 | import matplotlib.pyplot as plt
14 | from torch.utils.data import Dataset, DataLoader
15 | from torchvision import transforms, utils
16 |
17 | def _gaussian(
18 | size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
19 | height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
20 | mean_vert=0.5):
21 | # handle some defaults
22 | if width is None:
23 | width = size
24 | if height is None:
25 | height = size
26 | if sigma_horz is None:
27 | sigma_horz = sigma
28 | if sigma_vert is None:
29 | sigma_vert = sigma
30 | center_x = mean_horz * width + 0.5
31 | center_y = mean_vert * height + 0.5
32 | gauss = np.empty((height, width), dtype=np.float32)
33 | # generate kernel
34 | for i in range(height):
35 | for j in range(width):
36 | gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
37 | sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
38 | if normalize:
39 | gauss = gauss / np.sum(gauss)
40 | return gauss
41 |
42 | def draw_gaussian(image, point, sigma):
43 | # Check if the gaussian is inside
44 | ul = [np.floor(np.floor(point[0]) - 3 * sigma),
45 | np.floor(np.floor(point[1]) - 3 * sigma)]
46 | br = [np.floor(np.floor(point[0]) + 3 * sigma),
47 | np.floor(np.floor(point[1]) + 3 * sigma)]
48 | if (ul[0] > image.shape[1] or ul[1] >
49 | image.shape[0] or br[0] < 1 or br[1] < 1):
50 | return image
51 | size = 6 * sigma + 1
52 | g = _gaussian(size)
53 | g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) -
54 | int(max(1, ul[0])) + int(max(1, -ul[0]))]
55 | g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) -
56 | int(max(1, ul[1])) + int(max(1, -ul[1]))]
57 | img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
58 | img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
59 | assert (g_x[0] > 0 and g_y[1] > 0)
60 | correct = False
61 | while not correct:
62 | try:
63 | image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
64 | ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
65 | correct = True
66 | except:
67 | print('img_x: {}, img_y: {}, g_x:{}, g_y:{}, point:{}, g_shape:{}, ul:{}, br:{}'.format(img_x, img_y, g_x, g_y, point, g.shape, ul, br))
68 | ul = [np.floor(np.floor(point[0]) - 3 * sigma),
69 | np.floor(np.floor(point[1]) - 3 * sigma)]
70 | br = [np.floor(np.floor(point[0]) + 3 * sigma),
71 | np.floor(np.floor(point[1]) + 3 * sigma)]
72 | g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) -
73 | int(max(1, ul[0])) + int(max(1, -ul[0]))]
74 | g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) -
75 | int(max(1, ul[1])) + int(max(1, -ul[1]))]
76 | img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
77 | img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
78 | pass
79 | image[image > 1] = 1
80 | return image
81 |
82 | def transform(point, center, scale, resolution, rotation=0, invert=False):
83 | _pt = np.ones(3)
84 | _pt[0] = point[0]
85 | _pt[1] = point[1]
86 |
87 | h = 200.0 * scale
88 | t = np.eye(3)
89 | t[0, 0] = resolution / h
90 | t[1, 1] = resolution / h
91 | t[0, 2] = resolution * (-center[0] / h + 0.5)
92 | t[1, 2] = resolution * (-center[1] / h + 0.5)
93 |
94 | if rotation != 0:
95 | rotation = -rotation
96 | r = np.eye(3)
97 | ang = rotation * math.pi / 180.0
98 | s = math.sin(ang)
99 | c = math.cos(ang)
100 | r[0][0] = c
101 | r[0][1] = -s
102 | r[1][0] = s
103 | r[1][1] = c
104 |
105 | t_ = np.eye(3)
106 | t_[0][2] = -resolution / 2.0
107 | t_[1][2] = -resolution / 2.0
108 | t_inv = torch.eye(3)
109 | t_inv[0][2] = resolution / 2.0
110 | t_inv[1][2] = resolution / 2.0
111 | t = reduce(np.matmul, [t_inv, r, t_, t])
112 |
113 | if invert:
114 | t = np.linalg.inv(t)
115 | new_point = (np.matmul(t, _pt))[0:2]
116 |
117 | return new_point.astype(int)
118 |
119 | def cv_crop(image, landmarks, center, scale, resolution=256, center_shift=0):
120 | new_image = cv2.copyMakeBorder(image, center_shift,
121 | center_shift,
122 | center_shift,
123 | center_shift,
124 | cv2.BORDER_CONSTANT, value=[0,0,0])
125 | new_landmarks = landmarks.copy()
126 | if center_shift != 0:
127 | center[0] += center_shift
128 | center[1] += center_shift
129 | new_landmarks = new_landmarks + center_shift
130 | length = 200 * scale
131 | top = int(center[1] - length // 2)
132 | bottom = int(center[1] + length // 2)
133 | left = int(center[0] - length // 2)
134 | right = int(center[0] + length // 2)
135 | y_pad = abs(min(top, new_image.shape[0] - bottom, 0))
136 | x_pad = abs(min(left, new_image.shape[1] - right, 0))
137 | top, bottom, left, right = top + y_pad, bottom + y_pad, left + x_pad, right + x_pad
138 | new_image = cv2.copyMakeBorder(new_image, y_pad,
139 | y_pad,
140 | x_pad,
141 | x_pad,
142 | cv2.BORDER_CONSTANT, value=[0,0,0])
143 | new_image = new_image[top:bottom, left:right]
144 | new_image = cv2.resize(new_image, dsize=(int(resolution), int(resolution)),
145 | interpolation=cv2.INTER_LINEAR)
146 | new_landmarks[:, 0] = (new_landmarks[:, 0] + x_pad - left) * resolution / length
147 | new_landmarks[:, 1] = (new_landmarks[:, 1] + y_pad - top) * resolution / length
148 | return new_image, new_landmarks
149 |
150 | def cv_rotate(image, landmarks, heatmap, rot, scale, resolution=256):
151 | img_mat = cv2.getRotationMatrix2D((resolution//2, resolution//2), rot, scale)
152 | ones = np.ones(shape=(landmarks.shape[0], 1))
153 | stacked_landmarks = np.hstack([landmarks, ones])
154 | new_landmarks = img_mat.dot(stacked_landmarks.T).T
155 | if np.max(new_landmarks) > 255 or np.min(new_landmarks) < 0:
156 | return image, landmarks, heatmap
157 | else:
158 | new_image = cv2.warpAffine(image, img_mat, (resolution, resolution))
159 | if heatmap is not None:
160 | new_heatmap = np.zeros((heatmap.shape[0], 64, 64))
161 | for i in range(heatmap.shape[0]):
162 | if new_landmarks[i][0] > 0:
163 | new_heatmap[i] = draw_gaussian(new_heatmap[i],
164 | new_landmarks[i]/4.0+1, 1)
165 | return new_image, new_landmarks, new_heatmap
166 |
167 | def show_landmarks(image, heatmap, gt_landmarks, gt_heatmap):
168 | """Show image with pred_landmarks"""
169 | pred_landmarks = []
170 | pred_landmarks, _ = get_preds_fromhm(torch.from_numpy(heatmap).unsqueeze(0))
171 | pred_landmarks = pred_landmarks.squeeze()*4
172 |
173 | # pred_landmarks2 = get_preds_fromhm2(heatmap)
174 | heatmap = np.max(gt_heatmap, axis=0)
175 | heatmap = heatmap / np.max(heatmap)
176 | # image = ski_transform.resize(image, (64, 64))*255
177 | image = image.astype(np.uint8)
178 | heatmap = np.max(gt_heatmap, axis=0)
179 | heatmap = ski_transform.resize(heatmap, (image.shape[0], image.shape[1]))
180 | heatmap *= 255
181 | heatmap = heatmap.astype(np.uint8)
182 | heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
183 | plt.imshow(image)
184 | plt.scatter(gt_landmarks[:, 0], gt_landmarks[:, 1], s=0.5, marker='.', c='g')
185 | plt.scatter(pred_landmarks[:, 0], pred_landmarks[:, 1], s=0.5, marker='.', c='r')
186 | plt.pause(0.001) # pause a bit so that plots are updated
187 |
188 | def fan_NME(pred_heatmaps, gt_landmarks, num_landmarks=68):
189 | '''
190 | Calculate total NME for a batch of data
191 |
192 | Args:
193 | pred_heatmaps: torch tensor of size [batch, points, height, width]
194 | gt_landmarks: torch tesnsor of size [batch, points, x, y]
195 |
196 | Returns:
197 | nme: sum of nme for this batch
198 | '''
199 | nme = 0
200 | pred_landmarks, _ = get_preds_fromhm(pred_heatmaps)
201 | pred_landmarks = pred_landmarks.numpy()
202 | gt_landmarks = gt_landmarks.numpy()
203 | for i in range(pred_landmarks.shape[0]):
204 | pred_landmark = pred_landmarks[i] * 4.0
205 | gt_landmark = gt_landmarks[i]
206 |
207 | if num_landmarks == 68:
208 | left_eye = np.average(gt_landmark[36:42], axis=0)
209 | right_eye = np.average(gt_landmark[42:48], axis=0)
210 | norm_factor = np.linalg.norm(left_eye - right_eye)
211 | # norm_factor = np.linalg.norm(gt_landmark[36]- gt_landmark[45])
212 | elif num_landmarks == 98:
213 | norm_factor = np.linalg.norm(gt_landmark[60]- gt_landmark[72])
214 | elif num_landmarks == 19:
215 | left, top = gt_landmark[-2, :]
216 | right, bottom = gt_landmark[-1, :]
217 | norm_factor = math.sqrt(abs(right - left)*abs(top-bottom))
218 | gt_landmark = gt_landmark[:-2, :]
219 | elif num_landmarks == 29:
220 | # norm_factor = np.linalg.norm(gt_landmark[8]- gt_landmark[9])
221 | norm_factor = np.linalg.norm(gt_landmark[16]- gt_landmark[17])
222 | nme += (np.sum(np.linalg.norm(pred_landmark - gt_landmark, axis=1)) / pred_landmark.shape[0]) / norm_factor
223 | return nme
224 |
225 | def fan_NME_hm(pred_heatmaps, gt_heatmaps, num_landmarks=68):
226 | '''
227 | Calculate total NME for a batch of data
228 |
229 | Args:
230 | pred_heatmaps: torch tensor of size [batch, points, height, width]
231 | gt_landmarks: torch tesnsor of size [batch, points, x, y]
232 |
233 | Returns:
234 | nme: sum of nme for this batch
235 | '''
236 | nme = 0
237 | pred_landmarks, _ = get_index_fromhm(pred_heatmaps)
238 | pred_landmarks = pred_landmarks.numpy()
239 | gt_landmarks = gt_landmarks.numpy()
240 | for i in range(pred_landmarks.shape[0]):
241 | pred_landmark = pred_landmarks[i] * 4.0
242 | gt_landmark = gt_landmarks[i]
243 | if num_landmarks == 68:
244 | left_eye = np.average(gt_landmark[36:42], axis=0)
245 | right_eye = np.average(gt_landmark[42:48], axis=0)
246 | norm_factor = np.linalg.norm(left_eye - right_eye)
247 | else:
248 | norm_factor = np.linalg.norm(gt_landmark[60]- gt_landmark[72])
249 | nme += (np.sum(np.linalg.norm(pred_landmark - gt_landmark, axis=1)) / pred_landmark.shape[0]) / norm_factor
250 | return nme
251 |
252 | def power_transform(img, power):
253 | img = np.array(img)
254 | img_new = np.power((img/255.0), power) * 255.0
255 | img_new = img_new.astype(np.uint8)
256 | img_new = Image.fromarray(img_new)
257 | return img_new
258 |
259 | def get_preds_fromhm(hm, center=None, scale=None, rot=None):
260 | max, idx = torch.max(
261 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
262 | idx += 1
263 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
264 | preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
265 | preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
266 |
267 | for i in range(preds.size(0)):
268 | for j in range(preds.size(1)):
269 | hm_ = hm[i, j, :]
270 | pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
271 | if pX > 0 and pX < 63 and pY > 0 and pY < 63:
272 | diff = torch.FloatTensor(
273 | [hm_[pY, pX + 1] - hm_[pY, pX - 1],
274 | hm_[pY + 1, pX] - hm_[pY - 1, pX]])
275 | preds[i, j].add_(diff.sign_().mul_(.25))
276 |
277 | preds.add_(-0.5)
278 |
279 | preds_orig = torch.zeros(preds.size())
280 | if center is not None and scale is not None:
281 | for i in range(hm.size(0)):
282 | for j in range(hm.size(1)):
283 | preds_orig[i, j] = transform(
284 | preds[i, j], center, scale, hm.size(2), rot, True)
285 |
286 | return preds, preds_orig
287 |
288 | def get_index_fromhm(hm):
289 | max, idx = torch.max(
290 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
291 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
292 | preds[..., 0].remainder_(hm.size(3))
293 | preds[..., 1].div_(hm.size(2)).floor_()
294 |
295 | for i in range(preds.size(0)):
296 | for j in range(preds.size(1)):
297 | hm_ = hm[i, j, :]
298 | pX, pY = int(preds[i, j, 0]), int(preds[i, j, 1])
299 | if pX > 0 and pX < 63 and pY > 0 and pY < 63:
300 | diff = torch.FloatTensor(
301 | [hm_[pY, pX + 1] - hm_[pY, pX - 1],
302 | hm_[pY + 1, pX] - hm_[pY - 1, pX]])
303 | preds[i, j].add_(diff.sign_().mul_(.25))
304 |
305 | return preds
306 |
307 | def shuffle_lr(parts, num_landmarks=68, pairs=None):
308 | if num_landmarks == 68:
309 | if pairs is None:
310 | pairs = [[0, 16], [1, 15], [2, 14], [3, 13], [4, 12], [5, 11], [6, 10],
311 | [7, 9], [17, 26], [18, 25], [19, 24], [20, 23], [21, 22], [36, 45],
312 | [37, 44], [38, 43], [39, 42], [41, 46], [40, 47], [31, 35], [32, 34],
313 | [50, 52], [49, 53], [48, 54], [61, 63], [60, 64], [67, 65], [59, 55], [58, 56]]
314 | elif num_landmarks == 98:
315 | if pairs is None:
316 | pairs = [[0, 32], [1,31], [2, 30], [3, 29], [4, 28], [5, 27], [6, 26], [7, 25], [8, 24], [9, 23], [10, 22], [11, 21], [12, 20], [13, 19], [14, 18], [15, 17], [33, 46], [34, 45], [35, 44], [36, 43], [37, 42], [38, 50], [39, 49], [40, 48], [41, 47], [60, 72], [61, 71], [62, 70], [63, 69], [64, 68], [65, 75], [66, 74], [67, 73], [96, 97], [55, 59], [56, 58], [76, 82], [77, 81], [78, 80], [88, 92], [89, 91], [95, 93], [87, 83], [86, 84]]
317 | elif num_landmarks == 19:
318 | if pairs is None:
319 | pairs = [[0, 5], [1, 4], [2, 3], [6, 11], [7, 10], [8, 9], [12, 14], [15, 17]]
320 | elif num_landmarks == 29:
321 | if pairs is None:
322 | pairs = [[0, 1], [4, 6], [5, 7], [2, 3], [8, 9], [12, 14], [16, 17], [13, 15], [10, 11], [18, 19], [22, 23]]
323 | for matched_p in pairs:
324 | idx1, idx2 = matched_p[0], matched_p[1]
325 | tmp = np.copy(parts[idx1])
326 | np.copyto(parts[idx1], parts[idx2])
327 | np.copyto(parts[idx2], tmp)
328 | return parts
329 |
330 |
331 | def generate_weight_map(weight_map,heatmap):
332 |
333 | k_size = 3
334 | dilate = ndimage.grey_dilation(heatmap ,size=(k_size,k_size))
335 | weight_map[np.where(dilate>0.2)] = 1
336 | return weight_map
337 |
338 | def fig2data(fig):
339 | """
340 | @brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it
341 | @param fig a matplotlib figure
342 | @return a numpy 3D array of RGBA values
343 | """
344 | # draw the renderer
345 | fig.canvas.draw ( )
346 |
347 | # Get the RGB buffer from the figure
348 | w,h = fig.canvas.get_width_height()
349 | buf = np.fromstring (fig.canvas.tostring_rgb(), dtype=np.uint8)
350 | buf.shape = (w, h, 3)
351 |
352 | # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode
353 | buf = np.roll (buf, 3, axis=2)
354 | return buf
355 |
--------------------------------------------------------------------------------
/core/dataloader.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import random
4 | import glob
5 | import torch
6 | from skimage import io
7 | from skimage import transform as ski_transform
8 | from skimage.color import rgb2gray
9 | import scipy.io as sio
10 | from scipy import interpolate
11 | import numpy as np
12 | import matplotlib.pyplot as plt
13 | from torch.utils.data import Dataset, DataLoader
14 | from torchvision import transforms, utils
15 | from torchvision.transforms import Lambda, Compose
16 | from torchvision.transforms.functional import adjust_brightness, adjust_contrast, adjust_saturation, adjust_hue
17 | from utils.utils import cv_crop, cv_rotate, draw_gaussian, transform, power_transform, shuffle_lr, fig2data, generate_weight_map
18 | from PIL import Image
19 | import cv2
20 | import copy
21 | import math
22 | from imgaug import augmenters as iaa
23 |
24 |
25 | class AddBoundary(object):
26 | def __init__(self, num_landmarks=68):
27 | self.num_landmarks = num_landmarks
28 |
29 | def __call__(self, sample):
30 | landmarks_64 = np.floor(sample['landmarks'] / 4.0)
31 | if self.num_landmarks == 68:
32 | boundaries = {}
33 | boundaries['cheek'] = landmarks_64[0:17]
34 | boundaries['left_eyebrow'] = landmarks_64[17:22]
35 | boundaries['right_eyebrow'] = landmarks_64[22:27]
36 | boundaries['uper_left_eyelid'] = landmarks_64[36:40]
37 | boundaries['lower_left_eyelid'] = np.array([landmarks_64[i] for i in [36, 41, 40, 39]])
38 | boundaries['upper_right_eyelid'] = landmarks_64[42:46]
39 | boundaries['lower_right_eyelid'] = np.array([landmarks_64[i] for i in [42, 47, 46, 45]])
40 | boundaries['noise'] = landmarks_64[27:31]
41 | boundaries['noise_bot'] = landmarks_64[31:36]
42 | boundaries['upper_outer_lip'] = landmarks_64[48:55]
43 | boundaries['upper_inner_lip'] = np.array([landmarks_64[i] for i in [60, 61, 62, 63, 64]])
44 | boundaries['lower_outer_lip'] = np.array([landmarks_64[i] for i in [48, 59, 58, 57, 56, 55, 54]])
45 | boundaries['lower_inner_lip'] = np.array([landmarks_64[i] for i in [60, 67, 66, 65, 64]])
46 | elif self.num_landmarks == 98:
47 | boundaries = {}
48 | boundaries['cheek'] = landmarks_64[0:33]
49 | boundaries['left_eyebrow'] = landmarks_64[33:38]
50 | boundaries['right_eyebrow'] = landmarks_64[42:47]
51 | boundaries['uper_left_eyelid'] = landmarks_64[60:65]
52 | boundaries['lower_left_eyelid'] = np.array([landmarks_64[i] for i in [60, 67, 66, 65, 64]])
53 | boundaries['upper_right_eyelid'] = landmarks_64[68:73]
54 | boundaries['lower_right_eyelid'] = np.array([landmarks_64[i] for i in [68, 75, 74, 73, 72]])
55 | boundaries['noise'] = landmarks_64[51:55]
56 | boundaries['noise_bot'] = landmarks_64[55:60]
57 | boundaries['upper_outer_lip'] = landmarks_64[76:83]
58 | boundaries['upper_inner_lip'] = np.array([landmarks_64[i] for i in [88, 89, 90, 91, 92]])
59 | boundaries['lower_outer_lip'] = np.array([landmarks_64[i] for i in [76, 87, 86, 85, 84, 83, 82]])
60 | boundaries['lower_inner_lip'] = np.array([landmarks_64[i] for i in [88, 95, 94, 93, 92]])
61 | elif self.num_landmarks == 19:
62 | boundaries = {}
63 | boundaries['left_eyebrow'] = landmarks_64[0:3]
64 | boundaries['right_eyebrow'] = landmarks_64[3:5]
65 | boundaries['left_eye'] = landmarks_64[6:9]
66 | boundaries['right_eye'] = landmarks_64[9:12]
67 | boundaries['noise'] = landmarks_64[12:15]
68 |
69 | elif self.num_landmarks == 29:
70 | boundaries = {}
71 | boundaries['upper_left_eyebrow'] = np.stack([
72 | landmarks_64[0],
73 | landmarks_64[4],
74 | landmarks_64[2]
75 | ], axis=0)
76 | boundaries['lower_left_eyebrow'] = np.stack([
77 | landmarks_64[0],
78 | landmarks_64[5],
79 | landmarks_64[2]
80 | ], axis=0)
81 | boundaries['upper_right_eyebrow'] = np.stack([
82 | landmarks_64[1],
83 | landmarks_64[6],
84 | landmarks_64[3]
85 | ], axis=0)
86 | boundaries['lower_right_eyebrow'] = np.stack([
87 | landmarks_64[1],
88 | landmarks_64[7],
89 | landmarks_64[3]
90 | ], axis=0)
91 | boundaries['upper_left_eye'] = np.stack([
92 | landmarks_64[8],
93 | landmarks_64[12],
94 | landmarks_64[10]
95 | ], axis=0)
96 | boundaries['lower_left_eye'] = np.stack([
97 | landmarks_64[8],
98 | landmarks_64[13],
99 | landmarks_64[10]
100 | ], axis=0)
101 | boundaries['upper_right_eye'] = np.stack([
102 | landmarks_64[9],
103 | landmarks_64[14],
104 | landmarks_64[11]
105 | ], axis=0)
106 | boundaries['lower_right_eye'] = np.stack([
107 | landmarks_64[9],
108 | landmarks_64[15],
109 | landmarks_64[11]
110 | ], axis=0)
111 | boundaries['noise'] = np.stack([
112 | landmarks_64[18],
113 | landmarks_64[21],
114 | landmarks_64[19]
115 | ], axis=0)
116 | boundaries['outer_upper_lip'] = np.stack([
117 | landmarks_64[22],
118 | landmarks_64[24],
119 | landmarks_64[23]
120 | ], axis=0)
121 | boundaries['inner_upper_lip'] = np.stack([
122 | landmarks_64[22],
123 | landmarks_64[25],
124 | landmarks_64[23]
125 | ], axis=0)
126 | boundaries['outer_lower_lip'] = np.stack([
127 | landmarks_64[22],
128 | landmarks_64[26],
129 | landmarks_64[23]
130 | ], axis=0)
131 | boundaries['inner_lower_lip'] = np.stack([
132 | landmarks_64[22],
133 | landmarks_64[27],
134 | landmarks_64[23]
135 | ], axis=0)
136 | functions = {}
137 |
138 | for key, points in boundaries.items():
139 | temp = points[0]
140 | new_points = points[0:1, :]
141 | for point in points[1:]:
142 | if point[0] == temp[0] and point[1] == temp[1]:
143 | continue
144 | else:
145 | new_points = np.concatenate((new_points, np.expand_dims(point, 0)), axis=0)
146 | temp = point
147 | points = new_points
148 | if points.shape[0] == 1:
149 | points = np.concatenate((points, points+0.001), axis=0)
150 | k = min(4, points.shape[0])
151 | functions[key] = interpolate.splprep([points[:, 0], points[:, 1]], k=k-1,s=0)
152 |
153 | boundary_map = np.zeros((64, 64))
154 |
155 | fig = plt.figure(figsize=[64/96.0, 64/96.0], dpi=96)
156 |
157 | ax = fig.add_axes([0, 0, 1, 1])
158 |
159 | ax.axis('off')
160 |
161 | ax.imshow(boundary_map, interpolation='nearest', cmap='gray')
162 | #ax.scatter(landmarks[:, 0], landmarks[:, 1], s=1, marker=',', c='w')
163 |
164 | for key in functions.keys():
165 | xnew = np.arange(0, 1, 0.01)
166 | out = interpolate.splev(xnew, functions[key][0], der=0)
167 | plt.plot(out[0], out[1], ',', linewidth=1, color='w')
168 |
169 | img = fig2data(fig)
170 |
171 | plt.close()
172 |
173 | sigma = 1
174 | temp = 255-img[:,:,1]
175 | temp = cv2.distanceTransform(temp, cv2.DIST_L2, cv2.DIST_MASK_PRECISE)
176 | temp = temp.astype(np.float32)
177 | temp = np.where(temp < 3*sigma, np.exp(-(temp*temp)/(2*sigma*sigma)), 0 )
178 |
179 | fig = plt.figure(figsize=[64/96.0, 64/96.0], dpi=96)
180 |
181 | ax = fig.add_axes([0, 0, 1, 1])
182 |
183 | ax.axis('off')
184 | ax.imshow(temp, cmap='gray')
185 | plt.close()
186 |
187 | boundary_map = fig2data(fig)
188 |
189 | sample['boundary'] = boundary_map[:, :, 0]
190 |
191 | return sample
192 |
193 | class AddWeightMap(object):
194 | def __call__(self, sample):
195 | heatmap= sample['heatmap']
196 | boundary = sample['boundary']
197 | heatmap = np.concatenate((heatmap, np.expand_dims(boundary, axis=0)), 0)
198 | weight_map = np.zeros_like(heatmap)
199 | for i in range(heatmap.shape[0]):
200 | weight_map[i] = generate_weight_map(weight_map[i],
201 | heatmap[i])
202 | sample['weight_map'] = weight_map
203 | return sample
204 |
205 | class ToTensor(object):
206 | """Convert ndarrays in sample to Tensors."""
207 |
208 | def __call__(self, sample):
209 | image, heatmap, landmarks, boundary, weight_map= sample['image'], sample['heatmap'], sample['landmarks'], sample['boundary'], sample['weight_map']
210 |
211 | # swap color axis because
212 | # numpy image: H x W x C
213 | # torch image: C X H X W
214 | if len(image.shape) == 2:
215 | image = np.expand_dims(image, axis=2)
216 | image_small = np.expand_dims(image_small, axis=2)
217 | image = image.transpose((2, 0, 1))
218 | boundary = np.expand_dims(boundary, axis=2)
219 | boundary = boundary.transpose((2, 0, 1))
220 | return {'image': torch.from_numpy(image).float().div(255.0),
221 | 'heatmap': torch.from_numpy(heatmap).float(),
222 | 'landmarks': torch.from_numpy(landmarks).float(),
223 | 'boundary': torch.from_numpy(boundary).float().div(255.0),
224 | 'weight_map': torch.from_numpy(weight_map).float()}
225 |
226 | class FaceLandmarksDataset(Dataset):
227 | """Face Landmarks dataset."""
228 |
229 | def __init__(self, img_dir, landmarks_dir, num_landmarks=68, gray_scale=False,
230 | detect_face=False, enhance=False, center_shift=0,
231 | transform=None,):
232 | """
233 | Args:
234 | landmark_dir (string): Path to the mat file with landmarks saved.
235 | img_dir (string): Directory with all the images.
236 | transform (callable, optional): Optional transform to be applied
237 | on a sample.
238 | """
239 | self.img_dir = img_dir
240 | self.landmarks_dir = landmarks_dir
241 | self.num_lanmdkars = num_landmarks
242 | self.transform = transform
243 | self.img_names = glob.glob(self.img_dir+'*.jpg') + \
244 | glob.glob(self.img_dir+'*.png')
245 | self.gray_scale = gray_scale
246 | self.detect_face = detect_face
247 | self.enhance = enhance
248 | self.center_shift = center_shift
249 | if self.detect_face:
250 | self.face_detector = MTCNN(thresh=[0.5, 0.6, 0.7])
251 | def __len__(self):
252 | return len(self.img_names)
253 |
254 | def __getitem__(self, idx):
255 | img_name = self.img_names[idx]
256 | pil_image = Image.open(img_name)
257 | if pil_image.mode != "RGB":
258 | # if input is grayscale image, convert it to 3 channel image
259 | if self.enhance:
260 | pil_image = power_transform(pil_image, 0.5)
261 | temp_image = Image.new('RGB', pil_image.size)
262 | temp_image.paste(pil_image)
263 | pil_image = temp_image
264 | image = np.array(pil_image)
265 | if self.gray_scale:
266 | image = rgb2gray(image)
267 | image = np.expand_dims(image, axis=2)
268 | image = np.concatenate((image, image, image), axis=2)
269 | image = image * 255.0
270 | image = image.astype(np.uint8)
271 | if not self.detect_face:
272 | center = [450//2, 450//2+0]
273 | if self.center_shift != 0:
274 | center[0] += int(np.random.uniform(-self.center_shift,
275 | self.center_shift))
276 | center[1] += int(np.random.uniform(-self.center_shift,
277 | self.center_shift))
278 | scale = 1.8
279 | else:
280 | detected_faces = self.face_detector.detect_image(image)
281 | if len(detected_faces) > 0:
282 | box = detected_faces[0]
283 | left, top, right, bottom, _ = box
284 | center = [right - (right - left) / 2.0,
285 | bottom - (bottom - top) / 2.0]
286 | center[1] = center[1] - (bottom - top) * 0.12
287 | scale = (right - left + bottom - top) / 195.0
288 | else:
289 | center = [450//2, 450//2+0]
290 | scale = 1.8
291 | if self.center_shift != 0:
292 | shift = self.center * self.center_shift / 450
293 | center[0] += int(np.random.uniform(-shift, shift))
294 | center[1] += int(np.random.uniform(-shift, shift))
295 | base_name = os.path.basename(img_name)
296 | landmarks_base_name = base_name[:-4] + '_pts.mat'
297 | landmarks_name = os.path.join(self.landmarks_dir, landmarks_base_name)
298 | if os.path.isfile(landmarks_name):
299 | mat_data = sio.loadmat(landmarks_name)
300 | landmarks = mat_data['pts_2d']
301 | elif os.path.isfile(landmarks_name[:-8] + '.pts.npy'):
302 | landmarks = np.load(landmarks_name[:-8] + '.pts.npy')
303 | else:
304 | landmarks = []
305 | heatmap = []
306 |
307 | if landmarks != []:
308 | new_image, new_landmarks = cv_crop(image, landmarks, center,
309 | scale, 256, self.center_shift)
310 | tries = 0
311 | while self.center_shift != 0 and tries < 5 and (np.max(new_landmarks) > 240 or np.min(new_landmarks) < 15):
312 | center = [450//2, 450//2+0]
313 | scale += 0.05
314 | center[0] += int(np.random.uniform(-self.center_shift,
315 | self.center_shift))
316 | center[1] += int(np.random.uniform(-self.center_shift,
317 | self.center_shift))
318 |
319 | new_image, new_landmarks = cv_crop(image, landmarks,
320 | center, scale, 256,
321 | self.center_shift)
322 | tries += 1
323 | if np.max(new_landmarks) > 250 or np.min(new_landmarks) < 5:
324 | center = [450//2, 450//2+0]
325 | scale = 2.25
326 | new_image, new_landmarks = cv_crop(image, landmarks,
327 | center, scale, 256,
328 | 100)
329 | assert (np.min(new_landmarks) > 0 and np.max(new_landmarks) < 256), \
330 | "Landmarks out of boundary!"
331 | image = new_image
332 | landmarks = new_landmarks
333 | heatmap = np.zeros((self.num_lanmdkars, 64, 64))
334 | for i in range(self.num_lanmdkars):
335 | if landmarks[i][0] > 0:
336 | heatmap[i] = draw_gaussian(heatmap[i], landmarks[i]/4.0+1, 1)
337 | sample = {'image': image, 'heatmap': heatmap, 'landmarks': landmarks}
338 | if self.transform:
339 | sample = self.transform(sample)
340 |
341 | return sample
342 |
343 | def get_dataset(val_img_dir, val_landmarks_dir, batch_size,
344 | num_landmarks=68, rotation=0, scale=0,
345 | center_shift=0, random_flip=False,
346 | brightness=0, contrast=0, saturation=0,
347 | blur=False, noise=False, jpeg_effect=False,
348 | random_occlusion=False, gray_scale=False,
349 | detect_face=False, enhance=False):
350 | val_transforms = transforms.Compose([AddBoundary(num_landmarks),
351 | AddWeightMap(),
352 | ToTensor()])
353 |
354 | val_dataset = FaceLandmarksDataset(val_img_dir, val_landmarks_dir,
355 | num_landmarks=num_landmarks,
356 | gray_scale=gray_scale,
357 | detect_face=detect_face,
358 | enhance=enhance,
359 | transform=val_transforms)
360 |
361 | val_dataloader = torch.utils.data.DataLoader(val_dataset,
362 | batch_size=batch_size,
363 | shuffle=False,
364 | num_workers=6)
365 | data_loaders = {'val': val_dataloader}
366 | dataset_sizes = {}
367 | dataset_sizes['val'] = len(val_dataset)
368 | return data_loaders, dataset_sizes
369 |
--------------------------------------------------------------------------------