├── .gitignore ├── LICENSE ├── README.md ├── docs ├── examples │ ├── conan_obrien.mp4 │ └── shinzo_abe.mp4 └── teaser.gif ├── fsgan ├── __init__.py ├── criterions │ ├── __init__.py │ ├── gan_loss.py │ └── vgg_loss.py ├── datasets │ ├── __init__.py │ ├── appearance_map.py │ ├── image_list_dataset.py │ ├── image_seg_dataset.py │ ├── img_landmarks_transforms.py │ ├── img_lms_pose_transforms.py │ ├── opencv_video_seq_dataset.py │ ├── seq_dataset.py │ └── video_inference_dataset.py ├── experiments │ ├── reenactment │ │ ├── ijbc_msrunet_reenactment_attr.py │ │ ├── ijbc_msrunet_reenactment_attr_no_seg.py │ │ └── nfv_msrunet_reenactment_attr_no_seg_v2.1.py │ ├── segmentation │ │ └── celeba_unet.py │ └── swapping │ │ ├── ijbc_msrunet_blending.py │ │ └── ijbc_msrunet_inpainting.py ├── inference │ ├── face_swapping.ipynb │ ├── reenact.py │ └── swap.py ├── models │ ├── __init__.py │ ├── classifier1d.py │ ├── discriminators_pix2pix.py │ ├── hopenet.py │ ├── hrnet.py │ ├── msba.py │ ├── res_unet.py │ ├── res_unet_msba.py │ ├── res_unet_split.py │ ├── simple_unet.py │ ├── simple_unet_02.py │ └── vgg.py ├── preprocess │ ├── __init__.py │ ├── clear_cache.py │ ├── crop_image_sequences.py │ ├── crop_video_sequences.py │ ├── crop_video_sequences_batch.py │ ├── detections2sequences_1euro.py │ ├── detections2sequences_center.py │ ├── euler_sequences.py │ ├── preprocess_video.py │ ├── produce_train_val.py │ ├── render_sequences.py │ └── sequence_stats.py ├── train_blending.py ├── train_inpainting.py ├── train_reenactment_attr.py ├── train_reenactment_attr_no_seg.py ├── train_reenactment_attr_no_seg_v2_1.py ├── train_segmentation.py └── utils │ ├── __init__.py │ ├── batch.py │ ├── bbox_utils.py │ ├── blur.py │ ├── confusionmatrix.py │ ├── img_utils.py │ ├── iou_metric.py │ ├── landmarks_utils.py │ ├── obj_factory.py │ ├── one_euro_filter.py │ ├── seg_utils.py │ ├── set_checkpoint_arch.py │ ├── temporal_smoothing.py │ ├── tensorboard_logger.py │ ├── utils.py │ ├── video_renderer.py │ └── video_utils.py ├── fsgan_env.yml └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.pyc 3 | __pycache__ 4 | fsgan/experiments/results 5 | weights 6 | download_fsgan_models.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal 4 | 5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE 6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN 7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS 8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES 9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS 10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM 11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED 12 | HEREUNDER. 13 | 14 | Statement of Purpose 15 | 16 | The laws of most jurisdictions throughout the world automatically confer 17 | exclusive Copyright and Related Rights (defined below) upon the creator 18 | and subsequent owner(s) (each and all, an "owner") of an original work of 19 | authorship and/or a database (each, a "Work"). 20 | 21 | Certain owners wish to permanently relinquish those rights to a Work for 22 | the purpose of contributing to a commons of creative, cultural and 23 | scientific works ("Commons") that the public can reliably and without fear 24 | of later claims of infringement build upon, modify, incorporate in other 25 | works, reuse and redistribute as freely as possible in any form whatsoever 26 | and for any purposes, including without limitation commercial purposes. 27 | These owners may contribute to the Commons to promote the ideal of a free 28 | culture and the further production of creative, cultural and scientific 29 | works, or to gain reputation or greater distribution for their Work in 30 | part through the use and efforts of others. 31 | 32 | For these and/or other purposes and motivations, and without any 33 | expectation of additional consideration or compensation, the person 34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she 35 | is an owner of Copyright and Related Rights in the Work, voluntarily 36 | elects to apply CC0 to the Work and publicly distribute the Work under its 37 | terms, with knowledge of his or her Copyright and Related Rights in the 38 | Work and the meaning and intended legal effect of CC0 on those rights. 39 | 40 | 1. Copyright and Related Rights. A Work made available under CC0 may be 41 | protected by copyright and related or neighboring rights ("Copyright and 42 | Related Rights"). Copyright and Related Rights include, but are not 43 | limited to, the following: 44 | 45 | i. the right to reproduce, adapt, distribute, perform, display, 46 | communicate, and translate a Work; 47 | ii. moral rights retained by the original author(s) and/or performer(s); 48 | iii. publicity and privacy rights pertaining to a person's image or 49 | likeness depicted in a Work; 50 | iv. rights protecting against unfair competition in regards to a Work, 51 | subject to the limitations in paragraph 4(a), below; 52 | v. rights protecting the extraction, dissemination, use and reuse of data 53 | in a Work; 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal 56 | protection of databases, and under any national implementation 57 | thereof, including any amended or successor version of such 58 | directive); and 59 | vii. other similar, equivalent or corresponding rights throughout the 60 | world based on applicable law or treaty, and any national 61 | implementations thereof. 62 | 63 | 2. Waiver. To the greatest extent permitted by, but not in contravention 64 | of, applicable law, Affirmer hereby overtly, fully, permanently, 65 | irrevocably and unconditionally waives, abandons, and surrenders all of 66 | Affirmer's Copyright and Related Rights and associated claims and causes 67 | of action, whether now known or unknown (including existing as well as 68 | future claims and causes of action), in the Work (i) in all territories 69 | worldwide, (ii) for the maximum duration provided by applicable law or 70 | treaty (including future time extensions), (iii) in any current or future 71 | medium and for any number of copies, and (iv) for any purpose whatsoever, 72 | including without limitation commercial, advertising or promotional 73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each 74 | member of the public at large and to the detriment of Affirmer's heirs and 75 | successors, fully intending that such Waiver shall not be subject to 76 | revocation, rescission, cancellation, termination, or any other legal or 77 | equitable action to disrupt the quiet enjoyment of the Work by the public 78 | as contemplated by Affirmer's express Statement of Purpose. 79 | 80 | 3. Public License Fallback. Should any part of the Waiver for any reason 81 | be judged legally invalid or ineffective under applicable law, then the 82 | Waiver shall be preserved to the maximum extent permitted taking into 83 | account Affirmer's express Statement of Purpose. In addition, to the 84 | extent the Waiver is so judged Affirmer hereby grants to each affected 85 | person a royalty-free, non transferable, non sublicensable, non exclusive, 86 | irrevocable and unconditional license to exercise Affirmer's Copyright and 87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the 88 | maximum duration provided by applicable law or treaty (including future 89 | time extensions), (iii) in any current or future medium and for any number 90 | of copies, and (iv) for any purpose whatsoever, including without 91 | limitation commercial, advertising or promotional purposes (the 92 | "License"). The License shall be deemed effective as of the date CC0 was 93 | applied by Affirmer to the Work. Should any part of the License for any 94 | reason be judged legally invalid or ineffective under applicable law, such 95 | partial invalidity or ineffectiveness shall not invalidate the remainder 96 | of the License, and in such case Affirmer hereby affirms that he or she 97 | will not (i) exercise any of his or her remaining Copyright and Related 98 | Rights in the Work or (ii) assert any associated claims and causes of 99 | action with respect to the Work, in either case contrary to Affirmer's 100 | express Statement of Purpose. 101 | 102 | 4. Limitations and Disclaimers. 103 | 104 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 105 | surrendered, licensed or otherwise affected by this document. 106 | b. Affirmer offers the Work as-is and makes no representations or 107 | warranties of any kind concerning the Work, express, implied, 108 | statutory or otherwise, including without limitation warranties of 109 | title, merchantability, fitness for a particular purpose, non 110 | infringement, or the absence of latent or other defects, accuracy, or 111 | the present or absence of errors, whether or not discoverable, all to 112 | the greatest extent permissible under applicable law. 113 | c. Affirmer disclaims responsibility for clearing rights of other persons 114 | that may apply to the Work or any use thereof, including without 115 | limitation any person's Copyright and Related Rights in the Work. 116 | Further, Affirmer disclaims responsibility for obtaining any necessary 117 | consents, permissions or other rights required for any use of the 118 | Work. 119 | d. Affirmer understands and acknowledges that Creative Commons is not a 120 | party to this document and has no duty or obligation with respect to 121 | this CC0 or use of the Work. 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## FSGAN - Official PyTorch Implementation 2 | ![Teaser](./docs/teaser.gif) 3 | Example video face swapping: Barack Obama to Benjamin Netanyahu, Shinzo Abe to Theresa May, and Xi Jinping to 4 | Justin Trudeau. 5 | 6 | This repository contains the source code for the video face swapping and face reenactment method described in the paper: 7 | > **FSGAN: Subject Agnostic Face Swapping and Reenactment** 8 | > *International Conference on Computer Vision (ICCV), Seoul, Korea, 2019* 9 | > Yuval Nirkin, Yosi Keller, Tal Hassner 10 | > [Paper](https://arxiv.org/pdf/1908.05932.pdf)   [Video](https://www.youtube.com/watch?v=BsITEVX6hkE) 11 | > 12 | > **Abstract:** *We present Face Swapping GAN (FSGAN) for face swapping and reenactment. Unlike previous work, FSGAN is subject agnostic and can be applied to pairs of faces without requiring training on those faces. To this end, we describe a number of technical contributions. We derive a novel recurrent neural network (RNN)–based approach for face reenactment which adjusts for both pose and expression variations and can be applied to a single image or a video sequence. For video sequences, we introduce continuous interpolation of the face views based on reenactment, Delaunay Triangulation, and barycentric coordinates. Occluded face regions are handled by a face completion network. Finally, we use a face blending network for seamless blending of the two faces while preserving target skin color and lighting conditions. This network uses a novel Poisson blending loss which combines Poisson optimization with perceptual loss. We compare our approach to existing state-of-the-art systems and show our results to be both qualitatively and quantitatively superior.* 13 | 14 | ## Important note 15 | **THE METHODS PROVIDED IN THIS REPOSITORY ARE NOT TO BE USED FOR MALICIOUS OR INAPPROPRIATE USE CASES.** 16 | We release this code in order to help facilitate research of technical counter-measures for detecting this 17 | kind of forgeries. Suppressing this kind of publications will not stop their development but will only make 18 | it more difficult to detect them. 19 | 20 | Please note this is a work in progress, while we make every effort to improve the results of this method, not 21 | every pair of faces can produce a high quality face swap. 22 | 23 | 24 | ## Requirements 25 | - High-end NVIDIA GPUs with at least 11GB of DRAM. 26 | - Either Linux or Windows. We recommend Linux for better performance. 27 | - CUDA Toolkit 10.1+, CUDNN 7.5+, and the latest NVIDIA driver. 28 | 29 | ## Installation 30 | ```Bash 31 | git clone https://github.com/YuvalNirkin/fsgan 32 | cd fsgan 33 | conda env create -f fsgan_env.yml 34 | conda activate fsgan 35 | pip install . # Alternatively add the root directory of the repository to PYTHONPATH. 36 | ``` 37 | 38 | For accessing FSGAN's pretrained models and auxiliary data, please fill out 39 | [this form](https://docs.google.com/forms/d/e/1FAIpQLScyyNWoFvyaxxfyaPLnCIAxXgdxLEMwR9Sayjh3JpWseuYlOA/viewform?usp=sf_link). 40 | We will then send you a link to FSGAN's shared directory and download script. 41 | ```Bash 42 | python download_fsgan_models.py # From the repository root directory 43 | ``` 44 | 45 | ## Inference 46 | - [Face swapping guide](https://github.com/YuvalNirkin/fsgan/wiki/Face-Swapping-Inference) 47 | - [Face swapping Google Colab](fsgan/inference/face_swapping.ipynb) 48 | - [Paper models guide](https://github.com/YuvalNirkin/fsgan/wiki/Paper-Models-Inference) 49 | 50 | ## Training 51 | - [Training V2](https://github.com/YuvalNirkin/fsgan/wiki/Training-V2) 52 | 53 | ## Comparison on FaceForensics++ 54 | To make it easier to compare against FSGAN, we have provided the FSGAN (original paper) results on the [FaceForensics++](https://github.com/ondyari/FaceForensics) dataset for both the C23 and C40 compressions: 55 | - [FaceForensics++ FSGANv1 C40](https://github.com/YuvalNirkin/fsgan/releases/download/v1.0.1/face_forensics_fsgan_v1_c40.zip) 56 | - [FaceForensics++ FSGANv1 C23 (part 1)](https://github.com/YuvalNirkin/fsgan/releases/download/v1.0.1/face_forensics_fsgan_v1_c23_part1.zip) 57 | - [FaceForensics++ FSGANv1 C23 (part 2)](https://github.com/YuvalNirkin/fsgan/releases/download/v1.0.1/face_forensics_fsgan_v1_c23_part2.zip) 58 | 59 | ## Citation 60 | ``` 61 | @inproceedings{nirkin2019fsgan, 62 | title={{FSGAN}: Subject agnostic face swapping and reenactment}, 63 | author={Nirkin, Yuval and Keller, Yosi and Hassner, Tal}, 64 | booktitle={Proceedings of the IEEE International Conference on Computer Vision}, 65 | pages={7184--7193}, 66 | year={2019} 67 | } 68 | 69 | @inproceedings{nirkin2022fsganv2, 70 | title={{FSGANv2}: Improved Subject Agnostic Face Swapping and Reenactment}, 71 | author={Nirkin, Yuval and Keller, Yosi and Hassner, Tal}, 72 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 73 | year={2022}, 74 | publisher={IEEE} 75 | } 76 | ``` 77 | -------------------------------------------------------------------------------- /docs/examples/conan_obrien.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuvalNirkin/fsgan/0605e7c521ec697cbcd12d4f7d46b02beb808fd1/docs/examples/conan_obrien.mp4 -------------------------------------------------------------------------------- /docs/examples/shinzo_abe.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuvalNirkin/fsgan/0605e7c521ec697cbcd12d4f7d46b02beb808fd1/docs/examples/shinzo_abe.mp4 -------------------------------------------------------------------------------- /docs/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuvalNirkin/fsgan/0605e7c521ec697cbcd12d4f7d46b02beb808fd1/docs/teaser.gif -------------------------------------------------------------------------------- /fsgan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuvalNirkin/fsgan/0605e7c521ec697cbcd12d4f7d46b02beb808fd1/fsgan/__init__.py -------------------------------------------------------------------------------- /fsgan/criterions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuvalNirkin/fsgan/0605e7c521ec697cbcd12d4f7d46b02beb808fd1/fsgan/criterions/__init__.py -------------------------------------------------------------------------------- /fsgan/criterions/gan_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | # Adapted from https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py 5 | class GANLoss(nn.Module): 6 | """ Defines the GAN loss as described in the paper: 7 | `"High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs" 8 | `_. 9 | 10 | Args: 11 | use_lsgan (bool): If True, the least squares version will be used 12 | """ 13 | def __init__(self, use_lsgan=True): 14 | super(GANLoss, self).__init__() 15 | if use_lsgan: 16 | self.loss = nn.MSELoss() 17 | else: 18 | self.loss = nn.BCELoss() 19 | 20 | def __call__(self, input, target_is_real): 21 | if isinstance(input[0], list): 22 | loss = 0 23 | for input_i in input: 24 | pred = input_i[-1] 25 | target_tensor = pred.new_full(pred.shape, target_is_real) 26 | loss += self.loss(pred, target_tensor) 27 | return loss 28 | else: 29 | target_tensor = self.get_target_tensor(input[-1], target_is_real) 30 | return self.loss(input[-1], target_tensor) 31 | -------------------------------------------------------------------------------- /fsgan/criterions/vgg_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | from fsgan.models.vgg import vgg19 5 | 6 | 7 | # Adapted from https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py 8 | class Vgg19(torch.nn.Module): 9 | """ First layers of the VGG 19 model for the VGG loss. 10 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 11 | 12 | Args: 13 | model_path (str): Path to model weights file (.pth) 14 | requires_grad (bool): Enables or disables the "requires_grad" flag for all model parameters 15 | """ 16 | def __init__(self, model_path: str = None, requires_grad: bool = False): 17 | super(Vgg19, self).__init__() 18 | if model_path is None: 19 | vgg_pretrained_features = models.vgg19(pretrained=True).features 20 | else: 21 | model = vgg19(pretrained=False) 22 | checkpoint = torch.load(model_path) 23 | del checkpoint['state_dict']['classifier.6.weight'] 24 | del checkpoint['state_dict']['classifier.6.bias'] 25 | model.load_state_dict(checkpoint['state_dict'], strict=False) 26 | vgg_pretrained_features = model.features 27 | self.slice1 = torch.nn.Sequential() 28 | self.slice2 = torch.nn.Sequential() 29 | self.slice3 = torch.nn.Sequential() 30 | self.slice4 = torch.nn.Sequential() 31 | self.slice5 = torch.nn.Sequential() 32 | for x in range(2): 33 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 34 | for x in range(2, 7): 35 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 36 | for x in range(7, 12): 37 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 38 | for x in range(12, 21): 39 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 40 | for x in range(21, 30): 41 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 42 | if not requires_grad: 43 | for param in self.parameters(): 44 | param.requires_grad = False 45 | 46 | def forward(self, x): 47 | h_relu1 = self.slice1(x) 48 | h_relu2 = self.slice2(h_relu1) 49 | h_relu3 = self.slice3(h_relu2) 50 | h_relu4 = self.slice4(h_relu3) 51 | h_relu5 = self.slice5(h_relu4) 52 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 53 | return out 54 | 55 | 56 | # Adapted from https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py 57 | class VGGLoss(nn.Module): 58 | """ Defines a criterion that captures the high frequency differences between two images. 59 | `"Perceptual Losses for Real-Time Style Transfer and Super-Resolution" `_ 60 | 61 | Args: 62 | model_path (str): Path to model weights file (.pth) 63 | """ 64 | def __init__(self, model_path: str = None): 65 | super(VGGLoss, self).__init__() 66 | self.vgg = Vgg19(model_path) 67 | self.criterion = nn.L1Loss() 68 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] 69 | 70 | def forward(self, x, y): 71 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 72 | loss = 0 73 | for i in range(len(x_vgg)): 74 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 75 | return loss 76 | -------------------------------------------------------------------------------- /fsgan/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuvalNirkin/fsgan/0605e7c521ec697cbcd12d4f7d46b02beb808fd1/fsgan/datasets/__init__.py -------------------------------------------------------------------------------- /fsgan/datasets/appearance_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pickle 4 | from tqdm import tqdm 5 | import numpy as np 6 | import cv2 7 | from scipy.spatial import cKDTree, Delaunay 8 | import torch 9 | import torch.utils.data as data 10 | from fsgan.utils.seg_utils import decode_binary_mask 11 | from fsgan.utils.video_utils import get_video_info 12 | 13 | 14 | def fuse_clusters(points, r=0.5): 15 | """ Select a single point from each cluster of points. 16 | 17 | The clustering is done using a KD-Tree data structure for querying points by radius. 18 | 19 | Args: 20 | points (np.array): A set of points of shape (N, 2) to fuse 21 | r (float): The radius for which to fuse the points 22 | 23 | Returns: 24 | np.array: The indices of remaining points. 25 | """ 26 | kdt = cKDTree(points) 27 | indices = kdt.query_ball_point(points, r=r) 28 | 29 | # Build sorted neightbor list 30 | neighbors = [(i, l) for i, l in enumerate(indices)] 31 | neighbors.sort(key=lambda t: len(t[1]), reverse=True) 32 | 33 | # Mark remaining indices 34 | keep = np.ones(points.shape[0], dtype=bool) 35 | for i, cluster in neighbors: 36 | if not keep[i]: 37 | continue 38 | for j in cluster: 39 | if i == j: 40 | continue 41 | keep[j] = False 42 | 43 | return np.nonzero(keep)[0] 44 | 45 | 46 | class AppearanceMapDataset(data.Dataset): 47 | """A dataset representing the appearance map of a video sequence 48 | 49 | Args: 50 | root (string): Root directory path or file list path. 51 | transform (callable, optional): A function/transform that takes in an PIL image 52 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 53 | Attributes: 54 | video_paths (list): List of video paths 55 | """ 56 | def __init__(self, src_vid_seq_path, tgt_vid_seq_path, src_transform=None, tgt_transform=None, 57 | landmarks_postfix='_lms.npz', pose_postfix='_pose.npz', seg_postfix='_seg.pkl', min_radius=0.5): 58 | assert os.path.isfile(src_vid_seq_path), f'src_vid_seq_path is not a path to a file: {src_vid_seq_path}' 59 | assert os.path.isfile(tgt_vid_seq_path), f'tgt_vid_seq_path is not a path to a file: {tgt_vid_seq_path}' 60 | self.src_transform = src_transform 61 | self.tgt_transform = tgt_transform 62 | self.src_vid_seq_path = src_vid_seq_path 63 | self.tgt_vid_seq_path = tgt_vid_seq_path 64 | self.src_vid = None 65 | self.tgt_vid = None 66 | 67 | # Get target video info 68 | self.width, self.height, self.total_frames, self.fps = get_video_info(tgt_vid_seq_path) 69 | 70 | # Load landmarks 71 | src_lms_path = os.path.splitext(src_vid_seq_path)[0] + landmarks_postfix 72 | self.src_landmarks = np.load(src_lms_path)['landmarks_smoothed'] 73 | tgt_lms_path = os.path.splitext(tgt_vid_seq_path)[0] + landmarks_postfix 74 | self.tgt_landmarks = np.load(tgt_lms_path)['landmarks_smoothed'] 75 | 76 | # Load poses 77 | src_pose_path = os.path.splitext(src_vid_seq_path)[0] + pose_postfix 78 | self.src_poses = np.load(src_pose_path)['poses_smoothed'] 79 | tgt_pose_path = os.path.splitext(tgt_vid_seq_path)[0] + pose_postfix 80 | self.tgt_poses = np.load(tgt_pose_path)['poses_smoothed'] 81 | 82 | # Load target segmentations 83 | tgt_seg_path = os.path.splitext(tgt_vid_seq_path)[0] + seg_postfix 84 | with open(tgt_seg_path, "rb") as fp: # Unpickling 85 | self.tgt_encoded_seg = pickle.load(fp) 86 | 87 | # Initialize appearance map 88 | self.filtered_indices = fuse_clusters(self.src_poses[:, :2], r=min_radius / 99.) 89 | self.points = self.src_poses[self.filtered_indices, :2] 90 | limit_points = np.array([[-75., -75.], [-75., 75.], [75., -75.], [75., 75.]]) / 99. 91 | self.points = np.concatenate((self.points, limit_points)) 92 | self.tri = Delaunay(self.points) 93 | self.valid_size = len(self.filtered_indices) 94 | 95 | # Filter source landmarks and poses and handle edge cases 96 | self.src_landmarks = self.src_landmarks[self.filtered_indices] 97 | self.src_landmarks = np.vstack((self.src_landmarks, np.zeros_like(self.src_landmarks[-1:]))) 98 | self.src_poses = self.src_poses[self.filtered_indices] 99 | self.src_poses = np.vstack((self.src_poses, np.zeros_like(self.src_poses[-1:]))) 100 | 101 | # Initialize cached frames 102 | self.src_frames = [None for i in range(len(self.filtered_indices) + 1)] 103 | 104 | # Handle edge cases 105 | black_rgb = np.zeros((self.height, self.width, 3), dtype='uint8') 106 | self.src_frames[-1] = black_rgb 107 | 108 | def __getitem__(self, index): 109 | """ 110 | Args: 111 | index (int): Index 112 | Returns: 113 | tuple: (image1, image2, target) where target is True for same identity else False. 114 | """ 115 | if self.src_vid is None: 116 | # Open source video on the data loader's process 117 | self.src_vid = cv2.VideoCapture(self.src_vid_seq_path) 118 | if self.tgt_vid is None: 119 | # Open target video on the data loader's process 120 | self.tgt_vid = cv2.VideoCapture(self.tgt_vid_seq_path) 121 | 122 | # Read next target frame and meta-data 123 | ret, tgt_frame_bgr = self.tgt_vid.read() 124 | assert tgt_frame_bgr is not None, 'Failed to read frame from video in index: %d' % index 125 | tgt_frame = tgt_frame_bgr[:, :, ::-1] 126 | tgt_landmarks = self.tgt_landmarks[index] 127 | tgt_pose = self.tgt_poses[index] 128 | tgt_seg = decode_binary_mask(self.tgt_encoded_seg[index]) 129 | 130 | # Query source frames and meta-data given the current target pose 131 | query_point, tilt_angle = tgt_pose[:2], tgt_pose[2] 132 | tri_index = self.tri.find_simplex(query_point[:2]) 133 | tri_vertices = self.tri.simplices[tri_index] 134 | tri_vertices = np.minimum(tri_vertices, self.valid_size) 135 | 136 | # Compute barycentric weights 137 | b = self.tri.transform[tri_index, :2].dot(query_point[:2] - self.tri.transform[tri_index, 2]) 138 | bw = np.array([b[0], b[1], 1 - b.sum()], dtype='float32') 139 | bw[tri_vertices >= self.valid_size] = 0. # Set zero weight for edge points 140 | bw /= bw.sum() 141 | 142 | # Cache source frames 143 | for tv in np.sort(tri_vertices): 144 | if self.src_frames[tv] is None: 145 | self.src_vid.set(cv2.CAP_PROP_POS_FRAMES, self.filtered_indices[tv]) 146 | ret, frame_bgr = self.src_vid.read() 147 | assert frame_bgr is not None, 'Failed to read frame from source video in index: %d' % tv 148 | frame_rgb = frame_bgr[:, :, ::-1] 149 | self.src_frames[tv] = frame_rgb 150 | 151 | # Get source data from appearance map 152 | src_frames = [self.src_frames[tv] for tv in tri_vertices] 153 | src_landmarks = self.src_landmarks[tri_vertices].astype('float32') 154 | src_poses = self.src_poses[tri_vertices].astype('float32') 155 | 156 | # Apply source transformation 157 | if self.src_transform is not None: 158 | src_data = [(src_frames[i], src_landmarks[i], (src_poses[i][2] - tilt_angle) * 99.) 159 | for i in range(len(src_frames))] 160 | src_data = self.src_transform(src_data) 161 | src_landmarks = torch.stack([src_data[i][1] for i in range(len(src_data))]) 162 | src_frames = [src_data[i][0] for i in range(len(src_data))] 163 | src_poses[:, 2] = tilt_angle 164 | 165 | # Apply target transformation 166 | if self.tgt_transform is not None: 167 | tgt_frame = self.tgt_transform(tgt_frame) 168 | 169 | # Combine pyramids in source frames if they exist 170 | if isinstance(src_frames[0], (list, tuple)): 171 | src_frames = [torch.stack([src_frames[f][p] for f in range(len(src_frames))], dim=0) 172 | for p in range(len(src_frames[0]))] 173 | 174 | return src_frames, src_landmarks, src_poses, bw, tgt_frame, tgt_landmarks, tgt_pose, tgt_seg 175 | 176 | def __len__(self): 177 | return self.tgt_poses.shape[0] 178 | -------------------------------------------------------------------------------- /fsgan/datasets/image_seg_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from itertools import groupby 4 | import numpy as np 5 | import cv2 6 | from PIL import Image 7 | import torch 8 | import torch.utils.data as data 9 | from fsgan.datasets.image_list_dataset import ImageListDataset 10 | import fsgan.datasets.img_landmarks_transforms as img_landmarks_transforms 11 | 12 | 13 | def seg_label2img(seg, classes=3): 14 | out_seg = np.zeros(seg.shape + (classes,), dtype=seg.dtype) 15 | # out_seg = np.full(seg.shape + (classes,), 127.5, dtype='float32') 16 | for i in range(classes): 17 | out_seg[:, :, i][seg == i] = 255 18 | 19 | return out_seg 20 | 21 | 22 | class ImageSegDataset(ImageListDataset): 23 | """An image list datset with corresponding bounding boxes where the images can be arranged in this way: 24 | 25 | root/id1/xxx.png 26 | root/id1/xxy.png 27 | root/id1/xxz.png 28 | 29 | root/id2/123.png 30 | root/id2/nsdf3.png 31 | root/id2/asd932_.png 32 | 33 | Args: 34 | root (string): Root directory path. 35 | img_list (string): Image list file path. 36 | transform (callable, optional): A function/transform that takes in an PIL image 37 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 38 | target_transform (callable, optional): A function/transform that takes in the 39 | target and transforms it. 40 | loader (string, optional): 'opencv', 'accimage', or 'pil' 41 | 42 | Attributes: 43 | classes (list): List of the class names. 44 | class_to_idx (dict): Dict with items (class_name, class_index). 45 | imgs (list): List of (image path, class_index) tuples 46 | """ 47 | def __init__(self, root, img_list, bboxes_list=None, transform=None, target_transform=None, loader='opencv', 48 | seg_postfix='_mask.png', seg_classes=3, mask_root=None, classification=False): 49 | super(ImageSegDataset, self).__init__(root, img_list, bboxes_list, None, transform, target_transform, loader) 50 | self.seg_classes = seg_classes 51 | if not classification: 52 | self.classes = list(range(seg_classes)) 53 | self.classification = classification 54 | if mask_root is None: 55 | self.segs = [os.path.splitext(p)[0] + seg_postfix for p in self.imgs] 56 | else: 57 | self.segs = [os.path.join(mask_root, os.path.splitext(os.path.relpath(p, root))[0] + seg_postfix) 58 | for p in self.imgs] 59 | 60 | # Validate that all mask files exist 61 | for seg_path in self.segs: 62 | assert os.path.isfile(seg_path), 'Could not find mask file: "%s"' % seg_path 63 | 64 | def get_data(self, index): 65 | img, target, bbox = super(ImageSegDataset, self).get_data(index) 66 | seg = np.array(Image.open(self.segs[index])) 67 | 68 | # Convert segmentation format to a channel for each class 69 | seg = seg_label2img(seg, self.seg_classes) 70 | 71 | return img, seg, bbox, target 72 | 73 | def __getitem__(self, index): 74 | """ 75 | Args: 76 | index (int): Index 77 | 78 | Returns: 79 | tuple: (image, segmentation) 80 | """ 81 | img, seg, bbox, target = self.get_data(index) 82 | if self.transform is not None: 83 | seg_scale = seg.shape[0] / img.shape[0] 84 | bboxes = [bbox, bbox * seg_scale] 85 | img, seg = tuple(self.transform([img, seg], bboxes) if bbox is None else self.transform([img, seg], bboxes)) 86 | # seg[(seg[:, :, 0] < 0) & (seg[:, :, 1] < 0) & (seg[:, :, 2] < 0)] = 1.0 87 | # seg = torch.clamp(seg, min=0.0) 88 | # if self.transform is not None: 89 | # img = self.transform(img) if bbox is None else self.transform(img, bbox) 90 | # if self.target_transform is not None: 91 | # seg_scale = seg.shape[0] / img.shape[0] 92 | # seg = self.target_transform(seg, bbox * seg_scale) 93 | 94 | # Postprocess segmentation 95 | seg[0, :, :][(seg[0, :, :] <= 0) & (seg[1, :, :] <= 0) & (seg[0, :, :] <= 0)] = 1. 96 | seg = torch.clamp(seg, min=0.0) 97 | 98 | if self.classification: 99 | return img, seg, target 100 | else: 101 | return img, seg 102 | 103 | def __len__(self): 104 | return len(self.imgs) 105 | 106 | 107 | def main(dataset='fsgan.datasets.image_seg_dataset.ImageSegDataset', 108 | np_transforms1=None, np_transforms2=None, 109 | tensor_transforms1=('img_landmarks_transforms.ToTensor()', 110 | 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), 111 | tensor_transforms2=('img_landmarks_transforms.ToTensor()',), 112 | workers=4, batch_size=4): 113 | import time 114 | from fsgan.utils.obj_factory import obj_factory 115 | from fsgan.utils.seg_utils import blend_seg_pred, blend_seg_label 116 | from fsgan.utils.img_utils import tensor2bgr 117 | 118 | np_transforms1 = obj_factory(np_transforms1) if np_transforms1 is not None else [] 119 | tensor_transforms1 = obj_factory(tensor_transforms1) if tensor_transforms1 is not None else [] 120 | img_transforms1 = img_landmarks_transforms.Compose(np_transforms1 + tensor_transforms1) 121 | np_transforms2 = obj_factory(np_transforms2) if np_transforms2 is not None else [] 122 | tensor_transforms2 = obj_factory(tensor_transforms2) if tensor_transforms2 is not None else [] 123 | img_transforms2 = img_landmarks_transforms.Compose(np_transforms2 + tensor_transforms2) 124 | dataset = obj_factory(dataset, transform=img_transforms1, target_transform=img_transforms2) 125 | dataloader = data.DataLoader(dataset, batch_size=batch_size, num_workers=workers, pin_memory=True, drop_last=True, 126 | shuffle=True) 127 | 128 | start = time.time() 129 | for img, seg in dataloader: 130 | # For each batch 131 | for b in range(img.shape[0]): 132 | blend_tensor = blend_seg_pred(img, seg) 133 | render_img = tensor2bgr(blend_tensor[b]) 134 | # render_img = tensor2bgr(img[b]) 135 | cv2.imshow('render_img', render_img) 136 | if cv2.waitKey(0) & 0xFF == ord('q'): 137 | break 138 | end = time.time() 139 | print('elapsed time: %f[s]' % (end - start)) 140 | 141 | 142 | if __name__ == "__main__": 143 | # Parse program arguments 144 | import argparse 145 | parser = argparse.ArgumentParser('image_seg_dataset') 146 | parser.add_argument('dataset', metavar='OBJ', default='fsgan.datasets.image_seg_dataset.ImageSegDataset', 147 | help='dataset object') 148 | parser.add_argument('-nt1', '--np_transforms1', nargs='+', help='Numpy transforms') 149 | parser.add_argument('-nt2', '--np_transforms2', nargs='+', help='Numpy transforms') 150 | parser.add_argument('-tt1', '--tensor_transforms1', nargs='+', help='tensor transforms', 151 | default=('img_landmarks_transforms.ToTensor()', 152 | 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])')) 153 | parser.add_argument('-tt2', '--tensor_transforms2', nargs='+', help='tensor transforms', 154 | default=('img_landmarks_transforms.ToTensor()',)) 155 | parser.add_argument('-w', '--workers', default=4, type=int, metavar='N', 156 | help='number of data loading workers (default: 4)') 157 | parser.add_argument('-b', '--batch-size', default=4, type=int, metavar='N', 158 | help='mini-batch size (default: 4)') 159 | main(**vars(parser.parse_args())) 160 | -------------------------------------------------------------------------------- /fsgan/datasets/video_inference_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import pickle 4 | import torch.utils.data as data 5 | import numpy as np 6 | import cv2 7 | import torch 8 | from fsgan.utils.video_utils import Sequence, get_video_info 9 | 10 | 11 | class VideoInferenceDataset(data.Dataset): 12 | """A dataset for loading video sequences. 13 | 14 | Args: 15 | root (string): Root directory path or file list path. 16 | transform (callable, optional): A function/transform that takes in an PIL image 17 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 18 | Attributes: 19 | video_paths (list): List of video paths 20 | """ 21 | 22 | def __init__(self, vid_path, seq=None, transform=None): 23 | self.vid_path = vid_path 24 | self.seq = seq 25 | self.transform = transform 26 | self.cap = None 27 | 28 | # Get video info 29 | self.width, self.height, self.total_frames, self.fps = get_video_info(vid_path) 30 | 31 | def __getitem__(self, index): 32 | """ 33 | Args: 34 | index (int): Index 35 | Returns: 36 | tuple: (image1, image2, target) where target is True for same identity else False. 37 | """ 38 | if self.cap is None: 39 | # Open video file 40 | self.cap = cv2.VideoCapture(self.vid_path) 41 | if self.seq is not None: 42 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.seq.start_index) 43 | 44 | ret, frame_bgr = self.cap.read() 45 | assert frame_bgr is not None, 'Failed to read frame from video in index: %d' % index 46 | frame_rgb = frame_bgr[:, :, ::-1] 47 | bbox = self.seq.detections[index] if self.seq is not None else None 48 | 49 | # Apply transformation 50 | if self.transform is not None: 51 | if bbox is None: 52 | frame_rgb = self.transform(frame_rgb) 53 | else: 54 | frame_rgb = self.transform(frame_rgb, bbox) 55 | 56 | return frame_rgb 57 | 58 | def __len__(self): 59 | return self.total_frames if self.seq is None else len(self.seq) 60 | -------------------------------------------------------------------------------- /fsgan/experiments/reenactment/ijbc_msrunet_reenactment_attr.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.optim.lr_scheduler as lr_scheduler 6 | from fsgan.datasets.opencv_video_seq_dataset import VideoSeqPairDataset 7 | from fsgan.datasets.img_landmarks_transforms import Crop, Resize, RandomHorizontalFlip, Pyramids, ToTensor 8 | from fsgan.criterions.vgg_loss import VGGLoss 9 | from fsgan.criterions.gan_loss import GANLoss 10 | from fsgan.models.res_unet_split import MultiScaleResUNet 11 | from fsgan.models.discriminators_pix2pix import MultiscaleDiscriminator 12 | from fsgan.train_reenactment_attr import main 13 | 14 | 15 | if __name__ == '__main__': 16 | exp_name = os.path.splitext(os.path.basename(__file__))[0] 17 | exp_dir = os.path.join('../results/reenactment', exp_name) 18 | train_dataset = partial(VideoSeqPairDataset, '/data/datasets/ijb-c/ijbc_cropped/ijbc_cropped_r256_cs1.2', 19 | 'train_list.txt', frame_window=1, ignore_landmarks=True, same_prob=1.0) 20 | val_dataset = partial(VideoSeqPairDataset, '/data/datasets/ijb-c/ijbc_cropped/ijbc_cropped_r256_cs1.2', 21 | 'val_list.txt', frame_window=1, ignore_landmarks=True, same_prob=1.0) 22 | numpy_transforms = [RandomHorizontalFlip(), Pyramids(2)] 23 | tensor_transforms = [ToTensor()] 24 | resolutions = [128, 256] 25 | lr_gen = [1e-4, 4e-5] 26 | lr_dis = [1e-5, 4e-6] 27 | epochs = [24, 50] 28 | iterations = ['20k'] 29 | batch_size = [24, 12] 30 | workers = 32 31 | pretrained = False 32 | criterion_id = VGGLoss('../../../weights/vggface2_vgg19_256_1_2_id.pth') 33 | criterion_attr = VGGLoss('../../../weights/celeba_vgg19_256_2_0_28_attr.pth') 34 | criterion_gan = GANLoss(use_lsgan=True) 35 | generator = MultiScaleResUNet(in_nc=101, out_nc=(3, 3), flat_layers=(2, 2, 2, 2), ngf=128) 36 | discriminator = MultiscaleDiscriminator(use_sigmoid=True, num_D=2) 37 | optimizer = partial(optim.Adam, betas=(0.5, 0.999)) 38 | scheduler = partial(lr_scheduler.StepLR, step_size=10, gamma=0.5) 39 | seg_model = '../../weights/lfw_figaro_unet_256_segmentation.pth' 40 | lms_model = '../../weights/hr18_wflw_landmarks.pth' 41 | seg_weight = 0.1 42 | rec_weight = 1.0 43 | gan_weight = 0.001 44 | 45 | if not os.path.exists(exp_dir): 46 | os.makedirs(exp_dir) 47 | 48 | main(exp_dir, train_dataset=train_dataset, val_dataset=val_dataset, 49 | numpy_transforms=numpy_transforms, tensor_transforms=tensor_transforms, resolutions=resolutions, 50 | lr_gen=lr_gen, lr_dis=lr_dis, epochs=epochs, iterations=iterations, batch_size=batch_size, workers=workers, 51 | optimizer=optimizer, scheduler=scheduler, pretrained=pretrained, 52 | criterion_id=criterion_id, criterion_attr=criterion_attr, criterion_gan=criterion_gan, 53 | generator=generator, discriminator=discriminator, seg_model=seg_model, lms_model=lms_model, 54 | seg_weight=seg_weight, rec_weight=rec_weight, gan_weight=gan_weight) 55 | 56 | os.system('sudo shutdown') 57 | -------------------------------------------------------------------------------- /fsgan/experiments/reenactment/ijbc_msrunet_reenactment_attr_no_seg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.optim.lr_scheduler as lr_scheduler 6 | from fsgan.datasets.opencv_video_seq_dataset import VideoSeqPairDataset 7 | from fsgan.datasets.img_landmarks_transforms import RandomHorizontalFlip, Pyramids, ToTensor 8 | from fsgan.criterions.vgg_loss import VGGLoss 9 | from fsgan.criterions.gan_loss import GANLoss 10 | from fsgan.models.res_unet import MultiScaleResUNet 11 | from fsgan.models.discriminators_pix2pix import MultiscaleDiscriminator 12 | from fsgan.train_reenactment_attr_no_seg import main 13 | 14 | 15 | if __name__ == '__main__': 16 | exp_name = os.path.splitext(os.path.basename(__file__))[0] 17 | exp_dir = os.path.join('../results/reenactment', exp_name) 18 | train_dataset = partial(VideoSeqPairDataset, '/data/datasets/ijb-c/ijbc_cropped/ijbc_cropped_r256_cs1.2', 19 | 'train_list.txt', frame_window=1, ignore_landmarks=True, same_prob=1.0) 20 | val_dataset = partial(VideoSeqPairDataset, '/data/datasets/ijb-c/ijbc_cropped/ijbc_cropped_r256_cs1.2', 21 | 'val_list.txt', frame_window=1, ignore_landmarks=True, same_prob=1.0) 22 | numpy_transforms = [RandomHorizontalFlip(), Pyramids(2)] 23 | tensor_transforms = [ToTensor()] 24 | resolutions = [128, 256] 25 | lr_gen = [1e-4, 4e-5] 26 | lr_dis = [1e-5, 4e-6] 27 | epochs = [24, 50] 28 | iterations = ['20k'] 29 | batch_size = [32, 16] 30 | workers = 32 31 | pretrained = False 32 | criterion_id = VGGLoss('../../../weights/vggface2_vgg19_256_1_2_id.pth') 33 | criterion_attr = VGGLoss('../../../weights/celeba_vgg19_256_2_0_28_attr.pth') 34 | criterion_gan = GANLoss(use_lsgan=True) 35 | generator = MultiScaleResUNet(in_nc=101, out_nc=3, flat_layers=(2, 2, 2, 2), ngf=128) 36 | discriminator = MultiscaleDiscriminator(use_sigmoid=True, num_D=2) 37 | optimizer = partial(optim.Adam, betas=(0.5, 0.999)) 38 | scheduler = partial(lr_scheduler.StepLR, step_size=10, gamma=0.5) 39 | lms_model = '../../weights/hr18_wflw_landmarks.pth' 40 | rec_weight = 1.0 41 | gan_weight = 0.001 42 | 43 | if not os.path.exists(exp_dir): 44 | os.makedirs(exp_dir) 45 | 46 | main(exp_dir, train_dataset=train_dataset, val_dataset=val_dataset, 47 | numpy_transforms=numpy_transforms, tensor_transforms=tensor_transforms, resolutions=resolutions, 48 | lr_gen=lr_gen, lr_dis=lr_dis, epochs=epochs, iterations=iterations, batch_size=batch_size, workers=workers, 49 | optimizer=optimizer, scheduler=scheduler, pretrained=pretrained, 50 | criterion_id=criterion_id, criterion_attr=criterion_attr, criterion_gan=criterion_gan, 51 | generator=generator, discriminator=discriminator, lms_model=lms_model, 52 | rec_weight=rec_weight, gan_weight=gan_weight) 53 | 54 | os.system('sudo shutdown') 55 | -------------------------------------------------------------------------------- /fsgan/experiments/reenactment/nfv_msrunet_reenactment_attr_no_seg_v2.1.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.optim.lr_scheduler as lr_scheduler 6 | from fsgan.datasets.seq_dataset import SeqPairDataset 7 | from fsgan.datasets.img_lms_pose_transforms import RandomHorizontalFlip, Pyramids 8 | from fsgan.criterions.vgg_loss import VGGLoss 9 | from fsgan.criterions.gan_loss import GANLoss 10 | from fsgan.models.res_unet import MultiScaleResUNet 11 | from fsgan.models.discriminators_pix2pix import MultiscaleDiscriminator 12 | from fsgan.train_reenactment_attr_no_seg_v2_1 import main 13 | 14 | 15 | if __name__ == '__main__': 16 | exp_name = os.path.splitext(os.path.basename(__file__))[0] 17 | exp_dir = os.path.join('../results/reenactment', exp_name) 18 | root = '/data/datasets/nirkin_face_videos' 19 | train_dataset = partial(SeqPairDataset, root, 'videos_train.txt', postfixes=('.mp4', '_lms.npz'), same_prob=1.0) 20 | val_dataset = partial(SeqPairDataset, root, 'videos_val.txt', postfixes=('.mp4', '_lms.npz'), same_prob=1.0) 21 | numpy_transforms = [RandomHorizontalFlip(), Pyramids(2)] 22 | resolutions = [128, 256] 23 | lr_gen = [1e-4, 4e-5] 24 | lr_dis = [1e-5, 4e-6] 25 | epochs = [24, 50] 26 | iterations = ['20k'] 27 | batch_size = [48, 24] 28 | workers = 32 29 | pretrained = False 30 | criterion_id = VGGLoss('../../../weights/vggface2_vgg19_256_1_2_id.pth') 31 | criterion_attr = VGGLoss('../../../weights/celeba_vgg19_256_2_0_28_attr.pth') 32 | criterion_gan = GANLoss(use_lsgan=True) 33 | generator = MultiScaleResUNet(in_nc=101, out_nc=3, flat_layers=(2, 2, 2, 2), ngf=128) 34 | discriminator = MultiscaleDiscriminator(use_sigmoid=True, num_D=2) 35 | optimizer = partial(optim.Adam, betas=(0.5, 0.999)) 36 | scheduler = partial(lr_scheduler.StepLR, step_size=10, gamma=0.5) 37 | lms_model = '../../weights/hr18_wflw_landmarks.pth' 38 | rec_weight = 1.0 39 | gan_weight = 0.001 40 | 41 | if not os.path.exists(exp_dir): 42 | os.makedirs(exp_dir) 43 | 44 | main(exp_dir, train_dataset=train_dataset, val_dataset=val_dataset, numpy_transforms=numpy_transforms, 45 | resolutions=resolutions, lr_gen=lr_gen, lr_dis=lr_dis, epochs=epochs, iterations=iterations, 46 | batch_size=batch_size, workers=workers, optimizer=optimizer, scheduler=scheduler, pretrained=pretrained, 47 | criterion_id=criterion_id, criterion_attr=criterion_attr, criterion_gan=criterion_gan, 48 | generator=generator, discriminator=discriminator, rec_weight=rec_weight, gan_weight=gan_weight) 49 | 50 | os.system('sudo shutdown') 51 | -------------------------------------------------------------------------------- /fsgan/experiments/segmentation/celeba_unet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.optim.lr_scheduler as lr_scheduler 6 | from fsgan.datasets.image_seg_dataset import ImageSegDataset 7 | from fsgan.datasets.img_landmarks_transforms import Crop, Resize, RandomHorizontalFlip, RandomRotation 8 | from fsgan.datasets.img_landmarks_transforms import ColorJitter, RandomGaussianBlur 9 | from fsgan.models.simple_unet_02 import UNet 10 | from fsgan.train_segmentation import main 11 | 12 | 13 | if __name__ == '__main__': 14 | exp_name = os.path.splitext(os.path.basename(__file__))[0] 15 | exp_dir = os.path.join('../results/segmentation', exp_name) 16 | train_dataset = partial(ImageSegDataset, '/data/datasets/celeba_mask_hq', 17 | 'img_list_train.txt', 'bboxes_train.npy', seg_classes=3) 18 | val_dataset = partial(ImageSegDataset, '/data/datasets/celeba_mask_hq', 19 | 'img_list_val.txt', 'bboxes_val.npy', seg_classes=3) 20 | numpy_transforms = [RandomRotation(30.0, ('cubic', 'nearest')), Crop(), Resize(256, ('cubic', 'nearest')), 21 | RandomHorizontalFlip(), ColorJitter(0.5, 0.5, 0.5, 0.5, filter=(True, False)), 22 | RandomGaussianBlur(filter=(True, False))] 23 | resolutions = [256] 24 | learning_rate = [1e-4] 25 | epochs = [60] 26 | iterations = ['40k'] 27 | batch_size = [48] 28 | workers = 12 29 | pretrained = False 30 | optimizer = partial(optim.Adam, betas=(0.5, 0.999)) 31 | scheduler = partial(lr_scheduler.StepLR, step_size=10, gamma=0.5) 32 | criterion = nn.CrossEntropyLoss() 33 | model = partial(UNet) 34 | 35 | if not os.path.exists(exp_dir): 36 | os.makedirs(exp_dir) 37 | 38 | main(exp_dir, train_dataset=train_dataset, val_dataset=val_dataset, numpy_transforms=numpy_transforms, 39 | resolutions=resolutions, learning_rate=learning_rate, epochs=epochs, iterations=iterations, 40 | batch_size=batch_size, workers=workers, optimizer=optimizer, scheduler=scheduler, pretrained=pretrained, 41 | criterion=criterion, model=model) 42 | 43 | os.system('sudo shutdown') 44 | -------------------------------------------------------------------------------- /fsgan/experiments/swapping/ijbc_msrunet_blending.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.optim.lr_scheduler as lr_scheduler 6 | from fsgan.datasets.opencv_video_seq_dataset import VideoSeqPairDataset 7 | from fsgan.datasets.img_landmarks_transforms import RandomHorizontalFlip, Pyramids, ToTensor 8 | from fsgan.criterions.vgg_loss import VGGLoss 9 | from fsgan.criterions.gan_loss import GANLoss 10 | from fsgan.models.res_unet import MultiScaleResUNet 11 | from fsgan.models.discriminators_pix2pix import MultiscaleDiscriminator 12 | from fsgan.train_blending import main 13 | 14 | 15 | if __name__ == '__main__': 16 | exp_name = os.path.splitext(os.path.basename(__file__))[0] 17 | exp_dir = os.path.join('../results/swapping', exp_name) 18 | train_dataset = partial(VideoSeqPairDataset, '/data/datasets/ijb-c/ijbc_cropped/ijbc_cropped_r256_cs1.2', 19 | 'train_list.txt', frame_window=1, ignore_landmarks=True, same_prob=0.0) 20 | val_dataset = partial(VideoSeqPairDataset, '/data/datasets/ijb-c/ijbc_cropped/ijbc_cropped_r256_cs1.2', 21 | 'val_list.txt', frame_window=1, ignore_landmarks=True, same_prob=0.0) 22 | numpy_transforms = [RandomHorizontalFlip(), Pyramids(2)] 23 | tensor_transforms = [ToTensor()] 24 | resolutions = [128, 256] 25 | lr_gen = [1e-4, 4e-5] 26 | lr_dis = [1e-5, 4e-6] 27 | epochs = [24, 50] 28 | iterations = ['20k'] 29 | batch_size = [32, 16] 30 | workers = 32 31 | pretrained = False 32 | criterion_id = VGGLoss('../../../weights/vggface2_vgg19_256_1_2_id.pth') 33 | criterion_attr = VGGLoss('../../../weights/celeba_vgg19_256_2_0_28_attr.pth') 34 | criterion_gan = GANLoss(use_lsgan=True) 35 | generator = MultiScaleResUNet(in_nc=7, out_nc=3, flat_layers=(2, 2, 2, 2), ngf=128) 36 | discriminator = MultiscaleDiscriminator(use_sigmoid=True, num_D=2) 37 | optimizer = partial(optim.Adam, betas=(0.5, 0.999)) 38 | scheduler = partial(lr_scheduler.StepLR, step_size=10, gamma=0.5) 39 | reenactment_model = '../results/reenactment/ijbc_msrunet_reenactment_attr_no_seg/G_latest.pth' 40 | seg_model = '../../weights/lfw_figaro_unet_256_segmentation.pth' 41 | lms_model = '../../weights/hr18_wflw_landmarks.pth' 42 | rec_weight = 1.0 43 | gan_weight = 0.1 44 | background_value = -1.0 45 | 46 | if not os.path.exists(exp_dir): 47 | os.makedirs(exp_dir) 48 | 49 | main(exp_dir, train_dataset=train_dataset, val_dataset=val_dataset, 50 | numpy_transforms=numpy_transforms, tensor_transforms=tensor_transforms, resolutions=resolutions, 51 | lr_gen=lr_gen, lr_dis=lr_dis, epochs=epochs, iterations=iterations, batch_size=batch_size, workers=workers, 52 | optimizer=optimizer, scheduler=scheduler, pretrained=pretrained, 53 | criterion_id=criterion_id, criterion_attr=criterion_attr, criterion_gan=criterion_gan, 54 | generator=generator, discriminator=discriminator, reenactment_model=reenactment_model, seg_model=seg_model, 55 | lms_model=lms_model, rec_weight=rec_weight, gan_weight=gan_weight, background_value=background_value) 56 | 57 | os.system('sudo shutdown') 58 | -------------------------------------------------------------------------------- /fsgan/experiments/swapping/ijbc_msrunet_inpainting.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.optim.lr_scheduler as lr_scheduler 6 | from fsgan.datasets.opencv_video_seq_dataset import VideoSeqPairDataset 7 | from fsgan.datasets.img_landmarks_transforms import RandomHorizontalFlip, Pyramids, ToTensor 8 | from fsgan.criterions.vgg_loss import VGGLoss 9 | from fsgan.criterions.gan_loss import GANLoss 10 | from fsgan.models.res_unet import MultiScaleResUNet 11 | from fsgan.models.discriminators_pix2pix import MultiscaleDiscriminator 12 | from fsgan.train_inpainting import main 13 | 14 | 15 | if __name__ == '__main__': 16 | exp_name = os.path.splitext(os.path.basename(__file__))[0] 17 | exp_dir = os.path.join('../results/swapping', exp_name) 18 | train_dataset = partial(VideoSeqPairDataset, '/data/datasets/ijb-c/ijbc_cropped/ijbc_cropped_r256_cs1.2', 19 | 'train_list.txt', frame_window=1, ignore_landmarks=True, same_prob=1.0) 20 | val_dataset = partial(VideoSeqPairDataset, '/data/datasets/ijb-c/ijbc_cropped/ijbc_cropped_r256_cs1.2', 21 | 'val_list.txt', frame_window=1, ignore_landmarks=True, same_prob=1.0) 22 | numpy_transforms = [RandomHorizontalFlip(), Pyramids(2)] 23 | tensor_transforms = [ToTensor()] 24 | resolutions = [128, 256] 25 | lr_gen = [1e-4, 4e-5] 26 | lr_dis = [1e-5, 4e-6] 27 | epochs = [24, 50] 28 | iterations = ['20k'] 29 | batch_size = [32, 16] 30 | workers = 32 31 | pretrained = False 32 | criterion_id = VGGLoss('../../../weights/vggface2_vgg19_256_1_2_id.pth') 33 | criterion_attr = VGGLoss('../../../weights/celeba_vgg19_256_2_0_28_attr.pth') 34 | criterion_gan = GANLoss(use_lsgan=True) 35 | generator = MultiScaleResUNet(in_nc=4, out_nc=3, flat_layers=(2, 2, 2, 2), ngf=128) 36 | discriminator = MultiscaleDiscriminator(use_sigmoid=True, num_D=2) 37 | optimizer = partial(optim.Adam, betas=(0.5, 0.999)) 38 | scheduler = partial(lr_scheduler.StepLR, step_size=10, gamma=0.5) 39 | reenactment_model = '../results/reenactment/ijbc_msrunet_reenactment_attr_no_seg/G_latest.pth' 40 | seg_model = '../../weights/lfw_figaro_unet_256_segmentation.pth' 41 | lms_model = '../../weights/hr18_wflw_landmarks.pth' 42 | rec_weight = 1.0 43 | gan_weight = 0.001 44 | background_value = -1.0 45 | 46 | if not os.path.exists(exp_dir): 47 | os.makedirs(exp_dir) 48 | 49 | main(exp_dir, train_dataset=train_dataset, val_dataset=val_dataset, 50 | numpy_transforms=numpy_transforms, tensor_transforms=tensor_transforms, resolutions=resolutions, 51 | lr_gen=lr_gen, lr_dis=lr_dis, epochs=epochs, iterations=iterations, batch_size=batch_size, workers=workers, 52 | optimizer=optimizer, scheduler=scheduler, pretrained=pretrained, 53 | criterion_id=criterion_id, criterion_attr=criterion_attr, criterion_gan=criterion_gan, 54 | generator=generator, discriminator=discriminator, reenactment_model=reenactment_model, seg_model=seg_model, 55 | lms_model=lms_model, rec_weight=rec_weight, gan_weight=gan_weight, background_value=background_value) 56 | 57 | os.system('sudo shutdown') 58 | -------------------------------------------------------------------------------- /fsgan/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuvalNirkin/fsgan/0605e7c521ec697cbcd12d4f7d46b02beb808fd1/fsgan/models/__init__.py -------------------------------------------------------------------------------- /fsgan/models/classifier1d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import inspect 6 | 7 | 8 | def make_linear_block(in_nc, out_nc, bias=False, norm_layer=nn.BatchNorm1d, act_layer=nn.ReLU(True), use_dropout=False): 9 | linear_block = [] 10 | linear_block.append(nn.Linear(in_nc, out_nc, bias=bias)) 11 | if norm_layer is not None: 12 | linear_block.append(norm_layer(out_nc)) 13 | if act_layer is not None: 14 | linear_block.append(act_layer) 15 | 16 | if use_dropout: 17 | linear_block += [nn.Dropout(0.5)] 18 | 19 | return linear_block 20 | 21 | 22 | class Classifier(nn.Module): 23 | def __init__(self, in_nc=2048, out_nc=2, layers=(2048,), norm_layer=nn.BatchNorm1d, act_layer=nn.ReLU(True), 24 | use_dropout=False): 25 | super(Classifier, self).__init__() 26 | self.idx_tensor = None 27 | 28 | # Add linear layers 29 | channels = [in_nc] + list(layers) + [out_nc] 30 | self.model = [] 31 | for i in range(1, len(channels) - 1): 32 | self.model += make_linear_block(channels[i - 1], channels[i], norm_layer=norm_layer, act_layer=act_layer, 33 | use_dropout=use_dropout) 34 | self.model += make_linear_block(channels[-2], channels[-1], norm_layer=None, act_layer=None, 35 | use_dropout=use_dropout) 36 | self.model = nn.Sequential(*self.model) 37 | 38 | def forward(self, x): 39 | return self.model(x) 40 | 41 | 42 | def classifier(pretrained=False, **kwargs): 43 | model = Classifier(**kwargs) 44 | 45 | if pretrained: 46 | if os.path.isfile(pretrained): 47 | checkpoint = torch.load(pretrained) 48 | model.load_state_dict(checkpoint['state_dict']) 49 | else: 50 | raise RuntimeError('Could not find weights file: %s' % pretrained) 51 | 52 | return model 53 | 54 | 55 | def main(obj_exp): 56 | from fake_detection.utils.obj_batch import obj_factory 57 | obj = obj_factory(obj_exp) 58 | print(obj) 59 | 60 | 61 | if __name__ == "__main__": 62 | # Parse program arguments 63 | import argparse 64 | parser = argparse.ArgumentParser('classifier1d') 65 | parser.add_argument('obj_exp', help='object string') 66 | args = parser.parse_args() 67 | 68 | main(args.obj_exp) -------------------------------------------------------------------------------- /fsgan/models/discriminators_pix2pix.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | from fsgan.utils.img_utils import create_pyramid 4 | 5 | 6 | # Adapted from https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py 7 | class NLayerDiscriminator(nn.Module): 8 | """ Defines the PatchGAN discriminator. 9 | `"Image-to-Image Translation with Conditional Adversarial Networks" `_ 10 | 11 | Args: 12 | input_nc (int): Input number of channels 13 | ndf (int): Number of the discriminator feature channels of the first layer 14 | n_layers (int): Number of intermediate layers 15 | norm_layer (nn.Module): Type of feature normalization 16 | use_sigmoid (bool): If True, a Sigmoid activation will be used after the final layer 17 | getIntermFeat (bool): If True, all intermediate features will be returned else only the final feature 18 | """ 19 | def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, 20 | getIntermFeat=False): 21 | super(NLayerDiscriminator, self).__init__() 22 | self.getIntermFeat = getIntermFeat 23 | self.n_layers = n_layers 24 | 25 | kw = 4 26 | padw = int(np.ceil((kw - 1.0) / 2)) 27 | sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] 28 | 29 | nf = ndf 30 | for n in range(1, n_layers): 31 | nf_prev = nf 32 | nf = min(nf * 2, 512) 33 | sequence += [[ 34 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), 35 | norm_layer(nf), nn.LeakyReLU(0.2, True) 36 | ]] 37 | 38 | nf_prev = nf 39 | nf = min(nf * 2, 512) 40 | sequence += [[ 41 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), 42 | norm_layer(nf), 43 | nn.LeakyReLU(0.2, True) 44 | ]] 45 | 46 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 47 | 48 | if use_sigmoid: 49 | sequence += [[nn.Sigmoid()]] 50 | 51 | if getIntermFeat: 52 | for n in range(len(sequence)): 53 | setattr(self, 'model' + str(n), nn.Sequential(*sequence[n])) 54 | else: 55 | sequence_stream = [] 56 | for n in range(len(sequence)): 57 | sequence_stream += sequence[n] 58 | self.model = nn.Sequential(*sequence_stream) 59 | 60 | def forward(self, input): 61 | if self.getIntermFeat: 62 | res = [input] 63 | for n in range(self.n_layers + 2): 64 | model = getattr(self, 'model' + str(n)) 65 | res.append(model(res[-1])) 66 | return res[1:] 67 | else: 68 | return self.model(input) 69 | 70 | 71 | # Adapted from https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py 72 | class MultiscaleDiscriminator(nn.Module): 73 | """ Defines the multi-scale descriminator. 74 | `"High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs" 75 | `_. 76 | 77 | Args: 78 | input_nc (int): Input number of channels 79 | ndf (int): Number of the discriminator feature channels of the first layer 80 | n_layers (int): Number of intermediate layers 81 | norm_layer (nn.Module): Type of feature normalization 82 | use_sigmoid (bool): If True, a Sigmoid activation will be used after the final layer 83 | num_D (int): Number of discriminators 84 | getIntermFeat (bool): If True, all intermediate features will be returned else only the final feature 85 | """ 86 | def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 87 | use_sigmoid=False, num_D=3, getIntermFeat=False): 88 | super(MultiscaleDiscriminator, self).__init__() 89 | self.num_D = num_D 90 | self.n_layers = n_layers 91 | self.getIntermFeat = getIntermFeat 92 | 93 | for i in range(num_D): 94 | netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat) 95 | if getIntermFeat: 96 | for j in range(n_layers + 2): 97 | setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j))) 98 | else: 99 | setattr(self, 'layer' + str(i), netD.model) 100 | 101 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) 102 | 103 | def singleD_forward(self, model, input): 104 | if self.getIntermFeat: 105 | result = [input] 106 | for i in range(len(model)): 107 | result.append(model[i](result[-1])) 108 | return result[1:] 109 | else: 110 | return [model(input)] 111 | 112 | def forward(self, input): 113 | input = create_pyramid(input, self.num_D) 114 | levels = len(input) 115 | result = [] 116 | for i in range(levels): 117 | curr_input = input[i] 118 | if self.getIntermFeat: 119 | model = [getattr(self, 'scale' + str(levels - 1 - i) + '_layer' + str(j)) for j in 120 | range(self.n_layers + 2)] 121 | else: 122 | model = getattr(self, 'layer' + str(levels - 1 - i)) 123 | result.append(self.singleD_forward(model, curr_input)) 124 | 125 | return result 126 | -------------------------------------------------------------------------------- /fsgan/models/hopenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.nn.functional as F 5 | from torchvision.models.resnet import Bottleneck 6 | 7 | 8 | class Hopenet(nn.Module): 9 | """ Defines a head pose estimation network with 3 output layers: yaw, pitch and roll. 10 | `"Fine-Grained Head Pose Estimation Without Keypoints" `_. 11 | 12 | Predicts Euler angles by binning and regression. 13 | 14 | Args: 15 | block (nn.Module): Main convolution block 16 | layers (list of ints): Number of blocks per intermediate layer 17 | num_bins (int): Number of regression bins 18 | """ 19 | def __init__(self, block=Bottleneck, layers=(3, 4, 6, 3), num_bins=66): 20 | self.inplanes = 64 21 | super(Hopenet, self).__init__() 22 | self.idx_tensor = None 23 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 24 | bias=False) 25 | self.bn1 = nn.BatchNorm2d(64) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 28 | self.layer1 = self._make_layer(block, 64, layers[0]) 29 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 30 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 31 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 32 | self.avgpool = nn.AvgPool2d(7) 33 | self.fc_yaw = nn.Linear(512 * block.expansion, num_bins) 34 | self.fc_pitch = nn.Linear(512 * block.expansion, num_bins) 35 | self.fc_roll = nn.Linear(512 * block.expansion, num_bins) 36 | 37 | # Vestigial layer from previous experiments 38 | self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3) 39 | 40 | for m in self.modules(): 41 | if isinstance(m, nn.Conv2d): 42 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 43 | m.weight.data.normal_(0, math.sqrt(2. / n)) 44 | elif isinstance(m, nn.BatchNorm2d): 45 | m.weight.data.fill_(1) 46 | m.bias.data.zero_() 47 | 48 | def _make_layer(self, block, planes, blocks, stride=1): 49 | downsample = None 50 | if stride != 1 or self.inplanes != planes * block.expansion: 51 | downsample = nn.Sequential( 52 | nn.Conv2d(self.inplanes, planes * block.expansion, 53 | kernel_size=1, stride=stride, bias=False), 54 | nn.BatchNorm2d(planes * block.expansion), 55 | ) 56 | 57 | layers = [] 58 | layers.append(block(self.inplanes, planes, stride, downsample)) 59 | self.inplanes = planes * block.expansion 60 | for i in range(1, blocks): 61 | layers.append(block(self.inplanes, planes)) 62 | 63 | return nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | x = self.conv1(x) 67 | x = self.bn1(x) 68 | x = self.relu(x) 69 | x = self.maxpool(x) 70 | 71 | x = self.layer1(x) 72 | x = self.layer2(x) 73 | x = self.layer3(x) 74 | x = self.layer4(x) 75 | 76 | x = self.avgpool(x) 77 | x = x.view(x.size(0), -1) 78 | pred_yaw = self.fc_yaw(x) 79 | pred_pitch = self.fc_pitch(x) 80 | pred_roll = self.fc_roll(x) 81 | 82 | yaw_predicted = F.softmax(pred_yaw, dim=1) 83 | pitch_predicted = F.softmax(pred_pitch, dim=1) 84 | roll_predicted = F.softmax(pred_roll, dim=1) 85 | 86 | if self.idx_tensor is None: 87 | self.idx_tensor = torch.arange(0, 66, out=torch.FloatTensor()).to(x.device) 88 | 89 | # Get continuous predictions in degrees. 90 | yaw_predicted = torch.sum(yaw_predicted * self.idx_tensor, axis=1).unsqueeze(1) * 3 - 99 91 | pitch_predicted = torch.sum(pitch_predicted * self.idx_tensor, axis=1).unsqueeze(1) * 3 - 99 92 | roll_predicted = torch.sum(roll_predicted * self.idx_tensor, axis=1).unsqueeze(1) * 3 - 99 93 | 94 | return torch.cat((yaw_predicted, pitch_predicted, roll_predicted), axis=1) 95 | -------------------------------------------------------------------------------- /fsgan/models/msba.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multi-Scale Binned Activation. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class MSBA(nn.Module): 10 | def __init__(self, out_nc=3, bins=64): 11 | super(MSBA, self).__init__() 12 | self.in_nc = out_nc * bins 13 | self.out_nc = out_nc 14 | self.bins = bins 15 | self.norm_factor = 2. / (self.bins - 1) 16 | 17 | # Initialize scales tensor 18 | self.register_buffer('scales', torch.arange(0., bins).view(1, -1)) 19 | 20 | def forward(self, x): 21 | assert x.shape[1] == self.in_nc 22 | scales = self.scales.view(self.scales.shape + (1,) * (x.ndim - self.scales.ndim)) 23 | 24 | out = [] 25 | for i in range(self.out_nc): 26 | xc = F.softmax(x[:, i * self.bins:(i + 1) * self.bins], dim=1) 27 | xc = torch.sum(xc * scales, dim=1, keepdim=True).mul_(self.norm_factor).sub_(1.) 28 | out.append(xc) 29 | out = torch.cat(out, dim=1) 30 | 31 | return out 32 | 33 | 34 | def main(): 35 | msba = MSBA() 36 | img = torch.rand(2, msba.in_nc, 64, 64) 37 | out = msba(img) 38 | print(out.shape) 39 | 40 | 41 | if __name__ == "__main__": 42 | main() 43 | -------------------------------------------------------------------------------- /fsgan/models/simple_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | pretrained_models = {'pascal': 'path/to/pretrained_model.pth'} 7 | 8 | 9 | # Adapted from: https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/models/unet.py 10 | class UNet(nn.Module): 11 | """ Defines a variant of the UNet architecture described in the paper: 12 | `"U-Net: Convolutional Networks for Biomedical Image Segmentation `_. 13 | 14 | Args: 15 | feature_scale (int): Divides the intermediate feature map number of channels 16 | n_classes (int): Output number of channels 17 | is_deconv (bool): If True, transposed convolution will be used for the upsampling operation instead of 18 | bilinear interpolation 19 | in_channels (int): Input number of channels 20 | is_batchnorm (bool): If True, enables the use of batch normalization 21 | """ 22 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=False, in_channels=3, is_batchnorm=True): 23 | super(UNet, self).__init__() 24 | self.n_classes = n_classes 25 | self.is_deconv = is_deconv 26 | self.in_channels = in_channels 27 | self.is_batchnorm = is_batchnorm 28 | self.feature_scale = feature_scale 29 | 30 | filters = [64, 128, 256, 512, 1024] 31 | filters = [int(x / self.feature_scale) for x in filters] 32 | 33 | # downsampling 34 | self.conv1 = UnetConv2(self.in_channels, filters[0], self.is_batchnorm) 35 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 36 | 37 | self.conv2 = UnetConv2(filters[0], filters[1], self.is_batchnorm) 38 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 39 | 40 | self.conv3 = UnetConv2(filters[1], filters[2], self.is_batchnorm) 41 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 42 | 43 | self.conv4 = UnetConv2(filters[2], filters[3], self.is_batchnorm) 44 | self.maxpool4 = nn.MaxPool2d(kernel_size=2) 45 | 46 | self.center = UnetConv2(filters[3], filters[4], self.is_batchnorm) 47 | 48 | # upsampling 49 | self.up_concat4 = UnetUp(filters[4], filters[3], self.is_deconv) 50 | self.up_concat3 = UnetUp(filters[3], filters[2], self.is_deconv) 51 | self.up_concat2 = UnetUp(filters[2], filters[1], self.is_deconv) 52 | self.up_concat1 = UnetUp(filters[1], filters[0], self.is_deconv) 53 | 54 | # final conv (without any concat) 55 | self.final = nn.Conv2d(filters[0], n_classes, 1) 56 | 57 | def forward(self, inputs): 58 | conv1 = self.conv1(inputs) 59 | maxpool1 = self.maxpool1(conv1) 60 | 61 | conv2 = self.conv2(maxpool1) 62 | maxpool2 = self.maxpool2(conv2) 63 | 64 | conv3 = self.conv3(maxpool2) 65 | maxpool3 = self.maxpool3(conv3) 66 | 67 | conv4 = self.conv4(maxpool3) 68 | maxpool4 = self.maxpool4(conv4) 69 | 70 | center = self.center(maxpool4) 71 | up4 = self.up_concat4(conv4, center) 72 | up3 = self.up_concat3(conv3, up4) 73 | up2 = self.up_concat2(conv2, up3) 74 | up1 = self.up_concat1(conv1, up2) 75 | 76 | final = self.final(up1) 77 | 78 | return final 79 | 80 | 81 | class UnetConv2(nn.Module): 82 | """ Defines the UNet's convolution block. 83 | 84 | Args: 85 | in_size (int): Input number of channels 86 | out_size (int): Output number of channels 87 | is_batchnorm (bool): If True, enables the use of batch normalization 88 | """ 89 | def __init__(self, in_size, out_size, is_batchnorm): 90 | super(UnetConv2, self).__init__() 91 | 92 | if is_batchnorm: 93 | self.conv1 = nn.Sequential( 94 | nn.Conv2d(in_size, out_size, 3, 1, 1), 95 | nn.BatchNorm2d(out_size), 96 | nn.ReLU(), 97 | ) 98 | self.conv2 = nn.Sequential( 99 | nn.Conv2d(out_size, out_size, 3, 1, 1), 100 | nn.BatchNorm2d(out_size), 101 | nn.ReLU(), 102 | ) 103 | else: 104 | self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 1), nn.ReLU()) 105 | self.conv2 = nn.Sequential( 106 | nn.Conv2d(out_size, out_size, 3, 1, 1), nn.ReLU() 107 | ) 108 | 109 | def forward(self, inputs): 110 | outputs = self.conv1(inputs) 111 | outputs = self.conv2(outputs) 112 | return outputs 113 | 114 | 115 | class UnetUp(nn.Module): 116 | """ Defines the UNet's upsampling block. 117 | 118 | Args: 119 | in_size (int): Input number of channels 120 | out_size (int): Output number of channels 121 | is_deconv (bool): If True, transposed convolution will be used for the upsampling operation instead of 122 | bilinear interpolation 123 | """ 124 | def __init__(self, in_size, out_size, is_deconv): 125 | super(UnetUp, self).__init__() 126 | self.conv = UnetConv2(in_size, out_size, False) 127 | if is_deconv: 128 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2) 129 | else: 130 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 131 | self.conv1d = nn.Conv1d(in_size, out_size, kernel_size=(1,1)) 132 | 133 | def forward(self, inputs1, inputs2): 134 | outputs2 = self.up(inputs2) 135 | outputs2 = self.conv1d(outputs2,) 136 | offset = outputs2.size()[2] - inputs1.size()[2] 137 | padding = 2 * [offset // 2, offset // 2] 138 | outputs1 = F.pad(inputs1, padding) 139 | return self.conv(torch.cat([outputs1, outputs2], 1)) 140 | 141 | 142 | def unet(num_classes=21, is_deconv=False, feature_scale=1, is_batchnorm=True, pretrained=False): 143 | """ Creates a UNet model with pretrained optiopn. 144 | 145 | Args: 146 | num_classes (int): Output number of channels 147 | is_deconv (bool): If True, transposed convolution will be used for the upsampling operation instead of 148 | bilinear interpolation 149 | feature_scale (int): Divides the intermediate feature map number of channels 150 | is_batchnorm (bool): If True, enables the use of batch normalization 151 | pretrained (bool): If True, return a pretrained model on Pascal dataset 152 | 153 | Returns: 154 | UNet model 155 | """ 156 | if pretrained: 157 | model_path = pretrained_models['pascal'] 158 | model = UNet(n_classes=num_classes, feature_scale=feature_scale, is_batchnorm=is_batchnorm, is_deconv=is_deconv) 159 | checkpoint = torch.load(model_path) 160 | weights = checkpoint['state_dict'] 161 | weights['notinuse'] = weights.pop('final.weight') 162 | weights['notinuse2'] = weights.pop('final.bias') 163 | model.load_state_dict(weights, strict=False) 164 | else: 165 | model = UNet(n_classes=num_classes, feature_scale=feature_scale, is_batchnorm=is_batchnorm, is_deconv=is_deconv) 166 | 167 | return model 168 | -------------------------------------------------------------------------------- /fsgan/models/simple_unet_02.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # Code is borrowed from 7 | # https://github.com/meetshah1995/pytorch-semseg/commits/master/ptsemseg/models/unet.py 8 | 9 | 10 | class UNet(nn.Module): 11 | """ Defines a variant of the UNet architecture described in the paper: 12 | `"U-Net: Convolutional Networks for Biomedical Image Segmentation `_. 13 | 14 | Args: 15 | feature_scale (int): Divides the intermediate feature map number of channels 16 | num_classes (int): Output number of channels 17 | is_deconv (bool): If True, transposed convolution will be used for the upsampling operation instead of 18 | bilinear interpolation 19 | in_channels (int): Input number of channels 20 | is_batchnorm (bool): If True, enables the use of batch normalization 21 | """ 22 | def __init__(self, feature_scale=1, num_classes=21, is_deconv=False, in_channels=3, is_batchnorm=True): 23 | super(UNet, self).__init__() 24 | self.num_classes = num_classes 25 | self.is_deconv = is_deconv 26 | self.in_channels = in_channels 27 | self.is_batchnorm = is_batchnorm 28 | self.feature_scale = feature_scale 29 | 30 | filters = [64, 128, 256, 512, 1024] 31 | filters = [int(x / self.feature_scale) for x in filters] 32 | 33 | # downsampling 34 | self.conv1 = UnetConv2(self.in_channels, filters[0], self.is_batchnorm) 35 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 36 | 37 | self.conv2 = UnetConv2(filters[0], filters[1], self.is_batchnorm) 38 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 39 | 40 | self.conv3 = UnetConv2(filters[1], filters[2], self.is_batchnorm) 41 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 42 | 43 | self.conv4 = UnetConv2(filters[2], filters[3], self.is_batchnorm) 44 | self.maxpool4 = nn.MaxPool2d(kernel_size=2) 45 | 46 | self.center = UnetConv2(filters[3], filters[4], self.is_batchnorm) 47 | 48 | # upsampling 49 | self.up_concat4 = UnetUp(filters[4], filters[3], self.is_deconv) 50 | self.up_concat3 = UnetUp(filters[3], filters[2], self.is_deconv) 51 | self.up_concat2 = UnetUp(filters[2], filters[1], self.is_deconv) 52 | self.up_concat1 = UnetUp(filters[1], filters[0], self.is_deconv) 53 | 54 | # final conv (without any concat) 55 | self.final = nn.Conv2d(filters[0], num_classes, 1) 56 | 57 | def forward(self, inputs): 58 | conv1 = self.conv1(inputs) 59 | maxpool1 = self.maxpool1(conv1) 60 | 61 | conv2 = self.conv2(maxpool1) 62 | maxpool2 = self.maxpool2(conv2) 63 | 64 | conv3 = self.conv3(maxpool2) 65 | maxpool3 = self.maxpool3(conv3) 66 | 67 | conv4 = self.conv4(maxpool3) 68 | maxpool4 = self.maxpool4(conv4) 69 | 70 | center = self.center(maxpool4) 71 | up4 = self.up_concat4(conv4, center) 72 | up3 = self.up_concat3(conv3, up4) 73 | up2 = self.up_concat2(conv2, up3) 74 | up1 = self.up_concat1(conv1, up2) 75 | 76 | final = self.final(up1) 77 | 78 | return final 79 | 80 | 81 | class UnetConv2(nn.Module): 82 | """ Defines the UNet's convolution block. 83 | 84 | Args: 85 | in_size (int): Input number of channels 86 | out_size (int): Output number of channels 87 | is_batchnorm (bool): If True, enables the use of batch normalization 88 | """ 89 | def __init__(self, in_size, out_size, is_batchnorm): 90 | super(UnetConv2, self).__init__() 91 | 92 | if is_batchnorm: 93 | self.conv1 = nn.Sequential( 94 | nn.Conv2d(in_size, out_size, 3, 1, 1), 95 | nn.BatchNorm2d(out_size), 96 | nn.ReLU(), 97 | ) 98 | self.conv2 = nn.Sequential( 99 | nn.Conv2d(out_size, out_size, 3, 1, 1), 100 | nn.BatchNorm2d(out_size), 101 | nn.ReLU(), 102 | ) 103 | else: 104 | self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 1), nn.ReLU()) 105 | self.conv2 = nn.Sequential( 106 | nn.Conv2d(out_size, out_size, 3, 1, 1), nn.ReLU() 107 | ) 108 | 109 | def forward(self, inputs): 110 | outputs = self.conv1(inputs) 111 | outputs = self.conv2(outputs) 112 | return outputs 113 | 114 | 115 | class UnetUp(nn.Module): 116 | """ Defines the UNet's upsampling block. 117 | 118 | Args: 119 | in_size (int): Input number of channels 120 | out_size (int): Output number of channels 121 | is_deconv (bool): If True, transposed convolution will be used for the upsampling operation instead of 122 | bilinear interpolation 123 | """ 124 | def __init__(self, in_size, out_size, is_deconv): 125 | super(UnetUp, self).__init__() 126 | self.conv = UnetConv2(in_size, out_size, False) 127 | if is_deconv: 128 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2) 129 | else: 130 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 131 | # self.conv1d = nn.Conv1d(in_size, out_size, kernel_size=(1,1)) 132 | self.conv1d = nn.Conv2d(in_size, out_size, kernel_size=(1, 1)) 133 | 134 | def forward(self, inputs1, inputs2): 135 | outputs2 = self.up(inputs2) 136 | outputs2 = self.conv1d(outputs2,) 137 | offset = outputs2.size()[2] - inputs1.size()[2] 138 | padding = 2 * [offset // 2, offset // 2] 139 | outputs1 = F.pad(inputs1, padding) 140 | return self.conv(torch.cat([outputs1, outputs2], 1)) 141 | -------------------------------------------------------------------------------- /fsgan/models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | __all__ = [ 5 | 'VGG', 6 | 'vgg19', 7 | 'vgg_fcn' 8 | ] 9 | 10 | model_urls = { 11 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 12 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 13 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 14 | } 15 | 16 | cfg = { 17 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 18 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 19 | } 20 | 21 | 22 | # Batchnorm removed from the code, as Johnson didn't use it for his transfer learning work 23 | # First nn.Linear shape is changed from 512 * 7 * 7 to 512 * 8 * 8 to meet our 256x256 input requirements 24 | # Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 25 | class VGG(nn.Module): 26 | 27 | def __init__(self, features, num_classes=1000, init_weights=True, verification=False): 28 | super(VGG, self).__init__() 29 | self.features = features 30 | self.classifier = nn.Sequential( 31 | nn.Linear(512 * 8 * 8, 4096), # for input 256, 8x8 instead of 7x7 32 | nn.ReLU(True), 33 | nn.Dropout(), 34 | nn.Linear(4096, 4096), 35 | nn.ReLU(True), 36 | nn.Dropout(), 37 | nn.Linear(4096, num_classes), 38 | ) 39 | if verification: 40 | self.classifier = nn.Sequential(*list(self.classifier.children())[:-1]) 41 | if init_weights: 42 | self._initialize_weights() 43 | 44 | def forward(self, x): 45 | x = self.features(x) 46 | x = x.view(x.size(0), -1) 47 | x = self.classifier(x) 48 | return x 49 | 50 | def _initialize_weights(self): 51 | for m in self.modules(): 52 | if isinstance(m, nn.Conv2d): 53 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 54 | if m.bias is not None: 55 | nn.init.constant_(m.bias, 0) 56 | elif isinstance(m, nn.BatchNorm2d): 57 | nn.init.constant_(m.weight, 1) 58 | nn.init.constant_(m.bias, 0) 59 | elif isinstance(m, nn.Linear): 60 | nn.init.normal_(m.weight, 0, 0.01) 61 | nn.init.constant_(m.bias, 0) 62 | 63 | 64 | def make_layers(cfg, batch_norm=False): 65 | layers = [] 66 | in_channels = 3 67 | for v in cfg: 68 | if v == 'M': 69 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 70 | else: 71 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 72 | if batch_norm: 73 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 74 | else: 75 | layers += [conv2d, nn.ReLU(inplace=True)] 76 | in_channels = v 77 | return nn.Sequential(*layers) 78 | 79 | 80 | def vgg19(num_classes=1000, pretrained=False, batch_norm=True, verifcation=False, **kwargs): 81 | """VGG 19-layer model (configuration "E") 82 | 83 | Args: 84 | num_classes(int): the number of classes at dataset 85 | pretrained (bool): If True, returns a model pre-trained on ImageNet 86 | with a new FC layer 512x8x8 instead of 512x7x7 87 | batch_norm: if you want to introduce batch normalization 88 | verifcation (bool): Toggle verification mode (removes last fc from classifier) 89 | """ 90 | if pretrained: 91 | kwargs['init_weights'] = True 92 | model = VGG(make_layers(cfg['E'], batch_norm=batch_norm), num_classes, **kwargs) 93 | 94 | # if verifcation: 95 | # verifier = nn.Sequential() 96 | # for x in range(2): 97 | # verifier.add_module(str(x), model.classifier[x]) 98 | # for x in range(3, 5): 99 | # verifier.add_module(str(x), model.classifier[x]) 100 | # model.classifier = verifier 101 | 102 | if pretrained: 103 | # loading weights 104 | if batch_norm: 105 | pretrained_weights = model_zoo.load_url(model_urls['vgg19_bn']) 106 | else: 107 | pretrained_weights = model_zoo.load_url(model_urls['vgg19']) 108 | # loading only CONV layers weights 109 | for i in [0, 3, 6]: 110 | w = 'classifier.{}.weight'.format(str(i)) 111 | new_w = 'not_used_{}'.format(str(i)) 112 | b = 'classifier.{}.bias'.format(str(i)) 113 | new_b ='not_used_{}'.format(str(i*10)) 114 | pretrained_weights[new_w] = pretrained_weights.pop(w) 115 | pretrained_weights[new_b] = pretrained_weights.pop(b) 116 | 117 | model.load_state_dict(pretrained_weights, strict=False) 118 | 119 | return model 120 | 121 | 122 | def vgg_fcn(num_classes=1000, pretrained=False, batch_norm=False, **kwargs): 123 | """VGG 16-layer model (configuration "D") 124 | 125 | Args: 126 | num_classes(int): the number of classes at dataset 127 | pretrained (bool): If True, returns a model pre-trained on ImageNet 128 | batch_norm: if you want to introduce batch normalization 129 | """ 130 | if pretrained: 131 | kwargs['init_weights'] = True 132 | model = VGG(make_layers(cfg['D'], batch_norm=batch_norm), num_classes, **kwargs) 133 | 134 | if pretrained: 135 | # loading weights 136 | if batch_norm: 137 | pretrained_weights = model_zoo.load_url(model_urls['vgg19_bn']) 138 | else: 139 | pretrained_weights = model_zoo.load_url(model_urls['vgg19']) 140 | model.load_state_dict(pretrained_weights, strict=False) 141 | 142 | return model 143 | -------------------------------------------------------------------------------- /fsgan/preprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuvalNirkin/fsgan/0605e7c521ec697cbcd12d4f7d46b02beb808fd1/fsgan/preprocess/__init__.py -------------------------------------------------------------------------------- /fsgan/preprocess/clear_cache.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from glob import glob 4 | 5 | 6 | parser = argparse.ArgumentParser(os.path.splitext(os.path.basename(__file__))[0]) 7 | parser.add_argument('input', metavar='VIDEO', 8 | help='path to input sequence video') 9 | parser.add_argument('-o', '--output', metavar='PATH', 10 | help='output video path') 11 | parser.add_argument('-ep', '--except_postfix', default=('_dsfd.pkl',), nargs='+', metavar='POSTFIX', 12 | help='cache postfixes not to delete') 13 | default = parser.get_default 14 | 15 | 16 | def main(input, output=default('output'), except_postfix=default('except_postfix')): 17 | except_postfix = tuple(except_postfix) 18 | 19 | # Validation 20 | assert os.path.isfile(input), f'Input path "{input}" does not exist' 21 | 22 | # Parse cache files 23 | cache_dir = os.path.splitext(input)[0] 24 | cache_files = glob(os.path.join(cache_dir, '*')) 25 | 26 | # Warning and exit 27 | if not os.path.isdir(cache_dir): 28 | print(f'Warning: cache dir "{cache_dir}" does not exist') 29 | return 30 | if any([os.path.isdir(f) for f in cache_files]): 31 | print(f'Warning: "{cache_dir}" is not a cache directory') 32 | return 33 | 34 | # For each cache file 35 | delete_cache_dir = True 36 | for cache_file in cache_files: 37 | if cache_file.endswith(except_postfix): 38 | delete_cache_dir = False 39 | continue 40 | os.remove(cache_file) 41 | 42 | # Delete cache directory if it is empty 43 | if delete_cache_dir: 44 | os.rmdir(cache_dir) 45 | 46 | 47 | if __name__ == "__main__": 48 | main(**vars(parser.parse_args())) 49 | -------------------------------------------------------------------------------- /fsgan/preprocess/crop_image_sequences.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import cv2 5 | from fsgan.utils.bbox_utils import scale_bbox, crop_img 6 | from fsgan.utils.video_utils import Sequence 7 | 8 | 9 | def main(input_path, output_dir=None, cache_path=None, seq_postfix='_dsfd_seq.pkl', out_postfix='.jpg', resolution=256, 10 | crop_scale=1.2): 11 | cache_path = os.path.splitext(input_path)[0] + seq_postfix if cache_path is None else cache_path 12 | if output_dir is None: 13 | output_dir = os.path.splitext(input_path)[0] 14 | if not os.path.isdir(output_dir): 15 | os.mkdir(output_dir) 16 | 17 | # Verification 18 | if not os.path.isfile(input_path): 19 | raise RuntimeError('Input video does not exist: ' + input_path) 20 | if not os.path.isfile(cache_path): 21 | raise RuntimeError('Cache file does not exist: ' + cache_path) 22 | if not os.path.isdir(output_dir): 23 | raise RuntimeError('Output directory does not exist: ' + output_dir) 24 | 25 | print('=> Cropping image sequences from image: "%s"...' % os.path.basename(input_path)) 26 | 27 | # Load sequences from file 28 | with open(cache_path, "rb") as fp: # Unpickling 29 | seq_list = pickle.load(fp) 30 | 31 | # Read image from file 32 | img = cv2.imread(input_path) 33 | if img is None: 34 | raise RuntimeError('Failed to read image: ' + input_path) 35 | 36 | # For each sequence 37 | for s, seq in enumerate(seq_list): 38 | det = seq[0] 39 | 40 | # Crop image 41 | bbox = np.concatenate((det[:2], det[2:] - det[:2])) 42 | bbox = scale_bbox(bbox, crop_scale) 43 | img_cropped = crop_img(img, bbox) 44 | img_cropped = cv2.resize(img_cropped, (resolution, resolution), interpolation=cv2.INTER_CUBIC) 45 | 46 | # Write cropped image to file 47 | out_img_name = os.path.splitext(os.path.basename(input_path))[0] + '_seq%02d%s' % (seq.id, out_postfix) 48 | out_img_path = os.path.join(output_dir, out_img_name) 49 | cv2.imwrite(out_img_path, img_cropped) 50 | 51 | 52 | if __name__ == "__main__": 53 | # Parse program arguments 54 | import argparse 55 | parser = argparse.ArgumentParser('crop_image_sequences') 56 | parser.add_argument('input', metavar='VIDEO', 57 | help='path to input video') 58 | parser.add_argument('-o', '--output', metavar='DIR', 59 | help='output directory') 60 | parser.add_argument('-c', '--cache', metavar='PATH', 61 | help='path to sequence cache file') 62 | parser.add_argument('-sp', '--seq_postfix', default='_dsfd_seq.pkl', metavar='POSTFIX', 63 | help='input sequence file postfix') 64 | parser.add_argument('-op', '--out_postfix', default='.jpg', metavar='POSTFIX', 65 | help='input sequence file postfix') 66 | parser.add_argument('-r', '--resolution', default=256, type=int, metavar='N', 67 | help='output video resolution (default: 256)') 68 | parser.add_argument('-cs', '--crop_scale', default=1.2, type=float, metavar='F', 69 | help='crop scale relative to bounding box (default: 1.2)') 70 | args = parser.parse_args() 71 | main(args.input, args.output, args.cache, args.seq_postfix, args.out_postfix, args.resolution, args.crop_scale) 72 | -------------------------------------------------------------------------------- /fsgan/preprocess/crop_video_sequences.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle 4 | from tqdm import tqdm 5 | import numpy as np 6 | import cv2 7 | from fsgan.utils.bbox_utils import scale_bbox, crop_img 8 | from fsgan.utils.video_utils import Sequence 9 | 10 | 11 | def main(input_path, output_dir=None, cache_path=None, seq_postfix='_dsfd_seq.pkl', resolution=256, crop_scale=2.0, 12 | select='all', disable_tqdm=False, encoder_codec='mp4v'): 13 | cache_path = os.path.splitext(input_path)[0] + seq_postfix if cache_path is None else cache_path 14 | if output_dir is None: 15 | output_dir = os.path.splitext(input_path)[0] 16 | if not os.path.isdir(output_dir): 17 | os.mkdir(output_dir) 18 | 19 | # Verification 20 | if not os.path.isfile(input_path): 21 | raise RuntimeError('Input video does not exist: ' + input_path) 22 | if not os.path.isfile(cache_path): 23 | raise RuntimeError('Cache file does not exist: ' + cache_path) 24 | if not os.path.isdir(output_dir): 25 | raise RuntimeError('Output directory does not exist: ' + output_dir) 26 | 27 | print('=> Cropping video sequences from video: "%s"...' % os.path.basename(input_path)) 28 | 29 | # Load sequences from file 30 | with open(cache_path, "rb") as fp: # Unpickling 31 | seq_list = pickle.load(fp) 32 | 33 | # Select sequences 34 | if select == 'longest': 35 | selected_seq_index = np.argmax([len(s) for s in seq_list]) 36 | seq = seq_list[selected_seq_index] 37 | seq.id = 0 38 | seq_list = [seq] 39 | 40 | # Open input video file 41 | cap = cv2.VideoCapture(input_path) 42 | if not cap.isOpened(): 43 | raise RuntimeError('Failed to read video: ' + input_path) 44 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 45 | fps = cap.get(cv2.CAP_PROP_FPS) 46 | input_vid_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 47 | input_vid_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 48 | 49 | # For each sequence initialize output video file 50 | out_vids = [] 51 | fourcc = cv2.VideoWriter_fourcc(*encoder_codec) 52 | for seq in seq_list: 53 | curr_vid_name = os.path.splitext(os.path.basename(input_path))[0] + '_seq%02d.mp4' % seq.id 54 | curr_vid_path = os.path.join(output_dir, curr_vid_name) 55 | out_vids.append(cv2.VideoWriter(curr_vid_path, fourcc, fps, (resolution, resolution))) 56 | 57 | # For each frame in the target video 58 | cropped_detections = [[] for seq in seq_list] 59 | cropped_landmarks = [[] for seq in seq_list] 60 | pbar = range(total_frames) if disable_tqdm else tqdm(range(total_frames), file=sys.stdout) 61 | for i in pbar: 62 | ret, frame = cap.read() 63 | if frame is None: 64 | continue 65 | 66 | # For each sequence 67 | for s, seq in enumerate(seq_list): 68 | if i < seq.start_index or (seq.start_index + len(seq) - 1) < i: 69 | continue 70 | det = seq[i - seq.start_index] 71 | 72 | # Crop frame 73 | bbox = np.concatenate((det[:2], det[2:] - det[:2])) 74 | bbox = scale_bbox(bbox, crop_scale) 75 | frame_cropped = crop_img(frame, bbox) 76 | frame_cropped = cv2.resize(frame_cropped, (resolution, resolution), interpolation=cv2.INTER_CUBIC) 77 | 78 | # Write cropped frame to output video 79 | out_vids[s].write(frame_cropped) 80 | 81 | # Add cropped detection to list 82 | orig_size = bbox[2:] 83 | axes_scale = np.array([resolution, resolution]) / orig_size 84 | det[:2] -= bbox[:2] 85 | det[2:] -= bbox[:2] 86 | det[:2] *= axes_scale 87 | det[2:] *= axes_scale 88 | cropped_detections[s].append(det) 89 | 90 | # Add cropped landmarks to list 91 | if hasattr(seq, 'landmarks'): 92 | curr_landmarks = seq.landmarks[i - seq.start_index] 93 | curr_landmarks[:, :2] -= bbox[:2] 94 | 95 | # 3D landmarks case 96 | if curr_landmarks.shape[1] == 3: 97 | axes_scale = np.append(axes_scale, axes_scale.mean()) 98 | 99 | curr_landmarks *= axes_scale 100 | cropped_landmarks[s].append(curr_landmarks) 101 | 102 | # For each sequence write cropped sequence to file 103 | for s, seq in enumerate(seq_list): 104 | # seq.detections = np.array(cropped_detections[s]) 105 | # if hasattr(seq, 'landmarks'): 106 | # seq.landmarks = np.array(cropped_landmarks[s]) 107 | # seq.start_index = 0 108 | 109 | # TODO: this is a hack to change class type (remove this later) 110 | out_seq = Sequence(0) 111 | out_seq.detections = np.array(cropped_detections[s]) 112 | if hasattr(seq, 'landmarks'): 113 | out_seq.landmarks = np.array(cropped_landmarks[s]) 114 | out_seq.id, out_seq.obj_id, out_seq.size_avg = seq.id, seq.obj_id, seq.size_avg 115 | 116 | # Write to file 117 | curr_out_name = os.path.splitext(os.path.basename(input_path))[0] + '_seq%02d%s' % (out_seq.id, seq_postfix) 118 | curr_out_path = os.path.join(output_dir, curr_out_name) 119 | with open(curr_out_path, "wb") as fp: # Pickling 120 | pickle.dump([out_seq], fp) 121 | 122 | 123 | if __name__ == "__main__": 124 | # Parse program arguments 125 | import argparse 126 | parser = argparse.ArgumentParser('crop_video_sequences') 127 | parser.add_argument('input', metavar='VIDEO', 128 | help='path to input video') 129 | parser.add_argument('-o', '--output', metavar='DIR', 130 | help='output directory') 131 | parser.add_argument('-c', '--cache', metavar='PATH', 132 | help='path to sequence cache file') 133 | parser.add_argument('-sp', '--seq_postfix', default='_dsfd_seq.pkl', metavar='POSTFIX', 134 | help='input sequence file postfix') 135 | parser.add_argument('-r', '--resolution', default=256, type=int, metavar='N', 136 | help='output video resolution (default: 256)') 137 | parser.add_argument('-cs', '--crop_scale', default=2.0, type=float, metavar='F', 138 | help='crop scale relative to bounding box (default: 2.0)') 139 | parser.add_argument('-s', '--select', default='all', metavar='STR', 140 | help='selection method [all|longest]') 141 | parser.add_argument('-dt', '--disable_tqdm', dest='disable_tqdm', action='store_true', 142 | help='if specified disables tqdm progress bar') 143 | parser.add_argument('-ec', '--encoder_codec', default='mp4v', metavar='STR', 144 | help='encoder codec code') 145 | args = parser.parse_args() 146 | main(args.input, args.output, args.cache, args.seq_postfix, args.resolution, args.crop_scale, args.select, 147 | args.disable_tqdm, args.encoder_codec) 148 | -------------------------------------------------------------------------------- /fsgan/preprocess/crop_video_sequences_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | from multiprocessing import Pool 4 | from functools import partial 5 | from itertools import groupby 6 | import numpy as np 7 | from fsgan.preprocess.crop_video_sequences import main as crop_video_sequences 8 | 9 | 10 | def parse_videos(root): 11 | vid_rel_paths = [] 12 | for r, d, f in os.walk(root): 13 | for file in f: 14 | if file.endswith('.mp4'): 15 | vid_rel_paths.append(os.path.join(os.path.relpath(r, root), file).replace('\\', '/')) 16 | 17 | return vid_rel_paths 18 | 19 | 20 | def process_video(input, cache_postfix='_dsfd_seq.pkl', resolution=256, crop_scale=1.2, select='all'): 21 | file_path, out_dir = input[0], input[1] 22 | filename = os.path.basename(file_path) 23 | curr_out_cache_path = os.path.join(out_dir, os.path.splitext(filename)[0] + '_seq00' + cache_postfix) 24 | if os.path.exists(curr_out_cache_path): 25 | return True 26 | 27 | # Process video 28 | crop_video_sequences(file_path, out_dir, None, cache_postfix, resolution, crop_scale, select, disable_tqdm=True) 29 | return True 30 | 31 | 32 | def main(root, output_dir, file_lists=None, cache_postfix='_dsfd_seq.pkl', resolution=256, crop_scale=2.0, workers=4, 33 | select='all'): 34 | # Validation 35 | if not os.path.isdir(root): 36 | raise RuntimeError('root directory does not exist: ' + root) 37 | if not os.path.isdir(output_dir): 38 | raise RuntimeError('Output directory does not exist: ' + output_dir) 39 | 40 | # Parse files from directory or file lists (if specified) 41 | if file_lists is None: 42 | vid_rel_paths = parse_videos(root) 43 | else: 44 | vid_rel_paths = [] 45 | for file_list in file_lists: 46 | vid_rel_paths.append(np.loadtxt(os.path.join(root, file_list), dtype=str)) 47 | vid_rel_paths = np.concatenate(vid_rel_paths) 48 | 49 | vid_out_dirs = [os.path.join(output_dir, os.path.split(p)[0]) for p in vid_rel_paths] 50 | vid_paths = [os.path.join(root, p) for p in vid_rel_paths] 51 | 52 | # Make directory structure 53 | for out_dir in vid_out_dirs: 54 | if not os.path.exists(out_dir): 55 | os.makedirs(out_dir) 56 | 57 | # Process all videos 58 | f = partial(process_video, cache_postfix=cache_postfix, resolution=resolution, crop_scale=crop_scale, select=select) 59 | with Pool(workers) as p: 60 | list(tqdm(p.imap(f, zip(vid_paths, vid_out_dirs)), total=len(vid_paths))) 61 | 62 | # Parse generated sequence videos 63 | vid_seq_rel_paths = parse_videos(output_dir) 64 | vid_seq_keys, vid_seq_groups = zip(*[(key, list(group)) for key, group in 65 | groupby(vid_seq_rel_paths, lambda p: (p[:-10] + '.mp4'))]) 66 | vid_seq_groups = np.array(vid_seq_groups) 67 | 68 | for file_list in file_lists: 69 | # Adjust file list to generated sequence videos 70 | list_rel_paths = np.loadtxt(os.path.join(root, file_list), dtype=str) 71 | _, indices, _ = np.intersect1d(vid_seq_keys, list_rel_paths, return_indices=True) 72 | list_seq_rel_paths = np.concatenate(vid_seq_groups[indices]) 73 | 74 | # Write output list to file 75 | np.savetxt(os.path.join(output_dir, file_list), list_seq_rel_paths, fmt='%s') 76 | 77 | 78 | if __name__ == "__main__": 79 | # Parse program arguments 80 | import argparse 81 | parser = argparse.ArgumentParser(os.path.splitext(os.path.basename(__file__))[0]) 82 | parser.add_argument('root', metavar='DIR', 83 | help='root directory') 84 | parser.add_argument('-o', '--output', metavar='DIR', required=True, 85 | help='output directory') 86 | parser.add_argument('-fl', '--file_lists', metavar='PATH', nargs='+', 87 | help='file lists') 88 | parser.add_argument('-cp', '--cache_postfix', default='_dsfd_seq.pkl', metavar='POSTFIX', 89 | help='cache file postfix') 90 | parser.add_argument('-r', '--resolution', default=256, type=int, metavar='N', 91 | help='output video resolution (default: 256)') 92 | parser.add_argument('-cs', '--crop_scale', type=float, metavar='F', default=2.0, 93 | help='crop scale relative to detection bounding box') 94 | parser.add_argument('-w', '--workers', default=4, type=int, metavar='N', 95 | help='number of data loading workers (default: 4)') 96 | parser.add_argument('-s', '--select', default='all', metavar='STR', 97 | help='selection method [all|longest]') 98 | args = parser.parse_args() 99 | main(args.root, args.output, args.file_lists, args.cache_postfix, args.resolution, args.crop_scale, args.workers, 100 | args.select) 101 | -------------------------------------------------------------------------------- /fsgan/preprocess/detections2sequences_1euro.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from tqdm import tqdm 4 | import numpy as np 5 | import cv2 6 | from fsgan.utils.bbox_utils import batch_iou 7 | from fsgan.utils.video_utils import Sequence, smooth_detections_1euro 8 | 9 | 10 | def main(input_path, cache_path=None, output_path=None, iou_thresh=0.75, min_length=10, min_size=64, crop_scale=1.2, 11 | kernel_size=7, smooth=False, display=False, write_empty=False): 12 | cache_path = os.path.splitext(input_path)[0] + '_dsfd.pkl' if cache_path is None else cache_path 13 | output_path = os.path.splitext(input_path)[0] + '_dsfd_seq.pkl' if output_path is None else output_path 14 | min_length = 1 if os.path.splitext(input_path)[1] == '.jpg' else min_length 15 | 16 | # Validation 17 | if not os.path.isfile(cache_path): 18 | raise RuntimeError('Cache file does not exist: ' + cache_path) 19 | 20 | print('=> Extracting sequences from detections in video: "%s"...' % os.path.basename(input_path)) 21 | 22 | # Load detections from file 23 | with open(cache_path, "rb") as fp: # Unpickling 24 | det_list = pickle.load(fp) 25 | det_list.append(np.array([], dtype='float32')) # Makes sure the final sequences are added to the seq_list 26 | 27 | # Open input video file 28 | if display: 29 | cap = cv2.VideoCapture(input_path) 30 | if not cap.isOpened(): 31 | raise RuntimeError('Failed to read video: ' + input_path) 32 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 33 | fps = cap.get(cv2.CAP_PROP_FPS) 34 | input_vid_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 35 | input_vid_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 36 | 37 | # For each frame detection 38 | seq_list = [] 39 | curr_seq_list = [] 40 | # for i, frame_det in enumerate(det_list): # Debug 41 | for i, frame_det in tqdm(enumerate(det_list), total=len(det_list)): 42 | frame_det = list(frame_det) 43 | if len(curr_seq_list) > 0: 44 | # For each sequence find matching detections 45 | keep_indices = np.full(len(curr_seq_list), False) 46 | for s, curr_seq in enumerate(curr_seq_list): 47 | if len(frame_det) > 0: 48 | curr_seq_det_rep = np.repeat(np.expand_dims(curr_seq[-1], 0), len(frame_det), axis=0) 49 | ious = batch_iou(curr_seq_det_rep, np.array(frame_det)) 50 | best_match_ind = ious.argmax() 51 | if ious[best_match_ind] > iou_thresh: 52 | # Match found 53 | curr_seq.add(frame_det[best_match_ind]) 54 | del frame_det[best_match_ind] 55 | keep_indices[s] = True 56 | 57 | # Remove unmatched sequences and add the suitable ones to the final sequence list 58 | if not np.all(keep_indices): 59 | seq_list += [seq for k, seq in enumerate(curr_seq_list) 60 | if (not keep_indices[k]) and len(seq) >= min_length and 61 | (seq.size_avg * crop_scale) >= min_size] 62 | curr_seq_list = [seq for k, seq in enumerate(curr_seq_list) if keep_indices[k]] 63 | 64 | # Add remaining detections to current sequences list as new sequences 65 | curr_seq_list += [Sequence(i, d) for d in frame_det] 66 | 67 | # Render current sequence list 68 | if display: 69 | ret, render_img = cap.read() 70 | if render_img is None: 71 | continue 72 | for j, seq in enumerate(curr_seq_list): 73 | rect = seq[-1] 74 | # cv2.rectangle(render_img, tuple(rect[:2]), tuple(rect[:2] + rect[2:]), (0, 0, 255), 1) 75 | cv2.rectangle(render_img, tuple(rect[:2]), tuple(rect[2:]), (0, 0, 255), 1) 76 | text_pos = (rect[:2] + np.array([0, -8])).astype('float32') 77 | cv2.putText(render_img, 'id: %d' % seq.id, tuple(text_pos), 78 | cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA) 79 | cv2.imshow('render_img', render_img) 80 | if cv2.waitKey(1) & 0xFF == ord('q'): 81 | break 82 | 83 | # Reduce the n sequence ids to [0, ..., n - 1] 84 | ids = np.sort([seq.id for seq in seq_list]) 85 | ids_map = {k: v for v, k in enumerate(ids)} 86 | for seq in seq_list: 87 | seq.id = ids_map[seq.id] 88 | 89 | # Smooth sequence bounding boxes or finalize detections (convert to numpy array) 90 | for seq in seq_list: 91 | if smooth: 92 | seq.detections = smooth_detections_1euro(seq.detections, kernel_size) 93 | else: 94 | seq.finalize() 95 | 96 | # Write final sequence list to file 97 | if len(seq_list) > 0 or write_empty: 98 | with open(output_path, "wb") as fp: # Pickling 99 | pickle.dump(seq_list, fp) 100 | 101 | 102 | if __name__ == "__main__": 103 | # Parse program arguments 104 | import argparse 105 | parser = argparse.ArgumentParser('detections2sequences') 106 | parser.add_argument('input', metavar='VIDEO', 107 | help='path to input video') 108 | parser.add_argument('-c', '--cache', metavar='PATH', 109 | help='path to detections cache file') 110 | parser.add_argument('-o', '--output', metavar='PATH', 111 | help='output directory') 112 | parser.add_argument('-it', '--iou_thresh', default=0.75, type=float, 113 | metavar='F', help='IOU threshold') 114 | parser.add_argument('-ml', '--min_length', default=10, type=int, 115 | metavar='N', help='minimum sequence length') 116 | parser.add_argument('-ms', '--min_size', default=64, type=int, 117 | metavar='N', help='minimum sequence average bounding box size') 118 | parser.add_argument('-cs', '--crop_scale', default=1.2, type=float, metavar='F', 119 | help='crop scale relative to bounding box (default: 1.2)') 120 | parser.add_argument('-ks', '--kernel_size', default=7, type=int, 121 | metavar='N', help='average kernel size') 122 | parser.add_argument('-s', '--smooth', action='store_true', 123 | help='smooth the sequence bounding boxes') 124 | parser.add_argument('-d', '--display', action='store_true', 125 | help='display the rendering') 126 | parser.add_argument('-we', '--write_empty', action='store_true', 127 | help='write empty sequence lists to file') 128 | args = parser.parse_args() 129 | main(args.input, args.cache, args.output, args.iou_thresh, args.min_length, args.min_size, args.crop_scale, 130 | args.kernel_size, args.smooth, args.display, args.write_empty) 131 | -------------------------------------------------------------------------------- /fsgan/preprocess/detections2sequences_center.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from tqdm import tqdm 4 | import numpy as np 5 | import cv2 6 | from fsgan.utils.bbox_utils import batch_iou 7 | from fsgan.utils.bbox_utils import smooth_bboxes 8 | from fsgan.utils.video_utils import Sequence 9 | # from fsgan.utils.video_utils import Sequence, smooth_detections_avg_center 10 | 11 | 12 | def main(input_path, output_path=None, cache_path=None, iou_thresh=0.75, min_length=10, min_size=64, crop_scale=1.2, 13 | center_kernel=25, size_kernel=51, smooth=False, display=False, write_empty=False): 14 | cache_path = os.path.splitext(input_path)[0] + '_dsfd.pkl' if cache_path is None else cache_path 15 | output_path = os.path.splitext(input_path)[0] + '_dsfd_seq.pkl' if output_path is None else output_path 16 | min_length = 1 if os.path.splitext(input_path)[1] == '.jpg' else min_length 17 | 18 | # Validation 19 | if not os.path.isfile(cache_path): 20 | raise RuntimeError('Cache file does not exist: ' + cache_path) 21 | 22 | print('=> Extracting sequences from detections in video: "%s"...' % os.path.basename(input_path)) 23 | 24 | # Load detections from file 25 | with open(cache_path, "rb") as fp: # Unpickling 26 | det_list = pickle.load(fp) 27 | det_list.append(np.array([], dtype='float32')) # Makes sure the final sequences are added to the seq_list 28 | 29 | # Open input video file 30 | if display: 31 | cap = cv2.VideoCapture(input_path) 32 | if not cap.isOpened(): 33 | raise RuntimeError('Failed to read video: ' + input_path) 34 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 35 | fps = cap.get(cv2.CAP_PROP_FPS) 36 | input_vid_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 37 | input_vid_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 38 | 39 | # For each frame detection 40 | seq_list = [] 41 | curr_seq_list = [] 42 | # for i, frame_det in enumerate(det_list): # Debug 43 | for i, frame_det in tqdm(enumerate(det_list), total=len(det_list)): 44 | frame_det = list(frame_det) 45 | if len(curr_seq_list) > 0: 46 | # For each sequence find matching detections 47 | keep_indices = np.full(len(curr_seq_list), False) 48 | for s, curr_seq in enumerate(curr_seq_list): 49 | if len(frame_det) > 0: 50 | curr_seq_det_rep = np.repeat(np.expand_dims(curr_seq[-1], 0), len(frame_det), axis=0) 51 | ious = batch_iou(curr_seq_det_rep, np.array(frame_det)) 52 | best_match_ind = ious.argmax() 53 | if ious[best_match_ind] > iou_thresh: 54 | # Match found 55 | curr_seq.add(frame_det[best_match_ind]) 56 | del frame_det[best_match_ind] 57 | keep_indices[s] = True 58 | 59 | # Remove unmatched sequences and add the suitable ones to the final sequence list 60 | if not np.all(keep_indices): 61 | seq_list += [seq for k, seq in enumerate(curr_seq_list) 62 | if (not keep_indices[k]) and len(seq) >= min_length and 63 | (seq.size_avg * crop_scale) >= min_size] 64 | curr_seq_list = [seq for k, seq in enumerate(curr_seq_list) if keep_indices[k]] 65 | 66 | # Add remaining detections to current sequences list as new sequences 67 | curr_seq_list += [Sequence(i, d) for d in frame_det] 68 | 69 | # Render current sequence list 70 | if display: 71 | ret, render_img = cap.read() 72 | if render_img is None: 73 | continue 74 | for j, seq in enumerate(curr_seq_list): 75 | rect = seq[-1] 76 | # cv2.rectangle(render_img, tuple(rect[:2]), tuple(rect[:2] + rect[2:]), (0, 0, 255), 1) 77 | cv2.rectangle(render_img, tuple(rect[:2]), tuple(rect[2:]), (0, 0, 255), 1) 78 | text_pos = (rect[:2] + np.array([0, -8])).astype('float32') 79 | cv2.putText(render_img, 'id: %d' % seq.id, tuple(text_pos), 80 | cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA) 81 | cv2.imshow('render_img', render_img) 82 | if cv2.waitKey(1) & 0xFF == ord('q'): 83 | break 84 | 85 | # Reduce the n sequence ids to [0, ..., n - 1] 86 | ids = np.sort([seq.id for seq in seq_list]) 87 | ids_map = {k: v for v, k in enumerate(ids)} 88 | for seq in seq_list: 89 | seq.id = ids_map[seq.id] 90 | 91 | # Smooth sequence bounding boxes or finalize detections (convert to numpy array) 92 | for seq in seq_list: 93 | if smooth: 94 | # seq.detections = smooth_detections_avg_center(seq.detections, center_kernel, size_kernel) 95 | seq.detections = smooth_bboxes(seq.detections, center_kernel, size_kernel) 96 | else: 97 | seq.finalize() 98 | 99 | # Write final sequence list to file 100 | if len(seq_list) > 0 or write_empty: 101 | with open(output_path, "wb") as fp: # Pickling 102 | pickle.dump(seq_list, fp) 103 | 104 | 105 | if __name__ == "__main__": 106 | # Parse program arguments 107 | import argparse 108 | parser = argparse.ArgumentParser('detections2sequences_02') 109 | parser.add_argument('input', metavar='VIDEO', 110 | help='path to input video') 111 | parser.add_argument('-o', '--output', metavar='PATH', 112 | help='output directory') 113 | parser.add_argument('-c', '--cache', metavar='PATH', 114 | help='path to detections cache file') 115 | parser.add_argument('-it', '--iou_thresh', default=0.75, type=float, 116 | metavar='F', help='IOU threshold') 117 | parser.add_argument('-ml', '--min_length', default=10, type=int, 118 | metavar='N', help='minimum sequence length') 119 | parser.add_argument('-ms', '--min_size', default=64, type=int, 120 | metavar='N', help='minimum sequence average bounding box size') 121 | parser.add_argument('-cs', '--crop_scale', default=1.2, type=float, metavar='F', 122 | help='crop scale relative to bounding box (default: 1.2)') 123 | parser.add_argument('-ck', '--center_kernel', default=25, type=int, 124 | metavar='N', help='center average kernel size') 125 | parser.add_argument('-sk', '--size_kernel', default=51, type=int, 126 | metavar='N', help='size average kernel size') 127 | parser.add_argument('-s', '--smooth', action='store_true', 128 | help='smooth the sequence bounding boxes') 129 | parser.add_argument('-d', '--display', action='store_true', 130 | help='display the rendering') 131 | parser.add_argument('-we', '--write_empty', action='store_true', 132 | help='write empty sequence lists to file') 133 | args = parser.parse_args() 134 | main(args.input, args.output, args.cache, args.iou_thresh, args.min_length, args.min_size, args.crop_scale, 135 | args.center_kernel, args.size_kernel, args.smooth, args.display, args.write_empty) 136 | -------------------------------------------------------------------------------- /fsgan/preprocess/euler_sequences.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from tqdm import tqdm 4 | import numpy as np 5 | import cv2 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from fsgan.models.hopenet import Hopenet 10 | from fsgan.utils.utils import set_device 11 | from fsgan.utils.bbox_utils import scale_bbox, crop_img 12 | from fsgan.utils.img_utils import rgb2tensor 13 | from fsgan.utils.video_utils import Sequence 14 | from fsgan.utils.img_utils import tensor2bgr # Debug 15 | 16 | 17 | def main(input_path, output_path=None, seq_postfix='_dsfd_seq.pkl', output_postfix='_dsfd_seq_lms_euler.pkl', 18 | pose_model_path='weights/hopenet_robust_alpha1.pkl', smooth_det=False, smooth_euler=False, gpus=None, 19 | cpu_only=False, batch_size=16): 20 | cache_path = os.path.splitext(input_path)[0] + seq_postfix 21 | output_path = os.path.splitext(input_path)[0] + output_postfix if output_path is None else output_path 22 | 23 | # Initialize device 24 | torch.set_grad_enabled(False) 25 | device, gpus = set_device(gpus, not cpu_only) 26 | 27 | # Load sequences from file 28 | with open(cache_path, "rb") as fp: # Unpickling 29 | seq_list = pickle.load(fp) 30 | 31 | # Load pose model 32 | face_pose = Hopenet().to(device) 33 | checkpoint = torch.load(pose_model_path) 34 | face_pose.load_state_dict(checkpoint) 35 | face_pose.train(False) 36 | 37 | # Open input video file 38 | cap = cv2.VideoCapture(input_path) 39 | if not cap.isOpened(): 40 | raise RuntimeError('Failed to read video: ' + input_path) 41 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 42 | fps = cap.get(cv2.CAP_PROP_FPS) 43 | input_vid_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 44 | input_vid_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 45 | 46 | # Smooth sequence bounding boxes 47 | if smooth_det: 48 | for seq in seq_list: 49 | seq.smooth() 50 | 51 | # For each sequence 52 | total_detections = sum([len(s) for s in seq_list]) 53 | pbar = tqdm(range(total_detections), unit='detections') 54 | for seq in seq_list: 55 | euler = [] 56 | frame_cropped_tensor_list = [] 57 | cap.set(cv2.CAP_PROP_POS_FRAMES, seq.start_index) 58 | 59 | # For each detection bounding box in the current sequence 60 | for i, det in enumerate(seq.detections): 61 | ret, frame_bgr = cap.read() 62 | if frame_bgr is None: 63 | raise RuntimeError('Failed to read frame from video!') 64 | frame_rgb = frame_bgr[:, :, ::-1] 65 | 66 | # Crop frame 67 | bbox = np.concatenate((det[:2], det[2:] - det[:2])) 68 | bbox = scale_bbox(bbox, 1.2) 69 | frame_cropped_rgb = crop_img(frame_rgb, bbox) 70 | frame_cropped_rgb = cv2.resize(frame_cropped_rgb, (224, 224), interpolation=cv2.INTER_CUBIC) 71 | frame_cropped_tensor = rgb2tensor(frame_cropped_rgb).to(device) 72 | 73 | # Gather batches 74 | frame_cropped_tensor_list.append(frame_cropped_tensor) 75 | if len(frame_cropped_tensor_list) < batch_size and (i + 1) < len(seq): 76 | continue 77 | frame_cropped_tensor_batch = torch.cat(frame_cropped_tensor_list, dim=0) 78 | 79 | # Calculate euler angles 80 | curr_euler_batch = face_pose(frame_cropped_tensor_batch) # Yaw, Pitch, Roll 81 | curr_euler_batch = curr_euler_batch.cpu().numpy() 82 | 83 | # For each prediction in the batch 84 | for b, curr_euler in enumerate(curr_euler_batch): 85 | # Add euler to list 86 | euler.append(curr_euler) 87 | 88 | # Render 89 | # render_img = tensor2bgr(frame_cropped_tensor_batch[b]).copy() 90 | # cv2.putText(render_img, '(%.2f, %.2f, %.2f)' % (curr_euler[0], curr_euler[1], curr_euler[2]), (15, 15), 91 | # cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA) 92 | # cv2.imshow('render_img', render_img) 93 | # if cv2.waitKey(0) & 0xFF == ord('q'): 94 | # break 95 | 96 | # Clear lists 97 | frame_cropped_tensor_list.clear() 98 | 99 | pbar.update(len(frame_cropped_tensor_batch)) 100 | 101 | # Add landmarks to sequence and optionally smooth them 102 | euler = np.array(euler) 103 | if smooth_euler: 104 | euler = smooth(euler) 105 | seq.euler = euler 106 | 107 | # Write final sequence list to file 108 | with open(output_path, "wb") as fp: # Pickling 109 | pickle.dump(seq_list, fp) 110 | 111 | 112 | def smooth(x, kernel_size=7): 113 | # Prepare smoothing kernel 114 | w = np.hamming(kernel_size) 115 | w /= w.sum() 116 | 117 | # Smooth euler 118 | x_padded = np.pad(x, ((kernel_size // 2, kernel_size // 2), (0, 0)), 'reflect') 119 | for i in range(x.shape[1]): 120 | x[:, i] = np.convolve(w, x_padded[:, i], mode='valid') 121 | 122 | return x 123 | 124 | 125 | if __name__ == "__main__": 126 | # Parse program arguments 127 | import argparse 128 | parser = argparse.ArgumentParser('landmarks_sequences') 129 | parser.add_argument('input', metavar='VIDEO', 130 | help='path to input video') 131 | parser.add_argument('-o', '--output', default=None, metavar='PATH', 132 | help='output directory') 133 | parser.add_argument('-sp', '--seq_postfix', default='_dsfd_seq.pkl', metavar='POSTFIX', 134 | help='input sequence file postfix') 135 | parser.add_argument('-op', '--output_postfix', default='_dsfd_seq_lms_euler.pkl', metavar='POSTFIX', 136 | help='output file postfix') 137 | parser.add_argument('-p', '--pose_model', default='weights/hopenet_robust_alpha1.pkl', metavar='PATH', 138 | help='path to pose model file') 139 | parser.add_argument('-sd', '--smooth_det', action='store_true', 140 | help='smooth the sequence detection bounding boxes') 141 | parser.add_argument('-se', '--smooth_euler', action='store_true', 142 | help='smooth the sequence landmarks') 143 | parser.add_argument('--gpus', default=None, nargs='+', type=int, metavar='N', 144 | help='list of gpu ids to use (default: all)') 145 | parser.add_argument('--cpu_only', action='store_true', 146 | help='force cpu only') 147 | parser.add_argument('-b', '--batch-size', default=16, type=int, metavar='N', 148 | help='batch size (default: 16)') 149 | args = parser.parse_args() 150 | main(args.input, args.output, args.seq_postfix, args.output_postfix, args.pose_model, args.smooth_det, 151 | args.smooth_euler, args.gpus, args.cpu_only, args.batch_size) 152 | -------------------------------------------------------------------------------- /fsgan/preprocess/produce_train_val.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob1 3 | import numpy as np 4 | 5 | 6 | def parse_files(dir, postfix='.mp4', cache_postfix=None): 7 | files = glob1(dir, '*' + postfix) 8 | if cache_postfix is not None: 9 | files = [f for f in files if os.path.isfile(os.path.join(dir, os.path.splitext(f)[0] + cache_postfix))] 10 | dir = os.path.expanduser(dir) 11 | for fname in sorted(os.listdir(dir)): 12 | path = os.path.join(dir, fname) 13 | if os.path.isdir(path): 14 | files += [os.path.join(fname, f).replace('\\', '/') for f in parse_files(path, postfix, cache_postfix)] 15 | 16 | return sorted(files) 17 | 18 | 19 | def main(in_dir, out_dir=None, ratio=0.1, postfix='.mp4', cache_postfix=None): 20 | # Validation 21 | if not os.path.isdir(in_dir): 22 | raise RuntimeError('Input directory does not exist: ' + in_dir) 23 | out_dir = in_dir if out_dir is None else out_dir 24 | if not os.path.isdir(out_dir): 25 | raise RuntimeError('Output directory does not exist: ' + out_dir) 26 | 27 | # Parse files 28 | file_rel_paths = np.array(parse_files(in_dir, postfix, cache_postfix)) 29 | 30 | # Generate directory splits 31 | n = len(file_rel_paths) 32 | val_indices = np.random.choice(n, int(np.round(n * ratio)), replace=False).astype(int) 33 | train_indices = np.setdiff1d(np.arange(n), val_indices) 34 | train_indices.sort() 35 | val_indices.sort() 36 | 37 | train_file_list = file_rel_paths[train_indices] 38 | val_file_list = file_rel_paths[val_indices] 39 | 40 | # Output splits to file 41 | np.savetxt(os.path.join(out_dir, 'train_list.txt'), train_file_list, fmt='%s') 42 | np.savetxt(os.path.join(out_dir, 'val_list.txt'), val_file_list, fmt='%s') 43 | 44 | 45 | if __name__ == '__main__': 46 | import argparse 47 | parser = argparse.ArgumentParser('produce_train_val') 48 | parser.add_argument('input', help='dataset root directory') 49 | parser.add_argument('-o', '--output', help='output directory') 50 | parser.add_argument('-r', '--ratio', default=0.1, type=float, help='ratio of validation split') 51 | parser.add_argument('-p', '--postfix', default='.mp4', help='files postfix') 52 | parser.add_argument('-cp', '--cache_postfix', help='cache postfix') 53 | args = parser.parse_args() 54 | main(args.input, args.output, args.ratio, args.postfix, args.cache_postfix) 55 | -------------------------------------------------------------------------------- /fsgan/preprocess/render_sequences.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from tqdm import tqdm 4 | import numpy as np 5 | import cv2 6 | from fsgan.utils.video_utils import Sequence 7 | 8 | 9 | def main(input_path, output_path=None, postfix='_dsfd_seq.pkl', smooth=False, fps=None): 10 | cache_path = os.path.splitext(input_path)[0] + postfix 11 | # output_path = os.path.splitext(input_path)[0] + '.mp4' if output_path is None else output_path 12 | 13 | # Load sequences from file 14 | with open(cache_path, "rb") as fp: # Unpickling 15 | seq_list = pickle.load(fp) 16 | 17 | # Open input video file 18 | cap = cv2.VideoCapture(input_path) 19 | if not cap.isOpened(): 20 | raise RuntimeError('Failed to read video: ' + input_path) 21 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 22 | fps = cap.get(cv2.CAP_PROP_FPS) if fps is None else fps 23 | input_vid_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 24 | input_vid_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 25 | 26 | # Initialize output video file 27 | if output_path is not None: 28 | if os.path.isdir(output_path): 29 | output_filename = os.path.basename(input_path) 30 | output_path = os.path.join(output_path, output_filename) 31 | fourcc = cv2.VideoWriter_fourcc(*'x264') 32 | out_vid = cv2.VideoWriter(output_path, fourcc, fps, (input_vid_width, input_vid_height)) 33 | else: 34 | out_vid = None 35 | 36 | # Smooth sequence bounding boxes 37 | if smooth: 38 | for seq in seq_list: 39 | seq.smooth() 40 | 41 | # For each frame in the target video 42 | for i in tqdm(range(total_frames)): 43 | ret, frame = cap.read() 44 | if frame is None: 45 | continue 46 | 47 | # For each sequence 48 | render_img = frame 49 | for seq in seq_list: 50 | if i < seq.start_index or (seq.start_index + len(seq) - 1) < i: 51 | continue 52 | rect = seq[i - seq.start_index] 53 | cv2.rectangle(render_img, tuple(rect[:2]), tuple(rect[2:]), (0, 255, 0), 1) 54 | 55 | if hasattr(seq, 'landmarks'): 56 | landmarks = seq.landmarks[i - seq.start_index] 57 | for point in np.round(landmarks).astype(int): 58 | cv2.circle(render_img, (point[0], point[1]), 1, (0, 0, 255), -1) 59 | 60 | text_pos = (rect[:2] + np.array([0, -8])).astype('float32') 61 | cv2.putText(render_img, 'id: %d' % seq.id, tuple(text_pos), 62 | cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1, cv2.LINE_AA) 63 | 64 | # Render 65 | if out_vid is not None: 66 | out_vid.write(render_img) 67 | cv2.imshow('render_img', render_img) 68 | delay = np.round(1000.0 / fps).astype(int) if fps > 1e-5 else 0 69 | if cv2.waitKey(delay) & 0xFF == ord('q'): 70 | break 71 | 72 | 73 | if __name__ == "__main__": 74 | # Parse program arguments 75 | import argparse 76 | parser = argparse.ArgumentParser('detections2sequences') 77 | parser.add_argument('input', metavar='VIDEO', 78 | help='path to input video') 79 | parser.add_argument('-o', '--output', default=None, metavar='PATH', 80 | help='output directory') 81 | parser.add_argument('-p', '--postfix', default='_dsfd_seq.pkl', metavar='POSTFIX', 82 | help='input sequence file postfix') 83 | parser.add_argument('-s', '--smooth', action='store_true', 84 | help='smooth the sequence bounding boxes') 85 | parser.add_argument('-f', '--fps', type=float, metavar='F', 86 | help='force video fps, set 0 to pause after each frame') 87 | args = parser.parse_args() 88 | main(args.input, args.output, args.postfix, args.smooth, args.fps) 89 | -------------------------------------------------------------------------------- /fsgan/preprocess/sequence_stats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sequence statistics: Count, length, bounding boxes size. 3 | """ 4 | import os 5 | from glob import glob 6 | import pickle 7 | from tqdm import tqdm 8 | 9 | 10 | def extract_stats(cache_path): 11 | # Load sequences from file 12 | with open(cache_path, "rb") as fp: # Unpickling 13 | seq_list = pickle.load(fp) 14 | 15 | if len(seq_list) == 0: 16 | return 0, 0., 0. 17 | 18 | # For each sequence 19 | len_sum, size_sum = 0., 0. 20 | for seq in seq_list: 21 | len_sum += len(seq) 22 | size_sum += seq.size_avg 23 | 24 | return len(seq_list), len_sum / len(seq_list), size_sum / len(seq_list) 25 | 26 | 27 | def main(in_dir, out_path=None, postfix='_dsfd_seq.pkl'): 28 | out_path = os.path.join(in_dir, 'sequence_stats.txt') if out_path is None else out_path 29 | 30 | # Validation 31 | if not os.path.isdir(in_dir): 32 | raise RuntimeError('Input directory not exist: ' + in_dir) 33 | 34 | # Parse file paths 35 | input_query = os.path.join(in_dir, '*' + postfix) 36 | file_paths = sorted(glob(input_query)) 37 | 38 | # For each file in the input directory with the specified postfix 39 | pbar = tqdm(file_paths, unit='files') 40 | count_sum, len_sum, size_sum = 0., 0., 0. 41 | vid_count = 0 42 | for i, file_path in enumerate(pbar): 43 | curr_count, curr_mean_len, curr_mean_size = extract_stats(file_path) 44 | if curr_count == 0: 45 | continue 46 | count_sum += curr_count 47 | len_sum += curr_mean_len 48 | size_sum += curr_mean_size 49 | vid_count += 1 50 | pbar.set_description('mean_count = %.1f, mean_len = %.1f, mean_size = %.1f, valid_vids = %d / %d' % 51 | (count_sum / vid_count, len_sum / vid_count, size_sum / vid_count, vid_count, i + 1)) 52 | 53 | # Write result to file 54 | if out_path is not None: 55 | with open(out_path, "w") as f: 56 | f.write('mean_count = %.1f\n' % (count_sum / vid_count)) 57 | f.write('mean_len = %.1f\n' % (len_sum / vid_count)) 58 | f.write('mean_size = %.1f\n' % (size_sum / vid_count)) 59 | f.write('valid videos = %d / %d\n' % (vid_count, len(file_paths))) 60 | 61 | 62 | if __name__ == "__main__": 63 | # Parse program arguments 64 | import argparse 65 | parser = argparse.ArgumentParser('detections2sequences') 66 | parser.add_argument('input', metavar='DIR', 67 | help='input directory') 68 | parser.add_argument('-o', '--output', default=None, metavar='PATH', 69 | help='output directory') 70 | parser.add_argument('-p', '--postfix', metavar='POSTFIX', default='_dsfd_seq.pkl', 71 | help='the files postfix to search the input directory for') 72 | args = parser.parse_args() 73 | main(args.input, args.output, args.postfix) 74 | -------------------------------------------------------------------------------- /fsgan/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuvalNirkin/fsgan/0605e7c521ec697cbcd12d4f7d46b02beb808fd1/fsgan/utils/__init__.py -------------------------------------------------------------------------------- /fsgan/utils/batch.py: -------------------------------------------------------------------------------- 1 | """ Batch processing utility. """ 2 | 3 | import os 4 | import argparse 5 | from glob import glob 6 | import inspect 7 | from itertools import product 8 | import traceback 9 | import logging 10 | from fsgan.utils.obj_factory import partial_obj_factory 11 | 12 | 13 | parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) 14 | parser.add_argument('source', metavar='STR', nargs='+', 15 | help='sources input') 16 | parser.add_argument('-t', '--target', metavar='STR', nargs='*', 17 | help='targets input') 18 | parser.add_argument('-o', '--output', metavar='DIR', 19 | help='output directory') 20 | parser.add_argument('-fo', '--func_obj', default='fsgan.utils.batch.echo', metavar='OBJ', 21 | help='function object including kwargs') 22 | parser.add_argument('-p', '--postfix', metavar='POSTFIX', 23 | help='input files postfix') 24 | parser.add_argument('-op', '--out_postfix', metavar='POSTFIX', 25 | help='output files postfix') 26 | parser.add_argument('-i', '--indices', 27 | help='python style indices (e.g 0:10') 28 | parser.add_argument('-se', '--skip_existing', action='store_true', 29 | help='skip existing output file our directory') 30 | parser.add_argument('-ro', '--reverse_output', action='store_true', 31 | help='reverse the output name to be _') 32 | parser.add_argument('-io', '--ignore_output', action='store_true', 33 | help='avoid specifying an output parameter for the function object') 34 | 35 | 36 | def main(source, target=None, output=None, func_obj=None, postfix=None, out_postfix=None, indices=None, 37 | skip_existing=False, reverse_output=False, ignore_output=False): 38 | out_postfix = postfix if out_postfix is None else out_postfix 39 | out_postfix = '' if out_postfix is None else out_postfix 40 | ignore_output = True if output is None else ignore_output 41 | 42 | # Parse input paths 43 | source_paths = parse_paths(source, postfix) 44 | target_paths = parse_paths(target, postfix) 45 | assert len(source_paths) > 0, 'Found 0 source paths' 46 | assert target_paths is None or len(target_paths) > 0, 'Found 0 target paths' 47 | 48 | if target_paths is None: 49 | input_paths = source_paths 50 | else: 51 | input_paths = list(product(source_paths, target_paths)) 52 | input_paths = [(p1, p2) for p1, p2 in input_paths if os.path.basename(p1) != os.path.basename(p2)] 53 | input_paths = eval('input_paths[%s]' % indices) if indices is not None else input_paths 54 | 55 | # Get function object instance 56 | partial_func_obj = partial_obj_factory(func_obj) 57 | func_obj = partial_func_obj() if inspect.isclass(partial_func_obj.func) else partial_func_obj 58 | 59 | # For each input path 60 | for i, curr_input in enumerate(input_paths): 61 | if isinstance(curr_input, (list, tuple)): 62 | if reverse_output: 63 | out_vid_name = os.path.splitext(os.path.basename(curr_input[1]))[0] + '_' + \ 64 | os.path.splitext(os.path.basename(curr_input[0]))[0] + out_postfix 65 | else: 66 | out_vid_name = os.path.splitext(os.path.basename(curr_input[0]))[0] + '_' + \ 67 | os.path.splitext(os.path.basename(curr_input[1]))[0] + out_postfix 68 | else: 69 | out_vid_name = os.path.splitext(os.path.basename(curr_input))[0] 70 | curr_input = [curr_input] 71 | out_vid_path = os.path.join(output, out_vid_name) if output is not None else None 72 | if skip_existing and os.path.exists(out_vid_path): 73 | print('[%d/%d] Skipping "%s"' % (i + 1, len(input_paths), out_vid_name)) 74 | continue 75 | 76 | print('[%d/%d] Processing "%s"...' % (i + 1, len(input_paths), out_vid_name)) 77 | try: 78 | func_obj(*curr_input) if ignore_output else func_obj(*curr_input, out_vid_path) 79 | except Exception as e: 80 | logging.error(traceback.format_exc()) 81 | 82 | 83 | def parse_paths(inputs, postfix=None): 84 | postfix = '' if postfix is None else postfix 85 | if inputs is None: 86 | return None 87 | input_paths = [] 88 | i = 0 89 | while i < len(inputs): 90 | if os.path.isfile(inputs[i]): 91 | ext = os.path.splitext(inputs[i])[1] 92 | if ext == '.txt': 93 | # Found a list file with absolute paths 94 | with open(inputs[i], 'r') as f: 95 | file_abs_paths = f.read().splitlines() 96 | input_paths += file_abs_paths 97 | else: 98 | input_paths.append(inputs[i]) 99 | elif os.path.isdir(inputs[i]): 100 | if (i + 1) < len(inputs) and os.path.splitext(inputs[i + 1])[1] == '.txt': 101 | # Found root directory and list file pair 102 | file_list_path = inputs[i + 1] if os.path.exists(inputs[i + 1]) \ 103 | else os.path.join(inputs[i], inputs[i + 1]) 104 | assert os.path.isfile(file_list_path), f'List file does not exist: "{inputs[i + 1]}"' 105 | with open(file_list_path, 'r') as f: 106 | file_rel_paths = f.read().splitlines() 107 | input_paths += [os.path.join(inputs[i], p) for p in file_rel_paths] 108 | i += 1 109 | else: 110 | # Found a directory 111 | # Parse the files in the directory 112 | input_paths += glob(os.path.join(inputs[i], '*' + postfix)) 113 | elif any(c in inputs[i] for c in ['*']): 114 | input_paths += glob(inputs[i]) 115 | 116 | i += 1 117 | 118 | return input_paths 119 | 120 | 121 | def echo(*args, **kwargs): 122 | print('Received the following input:') 123 | print(f'args = {args}') 124 | print(f'kwargs = {kwargs}') 125 | 126 | 127 | if __name__ == "__main__": 128 | main(**vars(parser.parse_args())) 129 | -------------------------------------------------------------------------------- /fsgan/utils/bbox_utils.py: -------------------------------------------------------------------------------- 1 | """ Bounding box utilities. """ 2 | 3 | import numpy as np 4 | import cv2 5 | 6 | 7 | # Adapted from: http://ronny.rest/tutorials/module/localization_001/iou/ 8 | def get_iou(a, b, epsilon=1e-5): 9 | """ Given two boxes `a` and `b` defined as a list of four numbers: 10 | [x1,y1,x2,y2] 11 | where: 12 | x1,y1 represent the upper left corner 13 | x2,y2 represent the lower right corner 14 | It returns the Intersect of Union score for these two boxes. 15 | 16 | Args: 17 | a: (list of 4 numbers) [x1,y1,x2,y2] 18 | b: (list of 4 numbers) [x1,y1,x2,y2] 19 | epsilon: (float) Small value to prevent division by zero 20 | 21 | Returns: 22 | (float) The Intersect of Union score. 23 | """ 24 | # COORDINATES OF THE INTERSECTION BOX 25 | x1 = max(a[0], b[0]) 26 | y1 = max(a[1], b[1]) 27 | x2 = min(a[2], b[2]) 28 | y2 = min(a[3], b[3]) 29 | 30 | # AREA OF OVERLAP - Area where the boxes intersect 31 | width = (x2 - x1) 32 | height = (y2 - y1) 33 | # handle case where there is NO overlap 34 | if (width<0) or (height <0): 35 | return 0.0 36 | area_overlap = width * height 37 | 38 | # COMBINED AREA 39 | area_a = (a[2] - a[0]) * (a[3] - a[1]) 40 | area_b = (b[2] - b[0]) * (b[3] - b[1]) 41 | area_combined = area_a + area_b - area_overlap 42 | 43 | # RATIO OF AREA OF OVERLAP OVER COMBINED AREA 44 | iou = area_overlap / (area_combined+epsilon) 45 | return iou 46 | 47 | 48 | # Adapted from: http://ronny.rest/tutorials/module/localization_001/iou/ 49 | def batch_iou(a, b, epsilon=1e-5): 50 | """ Given two arrays `a` and `b` where each row contains a bounding 51 | box defined as a list of four numbers: 52 | [x1,y1,x2,y2] 53 | where: 54 | x1,y1 represent the upper left corner 55 | x2,y2 represent the lower right corner 56 | It returns the Intersect of Union scores for each corresponding 57 | pair of boxes. 58 | 59 | Args: 60 | a: (numpy array) each row containing [x1,y1,x2,y2] coordinates 61 | b: (numpy array) each row containing [x1,y1,x2,y2] coordinates 62 | epsilon: (float) Small value to prevent division by zero 63 | 64 | Returns: 65 | (numpy array) The Intersect of Union scores for each pair of bounding 66 | boxes. 67 | """ 68 | # COORDINATES OF THE INTERSECTION BOXES 69 | x1 = np.array([a[:, 0], b[:, 0]]).max(axis=0) 70 | y1 = np.array([a[:, 1], b[:, 1]]).max(axis=0) 71 | x2 = np.array([a[:, 2], b[:, 2]]).min(axis=0) 72 | y2 = np.array([a[:, 3], b[:, 3]]).min(axis=0) 73 | 74 | # AREAS OF OVERLAP - Area where the boxes intersect 75 | width = (x2 - x1) 76 | height = (y2 - y1) 77 | 78 | # handle case where there is NO overlap 79 | width[width < 0] = 0 80 | height[height < 0] = 0 81 | 82 | area_overlap = width * height 83 | 84 | # COMBINED AREAS 85 | area_a = (a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]) 86 | area_b = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1]) 87 | area_combined = area_a + area_b - area_overlap 88 | 89 | # RATIO OF AREA OF OVERLAP OVER COMBINED AREA 90 | iou = area_overlap / (area_combined + epsilon) 91 | return iou 92 | 93 | 94 | def scale_bbox(bbox, scale=1.2, square=True): 95 | """ Scale bounding box by the specified scale and optionally make it square. 96 | 97 | Args: 98 | bbox (np.array): Input bounding box in the format [left, top, width, height] 99 | scale (float): Multiply the bounding box by this scale 100 | square (bool): If True, make the shorter edges of the bounding box equal the length as the longer edges 101 | 102 | Returns: 103 | np.array. The scaled bounding box 104 | """ 105 | bbox_center = bbox[:2] + bbox[2:] / 2 106 | bbox_size = np.round(bbox[2:] * scale).astype(int) 107 | if square: 108 | bbox_max_size = np.max(bbox_size) 109 | bbox_size = np.array([bbox_max_size, bbox_max_size], dtype=int) 110 | bbox_min = np.round(bbox_center - bbox_size / 2).astype(int) 111 | bbox_scaled = np.concatenate((bbox_min, bbox_size)) 112 | 113 | return bbox_scaled 114 | 115 | 116 | def crop_img(img, bbox, landmarks=None, border=cv2.BORDER_CONSTANT, value=None): 117 | """ Crop image and corresponding landmarks by bounding box. 118 | 119 | If the bounding box is out the image bounds, the image will be padded in the corresponding regions. 120 | 121 | Args: 122 | img (np.array): An image of shape (H, W, 3) 123 | landmarks (np.array): Face landmarks points of shape (68, 2) 124 | bbox (np.array): Bounding box in the format [left, top, width, height] 125 | border (int): OpenCV's border code 126 | value (int, optional): Border value if border==cv2.BORDER_CONSTANT 127 | 128 | Returns: 129 | (np.array, np.array (optional)): A tuple of numpy arrays containing: 130 | - Cropped image (np.array) 131 | - Cropped landmarks (np.array): Will be returned if the landmarks parameter is not None 132 | """ 133 | left = -bbox[0] if bbox[0] < 0 else 0 134 | top = -bbox[1] if bbox[1] < 0 else 0 135 | right = bbox[0] + bbox[2] - img.shape[1] if (bbox[0] + bbox[2] - img.shape[1]) > 0 else 0 136 | bottom = bbox[1] + bbox[3] - img.shape[0] if (bbox[1] + bbox[3] - img.shape[0]) > 0 else 0 137 | img_bbox = bbox.copy() 138 | if any((left, top, right, bottom)): 139 | img = cv2.copyMakeBorder(img, top, bottom, left, right, border, value=value) 140 | img_bbox[0] += left 141 | img_bbox[1] += top 142 | 143 | if landmarks is not None: 144 | # Adjust landmarks 145 | new_landmarks = landmarks.copy() 146 | new_landmarks[:, :2] += (np.array([left, top]) - img_bbox[:2]) 147 | return img[img_bbox[1]:img_bbox[1] + img_bbox[3], img_bbox[0]:img_bbox[0] + img_bbox[2]], new_landmarks 148 | else: 149 | return img[img_bbox[1]:img_bbox[1] + img_bbox[3], img_bbox[0]:img_bbox[0] + img_bbox[2]] 150 | 151 | 152 | def crop2img(img, crop, bbox): 153 | """ Writes cropped image into another image corresponding to the specified bounding box. 154 | 155 | Args: 156 | img (np.array): The image to write into of shape (H, W, 3) 157 | crop (np.array): The cropped image of shape (H, W, 3) 158 | bbox (np.array): Bounding box in the format [left, top, width, height] 159 | 160 | Returns: 161 | np.array: Result image. 162 | """ 163 | scaled_bbox = bbox 164 | scaled_crop = cv2.resize(crop, (scaled_bbox[3], scaled_bbox[2]), interpolation=cv2.INTER_CUBIC) 165 | left = -scaled_bbox[0] if scaled_bbox[0] < 0 else 0 166 | top = -scaled_bbox[1] if scaled_bbox[1] < 0 else 0 167 | right = scaled_bbox[0] + scaled_bbox[2] - img.shape[1] if (scaled_bbox[0] + scaled_bbox[2] - img.shape[1]) > 0 else 0 168 | bottom = scaled_bbox[1] + scaled_bbox[3] - img.shape[0] if (scaled_bbox[1] + scaled_bbox[3] - img.shape[0]) > 0 else 0 169 | crop_bbox = np.array([left, top, scaled_bbox[2] - left - right, scaled_bbox[3] - top - bottom]) 170 | scaled_bbox += np.array([left, top, -left - right, -top - bottom]) 171 | 172 | out_img = img.copy() 173 | out_img[scaled_bbox[1]:scaled_bbox[1] + scaled_bbox[3], scaled_bbox[0]:scaled_bbox[0] + scaled_bbox[2]] = \ 174 | scaled_crop[crop_bbox[1]:crop_bbox[1] + crop_bbox[3], crop_bbox[0]:crop_bbox[0] + crop_bbox[2]] 175 | 176 | return out_img 177 | 178 | 179 | def get_main_bbox(bboxes, img_size): 180 | """ Returns the main bounding box in a list of bounding boxes according to their size and how central they are. 181 | 182 | Args: 183 | bboxes (list of np.array): A list of bounding boxes in the format [left, top, width, height] 184 | img_size (tuple of int): The size of the corresponding image in the format [height, width] 185 | 186 | Returns: 187 | np.array: The main bounding box. 188 | """ 189 | if len(bboxes) == 0: 190 | return None 191 | 192 | # Calculate frame max distance and size 193 | img_center = np.array([img_size[1], img_size[0]]) * 0.5 194 | max_dist = 0.25 * np.linalg.norm(img_size) 195 | max_size = 0.25 * (img_size[0] + img_size[1]) 196 | 197 | # For each bounding box 198 | scores = [] 199 | for bbox in bboxes: 200 | # Calculate center distance 201 | bbox_center = bbox[:2] + bbox[2:] * 0.5 202 | bbox_dist = np.linalg.norm(bbox_center - img_center) 203 | 204 | # Calculate bbox size 205 | bbox_size = bbox[2:].mean() 206 | 207 | # Calculate central ratio 208 | central_ratio = 1.0 if max_size < 1e-6 else (1.0 - bbox_dist / max_dist) 209 | central_ratio = np.clip(central_ratio, 0.0, 1.0) 210 | 211 | # Calculate size ratio 212 | size_ratio = 1.0 if max_size < 1e-6 else (bbox_size / max_size) 213 | size_ratio = np.clip(size_ratio, 0.0, 1.0) 214 | 215 | # Add score 216 | score = (central_ratio + size_ratio) * 0.5 217 | scores.append(score) 218 | 219 | return bboxes[np.argmax(scores)] 220 | 221 | 222 | def estimate_motion(points, kernel_size=5): 223 | """ Estimate motion of temporally sampled points. 224 | 225 | Args: 226 | points (np.array): An array of temporally sampled points of shape (N, 2) 227 | kernel_size (int): The temporal kernel size 228 | 229 | Returns: 230 | motion (np.array): Array of scalars of shape (N,) representing the amount of motion. 231 | """ 232 | deltas = np.zeros(points.shape) 233 | deltas[1:] = points[1:] - points[:-1] 234 | 235 | # Prepare smoothing kernel 236 | w = np.ones(kernel_size) 237 | w /= w.sum() 238 | 239 | # Smooth points 240 | deltas_padded = np.pad(deltas, ((kernel_size // 2, kernel_size // 2), (0, 0)), 'reflect') 241 | for i in range(points.shape[1]): 242 | deltas[:, i] = np.convolve(w, deltas_padded[:, i], mode='valid') 243 | 244 | motion = np.linalg.norm(deltas, axis=1) 245 | 246 | return motion 247 | 248 | 249 | def smooth_bboxes(detections, center_kernel=25, size_kernel=51, max_motion=0.01): 250 | """ Temporally smooth a series of bounding boxes by motion estimate. 251 | 252 | Based on the idea of the one Euro filter described in the paper: 253 | `"1 € filter: a simple speed-based low-pass filter for noisy input in interactive systems" 254 | `_ 255 | 256 | Args: 257 | detections (list of np.array): A list of detection bounding boxes in the format [left, top, bottom, right] 258 | center_kernel (int): The temporal kernel size for smoothing the bounding box centers 259 | size_kernel (int): The temporal kernel size for smoothing the bounding box sizes 260 | max_motion (float): The maximum allowed motion (for normalization) 261 | 262 | Returns: 263 | (list of np.array): The smoothed bounding boxes. 264 | """ 265 | # Prepare smoothing kernel 266 | center_w = np.ones(center_kernel) 267 | center_w /= center_w.sum() 268 | size_w = np.ones(size_kernel) 269 | size_w /= size_w.sum() 270 | 271 | # Convert bounding boxes to center and size format 272 | bboxes = np.array(detections) 273 | centers = (bboxes[:, :2] + bboxes[:, 2:]) / 2.0 274 | sizes = bboxes[:, 2:] - bboxes[:, :2] 275 | 276 | # Smooth sizes 277 | sizes_padded = np.pad(sizes, ((size_kernel // 2, size_kernel // 2), (0, 0)), 'reflect') 278 | for i in range(centers.shape[1]): 279 | sizes[:, i] = np.convolve(size_w, sizes_padded[:, i], mode='valid') 280 | 281 | # Estimate motion 282 | centers_normalized = centers / sizes[:, 1:] 283 | motion = estimate_motion(centers_normalized, center_kernel) 284 | 285 | # Average smooth centers 286 | centers_padded = np.pad(centers, ((center_kernel // 2, center_kernel // 2), (0, 0)), 'reflect') 287 | centers_avg = centers.copy() 288 | for i in range(centers.shape[1]): 289 | centers_avg[:, i] = np.convolve(center_w, centers_padded[:, i], mode='valid') 290 | 291 | # Smooth centers by motion 292 | a = np.minimum(motion / max_motion, 1.)[..., np.newaxis] 293 | centers_smoothed = centers * a + centers_avg * (1 - a) 294 | 295 | # Change back to detections format 296 | sizes /= 2.0 297 | bboxes = np.concatenate((centers_smoothed - sizes, centers_smoothed + sizes), axis=1) 298 | 299 | return bboxes 300 | -------------------------------------------------------------------------------- /fsgan/utils/blur.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | # Adapted from: 9 | # https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/10 10 | class GaussianSmoothing(nn.Module): 11 | """ Apply gaussian smoothing on a 1d, 2d or 3d tensor. 12 | 13 | Filtering is performed seperately for each channel in the input using a depthwise convolution. 14 | 15 | Args: 16 | channels (int): Number of channels for both input and output tensors. 17 | kernel_size(int or list of int): Size of the gaussian kernel 18 | sigma (float or list of float): Standard deviation of the gaussian kernel 19 | padding (int, optional): Padding size. The default is half the kernel size 20 | dim (int, optional): The number of dimensions of the data. Default value is 2 (spatial) 21 | """ 22 | def __init__(self, channels, kernel_size, sigma, padding=None, dim=2): 23 | super(GaussianSmoothing, self).__init__() 24 | self.padding = kernel_size // 2 if padding is None else padding 25 | if isinstance(kernel_size, numbers.Number): 26 | kernel_size = [kernel_size] * dim 27 | if isinstance(sigma, numbers.Number): 28 | sigma = [sigma] * dim 29 | 30 | # The gaussian kernel is the product of the 31 | # gaussian function of each dimension. 32 | kernel = 1 33 | meshgrids = torch.meshgrid( 34 | [ 35 | torch.arange(size, dtype=torch.float32) 36 | for size in kernel_size 37 | ] 38 | ) 39 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 40 | mean = (size - 1) / 2 41 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ 42 | torch.exp(-((mgrid - mean) / (2 * std)) ** 2) 43 | 44 | # Make sure sum of values in gaussian kernel equals 1. 45 | kernel = kernel / torch.sum(kernel) 46 | 47 | # Reshape to depthwise convolutional weight 48 | kernel = kernel.view(1, 1, *kernel.size()) 49 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 50 | 51 | self.register_buffer('weight', kernel) 52 | self.groups = channels 53 | 54 | if dim == 1: 55 | self.conv = F.conv1d 56 | elif dim == 2: 57 | self.conv = F.conv2d 58 | elif dim == 3: 59 | self.conv = F.conv3d 60 | else: 61 | raise RuntimeError( 62 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) 63 | ) 64 | 65 | def forward(self, input): 66 | """ Apply gaussian filter to input. 67 | 68 | Args: 69 | input (torch.Tensor): Input to apply gaussian filter on. 70 | 71 | Returns: 72 | filtered (torch.Tensor): Filtered output. 73 | """ 74 | return self.conv(input, weight=self.weight, groups=self.groups, padding=self.padding) 75 | -------------------------------------------------------------------------------- /fsgan/utils/confusionmatrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ConfusionMatrix(object): 6 | """Constructs a confusion matrix for a multi-class classification problems. 7 | Does not support multi-label, multi-class problems. 8 | Keyword arguments: 9 | - num_classes (int): number of classes in the classification problem. 10 | - normalized (boolean, optional): Determines whether or not the confusion 11 | matrix is normalized or not. Default: False. 12 | Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py 13 | """ 14 | 15 | def __init__(self, num_classes, normalized=False): 16 | super().__init__() 17 | 18 | self.conf = np.ndarray((num_classes, num_classes), dtype=np.int32) 19 | self.normalized = normalized 20 | self.num_classes = num_classes 21 | self.reset() 22 | 23 | def reset(self): 24 | self.conf.fill(0) 25 | 26 | def add(self, predicted, target): 27 | """Computes the confusion matrix 28 | The shape of the confusion matrix is K x K, where K is the number 29 | of classes. 30 | Keyword arguments: 31 | - predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of 32 | predicted scores obtained from the model for N examples and K classes, 33 | or an N-tensor/array of integer values between 0 and K-1. 34 | - target (Tensor or numpy.ndarray): Can be an N x K tensor/array of 35 | ground-truth classes for N examples and K classes, or an N-tensor/array 36 | of integer values between 0 and K-1. 37 | """ 38 | # If target and/or predicted are tensors, convert them to numpy arrays 39 | if torch.is_tensor(predicted): 40 | predicted = predicted.cpu().numpy() 41 | if torch.is_tensor(target): 42 | target = target.cpu().numpy() 43 | 44 | assert predicted.shape[0] == target.shape[0], \ 45 | 'number of targets and predicted outputs do not match' 46 | 47 | if np.ndim(predicted) != 1: 48 | assert predicted.shape[1] == self.num_classes, \ 49 | 'number of predictions does not match size of confusion matrix' 50 | predicted = np.argmax(predicted, 1) 51 | else: 52 | assert (predicted.max() < self.num_classes) and (predicted.min() >= 0), \ 53 | 'predicted values are not between 0 and k-1' 54 | 55 | if np.ndim(target) != 1: 56 | assert target.shape[1] == self.num_classes, \ 57 | 'Onehot target does not match size of confusion matrix' 58 | assert (target >= 0).all() and (target <= 1).all(), \ 59 | 'in one-hot encoding, target values should be 0 or 1' 60 | assert (target.sum(1) == 1).all(), \ 61 | 'multi-label setting is not supported' 62 | target = np.argmax(target, 1) 63 | else: 64 | assert (target.max() < self.num_classes) and (target.min() >= 0), \ 65 | 'target values are not between 0 and k-1' 66 | 67 | # hack for bincounting 2 arrays together 68 | x = predicted + self.num_classes * target 69 | bincount_2d = np.bincount( 70 | x.astype(np.int32), minlength=self.num_classes**2) 71 | assert bincount_2d.size == self.num_classes**2 72 | conf = bincount_2d.reshape((self.num_classes, self.num_classes)) 73 | 74 | self.conf += conf 75 | 76 | def value(self): 77 | """ 78 | Returns: 79 | Confustion matrix of K rows and K columns, where rows corresponds 80 | to ground-truth targets and columns corresponds to predicted 81 | targets. 82 | """ 83 | if self.normalized: 84 | conf = self.conf.astype(np.float32) 85 | return conf / conf.sum(1).clip(min=1e-12)[:, None] 86 | else: 87 | return self.conf 88 | -------------------------------------------------------------------------------- /fsgan/utils/img_utils.py: -------------------------------------------------------------------------------- 1 | """ Image utilities. """ 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torchvision.utils 7 | import torchvision.transforms.functional as F 8 | 9 | 10 | def rgb2tensor(img, normalize=True): 11 | """ Converts a RGB image to tensor. 12 | 13 | Args: 14 | img (np.array or list of np.array): RGB image of shape (H, W, 3) or a list of images 15 | normalize (bool): If True, the tensor will be normalized to the range [-1, 1] 16 | 17 | Returns: 18 | torch.Tensor or list of torch.Tensor: The converted image tensor or a list of converted tensors. 19 | """ 20 | if isinstance(img, (list, tuple)): 21 | return [rgb2tensor(o) for o in img] 22 | tensor = F.to_tensor(img) 23 | if normalize: 24 | tensor = F.normalize(tensor, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 25 | 26 | return tensor.unsqueeze(0) 27 | 28 | 29 | def bgr2tensor(img, normalize=True): 30 | """ Converts a BGR image to tensor. 31 | 32 | Args: 33 | img (np.array or list of np.array): BGR image of shape (H, W, 3) or a list of images 34 | normalize (bool): If True, the tensor will be normalized to the range [-1, 1] 35 | 36 | Returns: 37 | torch.Tensor or list of torch.Tensor: The converted image tensor or a list of converted tensors. 38 | """ 39 | if isinstance(img, (list, tuple)): 40 | return [bgr2tensor(o, normalize) for o in img] 41 | return rgb2tensor(img[:, :, ::-1].copy(), normalize) 42 | 43 | 44 | def unnormalize(tensor, mean, std): 45 | """Normalize a tensor image with mean and standard deviation. 46 | 47 | See :class:`~torchvision.transforms.Normalize` for more details. 48 | 49 | Args: 50 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 51 | mean (sequence): Sequence of means for each channel. 52 | std (sequence): Sequence of standard deviations for each channely. 53 | 54 | Returns: 55 | Tensor: Normalized Tensor image. 56 | """ 57 | for t, m, s in zip(tensor, mean, std): 58 | t.mul_(s).add_(m) 59 | return tensor 60 | 61 | 62 | def tensor2rgb(img_tensor): 63 | """ Convert an image tensor to a numpy RGB image. 64 | 65 | Args: 66 | img_tensor (torch.Tensor): Tensor image of shape (3, H, W) 67 | 68 | Returns: 69 | np.array: RGB image of shape (H, W, 3) 70 | """ 71 | output_img = unnormalize(img_tensor.clone(), [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 72 | output_img = output_img.squeeze().permute(1, 2, 0).cpu().numpy() 73 | output_img = np.round(output_img * 255).astype('uint8') 74 | 75 | return output_img 76 | 77 | 78 | def tensor2bgr(img_tensor): 79 | """ Convert an image tensor to a numpy BGR image. 80 | 81 | Args: 82 | img_tensor (torch.Tensor): Tensor image of shape (3, H, W) 83 | 84 | Returns: 85 | np.array: BGR image of shape (H, W, 3) 86 | """ 87 | output_img = tensor2rgb(img_tensor) 88 | output_img = output_img[:, :, ::-1] 89 | 90 | return output_img 91 | 92 | 93 | def make_grid(*args, cols=8): 94 | """ Create an image grid from a batch of images. 95 | 96 | Args: 97 | *args: (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) 98 | or a list of images all of the same size 99 | cols: The maximum number of columns in the grid 100 | 101 | Returns: 102 | torch.Tensor: The grid of images. 103 | """ 104 | assert len(args) > 0, 'At least one input tensor must be given!' 105 | imgs = torch.cat([a.cpu() for a in args], dim=2) 106 | 107 | return torchvision.utils.make_grid(imgs, nrow=cols, normalize=True, scale_each=False) 108 | 109 | 110 | def create_pyramid(img, n=1): 111 | """ Create an image pyramid. 112 | 113 | Args: 114 | img (torch.Tensor): An image tensor of shape (B, C, H, W) 115 | n (int): The number of pyramids to create 116 | 117 | Returns: 118 | list of torch.Tensor: The computed image pyramid. 119 | """ 120 | # If input is a list or tuple return it as it is (probably already a pyramid) 121 | if isinstance(img, (list, tuple)): 122 | return img 123 | 124 | pyd = [img] 125 | for i in range(n - 1): 126 | pyd.append(nn.functional.avg_pool2d(pyd[-1], 3, stride=2, padding=1, count_include_pad=False)) 127 | 128 | return pyd 129 | -------------------------------------------------------------------------------- /fsgan/utils/iou_metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from fsgan.utils.confusionmatrix import ConfusionMatrix 4 | 5 | 6 | # Adapted from: https://github.com/davidtvs/PyTorch-ENet/blob/master/metric/iou.py 7 | class IOUMetric(object): 8 | """Computes the intersection over union (IoU) per class and corresponding 9 | mean (mIoU). 10 | Intersection over union (IoU) is a common evaluation metric for semantic 11 | segmentation. The predictions are first accumulated in a confusion matrix 12 | and the IoU is computed from it as follows: 13 | IoU = true_positive / (true_positive + false_positive + false_negative). 14 | Keyword arguments: 15 | - num_classes (int): number of classes in the classification problem 16 | - normalized (boolean, optional): Determines whether or not the confusion 17 | matrix is normalized or not. Default: False. 18 | - ignore_index (int or iterable, optional): Index of the classes to ignore 19 | when computing the IoU. Can be an int, or any iterable of ints. 20 | """ 21 | def __init__(self, num_classes, normalized=False, ignore_index=None): 22 | super().__init__() 23 | self.conf_metric = ConfusionMatrix(num_classes, normalized) 24 | 25 | if ignore_index is None: 26 | self.ignore_index = None 27 | elif isinstance(ignore_index, int): 28 | self.ignore_index = (ignore_index,) 29 | else: 30 | try: 31 | self.ignore_index = tuple(ignore_index) 32 | except TypeError: 33 | raise ValueError("'ignore_index' must be an int or iterable") 34 | 35 | def reset(self): 36 | self.conf_metric.reset() 37 | 38 | def add(self, predicted, target): 39 | """Adds the predicted and target pair to the IoU metric. 40 | Keyword arguments: 41 | - predicted (Tensor): Can be a (N, K, H, W) tensor of 42 | predicted scores obtained from the model for N examples and K classes, 43 | or (N, H, W) tensor of integer values between 0 and K-1. 44 | - target (Tensor): Can be a (N, K, H, W) tensor of 45 | target scores for N examples and K classes, or (N, H, W) tensor of 46 | integer values between 0 and K-1. 47 | """ 48 | # Dimensions check 49 | assert predicted.size(0) == target.size(0), \ 50 | 'number of targets and predicted outputs do not match' 51 | assert predicted.dim() == 3 or predicted.dim() == 4, \ 52 | "predictions must be of dimension (N, H, W) or (N, K, H, W)" 53 | assert target.dim() == 3 or target.dim() == 4, \ 54 | "targets must be of dimension (N, H, W) or (N, K, H, W)" 55 | 56 | # If the tensor is in categorical format convert it to integer format 57 | if predicted.dim() == 4: 58 | _, predicted = predicted.max(1) 59 | if target.dim() == 4: 60 | _, target = target.max(1) 61 | 62 | self.conf_metric.add(predicted.view(-1), target.view(-1)) 63 | 64 | def value(self): 65 | """Computes the IoU and mean IoU. 66 | The mean computation ignores NaN elements of the IoU array. 67 | Returns: 68 | Tuple: (IoU, mIoU). The first output is the per class IoU, 69 | for K classes it's numpy.ndarray with K elements. The second output, 70 | is the mean IoU. 71 | """ 72 | conf_matrix = self.conf_metric.value() 73 | if self.ignore_index is not None: 74 | for index in self.ignore_index: 75 | conf_matrix[:, self.ignore_index] = 0 76 | conf_matrix[self.ignore_index, :] = 0 77 | true_positive = np.diag(conf_matrix) 78 | false_positive = np.sum(conf_matrix, 0) - true_positive 79 | false_negative = np.sum(conf_matrix, 1) - true_positive 80 | 81 | # Just in case we get a division by 0, ignore/hide the error 82 | with np.errstate(divide='ignore', invalid='ignore'): 83 | iou = true_positive / (true_positive + false_positive + false_negative) 84 | 85 | return iou, np.nanmean(iou) 86 | -------------------------------------------------------------------------------- /fsgan/utils/obj_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | from functools import partial 4 | 5 | 6 | KNOWN_MODULES = { 7 | # datasets 8 | 'image_list_dataset': 'fsgan.datasets.image_list_dataset', 9 | 'opencv_video_seq_dataset': 'fsgan.datasets.opencv_video_seq_dataset', 10 | 'seq_dataset': 'fsgan.datasets.seq_dataset', 11 | 'img_landmarks_transforms': 'fsgan.datasets.img_landmarks_transforms', 12 | 'img_lms_pose_transforms': 'fsgan.datasets.img_lms_pose_transforms', 13 | 'transforms': 'torchvision.transforms', 14 | 15 | # models 16 | 'res_unet': 'fsgan.models.res_unet', 17 | 'res_unet_split': 'fsgan.models.res_unet_split', 18 | 'res_unet_msba': 'fsgan.models.res_unet_msba', 19 | 20 | # criterions 21 | 'vgg_loss': 'fsgan.criterions.vgg_loss', 22 | 'gan_loss': 'fsgan.criterions.gan_loss', 23 | 24 | # Torch 25 | 'nn': 'torch.nn', 26 | 'optim': 'torch.optim', 27 | 'lr_scheduler': 'torch.optim.lr_scheduler', 28 | } 29 | 30 | 31 | def extract_args(*args, **kwargs): 32 | return args, kwargs 33 | 34 | 35 | def obj_factory(obj_exp, *args, **kwargs): 36 | """ Creates objects from strings or partial objects with additional provided arguments. 37 | 38 | In case a sequence is provided, all objects in the sequence will be created recursively. 39 | Objects that are not strings or partials be returned as they are. 40 | 41 | Args: 42 | obj_exp (str or partial): The object string expresion or partial to be converted into an object. Can also be 43 | a sequence of object expressions 44 | *args: Additional arguments to pass to the object 45 | **kwargs: Additional keyword arguments to pass to the object 46 | 47 | Returns: 48 | object or object list: Created object or list of recursively created objects 49 | """ 50 | if isinstance(obj_exp, (list, tuple)): 51 | return [obj_factory(o, *args, **kwargs) for o in obj_exp] 52 | if isinstance(obj_exp, partial): 53 | return obj_exp(*args, **kwargs) 54 | if not isinstance(obj_exp, str): 55 | return obj_exp 56 | 57 | # Handle arguments 58 | if '(' in obj_exp and ')' in obj_exp: 59 | args_exp = obj_exp[obj_exp.find('('):] 60 | obj_args, obj_kwargs = eval('extract_args' + args_exp) 61 | 62 | # Concatenate arguments 63 | args = obj_args + args 64 | kwargs.update(obj_kwargs) 65 | 66 | obj_exp = obj_exp[:obj_exp.find('(')] 67 | 68 | # From here we can assume that dots in the remaining of the expression 69 | # only separate between modules and classes 70 | module_name, class_name = os.path.splitext(obj_exp) 71 | class_name = class_name[1:] 72 | module = importlib.import_module(KNOWN_MODULES[module_name] if module_name in KNOWN_MODULES else module_name) 73 | module_class = getattr(module, class_name) 74 | class_instance = module_class(*args, **kwargs) 75 | 76 | return class_instance 77 | 78 | 79 | def partial_obj_factory(obj_exp, *args, **kwargs): 80 | """ Creates objects from strings or partial objects with additional provided arguments. 81 | 82 | In case a sequence is provided, all objects in the sequence will be created recursively. 83 | Objects that are not strings or partials be returned as they are. 84 | 85 | Args: 86 | obj_exp (str or partial): The object string expresion or partial to be converted into an object. Can also be 87 | a sequence of object expressions 88 | *args: Additional arguments to pass to the object 89 | **kwargs: Additional keyword arguments to pass to the object 90 | 91 | Returns: 92 | object or object list: Created object or list of recursively created objects 93 | """ 94 | if isinstance(obj_exp, (list, tuple)): 95 | return [partial_obj_factory(o, *args, **kwargs) for o in obj_exp] 96 | if isinstance(obj_exp, partial): 97 | return partial(obj_exp.func, *(obj_exp.args + args), **{**obj_exp.keywords, **kwargs}) 98 | if not isinstance(obj_exp, str): 99 | return partial(obj_exp) 100 | 101 | # Handle arguments 102 | if '(' in obj_exp and ')' in obj_exp: 103 | args_exp = obj_exp[obj_exp.find('('):] 104 | obj_args, obj_kwargs = eval('extract_args' + args_exp) 105 | 106 | # Concatenate arguments 107 | args = obj_args + args 108 | kwargs.update(obj_kwargs) 109 | 110 | obj_exp = obj_exp[:obj_exp.find('(')] 111 | 112 | # From here we can assume that dots in the remaining of the expression 113 | # only separate between modules and classes 114 | module_name, class_name = os.path.splitext(obj_exp) 115 | class_name = class_name[1:] 116 | module = importlib.import_module(KNOWN_MODULES[module_name] if module_name in KNOWN_MODULES else module_name) 117 | module_class = getattr(module, class_name) 118 | 119 | return partial(module_class, *args, **kwargs) 120 | 121 | 122 | def main(obj_exp): 123 | # obj = obj_factory(obj_exp) 124 | # print(obj) 125 | 126 | import inspect 127 | partial_obj = partial_obj_factory(obj_exp) 128 | print(f'is obj_exp a class = {inspect.isclass(partial_obj.func)}') 129 | print(partial_obj) 130 | 131 | 132 | if __name__ == "__main__": 133 | # Parse program arguments 134 | import argparse 135 | parser = argparse.ArgumentParser('utils test') 136 | parser.add_argument('obj_exp', help='object string') 137 | args = parser.parse_args() 138 | 139 | main(args.obj_exp) 140 | -------------------------------------------------------------------------------- /fsgan/utils/one_euro_filter.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | 5 | def smoothing_factor(t_e, cutoff): 6 | r = 2 * math.pi * cutoff * t_e 7 | return r / (r + 1) 8 | 9 | 10 | def exponential_smoothing(a, x, x_prev): 11 | return a * x + (1 - a) * x_prev 12 | 13 | 14 | class OneEuroFilter: 15 | def __init__(self, min_cutoff=1.0, beta=0.0, d_cutoff=1.0, t_e=33.333): 16 | """Initialize the one euro filter.""" 17 | # The parameters. 18 | self.min_cutoff = float(min_cutoff) 19 | self.beta = float(beta) 20 | self.d_cutoff = float(d_cutoff) 21 | self.t_e = t_e 22 | self.a_d = smoothing_factor(self.t_e, self.d_cutoff) 23 | self.x_prev = self.dx_prev = None 24 | 25 | def reset(self): 26 | self.x_prev = self.dx_prev = None 27 | 28 | def __call__(self, x): 29 | """Compute the filtered signal.""" 30 | if self.x_prev is None: 31 | self.x_prev = x 32 | self.dx_prev = 0.0 33 | return x, 0. 34 | 35 | # The filtered derivative of the signal. 36 | # dx = (x - self.x_prev) / self.t_e 37 | dx = np.linalg.norm(x - self.x_prev) / self.t_e 38 | dx_hat = exponential_smoothing(self.a_d, dx, self.dx_prev) 39 | 40 | # The filtered signal. 41 | cutoff = self.min_cutoff + self.beta * abs(dx_hat) 42 | a = smoothing_factor(self.t_e, cutoff) 43 | x_hat = exponential_smoothing(a, x, self.x_prev) 44 | 45 | # Memorize the previous values. 46 | self.x_prev = x_hat 47 | self.dx_prev = dx_hat 48 | 49 | return x_hat, a 50 | -------------------------------------------------------------------------------- /fsgan/utils/seg_utils.py: -------------------------------------------------------------------------------- 1 | """ Face segmentation utilities. """ 2 | 3 | import io 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import cv2 9 | from PIL import Image 10 | 11 | 12 | def blend_seg_pred(img, seg, alpha=0.5): 13 | """ Blend images with their corresponding segmentation prediction. 14 | 15 | Args: 16 | img (torch.Tensor): A batch of image tensors of shape (B, 3, H, W) where B is the batch size, 17 | H is the images height and W is the images width 18 | seg (torch.Tensor): A batch of segmentation predictions of shape (B, C, H, W) where B is the batch size, 19 | C is the number of segmentation classes, H is the images height and W is the images width 20 | alpha: alpha (float): Opacity value for the segmentation in the range [0, 1] where 0 is completely transparent 21 | and 1 is completely opaque 22 | 23 | Returns: 24 | torch.Tensor: The blended image. 25 | """ 26 | pred = seg.argmax(1) 27 | pred = pred.view(pred.shape[0], 1, pred.shape[1], pred.shape[2]).repeat(1, 3, 1, 1) 28 | blend = img 29 | 30 | # For each segmentation class except the background (label 0) 31 | for i in range(1, seg.shape[1]): 32 | color_mask = -torch.ones_like(img) 33 | color_mask[:, -i, :, :] = 1 34 | alpha_mask = 1 - (pred == i).float() * alpha 35 | blend = blend * alpha_mask + color_mask * (1 - alpha_mask) 36 | 37 | return blend 38 | 39 | 40 | def blend_seg_label(img, seg, alpha=0.5): 41 | """ Blend images with their corresponding segmentation labels. 42 | 43 | Args: 44 | img (torch.Tensor): A batch of image tensors of shape (B, 3, H, W) where B is the batch size, 45 | H is the images height and W is the images width 46 | seg (torch.Tensor): A batch of segmentation labels of shape (B, H, W) where B is the batch size, 47 | H is the images height and W is the images width 48 | alpha: alpha (float): Opacity value for the segmentation in the range [0, 1] where 0 is completely transparent 49 | and 1 is completely opaque 50 | 51 | Returns: 52 | torch.Tensor: The blended image. 53 | """ 54 | pred = seg.unsqueeze(1).repeat(1, 3, 1, 1) 55 | blend = img 56 | 57 | # For each segmentation class except the background (label 0) 58 | for i in range(1, pred.shape[1]): 59 | color_mask = -torch.ones_like(img) 60 | color_mask[:, -i, :, :] = 1 61 | alpha_mask = 1 - (pred == i).float() * alpha 62 | blend = blend * alpha_mask + color_mask * (1 - alpha_mask) 63 | 64 | return blend 65 | 66 | 67 | # TODO: Move this somewhere else later 68 | def random_hair_inpainting_mask(face_mask): 69 | """ Simulate random hair occlusions on face mask. 70 | 71 | The algorithm works as follows: 72 | 1. The method first randomly choose a y coordinate of the face mask 73 | 2. x coordinate is chosen randomly: Either minimum or maximum x value of the selected line 74 | 3. A random ellipse is rendered with its center in (x, y) 75 | 4. The inpainting map is the intersection of the face mask with the ellipse. 76 | 77 | Args: 78 | face_mask (np.array): A binary mask tensor of shape (H, W) where `1` means face region 79 | and `0` means background 80 | 81 | Returns: 82 | np.array: Result mask. 83 | """ 84 | mask = face_mask == 1 85 | inpainting_mask = np.zeros(mask.shape, 'uint8') 86 | a = np.where(mask != 0) 87 | if len(a[0]) == 0 or len(a[1]) == 0: 88 | return inpainting_mask 89 | if (np.max(a[0]) - np.min(a[0])) <= 10 or (np.max(a[1]) - np.min(a[1])) <= 10: 90 | return inpainting_mask 91 | 92 | # Select a random point on the mask borders 93 | try: 94 | y_coords = np.unique(a[0]) 95 | y_ind = np.random.randint(len(y_coords)) 96 | y = y_coords[y_ind] 97 | x_ind = np.where(a[0] == y) 98 | x_coords = a[1][x_ind[0]] 99 | x = x_coords.min() if np.random.rand() > 0.5 else x_coords.max() 100 | except: 101 | print(y_coords) 102 | print(x_coords) 103 | 104 | # Draw inpainting shape 105 | width = a[1].max() - a[1].min() + 1 106 | # height = a[0].max() - a[0].min() + 1 107 | scale = (np.random.randint(width // 4, width // 2), np.random.randint(width // 4, width // 2)) 108 | rotation_angle = np.random.randint(0, 180) 109 | cv2.ellipse(inpainting_mask, (x, y), scale, rotation_angle, 0, 360, (255, 255, 255), -1, 8) 110 | 111 | # Get inpainting mask by intersection with face mask 112 | inpainting_mask *= mask 113 | inpainting_mask = inpainting_mask > 0 114 | 115 | ### Debug ### 116 | # cv2.imshow('face_mask', inpainting_mask) 117 | # cv2.waitKey(0) 118 | ############# 119 | 120 | return inpainting_mask 121 | 122 | 123 | def random_hair_inpainting_mask_tensor(face_mask): 124 | """ Simulate random hair occlusions on face mask. 125 | 126 | The algorithm works as follows: 127 | 1. The method first randomly choose a y coordinate of the face mask 128 | 2. x coordinate is chosen randomly: Either minimum or maximum x value of the selected line 129 | 3. A random ellipse is rendered with its center in (x, y) 130 | 4. The inpainting map is the intersection of the face mask with the ellipse. 131 | 132 | Args: 133 | face_mask (torch.Tensor): A binary mask tensor of shape (B, H, W) where `1` means face region 134 | and `0` means background 135 | 136 | Returns: 137 | torch.Tensor: Result mask. 138 | """ 139 | out_tensors = [] 140 | for b in range(face_mask.shape[0]): 141 | curr_face_mask = face_mask[b] 142 | inpainting_mask = random_hair_inpainting_mask(curr_face_mask.cpu().numpy()) 143 | out_tensors.append(torch.from_numpy(inpainting_mask.astype(float)).unsqueeze(0)) 144 | 145 | return torch.cat(out_tensors, dim=0) 146 | 147 | 148 | # TODO: Remove this later 149 | def encode_segmentation(segmentation): 150 | seg_min, seg_max = segmentation.min(), segmentation.max() 151 | segmentation = segmentation.sub(seg_min).div_((seg_max - seg_min) * 0.5).sub_(1.0) 152 | 153 | return segmentation 154 | 155 | 156 | def encode_binary_mask(mask): 157 | """ Encode binary mask using binary PNG encoding. 158 | 159 | Args: 160 | mask (np.array): Binary mask of shape (H, W) 161 | 162 | Returns: 163 | bytes: Encoded binary mask. 164 | """ 165 | mask_pil = Image.fromarray(mask.astype('uint8') * 255, mode='L').convert('1') 166 | in_mem_file = io.BytesIO() 167 | mask_pil.save(in_mem_file, format='png') 168 | in_mem_file.seek(0) 169 | 170 | return in_mem_file.read() 171 | 172 | 173 | def decode_binary_mask(bytes): 174 | """ Decode an encoded binary mask. 175 | 176 | Args: 177 | bytes: Encoded binary mask of shape (H, W) 178 | 179 | Returns: 180 | np.array: Decoded binary mask. 181 | """ 182 | return np.array(Image.open(io.BytesIO(bytes))) 183 | # return np.array(Image.open(io.BytesIO(bytes)).convert('L')) 184 | 185 | 186 | class SoftErosion(nn.Module): 187 | """ Applies *soft erosion* on a binary mask, that is similar to the 188 | `erosion morphology operation `_, 189 | returning both a soft mask and a hard binary mask. 190 | 191 | All values greater or equal to the the specified threshold will be set to 1 in both the soft and hard masks, 192 | the other values will be 0 in the hard mask and will be gradually reduced to 0 in the soft mask. 193 | 194 | Args: 195 | kernel_size (int): The size of the erosion kernel size 196 | threshold (float): The erosion threshold 197 | iterations (int) The number of times to apply the erosion kernel 198 | """ 199 | def __init__(self, kernel_size=15, threshold=0.6, iterations=1): 200 | super(SoftErosion, self).__init__() 201 | r = kernel_size // 2 202 | self.padding = r 203 | self.iterations = iterations 204 | self.threshold = threshold 205 | 206 | # Create kernel 207 | y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size)) 208 | dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2) 209 | kernel = dist.max() - dist 210 | kernel /= kernel.sum() 211 | kernel = kernel.view(1, 1, *kernel.shape) 212 | self.register_buffer('weight', kernel) 213 | 214 | def forward(self, x): 215 | """ Apply the soft erosion operation. 216 | 217 | Args: 218 | x (torch.Tensor): A binary mask of shape (1, H, W) 219 | 220 | Returns: 221 | (torch.Tensor, torch.Tensor): Tuple containing: 222 | - soft_mask (torch.Tensor): The soft mask of shape (1, H, W) 223 | - hard_mask (torch.Tensor): The hard mask of shape (1, H, W) 224 | """ 225 | x = x.float() 226 | for i in range(self.iterations - 1): 227 | x = torch.min(x, F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)) 228 | x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding) 229 | 230 | mask = x >= self.threshold 231 | x[mask] = 1.0 232 | x[~mask] /= x[~mask].max() 233 | 234 | return x, mask 235 | 236 | 237 | def remove_inner_mouth(seg, landmarks): 238 | """ Removes the inner part of the mouth, corresponding to the face landmarks, from a binary mask. 239 | 240 | Args: 241 | seg (np.array): A binary mask of shape (H, W) 242 | landmarks (np.array): Face landmarks of shape (98, 2) 243 | 244 | Returns: 245 | np.array: The binary mask with the inner part of the mouth removed. 246 | """ 247 | size = np.array(seg.shape[::-1]) 248 | mouth_pts = landmarks[88:96] * size 249 | mouth_pts = np.round(mouth_pts).astype(int) 250 | out_seg = cv2.fillPoly(seg.astype('uint8'), [mouth_pts], (0, 0, 0)) 251 | 252 | return out_seg.astype(seg.dtype) 253 | 254 | 255 | def main(input_path): 256 | from PIL import Image 257 | seg = np.array(Image.open(input_path)) 258 | # while True: 259 | # random_hair_inpainting_mask(seg) 260 | 261 | 262 | if __name__ == "__main__": 263 | # Parse program arguments 264 | import argparse 265 | 266 | parser = argparse.ArgumentParser('seg_utils') 267 | parser.add_argument('input', help='input path') 268 | args = parser.parse_args() 269 | main(args.input) -------------------------------------------------------------------------------- /fsgan/utils/set_checkpoint_arch.py: -------------------------------------------------------------------------------- 1 | """ Utility script for overriding the architecture of saved checkpoints. """ 2 | 3 | import argparse 4 | import torch 5 | 6 | 7 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=__doc__) 8 | parser.add_argument('checkpoint_path', metavar='PATH', 9 | help='path to checkpoint file') 10 | parser.add_argument('-a', '--arch', 11 | help='network architecture') 12 | parser.add_argument('-o', '--output', metavar='PATH', 13 | help='output checkpoint path') 14 | parser.add_argument('--override', action='store_true', 15 | help='override existing architecture') 16 | 17 | 18 | def main(checkpoint_path, arch, output=None, override=False): 19 | output = checkpoint_path if output is None else output 20 | checkpoint = torch.load(checkpoint_path) 21 | if 'arch' in checkpoint and not override: 22 | print('checkpoint already contains "arch": ' + checkpoint['arch']) 23 | return 24 | print('Setting checkpoint\'s arch: ' + arch) 25 | checkpoint['arch'] = arch 26 | print('Writing chechpoint to path: ' + output) 27 | torch.save(checkpoint, output) 28 | 29 | 30 | if __name__ == "__main__": 31 | main(**vars(parser.parse_args())) 32 | -------------------------------------------------------------------------------- /fsgan/utils/temporal_smoothing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class TemporalSmoothing(nn.Module): 7 | """ Apply temporal smoothing kernel on the batch dimension of a 4d tensor. 8 | 9 | Filtering is performed separately for each channel in the input using a depthwise convolution. 10 | 11 | Args: 12 | channels (int): Number of channels of the input tensors. Output will 13 | have this number of channels as well 14 | kernel_size (int): Size of the average kernel 15 | """ 16 | def __init__(self, channels, kernel_size=5): 17 | super(TemporalSmoothing, self).__init__() 18 | self.kernel_size = kernel_size 19 | self.kernel_radius = kernel_size // 2 20 | self.groups = channels 21 | 22 | # Create kernel 23 | kernel = torch.ones(channels, 1, kernel_size, 1) 24 | kernel.div_(kernel_size) 25 | self.register_buffer('weight', kernel) 26 | 27 | def forward(self, x, pad_prev=0, pad_next=0): 28 | """ Apply temporal smoothing to x. 29 | 30 | Args: 31 | x (torch.Tensor): Input to apply temporal smoothing on 32 | pad_prev (int): The amount of reflection padding from the left side of the batch dimension 33 | pad_next (int): The amount of reflection padding from the right side of the batch dimension 34 | 35 | Returns: 36 | torch.Tensor: Filtered output. 37 | """ 38 | orig_shape = x.shape 39 | 40 | # Transform tensor for temporal filtering 41 | x = x.permute(1, 0, 2, 3) 42 | x = x.view(1, x.shape[0], x.shape[1], -1) 43 | if pad_prev > 0 or pad_next > 0: 44 | x = F.pad(x, (0, 0, pad_prev, pad_next), 'reflect') 45 | 46 | # Apply temporal convolution 47 | x = F.conv2d(x, self.weight, groups=self.groups) 48 | 49 | # Transform tensor back to original shape 50 | x = x.permute(0, 2, 1, 3) 51 | x = x.view((x.shape[1],) + orig_shape[1:]) 52 | 53 | return x 54 | 55 | 56 | def smooth_temporal(x, kernel_size=5, pad_prev=0, pad_next=0): 57 | """ Apply dynamic temporal smoothing kernel on the batch dimension of a 4d tensor. 58 | 59 | Filtering is performed separately for each channel in the input using a depthwise convolution. 60 | 61 | Args: 62 | x (torch.Tensor): Input to apply temporal smoothing on 63 | kernel_size (int): Size of the average kernel 64 | pad_prev (int): The amount of reflection padding from the left side of the batch dimension 65 | pad_next (int): The amount of reflection padding from the right side of the batch dimension 66 | 67 | Returns: 68 | torch.Tensor: Filtered output. 69 | """ 70 | orig_shape = x.shape 71 | 72 | # Create kernel 73 | kernel = torch.ones(x.shape[1], 1, kernel_size, 1).to(x.device) 74 | kernel.div_(kernel_size) 75 | 76 | # Transform tensor for temporal filtering 77 | x = x.permute(1, 0, 2, 3) 78 | x = x.view(1, x.shape[0], x.shape[1], -1) 79 | if pad_prev > 0 or pad_next > 0: 80 | x = F.pad(x, (0, 0, pad_prev, pad_next), 'reflect') 81 | 82 | # Apply temporal convolution 83 | x = F.conv2d(x, kernel, groups=x.shape[1]) 84 | 85 | # Transform tensor back to original shape 86 | x = x.permute(0, 2, 1, 3) 87 | x = x.view((x.shape[1],) + orig_shape[1:]) 88 | 89 | return x 90 | -------------------------------------------------------------------------------- /fsgan/utils/tensorboard_logger.py: -------------------------------------------------------------------------------- 1 | # from torch.utils.tensorboard import SummaryWriter 2 | from tensorboardX import SummaryWriter 3 | 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | 8 | def __init__(self): 9 | self.reset() 10 | 11 | def reset(self): 12 | """ Resets the statistics of previous values. """ 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | """ Add new value. 20 | 21 | Args: 22 | val (float): Value to add 23 | n (int): Count the value n times. Default is 1 24 | """ 25 | self.val = val 26 | self.sum += val * n 27 | self.count += n 28 | self.avg = self.sum / self.count 29 | 30 | 31 | class TensorBoardLogger(SummaryWriter): 32 | """ Writes entries directly to event files in the logdir to be consumed by TensorBoard. 33 | 34 | The logger keeps track of scalar values, allowing to easily log either the last value or the average value. 35 | 36 | Args: 37 | log_dir (str): The directory in which the log files will be written to 38 | """ 39 | 40 | def __init__(self, log_dir=None): 41 | super(TensorBoardLogger, self).__init__(log_dir) 42 | self.__tb_logger = SummaryWriter(log_dir) if log_dir is not None else None 43 | self.log_dict = {} 44 | 45 | def reset(self, prefix=None): 46 | """ Resets all saved scalars and description prefix. 47 | 48 | Args: 49 | prefix (str, optional): The logger's prefix description used when printing the logger status 50 | """ 51 | self.prefix = prefix 52 | self.log_dict.clear() 53 | 54 | def update(self, category='losses', **kwargs): 55 | """ Add named scalar values to the logger. If a scalar with the same name already exists, the new value will 56 | be associated with it. 57 | 58 | Args: 59 | category (str): The scalar category that will be concatenated with the main tag 60 | **kwargs: Named scalar values to be added to the logger. 61 | """ 62 | if category not in self.log_dict: 63 | self.log_dict[category] = {} 64 | category_dict = self.log_dict[category] 65 | for key, val in kwargs.items(): 66 | if key not in category_dict: 67 | category_dict[key] = AverageMeter() 68 | category_dict[key].update(val) 69 | 70 | def log_scalars_val(self, main_tag, global_step=None): 71 | """ Log the last value of all scalars. 72 | 73 | Args: 74 | main_tag (str): The parent name for the tags 75 | global_step (int, optional): Global step value to record 76 | """ 77 | if self.__tb_logger is not None: 78 | for category, category_dict in self.log_dict.items(): 79 | val_dict = {k: v.val for k, v in category_dict.items()} 80 | self.__tb_logger.add_scalars(main_tag + '/' + category, val_dict, global_step) 81 | 82 | def log_scalars_avg(self, main_tag, global_step=None): 83 | """ Log the average value of all scalars. 84 | 85 | Args: 86 | main_tag (str): The parent name for the tags 87 | global_step (int, optional): Global step value to record 88 | """ 89 | if self.__tb_logger is not None: 90 | for category, category_dict in self.log_dict.items(): 91 | val_dict = {k: v.avg for k, v in category_dict.items()} 92 | self.__tb_logger.add_scalars(main_tag + '/' + category, val_dict, global_step) 93 | 94 | def log_image(self, tag, img_tensor, global_step=None): 95 | """ Add an image tensor to the log. 96 | 97 | Args: 98 | tag (str): Name identifier for the image 99 | img_tensor (torch.Tensor): The image tensor to log 100 | global_step (int, optional): Global step value to record 101 | """ 102 | if self.__tb_logger is not None: 103 | self.__tb_logger.add_image(tag, img_tensor, global_step) 104 | 105 | def __str__(self): 106 | desc = '' if self.prefix is None else self.prefix 107 | for category, category_dict in self.log_dict.items(): 108 | desc += '{}: ['.format(category) 109 | for key, log in category_dict.items(): 110 | desc += '{}: {:.4f} ({:.4f}); '.format(key, log.val, log.avg) 111 | desc += '] ' 112 | 113 | return desc 114 | -------------------------------------------------------------------------------- /fsgan/utils/utils.py: -------------------------------------------------------------------------------- 1 | """ General utilities. """ 2 | 3 | import os 4 | import shutil 5 | from functools import partial 6 | import torch 7 | import random 8 | import warnings 9 | import requests 10 | from tqdm import tqdm 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn.init as init 13 | from fsgan.utils.obj_factory import extract_args, obj_factory 14 | 15 | 16 | def init_weights(m, init_type='normal', gain=0.02): 17 | """ Randomly initialize a module's weights. 18 | 19 | Args: 20 | m (nn.Module): The module to initialize its weights 21 | init_type (str): Initialization type: 'normal', 'xavier', 'kaiming', or 'orthogonal' 22 | gain (float): Standard deviation of the normal distribution 23 | """ 24 | classname = m.__class__.__name__ 25 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 26 | if init_type == 'normal': 27 | init.normal_(m.weight.data, 0.0, gain) 28 | elif init_type == 'xavier': 29 | init.xavier_normal_(m.weight.data, gain=gain) 30 | elif init_type == 'kaiming': 31 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 32 | elif init_type == 'orthogonal': 33 | init.orthogonal_(m.weight.data, gain=gain) 34 | else: 35 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 36 | if hasattr(m, 'bias') and m.bias is not None: 37 | init.constant_(m.bias.data, 0.0) 38 | elif classname.find('BatchNorm2d') != -1 or classname.find('BatchNorm3d') != -1: 39 | init.normal_(m.weight.data, 1.0, gain) 40 | init.constant_(m.bias.data, 0.0) 41 | 42 | 43 | def set_device(gpus=None, use_cuda=True): 44 | """ Sets computing device. Either the CPU or any of the available GPUs. 45 | 46 | Args: 47 | gpus (list of int, optional): The GPU ids to use. If not specified, all available GPUs will be used 48 | use_cuda (bool, optional): If True, CUDA enabled GPUs will be used, else the CPU will be used 49 | 50 | Returns: 51 | torch.device: The selected computing device. 52 | """ 53 | use_cuda = torch.cuda.is_available() if use_cuda else use_cuda 54 | if use_cuda: 55 | gpus = list(range(torch.cuda.device_count())) if not gpus else gpus 56 | print('=> using GPU devices: {}'.format(', '.join(map(str, gpus)))) 57 | else: 58 | gpus = None 59 | print('=> using CPU device') 60 | device = torch.device('cuda:{}'.format(gpus[0])) if gpus else torch.device('cpu') 61 | 62 | return device, gpus 63 | 64 | 65 | def set_seed(seed): 66 | """ Sets computing device. Either the CPU or any of the available GPUs. 67 | 68 | Args: 69 | gpus (list of int, optional): The GPU ids to use. If not specified, all available GPUs will be used 70 | use_cuda (bool, optional): If True, CUDA enabled GPUs will be used, else the CPU will be used 71 | 72 | Returns: 73 | torch.device: The selected computing device. 74 | """ 75 | if seed is not None: 76 | random.seed(seed) 77 | torch.manual_seed(seed) 78 | cudnn.deterministic = True 79 | warnings.warn('You have chosen to seed training. ' 80 | 'This will turn on the CUDNN deterministic setting, ' 81 | 'which can slow down your training considerably! ' 82 | 'You may see unexpected behavior when restarting ' 83 | 'from checkpoints.') 84 | 85 | 86 | def save_checkpoint(exp_dir, base_name, state, is_best=False): 87 | """ Saves a model's checkpoint. 88 | 89 | Args: 90 | exp_dir (str): Experiment directory to save the checkpoint into. 91 | base_name (str): The output file name will be _latest.pth and optionally _best.pth 92 | state (dict): The model state to save. 93 | is_best (bool): If True, _best.pth will be saved as well. 94 | """ 95 | filename = os.path.join(exp_dir, base_name + '_latest.pth') 96 | torch.save(state, filename) 97 | if is_best: 98 | shutil.copyfile(filename, os.path.join(exp_dir, base_name + '_best.pth')) 99 | 100 | 101 | mag_map = {'K': 3, 'M': 6, 'B': 9} 102 | 103 | 104 | def str2int(s): 105 | """ Converts a string containing a number with 'K', 'M', or 'B' to an integer. """ 106 | if isinstance(s, (list, tuple)): 107 | return [str2int(o) for o in s] 108 | if not isinstance(s, str): 109 | return s 110 | return int(float(s[:-1]) * 10 ** mag_map[s[-1].upper()]) if s[-1].upper() in mag_map else int(s) 111 | 112 | 113 | def get_arch(obj, *args, **kwargs): 114 | """ Extract the architecture (string representation) of an object given as a string or partial together 115 | with additional provided arguments. 116 | 117 | The returned architecture can be used to create the object using the obj_factory function. 118 | 119 | Args: 120 | obj (str or partial): The object string expresion or partial to be converted into an object 121 | *args: Additional arguments to pass to the object 122 | **kwargs: Additional keyword arguments to pass to the object 123 | 124 | Returns: 125 | arch (str): The object's architecture (string representation). 126 | """ 127 | obj_args, obj_kwargs = [], {} 128 | if isinstance(obj, str): 129 | if '(' in obj and ')' in obj: 130 | arg_pos = obj.find('(') 131 | func = obj[:arg_pos] 132 | args_exp = obj[arg_pos:] 133 | obj_args, obj_kwargs = eval('extract_args' + args_exp) 134 | else: 135 | func = obj 136 | elif isinstance(obj, partial): 137 | func = obj.func.__module__ + '.' + obj.func.__name__ 138 | obj_args, obj_kwargs = obj.args, obj.keywords 139 | else: 140 | return None 141 | 142 | # Concatenate arguments 143 | obj_args = obj_args + args 144 | obj_kwargs.update(kwargs) 145 | 146 | # Convert object components to string representation 147 | args = ",".join(map(repr, obj_args)) 148 | kwargs = ",".join("{}={!r}".format(k, v) for k, v in obj_kwargs.items()) 149 | comma = ',' if args != '' and kwargs != '' else '' 150 | format_string = '{func}({args}{comma}{kwargs})' 151 | arch = format_string.format(func=func, args=args, comma=comma, kwargs=kwargs).replace(' ', '') 152 | 153 | return arch 154 | 155 | 156 | def load_model(model_path, name='', device=None, arch=None, return_checkpoint=False, train=False): 157 | """ Load a model from checkpoint. 158 | 159 | This is a utility function that combines the model weights and architecture (string representation) to easily 160 | load any model without explicit knowledge of its class. 161 | 162 | Args: 163 | model_path (str): Path to the model's checkpoint (.pth) 164 | name (str): The name of the model (for printing and error management) 165 | device (torch.device): The device to load the model to 166 | arch (str): The model's architecture (string representation) 167 | return_checkpoint (bool): If True, the checkpoint will be returned as well 168 | train (bool): If True, the model will be set to train mode, else it will be set to test mode 169 | 170 | Returns: 171 | (nn.Module, dict (optional)): A tuple that contains: 172 | - model (nn.Module): The loaded model 173 | - checkpoint (dict, optional): The model's checkpoint (only if return_checkpoint is True) 174 | """ 175 | assert model_path is not None, '%s model must be specified!' % name 176 | assert os.path.exists(model_path), 'Couldn\'t find %s model in path: %s' % (name, model_path) 177 | print('=> Loading %s model: "%s"...' % (name, os.path.basename(model_path))) 178 | checkpoint = torch.load(model_path) 179 | assert arch is not None or 'arch' in checkpoint, 'Couldn\'t determine %s model architecture!' % name 180 | arch = checkpoint['arch'] if arch is None else arch 181 | model = obj_factory(arch) 182 | if device is not None: 183 | model.to(device) 184 | model.load_state_dict(checkpoint['state_dict']) 185 | model.train(train) 186 | 187 | if return_checkpoint: 188 | return model, checkpoint 189 | else: 190 | return model 191 | 192 | 193 | def random_pair(n, min_dist=1, index1=None): 194 | """ Return a random pair of integers in the range [0, n) with a minimum distance between them. 195 | 196 | Args: 197 | n (int): Determine the range size 198 | min_dist (int): The minimum distance between the random pair 199 | index1 (int, optional): If specified, this will determine the first integer 200 | 201 | Returns: 202 | (int, int): The random pair of integers. 203 | """ 204 | r1 = random.randint(0, n - 1) if index1 is None else index1 205 | d_left = min(r1, min_dist) 206 | d_right = min(n - 1 - r1, min_dist) 207 | r2 = random.randint(0, n - 2 - d_left - d_right) 208 | r2 = r2 + d_left + 1 + d_right if r2 >= (r1 - d_left) else r2 209 | 210 | return r1, r2 211 | 212 | 213 | def random_pair_range(a, b, min_dist=1, index1=None): 214 | """ Return a random pair of integers in the range [a, b] with a minimum distance between them. 215 | 216 | Args: 217 | a (int): The minimum number in the range 218 | b (int): The maximum number in the range 219 | min_dist (int): The minimum distance between the random pair 220 | index1 (int, optional): If specified, this will determine the first integer 221 | 222 | Returns: 223 | (int, int): The random pair of integers. 224 | """ 225 | r1 = random.randint(a, b) if index1 is None else index1 226 | d_left = min(r1 - a, min_dist) 227 | d_right = min(b - r1, min_dist) 228 | r2 = random.randint(a, b - 1 - d_left - d_right) 229 | r2 = r2 + d_left + 1 + d_right if r2 >= (r1 - d_left) else r2 230 | 231 | return r1, r2 232 | 233 | 234 | # Adapted from: https://github.com/Sudy/coling2018/blob/master/torchtext/utils.py 235 | def download_from_url(url, output_path): 236 | """ Download file from url including Google Drive. 237 | 238 | Args: 239 | url (str): File URL 240 | output_path (str): Output path to write the file to 241 | """ 242 | def process_response(r): 243 | chunk_size = 16 * 1024 244 | total_size = int(r.headers.get('Content-length', 0)) 245 | with open(output_path, "wb") as file: 246 | with tqdm(total=total_size, unit='B', unit_scale=1, desc=os.path.split(output_path)[1]) as t: 247 | for chunk in r.iter_content(chunk_size): 248 | if chunk: 249 | file.write(chunk) 250 | t.update(len(chunk)) 251 | 252 | if 'drive.google.com' not in url: 253 | response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) 254 | process_response(response) 255 | return 256 | 257 | # print('downloading from Google Drive; may take a few minutes') 258 | confirm_token = None 259 | session = requests.Session() 260 | response = session.get(url, stream=True) 261 | for k, v in response.cookies.items(): 262 | if k.startswith("download_warning"): 263 | confirm_token = v 264 | 265 | if confirm_token: 266 | url = url + "&confirm=" + confirm_token 267 | response = session.get(url, stream=True) 268 | 269 | process_response(response) 270 | 271 | 272 | def main(): 273 | from torch.optim.lr_scheduler import StepLR 274 | scheduler = partial(StepLR, step_size=10, gamma=0.5) 275 | print(get_arch(scheduler)) 276 | scheduler = partial(StepLR, 10, 0.5) 277 | print(get_arch(scheduler)) 278 | scheduler = partial(StepLR, 10, gamma=0.5) 279 | print(get_arch(scheduler)) 280 | scheduler = partial(StepLR) 281 | print(get_arch(scheduler)) 282 | print(get_arch(scheduler, 10, gamma=0.5)) 283 | 284 | scheduler = 'torch.optim.lr_scheduler.StepLR(step_size=10,gamma=0.5)' 285 | print(get_arch(scheduler)) 286 | scheduler = 'torch.optim.lr_scheduler.StepLR(10,0.5)' 287 | print(get_arch(scheduler)) 288 | scheduler = 'torch.optim.lr_scheduler.StepLR(10,gamma=0.5)' 289 | print(get_arch(scheduler)) 290 | scheduler = 'torch.optim.lr_scheduler.StepLR()' 291 | print(get_arch(scheduler)) 292 | print(get_arch(scheduler, 10, gamma=0.5)) 293 | 294 | 295 | if __name__ == "__main__": 296 | main() 297 | -------------------------------------------------------------------------------- /fsgan/utils/video_renderer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | import torch.multiprocessing as mp 5 | from fsgan.utils.img_utils import tensor2bgr 6 | from fsgan.utils.bbox_utils import crop2img, scale_bbox 7 | 8 | 9 | class VideoRenderer(mp.Process): 10 | """ Renders input video frames to both screen and video file. 11 | 12 | For more control on the rendering, this class should be inherited from and the on_render method overridden 13 | with an application specific implementation. 14 | 15 | Args: 16 | display (bool): If True, the rendered video will be displayed on screen 17 | verbose (int): Verbose level. Controls the amount of debug information in the rendering 18 | verbose_size (tuple of int): The rendered frame size for verbose level other than zero (width, height) 19 | output_crop (bool): If True, a cropped frame of size (resolution, resolution) will be rendered for 20 | verbose level zero 21 | resolution (int): Determines the size of cropped frames to be (resolution, resolution) 22 | crop_scale (float): Multiplier factor to scale tight bounding boxes 23 | encoder_codec (str): Encoder codec code 24 | separate_process (bool): If True, the renderer will be run in a separate process 25 | """ 26 | def __init__(self, display=False, verbose=0, verbose_size=None, output_crop=False, resolution=256, crop_scale=1.2, 27 | encoder_codec='mp4v', separate_process=False): 28 | super(VideoRenderer, self).__init__() 29 | self._display = display 30 | self._verbose = verbose 31 | self._verbose_size = verbose_size 32 | self._output_crop = output_crop 33 | self._resolution = resolution 34 | self._crop_scale = crop_scale 35 | self._running = True 36 | self._input_queue = mp.Queue() 37 | self._reply_queue = mp.Queue() 38 | self._fourcc = cv2.VideoWriter_fourcc(*encoder_codec) 39 | self._separate_process = separate_process 40 | self._in_vid = None 41 | self._out_vid = None 42 | self._seq = None 43 | self._in_vid_path = None 44 | self._total_frames = None 45 | self._frame_count = 0 46 | 47 | def init(self, in_vid_path, seq, out_vid_path=None, **kwargs): 48 | """ Initialize the video render for a new video rendering job. 49 | 50 | Args: 51 | in_vid_path (str): Input video path 52 | seq (Sequence): Input sequence corresponding to the input video 53 | out_vid_path (str, optional): If specified, the rendering will be written to an output video in that path 54 | **kwargs (dict): Additional keyword arguments that will be added as members of the class. This allows 55 | inheriting classes to access those arguments from the new process 56 | """ 57 | if self._separate_process: 58 | self._input_queue.put([in_vid_path, seq, out_vid_path, kwargs]) 59 | else: 60 | self._init_task(in_vid_path, seq, out_vid_path, kwargs) 61 | 62 | def write(self, *args): 63 | """ Add tensors for rendering. 64 | 65 | Args: 66 | *args (tuple of torch.Tensor): The tensors for rendering 67 | """ 68 | if self._separate_process: 69 | self._input_queue.put([a.cpu() for a in args]) 70 | else: 71 | self._write_batch([a.cpu() for a in args]) 72 | 73 | def finalize(self): 74 | if self._separate_process: 75 | self._input_queue.put(True) 76 | else: 77 | self._finalize_task() 78 | 79 | def wait_until_finished(self): 80 | """ Wait for the video renderer to finish the current video rendering job. """ 81 | if self._separate_process: 82 | return self._reply_queue.get() 83 | else: 84 | return True 85 | 86 | def on_render(self, *args): 87 | """ Given the input tensors this method produces a cropped rendered image. 88 | 89 | This method should be overridden by inheriting classes to customize the rendering. By default this method 90 | expects the first tensor to be a cropped image tensor of shape (B, 3, H, W) where B is the batch size, 91 | H is the height of the image and W is the width of the image. 92 | 93 | Args: 94 | *args (tuple of torch.Tensor): The tensors for rendering 95 | 96 | Returns: 97 | render_bgr (np.array): The cropped rendered image 98 | """ 99 | return tensor2bgr(args[0]) 100 | 101 | def start(self): 102 | if self._separate_process: 103 | super(VideoRenderer, self).start() 104 | 105 | def kill(self): 106 | if self._separate_process: 107 | super(VideoRenderer, self).kill() 108 | 109 | def run(self): 110 | """ Main processing loop. Intended to be executed on a separate process. """ 111 | while self._running: 112 | task = self._input_queue.get() 113 | 114 | # Initialize new video rendering task 115 | if self._in_vid is None: 116 | self._init_task(*task[:3], task[3]) 117 | continue 118 | 119 | # Finalize task 120 | if isinstance(task, bool): 121 | self._finalize_task() 122 | 123 | # Notify job is finished 124 | self._reply_queue.put(True) 125 | continue 126 | 127 | # Write a batch of frames 128 | self._write_batch(task) 129 | 130 | def _render(self, render_bgr, full_frame_bgr=None, bbox=None): 131 | if self._verbose == 0 and not self._output_crop and full_frame_bgr is not None: 132 | render_bgr = crop2img(full_frame_bgr, render_bgr, bbox) 133 | if self._out_vid is not None: 134 | self._out_vid.write(render_bgr) 135 | if self._display: 136 | cv2.imshow('render', render_bgr) 137 | if cv2.waitKey(1) & 0xFF == ord('q'): 138 | self._running = False 139 | 140 | def _init_task(self, in_vid_path, seq, out_vid_path, additional_attributes): 141 | # print('_init_task start') 142 | self._in_vid_path, self._seq = in_vid_path, seq 143 | self._frame_count = 0 144 | 145 | # Add additional arguments as members 146 | for attr_name, attr_val in additional_attributes.items(): 147 | setattr(self, attr_name, attr_val) 148 | 149 | # Open input video 150 | self._in_vid = cv2.VideoCapture(self._in_vid_path) 151 | assert self._in_vid.isOpened(), f'Failed to open video: "{self._in_vid_path}"' 152 | 153 | in_total_frames = int(self._in_vid.get(cv2.CAP_PROP_FRAME_COUNT)) 154 | fps = self._in_vid.get(cv2.CAP_PROP_FPS) 155 | in_vid_width = int(self._in_vid.get(cv2.CAP_PROP_FRAME_WIDTH)) 156 | in_vid_height = int(self._in_vid.get(cv2.CAP_PROP_FRAME_HEIGHT)) 157 | self._total_frames = in_total_frames if self._verbose == 0 else len(self._seq) 158 | # print(f'Debug: initializing video: "{self._in_vid_path}", total_frames={self._total_frames}') 159 | 160 | # Initialize output video 161 | if out_vid_path is not None: 162 | out_size = (in_vid_width, in_vid_height) 163 | if self._verbose <= 0 and self._output_crop: 164 | out_size = (self._resolution, self._resolution) 165 | elif self._verbose_size is not None: 166 | out_size = self._verbose_size 167 | self._out_vid = cv2.VideoWriter(out_vid_path, self._fourcc, fps, out_size) 168 | 169 | # Write frames as they are until the start of the sequence 170 | if self._verbose == 0: 171 | for i in range(self._seq.start_index): 172 | # Read frame 173 | ret, frame_bgr = self._in_vid.read() 174 | assert frame_bgr is not None, f'Failed to read frame {i} from input video: "{self._in_vid_path}"' 175 | self._render(frame_bgr) 176 | self._frame_count += 1 177 | 178 | def _write_batch(self, tensors): 179 | batch_size = tensors[0].shape[0] 180 | 181 | # For each frame in the current batch of tensors 182 | for b in range(batch_size): 183 | # Handle full frames if output_crop was not specified 184 | full_frame_bgr, bbox = None, None 185 | if self._verbose == 0 and not self._output_crop: 186 | # Read frame from input video 187 | ret, full_frame_bgr = self._in_vid.read() 188 | assert full_frame_bgr is not None, \ 189 | f'Failed to read frame {self._frame_count} from input video: "{self._in_vid_path}"' 190 | 191 | # Get bounding box from sequence 192 | det = self._seq[self._frame_count - self._seq.start_index] 193 | bbox = np.concatenate((det[:2], det[2:] - det[:2])) 194 | bbox = scale_bbox(bbox, self._crop_scale) 195 | 196 | render_bgr = self.on_render(*[t[b] for t in tensors]) 197 | self._render(render_bgr, full_frame_bgr, bbox) 198 | self._frame_count += 1 199 | # print(f'Debug: Wrote frame: {self._frame_count}') 200 | 201 | def _finalize_task(self): 202 | if self._verbose == 0 and self._frame_count >= (self._seq.start_index + len(self._seq)): 203 | for i in range(self._seq.start_index + len(self._seq), self._total_frames): 204 | # Read frame 205 | ret, frame_bgr = self._in_vid.read() 206 | assert frame_bgr is not None, f'Failed to read frame {i} from input video: "{self._in_vid_path}"' 207 | self._render(frame_bgr) 208 | self._frame_count += 1 209 | # print(f'Debug: Wrote frame: {self._frame_count}') 210 | 211 | # if self._frame_count >= self._total_frames: 212 | # Clean up 213 | self._in_vid.release() 214 | self._out_vid.release() 215 | self._in_vid = None 216 | self._out_vid = None 217 | self._seq = None 218 | self._in_vid_path = None 219 | self._total_frames = None 220 | self._frame_count = 0 221 | -------------------------------------------------------------------------------- /fsgan/utils/video_utils.py: -------------------------------------------------------------------------------- 1 | """ Video utilities. """ 2 | 3 | import numpy as np 4 | from itertools import count 5 | import ffmpeg 6 | from fsgan.utils.one_euro_filter import OneEuroFilter 7 | 8 | 9 | class Sequence(object): 10 | """ Represents a sequence of detected faces in a video. 11 | 12 | Args: 13 | start_index (int): The frame index in the video from which the sequence starts 14 | det (np.array): Frame face detections bounding boxes of shape (N, 4), in the format [left, top, right, bottom] 15 | """ 16 | _ids = count(0) 17 | 18 | def __init__(self, start_index, det=None): 19 | self.start_index = start_index 20 | self.size_sum = 0. 21 | self.size_avg = 0. 22 | self.id = next(self._ids) 23 | self.obj_id = -1 24 | self.detections = [] 25 | 26 | if det is not None: 27 | self.add(det) 28 | 29 | def add(self, det): 30 | """ Add new frame detection bounding boxes. 31 | 32 | Args: 33 | det (np.array): Frame face detections bounding boxes of shape (N, 4), in the format 34 | [left, top, right, bottom] 35 | """ 36 | self.detections.append(det) 37 | size = det[3] - det[1] 38 | self.size_sum += size 39 | self.size_avg = self.size_sum / len(self.detections) 40 | 41 | def smooth(self, kernel_size=7): 42 | """ Temporally smooth the detection bounding boxes. 43 | 44 | Args: 45 | kernel_size (int): The temporal kernel size 46 | """ 47 | # Prepare smoothing kernel 48 | w = np.hamming(kernel_size) 49 | w /= w.sum() 50 | 51 | # Smooth bounding boxes 52 | bboxes = np.array(self.detections) 53 | bboxes_padded = np.pad(bboxes, ((kernel_size // 2, kernel_size // 2), (0, 0)), 'reflect') 54 | for i in range(bboxes.shape[1]): 55 | bboxes[:, i] = np.convolve(w, bboxes_padded[:, i], mode='valid') 56 | 57 | self.detections = bboxes 58 | 59 | def finalize(self): 60 | """ Packs all list of added detections into a single numpy array. 61 | 62 | Should be called after all detections were added if smooth was not called. 63 | """ 64 | self.detections = np.array(self.detections) 65 | 66 | def __getitem__(self, index): 67 | return self.detections[index] 68 | 69 | def __len__(self): 70 | return len(self.detections) 71 | 72 | 73 | # TODO: Remove this 74 | def estimate_motion(detections, min_cutoff=0.0, beta=3.0, d_cutoff=5.0, fps=30.0): 75 | one_euro_filter = OneEuroFilter(min_cutoff=min_cutoff, beta=beta, d_cutoff=d_cutoff, t_e=(1.0 / fps)) 76 | detections_n = np.array(detections) 77 | center = np.mean((detections_n[:, 2:] + detections_n[:, :2])*0.5, axis=0) 78 | size = np.mean(detections_n[:, 2:] - detections_n[:, :2], axis=0) 79 | detections_n = (detections_n - np.concatenate((center, center))) / np.concatenate((size, size)) 80 | 81 | motion = [] 82 | for det in detections_n: 83 | det_s, a = one_euro_filter(det) 84 | motion.append(a) 85 | 86 | return np.array(motion) 87 | 88 | 89 | # TODO: Remove this 90 | def smooth_detections_avg(detections, kernel_size=7): 91 | # Prepare smoothing kernel 92 | # w = np.hamming(kernel_size) 93 | w = np.ones(kernel_size) 94 | w /= w.sum() 95 | 96 | # Smooth bounding boxes 97 | bboxes = np.array(detections) 98 | bboxes_padded = np.pad(bboxes, ((kernel_size // 2, kernel_size // 2), (0, 0)), 'reflect') 99 | for i in range(bboxes.shape[1]): 100 | bboxes[:, i] = np.convolve(w, bboxes_padded[:, i], mode='valid') 101 | 102 | return bboxes 103 | 104 | 105 | # TODO: Remove this 106 | def smooth_detections_1euro(detections, kernel_size=7, min_cutoff=0.0, beta=3.0, d_cutoff=5.0, fps=30.0): 107 | detections_np = np.array(detections) 108 | detections_avg = smooth_detections_avg(detections, kernel_size) 109 | motion = np.expand_dims(estimate_motion(detections, min_cutoff, beta, d_cutoff, fps), 1).astype('float32') 110 | out_detections = detections_np * motion + detections_avg * (1 - motion) 111 | 112 | return out_detections 113 | 114 | 115 | # TODO: Remove this 116 | def smooth_detections_avg_center(detections, center_kernel=11, size_kernel=21): 117 | # Prepare smoothing kernel 118 | center_w = np.ones(center_kernel) 119 | center_w /= center_w.sum() 120 | size_w = np.ones(size_kernel) 121 | size_w /= size_w.sum() 122 | 123 | # Convert bounding boxes to center and size format 124 | bboxes = np.array(detections) 125 | centers = (bboxes[:, :2] + bboxes[:, 2:]) / 2.0 126 | sizes = bboxes[:, 2:] - bboxes[:, :2] 127 | 128 | # Smooth bounding boxes 129 | centers_padded = np.pad(centers, ((center_kernel // 2, center_kernel // 2), (0, 0)), 'reflect') 130 | sizes_padded = np.pad(sizes, ((size_kernel // 2, size_kernel // 2), (0, 0)), 'reflect') 131 | for i in range(centers.shape[1]): 132 | centers[:, i] = np.convolve(center_w, centers_padded[:, i], mode='valid') 133 | sizes[:, i] = np.convolve(size_w, sizes_padded[:, i], mode='valid') 134 | 135 | # Change back to detections format 136 | sizes /= 2.0 137 | bboxes = np.concatenate((centers - sizes, centers + sizes), axis=1) 138 | 139 | return bboxes 140 | 141 | 142 | def get_main_sequence(seq_list, frame_size): 143 | """ Return the main sequence in a list of sequences according to their size and how central they are. 144 | 145 | Args: 146 | seq_list (list of Sequence): List of sequences 147 | frame_size (tuple of int): The corresponding sequence video's frame size of shape (H, W) 148 | 149 | Returns: 150 | Sequence: The main sequence. 151 | """ 152 | if len(seq_list) == 0: 153 | return None 154 | 155 | # Calculate frame max distance and size 156 | img_center = np.array([frame_size[1], frame_size[0]]) * 0.5 157 | max_dist = 0.25 * np.linalg.norm(frame_size) 158 | max_size = 0.25 * (frame_size[0] + frame_size[1]) 159 | 160 | # For each sequence 161 | seq_scores = [] 162 | for seq in seq_list: 163 | 164 | # For each detection in the sequence 165 | det_scores = [] 166 | for det in seq: 167 | bbox = np.concatenate((det[:2], det[2:] - det[:2])) 168 | 169 | # Calculate center distance 170 | bbox_center = bbox[:2] + bbox[2:] * 0.5 171 | bbox_dist = np.linalg.norm(bbox_center - img_center) 172 | 173 | # Calculate bbox size 174 | bbox_size = bbox[2:].mean() 175 | 176 | # Calculate central ratio 177 | central_ratio = 1.0 if max_size < 1e-6 else (1.0 - bbox_dist / max_dist) 178 | central_ratio = np.clip(central_ratio, 0.0, 1.0) 179 | 180 | # Calculate size ratio 181 | size_ratio = 1.0 if max_size < 1e-6 else (bbox_size / max_size) 182 | size_ratio = np.clip(size_ratio, 0.0, 1.0) 183 | 184 | # Add score 185 | score = (central_ratio + size_ratio) * 0.5 186 | det_scores.append(score) 187 | 188 | seq_scores.append(np.array(det_scores).mean()) 189 | 190 | return seq_list[np.argmax(seq_scores)] 191 | 192 | 193 | def get_media_info(media_path): 194 | """ Return media information. 195 | 196 | Args: 197 | media_path (str): Path to media file 198 | 199 | Returns: 200 | (int, int, int, float): Tuple containing: 201 | - width (int): Frame width 202 | - height (int): Frame height 203 | - total_frames (int): Total number of frames (will be 1 for images) 204 | - fps (float): Frames per second (irrelevant for images) 205 | """ 206 | probe = ffmpeg.probe(media_path) 207 | video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) 208 | width = int(video_stream['width']) 209 | height = int(video_stream['height']) 210 | total_frames = int(video_stream['nb_frames']) if 'nb_frames' in video_stream else 1 211 | fps_part1, fps_part2 = video_stream['r_frame_rate'].split(sep='/') 212 | fps = float(fps_part1) / float(fps_part2) 213 | 214 | return width, height, total_frames, fps 215 | 216 | 217 | def get_media_resolution(media_path): 218 | return get_media_info(media_path)[:2] 219 | 220 | 221 | # TODO: Remove this 222 | def get_video_info(vid_path): 223 | return get_media_info(vid_path) 224 | -------------------------------------------------------------------------------- /fsgan_env.yml: -------------------------------------------------------------------------------- 1 | name: fsgan 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - anaconda 6 | dependencies: 7 | - python=3.9.13 8 | - pip=20.3.1 9 | - cudatoolkit=11.6 10 | - pytorch=1.12.1 11 | - torchvision=0.13.1 12 | - ffmpeg=4.* 13 | - yacs=0.1.8 14 | - pip: 15 | - setuptools==58.2.0 16 | - torch-summary==1.4.5 17 | - opencv-contrib-python==4.5.4.60 18 | - tensorflow==2.7.0 19 | - tqdm==4.64.1 20 | - matplotlib==3.6.2 21 | - ffmpeg-python==0.2.0 22 | - PyYAML==6.0 23 | - pandas==1.5.1 24 | - seaborn==0.12.1 25 | - scipy==1.9.3 26 | - git+https://github.com/YuvalNirkin/face_detection_dsfd.git -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="fsgan", 5 | version="1.0.1", 6 | author="Dr. Yuval Nirkin", 7 | author_email="yuval.nirkin@gmail.com", 8 | description="FSGAN: Subject Agnostic Face Swapping and Reenactment", 9 | long_description_content_type="text/markdown", 10 | package_data={'': ['license.txt']}, 11 | include_package_data=True, 12 | packages=setuptools.find_packages(), 13 | python_requires='>=3.6', 14 | ) 15 | --------------------------------------------------------------------------------