├── .gitignore
├── LICENSE.md
├── README.md
├── animate.py
├── assets
├── driving.mp4
├── src.png
└── visual_vox1.png
├── augmentation.py
├── config
└── vox-adv-256.yaml
├── crop-video.py
├── data
├── celeV_cross_id_evaluation.csv
├── utils.py
├── vox256.csv
├── vox_cross_id_animate.csv
├── vox_cross_id_evaluation.csv
├── vox_cross_id_evaluation_best_frame.csv
└── vox_evaluation.csv
├── demo.py
├── demo_multi.py
├── depth
├── __init__.py
├── depth_decoder.py
├── layers.py
├── models
│ └── opt.json
├── pose_cnn.py
├── pose_decoder.py
└── resnet_encoder.py
├── evaluation_dataset.py
├── face-alignment
├── .gitattributes
├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── conda
│ └── meta.yaml
├── docs
│ └── images
│ │ ├── 2dlandmarks.png
│ │ └── face-alignment-adrian.gif
├── examples
│ ├── demo.ipynb
│ └── detect_landmarks_in_image.py
├── face_alignment
│ ├── __init__.py
│ ├── api.py
│ ├── detection
│ │ ├── __init__.py
│ │ ├── blazeface
│ │ │ ├── __init__.py
│ │ │ ├── blazeface_detector.py
│ │ │ ├── detect.py
│ │ │ ├── net_blazeface.py
│ │ │ └── utils.py
│ │ ├── core.py
│ │ ├── dlib
│ │ │ ├── __init__.py
│ │ │ └── dlib_detector.py
│ │ ├── folder
│ │ │ ├── __init__.py
│ │ │ └── folder_detector.py
│ │ └── sfd
│ │ │ ├── __init__.py
│ │ │ ├── bbox.py
│ │ │ ├── detect.py
│ │ │ ├── net_s3fd.py
│ │ │ └── sfd_detector.py
│ └── utils.py
├── requirements.txt
├── setup.cfg
├── setup.py
├── test
│ ├── facealignment_test.py
│ ├── smoke_test.py
│ └── test_utils.py
└── tox.ini
├── frames_dataset.py
├── kill_port.py
├── logger.py
├── modules
├── AdaIN.py
├── dense_motion.py
├── discriminator.py
├── dynamic_conv.py
├── generator.py
├── keypoint_detector.py
├── model.py
├── model_dataparallel.py
└── util.py
├── reconstruction.py
├── requirements.txt
├── run.py
├── run_dataparallel.py
├── sync_batchnorm
├── __init__.py
├── batchnorm.py
├── comm.py
├── replicate.py
└── unittest.py
├── train.py
├── train_dataparallel.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | log
3 | *.pth*
4 | readmesam*
5 | *.jpg
6 | *.mp4
7 | run.sh
8 | source.png
9 | tools
10 | *.pt
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## :book: Depth-Aware Generative Adversarial Network for Talking Head Video Generation (CVPR 2022)
3 |
4 | :fire: If DaGAN is helpful in your photos/projects, please help to :star: it or recommend it to your friends. Thanks:fire:
5 |
6 |
7 | :fire: Seeking for the collaboration and internship opportunities. :fire:
8 |
9 |
10 | > [[Paper](https://arxiv.org/abs/2203.06605)] [[Project Page](https://harlanhong.github.io/publications/dagan.html)] [[Demo](https://huggingface.co/spaces/HarlanHong/DaGAN)] [[Poster Video](https://www.youtube.com/watch?v=nahsJNjWzGo&t=1s)]
11 |
12 |
13 | > [Fa-Ting Hong](https://harlanhong.github.io), [Longhao Zhang](), [Li Shen](), [Dan Xu](https://www.danxurgb.net)
14 | > The Hong Kong University of Science and Technology
15 | > Alibaba Cloud
16 |
17 | ### Cartoon Sample
18 | https://user-images.githubusercontent.com/19970321/162151632-0195292f-30b8-4122-8afd-9b1698f1e4fe.mp4
19 |
20 | ### Human Sample
21 | https://user-images.githubusercontent.com/19970321/162151327-f2930231-42e3-40f2-bfca-a88529599f0f.mp4
22 |
23 | ### Voxceleb1 Dataset
24 |
25 |
26 |
27 |
28 | :triangular_flag_on_post: **Updates**
29 | - :fire::fire::white_check_mark: July 20 2023: Our new talking head work **[MCNet](https://harlanhong.github.io/publications/mcnet.html) was accpted by ICCV2023. There's no need to train a facial depth network, which makes it more convenient for users to test and fine-tune.
30 | - :fire::fire::white_check_mark: July 26, 2022: The normal dataparallel training scripts were released since some researchers informed me they ran into **DistributedDataParallel** problems. Please try to train your own model using this [command](#dataparallel). Also, we deleted the command line "with torch.autograd.set_detect_anomaly(True)" to boost the training speed.
31 | - :fire::fire::white_check_mark: June 26, 2022: The repo of our face depth network is released, please refer to [Face-Depth-Network](https://github.com/harlanhong/Face-Depth-Network) and feel free to email me if you meet any problem.
32 | - :fire::fire::white_check_mark: June 21, 2022: [Digression] I am looking for research intern/research assistant opportunities in European next year. Please contact me If you think I'm qualified for your position.
33 | - :fire::fire::white_check_mark: May 19, 2022: The depth face model (50 layers) trained on Voxceleb2 is released! (The corresponding checkpoint of DaGAN will release soon). Click the [LINK](https://hkustconnect-my.sharepoint.com/:f:/g/personal/fhongac_connect_ust_hk/EkxzfH7zbGJNr-WVmPU6fcABWAMq_WJoExAl4SttKK6hBQ?e=fbtGlX)
34 |
35 | - :fire::fire::white_check_mark: April 25, 2022: Integrated into Huggingface Spaces 🤗 using Gradio. Try out the web demo: [](https://huggingface.co/spaces/HarlanHong/DaGAN) (GPU version will come soon!)
36 | - :fire::fire::white_check_mark: Add **[SPADE model](https://hkustconnect-my.sharepoint.com/:f:/g/personal/fhongac_connect_ust_hk/EjfeXuzwo3JMn7s0oOPN_q0B81P5Wgu_kbYJAh7uSAKS2w?e=XNZl3K)**, which produces **more natural** results.
37 |
38 |
39 | ## :wrench: Dependencies and Installation
40 |
41 | - Python >= 3.7 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
42 | - [PyTorch >= 1.7](https://pytorch.org/)
43 | - Option: NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
44 | - Option: Linux
45 |
46 | ### Installation
47 | We now provide a *clean* version of DaGAN, which does not require customized CUDA extensions.
48 |
49 | 1. Clone repo
50 |
51 | ```bash
52 | git clone https://github.com/harlanhong/CVPR2022-DaGAN.git
53 | cd CVPR2022-DaGAN
54 | ```
55 |
56 | 2. Install dependent packages
57 |
58 | ```bash
59 | pip install -r requirements.txt
60 |
61 | ## Install the Face Alignment lib
62 | cd face-alignment
63 | pip install -r requirements.txt
64 | python setup.py install
65 | ```
66 | ## :zap: Quick Inference
67 |
68 | We take the paper version for an example. More models can be found [here](https://hkustconnect-my.sharepoint.com/:f:/g/personal/fhongac_connect_ust_hk/EjfeXuzwo3JMn7s0oOPN_q0B81P5Wgu_kbYJAh7uSAKS2w?e=KaQcPk).
69 |
70 | ### YAML configs
71 | See ```config/vox-adv-256.yaml``` to get description of each parameter.
72 |
73 | ### Pre-trained checkpoint
74 | The pre-trained checkpoint of face depth network and our DaGAN checkpoints can be found under following link: [OneDrive](https://hkustconnect-my.sharepoint.com/:f:/g/personal/fhongac_connect_ust_hk/EjfeXuzwo3JMn7s0oOPN_q0B81P5Wgu_kbYJAh7uSAKS2w?e=KaQcPk).
75 |
76 | **Inference!**
77 | To run a demo, download checkpoint and run the following command:
78 |
79 | ```bash
80 | CUDA_VISIBLE_DEVICES=0 python demo.py --config config/vox-adv-256.yaml --driving_video path/to/driving --source_image path/to/source --checkpoint path/to/checkpoint --relative --adapt_scale --kp_num 15 --generator DepthAwareGenerator
81 | ```
82 | The result will be stored in ```result.mp4```. The driving videos and source images should be cropped before it can be used in our method. To obtain some semi-automatic crop suggestions you can use ```python crop-video.py --inp some_youtube_video.mp4```. It will generate commands for crops using ffmpeg.
83 |
84 |
85 |
86 |
87 | ## :computer: Training
88 |
89 |
90 | ### Datasets
91 |
92 | 1) **VoxCeleb**. Please follow the instruction from https://github.com/AliaksandrSiarohin/video-preprocessing.
93 |
94 | ### Train on VoxCeleb
95 | To train a model on specific dataset run:
96 | ```
97 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --master_addr="0.0.0.0" --master_port=12348 run.py --config config/vox-adv-256.yaml --name DaGAN --rgbd --batchsize 12 --kp_num 15 --generator DepthAwareGenerator
98 | ```
99 | Or
100 |
101 | ```
102 | CUDA_VISIBLE_DEVICES=0,1,2,3 python run_dataparallel.py --config config/vox-adv-256.yaml --device_ids 0,1,2,3 --name DaGAN_voxceleb2_depth --rgbd --batchsize 48 --kp_num 15 --generator DepthAwareGenerator
103 | ```
104 |
105 |
106 |
107 |
108 | The code will create a folder in the log directory (each run will create a new name-specific directory).
109 | Checkpoints will be saved to this folder.
110 | To check the loss values during training see ```log.txt```.
111 | By default the batch size is tunned to run on 8 GeForce RTX 3090 gpu (You can obtain the best performance after about 150 epochs). You can change the batch size in the train_params in ```.yaml``` file.
112 |
113 |
114 | Also, you can watch the training loss by running the following command:
115 | ```bash
116 | tensorboard --logdir log/DaGAN/log
117 | ```
118 | When you kill your process for some reasons in the middle of training, a zombie process may occur, you can kill it using our provided tool:
119 | ```bash
120 | python kill_port.py PORT
121 | ```
122 |
123 | ### Training on your own dataset
124 | 1) Resize all the videos to the same size e.g 256x256, the videos can be in '.gif', '.mp4' or folder with images.
125 | We recommend the later, for each video make a separate folder with all the frames in '.png' format. This format is loss-less, and it has better i/o performance.
126 |
127 | 2) Create a folder ```data/dataset_name``` with 2 subfolders ```train``` and ```test```, put training videos in the ```train``` and testing in the ```test```.
128 |
129 | 3) Create a config ```config/dataset_name.yaml```, in dataset_params specify the root dir the ```root_dir: data/dataset_name```. Also adjust the number of epoch in train_params.
130 |
131 |
132 |
133 | ## :scroll: Acknowledgement
134 |
135 | Our DaGAN implementation is inspired by [FOMM](https://github.com/AliaksandrSiarohin/first-order-model). We appreciate the authors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model) for making their codes available to public.
136 |
137 | ## :scroll: BibTeX
138 |
139 | ```
140 | @inproceedings{hong2022depth,
141 | title={Depth-Aware Generative Adversarial Network for Talking Head Video Generation},
142 | author={Hong, Fa-Ting and Zhang, Longhao and Shen, Li and Xu, Dan},
143 | journal={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
144 | year={2022}
145 | }
146 |
147 | @article{hong2023dagan,
148 | title={DaGAN++: Depth-Aware Generative Adversarial Network for Talking Head Video Generation},
149 | author={Hong, Fa-Ting and and Shen, Li and Xu, Dan},
150 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)},
151 | year={2023}
152 | }
153 | ```
154 |
155 | ### :e-mail: Contact
156 |
157 | If you have any question or collaboration need (research purpose or commercial purpose), please email `fhongac@cse.ust.hk`.
158 |
--------------------------------------------------------------------------------
/animate.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 |
4 | import torch
5 | from torch.utils.data import DataLoader
6 |
7 | from frames_dataset import PairedDataset
8 | from logger import Logger, Visualizer
9 | import imageio
10 | from scipy.spatial import ConvexHull
11 | import numpy as np
12 | import depth
13 | from sync_batchnorm import DataParallelWithCallback
14 |
15 |
16 | def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
17 | use_relative_movement=False, use_relative_jacobian=False):
18 | if adapt_movement_scale:
19 | source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
20 | driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
21 | adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
22 | else:
23 | adapt_movement_scale = 1
24 |
25 | kp_new = {k: v for k, v in kp_driving.items()}
26 |
27 | if use_relative_movement:
28 | kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
29 | kp_value_diff *= adapt_movement_scale
30 | kp_new['value'] = kp_value_diff + kp_source['value']
31 |
32 | if use_relative_jacobian:
33 | jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
34 | kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
35 | return kp_new
36 |
37 |
38 | def animate(config, generator, kp_detector, checkpoint, log_dir, dataset,opt):
39 | log_dir = os.path.join(log_dir, 'animation')
40 | png_dir = os.path.join(log_dir, 'png')
41 | animate_params = config['animate_params']
42 |
43 | dataset = PairedDataset(initial_dataset=dataset, number_of_pairs=animate_params['num_pairs'])
44 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
45 |
46 | if checkpoint is not None:
47 | Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
48 | else:
49 | raise AttributeError("Checkpoint should be specified for mode='animate'.")
50 |
51 | if not os.path.exists(log_dir):
52 | os.makedirs(log_dir)
53 |
54 | if not os.path.exists(png_dir):
55 | os.makedirs(png_dir)
56 |
57 |
58 | depth_encoder = depth.ResnetEncoder(18, False).cuda()
59 | depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4)).cuda()
60 | loaded_dict_enc = torch.load('depth/models/weights_19/encoder.pth')
61 | loaded_dict_dec = torch.load('depth/models/weights_19/depth.pth')
62 | filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
63 | depth_encoder.load_state_dict(filtered_dict_enc)
64 | depth_decoder.load_state_dict(loaded_dict_dec)
65 | depth_decoder.eval()
66 | depth_encoder.eval()
67 | generator.eval()
68 | kp_detector.eval()
69 |
70 | for it, x in tqdm(enumerate(dataloader)):
71 | with torch.no_grad():
72 | predictions = []
73 | visualizations = []
74 |
75 | driving_video = x['driving_video'].cuda()
76 | source_frame = x['source_video'][:, :, 0, :, :].cuda()
77 |
78 | outputs = depth_decoder(depth_encoder(source_frame))
79 | depth_source = outputs[("disp", 0)]
80 | outputs = depth_decoder(depth_encoder(driving_video[:, :, 0]))
81 | depth_driving = outputs[("disp", 0)]
82 |
83 | source = torch.cat((source_frame,depth_source),1)
84 | driving = torch.cat((driving_video[:, :, 0],depth_driving),1)
85 |
86 | kp_source = kp_detector(source)
87 | kp_driving_initial = kp_detector(driving)
88 |
89 | for frame_idx in range(driving_video.shape[2]):
90 | driving_frame = driving_video[:, :, frame_idx].cuda()
91 | outputs = depth_decoder(depth_encoder(driving_frame))
92 | depth_map = outputs[("disp", 0)]
93 | driving = torch.cat((driving_frame,depth_map),1)
94 | kp_driving = kp_detector(driving)
95 |
96 | kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
97 | kp_driving_initial=kp_driving_initial, **animate_params['normalization_params'])
98 | out = generator(source_frame, kp_source=kp_source, kp_driving=kp_norm)
99 |
100 | out['kp_driving'] = kp_driving
101 | out['kp_source'] = kp_source
102 | out['kp_norm'] = kp_norm
103 |
104 | del out['sparse_deformed']
105 |
106 | predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
107 |
108 | visualization = Visualizer(**config['visualizer_params']).visualize(source=source_frame,
109 | driving=driving_frame, out=out)
110 | visualization = visualization
111 | visualizations.append(visualization)
112 |
113 | predictions = np.concatenate(predictions, axis=1)
114 | result_name = "-".join([x['driving_name'][0], x['source_name'][0]])
115 | imageio.imsave(os.path.join(png_dir, result_name + '.png'), (255 * predictions).astype(np.uint8))
116 |
117 | image_name = result_name + animate_params['format']
118 | imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
119 |
120 |
--------------------------------------------------------------------------------
/assets/driving.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harlanhong/CVPR2022-DaGAN/052ad359a019492893c098e8717dc993cdf846ed/assets/driving.mp4
--------------------------------------------------------------------------------
/assets/src.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harlanhong/CVPR2022-DaGAN/052ad359a019492893c098e8717dc993cdf846ed/assets/src.png
--------------------------------------------------------------------------------
/assets/visual_vox1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harlanhong/CVPR2022-DaGAN/052ad359a019492893c098e8717dc993cdf846ed/assets/visual_vox1.png
--------------------------------------------------------------------------------
/augmentation.py:
--------------------------------------------------------------------------------
1 | """
2 | Code from https://github.com/hassony2/torch_videovision
3 | """
4 |
5 | import numbers
6 |
7 | import random
8 | import numpy as np
9 | import PIL
10 |
11 | from skimage.transform import resize, rotate
12 | # from skimage.util import pad
13 | # import numpy.pad as pad
14 | import torchvision
15 |
16 | import warnings
17 |
18 | from skimage import img_as_ubyte, img_as_float
19 |
20 |
21 | def crop_clip(clip, min_h, min_w, h, w):
22 | if isinstance(clip[0], np.ndarray):
23 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
24 |
25 | elif isinstance(clip[0], PIL.Image.Image):
26 | cropped = [
27 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
28 | ]
29 | else:
30 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
31 | 'but got list of {0}'.format(type(clip[0])))
32 | return cropped
33 |
34 |
35 | def pad_clip(clip, h, w):
36 | im_h, im_w = clip[0].shape[:2]
37 | pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)
38 | pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)
39 |
40 | return np.pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')
41 |
42 |
43 | def resize_clip(clip, size, interpolation='bilinear'):
44 | if isinstance(clip[0], np.ndarray):
45 | if isinstance(size, numbers.Number):
46 | im_h, im_w, im_c = clip[0].shape
47 | # Min spatial dim already matches minimal size
48 | if (im_w <= im_h and im_w == size) or (im_h <= im_w
49 | and im_h == size):
50 | return clip
51 | new_h, new_w = get_resize_sizes(im_h, im_w, size)
52 | size = (new_w, new_h)
53 | else:
54 | size = size[1], size[0]
55 |
56 | scaled = [
57 | resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,
58 | mode='constant', anti_aliasing=True) for img in clip
59 | ]
60 | elif isinstance(clip[0], PIL.Image.Image):
61 | if isinstance(size, numbers.Number):
62 | im_w, im_h = clip[0].size
63 | # Min spatial dim already matches minimal size
64 | if (im_w <= im_h and im_w == size) or (im_h <= im_w
65 | and im_h == size):
66 | return clip
67 | new_h, new_w = get_resize_sizes(im_h, im_w, size)
68 | size = (new_w, new_h)
69 | else:
70 | size = size[1], size[0]
71 | if interpolation == 'bilinear':
72 | pil_inter = PIL.Image.NEAREST
73 | else:
74 | pil_inter = PIL.Image.BILINEAR
75 | scaled = [img.resize(size, pil_inter) for img in clip]
76 | else:
77 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
78 | 'but got list of {0}'.format(type(clip[0])))
79 | return scaled
80 |
81 |
82 | def get_resize_sizes(im_h, im_w, size):
83 | if im_w < im_h:
84 | ow = size
85 | oh = int(size * im_h / im_w)
86 | else:
87 | oh = size
88 | ow = int(size * im_w / im_h)
89 | return oh, ow
90 |
91 |
92 | class RandomFlip(object):
93 | def __init__(self, time_flip=False, horizontal_flip=False):
94 | self.time_flip = time_flip
95 | self.horizontal_flip = horizontal_flip
96 |
97 | def __call__(self, clip):
98 | if random.random() < 0.5 and self.time_flip:
99 | return clip[::-1]
100 | if random.random() < 0.5 and self.horizontal_flip:
101 | return [np.fliplr(img) for img in clip]
102 |
103 | return clip
104 |
105 |
106 | class RandomResize(object):
107 | """Resizes a list of (H x W x C) numpy.ndarray to the final size
108 | The larger the original image is, the more times it takes to
109 | interpolate
110 | Args:
111 | interpolation (str): Can be one of 'nearest', 'bilinear'
112 | defaults to nearest
113 | size (tuple): (widht, height)
114 | """
115 |
116 | def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
117 | self.ratio = ratio
118 | self.interpolation = interpolation
119 |
120 | def __call__(self, clip):
121 | scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
122 |
123 | if isinstance(clip[0], np.ndarray):
124 | im_h, im_w, im_c = clip[0].shape
125 | elif isinstance(clip[0], PIL.Image.Image):
126 | im_w, im_h = clip[0].size
127 |
128 | new_w = int(im_w * scaling_factor)
129 | new_h = int(im_h * scaling_factor)
130 | new_size = (new_w, new_h)
131 | resized = resize_clip(
132 | clip, new_size, interpolation=self.interpolation)
133 |
134 | return resized
135 |
136 |
137 | class RandomCrop(object):
138 | """Extract random crop at the same location for a list of videos
139 | Args:
140 | size (sequence or int): Desired output size for the
141 | crop in format (h, w)
142 | """
143 |
144 | def __init__(self, size):
145 | if isinstance(size, numbers.Number):
146 | size = (size, size)
147 |
148 | self.size = size
149 |
150 | def __call__(self, clip):
151 | """
152 | Args:
153 | img (PIL.Image or numpy.ndarray): List of videos to be cropped
154 | in format (h, w, c) in numpy.ndarray
155 | Returns:
156 | PIL.Image or numpy.ndarray: Cropped list of videos
157 | """
158 | h, w = self.size
159 | if isinstance(clip[0], np.ndarray):
160 | im_h, im_w, im_c = clip[0].shape
161 | elif isinstance(clip[0], PIL.Image.Image):
162 | im_w, im_h = clip[0].size
163 | else:
164 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
165 | 'but got list of {0}'.format(type(clip[0])))
166 |
167 | clip = pad_clip(clip, h, w)
168 | im_h, im_w = clip.shape[1:3]
169 | x1 = 0 if h == im_h else random.randint(0, im_w - w)
170 | y1 = 0 if w == im_w else random.randint(0, im_h - h)
171 | cropped = crop_clip(clip, y1, x1, h, w)
172 |
173 | return cropped
174 |
175 |
176 | class RandomRotation(object):
177 | """Rotate entire clip randomly by a random angle within
178 | given bounds
179 | Args:
180 | degrees (sequence or int): Range of degrees to select from
181 | If degrees is a number instead of sequence like (min, max),
182 | the range of degrees, will be (-degrees, +degrees).
183 | """
184 |
185 | def __init__(self, degrees):
186 | if isinstance(degrees, numbers.Number):
187 | if degrees < 0:
188 | raise ValueError('If degrees is a single number,'
189 | 'must be positive')
190 | degrees = (-degrees, degrees)
191 | else:
192 | if len(degrees) != 2:
193 | raise ValueError('If degrees is a sequence,'
194 | 'it must be of len 2.')
195 |
196 | self.degrees = degrees
197 |
198 | def __call__(self, clip):
199 | """
200 | Args:
201 | img (PIL.Image or numpy.ndarray): List of videos to be cropped
202 | in format (h, w, c) in numpy.ndarray
203 | Returns:
204 | PIL.Image or numpy.ndarray: Cropped list of videos
205 | """
206 | angle = random.uniform(self.degrees[0], self.degrees[1])
207 | if isinstance(clip[0], np.ndarray):
208 | rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]
209 | elif isinstance(clip[0], PIL.Image.Image):
210 | rotated = [img.rotate(angle) for img in clip]
211 | else:
212 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
213 | 'but got list of {0}'.format(type(clip[0])))
214 |
215 | return rotated
216 |
217 |
218 | class ColorJitter(object):
219 | """Randomly change the brightness, contrast and saturation and hue of the clip
220 | Args:
221 | brightness (float): How much to jitter brightness. brightness_factor
222 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
223 | contrast (float): How much to jitter contrast. contrast_factor
224 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
225 | saturation (float): How much to jitter saturation. saturation_factor
226 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
227 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from
228 | [-hue, hue]. Should be >=0 and <= 0.5.
229 | """
230 |
231 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
232 | self.brightness = brightness
233 | self.contrast = contrast
234 | self.saturation = saturation
235 | self.hue = hue
236 |
237 | def get_params(self, brightness, contrast, saturation, hue):
238 | if brightness > 0:
239 | brightness_factor = random.uniform(
240 | max(0, 1 - brightness), 1 + brightness)
241 | else:
242 | brightness_factor = None
243 |
244 | if contrast > 0:
245 | contrast_factor = random.uniform(
246 | max(0, 1 - contrast), 1 + contrast)
247 | else:
248 | contrast_factor = None
249 |
250 | if saturation > 0:
251 | saturation_factor = random.uniform(
252 | max(0, 1 - saturation), 1 + saturation)
253 | else:
254 | saturation_factor = None
255 |
256 | if hue > 0:
257 | hue_factor = random.uniform(-hue, hue)
258 | else:
259 | hue_factor = None
260 | return brightness_factor, contrast_factor, saturation_factor, hue_factor
261 |
262 | def __call__(self, clip):
263 | """
264 | Args:
265 | clip (list): list of PIL.Image
266 | Returns:
267 | list PIL.Image : list of transformed PIL.Image
268 | """
269 | if isinstance(clip[0], np.ndarray):
270 | brightness, contrast, saturation, hue = self.get_params(
271 | self.brightness, self.contrast, self.saturation, self.hue)
272 |
273 | # Create img transform function sequence
274 | img_transforms = []
275 | if brightness is not None:
276 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
277 | if saturation is not None:
278 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
279 | if hue is not None:
280 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
281 | if contrast is not None:
282 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
283 | random.shuffle(img_transforms)
284 | img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array,
285 | img_as_float]
286 |
287 | with warnings.catch_warnings():
288 | warnings.simplefilter("ignore")
289 | jittered_clip = []
290 | for img in clip:
291 | jittered_img = img
292 | for func in img_transforms:
293 | jittered_img = func(jittered_img)
294 | jittered_clip.append(jittered_img.astype('float32'))
295 | elif isinstance(clip[0], PIL.Image.Image):
296 | brightness, contrast, saturation, hue = self.get_params(
297 | self.brightness, self.contrast, self.saturation, self.hue)
298 |
299 | # Create img transform function sequence
300 | img_transforms = []
301 | if brightness is not None:
302 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
303 | if saturation is not None:
304 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
305 | if hue is not None:
306 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
307 | if contrast is not None:
308 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
309 | random.shuffle(img_transforms)
310 |
311 | # Apply to all videos
312 | jittered_clip = []
313 | for img in clip:
314 | for func in img_transforms:
315 | jittered_img = func(img)
316 | jittered_clip.append(jittered_img)
317 |
318 | else:
319 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
320 | 'but got list of {0}'.format(type(clip[0])))
321 | return jittered_clip
322 |
323 |
324 | class AllAugmentationTransform:
325 | def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None):
326 | self.transforms = []
327 |
328 | if flip_param is not None:
329 | self.transforms.append(RandomFlip(**flip_param))
330 |
331 | if rotation_param is not None:
332 | self.transforms.append(RandomRotation(**rotation_param))
333 |
334 | if resize_param is not None:
335 | self.transforms.append(RandomResize(**resize_param))
336 |
337 | if crop_param is not None:
338 | self.transforms.append(RandomCrop(**crop_param))
339 |
340 | if jitter_param is not None:
341 | self.transforms.append(ColorJitter(**jitter_param))
342 |
343 | def __call__(self, clip):
344 | for t in self.transforms:
345 | clip = t(clip)
346 | return clip
347 |
--------------------------------------------------------------------------------
/config/vox-adv-256.yaml:
--------------------------------------------------------------------------------
1 | dataset_params:
2 | root_dir: /data/fhongac/origDataset/vox1_frames
3 | frame_shape: [256, 256, 3]
4 | id_sampling: True
5 | pairs_list: data/vox256.csv
6 | augmentation_params:
7 | flip_param:
8 | horizontal_flip: True
9 | time_flip: True
10 | jitter_param:
11 | brightness: 0.1
12 | contrast: 0.1
13 | saturation: 0.1
14 | hue: 0.1
15 |
16 |
17 | model_params:
18 | common_params:
19 | num_kp: 10
20 | num_channels: 3
21 | estimate_jacobian: True
22 | kp_detector_params:
23 | temperature: 0.1
24 | block_expansion: 32
25 | max_features: 1024
26 | scale_factor: 0.25
27 | num_blocks: 5
28 | generator_params:
29 | block_expansion: 64
30 | max_features: 512
31 | num_down_blocks: 2
32 | num_bottleneck_blocks: 6
33 | estimate_occlusion_map: True
34 | dense_motion_params:
35 | block_expansion: 64
36 | max_features: 1024
37 | num_blocks: 5
38 | scale_factor: 0.25
39 | discriminator_params:
40 | scales: [1]
41 | block_expansion: 32
42 | max_features: 512
43 | num_blocks: 4
44 | use_kp: True
45 |
46 |
47 | train_params:
48 | num_epochs: 150
49 | num_repeats: 75
50 | epoch_milestones: []
51 | lr_generator: 2.0e-4
52 | lr_discriminator: 2.0e-4
53 | lr_kp_detector: 2.0e-4
54 | batch_size: 4
55 | scales: [1, 0.5, 0.25, 0.125]
56 | checkpoint_freq: 10
57 | transform_params:
58 | sigma_affine: 0.05
59 | sigma_tps: 0.005
60 | points_tps: 5
61 | loss_weights:
62 | generator_gan: 1
63 | discriminator_gan: 1
64 | feature_matching: [10, 10, 10, 10]
65 | perceptual: [10, 10, 10, 10, 10]
66 | equivariance_value: 10
67 | equivariance_jacobian: 10
68 | kp_distance: 10
69 | kp_prior: 0
70 | kp_scale: 0
71 | depth_constraint: 0
72 |
73 | reconstruction_params:
74 | num_videos: 1000
75 | format: '.mp4'
76 |
77 | animate_params:
78 | num_pairs: 50
79 | format: '.mp4'
80 | normalization_params:
81 | adapt_movement_scale: False
82 | use_relative_movement: True
83 | use_relative_jacobian: True
84 |
85 | visualizer_params:
86 | kp_size: 5
87 | draw_border: True
88 | colormap: 'gist_rainbow'
89 |
--------------------------------------------------------------------------------
/crop-video.py:
--------------------------------------------------------------------------------
1 | import face_alignment
2 | import skimage.io
3 | import numpy
4 | from argparse import ArgumentParser
5 | from skimage import img_as_ubyte
6 | from skimage.transform import resize
7 | from tqdm import tqdm
8 | import os
9 | import imageio
10 | import numpy as np
11 | import warnings
12 | warnings.filterwarnings("ignore")
13 |
14 | def extract_bbox(frame, fa):
15 | if max(frame.shape[0], frame.shape[1]) > 640:
16 | scale_factor = max(frame.shape[0], frame.shape[1]) / 640.0
17 | frame = resize(frame, (int(frame.shape[0] / scale_factor), int(frame.shape[1] / scale_factor)))
18 | frame = img_as_ubyte(frame)
19 | else:
20 | scale_factor = 1
21 | frame = frame[..., :3]
22 | bboxes = fa.face_detector.detect_from_image(frame[..., ::-1])
23 | if len(bboxes) == 0:
24 | return []
25 | return np.array(bboxes)[:, :-1] * scale_factor
26 |
27 |
28 |
29 | def bb_intersection_over_union(boxA, boxB):
30 | xA = max(boxA[0], boxB[0])
31 | yA = max(boxA[1], boxB[1])
32 | xB = min(boxA[2], boxB[2])
33 | yB = min(boxA[3], boxB[3])
34 | interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
35 | boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
36 | boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
37 | iou = interArea / float(boxAArea + boxBArea - interArea)
38 | return iou
39 |
40 |
41 | def join(tube_bbox, bbox):
42 | xA = min(tube_bbox[0], bbox[0])
43 | yA = min(tube_bbox[1], bbox[1])
44 | xB = max(tube_bbox[2], bbox[2])
45 | yB = max(tube_bbox[3], bbox[3])
46 | return (xA, yA, xB, yB)
47 |
48 |
49 | def compute_bbox(start, end, fps, tube_bbox, frame_shape, inp, image_shape, increase_area=0.1):
50 | left, top, right, bot = tube_bbox
51 | width = right - left
52 | height = bot - top
53 |
54 | #Computing aspect preserving bbox
55 | width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
56 | height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
57 |
58 | left = int(left - width_increase * width)
59 | top = int(top - height_increase * height)
60 | right = int(right + width_increase * width)
61 | bot = int(bot + height_increase * height)
62 |
63 | top, bot, left, right = max(0, top), min(bot, frame_shape[0]), max(0, left), min(right, frame_shape[1])
64 | h, w = bot - top, right - left
65 |
66 | start = start / fps
67 | end = end / fps
68 | time = end - start
69 |
70 | scale = f'{image_shape[0]}:{image_shape[1]}'
71 |
72 | return f'ffmpeg -i {inp} -ss {start} -t {time} -filter:v "crop={w}:{h}:{left}:{top}, scale={scale}" crop.mp4'
73 |
74 |
75 | def compute_bbox_trajectories(trajectories, fps, frame_shape, args):
76 | commands = []
77 | for i, (bbox, tube_bbox, start, end) in enumerate(trajectories):
78 | if (end - start) > args.min_frames:
79 | command = compute_bbox(start, end, fps, tube_bbox, frame_shape, inp=args.inp, image_shape=args.image_shape, increase_area=args.increase)
80 | commands.append(command)
81 | return commands
82 |
83 |
84 | def process_video(args):
85 | device = 'cpu' if args.cpu else 'cuda'
86 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device=device)
87 | video = imageio.get_reader(args.inp)
88 |
89 | trajectories = []
90 | previous_frame = None
91 | fps = video.get_meta_data()['fps']
92 | commands = []
93 | try:
94 | for i, frame in tqdm(enumerate(video)):
95 | frame_shape = frame.shape
96 | bboxes = extract_bbox(frame, fa)
97 | ## For each trajectory check the criterion
98 | not_valid_trajectories = []
99 | valid_trajectories = []
100 |
101 | for trajectory in trajectories:
102 | tube_bbox = trajectory[0]
103 | intersection = 0
104 | for bbox in bboxes:
105 | intersection = max(intersection, bb_intersection_over_union(tube_bbox, bbox))
106 | if intersection > args.iou_with_initial:
107 | valid_trajectories.append(trajectory)
108 | else:
109 | not_valid_trajectories.append(trajectory)
110 |
111 | commands += compute_bbox_trajectories(not_valid_trajectories, fps, frame_shape, args)
112 | trajectories = valid_trajectories
113 |
114 | ## Assign bbox to trajectories, create new trajectories
115 | for bbox in bboxes:
116 | intersection = 0
117 | current_trajectory = None
118 | for trajectory in trajectories:
119 | tube_bbox = trajectory[0]
120 | current_intersection = bb_intersection_over_union(tube_bbox, bbox)
121 | if intersection < current_intersection and current_intersection > args.iou_with_initial:
122 | intersection = bb_intersection_over_union(tube_bbox, bbox)
123 | current_trajectory = trajectory
124 |
125 | ## Create new trajectory
126 | if current_trajectory is None:
127 | trajectories.append([bbox, bbox, i, i])
128 | else:
129 | current_trajectory[3] = i
130 | current_trajectory[1] = join(current_trajectory[1], bbox)
131 |
132 |
133 | except IndexError as e:
134 | raise (e)
135 |
136 | commands += compute_bbox_trajectories(trajectories, fps, frame_shape, args)
137 | return commands
138 |
139 |
140 | if __name__ == "__main__":
141 | parser = ArgumentParser()
142 |
143 | parser.add_argument("--image_shape", default=(256, 256), type=lambda x: tuple(map(int, x.split(','))),
144 | help="Image shape")
145 | parser.add_argument("--increase", default=0.1, type=float, help='Increase bbox by this amount')
146 | parser.add_argument("--iou_with_initial", type=float, default=0.25, help="The minimal allowed iou with inital bbox")
147 | parser.add_argument("--inp", required=True, help='Input image or video')
148 | parser.add_argument("--min_frames", type=int, default=150, help='Minimum number of frames')
149 | parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
150 |
151 |
152 | args = parser.parse_args()
153 |
154 | commands = process_video(args)
155 | for command in commands:
156 | print (command)
157 |
158 |
--------------------------------------------------------------------------------
/data/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import csv
4 | import pdb
5 | import numpy as np
6 |
7 | def create_csv(path):
8 | videos = os.listdir(path)
9 | source = videos.copy()
10 | driving = videos.copy()
11 | random.shuffle(source)
12 | random.shuffle(driving)
13 | source = np.array(source).reshape(-1,1)
14 | driving = np.array(driving).reshape(-1,1)
15 | zeros = np.zeros((len(source),1))
16 | content = np.concatenate((source,driving,zeros),1)
17 | f = open('vox256.csv','w',encoding='utf-8')
18 | csv_writer = csv.writer(f)
19 | csv_writer.writerow(["source","driving","frame"])
20 | csv_writer.writerows(content)
21 | f.close()
22 |
23 |
24 | if __name__ == '__main__':
25 | create_csv('/data/fhongac/origDataset/vox1/test')
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('Agg')
3 | import os, sys
4 | import yaml
5 | from argparse import ArgumentParser
6 | from tqdm import tqdm
7 | import modules.generator as GEN
8 | import imageio
9 | import numpy as np
10 | from skimage.transform import resize
11 | from skimage import img_as_ubyte
12 | import torch
13 | from sync_batchnorm import DataParallelWithCallback
14 | import depth
15 | from modules.keypoint_detector import KPDetector
16 | from animate import normalize_kp
17 | from scipy.spatial import ConvexHull
18 | from collections import OrderedDict
19 | import pdb
20 | import cv2
21 | if sys.version_info[0] < 3:
22 | raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
23 |
24 | def load_checkpoints(config_path, checkpoint_path, cpu=False):
25 |
26 | with open(config_path) as f:
27 | config = yaml.load(f)
28 | if opt.kp_num != -1:
29 | config['model_params']['common_params']['num_kp'] = opt.kp_num
30 | generator = getattr(GEN, opt.generator)(**config['model_params']['generator_params'],**config['model_params']['common_params'])
31 | if not cpu:
32 | generator.cuda()
33 | config['model_params']['common_params']['num_channels'] = 4
34 | kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
35 | **config['model_params']['common_params'])
36 | if not cpu:
37 | kp_detector.cuda()
38 | if cpu:
39 | checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
40 | else:
41 | checkpoint = torch.load(checkpoint_path,map_location="cuda:0")
42 |
43 | ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in checkpoint['generator'].items())
44 | generator.load_state_dict(ckp_generator)
45 | ckp_kp_detector = OrderedDict((k.replace('module.',''),v) for k,v in checkpoint['kp_detector'].items())
46 | kp_detector.load_state_dict(ckp_kp_detector)
47 |
48 | if not cpu:
49 | generator = DataParallelWithCallback(generator)
50 | kp_detector = DataParallelWithCallback(kp_detector)
51 |
52 | generator.eval()
53 | kp_detector.eval()
54 |
55 | return generator, kp_detector
56 |
57 |
58 | def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False):
59 | sources = []
60 | drivings = []
61 | with torch.no_grad():
62 | predictions = []
63 | depth_gray = []
64 | source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
65 | driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
66 | if not cpu:
67 | source = source.cuda()
68 | driving = driving.cuda()
69 | outputs = depth_decoder(depth_encoder(source))
70 | depth_source = outputs[("disp", 0)]
71 |
72 | outputs = depth_decoder(depth_encoder(driving[:, :, 0]))
73 | depth_driving = outputs[("disp", 0)]
74 | source_kp = torch.cat((source,depth_source),1)
75 | driving_kp = torch.cat((driving[:, :, 0],depth_driving),1)
76 |
77 | kp_source = kp_detector(source_kp)
78 | kp_driving_initial = kp_detector(driving_kp)
79 |
80 | # kp_source = kp_detector(source)
81 | # kp_driving_initial = kp_detector(driving[:, :, 0])
82 |
83 | for frame_idx in tqdm(range(driving.shape[2])):
84 | driving_frame = driving[:, :, frame_idx]
85 |
86 | if not cpu:
87 | driving_frame = driving_frame.cuda()
88 | outputs = depth_decoder(depth_encoder(driving_frame))
89 | depth_map = outputs[("disp", 0)]
90 |
91 | gray_driving = np.transpose(depth_map.data.cpu().numpy(), [0, 2, 3, 1])[0]
92 | gray_driving = 1-gray_driving/np.max(gray_driving)
93 |
94 | frame = torch.cat((driving_frame,depth_map),1)
95 | kp_driving = kp_detector(frame)
96 |
97 | kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
98 | kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
99 | use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
100 | out = generator(source, kp_source=kp_source, kp_driving=kp_norm,source_depth = depth_source, driving_depth = depth_map)
101 |
102 | drivings.append(np.transpose(driving_frame.data.cpu().numpy(), [0, 2, 3, 1])[0])
103 | sources.append(np.transpose(source.data.cpu().numpy(), [0, 2, 3, 1])[0])
104 | predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
105 | depth_gray.append(gray_driving)
106 | return sources, drivings, predictions,depth_gray
107 |
108 |
109 | def find_best_frame(source, driving, cpu=False):
110 | import face_alignment
111 |
112 | def normalize_kp(kp):
113 | kp = kp - kp.mean(axis=0, keepdims=True)
114 | area = ConvexHull(kp[:, :2]).volume
115 | area = np.sqrt(area)
116 | kp[:, :2] = kp[:, :2] / area
117 | return kp
118 |
119 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
120 | device='cpu' if cpu else 'cuda')
121 | kp_source = fa.get_landmarks(255 * source)[0]
122 | kp_source = normalize_kp(kp_source)
123 | norm = float('inf')
124 | frame_num = 0
125 | for i, image in tqdm(enumerate(driving)):
126 | kp_driving = fa.get_landmarks(255 * image)[0]
127 | kp_driving = normalize_kp(kp_driving)
128 | new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
129 | if new_norm < norm:
130 | norm = new_norm
131 | frame_num = i
132 | return frame_num
133 |
134 | if __name__ == "__main__":
135 | parser = ArgumentParser()
136 | parser.add_argument("--config", required=True, help="path to config")
137 | parser.add_argument("--checkpoint", default='vox-cpk.pth.tar', help="path to checkpoint to restore")
138 |
139 | parser.add_argument("--source_image", default='sup-mat/source.png', help="path to source image")
140 | parser.add_argument("--driving_video", default='sup-mat/source.png', help="path to driving video")
141 | parser.add_argument("--result_video", default='result.mp4', help="path to output")
142 |
143 | parser.add_argument("--relative", dest="relative", action="store_true", help="use relative or absolute keypoint coordinates")
144 | parser.add_argument("--adapt_scale", dest="adapt_scale", action="store_true", help="adapt movement scale based on convex hull of keypoints")
145 | parser.add_argument("--generator", type=str, required=True)
146 | parser.add_argument("--kp_num", type=int, required=True)
147 |
148 |
149 | parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true",
150 | help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)")
151 |
152 | parser.add_argument("--best_frame", dest="best_frame", type=int, default=None,
153 | help="Set frame to start from.")
154 |
155 | parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
156 |
157 |
158 | parser.set_defaults(relative=False)
159 | parser.set_defaults(adapt_scale=False)
160 |
161 | opt = parser.parse_args()
162 |
163 | depth_encoder = depth.ResnetEncoder(18, False)
164 | depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4))
165 | loaded_dict_enc = torch.load('depth/models/weights_19/encoder.pth')
166 | loaded_dict_dec = torch.load('depth/models/weights_19/depth.pth')
167 | filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
168 | depth_encoder.load_state_dict(filtered_dict_enc)
169 | depth_decoder.load_state_dict(loaded_dict_dec)
170 | depth_encoder.eval()
171 | depth_decoder.eval()
172 | if not opt.cpu:
173 | depth_encoder.cuda()
174 | depth_decoder.cuda()
175 |
176 | source_image = imageio.imread(opt.source_image)
177 | reader = imageio.get_reader(opt.driving_video)
178 | fps = reader.get_meta_data()['fps']
179 | driving_video = []
180 | try:
181 | for im in reader:
182 | driving_video.append(im)
183 | except RuntimeError:
184 | pass
185 | reader.close()
186 |
187 | source_image = resize(source_image, (256, 256))[..., :3]
188 | driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
189 | generator, kp_detector = load_checkpoints(config_path=opt.config, checkpoint_path=opt.checkpoint, cpu=opt.cpu)
190 |
191 | if opt.find_best_frame or opt.best_frame is not None:
192 | i = opt.best_frame if opt.best_frame is not None else find_best_frame(source_image, driving_video, cpu=opt.cpu)
193 | print ("Best frame: " + str(i))
194 | driving_forward = driving_video[i:]
195 | driving_backward = driving_video[:(i+1)][::-1]
196 | sources_forward, drivings_forward, predictions_forward,depth_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
197 | sources_backward, drivings_backward, predictions_backward,depth_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
198 | predictions = predictions_backward[::-1] + predictions_forward[1:]
199 | sources = sources_backward[::-1] + sources_forward[1:]
200 | drivings = drivings_backward[::-1] + drivings_forward[1:]
201 | depth_gray = depth_backward[::-1] + depth_forward[1:]
202 |
203 | else:
204 | # predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
205 | sources, drivings, predictions,depth_gray = make_animation(source_image, driving_video, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
206 | imageio.mimsave(opt.result_video, [img_as_ubyte(p) for p in predictions], fps=fps)
207 | # imageio.mimsave(opt.result_video, [np.concatenate((img_as_ubyte(s),img_as_ubyte(d),img_as_ubyte(p)),1) for (s,d,p) in zip(sources, drivings, predictions)], fps=fps)
208 | # imageio.mimsave("gray.mp4", depth_gray, fps=fps)
209 | # merge the gray video
210 | # animation = np.array(imageio.mimread(opt.result_video,memtest=False))
211 | # gray = np.array(imageio.mimread("gray.mp4",memtest=False))
212 |
213 | # src_dst = animation[:,:,:512,:]
214 | # animate = animation[:,:,512:,:]
215 | # merge = np.concatenate((src_dst,gray,animate),2)
216 | # imageio.mimsave(opt.result_video, animate, fps=fps)
217 | #Transfer to gif
218 | # from moviepy.editor import *
219 | # clip = (VideoFileClip(opt.result_video))
220 | # clip.write_gif("{}.gif".format(opt.result_video))
--------------------------------------------------------------------------------
/demo_multi.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('Agg')
3 | import os, sys
4 | import yaml
5 | from argparse import ArgumentParser
6 | from tqdm import tqdm
7 | import modules.generator as GEN
8 | import imageio
9 | import numpy as np
10 | from skimage.transform import resize
11 | from skimage import img_as_ubyte
12 | import torch
13 | from sync_batchnorm import DataParallelWithCallback
14 | import depth
15 | from modules.keypoint_detector import KPDetector
16 | from animate import normalize_kp
17 | from scipy.spatial import ConvexHull
18 | from collections import OrderedDict
19 | import pdb
20 | import cv2
21 | from glob import glob
22 | if sys.version_info[0] < 3:
23 | raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
24 |
25 | def load_checkpoints(config_path, checkpoint_path, cpu=False):
26 |
27 | with open(config_path) as f:
28 | config = yaml.load(f,Loader=yaml.FullLoader)
29 | if opt.kp_num != -1:
30 | config['model_params']['common_params']['num_kp'] = opt.kp_num
31 | generator = getattr(GEN, opt.generator)(**config['model_params']['generator_params'],**config['model_params']['common_params'])
32 | if not cpu:
33 | generator.cuda()
34 | config['model_params']['common_params']['num_channels'] = 4
35 | kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
36 | **config['model_params']['common_params'])
37 | if not cpu:
38 | kp_detector.cuda()
39 | if cpu:
40 | checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
41 | else:
42 | checkpoint = torch.load(checkpoint_path,map_location="cuda:0")
43 |
44 | ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in checkpoint['generator'].items())
45 | generator.load_state_dict(ckp_generator)
46 | ckp_kp_detector = OrderedDict((k.replace('module.',''),v) for k,v in checkpoint['kp_detector'].items())
47 | kp_detector.load_state_dict(ckp_kp_detector)
48 |
49 | if not cpu:
50 | generator = DataParallelWithCallback(generator)
51 | kp_detector = DataParallelWithCallback(kp_detector)
52 |
53 | generator.eval()
54 | kp_detector.eval()
55 |
56 | return generator, kp_detector
57 |
58 |
59 | def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False):
60 | sources = []
61 | drivings = []
62 | with torch.no_grad():
63 | predictions = []
64 | depth_gray = []
65 | source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
66 | driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
67 | if not cpu:
68 | source = source.cuda()
69 | driving = driving.cuda()
70 | outputs = depth_decoder(depth_encoder(source))
71 | depth_source = outputs[("disp", 0)]
72 |
73 | outputs = depth_decoder(depth_encoder(driving[:, :, 0]))
74 | depth_driving = outputs[("disp", 0)]
75 | source_kp = torch.cat((source,depth_source),1)
76 | driving_kp = torch.cat((driving[:, :, 0],depth_driving),1)
77 |
78 | kp_source = kp_detector(source_kp)
79 | kp_driving_initial = kp_detector(driving_kp)
80 |
81 | # kp_source = kp_detector(source)
82 | # kp_driving_initial = kp_detector(driving[:, :, 0])
83 |
84 | for frame_idx in tqdm(range(driving.shape[2])):
85 | driving_frame = driving[:, :, frame_idx]
86 |
87 | if not cpu:
88 | driving_frame = driving_frame.cuda()
89 | outputs = depth_decoder(depth_encoder(driving_frame))
90 | depth_map = outputs[("disp", 0)]
91 |
92 | gray_driving = np.transpose(depth_map.data.cpu().numpy(), [0, 2, 3, 1])[0]
93 | gray_driving = 1-gray_driving/np.max(gray_driving)
94 |
95 | frame = torch.cat((driving_frame,depth_map),1)
96 | kp_driving = kp_detector(frame)
97 |
98 | kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
99 | kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
100 | use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
101 | out = generator(source, kp_source=kp_source, kp_driving=kp_norm,source_depth = depth_source, driving_depth = depth_map)
102 |
103 | drivings.append(np.transpose(driving_frame.data.cpu().numpy(), [0, 2, 3, 1])[0])
104 | sources.append(np.transpose(source.data.cpu().numpy(), [0, 2, 3, 1])[0])
105 | predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
106 | depth_gray.append(gray_driving)
107 | return sources, drivings, predictions,depth_gray
108 |
109 |
110 | def find_best_frame(source, driving, cpu=False):
111 | import face_alignment
112 |
113 | def normalize_kp(kp):
114 | kp = kp - kp.mean(axis=0, keepdims=True)
115 | area = ConvexHull(kp[:, :2]).volume
116 | area = np.sqrt(area)
117 | kp[:, :2] = kp[:, :2] / area
118 | return kp
119 |
120 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
121 | device='cpu' if cpu else 'cuda')
122 | kp_source = fa.get_landmarks(255 * source)[0]
123 | kp_source = normalize_kp(kp_source)
124 | norm = float('inf')
125 | frame_num = 0
126 | for i, image in tqdm(enumerate(driving)):
127 | kp_driving = fa.get_landmarks(255 * image)[0]
128 | kp_driving = normalize_kp(kp_driving)
129 | new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
130 | if new_norm < norm:
131 | norm = new_norm
132 | frame_num = i
133 | return frame_num
134 |
135 | if __name__ == "__main__":
136 | parser = ArgumentParser()
137 | parser.add_argument("--config", required=True, help="path to config")
138 | parser.add_argument("--checkpoint", default='vox-cpk.pth.tar', help="path to checkpoint to restore")
139 |
140 | parser.add_argument("--source_folder", default='sup-mat/source.png', help="path to source image")
141 | parser.add_argument("--driving_video", default='sup-mat/source.png', help="path to driving video")
142 | parser.add_argument("--save_folder", default='result.mp4', help="path to output")
143 |
144 | parser.add_argument("--relative", dest="relative", action="store_true", help="use relative or absolute keypoint coordinates")
145 | parser.add_argument("--adapt_scale", dest="adapt_scale", action="store_true", help="adapt movement scale based on convex hull of keypoints")
146 | parser.add_argument("--generator", type=str, required=True)
147 | parser.add_argument("--kp_num", type=int, required=True)
148 |
149 |
150 | parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true",
151 | help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)")
152 |
153 | parser.add_argument("--best_frame", dest="best_frame", type=int, default=None,
154 | help="Set frame to start from.")
155 |
156 | parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
157 |
158 |
159 | parser.set_defaults(relative=False)
160 | parser.set_defaults(adapt_scale=False)
161 |
162 | opt = parser.parse_args()
163 |
164 | depth_encoder = depth.ResnetEncoder(18, False)
165 | depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4))
166 | loaded_dict_enc = torch.load('depth/models/weights_19/encoder.pth')
167 | loaded_dict_dec = torch.load('depth/models/weights_19/depth.pth')
168 | filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
169 | depth_encoder.load_state_dict(filtered_dict_enc)
170 | depth_decoder.load_state_dict(loaded_dict_dec)
171 | depth_encoder.eval()
172 | depth_decoder.eval()
173 | if not opt.cpu:
174 | depth_encoder.cuda()
175 | depth_decoder.cuda()
176 |
177 | reader = imageio.get_reader(opt.driving_video)
178 | fps = reader.get_meta_data()['fps']
179 | driving_video = []
180 | try:
181 | for im in reader:
182 | driving_video.append(im)
183 | except RuntimeError:
184 | pass
185 | reader.close()
186 |
187 | driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
188 | generator, kp_detector = load_checkpoints(config_path=opt.config, checkpoint_path=opt.checkpoint, cpu=opt.cpu)
189 | if not os.path.exists(opt.save_folder):
190 | os.makedirs(opt.save_folder)
191 | sources = glob(opt.source_folder+"/*")
192 | for src in tqdm(sources):
193 | source_image = imageio.imread(src)
194 | source_image = resize(source_image, (256, 256))[..., :3]
195 | # predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
196 | sources, drivings, predictions,depth_gray = make_animation(source_image, driving_video, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
197 | fn = os.path.basename(src)
198 | imageio.mimsave(os.path.join(opt.save_folder,fn+'.mp4'), [img_as_ubyte(p) for p in predictions], fps=fps)
199 |
--------------------------------------------------------------------------------
/depth/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet_encoder import ResnetEncoder
2 | from .depth_decoder import DepthDecoder
3 | from .pose_decoder import PoseDecoder
4 | from .pose_cnn import PoseCNN
5 |
6 |
--------------------------------------------------------------------------------
/depth/depth_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright Niantic 2019. Patent Pending. All rights reserved.
2 | #
3 | # This software is licensed under the terms of the Monodepth2 licence
4 | # which allows for non-commercial use only, the full terms of which are made
5 | # available in the LICENSE file.
6 |
7 | from __future__ import absolute_import, division, print_function
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 |
13 | from collections import OrderedDict
14 | from depth.layers import *
15 |
16 |
17 | class DepthDecoder(nn.Module):
18 | def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True):
19 | super(DepthDecoder, self).__init__()
20 |
21 | self.num_output_channels = num_output_channels
22 | self.use_skips = use_skips
23 | self.upsample_mode = 'nearest'
24 | self.scales = scales
25 |
26 | self.num_ch_enc = num_ch_enc
27 | self.num_ch_dec = np.array([16, 32, 64, 128, 256])
28 |
29 | # decoder
30 | self.convs = OrderedDict()
31 | for i in range(4, -1, -1):
32 | # upconv_0
33 | num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
34 | num_ch_out = self.num_ch_dec[i]
35 | self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
36 |
37 | # upconv_1
38 | num_ch_in = self.num_ch_dec[i]
39 | if self.use_skips and i > 0:
40 | num_ch_in += self.num_ch_enc[i - 1]
41 | num_ch_out = self.num_ch_dec[i]
42 | self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
43 |
44 | for s in self.scales:
45 | self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels)
46 |
47 | self.decoder = nn.ModuleList(list(self.convs.values()))
48 | self.sigmoid = nn.Sigmoid()
49 |
50 | def forward(self, input_features):
51 | self.outputs = {}
52 |
53 | # decoder
54 | x = input_features[-1]
55 | for i in range(4, -1, -1):
56 | x = self.convs[("upconv", i, 0)](x)
57 | x = [upsample(x)]
58 | if self.use_skips and i > 0:
59 | x += [input_features[i - 1]]
60 | x = torch.cat(x, 1)
61 | x = self.convs[("upconv", i, 1)](x)
62 | if i in self.scales:
63 | self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x))
64 |
65 | return self.outputs
66 |
--------------------------------------------------------------------------------
/depth/layers.py:
--------------------------------------------------------------------------------
1 | # Copyright Niantic 2019. Patent Pending. All rights reserved.
2 | #
3 | # This software is licensed under the terms of the Monodepth2 licence
4 | # which allows for non-commercial use only, the full terms of which are made
5 | # available in the LICENSE file.
6 |
7 | from __future__ import absolute_import, division, print_function
8 |
9 | import numpy as np
10 | import pdb
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 |
16 | def disp_to_depth(disp, min_depth, max_depth):
17 | """Convert network's sigmoid output into depth prediction
18 | The formula for this conversion is given in the 'additional considerations'
19 | section of the paper.
20 | """
21 | min_disp = 1 / max_depth
22 | max_disp = 1 / min_depth
23 | scaled_disp = min_disp + (max_disp - min_disp) * disp
24 | depth = 1 / scaled_disp
25 | return scaled_disp, depth
26 |
27 |
28 | def transformation_from_parameters(axisangle, translation, invert=False):
29 | """Convert the network's (axisangle, translation) output into a 4x4 matrix
30 | """
31 | R = rot_from_axisangle(axisangle)
32 | t = translation.clone()
33 |
34 | if invert:
35 | R = R.transpose(1, 2)
36 | t *= -1
37 |
38 | T = get_translation_matrix(t)
39 |
40 | if invert:
41 | M = torch.matmul(R, T)
42 | else:
43 | M = torch.matmul(T, R)
44 |
45 | return M
46 |
47 |
48 | def get_translation_matrix(translation_vector):
49 | """Convert a translation vector into a 4x4 transformation matrix
50 | """
51 | T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device)
52 |
53 | t = translation_vector.contiguous().view(-1, 3, 1)
54 |
55 | T[:, 0, 0] = 1
56 | T[:, 1, 1] = 1
57 | T[:, 2, 2] = 1
58 | T[:, 3, 3] = 1
59 | T[:, :3, 3, None] = t
60 |
61 | return T
62 |
63 |
64 | def rot_from_axisangle(vec):
65 | """Convert an axisangle rotation into a 4x4 transformation matrix
66 | (adapted from https://github.com/Wallacoloo/printipi)
67 | Input 'vec' has to be Bx1x3
68 | """
69 | angle = torch.norm(vec, 2, 2, True)
70 | axis = vec / (angle + 1e-7)
71 |
72 | ca = torch.cos(angle)
73 | sa = torch.sin(angle)
74 | C = 1 - ca
75 |
76 | x = axis[..., 0].unsqueeze(1)
77 | y = axis[..., 1].unsqueeze(1)
78 | z = axis[..., 2].unsqueeze(1)
79 |
80 | xs = x * sa
81 | ys = y * sa
82 | zs = z * sa
83 | xC = x * C
84 | yC = y * C
85 | zC = z * C
86 | xyC = x * yC
87 | yzC = y * zC
88 | zxC = z * xC
89 |
90 | rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device)
91 |
92 | rot[:, 0, 0] = torch.squeeze(x * xC + ca)
93 | rot[:, 0, 1] = torch.squeeze(xyC - zs)
94 | rot[:, 0, 2] = torch.squeeze(zxC + ys)
95 | rot[:, 1, 0] = torch.squeeze(xyC + zs)
96 | rot[:, 1, 1] = torch.squeeze(y * yC + ca)
97 | rot[:, 1, 2] = torch.squeeze(yzC - xs)
98 | rot[:, 2, 0] = torch.squeeze(zxC - ys)
99 | rot[:, 2, 1] = torch.squeeze(yzC + xs)
100 | rot[:, 2, 2] = torch.squeeze(z * zC + ca)
101 | rot[:, 3, 3] = 1
102 |
103 | return rot
104 |
105 |
106 | class ConvBlock(nn.Module):
107 | """Layer to perform a convolution followed by ELU
108 | """
109 | def __init__(self, in_channels, out_channels):
110 | super(ConvBlock, self).__init__()
111 |
112 | self.conv = Conv3x3(in_channels, out_channels)
113 | self.nonlin = nn.ELU(inplace=True)
114 |
115 | def forward(self, x):
116 | out = self.conv(x)
117 | out = self.nonlin(out)
118 | return out
119 |
120 |
121 | class Conv3x3(nn.Module):
122 | """Layer to pad and convolve input
123 | """
124 | def __init__(self, in_channels, out_channels, use_refl=True):
125 | super(Conv3x3, self).__init__()
126 |
127 | if use_refl:
128 | self.pad = nn.ReflectionPad2d(1)
129 | else:
130 | self.pad = nn.ZeroPad2d(1)
131 | self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
132 |
133 | def forward(self, x):
134 | out = self.pad(x)
135 | out = self.conv(out)
136 | return out
137 |
138 |
139 | class BackprojectDepth(nn.Module):
140 | """Layer to transform a depth image into a point cloud
141 | """
142 | def __init__(self, batch_size, height, width):
143 | super(BackprojectDepth, self).__init__()
144 |
145 | self.batch_size = batch_size
146 | self.height = height
147 | self.width = width
148 |
149 | meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
150 | self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
151 | self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords),
152 | requires_grad=False)
153 |
154 | self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
155 | requires_grad=False)
156 |
157 | self.pix_coords = torch.unsqueeze(torch.stack(
158 | [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0)
159 | self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1)
160 | self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1),
161 | requires_grad=False)
162 |
163 | def forward(self, depth, K,scale):
164 | K[:,:2,:] = (K[:,:2,:]/(2 ** scale)).trunc()
165 | b,n,n = K.shape
166 | inv_K = torch.linalg.inv(K)
167 | #inv_K = torch.cholesky_inverse(K)
168 | pad = torch.tensor([0.0,0.0,0.0]).view(1,3,1).expand(b,3,1).cuda()
169 | inv_K = torch.cat([inv_K,pad],-1)
170 | pad = torch.tensor([0.0,0.0,0.0,1.0]).view(1,1,4).expand(b,1,4).cuda()
171 | inv_K = torch.cat([inv_K,pad],1)
172 | cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
173 | cam_points = depth.view(self.batch_size, 1, -1) * cam_points
174 | cam_points = torch.cat([cam_points, self.ones], 1)
175 |
176 | return cam_points
177 |
178 |
179 | class Project3D(nn.Module):
180 | """Layer which projects 3D points into a camera with intrinsics K and at position T
181 | """
182 | def __init__(self, batch_size, height, width, eps=1e-7):
183 | super(Project3D, self).__init__()
184 |
185 | self.batch_size = batch_size
186 | self.height = height
187 | self.width = width
188 | self.eps = eps
189 |
190 | def forward(self, points, K, T,scale=0):
191 | # K[0, :] *= self.width // (2 ** scale)
192 | # K[1, :] *= self.height // (2 ** scale)
193 | K[:,:2,:] = (K[:,:2,:]/(2 ** scale)).trunc()
194 | b,n,n = K.shape
195 | pad = torch.tensor([0.0,0.0,0.0]).view(1,3,1).expand(b,3,1).cuda()
196 | K = torch.cat([K,pad],-1)
197 | pad = torch.tensor([0.0,0.0,0.0,1.0]).view(1,1,4).expand(b,1,4).cuda()
198 | K = torch.cat([K,pad],1)
199 | P = torch.matmul(K, T)[:, :3, :]
200 |
201 | cam_points = torch.matmul(P, points)
202 |
203 | pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
204 | pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
205 | pix_coords = pix_coords.permute(0, 2, 3, 1)
206 | pix_coords[..., 0] /= self.width - 1
207 | pix_coords[..., 1] /= self.height - 1
208 | pix_coords = (pix_coords - 0.5) * 2
209 | return pix_coords
210 |
211 |
212 | def upsample(x):
213 | """Upsample input tensor by a factor of 2
214 | """
215 | return F.interpolate(x, scale_factor=2, mode="nearest")
216 |
217 |
218 | def get_smooth_loss(disp, img):
219 | """Computes the smoothness loss for a disparity image
220 | The color image is used for edge-aware smoothness
221 | """
222 | grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
223 | grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])
224 |
225 | grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
226 | grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)
227 |
228 | grad_disp_x *= torch.exp(-grad_img_x)
229 | grad_disp_y *= torch.exp(-grad_img_y)
230 |
231 | return grad_disp_x.mean() + grad_disp_y.mean()
232 |
233 |
234 | class SSIM(nn.Module):
235 | """Layer to compute the SSIM loss between a pair of images
236 | """
237 | def __init__(self):
238 | super(SSIM, self).__init__()
239 | self.mu_x_pool = nn.AvgPool2d(3, 1)
240 | self.mu_y_pool = nn.AvgPool2d(3, 1)
241 | self.sig_x_pool = nn.AvgPool2d(3, 1)
242 | self.sig_y_pool = nn.AvgPool2d(3, 1)
243 | self.sig_xy_pool = nn.AvgPool2d(3, 1)
244 |
245 | self.refl = nn.ReflectionPad2d(1)
246 |
247 | self.C1 = 0.01 ** 2
248 | self.C2 = 0.03 ** 2
249 |
250 | def forward(self, x, y):
251 | x = self.refl(x)
252 | y = self.refl(y)
253 |
254 | mu_x = self.mu_x_pool(x)
255 | mu_y = self.mu_y_pool(y)
256 |
257 | sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2
258 | sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2
259 | sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y
260 |
261 | SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
262 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)
263 |
264 | return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
265 |
266 |
267 | def compute_depth_errors(gt, pred):
268 | """Computation of error metrics between predicted and ground truth depths
269 | """
270 | thresh = torch.max((gt / pred), (pred / gt))
271 | a1 = (thresh < 1.25 ).float().mean()
272 | a2 = (thresh < 1.25 ** 2).float().mean()
273 | a3 = (thresh < 1.25 ** 3).float().mean()
274 |
275 | rmse = (gt - pred) ** 2
276 | rmse = torch.sqrt(rmse.mean())
277 |
278 | rmse_log = (torch.log(gt) - torch.log(pred)) ** 2
279 | rmse_log = torch.sqrt(rmse_log.mean())
280 |
281 | abs_rel = torch.mean(torch.abs(gt - pred) / gt)
282 |
283 | sq_rel = torch.mean((gt - pred) ** 2 / gt)
284 |
285 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3
286 |
--------------------------------------------------------------------------------
/depth/models/opt.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_path": "/data/fhongac/workspace/src/talkhead/kitti_data",
3 | "log_dir": "tmp",
4 | "model_name": "taking_head_10w",
5 | "split": "eigen_zhou",
6 | "num_layers": 18,
7 | "dataset": "celeb",
8 | "png": false,
9 | "height": 224,
10 | "width": 224,
11 | "disparity_smoothness": 0.001,
12 | "scales": [
13 | 0,
14 | 1,
15 | 2,
16 | 3
17 | ],
18 | "sample_num": 100000,
19 | "min_depth": 0.1,
20 | "max_depth": 100.0,
21 | "use_stereo": false,
22 | "frame_ids": [
23 | 0,
24 | -1,
25 | 1
26 | ],
27 | "batch_size": 64,
28 | "learning_rate": 1e-05,
29 | "num_epochs": 20,
30 | "scheduler_step_size": 15,
31 | "v1_multiscale": false,
32 | "avg_reprojection": false,
33 | "disable_automasking": false,
34 | "predictive_mask": false,
35 | "no_ssim": false,
36 | "weights_init": "pretrained",
37 | "pose_model_input": "pairs",
38 | "pose_model_type": "separate_resnet",
39 | "no_cuda": false,
40 | "num_workers": 12,
41 | "load_weights_folder": null,
42 | "models_to_load": [
43 | "encoder",
44 | "depth",
45 | "pose_encoder",
46 | "pose"
47 | ],
48 | "log_frequency": 250,
49 | "save_frequency": 1,
50 | "eval_stereo": false,
51 | "eval_mono": false,
52 | "disable_median_scaling": false,
53 | "pred_depth_scale_factor": 1,
54 | "ext_disp_to_eval": null,
55 | "eval_split": "eigen",
56 | "save_pred_disps": false,
57 | "no_eval": false,
58 | "eval_eigen_to_benchmark": false,
59 | "eval_out_dir": null,
60 | "post_process": false
61 | }
--------------------------------------------------------------------------------
/depth/pose_cnn.py:
--------------------------------------------------------------------------------
1 | # Copyright Niantic 2019. Patent Pending. All rights reserved.
2 | #
3 | # This software is licensed under the terms of the Monodepth2 licence
4 | # which allows for non-commercial use only, the full terms of which are made
5 | # available in the LICENSE file.
6 |
7 | from __future__ import absolute_import, division, print_function
8 |
9 | import torch
10 | import torch.nn as nn
11 |
12 |
13 | class PoseCNN(nn.Module):
14 | def __init__(self, num_input_frames):
15 | super(PoseCNN, self).__init__()
16 |
17 | self.num_input_frames = num_input_frames
18 |
19 | self.convs = {}
20 | self.convs[0] = nn.Conv2d(3 * num_input_frames, 16, 7, 2, 3)
21 | self.convs[1] = nn.Conv2d(16, 32, 5, 2, 2)
22 | self.convs[2] = nn.Conv2d(32, 64, 3, 2, 1)
23 | self.convs[3] = nn.Conv2d(64, 128, 3, 2, 1)
24 | self.convs[4] = nn.Conv2d(128, 256, 3, 2, 1)
25 | self.convs[5] = nn.Conv2d(256, 256, 3, 2, 1)
26 | self.convs[6] = nn.Conv2d(256, 256, 3, 2, 1)
27 |
28 | self.pose_conv = nn.Conv2d(256, 6 * (num_input_frames - 1), 1)
29 |
30 | self.num_convs = len(self.convs)
31 |
32 | self.relu = nn.ReLU(True)
33 |
34 | self.net = nn.ModuleList(list(self.convs.values()))
35 |
36 | def forward(self, out):
37 |
38 | for i in range(self.num_convs):
39 | out = self.convs[i](out)
40 | out = self.relu(out)
41 |
42 | out = self.pose_conv(out)
43 | out = out.mean(3).mean(2)
44 |
45 | out = 0.01 * out.view(-1, self.num_input_frames - 1, 1, 6)
46 |
47 | axisangle = out[..., :3]
48 | translation = out[..., 3:]
49 |
50 | return axisangle, translation
51 |
--------------------------------------------------------------------------------
/depth/pose_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright Niantic 2019. Patent Pending. All rights reserved.
2 | #
3 | # This software is licensed under the terms of the Monodepth2 licence
4 | # which allows for non-commercial use only, the full terms of which are made
5 | # available in the LICENSE file.
6 |
7 | from __future__ import absolute_import, division, print_function
8 |
9 | import torch
10 | import torch.nn as nn
11 | from collections import OrderedDict
12 | import pdb
13 | import torch.nn.functional as F
14 | # from options import MonodepthOptions
15 | # options = MonodepthOptions()
16 | # opts = options.parse()
17 | class PoseDecoder(nn.Module):
18 | def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=None, stride=1):
19 | super(PoseDecoder, self).__init__()
20 | self.num_ch_enc = num_ch_enc
21 | self.num_input_features = num_input_features
22 |
23 | if num_frames_to_predict_for is None:
24 | num_frames_to_predict_for = num_input_features - 1
25 | self.num_frames_to_predict_for = num_frames_to_predict_for
26 |
27 | self.convs = OrderedDict()
28 | self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1)
29 | self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1)
30 | self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1)
31 | self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1)
32 | self.convs[("intrinsics", 'focal')] = nn.Conv2d(256, 2, kernel_size = 3,stride = 1,padding = 1)
33 | self.convs[("intrinsics", 'offset')] = nn.Conv2d(256, 2, kernel_size = 3,stride = 1,padding = 1)
34 |
35 | self.relu = nn.ReLU()
36 | self.net = nn.ModuleList(list(self.convs.values()))
37 |
38 | def forward(self, input_features):
39 | last_features = [f[-1] for f in input_features]
40 |
41 | cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features]
42 | cat_features = torch.cat(cat_features, 1)
43 |
44 | feat = cat_features
45 | for i in range(2):
46 | feat = self.convs[("pose", i)](feat)
47 | feat = self.relu(feat)
48 | out = self.convs[("pose", 2)](feat)
49 |
50 | out = out.mean(3).mean(2)
51 | out = 0.01 * out.view(-1, self.num_frames_to_predict_for, 1, 6)
52 |
53 | axisangle = out[..., :3]
54 | translation = out[..., 3:]
55 |
56 | #add_intrinsics_head
57 | scales = torch.tensor([256,256]).cuda()
58 | focals = F.softplus(self.convs[("intrinsics", 'focal')](feat)).mean(3).mean(2)*scales
59 | offset = (F.softplus(self.convs[("intrinsics", 'offset')](feat)).mean(3).mean(2)+0.5)*scales
60 | #focals = F.softplus(self.convs[("intrinsics",'focal')](feat).mean(3).mean(2))
61 | #offset = F.softplus(self.convs[("intrinsics",'offset')](feat).mean(3).mean(2))
62 | eyes = torch.eye(2).cuda()
63 | b,xy = focals.shape
64 | focals = focals.unsqueeze(-1).expand(b,xy,xy)
65 | eyes = eyes.unsqueeze(0).expand(b,xy,xy)
66 | intrin = focals*eyes
67 | offset = offset.view(b,2,1).contiguous()
68 | intrin = torch.cat([intrin,offset],-1)
69 | pad = torch.tensor([0.0,0.0,1.0]).view(1,1,3).expand(b,1,3).cuda()
70 | intrinsics = torch.cat([intrin,pad],1)
71 | return axisangle, translation,intrinsics
72 |
--------------------------------------------------------------------------------
/depth/resnet_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright Niantic 2019. Patent Pending. All rights reserved.
2 | #
3 | # This software is licensed under the terms of the Monodepth2 licence
4 | # which allows for non-commercial use only, the full terms of which are made
5 | # available in the LICENSE file.
6 |
7 | from __future__ import absolute_import, division, print_function
8 |
9 | import numpy as np
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torchvision.models as models
14 | import torch.utils.model_zoo as model_zoo
15 |
16 |
17 | class ResNetMultiImageInput(models.ResNet):
18 | """Constructs a resnet model with varying number of input images.
19 | Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
20 | """
21 | def __init__(self, block, layers, num_classes=1000, num_input_images=1):
22 | super(ResNetMultiImageInput, self).__init__(block, layers)
23 | self.inplanes = 64
24 | self.conv1 = nn.Conv2d(
25 | num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
26 | self.bn1 = nn.BatchNorm2d(64)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
29 | self.layer1 = self._make_layer(block, 64, layers[0])
30 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
31 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
32 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
33 |
34 | for m in self.modules():
35 | if isinstance(m, nn.Conv2d):
36 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
37 | elif isinstance(m, nn.BatchNorm2d):
38 | nn.init.constant_(m.weight, 1)
39 | nn.init.constant_(m.bias, 0)
40 |
41 |
42 | def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1):
43 | """Constructs a ResNet model.
44 | Args:
45 | num_layers (int): Number of resnet layers. Must be 18 or 50
46 | pretrained (bool): If True, returns a model pre-trained on ImageNet
47 | num_input_images (int): Number of frames stacked as input
48 | """
49 | assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
50 | blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
51 | block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers]
52 | model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images)
53 |
54 | if pretrained:
55 | loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)])
56 | loaded['conv1.weight'] = torch.cat(
57 | [loaded['conv1.weight']] * num_input_images, 1) / num_input_images
58 | model.load_state_dict(loaded)
59 | return model
60 |
61 |
62 | class ResnetEncoder(nn.Module):
63 | """Pytorch module for a resnet encoder
64 | """
65 | def __init__(self, num_layers, pretrained, num_input_images=1):
66 | super(ResnetEncoder, self).__init__()
67 |
68 | self.num_ch_enc = np.array([64, 64, 128, 256, 512])
69 |
70 | resnets = {18: models.resnet18,
71 | 34: models.resnet34,
72 | 50: models.resnet50,
73 | 101: models.resnet101,
74 | 152: models.resnet152}
75 |
76 | if num_layers not in resnets:
77 | raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
78 |
79 | if num_input_images > 1:
80 | self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images)
81 | else:
82 | self.encoder = resnets[num_layers](pretrained)
83 |
84 | if num_layers > 34:
85 | self.num_ch_enc[1:] *= 4
86 |
87 | def forward(self, input_image):
88 | self.features = []
89 | x = (input_image - 0.45) / 0.225
90 | x = self.encoder.conv1(x)
91 | x = self.encoder.bn1(x)
92 | self.features.append(self.encoder.relu(x))
93 | self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
94 | self.features.append(self.encoder.layer2(self.features[-1]))
95 | self.features.append(self.encoder.layer3(self.features[-1]))
96 | self.features.append(self.encoder.layer4(self.features[-1]))
97 |
98 | return self.features
99 |
--------------------------------------------------------------------------------
/evaluation_dataset.py:
--------------------------------------------------------------------------------
1 | """Dataset class template
2 |
3 | This module provides a template for users to implement custom datasets.
4 | You can specify '--dataset_mode template' to use this dataset.
5 | The class name should be consistent with both the filename and its dataset_mode option.
6 | The filename should be _dataset.py
7 | The class name should be Dataset.py
8 | You need to implement the following functions:
9 | -- : Add dataset-specific options and rewrite default values for existing options.
10 | -- <__init__>: Initialize this dataset class.
11 | -- <__getitem__>: Return a data point and its metadata information.
12 | -- <__len__>: Return the number of images.
13 | """
14 | from torch.utils.data import Dataset
15 | import torchvision.transforms as T
16 | from PIL import Image
17 | from PIL import ImageFile
18 | from skimage import io, img_as_float32
19 | import numpy as np
20 | ImageFile.LOAD_TRUNCATED_IMAGES = True
21 | # from data.image_folder import make_dataset
22 | # from PIL import Image
23 | import os
24 | import torch
25 | import pdb
26 | import pandas as pd
27 |
28 | class EvaluationDataset(Dataset):
29 | """A template dataset class for you to implement custom datasets."""
30 |
31 | def __init__(self, dataroot, pairs_list=None):
32 | """Initialize this dataset class.
33 |
34 | Parameters:
35 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
36 |
37 | A few things can be done here.
38 | - save the options (have been done in BaseDataset)
39 | - get image paths and meta information of the dataset.
40 | - define the image transformation.
41 | """
42 | # save the option and dataset root
43 | # get the image paths of your dataset;
44 | self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
45 | # define the default transform function. You can use ; You can also define your custom transform function
46 | self.dataroot = dataroot
47 | # self.videos = self.videos[5000]
48 | self.frame_shape = (3,256,256)
49 | test_videos = os.listdir(os.path.join(self.dataroot,'test'))
50 | self.videos = test_videos
51 | pairs = pd.read_csv(pairs_list)
52 | self.source = pairs['source'].tolist()
53 | self.driving = pairs['driving'].tolist()
54 | # self.pose_anchors = pairs['best_frame'].tolist()
55 |
56 | self.transforms = T.Compose([T.ToTensor(),
57 | T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
58 | def __getitem__(self, idx):
59 | """Return a data point and its metadata information.
60 |
61 | Parameters:
62 | index -- a random integer for data indexing
63 |
64 | Returns:
65 | a dictionary of data with their names. It usually contains the data itself and its metadata information.
66 |
67 | Step 1: get a random image path: e.g., path = self.image_paths[index]
68 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
69 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
70 | Step 4: return a data point as a dictionary.
71 | """
72 | path_source = self.source[idx]
73 | path_driving = self.driving[idx]
74 | # path_anchor = self.pose_anchors[idx]
75 | anchor = ''
76 | source = img_as_float32(io.imread(path_source))
77 | source = np.array(source, dtype='float32')
78 | source = torch.tensor(source.transpose((2, 0, 1)))
79 |
80 | driving = img_as_float32(io.imread(path_driving))
81 | driving = np.array(driving, dtype='float32')
82 | driving = torch.tensor(driving.transpose((2, 0, 1)))
83 |
84 | # anchor = img_as_float32(io.imread(path_anchor))
85 | # anchor = np.array(anchor, dtype='float32')
86 | # anchor = torch.tensor(anchor.transpose((2, 0, 1)))
87 |
88 | # source = Image.open(path_source).convert('RGB')
89 | # driving = Image.open(path_driving).convert('RGB')
90 | # source = T.ToTensor()(source)
91 | # driving = T.ToTensor()(driving)
92 | return {'source': source, 'driving': driving, 'path_source': path_source,'path_driving':path_driving, 'anchor': anchor}
93 |
94 | def __len__(self):
95 | """Return the total number of images."""
96 | return len(self.source)
97 |
98 |
99 |
--------------------------------------------------------------------------------
/face-alignment/.gitattributes:
--------------------------------------------------------------------------------
1 | *.py linguist-language=python
2 | *.ipynb linguist-documentation
3 |
--------------------------------------------------------------------------------
/face-alignment/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/face-alignment/Dockerfile:
--------------------------------------------------------------------------------
1 | # Based on a older version of https://github.com/pytorch/pytorch/blob/master/Dockerfile
2 | FROM nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04
3 |
4 | RUN apt-get update && apt-get install -y --no-install-recommends \
5 | build-essential \
6 | cmake \
7 | git \
8 | curl \
9 | vim \
10 | ca-certificates \
11 | libboost-all-dev \
12 | python-qt4 \
13 | libjpeg-dev \
14 | libpng-dev &&\
15 | rm -rf /var/lib/apt/lists/*
16 |
17 | RUN curl -o ~/miniconda.sh -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
18 | chmod +x ~/miniconda.sh && \
19 | ~/miniconda.sh -b -p /opt/conda && \
20 | rm ~/miniconda.sh
21 |
22 | ENV PATH /opt/conda/bin:$PATH
23 |
24 | RUN conda config --set always_yes yes --set changeps1 no && conda update -q conda
25 | RUN conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
26 |
27 | # Install face-alignment package
28 | WORKDIR /workspace
29 | RUN chmod -R a+w /workspace
30 | RUN git clone https://github.com/1adrianb/face-alignment
31 | WORKDIR /workspace/face-alignment
32 | RUN pip install -r requirements.txt
33 | RUN python setup.py install
34 |
--------------------------------------------------------------------------------
/face-alignment/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2017, Adrian Bulat
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/face-alignment/README.md:
--------------------------------------------------------------------------------
1 | # Face Recognition
2 |
3 | Detect facial landmarks from Python using the world's most accurate face alignment network, capable of detecting points in both 2D and 3D coordinates.
4 |
5 | Build using [FAN](https://www.adrianbulat.com)'s state-of-the-art deep learning based face alignment method.
6 |
7 | 
8 |
9 | **Note:** The lua version is available [here](https://github.com/1adrianb/2D-and-3D-face-alignment).
10 |
11 | For numerical evaluations it is highly recommended to use the lua version which uses indentical models with the ones evaluated in the paper. More models will be added soon.
12 |
13 | [](https://opensource.org/licenses/BSD-3-Clause) [](https://github.com/1adrianb/face-alignment/actions?query=workflow%3A%22Test+Face+alignmnet%22) [](https://anaconda.org/1adrianb/face_alignment)
14 | [](https://pypi.org/project/face-alignment/)
15 |
16 | ## Features
17 |
18 | #### Detect 2D facial landmarks in pictures
19 |
20 |
21 |
22 |
23 |
24 | ```python
25 | import face_alignment
26 | from skimage import io
27 |
28 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
29 |
30 | input = io.imread('../test/assets/aflw-test.jpg')
31 | preds = fa.get_landmarks(input)
32 | ```
33 |
34 | #### Detect 3D facial landmarks in pictures
35 |
36 |
37 |
38 |
39 |
40 | ```python
41 | import face_alignment
42 | from skimage import io
43 |
44 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, flip_input=False)
45 |
46 | input = io.imread('../test/assets/aflw-test.jpg')
47 | preds = fa.get_landmarks(input)
48 | ```
49 |
50 | #### Process an entire directory in one go
51 |
52 | ```python
53 | import face_alignment
54 | from skimage import io
55 |
56 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
57 |
58 | preds = fa.get_landmarks_from_directory('../test/assets/')
59 | ```
60 |
61 | #### Detect the landmarks using a specific face detector.
62 |
63 | By default the package will use the SFD face detector. However the users can alternatively use dlib, BlazeFace, or pre-existing ground truth bounding boxes.
64 |
65 | ```python
66 | import face_alignment
67 |
68 | # sfd for SFD, dlib for Dlib and folder for existing bounding boxes.
69 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, face_detector='sfd')
70 | ```
71 |
72 | #### Running on CPU/GPU
73 | In order to specify the device (GPU or CPU) on which the code will run one can explicitly pass the device flag:
74 |
75 | ```python
76 | import face_alignment
77 |
78 | # cuda for CUDA
79 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device='cpu')
80 | ```
81 |
82 | Please also see the ``examples`` folder
83 |
84 | ## Installation
85 |
86 | ### Requirements
87 |
88 | * Python 3.5+ (it may work with other versions too). Last version with support for python 2.7 was v1.1.1
89 | * Linux, Windows or macOS
90 | * pytorch (>=1.5)
91 |
92 | While not required, for optimal performance(especially for the detector) it is **highly** recommended to run the code using a CUDA enabled GPU.
93 |
94 | ### Binaries
95 |
96 | The easiest way to install it is using either pip or conda:
97 |
98 | | **Using pip** | **Using conda** |
99 | |------------------------------|--------------------------------------------|
100 | | `pip install face-alignment` | `conda install -c 1adrianb face_alignment` |
101 | | | |
102 |
103 | Alternatively, bellow, you can find instruction to build it from source.
104 |
105 | ### From source
106 |
107 | Install pytorch and pytorch dependencies. Please check the [pytorch readme](https://github.com/pytorch/pytorch) for this.
108 |
109 | #### Get the Face Alignment source code
110 | ```bash
111 | git clone https://github.com/1adrianb/face-alignment
112 | ```
113 | #### Install the Face Alignment lib
114 | ```bash
115 | pip install -r requirements.txt
116 | python setup.py install
117 | ```
118 |
119 | ### Docker image
120 |
121 | A Dockerfile is provided to build images with cuda support and cudnn. For more instructions about running and building a docker image check the orginal Docker documentation.
122 | ```
123 | docker build -t face-alignment .
124 | ```
125 |
126 | ## How does it work?
127 |
128 | While here the work is presented as a black-box, if you want to know more about the intrisecs of the method please check the original paper either on arxiv or my [webpage](https://www.adrianbulat.com).
129 |
130 | ## Contributions
131 |
132 | All contributions are welcomed. If you encounter any issue (including examples of images where it fails) feel free to open an issue. If you plan to add a new features please open an issue to discuss this prior to making a pull request.
133 |
134 | ## Citation
135 |
136 | ```
137 | @inproceedings{bulat2017far,
138 | title={How far are we from solving the 2D \& 3D Face Alignment problem? (and a dataset of 230,000 3D facial landmarks)},
139 | author={Bulat, Adrian and Tzimiropoulos, Georgios},
140 | booktitle={International Conference on Computer Vision},
141 | year={2017}
142 | }
143 | ```
144 |
145 | For citing dlib, pytorch or any other packages used here please check the original page of their respective authors.
146 |
147 | ## Acknowledgements
148 |
149 | * To the [pytorch](http://pytorch.org/) team for providing such an awesome deeplearning framework
150 | * To [my supervisor](http://www.cs.nott.ac.uk/~pszyt/) for his patience and suggestions.
151 | * To all other python developers that made available the rest of the packages used in this repository.
--------------------------------------------------------------------------------
/face-alignment/conda/meta.yaml:
--------------------------------------------------------------------------------
1 | {% set version = "1.3.4" %}
2 |
3 | package:
4 | name: face_alignment
5 | version: {{ version }}
6 |
7 | source:
8 | path: ..
9 |
10 | build:
11 | number: 1
12 | noarch: python
13 | script: python setup.py install --single-version-externally-managed --record=record.txt
14 |
15 | requirements:
16 | build:
17 | - setuptools
18 | - python
19 | run:
20 | - python
21 | - pytorch
22 | - numpy
23 | - scikit-image
24 | - scipy
25 | - opencv
26 | - tqdm
27 | - numba
28 |
29 | about:
30 | home: https://github.com/1adrianb/face-alignment
31 | license: BSD
32 | license_file: LICENSE
33 | summary: A 2D and 3D face alignment libray in python
34 |
35 | extra:
36 | recipe-maintainers:
37 | - 1adrianb
38 |
--------------------------------------------------------------------------------
/face-alignment/docs/images/2dlandmarks.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harlanhong/CVPR2022-DaGAN/052ad359a019492893c098e8717dc993cdf846ed/face-alignment/docs/images/2dlandmarks.png
--------------------------------------------------------------------------------
/face-alignment/docs/images/face-alignment-adrian.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harlanhong/CVPR2022-DaGAN/052ad359a019492893c098e8717dc993cdf846ed/face-alignment/docs/images/face-alignment-adrian.gif
--------------------------------------------------------------------------------
/face-alignment/examples/detect_landmarks_in_image.py:
--------------------------------------------------------------------------------
1 | import face_alignment
2 | import matplotlib.pyplot as plt
3 | from mpl_toolkits.mplot3d import Axes3D
4 | from skimage import io
5 | import collections
6 |
7 |
8 | # Optionally set detector and some additional detector parameters
9 | face_detector = 'sfd'
10 | face_detector_kwargs = {
11 | "filter_threshold" : 0.8
12 | }
13 |
14 | # Run the 3D face alignment on a test image, without CUDA.
15 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, device='cpu', flip_input=True,
16 | face_detector=face_detector, face_detector_kwargs=face_detector_kwargs)
17 |
18 | try:
19 | input_img = io.imread('../test/assets/aflw-test.jpg')
20 | except FileNotFoundError:
21 | input_img = io.imread('test/assets/aflw-test.jpg')
22 |
23 | preds = fa.get_landmarks(input_img)[-1]
24 |
25 | # 2D-Plot
26 | plot_style = dict(marker='o',
27 | markersize=4,
28 | linestyle='-',
29 | lw=2)
30 |
31 | pred_type = collections.namedtuple('prediction_type', ['slice', 'color'])
32 | pred_types = {'face': pred_type(slice(0, 17), (0.682, 0.780, 0.909, 0.5)),
33 | 'eyebrow1': pred_type(slice(17, 22), (1.0, 0.498, 0.055, 0.4)),
34 | 'eyebrow2': pred_type(slice(22, 27), (1.0, 0.498, 0.055, 0.4)),
35 | 'nose': pred_type(slice(27, 31), (0.345, 0.239, 0.443, 0.4)),
36 | 'nostril': pred_type(slice(31, 36), (0.345, 0.239, 0.443, 0.4)),
37 | 'eye1': pred_type(slice(36, 42), (0.596, 0.875, 0.541, 0.3)),
38 | 'eye2': pred_type(slice(42, 48), (0.596, 0.875, 0.541, 0.3)),
39 | 'lips': pred_type(slice(48, 60), (0.596, 0.875, 0.541, 0.3)),
40 | 'teeth': pred_type(slice(60, 68), (0.596, 0.875, 0.541, 0.4))
41 | }
42 |
43 | fig = plt.figure(figsize=plt.figaspect(.5))
44 | ax = fig.add_subplot(1, 2, 1)
45 | ax.imshow(input_img)
46 |
47 | for pred_type in pred_types.values():
48 | ax.plot(preds[pred_type.slice, 0],
49 | preds[pred_type.slice, 1],
50 | color=pred_type.color, **plot_style)
51 |
52 | ax.axis('off')
53 |
54 | # 3D-Plot
55 | ax = fig.add_subplot(1, 2, 2, projection='3d')
56 | surf = ax.scatter(preds[:, 0] * 1.2,
57 | preds[:, 1],
58 | preds[:, 2],
59 | c='cyan',
60 | alpha=1.0,
61 | edgecolor='b')
62 |
63 | for pred_type in pred_types.values():
64 | ax.plot3D(preds[pred_type.slice, 0] * 1.2,
65 | preds[pred_type.slice, 1],
66 | preds[pred_type.slice, 2], color='blue')
67 |
68 | ax.view_init(elev=90., azim=90.)
69 | ax.set_xlim(ax.get_xlim()[::-1])
70 | plt.show()
71 |
--------------------------------------------------------------------------------
/face-alignment/face_alignment/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | __author__ = """Adrian Bulat"""
4 | __email__ = 'adrian@adrianbulat.com'
5 | __version__ = '1.3.4'
6 |
7 | from .api import FaceAlignment, LandmarksType, NetworkSize
--------------------------------------------------------------------------------
/face-alignment/face_alignment/detection/__init__.py:
--------------------------------------------------------------------------------
1 | from .core import FaceDetector
--------------------------------------------------------------------------------
/face-alignment/face_alignment/detection/blazeface/__init__.py:
--------------------------------------------------------------------------------
1 | from .blazeface_detector import BlazeFaceDetector as FaceDetector
2 |
--------------------------------------------------------------------------------
/face-alignment/face_alignment/detection/blazeface/blazeface_detector.py:
--------------------------------------------------------------------------------
1 | from torch.utils.model_zoo import load_url
2 |
3 | from ..core import FaceDetector
4 | from ...utils import load_file_from_url
5 |
6 | from .net_blazeface import BlazeFace
7 | from .detect import *
8 |
9 | models_urls = {
10 | 'blazeface_weights': 'https://github.com/hollance/BlazeFace-PyTorch/blob/master/blazeface.pth?raw=true',
11 | 'blazeface_anchors': 'https://github.com/hollance/BlazeFace-PyTorch/blob/master/anchors.npy?raw=true'
12 | }
13 |
14 |
15 | class BlazeFaceDetector(FaceDetector):
16 | def __init__(self, device, path_to_detector=None, path_to_anchor=None, verbose=False,
17 | min_score_thresh=0.5, min_suppression_threshold=0.3):
18 | super(BlazeFaceDetector, self).__init__(device, verbose)
19 |
20 | # Initialise the face detector
21 | if path_to_detector is None:
22 | model_weights = load_url(models_urls['blazeface_weights'])
23 | model_anchors = np.load(load_file_from_url(models_urls['blazeface_anchors']))
24 | else:
25 | model_weights = torch.load(path_to_detector)
26 | model_anchors = np.load(path_to_anchor)
27 |
28 | self.face_detector = BlazeFace()
29 | self.face_detector.load_state_dict(model_weights)
30 | self.face_detector.load_anchors_from_npy(model_anchors, device)
31 |
32 | # Optionally change the thresholds:
33 | self.face_detector.min_score_thresh = min_score_thresh
34 | self.face_detector.min_suppression_threshold = min_suppression_threshold
35 |
36 | self.face_detector.to(device)
37 | self.face_detector.eval()
38 |
39 | def detect_from_image(self, tensor_or_path):
40 | image = self.tensor_or_path_to_ndarray(tensor_or_path)
41 |
42 | bboxlist = detect(self.face_detector, image, device=self.device)[0]
43 |
44 | return bboxlist
45 |
46 | def detect_from_batch(self, tensor):
47 | bboxlists = batch_detect(self.face_detector, tensor, device=self.device)
48 | return bboxlists
49 |
50 | @property
51 | def reference_scale(self):
52 | return 195
53 |
54 | @property
55 | def reference_x_shift(self):
56 | return 0
57 |
58 | @property
59 | def reference_y_shift(self):
60 | return 0
61 |
--------------------------------------------------------------------------------
/face-alignment/face_alignment/detection/blazeface/detect.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | import cv2
5 | import numpy as np
6 |
7 | from .utils import *
8 |
9 |
10 | def detect(net, img, device):
11 | H, W, C = img.shape
12 | orig_size = min(H, W)
13 | img, (xshift, yshift) = resize_and_crop_image(img, 128)
14 | preds = net.predict_on_image(img)
15 |
16 | if 0 == len(preds):
17 | return [[]]
18 |
19 | shift = np.array([xshift, yshift] * 2)
20 | scores = preds[:, -1:]
21 |
22 | # TODO: ugly
23 | # reverses, x and y to adapt with face-alignment code
24 | locs = np.concatenate((preds[:, 1:2], preds[:, 0:1], preds[:, 3:4], preds[:, 2:3]), axis=1)
25 | return [np.concatenate((locs * orig_size + shift, scores), axis=1)]
26 |
27 |
28 | def batch_detect(net, img_batch, device):
29 | """
30 | Inputs:
31 | - img_batch: a numpy array or tensor of shape (Batch size, Channels, Height, Width)
32 | Outputs:
33 | - list of 2-dim numpy arrays with shape (faces_on_this_image, 5): x1, y1, x2, y2, confidence
34 | (x1, y1) - top left corner, (x2, y2) - bottom right corner
35 | """
36 | B, C, H, W = img_batch.shape
37 | orig_size = min(H, W)
38 |
39 | if isinstance(img_batch, torch.Tensor):
40 | img_batch = img_batch.cpu().numpy()
41 |
42 | img_batch = img_batch.transpose((0, 2, 3, 1))
43 |
44 | imgs, (xshift, yshift) = resize_and_crop_batch(img_batch, 128)
45 | preds = net.predict_on_batch(imgs)
46 | bboxlists = []
47 | for pred in preds:
48 | shift = np.array([xshift, yshift] * 2)
49 | scores = pred[:, -1:]
50 | locs = np.concatenate((pred[:, 1:2], pred[:, 0:1], pred[:, 3:4], pred[:, 2:3]), axis=1)
51 | bboxlists.append(np.concatenate((locs * orig_size + shift, scores), axis=1))
52 |
53 | return bboxlists
54 |
55 |
56 | def flip_detect(net, img, device):
57 | img = cv2.flip(img, 1)
58 | b = detect(net, img, device)
59 |
60 | bboxlist = np.zeros(b.shape)
61 | bboxlist[:, 0] = img.shape[1] - b[:, 2]
62 | bboxlist[:, 1] = b[:, 1]
63 | bboxlist[:, 2] = img.shape[1] - b[:, 0]
64 | bboxlist[:, 3] = b[:, 3]
65 | bboxlist[:, 4] = b[:, 4]
66 | return bboxlist
67 |
68 |
69 | def pts_to_bb(pts):
70 | min_x, min_y = np.min(pts, axis=0)
71 | max_x, max_y = np.max(pts, axis=0)
72 | return np.array([min_x, min_y, max_x, max_y])
73 |
--------------------------------------------------------------------------------
/face-alignment/face_alignment/detection/blazeface/utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | def image_resize(image, width=None, height=None, inter=cv2.INTER_AREA):
6 | # initialize the dimensions of the image to be resized and
7 | # grab the image size
8 | dim = None
9 | (h, w) = image.shape[:2]
10 |
11 | # if both the width and height are None, then return the
12 | # original image
13 | if width is None and height is None:
14 | return image
15 |
16 | # check to see if the width is None
17 | if width is None:
18 | # calculate the ratio of the height and construct the
19 | # dimensions
20 | r = height / float(h)
21 | dim = (int(w * r), height)
22 |
23 | # otherwise, the height is None
24 | else:
25 | # calculate the ratio of the width and construct the
26 | # dimensions
27 | r = width / float(w)
28 | dim = (width, int(h * r))
29 |
30 | # resize the image
31 | resized = cv2.resize(image, dim, interpolation=inter)
32 |
33 | # return the resized image
34 | return resized
35 |
36 |
37 | def resize_and_crop_image(image, dim):
38 | if image.shape[0] > image.shape[1]:
39 | img = image_resize(image, width=dim)
40 | yshift, xshift = (image.shape[0] - image.shape[1]) // 2, 0
41 | y_start = (img.shape[0] - img.shape[1]) // 2
42 | y_end = y_start + dim
43 | return img[y_start:y_end, :, :], (xshift, yshift)
44 | else:
45 | img = image_resize(image, height=dim)
46 | yshift, xshift = 0, (image.shape[1] - image.shape[0]) // 2
47 | x_start = (img.shape[1] - img.shape[0]) // 2
48 | x_end = x_start + dim
49 | return img[:, x_start:x_end, :], (xshift, yshift)
50 |
51 |
52 | def resize_and_crop_batch(frames, dim):
53 | """
54 | Center crop + resize to (dim x dim)
55 | inputs:
56 | - frames: list of images (numpy arrays)
57 | - dim: output dimension size
58 | """
59 | smframes = []
60 | xshift, yshift = 0, 0
61 | for i in range(len(frames)):
62 | smframe, (xshift, yshift) = resize_and_crop_image(frames[i], dim)
63 | smframes.append(smframe)
64 | smframes = np.stack(smframes)
65 | return smframes, (xshift, yshift)
66 |
--------------------------------------------------------------------------------
/face-alignment/face_alignment/detection/core.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import glob
3 | from tqdm import tqdm
4 | import numpy as np
5 | import torch
6 | from skimage import io
7 |
8 |
9 | class FaceDetector(object):
10 | """An abstract class representing a face detector.
11 |
12 | Any other face detection implementation must subclass it. All subclasses
13 | must implement ``detect_from_image``, that return a list of detected
14 | bounding boxes. Optionally, for speed considerations detect from path is
15 | recommended.
16 | """
17 |
18 | def __init__(self, device, verbose):
19 | self.device = device
20 | self.verbose = verbose
21 |
22 | if verbose:
23 | if 'cpu' in device:
24 | logger = logging.getLogger(__name__)
25 | logger.warning("Detection running on CPU, this may be potentially slow.")
26 |
27 | if 'cpu' not in device and 'cuda' not in device:
28 | if verbose:
29 | logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
30 | raise ValueError
31 |
32 | def detect_from_image(self, tensor_or_path):
33 | """Detects faces in a given image.
34 |
35 | This function detects the faces present in a provided BGR(usually)
36 | image. The input can be either the image itself or the path to it.
37 |
38 | Arguments:
39 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
40 | to an image or the image itself.
41 |
42 | Example::
43 |
44 | >>> path_to_image = 'data/image_01.jpg'
45 | ... detected_faces = detect_from_image(path_to_image)
46 | [A list of bounding boxes (x1, y1, x2, y2)]
47 | >>> image = cv2.imread(path_to_image)
48 | ... detected_faces = detect_from_image(image)
49 | [A list of bounding boxes (x1, y1, x2, y2)]
50 |
51 | """
52 | raise NotImplementedError
53 |
54 | def detect_from_batch(self, tensor):
55 | """Detects faces in a given image.
56 |
57 | This function detects the faces present in a provided BGR(usually)
58 | image. The input can be either the image itself or the path to it.
59 |
60 | Arguments:
61 | tensor {torch.tensor} -- image batch tensor.
62 |
63 | Example::
64 |
65 | >>> path_to_image = 'data/image_01.jpg'
66 | ... detected_faces = detect_from_image(path_to_image)
67 | [A list of bounding boxes (x1, y1, x2, y2)]
68 | >>> image = cv2.imread(path_to_image)
69 | ... detected_faces = detect_from_image(image)
70 | [A list of bounding boxes (x1, y1, x2, y2)]
71 |
72 | """
73 | raise NotImplementedError
74 |
75 | def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
76 | """Detects faces from all the images present in a given directory.
77 |
78 | Arguments:
79 | path {string} -- a string containing a path that points to the folder containing the images
80 |
81 | Keyword Arguments:
82 | extensions {list} -- list of string containing the extensions to be
83 | consider in the following format: ``.extension_name`` (default:
84 | {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
85 | folder recursively (default: {False}) show_progress_bar {bool} --
86 | display a progressbar (default: {True})
87 |
88 | Example:
89 | >>> directory = 'data'
90 | ... detected_faces = detect_from_directory(directory)
91 | {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
92 |
93 | """
94 | if self.verbose:
95 | logger = logging.getLogger(__name__)
96 |
97 | if len(extensions) == 0:
98 | if self.verbose:
99 | logger.error("Expected at list one extension, but none was received.")
100 | raise ValueError
101 |
102 | if self.verbose:
103 | logger.info("Constructing the list of images.")
104 | additional_pattern = '/**/*' if recursive else '/*'
105 | files = []
106 | for extension in extensions:
107 | files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
108 |
109 | if self.verbose:
110 | logger.info("Finished searching for images. %s images found", len(files))
111 | logger.info("Preparing to run the detection.")
112 |
113 | predictions = {}
114 | for image_path in tqdm(files, disable=not show_progress_bar):
115 | if self.verbose:
116 | logger.info("Running the face detector on image: %s", image_path)
117 | predictions[image_path] = self.detect_from_image(image_path)
118 |
119 | if self.verbose:
120 | logger.info("The detector was successfully run on all %s images", len(files))
121 |
122 | return predictions
123 |
124 | @property
125 | def reference_scale(self):
126 | raise NotImplementedError
127 |
128 | @property
129 | def reference_x_shift(self):
130 | raise NotImplementedError
131 |
132 | @property
133 | def reference_y_shift(self):
134 | raise NotImplementedError
135 |
136 | @staticmethod
137 | def tensor_or_path_to_ndarray(tensor_or_path):
138 | """Convert path (represented as a string) or torch.tensor to a numpy.ndarray
139 |
140 | Arguments:
141 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
142 | """
143 | if isinstance(tensor_or_path, str):
144 | return io.imread(tensor_or_path)
145 | elif torch.is_tensor(tensor_or_path):
146 | return tensor_or_path.cpu().numpy()
147 | elif isinstance(tensor_or_path, np.ndarray):
148 | return tensor_or_path
149 | else:
150 | raise TypeError
151 |
--------------------------------------------------------------------------------
/face-alignment/face_alignment/detection/dlib/__init__.py:
--------------------------------------------------------------------------------
1 | from .dlib_detector import DlibDetector as FaceDetector
--------------------------------------------------------------------------------
/face-alignment/face_alignment/detection/dlib/dlib_detector.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import cv2
3 | import dlib
4 |
5 | from ..core import FaceDetector
6 | from ...utils import load_file_from_url
7 |
8 |
9 | class DlibDetector(FaceDetector):
10 | def __init__(self, device, path_to_detector=None, verbose=False):
11 | super().__init__(device, verbose)
12 |
13 | warnings.warn('Warning: this detector is deprecated. Please use a different one, i.e.: S3FD.')
14 |
15 | # Initialise the face detector
16 | if 'cuda' in device:
17 | if path_to_detector is None:
18 | path_to_detector = load_file_from_url(
19 | "https://www.adrianbulat.com/downloads/dlib/mmod_human_face_detector.dat")
20 |
21 | self.face_detector = dlib.cnn_face_detection_model_v1(path_to_detector)
22 | else:
23 | self.face_detector = dlib.get_frontal_face_detector()
24 |
25 | def detect_from_image(self, tensor_or_path):
26 | image = self.tensor_or_path_to_ndarray(tensor_or_path)
27 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
28 |
29 | detected_faces = self.face_detector(cv2.cvtColor(image, cv2.COLOR_BGR2GRAY))
30 |
31 | if 'cuda' not in self.device:
32 | detected_faces = [[d.left(), d.top(), d.right(), d.bottom()] for d in detected_faces]
33 | else:
34 | detected_faces = [[d.rect.left(), d.rect.top(), d.rect.right(), d.rect.bottom()] for d in detected_faces]
35 |
36 | return detected_faces
37 |
38 | @property
39 | def reference_scale(self):
40 | return 195
41 |
42 | @property
43 | def reference_x_shift(self):
44 | return 0
45 |
46 | @property
47 | def reference_y_shift(self):
48 | return 0
49 |
--------------------------------------------------------------------------------
/face-alignment/face_alignment/detection/folder/__init__.py:
--------------------------------------------------------------------------------
1 | from .folder_detector import FolderDetector as FaceDetector
--------------------------------------------------------------------------------
/face-alignment/face_alignment/detection/folder/folder_detector.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 |
5 | from ..core import FaceDetector
6 |
7 |
8 | class FolderDetector(FaceDetector):
9 | '''This is a simple helper module that assumes the faces were detected already
10 | (either previously or are provided as ground truth).
11 |
12 | The class expects to find the bounding boxes in the same format used by
13 | the rest of face detectors, mainly ``list[(x1,y1,x2,y2),...]``.
14 | For each image the detector will search for a file with the same name and with one of the
15 | following extensions: .npy, .t7 or .pth
16 |
17 | '''
18 |
19 | def __init__(self, device, path_to_detector=None, verbose=False):
20 | super(FolderDetector, self).__init__(device, verbose)
21 |
22 | def detect_from_image(self, tensor_or_path):
23 | # Only strings supported
24 | if not isinstance(tensor_or_path, str):
25 | raise ValueError
26 |
27 | base_name = os.path.splitext(tensor_or_path)[0]
28 |
29 | if os.path.isfile(base_name + '.npy'):
30 | detected_faces = np.load(base_name + '.npy')
31 | elif os.path.isfile(base_name + '.t7'):
32 | detected_faces = torch.load(base_name + '.t7')
33 | elif os.path.isfile(base_name + '.pth'):
34 | detected_faces = torch.load(base_name + '.pth')
35 | else:
36 | raise FileNotFoundError
37 |
38 | if not isinstance(detected_faces, list):
39 | raise TypeError
40 |
41 | return detected_faces
42 |
43 | @property
44 | def reference_scale(self):
45 | return 195
46 |
47 | @property
48 | def reference_x_shift(self):
49 | return 0
50 |
51 | @property
52 | def reference_y_shift(self):
53 | return 0
54 |
--------------------------------------------------------------------------------
/face-alignment/face_alignment/detection/sfd/__init__.py:
--------------------------------------------------------------------------------
1 | from .sfd_detector import SFDDetector as FaceDetector
--------------------------------------------------------------------------------
/face-alignment/face_alignment/detection/sfd/bbox.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 |
4 |
5 | def nms(dets, thresh):
6 | if 0 == len(dets):
7 | return []
8 | x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
9 | areas = (x2 - x1 + 1) * (y2 - y1 + 1)
10 | order = scores.argsort()[::-1]
11 |
12 | keep = []
13 | while order.size > 0:
14 | i = order[0]
15 | keep.append(i)
16 | xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
17 | xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
18 |
19 | w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
20 | ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
21 |
22 | inds = np.where(ovr <= thresh)[0]
23 | order = order[inds + 1]
24 |
25 | return keep
26 |
27 |
28 | def encode(matched, priors, variances):
29 | """Encode the variances from the priorbox layers into the ground truth boxes
30 | we have matched (based on jaccard overlap) with the prior boxes.
31 | Args:
32 | matched: (tensor) Coords of ground truth for each prior in point-form
33 | Shape: [num_priors, 4].
34 | priors: (tensor) Prior boxes in center-offset form
35 | Shape: [num_priors,4].
36 | variances: (list[float]) Variances of priorboxes
37 | Return:
38 | encoded boxes (tensor), Shape: [num_priors, 4]
39 | """
40 |
41 | # dist b/t match center and prior's center
42 | g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
43 | # encode variance
44 | g_cxcy /= (variances[0] * priors[:, 2:])
45 | # match wh / prior wh
46 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
47 | g_wh = np.log(g_wh) / variances[1]
48 |
49 | # return target for smooth_l1_loss
50 | return np.concatenate([g_cxcy, g_wh], 1) # [num_priors,4]
51 |
52 |
53 | def decode(loc, priors, variances):
54 | """Decode locations from predictions using priors to undo
55 | the encoding we did for offset regression at train time.
56 | Args:
57 | loc (tensor): location predictions for loc layers,
58 | Shape: [num_priors,4]
59 | priors (tensor): Prior boxes in center-offset form.
60 | Shape: [num_priors,4].
61 | variances: (list[float]) Variances of priorboxes
62 | Return:
63 | decoded bounding box predictions
64 | """
65 |
66 | boxes = np.concatenate((
67 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
68 | priors[:, 2:] * np.exp(loc[:, 2:] * variances[1])), 1)
69 | boxes[:, :2] -= boxes[:, 2:] / 2
70 | boxes[:, 2:] += boxes[:, :2]
71 | return boxes
72 |
--------------------------------------------------------------------------------
/face-alignment/face_alignment/detection/sfd/detect.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | import cv2
5 | import numpy as np
6 |
7 | from .bbox import *
8 |
9 |
10 | def detect(net, img, device):
11 | img = img.transpose(2, 0, 1)
12 | # Creates a batch of 1
13 | img = np.expand_dims(img, 0)
14 |
15 | img = torch.from_numpy(img.copy()).to(device, dtype=torch.float32)
16 |
17 | return batch_detect(net, img, device)
18 |
19 |
20 | def batch_detect(net, img_batch, device):
21 | """
22 | Inputs:
23 | - img_batch: a torch.Tensor of shape (Batch size, Channels, Height, Width)
24 | """
25 |
26 | if 'cuda' in device:
27 | torch.backends.cudnn.benchmark = True
28 |
29 | batch_size = img_batch.size(0)
30 | img_batch = img_batch.to(device, dtype=torch.float32)
31 |
32 | img_batch = img_batch.flip(-3) # RGB to BGR
33 | img_batch = img_batch - torch.tensor([104.0, 117.0, 123.0], device=device).view(1, 3, 1, 1)
34 |
35 | with torch.no_grad():
36 | olist = net(img_batch) # patched uint8_t overflow error
37 |
38 | for i in range(len(olist) // 2):
39 | olist[i * 2] = F.softmax(olist[i * 2], dim=1)
40 |
41 | olist = [oelem.data.cpu().numpy() for oelem in olist]
42 |
43 | bboxlists = get_predictions(olist, batch_size)
44 | return bboxlists
45 |
46 |
47 | def get_predictions(olist, batch_size):
48 | bboxlists = []
49 | variances = [0.1, 0.2]
50 | for j in range(batch_size):
51 | bboxlist = []
52 | for i in range(len(olist) // 2):
53 | ocls, oreg = olist[i * 2], olist[i * 2 + 1]
54 | stride = 2**(i + 2) # 4,8,16,32,64,128
55 | poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
56 | for Iindex, hindex, windex in poss:
57 | axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
58 | score = ocls[j, 1, hindex, windex]
59 | loc = oreg[j, :, hindex, windex].copy().reshape(1, 4)
60 | priors = np.array([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
61 | box = decode(loc, priors, variances)
62 | x1, y1, x2, y2 = box[0]
63 | bboxlist.append([x1, y1, x2, y2, score])
64 |
65 | bboxlists.append(bboxlist)
66 |
67 | bboxlists = np.array(bboxlists)
68 | return bboxlists
69 |
70 |
71 | def flip_detect(net, img, device):
72 | img = cv2.flip(img, 1)
73 | b = detect(net, img, device)
74 |
75 | bboxlist = np.zeros(b.shape)
76 | bboxlist[:, 0] = img.shape[1] - b[:, 2]
77 | bboxlist[:, 1] = b[:, 1]
78 | bboxlist[:, 2] = img.shape[1] - b[:, 0]
79 | bboxlist[:, 3] = b[:, 3]
80 | bboxlist[:, 4] = b[:, 4]
81 | return bboxlist
82 |
83 |
84 | def pts_to_bb(pts):
85 | min_x, min_y = np.min(pts, axis=0)
86 | max_x, max_y = np.max(pts, axis=0)
87 | return np.array([min_x, min_y, max_x, max_y])
88 |
--------------------------------------------------------------------------------
/face-alignment/face_alignment/detection/sfd/net_s3fd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class L2Norm(nn.Module):
7 | def __init__(self, n_channels, scale=1.0):
8 | super(L2Norm, self).__init__()
9 | self.n_channels = n_channels
10 | self.scale = scale
11 | self.eps = 1e-10
12 | self.weight = nn.Parameter(torch.empty(self.n_channels).fill_(self.scale))
13 |
14 | def forward(self, x):
15 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
16 | x = x / norm * self.weight.view(1, -1, 1, 1)
17 | return x
18 |
19 |
20 | class s3fd(nn.Module):
21 | def __init__(self):
22 | super(s3fd, self).__init__()
23 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
24 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
25 |
26 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
27 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
28 |
29 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
30 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
31 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
32 |
33 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
34 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
35 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
36 |
37 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
38 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
39 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
40 |
41 | self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
42 | self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
43 |
44 | self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
45 | self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
46 |
47 | self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
48 | self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
49 |
50 | self.conv3_3_norm = L2Norm(256, scale=10)
51 | self.conv4_3_norm = L2Norm(512, scale=8)
52 | self.conv5_3_norm = L2Norm(512, scale=5)
53 |
54 | self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
55 | self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
56 | self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
57 | self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
58 | self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
59 | self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
60 |
61 | self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
62 | self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
63 | self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
64 | self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
65 | self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
66 | self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
67 |
68 | def forward(self, x):
69 | h = F.relu(self.conv1_1(x), inplace=True)
70 | h = F.relu(self.conv1_2(h), inplace=True)
71 | h = F.max_pool2d(h, 2, 2)
72 |
73 | h = F.relu(self.conv2_1(h), inplace=True)
74 | h = F.relu(self.conv2_2(h), inplace=True)
75 | h = F.max_pool2d(h, 2, 2)
76 |
77 | h = F.relu(self.conv3_1(h), inplace=True)
78 | h = F.relu(self.conv3_2(h), inplace=True)
79 | h = F.relu(self.conv3_3(h), inplace=True)
80 | f3_3 = h
81 | h = F.max_pool2d(h, 2, 2)
82 |
83 | h = F.relu(self.conv4_1(h), inplace=True)
84 | h = F.relu(self.conv4_2(h), inplace=True)
85 | h = F.relu(self.conv4_3(h), inplace=True)
86 | f4_3 = h
87 | h = F.max_pool2d(h, 2, 2)
88 |
89 | h = F.relu(self.conv5_1(h), inplace=True)
90 | h = F.relu(self.conv5_2(h), inplace=True)
91 | h = F.relu(self.conv5_3(h), inplace=True)
92 | f5_3 = h
93 | h = F.max_pool2d(h, 2, 2)
94 |
95 | h = F.relu(self.fc6(h), inplace=True)
96 | h = F.relu(self.fc7(h), inplace=True)
97 | ffc7 = h
98 | h = F.relu(self.conv6_1(h), inplace=True)
99 | h = F.relu(self.conv6_2(h), inplace=True)
100 | f6_2 = h
101 | h = F.relu(self.conv7_1(h), inplace=True)
102 | h = F.relu(self.conv7_2(h), inplace=True)
103 | f7_2 = h
104 |
105 | f3_3 = self.conv3_3_norm(f3_3)
106 | f4_3 = self.conv4_3_norm(f4_3)
107 | f5_3 = self.conv5_3_norm(f5_3)
108 |
109 | cls1 = self.conv3_3_norm_mbox_conf(f3_3)
110 | reg1 = self.conv3_3_norm_mbox_loc(f3_3)
111 | cls2 = self.conv4_3_norm_mbox_conf(f4_3)
112 | reg2 = self.conv4_3_norm_mbox_loc(f4_3)
113 | cls3 = self.conv5_3_norm_mbox_conf(f5_3)
114 | reg3 = self.conv5_3_norm_mbox_loc(f5_3)
115 | cls4 = self.fc7_mbox_conf(ffc7)
116 | reg4 = self.fc7_mbox_loc(ffc7)
117 | cls5 = self.conv6_2_mbox_conf(f6_2)
118 | reg5 = self.conv6_2_mbox_loc(f6_2)
119 | cls6 = self.conv7_2_mbox_conf(f7_2)
120 | reg6 = self.conv7_2_mbox_loc(f7_2)
121 |
122 | # max-out background label
123 | chunk = torch.chunk(cls1, 4, 1)
124 | bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
125 | cls1 = torch.cat([bmax, chunk[3]], dim=1)
126 |
127 | return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
128 |
--------------------------------------------------------------------------------
/face-alignment/face_alignment/detection/sfd/sfd_detector.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.model_zoo import load_url
3 |
4 | from ..core import FaceDetector
5 |
6 | from .net_s3fd import s3fd
7 | from .bbox import nms
8 | from .detect import detect, batch_detect
9 |
10 | models_urls = {
11 | 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
12 | }
13 |
14 |
15 | class SFDDetector(FaceDetector):
16 | '''SF3D Detector.
17 | '''
18 |
19 | def __init__(self, device, path_to_detector=None, verbose=False, filter_threshold=0.5):
20 | super(SFDDetector, self).__init__(device, verbose)
21 |
22 | # Initialise the face detector
23 | if path_to_detector is None:
24 | model_weights = load_url(models_urls['s3fd'])
25 | else:
26 | model_weights = torch.load(path_to_detector)
27 |
28 | self.fiter_threshold = filter_threshold
29 | self.face_detector = s3fd()
30 | self.face_detector.load_state_dict(model_weights)
31 | self.face_detector.to(device)
32 | self.face_detector.eval()
33 |
34 | def _filter_bboxes(self, bboxlist):
35 | if len(bboxlist) > 0:
36 | keep = nms(bboxlist, 0.3)
37 | bboxlist = bboxlist[keep, :]
38 | bboxlist = [x for x in bboxlist if x[-1] > self.fiter_threshold]
39 |
40 | return bboxlist
41 |
42 | def detect_from_image(self, tensor_or_path):
43 | image = self.tensor_or_path_to_ndarray(tensor_or_path)
44 |
45 | bboxlist = detect(self.face_detector, image, device=self.device)[0]
46 | bboxlist = self._filter_bboxes(bboxlist)
47 |
48 | return bboxlist
49 |
50 | def detect_from_batch(self, tensor):
51 | bboxlists = batch_detect(self.face_detector, tensor, device=self.device)
52 |
53 | new_bboxlists = []
54 | for i in range(bboxlists.shape[0]):
55 | bboxlist = bboxlists[i]
56 | bboxlist = self._filter_bboxes(bboxlist)
57 | new_bboxlists.append(bboxlist)
58 |
59 | return new_bboxlists
60 |
61 | @property
62 | def reference_scale(self):
63 | return 195
64 |
65 | @property
66 | def reference_x_shift(self):
67 | return 0
68 |
69 | @property
70 | def reference_y_shift(self):
71 | return 0
72 |
--------------------------------------------------------------------------------
/face-alignment/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python
2 | scipy>=0.17.0
3 | scikit-image
4 | numba
5 |
--------------------------------------------------------------------------------
/face-alignment/setup.cfg:
--------------------------------------------------------------------------------
1 | [bumpversion]
2 | current_version = 1.3.4
3 | commit = True
4 | tag = True
5 |
6 | [bumpversion:file:setup.py]
7 | search = version='{current_version}'
8 | replace = version='{new_version}'
9 |
10 | [bumpversion:file:face_alignment/__init__.py]
11 | search = __version__ = '{current_version}'
12 | replace = __version__ = '{new_version}'
13 |
14 | [metadata]
15 | description_file = README.md
16 |
17 | [bdist_wheel]
18 | universal = 1
19 |
20 | [flake8]
21 | exclude =
22 | .github,
23 | examples,
24 | docs,
25 | .tox,
26 | bin,
27 | dist,
28 | tools,
29 | *.egg-info,
30 | __init__.py,
31 | *.yml
32 | max-line-length = 160
--------------------------------------------------------------------------------
/face-alignment/setup.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 | from os import path
4 | import re
5 | from setuptools import setup, find_packages
6 | # To use consisten encodings
7 | from codecs import open
8 |
9 | # Function from: https://github.com/pytorch/vision/blob/master/setup.py
10 |
11 |
12 | def read(*names, **kwargs):
13 | with io.open(
14 | os.path.join(os.path.dirname(__file__), *names),
15 | encoding=kwargs.get("encoding", "utf8")
16 | ) as fp:
17 | return fp.read()
18 |
19 | # Function from: https://github.com/pytorch/vision/blob/master/setup.py
20 |
21 |
22 | def find_version(*file_paths):
23 | version_file = read(*file_paths)
24 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
25 | version_file, re.M)
26 | if version_match:
27 | return version_match.group(1)
28 | raise RuntimeError("Unable to find version string.")
29 |
30 | here = path.abspath(path.dirname(__file__))
31 |
32 | # Get the long description from the README file
33 | with open(path.join(here, 'README.md'), encoding='utf-8') as readme_file:
34 | long_description = readme_file.read()
35 |
36 | VERSION = find_version('face_alignment', '__init__.py')
37 |
38 | requirements = [
39 | 'torch',
40 | 'numpy',
41 | 'scipy>=0.17',
42 | 'scikit-image',
43 | 'opencv-python',
44 | 'tqdm',
45 | 'numba',
46 | 'enum34;python_version<"3.4"'
47 | ]
48 |
49 | setup(
50 | name='face_alignment',
51 | version=VERSION,
52 |
53 | description="Detector 2D or 3D face landmarks from Python",
54 | long_description=long_description,
55 | long_description_content_type="text/markdown",
56 |
57 | # Author details
58 | author="Adrian Bulat",
59 | author_email="adrian@adrianbulat.com",
60 | url="https://github.com/1adrianb/face-alignment",
61 |
62 | # Package info
63 | packages=find_packages(exclude=('test',)),
64 |
65 | python_requires='>=3',
66 | install_requires=requirements,
67 | license='BSD',
68 | zip_safe=True,
69 |
70 | classifiers=[
71 | 'Development Status :: 5 - Production/Stable',
72 | 'Operating System :: OS Independent',
73 | 'License :: OSI Approved :: BSD License',
74 | 'Natural Language :: English',
75 |
76 | # Supported python versions
77 | 'Programming Language :: Python :: 3',
78 | 'Programming Language :: Python :: 3.3',
79 | 'Programming Language :: Python :: 3.4',
80 | 'Programming Language :: Python :: 3.5',
81 | 'Programming Language :: Python :: 3.6',
82 | 'Programming Language :: Python :: 3.7',
83 | 'Programming Language :: Python :: 3.8',
84 | ],
85 | )
86 |
--------------------------------------------------------------------------------
/face-alignment/test/facealignment_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 | import face_alignment
4 | import sys
5 | import torch
6 | sys.path.append('.')
7 | from face_alignment.utils import get_image
8 |
9 |
10 | class Tester(unittest.TestCase):
11 | def setUp(self) -> None:
12 | self.reference_data = [np.array([[137., 240., -85.907196],
13 | [140., 264., -81.1443],
14 | [143., 288., -76.25633],
15 | [146., 306., -69.01708],
16 | [152., 327., -53.775352],
17 | [161., 342., -30.029667],
18 | [170., 348., -2.792292],
19 | [185., 354., 23.522688],
20 | [212., 360., 38.664257],
21 | [239., 357., 31.747217],
22 | [263., 354., 12.192401],
23 | [284., 348., -10.0569725],
24 | [302., 333., -29.42916],
25 | [314., 315., -41.675602],
26 | [320., 297., -46.924263],
27 | [326., 276., -50.33218],
28 | [335., 252., -53.945686],
29 | [152., 207., -7.6189857],
30 | [164., 201., 6.1879144],
31 | [176., 198., 16.991247],
32 | [188., 198., 24.690582],
33 | [200., 201., 29.248188],
34 | [245., 204., 37.878166],
35 | [257., 201., 37.420483],
36 | [269., 201., 34.163113],
37 | [284., 204., 28.480812],
38 | [299., 216., 18.31863],
39 | [221., 225., 37.93351],
40 | [218., 237., 48.337395],
41 | [215., 249., 60.502884],
42 | [215., 261., 63.353687],
43 | [203., 273., 40.186855],
44 | [209., 276., 45.057003],
45 | [218., 276., 48.56715],
46 | [227., 276., 47.744766],
47 | [233., 276., 45.01401],
48 | [170., 228., 7.166072],
49 | [179., 222., 17.168053],
50 | [188., 222., 19.775822],
51 | [200., 228., 19.06176],
52 | [191., 231., 20.636724],
53 | [179., 231., 16.125824],
54 | [248., 231., 28.566122],
55 | [257., 225., 33.024036],
56 | [269., 225., 34.384735],
57 | [278., 231., 27.014532],
58 | [269., 234., 32.867023],
59 | [257., 234., 33.34033],
60 | [185., 306., 29.927242],
61 | [194., 297., 42.611233],
62 | [209., 291., 50.563396],
63 | [215., 291., 52.831104],
64 | [221., 291., 52.9225],
65 | [236., 300., 48.32575],
66 | [248., 309., 38.2375],
67 | [236., 312., 48.377922],
68 | [224., 315., 52.63793],
69 | [212., 315., 52.330444],
70 | [203., 315., 49.552994],
71 | [194., 309., 42.64459],
72 | [188., 303., 30.746407],
73 | [206., 300., 46.514435],
74 | [215., 300., 49.611156],
75 | [224., 300., 49.058918],
76 | [248., 309., 38.084103],
77 | [224., 303., 49.817806],
78 | [215., 303., 49.59815],
79 | [206., 303., 47.13894]], dtype=np.float32)]
80 |
81 | def test_predict_points(self):
82 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, device='cpu')
83 | preds = fa.get_landmarks('test/assets/aflw-test.jpg')
84 | self.assertEqual(len(preds), len(self.reference_data))
85 | for pred, reference in zip(preds, self.reference_data):
86 | self.assertTrue(np.allclose(pred, reference))
87 |
88 | def test_predict_batch_points(self):
89 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, device='cpu')
90 |
91 | reference_data = self.reference_data + self.reference_data
92 | reference_data.append([])
93 | image = get_image('test/assets/aflw-test.jpg')
94 | batch = np.stack([image, image, np.zeros_like(image)])
95 | batch = torch.Tensor(batch.transpose(0, 3, 1, 2))
96 |
97 | preds = fa.get_landmarks_from_batch(batch)
98 |
99 | self.assertEqual(len(preds), len(reference_data))
100 | for pred, reference in zip(preds, reference_data):
101 | self.assertTrue(np.allclose(pred, reference))
102 |
103 | def test_predict_points_from_dir(self):
104 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, device='cpu')
105 |
106 | reference_data = {
107 | 'test/assets/grass.jpg': None,
108 | 'test/assets/aflw-test.jpg': self.reference_data}
109 |
110 | preds = fa.get_landmarks_from_directory('test/assests/')
111 |
112 | for k, points in preds.items():
113 | if isinstance(points, list):
114 | for p, p_reference in zip(points, reference_data[k]):
115 | self.assertTrue(np.allclose(p, p_reference))
116 | else:
117 | self.assertEqual(points, reference_data[k])
118 |
119 |
120 | if __name__ == '__main__':
121 | unittest.main()
122 |
--------------------------------------------------------------------------------
/face-alignment/test/smoke_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import face_alignment
3 |
--------------------------------------------------------------------------------
/face-alignment/test/test_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('.')
3 | import unittest
4 | from face_alignment.utils import *
5 | import numpy as np
6 | import torch
7 |
8 |
9 | class Tester(unittest.TestCase):
10 | def test_flip_is_label(self):
11 | # Generate the points
12 | heatmaps = torch.from_numpy(np.random.randint(1, high=250, size=(68, 64, 64)).astype('float32'))
13 |
14 | flipped_heatmaps = flip(flip(heatmaps.clone(), is_label=True), is_label=True)
15 |
16 | assert np.allclose(heatmaps.numpy(), flipped_heatmaps.numpy())
17 |
18 | def test_flip_is_image(self):
19 | fake_image = torch.torch.rand(3, 256, 256)
20 | fliped_fake_image = flip(flip(fake_image.clone()))
21 |
22 | assert np.allclose(fake_image.numpy(), fliped_fake_image.numpy())
23 |
24 | def test_getpreds(self):
25 | pts = np.random.randint(1, high=63, size=(68, 2)).astype('float32')
26 |
27 | heatmaps = np.zeros((68, 256, 256))
28 | for i in range(68):
29 | if pts[i, 0] > 0:
30 | heatmaps[i] = draw_gaussian(heatmaps[i], pts[i], 2)
31 | heatmaps = np.expand_dims(heatmaps, axis=0)
32 |
33 | preds, _, _ = get_preds_fromhm(heatmaps)
34 |
35 | assert np.allclose(pts, preds, atol=5)
36 |
37 | def test_create_heatmaps(self):
38 | reference_scale = 195
39 | target_landmarks = torch.randint(0, 255, (1, 68, 2)).type(torch.float) # simulated dataset
40 | bb = create_bounding_box(target_landmarks)
41 | centers = torch.stack([bb[:, 2] - (bb[:, 2] - bb[:, 0]) / 2.0, bb[:, 3] - (bb[:, 3] - bb[:, 1]) / 2.0], dim=1)
42 | centers[:, 1] = centers[:, 1] - (bb[:, 3] - bb[:, 1]) * 0.12 # Not sure where 0.12 comes from
43 | scales = (bb[:, 2] - bb[:, 0] + bb[:, 3] - bb[:, 1]) / reference_scale
44 | heatmaps = create_target_heatmap(target_landmarks, centers, scales)
45 | preds = get_preds_fromhm(heatmaps.numpy(), centers.squeeze().numpy(), scales.squeeze().numpy())[1]
46 |
47 | assert np.allclose(preds, target_landmarks, atol=5)
48 |
49 | if __name__ == '__main__':
50 | unittest.main()
51 |
--------------------------------------------------------------------------------
/face-alignment/tox.ini:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 120
3 | ignore = E305,E402,E721,F401,F403,F405,F821,F841,F999,W503
--------------------------------------------------------------------------------
/frames_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from skimage import io, img_as_float32
3 | from skimage.color import gray2rgb
4 | from sklearn.model_selection import train_test_split
5 | from imageio import mimread
6 |
7 | import numpy as np
8 | from torch.utils.data import Dataset
9 | import pandas as pd
10 | from augmentation import AllAugmentationTransform
11 | import glob
12 | from PIL import Image
13 | import pdb
14 | def read_video(name, frame_shape):
15 | """
16 | Read video which can be:
17 | - an image of concatenated frames
18 | - '.mp4' and'.gif'
19 | - folder with videos
20 | """
21 |
22 | if os.path.isdir(name):
23 | frames = sorted(os.listdir(name))
24 | num_frames = len(frames)
25 | video_array = np.array(
26 | [img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)])
27 | elif name.lower().endswith('.png') or name.lower().endswith('.jpg'):
28 | image = io.imread(name)
29 |
30 | if len(image.shape) == 2 or image.shape[2] == 1:
31 | image = gray2rgb(image)
32 |
33 | if image.shape[2] == 4:
34 | image = image[..., :3]
35 |
36 | image = img_as_float32(image)
37 |
38 | video_array = np.moveaxis(image, 1, 0)
39 |
40 | video_array = video_array.reshape((-1,) + frame_shape)
41 | video_array = np.moveaxis(video_array, 1, 2)
42 | elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'):
43 | video = np.array(mimread(name))
44 | if len(video.shape) == 3:
45 | video = np.array([gray2rgb(frame) for frame in video])
46 | if video.shape[-1] == 4:
47 | video = video[..., :3]
48 | video_array = img_as_float32(video)
49 | else:
50 | raise Exception("Unknown file extensions %s" % name)
51 |
52 | return video_array
53 |
54 |
55 | class FramesDataset(Dataset):
56 | """
57 | Dataset of videos, each video can be represented as:
58 | - an image of concatenated frames
59 | - '.mp4' or '.gif'
60 | - folder with all frames
61 | """
62 |
63 | def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
64 | random_seed=0, pairs_list=None, augmentation_params=None):
65 | self.root_dir = root_dir
66 | self.videos = os.listdir(root_dir)
67 | # self.videos = self.videos[5000]
68 | self.frame_shape = tuple(frame_shape)
69 | self.pairs_list = pairs_list
70 | self.id_sampling = id_sampling
71 | if os.path.exists(os.path.join(root_dir, 'train')):
72 | assert os.path.exists(os.path.join(root_dir, 'test'))
73 | print("Use predefined train-test split.")
74 | if id_sampling:
75 | train_videos = {os.path.basename(video).split('#')[0] for video in
76 | os.listdir(os.path.join(root_dir, 'train'))}
77 | train_videos = list(train_videos)
78 | else:
79 | train_videos = os.listdir(os.path.join(root_dir, 'train'))
80 | test_videos = os.listdir(os.path.join(root_dir, 'test'))
81 | self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
82 | else:
83 | print("Use random train-test split.")
84 | train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)
85 |
86 | if is_train:
87 | self.videos = train_videos
88 | else:
89 | self.videos = test_videos
90 |
91 | self.is_train = is_train
92 |
93 | if self.is_train:
94 | self.transform = AllAugmentationTransform(**augmentation_params)
95 | else:
96 | self.transform = None
97 |
98 | def __len__(self):
99 | return len(self.videos)
100 |
101 | def __getitem__(self, idx):
102 | if self.is_train and self.id_sampling:
103 | name = self.videos[idx]
104 | path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))
105 | else:
106 | name = self.videos[idx]
107 | path = os.path.join(self.root_dir, name)
108 |
109 | video_name = os.path.basename(path)
110 |
111 | if self.is_train and os.path.isdir(path):
112 | frames = os.listdir(path)
113 | num_frames = len(frames)
114 | frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))
115 | # video_array = [img_as_float32(io.imread(os.path.join(path, frames[idx].decode()))) for idx in frame_idx]
116 | video_array = []
117 | for idx in frame_idx:
118 | try:
119 | video_array.append(img_as_float32(io.imread(os.path.join(path, frames[idx].decode()))))
120 | except Exception as e:
121 | print(e)
122 | else:
123 | video_array = read_video(path, frame_shape=self.frame_shape)
124 | num_frames = len(video_array)
125 | frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range(
126 | num_frames)
127 | video_array = video_array[frame_idx]
128 |
129 | if self.transform is not None:
130 | video_array = self.transform(video_array)
131 |
132 | out = {}
133 | if self.is_train:
134 | source = np.array(video_array[0], dtype='float32')
135 | driving = np.array(video_array[1], dtype='float32')
136 |
137 | out['driving'] = driving.transpose((2, 0, 1))
138 | out['source'] = source.transpose((2, 0, 1))
139 | else:
140 | video = np.array(video_array, dtype='float32')
141 | out['video'] = video.transpose((3, 0, 1, 2))
142 |
143 | out['name'] = video_name
144 |
145 | return out
146 |
147 |
148 | class DatasetRepeater(Dataset):
149 | """
150 | Pass several times over the same dataset for better i/o performance
151 | """
152 |
153 | def __init__(self, dataset, num_repeats=100):
154 | self.dataset = dataset
155 | self.num_repeats = num_repeats
156 |
157 | def __len__(self):
158 | return self.num_repeats * self.dataset.__len__()
159 |
160 | def __getitem__(self, idx):
161 | return self.dataset[idx % self.dataset.__len__()]
162 |
163 |
164 | class PairedDataset(Dataset):
165 | """
166 | Dataset of pairs for animation.
167 | """
168 |
169 | def __init__(self, initial_dataset, number_of_pairs, seed=0):
170 | self.initial_dataset = initial_dataset
171 | pairs_list = self.initial_dataset.pairs_list
172 | np.random.seed(seed)
173 |
174 | if pairs_list is None:
175 | max_idx = min(number_of_pairs, len(initial_dataset))
176 | nx, ny = max_idx, max_idx
177 | xy = np.mgrid[:nx, :ny].reshape(2, -1).T
178 | number_of_pairs = min(xy.shape[0], number_of_pairs)
179 | self.pairs = xy.take(np.random.choice(xy.shape[0], number_of_pairs, replace=False), axis=0)
180 | else:
181 | videos = self.initial_dataset.videos
182 | name_to_index = {name: index for index, name in enumerate(videos)}
183 | pairs = pd.read_csv(pairs_list)
184 | pairs = pairs[np.logical_and(pairs['source'].isin(videos), pairs['driving'].isin(videos))]
185 | number_of_pairs = min(pairs.shape[0], number_of_pairs)
186 | self.pairs = []
187 | self.start_frames = []
188 | for ind in range(number_of_pairs):
189 | self.pairs.append(
190 | (name_to_index[pairs['driving'].iloc[ind]], name_to_index[pairs['source'].iloc[ind]]))
191 |
192 | def __len__(self):
193 | return len(self.pairs)
194 |
195 | def __getitem__(self, idx):
196 | pair = self.pairs[idx]
197 | first = self.initial_dataset[pair[0]]
198 | second = self.initial_dataset[pair[1]]
199 | first = {'driving_' + key: value for key, value in first.items()}
200 | second = {'source_' + key: value for key, value in second.items()}
201 |
202 | return {**first, **second}
203 |
--------------------------------------------------------------------------------
/kill_port.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import signal
4 | import pdb
5 | def kill_process(*pids):
6 | for pid in pids:
7 | a = os.kill(pid, signal.SIGKILL)
8 | print('已杀死pid为%s的进程, 返回值是:%s' % (pid, a))
9 |
10 | def get_pid(*ports):
11 | #其中\"为转义"
12 | pids = []
13 | print(ports)
14 | for port in ports:
15 | msg = os.popen('lsof -i:{}'.format(port)).read()
16 | msg = msg.split('\n')[1:-1]
17 | for m in msg:
18 | m = m.replace(' ', ' ')
19 | m = m.replace(' ', ' ')
20 | tokens = m.split(' ')
21 | pids.append(int(tokens[1]))
22 | return pids
23 |
24 | if __name__ == "__main__":
25 | # 杀死占用端口号的ps进程
26 | ports = sys.argv[1:]
27 | kill_process(*get_pid(*ports))
28 |
29 |
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import imageio
5 |
6 | import os
7 | from skimage.draw import ellipse
8 | import pdb
9 | import matplotlib.pyplot as plt
10 | import collections
11 |
12 | class Logger:
13 | def __init__(self, log_dir, checkpoint_freq=100, visualizer_params=None, zfill_num=8, log_file_name='log.txt'):
14 |
15 | self.loss_list = []
16 | self.cpk_dir = log_dir
17 | self.visualizations_dir = os.path.join(log_dir, 'train-vis')
18 | if not os.path.exists(self.visualizations_dir):
19 | os.makedirs(self.visualizations_dir)
20 | self.log_file = open(os.path.join(log_dir, log_file_name), 'a')
21 | self.zfill_num = zfill_num
22 | self.visualizer = Visualizer(**visualizer_params)
23 | self.checkpoint_freq = checkpoint_freq
24 | self.epoch = 0
25 | self.best_loss = float('inf')
26 | self.names = None
27 |
28 | def log_scores(self, loss_names):
29 | loss_mean = np.array(self.loss_list).mean(axis=0)
30 |
31 | loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)])
32 | loss_string = str(self.epoch).zfill(self.zfill_num) + ") " + loss_string
33 |
34 | print(loss_string, file=self.log_file)
35 | self.loss_list = []
36 | self.log_file.flush()
37 |
38 | def visualize_rec(self, inp, out):
39 | image = self.visualizer.visualize(inp['driving'], inp['source'], out)
40 | imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)), image)
41 |
42 | def save_cpk(self, emergent=False):
43 | cpk = {k: v.state_dict() for k, v in self.models.items()}
44 | cpk['epoch'] = self.epoch
45 | cpk_path = os.path.join(self.cpk_dir, '%s-checkpoint.pth.tar' % str(self.epoch).zfill(self.zfill_num))
46 | if not (os.path.exists(cpk_path) and emergent):
47 | torch.save(cpk, cpk_path)
48 |
49 | @staticmethod
50 | def load_cpk(checkpoint_path, generator=None, discriminator=None, kp_detector=None,
51 | optimizer_generator=None, optimizer_discriminator=None, optimizer_kp_detector=None):
52 | num_gpu = torch.cuda.device_count()
53 | if num_gpu == 1:
54 | checkpoint = torch.load(checkpoint_path,map_location='cuda:0')
55 | else:
56 | checkpoint = torch.load(checkpoint_path,map_location='cpu')
57 | if generator is not None:
58 | ckp_generator = collections.OrderedDict((k.replace('.module.','.'),v) for k,v in checkpoint['generator'].items())
59 | generator.load_state_dict(ckp_generator)
60 | if kp_detector is not None:
61 | ckp_kp_detector = collections.OrderedDict((k.replace('.module.','.'),v) for k,v in checkpoint['kp_detector'].items())
62 | kp_detector.load_state_dict(ckp_kp_detector)
63 | if discriminator is not None:
64 | try:
65 | ckp_discriminator = collections.OrderedDict((k.replace('.module.','.'),v) for k,v in checkpoint['discriminator'].items())
66 | discriminator.load_state_dict(ckp_discriminator)
67 | except:
68 | print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
69 | if optimizer_generator is not None:
70 | ckp_optimizer_generator = collections.OrderedDict((k.replace('.module.','.'),v) for k,v in checkpoint['optimizer_generator'].items())
71 | optimizer_generator.load_state_dict(ckp_optimizer_generator)
72 | if optimizer_discriminator is not None:
73 | try:
74 | ckp_optimizer_discriminator = collections.OrderedDict((k.replace('.module.','.'),v) for k,v in checkpoint['optimizer_discriminator'].items())
75 | optimizer_discriminator.load_state_dict(ckp_optimizer_discriminator)
76 | except RuntimeError as e:
77 | print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
78 | if optimizer_kp_detector is not None:
79 | ckp_optimizer_kp_detector = collections.OrderedDict((k.replace('.module.','.'),v) for k,v in checkpoint['optimizer_kp_detector'].items())
80 | optimizer_kp_detector.load_state_dict(ckp_optimizer_kp_detector)
81 |
82 | return checkpoint['epoch']
83 |
84 | def __enter__(self):
85 | return self
86 |
87 | def __exit__(self, exc_type, exc_val, exc_tb):
88 | if 'models' in self.__dict__:
89 | self.save_cpk()
90 | self.log_file.close()
91 |
92 | def log_iter(self, losses):
93 | losses = collections.OrderedDict(losses.items())
94 | if self.names is None:
95 | self.names = list(losses.keys())
96 | self.loss_list.append(list(losses.values()))
97 |
98 | def log_epoch(self, epoch, models, inp, out):
99 | self.epoch = epoch
100 | self.models = models
101 | if (self.epoch + 1) % self.checkpoint_freq == 0:
102 | self.save_cpk()
103 | self.log_scores(self.names)
104 | # self.visualize_rec(inp, out)
105 |
106 |
107 | class Visualizer:
108 | def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow'):
109 | self.kp_size = kp_size
110 | self.draw_border = draw_border
111 | self.colormap = plt.get_cmap(colormap)
112 |
113 | def draw_image_with_kp(self, image, kp_array):
114 | image = np.copy(image)
115 | spatial_size = np.array(image.shape[:2][::-1])[np.newaxis]
116 | kp_array = spatial_size * (kp_array + 1) / 2
117 | num_kp = kp_array.shape[0]
118 | for kp_ind, kp in enumerate(kp_array):
119 | rr, cc = ellipse(kp[1], kp[0], self.kp_size,self.kp_size, shape=image.shape[:2])
120 | image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3]
121 | return image
122 |
123 | def create_image_column_with_kp(self, images, kp):
124 | image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)])
125 | return self.create_image_column(image_array)
126 |
127 | def create_image_column(self, images):
128 | if self.draw_border:
129 | images = np.copy(images)
130 | images[:, :, [0, -1]] = (1, 1, 1)
131 | images[:, :, [0, -1]] = (1, 1, 1)
132 | return np.concatenate(list(images), axis=0)
133 |
134 | def create_image_grid(self, *args):
135 | out = []
136 | for arg in args:
137 | if type(arg) == tuple:
138 | out.append(self.create_image_column_with_kp(arg[0], arg[1]))
139 | else:
140 | out.append(self.create_image_column(arg))
141 | return np.concatenate(out, axis=1)
142 |
143 | def visualize(self, driving, source, out):
144 | images = []
145 |
146 | # Source image with keypoints
147 | source = source.data.cpu()
148 | kp_source = out['kp_source']['value'].data.cpu().numpy()
149 | source = np.transpose(source, [0, 2, 3, 1])
150 | images.append((source, kp_source))
151 |
152 | # Equivariance visualization
153 | if 'transformed_frame' in out:
154 | transformed = out['transformed_frame'].data.cpu().numpy()
155 | transformed = np.transpose(transformed, [0, 2, 3, 1])
156 | transformed_kp = out['transformed_kp']['value'].data.cpu().numpy()
157 | images.append((transformed, transformed_kp))
158 |
159 | # Driving image with keypoints
160 | kp_driving = out['kp_driving']['value'].data.cpu().numpy()
161 | driving = driving.data.cpu().numpy()
162 | driving = np.transpose(driving, [0, 2, 3, 1])
163 | images.append((driving, kp_driving))
164 |
165 | # Deformed image
166 | if 'deformed' in out:
167 | deformed = out['deformed'].data.cpu().numpy()
168 | deformed = np.transpose(deformed, [0, 2, 3, 1])
169 | images.append(deformed)
170 |
171 | # Result with and without keypoints
172 | prediction = out['prediction'].data.cpu().numpy()
173 | prediction = np.transpose(prediction, [0, 2, 3, 1])
174 | if 'kp_norm' in out:
175 | kp_norm = out['kp_norm']['value'].data.cpu().numpy()
176 | images.append((prediction, kp_norm))
177 | images.append(prediction)
178 |
179 |
180 | ## Occlusion map
181 | if 'occlusion_map' in out:
182 | occlusion_map = out['occlusion_map'].data.cpu().repeat(1, 3, 1, 1)
183 | occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy()
184 | occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1])
185 | images.append(occlusion_map)
186 |
187 | # Deformed images according to each individual transform
188 | if 'sparse_deformed' in out:
189 | full_mask = []
190 | for i in range(out['sparse_deformed'].shape[1]):
191 | image = out['sparse_deformed'][:, i].data.cpu()
192 | image = F.interpolate(image, size=source.shape[1:3])
193 | mask = out['mask'][:, i:(i+1)].data.cpu().repeat(1, 3, 1, 1)
194 | mask = F.interpolate(mask, size=source.shape[1:3])
195 | image = np.transpose(image.numpy(), (0, 2, 3, 1))
196 | mask = np.transpose(mask.numpy(), (0, 2, 3, 1))
197 |
198 | if i != 0:
199 | color = np.array(self.colormap((i - 1) / (out['sparse_deformed'].shape[1] - 1)))[:3]
200 | else:
201 | color = np.array((0, 0, 0))
202 |
203 | color = color.reshape((1, 1, 1, 3))
204 |
205 | images.append(image)
206 | if i != 0:
207 | images.append(mask * color)
208 | else:
209 | images.append(mask)
210 |
211 | full_mask.append(mask * color)
212 |
213 | images.append(sum(full_mask))
214 |
215 | image = self.create_image_grid(*images)
216 | image = (255 * image).astype(np.uint8)
217 | return image
218 |
--------------------------------------------------------------------------------
/modules/AdaIN.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def calc_mean_std(feat, eps=1e-5):
4 | # eps is a small value added to the variance to avoid divide-by-zero.
5 | size = feat.size()
6 | assert (len(size) == 4)
7 | N, C = size[:2]
8 | feat_var = feat.view(N, C, -1).var(dim=2) + eps
9 | feat_std = feat_var.sqrt().view(N, C, 1, 1)
10 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
11 | return feat_mean, feat_std
12 |
13 | def adaptive_instance_normalization(content_feat, style_feat):
14 | assert (content_feat.size()[:2] == style_feat.size()[:2])
15 | size = content_feat.size()
16 | style_mean, style_std = calc_mean_std(style_feat)
17 | content_mean, content_std = calc_mean_std(content_feat)
18 | normalized_feat = (content_feat - content_mean.expand(
19 | size)) / content_std.expand(size)
20 |
21 | return normalized_feat * style_std.expand(size) + style_mean.expand(size)
22 |
23 | def _calc_feat_flatten_mean_std(feat):
24 | # takes 3D feat (C, H, W), return mean and std of array within channels
25 | assert (feat.size()[0] == 3)
26 | assert (isinstance(feat, torch.FloatTensor))
27 | feat_flatten = feat.view(3, -1)
28 | mean = feat_flatten.mean(dim=-1, keepdim=True)
29 | std = feat_flatten.std(dim=-1, keepdim=True)
30 | return feat_flatten, mean, std
31 |
32 | def _mat_sqrt(x):
33 | U, D, V = torch.svd(x)
34 | return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())
35 |
36 | def coral(source, target):
37 | # assume both source and target are 3D array (C, H, W)
38 | # Note: flatten -> f
39 | source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
40 | source_f_norm = (source_f - source_f_mean.expand_as(
41 | source_f)) / source_f_std.expand_as(source_f)
42 | source_f_cov_eye = \
43 | torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)
44 |
45 | target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
46 | target_f_norm = (target_f - target_f_mean.expand_as(
47 | target_f)) / target_f_std.expand_as(target_f)
48 | target_f_cov_eye = \
49 | torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)
50 |
51 | source_f_norm_transfer = torch.mm(
52 | _mat_sqrt(target_f_cov_eye),
53 | torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
54 | source_f_norm)
55 | )
56 |
57 | source_f_transfer = source_f_norm_transfer * \
58 | target_f_std.expand_as(source_f_norm) + \
59 | target_f_mean.expand_as(source_f_norm)
60 |
61 | return source_f_transfer.view(source.size())
--------------------------------------------------------------------------------
/modules/dense_motion.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 | import torch
4 | from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian
5 | import pdb
6 | from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
7 |
8 |
9 | class DenseMotionNetwork(nn.Module):
10 | """
11 | Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
12 | """
13 |
14 | def __init__(self, block_expansion, num_blocks, max_features, num_kp, num_channels, estimate_occlusion_map=False,
15 | scale_factor=1, kp_variance=0.01):
16 | super(DenseMotionNetwork, self).__init__()
17 | self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp + 1) * (num_channels + 1),
18 | max_features=max_features, num_blocks=num_blocks)
19 |
20 | self.mask = nn.Conv2d(self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3))
21 |
22 | if estimate_occlusion_map:
23 | self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3))
24 | else:
25 | self.occlusion = None
26 |
27 | self.num_kp = num_kp
28 | self.scale_factor = scale_factor
29 | self.kp_variance = kp_variance
30 |
31 | if self.scale_factor != 1:
32 | self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
33 |
34 | def create_heatmap_representations(self, source_image, kp_driving, kp_source):
35 | """
36 | Eq 6. in the paper H_k(z)
37 | """
38 | spatial_size = source_image.shape[2:]
39 | gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance)
40 | gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance)
41 | heatmap = gaussian_driving - gaussian_source
42 | #adding background feature
43 | zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type())
44 | heatmap = torch.cat([zeros, heatmap], dim=1)
45 | heatmap = heatmap.unsqueeze(2)
46 | return heatmap
47 |
48 | def create_sparse_motions(self, source_image, kp_driving, kp_source):
49 | """
50 | Eq 4. in the paper T_{s<-d}(z)
51 | """
52 | bs, _, h, w = source_image.shape
53 | identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type())
54 | identity_grid = identity_grid.view(1, 1, h, w, 2)
55 | coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 2)
56 | if 'jacobian' in kp_driving:
57 | jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))
58 | jacobian = jacobian.unsqueeze(-3).unsqueeze(-3)
59 | jacobian = jacobian.repeat(1, 1, h, w, 1, 1)
60 | coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
61 | coordinate_grid = coordinate_grid.squeeze(-1)
62 |
63 | driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 2)
64 |
65 | #adding background feature
66 | identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
67 | sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) #bs, num_kp+1,w,h,2
68 | return sparse_motions
69 |
70 | def create_deformed_source_image(self, source_image, sparse_motions):
71 | """
72 | Eq 7. in the paper \hat{T}_{s<-d}(z)
73 | """
74 | bs, _, h, w = source_image.shape
75 | source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp + 1, 1, 1, 1, 1)
76 | source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w)
77 | sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1))
78 | sparse_deformed = F.grid_sample(source_repeat, sparse_motions)
79 | sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w))
80 | return sparse_deformed
81 |
82 | def forward(self, source_image, kp_driving, kp_source):
83 | if self.scale_factor != 1:
84 | source_image = self.down(source_image)
85 | bs, _, h, w = source_image.shape
86 | out_dict = dict()
87 | heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source)
88 | sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source)
89 | deformed_source = self.create_deformed_source_image(source_image, sparse_motion)
90 | out_dict['sparse_deformed'] = deformed_source
91 |
92 | input = torch.cat([heatmap_representation, deformed_source], dim=2)
93 | input = input.view(bs, -1, h, w)
94 |
95 | prediction = self.hourglass(input)
96 |
97 | mask = self.mask(prediction)
98 | mask = F.softmax(mask, dim=1)
99 | out_dict['mask'] = mask
100 | mask = mask.unsqueeze(2)
101 | sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3)
102 | deformation = (sparse_motion * mask).sum(dim=1)
103 | deformation = deformation.permute(0, 2, 3, 1)
104 |
105 | out_dict['deformation'] = deformation
106 |
107 | # Sec. 3.2 in the paper
108 | if self.occlusion:
109 | occlusion_map = torch.sigmoid(self.occlusion(prediction))
110 | out_dict['occlusion_map'] = occlusion_map
111 |
112 | return out_dict
113 |
--------------------------------------------------------------------------------
/modules/discriminator.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 | from modules.util import kp2gaussian
4 | import torch
5 | import pdb
6 |
7 | class DownBlock2d(nn.Module):
8 | """
9 | Simple block for processing video (encoder).
10 | """
11 |
12 | def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
13 | super(DownBlock2d, self).__init__()
14 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
15 |
16 | if sn:
17 | self.conv = nn.utils.spectral_norm(self.conv)
18 |
19 | if norm:
20 | self.norm = nn.InstanceNorm2d(out_features, affine=True)
21 | else:
22 | self.norm = None
23 | self.pool = pool
24 |
25 | def forward(self, x):
26 | out = x
27 | out = self.conv(out)
28 | if self.norm:
29 | out = self.norm(out)
30 | out = F.leaky_relu(out, 0.2)
31 | if self.pool:
32 | out = F.avg_pool2d(out, (2, 2))
33 | return out
34 |
35 |
36 | class Discriminator(nn.Module):
37 | """
38 | Discriminator similar to Pix2Pix
39 | """
40 |
41 | def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
42 | sn=False, use_kp=False, num_kp=10, kp_variance=0.01, **kwargs):
43 | super(Discriminator, self).__init__()
44 |
45 | down_blocks = []
46 | for i in range(num_blocks):
47 | down_blocks.append(
48 | DownBlock2d(num_channels + num_kp * use_kp if i == 0 else min(max_features, block_expansion * (2 ** i)),
49 | min(max_features, block_expansion * (2 ** (i + 1))),
50 | norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
51 | self.down_blocks = nn.ModuleList(down_blocks)
52 | self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
53 | if sn:
54 | self.conv = nn.utils.spectral_norm(self.conv)
55 | self.use_kp = use_kp
56 | self.kp_variance = kp_variance
57 |
58 | def forward(self, x, kp=None):
59 | feature_maps = []
60 | out = x
61 | if self.use_kp:
62 | heatmap = kp2gaussian(kp, x.shape[2:], self.kp_variance)
63 | out = torch.cat([out, heatmap], dim=1)
64 | # print(out.shape)
65 | for down_block in self.down_blocks:
66 | feature_maps.append(down_block(out))
67 | out = feature_maps[-1]
68 | # print(out.shape)
69 | prediction_map = self.conv(out)
70 |
71 | return feature_maps, prediction_map
72 |
73 |
74 | class MultiScaleDiscriminator(nn.Module):
75 | """
76 | Multi-scale (scale) discriminator
77 | """
78 |
79 | def __init__(self, scales=(), **kwargs):
80 | super(MultiScaleDiscriminator, self).__init__()
81 | self.scales = scales
82 | discs = {}
83 | for scale in scales:
84 | discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
85 | self.discs = nn.ModuleDict(discs)
86 |
87 | def forward(self, x, kp=None):
88 | out_dict = {}
89 | for scale, disc in self.discs.items():
90 | scale = str(scale).replace('-', '.')
91 | key = 'prediction_' + scale
92 | feature_maps, prediction_map = disc(x[key], kp)
93 | out_dict['feature_maps_' + scale] = feature_maps
94 | out_dict['prediction_map_' + scale] = prediction_map
95 | return out_dict
96 |
--------------------------------------------------------------------------------
/modules/keypoint_detector.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | import torch.nn.functional as F
4 | from modules.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d,Hourglass_2branch
5 | import pdb
6 |
7 | class KPDetector(nn.Module):
8 | """
9 | Detecting a keypoints. Return keypoint position and jacobian near each keypoint.
10 | """
11 |
12 | def __init__(self, block_expansion, num_kp, num_channels, max_features,
13 | num_blocks, temperature, estimate_jacobian=False, scale_factor=1,
14 | single_jacobian_map=False, pad=0):
15 | super(KPDetector, self).__init__()
16 | self.predictor = Hourglass(block_expansion, in_features=num_channels,
17 | max_features=max_features, num_blocks=num_blocks)
18 |
19 | self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7),
20 | padding=pad)
21 |
22 | if estimate_jacobian:
23 | self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
24 | self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters,
25 | out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad)
26 | self.jacobian.weight.data.zero_()
27 | self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
28 | else:
29 | self.jacobian = None
30 |
31 | self.temperature = temperature
32 | self.scale_factor = scale_factor
33 | if self.scale_factor != 1:
34 | self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
35 |
36 | def gaussian2kp(self, heatmap):
37 | """
38 | Extract the mean and from a heatmap
39 | """
40 | shape = heatmap.shape
41 | heatmap = heatmap.unsqueeze(-1)
42 | grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0)
43 | value = (heatmap * grid).sum(dim=(2, 3))
44 | kp = {'value': value}
45 |
46 | return kp
47 |
48 | def forward(self, x):
49 | if self.scale_factor != 1:
50 | x = self.down(x)
51 | feature_map = self.predictor(x) #x bz,4,64,64
52 | prediction = self.kp(feature_map)
53 |
54 | final_shape = prediction.shape
55 | heatmap = prediction.view(final_shape[0], final_shape[1], -1)
56 | heatmap = F.softmax(heatmap / self.temperature, dim=2)
57 | heatmap = heatmap.view(*final_shape)
58 |
59 | out = self.gaussian2kp(heatmap)
60 |
61 | if self.jacobian is not None:
62 | jacobian_map = self.jacobian(feature_map)
63 | # pdb.set_trace()
64 | jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],
65 | final_shape[3])
66 | heatmap = heatmap.unsqueeze(2)
67 |
68 | jacobian = heatmap * jacobian_map
69 | jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)
70 | jacobian = jacobian.sum(dim=-1)
71 | jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2)
72 | out['jacobian'] = jacobian
73 |
74 | return out
75 |
76 |
--------------------------------------------------------------------------------
/reconstruction.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | import torch
4 | from torch.utils.data import DataLoader
5 | from logger import Logger, Visualizer
6 | import numpy as np
7 | import imageio
8 | from sync_batchnorm import DataParallelWithCallback
9 | import depth
10 |
11 |
12 | def reconstruction(config, generator, kp_detector, checkpoint, log_dir, dataset):
13 | png_dir = os.path.join(log_dir, 'reconstruction/png')
14 | log_dir = os.path.join(log_dir, 'reconstruction')
15 |
16 | if checkpoint is not None:
17 | Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
18 | else:
19 | raise AttributeError("Checkpoint should be specified for mode='reconstruction'.")
20 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
21 |
22 | if not os.path.exists(log_dir):
23 | os.makedirs(log_dir)
24 |
25 | if not os.path.exists(png_dir):
26 | os.makedirs(png_dir)
27 |
28 | loss_list = []
29 | if torch.cuda.is_available():
30 | generator = DataParallelWithCallback(generator)
31 | kp_detector = DataParallelWithCallback(kp_detector)
32 |
33 | depth_encoder = depth.ResnetEncoder(18, False).cuda()
34 | depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4)).cuda()
35 | loaded_dict_enc = torch.load('depth/models/weights_19/encoder.pth')
36 | loaded_dict_dec = torch.load('depth/models/weights_19/depth.pth')
37 | filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
38 | depth_encoder.load_state_dict(filtered_dict_enc)
39 | depth_decoder.load_state_dict(loaded_dict_dec)
40 | depth_decoder.eval()
41 | depth_encoder.eval()
42 |
43 | generator.eval()
44 | kp_detector.eval()
45 |
46 | for it, x in tqdm(enumerate(dataloader)):
47 | if config['reconstruction_params']['num_videos'] is not None:
48 | if it > config['reconstruction_params']['num_videos']:
49 | break
50 | with torch.no_grad():
51 | predictions = []
52 | visualizations = []
53 | if torch.cuda.is_available():
54 | x['video'] = x['video'].cuda()
55 |
56 | outputs = depth_decoder(depth_encoder(x['video'][:, :, 0]))
57 | depth_source = outputs[("disp", 0)]
58 | source_rgbd = torch.cat((x['video'][:, :, 0],depth_source),1)
59 |
60 | kp_source = kp_detector(source_rgbd)
61 |
62 | for frame_idx in range(x['video'].shape[2]):
63 | source = x['video'][:, :, 0]
64 |
65 | driving = x['video'][:, :, frame_idx]
66 | outputs = depth_decoder(depth_encoder(driving))
67 | depth_driving = outputs[("disp", 0)]
68 | driving_rgbd = torch.cat((driving,depth_driving),1)
69 |
70 | kp_driving = kp_detector(driving_rgbd)
71 | out = generator(source, kp_source=kp_source, kp_driving=kp_driving)
72 | out['kp_source'] = kp_source
73 | out['kp_driving'] = kp_driving
74 | del out['sparse_deformed']
75 | predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
76 |
77 | visualization = Visualizer(**config['visualizer_params']).visualize(source=source,
78 | driving=driving, out=out)
79 | visualizations.append(visualization)
80 |
81 | loss_list.append(torch.abs(out['prediction'] - driving).mean().cpu().numpy())
82 |
83 | predictions = np.concatenate(predictions, axis=1)
84 | imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))
85 |
86 | image_name = x['name'][0] + config['reconstruction_params']['format']
87 | imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
88 |
89 | print("Reconstruction loss: %s" % np.mean(loss_list))
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.0.0
2 | certifi==2021.10.8
3 | cycler==0.11.0
4 | fonttools==4.33.2
5 | grpcio==1.44.0
6 | imageio==2.17.0
7 | importlib-metadata==4.11.3
8 | joblib==1.1.0
9 | kiwisolver==1.4.2
10 | Markdown==3.3.6
11 | matplotlib==3.5.1
12 | networkx==2.6.3
13 | numpy==1.21.6
14 | packaging==21.3
15 | pandas==1.3.5
16 | Pillow==9.1.0
17 | protobuf==3.20.1
18 | pyparsing==3.0.8
19 | python-dateutil==2.8.2
20 | pytz==2022.1
21 | PyWavelets==1.3.0
22 | PyYAML==5.4.1
23 | scikit-image==0.16.2
24 | scikit-learn==1.0.2
25 | scipy==1.7.3
26 | six==1.16.0
27 | sklearn==0.0
28 | tensorboard==1.15.0
29 | threadpoolctl==3.1.0
30 | tifffile==2021.11.2
31 | torch
32 | torchaudio==0.10.1+rocm4.1
33 | torchvision
34 | tqdm==4.64.0
35 | typing_extensions==4.2.0
36 | Werkzeug==2.1.1
37 | zipp==3.8.0
38 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 |
3 | matplotlib.use('Agg')
4 |
5 | import os, sys
6 | import yaml
7 | from argparse import ArgumentParser
8 | from time import gmtime, strftime
9 | from shutil import copy
10 |
11 | from frames_dataset import FramesDataset
12 | import pdb
13 | # from modules.generator import OcclusionAwareGenerator
14 | import modules.generator as generator
15 | from modules.discriminator import MultiScaleDiscriminator
16 | # from modules.keypoint_detector import KPDetector
17 | import modules.keypoint_detector as KPD
18 | import torch.distributed as dist
19 | from torch.nn.parallel import DistributedDataParallel as DDP
20 |
21 | import torch
22 | from torch.utils.tensorboard import SummaryWriter
23 | from train import train
24 | # from reconstruction import reconstruction
25 | from animate import animate
26 | import random
27 | import numpy as np
28 |
29 |
30 | if __name__ == "__main__":
31 |
32 | if sys.version_info[0] < 3:
33 | raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
34 |
35 | parser = ArgumentParser()
36 | parser.add_argument("--config", required=True, help="path to config")
37 | parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "animate"])
38 | parser.add_argument("--log_dir", default='log', help="path to log into")
39 | parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore")
40 | parser.add_argument("--device_ids", default="0", type=lambda x: list(map(int, x.split(','))),
41 | help="Names of the devices comma separated.")
42 | parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture")
43 | parser.add_argument("--local_rank", type=int)
44 | parser.add_argument("--use_depth",action='store_true',help='depth mode')
45 | parser.add_argument("--rgbd",action='store_true',help='rgbd mode')
46 | parser.add_argument("--kp_prior",action='store_true',help='use kp_prior in final objective function')
47 |
48 | # alter model
49 | parser.add_argument("--generator",required=True,help='the type of genertor')
50 | parser.add_argument("--kp_detector",default='KPDetector',type=str,help='the type of KPDetector')
51 | parser.add_argument("--GFM",default='GeneratorFullModel',help='the type of GeneratorFullModel')
52 |
53 | parser.add_argument("--batchsize",type=int, default=-1,help='user defined batchsize')
54 | parser.add_argument("--kp_num",type=int, default=-1,help='user defined keypoint number')
55 | parser.add_argument("--kp_distance",type=int, default=10,help='the weight of kp_distance loss')
56 | parser.add_argument("--depth_constraint",type=int, default=0,help='the weight of depth_constraint loss')
57 |
58 | parser.add_argument("--name",type=str,help='user defined model saved name')
59 |
60 | parser.set_defaults(verbose=False)
61 | opt = parser.parse_args()
62 | with open(opt.config) as f:
63 | config = yaml.load(f)
64 |
65 | if opt.checkpoint is not None:
66 | log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
67 | else:
68 | log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])
69 | log_dir += opt.name
70 |
71 |
72 | print("Training...")
73 |
74 | dist.init_process_group(backend='nccl', init_method='env://')
75 | torch.cuda.set_device(opt.local_rank)
76 | device=torch.device("cuda",opt.local_rank)
77 | config['train_params']['loss_weights']['depth_constraint'] = opt.depth_constraint
78 | config['train_params']['loss_weights']['kp_distance'] = opt.kp_distance
79 | if opt.kp_prior:
80 | config['train_params']['loss_weights']['kp_distance'] = 0
81 | config['train_params']['loss_weights']['kp_prior'] = 10
82 | if opt.batchsize != -1:
83 | config['train_params']['batch_size'] = opt.batchsize
84 | if opt.kp_num != -1:
85 | config['model_params']['common_params']['num_kp'] = opt.kp_num
86 | # create generator
87 | generator = getattr(generator, opt.generator)(**config['model_params']['generator_params'],
88 | **config['model_params']['common_params'])
89 | generator.to(device)
90 | if opt.verbose:
91 | print(generator)
92 | generator= torch.nn.SyncBatchNorm.convert_sync_batchnorm(generator)
93 |
94 | # create discriminator
95 | discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'],
96 | **config['model_params']['common_params'])
97 |
98 | discriminator.to(device)
99 | if opt.verbose:
100 | print(discriminator)
101 | discriminator= torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator)
102 |
103 | # create kp_detector
104 | if opt.use_depth:
105 | config['model_params']['common_params']['num_channels'] = 1
106 | if opt.rgbd:
107 | config['model_params']['common_params']['num_channels'] = 4
108 |
109 | kp_detector = getattr(KPD, opt.kp_detector)(**config['model_params']['kp_detector_params'],
110 | **config['model_params']['common_params'])
111 | kp_detector.to(device)
112 | if opt.verbose:
113 | print(kp_detector)
114 | kp_detector= torch.nn.SyncBatchNorm.convert_sync_batchnorm(kp_detector)
115 |
116 | kp_detector = DDP(kp_detector,device_ids=[opt.local_rank],broadcast_buffers=False)
117 | discriminator = DDP(discriminator,device_ids=[opt.local_rank],broadcast_buffers=False)
118 | generator = DDP(generator,device_ids=[opt.local_rank],broadcast_buffers=False)
119 |
120 | dataset = FramesDataset(is_train=(opt.mode == 'train'), **config['dataset_params'])
121 | if not os.path.exists(log_dir):
122 | os.makedirs(log_dir)
123 | if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
124 | copy(opt.config, log_dir)
125 |
126 | if not os.path.exists(os.path.join(log_dir,'log')):
127 | os.makedirs(os.path.join(log_dir,'log'))
128 | writer = SummaryWriter(os.path.join(log_dir,'log'))
129 | if opt.mode == 'train':
130 | train(config, generator, discriminator, kp_detector, opt.checkpoint, log_dir, dataset, opt.local_rank,device,opt,writer)
--------------------------------------------------------------------------------
/run_dataparallel.py:
--------------------------------------------------------------------------------
1 |
2 | import os, sys
3 | import yaml
4 | from argparse import ArgumentParser
5 | from shutil import copy
6 |
7 | from frames_dataset import FramesDataset
8 | import pdb
9 | import modules.generator as generator
10 | from modules.discriminator import MultiScaleDiscriminator
11 | import modules.keypoint_detector as KPD
12 |
13 | import torch
14 | from torch.utils.tensorboard import SummaryWriter
15 | from train_dataparallel import train
16 | # from reconstruction import reconstruction
17 | from animate import animate
18 | import random
19 | import numpy as np
20 |
21 | if __name__ == "__main__":
22 |
23 | if sys.version_info[0] < 3:
24 | raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
25 |
26 | parser = ArgumentParser()
27 | parser.add_argument("--config", required=True, help="path to config")
28 | parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "animate"])
29 | parser.add_argument("--log_dir", default='log', help="path to log into")
30 | parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore")
31 | parser.add_argument("--device_ids", default="0", type=lambda x: list(map(int, x.split(','))),
32 | help="Names of the devices comma separated.")
33 | parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture")
34 | parser.add_argument("--use_depth",action='store_true',help='depth mode')
35 | parser.add_argument("--rgbd",action='store_true',help='rgbd mode')
36 | parser.add_argument("--kp_prior",action='store_true',help='use kp_prior in final objective function')
37 |
38 | # alter model
39 | parser.add_argument("--generator",required=True,help='the type of genertor')
40 | parser.add_argument("--kp_detector",default='KPDetector',type=str,help='the type of KPDetector')
41 | parser.add_argument("--GFM",default='GeneratorFullModel',help='the type of GeneratorFullModel')
42 |
43 | parser.add_argument("--batchsize",type=int, default=-1,help='user defined batchsize')
44 | parser.add_argument("--kp_num",type=int, default=-1,help='user defined keypoint number')
45 | parser.add_argument("--kp_distance",type=int, default=10,help='the weight of kp_distance loss')
46 | parser.add_argument("--depth_constraint",type=int, default=0,help='the weight of depth_constraint loss')
47 |
48 | parser.add_argument("--name",type=str,help='user defined model saved name')
49 |
50 | parser.set_defaults(verbose=False)
51 | opt = parser.parse_args()
52 | with open(opt.config) as f:
53 | config = yaml.load(f)
54 |
55 | if opt.checkpoint is not None:
56 | log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
57 | else:
58 | log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])
59 | log_dir += opt.name
60 |
61 |
62 | print("Training...")
63 |
64 | config['train_params']['loss_weights']['depth_constraint'] = opt.depth_constraint
65 | config['train_params']['loss_weights']['kp_distance'] = opt.kp_distance
66 | if opt.kp_prior:
67 | config['train_params']['loss_weights']['kp_distance'] = 0
68 | config['train_params']['loss_weights']['kp_prior'] = 10
69 | if opt.batchsize != -1:
70 | config['train_params']['batch_size'] = opt.batchsize
71 | if opt.kp_num != -1:
72 | config['model_params']['common_params']['num_kp'] = opt.kp_num
73 | # create generator
74 | generator = getattr(generator, opt.generator)(**config['model_params']['generator_params'],
75 | **config['model_params']['common_params'])
76 | if torch.cuda.is_available():
77 | generator.to(opt.device_ids[0])
78 | if opt.verbose:
79 | print(generator)
80 |
81 | # create discriminator
82 | discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'],
83 | **config['model_params']['common_params'])
84 |
85 | if torch.cuda.is_available():
86 | discriminator.to(opt.device_ids[0])
87 | if opt.verbose:
88 | print(discriminator)
89 |
90 | # create kp_detector
91 | if opt.use_depth:
92 | config['model_params']['common_params']['num_channels'] = 1
93 | if opt.rgbd:
94 | config['model_params']['common_params']['num_channels'] = 4
95 |
96 | kp_detector = getattr(KPD, opt.kp_detector)(**config['model_params']['kp_detector_params'],
97 | **config['model_params']['common_params'])
98 | if torch.cuda.is_available():
99 | kp_detector.to(opt.device_ids[0])
100 | if opt.verbose:
101 | print(kp_detector)
102 |
103 | dataset = FramesDataset(is_train=(opt.mode == 'train'), **config['dataset_params'])
104 | if not os.path.exists(log_dir):
105 | os.makedirs(log_dir)
106 | if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
107 | copy(opt.config, log_dir)
108 |
109 | if not os.path.exists(os.path.join(log_dir,'log')):
110 | os.makedirs(os.path.join(log_dir,'log'))
111 | writer = SummaryWriter(os.path.join(log_dir,'log'))
112 | if opt.mode == 'train':
113 | train(config, generator, discriminator, kp_detector, opt.checkpoint, log_dir, dataset, opt.device_ids, opt,writer)
--------------------------------------------------------------------------------
/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
13 |
--------------------------------------------------------------------------------
/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def __getstate__(self):
79 | return {'master_callback': self._master_callback}
80 |
81 | def __setstate__(self, state):
82 | self.__init__(state['master_callback'])
83 |
84 | def register_slave(self, identifier):
85 | """
86 | Register an slave device.
87 |
88 | Args:
89 | identifier: an identifier, usually is the device id.
90 |
91 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
92 |
93 | """
94 | if self._activated:
95 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
96 | self._activated = False
97 | self._registry.clear()
98 | future = FutureResult()
99 | self._registry[identifier] = _MasterRegistry(future)
100 | return SlavePipe(identifier, self._queue, future)
101 |
102 | def run_master(self, master_msg):
103 | """
104 | Main entry for the master device in each forward pass.
105 | The messages were first collected from each devices (including the master device), and then
106 | an callback will be invoked to compute the message to be sent back to each devices
107 | (including the master device).
108 |
109 | Args:
110 | master_msg: the message that the master want to send to itself. This will be placed as the first
111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112 |
113 | Returns: the message to be sent back to the master device.
114 |
115 | """
116 | self._activated = True
117 |
118 | intermediates = [(0, master_msg)]
119 | for i in range(self.nr_slaves):
120 | intermediates.append(self._queue.get())
121 |
122 | results = self._master_callback(intermediates)
123 | assert results[0][0] == 0, 'The first result should belongs to the master.'
124 |
125 | for i, res in results:
126 | if i == 0:
127 | continue
128 | self._registry[i].result.put(res)
129 |
130 | for i in range(self.nr_slaves):
131 | assert self._queue.get() is True
132 |
133 | return results[0][1]
134 |
135 | @property
136 | def nr_slaves(self):
137 | return len(self._registry)
138 |
--------------------------------------------------------------------------------
/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 |
13 | import numpy as np
14 | from torch.autograd import Variable
15 |
16 |
17 | def as_numpy(v):
18 | if isinstance(v, Variable):
19 | v = v.data
20 | return v.cpu().numpy()
21 |
22 |
23 | class TorchTestCase(unittest.TestCase):
24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25 | npa, npb = as_numpy(a), as_numpy(b)
26 | self.assertTrue(
27 | np.allclose(npa, npb, atol=atol),
28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29 | )
30 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from tqdm import trange
2 | import torch
3 |
4 | from torch.utils.data import DataLoader
5 |
6 | from logger import Logger
7 | from modules.model import GeneratorFullModel, DiscriminatorFullModel
8 | import modules.model as MODEL
9 | from tqdm import tqdm
10 | from torch.optim.lr_scheduler import MultiStepLR
11 | from torch.nn.parallel import DistributedDataParallel as DDP
12 | import pdb
13 | from sync_batchnorm import DataParallelWithCallback
14 | from evaluation_dataset import EvaluationDataset
15 |
16 | from frames_dataset import DatasetRepeater
17 |
18 |
19 | def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, dataset, rank,device,opt,writer):
20 | train_params = config['train_params']
21 |
22 | optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999))
23 | optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999))
24 | optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr_kp_detector'], betas=(0.5, 0.999))
25 |
26 | if checkpoint is not None:
27 | start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector,
28 | optimizer_generator, optimizer_discriminator,
29 | None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector)
30 | else:
31 | start_epoch = 0
32 |
33 | scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1,
34 | last_epoch=start_epoch - 1)
35 | scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1,
36 | last_epoch=start_epoch - 1)
37 | scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1,
38 | last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0))
39 |
40 | if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
41 | dataset = DatasetRepeater(dataset, train_params['num_repeats'])
42 | sampler = torch.utils.data.distributed.DistributedSampler(dataset,num_replicas=torch.cuda.device_count(),rank=rank)
43 | dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=False, num_workers=16,sampler=sampler, drop_last=True)
44 |
45 |
46 | generator_full = getattr(MODEL,opt.GFM)(kp_detector, generator, discriminator, train_params,opt)
47 | discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params)
48 | test_dataset = EvaluationDataset(dataroot='/data/fhongac/origDataset/vox1_frames',pairs_list='data/vox_evaluation.csv')
49 | test_dataloader = torch.utils.data.DataLoader(
50 | test_dataset,
51 | batch_size = 1,
52 | shuffle=False,
53 | num_workers=4)
54 | with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger:
55 | for epoch in trange(start_epoch, train_params['num_epochs']):
56 | #parallel
57 | sampler.set_epoch(epoch)
58 | total = len(dataloader)
59 | epoch_train_loss = 0
60 | generator.train(), discriminator.train(), kp_detector.train()
61 | with tqdm(total=total) as par:
62 | for i,x in enumerate(dataloader):
63 | x['source'] = x['source'].to(device)
64 | x['driving'] = x['driving'].to(device)
65 | losses_generator, generated = generator_full(x)
66 |
67 | loss_values = [val.mean() for val in losses_generator.values()]
68 | loss = sum(loss_values)
69 | loss.backward()
70 | optimizer_generator.step()
71 | optimizer_generator.zero_grad()
72 | optimizer_kp_detector.step()
73 | optimizer_kp_detector.zero_grad()
74 | epoch_train_loss+=loss.item()
75 |
76 | if train_params['loss_weights']['generator_gan'] != 0:
77 | optimizer_discriminator.zero_grad()
78 | losses_discriminator = discriminator_full(x, generated)
79 | loss_values = [val.mean() for val in losses_discriminator.values()]
80 | loss = sum(loss_values)
81 |
82 | loss.backward()
83 | optimizer_discriminator.step()
84 | optimizer_discriminator.zero_grad()
85 | else:
86 | losses_discriminator = {}
87 |
88 | losses_generator.update(losses_discriminator)
89 | losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}
90 | # for k,v in losses.items():
91 | # writer.add_scalar(k, v, total*epoch+i)
92 | logger.log_iter(losses=losses)
93 | par.update(1)
94 | epoch_train_loss = epoch_train_loss/total
95 | if (epoch + 1) % train_params['checkpoint_freq'] == 0:
96 | writer.add_scalar('epoch_train_loss', epoch_train_loss, epoch)
97 | scheduler_generator.step()
98 | scheduler_discriminator.step()
99 | scheduler_kp_detector.step()
100 | logger.log_epoch(epoch, {'generator': generator,
101 | 'discriminator': discriminator,
102 | 'kp_detector': kp_detector,
103 | 'optimizer_generator': optimizer_generator,
104 | 'optimizer_discriminator': optimizer_discriminator,
105 | 'optimizer_kp_detector': optimizer_kp_detector}, inp=x, out=generated)
106 | generator.eval(), discriminator.eval(), kp_detector.eval()
107 | if (epoch + 1) % train_params['checkpoint_freq'] == 0:
108 | epoch_eval_loss = 0
109 | for i, data in tqdm(enumerate(test_dataloader)):
110 | data['source'] = data['source'].cuda()
111 | data['driving'] = data['driving'].cuda()
112 | losses_generator, generated = generator_full(data)
113 | loss_values = [val.mean() for val in losses_generator.values()]
114 | loss = sum(loss_values)
115 | epoch_eval_loss+=loss.item()
116 | epoch_eval_loss = epoch_eval_loss/len(test_dataloader)
117 | writer.add_scalar('epoch_eval_loss', epoch_eval_loss, epoch)
118 |
--------------------------------------------------------------------------------
/train_dataparallel.py:
--------------------------------------------------------------------------------
1 | from tqdm import trange
2 | import torch
3 |
4 | from torch.utils.data import DataLoader
5 |
6 | from logger import Logger
7 | from modules.model_dataparallel import DiscriminatorFullModel
8 | import modules.model_dataparallel as MODEL
9 | from tqdm import tqdm
10 | from torch.optim.lr_scheduler import MultiStepLR
11 | import pdb
12 | from sync_batchnorm import DataParallelWithCallback
13 | from evaluation_dataset import EvaluationDataset
14 |
15 | from frames_dataset import DatasetRepeater
16 |
17 |
18 | def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, dataset, device_ids,opt,writer):
19 | train_params = config['train_params']
20 |
21 | optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999))
22 | optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999))
23 | optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr_kp_detector'], betas=(0.5, 0.999))
24 |
25 | if checkpoint is not None:
26 | start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector,
27 | optimizer_generator, optimizer_discriminator,
28 | None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector)
29 | else:
30 | start_epoch = 0
31 |
32 | scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1,
33 | last_epoch=start_epoch - 1)
34 | scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1,
35 | last_epoch=start_epoch - 1)
36 | scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1,
37 | last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0))
38 |
39 | if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
40 | dataset = DatasetRepeater(dataset, train_params['num_repeats'])
41 | dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=16,drop_last=True)
42 |
43 |
44 | generator_full = getattr(MODEL,opt.GFM)(kp_detector, generator, discriminator, train_params,opt)
45 | discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params)
46 | test_dataset = EvaluationDataset(dataroot='/data/fhongac/origDataset/vox1_frames',pairs_list='data/vox_evaluation.csv')
47 | test_dataloader = torch.utils.data.DataLoader(
48 | test_dataset,
49 | batch_size = 1,
50 | shuffle=False,
51 | num_workers=4)
52 | if torch.cuda.is_available():
53 | generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids)
54 | discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids)
55 |
56 | with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger:
57 | for epoch in trange(start_epoch, train_params['num_epochs']):
58 | #parallel
59 | total = len(dataloader)
60 | epoch_train_loss = 0
61 | generator.train(), discriminator.train(), kp_detector.train()
62 | with tqdm(total=total) as par:
63 | for i,x in enumerate(dataloader):
64 | # x['source'] = x['source'].to(device)
65 | # x['driving'] = x['driving'].to(device)
66 | losses_generator, generated = generator_full(x)
67 |
68 | loss_values = [val.mean() for val in losses_generator.values()]
69 | loss = sum(loss_values)
70 | loss.backward()
71 | optimizer_generator.step()
72 | optimizer_generator.zero_grad()
73 | optimizer_kp_detector.step()
74 | optimizer_kp_detector.zero_grad()
75 | epoch_train_loss+=loss.item()
76 |
77 | if train_params['loss_weights']['generator_gan'] != 0:
78 | optimizer_discriminator.zero_grad()
79 | losses_discriminator = discriminator_full(x, generated)
80 | loss_values = [val.mean() for val in losses_discriminator.values()]
81 | loss = sum(loss_values)
82 |
83 | loss.backward()
84 | optimizer_discriminator.step()
85 | optimizer_discriminator.zero_grad()
86 | else:
87 | losses_discriminator = {}
88 |
89 | losses_generator.update(losses_discriminator)
90 | losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}
91 | # for k,v in losses.items():
92 | # writer.add_scalar(k, v, total*epoch+i)
93 | logger.log_iter(losses=losses)
94 | par.update(1)
95 | epoch_train_loss = epoch_train_loss/total
96 | if (epoch + 1) % train_params['checkpoint_freq'] == 0:
97 | writer.add_scalar('epoch_train_loss', epoch_train_loss, epoch)
98 | scheduler_generator.step()
99 | scheduler_discriminator.step()
100 | scheduler_kp_detector.step()
101 | logger.log_epoch(epoch, {'generator': generator,
102 | 'discriminator': discriminator,
103 | 'kp_detector': kp_detector,
104 | 'optimizer_generator': optimizer_generator,
105 | 'optimizer_discriminator': optimizer_discriminator,
106 | 'optimizer_kp_detector': optimizer_kp_detector}, inp=x, out=generated)
107 | generator.eval(), discriminator.eval(), kp_detector.eval()
108 | if (epoch + 1) % train_params['checkpoint_freq'] == 0:
109 | epoch_eval_loss = 0
110 | for i, data in tqdm(enumerate(test_dataloader)):
111 | data['source'] = data['source'].cuda()
112 | data['driving'] = data['driving'].cuda()
113 | losses_generator, generated = generator_full(data)
114 | loss_values = [val.mean() for val in losses_generator.values()]
115 | loss = sum(loss_values)
116 | epoch_eval_loss+=loss.item()
117 | epoch_eval_loss = epoch_eval_loss/len(test_dataloader)
118 | writer.add_scalar('epoch_eval_loss', epoch_eval_loss, epoch)
119 |
--------------------------------------------------------------------------------