├── .gitignore ├── LICENSE.md ├── README.md ├── animate.py ├── assets ├── driving.mp4 ├── src.png └── visual_vox1.png ├── augmentation.py ├── config └── vox-adv-256.yaml ├── crop-video.py ├── data ├── celeV_cross_id_evaluation.csv ├── utils.py ├── vox256.csv ├── vox_cross_id_animate.csv ├── vox_cross_id_evaluation.csv ├── vox_cross_id_evaluation_best_frame.csv └── vox_evaluation.csv ├── demo.py ├── demo_multi.py ├── depth ├── __init__.py ├── depth_decoder.py ├── layers.py ├── models │ └── opt.json ├── pose_cnn.py ├── pose_decoder.py └── resnet_encoder.py ├── evaluation_dataset.py ├── face-alignment ├── .gitattributes ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── conda │ └── meta.yaml ├── docs │ └── images │ │ ├── 2dlandmarks.png │ │ └── face-alignment-adrian.gif ├── examples │ ├── demo.ipynb │ └── detect_landmarks_in_image.py ├── face_alignment │ ├── __init__.py │ ├── api.py │ ├── detection │ │ ├── __init__.py │ │ ├── blazeface │ │ │ ├── __init__.py │ │ │ ├── blazeface_detector.py │ │ │ ├── detect.py │ │ │ ├── net_blazeface.py │ │ │ └── utils.py │ │ ├── core.py │ │ ├── dlib │ │ │ ├── __init__.py │ │ │ └── dlib_detector.py │ │ ├── folder │ │ │ ├── __init__.py │ │ │ └── folder_detector.py │ │ └── sfd │ │ │ ├── __init__.py │ │ │ ├── bbox.py │ │ │ ├── detect.py │ │ │ ├── net_s3fd.py │ │ │ └── sfd_detector.py │ └── utils.py ├── requirements.txt ├── setup.cfg ├── setup.py ├── test │ ├── facealignment_test.py │ ├── smoke_test.py │ └── test_utils.py └── tox.ini ├── frames_dataset.py ├── kill_port.py ├── logger.py ├── modules ├── AdaIN.py ├── dense_motion.py ├── discriminator.py ├── dynamic_conv.py ├── generator.py ├── keypoint_detector.py ├── model.py ├── model_dataparallel.py └── util.py ├── reconstruction.py ├── requirements.txt ├── run.py ├── run_dataparallel.py ├── sync_batchnorm ├── __init__.py ├── batchnorm.py ├── comm.py ├── replicate.py └── unittest.py ├── train.py ├── train_dataparallel.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | log 3 | *.pth* 4 | readmesam* 5 | *.jpg 6 | *.mp4 7 | run.sh 8 | source.png 9 | tools 10 | *.pt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## :book: Depth-Aware Generative Adversarial Network for Talking Head Video Generation (CVPR 2022) 3 |

4 | :fire: If DaGAN is helpful in your photos/projects, please help to :star: it or recommend it to your friends. Thanks:fire: 5 |

6 |

7 | :fire: Seeking for the collaboration and internship opportunities. :fire: 8 |

