├── .gitignore ├── libs ├── sfd │ ├── __init__.py │ ├── sfd_detector.py │ ├── detect.py │ ├── bbox.py │ ├── net_s3fd.py │ └── core.py ├── ffhq_cropping.py ├── utilities.py ├── landmarks_estimation.py └── fan_model │ ├── models.py │ └── utils.py ├── images └── example.png ├── requirements.txt ├── README.md ├── download_voxCeleb.py └── preprocess_voxCeleb.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | # Models 3 | *.pth 4 | 5 | -------------------------------------------------------------------------------- /libs/sfd/__init__.py: -------------------------------------------------------------------------------- 1 | from .sfd_detector import SFDDetector as FaceDetector -------------------------------------------------------------------------------- /images/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StelaBou/voxceleb_preprocessing/HEAD/images/example.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | numpy 4 | opencv-python 5 | scikit-image 6 | tqdm 7 | matplotlib 8 | scipy>=0.17.0 9 | numba 10 | yt-dlp 11 | -------------------------------------------------------------------------------- /libs/sfd/sfd_detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from torch.utils.model_zoo import load_url 4 | import sys 5 | import matplotlib.pyplot as plt 6 | from .core import FaceDetector 7 | 8 | from .net_s3fd import s3fd 9 | from .bbox import * 10 | from .detect import * 11 | import torch.backends.cudnn as cudnn 12 | 13 | 14 | models_urls = { 15 | 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth', 16 | } 17 | 18 | 19 | class SFDDetector(FaceDetector): 20 | def __init__(self, device, path_to_detector=None, verbose=False): 21 | super(SFDDetector, self).__init__(device, verbose) 22 | 23 | self.device = device 24 | model_weights = torch.load(path_to_detector) 25 | 26 | self.face_detector = s3fd() 27 | self.face_detector.load_state_dict(model_weights) 28 | self.face_detector.to(self.device) 29 | self.face_detector.eval() 30 | 31 | def detect_from_batch(self, tensor): 32 | 33 | bboxlists = batch_detect(self.face_detector, tensor, device=self.device) 34 | 35 | new_bboxlists = [] 36 | for i in range(bboxlists.shape[0]): 37 | bboxlist = bboxlists[i] 38 | keep = nms(bboxlist, 0.3) 39 | # print(keep) 40 | if len(keep)>0: 41 | bboxlist = bboxlist[keep, :] 42 | bboxlist = [x for x in bboxlist if x[-1] > 0.5] 43 | new_bboxlists.append(bboxlist) 44 | else: 45 | new_bboxlists.append([]) 46 | 47 | return new_bboxlists 48 | 49 | @property 50 | def reference_scale(self): 51 | return 195 52 | 53 | @property 54 | def reference_x_shift(self): 55 | return 0 56 | 57 | @property 58 | def reference_y_shift(self): 59 | return 0 60 | -------------------------------------------------------------------------------- /libs/ffhq_cropping.py: -------------------------------------------------------------------------------- 1 | """ 2 | Crop images using facial landmarks 3 | """ 4 | import numpy as np 5 | import cv2 6 | import os 7 | import collections 8 | import PIL.Image 9 | import PIL.ImageFile 10 | from PIL import Image 11 | import scipy.ndimage 12 | 13 | def pad_img_to_fit_bbox(img, x1, x2, y1, y2, crop_box): 14 | img_or = img.copy() 15 | img = cv2.copyMakeBorder(img, 16 | -min(0, y1), max(y2 - img.shape[0], 0), 17 | -min(0, x1), max(x2 - img.shape[1], 0), cv2.BORDER_REFLECT) 18 | 19 | y2 += -min(0, y1) 20 | y1 += -min(0, y1) 21 | x2 += -min(0, x1) 22 | x1 += -min(0, x1) 23 | 24 | pad = crop_box 25 | pad = (max(-pad[0], 0), max(-pad[1], 0), max(pad[2] - img_or.shape[1] , 0), max(pad[3] - img_or.shape[0] , 0)) 26 | 27 | h, w, _ = img.shape 28 | y, x, _ = np.ogrid[:h, :w, :1] 29 | pad = np.array(pad, dtype=np.float32) 30 | pad[pad == 0] = 1e-10 31 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3])) 32 | img = np.array(img, dtype=np.float32) 33 | blur = 5.0 34 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 35 | img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0) 36 | 37 | return img, x1, x2, y1, y2 38 | 39 | def crop_from_bbox(img, bbox): 40 | """ 41 | bbox: tuple, (x1, y1, x2, y2) 42 | x: horizontal, y: vertical, exclusive 43 | """ 44 | x1, y1, x2, y2 = bbox 45 | if x1 < 0 or y1 < 0 or x2 > img.shape[1] or y2 > img.shape[0]: 46 | img, x1, x2, y1, y2 = pad_img_to_fit_bbox(img, x1, x2, y1, y2, bbox) 47 | return img[y1:y2, x1:x2] 48 | 49 | def crop_using_landmarks(image, landmarks): 50 | image_size = 256 51 | center = ((landmarks.min(0) + landmarks.max(0)) / 2).round().astype(int) 52 | size = int(max(landmarks[:, 0].max() - landmarks[:, 0].min(), landmarks[:, 1].max() - landmarks[:, 1].min())) 53 | try: 54 | center[1] -= size // 6 55 | except: 56 | return None 57 | 58 | # Crop images and poses 59 | h, w, _ = image.shape 60 | img = Image.fromarray(image) 61 | crop_box = (center[0]-size, center[1]-size, center[0]+size, center[1]+size) 62 | image = crop_from_bbox(image, crop_box) 63 | try: 64 | img = Image.fromarray(image.astype(np.uint8)) 65 | img = img.resize((image_size, image_size), Image.BICUBIC) 66 | pix = np.array(img) 67 | return pix 68 | except: 69 | return None 70 | -------------------------------------------------------------------------------- /libs/sfd/detect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import os 5 | import sys 6 | import cv2 7 | import random 8 | import datetime 9 | import math 10 | import argparse 11 | import numpy as np 12 | 13 | import scipy.io as sio 14 | import zipfile 15 | from .net_s3fd import s3fd 16 | from .bbox import * 17 | import matplotlib.pyplot as plt 18 | 19 | 20 | def detect(net, img, device): 21 | img = img - np.array([104, 117, 123]) 22 | img = img.transpose(2, 0, 1) 23 | # Creates a batch of 1 24 | img = img.reshape((1,) + img.shape) 25 | 26 | 27 | if torch.cuda.current_device() == 0: 28 | torch.backends.cudnn.benchmark = True 29 | 30 | img = torch.from_numpy(img).float().to(device) 31 | 32 | 33 | return batch_detect(net, img, device) 34 | 35 | 36 | def batch_detect(net, img_batch, device): 37 | """ 38 | Inputs: 39 | - img_batch: a torch.Tensor of shape (Batch size, Channels, Height, Width) 40 | """ 41 | 42 | BB, CC, HH, WW = img_batch.size() 43 | 44 | with torch.no_grad(): 45 | olist = net(img_batch.float()) # patched uint8_t overflow error 46 | 47 | 48 | for i in range(len(olist) // 2): 49 | olist[i * 2] = F.softmax(olist[i * 2], dim=1) 50 | 51 | bboxlists = [] 52 | 53 | olist = [oelem.data.cpu() for oelem in olist] 54 | for j in range(BB): 55 | bboxlist = [] 56 | for i in range(len(olist) // 2): 57 | ocls, oreg = olist[i * 2], olist[i * 2 + 1] 58 | FB, FC, FH, FW = ocls.size() # feature map size 59 | stride = 2**(i + 2) # 4,8,16,32,64,128 60 | anchor = stride * 4 61 | poss = zip(*np.where(ocls[:, 1, :, :] > 0.05)) 62 | 63 | for Iindex, hindex, windex in poss: 64 | 65 | axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride 66 | score = ocls[j, 1, hindex, windex] 67 | loc = oreg[j, :, hindex, windex].contiguous().view(1, 4) 68 | priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]) 69 | variances = [0.1, 0.2] 70 | box = decode(loc, priors, variances) 71 | x1, y1, x2, y2 = box[0] * 1.0 72 | bboxlist.append([x1, y1, x2, y2, score]) 73 | bboxlists.append(bboxlist) 74 | 75 | bboxlists = np.array(bboxlists) 76 | 77 | if 0 == len(bboxlists): 78 | bboxlists = np.zeros((1, 1, 5)) 79 | 80 | 81 | return bboxlists 82 | 83 | 84 | def flip_detect(net, img, device): 85 | img = cv2.flip(img, 1) 86 | b = detect(net, img, device) 87 | 88 | bboxlist = np.zeros(b.shape) 89 | bboxlist[:, 0] = img.shape[1] - b[:, 2] 90 | bboxlist[:, 1] = b[:, 1] 91 | bboxlist[:, 2] = img.shape[1] - b[:, 0] 92 | bboxlist[:, 3] = b[:, 3] 93 | bboxlist[:, 4] = b[:, 4] 94 | return bboxlist 95 | 96 | 97 | def pts_to_bb(pts): 98 | min_x, min_y = np.min(pts, axis=0) 99 | max_x, max_y = np.max(pts, axis=0) 100 | return np.array([min_x, min_y, max_x, max_y]) 101 | -------------------------------------------------------------------------------- /libs/utilities.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | 5 | def make_path(path): 6 | if not os.path.exists(path): 7 | os.makedirs(path, exist_ok= True) 8 | 9 | def _parse_metadata_file(metadata, dataset = 'vox1', frame = None): 10 | 11 | if dataset == 'vox1': 12 | metadata_info = [] 13 | for path in metadata: 14 | frames = [] 15 | bboxes = [] 16 | with open(path) as f: 17 | segment_desc = f.read().strip() 18 | 19 | header, lines = segment_desc.split("\n\n") 20 | lines = lines.split("\n") 21 | 22 | for i in range(1,len(lines)): 23 | 24 | info = lines[i].split("\t") 25 | frames.append(int(info[0])) 26 | x1 = int(info[1]) 27 | y1 = int(info[2]) 28 | x2 = int(info[1]) + int(info[3]) 29 | y2 = int(info[2]) + int(info[4]) 30 | bboxes.append([x1, y1, x2, y2]) 31 | 32 | info = { 33 | 'frames': frames, 34 | 'bboxes': bboxes 35 | } 36 | metadata_info.append(info) 37 | elif dataset == 'vox2': 38 | metadata_info = [] 39 | for path in metadata: 40 | frames = [] 41 | bboxes = [] 42 | with open(path) as f: 43 | segment_desc = f.read().strip() 44 | 45 | header, lines = segment_desc.split("\n\n") 46 | lines = lines.split("\n") 47 | 48 | for i in range(1,len(lines)): 49 | 50 | info = lines[i].split("\t") 51 | frames.append(int(info[0])) 52 | x1 = int(float(info[1]) * frame.shape[1]) 53 | y1 = int(float(info[2]) * frame.shape[0]) 54 | x2 = int(float(info[3]) * frame.shape[1]) + x1 55 | y2 = int(float(info[4]) * frame.shape[0]) + y1 56 | bboxes.append([x1, y1, x2, y2]) 57 | 58 | info = { 59 | 'frames': frames, 60 | 'bboxes': bboxes 61 | } 62 | metadata_info.append(info) 63 | 64 | else: 65 | print('Specify correct dataset') 66 | exit() 67 | 68 | return metadata_info 69 | 70 | def crop_box(image, bbox, scale_crop = 1.0): 71 | 72 | h_im, w_im, c = image.shape 73 | 74 | y1_hat = bbox[1] 75 | y2_hat = bbox[3] 76 | x1_hat = bbox[0] 77 | x2_hat = bbox[2] 78 | 79 | new_w = x2_hat - x1_hat 80 | w = x2_hat - x1_hat 81 | h = y2_hat - y1_hat 82 | cx = int(x1_hat + w/2) 83 | cy = int(y1_hat + h/2) 84 | 85 | w_hat = int(w * scale_crop) 86 | h_hat = int(h * scale_crop) 87 | x1_hat = cx - int(w_hat/2) 88 | if x1_hat < 0: 89 | x1_hat = 0 90 | y1_hat = cy - int(h_hat/2) 91 | if y1_hat < 0: 92 | y1_hat = 0 93 | x2_hat = x1_hat + w_hat 94 | y2_hat = y1_hat + h_hat 95 | 96 | if x2_hat > w_im: 97 | x2_hat = w_im 98 | if y2_hat > h_im: 99 | y2_hat = h_im 100 | 101 | if (y2_hat - y1_hat) > 20 and (x2_hat - x1_hat) > 20: 102 | crop = image[y1_hat:y2_hat, x1_hat:x2_hat, :] 103 | else: 104 | crop = image 105 | 106 | bbox_caled = [x1_hat, y1_hat, x2_hat, y2_hat] 107 | return crop, bbox_caled 108 | 109 | 110 | " Read image from path" 111 | def read_image_opencv(image_path): 112 | img = cv2.imread(image_path, cv2.IMREAD_COLOR) # BGR order 113 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 114 | return img.astype('uint8') -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Download and Preprocessing scripts for VoxCeleb datasets 2 | 3 | This is an auxiliary repo for downloading VoxCeleb videos and preprocessing of the extracted frames by cropping them around the face. For detecting and cropping the face area we use the landmark estimation method proposed in [1], [face-alignment](https://github.com/1adrianb/face-alignment). 4 | 5 |

