├── .gitignore
├── libs
├── sfd
│ ├── __init__.py
│ ├── sfd_detector.py
│ ├── detect.py
│ ├── bbox.py
│ ├── net_s3fd.py
│ └── core.py
├── ffhq_cropping.py
├── utilities.py
├── landmarks_estimation.py
└── fan_model
│ ├── models.py
│ └── utils.py
├── images
└── example.png
├── requirements.txt
├── README.md
├── download_voxCeleb.py
└── preprocess_voxCeleb.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | # Models
3 | *.pth
4 |
5 |
--------------------------------------------------------------------------------
/libs/sfd/__init__.py:
--------------------------------------------------------------------------------
1 | from .sfd_detector import SFDDetector as FaceDetector
--------------------------------------------------------------------------------
/images/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StelaBou/voxceleb_preprocessing/HEAD/images/example.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | numpy
4 | opencv-python
5 | scikit-image
6 | tqdm
7 | matplotlib
8 | scipy>=0.17.0
9 | numba
10 | yt-dlp
11 |
--------------------------------------------------------------------------------
/libs/sfd/sfd_detector.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | from torch.utils.model_zoo import load_url
4 | import sys
5 | import matplotlib.pyplot as plt
6 | from .core import FaceDetector
7 |
8 | from .net_s3fd import s3fd
9 | from .bbox import *
10 | from .detect import *
11 | import torch.backends.cudnn as cudnn
12 |
13 |
14 | models_urls = {
15 | 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
16 | }
17 |
18 |
19 | class SFDDetector(FaceDetector):
20 | def __init__(self, device, path_to_detector=None, verbose=False):
21 | super(SFDDetector, self).__init__(device, verbose)
22 |
23 | self.device = device
24 | model_weights = torch.load(path_to_detector)
25 |
26 | self.face_detector = s3fd()
27 | self.face_detector.load_state_dict(model_weights)
28 | self.face_detector.to(self.device)
29 | self.face_detector.eval()
30 |
31 | def detect_from_batch(self, tensor):
32 |
33 | bboxlists = batch_detect(self.face_detector, tensor, device=self.device)
34 |
35 | new_bboxlists = []
36 | for i in range(bboxlists.shape[0]):
37 | bboxlist = bboxlists[i]
38 | keep = nms(bboxlist, 0.3)
39 | # print(keep)
40 | if len(keep)>0:
41 | bboxlist = bboxlist[keep, :]
42 | bboxlist = [x for x in bboxlist if x[-1] > 0.5]
43 | new_bboxlists.append(bboxlist)
44 | else:
45 | new_bboxlists.append([])
46 |
47 | return new_bboxlists
48 |
49 | @property
50 | def reference_scale(self):
51 | return 195
52 |
53 | @property
54 | def reference_x_shift(self):
55 | return 0
56 |
57 | @property
58 | def reference_y_shift(self):
59 | return 0
60 |
--------------------------------------------------------------------------------
/libs/ffhq_cropping.py:
--------------------------------------------------------------------------------
1 | """
2 | Crop images using facial landmarks
3 | """
4 | import numpy as np
5 | import cv2
6 | import os
7 | import collections
8 | import PIL.Image
9 | import PIL.ImageFile
10 | from PIL import Image
11 | import scipy.ndimage
12 |
13 | def pad_img_to_fit_bbox(img, x1, x2, y1, y2, crop_box):
14 | img_or = img.copy()
15 | img = cv2.copyMakeBorder(img,
16 | -min(0, y1), max(y2 - img.shape[0], 0),
17 | -min(0, x1), max(x2 - img.shape[1], 0), cv2.BORDER_REFLECT)
18 |
19 | y2 += -min(0, y1)
20 | y1 += -min(0, y1)
21 | x2 += -min(0, x1)
22 | x1 += -min(0, x1)
23 |
24 | pad = crop_box
25 | pad = (max(-pad[0], 0), max(-pad[1], 0), max(pad[2] - img_or.shape[1] , 0), max(pad[3] - img_or.shape[0] , 0))
26 |
27 | h, w, _ = img.shape
28 | y, x, _ = np.ogrid[:h, :w, :1]
29 | pad = np.array(pad, dtype=np.float32)
30 | pad[pad == 0] = 1e-10
31 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
32 | img = np.array(img, dtype=np.float32)
33 | blur = 5.0
34 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
35 | img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
36 |
37 | return img, x1, x2, y1, y2
38 |
39 | def crop_from_bbox(img, bbox):
40 | """
41 | bbox: tuple, (x1, y1, x2, y2)
42 | x: horizontal, y: vertical, exclusive
43 | """
44 | x1, y1, x2, y2 = bbox
45 | if x1 < 0 or y1 < 0 or x2 > img.shape[1] or y2 > img.shape[0]:
46 | img, x1, x2, y1, y2 = pad_img_to_fit_bbox(img, x1, x2, y1, y2, bbox)
47 | return img[y1:y2, x1:x2]
48 |
49 | def crop_using_landmarks(image, landmarks):
50 | image_size = 256
51 | center = ((landmarks.min(0) + landmarks.max(0)) / 2).round().astype(int)
52 | size = int(max(landmarks[:, 0].max() - landmarks[:, 0].min(), landmarks[:, 1].max() - landmarks[:, 1].min()))
53 | try:
54 | center[1] -= size // 6
55 | except:
56 | return None
57 |
58 | # Crop images and poses
59 | h, w, _ = image.shape
60 | img = Image.fromarray(image)
61 | crop_box = (center[0]-size, center[1]-size, center[0]+size, center[1]+size)
62 | image = crop_from_bbox(image, crop_box)
63 | try:
64 | img = Image.fromarray(image.astype(np.uint8))
65 | img = img.resize((image_size, image_size), Image.BICUBIC)
66 | pix = np.array(img)
67 | return pix
68 | except:
69 | return None
70 |
--------------------------------------------------------------------------------
/libs/sfd/detect.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | import os
5 | import sys
6 | import cv2
7 | import random
8 | import datetime
9 | import math
10 | import argparse
11 | import numpy as np
12 |
13 | import scipy.io as sio
14 | import zipfile
15 | from .net_s3fd import s3fd
16 | from .bbox import *
17 | import matplotlib.pyplot as plt
18 |
19 |
20 | def detect(net, img, device):
21 | img = img - np.array([104, 117, 123])
22 | img = img.transpose(2, 0, 1)
23 | # Creates a batch of 1
24 | img = img.reshape((1,) + img.shape)
25 |
26 |
27 | if torch.cuda.current_device() == 0:
28 | torch.backends.cudnn.benchmark = True
29 |
30 | img = torch.from_numpy(img).float().to(device)
31 |
32 |
33 | return batch_detect(net, img, device)
34 |
35 |
36 | def batch_detect(net, img_batch, device):
37 | """
38 | Inputs:
39 | - img_batch: a torch.Tensor of shape (Batch size, Channels, Height, Width)
40 | """
41 |
42 | BB, CC, HH, WW = img_batch.size()
43 |
44 | with torch.no_grad():
45 | olist = net(img_batch.float()) # patched uint8_t overflow error
46 |
47 |
48 | for i in range(len(olist) // 2):
49 | olist[i * 2] = F.softmax(olist[i * 2], dim=1)
50 |
51 | bboxlists = []
52 |
53 | olist = [oelem.data.cpu() for oelem in olist]
54 | for j in range(BB):
55 | bboxlist = []
56 | for i in range(len(olist) // 2):
57 | ocls, oreg = olist[i * 2], olist[i * 2 + 1]
58 | FB, FC, FH, FW = ocls.size() # feature map size
59 | stride = 2**(i + 2) # 4,8,16,32,64,128
60 | anchor = stride * 4
61 | poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
62 |
63 | for Iindex, hindex, windex in poss:
64 |
65 | axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
66 | score = ocls[j, 1, hindex, windex]
67 | loc = oreg[j, :, hindex, windex].contiguous().view(1, 4)
68 | priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
69 | variances = [0.1, 0.2]
70 | box = decode(loc, priors, variances)
71 | x1, y1, x2, y2 = box[0] * 1.0
72 | bboxlist.append([x1, y1, x2, y2, score])
73 | bboxlists.append(bboxlist)
74 |
75 | bboxlists = np.array(bboxlists)
76 |
77 | if 0 == len(bboxlists):
78 | bboxlists = np.zeros((1, 1, 5))
79 |
80 |
81 | return bboxlists
82 |
83 |
84 | def flip_detect(net, img, device):
85 | img = cv2.flip(img, 1)
86 | b = detect(net, img, device)
87 |
88 | bboxlist = np.zeros(b.shape)
89 | bboxlist[:, 0] = img.shape[1] - b[:, 2]
90 | bboxlist[:, 1] = b[:, 1]
91 | bboxlist[:, 2] = img.shape[1] - b[:, 0]
92 | bboxlist[:, 3] = b[:, 3]
93 | bboxlist[:, 4] = b[:, 4]
94 | return bboxlist
95 |
96 |
97 | def pts_to_bb(pts):
98 | min_x, min_y = np.min(pts, axis=0)
99 | max_x, max_y = np.max(pts, axis=0)
100 | return np.array([min_x, min_y, max_x, max_y])
101 |
--------------------------------------------------------------------------------
/libs/utilities.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 |
5 | def make_path(path):
6 | if not os.path.exists(path):
7 | os.makedirs(path, exist_ok= True)
8 |
9 | def _parse_metadata_file(metadata, dataset = 'vox1', frame = None):
10 |
11 | if dataset == 'vox1':
12 | metadata_info = []
13 | for path in metadata:
14 | frames = []
15 | bboxes = []
16 | with open(path) as f:
17 | segment_desc = f.read().strip()
18 |
19 | header, lines = segment_desc.split("\n\n")
20 | lines = lines.split("\n")
21 |
22 | for i in range(1,len(lines)):
23 |
24 | info = lines[i].split("\t")
25 | frames.append(int(info[0]))
26 | x1 = int(info[1])
27 | y1 = int(info[2])
28 | x2 = int(info[1]) + int(info[3])
29 | y2 = int(info[2]) + int(info[4])
30 | bboxes.append([x1, y1, x2, y2])
31 |
32 | info = {
33 | 'frames': frames,
34 | 'bboxes': bboxes
35 | }
36 | metadata_info.append(info)
37 | elif dataset == 'vox2':
38 | metadata_info = []
39 | for path in metadata:
40 | frames = []
41 | bboxes = []
42 | with open(path) as f:
43 | segment_desc = f.read().strip()
44 |
45 | header, lines = segment_desc.split("\n\n")
46 | lines = lines.split("\n")
47 |
48 | for i in range(1,len(lines)):
49 |
50 | info = lines[i].split("\t")
51 | frames.append(int(info[0]))
52 | x1 = int(float(info[1]) * frame.shape[1])
53 | y1 = int(float(info[2]) * frame.shape[0])
54 | x2 = int(float(info[3]) * frame.shape[1]) + x1
55 | y2 = int(float(info[4]) * frame.shape[0]) + y1
56 | bboxes.append([x1, y1, x2, y2])
57 |
58 | info = {
59 | 'frames': frames,
60 | 'bboxes': bboxes
61 | }
62 | metadata_info.append(info)
63 |
64 | else:
65 | print('Specify correct dataset')
66 | exit()
67 |
68 | return metadata_info
69 |
70 | def crop_box(image, bbox, scale_crop = 1.0):
71 |
72 | h_im, w_im, c = image.shape
73 |
74 | y1_hat = bbox[1]
75 | y2_hat = bbox[3]
76 | x1_hat = bbox[0]
77 | x2_hat = bbox[2]
78 |
79 | new_w = x2_hat - x1_hat
80 | w = x2_hat - x1_hat
81 | h = y2_hat - y1_hat
82 | cx = int(x1_hat + w/2)
83 | cy = int(y1_hat + h/2)
84 |
85 | w_hat = int(w * scale_crop)
86 | h_hat = int(h * scale_crop)
87 | x1_hat = cx - int(w_hat/2)
88 | if x1_hat < 0:
89 | x1_hat = 0
90 | y1_hat = cy - int(h_hat/2)
91 | if y1_hat < 0:
92 | y1_hat = 0
93 | x2_hat = x1_hat + w_hat
94 | y2_hat = y1_hat + h_hat
95 |
96 | if x2_hat > w_im:
97 | x2_hat = w_im
98 | if y2_hat > h_im:
99 | y2_hat = h_im
100 |
101 | if (y2_hat - y1_hat) > 20 and (x2_hat - x1_hat) > 20:
102 | crop = image[y1_hat:y2_hat, x1_hat:x2_hat, :]
103 | else:
104 | crop = image
105 |
106 | bbox_caled = [x1_hat, y1_hat, x2_hat, y2_hat]
107 | return crop, bbox_caled
108 |
109 |
110 | " Read image from path"
111 | def read_image_opencv(image_path):
112 | img = cv2.imread(image_path, cv2.IMREAD_COLOR) # BGR order
113 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
114 | return img.astype('uint8')
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Download and Preprocessing scripts for VoxCeleb datasets
2 |
3 | This is an auxiliary repo for downloading VoxCeleb videos and preprocessing of the extracted frames by cropping them around the face. For detecting and cropping the face area we use the landmark estimation method proposed in [1], [face-alignment](https://github.com/1adrianb/face-alignment).
4 |
5 |
6 |
7 |
8 |
9 | ## Installation
10 |
11 |
12 | * Python 3.5+
13 | * Linux
14 | * Pytorch (>=1.5)
15 |
16 | ### Instal requirments:
17 | ```
18 | pip install -r requirements.txt
19 | ```
20 |
21 | ### Install youtube-dl:
22 | ```
23 | pip install --upgrade youtube_dl
24 | ```
25 |
26 | ### Install ffmpeg
27 | ```
28 | sudo apt-get install ffmpeg
29 | ```
30 |
31 | ### Download auxilliary models and save them under `./pretrained_models`
32 |
33 | | Path | Description
34 | | :--- | :----------
35 | |[FaceDetector](https://drive.google.com/file/d/1IWqJUTAZCelAZrUzfU38zK_ZM25fK32S/view?usp=sharing) | SFD face detector for [face-alignment](https://github.com/1adrianb/face-alignment).
36 |
37 | ## Overview
38 |
39 | * Download videos of VoxCeleb1 or VoxCeleb2 dataset from youtube
40 | * Split videos in smaller ones using the metadata provided by the datasets and delete original videos
41 | * Extract frames from each video with REF_FPS = 25
42 | * Crop frames using the face boxes from the metadata and facial landmarks
43 | * Files are saved as:
44 | ```
45 | .path/to/voxdataset
46 | |-- id10271 # identity index
47 | | |-- 37nktPRUJ58 # video index
48 | | | |-- chunk_videos # chunk_videos: original video splitted in smaller ones
49 | | | | |-- 37nktPRUJ58#00001#257-396.mp4
50 | | | | |-- ...
51 | | | |-- frames # extracted frames
52 | | | | |-- 00_000025.png
53 | | | | |-- ...
54 | | | |-- frames_cropped # preprocessed frames
55 | | | | |-- 00_000025.png
56 | | | | |-- ...
57 | | |-- Zjc7Xy7aT8c
58 | | | | ...
59 | |-- id10273
60 | | | ...
61 | ```
62 |
63 | ## Download VoxCeleb datasets
64 |
65 | 1. Download metadata from [VoxCeleb1](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html) and [VoxCeleb2](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox2.html)
66 |
67 | ```
68 | wget www.robots.ox.ac.uk/~vgg/data/voxceleb/data/vox1_test_txt.zip
69 | unzip vox1_test_txt.zip
70 | mv ./txt ./vox1_txt_test
71 |
72 | wget www.robots.ox.ac.uk/~vgg/data/voxceleb/data/vox1_dev_txt.zip
73 | unzip vox1_dev_txt.zip
74 | mv ./txt ./vox1_txt_train
75 |
76 | ```
77 |
78 | ```
79 | wget www.robots.ox.ac.uk/~vgg/data/voxceleb/data/vox2_test_txt.zip
80 | unzip vox2_test_txt.zip
81 | mv ./txt ./vox2_txt_test
82 |
83 | wget www.robots.ox.ac.uk/~vgg/data/voxceleb/data/vox2_dev_txt.zip
84 | unzip vox2_dev_txt.zip
85 | mv ./txt ./vox2_txt_train
86 |
87 | ```
88 |
89 | 2. Run this script to download videos from youtube. Note that the original videos will be removed. Optionally extract and preprocess frames.
90 |
91 | ```
92 | python download_voxCeleb.py --dataset vox1 --output_path ./VoxCeleb1_test --metadata_path ./vox1_txt_test --delete_mp4
93 | ```
94 |
95 | ## Preprocessing of video frames
96 |
97 |
98 | 1. If videos have already been downloaded, run this script to extract and preprocess frames.
99 |
100 | ```
101 | python preprocess_voxCeleb.py --dataset vox1 --root_path ./VoxCeleb1_test --metadata_path ./vox1_txt_test
102 | ```
103 | ## Acknowledgments
104 |
105 | This code borrows from [video-preprocessing](https://github.com/AliaksandrSiarohin/video-preprocessing) and [face-alignment](https://github.com/1adrianb/face-alignment).
106 |
107 | ## References
108 |
109 | [1] Bulat, Adrian, and Georgios Tzimiropoulos. "How far are we from solving the 2D & 3D face alignment problem?(and a dataset of 230,000 3d facial landmarks)." Proceedings of the IEEE International Conference on Computer Vision. 2017.
110 |
111 |
112 |
113 |
--------------------------------------------------------------------------------
/libs/sfd/bbox.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import sys
4 | import cv2
5 | import random
6 | import datetime
7 | import time
8 | import math
9 | import argparse
10 | import numpy as np
11 | import torch
12 |
13 | try:
14 | from iou import IOU
15 | except BaseException:
16 | # IOU cython speedup 10x
17 | def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
18 | sa = abs((ax2 - ax1) * (ay2 - ay1))
19 | sb = abs((bx2 - bx1) * (by2 - by1))
20 | x1, y1 = max(ax1, bx1), max(ay1, by1)
21 | x2, y2 = min(ax2, bx2), min(ay2, by2)
22 | w = x2 - x1
23 | h = y2 - y1
24 | if w < 0 or h < 0:
25 | return 0.0
26 | else:
27 | return 1.0 * w * h / (sa + sb - w * h)
28 |
29 |
30 | def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
31 | xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
32 | dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
33 | dw, dh = math.log(ww / aww), math.log(hh / ahh)
34 | return dx, dy, dw, dh
35 |
36 |
37 | def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
38 | xc, yc = dx * aww + axc, dy * ahh + ayc
39 | ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
40 | x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
41 | return x1, y1, x2, y2
42 |
43 |
44 | def nms(dets, thresh):
45 | # print(dets)
46 | if 0 == len(dets):
47 | return []
48 | x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
49 | # print(x1,x2,y1,y2)
50 | areas = (x2 - x1 + 1) * (y2 - y1 + 1)
51 | order = scores.argsort()[::-1]
52 |
53 | keep = []
54 | while order.size > 0:
55 | i = order[0]
56 | keep.append(i)
57 | xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
58 | xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
59 |
60 | w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
61 | ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
62 |
63 | inds = np.where(ovr <= thresh)[0]
64 | order = order[inds + 1]
65 |
66 | return keep
67 |
68 |
69 | def encode(matched, priors, variances):
70 | """Encode the variances from the priorbox layers into the ground truth boxes
71 | we have matched (based on jaccard overlap) with the prior boxes.
72 | Args:
73 | matched: (tensor) Coords of ground truth for each prior in point-form
74 | Shape: [num_priors, 4].
75 | priors: (tensor) Prior boxes in center-offset form
76 | Shape: [num_priors,4].
77 | variances: (list[float]) Variances of priorboxes
78 | Return:
79 | encoded boxes (tensor), Shape: [num_priors, 4]
80 | """
81 |
82 | # dist b/t match center and prior's center
83 | g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
84 | # encode variance
85 | g_cxcy /= (variances[0] * priors[:, 2:])
86 | # match wh / prior wh
87 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
88 | g_wh = torch.log(g_wh) / variances[1]
89 | # return target for smooth_l1_loss
90 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
91 |
92 |
93 | def decode(loc, priors, variances):
94 | """Decode locations from predictions using priors to undo
95 | the encoding we did for offset regression at train time.
96 | Args:
97 | loc (tensor): location predictions for loc layers,
98 | Shape: [num_priors,4]
99 | priors (tensor): Prior boxes in center-offset form.
100 | Shape: [num_priors,4].
101 | variances: (list[float]) Variances of priorboxes
102 | Return:
103 | decoded bounding box predictions
104 | """
105 |
106 | boxes = torch.cat((
107 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
108 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
109 | boxes[:, :2] -= boxes[:, 2:] / 2
110 | boxes[:, 2:] += boxes[:, :2]
111 | return boxes
112 |
--------------------------------------------------------------------------------
/libs/sfd/net_s3fd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class L2Norm(nn.Module):
7 | def __init__(self, n_channels, scale=1.0):
8 | super(L2Norm, self).__init__()
9 | self.n_channels = n_channels
10 | self.scale = scale
11 | self.eps = 1e-10
12 | self.weight = nn.Parameter(torch.Tensor(self.n_channels))
13 | self.weight.data *= 0.0
14 | self.weight.data += self.scale
15 |
16 | def forward(self, x):
17 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
18 | x = x / norm * self.weight.view(1, -1, 1, 1)
19 | return x
20 |
21 |
22 | class s3fd(nn.Module):
23 | def __init__(self):
24 | super(s3fd, self).__init__()
25 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
26 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
27 |
28 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
29 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
30 |
31 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
32 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
33 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
34 |
35 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
36 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
37 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
38 |
39 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
40 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
41 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
42 |
43 | self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
44 | self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
45 |
46 | self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
47 | self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
48 |
49 | self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
50 | self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
51 |
52 | self.conv3_3_norm = L2Norm(256, scale=10)
53 | self.conv4_3_norm = L2Norm(512, scale=8)
54 | self.conv5_3_norm = L2Norm(512, scale=5)
55 |
56 | self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
57 | self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
58 | self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
59 | self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
60 | self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
61 | self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
62 |
63 | self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
64 | self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
65 | self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
66 | self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
67 | self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
68 | self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
69 |
70 | def forward(self, x):
71 | h = F.relu(self.conv1_1(x))
72 | h = F.relu(self.conv1_2(h))
73 | h = F.max_pool2d(h, 2, 2)
74 |
75 | h = F.relu(self.conv2_1(h))
76 | h = F.relu(self.conv2_2(h))
77 | h = F.max_pool2d(h, 2, 2)
78 |
79 | h = F.relu(self.conv3_1(h))
80 | h = F.relu(self.conv3_2(h))
81 | h = F.relu(self.conv3_3(h))
82 | f3_3 = h
83 | h = F.max_pool2d(h, 2, 2)
84 |
85 | h = F.relu(self.conv4_1(h))
86 | h = F.relu(self.conv4_2(h))
87 | h = F.relu(self.conv4_3(h))
88 | f4_3 = h
89 | h = F.max_pool2d(h, 2, 2)
90 |
91 | h = F.relu(self.conv5_1(h))
92 | h = F.relu(self.conv5_2(h))
93 | h = F.relu(self.conv5_3(h))
94 | f5_3 = h
95 | h = F.max_pool2d(h, 2, 2)
96 |
97 | h = F.relu(self.fc6(h))
98 | h = F.relu(self.fc7(h))
99 | ffc7 = h
100 | h = F.relu(self.conv6_1(h))
101 | h = F.relu(self.conv6_2(h))
102 | f6_2 = h
103 | h = F.relu(self.conv7_1(h))
104 | h = F.relu(self.conv7_2(h))
105 | f7_2 = h
106 |
107 | f3_3 = self.conv3_3_norm(f3_3)
108 | f4_3 = self.conv4_3_norm(f4_3)
109 | f5_3 = self.conv5_3_norm(f5_3)
110 |
111 | cls1 = self.conv3_3_norm_mbox_conf(f3_3)
112 | reg1 = self.conv3_3_norm_mbox_loc(f3_3)
113 | cls2 = self.conv4_3_norm_mbox_conf(f4_3)
114 | reg2 = self.conv4_3_norm_mbox_loc(f4_3)
115 | cls3 = self.conv5_3_norm_mbox_conf(f5_3)
116 | reg3 = self.conv5_3_norm_mbox_loc(f5_3)
117 | cls4 = self.fc7_mbox_conf(ffc7)
118 | reg4 = self.fc7_mbox_loc(ffc7)
119 | cls5 = self.conv6_2_mbox_conf(f6_2)
120 | reg5 = self.conv6_2_mbox_loc(f6_2)
121 | cls6 = self.conv7_2_mbox_conf(f7_2)
122 | reg6 = self.conv7_2_mbox_loc(f7_2)
123 |
124 | # max-out background label
125 | chunk = torch.chunk(cls1, 4, 1)
126 | bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
127 | cls1 = torch.cat([bmax, chunk[3]], dim=1)
128 |
129 | return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
130 |
--------------------------------------------------------------------------------
/libs/sfd/core.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import glob
3 | from tqdm import tqdm
4 | import numpy as np
5 | import torch
6 | import cv2
7 | from skimage import io
8 |
9 |
10 | class FaceDetector(object):
11 | """An abstract class representing a face detector.
12 |
13 | Any other face detection implementation must subclass it. All subclasses
14 | must implement ``detect_from_image``, that return a list of detected
15 | bounding boxes. Optionally, for speed considerations detect from path is
16 | recommended.
17 | """
18 |
19 | def __init__(self, device, verbose):
20 | self.device = device
21 | self.verbose = verbose
22 |
23 | # if verbose:
24 | # if 'cpu' in device:
25 | # logger = logging.getLogger(__name__)
26 | # logger.warning("Detection running on CPU, this may be potentially slow.")
27 |
28 | # if 'cpu' not in device and 'cuda' not in device:
29 | # if verbose:
30 | # logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
31 | # raise ValueError
32 |
33 | def detect_from_image(self, tensor_or_path):
34 | """Detects faces in a given image.
35 |
36 | This function detects the faces present in a provided BGR(usually)
37 | image. The input can be either the image itself or the path to it.
38 |
39 | Arguments:
40 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
41 | to an image or the image itself.
42 |
43 | Example::
44 |
45 | >>> path_to_image = 'data/image_01.jpg'
46 | ... detected_faces = detect_from_image(path_to_image)
47 | [A list of bounding boxes (x1, y1, x2, y2)]
48 | >>> image = cv2.imread(path_to_image)
49 | ... detected_faces = detect_from_image(image)
50 | [A list of bounding boxes (x1, y1, x2, y2)]
51 |
52 | """
53 | raise NotImplementedError
54 |
55 | def detect_from_batch(self, tensor):
56 | """Detects faces in a given image.
57 |
58 | This function detects the faces present in a provided BGR(usually)
59 | image. The input can be either the image itself or the path to it.
60 |
61 | Arguments:
62 | tensor {torch.tensor} -- image batch tensor.
63 |
64 | Example::
65 |
66 | >>> path_to_image = 'data/image_01.jpg'
67 | ... detected_faces = detect_from_image(path_to_image)
68 | [A list of bounding boxes (x1, y1, x2, y2)]
69 | >>> image = cv2.imread(path_to_image)
70 | ... detected_faces = detect_from_image(image)
71 | [A list of bounding boxes (x1, y1, x2, y2)]
72 |
73 | """
74 | raise NotImplementedError
75 |
76 | def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
77 | """Detects faces from all the images present in a given directory.
78 |
79 | Arguments:
80 | path {string} -- a string containing a path that points to the folder containing the images
81 |
82 | Keyword Arguments:
83 | extensions {list} -- list of string containing the extensions to be
84 | consider in the following format: ``.extension_name`` (default:
85 | {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
86 | folder recursively (default: {False}) show_progress_bar {bool} --
87 | display a progressbar (default: {True})
88 |
89 | Example:
90 | >>> directory = 'data'
91 | ... detected_faces = detect_from_directory(directory)
92 | {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
93 |
94 | """
95 | if self.verbose:
96 | logger = logging.getLogger(__name__)
97 |
98 | if len(extensions) == 0:
99 | if self.verbose:
100 | logger.error("Expected at list one extension, but none was received.")
101 | raise ValueError
102 |
103 | if self.verbose:
104 | logger.info("Constructing the list of images.")
105 | additional_pattern = '/**/*' if recursive else '/*'
106 | files = []
107 | for extension in extensions:
108 | files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
109 |
110 | if self.verbose:
111 | logger.info("Finished searching for images. %s images found", len(files))
112 | logger.info("Preparing to run the detection.")
113 |
114 | predictions = {}
115 | for image_path in tqdm(files, disable=not show_progress_bar):
116 | if self.verbose:
117 | logger.info("Running the face detector on image: %s", image_path)
118 | predictions[image_path] = self.detect_from_image(image_path)
119 |
120 | if self.verbose:
121 | logger.info("The detector was successfully run on all %s images", len(files))
122 |
123 | return predictions
124 |
125 | @property
126 | def reference_scale(self):
127 | raise NotImplementedError
128 |
129 | @property
130 | def reference_x_shift(self):
131 | raise NotImplementedError
132 |
133 | @property
134 | def reference_y_shift(self):
135 | raise NotImplementedError
136 |
137 | @staticmethod
138 | def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
139 | """Convert path (represented as a string) or torch.tensor to a numpy.ndarray
140 |
141 | Arguments:
142 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
143 | """
144 | if isinstance(tensor_or_path, str):
145 | return cv2.imread(tensor_or_path) if not rgb else io.imread(tensor_or_path)
146 | elif torch.is_tensor(tensor_or_path):
147 | # Call cpu in case its coming from cuda
148 | return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
149 | elif isinstance(tensor_or_path, np.ndarray):
150 | return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
151 | else:
152 | raise TypeError
153 |
--------------------------------------------------------------------------------
/libs/landmarks_estimation.py:
--------------------------------------------------------------------------------
1 | """
2 | The face detector used is SFD (taken from face-alignment)
3 | https://github.com/1adrianb/face-alignment
4 | """
5 | import os
6 | import numpy as np
7 | import cv2
8 | from enum import Enum
9 | import torch
10 | from torch.utils.model_zoo import load_url
11 |
12 |
13 | from libs.sfd.sfd_detector import SFDDetector as FaceDetector
14 | from libs.fan_model.models import FAN, ResNetDepth
15 | from libs.fan_model.utils import *
16 |
17 | models_urls = {
18 | '2DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/2DFAN4-11f355bf06.pth.tar',
19 | '3DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/3DFAN4-7835d9f11d.pth.tar',
20 | 'depth': 'https://www.adrianbulat.com/downloads/python-fan/depth-2a464da4ea.pth.tar',
21 | }
22 |
23 | class LandmarksType(Enum):
24 | """Enum class defining the type of landmarks to detect.
25 |
26 | ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
27 | ``_2halfD`` - this points represent the projection of the 3D points into 3D
28 | ``_3D`` - detect the points ``(x,y,z)``` in a 3D space
29 |
30 | """
31 | _2D = 1
32 | _2halfD = 2
33 | _3D = 3
34 |
35 | class NetworkSize(Enum):
36 | # TINY = 1
37 | # SMALL = 2
38 | # MEDIUM = 3
39 | LARGE = 4
40 |
41 | def __new__(cls, value):
42 | member = object.__new__(cls)
43 | member._value_ = value
44 | return member
45 |
46 | def __int__(self):
47 | return self.value
48 |
49 |
50 | def get_preds_fromhm(hm, center=None, scale=None):
51 | """Obtain (x,y) coordinates given a set of N heatmaps. If the center
52 | and the scale is provided the function will return the points also in
53 | the original coordinate frame.
54 |
55 | Arguments:
56 | hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
57 |
58 | Keyword Arguments:
59 | center {torch.tensor} -- the center of the bounding box (default: {None})
60 | scale {float} -- face scale (default: {None})
61 | """
62 | max, idx = torch.max(
63 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
64 | idx = idx + 1
65 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
66 | preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
67 | preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
68 |
69 | for i in range(preds.size(0)):
70 | for j in range(preds.size(1)):
71 | hm_ = hm[i, j, :]
72 | pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
73 | if pX > 0 and pX < 63 and pY > 0 and pY < 63:
74 | diff = torch.FloatTensor(
75 | [hm_[pY, pX + 1] - hm_[pY, pX - 1],
76 | hm_[pY + 1, pX] - hm_[pY - 1, pX]])
77 | preds[i, j].add_(diff.sign_().mul_(.25))
78 |
79 | preds.add_(-.5)
80 |
81 | preds_orig = torch.zeros(preds.size())
82 | if center is not None and scale is not None:
83 | for i in range(hm.size(0)):
84 | for j in range(hm.size(1)):
85 | preds_orig[i, j] = transform(
86 | preds[i, j], center, scale, hm.size(2), True)
87 |
88 | return preds, preds_orig
89 |
90 | def draw_detected_face(img, face):
91 | x_min = int(face[0])
92 | y_min = int(face[1])
93 | x_max = int(face[2])
94 | y_max = int(face[3])
95 |
96 | cv2.rectangle(img, (int(x_min),int(y_min)), (int(x_max),int(y_max)), (255,0,0), 2)
97 |
98 | return img
99 |
100 |
101 | class LandmarksEstimation():
102 | def __init__(self, type = '3D', path_to_detector = './pretrained_models/s3fd-619a316812.pth'):
103 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
104 | # Load all needed models - Face detector and Pose detector
105 | network_size = NetworkSize.LARGE
106 | network_size = int(network_size)
107 | if type == '3D':
108 | self.landmarks_type = LandmarksType._3D
109 | else:
110 | self.landmarks_type = LandmarksType._2D
111 | self.flip_input = False
112 |
113 | #################### SFD face detection ###################
114 | if not os.path.exists(path_to_detector):
115 | print('Pretrained model of SFD face detector does not exist in {}'.format(path_to_detector))
116 | exit()
117 | self.face_detector = FaceDetector(device=self.device, verbose=False, path_to_detector = path_to_detector)
118 | ###########################################################
119 |
120 | ################### Initialise the face alignemnt networks ###################
121 | self.face_alignment_net = FAN(network_size)
122 | if self.landmarks_type == LandmarksType._2D: #
123 | network_name = '2DFAN-' + str(network_size)
124 | else:
125 | network_name = '3DFAN-' + str(network_size)
126 | fan_weights = load_url(models_urls[network_name], map_location=lambda storage, loc: storage)
127 | self.face_alignment_net.load_state_dict(fan_weights)
128 | self.face_alignment_net.to(self.device)
129 | self.face_alignment_net.eval()
130 | ##############################################################################
131 |
132 | # Initialiase the depth prediciton network if 3D landmarks
133 | if self.landmarks_type == LandmarksType._3D:
134 | self.depth_prediciton_net = ResNetDepth()
135 | depth_weights = load_url(models_urls['depth'], map_location=lambda storage, loc: storage)
136 | depth_dict = {
137 | k.replace('module.', ''): v for k,
138 | v in depth_weights['state_dict'].items()}
139 | self.depth_prediciton_net.load_state_dict(depth_dict)
140 | self.depth_prediciton_net.to(self.device)
141 | self.depth_prediciton_net.eval()
142 |
143 | def get_landmarks(self, face, image):
144 |
145 | center = torch.FloatTensor(
146 | [(face[2] + face[0]) / 2.0,
147 | (face[3] + face[1]) / 2.0])
148 |
149 | center[1] = center[1] - (face[3] - face[1]) * 0.12
150 | scale = (face[2] - face[0] + face[3] - face[1]) / self.face_detector.reference_scale
151 |
152 | inp = crop_torch(image, center, scale).float().cuda()
153 | inp = inp.div(255.0)
154 |
155 | out = self.face_alignment_net(inp)[-1]
156 |
157 | if self.flip_input:
158 | out = out + flip(self.face_alignment_net(flip(inp))
159 | [-1], is_label=True)
160 | out = out.cpu()
161 |
162 | pts, pts_img = get_preds_fromhm(out, center, scale)
163 | out = out.cuda()
164 | # Added 3D landmark support
165 | if self.landmarks_type == LandmarksType._3D:
166 | pts, pts_img = pts.view(68, 2) * 4, pts_img.view(68, 2)
167 | heatmaps = torch.zeros((68,256,256), dtype=torch.float32)
168 | for i in range(68):
169 | if pts[i, 0] > 0:
170 | heatmaps[i] = draw_gaussian(
171 | heatmaps[i], pts[i], 2)
172 |
173 | heatmaps = heatmaps.unsqueeze(0)
174 |
175 | heatmaps = heatmaps.to(self.device)
176 | depth_pred = self.depth_prediciton_net(
177 | torch.cat((inp, heatmaps), 1)).view(68, 1)
178 |
179 | pts_img = pts_img.cuda()
180 | pts_img = torch.cat(
181 | (pts_img, depth_pred * (1.0 / (256.0 / (200.0 * scale)))), 1)
182 | else:
183 | pts, pts_img = pts.view(-1, 68, 2) * 4, pts_img.view(-1, 68, 2)
184 |
185 | return pts_img, out
186 |
187 | def detect_landmarks(self, image):
188 |
189 | if len(image.shape) == 3:
190 | image = image.unsqueeze(0)
191 |
192 | if self.device == 'cuda':
193 | image = image.cuda()
194 |
195 | with torch.no_grad():
196 | detected_faces = self.face_detector.detect_from_batch(image)
197 |
198 | if self.landmarks_type == LandmarksType._3D:
199 | landmarks = torch.empty((1, 68, 3))
200 | else:
201 | landmarks = torch.empty((1, 68, 2))
202 |
203 | for face in detected_faces[0]:
204 | conf = face[4]
205 | if conf > 0.99:
206 | pts_img, heatmaps = self.get_landmarks(face, image)
207 | landmarks[0] = pts_img
208 |
209 | return landmarks
--------------------------------------------------------------------------------
/download_voxCeleb.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import imageio
6 | import os
7 | import warnings
8 | import glob
9 | import time
10 | from tqdm import tqdm
11 | from argparse import ArgumentParser
12 | from skimage import img_as_ubyte
13 | from skimage.transform import resize
14 | warnings.filterwarnings("ignore")
15 | import cv2
16 | import yt_dlp as youtube_dl # Changed import
17 | import subprocess
18 |
19 | from libs.utilities import make_path
20 | from preprocess_voxCeleb import extract_frames_opencv, preprocess_frames
21 | from libs.landmarks_estimation import LandmarksEstimation
22 |
23 | """
24 | 1. Download videos from youtube for VoxCeleb1 dataset
25 | 2. Generate chunk videos using the metadata provided by VoxCeleb1 dataset
26 |
27 | Optionally:
28 | 3. Extract frames from chunk videos
29 | 4. Preprocess extracted frames by cropping them around the detected faces
30 |
31 | Arguments:
32 | output_path: path to save the videos
33 | metadata_path: txt files from VoxCeleb
34 | dataset: dataset name: vox1 or vox2
35 | fail_video_ids: txt file to save the videos ids that fail to download
36 | extract_frames: select for frame extraction from videos
37 | preprocessing: select for frame preprocessing
38 | delete_mp4: select to delete the original video from youtube
39 | delete_or_frames: select to delete the original extracted frames
40 |
41 | python download_voxCeleb.py --output_path ./VoxCeleb1_test --metadata_path ./txt_test --dataset vox1 \
42 | --fail_video_ids ./fail_video_ids_test.txt --delete_mp4 --extract_frames --preprocessing
43 |
44 | """
45 |
46 | DEVNULL = open(os.devnull, 'wb')
47 |
48 | REF_FPS = 25
49 |
50 | parser = ArgumentParser()
51 | parser.add_argument("--output_path", required = True, help='Path to save the videos')
52 | parser.add_argument("--metadata_path", required = True, help='Path to metadata')
53 | parser.add_argument("--dataset", required = True, type = str, choices=('vox1', 'vox2'), help="Download vox1 or vox2 dataset")
54 |
55 | parser.add_argument("--fail_video_ids", default=None, help='Txt file to save videos that fail to download')
56 | parser.add_argument("--extract_frames", action='store_true', help='Extract frames from videos')
57 | parser.set_defaults(extract_frames=False)
58 | parser.add_argument("--preprocessing", action='store_true', help='Preprocess extracted frames')
59 | parser.set_defaults(preprocessing=False)
60 | parser.add_argument("--delete_mp4", action='store_true', help='Delete original video downloaded from youtube')
61 | parser.set_defaults(delete_mp4=False)
62 | parser.add_argument("--delete_or_frames", dest='delete_or_frames', action='store_true', help="Delete original frames and keep only the cropped frames")
63 | parser.set_defaults(delete_or_frames=False)
64 |
65 | def my_hook(d):
66 | if d['status'] == 'finished':
67 | print('Done downloading, now converting ...')
68 |
69 | def download_video(video_id, video_path, id_path, fail_video_ids = None):
70 | ydl_opts = {
71 | 'format': 'mp4',
72 | 'outtmpl': video_path,
73 | 'progress_hooks': [my_hook],
74 | 'cookies': 'cookies.txt'
75 | }
76 | success = True
77 | try:
78 | with youtube_dl.YoutubeDL(ydl_opts) as ydl:
79 | ydl.download(['https://www.youtube.com/watch?v=' + video_id])
80 | except KeyboardInterrupt:
81 | print ('Stopped')
82 | exit()
83 | except:
84 | print('Error downloading video {}'.format(video_id))
85 | success = False
86 | if fail_video_ids is not None:
87 | f = open(fail_video_ids, "a")
88 | f.write(id_path + '/' + video_id + '\n')
89 | f.close()
90 | return success
91 |
92 | def split_in_utterances(video_id, video_path, utterance_files, chunk_folder):
93 | chunk_videos = []
94 | utterances = [pd.read_csv(f, sep='\t', skiprows=6) for f in utterance_files]
95 | for i, utterance in enumerate(utterances):
96 | first_frame, last_frame = utterance['FRAME '].iloc[0], utterance['FRAME '].iloc[-1]
97 | st = first_frame
98 | en = last_frame
99 | first_frame = round(first_frame / float(REF_FPS), 3)
100 | last_frame = round(last_frame / float(REF_FPS), 3)
101 | head, tail = os.path.split(utterance_files[i])#utterance_files[i].
102 | tail = tail.split('.tx')[0]
103 | chunk_name = os.path.join(chunk_folder, video_id + '#' + tail + '#' + str(st) + '-' + str(en) + '.mp4')
104 |
105 | command_fps = 'ffmpeg -y -i {} -qscale:v 5 -r 25 -threads 1 -ss {} -to {} -strict -2 {} -loglevel quiet'.format(video_path, first_frame, last_frame, chunk_name)
106 | os.system(command_fps)
107 | chunk_videos.append(chunk_name)
108 |
109 | return chunk_videos
110 |
111 | if __name__ == "__main__":
112 |
113 |
114 | args = parser.parse_args()
115 | extract_frames = args.extract_frames
116 | preprocessing = args.preprocessing
117 | fail_video_ids = args.fail_video_ids
118 | output_path = args.output_path
119 | make_path(output_path)
120 | delete_mp4 = args.delete_mp4
121 | delete_or_frames = args.delete_or_frames
122 | metadata_path = args.metadata_path
123 | dataset = args.dataset
124 |
125 | if not os.path.exists(metadata_path):
126 | print('Please download the metadata for {} dataset'.format(dataset))
127 | exit()
128 |
129 | ids_path = glob.glob(os.path.join(metadata_path, '*/'))
130 | ids_path.sort()
131 | print('{} dataset has {} identities'.format(dataset, len(ids_path)))
132 |
133 | print('--Delete original mp4 videos: \t\t{}'.format(delete_mp4))
134 | print('--Delete original frames: \t\t{}'.format(delete_or_frames))
135 | print('--Extract frames from chunk videos: \t{}'.format(extract_frames))
136 | print('--Preprocess original frames: \t\t{}'.format(preprocessing))
137 |
138 | if preprocessing:
139 | landmark_est = LandmarksEstimation(type = '2D')
140 |
141 | for i, id_path in enumerate(ids_path):
142 | id_index = id_path.split('/')[-2]
143 | videos_path = glob.glob(os.path.join(id_path, '*/'))
144 | videos_path.sort()
145 | print('*********************************************************')
146 | print('Identity {}/{}: {} videos for {} identity'.format(i, len(ids_path), len(videos_path), id_index))
147 |
148 | for j, video_path in enumerate(videos_path):
149 |
150 | print('{}/{} videos'.format(j, len(videos_path)))
151 |
152 | video_id = video_path.split('/')[-2]
153 | output_path_video = os.path.join(output_path, id_index, video_id)
154 | make_path(output_path_video)
155 |
156 | print('Download video id {}. Save to {}'.format(video_id, output_path_video))
157 |
158 | txt_metadata = glob.glob(os.path.join(video_path, '*.txt'))
159 | txt_metadata.sort()
160 |
161 | mp4_path = os.path.join(output_path_video, '{}.mp4'.format(video_id))
162 | if not os.path.exists(mp4_path):
163 | success = download_video(video_id, mp4_path, id_index, fail_video_ids = fail_video_ids)
164 | else:
165 | # Video already exists
166 | success = True
167 |
168 | if success:
169 | # Split in small videos using the metadata
170 | output_path_chunk_videos = os.path.join(output_path, id_index, video_id, 'chunk_videos')
171 | make_path(output_path_chunk_videos)
172 | chunk_videos = split_in_utterances(video_id, mp4_path, txt_metadata, output_path_chunk_videos)
173 | if delete_mp4: # Delete original video downloaded from youtube
174 | command_delete = 'rm -rf {}'.format(mp4_path)
175 | os.system(command_delete)
176 |
177 | extracted_frames_path = os.path.join(output_path_video, 'frames')
178 | if extract_frames:
179 | # Run frame extraction
180 | extract_frames_opencv(chunk_videos, REF_FPS, extracted_frames_path)
181 | if preprocessing:
182 | # Run preprocessing
183 | image_files = glob.glob(os.path.join(extracted_frames_path, '*.png'))
184 | image_files.sort()
185 | if len(image_files) > 0:
186 | save_dir = os.path.join(output_path_video, 'frames_cropped')
187 | make_path(save_dir)
188 | preprocess_frames(dataset, output_path_video, extracted_frames_path, image_files, save_dir, txt_metadata, landmark_est)
189 | else:
190 | print('There are no extracted frames on path: {}'.format(extracted_frames_path))
191 |
192 | if delete_or_frames and len(image_files) > 0: # Delete original frames
193 | command_delete = 'rm -rf {}'.format(extracted_frames_path)
194 | os.system(command_delete)
195 | else:
196 | print('Error downloading video {}/{}. Deleting folder {}'.format(id_index, video_id, output_path_video))
197 | command_delete = 'rm -rf {}'.format(output_path_video)
198 | os.system(command_delete)
199 |
200 |
--------------------------------------------------------------------------------
/preprocess_voxCeleb.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 | import os
4 | import glob
5 | from argparse import ArgumentParser
6 | import cv2
7 | import torch
8 | from skimage.transform import resize
9 |
10 | from libs.utilities import make_path, _parse_metadata_file, crop_box, read_image_opencv
11 | from libs.ffhq_cropping import crop_using_landmarks
12 | from libs.landmarks_estimation import LandmarksEstimation
13 |
14 | """
15 | If chunk videos have already been generated using download_voxCeleb.py:
16 |
17 | 1. Extract frames from chunk videos
18 | 2. Preprocess extracted frames by cropping them around the detected faces
19 |
20 | Arguments:
21 | root_path: path where chunk videos are saved
22 | metadata_path: txt files from VoxCeleb
23 | dataset: dataset name: vox1 or vox2
24 | delete_videos: select to delete all videos
25 | delete_or_frames: select to delete the original extracted frames
26 |
27 | python preprocess_voxCeleb.py --root_path ./VoxCeleb1_test --metadata_path ./vox1_txt_test --dataset vox1
28 |
29 | """
30 |
31 | REF_FPS = 25 # fps to extract frames
32 | REF_SIZE = 360 # Height
33 | LOW_RES_SIZE = 400
34 |
35 |
36 | parser = ArgumentParser()
37 | parser.add_argument("--root_path", default='videos', required = True, help='Path to youtube videos')
38 | parser.add_argument("--metadata_path", default='metadata', required = True, help='Path to metadata')
39 | parser.add_argument("--dataset", required = True, type = str, choices=('vox1', 'vox2'), help="Download vox1 or vox2 dataset")
40 |
41 | parser.add_argument("--delete_videos", action='store_true', help='Delete chunk videos')
42 | parser.set_defaults(delete_videos=False)
43 | parser.add_argument("--delete_or_frames", dest='delete_or_frames', action='store_true', help="Delete original frames and keep only the cropped frames")
44 | parser.set_defaults(delete_or_frames=False)
45 |
46 |
47 | def get_frames(video_path, frames_path, video_index, fps):
48 | cap = cv2.VideoCapture(video_path)
49 | counter = 0
50 | # a variable to set how many frames you want to skip
51 | frame_skip = fps
52 | while cap.isOpened():
53 | ret, frame = cap.read()
54 | if not ret:
55 | break
56 | if counter % frame_skip == 0:
57 | cv2.imwrite(os.path.join(frames_path, '{:02d}_{:06d}.png'.format(video_index, counter)), frame)
58 | counter += 1
59 |
60 | cap.release()
61 | cv2.destroyAllWindows()
62 |
63 |
64 | def extract_frames_opencv(videos_tmp, fps, frames_path):
65 |
66 | print('1. Extract frames')
67 | make_path(frames_path)
68 | for i in tqdm(range(len(videos_tmp))):
69 | get_frames(videos_tmp[i], frames_path, i, fps)
70 |
71 |
72 | def preprocess_frames(dataset, output_path_video, frames_path, image_files, save_dir, txt_metadata, landmark_est = None):
73 |
74 | if dataset == 'vox2':
75 | image_ref = read_image_opencv(image_files[0])
76 | mult = image_ref.shape[0] / REF_SIZE
77 | image_ref = resize(image_ref, (REF_SIZE, int(image_ref.shape[1] / mult)), preserve_range=True)
78 | else:
79 | image_ref = None
80 |
81 | info_metadata = _parse_metadata_file(txt_metadata, dataset = dataset, frame = image_ref)
82 |
83 | errors = []
84 | chunk_id = 0
85 | frame_i = 0
86 | print('2. Preprocess frames')
87 | for i in tqdm(range(len(image_files))):
88 |
89 | # Check from which chunk video each frame is extracted.
90 | # Frames are saved as chunkid_index.png
91 | image_file = image_files[i]
92 | image_name = image_file.split('/')[-1]
93 | image_chunk_id = image_name.split('.')[0]
94 | image_chunk_id = int(image_chunk_id.split('_')[0])
95 | bbox = None
96 | if chunk_id != image_chunk_id:
97 | chunk_id += 1
98 | frame_i = 0
99 | #########################################
100 | if chunk_id < len(info_metadata):
101 | frames = info_metadata[chunk_id]['frames']
102 | bboxes_metadata = info_metadata[chunk_id]['bboxes']
103 | # print('Index with chunk videos every REF_FPS frames..')
104 | index = frame_i+1 + frame_i*(REF_FPS-1)
105 | if index < len(bboxes_metadata):
106 | bbox = bboxes_metadata[index]
107 | frame = frames[index]
108 |
109 | if bbox is not None:
110 | image = read_image_opencv(image_file)
111 | frame = image.copy()
112 | (h, w) = image.shape[:2]
113 |
114 | scale_res = REF_SIZE / float(h)
115 | bbox_new = bbox.copy()
116 | bbox_new[0] = bbox_new[0] / scale_res
117 | bbox_new[1] = bbox_new[1] / scale_res
118 | bbox_new[2] = bbox_new[2] / scale_res
119 | bbox_new[3] = bbox_new[3] / scale_res
120 |
121 | cropped_image, bbox_scaled = crop_box(frame, bbox_new, scale_crop = 2.0)
122 | filename = os.path.join(save_dir, image_name)
123 | cv2.imwrite(filename, cv2.cvtColor(cropped_image.copy(), cv2.COLOR_RGB2BGR))
124 | h, w, _ = cropped_image.shape
125 | image_tensor = torch.tensor(np.transpose(cropped_image, (2,0,1))).float().cuda()
126 |
127 | if landmark_est is not None:
128 | with torch.no_grad():
129 | landmarks = landmark_est.detect_landmarks( image_tensor.unsqueeze(0))
130 | landmarks = landmarks[0].detach().cpu().numpy()
131 | landmarks = np.asarray(landmarks)
132 | condition = np.any(landmarks > w) or np.any(landmarks < 0)
133 | if (condition == False) :
134 | img = crop_using_landmarks(cropped_image, landmarks)
135 | if img is not None:
136 | filename = os.path.join(save_dir, image_name)
137 | cv2.imwrite(filename, cv2.cvtColor(img.copy(), cv2.COLOR_RGB2BGR))
138 | frame_i += 1
139 |
140 | if __name__ == "__main__":
141 |
142 |
143 | args = parser.parse_args()
144 | root_path = args.root_path
145 | if not os.path.exists(root_path):
146 | print('Videos path {} does not exist'.format(root_path))
147 |
148 | metadata_path = args.metadata_path
149 | delete_videos = args.delete_videos
150 | delete_or_frames = args.delete_or_frames
151 | dataset = args.dataset
152 |
153 | if not os.path.exists(metadata_path):
154 | print('Please download the metadata for {} dataset'.format(dataset))
155 | exit()
156 | landmark_est = LandmarksEstimation(type = '2D')
157 |
158 | print('--Delete chunk videos: \t\t\t{}'.format(delete_videos))
159 | print('--Delete original frames: \t\t{}'.format(delete_or_frames))
160 |
161 | ids_path = glob.glob(os.path.join(root_path, '*/'))
162 | ids_path.sort()
163 | print('Dataset has {} identities'.format(len(ids_path)))
164 |
165 | data_csv = []
166 | data_low_res = []
167 | for i, id_path in enumerate(ids_path):
168 | id_index = id_path.split('/')[-2]
169 | videos_path = glob.glob(os.path.join(id_path, '*/'))
170 | videos_path.sort()
171 | print('*********************************************************')
172 | print('Identity {}/{}: {} videos for {} identity'.format(i, len(ids_path), len(videos_path), id_index))
173 |
174 | count = 0
175 | for j, video_path in enumerate(videos_path):
176 | video_id = video_path.split('/')[-2]
177 |
178 | print('{}/{} videos'.format(j, len(videos_path)))
179 |
180 | output_path_video = os.path.join(root_path, id_index, video_id)
181 | output_path_chunk_videos = os.path.join(output_path_video, 'chunk_videos')
182 | if not os.path.exists(output_path_chunk_videos):
183 | print('path {} does not exist.'.format(output_path_chunk_videos))
184 | else:
185 |
186 | txt_metadata = glob.glob(os.path.join(metadata_path, id_index, video_id, '*.txt'))
187 | txt_metadata.sort()
188 |
189 | ############################################################
190 | ### Frame extraction ###
191 | ############################################################
192 | videos_tmp = glob.glob(os.path.join(output_path_chunk_videos, '*.mp4'))
193 | videos_tmp.sort()
194 | extracted_frames_path = os.path.join(output_path_video, 'frames')
195 | if len(videos_tmp) > 0:
196 | extract_frames_opencv(videos_tmp, REF_FPS, extracted_frames_path)
197 | else:
198 | print('No videos in {}'.format(output_path_video))
199 | count += 1
200 | continue
201 |
202 |
203 | ############################################################
204 | ### Preprocessing ###
205 | ############################################################
206 | image_files = glob.glob(os.path.join(extracted_frames_path, '*.png'))
207 | image_files.sort()
208 | if len(image_files) > 0:
209 | save_dir = os.path.join(output_path_video, 'frames_cropped')
210 | make_path(save_dir)
211 | preprocess_frames(dataset, output_path_video, extracted_frames_path, image_files, save_dir, txt_metadata, landmark_est)
212 | else:
213 | print('No frames in {}'.format(extracted_frames_path))
214 |
215 | # Delete all chunk videos
216 | if delete_videos:
217 | command_delete = 'rm -rf {}'.format(os.path.join(output_path_video, '*.mp4'))
218 | os.system(command_delete)
219 | # Delete original frames
220 | if delete_or_frames:
221 | command_delete = 'rm -rf {}'.format(os.path.join(output_path_video, frames_folder_name))
222 | os.system(command_delete)
223 | ################################################
224 | count += 1
225 | print('*********************************************************')
226 |
--------------------------------------------------------------------------------
/libs/fan_model/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 |
7 | def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
8 | "3x3 convolution with padding"
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3,
10 | stride=strd, padding=padding, bias=bias)
11 |
12 |
13 | class ConvBlock(nn.Module):
14 | def __init__(self, in_planes, out_planes):
15 | super(ConvBlock, self).__init__()
16 | self.bn1 = nn.BatchNorm2d(in_planes)
17 | self.conv1 = conv3x3(in_planes, int(out_planes / 2))
18 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
19 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
20 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
21 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
22 |
23 | if in_planes != out_planes:
24 | self.downsample = nn.Sequential(
25 | nn.BatchNorm2d(in_planes),
26 | nn.ReLU(True),
27 | nn.Conv2d(in_planes, out_planes,
28 | kernel_size=1, stride=1, bias=False),
29 | )
30 | else:
31 | self.downsample = None
32 |
33 | def forward(self, x):
34 | residual = x
35 |
36 | out1 = self.bn1(x)
37 | out1 = F.relu(out1, True)
38 | out1 = self.conv1(out1)
39 |
40 | out2 = self.bn2(out1)
41 | out2 = F.relu(out2, True)
42 | out2 = self.conv2(out2)
43 |
44 | out3 = self.bn3(out2)
45 | out3 = F.relu(out3, True)
46 | out3 = self.conv3(out3)
47 |
48 | out3 = torch.cat((out1, out2, out3), 1)
49 |
50 | if self.downsample is not None:
51 | residual = self.downsample(residual)
52 |
53 | out3 += residual
54 |
55 | return out3
56 |
57 |
58 | class Bottleneck(nn.Module):
59 |
60 | expansion = 4
61 |
62 | def __init__(self, inplanes, planes, stride=1, downsample=None):
63 | super(Bottleneck, self).__init__()
64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
65 | self.bn1 = nn.BatchNorm2d(planes)
66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
67 | padding=1, bias=False)
68 | self.bn2 = nn.BatchNorm2d(planes)
69 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
70 | self.bn3 = nn.BatchNorm2d(planes * 4)
71 | self.relu = nn.ReLU(inplace=True)
72 | self.downsample = downsample
73 | self.stride = stride
74 |
75 | def forward(self, x):
76 | residual = x
77 |
78 | out = self.conv1(x)
79 | out = self.bn1(out)
80 | out = self.relu(out)
81 |
82 | out = self.conv2(out)
83 | out = self.bn2(out)
84 | out = self.relu(out)
85 |
86 | out = self.conv3(out)
87 | out = self.bn3(out)
88 |
89 | if self.downsample is not None:
90 | residual = self.downsample(x)
91 |
92 | out += residual
93 | out = self.relu(out)
94 |
95 | return out
96 |
97 |
98 | class HourGlass(nn.Module):
99 | def __init__(self, num_modules, depth, num_features):
100 | super(HourGlass, self).__init__()
101 | self.num_modules = num_modules
102 | self.depth = depth
103 | self.features = num_features
104 |
105 | self._generate_network(self.depth)
106 |
107 | def _generate_network(self, level):
108 | self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
109 |
110 | self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
111 |
112 | if level > 1:
113 | self._generate_network(level - 1)
114 | else:
115 | self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
116 |
117 | self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
118 |
119 | def _forward(self, level, inp):
120 | # Upper branch
121 | up1 = inp
122 | up1 = self._modules['b1_' + str(level)](up1)
123 |
124 | # Lower branch
125 | low1 = F.avg_pool2d(inp, 2, stride=2)
126 | low1 = self._modules['b2_' + str(level)](low1)
127 |
128 | if level > 1:
129 | low2 = self._forward(level - 1, low1)
130 | else:
131 | low2 = low1
132 | low2 = self._modules['b2_plus_' + str(level)](low2)
133 |
134 | low3 = low2
135 | low3 = self._modules['b3_' + str(level)](low3)
136 |
137 | up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
138 |
139 | return up1 + up2
140 |
141 | def forward(self, x):
142 | return self._forward(self.depth, x)
143 |
144 |
145 | class FAN(nn.Module):
146 |
147 | def __init__(self, num_modules=1):
148 | super(FAN, self).__init__()
149 | self.num_modules = num_modules
150 |
151 | # Base part
152 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
153 | self.bn1 = nn.BatchNorm2d(64)
154 | self.conv2 = ConvBlock(64, 128)
155 | self.conv3 = ConvBlock(128, 128)
156 | self.conv4 = ConvBlock(128, 256)
157 |
158 | # Stacking part
159 | for hg_module in range(self.num_modules):
160 | self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
161 | self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
162 | self.add_module('conv_last' + str(hg_module),
163 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
164 | self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
165 | self.add_module('l' + str(hg_module), nn.Conv2d(256,
166 | 68, kernel_size=1, stride=1, padding=0))
167 |
168 | if hg_module < self.num_modules - 1:
169 | self.add_module(
170 | 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
171 | self.add_module('al' + str(hg_module), nn.Conv2d(68,
172 | 256, kernel_size=1, stride=1, padding=0))
173 |
174 | def forward(self, x):
175 | x = F.relu(self.bn1(self.conv1(x)), True)
176 | x = F.avg_pool2d(self.conv2(x), 2, stride=2)
177 | x = self.conv3(x)
178 | x = self.conv4(x)
179 |
180 | previous = x
181 |
182 | outputs = []
183 | for i in range(self.num_modules):
184 | hg = self._modules['m' + str(i)](previous)
185 |
186 | ll = hg
187 | ll = self._modules['top_m_' + str(i)](ll)
188 |
189 | ll = F.relu(self._modules['bn_end' + str(i)]
190 | (self._modules['conv_last' + str(i)](ll)), True)
191 |
192 | # Predict heatmaps
193 | tmp_out = self._modules['l' + str(i)](ll)
194 | outputs.append(tmp_out)
195 |
196 | if i < self.num_modules - 1:
197 | ll = self._modules['bl' + str(i)](ll)
198 | tmp_out_ = self._modules['al' + str(i)](tmp_out)
199 | previous = previous + ll + tmp_out_
200 |
201 | # x.register_hook(lambda grad: print('images',grad))
202 | return outputs
203 |
204 |
205 | class ResNetDepth(nn.Module):
206 |
207 | def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
208 | self.inplanes = 64
209 | super(ResNetDepth, self).__init__()
210 | self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
211 | bias=False)
212 | self.bn1 = nn.BatchNorm2d(64)
213 | self.relu = nn.ReLU(inplace=True)
214 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
215 | self.layer1 = self._make_layer(block, 64, layers[0])
216 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
217 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
218 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
219 | self.avgpool = nn.AvgPool2d(7)
220 | self.fc = nn.Linear(512 * block.expansion, num_classes)
221 |
222 | for m in self.modules():
223 | if isinstance(m, nn.Conv2d):
224 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
225 | m.weight.data.normal_(0, math.sqrt(2. / n))
226 | elif isinstance(m, nn.BatchNorm2d):
227 | m.weight.data.fill_(1)
228 | m.bias.data.zero_()
229 |
230 | def _make_layer(self, block, planes, blocks, stride=1):
231 | downsample = None
232 | if stride != 1 or self.inplanes != planes * block.expansion:
233 | downsample = nn.Sequential(
234 | nn.Conv2d(self.inplanes, planes * block.expansion,
235 | kernel_size=1, stride=stride, bias=False),
236 | nn.BatchNorm2d(planes * block.expansion),
237 | )
238 |
239 | layers = []
240 | layers.append(block(self.inplanes, planes, stride, downsample))
241 | self.inplanes = planes * block.expansion
242 | for i in range(1, blocks):
243 | layers.append(block(self.inplanes, planes))
244 |
245 | return nn.Sequential(*layers)
246 |
247 | def forward(self, x):
248 | # print(x.shape)
249 | # x.register_hook(lambda grad: print('images',grad))
250 | x = self.conv1(x)
251 | x = self.bn1(x)
252 | x = self.relu(x)
253 | x = self.maxpool(x)
254 |
255 | x = self.layer1(x)
256 | x = self.layer2(x)
257 | x = self.layer3(x)
258 | x = self.layer4(x)
259 |
260 | x = self.avgpool(x)
261 | x = x.view(x.size(0), -1)
262 | x = self.fc(x)
263 |
264 |
265 | return x
266 |
--------------------------------------------------------------------------------
/libs/fan_model/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import sys
4 | import time
5 | import torch
6 | import math
7 | import numpy as np
8 | import cv2
9 |
10 | import torchvision.transforms as transforms
11 |
12 |
13 | def _gaussian(
14 | size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
15 | height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
16 | mean_vert=0.5):
17 | # handle some defaults
18 | if width is None:
19 | width = size
20 | if height is None:
21 | height = size
22 | if sigma_horz is None:
23 | sigma_horz = sigma
24 | if sigma_vert is None:
25 | sigma_vert = sigma
26 | center_x = mean_horz * width + 0.5
27 | center_y = mean_vert * height + 0.5
28 | gauss = np.empty((height, width), dtype=np.float32)
29 | # generate kernel
30 | for i in range(height):
31 | for j in range(width):
32 | gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
33 | sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
34 | if normalize:
35 | gauss = gauss / np.sum(gauss)
36 | return gauss
37 |
38 |
39 | def draw_gaussian(image, point, sigma):
40 | # print(type(image))
41 | #
42 | # Check if the gaussian is inside
43 | ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
44 | br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
45 | if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
46 | return image
47 | size = 6 * sigma + 1
48 | g = _gaussian(size)
49 | # print(type(g))
50 | g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
51 | g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
52 | img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
53 | img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
54 | assert (g_x[0] > 0 and g_y[1] > 0)
55 | image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
56 | ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
57 | image[image > 1] = 1
58 |
59 | # quit()
60 | return image
61 |
62 |
63 | def transform(point, center, scale, resolution, invert=False):
64 | """Generate and affine transformation matrix.
65 |
66 | Given a set of points, a center, a scale and a targer resolution, the
67 | function generates and affine transformation matrix. If invert is ``True``
68 | it will produce the inverse transformation.
69 |
70 | Arguments:
71 | point {torch.tensor} -- the input 2D point
72 | center {torch.tensor or numpy.array} -- the center around which to perform the transformations
73 | scale {float} -- the scale of the face/object
74 | resolution {float} -- the output resolution
75 |
76 | Keyword Arguments:
77 | invert {bool} -- define wherever the function should produce the direct or the
78 | inverse transformation matrix (default: {False})
79 | """
80 |
81 | _pt = torch.ones(3)
82 | _pt[0] = point[0]
83 | _pt[1] = point[1]
84 |
85 | h = 200.0 * scale
86 | t = torch.eye(3)
87 | t[0, 0] = resolution / h
88 | t[1, 1] = resolution / h
89 | t[0, 2] = resolution * (-center[0] / h + 0.5)
90 | t[1, 2] = resolution * (-center[1] / h + 0.5)
91 |
92 | if invert:
93 | t = torch.inverse(t)
94 |
95 | new_point = (torch.matmul(t, _pt))[0:2]
96 |
97 | return new_point.int()
98 |
99 |
100 | def crop(image, center, scale, resolution=256.0):
101 | """Center crops an image or set of heatmaps
102 |
103 | Arguments:
104 | image {numpy.array} -- an rgb image
105 | center {numpy.array} -- the center of the object, usually the same as of the bounding box
106 | scale {float} -- scale of the face
107 |
108 | Keyword Arguments:
109 | resolution {float} -- the size of the output cropped image (default: {256.0})
110 |
111 | Returns:
112 | [type] -- [description]
113 | """ # Crop around the center point
114 | """ Crops the image around the center. Input is expected to be an np.ndarray """
115 | ul = transform([1, 1], center, scale, resolution, True)
116 | br = transform([resolution, resolution], center, scale, resolution, True)
117 | # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
118 | if image.ndim > 2:
119 | newDim = np.array([br[1] - ul[1], br[0] - ul[0],
120 | image.shape[2]], dtype=np.int32)
121 | newImg = np.zeros(newDim, dtype=np.uint8)
122 | else:
123 | newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
124 | newImg = np.zeros(newDim, dtype=np.uint8)
125 |
126 | ht = image.shape[0]
127 | wd = image.shape[1]
128 | newX = np.array(
129 | [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
130 | newY = np.array(
131 | [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
132 | oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
133 | oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
134 | newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
135 | ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
136 | newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
137 | interpolation=cv2.INTER_LINEAR)
138 | return newImg
139 |
140 | def crop_torch(image, center, scale, resolution = 256.0):
141 |
142 | # print(type(image))
143 | l1 = transform([1, 1], center, scale, resolution, True)
144 | l2 = transform([resolution, resolution], center, scale, resolution, True)
145 |
146 | newImg = torch.zeros((image.shape[0],image.shape[1], l2[1] - l1[1], l2[0] - l1[0]))
147 | height, width = image.shape[2],image.shape[3]
148 |
149 |
150 | newX = torch.Tensor([max(1, -l1[0] + 1), min(l2[0], width) - l1[0]])
151 | newY = torch.Tensor([max(1, -l1[1] + 1), min(l2[1], height) - l1[1]])
152 | oldX = torch.Tensor([max(1, l1[0] + 1), min(l2[0], width)])
153 | oldY = torch.Tensor([max(1, l1[1] + 1), min(l2[1], height)])
154 |
155 | newImg[:,:,int(newY[0].data.item()) - 1:int(newY[1].data.item()),
156 | int(newX[0].data.item())- 1:int(newX[1].data.item())] = image[:,:,int(oldY[0].data.item()) - 1:int(oldY[1].data.item()),
157 | int(oldX[0].data.item()) - 1:int(oldX[1].data.item())]
158 |
159 | # newImg = newImg.resize(resolution,resolution)
160 | # newImg = torch.nn.functional.interpolate(newImg, size=resolution)
161 |
162 | transformations = transforms.Resize((256,256))
163 | newImg = transformations(newImg)
164 |
165 | return newImg
166 |
167 |
168 | def get_preds_fromhm(hm, center=None, scale=None):
169 | """Obtain (x,y) coordinates given a set of N heatmaps. If the center
170 | and the scale is provided the function will return the points also in
171 | the original coordinate frame.
172 |
173 | Arguments:
174 | hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
175 |
176 | Keyword Arguments:
177 | center {torch.tensor} -- the center of the bounding box (default: {None})
178 | scale {float} -- face scale (default: {None})
179 | """
180 | max, idx = torch.max(
181 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
182 | idx += 1
183 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
184 | preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
185 | preds[..., 1].add_(-1).div_(hm.size(2)).add_(1) # edw eixe .floor()
186 |
187 | for i in range(preds.size(0)):
188 | for j in range(preds.size(1)):
189 | hm_ = hm[i, j, :]
190 | pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
191 | if pX > 0 and pX < 63 and pY > 0 and pY < 63:
192 | diff = torch.FloatTensor(
193 | [hm_[pY, pX + 1] - hm_[pY, pX - 1],
194 | hm_[pY + 1, pX] - hm_[pY - 1, pX]])
195 | preds[i, j].add_(diff.sign_().mul_(.25))
196 |
197 | preds.add_(-.5)
198 |
199 | preds_orig = torch.zeros(preds.size())
200 | if center is not None and scale is not None:
201 | for i in range(hm.size(0)):
202 | for j in range(hm.size(1)):
203 | preds_orig[i, j] = transform(
204 | preds[i, j], center, scale, hm.size(2), True)
205 |
206 | return preds, preds_orig
207 |
208 |
209 | def create_target_heatmap(target_landmarks, centers, scales):
210 | # print(type(target_landmarks))
211 | # print(centers,scales)
212 | # quit()
213 | heatmaps = np.zeros((target_landmarks.shape[0], 68, 64, 64), dtype=np.float32)
214 | for i in range(heatmaps.shape[0]):
215 | for p in range(68):
216 | landmark_cropped_coor = transform(target_landmarks[i, p] + 1, centers[i], scales[i], 64, invert=False)
217 | heatmaps[i, p] = draw_gaussian(heatmaps[i, p], landmark_cropped_coor + 1, 1)
218 | return torch.tensor(heatmaps)
219 |
220 |
221 | def create_bounding_box(target_landmarks, expansion_factor=0.0):
222 | """
223 | gets a batch of landmarks and calculates a bounding box that includes all the landmarks per set of landmarks in
224 | the batch
225 | :param target_landmarks: batch of landmarks of dim (n x 68 x 2). Where n is the batch size
226 | :param expansion_factor: expands the bounding box by this factor. For example, a `expansion_factor` of 0.2 leads
227 | to 20% increase in width and height of the boxes
228 | :return: a batch of bounding boxes of dim (n x 4) where the second dim is (x1,y1,x2,y2)
229 | """
230 | # Calc bounding box
231 | x_y_min, _ = target_landmarks.reshape(-1, 68, 2).min(dim=1)
232 | x_y_max, _ = target_landmarks.reshape(-1, 68, 2).max(dim=1)
233 | # expanding the bounding box
234 | expansion_factor /= 2
235 | bb_expansion_x = (x_y_max[:, 0] - x_y_min[:, 0]) * expansion_factor
236 | bb_expansion_y = (x_y_max[:, 1] - x_y_min[:, 1]) * expansion_factor
237 | x_y_min[:, 0] -= bb_expansion_x
238 | x_y_max[:, 0] += bb_expansion_x
239 | x_y_min[:, 1] -= bb_expansion_y
240 | x_y_max[:, 1] += bb_expansion_y
241 | return torch.cat([x_y_min, x_y_max], dim=1)
242 |
243 |
244 | def shuffle_lr(parts, pairs=None):
245 | """Shuffle the points left-right according to the axis of symmetry
246 | of the object.
247 |
248 | Arguments:
249 | parts {torch.tensor} -- a 3D or 4D object containing the
250 | heatmaps.
251 |
252 | Keyword Arguments:
253 | pairs {list of integers} -- [order of the flipped points] (default: {None})
254 | """
255 | if pairs is None:
256 | pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
257 | 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
258 | 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
259 | 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
260 | 62, 61, 60, 67, 66, 65]
261 | if parts.ndimension() == 3:
262 | parts = parts[pairs, ...]
263 | else:
264 | parts = parts[:, pairs, ...]
265 |
266 | return parts
267 |
268 |
269 | def flip(tensor, is_label=False):
270 | """Flip an image or a set of heatmaps left-right
271 |
272 | Arguments:
273 | tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
274 |
275 | Keyword Arguments:
276 | is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
277 | """
278 | if not torch.is_tensor(tensor):
279 | tensor = torch.from_numpy(tensor)
280 |
281 | if is_label:
282 | tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
283 | else:
284 | tensor = tensor.flip(tensor.ndimension() - 1)
285 |
286 | return tensor
287 |
288 | # From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
289 |
290 |
291 | def appdata_dir(appname=None, roaming=False):
292 | """ appdata_dir(appname=None, roaming=False)
293 |
294 | Get the path to the application directory, where applications are allowed
295 | to write user specific files (e.g. configurations). For non-user specific
296 | data, consider using common_appdata_dir().
297 | If appname is given, a subdir is appended (and created if necessary).
298 | If roaming is True, will prefer a roaming directory (Windows Vista/7).
299 | """
300 |
301 | # Define default user directory
302 | userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
303 | if userDir is None:
304 | userDir = os.path.expanduser('~')
305 | if not os.path.isdir(userDir): # pragma: no cover
306 | userDir = '/var/tmp' # issue #54
307 |
308 | # Get system app data dir
309 | path = None
310 | if sys.platform.startswith('win'):
311 | path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
312 | path = (path2 or path1) if roaming else (path1 or path2)
313 | elif sys.platform.startswith('darwin'):
314 | path = os.path.join(userDir, 'Library', 'Application Support')
315 | # On Linux and as fallback
316 | if not (path and os.path.isdir(path)):
317 | path = userDir
318 |
319 | # Maybe we should store things local to the executable (in case of a
320 | # portable distro or a frozen application that wants to be portable)
321 | prefix = sys.prefix
322 | if getattr(sys, 'frozen', None):
323 | prefix = os.path.abspath(os.path.dirname(sys.executable))
324 | for reldir in ('settings', '../settings'):
325 | localpath = os.path.abspath(os.path.join(prefix, reldir))
326 | if os.path.isdir(localpath): # pragma: no cover
327 | try:
328 | open(os.path.join(localpath, 'test.write'), 'wb').close()
329 | os.remove(os.path.join(localpath, 'test.write'))
330 | except IOError:
331 | pass # We cannot write in this directory
332 | else:
333 | path = localpath
334 | break
335 |
336 | # Get path specific for this app
337 | if appname:
338 | if path == userDir:
339 | appname = '.' + appname.lstrip('.') # Make it a hidden directory
340 | path = os.path.join(path, appname)
341 | if not os.path.isdir(path): # pragma: no cover
342 | os.mkdir(path)
343 |
344 | # Done
345 | return path
346 |
347 | def show_landmarks(image, heatmap, pred_landmarks):
348 | """Show image with pred_landmarks"""
349 | # pred_landmarks = []
350 | # pred_landmarks, _ = get_preds_fromhm(torch.from_numpy(heatmap).unsqueeze(0))
351 | # pred_landmarks = pred_landmarks.squeeze()*4
352 |
353 | # pred_landmarks2 = get_preds_fromhm2(heatmap)
354 | heatmap = np.max(gt_heatmap, axis=0)
355 | heatmap = heatmap / np.max(heatmap)
356 | # image = ski_transform.resize(image, (64, 64))*255
357 | image = image.astype(np.uint8)
358 | heatmap = np.max(gt_heatmap, axis=0)
359 | heatmap = ski_transform.resize(heatmap, (image.shape[0], image.shape[1]))
360 | heatmap *= 255
361 | heatmap = heatmap.astype(np.uint8)
362 | heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
363 | plt.imshow(image)
364 | plt.scatter(gt_landmarks[:, 0], gt_landmarks[:, 1], s=0.5, marker='.', c='g')
365 | plt.scatter(pred_landmarks[:, 0], pred_landmarks[:, 1], s=0.5, marker='.', c='r')
366 | plt.pause(0.001) # pause a bit so that plots are updated
--------------------------------------------------------------------------------