9 | 10 | > [[Paper](https://arxiv.org/abs/2203.06605)]   [[Project Page](https://harlanhong.github.io/publications/dagan.html)]   [[Demo](https://huggingface.co/spaces/HarlanHong/DaGAN)]   [[Poster Video](https://www.youtube.com/watch?v=nahsJNjWzGo&t=1s)]
11 | 12 | 13 | > [Fa-Ting Hong](https://harlanhong.github.io), [Longhao Zhang](), [Li Shen](), [Dan Xu](https://www.danxurgb.net)
14 | > The Hong Kong University of Science and Technology
15 | > Alibaba Cloud 16 | 17 | ### Cartoon Sample 18 | https://user-images.githubusercontent.com/19970321/162151632-0195292f-30b8-4122-8afd-9b1698f1e4fe.mp4 19 | 20 | ### Human Sample 21 | https://user-images.githubusercontent.com/19970321/162151327-f2930231-42e3-40f2-bfca-a88529599f0f.mp4 22 | 23 | ### Voxceleb1 Dataset 24 |

25 | 26 |

27 | 28 | :triangular_flag_on_post: **Updates** 29 | - :fire::fire::white_check_mark: July 20 2023: Our new talking head work **[MCNet](https://harlanhong.github.io/publications/mcnet.html) was accpted by ICCV2023. There's no need to train a facial depth network, which makes it more convenient for users to test and fine-tune. 30 | - :fire::fire::white_check_mark: July 26, 2022: The normal dataparallel training scripts were released since some researchers informed me they ran into **DistributedDataParallel** problems. Please try to train your own model using this [command](#dataparallel). Also, we deleted the command line "with torch.autograd.set_detect_anomaly(True)" to boost the training speed. 31 | - :fire::fire::white_check_mark: June 26, 2022: The repo of our face depth network is released, please refer to [Face-Depth-Network](https://github.com/harlanhong/Face-Depth-Network) and feel free to email me if you meet any problem. 32 | - :fire::fire::white_check_mark: June 21, 2022: [Digression] I am looking for research intern/research assistant opportunities in European next year. Please contact me If you think I'm qualified for your position. 33 | - :fire::fire::white_check_mark: May 19, 2022: The depth face model (50 layers) trained on Voxceleb2 is released! (The corresponding checkpoint of DaGAN will release soon). Click the [LINK](https://hkustconnect-my.sharepoint.com/:f:/g/personal/fhongac_connect_ust_hk/EkxzfH7zbGJNr-WVmPU6fcABWAMq_WJoExAl4SttKK6hBQ?e=fbtGlX) 34 | 35 | - :fire::fire::white_check_mark: April 25, 2022: Integrated into Huggingface Spaces 🤗 using Gradio. Try out the web demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/HarlanHong/DaGAN) (GPU version will come soon!) 36 | - :fire::fire::white_check_mark: Add **[SPADE model](https://hkustconnect-my.sharepoint.com/:f:/g/personal/fhongac_connect_ust_hk/EjfeXuzwo3JMn7s0oOPN_q0B81P5Wgu_kbYJAh7uSAKS2w?e=XNZl3K)**, which produces **more natural** results. 37 | 38 | 39 | ## :wrench: Dependencies and Installation 40 | 41 | - Python >= 3.7 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html)) 42 | - [PyTorch >= 1.7](https://pytorch.org/) 43 | - Option: NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads) 44 | - Option: Linux 45 | 46 | ### Installation 47 | We now provide a *clean* version of DaGAN, which does not require customized CUDA extensions.
48 | 49 | 1. Clone repo 50 | 51 | ```bash 52 | git clone https://github.com/harlanhong/CVPR2022-DaGAN.git 53 | cd CVPR2022-DaGAN 54 | ``` 55 | 56 | 2. Install dependent packages 57 | 58 | ```bash 59 | pip install -r requirements.txt 60 | 61 | ## Install the Face Alignment lib 62 | cd face-alignment 63 | pip install -r requirements.txt 64 | python setup.py install 65 | ``` 66 | ## :zap: Quick Inference 67 | 68 | We take the paper version for an example. More models can be found [here](https://hkustconnect-my.sharepoint.com/:f:/g/personal/fhongac_connect_ust_hk/EjfeXuzwo3JMn7s0oOPN_q0B81P5Wgu_kbYJAh7uSAKS2w?e=KaQcPk). 69 | 70 | ### YAML configs 71 | See ```config/vox-adv-256.yaml``` to get description of each parameter. 72 | 73 | ### Pre-trained checkpoint 74 | The pre-trained checkpoint of face depth network and our DaGAN checkpoints can be found under following link: [OneDrive](https://hkustconnect-my.sharepoint.com/:f:/g/personal/fhongac_connect_ust_hk/EjfeXuzwo3JMn7s0oOPN_q0B81P5Wgu_kbYJAh7uSAKS2w?e=KaQcPk). 75 | 76 | **Inference!** 77 | To run a demo, download checkpoint and run the following command: 78 | 79 | ```bash 80 | CUDA_VISIBLE_DEVICES=0 python demo.py --config config/vox-adv-256.yaml --driving_video path/to/driving --source_image path/to/source --checkpoint path/to/checkpoint --relative --adapt_scale --kp_num 15 --generator DepthAwareGenerator 81 | ``` 82 | The result will be stored in ```result.mp4```. The driving videos and source images should be cropped before it can be used in our method. To obtain some semi-automatic crop suggestions you can use ```python crop-video.py --inp some_youtube_video.mp4```. It will generate commands for crops using ffmpeg. 83 | 84 | 85 | 86 | 87 | ## :computer: Training 88 | 89 | 90 | ### Datasets 91 | 92 | 1) **VoxCeleb**. Please follow the instruction from https://github.com/AliaksandrSiarohin/video-preprocessing. 93 | 94 | ### Train on VoxCeleb 95 | To train a model on specific dataset run: 96 | ``` 97 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --master_addr="0.0.0.0" --master_port=12348 run.py --config config/vox-adv-256.yaml --name DaGAN --rgbd --batchsize 12 --kp_num 15 --generator DepthAwareGenerator 98 | ``` 99 |
Or
100 | 101 | ``` 102 | CUDA_VISIBLE_DEVICES=0,1,2,3 python run_dataparallel.py --config config/vox-adv-256.yaml --device_ids 0,1,2,3 --name DaGAN_voxceleb2_depth --rgbd --batchsize 48 --kp_num 15 --generator DepthAwareGenerator 103 | ``` 104 | 105 | 106 | 107 | 108 | The code will create a folder in the log directory (each run will create a new name-specific directory). 109 | Checkpoints will be saved to this folder. 110 | To check the loss values during training see ```log.txt```. 111 | By default the batch size is tunned to run on 8 GeForce RTX 3090 gpu (You can obtain the best performance after about 150 epochs). You can change the batch size in the train_params in ```.yaml``` file. 112 | 113 | 114 | Also, you can watch the training loss by running the following command: 115 | ```bash 116 | tensorboard --logdir log/DaGAN/log 117 | ``` 118 | When you kill your process for some reasons in the middle of training, a zombie process may occur, you can kill it using our provided tool: 119 | ```bash 120 | python kill_port.py PORT 121 | ``` 122 | 123 | ### Training on your own dataset 124 | 1) Resize all the videos to the same size e.g 256x256, the videos can be in '.gif', '.mp4' or folder with images. 125 | We recommend the later, for each video make a separate folder with all the frames in '.png' format. This format is loss-less, and it has better i/o performance. 126 | 127 | 2) Create a folder ```data/dataset_name``` with 2 subfolders ```train``` and ```test```, put training videos in the ```train``` and testing in the ```test```. 128 | 129 | 3) Create a config ```config/dataset_name.yaml```, in dataset_params specify the root dir the ```root_dir: data/dataset_name```. Also adjust the number of epoch in train_params. 130 | 131 | 132 | 133 | ## :scroll: Acknowledgement 134 | 135 | Our DaGAN implementation is inspired by [FOMM](https://github.com/AliaksandrSiarohin/first-order-model). We appreciate the authors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model) for making their codes available to public. 136 | 137 | ## :scroll: BibTeX 138 | 139 | ``` 140 | @inproceedings{hong2022depth, 141 | title={Depth-Aware Generative Adversarial Network for Talking Head Video Generation}, 142 | author={Hong, Fa-Ting and Zhang, Longhao and Shen, Li and Xu, Dan}, 143 | journal={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 144 | year={2022} 145 | } 146 | 147 | @article{hong2023dagan, 148 | title={DaGAN++: Depth-Aware Generative Adversarial Network for Talking Head Video Generation}, 149 | author={Hong, Fa-Ting and and Shen, Li and Xu, Dan}, 150 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, 151 | year={2023} 152 | } 153 | ``` 154 | 155 | ### :e-mail: Contact 156 | 157 | If you have any question or collaboration need (research purpose or commercial purpose), please email `fhongac@cse.ust.hk`. 158 | -------------------------------------------------------------------------------- /animate.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | from frames_dataset import PairedDataset 8 | from logger import Logger, Visualizer 9 | import imageio 10 | from scipy.spatial import ConvexHull 11 | import numpy as np 12 | import depth 13 | from sync_batchnorm import DataParallelWithCallback 14 | 15 | 16 | def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False, 17 | use_relative_movement=False, use_relative_jacobian=False): 18 | if adapt_movement_scale: 19 | source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume 20 | driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume 21 | adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) 22 | else: 23 | adapt_movement_scale = 1 24 | 25 | kp_new = {k: v for k, v in kp_driving.items()} 26 | 27 | if use_relative_movement: 28 | kp_value_diff = (kp_driving['value'] - kp_driving_initial['value']) 29 | kp_value_diff *= adapt_movement_scale 30 | kp_new['value'] = kp_value_diff + kp_source['value'] 31 | 32 | if use_relative_jacobian: 33 | jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian'])) 34 | kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian']) 35 | return kp_new 36 | 37 | 38 | def animate(config, generator, kp_detector, checkpoint, log_dir, dataset,opt): 39 | log_dir = os.path.join(log_dir, 'animation') 40 | png_dir = os.path.join(log_dir, 'png') 41 | animate_params = config['animate_params'] 42 | 43 | dataset = PairedDataset(initial_dataset=dataset, number_of_pairs=animate_params['num_pairs']) 44 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) 45 | 46 | if checkpoint is not None: 47 | Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector) 48 | else: 49 | raise AttributeError("Checkpoint should be specified for mode='animate'.") 50 | 51 | if not os.path.exists(log_dir): 52 | os.makedirs(log_dir) 53 | 54 | if not os.path.exists(png_dir): 55 | os.makedirs(png_dir) 56 | 57 | 58 | depth_encoder = depth.ResnetEncoder(18, False).cuda() 59 | depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4)).cuda() 60 | loaded_dict_enc = torch.load('depth/models/weights_19/encoder.pth') 61 | loaded_dict_dec = torch.load('depth/models/weights_19/depth.pth') 62 | filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()} 63 | depth_encoder.load_state_dict(filtered_dict_enc) 64 | depth_decoder.load_state_dict(loaded_dict_dec) 65 | depth_decoder.eval() 66 | depth_encoder.eval() 67 | generator.eval() 68 | kp_detector.eval() 69 | 70 | for it, x in tqdm(enumerate(dataloader)): 71 | with torch.no_grad(): 72 | predictions = [] 73 | visualizations = [] 74 | 75 | driving_video = x['driving_video'].cuda() 76 | source_frame = x['source_video'][:, :, 0, :, :].cuda() 77 | 78 | outputs = depth_decoder(depth_encoder(source_frame)) 79 | depth_source = outputs[("disp", 0)] 80 | outputs = depth_decoder(depth_encoder(driving_video[:, :, 0])) 81 | depth_driving = outputs[("disp", 0)] 82 | 83 | source = torch.cat((source_frame,depth_source),1) 84 | driving = torch.cat((driving_video[:, :, 0],depth_driving),1) 85 | 86 | kp_source = kp_detector(source) 87 | kp_driving_initial = kp_detector(driving) 88 | 89 | for frame_idx in range(driving_video.shape[2]): 90 | driving_frame = driving_video[:, :, frame_idx].cuda() 91 | outputs = depth_decoder(depth_encoder(driving_frame)) 92 | depth_map = outputs[("disp", 0)] 93 | driving = torch.cat((driving_frame,depth_map),1) 94 | kp_driving = kp_detector(driving) 95 | 96 | kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving, 97 | kp_driving_initial=kp_driving_initial, **animate_params['normalization_params']) 98 | out = generator(source_frame, kp_source=kp_source, kp_driving=kp_norm) 99 | 100 | out['kp_driving'] = kp_driving 101 | out['kp_source'] = kp_source 102 | out['kp_norm'] = kp_norm 103 | 104 | del out['sparse_deformed'] 105 | 106 | predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) 107 | 108 | visualization = Visualizer(**config['visualizer_params']).visualize(source=source_frame, 109 | driving=driving_frame, out=out) 110 | visualization = visualization 111 | visualizations.append(visualization) 112 | 113 | predictions = np.concatenate(predictions, axis=1) 114 | result_name = "-".join([x['driving_name'][0], x['source_name'][0]]) 115 | imageio.imsave(os.path.join(png_dir, result_name + '.png'), (255 * predictions).astype(np.uint8)) 116 | 117 | image_name = result_name + animate_params['format'] 118 | imageio.mimsave(os.path.join(log_dir, image_name), visualizations) 119 | 120 | -------------------------------------------------------------------------------- /assets/driving.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harlanhong/CVPR2022-DaGAN/052ad359a019492893c098e8717dc993cdf846ed/assets/driving.mp4 -------------------------------------------------------------------------------- /assets/src.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harlanhong/CVPR2022-DaGAN/052ad359a019492893c098e8717dc993cdf846ed/assets/src.png -------------------------------------------------------------------------------- /assets/visual_vox1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harlanhong/CVPR2022-DaGAN/052ad359a019492893c098e8717dc993cdf846ed/assets/visual_vox1.png -------------------------------------------------------------------------------- /augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code from https://github.com/hassony2/torch_videovision 3 | """ 4 | 5 | import numbers 6 | 7 | import random 8 | import numpy as np 9 | import PIL 10 | 11 | from skimage.transform import resize, rotate 12 | # from skimage.util import pad 13 | # import numpy.pad as pad 14 | import torchvision 15 | 16 | import warnings 17 | 18 | from skimage import img_as_ubyte, img_as_float 19 | 20 | 21 | def crop_clip(clip, min_h, min_w, h, w): 22 | if isinstance(clip[0], np.ndarray): 23 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] 24 | 25 | elif isinstance(clip[0], PIL.Image.Image): 26 | cropped = [ 27 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip 28 | ] 29 | else: 30 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 31 | 'but got list of {0}'.format(type(clip[0]))) 32 | return cropped 33 | 34 | 35 | def pad_clip(clip, h, w): 36 | im_h, im_w = clip[0].shape[:2] 37 | pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2) 38 | pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2) 39 | 40 | return np.pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge') 41 | 42 | 43 | def resize_clip(clip, size, interpolation='bilinear'): 44 | if isinstance(clip[0], np.ndarray): 45 | if isinstance(size, numbers.Number): 46 | im_h, im_w, im_c = clip[0].shape 47 | # Min spatial dim already matches minimal size 48 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 49 | and im_h == size): 50 | return clip 51 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 52 | size = (new_w, new_h) 53 | else: 54 | size = size[1], size[0] 55 | 56 | scaled = [ 57 | resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True, 58 | mode='constant', anti_aliasing=True) for img in clip 59 | ] 60 | elif isinstance(clip[0], PIL.Image.Image): 61 | if isinstance(size, numbers.Number): 62 | im_w, im_h = clip[0].size 63 | # Min spatial dim already matches minimal size 64 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 65 | and im_h == size): 66 | return clip 67 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 68 | size = (new_w, new_h) 69 | else: 70 | size = size[1], size[0] 71 | if interpolation == 'bilinear': 72 | pil_inter = PIL.Image.NEAREST 73 | else: 74 | pil_inter = PIL.Image.BILINEAR 75 | scaled = [img.resize(size, pil_inter) for img in clip] 76 | else: 77 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 78 | 'but got list of {0}'.format(type(clip[0]))) 79 | return scaled 80 | 81 | 82 | def get_resize_sizes(im_h, im_w, size): 83 | if im_w < im_h: 84 | ow = size 85 | oh = int(size * im_h / im_w) 86 | else: 87 | oh = size 88 | ow = int(size * im_w / im_h) 89 | return oh, ow 90 | 91 | 92 | class RandomFlip(object): 93 | def __init__(self, time_flip=False, horizontal_flip=False): 94 | self.time_flip = time_flip 95 | self.horizontal_flip = horizontal_flip 96 | 97 | def __call__(self, clip): 98 | if random.random() < 0.5 and self.time_flip: 99 | return clip[::-1] 100 | if random.random() < 0.5 and self.horizontal_flip: 101 | return [np.fliplr(img) for img in clip] 102 | 103 | return clip 104 | 105 | 106 | class RandomResize(object): 107 | """Resizes a list of (H x W x C) numpy.ndarray to the final size 108 | The larger the original image is, the more times it takes to 109 | interpolate 110 | Args: 111 | interpolation (str): Can be one of 'nearest', 'bilinear' 112 | defaults to nearest 113 | size (tuple): (widht, height) 114 | """ 115 | 116 | def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'): 117 | self.ratio = ratio 118 | self.interpolation = interpolation 119 | 120 | def __call__(self, clip): 121 | scaling_factor = random.uniform(self.ratio[0], self.ratio[1]) 122 | 123 | if isinstance(clip[0], np.ndarray): 124 | im_h, im_w, im_c = clip[0].shape 125 | elif isinstance(clip[0], PIL.Image.Image): 126 | im_w, im_h = clip[0].size 127 | 128 | new_w = int(im_w * scaling_factor) 129 | new_h = int(im_h * scaling_factor) 130 | new_size = (new_w, new_h) 131 | resized = resize_clip( 132 | clip, new_size, interpolation=self.interpolation) 133 | 134 | return resized 135 | 136 | 137 | class RandomCrop(object): 138 | """Extract random crop at the same location for a list of videos 139 | Args: 140 | size (sequence or int): Desired output size for the 141 | crop in format (h, w) 142 | """ 143 | 144 | def __init__(self, size): 145 | if isinstance(size, numbers.Number): 146 | size = (size, size) 147 | 148 | self.size = size 149 | 150 | def __call__(self, clip): 151 | """ 152 | Args: 153 | img (PIL.Image or numpy.ndarray): List of videos to be cropped 154 | in format (h, w, c) in numpy.ndarray 155 | Returns: 156 | PIL.Image or numpy.ndarray: Cropped list of videos 157 | """ 158 | h, w = self.size 159 | if isinstance(clip[0], np.ndarray): 160 | im_h, im_w, im_c = clip[0].shape 161 | elif isinstance(clip[0], PIL.Image.Image): 162 | im_w, im_h = clip[0].size 163 | else: 164 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 165 | 'but got list of {0}'.format(type(clip[0]))) 166 | 167 | clip = pad_clip(clip, h, w) 168 | im_h, im_w = clip.shape[1:3] 169 | x1 = 0 if h == im_h else random.randint(0, im_w - w) 170 | y1 = 0 if w == im_w else random.randint(0, im_h - h) 171 | cropped = crop_clip(clip, y1, x1, h, w) 172 | 173 | return cropped 174 | 175 | 176 | class RandomRotation(object): 177 | """Rotate entire clip randomly by a random angle within 178 | given bounds 179 | Args: 180 | degrees (sequence or int): Range of degrees to select from 181 | If degrees is a number instead of sequence like (min, max), 182 | the range of degrees, will be (-degrees, +degrees). 183 | """ 184 | 185 | def __init__(self, degrees): 186 | if isinstance(degrees, numbers.Number): 187 | if degrees < 0: 188 | raise ValueError('If degrees is a single number,' 189 | 'must be positive') 190 | degrees = (-degrees, degrees) 191 | else: 192 | if len(degrees) != 2: 193 | raise ValueError('If degrees is a sequence,' 194 | 'it must be of len 2.') 195 | 196 | self.degrees = degrees 197 | 198 | def __call__(self, clip): 199 | """ 200 | Args: 201 | img (PIL.Image or numpy.ndarray): List of videos to be cropped 202 | in format (h, w, c) in numpy.ndarray 203 | Returns: 204 | PIL.Image or numpy.ndarray: Cropped list of videos 205 | """ 206 | angle = random.uniform(self.degrees[0], self.degrees[1]) 207 | if isinstance(clip[0], np.ndarray): 208 | rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip] 209 | elif isinstance(clip[0], PIL.Image.Image): 210 | rotated = [img.rotate(angle) for img in clip] 211 | else: 212 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 213 | 'but got list of {0}'.format(type(clip[0]))) 214 | 215 | return rotated 216 | 217 | 218 | class ColorJitter(object): 219 | """Randomly change the brightness, contrast and saturation and hue of the clip 220 | Args: 221 | brightness (float): How much to jitter brightness. brightness_factor 222 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 223 | contrast (float): How much to jitter contrast. contrast_factor 224 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 225 | saturation (float): How much to jitter saturation. saturation_factor 226 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 227 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 228 | [-hue, hue]. Should be >=0 and <= 0.5. 229 | """ 230 | 231 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 232 | self.brightness = brightness 233 | self.contrast = contrast 234 | self.saturation = saturation 235 | self.hue = hue 236 | 237 | def get_params(self, brightness, contrast, saturation, hue): 238 | if brightness > 0: 239 | brightness_factor = random.uniform( 240 | max(0, 1 - brightness), 1 + brightness) 241 | else: 242 | brightness_factor = None 243 | 244 | if contrast > 0: 245 | contrast_factor = random.uniform( 246 | max(0, 1 - contrast), 1 + contrast) 247 | else: 248 | contrast_factor = None 249 | 250 | if saturation > 0: 251 | saturation_factor = random.uniform( 252 | max(0, 1 - saturation), 1 + saturation) 253 | else: 254 | saturation_factor = None 255 | 256 | if hue > 0: 257 | hue_factor = random.uniform(-hue, hue) 258 | else: 259 | hue_factor = None 260 | return brightness_factor, contrast_factor, saturation_factor, hue_factor 261 | 262 | def __call__(self, clip): 263 | """ 264 | Args: 265 | clip (list): list of PIL.Image 266 | Returns: 267 | list PIL.Image : list of transformed PIL.Image 268 | """ 269 | if isinstance(clip[0], np.ndarray): 270 | brightness, contrast, saturation, hue = self.get_params( 271 | self.brightness, self.contrast, self.saturation, self.hue) 272 | 273 | # Create img transform function sequence 274 | img_transforms = [] 275 | if brightness is not None: 276 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) 277 | if saturation is not None: 278 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) 279 | if hue is not None: 280 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) 281 | if contrast is not None: 282 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) 283 | random.shuffle(img_transforms) 284 | img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array, 285 | img_as_float] 286 | 287 | with warnings.catch_warnings(): 288 | warnings.simplefilter("ignore") 289 | jittered_clip = [] 290 | for img in clip: 291 | jittered_img = img 292 | for func in img_transforms: 293 | jittered_img = func(jittered_img) 294 | jittered_clip.append(jittered_img.astype('float32')) 295 | elif isinstance(clip[0], PIL.Image.Image): 296 | brightness, contrast, saturation, hue = self.get_params( 297 | self.brightness, self.contrast, self.saturation, self.hue) 298 | 299 | # Create img transform function sequence 300 | img_transforms = [] 301 | if brightness is not None: 302 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) 303 | if saturation is not None: 304 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) 305 | if hue is not None: 306 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) 307 | if contrast is not None: 308 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) 309 | random.shuffle(img_transforms) 310 | 311 | # Apply to all videos 312 | jittered_clip = [] 313 | for img in clip: 314 | for func in img_transforms: 315 | jittered_img = func(img) 316 | jittered_clip.append(jittered_img) 317 | 318 | else: 319 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 320 | 'but got list of {0}'.format(type(clip[0]))) 321 | return jittered_clip 322 | 323 | 324 | class AllAugmentationTransform: 325 | def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None): 326 | self.transforms = [] 327 | 328 | if flip_param is not None: 329 | self.transforms.append(RandomFlip(**flip_param)) 330 | 331 | if rotation_param is not None: 332 | self.transforms.append(RandomRotation(**rotation_param)) 333 | 334 | if resize_param is not None: 335 | self.transforms.append(RandomResize(**resize_param)) 336 | 337 | if crop_param is not None: 338 | self.transforms.append(RandomCrop(**crop_param)) 339 | 340 | if jitter_param is not None: 341 | self.transforms.append(ColorJitter(**jitter_param)) 342 | 343 | def __call__(self, clip): 344 | for t in self.transforms: 345 | clip = t(clip) 346 | return clip 347 | -------------------------------------------------------------------------------- /config/vox-adv-256.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | root_dir: /data/fhongac/origDataset/vox1_frames 3 | frame_shape: [256, 256, 3] 4 | id_sampling: True 5 | pairs_list: data/vox256.csv 6 | augmentation_params: 7 | flip_param: 8 | horizontal_flip: True 9 | time_flip: True 10 | jitter_param: 11 | brightness: 0.1 12 | contrast: 0.1 13 | saturation: 0.1 14 | hue: 0.1 15 | 16 | 17 | model_params: 18 | common_params: 19 | num_kp: 10 20 | num_channels: 3 21 | estimate_jacobian: True 22 | kp_detector_params: 23 | temperature: 0.1 24 | block_expansion: 32 25 | max_features: 1024 26 | scale_factor: 0.25 27 | num_blocks: 5 28 | generator_params: 29 | block_expansion: 64 30 | max_features: 512 31 | num_down_blocks: 2 32 | num_bottleneck_blocks: 6 33 | estimate_occlusion_map: True 34 | dense_motion_params: 35 | block_expansion: 64 36 | max_features: 1024 37 | num_blocks: 5 38 | scale_factor: 0.25 39 | discriminator_params: 40 | scales: [1] 41 | block_expansion: 32 42 | max_features: 512 43 | num_blocks: 4 44 | use_kp: True 45 | 46 | 47 | train_params: 48 | num_epochs: 150 49 | num_repeats: 75 50 | epoch_milestones: [] 51 | lr_generator: 2.0e-4 52 | lr_discriminator: 2.0e-4 53 | lr_kp_detector: 2.0e-4 54 | batch_size: 4 55 | scales: [1, 0.5, 0.25, 0.125] 56 | checkpoint_freq: 10 57 | transform_params: 58 | sigma_affine: 0.05 59 | sigma_tps: 0.005 60 | points_tps: 5 61 | loss_weights: 62 | generator_gan: 1 63 | discriminator_gan: 1 64 | feature_matching: [10, 10, 10, 10] 65 | perceptual: [10, 10, 10, 10, 10] 66 | equivariance_value: 10 67 | equivariance_jacobian: 10 68 | kp_distance: 10 69 | kp_prior: 0 70 | kp_scale: 0 71 | depth_constraint: 0 72 | 73 | reconstruction_params: 74 | num_videos: 1000 75 | format: '.mp4' 76 | 77 | animate_params: 78 | num_pairs: 50 79 | format: '.mp4' 80 | normalization_params: 81 | adapt_movement_scale: False 82 | use_relative_movement: True 83 | use_relative_jacobian: True 84 | 85 | visualizer_params: 86 | kp_size: 5 87 | draw_border: True 88 | colormap: 'gist_rainbow' 89 | -------------------------------------------------------------------------------- /crop-video.py: -------------------------------------------------------------------------------- 1 | import face_alignment 2 | import skimage.io 3 | import numpy 4 | from argparse import ArgumentParser 5 | from skimage import img_as_ubyte 6 | from skimage.transform import resize 7 | from tqdm import tqdm 8 | import os 9 | import imageio 10 | import numpy as np 11 | import warnings 12 | warnings.filterwarnings("ignore") 13 | 14 | def extract_bbox(frame, fa): 15 | if max(frame.shape[0], frame.shape[1]) > 640: 16 | scale_factor = max(frame.shape[0], frame.shape[1]) / 640.0 17 | frame = resize(frame, (int(frame.shape[0] / scale_factor), int(frame.shape[1] / scale_factor))) 18 | frame = img_as_ubyte(frame) 19 | else: 20 | scale_factor = 1 21 | frame = frame[..., :3] 22 | bboxes = fa.face_detector.detect_from_image(frame[..., ::-1]) 23 | if len(bboxes) == 0: 24 | return [] 25 | return np.array(bboxes)[:, :-1] * scale_factor 26 | 27 | 28 | 29 | def bb_intersection_over_union(boxA, boxB): 30 | xA = max(boxA[0], boxB[0]) 31 | yA = max(boxA[1], boxB[1]) 32 | xB = min(boxA[2], boxB[2]) 33 | yB = min(boxA[3], boxB[3]) 34 | interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) 35 | boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) 36 | boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) 37 | iou = interArea / float(boxAArea + boxBArea - interArea) 38 | return iou 39 | 40 | 41 | def join(tube_bbox, bbox): 42 | xA = min(tube_bbox[0], bbox[0]) 43 | yA = min(tube_bbox[1], bbox[1]) 44 | xB = max(tube_bbox[2], bbox[2]) 45 | yB = max(tube_bbox[3], bbox[3]) 46 | return (xA, yA, xB, yB) 47 | 48 | 49 | def compute_bbox(start, end, fps, tube_bbox, frame_shape, inp, image_shape, increase_area=0.1): 50 | left, top, right, bot = tube_bbox 51 | width = right - left 52 | height = bot - top 53 | 54 | #Computing aspect preserving bbox 55 | width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)) 56 | height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)) 57 | 58 | left = int(left - width_increase * width) 59 | top = int(top - height_increase * height) 60 | right = int(right + width_increase * width) 61 | bot = int(bot + height_increase * height) 62 | 63 | top, bot, left, right = max(0, top), min(bot, frame_shape[0]), max(0, left), min(right, frame_shape[1]) 64 | h, w = bot - top, right - left 65 | 66 | start = start / fps 67 | end = end / fps 68 | time = end - start 69 | 70 | scale = f'{image_shape[0]}:{image_shape[1]}' 71 | 72 | return f'ffmpeg -i {inp} -ss {start} -t {time} -filter:v "crop={w}:{h}:{left}:{top}, scale={scale}" crop.mp4' 73 | 74 | 75 | def compute_bbox_trajectories(trajectories, fps, frame_shape, args): 76 | commands = [] 77 | for i, (bbox, tube_bbox, start, end) in enumerate(trajectories): 78 | if (end - start) > args.min_frames: 79 | command = compute_bbox(start, end, fps, tube_bbox, frame_shape, inp=args.inp, image_shape=args.image_shape, increase_area=args.increase) 80 | commands.append(command) 81 | return commands 82 | 83 | 84 | def process_video(args): 85 | device = 'cpu' if args.cpu else 'cuda' 86 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device=device) 87 | video = imageio.get_reader(args.inp) 88 | 89 | trajectories = [] 90 | previous_frame = None 91 | fps = video.get_meta_data()['fps'] 92 | commands = [] 93 | try: 94 | for i, frame in tqdm(enumerate(video)): 95 | frame_shape = frame.shape 96 | bboxes = extract_bbox(frame, fa) 97 | ## For each trajectory check the criterion 98 | not_valid_trajectories = [] 99 | valid_trajectories = [] 100 | 101 | for trajectory in trajectories: 102 | tube_bbox = trajectory[0] 103 | intersection = 0 104 | for bbox in bboxes: 105 | intersection = max(intersection, bb_intersection_over_union(tube_bbox, bbox)) 106 | if intersection > args.iou_with_initial: 107 | valid_trajectories.append(trajectory) 108 | else: 109 | not_valid_trajectories.append(trajectory) 110 | 111 | commands += compute_bbox_trajectories(not_valid_trajectories, fps, frame_shape, args) 112 | trajectories = valid_trajectories 113 | 114 | ## Assign bbox to trajectories, create new trajectories 115 | for bbox in bboxes: 116 | intersection = 0 117 | current_trajectory = None 118 | for trajectory in trajectories: 119 | tube_bbox = trajectory[0] 120 | current_intersection = bb_intersection_over_union(tube_bbox, bbox) 121 | if intersection < current_intersection and current_intersection > args.iou_with_initial: 122 | intersection = bb_intersection_over_union(tube_bbox, bbox) 123 | current_trajectory = trajectory 124 | 125 | ## Create new trajectory 126 | if current_trajectory is None: 127 | trajectories.append([bbox, bbox, i, i]) 128 | else: 129 | current_trajectory[3] = i 130 | current_trajectory[1] = join(current_trajectory[1], bbox) 131 | 132 | 133 | except IndexError as e: 134 | raise (e) 135 | 136 | commands += compute_bbox_trajectories(trajectories, fps, frame_shape, args) 137 | return commands 138 | 139 | 140 | if __name__ == "__main__": 141 | parser = ArgumentParser() 142 | 143 | parser.add_argument("--image_shape", default=(256, 256), type=lambda x: tuple(map(int, x.split(','))), 144 | help="Image shape") 145 | parser.add_argument("--increase", default=0.1, type=float, help='Increase bbox by this amount') 146 | parser.add_argument("--iou_with_initial", type=float, default=0.25, help="The minimal allowed iou with inital bbox") 147 | parser.add_argument("--inp", required=True, help='Input image or video') 148 | parser.add_argument("--min_frames", type=int, default=150, help='Minimum number of frames') 149 | parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.") 150 | 151 | 152 | args = parser.parse_args() 153 | 154 | commands = process_video(args) 155 | for command in commands: 156 | print (command) 157 | 158 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import csv 4 | import pdb 5 | import numpy as np 6 | 7 | def create_csv(path): 8 | videos = os.listdir(path) 9 | source = videos.copy() 10 | driving = videos.copy() 11 | random.shuffle(source) 12 | random.shuffle(driving) 13 | source = np.array(source).reshape(-1,1) 14 | driving = np.array(driving).reshape(-1,1) 15 | zeros = np.zeros((len(source),1)) 16 | content = np.concatenate((source,driving,zeros),1) 17 | f = open('vox256.csv','w',encoding='utf-8') 18 | csv_writer = csv.writer(f) 19 | csv_writer.writerow(["source","driving","frame"]) 20 | csv_writer.writerows(content) 21 | f.close() 22 | 23 | 24 | if __name__ == '__main__': 25 | create_csv('/data/fhongac/origDataset/vox1/test') -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import os, sys 4 | import yaml 5 | from argparse import ArgumentParser 6 | from tqdm import tqdm 7 | import modules.generator as GEN 8 | import imageio 9 | import numpy as np 10 | from skimage.transform import resize 11 | from skimage import img_as_ubyte 12 | import torch 13 | from sync_batchnorm import DataParallelWithCallback 14 | import depth 15 | from modules.keypoint_detector import KPDetector 16 | from animate import normalize_kp 17 | from scipy.spatial import ConvexHull 18 | from collections import OrderedDict 19 | import pdb 20 | import cv2 21 | if sys.version_info[0] < 3: 22 | raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") 23 | 24 | def load_checkpoints(config_path, checkpoint_path, cpu=False): 25 | 26 | with open(config_path) as f: 27 | config = yaml.load(f) 28 | if opt.kp_num != -1: 29 | config['model_params']['common_params']['num_kp'] = opt.kp_num 30 | generator = getattr(GEN, opt.generator)(**config['model_params']['generator_params'],**config['model_params']['common_params']) 31 | if not cpu: 32 | generator.cuda() 33 | config['model_params']['common_params']['num_channels'] = 4 34 | kp_detector = KPDetector(**config['model_params']['kp_detector_params'], 35 | **config['model_params']['common_params']) 36 | if not cpu: 37 | kp_detector.cuda() 38 | if cpu: 39 | checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) 40 | else: 41 | checkpoint = torch.load(checkpoint_path,map_location="cuda:0") 42 | 43 | ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in checkpoint['generator'].items()) 44 | generator.load_state_dict(ckp_generator) 45 | ckp_kp_detector = OrderedDict((k.replace('module.',''),v) for k,v in checkpoint['kp_detector'].items()) 46 | kp_detector.load_state_dict(ckp_kp_detector) 47 | 48 | if not cpu: 49 | generator = DataParallelWithCallback(generator) 50 | kp_detector = DataParallelWithCallback(kp_detector) 51 | 52 | generator.eval() 53 | kp_detector.eval() 54 | 55 | return generator, kp_detector 56 | 57 | 58 | def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False): 59 | sources = [] 60 | drivings = [] 61 | with torch.no_grad(): 62 | predictions = [] 63 | depth_gray = [] 64 | source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) 65 | driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3) 66 | if not cpu: 67 | source = source.cuda() 68 | driving = driving.cuda() 69 | outputs = depth_decoder(depth_encoder(source)) 70 | depth_source = outputs[("disp", 0)] 71 | 72 | outputs = depth_decoder(depth_encoder(driving[:, :, 0])) 73 | depth_driving = outputs[("disp", 0)] 74 | source_kp = torch.cat((source,depth_source),1) 75 | driving_kp = torch.cat((driving[:, :, 0],depth_driving),1) 76 | 77 | kp_source = kp_detector(source_kp) 78 | kp_driving_initial = kp_detector(driving_kp) 79 | 80 | # kp_source = kp_detector(source) 81 | # kp_driving_initial = kp_detector(driving[:, :, 0]) 82 | 83 | for frame_idx in tqdm(range(driving.shape[2])): 84 | driving_frame = driving[:, :, frame_idx] 85 | 86 | if not cpu: 87 | driving_frame = driving_frame.cuda() 88 | outputs = depth_decoder(depth_encoder(driving_frame)) 89 | depth_map = outputs[("disp", 0)] 90 | 91 | gray_driving = np.transpose(depth_map.data.cpu().numpy(), [0, 2, 3, 1])[0] 92 | gray_driving = 1-gray_driving/np.max(gray_driving) 93 | 94 | frame = torch.cat((driving_frame,depth_map),1) 95 | kp_driving = kp_detector(frame) 96 | 97 | kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving, 98 | kp_driving_initial=kp_driving_initial, use_relative_movement=relative, 99 | use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale) 100 | out = generator(source, kp_source=kp_source, kp_driving=kp_norm,source_depth = depth_source, driving_depth = depth_map) 101 | 102 | drivings.append(np.transpose(driving_frame.data.cpu().numpy(), [0, 2, 3, 1])[0]) 103 | sources.append(np.transpose(source.data.cpu().numpy(), [0, 2, 3, 1])[0]) 104 | predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) 105 | depth_gray.append(gray_driving) 106 | return sources, drivings, predictions,depth_gray 107 | 108 | 109 | def find_best_frame(source, driving, cpu=False): 110 | import face_alignment 111 | 112 | def normalize_kp(kp): 113 | kp = kp - kp.mean(axis=0, keepdims=True) 114 | area = ConvexHull(kp[:, :2]).volume 115 | area = np.sqrt(area) 116 | kp[:, :2] = kp[:, :2] / area 117 | return kp 118 | 119 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True, 120 | device='cpu' if cpu else 'cuda') 121 | kp_source = fa.get_landmarks(255 * source)[0] 122 | kp_source = normalize_kp(kp_source) 123 | norm = float('inf') 124 | frame_num = 0 125 | for i, image in tqdm(enumerate(driving)): 126 | kp_driving = fa.get_landmarks(255 * image)[0] 127 | kp_driving = normalize_kp(kp_driving) 128 | new_norm = (np.abs(kp_source - kp_driving) ** 2).sum() 129 | if new_norm < norm: 130 | norm = new_norm 131 | frame_num = i 132 | return frame_num 133 | 134 | if __name__ == "__main__": 135 | parser = ArgumentParser() 136 | parser.add_argument("--config", required=True, help="path to config") 137 | parser.add_argument("--checkpoint", default='vox-cpk.pth.tar', help="path to checkpoint to restore") 138 | 139 | parser.add_argument("--source_image", default='sup-mat/source.png', help="path to source image") 140 | parser.add_argument("--driving_video", default='sup-mat/source.png', help="path to driving video") 141 | parser.add_argument("--result_video", default='result.mp4', help="path to output") 142 | 143 | parser.add_argument("--relative", dest="relative", action="store_true", help="use relative or absolute keypoint coordinates") 144 | parser.add_argument("--adapt_scale", dest="adapt_scale", action="store_true", help="adapt movement scale based on convex hull of keypoints") 145 | parser.add_argument("--generator", type=str, required=True) 146 | parser.add_argument("--kp_num", type=int, required=True) 147 | 148 | 149 | parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true", 150 | help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)") 151 | 152 | parser.add_argument("--best_frame", dest="best_frame", type=int, default=None, 153 | help="Set frame to start from.") 154 | 155 | parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.") 156 | 157 | 158 | parser.set_defaults(relative=False) 159 | parser.set_defaults(adapt_scale=False) 160 | 161 | opt = parser.parse_args() 162 | 163 | depth_encoder = depth.ResnetEncoder(18, False) 164 | depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4)) 165 | loaded_dict_enc = torch.load('depth/models/weights_19/encoder.pth') 166 | loaded_dict_dec = torch.load('depth/models/weights_19/depth.pth') 167 | filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()} 168 | depth_encoder.load_state_dict(filtered_dict_enc) 169 | depth_decoder.load_state_dict(loaded_dict_dec) 170 | depth_encoder.eval() 171 | depth_decoder.eval() 172 | if not opt.cpu: 173 | depth_encoder.cuda() 174 | depth_decoder.cuda() 175 | 176 | source_image = imageio.imread(opt.source_image) 177 | reader = imageio.get_reader(opt.driving_video) 178 | fps = reader.get_meta_data()['fps'] 179 | driving_video = [] 180 | try: 181 | for im in reader: 182 | driving_video.append(im) 183 | except RuntimeError: 184 | pass 185 | reader.close() 186 | 187 | source_image = resize(source_image, (256, 256))[..., :3] 188 | driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video] 189 | generator, kp_detector = load_checkpoints(config_path=opt.config, checkpoint_path=opt.checkpoint, cpu=opt.cpu) 190 | 191 | if opt.find_best_frame or opt.best_frame is not None: 192 | i = opt.best_frame if opt.best_frame is not None else find_best_frame(source_image, driving_video, cpu=opt.cpu) 193 | print ("Best frame: " + str(i)) 194 | driving_forward = driving_video[i:] 195 | driving_backward = driving_video[:(i+1)][::-1] 196 | sources_forward, drivings_forward, predictions_forward,depth_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu) 197 | sources_backward, drivings_backward, predictions_backward,depth_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu) 198 | predictions = predictions_backward[::-1] + predictions_forward[1:] 199 | sources = sources_backward[::-1] + sources_forward[1:] 200 | drivings = drivings_backward[::-1] + drivings_forward[1:] 201 | depth_gray = depth_backward[::-1] + depth_forward[1:] 202 | 203 | else: 204 | # predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu) 205 | sources, drivings, predictions,depth_gray = make_animation(source_image, driving_video, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu) 206 | imageio.mimsave(opt.result_video, [img_as_ubyte(p) for p in predictions], fps=fps) 207 | # imageio.mimsave(opt.result_video, [np.concatenate((img_as_ubyte(s),img_as_ubyte(d),img_as_ubyte(p)),1) for (s,d,p) in zip(sources, drivings, predictions)], fps=fps) 208 | # imageio.mimsave("gray.mp4", depth_gray, fps=fps) 209 | # merge the gray video 210 | # animation = np.array(imageio.mimread(opt.result_video,memtest=False)) 211 | # gray = np.array(imageio.mimread("gray.mp4",memtest=False)) 212 | 213 | # src_dst = animation[:,:,:512,:] 214 | # animate = animation[:,:,512:,:] 215 | # merge = np.concatenate((src_dst,gray,animate),2) 216 | # imageio.mimsave(opt.result_video, animate, fps=fps) 217 | #Transfer to gif 218 | # from moviepy.editor import * 219 | # clip = (VideoFileClip(opt.result_video)) 220 | # clip.write_gif("{}.gif".format(opt.result_video)) -------------------------------------------------------------------------------- /demo_multi.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import os, sys 4 | import yaml 5 | from argparse import ArgumentParser 6 | from tqdm import tqdm 7 | import modules.generator as GEN 8 | import imageio 9 | import numpy as np 10 | from skimage.transform import resize 11 | from skimage import img_as_ubyte 12 | import torch 13 | from sync_batchnorm import DataParallelWithCallback 14 | import depth 15 | from modules.keypoint_detector import KPDetector 16 | from animate import normalize_kp 17 | from scipy.spatial import ConvexHull 18 | from collections import OrderedDict 19 | import pdb 20 | import cv2 21 | from glob import glob 22 | if sys.version_info[0] < 3: 23 | raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") 24 | 25 | def load_checkpoints(config_path, checkpoint_path, cpu=False): 26 | 27 | with open(config_path) as f: 28 | config = yaml.load(f,Loader=yaml.FullLoader) 29 | if opt.kp_num != -1: 30 | config['model_params']['common_params']['num_kp'] = opt.kp_num 31 | generator = getattr(GEN, opt.generator)(**config['model_params']['generator_params'],**config['model_params']['common_params']) 32 | if not cpu: 33 | generator.cuda() 34 | config['model_params']['common_params']['num_channels'] = 4 35 | kp_detector = KPDetector(**config['model_params']['kp_detector_params'], 36 | **config['model_params']['common_params']) 37 | if not cpu: 38 | kp_detector.cuda() 39 | if cpu: 40 | checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) 41 | else: 42 | checkpoint = torch.load(checkpoint_path,map_location="cuda:0") 43 | 44 | ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in checkpoint['generator'].items()) 45 | generator.load_state_dict(ckp_generator) 46 | ckp_kp_detector = OrderedDict((k.replace('module.',''),v) for k,v in checkpoint['kp_detector'].items()) 47 | kp_detector.load_state_dict(ckp_kp_detector) 48 | 49 | if not cpu: 50 | generator = DataParallelWithCallback(generator) 51 | kp_detector = DataParallelWithCallback(kp_detector) 52 | 53 | generator.eval() 54 | kp_detector.eval() 55 | 56 | return generator, kp_detector 57 | 58 | 59 | def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False): 60 | sources = [] 61 | drivings = [] 62 | with torch.no_grad(): 63 | predictions = [] 64 | depth_gray = [] 65 | source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) 66 | driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3) 67 | if not cpu: 68 | source = source.cuda() 69 | driving = driving.cuda() 70 | outputs = depth_decoder(depth_encoder(source)) 71 | depth_source = outputs[("disp", 0)] 72 | 73 | outputs = depth_decoder(depth_encoder(driving[:, :, 0])) 74 | depth_driving = outputs[("disp", 0)] 75 | source_kp = torch.cat((source,depth_source),1) 76 | driving_kp = torch.cat((driving[:, :, 0],depth_driving),1) 77 | 78 | kp_source = kp_detector(source_kp) 79 | kp_driving_initial = kp_detector(driving_kp) 80 | 81 | # kp_source = kp_detector(source) 82 | # kp_driving_initial = kp_detector(driving[:, :, 0]) 83 | 84 | for frame_idx in tqdm(range(driving.shape[2])): 85 | driving_frame = driving[:, :, frame_idx] 86 | 87 | if not cpu: 88 | driving_frame = driving_frame.cuda() 89 | outputs = depth_decoder(depth_encoder(driving_frame)) 90 | depth_map = outputs[("disp", 0)] 91 | 92 | gray_driving = np.transpose(depth_map.data.cpu().numpy(), [0, 2, 3, 1])[0] 93 | gray_driving = 1-gray_driving/np.max(gray_driving) 94 | 95 | frame = torch.cat((driving_frame,depth_map),1) 96 | kp_driving = kp_detector(frame) 97 | 98 | kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving, 99 | kp_driving_initial=kp_driving_initial, use_relative_movement=relative, 100 | use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale) 101 | out = generator(source, kp_source=kp_source, kp_driving=kp_norm,source_depth = depth_source, driving_depth = depth_map) 102 | 103 | drivings.append(np.transpose(driving_frame.data.cpu().numpy(), [0, 2, 3, 1])[0]) 104 | sources.append(np.transpose(source.data.cpu().numpy(), [0, 2, 3, 1])[0]) 105 | predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) 106 | depth_gray.append(gray_driving) 107 | return sources, drivings, predictions,depth_gray 108 | 109 | 110 | def find_best_frame(source, driving, cpu=False): 111 | import face_alignment 112 | 113 | def normalize_kp(kp): 114 | kp = kp - kp.mean(axis=0, keepdims=True) 115 | area = ConvexHull(kp[:, :2]).volume 116 | area = np.sqrt(area) 117 | kp[:, :2] = kp[:, :2] / area 118 | return kp 119 | 120 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True, 121 | device='cpu' if cpu else 'cuda') 122 | kp_source = fa.get_landmarks(255 * source)[0] 123 | kp_source = normalize_kp(kp_source) 124 | norm = float('inf') 125 | frame_num = 0 126 | for i, image in tqdm(enumerate(driving)): 127 | kp_driving = fa.get_landmarks(255 * image)[0] 128 | kp_driving = normalize_kp(kp_driving) 129 | new_norm = (np.abs(kp_source - kp_driving) ** 2).sum() 130 | if new_norm < norm: 131 | norm = new_norm 132 | frame_num = i 133 | return frame_num 134 | 135 | if __name__ == "__main__": 136 | parser = ArgumentParser() 137 | parser.add_argument("--config", required=True, help="path to config") 138 | parser.add_argument("--checkpoint", default='vox-cpk.pth.tar', help="path to checkpoint to restore") 139 | 140 | parser.add_argument("--source_folder", default='sup-mat/source.png', help="path to source image") 141 | parser.add_argument("--driving_video", default='sup-mat/source.png', help="path to driving video") 142 | parser.add_argument("--save_folder", default='result.mp4', help="path to output") 143 | 144 | parser.add_argument("--relative", dest="relative", action="store_true", help="use relative or absolute keypoint coordinates") 145 | parser.add_argument("--adapt_scale", dest="adapt_scale", action="store_true", help="adapt movement scale based on convex hull of keypoints") 146 | parser.add_argument("--generator", type=str, required=True) 147 | parser.add_argument("--kp_num", type=int, required=True) 148 | 149 | 150 | parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true", 151 | help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)") 152 | 153 | parser.add_argument("--best_frame", dest="best_frame", type=int, default=None, 154 | help="Set frame to start from.") 155 | 156 | parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.") 157 | 158 | 159 | parser.set_defaults(relative=False) 160 | parser.set_defaults(adapt_scale=False) 161 | 162 | opt = parser.parse_args() 163 | 164 | depth_encoder = depth.ResnetEncoder(18, False) 165 | depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4)) 166 | loaded_dict_enc = torch.load('depth/models/weights_19/encoder.pth') 167 | loaded_dict_dec = torch.load('depth/models/weights_19/depth.pth') 168 | filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()} 169 | depth_encoder.load_state_dict(filtered_dict_enc) 170 | depth_decoder.load_state_dict(loaded_dict_dec) 171 | depth_encoder.eval() 172 | depth_decoder.eval() 173 | if not opt.cpu: 174 | depth_encoder.cuda() 175 | depth_decoder.cuda() 176 | 177 | reader = imageio.get_reader(opt.driving_video) 178 | fps = reader.get_meta_data()['fps'] 179 | driving_video = [] 180 | try: 181 | for im in reader: 182 | driving_video.append(im) 183 | except RuntimeError: 184 | pass 185 | reader.close() 186 | 187 | driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video] 188 | generator, kp_detector = load_checkpoints(config_path=opt.config, checkpoint_path=opt.checkpoint, cpu=opt.cpu) 189 | if not os.path.exists(opt.save_folder): 190 | os.makedirs(opt.save_folder) 191 | sources = glob(opt.source_folder+"/*") 192 | for src in tqdm(sources): 193 | source_image = imageio.imread(src) 194 | source_image = resize(source_image, (256, 256))[..., :3] 195 | # predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu) 196 | sources, drivings, predictions,depth_gray = make_animation(source_image, driving_video, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu) 197 | fn = os.path.basename(src) 198 | imageio.mimsave(os.path.join(opt.save_folder,fn+'.mp4'), [img_as_ubyte(p) for p in predictions], fps=fps) 199 | -------------------------------------------------------------------------------- /depth/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet_encoder import ResnetEncoder 2 | from .depth_decoder import DepthDecoder 3 | from .pose_decoder import PoseDecoder 4 | from .pose_cnn import PoseCNN 5 | 6 | -------------------------------------------------------------------------------- /depth/depth_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | 13 | from collections import OrderedDict 14 | from depth.layers import * 15 | 16 | 17 | class DepthDecoder(nn.Module): 18 | def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True): 19 | super(DepthDecoder, self).__init__() 20 | 21 | self.num_output_channels = num_output_channels 22 | self.use_skips = use_skips 23 | self.upsample_mode = 'nearest' 24 | self.scales = scales 25 | 26 | self.num_ch_enc = num_ch_enc 27 | self.num_ch_dec = np.array([16, 32, 64, 128, 256]) 28 | 29 | # decoder 30 | self.convs = OrderedDict() 31 | for i in range(4, -1, -1): 32 | # upconv_0 33 | num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1] 34 | num_ch_out = self.num_ch_dec[i] 35 | self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out) 36 | 37 | # upconv_1 38 | num_ch_in = self.num_ch_dec[i] 39 | if self.use_skips and i > 0: 40 | num_ch_in += self.num_ch_enc[i - 1] 41 | num_ch_out = self.num_ch_dec[i] 42 | self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out) 43 | 44 | for s in self.scales: 45 | self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels) 46 | 47 | self.decoder = nn.ModuleList(list(self.convs.values())) 48 | self.sigmoid = nn.Sigmoid() 49 | 50 | def forward(self, input_features): 51 | self.outputs = {} 52 | 53 | # decoder 54 | x = input_features[-1] 55 | for i in range(4, -1, -1): 56 | x = self.convs[("upconv", i, 0)](x) 57 | x = [upsample(x)] 58 | if self.use_skips and i > 0: 59 | x += [input_features[i - 1]] 60 | x = torch.cat(x, 1) 61 | x = self.convs[("upconv", i, 1)](x) 62 | if i in self.scales: 63 | self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x)) 64 | 65 | return self.outputs 66 | -------------------------------------------------------------------------------- /depth/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import numpy as np 10 | import pdb 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def disp_to_depth(disp, min_depth, max_depth): 17 | """Convert network's sigmoid output into depth prediction 18 | The formula for this conversion is given in the 'additional considerations' 19 | section of the paper. 20 | """ 21 | min_disp = 1 / max_depth 22 | max_disp = 1 / min_depth 23 | scaled_disp = min_disp + (max_disp - min_disp) * disp 24 | depth = 1 / scaled_disp 25 | return scaled_disp, depth 26 | 27 | 28 | def transformation_from_parameters(axisangle, translation, invert=False): 29 | """Convert the network's (axisangle, translation) output into a 4x4 matrix 30 | """ 31 | R = rot_from_axisangle(axisangle) 32 | t = translation.clone() 33 | 34 | if invert: 35 | R = R.transpose(1, 2) 36 | t *= -1 37 | 38 | T = get_translation_matrix(t) 39 | 40 | if invert: 41 | M = torch.matmul(R, T) 42 | else: 43 | M = torch.matmul(T, R) 44 | 45 | return M 46 | 47 | 48 | def get_translation_matrix(translation_vector): 49 | """Convert a translation vector into a 4x4 transformation matrix 50 | """ 51 | T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device) 52 | 53 | t = translation_vector.contiguous().view(-1, 3, 1) 54 | 55 | T[:, 0, 0] = 1 56 | T[:, 1, 1] = 1 57 | T[:, 2, 2] = 1 58 | T[:, 3, 3] = 1 59 | T[:, :3, 3, None] = t 60 | 61 | return T 62 | 63 | 64 | def rot_from_axisangle(vec): 65 | """Convert an axisangle rotation into a 4x4 transformation matrix 66 | (adapted from https://github.com/Wallacoloo/printipi) 67 | Input 'vec' has to be Bx1x3 68 | """ 69 | angle = torch.norm(vec, 2, 2, True) 70 | axis = vec / (angle + 1e-7) 71 | 72 | ca = torch.cos(angle) 73 | sa = torch.sin(angle) 74 | C = 1 - ca 75 | 76 | x = axis[..., 0].unsqueeze(1) 77 | y = axis[..., 1].unsqueeze(1) 78 | z = axis[..., 2].unsqueeze(1) 79 | 80 | xs = x * sa 81 | ys = y * sa 82 | zs = z * sa 83 | xC = x * C 84 | yC = y * C 85 | zC = z * C 86 | xyC = x * yC 87 | yzC = y * zC 88 | zxC = z * xC 89 | 90 | rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device) 91 | 92 | rot[:, 0, 0] = torch.squeeze(x * xC + ca) 93 | rot[:, 0, 1] = torch.squeeze(xyC - zs) 94 | rot[:, 0, 2] = torch.squeeze(zxC + ys) 95 | rot[:, 1, 0] = torch.squeeze(xyC + zs) 96 | rot[:, 1, 1] = torch.squeeze(y * yC + ca) 97 | rot[:, 1, 2] = torch.squeeze(yzC - xs) 98 | rot[:, 2, 0] = torch.squeeze(zxC - ys) 99 | rot[:, 2, 1] = torch.squeeze(yzC + xs) 100 | rot[:, 2, 2] = torch.squeeze(z * zC + ca) 101 | rot[:, 3, 3] = 1 102 | 103 | return rot 104 | 105 | 106 | class ConvBlock(nn.Module): 107 | """Layer to perform a convolution followed by ELU 108 | """ 109 | def __init__(self, in_channels, out_channels): 110 | super(ConvBlock, self).__init__() 111 | 112 | self.conv = Conv3x3(in_channels, out_channels) 113 | self.nonlin = nn.ELU(inplace=True) 114 | 115 | def forward(self, x): 116 | out = self.conv(x) 117 | out = self.nonlin(out) 118 | return out 119 | 120 | 121 | class Conv3x3(nn.Module): 122 | """Layer to pad and convolve input 123 | """ 124 | def __init__(self, in_channels, out_channels, use_refl=True): 125 | super(Conv3x3, self).__init__() 126 | 127 | if use_refl: 128 | self.pad = nn.ReflectionPad2d(1) 129 | else: 130 | self.pad = nn.ZeroPad2d(1) 131 | self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) 132 | 133 | def forward(self, x): 134 | out = self.pad(x) 135 | out = self.conv(out) 136 | return out 137 | 138 | 139 | class BackprojectDepth(nn.Module): 140 | """Layer to transform a depth image into a point cloud 141 | """ 142 | def __init__(self, batch_size, height, width): 143 | super(BackprojectDepth, self).__init__() 144 | 145 | self.batch_size = batch_size 146 | self.height = height 147 | self.width = width 148 | 149 | meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') 150 | self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) 151 | self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), 152 | requires_grad=False) 153 | 154 | self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width), 155 | requires_grad=False) 156 | 157 | self.pix_coords = torch.unsqueeze(torch.stack( 158 | [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) 159 | self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) 160 | self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), 161 | requires_grad=False) 162 | 163 | def forward(self, depth, K,scale): 164 | K[:,:2,:] = (K[:,:2,:]/(2 ** scale)).trunc() 165 | b,n,n = K.shape 166 | inv_K = torch.linalg.inv(K) 167 | #inv_K = torch.cholesky_inverse(K) 168 | pad = torch.tensor([0.0,0.0,0.0]).view(1,3,1).expand(b,3,1).cuda() 169 | inv_K = torch.cat([inv_K,pad],-1) 170 | pad = torch.tensor([0.0,0.0,0.0,1.0]).view(1,1,4).expand(b,1,4).cuda() 171 | inv_K = torch.cat([inv_K,pad],1) 172 | cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords) 173 | cam_points = depth.view(self.batch_size, 1, -1) * cam_points 174 | cam_points = torch.cat([cam_points, self.ones], 1) 175 | 176 | return cam_points 177 | 178 | 179 | class Project3D(nn.Module): 180 | """Layer which projects 3D points into a camera with intrinsics K and at position T 181 | """ 182 | def __init__(self, batch_size, height, width, eps=1e-7): 183 | super(Project3D, self).__init__() 184 | 185 | self.batch_size = batch_size 186 | self.height = height 187 | self.width = width 188 | self.eps = eps 189 | 190 | def forward(self, points, K, T,scale=0): 191 | # K[0, :] *= self.width // (2 ** scale) 192 | # K[1, :] *= self.height // (2 ** scale) 193 | K[:,:2,:] = (K[:,:2,:]/(2 ** scale)).trunc() 194 | b,n,n = K.shape 195 | pad = torch.tensor([0.0,0.0,0.0]).view(1,3,1).expand(b,3,1).cuda() 196 | K = torch.cat([K,pad],-1) 197 | pad = torch.tensor([0.0,0.0,0.0,1.0]).view(1,1,4).expand(b,1,4).cuda() 198 | K = torch.cat([K,pad],1) 199 | P = torch.matmul(K, T)[:, :3, :] 200 | 201 | cam_points = torch.matmul(P, points) 202 | 203 | pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) 204 | pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) 205 | pix_coords = pix_coords.permute(0, 2, 3, 1) 206 | pix_coords[..., 0] /= self.width - 1 207 | pix_coords[..., 1] /= self.height - 1 208 | pix_coords = (pix_coords - 0.5) * 2 209 | return pix_coords 210 | 211 | 212 | def upsample(x): 213 | """Upsample input tensor by a factor of 2 214 | """ 215 | return F.interpolate(x, scale_factor=2, mode="nearest") 216 | 217 | 218 | def get_smooth_loss(disp, img): 219 | """Computes the smoothness loss for a disparity image 220 | The color image is used for edge-aware smoothness 221 | """ 222 | grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) 223 | grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) 224 | 225 | grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) 226 | grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) 227 | 228 | grad_disp_x *= torch.exp(-grad_img_x) 229 | grad_disp_y *= torch.exp(-grad_img_y) 230 | 231 | return grad_disp_x.mean() + grad_disp_y.mean() 232 | 233 | 234 | class SSIM(nn.Module): 235 | """Layer to compute the SSIM loss between a pair of images 236 | """ 237 | def __init__(self): 238 | super(SSIM, self).__init__() 239 | self.mu_x_pool = nn.AvgPool2d(3, 1) 240 | self.mu_y_pool = nn.AvgPool2d(3, 1) 241 | self.sig_x_pool = nn.AvgPool2d(3, 1) 242 | self.sig_y_pool = nn.AvgPool2d(3, 1) 243 | self.sig_xy_pool = nn.AvgPool2d(3, 1) 244 | 245 | self.refl = nn.ReflectionPad2d(1) 246 | 247 | self.C1 = 0.01 ** 2 248 | self.C2 = 0.03 ** 2 249 | 250 | def forward(self, x, y): 251 | x = self.refl(x) 252 | y = self.refl(y) 253 | 254 | mu_x = self.mu_x_pool(x) 255 | mu_y = self.mu_y_pool(y) 256 | 257 | sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2 258 | sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2 259 | sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y 260 | 261 | SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) 262 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2) 263 | 264 | return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) 265 | 266 | 267 | def compute_depth_errors(gt, pred): 268 | """Computation of error metrics between predicted and ground truth depths 269 | """ 270 | thresh = torch.max((gt / pred), (pred / gt)) 271 | a1 = (thresh < 1.25 ).float().mean() 272 | a2 = (thresh < 1.25 ** 2).float().mean() 273 | a3 = (thresh < 1.25 ** 3).float().mean() 274 | 275 | rmse = (gt - pred) ** 2 276 | rmse = torch.sqrt(rmse.mean()) 277 | 278 | rmse_log = (torch.log(gt) - torch.log(pred)) ** 2 279 | rmse_log = torch.sqrt(rmse_log.mean()) 280 | 281 | abs_rel = torch.mean(torch.abs(gt - pred) / gt) 282 | 283 | sq_rel = torch.mean((gt - pred) ** 2 / gt) 284 | 285 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 286 | -------------------------------------------------------------------------------- /depth/models/opt.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_path": "/data/fhongac/workspace/src/talkhead/kitti_data", 3 | "log_dir": "tmp", 4 | "model_name": "taking_head_10w", 5 | "split": "eigen_zhou", 6 | "num_layers": 18, 7 | "dataset": "celeb", 8 | "png": false, 9 | "height": 224, 10 | "width": 224, 11 | "disparity_smoothness": 0.001, 12 | "scales": [ 13 | 0, 14 | 1, 15 | 2, 16 | 3 17 | ], 18 | "sample_num": 100000, 19 | "min_depth": 0.1, 20 | "max_depth": 100.0, 21 | "use_stereo": false, 22 | "frame_ids": [ 23 | 0, 24 | -1, 25 | 1 26 | ], 27 | "batch_size": 64, 28 | "learning_rate": 1e-05, 29 | "num_epochs": 20, 30 | "scheduler_step_size": 15, 31 | "v1_multiscale": false, 32 | "avg_reprojection": false, 33 | "disable_automasking": false, 34 | "predictive_mask": false, 35 | "no_ssim": false, 36 | "weights_init": "pretrained", 37 | "pose_model_input": "pairs", 38 | "pose_model_type": "separate_resnet", 39 | "no_cuda": false, 40 | "num_workers": 12, 41 | "load_weights_folder": null, 42 | "models_to_load": [ 43 | "encoder", 44 | "depth", 45 | "pose_encoder", 46 | "pose" 47 | ], 48 | "log_frequency": 250, 49 | "save_frequency": 1, 50 | "eval_stereo": false, 51 | "eval_mono": false, 52 | "disable_median_scaling": false, 53 | "pred_depth_scale_factor": 1, 54 | "ext_disp_to_eval": null, 55 | "eval_split": "eigen", 56 | "save_pred_disps": false, 57 | "no_eval": false, 58 | "eval_eigen_to_benchmark": false, 59 | "eval_out_dir": null, 60 | "post_process": false 61 | } -------------------------------------------------------------------------------- /depth/pose_cnn.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class PoseCNN(nn.Module): 14 | def __init__(self, num_input_frames): 15 | super(PoseCNN, self).__init__() 16 | 17 | self.num_input_frames = num_input_frames 18 | 19 | self.convs = {} 20 | self.convs[0] = nn.Conv2d(3 * num_input_frames, 16, 7, 2, 3) 21 | self.convs[1] = nn.Conv2d(16, 32, 5, 2, 2) 22 | self.convs[2] = nn.Conv2d(32, 64, 3, 2, 1) 23 | self.convs[3] = nn.Conv2d(64, 128, 3, 2, 1) 24 | self.convs[4] = nn.Conv2d(128, 256, 3, 2, 1) 25 | self.convs[5] = nn.Conv2d(256, 256, 3, 2, 1) 26 | self.convs[6] = nn.Conv2d(256, 256, 3, 2, 1) 27 | 28 | self.pose_conv = nn.Conv2d(256, 6 * (num_input_frames - 1), 1) 29 | 30 | self.num_convs = len(self.convs) 31 | 32 | self.relu = nn.ReLU(True) 33 | 34 | self.net = nn.ModuleList(list(self.convs.values())) 35 | 36 | def forward(self, out): 37 | 38 | for i in range(self.num_convs): 39 | out = self.convs[i](out) 40 | out = self.relu(out) 41 | 42 | out = self.pose_conv(out) 43 | out = out.mean(3).mean(2) 44 | 45 | out = 0.01 * out.view(-1, self.num_input_frames - 1, 1, 6) 46 | 47 | axisangle = out[..., :3] 48 | translation = out[..., 3:] 49 | 50 | return axisangle, translation 51 | -------------------------------------------------------------------------------- /depth/pose_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import torch 10 | import torch.nn as nn 11 | from collections import OrderedDict 12 | import pdb 13 | import torch.nn.functional as F 14 | # from options import MonodepthOptions 15 | # options = MonodepthOptions() 16 | # opts = options.parse() 17 | class PoseDecoder(nn.Module): 18 | def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=None, stride=1): 19 | super(PoseDecoder, self).__init__() 20 | self.num_ch_enc = num_ch_enc 21 | self.num_input_features = num_input_features 22 | 23 | if num_frames_to_predict_for is None: 24 | num_frames_to_predict_for = num_input_features - 1 25 | self.num_frames_to_predict_for = num_frames_to_predict_for 26 | 27 | self.convs = OrderedDict() 28 | self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1) 29 | self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1) 30 | self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1) 31 | self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1) 32 | self.convs[("intrinsics", 'focal')] = nn.Conv2d(256, 2, kernel_size = 3,stride = 1,padding = 1) 33 | self.convs[("intrinsics", 'offset')] = nn.Conv2d(256, 2, kernel_size = 3,stride = 1,padding = 1) 34 | 35 | self.relu = nn.ReLU() 36 | self.net = nn.ModuleList(list(self.convs.values())) 37 | 38 | def forward(self, input_features): 39 | last_features = [f[-1] for f in input_features] 40 | 41 | cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features] 42 | cat_features = torch.cat(cat_features, 1) 43 | 44 | feat = cat_features 45 | for i in range(2): 46 | feat = self.convs[("pose", i)](feat) 47 | feat = self.relu(feat) 48 | out = self.convs[("pose", 2)](feat) 49 | 50 | out = out.mean(3).mean(2) 51 | out = 0.01 * out.view(-1, self.num_frames_to_predict_for, 1, 6) 52 | 53 | axisangle = out[..., :3] 54 | translation = out[..., 3:] 55 | 56 | #add_intrinsics_head 57 | scales = torch.tensor([256,256]).cuda() 58 | focals = F.softplus(self.convs[("intrinsics", 'focal')](feat)).mean(3).mean(2)*scales 59 | offset = (F.softplus(self.convs[("intrinsics", 'offset')](feat)).mean(3).mean(2)+0.5)*scales 60 | #focals = F.softplus(self.convs[("intrinsics",'focal')](feat).mean(3).mean(2)) 61 | #offset = F.softplus(self.convs[("intrinsics",'offset')](feat).mean(3).mean(2)) 62 | eyes = torch.eye(2).cuda() 63 | b,xy = focals.shape 64 | focals = focals.unsqueeze(-1).expand(b,xy,xy) 65 | eyes = eyes.unsqueeze(0).expand(b,xy,xy) 66 | intrin = focals*eyes 67 | offset = offset.view(b,2,1).contiguous() 68 | intrin = torch.cat([intrin,offset],-1) 69 | pad = torch.tensor([0.0,0.0,1.0]).view(1,1,3).expand(b,1,3).cuda() 70 | intrinsics = torch.cat([intrin,pad],1) 71 | return axisangle, translation,intrinsics 72 | -------------------------------------------------------------------------------- /depth/resnet_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torchvision.models as models 14 | import torch.utils.model_zoo as model_zoo 15 | 16 | 17 | class ResNetMultiImageInput(models.ResNet): 18 | """Constructs a resnet model with varying number of input images. 19 | Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 20 | """ 21 | def __init__(self, block, layers, num_classes=1000, num_input_images=1): 22 | super(ResNetMultiImageInput, self).__init__(block, layers) 23 | self.inplanes = 64 24 | self.conv1 = nn.Conv2d( 25 | num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False) 26 | self.bn1 = nn.BatchNorm2d(64) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 29 | self.layer1 = self._make_layer(block, 64, layers[0]) 30 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 31 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 32 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 33 | 34 | for m in self.modules(): 35 | if isinstance(m, nn.Conv2d): 36 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 37 | elif isinstance(m, nn.BatchNorm2d): 38 | nn.init.constant_(m.weight, 1) 39 | nn.init.constant_(m.bias, 0) 40 | 41 | 42 | def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1): 43 | """Constructs a ResNet model. 44 | Args: 45 | num_layers (int): Number of resnet layers. Must be 18 or 50 46 | pretrained (bool): If True, returns a model pre-trained on ImageNet 47 | num_input_images (int): Number of frames stacked as input 48 | """ 49 | assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet" 50 | blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers] 51 | block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers] 52 | model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images) 53 | 54 | if pretrained: 55 | loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)]) 56 | loaded['conv1.weight'] = torch.cat( 57 | [loaded['conv1.weight']] * num_input_images, 1) / num_input_images 58 | model.load_state_dict(loaded) 59 | return model 60 | 61 | 62 | class ResnetEncoder(nn.Module): 63 | """Pytorch module for a resnet encoder 64 | """ 65 | def __init__(self, num_layers, pretrained, num_input_images=1): 66 | super(ResnetEncoder, self).__init__() 67 | 68 | self.num_ch_enc = np.array([64, 64, 128, 256, 512]) 69 | 70 | resnets = {18: models.resnet18, 71 | 34: models.resnet34, 72 | 50: models.resnet50, 73 | 101: models.resnet101, 74 | 152: models.resnet152} 75 | 76 | if num_layers not in resnets: 77 | raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) 78 | 79 | if num_input_images > 1: 80 | self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images) 81 | else: 82 | self.encoder = resnets[num_layers](pretrained) 83 | 84 | if num_layers > 34: 85 | self.num_ch_enc[1:] *= 4 86 | 87 | def forward(self, input_image): 88 | self.features = [] 89 | x = (input_image - 0.45) / 0.225 90 | x = self.encoder.conv1(x) 91 | x = self.encoder.bn1(x) 92 | self.features.append(self.encoder.relu(x)) 93 | self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1]))) 94 | self.features.append(self.encoder.layer2(self.features[-1])) 95 | self.features.append(self.encoder.layer3(self.features[-1])) 96 | self.features.append(self.encoder.layer4(self.features[-1])) 97 | 98 | return self.features 99 | -------------------------------------------------------------------------------- /evaluation_dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset class template 2 | 3 | This module provides a template for users to implement custom datasets. 4 | You can specify '--dataset_mode template' to use this dataset. 5 | The class name should be consistent with both the filename and its dataset_mode option. 6 | The filename should be _dataset.py 7 | The class name should be Dataset.py 8 | You need to implement the following functions: 9 | -- : Add dataset-specific options and rewrite default values for existing options. 10 | -- <__init__>: Initialize this dataset class. 11 | -- <__getitem__>: Return a data point and its metadata information. 12 | -- <__len__>: Return the number of images. 13 | """ 14 | from torch.utils.data import Dataset 15 | import torchvision.transforms as T 16 | from PIL import Image 17 | from PIL import ImageFile 18 | from skimage import io, img_as_float32 19 | import numpy as np 20 | ImageFile.LOAD_TRUNCATED_IMAGES = True 21 | # from data.image_folder import make_dataset 22 | # from PIL import Image 23 | import os 24 | import torch 25 | import pdb 26 | import pandas as pd 27 | 28 | class EvaluationDataset(Dataset): 29 | """A template dataset class for you to implement custom datasets.""" 30 | 31 | def __init__(self, dataroot, pairs_list=None): 32 | """Initialize this dataset class. 33 | 34 | Parameters: 35 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 36 | 37 | A few things can be done here. 38 | - save the options (have been done in BaseDataset) 39 | - get image paths and meta information of the dataset. 40 | - define the image transformation. 41 | """ 42 | # save the option and dataset root 43 | # get the image paths of your dataset; 44 | self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root 45 | # define the default transform function. You can use ; You can also define your custom transform function 46 | self.dataroot = dataroot 47 | # self.videos = self.videos[5000] 48 | self.frame_shape = (3,256,256) 49 | test_videos = os.listdir(os.path.join(self.dataroot,'test')) 50 | self.videos = test_videos 51 | pairs = pd.read_csv(pairs_list) 52 | self.source = pairs['source'].tolist() 53 | self.driving = pairs['driving'].tolist() 54 | # self.pose_anchors = pairs['best_frame'].tolist() 55 | 56 | self.transforms = T.Compose([T.ToTensor(), 57 | T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 58 | def __getitem__(self, idx): 59 | """Return a data point and its metadata information. 60 | 61 | Parameters: 62 | index -- a random integer for data indexing 63 | 64 | Returns: 65 | a dictionary of data with their names. It usually contains the data itself and its metadata information. 66 | 67 | Step 1: get a random image path: e.g., path = self.image_paths[index] 68 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). 69 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) 70 | Step 4: return a data point as a dictionary. 71 | """ 72 | path_source = self.source[idx] 73 | path_driving = self.driving[idx] 74 | # path_anchor = self.pose_anchors[idx] 75 | anchor = '' 76 | source = img_as_float32(io.imread(path_source)) 77 | source = np.array(source, dtype='float32') 78 | source = torch.tensor(source.transpose((2, 0, 1))) 79 | 80 | driving = img_as_float32(io.imread(path_driving)) 81 | driving = np.array(driving, dtype='float32') 82 | driving = torch.tensor(driving.transpose((2, 0, 1))) 83 | 84 | # anchor = img_as_float32(io.imread(path_anchor)) 85 | # anchor = np.array(anchor, dtype='float32') 86 | # anchor = torch.tensor(anchor.transpose((2, 0, 1))) 87 | 88 | # source = Image.open(path_source).convert('RGB') 89 | # driving = Image.open(path_driving).convert('RGB') 90 | # source = T.ToTensor()(source) 91 | # driving = T.ToTensor()(driving) 92 | return {'source': source, 'driving': driving, 'path_source': path_source,'path_driving':path_driving, 'anchor': anchor} 93 | 94 | def __len__(self): 95 | """Return the total number of images.""" 96 | return len(self.source) 97 | 98 | 99 | -------------------------------------------------------------------------------- /face-alignment/.gitattributes: -------------------------------------------------------------------------------- 1 | *.py linguist-language=python 2 | *.ipynb linguist-documentation 3 | -------------------------------------------------------------------------------- /face-alignment/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /face-alignment/Dockerfile: -------------------------------------------------------------------------------- 1 | # Based on a older version of https://github.com/pytorch/pytorch/blob/master/Dockerfile 2 | FROM nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04 3 | 4 | RUN apt-get update && apt-get install -y --no-install-recommends \ 5 | build-essential \ 6 | cmake \ 7 | git \ 8 | curl \ 9 | vim \ 10 | ca-certificates \ 11 | libboost-all-dev \ 12 | python-qt4 \ 13 | libjpeg-dev \ 14 | libpng-dev &&\ 15 | rm -rf /var/lib/apt/lists/* 16 | 17 | RUN curl -o ~/miniconda.sh -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 18 | chmod +x ~/miniconda.sh && \ 19 | ~/miniconda.sh -b -p /opt/conda && \ 20 | rm ~/miniconda.sh 21 | 22 | ENV PATH /opt/conda/bin:$PATH 23 | 24 | RUN conda config --set always_yes yes --set changeps1 no && conda update -q conda 25 | RUN conda install pytorch torchvision cudatoolkit=10.1 -c pytorch 26 | 27 | # Install face-alignment package 28 | WORKDIR /workspace 29 | RUN chmod -R a+w /workspace 30 | RUN git clone https://github.com/1adrianb/face-alignment 31 | WORKDIR /workspace/face-alignment 32 | RUN pip install -r requirements.txt 33 | RUN python setup.py install 34 | -------------------------------------------------------------------------------- /face-alignment/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, Adrian Bulat 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /face-alignment/README.md: -------------------------------------------------------------------------------- 1 | # Face Recognition 2 | 3 | Detect facial landmarks from Python using the world's most accurate face alignment network, capable of detecting points in both 2D and 3D coordinates. 4 | 5 | Build using [FAN](https://www.adrianbulat.com)'s state-of-the-art deep learning based face alignment method. 6 | 7 |

8 | 9 | **Note:** The lua version is available [here](https://github.com/1adrianb/2D-and-3D-face-alignment). 10 | 11 | For numerical evaluations it is highly recommended to use the lua version which uses indentical models with the ones evaluated in the paper. More models will be added soon. 12 | 13 | [![License](https://img.shields.io/badge/License-BSD%203--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) [![Test Face alignmnet](https://github.com/1adrianb/face-alignment/workflows/Test%20Face%20alignmnet/badge.svg)](https://github.com/1adrianb/face-alignment/actions?query=workflow%3A%22Test+Face+alignmnet%22) [![Anaconda-Server Badge](https://anaconda.org/1adrianb/face_alignment/badges/version.svg)](https://anaconda.org/1adrianb/face_alignment) 14 | [![PyPI version](https://badge.fury.io/py/face-alignment.svg)](https://pypi.org/project/face-alignment/) 15 | 16 | ## Features 17 | 18 | #### Detect 2D facial landmarks in pictures 19 | 20 |

21 | 22 |

23 | 24 | ```python 25 | import face_alignment 26 | from skimage import io 27 | 28 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False) 29 | 30 | input = io.imread('../test/assets/aflw-test.jpg') 31 | preds = fa.get_landmarks(input) 32 | ``` 33 | 34 | #### Detect 3D facial landmarks in pictures 35 | 36 |

37 | 38 |

39 | 40 | ```python 41 | import face_alignment 42 | from skimage import io 43 | 44 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, flip_input=False) 45 | 46 | input = io.imread('../test/assets/aflw-test.jpg') 47 | preds = fa.get_landmarks(input) 48 | ``` 49 | 50 | #### Process an entire directory in one go 51 | 52 | ```python 53 | import face_alignment 54 | from skimage import io 55 | 56 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False) 57 | 58 | preds = fa.get_landmarks_from_directory('../test/assets/') 59 | ``` 60 | 61 | #### Detect the landmarks using a specific face detector. 62 | 63 | By default the package will use the SFD face detector. However the users can alternatively use dlib, BlazeFace, or pre-existing ground truth bounding boxes. 64 | 65 | ```python 66 | import face_alignment 67 | 68 | # sfd for SFD, dlib for Dlib and folder for existing bounding boxes. 69 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, face_detector='sfd') 70 | ``` 71 | 72 | #### Running on CPU/GPU 73 | In order to specify the device (GPU or CPU) on which the code will run one can explicitly pass the device flag: 74 | 75 | ```python 76 | import face_alignment 77 | 78 | # cuda for CUDA 79 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device='cpu') 80 | ``` 81 | 82 | Please also see the ``examples`` folder 83 | 84 | ## Installation 85 | 86 | ### Requirements 87 | 88 | * Python 3.5+ (it may work with other versions too). Last version with support for python 2.7 was v1.1.1 89 | * Linux, Windows or macOS 90 | * pytorch (>=1.5) 91 | 92 | While not required, for optimal performance(especially for the detector) it is **highly** recommended to run the code using a CUDA enabled GPU. 93 | 94 | ### Binaries 95 | 96 | The easiest way to install it is using either pip or conda: 97 | 98 | | **Using pip** | **Using conda** | 99 | |------------------------------|--------------------------------------------| 100 | | `pip install face-alignment` | `conda install -c 1adrianb face_alignment` | 101 | | | | 102 | 103 | Alternatively, bellow, you can find instruction to build it from source. 104 | 105 | ### From source 106 | 107 | Install pytorch and pytorch dependencies. Please check the [pytorch readme](https://github.com/pytorch/pytorch) for this. 108 | 109 | #### Get the Face Alignment source code 110 | ```bash 111 | git clone https://github.com/1adrianb/face-alignment 112 | ``` 113 | #### Install the Face Alignment lib 114 | ```bash 115 | pip install -r requirements.txt 116 | python setup.py install 117 | ``` 118 | 119 | ### Docker image 120 | 121 | A Dockerfile is provided to build images with cuda support and cudnn. For more instructions about running and building a docker image check the orginal Docker documentation. 122 | ``` 123 | docker build -t face-alignment . 124 | ``` 125 | 126 | ## How does it work? 127 | 128 | While here the work is presented as a black-box, if you want to know more about the intrisecs of the method please check the original paper either on arxiv or my [webpage](https://www.adrianbulat.com). 129 | 130 | ## Contributions 131 | 132 | All contributions are welcomed. If you encounter any issue (including examples of images where it fails) feel free to open an issue. If you plan to add a new features please open an issue to discuss this prior to making a pull request. 133 | 134 | ## Citation 135 | 136 | ``` 137 | @inproceedings{bulat2017far, 138 | title={How far are we from solving the 2D \& 3D Face Alignment problem? (and a dataset of 230,000 3D facial landmarks)}, 139 | author={Bulat, Adrian and Tzimiropoulos, Georgios}, 140 | booktitle={International Conference on Computer Vision}, 141 | year={2017} 142 | } 143 | ``` 144 | 145 | For citing dlib, pytorch or any other packages used here please check the original page of their respective authors. 146 | 147 | ## Acknowledgements 148 | 149 | * To the [pytorch](http://pytorch.org/) team for providing such an awesome deeplearning framework 150 | * To [my supervisor](http://www.cs.nott.ac.uk/~pszyt/) for his patience and suggestions. 151 | * To all other python developers that made available the rest of the packages used in this repository. -------------------------------------------------------------------------------- /face-alignment/conda/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set version = "1.3.4" %} 2 | 3 | package: 4 | name: face_alignment 5 | version: {{ version }} 6 | 7 | source: 8 | path: .. 9 | 10 | build: 11 | number: 1 12 | noarch: python 13 | script: python setup.py install --single-version-externally-managed --record=record.txt 14 | 15 | requirements: 16 | build: 17 | - setuptools 18 | - python 19 | run: 20 | - python 21 | - pytorch 22 | - numpy 23 | - scikit-image 24 | - scipy 25 | - opencv 26 | - tqdm 27 | - numba 28 | 29 | about: 30 | home: https://github.com/1adrianb/face-alignment 31 | license: BSD 32 | license_file: LICENSE 33 | summary: A 2D and 3D face alignment libray in python 34 | 35 | extra: 36 | recipe-maintainers: 37 | - 1adrianb 38 | -------------------------------------------------------------------------------- /face-alignment/docs/images/2dlandmarks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harlanhong/CVPR2022-DaGAN/052ad359a019492893c098e8717dc993cdf846ed/face-alignment/docs/images/2dlandmarks.png -------------------------------------------------------------------------------- /face-alignment/docs/images/face-alignment-adrian.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harlanhong/CVPR2022-DaGAN/052ad359a019492893c098e8717dc993cdf846ed/face-alignment/docs/images/face-alignment-adrian.gif -------------------------------------------------------------------------------- /face-alignment/examples/detect_landmarks_in_image.py: -------------------------------------------------------------------------------- 1 | import face_alignment 2 | import matplotlib.pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | from skimage import io 5 | import collections 6 | 7 | 8 | # Optionally set detector and some additional detector parameters 9 | face_detector = 'sfd' 10 | face_detector_kwargs = { 11 | "filter_threshold" : 0.8 12 | } 13 | 14 | # Run the 3D face alignment on a test image, without CUDA. 15 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, device='cpu', flip_input=True, 16 | face_detector=face_detector, face_detector_kwargs=face_detector_kwargs) 17 | 18 | try: 19 | input_img = io.imread('../test/assets/aflw-test.jpg') 20 | except FileNotFoundError: 21 | input_img = io.imread('test/assets/aflw-test.jpg') 22 | 23 | preds = fa.get_landmarks(input_img)[-1] 24 | 25 | # 2D-Plot 26 | plot_style = dict(marker='o', 27 | markersize=4, 28 | linestyle='-', 29 | lw=2) 30 | 31 | pred_type = collections.namedtuple('prediction_type', ['slice', 'color']) 32 | pred_types = {'face': pred_type(slice(0, 17), (0.682, 0.780, 0.909, 0.5)), 33 | 'eyebrow1': pred_type(slice(17, 22), (1.0, 0.498, 0.055, 0.4)), 34 | 'eyebrow2': pred_type(slice(22, 27), (1.0, 0.498, 0.055, 0.4)), 35 | 'nose': pred_type(slice(27, 31), (0.345, 0.239, 0.443, 0.4)), 36 | 'nostril': pred_type(slice(31, 36), (0.345, 0.239, 0.443, 0.4)), 37 | 'eye1': pred_type(slice(36, 42), (0.596, 0.875, 0.541, 0.3)), 38 | 'eye2': pred_type(slice(42, 48), (0.596, 0.875, 0.541, 0.3)), 39 | 'lips': pred_type(slice(48, 60), (0.596, 0.875, 0.541, 0.3)), 40 | 'teeth': pred_type(slice(60, 68), (0.596, 0.875, 0.541, 0.4)) 41 | } 42 | 43 | fig = plt.figure(figsize=plt.figaspect(.5)) 44 | ax = fig.add_subplot(1, 2, 1) 45 | ax.imshow(input_img) 46 | 47 | for pred_type in pred_types.values(): 48 | ax.plot(preds[pred_type.slice, 0], 49 | preds[pred_type.slice, 1], 50 | color=pred_type.color, **plot_style) 51 | 52 | ax.axis('off') 53 | 54 | # 3D-Plot 55 | ax = fig.add_subplot(1, 2, 2, projection='3d') 56 | surf = ax.scatter(preds[:, 0] * 1.2, 57 | preds[:, 1], 58 | preds[:, 2], 59 | c='cyan', 60 | alpha=1.0, 61 | edgecolor='b') 62 | 63 | for pred_type in pred_types.values(): 64 | ax.plot3D(preds[pred_type.slice, 0] * 1.2, 65 | preds[pred_type.slice, 1], 66 | preds[pred_type.slice, 2], color='blue') 67 | 68 | ax.view_init(elev=90., azim=90.) 69 | ax.set_xlim(ax.get_xlim()[::-1]) 70 | plt.show() 71 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | __author__ = """Adrian Bulat""" 4 | __email__ = 'adrian@adrianbulat.com' 5 | __version__ = '1.3.4' 6 | 7 | from .api import FaceAlignment, LandmarksType, NetworkSize -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import FaceDetector -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/blazeface/__init__.py: -------------------------------------------------------------------------------- 1 | from .blazeface_detector import BlazeFaceDetector as FaceDetector 2 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/blazeface/blazeface_detector.py: -------------------------------------------------------------------------------- 1 | from torch.utils.model_zoo import load_url 2 | 3 | from ..core import FaceDetector 4 | from ...utils import load_file_from_url 5 | 6 | from .net_blazeface import BlazeFace 7 | from .detect import * 8 | 9 | models_urls = { 10 | 'blazeface_weights': 'https://github.com/hollance/BlazeFace-PyTorch/blob/master/blazeface.pth?raw=true', 11 | 'blazeface_anchors': 'https://github.com/hollance/BlazeFace-PyTorch/blob/master/anchors.npy?raw=true' 12 | } 13 | 14 | 15 | class BlazeFaceDetector(FaceDetector): 16 | def __init__(self, device, path_to_detector=None, path_to_anchor=None, verbose=False, 17 | min_score_thresh=0.5, min_suppression_threshold=0.3): 18 | super(BlazeFaceDetector, self).__init__(device, verbose) 19 | 20 | # Initialise the face detector 21 | if path_to_detector is None: 22 | model_weights = load_url(models_urls['blazeface_weights']) 23 | model_anchors = np.load(load_file_from_url(models_urls['blazeface_anchors'])) 24 | else: 25 | model_weights = torch.load(path_to_detector) 26 | model_anchors = np.load(path_to_anchor) 27 | 28 | self.face_detector = BlazeFace() 29 | self.face_detector.load_state_dict(model_weights) 30 | self.face_detector.load_anchors_from_npy(model_anchors, device) 31 | 32 | # Optionally change the thresholds: 33 | self.face_detector.min_score_thresh = min_score_thresh 34 | self.face_detector.min_suppression_threshold = min_suppression_threshold 35 | 36 | self.face_detector.to(device) 37 | self.face_detector.eval() 38 | 39 | def detect_from_image(self, tensor_or_path): 40 | image = self.tensor_or_path_to_ndarray(tensor_or_path) 41 | 42 | bboxlist = detect(self.face_detector, image, device=self.device)[0] 43 | 44 | return bboxlist 45 | 46 | def detect_from_batch(self, tensor): 47 | bboxlists = batch_detect(self.face_detector, tensor, device=self.device) 48 | return bboxlists 49 | 50 | @property 51 | def reference_scale(self): 52 | return 195 53 | 54 | @property 55 | def reference_x_shift(self): 56 | return 0 57 | 58 | @property 59 | def reference_y_shift(self): 60 | return 0 61 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/blazeface/detect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | from .utils import * 8 | 9 | 10 | def detect(net, img, device): 11 | H, W, C = img.shape 12 | orig_size = min(H, W) 13 | img, (xshift, yshift) = resize_and_crop_image(img, 128) 14 | preds = net.predict_on_image(img) 15 | 16 | if 0 == len(preds): 17 | return [[]] 18 | 19 | shift = np.array([xshift, yshift] * 2) 20 | scores = preds[:, -1:] 21 | 22 | # TODO: ugly 23 | # reverses, x and y to adapt with face-alignment code 24 | locs = np.concatenate((preds[:, 1:2], preds[:, 0:1], preds[:, 3:4], preds[:, 2:3]), axis=1) 25 | return [np.concatenate((locs * orig_size + shift, scores), axis=1)] 26 | 27 | 28 | def batch_detect(net, img_batch, device): 29 | """ 30 | Inputs: 31 | - img_batch: a numpy array or tensor of shape (Batch size, Channels, Height, Width) 32 | Outputs: 33 | - list of 2-dim numpy arrays with shape (faces_on_this_image, 5): x1, y1, x2, y2, confidence 34 | (x1, y1) - top left corner, (x2, y2) - bottom right corner 35 | """ 36 | B, C, H, W = img_batch.shape 37 | orig_size = min(H, W) 38 | 39 | if isinstance(img_batch, torch.Tensor): 40 | img_batch = img_batch.cpu().numpy() 41 | 42 | img_batch = img_batch.transpose((0, 2, 3, 1)) 43 | 44 | imgs, (xshift, yshift) = resize_and_crop_batch(img_batch, 128) 45 | preds = net.predict_on_batch(imgs) 46 | bboxlists = [] 47 | for pred in preds: 48 | shift = np.array([xshift, yshift] * 2) 49 | scores = pred[:, -1:] 50 | locs = np.concatenate((pred[:, 1:2], pred[:, 0:1], pred[:, 3:4], pred[:, 2:3]), axis=1) 51 | bboxlists.append(np.concatenate((locs * orig_size + shift, scores), axis=1)) 52 | 53 | return bboxlists 54 | 55 | 56 | def flip_detect(net, img, device): 57 | img = cv2.flip(img, 1) 58 | b = detect(net, img, device) 59 | 60 | bboxlist = np.zeros(b.shape) 61 | bboxlist[:, 0] = img.shape[1] - b[:, 2] 62 | bboxlist[:, 1] = b[:, 1] 63 | bboxlist[:, 2] = img.shape[1] - b[:, 0] 64 | bboxlist[:, 3] = b[:, 3] 65 | bboxlist[:, 4] = b[:, 4] 66 | return bboxlist 67 | 68 | 69 | def pts_to_bb(pts): 70 | min_x, min_y = np.min(pts, axis=0) 71 | max_x, max_y = np.max(pts, axis=0) 72 | return np.array([min_x, min_y, max_x, max_y]) 73 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/blazeface/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def image_resize(image, width=None, height=None, inter=cv2.INTER_AREA): 6 | # initialize the dimensions of the image to be resized and 7 | # grab the image size 8 | dim = None 9 | (h, w) = image.shape[:2] 10 | 11 | # if both the width and height are None, then return the 12 | # original image 13 | if width is None and height is None: 14 | return image 15 | 16 | # check to see if the width is None 17 | if width is None: 18 | # calculate the ratio of the height and construct the 19 | # dimensions 20 | r = height / float(h) 21 | dim = (int(w * r), height) 22 | 23 | # otherwise, the height is None 24 | else: 25 | # calculate the ratio of the width and construct the 26 | # dimensions 27 | r = width / float(w) 28 | dim = (width, int(h * r)) 29 | 30 | # resize the image 31 | resized = cv2.resize(image, dim, interpolation=inter) 32 | 33 | # return the resized image 34 | return resized 35 | 36 | 37 | def resize_and_crop_image(image, dim): 38 | if image.shape[0] > image.shape[1]: 39 | img = image_resize(image, width=dim) 40 | yshift, xshift = (image.shape[0] - image.shape[1]) // 2, 0 41 | y_start = (img.shape[0] - img.shape[1]) // 2 42 | y_end = y_start + dim 43 | return img[y_start:y_end, :, :], (xshift, yshift) 44 | else: 45 | img = image_resize(image, height=dim) 46 | yshift, xshift = 0, (image.shape[1] - image.shape[0]) // 2 47 | x_start = (img.shape[1] - img.shape[0]) // 2 48 | x_end = x_start + dim 49 | return img[:, x_start:x_end, :], (xshift, yshift) 50 | 51 | 52 | def resize_and_crop_batch(frames, dim): 53 | """ 54 | Center crop + resize to (dim x dim) 55 | inputs: 56 | - frames: list of images (numpy arrays) 57 | - dim: output dimension size 58 | """ 59 | smframes = [] 60 | xshift, yshift = 0, 0 61 | for i in range(len(frames)): 62 | smframe, (xshift, yshift) = resize_and_crop_image(frames[i], dim) 63 | smframes.append(smframe) 64 | smframes = np.stack(smframes) 65 | return smframes, (xshift, yshift) 66 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/core.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import glob 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | from skimage import io 7 | 8 | 9 | class FaceDetector(object): 10 | """An abstract class representing a face detector. 11 | 12 | Any other face detection implementation must subclass it. All subclasses 13 | must implement ``detect_from_image``, that return a list of detected 14 | bounding boxes. Optionally, for speed considerations detect from path is 15 | recommended. 16 | """ 17 | 18 | def __init__(self, device, verbose): 19 | self.device = device 20 | self.verbose = verbose 21 | 22 | if verbose: 23 | if 'cpu' in device: 24 | logger = logging.getLogger(__name__) 25 | logger.warning("Detection running on CPU, this may be potentially slow.") 26 | 27 | if 'cpu' not in device and 'cuda' not in device: 28 | if verbose: 29 | logger.error("Expected values for device are: {cpu, cuda} but got: %s", device) 30 | raise ValueError 31 | 32 | def detect_from_image(self, tensor_or_path): 33 | """Detects faces in a given image. 34 | 35 | This function detects the faces present in a provided BGR(usually) 36 | image. The input can be either the image itself or the path to it. 37 | 38 | Arguments: 39 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path 40 | to an image or the image itself. 41 | 42 | Example:: 43 | 44 | >>> path_to_image = 'data/image_01.jpg' 45 | ... detected_faces = detect_from_image(path_to_image) 46 | [A list of bounding boxes (x1, y1, x2, y2)] 47 | >>> image = cv2.imread(path_to_image) 48 | ... detected_faces = detect_from_image(image) 49 | [A list of bounding boxes (x1, y1, x2, y2)] 50 | 51 | """ 52 | raise NotImplementedError 53 | 54 | def detect_from_batch(self, tensor): 55 | """Detects faces in a given image. 56 | 57 | This function detects the faces present in a provided BGR(usually) 58 | image. The input can be either the image itself or the path to it. 59 | 60 | Arguments: 61 | tensor {torch.tensor} -- image batch tensor. 62 | 63 | Example:: 64 | 65 | >>> path_to_image = 'data/image_01.jpg' 66 | ... detected_faces = detect_from_image(path_to_image) 67 | [A list of bounding boxes (x1, y1, x2, y2)] 68 | >>> image = cv2.imread(path_to_image) 69 | ... detected_faces = detect_from_image(image) 70 | [A list of bounding boxes (x1, y1, x2, y2)] 71 | 72 | """ 73 | raise NotImplementedError 74 | 75 | def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True): 76 | """Detects faces from all the images present in a given directory. 77 | 78 | Arguments: 79 | path {string} -- a string containing a path that points to the folder containing the images 80 | 81 | Keyword Arguments: 82 | extensions {list} -- list of string containing the extensions to be 83 | consider in the following format: ``.extension_name`` (default: 84 | {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the 85 | folder recursively (default: {False}) show_progress_bar {bool} -- 86 | display a progressbar (default: {True}) 87 | 88 | Example: 89 | >>> directory = 'data' 90 | ... detected_faces = detect_from_directory(directory) 91 | {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]} 92 | 93 | """ 94 | if self.verbose: 95 | logger = logging.getLogger(__name__) 96 | 97 | if len(extensions) == 0: 98 | if self.verbose: 99 | logger.error("Expected at list one extension, but none was received.") 100 | raise ValueError 101 | 102 | if self.verbose: 103 | logger.info("Constructing the list of images.") 104 | additional_pattern = '/**/*' if recursive else '/*' 105 | files = [] 106 | for extension in extensions: 107 | files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive)) 108 | 109 | if self.verbose: 110 | logger.info("Finished searching for images. %s images found", len(files)) 111 | logger.info("Preparing to run the detection.") 112 | 113 | predictions = {} 114 | for image_path in tqdm(files, disable=not show_progress_bar): 115 | if self.verbose: 116 | logger.info("Running the face detector on image: %s", image_path) 117 | predictions[image_path] = self.detect_from_image(image_path) 118 | 119 | if self.verbose: 120 | logger.info("The detector was successfully run on all %s images", len(files)) 121 | 122 | return predictions 123 | 124 | @property 125 | def reference_scale(self): 126 | raise NotImplementedError 127 | 128 | @property 129 | def reference_x_shift(self): 130 | raise NotImplementedError 131 | 132 | @property 133 | def reference_y_shift(self): 134 | raise NotImplementedError 135 | 136 | @staticmethod 137 | def tensor_or_path_to_ndarray(tensor_or_path): 138 | """Convert path (represented as a string) or torch.tensor to a numpy.ndarray 139 | 140 | Arguments: 141 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself 142 | """ 143 | if isinstance(tensor_or_path, str): 144 | return io.imread(tensor_or_path) 145 | elif torch.is_tensor(tensor_or_path): 146 | return tensor_or_path.cpu().numpy() 147 | elif isinstance(tensor_or_path, np.ndarray): 148 | return tensor_or_path 149 | else: 150 | raise TypeError 151 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/dlib/__init__.py: -------------------------------------------------------------------------------- 1 | from .dlib_detector import DlibDetector as FaceDetector -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/dlib/dlib_detector.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import cv2 3 | import dlib 4 | 5 | from ..core import FaceDetector 6 | from ...utils import load_file_from_url 7 | 8 | 9 | class DlibDetector(FaceDetector): 10 | def __init__(self, device, path_to_detector=None, verbose=False): 11 | super().__init__(device, verbose) 12 | 13 | warnings.warn('Warning: this detector is deprecated. Please use a different one, i.e.: S3FD.') 14 | 15 | # Initialise the face detector 16 | if 'cuda' in device: 17 | if path_to_detector is None: 18 | path_to_detector = load_file_from_url( 19 | "https://www.adrianbulat.com/downloads/dlib/mmod_human_face_detector.dat") 20 | 21 | self.face_detector = dlib.cnn_face_detection_model_v1(path_to_detector) 22 | else: 23 | self.face_detector = dlib.get_frontal_face_detector() 24 | 25 | def detect_from_image(self, tensor_or_path): 26 | image = self.tensor_or_path_to_ndarray(tensor_or_path) 27 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 28 | 29 | detected_faces = self.face_detector(cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)) 30 | 31 | if 'cuda' not in self.device: 32 | detected_faces = [[d.left(), d.top(), d.right(), d.bottom()] for d in detected_faces] 33 | else: 34 | detected_faces = [[d.rect.left(), d.rect.top(), d.rect.right(), d.rect.bottom()] for d in detected_faces] 35 | 36 | return detected_faces 37 | 38 | @property 39 | def reference_scale(self): 40 | return 195 41 | 42 | @property 43 | def reference_x_shift(self): 44 | return 0 45 | 46 | @property 47 | def reference_y_shift(self): 48 | return 0 49 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/folder/__init__.py: -------------------------------------------------------------------------------- 1 | from .folder_detector import FolderDetector as FaceDetector -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/folder/folder_detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | 5 | from ..core import FaceDetector 6 | 7 | 8 | class FolderDetector(FaceDetector): 9 | '''This is a simple helper module that assumes the faces were detected already 10 | (either previously or are provided as ground truth). 11 | 12 | The class expects to find the bounding boxes in the same format used by 13 | the rest of face detectors, mainly ``list[(x1,y1,x2,y2),...]``. 14 | For each image the detector will search for a file with the same name and with one of the 15 | following extensions: .npy, .t7 or .pth 16 | 17 | ''' 18 | 19 | def __init__(self, device, path_to_detector=None, verbose=False): 20 | super(FolderDetector, self).__init__(device, verbose) 21 | 22 | def detect_from_image(self, tensor_or_path): 23 | # Only strings supported 24 | if not isinstance(tensor_or_path, str): 25 | raise ValueError 26 | 27 | base_name = os.path.splitext(tensor_or_path)[0] 28 | 29 | if os.path.isfile(base_name + '.npy'): 30 | detected_faces = np.load(base_name + '.npy') 31 | elif os.path.isfile(base_name + '.t7'): 32 | detected_faces = torch.load(base_name + '.t7') 33 | elif os.path.isfile(base_name + '.pth'): 34 | detected_faces = torch.load(base_name + '.pth') 35 | else: 36 | raise FileNotFoundError 37 | 38 | if not isinstance(detected_faces, list): 39 | raise TypeError 40 | 41 | return detected_faces 42 | 43 | @property 44 | def reference_scale(self): 45 | return 195 46 | 47 | @property 48 | def reference_x_shift(self): 49 | return 0 50 | 51 | @property 52 | def reference_y_shift(self): 53 | return 0 54 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/sfd/__init__.py: -------------------------------------------------------------------------------- 1 | from .sfd_detector import SFDDetector as FaceDetector -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/sfd/bbox.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | 5 | def nms(dets, thresh): 6 | if 0 == len(dets): 7 | return [] 8 | x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4] 9 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 10 | order = scores.argsort()[::-1] 11 | 12 | keep = [] 13 | while order.size > 0: 14 | i = order[0] 15 | keep.append(i) 16 | xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]]) 17 | xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]]) 18 | 19 | w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1) 20 | ovr = w * h / (areas[i] + areas[order[1:]] - w * h) 21 | 22 | inds = np.where(ovr <= thresh)[0] 23 | order = order[inds + 1] 24 | 25 | return keep 26 | 27 | 28 | def encode(matched, priors, variances): 29 | """Encode the variances from the priorbox layers into the ground truth boxes 30 | we have matched (based on jaccard overlap) with the prior boxes. 31 | Args: 32 | matched: (tensor) Coords of ground truth for each prior in point-form 33 | Shape: [num_priors, 4]. 34 | priors: (tensor) Prior boxes in center-offset form 35 | Shape: [num_priors,4]. 36 | variances: (list[float]) Variances of priorboxes 37 | Return: 38 | encoded boxes (tensor), Shape: [num_priors, 4] 39 | """ 40 | 41 | # dist b/t match center and prior's center 42 | g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] 43 | # encode variance 44 | g_cxcy /= (variances[0] * priors[:, 2:]) 45 | # match wh / prior wh 46 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] 47 | g_wh = np.log(g_wh) / variances[1] 48 | 49 | # return target for smooth_l1_loss 50 | return np.concatenate([g_cxcy, g_wh], 1) # [num_priors,4] 51 | 52 | 53 | def decode(loc, priors, variances): 54 | """Decode locations from predictions using priors to undo 55 | the encoding we did for offset regression at train time. 56 | Args: 57 | loc (tensor): location predictions for loc layers, 58 | Shape: [num_priors,4] 59 | priors (tensor): Prior boxes in center-offset form. 60 | Shape: [num_priors,4]. 61 | variances: (list[float]) Variances of priorboxes 62 | Return: 63 | decoded bounding box predictions 64 | """ 65 | 66 | boxes = np.concatenate(( 67 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], 68 | priors[:, 2:] * np.exp(loc[:, 2:] * variances[1])), 1) 69 | boxes[:, :2] -= boxes[:, 2:] / 2 70 | boxes[:, 2:] += boxes[:, :2] 71 | return boxes 72 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/sfd/detect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | from .bbox import * 8 | 9 | 10 | def detect(net, img, device): 11 | img = img.transpose(2, 0, 1) 12 | # Creates a batch of 1 13 | img = np.expand_dims(img, 0) 14 | 15 | img = torch.from_numpy(img.copy()).to(device, dtype=torch.float32) 16 | 17 | return batch_detect(net, img, device) 18 | 19 | 20 | def batch_detect(net, img_batch, device): 21 | """ 22 | Inputs: 23 | - img_batch: a torch.Tensor of shape (Batch size, Channels, Height, Width) 24 | """ 25 | 26 | if 'cuda' in device: 27 | torch.backends.cudnn.benchmark = True 28 | 29 | batch_size = img_batch.size(0) 30 | img_batch = img_batch.to(device, dtype=torch.float32) 31 | 32 | img_batch = img_batch.flip(-3) # RGB to BGR 33 | img_batch = img_batch - torch.tensor([104.0, 117.0, 123.0], device=device).view(1, 3, 1, 1) 34 | 35 | with torch.no_grad(): 36 | olist = net(img_batch) # patched uint8_t overflow error 37 | 38 | for i in range(len(olist) // 2): 39 | olist[i * 2] = F.softmax(olist[i * 2], dim=1) 40 | 41 | olist = [oelem.data.cpu().numpy() for oelem in olist] 42 | 43 | bboxlists = get_predictions(olist, batch_size) 44 | return bboxlists 45 | 46 | 47 | def get_predictions(olist, batch_size): 48 | bboxlists = [] 49 | variances = [0.1, 0.2] 50 | for j in range(batch_size): 51 | bboxlist = [] 52 | for i in range(len(olist) // 2): 53 | ocls, oreg = olist[i * 2], olist[i * 2 + 1] 54 | stride = 2**(i + 2) # 4,8,16,32,64,128 55 | poss = zip(*np.where(ocls[:, 1, :, :] > 0.05)) 56 | for Iindex, hindex, windex in poss: 57 | axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride 58 | score = ocls[j, 1, hindex, windex] 59 | loc = oreg[j, :, hindex, windex].copy().reshape(1, 4) 60 | priors = np.array([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]) 61 | box = decode(loc, priors, variances) 62 | x1, y1, x2, y2 = box[0] 63 | bboxlist.append([x1, y1, x2, y2, score]) 64 | 65 | bboxlists.append(bboxlist) 66 | 67 | bboxlists = np.array(bboxlists) 68 | return bboxlists 69 | 70 | 71 | def flip_detect(net, img, device): 72 | img = cv2.flip(img, 1) 73 | b = detect(net, img, device) 74 | 75 | bboxlist = np.zeros(b.shape) 76 | bboxlist[:, 0] = img.shape[1] - b[:, 2] 77 | bboxlist[:, 1] = b[:, 1] 78 | bboxlist[:, 2] = img.shape[1] - b[:, 0] 79 | bboxlist[:, 3] = b[:, 3] 80 | bboxlist[:, 4] = b[:, 4] 81 | return bboxlist 82 | 83 | 84 | def pts_to_bb(pts): 85 | min_x, min_y = np.min(pts, axis=0) 86 | max_x, max_y = np.max(pts, axis=0) 87 | return np.array([min_x, min_y, max_x, max_y]) 88 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/sfd/net_s3fd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class L2Norm(nn.Module): 7 | def __init__(self, n_channels, scale=1.0): 8 | super(L2Norm, self).__init__() 9 | self.n_channels = n_channels 10 | self.scale = scale 11 | self.eps = 1e-10 12 | self.weight = nn.Parameter(torch.empty(self.n_channels).fill_(self.scale)) 13 | 14 | def forward(self, x): 15 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps 16 | x = x / norm * self.weight.view(1, -1, 1, 1) 17 | return x 18 | 19 | 20 | class s3fd(nn.Module): 21 | def __init__(self): 22 | super(s3fd, self).__init__() 23 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 24 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 25 | 26 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) 27 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 28 | 29 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) 30 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 31 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 32 | 33 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) 34 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 35 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 36 | 37 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 38 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 39 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 40 | 41 | self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3) 42 | self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0) 43 | 44 | self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) 45 | self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1) 46 | 47 | self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0) 48 | self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) 49 | 50 | self.conv3_3_norm = L2Norm(256, scale=10) 51 | self.conv4_3_norm = L2Norm(512, scale=8) 52 | self.conv5_3_norm = L2Norm(512, scale=5) 53 | 54 | self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 55 | self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 56 | self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 57 | self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 58 | self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 59 | self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 60 | 61 | self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1) 62 | self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1) 63 | self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 64 | self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 65 | self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1) 66 | self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 67 | 68 | def forward(self, x): 69 | h = F.relu(self.conv1_1(x), inplace=True) 70 | h = F.relu(self.conv1_2(h), inplace=True) 71 | h = F.max_pool2d(h, 2, 2) 72 | 73 | h = F.relu(self.conv2_1(h), inplace=True) 74 | h = F.relu(self.conv2_2(h), inplace=True) 75 | h = F.max_pool2d(h, 2, 2) 76 | 77 | h = F.relu(self.conv3_1(h), inplace=True) 78 | h = F.relu(self.conv3_2(h), inplace=True) 79 | h = F.relu(self.conv3_3(h), inplace=True) 80 | f3_3 = h 81 | h = F.max_pool2d(h, 2, 2) 82 | 83 | h = F.relu(self.conv4_1(h), inplace=True) 84 | h = F.relu(self.conv4_2(h), inplace=True) 85 | h = F.relu(self.conv4_3(h), inplace=True) 86 | f4_3 = h 87 | h = F.max_pool2d(h, 2, 2) 88 | 89 | h = F.relu(self.conv5_1(h), inplace=True) 90 | h = F.relu(self.conv5_2(h), inplace=True) 91 | h = F.relu(self.conv5_3(h), inplace=True) 92 | f5_3 = h 93 | h = F.max_pool2d(h, 2, 2) 94 | 95 | h = F.relu(self.fc6(h), inplace=True) 96 | h = F.relu(self.fc7(h), inplace=True) 97 | ffc7 = h 98 | h = F.relu(self.conv6_1(h), inplace=True) 99 | h = F.relu(self.conv6_2(h), inplace=True) 100 | f6_2 = h 101 | h = F.relu(self.conv7_1(h), inplace=True) 102 | h = F.relu(self.conv7_2(h), inplace=True) 103 | f7_2 = h 104 | 105 | f3_3 = self.conv3_3_norm(f3_3) 106 | f4_3 = self.conv4_3_norm(f4_3) 107 | f5_3 = self.conv5_3_norm(f5_3) 108 | 109 | cls1 = self.conv3_3_norm_mbox_conf(f3_3) 110 | reg1 = self.conv3_3_norm_mbox_loc(f3_3) 111 | cls2 = self.conv4_3_norm_mbox_conf(f4_3) 112 | reg2 = self.conv4_3_norm_mbox_loc(f4_3) 113 | cls3 = self.conv5_3_norm_mbox_conf(f5_3) 114 | reg3 = self.conv5_3_norm_mbox_loc(f5_3) 115 | cls4 = self.fc7_mbox_conf(ffc7) 116 | reg4 = self.fc7_mbox_loc(ffc7) 117 | cls5 = self.conv6_2_mbox_conf(f6_2) 118 | reg5 = self.conv6_2_mbox_loc(f6_2) 119 | cls6 = self.conv7_2_mbox_conf(f7_2) 120 | reg6 = self.conv7_2_mbox_loc(f7_2) 121 | 122 | # max-out background label 123 | chunk = torch.chunk(cls1, 4, 1) 124 | bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2]) 125 | cls1 = torch.cat([bmax, chunk[3]], dim=1) 126 | 127 | return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6] 128 | -------------------------------------------------------------------------------- /face-alignment/face_alignment/detection/sfd/sfd_detector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.model_zoo import load_url 3 | 4 | from ..core import FaceDetector 5 | 6 | from .net_s3fd import s3fd 7 | from .bbox import nms 8 | from .detect import detect, batch_detect 9 | 10 | models_urls = { 11 | 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth', 12 | } 13 | 14 | 15 | class SFDDetector(FaceDetector): 16 | '''SF3D Detector. 17 | ''' 18 | 19 | def __init__(self, device, path_to_detector=None, verbose=False, filter_threshold=0.5): 20 | super(SFDDetector, self).__init__(device, verbose) 21 | 22 | # Initialise the face detector 23 | if path_to_detector is None: 24 | model_weights = load_url(models_urls['s3fd']) 25 | else: 26 | model_weights = torch.load(path_to_detector) 27 | 28 | self.fiter_threshold = filter_threshold 29 | self.face_detector = s3fd() 30 | self.face_detector.load_state_dict(model_weights) 31 | self.face_detector.to(device) 32 | self.face_detector.eval() 33 | 34 | def _filter_bboxes(self, bboxlist): 35 | if len(bboxlist) > 0: 36 | keep = nms(bboxlist, 0.3) 37 | bboxlist = bboxlist[keep, :] 38 | bboxlist = [x for x in bboxlist if x[-1] > self.fiter_threshold] 39 | 40 | return bboxlist 41 | 42 | def detect_from_image(self, tensor_or_path): 43 | image = self.tensor_or_path_to_ndarray(tensor_or_path) 44 | 45 | bboxlist = detect(self.face_detector, image, device=self.device)[0] 46 | bboxlist = self._filter_bboxes(bboxlist) 47 | 48 | return bboxlist 49 | 50 | def detect_from_batch(self, tensor): 51 | bboxlists = batch_detect(self.face_detector, tensor, device=self.device) 52 | 53 | new_bboxlists = [] 54 | for i in range(bboxlists.shape[0]): 55 | bboxlist = bboxlists[i] 56 | bboxlist = self._filter_bboxes(bboxlist) 57 | new_bboxlists.append(bboxlist) 58 | 59 | return new_bboxlists 60 | 61 | @property 62 | def reference_scale(self): 63 | return 195 64 | 65 | @property 66 | def reference_x_shift(self): 67 | return 0 68 | 69 | @property 70 | def reference_y_shift(self): 71 | return 0 72 | -------------------------------------------------------------------------------- /face-alignment/requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | scipy>=0.17.0 3 | scikit-image 4 | numba 5 | -------------------------------------------------------------------------------- /face-alignment/setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 1.3.4 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:setup.py] 7 | search = version='{current_version}' 8 | replace = version='{new_version}' 9 | 10 | [bumpversion:file:face_alignment/__init__.py] 11 | search = __version__ = '{current_version}' 12 | replace = __version__ = '{new_version}' 13 | 14 | [metadata] 15 | description_file = README.md 16 | 17 | [bdist_wheel] 18 | universal = 1 19 | 20 | [flake8] 21 | exclude = 22 | .github, 23 | examples, 24 | docs, 25 | .tox, 26 | bin, 27 | dist, 28 | tools, 29 | *.egg-info, 30 | __init__.py, 31 | *.yml 32 | max-line-length = 160 -------------------------------------------------------------------------------- /face-alignment/setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | from os import path 4 | import re 5 | from setuptools import setup, find_packages 6 | # To use consisten encodings 7 | from codecs import open 8 | 9 | # Function from: https://github.com/pytorch/vision/blob/master/setup.py 10 | 11 | 12 | def read(*names, **kwargs): 13 | with io.open( 14 | os.path.join(os.path.dirname(__file__), *names), 15 | encoding=kwargs.get("encoding", "utf8") 16 | ) as fp: 17 | return fp.read() 18 | 19 | # Function from: https://github.com/pytorch/vision/blob/master/setup.py 20 | 21 | 22 | def find_version(*file_paths): 23 | version_file = read(*file_paths) 24 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", 25 | version_file, re.M) 26 | if version_match: 27 | return version_match.group(1) 28 | raise RuntimeError("Unable to find version string.") 29 | 30 | here = path.abspath(path.dirname(__file__)) 31 | 32 | # Get the long description from the README file 33 | with open(path.join(here, 'README.md'), encoding='utf-8') as readme_file: 34 | long_description = readme_file.read() 35 | 36 | VERSION = find_version('face_alignment', '__init__.py') 37 | 38 | requirements = [ 39 | 'torch', 40 | 'numpy', 41 | 'scipy>=0.17', 42 | 'scikit-image', 43 | 'opencv-python', 44 | 'tqdm', 45 | 'numba', 46 | 'enum34;python_version<"3.4"' 47 | ] 48 | 49 | setup( 50 | name='face_alignment', 51 | version=VERSION, 52 | 53 | description="Detector 2D or 3D face landmarks from Python", 54 | long_description=long_description, 55 | long_description_content_type="text/markdown", 56 | 57 | # Author details 58 | author="Adrian Bulat", 59 | author_email="adrian@adrianbulat.com", 60 | url="https://github.com/1adrianb/face-alignment", 61 | 62 | # Package info 63 | packages=find_packages(exclude=('test',)), 64 | 65 | python_requires='>=3', 66 | install_requires=requirements, 67 | license='BSD', 68 | zip_safe=True, 69 | 70 | classifiers=[ 71 | 'Development Status :: 5 - Production/Stable', 72 | 'Operating System :: OS Independent', 73 | 'License :: OSI Approved :: BSD License', 74 | 'Natural Language :: English', 75 | 76 | # Supported python versions 77 | 'Programming Language :: Python :: 3', 78 | 'Programming Language :: Python :: 3.3', 79 | 'Programming Language :: Python :: 3.4', 80 | 'Programming Language :: Python :: 3.5', 81 | 'Programming Language :: Python :: 3.6', 82 | 'Programming Language :: Python :: 3.7', 83 | 'Programming Language :: Python :: 3.8', 84 | ], 85 | ) 86 | -------------------------------------------------------------------------------- /face-alignment/test/facealignment_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import face_alignment 4 | import sys 5 | import torch 6 | sys.path.append('.') 7 | from face_alignment.utils import get_image 8 | 9 | 10 | class Tester(unittest.TestCase): 11 | def setUp(self) -> None: 12 | self.reference_data = [np.array([[137., 240., -85.907196], 13 | [140., 264., -81.1443], 14 | [143., 288., -76.25633], 15 | [146., 306., -69.01708], 16 | [152., 327., -53.775352], 17 | [161., 342., -30.029667], 18 | [170., 348., -2.792292], 19 | [185., 354., 23.522688], 20 | [212., 360., 38.664257], 21 | [239., 357., 31.747217], 22 | [263., 354., 12.192401], 23 | [284., 348., -10.0569725], 24 | [302., 333., -29.42916], 25 | [314., 315., -41.675602], 26 | [320., 297., -46.924263], 27 | [326., 276., -50.33218], 28 | [335., 252., -53.945686], 29 | [152., 207., -7.6189857], 30 | [164., 201., 6.1879144], 31 | [176., 198., 16.991247], 32 | [188., 198., 24.690582], 33 | [200., 201., 29.248188], 34 | [245., 204., 37.878166], 35 | [257., 201., 37.420483], 36 | [269., 201., 34.163113], 37 | [284., 204., 28.480812], 38 | [299., 216., 18.31863], 39 | [221., 225., 37.93351], 40 | [218., 237., 48.337395], 41 | [215., 249., 60.502884], 42 | [215., 261., 63.353687], 43 | [203., 273., 40.186855], 44 | [209., 276., 45.057003], 45 | [218., 276., 48.56715], 46 | [227., 276., 47.744766], 47 | [233., 276., 45.01401], 48 | [170., 228., 7.166072], 49 | [179., 222., 17.168053], 50 | [188., 222., 19.775822], 51 | [200., 228., 19.06176], 52 | [191., 231., 20.636724], 53 | [179., 231., 16.125824], 54 | [248., 231., 28.566122], 55 | [257., 225., 33.024036], 56 | [269., 225., 34.384735], 57 | [278., 231., 27.014532], 58 | [269., 234., 32.867023], 59 | [257., 234., 33.34033], 60 | [185., 306., 29.927242], 61 | [194., 297., 42.611233], 62 | [209., 291., 50.563396], 63 | [215., 291., 52.831104], 64 | [221., 291., 52.9225], 65 | [236., 300., 48.32575], 66 | [248., 309., 38.2375], 67 | [236., 312., 48.377922], 68 | [224., 315., 52.63793], 69 | [212., 315., 52.330444], 70 | [203., 315., 49.552994], 71 | [194., 309., 42.64459], 72 | [188., 303., 30.746407], 73 | [206., 300., 46.514435], 74 | [215., 300., 49.611156], 75 | [224., 300., 49.058918], 76 | [248., 309., 38.084103], 77 | [224., 303., 49.817806], 78 | [215., 303., 49.59815], 79 | [206., 303., 47.13894]], dtype=np.float32)] 80 | 81 | def test_predict_points(self): 82 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, device='cpu') 83 | preds = fa.get_landmarks('test/assets/aflw-test.jpg') 84 | self.assertEqual(len(preds), len(self.reference_data)) 85 | for pred, reference in zip(preds, self.reference_data): 86 | self.assertTrue(np.allclose(pred, reference)) 87 | 88 | def test_predict_batch_points(self): 89 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, device='cpu') 90 | 91 | reference_data = self.reference_data + self.reference_data 92 | reference_data.append([]) 93 | image = get_image('test/assets/aflw-test.jpg') 94 | batch = np.stack([image, image, np.zeros_like(image)]) 95 | batch = torch.Tensor(batch.transpose(0, 3, 1, 2)) 96 | 97 | preds = fa.get_landmarks_from_batch(batch) 98 | 99 | self.assertEqual(len(preds), len(reference_data)) 100 | for pred, reference in zip(preds, reference_data): 101 | self.assertTrue(np.allclose(pred, reference)) 102 | 103 | def test_predict_points_from_dir(self): 104 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, device='cpu') 105 | 106 | reference_data = { 107 | 'test/assets/grass.jpg': None, 108 | 'test/assets/aflw-test.jpg': self.reference_data} 109 | 110 | preds = fa.get_landmarks_from_directory('test/assests/') 111 | 112 | for k, points in preds.items(): 113 | if isinstance(points, list): 114 | for p, p_reference in zip(points, reference_data[k]): 115 | self.assertTrue(np.allclose(p, p_reference)) 116 | else: 117 | self.assertEqual(points, reference_data[k]) 118 | 119 | 120 | if __name__ == '__main__': 121 | unittest.main() 122 | -------------------------------------------------------------------------------- /face-alignment/test/smoke_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import face_alignment 3 | -------------------------------------------------------------------------------- /face-alignment/test/test_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | import unittest 4 | from face_alignment.utils import * 5 | import numpy as np 6 | import torch 7 | 8 | 9 | class Tester(unittest.TestCase): 10 | def test_flip_is_label(self): 11 | # Generate the points 12 | heatmaps = torch.from_numpy(np.random.randint(1, high=250, size=(68, 64, 64)).astype('float32')) 13 | 14 | flipped_heatmaps = flip(flip(heatmaps.clone(), is_label=True), is_label=True) 15 | 16 | assert np.allclose(heatmaps.numpy(), flipped_heatmaps.numpy()) 17 | 18 | def test_flip_is_image(self): 19 | fake_image = torch.torch.rand(3, 256, 256) 20 | fliped_fake_image = flip(flip(fake_image.clone())) 21 | 22 | assert np.allclose(fake_image.numpy(), fliped_fake_image.numpy()) 23 | 24 | def test_getpreds(self): 25 | pts = np.random.randint(1, high=63, size=(68, 2)).astype('float32') 26 | 27 | heatmaps = np.zeros((68, 256, 256)) 28 | for i in range(68): 29 | if pts[i, 0] > 0: 30 | heatmaps[i] = draw_gaussian(heatmaps[i], pts[i], 2) 31 | heatmaps = np.expand_dims(heatmaps, axis=0) 32 | 33 | preds, _, _ = get_preds_fromhm(heatmaps) 34 | 35 | assert np.allclose(pts, preds, atol=5) 36 | 37 | def test_create_heatmaps(self): 38 | reference_scale = 195 39 | target_landmarks = torch.randint(0, 255, (1, 68, 2)).type(torch.float) # simulated dataset 40 | bb = create_bounding_box(target_landmarks) 41 | centers = torch.stack([bb[:, 2] - (bb[:, 2] - bb[:, 0]) / 2.0, bb[:, 3] - (bb[:, 3] - bb[:, 1]) / 2.0], dim=1) 42 | centers[:, 1] = centers[:, 1] - (bb[:, 3] - bb[:, 1]) * 0.12 # Not sure where 0.12 comes from 43 | scales = (bb[:, 2] - bb[:, 0] + bb[:, 3] - bb[:, 1]) / reference_scale 44 | heatmaps = create_target_heatmap(target_landmarks, centers, scales) 45 | preds = get_preds_fromhm(heatmaps.numpy(), centers.squeeze().numpy(), scales.squeeze().numpy())[1] 46 | 47 | assert np.allclose(preds, target_landmarks, atol=5) 48 | 49 | if __name__ == '__main__': 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /face-alignment/tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E305,E402,E721,F401,F403,F405,F821,F841,F999,W503 -------------------------------------------------------------------------------- /frames_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from skimage import io, img_as_float32 3 | from skimage.color import gray2rgb 4 | from sklearn.model_selection import train_test_split 5 | from imageio import mimread 6 | 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | import pandas as pd 10 | from augmentation import AllAugmentationTransform 11 | import glob 12 | from PIL import Image 13 | import pdb 14 | def read_video(name, frame_shape): 15 | """ 16 | Read video which can be: 17 | - an image of concatenated frames 18 | - '.mp4' and'.gif' 19 | - folder with videos 20 | """ 21 | 22 | if os.path.isdir(name): 23 | frames = sorted(os.listdir(name)) 24 | num_frames = len(frames) 25 | video_array = np.array( 26 | [img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)]) 27 | elif name.lower().endswith('.png') or name.lower().endswith('.jpg'): 28 | image = io.imread(name) 29 | 30 | if len(image.shape) == 2 or image.shape[2] == 1: 31 | image = gray2rgb(image) 32 | 33 | if image.shape[2] == 4: 34 | image = image[..., :3] 35 | 36 | image = img_as_float32(image) 37 | 38 | video_array = np.moveaxis(image, 1, 0) 39 | 40 | video_array = video_array.reshape((-1,) + frame_shape) 41 | video_array = np.moveaxis(video_array, 1, 2) 42 | elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'): 43 | video = np.array(mimread(name)) 44 | if len(video.shape) == 3: 45 | video = np.array([gray2rgb(frame) for frame in video]) 46 | if video.shape[-1] == 4: 47 | video = video[..., :3] 48 | video_array = img_as_float32(video) 49 | else: 50 | raise Exception("Unknown file extensions %s" % name) 51 | 52 | return video_array 53 | 54 | 55 | class FramesDataset(Dataset): 56 | """ 57 | Dataset of videos, each video can be represented as: 58 | - an image of concatenated frames 59 | - '.mp4' or '.gif' 60 | - folder with all frames 61 | """ 62 | 63 | def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True, 64 | random_seed=0, pairs_list=None, augmentation_params=None): 65 | self.root_dir = root_dir 66 | self.videos = os.listdir(root_dir) 67 | # self.videos = self.videos[5000] 68 | self.frame_shape = tuple(frame_shape) 69 | self.pairs_list = pairs_list 70 | self.id_sampling = id_sampling 71 | if os.path.exists(os.path.join(root_dir, 'train')): 72 | assert os.path.exists(os.path.join(root_dir, 'test')) 73 | print("Use predefined train-test split.") 74 | if id_sampling: 75 | train_videos = {os.path.basename(video).split('#')[0] for video in 76 | os.listdir(os.path.join(root_dir, 'train'))} 77 | train_videos = list(train_videos) 78 | else: 79 | train_videos = os.listdir(os.path.join(root_dir, 'train')) 80 | test_videos = os.listdir(os.path.join(root_dir, 'test')) 81 | self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test') 82 | else: 83 | print("Use random train-test split.") 84 | train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2) 85 | 86 | if is_train: 87 | self.videos = train_videos 88 | else: 89 | self.videos = test_videos 90 | 91 | self.is_train = is_train 92 | 93 | if self.is_train: 94 | self.transform = AllAugmentationTransform(**augmentation_params) 95 | else: 96 | self.transform = None 97 | 98 | def __len__(self): 99 | return len(self.videos) 100 | 101 | def __getitem__(self, idx): 102 | if self.is_train and self.id_sampling: 103 | name = self.videos[idx] 104 | path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4'))) 105 | else: 106 | name = self.videos[idx] 107 | path = os.path.join(self.root_dir, name) 108 | 109 | video_name = os.path.basename(path) 110 | 111 | if self.is_train and os.path.isdir(path): 112 | frames = os.listdir(path) 113 | num_frames = len(frames) 114 | frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) 115 | # video_array = [img_as_float32(io.imread(os.path.join(path, frames[idx].decode()))) for idx in frame_idx] 116 | video_array = [] 117 | for idx in frame_idx: 118 | try: 119 | video_array.append(img_as_float32(io.imread(os.path.join(path, frames[idx].decode())))) 120 | except Exception as e: 121 | print(e) 122 | else: 123 | video_array = read_video(path, frame_shape=self.frame_shape) 124 | num_frames = len(video_array) 125 | frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range( 126 | num_frames) 127 | video_array = video_array[frame_idx] 128 | 129 | if self.transform is not None: 130 | video_array = self.transform(video_array) 131 | 132 | out = {} 133 | if self.is_train: 134 | source = np.array(video_array[0], dtype='float32') 135 | driving = np.array(video_array[1], dtype='float32') 136 | 137 | out['driving'] = driving.transpose((2, 0, 1)) 138 | out['source'] = source.transpose((2, 0, 1)) 139 | else: 140 | video = np.array(video_array, dtype='float32') 141 | out['video'] = video.transpose((3, 0, 1, 2)) 142 | 143 | out['name'] = video_name 144 | 145 | return out 146 | 147 | 148 | class DatasetRepeater(Dataset): 149 | """ 150 | Pass several times over the same dataset for better i/o performance 151 | """ 152 | 153 | def __init__(self, dataset, num_repeats=100): 154 | self.dataset = dataset 155 | self.num_repeats = num_repeats 156 | 157 | def __len__(self): 158 | return self.num_repeats * self.dataset.__len__() 159 | 160 | def __getitem__(self, idx): 161 | return self.dataset[idx % self.dataset.__len__()] 162 | 163 | 164 | class PairedDataset(Dataset): 165 | """ 166 | Dataset of pairs for animation. 167 | """ 168 | 169 | def __init__(self, initial_dataset, number_of_pairs, seed=0): 170 | self.initial_dataset = initial_dataset 171 | pairs_list = self.initial_dataset.pairs_list 172 | np.random.seed(seed) 173 | 174 | if pairs_list is None: 175 | max_idx = min(number_of_pairs, len(initial_dataset)) 176 | nx, ny = max_idx, max_idx 177 | xy = np.mgrid[:nx, :ny].reshape(2, -1).T 178 | number_of_pairs = min(xy.shape[0], number_of_pairs) 179 | self.pairs = xy.take(np.random.choice(xy.shape[0], number_of_pairs, replace=False), axis=0) 180 | else: 181 | videos = self.initial_dataset.videos 182 | name_to_index = {name: index for index, name in enumerate(videos)} 183 | pairs = pd.read_csv(pairs_list) 184 | pairs = pairs[np.logical_and(pairs['source'].isin(videos), pairs['driving'].isin(videos))] 185 | number_of_pairs = min(pairs.shape[0], number_of_pairs) 186 | self.pairs = [] 187 | self.start_frames = [] 188 | for ind in range(number_of_pairs): 189 | self.pairs.append( 190 | (name_to_index[pairs['driving'].iloc[ind]], name_to_index[pairs['source'].iloc[ind]])) 191 | 192 | def __len__(self): 193 | return len(self.pairs) 194 | 195 | def __getitem__(self, idx): 196 | pair = self.pairs[idx] 197 | first = self.initial_dataset[pair[0]] 198 | second = self.initial_dataset[pair[1]] 199 | first = {'driving_' + key: value for key, value in first.items()} 200 | second = {'source_' + key: value for key, value in second.items()} 201 | 202 | return {**first, **second} 203 | -------------------------------------------------------------------------------- /kill_port.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import signal 4 | import pdb 5 | def kill_process(*pids): 6 | for pid in pids: 7 | a = os.kill(pid, signal.SIGKILL) 8 | print('已杀死pid为%s的进程, 返回值是:%s' % (pid, a)) 9 | 10 | def get_pid(*ports): 11 | #其中\"为转义" 12 | pids = [] 13 | print(ports) 14 | for port in ports: 15 | msg = os.popen('lsof -i:{}'.format(port)).read() 16 | msg = msg.split('\n')[1:-1] 17 | for m in msg: 18 | m = m.replace(' ', ' ') 19 | m = m.replace(' ', ' ') 20 | tokens = m.split(' ') 21 | pids.append(int(tokens[1])) 22 | return pids 23 | 24 | if __name__ == "__main__": 25 | # 杀死占用端口号的ps进程 26 | ports = sys.argv[1:] 27 | kill_process(*get_pid(*ports)) 28 | 29 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import imageio 5 | 6 | import os 7 | from skimage.draw import ellipse 8 | import pdb 9 | import matplotlib.pyplot as plt 10 | import collections 11 | 12 | class Logger: 13 | def __init__(self, log_dir, checkpoint_freq=100, visualizer_params=None, zfill_num=8, log_file_name='log.txt'): 14 | 15 | self.loss_list = [] 16 | self.cpk_dir = log_dir 17 | self.visualizations_dir = os.path.join(log_dir, 'train-vis') 18 | if not os.path.exists(self.visualizations_dir): 19 | os.makedirs(self.visualizations_dir) 20 | self.log_file = open(os.path.join(log_dir, log_file_name), 'a') 21 | self.zfill_num = zfill_num 22 | self.visualizer = Visualizer(**visualizer_params) 23 | self.checkpoint_freq = checkpoint_freq 24 | self.epoch = 0 25 | self.best_loss = float('inf') 26 | self.names = None 27 | 28 | def log_scores(self, loss_names): 29 | loss_mean = np.array(self.loss_list).mean(axis=0) 30 | 31 | loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)]) 32 | loss_string = str(self.epoch).zfill(self.zfill_num) + ") " + loss_string 33 | 34 | print(loss_string, file=self.log_file) 35 | self.loss_list = [] 36 | self.log_file.flush() 37 | 38 | def visualize_rec(self, inp, out): 39 | image = self.visualizer.visualize(inp['driving'], inp['source'], out) 40 | imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)), image) 41 | 42 | def save_cpk(self, emergent=False): 43 | cpk = {k: v.state_dict() for k, v in self.models.items()} 44 | cpk['epoch'] = self.epoch 45 | cpk_path = os.path.join(self.cpk_dir, '%s-checkpoint.pth.tar' % str(self.epoch).zfill(self.zfill_num)) 46 | if not (os.path.exists(cpk_path) and emergent): 47 | torch.save(cpk, cpk_path) 48 | 49 | @staticmethod 50 | def load_cpk(checkpoint_path, generator=None, discriminator=None, kp_detector=None, 51 | optimizer_generator=None, optimizer_discriminator=None, optimizer_kp_detector=None): 52 | num_gpu = torch.cuda.device_count() 53 | if num_gpu == 1: 54 | checkpoint = torch.load(checkpoint_path,map_location='cuda:0') 55 | else: 56 | checkpoint = torch.load(checkpoint_path,map_location='cpu') 57 | if generator is not None: 58 | ckp_generator = collections.OrderedDict((k.replace('.module.','.'),v) for k,v in checkpoint['generator'].items()) 59 | generator.load_state_dict(ckp_generator) 60 | if kp_detector is not None: 61 | ckp_kp_detector = collections.OrderedDict((k.replace('.module.','.'),v) for k,v in checkpoint['kp_detector'].items()) 62 | kp_detector.load_state_dict(ckp_kp_detector) 63 | if discriminator is not None: 64 | try: 65 | ckp_discriminator = collections.OrderedDict((k.replace('.module.','.'),v) for k,v in checkpoint['discriminator'].items()) 66 | discriminator.load_state_dict(ckp_discriminator) 67 | except: 68 | print ('No discriminator in the state-dict. Dicriminator will be randomly initialized') 69 | if optimizer_generator is not None: 70 | ckp_optimizer_generator = collections.OrderedDict((k.replace('.module.','.'),v) for k,v in checkpoint['optimizer_generator'].items()) 71 | optimizer_generator.load_state_dict(ckp_optimizer_generator) 72 | if optimizer_discriminator is not None: 73 | try: 74 | ckp_optimizer_discriminator = collections.OrderedDict((k.replace('.module.','.'),v) for k,v in checkpoint['optimizer_discriminator'].items()) 75 | optimizer_discriminator.load_state_dict(ckp_optimizer_discriminator) 76 | except RuntimeError as e: 77 | print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized') 78 | if optimizer_kp_detector is not None: 79 | ckp_optimizer_kp_detector = collections.OrderedDict((k.replace('.module.','.'),v) for k,v in checkpoint['optimizer_kp_detector'].items()) 80 | optimizer_kp_detector.load_state_dict(ckp_optimizer_kp_detector) 81 | 82 | return checkpoint['epoch'] 83 | 84 | def __enter__(self): 85 | return self 86 | 87 | def __exit__(self, exc_type, exc_val, exc_tb): 88 | if 'models' in self.__dict__: 89 | self.save_cpk() 90 | self.log_file.close() 91 | 92 | def log_iter(self, losses): 93 | losses = collections.OrderedDict(losses.items()) 94 | if self.names is None: 95 | self.names = list(losses.keys()) 96 | self.loss_list.append(list(losses.values())) 97 | 98 | def log_epoch(self, epoch, models, inp, out): 99 | self.epoch = epoch 100 | self.models = models 101 | if (self.epoch + 1) % self.checkpoint_freq == 0: 102 | self.save_cpk() 103 | self.log_scores(self.names) 104 | # self.visualize_rec(inp, out) 105 | 106 | 107 | class Visualizer: 108 | def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow'): 109 | self.kp_size = kp_size 110 | self.draw_border = draw_border 111 | self.colormap = plt.get_cmap(colormap) 112 | 113 | def draw_image_with_kp(self, image, kp_array): 114 | image = np.copy(image) 115 | spatial_size = np.array(image.shape[:2][::-1])[np.newaxis] 116 | kp_array = spatial_size * (kp_array + 1) / 2 117 | num_kp = kp_array.shape[0] 118 | for kp_ind, kp in enumerate(kp_array): 119 | rr, cc = ellipse(kp[1], kp[0], self.kp_size,self.kp_size, shape=image.shape[:2]) 120 | image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3] 121 | return image 122 | 123 | def create_image_column_with_kp(self, images, kp): 124 | image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)]) 125 | return self.create_image_column(image_array) 126 | 127 | def create_image_column(self, images): 128 | if self.draw_border: 129 | images = np.copy(images) 130 | images[:, :, [0, -1]] = (1, 1, 1) 131 | images[:, :, [0, -1]] = (1, 1, 1) 132 | return np.concatenate(list(images), axis=0) 133 | 134 | def create_image_grid(self, *args): 135 | out = [] 136 | for arg in args: 137 | if type(arg) == tuple: 138 | out.append(self.create_image_column_with_kp(arg[0], arg[1])) 139 | else: 140 | out.append(self.create_image_column(arg)) 141 | return np.concatenate(out, axis=1) 142 | 143 | def visualize(self, driving, source, out): 144 | images = [] 145 | 146 | # Source image with keypoints 147 | source = source.data.cpu() 148 | kp_source = out['kp_source']['value'].data.cpu().numpy() 149 | source = np.transpose(source, [0, 2, 3, 1]) 150 | images.append((source, kp_source)) 151 | 152 | # Equivariance visualization 153 | if 'transformed_frame' in out: 154 | transformed = out['transformed_frame'].data.cpu().numpy() 155 | transformed = np.transpose(transformed, [0, 2, 3, 1]) 156 | transformed_kp = out['transformed_kp']['value'].data.cpu().numpy() 157 | images.append((transformed, transformed_kp)) 158 | 159 | # Driving image with keypoints 160 | kp_driving = out['kp_driving']['value'].data.cpu().numpy() 161 | driving = driving.data.cpu().numpy() 162 | driving = np.transpose(driving, [0, 2, 3, 1]) 163 | images.append((driving, kp_driving)) 164 | 165 | # Deformed image 166 | if 'deformed' in out: 167 | deformed = out['deformed'].data.cpu().numpy() 168 | deformed = np.transpose(deformed, [0, 2, 3, 1]) 169 | images.append(deformed) 170 | 171 | # Result with and without keypoints 172 | prediction = out['prediction'].data.cpu().numpy() 173 | prediction = np.transpose(prediction, [0, 2, 3, 1]) 174 | if 'kp_norm' in out: 175 | kp_norm = out['kp_norm']['value'].data.cpu().numpy() 176 | images.append((prediction, kp_norm)) 177 | images.append(prediction) 178 | 179 | 180 | ## Occlusion map 181 | if 'occlusion_map' in out: 182 | occlusion_map = out['occlusion_map'].data.cpu().repeat(1, 3, 1, 1) 183 | occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy() 184 | occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1]) 185 | images.append(occlusion_map) 186 | 187 | # Deformed images according to each individual transform 188 | if 'sparse_deformed' in out: 189 | full_mask = [] 190 | for i in range(out['sparse_deformed'].shape[1]): 191 | image = out['sparse_deformed'][:, i].data.cpu() 192 | image = F.interpolate(image, size=source.shape[1:3]) 193 | mask = out['mask'][:, i:(i+1)].data.cpu().repeat(1, 3, 1, 1) 194 | mask = F.interpolate(mask, size=source.shape[1:3]) 195 | image = np.transpose(image.numpy(), (0, 2, 3, 1)) 196 | mask = np.transpose(mask.numpy(), (0, 2, 3, 1)) 197 | 198 | if i != 0: 199 | color = np.array(self.colormap((i - 1) / (out['sparse_deformed'].shape[1] - 1)))[:3] 200 | else: 201 | color = np.array((0, 0, 0)) 202 | 203 | color = color.reshape((1, 1, 1, 3)) 204 | 205 | images.append(image) 206 | if i != 0: 207 | images.append(mask * color) 208 | else: 209 | images.append(mask) 210 | 211 | full_mask.append(mask * color) 212 | 213 | images.append(sum(full_mask)) 214 | 215 | image = self.create_image_grid(*images) 216 | image = (255 * image).astype(np.uint8) 217 | return image 218 | -------------------------------------------------------------------------------- /modules/AdaIN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def calc_mean_std(feat, eps=1e-5): 4 | # eps is a small value added to the variance to avoid divide-by-zero. 5 | size = feat.size() 6 | assert (len(size) == 4) 7 | N, C = size[:2] 8 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 9 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 10 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 11 | return feat_mean, feat_std 12 | 13 | def adaptive_instance_normalization(content_feat, style_feat): 14 | assert (content_feat.size()[:2] == style_feat.size()[:2]) 15 | size = content_feat.size() 16 | style_mean, style_std = calc_mean_std(style_feat) 17 | content_mean, content_std = calc_mean_std(content_feat) 18 | normalized_feat = (content_feat - content_mean.expand( 19 | size)) / content_std.expand(size) 20 | 21 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 22 | 23 | def _calc_feat_flatten_mean_std(feat): 24 | # takes 3D feat (C, H, W), return mean and std of array within channels 25 | assert (feat.size()[0] == 3) 26 | assert (isinstance(feat, torch.FloatTensor)) 27 | feat_flatten = feat.view(3, -1) 28 | mean = feat_flatten.mean(dim=-1, keepdim=True) 29 | std = feat_flatten.std(dim=-1, keepdim=True) 30 | return feat_flatten, mean, std 31 | 32 | def _mat_sqrt(x): 33 | U, D, V = torch.svd(x) 34 | return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t()) 35 | 36 | def coral(source, target): 37 | # assume both source and target are 3D array (C, H, W) 38 | # Note: flatten -> f 39 | source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source) 40 | source_f_norm = (source_f - source_f_mean.expand_as( 41 | source_f)) / source_f_std.expand_as(source_f) 42 | source_f_cov_eye = \ 43 | torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3) 44 | 45 | target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target) 46 | target_f_norm = (target_f - target_f_mean.expand_as( 47 | target_f)) / target_f_std.expand_as(target_f) 48 | target_f_cov_eye = \ 49 | torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3) 50 | 51 | source_f_norm_transfer = torch.mm( 52 | _mat_sqrt(target_f_cov_eye), 53 | torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)), 54 | source_f_norm) 55 | ) 56 | 57 | source_f_transfer = source_f_norm_transfer * \ 58 | target_f_std.expand_as(source_f_norm) + \ 59 | target_f_mean.expand_as(source_f_norm) 60 | 61 | return source_f_transfer.view(source.size()) -------------------------------------------------------------------------------- /modules/dense_motion.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | import torch 4 | from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian 5 | import pdb 6 | from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d 7 | 8 | 9 | class DenseMotionNetwork(nn.Module): 10 | """ 11 | Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving 12 | """ 13 | 14 | def __init__(self, block_expansion, num_blocks, max_features, num_kp, num_channels, estimate_occlusion_map=False, 15 | scale_factor=1, kp_variance=0.01): 16 | super(DenseMotionNetwork, self).__init__() 17 | self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp + 1) * (num_channels + 1), 18 | max_features=max_features, num_blocks=num_blocks) 19 | 20 | self.mask = nn.Conv2d(self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3)) 21 | 22 | if estimate_occlusion_map: 23 | self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3)) 24 | else: 25 | self.occlusion = None 26 | 27 | self.num_kp = num_kp 28 | self.scale_factor = scale_factor 29 | self.kp_variance = kp_variance 30 | 31 | if self.scale_factor != 1: 32 | self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) 33 | 34 | def create_heatmap_representations(self, source_image, kp_driving, kp_source): 35 | """ 36 | Eq 6. in the paper H_k(z) 37 | """ 38 | spatial_size = source_image.shape[2:] 39 | gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance) 40 | gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance) 41 | heatmap = gaussian_driving - gaussian_source 42 | #adding background feature 43 | zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type()) 44 | heatmap = torch.cat([zeros, heatmap], dim=1) 45 | heatmap = heatmap.unsqueeze(2) 46 | return heatmap 47 | 48 | def create_sparse_motions(self, source_image, kp_driving, kp_source): 49 | """ 50 | Eq 4. in the paper T_{s<-d}(z) 51 | """ 52 | bs, _, h, w = source_image.shape 53 | identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type()) 54 | identity_grid = identity_grid.view(1, 1, h, w, 2) 55 | coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 2) 56 | if 'jacobian' in kp_driving: 57 | jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian'])) 58 | jacobian = jacobian.unsqueeze(-3).unsqueeze(-3) 59 | jacobian = jacobian.repeat(1, 1, h, w, 1, 1) 60 | coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1)) 61 | coordinate_grid = coordinate_grid.squeeze(-1) 62 | 63 | driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 2) 64 | 65 | #adding background feature 66 | identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1) 67 | sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) #bs, num_kp+1,w,h,2 68 | return sparse_motions 69 | 70 | def create_deformed_source_image(self, source_image, sparse_motions): 71 | """ 72 | Eq 7. in the paper \hat{T}_{s<-d}(z) 73 | """ 74 | bs, _, h, w = source_image.shape 75 | source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp + 1, 1, 1, 1, 1) 76 | source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w) 77 | sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1)) 78 | sparse_deformed = F.grid_sample(source_repeat, sparse_motions) 79 | sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w)) 80 | return sparse_deformed 81 | 82 | def forward(self, source_image, kp_driving, kp_source): 83 | if self.scale_factor != 1: 84 | source_image = self.down(source_image) 85 | bs, _, h, w = source_image.shape 86 | out_dict = dict() 87 | heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source) 88 | sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source) 89 | deformed_source = self.create_deformed_source_image(source_image, sparse_motion) 90 | out_dict['sparse_deformed'] = deformed_source 91 | 92 | input = torch.cat([heatmap_representation, deformed_source], dim=2) 93 | input = input.view(bs, -1, h, w) 94 | 95 | prediction = self.hourglass(input) 96 | 97 | mask = self.mask(prediction) 98 | mask = F.softmax(mask, dim=1) 99 | out_dict['mask'] = mask 100 | mask = mask.unsqueeze(2) 101 | sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3) 102 | deformation = (sparse_motion * mask).sum(dim=1) 103 | deformation = deformation.permute(0, 2, 3, 1) 104 | 105 | out_dict['deformation'] = deformation 106 | 107 | # Sec. 3.2 in the paper 108 | if self.occlusion: 109 | occlusion_map = torch.sigmoid(self.occlusion(prediction)) 110 | out_dict['occlusion_map'] = occlusion_map 111 | 112 | return out_dict 113 | -------------------------------------------------------------------------------- /modules/discriminator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | from modules.util import kp2gaussian 4 | import torch 5 | import pdb 6 | 7 | class DownBlock2d(nn.Module): 8 | """ 9 | Simple block for processing video (encoder). 10 | """ 11 | 12 | def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False): 13 | super(DownBlock2d, self).__init__() 14 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) 15 | 16 | if sn: 17 | self.conv = nn.utils.spectral_norm(self.conv) 18 | 19 | if norm: 20 | self.norm = nn.InstanceNorm2d(out_features, affine=True) 21 | else: 22 | self.norm = None 23 | self.pool = pool 24 | 25 | def forward(self, x): 26 | out = x 27 | out = self.conv(out) 28 | if self.norm: 29 | out = self.norm(out) 30 | out = F.leaky_relu(out, 0.2) 31 | if self.pool: 32 | out = F.avg_pool2d(out, (2, 2)) 33 | return out 34 | 35 | 36 | class Discriminator(nn.Module): 37 | """ 38 | Discriminator similar to Pix2Pix 39 | """ 40 | 41 | def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512, 42 | sn=False, use_kp=False, num_kp=10, kp_variance=0.01, **kwargs): 43 | super(Discriminator, self).__init__() 44 | 45 | down_blocks = [] 46 | for i in range(num_blocks): 47 | down_blocks.append( 48 | DownBlock2d(num_channels + num_kp * use_kp if i == 0 else min(max_features, block_expansion * (2 ** i)), 49 | min(max_features, block_expansion * (2 ** (i + 1))), 50 | norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn)) 51 | self.down_blocks = nn.ModuleList(down_blocks) 52 | self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) 53 | if sn: 54 | self.conv = nn.utils.spectral_norm(self.conv) 55 | self.use_kp = use_kp 56 | self.kp_variance = kp_variance 57 | 58 | def forward(self, x, kp=None): 59 | feature_maps = [] 60 | out = x 61 | if self.use_kp: 62 | heatmap = kp2gaussian(kp, x.shape[2:], self.kp_variance) 63 | out = torch.cat([out, heatmap], dim=1) 64 | # print(out.shape) 65 | for down_block in self.down_blocks: 66 | feature_maps.append(down_block(out)) 67 | out = feature_maps[-1] 68 | # print(out.shape) 69 | prediction_map = self.conv(out) 70 | 71 | return feature_maps, prediction_map 72 | 73 | 74 | class MultiScaleDiscriminator(nn.Module): 75 | """ 76 | Multi-scale (scale) discriminator 77 | """ 78 | 79 | def __init__(self, scales=(), **kwargs): 80 | super(MultiScaleDiscriminator, self).__init__() 81 | self.scales = scales 82 | discs = {} 83 | for scale in scales: 84 | discs[str(scale).replace('.', '-')] = Discriminator(**kwargs) 85 | self.discs = nn.ModuleDict(discs) 86 | 87 | def forward(self, x, kp=None): 88 | out_dict = {} 89 | for scale, disc in self.discs.items(): 90 | scale = str(scale).replace('-', '.') 91 | key = 'prediction_' + scale 92 | feature_maps, prediction_map = disc(x[key], kp) 93 | out_dict['feature_maps_' + scale] = feature_maps 94 | out_dict['prediction_map_' + scale] = prediction_map 95 | return out_dict 96 | -------------------------------------------------------------------------------- /modules/keypoint_detector.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | from modules.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d,Hourglass_2branch 5 | import pdb 6 | 7 | class KPDetector(nn.Module): 8 | """ 9 | Detecting a keypoints. Return keypoint position and jacobian near each keypoint. 10 | """ 11 | 12 | def __init__(self, block_expansion, num_kp, num_channels, max_features, 13 | num_blocks, temperature, estimate_jacobian=False, scale_factor=1, 14 | single_jacobian_map=False, pad=0): 15 | super(KPDetector, self).__init__() 16 | self.predictor = Hourglass(block_expansion, in_features=num_channels, 17 | max_features=max_features, num_blocks=num_blocks) 18 | 19 | self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7), 20 | padding=pad) 21 | 22 | if estimate_jacobian: 23 | self.num_jacobian_maps = 1 if single_jacobian_map else num_kp 24 | self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters, 25 | out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad) 26 | self.jacobian.weight.data.zero_() 27 | self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float)) 28 | else: 29 | self.jacobian = None 30 | 31 | self.temperature = temperature 32 | self.scale_factor = scale_factor 33 | if self.scale_factor != 1: 34 | self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) 35 | 36 | def gaussian2kp(self, heatmap): 37 | """ 38 | Extract the mean and from a heatmap 39 | """ 40 | shape = heatmap.shape 41 | heatmap = heatmap.unsqueeze(-1) 42 | grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) 43 | value = (heatmap * grid).sum(dim=(2, 3)) 44 | kp = {'value': value} 45 | 46 | return kp 47 | 48 | def forward(self, x): 49 | if self.scale_factor != 1: 50 | x = self.down(x) 51 | feature_map = self.predictor(x) #x bz,4,64,64 52 | prediction = self.kp(feature_map) 53 | 54 | final_shape = prediction.shape 55 | heatmap = prediction.view(final_shape[0], final_shape[1], -1) 56 | heatmap = F.softmax(heatmap / self.temperature, dim=2) 57 | heatmap = heatmap.view(*final_shape) 58 | 59 | out = self.gaussian2kp(heatmap) 60 | 61 | if self.jacobian is not None: 62 | jacobian_map = self.jacobian(feature_map) 63 | # pdb.set_trace() 64 | jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2], 65 | final_shape[3]) 66 | heatmap = heatmap.unsqueeze(2) 67 | 68 | jacobian = heatmap * jacobian_map 69 | jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) 70 | jacobian = jacobian.sum(dim=-1) 71 | jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) 72 | out['jacobian'] = jacobian 73 | 74 | return out 75 | 76 | -------------------------------------------------------------------------------- /reconstruction.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from logger import Logger, Visualizer 6 | import numpy as np 7 | import imageio 8 | from sync_batchnorm import DataParallelWithCallback 9 | import depth 10 | 11 | 12 | def reconstruction(config, generator, kp_detector, checkpoint, log_dir, dataset): 13 | png_dir = os.path.join(log_dir, 'reconstruction/png') 14 | log_dir = os.path.join(log_dir, 'reconstruction') 15 | 16 | if checkpoint is not None: 17 | Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector) 18 | else: 19 | raise AttributeError("Checkpoint should be specified for mode='reconstruction'.") 20 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) 21 | 22 | if not os.path.exists(log_dir): 23 | os.makedirs(log_dir) 24 | 25 | if not os.path.exists(png_dir): 26 | os.makedirs(png_dir) 27 | 28 | loss_list = [] 29 | if torch.cuda.is_available(): 30 | generator = DataParallelWithCallback(generator) 31 | kp_detector = DataParallelWithCallback(kp_detector) 32 | 33 | depth_encoder = depth.ResnetEncoder(18, False).cuda() 34 | depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4)).cuda() 35 | loaded_dict_enc = torch.load('depth/models/weights_19/encoder.pth') 36 | loaded_dict_dec = torch.load('depth/models/weights_19/depth.pth') 37 | filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()} 38 | depth_encoder.load_state_dict(filtered_dict_enc) 39 | depth_decoder.load_state_dict(loaded_dict_dec) 40 | depth_decoder.eval() 41 | depth_encoder.eval() 42 | 43 | generator.eval() 44 | kp_detector.eval() 45 | 46 | for it, x in tqdm(enumerate(dataloader)): 47 | if config['reconstruction_params']['num_videos'] is not None: 48 | if it > config['reconstruction_params']['num_videos']: 49 | break 50 | with torch.no_grad(): 51 | predictions = [] 52 | visualizations = [] 53 | if torch.cuda.is_available(): 54 | x['video'] = x['video'].cuda() 55 | 56 | outputs = depth_decoder(depth_encoder(x['video'][:, :, 0])) 57 | depth_source = outputs[("disp", 0)] 58 | source_rgbd = torch.cat((x['video'][:, :, 0],depth_source),1) 59 | 60 | kp_source = kp_detector(source_rgbd) 61 | 62 | for frame_idx in range(x['video'].shape[2]): 63 | source = x['video'][:, :, 0] 64 | 65 | driving = x['video'][:, :, frame_idx] 66 | outputs = depth_decoder(depth_encoder(driving)) 67 | depth_driving = outputs[("disp", 0)] 68 | driving_rgbd = torch.cat((driving,depth_driving),1) 69 | 70 | kp_driving = kp_detector(driving_rgbd) 71 | out = generator(source, kp_source=kp_source, kp_driving=kp_driving) 72 | out['kp_source'] = kp_source 73 | out['kp_driving'] = kp_driving 74 | del out['sparse_deformed'] 75 | predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) 76 | 77 | visualization = Visualizer(**config['visualizer_params']).visualize(source=source, 78 | driving=driving, out=out) 79 | visualizations.append(visualization) 80 | 81 | loss_list.append(torch.abs(out['prediction'] - driving).mean().cpu().numpy()) 82 | 83 | predictions = np.concatenate(predictions, axis=1) 84 | imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8)) 85 | 86 | image_name = x['name'][0] + config['reconstruction_params']['format'] 87 | imageio.mimsave(os.path.join(log_dir, image_name), visualizations) 88 | 89 | print("Reconstruction loss: %s" % np.mean(loss_list)) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | certifi==2021.10.8 3 | cycler==0.11.0 4 | fonttools==4.33.2 5 | grpcio==1.44.0 6 | imageio==2.17.0 7 | importlib-metadata==4.11.3 8 | joblib==1.1.0 9 | kiwisolver==1.4.2 10 | Markdown==3.3.6 11 | matplotlib==3.5.1 12 | networkx==2.6.3 13 | numpy==1.21.6 14 | packaging==21.3 15 | pandas==1.3.5 16 | Pillow==9.1.0 17 | protobuf==3.20.1 18 | pyparsing==3.0.8 19 | python-dateutil==2.8.2 20 | pytz==2022.1 21 | PyWavelets==1.3.0 22 | PyYAML==5.4.1 23 | scikit-image==0.16.2 24 | scikit-learn==1.0.2 25 | scipy==1.7.3 26 | six==1.16.0 27 | sklearn==0.0 28 | tensorboard==1.15.0 29 | threadpoolctl==3.1.0 30 | tifffile==2021.11.2 31 | torch 32 | torchaudio==0.10.1+rocm4.1 33 | torchvision 34 | tqdm==4.64.0 35 | typing_extensions==4.2.0 36 | Werkzeug==2.1.1 37 | zipp==3.8.0 38 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | 3 | matplotlib.use('Agg') 4 | 5 | import os, sys 6 | import yaml 7 | from argparse import ArgumentParser 8 | from time import gmtime, strftime 9 | from shutil import copy 10 | 11 | from frames_dataset import FramesDataset 12 | import pdb 13 | # from modules.generator import OcclusionAwareGenerator 14 | import modules.generator as generator 15 | from modules.discriminator import MultiScaleDiscriminator 16 | # from modules.keypoint_detector import KPDetector 17 | import modules.keypoint_detector as KPD 18 | import torch.distributed as dist 19 | from torch.nn.parallel import DistributedDataParallel as DDP 20 | 21 | import torch 22 | from torch.utils.tensorboard import SummaryWriter 23 | from train import train 24 | # from reconstruction import reconstruction 25 | from animate import animate 26 | import random 27 | import numpy as np 28 | 29 | 30 | if __name__ == "__main__": 31 | 32 | if sys.version_info[0] < 3: 33 | raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") 34 | 35 | parser = ArgumentParser() 36 | parser.add_argument("--config", required=True, help="path to config") 37 | parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "animate"]) 38 | parser.add_argument("--log_dir", default='log', help="path to log into") 39 | parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore") 40 | parser.add_argument("--device_ids", default="0", type=lambda x: list(map(int, x.split(','))), 41 | help="Names of the devices comma separated.") 42 | parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture") 43 | parser.add_argument("--local_rank", type=int) 44 | parser.add_argument("--use_depth",action='store_true',help='depth mode') 45 | parser.add_argument("--rgbd",action='store_true',help='rgbd mode') 46 | parser.add_argument("--kp_prior",action='store_true',help='use kp_prior in final objective function') 47 | 48 | # alter model 49 | parser.add_argument("--generator",required=True,help='the type of genertor') 50 | parser.add_argument("--kp_detector",default='KPDetector',type=str,help='the type of KPDetector') 51 | parser.add_argument("--GFM",default='GeneratorFullModel',help='the type of GeneratorFullModel') 52 | 53 | parser.add_argument("--batchsize",type=int, default=-1,help='user defined batchsize') 54 | parser.add_argument("--kp_num",type=int, default=-1,help='user defined keypoint number') 55 | parser.add_argument("--kp_distance",type=int, default=10,help='the weight of kp_distance loss') 56 | parser.add_argument("--depth_constraint",type=int, default=0,help='the weight of depth_constraint loss') 57 | 58 | parser.add_argument("--name",type=str,help='user defined model saved name') 59 | 60 | parser.set_defaults(verbose=False) 61 | opt = parser.parse_args() 62 | with open(opt.config) as f: 63 | config = yaml.load(f) 64 | 65 | if opt.checkpoint is not None: 66 | log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1]) 67 | else: 68 | log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0]) 69 | log_dir += opt.name 70 | 71 | 72 | print("Training...") 73 | 74 | dist.init_process_group(backend='nccl', init_method='env://') 75 | torch.cuda.set_device(opt.local_rank) 76 | device=torch.device("cuda",opt.local_rank) 77 | config['train_params']['loss_weights']['depth_constraint'] = opt.depth_constraint 78 | config['train_params']['loss_weights']['kp_distance'] = opt.kp_distance 79 | if opt.kp_prior: 80 | config['train_params']['loss_weights']['kp_distance'] = 0 81 | config['train_params']['loss_weights']['kp_prior'] = 10 82 | if opt.batchsize != -1: 83 | config['train_params']['batch_size'] = opt.batchsize 84 | if opt.kp_num != -1: 85 | config['model_params']['common_params']['num_kp'] = opt.kp_num 86 | # create generator 87 | generator = getattr(generator, opt.generator)(**config['model_params']['generator_params'], 88 | **config['model_params']['common_params']) 89 | generator.to(device) 90 | if opt.verbose: 91 | print(generator) 92 | generator= torch.nn.SyncBatchNorm.convert_sync_batchnorm(generator) 93 | 94 | # create discriminator 95 | discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'], 96 | **config['model_params']['common_params']) 97 | 98 | discriminator.to(device) 99 | if opt.verbose: 100 | print(discriminator) 101 | discriminator= torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator) 102 | 103 | # create kp_detector 104 | if opt.use_depth: 105 | config['model_params']['common_params']['num_channels'] = 1 106 | if opt.rgbd: 107 | config['model_params']['common_params']['num_channels'] = 4 108 | 109 | kp_detector = getattr(KPD, opt.kp_detector)(**config['model_params']['kp_detector_params'], 110 | **config['model_params']['common_params']) 111 | kp_detector.to(device) 112 | if opt.verbose: 113 | print(kp_detector) 114 | kp_detector= torch.nn.SyncBatchNorm.convert_sync_batchnorm(kp_detector) 115 | 116 | kp_detector = DDP(kp_detector,device_ids=[opt.local_rank],broadcast_buffers=False) 117 | discriminator = DDP(discriminator,device_ids=[opt.local_rank],broadcast_buffers=False) 118 | generator = DDP(generator,device_ids=[opt.local_rank],broadcast_buffers=False) 119 | 120 | dataset = FramesDataset(is_train=(opt.mode == 'train'), **config['dataset_params']) 121 | if not os.path.exists(log_dir): 122 | os.makedirs(log_dir) 123 | if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))): 124 | copy(opt.config, log_dir) 125 | 126 | if not os.path.exists(os.path.join(log_dir,'log')): 127 | os.makedirs(os.path.join(log_dir,'log')) 128 | writer = SummaryWriter(os.path.join(log_dir,'log')) 129 | if opt.mode == 'train': 130 | train(config, generator, discriminator, kp_detector, opt.checkpoint, log_dir, dataset, opt.local_rank,device,opt,writer) -------------------------------------------------------------------------------- /run_dataparallel.py: -------------------------------------------------------------------------------- 1 | 2 | import os, sys 3 | import yaml 4 | from argparse import ArgumentParser 5 | from shutil import copy 6 | 7 | from frames_dataset import FramesDataset 8 | import pdb 9 | import modules.generator as generator 10 | from modules.discriminator import MultiScaleDiscriminator 11 | import modules.keypoint_detector as KPD 12 | 13 | import torch 14 | from torch.utils.tensorboard import SummaryWriter 15 | from train_dataparallel import train 16 | # from reconstruction import reconstruction 17 | from animate import animate 18 | import random 19 | import numpy as np 20 | 21 | if __name__ == "__main__": 22 | 23 | if sys.version_info[0] < 3: 24 | raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") 25 | 26 | parser = ArgumentParser() 27 | parser.add_argument("--config", required=True, help="path to config") 28 | parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "animate"]) 29 | parser.add_argument("--log_dir", default='log', help="path to log into") 30 | parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore") 31 | parser.add_argument("--device_ids", default="0", type=lambda x: list(map(int, x.split(','))), 32 | help="Names of the devices comma separated.") 33 | parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture") 34 | parser.add_argument("--use_depth",action='store_true',help='depth mode') 35 | parser.add_argument("--rgbd",action='store_true',help='rgbd mode') 36 | parser.add_argument("--kp_prior",action='store_true',help='use kp_prior in final objective function') 37 | 38 | # alter model 39 | parser.add_argument("--generator",required=True,help='the type of genertor') 40 | parser.add_argument("--kp_detector",default='KPDetector',type=str,help='the type of KPDetector') 41 | parser.add_argument("--GFM",default='GeneratorFullModel',help='the type of GeneratorFullModel') 42 | 43 | parser.add_argument("--batchsize",type=int, default=-1,help='user defined batchsize') 44 | parser.add_argument("--kp_num",type=int, default=-1,help='user defined keypoint number') 45 | parser.add_argument("--kp_distance",type=int, default=10,help='the weight of kp_distance loss') 46 | parser.add_argument("--depth_constraint",type=int, default=0,help='the weight of depth_constraint loss') 47 | 48 | parser.add_argument("--name",type=str,help='user defined model saved name') 49 | 50 | parser.set_defaults(verbose=False) 51 | opt = parser.parse_args() 52 | with open(opt.config) as f: 53 | config = yaml.load(f) 54 | 55 | if opt.checkpoint is not None: 56 | log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1]) 57 | else: 58 | log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0]) 59 | log_dir += opt.name 60 | 61 | 62 | print("Training...") 63 | 64 | config['train_params']['loss_weights']['depth_constraint'] = opt.depth_constraint 65 | config['train_params']['loss_weights']['kp_distance'] = opt.kp_distance 66 | if opt.kp_prior: 67 | config['train_params']['loss_weights']['kp_distance'] = 0 68 | config['train_params']['loss_weights']['kp_prior'] = 10 69 | if opt.batchsize != -1: 70 | config['train_params']['batch_size'] = opt.batchsize 71 | if opt.kp_num != -1: 72 | config['model_params']['common_params']['num_kp'] = opt.kp_num 73 | # create generator 74 | generator = getattr(generator, opt.generator)(**config['model_params']['generator_params'], 75 | **config['model_params']['common_params']) 76 | if torch.cuda.is_available(): 77 | generator.to(opt.device_ids[0]) 78 | if opt.verbose: 79 | print(generator) 80 | 81 | # create discriminator 82 | discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'], 83 | **config['model_params']['common_params']) 84 | 85 | if torch.cuda.is_available(): 86 | discriminator.to(opt.device_ids[0]) 87 | if opt.verbose: 88 | print(discriminator) 89 | 90 | # create kp_detector 91 | if opt.use_depth: 92 | config['model_params']['common_params']['num_channels'] = 1 93 | if opt.rgbd: 94 | config['model_params']['common_params']['num_channels'] = 4 95 | 96 | kp_detector = getattr(KPD, opt.kp_detector)(**config['model_params']['kp_detector_params'], 97 | **config['model_params']['common_params']) 98 | if torch.cuda.is_available(): 99 | kp_detector.to(opt.device_ids[0]) 100 | if opt.verbose: 101 | print(kp_detector) 102 | 103 | dataset = FramesDataset(is_train=(opt.mode == 'train'), **config['dataset_params']) 104 | if not os.path.exists(log_dir): 105 | os.makedirs(log_dir) 106 | if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))): 107 | copy(opt.config, log_dir) 108 | 109 | if not os.path.exists(os.path.join(log_dir,'log')): 110 | os.makedirs(os.path.join(log_dir,'log')) 111 | writer = SummaryWriter(os.path.join(log_dir,'log')) 112 | if opt.mode == 'train': 113 | train(config, generator, discriminator, kp_detector, opt.checkpoint, log_dir, dataset, opt.device_ids, opt,writer) -------------------------------------------------------------------------------- /sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | import torch 3 | 4 | from torch.utils.data import DataLoader 5 | 6 | from logger import Logger 7 | from modules.model import GeneratorFullModel, DiscriminatorFullModel 8 | import modules.model as MODEL 9 | from tqdm import tqdm 10 | from torch.optim.lr_scheduler import MultiStepLR 11 | from torch.nn.parallel import DistributedDataParallel as DDP 12 | import pdb 13 | from sync_batchnorm import DataParallelWithCallback 14 | from evaluation_dataset import EvaluationDataset 15 | 16 | from frames_dataset import DatasetRepeater 17 | 18 | 19 | def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, dataset, rank,device,opt,writer): 20 | train_params = config['train_params'] 21 | 22 | optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999)) 23 | optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999)) 24 | optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr_kp_detector'], betas=(0.5, 0.999)) 25 | 26 | if checkpoint is not None: 27 | start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector, 28 | optimizer_generator, optimizer_discriminator, 29 | None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector) 30 | else: 31 | start_epoch = 0 32 | 33 | scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1, 34 | last_epoch=start_epoch - 1) 35 | scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1, 36 | last_epoch=start_epoch - 1) 37 | scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1, 38 | last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0)) 39 | 40 | if 'num_repeats' in train_params or train_params['num_repeats'] != 1: 41 | dataset = DatasetRepeater(dataset, train_params['num_repeats']) 42 | sampler = torch.utils.data.distributed.DistributedSampler(dataset,num_replicas=torch.cuda.device_count(),rank=rank) 43 | dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=False, num_workers=16,sampler=sampler, drop_last=True) 44 | 45 | 46 | generator_full = getattr(MODEL,opt.GFM)(kp_detector, generator, discriminator, train_params,opt) 47 | discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params) 48 | test_dataset = EvaluationDataset(dataroot='/data/fhongac/origDataset/vox1_frames',pairs_list='data/vox_evaluation.csv') 49 | test_dataloader = torch.utils.data.DataLoader( 50 | test_dataset, 51 | batch_size = 1, 52 | shuffle=False, 53 | num_workers=4) 54 | with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger: 55 | for epoch in trange(start_epoch, train_params['num_epochs']): 56 | #parallel 57 | sampler.set_epoch(epoch) 58 | total = len(dataloader) 59 | epoch_train_loss = 0 60 | generator.train(), discriminator.train(), kp_detector.train() 61 | with tqdm(total=total) as par: 62 | for i,x in enumerate(dataloader): 63 | x['source'] = x['source'].to(device) 64 | x['driving'] = x['driving'].to(device) 65 | losses_generator, generated = generator_full(x) 66 | 67 | loss_values = [val.mean() for val in losses_generator.values()] 68 | loss = sum(loss_values) 69 | loss.backward() 70 | optimizer_generator.step() 71 | optimizer_generator.zero_grad() 72 | optimizer_kp_detector.step() 73 | optimizer_kp_detector.zero_grad() 74 | epoch_train_loss+=loss.item() 75 | 76 | if train_params['loss_weights']['generator_gan'] != 0: 77 | optimizer_discriminator.zero_grad() 78 | losses_discriminator = discriminator_full(x, generated) 79 | loss_values = [val.mean() for val in losses_discriminator.values()] 80 | loss = sum(loss_values) 81 | 82 | loss.backward() 83 | optimizer_discriminator.step() 84 | optimizer_discriminator.zero_grad() 85 | else: 86 | losses_discriminator = {} 87 | 88 | losses_generator.update(losses_discriminator) 89 | losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()} 90 | # for k,v in losses.items(): 91 | # writer.add_scalar(k, v, total*epoch+i) 92 | logger.log_iter(losses=losses) 93 | par.update(1) 94 | epoch_train_loss = epoch_train_loss/total 95 | if (epoch + 1) % train_params['checkpoint_freq'] == 0: 96 | writer.add_scalar('epoch_train_loss', epoch_train_loss, epoch) 97 | scheduler_generator.step() 98 | scheduler_discriminator.step() 99 | scheduler_kp_detector.step() 100 | logger.log_epoch(epoch, {'generator': generator, 101 | 'discriminator': discriminator, 102 | 'kp_detector': kp_detector, 103 | 'optimizer_generator': optimizer_generator, 104 | 'optimizer_discriminator': optimizer_discriminator, 105 | 'optimizer_kp_detector': optimizer_kp_detector}, inp=x, out=generated) 106 | generator.eval(), discriminator.eval(), kp_detector.eval() 107 | if (epoch + 1) % train_params['checkpoint_freq'] == 0: 108 | epoch_eval_loss = 0 109 | for i, data in tqdm(enumerate(test_dataloader)): 110 | data['source'] = data['source'].cuda() 111 | data['driving'] = data['driving'].cuda() 112 | losses_generator, generated = generator_full(data) 113 | loss_values = [val.mean() for val in losses_generator.values()] 114 | loss = sum(loss_values) 115 | epoch_eval_loss+=loss.item() 116 | epoch_eval_loss = epoch_eval_loss/len(test_dataloader) 117 | writer.add_scalar('epoch_eval_loss', epoch_eval_loss, epoch) 118 | -------------------------------------------------------------------------------- /train_dataparallel.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | import torch 3 | 4 | from torch.utils.data import DataLoader 5 | 6 | from logger import Logger 7 | from modules.model_dataparallel import DiscriminatorFullModel 8 | import modules.model_dataparallel as MODEL 9 | from tqdm import tqdm 10 | from torch.optim.lr_scheduler import MultiStepLR 11 | import pdb 12 | from sync_batchnorm import DataParallelWithCallback 13 | from evaluation_dataset import EvaluationDataset 14 | 15 | from frames_dataset import DatasetRepeater 16 | 17 | 18 | def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, dataset, device_ids,opt,writer): 19 | train_params = config['train_params'] 20 | 21 | optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999)) 22 | optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999)) 23 | optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr_kp_detector'], betas=(0.5, 0.999)) 24 | 25 | if checkpoint is not None: 26 | start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector, 27 | optimizer_generator, optimizer_discriminator, 28 | None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector) 29 | else: 30 | start_epoch = 0 31 | 32 | scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1, 33 | last_epoch=start_epoch - 1) 34 | scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1, 35 | last_epoch=start_epoch - 1) 36 | scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1, 37 | last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0)) 38 | 39 | if 'num_repeats' in train_params or train_params['num_repeats'] != 1: 40 | dataset = DatasetRepeater(dataset, train_params['num_repeats']) 41 | dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=16,drop_last=True) 42 | 43 | 44 | generator_full = getattr(MODEL,opt.GFM)(kp_detector, generator, discriminator, train_params,opt) 45 | discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params) 46 | test_dataset = EvaluationDataset(dataroot='/data/fhongac/origDataset/vox1_frames',pairs_list='data/vox_evaluation.csv') 47 | test_dataloader = torch.utils.data.DataLoader( 48 | test_dataset, 49 | batch_size = 1, 50 | shuffle=False, 51 | num_workers=4) 52 | if torch.cuda.is_available(): 53 | generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids) 54 | discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids) 55 | 56 | with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger: 57 | for epoch in trange(start_epoch, train_params['num_epochs']): 58 | #parallel 59 | total = len(dataloader) 60 | epoch_train_loss = 0 61 | generator.train(), discriminator.train(), kp_detector.train() 62 | with tqdm(total=total) as par: 63 | for i,x in enumerate(dataloader): 64 | # x['source'] = x['source'].to(device) 65 | # x['driving'] = x['driving'].to(device) 66 | losses_generator, generated = generator_full(x) 67 | 68 | loss_values = [val.mean() for val in losses_generator.values()] 69 | loss = sum(loss_values) 70 | loss.backward() 71 | optimizer_generator.step() 72 | optimizer_generator.zero_grad() 73 | optimizer_kp_detector.step() 74 | optimizer_kp_detector.zero_grad() 75 | epoch_train_loss+=loss.item() 76 | 77 | if train_params['loss_weights']['generator_gan'] != 0: 78 | optimizer_discriminator.zero_grad() 79 | losses_discriminator = discriminator_full(x, generated) 80 | loss_values = [val.mean() for val in losses_discriminator.values()] 81 | loss = sum(loss_values) 82 | 83 | loss.backward() 84 | optimizer_discriminator.step() 85 | optimizer_discriminator.zero_grad() 86 | else: 87 | losses_discriminator = {} 88 | 89 | losses_generator.update(losses_discriminator) 90 | losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()} 91 | # for k,v in losses.items(): 92 | # writer.add_scalar(k, v, total*epoch+i) 93 | logger.log_iter(losses=losses) 94 | par.update(1) 95 | epoch_train_loss = epoch_train_loss/total 96 | if (epoch + 1) % train_params['checkpoint_freq'] == 0: 97 | writer.add_scalar('epoch_train_loss', epoch_train_loss, epoch) 98 | scheduler_generator.step() 99 | scheduler_discriminator.step() 100 | scheduler_kp_detector.step() 101 | logger.log_epoch(epoch, {'generator': generator, 102 | 'discriminator': discriminator, 103 | 'kp_detector': kp_detector, 104 | 'optimizer_generator': optimizer_generator, 105 | 'optimizer_discriminator': optimizer_discriminator, 106 | 'optimizer_kp_detector': optimizer_kp_detector}, inp=x, out=generated) 107 | generator.eval(), discriminator.eval(), kp_detector.eval() 108 | if (epoch + 1) % train_params['checkpoint_freq'] == 0: 109 | epoch_eval_loss = 0 110 | for i, data in tqdm(enumerate(test_dataloader)): 111 | data['source'] = data['source'].cuda() 112 | data['driving'] = data['driving'].cuda() 113 | losses_generator, generated = generator_full(data) 114 | loss_values = [val.mean() for val in losses_generator.values()] 115 | loss = sum(loss_values) 116 | epoch_eval_loss+=loss.item() 117 | epoch_eval_loss = epoch_eval_loss/len(test_dataloader) 118 | writer.add_scalar('epoch_eval_loss', epoch_eval_loss, epoch) 119 | --------------------------------------------------------------------------------