6 | 7 |

8 | 9 | ## Installation 10 | 11 | 12 | * Python 3.5+ 13 | * Linux 14 | * Pytorch (>=1.5) 15 | 16 | ### Instal requirments: 17 | ``` 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | ### Install youtube-dl: 22 | ``` 23 | pip install --upgrade youtube_dl 24 | ``` 25 | 26 | ### Install ffmpeg 27 | ``` 28 | sudo apt-get install ffmpeg 29 | ``` 30 | 31 | ### Download auxilliary models and save them under `./pretrained_models` 32 | 33 | | Path | Description 34 | | :--- | :---------- 35 | |[FaceDetector](https://drive.google.com/file/d/1IWqJUTAZCelAZrUzfU38zK_ZM25fK32S/view?usp=sharing) | SFD face detector for [face-alignment](https://github.com/1adrianb/face-alignment). 36 | 37 | ## Overview 38 | 39 | * Download videos of VoxCeleb1 or VoxCeleb2 dataset from youtube 40 | * Split videos in smaller ones using the metadata provided by the datasets and delete original videos 41 | * Extract frames from each video with REF_FPS = 25 42 | * Crop frames using the face boxes from the metadata and facial landmarks 43 | * Files are saved as: 44 | ``` 45 | .path/to/voxdataset 46 | |-- id10271 # identity index 47 | | |-- 37nktPRUJ58 # video index 48 | | | |-- chunk_videos # chunk_videos: original video splitted in smaller ones 49 | | | | |-- 37nktPRUJ58#00001#257-396.mp4 50 | | | | |-- ... 51 | | | |-- frames # extracted frames 52 | | | | |-- 00_000025.png 53 | | | | |-- ... 54 | | | |-- frames_cropped # preprocessed frames 55 | | | | |-- 00_000025.png 56 | | | | |-- ... 57 | | |-- Zjc7Xy7aT8c 58 | | | | ... 59 | |-- id10273 60 | | | ... 61 | ``` 62 | 63 | ## Download VoxCeleb datasets 64 | 65 | 1. Download metadata from [VoxCeleb1](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html) and [VoxCeleb2](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox2.html) 66 | 67 | ``` 68 | wget www.robots.ox.ac.uk/~vgg/data/voxceleb/data/vox1_test_txt.zip 69 | unzip vox1_test_txt.zip 70 | mv ./txt ./vox1_txt_test 71 | 72 | wget www.robots.ox.ac.uk/~vgg/data/voxceleb/data/vox1_dev_txt.zip 73 | unzip vox1_dev_txt.zip 74 | mv ./txt ./vox1_txt_train 75 | 76 | ``` 77 | 78 | ``` 79 | wget www.robots.ox.ac.uk/~vgg/data/voxceleb/data/vox2_test_txt.zip 80 | unzip vox2_test_txt.zip 81 | mv ./txt ./vox2_txt_test 82 | 83 | wget www.robots.ox.ac.uk/~vgg/data/voxceleb/data/vox2_dev_txt.zip 84 | unzip vox2_dev_txt.zip 85 | mv ./txt ./vox2_txt_train 86 | 87 | ``` 88 | 89 | 2. Run this script to download videos from youtube. Note that the original videos will be removed. Optionally extract and preprocess frames. 90 | 91 | ``` 92 | python download_voxCeleb.py --dataset vox1 --output_path ./VoxCeleb1_test --metadata_path ./vox1_txt_test --delete_mp4 93 | ``` 94 | 95 | ## Preprocessing of video frames 96 | 97 | 98 | 1. If videos have already been downloaded, run this script to extract and preprocess frames. 99 | 100 | ``` 101 | python preprocess_voxCeleb.py --dataset vox1 --root_path ./VoxCeleb1_test --metadata_path ./vox1_txt_test 102 | ``` 103 | ## Acknowledgments 104 | 105 | This code borrows from [video-preprocessing](https://github.com/AliaksandrSiarohin/video-preprocessing) and [face-alignment](https://github.com/1adrianb/face-alignment). 106 | 107 | ## References 108 | 109 | [1] Bulat, Adrian, and Georgios Tzimiropoulos. "How far are we from solving the 2D & 3D face alignment problem?(and a dataset of 230,000 3d facial landmarks)." Proceedings of the IEEE International Conference on Computer Vision. 2017. 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /libs/sfd/bbox.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import cv2 5 | import random 6 | import datetime 7 | import time 8 | import math 9 | import argparse 10 | import numpy as np 11 | import torch 12 | 13 | try: 14 | from iou import IOU 15 | except BaseException: 16 | # IOU cython speedup 10x 17 | def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2): 18 | sa = abs((ax2 - ax1) * (ay2 - ay1)) 19 | sb = abs((bx2 - bx1) * (by2 - by1)) 20 | x1, y1 = max(ax1, bx1), max(ay1, by1) 21 | x2, y2 = min(ax2, bx2), min(ay2, by2) 22 | w = x2 - x1 23 | h = y2 - y1 24 | if w < 0 or h < 0: 25 | return 0.0 26 | else: 27 | return 1.0 * w * h / (sa + sb - w * h) 28 | 29 | 30 | def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh): 31 | xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1 32 | dx, dy = (xc - axc) / aww, (yc - ayc) / ahh 33 | dw, dh = math.log(ww / aww), math.log(hh / ahh) 34 | return dx, dy, dw, dh 35 | 36 | 37 | def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh): 38 | xc, yc = dx * aww + axc, dy * ahh + ayc 39 | ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh 40 | x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2 41 | return x1, y1, x2, y2 42 | 43 | 44 | def nms(dets, thresh): 45 | # print(dets) 46 | if 0 == len(dets): 47 | return [] 48 | x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4] 49 | # print(x1,x2,y1,y2) 50 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 51 | order = scores.argsort()[::-1] 52 | 53 | keep = [] 54 | while order.size > 0: 55 | i = order[0] 56 | keep.append(i) 57 | xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]]) 58 | xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]]) 59 | 60 | w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1) 61 | ovr = w * h / (areas[i] + areas[order[1:]] - w * h) 62 | 63 | inds = np.where(ovr <= thresh)[0] 64 | order = order[inds + 1] 65 | 66 | return keep 67 | 68 | 69 | def encode(matched, priors, variances): 70 | """Encode the variances from the priorbox layers into the ground truth boxes 71 | we have matched (based on jaccard overlap) with the prior boxes. 72 | Args: 73 | matched: (tensor) Coords of ground truth for each prior in point-form 74 | Shape: [num_priors, 4]. 75 | priors: (tensor) Prior boxes in center-offset form 76 | Shape: [num_priors,4]. 77 | variances: (list[float]) Variances of priorboxes 78 | Return: 79 | encoded boxes (tensor), Shape: [num_priors, 4] 80 | """ 81 | 82 | # dist b/t match center and prior's center 83 | g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] 84 | # encode variance 85 | g_cxcy /= (variances[0] * priors[:, 2:]) 86 | # match wh / prior wh 87 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] 88 | g_wh = torch.log(g_wh) / variances[1] 89 | # return target for smooth_l1_loss 90 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] 91 | 92 | 93 | def decode(loc, priors, variances): 94 | """Decode locations from predictions using priors to undo 95 | the encoding we did for offset regression at train time. 96 | Args: 97 | loc (tensor): location predictions for loc layers, 98 | Shape: [num_priors,4] 99 | priors (tensor): Prior boxes in center-offset form. 100 | Shape: [num_priors,4]. 101 | variances: (list[float]) Variances of priorboxes 102 | Return: 103 | decoded bounding box predictions 104 | """ 105 | 106 | boxes = torch.cat(( 107 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], 108 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) 109 | boxes[:, :2] -= boxes[:, 2:] / 2 110 | boxes[:, 2:] += boxes[:, :2] 111 | return boxes 112 | -------------------------------------------------------------------------------- /libs/sfd/net_s3fd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class L2Norm(nn.Module): 7 | def __init__(self, n_channels, scale=1.0): 8 | super(L2Norm, self).__init__() 9 | self.n_channels = n_channels 10 | self.scale = scale 11 | self.eps = 1e-10 12 | self.weight = nn.Parameter(torch.Tensor(self.n_channels)) 13 | self.weight.data *= 0.0 14 | self.weight.data += self.scale 15 | 16 | def forward(self, x): 17 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps 18 | x = x / norm * self.weight.view(1, -1, 1, 1) 19 | return x 20 | 21 | 22 | class s3fd(nn.Module): 23 | def __init__(self): 24 | super(s3fd, self).__init__() 25 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 26 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 27 | 28 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) 29 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 30 | 31 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) 32 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 33 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 34 | 35 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) 36 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 37 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 38 | 39 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 40 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 41 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 42 | 43 | self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3) 44 | self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0) 45 | 46 | self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) 47 | self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1) 48 | 49 | self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0) 50 | self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) 51 | 52 | self.conv3_3_norm = L2Norm(256, scale=10) 53 | self.conv4_3_norm = L2Norm(512, scale=8) 54 | self.conv5_3_norm = L2Norm(512, scale=5) 55 | 56 | self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 57 | self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 58 | self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 59 | self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 60 | self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 61 | self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 62 | 63 | self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1) 64 | self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1) 65 | self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 66 | self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 67 | self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1) 68 | self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 69 | 70 | def forward(self, x): 71 | h = F.relu(self.conv1_1(x)) 72 | h = F.relu(self.conv1_2(h)) 73 | h = F.max_pool2d(h, 2, 2) 74 | 75 | h = F.relu(self.conv2_1(h)) 76 | h = F.relu(self.conv2_2(h)) 77 | h = F.max_pool2d(h, 2, 2) 78 | 79 | h = F.relu(self.conv3_1(h)) 80 | h = F.relu(self.conv3_2(h)) 81 | h = F.relu(self.conv3_3(h)) 82 | f3_3 = h 83 | h = F.max_pool2d(h, 2, 2) 84 | 85 | h = F.relu(self.conv4_1(h)) 86 | h = F.relu(self.conv4_2(h)) 87 | h = F.relu(self.conv4_3(h)) 88 | f4_3 = h 89 | h = F.max_pool2d(h, 2, 2) 90 | 91 | h = F.relu(self.conv5_1(h)) 92 | h = F.relu(self.conv5_2(h)) 93 | h = F.relu(self.conv5_3(h)) 94 | f5_3 = h 95 | h = F.max_pool2d(h, 2, 2) 96 | 97 | h = F.relu(self.fc6(h)) 98 | h = F.relu(self.fc7(h)) 99 | ffc7 = h 100 | h = F.relu(self.conv6_1(h)) 101 | h = F.relu(self.conv6_2(h)) 102 | f6_2 = h 103 | h = F.relu(self.conv7_1(h)) 104 | h = F.relu(self.conv7_2(h)) 105 | f7_2 = h 106 | 107 | f3_3 = self.conv3_3_norm(f3_3) 108 | f4_3 = self.conv4_3_norm(f4_3) 109 | f5_3 = self.conv5_3_norm(f5_3) 110 | 111 | cls1 = self.conv3_3_norm_mbox_conf(f3_3) 112 | reg1 = self.conv3_3_norm_mbox_loc(f3_3) 113 | cls2 = self.conv4_3_norm_mbox_conf(f4_3) 114 | reg2 = self.conv4_3_norm_mbox_loc(f4_3) 115 | cls3 = self.conv5_3_norm_mbox_conf(f5_3) 116 | reg3 = self.conv5_3_norm_mbox_loc(f5_3) 117 | cls4 = self.fc7_mbox_conf(ffc7) 118 | reg4 = self.fc7_mbox_loc(ffc7) 119 | cls5 = self.conv6_2_mbox_conf(f6_2) 120 | reg5 = self.conv6_2_mbox_loc(f6_2) 121 | cls6 = self.conv7_2_mbox_conf(f7_2) 122 | reg6 = self.conv7_2_mbox_loc(f7_2) 123 | 124 | # max-out background label 125 | chunk = torch.chunk(cls1, 4, 1) 126 | bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2]) 127 | cls1 = torch.cat([bmax, chunk[3]], dim=1) 128 | 129 | return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6] 130 | -------------------------------------------------------------------------------- /libs/sfd/core.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import glob 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | import cv2 7 | from skimage import io 8 | 9 | 10 | class FaceDetector(object): 11 | """An abstract class representing a face detector. 12 | 13 | Any other face detection implementation must subclass it. All subclasses 14 | must implement ``detect_from_image``, that return a list of detected 15 | bounding boxes. Optionally, for speed considerations detect from path is 16 | recommended. 17 | """ 18 | 19 | def __init__(self, device, verbose): 20 | self.device = device 21 | self.verbose = verbose 22 | 23 | # if verbose: 24 | # if 'cpu' in device: 25 | # logger = logging.getLogger(__name__) 26 | # logger.warning("Detection running on CPU, this may be potentially slow.") 27 | 28 | # if 'cpu' not in device and 'cuda' not in device: 29 | # if verbose: 30 | # logger.error("Expected values for device are: {cpu, cuda} but got: %s", device) 31 | # raise ValueError 32 | 33 | def detect_from_image(self, tensor_or_path): 34 | """Detects faces in a given image. 35 | 36 | This function detects the faces present in a provided BGR(usually) 37 | image. The input can be either the image itself or the path to it. 38 | 39 | Arguments: 40 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path 41 | to an image or the image itself. 42 | 43 | Example:: 44 | 45 | >>> path_to_image = 'data/image_01.jpg' 46 | ... detected_faces = detect_from_image(path_to_image) 47 | [A list of bounding boxes (x1, y1, x2, y2)] 48 | >>> image = cv2.imread(path_to_image) 49 | ... detected_faces = detect_from_image(image) 50 | [A list of bounding boxes (x1, y1, x2, y2)] 51 | 52 | """ 53 | raise NotImplementedError 54 | 55 | def detect_from_batch(self, tensor): 56 | """Detects faces in a given image. 57 | 58 | This function detects the faces present in a provided BGR(usually) 59 | image. The input can be either the image itself or the path to it. 60 | 61 | Arguments: 62 | tensor {torch.tensor} -- image batch tensor. 63 | 64 | Example:: 65 | 66 | >>> path_to_image = 'data/image_01.jpg' 67 | ... detected_faces = detect_from_image(path_to_image) 68 | [A list of bounding boxes (x1, y1, x2, y2)] 69 | >>> image = cv2.imread(path_to_image) 70 | ... detected_faces = detect_from_image(image) 71 | [A list of bounding boxes (x1, y1, x2, y2)] 72 | 73 | """ 74 | raise NotImplementedError 75 | 76 | def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True): 77 | """Detects faces from all the images present in a given directory. 78 | 79 | Arguments: 80 | path {string} -- a string containing a path that points to the folder containing the images 81 | 82 | Keyword Arguments: 83 | extensions {list} -- list of string containing the extensions to be 84 | consider in the following format: ``.extension_name`` (default: 85 | {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the 86 | folder recursively (default: {False}) show_progress_bar {bool} -- 87 | display a progressbar (default: {True}) 88 | 89 | Example: 90 | >>> directory = 'data' 91 | ... detected_faces = detect_from_directory(directory) 92 | {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]} 93 | 94 | """ 95 | if self.verbose: 96 | logger = logging.getLogger(__name__) 97 | 98 | if len(extensions) == 0: 99 | if self.verbose: 100 | logger.error("Expected at list one extension, but none was received.") 101 | raise ValueError 102 | 103 | if self.verbose: 104 | logger.info("Constructing the list of images.") 105 | additional_pattern = '/**/*' if recursive else '/*' 106 | files = [] 107 | for extension in extensions: 108 | files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive)) 109 | 110 | if self.verbose: 111 | logger.info("Finished searching for images. %s images found", len(files)) 112 | logger.info("Preparing to run the detection.") 113 | 114 | predictions = {} 115 | for image_path in tqdm(files, disable=not show_progress_bar): 116 | if self.verbose: 117 | logger.info("Running the face detector on image: %s", image_path) 118 | predictions[image_path] = self.detect_from_image(image_path) 119 | 120 | if self.verbose: 121 | logger.info("The detector was successfully run on all %s images", len(files)) 122 | 123 | return predictions 124 | 125 | @property 126 | def reference_scale(self): 127 | raise NotImplementedError 128 | 129 | @property 130 | def reference_x_shift(self): 131 | raise NotImplementedError 132 | 133 | @property 134 | def reference_y_shift(self): 135 | raise NotImplementedError 136 | 137 | @staticmethod 138 | def tensor_or_path_to_ndarray(tensor_or_path, rgb=True): 139 | """Convert path (represented as a string) or torch.tensor to a numpy.ndarray 140 | 141 | Arguments: 142 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself 143 | """ 144 | if isinstance(tensor_or_path, str): 145 | return cv2.imread(tensor_or_path) if not rgb else io.imread(tensor_or_path) 146 | elif torch.is_tensor(tensor_or_path): 147 | # Call cpu in case its coming from cuda 148 | return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy() 149 | elif isinstance(tensor_or_path, np.ndarray): 150 | return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path 151 | else: 152 | raise TypeError 153 | -------------------------------------------------------------------------------- /libs/landmarks_estimation.py: -------------------------------------------------------------------------------- 1 | """ 2 | The face detector used is SFD (taken from face-alignment) 3 | https://github.com/1adrianb/face-alignment 4 | """ 5 | import os 6 | import numpy as np 7 | import cv2 8 | from enum import Enum 9 | import torch 10 | from torch.utils.model_zoo import load_url 11 | 12 | 13 | from libs.sfd.sfd_detector import SFDDetector as FaceDetector 14 | from libs.fan_model.models import FAN, ResNetDepth 15 | from libs.fan_model.utils import * 16 | 17 | models_urls = { 18 | '2DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/2DFAN4-11f355bf06.pth.tar', 19 | '3DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/3DFAN4-7835d9f11d.pth.tar', 20 | 'depth': 'https://www.adrianbulat.com/downloads/python-fan/depth-2a464da4ea.pth.tar', 21 | } 22 | 23 | class LandmarksType(Enum): 24 | """Enum class defining the type of landmarks to detect. 25 | 26 | ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face 27 | ``_2halfD`` - this points represent the projection of the 3D points into 3D 28 | ``_3D`` - detect the points ``(x,y,z)``` in a 3D space 29 | 30 | """ 31 | _2D = 1 32 | _2halfD = 2 33 | _3D = 3 34 | 35 | class NetworkSize(Enum): 36 | # TINY = 1 37 | # SMALL = 2 38 | # MEDIUM = 3 39 | LARGE = 4 40 | 41 | def __new__(cls, value): 42 | member = object.__new__(cls) 43 | member._value_ = value 44 | return member 45 | 46 | def __int__(self): 47 | return self.value 48 | 49 | 50 | def get_preds_fromhm(hm, center=None, scale=None): 51 | """Obtain (x,y) coordinates given a set of N heatmaps. If the center 52 | and the scale is provided the function will return the points also in 53 | the original coordinate frame. 54 | 55 | Arguments: 56 | hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] 57 | 58 | Keyword Arguments: 59 | center {torch.tensor} -- the center of the bounding box (default: {None}) 60 | scale {float} -- face scale (default: {None}) 61 | """ 62 | max, idx = torch.max( 63 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) 64 | idx = idx + 1 65 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() 66 | preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) 67 | preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) 68 | 69 | for i in range(preds.size(0)): 70 | for j in range(preds.size(1)): 71 | hm_ = hm[i, j, :] 72 | pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 73 | if pX > 0 and pX < 63 and pY > 0 and pY < 63: 74 | diff = torch.FloatTensor( 75 | [hm_[pY, pX + 1] - hm_[pY, pX - 1], 76 | hm_[pY + 1, pX] - hm_[pY - 1, pX]]) 77 | preds[i, j].add_(diff.sign_().mul_(.25)) 78 | 79 | preds.add_(-.5) 80 | 81 | preds_orig = torch.zeros(preds.size()) 82 | if center is not None and scale is not None: 83 | for i in range(hm.size(0)): 84 | for j in range(hm.size(1)): 85 | preds_orig[i, j] = transform( 86 | preds[i, j], center, scale, hm.size(2), True) 87 | 88 | return preds, preds_orig 89 | 90 | def draw_detected_face(img, face): 91 | x_min = int(face[0]) 92 | y_min = int(face[1]) 93 | x_max = int(face[2]) 94 | y_max = int(face[3]) 95 | 96 | cv2.rectangle(img, (int(x_min),int(y_min)), (int(x_max),int(y_max)), (255,0,0), 2) 97 | 98 | return img 99 | 100 | 101 | class LandmarksEstimation(): 102 | def __init__(self, type = '3D', path_to_detector = './pretrained_models/s3fd-619a316812.pth'): 103 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 104 | # Load all needed models - Face detector and Pose detector 105 | network_size = NetworkSize.LARGE 106 | network_size = int(network_size) 107 | if type == '3D': 108 | self.landmarks_type = LandmarksType._3D 109 | else: 110 | self.landmarks_type = LandmarksType._2D 111 | self.flip_input = False 112 | 113 | #################### SFD face detection ################### 114 | if not os.path.exists(path_to_detector): 115 | print('Pretrained model of SFD face detector does not exist in {}'.format(path_to_detector)) 116 | exit() 117 | self.face_detector = FaceDetector(device=self.device, verbose=False, path_to_detector = path_to_detector) 118 | ########################################################### 119 | 120 | ################### Initialise the face alignemnt networks ################### 121 | self.face_alignment_net = FAN(network_size) 122 | if self.landmarks_type == LandmarksType._2D: # 123 | network_name = '2DFAN-' + str(network_size) 124 | else: 125 | network_name = '3DFAN-' + str(network_size) 126 | fan_weights = load_url(models_urls[network_name], map_location=lambda storage, loc: storage) 127 | self.face_alignment_net.load_state_dict(fan_weights) 128 | self.face_alignment_net.to(self.device) 129 | self.face_alignment_net.eval() 130 | ############################################################################## 131 | 132 | # Initialiase the depth prediciton network if 3D landmarks 133 | if self.landmarks_type == LandmarksType._3D: 134 | self.depth_prediciton_net = ResNetDepth() 135 | depth_weights = load_url(models_urls['depth'], map_location=lambda storage, loc: storage) 136 | depth_dict = { 137 | k.replace('module.', ''): v for k, 138 | v in depth_weights['state_dict'].items()} 139 | self.depth_prediciton_net.load_state_dict(depth_dict) 140 | self.depth_prediciton_net.to(self.device) 141 | self.depth_prediciton_net.eval() 142 | 143 | def get_landmarks(self, face, image): 144 | 145 | center = torch.FloatTensor( 146 | [(face[2] + face[0]) / 2.0, 147 | (face[3] + face[1]) / 2.0]) 148 | 149 | center[1] = center[1] - (face[3] - face[1]) * 0.12 150 | scale = (face[2] - face[0] + face[3] - face[1]) / self.face_detector.reference_scale 151 | 152 | inp = crop_torch(image, center, scale).float().cuda() 153 | inp = inp.div(255.0) 154 | 155 | out = self.face_alignment_net(inp)[-1] 156 | 157 | if self.flip_input: 158 | out = out + flip(self.face_alignment_net(flip(inp)) 159 | [-1], is_label=True) 160 | out = out.cpu() 161 | 162 | pts, pts_img = get_preds_fromhm(out, center, scale) 163 | out = out.cuda() 164 | # Added 3D landmark support 165 | if self.landmarks_type == LandmarksType._3D: 166 | pts, pts_img = pts.view(68, 2) * 4, pts_img.view(68, 2) 167 | heatmaps = torch.zeros((68,256,256), dtype=torch.float32) 168 | for i in range(68): 169 | if pts[i, 0] > 0: 170 | heatmaps[i] = draw_gaussian( 171 | heatmaps[i], pts[i], 2) 172 | 173 | heatmaps = heatmaps.unsqueeze(0) 174 | 175 | heatmaps = heatmaps.to(self.device) 176 | depth_pred = self.depth_prediciton_net( 177 | torch.cat((inp, heatmaps), 1)).view(68, 1) 178 | 179 | pts_img = pts_img.cuda() 180 | pts_img = torch.cat( 181 | (pts_img, depth_pred * (1.0 / (256.0 / (200.0 * scale)))), 1) 182 | else: 183 | pts, pts_img = pts.view(-1, 68, 2) * 4, pts_img.view(-1, 68, 2) 184 | 185 | return pts_img, out 186 | 187 | def detect_landmarks(self, image): 188 | 189 | if len(image.shape) == 3: 190 | image = image.unsqueeze(0) 191 | 192 | if self.device == 'cuda': 193 | image = image.cuda() 194 | 195 | with torch.no_grad(): 196 | detected_faces = self.face_detector.detect_from_batch(image) 197 | 198 | if self.landmarks_type == LandmarksType._3D: 199 | landmarks = torch.empty((1, 68, 3)) 200 | else: 201 | landmarks = torch.empty((1, 68, 2)) 202 | 203 | for face in detected_faces[0]: 204 | conf = face[4] 205 | if conf > 0.99: 206 | pts_img, heatmaps = self.get_landmarks(face, image) 207 | landmarks[0] = pts_img 208 | 209 | return landmarks -------------------------------------------------------------------------------- /download_voxCeleb.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import imageio 6 | import os 7 | import warnings 8 | import glob 9 | import time 10 | from tqdm import tqdm 11 | from argparse import ArgumentParser 12 | from skimage import img_as_ubyte 13 | from skimage.transform import resize 14 | warnings.filterwarnings("ignore") 15 | import cv2 16 | import yt_dlp as youtube_dl # Changed import 17 | import subprocess 18 | 19 | from libs.utilities import make_path 20 | from preprocess_voxCeleb import extract_frames_opencv, preprocess_frames 21 | from libs.landmarks_estimation import LandmarksEstimation 22 | 23 | """ 24 | 1. Download videos from youtube for VoxCeleb1 dataset 25 | 2. Generate chunk videos using the metadata provided by VoxCeleb1 dataset 26 | 27 | Optionally: 28 | 3. Extract frames from chunk videos 29 | 4. Preprocess extracted frames by cropping them around the detected faces 30 | 31 | Arguments: 32 | output_path: path to save the videos 33 | metadata_path: txt files from VoxCeleb 34 | dataset: dataset name: vox1 or vox2 35 | fail_video_ids: txt file to save the videos ids that fail to download 36 | extract_frames: select for frame extraction from videos 37 | preprocessing: select for frame preprocessing 38 | delete_mp4: select to delete the original video from youtube 39 | delete_or_frames: select to delete the original extracted frames 40 | 41 | python download_voxCeleb.py --output_path ./VoxCeleb1_test --metadata_path ./txt_test --dataset vox1 \ 42 | --fail_video_ids ./fail_video_ids_test.txt --delete_mp4 --extract_frames --preprocessing 43 | 44 | """ 45 | 46 | DEVNULL = open(os.devnull, 'wb') 47 | 48 | REF_FPS = 25 49 | 50 | parser = ArgumentParser() 51 | parser.add_argument("--output_path", required = True, help='Path to save the videos') 52 | parser.add_argument("--metadata_path", required = True, help='Path to metadata') 53 | parser.add_argument("--dataset", required = True, type = str, choices=('vox1', 'vox2'), help="Download vox1 or vox2 dataset") 54 | 55 | parser.add_argument("--fail_video_ids", default=None, help='Txt file to save videos that fail to download') 56 | parser.add_argument("--extract_frames", action='store_true', help='Extract frames from videos') 57 | parser.set_defaults(extract_frames=False) 58 | parser.add_argument("--preprocessing", action='store_true', help='Preprocess extracted frames') 59 | parser.set_defaults(preprocessing=False) 60 | parser.add_argument("--delete_mp4", action='store_true', help='Delete original video downloaded from youtube') 61 | parser.set_defaults(delete_mp4=False) 62 | parser.add_argument("--delete_or_frames", dest='delete_or_frames', action='store_true', help="Delete original frames and keep only the cropped frames") 63 | parser.set_defaults(delete_or_frames=False) 64 | 65 | def my_hook(d): 66 | if d['status'] == 'finished': 67 | print('Done downloading, now converting ...') 68 | 69 | def download_video(video_id, video_path, id_path, fail_video_ids = None): 70 | ydl_opts = { 71 | 'format': 'mp4', 72 | 'outtmpl': video_path, 73 | 'progress_hooks': [my_hook], 74 | 'cookies': 'cookies.txt' 75 | } 76 | success = True 77 | try: 78 | with youtube_dl.YoutubeDL(ydl_opts) as ydl: 79 | ydl.download(['https://www.youtube.com/watch?v=' + video_id]) 80 | except KeyboardInterrupt: 81 | print ('Stopped') 82 | exit() 83 | except: 84 | print('Error downloading video {}'.format(video_id)) 85 | success = False 86 | if fail_video_ids is not None: 87 | f = open(fail_video_ids, "a") 88 | f.write(id_path + '/' + video_id + '\n') 89 | f.close() 90 | return success 91 | 92 | def split_in_utterances(video_id, video_path, utterance_files, chunk_folder): 93 | chunk_videos = [] 94 | utterances = [pd.read_csv(f, sep='\t', skiprows=6) for f in utterance_files] 95 | for i, utterance in enumerate(utterances): 96 | first_frame, last_frame = utterance['FRAME '].iloc[0], utterance['FRAME '].iloc[-1] 97 | st = first_frame 98 | en = last_frame 99 | first_frame = round(first_frame / float(REF_FPS), 3) 100 | last_frame = round(last_frame / float(REF_FPS), 3) 101 | head, tail = os.path.split(utterance_files[i])#utterance_files[i]. 102 | tail = tail.split('.tx')[0] 103 | chunk_name = os.path.join(chunk_folder, video_id + '#' + tail + '#' + str(st) + '-' + str(en) + '.mp4') 104 | 105 | command_fps = 'ffmpeg -y -i {} -qscale:v 5 -r 25 -threads 1 -ss {} -to {} -strict -2 {} -loglevel quiet'.format(video_path, first_frame, last_frame, chunk_name) 106 | os.system(command_fps) 107 | chunk_videos.append(chunk_name) 108 | 109 | return chunk_videos 110 | 111 | if __name__ == "__main__": 112 | 113 | 114 | args = parser.parse_args() 115 | extract_frames = args.extract_frames 116 | preprocessing = args.preprocessing 117 | fail_video_ids = args.fail_video_ids 118 | output_path = args.output_path 119 | make_path(output_path) 120 | delete_mp4 = args.delete_mp4 121 | delete_or_frames = args.delete_or_frames 122 | metadata_path = args.metadata_path 123 | dataset = args.dataset 124 | 125 | if not os.path.exists(metadata_path): 126 | print('Please download the metadata for {} dataset'.format(dataset)) 127 | exit() 128 | 129 | ids_path = glob.glob(os.path.join(metadata_path, '*/')) 130 | ids_path.sort() 131 | print('{} dataset has {} identities'.format(dataset, len(ids_path))) 132 | 133 | print('--Delete original mp4 videos: \t\t{}'.format(delete_mp4)) 134 | print('--Delete original frames: \t\t{}'.format(delete_or_frames)) 135 | print('--Extract frames from chunk videos: \t{}'.format(extract_frames)) 136 | print('--Preprocess original frames: \t\t{}'.format(preprocessing)) 137 | 138 | if preprocessing: 139 | landmark_est = LandmarksEstimation(type = '2D') 140 | 141 | for i, id_path in enumerate(ids_path): 142 | id_index = id_path.split('/')[-2] 143 | videos_path = glob.glob(os.path.join(id_path, '*/')) 144 | videos_path.sort() 145 | print('*********************************************************') 146 | print('Identity {}/{}: {} videos for {} identity'.format(i, len(ids_path), len(videos_path), id_index)) 147 | 148 | for j, video_path in enumerate(videos_path): 149 | 150 | print('{}/{} videos'.format(j, len(videos_path))) 151 | 152 | video_id = video_path.split('/')[-2] 153 | output_path_video = os.path.join(output_path, id_index, video_id) 154 | make_path(output_path_video) 155 | 156 | print('Download video id {}. Save to {}'.format(video_id, output_path_video)) 157 | 158 | txt_metadata = glob.glob(os.path.join(video_path, '*.txt')) 159 | txt_metadata.sort() 160 | 161 | mp4_path = os.path.join(output_path_video, '{}.mp4'.format(video_id)) 162 | if not os.path.exists(mp4_path): 163 | success = download_video(video_id, mp4_path, id_index, fail_video_ids = fail_video_ids) 164 | else: 165 | # Video already exists 166 | success = True 167 | 168 | if success: 169 | # Split in small videos using the metadata 170 | output_path_chunk_videos = os.path.join(output_path, id_index, video_id, 'chunk_videos') 171 | make_path(output_path_chunk_videos) 172 | chunk_videos = split_in_utterances(video_id, mp4_path, txt_metadata, output_path_chunk_videos) 173 | if delete_mp4: # Delete original video downloaded from youtube 174 | command_delete = 'rm -rf {}'.format(mp4_path) 175 | os.system(command_delete) 176 | 177 | extracted_frames_path = os.path.join(output_path_video, 'frames') 178 | if extract_frames: 179 | # Run frame extraction 180 | extract_frames_opencv(chunk_videos, REF_FPS, extracted_frames_path) 181 | if preprocessing: 182 | # Run preprocessing 183 | image_files = glob.glob(os.path.join(extracted_frames_path, '*.png')) 184 | image_files.sort() 185 | if len(image_files) > 0: 186 | save_dir = os.path.join(output_path_video, 'frames_cropped') 187 | make_path(save_dir) 188 | preprocess_frames(dataset, output_path_video, extracted_frames_path, image_files, save_dir, txt_metadata, landmark_est) 189 | else: 190 | print('There are no extracted frames on path: {}'.format(extracted_frames_path)) 191 | 192 | if delete_or_frames and len(image_files) > 0: # Delete original frames 193 | command_delete = 'rm -rf {}'.format(extracted_frames_path) 194 | os.system(command_delete) 195 | else: 196 | print('Error downloading video {}/{}. Deleting folder {}'.format(id_index, video_id, output_path_video)) 197 | command_delete = 'rm -rf {}'.format(output_path_video) 198 | os.system(command_delete) 199 | 200 | -------------------------------------------------------------------------------- /preprocess_voxCeleb.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import os 4 | import glob 5 | from argparse import ArgumentParser 6 | import cv2 7 | import torch 8 | from skimage.transform import resize 9 | 10 | from libs.utilities import make_path, _parse_metadata_file, crop_box, read_image_opencv 11 | from libs.ffhq_cropping import crop_using_landmarks 12 | from libs.landmarks_estimation import LandmarksEstimation 13 | 14 | """ 15 | If chunk videos have already been generated using download_voxCeleb.py: 16 | 17 | 1. Extract frames from chunk videos 18 | 2. Preprocess extracted frames by cropping them around the detected faces 19 | 20 | Arguments: 21 | root_path: path where chunk videos are saved 22 | metadata_path: txt files from VoxCeleb 23 | dataset: dataset name: vox1 or vox2 24 | delete_videos: select to delete all videos 25 | delete_or_frames: select to delete the original extracted frames 26 | 27 | python preprocess_voxCeleb.py --root_path ./VoxCeleb1_test --metadata_path ./vox1_txt_test --dataset vox1 28 | 29 | """ 30 | 31 | REF_FPS = 25 # fps to extract frames 32 | REF_SIZE = 360 # Height 33 | LOW_RES_SIZE = 400 34 | 35 | 36 | parser = ArgumentParser() 37 | parser.add_argument("--root_path", default='videos', required = True, help='Path to youtube videos') 38 | parser.add_argument("--metadata_path", default='metadata', required = True, help='Path to metadata') 39 | parser.add_argument("--dataset", required = True, type = str, choices=('vox1', 'vox2'), help="Download vox1 or vox2 dataset") 40 | 41 | parser.add_argument("--delete_videos", action='store_true', help='Delete chunk videos') 42 | parser.set_defaults(delete_videos=False) 43 | parser.add_argument("--delete_or_frames", dest='delete_or_frames', action='store_true', help="Delete original frames and keep only the cropped frames") 44 | parser.set_defaults(delete_or_frames=False) 45 | 46 | 47 | def get_frames(video_path, frames_path, video_index, fps): 48 | cap = cv2.VideoCapture(video_path) 49 | counter = 0 50 | # a variable to set how many frames you want to skip 51 | frame_skip = fps 52 | while cap.isOpened(): 53 | ret, frame = cap.read() 54 | if not ret: 55 | break 56 | if counter % frame_skip == 0: 57 | cv2.imwrite(os.path.join(frames_path, '{:02d}_{:06d}.png'.format(video_index, counter)), frame) 58 | counter += 1 59 | 60 | cap.release() 61 | cv2.destroyAllWindows() 62 | 63 | 64 | def extract_frames_opencv(videos_tmp, fps, frames_path): 65 | 66 | print('1. Extract frames') 67 | make_path(frames_path) 68 | for i in tqdm(range(len(videos_tmp))): 69 | get_frames(videos_tmp[i], frames_path, i, fps) 70 | 71 | 72 | def preprocess_frames(dataset, output_path_video, frames_path, image_files, save_dir, txt_metadata, landmark_est = None): 73 | 74 | if dataset == 'vox2': 75 | image_ref = read_image_opencv(image_files[0]) 76 | mult = image_ref.shape[0] / REF_SIZE 77 | image_ref = resize(image_ref, (REF_SIZE, int(image_ref.shape[1] / mult)), preserve_range=True) 78 | else: 79 | image_ref = None 80 | 81 | info_metadata = _parse_metadata_file(txt_metadata, dataset = dataset, frame = image_ref) 82 | 83 | errors = [] 84 | chunk_id = 0 85 | frame_i = 0 86 | print('2. Preprocess frames') 87 | for i in tqdm(range(len(image_files))): 88 | 89 | # Check from which chunk video each frame is extracted. 90 | # Frames are saved as chunkid_index.png 91 | image_file = image_files[i] 92 | image_name = image_file.split('/')[-1] 93 | image_chunk_id = image_name.split('.')[0] 94 | image_chunk_id = int(image_chunk_id.split('_')[0]) 95 | bbox = None 96 | if chunk_id != image_chunk_id: 97 | chunk_id += 1 98 | frame_i = 0 99 | ######################################### 100 | if chunk_id < len(info_metadata): 101 | frames = info_metadata[chunk_id]['frames'] 102 | bboxes_metadata = info_metadata[chunk_id]['bboxes'] 103 | # print('Index with chunk videos every REF_FPS frames..') 104 | index = frame_i+1 + frame_i*(REF_FPS-1) 105 | if index < len(bboxes_metadata): 106 | bbox = bboxes_metadata[index] 107 | frame = frames[index] 108 | 109 | if bbox is not None: 110 | image = read_image_opencv(image_file) 111 | frame = image.copy() 112 | (h, w) = image.shape[:2] 113 | 114 | scale_res = REF_SIZE / float(h) 115 | bbox_new = bbox.copy() 116 | bbox_new[0] = bbox_new[0] / scale_res 117 | bbox_new[1] = bbox_new[1] / scale_res 118 | bbox_new[2] = bbox_new[2] / scale_res 119 | bbox_new[3] = bbox_new[3] / scale_res 120 | 121 | cropped_image, bbox_scaled = crop_box(frame, bbox_new, scale_crop = 2.0) 122 | filename = os.path.join(save_dir, image_name) 123 | cv2.imwrite(filename, cv2.cvtColor(cropped_image.copy(), cv2.COLOR_RGB2BGR)) 124 | h, w, _ = cropped_image.shape 125 | image_tensor = torch.tensor(np.transpose(cropped_image, (2,0,1))).float().cuda() 126 | 127 | if landmark_est is not None: 128 | with torch.no_grad(): 129 | landmarks = landmark_est.detect_landmarks( image_tensor.unsqueeze(0)) 130 | landmarks = landmarks[0].detach().cpu().numpy() 131 | landmarks = np.asarray(landmarks) 132 | condition = np.any(landmarks > w) or np.any(landmarks < 0) 133 | if (condition == False) : 134 | img = crop_using_landmarks(cropped_image, landmarks) 135 | if img is not None: 136 | filename = os.path.join(save_dir, image_name) 137 | cv2.imwrite(filename, cv2.cvtColor(img.copy(), cv2.COLOR_RGB2BGR)) 138 | frame_i += 1 139 | 140 | if __name__ == "__main__": 141 | 142 | 143 | args = parser.parse_args() 144 | root_path = args.root_path 145 | if not os.path.exists(root_path): 146 | print('Videos path {} does not exist'.format(root_path)) 147 | 148 | metadata_path = args.metadata_path 149 | delete_videos = args.delete_videos 150 | delete_or_frames = args.delete_or_frames 151 | dataset = args.dataset 152 | 153 | if not os.path.exists(metadata_path): 154 | print('Please download the metadata for {} dataset'.format(dataset)) 155 | exit() 156 | landmark_est = LandmarksEstimation(type = '2D') 157 | 158 | print('--Delete chunk videos: \t\t\t{}'.format(delete_videos)) 159 | print('--Delete original frames: \t\t{}'.format(delete_or_frames)) 160 | 161 | ids_path = glob.glob(os.path.join(root_path, '*/')) 162 | ids_path.sort() 163 | print('Dataset has {} identities'.format(len(ids_path))) 164 | 165 | data_csv = [] 166 | data_low_res = [] 167 | for i, id_path in enumerate(ids_path): 168 | id_index = id_path.split('/')[-2] 169 | videos_path = glob.glob(os.path.join(id_path, '*/')) 170 | videos_path.sort() 171 | print('*********************************************************') 172 | print('Identity {}/{}: {} videos for {} identity'.format(i, len(ids_path), len(videos_path), id_index)) 173 | 174 | count = 0 175 | for j, video_path in enumerate(videos_path): 176 | video_id = video_path.split('/')[-2] 177 | 178 | print('{}/{} videos'.format(j, len(videos_path))) 179 | 180 | output_path_video = os.path.join(root_path, id_index, video_id) 181 | output_path_chunk_videos = os.path.join(output_path_video, 'chunk_videos') 182 | if not os.path.exists(output_path_chunk_videos): 183 | print('path {} does not exist.'.format(output_path_chunk_videos)) 184 | else: 185 | 186 | txt_metadata = glob.glob(os.path.join(metadata_path, id_index, video_id, '*.txt')) 187 | txt_metadata.sort() 188 | 189 | ############################################################ 190 | ### Frame extraction ### 191 | ############################################################ 192 | videos_tmp = glob.glob(os.path.join(output_path_chunk_videos, '*.mp4')) 193 | videos_tmp.sort() 194 | extracted_frames_path = os.path.join(output_path_video, 'frames') 195 | if len(videos_tmp) > 0: 196 | extract_frames_opencv(videos_tmp, REF_FPS, extracted_frames_path) 197 | else: 198 | print('No videos in {}'.format(output_path_video)) 199 | count += 1 200 | continue 201 | 202 | 203 | ############################################################ 204 | ### Preprocessing ### 205 | ############################################################ 206 | image_files = glob.glob(os.path.join(extracted_frames_path, '*.png')) 207 | image_files.sort() 208 | if len(image_files) > 0: 209 | save_dir = os.path.join(output_path_video, 'frames_cropped') 210 | make_path(save_dir) 211 | preprocess_frames(dataset, output_path_video, extracted_frames_path, image_files, save_dir, txt_metadata, landmark_est) 212 | else: 213 | print('No frames in {}'.format(extracted_frames_path)) 214 | 215 | # Delete all chunk videos 216 | if delete_videos: 217 | command_delete = 'rm -rf {}'.format(os.path.join(output_path_video, '*.mp4')) 218 | os.system(command_delete) 219 | # Delete original frames 220 | if delete_or_frames: 221 | command_delete = 'rm -rf {}'.format(os.path.join(output_path_video, frames_folder_name)) 222 | os.system(command_delete) 223 | ################################################ 224 | count += 1 225 | print('*********************************************************') 226 | -------------------------------------------------------------------------------- /libs/fan_model/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False): 8 | "3x3 convolution with padding" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, 10 | stride=strd, padding=padding, bias=bias) 11 | 12 | 13 | class ConvBlock(nn.Module): 14 | def __init__(self, in_planes, out_planes): 15 | super(ConvBlock, self).__init__() 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.conv1 = conv3x3(in_planes, int(out_planes / 2)) 18 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) 19 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) 20 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) 21 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) 22 | 23 | if in_planes != out_planes: 24 | self.downsample = nn.Sequential( 25 | nn.BatchNorm2d(in_planes), 26 | nn.ReLU(True), 27 | nn.Conv2d(in_planes, out_planes, 28 | kernel_size=1, stride=1, bias=False), 29 | ) 30 | else: 31 | self.downsample = None 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | out1 = self.bn1(x) 37 | out1 = F.relu(out1, True) 38 | out1 = self.conv1(out1) 39 | 40 | out2 = self.bn2(out1) 41 | out2 = F.relu(out2, True) 42 | out2 = self.conv2(out2) 43 | 44 | out3 = self.bn3(out2) 45 | out3 = F.relu(out3, True) 46 | out3 = self.conv3(out3) 47 | 48 | out3 = torch.cat((out1, out2, out3), 1) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(residual) 52 | 53 | out3 += residual 54 | 55 | return out3 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None): 63 | super(Bottleneck, self).__init__() 64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 65 | self.bn1 = nn.BatchNorm2d(planes) 66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 67 | padding=1, bias=False) 68 | self.bn2 = nn.BatchNorm2d(planes) 69 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 70 | self.bn3 = nn.BatchNorm2d(planes * 4) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | residual = self.downsample(x) 91 | 92 | out += residual 93 | out = self.relu(out) 94 | 95 | return out 96 | 97 | 98 | class HourGlass(nn.Module): 99 | def __init__(self, num_modules, depth, num_features): 100 | super(HourGlass, self).__init__() 101 | self.num_modules = num_modules 102 | self.depth = depth 103 | self.features = num_features 104 | 105 | self._generate_network(self.depth) 106 | 107 | def _generate_network(self, level): 108 | self.add_module('b1_' + str(level), ConvBlock(self.features, self.features)) 109 | 110 | self.add_module('b2_' + str(level), ConvBlock(self.features, self.features)) 111 | 112 | if level > 1: 113 | self._generate_network(level - 1) 114 | else: 115 | self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features)) 116 | 117 | self.add_module('b3_' + str(level), ConvBlock(self.features, self.features)) 118 | 119 | def _forward(self, level, inp): 120 | # Upper branch 121 | up1 = inp 122 | up1 = self._modules['b1_' + str(level)](up1) 123 | 124 | # Lower branch 125 | low1 = F.avg_pool2d(inp, 2, stride=2) 126 | low1 = self._modules['b2_' + str(level)](low1) 127 | 128 | if level > 1: 129 | low2 = self._forward(level - 1, low1) 130 | else: 131 | low2 = low1 132 | low2 = self._modules['b2_plus_' + str(level)](low2) 133 | 134 | low3 = low2 135 | low3 = self._modules['b3_' + str(level)](low3) 136 | 137 | up2 = F.interpolate(low3, scale_factor=2, mode='nearest') 138 | 139 | return up1 + up2 140 | 141 | def forward(self, x): 142 | return self._forward(self.depth, x) 143 | 144 | 145 | class FAN(nn.Module): 146 | 147 | def __init__(self, num_modules=1): 148 | super(FAN, self).__init__() 149 | self.num_modules = num_modules 150 | 151 | # Base part 152 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 153 | self.bn1 = nn.BatchNorm2d(64) 154 | self.conv2 = ConvBlock(64, 128) 155 | self.conv3 = ConvBlock(128, 128) 156 | self.conv4 = ConvBlock(128, 256) 157 | 158 | # Stacking part 159 | for hg_module in range(self.num_modules): 160 | self.add_module('m' + str(hg_module), HourGlass(1, 4, 256)) 161 | self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256)) 162 | self.add_module('conv_last' + str(hg_module), 163 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 164 | self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) 165 | self.add_module('l' + str(hg_module), nn.Conv2d(256, 166 | 68, kernel_size=1, stride=1, padding=0)) 167 | 168 | if hg_module < self.num_modules - 1: 169 | self.add_module( 170 | 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 171 | self.add_module('al' + str(hg_module), nn.Conv2d(68, 172 | 256, kernel_size=1, stride=1, padding=0)) 173 | 174 | def forward(self, x): 175 | x = F.relu(self.bn1(self.conv1(x)), True) 176 | x = F.avg_pool2d(self.conv2(x), 2, stride=2) 177 | x = self.conv3(x) 178 | x = self.conv4(x) 179 | 180 | previous = x 181 | 182 | outputs = [] 183 | for i in range(self.num_modules): 184 | hg = self._modules['m' + str(i)](previous) 185 | 186 | ll = hg 187 | ll = self._modules['top_m_' + str(i)](ll) 188 | 189 | ll = F.relu(self._modules['bn_end' + str(i)] 190 | (self._modules['conv_last' + str(i)](ll)), True) 191 | 192 | # Predict heatmaps 193 | tmp_out = self._modules['l' + str(i)](ll) 194 | outputs.append(tmp_out) 195 | 196 | if i < self.num_modules - 1: 197 | ll = self._modules['bl' + str(i)](ll) 198 | tmp_out_ = self._modules['al' + str(i)](tmp_out) 199 | previous = previous + ll + tmp_out_ 200 | 201 | # x.register_hook(lambda grad: print('images',grad)) 202 | return outputs 203 | 204 | 205 | class ResNetDepth(nn.Module): 206 | 207 | def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68): 208 | self.inplanes = 64 209 | super(ResNetDepth, self).__init__() 210 | self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3, 211 | bias=False) 212 | self.bn1 = nn.BatchNorm2d(64) 213 | self.relu = nn.ReLU(inplace=True) 214 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 215 | self.layer1 = self._make_layer(block, 64, layers[0]) 216 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 217 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 218 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 219 | self.avgpool = nn.AvgPool2d(7) 220 | self.fc = nn.Linear(512 * block.expansion, num_classes) 221 | 222 | for m in self.modules(): 223 | if isinstance(m, nn.Conv2d): 224 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 225 | m.weight.data.normal_(0, math.sqrt(2. / n)) 226 | elif isinstance(m, nn.BatchNorm2d): 227 | m.weight.data.fill_(1) 228 | m.bias.data.zero_() 229 | 230 | def _make_layer(self, block, planes, blocks, stride=1): 231 | downsample = None 232 | if stride != 1 or self.inplanes != planes * block.expansion: 233 | downsample = nn.Sequential( 234 | nn.Conv2d(self.inplanes, planes * block.expansion, 235 | kernel_size=1, stride=stride, bias=False), 236 | nn.BatchNorm2d(planes * block.expansion), 237 | ) 238 | 239 | layers = [] 240 | layers.append(block(self.inplanes, planes, stride, downsample)) 241 | self.inplanes = planes * block.expansion 242 | for i in range(1, blocks): 243 | layers.append(block(self.inplanes, planes)) 244 | 245 | return nn.Sequential(*layers) 246 | 247 | def forward(self, x): 248 | # print(x.shape) 249 | # x.register_hook(lambda grad: print('images',grad)) 250 | x = self.conv1(x) 251 | x = self.bn1(x) 252 | x = self.relu(x) 253 | x = self.maxpool(x) 254 | 255 | x = self.layer1(x) 256 | x = self.layer2(x) 257 | x = self.layer3(x) 258 | x = self.layer4(x) 259 | 260 | x = self.avgpool(x) 261 | x = x.view(x.size(0), -1) 262 | x = self.fc(x) 263 | 264 | 265 | return x 266 | -------------------------------------------------------------------------------- /libs/fan_model/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import time 5 | import torch 6 | import math 7 | import numpy as np 8 | import cv2 9 | 10 | import torchvision.transforms as transforms 11 | 12 | 13 | def _gaussian( 14 | size=3, sigma=0.25, amplitude=1, normalize=False, width=None, 15 | height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5, 16 | mean_vert=0.5): 17 | # handle some defaults 18 | if width is None: 19 | width = size 20 | if height is None: 21 | height = size 22 | if sigma_horz is None: 23 | sigma_horz = sigma 24 | if sigma_vert is None: 25 | sigma_vert = sigma 26 | center_x = mean_horz * width + 0.5 27 | center_y = mean_vert * height + 0.5 28 | gauss = np.empty((height, width), dtype=np.float32) 29 | # generate kernel 30 | for i in range(height): 31 | for j in range(width): 32 | gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / ( 33 | sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0)) 34 | if normalize: 35 | gauss = gauss / np.sum(gauss) 36 | return gauss 37 | 38 | 39 | def draw_gaussian(image, point, sigma): 40 | # print(type(image)) 41 | # 42 | # Check if the gaussian is inside 43 | ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)] 44 | br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)] 45 | if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1): 46 | return image 47 | size = 6 * sigma + 1 48 | g = _gaussian(size) 49 | # print(type(g)) 50 | g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))] 51 | g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))] 52 | img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))] 53 | img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))] 54 | assert (g_x[0] > 0 and g_y[1] > 0) 55 | image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1] 56 | ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]] 57 | image[image > 1] = 1 58 | 59 | # quit() 60 | return image 61 | 62 | 63 | def transform(point, center, scale, resolution, invert=False): 64 | """Generate and affine transformation matrix. 65 | 66 | Given a set of points, a center, a scale and a targer resolution, the 67 | function generates and affine transformation matrix. If invert is ``True`` 68 | it will produce the inverse transformation. 69 | 70 | Arguments: 71 | point {torch.tensor} -- the input 2D point 72 | center {torch.tensor or numpy.array} -- the center around which to perform the transformations 73 | scale {float} -- the scale of the face/object 74 | resolution {float} -- the output resolution 75 | 76 | Keyword Arguments: 77 | invert {bool} -- define wherever the function should produce the direct or the 78 | inverse transformation matrix (default: {False}) 79 | """ 80 | 81 | _pt = torch.ones(3) 82 | _pt[0] = point[0] 83 | _pt[1] = point[1] 84 | 85 | h = 200.0 * scale 86 | t = torch.eye(3) 87 | t[0, 0] = resolution / h 88 | t[1, 1] = resolution / h 89 | t[0, 2] = resolution * (-center[0] / h + 0.5) 90 | t[1, 2] = resolution * (-center[1] / h + 0.5) 91 | 92 | if invert: 93 | t = torch.inverse(t) 94 | 95 | new_point = (torch.matmul(t, _pt))[0:2] 96 | 97 | return new_point.int() 98 | 99 | 100 | def crop(image, center, scale, resolution=256.0): 101 | """Center crops an image or set of heatmaps 102 | 103 | Arguments: 104 | image {numpy.array} -- an rgb image 105 | center {numpy.array} -- the center of the object, usually the same as of the bounding box 106 | scale {float} -- scale of the face 107 | 108 | Keyword Arguments: 109 | resolution {float} -- the size of the output cropped image (default: {256.0}) 110 | 111 | Returns: 112 | [type] -- [description] 113 | """ # Crop around the center point 114 | """ Crops the image around the center. Input is expected to be an np.ndarray """ 115 | ul = transform([1, 1], center, scale, resolution, True) 116 | br = transform([resolution, resolution], center, scale, resolution, True) 117 | # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0) 118 | if image.ndim > 2: 119 | newDim = np.array([br[1] - ul[1], br[0] - ul[0], 120 | image.shape[2]], dtype=np.int32) 121 | newImg = np.zeros(newDim, dtype=np.uint8) 122 | else: 123 | newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int) 124 | newImg = np.zeros(newDim, dtype=np.uint8) 125 | 126 | ht = image.shape[0] 127 | wd = image.shape[1] 128 | newX = np.array( 129 | [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32) 130 | newY = np.array( 131 | [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32) 132 | oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32) 133 | oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32) 134 | newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1] 135 | ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :] 136 | newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)), 137 | interpolation=cv2.INTER_LINEAR) 138 | return newImg 139 | 140 | def crop_torch(image, center, scale, resolution = 256.0): 141 | 142 | # print(type(image)) 143 | l1 = transform([1, 1], center, scale, resolution, True) 144 | l2 = transform([resolution, resolution], center, scale, resolution, True) 145 | 146 | newImg = torch.zeros((image.shape[0],image.shape[1], l2[1] - l1[1], l2[0] - l1[0])) 147 | height, width = image.shape[2],image.shape[3] 148 | 149 | 150 | newX = torch.Tensor([max(1, -l1[0] + 1), min(l2[0], width) - l1[0]]) 151 | newY = torch.Tensor([max(1, -l1[1] + 1), min(l2[1], height) - l1[1]]) 152 | oldX = torch.Tensor([max(1, l1[0] + 1), min(l2[0], width)]) 153 | oldY = torch.Tensor([max(1, l1[1] + 1), min(l2[1], height)]) 154 | 155 | newImg[:,:,int(newY[0].data.item()) - 1:int(newY[1].data.item()), 156 | int(newX[0].data.item())- 1:int(newX[1].data.item())] = image[:,:,int(oldY[0].data.item()) - 1:int(oldY[1].data.item()), 157 | int(oldX[0].data.item()) - 1:int(oldX[1].data.item())] 158 | 159 | # newImg = newImg.resize(resolution,resolution) 160 | # newImg = torch.nn.functional.interpolate(newImg, size=resolution) 161 | 162 | transformations = transforms.Resize((256,256)) 163 | newImg = transformations(newImg) 164 | 165 | return newImg 166 | 167 | 168 | def get_preds_fromhm(hm, center=None, scale=None): 169 | """Obtain (x,y) coordinates given a set of N heatmaps. If the center 170 | and the scale is provided the function will return the points also in 171 | the original coordinate frame. 172 | 173 | Arguments: 174 | hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] 175 | 176 | Keyword Arguments: 177 | center {torch.tensor} -- the center of the bounding box (default: {None}) 178 | scale {float} -- face scale (default: {None}) 179 | """ 180 | max, idx = torch.max( 181 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) 182 | idx += 1 183 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() 184 | preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) 185 | preds[..., 1].add_(-1).div_(hm.size(2)).add_(1) # edw eixe .floor() 186 | 187 | for i in range(preds.size(0)): 188 | for j in range(preds.size(1)): 189 | hm_ = hm[i, j, :] 190 | pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 191 | if pX > 0 and pX < 63 and pY > 0 and pY < 63: 192 | diff = torch.FloatTensor( 193 | [hm_[pY, pX + 1] - hm_[pY, pX - 1], 194 | hm_[pY + 1, pX] - hm_[pY - 1, pX]]) 195 | preds[i, j].add_(diff.sign_().mul_(.25)) 196 | 197 | preds.add_(-.5) 198 | 199 | preds_orig = torch.zeros(preds.size()) 200 | if center is not None and scale is not None: 201 | for i in range(hm.size(0)): 202 | for j in range(hm.size(1)): 203 | preds_orig[i, j] = transform( 204 | preds[i, j], center, scale, hm.size(2), True) 205 | 206 | return preds, preds_orig 207 | 208 | 209 | def create_target_heatmap(target_landmarks, centers, scales): 210 | # print(type(target_landmarks)) 211 | # print(centers,scales) 212 | # quit() 213 | heatmaps = np.zeros((target_landmarks.shape[0], 68, 64, 64), dtype=np.float32) 214 | for i in range(heatmaps.shape[0]): 215 | for p in range(68): 216 | landmark_cropped_coor = transform(target_landmarks[i, p] + 1, centers[i], scales[i], 64, invert=False) 217 | heatmaps[i, p] = draw_gaussian(heatmaps[i, p], landmark_cropped_coor + 1, 1) 218 | return torch.tensor(heatmaps) 219 | 220 | 221 | def create_bounding_box(target_landmarks, expansion_factor=0.0): 222 | """ 223 | gets a batch of landmarks and calculates a bounding box that includes all the landmarks per set of landmarks in 224 | the batch 225 | :param target_landmarks: batch of landmarks of dim (n x 68 x 2). Where n is the batch size 226 | :param expansion_factor: expands the bounding box by this factor. For example, a `expansion_factor` of 0.2 leads 227 | to 20% increase in width and height of the boxes 228 | :return: a batch of bounding boxes of dim (n x 4) where the second dim is (x1,y1,x2,y2) 229 | """ 230 | # Calc bounding box 231 | x_y_min, _ = target_landmarks.reshape(-1, 68, 2).min(dim=1) 232 | x_y_max, _ = target_landmarks.reshape(-1, 68, 2).max(dim=1) 233 | # expanding the bounding box 234 | expansion_factor /= 2 235 | bb_expansion_x = (x_y_max[:, 0] - x_y_min[:, 0]) * expansion_factor 236 | bb_expansion_y = (x_y_max[:, 1] - x_y_min[:, 1]) * expansion_factor 237 | x_y_min[:, 0] -= bb_expansion_x 238 | x_y_max[:, 0] += bb_expansion_x 239 | x_y_min[:, 1] -= bb_expansion_y 240 | x_y_max[:, 1] += bb_expansion_y 241 | return torch.cat([x_y_min, x_y_max], dim=1) 242 | 243 | 244 | def shuffle_lr(parts, pairs=None): 245 | """Shuffle the points left-right according to the axis of symmetry 246 | of the object. 247 | 248 | Arguments: 249 | parts {torch.tensor} -- a 3D or 4D object containing the 250 | heatmaps. 251 | 252 | Keyword Arguments: 253 | pairs {list of integers} -- [order of the flipped points] (default: {None}) 254 | """ 255 | if pairs is None: 256 | pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 257 | 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35, 258 | 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41, 259 | 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63, 260 | 62, 61, 60, 67, 66, 65] 261 | if parts.ndimension() == 3: 262 | parts = parts[pairs, ...] 263 | else: 264 | parts = parts[:, pairs, ...] 265 | 266 | return parts 267 | 268 | 269 | def flip(tensor, is_label=False): 270 | """Flip an image or a set of heatmaps left-right 271 | 272 | Arguments: 273 | tensor {numpy.array or torch.tensor} -- [the input image or heatmaps] 274 | 275 | Keyword Arguments: 276 | is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False}) 277 | """ 278 | if not torch.is_tensor(tensor): 279 | tensor = torch.from_numpy(tensor) 280 | 281 | if is_label: 282 | tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1) 283 | else: 284 | tensor = tensor.flip(tensor.ndimension() - 1) 285 | 286 | return tensor 287 | 288 | # From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py) 289 | 290 | 291 | def appdata_dir(appname=None, roaming=False): 292 | """ appdata_dir(appname=None, roaming=False) 293 | 294 | Get the path to the application directory, where applications are allowed 295 | to write user specific files (e.g. configurations). For non-user specific 296 | data, consider using common_appdata_dir(). 297 | If appname is given, a subdir is appended (and created if necessary). 298 | If roaming is True, will prefer a roaming directory (Windows Vista/7). 299 | """ 300 | 301 | # Define default user directory 302 | userDir = os.getenv('FACEALIGNMENT_USERDIR', None) 303 | if userDir is None: 304 | userDir = os.path.expanduser('~') 305 | if not os.path.isdir(userDir): # pragma: no cover 306 | userDir = '/var/tmp' # issue #54 307 | 308 | # Get system app data dir 309 | path = None 310 | if sys.platform.startswith('win'): 311 | path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA') 312 | path = (path2 or path1) if roaming else (path1 or path2) 313 | elif sys.platform.startswith('darwin'): 314 | path = os.path.join(userDir, 'Library', 'Application Support') 315 | # On Linux and as fallback 316 | if not (path and os.path.isdir(path)): 317 | path = userDir 318 | 319 | # Maybe we should store things local to the executable (in case of a 320 | # portable distro or a frozen application that wants to be portable) 321 | prefix = sys.prefix 322 | if getattr(sys, 'frozen', None): 323 | prefix = os.path.abspath(os.path.dirname(sys.executable)) 324 | for reldir in ('settings', '../settings'): 325 | localpath = os.path.abspath(os.path.join(prefix, reldir)) 326 | if os.path.isdir(localpath): # pragma: no cover 327 | try: 328 | open(os.path.join(localpath, 'test.write'), 'wb').close() 329 | os.remove(os.path.join(localpath, 'test.write')) 330 | except IOError: 331 | pass # We cannot write in this directory 332 | else: 333 | path = localpath 334 | break 335 | 336 | # Get path specific for this app 337 | if appname: 338 | if path == userDir: 339 | appname = '.' + appname.lstrip('.') # Make it a hidden directory 340 | path = os.path.join(path, appname) 341 | if not os.path.isdir(path): # pragma: no cover 342 | os.mkdir(path) 343 | 344 | # Done 345 | return path 346 | 347 | def show_landmarks(image, heatmap, pred_landmarks): 348 | """Show image with pred_landmarks""" 349 | # pred_landmarks = [] 350 | # pred_landmarks, _ = get_preds_fromhm(torch.from_numpy(heatmap).unsqueeze(0)) 351 | # pred_landmarks = pred_landmarks.squeeze()*4 352 | 353 | # pred_landmarks2 = get_preds_fromhm2(heatmap) 354 | heatmap = np.max(gt_heatmap, axis=0) 355 | heatmap = heatmap / np.max(heatmap) 356 | # image = ski_transform.resize(image, (64, 64))*255 357 | image = image.astype(np.uint8) 358 | heatmap = np.max(gt_heatmap, axis=0) 359 | heatmap = ski_transform.resize(heatmap, (image.shape[0], image.shape[1])) 360 | heatmap *= 255 361 | heatmap = heatmap.astype(np.uint8) 362 | heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) 363 | plt.imshow(image) 364 | plt.scatter(gt_landmarks[:, 0], gt_landmarks[:, 1], s=0.5, marker='.', c='g') 365 | plt.scatter(pred_landmarks[:, 0], pred_landmarks[:, 1], s=0.5, marker='.', c='r') 366 | plt.pause(0.001) # pause a bit so that plots are updated --------------------------------------------------------------------------------